Wielowarstwowa sieć neuronowa w Pytorch – klasyfikacja CIFAR-10

Krok po kroku opisuję najważniejsze etapy tworzenia wielowarstwowej sieci neuronowej do klasyfikacji obrazów ze zbioru CIFAR-10. Skupimy się na wczytywaniu danych oraz omówimy jeden z głównych elementów, czyli klasę nn.Module reprezentującą sieć neuronową.

Wpis ten należy do serii artykułów o Pytorch, w każdym artykule opisuję pewien aspekt tej biblioteki oraz przykłady najpopularniejszych architektur sieci neuronowych.

Do tej pory opublikowałem:

Omówienie problemu i architektury sieci

Postawione przed nami zadanie polega na klasyfikacji zbioru obrazów ze zbioru CIFAR-10. Zawiera on 60000 kolorowych obrazów o rozmiarze 32×32 piksele, każdy obraz przynależy do jednej z 10 klas. Zbiór został podzielony na treningowy 50k oraz testowy 10k obrazów.

CIFAR 10 dataset klasyfikacja
Przykładowe obrazy z zbioru CIFAR 10.

Zadanie to zrealizujemy przy pomocy wielowarstwowej sieci neuronowej w Pytorch. Nie jest to obecnie najlepszy model do tego typu problemu, więc nie spodziewajcie się rewelacyjnych wyników. Zależy mi na tym, aby na praktycznym przykładzie wprowadzić was w świat Pytroch.

Architektura sieci


Model będzie składać się z trzech warstw ukrytych o rozmiarach 512, 256, 128. Proszę nie pytajcie skąd te rozmiary, wziąłem je z kosmosu :). Na obecnym etapie nie są istotne, oczywiście zachęcam do pobawienia się rozmiarami i zanotowaniu czasu treningu oraz dokładności klasyfikacji. Jako funkcję aktywacji wybierzemy RELU.

Na wejściu do sieci będziemy podawać surowe dane, każdy obraz „rozwiniemy”. Trój wymiarową macierz obrazu (RGB) przekształcimy wiersz po wierszu na długi wektor pikseli o wymiarach 3x32x32. Natomiast na wyjściu sieci damy 10 wymiarową warstwę liniową. Każdy neuron będzie miał obliczoną wartość przynależności do danej klasy, im wyższa wartość, tym wyższe prawdopodobieństwo przynależności do tej klasy.

Struktura skryptu uczącego

Skrypty zawierające kod modułu uczącego zazwyczaj są do siebie bardzo podobne. Można wyróżnić następujące elementy:

  • wczytywanie danych – w zależności od problemu, z jakim przyszło się nam zmierzyć może to być jedno z trudniejszych zadań 🙁
  • transformacja danych – przekształcenie danych do formatu pozwalającego na wczytanie do sieci (tensor o odpowiednich rozmiarach, zakodowanie danych nie numerycznych, augumentacja danych itp.)
  • definicja architektury sieci – w przypadku Pytorch będziemy pisać własną klasę dziedziczącą po nn.Module
  • wybór i deklaracja funkcji straty – to w jaki sposób obliczamy błąd na wyjściu sieci. W zależności od klasy problemu (klasyfikacja, regresja, klasyfikacja wielo-etykietowa itp.) powinniśmy użyć odpowiedniej funkcji
  • wybór i deklaracja metody optymalizacji – jaki algorytm optymalizacji wag zastosujemy, od niego często zależy zbieżność oraz czas treningu
  • pętla ucząca – scala wszystkie poprzednie etapy, tu odbywa się właściwy trening,
    • wczytywana jest paczka danych (batch),
    • sieć oblicza wyjście,
    • funkcja straty porównuje wyjście sieci z prawdziwymi etykietami i oblicza stratę (loss)
    • obliczamy gradienty
    • optymalizator aktualizuje wagi sieci
    • i powrót do początku
  • pętla walidująca – po zakończonej epoce, sprawdzamy i wyświetlamy postęp treningu sieci

Pytorch datasets i dataloaders, czyli łatwe wczytywanie danych

Wczytywanie danych jest jednym z najuciążliwszych zadań. Wymaga od nas radzenia z ogromem formatów i sposobów zapisu. Często musimy dokonać transformacji danych, a także należy pamiętać o samym sposobie ich serwowania. Dzielenia na paczki, sortowania, współbieżnego dostępu, kolejkowania odczytów itp.

Właśnie takie odpowiedzialności mają klasy Dataset i Dataloader. Dzięki nim możemy łatwo wpiąć się w cały pipeline Pytorch’a do pobierania danych. Dataset odpowiada za odczytywanie zbioru a Dataloader za serwowanie danych.
Warto przyjrzeć się dwóm pokrewnym projektom TorchVision i TorchText bo mają już wiele gotowych klas do wczytywania znanych zbiorów z Computer Vision lub NLP (natural language processing)

