Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions unet_segmentation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.8)
project(unet_segmentation)

find_package(ament_cmake REQUIRED)
find_package(ament_cmake_python REQUIRED)
find_package(rclpy REQUIRED)

ament_python_install_package(${PROJECT_NAME})

install(DIRECTORY
launch
config
model
DESTINATION share/${PROJECT_NAME}
)

install(PROGRAMS
${PROJECT_NAME}/interface.py
DESTINATION lib/${PROJECT_NAME}
)

ament_package()
17 changes: 17 additions & 0 deletions unet_segmentation/config/unet_segmentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
unet_segmentation_node:
ros__parameters:
model_path: "model/unet-simple-320-240-l-5-e10-b16(1).pth"
input_topic: "cam_down/image_color"
overlay_topic: "/segmentation/overlay"
mask_topic: "/segmentation/mask"
resize_width: 320
resize_height: 240
keep_original_size: true # upsample mask/overlay back to source image size
mask_threshold: 0.5
bilinear: false
simple: true
classes: 1
device: "cuda"
pred_color: [255, 0, 0]
overlay_alpha: 0.4
qos_depth: 3
37 changes: 37 additions & 0 deletions unet_segmentation/launch/unet_segmentation.launch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# launch/unet_segmentation.launch.py
from launch import LaunchDescription
from launch.actions import DeclareLaunchArgument
from launch.substitutions import LaunchConfiguration, PathJoinSubstitution
from launch_ros.actions import Node
from launch_ros.substitutions import FindPackageShare


def generate_launch_description():
pkg = DeclareLaunchArgument("package", default_value="unet_segmentation")
params_file = DeclareLaunchArgument(
"params_file",
default_value=PathJoinSubstitution(
[
FindPackageShare(LaunchConfiguration("package")),
"config",
"unet_segmentation.yaml",
]
),
)
input_topic = DeclareLaunchArgument("input_topic", default_value="/image_color")

node = Node(
package=LaunchConfiguration("package"),
executable="interface.py",
name="unet_segmentation_node",
parameters=[LaunchConfiguration("params_file")],
remappings=[
("/image_color", LaunchConfiguration("input_topic")),
# Overlay and mask topics come from YAML; remap here only if needed:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the input_topic be set both by the config file and the launch config can be ambigous. We should only use 1 of these.

# ("/segmentation/overlay", "/your/overlay"),
# ("/segmentation/mask", "/your/mask"),
],
output="screen",
)

return LaunchDescription([pkg, params_file, input_topic, node])
Binary file not shown.
20 changes: 20 additions & 0 deletions unet_segmentation/package.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>unet_segmentation</name>
<version>0.0.0</version>
<description>YOLO inference on images, publishing detections and annotated outputs.</description>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update description

Copy link

Copilot AI Oct 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The package description references 'YOLO inference' but this package is for UNet segmentation. The description should be updated to accurately reflect the package's purpose, e.g., 'UNet-based image segmentation node for ROS2, publishing segmentation masks and annotated overlays.'

Suggested change
<description>YOLO inference on images, publishing detections and annotated outputs.</description>
<description>UNet-based image segmentation node for ROS2, publishing segmentation masks and annotated overlays.</description>

Copilot uses AI. Check for mistakes.
<maintainer email="[email protected]">kluge7</maintainer>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this with your own email and github username

<license>MIT</license>

<buildtool_depend>ament_cmake</buildtool_depend>

<depend>rclpy</depend>
<depend>sensor_msgs</depend>
<depend>cv_bridge</depend>
<depend>vision_msgs</depend>

<export>
<build_type>ament_cmake</build_type>
</export>
</package>
Empty file.
Empty file.
56 changes: 56 additions & 0 deletions unet_segmentation/unet_segmentation/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from dataclasses import dataclass

import torch
import yaml
from rclpy.qos import QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy


@dataclass
class UnetSegmentationConfig:
model_path: str
input_topic: str
overlay_topic: str
mask_topic: str
resize_width: int
resize_height: int
keep_original_size: bool
mask_threshold: float
bilinear: bool
simple: bool
n_classes: int
device_name: str
pred_color: tuple[int, int, int]
overlay_alpha: float
qos_depth: int

