生成对抗网络(Generative Adversarial Networks)越来越火,最开始是用于MLP的生成对抗网络,就是Ian J. Goodfellow论文中提出的。后来出现了CNN架构的,效果确实提高了。现在还有SeqGAN,我不敢做评论。既然是Simplified DeepLearning系列的,自然尝试最简单的。
Generative Adversarial Networks原理
GAN的原理其实很简单。首先我们要学习的是生成器G关于数据x的分布,问题从哪生成呢,这就需要我们定义一个噪声先验
,如一个均匀分布。那么生成的样本就是G(z)。然后我们需要一个判别器D(y),判断y是来自G还是x,也就是D要判断出输入是伪造的还是真实的。这就是一个对抗学习的过程:G尽量生成逼近真实的数据,使D不能分辨真伪。D要足够厉害,能够分辨真伪。形象的图如下:

形式的代价函数如下:
理解这个很重要,可以看看这个。这里其实是将两个loss拼起来了,而且省去了一部分。有了目标函数,训练则是交替的训练,论文中有详细描述。简单的说,就是一步固定G,训练D,然后固定D,训练G。作者提到,为了防止”the Helvetica scenario”,要训练D多步,然后训练G一步。
一维高斯分布
看一个简单的例子,也是论文中的例子,我觉得很多人没理解论文中那个用均匀分布生成高斯分布的例子。为了深刻理解,保证自己没有理解错误,我必须再现一下。网络结构很简单,没几行如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | # simple mlp discriminator discriminator = Sequential() discriminator.add(Dense(mid_dim, input_dim=data_dim, activation='tanh')) discriminator.add(Dense(mid_dim, activation='tanh')) discriminator.add(Dense(1, activation='sigmoid')) # simple mlp generator generator = Sequential() generator.add(Dense(mid_dim, input_dim=sample_dim, activation='tanh')) generator.add(Dense(mid_dim, activation='tanh')) generator.add(Dense(data_dim, activation='tanh')) generator.add(Lambda(lambda x: x * 4 * sigma + mu)) # scale sample_fake = K.function([generator.input], generator.output) discriminator.trainable = False generator.add(discriminator) # training setting opt_g = Adam(lr=.0001) generator.compile(loss='binary_crossentropy', optimizer=opt_g) opt_d = Adam(lr=.0002) discriminator.trainable = True discriminator.compile(loss='binary_crossentropy', optimizer=opt_d) |
然后就是训练,具体见github。怎么知道我是对的呢。可以看以下训练过程。




我还能说什么,完全符合理论以及最后
。在强调一下,这是一维的情况,也就是从真实数据给出一个数字和从G中生成一个数字如4,判别器无法判断4来自真实数据还是伪造的,因为这两个分布产生这个数字的概率是一样的,如最后一张图。如果真实数据产生4的概率大一点,判别器就能稍微判断一下,如最后第二张图。
MNIST测试
又得用到果蝇MNIST了,上面的例子太简单,可能不能令人信服。代码其实差不多,稍微加了个dropout,见github。训练中间结果如下,注意这是一张只循环一次的gif,你可能要刷新一下。

可能注意到怎么只生成两个数字,这就是所谓的”the Helvetica scenario”,我还不知道哪里出了问题。反正能生成数字了,够了。
实践出真知,show me the code!
国庆终于把坑填完了,DOTA还拿了暴走,完成了千年辅助的梦想!3天没出宿舍楼,欢度国庆,开心!
请问下为什么要把fake data和real data的prediction加起来做为D的曲线? 就是tmp_d和tmp_m, 我看到别的地方是用fake data的prediction作为D的曲线.
这里就模拟了一下,表明符合理论。你看那个D(y)就是加起来的
大神,你用的什么版本的kera,我这里代码运行就出错TypeError: `outputs` of a TensorFlow backend function should be a list or tuple.(这是运行simpleGAN的错误),还有运行mnist那个也报错
TypeError: Value passed to parameter ‘shape’ has DataType float32 not in list of allowed values: int32, int64
很老的版本了,sample_fake = K.function([generator.input], generator.output)可以改为sample_fake = K.function([generator.input], [generator.output])。shape那个强制转为int就可以了。