VQ-VAE: Neural Discrete Representation Learning

VQ-VAE: Neural Discrete Representation Learning
 

기본 개념

  • prior
    • p(z)를 학습하는 모델
    • 인코더는 “입력 x가 있을 때”만 z를 뽑을 수 있음 (압축/복원용)
    • prior는 “입력 x 없이도” 그럴듯한 z를 뽑아, 새로운 x를 생성할 수 있게 해줌. (생성용)
      • notion image
  • discrete latent variable
    • 값의 집합이 연속 공간이 아니라, 유한/가산 집합인 잠재변수
      • VQ-VAE에서는 정해진 숫자의 latent vector 집합으로 이루어짐
      • 즉 D dimension의 K개의 embedding 집합으로 구성 됨
    • continuous latent variable의 경우 해석도 어렵고 효율이 떨어짐
 

VQ-VAE

모델 구조

notion image
  • 인코더 - 코드북 - 디코더의 구조
  • 코드북
    • D차원의 K개의 임베딩으로 이루어져 있음
  • 인코더 → 코드북
    • 인코더의 아웃풋인 z와 가장 거리가 가까운 임베딩을 코드북에서 찾음
      • notion image
    • 즉 연속된 latent를, 유한한 심볼의 집합에서 하나를 고르도록 하는 것
  • 코드북 → 디코더
    • 코드북에서 찾은 임베딩을 디코더에 전달하여 foward
  • z는 2차원, 3차원이 될 수도 있음
    • 예를 들어 이미지넷을 학습시킬 때에는 z의 차원은 32x32였음
      • 128x128x3 차원의 이미지를 32x32x1 (K=512)로 줄였다가 다시 128x128x3로 복원함
 

학습 방법

  • straight-through gradient estimation
    • forward시에는 코드북을 거쳐서 가지만, 코드북은 미분 불가이기 때문에 backward가 안 됨
    • 따라서 backward시에는 decoder의 인풋에서 encoder의 output으로 곧바로 연결함
  • loss
    • notion image
    • sg는 stopgradient operator로, 순전파 시에는 값을 그대로 사용하고, 역전파시에는 gradient가 흐르지 않도록 하여 상수처럼 취급되도록 함
      1. 첫번째 loss항
          • reconstruction 항
          • 디코더가 원본을 잘 복원할 수 있도록
      1. 두번째 loss항
          • 코드북 업데이트 항. Vector Quantisation (VQ).
          • straight-through gradient estimation로 인해 코드북은 backward시 gradient가 연결되는 부분이 없음.
            • 이 loss를 통해 선택된 코드북의 임베딩이, 인코더에서 나온 Ze(x)와 비슷해지도록 학습
            • notion image
      1. 세번째 loss항
          • commitment loss
          • 인코더 출력 Ze(x)가 선택된 코드북 임베딩과 비슷해지도록 학습 (너무 커지지 않도록 함)
 

왜 이렇게 구성하였는가?

  • posterior collapse를 겪지 않음
    • posterior collapse란?
      • 현상: posterior q(z∣x)가 prior p(z)와 거의 같아진다
      • 결과: pixcelCNN과 같이 autoregressive 모델에서 디코더가 latent를 무시하게 됨
        • latent를 사용하지 않고 Xt-1까지의 정보만 가지고 Xt를 예측함.
      • 왜 발생하는가?
        • posterior collapse가 생기는 이유
          • ELBO
            • notion image
            • VAE의 maximum likelihood
            • VAE가 최적화하려는 이론적 목적함수
            • 따라서 ELBO를 키우려면 reconstruction은 키우고 KL은 줄여야 함
            • KL을 최소화하는 최선의 방법은 q(z∣x)를 p(z)와 같게 만드는 것
      • 왜 VQ-VAE는 posterior collapse가 발생하지 않는가?
        • VQ-VAE 학습은 2단계로 이루어지는데,
          • 1단계, prior는 유니폼으로 두고, 인코더 - 코드북 - 디코더 학습
          • 2단계, 인코더 - 코드북 - 디코더를 프리징하고, z값을 뽑아서 prior 학습
          • 1단계에서 KL이 이미 상수여서 z를 무시하는 쪽으로 학습되지 않는다고 함
  • 기존 discrete VAE보다 gradient variance가 덜 큼
    • 미분 불가능한 구간들로 인해서 사용한 근사적 추정기들로 인해 큰 분산이 발생한다고 함
    • 분산이 크면 업데이트 방향이 요동치고, 수렴이 느려지며, 학습이 불안정해짐. 학습률을 낮춰야 하는 상황도 자주 생김
    • VQ-VAE는 straight-through 추정기로 gradient를 직접 전달하거나, codebook을 이용해 결정론적 방식으로 분산이 적다함
  • discrete latent 모델이지만 continuous latent 모델만큼의 성능이 나옴
 

실험 결과

notion image
  • 정량 비교: CIFAR-10 bits/dim
    • 목적: 이산 latent 써도 연속 VAE급 성능 나오는지 확인함
    • 결과(bits/dim, 낮을수록 좋음)
      • Continuous VAE: 4.51
      • VQ-VAE: 4.67
      • VIMCO(discrete VAE 계열): 5.14
      • bits/dim
        • notion image
        • 의미: 데이터 x를 설명하는 데 필요한 비트
    • VQ-VAE가 연속 VAE에 근접함, 기존 discrete 추정기보다 훨씬 나음
  • 그 외 이미지 생성/비디오/오디오 task에서도 좋은 성능이 나옴
Share article

kjyong