Source code for drama.restrictions

# Typing
from typing import Optional, Set, Callable, Any, Union

# Standard modules
import math
from decimal import Decimal, getcontext
from abc import ABC
import random

# External modules
import numpy as np
import gymnasium as gym
from gymnasium.spaces import Box


[docs]class Restriction(ABC, gym.Space): """Base class for restrictions. All restrictions are valid :class:`gymnasium.spaces.Space`'s.""" def __init__( self, base_space: gym.Space, *, seed: int | np.random.Generator | None = None, ): """Constructor of :class:`Restriction`. Args: base_space: :class:`gymnasium.spaces.Space` whose subsets can be represented by the restriction. seed: Random seed for sampling. Defaults to None. """ super().__init__(base_space.shape, base_space.dtype, seed) self.base_space = base_space def __repr__(self) -> str: """Representation of the Restriction.""" return f"{self.__class__.__name__}"
[docs]class DiscreteRestriction(Restriction, ABC): """Representation of a :class:`gymnasium.spaces.Discrete` restriction.""" def __init__( self, base_space: gym.spaces.Discrete, *, seed: int | np.random.Generator | None = None, ): """Constructor of :class:`DiscreteRestriction`. Args: base_space: :class:`gymnasium.spaces.Discrete` whose subsets can be represented by the restriction. seed: Random seed for sampling. Defaults to None. """ super().__init__(base_space, seed=seed)
[docs]class ContinuousRestriction(Restriction, ABC): """Representation of a :class:`gymnasium.spaces.Box` restriction.""" def __init__( self, base_space: gym.spaces.Box, *, seed: int | np.random.Generator | None = None, ): """Constructor of :class:`ContinuousRestriction`. Args: base_space: :class:`gymnasium.spaces.Box` whose subsets can be represented by the restriction. seed: Random seed for sampling. Defaults to None. """ super().__init__(base_space, seed=seed)
[docs]class DiscreteSetRestriction(DiscreteRestriction): """Representation of a :class:`gymnasium.spaces.Discrete` restriction as a set of allowed actions.""" def __init__( self, base_space: gym.spaces.Discrete, *, allowed_actions: Optional[Set[int]] = None, seed: int | np.random.Generator | None = None, ): """Constructor of :class:`DiscreteSetRestriction`. Args: base_space: :class:`gymnasium.spaces.Discrete` whose subsets can be represented by the restriction allowed_actions: Optional, initial set of allowed actions seed: Random seed for sampling. Defaults to None """ super().__init__(base_space, seed=seed) self.allowed_actions = ( allowed_actions if allowed_actions is not None else set(range(base_space.start, base_space.start + base_space.n)) ) @property def is_np_flattenable(self) -> bool: """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`. Returns: `True` """ return True
[docs] def sample(self, mask: Any | None = None) -> int: """Randomly sample an action from the allowed set. Args: mask: The mask used for sampling (currently no effect) Returns: Valid discrete action """ return random.choice(tuple(self.allowed_actions))
[docs] def contains(self, x: int) -> bool: """Check if a discrete action is allowed. Args: x: The discrete action Returns: `True` if the action is allowed. `False` otherwise """ return x in self.allowed_actions
[docs] def add(self, x: int) -> None: """Adds a discrete action to the set of allowed actions. Args: x: The discrete action """ self.allowed_actions.add(x)
[docs] def remove(self, x: int) -> None: """Removes a discrete action from the set of allowed actions. Args: x: The discrete action """ self.allowed_actions.remove(x)
def __eq__(self, __value: object) -> bool: """Check if two instances of :class:`DiscreteSetRestriction` are equal.""" return ( isinstance(__value, DiscreteSetRestriction) and self.base_space == __value.base_space and self.allowed_actions == __value.allowed_actions ) def __repr__(self) -> str: """Representation of the :class:`DiscreteSetRestriction`.""" return f"{self.__class__.__name__}({self.allowed_actions})"
[docs]class DiscreteVectorRestriction(DiscreteRestriction): """Representation of a :class:`gymnasium.spaces.Discrete` restriction as a boolean vector of allowed and forbidden actions. """ def __init__( self, base_space: gym.spaces.Discrete, *, allowed_actions: Optional[np.ndarray[bool]] = None, seed: int | np.random.Generator | None = None, ): """Constructor of :class:`DiscreteVectorRestriction`. Args: base_space: :class:`gymnasium.spaces.Discrete` whose subsets can be represented by the restriction allowed_actions: Optional, initial binary vector indicating allowed actions seed: Random seed for sampling. Defaults to None """ super().__init__(base_space, seed=seed) self.allowed_actions = ( allowed_actions if allowed_actions is not None else np.ones(base_space.n, dtype=np.bool_) ) @property def is_np_flattenable(self) -> bool: """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`. Returns: `True` """ return True
[docs] def sample(self, mask: Any | None = None) -> int: """Randomly sample an action from the allowed set. Args: mask: The mask used for sampling (currently no effect) Returns: Valid discrete action """ return self.base_space.start + random.choice( tuple(index for index, value in enumerate(self.allowed_actions) if value) )
[docs] def contains(self, x: int) -> bool: """Check if a discrete action is allowed. Args: x: The discrete action Returns: `True` if the action is allowed. `False` otherwise """ return self.allowed_actions[x - self.base_space.start]
def __repr__(self) -> str: """Representation of the :class:`DiscreteVectorRestriction`.""" return f"{self.__class__.__name__}({self.allowed_actions})"
[docs]class Node(object): """Node in the AVL tree of intervals. A single instance of :class:`Node` represents an allowed interval. """ def __init__( self, x: float = None, y: float = None, left: object = None, right: object = None, height: int = 1, ): """Constructor of :class:`Node`. Args: x: Lower bound of the interval y: Upper bound of the interval left: Left, smaller interval right: Right, larger interval """ self.x: Decimal = Decimal(f"{x}") if x is not None else None self.y: Decimal = Decimal(f"{y}") if y is not None else None self.left = left self.right = right self.height = height def __str__(self): """String representation of the :class:`Node`.""" return f"<Node ({self.x},{self.y}), height: {self.height}, left: {self.left}, \ right: {self.right}>" def __repr__(self): """Representation of the :class:`Node`.""" return self.__str__()
[docs]class IntervalUnionRestriction(ContinuousRestriction): """Representation of a one-dimensional :class:`gymnasium.spaces.Box` restriction as an AVL tree of allowed intervals.""" root_tree = None size: Decimal = 0 draw = None def __init__(self, base_space: Box): """Constructor of :class:`IntervalUnionRestriction`. Args: base_space: :class:`gymnasium.spaces.Box` whose subsets can be represented by the restriction """ super().__init__(base_space) getcontext().prec = 28 self.root_tree = Node(base_space.low[0], base_space.high[0]) self.size = Decimal(f"{base_space.high[0]}") - Decimal(f"{base_space.low[0]}") @property def is_np_flattenable(self) -> bool: """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`. Returns: `True` """ return True def __contains__(self, item): """Check if a continuous action is allowed. Args: item: The continuous action Returns: `True` if the action is allowed. `False` otherwise """ return self.contains(item)
[docs] def contains(self, x: Union[np.array, float], root: object = "root"): """Check if a continuous action is allowed. Args: x: The continuous action root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: `True` if the action is allowed. `False` otherwise """ if root == "root": root = self.root_tree if isinstance(x, np.ndarray): x = x.item() x = Decimal(f"{x}") if not root: return False elif root.x <= x <= root.y: return True elif root.x > x: return self.contains(x, root.left) else: return self.contains(x, root.right)
[docs] def nearest_elements(self, x, root: Node = "root"): """Finds the closest allowed actions for a continuous action. Args: x: The continuous action root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: Nearest elements in the action space. Returns x if it is valid. """ if root == "root": root = self.root_tree x = Decimal(f"{x}") if root and x > root.y: return self._nearest_elements(x, x - root.y, root.y, root.right) elif root and x < root.x: return self._nearest_elements(x, root.x - x, root.x, root.left) else: return [x]
def _nearest_elements(self, x, min_diff, min_value, root: Node = "root"): """Helper function to find the closest allowed actions for a continuous action. Args: x: The continuous action min_diff: Minimum distance of an allowed action to x that has been found so far min_value: Allowed action with the minimum distance to x that has been found so far root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: Nearest elements in the allowed action space. Returns x if it is valid. """ if root == "root": root = self.root_tree x = Decimal(f"{x}") min_diff = Decimal(f"{min_diff}") min_value = Decimal(f"{min_value}") if not root: return [min_value] elif x > root.y: distance = x - root.y return ( [min_value, root.y] if distance == min_diff else [min_value] if distance > min_diff else self._nearest_elements(x, distance, root.y, root.right) ) elif x < root.x: distance = root.x - x return ( [min_value, root.x] if distance == min_diff else [min_value] if distance > min_diff else self._nearest_elements(x, distance, root.x, root.left) ) else: return [x]
[docs] def nearest_element(self, x, root: Node = "root"): """Finds the minimum closest allowed action for a continuous action. Args: x: The continuous action root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: Nearest element in the allowed action space. Returns x if it is valid. """ if root == "root": root = self.root_tree x = Decimal(f"{x}") return self.nearest_elements(x, root)[-1]
[docs] def last_interval_before_or_within(self, x, root: Node = "root"): """The last interval before or within a continuous action Args: x: The continuous action root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: Tuple containing the lower and upper boundaries of the interval and a variable indicating if the number lies in the interval. For example: (root.x, root.y), True """ if root == "root": root = self.root_tree x = Decimal(f"{x}") if root.x <= x <= root.y: return (root.x, root.y), True elif x < root.x: return ( self.last_interval_before_or_within(x, root.left) if root.left is not None else ((None, None), False) ) else: if root.right is not None: interval, flag = self.last_interval_before_or_within(x, root.right) if interval[0] is None: interval, flag = (root.x, root.y), False else: interval, flag = (root.x, root.y), False return ( (interval, flag) if root.right is not None else ((root.x, root.y), False) )
[docs] def first_interval_after_or_within(self, x, root: Node = "root"): """The last interval after or within a continuous action Args: x: The continuous action root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: Tuple containing the lower and upper boundaries of the interval and a variable indicating if the number lies in the interval. For example: (root.x, root.y), True """ if root == "root": root = self.root_tree x = Decimal(f"{x}") if root.x <= x <= root.y: return (root.x, root.y), True elif x > root.y: return ( self.first_interval_after_or_within(x, root.right) if root.right is not None else ((None, None), False) ) else: if root.left is not None: interval, flag = self.first_interval_after_or_within(x, root.left) if interval[0] is None: interval, flag = (root.x, root.y), False else: interval, flag = (root.x, root.y), False return ( (interval, flag) if root.left is not None else ((root.x, root.y), False) )
[docs] def smallest_interval(self, root: Node = "root"): """Return the Node of the smallest interval Args: root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: :class:`Node` of the smallest interval """ if root == "root": root = self.root_tree if root is None or root.left is None: return root else: return self.smallest_interval(root.left)
[docs] def add(self, x, y, root: Node = "root"): """Add an interval to the action space Args: x: Lower bound of the interval y: Upper bound of the interval root: :class:`Node` to start the insertion from or 'root' for inserting over the whole tree, default is 'root' Returns: Updated root :class:`Node` of the action space """ assert y > x, "Upper must be larger than lower bound" if root == "root": root = self.root_tree if root is None: self.root_tree = Node(x, y) self.size += y - x return self.root_tree x = Decimal(f"{x}") y = Decimal(f"{y}") if not root: self.size += y - x return Node(x, y) elif y < root.x: root.left = self.add(x, y, root.left) elif x > root.y: root.right = self.add(x, y, root.right) else: old_size = root.y - root.x root.x = min(root.x, x) root.y = max(root.y, y) self.size += root.y - root.x - old_size updated = False if root.right is not None and root.y >= root.right.x: self.size -= root.y - root.right.y root.y = root.right.y updated = True if root.left is not None and root.x <= root.left.y: self.size -= root.left.x - root.x root.x = root.left.x updated = True root.right = self.remove(root.x, root.y, root.right) root.left = self.remove(root.x, root.y, root.left) if updated: root = self.add(x, y, root) root.height = 1 + max(self.getHeight(root.left), self.getHeight(root.right)) b = self.getBal(root) if b > 1 and y < root.left.x and self.getBal(root.left) > 0: self.root_tree = self.rRotate(root) return self.root_tree if b < -1 and x > root.right.y and self.getBal(root.right) < 0: self.root_tree = self.lRotate(root) return self.root_tree if b > 1 and x > root.left.y and self.getBal(root.left) < 0: root.left = self.lRotate(root.left) self.root_tree = self.rRotate(root) return self.root_tree if b < -1 and y < root.right.x and self.getBal(root.right) > 0: root.right = self.rRotate(root.right) self.root_tree = self.lRotate(root) return self.root_tree self.root_tree = root return root
[docs] def sample(self, root: Node = "root") -> np.ndarray: """Randomly sample a continuous action from a uniform distribution over the allowed action space Args: root: :class:`Node` node of the action space, default is 'root' Returns: Sampled continuous action """ if root == "root": root = self.root_tree if root is None: # raise Exception("Empty Action Space") or return self.base_space.sample() if self.draw is None: self.draw = Decimal(f"{random.uniform(0.0, float(self.size))}") self.draw -= root.y - root.x if self.draw > 0: result = None if root.left is not None: result = self.sample(root.left) if not result and root.right is not None: result = self.sample(root.right) return result else: result = float(root.y + self.draw) self.draw = None return np.array([result], dtype=np.float32)
[docs] def remove(self, x, y, root: Node = "root", adjust_size: bool = True): """Removes an interval from the action space Args: x: Lower bound of the interval y: Upper bound of the interval root: :class:`Node` to start the removal from or 'root' for removing over the whole tree, default is 'root' adjust_size: Whether the size attribute of the tree should be modified Returns: Updated root :class:`Node` of the action space """ assert y > x, "Upper must be larger than lower bound" if root == "root": root = self.root_tree if root is None: return root x = Decimal(f"{x}") y = Decimal(f"{y}") if not root: return None elif x > root.x and y < root.y: self.size -= root.y - x old_maximum = root.y root.y = x root = self.add(y, old_maximum, root) elif x == root.x and y < root.y: self.size -= y - x root.x = y elif x > root.x and y == root.y: self.size -= y - x root.y = x elif x < root.x < y < root.y: self.size -= y - root.x root.x = y root.left = self.remove(x, y, root.left, adjust_size) elif root.x < x < root.y < y: self.size -= root.y - x root.y = x root.right = self.remove(x, y, root.right, adjust_size) elif y <= root.x: root.left = self.remove(x, y, root.left, adjust_size) elif x >= root.y: root.right = self.remove(x, y, root.right, adjust_size) else: if adjust_size: self.size -= root.y - root.x if root.left is None: self.root_tree = self.remove(x, y, root.right, adjust_size) return self.root_tree elif root.right is None: self.root_tree = self.remove(x, y, root.left, adjust_size) return self.root_tree rgt = self.smallest_interval(root.right) root.x = rgt.x root.y = rgt.y root.right = self.remove(rgt.x, rgt.y, root.right, adjust_size=False) root = self.remove(x, y, root, adjust_size) if not root: return None root.height = 1 + max(self.getHeight(root.left), self.getHeight(root.right)) b = self.getBal(root) if b > 1 and self.getBal(root.left) > 0: self.root_tree = self.rRotate(root) return self.root_tree if b < -1 and self.getBal(root.right) < 0: self.root_tree = self.lRotate(root) return self.root_tree if b > 1 and self.getBal(root.left) < 0: root.left = self.lRotate(root.left) self.root_tree = self.rRotate(root) return self.root_tree if b < -1 and self.getBal(root.right) > 0: root.right = self.rRotate(root.right) self.root_tree = self.lRotate(root) return self.root_tree self.root_tree = root return root
[docs] def lRotate(self, z: Node): """Performs a left rotation. Switches roles of parent and child nodes. Args: z: Parent :class:`Node` for the rotation Returns: Updated parent :class:`Node` """ y = z.right T2 = y.left y.left = z z.right = T2 z.height = 1 + max(self.getHeight(z.left), self.getHeight(z.right)) y.height = 1 + max(self.getHeight(y.left), self.getHeight(y.right)) return y
[docs] def rRotate(self, z: Node): """Performs a right rotation. Switches roles of parent and child nodes. Args: z: Parent :class:`Node` for the rotation Returns: Updated parent :class:`Node` """ y = z.left T3 = y.right y.right = z z.left = T3 z.height = 1 + max(self.getHeight(z.left), self.getHeight(z.right)) y.height = 1 + max(self.getHeight(y.left), self.getHeight(y.right)) return y
[docs] def getHeight(self, root: Node = "root"): """Returns the height of a Node Args: root: :class:`Node` to return the height from or 'root' for the height of the whole tree, default is 'root' Returns: The height of the node in the tree """ if root == "root": root = self.root_tree if not root: return 0 return root.height
[docs] def getBal(self, root: Node = "root"): """Calculate the balance factor Args: root: :class:`Node` to calculate the balance factor for or 'root' for the balance factor of the whole tree, default is 'root' Returns: The balance factor """ if root == "root": root = self.root_tree if not root: return 0 return self.getHeight(root.left) - self.getHeight(root.right)
[docs] def intervals(self): """Return all intervals of the allowed action space in an ordered way. Returns: List of tuples containing the ordered intervals. For example: [(0.1,0.5), (0.7,0.9)] """ return self._intervals()
def _intervals(self, root: Node = "root"): """Return all allowed intervals starting from a specific node in an ordered way. Args: root: :class:`Node` to start the search from or 'root' for searching the whole tree, default is 'root' Returns: List of tuples containing the ordered intervals. For example: [(0.1,0.5), (0.7,0.9)] """ if root == "root": root = self.root_tree if root is None: return [] ordered = [] if root.left is not None: ordered = ordered + self._intervals(root.left) ordered.append((float(root.x), float(root.y))) if root.right is not None: ordered = ordered + self._intervals(root.right) return ordered def __str__(self): """String representation of the :class:`IntervalUnionRestriction`.""" return f"{self.__class__.__name__}({self.intervals()})" def __repr__(self): """Representation of the :class:`IntervalUnionRestriction`.""" return self.__str__()
[docs]class BucketSpaceRestriction(ContinuousRestriction): """Representation of a one-dimensional :class:`gymnasium.spaces.Box` restriction as a binary vector indicating the availability of equally sized buckets.""" def __init__( self, base_space: Box, bucket_width=1.0, epsilon=0.01, available_buckets: np.ndarray = None, ) -> None: """Constructor of :class:`BucketSpaceRestriction`. Args: base_space: :class:`gymnasium.spaces.Box` whose subsets can be represented by the restriction bucket_width: The width of each bucket epsilon: The radius in which buckets are set valid/invalid around a specific point available_buckets: The binary vector indicating the allowed subsets """ super().__init__(base_space) assert isinstance(self.base_space, Box) self.a, self.b = Decimal(f"{self.base_space.low.item()}"), Decimal( f"{self.base_space.high.item()}" ) self.bucket_width, self.epsilon = Decimal(f"{bucket_width}"), Decimal( f"{epsilon}" ) self.number_of_buckets = math.ceil((self.b - self.a) / self.bucket_width) if available_buckets: assert ( len(available_buckets) == self.number_of_buckets ), "Not all available bucket indicators provided!" assert np.all( [index in [1.0, 0.0] for index in available_buckets] ), "No boolean bucket indicators!" self.buckets = available_buckets else: self.buckets = np.ones((self.number_of_buckets,), dtype=bool) @property def is_np_flattenable(self) -> bool: """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`. Returns: `True` """ return True
[docs] def contains(self, x): """Check if a continuous action is allowed. Args: x: The continuous action Returns: `True` if the action is allowed. `False` otherwise """ return False if x < self.a or x >= self.b else self.buckets[self._bucket(x)]
[docs] def sample(self, mask: None = None): """Randomly sample a continuous action from a uniform distribution over the allowed action space Args: mask: The mask used for sampling (currently no effect) Returns: Sampled continuous action """ if not self.intervals: return None else: x = Decimal(f"{random.uniform(0.0, float(self.b - self.a))}") for i, (a, b) in enumerate(self.intervals): if x > Decimal(b) - Decimal(a): x -= Decimal(b) - Decimal(a) else: return Decimal(a) + x return self.intervals[-1][1]
[docs] def clone(self): """Returns a copy of the :class:`BucketSpaceRestriction` Returns: :class:`BucketSpaceRestriction` copy """ assert isinstance(self.base_space, Box) return BucketSpaceRestriction( self.base_space, bucket_width=float(self.bucket_width), epsilon=float(self.epsilon), available_buckets=self.buckets, )
[docs] def clone_and_remove(self, x): """Returns a copy of the :class:`BucketSpaceRestriction` without buckets containing a specific value Args: x: Buckets containing this value should be removed from the allowed action space Returns: :class:`BucketSpaceRestriction` copy """ space = self.clone() space.remove(x) return space
[docs] def remove(self, x, with_epsilon=True): """Remove buckets containing a specific value from the allowed action space Args: x: Buckets containing this value should be removed from the allowed action space with_epsilon: If `True`, a subset of epsilon around x is removed. Otherwise, only buckets containing the specific value are removed. """ x = Decimal(f"{x}") if with_epsilon: self._set(x, False) else: self.buckets[self._bucket(x)] = False
[docs] def add(self, x, with_epsilon=True): """Add buckets containing a specific value to the allowed action space Args: x: Buckets containing this value should be added to the allowed action space with_epsilon: If `True`, a subset of epsilon around x is added. Otherwise, only buckets containing the specific value are added. """ x = Decimal(f"{x}") if with_epsilon: self._set(x) else: self.buckets[self._bucket(x)] = True
@property def intervals(self): """Return all intervals of the allowed action space in an ordered way. Returns: List of tuples containing the ordered intervals. For example: [(0.1,0.5), (0.7,0.9)] """ a, intervals = None, [] for i in range(self.number_of_buckets): if a is None: if self.buckets[i]: a = self.a + i * self.bucket_width elif not self.buckets[i]: intervals.append((float(a), float(self.a + i * self.bucket_width))) a = None elif i == self.number_of_buckets - 1: intervals.append((float(a), float(self.b))) return intervals def _bucket(self, x): """Return the bucket which contains a specific value Args: x: Value for which the bucket has to be found Returns: Indicator of the bucket """ return math.floor((x - self.a) / self.bucket_width) def _set(self, x, value=True): """Set the indicator value for the bucket of a specific value manually. Args: x: The indicator value for the bucket containing x is modified value: If `True`, the bucket containing x belongs to the allowed action space. Otherwise, the bucket is unavailable. """ lower_bucket = ( self._bucket(x - self.epsilon) if x - self.epsilon >= self.a else None ) upper_bucket = ( self._bucket(x + self.epsilon) if x + self.epsilon <= self.b else None ) if lower_bucket is None: if upper_bucket is None: self.buckets = ( np.ones((self.number_of_buckets,), dtype=bool) if value else np.zeros((self.number_of_buckets,), dtype=bool) ) else: self.buckets[: upper_bucket + 1] = value else: if upper_bucket is None: self.buckets[lower_bucket:] = value else: self.buckets[lower_bucket : upper_bucket + 1] = value
[docs] def reset(self): """Resets the action space to the unrestricted state""" self.buckets = np.ones((self.number_of_buckets,), dtype=bool)
def __str__(self): """String representation of the :class:`IntervalUnionRestriction`.""" intervals = ( " ".join(f"[{float(a)}, {float(b)})" for a, b in self.intervals) if self.intervals else "()" ) return f"<BucketSpace {intervals}>" def __repr__(self): """Representation of the :class:`IntervalUnionRestriction`.""" return self.__str__() def __bool__(self): return bool(np.any(self.buckets)) def __contains__(self, item): return self.contains(item) def __hash__(self): return hash((self.a, self.b, self.bucket_width, tuple(self.intervals))) def __eq__(self, other): return (self.a, self.b, self.bucket_width, tuple(self.intervals)) == ( other.a, other.b, other.bucket_width, tuple(other.intervals), )
[docs]class PredicateRestriction(Restriction): """Representation of an arbitrary space as the set of elements for which a predicate is True.""" def __init__( self, base_space: gym.Space, *, predicate: Optional[Callable[[Any], bool]] = None, seed: int | np.random.Generator | None = None, ): super().__init__(base_space, seed=seed) self.predicate = predicate if predicate is not None else (lambda x: True) @property def is_np_flattenable(self) -> bool: """Checks whether this space can be flattened to a :class:`gymnasium.spaces.Box`. Returns: `False` """ return False
[docs] def sample(self, mask: Any | None = None) -> int: """Randomly sample a set of elements for which the predicate is True Args: mask: The mask used for sampling (currently no effect) Returns: Sampled set of elements """ raise NotImplementedError
[docs] def contains(self, x: Any) -> bool: """Check if an action is allowed and the predicate is True. Args: x: The action Returns: `True` if the action is allowed. `False` otherwise """ return self.base_space.contains(x) and self.predicate(x)