Just Do IT

[논문구현] Generative Adversarial Networks 생성적 적대 신경망(GAN) with Pytorch lightning 본문

AI Study/논문 및 구현

[논문구현] Generative Adversarial Networks 생성적 적대 신경망(GAN) with Pytorch lightning

풀용 2023. 1. 20. 03:20

들어가며

이번 포스팅에서는 흔히 Vanilla GAN 혹은 Simple GAN이라고 불리는 가장 기본적인, 논문에서 제시한 알고리즘을 바탕으로 구현을 해볼 예정입니다. 가장 심플한 GAN이기 때문에 하이퍼 파라미터에 따라, 초기 noise에 따라 학습이 잘 안되기도 하고 나름 잘되기도 합니다. Pytorch lightning을 이용해서 구현해보도록 하겠습니다.

 

Pytorch lightning은 Pytorch 프레임워크를 베이스로 보일러 플레이트를 최대한 제거하고 공통된 스타일의 템플릿을 제공해주는 역할을 합니다.

 

보일러 플레이트란?https://charlezz.medium.com/%EB%B3%B4%EC%9D%BC%EB%9F%AC%ED%94%8C%EB%A0%88%EC%9D%B4%ED%8A%B8-%EC%BD%94%EB%93%9C%EB%9E%80-boilerplate-code-83009a8d3297

 

보일러플레이트 코드란?(Boilerplate code)

보일러플레이트란?

charlezz.medium.com

 

pytorch를 이용하여 구현했기 때문에 pytorch의 모든 기능을 사용할 수 있는게 큰 장점입니다.

 

이 블로그에서 Pytorch lightning의 핵심 모듈들을 잘 설명해 놓으셨습니다.

https://velog.io/@khs0415p/pytorch-lightning#lightningdatamodule

 

pytorch lightning

pytorch lightning

velog.io

0. Config

 

import pytorch_lightning as pl
import random
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset, random_split)
import torch.nn.functional as F
import torchvision.utils as utils
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
config = {
    'LATENT_SIZE':100,
    'HIDDEN_SIZE':256,
    'OUTPUT_SIZE':1,
    'EPOCHS':100,
    'LEARNING_RATE':0.0002,
    'BATCH_SIZE':128,
    'HEIGHT':28,
    'WIDTH':28,
    'CHANNEL':1,
    'SEED':42
}
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self,config,data_dir: str = '/content/'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
        self.config = config

    def prepare_data(self):
        datasets.MNIST(self.data_dir,train=True,download=True)
        datasets.MNIST(self.data_dir,train=False,download=True)


    def setup(self, stage: str):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)


    def train_dataloader(self):
        return DataLoader(self.mnist_full, batch_size=config['BATCH_SIZE'])

GAN의 구조

들어가기에 앞서 GAN의 구조를 다시 한번 상기 시키겠습니다.

1. Generator

class Generator(nn.Module):

  def __init__(self, config):
    super(Generator, self).__init__()

    # 입력층 노드 수
    self.inode = config["LATENT_SIZE"] # 28x28보다 작거나 같다? Z
    # 은닉층 노드 수
    self.hnode = config["HIDDEN_SIZE"]
    # 출력층 노드 수: 생성해야 하는 노드 수
    self.onode = config["HEIGHT"] * config['WIDTH'] # 28x28

    # 신경망 설계
    self.net = nn.Sequential(nn.Linear(self.inode, self.hnode, bias=True),
                             nn.LeakyReLU(),
                             nn.Dropout(0.1),
                             nn.Linear(self.hnode, self.hnode, bias=True),
                             nn.LeakyReLU(),
                             nn.Dropout(0.1),
                             nn.Linear(self.hnode, self.hnode, bias=True),
                             nn.LeakyReLU(),
                             nn.Dropout(0.1),
                             nn.Linear(self.hnode, self.onode, bias=True),
                             nn.Tanh())

  def forward(self, input_features):
    hypothesis = self.net(input_features)
    hypothesis = hypothesis.view(hypothesis.size(0),-1)
    return hypothesis

generator는 random noise(여기에서는 normal distribution에서 random sample된 숫자)를 input으로 받아서 fake image를 생성하는 역할을 합니다.

