介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还引入checkpoint和早停策略,以得到模型最佳权重。
相关链接:https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html
训练集、验证集、测试集的使用
1.添加依赖,获取训练集和测试集
添加相应的依赖,同时使用MNIST数据集,获取训练和测试集
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 加载数据(测试集,train=False)
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
2.实现并调用test_step
在定义LightningModule中,实现test_step方法;在外部,调用test方法
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
# 初始化Trainer
trainer = Trainer()
# 执行test方法
trainer.test(model, dataloaders=DataLoader(test_set))
3.实现并调用验证集
通常使用torch.utils.data中的方法,将训练集中的一部分数据化为验证集
# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
与测试集一样,需要在定义LightningModule中,实现validation_step方法;在外部,调用fit方法
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
def test_step(self, batch, batch_idx):
...
# 调用torch.utils.data中的DataLoader对训练和测试集进行封装
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
# 在fit方法中,引入valid_loader,即验证集
trainer = Trainer()
trainer.fit(model, train_loader, valid_loader)
checkpoint
checkpoint有两个作用,一是能得到每一次epoch后的模型权重,能得到最佳表现的权重;二是能够在中断或停止后,继续在当前checkpoint处,继续训练。在Lightning中的checkpoint,包含模型的整个内部状态,这与普通的PyTorch不同,即使在最复杂的分布式训练环境中,Lightning也可以保存恢复模型所需的一切。包含以下状态:
- 16-bit scaling factor (若使用16精度训练)
- Current epoch
- Global step
- LightningModule’s state_dict
- State of all optimizers
- State of all learning rate schedulers
- State of all callbacks (for stateful callbacks)
- State of datamodule (for stateful datamodules)
- The hyperparameters (init arguments) with which the model was created
- The hyperparameters (init arguments) with which the datamodule was created
- State of Loops
保存与调用方法
# 保存方法,可自定义default_root_dir路径,若不设置路径,将会自动保存
trainer = Trainer(default_root_dir="some/path/")
# 调用方法
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
model.eval() # disable randomness, dropout, etc...
y_hat = model(x)
调用,还可以使用torch的方法
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}
也可以实现重现,例如模型LitModel(in_dim=32, out_dim=10)
# 使用 in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# 使用 in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Lightning和PyTorch完全兼容
checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["encoder"]
decoder_weights = checkpoint["decoder"]
设置checkpoint不可见
trainer = Trainer(enable_checkpointing=False)
如果想全部重新恢复
model = LitModel()
trainer = Trainer()
自动恢复所有相关参数 model, epoch, step, LR schedulers, etc…文章来源:https://www.toymoban.com/news/detail-603699.html
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
早停策略
EarlyStopping Callback
在Lightning中,早停回调步骤如下:文章来源地址https://www.toymoban.com/news/detail-603699.html
- Import
EarlyStopping
callback. 载入EarlyStopping回调方法 - Log the metric you want to monitor using
log()
method. 加载日志方法 - Init the callback, and set
monitor
to the logged metric of your choice. 设置monitor - Set the
mode
based on the metric needs to be monitored. 设置mode - Pass the
EarlyStopping
callback to theTrainer
callbacks flag. 调入EarlyStropping
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
class LitModel(LightningModule):
def validation_step(self, batch, batch_idx):
loss = ...
self.log("val_loss", loss)
model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)
# 也可以使用自定义的EarlyStopping策略
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
# EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
注意
- EarlyStopping默认在一次Validation后调用,但是Validation可以自定义多少次epoch后进行一次验证,例如
check_val_every_n_epoch
andval_check_interval
。
完整代码
# coding:utf-8
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import lightning as L
# --------------------------------
# Step 1: 定义模型
# --------------------------------
class LitAutoEncoder(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def forward(self, x):
# forward 定义了一次 预测/推理 行为
embedding = self.encoder(x)
return embedding
# --------------------------------
# Step 2: 加载数据+模型
# --------------------------------
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size
# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)
autoencoder = LitAutoEncoder()
# --------------------------------
# Step 3: 训练+验证+测试
# --------------------------------
# 训练+验证
trainer = L.Trainer(default_root_dir="some/path/") # 这里自定义需要保存的路径
trainer.fit(autoencoder, train_loader, valid_loader)
# 测试
trainer.test(autoencoder, dataloaders=DataLoader(test_set))
到了这里,关于PyTorch Lightning教程二:验证、测试、checkpoint、早停策略的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!