Skip to content

Commit

Permalink
dropped ray in favour of custom solution
Browse files Browse the repository at this point in the history
  • Loading branch information
drblallo committed Dec 28, 2024
1 parent a5e74c5 commit 3cec76d
Show file tree
Hide file tree
Showing 18 changed files with 142 additions and 888 deletions.
16 changes: 10 additions & 6 deletions python/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from tensorboard.program import TensorBoard
from ml.ppg.train import train
from os import makedirs
from os import makedirs, path
import tempfile

def hypersearch_params():
for lr in [1e-3, 1e-4, 1e-5]:
Expand Down Expand Up @@ -52,12 +53,15 @@ def main():
args = parser.parse_args()
program = load_program_from_args(args, True)

tmp_dir = path.join(tempfile.gettempdir(), "ppg")
league_play_nets_dir = path.join(tmp_dir, "nets")

if not args.no_tensorboard:
tb = TensorBoard()
tb.configure(argv=[None, "--logdir", "/tmp/ppg/"])
tb.configure(argv=[None, "--logdir", tmp_dir])
url = tb.launch()
if args.league_play:
makedirs("/tmp/ppg/nets", exist_ok=True)
makedirs(league_play_nets_dir, exist_ok=True)

if args.hypersearch:
for num, params in enumerate(hypersearch_params()):
Expand All @@ -69,7 +73,7 @@ def main():
path_to_weights=args.load,
output=args.output,
model_save_frequency=args.model_save_frequency,
log_dir=f"/tmp/ppg/{num}_{hypers}/",
log_dir=path.join(tmp_dir, f"{num}_{hypers}"),
**params
)
else:
Expand All @@ -84,8 +88,8 @@ def main():
model_save_frequency=args.model_save_frequency,
entcoef=args.entropy_coeff,
nstep=args.steps_per_env,
log_dir="/tmp/ppg",
league_play_dir="" if not args.league_play else "/tmp/ppg/nets"
log_dir=tmp_dir,
league_play_dir="" if not args.league_play else league_play_nets_dir
)

program.cleanup()
Expand Down
4 changes: 2 additions & 2 deletions python/ml/ppg/distr_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def tensor_distr_builder(ac_space):
"""
assert isinstance(ac_space, TensorType)
eltype = ac_space.eltype
if eltype == Discrete(2):
return (ac_space.size, partial(_make_bernoulli, shape=ac_space.shape))
# if eltype == Discrete(2):
# return (ac_space.size, partial(_make_bernoulli, shape=ac_space.shape))
if isinstance(eltype, Discrete):
return (
eltype.n * ac_space.size,
Expand Down
11 changes: 10 additions & 1 deletion python/ml/ppg/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ def get_action_mask(self):
return self.last_valid_action_mask

def total_score(self, player_id):
return self.score_fn(self.state.state, player_id).value
score = self.score_fn(self.state.state, player_id)
if isinstance(score, float):
return score
return score.value

def step_score(self, player_id):
return self.current_score[player_id] - self.last_score[player_id]
Expand Down Expand Up @@ -261,6 +264,9 @@ def is_done_underling(self):
def pretty_print(self):
self.state.pretty_print()

def print(self):
self.state.print()

def is_done_for_everyone(self):
return self.state.state.resume_index == -1 and len(self.players_final_turn) == 0

Expand Down Expand Up @@ -350,6 +356,9 @@ def first_for_all_players(self, game_id):
def pretty_print(self, game_id):
self.games[game_id].pretty_print()

def print(self, game_id):
self.games[game_id].print()

def get_user_defined_log_functions(self):
return self.games[0].user_defined_log_functions

Expand Down
11 changes: 11 additions & 0 deletions python/ml/ppg/ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ class PpoModel(th.nn.Module):
def forward(self, ob, first, state_in, action_mask) -> "pd, vpred, aux, state_out":
raise NotImplementedError

@tu.no_grad
def act_logp(self, ob, first, state_in, action_mask):
pd, vpred, _, state_out = self(
ob=tree_util.tree_map(lambda x: x[:, None], ob),
first=first[:, None],
state_in=state_in,
action_mask=action_mask,
)

return pd

@tu.no_grad
def act(self, ob, first, state_in, action_mask):
pd, vpred, _, state_out = self(
Expand Down
2 changes: 1 addition & 1 deletion python/ml/ppg/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def train_pi_and_vf(**arrays):
initial_state=model.initial_state(venv.num),
keep_buf=10000,
keep_non_rolling=log_save_opts.get("log_new_eps", False),
past_stragey_model=deepcopy(model),
past_stragey_model=deepcopy(model) if path_to_league_play_dir != "" else None,
)

lsh = learn_state.get("lsh") or log_save_helper.LogSaveHelper(
Expand Down
2 changes: 1 addition & 1 deletion python/ml/ppg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def train_fn(
n_epoch_pi=n_epoch_pi,
clip_param=clip_param,
kl_penalty=kl_penalty,
log_save_opts={"save_mode": "last", "num_players": 2},
log_save_opts={"save_mode": "last", "num_players": venv.get_num_players()},
nstep=nstep,
entcoef=entcoef,
callbacks=[ModelSaver(model, output, model_save_frequency)],
Expand Down
Empty file removed python/ml/raylib/__init__.py
Empty file.
202 changes: 0 additions & 202 deletions python/ml/raylib/action_mask.py

This file was deleted.

Loading

0 comments on commit 3cec76d

Please sign in to comment.