d0tfi1e’s blog

趣味と日記

LMTの理論をわかりやすく

はじめにリンクを載せておきます。

要は、うまくネットワークを学習させることで、ある閾値未満の摂動に対してロバストにできることを理論的に示した、ということです。

この手法はLipschitz Margin Trainingといわれ、略してLMTです。

以下、説明に際してモデルはClassifierを考えます。つまり、入力ベクトルを、いくつかのクラスに分類するような問題を考えます。

どれだけ出力がずれても正しく分類されるか

Classifierでは、それぞれのクラスのconfidenceを求め、それがもっとも高くなるようなクラスに入力を分類します。 入力を xとし、真のラベルに対応するconfidenceを C_t(x)真のラベル以外のラベルでもっともconfidenceが高いものに対応するconfidenceを C_{t'}(x)と書くことにすると、正しく分類される条件は

 M(x) = C_t(x) - C_{t'}(x) > 0

です。 Mはmarginの頭文字です。ある程度入力を摂動させたとしてもネットワークが正しく分類するというのは、 Mが摂動に対して正の符号を保ったままように変化することです。このためには、ある程度 Mがデフォルトで大きな値をとっておく必要がありますよね。

LMTで一番大事な式

ネットワークのリプシッツ定数 Lを次の式を満たすような実数とします。以下、ベクトルのノルムはL2 norm, 行列ノルムはL2 normから誘導されるノルムであるspectral normを考えます。

 |F(x+\varepsilon) - F(x)| \le L|\varepsilon|

ここで Fはネットワークの関数、 xは入力です。このとき、次の式が成り立ちます。

 M(x) > \sqrt{2} L |\varepsilon| \Rightarrow M(x + \varepsilon) > 0

証明は元論文にあるのでここでは省略します。

つまり、入力 xに対して M(x)が常に、 \sqrt{2}cL cは定数)より大きくなるように学習させることができれば、 |\varepsilon| \le cを満たすような摂動に関しては、結果を変えません。

Lipschitz Margin Training

LMTでは、classifierが正しく分類したときに限り、 C_t(x)の値を C_t(x) - \sqrt{2}cLに補正して誤差を計算します。 これで学習がうまくいけば、このネットワークは正しく分類するときに、2番目の候補に対して \sqrt{2}cLより大きなconfidenceの差をつけて 判別するわけですから、上述したことから摂動に強くなることがわかります。

LMTの気持ち

と、ここまででLMTの説明としては十分なわけですが、最後にLMTがどんな場合でもうまくいくわけではない、ということを説明しておきます。

まず、上の説明でも出てきた定数 cの意味をよく考えてみましょう。実は cにはinvariant radii(不変半径)という名前がついており、名前の通り、この範囲内の摂動なら結果が変わらないように要請するものです。 c = \inftyの極限を考えると明らかに失敗する(すべての入力が同じクラスに分類される)ように、分類に対して適切な cの値を見積もることが重要です。 c=\inftyではネットワークが定値関数になってしまうように、 cはネットワークが表現する関数のなめらかさに関する制約だとみなすこともできます。

また、それとは別に、ネットワークが表現できる関数のクラスというものがあります。20層のネットワークは4層のネットワークに比べ、ずっと複雑なモデルを扱うことができます。複雑なモデルだと、途中のパラメータをいくらでも調整できるため、かなり大きな cを指定しても、うまくそのなめらかさの要請を満たしつつ、適切な関数を学習してくれます。しかし、Shallow Networkの場合、 cの要請を満たすようななめらかな関数がそもそもそのネットワークで表現不可能、ということがしばしばあります。こういった場合、 cの値がかなり小さく制限されてしまいます。

また、Adversarial Exampleの生成手法であるFGSM: Fast Gradient Sign Methodに対して、LMTはそれほど強くないということが知られています。 FGSMは、入力の各要素を摂動させたときの誤差の変化をもとに、誤差を増大する方向を調べ、その方向に入力を摂動させるというアルゴリズムです。 LMTで学習されたモデルの場合だと、ネットワークの関数があまりになめらかなため、FGSMの攻撃に弱いです。関数がなめらかだと数値微分の精度が上がってしまうので、誤差が増大する方向を的確に突き止められてしまいます。FGSMに強いネットワークというのは、LMTとは真逆で、ネットワークの関数が極めてギザギザしているような関数になっているネットワークです(こうすると数値微分の値がほとんどランダムになります)。

とはいってもぼくの実験では、4層のShallow Networkでも、入力の各要素で5%程度のスケールまでの摂動なら通常の学習をした場合に比べて、Advesarial Exampleでの攻撃成功率を半分程度に抑えることができました。