# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import Optional
import numpy as np
from gym.spaces import Space
from compiler_gym.spaces.common import issubdtype
[docs]class Scalar(Space):
"""A scalar value."""
[docs] def __init__(
self,
name: str,
min: Optional[float] = None,
max: Optional[float] = None,
dtype=np.float64,
):
"""Constructor.
:param name: The name of the space.
:param min: The lower bound for a value in this space. If None, there is
no lower bound.
:param max: The upper bound for a value in this space. If None, there is
no upper bound.
:param dtype: The type of this scalar.
"""
self.name = name
self.min = min
self.max = max
self.dtype = dtype
[docs] def sample(self):
min = 0 if self.min is None else self.min
max = 1 if self.max is None else self.max
return self.dtype(random.uniform(min, max))
[docs] def contains(self, x):
if not issubdtype(type(x), self.dtype):
return False
min = -float("inf") if self.min is None else self.min
max = float("inf") if self.max is None else self.max
return min <= x <= max
def __repr__(self):
if self.min is None and self.max is None:
return self.dtype.__name__
lower_bound = "-inf" if self.min is None else self.min
upper_bound = "inf" if self.max is None else self.max
return f"{self.dtype.__name__}<{lower_bound},{upper_bound}>"
def __eq__(self, rhs):
"""Equality test."""
if not isinstance(rhs, Scalar):
return False
return (
self.name == rhs.name
and self.min == rhs.min
and self.max == rhs.max
and np.dtype(self.dtype) == np.dtype(rhs.dtype)
)