OverTheWire やってみた(Natas 編)

前回 に引き続き CTF(Capture The Flag)として OverTheWire をやってみる。今回はウェブセキュリティを扱った Natas 編をやる。ページは以下。

overthewire.org

ウェブセキュリティなんてほぼ知らないぞ。課題をクリアしていけるか分からんがとりあえずやってみよう。課題の状況を把握したうえで5分くらい手が止まったら解法をググろう。そんくらいの気楽さでやるのが大切。知らんけど。

Level 0

クライアントサイドの html ファイルのソースを見るとパスワード発見。まずはサーバサイドは関係ないのか。

Level 1

前の課題と同じく html ファイルのソース見るだけ。右クリックが禁止されているので右クリックからソースを表示できないってことか。最初から Chrome で F12 でソース見てたので最初課題の意味が分からなかった。

Level 2

pixcel.png という謎のファイルをサーバから受信しているのが怪しい。このファイルの URL を見ると files というディレクトリが存在していて、このディレクトリにある users.txt というファイルの中にパスワードを発見。静的コンテンツへのアクセス権限の設定誤りってことなんだろう。

Level 3

Google でも見つけられないってことは、このページのコンテンツからは到達できないページにパスワードが隠されているのだろうと推測。適当なディレクトリ名を叩いてみてもなにも見つからず。

解法をググった。robots.txt にヒントが書かれているらしい。そういえばこういう検索エンジンのクローリングを制御するためのファイルがあったのを思い出した。あとは簡単で、クローリング対象から除外されているディレクトリにアクセスすればパスワードが見つかる。

Level 4

index.php にアクセスするが、そのアクセス元(HTTP referer)が http://natas5.natas.labs.overthewire.org/ でないとパスワードが手に入らないらしい。そこで curl コマンドを使って referrer を偽装してリクエストを送信することでパスワードをゲット。

こういう操作は慣れていないがググりながらやればなんとかなる。Chrome の拡張とかもあるらしい。この先の課題次第ではそういうのをインストールするかも。

Level 5

ログインしてないからアクセスできないと起こられる。ブラウザに保存されているクッキーをみると loggedin というキーに 0 が設定されているのを発見。怪しいのでその値を 1 に変更して再度アクセスしてみるとパスワードを発見。ログイン状態の管理方法が杜撰だとこうなるってことか。

Level 6

文字列を入力するフォームがある。認証ロジックの PHPソースコードを公開してくれていて、どうやら入力された文字列が変数 $secret の値と一致するか調べているらしい。$secret の値は includes/secret.inc というファイルの中にあるっぽく、そのファイルを見るとキーを発見できた。

Level 7

クエリパラメータとしてパスを与えると、相対パスとして解釈されてそのファイルの中身にアクセスしてくれるらしい。ってことで /etc/natas_webpass/natas8 にアクセスできるような相対パスを書いてパスワードをゲット。いわゆるディレクトリトラバーサル

Level 8

文字列を入力するフォームがある。その認証ロジックでは bin2hex(strrev(base64_encode($secret))) のように入力文字列の変換を処理をした上で正解となる文字列と一致するかを見ている。これらの演算はすべて可逆なので、ひとつずつ逆変換をしてゆけば入力文字列が特定できる。逆変換の実装には PHP のプログラムをオンラインで書いて実行できるサービスを使った。

paiza.io

Level 9

ユーザが入力した文字列が、実行されるコマンドに埋め込まれる。その実装が grep -i $key dictionary.txt のように直接的に埋め込むようになっている。これを上手く利用して以下のようなコマンドが実行されるように入力文字列を設計した。結果としてパスワードを cat できるのでクリア。

grep -i xxx dictionary.txt ; cat /etc/natas_webpass/natas10 ; echo dictionary.txt

こういう任意のコマンドを実行できる脆弱性は一般に arbitrary code execution として知られているらしい。今回の脆弱性にもっと適切な名前があるのかもしれないが。

Level 10

前の課題とほとんど同じだが、grep コマンドが実行される前に文字列が ;, |, & の3つの記号を持つかどうかチェックしている。そのため前の課題のようにコマンドをパイプできない。とはいえ grep コマンドは cat に近いものがあるので、以下のように任意の文字列をマッチさせるコマンドを実行させてパスワードをゲット。

grep -i ".*" /etc/natas_webpass/natas11 dictionary.txt 

Level11

XOR 暗号(XOR cipher)のキーを見つける問題。PHParray( "showpassword"=>"no", "bgcolor"=>"#ffffff") みたいなオブジェクトが文字列化され、XOR 暗号にかけられ、Base64エンコードされたものがクッキーに保存される。XOR 暗号以外の部分は可逆操作なので XOR の前後の文字列が求まる。XOR 暗号のキーは暗号化の前後の2文を XOR にかければ求まるので、これを求める処理を PHP で書いた。見つかったキー qw8J を設定して array( "showpassword"=>"yes", "bgcolor"=>"#ffffff") を暗号化した文字列をキャッシュに設定し、それを読み出すとパスワードが表示される。

この課題は実装も面倒くさかったが、それ以外にもかなりハマった。というのも Chrome でキャッシュされてる文字列を見ると ClVLIh4ASCsCBE8lAxMacFMZV2hdVVotEhhUJQNVAmhSEV4sFxFeaAw%3D というものだった。最後の %3D= を URL エンコードしたものだと気づかず、見つかった XOR 暗号のキーを設定して array( "showpassword"=>"no", "bgcolor"=>"#ffffff") を変換しても上記のものとは一致しないことに頭を悩ませていた。2つの文字列の並べて初めて URL エンコードに気づいた。クッキーに保存されるタイミングで一部の記号が URL エンコードされるのか?PHP の仕様だと思うがよく分からん。

Level12

ファイルをアップロードできるうえ、アップロードしたファイルをサーバから取得できる。拡張子も手元の html を書き換えることで好きに設定できる。しかし攻撃できそうなところが見つからなかった。まったく分からん。

答えをググる。どうやら任意のファイルをアップロードできるので PHP ファイルをアップロードすると PHP を実行できるらしい。なるほど。そして /etc/natas_webpass/natas13cat するコマンドを実行させることでパスワードを入手できる。ウェブアプリとか作るときにはアップロードできるファイルの種類に制限をかけるのが安全ということなんだろう。

Level13

前の課題とほぼ同じであるうえ、アップロードできるのが画像のみに制限されたバージョン。すでにできる気がしない。とりあえず exif_imagetype をググってみたら以下のような記事を発見。

この関数はファイルの先頭バイトを見て画像であるかどうかをチェックしているだけなので、先頭だけごまかせば PHP のコードを書けるらしい。拡張子は filename 属性として、アップロードするファイルとは別に設定できるので、結果として拡張子が PHP のコードをアップロードできる。そしてパスワードゲット。

Level14

与えたユーザとパスワードが SQL に組み込まれて実行される。これは SQL インジェクションができそう。さすがに知っている。Web アプリ実装するときは SQL クエリを作るときに文字列結合はしたらダメ。SQL の WHERE 句の条件式に or "a" = "a" のような常に真になる条件が付くように入力値を設定すればパスワードゲット。

Level15

データベースの users テーブルにアカウントが存在するかどうかを調べる処理が実行される。SQL に任意の条件を埋め込めて、その結果が空かどうかだけ分かる状況。2分探索とかでパスワードの範囲を絞るとかはできそうだけどそれ用のプログラム書くか?

書くしか無さそうなのでプログラム書いてみた。SQLLIKE 文で指定するパターンに対して1文字ずつ末尾に追加してみてなおマッチするかどうかを調べた。case sensitive なマッチにするために BINARY password LIKE (候補) のような SQL を発行する。追加する文字の範囲は ASCII コードで33番目から126番目あたりまで広めに取ってみた。ただし LIKE を使うので _& は除外した。その結果パスワードらしき文字列が得られた。ググってみてもこうやるのが正解だったっぽい。こりゃ大変だ。

疲れたのでおしまい。Level16 以降はまた今度の機会にしよう。

所感

前回の badit 編より頭を使う課題が多かった印象。セキュリティも PHP も特に詳しく知らなかったが意外と解ける課題が多かった。たまにやるには楽しい。それにしても、悲しいことにウェブアプリには脆弱になりそうな箇所がたくさん存在するんだな。世の中のウェブエンジニアに今日もご苦労さまですと言いたい。

OverTheWire やってみた(Bandit 編)

