Wielowarstwowa sieć neuronowa w Pytorch – klasyfikacja obrazów CIFAR-10

Krok po kroku opisuję najważniejsze etapy tworzenia wielowarstwowej sieci neuronowej do klasyfikacji obrazów ze zbioru CIFAR-10. Skupimy się na wczytywaniu danych oraz omówimy jeden z głównych elementów, czyli klasę nn.Module reprezentującą sieć neuronową.

Wpis ten należy do serii artykułów o Pytorch, w każdym artykule opisuję pewien aspekt tej biblioteki oraz przykłady najpopularniejszych architektur sieci neuronowych.

Do tej pory opublikowałem:

Omówienie problemu i architektury sieci

Postawione przed nami zadanie polega na klasyfikacji zbioru obrazów ze zbioru CIFAR-10. Zawiera on 60000 kolorowych obrazów o rozmiarze 32×32 piksele, każdy obraz przynależy do jednej z 10 klas. Zbiór został podzielony na treningowy 50k oraz testowy 10k obrazów.

CIFAR 10 dataset klasyfikacja
Przykładowe obrazy z zbioru CIFAR 10.

Zadanie to zrealizujemy przy pomocy wielowarstwowej sieci neuronowej w Pytorch. Nie jest to obecnie najlepszy model do tego typu problemu, więc nie spodziewajcie się rewelacyjnych wyników. Zależy mi na tym, aby na praktycznym przykładzie wprowadzić was w świat Pytroch.

Architektura sieci


Model będzie składać się z trzech warstw ukrytych o rozmiarach 512, 256, 128. Proszę nie pytajcie skąd te rozmiary, wziąłem je z kosmosu :). Na obecnym etapie nie są istotne, oczywiście zachęcam do pobawienia się rozmiarami i zanotowaniu czasu treningu oraz dokładności klasyfikacji. Jako funkcję aktywacji wybierzemy RELU.

Na wejściu do sieci będziemy podawać surowe dane, każdy obraz “rozwiniemy”. Trój wymiarową macierz obrazu (RGB) przekształcimy wiersz po wierszu na długi wektor pikseli o wymiarach 3x32x32. Natomiast na wyjściu sieci damy 10 wymiarową warstwę liniową. Każdy neuron będzie miał obliczoną wartość przynależności do danej klasy, im wyższa wartość, tym wyższe prawdopodobieństwo przynależności do tej klasy.

Struktura skryptu uczącego

Skrypty zawierające kod modułu uczącego zazwyczaj są do siebie bardzo podobne. Można wyróżnić następujące elementy:

  • wczytywanie danych – w zależności od problemu, z jakim przyszło się nam zmierzyć może to być jedno z trudniejszych zadań 🙁
  • transformacja danych – przekształcenie danych do formatu pozwalającego na wczytanie do sieci (tensor o odpowiednich rozmiarach, zakodowanie danych nie numerycznych, augumentacja danych itp.)
  • definicja architektury sieci – w przypadku Pytorch będziemy pisać własną klasę dziedziczącą po nn.Module
  • wybór i deklaracja funkcji straty – to w jaki sposób obliczamy błąd na wyjściu sieci. W zależności od klasy problemu (klasyfikacja, regresja, klasyfikacja wielo-etykietowa itp.) powinniśmy użyć odpowiedniej funkcji
  • wybór i deklaracja metody optymalizacji – jaki algorytm optymalizacji wag zastosujemy, od niego często zależy zbieżność oraz czas treningu
  • pętla ucząca – scala wszystkie poprzednie etapy, tu odbywa się właściwy trening,
    • wczytywana jest paczka danych (batch),
    • sieć oblicza wyjście,
    • funkcja straty porównuje wyjście sieci z prawdziwymi etykietami i oblicza stratę (loss)
    • obliczamy gradienty
    • optymalizator aktualizuje wagi sieci
    • i powrót do początku
  • pętla walidująca – po zakończonej epoce, sprawdzamy i wyświetlamy postęp treningu sieci

Pytorch datasets i dataloaders, czyli łatwe wczytywanie danych

Wczytywanie danych jest jednym z najuciążliwszych zadań. Wymaga od nas radzenia z ogromem formatów i sposobów zapisu. Często musimy dokonać transformacji danych, a także należy pamiętać o samym sposobie ich serwowania. Dzielenia na paczki, sortowania, współbieżnego dostępu, kolejkowania odczytów itp.

Właśnie takie odpowiedzialności mają klasy Dataset i Dataloader. Dzięki nim możemy łatwo wpiąć się w cały pipeline Pytorch’a do pobierania danych. Dataset odpowiada za odczytywanie zbioru a Dataloader za serwowanie danych.
Warto przyjrzeć się dwóm pokrewnym projektom TorchVision i TorchText bo mają już wiele gotowych klas do wczytywania znanych zbiorów z Computer Vision lub NLP (natural language processing)

Ci z Was, którzy będą chcieli trenować modele na własnych danych to wystarczy, że stworzą klasy dziedziczące po Dataset lub Dataloader. W tym celu polecam tutorial “Data Loading and Processing Tutorial“.

Wczytywanie zbioru CIFAR-10 przy pomocy torchvision

Wracając do naszego zadania, aby wczytać CIFAR-10 wystarczy wywołać funkcję torchvision.datasets.CIFAR10. Przekazując do niej parametr download=True spowodujemy, że zbiór ten zostanie automatycznie ściągnięty i zapisany w ścieżce określonej przez parametr root.

