MNIST i SVM klasyfikacja ręcznie pisanych cyfr

Klasyfikacja odręcznie pisanych cyfr z zbioru MNIST jest swojego rodzaju ‚hello_world’ w dziedzinie uczenia maszynowego. Post ten prezentuje w jaki sposób w Pythonie wczytać, wyświetlić oraz wykorzystać algorytm Support Vector Machines (SVM) do klasyfikacji obrazów przedstawiających ręcznie pisane cyfry.

Jak wczytać zbiór MNIST?

Zbiór ten został skompletowany przez Yann’a Lecun (http://yann.lecun.com/exdb/mnist/) zawiera 70k obrazów w rozdzielczości 28x28px. Przyjęło się, że 60k obiektów stanowi zbiór treningowy a 10k przykładów są to elementy z zbioru testowego.  Obraz(cyfra) reprezentowany jest jako macierz 2D o wymaiarach 28×28, w której wartości 0-255 reprezentują odcień szarości pixela. Łatwiejszą formą do przetwarzania przez algorytmy uczenia maszynowego jest postać wektora. Możemy ją otrzymać odczytując wartości pixeli z obrazu wiersz po wierszu, stąd otrzymamy 784-wymiarowy wektor (28*28=784) reprezentujący jeden obrazek z cyfrą. Natomiast wczytując cały zbiór treningowy załadujemy do pamięci macierz o wymiarach 60000×784, która posłuży jako dane wejściowe do algorytmu uczącego.

Zbiór można wczytać na wiele sposobów, ściągnąć go z strony Yann’a, co będzie wymagało pewnego preprocessingu lub korzystając z modułu sklearn.datasets z biblioteki scikit-learn, dzięki czemu łatwo i szybko wczytamy zbiór do dwu-wymiarowej tablicy numpy.

Funkcja fetch_mldata automatycznie ściągnie niezbędne pliki i umieści w folderze data_home, jeżeli pliki zostały ściągnięte wcześniej to ściąganie zostanie pominięte. W ten sposób w zmiennych images (tablica 2D) oraz targets (tablica 1D) mamy macierz wczytanych cyfr oraz odpowiadające im poprawne etykiety. Warto zwrócić uwagę na wymiary tablic images oraz targets, ilość wierszy w images musi odpowiadać ilość elementów w targets. Spróbujmy wyświetlić np. 30000 element z naszego zbioru treningowego i odpowiadającą jemu etykietę:

Wynik powinine wyglądać następująco:

MNIST digit 4

W powyższym kodzie wykorzystaliśmy bibliotekę matplotlib oraz moduł pyplot do wyświetlenia obrazu, pamiętać należy aby nasz 784-wymiarowy wektor z powrotem przekształcić w obraz(macierz) o wymiarach 28×28 stąd wykorzystanie funkcji reshape.

Idea algorytmu Support Vector Machines

Chcąc utrzymać ten wpis w bardzo praktycznym tonie, nie będę się rozpisywał się na temat SVM (zasługuje to na oddzielny wpis), przedstawię tylko idee. Algorytm SVM jest jednym z przykładów uczenia z nauczycielem, czyli wraz z egzemplarzami uczącymi w postaci zbioru treningowego mamy, dla każdego egzemplarza podaną poprawną etykietę, klasę lub wartość w zależności od problemu.

W naszym przypadku mamy do czynienia z problemem klasyfikacji cyfr, więc zbiór etykiet składa się z 10 cyfr od 0 do 9. Stąd też wczytując nasze dane (zmienna images) wczytaliśmy także etykiety (targets). Algorytm ten na podstawie danych treningowych oraz etykiet poszukuje tzw. hiperpłaszczyzny w 784 wymiarowej przestrzeni, która najlepiej rozdzieli egzemplarze reprezentujące poszczególne cyfry od siebie.

W wielu metodach uczenia z nauczycielem, zestaw danych rozbijany jest na dwa zbiory: treningowy i testowy. Zbiór treningowy wraz z jego etykietami służy jak dane wejściowe dla algorytmu uczącego, który na jego podstawie uczy się rozpoznawać wzorce z naszego problemu. Następnie chcąc przetestować jak dobry jest nasz model, przewidujemy przy jego wykorzystaniu etykiety z zbioru testowego i porównujemy je z tymi rzeczywistymi.

Z praktycznego punktu widzenia potrzebujemy czterech tablic:

  • train – tablica 2D zawierająca 60k wierszy, a każdy wiersz jest 784-wymiarowym wektorem reprezentującym jedną cyfrę (obraz), przechowuje dane treningowe
  • train_labels – tablica 1D, zawierająca 60k elementów, wartościami są liczby 0-9, odpowiadające cyfrom z training
  • test – tablica 2D zawierająca 10k wierszy, przechowująca dane na podstawie których będziemy testować poprawność naszego wyuczonego modelu
  • test_labels – tablica 1D, zawierająca 10k elementów, wartościami są liczby 0-9, odpowiadające cyfrom z test

Klasyfikacja cyfr z MNIST wykorzystując SVM z scikit-learn

Przejdźmy teraz do implementacji, oczywiście nie będziemy implementowali algorytmu SVM sami, choć sam implementowałem go w swoim doktoracie na trylion różnych sposobów :). Wykorzystamy w tym celu świetną Python’ową bibliotekę scikit-learn. Wspomniana już wyżej biblioteka zawiera szereg algorytmów uczenia maszynowego oraz funkcji pomocniczych pozwalających operować na danych, podstawowych miarach dokładności itp.

