Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add agent communication wrapper. #1881

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2d03833
Add agent communication wrapper.
Gamenot Feb 22, 2023
05e1f87
Add docstrings.
Gamenot Feb 22, 2023
9f8d428
Add header.
Gamenot Feb 22, 2023
eacbeaa
Update configuration.
Gamenot Feb 22, 2023
7efa724
Mock up interface for message
Gamenot Feb 22, 2023
b2a4c8d
Add custom block to interface option.
Gamenot Feb 22, 2023
d2081ba
Bring wrapper close to completion.
Gamenot Feb 23, 2023
96a878e
Remove agent interface todo.
Gamenot Feb 23, 2023
b7dc02e
Fix docstring test.
Gamenot Feb 23, 2023
e9a22c0
Add agent communcation example.
Gamenot Feb 23, 2023
cde2091
Add agent communication wrapper.
Gamenot Feb 22, 2023
e021a06
Add docstrings.
Gamenot Feb 22, 2023
1e78270
Add header.
Gamenot Feb 22, 2023
356a922
Update configuration.
Gamenot Feb 22, 2023
4decba0
Mock up interface for message
Gamenot Feb 22, 2023
8b0ed19
Add custom block to interface option.
Gamenot Feb 22, 2023
8960f52
Bring wrapper close to completion.
Gamenot Feb 23, 2023
5268ebd
Remove agent interface todo.
Gamenot Feb 23, 2023
af425d0
Fix docstring test.
Gamenot Feb 23, 2023
f6355a6
Add agent communcation example.
Gamenot Feb 23, 2023
a2f3a1d
Merge branch 'tucker/add_agent_communication' of https://github.com/h…
Gamenot Mar 2, 2023
842f60e
Improve action space.
Gamenot Mar 2, 2023
f98350f
Add vehicle targetting communication wrapper.
Gamenot Mar 3, 2023
931800f
make gen-header
Gamenot Mar 10, 2023
b519785
Fix docstring test
Gamenot Mar 10, 2023
b1a03ef
Fix type test.
Gamenot Mar 10, 2023
85c7582
Fix remaining typing issues.
Gamenot Mar 10, 2023
2cef72e
Remove unused import.
Gamenot Mar 10, 2023
ca0b51e
Merge branch 'master' into tucker/add_agent_communication
Gamenot Mar 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions examples/control/agent_communcation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import sys
from pathlib import Path
from typing import Any, Dict, Generator, List, Tuple, Union

from smarts.core.agent import Agent
from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.core.utils.episodes import episodes
from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1
from smarts.env.gymnasium.wrappers.agent_communication import (
Bands,
Header,
Message,
MessagePasser,
V2XReceiver,
V2XTransmitter,
)
from smarts.env.utils.action_conversion import ActionOptions
from smarts.env.utils.observation_conversion import ObservationOptions
from smarts.sstudio.scenario_construction import build_scenarios

sys.path.insert(0, str(Path(__file__).parents[2].absolute()))
import gymnasium as gym

from examples.tools.argument_parser import default_argument_parser

TIMESTEP = 0.1
BYTES_IN_MEGABIT = 125000
MESSAGE_MEGABITS_PER_SECOND = 10
MESSAGE_BYTES = int(BYTES_IN_MEGABIT * MESSAGE_MEGABITS_PER_SECOND / TIMESTEP)


def filter_useless(
transmissions: List[Tuple[Header, Message]]
) -> Generator[Tuple[Header, Message], None, None]:
"""A primitive example filter that takes in transmissions and outputs filtered transmissions."""
for header, msg in transmissions:
if header.sender in ("parked_agent", "broken_stoplight"):
continue
if header.sender_type in ("advertisement",):
continue
yield header, msg


class LaneFollowerAgent(Agent):
def act(self, obs: Dict[Any, Union[Any, Dict]]):
return (obs["waypoint_paths"]["speed_limit"][0][0], 0)


class GossiperAgent(Agent):
def __init__(self, id_: str, base_agent: Agent, filter_, friends):
self._filter = filter_
self._id = id_
self._friends = friends
self._base_agent = base_agent

def act(self, obs, **configs):
out_transmissions = []
for header, msg in self._filter(obs["transmissions"]):
header: Header = header
msg: Message = msg
if not {self._id, "__all__"}.intersection(header.cc | header.bcc):
continue
if header.channel == "position_request":
print()
print("On step: ", obs["steps_completed"])
print("Gossiper received position request: ", header)
out_transmissions.append(
(
Header(
channel="position",
sender=self._id,
sender_type="ad_vehicle",
cc={header.sender},
bcc={*self._friends},
format="position",
), # optimize this later
Message(
content=obs["ego_vehicle_state"]["position"],
), # optimize this later
)
)
print("Gossiper sent position: ", out_transmissions[0][1])

