# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module defines a utility function for computing LLVM observations."""
import subprocess
from pathlib import Path
from typing import List
import google.protobuf.text_format
from compiler_gym.service.proto import Event
from compiler_gym.util.commands import Popen
from compiler_gym.util.gym_type_hints import ObservationType
from compiler_gym.util.runfiles_path import runfiles_path
from compiler_gym.util.shell_format import plural
from compiler_gym.views.observation_space_spec import ObservationSpaceSpec
_COMPUTE_OBSERVATION_BIN = runfiles_path(
"compiler_gym/envs/llvm/service/compute_observation"
)
def pascal_case_to_enum(pascal_case: str) -> str:
"""Convert PascalCase to ENUM_CASE."""
word_arrays: List[List[str]] = [[]]
for c in pascal_case:
if c.isupper() and word_arrays[-1]:
word_arrays.append([c])
else:
word_arrays[-1].append(c.upper())
return "_".join(["".join(word) for word in word_arrays])
[docs]def compute_observation(
observation_space: ObservationSpaceSpec, bitcode: Path, timeout: float = 300
) -> ObservationType:
"""Compute an LLVM observation.
This is a utility function that uses a standalone C++ binary to compute an
observation from an LLVM bitcode file. It is intended for use cases where
you want to compute an observation without the overhead of initializing a
full environment.
Example usage:
>>> env = compiler_gym.make("llvm-v0")
>>> space = env.observation.spaces["Ir"]
>>> bitcode = Path("bitcode.bc")
>>> observation = llvm.compute_observation(space, bitcode, timeout=30)
.. warning::
This is not part of the core CompilerGym API and may change in a future
release.
:param observation_space: The observation that is to be computed.
:param bitcode: The path of an LLVM bitcode file.
:param timeout: The maximum number of seconds to allow the computation to
run before timing out.
:raises ValueError: If computing the observation fails.
:raises TimeoutError: If computing the observation times out.
:raises FileNotFoundError: If the given bitcode does not exist.
"""
if not Path(bitcode).is_file():
raise FileNotFoundError(bitcode)
observation_space_name = pascal_case_to_enum(observation_space.id)
try:
with Popen(
[str(_COMPUTE_OBSERVATION_BIN), observation_space_name, str(bitcode)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as process:
stdout, stderr = process.communicate(timeout=timeout)
if process.returncode:
try:
stderr = stderr.decode("utf-8")
raise ValueError(
f"Failed to compute {observation_space.id} observation: {stderr}"
)
except UnicodeDecodeError as e:
raise ValueError(
f"Failed to compute {observation_space.id} observation"
) from e
except subprocess.TimeoutExpired as e:
raise TimeoutError(
f"Failed to compute {observation_space.id} observation in "
f"{timeout:.1f} {plural(int(round(timeout)), 'second', 'seconds')}"
) from e
try:
stdout = stdout.decode("utf-8")
except UnicodeDecodeError as e:
raise ValueError(
f"Failed to parse {observation_space.id} observation: {e}"
) from e
observation = Event()
try:
google.protobuf.text_format.Parse(stdout, observation)
except google.protobuf.text_format.ParseError as e:
raise ValueError(f"Failed to parse {observation_space.id} observation") from e
return observation_space.translate(observation)