pytorchで機械学習(導入編)
唐突に機械学習がやりたくなったのでpytorchで遊んでみることにした。
pytorchとは
pytorchはchainerからforkされた機械学習のフレームワークらしい。
かなり抽象化されており比較的少ないコーディング量で機械学習が楽しめる。
公式
インストール
condaの環境が推奨されているようなのでそれに従うことにする。
condaのインストールについては割愛する。
pipとnumpyは最新にしておいたほうがよさげ。
https://pytorch.org/get-started/locally/
公式ページでOSやら自分の環境を選択すると適切なコマンドが表示されるのでそれに従う。
今回は以下の環境で実行した。
OS:Windows10 Pro
python3.6.1(conda上)
package:pip
Cuda:None
pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-win_amd64.whl
pip install torchvision
インストールが終わったらシェル上で
import torch
print(torch.__version__)
を実行してインストールができているか確認しよう。
pytorchについて
必要に応じて、numpy、scipy、CythonなどのPythonパッケージを再利用してPyTorchを拡張することができる模様。
パッケージ | 説明 |
---|---|
torch | NumPyのような強力なGPUサポートを備えたTensorライブラリ |
torch.autograd | torch内のすべての微分可能なテンソル操作をサポートするテープベースの自動微分ライブラリ |
torch.nn | 軟性を最大限に高めるように設計された自動微分機能と統合されたニューラルネットワークライブラリ |
torch.optim | SGD、RMSProp、LBFGS、Adamなどの標準的な最適化手法を使用してtorch.nnとともに使用する最適化パッケージ |
torch.multiprocessing | Pythonマルチプロセッシングではなく、プロセス全体にわたるトーチ・テンソルの魔法のメモリ共有が可能。データロードとホグワルドトレーニングに役立つ。 |
torch.utils | DataLoader、Trainerおよびその他のユーティリティ関数 |
torch.legacy.nn | 後方互換性の理由からトーチから移植されたレガシーコード、新しく始めるならまず使わないほうが良い。 |
torch.legacy.optim | 上に同じく |
学習済みのt7ファイルを読み込む
今回はちょっと遊んでみるだけなので自分でゴリゴリ計算(学習)するのではなくすでに学習済みのモデルを利用して動かしてみることにする。
拝借した学習モデルは、早稲田大学の研究者が作成したものです。
ここの記事に割と詳しく載ってます。
https://gigazine.net/news/20170501-globally-locally-consistent-image-completion/
画像の欠損した部分をかなりきれいに補完してくれるらしいです。
git clone https://github.com/satoshiiizuka/siggraph2017_inpainting.git
でファイルをクローンして中にあるdownload_model.sh
を実行する。
ダウンロードが開始されるのでしばらく待つと学習済みのモデルファイル(completionnet_places2.t7)が手に入る。
モデルファイルが手に入ったら早速読み込ませてみる。
from torch.utils.serialization import load_lua
model = load_lua("./completionnet_places2.t7")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 608, in load_lua
return reader.read()
File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 595, in read
return self.read_table()
File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 523, in wrapper
result = fn(self, *args, **kwargs)
File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 572, in read_table
v = self.read()
File "C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\utils\serialization\read_lua_file.py", line 598, in read
"corrupted.".format(typeidx))
torch.utils.serialization.read_lua_file.T7ReaderException: unknown type id 1054128068. The file may be corrupted.
エラーを吐かれた。
え?ファイル壊れてんの!?
そんなバナナ!
windowsの場合はload_luaの引数にlong_size=8をつけておかないと上のようなエラーを吐れるらしい…(そんなの聞いてない)
というわけで改めて
from torch.utils.serialization import load_lua
model = load_lua("./completionnet_places2.t7", long_size=8)
print(model)
{'mean': tensor([0.4560, 0.4472, 0.4155]), 'model': nn.Sequential {
[input -> (0) -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8)
-> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17)
-> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> (24) -> (25) -> (26)
-> (27) -> (28) -> (29) -> (30) -> (31) -> (32) -> (33) -> (34) -> (35)
-> (36) -> (37) -> (38) -> (39) -> (40) -> (41) -> (42) -> (43) -> (44)
-> (45) -> (46) -> (47) -> (48) -> (49) -> output]
(0): nn.SpatialConvolution(4 -> 64, 5x5, 1, 1, 2, 2)
(1): nn.SpatialBatchNormalization
(2): nn.ReLU
(3): nn.SpatialConvolution(64 -> 128, 3x3, 2, 2, 1, 1)
(4): nn.SpatialBatchNormalization
(5): nn.ReLU
(6): nn.SpatialConvolution(128 -> 128, 3x3, 1, 1, 1, 1)
(7): nn.SpatialBatchNormalization
(8): nn.ReLU
(9): nn.SpatialConvolution(128 -> 256, 3x3, 2, 2, 1, 1)
(10): nn.SpatialBatchNormalization
(11): nn.ReLU
(12): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 1, 1, 1, 1)
(13): nn.SpatialBatchNormalization
(14): nn.ReLU
(15): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 1, 1, 1, 1)
(16): nn.SpatialBatchNormalization
(17): nn.ReLU
(18): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 2, 2, 2, 2)
(19): nn.SpatialBatchNormalization
(20): nn.ReLU
(21): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 4, 4, 4, 4)
(22): nn.SpatialBatchNormalization
(23): nn.ReLU
(24): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 8, 8, 8, 8)
(25): nn.SpatialBatchNormalization
(26): nn.ReLU
(27): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 16, 16, 16, 16)
(28): nn.SpatialBatchNormalization
(29): nn.ReLU
(30): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 1, 1, 1, 1)
(31): nn.SpatialBatchNormalization
(32): nn.ReLU
(33): nn.SpatialDilatedConvolution(256 -> 256, 3x3, 1, 1, 1, 1, 1, 1)
(34): nn.SpatialBatchNormalization
(35): nn.ReLU
(36): nn.SpatialFullConvolution(256 -> 128, 4x4, 2, 2, 1, 1)
(37): nn.SpatialBatchNormalization
(38): nn.ReLU
(39): nn.SpatialConvolution(128 -> 128, 3x3, 1, 1, 1, 1)
(40): nn.SpatialBatchNormalization
(41): nn.ReLU
(42): nn.SpatialFullConvolution(128 -> 64, 4x4, 2, 2, 1, 1)
(43): nn.SpatialBatchNormalization
(44): nn.ReLU
(45): nn.SpatialConvolution(64 -> 32, 3x3, 1, 1, 1, 1)
(46): nn.SpatialBatchNormalization
(47): nn.ReLU
(48): nn.SpatialConvolution(32 -> 3, 3x3, 1, 1, 1, 1)
(49): nn.Sigmoid
}}
今度は無事に読み込めた。
疲れたので今回はこの辺でつづきはまた次回