import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
class StackLinear(nn.Module): #( 128 *2, 4, 128)
def __init__(self, quant_factor=2, unstack=False, seq_first=True):
super().__init__()
self.quant_factor = quant_factor
self.latent_frame_size = 2**quant_factor
self.unstack = unstack
self.seq_first = seq_first
def forward(self, x):
if self.seq_first:
B, T, F = x.shape # (BS,64,256)
else:
B, F, T = x.shape
x = x.permute(0, 2, 1)
if not self.unstack:# stack
assert T % self.latent_frame_size == 0, "T must be divisible by latent_frame_size"
T_latent = T // self.latent_frame_size
F_stack = F * self.latent_frame_size
x = x.reshape(B, T_latent, F_stack)
else: #unstack
F_stack = F // self.latent_frame_size
x = x.reshape(B, T * self.latent_frame_size, F_stack)
if not self.seq_first:
x = x.permute(0, 2, 1)
return x
def test_stack_linear():
# Test Case 1: Standard Input (Stacking)
model = StackLinear(quant_factor=2, unstack=False, seq_first=True)
x = torch.randn(2, 8, 4)
output = model(x)
expected_output_shape = (2, 2, 16)
assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
print("Test Case 1 passed: Output shape is correct.")
# Test Case 2: Standard Input (Unstacking)
model = StackLinear(quant_factor=2, unstack=True, seq_first=True)
x = torch.randn(2, 4, 8)
output = model(x)
expected_output_shape = (2, 16, 2)
assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
print("Test Case 2 passed: Output shape is correct.")
# Test Case 3: Different Sequence First (Stacking)
model = StackLinear(quant_factor=2, unstack=False, seq_first=False)
x = torch.randn(2, 8, 4)
output = model(x)
expected_output_shape = (2, 32, 1) # (bs, c, t)
assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
print("Test Case 3 passed: Output shape is correct.")
# Test Case 4: Different Sequence First (Unstacking)
model = StackLinear(quant_factor=2, unstack=True, seq_first=False)
x = torch.randn(2, 4, 8) #(bs, c, t)
output = model(x)
expected_output_shape = (2, 1, 32)
assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
print("Test Case 4 passed: Output shape is correct.")
# Test Case 5: Quant Factor not Divisible
try:
model = StackLinear(quant_factor=3, unstack=False, seq_first=True)
x = torch.randn(2, 7, 4)
model(x)
except AssertionError as e:
print("Test Case 5 passed: Caught expected assertion error.")
# Visualize input and output tensors for Test Case 1
model = StackLinear(quant_factor=2, unstack=False, seq_first=True)
x = torch.randn(2, 2, 4).repeat_interleave(4, dim=1)
print(x[0])
output = model(x)
print(output[0])
x_np = x.detach().numpy().reshape(-1, 4)
output_np = output.detach().numpy().reshape(-1, 16)
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
sns.heatmap(x_np, cmap='viridis', cbar=True)
plt.title('Input Tensor')
plt.subplot(1, 2, 2)
sns.heatmap(output_np, cmap='viridis', cbar=True)
plt.title('Output Tensor')
plt.tight_layout()
plt.show()
# Running the test
test_stack_linear()
test_stack_linear_squash()

