这几天跑代码的时候,跑着跑着就显示被killed掉(整个人都不好了)。查系统日志发现是内存不够(out of memory),没……办……法……了,直接放弃!当然这是不可能的,笔者怎么可能是个轻言放弃的人呢,哈哈。文章来源:https://www.toymoban.com/news/detail-536177.html
言归正传,笔者用的设备是3T的硬盘,跑的程序batch_size=1024,共划分了2000个batch,每跑一个batch内存占用率就会升高,0.5%左右,无奈之下只能一句一句debug,最后后发现是损失累加造成的,如下放所示,代码共计算了三个损失:BPR Loss , Reg Loss , InfoNCE Loss,不能直接累加作为total_loss!而是通过.item()将损失值取出,再累加。文章来源地址https://www.toymoban.com/news/detail-536177.html
# BPR Loss
bpr_loss = -torch.sum(F.logsigmoid(sup_logits))
# Reg Loss
reg_loss = l2_loss(
self.lightgcn.user_embeddings(bat_users),
self.lightgcn.item_embeddings(bat_pos_items),
self.lightgcn.item_embeddings(bat_neg_items),
)
# InfoNCE Loss
clogits_user = torch.logsumexp(ssl_logits_user / self.ssl_temp, dim=1)
clogits_item = torch.logsumexp(ssl_logits_item / self.ssl_temp, dim=1)
infonce_loss = torch.sum(clogits_user + clogits_item)
loss = bpr_loss + self.ssl_reg * infonce_loss + self.reg * reg_loss
total_loss = total_loss + loss.item()
total_bpr_loss += bpr_loss.item()
total_reg_loss += self.reg * reg_loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
到了这里,关于Pytorch堆叠多个损失造成内存爆炸的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!