GaussianStochasticGates

class captum.module.GaussianStochasticGates(n_gates, mask=None, reg_weight=1.0, std=0.5, reg_reduction='sum')[source]

Stochastic Gates with Gaussian distribution.

Stochastic Gates is a practical solution to add L0 norm regularization for neural networks. L0 regularization, which explicitly penalizes any present (non-zero) parameters, can help network pruning and feature selection, but directly optimizing L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate uses certain continuous probability distributions (e.g., Concrete, Gaussian) with hard-sigmoid rectification as a continuous smoothed Bernoulli distribution determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates’s non-zero probability represented by the parameters of the continuous probability distribution. The gate value can also be reparameterized to the distribution parameters with a noise. So the expected L0 can be optimized through learning the distribution parameters via stochastic gradients.

GaussianStochasticGates adopts a gaussian distribution as the smoothed Bernoulli distribution of gate. While the smoothed Bernoulli distribution should be within 0 and 1, gaussian does not have boundaries. So hard-sigmoid rectification is used to “fold” the parts smaller than 0 or larger than 1 back to 0 and 1.

More details can be found in the original paper: https://arxiv.org/abs/1810.04247

Examples:

>>> n_params = 5  # number of gates
>>> stg = GaussianStochasticGates(n_params, reg_weight=0.01)
>>> inputs = torch.randn(3, n_params)  # mock inputs with batch size of 3
>>> gated_inputs, reg = stg(mock_inputs)  # gate the inputs
Parameters:
  • n_gates (int) – number of gates.

  • mask (Tensor, optional) – If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0, i.e., batch dim) is gated separately. Default: None

  • reg_weight (float, optional) – rescaling weight for L0 regularization term. Default: 1.0

  • std (float, optional) – standard deviation that will be fixed throughout. Default: 0.5

  • reg_reduction (str, optional) – the reduction to apply to the regularization: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied and it will be the same as the return of get_active_probs, 'mean': the sum of the gates non-zero probabilities will be divided by the number of gates, 'sum': the gates non-zero probabilities will be summed. Default: 'sum'

forward(input_tensor)
Parameters:

input_tensor (Tensor) – Tensor to be gated with stochastic gates

Returns:

  • gated_input (Tensor): Tensor of the same shape weighted by the sampled

    gate values

  • l0_reg (Tensor): L0 regularization term to be optimized together with

    model loss, e.g. loss(model_out, target) + l0_reg

Return type:

tuple[Tensor, Tensor]

get_gate_active_probs()

Get the active probability of each gate, i.e, gate value > 0

Returns:

  • probs (Tensor): probabilities tensor of the gates are active

    in shape(n_gates)

Return type:

Tensor

get_gate_values(clamp=True)

Get the gate values, which are the means of the underneath gate distributions, optionally clamped within 0 and 1.

Parameters:

clamp (bool, optional) – whether to clamp the gate values or not. As smoothed Bernoulli variables, gate values are clamped within 0 and 1 by default. Turn this off to get the raw means of the underneath distribution (e.g., concrete, gaussian), which can be useful to differentiate the gates’ importance when multiple gate values are beyond 0 or 1. Default: True

Returns:

  • gate_values (Tensor): value of each gate in shape(n_gates)

Return type:

Tensor