단순히 Linear layer와 LeakyReLU activation function, dropout으로만 구성된 기본적인 구조로 되어있습니다.

 

  1. random noise의 차원은 hyper parameter다. 본 구현에서는 100차원으로 설정했다. (batch_size,100)의 normal distribution random sample을 input으로 넣는다.
  2. (batch_size,100) shape의 input을 시작으로 hidden layer를 거쳐 (batch_size,784)의 output을 내뱉는다. MNIST의 해상도는 28x28이므로 총 784개의 pixel로 이루어져있기 때문에 output의 shape이 (batch_size,784)가 된다.

2. Discriminator

class Discriminator(nn.Module):

  def __init__(self, config):
    super(Discriminator, self).__init__()

    # 입력층 노드 수
    self.inode = config["HEIGHT"] * config['WIDTH'] # 28x28
    # 은닉층 노드 수
    self.hnode = config["HIDDEN_SIZE"]
    # 출력층 노드 수: 분류해야 하는 레이블 수
    self.onode = config["OUTPUT_SIZE"] # real? fake?

    # 신경망 설계
    self.net = nn.Sequential(nn.Linear(self.inode, self.hnode, bias=True),
                             nn.LeakyReLU(),
                             nn.Dropout(0.1),
                             nn.Linear(self.hnode, self.hnode, bias=True),
                             nn.LeakyReLU(),
                             nn.Dropout(0.1),
                             nn.Linear(self.hnode, self.hnode, bias=True),
                             nn.LeakyReLU(),
                             nn.Dropout(0.1),
                             nn.Linear(self.hnode, self.onode, bias=True),
                             nn.Sigmoid())
    
  def forward(self, input_features):
    hypothesis = self.net(input_features)
    return hypothesis

discriminator는 단순히 real image와 generator에서 나온 fake 이미지를 구별하는 역할을 합니다.

 

  1. (batch_size,784)의 입력을 받아 Linear layer을 거친 후 real과 fake를 구별한다.
  2. 따라서 output의 shape는 (batch_size,1)이다.

Generator와 Discriminator의 마지막 activation fucntion에 관해서

generator는 왜 마지막에 tanh를 쓰고 discriminator는 마지막에 sigmoid를 쓰는지 궁금하실 수 있을것 같습니다. 먼저 discriminator가 sigmoid를 쓰는이유는 자명합니다. discriminator는 real인지 fake인지를 구분해야하기 때문에 (fake : 0, real : 1) 당연히 sigmoid를 써야하는 것입니다. 그렇다면 왜 generator는 tanh를 쓸까요? 바로 이미지를 생성할 때 보통 이미지를 [0,1] 혹은 [-1,1]로 normalize하기 때문입니다. 왜 normalize를 해야하냐고 한다면 

 

https://velog.io/@cha-suyeon/Normalization%EC%9D%B4%EB%9E%80-Normalize%EB%A5%BC-%ED%95%B4%EC%95%BC%ED%95%98%EB%8A%94-%EC%9D%B4%EC%9C%A0

 

Normalization이란? Normalize를 해야하는 이유

stackoverflow와 youtube 영상의 출처는 아래에 있습니다.Data Normalization은 데이터의 범위를 사용자가 원하는 범위로 제한하는 것을 의미합니다.I understand that sometimes, when for example the inpu

velog.io

 

위 포스팅이 잘 설명되어있습니다. 여기서 [0,1]사이로 normalize를 하기위해 sigmoid를 사용하지 않고 [-1,1]사이의 tanh로 normalize하는 이유는 real image를 전처리 하는 과정을 보면 이해하기 쉽습니다.

class MNISTDataModule(pl.LightningDataModule):

    def __init__(self,config,data_dir: str = '/content/'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
        self.config = config

    def setup(self, stage: str):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            self.mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)

다시한번 가져온 datamodule입니다. transform의 ToTensor를 통해 먼저 image를 Tensor로 바꿔줍니다. 여기서 ToTensor는 image를 0~1사이의 값으로 바꿔줍니다. 그 이후 mean : 0.5, std : 0.5로 normalize하게되면 $ Z = \frac{X - mean}{std} $ 으로 min = (0 - 0.5) / 0.5 = -1 max = (1 - 0.5) / 0.5 = 1로 [-1,1]이 되게됩니다. 따라서 tanh함수를 이용한 gnerator의 output의 범위와 전처리된 real image의 범위가 같게 됩니다.

 