最近、CTF(Capture the flag)という言葉をよく聞くので入門者向けのやつをやってみた。YouTube 上にビギナー向けの CTF サービスを紹介している 動画 があり、それが OverTheWire の Wargames をオススメしていたのでやってみた。

overthewire.org

今回は初心者向けの bandit 編を順番に解いてみた。次の問題に行くにはパスワードが必要であり、そのパスワードを発見するのが各問題の課題という感じ。ちなみに自分は Linux は自宅用デスクトップ PC として普段使いしているので基本的なコマンドや OS の概念はだいたい分かっている。しかし、何事もあまりマニアックな使い方をしない性格なので、そういう知識を要するのがあったら詰みそう。

ではレッツゴー。

Level 0

ターミナルから ssh コマンドを叩いてリモートサーバにログインするだけ。ユーザは bandit0

Level1

cat コマンドを叩いてファイルの中身を見るとパスワードの文字列がある。これがユーザ bandit1 のパスワードになっているので再度 ssh でログインする。

Level2

パスワードは - という名前のファイル内にある。cat - だとコマンドオプションだと認識されるのが、ファイルのフルパスで指定すると意図通りのファイルが認識される。知らなかったがググればすぐ分かる話ではある。

Level3

パスワードを含むファイル名に空白文字(スペース)が含まれるパターン。ファイル名をエスケープして cat すればよいが、これは tab キーでファイル名を補完できるのでなんてことない。これは Linux を普段使いしているとたまにやる操作。

Level4

パスワードは隠しファイル内にあるだけ。簡単。

Level5

パスワードが複数のファイルのうちどれかの中にある。ほとんどのファイルが空っぽなので、find . -type f | xargs wc -l あたりでファイルごとの行数を見て発見した。

Level6

パスワードが複数のファイルのうちどれかの中にある。ファイルサイズがヒントとして与えられている。find . -type f | xargs -d '\n' wc | grep 1033 あたりでファイルサイズで調べて発見した。

Level7

パスワードが複数のファイルのうちどれかの中にある。user と group とファイルサイズがヒントとして与えられているので、find / -type f | xargs ls -l | grep bandit7 あたりでユーザ名で grep すると1件しか見つからなかったのでそれが答えだった。find コマンドのマニュアル読めばユーザの検索とかできそう。

Level8

パスワードが指定のファイル内にあることが分かっている。そのファイル内において millionth という文字列の横にパスワードがあるらしい。普通に grep したら発見。

Level9

パスワードが指定のファイル内にあることが分かっている。このファイルに含まれるユニークな文字列が正解らしい。リダイレクトしながら diff ってみた。diff <(sort -u data.txt) <(sort data.txt | uniq -D | uniq) で差分が1行だけになったのでこれが正解。なんか無理矢理感がある。ユニークな列だけ抽出するみたいなコマンドオプションがあればもっとスマートにできそう。

Level10

パスワードが指定のファイル内にあることが分かっている。このファイルに含まれる、人間が読めて いくつか = が続く文字列の直後がパスワードらしい。ってことで strings コマンドを用いて ASCII 文字列っぽいものを抽出してみた。strings data.txt | grep == で答えを発見。strings コマンド初めて使った。バイナリファイルの解析とかに使えそう。

Level11

パスワードが Base64エンコードされている。base64 -d data.txt みたいに base64 コマンドを使ってデコードするとパスワードを発見。

Level12

パスワードが ROT13 により暗号化されている。シーザー暗号の一種。ただし数字はそのまま。各文字を元に戻せばよいので tr コマンドで置換した。対応する文字を愚直に列挙した。パスワードの復号に成功。

cat data.txt | tr -s 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 'nopqrstuvwxyzabcdefghijklmNOPQRSTUVWXYZABCDEFGHIJKLM'

Level13

パスワードを含むファイルが何らかの方式で繰り返し圧縮されている。圧縮ファイルの先頭を見れば圧縮方式が分かるようになっており xxd コマンドで見ながら bzip2 ,gzip, tar コマンドを繰り返し適用した。けっこうダルかった。結果的にパスワードを含むテキストファイルが出てきた。

圧縮ファイルの方式は以下を参考にした。よく見ると2つ目の回答の file コマンドをたたくやり方の方が楽そう。

stackoverflow.com

Level14

ssh コマンドを bandit14 でログインするが、このときプライベートキーを食わるだけ。

Level15

telnet コマンドでポート 30000 に現在のパスワードを送信する。具体的には telnet localhost 30000 で接続したあとパスワードを貼り付ければオッケー。

Level16

SSL 通信でポート 30001 に接続する。openssls_client を使うと SSL 通信ができる。openssl s_client -connect localhost:30001 を実行すればオッケー。このあたりは使い慣れてないのでググりながらやった。

www.openssl.org

Level17

localhost のポートの範囲 31000 - 32000 のどれかに SSL 通信を受け継いているものがあるらしい。nmap コマンドを使って nmap -p31000-32000 localhost のようにスキャン。いくつか有効なポートが見つかるのでひとつずつ opensll s_client -connect で接続できるか試した。ひとつだけ秘密鍵を返すものがあたのでそれを次の SSH に食わせた。

nmap コマンドはよく知らなかったので man コマンドでマニュアルを見た。

Level18

diff コマンド叩くだけ。

Level19

パスワードは readme ファイルにあるが、.bashrc の設定により、ssh コマンドでログインすると即座にログアウトされてしまうらしい。ってことで ssh コマンドを叩くときに cat readme することでパスワードを取得できた。ssh って Docker 動かすときみたいにコマンド叩けるんだ。知らなかった。

Level20

/etc/bandit_pass/bandit20 にパスワードが書かれているが、現在のユーザ bandit19 では読み取り権限がない。そこで、bandit20-do コマンドを使って一時的に bandit20 の権限を獲得したうえでパスワードが書かれたファイルへアクセスすればよい。具体的には ./bandit20-do cat /etc/bandit_pass/bandit20 というコマンドを実行した。

こういう権限を一時的に獲得するような仕組みは setuid と呼ばれている。実行ファイルに対して chmod コマンドで設定できるらしい。

Level21

tmux コマンドでターミナルを2画面開く。片方のターミナルで nc -l -p 1234 のようにサーバを起動して接続待ちをする。それに対して、もう片方のターミナルで ./suconnect 1234 のようにアクセスすることで接続が確立される。nc コマンドを実行している方で前回のパスワードを送信すると次のパスワードが返される。nc コマンドの使い方がよく分かっていなかったのでけっこう時間かかった。

Level22

cron の設定ファイルが /etc/cron.d/ の中にいくつかある。それらはスクリプトファイル(sh ファイル)を実行するように設定されているが、その中に現在のユーザ bandit21 でも読み取り権限があるスクリプトがある。それの中を見るとパスワードの書かれた /etc/bandit_pass/bandit22cat していることが分かるので、そのリダイレクト先のファイルを見ればパスワードが手に入る。

ちなみに cron の設定ファイルである crontab ファイルの5つのアスタリスク * * * * * は指定されたコマンドを毎分実行するようなスケジュールを意味する。

Level23

前の問題と同様に現在のユーザ bandit22 で読み取り権限のあるスクリプトが1つだけある。その中身を見ると、whoami コマンドの実行結果をもとに計算したハッシュ値をファイル名としてパスワードを書き込んでいることが分かる。cron によってこのスクリプトが実行されるときのユーザは bandit23 なので bandit23 という文字列を与えてこのハッシュ値を計算してみる。すると書き込まれたファイル名が分かるのでパスワードを入手できる。

Level24

bandit24 によって定期実行される /usr/bin/cronjob_bandit24.shスクリプトの中身を見ると /var/spool/bandit24 の中のスクリプトが実行されることが分かる。よってパスワードを出力する cat コマンドをこのディレクトリ内に仕込めばよい。具体的には、結果を /tmp 配下にリダイレクトする cat コマンドをもつスクリプトを作成し bandit24 が実行できるように chmod し、cp コマンドによってコピーした。1分ほど待てばパスワードがリダイレクト先のファイルに書き込まれる。

Level25

brute-force で 0000 から 999930002 番ポートに送り続ける必要がある。そこで for i in {0000..9999}; do echo (省略) $i; done | nc localhost 30002 | tail のように for 文を書いて入力候補をすべて作ったあとで nc コマンドでポートに送信。やがて答えが見つかった。

Level26

getent passwd | grep bandit26 するとユーザ bandit26 のログインシェルが bash ではなく/usr/bin/showtext になっていることが分かる。このなかでは more コマンドが実行されている。このため bandit26 にログインしてもすぐにログアウトされてしまう。