Ci z Was, którzy będą chcieli trenować modele na własnych danych to wystarczy, że stworzą klasy dziedziczące po Dataset lub Dataloader. W tym celu polecam tutorial „Data Loading and Processing Tutorial„.

Wczytywanie zbioru CIFAR-10 przy pomocy torchvision

Wracając do naszego zadania, aby wczytać CIFAR-10 wystarczy wywołać funkcję torchvision.datasets.CIFAR10. Przekazując do niej parametr download=True spowodujemy, że zbiór ten zostanie automatycznie ściągnięty i zapisany w ścieżce określonej przez parametr root.

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print(f'Working on device={device}')

# Hyper-parameters 

# each cifar image is RGB 32x32, so it is an 3D array [3,32,32]
# we will flatten the image as vector dim=3*32*32 
input_size = 3*32*32

hidden_size = 512
# we have 10 classes
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

num_classes = 10
num_epochs = 5
batch_size = 16
learning_rate = 0.001

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

Zauważcie, że oddzielnie tworzymy trainset i i testset i od razu dla nich dataloadery. Dodatkowo tworzymy tablicę z klasami określającymi transformacje na obrazie (transform). Przekazane są one do konstruktora klasy Dataset. Podczas odczytu elementy ze zbioru są przekształcane na tensory i znormalizowane (każdy kanał oddzielnie mean=0.5 std=0.5)

Pytorch nn.Module klasa bazowa dla twojej sieci

Dobra, teraz zabieramy się do właściwej roboty. Zdefiniujmy nasz moduł określający sieć neuronową.

W Pytorch jest kilka sposobów, aby to zrobić. Do mnie najbardziej trafia stworzenie własnej klasy dziedziczącej po nn.Module. Konwencja nakazuje, zdefiniowanie poszczególnych warstw sieci w metodzie __init__, to tutaj tworzymy pola odpowiadające poszczególnym warstwom.
Model musi mieć zdefiniowaną metodę forward. Jest ona automatycznie wywoływana w celu obliczenia wyjścia sieci. Oczywiście to naszym obowiązkiem jest napisanie poprawnego kodu, który kolejno przekształca jedną warstwę w drugą aż do wyjściowej. Pytorch jedynie zapewnia wywołanie tej metody z przekazanymi danymi.

W naszym przypadku definiujemy 3 liniowe warstwy w pełni połączone (fc1, fc2, fc3). Po każdej z nich będziemy wywoływać nieliniową funkcję aktywacji relu (relu1, relu2, relu3). Ostatnia warstwa wyjściowa ’output’ przyjmuje wyjście warstwy 3 o rozmiarze hidden3 i zwraca nam 10 liczb, określających przynależność do 1 z 10 klas z CIFAR-10.

class MultilayerNeuralNet(nn.Module):
    def __init__(self, input_size, num_classes):
        '''
        Fully connected neural network with 3 hidden layers
        '''
        super(MultilayerNeuralNet, self).__init__()
        
        # hidden layers sizes, you can play with it as you wish!
        hidden1 = 512
        hidden2 = 256
        hidden3 = 128

        # input to first hidden layer parameters
        self.fc1 = nn.Linear(input_size, hidden1) 
        self.relu1 = nn.ReLU()

        # second hidden layer
        self.fc2 = nn.Linear(hidden1, hidden2) 
        self.relu2 = nn.ReLU()
        
        # third hidden layer
        self.fc3 = nn.Linear(hidden2, hidden3)
        self.relu3 = nn.ReLU()

        # last output layer
        self.output = nn.Linear(hidden3, num_classes) 

    
    def forward(self, x):
        '''
        This method takes an input x and layer after layer compute network states.
        Last layer gives us predictions.
        '''
        state = self.fc1(x)
        state = self.relu1(state)

        state = self.fc2(state)
        state = self.relu2(state)

        state = self.fc3(state)
        state = self.relu3(state)

        state = self.output(state)
        
        return state

model = MultilayerNeuralNet(input_size, num_classes).to(device)

Metoda forward w tym przypadku jest bardzo prosta. Jako parametr dostaje wejście „x” będące trój-wymiarowym tensorem. Przechowuje on paczkę danych (batch_size=16). Następnie dane te kolejno przechodzą przez warstwy sieci, wyobraźcie sobie jakby to były filtry. Na wejście warstwy trafia wyjście warstwy poprzedniej.

Pętla ucząca i propagacja wsteczna

Teraz przejdziemy do magicznej części. Bo jedna rzecz to obliczenie wyjścia sieci a druga to sprawienie, aby na wyjściu pojawiały się poprawne wartości. Rozumiane jako adekwatne predykcje.

