A look at some of the basic research on contrastive learning
Earlier this year I had to write a technical report on the topic of contrastive learning (link). This post is an effort to create something more digestible and adequate for reading. The goal is to give an introduction to representation learning. This is done in the first section, also discussing its advantages and other related topics such as self-supervised or transfer learning. The second section is dedicated to the techniques for representations learning labelled as “contrastive”. We will go over some of the most important topics: loss functions, data generation techniques, negative examples and applications.
Consider some neural network such as the following:
Neural networks are the composition of transformations, where every intermediate layer applies a parameterized function to the input coming from the previous one. The output of each intermediate layer is a new view or representation of the input data. These representations are nothing but vectors, where the values are features or characteristics of the input. Of course, neural networks are not interpretable, and these vectors do not necessarily hold any meaning for us. The network learns to progressively transform the data into a set of features that is useful to perform the task we are training it for. For example, if we are training for image classification of animals, the network might learn a representation with the number of legs, the shape of the eyes, the color…
At the end of the network, we have a final layer that transforms the last representation into the prediction, such as a regression value or a probability distribution over the different classes in our classification problem.
When we train a network on one dataset, we are not only teaching it how to make predictions, we are teaching it to extract intermediate representations with features that are useful for the task. This is interesting because these representations might be useful for other tasks as well. For example, imagine we want to add a new class of animals in our classification example. Perhaps the network is already capable of getting features to predict that class. We can then modify the final head and freeze the rest of the weights of the network. With fewer parameters to learn, we can train with less data, much faster and with a smaller energy consumption. In fact, if the learned representations are rich enough, we might be able to use the same network for a myriad of different tasks, training only a small head at the end of the network. This is called transfer learning: first pretraining for a pretext task, and then fine-tuning the model (perhaps modifying the head) on the downstream task we are actually interested in solving. The following diagram shows this process.
Another important and related concept is that of self-supervised learning.
This term, although poorly coined in my opinion, can be useful to refer to a
particular set of tasks. It refers to supervised tasks where we can generate a
training signal automatically, without the need for human labeling. One example
is the training procedure of the language model BERT
Let’s consider another example, this time with images. Images allow us very simple forms of data augmentation: we can crop, rotate, translate, convert to greyscale, modify the brightness… We can train a neural network to produce similar representations for augmented views of the same image, and different representations for views of different images. Again, we can generate examples with any image we find, allowing us to create very large datasets and train massive networks that obtain general representations. As we will see, this is our first example of contrastive learning.
Representation learning, together with transfer learning and self-supervised learning, has produced a revolution in the field. Some of the benefits we get from these techniques are:
It can greatly improve performance for some tasks. Let’s consider the example of sentiment analysis in text. Textual data is very diverse and nuanced. Even the same word can represent different sentiments depending on the context. A small dataset of examples labeled by humans is unlikely to produce a model with detailed understanding of text. By pretraining a large model on a difficult task with a gigantic set of examples, we get a model with nuanced understanding of text, and this is likely to produce better results on our downstream task as well.
Another way to put the previous point is that it alleviates data bottlenecks. If we don’t have enough data to train a decent model for a task, we can try to pretrain on some pretext task and then fine-tune small parts of the model with the supervised data we have.
It improves generalization. Let’s consider the sentiment analysis task again. A pretrained model learns representations for many different tokens in a lot of different contexts, since the dataset is very large. In this common feature space, we may find specific features about sentiment, which may be used by the final head we train on supervised data.
Since the feature space is common to a large number of tokens and contexts, even if they do not appear on the supervised data, the model will be able to extrapolate what they learn.
It allows us to use large and powerful models at a fraction of the cost. We find many open source models nowadays, and using them for inference requires much fewer resources than for training.
Some model have zero-shot capabilities. This means that we can perform
inference on tasks it wasn’t trained for without any need for fine-tuning.
We have all seen ChatGPT performing a myriad of different tasks, even though
it was only trained to do next-token prediction. Another example is
CLIP
Of course, not everything is bright. Training big and general models is very expensive, and only big companies with large resources can afford to do so. We may also find that these models are overkill for some tasks. A smaller and more specialized model can sometimes perform better, specially if the learned representations aren’t very good for the particular task. Finally, although zero-shot capabilities are very interesting, performance tends to be subpar.
Roughly speaking, in contrastive learning we try to learn vector representation of the input data by comparing and contrasting examples with each other. For instances that we deem similar in some sense, we want their vector representation to be close according to some distance or similarity function. On the other hand, the vector representation of dissimilar examples should be further apart. Images provide an easy example. We would expect pictures of animals to be close together, and far apart from pictures of buildings. Inside the cluster of pictures of animals, we would also like for pictures of the same animal to be closer than pictures of different animals, and so on.
Let’s establish a framework and notation for the rest of the text.
We first consider a set of networks, called encoders, that map the input
data to a fixed sized vector representation of dimension
Mathematically, the
Here,
Additionally, we can assume without loss of generality that a final
head
The metric is given by a distance or similarity function
A loss function
Finally, given an example
This structure is shown graphically for two examples in the following diagram:
Loss functions are generated from the distance/similarity metric applied to different pairs of examples. It should penalize similar examples being far away and dissimilar examples being close. The latter is as important as the former, otherwise the network could learn the trivial representation of mapping any example to the same vector. Let’s take a look at some common functions.
Let
where
The triplet loss
In this last loss function, a full training example requires now both a similar and a dissimilar instance. However, the number of interactions between different examples is still very limited. The use of many negative examples, specially those that are hard for the model, are crucial for a good and efficient learning process. This makes sense intuitively: there are many more ways in which images can be different than similar.
Increasing the number of interactions between instances at once, Song et al.
proposed the Lifted Embedding loss
where
and
so this loss can be interpreted as an adaptation to the triplet loss, trying to mine the hardest negative example of the set for each positive pair, and also squaring the result. As usual in deep learning, instead of using the full set of examples, the loss is approximated with a batch of smaller size.
Let’s now take a probabilistic perspective. Let
where
where the expected value would be approximated by its population mean (e.g.
with a single batch). Keeping the previous notation for the sets of positive
pairs
A more recent loss function is InfoNCE
Minimizing the negative log-likelihood of the true positive yields:
As we will see, it is a common setting to have batches of pairs of similar
instances
A temperature parameter
A small value of
We have already seen how to create a loss function to train our model, but how do we get appropriate data? In most cases, the question is how to generate pairs of similar examples, as you can often get away with the assumption that examples in different pairs are dissimilar. This is the case when you have a huge domain, such as a large amount of internet images. Getting two similar examples in the same batch by sampling randomly is unlikely. Still, with this assumption we could incur in false negatives, something we will discuss in the next section. Let’s see now some of the ways in which datasets are generated for contrastive learning.
Of course, human annotation is always an option, albeit a very expensive one. Unfortunately, this is sometimes necessary. We previously talked about CLIP, a model that learned representations for text and image, mapping a text that correctly describe an image to an embedding that is close to the embedding of the image. In this case, there is no way around getting someone to create captions for different images (although in the future synthetic data generation with multimodal models might be an option).
One way to reduce the amount of data necessary when human labeling is required is to first pretrain the encoders with some other self-supervised technique. For example, in CLIP, you could train the text and image encoders separately with a denoising task, and then use the supervised data to align the representations of both encoders to a common metric space.
What if we want to generate large amounts of data from existing sources automatically? As explained earlier, we will generate pairs of similar instances, and assume that images in different pairs are dissimilar. The following are just some examples:
Data augmentation. In the particular context of contrastive learning, a transformation that does not change the instance semantically can be applied to generate positive examples. To illustrate this, assume that we are contrasting images, and we want to map images based on the concepts inside them. Then if we crop, blur, or saturate the color of an image of a cat, it is still an image of a cat. This way, augmentations of similar or equal instances are also similar, whereas augmentations of dissimilar examples are also dissimilar.
In the case of images, we find transformations such as rotations,
translations, cutouts, cropping and resizing, blurring, applying noise…
These were used, for example, in
Textual data is a bit more complex. Fang et al.
Careful exploration of data augmentation techniques can be very important.
For example, Chen et al.
Multi-sensor. When multiple inputs are captured simultaneously, we can use all inputs from corresponding times. For example, capturing the same image from different angles with multiple cameras, or the audio and images in videos.
Continuity. This can be applied to domains where there are sequences with some continuity. For example, videos are sequences of images where most contiguous frames are almost equal. We can take images very close in time and assume they are very likely to be similar.
The loss functions we have seen thus far use both positive and negatives
examples. We already gave a theoretical reason for this: if negative examples
were not used, a trivial representation mapping everything to one vector would
obtain perfect performance. There is empirical evidence showing that
performance is increased from contrasting with many negative examples. For
example, Chen et al.
The concept of negative examples or, more generally, preventing representational collapse, is quite important in contrastive learning, and there are some topics around it that are worth discussing.
In the data generation section we already talked at the possibility of false
negatives when there is no supervision. This is because we are drawing from the
full distribution of examples instead of the distribution of negative examples,
creating a bias. Some work has been done on alleviating this while not requiring
manual labels for negatives. While we won’t get into detail, the reader may
find an example in the Debiased Contrastive loss proposed in
The use of examples in the same batch as negatives presents a major drawback in terms of computing resources. If we want to use many negative examples, we are forced to select a very large batch size, requiring a lot of GPU memory. Some work has been done on decoupling batch size and the number of negatives by sampling negatives from an offline memory bank. In this setting, an encoded representation is kept on-disk for some or all examples, and the loss function is not back-propagated through them. As the encoder is updated in the training process, the representations on-disk get outdated. The main difference between approaches lies in the method to keep these offline representations updated as the encoder is optimized.
Wu et al.
He et al.
This smoother update yielded better empirical performance.
While increasing the number of negative examples has been observed to improve performance, this might be due to the increased probability of finding meaningful negatives to learn from.
Conceptually, we might intuit that it is more difficult for a model to
distinguish between a dog and a cat than between a dog and a building. In
practice, we could select hard negative examples based on their representation.
Examples with close representations are difficult to distinguish for the model.
Kalantidis et al.
However, hard negative mining has some disadvantages, such as the increased time complexity from sampling close neighbours and the increased probability of drawing false negatives in the self-supervised setting as the encoder gets better.
If the only reason to use negative examples is to prevent the representations from collapsing onto one single vector, then other techniques to prevent it could allow us to remove negative examples.
The work by Grill et al.
It is not obvious to me that this approach would not converge to a collapsed
representation. The authors argue that the update to the target parameters is
not exactly according to the gradient of the loss with respect to
SimSiam
where
Bardes et al.
where
Invariance. Forces the positive examples to be close together:
Variance. Forces the variance of each feature to be positive, above some threshold value:
Covariance. Forces the features to be uncorrelated:
The final loss is a weighted sum of the previous terms:
where
We have already talked about the benefits of representation learning in general. But does contrastive learning have any particular advantages? This is what we discuss here.
The first advantage is that it provides other training ideas for representation
learning. These techniques have shown great success, reaching state-of-the-art
performance in semi-supervised and transfer learning. The SimCLR model for
images
A second advantage is the metric representation that they learn. These
representations can later be used in very efficient
Although not exclusive to contrastive learning, another advantage is that these
techniques allow us to align different encoders in a common representation
space, in particular encoders of different modalities. This is the case of the
CLIP
Finally, models trained through contrastive learning have zero-shot classification capabilities. If you have some anchor examples that represent a category, you can classify a new example by selecting the category with the closest metric representation. In the case of CLIP, since it has a text encoder, you can provide any textual label that describes the category, and apply this inference technique. If there are no textual encoders, you could still use a small set of labeled data to create a metric representation of each class. Of course, performance is not as good as with fully-trained models, specially for very niche categories, but it might still be useful with very few resources.
That it is for this post. I hope you found it useful and not too boring! As always, the best way to learn about anything is to get your hands dirty. In this case, you can try programming some training loop with a small dataset and a loss function, such as NT-Xent. You can also try programming some of the techniques to avoid representational collapse and see how they perform. Don’t doubt in hitting me up with any feedback or corrections!