Keras のテキスト分類の結果をscikit learn のmetricsで評価

前回、scikit-learnの GridSearch をおさらいした。今回は、前々回のコードを修正し、同じscikit-learnのデータを使ってKeras(Tensolflowバックエンド)での標準的実装で精度を出した。精度算出のメトリックを合わせるため、scikit-learnで提供されているmetrics系の関数を使って計算している。

結論からいうと、scikit-learnのlinear svmと、KerasでのMLPの精度は、全体平均精度は変わらず。
というか、今回は、全結合1層+活性化層1層+Dropoutしかないので、Deep learningじゃない。なので、scikit-learnのパーセプトロン実装にDropoutを加えただけという感じ。そういう意味では、パーセプトロンと比較して、Dropoutの効果として見るべきかもしれない。


実装 特徴量 precision recall f1-score
Keras(DeepLearning-512-ReLU-Dropout0.5, batch=1024) 単語カウント(記号除去) 0.71 0.70 0.70
Scikit-Learn(LinearSVM-GridSearch(2乗誤差-L2正則化)) TFIDF+正規化(ストップワード除去) 0.71 0.70 0.70
Scikit-Learn(Perceptron-50イテレーション) 単語カウント(ストップワード除去) 0.65 0.64 0.64


20 newsgroups dataset

  • 11314 documents - 13.782MB (training set)
  • 7532 documents - 8.262MB (test set)


  • Keras:105373 (KerasのTokenizerはデフォルトで半角の記号群を除去)
  • Scikit-Learn:101322(TfidfVectorizerでstopword=englishを指定)



from __future__ import print_function
import sys
import os

from optparse import OptionParser
from time import time

import numpy as np

from sklearn.datasets import fetch_20newsgroups
from sklearn import metrics

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.utils import np_utils
from keras.preprocessing.text import Tokenizer

batch_size = 1024
nb_epoch = 5

def size_mb(docs):
    return sum(len(s.encode('utf-8')) for s in docs) / 1e6

def trim(s):
    """Trim string to fit on terminal (assuming 80-column display)"""
    return s if len(s) <= 80 else s[:77] + "..."

