Keras のテキスト分類の結果をscikit learn のmetricsで評価

前回、scikit-learnの GridSearch をおさらいした。今回は、前々回のコードを修正し、同じscikit-learnのデータを使ってKeras(Tensolflowバックエンド)での標準的実装で精度を出した。精度算出のメトリックを合わせるため、scikit-learnで提供されているmetrics系の関数を使って計算している。

結論からいうと、scikit-learnのlinear svmと、KerasでのMLPの精度は、全体平均精度は変わらず。
というか、今回は、全結合1層+活性化層1層+Dropoutしかないので、Deep learningじゃない。なので、scikit-learnのパーセプトロン実装にDropoutを加えただけという感じ。そういう意味では、パーセプトロンと比較して、Dropoutの効果として見るべきかもしれない。
もう一つの見方として、全結合で512まで一度落としているので、ここで、うまく特徴抽出が働いていて、SVMのGridSearchと同じくらいの性能は出せたと考えるべきかな。

結果

実装 特徴量 precision recall f1-score
Keras(DeepLearning-512-ReLU-Dropout0.5, batch=1024) 単語カウント(記号除去) 0.71 0.70 0.70
Scikit-Learn(LinearSVM-GridSearch(2乗誤差-L2正則化)) TFIDF+正規化(ストップワード除去) 0.71 0.70 0.70
Scikit-Learn(Perceptron-50イテレーション) 単語カウント(ストップワード除去) 0.65 0.64 0.64

入力データ

20 newsgroups dataset

  • 11314 documents - 13.782MB (training set)
  • 7532 documents - 8.262MB (test set)

特徴数

  • Keras:105373 (KerasのTokenizerはデフォルトで半角の記号群を除去)
  • Scikit-Learn:101322(TfidfVectorizerでstopword=englishを指定)

実行時に、ヘッダを削除(過学習を抑えるため)。

ソースコード

# encoding: utf-8
'''
 -- Keras example text classification with scikit leran metrics

Created on 2017/02/23

@author:     mzi

@copyright:  2017 mzi. All rights reserved.

@license:    Apache Licence 2.0

'''
from __future__ import print_function
import sys
import os

from optparse import OptionParser
from time import time

import numpy as np

from sklearn.datasets import fetch_20newsgroups
from sklearn import metrics

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.utils import np_utils
from keras.preprocessing.text import Tokenizer

__all__ = []
__version__ = 0.1
__date__ = '2017-02-23'
__updated__ = '2017-02-23'

TESTRUN = 0
PROFILE = 0

batch_size = 1024
nb_epoch = 5

def size_mb(docs):
    return sum(len(s.encode('utf-8')) for s in docs) / 1e6

def trim(s):
    """Trim string to fit on terminal (assuming 80-column display)"""
    return s if len(s) <= 80 else s[:77] + "..."


