|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from time import time |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +def timeit(tag, t): |
| 8 | + print("{}: {}s".format(tag, time() - t)) |
| 9 | + return time() |
| 10 | + |
| 11 | +def pc_normalize(pc): |
| 12 | + l = pc.shape[0] |
| 13 | + centroid = np.mean(pc, axis=0) |
| 14 | + pc = pc - centroid |
| 15 | + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) |
| 16 | + pc = pc / m |
| 17 | + return pc |
| 18 | + |
| 19 | +def square_distance(src, dst): |
| 20 | + """ |
| 21 | + Calculate Euclid distance between each two points. |
| 22 | +
|
| 23 | + src^T * dst = xn * xm + yn * ym + zn * zm; |
| 24 | + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; |
| 25 | + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; |
| 26 | + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 |
| 27 | + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst |
| 28 | +
|
| 29 | + Input: |
| 30 | + src: source points, [B, N, C] |
| 31 | + dst: target points, [B, M, C] |
| 32 | + Output: |
| 33 | + dist: per-point square distance, [B, N, M] |
| 34 | + """ |
| 35 | + B, N, _ = src.shape |
| 36 | + _, M, _ = dst.shape |
| 37 | + dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) |
| 38 | + dist += torch.sum(src ** 2, -1).view(B, N, 1) |
| 39 | + dist += torch.sum(dst ** 2, -1).view(B, 1, M) |
| 40 | + return dist |
| 41 | + |
| 42 | + |
| 43 | +def index_points(points, idx): |
| 44 | + """ |
| 45 | +
|
| 46 | + Input: |
| 47 | + points: input points data, [B, N, C] |
| 48 | + idx: sample index data, [B, S] |
| 49 | + Return: |
| 50 | + new_points:, indexed points data, [B, S, C] |
| 51 | + """ |
| 52 | + device = points.device |
| 53 | + B = points.shape[0] |
| 54 | + view_shape = list(idx.shape) |
| 55 | + view_shape[1:] = [1] * (len(view_shape) - 1) |
| 56 | + repeat_shape = list(idx.shape) |
| 57 | + repeat_shape[0] = 1 |
| 58 | + batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) |
| 59 | + new_points = points[batch_indices, idx, :] |
| 60 | + return new_points |
| 61 | + |
| 62 | + |
| 63 | +def farthest_point_sample(xyz, npoint): |
| 64 | + """ |
| 65 | + Input: |
| 66 | + xyz: pointcloud data, [B, N, 3] |
| 67 | + npoint: number of samples |
| 68 | + Return: |
| 69 | + centroids: sampled pointcloud index, [B, npoint] |
| 70 | + """ |
| 71 | + device = xyz.device |
| 72 | + B, N, C = xyz.shape |
| 73 | + centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) |
| 74 | + distance = torch.ones(B, N).to(device) * 1e10 |
| 75 | + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) |
| 76 | + batch_indices = torch.arange(B, dtype=torch.long).to(device) |
| 77 | + for i in range(npoint): |
| 78 | + centroids[:, i] = farthest |
| 79 | + centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) |
| 80 | + dist = torch.sum((xyz - centroid) ** 2, -1) |
| 81 | + mask = dist < distance |
| 82 | + distance[mask] = dist[mask] |
| 83 | + farthest = torch.max(distance, -1)[1] |
| 84 | + return centroids |
| 85 | + |
| 86 | + |
| 87 | +def query_ball_point(radius, nsample, xyz, new_xyz): |
| 88 | + """ |
| 89 | + Input: |
| 90 | + radius: local region radius |
| 91 | + nsample: max sample number in local region |
| 92 | + xyz: all points, [B, N, 3] |
| 93 | + new_xyz: query points, [B, S, 3] |
| 94 | + Return: |
| 95 | + group_idx: grouped points index, [B, S, nsample] |
| 96 | + """ |
| 97 | + device = xyz.device |
| 98 | + B, N, C = xyz.shape |
| 99 | + _, S, _ = new_xyz.shape |
| 100 | + group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) |
| 101 | + sqrdists = square_distance(new_xyz, xyz) |
| 102 | + group_idx[sqrdists > radius ** 2] = N |
| 103 | + group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] |
| 104 | + group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) |
| 105 | + mask = group_idx == N |
| 106 | + group_idx[mask] = group_first[mask] |
| 107 | + return group_idx |
| 108 | + |
| 109 | + |
| 110 | +def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): |
| 111 | + """ |
| 112 | + Input: |
| 113 | + npoint: |
| 114 | + radius: |
| 115 | + nsample: |
| 116 | + xyz: input points position data, [B, N, 3] |
| 117 | + points: input points data, [B, N, D] |
| 118 | + Return: |
| 119 | + new_xyz: sampled points position data, [B, npoint, nsample, 3] |
| 120 | + new_points: sampled points data, [B, npoint, nsample, 3+D] |
| 121 | + """ |
| 122 | + B, N, C = xyz.shape |
| 123 | + S = npoint |
| 124 | + fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] |
| 125 | + new_xyz = index_points(xyz, fps_idx) |
| 126 | + idx = query_ball_point(radius, nsample, xyz, new_xyz) |
| 127 | + grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] |
| 128 | + grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) |
| 129 | + |
| 130 | + if points is not None: |
| 131 | + grouped_points = index_points(points, idx) |
| 132 | + new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] |
| 133 | + else: |
| 134 | + new_points = grouped_xyz_norm |
| 135 | + if returnfps: |
| 136 | + return new_xyz, new_points, grouped_xyz, fps_idx |
| 137 | + else: |
| 138 | + return new_xyz, new_points |
| 139 | + |
| 140 | + |
| 141 | +def sample_and_group_all(xyz, points): |
| 142 | + """ |
| 143 | + Input: |
| 144 | + xyz: input points position data, [B, N, 3] |
| 145 | + points: input points data, [B, N, D] |
| 146 | + Return: |
| 147 | + new_xyz: sampled points position data, [B, 1, 3] |
| 148 | + new_points: sampled points data, [B, 1, N, 3+D] |
| 149 | + """ |
| 150 | + device = xyz.device |
| 151 | + B, N, C = xyz.shape |
| 152 | + new_xyz = torch.zeros(B, 1, C).to(device) |
| 153 | + grouped_xyz = xyz.view(B, 1, N, C) |
| 154 | + if points is not None: |
| 155 | + new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) |
| 156 | + else: |
| 157 | + new_points = grouped_xyz |
| 158 | + return new_xyz, new_points |
| 159 | + |
| 160 | + |
| 161 | +class PointNetSetAbstraction(nn.Module): |
| 162 | + def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): |
| 163 | + super(PointNetSetAbstraction, self).__init__() |
| 164 | + self.npoint = npoint |
| 165 | + self.radius = radius |
| 166 | + self.nsample = nsample |
| 167 | + self.mlp_convs = nn.ModuleList() |
| 168 | + self.mlp_bns = nn.ModuleList() |
| 169 | + last_channel = in_channel |
| 170 | + for out_channel in mlp: |
| 171 | + self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) |
| 172 | + self.mlp_bns.append(nn.BatchNorm2d(out_channel)) |
| 173 | + last_channel = out_channel |
| 174 | + self.group_all = group_all |
| 175 | + |
| 176 | + def forward(self, xyz, points): |
| 177 | + """ |
| 178 | + Input: |
| 179 | + xyz: input points position data, [B, C, N] |
| 180 | + points: input points data, [B, D, N] |
| 181 | + Return: |
| 182 | + new_xyz: sampled points position data, [B, C, S] |
| 183 | + new_points_concat: sample points feature data, [B, D', S] |
| 184 | + """ |
| 185 | + xyz = xyz.permute(0, 2, 1) |
| 186 | + if points is not None: |
| 187 | + points = points.permute(0, 2, 1) |
| 188 | + |
| 189 | + if self.group_all: |
| 190 | + new_xyz, new_points = sample_and_group_all(xyz, points) |
| 191 | + else: |
| 192 | + new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) |
| 193 | + # new_xyz: sampled points position data, [B, npoint, C] |
| 194 | + # new_points: sampled points data, [B, npoint, nsample, C+D] |
| 195 | + new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] |
| 196 | + for i, conv in enumerate(self.mlp_convs): |
| 197 | + bn = self.mlp_bns[i] |
| 198 | + new_points = F.relu(bn(conv(new_points)), inplace=True) |
| 199 | + |
| 200 | + new_points = torch.max(new_points, 2)[0] |
| 201 | + new_xyz = new_xyz.permute(0, 2, 1) |
| 202 | + return new_xyz, new_points |
| 203 | + |
| 204 | + |
| 205 | +class PointNetSetAbstractionMsg(nn.Module): |
| 206 | + def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): |
| 207 | + super(PointNetSetAbstractionMsg, self).__init__() |
| 208 | + self.npoint = npoint |
| 209 | + self.radius_list = radius_list |
| 210 | + self.nsample_list = nsample_list |
| 211 | + self.conv_blocks = nn.ModuleList() |
| 212 | + self.bn_blocks = nn.ModuleList() |
| 213 | + for i in range(len(mlp_list)): |
| 214 | + convs = nn.ModuleList() |
| 215 | + bns = nn.ModuleList() |
| 216 | + last_channel = in_channel + 3 |
| 217 | + for out_channel in mlp_list[i]: |
| 218 | + convs.append(nn.Conv2d(last_channel, out_channel, 1)) |
| 219 | + bns.append(nn.BatchNorm2d(out_channel)) |
| 220 | + last_channel = out_channel |
| 221 | + self.conv_blocks.append(convs) |
| 222 | + self.bn_blocks.append(bns) |
| 223 | + |
| 224 | + def forward(self, xyz, points): |
| 225 | + """ |
| 226 | + Input: |
| 227 | + xyz: input points position data, [B, C, N] |
| 228 | + points: input points data, [B, D, N] |
| 229 | + Return: |
| 230 | + new_xyz: sampled points position data, [B, C, S] |
| 231 | + new_points_concat: sample points feature data, [B, D', S] |
| 232 | + """ |
| 233 | + xyz = xyz.permute(0, 2, 1) |
| 234 | + if points is not None: |
| 235 | + points = points.permute(0, 2, 1) |
| 236 | + |
| 237 | + B, N, C = xyz.shape |
| 238 | + S = self.npoint |
| 239 | + new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) |
| 240 | + new_points_list = [] |
| 241 | + for i, radius in enumerate(self.radius_list): |
| 242 | + K = self.nsample_list[i] |
| 243 | + group_idx = query_ball_point(radius, K, xyz, new_xyz) |
| 244 | + grouped_xyz = index_points(xyz, group_idx) |
| 245 | + grouped_xyz -= new_xyz.view(B, S, 1, C) |
| 246 | + if points is not None: |
| 247 | + grouped_points = index_points(points, group_idx) |
| 248 | + grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) |
| 249 | + else: |
| 250 | + grouped_points = grouped_xyz |
| 251 | + |
| 252 | + grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] |
| 253 | + for j in range(len(self.conv_blocks[i])): |
| 254 | + conv = self.conv_blocks[i][j] |
| 255 | + bn = self.bn_blocks[i][j] |
| 256 | + grouped_points = F.relu(bn(conv(grouped_points)), inplace=True) |
| 257 | + new_points = torch.max(grouped_points, 2)[0] # [B, D', S] |
| 258 | + new_points_list.append(new_points) |
| 259 | + |
| 260 | + new_xyz = new_xyz.permute(0, 2, 1) |
| 261 | + new_points_concat = torch.cat(new_points_list, dim=1) |
| 262 | + return new_xyz, new_points_concat |
| 263 | + |
| 264 | + |
| 265 | +class PointNetFeaturePropagation(nn.Module): |
| 266 | + def __init__(self, in_channel, mlp): |
| 267 | + super(PointNetFeaturePropagation, self).__init__() |
| 268 | + self.mlp_convs = nn.ModuleList() |
| 269 | + self.mlp_bns = nn.ModuleList() |
| 270 | + last_channel = in_channel |
| 271 | + for out_channel in mlp: |
| 272 | + self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) |
| 273 | + self.mlp_bns.append(nn.BatchNorm1d(out_channel)) |
| 274 | + last_channel = out_channel |
| 275 | + |
| 276 | + def forward(self, xyz1, xyz2, points1, points2): |
| 277 | + """ |
| 278 | + Input: |
| 279 | + xyz1: input points position data, [B, C, N] |
| 280 | + xyz2: sampled input points position data, [B, C, S] |
| 281 | + points1: input points data, [B, D, N] |
| 282 | + points2: input points data, [B, D, S] |
| 283 | + Return: |
| 284 | + new_points: upsampled points data, [B, D', N] |
| 285 | + """ |
| 286 | + xyz1 = xyz1.permute(0, 2, 1) |
| 287 | + xyz2 = xyz2.permute(0, 2, 1) |
| 288 | + |
| 289 | + points2 = points2.permute(0, 2, 1) |
| 290 | + B, N, C = xyz1.shape |
| 291 | + _, S, _ = xyz2.shape |
| 292 | + |
| 293 | + if S == 1: |
| 294 | + interpolated_points = points2.repeat(1, N, 1) |
| 295 | + else: |
| 296 | + dists = square_distance(xyz1, xyz2) |
| 297 | + dists, idx = dists.sort(dim=-1) |
| 298 | + dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] |
| 299 | + |
| 300 | + dist_recip = 1.0 / (dists + 1e-8) |
| 301 | + norm = torch.sum(dist_recip, dim=2, keepdim=True) |
| 302 | + weight = dist_recip / norm |
| 303 | + interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) |
| 304 | + |
| 305 | + if points1 is not None: |
| 306 | + points1 = points1.permute(0, 2, 1) |
| 307 | + new_points = torch.cat([points1, interpolated_points], dim=-1) |
| 308 | + else: |
| 309 | + new_points = interpolated_points |
| 310 | + |
| 311 | + new_points = new_points.permute(0, 2, 1) |
| 312 | + for i, conv in enumerate(self.mlp_convs): |
| 313 | + bn = self.mlp_bns[i] |
| 314 | + new_points = F.relu(bn(conv(new_points)), inplace=True) |
| 315 | + return new_points |
| 316 | + |
0 commit comments