Keras のテキスト分類の結果をscikit learn のmetricsで評価
前回、scikit-learnの GridSearch をおさらいした。今回は、前々回のコードを修正し、同じscikit-learnのデータを使ってKeras(Tensolflowバックエンド)での標準的実装で精度を出した。精度算出のメトリックを合わせるため、scikit-learnで提供されているmetrics系の関数を使って計算している。
結論からいうと、scikit-learnのlinear svmと、KerasでのMLPの精度は、全体平均精度は変わらず。
というか、今回は、全結合1層+活性化層1層+Dropoutしかないので、Deep learningじゃない。なので、scikit-learnのパーセプトロン実装にDropoutを加えただけという感じ。そういう意味では、パーセプトロンと比較して、Dropoutの効果として見るべきかもしれない。
もう一つの見方として、全結合で512まで一度落としているので、ここで、うまく特徴抽出が働いていて、SVMのGridSearchと同じくらいの性能は出せたと考えるべきかな。
結果
実装 | 特徴量 | precision | recall | f1-score |
---|---|---|---|---|
Keras(DeepLearning-512-ReLU-Dropout0.5, batch=1024) | 単語カウント(記号除去) | 0.71 | 0.70 | 0.70 |
Scikit-Learn(LinearSVM-GridSearch(2乗誤差-L2正則化)) | TFIDF+正規化(ストップワード除去) | 0.71 | 0.70 | 0.70 |
Scikit-Learn(Perceptron-50イテレーション) | 単語カウント(ストップワード除去) | 0.65 | 0.64 | 0.64 |
入力データ
20 newsgroups dataset
- 11314 documents - 13.782MB (training set)
- 7532 documents - 8.262MB (test set)
特徴数
- Keras:105373 (KerasのTokenizerはデフォルトで半角の記号群を除去)
- Scikit-Learn:101322(TfidfVectorizerでstopword=englishを指定)
実行時に、ヘッダを削除(過学習を抑えるため)。
ソースコード
# encoding: utf-8 ''' -- Keras example text classification with scikit leran metrics Created on 2017/02/23 @author: mzi @copyright: 2017 mzi. All rights reserved. @license: Apache Licence 2.0 ''' from __future__ import print_function import sys import os from optparse import OptionParser from time import time import numpy as np from sklearn.datasets import fetch_20newsgroups from sklearn import metrics from keras.models import Sequential from keras.layers import Dense, Dropout, Activation from keras.utils import np_utils from keras.preprocessing.text import Tokenizer __all__ = [] __version__ = 0.1 __date__ = '2017-02-23' __updated__ = '2017-02-23' TESTRUN = 0 PROFILE = 0 batch_size = 1024 nb_epoch = 5 def size_mb(docs): return sum(len(s.encode('utf-8')) for s in docs) / 1e6 def trim(s): """Trim string to fit on terminal (assuming 80-column display)""" return s if len(s) <= 80 else s[:77] + "..." def main(argv=None): '''Command line options.''' program_name = os.path.basename(sys.argv[0]) program_version = "v%f" % __version__ program_build_date = "%s" % __updated__ program_version_string = '%%prog %s (%s)' % (program_version, program_build_date) program_longdesc = 'GridSearh for scikit learn - LinearSVC with TextData' program_license = "Copyright 2017 mzi \ Licensed under the Apache License 2.0\nhttp://www.apache.org/licenses/LICENSE-2.0" if argv is None: argv = sys.argv[1:] # setup option parser op = OptionParser(version=program_version_string, epilog=program_longdesc, description=program_license) op.add_option("--report", action="store_true", dest="print_report", help="Print a detailed classification report.") op.add_option("--confusion_matrix", action="store_true", dest="print_cm", help="Print the confusion matrix.") op.add_option("--all_categories", action="store_true", dest="all_categories", help="Whether to use all categories or not.") op.add_option("--filtered", action="store_true", help="Remove newsgroup information that is easily overfit: " "headers, signatures, and quoting.") # process options (opts, args) = op.parse_args(argv) #Categories if opts.all_categories: categories = None else: categories = [ 'alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space', ] #Remove headers if opts.filtered: remove = ('headers', 'footers', 'quotes') else: remove = () print(__doc__) op.print_help() print() # MAIN BODY # print("Loading 20 newsgroups dataset for categories:") print(categories if categories else "all") data_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42, remove=remove) data_test = fetch_20newsgroups(subset='test', categories=categories, shuffle=True, random_state=42, remove=remove) print('data loaded') # order of labels in `target_names` can be different from `categories` target_names = data_train.target_names data_train_size_mb = size_mb(data_train.data) data_test_size_mb = size_mb(data_test.data) print("%d documents - %0.3fMB (training set)" % ( len(data_train.data), data_train_size_mb)) print("%d documents - %0.3fMB (test set)" % ( len(data_test.data), data_test_size_mb)) print() nb_classes = np.max(data_train.target) + 1 print(nb_classes, 'classes') print('Vectorizing sequence data...') tokenizer = Tokenizer(nb_words=None, lower=True, split=' ', char_level=False) tokenizer.fit_on_texts(data_train.data) X_train = tokenizer.texts_to_matrix(data_train.data, mode='count') X_test = tokenizer.texts_to_matrix(data_test.data, mode='count') print('X_train shape:', X_train.shape) print('X_test shape:', X_test.shape) print('Convert class vector to binary class matrix (for use with categorical_crossentropy)') Y_train = np_utils.to_categorical(data_train.target, nb_classes) Y_test = np_utils.to_categorical(data_test.target, nb_classes) print('Y_train shape:', Y_train.shape) print('Y_test shape:', Y_test.shape) max_words=X_train.shape[1] print('Building model...') model = Sequential() model.add(Dense(512, input_shape=(max_words,))) model.add(Activation('relu')) model.add(Dropout(0.5)) model.add(Dense(nb_classes)) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) t0 = time() history = model.fit(X_train, Y_train, nb_epoch=nb_epoch, batch_size=batch_size, verbose=1, validation_split=0.1) train_time = time() - t0 print("train time: %0.3fs" % train_time) t0 = time() #score = model.evaluate(X_test, Y_test, # batch_size=batch_size, verbose=1) #print('Test score:', score[0]) #print('Test accuracy:', score[1]) #Y_pred = model.predict(X_test, batch_size=batch_size, verbose=1) y_pred = model.predict_classes(X_test, batch_size=batch_size, verbose=1) train_time = time() - t0 print("test time: %0.3fs" % train_time) if opts.print_report: print("classification report:") print(metrics.classification_report(data_test.target, y_pred, target_names=target_names)) if opts.print_cm: print("confusion matrix:") print(metrics.confusion_matrix(data_test.target, y_pred)) return if __name__ == "__main__": if TESTRUN: import doctest doctest.testmod() if PROFILE: import cProfile import pstats profile_filename = '_profile.txt' cProfile.run('main()', profile_filename) statsfile = open("profile_stats.txt", "wb") p = pstats.Stats(profile_filename, stream=statsfile) stats = p.strip_dirs().sort_stats('cumulative') stats.print_stats() statsfile.close() sys.exit(0) sys.exit(main())
コマンドライン
mlp_textclf.py --report --confusion_matrix --filtered --all_categories
出力
Using TensorFlow backend. -- Keras example text classification with scikit leran metrics Created on 2017/02/23 @author: mzi @copyright: 2017 mzi. All rights reserved. @license: Apache Licence 2.0 Usage: mlp_textclf.py [options] Copyright 2017 mzi Licensed under the Apache License 2.0 http://www.apache.org/licenses/LICENSE-2.0 Options: --version show program's version number and exit -h, --help show this help message and exit --report Print a detailed classification report. --confusion_matrix Print the confusion matrix. --all_categories Whether to use all categories or not. --filtered Remove newsgroup information that is easily overfit: headers, signatures, and quoting. GridSearh for scikit learn - LinearSVC with TextData Loading 20 newsgroups dataset for categories: all data loaded 11314 documents - 13.782MB (training set) 7532 documents - 8.262MB (test set) 20 classes Vectorizing sequence data... X_train shape: (11314, 105373) X_test shape: (7532, 105373) Convert class vector to binary class matrix (for use with categorical_crossentropy) Y_train shape: (11314, 20) Y_test shape: (7532, 20) Building model... Train on 10182 samples, validate on 1132 samples Epoch 1/5 1024/10182 [==>...........................] - ETA: 95s - loss: 3.0274 - acc: 0.0518 2048/10182 [=====>........................] - ETA: 62s - loss: 3.0016 - acc: 0.0894 3072/10182 [========>.....................] - ETA: 48s - loss: 2.9685 - acc: 0.1387 4096/10182 [===========>..................] - ETA: 39s - loss: 2.9335 - acc: 0.1799 5120/10182 [==============>...............] - ETA: 31s - loss: 2.8859 - acc: 0.2254 6144/10182 [=================>............] - ETA: 24s - loss: 2.8427 - acc: 0.2664 7168/10182 [====================>.........] - ETA: 18s - loss: 2.7968 - acc: 0.3040 8192/10182 [=======================>......] - ETA: 12s - loss: 2.7480 - acc: 0.3346 9216/10182 [==========================>...] - ETA: 5s - loss: 2.6972 - acc: 0.3655 10182/10182 [==============================] - 69s - loss: 2.6638 - acc: 0.3841 - val_loss: 2.1918 - val_acc: 0.6670 Epoch 2/5 1024/10182 [==>...........................] - ETA: 57s - loss: 2.0440 - acc: 0.7549 2048/10182 [=====>........................] - ETA: 47s - loss: 1.9434 - acc: 0.7656 3072/10182 [========>.....................] - ETA: 44s - loss: 1.9161 - acc: 0.7614 4096/10182 [===========>..................] - ETA: 35s - loss: 1.8767 - acc: 0.7666 5120/10182 [==============>...............] - ETA: 29s - loss: 1.8473 - acc: 0.7637 6144/10182 [=================>............] - ETA: 24s - loss: 1.8071 - acc: 0.7651 7168/10182 [====================>.........] - ETA: 17s - loss: 1.7737 - acc: 0.7662 8192/10182 [=======================>......] - ETA: 11s - loss: 1.7445 - acc: 0.7670 9216/10182 [==========================>...] - ETA: 5s - loss: 1.7181 - acc: 0.7666 10182/10182 [==============================] - 64s - loss: 1.6931 - acc: 0.7682 - val_loss: 1.6281 - val_acc: 0.7102 Epoch 3/5 1024/10182 [==>...........................] - ETA: 53s - loss: 1.3333 - acc: 0.8359 2048/10182 [=====>........................] - ETA: 45s - loss: 1.2573 - acc: 0.8413 3072/10182 [========>.....................] - ETA: 38s - loss: 1.2245 - acc: 0.8480 4096/10182 [===========>..................] - ETA: 33s - loss: 1.2217 - acc: 0.8459 5120/10182 [==============>...............] - ETA: 27s - loss: 1.2054 - acc: 0.8443 6144/10182 [=================>............] - ETA: 21s - loss: 1.1911 - acc: 0.8433 7168/10182 [====================>.........] - ETA: 16s - loss: 1.1727 - acc: 0.8470 8192/10182 [=======================>......] - ETA: 11s - loss: 1.1616 - acc: 0.8478 9216/10182 [==========================>...] - ETA: 5s - loss: 1.1485 - acc: 0.8477 10182/10182 [==============================] - 62s - loss: 1.1331 - acc: 0.8494 - val_loss: 1.3319 - val_acc: 0.7438 Epoch 4/5 1024/10182 [==>...........................] - ETA: 56s - loss: 0.9577 - acc: 0.8896 2048/10182 [=====>........................] - ETA: 46s - loss: 0.8776 - acc: 0.8955 3072/10182 [========>.....................] - ETA: 40s - loss: 0.8612 - acc: 0.8916 4096/10182 [===========>..................] - ETA: 34s - loss: 0.8486 - acc: 0.8955 5120/10182 [==============>...............] - ETA: 28s - loss: 0.8433 - acc: 0.8943 6144/10182 [=================>............] - ETA: 22s - loss: 0.8497 - acc: 0.8914 7168/10182 [====================>.........] - ETA: 16s - loss: 0.8415 - acc: 0.8923 8192/10182 [=======================>......] - ETA: 11s - loss: 0.8343 - acc: 0.8940 9216/10182 [==========================>...] - ETA: 5s - loss: 0.8357 - acc: 0.8939 10182/10182 [==============================] - 62s - loss: 0.8253 - acc: 0.8942 - val_loss: 1.2023 - val_acc: 0.7429 Epoch 5/5 1024/10182 [==>...........................] - ETA: 68s - loss: 0.7332 - acc: 0.9023 2048/10182 [=====>........................] - ETA: 61s - loss: 0.7494 - acc: 0.9019 3072/10182 [========>.....................] - ETA: 50s - loss: 0.7110 - acc: 0.9082 4096/10182 [===========>..................] - ETA: 39s - loss: 0.6886 - acc: 0.9141 5120/10182 [==============>...............] - ETA: 31s - loss: 0.6811 - acc: 0.9129 6144/10182 [=================>............] - ETA: 25s - loss: 0.6698 - acc: 0.9157 7168/10182 [====================>.........] - ETA: 18s - loss: 0.6632 - acc: 0.9157 8192/10182 [=======================>......] - ETA: 12s - loss: 0.6540 - acc: 0.9163 9216/10182 [==========================>...] - ETA: 5s - loss: 0.6445 - acc: 0.9173 10182/10182 [==============================] - 64s - loss: 0.6358 - acc: 0.9181 - val_loss: 1.1015 - val_acc: 0.7562 train time: 326.131s 1024/7532 [===>..........................] - ETA: 115s 2048/7532 [=======>......................] - ETA: 57s 3072/7532 [===========>..................] - ETA: 34s 4096/7532 [===============>..............] - ETA: 22s 5120/7532 [===================>..........] - ETA: 13s 6144/7532 [=======================>......] - ETA: 7s 7168/7532 [===========================>..] - ETA: 1s 7532/7532 [==============================] - 40s test time: 55.487s classification report: precision recall f1-score support alt.atheism 0.53 0.46 0.49 319 comp.graphics 0.68 0.73 0.71 389 comp.os.ms-windows.misc 0.68 0.62 0.65 394 comp.sys.ibm.pc.hardware 0.69 0.62 0.65 392 comp.sys.mac.hardware 0.71 0.70 0.71 385 comp.windows.x 0.86 0.75 0.80 395 misc.forsale 0.80 0.81 0.80 390 rec.autos 0.81 0.74 0.77 396 rec.motorcycles 0.74 0.79 0.76 398 rec.sport.baseball 0.53 0.92 0.67 397 rec.sport.hockey 0.96 0.86 0.91 399 sci.crypt 0.74 0.74 0.74 396 sci.electronics 0.63 0.64 0.63 393 sci.med 0.82 0.76 0.79 396 sci.space 0.70 0.74 0.72 394 soc.religion.christian 0.67 0.81 0.73 398 talk.politics.guns 0.60 0.71 0.65 364 talk.politics.mideast 0.87 0.73 0.80 376 talk.politics.misc 0.59 0.43 0.50 310 talk.religion.misc 0.51 0.29 0.37 251 avg / total 0.71 0.70 0.70 7532 confusion matrix: [[147 1 1 0 1 3 0 5 7 18 1 4 5 3 17 46 7 12 8 33] [ 6 284 15 8 12 16 4 1 4 10 0 11 8 1 9 0 0 0 0 0] [ 2 24 243 35 16 16 3 1 4 20 0 5 1 4 12 1 2 0 4 1] [ 0 12 38 243 38 4 9 4 0 8 1 4 29 0 1 0 0 0 0 1] [ 0 7 10 22 270 3 8 6 6 16 0 5 24 3 4 1 0 0 0 0] [ 0 32 25 8 3 295 7 1 0 10 0 3 5 2 2 1 0 0 1 0] [ 1 1 3 13 13 0 316 6 6 12 0 1 8 0 5 0 4 0 0 1] [ 3 0 0 1 2 0 12 292 21 28 0 2 16 1 9 3 3 0 3 0] [ 2 1 0 0 1 2 7 16 313 18 1 1 11 4 9 1 6 0 5 0] [ 2 3 0 0 1 0 3 0 5 364 4 0 1 1 1 7 0 0 5 0] [ 5 1 0 0 0 0 2 1 2 32 343 2 0 2 3 1 3 1 0 1] [ 3 9 4 2 3 2 1 1 3 26 0 292 14 2 7 1 17 3 5 1] [ 1 11 8 18 15 1 12 7 10 17 0 19 253 10 8 0 1 1 1 0] [ 7 8 3 2 1 0 6 3 11 15 2 1 8 302 5 10 4 5 3 0] [ 5 13 1 0 1 0 2 4 4 21 1 5 11 9 293 6 4 1 13 0] [ 18 3 3 0 0 0 0 1 1 17 1 1 2 4 4 321 1 1 3 17] [ 8 0 2 0 1 0 2 4 6 22 0 15 2 4 8 7 257 2 16 8] [ 21 1 1 1 0 1 1 3 7 15 0 5 2 2 5 8 10 276 16 1] [ 15 1 0 0 0 0 1 1 9 10 3 14 3 5 8 5 88 8 133 6] [ 34 4 2 1 0 0 1 3 4 11 1 3 1 11 6 61 22 6 8 72]]