# Typing
from typing import Union, Callable, Optional
# Standard modules
import functools
# External modules
from gymnasium.spaces import Dict
from pettingzoo import AECEnv
from pettingzoo.utils import BaseWrapper
# Internal modules
from drama.restrictions import Restriction
from drama.restrictors import Restrictor
from drama.utils import flatten, RestrictionViolationException
# If no functions are provided for some or all restrictors, use these defaults
def _default_restrictor_reward_fn(env, rewards):
"""Default restrictor reward function as the social welfare.
Args:
env: The environment after the agent step was taken
rewards: The rewards of each agent
Returns:
The restrictor reward
"""
return sum(rewards.values())
def _default_preprocess_restrictor_observation_fn(env):
"""Default pre-processing of the restrictor observation.
Args:
env: The environment at the point in time
Returns:
The restrictor observation
"""
return env.state()
def _default_postprocess_restriction_fn(restriction):
"""Default post-processing of the restriction.
Args:
restriction: The restriction derived from the restrictor
Returns:
The post-processed restriction
"""
return restriction
def _default_restriction_violation_fn(env, action, restriction: Restriction):
"""Default handling of restriction violations.
Args:
env: The environment after the agent step was taken
action: The action which violated the restriction
restriction: The restriction object corresponding to the action
Raises:
RestrictionViolationException: If the restriction is violated
"""
raise RestrictionViolationException()
[docs]class RestrictionWrapper(BaseWrapper):
"""Wrapper that implements the agent-restrictor-environment loop of DRAMA:
Reset() -> Restrictor of Agent_0 -> Step() -> Agent_0 -> Step()
-> Restrictor of Agent_1 -> Step() -> Agent_1 -> ...
"""
def __init__(
self,
env: AECEnv,
restrictors: Union[dict, Restrictor],
*,
agent_restrictor_mapping: Optional[dict] = None,
restrictor_reward_fns: Union[dict, Callable] = None,
preprocess_restrictor_observation_fns: Union[dict, Callable] = None,
postprocess_restriction_fns: Union[dict, Callable] = None,
restriction_violation_fns: Union[dict, Callable] = None,
restriction_key: str = "restriction",
observation_key: str = "observation",
return_object: bool = False,
**kwargs,
):
"""Constructor of :class:`RestrictionWrapper`.
Args:
env: The environment to apply the wrapper
restrictors: The restrictors to apply before each agent's step.
:class:`Dictionary` mapping IDs to restrictors or :class:`Restrictor` with default ID `restrictor_0`
agent_restrictor_mapping: The assignment of restrictors to agents.
:class:`Dictionary` mapping agent to restrictor IDs.
By default, a single restrictor is assigned to all agents
restrictor_reward_fns: The reward function for each restrictor.
:class:`Dictionary` mapping restrictor IDs to reward functions
or :class:`Callable` applied to all restrictor rewards.
By default, the social welfare is used
preprocess_restrictor_observation_fns: The pre-processing function for each restrictor observation.
:class:`Dictionary` mapping restrictor IDs to pre-processing functions
or :class:`Callable` applied to derive all restrictor observations likewise.
By default, the state of the environment is returned
postprocess_restriction_fns: The post-processing function for each restriction.
:class:`Dictionary` mapping restrictor IDs to post-processing functions
or :class:`Callable` applied to post-process all restrictions likewise.
By default, the unmodified restrictions apply
restriction_violation_fns: The callback to handle restriction violations.
:class:`Dictionary` mapping restrictor IDs to violation functions
or :class:`Callable` applied to all restriction violations.
By default, a :class:`RestrictionViolationException` is raised
restriction_key: Key for the restriction in the agent observation
observation_key: Key for the original observation in the agent observation
return_object: If `True`, the restriction object will be returned.
Otherwise, if possible, the restriction object is flattened
**kwargs: Additional arguments for the flatten operation
"""
super().__init__(env)
if isinstance(restrictors, dict):
assert agent_restrictor_mapping, "Agent-restrictor mapping required!"
self.restrictors = (
restrictors
if isinstance(restrictors, dict)
else {"restrictor_0": restrictors} # Naming convention from PettingZoo
)
self.agent_restrictor_mapping = (
agent_restrictor_mapping
if isinstance(restrictors, dict)
else {agent: "restrictor_0" for agent in self.env.possible_agents}
)
self.restrictor_reward_fns = (
{
restrictor: restrictor_reward_fns[restrictor]
if restrictor_reward_fns and restrictor_reward_fns.get(restrictor, None)
else _default_restrictor_reward_fn
for restrictor in self.restrictors
}
if isinstance(restrictor_reward_fns, Union[dict, None])
else {restrictor: restrictor_reward_fns for restrictor in self.restrictors}
)
# Set restrictor observation preprocessing functions
if isinstance(preprocess_restrictor_observation_fns, Callable):
self.preprocess_restrictor_observation_fns = {
restrictor: preprocess_restrictor_observation_fns
for restrictor in self.restrictors
}
else:
self.preprocess_restrictor_observation_fns = {
restrictor: _default_preprocess_restrictor_observation_fn
for restrictor in self.restrictors
}
for name, restrictor in self.restrictors.items():
if isinstance(
preprocess_restrictor_observation_fns, dict
) and preprocess_restrictor_observation_fns.get(name, None):
self.preprocess_restrictor_observation_fns[
name
] = preprocess_restrictor_observation_fns[name]
elif hasattr(restrictor, "preprocess_observation"):
self.preprocess_restrictor_observation_fns[
name
] = restrictor.preprocess_observation
# Set restriction postprocessing functions
if isinstance(postprocess_restriction_fns, Callable):
self.postprocess_restriction_fns = {
restrictor: postprocess_restriction_fns
for restrictor in self.restrictors
}
else:
self.postprocess_restriction_fns = {
restrictor: _default_postprocess_restriction_fn
for restrictor in self.restrictors
}
for name, restrictor in self.restrictors.items():
if isinstance(
postprocess_restriction_fns, dict
) and postprocess_restriction_fns.get(name, None):
self.postprocess_restriction_fns[
name
] = postprocess_restriction_fns[name]
elif hasattr(restrictor, "postprocess_restriction"):
self.postprocess_restriction_fns[
name
] = restrictor.postprocess_restriction
self.restriction_violation_fns = (
{
agent: restriction_violation_fns[agent]
if restriction_violation_fns
and restriction_violation_fns.get(agent, None)
else _default_restriction_violation_fn
for agent in self.env.possible_agents
}
if isinstance(restriction_violation_fns, Union[dict, None])
else {
agent: restriction_violation_fns for agent in self.env.possible_agents
}
)
self.restriction_key = restriction_key
self.observation_key = observation_key
self.return_object = return_object
self.kwargs = {**kwargs}
# self.restrictions is a dictionary which keeps the latest value for each agent
self.restrictions = None
self.possible_agents = self.possible_agents + list(self.restrictors)
# Check if restrictor action spaces (after post-processing) match
# agent action spaces
for agent in self.env.possible_agents:
restrictor = self.restrictors[self.agent_restrictor_mapping[agent]]
sample_restriction = self.postprocess_restriction_fns[
self.agent_restrictor_mapping[agent]
](restrictor.action_space.sample())
assert sample_restriction.base_space == env.action_space(
agent
), f"The action spaces of {self.agent_restrictor_mapping[agent]} and {agent} are not compatible!"
[docs] @functools.lru_cache(maxsize=None)
def observation_space(self, agent):
"""Takes in agent or restrictor and returns the observation space for that agent or restrictor."""
if agent in self.restrictors:
return self.restrictors[agent].observation_space
else:
return Dict(
{
self.observation_key: self.env.observation_space(agent),
self.restriction_key: self.restrictors[
self.agent_restrictor_mapping[agent]
].action_space,
}
)
[docs] @functools.lru_cache(maxsize=None)
def action_space(self, agent):
"""Takes in agent or restrictor and returns the action space for that agent or restrictor."""
if agent in self.restrictors:
return self.restrictors[agent].action_space
else:
return self.env.action_space(agent)
[docs] def reset(self, seed=None, options=None):
"""Resets the agent-restrictor-environment loop to a starting state."""
self.env.reset(seed, options)
# Set properties
self.rewards = {
**self.env.rewards,
**{restrictor: 0.0 for restrictor in self.restrictors},
}
self.terminations = {
**self.env.terminations,
**{restrictor: False for restrictor in self.restrictors},
}
self.truncations = {
**self.env.truncations,
**{restrictor: False for restrictor in self.restrictors},
}
self.infos = {
**self.env.infos,
**{restrictor: {} for restrictor in self.restrictors},
}
self.agents = self.env.agents + list(
set(self.agent_restrictor_mapping[agent] for agent in self.env.agents)
)
self._cumulative_rewards = {
**self.env._cumulative_rewards,
**{restrictor: 0.0 for restrictor in self.restrictors},
}
self.restrictions = {agent: None for agent in self.env.agents}
# Start an episode with the restrictor of the first agent to obtain a
# restriction
self.agent_selection = self.agent_restrictor_mapping[self.env.agent_selection]
[docs] def step(self, action):
"""Accepts and executes the action or restriction of the current agent_selection in the environment.
Automatically switches control between the agents and restrictors.
"""
if self.agent_selection in self.restrictors:
# If the action was taken by the restrictor, check if it was terminated
# last step
if self.terminations[self.agent_selection]:
self._was_dead_step(action)
self.agent_selection = self.env.agent_selection
return
# Reset cumulative reward for the current restrictor
# self._cumulative_rewards[self.agent_selection] = 0
# Otherwise set the restrictions that apply to the next agent.
assert (
self.agent_restrictor_mapping[self.env.agent_selection]
== self.agent_selection
)
# self.restrictions[self.env.agent_selection] = action
self.restrictions[
self.env.agent_selection
] = self.postprocess_restriction_fns[self.agent_selection](action)
# Switch to the next agent of the original environment
self.agent_selection = self.env.agent_selection
else:
# Check if the action violated the current restriction for the agent
if action and not self.restrictions[self.agent_selection].contains(action):
self.restriction_violation_fns[self.agent_selection](
self.env, action, self.restrictions[self.agent_selection]
)
else:
# If the action was taken by an agent, execute it in the original
# environment
self.env.step(action)
# Update properties
self.agents = self.env.agents + list(
set(self.agent_restrictor_mapping[agent] for agent in self.env.agents)
)
self.rewards = {
**self.env.rewards,
**{
restrictor: self.restrictor_reward_fns[restrictor](
self.env, self.env.rewards
)
for restrictor in self.restrictors
},
}
self.terminations = {
**self.env.terminations,
**{
restrictor: all(
self.env.terminations[agent] or self.env.truncations[agent]
for agent in self.env.agents
)
for restrictor in self.restrictors
},
}
self.truncations = {
**self.env.truncations,
**{restrictor: False for restrictor in self.restrictors},
}
self.infos = {
**self.env.infos,
**{restrictor: {} for restrictor in self.restrictors},
}
self._cumulative_rewards = {
**self.env._cumulative_rewards,
**{
restrictor: self.restrictor_reward_fns[restrictor](
self.env, self.env._cumulative_rewards
)
for restrictor in self.restrictors
},
}
if self.env.agents and all(
self.env.terminations[agent] or self.env.truncations[agent]
for agent in self.env.agents
):
# If there are alive agents left, get the next restriction
self.agent_selection = self.env.agent_selection
else:
# Otherwise, get the next restriction
self.agent_selection = self.agent_restrictor_mapping[
self.env.agent_selection
]
[docs] def observe(self, agent: str, return_object: bool = None, **kwargs):
"""Returns the observation an agent or restrictor currently can make.
`last()` calls this function.
"""
if agent in self.restrictors:
return self.preprocess_restrictor_observation_fns[agent](self.env)
else:
return_object = (
return_object if return_object is not None else self.return_object
)
return {
self.observation_key: super().observe(agent),
self.restriction_key: self.restrictions[agent]
if return_object and self.restrictions[agent].is_np_flattenable
else flatten(
self.restrictors[self.agent_restrictor_mapping[agent]].action_space,
self.restrictions[agent],
**{**self.kwargs, **kwargs},
),
}