Post

Chinchilla Scaling Laws Calculator

Simple Chinchilla Calculator

How many tokens should we train with?

If you look for material about scaling laws online, you will either find rather technical material close to the original paper, or very vague simplifications. I have not found something like the following graph, which I think is (one of) the most useful takeaways of the paper:

This graph plots the compute-optimal relation between the number of training tokens and the model size. More precisely, for a given compute capacity, the Chinchilla scaling laws predict that there exists a certain count of tokens and model size such that training a smaller model with more data, or a bigger model with less data would lead to worse performance. This graph shows these optimal token counts and model sizes for a vast range of compute capacities.

The formula to measure the relation between $N$ (given model size) and $D^*$ (optimal training tokens count) can be derived as:

\[D^*(N) = \gamma N^\phi\]

where $\gamma \simeq 0.519$ and $\phi \simeq 1.214$.

The Scaling Laws & Derivation

In the Chinchilla paper, the authors compute the surface of the loss function at fixed compute (i.e. a given $(N, D)$ pair). They empirically retrofit the following formulation:

\[L(N, D) = \frac{A}{N^\alpha} + \frac{B}{D^\beta} + E\]

Let’s suppose that we want to train a model of $N_m$ parameters in a compute-optimal way for an unknown level of compute. That is, we want to train our model using $D$ tokens, which will cost $C(N_m, D) \simeq 6N_mD = 6c$ FLOPS, so that if we had had $6c$ compute in the first place the optimal training configuration would have been $(N_m, D)$.

Given a compute level $6c$, and using the approximation of the paper, we have $L(N, D) = L(N, \frac{c}{N})$. Thus, the optimal model size $N_o$ can be found:

\[\begin{align} &\frac{\partial L}{\partial N} (N_o, \frac{c}{N_o}) = 0 \\ &\implies -\alpha A N_o^{-\alpha - 1} + \beta B c^{-\beta} N_o^{\beta - 1} = 0 \\ &\implies N_o = \sqrt[\alpha + \beta]{\frac{\alpha A c^{\beta}}{\beta B}} \end{align}\]

Let’s put ourselves in the setup where we want to train a model of size $N_m$. The compute $c_t$ such that $N_m$ would be the optimal model size can be deduced:

\[N_m = \sqrt[\alpha + \beta]{\frac{\alpha A c_t^{\beta}}{\beta B}} \implies c_t = \sqrt[\beta]{\frac{\beta B N_m^{\alpha + \beta}}{\alpha A}}\]

And then using the approximation $D^*(N_m) = \frac{c_t}{N_m}$, we get:

\[D^*(N_m) = \sqrt[\beta]{\frac{\beta B}{\alpha A}} N_m^{\frac{\alpha}{\beta}} = \gamma N^\phi\]

Using the fitted parameters from the paper, we get $\gamma \simeq 0.519$ and $\phi \simeq 1.214$.

Visualizing the Scaling Laws

Below is a heatmap helping with the visualization of the Chinchilla loss interpolation:

The compute estimation is based on an estimate of 100 WFLOPS throughput on A100 GPUs, and is thus a very approximative guess. Please let me know in the comments if any calculation or value seems wrong!

This work was funded by the PRAIRIE institute as part of a PhD contract at Inria Paris and Sorbonne Université.

This post is licensed under CC BY 4.0 by the author.