ここからどう打開するか分からずこの問題の正解をググった。どうやら、ターミナルの画面サイズを小さくしておくことで more の画面がコマンドを受け付ける状態で止まり、ここで v を押すと vi が立ち上がるらしい。なるほど。vi 内で :set shell=/bin/bash に続けて :shell を実行すればシェルが起動する。これで bandit26 としてログインできたことになる。

Level27

前の手順に続き、bandit27-do なる実行ファイルがあるため、これに cat を引数として与えるとパスワードを表示できる。

Level28

/tmp ディレクトリに移動して git clone するだけ。clone したディレクトリにパスワードが書かれている。

Level29

/tmp ディレクトリに移動して git clone する。なぜかパスワードがマスキングされているが、git log を見るとマスキング前のデータがありそうなことが分かる。そこで git diff (前のバージョンのハッシュ値) してパスワードゲット。

Level30

/tmp ディレクトリに移動して git clone する。パスワードは not in production らしい。リモートブランチを見てみると dev ブランチとかがあって怪しいのでこれを pull してきて中身をみるとパスワード発見。

Level31

git tag で secret という名前のタグを発見できる。これを git show するとパスワードの文字列を発見。これはなかなか時間かかった。タグという発想はなかった。

Level32

.gitignore に注意しながら push するだけ。ここに来て簡単。

Level33

打った文字が大文字になって実行されるシェルがいきなり起動する。大文字に変換されても問題なく実行できるコマンドが必要そう。しかし思いつかずこの問題の解法をググった。どうやら $0sh が呼ばれるらしい。これでコマンドが叩けるようになったのでパスワードを cat できるようになった。

2021年8月現在ではこれにてすべての課題クリア。めでたし。

所感

初めて CTF 系の問題をやってみたが、各問題が上手く作られた感じがあって飽きずに楽しめた。普段はネットワーク系など低めのレイヤに触れることがあまりないので今回はよい勉強になった。それにしても、手がかりが見つからないときは本当に何をしてよいのか分からない。CTF が強い人はたくさんの攻めパターンを知っていて、それを順番に調べてみるみたいなことをするんだろうと思う。これぞハッカーの道だ。また気が向いたらこういうのやろう。

トップカンファレンスVLDB2021に通った話

筆頭論文がよいところに採録された。めでたい。業務でやった研究なので内容に踏み込んだことは書くべきでないが、外からでも観測できる情報と今の気持ちについてメモる。

なにが起こったのか

筆頭論文が国際論文誌 PVLDB(Proceedings of the VLDB Endowment)に採録された。この論文誌に採録されると国際会議 VLDB(International Conference on Very Large Data Bases)でプレゼンテーションできる仕組みになっている。つまり、論文誌として常に論文の投稿を受け付けており、採録されたものに関しては1年に1回開かれる国際会議に招待される。変わった仕組みだが、いつ投稿してもよいので投稿する側としては非常にありがたい。

VLDB はデータベース分野で最高難易度のトップカンファレンスであり、CORE ranking では上位5%を意味する最高ランク A* が付いている。いわばデータベース研究者の夢の国際会議である。まぁ自分はデータベースの研究者ではないので ICSE とか PLDI とかに通せたほうが嬉しいんだが。今回論文が採録された VLDB2021 の Research Track の採択率は 23% だったらしい。けっこう低いが一時期よりは 高いっぽい

ちなみに採録された論文は以下。テーマはプログラム合成。

arxiv.org

アカデミックで頑張るためのメモ

こういうめでたいときは自身を振り返るのにもよいタイミングである。ということで、どうやればコンピュータサイエンス系のトップカンファレンスに(運が良ければなんとか1本くらいは)論文を通せるのかを書きたい。本当はよくある「統計検定準1級合格のために勉強したこと〜問題集編〜」みたいなノリでまとめるのが面白そうだが、他人と違うことをするのが研究活動なのでなかなか上手く整理できそうにもない。

参考になったもの

とりあえず論文を書くにあたって参考になったものを書いてみる。ちなみに私自身はコンピュータサイエンス修士号だけ持っている、卒業してから4年くらい経った社会人。

English for Writing Research Papers

この書籍には英語論文を書くテクニックがたくさん載っている。めっちゃ参考になった。この本に書かれていることの多くはシンプルかつ明瞭な文を書く方法である。英語を書くときは「もっとかっこよく書きたいなぁ」と思うものだが、かっこよさよりシンプルで分かりやすい文を書くことが重要だと教えてくれる。ちなみに分量が多いので、いざ論文を書くとなってからすべてを読むようなものではない。

www.amazon.co.jp

松尾ぐみの論文の書き方

こちらは、かの有名な松尾先生が論文を書くときのノウハウをまとめたウェブページである。ここに書いてあるような内容を知らないとたぶんちゃんとした論文は書けない。個人的には「失点を少なく、守りの野球を」というメッセージが印象的だった。「2-0とか3-0とかいう試合をすべきで、10-8とか15-12とかの大味な試合をしてはいけません。 」と教えてくれる。さっと読めるので読もう。そして実践することが何より大切。

ymatsuo.com ymatsuo.com

How to Write a Great Research Paper

こちらは YouTube の動画である。Microsoft Research の Simon Peyton Jones 氏が研究の進め方と論文の書き方について指導している。以前、このスライドを誰かが和訳したものを見た気がするので興味ある人は探すべし。内容として印象的だったのは、論文はたった1つの "ping" すなわち one clear, sharp idea をプレゼンするようなものであるべきという話。だいたいの場合、論文執筆は内容を盛り込む作業より削る(洗練する)作業が重要になってくる。そういう感覚が身についた。

www.youtube.com

他には英文ライティングの本とかはいくつか読んだ。特に冠詞の付け方はちゃんと勉強した。もちろん他にもいろいろ勉強したんだが、ぱっと思い浮かぶのはこのあたり。

個人的に心がけていること

ここからはポエム。参考になるかも知れないし、ならないかも知れない。

Support All Claims

論文では、ストーリーのなかで手法の貢献を明確に主張(claim)し、それが実験等で裏付け(support)されるようにする。このような claim と support のペアを常に意識している。すべての claim が support されているかチェックするし、support できない claim は論文に書かないようにする。というのも、support されていない claim を含む研究はサイエンスとして破綻しており、レビュアーとしてはアクセプトするわけにはいかないからである。レビュアーの好みとかそういう次元に行く前に落とされる明確な理由になるので、すべての claim はしっかり support する。

論文をたくさん読む

論文を読んで知識を付けるという行為は量が質の変化をもたらす類のものだと思う。論文をたくさん読むと技術を体系的に整理できるのはもちろんだし、トップ研究者たちと同じ空気を吸っている感覚がわいてくる。「こんな研究ができたらこのくらいスゴイ」っていう感覚はその分野の論文をたくさん読まないと分からないと思う。論文投稿して reject を食らうにしても、同じ空気を吸えていたらレビューを受け入れられるし、同じ空気を吸えていないとレビューコメントが意味不明で「研究おもんな」ってなる。まぁ本当に意味不明なレビューもあるんだろうが。

コンピュータサイエンスの基礎を身につける

新しい研究を効率よく理解したいだけなら新しい論文をたくさん読めばよい。しかし、自分で研究をやるとなると基礎力のある人が強いと思う。というのも研究活動ってのは宝探しみたいなもんで、コンクリートの地面をスコップで掘ろうとしているみたいなことが容易に起こりうる。そしてその地面がコンクリートであることが大学の学部レベルの知識で分かることもある。たとえば、解きたい問題を真面目に定式化すると効率的な解法が知られていない問題に帰着するとか。それとは別に、研究に対する指摘として「ナイーブなアプローチではなぜダメなのか?」というのはどんな研究にもつきまとう。こういうのは素朴なやり方を知らないとディフェンスしようがない。

自分にコンピュータサイエンスの基礎が身についたと思うタイミングは2つある。1つは院試で、もう1つは院生時代の独学。

1つ目の院試について。学部時代は真面目に勉強してなかったので、院試は過去問を覚えながらなんとか乗り切った。それだけだとよくあるエピソードだが、自分の場合は M1 のときに留学生の院試勉強の手助けをする担当に任命された。とはいっても院試の内容なんて忘れてしまっていたので、自分が院試を受けた1年後にふたたび院試勉強をやり直すことなる。これがよかった。その間に卒論を書いたこともあり教科書的な内容の意義がよく理解できた。当たり前の話だが、教科書に載るような内容は汎用性が高く科学的な価値も高い。これに気づいた。

