BinaryConcreteStochasticGates

class captum.module.BinaryConcreteStochasticGates(n_gates, mask=None, reg_weight=1.0, temperature=2.0 / 3, lower_bound=- 0.1, upper_bound=1.1, eps=1e-08, reg_reduction='sum')[source]

Stochastic Gates with binary concrete distribution.

Stochastic Gates is a practical solution to add L0 norm regularization for neural networks. L0 regularization, which explicitly penalizes any present (non-zero) parameters, can help network pruning and feature selection, but directly optimizing L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate uses certain continuous probability distributions (e.g., Concrete, Gaussian) with hard-sigmoid rectification as a continuous smoothed Bernoulli distribution determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates’s non-zero probability represented by the parameters of the continuous probability distribution. The gate value can also be reparameterized to the distribution parameters with a noise. So the expected L0 can be optimized through learning the distribution parameters via stochastic gradients.

BinaryConcreteStochasticGates adopts a “stretched” binary concrete distribution as the smoothed Bernoulli distribution of gate. The binary concrete distribution does not include its lower and upper boundaries, 0 and 1, which are required by a Bernoulli distribution, so it needs to be linearly stretched beyond both boundaries. Then use hard-sigmoid rectification to “fold” the parts smaller than 0 or larger than 1 back to 0 and 1.

More details can be found in the original paper <https://arxiv.org/abs/1712.01312>.

Parameters
  • n_gates (int) – number of gates.

  • mask (Optional[Tensor]) – If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0 - batch dim) is gated separately. Default: None

  • reg_weight (Optional[float]) – rescaling weight for L0 regularization term. Default: 1.0

  • temperature (float) – temperature of the concrete distribution, controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3

  • lower_bound (float) – the lower bound to “stretch” the binary concrete distribution Default: -0.1

  • upper_bound (float) – the upper bound to “stretch” the binary concrete distribution Default: 1.1

  • eps (float) – term to improve numerical stability in binary concerete sampling Default: 1e-8

  • reg_reduction (str, optional) – the reduction to apply to the regularization: ‘none’|’mean’|’sum’. ‘none’: no reduction will be applied and it will be the same as the return of get_active_probs, ‘mean’: the sum of the gates non-zero probabilities will be divided by the number of gates, ‘sum’: the gates non-zero probabilities will be summed. Default: ‘sum’

forward(input_tensor)
Parameters

input_tensor (Tensor) – Tensor to be gated with stochastic gates

Returns

  • gated_input (Tensor): Tensor of the same shape weighted by the sampled

    gate values

  • l0_reg (Tensor): L0 regularization term to be optimized together with

    model loss, e.g. loss(model_out, target) + l0_reg

Return type

tuple[Tensor, Tensor]

get_gate_active_probs()

Get the active probability of each gate, i.e, gate value > 0

Returns

  • probs (Tensor): probabilities tensor of the gates are active

    in shape(n_gates)

Return type

Tensor

get_gate_values(clamp=True)

Get the gate values, which are the means of the underneath gate distributions, optionally clamped within 0 and 1.

Parameters

clamp (bool) – whether to clamp the gate values or not. As smoothed Bernoulli variables, gate values are clamped within 0 and 1 by default. Turn this off to get the raw means of the underneath distribution (e.g., concrete, gaussian), which can be useful to differentiate the gates’ importance when multiple gate values are beyond 0 or 1. Default: True

Returns

  • gate_values (Tensor): value of each gate in shape(n_gates)

Return type

Tensor