PyTorch-Lightning:trining_step的自动优化
使用PyTorch-Lightning时,在trining_step定义损失,在没有定义损失,没有任何返回的情况下没有报错,在定义一个包含loss的多个元素字典返回时,也可以正常训练,那么到底lightning是怎么完成训练过程的。
总结:
在自动优化中,training_step必须返回一个tensor或者dict或者None(跳过),对于简单的使用,在training_step可以return一个tensor会作为Loss回传,也可以return一个字典,其中必须包括key"loss",字典中的"loss"会提取出来作为Loss回传,具体过程主要包含在lightning\pytorch\loop\sautomatic.py中的_ AutomaticOptimization()类。
class _ AutomaticOptimization()
实现自动优化(前向,梯度清零,后向,optimizer step)
在training_epoch_loop中会调用这个类的run函数。
def run
首先通过 _make_closure得到一个closure,详见def _make_closure,最后返回一个字典,如果我们在training_step只return了一个loss tensor则字典只有一个’loss’键值对,如果return了一个字典,则包含其他键值对。
可以看到调用了_ optimizer_step,_ optimizer_step经过层层调用,最后会调用torch默认的optimizer.zero_grad,而我们通过 make_closure得到的closure作为参数传入,具体而言是调用了closure类的_ call __()方法。
def run(self, optimizer: Optimizer, batch_idx: int, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
closure = self._make_closure(kwargs, optimizer, batch_idx)
if (
# when the strategy handles accumulation, we want to always call the optimizer step
not self.trainer.strategy.handles_gradient_accumulation and self.trainer.fit_loop._should_accumulate()
):
# For gradient accumulation
# -------------------
# calculate loss (train step + train step end)
# -------------------
# automatic_optimization=True: perform ddp sync only when performing optimizer_step
with _block_parallel_sync_behavior(self.trainer.strategy, block=True):
closure()
# ------------------------------
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
else:
self._optimizer_step(batch_idx, closure)
result = closure.consume_result()
if result.loss is None:
return {}
return result.asdict()
def _make_closure
创建一个closure对象,来捕捉给定的参数并且运行’training_step’和可选的其他如backword和zero_grad函数
比较重要的是step_fn,在这里调用了_training_step,得到的是一个存储我们在定义模型时重写的training step的输出所构成ClosureResult数据类。具体见def _training_step
def _make_closure(self, kwargs: OrderedDict, optimizer: Optimizer, batch_idx: int) -> Closure:
step_fn = self._make_step_fn(kwargs)
backward_fn = self._make_backward_fn(optimizer)
zero_grad_fn = self._make_zero_grad_fn(batch_idx, optimizer)
return Closure(step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn)
def _training_step
通过hook函数实现真正的训练step,返回一个存储training step输出的ClosureResult数据类。
将我们在定义模型时定义的lightning.pytorch.core.LightningModule.training_step的输出作为参数传入存储容器class ClosureResult的from_training_step_output方法,见class Closure
class ClosureResult():
一个数据类,包含closure_loss,loss,extra
closure_loss: Optional[Tensor]
loss: Optional[Tensor] = field(init=False, default=None)
extra: Dict[str, Any] = field(default_factory=dict)
def from_training_step_output
一个类方法,如果我们在training_step定义的返回是一个字典,则我们会将key值为"loss"的value赋值给closure_loss,而其余的键值对赋值给extra字典,如果返回的既不是包含"loss"的字典也不是tensor,则会报错。当我们在training_step不设定返回,则自然为None,但是不会报错。
class Closure
闭包是将外部作用域中的变量绑定到对这些变量进行计算的函数变量,而不将它们明确地作为输入。这样做的好处是可以将闭包传递给对象,之后可以像函数一样调用它,但不需要传入任何参数。
在lightning,用于自动优化的Closure类将training_step和backward, zero_grad三个基本的闭包结合在一起。
这个Closure得到training循环中的结果之后传入torch.optim.Optimizer.step。
参数:文章来源:https://www.toymoban.com/news/detail-852831.html
- step_fn: 这里是一个存储了training step输出的ClosureResult数据类,见def _training_step
- backward_fn: 梯度回传函数
- zero_grad_fn: 梯度清零函数
按照顺序,会先检查得到loss,之后调用梯度清零函数,最后调用梯度回传函数文章来源地址https://www.toymoban.com/news/detail-852831.html
class Closure(AbstractClosure[ClosureResult]):
warning_cache = WarningCache()
def __init__(
self,
step_fn: Callable[[], ClosureResult],
backward_fn: Optional[Callable[[Tensor], None]] = None,
zero_grad_fn: Optional[Callable[[], None]] = None,
):
super().__init__()
self._step_fn = step_fn
self._backward_fn = backward_fn
self._zero_grad_fn = zero_grad_fn
@override
@torch.enable_grad()
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
step_output = self._step_fn()
if step_output.closure_loss is None:
self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")
if self._zero_grad_fn is not None:
self._zero_grad_fn()
if self._backward_fn is not None and step_output.closure_loss is not None:
self._backward_fn(step_output.closure_loss)
return step_output
@override
def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
self._result = self.closure(*args, **kwargs)
return self._result.loss
到了这里,关于PyTorch-Lightning:trining_step的自动优化的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!