In which we investigate mini-batch size and learn that we have a problem with forgetfulness

When we left off last time, we had inherited an 18-layer ResNet and learning rate schedule from the fastest, single GPU DAWNBench entry for CIFAR10. Training to 94% test accuracy took 341s and with some minor adjustments to network and data loading we had reduced this to 297s.

Our training thus far uses a batch size of 128. Larger batches should allow for more efficient computation so let’s see what happens if we increase batch size to 512. If we’re to approximate the previous setup, we need to make sure that learning rates and other hyperparameters are suitably adjusted.

SGD with mini-batches is similar to training one example at a time, with the difference being that parameter updates are delayed until the end of a batch. In the limit of low learning rates, one can argue that this delay is a higher order effect and that batching doesn’t change anything to first order, so long as gradients are summed, not averaged, over mini-batches. We are also applying weight decay after each batch and this should be increased by a factor of batch size to compensate for the reduction in the number of batches processed. If gradients are being averaged over mini-batches, then learning rates should be scaled to undo the effect and weight decay left alone since our weight decay update incorporates a factor of the learning rate.

So without further ado, let’s train with batch size 512. Training completes in 256s and with one minor adjustment to the learning rate – increasing it by 10% – we are able to match the training curve of the base runs with batch size 128 and 3/5 runs reach 94% test accuracy. The noisier validation results during training at batch size 512 are expected because of batch norm effects, but more on this in a later post. Larger batches may also be possible with a little care, but for now we’ll settle for 512.

Now speedups are great, but this result is surprising to me.

Our discussion, concerning the equivalence of training with different mini-batch sizes, was naive in at least two ways. First, we argued that delaying updates until the end of a mini-batch is a higher order effect and that it should be ok in the limit of small learning rates. It is not at all clear that this limit applies. Indeed the fast training speed of the current setup comes in large part from the use of high learning rates. In the context of convex optimisation (or just gradient descent on a quadratic), one achieves maximum training speed by setting learning rates at the point where second order effects start to balance first order ones and any benefits from increased first order steps are offset by curvature effects. Assuming that we are in this regime, the delayed updates from mini-batching should incur the same curvature penalties as a corresponding increase in learning rate and training should become unstable. In short, if higher order effects can be neglected, you are not training fast enough.

A second issue with the argument is that it applied to a single training step, but in fact training is a long running process which continues for at least O(1/learning rate) steps in order to allow O(1) changes to parameters. Thus, second order differences between small and large batch training could accumulate over time and lead to substantially different training trajectories. We will revisit this point in a later post when we’ve understood a little more about the long running dynamics of training but for now we will concern ourselves with the first one.

So how can it be that we are simultaneously at the speed limit of training and able to increase batch size without sustaining instability from curvature effects? The answer, presumably, is that something else is limiting the achievable learning rates and we are not in the regime where curvature effects dominate. We will argue that this something else is the alarmingly named Catastrophic Forgetting and that this, rather than curvature of the loss, is what limits learning rates at small batch sizes.

First we should explain what we mean. Usually this term is applied to a situation where a model is trained on one task and then on a second or more tasks. Learning the later tasks leads to degradation in performance on the earlier ones and sometime this effect is Catastrophic. In our case, the tasks in question are different parts of the same training set and forgetfulness can occur within a single epoch at high enough learning rates. The larger the learning rates, the more parameters move about during a single training epoch and at some point this must impair the model’s ability to absorb information from the whole dataset. Earlier batches will be effectively forgotten.

We have already seen a first piece of evidence for our claim: increasing batch sizes does not immediately lead to training instability as it should if curvature was the issue, but not if the issue is forgetfulness which should be mostly unaffected by batch size.

Next, we run an experiment to separate out the effects of curvature, which depends primarily on learning rate, from forgetfulness, which depends jointly on learning rate and dataset size. We plot the final training and test losses at batch size 128 when we train using subsets of the training set of different sizes. We use the original learning rate schedule rescaled by a range of factors between 1/8 and 16.

We can see a number of interesting things in the plots. First, training and test losses both become suddenly unstable at a similar learning rate (about 8 times the original learning rate) independent of training set size. This is a strong indication that curvature effects become important at this point. Conversely, in a large range around the original learning rate (learning rate factor=1 in the plots) training and test losses are stable.

The optimal learning rate factor (measured by test set loss) is close to one for the full training dataset which is expected since this has been hand optimised. For smaller datasets, the optimal learning rate factor is higher and for the smallest, of size 6250, the optimum is close to the point at which curvature effects destabilise training. This is in line with our hypothesis above: for a dataset which is small enough such that forgetfulness is no longer an issue, learning rates should be pushed close to the limit allowed by curvature. For larger datasets, the optimal point can be significantly lower because of the forgetfulness effect.

It’s interesting to plot results also at batch size 512. We might expect that because of the 4× larger steps at this batch size, we will find ourselves closer to the curvature instability, which should set in at a learning rate factor of about 2 instead of 8. We also expect the optimal values of learning rate factors and losses to be similar to what they were at batch size 128, since speed of forgetting is unaffected by batch size and curvature effects are not yet dominant at the optimum. The results are as we would hope:

We can directly observe the effects of forgetfulness with the following experiment. We set batch size=128 and train with a learning rate schedule which increases linearly for the first 5 epochs and then remains constant at a fixed maximal rate for a further 25 epochs so that the training and test losses stabilise at the given learning rate. We compare training runs on two different datasets: a) 50% of the full training set with no data augmentation and b) the full dataset with our standard augmentation. We then freeze the final model from run b) and recompute losses over the last several epochs from the training run just completed. The idea of recomputing losses in this way is to compare the loss of the model on batches seen most recently versus ones seen longer ago to test the model’s memory.

Here are results for a maximum learning rate 4× higher than the original training setup:

And here are results for a maximum learning rate 4× lower than the original training:

Several things stand out from these results. Focussing on the first three plots, corresponding to a high learning rate, we can observe that the test loss is almost the same when the model is trained on 50% of the dataset with no augmentation or the full dataset with augmentation. This implies that training is not able to extract information from the full dataset and that 50% of the unaugmented dataset already contains (almost) as much information as the model can learn in this regime. The far right plot shows why this is so. The most recently seen training batches have a significantly lower loss than older ones, but the loss reverts to the level of the test set of unseen examples within half a training epoch. This is clear evidence that the model is forgetting what it has seen earlier in the same training epoch and this is limiting the amount of information which it can absorb at this learning rate.

The second row shows for contrast what happens with a low learning rate. The full (augmented) dataset leads to a lower test loss and recently seen batches outperform random batches for many epochs into the past (note the different scale on the x-axis for the final plots in the two rows.)

Discussion

The results above suggest that if one wishes to train a neural network at high learning rates then there are two regimes to consider. For the current model and dataset, at batch size 128 we are safely in the regime where forgetfulness dominates and we should either focus on methods to reduce this (e.g. using larger models with sparse updates or perhaps natural gradient descent), or we should push batch sizes higher. At batch size 512 we enter the regime where curvature effects dominate and the focus should shift to mitigating these.

For a larger dataset such as ImageNet-1K, which consists of about 20× as many training examples as CIFAR10, the effects of forgetfulness are likely to be much more severe. This would explain why attempts to speed up training at small batch sizes with very high learning rates have failed on this dataset whilst training with batches of size 8000 or more across multiple machines has been successful. At the very largest batch sizes, curvature effects dominate once again and for this reason there is substantial overlap between the techniques used in large batch training of ImageNet and fast single GPU training of CIFAR10.

In Part 3 we speed up batch norms, add some regularisation and overtake another benchmark.