Sieć rekurencyjna LSTM do zliczania znaków – wprowadzenie

Sieci rekurencyjne zdolne są do rozpoznawania zależności wynikających z połączenia i częstości występowania symboli. W tym tutorialu stworzymy sieć LSTM, która będzie stanowiła dobre wprowadzenie do bardziej skomplikowanych modeli.

W tym artykule chciałbym pomóc Wam w rozpoczęciu implementacji własnej sieci rekurencyjnej przetwarzającej znaki. Zbudujemy rekurencyjną sieć neuronową LSTM, która będzie rozpoznawała jaki symbol dominuje w całej sekwencji.

Do tego celu wykorzystamy bibliotekę TorchText. Pomoże nam ona w wczytywaniu tekstu, budowie słownika oraz zamianie tekstu na wektor.

Wpis ten należy do serii artykułów o Pytorch, w każdym artykule opisuję pewien aspekt tej biblioteki oraz przykłady najpopularniejszych architektur sieci neuronowych.

Do tej pory opublikowałem:

Celem całej serii jest pomoc w nauce i wskazaniu dalszych kierunków rozwoju.

Jak działa sieć rekurencyjna?

Drogi czytelniku, w tym miejscu muszę cię przeprosić. Pisząc ten artykuł zacząłem wgłębiać się jak działają i jak są zbudowane sieci rekurencyjne. W trakcie zdałem sobie sprawę, że materiał jest na tyle obszerny, że poświęcę mu osobny wpis. Chcąc utrzymać praktyczny ton tutorialu pozwól, że odeślę Cię do innych sprawdzonych opracowań.

Materiały o sieciach rekurencyjnych

Sprawdzona i polecana lista materiałów:

Możecie także obejrzeć mój wykład z konferencji 4Developers z 2018, w którym omawiam rekurencyjną architekturę dla “Char Language model”. Kod jest w Tensorflow, ale początek powinien wam dać dobre wprowadzenie do idei sieci rekurencyjnych.

Omówienie problemu

Na wstępie chciałbym się wytłumaczyć, dlaczego wybrałem taki przykład. Zliczanie znaków można zrealizować na wiele prostszych sposobów, sieć rekurencyjna nie jest idealnym modelem do tego typu zadania.

Niemniej jednak przykład ten dobrze nadaje się, aby pokazać sekwencyjny charakter sieci. Chciałem uwypuklić umiejętność zapamiętywania stanu (aby zliczać trzeba pamiętać co do tej pory widzieliśmy), co w przykładzie ze zliczaniem jest łatwe do zrozumienia.

Ponadto zadanie to jest proste ideowo. Łatwo sobie wyobrazić co ma być na wyjściu.

Minusem tego podejścia jest brak zależności semantycznych pomiędzy poszczególnymi elementami w sekwencji jak np. w analizie sentymentu. Na tego typu przykłady przyjdzie jeszcze pora.

Skąd weźmiemy dane do sieci?

Wygenerujemy je! Napiszmy funkcję, która będzie generowała ciągi losowych znaków. Jako wartość docelową, której będziemy chcieli się nauczyć podamy znak który występuje najczęściej.

Całość zapiszemy w Pandas’owym dataframe’ie. I ten dataframe posłuży nam jako dataset. Tu małe sprostowanie, pytorch nie ma klasy, która pozwala na pobieranie danych z pands dataframe. W przykładzie wykorzystam własną klasę DataframeDataset oraz bibliotekę TorchText.

Przetwarzanie sekwencji z TorchText

Wiele zadań z NLP (natural language processing) wymaga wstępnego przetwarzania tekstu. Do standardowych zadań możemy zaliczyć: wczytywanie danych, budowa słownika, tworzenie wektorów. W naszym przykładzie posłużę się biblioteką TorchText, która ułatwia i upraszcza część zadań.

Sieć LSTM do zliczania znaków – implementacja

Link do całego przykładu na githubie znajduje się na końcu artykułu.