@staticmethod
def from_yaml(path: str) -> "UnetSegmentationConfig":
with open(path) as f:
data = yaml.safe_load(f)
p = data.get("unet_segmentation_node", {}).get("ros__parameters", {})
device_name = p.get("device", "cuda" if torch.cuda.is_available() else "cpu")
return UnetSegmentationConfig(
model_path=p["model_path"],
input_topic=p.get("input_topic", "/image_color"),
overlay_topic=p.get("overlay_topic", "/segmentation/overlay"),
mask_topic=p.get("mask_topic", "/segmentation/mask"),
resize_width=p.get("resize_width", 320),
resize_height=p.get("resize_height", 240),
keep_original_size=p.get("keep_original_size", True),
mask_threshold=p.get("mask_threshold", 0.5),
bilinear=p.get("bilinear", False),
simple=p.get("simple", True),
n_classes=p.get("classes", 1),
device_name=device_name,
pred_color=tuple(p.get("pred_color", [255, 0, 0])),
overlay_alpha=p.get("overlay_alpha", 0.4),
qos_depth=p.get("qos_depth", 3),
)

@staticmethod
def qos_profile(depth: int = 3) -> QoSProfile:
return QoSProfile(
reliability=QoSReliabilityPolicy.BEST_EFFORT,
history=QoSHistoryPolicy.KEEP_LAST,
depth=depth,
)
149 changes: 149 additions & 0 deletions unet_segmentation/unet_segmentation/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/env python3

from pathlib import Path

import cv2
import numpy as np
import rclpy
import torch
from ament_index_python.packages import get_package_share_directory
from cv_bridge import CvBridge
from PIL import Image as PILImage
from rclpy.node import Node
from sensor_msgs.msg import Image

from unet_segmentation.config import UnetSegmentationConfig
from unet_segmentation.utils import (
ResizeIfLargerKeepAspect,
build_image_transforms,
load_unet,
make_overlay,
mask_to_mono8,
predict_mask,
upsample_mask_nearest,
)


def default_config_path() -> str:
share_dir = Path(get_package_share_directory('unet_segmentation'))
return str(share_dir / 'config' / 'unet_segmentation.yaml')


class UnetSegmentationNode(Node):
def __init__(self, config_path: str | None = None):
super().__init__("unet_segmentation_node")

# 1) Resolve config path
cfg_path = Path(config_path) if config_path else Path(default_config_path())
if not cfg_path.is_file():
# Try package share as a fallback if a bad relative path was passed
pkg_cfg = Path(default_config_path())
if pkg_cfg.is_file():
cfg_path = pkg_cfg
else:
self.get_logger().fatal(
f"Config file not found. Tried: {config_path} and {pkg_cfg}"
)
raise SystemExit(1)

# 2) LOAD the config -> self.cfg (do this before using self.cfg)
self.cfg = UnetSegmentationConfig.from_yaml(str(cfg_path))

# 3) Validate model path AFTER self.cfg exists
model_p = Path(self.cfg.model_path).expanduser()
if not model_p.exists():
self.get_logger().fatal(f"Model file not found: {model_p}")
raise SystemExit(1)

self.device = torch.device(self.cfg.device_name)
self.bridge = CvBridge()

self.net = load_unet(
model_path=self.cfg.model_path,
n_classes=self.cfg.n_classes,
device=self.device,
bilinear=self.cfg.bilinear,
simple=self.cfg.simple,
logger=self.get_logger(),
)

self.image_transforms = build_image_transforms(
self.cfg.resize_width, self.cfg.resize_height
)
qos_profile = UnetSegmentationConfig.qos_profile(self.cfg.qos_depth)

self.subscription = self.create_subscription(
Image, self.cfg.input_topic, self.image_callback, qos_profile
)
self.overlay_pub = self.create_publisher(
Image, self.cfg.overlay_topic, qos_profile
)
self.mask_pub = self.create_publisher(Image, self.cfg.mask_topic, qos_profile)

self.get_logger().info(
f"Subscribing to '{self.cfg.input_topic}', publishing overlay to '{self.cfg.overlay_topic}' "
f"and mask to '{self.cfg.mask_topic}'."
)

