#!/usr/bin/env python3
# pyre-strict
import warnings
from collections import namedtuple
from typing import (
Any,
Callable,
cast,
Dict,
Generic,
List,
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
)
from captum._utils.common import (
_expand_additional_forward_args,
_format_additional_forward_args,
_reduce_list,
)
from captum.attr import Max, Mean, Min, Summarizer
from captum.log import log_usage
from captum.robust._core.perturbation import Perturbation
from torch import Tensor
ORIGINAL_KEY = "Original"
MetricResultType = TypeVar(
"MetricResultType", float, Tensor, Tuple[Union[float, Tensor], ...]
)
class AttackInfo(NamedTuple):
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
attack_fn: Union[Perturbation, Callable]
name: str
num_attempts: int
apply_before_preproc: bool
attack_kwargs: Dict[str, Any]
additional_args: List[str]
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def agg_metric(inp):
if isinstance(inp, Tensor):
return inp.mean(dim=0)
elif isinstance(inp, tuple):
return tuple(agg_metric(elem) for elem in inp)
return inp
[docs]
class AttackComparator(Generic[MetricResultType]):
r"""
Allows measuring model robustness for a given attack or set of attacks. This class
can be used with any metric(s) as well as any set of attacks, either based on
attacks / perturbations from captum.robust such as FGSM or PGD or external
augmentation methods or perturbations such as torchvision transforms.
"""
def __init__(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
forward_func: Callable,
metric: Callable[..., MetricResultType],
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
preproc_fn: Optional[Callable] = None,
) -> None:
r"""
Args:
forward_func (Callable or torch.nn.Module): This can either be an instance
of pytorch model or any modification of a model's forward
function.
metric (Callable): This function is applied to the model output in
order to compute the desired performance metric or metrics.
This function should have the following signature::
>>> def model_metric(model_out: Tensor, **kwargs: Any)
>>> -> Union[float, Tensor, Tuple[Union[float, Tensor], ...]:
All kwargs provided to evaluate are provided to the metric function,
following the model output. A single metric can be returned as
a float or tensor, and multiple metrics should be returned as either
a tuple or named tuple of floats or tensors. For a tensor metric,
the first dimension should match the batch size, corresponding to
metrics for each example. Tensor metrics are averaged over the first
dimension when aggregating multiple batch results.
If tensor metrics represent results for the full batch, the size of the
first dimension should be 1.
preproc_fn (Callable, optional): Optional method applied to inputs. Output
of preproc_fn is then provided as input to model, in addition to
additional_forward_args provided to evaluate.
Default: ``None``
"""
self.forward_func = forward_func
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
self.metric: Callable = metric
self.preproc_fn = preproc_fn
self.attacks: Dict[str, AttackInfo] = {}
self.summary_results: Dict[str, Summarizer] = {}
# pyre-fixme[4]: Attribute must be annotated.
self.metric_aggregator = agg_metric
self.batch_stats = [Mean, Min, Max]
self.aggregate_stats = [Mean]
self.summary_results = {}
# pyre-fixme[4]: Attribute must be annotated.
self.out_format = None
[docs]
def add_attack(
self,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
attack: Union[Perturbation, Callable],
name: Optional[str] = None,
num_attempts: int = 1,
apply_before_preproc: bool = True,
attack_kwargs: Optional[Dict[str, Any]] = None,
additional_attack_arg_names: Optional[List[str]] = None,
) -> None:
r"""
Adds attack to be evaluated when calling evaluate.
Args:
attack (Perturbation or Callable): This can either be an instance
of a Captum Perturbation / Attack
or any other perturbation or attack function such
as a torchvision transform.
name (str, optional): Name or identifier for attack, used as key for
attack results. This defaults to attack.__class__.__name__
if not provided and must be unique for all added attacks.
Default: ``None``
num_attempts (int, optional): Number of attempts that attack should be
repeated. This should only be set to > 1 for non-deterministic
attacks. The minimum, maximum, and average (best, worst, and
average case) are tracked for attack attempts.
Default: ``1``
apply_before_preproc (bool, optional): Defines whether attack should be
applied before or after preproc function.
Default: ``True``
attack_kwargs (dict, optional): Additional arguments to be provided to
given attack. This should be provided as a dictionary of keyword
arguments.
Default: ``None``
additional_attack_arg_names (list[str], optional): Any additional
arguments for the attack which are specific to the particular input
example or batch. An example of this is target, which is necessary
for some attacks such as FGSM or PGD. These arguments are included
if provided as a kwarg to evaluate.
Default: ``None``
"""
if name is None:
name = attack.__class__.__name__
if attack_kwargs is None:
attack_kwargs = {}
if additional_attack_arg_names is None:
additional_attack_arg_names = []
if name in self.attacks:
raise RuntimeError(
"Cannot add attack with same name as existing attack {}".format(name)
)
self.attacks[name] = AttackInfo(
attack_fn=attack,
name=name,
num_attempts=num_attempts,
apply_before_preproc=apply_before_preproc,
attack_kwargs=attack_kwargs,
additional_args=additional_attack_arg_names,
)
def _format_summary(
self,
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
# errors.
summary: Union[Dict, List[Dict]],
) -> Dict[str, MetricResultType]:
r"""
This method reformats a given summary; particularly for tuples,
the Summarizer's summary format is a list of dictionaries,
each containing the summary for the corresponding elements.
We reformat this to return a dictionary with tuples containing
the summary results.
"""
if isinstance(summary, dict):
return summary
else:
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
summary_dict: Dict[str, Tuple] = {}
for key in summary[0]:
summary_dict[key] = tuple(s[key] for s in summary)
if self.out_format:
summary_dict[key] = self.out_format(*summary_dict[key])
return summary_dict # type: ignore
def _update_out_format(
self, out_metric: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
) -> None:
if (
not self.out_format
and isinstance(out_metric, tuple)
and hasattr(out_metric, "_fields")
):
self.out_format = namedtuple( # type: ignore
type(out_metric).__name__, cast(NamedTuple, out_metric)._fields
)
def _evaluate_batch(
self,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
input_list: List[Any],
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
additional_forward_args: Optional[Tuple],
key_list: List[str],
batch_summarizers: Dict[str, Summarizer],
metric_kwargs: Dict[str, Any],
) -> None:
if additional_forward_args is None:
additional_forward_args = ()
if len(input_list) == 1:
model_out = self.forward_func(input_list[0], *additional_forward_args)
out_metric = self.metric(model_out, **metric_kwargs)
self._update_out_format(out_metric)
batch_summarizers[key_list[0]].update(out_metric)
else:
batched_inps = _reduce_list(input_list)
model_out = self.forward_func(batched_inps, *additional_forward_args)
current_count = 0
for i in range(len(input_list)):
batch_size = (
input_list[i].shape[0]
if isinstance(input_list[i], Tensor)
else input_list[i][0].shape[0]
)
out_metric = self.metric(
model_out[current_count : current_count + batch_size],
**metric_kwargs,
)
self._update_out_format(out_metric)
batch_summarizers[key_list[i]].update(out_metric)
current_count += batch_size
[docs]
@log_usage()
def evaluate(
self,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
inputs: Any,
additional_forward_args: Optional[object] = None,
perturbations_per_eval: int = 1,
# pyre-fixme[2]: Parameter must be annotated.
**kwargs,
) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]:
r"""
Evaluate model and attack performance on provided inputs
Args:
inputs (Any): Input for which attack metrics
are computed. It can be provided as a tensor, tuple of tensors,
or any raw input type (e.g. PIL image or text string).
This input is provided directly as input to preproc function as well
as any attack applied before preprocessing. If no pre-processing
function is provided, this input is provided directly to the main
model and all attacks.
additional_forward_args (Any, optional): If the forward function
requires additional arguments other than the preprocessing
outputs (or inputs if preproc_fn is None), this argument
can be provided. It must be either a single additional
argument of a Tensor or arbitrary (non-tuple) type or a
tuple containing multiple additional arguments including
tensors or any arbitrary python types. These arguments
are provided to forward_func in order following the
arguments in inputs.
For a tensor, the first dimension of the tensor must
correspond to the number of examples. For all other types,
the given argument is used for all forward evaluations.
Default: ``None``
perturbations_per_eval (int, optional): Allows perturbations of multiple
attacks to be grouped and evaluated in one call of forward_fn
Each forward pass will contain a maximum of
perturbations_per_eval * #examples samples.
For DataParallel models, each batch is split among the
available devices, so evaluations on each available
device contain at most
(perturbations_per_eval * #examples) / num_devices
samples.
In order to apply this functionality, the output of preproc_fn
(or inputs itself if no preproc_fn is provided) must be a tensor
or tuple of tensors.
Default: ``1``
kwargs (Any, optional): Additional keyword arguments provided to metric
function as well as selected attacks based on chosen additional_args.
Default: ``None``
Returns:
- **attack results** Dict: str -> Dict[str, Union[Tensor, Tuple[Tensor, ...]]]:
Dictionary containing attack results for provided batch.
Maps attack name to dictionary,
containing best-case, worst-case and average-case results for attack.
Dictionary contains keys "mean", "max" and "min" when num_attempts > 1
and only "mean" for num_attempts = 1, which contains the (single) metric
result for the attack attempt.
An additional key of 'Original' is included with metric results
without any perturbations.
Examples::
>>> def accuracy_metric(model_out: Tensor, targets: Tensor):
>>> return torch.argmax(model_out, dim=1) == targets).float()
>>> attack_metric = AttackComparator(model=resnet18,
metric=accuracy_metric,
preproc_fn=normalize)
>>> random_rotation = transforms.RandomRotation()
>>> jitter = transforms.ColorJitter()
>>> attack_metric.add_attack(random_rotation, "Random Rotation",
>>> num_attempts = 5)
>>> attack_metric.add_attack((jitter, "Jitter", num_attempts = 1)
>>> attack_metric.add_attack(FGSM(resnet18), "FGSM 0.1", num_attempts = 1,
>>> apply_before_preproc=False,
>>> attack_kwargs={epsilon: 0.1},
>>> additional_args=["targets"])
>>> for images, labels in dataloader:
>>> batch_results = attack_metric.evaluate(inputs=images, targets=labels)
"""
additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
expanded_additional_args = (
_expand_additional_forward_args(
additional_forward_args, perturbations_per_eval
)
if perturbations_per_eval > 1
else additional_forward_args
)
preproc_input = None
if self.preproc_fn is not None:
preproc_input = self.preproc_fn(inputs)
else:
preproc_input = inputs
input_list = [preproc_input]
key_list = [ORIGINAL_KEY]
batch_summarizers = {ORIGINAL_KEY: Summarizer([Mean()])}
if ORIGINAL_KEY not in self.summary_results:
self.summary_results[ORIGINAL_KEY] = Summarizer(
[stat() for stat in self.aggregate_stats]
)
# pyre-fixme[53]: Captured variable `batch_summarizers` is not annotated.
# pyre-fixme[53]: Captured variable `expanded_additional_args` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _check_and_evaluate(input_list, key_list):
if len(input_list) == perturbations_per_eval:
self._evaluate_batch(
input_list,
expanded_additional_args,
key_list,
batch_summarizers,
kwargs,
)
return [], []
return input_list, key_list
input_list, key_list = _check_and_evaluate(input_list, key_list)
for attack_key in self.attacks:
attack = self.attacks[attack_key]
if attack.num_attempts > 1:
stats = [stat() for stat in self.batch_stats]
else:
stats = [Mean()]
batch_summarizers[attack.name] = Summarizer(stats)
additional_attack_args = {}
for key in attack.additional_args:
if key not in kwargs:
warnings.warn(
f"Additional sample arg {key} not provided for {attack_key}",
stacklevel=1,
)
else:
additional_attack_args[key] = kwargs[key]
for _ in range(attack.num_attempts):
if attack.apply_before_preproc:
attacked_inp = attack.attack_fn(
inputs, **additional_attack_args, **attack.attack_kwargs
)
preproc_attacked_inp = (
self.preproc_fn(attacked_inp)
if self.preproc_fn
else attacked_inp
)
else:
preproc_attacked_inp = attack.attack_fn(
preproc_input, **additional_attack_args, **attack.attack_kwargs
)
input_list.append(preproc_attacked_inp)
key_list.append(attack.name)
input_list, key_list = _check_and_evaluate(input_list, key_list)
if len(input_list) > 0:
final_add_args = _expand_additional_forward_args(
additional_forward_args, len(input_list)
)
self._evaluate_batch(
input_list, final_add_args, key_list, batch_summarizers, kwargs
)
return self._parse_and_update_results(batch_summarizers)
def _parse_and_update_results(
self, batch_summarizers: Dict[str, Summarizer]
) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]:
results: Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]] = {
ORIGINAL_KEY: self._format_summary(
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime
# subscripting errors.
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List[<element type>]` to avoid runtime subscripting errors.
cast(Union[Dict, List], batch_summarizers[ORIGINAL_KEY].summary)
)["mean"]
}
self.summary_results[ORIGINAL_KEY].update(
self.metric_aggregator(results[ORIGINAL_KEY])
)
for attack_key in self.attacks:
attack = self.attacks[attack_key]
attack_results = self._format_summary(
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime
# subscripting errors.
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List[<element type>]` to avoid runtime subscripting errors.
cast(Union[Dict, List], batch_summarizers[attack.name].summary)
)
results[attack.name] = attack_results
if len(attack_results) == 1:
key = next(iter(attack_results))
if attack.name not in self.summary_results:
self.summary_results[attack.name] = Summarizer(
[stat() for stat in self.aggregate_stats]
)
self.summary_results[attack.name].update(
self.metric_aggregator(attack_results[key])
)
else:
for key in attack_results:
summary_key = f"{attack.name} {key.title()} Attempt"
if summary_key not in self.summary_results:
self.summary_results[summary_key] = Summarizer(
[stat() for stat in self.aggregate_stats]
)
self.summary_results[summary_key].update(
self.metric_aggregator(attack_results[key])
)
return results
[docs]
def summary(self) -> Dict[str, Dict[str, MetricResultType]]:
r"""
Returns average results over all previous batches evaluated.
Returns:
- **summary** Dict: str -> Dict[str, Union[Tensor, Tuple[Tensor, ...]]]:
Dictionary containing summarized average attack results.
Maps attack name (with "Mean Attempt", "Max Attempt" and "Min Attempt"
suffixes if num_attempts > 1) to dictionary containing a key of "mean"
maintaining summarized results,
which is the running mean of results over all batches
since construction or previous reset call. Tensor metrics are averaged
over dimension 0 for each batch, in order to aggregte metrics collected
per batch.
"""
return {
key: self._format_summary(
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
# `typing.Dict[<key type>, <value type>]` to avoid runtime
# subscripting errors.
# pyre-fixme[24]: Generic type `list` expects 1 type parameter, use
# `typing.List[<element type>]` to avoid runtime subscripting errors.
cast(Union[Dict, List], self.summary_results[key].summary)
)
for key in self.summary_results
}
[docs]
def reset(self) -> None:
r"""
Reset stored average summary results for previous batches
"""
self.summary_results = {}