Re:ゼロから始めるML生活

どちらかといえばエミリア派です

簡易版のGANを書いてみた

この前はCNNを書いてみました。

tsunotsuno.hatenablog.com

今回はちょっぴりCNNの応用のGANってやつをやってみます。 今回参考にさせていただいたのはこちら。

elix-tech.github.io

めちゃくちゃ分かりやすかったです。 ありがとうございます。 こういういろんな人に参考にされるブログが書けると良いなあ、なんて思った次第です。

[1511.06434] Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

代表的なGANのモデルであるDCGANの論文です。パラメータとかの具体的な値について参考にしました。

GAN (Generative Adversarial Network)

全体感

まず、GANとはなんぞやというところを、いつも通りふわっと勉強してみました。 GANはざっくりと言うと、2つのCNNを競争させながら学習する機械学習モデルです。 イメージとしてはこんな感じです。

f:id:nogawanogawa:20180118214214j:plain

GeneratorとDiscreminatorがいて、Generatorが精製した嘘のデータをDiscriminatorは本物かどうか見抜き続け、お互いに学習し合います。 イメージは"ホコタテ"で、最終的に最強になるかどうかは置いておいて、いい感じの矛と盾が同時に出来上がります。

さてさて、ここらへんで私としては”???”ってなりました。 困ったのはこちらの2箇所。

  • Generatorどないすんねん、、、
  • 教師無し学習なんてやったことない。。。

ってことで、こちらの2つについて細かく見てみました。

Generatorどないすんねん、、、

Discriminatorは前回までのCNNでいけそうな気がしたんですが、Generatorは全くイメージつきませんでした。 そもそも、何を材料にして確率分布を生成すんねん、、、ってなって、上のサイトを参考にさせて頂きました。

どうやらイメージはこんな感じになっているみたいです。

f:id:nogawanogawa:20180118224151j:plain

入力はスカラ値のノイズで、そこから確率分布を生成するそうです。へー。

実際には、このネットワークが逆向きの畳み込みネットワークになっていたり、Pooling層がなかったりするんですが、まずはざっくりとした概要です。 初めは適当な画像ができるんですが、だんだんやってくうちにDiscriminatorを騙し切れるように学習していきます。

そんでもって、次に考えたのは畳み込みでどうやって行列のサイズを大きくしていくの?ってことです。 一般的に畳み込みは入力データに対して一回り小さなデータが生成されます。

f:id:nogawanogawa:20180127230420j:plain

下にある青いセルが入力データで、そこに対して畳み込み演算をすると出力は入力より一回り小さな緑のデータが生成されることがわかります。 データが小さくすることを防ぐには、入力データの外側にパディングでデータを詰めていく調整をすることが一般的です。 出力データを入力データより大きくするためには、かなりの量のパディングをすることになります。

それって結局強烈なバイアスをかけていることに近いので、画像生成にはならないと思います。

こんな感じで"???"ってなったので調べてみると、やっぱりやり方がありました。

qiita.com

f:id:nogawanogawa:20180127230355j:plain

下の青いセルが入力のデータで、そのデータを間引いて拡大します。 拡大後のデータに対してコンボリューションすることで、出力のサイズを入力のサイズより大きくすることができます。 こんな感じで入力データを拡大して、そこに畳み込み演算をしていくんですね。へー。

教師無し学習なんてやったことない。。。

次にコケたのが、「教師なし学習」のモデル設計です。

前回までやっていたCNNでは教師あり学習をやっていました。 画像に対してなんの画像なのかのラベルをくっつけることで、「入力された画像がなんの数字なのか」を判定していました。

で、また上の記事を参考に見て見たところ、考え方が違うんですね。 Discriminatorは「入力された画像が教師データなのか生成データなのか」を判定するんですね。

イメージはこんな感じです。

f:id:nogawanogawa:20180120085304j:plain

Discriminatorに入力された画像が生成されたものか教師データかは予めわかっているので、事実上の教師あり学習に変わります。 判定結果を正解と比較することでlossを計算し、Discriminatorは学習します。

実装

概念的なレベルの疑問が解消されたところで、ちょっとずつ実装に入っていきたいと思います。

