Ex 4: Label Propagation digits active learning
半監督式分類法/範例4 : Label Propagation digits active learning
本範例目的:
展示active learning(主動學習)進行以label propagation(標籤傳播法)學習辨識手寫數字
一、Active Learning 主動學習
在實際應用上,通常我們獲得到的數據,有一大部分是未標籤的,如果要套用在常用的分類法上,最直接的想法是標籤所有的數據,但一一標籤所有數據是非常耗時耗工的,因此,在面對未標籤的數據遠多於有標籤的數據之情況下,可以透過active learning,主動的挑選一些數據進行標籤。 Active learning分成兩部分:
從已標籤的數據中隨機抽取一小部分作為訓練集,訓練出一個分類模型
透過迭代,將分類器預測出來的結果再進行訓練。
二、引入函式與模型
stats用來進行統計與分析
LabelSpreading為半監督式學習的模型
confusion_matrix為混淆矩陣
classification_report用於觀察預測和實際數值的差異,包含precision、recall、f1-score及support
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from sklearn import datasets
from sklearn.semi_supervised import LabelSpreading
from sklearn.metrics import classification_report, confusion_matrix三、建立dataset
Dataset取自sklearn.datasets.load_digits,內容為0~9的手寫數字,共有1797筆
使用其中的330筆進行訓練(y_train),其中40筆為labeled,其餘290筆為unlabeled(標為-1)
迭代的次數設定為5次
scikit learn網站中的範例程式敘述為10筆labeled,但原始程式碼為40筆,因此在這邊以原始碼為主
四、利用Active learning進行模型訓練與預測
以下程式為每一次迭代所做的過程(for迴圈的內容)
每一次迭代都利用訓練過後的模型進行預測,得到predicted_labels,並與true_labels計算混淆矩陣與classification report
利用stats進行數據的統計,找出前5筆預測最不佳的結果,將其預測的label與true label和圖像顯示出來
每一次迭代的最後挑出上述的5筆預測最不佳的結果,進行下一次的迭代時,把相對應的true label替換給y_train測試集裡面,其餘(第40筆之後的數據)的label依然給予-1表示unlabeled
下列程式屬於for迴圈外圍
以下即為每一次迭代的結果,可以看到每一次迭代之後,micro avg逐漸上升
Out:

上圖的結果即為Active Learning訓練過程的結果,第一次迭代以330筆的資料進行訓練,其中包含40筆labeled的資料與290 unlabeled的資料,再對unlabeled的資料做預測,將預測出來的結果中,5個預測最不佳的結果顯示出來,即第一列的5張圖,將這5筆資料的從測試集中強制變為true label的結果,再下一次迭代中,labeled的資料就變成45筆,unlabeled的資料為285筆,總和為330筆的資料進行第二次的訓練,以此類推,因此可以看到,每一次訓練,labeled的資料會5筆、5筆的增加。
五、原始碼列表
Python source code: plot_label_propagation_digits_active_learning.py
Last updated