def image_callback(self, msg: Image):
try:
cv_bgr = self.bridge.imgmsg_to_cv2(msg, "bgr8")
base_rgb = cv2.cvtColor(cv_bgr, cv2.COLOR_BGR2RGB)
orig_h, orig_w = base_rgb.shape[:2]
pil_img = PILImage.fromarray(base_rgb)

# Resize (downscale only) for inference
resized_pil = ResizeIfLargerKeepAspect(
self.cfg.resize_width, self.cfg.resize_height
)(pil_img)
resized_w, resized_h = resized_pil.size
image_tensor = self.image_transforms(resized_pil)

# Predict (in resized space)
pred_mask = predict_mask(
self.net,
image_tensor,
self.device,
out_threshold=self.cfg.mask_threshold,
)

# Optionally upsample mask to original size
if self.cfg.keep_original_size:
mask_out = upsample_mask_nearest(pred_mask, orig_w, orig_h)
base_for_overlay = base_rgb # original size
else:
mask_out = pred_mask
base_for_overlay = np.array(resized_pil)

# Build overlay at the chosen size
overlay_np = make_overlay(
base_for_overlay,
mask_out if self.cfg.keep_original_size else pred_mask,
color=self.cfg.pred_color,
alpha=self.cfg.overlay_alpha,
)

# Publish mask (mono8) and overlay (rgb8)
mask_mono8 = mask_to_mono8(mask_out)
mask_msg = self.bridge.cv2_to_imgmsg(mask_mono8, encoding="mono8")
mask_msg.header = msg.header
self.mask_pub.publish(mask_msg)

overlay_msg = self.bridge.cv2_to_imgmsg(overlay_np, encoding="rgb8")
overlay_msg.header = msg.header
self.overlay_pub.publish(overlay_msg)

except Exception as e:
self.get_logger().error(f'Failed to process image: {e}')


def main():
rclpy.init()
node = UnetSegmentationNode(default_config_path())
rclpy.spin(node)
node.destroy_node()
rclpy.shutdown()


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions unet_segmentation/unet_segmentation/unet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .unet_model import UNet as UNet

__all__ = ["UNet"]
Comment on lines +1 to +3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed to run the code?

84 changes: 84 additions & 0 deletions unet_segmentation/unet_segmentation/unet/unet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from torch import nn

from .unet_parts import DoubleConv, Down, OutConv, Up


class UNet(nn.Module):
def __init__(self, n_channels, n_classes, simple=False, bilinear=False):
"""The U-Net architecture.

:param n_channels: Number of input channels (e.g., 3 for RGB images)
:param n_classes: Number of output classes (e.g., 1 for binary segmentation)
:param simple: If True, creates a smaller U-Net with fewer layers.
:param bilinear: If True, use bilinear upsampling instead of transposed convolutions.
"""
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.simple = simple

factor = 2 if bilinear else 1

if not self.simple:
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
else:
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256 // factor)
self.up1 = Up(256, 128 // factor, bilinear)
self.up2 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)

def forward(self, x):
if not self.simple:
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
else:
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x = self.up1(x3, x2)
x = self.up2(x, x1)
logits = self.outc(x)

return logits

def use_checkpointing(self):
"""Enable gradient checkpointing to save memory, but at the cost of additional computation during backpropagation."""
if not self.simple:
self.inc = torch.utils.checkpoint.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint.checkpoint(self.up4)
self.outc = torch.utils.checkpoint.checkpoint(self.outc)
Comment on lines +68 to +77
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is incorrect. torch.utils.checkpoint.checkpoint is a function that performs checkpointing during forward pass, not a wrapper for modules. Assigning it to module attributes will break the model. Gradient checkpointing should be applied in the forward() method by wrapping module calls with torch.utils.checkpoint.checkpoint(module, inputs), or by using the module's .gradient_checkpointing_enable() method if available.

Copilot uses AI. Check for mistakes.
else:
self.inc = torch.utils.checkpoint.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint.checkpoint(self.down2)
self.up1 = torch.utils.checkpoint.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint.checkpoint(self.up2)
self.outc = torch.utils.checkpoint.checkpoint(self.outc)
Loading