import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns

class StackLinear(nn.Module): #( 128 *2, 4, 128)
    def __init__(self, quant_factor=2, unstack=False, seq_first=True): 
        super().__init__() 
        self.quant_factor = quant_factor
        self.latent_frame_size = 2**quant_factor
        self.unstack = unstack
        self.seq_first = seq_first

    def forward(self, x):
        if self.seq_first:
            B, T, F = x.shape # (BS,64,256)
        else:
            B, F, T = x.shape
            x = x.permute(0, 2, 1)

        if not self.unstack:# stack
            assert T % self.latent_frame_size == 0, "T must be divisible by latent_frame_size"
            T_latent = T // self.latent_frame_size
            F_stack = F * self.latent_frame_size
            x = x.reshape(B, T_latent, F_stack) 
        else: #unstack
            F_stack = F // self.latent_frame_size
            x = x.reshape(B, T * self.latent_frame_size, F_stack)

        if not self.seq_first:
            x = x.permute(0, 2, 1)

        return x

def test_stack_linear():
    # Test Case 1: Standard Input (Stacking)
    model = StackLinear(quant_factor=2, unstack=False, seq_first=True)
    x = torch.randn(2, 8, 4)
    output = model(x)
    expected_output_shape = (2, 2, 16)
    assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
    print("Test Case 1 passed: Output shape is correct.")

    # Test Case 2: Standard Input (Unstacking)
    model = StackLinear(quant_factor=2, unstack=True, seq_first=True)
    x = torch.randn(2, 4, 8)
    output = model(x)
    expected_output_shape = (2, 16, 2)
    assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
    print("Test Case 2 passed: Output shape is correct.")

    # Test Case 3: Different Sequence First (Stacking)
    model = StackLinear(quant_factor=2, unstack=False, seq_first=False)
    x = torch.randn(2, 8, 4)
    output = model(x)
    expected_output_shape = (2, 32, 1) # (bs, c, t)
    assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
    print("Test Case 3 passed: Output shape is correct.")

    # Test Case 4: Different Sequence First (Unstacking)
    model = StackLinear(quant_factor=2, unstack=True, seq_first=False)
    x = torch.randn(2, 4, 8) #(bs, c, t)
    output = model(x)
    expected_output_shape = (2, 1, 32)
    assert output.shape == expected_output_shape, f"Output shape mismatch. Expected: {expected_output_shape}, Got: {output.shape}"
    print("Test Case 4 passed: Output shape is correct.")

    # Test Case 5: Quant Factor not Divisible
    try:
        model = StackLinear(quant_factor=3, unstack=False, seq_first=True)
        x = torch.randn(2, 7, 4)
        model(x)
    except AssertionError as e:
        print("Test Case 5 passed: Caught expected assertion error.")

    # Visualize input and output tensors for Test Case 1
    model = StackLinear(quant_factor=2, unstack=False, seq_first=True)
    x = torch.randn(2, 2, 4).repeat_interleave(4, dim=1)
    print(x[0])
    output = model(x)
    print(output[0])
    x_np = x.detach().numpy().reshape(-1, 4)
    output_np = output.detach().numpy().reshape(-1, 16)
    

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    sns.heatmap(x_np, cmap='viridis', cbar=True)
    plt.title('Input Tensor')

    plt.subplot(1, 2, 2)
    sns.heatmap(output_np, cmap='viridis', cbar=True)
    plt.title('Output Tensor')

    plt.tight_layout()
    plt.show()

# Running the test
test_stack_linear()

Running the test

test_stack_linear_squash()

image.png

image.png