Source code for captum.attr._utils.baselines
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# pyre-strict
import random
from typing import Any, Dict, List, Tuple, Union
[docs]
class ProductBaselines:
"""
A Callable Baselines class that returns a sample from the Cartesian product of
the inputs' available baselines.
Args:
baseline_values (List or Dict): A list or dict of lists containing
the possible values for each feature. If a dict is provided, the keys
can a string of the feature name and the values is a list of available
baselines. The keys can also be a tuple of strings to group
multiple features whose baselines are not independent to each other.
If the key is a tuple, the value must be a list of tuples of
the corresponding values.
"""
def __init__(
self,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
baseline_values: Union[
List[List[Any]],
Dict[Union[str, Tuple[str, ...]], List[Any]],
],
) -> None:
if isinstance(baseline_values, dict):
dict_keys = list(baseline_values.keys())
baseline_values = [baseline_values[k] for k in dict_keys]
else:
dict_keys = []
# pyre-fixme[4]: Attribute must be annotated.
self.dict_keys = dict_keys
self.baseline_values = baseline_values
# pyre-fixme[3]: Return annotation cannot contain `Any`.
def sample(self) -> Union[List[Any], Dict[str, Any]]:
baselines = [
random.choice(baseline_list) for baseline_list in self.baseline_values
]
if not self.dict_keys:
return baselines
dict_baselines = {}
for key, val in zip(self.dict_keys, baselines):
if not isinstance(key, tuple):
key, val = (key,), (val,)
for k, v in zip(key, val):
dict_baselines[k] = v
return dict_baselines
# pyre-fixme[3]: Return annotation cannot contain `Any`.
def __call__(self) -> Union[List[Any], Dict[str, Any]]:
"""
Returns:
baselines (List or Dict): A sample from the Cartesian product of
the inputs' available baselines
"""
return self.sample()