生成对抗网络(GAN)教程 - 多图详解

一、生成对抗网络简介

1. 生成对抗网络模型主要包括两部分:生成模型和判别模型。 生成模型是指我们可以根据任务、通过模型训练由输入的数据生成文字、图像、视频等数据。            

[1]比如RNN部分讲的用于生成奥巴马演讲稿的RNN模型,通过输入开头词就能生成下来。

[2]或者由有马赛克的图像通过模型变成清晰的图像,第一张是真实,第四张是合成的。    


[3]或者我们为了生成一类图像时会通过输入指定的分布形态数据,这样经过训练数据和输入分布形态后,就可以将分布形态输入网络获得这一类的图像(可用于数据集的扩展,图片的合成,思路等下介绍),左图是看了大量街景记录生成的未见过的场景。



二、生成对抗网络 —— 基本机构

1. 生成模型从本质上是一种极大似然估计,用于产生指定分布数据的模型,生成模型的作用是捕捉样本数据的分布、将原输入信息的分布情况经过极大似然估计中参数的转化来将训练偏向转换为指定分布的样本。


2. 判别模型实际上是个二分类,会对生成模型生成的图像等数据进行判断,判断其是否是真实的训练数据中的数据。


三、生成对抗网络 —— 模型介绍

1. 对于GAN,一个简单的理解是可以将其看做博弈的过程,我们可以将生成模型和判别模型看作博弈的双方,比如在犯罪分子造假币和警察识别假币的过程中:            

[1]生成模型G相当于制造假币的一方,其目的是根据看到的钱币情况和警察的识别技术,去尽量生成更加真实的、警察识别不出的假币。            

[2]判别模型D相当于识别假币的一方,其目的是尽可能的识别出犯罪分子制造的假币。 这样通过造假者和识假者双方的较量和朝目的的改进,使得最后能达到生成模型能尽可能真的钱币、识假者判断不出真假的纳什均衡效果(真假币概率都为0.5)。

2. 我们可以将上面的场景映射成图片之间生成模型和判别模型之间的博弈过程,博弈的简单模式如下:生成模型生成一些图片->判别模型学习区分生成的图片和真实图片->生成模型根据判别模型改进自己,生成新的图片->判别模型再学习区分生成的图片和真实图片.....

上面的博弈场景会一直继续下去,直到生成模型和判别模型别无法提升自己,这样生成模型就会成为一个比较完美的模型。


3. 下图是基本GAN网络的模型结构,我们现在开始由下往上介绍,先看右图,右图是生成模型函数的训练网络。          

[1]首先我们先将正态分布的噪声数据z(必须统一的一类分布数据,因为训练模型是按分布情况转换的,模型作用是将一类分布转化为任务需要的数据分布情况)输入到网络中。


[2]噪音数据会通过生成模型网络G(z)生成造假的图像数据,因为我们的目的是制造尽可能让判别模型分不清往图像,所以会结合判别网络进行模型训练,通过这种训练来使生成模型有更好的造假效果。(结合判别模型知道误差)


[3]接下来会将生成模型输出的造假图片数据输入到判别模型网络D(x)中,之后进行网络的参数计算得到最后的判别输出,输出0-1之间的参数值,0表示是造假信息,1表示是真实数据,对于生成模型产生的造假信息,我们希望判别模型能够输出接近0的输出值,从而有效判断真假。        

在下面网络进行生成模型训练时,在判别模型部分生成误差后,我们在训练时判别网络的网络参数并不需要发生变化,只是把最后按生成模型目标函数计算的误差往前一直传,传到生成网络来更新生成网络的参数,这样就可以完成生成网络的训练。


4. 下图中上半部分是判别模型的训练模型,我们会结合真实样本集和造假样本集按批次进行判别模型网络的训练,根据其目标函数进行梯度下降:



四、生成对抗网络 —— 模型训练

1. 进行网络训练时,判别模型的目标函数是:


其中D(x)是判别模型的输出结果,是一个0-1范围内的实数值,用来判断图片是真实图片的概率,其中Pr和Pg分别代表真实图像的分布与生成图像的数据分布情况,可以看出目标函数是找到使得后面两个式子之和最大的判别模型函数D(z),后面两个式子是一个加和形式,其中:

[1]是指使得真实数据放入到判别模型D(x)输出的计算值和整个式子值尽可能大,

[2]是指使得造假数据放入到判别模型D(x)输出的计算值尽可能小和整个式子值尽可能大,这样整合下来就是使得目标函数尽可能大,因此在训练时就可以根据目标函数进行梯度提升。

2. 生成模型的目标是让判别模型无法区分真实图片和生成图片,其目标函数是:


也就是找到生成函数g(z)使得生成模型的目标函数尽量小,所以两者是对抗的。

3. 下图是GAN的一个算法流程,我们会使用目标函数在两个网络中进行参数的梯度改变。


4. 对于上面的最大最小化目标函数进行优化时,我们最直观的处理方式是将生成网络模型D和判别网络模型G进行交替迭代,在一段时间内,固定G网络内的参数,来优化网络D,另一段时间固定D网络中的参数,来优化G网络中的参数(这样的话,上图俩部分就是两个网络)。

