BinaryConcreteStochasticGates¶
- class captum.module.BinaryConcreteStochasticGates(n_gates, mask=None, reg_weight=1.0, temperature=2.0 / 3, lower_bound=-0.1, upper_bound=1.1, eps=1e-8, reg_reduction='sum')[source]¶
Stochastic Gates with binary concrete distribution.
Stochastic Gates is a practical solution to add L0 norm regularization for neural networks. L0 regularization, which explicitly penalizes any present (non-zero) parameters, can help network pruning and feature selection, but directly optimizing L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate uses certain continuous probability distributions (e.g., Concrete, Gaussian) with hard-sigmoid rectification as a continuous smoothed Bernoulli distribution determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates’s non-zero probability represented by the parameters of the continuous probability distribution. The gate value can also be reparameterized to the distribution parameters with a noise. So the expected L0 can be optimized through learning the distribution parameters via stochastic gradients.
BinaryConcreteStochasticGates adopts a “stretched” binary concrete distribution as the smoothed Bernoulli distribution of gate. The binary concrete distribution does not include its lower and upper boundaries, 0 and 1, which are required by a Bernoulli distribution, so it needs to be linearly stretched beyond both boundaries. Then use hard-sigmoid rectification to “fold” the parts smaller than 0 or larger than 1 back to 0 and 1.
More details can be found in the original paper: https://arxiv.org/abs/1712.01312
Examples:
>>> n_params = 5 # number of parameters >>> stg = BinaryConcreteStochasticGates(n_params, reg_weight=0.01) >>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3 >>> gated_inputs, reg = stg(mock_inputs) # gate the inputs
- Parameters:
n_gates (int) – number of gates.
mask (Tensor, optional) – If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0, i.e., batch dim) is gated separately. Default: None
reg_weight (float, optional) – rescaling weight for L0 regularization term. Default: 1.0
temperature (float, optional) – temperature of the concrete distribution, controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3
lower_bound (float, optional) – the lower bound to “stretch” the binary concrete distribution Default: -0.1
upper_bound (float, optional) – the upper bound to “stretch” the binary concrete distribution Default: 1.1
eps (float, optional) – term to improve numerical stability in binary concerete sampling Default: 1e-8
reg_reduction (str, optional) – the reduction to apply to the regularization:
'none'
|'mean'
|'sum'
.'none'
: no reduction will be applied and it will be the same as the return ofget_active_probs
,'mean'
: the sum of the gates non-zero probabilities will be divided by the number of gates,'sum'
: the gates non-zero probabilities will be summed. Default:'sum'
- forward(input_tensor)¶
- Parameters:
input_tensor (Tensor) – Tensor to be gated with stochastic gates
- Returns:
- gated_input (Tensor): Tensor of the same shape weighted by the sampled
gate values
- l0_reg (Tensor): L0 regularization term to be optimized together with
model loss, e.g. loss(model_out, target) + l0_reg
- Return type:
tuple[Tensor, Tensor]
- get_gate_active_probs()¶
Get the active probability of each gate, i.e, gate value > 0
- Returns:
- probs (Tensor): probabilities tensor of the gates are active
in shape(n_gates)
- Return type:
Tensor
- get_gate_values(clamp=True)¶
Get the gate values, which are the means of the underneath gate distributions, optionally clamped within 0 and 1.
- Parameters:
clamp (bool, optional) – whether to clamp the gate values or not. As smoothed Bernoulli variables, gate values are clamped within 0 and 1 by default. Turn this off to get the raw means of the underneath distribution (e.g., concrete, gaussian), which can be useful to differentiate the gates’ importance when multiple gate values are beyond 0 or 1. Default:
True
- Returns:
gate_values (Tensor): value of each gate in shape(n_gates)
- Return type:
Tensor