PyTorch の推論時のオーバーヘッドを調べてみた

論文投稿が終わって解放感がスゴい。勢いあまって PyTorch の推論の高速化について実験したのでその結果をここにまとめる。もちろん機械学習は勉強中なので手探りで実験しながら調べている感じである。アホみたいなことをしている可能性がある。

悩ましい(と思っている)ところ

ニューラルネットワーク用のライブラリといえば今は PyTorch やら TensofFlow やらが主流なんだろうけど、どちらも Python 向けに作られている。Python と聞くと遅いイメージがあるが、Python から API を叩くと裏では C++カリカリにチューニングされたプログラムが動いているので学習は効率的に行われる。別に不満はない。一方で推論時は、モデルに入力を食わせて出力が欲しいだけなのに、推論のたびに Python から C++ を呼び出すみたいなことが起こる。自分の機械学習の用途からすると1回の探索につき学習済みモデルを100万回くらい呼び出したいこともある。そうなると、言語をまたいでプログラムを実行するオーバヘッドが膨らみすぎるのである。処理の高速化のために機械学習を使いたいのに、機械学習を使うと実行時間が膨れ上がってしまう。これをどうにかしたい。そこでオーバヘッドを含む推論の高速化の方法を調べながら比較実験してみた。

適当なモデルを作る

まずは今回の実験用に適当なニューラルネットワークを定義する。以下のようにしてみた。

class MyNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Linear(2, 8)
        self.b = nn.Linear(8, 8)
        self.c = nn.Linear(8, 1)

    def forward(self, x):
        x = nn.ReLU()(self.a(x))
        x = nn.ReLU()(self.b(x))
        x = self.c(x)
        return x

隠れ層が1つだけのニューラルネットワークを定義している。そして適当に入力を用意して学習させてみる。PyTorch でよく見る感じの、確率的勾配法によって重みを調整する流れである。最後に、学習したパラメータを保存している。

(省略)
model.train()
for _ in range(n_epochs):
    for x, y in zip(X, Y):
        optimizer.zero_grad()
        predicted = model(x)
        loss = loss_fn(predicted, y)
        loss.backward()
        optimizer.step()

torch.save(model.state_dict(), "model_py.pt")

コード全体は Gist に置いた。本記事のコードはすべて同じ Gist のページに置いている。今回は  f(x_1, x_2) = | x_1| + x_2 という関数を近似してみた。どんなパラメータであっても推論速度に差は出ないのでどうでもよいのだが。

Python でロードして呼び出す場合

ここからは推論にかかる時間を計測する。まずはおそらく推論時のもっとも素朴なやり方であるモデルをロードするやり方をやってみる。以下ではモデルをロードした上で100万回の推論を行っている。

size = 1_000_000
input_values = list(map(
    lambda _: [random.uniform(-10, 10),
               random.uniform(-10, 10)],
    range(size)))
X = torch.tensor(input_values)

model = MyNN()
model.eval()
model.load_state_dict(torch.load("model.pt"))

import time
start = time.time()
for x in X:
    model(x)
print(time.timeit() - start)

実行時間は 124 秒だった。これは遅い。日が暮れる。

Python で TorchScprit を呼び出す

次にやってみるのは TorchScript というテクノロジーである。モデルを Python の実行環境に依存しない形で保存したりロードしたりできるらしい。Python に非依存である代わりに TorchScript 言語専用のインタプリタが存在しているようなので、ある意味新たな言語を仲介してモデルにアクセスできるようになっていると思ったらよさそう。仕組みはよく分かっていないが JavaC++ からも TorchScript を叩きやすいらしい。もしかして Java から TorchScript を利用する場合は JVM 上で動く TorchScript のインタプリタが存在している?よく分からんが、まずは Python から使ってみよう。

TorchScript のチュートリアル を見ながら進める。モデルを TorchScript の形式で保存するのは簡単で、以下のようにモデルを保存する部分だけ書き換えればよい。

example = torch.tensor([random.uniform(-10, 10), random.uniform(-10, 10)])
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("../model_torchscript.pt")

モデルに適当な入力を食わせる tracing という方法でやってみた。どうやら TorchScript として保存するためには Python インタプリタを動かしてみることに意味があるらしい。動的解析みたいに見えるが、おそらくモデルの保存は安全に行われると思う。謎テクノロジーだ。

そして TorchScript をロードして100万回推論を実行してみた。

model = torch.jit.load("../model_torchscript.pt")
start = time.time()
for x in X:
    _ = model(x)
