# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, ClassVar, Optional, Union
from gym.spaces import Space
from compiler_gym.service.proto import Event, ObservationSpace, py_converters
from compiler_gym.util.gym_type_hints import ObservationType
from compiler_gym.util.shell_format import indent
[docs]class ObservationSpaceSpec:
"""Specification of an observation space.
:ivar id: The name of the observation space.
:vartype id: str
:ivar index: The index into the list of observation spaces that the service
supports.
:vartype index: int
:ivar space: The space.
:vartype space: Space
:ivar deterministic: Whether the observation space is deterministic.
:vartype deterministic: bool
:ivar platform_dependent: Whether the observation values depend on the
execution environment of the service.
:vartype platform_dependent: bool
:ivar default_value: A default observation. This value will be returned by
:func:`CompilerEnv.step() <compiler_gym.envs.CompilerEnv.step>` if
:func:`CompilerEnv.observation_space <compiler_gym.envs.CompilerEnv.observation_space>`
is set and the service terminates.
"""
message_converter: ClassVar[
Callable[[Any], Any]
] = py_converters.make_message_default_converter()
def __init__(
self,
id: str,
index: int,
space: Space,
translate: Callable[[Union[ObservationType, Event]], ObservationType],
to_string: Callable[[ObservationType], str],
deterministic: bool,
platform_dependent: bool,
default_value: ObservationType,
):
"""Constructor. Don't call directly, use make_derived_space()."""
self.id: str = id
self.index: int = index
self.space = space
self.deterministic = deterministic
self.platform_dependent = platform_dependent
self.default_value = default_value
self.translate = translate
self.to_string = to_string
def __hash__(self) -> int:
# Quickly hash observation spaces by comparing the index into the list
# of spaces returned by the environment. This means that you should not
# hash between observation spaces from different environments as this
# will cause collisions, e.g.
#
# # not okay:
# >>> obs = set(env.observation.spaces).union(
# other_env.observation.spaces
# )
#
# If you want to hash between environments, consider using the string id
# to identify the observation spaces.
return self.index
def __repr__(self) -> str:
return f"ObservationSpaceSpec({self.id})"
def __eq__(self, rhs) -> bool:
"""Equality check."""
if isinstance(rhs, str):
return self.id == rhs
elif isinstance(rhs, ObservationSpaceSpec):
return (
self.id == rhs.id
and self.index == rhs.index
and self.space == rhs.space
and self.platform_dependent == rhs.platform_dependent
and self.deterministic == rhs.deterministic
)
return False
@classmethod
def from_proto(cls, index: int, proto: ObservationSpace):
"""Create an observation space from a ObservationSpace protocol buffer.
:param index: The index of this observation space into the list of
observation spaces that the compiler service supports.
:param proto: An ObservationSpace protocol buffer.
:raises ValueError: If protocol buffer is invalid.
"""
try:
spec = ObservationSpaceSpec.message_converter(proto)
except ValueError as e:
raise ValueError(
f"Error interpreting description of observation space '{proto.name}'.\n"
f"Error: {e}\n"
f"ObservationSpace message:\n"
f"{indent(proto.space, n=2)}"
) from e
# TODO(cummins): Additional validation of the observation space
# specification would be useful here, such as making sure that the size
# of {low, high} tensors for box shapes match. At present, these errors
# tend not to show up until later, making it more difficult to debug.
return cls(
id=proto.name,
index=index,
space=spec,
translate=ObservationSpaceSpec.message_converter,
to_string=str,
deterministic=proto.deterministic,
platform_dependent=proto.platform_dependent,
default_value=ObservationSpaceSpec.message_converter(
proto.default_observation
),
)
[docs] def make_derived_space(
self,
id: str,
translate: Callable[[ObservationType], ObservationType],
space: Optional[Space] = None,
deterministic: Optional[bool] = None,
default_value: Optional[ObservationType] = None,
platform_dependent: Optional[bool] = None,
to_string: Callable[[ObservationType], str] = None,
) -> "ObservationSpaceSpec":
"""Create a derived observation space.
:param id: The name of the derived observation space.
:param translate: A callback function to compute a derived observation
from the base observation.
:param space: The :code:`gym.Space` describing the observation space.
:param deterministic: Whether the observation space is deterministic.
If not provided, the value is inherited from the base observation
space.
:param default_value: The default value for the observation space. If
not provided, the value is derived from the default value of the
base observation space.
:param platform_dependent: Whether the derived observation space is
platform-dependent. If not provided, the value is inherited from
the base observation space.
:param to_string: A callback to convert and observation to a string
representation. If not provided, the callback is inherited from the
base observation space.
:return: A new ObservationSpaceSpec.
"""
return ObservationSpaceSpec(
id=id,
index=self.index,
space=space or self.space,
translate=lambda observation: translate(self.translate(observation)),
to_string=to_string or self.to_string,
default_value=(
translate(self.default_value)
if default_value is None
else default_value
),
deterministic=(
self.deterministic if deterministic is None else deterministic
),
platform_dependent=(
self.platform_dependent
if platform_dependent is None
else platform_dependent
),
)