Generative Adversarial Networks

生成对抗网络(Generative Adversarial Networks)越来越火,最开始是用于MLP的生成对抗网络,就是Ian J. Goodfellow论文中提出的。后来出现了CNN架构的,效果确实提高了。现在还有SeqGAN,我不敢做评论。既然是Simplified DeepLearning系列的,自然尝试最简单的。


Generative Adversarial Networks原理

GAN的原理其实很简单。首先我们要学习的是生成器G关于数据x的分布p_g,问题从哪生成呢,这就需要我们定义一个噪声先验p_z(z),如一个均匀分布。那么生成的样本就是G(z)。然后我们需要一个判别器D(y),判断y是来自G还是x,也就是D要判断出输入是伪造的还是真实的。这就是一个对抗学习的过程:G尽量生成逼近真实的数据,使D不能分辨真伪。D要足够厉害,能够分辨真伪。形象的图如下:

gan
图自slideshare

形式的代价函数如下:

    \[\underset{G}{\min}~ \underset{D}{\max} ~V(D,G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] +\mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))]\]

理解这个很重要,可以看看这个。这里其实是将两个loss拼起来了,而且省去了一部分。有了目标函数,训练则是交替的训练,论文中有详细描述。简单的说,就是一步固定G,训练D,然后固定D,训练G。作者提到,为了防止”the Helvetica scenario”,要训练D多步,然后训练G一步。


一维高斯分布

看一个简单的例子,也是论文中的例子,我觉得很多人没理解论文中那个用均匀分布生成高斯分布的例子。为了深刻理解,保证自己没有理解错误,我必须再现一下。网络结构很简单,没几行如下。

然后就是训练,具体见github。怎么知道我是对的呢。可以看以下训练过程。

init
初始状态
m1
训练好的判别器
m2
交替训练几步之后
m2
最后结果

 

我还能说什么,完全符合理论D^*(y) = \frac{p_{data}(y)}{p_{data}(y) + p_{g}(y)}以及最后D(y) = \frac{1}{2}。在强调一下,这是一维的情况,也就是从真实数据给出一个数字和从G中生成一个数字如4,判别器无法判断4来自真实数据还是伪造的,因为这两个分布产生这个数字的概率是一样的,如最后一张图。如果真实数据产生4的概率大一点,判别器就能稍微判断一下,如最后第二张图。


MNIST测试

又得用到果蝇MNIST了,上面的例子太简单,可能不能令人信服。代码其实差不多,稍微加了个dropout,见github。训练中间结果如下,注意这是一张只循环一次的gif,你可能要刷新一下。

mnist gan
mnist gan (gif)

可能注意到怎么只生成两个数字,这就是所谓的”the Helvetica scenario”,我还不知道哪里出了问题。反正能生成数字了,够了。


实践出真知,show me the code!

国庆终于把坑填完了,DOTA还拿了暴走,完成了千年辅助的梦想!3天没出宿舍楼,欢度国庆,开心!


链接