# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from itertools import cycle
from typing import Callable, Iterable, Optional, Union
import numpy as np
from compiler_gym.datasets import Benchmark
from compiler_gym.envs import CompilerEnv
from compiler_gym.util.parallelization import thread_safe_tee
from compiler_gym.wrappers.core import CompilerEnvWrapper
BenchmarkLike = Union[str, Benchmark]
[docs]class IterateOverBenchmarks(CompilerEnvWrapper):
"""Iterate over a (possibly infinite) sequence of benchmarks on each call to
reset(). Will raise :code:`StopIteration` on :meth:`reset()
<compiler_gym.envs.CompilerEnv.reset>` once the iterator is exhausted. Use
:class:`CycleOverBenchmarks` or :class:`RandomOrderBenchmarks` for wrappers
which will loop over the benchmarks.
"""
[docs] def __init__(
self,
env: CompilerEnv,
benchmarks: Iterable[BenchmarkLike],
fork_shares_iterator: bool = False,
):
"""Constructor.
:param env: The environment to wrap.
:param benchmarks: An iterable sequence of benchmarks.
:param fork_shares_iterator: If :code:`True`, the :code:`benchmarks`
iterator will bet shared by a forked environment created by
:meth:`env.fork() <compiler_gym.envs.CompilerEnv.fork>`. This means
that calling :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` with one environment will
advance the iterator in the other. If :code:`False`, forked
environments will use :code:`itertools.tee()` to create a copy of
the iterator so that each iterator may advance independently.
However, this requires shared buffers between the environments which
can lead to memory overheads if :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` is called many times more in
one environment than the other.
"""
super().__init__(env)
self.benchmarks = iter(benchmarks)
self.fork_shares_iterator = fork_shares_iterator
def reset(self, benchmark: Optional[BenchmarkLike] = None, **kwargs):
if benchmark is not None:
raise TypeError("Benchmark passed to IterateOverBenchmarks.reset()")
benchmark: BenchmarkLike = next(self.benchmarks)
return self.env.reset(benchmark=benchmark)
def fork(self) -> "IterateOverBenchmarks":
if self.fork_shares_iterator:
other_benchmarks_iterator = self.benchmarks
else:
self.benchmarks, other_benchmarks_iterator = thread_safe_tee(
self.benchmarks
)
return IterateOverBenchmarks(
env=self.env.fork(),
benchmarks=other_benchmarks_iterator,
fork_shares_iterator=self.fork_shares_iterator,
)
[docs]class CycleOverBenchmarks(IterateOverBenchmarks):
"""Cycle through a list of benchmarks on each call to :meth:`reset()
<compiler_gym.envs.CompilerEnv.reset>`. Same as
:class:`IterateOverBenchmarks` except the list of benchmarks repeats once
exhausted.
"""
[docs] def __init__(
self,
env: CompilerEnv,
benchmarks: Iterable[BenchmarkLike],
fork_shares_iterator: bool = False,
):
"""Constructor.
:param env: The environment to wrap.
:param benchmarks: An iterable sequence of benchmarks.
:param fork_shares_iterator: If :code:`True`, the :code:`benchmarks`
iterator will be shared by a forked environment created by
:meth:`env.fork() <compiler_gym.envs.CompilerEnv.fork>`. This means
that calling :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` with one environment will
advance the iterator in the other. If :code:`False`, forked
environments will use :code:`itertools.tee()` to create a copy of
the iterator so that each iterator may advance independently.
However, this requires shared buffers between the environments which
can lead to memory overheads if :meth:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` is called many times more in
one environment than the other.
"""
super().__init__(
env, benchmarks=cycle(benchmarks), fork_shares_iterator=fork_shares_iterator
)
[docs]class CycleOverBenchmarksIterator(CompilerEnvWrapper):
"""Same as :class:`CycleOverBenchmarks
<compiler_gym.wrappers.CycleOverBenchmarks>` except that the user generates
the iterator.
"""
[docs] def __init__(
self,
env: CompilerEnv,
make_benchmark_iterator: Callable[[], Iterable[BenchmarkLike]],
):
"""Constructor.
:param env: The environment to wrap.
:param make_benchmark_iterator: A callback that returns an iterator over
a sequence of benchmarks. Once the iterator is exhausted, this
callback is called to produce a new iterator.
"""
super().__init__(env)
self.make_benchmark_iterator = make_benchmark_iterator
self.benchmarks = iter(self.make_benchmark_iterator())
def reset(self, benchmark: Optional[BenchmarkLike] = None, **kwargs):
if benchmark is not None:
raise TypeError("Benchmark passed toIterateOverBenchmarks.reset()")
try:
benchmark: BenchmarkLike = next(self.benchmarks)
except StopIteration:
self.benchmarks = iter(self.make_benchmark_iterator())
benchmark: BenchmarkLike = next(self.benchmarks)
return self.env.reset(benchmark=benchmark)
def fork(self) -> "CycleOverBenchmarksIterator":
return CycleOverBenchmarksIterator(
env=self.env.fork(),
make_benchmark_iterator=self.make_benchmark_iterator,
)
[docs]class RandomOrderBenchmarks(IterateOverBenchmarks):
"""Select randomly from a list of benchmarks on each call to :meth:`reset()
<compiler_gym.envs.CompilerEnv.reset>`.
.. note::
Uniform random selection is provided by evaluating the input benchmarks
iterator into a list and sampling randomly from the list. For very large
and infinite iterables of benchmarks you must use the
:class:`IterateOverBenchmarks
<compiler_gym.wrappers.IterateOverBenchmarks>` wrapper with your own
random sampling iterator.
"""
[docs] def __init__(
self,
env: CompilerEnv,
benchmarks: Iterable[BenchmarkLike],
rng: Optional[np.random.Generator] = None,
):
"""Constructor.
:param env: The environment to wrap.
:param benchmarks: An iterable sequence of benchmarks. The entirety of
this input iterator is evaluated during construction.
:param rng: A random number generator to use for random benchmark
selection.
"""
self._all_benchmarks = list(benchmarks)
rng = rng or np.random.default_rng()
super().__init__(
env,
benchmarks=(rng.choice(self._all_benchmarks) for _ in iter(int, 1)),
fork_shares_iterator=True,
)
def fork(self) -> "IterateOverBenchmarks":
"""Fork the random order benchmark wrapper.
Note that RNG state is not copied to forked environments.
"""
return IterateOverBenchmarks(
env=self.env.fork(), benchmarks=self._all_benchmarks
)