# z - 随机噪声
# X - 输入数据
# c - 输入的label
# ===== 训练判别器D =====
# 真数据输入到D中
D_real = D(X, c)
# 真数据D的判断结果应尽可能接近1
D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
# 生成随机噪声
z = torch.rand((batch_size, self.z_dim))
# G生成的伪数据,这一步的c可以用已知的,也可以重新随机生成一些label,但总之这些c所生成的数据都是伪的
G_sample = G(z, c)
# 伪数据输入到D中
D_fake = D(G_sample , c)
# 伪数据D的判断结果应尽可能接近0
D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
# D的loss定义为上面两部分之和,即真数据要尽可能接近1,伪数据要尽可能接近0
D_loss = D_loss_real + D_loss_fake
# 更新D的参数
D_loss.backward()
D_solver.step()
# 在训练G之前把梯度清零,也可以不这么做
reset_grad()
# ===== 训练生成器G =====
# 这里可以选择,有的实现是直接用上面的z
z = Variable(torch.randn(mb_size, Z_dim))
# 这里可以选择用已知的c,或者重新采样
c = 重新随机一些label
# 用G生成伪数据
G_sample = G(z, c)
# 伪数据输入到D中
D_fake = D(G_sample, c)
# 此时计算的是G的Loss,伪数据D的判断结果应尽可能接近1,因为G要试图骗过D
G_loss = nn.binary_cross_entropy(D_fake, ones_label)
# 更新G的参数
G_loss.backward()
G_solver.step()
原文:https://www.cnblogs.com/zhsuiy/p/9769403.html