2つ目は院生時代の独学。M1 あたりからコンピュータサイエンスの楽しさに気づいた。研究科の近くに理工学図書館があり、たくさんの技術書があることにも気づいた。そこで、院生の時間のうち自分の研究に割くのは4割くらいにして、残りの時間は図書館で本読んだりプログラム書いてみたりした。その過程で試行錯誤しながら独学できるようになったことが一番の収穫であった。本当によい時間だったと思う。結局のところ、技術を理解するっていうのは自分の中から湧き出る疑問を解消し続ける作業に他ならない。いろいろ勉強するほど疑問が湧いてくるし、それを解消するごとに理解も深まる。今でもこのブログにせっせと記事を書いてるのは、技術をドキュメンテーションする行為と疑問を解消する行為の親和性が高いからである。

雑感

いろいろ書いたが、なにごとも自分で実践して検証することが重要というのが現在の気持ち。どんな崇高な情報であっても、それを聞いただけで満足しているのは自己啓発セミナー受けてるのと何ら変わらない。結局のところ孤独に地道にやるしかない。

誤差逆伝播法とは何か

プログラマの良いところは離散的な考え方に親しんでいる点である。だいたいのプログラミングなんて組み合わせの列挙と計算量の解析ができればなんとかなる。微分計算なんて大学1年の数学で見納めたはずだった。しかしながら、機械学習を勉強するとやたら微分計算が登場するのである。特に、ニューラルネットワークバックプロパゲーション誤差逆伝播法)は意味が分からん。ってことで今回は気合を入れてバックプロパゲーションについて調べてメモってみた。

説明の方針

世の中のバックプロパゲーションの説明を読んでて何が苦しいかと言うと  i とか  j とか  k とか添字が多すぎて混乱を招くことである。人間は添字が多すぎるとやる気を失う。やる気を出して添字を細かく見ていっても、全体として結局なにが言いたかったのか腑に落ちない。

そこで本記事では、添字の煩わしさを軽減するためにベクトルや行列を "そのまま" 微分する方式で説明する。"ベクトルをベクトルで微分する" とか "スカラーを行列で微分する" みたいな説明になる。こういうのは Matrix calculus と呼ばれるらしい。たぶん日本語では "行列微分積分学" って感じだと思う。ちなみに英語版 WikipediaBackpropagation のページはこの方式を取っていて分かりやすい。説明の上手さに感心した。

プログラマであれば変数の "型" が明確でないのは気持ち悪いだろう。数学者の気持ちとしては「違うものが同じように見える」ことが嬉しいんだろうが、プログラマとしては異なる型を持つものは異なる表記であってほしい。そこで本記事ではベクトルは  \vec{v} のように矢印を付けて表記し、行列は  W のように大文字で、スカラー c のように小文字で書く。

準備:Matrix calculus

まずはベクトルや行列を偏微分するところから始める。内容は このページ を参考にした。

スカラーをベクトルで偏微分

スカラー  y をベクトル  \vec{x} = [x_1 x_2 \dots x_n]^{\top}偏微分した結果は以下のベクトルとなる。


\frac{\partial y}{\partial \vec{x}} = 
\begin{bmatrix}
\frac{\partial y}{\partial x_{1}} & \frac{\partial y}{\partial x_{2}} & \dots, \frac{\partial y}{\partial x_{n}} 
\end{bmatrix}

ベクトルをベクトルで偏微分

ベクトル  \vec{y} = [y_1 \dots y_m]^{\top} をベクトル  \vec{x} = [x_1 \dots x_n]^{\top} 偏微分した結果は以下の行列となる。


\frac{\partial \vec{y}}{\partial \vec{x}} = 
\begin{bmatrix}
\frac{\partial y_{1}}{\partial x_{1}} & \frac{\partial y_{1}}{\partial x_{2}} & \dots & \frac{\partial y_{1}}{\partial x_{n}} \\\
\frac{\partial y_{2}}{\partial x_{1}} & \frac{\partial y_{2}}{\partial x_{2}} & \dots & \frac{\partial y_{2}}{\partial x_{n}} \\\
\vdots &  \vdots & \ddots & \vdots \\\
\frac{\partial y_{m}}{\partial x_{1}} & \frac{\partial y_{m}}{\partial x_{2}} & \dots & \frac{\partial y_{m}}{\partial x_{n}}
\end{bmatrix}

スカラーを行列で偏微分

スカラー  y p \times q の行列  X偏微分した結果は以下の行列となる。


\frac{\partial y}{\partial X} = 
\begin{bmatrix}
\frac{y}{x_{11}} & \frac{y}{x_{21}} & \dots & \frac{y}{x_{p1}} \\\
\frac{y}{x_{12}} & \frac{y}{x_{22}} & \dots & \frac{y}{x_{p2}} \\\
\vdots &  \vdots & \ddots & \vdots \\\
\frac{y}{x_{1q}} & \frac{y}{x_{2q}} & \dots & \frac{y}{x_{pq}}
\end{bmatrix}

ニューラルネットワーク

本稿ではニューラルネットワークの構成要素を以下のように表記する。

  •  \vec{x}: 入力データ。特徴ベクトルのこと。
  •  \vec{y}: 出力データ。教師データとして与えられるやつ。
  •  c: 誤差関数(目的関数)。二乗誤差など。
  •  L: 層の総数
  •  W^{l}: 第  l 層への重みの行列。第  l-1 層の  k 番目のユニットから第  l 層の  j 番目のユニットへの重みを行列の成分 w^{l}_{jk} とする。
  •  f^{l}: 第  l 層の活性化関数。ロジスティック関数やシグモイド関数など。
  •  \vec{z}^{l}: 第  l 層の入力ベクトル
  •  \vec{a}^{l}: 第  l 層の出力ベクトル

これらの表記を用いてニューラルネットワークを図示すると以下のようになる。

f:id:t-keita:20210815021420p:plain:w600

よくあるニューラルネットワークの図といえばたくさんの円と線が並んだ複雑なやつだが、ここではすべての要素がベクトル or 行列で表現されている。実はニューラルネットワークを図示するにはこれで十分なのである。

ニューラルネットワークは、ベクトルを入力としてそのベクトルを順に処理してゆくだけの機械  g(\vec{x}) だと思えばよい。具体的には以下の関係が成り立つ。


g(\vec{x}) = f^{L} ( W^{L} f^{L-1} ( W^{L-1} \dots f^{1} ( W^{1} \vec{x}) \dots )  )

ニューラルネットワークの学習では、誤差関数  c を用いて  c(\vec{y}, g(\vec{x})) をなるべく小さくするような重 W^{1}, \dots, W^{L} を求める。この計算には確率的勾配降下法が用いられることが一般的であり、第  l 層の重み  W^{l} を更新するには誤差関数  c を行列  W^{l}偏微分した値  \frac{\partial c}{\partial W^{l}} が必要となる。ただし、この偏微分スカラーを行列で偏微分したものであるため結果も行列であることに注意する。

ちなみに確率的勾配降下法については以前の記事で紹介した。 t-keita.hatenadiary.jp

バックプロパゲーションの入出力

プログラマであれば手法を理解するときに入出力を明示してほしいと思うだろう。そこでバックプロパゲーションの入出力を以下に示す。

出力の偏微分(行列)を転置したものは勾配と呼ばれる。要するに、バックプロパゲーションは勾配法に用いるための勾配を求めるためのものである。

ちなみに、出力となる偏微分は出力層の第  L 層の  \frac{\partial c}{\partial W^{L}} から第1層の  \frac{\partial c}{\partial W^{1}} へと順に求まる。このような順序になるのはアルゴリズム動的計画法Dynamic Programming, いわゆる DP)になっているからである。詳細は後述する。

一方で、バックプロパゲーションは以下のようなものではない

  • 出力層に近い順に重みを更新するもの:重みが更新されるのは勾配法の実行時である。ちなみに、勾配法によって重みを更新するときは出力層に近い順である必然性はない。
  • 出力層から順に誤差を伝えるもの:バックプロパゲーションでは学習データを使わないので具体的な誤差は計算されない。具体的な誤差が考慮されるのは勾配法の実行時である。

とはいえ、バックプロパゲーションという言葉は学習アルゴリズム全体を指すようになってきているらしい。

バックプロパゲーションの仕組み

