Edward Newell

word2vec is an algorithm for generating word embeddings, originally described in Distributed Representations of Words and Phrases and their Compositionality by Mikolov et al. Implementing word2vec in Python Using Theano and Lasagne

I recently implemented word2vec in Python using Theano and Lasagne. This helped me deepen my understanding of word2vec, and gave me some practice in using Theano and Lasagne. Here I aim to offer a detailed but accessible description of how word2vec works, and how to implement it in Python using Theano and Lasagne.

What is word2vec?

Just want to get the code? pip install theano-word2vec or git clone https://github.com/enewe101/word2vec.git, then check out the API examples below word2vec is a machine learning algorithm that, in some sense, learns the meaning of words, and encodes that meaning in a mathematical form that can be manipulated by machines.

It works by taking a very large corpus, say 1 million New York Times articles, and analyzes how words are distributed within it. Other than the text itself, it doesn’t need to be told any more information: it learns the meaning of words simply from how they are used. This makes it an unsupervised learning algorithm.

This implementation

After explaining the ideas behind word2vec, I’ll walk through an implementation of it. You can get the working code from github or if you just want to use it in your own projects, you can pip install theano-word2vec. The implementation is backed by unit tests and soon I will benchmark it against other implementations. word2vec sometimes gets referred to as an example of deep learning. Actually, it’s a very shallow architecture. But, the embeddings it yields often get applied in deep architectures. Performance aside, its major edge on the other implementations is that it’s implemented in Python using Lasagne and Theano, so it’s (hopefully) easy to hack on and integrate into deep learning architectures of your own creation!

Underlying assumptions

word2vec is built on two assumptions. One of these assumptions is the distributional hypothesis, which is the idea that the meaning of a word can be understood from the words that tend to be near it. For example “bread” might tend to show up near “eat”, “bake”, “butter”, “toast”, etc., and this entourage gives a signal of what “bread” means.

The second hypothesis is that the meaning of words can be encoded in vectors. The idea is that each word is assigned a vector, and the relative orientations of these vectors captures something about the semantic character of the corresponding words. A word embedding is a map that assigns a k-dimensional vector to every word, with typical values of k from 50 to 500. A good word embedding encodes semantic (and syntactic) information about the words. For example, words that are similar should be assigned similar vectors. Words that have similar meanings should, for example, get similar vectors. The assignment of a vector to each word in a vocabulary defines a map called an embedding. word2vec is an algorithm that generates state of the art embeddings (as of 2016). To generate embeddings, word2vec uses statistics about which words tend to occur together in a corpus, hence it’s reliance on the distributional hypothesis.

How does word2vec work?

word2vec begins with random vectors assigned to each word. It then processes a corpus, and, based on how words co-occur with one another, gradually adjusts the embeddings. Like many learning algorithms, the adjustments made to the vectors during learning is guided by an objective function—a function whose value dictates how well the embeddings reflect the training corpus statistics.

The objective acts as a compass, showing how the embedding should be modified to improve it. The particular objective used by word2vec is based on noise-contrastive estimation, and this is the method’s key innovation. To understand word2vec in full detail, it is necessary to understand the objective function that it uses.

Noise-contrastive estimation

A description of noise-contrastive estimation can be found in Noise-contrastive estimation of unnormalized statistical models, with applications to natural image statistics by Michael U. Gutmann and Aapo Hyvärinen. At the heart of this estimation procedure is the idea that you can learn a distribution by learning to distinguish it from other distributions. As a very rough intuition, this is a bit like learning what a house cat is by learning to tell it apart from lynxes and sphinxes.

Previous methods use what actually might seem like a more straightforward application of the distributional hypothesis. If we take the distributional hypothesis as guidance, the embedding for a given word should capture something about what kinds of words tend to be found near it. Certain previous methods do this by learning embeddings that model, for a given query word, \(w\), the conditional probability of finding a given context-word, \(c\), in its vicinity: \(p(c|w)\).

Noise-contrastive estimation, and by extension word2vec, is instead based on taking two different samples, and training the model to tell them apart. The first sample which I’ll call the signal, is a set of query-word context-word pairs, \(\langle w,c\rangle\), that are sampled from a corpus of text (e.g. a bunch of New York Times articles), according to their natural rate of occurrence, \(p(w,c)\). This sampling could be done by uniformly and randomly selecting a token \(w\) in the corpus, and then selecting a neighbor token \(c\) occurring in a window of \(\pm 3\) words.

