Deep Learning - Cross Entropy and Softmax

Subscribe Send me a message home page tags


#deep learning  #machine learning  #tutorial  #math 

Cross Entropy

In this post we will talk about cross entropy and softmax function. They are used in most of the neural network classification problems. The standard setup is to use the output layer to represent the distribution of label classes. For example, if we have M distinguish labels, then the size of the output layer of the neural network is M and the value of the neuron represents the probability of being the correspondent label.

classification_neural_network.png

Now we need to answer the following two questions

How to represent the label in the training data?

As mentioned earlier, the output layer of the neural network represents a probability distribution so it's natural to convert the label in training data to a distribution representation as well. Suppose the \(i^{th}\) training data point is in class \(c\), i.e. \(y_i = c\). This is a scalar value and we need to convert it into a distribution. We further assume there are \(M\) different classes so the probability distribution can be represented by a vector of length \(M\).

Let \(p_{c_j}^{k}\) denote the probability of being class \(c_j\) for the \(k^{th}\) data point in the training data. We then have

$$ p_{c_j}^{i}(y_i) = \begin{cases} 1 & \;\;\; \text{if} \; c_j = c \\ 0 & \;\;\; \text{otherwise} \end{cases} $$

How to measure the difference between a predicated value and a target value?

It's clear that we need something to quantify the difference between two distributions. We have a few options here.

Some notes about these distance measure:

Special Case: Binary Cross Entropy

One special cass is binary cross entropy. In this particular case, there are only two classes. Suppose we have class A and class B and if a data point belongs to class A, we set \(y_i\) to 1, otherwise \(y_i\) is set to 0, then the cross entropy can be calculated by the following formula:

$$ y_i log(p(y_i)) + (1 - y_i) log(1 - p(y_i)) $$

Note that the class representation is different as well. For binary classification, we only need one node to represent two classes (instead of using two nodes).

Softmax

Softmax is only one of many ways to convert a vector of scores to a probability distribution. Suppose we have a vector of scores \( (z_1, z_2, ..., z_n) \) and we want to conert it into a probability distribution vector \( (y_1, y_2, ..., y_n) \). The softmax formula is the following

$$ y_i = \frac{e^{z_i}}{\sum\limits_{k=1}^{n}e^{z_k}} $$

Hierarchical Softmax

Technically speaking, hierarchical softmax is an optimization of the probability calculation. This method is proposed by Morin and Bengio in the paper Hierarchical Probabilistic Neural Network Language Model. The discussion of the paper is in a NLP context and the method aims to speed up the calculation of \( P(\omega_t|\omega_{t-1}, ..., \omega_{t-l+1}) \), where \( \omega_t \) is a word in a document. The \(n\) in the above softmax expression is the size of vocabulary, denoted by \( |V| \).

There are two important observations mentioned in the paper:

  1. The time complexity of softmax is O(n) (or O(|V|)) due to the calcuation in the denominator.
  2. By classifying word and using conditional probability, the time complexity can be reduced.

The idea is to use a binary tree to represent word, which is similar to an encoding process and the probability of having a word is given by

$$ P(v|\omega_{t-1}, ..., \omega_{t-l+1}) = \prod \limits_{j=1}^{m} P(b_j(v) | b_1(v), ..., b_{j-1}(v), \omega_{t-1}, .., \omega_{t-l+1}) $$

where \( m \) is the lenght of the path to reach the word in the binary tree.

hierarchical_softmax.png

The probability of having a word becomes the probability of following the correspondent path in the binary tree. In the figure above, there are three probabilities to be calculated and the time complexity of the calculation of each item is O(1) because we only need to iterate 2 classes (e.g. left/right or 0/1 in the encoding). The overall time complexity is the height of the binary tree. If the binary tree is balanced then the height is \( O(log(|V|)) \).

----- END -----