PRML 曲線フィッティングをJavaで書いてみる
- 作者: C. M.ビショップ,元田浩,栗田多喜夫,樋口知之,松本裕治,村田昇
- 出版社/メーカー: シュプリンガー・ジャパン株式会社
- 発売日: 2007/12/10
- メディア: 単行本
- 購入: 18人 クリック: 1,588回
- この商品を含むブログ (110件) を見る
PRML本の曲線フィッティング1.1節を実装してみたのはだいぶ前。
でも、weightがPRML本と微妙に違っていたので、グラフを書いてみることにした。
グラフの見た目では、PRML本とそれほど大きな違いではなさそうなことがわかった。数値解析的な面を無視してしまってるからなぁ。
なお、今回は、グラフ描画に、JFreeChartを使った。これ、意外と簡単でよかった。JFreeChartを使ってXYグラフだけを書くラップクラスも載せておきます。
グラフ描画した結果。
- M=9。
- 赤=正解sin(2πx)
- 青点=学習データ(PRMLのサイトからダウンロード)
- 緑=二乗誤差最小のみ
- 黄=二乗誤差+正則(ln λ = -18)
訓練データが10点しかないので、モデルが9では、二乗誤差最小のみでは過学習しているのがわかる(すべての学習データ点は通過しているが、正解にはほど遠い)。正則化するとかなり改善される。
コード
package prml.chap1; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; import Jama.Matrix; public class PolyCurveFit { ArrayList<double[]> training; public PolyCurveFit() { this.training = new ArrayList<double[]>(); } public void addTrainingData(double[] data) { this.training.add(data); } public List<double[]> getTrainingDataValues() { return this.training; } public double[] getWeightsByMinSqrtErr(int m) { // 学習データ:(x,t)={(x_1,t_1),...,(x_n,t_n) // 曲線フィット:y(x,w) = w_0 + w_1 * x^1 + w_2 * x^2 ... = ??_j=0〜M w_j * x^j // Aw = T (A:m*m次元, T:m*1次元, w:m*1次元の重み) // A_ij = ??_n=1〜N (x_n)^{i+j} // T_i = ??_n=1〜N (x_n)^i * t_n // w重み //行列Aを直接計算 Matrix A = new Matrix(m + 1, m + 1); for (int i = 0; i < m + 1; i++) { for (int j = 0; j < m + 1; j++) { double val_Aij = 0; for (int n = 0; n < this.training.size(); n++) { val_Aij += Math.pow(this.training.get(n)[0], i + j); } A.set(i, j, val_Aij); } } // 行列T Matrix T = new Matrix(m + 1, 1); for (int i = 0; i < m + 1; i++) { double val_Ti = 0; for (int n = 0; n < this.training.size(); n++) { val_Ti += this.training.get(n)[1] * Math.pow(this.training.get(n)[0], i); } T.set(i, 0, val_Ti); } Matrix w = A.solve(T); // for debug // printMatrix(A); // printMatrix(T); // printMatrix(w); // Matrix Residual = A.times(w).minus(T); // printMatrix(Residual); return w.getColumnPackedCopy(); } public double[] getWeightsByMinSqrtErrReg(int m, double r) { // 学習データ:(x,t)={(x_1,t_1),...,(x_n,t_n) // 曲線フィット:y(x,w) = w_0 + w_1 * x^1 + w_2 * x^2 ... = ??_j=0〜M w_j * x^j // (A+λI)w = T (A:m*m次元, λ:正則化用係数, I:m*m次元単位行列, T:m*1次元, w:m*1次元の重み) // A_ij = ??_n=1〜N (x_n)^{i+j} // T_i = ??_n=1〜N (x_n)^i * t_n // w重み // 行列A Matrix A = new Matrix(m + 1, m + 1); for (int i = 0; i < m + 1; i++) { for (int j = 0; j < m + 1; j++) { double val_Aij = 0; for (int n = 0; n < this.training.size(); n++) { val_Aij += Math.pow(this.training.get(n)[0], i + j); } A.set(i, j, val_Aij); } } // 行列λI Matrix rI = new Matrix(m + 1, m + 1); for (int i = 0; i < m + 1; i++) { for (int j = 0; j < m + 1; j++) { if (i == j) rI.set(i, j, r); else rI.set(i, j, 0.0); } } // 行列T Matrix T = new Matrix(m + 1, 1); for (int i = 0; i < m + 1; i++) { double val_Ti = 0; for (int n = 0; n < this.training.size(); n++) { val_Ti += this.training.get(n)[1] * Math.pow(this.training.get(n)[0], i); } T.set(i, 0, val_Ti); } Matrix A_rI = A.plus(rI); Matrix w = A_rI.solve(T); // for debug // printMatrix(A); // printMatrix(T); // printMatrix(w); // Matrix Residual = A.times(w).minus(T); // printMatrix(Residual); return w.getColumnPackedCopy(); } private static void printMatrix(Matrix x) { for (int j = 0; j < x.getColumnDimension(); j++) { System.out.print("\t"); System.out.print("[" + j + "]"); } System.out.println(); for (int i = 0; i < x.getRowDimension(); i++) { System.out.print("[" + i + "]"); for (int j = 0; j < x.getColumnDimension(); j++) { System.out.print("\t"); System.out.print(x.get(i, j)); } System.out.println(); } return; } public static void main(String args[]) throws IOException { // 学習データファイル名 String filename = args[0]; // 次数 int m = Integer.parseInt(args[1]); // ln λ int ln_r = Integer.parseInt(args[2]); // フィッティングオブジェクト PolyCurveFit pcFitEM = new PolyCurveFit(); // データをロードして学習データとして追加 BufferedReader br = new BufferedReader(new FileReader( new File(filename))); String line; try { while ((line = br.readLine()) != null) { // System.out.println(line); String recStr[] = line.split(" ", 2); double[] rec = { Double.parseDouble(recStr[0]), Double.parseDouble(recStr[1]) }; pcFitEM.addTrainingData(rec); } br.close(); } catch (IOException e) { e.printStackTrace(); } // 重み計算(二乗和誤差最小) double[] w_MinSqrtErr = pcFitEM.getWeightsByMinSqrtErr(m); System.out.println("----"); for (int i = 0; i < w_MinSqrtErr.length; i++) { System.out.println(w_MinSqrtErr[i]); } // 重み計算(二乗和誤差+正則化最小) double r = Math.pow(Math.E, ln_r); double[] w_MinSqrtErrReg = pcFitEM.getWeightsByMinSqrtErrReg(m, r); System.out.println("----"); for (int i = 0; i < w_MinSqrtErrReg.length; i++) { System.out.println(w_MinSqrtErrReg[i]); } //グラフ描画用 //正解データ double[][] sineValues = makeSineValues(); //訓練データ double[][] trainValues = makeTrainValues(pcFitEM.getTrainingDataValues()); //学習結果(二乗和) double[][] resultValues = makeResultValues(w_MinSqrtErr, m); //学習結果(二乗和) double[][] resultValuesReg = makeResultValues(w_MinSqrtErrReg, m); //グラフ表示 XYGraph xyGraph = new XYGraph("Fit Sine(m="+ String.valueOf(m)+")", "X", "Y"); xyGraph.addDataValues("Sin(x)", sineValues, true); xyGraph.addDataValues("Training", trainValues, false); xyGraph.addDataValues("MinSqrtErr", resultValues, true); xyGraph.addDataValues("MinSqrtErrReg(ln rabmda = "+String.valueOf(ln_r)+")", resultValuesReg, true); xyGraph.rangeX(0.0, 1.0); xyGraph.rangeY(-1.0, 1.0); xyGraph.saveGraphAsPNG("sin.png", 500, 300); //viewメソッドの後に呼び出すと、動作がおかしいので注意 xyGraph.view(700, 700); } private static double[][] makeSineValues() { double[][] ret = new double[2][101]; //0-1を100個のデータで埋める for(int i=0; i<=100; i++){ ret[0][i] = i/100.0; //X ret[1][i] = Math.sin(2.0 * Math.PI * ret[0][i]); //Y } return ret; } private static double[][] makeTrainValues(List<double[]> trainingDataValues) { double[][] ret = new double[2][trainingDataValues.size()]; int i=0; for(double[] rec: trainingDataValues){ ret[0][i] = rec[0]; //X ret[1][i] = rec[1]; //Y i++; } return ret; } private static double[][] makeResultValues(double[] w, int m) { double[][] ret = new double[2][101]; //0-1を100個のデータで埋める for(int i=0; i<=100; i++){ ret[0][i] = i/100.0; //X for(int j=0; j<=m; j++){ //Y ret[1][i] += w[j] * Math.pow(ret[0][i],j); } } return ret; } }
グラフ表示用クラス。JFreeChartを利用。
package prml.chap1; import java.io.File; import java.io.IOException; import javax.swing.JFrame; import org.jfree.chart.ChartFactory; import org.jfree.chart.ChartPanel; import org.jfree.chart.ChartUtilities; import org.jfree.chart.JFreeChart; import org.jfree.chart.axis.ValueAxis; import org.jfree.chart.plot.DatasetRenderingOrder; import org.jfree.chart.plot.PlotOrientation; import org.jfree.chart.plot.XYPlot; import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer; import org.jfree.data.xy.DefaultXYDataset; public class XYGraph { private JFreeChart chart; public XYGraph(String graphName, String labelX, String labelY){ this.chart = ChartFactory.createXYLineChart( graphName, // The chart title labelX, // x axis label labelY, // y axis label null, // The dataset for the chart PlotOrientation.VERTICAL, true, // Is a legend required? false, // Use tooltips false // Configure chart to generate URLs? ); } //データを追加(double[0]=X座標リスト、double[1]=Y座標リスト)。線も表示するか否か public void addDataValues(String name, double[][] dataValuses, boolean line){ DefaultXYDataset dataset = new DefaultXYDataset(); dataset.addSeries(name, dataValuses); XYPlot plot = this.chart.getXYPlot(); int datasetCnt = plot.getDatasetCount(); plot.setDataset(datasetCnt, dataset); XYLineAndShapeRenderer renderer; if(line){ renderer = new XYLineAndShapeRenderer( true, // 線を表示しない false // 点を表示する ); }else{ renderer = new XYLineAndShapeRenderer( false, // 線を表示しない true // 点を表示する ); } plot.setRenderer(datasetCnt, renderer); plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD); } //X座標の表示区間 public void rangeX(double min, double max){ ValueAxis xAxis = this.chart.getXYPlot().getDomainAxis(); xAxis.setAutoRange(false); xAxis.setRange(min, max); } //Y座標の表示区間 public void rangeY(double min, double max){ ValueAxis yAxis = this.chart.getXYPlot().getRangeAxis(); yAxis.setAutoRange(false); yAxis.setRange(min, max); } //表示(画面サイズを指定) public void view(int sizeX, int sizeY){ String graphName = this.chart.getTitle().getText(); JFrame frame = new JFrame(graphName); frame.getContentPane().add(new ChartPanel(this.chart)); frame.setSize(sizeX, sizeY); frame.setVisible(true); frame.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); } //保存(ファイル名と画面サイズを指定) public void saveGraphAsPNG(String filename, int sizeX, int sizeY) throws IOException { File imageFile = new File(filename); ChartUtilities.saveChartAsPNG(imageFile, this.chart, sizeX, sizeY); } //テスト用メイン public static void main(String[] args) throws IOException{ //サンプルデータ(sinとcosとy=x-1) final int NUM_VALUES = 60; double[][] sineValues = new double[2][NUM_VALUES]; double[][] cosineValues = new double[2][NUM_VALUES]; // X values for (int i = 0; i < NUM_VALUES; i++) { sineValues[0][i] = i / 10.0; cosineValues[0][i] = i / 10.0; } // Y values for (int i = 0; i < NUM_VALUES; i++) { sineValues[1][i] = Math.sin(sineValues[0][i]); cosineValues[1][i] = Math.cos(cosineValues[0][i]); } //サンプルデータ(y=a*x-1) double[][] scatterplotValues = new double[2][10]; for (int i = 0; i < 10; i++) { scatterplotValues[0][i] = (double) i; // x scatterplotValues[1][i] = (double) i * (0.25 / 2.0) - 1.0; // y } //表示テスト XYGraph xyGraph = new XYGraph("Sine / Cosine Curves", "X", "Y"); xyGraph.addDataValues("Sin(x)", sineValues, true); xyGraph.addDataValues("Cos(x)", cosineValues, true); xyGraph.addDataValues("y=a*x-1", scatterplotValues, false); xyGraph.saveGraphAsPNG("sin-cos.png", 500, 300); //先に保存しないと、グラフがおかしくなる xyGraph.view(700, 700); } }