TensorFlowチュートリアル:ローカルのMNIST datasetを使う方法

投稿者: | 2017-07-25

TensorFlowのチュートリアルとしてMNIST dataset(手書き数字データセット)の認識が用意されています。

基本はチュートリアルに沿って進めれば問題ないのですが、デフォルトでは学習/評価データであるMNIST datasetをインターネット経由で取得する仕様になっておりインターネットがない環境では動作しません。そこで、ローカルに保存したMNIST datasetを使う方法を調べたので紹介します。

なお、試した環境は

  • OS: CentOS 7
  • TensorFlow: 1.2.1
  • Python: 2.7

です。

MNIST datasetの格納

MNIST datasetはYann LeCun’s website等で公開されており以下の4つのファイルになります。

  • train-images-idx3-ubyte.gz: 学習用データ(画像)
  • train-labels-idx1-ubyte.gz: 学習用データ(ラベル)
  • t10k-images-idx3-ubyte.gz: 評価用データ(画像)
  • t10k-labels-idx1-ubyte.gz: 評価用データ(ラベル)

この4つのデータをローカルディレクトリに保存しておきます。ここでは/home/tfuser/dataset/mnistディレクトリに置いたとします。

mnist_softmax.pyの修正

チュートリアルのコードであるmnist_softmax.pyはデフォルトではインターネット経由でMNIST datasetを取得しようとします。

データを読み込むために用意されているinput_data.read_data_sets関数は

  • 第一引数で指定されたパスに学習/評価用データがあるか探す
  • 指定されたパスになければインターネット経由からデータ取得

という仕様になっているので第一引数に適切なパスを指定することでローカルにあるデータを使うようにできます。

具体的にはmnist_softmax.pyの35~37行目で

def main(_):
  # Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

となっているのを

def main(_):
  # Import data
  FLAGS.data_dir = '/home/tfuser/dataset/mnist/'
  mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

とすると/home/tfuser/dataset/mnist/配下のMNIST datasetを読みに行くようになりインターネット接続の無い環境でもチュートリアルを行うことができます。

スポンサーリンク


コメントを残す

メールアドレスが公開されることはありません。