tianmoucv.sim.simple_tmc_sim_advance 源代码

# Author:Taoyi
import torch
import torch.nn.functional as F
import cv2
import os, sys
import numpy as np
import random
from tqdm import tqdm
import re
import json
import torchvision
from tianmoucv.isp import SD2XY, upsample_cross_conv
### Load parameters

file_path = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(file_path,'simpleSim_params.json'), 'r') as f:
        sim_parama = json.load(f)

adc_bit_prec = sim_parama['adc_bit_prec']
dark_fpn_stat = sim_parama['dark_fpn_stat']
if adc_bit_prec == 8:
    for key in dark_fpn_stat:
        # 确保值是列表或NumPy数组,然后除以128
        dark_fpn_stat[key] = dark_fpn_stat[key] / 128
else:
    raise ValueError("Only support 8bit ADC now!")

digital_type = torch.int8 if adc_bit_prec <= 8 else torch.int16 if adc_bit_prec <= 16 else torch.int32 if adc_bit_prec <= 24 else torch.float32

def sort_filenames(filename):
    # 从文件名中提取数字部分
    pattern = r'[0-9]+'
    matches = re.findall(pattern, filename)
    if matches:
        return int(matches[0])
    else:
        return -1


def shift_rotate_tensor(tensor, shift_dx,shift_dy, rotate_degrees):

    transformed_tensor = torchvision.transforms.functional.affine(tensor, 
                                                                  translate=(shift_dx,shift_dy), 
                                                                  angle=rotate_degrees, 
                                                                  scale=1, 
                                                                  shear=0)
    return transformed_tensor

