-
Notifications
You must be signed in to change notification settings - Fork 0
unet node for image segmentation #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
d746a0f
fc191f3
55301df
0a1cfc9
3c34486
9c690dc
8ee5956
468884a
3233797
de618f8
5f16759
c999220
cbe6f84
6c34ddd
81b01b5
beb0784
b8cbe78
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| cmake_minimum_required(VERSION 3.8) | ||
| project(unet_segmentation) | ||
|
|
||
| find_package(ament_cmake REQUIRED) | ||
| find_package(ament_cmake_python REQUIRED) | ||
| find_package(rclpy REQUIRED) | ||
|
|
||
| ament_python_install_package(${PROJECT_NAME}) | ||
|
|
||
| install(DIRECTORY | ||
| launch | ||
| config | ||
| model | ||
| DESTINATION share/${PROJECT_NAME} | ||
| ) | ||
|
|
||
| install(PROGRAMS | ||
| ${PROJECT_NAME}/interface.py | ||
| DESTINATION lib/${PROJECT_NAME} | ||
| ) | ||
|
|
||
| ament_package() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| unet_segmentation_node: | ||
| ros__parameters: | ||
| model_path: "model/unet-simple-320-240-l-5-e10-b16(1).pth" | ||
| input_topic: "cam_down/image_color" | ||
| overlay_topic: "/segmentation/overlay" | ||
| mask_topic: "/segmentation/mask" | ||
| resize_width: 320 | ||
| resize_height: 240 | ||
| keep_original_size: true # upsample mask/overlay back to source image size | ||
| mask_threshold: 0.5 | ||
| bilinear: false | ||
| simple: true | ||
| classes: 1 | ||
| device: "cuda" | ||
| pred_color: [255, 0, 0] | ||
| overlay_alpha: 0.4 | ||
| qos_depth: 3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| # launch/unet_segmentation.launch.py | ||
| from launch import LaunchDescription | ||
| from launch.actions import DeclareLaunchArgument | ||
| from launch.substitutions import LaunchConfiguration, PathJoinSubstitution | ||
| from launch_ros.actions import Node | ||
| from launch_ros.substitutions import FindPackageShare | ||
|
|
||
|
|
||
| def generate_launch_description(): | ||
| pkg = DeclareLaunchArgument("package", default_value="unet_segmentation") | ||
| params_file = DeclareLaunchArgument( | ||
| "params_file", | ||
| default_value=PathJoinSubstitution( | ||
| [ | ||
| FindPackageShare(LaunchConfiguration("package")), | ||
| "config", | ||
| "unet_segmentation.yaml", | ||
| ] | ||
| ), | ||
| ) | ||
| input_topic = DeclareLaunchArgument("input_topic", default_value="/image_color") | ||
|
|
||
| node = Node( | ||
| package=LaunchConfiguration("package"), | ||
| executable="interface.py", | ||
| name="unet_segmentation_node", | ||
| parameters=[LaunchConfiguration("params_file")], | ||
| remappings=[ | ||
| ("/image_color", LaunchConfiguration("input_topic")), | ||
| # Overlay and mask topics come from YAML; remap here only if needed: | ||
| # ("/segmentation/overlay", "/your/overlay"), | ||
| # ("/segmentation/mask", "/your/mask"), | ||
| ], | ||
| output="screen", | ||
| ) | ||
|
|
||
| return LaunchDescription([pkg, params_file, input_topic, node]) | ||
kluge7 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,20 @@ | ||||||
| <?xml version="1.0"?> | ||||||
| <?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?> | ||||||
| <package format="3"> | ||||||
| <name>unet_segmentation</name> | ||||||
| <version>0.0.0</version> | ||||||
| <description>YOLO inference on images, publishing detections and annotated outputs.</description> | ||||||
|
||||||
| <description>YOLO inference on images, publishing detections and annotated outputs.</description> | |
| <description>UNet-based image segmentation node for ROS2, publishing segmentation masks and annotated overlays.</description> |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace this with your own email and github username
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
| import yaml | ||
| from rclpy.qos import QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy | ||
|
|
||
|
|
||
| @dataclass | ||
| class UnetSegmentationConfig: | ||
| model_path: str | ||
| input_topic: str | ||
| overlay_topic: str | ||
| mask_topic: str | ||
| resize_width: int | ||
| resize_height: int | ||
| keep_original_size: bool | ||
| mask_threshold: float | ||
| bilinear: bool | ||
| simple: bool | ||
| n_classes: int | ||
| device_name: str | ||
| pred_color: tuple[int, int, int] | ||
| overlay_alpha: float | ||
| qos_depth: int | ||
|
|
||
| @staticmethod | ||
| def from_yaml(path: str) -> "UnetSegmentationConfig": | ||
| with open(path) as f: | ||
| data = yaml.safe_load(f) | ||
| p = data.get("unet_segmentation_node", {}).get("ros__parameters", {}) | ||
| device_name = p.get("device", "cuda" if torch.cuda.is_available() else "cpu") | ||
| return UnetSegmentationConfig( | ||
| model_path=p["model_path"], | ||
| input_topic=p.get("input_topic", "/image_color"), | ||
| overlay_topic=p.get("overlay_topic", "/segmentation/overlay"), | ||
| mask_topic=p.get("mask_topic", "/segmentation/mask"), | ||
| resize_width=p.get("resize_width", 320), | ||
| resize_height=p.get("resize_height", 240), | ||
| keep_original_size=p.get("keep_original_size", True), | ||
| mask_threshold=p.get("mask_threshold", 0.5), | ||
| bilinear=p.get("bilinear", False), | ||
| simple=p.get("simple", True), | ||
| n_classes=p.get("classes", 1), | ||
| device_name=device_name, | ||
| pred_color=tuple(p.get("pred_color", [255, 0, 0])), | ||
| overlay_alpha=p.get("overlay_alpha", 0.4), | ||
| qos_depth=p.get("qos_depth", 3), | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def qos_profile(depth: int = 3) -> QoSProfile: | ||
| return QoSProfile( | ||
| reliability=QoSReliabilityPolicy.BEST_EFFORT, | ||
| history=QoSHistoryPolicy.KEEP_LAST, | ||
| depth=depth, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| #!/usr/bin/env python3 | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| import cv2 | ||
| import numpy as np | ||
| import rclpy | ||
| import torch | ||
| from ament_index_python.packages import get_package_share_directory | ||
| from cv_bridge import CvBridge | ||
| from PIL import Image as PILImage | ||
| from rclpy.node import Node | ||
| from sensor_msgs.msg import Image | ||
|
|
||
| from unet_segmentation.config import UnetSegmentationConfig | ||
| from unet_segmentation.utils import ( | ||
| ResizeIfLargerKeepAspect, | ||
| build_image_transforms, | ||
| load_unet, | ||
| make_overlay, | ||
| mask_to_mono8, | ||
| predict_mask, | ||
| upsample_mask_nearest, | ||
| ) | ||
|
|
||
|
|
||
| def default_config_path() -> str: | ||
| share_dir = Path(get_package_share_directory('unet_segmentation')) | ||
| return str(share_dir / 'config' / 'unet_segmentation.yaml') | ||
|
|
||
|
|
||
| class UnetSegmentationNode(Node): | ||
| def __init__(self, config_path: str | None = None): | ||
| super().__init__("unet_segmentation_node") | ||
|
|
||
| # 1) Resolve config path | ||
| cfg_path = Path(config_path) if config_path else Path(default_config_path()) | ||
| if not cfg_path.is_file(): | ||
| # Try package share as a fallback if a bad relative path was passed | ||
| pkg_cfg = Path(default_config_path()) | ||
| if pkg_cfg.is_file(): | ||
| cfg_path = pkg_cfg | ||
| else: | ||
| self.get_logger().fatal( | ||
| f"Config file not found. Tried: {config_path} and {pkg_cfg}" | ||
| ) | ||
| raise SystemExit(1) | ||
|
|
||
| # 2) LOAD the config -> self.cfg (do this before using self.cfg) | ||
| self.cfg = UnetSegmentationConfig.from_yaml(str(cfg_path)) | ||
|
|
||
| # 3) Validate model path AFTER self.cfg exists | ||
| model_p = Path(self.cfg.model_path).expanduser() | ||
| if not model_p.exists(): | ||
| self.get_logger().fatal(f"Model file not found: {model_p}") | ||
| raise SystemExit(1) | ||
|
|
||
| self.device = torch.device(self.cfg.device_name) | ||
| self.bridge = CvBridge() | ||
|
|
||
| self.net = load_unet( | ||
| model_path=self.cfg.model_path, | ||
| n_classes=self.cfg.n_classes, | ||
| device=self.device, | ||
| bilinear=self.cfg.bilinear, | ||
| simple=self.cfg.simple, | ||
| logger=self.get_logger(), | ||
| ) | ||
|
|
||
| self.image_transforms = build_image_transforms( | ||
| self.cfg.resize_width, self.cfg.resize_height | ||
| ) | ||
| qos_profile = UnetSegmentationConfig.qos_profile(self.cfg.qos_depth) | ||
|
|
||
| self.subscription = self.create_subscription( | ||
| Image, self.cfg.input_topic, self.image_callback, qos_profile | ||
| ) | ||
| self.overlay_pub = self.create_publisher( | ||
| Image, self.cfg.overlay_topic, qos_profile | ||
| ) | ||
| self.mask_pub = self.create_publisher(Image, self.cfg.mask_topic, qos_profile) | ||
|
|
||
| self.get_logger().info( | ||
| f"Subscribing to '{self.cfg.input_topic}', publishing overlay to '{self.cfg.overlay_topic}' " | ||
| f"and mask to '{self.cfg.mask_topic}'." | ||
| ) | ||
|
|
||
| def image_callback(self, msg: Image): | ||
| try: | ||
| cv_bgr = self.bridge.imgmsg_to_cv2(msg, "bgr8") | ||
| base_rgb = cv2.cvtColor(cv_bgr, cv2.COLOR_BGR2RGB) | ||
| orig_h, orig_w = base_rgb.shape[:2] | ||
| pil_img = PILImage.fromarray(base_rgb) | ||
|
|
||
| # Resize (downscale only) for inference | ||
| resized_pil = ResizeIfLargerKeepAspect( | ||
| self.cfg.resize_width, self.cfg.resize_height | ||
| )(pil_img) | ||
| resized_w, resized_h = resized_pil.size | ||
| image_tensor = self.image_transforms(resized_pil) | ||
|
|
||
| # Predict (in resized space) | ||
| pred_mask = predict_mask( | ||
| self.net, | ||
| image_tensor, | ||
| self.device, | ||
| out_threshold=self.cfg.mask_threshold, | ||
| ) | ||
|
|
||
| # Optionally upsample mask to original size | ||
| if self.cfg.keep_original_size: | ||
| mask_out = upsample_mask_nearest(pred_mask, orig_w, orig_h) | ||
| base_for_overlay = base_rgb # original size | ||
| else: | ||
| mask_out = pred_mask | ||
| base_for_overlay = np.array(resized_pil) | ||
|
|
||
| # Build overlay at the chosen size | ||
| overlay_np = make_overlay( | ||
| base_for_overlay, | ||
| mask_out if self.cfg.keep_original_size else pred_mask, | ||
| color=self.cfg.pred_color, | ||
| alpha=self.cfg.overlay_alpha, | ||
| ) | ||
|
|
||
| # Publish mask (mono8) and overlay (rgb8) | ||
| mask_mono8 = mask_to_mono8(mask_out) | ||
| mask_msg = self.bridge.cv2_to_imgmsg(mask_mono8, encoding="mono8") | ||
| mask_msg.header = msg.header | ||
| self.mask_pub.publish(mask_msg) | ||
|
|
||
| overlay_msg = self.bridge.cv2_to_imgmsg(overlay_np, encoding="rgb8") | ||
| overlay_msg.header = msg.header | ||
| self.overlay_pub.publish(overlay_msg) | ||
|
|
||
| except Exception as e: | ||
| self.get_logger().error(f'Failed to process image: {e}') | ||
|
|
||
|
|
||
| def main(): | ||
| rclpy.init() | ||
| node = UnetSegmentationNode(default_config_path()) | ||
| rclpy.spin(node) | ||
| node.destroy_node() | ||
| rclpy.shutdown() | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .unet_model import UNet as UNet | ||
|
|
||
| __all__ = ["UNet"] | ||
|
Comment on lines
+1
to
+3
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this needed to run the code? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| import torch | ||
| from torch import nn | ||
|
|
||
| from .unet_parts import DoubleConv, Down, OutConv, Up | ||
|
|
||
|
|
||
| class UNet(nn.Module): | ||
| def __init__(self, n_channels, n_classes, simple=False, bilinear=False): | ||
| """The U-Net architecture. | ||
|
|
||
| :param n_channels: Number of input channels (e.g., 3 for RGB images) | ||
| :param n_classes: Number of output classes (e.g., 1 for binary segmentation) | ||
| :param simple: If True, creates a smaller U-Net with fewer layers. | ||
| :param bilinear: If True, use bilinear upsampling instead of transposed convolutions. | ||
| """ | ||
| super().__init__() | ||
| self.n_channels = n_channels | ||
| self.n_classes = n_classes | ||
| self.bilinear = bilinear | ||
| self.simple = simple | ||
|
|
||
| factor = 2 if bilinear else 1 | ||
|
|
||
| if not self.simple: | ||
| self.inc = DoubleConv(n_channels, 64) | ||
| self.down1 = Down(64, 128) | ||
| self.down2 = Down(128, 256) | ||
| self.down3 = Down(256, 512) | ||
| self.down4 = Down(512, 1024 // factor) | ||
| self.up1 = Up(1024, 512 // factor, bilinear) | ||
| self.up2 = Up(512, 256 // factor, bilinear) | ||
| self.up3 = Up(256, 128 // factor, bilinear) | ||
| self.up4 = Up(128, 64, bilinear) | ||
| self.outc = OutConv(64, n_classes) | ||
| else: | ||
| self.inc = DoubleConv(n_channels, 64) | ||
| self.down1 = Down(64, 128) | ||
| self.down2 = Down(128, 256 // factor) | ||
| self.up1 = Up(256, 128 // factor, bilinear) | ||
| self.up2 = Up(128, 64, bilinear) | ||
| self.outc = OutConv(64, n_classes) | ||
|
|
||
| def forward(self, x): | ||
| if not self.simple: | ||
| x1 = self.inc(x) | ||
| x2 = self.down1(x1) | ||
| x3 = self.down2(x2) | ||
| x4 = self.down3(x3) | ||
| x5 = self.down4(x4) | ||
| x = self.up1(x5, x4) | ||
| x = self.up2(x, x3) | ||
| x = self.up3(x, x2) | ||
| x = self.up4(x, x1) | ||
| logits = self.outc(x) | ||
| else: | ||
| x1 = self.inc(x) | ||
| x2 = self.down1(x1) | ||
| x3 = self.down2(x2) | ||
| x = self.up1(x3, x2) | ||
| x = self.up2(x, x1) | ||
| logits = self.outc(x) | ||
|
|
||
| return logits | ||
|
|
||
| def use_checkpointing(self): | ||
| """Enable gradient checkpointing to save memory, but at the cost of additional computation during backpropagation.""" | ||
| if not self.simple: | ||
| self.inc = torch.utils.checkpoint.checkpoint(self.inc) | ||
| self.down1 = torch.utils.checkpoint.checkpoint(self.down1) | ||
| self.down2 = torch.utils.checkpoint.checkpoint(self.down2) | ||
| self.down3 = torch.utils.checkpoint.checkpoint(self.down3) | ||
| self.down4 = torch.utils.checkpoint.checkpoint(self.down4) | ||
| self.up1 = torch.utils.checkpoint.checkpoint(self.up1) | ||
| self.up2 = torch.utils.checkpoint.checkpoint(self.up2) | ||
| self.up3 = torch.utils.checkpoint.checkpoint(self.up3) | ||
| self.up4 = torch.utils.checkpoint.checkpoint(self.up4) | ||
| self.outc = torch.utils.checkpoint.checkpoint(self.outc) | ||
|
Comment on lines
+68
to
+77
|
||
| else: | ||
| self.inc = torch.utils.checkpoint.checkpoint(self.inc) | ||
| self.down1 = torch.utils.checkpoint.checkpoint(self.down1) | ||
| self.down2 = torch.utils.checkpoint.checkpoint(self.down2) | ||
| self.up1 = torch.utils.checkpoint.checkpoint(self.up1) | ||
| self.up2 = torch.utils.checkpoint.checkpoint(self.up2) | ||
| self.outc = torch.utils.checkpoint.checkpoint(self.outc) | ||
kluge7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having the input_topic be set both by the config file and the launch config can be ambigous. We should only use 1 of these.