a = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()))
b = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), *dim*=(1,2)))
c = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), *dim*=-1))

image.png

Case a

1 + logvar - mu.pow(2) - logvar.exp() 의 shape는 (BS, T/q, latent dim)이기 때문에 Batchsize에 대해서도 모두 더해버린다 → 그렇기 때문에 batch size invariant되지 않기 때문에 torch.mean을 하더라도 똑같음

Case b

batch size에 대해서는 sum하지 않기 때문에 sum하면 size가 (BS,)가 된다. 그다음 torch.mean을 한다는 것은 batch size에 대해서 평균을 취하는 것. 이러한 구현은 만약 quant size가 달라지거나, input lenght가 달라지면 시간축이 늘어나거나 줄어들기 때문에 그때마다 KLD의 weight를 reweighting해줘야함

Case C

latent dim에 대해서만 sum하고 나머지 축에 대해선 mean을 때리기 때문에 timedim, batchsize에 대해 invariant함. 해당 loss를 쓰는 것이 batch size, input time에 대해 상관없이 같은 kld weight를 사용하면 됨.