PyTorch の Transformer に入門する

PyTorch の公式チュートリアルに Transformer を使ったテキスト処理の章がある。これが何やってるのか分からん。もちろん Transformer の仕組みは分からん。それは想定内。問題は Transformer の外側で何をやってるのか分からんってこと。今回はそれを解説しようと思う。

問題のチュートリアルは以下である。例によって、本記事に載っているソースコードはこのチュートリアルからの引用である。

pytorch.org

以降では、このチュートリアルソースコードを使ったまま、このチュートリアルが何をやりたいのかを解説する。説明の順序はオリジナルからかなり変えている。そして本記事では Transformer の仕組みについては解説しない。

データセットの読み込み

まずは学習に用いるデータセットを読み込む部分から説明する。その後 Transformer に入力するためのテンソルへ変換する部分を理解しよう。

以下のように data_process 関数が定義されている。これはデータセットイテレータをベクトルに直す関数である。

def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:
    """Converts raw text into a flat Tensor."""
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

この関数の data = の右辺の操作が何をやっているのか分かりづらいので、以下のように for 文で書き直してイテレータの要素がどんなベクトルに変換されるのか見てみた。

train_iter, _, _ = WikiText2()
for str in train_iter:
    print("str=", str)
    tokens = tokenizer(str)
    print("tokens=", tokens)
    vocabs = vocab(tokens)
    print("vocabs=", vocabs)
    print("tensor=", torch.tensor(vocabs, dtype=torch.long))
# -> (以下は実行結果)
...
str=  = Valkyria Chronicles III = 
tokens= ['=', 'valkyria', 'chronicles', 'iii', '=']
vocabs= [9, 3849, 3869, 881, 9]
tensor= tensor([   9, 3849, 3869,  881,    9])
...

イテレータから取り出した1文を単語ごと(tokens)に分割し、単語ごとに対応する整数値(vocabs)に変換する。そして、それを要素にもつテンソルが出来上がる。上記の例では5単語なので長さ5の1次元テンソル(ベクトル)に変換されている。つまり1文をベクトルにエンコードする処理である。

data_process 関数内では、この処理が イテレータに含まれる各文に対して行われ、得られたベクトル同士が結合される。プログラミングが得意な人に向けて言うと、data_process 関数は入力されたイテレータに対して flatMap を行うイメージである。

以上の話を踏まえると、以下では各イテレータが長いベクトルに変換されることが分かる。

train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

バッチへの分割

次に、Transformer の入力として適当な形にするため、得られた長いベクトルを適当に分割してバッチ化する。この処理を行うのが batchify 関数である。

この処理の様子は公式チュートリアルに載っている以下の図がとても分かりやすい。以下の例では左辺が長いベクトルで、右辺がバッチ化されたテンソルである。バッチサイズが4であるときの例であるため、テンソルには4つの要素が存在している。つまりバッチサイズとは「何個のバッチを作るか」のことである。

f:id:t-keita:20210908225020p:plain:w600

ここで注意点がある。バッチ化をする batchify 関数であるが、上記の図のバッチを作ったあとにしれっと t() という関数呼び出しで テンソルの0次元と1次元を入れ替えている。つまり、batchify 関数が返すものはバッチ分割しただけのものではない。これがめっちゃ紛らわしい。正しくは以下の図が batchify 関数の正しい挙動である。

f:id:t-keita:20210909014055p:plain:w600

この図をイメージすると、以下の batchify でどんな処理が行われているのか理解しやすいだろう。たとえば train_data というベクトルは20個のバッチに分割された後、各バッチの i 番目の要素が1つのベクトルにまとめられる。

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)  # shape [seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

入力データとターゲットの作成

ここからは Transformer に入力するための "入力データ" と、その結果の正しさを検証するための "ターゲット" データを作成してゆく。このペアは get_batch 関数によって得られる。公式チュートリアルに載っている以下の図を見てみよう。

f:id:t-keita:20210908230717p:plain:w500

この図に例では i=0 かつ bptt=2 であるため、Input として0番目からの2行が抽出されている。Target は Input の行を1つだけずらしたものである。

