Source code for drama.utils

# Typing
from typing import Any

# Standard modules
from collections import OrderedDict
from functools import singledispatch

# External modules
import gymnasium.spaces.utils
import numpy as np
from numpy.typing import NDArray
from gymnasium import Space
from gymnasium.spaces import Dict
from gymnasium.spaces.utils import FlatType, T

# Internal modules
from drama.restrictions import (
    DiscreteSetRestriction,
    DiscreteVectorRestriction,
    IntervalUnionRestriction,
    BucketSpaceRestriction,
    Restriction,
)
from drama.restrictors import (
    DiscreteSetActionSpace,
    DiscreteVectorActionSpace,
    IntervalUnionActionSpace,
    BucketSpaceActionSpace,
)


[docs]class IntervalsOutOfBoundException(Exception): """Thrown when the number of intervals exceeds the specified maximum representation length""" def __init__(self, *args) -> None: super().__init__(*args)
[docs]class RestrictionViolationException(Exception): """Thrown when a restriction is violated""" def __init__(self, *args) -> None: super().__init__(*args)
def random_replacement_violation_fn(env, action, restriction: Restriction): """Random replacement of restriction violations. Args: env: The environment after the agent step was taken action: The action which violated the restriction restriction: The restriction object corresponding to the action Returns: Randomly sampled action from the allowed action space """ env.step(restriction.sample()) def projection_violation_fn(env, action, restriction: IntervalUnionRestriction): """Euclidean projection of restriction violations. Args: env: The environment after the agent step was taken action: The action which violated the restriction restriction: The restriction object corresponding to the action Returns: Closest allowed action minimizing the euclidean distance """ env.step(np.array([restriction.nearest_element(action.item())], dtype=np.float32)) # flatten functions for restriction classes @singledispatch def flatten(space: Space, x: T, **kwargs) -> FlatType: """Flatten a data point from a space. Args: space: The space that ``x`` is flattened by x: The value to flatten **kwargs: Additional flattening parameters passed to subsequent flatten operations Returns: The flattened datapoint """ return gymnasium.spaces.utils.flatten(space, x) @flatten.register(Dict) def _flatten_dict(space: Dict, x: T, **kwargs): if space.is_np_flattenable: return np.concatenate( [np.array(flatten(s, x[key], **kwargs)) for key, s in space.spaces.items()] ) return OrderedDict( (key, flatten(s, x[key], **kwargs)) for key, s in space.spaces.items() ) @flatten.register(DiscreteSetActionSpace) def _flatten_discrete_set_action_space( space: DiscreteSetActionSpace, x: DiscreteSetRestriction ) -> NDArray: return np.asarray(list(x.allowed_actions), dtype=space.dtype) @flatten.register(DiscreteVectorActionSpace) def _flatten_discrete_vector_action_space( space: DiscreteVectorActionSpace, x: DiscreteVectorRestriction ): return np.asarray(x.allowed_actions, dtype=space.dtype) @flatten.register(IntervalUnionActionSpace) def _flatten_interval_union_restriction( space: IntervalUnionActionSpace, x: IntervalUnionRestriction, pad: bool = True, clamp: bool = True, max_len: int = 7, pad_value: float = 0.0, raise_error: bool = True, ): intervals = np.array(x.intervals(), dtype=np.float32) if raise_error and intervals.shape[0] > max_len: raise IntervalsOutOfBoundException if clamp: intervals = intervals[:max_len] if pad: padding = np.full((max_len - intervals.shape[0], 2), pad_value) return ( np.concatenate([intervals, padding], axis=0, dtype=np.float32).flatten() if len(intervals) > 0 else padding.flatten() ) return intervals.flatten() @flatten.register(BucketSpaceActionSpace) def _flatten_bucket_space_action_space( space: BucketSpaceActionSpace, x: BucketSpaceRestriction ): return np.concatenate( [ np.array( [ space.base_space.low.item(), space.base_space.high.item(), space.number_of_buckets, ] ), x.buckets, ] ) # flatdim functions for restriction classes @singledispatch def flatdim(space: Space[Any]) -> int: """Return the number of dimensions a flattened equivalent of this space would have. Args: space: The space to return the number of dimensions of the flattened spaces Returns: The number of dimensions for the flattened spaces Raises: NotImplementedError: if the space is not defined in :mod:`gym.spaces`. ValueError: if the space cannot be flattened into a :class:`gymnasium.spaces.Box` """ return gymnasium.spaces.utils.flatdim(space) @flatdim.register(DiscreteSetActionSpace) def _flatdim_discrete_set_action_space(space: DiscreteSetActionSpace): raise ValueError( "Cannot get flattened size as the DiscreteSetActionSpace has a dynamic size." ) @flatdim.register(DiscreteVectorActionSpace) def _flatdim_discrete_vector_action_space(space: DiscreteVectorActionSpace) -> int: return space.base_space.n @flatdim.register(IntervalUnionActionSpace) def _flatdim_graph(space: IntervalUnionActionSpace): raise ValueError( "Cannot get flattened size as the IntervalUnionActionSpace has a dynamic size." ) @flatdim.register(BucketSpaceActionSpace) def _flatdim_bucket_space_action_space(space: BucketSpaceActionSpace) -> int: return space.number_of_buckets + 3 # unflatten functions for restriction classes @singledispatch def unflatten(space: Space[T], x: FlatType) -> T: """Unflatten a data point from a space. Args: space: The space used to unflatten ``x`` x: The array to unflatten Returns: A point with a structure that matches the space. Raises: NotImplementedError: if the space is not defined in :mod:`gymnasium.spaces`. """ return gymnasium.spaces.utils.unflatten(space, x) @unflatten.register(DiscreteSetActionSpace) def _unflatten_discrete_set_action_space( space: DiscreteSetActionSpace, x: FlatType ) -> DiscreteSetRestriction: return DiscreteSetRestriction(space.base_space, allowed_actions=set(x)) @unflatten.register(DiscreteVectorActionSpace) def _unflatten_discrete_vector_action_space( space: DiscreteVectorActionSpace, x: FlatType ) -> DiscreteVectorRestriction: return DiscreteVectorRestriction(space.base_space, allowed_actions=np.array(x)) @unflatten.register(IntervalUnionActionSpace) def _unflatten_interval_union_action_space( space: IntervalUnionActionSpace, x: FlatType ) -> T: unpadded_intervals = np.array( [[x[i], x[i + 1]] for i in range(0, len(x), 2) if x[i] != x[i + 1]] ).flatten() interval_union_space = IntervalUnionRestriction(space.base_space) for i in list(range(0, len(unpadded_intervals) + 2, 2)): if i == 0: if unpadded_intervals[i] != space.base_space.low.item(): interval_union_space.remove( space.base_space.low[0], unpadded_intervals[i - 1] ) elif i == len(unpadded_intervals): if unpadded_intervals[i - 1] != space.base_space.high.item(): interval_union_space.remove( unpadded_intervals[i - 1], space.base_space.high.item() ) else: interval_union_space.remove( unpadded_intervals[i - 1], unpadded_intervals[i] ) return interval_union_space @unflatten.register(BucketSpaceActionSpace) def _unflatten_bucket_space_action_space( space: BucketSpaceActionSpace, x: FlatType ) -> BucketSpaceRestriction: return BucketSpaceRestriction( space.base_space, space.bucket_width, space.epsilon, x[3:] )