SSD测试时RuntimeError


问题提出

RuntimeError: Legacy autograd function with non-static forward method is deprecated

原因

Pytorch版本过高,源代码是要求低版本的Pytorch,当前版本要求的forward过程是静态的,所以要修改源代码

解决方案

  1. 删除文件detection.py里的 def __init__(),然后再在def forward()方法前面加上@staticmethod
  2. 修改def forward()的前半部分,如下
def forward(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh, loc_data, conf_data, prior_data): 
    """ 
    Args: 
        loc_data: (tensor) Loc preds from loc layers 
            Shape: [batch,num_priors*4] 
        conf_data: (tensor) Shape: Conf preds from conf layers 
            Shape: [batch*num_priors,num_classes] 
        prior_data: (tensor) Prior boxes and variances from priorbox layers 
            Shape: [1,num_priors,4] 
    """ 
    self.num_classes = num_classes 
    self.background_label = bkg_label 
    self.top_k = top_k 
    # Parameters used in nms. 
    self.nms_thresh = nms_thresh 
    if nms_thresh <= 0: 
        raise ValueError('nms_threshold must be non negative.') 
    self.conf_thresh = conf_thresh 
    self.variance = cfg['variance']
  1. 因为上述detect类发生了变化,所以在使用该类的时候需要改变,所以修改ssd.py中的初始化方法def __init__()改成如下
# 修改前 
if phase == 'test': 
    self.softmax = nn.Softmax(dim=-1) 
    self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) 
    
# 修改后 
if phase == 'test': 
    self.softmax = nn.Softmax() 
    self.detect = Detect()
  1. 另外ssd.py文件下的def forward()方法也要改

python

# 修改前 
if self.phase == "test": 
    output = self.detect( 
        loc.view(loc.size(0), -1, 4), # loc preds 
        self.softmax(conf.view(conf.size(0), -1, 
        self.num_classes)), # conf preds 
        self.priors.type(type(x.data)) # default boxes 
        ) 
# 修改后 
if self.phase == "test": 
    output = self.detect.apply(21, 0, 200, 0.01, 0.45, 
    loc.view(loc.size(0), -1, 4), # loc preds 
    self.softmax(conf.view(-1, 21)), # conf preds 
    self.priors.type(type(x.data)) # default boxes 
    )

引用&致谢

(https://blog.csdn.net/baidu_39389949/article/details/107398256)
(https://github.com/sayakbanerjee1999/Single-Shot-Object-Detection-Updated)


文章作者: Jelly
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Jelly !
  目录