この Input と Target の意味を説明する。Transformer は、文中のある単語の次に出現する単語を予測するモデルである。Transformer は Input の各要素に対して「次に出現する単語」を推定する。図の例では、Input の "A" を見たときに「次に "B" が来ること」を予想できることが望ましい。(なぜなら実際に "A" に続くのは "B" であるため。)そういう意味で、Target の "B" が Input の ”A" と同じ左上の位置に配置されている。Input の各要素に対する推定の正解が Target の同じ位置に配置されているというわけである。

また、Transformer は文脈を考慮して次に出現する単語を推定する。どれだけの長さの文脈を考慮するかを決めるのが bptt というパラメータである。図の例では bptt=2 であるため直前の単語のみ考慮される。たとえば、Input の "H" に対して次の単語を予想するときは直前に "G" が来たことを考慮する。これを考慮した上で、正解の "I" が推定できるかという話である。

モデルの実行

Transformer を学習するときも評価するときも、上記で作成した入力データ(テンソル)をモデルに与えて実行することになる。ということで Transformer のモデルを実行する部分を見ていこう。

例えば以下のようにモデルを実行してみる。モデルの実行にはバッチが必要なので、テキストを単語分割したあとにバッチ化している。

original_text = "The capital of Japan is Tokyo."
data = torch.tensor([vocab(tokenizer(original_text))]).t().to(device)
src_mask = generate_square_subsequent_mask(len(data)).to(device)
output = best_model(data, src_mask)

ここでは data という入力データを作成し model に与えている。加えて、 src_mask というものがモデルに与えられている。これは次に出現する単語の先読みを許さないようにするものらしいが詳しいことは分からない。とにかくTransformer を正しく動かすために必要な引数だと思っておく。

モデルの出力である output はやや複雑な構造を持っている。上での述べたとおり、Transformer は次に出現する単語を推定するものである。具体的には、辞書に含まれる各単語に対して、次の単語として出現する確率を割り当てる。例えば、利用可能な単語が10個であったとき、1つ目の単語に5%、2つ目の単語に8%、3つ目の単語に14%、... といった感じである。このチュートリアルの単語数は28,782個らしいので、かなり選択肢の多い確率分布を推定することになる。

では output の構造について見ていこう。入力データである data の shape もついでに見てみる。

print(data.shape)
print(output.shape)
# -> (以下は実行結果)
torch.Size([6, 1])
torch.Size([6, 1, 28782])

入力データの data の shape は [6, 1] である。これは6単語からなるデータであることを意味する。そして、モデルの出力である output の shape は [6, 1, 28782] である。すべての単語数が28,782個であることを思い出すと、これは入力データの各単語ごとに「次に来そうな単語」を予想し、それを確率分布として保持しているのである。

実際に、次にどんな単語が来ると推定されたのか見てみよう。入力の各単語ごとに、確率分布のうち最大値をもつ単語を表示してみる。

for i, e in enumerate(output):
    max_index = torch.argmax(e[0])
    print(vocab.lookup_token(data[i][0]), "->", vocab.lookup_token(max_index))
# -> (以下は実行結果)
the -> first
capital -> of
of -> the
japan -> ,
is -> a
tokyo -> ,

結果として「capital の後に of が来る」という予想しか当たっていない。残念。このチュートリアルのような手軽な学習ではこんな感じのよく分からない推定結果になっているが、Transformer そのものはすごいものなんだろう。

ここまで分かれば、学習や評価で具体的に何が行われているのかソースコードを読めば分かると思われる。めでたし。

所感

結局のところ、公式チュートリアルbatchify の図があまりに紛らわしい。それと get_batch 関数の出力の意味について説明が足りてない。なぜあんなデータ構造にするのか初学者には分からん。ついでに、学習済みの Transformer を使って実際に単語を予測してみる例がないので、全体を通して何を作っているのか分かりづらい。満足な精度が出ないから載せてないのかもしれないが、何を出力するものか分からずに精度が改善されてゆく様子だけ見せられてもよく分からん。

次は Transformer そのものや mask が何だったのか調べてみよう。