Skip to content

Commit a09d865

Browse files
committed
fix cuda
1 parent 5c88856 commit a09d865

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

models.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@ def point_grad_to(self, target):
1313
Set .grad attribute of each parameter to be proportional
1414
to the difference between self and target
1515
'''
16-
is_cuda = next(self.parameters()).is_cuda
1716
for p, target_p in zip(self.parameters(), target.parameters()):
1817
if p.grad is None:
19-
if is_cuda:
18+
if self.is_cuda():
2019
p.grad = Variable(torch.zeros(p.size())).cuda()
2120
else:
2221
p.grad = Variable(torch.zeros(p.size()))
2322
p.grad.data.zero_() # not sure this is required
2423
p.grad.data.add_(p.data - target_p.data)
2524

25+
def is_cuda(self):
26+
return next(self.parameters()).is_cuda
27+
2628

2729
class OmniglotModel(ReptileModel):
2830
"""
@@ -77,6 +79,8 @@ def predict(self, prob):
7779
def clone(self):
7880
clone = OmniglotModel(self.num_classes)
7981
clone.load_state_dict(self.state_dict())
82+
if self.is_cuda():
83+
clone.cuda()
8084
return clone
8185

8286

0 commit comments

Comments
 (0)