Ex 4: Label Propagation digits active learning

半監督式分類法/範例4 : Label Propagation digits active learning

本範例目的:

  • 展示active learning(主動學習)進行以label propagation(標籤傳播法)學習辨識手寫數字

一、Active Learning 主動學習

在實際應用上,通常我們獲得到的數據,有一大部分是未標籤的,如果要套用在常用的分類法上,最直接的想法是標籤所有的數據,但一一標籤所有數據是非常耗時耗工的,因此,在面對未標籤的數據遠多於有標籤的數據之情況下,可以透過active learning,主動的挑選一些數據進行標籤。 Active learning分成兩部分:

  • 從已標籤的數據中隨機抽取一小部分作為訓練集,訓練出一個分類模型

  • 透過迭代,將分類器預測出來的結果再進行訓練。

二、引入函式與模型

  • stats用來進行統計與分析

  • LabelSpreading為半監督式學習的模型

  • confusion_matrix為混淆矩陣

  • classification_report用於觀察預測和實際數值的差異,包含precision、recall、f1-score及support

import numpy as np
import matplotlib.pyplot as plt

from scipy import stats
from sklearn import datasets
from sklearn.semi_supervised import LabelSpreading
from sklearn.metrics import classification_report, confusion_matrix

三、建立dataset

  • Dataset取自sklearn.datasets.load_digits,內容為0~9的手寫數字,共有1797筆

  • 使用其中的330筆進行訓練(y_train),其中40筆為labeled,其餘290筆為unlabeled(標為-1)

  • 迭代的次數設定為5次

  • scikit learn網站中的範例程式敘述為10筆labeled,但原始程式碼為40筆,因此在這邊以原始碼為主

四、利用Active learning進行模型訓練與預測

  • 以下程式為每一次迭代所做的過程(for迴圈的內容)

  • 每一次迭代都利用訓練過後的模型進行預測,得到predicted_labels,並與true_labels計算混淆矩陣與classification report

  • 利用stats進行數據的統計,找出前5筆預測最不佳的結果,將其預測的label與true label和圖像顯示出來

  • 每一次迭代的最後挑出上述的5筆預測最不佳的結果,進行下一次的迭代時,把相對應的true label替換給y_train測試集裡面,其餘(第40筆之後的數據)的label依然給予-1表示unlabeled

  • 下列程式屬於for迴圈外圍

  • 以下即為每一次迭代的結果,可以看到每一次迭代之後,micro avg逐漸上升

Out:

png

上圖的結果即為Active Learning訓練過程的結果,第一次迭代以330筆的資料進行訓練,其中包含40筆labeled的資料與290 unlabeled的資料,再對unlabeled的資料做預測,將預測出來的結果中,5個預測最不佳的結果顯示出來,即第一列的5張圖,將這5筆資料的從測試集中強制變為true label的結果,再下一次迭代中,labeled的資料就變成45筆,unlabeled的資料為285筆,總和為330筆的資料進行第二次的訓練,以此類推,因此可以看到,每一次訓練,labeled的資料會5筆、5筆的增加。

五、原始碼列表

Python source code: plot_label_propagation_digits_active_learning.py

https://scikit-learn.org/stable/auto_examples/semi_supervised/plot_label_propagation_digits_active_learning.html

Last updated