A teraz zabieramy się do pracy. Cały skrypt można podzielić na następujące części:

  • Generowanie dataframe z danymi, wykorzystamy własny pomocniczy skrypt ‘data_helpers/data_gen_utils.py’.
  • Określenie typów danych i ich sposobów przetwarzania z wykorzystaniem TorchText.data.Field.
  • Opakowanie dataframe z danymi w klasę DataFrameDataset, aby móc wpiąć się w cały proces przetwarzania tekstu z TorchText.
  • Budowanie słownika, czyli listy tokenów, z jakich składa się nasz tekst. Zazwyczaj jest to lista słów, a w naszym przypadku będzie to ponumerowana lista znaków. Zwróćcie uwagę na dodatkowe tokeny <unk>, <pad>.
  • Stworzenie klas iteratorów po dataset’cie, my wykorzystamy BucketIterator z TorchText. Zaletą jego jest grupowanie tekstów o podobnej długości. Dane do sieci w ramach paczki (‘batch’) muszą być równe. Wyrównanie polega na dodaniu na końcu krótszych sekwencji dodatkowych tokenów <pad>. Pogrupowanie pozwala na efektywniejsze wykorzystanie pamięci (GPU).
  • Stworzenie modułu SeqLSTM
  • Określenie funkcji straty i optymalizatora
  • Pętla ucząca wraz z walidacją wyników

Generowanie danych

W tym celu wykorzystamy własną funkcję gen_df z modułu data_helpers.data_gen_utils.py, pamiętajcie o jego zaimportowaniu.

Funkcja ta wygeneruje nam dwie pandas’owe ramki train_df i valid_df. Treningowa będzie zawierała 1000 losowych tekstów, walidacyjna 200. Teksty (sekwencje) będą o długości od 100 do 300 znaków i składają się z pierwszych 10 znaków alfabetu (od ‘a’ do ‘j’). Podczas generowania 2-krotnie podbijam wystąpienie jednego z 10 znaków. Tak, aby mieć pewność, że jeden z nich będzie rzeczywiście dominował.

Wczytywanie danych z TorchText

Praca z TorchText wymaga od nas zdefiniowania, w jaki sposób będziemy traktować poszczególne pola w naszym zbiorze danych. Pole może być tekstem, liczbą, stałą sekwencją itp.

My zdefiniowaliśmy dwa pola, TEXT i LABEL. Pierwsze będzie traktowane jako sekwencja, zostanie przepuszczone przez naszą funkcję tokenizującą ‘tokenize’ i będzie pozwalało na budowanie słownika (use_vocab domyślnie ma wartość True). Drugie LABEL określa etykietę, będącą liczbą oznaczającą, który numer znaku dominuje.

Następnie opakowujemy dataframe’y we własną klasę rozszerzającą TorchText.Datasets DataFrameDataset i budujemy słownik.

Zostaje jeszcze stworzenie iteratorów po naszych data setach

Do iteratorów przekazujemy obiekty dataset’ów (train_ds, valid_ds) skąd będą czerpały dane. Musimy podać rozmiar paczki danych (‘batch_size‘) oraz podajemy funkcję sortującą do grupowania (sort_key) w naszym wypadku będzie to po prostu długość tekstu. Ostatnim parametrem jest device określający gdzie będziemy dokonywać obliczeń ‘CPU’ czy ‘GPU’. Jest on ustawiony na samym początku skryptu w zależności od tego czy macie GPU u siebie na pokładzie.

Budujemy sieć LSTM

Sam moduł do przetwarzania sekwencji ma dość standardową budowę. Do konstruktora przekazujemy niezbędne rozmiary, wejścia (vocab_size) wyjścia (output_size), rozmiar warstwy embedding (embed_size) oraz rozmiar warstwy ukrytej w LSTM (hidden_size).

Kolejno definiujemy wykorzystywane warstwy sieci. Pierwszą warstwą jest warstwa do zanurzeń (Embedding)self.embed. W naszym przykładzie nie jest ona konieczna, bo przetwarzamy losowe litery i nie będzie ona w stanie odtworzyć znaczeniowych powiązań pomiędzy sekwencjami. Niemniej jednak dodałem ją, bo gdy będziecie chcieli zrobić coś swojego to pewnie się ona przyda. Warstwa ta na wejściu przyjmuje wektor zawierający numery tokenów w sekwencji i transformuje każdy token do wektora o rozmiarze (embed_size).

