目錄
unbind拆分子張量
1. 沿著第n個維度拆分(即按“批次”拆分)
split分割張量
常用用法:
總結:
unbind拆分子張量
import torchquaternions = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
result = torch.unbind(quaternions, -1)
print(result)
1. 沿著第n個維度拆分(即按“批次”拆分)
假設你有一個形狀為 (batch_size, n)
的張量,你可以沿著第一個維度(即批次維度)拆分它。
split分割張量
返回一個元組,其中包含分割后的子張量。
常用用法:
-
按指定大小分割: 當
split_size_or_sections
為一個整數時,表示每個子張量的大小。import torch# 創建一個張量 tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])# 按照大小 3 分割 result = torch.split(tensor, 3)# 輸出分割后的結果 for i, part in enumerate(result):print(f"Part {i}: {part}")
這個例子中,張量被分割成了 3 個大小為 3 的子張量和一個大小為 1 的子張量。
-
按指定分割長度分割: 當
split_size_or_sections
是一個列表或元組時,表示沿指定維度分割的塊數。每個數值對應要分割的子張量的大小。
# 創建一個張量
tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])# 按照每個塊的大小為 2 和 3 分割
result = torch.split(tensor, [2, 3, 5])# 輸出分割后的結果
for i, part in enumerate(result):print(f"Part {i}: {part}")
指定維度進行分割: 可以通過 dim
參數指定沿哪個維度進行分割。
# 創建一個二維張量
tensor = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8]])# 沿著維度 1(列)分割,每個子張量包含 2 列
result = torch.split(tensor, 2, dim=1)# 輸出分割后的結果
for i, part in enumerate(result):print(f"Part {i}: {part}")
總結:
torch.split
是一個非常實用的工具,能夠根據指定的大小或者分割長度將張量分割成多個子張量。常用的應用場景包括:
-
將數據按批次(batch)分割。
-
在處理大張量時按一定的塊進行分割,方便并行計算或逐塊處理。