PytorchでDeepLeaning実装 ~クラス分類1~
ここ最近、AIによる技術の進化が激しく、深層学習も非常にトレンディーになってます。
リアルな画像生成や人物認識、姿勢推定によるスポーツ指導への応用等々...
私もニュースを見て非常に興味が出てきたので、個人でプログラミングによる深層学習を学んでそれをアウトプットしていきたいと思います。
深層学習を始めるにあたって
まずは、定義等の確認から。深層学習とは機械学習の一部です。
昔の機械学習といえば、情報を人間が取捨選択して、それを機械 (プログラム) で学習するのが一般的でした。
例えば、画像から犬と猫を認識したい場合は耳や口といった部分を切り取って、機械に入力し、その違いを学習させます。
しかし、この方法では人間が選択する必要があるので非常に面倒です。
そこで、画像の特徴量 (何をもって猫とするか犬とするかといった判断基準) も機械に学習させるのが深層学習です。
深層学習では畳み込みニューラルネットワークが用いられ、畳み込み演算等を何度も行うことで特徴量抽出を行い、クラス分類や物体検出を行います。
私もいまだに理解できていない部分が多いですが、画像処理でいうところのエッジ抽出や膨張処理、
ノイズ処理みたいなことを何度も行い、猫っぽい特徴や犬っぽい特徴を抽出するといったイメージでしょうか?
簡単に説明しましたが、理論は自分で勉強したほうが正確ですし理解も早いので、本やweb等で勉強するのがいいと思います。
一般的な深層学習の勉強はこちらの本がおススメです。
ただし、この本の1は非常に勉強になったのですが、3の購入は正直あまりお勧めしません (2は未購入) 。
3は深層学習で用いるフレームワークを自分で一から実装するというテーマなので、アルゴリズム等のより深い理解には役立ちますが、実践的な力を身に着けるには不向きだからです。
クラス分類や物体検出等のプログラムを動かしてみたいなら、1で基礎を勉強した後に参考書等やwebにあるソースコードを参考にすると実践的な力が身につくと思います。
準備
まず個人で深層学習をやるにあたって、準備するべきことやものは以下の通りです。
- プログラミング環境
- GPU
- 知識
- 学習がうまくいかなくてもめげない根性
プログラミング環境
深層学習の実装自体はC++とかの言語でも可能なようですが、フレームワークの選択肢・困ったときの情報量の多さ・
実装がweb上にあふれている、という理由でPythonがいいと思います。
私は基本的にVSCode+WSLを用いて実装し、画像等を手軽に見たいときはAnacondaでも実装しております。
VSCodeのほうが処理が早い (気がする) ので、おススメはVSCodeです。
また、Pythonの深層学習フレームワークはたくさんあります。
GitHubなんかを見てると主に使われるのは、PyTorch・Tensorflow・scikit-learnの3つでしょうか?
それぞれの"超"個人的な感想は以下の通りです。
GPU
計算高速化のためには必須です。
PCを理解していない私のようなものには判断が難しいのですが、家電量販店等で表示されているものはあくまで画像出力用であり、深層学習に使えるものではないらしいです。
新しくPCを購入する場合は、VRAM (GPUのメモリ) がきちんと記載されているやつを購入しましょう。
VRAMはあればあるほど安心ですが、最低でも4GBくらいは欲しいです。
PCでゲームをやる場合もぬるぬる動いてくれるので、深層学習への熱意を失っても無駄にはならないような気がします。
お金等の問題で購入できない場合は、Google Colabを使うのがよいでしょう。
Googleさんが提供している基本無料かつ個人でGPU使用が可能なものです。
ただし、一定時間ごとにログインしているか確認されるうえ、性能はランダムなようなので、使い勝手はよくないです。
私は、6時間の計算が4時間で中断された経験から、PC購入を決意しました。
知識
わからないことは調べましょう。大抵のことはネットで解決できます。
ただし、ちゃんとした基本を身に着けるという意味でも、基礎本は1冊ほしいです。
PyTorch実装場合は、この本が詳しい実装が載っているのでお勧めです。
論文実装の場合はGitHubのISSUEやREADMEを見ることになると思うので、英語も怖がらずに読む・翻訳する勇気が必要です
クラス分類の実装
今回は、Python+PyTorchを用いて、PyTorchのクラス分類チュートリアルを行いました。
PyTorchチュートリアル (英語)
pytorch.org
チュートリアルを参考に自分で実装したプログラムはGitHubへとアップロードしました。
GitHub初心者なので不安ですが、多分下記のやつをコピペで動くと思います。
github.com
特に変更するべきは79行目の
for epoch in range(10000):
です。
過学習するかを確認するため、こんなに大きなエポック数にしておりますが、100くらいで十分です。
このチュートリアルではThe cifar10と呼ばれる10クラス (飛行機、車、鳥等) 分類を学習し、
テストデータに対して各クラスどのくらいの精度で分類できたかを出力します。
おおよそですが、100エポックの場合は全クラス平均60%くらいの精度で分類できます。
ソースコードの詳しい説明は次回