x.1 with torch.no_grad()简述及例子
torch.no_grad()
是PyTorch中的一个上下文管理器(context manager),用于指定在其内部的代码块中不进行梯度计算。当你不需要计算梯度时,可以使用该上下文管理器来提高代码的执行效率,尤其是在推断(inference)阶段和梯度裁剪(grad clip)阶段的时候。
在使用torch.autograd
进行自动求导时,PyTorch会默认跟踪并计算张量的梯度。然而,有时我们只关心前向传播的结果,而不需要计算梯度,这时就可以使用torch.no_grad()
来关闭自动求导功能。在torch.no_grad()
的上下文中执行的张量运算不会被跟踪,也不会产生梯度信息,从而提高计算效率并节省内存。
下面举例一个在关闭梯度跟踪torch.no_grad()
后仍然要更新梯度矩阵y.backward()
的错误例子:
import torch
# 创建两个张量
x = torch.tensor([2.0], requires_grad=True)
w = torch.tensor([3.0], requires_grad=True)
# 在计算阶段使用 torch.no_grad()
with torch.no_grad():
y = x * w
# 输出结果,不会计算梯度
print(y) # tensor([6.])
# 尝试对 y 进行反向传播(会报错)
y.backward() # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
在上面的例子中,我们通过将x和w张量的requires_grad
属性设置为True,表示我们希望计算它们的梯度。然而,在torch.no_grad()
的上下文中,对于y的计算不会被跟踪,也不会生成梯度信息。因此,在执行y.backward()
时会报错。
x.2 with torch.no_grad()在训练阶段使用
with torch.no_grad()
常见于eval()验证集和测试集中,但是有时候我们仍然会在train()训练集中看到,如下:
@d2l.add_to_class(d2l.Trainer) #@save
def prepare_batch(self, batch):
return batch
@d2l.add_to_class(d2l.Trainer) #@save
def fit_epoch(self):
self.model.train()
for batch in self.train_dataloader:
loss = self.model.training_step(self.prepare_batch(batch))
self.optim.zero_grad()
with torch.no_grad():
loss.backward()
if self.gradient_clip_val > 0: # To be discussed later
self.clip_gradients(self.gradient_clip_val, self.model)
self.optim.step()
self.train_batch_idx += 1
if self.val_dataloader is None:
return
self.model.eval()
for batch in self.val_dataloader:
with torch.no_grad():
self.model.validation_step(self.prepare_batch(batch))
self.val_batch_idx += 1
这是因为我们进行了梯度裁剪,在上述代码中,torch.no_grad()
的作用是在计算梯度之前执行梯度裁剪操作。loss.backward()
会计算损失的梯度,但在这个特定的上下文中,我们不希望梯度裁剪的操作被跟踪和计算梯度。因此,我们使用torch.no_grad()
将裁剪操作放在一个没有梯度跟踪的上下文中,以避免计算和存储与梯度裁剪无关的梯度信息。文章来源:https://www.toymoban.com/news/detail-476070.html
而梯度的记录和跟踪实际上已经在loss = self.model.training_step(self.prepare_batch(batch))
中完成了(类似output = model(input)
),而loss.backward()
只是计算梯度并更新了model的梯度矩阵。文章来源地址https://www.toymoban.com/news/detail-476070.html
到了这里,关于with torch.no_grad()解答的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!