Levans' workshop

Experiments and thoughts about Machine Learning, Rust, and other stuff...

Some thoughts on the objective functions of GANs


It has now been accepted as common wisdom that Generative Adversarial Networks (GAN) minimise the Jensen divergence between the real data distribution and the generated one (Goodfellow et al, 2014). In this article I'm going to study a little more closely the interpretation of these losses in terms of minimization of divergences.

Recap of the GAN game

The original GAN game can be described by the following fomula, where \(D(x)\) is the discriminator model, \(\mathcal{P}(x)\) is the real data discribution, and \(\mathcal{Q}(x)\) is the generated data distribution (note that we are not making any assumption about how \(D\) and \(\mathcal{Q}\) are computed):

$$min_{\mathcal{Q}} \quad max_D \quad \mathbb{E}_{x \sim \mathcal{P}} \left[\log D(x)\right] + \mathbb{E}_{x \sim \mathcal{Q}} \left[\log (1 - D(x))\right] \tag{1}$$

From this we can prove that the first step (maximizing over \(D\)) is solved by:

$$D^{\star}(x) = \frac{\mathcal{P}(x)}{\mathcal{P(x)}+\mathcal{Q}(x)} \tag{2}$$

Plugging this solution back into the original equation, solving for \(\mathcal{Q}\) then comes to:

$$min_{\mathcal{Q}} \quad \mathbb{E}_{x \sim \mathcal{P}} \left[ \log \left( \frac{\mathcal{P}(x)}{\mathcal{P(x)}+\mathcal{Q}(x)} \right) \right] + \mathbb{E}_{x \sim \mathcal{Q}} \left[ \log \left( \frac{\mathcal{Q}(x)}{\mathcal{P(x)}+\mathcal{Q}(x)} \right) \right] \tag{3}$$

Which we can reformulate as

$$min_{\mathcal{Q}} \quad \operatorname{D}_{KL}\left[ \mathcal{P} \middle\| \frac{\mathcal{P}+\mathcal{Q}}{2} \right] + \operatorname{D}_{KL}\left[ \mathcal{Q} \middle\| \frac{\mathcal{P}+\mathcal{Q}}{2} \right] -2 \log(2) \tag{4}$$

Or

$$min_{\mathcal{Q}} \quad 2 \operatorname{JSD} \left[ \mathcal{P} \middle\| \mathcal{Q} \right] - 2 \log(2) \tag{5}$$

And from this, like Goodfellow et al, 2014, we can conclude that the GAN game does indeed minimize the Jensen divergence between \(\mathcal{P}\) and \(\mathcal{Q}\).

The devil is in the details

If we want to implement this with a parametric generative model for \(\mathcal{Q}\), say we have a \(\mathcal{Q}_{\theta}\), and want to train this by gradient descent.

For example, this \(\mathcal{Q}_{\theta}\) could be implemented using a neural network \(G_{\theta}(z)\): \(\mathcal{Q}_{\theta}(\hat{x})\) would then be the density distribution of the generated samples \(\hat{x} = G_{\theta}(z)\) with \(z\) sampled from some prior \(\rho(z)\).

Let's assume further that we manage to create this optimal discriminator in a way that makes it helpful to train a generator without vanishing the gradients (which as we know is actually one of the biggest challenges of GANs, as detailed by Arjovsky et al, 2017):

$$D^{\star}_{\theta}(x) = \frac{\mathcal{P}(x)}{\mathcal{P(x)}+\mathcal{Q}_{\theta}(x)} \tag{6}$$

Note here that this optimal discriminator depends on \(\mathcal{Q}_{\theta}\), and thus on \(\theta\), the parameters of our generator model.

We now have to minimize the following loss to train our generator:

$$\mathcal{L}(\theta) = \mathbb{E}_{x \sim \mathcal{P}} \left[\log D^{\star}_{\theta}(x)\right] + \mathbb{E}_{x \sim \mathcal{Q}_{\theta}} \left[\log (1 - D^{\star}_{\theta}(x))\right] \tag{7}$$

