PyTorchで畳み込みオートエンコーダーを作ってみよう作って試そう! ディープラーニング工作室(2/2 ページ)

» 2020年08月07日 05時00分 公開
[かわさきしんじDeep Insider編集部]
前のページへ 1|2       

学習

 では学習を行いましょう。といっても、そのコードはいつも通りです。

net = AutoEncoder2(enc, dec)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
EPOCHS = 100

output_and_label, losses = train(net, criterion, optimizer, EPOCHS, trainloader)

学習を実行

 学習が終わるまでには1時間弱かかりました(CPUのみを使用)。エポック数は100ですが、前回の100エポックの学習であまりデキがよくなかったヤツでも1時間半かかったことを考えると速度面ではかなりよいものといえるでしょう。

 でも、問題はどんな画像が復元されるかです。

img, org = output_and_label[-1]
imshow(org)
imshow(img)

最後に学習した結果から得た復元画像と元の画像を表示

 どんな結果になったでしょう。

実行結果 実行結果

 うーむ。全般にもやが掛かったようになっているのが気になります。復元度合いも正直これではいいとも悪いともいえません。ちなみにテストローダーから読み込んだ元画像、100エポックの学習を終えた時点のモデルに元画像を入力した結果、同じモデルでもう200エポックを学習させたものによる結果を以下に示します。

画像比較 画像比較

 100エポックのものも、300エポックのものもさほど変わらない結果となってしまいました。記事には掲載していませんが、追加で200エポックの学習をしている間に、損失が下がらなくなっていたので、実は予想できる事態だったのですが、ハッキリと結果が出るとがっかりするところではありますね。

 では、どうすれば改善できるでしょうか。少し考えてみます。

チャネル数とカーネルサイズを大きくしてみる

 CNNには畳み込みやプーリングによって、画像の特徴を抽出するという機能があります。ということは、チャネル数(カーネル数)を増やしたり、カーネルサイズを大きくしたりすることで、表現力が向上するのではないでしょうか。最後にこれを試してみることにします。

前のページへ 1|2       

Copyright© Digital Advantage Corp. All Rights Reserved.

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。