Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions fuse_conv_and_bn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch


def fuse_conv_and_bn(conv, bn):
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
fusedconv = (
torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)

# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

# Prepare spatial bias
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

return fusedconv
104 changes: 46 additions & 58 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import numpy as np
import cv2
from NeuFlow.neuflow import NeuFlow

from NeuFlow.backbone_v7 import ConvBlock
from data_utils import flow_viz
from fuse_conv_and_bn import fuse_conv_and_bn


image_width = 768
Expand All @@ -20,80 +22,66 @@ def get_cuda_image(image_path):
return image[None].cuda()


def fuse_conv_and_bn(conv, bn):
"""Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/."""
fusedconv = (
torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)

# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))

# Prepare spatial bias
b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)

return fusedconv
def main(hugging_face: bool) -> None:
image_path_list = sorted(glob('test_images/*.jpg'))
vis_path = 'test_results/'

device = torch.device('cuda')

image_path_list = sorted(glob('test_images/*.jpg'))
vis_path = 'test_results/'
if hugging_face:
print("from_pretrained: Study-is-happy/neuflow-v2")
model = NeuFlow.from_pretrained("Study-is-happy/neuflow-v2").to(device)
else:
model = NeuFlow().to(device)
print("load: neuflow_mixed.pth")
checkpoint = torch.load('neuflow_mixed.pth', map_location='cuda')
model.load_state_dict(checkpoint['model'], strict=True)

device = torch.device('cuda')
for m in model.modules():
if type(m) is ConvBlock:
m.conv1 = fuse_conv_and_bn(m.conv1, m.norm1) # update conv
m.conv2 = fuse_conv_and_bn(m.conv2, m.norm2) # update conv
delattr(m, "norm1") # remove batchnorm
delattr(m, "norm2") # remove batchnorm
m.forward = m.forward_fuse # update forward

model = NeuFlow().to(device)
model.eval()
model.half()

checkpoint = torch.load('neuflow_mixed.pth', map_location='cuda')
model.init_bhwd(1, image_height, image_width, 'cuda')

model.load_state_dict(checkpoint['model'], strict=True)
if not os.path.exists(vis_path):
os.makedirs(vis_path)

for m in model.modules():
if type(m) is ConvBlock:
m.conv1 = fuse_conv_and_bn(m.conv1, m.norm1) # update conv
m.conv2 = fuse_conv_and_bn(m.conv2, m.norm2) # update conv
delattr(m, "norm1") # remove batchnorm
delattr(m, "norm2") # remove batchnorm
m.forward = m.forward_fuse # update forward
for image_path_0, image_path_1 in zip(image_path_list[:-1], image_path_list[1:]):

model.eval()
model.half()
print(image_path_0)

model.init_bhwd(1, image_height, image_width, 'cuda')
image_0 = get_cuda_image(image_path_0)
image_1 = get_cuda_image(image_path_1)

if not os.path.exists(vis_path):
os.makedirs(vis_path)
file_name = os.path.basename(image_path_0)

for image_path_0, image_path_1 in zip(image_path_list[:-1], image_path_list[1:]):
with torch.no_grad():

print(image_path_0)
flow = model(image_0, image_1)[-1][0]

image_0 = get_cuda_image(image_path_0)
image_1 = get_cuda_image(image_path_1)
flow = flow.permute(1,2,0).cpu().numpy()

file_name = os.path.basename(image_path_0)
flow = flow_viz.flow_to_image(flow)

with torch.no_grad():
image_0 = cv2.resize(cv2.imread(image_path_0), (image_width, image_height))

flow = model(image_0, image_1)[-1][0]
cv2.imwrite(vis_path + file_name, np.vstack([image_0, flow]))

flow = flow.permute(1,2,0).cpu().numpy()

flow = flow_viz.flow_to_image(flow)

image_0 = cv2.resize(cv2.imread(image_path_0), (image_width, image_height))
if __name__ == "__main__":
from argparse import ArgumentParser

cv2.imwrite(vis_path + file_name, np.vstack([image_0, flow]))
parser = ArgumentParser()
parser.add_argument(
"--hugging-face", action="store_true",
help="load model form hugging face (Study-is-happy/neuflow-v2)"
)
args = parser.parse_args()
main(args.hugging_face)
96 changes: 0 additions & 96 deletions infer_hf.py

This file was deleted.