MnistModel.register_forward_hook#

MnistModel.register_forward_hook(hook: Union[Callable[[T, Tuple[Any, ...], Any], Optional[Any]], Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle#

Register a forward hook on the module.

The hook will be called every time after forward() has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature:

hook(module, args, output) -> None or modified output

If with_kwargs is True, the forward hook will be passed the kwargs given to the forward function and be expected to return the output possibly modified. The hook should have the following signature:

hook(module, args, kwargs, output) -> None or modified output
Args:

hook (Callable): The user defined hook to be registered. prepend (bool): If True, the provided hook will be fired

before all existing forward hooks on this torch.nn.modules.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.modules.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False

with_kwargs (bool): If True, the hook will be passed the

kwargs given to the forward function. Default: False

always_call (bool): If True the hook will be run regardless of

whether an exception is raised while calling the Module. Default: False

Returns:
torch.utils.hooks.RemovableHandle:

a handle that can be used to remove the added hook by calling handle.remove()