#How does it work?
OptNet contains 4 levels of memory optimization which are internally performed.
Those are:
- Addition of in-place operations
- Reuse of internal module temporary buffers
- Removal of
gradWeightsandgradBiasfor inference mode output/gradInputreuse once they are not needed anymore.
The following sections explain in more details each of those optimizations.
This is a pretty basic functionality at the moment. It currently only runs over the network and sets all modules which supports in-place operations to inplace mode.
As it currently does not analyse the flow of computations to evaluate if the chaining of inplace operations produces correct gradients, it might be better to turn this option off if you have at the moment and manually define the in-place modules that you want to use.
Several modules like convolutions and max pooling from nn have internal buffers to store intermediate information, like the unfolded image or the indices of the maximum values. We use a simple heuristic for buffer sharing. We manually annotated a few modules that are mainly used and the buffers that are needed and not needed for computing correct gradients. The buffers that are only needed for the forward pass or the backward pass are then shared among the other instances of the same module in the network.
When using networks only during inference, one does not need the gradients with respect to the weights and biases, which are by default allocated by nn. This function remove all gradients from the model, keeping track of sharings so that they can be exactly reconstructed when needed.
The main part of the optnet package, this function reuses the outputs (or gradInputs in training mode) by keeping track of the life of the storage of each output. Once a specific output is not used anymore, it adds it to a list of storages that can be reused.
In the discussion that follows, we will focus on the inference mode for the optimization, but a similar reasoning is used for the training mode.
The first step is then to find the first time each output is defined and last time each output is used.
Given that nn is so flexible and only enforces that self.output is defined after the forward pass and is the returned value,
the way optnet tries to handle new/generic modules it is to infer the structure of the network by running a forward pass.
If we are able to analyse the flow of each storage during the forward pass, then we will be able to infer the structure of the network
as well as the life cycle of each output.
The question is then how to track the outputs definition and use without having to write specific code for each module.
We solve this by temporarely overwriting the updateOutput function.
By using upvalues, we are able to keep track of the input that is fed to each module, as well as the output that is generated by each module.
Here is an example snippet that illustrates the idea that is employed:
local inputs = {}
net:apply(
function(x)
-- this is the original forward function
local orig_updateOutput = x.updateOutput
-- lets overwrite it to do some more things before/after
-- the original function is executed
x.updateOutput = function(self, input)
-- inputs is an upvalue, and we do not need to change the
-- function signature, so we can actually inspect
-- each module during forward call
table.insert(inputs, input)
print('hello from '..torch.typename(self))
return orig_updateOutput(self, input)
end
end
)With that in mind, we are then able to extract for each output the first time an output is defined and the last time it is used.
Once we have the life span of each storage, the memory sharing will try to find the minimum number of sets of non-overlapping storages (in a temporal sense) and attribute a single storage to each of the non-overlapping sets.
With the assignments from the previous section in hand, we can then change the storage of each output to match the assignments that were previously found.