def main(argv=None):
    '''Command line options.'''
    program_name = os.path.basename(sys.argv[0])
    program_version = "v%f" %  __version__
    program_build_date = "%s" % __updated__

    program_version_string = '%%prog %s (%s)' % (program_version, program_build_date)
    program_longdesc = 'GridSearh for scikit learn - LinearSVC with TextData' 
    program_license = "Copyright 2017 mzi                                            \
                Licensed under the Apache License 2.0\nhttp://www.apache.org/licenses/LICENSE-2.0"

    if argv is None:
        argv = sys.argv[1:]

    # setup option parser
    op = OptionParser(version=program_version_string, epilog=program_longdesc, description=program_license)
    op.add_option("--report",
                  action="store_true", dest="print_report",
                  help="Print a detailed classification report.")
    op.add_option("--confusion_matrix",
                  action="store_true", dest="print_cm",
                  help="Print the confusion matrix.")
    op.add_option("--all_categories",
                  action="store_true", dest="all_categories",
                  help="Whether to use all categories or not.")
    op.add_option("--filtered",
                  action="store_true",
                  help="Remove newsgroup information that is easily overfit: "
                       "headers, signatures, and quoting.")

    # process options
    (opts, args) = op.parse_args(argv)

    #Categories
    if opts.all_categories:
        categories = None
    else:
        categories = [
            'alt.atheism',
            'talk.religion.misc',
            'comp.graphics',
            'sci.space',
            ]

    #Remove headers
    if opts.filtered:
        remove = ('headers', 'footers', 'quotes')
    else:
        remove = ()

    print(__doc__)
    op.print_help()
    print()
    
    # MAIN BODY #
    print("Loading 20 newsgroups dataset for categories:")
    print(categories if categories else "all")
    
    data_train = fetch_20newsgroups(subset='train', categories=categories,
                                    shuffle=True, random_state=42,
                                    remove=remove)
    
    data_test = fetch_20newsgroups(subset='test', categories=categories,
                                   shuffle=True, random_state=42,
                                   remove=remove)
    print('data loaded')

    # order of labels in `target_names` can be different from `categories`
    target_names = data_train.target_names

    data_train_size_mb = size_mb(data_train.data)
    data_test_size_mb = size_mb(data_test.data)
    print("%d documents - %0.3fMB (training set)" % (
        len(data_train.data), data_train_size_mb))
    print("%d documents - %0.3fMB (test set)" % (
        len(data_test.data), data_test_size_mb))
    print()
    nb_classes = np.max(data_train.target) + 1
    print(nb_classes, 'classes')

    print('Vectorizing sequence data...')
    tokenizer = Tokenizer(nb_words=None, lower=True, split=' ', char_level=False)
    tokenizer.fit_on_texts(data_train.data)
    X_train = tokenizer.texts_to_matrix(data_train.data, mode='count')
    X_test = tokenizer.texts_to_matrix(data_test.data, mode='count')
    print('X_train shape:', X_train.shape)
    print('X_test shape:', X_test.shape)

    print('Convert class vector to binary class matrix (for use with categorical_crossentropy)')
    Y_train = np_utils.to_categorical(data_train.target, nb_classes)
    Y_test = np_utils.to_categorical(data_test.target, nb_classes)
    print('Y_train shape:', Y_train.shape)
    print('Y_test shape:', Y_test.shape)

    max_words=X_train.shape[1]
    print('Building model...')
    model = Sequential()
    model.add(Dense(512, input_shape=(max_words,)))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))

    model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    
    t0 = time()
    history = model.fit(X_train, Y_train,
                        nb_epoch=nb_epoch, batch_size=batch_size,
                        verbose=1, validation_split=0.1)
    train_time = time() - t0
    print("train time: %0.3fs" % train_time)

    t0 = time()
    #score = model.evaluate(X_test, Y_test,
    #                       batch_size=batch_size, verbose=1)
    #print('Test score:', score[0])
    #print('Test accuracy:', score[1])
    #Y_pred = model.predict(X_test, batch_size=batch_size, verbose=1)
    y_pred = model.predict_classes(X_test, batch_size=batch_size, verbose=1)
    train_time = time() - t0
    print("test time: %0.3fs" % train_time)

    if opts.print_report:
        print("classification report:")
        print(metrics.classification_report(data_test.target, y_pred,
                                            target_names=target_names))
    if opts.print_cm:
        print("confusion matrix:")
        print(metrics.confusion_matrix(data_test.target, y_pred))

    return
    
if __name__ == "__main__":
    if TESTRUN:
        import doctest
        doctest.testmod()
    if PROFILE:
        import cProfile
        import pstats
        profile_filename = '_profile.txt'
        cProfile.run('main()', profile_filename)
        statsfile = open("profile_stats.txt", "wb")
        p = pstats.Stats(profile_filename, stream=statsfile)
        stats = p.strip_dirs().sort_stats('cumulative')
        stats.print_stats()
        statsfile.close()
        sys.exit(0)
    sys.exit(main())
    

コマンドライン

mlp_textclf.py --report --confusion_matrix   --filtered --all_categories

出力

Using TensorFlow backend.

 -- Keras example text classification with scikit leran metrics

Created on 2017/02/23

@author:     mzi

@copyright:  2017 mzi. All rights reserved.

@license:    Apache Licence 2.0


Usage: mlp_textclf.py [options]

Copyright 2017 mzi
Licensed under the Apache License 2.0
http://www.apache.org/licenses/LICENSE-2.0

Options:
  --version           show program's version number and exit
  -h, --help          show this help message and exit
  --report            Print a detailed classification report.
  --confusion_matrix  Print the confusion matrix.
  --all_categories    Whether to use all categories or not.
  --filtered          Remove newsgroup information that is easily overfit:
                      headers, signatures, and quoting.

GridSearh for scikit learn - LinearSVC with TextData

Loading 20 newsgroups dataset for categories:
all
data loaded
11314 documents - 13.782MB (training set)
7532 documents - 8.262MB (test set)

