diff --git a/src/gradfuns.lua b/src/gradfuns.lua index 80fade1..24306e7 100644 --- a/src/gradfuns.lua +++ b/src/gradfuns.lua @@ -152,7 +152,12 @@ functions.set = { return nil end, function(g, ans, x, k, v) - return g[k] + local gk = getValue(g[k]) + if type(gk) == 'number' then + return gk + else + return torch.clone(gk) + end end, }