[ad_1]
split() can split a 1D or more D tensor into 1 or more tensors as shown below. *Setting a dimension to the 2nd argument can select the split position of a tensor:
import torch
my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
torch.split(my_tensor, 1)
my_tensor.split(1)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.split(my_tensor, 2)
my_tensor.split(2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.split(my_tensor, 3)
my_tensor.split(3)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)
torch.split(my_tensor, (0, 3))
my_tensor.split((0, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.split(my_tensor, (1, 2))
my_tensor.split((1, 2))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.split(my_tensor, (2, 1))
my_tensor.split((2, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.split(my_tensor, (3, 0))
my_tensor.split((3, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.split(my_tensor, (1, 1, 1))
my_tensor.split((1, 1, 1))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
vsplit() can vertically splits a 2D or more D tensor into 1 or more tensors as shown below. *Setting a dimension to the 2nd argument can select the split position of a tensor:
import torch
my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
torch.vsplit(my_tensor, 1)
my_tensor.vsplit(1)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)
torch.vsplit(my_tensor, 3)
my_tensor.vsplit(3)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0, 0))
my_tensor.vsplit((0, 0))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0, 1))
my_tensor.vsplit((0, 1))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0, 2))
my_tensor.vsplit((0, 2))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0, 3))
my_tensor.vsplit((0, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (1, 0))
my_tensor.vsplit((1, 0))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (1, 1))
my_tensor.vsplit((1, 1))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (1, 2))
my_tensor.vsplit((1, 2))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (1, 3))
my_tensor.vsplit((1, 3))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (2, 0))
my_tensor.vsplit((2, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (2, 1))
my_tensor.vsplit((2, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (2, 2))
my_tensor.vsplit((2, 2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (2, 3))
my_tensor.vsplit((2, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (3, 0))
my_tensor.vsplit((3, 0))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (3, 1))
my_tensor.vsplit((3, 1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (3, 2))
my_tensor.vsplit((3, 2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (3, 3))
my_tensor.vsplit((3, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64))
[ad_2]
Source link
[elementor-template id="51130"]