Source code for compiler_gym.spaces.permutation

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from numbers import Integral

import numpy as np

from compiler_gym.spaces.scalar import Scalar
from compiler_gym.spaces.sequence import Sequence

[docs]class Permutation(Sequence): """The space of permutations of all numbers in the range `scalar_range`."""
[docs] def __init__(self, name: str, scalar_range: Scalar): """Constructor. :param name: The name of the permutation space. :param scalar_range: Range of numbers in the permutation. For example the scalar range [1, 3] would define permutations like [1, 2, 3] or [2, 1, 3], etc. :raises TypeError: If `scalar_range.dtype` is not an integral type. """ if not issubclass(np.dtype(scalar_range.dtype).type, Integral): raise TypeError("Permutation space can have integral scalar range only.") sz = scalar_range.max - scalar_range.min + 1 super().__init__( name=name, size_range=(sz, sz), dtype=scalar_range.dtype, scalar_range=scalar_range, )
[docs] def sample(self): return ( np.random.choice(self.size_range[0], size=self.size_range[1], replace=False) + self.scalar_range.min )
def __eq__(self, other) -> bool: return isinstance(self, other.__class__) and super().__eq__(other)