首页 > 其他 > 详细

CA code1.2 mask(方形)的生成和Dataset

时间:2021-03-28 21:39:32      阅读:29      评论:0      收藏:0      [点我收藏+]

# random_bbox

确定mask=0的位置和大小参数:[t, l, h, w ]

技术分享图片
 1 bboxes = random_bbox(config, batch_size=ground_truth.size(0))  # 随机确定128*128的方形区域
 2 def random_bbox(config, batch_size):
 3     """Generate a random tlhw with configuration.
 4 
 5     Args:
 6         config: Config should have configuration including img
 7 
 8     Returns:
 9         tuple: (top, left, height, width)
10         eg:
11         tensor([[ 73,  20, 128, 128],
12         [ 73,  20, 128, 128],
13         [ 73,  20, 128, 128],
14         [ 73,  20, 128, 128],
15         [ 73,  20, 128, 128]])
16 
17     """
18     img_height, img_width, _ = config[image_shape]  # [256, 256]
19     h, w = config[mask_shape]  # [128, 128]
20     margin_height, margin_width = config[margin]  # [0, 0]
21     maxt = img_height - margin_height - h  # maxt = 256-0-128=128
22     maxl = img_width - margin_width - w  # maxl = 256-0-128=128
23     bbox_list = []
24     # 每个batch内mask是否相同
25     if config[mask_batch_same]:  # True
26         t = np.random.randint(margin_height, maxt)  # t = randint(0, 128)
27         l = np.random.randint(margin_width, maxl)  # l = randint(0, 128)
28         bbox_list.append((t, l, h, w))  # (t, l, 128, 128)
29         bbox_list = bbox_list * batch_size  # batch_size个(t, l, 128, 128)
30     else:
31         for i in range(batch_size):
32             t = np.random.randint(margin_height, maxt)
33             l = np.random.randint(margin_width, maxl)
34             bbox_list.append((t, l, h, w))
35 
36     return torch.tensor(bbox_list, dtype=torch.int64)  # bbox_list转化为Tensor
View Code

 # bbox2mask

将位置参数转化为mask

技术分享图片
 1 def mask_image(x, bboxes, config):
 2     height, width, _ = config[image_shape]                           # [256, 256]
 3     max_delta_h, max_delta_w = config[max_delta_shape]               # [32, 32]
 4     mask = bbox2mask(bboxes, height, width, max_delta_h, max_delta_w)  # 返回mask: torch.tensor[5, 1, 256, 256]
 5     if x.is_cuda:
 6         mask = mask.cuda()
 7 
 8     if config[mask_type] == hole:
 9         result = x * (1. - mask)
10     elif config[mask_type] == mosaic:
11         # TODO: Matching the mosaic patch size and the mask size
12         mosaic_unit_size = config[mosaic_unit_size]
13         downsampled_image = F.interpolate(x, scale_factor=1. / mosaic_unit_size, mode=nearest)
14         upsampled_image = F.interpolate(downsampled_image, size=(height, width), mode=nearest)
15         result = upsampled_image * mask + x * (1. - mask)
16     else:
17         raise NotImplementedError(Not implemented mask type.)
18     # result为corrupt_img
19     return result, mask
View Code

# mask_image

确定corrupt_img和mask

技术分享图片
 1 def mask_image(x, bboxes, config):
 2     height, width, _ = config[image_shape]                           # [256, 256]
 3     max_delta_h, max_delta_w = config[max_delta_shape]               # [32, 32]
 4     mask = bbox2mask(bboxes, height, width, max_delta_h, max_delta_w)  # 返回mask: torch.tensor[5, 1, 256, 256]
 5     if x.is_cuda:
 6         mask = mask.cuda()
 7 
 8     if config[mask_type] == hole:
 9         result = x * (1. - mask)
10     elif config[mask_type] == mosaic:
11         # TODO: Matching the mosaic patch size and the mask size
12         mosaic_unit_size = config[mosaic_unit_size]
13         downsampled_image = F.interpolate(x, scale_factor=1. / mosaic_unit_size, mode=nearest)
14         upsampled_image = F.interpolate(downsampled_image, size=(height, width), mode=nearest)
15         result = upsampled_image * mask + x * (1. - mask)
16     else:
17         raise NotImplementedError(Not implemented mask type.)
18     # result为corrupt_img
19     return result, mask
View Code
x1, x2, offset_flow = self.netG(x, masks) # 两阶段的结果和热度图

CA code1.2 mask(方形)的生成和Dataset

原文:https://www.cnblogs.com/Overture/p/14589637.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!