import os
import torch

import argparse
from config import get_args_parser
import models
import datasets
from clip_loss import CLIP_loss
from torch.utils.data import DataLoader
import tqdm
import time
import json
import random
import glob
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt

def validation(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Validation start using {device}...")

    
    print("Loading models...")
    DEE = models.DEE(args)
    DEE = DEE.to(device)
    
    print("Loading validation dataset...")
    val_dataset = datasets.AudioExpressionDataset(args, split = 'val')
    val_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=True, num_workers=4)
  

    print("Validation starts...")

    checkpoint_path = f"{args.save_dir}/model_99.pt"

    checkpoint = torch.load(checkpoint_path)
    DEE.load_state_dict(checkpoint)

    DEE.eval()
    for audio_samples, expression_samples in tqdm.tqdm(val_dataloader) :
        '''
        audio_samples : [emotion, intensity, audio_processed]
        expression_samples : [emotion, intensity, expression_processed]
        '''
        emotion, intensity, expression_processed = expression_samples
        audio_processed = audio_samples[2]
        emotion = torch.tensor(emotion).unsqueeze(1).to(device) #(BS,1)
        intensity = torch.tensor(intensity).unsqueeze(1).to(device) #(BS,1)

        expression_processed = expression_processed.to(device)
        audio_processed = audio_processed.to(device)
        with torch.no_grad():
            audio_embedding = DEE.encode_audio(audio_processed)
            expression_embedding = DEE.encode_parameter(expression_processed)

        DB_expression_transposed = expression_embedding.T # (512,BS)
        DB_audio = audio_embedding # (BS,512)
        DB_emosity = torch.cat((emotion, intensity), dim=1) # (BS,2)

        # compute similarity btw audio, expression
        sim_matrix = DB_audio @ DB_expression_transposed
    
    # compute accuracy
    expression_accuracy , audio_accuracy = retrival_accuracy(sim_matrix, DB_emosity)

    print('expression retrieval accuracy :')
    print(expression_accuracy)
    print('audio retrieval accuracy : ')
    print(audio_accuracy)
        

def save_arguments_to_file(args, filename='arguments.json'):
    with open(filename, 'w') as file:
        json.dump(vars(args), file)

def visualze_embeddings(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Validation start using {device}...")
    
    print("Loading models...")
    DEE = models.DEE(args)
    DEE = DEE.to(device)
    
    print("Loading validation dataset...")
    val_dataset = datasets.AudioDataset(args, split = 'train')
    val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=4)
  

    checkpoint_path = f"{args.save_dir}/model_99.pt"
    checkpoint = torch.load(checkpoint_path)
    DEE.load_state_dict(checkpoint)

    DEE.eval()
    audio_embeddings = []
    emotions = []
    for audio_samples in tqdm.tqdm(val_dataloader) :
        '''
        audio_samples : [emotion, intensity, audio_processed]
        expression_samples : [emotion, intensity, expression_processed]
        '''
        emotion, intensity, audio_processed = audio_samples

        emotion = torch.tensor(emotion).unsqueeze(1)
        intensity = torch.tensor(intensity).unsqueeze(1)
        audio_processed = audio_processed.to(device)

        with torch.no_grad():
            audio_embedding = DEE.encode_audio(audio_processed)
        print(audio_embedding)

 
        DB_audio = audio_embedding # (BS,512)
        audio_embeddings.append(DB_audio)
        emotions.append(emotion)
        DB_emosity = torch.cat((emotion, intensity), dim=1) # (BS,2)
    DB_audio = torch.cat(audio_embeddings, dim=0)
    DB_audio = DB_audio.cpu().numpy()
    emotion = torch.cat(emotions, dim=0)
    emotion = emotion.numpy().reshape(DB_audio.shape[0],)
    for i in [1,2,4,6, 7,8]:
        DB_audio = np.delete(DB_audio, np.where(emotion == i), axis=0)
        emotion = np.delete(emotion, np.where(emotion == i), axis=0)
    print("DB_audio shape: ", DB_audio.shape)
    print("emotion shape: ", emotion.shape)
    X_tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=40).fit_transform(DB_audio)
    
    plt.figure(figsize=(8, 6))
    print(np.unique(emotion))
    for i in np.unique(emotion):
        indices = np.where(emotion == i)
        plt.scatter(X_tsne[indices, 0], X_tsne[indices, 1], label=f'Class {i}')

    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.title('t-SNE Visualization')
    plt.legend()
    plt.show()

if __name__ == '__main__':

    parser = argparse.ArgumentParser('train', parents=[get_args_parser()])
    args = parser.parse_args()
    
    if os.path.exists(args.save_dir) == False:
        os.makedirs(args.save_dir)
        
    # train(args)
    # save_arguments_to_file(args)
    # validation(args)
    visualze_embeddings(args)

    print("DONE!")
