tianmoucv.proc.opticalflow.spy_net 源代码

import os
import torch.nn as nn
import numpy as np
import cv2
import torch.nn.functional as F
import torch
import time

from .basic import *
from tianmoucv.isp import *
from tianmoucv.proc.nn.spy_modules import *
from tianmoucv.tools import check_url_or_local_path,download_file

[文档] class TianmoucOF_SpyNet(nn.Module): ''' 计算稠密光流的nn方法 默认权重存储于'of_0918_ver_best.ckpt' 或初始化时指定ckpt_path parameter: :param imgsize: (w,h),list :param ckpt\_path: string, path to weight dictionary ''' #temp network def __init__(self,imgsize,ckpt_path = None, _optim=True): super(TianmoucOF_SpyNet, self).__init__() current_dir=os.path.dirname(__file__) if ckpt_path is None: ckpt_path = 'https://cloud.tsinghua.edu.cn/f/84ac6e32060443e2975d/?dl=1' status = check_url_or_local_path(ckpt_path) print('loading..:',ckpt_path) if status == 1: default_file_name = 'of_spy_ver_best.ckpt' if not os.path.exists(default_file_name): ckpt_path = download_file(url=ckpt_path,file_name=default_file_name) else: ckpt_path = default_file_name print('load finished') self.flowComp = SpyNet(dim=1+2+2) self.W, self.H = imgsize self.gridX, self.gridY = np.meshgrid(np.arange(self.W), np.arange(self.H)) dict1 = torch.load(ckpt_path, map_location=torch.device('cpu')) dict1 = dict1['state_dict_OF'] dict_flowComp = dict([]) for key in dict1: new_key_list = key.split('.')[1:] new_key = '' for e in new_key_list: new_key += e + '.' new_key = new_key[:-1] if 'flowComp' in key: dict_flowComp[new_key] = dict1[key] self.flowComp.load_state_dict(dict_flowComp,strict=True) self.eval() for param in self.flowComp.parameters(): param.requires_grad = False main_version = int(torch.__version__[0]) if main_version==2 and _optim: print('compiling model for pytorch version>= 2.0.0') self.flowComp = torch.compile(self.flowComp) print('compiled!')
[文档] @torch.no_grad() def forward_time_range(self, tsdiff: torch.Tensor,t1,t2,F0=None): ''' Args: @tsdiff: [c,n,w,h], -1~1,torch,decoder的输出直接concate的结果 @t1,t2 \in [0,n] calculate the OF between t1-t2 ''' for param in self.flowComp.parameters(): self.device = param.device break if len(tsdiff.shape)==3: tsdiff = tsdiff.unsqueeze(0) tsdiff = tsdiff.to(self.device) TD_0_t = torch.sum(tsdiff[:,0:1,t1:t2,...],dim=2) SD0 = tsdiff[:,1:,t1,...] SD1 = tsdiff[:,1:,t2,...] TD_t_0 = -1 * TD_0_t # Part1. warp stime = time.time() Flow_1_0 = self.flowComp(TD_t_0, SD0, SD1) #输出值0~1 etime = time.time() frameTime = etime - stime print(1/frameTime, 'fps') return Flow_1_0