GaussianStochasticGates¶
- class captum.module.GaussianStochasticGates(n_gates, mask=None, reg_weight=1.0, std=0.5, reg_reduction='sum')[source]¶
Stochastic Gates with Gaussian distribution.
Stochastic Gates is a practical solution to add L0 norm regularization for neural networks. L0 regularization, which explicitly penalizes any present (non-zero) parameters, can help network pruning and feature selection, but directly optimizing L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate uses certain continuous probability distributions (e.g., Concrete, Gaussian) with hard-sigmoid rectification as a continuous smoothed Bernoulli distribution determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates’s non-zero probability represented by the parameters of the continuous probability distribution. The gate value can also be reparameterized to the distribution parameters with a noise. So the expected L0 can be optimized through learning the distribution parameters via stochastic gradients.
GaussianStochasticGates adopts a gaussian distribution as the smoothed Bernoulli distribution of gate. While the smoothed Bernoulli distribution should be within 0 and 1, gaussian does not have boundaries. So hard-sigmoid rectification is used to “fold” the parts smaller than 0 or larger than 1 back to 0 and 1.
More details can be found in the original paper: https://arxiv.org/abs/1810.04247
Examples:
>>> n_params = 5 # number of gates >>> stg = GaussianStochasticGates(n_params, reg_weight=0.01) >>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3 >>> gated_inputs, reg = stg(mock_inputs) # gate the inputs
- Parameters:
n_gates (int) – number of gates.
mask (Tensor, optional) – If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0, i.e., batch dim) is gated separately. Default: None
reg_weight (float, optional) – rescaling weight for L0 regularization term. Default: 1.0
std (float, optional) – standard deviation that will be fixed throughout. Default: 0.5
reg_reduction (str, optional) – the reduction to apply to the regularization:
'none'
|'mean'
|'sum'
.'none'
: no reduction will be applied and it will be the same as the return ofget_active_probs
,'mean'
: the sum of the gates non-zero probabilities will be divided by the number of gates,'sum'
: the gates non-zero probabilities will be summed. Default:'sum'
- forward(input_tensor)¶
- Parameters:
input_tensor (Tensor) – Tensor to be gated with stochastic gates
- Returns:
- gated_input (Tensor): Tensor of the same shape weighted by the sampled
gate values
- l0_reg (Tensor): L0 regularization term to be optimized together with
model loss, e.g. loss(model_out, target) + l0_reg
- Return type:
tuple[Tensor, Tensor]
- get_gate_active_probs()¶
Get the active probability of each gate, i.e, gate value > 0
- Returns:
- probs (Tensor): probabilities tensor of the gates are active
in shape(n_gates)
- Return type:
Tensor
- get_gate_values(clamp=True)¶
Get the gate values, which are the means of the underneath gate distributions, optionally clamped within 0 and 1.
- Parameters:
clamp (bool, optional) – whether to clamp the gate values or not. As smoothed Bernoulli variables, gate values are clamped within 0 and 1 by default. Turn this off to get the raw means of the underneath distribution (e.g., concrete, gaussian), which can be useful to differentiate the gates’ importance when multiple gate values are beyond 0 or 1. Default:
True
- Returns:
gate_values (Tensor): value of each gate in shape(n_gates)
- Return type:
Tensor