20 classes
Vectorizing sequence data...
X_train shape: (11314, 105373)
X_test shape: (7532, 105373)
Convert class vector to binary class matrix (for use with categorical_crossentropy)
Y_train shape: (11314, 20)
Y_test shape: (7532, 20)
Building model...
Train on 10182 samples, validate on 1132 samples
Epoch 1/5

 1024/10182 [==>...........................] - ETA: 95s - loss: 3.0274 - acc: 0.0518
 2048/10182 [=====>........................] - ETA: 62s - loss: 3.0016 - acc: 0.0894
 3072/10182 [========>.....................] - ETA: 48s - loss: 2.9685 - acc: 0.1387
 4096/10182 [===========>..................] - ETA: 39s - loss: 2.9335 - acc: 0.1799
 5120/10182 [==============>...............] - ETA: 31s - loss: 2.8859 - acc: 0.2254
 6144/10182 [=================>............] - ETA: 24s - loss: 2.8427 - acc: 0.2664
 7168/10182 [====================>.........] - ETA: 18s - loss: 2.7968 - acc: 0.3040
 8192/10182 [=======================>......] - ETA: 12s - loss: 2.7480 - acc: 0.3346
 9216/10182 [==========================>...] - ETA: 5s - loss: 2.6972 - acc: 0.3655 
10182/10182 [==============================] - 69s - loss: 2.6638 - acc: 0.3841 - val_loss: 2.1918 - val_acc: 0.6670
Epoch 2/5

 1024/10182 [==>...........................] - ETA: 57s - loss: 2.0440 - acc: 0.7549
 2048/10182 [=====>........................] - ETA: 47s - loss: 1.9434 - acc: 0.7656
 3072/10182 [========>.....................] - ETA: 44s - loss: 1.9161 - acc: 0.7614
 4096/10182 [===========>..................] - ETA: 35s - loss: 1.8767 - acc: 0.7666
 5120/10182 [==============>...............] - ETA: 29s - loss: 1.8473 - acc: 0.7637
 6144/10182 [=================>............] - ETA: 24s - loss: 1.8071 - acc: 0.7651
 7168/10182 [====================>.........] - ETA: 17s - loss: 1.7737 - acc: 0.7662
 8192/10182 [=======================>......] - ETA: 11s - loss: 1.7445 - acc: 0.7670
 9216/10182 [==========================>...] - ETA: 5s - loss: 1.7181 - acc: 0.7666 
10182/10182 [==============================] - 64s - loss: 1.6931 - acc: 0.7682 - val_loss: 1.6281 - val_acc: 0.7102
Epoch 3/5

 1024/10182 [==>...........................] - ETA: 53s - loss: 1.3333 - acc: 0.8359
 2048/10182 [=====>........................] - ETA: 45s - loss: 1.2573 - acc: 0.8413
 3072/10182 [========>.....................] - ETA: 38s - loss: 1.2245 - acc: 0.8480
 4096/10182 [===========>..................] - ETA: 33s - loss: 1.2217 - acc: 0.8459
 5120/10182 [==============>...............] - ETA: 27s - loss: 1.2054 - acc: 0.8443
 6144/10182 [=================>............] - ETA: 21s - loss: 1.1911 - acc: 0.8433
 7168/10182 [====================>.........] - ETA: 16s - loss: 1.1727 - acc: 0.8470
 8192/10182 [=======================>......] - ETA: 11s - loss: 1.1616 - acc: 0.8478
 9216/10182 [==========================>...] - ETA: 5s - loss: 1.1485 - acc: 0.8477 
10182/10182 [==============================] - 62s - loss: 1.1331 - acc: 0.8494 - val_loss: 1.3319 - val_acc: 0.7438
Epoch 4/5

 1024/10182 [==>...........................] - ETA: 56s - loss: 0.9577 - acc: 0.8896
 2048/10182 [=====>........................] - ETA: 46s - loss: 0.8776 - acc: 0.8955
 3072/10182 [========>.....................] - ETA: 40s - loss: 0.8612 - acc: 0.8916
 4096/10182 [===========>..................] - ETA: 34s - loss: 0.8486 - acc: 0.8955
 5120/10182 [==============>...............] - ETA: 28s - loss: 0.8433 - acc: 0.8943
 6144/10182 [=================>............] - ETA: 22s - loss: 0.8497 - acc: 0.8914
 7168/10182 [====================>.........] - ETA: 16s - loss: 0.8415 - acc: 0.8923
 8192/10182 [=======================>......] - ETA: 11s - loss: 0.8343 - acc: 0.8940
 9216/10182 [==========================>...] - ETA: 5s - loss: 0.8357 - acc: 0.8939 
10182/10182 [==============================] - 62s - loss: 0.8253 - acc: 0.8942 - val_loss: 1.2023 - val_acc: 0.7429
Epoch 5/5

 1024/10182 [==>...........................] - ETA: 68s - loss: 0.7332 - acc: 0.9023
 2048/10182 [=====>........................] - ETA: 61s - loss: 0.7494 - acc: 0.9019
 3072/10182 [========>.....................] - ETA: 50s - loss: 0.7110 - acc: 0.9082
 4096/10182 [===========>..................] - ETA: 39s - loss: 0.6886 - acc: 0.9141
 5120/10182 [==============>...............] - ETA: 31s - loss: 0.6811 - acc: 0.9129
 6144/10182 [=================>............] - ETA: 25s - loss: 0.6698 - acc: 0.9157
 7168/10182 [====================>.........] - ETA: 18s - loss: 0.6632 - acc: 0.9157
 8192/10182 [=======================>......] - ETA: 12s - loss: 0.6540 - acc: 0.9163
 9216/10182 [==========================>...] - ETA: 5s - loss: 0.6445 - acc: 0.9173 
