Saving and Loading Models - PyTorch Tutorials 1.10.1+cu102 documentation

transfer_learning_ipynb의_사본.ipynb

Model.save()


#saving model parameters
torch.save(model.state_dict(),os.path.join(MODEL_PATH,"model.pt"))
#loading the saved parameters to new_model
new_model=CustomModelClass()
new_model.load_state_dict(torch.load(os.path.join(MODEL_PATH,"model.pt")))

#saving the model itself
torch.save(model,os.path.join(MODEL_PATH,"model.pt"))
#loading saved model to new_model
new_model=torch.load(os.path.join(MODEL_PATH,"model.pt"))

Checkpoints

Transfer Learning

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg = models.vgg16(pretained=True).to(device)