torch.cdist高效计算大矩阵相似度
问题定义现有矩阵A∈RN×C,B∈RM×CA\in R^{N\times C}, B\in R^{M\times C}A∈RN×C,B∈RM×C,需要计算矩阵AAA和BBB的相似度(欧式距离)矩阵S∈RN×MS\in R^{N\times M}S∈RN×M,NNN和MMM很大。可以使用pytorch提供的torch.cdist方法,记得使用GPU计算。import torchN, M, C = 2
·
问题定义
现有矩阵A∈RN×C,B∈RM×CA\in R^{N\times C}, B\in R^{M\times C}A∈RN×C,B∈RM×C,需要计算矩阵AAA和BBB的相似度(欧式距离)矩阵S∈RN×MS\in R^{N\times M}S∈RN×M,NNN和MMM很大。可以使用pytorch提供的torch.cdist
方法,记得使用GPU计算。
import torch
N, M, C = 20000, 50000, 128
A = torch.rand((N, C)).cuda()
B = torch.rand((M, C)).cuda()
S = torch.cdist(A, B, p=2)
print(S.shape)

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐
所有评论(0)