码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • SoftTriple Loss


    目录

    19-ICCV-SoftTriple Loss:Deep Metric Learning Without Triplet Sampling

    SoftTriple Loss

    Multiple Centers

    Adaptive Number of Centers


    19-ICCV-SoftTriple Loss:Deep Metric Learning Without Triplet Sampling

    1)SoftMax loss is equivalent to a smoothed triplet loss where each class has a single center.

    现实中一个类不只有一个中心,例如鸟有很多姿势(从细粒度角度解释)。扩展SoftMax loss,每个类有多中心。

    2)learn the embeddings without the sampling phase by mildly increasing the size of the last fully connected layer.不需要采样。
     

    SoftTriple Loss

    最小化有平滑项 λ的normalized SoftMax loss=最大化平滑的triplet loss

    这接下来都是证明推导了些啥??

     

    Multiple Centers

    每个类c有k个中心。

    对于样本xi选择相似度最大的中心。

     样本xi与所属类yi的距离比其他类j小。

     

    Inspired by the SoftMax loss, improve the robustness by smoothing the max operator.

    原本是直接选最大值。

    现在是对所有值加权求和,为保证和最大,原本较大的值对应的权值q一定也大。

     

     

     

    类中心越多,类内方差越小;中心数=样本数时,类内方差为0。

    Adaptive Number of Centers

    一个中心到其他中心的距离

     K个中心间的L2距离求和

    共N个样本,最小化中心间的距离,为0时即合并。

     

    1. class SoftTriple(nn.Module):
    2. def __init__(self, la, gamma, tau, margin, dim, cN, K):
    3. super(SoftTriple, self).__init__()
    4. self.la = la
    5. self.gamma = 1./gamma
    6. self.tau = tau
    7. self.margin = margin
    8. self.cN = cN # 有cN个类
    9. self.K = K # 每个类K个中心
    10. self.fc = Parameter(torch.Tensor(dim, cN*K))
    11. self.weight = torch.zeros(cN*K, cN*K, dtype=torch.bool).cuda()
    12. for i in range(0, cN):
    13. for j in range(0, K):
    14. self.weight[i*K+j, i*K+j+1:(i+1)*K] = 1
    15. init.kaiming_uniform_(self.fc, a=math.sqrt(5))
    16. return
    17. def forward(self, input, target):
    18. centers = F.normalize(self.fc, p=2, dim=0)
    19. simInd = input.matmul(centers)
    20. simStruc = simInd.reshape(-1, self.cN, self.K)
    21. prob = F.softmax(simStruc*self.gamma, dim=2)
    22. simClass = torch.sum(prob*simStruc, dim=2)
    23. marginM = torch.zeros(simClass.shape).cuda()
    24. marginM[torch.arange(0, marginM.shape[0]), target] = self.margin
    25. lossClassify = F.cross_entropy(self.la*(simClass-marginM), target)
    26. if self.tau > 0 and self.K > 1:
    27. simCenter = centers.t().matmul(centers)
    28. reg = torch.sum(torch.sqrt(2.0+1e-5-2.*simCenter[self.weight]))/(self.cN*self.K*(self.K-1.))
    29. return lossClassify+self.tau*reg
    30. else:
    31. return lossClassify

     代码里Sij也减去marginM了?

    代码:GitHub - idstcv/SoftTriple: PyTorch Implementation for SoftTriple Loss

  • 相关阅读:
    [2023-09-12]Oracle备库查询报ORA-01187
    C++ while 循环
    关于企业中台规划和 IT 架构微服务转型
    【深入浅出Docker原理及实战】「原理实战体系」零基础+全方位带你学习探索Docker容器开发实战指南(底层实现系列)
    Flink watermark与乱序消息处理机制
    SpringBoot第49讲:SpringBoot定时任务 - 基础quartz实现方式
    Spring入门
    爬虫都可以干什么?
    家用洗地机什么牌子最好用?2023年洗地机排行榜
    通讯网关软件014——利用CommGate X2HTTP实现HTTP访问OPC Server
  • 原文地址:https://blog.csdn.net/weixin_44742887/article/details/125521006
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号