読者です 読者をやめる 読者になる 読者になる

福岡は今日も雨

情報系大学生のブログ。主に技術,音楽について。

(2/17)AlexNetを作ってcifar-10をしていてtupleDatasetにできなくてハマっている

曲者だじぇ
Cifar-10をAlexnetで分類するコードを書いているのだが,どうもなんだか動かない
Chainerではcifar-10を手にいれるchainer.datasets.tuple_dataset.TupleDataset型で手にいれるメソッドがあるのは知ってるけれど
今回は次に書くプログラムのためにそれを使わずにすることにした。

train_tup, test_tup = chainer.datasets.get_mnist()
# get_mnistでdataset.TupleDataset型のものを持ってくるメソッド

(array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.01176471,  0.07058824,  0.07058824,
         0.07058824,  0.49411768,  0.53333336,  0.68627453,  0.10196079,
         0.65098041,  1.        ,  0.96862751,  0.49803925,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.11764707,  0.14117648,  0.36862746,  0.60392159,
         0.66666669,  0.99215692,  0.99215692,  0.99215692,  0.99215692,
         0.99215692,  0.88235301,  0.67450982,  0.99215692,  0.94901967,
         0.76470596,  0.25098041,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.19215688,  0.9333334 ,
         0.99215692,  0.99215692,  0.99215692,  0.99215692,  0.99215692,
         0.99215692,  0.99215692,  0.99215692,  0.98431379,  0.36470589,
         0.32156864,  0.32156864,  0.21960786,  0.15294118,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.07058824,  0.8588236 ,  0.99215692,  0.99215692,
         0.99215692,  0.99215692,  0.99215692,  0.77647066,  0.71372551,
         0.96862751,  0.9450981 ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.3137255 ,  0.61176473,  0.41960788,  0.99215692,  0.99215692,
         0.80392164,  0.04313726,  0.        ,  0.16862746,  0.60392159,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.05490196,
         0.00392157,  0.60392159,  0.99215692,  0.35294119,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.54509807,
         0.99215692,  0.74509805,  0.00784314,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.04313726,  0.74509805,  0.99215692,
         0.27450982,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.13725491,  0.9450981 ,  0.88235301,  0.627451  ,
         0.42352945,  0.00392157,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.31764707,  0.94117653,  0.99215692,  0.99215692,  0.4666667 ,
         0.09803922,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.17647059,
         0.72941178,  0.99215692,  0.99215692,  0.58823532,  0.10588236,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.0627451 ,  0.36470589,
         0.98823535,  0.99215692,  0.73333335,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.97647065,  0.99215692,
         0.97647065,  0.25098041,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.18039216,  0.50980395,
         0.71764708,  0.99215692,  0.99215692,  0.81176478,  0.00784314,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.15294118,
         0.58039218,  0.89803928,  0.99215692,  0.99215692,  0.99215692,
         0.98039222,  0.71372551,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.09411766,  0.44705886,  0.86666673,  0.99215692,  0.99215692,
         0.99215692,  0.99215692,  0.78823537,  0.30588236,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.09019608,  0.25882354,  0.83529419,  0.99215692,
         0.99215692,  0.99215692,  0.99215692,  0.77647066,  0.31764707,
         0.00784314,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.07058824,  0.67058825,  0.8588236 ,
         0.99215692,  0.99215692,  0.99215692,  0.99215692,  0.76470596,
         0.3137255 ,  0.03529412,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.21568629,  0.67450982,
         0.88627458,  0.99215692,  0.99215692,  0.99215692,  0.99215692,
         0.95686281,  0.52156866,  0.04313726,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.53333336,  0.99215692,  0.99215692,  0.99215692,
         0.83137262,  0.52941179,  0.51764709,  0.0627451 ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ], dtype=float32), 5)

これがおそらくiteratorとして受け付ける形で,データとラベルのタプル。
んでこれを作ればいいんでしょ、って方向性で元々のデータをいじってみたのだが

def divive_l_t():
    ary = unpickle()
    data_ary = []
    label_ary = []
    for dict in ary:
        for data, label in zip(dict['data'], dict['labels']):
            data_ary.append((data / 255).astype(np.float32))
            label_ary.append(label)

    return data_ary, label_ary

