Source code for captum.module.binary_concrete_stochastic_gates

#!/usr/bin/env python3
import math
from typing import Optional

import torch
from captum.module.stochastic_gates_base import StochasticGatesBase
from torch import nn, Tensor


def _torch_empty(batch_size: int, n_gates: int, device: torch.device) -> Tensor:
    return torch.empty(batch_size, n_gates, device=device)


# torch.fx is introduced in 1.8.0
if hasattr(torch, "fx"):
    torch.fx.wrap(_torch_empty)


def _logit(inp):
    # torch.logit is introduced in 1.7.0
    if hasattr(torch, "logit"):
        return torch.logit(inp)
    else:
        return torch.log(inp) - torch.log(1 - inp)


[docs]class BinaryConcreteStochasticGates(StochasticGatesBase): """ Stochastic Gates with binary concrete distribution. Stochastic Gates is a practical solution to add L0 norm regularization for neural networks. L0 regularization, which explicitly penalizes any present (non-zero) parameters, can help network pruning and feature selection, but directly optimizing L0 is a non-differentiable combinatorial problem. To surrogate L0, Stochastic Gate uses certain continuous probability distributions (e.g., Concrete, Gaussian) with hard-sigmoid rectification as a continuous smoothed Bernoulli distribution determining the weight of a parameter, i.e., gate. Then L0 is equal to the gates's non-zero probability represented by the parameters of the continuous probability distribution. The gate value can also be reparameterized to the distribution parameters with a noise. So the expected L0 can be optimized through learning the distribution parameters via stochastic gradients. BinaryConcreteStochasticGates adopts a "stretched" binary concrete distribution as the smoothed Bernoulli distribution of gate. The binary concrete distribution does not include its lower and upper boundaries, 0 and 1, which are required by a Bernoulli distribution, so it needs to be linearly stretched beyond both boundaries. Then use hard-sigmoid rectification to "fold" the parts smaller than 0 or larger than 1 back to 0 and 1. More details can be found in the `original paper <https://arxiv.org/abs/1712.01312>`. """ def __init__( self, n_gates: int, mask: Optional[Tensor] = None, reg_weight: float = 1.0, temperature: float = 2.0 / 3, lower_bound: float = -0.1, upper_bound: float = 1.1, eps: float = 1e-8, reg_reduction: str = "sum", ): """ Args: n_gates (int): number of gates. mask (Optional[Tensor]): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0 - batch dim) is gated separately. Default: None reg_weight (Optional[float]): rescaling weight for L0 regularization term. Default: 1.0 temperature (float): temperature of the concrete distribution, controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3 lower_bound (float): the lower bound to "stretch" the binary concrete distribution Default: -0.1 upper_bound (float): the upper bound to "stretch" the binary concrete distribution Default: 1.1 eps (float): term to improve numerical stability in binary concerete sampling Default: 1e-8 reg_reduction (str, optional): the reduction to apply to the regularization: 'none'|'mean'|'sum'. 'none': no reduction will be applied and it will be the same as the return of get_active_probs, 'mean': the sum of the gates non-zero probabilities will be divided by the number of gates, 'sum': the gates non-zero probabilities will be summed. Default: 'sum' """ super().__init__( n_gates, mask=mask, reg_weight=reg_weight, reg_reduction=reg_reduction ) # avoid changing the tensor's variable name # when the module is used after compilation, # users may directly access this tensor by name log_alpha_param = torch.empty(n_gates) nn.init.normal_(log_alpha_param, mean=0.0, std=0.01) self.log_alpha_param = nn.Parameter(log_alpha_param) assert ( 0 < temperature < 1 ), f"the temperature should be bwteen 0 and 1, received {temperature}" self.temperature = temperature assert ( lower_bound < 0 ), f"the stretch lower bound should smaller than 0, received {lower_bound}" self.lower_bound = lower_bound assert ( upper_bound > 1 ), f"the stretch upper bound should larger than 1, received {upper_bound}" self.upper_bound = upper_bound self.eps = eps # pre-calculate the fixed term used in active prob self.active_prob_offset = temperature * math.log(-lower_bound / upper_bound) def _sample_gate_values(self, batch_size: int) -> Tensor: """ Sample gate values for each example in the batch from the binary concrete distributions Args: batch_size (int): input batch size Returns: gate_values (Tensor): gate value tensor of shape(batch_size, n_gates) """ if self.training: u = _torch_empty( batch_size, self.n_gates, device=self.log_alpha_param.device ) u.uniform_(self.eps, 1 - self.eps) s = torch.sigmoid((_logit(u) + self.log_alpha_param) / self.temperature) else: s = torch.sigmoid(self.log_alpha_param) s = s.expand(batch_size, self.n_gates) s_bar = s * (self.upper_bound - self.lower_bound) + self.lower_bound return s_bar def _get_gate_values(self) -> Tensor: """ Get the raw gate values, which are the means of the underneath gate distributions, derived from learned log_alpha_param Returns: gate_values (Tensor): value of each gate after model is trained """ gate_values = ( torch.sigmoid(self.log_alpha_param) * (self.upper_bound - self.lower_bound) + self.lower_bound ) return gate_values def _get_gate_active_probs(self) -> Tensor: """ Get the active probability of each gate, i.e, gate value > 0, in the binary concrete distributions Returns: probs (Tensor): probabilities tensor of the gates are active in shape(n_gates) """ return torch.sigmoid(self.log_alpha_param - self.active_prob_offset) @classmethod def _from_pretrained(cls, log_alpha_param: Tensor, *args, **kwargs): """ Private factory method to create an instance with pretrained parameters Args: log_alpha_param (Tensor): FloatTensor containing weights for the pretrained log_alpha mask (Optional[Tensor]): If provided, this allows grouping multiple input tensor elements to share the same stochastic gate. This tensor should be broadcastable to match the input shape and contain integers in the range 0 to n_gates - 1. Indices grouped to the same stochastic gate should have the same value. If not provided, each element in the input tensor (on dimensions other than dim 0 - batch dim) is gated separately. Default: None reg_weight (Optional[float]): rescaling weight for L0 regularization term. Default: 1.0 temperature (float): temperature of the concrete distribution, controls the degree of approximation, as 0 means the original Bernoulli without relaxation. The value should be between 0 and 1. Default: 2/3 lower_bound (float): the lower bound to "stretch" the binary concrete distribution Default: -0.1 upper_bound (float): the upper bound to "stretch" the binary concrete distribution Default: 1.1 eps (float): term to improve numerical stability in binary concerete sampling Default: 1e-8 Returns: stg (BinaryConcreteStochasticGates): StochasticGates instance """ assert ( log_alpha_param.dim() == 1 ), "log_alpha_param is expected to be 1-dimensional" n_gates = log_alpha_param.numel() stg = cls(n_gates, *args, **kwargs) stg.load_state_dict({"log_alpha_param": log_alpha_param}, strict=False) return stg