Klasyfikacja odręcznie pisanych cyfr z zbioru MNIST jest swojego rodzaju 'hello_world’ w dziedzinie uczenia maszynowego. Post ten prezentuje, w jaki sposób w Pythonie wczytać, wyświetlić oraz wykorzystać algorytm Support Vector Machines (SVM) do klasyfikacji obrazów przedstawiających ręcznie pisane cyfry.
Jak wczytać zbiór MNIST?
Zbiór ten został skompletowany przez Yann’a Lecun (http://yann.lecun.com/exdb/mnist/) zawiera 70k obrazów w rozdzielczości 28x28px. Przyjęło się, że 60k obiektów stanowi zbiór treningowy a 10k przykładów są to elementy z zbioru testowego. Obraz(cyfra) reprezentowany jest jako macierz 2D o wymaiarach 28×28, w której wartości 0-255 reprezentują odcień szarości pixela. Łatwiejszą formą do przetwarzania przez algorytmy uczenia maszynowego jest postać wektora. Możemy ją otrzymać odczytując wartości pixeli z obrazu wiersz po wierszu, stąd otrzymamy 784-wymiarowy wektor (28*28=784) reprezentujący jeden obrazek z cyfrą. Natomiast wczytując cały zbiór treningowy załadujemy do pamięci macierz o wymiarach 60000×784, która posłuży jako dane wejściowe do algorytmu uczącego.
Zbiór można wczytać na wiele sposobów, ściągnąć go z strony Yann’a, co będzie wymagało pewnego preprocessingu lub korzystając z modułu sklearn.datasets z biblioteki scikit-learn, dzięki czemu łatwo i szybko wczytamy zbiór do dwu-wymiarowej tablicy numpy.
import matplotlib.pyplot as plt import numpy as np # Import datasets, classifiers and performance metrics from sklearn import datasets, svm, metrics #fetch original mnist dataset from sklearn.datasets import fetch_mldata mnist = fetch_mldata('MNIST original', data_home='./') #minist object contains: data, COL_NAMES, DESCR, target fields #you can check it by running #mnist.keys() #data field is 70k x 784 array, each row represents pixels from 28x28=784 image images = mnist.data targets = mnist.target
Funkcja fetch_mldata automatycznie ściągnie niezbędne pliki i umieści w folderze data_home, jeżeli pliki zostały ściągnięte wcześniej to ściąganie zostanie pominięte. W ten sposób w zmiennych images (tablica 2D) oraz targets (tablica 1D) mamy macierz wczytanych cyfr oraz odpowiadające im poprawne etykiety. Warto zwrócić uwagę na wymiary tablic images oraz targets, ilość wierszy w images musi odpowiadać ilość elementów w targets. Spróbujmy wyświetlić np. 30000 element z naszego zbioru treningowego i odpowiadającą jemu etykietę:
plt.imshow(images[30000].reshape(28,28), cmap=plt.cm.gray_r, interpolation='nearest') plt.axis('off') plt.title("Digit: {}".format(targets[30000]))
Wynik powinien wyglądać następująco:
W powyższym kodzie wykorzystaliśmy bibliotekę matplotlib oraz moduł pyplot do wyświetlenia obrazu, pamiętać należy aby nasz 784-wymiarowy wektor z powrotem przekształcić w obraz(macierz) o wymiarach 28×28 stąd wykorzystanie funkcji reshape.
Idea algorytmu Support Vector Machines
Chcąc utrzymać ten wpis w bardzo praktycznym tonie, nie będę rozpisywał się na temat SVM (zasługuje to na oddzielny wpis), przedstawię tylko idee. Algorytm SVM jest jednym z przykładów uczenia z nauczycielem. Wraz z egzemplarzami uczącymi mamy dla każdego egzemplarza podaną poprawną etykietę, klasę lub wartość w zależności od problemu.
My rozważamy problem klasyfikacji cyfr. Zbiór etykiet składa się z 10 cyfr od 0 do 9. Stąd też wczytując nasze dane (zmienna images) wczytaliśmy także etykiety (targets). Algorytm ten na podstawie danych treningowych oraz etykiet poszukuje tzw. hiperpłaszczyzny w 784 wymiarowej przestrzeni, która najlepiej rozdzieli egzemplarze reprezentujące poszczególne cyfry od siebie.
W wielu metodach uczenia z nauczycielem zestaw danych rozbijany jest na dwa zbiory: treningowy i testowy. Zbiór treningowy wraz z jego etykietami służy jak dane wejściowe dla algorytmu uczącego. Na jego podstawie algorytm uczy się rozpoznawać wzorce z naszego problemu. Ostatnim krokiem jest określenia sprawności modelu. Przewidujemy przy jego wykorzystaniu etykiety z zbioru testowego i porównujemy je z tymi rzeczywistymi.
Z praktycznego punktu widzenia potrzebujemy czterech tablic:
- train – tablica 2D zawierająca 60k wierszy, a każdy wiersz jest 784-wymiarowym wektorem reprezentującym jedną cyfrę (obraz), przechowuje dane treningowe
- train_labels – tablica 1D, zawierająca 60k elementów, wartościami są liczby 0-9, odpowiadające cyfrom z training
- test – tablica 2D zawierająca 10k wierszy, przechowująca dane na podstawie których będziemy testować poprawność naszego wyuczonego modelu
- test_labels – tablica 1D, zawierająca 10k elementów, wartościami są liczby 0-9, odpowiadające cyfrom z test
Klasyfikacja cyfr z MNIST wykorzystując SVM z scikit-learn
Przejdźmy teraz do implementacji. Oczywiście nie będziemy implementowali algorytmu SVM sami, choć sam implementowałem go w swoim doktoracie na trylion różnych sposobów :). Wykorzystamy w tym celu świetną Python’ową bibliotekę scikit-learn. Wspomniana już wyżej biblioteka zawiera szereg algorytmów uczenia maszynowego oraz funkcji pomocniczych pozwalających operować na danych, podstawowych miarach dokładności itp.
W naszym przykładzie w celu szybszego działania fazy uczenia, wykorzystamy tylko 10k egzemplarzy. Stąd na początku z 70k wartości wybieramy losowych 10k indeksów z wykorzystaniem funkcji np.random.choice, następnie korzystając z funkcji pomocniczych z scikit-learn, dzielimy dane na dwa zbiory, treningowy (X_train, y_train) oraz testowy (X_test, y_test). Parametr test_size=0.15 określa, że 15% danych trafi do zbioru testowego, reszta do zbioru treningowego, więc nasze tablice powinny mieć rozmiary:
- X_train.shape == (8500,784) , y_train.shape==(8500,)
- X_test.shape == (1500,784), y_test.shape == (1500,)
#sample smaller size for testing rand_idx = np.random.choice(images.shape[0],10000) #scale data for [0,255] -> [0,1] X_data =images[rand_idx]/255.0 Y = targets[rand_idx] #split data to train and test from sklearn.cross_validation import train_test_split X_train, X_test, y_train, y_test = train_test_split(X_data, Y, test_size=0.15, random_state=42)
Trening SVM
Sam proces uczenia jest już stosunkowo prosty. Tworzymy obiekt classifier klasy SVC (support vector classification) z modułu sklearn.svm. Ustawiamy parametry C oraz gamma (o nich więcej w innym wpisie), a następnie wywołujemy funkcję fit, która dokonuje treningu. Zauważcie, że dodane jest proste mierzenie czasu działania funkcji fit. Na większych danych proces ten może zając dużo czasu. W zależności od ustawionych parametrów czas uczenia na całym zbiorze zajmował ok 2h.
# Create a classifier: a support vector classifier classifier = svm.SVC(C=1,gamma=0.001) import datetime as dt # We learn the digits on train part start_time = dt.datetime.now() print 'Start learning at {}'.format(str(start_time)) classifier.fit(X_train, y_train) end_time = dt.datetime.now() print 'Stop learning {}'.format(str(end_time)) elapsed_time= end_time - start_time print 'Elapsed learning {}'.format(str(elapsed_time))
Predykcja SVM
Po skończeniu działania funkcji fit, obiekt classifier przechowuje parametry wyuczonego modelu wewnętrznie w zmiennych obiektu. Chcąc sprawdzić dokładność klasyfikatora spróbujmy przewiedzieć etykiety dla zbioru testowego. Użyjmy funkcji classifier.predict, zwróci ona tablicę predict przewidzianych etykiet. Porównując ją z rzeczywistymi etykietami możemy określić w ilu procentach nasz model poprawnie rozpoznał cyfry.
# Now predict the value of the test expected = y_test predicted = classifier.predict(X_test) show_some_digits(X_test,predicted,title_text="Predicted {}") print("Classification report for classifier %s:\n%s\n" % (classifier, metrics.classification_report(expected, predicted))) cm = metrics.confusion_matrix(expected, predicted) print("Confusion matrix:\n%s" % cm)
Podsumowanie
Wpis przedstawia sposób klasyfikacji odręcznie pisanych cyfr z zbioru MNIST z wykorzystaniem algorytmu SVM. Przykład został zaimplementowana z wykorzystaniem biblioteki sciki-learn. Po niewielkich modyfikacjach można śmiało go wykorzystać przy innych projektach, w których mamy do czynienia z problemem klasyfikacji.
Zaznaczam, że jest kilka kwestii, które z premedytacją ominąłem. Nie opisywałem funkcji jądra w SVM (ang. kernel functions), preprocessingu danych np. elastic deformations, oraz kernel aproximations. Wszystkie wymienione techniki pomagają uzyskać lepszą dokładność, ale nie to było celem tego wpisu.
Cały projekt jest do ściągnięcia kod z GitHub: https://github.com/ksopyla/svm_mnist_digit_classification – przykład zawarty jest w pliku svm_mnist_classification.py
Jeżeli uważasz ten wpis za wartościowy to Zasubskrybuj bloga. Dostaniesz informacje o nowych artykułach.
Cześć,
Super wprowedznie do klasyfikacji MNIST.
W ćwiczeniu został użyty modułu o nazwie sklearn.cross_validation, w nowszej wersji tego modułu sklearn (wersja 0.20.2) ma on nazwę sklearn.model_selection.
Pozdrawiam,
dk
Dzięki będę musiał poprawić.
Warto też w importach uzupełnić użycie modułu 'mnist_helpers’ inaczej zamiast accuracy pojawi się błąd.
Dzięki za przyjazne wprowadzenie do ML, trochę musiałem też powalczyć z załadowaniem danych (m.in dostępny w openml.org nzestaw danych pod nazwą- 'mnist_784′) ale dalej było ok.
LeszekT