[文档] def push_to_fifo(tensor, x): """push to fifo, depth = self defined Args: tensor (torch.tensor): fifo tensor x (torch.tensor): new data Returns: torch.tensor : new fifo tensor """ return torch.cat((tensor[1:], x))
[文档] def visualize_diff(diff_out, vis_gain, color='rg'): """diff data visualization Args: diff_out (_type_): _description_ vis_gain (_type_): _description_ color (str, optional): _description_. Defaults to 'rg'. Returns: _type_: _description_ """ height, width = diff_out.shape if color == 'rg' or color == 'rgwb': diff = (diff_out.astype(np.int32) * vis_gain).clip(min=-255, max=255) diff_vis = np.zeros((height, width, 3), dtype=np.uint8) diff_pos = diff * (diff > 0) diff_neg = -diff * (diff < 0) diff_vis[..., 2] = diff_pos.astype(np.uint8) diff_vis[..., 1] = diff_neg.astype(np.uint8) return diff_vis elif color == 'gray': diff_vis = np.zeros((height, width), dtype=np.uint8) diff = (diff_out.astype(np.int32) * vis_gain).clip(min=-127, max=127) diff = diff + 127 #diff = diff diff_vis = diff.astype(np.uint8) return diff_vis
def diff_quant(diff, th): max_digi_num = 2 ** (adc_bit_prec- 1) lin_lsb = 1.0 / max_digi_num diff_th = (diff - th) * (diff > th) + (diff + th) * (diff < -th) diff_quantized = diff_th / lin_lsb diff_quantized = diff_quantized.clip(min=-max_digi_num, max=max_digi_num) diff_quantized = diff_quantized.ceil() * (diff > 0) + diff_quantized.floor() * (diff < 0) diff_quantized = diff_quantized.char() if digital_type == torch.int8 \ else diff_quantized.short() if digital_type == torch.int16 \ else diff_quantized.int() if digital_type == torch.int32 \ else diff_quantized return diff_quantized
[文档] def diff_response(pix_v_out, threshold, sim_cnt, device, fpn=None, xy=False): ''' :param pix_v_out, original input, please convert to float: 3D tensor, [2, H, W] :param threshold: dict, {'td': float, 'sd': float} :param sim_cnt: int, simulation counter, the number of frames simulated :param device: torch.device, cuda or cpu :param fpn: dict, optional fixed pattern noise :param xy: bool, whether to return SDX and SDY :return: td_quantized, sdl_quant, sdr_quant (or sdx, sdy if xy=True) ''' assert pix_v_out.shape[0] >= 2 td_th = threshold['td'] if sim_cnt == 0 : x_s, y_s = int(random.uniform(1,5)),int(random.uniform(1,5)) response_org = pix_v_out[1, :, :].cpu().numpy() matShift = np.float32([[1,0,x_s],[0,1,y_s]]) response_shift = cv2.warpAffine(response_org, matShift,(response_org.shape[1],response_org.shape[0])) torch_tensor = torch.from_numpy(response_shift) response_shift = torch_tensor.to(device)#torch.FloatTensor(response_shift, device=device) temp_diff = response_shift - response_org else: temp_diff = pix_v_out[1, :, :] - pix_v_out[0, :, :] # ADVANCED FPN for TD if fpn is not None: if sim_cnt % 2 == 0: td_fpn = fpn['td_even'] else: td_fpn = fpn['td_odd'] temp_diff += td_fpn # get quantized TD td_quantized = diff_quant(temp_diff, td_th) ### SD Calculation sd_th = threshold['sd'] sd_cal = pix_v_out[-1, :, :] H, W = sd_cal.shape sdl = torch.zeros_like(sd_cal) sdr = torch.zeros_like(sd_cal) # Correct SD formulas: Minuends are ALWAYS odd rows # Due to the ROD layout (Even rows shifted left by 1 pixel): # - SDL calculation (vertical in ROD grid) is physically Center - Left # - SDR calculation (diagonal in ROD grid) is physically Center - Right # if H > 1 and W > 1: # SDL Calculation (Physically Center - Left) # Even rows sdl[0:H-1:2, 0:W] = sd_cal[1:H:2, 0:W] - sd_cal[0:H-1:2, 0:W] # Odd rows sdl[1:H-1:2, 0:W] = sd_cal[1:H-1:2, 0:W] - sd_cal[2:H:2, 0:W] # SDR Calculation (Physically Center - Right) # Even rows sdr[0:H-1:2, 0:W-1] = sd_cal[1:H:2, 0:W-1] - sd_cal[0:H-1:2, 1:W] # Odd rows sdr[1:H-1:2, 0:W-1] = sd_cal[1:H-1:2, 0:W-1] - sd_cal[2:H:2, 1:W] ### ADVANCED NOISE SIMULATION FPN for SD if fpn is not None: if sim_cnt % 2 == 0: sdl_fpn = fpn['sdl_even'] sdr_fpn = fpn['sdr_even'] else: sdl_fpn = fpn['sdl_odd'] sdr_fpn = fpn['sdr_odd'] sdl += sdl_fpn sdr += sdr_fpn if xy: sd = torch.stack([sdl, sdr], dim=0).float() sdx, sdy = SD2XY(sd) sdx_quant = diff_quant(sdx, sd_th) sdy_quant = diff_quant(sdy, sd_th) return td_quantized, sdx_quant, sdy_quant else: # get quantized SD left and right sdl_quant = diff_quant(sdl, sd_th) sdr_quant = diff_quant(sdr, sd_th) return td_quantized, sdl_quant, sdr_quant
[文档] def get_fpn_from_stat(rod_height, rod_width, device): """ Generate Fixed Pattern Noise (FPN) from dark_fpn_stat. """ fixed_td_noise_odd = torch.normal(mean=dark_fpn_stat['td_odd_mean'], std=dark_fpn_stat['td_odd_std'], size=(rod_height, rod_width), device=device) fixed_sdl_noise_odd = torch.normal(mean=dark_fpn_stat['sdl_odd_mean'], std=dark_fpn_stat['sdl_odd_std'], size=(rod_height, rod_width), device=device) fixed_sdr_noise_odd = torch.normal(mean=dark_fpn_stat['sdr_odd_mean'], std=dark_fpn_stat['sdr_odd_std'], size=(rod_height, rod_width), device=device) fixed_td_noise_even = torch.normal(mean=dark_fpn_stat['td_even_mean'], std=dark_fpn_stat['td_even_std'], size=(rod_height, rod_width), device=device) fixed_sdl_noise_even = torch.normal(mean=dark_fpn_stat['sdl_even_mean'], std=dark_fpn_stat['sdl_even_std'], size=(rod_height, rod_width), device=device) fixed_sdr_noise_even = torch.normal(mean=dark_fpn_stat['sdr_even_mean'], std=dark_fpn_stat['sdr_even_std'], size=(rod_height, rod_width), device=device) fpn = { 'td_odd': fixed_td_noise_odd, 'sdl_odd': fixed_sdl_noise_odd, 'sdr_odd': fixed_sdr_noise_odd, 'td_even': fixed_td_noise_even, 'sdl_even': fixed_sdl_noise_even, 'sdr_even': fixed_sdr_noise_even } return fpn
def sample_sensor_fixed_noise(sensor_height, sensor_width, device, sensor_fixed_noise_prob, sensor_fixed_noise_mean_ch0, sensor_fixed_noise_std_ch0, sensor_fixed_noise_mean_ch12, sensor_fixed_noise_std_ch12): rod_height = sensor_height // 2 rod_width = sensor_width // 4 fixed_noise_map = None if sensor_fixed_noise_prob > 0 and random.random() < sensor_fixed_noise_prob: fixed_noise_map = np.zeros((3, rod_height, rod_width), dtype=np.float32) if sensor_fixed_noise_std_ch0 > 0: fixed_noise_map[0] = np.random.normal(loc=sensor_fixed_noise_mean_ch0, scale=sensor_fixed_noise_std_ch0, size=(rod_height, rod_width)).astype(np.float32) if sensor_fixed_noise_std_ch12 > 0: shared_map = np.random.normal(loc=sensor_fixed_noise_mean_ch12, scale=sensor_fixed_noise_std_ch12, size=(rod_height, rod_width)).astype(np.float32) fixed_noise_map[1] = shared_map fixed_noise_map[2] = shared_map fixed_noise_map = torch.from_numpy(fixed_noise_map).to(device) return fixed_noise_map def apply_gray_jitter(img, gray_weight_jitter, gray_gain_min, gray_gain_max): gray_weights = np.array([0.299, 0.587, 0.114], dtype=np.float32) gray_gain = 1.0 if gray_weight_jitter > 0: delta = np.random.uniform(-gray_weight_jitter, gray_weight_jitter, size=3).astype(np.float32) gray_weights = np.clip(gray_weights + delta, 1e-6, None) gray_weights = gray_weights / np.sum(gray_weights) gray_gain = float(np.random.uniform(gray_gain_min, gray_gain_max)) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 img_gray = np.tensordot(img_rgb, gray_weights, axes=([-1], [0])) img_gray = (img_gray * gray_gain * 255.0).clip(0, 255).astype(np.uint8) else: img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) return img_gray def get_rod_img(img_gray_tensor, rod_height, rod_width, device): if img_gray_tensor.dim() == 3: # [C, H, W] img_gray_tensor = torch.mean(img_gray_tensor, dim=0) # [H, W] etron_img_rod = torch.zeros(size=(rod_height, rod_width), device=device) etron_img_bin = (img_gray_tensor[0::2, 0::2] + img_gray_tensor[1::2, 0::2] + img_gray_tensor[0::2, 1::2] + img_gray_tensor[1::2, 1::2]) / 4 etron_img_rod[0::2, :] = etron_img_bin[0::2, 0::2] etron_img_rod[1::2, :] = etron_img_bin[1::2, 1::2] return etron_img_rod def apply_sensor_noise(td, sdl, sdr, fixed_noise_map, sensor_random_noise_prob, sensor_poisson_lambda, sensor_random_noise_std, rod_height, rod_width, device): td, sdl, sdr = td.float(), sdl.float(), sdr.float() if fixed_noise_map is not None: td = td + (fixed_noise_map[0] * 128.0) sdl = sdl + (fixed_noise_map[1] * 128.0) sdr = sdr + (fixed_noise_map[2] * 128.0) if sensor_random_noise_prob > 0 and random.random() < sensor_random_noise_prob: mu = max(float(sensor_poisson_lambda), 1e-6) std_target = max(float(sensor_random_noise_std), 0.0) if std_target > 0: scale = (std_target / np.sqrt(2.0 * mu)) * 128.0 p1 = torch.poisson(torch.full((3, rod_height, rod_width), mu, device=device)) p2 = torch.poisson(torch.full((3, rod_height, rod_width), mu, device=device)) skellam_noise = (p1 - p2) * scale td = td + skellam_noise[0] sdl = sdl + skellam_noise[1] sdr = sdr + skellam_noise[2] return td, sdl, sdr def run_sim(datapath,sensor_width, sensor_height, device, display = False, save = False, save_path = os.path.join(os.environ.get('HOME'), "temp"), interp=False, # 传感器噪声参数用于噪声增强 sensor_fixed_noise_prob=sim_parama.get('sensor_fixed_noise_prob', 0.0), sensor_random_noise_prob=sim_parama.get('sensor_random_noise_prob', 0.0), sensor_fixed_noise_mean_ch0=sim_parama.get('sensor_fixed_noise_mean_ch0', 0.2), sensor_fixed_noise_std_ch0=sim_parama.get('sensor_fixed_noise_std_ch0', 0.5 / 128.0), sensor_fixed_noise_mean_ch12=sim_parama.get('sensor_fixed_noise_mean_ch12', 0.0), sensor_fixed_noise_std_ch12=sim_parama.get('sensor_fixed_noise_std_ch12', 0.4 / 128.0), sensor_random_noise_std=sim_parama.get('sensor_random_noise_std', 1.0 / 128.0), sensor_poisson_lambda=sim_parama.get('sensor_poisson_lambda', 4), # 灰度输入扰动参数 gray_weight_jitter=sim_parama.get('gray_weight_jitter', 0.0), gray_gain_min=sim_parama.get('gray_gain_min', 0.78), gray_gain_max=sim_parama.get('gray_gain_max', 0.88), # 仿真参数 sim_threshold_range=tuple(sim_parama.get('sim_threshold_range', [0.005, 0.02])) ): print('读取:',datapath,'下的所有图像数据,按图像名称排序') print('存储到:',save_path,'') # 1. 采样仿真阈值 sim_th = random.uniform(sim_threshold_range[0], sim_threshold_range[1]) # 2. 采样传感器固定噪声图 fixed_noise_map = sample_sensor_fixed_noise(sensor_height, sensor_width, device, sensor_fixed_noise_prob, sensor_fixed_noise_mean_ch0, sensor_fixed_noise_std_ch0, sensor_fixed_noise_mean_ch12, sensor_fixed_noise_std_ch12) # assum you have a png, 8bit rgb dataset # if use some other dataset, please write your own code flist = sorted(os.listdir(datapath), key=sort_filenames) cop_cone_skip = 10 # the interval of RGB generation sim_cnt = 0 rod_height = sensor_height // 2 rod_width = sensor_width // 4 rod_v_buf = torch.zeros(size=(2, rod_height, rod_width), dtype=torch.float32, device=device) ### ADVANCED NOISE SIM, Fixed pattern noise odd and ven fpn = get_fpn_from_stat(rod_height, rod_width, device) # fixed_td_noise_odd_np = fpn['td_odd'].cpu().numpy() * 128 if save: if os.path.exists(save_path) == False: os.makedirs(save_path) rgb_save_path = os.path.join(save_path, 'rgb') tsdiff_save_path = os.path.join(save_path, 'tsdiff') gt_with_noise_save_path = os.path.join(save_path, 'noisy_gt') viz_path = os.path.join(save_path, 'viz') if os.path.exists(rgb_save_path) == False: os.makedirs(rgb_save_path) if os.path.exists(tsdiff_save_path) == False: os.makedirs(tsdiff_save_path) if os.path.exists(gt_with_noise_save_path) == False: os.makedirs(gt_with_noise_save_path) if os.path.exists(viz_path) == False: os.makedirs(viz_path) # Here save Fixed pattern noise all_fpn = np.stack([fpn['td_odd'].cpu().numpy(), fpn['sdl_odd'].cpu().numpy(), fpn['sdr_odd'].cpu().numpy(), fpn['td_even'].cpu().numpy(), fpn['sdl_even'].cpu().numpy(), fpn['sdr_even'].cpu().numpy()], axis=-1) np.save(os.path.join(save_path, "fpn.npy"), all_fpn) for img_file in tqdm(flist): if img_file.endswith(".png") or img_file.endswith(".jpg"): # You can add more endings, liek .bmp, .tiff, etc. img = cv2.imread(os.path.join(datapath, img_file)) img = cv2.resize(img, (sensor_width ,sensor_height)) # 灰度扰动 img_gray = apply_gray_jitter(img, gray_weight_jitter, gray_gain_min, gray_gain_max) img_gray_tensor = torch.FloatTensor(img_gray).to(device) if sim_cnt % cop_cone_skip == 0: rgb = torch.ShortTensor(img).to(device) ### ADVANCED NOISE SIM, # First, possion noise img_gray_tensor = torch.poisson(img_gray_tensor) img_diff_sim = img_gray_tensor / 255.0 # if you use 10bit, or higher precision img, please divide by the max value! etron_img_rod = get_rod_img(img_diff_sim, rod_height, rod_width, device) ### ADVANCED SIM for NOISE: Add Norm read noise! etron_img_rod = etron_img_rod + torch.normal(mean=0, std=0.008, size=(rod_height, rod_width), device=device) # etron_img_rod img_diff_sim = etron_img_rod.unsqueeze(0) rod_v_buf = push_to_fifo(rod_v_buf, img_diff_sim) td, sdl, sdr = diff_response(rod_v_buf, {'td': sim_th, 'sd': sim_th}, sim_cnt, device, fpn=fpn) # 传感器噪声增强 (固定噪声 + 随机噪声) td, sdl, sdr = apply_sensor_noise(td, sdl, sdr, fixed_noise_map, sensor_random_noise_prob, sensor_poisson_lambda, sensor_random_noise_std, rod_height, rod_width, device) if interp: tsdiff = torch.stack([td, sdl, sdr], dim=0) # [3, H, W] # upsample_cross_conv expects [C, T, H, W] tsdiff_expand = upsample_cross_conv(tsdiff.unsqueeze(1)).squeeze(1) # [3, H, W*2] # bilinear upsample to full size tsdiff_full = F.interpolate(tsdiff_expand.unsqueeze(0), size=(sensor_height, sensor_width), mode='bilinear', align_corners=False).squeeze(0) td, sdl, sdr = tsdiff_full[0], tsdiff_full[1], tsdiff_full[2] # Update sizes for visualization if needed cur_rod_h, cur_rod_w = sensor_height, sensor_width // 2 else: cur_rod_h, cur_rod_w = rod_height, rod_width if display or save: rgb_np = rgb.cpu().numpy().astype(np.uint8) td_np = td.cpu().numpy() sdl_np = sdl.cpu().numpy() sdr_np = sdr.cpu().numpy() tsdiff_raw = np.stack([td_np, sdl_np, sdr_np], axis=-1) td_np_viz = visualize_diff(td_np, 8, color='rg') sdl_np_viz = visualize_diff(sdl_np, 8, color='rg') sdr_np_viz = visualize_diff(sdr_np, 8, color='rg') red_line = np.zeros((cur_rod_h, 2, 3), dtype=np.uint8) diff_concat = np.concatenate((td_np_viz, red_line, sdl_np_viz, red_line, sdr_np_viz), axis=1) ### for visualization sd_merge_viz = cv2.addWeighted(sdl_np_viz, 0.5, sdr_np_viz, 0.5, 0) tsd_viz = np.concatenate( (cv2.resize(td_np_viz, (cur_rod_w * 2, cur_rod_h)), cv2.resize(sd_merge_viz, (cur_rod_w * 2, cur_rod_h))), axis=1) rgb_tsd_viz = np.concatenate((rgb_np, tsd_viz), axis=0) if display: cv2.imshow("all", rgb_tsd_viz) cv2.waitKey(1) if save: if sim_cnt % cop_cone_skip == 0: cv2.imwrite(os.path.join(rgb_save_path, img_file), rgb_np) cv2.imwrite(os.path.join(viz_path, f"sim_viz_{sim_cnt}.png"), rgb_tsd_viz) np.save(os.path.join(tsdiff_save_path, f"tsdiff_{sim_cnt}.npy"), tsdiff_raw) np.save(os.path.join(gt_with_noise_save_path, f"noisyGT_{sim_cnt}.npy"), etron_img_rod.cpu().numpy()) # sim counter sim_cnt += 1 def run_sim_singleimg(img_target=None, img_ref=None ,sensor_width= 640, sensor_height= 320, xy=False, interp=False, # assum you have a png, 8bit rgb dataset device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), # 传感器噪声参数用于噪声增强 sensor_fixed_noise_prob=sim_parama.get('sensor_fixed_noise_prob', 0.0), sensor_random_noise_prob=sim_parama.get('sensor_random_noise_prob', 0.0), sensor_fixed_noise_mean_ch0=sim_parama.get('sensor_fixed_noise_mean_ch0', 0.2), sensor_fixed_noise_std_ch0=sim_parama.get('sensor_fixed_noise_std_ch0', 0.5 / 128.0), sensor_fixed_noise_mean_ch12=sim_parama.get('sensor_fixed_noise_mean_ch12', 0.0), sensor_fixed_noise_std_ch12=sim_parama.get('sensor_fixed_noise_std_ch12', 0.4 / 128.0), sensor_random_noise_std=sim_parama.get('sensor_random_noise_std', 1.0 / 128.0), sensor_poisson_lambda=sim_parama.get('sensor_poisson_lambda', 4), # 灰度输入扰动参数 gray_weight_jitter=sim_parama.get('gray_weight_jitter', 0.0), gray_gain_min=sim_parama.get('gray_gain_min', 0.78), gray_gain_max=sim_parama.get('gray_gain_max', 0.88), # 仿真参数 sim_threshold_range=tuple(sim_parama.get('sim_threshold_range', [0.0, 0.0])) ): # 输入: h,w,c的numpy矩阵 # 输出: c,h,w的tensor 和 h,w的tensor # 1. 采样仿真阈值 sim_th = random.uniform(sim_threshold_range[0], sim_threshold_range[1]) if sim_threshold_range[1] > sim_threshold_range[0] else sim_parama.get('threshold', 0.0) # 2. 采样传感器固定噪声图 fixed_noise_map = sample_sensor_fixed_noise(sensor_height, sensor_width, device, sensor_fixed_noise_prob, sensor_fixed_noise_mean_ch0, sensor_fixed_noise_std_ch0, sensor_fixed_noise_mean_ch12, sensor_fixed_noise_std_ch12) sim_cnt = 0 rod_height = sensor_height // 2 rod_width = sensor_width // 4 fpn = get_fpn_from_stat(rod_height, rod_width, device) img_target_tensor = None img_ref_tensor = None if img_target is not None: img_target = cv2.resize(img_target, (sensor_width ,sensor_height)) img_gray_target = apply_gray_jitter(img_target, gray_weight_jitter, gray_gain_min, gray_gain_max) img_target = cv2.cvtColor(img_gray_target, cv2.COLOR_GRAY2BGR) img_target_tensor = torch.FloatTensor(img_target).to(device) / 255.0 img_target_tensor = img_target_tensor.permute(2,0,1) if img_ref is not None: img_ref = cv2.resize(img_ref, (sensor_width ,sensor_height)) img_gray_ref = apply_gray_jitter(img_ref, gray_weight_jitter, gray_gain_min, gray_gain_max) img_ref = cv2.cvtColor(img_gray_ref, cv2.COLOR_GRAY2BGR) img_ref_tensor = torch.FloatTensor(img_ref).to(device) / 255.0 img_ref_tensor = img_ref_tensor.permute(2,0,1) # Generate missing frame if needed if img_target_tensor is None and img_ref_tensor is not None: # Generate target (current) from ref (previous) random_x = int(random.uniform(-32,32)) random_y = int(random.uniform(-32,32)) random_degrees = random.uniform(-4,4) img_target_tensor = shift_rotate_tensor(img_ref_tensor.unsqueeze(0), random_x, random_y, random_degrees).squeeze(0) elif img_ref_tensor is None and img_target_tensor is not None: # Generate ref (previous) from target (current) random_x = int(random.uniform(-32,32)) random_y = int(random.uniform(-32,32)) random_degrees = random.uniform(-4,4) img_ref_tensor = shift_rotate_tensor(img_target_tensor.unsqueeze(0), random_x, random_y, random_degrees).squeeze(0) elif img_target_tensor is None and img_ref_tensor is None: raise ValueError("Either img_target or img_ref must be provided to run_sim_singleimg") # ROD采样逻辑 intensity_ref = get_rod_img(img_ref_tensor, rod_height, rod_width, device).unsqueeze(0) intensity_target = get_rod_img(img_target_tensor, rod_height, rod_width, device).unsqueeze(0) # First frame SD only (using ref-ref to make TD=0) rod_v_buf = torch.cat([intensity_ref, intensity_ref], dim=0) _, sdl0, sdr0 = diff_response(rod_v_buf, {'td': sim_th, 'sd': sim_th}, 0, device, fpn=fpn, xy=xy) # Second frame TD+SD (using ref-target) rod_v_buf2 = torch.cat([intensity_ref, intensity_target], dim=0) td, sdl1, sdr1 = diff_response(rod_v_buf2, {'td': sim_th, 'sd': sim_th}, 1, device, fpn=fpn, xy=xy) # 传感器噪声增强 (固定噪声 + 随机噪声) _, sdl0, sdr0 = apply_sensor_noise(torch.zeros_like(td), sdl0, sdr0, fixed_noise_map, sensor_random_noise_prob, sensor_poisson_lambda, sensor_random_noise_std, rod_height, rod_width, device) td, sdl1, sdr1 = apply_sensor_noise(td, sdl1, sdr1, fixed_noise_map, sensor_random_noise_prob, sensor_poisson_lambda, sensor_random_noise_std, rod_height, rod_width, device) if interp: if not xy: # First frame SD (no TD) tsdiff0 = torch.stack([torch.zeros_like(td), sdl0, sdr0], dim=0) tsdiff0_expand = upsample_cross_conv(tsdiff0.unsqueeze(1)).squeeze(1) tsdiff0_full = F.interpolate(tsdiff0_expand.unsqueeze(0), size=(sensor_height, sensor_width), mode='bilinear', align_corners=False).squeeze(0) sdl0, sdr0 = tsdiff0_full[1], tsdiff0_full[2] # Second frame TD + SD tsdiff1 = torch.stack([td, sdl1, sdr1], dim=0) tsdiff1_expand = upsample_cross_conv(tsdiff1.unsqueeze(1)).squeeze(1) tsdiff1_full = F.interpolate(tsdiff1_expand.unsqueeze(0), size=(sensor_height, sensor_width), mode='bilinear', align_corners=False).squeeze(0) td, sdl1, sdr1 = tsdiff1_full[0], tsdiff1_full[1], tsdiff1_full[2] else: # xy=True, just directly upsample as requested def up_direct(t): if t.dim() == 2: t = t.unsqueeze(0).unsqueeze(0) elif t.dim() == 3: t = t.unsqueeze(0) return F.interpolate(t, size=(sensor_height, sensor_width), mode='bilinear', align_corners=False).squeeze() # td still uses upsample_cross_conv based on tianmoucData reference td_expand = upsample_cross_conv(td.view(1, 1, rod_height, rod_width)).squeeze() td = F.interpolate(td_expand.view(1, 1, rod_height, rod_width*2), size=(sensor_height, sensor_width), mode='bilinear', align_corners=False).squeeze() sdl0 = up_direct(sdl0) sdr0 = up_direct(sdr0) sdl1 = up_direct(sdl1) sdr1 = up_direct(sdr1) cdim = 0 sd0 = torch.stack([sdl0, sdr0], dim=cdim) sd1 = torch.stack([sdl1, sdr1], dim=cdim) return img_ref_tensor.cpu(), img_target_tensor.cpu(), td.cpu(), sd0.cpu(), sd1.cpu()