GAN中生成器与判别器的对抗训练
GAN由生成器和判别器组成,通过对抗训练相互博弈。请解释两个网络的结构设计、损失函数以及纳什均衡的收敛过程。
回答
我还是少年
GAN(Generative Adversarial Network)由Goodfellow等人于2014年提出,开创了生成对抗学习的范式。
生成器(Generator, G):
- 输入:随机噪声向量z(通常从高斯分布或均匀分布采样)。
- 输出:生成样本G(z),希望逼近真实数据分布p_data。
- 目标:最大化判别器将其分类为真实数据的概率。
判别器(Discriminator, D):
- 输入:真实样本x或生成样本G(z)。
- 输出:样本来自真实分布的概率D(x)。
- 目标:区分真实和生成样本。
对抗训练与损失函数: min_G max_D V(D,G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1-D(G(z)))]
交替训练过程:
- 更新判别器(固定G):最大化log D(x) + log(1-D(G(z)))。
- 更新生成器(固定D):最小化log(1-D(G(z))),实践中常改为最大化log D(G(z))以缓解梯度消失。
纳什均衡:
- 理想情况下,当G生成的数据分布p_g完全等于真实分布p_data时,达到纳什均衡。
- 此时判别器最优解D*(x)=p_data(x)/(p_data(x)+p_g(x))=1/2,即无法区分真假。
- 训练中交替优化,理论上通过凸博弈分析可在无限容量下收敛到均衡。
训练难点:
- 模式崩塌(Mode Collapse):生成器只学习到部分模式。
- 不收敛:G和D的损失震荡,难以达到均衡。
- 梯度消失:判别器太强时,生成器梯度趋近于0。
改进方案:WGAN(Wasserstein距离代替JS散度)、WGAN-GP(梯度惩罚)、CGAN(条件生成)、LSGAN(最小二乘损失)等。