W uczeniu maszynowym kluczowym jest odpowiedni wybór funkcji starty (ang. loss function – idee funkcji straty opisałem paragrafie we wpisie o klasyfikacji MNIST w Tensorflow warto tam sięgnąć jeżeli nie masz intuicji czym jest funkcja straty).

W przypadku klasyfikacji o stosunkowo niewielkiej liczbie klas sprawdza się CrossEntropyLoss. Oblicza „różnicę” pomiędzy naszym wyjściem a tym, co powinno wyjść.

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  

Sama pętla ucząca przechodzi po danych num_epoch razy. Wewnętrzna pętla w paczkach oblicza predykcję (prediction).
Następnie funkcji straty przekazujemy tensor z predykcjami (prediction) o wymiarach batch_size x output (16×10) oraz tablicę z numerami klas (labels np. [1,8,9, …]) o wymiarach batch_size (16).
Obliczamy wartość straty oraz gradient(loss.backward), który wykorzystujemy w algorytmie optymalizacji (Adam) do wykonania kroku optymalizacji (optimizer.step). Adam na podstawie wartości gradientów dokonuje aktualizacji wag w sieci wykorzystując alg. wstecznej propagacji błędów. Istotna uwaga w Pytorch musimy za każdym razem wyzerować gradienty, gdyż ich wartości się akumulują. Pozwala to np. na obliczenie gradientów kolejnego rzędu, ale w naszym przypadku nie jest to potrzebne.

# set our model in the training mode
model.train()
for epoch in range(num_epochs):

    epoch_loss = 0
    # data loop, iterate over chunk of data(batch) eg. 32 elements
    # compute model prediction
    # update weights
    for i, batch_sample in enumerate(train_loader):

        # print(batch_sample)
        images, labels = batch_sample

        # flatten the image and move to device
        images = images.reshape(-1, input_size).to(device)
        labels = labels.to(device)

        # Forward pass, compute prediction,
        # method 'forward' is automatically called
        prediction = model(images)
        # Compute loss, quantify how wrong our predictions are
        # small loss means a small error
        loss = criterion(prediction, labels)
        epoch_loss += loss.item()

        # Backward and optimize
        model.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss = epoch_loss / len(train_loader)

    # Test the model

    # set our model in the training mode
    model.eval()
    # In test phase, we don't need to compute gradients (for memory efficiency)
    with torch.no_grad():
        correct = 0
        total = 0

        for images, labels in test_loader:
            # reshape image
            images = images.reshape(-1, input_size).to(device)
            labels = labels.to(device)

            # predict classes
            prediction = model(images)

            # compute accuracy
            _, predicted = torch.max(prediction.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = correct/total

        # Accuracy of the network on the 10000 test images
        print(f'Epoch [{epoch+1}/{num_epochs}]], Loss: {epoch_loss:.4f} Test acc: {acc}')

W trakcie każdej epoki dokonujemy także klasyfikacji zbioru testowego, aby móc zaraportować postęp uczenia. Po uruchomieniu powinniśmy osiągnąć wyniki na poziomie 51%. Dodam, że przykładowa sieć jedno warstwowa osiąga tylko 48% dokładności. Przykładowy wydruk na konsoli:

Epoch [1/5]], Loss: 1.6971 Test acc: 0.4627
Epoch [2/5]], Loss: 1.4874 Test acc: 0.4947
Epoch [3/5]], Loss: 1.3887 Test acc: 0.493
Epoch [4/5]], Loss: 1.3108 Test acc: 0.5144
Epoch [5/5]], Loss: 1.2406 Test acc: 0.5166

Podsumowanie i materiały dodatkowe

W tym artykule przedstawiłem krok po kroku budowę wielowarstwowej sieci neuronowej. Chciałbym, aby stanowił dla Was kompletny materiał do nauki i dalszych eksperymentów.
Dajcie znać w komentarzach jeżeli Wam się przydał lub coś nie jest jasne.

Cały kod tego przykładu znajduje się na moim github’ie w projekcie „Pytorch neural networks tutorial” w pliku feedforward_3_hid_nn.py

Znajdują się (lub będą znajdować zależy kiedy czytasz) także inne zaimplementowane modele i przykłady.

Polecam dobre opracowanie związane z funkcjami aktywacji, w szczególności ReLu i jej pochodnych: PReLu, LeakyRelu, ELU, Relu-6

Zachęcam także do zostawienia maila i w ten sposób zasubskrybowania bloga. Dostaniesz tylko informacje o artykułach i planowanych przez mnie kursach.

Photo by Clint Adair on Unsplash

Ciekawe, wartościowe, podziel się proszę opinią!

Witryna wykorzystuje Akismet, aby ograniczyć spam. Dowiedz się więcej jak przetwarzane są dane komentarzy.