File tree 1 file changed +6
-2
lines changed
1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -13,16 +13,18 @@ def point_grad_to(self, target):
13
13
Set .grad attribute of each parameter to be proportional
14
14
to the difference between self and target
15
15
'''
16
- is_cuda = next (self .parameters ()).is_cuda
17
16
for p , target_p in zip (self .parameters (), target .parameters ()):
18
17
if p .grad is None :
19
- if is_cuda :
18
+ if self . is_cuda () :
20
19
p .grad = Variable (torch .zeros (p .size ())).cuda ()
21
20
else :
22
21
p .grad = Variable (torch .zeros (p .size ()))
23
22
p .grad .data .zero_ () # not sure this is required
24
23
p .grad .data .add_ (p .data - target_p .data )
25
24
25
+ def is_cuda (self ):
26
+ return next (self .parameters ()).is_cuda
27
+
26
28
27
29
class OmniglotModel (ReptileModel ):
28
30
"""
@@ -77,6 +79,8 @@ def predict(self, prob):
77
79
def clone (self ):
78
80
clone = OmniglotModel (self .num_classes )
79
81
clone .load_state_dict (self .state_dict ())
82
+ if self .is_cuda ():
83
+ clone .cuda ()
80
84
return clone
81
85
82
86
You can’t perform that action at this time.
0 commit comments