base_action = self._base_agent.act(obs)
return (base_action, out_transmissions)


class SchemerAgent(Agent):
def __init__(self, id_: str, base_agent: Agent, request_freq) -> None:
self._base_agent = base_agent
self._id = id_
self._request_freq = request_freq

def act(self, obs, **configs):
out_transmissions = []
for header, msg in obs["transmissions"]:
header: Header = header
msg: Message = msg
if header.channel == "position":
print()
print("On step: ", obs["steps_completed"])
print("Schemer received position: ", msg)

if obs["steps_completed"] % self._request_freq == 0:
print()
print("On step: ", obs["steps_completed"])
out_transmissions.append(
(
Header(
channel="position_request",
sender=self._id,
sender_type="ad_vehicle",
cc=set(),
bcc={"__all__"},
format="position_request",
),
Message(content=None),
)
)
print("Schemer requested position with: ", out_transmissions[0][0])

base_action = self._base_agent.act(obs)
return (base_action, out_transmissions)


def main(scenarios, headless, num_episodes, max_episode_steps=None):
agent_interface = AgentInterface.from_type(
AgentType.LanerWithSpeed, max_episode_steps=max_episode_steps
)
hiwayv1env = HiWayEnvV1(
scenarios=scenarios,
agent_interfaces={"gossiper0": agent_interface, "schemer": agent_interface},
headless=headless,
observation_options=ObservationOptions.multi_agent,
action_options=ActionOptions.default,
)
# for now
env = MessagePasser(
hiwayv1env,
max_message_bytes=MESSAGE_BYTES,
message_config={
"gossiper0": (
V2XTransmitter(
bands=Bands.ALL,
range=100,
# available_channels=["position_request", "position"]
),
V2XReceiver(
bands=Bands.ALL,
aliases={"tim"},
blacklist_channels={"self_control"},
),
),
"schemer": (
V2XTransmitter(
bands=Bands.ALL,
range=100,
),
V2XReceiver(
bands=Bands.ALL,
aliases=set(),
),
),
},
)
agents = {
"gossiper0": GossiperAgent(
"gossiper0",
base_agent=LaneFollowerAgent(),
filter_=filter_useless,
friends={"schemer"},
),
"schemer": SchemerAgent(
"schemer", base_agent=LaneFollowerAgent(), request_freq=100
),
}

# then just the standard gym interface with no modifications
for episode in episodes(n=num_episodes):
observation, info = env.reset()
episode.record_scenario(env.scenario_log)

terminated = {"__all__": False}
while not terminated["__all__"]:
agent_action = {
agent_id: agents[agent_id].act(obs)
for agent_id, obs in observation.items()
}
observation, reward, terminated, truncated, info = env.step(agent_action)
episode.record_step(observation, reward, terminated, info)

env.close()


if __name__ == "__main__":
parser = default_argument_parser("single-agent-example")
args = parser.parse_args()

if not args.scenarios:
args.scenarios = [
str(Path(__file__).absolute().parents[2] / "scenarios" / "sumo" / "loop")
]

build_scenarios(scenarios=args.scenarios)

main(
scenarios=args.scenarios,
headless=args.headless,
num_episodes=args.episodes,
)
2 changes: 1 addition & 1 deletion smarts/core/smarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def providers(self) -> List[Provider]:
"""The current providers controlling actors within the simulation."""
return self._providers

def get_provider_by_type(self, requested_type) -> Optional[Provider]:
def get_provider_by_type(self, requested_type: type) -> Optional[Provider]:
"""Get The first provider that matches the requested type."""
self._check_valid()
for provider in self._providers:
Expand Down
2 changes: 1 addition & 1 deletion smarts/core/vehicle_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def vehicleitems(self) -> Iterator[Tuple[str, Vehicle]]:
return map(lambda x: (self._2id_to_id[x[0]], x[1]), self._vehicles.items())

@cache
def vehicle_by_id(self, vehicle_id, default=...):
def vehicle_by_id(self, vehicle_id, default=...) -> Vehicle:
"""Get a vehicle by its id."""
vehicle_id = _2id(vehicle_id)
if default is ...:
Expand Down
Loading