Skip to content

Commit 3fac647

Browse files
author
Zach Teed
committed
fixed problems with variational dropout
1 parent dd91321 commit 3fac647

File tree

5 files changed

+22
-8
lines changed

5 files changed

+22
-8
lines changed

RAFT.png

199 KB
Loading

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ This repository contains the source code for our paper:
44
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
55
Zachary Teed and Jia Deng<br/>
66

7+
<img src="RAFT.png">
8+
79
## Requirements
810
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed
911

@@ -84,11 +86,11 @@ python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps
8486
You can evaluate a model on Sintel and KITTI by running
8587

8688
```Shell
87-
python evaluate.py --model=checkpoints/chairs+things.pth
89+
python evaluate.py --model=models/chairs+things.pth
8890
```
8991

9092
or the small model by including the `small` flag
9193

9294
```Shell
93-
python evaluate.py --model=checkpoints/small.pth --small
95+
python evaluate.py --model=models/small.pth --small
9496
```

core/modules/update.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,20 @@ def __init__(self, args, hidden_dim=96):
133133
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
134134
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
135135

136+
self.drop_inp = VariationalHidDropout(dropout=args.dropout)
137+
self.drop_net = VariationalHidDropout(dropout=args.dropout)
138+
139+
def reset_mask(self, net, inp):
140+
self.drop_inp.reset_mask(inp)
141+
self.drop_net.reset_mask(net)
142+
136143
def forward(self, net, inp, corr, flow):
137144
motion_features = self.encoder(flow, corr)
145+
146+
if self.training:
147+
net = self.drop_net(net)
148+
inp = self.drop_inp(inp)
149+
138150
inp = torch.cat([inp, motion_features], dim=1)
139151
net = self.gru(net, inp)
140152
delta_flow = self.flow_head(net)
@@ -157,12 +169,12 @@ def reset_mask(self, net, inp):
157169

158170
def forward(self, net, inp, corr, flow):
159171
motion_features = self.encoder(flow, corr)
160-
inp = torch.cat([inp, motion_features], dim=1)
161172

162173
if self.training:
163174
net = self.drop_net(net)
164175
inp = self.drop_inp(inp)
165-
176+
177+
inp = torch.cat([inp, motion_features], dim=1)
166178
net = self.gru(net, inp)
167179
delta_flow = self.flow_head(net)
168180

core/raft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, args):
2626
args.corr_levels = 4
2727
args.corr_radius = 4
2828

29-
if 'dropout' not in args._get_kwargs():
29+
if not hasattr(args, 'dropout'):
3030
args.dropout = 0
3131

3232
# feature network, context network, and update block

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
# exclude extremly large displacements
2323
MAX_FLOW = 1000
24-
SUM_FREQ = 100
24+
SUM_FREQ = 200
2525
VAL_FREQ = 5000
2626

2727

@@ -56,7 +56,7 @@ def sequence_loss(flow_preds, flow_gt, valid):
5656

5757

5858
def fetch_dataloader(args):
59-
""" Create the data loader for the corresponding trainign set """
59+
""" Create the data loader for the corresponding training set """
6060

6161
if args.dataset == 'chairs':
6262
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
@@ -86,7 +86,7 @@ def fetch_optimizer(args, model):
8686
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
8787

8888
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
89-
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear', final_div_factor=1.0)
89+
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear')
9090

9191
return optimizer, scheduler
9292

0 commit comments

Comments
 (0)