W naszym przykładzie w celu szybszego działania fazy uczenia, wykorzystamy tylko 10k egzemplarzy. Stąd na początku z 70k wartości wybieramy losowych 10k indeksów z wykorzystaniem funkcji np.random.choice, następnie korzystając z funkcji pomocniczych z scikit-learn, dzielimy dane na dwa zbiory, treningowy (X_train, y_train) oraz testowy (X_test, y_test). Parametr test_size=0.15 określa, że 15% danych trafi do zbioru testowego, reszta do zbioru treningowego, więc nasze tablice powinny mieć rozmiary:

  • X_train.shape == (8500,784) , y_train.shape==(8500,)
  • X_test.shape == (1500,784), y_test.shape == (1500,)

Trening SVM

Sam proces uczenia jest już stosunkowo prosty, tworzymy obiekt classifier klasy SVC (support vector classification) z modułu sklearn.svm zaimportowanego na samym początku. Ustawiamy parametry C oraz gamma (o nich więcej w innym wpisie)a następnie wywołujemy funkcję fit, która dokonuje uczenia. Zauważcie, że dodatkowo dodane jest proste mierzenie czasu działania funkcji fit,  gdyż na większych danych proces ten może zając dużo czasu. W zależności od ustawionych parametrów czas uczenia na całym zbiorze u mnie w PLON’ie (online python IDE) zajmował ok 2h.

Predykcja SVM

Po skończeniu działania funkcji fit, obiekt classifier przechowuje parametry wyuczonego modelu wewnętrznie w zmiennych obiektu. Chcąc sprawdzić działanie i dokładność klasyfikatora spróbujmy przewiedzieć etykiety dla zbioru testowego w tym celu użyjmy funkcji classifier.predict zwróci ona tablicę predict przewidzianych etykiet. Porównując ją z rzeczywistymi etykietami możemy określić w ilu procentach nasz model poprawnie rozpoznał cyfry.

Podsumowanie

Wpis przedstawia sposób klasyfikacji odręcznie pisanych cyfr z zbioru MNIST z wykorzystaniem algorytmu SVM całość została zaimplementowana z wykorzystaniem biblioteki sciki-learn. Po niewielkich modyfikacjach kodu, przedstawiony przykład można śmiało wykorzystać przy innych projektach, w których mamy do czynienia z problemem klasyfikacji.

Zaznaczam, że jest kilka kwestii, które z premedytacją ominąłem, jak np. funkcje jądra w SVM (ang. kernel functions), preprocessing danych np. elastic deformations, oraz kernel aproximations ale o tych metodach będzie w kolejnych wpisach.

Cały projekt jest do ściągnięcia lub uruchomienia online:

1 Comment MNIST i SVM klasyfikacja ręcznie pisanych cyfr

  1. Pingback: Klasyfikacja cyfr z MNIST przy pomocy sieci neuronowej w Tesnsorflow - About Data

Dodaj komentarz

Twój adres email nie zostanie opublikowany. Pola, których wypełnienie jest wymagane, są oznaczone symbolem *