それでは本題のバックプロパゲーションアルゴリズムについて見てゆく。説明のうえで第  l-1 層と第  l 層の関係が重要になってくるので図を抜粋して再掲しておく。 f:id:t-keita:20210815103211p:plain:w600

まずは誤差関数について説明する。今回考える誤差関数は前述した通り以下である。


c(\vec{y}, f^{L} ( W^{L} f^{L-1} ( W^{L-1} \dots f^{1} ( W^{1} \vec{x}) \dots )  ) )

ここで、具体的な学習データを用いる代わりにニューラルネットワークの入出力として  (\vec{x}, \vec{y}) は固定であるとみなす。そのうえで重み  W^{1}, \dots, W^{L} を動かしたときの誤差関数 c の変化を調べることが目的である。

つぎに、誤差関数  c を第  l 層の入力ベクトル  \vec{z}^{l}偏微分して得られるベクトルを  \vec{\delta}^{l} とする。 \vec{\delta}^{l} に対して以下の関係が成り立つ。


\begin{align}
\vec{\delta}^{l} & = \frac{\partial c}{\partial \vec{z}^{l}} \\\
                        & = 
\frac{\partial c}{\partial \vec{a}^{L}} \cdot
\frac{\partial \vec{a}^{L}}{\partial \vec{z}^{L}} \cdot
\dots \cdot
\frac{\partial \vec{z}^{l+1}}{\partial \vec{a}^{l}} \cdot
\frac{\partial \vec{a}^{l}}{\partial \vec{z}^{l}}
\end{align}

ここで偏微分の連鎖律(Chain rule)を用いている。この積は出力層の第  L 層から第  l 層まで下っている。それゆえ、 \vec{\delta}^{l} \vec{\delta}^{l-1} について以下の再帰的な関係が成り立つ。


\vec{\delta}^{l-1} = 
\vec{\delta}^{l} \cdot
\frac{\partial \vec{z}^{l}}{\partial \vec{a}^{l-1}} \cdot
\frac{\partial \vec{a}^{l-1}}{\partial \vec{z}^{l-1}}

ここで偏微分計算について以下が成り立つ。ただし  I単位行列である。いずれもベクトルをベクトルで微分しているので結果は行列になる。


\frac{\partial \vec{z}^{l + 1}}{\partial \vec{a}^{l}} = W^{l}, \hspace{0.3cm}
\frac{\partial \vec{a}^{l}}{\partial \vec{z}^{l}} = (f^{l})' \cdot I

したがって  \vec{\delta}^{l}再帰式は以下で計算できる。


\vec{\delta}^{l-1} = \vec{\delta}^{l} \cdot W^{l} \cdot (f^{l-1})'

ここまでの話を整理すると、 \vec{\delta}^{L} が求まれば  \vec{\delta}^{L-1} が求まる。そして  \vec{\delta}^{L-1} が求まれば  \vec{\delta}^{L-2} が求まる。これを繰り返すことで  \vec{\delta}^{1} まで求めることができる。なお、 \vec{\delta}^{L} は定義通りに誤差関数を偏微分することで求まる。このように、漸化式を用いてボトムアップにすべての数列の値を求めるアプローチは動的計画法に他ならない。この計算過程が出力層に近い第  L 層から入力層の方向へ向かう。この方向はニューラルネットワークが入力から出力を計算する方向と逆であり、それが "バック" プロパゲーションという名前の由来である。

あと残っているのは、手法の目的だった  \frac{\partial c}{\partial W^{l}} の計算である。これは以下のように計算できる(らしい)。


\begin{align}
\frac{\partial c}{\partial W^{l}} & = \vec{a}^{l-1} \cdot \frac{\partial c}{\partial \vec{z}^{l}} \\\
                                              & = \vec{a}^{l-1} \cdot \vec{\delta}
\end{align}

すでに  \vec{\delta} は計算できることが分かっているので、手法の目的の  \frac{\partial c}{\partial W^{l}} も計算できることが分かった。めでたし。

メモ:この最後の計算は厳密にはよく理解できていない。というのも行列で偏微分するときは連鎖律が成り立たないらしく、計算を理解するには行列の各成分の添字を真面目に追わないといけないっぽい。そこまでする気力はなかった。この計算のイメージはおそらく、 \vec{z}^{l} = W^{l} \cdot \vec{a}^{l-1} であることから、重み  W^{l} の各成分を微小に動かしたときの誤差関数  c の変化が、対応する  \vec{a}^{l-1} の成分に比例するということだと思われる。

最後に、確率的勾配降下法の実行時の流れを整理する。勾配法の実行時には現在の重み  W^{1}, \dots, W^{L} が決まっている。ランダムに選択した学習データ  (\vec{x}, \vec{y}) のうち  \vec{x}ニューラルネットワークに与えるとすべての  \vec{z}^{l}, \vec{a}^{l} の値が定まる。また、 a^{L} \vec{y} の値を用いて  \vec{\delta}^{L} が計算され、その後  W^{l} を用いながらすべての  \vec{\delta}^{l} の値が定まる。そして  \frac{\partial c}{\partial W^{l}} の値も定まる。得られた勾配を下る方向に重みを更新する。この重みの更新を繰り返すという流れ。

おしまい。

確率的勾配降下法とは何か

機械学習において確率的勾配降下法(Stochastic gradient descent, SGD)ってよく耳にするけどよく分からない。「確率的勾配降下法 わかりやすく」でググった人は数知れないだろう。自分も SGD の本質がどこにあるのか分かっていなかったので改めて調べてみた。

準備:パラメータを "動かす"

本記事のトピックである勾配法では "パラメータを動かす" という考え方が重要となる。そこで、線型回帰を例としてパラメータを "動かす" とはどういうことか説明してみる。ここでは1変数の線形回帰を例として説明する。もちろん線形回帰は降下法なんか使わなくても最適なパラメータを求めることができる。あくまで説明のため線形回帰を例にしているに過ぎない。

1変数の線形回帰 とは、与えられた複数の点  (x_1, y_1), \dots (x_n, y_n) に対して誤差の和が最小になるような直線  y = f(x) を引く問題である。たとえば以下の図では赤い直線が解である。

f:id:t-keita:20210801020123p:plain:w500

直線と複数の点の誤差の和は 残差の二乗和(esidual sum of squares, RSS)として計算されることが一般的である。具体的には以下の式で与えられる。

 \text{RSS} = \Sigma^{n}_{i=1} (y_{i} - f(x_{i}))^{2}

いまは  y = f(x) は直線なので  f(x) = ax + b と表される。中学数学を思い出すと、 a は直線の傾きで、 by 切片である。これをふまえて上記の二乗和を a, b の関数  Q(a,b) とみなす。

 Q(a, b) = \Sigma^{n}_{i=1} (y_{i} - (ax_{i} + b)))^{2}

要するに、与えられた点  (x_1, y_1), \dots (x_n, y_n) は固定であり、パラメータ  a, b の値を変えることで関数  Q(a, b) の値が定まると考える。パラメータ  a, b を "動かす" ことで、目的関数  Q(a, b) の値をなるべく小さくするアプローチこそが勾配法である。

パラメータ  a, b を動かす様子を図示してみる。まずは直線の傾き  a のみを動かしてみる。以下の図は赤線で傾き  a=1 のとき、緑線で傾き  a=2 のときの直線を示している。

f:id:t-keita:20210801020216p:plain:w400

当然、 a の値を大きくすると傾きが大きくなり、小さくすると傾きが小さくなる。それと同時に点線の長さの和、すなわち残差の二乗和  Q(a, b) の値も変化する。 a の値を動かしながら  Q(a, b) の変化を観察すれば、 Q(a, b) の値を小さくするような  a の値が分かりそうである。

同様に、切片  b の値のみを動かしてみる。赤線で切片  b = -1、緑線で切片  b = 3 のときの直線を示している。

f:id:t-keita:20210801020405p:plain:w400

こちらも  b の値を動かすことで  Q(a, b) の値を小さくするような  b の値が分かりそうである。

最急降下法

ここまでの話で、パラメータを動かすと目的関数(誤差)の値も動くことを見てきた。このようにパラメータと目的関数の関係を用いて誤差を最小化する方法として 最急降下法(gradient descent)がある。ちなみにこの手法を適用するためには、目的関数  Q(a, b) がパラメータ  a, b によって偏微分可能である必要がある。

