nice_hooks.nice_hooks

Module Contents

Classes

RemovableHandleCollection

Represents a collection of torch RemovableHandle objects.

ModulePath

Represents a parsed module path. See Module Paths

Functions

_to_paths(paths)

Converts paths like things to list[str]

expand_module_path(model, path)

Parses a path string to a ModulePath, then expands any wildcards in the name.

_is_wild_slice(sl)

Wild slices are those that may return multiple values in expand_slice

_expand_slice(tt, sl)

Computes tt[sl]. If sl has wildcard references, expand them

_regroup(iter, key_fn[, value_fn])

Unordred group by.

_do_hook(expanded_paths, hook, reg)

Like register_forward_hook, but for the results of expand_module_path

register_forward_hook(module, path, hook)

Registers forward hooks on submodules of a module.

register_forward_pre_hook(module, path, hook)

Registers forward pre hooks on submodules of a module.

register_full_backward_hook(module, path, hook)

Registers backward hooks on submodules of a module.

register_full_backward_pre_hook(module, path, hook)

Registers backward pre hooks on submodules of a module.

run(module, *args[, return_activations, ...])

Runs the model, accepting some extra keyword parameters for various behaviours.

patch_method_to_module(cls, fname)

Call on a class inheriting from nn.Module to convert a method in that module to a submodule.

Attributes

AnySlice

Anything that can be used as an index to a torch Tensor. Or None.

ModulePathsLike

Anything that can be parsed into a list of ModulePath object.

nice_hooks.nice_hooks.AnySlice

Anything that can be used as an index to a torch Tensor. Or None.

nice_hooks.nice_hooks.ModulePathsLike

Anything that can be parsed into a list of ModulePath object.

nice_hooks.nice_hooks._to_paths(paths)

Converts paths like things to list[str]

Parameters:

paths (ModulePathsLike) –

Return type:

list[str]

class nice_hooks.nice_hooks.RemovableHandleCollection(handles)

Represents a collection of torch RemovableHandle objects.

Like RemovableHandle itself, this class supports with statements.

Parameters:

handles (list[torch.utils.hooks.RemovableHandle]) –

class nice_hooks.nice_hooks.ModulePath

Represents a parsed module path. See Module Paths

name: str

The name of the module

slice: AnySlice

How the module’s output should be indexed. None indicates the module output is not changed.

static parse(path)

Parse a string into a ModulePath. Wildcards are passed through unchanged.

Parameters:

path (str) – The path to parse

Returns:

The path, parsed into parts.

Return type:

ModulePath

__str__()

Return str(self).

Return type:

str

nice_hooks.nice_hooks.expand_module_path(model, path)

Parses a path string to a ModulePath, then expands any wildcards in the name.

Slice wildcards are left unchanged

Parameters:
  • model (torch.nn.Module) –

  • path (ModulePathsLike) –

Return type:

list[ModulePath, torch.nn.Module]

nice_hooks.nice_hooks._is_wild_slice(sl)

Wild slices are those that may return multiple values in expand_slice

Parameters:

sl (AnySlice) –

Return type:

bool

nice_hooks.nice_hooks._expand_slice(tt, sl)

Computes tt[sl]. If sl has wildcard references, expand them according to the shape of tt

Parameters:
  • tt (torch.tensor) –

  • sl (AnySlice) –

Return type:

list[Tuple[torch.tensor, AnySlice]]

nice_hooks.nice_hooks._regroup(iter, key_fn, value_fn=None)

Unordred group by.

nice_hooks.nice_hooks._do_hook(expanded_paths, hook, reg)

Like register_forward_hook, but for the results of expand_module_path Handles wildcard slices.

Parameters:
  • expanded_paths (list[Tuple[ModulePath, torch.nn.Module]]) –

  • hook (Callable) –

Return type:

RemovableHandleCollection

nice_hooks.nice_hooks.register_forward_hook(module, path, hook)

Registers forward hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, args, output) arguments.

Return type:

RemovableHandleCollection

nice_hooks.nice_hooks.register_forward_pre_hook(module, path, hook)

Registers forward pre hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, args) arguments.

Return type:

RemovableHandleCollection

nice_hooks.nice_hooks.register_full_backward_hook(module, path, hook)

Registers backward hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, grad_args, grad_output) arguments.

Return type:

RemovableHandleCollection

nice_hooks.nice_hooks.register_full_backward_pre_hook(module, path, hook)

Registers backward pre hooks on submodules of a module.

Parameters:
  • module (torch.nn.Module) – The root module to read named_modules() from.

  • path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.

  • hook (Callable) – A function accepting (module, path, grad_args) arguments.

Return type:

RemovableHandleCollection

nice_hooks.nice_hooks.run(module, *args, return_activations=None, with_activations=None, forward_hooks=None, forward_pre_hooks=None, full_backward_hooks=None, full_backward_pre_hooks=None, **kwargs)

Runs the model, accepting some extra keyword parameters for various behaviours.

Parameters:
  • module (torch.nn.Module) – The module to run

  • *args – Args to pass to the model

  • **kwargs – Args to pass to the model

  • return_activations (ModulePathsLike) – If true, records activations as the module is run an activations cache. Returns a tuple of model output, and the activations cache.

  • with_activations (nice_hooks.activationcache.ActivationCacheLike) – If set, replaces the given activations when running the module forward.

  • forward_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers forward hooks for just this run

  • forward_pre_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers forward pre hooks for just this run

  • full_backward_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers backward hooks for just this run

  • full_backward_pre_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers backward pre hooks for just this run

nice_hooks.nice_hooks.patch_method_to_module(cls, fname)

Call on a class inheriting from nn.Module to convert a method in that module to a submodule. This is useful if you wish to add hooks on the function.

This uses monkey patching, you only need call it once on a class to affect all instances. It must be called before creating instances of cls.

Parameters:
  • cls (type) – The class to patch

  • fname (str) – The name of the method on the class.