July 7, 2021
Self-supervised learning (where machines learn directly from whatever text, images, or other data they’re given — without relying on carefully curated and labeled data sets) is one of the most promising areas of AI research today. But many important open questions remain about how best to teach machines without annotated data.
We’re sharing a theory that attempts to explain one of these mysteries: why so-called non-contrastive self-supervised learning often works well. With this approach, an AI system learns only from a set of positive sample pairs. For example, the training data might contain two versions of the same photo of a cat, with the original in color and one in black and white. The model is not given any negative examples (such as an unrelated photo of a mountain).
This is different from contrastive self-supervised learning, which includes both negative and positive examples and is one of the most effective methods to learn good representations. The loss function of contrastive learning is intuitively simple: minimize the distance in representation space between positive sample pairs while maximizing the distance between negative sample pairs.
Non-contrastive self-supervised learning is counterintuitive, however. When trained with only positive sample pairs (and only minimizing the distance between them), it might seem like the representation will collapse into a constant solution, where all inputs map to the same output. With a collapsed representation, the loss function would reach zero, the minimal possible value.
In fact, these models can still learn good representations. We’ve found the training of non-contrastive self-supervised learning framework converges to a useful local minimum but not the global trivial one. Our work attempts to show why this is.
We’re also sharing a new method called DirectPred, which directly sets the predictor weight instead of training it with gradient update. Using a linear predictor, DirectPred performs on par with existing non-contrastive self-supervised approaches like Bootstrap Your Own Latent (BYOL).
We focused on analyzing the model’s dynamics during training: how the weights change over time and why it doesn’t collapse to trivial solutions.
To learn this, we started with a highly simplified version of a non-contrastive self-supervised learning model — in this case, one that contains a linear trunk neural network W (and its moving-averaged version, Wa) plus an extra linear predictor, Wp. Despite this setting’s simplicity, our analysis is surprisingly consistent with real-world circumstances, where the trunk network is highly nonlinear.
We first showed that two things are essential for non-contrastive self-supervised learning: There needs to be an extra predictor on the online side, and the gradient cannot be back-propagated on the target side. We were able to demonstrate that if either of these conditions is not met, the model will not work. (The weight of the trunk network simply shrinks to zero, and no learning would happen.) This phenomenon was previously verified empirically in two non-contrastive self-supervised methods, BYOL and SimSiam, but our work now shows it theoretically.
Despite our example model’s simplicity, it is still difficult to analyze and no close form of the dynamics can be derived. Although it has a linear predictor and linear trunk weight, its dynamics are still highly nonlinear. Fortunately, we still managed to land on an interesting finding: a phenomenon called eigenspace alignment between the predictor Wp and the correlation matrix F = WXW^T, if we assume that Wp is a symmetric matrix.
Roughly speaking, the eigenspace of a symmetric matrix characterizes how it behaves along different directions in the high-dimensional space. Our analysis shows that during the gradient update in training, under certain conditions, the eigenspace of the predictor will gradually align with that of the correlation matrix of its input. This phenomenon is shown not only in our simple theoretical model, but also with ResNet18 as the trunk network in real experiments with the CIFAR10 data set, where the eigenspace starts to align almost perfectly after ~50 epochs.
While the training procedure itself may be very complex and hard to interpret, the significance of this alignment is clear and easy to understand: The two matrices, the predictor and the correlation matrix of the input, finally reach an “agreement” after the training procedure.
Two natural questions follow. First, what is the most important part of the training? Is it the process of reaching an agreement or the final agreement between the two matrices? Second, if the predictor and correlation matrix need many epochs of training in order to reach an agreement, why not make it happen immediately?
The first question leads to our detailed analysis on why non-contrastive SSL doesn’t converge into a trivial solution: It turns out that if we choose to use gradient descent to optimize, then the procedure itself is important to keep the weights from arriving at trivial solutions and instead to converge on a meaningful result. From this analysis, we also gain insights on the role played by the three hyperparameters: the relative learning rate of the predictor (compared to that of the trunk), the weight decay, and the rate of exponential moving average. (More details are available in this paper.)
With this understanding of the importance of the training procedure itself, it is reasonable to wonder whether the basic building block in deep model training, that is, the gradient update, is the culprit. Will non-contrastive learning still work if we don’t use gradient descent and instead reach an agreement faster? It turns out we can circumvent the gradient update of the predictor and directly set it according to the correlation matrix at each training stage. This ensures that there is always an agreement throughout the training.
Following this idea, we’ve developed our new DirectPred. Surprisingly, on ImageNet, the downstream performance obtained by pretraining with DirectPred is better than that obtained by gradient update on linear predictor and is comparable with SoTA non-contrastive SSL methods like BYOL that uses a 2-layer nonlinear predictor with BatchNorm and ReLU nonlinearity. For 300 epoch training, the Top-5 metric is even better by 0.2 percent than vanilla BYOL. On the CIFAR-10 and STL-10 data sets, DirectPred also achieves downstream performance comparable to that of other non-contrastive SSL methods.
Because it doesn’t rely on annotated data, self-supervised learning enables AI researchers to teach machines in new, more powerful ways. Machines can be trained with billions of examples, for instance, since there is no need to hand-curate the data set. They can also learn even when annotated data simply isn’t available.
The AI research community is in the early stages of applying self-supervised learning, so it’s important to develop new techniques such as the non-contrastive methods discussed here. Compared to contrastive learning, non-contrastive approaches are conceptually simple, and do not need a large batch size or a large memory bank to store negative samples, thereby saving both memory and computation cost during pretraining. Furthermore, with a better theoretical insight into why non-contrastive self-supervised learning can work well, the AI research community will be able to design new approaches to further improve the methods, and focus on the model components that matter most.
The finding that our DirectPred algorithm rivals that of existing non-contrastive self-supervised learning methods is also noteworthy. It shows that by improving our theoretical understanding of non-contrastive self-supervised learning, we can achieve strong performance in practice and use our discoveries to design novel approaches. Novel self-supervised representation learning techniques have progressed astonishingly quickly in recent years. But we hope that our work, following a long stream of scientific efforts to develop a theoretical understanding of neural networks (e.g., understanding “lottery tickets,” or the phenomenon of student specialization), will show other researchers that with a deep theoretical understanding of existing methods, it is possible to come up with valuable fundamentally different new approaches.