# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Dict, List
from compiler_gym.errors import ServiceError
from compiler_gym.service.proto import ObservationSpace
from compiler_gym.util.gym_type_hints import (
    ActionType,
    ObservationType,
    RewardType,
    StepType,
)
from compiler_gym.views.observation_space_spec import ObservationSpaceSpec
[docs]class ObservationView:
    """A view into the available observation spaces of a service.
    Example usage:
        >>> env = gym.make("llvm-v0")
        >>> env.reset()
        >>> env.observation.spaces.keys()
        ["Autophase", "Ir"]
        >>> env.observation.spaces["Autophase"].space
        Box(56,)
        >>> env.observation["Autophase"]
        [0, 1, ..., 2]
        >>> observation["Ir"]
        int main() {...}
    """
    def __init__(
        self,
        raw_step: Callable[
            [List[ActionType], List[ObservationType], List[RewardType]], StepType
        ],
        spaces: List[ObservationSpace],
    ):
        if not spaces:
            raise ValueError("No observation spaces")
        self.spaces: Dict[str, ObservationSpaceSpec] = {}
        self._raw_step = raw_step
        for i, s in enumerate(spaces):
            self._add_space(ObservationSpaceSpec.from_proto(i, s))
[docs]    def __getitem__(self, observation_space: str) -> ObservationType:
        """Request an observation from the given space.
        :param observation_space: The observation space to query.
        :return: An observation.
        :raises KeyError: If the requested observation space does not exist.
        :raises SessionNotFound: If :meth:`env.reset()
            <compiler_gym.envs.CompilerEnv.reset>` has not been called.
        :raises ServiceError: If the backend service fails to compute the
            observation, or reports that a terminal state has been reached.
        """
        observation_space: ObservationSpaceSpec = self.spaces[observation_space]
        observations, _, done, info = self._raw_step(
            actions=[], observation_spaces=[observation_space], reward_spaces=[]
        )
        if done:
            # Computing an observation should never cause a terminal state since
            # no action has been applied.
            msg = f"Failed to compute observation '{observation_space.id}'"
            if info.get("error_details"):
                msg += f": {info['error_details']}"
            raise ServiceError(msg)
        if len(observations) != 1:
            raise ServiceError(
                f"Expected 1 '{observation_space.id}' observation "
                f"but the service returned {len(observations)}"
            )
        return observations[0] 
    def _add_space(self, space: ObservationSpaceSpec):
        """Register a new space."""
        self.spaces[space.id] = space
        # Bind a new method to this class that is a callback to compute the
        # given observation space. E.g. if a new space is added with ID
        # `FooBar`, this observation can be computed using
        # env.observation.FooBar().
        setattr(self, space.id, lambda: self[space.id])
[docs]    def add_derived_space(
        self,
        id: str,
        base_id: str,
        **kwargs,
    ) -> None:
        """Internal API for adding a new observation space."""
        base_space = self.spaces[base_id]
        self._add_space(base_space.make_derived_space(id=id, **kwargs)) 
    def __repr__(self):
        return f"ObservationView[{', '.join(sorted(self.spaces.keys()))}]"