# Source code for captum.module.gaussian_stochastic_gates

#!/usr/bin/env python3

# pyre-strict
import math
from typing import Optional

import torch
from captum.module.stochastic_gates_base import StochasticGatesBase
from torch import nn, Tensor

[docs]
class GaussianStochasticGates(StochasticGatesBase):
"""
Stochastic Gates with Gaussian distribution.

Stochastic Gates is a practical solution to add L0 norm regularization for neural
networks. L0 regularization, which explicitly penalizes any present (non-zero)
parameters, can help network pruning and feature selection, but directly optimizing
L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate
uses certain continuous probability distributions (e.g., Concrete, Gaussian) with
hard-sigmoid rectification as a continuous smoothed Bernoulli distribution
determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates's
non-zero probability represented by the parameters of the continuous probability
distribution. The gate value can also be reparameterized to the distribution
parameters with a noise. So the expected L0 can be optimized through learning
the distribution parameters via stochastic gradients.

GaussianStochasticGates adopts a gaussian distribution as the smoothed Bernoulli
distribution of gate. While the smoothed Bernoulli distribution should be
within 0 and 1, gaussian does not have boundaries. So hard-sigmoid rectification
is used to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1.

More details can be found in the original paper:
https://arxiv.org/abs/1810.04247

Examples::

>>> n_params = 5  # number of gates
>>> stg = GaussianStochasticGates(n_params, reg_weight=0.01)
>>> inputs = torch.randn(3, n_params)  # mock inputs with batch size of 3
>>> gated_inputs, reg = stg(mock_inputs)  # gate the inputs
"""

# pyre-fixme[3]: Return type must be annotated.
def __init__(
self,
n_gates: int,
reg_weight: Optional[float] = 1.0,
std: Optional[float] = 0.5,
reg_reduction: str = "sum",
):
"""
Args:
n_gates (int): number of gates.

mask (Tensor, optional): If provided, this allows grouping multiple
input tensor elements to share the same stochastic gate.
This tensor should be broadcastable to match the input shape
and contain integers in the range 0 to n_gates - 1.
Indices grouped to the same stochastic gate should have the same value.
If not provided, each element in the input tensor
(on dimensions other than dim 0, i.e., batch dim) is gated separately.
Default: None

reg_weight (float, optional): rescaling weight for L0 regularization term.
Default: 1.0

std (float, optional): standard deviation that will be fixed throughout.
Default: 0.5

reg_reduction (str, optional): the reduction to apply to the regularization:
'none' | 'mean' | 'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided
by the number of gates, 'sum': the gates non-zero probabilities will
be summed.
Default: 'sum'
"""
super().__init__(
n_gates,
# pyre-fixme[6]: For 3rd argument expected float but got
#  Optional[float].
reg_weight=reg_weight,
reg_reduction=reg_reduction,
)

mu = torch.empty(n_gates)
nn.init.normal_(mu, mean=0.5, std=0.01)
self.mu = nn.Parameter(mu)

# pyre-fixme[58]: < is not supported for operand types int and
#  Optional[float].
assert 0 < std, f"the standard deviation should be positive, received {std}"
self.std = std

def _sample_gate_values(self, batch_size: int) -> Tensor:
"""
Sample gate values for each example in the batch from the Gaussian distribution

Args:
batch_size (int): input batch size

Returns:
gate_values (Tensor): gate value tensor of shape(batch_size, n_gates)
"""

if self.training:
n = torch.empty(batch_size, self.n_gates, device=self.mu.device)
# pyre-fixme[6]: For 2nd argument expected float but got
#  Optional[float].
n.normal_(mean=0, std=self.std)
return self.mu + n

return self.mu.expand(batch_size, self.n_gates)

def _get_gate_values(self) -> Tensor:
"""
Get the raw gate values, which are the means of the underneath gate
distributions, the learned mu

Returns:
gate_values (Tensor): value of each gate after model is trained
"""
return self.mu

def _get_gate_active_probs(self) -> Tensor:
"""
Get the active probability of each gate, i.e, gate value > 0, in the
Gaussian distribution

Returns:
probs (Tensor): probabilities tensor of the gates are active
in shape(n_gates)
"""
x = self.mu / self.std
return 0.5 * (1 + torch.erf(x / math.sqrt(2)))

@classmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _from_pretrained(cls, mu: Tensor, *args, **kwargs):
"""
Private factory method to create an instance with pretrained parameters

Args:
mu (Tensor): FloatTensor containing weights for the pretrained mu

mask (Tensor, optional): If provided, this allows grouping multiple
input tensor elements to share the same stochastic gate.
This tensor should be broadcastable to match the input shape
and contain integers in the range 0 to n_gates - 1.
Indices grouped to the same stochastic gate should have the same value.
If not provided, each element in the input tensor
(on dimensions other than dim 0 - batch dim) is gated separately.
Default: None

reg_weight (float, optional): rescaling weight for L0 regularization term.
Default: 1.0

std (float, optional): standard deviation that will be fixed throughout.
Default: 0.5

reg_reduction (str, optional): the reduction to apply to the regularization:
'none' | 'mean' | 'sum'. 'none': no reduction will be
applied and it will be the same as the return of get_active_probs,
'mean': the sum of the gates non-zero probabilities will be divided
by the number of gates, 'sum': the gates non-zero probabilities will
be summed.
Default: 'sum'

Returns:
stg (GaussianStochasticGates): StochasticGates instance
"""
n_gates = mu.numel()
stg = cls(n_gates, *args, **kwargs)