This loss seems very familiar, it is the loss used to train the generator network in a GAN. Except, when actually training the network, we drop the first term, and optimize the following loss:

$$\widetilde{\mathcal{L}}(\theta) = \mathbb{E}_{x \sim \mathcal{Q}_{\theta}} \left[\log (1 - D^{\star}_{\theta}(x))\right] \tag{8}$$

We have very good reasons for doing so: mostly the first term's gradient relative to \(\theta\) is not really computable, as it would require backpropagating through the whole training process of the discriminator.

We also have bad reasons for doing so: implicitly assuming that this term does not depend on \(\theta\) and that its gradient is thus \(0\).

In both cases, as a result we are not actually training our generator to minimize the Jensen divergence any more but rather the following divergence:

$$\widetilde{\mathcal{L}}(\theta) = \operatorname{D}_{KL}\left[\mathcal{Q}_{\theta} \middle\| \frac{\mathcal{P}+\mathcal{Q}_{\theta}}{2}\right] - \log(2) \tag{9}$$

With half of the Jensen divergence missing, we end up half way between \(\operatorname{JSD}\left[\mathcal{Q}_{\theta}\|\mathcal{P}\right]\) and \(\operatorname{D}_{KL}\left[\mathcal{Q}_{\theta}\|\mathcal{P}\right]\). This may explain some of the observed dynamics of GANs: this \(\widetilde{\mathcal{L}}\) loss is more forgiving than the Jensen divergence regarding mode dropping (it's the \(\operatorname{D}_{KL}\left[\mathcal{P} \middle\| \frac{\mathcal{P}+\mathcal{Q}_{\theta}}{2}\right]\) half of the Jensen divergence that becomes big if \(\mathcal{Q}(x)=0\) while \(\mathcal{P}(x) > 0\), not the one we've kept).

Playing with the generator's loss

Other losses for the generator can be considered, depending on the properties and dynamics we want:

$$\widetilde{\mathcal{L}_2}(\theta) = - \mathbb{E}_{x \sim \mathcal{Q}_{\theta}} \left[\log D^{\star}_{\theta}(x)\right] \tag{10}$$

which tends to provide better gradients to the generator, but does not clearly express a divergence. Though, by doing the same work with \(D^{\star}_{\theta}\) as before, we can re-express it like this:

$$\widetilde{\mathcal{L}_2}(\theta) = \operatorname{D}_{KL}\left[\mathcal{Q}_{\theta} \middle\| \mathcal{P} \right] - \operatorname{D}_{KL}\left[\mathcal{Q}_{\theta} \middle\| \frac{\mathcal{P}+\mathcal{Q}_{\theta}}{2}\right] + \log(2) \tag{11}$$

And using the convexity of the Kullback-Leibler divergence, we can show that \(\operatorname{D}_{KL}\left[\mathcal{Q}_{\theta} \middle\| \frac{\mathcal{P}+\mathcal{Q}_{\theta}}{2}\right] \leq \frac{1}{2} \operatorname{D}_{KL}\left[\mathcal{Q}_{\theta} \middle\| \mathcal{P} \right]\), and thus:

$$\widetilde{\mathcal{L}_2}(\theta) \geq \frac{1}{2} \operatorname{D}_{KL}\left[\mathcal{Q}_{\theta} \middle\| \mathcal{P} \right] + \log(2) \tag{12}$$

which is much less forgiving than the original \(\widetilde{\mathcal{L}}\) regarding the generation on non-realistic samples, and might explain why it gives better gradients to the generator.

$$\widetilde{\mathcal{L}_3}(\theta) = - \mathbb{E}_{x \sim \mathcal{Q}_{\theta}} \left[\log \frac{D^{\star}_{\theta}(x)}{1-D^{\star}_{\theta}(x)}\right] \tag{13}$$

which is then equivalent to the KL divergence: \(\widetilde{\mathcal{L}_3}(\theta) = \operatorname{D}_{KL}\left[\mathcal{Q}_{\theta}\middle\|\mathcal{P}\right]\), so possibly a better version of the previous one.