Vector Research Blog: Is Your Neural Network at Risk? The Pitfall of Adaptive Gradient Optimizers

March 12, 2024

Insights Research

By Avery Ma, Yangchen Pan, and Amir-massoud Farahmand

tl;dr: Our empirical and theoretical analyses reveal that models trained using stochastic gradient descent exhibit significantly higher robustness to input perturbations than those trained via adaptive gradient methods. This means that certain training techniques make machine learning systems more reliable and less likely to be thrown off by unexpected changes in the input data.

Have you ever wondered about the differences between models trained with various optimizers? Ongoing research focuses on how these optimizers impact a model’s standard generalization performance: their accuracy on the original test set. In this post, we explore how they can make or break the robustness of the models against input perturbations, whether you are team stochastic gradient descent (SGD) or team adaptive gradient.

Comparison between models trained using SGD, Adam, and RMSProp. Models trained by different algorithms have similar test accuracy, but there is a distinct robustness difference.
Figure 1: Comparison between models trained using SGD, Adam, and RMSProp. Models trained by different algorithms have similar test accuracy, but there is a distinct robustness difference.

We start by putting models trained with SGD, Adam, and RMSProp side by side. The result is summarized in Figure 1. We focus on two criteria in this figure. First, all three plots align on the same Y-axis, which indicates the standard test accuracy. The three X-axes show the accuracy of the model under various input perturbations. Models trained by SGD, Adam, and RMSProp are marked using a star, circle, and diamond, respectively. Each colored triplet denotes models on the same dataset.

There is a small vertical gap among each triplet, showing that the models have similar standard generalization performance despite being trained by different algorithms. 

On the other hand, under all three types of perturbations, there is a large horizontal span with the star always positioned on the far right side among the three. This indicates that models trained by SGD are the clear winners in terms of robustness against perturbations. Similar results can be observed with vision transformers or other data modalities.

Why do models behave differently under perturbations?

To understand this phenomenon, we investigate it through the lens of a frequency-domain analysis. First, we notice that natural datasets contain some frequencies that do not significantly impact the standard generalization performance of models. But here is the twist: under certain optimizers, this type of irrelevant information can actually make the model more vulnerable. Specifically, our main claim is that:

To optimize the standard training objective, models only need to learn how to correctly use relevant information in the data. However, their use of irrelevant information in the data  is under-constrained and can lead to solutions sensitive to perturbations.

Because of this, by injecting perturbations into parts of the signal that contain irrelevant information, we observe that models trained by different algorithms exhibit very different performance changes.

Observation I: Irrelevant Frequencies in Natural Signals

To demonstrate that irrelevant frequencies exist when training a neural network classifier, we consider a supervised learning task, removing the irrelevant information from the training input, and then assessing the model’s performance using the original test data.

Figure 2: Irrelevant frequencies exist in the natural data. Accuracy on the original test set remains high when the training inputs are modified by removing parts of the signal with low spectrum energy (left) and high frequencies (right).
Figure 2: Irrelevant frequencies exist in the natural data. Accuracy on the original test set remains high when the training inputs are modified by removing parts of the signal with low spectrum energy (left) and high frequencies (right).

When we modify the training data by removing parts of the signal that either have low energy (Figure 2, left) or are of high frequency (Figure 2, right), we find that it does not really affect how accurate the models are on the original test set. This suggests that there is a considerable amount of irrelevant information from the perspective of a neural network.

This observation leads to the first part of our claim, that models only need to learn how to correctly use the crucial class-defining information from the training data to optimize the training objective. On the other hand, the extent to which they utilize irrelevant information in the data is not well-regulated. This can be problematic and lead to solutions sensitive to perturbations.

Observation II: Model Robustness along Irrelevant Frequencies

Let us now focus on the second part of the claim. If models’ responses to perturbations along the irrelevant frequencies explain their robustness difference, then we should expect a similar accuracy drop between models when perturbations are along relevant frequencies, but a much larger accuracy drop on less robust models when test inputs are perturbed along irrelevant frequencies.

Figure 3: The effect of band-limited Gaussian perturbations on models trained using SGD, Adam, and RMSProp. Perturbations from the lowest band have a similar effect on all the models, while models’ responses vary significantly when the perturbation focuses on higher frequency bands.
Figure 3: The effect of band-limited Gaussian perturbations on models trained using SGD, Adam, and RMSProp. Perturbations from the lowest band have a similar effect on all the models, while models’ responses vary significantly when the perturbation focuses on higher frequency bands.

