TensorFlowのチュートリアルとしてMNIST dataset(手書き数字データセット)の認識が用意されています。
- MNIST For ML Beginners: TensorFlow公式サイトのチュートリアル(英語)
- TensorFlow : Get Started : ML 初心者向けの MNIST: クラスキャット社が公開しているチュートリアルの翻訳&解説。
基本はチュートリアルに沿って進めれば問題ないのですが、デフォルトでは学習/評価データであるMNIST datasetをインターネット経由で取得する仕様になっておりインターネットがない環境では動作しません。そこで、ローカルに保存したMNIST datasetを使う方法を調べたので紹介します。
なお、試した環境は
- OS: CentOS 7
- TensorFlow: 1.2.1
- Python: 2.7
です。
MNIST datasetの格納
MNIST datasetはYann LeCun’s website等で公開されており以下の4つのファイルになります。
- train-images-idx3-ubyte.gz: 学習用データ(画像)
- train-labels-idx1-ubyte.gz: 学習用データ(ラベル)
- t10k-images-idx3-ubyte.gz: 評価用データ(画像)
- t10k-labels-idx1-ubyte.gz: 評価用データ(ラベル)
この4つのデータをローカルディレクトリに保存しておきます。ここでは/home/tfuser/dataset/mnistディレクトリに置いたとします。
mnist_softmax.pyの修正
チュートリアルのコードであるmnist_softmax.pyはデフォルトではインターネット経由でMNIST datasetを取得しようとします。
データを読み込むために用意されているinput_data.read_data_sets関数は
- 第一引数で指定されたパスに学習/評価用データがあるか探す
- 指定されたパスになければインターネット経由からデータ取得
という仕様になっているので第一引数に適切なパスを指定することでローカルにあるデータを使うようにできます。
具体的にはmnist_softmax.pyの35~37行目で
def main(_): # Import data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
となっているのを
def main(_): # Import data FLAGS.data_dir = '/home/tfuser/dataset/mnist/' mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
とすると/home/tfuser/dataset/mnist/配下のMNIST datasetを読みに行くようになりインターネット接続の無い環境でもチュートリアルを行うことができます。