How word frequency affects language models
As I started my PhD a few months ago, I believed that contextual embeddings coming from pre-trained language models aimed at representing words in a vector space as humans might in their thought process. In other words, I thought the BERT-like embedding spaces tended to learn the underlying metric of word relatedness depending on context.
To the impish, this belief is easy to catch because of the great adaptability of these embeddings during fine-tuning or domain adaptation. This adaptability leads to think that the embedding space captures most of the semantic and syntactic substance of the words, which only needs to be extracted and slightly adjusted for downstream tasks.
(Arguable) beliefs about Word2Vec models that may have impacted my understanding of contextual representations. These embeddings were isotropic in the sense that most of the representations were close to orthogonal to each other, which makes sense in a high-dimension space. (source: OpenClassrooms)
A probability distribution on a vector space is said to be isotropic if we expect a similar density when considering every direction of the space. The covariance matrix of a perfectly isotropic distribution is the identity $I_D$.
That’s with this belief in mind that I started reading about the representation degeneration problem.
Representation degeneration of language models
To be honest, I already knew that LM representations were far from perfect in the sense I describe above. When working on downstream tasks for industrial applications, or when trying to use naive zero-shot sentence embeddings, one can clearly see inherent flaws, for instance when it comes to comparing metrics like cosine similarity with our human perception of similarity. This raises concerns about the architectures we use, the data we pre-train with, the optimization parameters we take into account, the scale of our models, among other things.
But one factor that was (is?) often under-estimated is the learning objective itself, i.e. the loss. Until you read the excellent paper Representation Degeneration Problem in Training Natural Language Generation Models (Gao et al., ICLR 2019).
The cone problem: if you scatter the first singular dimensions of contextual representations, a cone appears.
This paper argues that cross-entropy loss inherently pushes low-frequency categories (e.g. tokens) away from the origin in a given direction. Let’s dive into this.
The problem with cross-entropy and rare words
The uniformly negative direction
The cross-entropy loss is used to learn how to predict tokens in many language models. It goes like this:
\[\mathcal{L}_{ce} = \frac{1}{N} \sum_{k=1}^N - \log(\frac{\exp(\langle h_k,w_k \rangle)}{\sum_{l=1}^{V}\exp(\langle h_k,w_l \rangle)})\]where $h_k$ describe the hidden states for the $N$ tokens used for training, and $w_l$ are the vectors from the last projection layer mapping to the vocabulary prediction in $\mathbb{R}^V$.
A way to explain the softmax computation of BERT-like models as an operation between word representations and hidden states
Now let’s focus on the extreme case of word $S$ (for sad) that never appears in the training set but is included in the vocabulary. If we freeze all parameters except for $w_S$, the optimization problem becomes:
\[\min_{w_S}{\frac{1}{N} \sum_{k=1}^N -\log\frac{C_1}{\sum_{l\neq S}\exp(\langle h_k,w_l \rangle) + \exp(\langle h_k,w_S \rangle)}} = \min_{w_S}{\frac{1}{N} \sum_{k=1}^N \log(\exp(\langle h_k,w_S \rangle) + C_2)}\]where $C_1$ and $C_2$ are constant with respect to $w_S$.
Let’s imagine that the model is on the edge of converging, and that the structure of the hidden states $h_k$ doesn’t evolve much with new samples. What Gao et al. show is that there will most likely exist a specific direction $v$ so that $\exp(\langle h_k,v \rangle)$ is negative for all $k$. In other words, for a fixed structure of the hidden states, there will almost surely exist a direction for the $w_S$ that will minimize the loss.
This uniformly negative direction thus emerges as an optimal direction for the loss function when optimizing $w_S$, possibly pushing it towards infinite magnitude to amplify the effect. This means that even after the hidden states have converged, the input embeddings of unused words are still endlessly pushed in a specific direction.
Rare words
Gao et al. show that this phenomenon smoothly decays as word frequency increases, meaning it also happens for rare words. The thing is, natural language is full of rare words. Here is a plot of the distribution of BERT tokens in the Wikitext-103 dataset: