nice_hooks
The top-level package.
Package Contents
Classes
A wrapper around a dictionary of cached activations from a model run. |
|
Represents a collection of torch RemovableHandle objects. |
|
Represents a parsed module path. See Module Paths |
Functions
|
Parses a path string to a ModulePath, then expands any wildcards in the name. |
|
Registers forward hooks on submodules of a module. |
|
Registers forward pre hooks on submodules of a module. |
|
Registers backward hooks on submodules of a module. |
|
Registers backward pre hooks on submodules of a module. |
|
Runs the model, accepting some extra keyword parameters for various behaviours. |
|
Call on a class inheriting from nn.Module to convert a method in that module to a submodule. |
Attributes
Any object that a |
|
Anything that can be used as an index to a torch Tensor. Or None. |
|
Anything that can be parsed into a list of |
- class nice_hooks.ActivationCache
Bases:
dict[str,torch.Tensor]A wrapper around a dictionary of cached activations from a model run. Supports similar operations to t.Tensor, which are usually applied elementwise. See ActivationCache
- to(device)
Moves each tensor of the cache to a device.
- Parameters:
device (Union[str, torch.device]) –
- Return type:
- index(key)
Applies t[key] for every tensor in the activation cache. As the tensors may have different shapes, this operation only really makes sense for manipulating the batch dimension.
- Parameters:
key (Union[None, _int, slice, torch.Tensor, List, Tuple]) –
- Return type:
- nice_hooks.ActivationCacheLike
Any object that a
ActivationCachecan be constructed from.
- nice_hooks.AnySlice
Anything that can be used as an index to a torch Tensor. Or None.
- nice_hooks.ModulePathsLike
Anything that can be parsed into a list of
ModulePathobject.
- class nice_hooks.RemovableHandleCollection(handles)
Represents a collection of torch RemovableHandle objects.
Like RemovableHandle itself, this class supports with statements.
- Parameters:
handles (list[torch.utils.hooks.RemovableHandle]) –
- class nice_hooks.ModulePath
Represents a parsed module path. See Module Paths
- name: str
The name of the module
- slice: AnySlice
How the module’s output should be indexed. None indicates the module output is not changed.
- static parse(path)
Parse a string into a ModulePath. Wildcards are passed through unchanged.
- Parameters:
path (str) – The path to parse
- Returns:
The path, parsed into parts.
- Return type:
- __str__()
Return str(self).
- Return type:
str
- nice_hooks.expand_module_path(model, path)
Parses a path string to a ModulePath, then expands any wildcards in the name.
Slice wildcards are left unchanged
- Parameters:
model (torch.nn.Module) –
path (ModulePathsLike) –
- Return type:
list[ModulePath, torch.nn.Module]
- nice_hooks.register_forward_hook(module, path, hook)
Registers forward hooks on submodules of a module.
- Parameters:
module (torch.nn.Module) – The root module to read named_modules() from.
path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.
hook (Callable) – A function accepting (module, path, args, output) arguments.
- Return type:
- nice_hooks.register_forward_pre_hook(module, path, hook)
Registers forward pre hooks on submodules of a module.
- Parameters:
module (torch.nn.Module) – The root module to read named_modules() from.
path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.
hook (Callable) – A function accepting (module, path, args) arguments.
- Return type:
- nice_hooks.register_full_backward_hook(module, path, hook)
Registers backward hooks on submodules of a module.
- Parameters:
module (torch.nn.Module) – The root module to read named_modules() from.
path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.
hook (Callable) – A function accepting (module, path, grad_args, grad_output) arguments.
- Return type:
- nice_hooks.register_full_backward_pre_hook(module, path, hook)
Registers backward pre hooks on submodules of a module.
- Parameters:
module (torch.nn.Module) – The root module to read named_modules() from.
path (ModulePathsLike) – A string or strings indicating which modules to attach the hook to.
hook (Callable) – A function accepting (module, path, grad_args) arguments.
- Return type:
- nice_hooks.run(module, *args, return_activations=None, with_activations=None, forward_hooks=None, forward_pre_hooks=None, full_backward_hooks=None, full_backward_pre_hooks=None, **kwargs)
Runs the model, accepting some extra keyword parameters for various behaviours.
- Parameters:
module (torch.nn.Module) – The module to run
*args – Args to pass to the model
**kwargs – Args to pass to the model
return_activations (ModulePathsLike) – If true, records activations as the module is run an activations cache. Returns a tuple of model output, and the activations cache.
with_activations (nice_hooks.activationcache.ActivationCacheLike) – If set, replaces the given activations when running the module forward.
forward_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers forward hooks for just this run
forward_pre_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers forward pre hooks for just this run
full_backward_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers backward hooks for just this run
full_backward_pre_hooks (dict[ModulePathsLike, Callable]) – If set, temporarily registers backward pre hooks for just this run
- nice_hooks.patch_method_to_module(cls, fname)
Call on a class inheriting from nn.Module to convert a method in that module to a submodule. This is useful if you wish to add hooks on the function.
This uses monkey patching, you only need call it once on a class to affect all instances. It must be called before creating instances of cls.
- Parameters:
cls (type) – The class to patch
fname (str) – The name of the method on the class.