Następnie tworzymy obiekt właściwego modułu LSTM. W naszym przykładzie podajemy parametry:

  • embed_size – rozmiar pojedynczego wektora z danymi, nie licząc długości sekwencji oraz batch.
  • hidden_size – rozmiar warstwy ukrytej w LSTM
  • num_layers – liczba warstw LSTM, my dla prostoty ustawiamy na 1. Więcej warstw == mocniejszy model, ale i bardziej wymagający obliczeniowo
  • batch_first – określa sposób ułożenia danych. Mamy dwa podejścia sequence_len first (domyślne) oraz batch_first. W pierwszym podejściu tensor wejściowy ma rozmiar (seq_len, batch_size, embed_size) a w drugim (batch_size, seq_len, embed_size). Parametr ten pozwala na zachowanie kompatybilności z różnymi podejściami. W Tensorflow dominuje raczej podejście batch_first a w Pytroch’u sequence_first, kwestia konwencji i tego, jak kto woli.

Ostatnia warstwa to warstwa liniowa w pełni połączona, wyjściowa, która mapuje wektor stanu na wyjście. W naszym przypadku mapuje na 10-elementowy wektor zawierający wartości określające, który znak dominuje. Wybieramy ten, dla którego wartość jest największa.

Metoda forward zawiera przepis na przepuszczanie tensora wejściowego przez kolejne warstwy. Najpierw embed, potem lstm. Następnie spłaszczamy (output[-1, :, :]) trój-wymiarowy tensor wyjściowy z lstm do dwu- wymiarowego wektora o rozmiarach [batch, hidden_size*embed_size].

Pętla ucząca i walidacja modelu

W pętli uczącej przechodzimy w N epokach po dataset’cie używając dataloaderów.

  • Pobieramy paczke danych i przenosimy na wybrane przez nas urządzenie (CPU lub GPU) inputs, labels = inputs.to(device), labels.to(device).
  • Zerujemy gradienty modelu (model.zero_grad())
  • obliczamy predykcję (predictions = model(inputs))
  • obliczamy funkcję starty (loss = criterion(predictions, labels)) porównując rzeczywiste etykiety z tymi przewidzianymi przez nasz model
  • obliczamy gradienty dla optymalizatora (loss.backward())
  • dokonujemy aktualizacji wag w sieci (optimizer.step() )

Pętla walidacyjna jest zbudowana analogicznie do uczącej, z tym że co epokę obliczamy dokładność klasyfikacji.

Poniżej znajduje się wydruk z konsoli. Całość uruchomiłem na GPU (Geforce 960) na 60 epokach.

Podsumowanie i materiały dodatkowe

W tym artykule przedstawiłem krok po kroku budowę rekurencyjnej sieci neuronowej na przykładzie zliczania znaków. Wpis potraktujcie jako wprowadzenie do własnych eksperymentów oraz dalszej nauki.
Dajcie znać w komentarzach, jeżeli Wam się przydał lub coś nie jest jasne.

Cały kod tego przykładu znajduje się na moim github’ie w projekcie “Pytorch neural networks tutorial” w pliku lstm_net_counting_chars.py

Sposób uruchomienia szczegółowo opisany jest w README.md. Wszystkie niezbędne zależności zainstalujecie z wykorzystaniem pipenv. W skrócie należy:

  1. Install Python.
  2. Install pipenv
  3. Git clone the repository
  4. Install all necessary python packages executing this command in terminal

Zachęcam także do zostawienia maila i w ten sposób zasubskrybowania bloga. Dostaniesz tylko informacje o artykułach i planowanych przez mnie kursach.

Photo by Thomas Tucker on Unsplash

1 Comment Sieć rekurencyjna LSTM do zliczania znaków – wprowadzenie

  1. Pingback: Implemetanacja Pandas DataFrame Dataset w TorchText - About Data

Dodaj komentarz

This site uses Akismet to reduce spam. Learn how your comment data is processed.