Skip to content

Commit

Permalink
Get rid of generics and work around NiklasRosenstein/python-databind#66
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed May 5, 2024
1 parent 1cc9423 commit 9b4b8dd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion cleanba/convlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def make_conv(self, **kwargs):


@dataclasses.dataclass(frozen=True)
class BaseLSTMConfig(PolicySpec["LSTMState"]):
class BaseLSTMConfig(PolicySpec):
n_recurrent: int = 1 # D in the paper
repeats_per_step: int = 1 # N in the paper
pool_and_inject: bool = True
Expand Down
12 changes: 6 additions & 6 deletions cleanba/network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import dataclasses
from typing import Any, Generic, Literal, SupportsFloat, TypeVar
from typing import Any, Literal, SupportsFloat

import flax.linen as nn
import gymnasium as gym
Expand Down Expand Up @@ -37,11 +37,11 @@ def __call__(self, x: jax.Array) -> jax.Array:
return x


PolicyCarryT = TypeVar("PolicyCarryT")
PolicyCarryT = Any


@dataclasses.dataclass(frozen=True)
class PolicySpec(abc.ABC, Generic[PolicyCarryT]):
class PolicySpec(abc.ABC):
yang_init: bool = False
norm: NormConfig = IdentityNorm()
normalize_input: bool = False
Expand Down Expand Up @@ -181,7 +181,7 @@ def initialize_carry(self, rng, input_shape):


@dataclasses.dataclass(frozen=True)
class AtariCNNSpec(PolicySpec[tuple[()]]):
class AtariCNNSpec(PolicySpec):
channels: tuple[int, ...] = (16, 32, 32) # the channels of the CNN
strides: tuple[int, ...] = (2, 2, 2)
mlp_hiddens: tuple[int, ...] = (256,) # the hiddens size of the MLP
Expand Down Expand Up @@ -341,7 +341,7 @@ def __call__(self, x):


@dataclasses.dataclass(frozen=True)
class SokobanResNetConfig(PolicySpec[tuple[()]]):
class SokobanResNetConfig(PolicySpec):
channels: tuple[int, ...] = (64, 64, 64) * 3
kernel_sizes: tuple[int, ...] = (4, 4, 4) * 3

Expand Down Expand Up @@ -481,7 +481,7 @@ def label_and_learning_rate_for_params(


@dataclasses.dataclass(frozen=True)
class GuezResNetConfig(PolicySpec[tuple[()]]):
class GuezResNetConfig(PolicySpec):
channels: tuple[int, ...] = (32, 32, 64, 64, 64, 64, 64, 64, 64)
strides: tuple[int, ...] = (1,) * 9
kernel_sizes: tuple[int, ...] = (4,) * 9
Expand Down

0 comments on commit 9b4b8dd

Please sign in to comment.