MnistModel.apply#
- MnistModel.apply(fn: Callable[[Module], None]) T #
Applies
fn
recursively to every submodule (as returned by.children()
) as well as self. Typical use includes initializing the parameters of a model (see also torch.nn.init).- Args:
fn (
Module
-> None): function to be applied to each submodule- Returns:
Module: self
Example:
>>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[ 1., 1.], [ 1., 1.]]) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )