Blogue de Vector Research : Votre réseau neuronal est-il en danger ? Le piège des optimiseurs de gradient adaptatifs
12 mars 2024
12 mars 2024
Par Avery Ma, Yangchen Pan et Amir-massoud Farahmand
tl;dr : Nos analyses empiriques et théoriques révèlent que les modèles formés à l'aide de la descente de gradient stochastique présentent une robustesse aux perturbations des données d'entrée nettement supérieure à celle des modèles formés à l'aide des méthodes de gradient adaptatif. Cela signifie que certaines techniques de formation rendent les systèmes d'apprentissage automatique plus fiables et moins susceptibles d'être perturbés par des changements inattendus dans les données d'entrée.
Vous êtes-vous déjà interrogé sur les différences entre les modèles formés à l'aide de divers optimiseurs ? Les recherches en cours se concentrent sur l'impact de ces optimiseurs sur les performances de généralisation standard d'un modèle : leur précision sur l'ensemble de test d'origine. Dans ce billet, nous explorons comment ils peuvent faire ou défaire la robustesse des modèles face aux perturbations d'entrée, que vous soyez une équipe de descente de gradient stochastique (SGD) ou une équipe de gradient adaptatif.
Nous commençons par mettre côte à côte des modèles formés avec SGD, Adam et RMSProp. Le résultat est résumé dans la figure 1. Nous nous concentrons sur deux critères dans cette figure. Tout d'abord, les trois graphiques s'alignent sur le même axe Y, qui indique la précision du test standard. Les trois axes X montrent la précision du modèle sous différentes perturbations d'entrée. Les modèles formés par SGD, Adam et RMSProp sont marqués d'une étoile, d'un cercle et d'un losange, respectivement. Chaque triplet coloré indique les modèles sur le même ensemble de données.
Il y a un petit écart vertical entre chaque triplet, ce qui montre que les modèles ont des performances de généralisation standard similaires bien qu'ils aient été formés par des algorithmes différents.
D'autre part, pour les trois types de perturbations, il y a une grande portée horizontale avec l'étoile toujours positionnée à l'extrême droite parmi les trois. Cela indique que les modèles formés par SGD sont les grands gagnants en termes de robustesse face aux perturbations. Des résultats similaires peuvent être observés avec des transformateurs de vision ou d'autres modalités de données.
Pour comprendre ce phénomène, nous l'étudions sous l'angle d'une analyse du domaine des fréquences. Tout d'abord, nous remarquons que les ensembles de données naturelles contiennent certaines fréquences qui n'ont pas d'impact significatif sur les performances de généralisation standard des modèles. Mais voici le clou du spectacle : sous certains optimiseurs, ce type d'informations non pertinentes peut en fait rendre le modèle plus vulnérable. Plus précisément, notre principale affirmation est la suivante :
Pour optimiser l'objectif d'apprentissage standard, les modèles doivent uniquement apprendre à utiliser correctement les informations pertinentes contenues dans les données. Cependant, leur utilisation d'informations non pertinentes dans les données est sous-contrainte et peut conduire à des solutions sensibles aux perturbations.
C'est pourquoi, en injectant des perturbations dans des parties du signal qui contiennent des informations non pertinentes, nous observons que les modèles formés par différents algorithmes présentent des changements de performance très différents.
Pour démontrer qu'il existe des fréquences non pertinentes lors de la formation d'un classificateur de réseau neuronal, nous considérons une tâche d'apprentissage supervisé, en supprimant les informations non pertinentes de l'entrée de formation, puis en évaluant les performances du modèle à l'aide des données d'essai originales.
Lorsque nous modifions les données d'apprentissage en supprimant les parties du signal qui ont une faible énergie (figure 2, à gauche) ou une fréquence élevée (figure 2, à droite), nous constatons que cela n'affecte pas vraiment la précision des modèles sur l'ensemble de test d'origine. Cela suggère qu'il existe une quantité considérable d'informations non pertinentes du point de vue d'un réseau neuronal.
Cette observation conduit à la première partie de notre affirmation, à savoir que les modèles n'ont besoin que d'apprendre à utiliser correctement les informations cruciales définissant les classes à partir des données d'apprentissage afin d'optimiser l'objectif d'apprentissage. D'autre part, la mesure dans laquelle ils utilisent des informations non pertinentes dans les données n'est pas bien réglementée. Cela peut être problématique et conduire à des solutions sensibles aux perturbations.
Concentrons-nous à présent sur la deuxième partie de l'affirmation. Si les réponses des modèles aux perturbations le long des fréquences non pertinentes expliquent leur différence de robustesse, nous devrions nous attendre à une baisse de précision similaire entre les modèles lorsque les perturbations se font le long des fréquences pertinentes, mais à une baisse de précision beaucoup plus importante pour les modèles moins robustes lorsque les entrées des tests sont perturbées le long des fréquences non pertinentes.
Cela nous amène à l'expérience suivante. La figure 3 montre comment la précision de la classification se dégrade sous l'effet de différents bruits gaussiens à bande limitée sur CIFAR100 et Imagenette. On remarque que la perturbation de la bande la plus basse a un impact similaire sur tous les modèles, quel que soit l'algorithme par lequel ils ont été entraînés. Il y a cependant une différence notable dans la façon dont les modèles formés par les méthodes SGD et de gradient adaptatif réagissent aux perturbations provenant de bandes de fréquences plus élevées.
Cette observation montre que lorsque les modèles, au cours de leur phase d'apprentissage, ne disposent pas de mécanismes permettant de limiter leur utilisation de fréquences non pertinentes, leurs performances peuvent être compromises si les données correspondant à des fréquences non pertinentes sont corrompues au moment du test.
En plus des études empiriques, nous analysons théoriquement la dynamique d'apprentissage de la descente de gradient (GD) et de la descente de gradient de signe (signGD), une version sans mémoire d'Adam et de RMSProp, avec des modèles linéaires. Nous présentons brièvement la configuration du problème et résumons les principaux résultats. Pour plus de détails, nous renvoyons le lecteur à notre article.
Nous nous concentrons sur la régression des moindres carrés et comparons les risques standard et adverses des solutions asymptotiques obtenues par GD et signGD. Motivés par nos observations précédentes, nous concevons un ensemble de données synthétiques qui imite les propriétés d'un ensemble de données naturelles en spécifiant des fréquences qui ne sont pas pertinentes pour générer la vraie cible. Nous nous intéressons particulièrement au risque standard :
et le risque contradictoire en cas de perturbations limitées par la norme l2 :
Nos principaux résultats sont triples.
1. Les informations non pertinentes conduisent à de multiples minimiseurs de risque standard. Pour un minimiseur arbitraire, nous pouvons obtenir son risque contradictoire comme suit :
Cela signifie que la robustesse des modèles aux changements limités par la norme l2 est inversement proportionnelle à la norme de poids des paramètres du modèle : une norme de poids plus petite implique une meilleure robustesse.
2. Avec un taux d'apprentissage suffisamment faible, le risque standard des solutions obtenues par GD et signGD peut être proche de 0.
3. Considérons un espace d'entrée tridimensionnel. Le rapport entre le risque contradictoire de la solution GD et celui de la solution signGD est toujours supérieur à 1 :
où C>0 et sa valeur dépend de l'initialisation des poids et de la covariance des données.
Les deux derniers résultats sont particulièrement importants. Ils permettent d'expliquer les phénomènes observés dans la figure 1, en particulier les niveaux similaires de généralisation standard entre les modèles et les variations de leur robustesse. Les derniers résultats soulignent que le modèle linéaire tridimensionnel obtenu par la méthode GD présente systématiquement une plus grande robustesse par rapport à la méthode GD. -par rapport au modèle obtenu à partir de signGD.
Les premiers résultats de l'analyse linéaire montrent que, pour les minimiseurs de risque standard, la robustesse à l'égard de l'effet de levier est très faible. est proportionnelle à son poids. Pour généraliser ce résultat dans le cadre de l'apprentissage profond, nous établissons un lien entre la norme de poids et la Lipschitzness des réseaux neuronaux.
Considérons le réseau neuronal feed-forward comme une série de compositions de fonctions :
où chaque est une opération linéaire, une fonction d'activation ou des opérations de regroupement. En dénotant la constante de Lipschitz de la fonction comme nous pouvons établir une borne supérieure sur la constante de Lipschitz pour l'ensemble du réseau neuronal feed-forward en utilisant.
L'approximation de la Lipschitzness des composants des réseaux neuronaux, tels que les convolutions et les skip-connections, dépend souvent de la norme des poids. Cette méthode nous permet d'établir des liens entre la norme des poids d'un réseau neuronal et sa robustesse. Essentiellement, une norme de poids plus faible suggère une limite supérieure plus petite sur la constante de Lipschitz, ce qui indique que les modèles sont moins sujets aux perturbations.
Les résultats du tableau 1 montrent que les réseaux neuronaux formés par SGD ont des constantes de Lipschitz beaucoup plus petites, ce qui explique leur meilleure résistance aux perturbations des entrées que les réseaux formés par des méthodes de gradient adaptatif, comme le montre la figure 1.
Notre travail met en évidence l'importance de la sélection de l'optimiseur pour parvenir à la fois à la généralisation et à la robustesse. Cet aperçu fait non seulement progresser notre compréhension de la robustesse des réseaux neuronaux, mais oriente également les recherches futures sur le développement de stratégies d'optimisation qui maintiennent une grande précision tout en étant résistantes aux perturbations d'entrée, ouvrant ainsi la voie à des applications d'apprentissage automatique plus sûres et plus fiables.