🗓️ 2025-05-20 🗂️ 深度学习

Pytorch中的张量操作

数学符号

1. * 运算符

如果张量的形状完全一样,张量 * 张量,其实就是对应元素相乘,也可以$\odot$表示。

如果张量的形状不一样,则要求他们的形状只能在一个维度上不一样,且其中一个张量这个维度的size为1,这样可以用广播机制补全后相乘。

如果张量 * 数字,则是全体元素乘以这个数字。

函数、方法

1. permute()

可以作为张量对象的方法使用,接受置换后的维度。

import torch

table = torch.randn(2,3,4)
print(table)
print(table.permute(2,1,0))

这段代码就是将第一维和第三维进行置换。

由已知的张量还有置换后的顺序,怎么写出新的张量呢,这个问题一直让我困扰。

今天的学习,我脑子里突然蹦出一个邻居的概念。对于一个四维的张量元素,它应该有四个不同维度的邻居,但如果仅仅把视角局限在矩阵里,那它就只有两个邻居,一个行维度上的,一个列维度上的,这很不对。得把视野打开,通道维度和批次维度的邻居可能隔了很远,但它们也是邻居。

2. reshape()view()

深度学习中,经常会把高维张量降维,计算后再展开,要小心翼翼,错位会带来灾难性的错误。

对于一个(N,F,M)的张量,F是其特征维度。如果想把N和M合并,应该先把F变成最高维度或者最低维度,让需要合并的维度接壤

import torch

table = torch.randn(2,3,4)
print(table)
a = table.reshape(3,8)
b = table.permute(1,0,2).reshape(3,8)

print(a)
print(b)

这里先把张量形状变成了(3,2,4),再合并,可以保证张量维度不会错位,如果直接改写性质为(3,8),则会错位。

reshape()view()是一样的,前者可以作用于numpy数组。

3. flatten()

将高维张量展平为一维张量。

4. torch.cat()

把张量按给定的顺序和维度进行拼接,除了拼接维度外,张量必须具有相同的形状。

可用于特征维度的拓展,加入新的特征。

5. unsqueeze()

torch.unsqueeze

torch.unsqueeze(input, dim) → Tensor

为张量插入一个新的size为1的维度。用于升维。配合expand()食用更佳。

6. expand()

torch.Tensor.expand

这个只有张量方法,没有函数。

Tensor.expand(*sizes) → Tensor

将张量沿着某个size为1的维度进行复制拓展。

total_nbr_fea = torch.cat(
	[atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
	atom_nbr_fea, nbr_fea], dim=2)

atom_in_fea张量的形状本是(N,self.atom_fea_len),通过unsqueeze(1),变成(N,1,self.atom_fea_len)。

再通过expand(N, M, self.atom_fea_len),变成(N, M, self.atom_fea_len)形状。

7. chunk()

torch.chunk

torch.chunk(input: Tensor , chunks: int , dim: int = 0) → Tuple[Tensor, …]

尝试把张量分成指定数量的块。也就是分割张量。可指定维度分割。但,数量可能会比指定的少,一般最好成倍数时使用。

可用于分割特征维度。

8. torch.sum()

torch.sum

torch.sum(input, dim, keepdim=False, ***, dtype=None) → Tensor

将张量元素沿指定维度进行加和。

切片

1. 简单切片

import torch

a = torch.tensor(range(12)).view(3,4)
b = a[1,2]
print(a)
print(b)

在张量切片中,逗号表示维度间隔。在这个简单切片中,整数代表索引。

也可以接受start:stop:step格式的索引进行切片,一个简单:代表全选这一维度。

import torch

a = torch.tensor(range(24)).view(2,3,4)
b = a[1,:]
print(a)
print(b)

这里是一个三维张量,但是只用了一个逗号,这是一种简写,其实是逗号后续的所有维度都全选。

2. 高级切片

张量的某个维度的索引也可以是张量。

如果索引是一维张量,据测试,其实是列表也行。就是多个整数索引组成的列表或者一维张量。根据这些整数确定这个维度怎么被取。

import torch
import numpy as np

atom_in_fea = torch.randn(4,5,6)

nbr_fea_idx = torch.tensor([0, 1, 4])

atom_nbr_fea1 = atom_in_fea[:,nbr_fea_idx,:]
atom_nbr_fea2 = atom_in_fea[:,[0,1,4],:]

print(atom_nbr_fea1==atom_nbr_fea2)

这里dim=1维度的切片索引就是一个一维张量,里面的整数是0,1,4。说明切片时,取这个维度的第1,第2,第5个数据。

如果索引是二维张量,在切片的同时其实进行了升维。

import torch
import numpy as np

atom_in_fea = torch.randn(4,5,6)

#shape(2,2)
nbr_fea_idx = torch.tensor([[0, 1],[1,2]])

atom_nbr_fea = atom_in_fea[:,nbr_fea_idx,:]

print(atom_nbr_fea.shape)
#torch.Size([4, 2, 2, 6])

这里dim=1维度的切片索引就是一个二维张量,它会按照内部每一行的整数对这个维度进行切片,并根据内部的行数形成新的一维。

听起来很复杂,这里给出一个具体的应用场景,在CGCNN的卷积操作中。

atom_in_fea的性质为(N,features),N是批次内的总原子数

nbr_fea_idx是一个二维张量(N,M),记录了N个原子的邻居原子的原子id。

atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]

从形状上来看,这样的切片结果就是把(N,features)中的N替换成(N,M),最终变成(N,M,features)

这样,就得到了,每个原子的M个邻居原子的原子特征。

这里nbr_fea_idx的记录很关键,CIFData给出的是每个晶体内的,collate_pool函数会把它更新成batch内的。模型接受的参数是经collate_pool函数集成后的。

Comment