import os
import torch
import argparse
from config import get_args_parser
import models
import datasets
from clip_loss import CLIP_loss
from torch.utils.data import DataLoader
import tqdm
import time
import json
import random
import glob
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
def validation(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Validation start using {device}...")
print("Loading models...")
DEE = models.DEE(args)
DEE = DEE.to(device)
print("Loading validation dataset...")
val_dataset = datasets.AudioExpressionDataset(args, split = 'val')
val_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=True, num_workers=4)
print("Validation starts...")
checkpoint_path = f"{args.save_dir}/model_99.pt"
checkpoint = torch.load(checkpoint_path)
DEE.load_state_dict(checkpoint)
DEE.eval()
for audio_samples, expression_samples in tqdm.tqdm(val_dataloader) :
'''
audio_samples : [emotion, intensity, audio_processed]
expression_samples : [emotion, intensity, expression_processed]
'''
emotion, intensity, expression_processed = expression_samples
audio_processed = audio_samples[2]
emotion = torch.tensor(emotion).unsqueeze(1).to(device) #(BS,1)
intensity = torch.tensor(intensity).unsqueeze(1).to(device) #(BS,1)
expression_processed = expression_processed.to(device)
audio_processed = audio_processed.to(device)
with torch.no_grad():
audio_embedding = DEE.encode_audio(audio_processed)
expression_embedding = DEE.encode_parameter(expression_processed)
DB_expression_transposed = expression_embedding.T # (512,BS)
DB_audio = audio_embedding # (BS,512)
DB_emosity = torch.cat((emotion, intensity), dim=1) # (BS,2)
# compute similarity btw audio, expression
sim_matrix = DB_audio @ DB_expression_transposed
# compute accuracy
expression_accuracy , audio_accuracy = retrival_accuracy(sim_matrix, DB_emosity)
print('expression retrieval accuracy :')
print(expression_accuracy)
print('audio retrieval accuracy : ')
print(audio_accuracy)
def save_arguments_to_file(args, filename='arguments.json'):
with open(filename, 'w') as file:
json.dump(vars(args), file)
def visualze_embeddings(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Validation start using {device}...")
print("Loading models...")
DEE = models.DEE(args)
DEE = DEE.to(device)
print("Loading validation dataset...")
val_dataset = datasets.AudioDataset(args, split = 'train')
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=4)
checkpoint_path = f"{args.save_dir}/model_99.pt"
checkpoint = torch.load(checkpoint_path)
DEE.load_state_dict(checkpoint)
DEE.eval()
audio_embeddings = []
emotions = []
for audio_samples in tqdm.tqdm(val_dataloader) :
'''
audio_samples : [emotion, intensity, audio_processed]
expression_samples : [emotion, intensity, expression_processed]
'''
emotion, intensity, audio_processed = audio_samples
emotion = torch.tensor(emotion).unsqueeze(1)
intensity = torch.tensor(intensity).unsqueeze(1)
audio_processed = audio_processed.to(device)
with torch.no_grad():
audio_embedding = DEE.encode_audio(audio_processed)
print(audio_embedding)
DB_audio = audio_embedding # (BS,512)
audio_embeddings.append(DB_audio)
emotions.append(emotion)
DB_emosity = torch.cat((emotion, intensity), dim=1) # (BS,2)
DB_audio = torch.cat(audio_embeddings, dim=0)
DB_audio = DB_audio.cpu().numpy()
emotion = torch.cat(emotions, dim=0)
emotion = emotion.numpy().reshape(DB_audio.shape[0],)
for i in [1,2,4,6, 7,8]:
DB_audio = np.delete(DB_audio, np.where(emotion == i), axis=0)
emotion = np.delete(emotion, np.where(emotion == i), axis=0)
print("DB_audio shape: ", DB_audio.shape)
print("emotion shape: ", emotion.shape)
X_tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=40).fit_transform(DB_audio)
plt.figure(figsize=(8, 6))
print(np.unique(emotion))
for i in np.unique(emotion):
indices = np.where(emotion == i)
plt.scatter(X_tsne[indices, 0], X_tsne[indices, 1], label=f'Class {i}')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.title('t-SNE Visualization')
plt.legend()
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser('train', parents=[get_args_parser()])
args = parser.parse_args()
if os.path.exists(args.save_dir) == False:
os.makedirs(args.save_dir)
# train(args)
# save_arguments_to_file(args)
# validation(args)
visualze_embeddings(args)
print("DONE!")
class AudioDataset(data.Dataset):
def __init__(self, args, split='train'):
# split
self.split = split
self.data_path_file = args.data_path_file
# data path
self.audio_dir = args.audio_feature_dir
self.expression_feature_dir = args.expression_feature_dir
# max sequence length
self.audio_feature_len = args.audio_feature_len
self.expression_feature_len = args.expression_feature_len
# list for features
self.inputs = []
#load data path file
with open(self.data_path_file) as f:
uid_dict = json.load(f)
uid_list = uid_dict[self.split]
for uid in tqdm.tqdm(uid_list):
actor_name = 'Actor_' + uid.split('-')[-1]
emotion = int(uid.split('-')[2])
intensity = int(uid.split('-')[3])
audio_path = self.audio_dir +'/'+ actor_name +'/'+ uid + '.npy'
# audio_samples,sample_rate = librosa.core.load(audio_path, sr=16000)
audio_samples = np.load(audio_path)
audio_samples_lib = torch.tensor(audio_samples, dtype=torch.float32)
audio_samples = torch.squeeze(processor(audio_samples_lib, sampling_rate=16000, return_tensors="pt").input_values)
# generate input samples by slicing
for audio_start in range(12800, audio_samples.shape[0] - self.audio_feature_len -12800, 1600):
audio_samples_slice = audio_samples[audio_start:audio_start+self.audio_feature_len]
self.inputs.append([emotion, intensity, audio_samples_slice])
# if the dataset is made for the first time, save the sampless
# print(len(self.inputs))
# print(audio_samples_slice.shape)
# print(expression_samples_slice.shape)
def __getitem__(self, index):
return self.inputs[index]
def __len__(self):
return len(self.inputs)
def visualze_embeddings(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Validation start using {device}...")
print("Loading models...")
DEE = models.DEE(args)
DEE = DEE.to(device)
print("Loading validation dataset...")
val_dataset = datasets.AudioDataset(args, split = 'train')
val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=4)
checkpoint_path = f"{args.save_dir}/model_99.pt"
checkpoint = torch.load(checkpoint_path)
DEE.load_state_dict(checkpoint)
DEE.eval()
audio_embeddings = []
emotions = []
for audio_samples in tqdm.tqdm(val_dataloader) :
'''
audio_samples : [emotion, intensity, audio_processed]
expression_samples : [emotion, intensity, expression_processed]
'''
emotion, intensity, audio_processed = audio_samples
emotion = torch.tensor(emotion).unsqueeze(1)
intensity = torch.tensor(intensity).unsqueeze(1)
audio_processed = audio_processed.to(device)
with torch.no_grad():
audio_embedding = DEE.encode_audio(audio_processed)
print(audio_embedding)
DB_audio = audio_embedding # (BS,512)
audio_embeddings.append(DB_audio)
emotions.append(emotion)
DB_emosity = torch.cat((emotion, intensity), dim=1) # (BS,2)
DB_audio = torch.cat(audio_embeddings, dim=0)
DB_audio = DB_audio.cpu().numpy()
emotion = torch.cat(emotions, dim=0)
emotion = emotion.numpy().reshape(DB_audio.shape[0],)
for i in [1,2,4,6, 7,8]:
DB_audio = np.delete(DB_audio, np.where(emotion == i), axis=0)
emotion = np.delete(emotion, np.where(emotion == i), axis=0)
print("DB_audio shape: ", DB_audio.shape)
print("emotion shape: ", emotion.shape)
X_tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=40).fit_transform(DB_audio)
plt.figure(figsize=(8, 6))
print(np.unique(emotion))
for i in np.unique(emotion):
indices = np.where(emotion == i)
plt.scatter(X_tsne[indices, 0], X_tsne[indices, 1], label=f'Class {i}')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.title('t-SNE Visualization')
plt.legend()
plt.show()
import os
import torch
import argparse
from config import get_args_parser
import models
import datasets
from clip_loss import CLIP_loss
from torch.utils.data import DataLoader
import tqdm
import time
import json
import random
import glob
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
if __name__ == '__main__':
parser = argparse.ArgumentParser('train', parents=[get_args_parser()])
args = parser.parse_args()
if os.path.exists(args.save_dir) == False:
os.makedirs(args.save_dir)
# train(args)
# save_arguments_to_file(args)
# validation(args)
PCA_visualize(args)
print("DONE!")