PyTorch的继承LinearNet代码分析


LinearRegression代码顺序

  • Prepare dataset
  • Design model using Class
    • inherit from nn.Module
  • Construct loss and optimizer
    • using PyTorch APi
  • Training cycle
    • foward backward update

代码分析

Prepare dataset

python

num_inputs=2 # 两个特征数 
num_examples=1000 # 1000个样本 
true_w=[2,-3.4] # 两个权重 
true_b=4.2 # 偏差 
features=torch.from_numpy(np.random.normal(0,1,(num_examples,num_inputs))) # 正态分布得到2*1000个特征数的tensor,均值为0,scale为1 
labels=true_w[0]*features[:,0]+true_w[1]*features[:,1]+true_b # 真实值y(真实labels) 矢量计算
labels+=torch.from_numpy(np.random.normal(0,0.01,size=labels.size())) # 真实值(真实labels)y+噪声

Design model using Class

python

class LinearNet(nn.Module): 
    def __init__(self,n_feature): 
        super(LinearNet,self).__init__() 
        self.linear=nn.Linear(n_feature,1) 
    # forward 定义前向传播 
    def forward(self,x): 
        y=self.linear(x) 
        return y 
net=nn.Sequential() 
net.add_module('linear',nn.Linear(num_inputs,1)) 
# 初始化模型参数 
from torch.nn import init init.normal_(net[0].weight,mean=0,std=0.01) 
init.constant_(net[0].bias,val=0)

上述代码y=self.linear(x)无法理解

解决办法

查找官方文档
3.3.6. Emulating callable objects
object.__call__(self[, args...]) 
Called when the instance is “called” as a function; 
if this method is defined, x(arg1, arg2, ...) roughly translates to type(x).__call__(x, arg1, ...).

上述意思是在类中添加__call__方法,来让类对象可调用

查看Linear源码
def _call_impl(self, *input, **kwargs): 
    # Do not call functions when jit is used 
    full_backward_hooks, non_full_backward_hooks = [], [] 
    if len(self._backward_hooks) > 0 or len(_global_backward_hooks) > 0: 
        full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() 

    for hook in itertools.chain( 
            _global_forward_pre_hooks.values(), 
            self._forward_pre_hooks.values()): 
         result = hook(self, input) 
        if result is not None:
            if not isinstance(result, tuple): 
                result = (result,)
            input = result bw_hook = None if len(full_backward_hooks) > 0: 
                bw_hook = hooks.BackwardHook(self, full_backward_hooks) 
                input = bw_hook.setup_input_hook(input) 
            # 这里执行了forward 
            if torch._C._get_tracing_state(): 
                result = self._slow_forward(*input, **kwargs)
            else: 
                result = self.forward(*input, **kwargs) 
            for hook in itertools.chain( 
                _global_forward_hooks.values(), 
                self._forward_hooks.values()): 
            hook_result = hook(self, input, result) 
            if hook_result is not None: 
                result = hook_result 

        if bw_hook: 
            result = bw_hook.setup_output_hook(result)
        # Handle the non-full backward hooks 
        if len(non_full_backward_hooks) > 0: 
            var = result 
            while not isinstance(var, torch.Tensor): 
                if isinstance(var, dict): 
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor))) 
                else:
                    var = var[0] 
            grad_fn = var.grad_fn 
            if grad_fn is not None: 
                for hook in non_full_backward_hooks: 
                    wrapper = functools.partial(hook, self) 
                    functools.update_wrapper(wrapper, hook) 
                    grad_fn.register_hook(wrapper) self._maybe_warn_non_full_backward_hook(input, result, grad_fn) 
        return result 
__call__ : Callable[..., Any] = _call_impl
写个案例

python

class Foobar: 
    def __init__(self): 
        pass 

    # *args收前边没指定的参数 tuple 
    # kwargs 收后边指定的参数 dict 
    def __call__(self, *args, **kwargs): 
        print("hello",str(args[0])) 
if __name__ == '__main__': 
    f=Foobar() 
    f("maozhongkuan")

输出

hello maozhongkuan

Construct loss and optimizer

loss=nn.MSELoss() # 损失函数
import torch.optim as optim 
optimizer=optim.SGD(net.parameters(),lr=0.03) # 优化器

Training cycle

num_epochs=3 
for epoch in range(1,num_epochs+1): 
    for x,y in data_iter: 
        output=net(x.float()) # 算y_hat 
        l=loss(output.float(),y.view(-1,1).float()) # 算loss 
        optimizer.zero_grad() # 梯度清零,等价与net.zero_grad() 
        l.backward() # 求l对w和b的导数 
        optimizer.step() # 梯度下降找loss最小 
    print('epoh %d,loss: %f'%(epoch,l.item())) 
dense=net[0] 
print(true_w,dense.weight) 
print(true_b,dense.bias)

参考

https://www.bilibili.com/video/BV1Y7411d7Ys?p=5 35min


文章作者: Jelly
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Jelly !
  目录