Ex 2: Restricted Boltzmann Machine features for digit classification

http://scikit-learn.org/stable/auto_examples/neural_networks/plot_rbm_logistic_classification.html#sphx-glr-auto-examples-neural-networks-plot-rbm-logistic-classification-py

此範例將使用BernoulliRBM特徵選取方法,提升手寫數字識別的精確率,伯努利限制玻爾茲曼機器模型(`BernoulliRBM

`)將可以對數據做有效的非線性 特徵提取的處理。 為了讓此模型訓練出來更為強健,將輸入的圖檔,分別做上左右下,一像素的平移,用以增加更多訓練資料, 訓練網路的參數是使用grid search演算法,但此訓練太耗費時間,因此不再這重現,。 此範例結果將比較, 1.使用原本的像素值做的邏輯回歸 2.使用BernoulliRBM做特徵選取的邏輯回歸 結果將顯示:使用BernoulliRBM將可以提升分類的準確度。

(一)引入函式庫與資料

from __future__ import print_function

print(__doc__)

# Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve
# License: BSD

import numpy as np
import matplotlib.pyplot as plt

from scipy.ndimage import convolve
from sklearn import linear_model, datasets, metrics
from sklearn.model_selection import train_test_split
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline

(二)資料前處理、讀取資料、選取模型

(三)設定模型參數與訓練模型

(四)評估模型的分辨準確率

圖1:使用RBM演算法後準確率為0.95

圖2:不使用任何特徵選取方法做的做的邏輯回歸準確率0.77

(五)畫出100個RBM萃取出的特徵

圖3:使用RBM演算法,尋找出來的特徵

(六)完整程式碼

Last updated