The second sample, which I’ll call the noise, is generated according to some other distribution \(p_n(w,c)\), whose most relevant property for the moment is that it isn’t the natural distribution of query-context pairs in the corpus: \(p_n(w,c) \neq p(w,c)\). The details of \(p_n(w,c)\) matter, but we’ll come to that in a minute.

During training, signal and noise examples are drawn, and the model is adjusted so as to be able to tell them apart. Let \(C\) indicate the source of a particular query-context pair, \(\langle w,c \rangle\), and let \(C=1\) mean that \(\langle w,c \rangle\) was sampled from the training corpus (according to \(p(w,c)\)), and let \(C=0\) mean that \(\langle w,c \rangle\) was generated from the noise distribution \(p_n(w,c)\). In the noise-contrastive approach, we try to learn a model of \(p(C=1|w,c)\), rather than \(p(c|w)\).

It turns out that these are highly related. An optimal model of \(p(C=1|w,c)\) implicitly yields an optimal model for \(p(w,c)\), which can be recovered if you know the noise distribution, and from this you can calculate to \(p(c|w)\) if you know the prevalence of individual words, \(p(w)\). That this works makes some intuitive sense (think of distinguishing house cats from leopards), but it is worth pondering the proof and trying to train your intuition around that fact. The proof of why it works can be found in the first theorem of this article.

Using noise-contrastive estimation to learn embeddings

So far we’ve covered that word2vec learns embeddings that are optimized to distinguish noise query-word context-word pairs \(\langle w,c\rangle\) from pairs sampled from the corpus. What’s missing, though, is how embeddings actually relate to this classification task. This brings us to the basic architecture of word2vec. Given a query-word \(w\) and a context-word \(c\), the embeddings for each are looked up, yielding two vectors, and then the dot-product of these two vectors is taken, yeilding a scalar which we’ll call the “match-score”. The match score can be positive or negative and can range in magnitude. This match-score is then converted into a probability, \(p(C=1|w,c)\), by acting on it with the sigmoid function, \(\sigma(a) = \frac{1}{1+e^{-a}}\). This overall architecture is depicted below in Figure 1:

\[p(C=1|w,c) = \sigma(v_w \cdot v_c)\]
Figure 1. The query-word \(w\) and the context-word \(c\) are translated into their respective embeddings, \(v_w\) and \(v_c\), each represented visually as a row of boxes. The probability that the pair \(\langle w, c\rangle\) was sampled from the corpus (as opposed to noise) is found by taking the dot product of the embeddings, and passing the result through the sigmoid activation function.

This provides us with the needed link between the noise-contrastive approach and the word embeddings. The word2vec algorithm learns embeddings that perform well at distinguishing “true” signal examples taken from the corpus from “fake” examples generated from the noise distribution. Embeddings that assign high probability to signal examples and low probability to noise examples are favored.

Notice how the unsupervised problem of learning word embeddings has been recast as a supervised classification problem. In fact, the loss function that is used is the cross-entropy between the actual and predicted source, which is a common objective for classification: \[J = \sum_{\langle w,c \rangle}\Big( C\ln\sigma(w \cdot v) + (1-C)\ln(1-\sigma(w \cdot v)) \Big)\]

This objective function is written so that, classifier accuracy increases with \(J\). To match the convention of using a loss function, we can instead minimize \(-J\).

Sampling

With the overall approach in view, it’s worth discussing a few details about sampling. Let's come back to a question that we set aside at the beginning: what should we use as the noise distribution? Mikolov et al generate noise samples based on the unigram distribution, which just means choosing \(w\) and \(c\) based on the relative frequencies of individual words in the corpus. Notice that this is different from the signal distribution: in the noise distribution we assemble \(\langle w,c \rangle \) by sampling \(w\) and \(c\) independantly based on their frequencies in the corpus, whereas, for the signal distribution, we again sample \(w\) based on frequency in the corpus, but we then sample \(c\) conditional on being within \(\pm 3\) tokens of \(w\).

