MnistModel.apply#
- MnistModel.apply(fn: Callable[[Module], None]) T #
Apply
fn
recursively to every submodule (as returned by.children()
) as well as self.Typical use includes initializing the parameters of a model (see also torch.nn.init).
- Args:
fn (
Module
-> None): function to be applied to each submodule- Returns:
Module: self
Example:
>>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )