福岡は今日も雨

情報系大学生のブログ。主に技術,音楽について。

AlexNetを(とりあえず)CPUモードで動かす

numpyとかへの慣れの少なさか,データセットを用意してchainerに流すところで詰まっている感じ
というかここが一番面倒くさいかも
chainerからとってこれるメソッドで呼び出すと以下

class AlexNet(chainer.Chain):

    input_size = 227

    def __init__(self):
        super(AlexNet, self).__init__(
            conv1 = L.Convolution2D(None, 96, 11, stride=4),
            conv2 = L.Convolution2D(None, 256, 3, pad=2),
            conv3 = L.Convolution2D(None, 384, 3, pad=1),
            conv4 = L.Convolution2D(None, 384, 3, pad=1),
            conv5 = L.Convolution2D(None, 256, 3, pad=1),
            fc6 = L.Linear(None, 4096),
            fc7 = L.Linear(None, 4096),
            fc8 = L.Linear(None, 10))

    def __call__(self, x):
        h = F.max_pooling_2d(F.local_response_normalization(F.relu(self.conv1(x))), 3, stride=2)
        h = F.max_pooling_2d(F.local_response_normalization(F.relu(self.conv2(h))), 3, stride=2)
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        h = F.max_pooling_2d(F.relu(self.conv5(h)), 3, stride=2)
        h = F.dropout(F.relu(self.fc6(h)))
        h = F.dropout(F.relu(self.fc7(h)))
        h = F.relu(self.fc8(h))

        return h

# モデルのインスタンス化
model = L.Classifier(AlexNet())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

# 訓練データとテストデータに分割
train, test = chainer.datasets.get_cifar10()
# print(len(imageData[0][1]))
# threshold = np.int32(len(imageData)/8*7)
# train = tuple_dataset.TupleDataset(imageData[0:threshold], labelData[0:threshold])
# test = tuple_dataset.TupleDataset(imageData[threshold:], labelData[threshold:])

# print(len(train[0][0]))

train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100, repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (100, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())

trainer.run()

頑張った跡が見られる
ちょっと今から普通の画像でもなんとかできるようになんとかします。