A certain engineer "COMPLEX"

開発メモ その171 ChainerのモデルデータをONNXに変換してみる

Introduction


ChainerはPythonで動きます。
学習はPythonで良くてもシステムに組み込む際にPythonは聊か都合が悪いです。
なので、学習結果であるモデルファイルをPython以外の言語で扱えるように、共通データフォーマットであるONNX (Open Neural Network Exchange)に変換してみます。

How to


シンプルなMNISTのサンプルで生成されるモデルファイルを変換してみます。
まずはChainerの環境を構築。Windows上の仮想環境を想定。

環境構築

Introduction以前、CUDAを含むChainerのインストール手法を書きましたが、もう少しわかりやすく整理しました。ResolutionCUDA_PATHの確認Chainerで利用するCUDAを確認し...

を参考にし、下記のコマンドを実行します。


> python -m pip install --upgrade pip
> python -m pip install cupy==5.3.0
> python -m pip install cupy-cuda92==5.3.0
> python -m pip install chainer==5.3.0
> python -m pip install onnx-chainer

インストール後はCUDAを認識しているかを確認します。


> python -c "import chainer; chainer.print_runtime_info()"
Platform: Windows-10-10.0.17763-SP0
Chainer: 5.3.0
NumPy: 1.16.2
CuPy:
CuPy Version : 5.3.0
CUDA Root : C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v9.2
CUDA Build Version : 9020
CUDA Driver Version : 10000
CUDA Runtime Version : 9020
cuDNN Build Version : 7402
cuDNN Version : 7402
NCCL Build Version : None
NCCL Runtime Version : None
iDeep: Not Available

学習と変換

ソースをクローンし、train_mnist.pyを修正します。


> git clone https://github.com/chainer/chainer
> cd chainer
> git checkout v5.3.0
> cd examples\mnist

修正前はモデルファイルを出力しないためです。
修正は下記2箇所。

箇所1


#!/usr/bin/env python
import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions

+ import os
+ from chainer import serializers
+ import onnx_chainer
+ import cupy as xp

箇所2


if args.resume:
# Resume from a snapshot
chainer.serializers.load_npz(args.resume, trainer)

# Run the training
trainer.run()

+ model_file = os.path.join(args.out, 'mnist.model')
+ onnx_file = os.path.join(args.out, 'mnist.onnx')
+ serializers.save_npz(model_file, model)
+ chainer.config.train = False
+ # ダミーデータ
+ x = xp.zeros((1, 1, 28, 28), dtype=xp.float32)
+ onnx_chainer.export(model.predictor, x, filename=onnx_file)

if __name__ == '__main__':
main()

下記を実行します。


> python train_mnist.py --gpu 0
GPU: 0
# unit: 1000
# Minibatch-size: 100
# epoch: 20

Downloading from http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...
D:\Works\Python\envs\Chainer_5.3.0\lib\site-packages\chainer\training\extensions\plot_report.py:25: UserWarning: matplotlib is not installed on your environment, so nothing will be plotted at this time. Please install matplotlib to plot figures.

$ pip install matplotlib

warnings.warn('matplotlib is not installed on your environment, '
epoch main/loss validation/main/loss main/accuracy validation/main/accuracy elapsed_time
1 0.189437 0.0972368 0.94235 0.9701 8.70826
2 0.0736203 0.0654 0.977182 0.9789 11.5016
3 0.0483072 0.0657967 0.984433 0.9786 13.9685
4 0.0344889 0.0775722 0.988614 0.9779 16.4064
5 0.0307565 0.073484 0.990132 0.9808 18.8182
6 0.0247791 0.0676734 0.991932 0.981 21.2477
7 0.0187978 0.0756141 0.993648 0.9822 23.7888
8 0.0151227 0.0729054 0.994965 0.9836 26.2349
9 0.0162507 0.09349 0.994649 0.9805 28.6772
10 0.0186082 0.07771 0.994565 0.983 31.1291
11 0.0121478 0.0878187 0.996282 0.9829 33.575
12 0.0139532 0.103692 0.995515 0.9803 36.1334
13 0.0128902 0.0869704 0.996015 0.9812 38.5139
14 0.00887171 0.0960986 0.997182 0.9807 40.966
15 0.00950978 0.0941747 0.997016 0.9835 43.492
16 0.0117687 0.112775 0.996632 0.9805 45.8972
17 0.012322 0.109815 0.996782 0.9803 48.2932
18 0.00880352 0.10031 0.997383 0.9818 50.731
19 0.00929279 0.118552 0.997365 0.979 53.3326
20 0.00607901 0.0849677 0.99825 0.9852 55.8271

学習の結果、result\mnist.modelresult\mnist.onnxが出力されています。

コメントを残す

メールアドレスが公開されることはありません。

%d人のブロガーが「いいね」をつけました。