Registering hooks

nice_hooks comes with functions nice_hooks.register_forward_hook(), nice_hooks.register_forward_pre_hook(), nice_hooks.register_full_backward_hook(), nice_hooks.register_full_backward_pre_hook() which correspond to their counterparts on torch.nn.Module.

The main differences are:
  • They are functions, not methods of Module

  • The accept an additional argument path, which lets you select multiple modules of the model.

  • The hook functions are called with an additional argument path which describes the module currently

Because of the path argument, you can register hooks on your root model, and specify the particular module you are interested in as a string. The path argument is a module path so can specify multiple modules with wild cards, and index specific values in the activation.

def my_hook(module: nn.Module, path: nice_hooks.ModulePath, args, output):

nice_hooks.register_forward_hook(my_model, "", my_hook)

Running the model

nice_hooks comes with a method run that runs the model.

result =, *args, **kwargs)
# equivalent to:
# result = model(*args, **kwargs)

run comes with several keyword arguments for controlling running the model:

Running the model with hooks

Hooks can be set on the model for the duration of a single run:

result =, *args, forward_hooks={'mod1': hook1})
# Equivalent to
# with nice_hooks.register_forward_hook(model, 'mod1', hook1):
#     result = model(*args)

See Registering hooks for details.

Recording activations

You can get the activations associated with an evaluation of the model with:

result, cache =, *args, return_activations=True)

The returned cache is a dictionary with a keys for each module name, and tensor values for their output during the run.

As storing all activations occupies memory, you can also specify which modules you are interested in, using the module path syntax.

result, cache =, *args_to_model, return_activations=["mod1.*", "mod2"])

Activation patching

You can replace specific activations with a known value:

result =, *args, with_activations={'mod1': t.ones(5)})

This replaces the output of the module named mod1 with the given tensor. Replacing an entire layer is not often useful, so you will likely want to use a path with an index

result =, *args, with_activations={'mod1[:,3:5]': t.ones(2)})