简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风

这篇具有很好参考价值的文章主要介绍了简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

论文链接:https://arxiv.org/pdf/2304.03977.pdf

代码:https://github.com/tsb0601/EMP-SSL

其他学习链接:突破自监督学习效率极限!马毅、LeCun联合发布EMP-SSL:无需花哨trick,30个epoch即可实现SOTA


主要思想

如图,一张图片裁剪成不同的 patch,对不同的 patch 做数据增强,分别输入 encoder,得到多个 embedding,对它们求均值,得到  作为这张图片的 embedding。最后,拉近每个 patch 的 embedding 和图片的 embedding()之间的余弦距离;再用 Total Coding Rate(TCR) 防止坍塌(即 encoder 对所有输入都输出相同的 embedding)

简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风,学习,人工智能,日常学习,自监督学习,深度学习,ssl,对比学习,特征空间

简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风,学习,人工智能,日常学习,自监督学习,深度学习,ssl,对比学习,特征空间

Total Coding Rate(TCR)

公式如下:

简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风,学习,人工智能,日常学习,自监督学习,深度学习,ssl,对比学习,特征空间

其中,det 表示求矩阵的行列式,d 是 feature vector 的 dimension,b 是 batch size

查了查该公式的含义:expand all features of Z as large as possible,即尽可能拉远矩阵中特征之间的距离。

源自 PPT 第 24 页:

https://s3.amazonaws.com/sf-web-assets-prod/wp-content/uploads/2021/06/15175515/Deep_Networks_from_First_Principles.pdf

至于为什么最大化该公式的值就可以拉远矩阵中特征之间的距离,这背后的数学原理真难啃啊 /(ㄒoㄒ)/~~


核心代码解读

数据处理

https://github.com/tsb0601/EMP-SSL/blob/main/dataset/aug.py#L116C1-L138C27

class ContrastiveLearningViewGenerator(object):
    def __init__(self, num_patch = 4):
    
        self.num_patch = num_patch
      
    def __call__(self, x):
    
    
        normalize = transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
        
        aug_transform = transforms.Compose([
            transforms.RandomResizedCrop(32,scale=(0.25, 0.25), ratio=(1,1)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.2)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GBlur(p=0.1),
            transforms.RandomApply([Solarization()], p=0.1),
            transforms.ToTensor(),  
            normalize
        ])
        augmented_x = [aug_transform(x) for i in range(self.num_patch)]
     
        return augmented_x

由此看出返回的 数据 为:长度为 num_patches 个 tensor 的列表。其中,每个 tensor 的 shape 为 (B, C, H, W)。

主函数

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L148C9-L162C63

for step, (data, label) in tqdm(enumerate(dataloader)):
    net.zero_grad()
    opt.zero_grad()
        
    data = torch.cat(data, dim=0) 
    data = data.cuda()
    z_proj = net(data)
            
    z_list = z_proj.chunk(num_patches, dim=0)
    z_avg = chunk_avg(z_proj, num_patches)
            
    # Contractive Loss
    loss_contract, _ = contractive_loss(z_list, z_avg)
    loss_TCR = cal_TCR(z_proj, criterion, num_patches)

这里要稍微注意一下几个变量的 shape:

  • data 被 cat 完后:(num_patches * B,C,H,W)
  • z_proj:(num_patches * B,C)
  • z_list:(num_patches,B,C)
  • z_avg:(B,C)

其中,chunk_avg 就是对来自同一张图片的不同 patch 的 embedding 求均值():

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L67

def chunk_avg(x,n_chunks=2,normalize=False):
    x_list = x.chunk(n_chunks,dim=0)
    x = torch.stack(x_list,dim=0)
    if not normalize:
        return x.mean(0)
    else:
        return F.normalize(x.mean(0),dim=1)

loss

contractive_loss 就是计算每个 patch 的 embedding 和均值()的余弦距离:

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L76

class Similarity_Loss(nn.Module):
    def __init__(self, ):
        super().__init__()
        pass

    def forward(self, z_list, z_avg):
        z_sim = 0
        num_patch = len(z_list)
        z_list = torch.stack(list(z_list), dim=0)
        z_avg = z_list.mean(dim=0)
        
        z_sim = 0
        for i in range(num_patch):
            z_sim += F.cosine_similarity(z_list[i], z_avg, dim=1).mean()
            
        z_sim = z_sim/num_patch
        z_sim_out = z_sim.clone().detach()
                
        return -z_sim, z_sim_out

TCR loss:最大化矩阵之间特征的距离,即拉远负样本(不是来自同一个样本的 patches)之间的距离

https://github.com/tsb0601/EMP-SSL/blob/main/main.py#L96

def cal_TCR(z, criterion, num_patches):
    z_list = z.chunk(num_patches,dim=0)
    loss = 0
    for i in range(num_patches):
        loss += criterion(z_list[i])
    loss = loss/num_patches
    return loss

需要注意:函数输入的 z 是 z_proj,形状为(num_patches * B,C)。

所以,函数内部 z_list 的形状为(num_patches,B,C),即将数据分为了 num_patches 个组,每个组包含了来自不同图片里 patch 的 embedding。再分别对每个组求 TCR loss,最大化组内(不同图片的 patch)特征的距离。

所以,公式中的  指的是一组来自不同图片里 patch 的 embedding,形状为(B,C)。

每个组内求 TCR loss 的代码按照公式计算,如下: 

简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风,学习,人工智能,日常学习,自监督学习,深度学习,ssl,对比学习,特征空间

https://github.com/tsb0601/EMP-SSL/blob/main/loss.py#L76文章来源地址https://www.toymoban.com/news/detail-649898.html

