728x90
torch.multinomial
multinomial 함수는 텐서안에 있는 값들을 통해 무작위로 샘플링을 수행합니다.
텐서를 반환하며, 각 행은 텐서 입력의 해당 행에 위치한 다항 확률 분포(input_tensor)로부터 추출된 num_samples 개의 인덱스를 포함합니다.
입력의 행들은 합이 반드시 1이 될 필요는 없습니다.( 이 경우에는 값을 가중치로 사용) 그러나 음수가 아니어야 하며, 유한해야 하고, 합이 0이 아니어야 합니다.
함수 선언
import torch
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor
파라미터
- input(Tensor) - 확률을 포함한 Tensor ( 꼭 합이 1일 필요가 없으며, 모든 요소가 0 이상의 정수이며 합이 0이상)
- num_samples(int) - 각 행에서 뽑을 sample 의 갯수
- replacement(bool, optional) - 복원 추출에 대한 여부입니다.(추출하고서 없앨 것인지 아닌지)
함수 사용하는 방법
import torch
probs = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.0]], dtype=torch.float)
# 각 행에서 3개의 샘플을 추출
samples = torch.multinomial(probs, num_samples=3)
print(samples)
# Output
# tensor([[1, 2, 3]])
* 확률 높은 것으로 뽑는 것이 아닌 확률을 통해서 인덱스를 랜덤하게 뽑습니다.
replcement = True
import torch
probs = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.0]], dtype=torch.float)
samples = torch.multinomial(probs, num_samples=3,replacement=True)
print(samples)
# Output
# tensor([[2, 2, 1]])
뽑은 index가 또 뽑힐 수 있습니다.
함수가 오류날 경우
import torch
# tensor 안에 inf, nan, 0보다 이하 요소가 포함될 경우
# RuntimeError 발생
probs = torch.tensor([[-0.1, 0.2, 0.3, 0.4, 0.0]], dtype=torch.float)
samples = torch.multinomial(probs, num_samples=3)
print(samples)
# Output
# RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
'AI & DL > Pytorch' 카테고리의 다른 글
[PyTorch] DataLoader 기초 및 구현 (2) | 2024.03.17 |
---|---|
[PyTorch] Dataset 기초 및 구현 (0) | 2024.03.16 |
[PyTorch] RuntimeError: The NVIDIA driver on your system is too old. 오류 해결법 (0) | 2023.12.31 |
[PyTorch] torch.ne, torch.eq 에 대해 알아보자 (0) | 2023.09.13 |
[Pytorch] numpy to tensor, tensor to numpy, 넘파이와 텐서 변환 (0) | 2023.09.06 |