mao9o9の技術メモ

個人的に興味のある分野についての技術メモです。

SGDについて

はじめに

SGD(Stochastic gradient decession)について、設定した目的関数に対してどのようにパラメータを更新しているのかシミュレーションを行い調べた。

実装環境

ついでにPytorchでのSGDの挙動を調べるためPytorch版とnumpy版を作成して、ステップの更新を調べた。

github.com

SGD

SGDは目的関数 f(w)を最小化する手法。 f(w)導関数を求め、入力の勾配に応じて更新する度合を設定する。

モーメントを使用しない時、


w_t = w_{t-1} - \gamma \nabla_w f(w_{t-1})

momentumを使用するとき、


v_t = \mu v_{t-1} + \nabla_wf(w_{t-1}) \\
w_t = w_{t-1} - \gamma v_t

momentumを使用かつnestrovを適用する時、


v_t = \mu v_{t-1} + \nabla_wf(w_{t-1}) \\
w_t = w_{t-1} - \gamma (\nabla_wf(w_{t-1}) + \mu v_t)


結果

  • パラメータ設定

SGD:momentum = 0.9, nesterov = True

初期点: (x, y) = (20, 0)

学習率: lr = -0.0002 * iteration + 0.01

  • 目的関数: f(x,y) = x^{2} + y^{2}
f:id:mao9o9:20220219100122p:plain
図1  f(x, y) = x^2 + y^2

図1に示す目的関数は (x, y) = (0, 0)で極小値を持つ。

f:id:mao9o9:20220219100345p:plain
図2 パラメータの軌跡

目的関数の xy平面上におけるSGDのパラメータの軌跡を図2に示す。勾配が0となる点に向けて移動するため、 xの負の方向にパラメータが更新されている。

f:id:mao9o9:20220219100431p:plainf:id:mao9o9:20220219100444p:plain
図3 iterationと勾配および損失の関係

次に、iterationと勾配、および損失の関係を図3に示す。iterationが15の時、勾配は0となり損失は最小値をとった。その後の勾配は負であるため、目的関数の最小点を超えて坂道を駆け上っているような状態であると思われる。

  • 目的関数: f(x, y) = x^{2} - y^{2}
f:id:mao9o9:20220219100556p:plain
図4  f(x, y) = x^2 - y^2

図4に示すように (x, y) = (0, 0)で鞍点を持つ関数である。

f:id:mao9o9:20220219100649p:plain
図5 パラメータの軌跡

目的関数の xy平面上におけるSGDのパラメータの軌跡を図5に示す。勾配が0となる点に向けて移動するため、 xの負の方向にパラメータが更新されている。また、今回のような初期値の取り方では鞍点から抜け出せず、モデルに最適な最小値をとれていない。

f:id:mao9o9:20220219100731p:plainf:id:mao9o9:20220219100740p:plain
図6 iterationと勾配および損失の関係

次に、iterationと勾配、および損失の関係を図6に示す。鞍点にはまっているにも関わらず損失および勾配は図3と同様な傾向を示している。

従って、損失が0に近づいたとしても目的関数の最小値であるとは限らない。つまり、iterationにおける損失をみただけでは、目的関数が最小となる値を求められているかわからない。

  • PytorchとnumpyのSGDの比較

パラメータの更新式はPytorchのReferenceを参考に構成したため、異なる点は勾配の計算方法であると思われる。Pytorchでは中心差分で計算してるのかな?うまく再現できなかったため不明である。

参考

Pytorch – 確率的勾配降下法 (SGD)、Momentum について解説

6.1.2:SGD【ゼロつく1のノート(実装)】