前言
pytorch lighting是导师推荐给我学习的一个轻量级的PyTorch库,代码干净简洁,使用pl更容易理解ML代码,对于初学者的我还是相对友好的。
pytorch lightning官网网址
https://lightning.ai/docs/pytorch/stable/levels/core_skills.html
多层感知机pl代码
1.引入库
代码如下:
import os
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from torchvision import transforms
from torch.utils import data
# 处理anaconda和torch重复文件
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
2.读入数据
代码如下:(可以直接把download改为true下载)文章来源:https://www.toymoban.com/news/detail-530183.html
def load_data_fashion_mnist(batch_size, resize=None): # 图片28*28*1
"""在本地读入Fashion-MNIST数据集"""
trans = [transforms.ToTensor()] # 把图片转换为pytorch tensor
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="D:/python_project/fashion-mnist-master/fashion-mnist-master/data/fashion",
train=True,
transform=trans,
download=False
)
mnist_test = torchvision.datasets.FashionMNIST(
root="D:/python_project/fashion-mnist-master/fashion-mnist-master/data/fashion",
train=False,
transform=trans,
download=False
)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=0),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=0))
3.pl二层感知机
# 二层感知机
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Sequential(nn.Linear(28*28, 256), nn.ReLU(), nn.Linear(256, 10))
def forward(self, x):
return self.l1(x)
class Perceptron(pl.LightningModule):
# pl模块和nn模块交互
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self.encoder(x)
loss = F.cross_entropy(y_hat, y)
print("train_loss=", loss)
return loss
def test_step(self, batch, batch_idx):
# this is the test loop
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self.encoder(x)
test_loss = F.cross_entropy(y_hat, y)
self.log("test_loss", test_loss)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=1e-1)
return optimizer
batch_size = 256
# 训练/测试集
train_loader, test_loader = load_data_fashion_mnist(batch_size)
# 模型
model = Perceptron(Encoder())
# 训练模型
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=train_loader)
# 测试
trainer.test(dataloaders=test_loader)
总结
更多pl的方法,可以到pl官网查看文章来源地址https://www.toymoban.com/news/detail-530183.html
到了这里,关于(5)深度学习学习笔记-多层感知机-pytorch lightning版的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!