print(time.time() - start)

実行時間は 54 秒だった。同じ Python でもオーソドックスなモデルの保存方法より推論時の性能が向上することが分かった。なにか最適化が走っているのだと思われる。Python でモデルを保存するときは TorchScript でやるのが今後の主流になるのかな。

Java で TorchScript を呼び出す

TorchScript として保存したモデルは他のプログラミング言語から呼び出すのが簡単らしい。ってことで次は Java からモデルをロードして推論を実行してみた。このソースコード を参考に実装した。

実装した Java コードは以下である。

Module mod = Module.load("../../model_torchscript.pt");
 int size = 1000000;
(中略)
long start = System.currentTimeMillis();
for (int i = 0; i < size; i++) {
    Tensor inputs = Tensor.fromBlob(
            data[i],
            new long[]{2} // shape
    );
    mod.forward(IValue.from(inputs));
}
long elapsed = (System.currentTimeMillis() - start);
System.out.println(elapsed);

実行時間は 35 秒であった。わりと速い。

以下、Java から PyTorch を動かすためのメモ:

まず、Java プロジェクトの pom.xml に以下を追加した。これで Pytorch を動かすための Java API が叩けるようになる。

    <dependencies>
        <dependency>
            <groupId>org.pytorch</groupId>
            <artifactId>pytorch_java_only</artifactId>
            <version>1.11</version>
        </dependency>
    </dependencies>

また、裏で動く Pytorch として ここの指示 の通りに LibTroch (version 1.11.0) をダウンロードしてきて Java の実行時の VM 引数に -Djava.library.path=/(省略)/libtorch/lib を指定した。

C++ から TorchScript を呼び出す

同様に、C++ から TorchScript を呼び出してみる。

model = torch::jit::load(argv[1]);
int size = 1000000;
(中略)
std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
for (size_t i = 0; i < size; i++) {
    int64_t shape[] = {2};
    std::vector<torch::jit::IValue> inputs{torch::from_blob(data[i], shape)};
    model.forward(inputs);
}
std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count() << "[ms]" << std::endl;

実行時間は 44 秒であった。Java より速いかと思ったら意外とそんなことはなかった。なぜだろう。

C++ API からモデルを呼び出す

いろいろ調べたところ、TorchScript すら遅いと感じる場合は PyTorch の C++ API を使うとよいらしい。PyTorch の裏側で実行される LibTorch というライブラリを直接叩けるようだ。TorchScript みたいにインタプリタを挟むのではなく、C++ のピュアなデータ構造として表現されたニューラルネットワークを使って推論を実行できるってことだろう。

今回は学習から C++ API を使ってやる必要がある。このチュートリアル を見ながらやってみた。コードは以下のような感じである。

int main()
{
    (中略)
    for (size_t epoch = 1; epoch <= n_epochs; epoch++) {
        for (size_t i = 0; i < size / batch_size; i++) {
            torch::Tensor x = X[i];
            torch::Tensor y = Y[i];
            optimizer.zero_grad();
            torch::Tensor predicted = model->forward(x).reshape({batch_size});
            torch::Tensor loss = torch::mse_loss(predicted, y);
            loss.backward();
            optimizer.step();
        }
    }
    torch::save(model, "../model_cpp.pt");
}

基本的には Python から PyTorch を使う時と同じような流れで C++ でプログラムが組めるようになっている。

そして例によって100万回推論を実行してみる。

int main()
{
    c10::InferenceMode guard(true);
    auto model = std::make_shared<MyNN>();
    torch::load(model, "../model_cpp.pt");
    (中略)
    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    for (size_t i = 0; i < size; i++) {
        model->forward(input_values[i]);
    }
    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
    std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - begin).count() << "[ms]" << std::endl;
}

実行時間は 9 秒であった。めっちゃ速い。これだ。

まとめ

各設定で推論を100万回実行したときの実行時間を以下にまとめる。

設定 実行時間
Python 124 秒
TorchScript from Python 54 秒
TorchScript from Java 35 秒
TorchScript from C++ 44 秒
C++ API 9 秒

推論時のオーバヘッドを減らせる C++ API が最強だけど、Java から TorchScript を呼び出すくらいでも悪くないケースも多いかもしれない。Rust から C++ API を叩けるラッパーとかもあるっぽいので、そのあたりを触ってみたい。C++ よりは Rust でプログラム書きたいので。

github.com

おしまい。