diff --git a/src/gradfuns.lua b/src/gradfuns.lua index b9c7b3b..0f20da6 100644 --- a/src/gradfuns.lua +++ b/src/gradfuns.lua @@ -152,7 +152,12 @@ functions.set = { return nil end, function(g, ans, x, k, v) - return g[k] + local gk = getValue(g[k]) + if type(gk) == 'number' then + return gk + else + return torch.clone(gk) + end end, } diff --git a/test/test.lua b/test/test.lua index 54b31c8..1d807db 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1687,6 +1687,11 @@ local tests = { return torch.sum(xc) end tester:assert(gradcheck(f4,{x=torch.randn(10,10),y=torch.randn(3)}), "Incorrect gradient") + local f5 = function(params) + params.x[2] = params.y*2.0 + return torch.sum(params.x) + end + tester:assert(gradcheck(f5,{x=torch.randn(10,10),y=torch.randn(10)}), "Incorrect gradient") end, ScalarSigmoid = function()