728x90
안녕하세요. 오늘은 인공지능과 딥러닝에서 매우 많이 사용하는 PyTorch의 'Dataset'에 대해 자세하게 설명해보려고 합니다. 인공지능과 딥러닝에서 모델을 학습시키기 위해 Data를 처리하는 코드는 복잡하고 유지보수가 어려울 수 있습니다. 이를 위해 PyTorch에서는 더 나은 가독성과 모듈성을 위해 데이터셋 코드를 분리합니다.
이 글에서 PyTorch의 'Dataset'은 클래스를 직접 구현해 보며 이해해보려고 합니다.
Dataset이란?
PyTorch에서 Dataset은 데이터와 레이블을 저장하고 있으며, 데이터에 쉽게 접근할 수 있도록 도와주는 추상 클래스입니다. 이 클래스를 사용하여 다양한 데이터 소스(예: 파일, 데이터베이스, 메모리 등)에서 데이터를 불러오고, 필요한 전처리 작업을 수행할 수 있습니다.
요약하자면, Dataset은 샘플과 정답들을 저장하는 클래스라고 생각하시면 됩니다.
기본적으로 Dataset 클래스는 다음 세 가지 메서드를 구현해야 합니다.
- __init__: 필요한 변수들을 선언합니다. init 함수는 Dataset 객체가 생성될 때 한 번만 실행됩니다.
- __len__: 데이터셋의 총 데이터 개수를 반환합니다
- __getitem__: 주어진 인덱스에 해당하는 샘플을 데이터셋에서 불러와 반환합니다.
구현 예제
import torch
from torch.utils.data import Dataset
# 사용자 정의 데이터셋 클래스
class CustomDataset(Dataset):
# 생성자, 데이터를 전처리하는 부분
def __init__(self, length=100):
self.data = torch.randn(length, 10) # 임의의 데이터 생성 (예: 100x10 텐서)
self.labels = torch.randint(0, 2, (length,)) # 임의의 레이블 생성 (0 또는 1)
# Dataset 클래스의 길이를 반환
def __len__(self):
return len(self.data)
# 데이터셋의 데이터중 idx위치의 Data를 반환하는 코드
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 데이터셋 생성
dataset = CustomDataset()
실행결과
>>>len(dataset) # 데이터셋의 길이는 100
100
>>>dataset[0] # 데이터셋의 첫번째 데이터의 값(샘플, label)
(tensor([ 0.5283, -0.5272, -1.0905, 0.4210, 0.2976, -0.2760, -0.8738, 1.0800,
-1.9537, -0.2197]),
tensor(0))
마치며
PyTorch의 Dataset 클래스를 사용하면 다양한 데이터 소스에서 데이터를 효율적으로 불러오고, 전처리하는 과정을 간소화할 수 있습니다. 사용자 정의 Dataset을 만들 때는 __len__과 __getitem__ 메서드를 구현해야 한다는 점을 기억해주세요.
'AI & DL > Pytorch' 카테고리의 다른 글
[PyTorch] DataLoader 기초 및 구현 (2) | 2024.03.17 |
---|---|
[PyTorch] RuntimeError: The NVIDIA driver on your system is too old. 오류 해결법 (0) | 2023.12.31 |
[PyTorch] torch.ne, torch.eq 에 대해 알아보자 (0) | 2023.09.13 |
[Pytorch] torch.multinomial, multinomial (0) | 2023.09.07 |
[Pytorch] numpy to tensor, tensor to numpy, 넘파이와 텐서 변환 (0) | 2023.09.06 |