Source code for compiler_gym.views.observation

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List

from compiler_gym.errors import ServiceError
from compiler_gym.service.proto import ObservationSpace
from compiler_gym.util.gym_type_hints import (
    ActionType,
    ObservationType,
    RewardType,
    StepType,
)
from compiler_gym.views.observation_space_spec import ObservationSpaceSpec


[docs]class ObservationView: """A view into the available observation spaces of a service. Example usage: >>> env = gym.make("llvm-v0") >>> env.reset() >>> env.observation.spaces.keys() ["Autophase", "Ir"] >>> env.observation.spaces["Autophase"].space Box(56,) >>> env.observation["Autophase"] [0, 1, ..., 2] >>> observation["Ir"] int main() {...} """ def __init__( self, raw_step: Callable[ [List[ActionType], List[ObservationType], List[RewardType]], StepType ], spaces: List[ObservationSpace], ): if not spaces: raise ValueError("No observation spaces") self.spaces: Dict[str, ObservationSpaceSpec] = {} self._raw_step = raw_step for i, s in enumerate(spaces): self._add_space(ObservationSpaceSpec.from_proto(i, s))
[docs] def __getitem__(self, observation_space: str) -> ObservationType: """Request an observation from the given space. :param observation_space: The observation space to query. :return: An observation. :raises KeyError: If the requested observation space does not exist. :raises SessionNotFound: If :meth:`env.reset() <compiler_gym.envs.CompilerEnv.reset>` has not been called. :raises ServiceError: If the backend service fails to compute the observation, or reports that a terminal state has been reached. """ observation_space: ObservationSpaceSpec = self.spaces[observation_space] observations, _, done, info = self._raw_step( actions=[], observation_spaces=[observation_space], reward_spaces=[] ) if done: # Computing an observation should never cause a terminal state since # no action has been applied. msg = f"Failed to compute observation '{observation_space.id}'" if info.get("error_details"): msg += f": {info['error_details']}" raise ServiceError(msg) if len(observations) != 1: raise ServiceError( f"Expected 1 '{observation_space.id}' observation " f"but the service returned {len(observations)}" ) return observations[0]
def _add_space(self, space: ObservationSpaceSpec): """Register a new space.""" self.spaces[space.id] = space # Bind a new method to this class that is a callback to compute the # given observation space. E.g. if a new space is added with ID # `FooBar`, this observation can be computed using # env.observation.FooBar(). setattr(self, space.id, lambda: self[space.id])
[docs] def add_derived_space( self, id: str, base_id: str, **kwargs, ) -> None: """Internal API for adding a new observation space.""" base_space = self.spaces[base_id] self._add_space(base_space.make_derived_space(id=id, **kwargs))
def __repr__(self): return f"ObservationView[{', '.join(sorted(self.spaces.keys()))}]"