Source code for compiler_gym.spaces.scalar

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import Optional

import numpy as np
from gym.spaces import Space

from compiler_gym.spaces.common import issubdtype


[docs]class Scalar(Space): """A scalar value."""
[docs] def __init__( self, name: str, min: Optional[float] = None, max: Optional[float] = None, dtype=np.float64, ): """Constructor. :param name: The name of the space. :param min: The lower bound for a value in this space. If None, there is no lower bound. :param max: The upper bound for a value in this space. If None, there is no upper bound. :param dtype: The type of this scalar. """ self.name = name self.min = min self.max = max self.dtype = dtype
[docs] def sample(self): min = 0 if self.min is None else self.min max = 1 if self.max is None else self.max return self.dtype(random.uniform(min, max))
[docs] def contains(self, x): if not issubdtype(type(x), self.dtype): return False min = -float("inf") if self.min is None else self.min max = float("inf") if self.max is None else self.max return min <= x <= max
def __repr__(self): if self.min is None and self.max is None: return self.dtype.__name__ lower_bound = "-inf" if self.min is None else self.min upper_bound = "inf" if self.max is None else self.max return f"{self.dtype.__name__}<{lower_bound},{upper_bound}>" def __eq__(self, rhs): """Equality test.""" if not isinstance(rhs, Scalar): return False return ( self.name == rhs.name and self.min == rhs.min and self.max == rhs.max and np.dtype(self.dtype) == np.dtype(rhs.dtype) )