Pytorch中一些重要的經典操作和簡單講解:
形狀變換操作
reshape() / view()
import torchx = torch.randn(2, 3, 4)
print(f"原始形狀: {x.shape}")# reshape可以處理非連續張量
y = x.reshape(6, 4)
print(f"reshape后: {y.shape}")# view要求張量在內存中連續
z = x.view(2, 12)
print(f"view后: {z.shape}")
transpose() / permute()
# transpose交換兩個維度
x = torch.randn(2, 3, 4)
y = x.transpose(0, 2) # 交換第0和第2維
print(f"transpose后: {y.shape}") # torch.Size([4, 3, 2])# permute重新排列所有維度
z = x.permute(2, 0, 1) # 將維度重排為 (4, 2, 3)
print(f"permute后: {z.shape}")
拼接和分割操作
cat() / stack()
# cat在現有維度上拼接
x1 = torch.randn(2, 3)
x2 = torch.randn(2, 3)# 在第0維拼接
cat_dim0 = torch.cat([x1, x2], dim=0) # (4, 3)
# 在第1維拼接
cat_dim1 = torch.cat([x1, x2], dim=1) # (2, 6)# stack創建新維度并拼接
stacked = torch.stack([x1, x2], dim=0) # (2, 2, 3)
chunk() / split()
x = torch.randn(6, 4)# chunk均勻分割
chunks = torch.chunk(x, 3, dim=0) # 分成3塊,每塊(2, 4)# split按指定大小分割
splits = torch.split(x, 2, dim=0) # 每塊大小為2
splits_uneven = torch.split(x, [1, 2, 3], dim=0) # 不均勻分割
索引和選擇操作
gather() / scatter()
# gather根據索引收集元素
x = torch.randn(3, 4)
indices = torch.tensor([[0, 1], [2, 3], [1, 0]])
gathered = torch.gather(x, 1, indices) # (3, 2)# scatter根據索引分散元素
src = torch.randn(3, 2)
scattered = torch.zeros(3, 4).scatter_(1, indices, src)
masked_select() / where()
x = torch.randn(3, 4)
mask = x > 0# 選擇滿足條件的元素
selected = torch.masked_select(x, mask)# 條件選擇
y = torch.randn(3, 4)
result = torch.where(mask, x, y) # mask為True選x,否則選y
數學運算操作
clamp() / clip()
x = torch.randn(3, 4)# 限制數值范圍
clamped = torch.clamp(x, min=-1, max=1)
# 等價于
clipped = torch.clip(x, -1, 1)
norm() / normalize()
x = torch.randn(3, 4)# 計算范數
l2_norm = torch.norm(x, p=2, dim=1) # L2范數
l1_norm = torch.norm(x, p=1, dim=1) # L1范數# 歸一化
normalized = torch.nn.functional.normalize(x, p=2, dim=1)
統計運算操作
mean() / sum() / std()
x = torch.randn(3, 4, 5)# 各種統計量
mean_all = x.mean() # 全局均值
mean_dim = x.mean(dim=1) # 沿第1維求均值
sum_keepdim = x.sum(dim=1, keepdim=True) # 保持維度# 最值操作
max_val, max_idx = torch.max(x, dim=1)
min_val, min_idx = torch.min(x, dim=1)
廣播和重復操作
expand() / repeat()
x = torch.randn(1, 3)# expand不復制數據,只是改變視圖
expanded = x.expand(4, 3) # (4, 3)# repeat實際復制數據
repeated = x.repeat(4, 2) # (4, 6)
tile() / repeat_interleave()
x = torch.tensor([1, 2, 3])# tile像numpy的tile
tiled = x.tile(2, 3) # 重復2次每行,3次每列# repeat_interleave每個元素重復
interleaved = x.repeat_interleave(2) # [1, 1, 2, 2, 3, 3]
類型轉換操作
to() / type() / cast()
x = torch.randn(3, 4)# 類型轉換
x_int = x.to(torch.int32)
x_float = x.type(torch.float64)
x_cuda = x.to('cuda') # 移到GPU(如果可用)# 設備轉換
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x_device = x.to(device)
在深度學習領域,這類張量運算操作具有極高的應用頻率,尤其在數據預處理、模型架構構建及推理后處理等關鍵環節中不可或缺。熟練掌握此類算子的應用邏輯,能夠顯著優化張量數據的處理流程,提升深度學習任務的執行效率與工程實現效能。