def main(argv=None):
    '''Command line options.'''
    program_name = os.path.basename(sys.argv[0])
    program_version = "v%f" %  __version__
    program_build_date = "%s" % __updated__

    program_version_string = '%%prog %s (%s)' % (program_version, program_build_date)
    program_longdesc = 'GridSearh for scikit learn - LinearSVC with TextData' 
    program_license = "Copyright 2017 mzi                                            \
                Licensed under the Apache License 2.0\n"

    if argv is None:
        argv = sys.argv[1:]

    # setup option parser
    op = OptionParser(version=program_version_string, epilog=program_longdesc, description=program_license)
                  action="store_true", dest="print_report",
                  help="Print a detailed classification report.")
                  action="store_true", dest="print_cm",
                  help="Print the confusion matrix.")
                  action="store_true", dest="all_categories",
                  help="Whether to use all categories or not.")
                  help="Remove newsgroup information that is easily overfit: "
                       "headers, signatures, and quoting.")

    # process options
    (opts, args) = op.parse_args(argv)

    if opts.all_categories:
        categories = None
        categories = [

    #Remove headers
    if opts.filtered:
        remove = ('headers', 'footers', 'quotes')
        remove = ()

    # MAIN BODY #
    print("Loading 20 newsgroups dataset for categories:")
    print(categories if categories else "all")
    data_train = fetch_20newsgroups(subset='train', categories=categories,
                                    shuffle=True, random_state=42,
    data_test = fetch_20newsgroups(subset='test', categories=categories,
                                   shuffle=True, random_state=42,
    print('data loaded')

    # order of labels in `target_names` can be different from `categories`
    target_names = data_train.target_names

    data_train_size_mb = size_mb(
    data_test_size_mb = size_mb(
    print("%d documents - %0.3fMB (training set)" % (
        len(, data_train_size_mb))
    print("%d documents - %0.3fMB (test set)" % (
        len(, data_test_size_mb))
    nb_classes = np.max( + 1
    print(nb_classes, 'classes')

    print('Vectorizing sequence data...')
    tokenizer = Tokenizer(nb_words=None, lower=True, split=' ', char_level=False)
    X_train = tokenizer.texts_to_matrix(, mode='count')
    X_test = tokenizer.texts_to_matrix(, mode='count')
    print('X_train shape:', X_train.shape)
    print('X_test shape:', X_test.shape)

    print('Convert class vector to binary class matrix (for use with categorical_crossentropy)')
    Y_train = np_utils.to_categorical(, nb_classes)
    Y_test = np_utils.to_categorical(, nb_classes)
    print('Y_train shape:', Y_train.shape)
    print('Y_test shape:', Y_test.shape)

    print('Building model...')
    model = Sequential()
    model.add(Dense(512, input_shape=(max_words,)))

    t0 = time()
    history =, Y_train,
                        nb_epoch=nb_epoch, batch_size=batch_size,
                        verbose=1, validation_split=0.1)
    train_time = time() - t0
    print("train time: %0.3fs" % train_time)

    t0 = time()
    #score = model.evaluate(X_test, Y_test,
    #                       batch_size=batch_size, verbose=1)
    #print('Test score:', score[0])
    #print('Test accuracy:', score[1])
    #Y_pred = model.predict(X_test, batch_size=batch_size, verbose=1)
    y_pred = model.predict_classes(X_test, batch_size=batch_size, verbose=1)
    train_time = time() - t0
    print("test time: %0.3fs" % train_time)

    if opts.print_report:
        print("classification report:")
        print(metrics.classification_report(, y_pred,
    if opts.print_cm:
        print("confusion matrix:")
        print(metrics.confusion_matrix(, y_pred))

if __name__ == "__main__":
    if TESTRUN:
        import doctest
    if PROFILE:
        import cProfile
        import pstats
        profile_filename = '_profile.txt''main()', profile_filename)
        statsfile = open("profile_stats.txt", "wb")
        p = pstats.Stats(profile_filename, stream=statsfile)
        stats = p.strip_dirs().sort_stats('cumulative')

コマンドライン --report --confusion_matrix   --filtered --all_categories


Loading 20 newsgroups dataset for categories:
data loaded
11314 documents - 13.782MB (training set)
7532 documents - 8.262MB (test set)

20 classes
Vectorizing sequence data...
X_train shape: (11314, 105373)
X_test shape: (7532, 105373)
Convert class vector to binary class matrix (for use with categorical_crossentropy)
Y_train shape: (11314, 20)
Y_test shape: (7532, 20)
Building model...
Train on 10182 samples, validate on 1132 samples
Epoch 1/5

10182/10182 [==============================] - 69s - loss: 2.6638 - acc: 0.3841 - val_loss: 2.1918 - val_acc: 0.6670
Epoch 2/5

10182/10182 [==============================] - 64s - loss: 1.6931 - acc: 0.7682 - val_loss: 1.6281 - val_acc: 0.7102
Epoch 3/5

10182/10182 [==============================] - 62s - loss: 1.1331 - acc: 0.8494 - val_loss: 1.3319 - val_acc: 0.7438
Epoch 4/5

10182/10182 [==============================] - 62s - loss: 0.8253 - acc: 0.8942 - val_loss: 1.2023 - val_acc: 0.7429
Epoch 5/5

10182/10182 [==============================] - 64s - loss: 0.6358 - acc: 0.9181 - val_loss: 1.1015 - val_acc: 0.7562
train time: 326.131s

test time: 55.487s
classification report:
                          precision    recall  f1-score   support

             alt.atheism       0.53      0.46      0.49       319
        0.68      0.73      0.71       389       0.68      0.62      0.65       394       0.69      0.62      0.65       392
   comp.sys.mac.hardware       0.71      0.70      0.71       385
       0.86      0.75      0.80       395
         0.80      0.81      0.80       390
            0.81      0.74      0.77       396       0.74      0.79      0.76       398       0.53      0.92      0.67       397       0.96      0.86      0.91       399
               sci.crypt       0.74      0.74      0.74       396
         sci.electronics       0.63      0.64      0.63       393
              0.82      0.76      0.79       396
            0.70      0.74      0.72       394
  soc.religion.christian       0.67      0.81      0.73       398
      talk.politics.guns       0.60      0.71      0.65       364
   talk.politics.mideast       0.87      0.73      0.80       376
      talk.politics.misc       0.59      0.43      0.50       310
      talk.religion.misc       0.51      0.29      0.37       251

             avg / total       0.71      0.70      0.70      7532

confusion matrix:
[[147   1   1   0   1   3   0   5   7  18   1   4   5   3  17  46   7  12
    8  33]
 [  6 284  15   8  12  16   4   1   4  10   0  11   8   1   9   0   0   0
    0   0]
 [  2  24 243  35  16  16   3   1   4  20   0   5   1   4  12   1   2   0
    4   1]
 [  0  12  38 243  38   4   9   4   0   8   1   4  29   0   1   0   0   0
    0   1]
 [  0   7  10  22 270   3   8   6   6  16   0   5  24   3   4   1   0   0
    0   0]
 [  0  32  25   8   3 295   7   1   0  10   0   3   5   2   2   1   0   0
    1   0]
 [  1   1   3  13  13   0 316   6   6  12   0   1   8   0   5   0   4   0
    0   1]
 [  3   0   0   1   2   0  12 292  21  28   0   2  16   1   9   3   3   0
    3   0]
 [  2   1   0   0   1   2   7  16 313  18   1   1  11   4   9   1   6   0
    5   0]
 [  2   3   0   0   1   0   3   0   5 364   4   0   1   1   1   7   0   0
    5   0]
 [  5   1   0   0   0   0   2   1   2  32 343   2   0   2   3   1   3   1
    0   1]
 [  3   9   4   2   3   2   1   1   3  26   0 292  14   2   7   1  17   3
    5   1]
 [  1  11   8  18  15   1  12   7  10  17   0  19 253  10   8   0   1   1
    1   0]
 [  7   8   3   2   1   0   6   3  11  15   2   1   8 302   5  10   4   5
    3   0]
 [  5  13   1   0   1   0   2   4   4  21   1   5  11   9 293   6   4   1
   13   0]
 [ 18   3   3   0   0   0   0   1   1  17   1   1   2   4   4 321   1   1
    3  17]
 [  8   0   2   0   1   0   2   4   6  22   0  15   2   4   8   7 257   2
   16   8]
 [ 21   1   1   1   0   1   1   3   7  15   0   5   2   2   5   8  10 276
   16   1]
 [ 15   1   0   0   0   0   1   1   9  10   3  14   3   5   8   5  88   8
  133   6]
 [ 34   4   2   1   0   0   1   3   4  11   1   3   1  11   6  61  22   6
    8  72]]