07A: Scaling Laws
Materials:
Date: Tuesday, 10-Sep-2024
Pre-work:
- Lesson 2 of Stat503 on planning a simple comparative experiment
In-Class
- Review sample size calculation in t-test. How sample size is a function of data quality, confidence, and tolerance for errors in the conclusions. Except in simple cases, getting a good sense of the “how much data” is a hard, rather very hard question. But we can try.
- Myth Buster - more data is better
- it is like a tautology. this notion never gets challenged. More data does not lead to better RoI. In fact, more of the same can never improve performance beyond a point or in some specific cases, it is impossible.
- law of diminishing returns. collecting more data can not only be expensive but can saturate, reaching a plateau. In fact, collecting data to understand where this plateau and what is the ceiling is, is an interesting problem in itself. A well designed experiment can address this question.
- Can we estimate the sample size needed?
- back of the envelope calculations based on some idea about the data, based on two-sample sample size calculations
- from PAC theory bounds (Chapter 8 of An Elementary Introduction to Statistical Learning Theory)
- empirical scaling laws
Post-class
- [book] An Elementary Introduction to Statistical Learning Theory. See Chapter 8 for how the theory of VC Dimension can be applied to get an idea on the sample size.
- [youtube] PAC Learning Ali Ghosi Lec 19 PAC Learning, STAT 441/841 Statistical Learning, University of Waterloo
- [paper] Training Compute-Optimal LLMs
Additional Reading (optional)
- [book-online] Linear Models - first regression course for EPFL mathematicians.
- [paper] Approximation Rates and VC-Dimension Bounds for (P)ReLU MLP Mixture of Experts
- [book] Understanding Machine Learning: From Theory to Algorithms - a classic from Shai Shalev-Shwartz and Shai Ben-David. Part-1. youtube playlist
- [paper] Scaling Laws for Neural Language Models
- [paper] An empirical study of scaling laws for transfer learning
- [paper] Scaling laws for Individual data points
- [paper] Scaling laws in Linear Regression
Notes
When embarking on a problem, the first set of questions we are asked are:
- What type/kind of data is needed or useful for the problem at hand?
- How many samples do you need? A new name for this is scaling laws :)
- How hard the problem is - are all examples will be equally important?
Let us briefly discuss the first question about “what kind”.
What to collect?
We discussed ideation tools like Design Thinking, Human-centered Design to engage with the problem (to define what to solve). In the Project Canvas, we also discussed what is the ML Task, and what kind of data is needed to solve the business problem.
For example, suppose the problem is, say, assessing the quality of food grains. The business problem is to objectively assess the grain quality and decide a fair price linked to the quality. ML problem is to grade the grains, identify any foreign materia/contaminations, may be based on the morphology of the grains. What kind of data is needed here? We can treat this problem as a combination of segmentation, object detection, object classification, morphology analysis, and granulometry. Data must support these tasks. See the blogs Image Annotation Definition, Usecases and Types, Best practices for successful image annotation, Data Annotation - A beginner’s guide to delve into annotating images further.
Arguably, annotating images for segmentation at grain level is very time consuming but can potentially be very useful. Treating this problem as a grading problem (accept/ reject) as opposed to provide quality assessment at grain and sample level is relatively easy in terms of annotations. But with accept/reject the lot binary labels at image level one can not pivot to a different ML problem formulation if the solution is proved not useful. This is an inherent design challenge. Getting as many details as possible is desirable but many not be practical. But thankfully, with foundation models like SAM, some of this leg work can be automated. See this paper on sample selection for efficient image annotation. Most modern annotation studios like Labellerr offer some support to speed-up the process.
While the above labels support the primary ML task, say the segmentation task, for example, they may not be sufficient to offer SLAs or Model monitoring (battling the unknown unknowns) that we discussed in Model Monitoring. Suppose, for the same problem, the model is deployed in two grain collection sites. Training data was collected from the the first site (for rice grains), and model is being is used in the second site. It so happens that the second site is using the model for “wheat” instead of “rice.” Obviously, model will perform poorly. This may not have been anticipated ahead of time. Only way to detect this is to first log the data where the model is being used, and observe the performance by site. That means, where the model is being in this case and any other contextual data is necessary to monitor, diagnose, and improve models over time. This is all but one of the dimensions to think about. Generally speaking, thinking about dimensions is the topic of knowledge representation and knowledge engineering. Dimensional Modeling](https://www.ibm.com/docs/en/ida/9.1.2?topic=modeling-dimensional) is a framework to adopt to collect data that can address the concerns of different stakeholders besides the ML and Data Scientists such as a Software Engineer, QA, Architect, Product Manager, Site Manager etc. IBM’s book on Dimensional Modeling: In a Business Intelligence Environment is a definitive guide on this topic. A knowledge Engineering Primer is perhaps a lighter compared to the book.
Now let us move to the other question. How many?
How many?
This question gets very complicated in no time. And there can be more than one way to develop a heuristic to arrive at an approximate answer. The type of response and approach also depends on the scenario.
- Exact problem was solved before or data is available which can be used as-is [best case]. Take the data, build a model, and deploy and start using. Such occurrences can be rare but can happen. For example, people counting of a given demographic from traffic camera feeds.
- Similar problem solved, except for differences in domain. Take a model trained on that data, collect data on the target domain. Start improving the model over a period of time. Use transfer learning/ supervised fine-tuning methods. If data exists (in literature or on huggingace/kaggle for eg) for that problem, plot accuracy vs sample size and pick a target that gives desired performance. Iterate over it. Do not collect all data at once.
- Unlabelled data available in large quantities. In such a case, run a pre-trained model or develop self-supervised learning tasks, get very good representations and try to label those that are easy and/or important for the performance. See cords and other active learning methods. See the application of self-supervised pre-training techniques for multi-lable classification in Chest X-Ray paper
- Data exists and similar problem is not solved before [worst case]. Then, one has to design an experiment, provide some inputs (design considerations or operating environment) about the problem, get a ball park sense of the sample size (for the sake of budget, resource and project planning), run a pilot data collection exercise, refine the strategy and iterate. Below, we develop some heuristics to estimate the sample size.
A Statistical Approach
Let us simplify and consider a regression problem. We are trying to learn a function \(f: [0,1] \rightarrow R\). Imagine you are fitting a decision tree to approximate this function. A decision tree partitions the input space, and in each of the partitions certain statistics like mean and quintiles are computed. For a given instance, the prediction is given by, for example, the mean of all responses of the examples belonging to that partition. So, we can divide or cluster or partition the training data into \(K\) subsets and compute some statistic in each these subsets. If unlablled data is available, running a clustering algorithms will give an idea about \(K\). If we assume that the labels (responses) of the k-th partition denoted by \(y_k\) follow \(N(\mu_k, \sigma^2)\), we can estimate the \((1-\alpha)\) level prediction interval (PI) for \(\mu_k\) as \[\bar{y}_k + z_{1-\frac{\alpha}{2}}\sigma \sqrt{1+\frac{1}{N}}\]
where \(\bar{y}_k\) is the sample mean, \(N\) is the sample size, \(\sigma^2\) is the noise variance, \(\alpha\) controls the confidence level (or type-1 error of the corresponding hypothesis test). So the “design inputs” needed to solicit a sample size are: \(\alpha\), \(\sigma^2\). Sometimes, the precision needed for the estimate can be asserted in terms of the width of the interval (PIW). In this case, PIW is given as \(PIW = 2z_{1-\frac{\alpha}{2}}\sigma \sqrt{1+\frac{1}{N}}\). Now, we can express sample size as a function of \(PIW\), \(\alpha\) as: \(N = \left( (\frac{PIW}{2z_{1-\frac{\alpha}{2}}})^{2}-1 \right)^{-1}\). If there are \(K\) partitions, we need to estimate that many \(\mu_k\)s. So, the total sample size will be \(NK\) assuming all partitions have same variance. If not, is is not hard to update the formula. In someways, the model complexity is captured by \(K\). In general, one does not know these numbers in advance and has to make an educated guess based on domain knowledge and refine the design inputs as the data collection drive is set in motion.
What about the classification problem?
Assume it is a binary classification problem. Approach is still identical. Even in the classification setting, estimating the mean and taking argmax to predict the label of the partition is still useful and applicable except that the \(PI\) formula needs to be updated. For other types of Tasks, suitable estimate of the target varaible has to be chosen, and derive its PI, and use it to get an estimate of the sample size.
An ML Approach
Can we relax the assumptions and yet come-up with some estimates for \(N\)?
We can invoke PAC theory. Suppose \(\epsilon\) is the tolerable error (that the test error should not exceed), \(\delta\) is the confidence in the learning algorithm that test error can exceed \(\epsilon\) by no more than \(\delta\) fraction of times and \(V\) is the VC-dimensionof the function class. The following bounds from PAC theory can guide us:
\[\frac{1}{\epsilon}\left(4\log(\frac{2}{\delta}) + 8V\log(\frac{13}{\epsilon}) \right) \le N \le \frac{64}{\epsilon^2}\left(2V \log(\frac{V}{\epsilon}) + \log(\frac{4}{\delta}) \right) \]
In particular, if one uses an MLP, following is a lower bound on the VC-dimension \[V \le 2(d+1)s \log(2.718s)\] where \(d\) is the number of features, \(s\) is the number of perceptrons in the network. So, in effect, given an MLP architecture (with \(s\) number of perceptrons), confidence(\(\delta\)), permissible error (\(\epsilon\)), input features (\(d\)), one can get an idea of the sample size.
Caution: These bounds can be very vacuous and many not be uniformly tight across the range of the inputs. For example, consider the inequality \(x^2 \ge x, \forall x \ge 1\). The gap \(|x^2-x|\) grows unboundedly and gets worse as \(x\) increases. So, always play with several input values and pick a sensible one. Do not just believe that they will work out of the box. No magics here.
A DL Approach
For tabular problems, one can use either of the above methods to get a sensible estimate. But for speech, vision, and language datasets, it is both complicated and simple at the same time. Simple in the sense that, one could take pre-trained foundation models and work with their representations and technically use PAC-theory-based heuristics. But the latent dimensions could be extremely large. So, if unlabeled data is available, run a clustering algorithm and figure out the intrinsic dimensionality which can be plugged into previous formulae. It is complicated when one has no prior knowledge and commits to a Deep Learning approach. In such cases, building a simple baselines is still going to be useful. Such a simple model can guide the data collection exercise.
What about Large Language Models?
Do we want to train custom LLMs or do we want to use them for inference?
For training LLMs, how much data is needed, the compute needed and the performance and emergence of the variables to study play. This area is emerging with results and counter-results. But some early works fit parametric curves to predict performance given training data (counted in terms of number of tokens) and for a given compute budget. The following regression model is considered in Training Compute-Optimal LLMs, using their notation, known as Chinchila scaling laws : \[ L(N,D) = \frac{A}{N^{\alpha}} + \frac{B}{D^{\beta}} + E \] where \(L()\) is the loss, \(D\) is the training dataset size in tokens, \(N\) is the model size, \(A,B,E,\alpha, \beta\) are unknowns to be fit from experimental data. Under this mode, \(E\) is the smallest loss achievable (irreducible noise), with infinite data and infinite compute. Based on large scale experiments, they fit \[ L(N,D) = \frac{406.4}{N^{0.34}} + \frac{410.7}{D^{0.28}} + 1.69 \]. It may be better to use a more standard notation and rewrite them as: \[ E(K,N) = \frac{A}{K^{\alpha}} + \frac{B}{N^{\beta}} + \sigma^2 \] where \(K\) is the model size/complexity, \(N\) is the number of tokens, and \(\sigma^2\) is the noise variance. See Pathways Language Model and Model Scaling from Aakanskha for more on Scaling Laws.
Some observations w.r.t LLM scaling laws are:
- Training LLMs from scratch is an extremely specialized endeavor, requiring not only deep pockets, good understanding of the LLM science but also solid (distributed) systems engineering knowledge.
- Both data size and model size of sufficient size are needed to see emergence.
- For instruction fine-tuning, about 1k-6k instruction pairs is considered a good start. See LIMIT: Less Is More for Instruction Tuning Across Evaluation Paradigm . More would be better in this case. Like always, quality and representativeness matter.
- For In-Context Learning (ICL), a fewshot learning is better. But instead of prompt engineering in ad-hoc fashion, optimizing prompts in a data driven manner, using frameworks like DSPy. For example, how many and which examples to include in the fewshot ICL can be optimzed with DSPy. Fewshot is all but one of the strategies to improve the performance of LLMs. See The Prompt report for more details.
Take-aways
- Collecting data is an iterative exercise
- Play with several design inputs and pick a good starting point. Run several heuristics.
- Try to leverage past knowledge (datasets, models, and problem similarity) as much as possible.
- Do not collect all data in one tranche but collect often, refine the strategy and iterate.
- Incorporate practical constraints. Otherwise, data collection will not even begin.