5. 我们举个优化效果的例子,假设刚开始的真实样本分布、生成样本分布、判别模型分别对应左图的黑线、绿线、蓝线。

那么:[1]当我们固定生成模型,而优化判别模型时,我们发现判别模型会变得有很好的对黑线和蓝线的区分效果。(偏向于在两者中间)


2]当我们固定判别模型,改进生成模型时,我们发现生成模型生成的数据分布(绿线)会不断往真实数据分布(黑线)靠拢,也就如第三幅图,使得判别模型很难分离判断。          

[3]进行1、2过程进行大量迭代后,我们会得到最后图的效果,生成样本数据分布和真实数据分布基本吻合,判别模型处于纳什均衡,做不了判断。


五、生成对抗网络 —— 模型应用

1. 对于GAN的应用,我们可以分为两大类:            

[1]一类是学后联想任务,这类任务特点是只有大量的目标样例 [y1,y2,y3...],这类任务我们只知道我们想让网络模仿的东西,比如我们使用大量毕加索的画像去训练GAN,让其模仿生成毕加索的画像,此时生成模型网络的输入是正态分布即可,网络会学着把正态分布转成猫的图像的数据分布情况(可用于数据生成)。      


[2]另一类是目标引导的训练任务,这类任务是有目标指定的,训练数据是{[x1,y1],[x2,y2]...}这种形式,我们会学着让生成网络学会去除图片的马赛克、图像风格变化等任务。此时生成模型的输入应该是原始转换前的图片(按批次放,每批次数据称为一个噪音,噪音维度越小,图片共性越大)。


2. 早期的GAN模型中,生成模型输入的是一些服从某一简单分布(例如高斯分布)的随机噪声z,输出是训练图像相同尺寸的生成图像。 如能实现数据量的扩充,比如对手写数字图像的扩充,如下图输入高斯分布,将大量0数据放入生成对抗网络进行训练,我们就能获得很像9的右面的图片。再比如我们使用大量人类的脸或者猫的图片,GAN就能通过对高斯分布转成猫图像数据分布学习来想想生成新的猫的图像、或者合成场景、脸等数据。

3. 应用于图像方面,更强的改进模型是DCGAN(convolutional nn for Gan),DCGAN中生成模型用到的神经网络是与卷积神经网络相反的网络形式(反卷积),判别是正卷积,网络形式如下图。

为什么是这种形式,他与通常的CNN网络比较可以考虑成是画画和理解画两个不同的过程。            

[1]我们在画画时是先考虑构图(高维联想),之后开始画轮廓、画线条、填充颜色。  


[2]而我们在理解画时,是按照视觉皮层的理解思路,先看到边缘的即视的东西,之后将这些低维的东西组合进行抽象信息的联想。 所以我们现在也是按照初始分布数据画出一幅画的过程,所以与普通CNN是相反的。

所以我们现在也是按照初始分布数据画出一幅画的过程,所以与普通CNN是相反的。


5. 使用DCGAN我们可以实现图像合成的功能,如左图我们可以将左边的图像作为生成模型的输入(将图像像素不断按行输入到生成模型中,在生成模型的输出部分再拼合成造假的图片x),进行输入时会先将三个图片(x1,x2,x3)进行加减合成得到合成的图片分布z,之后将z输入到生成模型得到对应的造假图像x,x和y图像通过判别模型得到D(x)和D(y)结果,再根据目函数计算梯度值实现对生成模型和判别模型参数的更新。


还可以进行图像还原、生成卡通图片等。。。



六、生成对抗网络 —— 模型特点

1. 生成对抗网络仍存在的问题有:        

[1]解决不收敛的问题。所有的理论都认为 GAN 应该在纳什均衡上有卓越的表现,但梯度下降只有在凸函数的情况下才能保证实现纳什均衡。当博弈双方都由神经网络表示时,在没有实际达到均衡的情况下,让它们永远保持对自己策略的调整是可能的。        [2]难以训练(目标函数难构建)。GAN模型被定义为极小极大问题,没有损失函数,在训练过程中很难区分是否正在取得进展。GAN的学习过程可能发生崩溃问题,生成器开始退化,会生成同样的样本点,无法继续学习。当生成模型崩溃时,判别模型也会对相似的样本点指向相似的方向,训练无法继续。    

[3]判别器D效果越好,生成器梯度消失越严重,最后难训练。

2.对于GAN模型最难的是对目标函数的构建,构建的式子非常影响最后优化出模型的效果,而且要解决各种梯度消失和梯度爆炸问题。

所以很多改进模型都是在目标函数的构造上动手脚,如著名的wGAN方法,进行了以下几点的改进:              

[1]判别器的最后一层去掉sigmoid函数。              

[2]生成器和判别器的目标函数中不用log函数包装。            

[3]不要急于动量的优化算法,而是使用SGD、RMSprop等方法。            

[4]对判别器加上Lipschitz限制,当输入的样本稍微变化后,判别器给出的分数不能发生太过剧烈的变化。


来源:CSDN,作者:马飞飞,,转载此文目的在于传递更多信息,版权归原作者所有。
原文:https://blog.csdn.net/maqunfi/article/details/82220297
版权声明:本文为博主原创文章,转载请附上博文链接!

最新文章