class TotalCodingRate(nn.Module):
    def __init__(self, eps=0.01):
        super(TotalCodingRate, self).__init__()
        self.eps = eps
        
    def compute_discrimn_loss(self, W):
        """Discriminative Loss."""
        p, m = W.shape  #[d, B]
        I = torch.eye(p,device=W.device)
        scalar = p / (m * self.eps)
        logdet = torch.logdet(I + scalar * W.matmul(W.T))
        return logdet / 2.
    
    def forward(self,X):
        return - self.compute_discrimn_loss(X.T)

到了这里,关于简单谈谈 EMP-SSL:自监督对比学习的一种极简主义风的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【论文阅读】基于纤维束成像的新型微结构信息引导的监督对比学习,自动识别视网膜丘视觉通路

    Li, S., Zhang, W., Yao, S., He, J., Zhu, C., Gao, J., Xue, T., Xie, G., Chen, Y., Torio, E. F., Feng, Y., Bastos, D. C. A., Rathi, Y., Makris, N., Kikinis, R., Bi, W. L., Golby, A. J., O’Donnell, L. J., Zhang, F. (2024). Tractography-based automated identification of the retinogeniculate visual pathway with novel microstructure-informed supervised contrast

    2024年02月01日
    浏览(39)
  • 图像融合论文阅读:CS2Fusion: 通过估计特征补偿图谱实现自监督红外和可见光图像融合的对比学习

    @article{wang2024cs2fusion, title={CS2Fusion: Contrastive learning for Self-Supervised infrared and visible image fusion by estimating feature compensation map}, author={Wang, Xue and Guan, Zheng and Qian, Wenhua and Cao, Jinde and Liang, Shu and Yan, Jin}, journal={Information Fusion}, volume={102}, pages={102039}, year={2024}, publisher={Elsevier} } 论文级

    2024年01月22日
    浏览(60)
  • 自监督LIGHTLY SSL教程

    Lightly SSL 是一个用于自监督学习的计算机视觉框架。 github链接:GitHub - lightly-ai/lightly: A python library for self-supervised learning on images. Documentation:Documentation — lightly 1.4.20 documentation 以下内容主要来自Documentation,部分内容省略,部分专业名字不翻译,主要复现教程6。 下图显示了

    2024年02月05日
    浏览(34)
  • 一文谈谈文心一言对比ChatGPT4.0的差距

    对于想体验文心一言的朋友,可以进行申请尝试,快速入口 如果想体验ChatGPT的朋友,可以自行fq注册;但是由于现在限制注册并且不稳定,对于不会用梯子不想注册的朋友可以使用这个进行访问,快速入口 关于ChatGPT对我们的帮助,可以参考我往期博客 看到一篇国金证券的研

    2024年02月08日
    浏览(61)
  • 【137期】面试官问:RocketMQ 与 Kafka 对比,谈谈两者的差异?(1)

    先自我介绍一下,小编浙江大学毕业,去过华为、字节跳动等大厂,目前阿里P7 深知大多数程序员,想要提升技能,往往是自己摸索成长,但自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前! 因此收集整理了一份《2024年最新Java开发全套学习资料》,

    2024年04月27日
    浏览(34)
  • ML:机器学习中有监督学习算法的四种最基础模型的简介(基于概率的模型、线性模型、树模型-树类模型、神经网络模型)、【线性模型/非线性模型、树类模型/基于样本距离的模型】多种对比(假设/特点/决策形式等

    ML:机器学习中有监督学习算法的四种最基础模型的简介(基于概率的模型、线性模型、树模型-树类模型、神经网络模型)、【线性模型/非线性模型、树类模型/基于样本距离的模型】多种对比(假设/特点/决策形式等) 目录

    2024年02月09日
    浏览(64)
  • 【Nginx37】Nginx学习:SSL模块(一)简单配置与指令介绍

    又是一个重点模块,SSL 模块,其实就是我们常见的 HTTPS 所需要的配置模块。HTTPS 的重要性不用多说了吧,现在所有的 App、小程序 都强制要求是 HTTPS 的,即使是网站开发,百度也明确了对 HTTPS 的收录会更好。也就是说,HTTPS 已经成为了事实上的正式环境协议标准。 在 Ngin

    2024年02月06日
    浏览(34)
  • 简单谈谈Feign

    本文只是简单粗略的分析一下 feign 的源码与过程原理 Feign 是 Netflix 开发的声明式、模板化的 HTTP 客户端, Feign 可以 帮助我们更快捷、优雅地调用 HTTP API 。 Spring Cloud 对 Feign 进行了增强,整合了 Spring Cloud Ribbon 和 Spring Cloud Hystrix ,除了提供这两者的强大功能外,还提供了一

    2024年02月09日
    浏览(31)
  • 无/自监督去噪(1)——一个变迁:N2N→N2V→HQ-SSL

    知乎同名账号同步发表 1. 前沿 N2N,即 Noise2Noise: Learning Image Restoration without Clean Data ,2018 ICML的文章。 N2V,即 Noise2Void - Learning Denoising from Single Noisy Images ,2019 CVPR的文章。 这两个工作都是无监督去噪的重要开山之作,本文先对其进行简单总结,然后引出一个变体:HQ-SSL(2

    2024年01月20日
    浏览(38)
  • 无监督去噪的一个变迁(1)——N2N→N2V→HQ-SSL

    知乎同名账号同步发表 1. 前沿 N2N,即 Noise2Noise: Learning Image Restoration without Clean Data ,2018 ICML的文章。 N2V,即 Noise2Void - Learning Denoising from Single Noisy Images ,2019 CVPR的文章。 这两个工作都是无监督去噪的重要开山之作,本文先对其进行简单总结,然后引出一个变体:HQ-SSL(2

    2024年01月17日
    浏览(46)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包