Re:ゼロから始めるML生活

どちらかといえばエミリア派です

自然言語処理について勉強してみた(その4:LSTM)

この前は基本的なRNNの仕組みについて勉強していました。

tsunotsuno.hatenablog.com

今回は、現在RNNの中でも代表的なモデルの一つであるLSTMについて勉強します。

今回も参考にしたのはこちらの本です。

ゼロから作るDeep Learning ? ―自然言語処理編

ゼロから作るDeep Learning ? ―自然言語処理編

毎回のことながら、今回も非常にわかりやすかったです。

今回もpytorchを使って、楽して実装を眺めながら勉強していきます。

LSTM

勾配消失/勾配爆発

単純なRNNの問題点として、勾配消失/勾配爆発があります。 RNNレイヤの中で、時系列方向の逆伝播を考えます。

f:id:nogawanogawa:20190210154115j:plain:w500

上の図より、時系列方向の距離が大きい要素ほど、多くの演算を通過することがわかります。 細かい説明は教科書に譲ります。(正直あまり良くわかってないです)

あとは、この辺も参考になるかもしれないです。

qiita.com

この時、時系列的に遠いところの入力は学習にほとんど影響を与えない、あるいは他の要素に対して大きすぎる影響を持つようになってしまいます。

単語の推定を行う場合などには、文脈を考慮するためにある程度の時系列範囲を考慮する必要があります。 しかし、単純なRNNでは、時系列を適切に活用することができないという問題があります。

LSTMの仕組み

そこで考案されたのがLSTM (Long Short Term Memory) です。 前回使ったRNNレイヤの部分を下のようなLSTMレイヤに置き換えます。

f:id:nogawanogawa:20190210154144j:plain

このようにすることで、時系列方向に関して勾配消失/勾配爆発することを回避します。 ざっくりいえば、時系列方向に残す信号と忘れる信号を管理することで、長期記憶を可能にしています。 なんでこうなるのかは、教科書読んでください。どの部分が何を表してあるかまで説明されています。 その他、最適化とかも紹介されています。

何はともあれ、この形にすれば勾配消失/勾配爆発を回避できるLSTMの出来上がりです。

Pytorchの実装

今回はこちらのチュートリアルをなぞってみます。

pytorch.org

環境構築

省略。どうせ前回と一緒なので。

チュートリアルを眺める

LSTMのモデルのコードはこんな感じです。

class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (torch.zeros(1, 1, self.hidden_dim),
                torch.zeros(1, 1, self.hidden_dim))

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(
            embeds.view(len(sentence), 1, -1), self.hidden)
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores

なんか上の図のLSTMレイヤはすでに関数が用意されているんですね。 なのでパラメータを設定して後続の活性化関数を仕込むだけで使えるんですね。 超便利じゃないですか!

感想

以前、LSTMを使って異常検知をやっていた方がいたのを思い出しました。 この辺りまで来ると、結構実用的なレベルで使えるようになるみたいですね。