Notice that this means that, within the embeddings, we are encoding how \(p(w,c)\) differs from \(p(w)p(c)\), and this is precisely the aspect of \(p(w,c)\) that should be informative if we take the distributional hypothesis seriously. For example, the word "ink" may not be overly common in the context of "octopus", but it is probably more common in that context than at random, because octopuses produce ink. It is how a given word alters its context statistically, relative to overall word prevelences, that is informative, and that’s what the word embeddings capture.

While it makes good intuitive sense to take \(p_n(w,c)=p(w)p(c)\), Mikolov et al found that smoothing unigram probabilities somewhat, by taking \(p'(w) = p(w)^{3/4}/Z\), actually yields better results (where \(Z\) is a normalization constant).

Sampling according to a corpus and noise can be carried out simultaneously as follows. First, we sample a query word \(w\) by selecting a token at random in the document (i.e. according to to \(p(w)\)). Next, we generate the signal pair by randomly drawing a context word from the words near \(w\) in a window of \(\pm k\), say with \(k=3\). A noise example can then be obtained by sampling c independently with probability equal to \(p'(c)\).

One additional modification to the distributions found by Mikolov to help pertains to the initial selection of the query word \(w\) for the pairs. They found that instead of sampling according to \(p(w)\), the most frequent words should be downsampled, giving the less-frequent words more opportunity to impact the embeddings. During sampling of \(w\), they discard the obtained sample according to \[ p(\mathrm{discard}|w) = \max \left( 0, 1 - \sqrt\frac{t}{p(w)} \right) \] with \(t\) set empirically to \(10^{-5}\).

Another detail is that Mikolov et al found that providing more noise examples than signal examples improves the model. They found that providing as much as 15 noise examples for every signal example improved the results.

Implementation overview

Now that we’ve covered all of the details of how word2vec is meant to work, let’s implement it in python using Theano and Lasagne. Our first step will be to write the code that holds the embedding parameters, and uses them to generate predictions according to the calculation depicted in Figure 1. We will then use that to set up the loss function as a theano symbolic expression, as well as define stochastic gradient descent updates to the embedding that changes the embeddings to minimize the loss function. Finally, we’ll handle training on an input corpus by generating the signal and noise samples. This will cover the core functionality of the theano-word2vec package, save for less interesting details like loading and saving models.

Implementing the word2vec basic architecture

We begin by implementing the architecture in Figure 1. We first import the needed modules / classes from Theano and Lasagne. Lasagne will help becuase it defines an EmbeddingLayer class, which makes it easier to generate a theano expression that includes trainable embedding parameters.


from theano import tensor
from lasagne import layers
To provide some modularity, I bundle the functionality related to the calculation expressed by the architecture in Figure 1 inside a class called Word2VecEmbedder. The following code, which builds that architecture, is that class’ __init__ function. It’s important to know that this class generally won’t be instantiated directly, but will instead be wrapped in a bigger class later. Let’s first have a look at it’s call signature:

class Word2VecEmbedder(object):
    def __init__(
        self,
        input_var,
        batch_size,
        vocabulary_size=100000,
        num_embedding_dimensions=500,
		...	# I've left out other kwargs that are less important
    ):  
Here, input_var is a theano symbolic variable that will hold a batch of word-context pairs. Each word-context pair will be a separate row, so we’ll need a batch_size by 2 matrix to hold them. The batch_size must be specified in advance, which enables better optimization by Theano. vocabulary_size defines the number of unique words to be embedded, and the num_embedding_dimensions is the dimensionality of the embedding to be found.

In the next step, we register the input_var into the class’ namespace, and then split the input_var along axis 1, which gives us two parallel arrays: one containing all of the words, the other containing all of the corresponding contexts. This let’s us use separate query and context embeddings.

        self.input_var = input_var
        self.query_input = input_var[:,0]
        self.context_input = input_var[:,1]

We will then instantiate two lasagne.layers.EmbeddingLayers. Before we do that, however, we need to pass the inputs through two separate lasange.layers.InputLayers. This is just a feature of Lasagne’s approach to building neural architectures: all in-bound data must pass through an InputLayer before entering any other kind of layer.


        # Make separate input layers for query and context words
        self.l_in_query = lasagne.layers.InputLayer(
            shape=(batch_size,), input_var=self.query_input
        )
        self.l_in_context = lasagne.layers.InputLayer(
            shape=(batch_size,), input_var=self.context_input
        )   

        # Make separate embedding layers for query and context words
        self.l_embed_query = lasagne.layers.EmbeddingLayer(
            incoming=self.l_in_query,
            input_size=vocabulary_size, 
            output_size=num_embedding_dimensions, 
            W=word_embedding_init # This was a kwarg ommitted 
        )                         # from the call signature above.

        self.l_embed_context = lasagne.layers.EmbeddingLayer(
            incoming=self.l_in_context,
            input_size=vocabulary_size, 
            output_size=num_embedding_dimensions, 
            W=context_embedding_init # Another ommitted kwarg.  
        )                            # Both have reasonable defaults.

Next we get references to the embedding output.


        self.query_embedding = get_output(self.l_embed_query)
        self.context_embedding = get_output(self.l_embed_context)
These two symbolic variables will hold a vector for each query- and context-word respectively, and hence are batch_size by num_embedding_dimensions matrices.

We now want to take the dot product of each query-context embedding pair. One way to achieve that is to do matrix multiplication of self.query_embedding by self.context_embedding.T (note that we have taken the transpose of the context embedding matrix). Recall that multiplying matrix \(A\) and \(B\) yields a matrix whose \((i,j)\)th entry is the dot product of row \(i\) in \(A\) with row \(j\) in \(B\). Of course, we only wanted the dot product of words and contexts that are from the same pair, i.e. with \(i=j\), meaning that they are on the matrix-product’s main diagonal:


        self.match_scores = tensor.dot(
            self.query_embedding, self.context_embedding.T
        ).diagonal()

It may seem wasteful to multiply these matrices rather than loop through the pairs and compute only the dot products we need. However, this code is not actually computing anything dot products, it is merely assembling the symbolic expression representing our computations, which will later be compiled by theano into code that does. This turns out to be faster than using theano.scan to define the same computation, probably because the expression based on matrix multiplication is more transparent to Theano’s optimization procedure.

Finally, we pass each dot-product through the sigmoid function (elementwise):


        self.output = sigmoid(self.match_scores)
Where I have defined the sigmoid outside the class as:

def sigmoid(tensor_var):
    return 1/(1+T.exp(-tensor_var))

The remaining code in the package wraps around this class, providing convenience functions for handling an input corpus, saving or loading models, and and takes care of actually compiling the theano functions behind the scenes to make the API simple and allow users to interact with word2vec without directly touching theano if so-desired. Alternatively, Theano and Lasagne users can integrate word2vec into larger architectures very straightforwardly, all of which is described in the API details below.

theano-word2vec API

To start working with theano-word2vec, import and instantiate a Word2Vec object:

    >>> from word2vec import Word2Vec
    >>> word2vec = Word2Vec()
Train an embedding on a corpus:

    >>> word2vec.train_on_corpus(
    ...     open('my-corpus.txt').read(),
    ...     num_embedding_dimensions=500
    ... )
Get the embeddings for specific bunch of words using a trained model:

    >>> embeddings = word2vec.embed(
    ...     'this will produce a list of vectors.  '
    ...     'If the input is a string it gets tokenized.'
    ... ) 
You can pass in pre-tokenized text too:

    >>> other_embeddings = word2vec.embed([
    ...     'control', 'tokenization', 'by', 'passing', 'a', 
    ...     'tokenized', 'list' 
    ... ])
Do anological arithmetic, and find the word with the nearest embedding:

    >>> king, man, woman = word2vec.embed('king man woman') 
    >>> queen = king - man + woman
    >>> print word2vec.nearest(queen)
    'queen'
Save and load models

    >>> word2vec.save('my-embedding.npz')
    >>> word2vec.load('my-embedding.npz')
Get a fresh Lasagne layer out of the trained model, and use it seemlessly with other Lasagne layers:

    >>> some_input_var = theano.dmatrix('input')
    >>> input_layer = lasagne.layers.InputLayer(some_input_var, shape)
    >>> embedding_layer = word2vec.layer(input_layer)
    >>> my_cool_architecture = lasagne.layers.DenseLayer(embedding_layer)