https://stats.stackexchange.com/questions/498508/why-use-tanh-function-at-the-last-layer-of-generator-in-gan

 

Why use tanh function at the last layer of generator in GAN?

While studying GAN, I found out that ReLU activation is used at the intermediate layers, and tanh or sigmoid is used at the last layer of the generator. I'm curious about why sigmoid or tanh is use...

stats.stackexchange.com

 

3. GAN

class GAN(pl.LightningModule):
    def __init__(self,config):
        super(GAN,self).__init__()

        self.config = config
        self.G = Generator(config)
        self.D = Discriminator(config)
        self.automatic_optimization = False

    def forward(self, z):
        return self.G(z)


    def get_noise(self,batch_size):
        return torch.randn(batch_size, self.config['LATENT_SIZE'])

    def imshow(self, img):
        img = (img+1)/2    
        img = img.squeeze()
        np_img = img.numpy()
        plt.imshow(np_img, cmap='gray')
        plt.show()

    def imshow_grid(self,img):
        img = utils.make_grid(img.cpu().detach())
        img = (img+1)/2
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1,2,0)))
        plt.show()

    def loss(self, y_hat,y):
        return F.binary_cross_entropy(y_hat,y)

    def configure_optimizers(self):
        optimizer_G = optim.Adam(self.G.parameters(),lr = config['LEARNING_RATE'],betas=(0.9,0.999))
        optimizer_D = optim.Adam(self.D.parameters(),lr = config['LEARNING_RATE'],betas=(0.9,0.999))
        return [optimizer_G,optimizer_D]

    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        
        batch_size = imgs.shape[0]
        imgs = imgs.view(batch_size,-1)

        optimizers = self.optimizers()
        opt_G,opt_D = optimizers[0],optimizers[1]

        real_label = torch.ones((batch_size,1))
        real_label = real_label.type_as(imgs) #이런식으로도 device 변경가능
        fake_label = torch.zeros((batch_size,1),device=self.device) #이런식으로도 가능

        # noise sample
        z = self.get_noise(batch_size)
        z = z.type_as(imgs) # dynamic 하게 device를 변경시켜줌

        G_x = self(z)

        # discriminator

        #real image
        D_x = self.D(imgs)
        D_loss = self.loss(D_x,real_label)

        #fake image
        D_z = self.D(G_x.detach())
        Z_loss = self.loss(D_z,fake_label)

        total_loss = D_loss + Z_loss

        opt_D.zero_grad()
        self.manual_backward(total_loss)
        opt_D.step()

        # generator

        D_z = self.D(G_x)
        G_loss = self.loss(D_z,real_label)

        opt_G.zero_grad()
        self.manual_backward(G_loss)
        opt_G.step()

        return {'D_loss':total_loss,'G_loss':G_loss}

    def on_train_epoch_end(self):
    
        if self.current_epoch % 10 == 0:
            z = self.get_noise(16)
            z = z.type_as(self.G.net[0].weight)

            generate_imgs = self(z)
            generate_imgs_ = generate_imgs.reshape((-1,28,28)).unsqueeze(1)
            self.imshow_grid(generate_imgs_)

앞에서 만든 모듈들을 pytorch lightning 모듈을 이용해 하나로 합치고 학습 루프를 만듭니다.

def __init__(self,config):
        super(GAN,self).__init__()

        self.config = config
        self.G = Generator(config)
        self.D = Discriminator(config)
        self.automatic_optimization = False

 

  • torch lightning은 기본적으로 zero_grad, backward, step과 같은 function을 내부적으로 수행해 줍니다. 하지만 GAN에서는 두가지의 optimizer을 순차적으로 사용하기 때문에 automatic_optimization을 False로 설정해 커스텀 할 수 있게 만듭니다.
