728x90

torch.multinomial

multinomial 함수는 텐서안에 있는 값들을 통해 무작위로 샘플링을 수행합니다. 

텐서를 반환하며, 각 행은 텐서 입력의 해당 행에 위치한 다항 확률 분포(input_tensor)로부터 추출된 num_samples 개의 인덱스를 포함합니다.

 

입력의 행들은 합이 반드시 1이 될 필요는 없습니다.( 이 경우에는 값을 가중치로 사용) 그러나 음수가 아니어야 하며, 유한해야 하고, 합이 0이 아니어야 합니다.

 

함수 선언

import torch
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor

파라미터

  • input(Tensor) - 확률을 포함한 Tensor ( 꼭 합이 1일 필요가 없으며, 모든 요소가 0 이상의 정수이며 합이 0이상)
  • num_samples(int) - 각 행에서 뽑을 sample 의 갯수
  • replacement(bool, optional) - 복원 추출에 대한 여부입니다.(추출하고서 없앨 것인지 아닌지)

 

함수 사용하는 방법

import torch
probs = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.0]], dtype=torch.float)

# 각 행에서 3개의 샘플을 추출
samples = torch.multinomial(probs, num_samples=3)

print(samples)

# Output
# tensor([[1, 2, 3]])

 * 확률 높은 것으로 뽑는 것이 아닌 확률을 통해서 인덱스를 랜덤하게 뽑습니다.

확률이 0인 4번째 index는 한번도 나오지 않는 것을 알 수 있습니다.

 

replcement = True

import torch
probs = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.0]], dtype=torch.float)

samples = torch.multinomial(probs, num_samples=3,replacement=True)

print(samples)

# Output
# tensor([[2, 2, 1]])

뽑은 index가 또 뽑힐 수 있습니다.

 

함수가 오류날 경우

import torch
# tensor 안에 inf, nan, 0보다 이하 요소가 포함될 경우
# RuntimeError 발생
probs = torch.tensor([[-0.1, 0.2, 0.3, 0.4, 0.0]], dtype=torch.float)


samples = torch.multinomial(probs, num_samples=3)

print(samples)

# Output
# RuntimeError: probability tensor contains either `inf`, `nan` or element < 0

 

728x90