Skip to content

Commit 6862dc3

Browse files
committed
add pre-trained models for classification and partseg
1 parent 546693c commit 6862dc3

10 files changed

Lines changed: 4411 additions & 2 deletions

File tree

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
This repo is implementation for [PointNet](http://openaccess.thecvf.com/content_cvpr_2017/papers/Qi_PointNet_Deep_Learning_CVPR_2017_paper.pdf) and [PointNet++](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-sets-in-a-metric-space.pdf) in pytorch.
44

55
## Update
6-
**2021/03/26:**
6+
**2021/03/27:**
77

8-
Release pre-trained models for semantic segmentation, where PointNet++ can achieve **53.5\%** mIoU.
8+
(1) Release pre-trained models for semantic segmentation, where PointNet++ can achieve **53.5\%** mIoU.
9+
10+
(2) Release pre-trained models for classification and part segmentation in `log/`.
911

1012
**2021/03/20:** Update codes for classification, including:
1113

Binary file not shown.

log/classification/pointnet2_msg_normals/logs/pointnet2_cls_msg.txt

Lines changed: 852 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
from pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction
4+
5+
6+
class get_model(nn.Module):
7+
def __init__(self,num_class,normal_channel=True):
8+
super(get_model, self).__init__()
9+
in_channel = 3 if normal_channel else 0
10+
self.normal_channel = normal_channel
11+
self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
12+
self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
13+
self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)
14+
self.fc1 = nn.Linear(1024, 512)
15+
self.bn1 = nn.BatchNorm1d(512)
16+
self.drop1 = nn.Dropout(0.4)
17+
self.fc2 = nn.Linear(512, 256)
18+
self.bn2 = nn.BatchNorm1d(256)
19+
self.drop2 = nn.Dropout(0.5)
20+
self.fc3 = nn.Linear(256, num_class)
21+
22+
def forward(self, xyz):
23+
B, _, _ = xyz.shape
24+
if self.normal_channel:
25+
norm = xyz[:, 3:, :]
26+
xyz = xyz[:, :3, :]
27+
else:
28+
norm = None
29+
l1_xyz, l1_points = self.sa1(xyz, norm)
30+
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
31+
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
32+
x = l3_points.view(B, 1024)
33+
x = self.drop1(F.relu(self.bn1(self.fc1(x)), inplace=True))
34+
x = self.drop2(F.relu(self.bn2(self.fc2(x)), inplace=True))
35+
x = self.fc3(x)
36+
x = F.log_softmax(x, -1)
37+
38+
39+
return x,l3_points
40+
41+
42+
class get_loss(nn.Module):
43+
def __init__(self):
44+
super(get_loss, self).__init__()
45+
46+
def forward(self, pred, target, trans_feat):
47+
total_loss = F.nll_loss(pred, target)
48+
49+
return total_loss
50+
51+
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from time import time
5+
import numpy as np
6+
7+
def timeit(tag, t):
8+
print("{}: {}s".format(tag, time() - t))
9+
return time()
10+
11+
def pc_normalize(pc):
12+
l = pc.shape[0]
13+
centroid = np.mean(pc, axis=0)
14+
pc = pc - centroid
15+
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
16+
pc = pc / m
17+
return pc
18+
19+
def square_distance(src, dst):
20+
"""
21+
Calculate Euclid distance between each two points.
22+
23+
src^T * dst = xn * xm + yn * ym + zn * zm;
24+
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
25+
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
26+
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
27+
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
28+
29+
Input:
30+
src: source points, [B, N, C]
31+
dst: target points, [B, M, C]
32+
Output:
33+
dist: per-point square distance, [B, N, M]
34+
"""
35+
B, N, _ = src.shape
36+
_, M, _ = dst.shape
37+
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
38+
dist += torch.sum(src ** 2, -1).view(B, N, 1)
39+
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
40+
return dist
41+
42+
43+
def index_points(points, idx):
44+
"""
45+
46+
Input:
47+
points: input points data, [B, N, C]
48+
idx: sample index data, [B, S]
49+
Return:
50+
new_points:, indexed points data, [B, S, C]
51+
"""
52+
device = points.device
53+
B = points.shape[0]
54+
view_shape = list(idx.shape)
55+
view_shape[1:] = [1] * (len(view_shape) - 1)
56+
repeat_shape = list(idx.shape)
57+
repeat_shape[0] = 1
58+
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
59+
new_points = points[batch_indices, idx, :]
60+
return new_points
61+
62+
63+
def farthest_point_sample(xyz, npoint):
64+
"""
65+
Input:
66+
xyz: pointcloud data, [B, N, 3]
67+
npoint: number of samples
68+
Return:
69+
centroids: sampled pointcloud index, [B, npoint]
70+
"""
71+
device = xyz.device
72+
B, N, C = xyz.shape
73+
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
74+
distance = torch.ones(B, N).to(device) * 1e10
75+
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
76+
batch_indices = torch.arange(B, dtype=torch.long).to(device)
77+
for i in range(npoint):
78+
centroids[:, i] = farthest
79+
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
80+
dist = torch.sum((xyz - centroid) ** 2, -1)
81+
mask = dist < distance
82+
distance[mask] = dist[mask]
83+
farthest = torch.max(distance, -1)[1]
84+
return centroids
85+
86+
87+
def query_ball_point(radius, nsample, xyz, new_xyz):
88+
"""
89+
Input:
90+
radius: local region radius
91+
nsample: max sample number in local region
92+
xyz: all points, [B, N, 3]
93+
new_xyz: query points, [B, S, 3]
94+
Return:
95+
group_idx: grouped points index, [B, S, nsample]
96+
"""
97+
device = xyz.device
98+
B, N, C = xyz.shape
99+
_, S, _ = new_xyz.shape
100+
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
101+
sqrdists = square_distance(new_xyz, xyz)
102+
group_idx[sqrdists > radius ** 2] = N
103+
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
104+
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
105+
mask = group_idx == N
106+
group_idx[mask] = group_first[mask]
107+
return group_idx
108+
109+
110+
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
111+
"""
112+
Input:
113+
npoint:
114+
radius:
115+
nsample:
116+
xyz: input points position data, [B, N, 3]
117+
points: input points data, [B, N, D]
118+
Return:
119+
new_xyz: sampled points position data, [B, npoint, nsample, 3]
120+
new_points: sampled points data, [B, npoint, nsample, 3+D]
121+
"""
122+
B, N, C = xyz.shape
123+
S = npoint
124+
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
125+
new_xyz = index_points(xyz, fps_idx)
126+
idx = query_ball_point(radius, nsample, xyz, new_xyz)
127+
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
128+
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
129+
130+
if points is not None:
131+
grouped_points = index_points(points, idx)
132+
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
133+
else:
134+
new_points = grouped_xyz_norm
135+
if returnfps:
136+
return new_xyz, new_points, grouped_xyz, fps_idx
137+
else:
138+
return new_xyz, new_points
139+
140+
141+
def sample_and_group_all(xyz, points):
142+
"""
143+
Input:
144+
xyz: input points position data, [B, N, 3]
145+
points: input points data, [B, N, D]
146+
Return:
147+
new_xyz: sampled points position data, [B, 1, 3]
148+
new_points: sampled points data, [B, 1, N, 3+D]
149+
"""
150+
device = xyz.device
151+
B, N, C = xyz.shape
152+
new_xyz = torch.zeros(B, 1, C).to(device)
153+
grouped_xyz = xyz.view(B, 1, N, C)
154+
if points is not None:
155+
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
156+
else:
157+
new_points = grouped_xyz
158+
return new_xyz, new_points
159+
160+
161+
class PointNetSetAbstraction(nn.Module):
162+
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
163+
super(PointNetSetAbstraction, self).__init__()
164+
self.npoint = npoint
165+
self.radius = radius
166+
self.nsample = nsample
167+
self.mlp_convs = nn.ModuleList()
168+
self.mlp_bns = nn.ModuleList()
169+
last_channel = in_channel
170+
for out_channel in mlp:
171+
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
172+
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
173+
last_channel = out_channel
174+
self.group_all = group_all
175+
176+
def forward(self, xyz, points):
177+
"""
178+
Input:
179+
xyz: input points position data, [B, C, N]
180+
points: input points data, [B, D, N]
181+
Return:
182+
new_xyz: sampled points position data, [B, C, S]
183+
new_points_concat: sample points feature data, [B, D', S]
184+
"""
185+
xyz = xyz.permute(0, 2, 1)
186+
if points is not None:
187+
points = points.permute(0, 2, 1)
188+
189+
if self.group_all:
190+
new_xyz, new_points = sample_and_group_all(xyz, points)
191+
else:
192+
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
193+
# new_xyz: sampled points position data, [B, npoint, C]
194+
# new_points: sampled points data, [B, npoint, nsample, C+D]
195+
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
196+
for i, conv in enumerate(self.mlp_convs):
197+
bn = self.mlp_bns[i]
198+
new_points = F.relu(bn(conv(new_points)), inplace=True)
199+
200+
new_points = torch.max(new_points, 2)[0]
201+
new_xyz = new_xyz.permute(0, 2, 1)
202+
return new_xyz, new_points
203+
204+
205+
class PointNetSetAbstractionMsg(nn.Module):
206+
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
207+
super(PointNetSetAbstractionMsg, self).__init__()
208+
self.npoint = npoint
209+
self.radius_list = radius_list
210+
self.nsample_list = nsample_list
211+
self.conv_blocks = nn.ModuleList()
212+
self.bn_blocks = nn.ModuleList()
213+
for i in range(len(mlp_list)):
214+
convs = nn.ModuleList()
215+
bns = nn.ModuleList()
216+
last_channel = in_channel + 3
217+
for out_channel in mlp_list[i]:
218+
convs.append(nn.Conv2d(last_channel, out_channel, 1))
219+
bns.append(nn.BatchNorm2d(out_channel))
220+
last_channel = out_channel
221+
self.conv_blocks.append(convs)
222+
self.bn_blocks.append(bns)
223+
224+
def forward(self, xyz, points):
225+
"""
226+
Input:
227+
xyz: input points position data, [B, C, N]
228+
points: input points data, [B, D, N]
229+
Return:
230+
new_xyz: sampled points position data, [B, C, S]
231+
new_points_concat: sample points feature data, [B, D', S]
232+
"""
233+
xyz = xyz.permute(0, 2, 1)
234+
if points is not None:
235+
points = points.permute(0, 2, 1)
236+
237+
B, N, C = xyz.shape
238+
S = self.npoint
239+
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
240+
new_points_list = []
241+
for i, radius in enumerate(self.radius_list):
242+
K = self.nsample_list[i]
243+
group_idx = query_ball_point(radius, K, xyz, new_xyz)
244+
grouped_xyz = index_points(xyz, group_idx)
245+
grouped_xyz -= new_xyz.view(B, S, 1, C)
246+
if points is not None:
247+
grouped_points = index_points(points, group_idx)
248+
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
249+
else:
250+
grouped_points = grouped_xyz
251+
252+
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
253+
for j in range(len(self.conv_blocks[i])):
254+
conv = self.conv_blocks[i][j]
255+
bn = self.bn_blocks[i][j]
256+
grouped_points = F.relu(bn(conv(grouped_points)), inplace=True)
257+
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
258+
new_points_list.append(new_points)
259+
260+
new_xyz = new_xyz.permute(0, 2, 1)
261+
new_points_concat = torch.cat(new_points_list, dim=1)
262+
return new_xyz, new_points_concat
263+
264+
265+
class PointNetFeaturePropagation(nn.Module):
266+
def __init__(self, in_channel, mlp):
267+
super(PointNetFeaturePropagation, self).__init__()
268+
self.mlp_convs = nn.ModuleList()
269+
self.mlp_bns = nn.ModuleList()
270+
last_channel = in_channel
271+
for out_channel in mlp:
272+
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
273+
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
274+
last_channel = out_channel
275+
276+
def forward(self, xyz1, xyz2, points1, points2):
277+
"""
278+
Input:
279+
xyz1: input points position data, [B, C, N]
280+
xyz2: sampled input points position data, [B, C, S]
281+
points1: input points data, [B, D, N]
282+
points2: input points data, [B, D, S]
283+
Return:
284+
new_points: upsampled points data, [B, D', N]
285+
"""
286+
xyz1 = xyz1.permute(0, 2, 1)
287+
xyz2 = xyz2.permute(0, 2, 1)
288+
289+
points2 = points2.permute(0, 2, 1)
290+
B, N, C = xyz1.shape
291+
_, S, _ = xyz2.shape
292+
293+
if S == 1:
294+
interpolated_points = points2.repeat(1, N, 1)
295+
else:
296+
dists = square_distance(xyz1, xyz2)
297+
dists, idx = dists.sort(dim=-1)
298+
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
299+
300+
dist_recip = 1.0 / (dists + 1e-8)
301+
norm = torch.sum(dist_recip, dim=2, keepdim=True)
302+
weight = dist_recip / norm
303+
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
304+
305+
if points1 is not None:
306+
points1 = points1.permute(0, 2, 1)
307+
new_points = torch.cat([points1, interpolated_points], dim=-1)
308+
else:
309+
new_points = interpolated_points
310+
311+
new_points = new_points.permute(0, 2, 1)
312+
for i, conv in enumerate(self.mlp_convs):
313+
bn = self.mlp_bns[i]
314+
new_points = F.relu(bn(conv(new_points)), inplace=True)
315+
return new_points
316+

0 commit comments

Comments
 (0)