class AudioDataset(data.Dataset):
    def __init__(self, args, split='train'):
        # split
        self.split = split
        self.data_path_file = args.data_path_file
        
        # data path
        self.audio_dir = args.audio_feature_dir
        self.expression_feature_dir = args.expression_feature_dir
      
        # max sequence length
        self.audio_feature_len = args.audio_feature_len
        self.expression_feature_len = args.expression_feature_len
        
        # list for features
        self.inputs = []
        
        #load data path file
        with open(self.data_path_file) as f:
            uid_dict = json.load(f)
            
        uid_list = uid_dict[self.split]
        
        for uid in tqdm.tqdm(uid_list):
            actor_name = 'Actor_' + uid.split('-')[-1]
            emotion = int(uid.split('-')[2])
            intensity = int(uid.split('-')[3])

            audio_path = self.audio_dir +'/'+ actor_name +'/'+ uid + '.npy'
            # audio_samples,sample_rate = librosa.core.load(audio_path, sr=16000)
            audio_samples = np.load(audio_path)
            audio_samples_lib = torch.tensor(audio_samples, dtype=torch.float32)
            audio_samples = torch.squeeze(processor(audio_samples_lib, sampling_rate=16000, return_tensors="pt").input_values)

            # generate input samples by slicing
            for audio_start in  range(12800, audio_samples.shape[0] - self.audio_feature_len -12800, 1600):

                audio_samples_slice = audio_samples[audio_start:audio_start+self.audio_feature_len]
                

                self.inputs.append([emotion, intensity, audio_samples_slice])
            # if the dataset is made for the first time, save the sampless
        # print(len(self.inputs))
            # print(audio_samples_slice.shape)
            # print(expression_samples_slice.shape)
    
    def __getitem__(self, index):
        return self.inputs[index]
    
    def __len__(self):
        return len(self.inputs)
def visualze_embeddings(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Validation start using {device}...")
    
    print("Loading models...")
    DEE = models.DEE(args)
    DEE = DEE.to(device)
    
    print("Loading validation dataset...")
    val_dataset = datasets.AudioDataset(args, split = 'train')
    val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=4)
  

    checkpoint_path = f"{args.save_dir}/model_99.pt"
    checkpoint = torch.load(checkpoint_path)
    DEE.load_state_dict(checkpoint)

    DEE.eval()
    audio_embeddings = []
    emotions = []
    for audio_samples in tqdm.tqdm(val_dataloader) :
        '''
        audio_samples : [emotion, intensity, audio_processed]
        expression_samples : [emotion, intensity, expression_processed]
        '''
        emotion, intensity, audio_processed = audio_samples

        emotion = torch.tensor(emotion).unsqueeze(1)
        intensity = torch.tensor(intensity).unsqueeze(1)
        audio_processed = audio_processed.to(device)

        with torch.no_grad():
            audio_embedding = DEE.encode_audio(audio_processed)
        print(audio_embedding)

 
        DB_audio = audio_embedding # (BS,512)
        audio_embeddings.append(DB_audio)
        emotions.append(emotion)
        DB_emosity = torch.cat((emotion, intensity), dim=1) # (BS,2)
    DB_audio = torch.cat(audio_embeddings, dim=0)
    DB_audio = DB_audio.cpu().numpy()
    emotion = torch.cat(emotions, dim=0)
    emotion = emotion.numpy().reshape(DB_audio.shape[0],)
    for i in [1,2,4,6, 7,8]:
        DB_audio = np.delete(DB_audio, np.where(emotion == i), axis=0)
        emotion = np.delete(emotion, np.where(emotion == i), axis=0)
    print("DB_audio shape: ", DB_audio.shape)
    print("emotion shape: ", emotion.shape)
    X_tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=40).fit_transform(DB_audio)
    
    plt.figure(figsize=(8, 6))
    print(np.unique(emotion))
    for i in np.unique(emotion):
        indices = np.where(emotion == i)
        plt.scatter(X_tsne[indices, 0], X_tsne[indices, 1], label=f'Class {i}')

    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.title('t-SNE Visualization')
    plt.legend()
    plt.show()
import os
import torch

import argparse
from config import get_args_parser
import models
import datasets
from clip_loss import CLIP_loss
from torch.utils.data import DataLoader
import tqdm
import time
import json
import random
import glob
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
if __name__ == '__main__':

    parser = argparse.ArgumentParser('train', parents=[get_args_parser()])
    args = parser.parse_args()
    
    if os.path.exists(args.save_dir) == False:
        os.makedirs(args.save_dir)
        
    # train(args)
    # save_arguments_to_file(args)
    # validation(args)
    PCA_visualize(args)

    print("DONE!")