Source code for captum.robust._core.metrics.attack_comparator

#!/usr/bin/env python3
import warnings
from collections import namedtuple
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    List,
    NamedTuple,
    Optional,
    Tuple,
    TypeVar,
    Union,
    cast,
)

from captum._utils.common import (
    _expand_additional_forward_args,
    _format_additional_forward_args,
    _reduce_list,
)
from captum.attr import Max, Mean, Min, Summarizer
from captum.robust._core.perturbation import Perturbation
from torch import Tensor

ORIGINAL_KEY = "Original"

MetricResultType = TypeVar(
    "MetricResultType", float, Tensor, Tuple[Union[float, Tensor], ...]
)


class AttackInfo(NamedTuple):
    attack_fn: Union[Perturbation, Callable]
    name: str
    num_attempts: int
    apply_before_preproc: bool
    attack_kwargs: Dict[str, Any]
    additional_args: List[str]


def agg_metric(inp):
    if isinstance(inp, Tensor):
        return inp.mean(dim=0)
    elif isinstance(inp, tuple):
        return tuple(agg_metric(elem) for elem in inp)
    return inp


[docs]class AttackComparator(Generic[MetricResultType]): r""" Allows measuring model robustness for a given attack or set of attacks. This class can be used with any metric(s) as well as any set of attacks, either based on attacks / perturbations from captum.robust such as FGSM or PGD or external augmentation methods or perturbations such as torchvision transforms. """ def __init__( self, forward_func: Callable, metric: Callable[..., MetricResultType], preproc_fn: Callable = None, ) -> None: r""" Args: forward_func (callable or torch.nn.Module): This can either be an instance of pytorch model or any modification of a model's forward function. metric (callable): This function is applied to the model output in order to compute the desired performance metric or metrics. This function should have the following signature:: >>> def model_metric(model_out: Tensor, **kwargs: Any) >>> -> Union[float, Tensor, Tuple[Union[float, Tensor], ...]: All kwargs provided to evaluate are provided to the metric function, following the model output. A single metric can be returned as a float or tensor, and multiple metrics should be returned as either a tuple or named tuple of floats or tensors. For a tensor metric, the first dimension should match the batch size, corresponding to metrics for each example. Tensor metrics are averaged over the first dimension when aggregating multiple batch results. If tensor metrics represent results for the full batch, the size of the first dimension should be 1. preproc_fn (callable, optional): Optional method applied to inputs. Output of preproc_fn is then provided as input to model, in addition to additional_forward_args provided to evaluate. """ self.forward_func = forward_func self.metric: Callable = metric self.preproc_fn = preproc_fn self.attacks: Dict[str, AttackInfo] = {} self.summary_results: Dict[str, Summarizer] = {} self.metric_aggregator = agg_metric self.batch_stats = [Mean, Min, Max] self.aggregate_stats = [Mean] self.summary_results = {} self.out_format = None
[docs] def add_attack( self, attack: Union[Perturbation, Callable], name: Optional[str] = None, num_attempts: int = 1, apply_before_preproc: bool = True, attack_kwargs: Optional[Dict[str, Any]] = None, additional_attack_arg_names: Optional[List[str]] = None, ) -> None: r""" Adds attack to be evaluated when calling evaluate. Args: attack (perturbation or callable): This can either be an instance of a Captum Perturbation / Attack or any other perturbation or attack function such as a torchvision transform. name (optional, str): Name or identifier for attack, used as key for attack results. This defaults to attack.__class__.__name__ if not provided and must be unique for all added attacks. num_attempts (int): Number of attempts that attack should be repeated. This should only be set to > 1 for non-deterministic attacks. The minimum, maximum, and average (best, worst, and average case) are tracked for attack attempts. apply_before_preproc (bool): Defines whether attack should be applied before or after preproc function. attack_kwargs (dict): Additional arguments to be provided to given attack. This should be provided as a dictionary of keyword arguments. additional_attack_arg_names (list[str]): Any additional arguments for the attack which are specific to the particular input example or batch. An example of this is target, which is necessary for some attacks such as FGSM or PGD. These arguments are included if provided as a kwarg to evaluate. """ if name is None: name = attack.__class__.__name__ if attack_kwargs is None: attack_kwargs = {} if additional_attack_arg_names is None: additional_attack_arg_names = [] if name in self.attacks: raise RuntimeError( "Cannot add attack with same name as existing attack {}".format(name) ) self.attacks[name] = AttackInfo( attack_fn=attack, name=name, num_attempts=num_attempts, apply_before_preproc=apply_before_preproc, attack_kwargs=attack_kwargs, additional_args=additional_attack_arg_names, )
def _format_summary( self, summary: Union[Dict, List[Dict]] ) -> Dict[str, MetricResultType]: r""" This method reformats a given summary; particularly for tuples, the Summarizer's summary format is a list of dictionaries, each containing the summary for the corresponding elements. We reformat this to return a dictionary with tuples containing the summary results. """ if isinstance(summary, dict): return summary else: summary_dict: Dict[str, Tuple] = {} for key in summary[0]: summary_dict[key] = tuple(s[key] for s in summary) if self.out_format: summary_dict[key] = self.out_format(*summary_dict[key]) return summary_dict # type: ignore def _update_out_format( self, out_metric: Union[float, Tensor, Tuple[Union[float, Tensor], ...]] ) -> None: if ( not self.out_format and isinstance(out_metric, tuple) and hasattr(out_metric, "_fields") ): self.out_format = namedtuple( # type: ignore type(out_metric).__name__, cast(NamedTuple, out_metric)._fields ) def _evaluate_batch( self, input_list: List[Any], additional_forward_args: Optional[Tuple], key_list: List[str], batch_summarizers: Dict[str, Summarizer], metric_kwargs: Dict[str, Any], ) -> None: if additional_forward_args is None: additional_forward_args = () if len(input_list) == 1: model_out = self.forward_func(input_list[0], *additional_forward_args) out_metric = self.metric(model_out, **metric_kwargs) self._update_out_format(out_metric) batch_summarizers[key_list[0]].update(out_metric) else: batched_inps = _reduce_list(input_list) model_out = self.forward_func(batched_inps, *additional_forward_args) current_count = 0 for i in range(len(input_list)): batch_size = ( input_list[i].shape[0] if isinstance(input_list[i], Tensor) else input_list[i][0].shape[0] ) out_metric = self.metric( model_out[current_count : current_count + batch_size], **metric_kwargs, ) self._update_out_format(out_metric) batch_summarizers[key_list[i]].update(out_metric) current_count += batch_size
[docs] def evaluate( self, inputs: Any, additional_forward_args: Any = None, perturbations_per_eval: int = 1, **kwargs, ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: r""" Evaluate model and attack performance on provided inputs Args: inputs (any): Input for which attack metrics are computed. It can be provided as a tensor, tuple of tensors, or any raw input type (e.g. PIL image or text string). This input is provided directly as input to preproc function as well as any attack applied before preprocessing. If no pre-processing function is provided, this input is provided directly to the main model and all attacks. additional_forward_args (any, optional): If the forward function requires additional arguments other than the preprocessing outputs (or inputs if preproc_fn is None), this argument can be provided. It must be either a single additional argument of a Tensor or arbitrary (non-tuple) type or a tuple containing multiple additional arguments including tensors or any arbitrary python types. These arguments are provided to forward_func in order following the arguments in inputs. For a tensor, the first dimension of the tensor must correspond to the number of examples. For all other types, the given argument is used for all forward evaluations. Default: None perturbations_per_eval (int, optional): Allows perturbations of multiple attacks to be grouped and evaluated in one call of forward_fn Each forward pass will contain a maximum of perturbations_per_eval * #examples samples. For DataParallel models, each batch is split among the available devices, so evaluations on each available device contain at most (perturbations_per_eval * #examples) / num_devices samples. In order to apply this functionality, the output of preproc_fn (or inputs itself if no preproc_fn is provided) must be a tensor or tuple of tensors. Default: 1 kwargs (any, optional): Additional keyword arguments provided to metric function as well as selected attacks based on chosen additional_args Returns: - **attack results** Dict: str -> Dict[str, Union[Tensor, Tuple[Tensor, ...]]]: Dictionary containing attack results for provided batch. Maps attack name to dictionary, containing best-case, worst-case and average-case results for attack. Dictionary contains keys "mean", "max" and "min" when num_attempts > 1 and only "mean" for num_attempts = 1, which contains the (single) metric result for the attack attempt. An additional key of 'Original' is included with metric results without any perturbations. Examples:: >>> def accuracy_metric(model_out: Tensor, targets: Tensor): >>> return torch.argmax(model_out, dim=1) == targets).float() >>> attack_metric = AttackComparator(model=resnet18, metric=accuracy_metric, preproc_fn=normalize) >>> random_rotation = transforms.RandomRotation() >>> jitter = transforms.ColorJitter() >>> attack_metric.add_attack(random_rotation, "Random Rotation", >>> num_attempts = 5) >>> attack_metric.add_attack((jitter, "Jitter", num_attempts = 1) >>> attack_metric.add_attack(FGSM(resnet18), "FGSM 0.1", num_attempts = 1, >>> apply_before_preproc=False, >>> attack_kwargs={epsilon: 0.1}, >>> additional_args=["targets"]) >>> for images, labels in dataloader: >>> batch_results = attack_metric.evaluate(inputs=images, targets=labels) """ additional_forward_args = _format_additional_forward_args( additional_forward_args ) expanded_additional_args = ( _expand_additional_forward_args( additional_forward_args, perturbations_per_eval ) if perturbations_per_eval > 1 else additional_forward_args ) preproc_input = None if self.preproc_fn is not None: preproc_input = self.preproc_fn(inputs) else: preproc_input = inputs input_list = [preproc_input] key_list = [ORIGINAL_KEY] batch_summarizers = {ORIGINAL_KEY: Summarizer([Mean()])} if ORIGINAL_KEY not in self.summary_results: self.summary_results[ORIGINAL_KEY] = Summarizer( [stat() for stat in self.aggregate_stats] ) def _check_and_evaluate(input_list, key_list): if len(input_list) == perturbations_per_eval: self._evaluate_batch( input_list, expanded_additional_args, key_list, batch_summarizers, kwargs, ) return [], [] return input_list, key_list input_list, key_list = _check_and_evaluate(input_list, key_list) for attack_key in self.attacks: attack = self.attacks[attack_key] if attack.num_attempts > 1: stats = [stat() for stat in self.batch_stats] else: stats = [Mean()] batch_summarizers[attack.name] = Summarizer(stats) additional_attack_args = {} for key in attack.additional_args: if key not in kwargs: warnings.warn( f"Additional sample arg {key} not provided for {attack_key}" ) else: additional_attack_args[key] = kwargs[key] for _ in range(attack.num_attempts): if attack.apply_before_preproc: attacked_inp = attack.attack_fn( inputs, **additional_attack_args, **attack.attack_kwargs ) preproc_attacked_inp = ( self.preproc_fn(attacked_inp) if self.preproc_fn else attacked_inp ) else: preproc_attacked_inp = attack.attack_fn( preproc_input, **additional_attack_args, **attack.attack_kwargs ) input_list.append(preproc_attacked_inp) key_list.append(attack.name) input_list, key_list = _check_and_evaluate(input_list, key_list) if len(input_list) > 0: final_add_args = _expand_additional_forward_args( additional_forward_args, len(input_list) ) self._evaluate_batch( input_list, final_add_args, key_list, batch_summarizers, kwargs ) return self._parse_and_update_results(batch_summarizers)
def _parse_and_update_results( self, batch_summarizers: Dict[str, Summarizer] ) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]: results: Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]] = { ORIGINAL_KEY: self._format_summary( cast(Union[Dict, List], batch_summarizers[ORIGINAL_KEY].summary) )["mean"] } self.summary_results[ORIGINAL_KEY].update( self.metric_aggregator(results[ORIGINAL_KEY]) ) for attack_key in self.attacks: attack = self.attacks[attack_key] attack_results = self._format_summary( cast(Union[Dict, List], batch_summarizers[attack.name].summary) ) results[attack.name] = attack_results if len(attack_results) == 1: key = next(iter(attack_results)) if attack.name not in self.summary_results: self.summary_results[attack.name] = Summarizer( [stat() for stat in self.aggregate_stats] ) self.summary_results[attack.name].update( self.metric_aggregator(attack_results[key]) ) else: for key in attack_results: summary_key = f"{attack.name} {key.title()} Attempt" if summary_key not in self.summary_results: self.summary_results[summary_key] = Summarizer( [stat() for stat in self.aggregate_stats] ) self.summary_results[summary_key].update( self.metric_aggregator(attack_results[key]) ) return results
[docs] def summary(self) -> Dict[str, Dict[str, MetricResultType]]: r""" Returns average results over all previous batches evaluated. Returns: - **summary** Dict: str -> Dict[str, Union[Tensor, Tuple[Tensor, ...]]]: Dictionary containing summarized average attack results. Maps attack name (with "Mean Attempt", "Max Attempt" and "Min Attempt" suffixes if num_attempts > 1) to dictionary containing a key of "mean" maintaining summarized results, which is the running mean of results over all batches since construction or previous reset call. Tensor metrics are averaged over dimension 0 for each batch, in order to aggregte metrics collected per batch. """ return { key: self._format_summary( cast(Union[Dict, List], self.summary_results[key].summary) ) for key in self.summary_results }
[docs] def reset(self) -> None: r""" Reset stored average summary results for previous batches """ self.summary_results = {}