Source code for captum.module.binary_concrete_stochastic_gates

#!/usr/bin/env python3

# pyre-strict
import math
from typing import Any, Optional

import torch
from captum.module.stochastic_gates_base import StochasticGatesBase
from torch import nn, Tensor


def _torch_empty(batch_size: int, n_gates: int, device: torch.device) -> Tensor:
    return torch.empty(batch_size, n_gates, device=device)


# torch.fx is introduced in 1.8.0
if hasattr(torch, "fx"):
    torch.fx.wrap(_torch_empty)


[docs] class BinaryConcreteStochasticGates(StochasticGatesBase): """ Stochastic Gates with binary concrete distribution. Stochastic Gates is a practical solution to add L0 norm regularization for neural networks. L0 regularization, which explicitly penalizes any present (non-zero) parameters, can help network pruning and feature selection, but directly optimizing L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate uses certain continuous probability distributions (e.g., Concrete, Gaussian) with hard-sigmoid rectification as a continuous smoothed Bernoulli distribution determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates's non-zero probability represented by the parameters of the continuous probability distribution. The gate value can also be reparameterized to the distribution parameters with a noise. So the expected L0 can be optimized through learning the distribution parameters via stochastic gradients. BinaryConcreteStochasticGates adopts a "stretched" binary concrete distribution as the smoothed Bernoulli distribution of gate. The binary concrete distribution does not include its lower and upper boundaries, 0 and 1, which are required by a Bernoulli distribution, so it needs to be linearly stretched beyond both boundaries. Then use hard-sigmoid rectification to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1. More details can be found in the original paper: https://arxiv.org/abs/1712.01312 Examples:: >>> n_params = 5 # number of parameters >>> stg = BinaryConcreteStochasticGates(n_params, reg_weight=0.01) >>> inputs = torch.randn(3, n_params) # mock inputs with batch size of 3 >>> gated_inputs, reg = stg(mock_inputs) # gate the inputs """ def __init__( self, n_gates: int, mask: Optional[Tensor] = None, reg_weight: float = 1.0, temperature: float = 2.0 / 3, lower_bound: float = -0.1, upper_bound: float = 1.1, eps: float = 1e-8, reg_reduction: str = "sum", ) -> None: """ Args: n_gates (int): number of gates. mask (Tensor, optional): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0, i.e., batch dim) is gated separately. Default: None reg_weight (float, optional): rescaling weight for L0 regularization term. Default: 1.0 temperature (float, optional): temperature of the concrete distribution, controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3 lower_bound (float, optional): the lower bound to "stretch" the binary concrete distribution Default: -0.1 upper_bound (float, optional): the upper bound to "stretch" the binary concrete distribution Default: 1.1 eps (float, optional): term to improve numerical stability in binary concerete sampling Default: 1e-8 reg_reduction (str, optional): the reduction to apply to the regularization: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied and it will be the same as the return of ``get_active_probs``, ``'mean'``: the sum of the gates non-zero probabilities will be divided by the number of gates, ``'sum'``: the gates non-zero probabilities will be summed. Default: ``'sum'`` """ super().__init__( n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction ) # avoid changing the tensor's variable name # when the module is used after compilation, # users may directly access this tensor by name log_alpha_param = torch.empty(n_gates) nn.init.normal_(log_alpha_param, mean=0.0, std=0.01) self.log_alpha_param = nn.Parameter(log_alpha_param) assert ( 0 < temperature < 1 ), f"the temperature should be bwteen 0 and 1, received {temperature}" self.temperature = temperature assert ( lower_bound < 0 ), f"the stretch lower bound should smaller than 0, received {lower_bound}" self.lower_bound = lower_bound assert ( upper_bound > 1 ), f"the stretch upper bound should larger than 1, received {upper_bound}" self.upper_bound = upper_bound self.eps = eps # pre-calculate the fixed term used in active prob # pyre-fixme[4]: Attribute must be annotated. self.active_prob_offset = temperature * math.log(-lower_bound / upper_bound) def _sample_gate_values(self, batch_size: int) -> Tensor: """ Sample gate values for each example in the batch from the binary concrete distributions Args: batch_size (int): input batch size Returns: gate_values (Tensor): gate value tensor of shape(batch_size, n_gates) """ if self.training: u = _torch_empty( batch_size, self.n_gates, device=self.log_alpha_param.device ) u.uniform_(self.eps, 1 - self.eps) s = torch.sigmoid( (torch.logit(u) + self.log_alpha_param) / self.temperature ) else: s = torch.sigmoid(self.log_alpha_param) s = s.expand(batch_size, self.n_gates) s_bar = s * (self.upper_bound - self.lower_bound) + self.lower_bound return s_bar def _get_gate_values(self) -> Tensor: """ Get the raw gate values, which are the means of the underneath gate distributions, derived from learned log_alpha_param Returns: gate_values (Tensor): value of each gate after model is trained """ gate_values = ( torch.sigmoid(self.log_alpha_param) * (self.upper_bound - self.lower_bound) + self.lower_bound ) return gate_values def _get_gate_active_probs(self) -> Tensor: """ Get the active probability of each gate, i.e, gate value > 0, in the binary concrete distributions Returns: probs (Tensor): probabilities tensor of the gates are active in shape(n_gates) """ return torch.sigmoid(self.log_alpha_param - self.active_prob_offset) @classmethod def _from_pretrained( cls, log_alpha_param: Tensor, *args: Any, **kwargs: Any ) -> "BinaryConcreteStochasticGates": """ Private factory method to create an instance with pretrained parameters Args: log_alpha_param (Tensor): FloatTensor containing weights for the pretrained log_alpha mask (Tensor, optional): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0 - batch dim) is gated separately. Default: None reg_weight (float, optional): rescaling weight for L0 regularization term. Default: 1.0 temperature (float, optional): temperature of the concrete distribution, controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3 lower_bound (float, optional): the lower bound to "stretch" the binary concrete distribution Default: -0.1 upper_bound (float, optional): the upper bound to "stretch" the binary concrete distribution Default: 1.1 eps (float, optional): term to improve numerical stability in binary concerete sampling Default: 1e-8 reg_reduction (str, optional): the reduction to apply to the regularization: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied and it will be the same as the return of ``get_active_probs``, ``'mean'``: the sum of the gates non-zero probabilities will be divided by the number of gates, ``'sum'``: the gates non-zero probabilities will be summed. Default: ``'sum'`` Returns: stg (BinaryConcreteStochasticGates): StochasticGates instance """ assert ( log_alpha_param.dim() == 1 ), "log_alpha_param is expected to be 1-dimensional" n_gates = log_alpha_param.numel() stg = cls(n_gates, *args, **kwargs) stg.load_state_dict({"log_alpha_param": log_alpha_param}, strict=False) return stg