# 目的是求a、b之间的余弦相似度。一般是[bs,h]形状。以[2,4]为例。


a = torch.rand(2, 4)
b = torch.rand(2, 4)
print('初始输入\n', a, '\n', b)

ab = torch.matmul(a, b.t())
print('ab相乘\n', ab)

aa = a*a
print('a*a\n',aa)
aa = aa.sum(1, keepdim=True)
print('求和aa')
aa = aa ** 0.5
print('开方aa\n', aa)

bb = b * b
print('b*b\n',bb)
bb = bb.sum(1, keepdim=True)
print('求和bb')
bb = bb ** 0.5
print('开方bb\n', bb)

   
c = ab/aa/bb
print('自己算的cosi相似度\n', c)

d = torch.nn.functional.cosine_similarity(a.unsqueeze(1), b.unsqueeze(0),dim=2)    
print('固有算法求cos相似度\n', d)    

这段代码实际上没有实现。你会发现只有对角线的位置计算正确,而非对角线位置错误。

a:[2,4]  
按行分块 结果可写为:
A1
A2

b:[2,4]按行分块
B1
B2

a*b.t()相当于是
A1B1 A1B2
A2B1 A2B2

a*b.t()/aa/bb相当于是
A1B1/|A1||B1| A1B2/|A1||B1|
A2B1/|A2||B2| A2B2/|A2||B2|

|A1|表示向量2范数。就是求aa和bb的过程。那这时就能知道为什么错了。

那应该怎样才能实现cos这个功能呢?实际上就是怎么得到除法中的被除数。仍然是矩阵乘法。

e = ab / (aa.matmul(bb.t()))
print('改进后的余弦相似度计算\n', e)

代入验证。可以发现实现了余弦相似度计算功能。

写这个的目的就是,每次遇到了都要花上半天才能搞明白。

cosine_similarity的结果,每个位置是算[1,h]和[1,h]向量的余弦相似度。这个计算就是ab/|a||b|。而|a|=(a1*a1+a2*a2+...)**0.5。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