Zauważcie, że oddzielnie tworzymy trainset i i testset i od razu dla nich dataloadery. Dodatkowo tworzymy tablicę z klasami określającymi transformacje na obrazie (transform). Przekazane są one do konstruktora klasy Dataset. Podczas odczytu elementy ze zbioru są przekształcane na tensory i znormalizowane (każdy kanał oddzielnie mean=0.5 std=0.5)

Pytorch nn.Module klasa bazowa dla twojej sieci

Dobra, teraz zabieramy się do właściwej roboty. Zdefiniujmy nasz moduł określający sieć neuronową.

W Pytorch jest kilka sposobów, aby to zrobić. Do mnie najbardziej trafia stworzenie własnej klasy dziedziczącej po nn.Module. Konwencja nakazuje, zdefiniowanie poszczególnych warstw sieci w metodzie __init__, to tutaj tworzymy pola odpowiadające poszczególnym warstwom.
Model musi mieć zdefiniowaną metodę forward. Jest ona automatycznie wywoływana w celu obliczenia wyjścia sieci. Oczywiście to naszym obowiązkiem jest napisanie poprawnego kodu, który kolejno przekształca jedną warstwę w drugą aż do wyjściowej. Pytorch jedynie zapewnia wywołanie tej metody z przekazanymi danymi.

W naszym przypadku definiujemy 3 liniowe warstwy w pełni połączone (fc1, fc2, fc3). Po każdej z nich będziemy wywoływać nieliniową funkcję aktywacji relu (relu1, relu2, relu3). Ostatnia warstwa wyjściowa ‘output‘ przyjmuje wyjście warstwy 3 o rozmiarze hidden3 i zwraca nam 10 liczb, określających przynależność do 1 z 10 klas z CIFAR-10.

Metoda forward w tym przypadku jest bardzo prosta. Jako parametr dostaje wejście “x” będące trój-wymiarowym tensorem. Przechowuje on paczkę danych (batch_size=16). Następnie dane te kolejno przechodzą przez warstwy sieci, wyobraźcie sobie jakby to były filtry. Na wejście warstwy trafia wyjście warstwy poprzedniej.

Pętla ucząca i propagacja wsteczna

Teraz przejdziemy do magicznej części. Bo jedna rzecz to obliczenie wyjścia sieci a druga to sprawienie, aby na wyjściu pojawiały się poprawne wartości. Rozumiane jako adekwatne predykcje.

W uczeniu maszynowym kluczowym jest odpowiedni wybór funkcji starty (ang. loss function – idee funkcji straty opisałem paragrafie we wpisie o klasyfikacji MNIST w Tensorflow warto tam sięgnąć jeżeli nie masz intuicji czym jest funkcja straty).

W przypadku klasyfikacji o stosunkowo niewielkiej liczbie klas sprawdza się CrossEntropyLoss. Oblicza “różnicę” pomiędzy naszym wyjściem a tym, co powinno wyjść.

Sama pętla ucząca przechodzi po danych num_epoch razy. Wewnętrzna pętla w paczkach oblicza predykcję (prediction).
Następnie funkcji straty przekazujemy tensor z predykcjami (prediction) o wymiarach batch_size x output (16×10) oraz tablicę z numerami klas (labels np. [1,8,9, …]) o wymiarach batch_size (16).
Obliczamy wartość straty oraz gradient(loss.backward), który wykorzystujemy w algorytmie optymalizacji (Adam) do wykonania kroku optymalizacji (optimizer.step). Adam na podstawie wartości gradientów dokonuje aktualizacji wag w sieci wykorzystując alg. wstecznej propagacji błędów. Istotna uwaga w Pytorch musimy za każdym razem wyzerować gradienty, gdyż ich wartości się akumulują. Pozwala to np. na obliczenie gradientów kolejnego rzędu, ale w naszym przypadku nie jest to potrzebne.

W trakcie każdej epoki dokonujemy także klasyfikacji zbioru testowego, aby móc zaraportować postęp uczenia. Po uruchomieniu powinniśmy osiągnąć wyniki na poziomie 51%. Dodam, że przykładowa sieć jedno warstwowa osiąga tylko 48% dokładności. Przykładowy wydruk na konsoli:

Podsumowanie i materiały dodatkowe

W tym artykule przedstawiłem krok po kroku budowę wielowarstwowej sieci neuronowej. Chciałbym, aby stanowił dla Was kompletny materiał do nauki i dalszych eksperymentów.
Dajcie znać w komentarzach jeżeli Wam się przydał lub coś nie jest jasne.

Cały kod tego przykładu znajduje się na moim github’ie w projekcie “Pytorch neural networks tutorial” w pliku feedforward_3_hid_nn.py

Znajdują się (lub będą znajdować zależy kiedy czytasz) także inne zaimplementowane modele i przykłady.

Polecam dobre opracowanie związane z funkcjami aktywacji, w szczególności ReLu i jej pochodnych: PReLu, LeakyRelu, ELU, Relu-6

Zachęcam także do zostawienia maila i w ten sposób zasubskrybowania bloga. Dostaniesz tylko informacje o artykułach i planowanych przez mnie kursach.

Photo by Clint Adair on Unsplash

3 Comments Wielowarstwowa sieć neuronowa w Pytorch – klasyfikacja obrazów CIFAR-10

  1. Pingback: Wielowarstwowa sieć neuronowa w Tensorflow do klasyfikacji cyfr z MNIST - About Data

  2. Pingback: Sieć konwolucyjna w Pytorch - klasyfikacja obrazów CIFAR-10 - About Data

  3. Pingback: Sieć rekurencyjna LSTM do zliczania znaków - wprowadzenie - About Data

Dodaj komentarz

This site uses Akismet to reduce spam. Learn how your comment data is processed.