PyTorchで全結合型のGANを作ってみよう作って試そう! ディープラーニング工作室(2/2 ページ)

» 2020年10月09日 05時00分 公開
[かわさきしんじDeep Insider編集部]
前のページへ 1|2       

学習

 2つのニューラルネットワーククラスが定義できたので、次は実際に学習を行うコードの番です。ここではGPUを使って計算を高速に行うことにしました(Google ColabでGPUを有効にする方法については「PyTorchからGPUを使って畳み込みオートエンコーダーの学習を高速化してみよう」を参照してください)。

 GPUを使うには、計算に必要なニューラルネットワークモデル、計算対象などを全てGPUに転送しておかなければなりませんでした。というわけで、学習に使用するニューラルネットワークは次のようにして、GPUに転送しています。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'using: {device}')

netD = Discriminator(in_features).to(device)
netG = Generator(zsize, in_features).to(device)

ニューラルネットワークモデルをGPUに転送

 また、損失関数としてはここではtorch.nn.BCELossを、最適化アルゴリズムにはtorch.optim.Adamクラスを使うことにしました。これについては、PyTorchのドキュメント「DCGAN TUTORIAL」などを参考にしてください(Adamはこれまでに使っていたSGDによる最適化をより巧妙に行えるようにしたものだと考えておいてください)。

 ここで覚えておきたいのは、次のことです。

  • 識別器では、訓練データが入力されたら出力が1になり、生成器によるデータが入力されたら出力が0となることが理想的
  • 生成器では、それが生成したデータを識別器に入力したときに、その出力が1となることが理想的

 これを実現するような損失関数が必要なのですが、それをうまく取り扱ってくれるのがBCELossクラスです。

 識別器では、訓練データを入力したときと生成器からのデータを入力したときという2つの条件があることに注意してください。学習時には識別器に訓練データを入力してその結果と正解ラベル1から損失を計算し、次に生成器からのデータを入力してその結果と正解ラベル0から損失を計算し、それら2つをまとめたものが実際の損失となります。

 一方、生成器の学習に関しては、生成器から得たデータを識別器に入力して、その出力と1から損失を計算することになります。そのため、学習を行うコードは少しばかり長ったらしくなってしまいます。なるべく数式を使わないように説明することを目指しているので少し分かりにくかったかもしれませんが、だいたいこんなもんだと思ってくれれば十分です。

 というわけで、以下は損失関数と最適化アルゴリズムを指定するコードです。間にあるのは、固定の正解ラベルです。one_labelsは正解ラベル1をデータローダーから読み込む数(バッチサイズ)だけ並べたもので、zero_labelsはその正解ラベル0版です。

criterion = torch.nn.BCELoss().to(device)

one_labels = torch.ones(batch_size).to(device)
zero_labels = torch.zeros(batch_size).to(device)

optimizer_netD = optim.Adam(netD.parameters(), lr=0.0002, betas=[0.5, 0.999])
optimizer_netG = optim.Adam(netG.parameters(), lr=0.0002, betas=[0.5, 0.999])

損失関数と最適化アルゴリズムの選択

 そして、今述べたような損失計算を含んだ学習コードが以下です。

losses_netD = []
losses_netG = []
EPOCHS = 50

for epoch in range(1, EPOCHS+1):
    running_loss_netD = 0.0
    running_loss_netG = 0.0
    for count, (real_imgs, _) in enumerate(trainloader, 1):
        netD.zero_grad()

        # 識別器の学習
        real_imgs = real_imgs.to(device)

        # データローダーからデータを読み込み、識別器に入力し、損失を計算
        output_real_imgs = netD(real_imgs.reshape(batch_size, -1))
        output_real_imgs = output_real_imgs.reshape(batch_size)
        loss_real_imgs = criterion(output_real_imgs, one_labels)
        loss_real_imgs.backward()

        # 生成器から得たデータを、識別器に入力し、損失を計算
        z = torch.randn(batch_size, zsize).to(device)
        fake_imgs = netG(z)
        output_fake_imgs = netD(fake_imgs.detach()).reshape(batch_size)
        loss_fake_imgs = criterion(output_fake_imgs, zero_labels)
        loss_fake_imgs.backward()

        # それらをまとめたものが最終的な損失
        loss_netD = loss_real_imgs + loss_fake_imgs
        optimizer_netD.step()
        running_loss_netD += loss_netD

        # 生成器の学習
        netG.zero_grad()
        z = torch.randn(batch_size, zsize).to(device)
        fake_imgs = netG(z)
        output_fake_imgs = netD(fake_imgs).reshape(batch_size)
        loss_netG = criterion(output_fake_imgs, one_labels)
        loss_netG.backward()
        optimizer_netG.step()
        running_loss_netG += loss_netG

    running_loss_netD /= count
    running_loss_netG /= count
    print(f'epoch: {epoch}, netD loss: {running_loss_netD}, netG loss: {running_loss_netG}')
    losses_netD.append(running_loss_netD.cpu())
    losses_netG.append(running_loss_netG.cpu())
    if epoch % 10 == 0:
        z = torch.randn(batch_size, zsize).to(device)
        generated_imgs = netG(z).cpu()
        imshow(generated_imgs[0:8].reshape(8, 1, 28, 28))

学習を行うコード

 このコードでは、エポック数を50として、10エポックごとにその時点での生成器を使って画像を生成して、表示するようにしています。実際に実行してみると次のようになります。生成された画像は最後の3つだけを掲載します。

実行結果(一部) 実行結果(一部)

 どうでしょう。うん、ダメですね。MNIST風といえばMNIST風ですが、これで識別器をだますなんて無理だと思います。もう少しまともな画像にならないものかと考えましたが、一番簡単なのは全結合層をもう一つ増やしてみることな気がします。というわけで、DiscriminatorクラスとGeneratorクラスを次のようにしてみましょう。

前のページへ 1|2       

Copyright© Digital Advantage Corp. All Rights Reserved.

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。