nice_hooks

The top-level package.

Package Contents

Classes

ActivationCache

A wrapper around a dictionary of cached activations from a model run.

RemovableHandleCollection

Represents a collection of torch RemovableHandle objects.

ModulePath

Represents a parsed module path. See Module Paths

Functions

expand_module_path(model, path)

Parses a path string to a ModulePath, then expands any wildcards in the name.

register_forward_hook(module, path, hook)

Registers forward hooks on submodules of a module.

register_forward_pre_hook(module, path, hook)

Registers forward pre hooks on submodules of a module.

register_full_backward_hook(module, path, hook)

Registers backward hooks on submodules of a module.

register_full_backward_pre_hook(module, path, hook)

Registers backward pre hooks on submodules of a module.

run(module, *args[, return_activations, ...])

Runs the model, accepting some extra keyword parameters for various behaviours.

patch_method_to_module(cls, fname)

Call on a class inheriting from nn.Module to convert a method in that module to a submodule.

Attributes

ActivationCacheLike

Any object that a ActivationCache can be constructed from.

AnySlice

Anything that can be used as an index to a torch Tensor. Or None.

ModulePathsLike

Anything that can be parsed into a list of ModulePath object.

class nice_hooks.ActivationCache

Bases: dict[str, torch.Tensor]

A wrapper around a dictionary of cached activations from a model run. Supports similar operations to t.Tensor, which are usually applied elementwise. See ActivationCache

to(device)

Moves each tensor of the cache to a device.

Parameters:

device (Union[str, torch.device]) –

Return type:

ActivationCache

index(key)

Applies t[key] for every tensor in the activation cache. As the tensors may have different shapes, this operation only really makes sense for manipulating the batch dimension.

Parameters:

key (Union[None, _int, slice, torch.Tensor, List, Tuple]) –

Return type:

ActivationCache

nice_hooks.ActivationCacheLike

Any object that a ActivationCache can be constructed from.

nice_hooks.AnySlice

Anything that can be used as an index to a torch Tensor. Or None.

nice_hooks.ModulePathsLike

Anything that can be parsed into a list of ModulePath object.

class nice_hooks.RemovableHandleCollection(handles)

Represents a collection of torch RemovableHandle objects.

Like RemovableHandle itself, this class supports with statements.

Parameters:

handles (list[torch.utils.hooks.RemovableHandle]) –

class nice_hooks.ModulePath

Represents a parsed module path. See Module Paths

name: str

The name of the module

slice: AnySlice

How the module’s output should be indexed. None indicates the module output is not changed.

static parse(path)

Parse a string into a ModulePath. Wildcards are passed through unchanged.

Parameters:

path (str) – The path to parse

Returns:

The path, parsed into parts.

Return type:

ModulePath

__str__()

Return str(self).

Return type:

str

nice_hooks.expand_module_path(model, path)

Parses a path string to a ModulePath, then expands any wildcards in the name.

Slice wildcards are left unchanged

Parameters:
  • model (torch.nn.Module) –

  • path (ModulePathsLike) –

Return type:

list[ModulePath, torch.nn.Module]

nice_hooks.register_forward_hook(module, path, hook)

Registers forward hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, args, output) arguments.

Return type:

RemovableHandleCollection

nice_hooks.register_forward_pre_hook(module, path, hook)

Registers forward pre hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, args) arguments.

Return type:

RemovableHandleCollection

nice_hooks.register_full_backward_hook(module, path, hook)

Registers backward hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, grad_args, grad_output) arguments.

Return type:

RemovableHandleCollection

nice_hooks.register_full_backward_pre_hook(module, path, hook)

Registers backward pre hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, grad_args) arguments.

Return type:

RemovableHandleCollection

nice_hooks.run(module, *args, return_activations=None, with_activations=None, forward_hooks=None, forward_pre_hooks=None, full_backward_hooks=None, full_backward_pre_hooks=None, **kwargs)

Runs the model, accepting some extra keyword parameters for various behaviours.

Parameters:
  • module (torch.nn.Module) – The module to run

  • *args – Args to pass to the model

  • **kwargs – Args to pass to the model

  • return_activations (ModulePathsLike) – If true, records activations as the module is run an activations cache. Returns a tuple of model output, and the activations cache.

  • with_activations (nice_hooks.activationcache.ActivationCacheLike) – If set, replaces the given activations when running the module forward.

  • forward_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers forward hooks for just this run

  • forward_pre_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers forward pre hooks for just this run

  • full_backward_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers backward hooks for just this run

  • full_backward_pre_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers backward pre hooks for just this run

nice_hooks.patch_method_to_module(cls, fname)

Call on a class inheriting from nn.Module to convert a method in that module to a submodule. This is useful if you wish to add hooks on the function.

This uses monkey patching, you only need call it once on a class to affect all instances. It must be called before creating instances of cls.

Parameters:
  • cls (type) – The class to patch

  • fname (str) – The name of the method on the class.