본문 바로가기

VISION

[논문] Generative Adversarial Nets & SRGAN for Super Resolution

  • Discriminative model : 데이터 X가 주어졌을 때 decision boundary를 찾아 classification/regression을 하는 일반적인 모델의 형태. 
  • Generative model : 데이터의 분포를 학습하는 모델. 클래스 마다 데이터들의 분포를 파악하기 때문에 클래스가 주어지면 새로운 데이터를 샘플링할 수 있다.

 

Adversarial Network

 이전까지 Generative model은 이미지 생성 시 필요한 수식들이 적분 불가능함으로 인해 높은 성능을 낼 수 없었다. 때문에 성능이 좋았던 대부분의 DNN은 discriminative model이었는데, GAN은 이 적분불가능함을 해결하고 두 가지 모델이 적대적(adversarial)으로 대치하며 서로를 이기기 위해 학습을 진행하는 구조로 모델을 구축하였다.

 Generative model(Generator)이 이미지를 생성해 전달하면, Discriminative model(Discriminator)은 입력받은 이미지가 real image와 fake image distribution 중 어디에서 추출된 이미지인지 판단한다. 아래 그림의 (a)를 보면 초반에는 generator(초록색 선)의 성능이 낮아 discriminator(파란색 선)가 쉽게 fake 이미지라고 판단할 수 있지만, 시간이 지날수록 generator는 discriminator를 속일 수 있는 방향(검정색 선)으로 학습을 진행한다(현재의 generator distribution을 real image distribution 방향으로 수정). 계속해서 학습을 진행하다가 geterator distribution이 real image distribution과 같아져(d) discriminator가 더 이상 real image와 fake image를 구분할 수 없게 되면(각각 50%의 정확도를 달성하면) 학습이 종료된다.

 

이를 objective function으로 정의하면 아래와 같이 두 모델이 서로 이기려고하는 minimize-maximize 함수가 된다.

Generator 입장에서는 Discriminator가 자신이 생성한 이미지를 1(real)로 판단하게끔 만들어 1-D(G(z))가 최소화되도록 학습한다. Discriminator는 real data(x) distribution에서 샘플링된 데이터에 대해서는 1이라고 판단해야 하며, generator data(z) distribution에서 샘플링된 데이터에 대해서는 D(G(z))=0(fake)이라고 판단해야 한다. 따라서 Discriminator 입장에서는 전체 함수가 maximize 되는 방향으로 학습하게 되므로, 이를 Discriminator 측면에서 다시 작성하면 다음과 같다.

 

Discriminator는 이 objective function을 최대화하려고 하지만, 학습이 진행될수록 Generator의 성능이 높아지기 때문에 함수의 결과값은 줄어들고, D(x)와 D(G(z))가 각각 1/2이 되는 시점에서 -log4의 minima로 수렴한다.

 이러한 함수 정의를 통해 Marcov chain 및 적분불가능한 함수의 approximation을 없애고 적분가능한 함수들을 이용해 generator network를 학습시킨 것이 GAN의 방식이다. GAN이 제안된 이후로 이를 활용한 많은 모델이 제안되었다. 그 중 하나가 이번에 읽은 SRGAN이다.

 

 

SRGAN: Super Resolution using GAN

 SRGAN은 이미지의 고해상도화(Super Resolution)에 GAN을 사용한 모델이다 기존의 고해상도화 모델들은 이미지를 생성할 때 ground truth 이미지와의 pixel-wise Mean Squared Error(MSE)를 criterion으로 잡기 때문에 같은 위치의 픽셀값이 비슷하면 학습을 종료시켰다. 아래의 첫번쨰, 두번째 이미지는 MSE의 역함수를 사용하는 PSNR 값이 세번째 이미지보다 더 높지만, 눈으로 보기에는 세 번째 output이 더 선명하다. 이를 보면 MSE를 criterion으로 사용하는 것이 효과적이지 않다는 것을 알 수 있다. MSE를 사용할 경우 여러 possible solution들의 평균값으로 수렴해 이미지가 흐릿해진다. SRGAN은 인지적으로(perceptually) 더 완벽한 output을 만들어내기 위해 perceptual loss라는 loss function을 정의한다. 

 

 

Perceptual Loss

  Perceptual loss는 content loss와 adversarial loss로 이루어져 있다. 

content loss는 이미지를 pre-trained VGG-19 네트워크에 통과시킨 후 얻은 high level feature 간 차이를 계산한다. 아래처럼 pixel-wise로 계산한 MSE 대신 

