Skip to content

Commit

Permalink
increased performance for large action spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
drblallo committed Dec 29, 2024
1 parent 3cec76d commit 17a8ce0
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion python/ml/ppg/distr_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,59 @@
import torch as th
import torch.distributions as dis
from gym3.types import Discrete, Real, TensorType
from torch.distributions.utils import probs_to_logits

class Categorical:
def __init__(self, probs_shape):
# NOTE: probs_shape is supposed to be
# the shape of probs that will be
# produced by policy network
if len(probs_shape) < 1:
raise ValueError("`probs_shape` must be at least 1.")
self.probs_dim = len(probs_shape)
self.probs_shape = probs_shape
self._num_events = probs_shape[-1]
self._batch_shape = probs_shape[:-1] if self.probs_dim > 1 else th.Size()
self._event_shape=th.Size()

def set_probs_(self, probs):
self.probs = probs
self.logits = probs_to_logits(self.probs)

def set_probs(self, probs):
self.probs = probs / probs.sum(-1, keepdim=True)
self.logits = probs_to_logits(self.probs)

def set_logits(self, logits):
self.probs = th.softmax(logits, -1)
self.logits = logits

def sample(self, sample_shape=th.Size()):
if not isinstance(sample_shape, th.Size):
sample_shape = th.Size(sample_shape)
probs_2d = self.probs.reshape(-1, self._num_events)
samples_2d = th.multinomial(probs_2d, sample_shape.numel(), True).T
return samples_2d.reshape(sample_shape + self._batch_shape + self._event_shape)

def log_prob(self, value):
value = value.long().unsqueeze(-1)
value, log_pmf = th.broadcast_tensors(value, self.logits)
value = value[..., :1]
return log_pmf.gather(-1, value).squeeze(-1)

def entropy(self):
min_real = th.finfo(self.logits.dtype).min
logits = th.clamp(self.logits, min=min_real)
p_log_p = logits * self.probs
return -p_log_p.sum(-1)



def _make_categorical(x, ncat, shape):
x = x.reshape((*x.shape[:-1], *shape, ncat))
return dis.Categorical(logits=x)
cat = Categorical(x.shape)
cat.set_logits(x)
return cat


def _make_normal(x, shape):
Expand Down

0 comments on commit 17a8ce0

Please sign in to comment.