Skip to content
This repository has been archived by the owner on Nov 1, 2021. It is now read-only.

Commit

Permalink
Merge pull request #107 from as1986/finitedifference
Browse files Browse the repository at this point in the history
alternative finite difference gradient check
  • Loading branch information
alexbw committed May 23, 2016
2 parents 3d07934 + 8e4687a commit cc6d1f6
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/gradcheck.lua
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,41 @@ local function jacobianFromFiniteDifferences(func, inputs, var)
return grads
end

local function gradcheckvar2(func, inputs, var, randomizeInput)
-- Random input:
if randomizeInput then
var:uniform(-10,10)
end

-- Estimate grads with fprop:
local jacobian = jacobianFromAutograd(func, inputs, var)
local originalLoss = func(table.unpack(inputs))
local noise = jacobian:view(-1):clone():zero()
local idx = math.random(1, noise:size(1))

local originalVar = var:clone()
noise:narrow(1,idx,1):uniform(-perturbation, perturbation)
var:add(torch.view(noise, var:size()))

local perturbedLoss = func(table.unpack(inputs))
local approxPerturbed = originalLoss + torch.dot(jacobian, noise)

-- Error:
local err = math.abs((perturbedLoss - approxPerturbed)) /
(math.max(math.abs(perturbedLoss), math.abs(originalLoss))+perturbation)

-- Threhold?
local pass = err < threshold
if not pass then
print('original loss = '..originalLoss)
print('perturbed loss = '..perturbedLoss)
print('approximated perturbed loss = '..approxPerturbed)
print('error = ' .. err)
end
var:copy(originalVar)
return pass, err
end

local function gradcheckvar(func, inputs, var, randomizeInput)
-- Random input:
if randomizeInput then
Expand Down Expand Up @@ -102,10 +137,15 @@ return function(opt)
local args = {...}
-- get all vars:
local vars = autograd.util.sortedFlatten(args[1])
local max_err = 0
local ok = true
for i,var in ipairs(vars) do
local t, err = gradcheckvar2(func, args, var, randomizeInput)
ok = ok and t
if err > max_err then max_err = err end
ok = ok and gradcheckvar(func, args, var, randomizeInput)
end
print('[gradcheck2] maximum error = '..max_err)
return ok
end

Expand Down

0 comments on commit cc6d1f6

Please sign in to comment.