最急降下法の説明をする前に、先ほどの線型回帰においてパラメータ  a, b に応じて目的関数  Q(a, b) がとる値を等高線として図示してみる。この図は横軸が  a の値、縦軸が  b の値になっている。この平面上の点ごとに回帰直線が一意に定まることに注意したい。この平面上を動くことがまさにパラメータを動かすことになる。

f:id:t-keita:20210801020833p:plain:w400

結果として図示されるものは同心円状の等高線になっており、中心に行くほど目的関数(誤差)の値が小さくなっている。つまり、中心に行くほど望ましいパラメータになっている。

最急降下法は、適当なパラメータ  a, b からスタートして、この等高線と直交するようなルートをたどって最も低いところへ進んでゆく手法である。いわばアリジゴクの底に落ちるようなルートをたどる。具体的には、点  (a, b) 上において進む方向はその点における勾配を下る方向であり、その勾配はベクトル  \nabla Q = (\frac{\partial Q}{\partial a}, \frac{\partial Q}{\partial b}) (a, b) の値を代入することで計算できる。したがって、適当な  (a_0, b_0) からスタートし、以下のようにパラメータ  (a_n, b_n) を更新してゆけばよい。

更新式: (a_{n+1}, b_{n+1}) = (a_n, b_n) - \eta \nabla Q(a_n, b_n)

ただし、 \eta学習率(Learning rate)であり、更新の大きさを調整する役割がある。

たとえば線型回帰の場合、 \nabla Q は以下の2つの要素からなる。

  •  \frac{\partial Q}{\partial a} = -2 \Sigma^{n}_{i=1} x_{i}(y_{i} - (ax_{i} + b))
  •  \frac{\partial Q}{\partial b} = -2 \Sigma^{n}_{i=1} (y_{i} - (ax_{i} + b))

これらを用いて、パラメータの初期値  (a, b) = (3.5, -2) として最急降下法を実行した結果は以下である。ただし学習率は 0.001 とし、勾配が十分に小さくなるまでパラメータの更新を繰り返した。

f:id:t-keita:20210801113515p:plain:w400

初期値に対応する点からアリジゴクの坂を下るように底に落ちてゆく様子が見て取れる。ここで重要なことは、初期値が同じであれば何度実行しても同じルートをたどって同じ値に落ち着くということである。つまり "確率的" な挙動はしない。

最急勾配法はパラメータの更新にすべての点との誤差  Q(a,b)偏微分を用いる。このようにパラメータの更新にすべてのデータを使用する学習方法は "バッチ学習" と呼ばれている。実はバッチ学習はあまり実用的でない。というのも、バッチ学習の計算量は学習データや各データの複雑さに依存するため、たとえば学習データが1億件ある場合、1回のパラメータの更新に1億件に対する誤差の和を計算する必要がある。学習データが1億件というのは現代の機械学習では珍しいことではないというのが怖いところである。

確率的勾配降下法

最急勾配法の課題を解決するため、最急勾配法を乱択アルゴリズムとして近似したのが 確率的勾配降下法(Stochastic gradient descent, SGD)である。

SGD の基本的な仕組みは最急勾配法と同じであるが、パラメータの更新ごとにランダムに選んだデータ1つのみを用いるのが改良点である。つまり、パラメータの更新ごとにすべてのデータを用いる代わりに1つのデータだけ用いることで計算量を大幅に削減する。具体的には、ランダムに  i を選んだ点  (x_i, y_i) と現在の直線の誤差をもとにパラメータの更新を行う。

更新式: (a_{n+1}, b_{n+1}) = (a_n, b_n) - \eta \nabla Q_{i}(a_n, b_n)

ここで、 Q_{i} i 番目のデータと直線の誤差であり、線型回帰の例では  Q_{i} = (y_{i} - (ax_{i} + b)))^{2} である。これを各パラメータ  a, b によって偏微分して得られるベクトルが  \nabla Q_{i} である。線型回帰の例では具体的に以下である。ここに  \Sigma はない。

  •  \frac{\partial Q_{i}}{\partial a} = -2 x_{i}(y_{i} - (ax_{i} + b))
  •  \frac{\partial Q_{i}}{\partial b} = -2 (y_{i} - (ax_{i} + b))

より理解を深めるために補足すると、SGD では現状のパラメータに対応する直線(下図の赤色の直線)とランダムに選んだ1点との距離のみを考慮してパラメータを更新する。下図の例だと直線の傾き a も切片  b も増加する方向にパラメータが更新されるはずである。

f:id:t-keita:20210801121148p:plain:w400

そして、先ほどと同様にパラメータの初期値を  (a, b) = (3.5, -2) として SGD を3回実行した様子を以下に示す。

f:id:t-keita:20210801130155p:plain:w600

3回の実行をそれぞれ赤色、青色、緑色で示している。見てのとおり実行ごとに異なるルートをたどっている。これはもちろん学習に用いるデータをランダムに決めているからであり、乱択アルゴリズムとしての特徴が現れているといえる。これゆえ SGD は "確率的" な挙動をする。

ちなみに、SGD のように1つのデータだけを考慮してパラメータの更新を行うやり方は "オンライン学習" と呼ばれる。SGD はオンライン学習であるが、SGD の本質(stochastic と呼ばれる理由)は乱択アルゴリズムである部分にある。世の中には「最急勾配法をオンライン学習にしたものが SGD である」という説明が見られるが、これだけではランダムネスに言及できていないので説明として不十分であると思ったほうがよさそう。実際、少数のデータをパラメータ更新に用いる "ミチバッチ学習" も(オンライン学習ではないが) SGD の一種として有名である。

所感

今回初めて SGD を実装して実験的に動作を確認してみた。本記事の図を書いたコードは Gitst 上に置いた。勾配法の更新式を見ただけでは、代入してから偏微分するのか、偏微分してから代入するのか深く考えていなかったが、実装するにあたり後者が正しいことを認識できた。まぁ当たり前みたいな話なんだが。

マルコフ決定過程で脱出ゲーム

マルコフ決定過程と仲良くなりたい。一昨日書いた記事ではベルマン方程式の話をした。本を読んで理論を理解したつもりでも、自分で実装してみて初めて見える世界ってのもある。ってことで簡単なゲームを作って、マルコフ決定過程問題を解くアルゴリズムを実装してみた。

ベルマン方程式については前回の記事を参照のこと。

t-keita.hatenadiary.jp

問題設定

典型的なマルコフ決定過程の問題を解くだけでは面白さが足りないので自分でゲームを考えてみた。その名も、邪魔してくる人を避けながら部屋を脱出するゲーム。

ゲーム画面はこんな感じ。

f:id:t-keita:20210701222747g:plain:w300 

ゲームの設定を説明しよう。主役は A さんで、ゲーム画面上に黄色い文字で「A」って書かれているやつ。A さんは5マス × 5マスの部屋に閉じ込められていて、すぐにでも部屋から脱出したい。部屋の出口は赤い文字で「G」って書かれたところである。A さんは左上のマスからスタートして出口の G を目指す。これだけなら何の難しさもないゲームである。

ここで登場するのが A さんの移動を邪魔する B さんである。ゲーム画面上で黄緑色の文字で「B」って書かれたイヤなヤツである。B さんは A さんの右下からスタートして、その後は A さんの動きをマネし続ける。A さんが右に動けば B さんも右に動くし、A さんが下に動けば B さんも下に動く。ただし、B さんが壁際にいるときに A さんが壁の方向に動くとさすがに A さんのマネはできない。こういうときは B さんは動かない。一方で、A さんだけが壁際にいて壁の方向に動こうとすると、B さんはつられて同じ方向に動いてしまうものとする。要するにフェイントである。

お分かりのとおり、B さんがいる限りは A さんは G に到達できない。これではゲームとして成立しない。もうひと工夫しよう。

f:id:t-keita:20210701224448g:plain:w300

そこで導入するのが「柱」であり、ゲーム画面上にある白い長方形である。マスに柱があると、A さんも B さんもそのマスに移動できなくなる。もちろん、B さんが A さんの動きをマネして動こうとした先に柱がある場合は B さんは動けない。この柱があれば A さんは B さんを回避して出口 G まで辿り着けそうである。ゲームの設定は以上である。

なんとかして A さんを部屋から出してあげたい。もっというと 出口 G に到達するための最善の動き方 を A さんに学習してほしい。これが今回の最終的な目標である。

マルコフ決定過程として定式化する

このゲーム設定をマルコフ決定過程(MDP)として定式化し、それを解くことで A さんの最善の動きを求めよう。前回の記事でも載せた MDP のベイジアンネットワークを以下に示しておく。 s \in \mathcal{S} が状態で、a \in  \mathcal{A} が行動であり、 r \in \mathcal{R} が報酬である。