10182/10182 [==============================] - 64s - loss: 0.6358 - acc: 0.9181 - val_loss: 1.1015 - val_acc: 0.7562
train time: 326.131s

1024/7532 [===>..........................] - ETA: 115s
2048/7532 [=======>......................] - ETA: 57s 
3072/7532 [===========>..................] - ETA: 34s
4096/7532 [===============>..............] - ETA: 22s
5120/7532 [===================>..........] - ETA: 13s
6144/7532 [=======================>......] - ETA: 7s 
7168/7532 [===========================>..] - ETA: 1s
7532/7532 [==============================] - 40s    
test time: 55.487s
classification report:
                          precision    recall  f1-score   support

             alt.atheism       0.53      0.46      0.49       319
           comp.graphics       0.68      0.73      0.71       389
 comp.os.ms-windows.misc       0.68      0.62      0.65       394
comp.sys.ibm.pc.hardware       0.69      0.62      0.65       392
   comp.sys.mac.hardware       0.71      0.70      0.71       385
          comp.windows.x       0.86      0.75      0.80       395
            misc.forsale       0.80      0.81      0.80       390
               rec.autos       0.81      0.74      0.77       396
         rec.motorcycles       0.74      0.79      0.76       398
      rec.sport.baseball       0.53      0.92      0.67       397
        rec.sport.hockey       0.96      0.86      0.91       399
               sci.crypt       0.74      0.74      0.74       396
         sci.electronics       0.63      0.64      0.63       393
                 sci.med       0.82      0.76      0.79       396
               sci.space       0.70      0.74      0.72       394
  soc.religion.christian       0.67      0.81      0.73       398
      talk.politics.guns       0.60      0.71      0.65       364
   talk.politics.mideast       0.87      0.73      0.80       376
      talk.politics.misc       0.59      0.43      0.50       310
      talk.religion.misc       0.51      0.29      0.37       251

             avg / total       0.71      0.70      0.70      7532

confusion matrix:
[[147   1   1   0   1   3   0   5   7  18   1   4   5   3  17  46   7  12
    8  33]
 [  6 284  15   8  12  16   4   1   4  10   0  11   8   1   9   0   0   0
    0   0]
 [  2  24 243  35  16  16   3   1   4  20   0   5   1   4  12   1   2   0
    4   1]
 [  0  12  38 243  38   4   9   4   0   8   1   4  29   0   1   0   0   0
    0   1]
 [  0   7  10  22 270   3   8   6   6  16   0   5  24   3   4   1   0   0
    0   0]
 [  0  32  25   8   3 295   7   1   0  10   0   3   5   2   2   1   0   0
    1   0]
 [  1   1   3  13  13   0 316   6   6  12   0   1   8   0   5   0   4   0
    0   1]
 [  3   0   0   1   2   0  12 292  21  28   0   2  16   1   9   3   3   0
    3   0]
 [  2   1   0   0   1   2   7  16 313  18   1   1  11   4   9   1   6   0
    5   0]
 [  2   3   0   0   1   0   3   0   5 364   4   0   1   1   1   7   0   0
    5   0]
 [  5   1   0   0   0   0   2   1   2  32 343   2   0   2   3   1   3   1
    0   1]
 [  3   9   4   2   3   2   1   1   3  26   0 292  14   2   7   1  17   3
    5   1]
 [  1  11   8  18  15   1  12   7  10  17   0  19 253  10   8   0   1   1
    1   0]
 [  7   8   3   2   1   0   6   3  11  15   2   1   8 302   5  10   4   5
    3   0]
 [  5  13   1   0   1   0   2   4   4  21   1   5  11   9 293   6   4   1
   13   0]
 [ 18   3   3   0   0   0   0   1   1  17   1   1   2   4   4 321   1   1
    3  17]
 [  8   0   2   0   1   0   2   4   6  22   0  15   2   4   8   7 257   2
   16   8]
 [ 21   1   1   1   0   1   1   3   7  15   0   5   2   2   5   8  10 276
   16   1]
 [ 15   1   0   0   0   0   1   1   9  10   3  14   3   5   8   5  88   8
  133   6]
 [ 34   4   2   1   0   0   1   3   4  11   1   3   1  11   6  61  22   6
    8  72]]