PyTorch に入門する
論文とかに載ってるニューラルネットワークのアーキテクチャを再現して自由自在にカスタマイズできるようになりたい。最近はニューラルネットワーク触るとなると PyTorch らしい。ってことで PyTorch のチュートリアルやってみる。チュートリアルのページは以下。
ニューラルネットワークの学習を理解するために重要な確率的勾配降下法と誤差逆伝播法については以前の記事で解説を書いた。
これらの概念はしっかり理解しているつもりなので、PyTorch というフレームワーク上でどう表現されているのかを見ていきたい。なお、本記事のソースコードはすべてこのチュートリアルからの引用である。
0章 Quickstart
まずは全体像を学ぶ章から見てゆく。
PyTorch でデータ扱うには以下の2つが重要らしい。
torch.utils.data.Dataset
- サンプルとラベルからなるデータ
torch.utils.data.DataLoader
Dataset
をイテラブルにしたデータ
以下のように Dataset
を DataLoader
に変換できる。このときバッチサイズを指定することで、1回のイテレーションでそのバッチサイズだけのデータが取得できる。
batch_size = 64 # Create data loaders. train_dataloader = DataLoader(training_data, batch_size) test_dataloader = DataLoader(test_data, batch_size) for X, y in test_dataloader: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break
モデルを定義するときは以下のように nn.Module
というクラスを継承するらしい。そして、__init__
関数内にニューラルネットワークのレイヤーを定義して、forward
関数内にデータがどう受け渡されてゆくのか定義する。
class NeuralNetwork(nn.Module): def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.liner_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 18), ) def forward(self, x): x = self.flatten(x) logits = self.liner_relu_stack(x) return logits
そして、学習時に用いる損失関数やオプティマイザを設定する。今回の例ではクロスエントロピーと確率的勾配降下法を指定している。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
あとは学習データの各バッチを読み込みながらパラメータを調整してゆく。バックプロパゲーションは以下のような記述をするらしい。
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
学習の流れの雰囲気は掴めた。要するにデータを用意してモデルを定義して学習を実行すればオッケー。
疑問点を解消してゆく
次は "1章 Tensors" なので次はテンソルについて学んでいこう。...と言いたいところだが上記の確認だけで雰囲気が分かってしまったので、ここからは自分が抱いた疑問点を解消しながら理解を深めていきたい。
学習モデルのインスタンス変数はなぜ必要か?
独自のモデルを定義するときは nn.Module
を継承するのであった。このとき、__init__
関数内でインスタンス変数(self
でアクセスする変数)を定義している。
def __init__(self): super(NeuralNetwork, self).__init__() self.flatten = nn.Flatten() self.liner_relu_stack = nn.Sequential(...) def forward(self, x): x = self.flatten(x) logits = self.liner_relu_stack(x) return logits
しかし、これらのインスタンス変数は forward
関数でしか使われない。だったらなぜインスタンス変数として定義するのか。以下のように forward
関数内に記述すればよいのでは?
def __init__(self): super(NeuralNetwork, self).__init__() def forward(self, x): x = nn.Flatten(x) logits = nn.Sequential(...)(x) return logits
しかし実際にこのコードを実行してみると以下のようなエラーメッセージが表示される。
ValueError: optimizer got an empty parameter list
オプティマイザに渡すモデルのパラメータが空であると怒っている。学習によって調整するパラメータはモデルのインスタンス変数として定義されていなければならないのか。理解した。考えてみれば、モデルのオブジェクトが調整されるべきパラメータの情報を持っているっていうのは直感的で分かりやすい。
Flatten 関数とは何か?
モデルを定義するときに nn.Flatten()
という関数が使われている。おそらくテンソルを1次元ベクトルに直す関数だと予想されるが念のため調べる。
どうやら引数に与えられた範囲の次元を1つのテンソルにするということらしい。つまり常にベクトル(1次元テンソル)を返すとは限らない。ちょっと実験してみよう。
まずはデフォルトの start_dim = 1
の場合に2次元のテンソルを与える。
t = torch.Tensor([[1,2],[3,4],[5,6]]) nn.Flatten()(t) # => tensor([[1., 2.], [3., 4.], [5., 6.]])
2次元以降はすでに1つのテンソルなので、もちろん何も変化がない。
次は同じ条件で3次元のテンソルを与えてみる。
t = torch.Tensor([[[1,1],[2,2]],[[3,3],[4,4]]]) nn.Flatten()(t) # => tensor([[1., 1., 2., 2.], [3., 3., 4., 4.]])
今度は1次元の構造はそのままで2次元以降が flatten されていることを確認できた。
なぜ nn.flatten
の引数の start_dim
のデフォルト値が 1
であるのか。それは、1つの学習データは1つのテンソルに対応するため、1つのバッチは複数のテンソルに対応するからである。たとえば、バッチサイズが64である場合はバッチは以下の次元をもつ。
for X, y in train_dataloader: print(X.shape) # => torch.Size([64, 1, 28, 28])
flatten
関数がこのように挙動するおかげで、forward
関数を実行するときにバッチごと入力できて、その出力を1つのテンソルとして返せる。以下はその例である。
for X, y in train_dataloader: print(X.shape) print(nn.Flatten()(X).shape) print(model.forward(X.to(device)).shape) # => torch.Size([64, 1, 28, 28]) torch.Size([64, 784]) torch.Size([64, 18])
flatten の処理だけでなく forward
関数の動きにもイメージが湧いた。
train 関数ってなに?
モデルの学習を行うところで以下のように model.train()
という謎の呼び出しがある。パラメータの更新は for ループの中でやっているのに、なぜループの外にこんな呼び出しがあるんだろう。
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) model.train() for batch, (X, y) in enumerate(dataloader): ...
公式ドキュメント を見てみるとnn.Module
の中に train(mode=True)
という関数についての説明がある。どうやらモデルには training mode と evaluation mode というのがあって、モジュールによってはこの値が影響するらしい。今から学習しますよ、って意味での model.train()
というわけか。ちなみに test
関数の中には model.eval()
という呼び出しがある。これは evaluation mode であることを明示している。今回の例ではこれらをコメントアウトしても挙動に変化がなかったので今は気にしなくてよさそう。おそらく発展的な機能。
パラメータの更新はどういう流れ?
おそらく誰にとっても一番よく分からないところ。以下のバックプロパゲーションの処理を記述したコードを再掲する。3行あって3行ともよく分からない。
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
さすがに意味不明なのでチュートリアルの Optimization の章 を見てみる。以下のような記述がある。
Call
optimizer.zero_grad()
to reset the gradients of model parameters. Gradients by default add up; to prevent double-counting, we explicitly zero them at each iteration.
optimizer.zero_grad()
の呼び出しは勾配をリセットする役割があるらしい。勾配は累積してゆくのがデフォルトの仕様なので、パラメータを更新するたびにゼロにする必要があるってことらしい。なぜ累積されるのがデフォルトなのかよく分からんがこういうものなんだろう。
Backpropagate the prediction loss with a call to
loss.backwards()
. PyTorch deposits the gradients of the loss w.r.t. each parameter.
loss.backward()
の呼び出しによって誤差を逆伝播させる。これはニューラルネットワークの各パラメータの勾配を計算しているのだと思われる。
Once we have our gradients, we call
optimizer.step()
to adjust the parameters by the gradients collected in the backward pass.
optimizer.step()
の呼び出しによってパラメータの更新を行う。今回は SGD(確率的勾配降下法)を用いているので勾配に学習率をかけてパラメータの更新を行っているはず。
この処理の流れは直感的である。しかし次の疑問が湧いてくる。
optimizer と loss はどう関係している?
ここで気持ち悪いのが optimizer
と loss
が独立した変数ではないという点である。とくに loss.backward()
の結果が optimizer.step()
で使われるという部分のデータフローがどうなっているのか意味不明。変数の依存関係を考えるためこれらの変数の初期化部分などを見てみる。関連しそうな部分を以下に抽出した。
model = NeuralNetwork().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
...
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer
はモデルのパラメータの更新を管理してそう。現在のパラメータの値と更新式を保持していると考えるのが妥当だろう。そして loss
は誤差だけ持っていそう。やはりこの2つの変数間でどうデータがやりとりされているのか分からない。そもそも loss.backward()
でモデルのパラメータの情報が必要そうだが、モデルの情報を知っているようには見えない。
もっと深堀りして調べてみよう。まさに同じ疑問が Stack Overflow に投稿されていた。気になるよなやっぱり。
この回答によると話はなかなか複雑らしい。ちょっと抜き出すと this Tensor object has a grad_fn prop in which there stores tensors it is derived from.
とか言っている。つまり、Tensor オブジェクトは生みの親を覚えているらしい。それなら納得できないことはない。たとえば、 loss
はただのテンソルではなく、その誤差を生み出した原因となるテンソル(パラメータの重み)を知っているので勾配の伝播ができるみたいな話だと思われる。これが PyTorch の自動微分の仕組みに関係しているらしく、以下のチュートリアルが参考として与えられている。
このチュートリアルは次の機会にやってみよう。現時点での正確な理解は断念。
所感
今回は初めて PyTorch を触ってみた。けっこう直感的に使える気がする。PyTorch のテンソルがお化けオブジェクトというイメージが湧いた。油断してはいけない。たぶんだけどテンソル経由で様々な情報がやりとりされている。しかしなんでこんな設計にしたんだろう。Python 好きな人にとってはデータフローが隠蔽されることは別にイヤじゃないのかな。分からん。Java のコード読むほうが10倍くらい楽だ。
追記:以下の記事で PyTorch の自動微分の仕組みを勉強してみた。この記事の疑問がそれなりに解消できた。