文章目录
1.register_module_forward_pre_hook
在 PyTorch 中,register_module_forward_pre_hook 是一个方法,用于向模型的模块注册前向传播预钩子(forward pre-hook)。预钩子是在模块的前向传播之前被调用的函数,允许在模块接收输入之前对输入进行修改或记录。
import torch
import torch.nn as nn
# 定义一个前向传播预钩子函数
def forward_pre_hook(module, input):
print("Forward pre-hook called for module:", module)
print("Input shape:", input[0].shape)
# 创建一个模型类
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
# 创建模型实例
model = MyModel()
# 注册前向传播预钩子
model.register_module_forward_pre_hook(forward_pre_hook)
# 输入数据
input_data = torch.randn(1, 10)
# 前向传播
output = model(input_data)
Forward pre-hook called for module: Linear(in_features=10, out_features=10, bias=True)
Input shape: torch.Size([1, 10])
2.register_module_forward_hook
在 PyTorch 中,register_module_forward_hook 是一个方法,用于向模型的模块注册前向传播钩子(forward hook)。钩子是在模块的前向传播过程中被调用的函数,可以用于获取中间特征、对特征进行修改或记录等操作。
import torch
import torch.nn as nn
# 定义一个前向传播钩子函数
def forward_hook(module, input, output):
print("Forward hook called for module:", module)
print("Input shape:", input[0].shape)
print("Output shape:", output.shape)
# 创建一个模型类
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
# 创建模型实例
model = MyModel()
# 注册前向传播钩子
model.register_forward_hook(forward_hook)
# 输入数据
input_data = torch.randn(1, 10)
# 前向传播
output = model(input_data)
Forward hook called for module: Linear(in_features=10, out_features=10, bias=True)
Input shape: torch.Size([1, 10])
Output shape: torch.Size([1, 10])
3.register_module_backward_hook
在 PyTorch 中,register_module_backward_hook 是一个方法,用于向模型的模块注册反向传播钩子(backward hook)。钩子是在模块的反向传播过程中被调用的函数,可以用于获取梯度、对梯度进行修改或记录等操作。
import torch
import torch.nn as nn
# 定义一个反向传播钩子函数
def backward_hook(module, grad_input, grad_output):
print("Backward hook called for module:", module)
print("Grad input shape:", grad_input[0].shape)
print("Grad output shape:", grad_output[0].shape)
# 创建一个模型类
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
# 创建模型实例
model = MyModel()
# 注册反向传播钩子
model.register_backward_hook(backward_hook)
# 输入数据
input_data = torch.randn(1, 10)
target = torch.randn(1, 10)
# 前向传播和反向传播
output = model(input_data)
loss = nn.MSELoss()(output, target)
loss.backward()
Backward hook called for module: Linear(in_features=10, out_features=10, bias=True)
Grad input shape: torch.Size([1, 10])
Grad output shape: torch.Size([1, 10])
本站资源均来自互联网,仅供研究学习,禁止违法使用和商用,产生法律纠纷本站概不负责!如果侵犯了您的权益请与我们联系!
转载请注明出处: 免费源码网-免费的源码资源网站 » Pytorch--Hooks For Module
发表评论 取消回复