# Typing
import math
from typing import Any
# Standard modules
from abc import ABC
# External modules
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Discrete
from pettingzoo import AECEnv
# Internal modules
from drama.restrictions import (
Restriction,
IntervalUnionRestriction,
DiscreteVectorRestriction,
DiscreteSetRestriction,
BucketSpaceRestriction,
)
[docs]class RestrictorActionSpace(ABC, gym.Space):
"""Action space representing sets of restrictions for a given base space."""
def __init__(
self, base_space: gym.Space, seed: int | np.random.Generator | None = None
):
"""Constructor of :class:`RestrictorActionSpace`.
Args:
base_space: The base action space which is restricted
seed: Optionally, you can use this argument to seed the RNG that is used to sample from the space
"""
super().__init__(None, None, seed)
self.base_space = base_space
[docs] def contains(self, x: Restriction) -> bool:
"""Check if a restriction was created with the same base space.
Args:
x: The restriction
Returns:
`True` if the restriction is compatible with the restrictor action space.
`False` otherwise
"""
return x.base_space == self.base_space
[docs] def sample(self, mask: Any | None = None) -> Restriction:
"""Randomly sample a restriction from the restrictor action space.
Args:
mask: The mask used for sampling
Returns:
A sampled restriction
"""
raise NotImplementedError
[docs] def is_compatible_with(self, action_space: gym.Space):
"""Check if a action space is compatible with the restrictor action space.
Args:
action_space: The action space which is checked for compatibility
Returns:
`True` if the action space is compatible with the restrictor action space.
`False` otherwise
"""
return self.base_space == action_space
def __repr__(self) -> str:
"""String representation of the restrictor action space."""
return f"{self.__class__.__name__}(base_space={self.base_space})"
[docs]class Restrictor(ABC):
"""An agent whose actions are restrictions."""
def __init__(self, observation_space, action_space) -> None:
"""Constructor of :class:`Restrictor`.
Args:
observation_space: The observation space of the restrictor
action_space: The action space of the restrictor
"""
self.observation_space = observation_space
self.action_space = action_space
[docs] def preprocess_observation(self, env: AECEnv) -> Any:
"""Pre-processing function applied by the :class:`RestrictionWrapper`
before the observation is forwarded to act().
Args:
env: The environment at the point in time
Returns:
The restrictor observation
"""
return env.state()
[docs] def act(self, observation: gym.Space) -> Restriction:
"""Compute the restriction for a observation.
Args:
observation: The observation used to compute the restriction
Returns:
The computed restriction
"""
raise NotImplementedError
[docs]class DiscreteSetActionSpace(RestrictorActionSpace):
"""Action space representing valid restrictions for a discrete action space as a set of allowed actions."""
def __init__(self, base_space: Discrete):
"""Constructor of :class:`DiscreteSetActionSpace`.
Args:
base_space: The :class:`gymnasium.spaces.Discrete` action space which is restricted
"""
super().__init__(base_space)
@property
def is_np_flattenable(self) -> bool:
"""Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`.
Returns:
`True`
"""
return True
[docs] def sample(self, mask: Any | None = None) -> DiscreteSetRestriction:
"""Randomly sample an instance of :class:`DiscreteSetRestriction` from the :class:`DiscreteSetActionSpace`.
Args:
mask: The mask used for sampling (currently no effect)
Returns:
A sampled :class:`DiscreteSetRestriction`
"""
assert isinstance(self.base_space, Discrete)
discrete_set = DiscreteSetRestriction(
self.base_space,
allowed_actions=set(
np.arange(
self.base_space.start, self.base_space.start + self.base_space.n
)[np.random.choice([True, False], size=self.base_space.n)]
),
)
return discrete_set
[docs]class DiscreteVectorActionSpace(RestrictorActionSpace):
"""Action space representing valid restrictions for a discrete action space as a binary vector."""
def __init__(self, base_space: Discrete):
"""Constructor of :class:`DiscreteVectorActionSpace`.
Args:
base_space: The :class:`gymnasium.spaces.Discrete` action space which is restricted
"""
super().__init__(base_space)
@property
def is_np_flattenable(self) -> bool:
"""Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`.
Returns:
`True`
"""
return True
[docs] def sample(self, mask: Any | None = None) -> DiscreteVectorRestriction:
"""Randomly sample an instance of :class:`DiscreteVectorRestriction` from the :class:`DiscreteVectorActionSpace`.
Args:
mask: The mask used for sampling (currently no effect)
Returns:
A sampled :class:`DiscreteVectorRestriction`
"""
assert isinstance(self.base_space, Discrete)
discrete_vector = DiscreteVectorRestriction(
self.base_space,
allowed_actions=np.random.choice([True, False], self.base_space.n),
)
return discrete_vector
[docs]class IntervalUnionActionSpace(RestrictorActionSpace):
"""Action space representing valid restrictions for a :class:`gymnasium.spaces.Box` action space
as a union of intervals."""
def __init__(self, base_space: Box):
"""Constructor of :class:`IntervalUnionActionSpace`.
Args:
base_space: The :class:`gymnasium.spaces.Box` action space which is restricted
"""
super().__init__(base_space)
@property
def is_np_flattenable(self) -> bool:
"""Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`.
Returns:
`True`
"""
return True
[docs] def sample(self, mask: Any | None = None) -> IntervalUnionRestriction:
"""Randomly sample an instance of :class:`IntervalUnionRestriction` from the :class:`IntervalUnionActionSpace`.
Args:
mask: The mask used for sampling (currently no effect)
Returns:
A sampled :class:`IntervalUnionRestriction`
"""
assert isinstance(self.base_space, Box)
interval_union = IntervalUnionRestriction(self.base_space)
num_intervals = self.np_random.geometric(0.25)
for _ in range(num_intervals):
interval_start = self.np_random.uniform(
self.base_space.low[0], self.base_space.high[0]
)
interval_union.remove(
interval_start,
self.np_random.uniform(interval_start, self.base_space.high[0]),
)
return interval_union
[docs]class BucketSpaceActionSpace(RestrictorActionSpace):
"""Action space representing valid restrictions for a :class:`gymnasium.spaces.Box` action space
as a binary indicator vector for evenly split buckets."""
def __init__(self, base_space: Box, bucket_width=1.0, epsilon=0.01):
"""Constructor of :class:`BucketSpaceActionSpace`.
Args:
base_space: The :class:`gymnasium.spaces.Box` action space which is restricted
bucket_width: The width of each bucket
epsilon: The radius in which buckets are set valid/invalid around a specific point
"""
super().__init__(base_space)
assert isinstance(self.base_space, Box)
self.bucket_width = bucket_width
self.epsilon = epsilon
self.number_of_buckets = math.ceil(
(self.base_space.high.item() - self.base_space.low.item())
/ self.bucket_width
)
@property
def is_np_flattenable(self) -> bool:
"""Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`.
Returns:
`True`
"""
return True
[docs] def sample(self, mask: Any | None = None) -> BucketSpaceRestriction:
"""Randomly sample an instance of :class:`BucketSpaceRestriction` from the :class:`BucketSpaceActionSpace`.
Args:
mask: The mask used for sampling (currently no effect)
Returns:
A sampled :class:`BucketSpaceRestriction`
"""
assert isinstance(self.base_space, Box)
return BucketSpaceRestriction(
self.base_space,
self.bucket_width,
self.epsilon,
available_buckets=np.random.choice([True, False], self.number_of_buckets),
)
class PredicateActionSpace(RestrictorActionSpace):
pass