Source code for compiler_gym.wrappers.time_limit

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

from compiler_gym.envs import CompilerEnv
from compiler_gym.util.gym_type_hints import ActionType
from compiler_gym.wrappers.core import CompilerEnvWrapper

[docs]class TimeLimit(CompilerEnvWrapper): """A step-limited wrapper that is compatible with CompilerGym. Example usage: >>> env = TimeLimit(env, max_episode_steps=3) >>> env.reset() >>> _, _, done, _ = env.step(0) >>> _, _, done, _ = env.step(0) >>> _, _, done, _ = env.step(0) >>> done True """
[docs] def __init__(self, env: CompilerEnv, max_episode_steps: Optional[int] = None): super().__init__(env=env) if max_episode_steps is None and self.env.spec is not None: max_episode_steps = env.spec.max_episode_steps if self.env.spec is not None: self.env.spec.max_episode_steps = max_episode_steps self._max_episode_steps = max_episode_steps self._elapsed_steps = None
def step(self, action: ActionType, **kwargs): assert ( self._elapsed_steps is not None ), "Cannot call env.step() before calling reset()" observation, reward, done, info = self.env.step(action, **kwargs) self._elapsed_steps += 1 if self._elapsed_steps >= self._max_episode_steps: info["TimeLimit.truncated"] = not done done = True return observation, reward, done, info def reset(self, **kwargs): self._elapsed_steps = 0 return self.env.reset(**kwargs) def fork(self) -> "TimeLimit": """Fork the wrapped environment. The time limit state of the forked environment is the same as the source state. """ fkd = type(self)(env=self.env.fork(), max_episode_steps=self._max_episode_steps) fkd._elapsed_steps = self._elapsed_steps # pylint: disable=protected-access return fkd