前言
在使用Pytorch训练模型时,用到python中的item()函数,如:
train_loss += loss.item()
现对item()函数用法做出总结。item()函数的作用是从包含单个元素的张量中取出该元素值,并保持该元素的类型不变。,即:该元素为整形,则返回整形,该元素为浮点型,则返回浮点型。官网解释如下:
Pytorch官网:https://pytorch.org/docs/stable/tensors.html?highlight=item#torch.Tensor.item
实验
做个测试:文章来源:https://www.toymoban.com/news/detail-546205.html
import torch
x = torch.randn(2, 2)
print(x)
print(x[0,0])
print(x[0,0].item())
Output:文章来源地址https://www.toymoban.com/news/detail-546205.html
tensor([[-0.1405, 2.4767],
[-0.6847, 0.0057]])
tensor(-0.1405)
-0.14052967727184296
总结
- 计算loss或者accuracy时,经常使用item()函数,而不是直接取对应的元素x[i,j]。
- item()函数取值时,保持该元素的类型不变。
到了这里,关于Pytorch/Python中item()的用法的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!