Tworzymy rekurencyjną sieć LSTM do analizy wydźwięku recenzji filmowych. Nie jest to jednak typowy przykład o sentiment analysis postanowiłem go rozbudować i wykorzystać bardziej zaawansowaną technikę zwaną „Truncated Backpropagation through Time”.
W artykule chciałbym przedstawić technikę, która jest szczególnie przydatna przy analizie długich sekwencji. Wspomniana na początku „Truncated Backpropagation Through Time” (ucięta propagacja wsteczna poprzez czas – wybacznie nie wiem jak to przetłumaczyć) pozwala na uczenie rekurencyjnej sieci, gdy ciąg sekwencji jest bardzo długi. Najczęściej wykorzystywana jest przy budowie „Neural language model”, czyli modeli języka naturalnego pozwalającego na przewidywanie następnego słowa. Można ją także wykorzystać w klasycznych problemach, chociażby takich jak klasyfikacja.
Właśnie na przykładzie klasyfikacji sentymentu, chciałbym Wam przybliżyć TBTT (Truncated Backpropagation Through Time) dla sieci LSTM w Pytorch. W artykule posłużę się znanym zbiorem danych recenzji filmowych IMDB. Co prawda nie wymaga on zastosowania TBTT, bo teksty nie są długie, ale z uwagi na jego prostotę i dostępność pozwoli szybko przedstawić ideę.
Wpis podzielony jest na dwie części:
- pierwsza teoretyczna – wprowadzająca w problem analizy sentymentu oraz omawiająca sam algorytm Truncated Backpropagation (TBTT)
- druga praktyczna – przykład sieci LSTM z TBPTT w PyTorch
Zachęcam także do przeczytania kolejnych tutoriali o PyTorch.
- Dlaczego porzuciłem Tensorflow na rzecz Pytorch
- Wielowarstwowa sieć neuronowa w Pytorch – klasyfikacja CIFAR-10
- Sieć konwolucyjna w Pytorch – klasyfikacja obrazów CIFAR-10
- Sieć rekurencyjna LSTM do zliczania znaków – wprowadzenie
- Sieć LSTM do analizy sentymentu recenzji filmowych z IMDB
Analiza sentymentu
Analiza sentymentu lub ocena wydźwięku jest to problem z dziedziny NLP (natural language processing) polegający na identyfikacji i wyodrębnieniu z tekstu subiektywnej oceny lub emocji autora. Przyjmuje on najczęściej postać klasyfikacji wieloklasowej (multiclass) lub wieloetykietowej (multilabel). W najprostszej konfiguracji klasyfikacji binarnej automat ocenia czy wypowiedź (recenzja, opinia) jest pozytywna czy negatywna.
W przypadku zbioru IMDB pracujemy na tekstach o długości ok 200 słów, każdy z nich ma doczepioną etykietę 'pos’ lub 'neg’ (pozytywna, negatywna). Na podstawie tekstu recenzji będziemy uczyć sieć rozpoznania recenzji pozytywnej od negatywnej.
Truncated Backpropagation Through Time
Kamieniem milowym w uczeniu sztucznych sieci neuronowych było opracowanie w latach 70 algorytmu wstecznej propagacji błędów (Backpropagation). Do tej pory jest on kluczową techniką uczenia sieci neuronowych (głównie typu feedforward).
Niestety, w przypadku sieci rekurencyjnych nie można go zastosować wprost. Wynika to z faktu, że wyjście sieci nie zależy tylko od bieżącego wejścia i stanów pośrednich, ale także od wszystkich stanów wcześniejszych w czasie (’A’). Rozwijając sieć rekurencyjną w czasie łatwiej możemy dostrzec tę zależność. Na każdym kroku otrzymujemy ukryty stan ’A’ oraz wartość na wyjściu ’h_i’. Wartość stanu ’A’ z poprzedniego kroku wraz z bieżącymi danymi wejściowymi brana jest do obliczenia następnego stanu ukrytego. Dzięki temu możemy uchwycić kontekst sekwencji.
Na ostateczną wartość wyjścia sieci ’h_t’ mają wpływ wszystkie poprzednie stany. Stąd aktualizacja wag następuje nie tylko w dół, lecz także w lewo. Czyli nie aktualizujemy wag tylko w stronę bieżącego wejścia 'x_t’, ale także w czasie. Na bazie tej obserwacji powstał algorytm wstecznej propagacji w czasie (Backpropagation Through Time). Podczas treningu sieci dokonuje on aktualizacji wag od danego stanu ’h_t’ także w czasie.
W praktyce dla długich sekwencji podejście to nie dawało satysfakcjonujących rezultatów. Obliczenie wartości gradientów w długim łańcuchu (np. 1000 kroków wstecz w czasie) powoduje, że są one numerycznie niestabilne. Na kolejnych etapach stają się one bardzo małe (vanishing gradient), albo ich wartości zaczynają rosnąć wykładniczo (exploding gradient).
Rozwiązanie zaprezentował Pan Ilya Sutskever w swojej pracy doktorskiej (https://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf). Zaproponował i udowodnił, że podejście pracujące na uciętych ostatnich K-krokach (elementech sekwencji) jest wystarczające i efektywne w procesie treningu sieci rekurencyjnej. Swoją drogą warto tego pana śledzić, jest on jednym z prekursorów deep learningu, maczał palce w większości najbardziej cytowanych publikacji.
Wróćmy do idei uciętej propagacji wstecznej. Bardzo dobrze prezentuje to poniższy diagram.
Spróbujmy przeanalizować to na powyższym schemacie. Górna linia symbolizuje wyjście sieci, dolna stan ukryty a strzałki na dole wejście. Przechodzimy po sekwencji co k1=3 kroki (górna linia), patrzymy na ostatnią wartość wyjściową i dokonujemy propagacji wstecznej o k2=5 kroków (dolna linia).
Wariacji tego schematu może być więcej. Możemy brać nie jedno ostatnie wyjście pod uwagę, ale trzy (te, o które przeskakujemy). Innym podejściem jest zrównanie k1 i k2. Wiele tych podejść zależy od waszego problemu i w normalnej pracy badawczej warto je przetestować.
Z ostatnich moich doświadczeń z pracy z długimi tekstami (średnio ponad 200 000 słów) pierwsze podejście sprawdziło się najlepiej, głównie ze względu na mniejsze zużycie zasobów pamięciowych.
Truncated Backpropagation throug Time to metoda uczenia sieci rekurencyjnej w której co k1 kroków dokonujemy aktualizacji wag k2 kroków wstecz.
Część praktyczna – omówienie kodu
Przejdźmy teraz do praktycznej części artykułu, w której zdefiniujemy pomocnicze funkcje, pobierzemy i przetworzymy dane o recenzjach. Następnie zbudujemy architekturę sieci jako PyTorch nn.Module, w którym dodamy niezbędne funkcje do obliczania propagacji wstecznej tylko na części sekwencji. Na koniec napiszemy pętlę uczącą, w której trening będzie odbywał się na uciętych danych. Po każdej epoce dokonamy walidacji dotychczas przetrenowanego modelu.
Importy, stałe i funkcje pomocnicze
Na początek zacznijmy od importów oraz zdefiniowania oraz pomocniczych funkcji.
from torch import nn, optim from torchtext import data, datasets from datetime import datetime from progress.bar import Bar import numpy as np # set random seeds for reproducibility torch.manual_seed(12) torch.cuda.manual_seed(12) np.random.seed(12) # check if cuda is enabled USE_GPU=1 # Device configuration device = torch.device('cuda' if (torch.cuda.is_available() and USE_GPU) else 'cpu')
Importujemy biblioteki od PyTorch (nn i optim), zawierające moduły do budowy sieci neuronowych oraz algorytmy optymalizacji. Następnie wczytujemy biblioteki do pracy z tekstem (TorchText). Na końcu bibliotekę do manipulacji datą i czasem (potrzebna do obliczenia czasu treningu), bibliotekę do wyświetlania ładnego paska postępu oraz numpy.
Ustalmy ziarno generatorów losowych tak, abyście łatwo mogli dokonać powtórzenia wyników. Ostatni etap to zbadanie określenie czy chcemy dokonywać obliczeń na GPU oraz czy jest ono dostępne na naszym komputerze.
W przykładzie wyodrębniłem oddzielną funkcję tokenize, którą ma za zadanie podzielić nasz tekst na słowa. Funkcję tę przekażemy do definicji pola TEXT(Field) z TorchText. Oczywiście można użyć czegoś bardziej wyrafinowanego (spacy lub nltk).
def tokenize(text): # simple tokenizer words = text.lower().split() return words
Pomocnicza funkcja accuracy, pomoże nam podczas etapu walidacji policzyć dokładność klasyfikacji. Do funkcji przekazujemy bezpośrednie wyjście sieci preds oraz docelowe wartości y (wszystkie zmienne są tablicami 2D, gdzie pierwszy wymiar oznacza rozmiar paczki – batch). Zmienna preds przechowuje nieznormalizowane wartości więc na początku zamieniamy je na prawdopodobieństwa (0 do 1). Wybieramy indeks większego z wyjść (pos, neg) i porównujemy z docelowymi etykietami. Na koniec porównujemy ile się pokrywa i obliczamy z tego śrenią, czyli ostateczne acc.
def accuracy(preds, y): """ Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8 """ # apply softmax preds = torch.nn.functional.softmax(preds, dim=1) # get max values along rows _, indices = preds.max(dim=1) correct = (indices == y).float() # convert into float for division acc = correct.sum()/len(correct) return acc
Ostatnia funkcja split_batch pozwala na podział paczki danych na części do algorytmu uciętej propagacji wstecznej. Zmienną seq_batch (o rozmiarze [seq_len x batch_size] ) dzielimy wzdłuż sekwencji, dzięki czemu otrzymujemy listę tensorów o rozmiarach [seq_len/bptt x batch_size].
def split_batch(seq_batch, bptt): """ Split torch.tensor batch by bptt steps, Split seqence dim by bptt """ batch_splits = seq_batch.split(bptt,dim=0) return batch_splits
Wczytanie danych IMDB przy pomocy TorchText
Kolejnym etapem jest wczytanie danych z IMDB. Wykorzystamy do tego bibliotekę TorchText do zadań NLP.
Definiujemy dwa pola (kolumny) z naszego datasetu, pierwsze (TEXT) będzie odpowiadało za przetworzenie tekstu, a drugie etykiety (’neg’ lub 'pos’).
# set up fields TEXT = data.Field(lower=True, include_lengths=True, tokenize=tokenize) LABEL = data.LabelField()
Następnie wczytujemy zbiór IMDB, jeżeli go nie macie to zostanie automatycznie ściągnięty (spakowany zajmuje ok 87MB). Dzielimy go na zbiór treningowy i walidacyjny. Oba mają po 25k przykładów.
# make splits for data train_ds, valid_ds = datasets.IMDB.splits(TEXT, LABEL) # take a portion of datasets, for testing # train_ds, _ = train_ds.split(0.5) # valid_ds, _ = valid_ds.split(0.5) print(f'train={len(train_ds)} valid={len(valid_ds)}')
Budowa słownika
W NLP etap budowy słownika jest bardzo istotny. Na jego bazie zostaną zakodowane reprezentacje one-hot poszczególnych wyrazów. Wobec czego sekwencję (zdanie) przedstawimy jako ciąg liczb np. [23, 45, 1124, 9833, …] odnoszących się do numerów słów ze słownika.
W przykładzie budujemy dwa słowniki. Pierwszy na podstawie tekstu ze zbioru treningowego oraz definicji pola tekstowego. W funkcji build_vocab określamy, że słownik będzie maksymalnie zawierał 10000 najczęściej występujących słów, a słowo musi minimalnie wystąpić 10 razy w tekstach. Przyjęty rozmiar słownika nie jest duży (możecie śmiało spróbować na 30k, 50k). Ktoś może zadać pytanie, po co budować słownik na bazie etykiet? W IMDB etykiety są tekstowe (’neg’ i 'pos’), a nie liczbowe, więc trzeba na bazie słownika je zamienić na liczby (’0′ i '1′).
# build the vocabulary TEXT.build_vocab(train_ds,min_freq=10, max_size=10000 ) #, vectors=GloVe(name='6B', dim=300)) LABEL.build_vocab(train_ds) print(TEXT.vocab.freqs.most_common(10)) print(TEXT.vocab.freqs.most_common()[:-11:-1]) vocab = TEXT.vocab vocab_size = len(vocab) print(f'vocab_size={vocab_size}') print(list(vocab.stoi.keys())[0:10]) print(LABEL.vocab.stoi.keys())
Wyświetlając rozmiar zbudowanego słownika dla tekstu zobaczycie, że ma rozmiar 10002 a nie 10000. Otóż TorchText automatycznie doda dwa słowa: '<unk>'(unknown) i '<pad>’ (padding). Pierwsze służy jako zamiennik dla słów, które nie trafiły do słownika. Może się tak zdarzyć, czy to ze względu na mały rozmiar słownika, czy z powodu, że w zbiorze treningowym to słowo nie wystąpiło a może się pojawić z zbiorze testowym. Drugie jako wyrównanie (dopełnienie do najdłuższego zdania) dla sekwencji o różnych rozmiarach.
Deklaracja iteratorów dla danych
Ostatnim krokiem związanym z danymi jest utworzenie iteratorów dla danych ze zbioru treningowego i walidacyjnego. W tym celu wykorzystujemy BucketIterator, który zgrupuje nam teksty o podobnej długości i pozwoli pobierać je w paczkach o rozmiarze batch_size.
batch_size = 32 train_iter = data.BucketIterator( train_ds, batch_size=batch_size, sort_key=lambda x: len(x.text), sort_within_batch=True, device=device) valid_iter = data.BucketIterator( valid_ds, batch_size=batch_size, sort_key=lambda x: len(x.text), sort_within_batch=True, device=device)
Budowa modułu LSTM (Truncated Backpropagation)
Definiujemy klasę LongSeqTbttRnn jako PyTorch nn.Module. Właściwie, jeżeli czytaliście poprzedni wpis o wprowadzeniu do sieci LSTM do zliczania znaków to architektura wygląda podobnie. Jedyną różnicą są dwie pomocnicze funkcje repackage_rnn_state oraz _detach_rnn_state. Służą one do „przerwania” połączenia łańcucha gradientów.
Chodzi o to, że PyTorch zapamiętuje cały ciąg operacji, które wykonujemy na naszym modelu i na ich podstawie metoda backward oblicza gradienty. My w naszej implementacji nie chcemy pamiętać całej historii operacji (wszystkich obliczeń na danych), chcemy, aby gradienty były obliczone tylko dla K kroków wstecz. Po każdych K krokach musimy wymusić, aby PyTorch zapomniał o poprzednich powiązanych obliczeniach. Możemy to zrobić poprzez odłączenie (metoda detach). Ważne, że metoda ta nie zeruje parametrów (wag) warstwy dzięki czemu kolejne obliczenia zaczniemy już od ostatnio obliczonego stanu sieci. Polecam dwa wpisy, aby lepiej zrozumieć działanie detach:
- https://discuss.pytorch.org/t/help-clarifying-repackage-hidden-in-word-language-model/226
- http://www.bnikolic.co.uk/blog/pytorch-detach.html
Metoda repackage_rnn_state wywoływana jest w trakcie przebiegu pętli uczącej właśnie co K kroków. Wywołuje ona rekurencyjną metodę _detach_rnn_state, która kolejno przechodzi po wszystkich parametrach warstwy rekurencyjnej i wywołuje na poszczególnych tensorach detach.
Wejściem dla metody repackage_rnn_state jest przechowywana w obiekcie wartość ostatniego stanu self.rnn_state.
class LongSeqTbttRnn(nn.Module): """ RNN example for long sequence with TBPTT truncated backpropagation throu time """ def __init__(self, input_dim, output_dim, embed_size, hidden_size, num_layers=1, dropout=0.1,vectors=None ): super().__init__() self.embed_size = embed_size self.hidden_size = hidden_size self.output_dim = output_dim self.num_layers = num_layers self.embed = nn.Embedding(input_dim, embed_size) # if we want to copy embedding vectors if vectors: self.embed.weight.data.copy_(vectors) #after the embedding we can add dropout self.drop = nn.Dropout(dropout) self.rnn = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=False) # we need this for storing last rnn state self.rnn_state = None self.linear = nn.Linear(hidden_size, output_dim) def repackage_rnn_state(self): self.rnn_state = self._detach_rnn_state(self.rnn_state) def _detach_rnn_state(self, h): """Wraps hidden states in new Tensors, to detach them from their history. based on repackage_hidden function from https://github.com/pytorch/examples/blob/master/word_language_model/main.py """ if isinstance(h, torch.Tensor): return h.detach() else: return tuple(self._detach_rnn_state(v) for v in h) def forward(self, seq): # Embed word ids to vectors len_seq, bs = seq.shape w_embed = self.embed(seq) w_embed = self.drop(w_embed) output, self.rnn_state = self.rnn(w_embed, self.rnn_state) # this does .squeeze(0) now hidden has size [batch, hid dim] last_output = output[-1, :, :] last_output = self.drop(last_output) out = self.linear(last_output) return out
Sama metoda forward ma bardzo klasyczną postać. Metoda otrzymuje jako argument seq, który w naszym przypadku jest kawałkiem pierwotnej sekwencji (fragment o rozmiarze K). Jest to tensor o rozmiarach [seq/K x batch_size] przechowujący ciąg numerów kolejnych słów.
Na początku, aby wydobyć semantykę słów, kodujemy poszczególne tokeny (słowa) na gęste wektory z wykorzystaniem podejścia „word embedding” (self.embed(seq)). Otrzymujemy w ten sposób tensor o rozmiarze [seq/K x batch_size x embed_dim]. Po czym, po zakodowaniu stosujemy dropout z prawdopodobieństwem zamaskowania neuronu równym 0.1. Następnie przekazujemy sekwencje do warstwy rekurencyjnej do jej rozwinięcia i obliczenia poszczególnych stanów i wyjść. Zmienna output przechowuje wartości wyjść na każdym etapie sekwencji (dla każdego słowa), a zmienna self.rnn_state ostatni stan sieci. Kolejne K kroków zostanie rozpoczęte już od początkowego stanu self.rnn_state, dzięki czemu będziemy zachowywać kontekst.
Wyjście naszego modelu obliczane jest na bazie ostatniego wyjścia last_output, które jest rzutowane przez warstwę liniową na dwu wymiarowe wyjście (’pos’, 'neg’).
Definicja zadania – funkcja starty i algortym optymalizacji
Naszym zadaniem jest klasyfikacja, więc jako funkcję straty (ang. loss function) użyjemy CrossEntropyLoss. Jako algorytmy klasyfikacji wybrałem Adam (najbardziej polecany w chwili pisania).
model = LongSeqTbttRnn(input_dim=input_dim, output_dim=output_dim, embed_size=n_embed, hidden_size=n_hid) model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters())
TBTT – Pętla ucząca
Pętla ucząca w podejściu uciętej propagacji wstecznej (Truncated Backpropagation Throut Time – TBTT) ma trochę inną postać niż zwykle. Do modelu nie przekazujemy całej sekwnecji tylko jej kawałek oraz po przetworzeniu każdego kawałka w PyTorch musimy pamiętać o „przerwaniu” grafu obliczen przy zachowaniu dotychczasowego stanu sieci.
Początek jest standardowy, iterujemy po epokach (for e in range(epoch) ), a następnie po zbiorze treningowym wykorzystując zadeklarowany wcześniej iterator train_iter. Przed obliczeniem wartości modelu zerujemy gradienty oraz zerujemy stan warstwy rekurencyjnej.
Dzielimy paczkę tekstów (batch_text) na części o rozmiarze (bptt=50) z wykorzystaniem funkcji split_batch. Otrzymujemy listę tensorów bptt_batch_chunks, po których kolejno przechodzimy w następnej pętli. Obliczamy na danej części wyjście sieci i dokonujemy aktualizacji wag tylko na tej części (obliczamy loss, wykonujemy loss.backward i optimizer.step). Po dokonaniu aktualizacji modelu wywołujemy metodę „przerywającą gradient flow”. Następnie iterujemy po kolejnych częściach sekwencji, a gdy się skończą to pobieramy kolejną paczkę danych.
for e in range(epoch): start_time = datetime.now() # train loop model.train() # progress bar = Bar(f'Training Epoch {e}/{epoch}', max=len(train_iter)) for batch_idx, batch in enumerate(train_iter): model.zero_grad() # before each tbptt zero state model.rnn_state = None # move data to device (GPU if enabled, else CPU do nothing) batch_text = batch.text[0].to(device) # include lengths at [1] batch_label = batch.label.to(device) bptt_loss= 0 bptt_batch_chunks = split_batch(batch_text, bptt) # second TBPTT loop, split batch and learn in chunks of batch for text_chunk in bptt_batch_chunks: model.zero_grad() predictions = model(text_chunk) # for each bptt size we have the same batch_labels loss = criterion(predictions, batch_label) bptt_loss += loss.item() # do back propagation for bptt steps in time loss.backward() optimizer.step() # after doing back prob, detach rnn state in order to implement TBPTT (truncated backpropagation through time startegy) # now rnn_state was detached and chain of gradeints was broken model.repackage_rnn_state() bar.next() epoch_loss += bptt_loss bar.finish() # mean epoch loss epoch_loss = epoch_loss / len(train_iter) time_elapsed = datetime.now() - start_time # progress bar = Bar(f'Validation Epoch {e}/{epoch}', max=len(valid_iter)) # evaluation loop model.eval() with torch.no_grad(): for batch_idx, batch in enumerate(valid_iter): # print(f'batch_idx={batch_idx}') batch_text = batch.text[0] #batch.text is a tuple batch_label = batch.label #reset to zero model state model.rnn_state = None # get model output predictions = model(batch_text) # compute batch validation accuracy acc = accuracy(predictions, batch_label) epoch_acc += acc bar.next() epoch_acc = epoch_acc/len(valid_iter) bar.finish() # show summary print( f'Epoch {e}/{epoch} loss={epoch_loss} acc={epoch_acc} time={time_elapsed}') epoch_loss = 0 epoch_acc = 0
W ramach pętli uczącej po każdej epoce dokonujemy ewaluacji naszego modelu na zbiorze walidacyjnym (valid_iter) po to, aby zobaczyć jak wygląda postęp uczenia. Zauważcie, że tym razem już nie korzystamy z TBTT, tylko przekazujemy całą sekwencję.
Dokładność tego modelu wynosi ok 0.85 na parametrach:
bptt=50
batch_size = 64
hidden size
n_hid=256
embed size
n_embed=100
number of layers
n_layers=1
Wydruk powinien wyglądać podobnie do poniższego: !!!!!!!!!!!!!!!!!
```
Training Epoch 0/10 |################################| 782/782
Validation Epoch 0/10 |################################| 782/782
Epoch 0/10 loss=3.116480209295402 acc=0.8292838931083679 time=0:01:19.544885
Training Epoch 1/10 |################################| 782/782
Validation Epoch 1/10 |################################| 782/782
Epoch 1/10 loss=1.9649102331837043 acc=0.8706841468811035 time=0:01:18.703602
Training Epoch 2/10 |################################| 782/782
Validation Epoch 2/10 |################################| 782/782
Epoch 2/10 loss=1.2844358206490802 acc=0.8699648380279541 time=0:01:18.822965
Training Epoch 3/10 |################################| 782/782
Validation Epoch 3/10 |################################| 782/782
Epoch 3/10 loss=0.7612629080134089 acc=0.8631713390350342 time=0:01:18.438742
Training Epoch 4/10 |################################| 782/782
Validation Epoch 4/10 |################################| 782/782
Epoch 4/10 loss=0.46042709653039493 acc=0.8654091954231262 time=0:01:18.089986
Training Epoch 5/10 |################################| 782/782
Validation Epoch 5/10 |################################| 782/782
Epoch 5/10 loss=0.3314593220629808 acc=0.8596547245979309 time=0:01:18.294183
Training Epoch 6/10 |################################| 782/782
Validation Epoch 6/10 |################################| 782/782
Epoch 6/10 loss=0.2812261906621592 acc=0.8589354157447815 time=0:01:18.187062
Training Epoch 7/10 |################################| 782/782
Validation Epoch 7/10 |################################| 782/782
Epoch 7/10 loss=0.2437611708150762 acc=0.8552589416503906 time=0:01:17.948963
Training Epoch 8/10 |################################| 782/782
Validation Epoch 8/10 |################################| 782/782
Epoch 8/10 loss=0.2500312502574547 acc=0.8591752052307129 time=0:01:18.136995
Training Epoch 9/10 |################################| 782/782
Validation Epoch 9/10 |################################| 782/782
Epoch 9/10 loss=0.2074765177977169 acc=0.8543797731399536 time=0:01:18.278987
```
Podsumowanie
W tym artykule przedstawiłem przykład analizy sentymentu recenzji filmowych z bazy IMDB z wykorzystaniem sieci rekurencyjnej. Dodatkowo został przedstawiony algorytm uciętej propagacji wstecznej (Truncated Backpropagation Through Time). Technika ta jest niezbędna podczas uczenia sieci rekurencyjnej na długich sekwencjach.
Materiały
W ramach dalszej nauki polecam do przestudiowania:
- wpis na blogu Jasona Brownlee o Backpropagation Through Time – A Gentle Introduction to Backpropagation Through Time
- dyskusję na forum PyTroch, podającą szczegóły i sposoby implementacji uciętej propagacji wstecznej
Jak każdy wpis na blogu potraktujcie go jako wprowadzenie do własnych eksperymentów oraz dalszej nauki.
Dajcie znać w komentarzach, jeżeli Wam się przydał lub coś nie jest jasne.
Cały kod tego przykładu znajduje się na moim github’ie w projekcie „Pytorch neural networks tutorial” w pliku lstm_imdb_tbptt.py
Sposób uruchomienia szczegółowo opisany jest w README.md. Wszystkie niezbędne zależności zainstalujecie z wykorzystaniem pipenv. W skrócie należy:
- Install Python.
- Install pipenv
- Git clone the repository
- Install all necessary python packages executing this command in terminal
git clone https://github.com/ksopyla/pytorch_neural_networks.git
cd pytorch_neural_networks
pipenv install
Jeżeli uważasz ten wpis za wartościowy to Zasubskrybuj bloga. Dostaniesz informacje o nowych artykułach.
Photo by Daniela Cuevas on Unsplash