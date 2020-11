def train(netD, netG, batch_size, zsize, epochs, trainloader):

losses_netD = []

losses_netG = []



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

netD = netD.to(device)

netG = netG.to(device)

one_labels = torch.ones(batch_size).reshape(batch_size, 1).to(device)

zero_labels = torch.zeros(batch_size).reshape(batch_size, 1).to(device)

criterion = nn.BCELoss().to(device)



optD = optim.Adam(netD.parameters(), lr=0.0002, betas=[0.5, 0.999])

optG = optim.Adam(netG.parameters(), lr=0.0002, betas=[0.5, 0.999])



for epoch in range(1, epochs+1):

running_loss_netD = 0.0

running_loss_netG = 0.0

for count, (real_imgs, _) in enumerate(trainloader, 1):

netD.zero_grad()



# 識別器の学習

real_imgs = real_imgs.to(device)

# データローダーから読み込んだデータを識別器に入力し、損失を計算

output_from_real = netD(real_imgs).reshape(batch_size, -1)

loss_from_real = criterion(output_from_real, one_labels)

loss_from_real.backward()



# 生成器から得たデータを、識別器に入力し、損失を計算

z = torch.randn(batch_size, zsize, 1, 1).to(device)

fake_imgs = netG(z).to(device)

output_from_fake = netD(fake_imgs.detach()).reshape(batch_size, -1)

loss_from_fake = criterion(output_from_fake, zero_labels)

loss_from_fake.backward()



# それらをまとめたものが最終的な損失

loss_netD = loss_from_real + loss_from_fake

optD.step()

running_loss_netD += loss_netD



# 生成器の学習

netG.zero_grad()

z = torch.randn(batch_size, zsize, 1, 1).to(device)

fake_imgs = netG(z).to(device)

output_from_fake = netD(fake_imgs).reshape(batch_size, -1)

loss_netG = criterion(output_from_fake, one_labels)

loss_netG.backward()

optG.step()

running_loss_netG += loss_netG



running_loss_netD /= count

running_loss_netG /= count

print(f'epoch: {epoch}, netD loss: {running_loss_netD}, netG loss: {running_loss_netG}')

losses_netD.append(running_loss_netD)

losses_netG.append(running_loss_netG)

if epoch % 10 == 0:

z = torch.randn(batch_size, zsize, 1, 1).to(device)

generated_imgs = netG(z).cpu()

imshow(generated_imgs[0:8].reshape(8, 1, 28, 28))

return losses_netD, losses_netG