问题产生的原因是使用nn.CrossEntropyLoss()来计算损失的时候,target的维度超过4
import torch
import torch.nn as nn
logit = torch.ones(size=(4, 32, 256, 256)) # b,c,h,w
target = torch.ones(size=(4, 1, 256, 256))
criterion = nn.CrossEntropyLoss()
loss = criterion(logit, target)
如实target中的C不是1,则可以:
import torch
import torch.nn as nn
logit = torch.ones(size=(4, 32, 256, 256)) # b,c,h,w
target = torch.ones(size=(4, 2, 256, 256))
criterion = nn.CrossEntropyLoss()
losses = 0
for i in range(2):
loss = criterion(logit, target[:, i, ...].long())
losses += loss
可以看到代码里面有个.long(),如果不用的话则会报错:文章来源:https://www.toymoban.com/news/detail-508363.html
RuntimeError: expected scalar type Long but found Float文章来源地址https://www.toymoban.com/news/detail-508363.html
到了这里,关于only batches of spatial targets supported (3D tensors) but got targets of dimension的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!