PRMLの第1章の曲線fittingをjavaで簡単に解いてみる

最近、周りでPRML本(パターン認識機械学習)がはやり始めたのでちょっとプログラムでもしてみるかなと思い立った。

ひとまず第1章の曲線フィッティングでもやってみる。
第1章の二乗誤差和最小と二乗誤差和+正則化で最小。

最近、言語処理が多くperlばかりだったので、たまにはJavaで書いてみようかと思った。

で、とりあえず解ければいいので、フリーのツールは使いまくることにした。
#まあ、これではアルゴリズムのとき方のお勉強にはならないですけどね。

で必要なのは、最終的には行列演算。
オプションで、結果を見てみるための関数描画。

行列演算のライブラリ
javaのライブラリはいろいろあるようなのだけど、どれも更新がとまっているようで、少し不安。
Java matrix package(http://math.nist.gov/javanumerics/jama/)
JAMA is comprised of six Java classes: Matrix, CholeskyDecomposition, LUDecomposition, QRDecomposition, SingularValueDecomposition and EigenvalueDecomposition.
ベーシック機能は全部そろってる。
・netliv-java(http://code.google.com/p/netlib-java/)
Netlib is a collection of mission-critical software components for linear algebra systems (i.e. working with vectors or matrices). Netlib libraries are written in C, Fortran or optimised assembly code. A Java translation has been provided by the F2J project but it does not take advantage of optimised system libraries.
Netlib(http://www.netlib.org/)をFortran2Javaを使ってソースコードレベルでJavaファイルに直したもの。
・Matrix toolkit java(http://code.google.com/p/matrix-toolkits-java/)
MTJ is designed to be used as a library for developing numerical applications, both for small and large scale computations. The library is based on BLAS and LAPACK for its dense and structured sparse computations, and on the Templates project for unstructured sparse operations
大規模でも大丈夫そうなことは書いてある。netlib-javaをバックで呼び出している。
・Universal Java Matrix Package(http://www.ujmp.org/)
The Universal Java Matrix Package (UJMP) is an open source Java library that provides sparse and dense matrix classes, as well as a large number of calculations for linear algebra like matrix multiplication or matrix inverse. Operations such as mean, correlation, standard deviation, replacement of missing values or the calculation of mutual information are supported also.
ビジュアライズ機能も含んでいるようで、一番でかい。Matrix toolkit javaをaddonできるようなことも書いてあった。

今回は、テスト用なので、単機能ですぐ使えそうなJava matrix packageにする。
ほかのものは、高機能なので、あとでゆっくり試すとしよう。

では、解いてみる。

package prml.chap1;

// 学習データ:(x,t)={(x_1,t_1),...,(x_n,t_n)
// 曲線フィット:y(x,w) = w_0 + w_1 * x^1 + w_2 * x^2 ... = ??_j=0〜M w_j * x^j 
// 学習データ(x,t)と次元数Mを入力に、重みベクトルwを出力する
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;

import Jama.Matrix;

public class PolyCurveFit {

	ArrayList<double[]> training;
	int dim;

	public PolyCurveFit() {
		this.training = new ArrayList<double[]>();
		this.dim = -1;
	}

	public void addTrainingData(double[] data) {
		this.training.add(data);
	}

	public double[] getWeightsByMinSqrtErr(int m) {
		// 学習データ:(x,t)={(x_1,t_1),...,(x_n,t_n)
		// 曲線フィット:y(x,w) = w_0 + w_1 * x^1 + w_2 * x^2 ... = ??_j=0〜M w_j * x^j

		// Aw = T (A:m*m次元, T:m*1次元, w:m*1次元の重み)
		// A_ij = ??_n=1〜N (x_n)^{i+j}
		// T_i = ??_n=1〜N (x_n)^i * t_n
		// w重み

		// 行列A
		Matrix A = new Matrix(m + 1, m + 1);
		for (int i = 0; i < m + 1; i++) {
			for (int j = 0; j < m + 1; j++) {
				A.set(i, j, this.calcElemA(i, j));
			}
		}

		// 行列T
		Matrix T = new Matrix(m + 1, 1);
		for (int i = 0; i < m + 1; i++) {
			T.set(i, 0, this.calcElemT(i));
		}

		Matrix w = A.solve(T);

		// for debug
		// printMatrix(A);
		// printMatrix(T);
		// printMatrix(w);
		// Matrix Residual = A.times(w).minus(T);
		// printMatrix(Residual);

		return w.getColumnPackedCopy();

	}

	private double calcElemA(int i, int j) {
		double ret = 0.0;
		for (int n = 0; n < this.training.size(); n++) {
			ret += Math.pow(this.training.get(n)[0], i + j);
		}
		return ret;
	}

	private double calcElemT(int i) {
		double ret = 0.0;
		for (int n = 0; n < this.training.size(); n++) {
			ret += this.training.get(n)[1]
					* Math.pow(this.training.get(n)[0], i);
		}
		return ret;
	}

	public double[] getWeightsByMinSqrtErrReg(int m, double r) {
		// 学習データ:(x,t)={(x_1,t_1),...,(x_n,t_n)
		// 曲線フィット:y(x,w) = w_0 + w_1 * x^1 + w_2 * x^2 ... = ??_j=0〜M w_j * x^j

		// (A+λI)w = T (A:m*m次元, λ:正則化用係数, I:m*m次元単位行列, T:m*1次元, w:m*1次元の重み)
		// A_ij = ??_n=1〜N (x_n)^{i+j}
		// T_i = ??_n=1〜N (x_n)^i * t_n
		// w重み

		// 行列A
		Matrix A = new Matrix(m + 1, m + 1);
		for (int i = 0; i < m + 1; i++) {
			for (int j = 0; j < m + 1; j++) {
				A.set(i, j, this.calcElemA(i, j));
			}
		}

		// 行列λI
		Matrix rI = new Matrix(m + 1, m + 1);
		for (int i = 0; i < m + 1; i++) {
			for (int j = 0; j < m + 1; j++) {
				if (i == j)
					rI.set(i, j, r);
				else
					rI.set(i, j, 0.0);
			}
		}

		// 行列T
		Matrix T = new Matrix(m + 1, 1);
		for (int i = 0; i < m + 1; i++) {
			T.set(i, 0, this.calcElemT(i));
		}

		Matrix A_rI = A.plus(rI);
		Matrix w = A_rI.solve(T);

		// for debug
		// printMatrix(A);
		// printMatrix(T);
		// printMatrix(w);
		// Matrix Residual = A.times(w).minus(T);
		// printMatrix(Residual);

		return w.getColumnPackedCopy();

	}

	private static void printMatrix(Matrix x) {
		for (int j = 0; j < x.getColumnDimension(); j++) {
			System.out.print("\t");
			System.out.print("[" + j + "]");
		}
		System.out.println();
		for (int i = 0; i < x.getRowDimension(); i++) {
			System.out.print("[" + i + "]");
			for (int j = 0; j < x.getColumnDimension(); j++) {
				System.out.print("\t");
				System.out.print(x.get(i, j));
			}
			System.out.println();
		}
		return;
	}

	public static void main(String args[]) throws FileNotFoundException {
		// ファイル名
		String filename = args[0];
		// 次数
		int m = Integer.parseInt(args[1]);

		// フィッティングオブジェクト
		PolyCurveFit pcFitEM = new PolyCurveFit();

		// データをロードする
		BufferedReader br = new BufferedReader(new FileReader(
				new File(filename)));
		String line;

		try {
			while ( ( line = br.readLine()) != null) {
				// System.out.println(line);
				String recStr[] = line.split(" ", 2);
				double[] rec = { Double.parseDouble(recStr[0]),
						Double.parseDouble(recStr[1]) };
				pcFitEM.addTrainingData(rec);
			}
			br.close();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

		// 重み計算(二乗和誤差最小)
		double[] w_MinSqrtErr = pcFitEM.getWeightsByMinSqrtErr(m);
		System.out.println("----");
		for (int i = 0; i < w_MinSqrtErr.length; i++) {
			System.out.println(w_MinSqrtErr[i]);
		}

		// 重み計算(二乗和誤差+正則化最小)
		double r = Math.pow(Math.E, 0);
		double[] w_MinSqrtErrReg = pcFitEM.getWeightsByMinSqrtErrReg(m, r);
		System.out.println("----");
		for (int i = 0; i < w_MinSqrtErrReg.length; i++) {
			System.out.println(w_MinSqrtErrReg[i]);
		}
	}
}

結果が、ちょっずれてるなぁ、本の中の値と。

追記
こっちで、グラフも書いてみました。http://d.hatena.ne.jp/mzi/20121015