Source code for compiler_gym.spaces.sequence

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple

import numpy as np
from gym.spaces import Space

from compiler_gym.spaces.common import issubdtype
from compiler_gym.spaces.scalar import Scalar


[docs]class Sequence(Space): """A sequence of values. Each element of the sequence is of `dtype`. The length of the sequence is bounded by `size_range`. Example: :: >>> space = Sequence(size_range=(0, None), dtype=str) >>> space.contains("Hello, world!") True :: >>> space = Sequence(size_range=(256, 256), dtype=bytes) >>> space.contains("Hello, world!") False :ivar size_range: A tuple indicating the `(lower, upper)` bounds for sequence lengths. An upper bound of `None` means no upper bound. All sequences must have a lower bound of length >= 0. :ivar dtype: The data type for each element in a sequence. :ivar opaque_data_format: An optional string describing an opaque data format, e.g. a data structure that is serialized to a string/binary array for transmission to the client. It is up to the client and service to agree on how to decode observations using this value. For example, an opaque_data_format of `string_json` could be used to indicate that the observation is a string-serialized JSON value. """
[docs] def __init__( self, name: str, size_range: Tuple[int, Optional[int]] = (0, None), dtype=bytes, opaque_data_format: Optional[str] = None, scalar_range: Optional[Scalar] = None, ): """Constructor. :param name: The name of the space. :param size_range: A tuple indicating the `(lower, upper)` bounds for sequence lengths. An upper bound of `None` means no upper bound. All sequences must have a lower bound of length >= 0. :param dtype: The data type for each element in a sequence. :param opaque_data_format: An optional string describing an opaque data format, e.g. a data structure that is serialized to a string/binary array for transmission to the client. It is up to the client and service to agree on how to decode observations using this value. For example, an opaque_data_format of `string_json` could be used to indicate that the observation is a string-serialized JSON value. :param scalar_range: If specified, this denotes the legal range of each element in the sequence. This is enforced by :meth:`contains() <compiler_gym.spaces.Sequence.contains>` checks. """ self.name = name self.size_range = size_range self.dtype = dtype self.opaque_data_format = opaque_data_format self.scalar_range = scalar_range
def __repr__(self) -> str: upper_bound = "inf" if self.size_range[1] is None else self.size_range[1] d = f" -> {self.opaque_data_format}" if self.opaque_data_format else "" return ( f"{self.dtype.__name__}_list<>[{int(self.size_range[0])},{upper_bound}]){d}" )
[docs] def contains(self, x): lower_bound = self.size_range[0] upper_bound = float("inf") if self.size_range[1] is None else self.size_range[1] if not (lower_bound <= len(x) <= upper_bound): return False # TODO(cummins): The dtype API is inconsistent. When dtype=str or # dtype=bytes, we expect this to be the type of the entire sequence. But # for dtype=int, we expect this to be the type of each element. We # should distinguish these differences better. if self.dtype in {str, bytes}: if not isinstance(x, self.dtype): return False elif hasattr(x, "dtype"): if not issubdtype(x.dtype, self.dtype): return False # Run the bounds check on every scalar element, if there is a scalar # range specified. elif self.scalar_range: return all(self.scalar_range.contains(s) for s in x) else: for element in x: if not issubdtype(type(element), self.dtype): return False return True
[docs] def sample(self): """ .. warning:: The `Sequence` space cannot be sampled from. :raises NotImplementedError: Not supported. """ raise NotImplementedError
def __eq__(self, other): if not isinstance(other, Sequence): return False return ( self.name == other.name and self.size_range == other.size_range and np.dtype(self.dtype) == np.dtype(other.dtype) and self.opaque_data_format == other.opaque_data_format and self.scalar_range == other.scalar_range )