def training_step(self, batch, batch_idx):
        imgs, _ = batch
        
        batch_size = imgs.shape[0]
        imgs = imgs.view(batch_size,-1)

 

  • training step이 학습 루프를 담당합니다. batch에서 이미지를 가져오고 view를 통해 28x28을 차원 하나로 만들어 줍니다.
        real_label = torch.ones((batch_size,1))
        real_label = real_label.type_as(imgs) #이런식으로도 device 변경가능
        fake_label = torch.zeros((batch_size,1),device=self.device) #이런식으로도 가능

 

  • pytorch lightning에서는 device를 자동으로 설정해 줍니다. gpu환경이면 gpu를 이용할 수 있고 cpu환경이면 cpu를 따로 설정하지 않아도 이용할 수 있습니다. 따라서 type_as나 self.device를 이용해서 dynamic하게 device를 이용 할 수 있게 설정해 놓으면 편합니다.
  • 먼저 real image을 위해 1로 label된 matrix와 fake를 위해 0으로 label된 matrix를 준비합니다.
  •  
	def get_noise(self,batch_size):
        return torch.randn(batch_size, self.config['LATENT_SIZE'])
        # noise sample
        z = self.get_noise(batch_size)
        z = z.type_as(imgs) # dynamic 하게 device를 변경시켜줌

        G_x = self(z)

 

  • 그리고 noise를 생성하고 Generator에 통과시킵니다.

 

    def loss(self, y_hat,y):
        return F.binary_cross_entropy(y_hat,y)
        # discriminator

        #real image
        D_x = self.D(imgs)
        D_loss = self.loss(D_x,real_label)

        #fake image
        D_z = self.D(G_x.detach())
        Z_loss = self.loss(D_z,fake_label)

        total_loss = D_loss + Z_loss

        opt_D.zero_grad()
        self.manual_backward(total_loss)
        opt_D.step()

 

  • discriminator를 먼저 학습시킵니다. real image를 discriminator에 통과시키는 것은 1, generator를 통과한 fake image는 0으로 cross entropy loss를 계산하고 두 loss를 더해 total loss를 계산합니다. 그리고 backpropagation을 진행합니다. 여기서 주의해야할 점은 generator에서 나온 fake image를 discriminator에 input으로 넣을때 detach를 해줘야 하는 점입니다.
  • discriminator를 학습할 때는 generator는 고정시키고 discriminator의 parameters들만 backpropagation시켜야 합니다. 하지만 기본적으로 generator를 통과한 tensor에는 backpropagation을 위한 계산 그래프가 붙어있습니다. 따라서 이 계산그래프를 detach시켜야 오류가 나지않고 discriminator만 backpropagation시킬 수 있게 됩니다.
        # generator

        D_z = self.D(G_x)
        G_loss = self.loss(D_z,real_label)

        opt_G.zero_grad()
        self.manual_backward(G_loss)
        opt_G.step()

        return {'D_loss':total_loss,'G_loss':G_loss}

 

  • 이후 discriminator를 고정시킨 후 generator만 update합니다. generator는 discriminator가 자신을 real image로 판단해야 하므로 1로 라벨링한 값을 loss로 계산합니다.
    def on_train_epoch_end(self):
    
        if self.current_epoch % 10 == 0:
            z = self.get_noise(16)
            z = z.type_as(self.G.net[0].weight)

            generate_imgs = self(z)
            generate_imgs_ = generate_imgs.reshape((-1,28,28)).unsqueeze(1)
            self.imshow_grid(generate_imgs_)

 

  • 이후 10번째 epoch마다 generator의 학습 상태를 이미지로 출력합니다.
dm = MNISTDataModule(config)
model = GAN(config)
trainer = pl.Trainer(accelerator='auto',devices='auto',max_epochs=config['EPOCHS'],callbacks=[pl.callbacks.progress.TQDMProgressBar(refresh_rate=20)])
trainer.fit(model,dm)

 

  • trainer를 만든후 학습을 시키면 완성입니다. accelerator를 auto로 지정하면 알아서 사용 가능한 gpu,cpu를 파악하고 devices를 auto로 지정하면 accelerator 몇개를 사용 가능한지 알아서 판단하여 사용합니다.

4. 결과

총 100 epoch를 돌려 10회 마다 출력한 결과물 입니다. 처음에는 단순 noise에 불과했지만 점점 MNIST데이터와 비슷하게 generator가 생성하는 것을 볼 수 있습니다.

 

물론 아주 간단한 GAN이기 때문에 하이퍼 파라미터와 초기 noise에 따라 학습이 진행이 되지 않을 수도 있습니다. generator가 다양한 fake image를 생성하지 못하고 비슷한 data만 생성하는 mode collapse의 문제도 발생할 수 있습니다.

 

그럼에도 모든 GAN의 base가 되는 기초이기 때문에 알아두면 좋을 것 같습니다.

 
 
 
 
 
 
 
 
Comments