콘텐츠로 이동

슬라이싱

슬라이싱은 tensor[start:end:step] 형식으로 구간을 선택합니다. end포함하지 않습니다(exclusive).

import torch
v = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
print(v[2:6]) # tensor([2, 3, 4, 5]) ← 인덱스 2~5
print(v[::2]) # tensor([0, 2, 4, 6, 8]) ← 2칸 간격
print(v[1::2]) # tensor([1, 3, 5, 7, 9]) ← 1부터 2칸 간격
print(v[::-1]) # tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) ← 역순
print(v[:5]) # tensor([0, 1, 2, 3, 4]) ← 처음부터 5개
print(v[5:]) # tensor([5, 6, 7, 8, 9]) ← 5번째부터 끝까지
print(v[:]) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) ← 전체
표현의미
v[a:b]인덱스 a 이상 b 미만
v[a:]인덱스 a부터 끝까지
v[:b]처음부터 인덱스 b 미만
v[:]전체
v[a:b:s]a 이상 b 미만, s 간격
v[::-1]역순 전체

각 차원의 슬라이스를 쉼표로 구분합니다.

import torch
m = torch.arange(1, 25).reshape(4, 6)
# tensor([[ 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12],
# [13, 14, 15, 16, 17, 18],
# [19, 20, 21, 22, 23, 24]])
# 처음 2행, 처음 3열
print(m[0:2, 0:3])
# tensor([[1, 2, 3],
# [7, 8, 9]])
# 마지막 2행 전체
print(m[-2:, :])
# tensor([[13, 14, 15, 16, 17, 18],
# [19, 20, 21, 22, 23, 24]])
# 짝수 행, 홀수 열
print(m[::2, 1::2])
# tensor([[ 2, 4, 6],
# [14, 16, 18]])

딥러닝에서 자주 만나는 (배치, 채널, 시퀀스) 형태의 예시입니다.

import torch
# shape: (배치=4, 채널=3, 시퀀스=10)
t = torch.randn(4, 3, 10)
# 처음 2개 배치만
print(t[0:2].shape) # torch.Size([2, 3, 10])
# 모든 배치의 첫 번째 채널만
print(t[:, 0:1, :].shape) # torch.Size([4, 1, 10])
# 모든 배치, 모든 채널, 처음 5개 시퀀스
print(t[:, :, :5].shape) # torch.Size([4, 3, 5])
# 마지막 배치, 두 번째 채널, 홀수 위치
print(t[-1, 1, ::2].shape) # torch.Size([5])

... (Ellipsis) 는 나머지 모든 차원:로 채우는 축약 표현입니다. 차원 수가 많거나 가변적일 때 유용합니다.

import torch
t = torch.randn(2, 3, 4, 5)
# 마지막 차원만 슬라이싱: 나머지 차원은 전체 선택
print(t[..., :3].shape) # torch.Size([2, 3, 4, 3])
# 위는 t[:, :, :, :3]과 동일
# 첫 번째 차원 고정, 나머지 전체
print(t[0, ...].shape) # torch.Size([3, 4, 5])
# 위는 t[0, :, :, :]와 동일
# 앞 차원 고정, 마지막 차원 슬라이싱
print(t[0, ..., :3].shape) # torch.Size([3, 4, 3])

슬라이싱은 데이터를 복사하지 않고 원본 storage를 공유하는 뷰 를 반환합니다.

import torch
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
y = x[1:4] # 뷰 반환
print(y) # tensor([2., 3., 4.])
# y를 수정하면 x도 바뀜
y[0] = 99.0
print(x) # tensor([ 1., 99., 3., 4., 5.]) ← x도 변경됨!
print(y) # tensor([99., 3., 4.])
메모리: [ 1., 99., 3., 4., 5. ]
↑ ↑ ↑ ↑ ↑
x x[0] x[1] x[2] x[3] x[4]
y y[0] y[1] y[2]
(같은 메모리를 공유)
import torch
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
# clone()으로 독립 복사
y = x[1:4].clone()
y[0] = 99.0
print(x) # tensor([1., 2., 3., 4., 5.]) ← x는 변경되지 않음
print(y) # tensor([99., 3., 4.])

슬라이싱은 정수 인덱싱과 달리 차원을 유지 합니다.

import torch
m = torch.randn(4, 6)
# 정수 인덱싱: 차원 감소
print(m[0].shape) # torch.Size([6]) ← 1D
# 슬라이싱: 차원 유지
print(m[0:1].shape) # torch.Size([1, 6]) ← 2D 유지
print(m[0:3].shape) # torch.Size([3, 6])
# 열 방향도 동일
print(m[:, 0].shape) # torch.Size([4]) ← 정수, 차원 감소
print(m[:, 0:1].shape) # torch.Size([4, 1]) ← 슬라이싱, 차원 유지

배치 처리 코드에서 차원을 유지해야 할 때 0:1 형태의 슬라이싱을 자주 사용합니다.