Skip to content

Commit

Permalink
Better reset, release 0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
yannbouteiller committed Mar 24, 2023
1 parent 4bf5593 commit 969799c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ This happens either because the environment has been 'paused', or because the sy
- The inference duration of the model, i.e. the elapsed duration between two calls of the step() function, may be too long for the time-step duration that the user is trying to use.
- The procedure that retrieves observations may take too much time or may be called too late (the latter can be tweaked in the configuration dictionary). Remember that, if observation capture is too long, it must not be part of the `get_obs_rew_terminated_info()` method of your interface. Instead, this method must simply retrieve the latest available observation from another process, and the action buffer must be long enough to handle the observation capture duration. This is described in the Appendix of [Reinforcement Learning with Random Delays](https://arxiv.org/abs/2010.02966).


A call to `reset()` starts the elastic `rtgym` clock.
Once the clock is started, it can be stopped via a call to the `wait()` API to artificially "pause" the environment.

`reset()` captures an initial observation and sends the default action, because Real-Time MDPs require an action to be applied at all time.
`reset()` captures an initial observation and sends the default action, since Real-Time MDPs require an action to be applied at all time.

The following figure illustrates how `rtgym` behaves around `reset` transitions when:
- the configuration dictionary has `"wait_on_done": True`
Expand Down
14 changes: 8 additions & 6 deletions rtgym/envs/real_time_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,6 @@ def _retrieve_obs_rew_terminated_truncated_info(self):
self.__o_set_flag = False
c = False
self.__o_lock.release()
if self.act_in_obs:
elt = tuple((*elt, *tuple(self.act_buf),))
return elt, r, d, t, i

def init_action_buffer(self):
Expand Down Expand Up @@ -506,9 +504,11 @@ def reset(self, seed=None, options=None):
self.options = options
self.current_step = 0
if self.reset_act_buf:
# fill the action buffer with default actions:
self.init_action_buffer()
else:
self.act_buf.append(self.default_action)
# replace the last (non-applied) action from the previous episode by the action that is going to be applied:
self.act_buf[-1] = self.default_action
elt, info = self.interface.reset(seed=seed, options=options)
if self.act_in_obs:
elt = elt + list(self.act_buf)
Expand Down Expand Up @@ -541,15 +541,17 @@ def step(self, action):
self.bench.start_step_time()
self._join_thread()
self.current_step += 1
self.act_buf.append(action)
self.act_buf.append(action) # the action is always appended to the buffer
if not self.real_time:
self._run_time_step(action)
obs, rew, terminated, truncated, info = self._retrieve_obs_rew_terminated_truncated_info()
done = (terminated or truncated)
if self.real_time and not done:
if not done: # apply action only when not done
self._run_time_step(action)
if done and self.wait_on_done:
elif self.wait_on_done:
self.wait()
if self.act_in_obs:
obs = tuple((*obs, *tuple(self.act_buf),))
if self.benchmark:
self.bench.end_step_time()
return obs, rew, terminated, truncated, info
Expand Down
23 changes: 21 additions & 2 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def get_obs_rew_terminated_info(self):
obs = [np.array([time.time()], dtype=np.float64),
np.array(self.control, dtype=np.float64),
np.array([self.control_time], dtype=np.float64)]
return obs, 0.0, False, {}
terminated = (self.control >= 9).item()
return obs, 0.0, terminated, {}

def get_observation_space(self):
ob = gymnasium.spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float64)
Expand All @@ -47,6 +48,8 @@ def get_default_action(self):
config["time_step_duration"] = 0.1
config["start_obs_capture"] = 0.1
config["act_buf_len"] = 1
config["wait_on_done"] = False
config["reset_act_buf"] = False


class TestEnv(unittest.TestCase):
Expand Down Expand Up @@ -82,7 +85,7 @@ def test_timing(self):
for i in range(10):
obs1 = obs2
act = np.array([float(i + 1)])
obs2, _, _, _, _ = env.step(act)
obs2, _, terminated, _, _ = env.step(act)
now = time.time()
self.assertEqual(obs2[3], act)
self.assertEqual(obs2[1], act - 1.0)
Expand All @@ -94,6 +97,22 @@ def test_timing(self):
self.assertGreater(obs2[0] - obs1[0], 0.1 - epsilon)
self.assertGreater(0.1 + epsilon, obs2[0] - obs1[0])

# terminated signal:
if i >= 9:
self.assertTrue(terminated)

# test reset:
obs1, info = env.reset()

# default action (buffer):
self.assertEqual(obs1[3], -1)

act = np.array([float(22)])
obs1, _, _, _, _ = env.step(act)

# new action (buffer):
self.assertEqual(obs1[3], 22)


if __name__ == '__main__':
unittest.main()

0 comments on commit 969799c

Please sign in to comment.