def divive_tr_t():
    data_ary, label_ary = divive_l_t()
    train = []
    test = []
    for i in range(50000):
        if i < 39999:
            # ひとつひとつのデータについて,型をchainerで決められたtupleの型に合わせていく。
            tup = (data_ary[i], label_ary[i])
            # print(len(tup))
            train.append(tup)
        else:
            tup = (data_ary[i], label_ary[i])
            test.append(tup)

    train_tup = chainer.datasets.tuple_dataset.TupleDataset(train)
    test_tup = chainer.datasets.tuple_dataset.TupleDataset(test)

    return train_tup, test_tup

こうしてみても、

((array([ 0.23137255,  0.16862746,  0.19607843, ...,  0.54901963,
          0.32941177,  0.28235295], dtype=float32), 6),)

という奇妙なデータ構造になっちゃうんだよね。
この辺りをどう解決すればいいのか...もう少しchainer.datasets.tuple_dataset.TupleDataset()について調べておくべきか..
以下全コード。ちょっと違うけれど実験中なので許して。出来次第しっかり書く予定

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
import numpy as np
import _pickle as cpickle
import os

def unpickle():
    base_path = os.path.dirname(os.path.abspath(__file__))
    cifar_path = os.path.normpath(os.path.join(base_path, '../cifar-10-batches-py'))
    ary = []
    for i in range(1, 6):
        file_path = cifar_path +'/data_batch_' + str(i)
        fo = open(file_path, 'rb')
        tmp_dic = cpickle.load(fo, encoding='latin1')
        ary.append(tmp_dic)
        fo.close()

    return ary

def divive_l_t():
    ary = unpickle()
    data_ary = []
    label_ary = []
    for dict in ary:
        for data, label in zip(dict['data'], dict['labels']):
            data_ary.append((data / 255).astype(np.float32))
            label_ary.append(label)

    return data_ary, label_ary

def divive_tr_t():
    data_ary, label_ary = divive_l_t()
    train = []
    test = []
    for i in range(50000):
        if i < 39999:
            # ひとつひとつのデータについて,型をchainerで決められたtupleの型に合わせていく。
            tup = (data_ary[i], label_ary[i])
            # print(len(tup))
            train.append(tup)
        else:
            tup = (data_ary[i], label_ary[i])
            test.append(tup)

    train_tup = chainer.datasets.tuple_dataset.TupleDataset(train)
    test_tup = chainer.datasets.tuple_dataset.TupleDataset(test)

    return train_tup, test_tup

class AlexNet(chainer.Chain):

    input_size = 227

    def __init__(self):
        super(AlexNet, self).__init__(
            conv1 = L.Convolution2D(None, 96, 11, stride=4),
            conv2 = L.Convolution2D(None, 256, 5, pad=2),
            conv3 = L.Convolution2D(None, 384, 3, pad=1),
            conv4 = L.Convolution2D(None, 384, 3, pad=1),
            conv5 = L.Convolution2D(None, 256, 3, pad=1),
            fc6 = L.Linear(None, 4096),
            fc7 = L.Linear(None, 4096),
            fc8 = L.Linear(None, 10))

    def __call__(self, x):
        h = F.max_pooling_2d(F.local_response_normalization(F.relu(self.conv1(x))), 3, stride=2)
        h = F.max_pooling_2d(F.local_response_normalization(F.relu(self.conv2(h))), 3, stride=2)
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        h = F.max_pooling_2d(F.relu(self.conv5(h)), 3, stride=2)
        h = F.dropout(F.relu(self.fc6(h)))
        h = F.dropout(F.relu(self.fc7(h)))
        h = F.relu(self.fc8(h))

        return h

# モデルのインスタンス化
model = L.Classifier(AlexNet())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

# 訓練データとテストデータに分割
# 今回はchainerのTupleDatasetのメソッドを使っているが,うまくキャストしているのかが心配
# TODO: ここのtupleのデータ構造(型)と,もともとのデータ構造(型)を合わせないとうまくいかないと思う
train, test = divive_tr_t()
# train_tup, test_tup = chainer.datasets.get_mnist()
# mnistとデータを合わせればうまくいくはず

train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100, repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (100, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())

trainer.run()