아래와 같이 VGG layer의 아웃풋인 Φ 간의 차이를 계산한다. 이를 통해 high frequency 데이터를 잃는 것을 방지하고 perceptual similarity를 loss에 적용할 수 있다. 더 깊은 네트워크에서 더 높은 레벨의 feature map을 사용할 수록 텍스처의 디테일이 잘 살아나는 것으로 확인되었다.

 

  adversarial loss는 앞서 본 GAN의 구조를 이용한 손실함수이다. 이전에 제안된 super resolution 모델 중 주변 픽셀값을 토대로 픽셀값을 예측하는 prediction-based 방식은 빠르지만 prediction 시 지나치게 일반화시켜 흐릿한(smoothed) 이미지를 생성했다. 이후 CNN을 적용한 Deep Neural Network를 이용하는 방식이 네트워크를 깊게 쌓을 수록 복잡한 mapping function을 더 정확하게 표현할 수 있어 성능이 좋았다.

 SRGAN에서도 CNN을 이용해 두 개의 feed-forward adversarial network를 형성한다. 먼저 high resolution(HR) 이미지에서 시작해, 가우시안 필터 적용 후 다운샘플링하여 low resolution(LR) 이미지를 만든다. 이 이미지를 generator에 통과시켜 G(LR) = HR' 이미지를 생성하면, discriminator network는 실제 고화질 이미지(real)와 저해상도에서 고해상도화된 이미지(fake)를 올바르게 판단해 loss값을 극대화하고자 한다. generator는 이를 속이려는 방향으로 학습해나가 loss를 줄인다. GAN 네트워크의 objective function은 다음과 같다.

 Generator가 지속적으로 discriminator를 속이는 방향으로 학습하기 때문에 real high resolution 이미지 분포에 접근해 실제 고해상도와 비슷한 수준의 이미지를 생성해낼 수 있다. 

 

 

모델 구조

 Generator Network는 ResNet을 기반으로 만들어졌다. 16개의 residual block으로 이루어져 있으며, LR 이미지가 들어왔을 때 HR 이미지로 변환한다. generator는 MSE를 기반으로 먼저 pre-training을 시킨 값으로 initialize해놓고 main training을 시작하여 local optima에 빠지는 것을 방지한다. test time에는 batch normalization을 생략하여 각 input 이미지에 대해서만 super resolution이 이루어지도록 하였다. Discriminator Network는 8개의 conv block 이후 2개의 fc layer로 이루어진다. 중간중간 stride를 주어 feature map을 다운샘플링한다. 마지막으로 sigmoid함수를 통과시켜 real 이미지인지 판단한다. 

 

 

Performance

VGG 네트워크에서 더 높은 레이어를 사용한 SRGAN-VGG54모델이 SRGAN-VGG22 모델이나 MSE에러를 사용한 SRResNet보다 더 좋은 성능을 보여주었다. 실제 Higg Resolution 이미지와는 여전히 차이가 크지만, 이전의 성능에 비해서 인지적으로 좋은 고해상도화 성능을 낼 수 있다는 의의가 있다.

 

 

Implementation

 이미지 고해상도화를 위해 필요한 upsampling block을 먼저 구현한다. 논문에서는 4x super resolution을 기준으로 하였기 때문에 up_scale 2로 주면 input channel 4배로 증가시키게 된다. 이후 PixelShuffle 함수를 이용해 앞서 생성된 (N, 4C, H, W)의 데이터를 (N, C, 2H, 2W)로 분배한다. 따라서 하나의 픽셀이 2x2=4개의 픽셀로 upsampling되는 효과가 발생한다.

class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

 

 Generator에서는 먼저 인풋 이미지를 64개 채널로 확대한 후, 일정하게 64채널을 유지하면서 Skip connection을 이용한 Residual block을 여러 개 거친다. 그리고 마지막 블록 직전에 앞서 정의한 upsample block을 적용한다. 여기서는 up_scale 2로 지정해 4x 고해상도화 작업을 하며, 이후 convolution을 통해 3개 채널로 컬러 이미지를 출력한다.

class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, padding=4),nn.PReLU())
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64))
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2

 

 Discriminator Conv – BatchNorm – LeakyReLU블록 8개를 이용하며, 채널 2배 증가할 때마다 stride를 주어 연산량을 줄인다. 마지막 레이어에서는 Adaptive average pooling으로 dimension 1x1으로 축소하고, 이후 convolution을 통해 1개 채널로 축소시킨다. Sigmoid 함수를 통해 0~1 사이의 값으로 매칭하는 구조이다.

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))