GANの学習の定義

今回は全体構成としてこんな感じで全体を構成します。

f:id:nogawanogawa:20180304142524j:plain:w450

そんでもってコードとしてはこんな感じにしてみました。

色々試行錯誤してみた残骸(コード)はこちら。

https://github.com/nogawanogawa/simpleGAN.git

Generator

Generatorのネットワークの構成としてはこんな感じになります。

f:id:nogawanogawa:20180121230809j:plain

登場する層としては、DeconvolutionとReLUだけとなります。 GeneratorのDeconvolutionをして画像を生成するので、そこだけ異なってきます。

Deconvolution

基本的にはConvolutionと同じ処理になりますが、入力に対して前処理を行って入力のサイズを拡大してからConvolutionを実行するみたいです。

f:id:nogawanogawa:20180127230554j:plain

前回までのConvolutionの層の前に入力のPadding処理を入れてから実行しています。 Tensorflowとかのライブラリを見る感じ、出力のサイズによってPaddingや間引きのサイズにルールがあるみたいですが、今回は簡略化してやってみます。

www.monthly-hack.com

forward
def forward(self,x): 

    # フィルターと出力の形状
    FN, C, FH, FW = self.W.shape
    N, C, H, W = x.shape
    out_h = 1 + int((H + 2*self.pad - FH) / self.stride)
    out_w = 1 + int((W + 2*self.pad - FW) / self.stride)

    # transpose処理
    x_padded = np.zeros((1, 1, H*2, W*2), dtype=np.float32)
    x_padded[:, :, ::self.pad_stride, ::self.pad_stride] = x

    out = self.convolution(x_padded)

    return out
backward
def backward(self, dout):
    FN, C, FH, FW = self.W.shape
    N, C, H, W = dout.shape
    dout = dout.transpose(0,2,3,1).reshape(-1, FN)

    # affine層と同様の逆伝播
    self.db = np.sum(dout, axis=0)
    self.dW = np.dot(self.col.T, dout)
    self.dW = self.dW.transpose(1, 0).reshape(FN, C, FH, FW)

    dcol = np.dot(dout, self.col_W.T)
    dx = col2im(dcol, self.x.shape, FH, FW, self.stride, self.pad)

    # transpose処理(縮小方向)
    H_ = int(H/2)
    W_ = int(W/2)
    dx_ = np.zeros((1, 1, H_, W_), dtype=np.float32)
    dx_ = dx[:, :, ::self.pad_stride, ::self.pad_stride]

    return dx_

LeakyReLU

LeakyReLUはこんな関数になっています。

f:id:nogawanogawa:20180304135012p:plain:w360

x<0の領域で0一定ではなく負の値をもつReLU関数のイメージです。 パラメータとかはめんどくさいので、論文を参考にしました。

forward
def forward(self, x):
    out = np.maximum(0.2 * x, x)
    return out
backward
def backward(self, dout):
    out = np.minimum(5*dout, dout)
    return out

Discriminator

Discriminatorは以前のCNNを流用、改良したので省略します。

結果

まず、教師データはこちらです。数字の7ですね。 f:id:nogawanogawa:20180304121332p:plain:w300

次にGANによって生成された画像がこちら。 f:id:nogawanogawa:20180304121345p:plain:w300

とりあえずなんか画像は出ました。 ぼやっと対角線方向に分布が形成されていることはわかります。

ちなみに1800イテレーション終了時の画像がこちらで、これ以上学習させると今度は過学習なのか画像が真っ黒になります。 鮮明な画像が出ないのは単純にGeneratorが簡易版で実装しているためですね。 デコンボリューション時のバイアスの影響をもろに受けているので、格子状の模様が見えていますし。

感想

世の中に出回っているレベルの鮮明な画像生成ができるようになるにはまだまだ時間がかかりそうです。 今回は理屈がどんなもんで、ちゃんと動いてそれっぽい画像が出てきたのでOKとします。。。

あと、二度と(半)フルスクラッチでGANは書きたくないです。笑

専用のライブラリを使っていかないと、理屈から勉強してたら時間がいくらあっても足りませんね。 次はTensorFlowで実装したもので色々遊ぼうかと思います。