mao9o9の技術メモ

個人的に興味のある分野についての技術メモです。

WideResNetを試してみた

はじめに

ディープラーニングについて学習する上で、WideResNetを実装して試行毎の損失、正解率をグラフ化してみた。

github.com

WideResNetについて

従来ではResNetの層の深さに対するアプローチが行われていたが、層の幅を広くすることで、浅い層でも高い精度を実現した手法[1]。モデルの作成において[2][3][4]を参考にした。

f:id:mao9o9:20220203220343p:plain:w400
図1 WideResNetのネットワークの構成 [1]より引用

実装環境

バージョン情報

  • torch 1.10.0+cu111
  • torchvision 0.11.1+cu111
  • numpy 1.19.5
  • sklearn 1.0.2

訓練データはCIFAR-100を利用した。

実験条件

  • 入力サイズ:[3, 32, 32]
  • バッチサイズ:128
  • 学習率:0.1 (60 epochからは0.2倍)
  • Network width:2, 10
  • weight decay:5e-4 (Network width: 10の時 1e-4)
  • optimizer: SGD (momentum: 0.9, nesterov)
  • transform: RandAugmentation (magnitude: 10)
  • Drop out: 0.3

Network widthが10の時、初期学習率が0.1だと損失が発散してしまったため、0.01とした。ネットワークの深さは28層とし、100 epochで確かめた。

結果

f:id:mao9o9:20220203220426p:plain:w500
図2 wrn-28-2の訓練と精度

f:id:mao9o9:20220203220452p:plain:w500
図3 wrn-28-10の訓練と精度

精度は式(1)に示す正解率である。

 \displaystyle
\rm{Accuracy} = \frac{\rm{TP + NF}}{\rm{TP + NP + TF + NF}} \ \ \ \ (1)

wrn-28-2では実行時間は1時間ほどで、60 epochで学習率を0.2倍した際に精度は上昇した。50層のResNetと比較すると、50 epochまでであるため、1/4の時間でほぼ同様の精度となった。

wrn-28-10では3時間ほどであったが、精度はwrn-28-2よりも高くなった。また、wrn-28-10では30 epochからTrainよりもValの損失が大きくなり、過学習の傾向が見られた。

結論

ネットワークが広くなるほど(チャンネル数が大きくなるほど)実行時間は増加するが、精度は高くなり、ネットワークの深さだけでなく広さに対しても実行時間と精度はトレードオフの関係にあることが確認できた。

学習が進むほど損失は小さくなるため、損失に応じた学習率の設定が精度向上に有効? → 次回

参考

[1] S. Zagoruyko, et.al., "Wide Residual Networks", 2016.

[2] WideResNet作成時に引っかかった点

[3] 論文の勉強3 WideResNet

[4] Pytorch - Wide ResNet の仕組みと実装について解説