diff --git a/metadrive/component/lane/scenario_lane.py b/metadrive/component/lane/scenario_lane.py index ec4c1c916..66c169afb 100644 --- a/metadrive/component/lane/scenario_lane.py +++ b/metadrive/component/lane/scenario_lane.py @@ -52,6 +52,7 @@ def __init__(self, lane_id: int, map_data: dict, need_lane_localization): self.exit_lanes = map_data[lane_id].get(ScenarioDescription.EXIT, None) self.left_lanes = map_data[lane_id].get(ScenarioDescription.LEFT_NEIGHBORS, None) self.right_lanes = map_data[lane_id].get(ScenarioDescription.RIGHT_NEIGHBORS, None) + self.turns = map_data[lane_id].get(ScenarioDescription.TURNS, None) @staticmethod def try_get_polygon(map_data, lane_id): @@ -133,6 +134,7 @@ def destroy(self): self.exit_lanes = None self.left_lanes = None self.right_lanes = None + self.turns = None super(ScenarioLane, self).destroy() diff --git a/metadrive/component/navigation_module/edge_network_navigation.py b/metadrive/component/navigation_module/edge_network_navigation.py index 80bd992a8..fb9c15c0b 100644 --- a/metadrive/component/navigation_module/edge_network_navigation.py +++ b/metadrive/component/navigation_module/edge_network_navigation.py @@ -49,7 +49,6 @@ def reset(self, vehicle): lane, new_l_index = possible_lanes[0][:-1] dest = vehicle.config["destination"] - current_lane = lane destination = dest if dest is not None else None assert current_lane is not None, "spawn place is not on road!" @@ -71,6 +70,9 @@ def set_route(self, current_lane_index: str, destination: str): # self.checkpoints.append(current_lane_index) self._target_checkpoints_index = [0, 1] # update routing info + if len(self.checkpoints) == 0: + self.checkpoints.append(current_lane_index) + self.checkpoints.append(current_lane_index) assert len(self.checkpoints) > 0, "Can not find a route from {} to {}".format(current_lane_index, destination) self.final_lane = self.map.road_network.get_lane(self.checkpoints[-1]) self._navi_info.fill(0.0) @@ -243,6 +245,6 @@ def _update_current_lane(self, ego_vehicle): if self.FORCE_CALCULATE: lane_index, _ = self.map.road_network.get_closest_lane_index(ego_vehicle.position) lane = self.map.road_network.get_lane(lane_index) - self.current_lane = lane + self._current_lane = lane assert lane_index == lane.index, "lane index mismatch!" return lane, lane_index diff --git a/metadrive/component/road_network/edge_road_network.py b/metadrive/component/road_network/edge_road_network.py index e161e1a2d..22aaf3751 100644 --- a/metadrive/component/road_network/edge_road_network.py +++ b/metadrive/component/road_network/edge_road_network.py @@ -8,7 +8,7 @@ from metadrive.utils.math import get_boxes_bounding_box from metadrive.utils.pg.utils import get_lanes_bounding_box -lane_info = namedtuple("edge_lane", ["lane", "entry_lanes", "exit_lanes", "left_lanes", "right_lanes"]) +lane_info = namedtuple("edge_lane", ["lane", "entry_lanes", "exit_lanes", "left_lanes", "right_lanes", "turns"]) class EdgeRoadNetwork(BaseRoadNetwork): @@ -27,8 +27,36 @@ def add_lane(self, lane) -> None: entry_lanes=lane.entry_lanes or [], exit_lanes=lane.exit_lanes or [], left_lanes=lane.left_lanes or [], - right_lanes=lane.right_lanes or [] + right_lanes=lane.right_lanes or [], + turns=lane.turns or [] ) + + def find_rightmost_lane_by_road_id(self, original_road_id): + target = str(original_road_id) + candidates = [] + + for lane_key in self.graph.keys(): + if not isinstance(lane_key, str) or not lane_key.startswith("lane_"): + continue + + parts = lane_key.split("_") + if len(parts) < 3: + continue + + try: + edge_id = parts[1] + lane_index = int(parts[2]) + except (ValueError, IndexError): + continue + + if edge_id == target or edge_id == f"-{target}": + candidates.append((lane_key, lane_index)) + + if not candidates: + return None + + rightmost = max(candidates, key=lambda x: x[1]) + return rightmost[0] def get_lane(self, index: LaneIndex): return self.graph[index].lane @@ -68,6 +96,7 @@ def bfs_paths(self, start: str, goal: str) -> List[List[str]]: :param goal: goal edge :return: list of paths from start to goal. """ + lanes = self.graph[start].left_lanes + self.graph[start].right_lanes + [start] queue = [(lane, [lane]) for lane in lanes] @@ -81,6 +110,8 @@ def bfs_paths(self, start: str, goal: str) -> List[List[str]]: if _next in path: # circle continue + if goal is None and len(path) > 3: + yield path + [_next] if _next == goal: yield path + [_next] elif _next in self.graph: @@ -90,9 +121,9 @@ def get_peer_lanes_from_index(self, lane_index): info: lane_info = self.graph[lane_index] ret = [self.graph[lane_index].lane] for left_n in info.left_lanes: - ret.append(self.graph[left_n["id"]].lane) + ret.append(self.graph[left_n].lane) for right_n in info.right_lanes: - ret.append(self.graph[right_n["id"]].lane) + ret.append(self.graph[right_n].lane) return ret def destroy(self): @@ -124,6 +155,7 @@ def get_map_features(self, interval=2): SD.EXIT: lane_info.exit_lanes, SD.LEFT_NEIGHBORS: lane_info.left_lanes, SD.RIGHT_NEIGHBORS: lane_info.right_lanes, + SD.TURNS: lane_info.turns, "speed_limit_kmh": lane_info.lane.speed_limit } return ret diff --git a/metadrive/component/vehicle/base_vehicle.py b/metadrive/component/vehicle/base_vehicle.py index ca06d6a95..a22cd4fef 100644 --- a/metadrive/component/vehicle/base_vehicle.py +++ b/metadrive/component/vehicle/base_vehicle.py @@ -28,6 +28,7 @@ from metadrive.utils.math import wrap_to_pi from metadrive.utils.pg.utils import rect_region_detection from metadrive.utils.utils import get_object_from_node +from metadrive.component.road_network.edge_road_network import EdgeRoadNetwork logger = get_logger() @@ -341,7 +342,11 @@ def reset( position = [0, 0] heading = 0 else: - lane = map.road_network.get_lane(self.config["spawn_lane_index"]) + if map.road_network_type == EdgeRoadNetwork: + lane_key = map.road_network.find_rightmost_lane_by_road_id(self.config["spawn_lane_index"]) + else: + lane_key = self.config["spawn_lane_index"] + lane = map.road_network.get_lane(lane_key) position = lane.position(self.config["spawn_longitude"], self.config["spawn_lateral"]) heading = lane.heading_theta_at(self.config["spawn_longitude"]) else: @@ -361,7 +366,6 @@ def reset( self.set_position(position[:2], height=position[-1]) else: raise ValueError() - self.reset_navigation() self.body.clearForces() self.body.setLinearVelocity(Vec3(0, 0, 0)) diff --git a/metadrive/engine/top_down_renderer.py b/metadrive/engine/top_down_renderer.py index 15380cc15..8fb6b424c 100644 --- a/metadrive/engine/top_down_renderer.py +++ b/metadrive/engine/top_down_renderer.py @@ -20,6 +20,65 @@ color_white = (255, 255, 255) +import math + + +def draw_turn_sign(surface, start_pos, directions, color=(255, 255, 255), bg_color=(0, 0, 255, 128), sign_size=18, first=True): + + if not directions: + return + if first: + half = sign_size // 2 + rect = pygame.Rect(start_pos[0] - half, start_pos[1] - half, sign_size + 4, sign_size + 4) + + s = pygame.Surface((sign_size, sign_size), pygame.SRCALPHA) + s.fill(bg_color) + surface.blit(s, (start_pos[0] - half, start_pos[1] - half)) + + center_local = (0, 0) + arrow_length = sign_size * 0.5 + arrow_size = sign_size * 0.15 + + local_angle_map = { + 's': -math.pi / 2, + 'r': 0, + 'l': math.pi, + 't': math.pi / 2, + } + + if set(directions) == {'s', 'r'}: + offsets = {'s': (-3, 0), 'r': (3, 0)} + elif set(directions) == {'s', 'l'}: + offsets = {'s': (3, 0), 'l': (-3, 0)} + elif set(directions) == {'l', 's', 'r'}: + offsets = {'l': (-6, 0), 's': (0, 0), 'r': (6, 0)} + else: + offsets = {d: (0, 0) for d in directions} + + for d in directions: + if d not in local_angle_map: + continue + angle = local_angle_map[d] + dx = arrow_length * math.cos(angle) + dy = arrow_length * math.sin(angle) + + ox, oy = offsets.get(d, (0, 0)) + sp = (start_pos[0] + ox, start_pos[1] + oy) + ep = (sp[0] + dx, sp[1] + dy) + + pygame.draw.line(surface, color, sp, ep, 4) + + arrow_tip_angle = math.pi / 6 + left_tip = ( + ep[0] + arrow_size * math.cos(angle + math.pi - arrow_tip_angle), + ep[1] + arrow_size * math.sin(angle + math.pi - arrow_tip_angle) + ) + right_tip = ( + ep[0] + arrow_size * math.cos(angle + math.pi + arrow_tip_angle), + ep[1] + arrow_size * math.sin(angle + math.pi + arrow_tip_angle) + ) + pygame.draw.line(surface, color, ep, left_tip, 4) + pygame.draw.line(surface, color, ep, right_tip, 4) def draw_top_down_map_native( map, @@ -244,6 +303,19 @@ def __init__( # LQY: do not delete the above line !!!!! # Setup some useful flags + + self.sign_icon_raw = {} + self.sign_icon_surfaces = {} + if hasattr(self.engine, 'traffic_sign_manager'): + sign_mgr = self.engine.traffic_sign_manager + for sign in sign_mgr.signs: + sign_type = type(sign).__name__ + icon_path = sign.icon_path + try: + self.sign_icon_raw[sign_type] = pygame.image.load(icon_path) + except Exception as e: + print(f"Failed to load icon for {sign_type}: {e}") + self.logger = get_logger() if num_stack < 1: self.logger.warning("num_stack should be greater than 0. Current value: {}. Set to 1".format(num_stack)) @@ -524,6 +596,55 @@ def _draw(self, *args, **kwargs): radius=5 ) self._deads.append(v) + + if not self.sign_icon_surfaces: + if pygame.display.get_init(): + for name, img in self.sign_icon_raw.items(): + scaled = pygame.transform.smoothscale(img, (24, 24)) + self.sign_icon_surfaces[name] = scaled.convert_alpha() + + if (self.current_track_agent is not None and + hasattr(self.current_track_agent, 'navigation') and + hasattr(self.current_track_agent.navigation, 'final_lane') and + self.current_track_agent.navigation.final_lane is not None): + + final_lane = self.current_track_agent.navigation.final_lane + target_pos = final_lane.position(final_lane.length, 0) + + pixel_x, pixel_y = self._frame_canvas.pos2pix(target_pos[0], target_pos[1]) + + pygame.draw.circle( + surface=self._frame_canvas, + color=(255, 0, 0), + center=(pixel_x, pixel_y), + radius=6, + width=0 + ) + pygame.draw.circle( + surface=self._frame_canvas, + color=(255, 255, 255), + center=(pixel_x, pixel_y), + radius=6, + width=2 + ) + + if hasattr(self.engine, 'traffic_sign_manager'): + sign_mgr = self.engine.traffic_sign_manager + for sign in sign_mgr.signs: + sign_type = type(sign).__name__ + icon = self.sign_icon_surfaces.get(sign_type) + if sign_type == "DirectionSign": + dir_order = {'l': 0, 's': 1, 'r': 2, 't': 3} + sorted_dirs = sorted(sign.lane.turns, key=lambda d: dir_order.get(d["direction"], 99)) + screen_end = self._frame_canvas.pos2pix(sign.position[0], sign.position[1]) + first = True + for d in sorted_dirs: + draw_turn_sign(self._frame_canvas, screen_end, d["direction"], color=(255, 255, 255), first=first) + first = False + elif icon is not None and hasattr(sign, 'position'): + pixel_x, pixel_y = self._frame_canvas.pos2pix(sign.position[0], sign.position[1]) + rect = icon.get_rect(center=(pixel_x, pixel_y)) + self._frame_canvas.blit(icon, rect) v = self.current_track_agent canvas = self._frame_canvas diff --git a/metadrive/envs/base_env.py b/metadrive/envs/base_env.py index caafa54d3..541ad61cf 100644 --- a/metadrive/envs/base_env.py +++ b/metadrive/envs/base_env.py @@ -37,7 +37,7 @@ # Whether randomize the car model for the agent, randomly choosing from 4 types of cars random_agent_model=False, # The ego config is: env_config["vehicle_config"].update(env_config"[agent_configs"]["default_agent"]) - agent_configs={DEFAULT_AGENT: dict(use_special_color=True, spawn_lane_index=None)}, + agent_configs={DEFAULT_AGENT: dict(use_special_color=True)}, # ===== multi-agent ===== # This should be >1 in MARL envs, or set to -1 for spawning as many vehicles as possible. diff --git a/metadrive/scenario/scenario_description.py b/metadrive/scenario/scenario_description.py index 7f805db81..921812903 100644 --- a/metadrive/scenario/scenario_description.py +++ b/metadrive/scenario/scenario_description.py @@ -139,10 +139,11 @@ class ScenarioDescription(dict): POLYGON = "polygon" LEFT_BOUNDARIES = "left_boundaries" RIGHT_BOUNDARIES = "right_boundaries" - LEFT_NEIGHBORS = "left_neighbor" - RIGHT_NEIGHBORS = "right_neighbor" + LEFT_NEIGHBORS = "left_lanes" + RIGHT_NEIGHBORS = "right_lanes" ENTRY = "entry_lanes" EXIT = "exit_lanes" + TURNS = "turns" # object TYPE = "type" diff --git a/metadrive/type.py b/metadrive/type.py index 00f9fdd1a..d6cc308ab 100644 --- a/metadrive/type.py +++ b/metadrive/type.py @@ -37,6 +37,7 @@ class MetaDriveType: BOUNDARY_MEDIAN = "ROAD_EDGE_MEDIAN" # line BOUNDARY_SIDEWALK = "ROAD_EDGE_SIDEWALK" # polygon STOP_SIGN = "STOP_SIGN" + LANE_DIRECTION = "LANE_DIRECTION" CROSSWALK = "CROSSWALK" SPEED_BUMP = "SPEED_BUMP" DRIVEWAY = "DRIVEWAY" diff --git a/metadrive/utils/sumo/map_utils.py b/metadrive/utils/sumo/map_utils.py index fcee0abf7..8b01d55ee 100644 --- a/metadrive/utils/sumo/map_utils.py +++ b/metadrive/utils/sumo/map_utils.py @@ -18,6 +18,7 @@ raise ImportError("Please install sumolib before running this script via: pip install sumolib") from shapely.geometry import LineString, MultiPolygon, Polygon from shapely.geometry.base import CAP_STYLE, JOIN_STYLE +from collections import defaultdict def buffered_shape(shape, width: float = 1.0) -> Polygon: @@ -229,10 +230,9 @@ def __init__( conns = junction.sumolib_obj.getConnections() for conn in conns: - from_lane_id = conn.getFromLane().getID() # Link lanes + from_lane_id = conn.getFromLane().getID() to_lane_id = conn.getToLane().getID() via_lane_id = conn.getViaLaneID() - from_road_id = conn.getFrom().getID() # Link roads to_road_id = conn.getTo().getID() if via_lane_id == '': # Maybe we could skip this, but not sure @@ -319,17 +319,27 @@ def extract_map_features(graph): # } # build map lanes + from shapely.geometry import Polygon + ret = {} + + # Сначала создадим все полосы без связей + lane_names = set() for road_id, road in graph.roads.items(): for lane in road.lanes: - id = "lane_{}".format(lane.name) - + lane_names.add(id) boundary_polygon = [(x, y) for x, y in lane.shape.shape.exterior.coords] if lane.type == 'driving': ret[id] = { SD.TYPE: MetaDriveType.LANE_SURFACE_STREET, SD.POLYLINE: lane.sumolib_obj.getShape(), SD.POLYGON: boundary_polygon, + # Заглушки — заполним ниже + "entry_lanes": [], + "turns": [], + "exit_lanes": [], + "left_lanes": [], + "right_lanes": [], } elif lane.type == 'sidewalk': ret[id] = { @@ -347,6 +357,71 @@ def extract_map_features(graph): SD.TYPE: MetaDriveType.CROSSWALK, SD.POLYGON: boundary_polygon, } + + + for road_id, road in graph.roads.items(): + for lane in road.lanes: + if lane.type != 'driving': + continue + lane_id = f"lane_{lane.name}" + if lane_id not in ret: + continue + + left_lanes = [] + right_lanes = [] + if lane.left_neigh and lane.left_neigh.type == 'driving': + left_lanes.append(f"lane_{lane.left_neigh.name}") + if lane.right_neigh and lane.right_neigh.type == 'driving': + right_lanes.append(f"lane_{lane.right_neigh.name}") + + ret[lane_id]["left_lanes"] = left_lanes + ret[lane_id]["right_lanes"] = right_lanes + + for lane_name, lane_node in graph.lanes.items(): + lane_id = "lane_{}".format(lane_name) + if lane_id not in ret: + continue + + exit_ids = [] + for out_lane in lane_node.outgoing: + out_id = "lane_{}".format(out_lane.name) + if out_id in ret: + exit_ids.append(out_id) + ret[lane_id]["exit_lanes"] = exit_ids + + entry_ids = [] + for in_lane in lane_node.incoming: + in_id = "lane_{}".format(in_lane.name) + if in_id in ret: + entry_ids.append(in_id) + ret[lane_id]["entry_lanes"] = entry_ids + + for lane_name, lane_node in graph.lanes.items(): + base_lane_id = f"lane_{lane_name}" + if base_lane_id not in ret or lane_node.type != 'driving': + continue + + turns = [] + for out_lane in lane_node.outgoing: + direction = None + for conn in lane_node.sumolib_obj.getOutgoing(): + if conn.getToLane() == out_lane.sumolib_obj or conn.getViaLaneID() == out_lane.name: + direction = conn.getDirection() + if direction in ('s', 'l', 'r', 't'): + break + + if direction is None: + continue + + to_lane_id = f"lane_{out_lane.name}" + if to_lane_id in ret: # только если целевая полоса сохранена + turns.append({ + "direction": direction, + "to_lane": to_lane_id + }) + + if turns: + ret[base_lane_id]["turns"] = turns for lane_divider_id, lane_divider in enumerate(graph.lane_dividers): id = "lane_divider_{}".format(lane_divider_id)