This leads to our next experiment. Figure 3 demonstrates how the classification accuracy degrades under different band-limited Gaussian noises on CIFAR100 and Imagenette. Notice that the perturbation from the lowest band has a similar impact on all the models regardless of the algorithm they are trained by. There is however a noticeable difference in how models trained by SGD and adaptive gradient methods respond to perturbations from higher frequency bands.

This observation shows that when models, during their training phase, do not have mechanisms in place to limit their use of irrelevant frequencies, their performance can be compromised if data along irrelevant frequencies become corrupted at test time.

Linear Regression Analysis with an Over-parameterized Model

In addition to the empirical studies, we theoretically analyze the learning dynamics of gradient descent (GD) and sign gradient descent (signGD), a memory-free version of Adam and RMSProp, with linear models. We briefly introduce the problem setup and summarize key results. For more details, we direct the reader to our paper. 

We focus on least square regression and compare the standard and adversarial risk of the asymptotic solutions obtained by GD and signGD. Motivated by our previous observations, we design a synthetic dataset that mimics the properties of a natural dataset by specifying frequencies that are irrelevant in generating the true target. We are particularly interested in the standard risk:

and the adversarial risk under l2-norm bounded perturbations:

Our main results are threefold.

1. Irrelevant information leads to multiple standard risk minimizers. For an arbitrary minimizer, we can obtain its adversarial risk as:

This means that for models’ robustness to l2-norm bounded changes are inversely proportional to the model parameters’ weight norm: a smaller weight norm implies better robustness.

2. With a sufficiently small learning rate, the standard risk of solutions obtained by GD and signGD can be both close to 0.

3. Consider a three dimensional input space. The ratio between the adversarial risk of GD and signGD solution is always greater than 1:

where C>0 and its value depends on weight initialization and the data covariance.

The latter two findings are particularly important. They provide insights that help explain the phenomena observed in Figure 1, specifically the similar levels of standard generalization across models and the variations in their robustness. The last results highlight that the three-dimensional linear model obtained through GD consistently exhibits greater robustness against -norm bounded perturbations compared to the model obtained from signGD.

Connecting the Norm of Linear Models to the Lipschitzness of Neural Networks

The first results from the linear analysis shows that for the standard risk minimizers, its robustness against perturbation is proportional to its weight. To generalize this result in the deep learning setting, we make a connection between weight norm and the Lipschitzness of neural networks.

Consider the feed-forward neural network as a series of function compositions:

where each is a linear operation, an activation function, or pooling operations. Denoting the Lipschitz constant of function as we can establish an upper bound on the Lipschitz constant for the entire feed-forward neural network using.

Approximating the Lipschitzness of neural network components, like convolutions and skip-connections, often depends on the norm of the weights. This method enables us to draw connections between a neural network’s weight norm and its robustness. Essentially, a lower weight norm suggests a smaller upper bound on the Lipschitz constant, indicating that models are less prone to perturbations.

Table1: Comparing the upper bound on the Lipschitz constant and the averaged robust accuracy of neural networks. Notice that across all selected datasets, models trained by SGD have a considerably smaller upper bound compared to models trained by Adam and RMSProp.

Results in Table 1 demonstrate that SGD-trained neural networks have considerably smaller Lipschitz constants, explaining the better robustness to input perturbations than those trained with adaptive gradient methods as shown in Figure 1.

Our work highlights the importance of optimizer selection in achieving both generalization and robustness. This insight not only advances our understanding of neural network robustness but also guides future research in developing optimization strategies that maintain high accuracy while being resilient to input perturbations, paving the way for more secure and reliable machine learning applications.

Related:

Three people stare at a laptop with a Vector logo on it
Generative AI
Research

Benchmarking xAI’s Grok-1

A man looks at a white board with red formulas on it
Insights
Trustworthy AI

How to safely implement AI systems

Vector Faculty Member Frank Rudzicz welcoming participants to the workshop.
Natural Language Processing
Research

Breaking Ground: Natural language processing headlines Vector Institute’s latest workshop gathering