f:id:t-keita:20210630021132p:plain:w400

MDP として以下のように定式化した。MDP の構成要素については Wikipedia のページ を参照のこと。

  • 状態は A さんの座標と B さんの座標の組とした。すなわち、状態集合  \mathcal{S} は2人の座標のすべての組み合わせである。
  • 行動  \mathcal{A} は、A さんが「左」「右」「上」「下」に移動するという4つからなる。
  • 状態遷移確率 p(s' \mid s, a) は A さんと B さんの移動に対応する決定論的な遷移とした。たとえば、A さんと B さんの座標がそれぞれ (2, 3), (3, 4) である状態を  s として、行動  a を「右」とすると、遷移後の状態  s' は (3, 3), (4, 4) となる。これは p(s' \mid s, a) = 1 とすることで意図する遷移を表現できる。
  • 報酬関数  g(r \mid s, a) は、A さんが G 以外の場所にいるときは -1、G にいるときは 0 とした。この報酬は毎時間ごとに発生するので、A さんが G 以外の場所にいる時間が長くなるほど報酬が小さくなってしまう。この報酬の設計を通して、なるべく早く G に到達したほうがよいことを A さんに教えてあげている。
  • 目的関数は  G_t = R_{t+1} + R_{t+2} + R_{t+3} + \dots とした。

この MDP 問題を解くことで、A さんの移動を決める方策  \pi(a \mid s) を最適化する。すなわち最適方策を求める。A さんがどの状態にいても最善の行動が取れるようになるわけである。

価値反復法を実装する

前回の記事の通り、MDP を解くにはベルマン最適方程式を利用して価値関数を繰り返し更新すればよい。以下のそのコードを示す。言語は Kotlin である。

    // Initialize the value function
    val values = mutableMapOf<State, Int>().apply {
        allStates().forEach { this[it] = 0 }
    }
    // Value iteration using Bellman optimality equation
    do {
        var delta = 0
        allStates().forEach { state ->
            val vTmp = values[state]!!
            // update the value of a state
            values[state] = Action.all().maxOf { action ->
                reward(state, action) + values[nextState(state, action)]!!
            }
            delta = max(delta, abs(vTmp - values[state]!!))
        }
        println(delta)
    } while (delta > 0)

なんの工夫もなく価値反復法を実装している。このコードの全体は Gist 上に置いた

コードを実行してみる

それでは A さんの最終的な行動を見てみよう。頼むぞ、A さん。

柱を真ん中に置いた場合

柱を真ん中において MDP を解いた結果得られた最適方策は以下の通りである。

f:id:t-keita:20210701234124g:plain:w300

見事に A さんは B さんを柱に引っかけて出口 G に到達できている。A さんは寄り道することもなく、最短経路で G に到達している。自分の報酬設計は間違っていなかった。ありがとう、A さん。

柱を端に置いた場合

次は柱を端っこの方に置いたときの A さんの行動を見てみよう。さっきより一筋縄では行かなそうな雰囲気があるぞ。

f:id:t-keita:20210701224352g:plain:w300

こちらも見事に B さんを柱の陰に封印して出口 G に到達できてる。最初に B さんだけが左に移動したのは、A さんが左に移動しようとしたからである。いわばフェイントを使って B さんを壁際に移動させたのである。やるなぁ、A さん。

所感

今回は適当なゲーム設定を作ってマルコフ決定過程として定式化し、それを解くアルゴリズムを実装してみた。状態遷移をどう設定するか、報酬をどう設定するか、いざ考えてみると少し悩ましかった。それゆえけっこう勉強になった。

MDP を解くアルゴリズムの実装は全体で100行くらい。最適化計算には1秒もかからなかった。こんなちょっとしか書いてないのに一瞬でスマートに振る舞う A さんが誕生するのはスゴい。MDP として定式化できる問題は理論上なんでも同じように解けるんだから、この仕組み考えたベルマンさんマジで天才。

あと別の気付きとして MDP の解法が Dynamic Programming と呼ばれている意味を理解した。不動点を近似計算するのは DP の要件ではないので、これまでなぜこの種のアルゴリズムが DP と呼ばれているのが理解できていなかった。結局のところ、価値が最大になる行動を greedy に選べば最適方策が得られるという構造が DP になっている。これはダイクストラ法の構造とまったく同じ。価値関数の計算が DP なのではなく、価値関数の使い方が DP ってことなんだよなきっと。f:id:t-keita:20210702000458p:plain:w0

ベルマン方程式とは何か

今日はベルマン方程式について。強化学習を勉強していると最初の方に出てくるやつ。専門書を読みながら式変形を追うのに精いっぱいになっていると、いつの間にかその式がもつ意味を見失っちゃいますよね。

ってことで本記事では、細かい式変形は追わずに「ベルマン方程式の気持ち」を説明することを目指す。説明の流れを重視するので残念ながら厳密性は大いに欠いている。解の一意性とか極限の収束性とか。ちなみに、ベルマン方程式は強化学習だけでなく経済学などにも応用があるらしい。どんな経済の問題を解くのに使われるんだろうか。

本稿の目次は以下の通りである。

なお本稿の内容や数式の記法は Richard S. Sutton 氏らの Reinforcement Learning: An Introduction, second edition に基づいている。この本の表記は、確率変数は大文字、その他の変数は小文字で表記されているのがよい。

準備:Fixed-Point Iteration

まずは関数の不動点を求める数値計算のテクニックから紹介する。

関数 f不動点(fixed point)とは、 x = f(x) を満たすような x である。つまり、関数を適用しても値が変わらない入力が不動点である。

ある "よい性質" をもつ関数 f に対して不動点、すなわち方程式  x = f(x) の解を求めるテクニックとして Fixed-Point Iteration というやり方が知られている。Fixed-Point Iteration では、 n= 0, 1, 2, \dots に対して  x_{n+1} = f(x_{n}) を適用してゆくことで数列  x_0, x_1, x_2, \dots を得る。このとき  x_0 の値は適当に設定する。この数列の極限が  x に収束するとき、 x が関数 f不動点になることが知られている。要するに、関数 f を適当な  x_0 に適用しまくることで不動点が得られるわけである。非常に簡単な話である。ただし、この不動点の求め方ができるのは関数 f が "よい性質" をもつ場合のみであった。この条件について詳しくは Wikipedia のページ を参照のこと。

Wikipedia のページに載っている例を紹介する。ここでは方程式  x = \text{sin}(x) の解、すなわち関数  \text{sin}(x)不動点を求めることを考える。たとえば x_0 = 2 を初期値として  x_1 = \text{sin}(x_0), x_2 = \text{sin}(x_1), x_3 = \dots を求めてゆく。この数列の極限は 0 に収束するため、求めたい不動点は 0 であることが分かる。実際、 \text{sin}(0) = 0 でありこの解は正しい。この極限を求める過程を示したのが以下の図である。

f:id:t-keita:20210628232800p:plain:w600

青色の線で描かれているのが曲線  y = \text{sin}(x) と直線  y = x である。 \text{sin}(2) などの関数の出力を次の関数の入力にするために、直線  y = x によって y 軸の値を x 軸の値へ変換している。赤色の線が関数  \text{sin}(x) の入出力の値を示しており、徐々に不動点x = 0 上の点(原点)に近づいている。この計算はコンピュータで計算可能であり、関数適用の回数を増やすほど精度の高い解が得られる。

重要なこととして、ベルマン方程式はまさに  x = f(x) という形をもった方程式である。しかも、ベルマン方程式における関数 f は "よい性質" を持っている。したがって、関数  f を適用しまくるだけでベルマン方程式の解が求まるのである。

マルコフ決定過程

次にマルコフ決定過程Markov decision process)について説明する。とはいっても真面目に説明すると長くなるので、ここでは本稿の表記を明確にするくらいにしておく。このモデルの意味や背景など詳しくは世の中の書籍を参照のこと。

  • 状態  s \in \mathcal{S}, 行動  a \in \mathcal{A}, 報酬  r \in \mathcal{R}
  • 確率を出力とするダイナミクス関数  p(s', r \mid s, a)
  • 時刻  t における確率変数:状態  S_t, 行動  A_t, 報酬  R_t
  • 目的関数: G_t = R_{t+1} + \gamma R_{t+2} + \gamma^{2} R_{t+3} + \dots

典型的なマルコフ決定過程の問題は目的関数  G_t を最大化するための方策  \pi(s \mid a) を求めることである。また、方策  \pi に従ったときの状態  s の価値関数を  v_{\pi}(s) = \mathbb{E}[G_t \mid S_t = s] とする。

おまけとして、方策  \pi(s \mid a) を含めたマルコフ決定過程ベイジアンネットワークを以下に示す。状態  S_t において方策が出力した行動  A_t が、報酬  R_t や次の状態  S_{t+1} に影響を与えることがひと目で分かる。

f:id:t-keita:20210630021132p:plain:w400

方策から評価関数の計算(ベルマン期待方程式)

まずは方策反復評価(iterative policy evaluation)と呼ばれる、方策から価値関数を計算するアプローチについて説明する。

前提として価値関数  v_{\pi}(s) は方策  \pi に依存する。よい方策を選ぶほど価値関数の値は大きくなる。ただし、方策  \pi が与えられたときに価値関数  v_{\pi}(s) の値を計算する方法は自明でない。この計算方法を以下に示す。

まず以下の ベルマン期待方程式(Bellman equation) を考える。"方程式" という名前が付いているが未知数は含んでいないことに注意する。英語の "equation" という言葉は必ずしも未知数を含まない単なる "等式" くらいの意味である。

 v_{\pi}(s) = \sum_{a} \pi(a \mid s) \sum_{s'} \sum_{r} p(s', r \mid s, a)[r + \gamma v_{\pi}(s')]

この式は  v_{\pi}(s) の定義から導出されたものである。ポイントは、状態 s から期待される合計報酬  v_{\pi}(s) を、遷移先の状態  s' から期待される合計報酬  v_{\pi}(s') によって表現している点である。すなわち再帰的な構造になっている。

式をスッキリさせるためにベルマン期待方程式の右辺を  v の関数  \textbf{B}_{\pi}(v) で置換し以下の式を得る。

 v_{\pi} = \textbf{B}_{\pi}(v_{\pi})

ただし  (\textbf{B}_{\pi}(v))(s) = \sum_{a} \pi(a \mid s) \sum_{s'} \sum_{r} p(s', r \mid s, a)[r + \gamma v(s')]

プログラミングっぽく言うと、関数  \textbf{B}_{\pi}(v) は関数を引数とする高階関数である。関数  v の値域は状態集合であり有限なので、 v は数学的な関数というよりは Python連想配列Java の Map 型オブジェクトをイメージすればよい。連想配列の同値性が同じキーに対して同じ値がマッピングされているかどうかで判定されるように、 v_{\pi} = \textbf{B}_{\pi}(v_{\pi}) も同じやり方で関数としての同値が成立することを示している。

ここで  v_{\pi} は以下の方程式の解である。

 v = \textbf{B}_{\pi}(v)

なお、価値関数  v_{\pi}と異なる価値関数  v' はこの方程式の解にならないことが知られている。つまり、ベルマン期待方程式を一般化したこの方程式を解くことで方策  \pi に従う価値関数  v_{\pi}(s) を求めることができる。この方程式をなんとかして解きたい。

この方程式を見れば分かるが、 v は 関数  \textbf{B}_{\pi}不動点になっている。そのためこの方程式を解くには不動点の計算を行えばよい。具体的には、適当な  v_0 に対して関数  \textbf{B}_{\pi} を繰り返し適用し  v_1, v_2, \dots の値が収束すればそれが  v_{\pi} である。これにて、与えられた方策  \pi から価値関数  v_{\pi}(s) を導くことができた。めでたし。

以上が方策反復評価と呼ばれるアプローチである。ここまでの話を簡単に整理する。特定の値  v_{\pi} を求めたい状況であり、 v_{\pi} v_{\pi} = \textbf{B}(v_{\pi}) を満たすことは分かっていた。そこで、一般化した方程式  v = \textbf{B}(v) を解くことでその解として  v_{\pi} の値を知ろうとした。この方程式の解を不動点の計算によって求めた。

価値反復法(ベルマン最適方程式)

上記の方策反復評価では方策  \pi から価値関数  v_{\pi} を求める方法を示した。ここでマルコフ決定過程の問題を思い出すと、やりたいことは目的関数  G_t を最大化するための方策  \pi を求めることであった。そこで、この最適な方策を 最適方策 と呼び  \pi_{\ast} と表すことにする。また、最適方策  \pi_{\ast} のもとでの価値関数を  v_{\ast}(s) とする。

実は、最適な価値関数  v_{\ast}(s) さえ求まれば、そこから最適方策  \pi_{\ast} を作ることは簡単である。具体的には、各状態  s から遷移できる状態   s' のうち価値関数  v_{\ast}(s') が最大の状態に常に遷移するような方策が最適方策  \pi_{\ast} である。貪欲(greedy)によい状態を選ぶ方策がよいということである。

すなわち、最適方策の価値関数  v_{\ast}(s) を求めることがマルコフ決定過程問題を解くことに直結する。以下、 v_{\ast}(s) を求めるための価値反復法(value iteration)の流れを説明する。

まずはベルマン期待方程式を立ててみる。

 v_{\ast}(s) = \sum_{a} \pi_{\ast}(a \mid s) \sum_{s'} \sum_{r} p(s', r \mid s, a)[r + \gamma v_{\ast}(s')]

これで不動点を求める計算をすれば価値関数  v_{\ast} が求まりそうであるが1つ問題がある。それは、右辺の方策  \pi_{\ast} こそが今から求めたいものであり未知なのである。よってこの式からは価値関数  v_{\ast} を上手く求められない。困った。

そこで登場するのが ベルマン最適方程式(Bellman optimality equation) である。

 v_{\ast}(s) = \text{max}_{a} \sum_{s'} \sum_{r} p(s', r \mid s, a)[r + \gamma v_{\ast}(s')]

この式のポイントは最適方策  \pi_{\ast} を含まない  v_{\ast}(s)再帰的な式になっていることである。最適な価値関数を考えているので、確率的に状態遷移する方策の代わりに貪欲な方策、すなわち  \text{max} を使えばよいのである。非常によくできている。

あとは簡単で、不動点を求めるために関数  \textbf{B}_{\ast}(v) を用いて以下のように変形する。

 v_{\ast} = \textbf{B}_{\ast}(v_{\ast})

ただし  (\textbf{B}_{\ast}(v))(s) = \text{max}_{a} \sum_{s'} \sum_{r} p(s', r \mid s, a)[r + \gamma v_{\ast}(s')]

例によって  v_{\ast} は以下の方程式の解である。

 v = \textbf{B}_{\ast}(v)

これは 関数  \textbf{B}_{\ast}不動点なので、適当な  v_0 に関数  \textbf{B}_{\ast} を繰り返し適用することで  v_{\ast} に収束する。あとは貪欲に最適方策  \pi_{\ast} を求めればおしまい。マルコフ決定過程が解けた。めでたしめでたし。

ベルマン方程式の気持ちを整理する

最後に2つのベルマン方程式を以下の図に整理する。

f:id:t-keita:20210630023224p:plain:w600

ベルマン期待方程式は、方策  \pi と価値関数  v_{\pi} が満たす一般的な関係である。方策が既知であれば価値関数のみが未知数であるため方程式を解くことで価値関数を求めることができる。しかし最適方策を求めたいときは最適な価値関数も未知であるため、1つの方程式の中に2つの未知数があることになり、最適方策を求めるのには使えなかった。

一方で、ベルマン最適方程式は最適な価値関数のみが満たす特別な関係である。方程式を解くことで最適な価値関数を求めることができる。最適だからこそ方策を含まない等式が立てられるのである。max を含む方程式を解くのは難しそうに見えるが、実際は不動点計算をするだけなので難なく解を得られる。

だいたいの書籍ではこれらの方程式が同じような感じで説明されるが、それぞれがもつ意味合いも使い方も異なるので注意したい。

さいごに

ベルマン方程式は数学的な厳密性を守りつつ説明しようとするとだいたいお腹いっぱいな感じになってしまう。今回はストーリー重視なので、かなりいい加減な感じの説明になったのも仕方ないとしよう。まぁ厳密に語れる自信もないんだが。

書籍 Reinforcement Learning: An Introduction の英語版と翻訳版が手元にあるが、オリジナルの英語版のほうが読みやすいという。翻訳版読んでも全然あたまに入ってこない...。

追記

後日、ベルマン方程式を解いて最適な行動を求めるコードを書いてみた。

t-keita.hatenadiary.jp