torch_scatter.scatter()的使用方法
学习目标:
在学习PyG时,遇到了 scatter 这个函数,经过学习加上自身的理解,记录如下以备复习
学习内容:
- src:表示输入的tensor,接下来被处理;
- index:表示tensor对应的索引;
- dim:该值取0或者1(-1),默认是1;当
dim=0
时,表示从行
进行分割成元素;当dim=1
时,表示从列
进行分割成元素。 - reduce:表示对应的操作
具体操作如下:
例子1
from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)
1.首先是dim=0
表示对输入的tensor进行行
分割:[1,2,3],[4,5,6],[7,8,9]。
2.索引index=[0,0,1]表示处理的顺序:第一行元素和第二行元素进行处理,再是第三行的元素进行进行。对第一行元素[1,2,3]和第二行元素[4,5,6]进行reduce='mean'
得到[2.5,3.5,4.5],对第三行元素[7,8,9]进行reduce='mean'得到[7,8,9]
.
例子2
from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
index = torch.tensor([0, 0, 1], dtype=torch.int64)
out = scatter(src, index, dim=1, reduce='mean')
print(out)
1.首先dim=1
表示对输入的tensor进行列向
分割元素[1,4,7]、[2,5,8]和[3,6,9]。
2.索引index=[0,0,1]
表示将[1,4,7]和[2,5,8]首先进行reduce='mean'
操作得到[1.5,4.5,7.5];[3,6,9]进行reduce=mean
操作后仍为[3,6,9],接着将其进行列向
拼接。
例子3–维度问题
from torch_scatter import scatter
src = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
index = torch.tensor([1, 1, 0,2], dtype=torch.int64)
out = scatter(src, index, dim=0, reduce='mean')
print(out)
1.dim=0
表示从行向
进行分割
[1,2,3]
[4,5,6]
[7,8,9]
[10,11,12]
2.索引index=[1,1,0,2]
,从索引可以看出顺序为[7,8,9]——[1,2,3]和[4,5,6]——[10,11,12],分别进行reduce='mean'
操作得到[7,8,9]——[2.5,3.5,4.5]——[10,11,12]三个tensor,然后进行行向
拼接。
KevinLi945: 你的“conda activate tensorflow”是指你启用了一个名叫tensorflow的环境,名字容易误导别人而且你自己在4都说错了
寻找buff的小白兔: 为啥显示:ModuleNotFoundError: No module named 'tensorflow' ?????????
yufy001: 全都叫tensorflow……环境、内核分不清……弄不成
CSDN-Ada助手: 非常感谢您分享torch_scatter.scatter()的使用方法,这对于深度学习爱好者来说非常有用。我们期待您写更多类似的技术博客,并且建议您可以尝试写一篇介绍PyTorch的分布式训练的博文,分享一下您的经验和技巧,这对于需要进行大规模训练的用户来说会非常有价值。期待您的下一篇博客! 2023年博客之星「城市赛道」年中评选已开启(https://activity.csdn.net/creatActivity?id=10470&utm_source=blog_comment_city ), 博主的原力值在所在城市已经名列前茅,持续创作就有机会成为所在城市的 TOP1 博主(https://bbs.csdn.net/forums/blogstar2023?typeId=3152981&utm_source=blog_comment_city),更有丰厚奖品等你来拿~。
菜鸟要爱学习: 那就重新试一下呗