手書き数字認識;(バッチ版)交差エントロピー誤差の実装でのつまづき

こんにちは。まだプログラミング続いてます。

例の本「ゼロから作るDeep Learning」のp94です。手書き数字認識の、交差エントロピー誤差をバッチで実装するパートがあります。

ラベル表現の場合の交差エントロピー誤差を実装するのがよくわからなかったので、自分へのノートとして書きます。見苦しいですが、お許しください。

ここに交差エントロピー誤差の公式を入れる

交差エントロピー誤差

one-hot表現

one-hot表現だと簡単に計算できますね。tに入っているのは0または1で、正解(0~9の数字)のインデックスに1が入っており、他は全部0です。なので、y(予想の出力、0~9の各数字に対してパーセプトロンを処理した結果が入っています)とtを掛け算すれば、正解の数字に対応する出力が計算されたことになります。

例を挙げます。ある画像に対して、正解は6だとします。tは[0,0,0,0,0,0,1,0,0,0]となりますね。6個目の要素(インデックス6)に1が、他には0が入っています。交差エントロピー誤差(以下、単に誤差と呼びます)を計算するのに必要なのは、yのうちインデックス6の要素だけです。yは例えば[0.3,0.05,0.0,0.0,0.1,0.3,0.7,0.1,0.1,0.2]とかになってるわけです。つまり必要なのは6番目の要素である0.7ですね(0.7×1=0.7)。

ラベル表現

ここで私は詰まってしまいました。考え方のフローをメモとして残しときます。

実はp91の式(4.3)がそのまま使えるのはone-hot表現でtをインストールしたときのみです。ラベル表現には違う計算方法を使うしかありません。で、その説明が例の本では正直わかりにくい。すごく悩みます。one-hot表現ってなんぞや!って人は、p76に戻り、mnistデータをダウンロードするコードの部分を見てください。one_hot_label=Falseというコードが見えますね。これがラベル表現でダウンロードするコードです。詳細なる説明はp92の下にあります。

ラベル表現は厄介です。なぜならtは正解の数字そのもの、3とか9だったりします。なので、数字によって出力yが誤差に与える影響が変わってしまいます。例えばtが0だったら、いくら予想の精度が悪くてもゼロになっちゃいますね。

本題です。交差エントロピー誤差は「正解に対応する出力の自然対数を合計してデータ数で割ったもの」で定義されます(間違ってたらすみません)。

上の例なら、tはt=6なのです。なので必要な出力はyの6番目ですから、y[6]として取得できますね。この方法をデータ個数が増えた形に応用してあげれば計算できます。

少し拡張して、以下のようになっているとします。〇個目というのは画像データが何個目かを表します。

  1個目 2個目 3個目 4個目 5個目
t(正解) 2 7 0 9 4

このとき、yはどうなっているでしょうか?

y=[[A],[B],・・[E]] 

A = (1個目のデータに対する出力10個)

B = (2個目のデータに対する出力10個)・・・

となっています(いやここが一番合ってるか自分でもモヤモヤしてますが)。Aを表すには、y[0]と書き、yのうち0番目の要素として取り出せばよさそうです。

さらに、Aの2番目の要素、Bの7番目の要素、、、Eの4番目の要素のそれぞれの自然対数を合計し、4(データの個数)で割れば誤差が計算されます。Aの2番目の要素はA[2]とすれば取り出せそう。すなわちyで表現すると、y[0][2]です。同じくy[1][7], y[2][0], y[3][9], y[4][4]ですね。

本書ではこれがy[0,2], y[1,7], y[2,0], y[3,9], y[4,4]と書ける、と書いてあります。そんな記述これまであったっけ?とつまづいてしまいました。

p15の1.5.6 要素へのアクセスでも、X [0][1]で(0, 1)の要素を取り出せると書いてあるのみで、X[0, 1]と書けるとは言っていません。

でも試してみたらX[0,1]で同じ結果を得ました。こう書けるみたいですね!

なんとなーく読んでるだけやと「書けるやろ」で流してる箇所ですが、ちょっとつまづいてしまいました。

さいごに

教訓として、「腹落ちしないならコード書いてみろ」と学べましたね。「書けるやろ」と思ったら本当に試してみて覚えていこうと思いました。

まぁこんな感じで初歩的なつまづきを共有していきます。

最後までお読みいただきありがとうございました。

シェアする

  • このエントリーをはてなブックマークに追加

フォローする