佛山网站设计联系方式,广西住房城乡建设培训中心,保之友微网站怎么建,河南品牌网站建设目录
摘要
Abstract
Mask R-CNN
网络架构
Backbone
RPN
Proposal Layer
ROIAlign
bbox检测
Mask分割
损失计算
实验复现
总结 摘要
Mask R-CNN是在Faster R-CNN的基础上进行改进的目标检测和实例分割网络。Faster R-CNN主要用于目标检测#xff0c;输出对象的边…目录
摘要
Abstract
Mask R-CNN
网络架构
Backbone
RPN
Proposal Layer
ROIAlign
bbox检测
Mask分割
损失计算
实验复现
总结 摘要
Mask R-CNN是在Faster R-CNN的基础上进行改进的目标检测和实例分割网络。Faster R-CNN主要用于目标检测输出对象的边界框和类别标签而Mask R-CNN在Faster R-CNN的基础上增加了像素级分割的能力能够输出对象的像素级掩码。Mask R-CNN使用了ROI Align层解决了Faster R-CNN在边界像素对齐方面的问题从而提高了检测和分割的精度。ROI Align通过双线性插值来避免量化操作更精确地从特征图中提取对应RoI的富含空间信息的特征保持空间位置信息解决了Faster R-CNN中使用的RoI Pooling方法的定位不准确问题 。Mask R-CNN在Faster R-CNN的架构基础上增加了一个并行的掩膜预测分支在每个RoI上使用FCN来预测对象的掩膜使得网络能够更细致地学习物体的空间特征。Mask R-CNN在PASCAL VOC和MS COCO等多个重要的数据集上达到了当时的最佳分割和检测精度。 Abstract
Mask R-CNN is an improved object detection and instance segmentation network based on Faster R-CNN. Faster R-CNN is primarily used for object detection, outputting the bounding boxes and class labels of objects, while Mask R-CNN adds the capability of pixel-level segmentation on the basis of Faster R-CNN, enabling the output of pixel-level masks for objects. Mask R-CNN employs the ROI Align layer, which addresses the issue of boundary pixel alignment in Faster R-CNN, thereby enhancing the precision of detection and segmentation. ROI Align uses bilinear interpolation to avoid quantization operations, extracting features from the feature map that correspond to the RoI with rich spatial information more accurately, preserving spatial location information, and resolving the inaccurate localization issue of the RoI Pooling method used in Faster R-CNN. Mask R-CNN adds a parallel mask prediction branch to the architecture of Faster R-CNN, using an FCN to predict the masks of objects on each RoI, allowing the network to learn the spatial features of objects in more detail. Mask R-CNN has achieved the best segmentation and detection accuracy at the time on several important datasets, including PASCAL VOC and MS COCO. Mask R-CNN
论文地址[1703.06870v3] Mask R-CNN
项目地址Mask R-CNN
Mask R-CNN是一种在有效检测目标的同时输出高质量的实例分割的网络模型是对Faster R-CNN的扩展在bbox检测的同时并行地增加一个预测分割掩码的分支。Mask R-CNN就是将物体检测和语义分割结合起来从而达到了实例分割的效果该模型效果图如下所示 在我们学习Mask R-CNN之前我们需要先对Faster R-CNN有一定的了解大家可以通过我之前的博客了解。
网络架构
Mask R-CNN网络模型如下图所示 Backbone
该模型采用了ResNet101FPN作为骨干网络进行图像特征提取选用ResNet提取特征我们已再熟悉不过了为了增强图像的语义特征更好地预测不同大小的物体额外引入了FPN模块。FPN示意图如下图(d)所示 图(d)中金字塔底部为浅层特征图金字塔顶部为深层特征图。浅层特征图感受野小适合检测小目标深层的特征图感受野大适合检测大目标。FPN通过融合不同尺度的特征图使得模型能够同时处理不同大小的目标。
FPN网络结构如下所示 该网络主要由自底向上的特征提取路径和自顶向下的特征融合路径组成。自底向上的路径是ResNet的正向传播过程用于提取不同层次的特征图。自顶向下的路径通过上采样和横向连接的方式将高层特征图的语义信息与低层特征图的空间信息进行融合。
RPN
主要是在骨干网络提取的特征图像中选取候选区域详细可看Faster R-CNN中的介绍。
Proposal Layer
将RPN选取的候选框作为输入利用rpn_bbox对选取的anchors进行修正得到修正后的RoI。然后舍弃掉修正后边框超过图片大小的anchor再根据RPN网络获取score靠前的前6000个RoI。最后利用非极大抑制的方法获得最终需要进行预测和分割的区域。
ROIAlign
ROIAlign的提出是为了解决Faster R-CNN中RoI Pooling的区域不匹配的问题。
RoI Pooling
RoI Pooling是Faster R-CNN中必不可少的一步因为其会产生长度固定的特征向量有了长度固定的特征向量才能进行Softmax计算分类损失。该方法区域不匹配问题是由于RoI Pooling过程中的取整操作造成的。 例如输入一张 800×800 的图片经过一个有5次降采样的卷机网络得到大小为 25×25 的特征图像。 第一次区域不匹配 输入图像的RoI区域大小为 600×500 经过网络之后对应的区域为 18.75 × 15.625 ROI Pooling采用向下取整的方式得到RoI区域的特征图像为 18 × 15 。 第二次区域不匹配 然后RoI Pooling将上一步中的特征图像分块假如需要一个 7 × 7 块每个块大小为 同样进行向下取整导致每块大小为 2×2 即整个RoI区域的特征图像的尺寸为缩小为 14×14 。 上述两次不匹配导致特征图像在横向和纵向上分别产生了4.75和1.625的误差对于Faster R-CNN进行目标检测而言几个像素的偏差在视觉上可能微乎其微。但是对于Mask R-CNN增加了实例分割而言就会严重影响精确度。 ROIAlign
RoIAlign没有取整操作可全程使用浮点数。 1计算RoI区域的边长边长不取整 2将RoI区域均匀分成 k × k 个块每个块的大小不取整 3每个块的值为其最邻近的特征图像的四个值通过双线性插值得到 假设白框中的交点为特征图像上的点蓝框为RoI特征图像。将蓝框分为了 7x7 的块若要计算每个块的值则需要借助以下公式 其中u、v分别为某块中心粉点与、的横向距离u以及与、的纵向距离v。 4使用Max Pooling或者Average Pooling得到长度固定的特征向量。 使用RoIAlign的对于准确度的提升还是很明显的如下图所示 bbox检测
将RoIAlign输出的 7x7x256 的特征图像拉伸至 1x1x1024 的特征向量然后分别进行分类和框预测即可。与Faster R-CNN类似如下图灰色区域所示 Mask分割
如上图下半部分所示Mask分支使用传统的FCN图像分割方法最后生成 28×28×80 的预测掩码结果。 最后得到的结果是软掩码经过Sigmoid后的(0,1)浮点数。 损失计算 Mask R-CNN在Faster R-CNN的基础上添加了一个用于语义分割的Mask损失函数。
在进行掩码预测时FCN的分割和预测是同时进行的即需要预测每个像素属于哪一类。而Mask R-CNN将分类和语义分割任务进行了解耦即每个类单独的预测一个位置掩码这种解耦提升了语义分割的效果如下图所示 实验复现
本次实验特征提取网络采用预训练的ResNet50Mask R-CNN以Batch Size8、学习率为0.08在COCO2017数据集上训练一轮。 由于资源有限只训练了一轮由于COCO数据集比较大最后得到的检测和分割效果还能接受。 数据处理代码如下
import os
import jsonimport torch
from PIL import Image
import torch.utils.data as data
from pycocotools.coco import COCO
from train_utils import coco_remove_images_without_annotations, convert_coco_poly_maskclass CocoDetection(data.Dataset):MS Coco Detection https://cocodataset.org/_ Dataset.Args:root (string): Root directory where images are downloaded to.dataset (string): train or val.transforms (callable, optional): A function/transform that takes input sample and its target as entryand returns a transformed version.def __init__(self, root, datasettrain, transformsNone, years2017):super(CocoDetection, self).__init__()assert dataset in [train, val], dataset must be in [train, val]anno_file finstances_{dataset}{years}.jsonassert os.path.exists(root), file {} does not exist..format(root)self.img_root os.path.join(root, f{dataset}{years})assert os.path.exists(self.img_root), path {} does not exist..format(self.img_root)self.anno_path os.path.join(root, annotations, anno_file)assert os.path.exists(self.anno_path), file {} does not exist..format(self.anno_path)self.mode datasetself.transforms transformsself.coco COCO(self.anno_path)# 获取coco数据索引与类别名称的关系# 注意在object80中的索引并不是连续的虽然只有80个类别但索引还是按照stuff91来排序的data_classes dict([(v[id], v[name]) for k, v in self.coco.cats.items()])max_index max(data_classes.keys()) # 90# 将缺失的类别名称设置成N/Acoco_classes {}for k in range(1, max_index 1):if k in data_classes:coco_classes[k] data_classes[k]else:coco_classes[k] N/Aif dataset train:json_str json.dumps(coco_classes, indent4)with open(coco91_indices.json, w) as f:f.write(json_str)self.coco_classes coco_classesids list(sorted(self.coco.imgs.keys()))if dataset train:# 移除没有目标或者目标面积非常小的数据valid_ids coco_remove_images_without_annotations(self.coco, ids)self.ids valid_idselse:self.ids idsdef parse_targets(self,img_id: int,coco_targets: list,w: int None,h: int None):assert w 0assert h 0# 只筛选出单个对象的情况anno [obj for obj in coco_targets if obj[iscrowd] 0]boxes [obj[bbox] for obj in anno]# guard against no boxes via resizingboxes torch.as_tensor(boxes, dtypetorch.float32).reshape(-1, 4)# [xmin, ymin, w, h] - [xmin, ymin, xmax, ymax]boxes[:, 2:] boxes[:, :2]boxes[:, 0::2].clamp_(min0, maxw)boxes[:, 1::2].clamp_(min0, maxh)classes [obj[category_id] for obj in anno]classes torch.tensor(classes, dtypetorch.int64)area torch.tensor([obj[area] for obj in anno])iscrowd torch.tensor([obj[iscrowd] for obj in anno])segmentations [obj[segmentation] for obj in anno]masks convert_coco_poly_mask(segmentations, h, w)# 筛选出合法的目标即x_maxx_min且y_maxy_minkeep (boxes[:, 3] boxes[:, 1]) (boxes[:, 2] boxes[:, 0])boxes boxes[keep]classes classes[keep]masks masks[keep]area area[keep]iscrowd iscrowd[keep]target {}target[boxes] boxestarget[labels] classestarget[masks] maskstarget[image_id] torch.tensor([img_id])# for conversion to coco apitarget[area] areatarget[iscrowd] iscrowdreturn targetdef __getitem__(self, index):Args:index (int): IndexReturns:tuple: Tuple (image, target). target is the object returned by coco.loadAnns.coco self.cocoimg_id self.ids[index]ann_ids coco.getAnnIds(imgIdsimg_id)coco_target coco.loadAnns(ann_ids)path coco.loadImgs(img_id)[0][file_name]img Image.open(os.path.join(self.img_root, path)).convert(RGB)w, h img.sizetarget self.parse_targets(img_id, coco_target, w, h)if self.transforms is not None:img, target self.transforms(img, target)return img, targetdef __len__(self):return len(self.ids)def get_height_and_width(self, index):coco self.cocoimg_id self.ids[index]img_info coco.loadImgs(img_id)[0]w img_info[width]h img_info[height]return h, wstaticmethoddef collate_fn(batch):return tuple(zip(*batch))if __name__ __main__:train CocoDetection(/root/autodl-tmp/COCO2017, datasettrain)print(len(train))t train[0]
模型训练代码如下
import os
import datetimeimport torch
from torchvision.ops.misc import FrozenBatchNorm2dimport transforms
from network_files import MaskRCNN
from backbone import resnet50_fpn_backbone
from my_dataset_coco import CocoDetection
from my_dataset_voc import VOCInstances
from train_utils import train_eval_utils as utils
from train_utils import GroupedBatchSampler, create_aspect_ratio_groupsdef create_model(num_classes, load_pretrain_weightsTrue):# 如果GPU显存很小batch_size不能设置很大建议将norm_layer设置成FrozenBatchNorm2d(默认是nn.BatchNorm2d)# FrozenBatchNorm2d的功能与BatchNorm2d类似但参数无法更新# trainable_layers包括[layer4, layer3, layer2, layer1, conv1] 5代表全部训练# backbone resnet50_fpn_backbone(norm_layerFrozenBatchNorm2d,# trainable_layers3)# resnet50 imagenet weights url: https://download.pytorch.org/models/resnet50-0676ba61.pthbackbone resnet50_fpn_backbone(pretrain_path./weight/resnet50.pth, trainable_layers3)model MaskRCNN(backbone, num_classesnum_classes)if load_pretrain_weights:# coco weights url: https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pthweights_dict torch.load(./weight/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth, map_locationcpu)for k in list(weights_dict.keys()):if (box_predictor in k) or (mask_fcn_logits in k):del weights_dict[k]print(model.load_state_dict(weights_dict, strictFalse))return modeldef main(args):device torch.device(args.device if torch.cuda.is_available() else cpu)print(Using {} device training..format(device.type))# 用来保存coco_info的文件now datetime.datetime.now().strftime(%Y%m%d-%H%M%S)det_results_file fdet_results{now}.txtseg_results_file fseg_results{now}.txtdata_transform {train: transforms.Compose([transforms.ToTensor(),transforms.RandomHorizontalFlip(0.5)]),val: transforms.Compose([transforms.ToTensor()])}data_root args.data_path# load train data set# coco2017 - annotations - instances_train2017.jsontrain_dataset CocoDetection(data_root, train, data_transform[train])# VOCdevkit - VOC2012 - ImageSets - Main - train.txt# train_dataset VOCInstances(data_root, year2012, txt_nametrain.txt, transformsdata_transform[train])train_sampler None# 是否按图片相似高宽比采样图片组成batch# 使用的话能够减小训练时所需GPU显存默认使用if args.aspect_ratio_group_factor 0:train_sampler torch.utils.data.RandomSampler(train_dataset)# 统计所有图像高宽比例在bins区间中的位置索引group_ids create_aspect_ratio_groups(train_dataset, kargs.aspect_ratio_group_factor)# 每个batch图片从同一高宽比例区间中取train_batch_sampler GroupedBatchSampler(train_sampler, group_ids, args.batch_size)# 注意这里的collate_fn是自定义的因为读取的数据包括image和targets不能直接使用默认的方法合成batchbatch_size args.batch_sizenw min([os.cpu_count(), batch_size if batch_size 1 else 0, 8]) # number of workersprint(Using %g dataloader workers % nw)if train_sampler:# 如果按照图片高宽比采样图片dataloader中需要使用batch_samplertrain_data_loader torch.utils.data.DataLoader(train_dataset,batch_samplertrain_batch_sampler,pin_memoryTrue,num_workersnw,collate_fntrain_dataset.collate_fn)else:train_data_loader torch.utils.data.DataLoader(train_dataset,batch_sizebatch_size,shuffleTrue,pin_memoryTrue,num_workersnw,collate_fntrain_dataset.collate_fn)# load validation data set# coco2017 - annotations - instances_val2017.jsonval_dataset CocoDetection(data_root, val, data_transform[val])# VOCdevkit - VOC2012 - ImageSets - Main - val.txt# val_dataset VOCInstances(data_root, year2012, txt_nameval.txt, transformsdata_transform[val])val_data_loader torch.utils.data.DataLoader(val_dataset,batch_size1,shuffleFalse,pin_memoryTrue,num_workersnw,collate_fntrain_dataset.collate_fn)# create model num_classes equal background classesmodel create_model(num_classesargs.num_classes 1, load_pretrain_weightsargs.pretrain)model.to(device)train_loss []learning_rate []val_map []# define optimizerparams [p for p in model.parameters() if p.requires_grad]optimizer torch.optim.SGD(params, lrargs.lr,momentumargs.momentum,weight_decayargs.weight_decay)scaler torch.cuda.amp.GradScaler() if args.amp else None# learning rate schedulerlr_scheduler torch.optim.lr_scheduler.MultiStepLR(optimizer,milestonesargs.lr_steps,gammaargs.lr_gamma)# 如果传入resume参数即上次训练的权重地址则接着上次的参数训练if args.resume:# If map_location is missing, torch.load will first load the module to CPU# and then copy each parameter to where it was saved,# which would result in all processes on the same machine using the same set of devices.checkpoint torch.load(args.resume, map_locationcpu) # 读取之前保存的权重文件(包括优化器以及学习率策略)model.load_state_dict(checkpoint[model])optimizer.load_state_dict(checkpoint[optimizer])lr_scheduler.load_state_dict(checkpoint[lr_scheduler])args.start_epoch checkpoint[epoch] 1if args.amp and scaler in checkpoint:scaler.load_state_dict(checkpoint[scaler])for epoch in range(args.start_epoch, args.epochs):# train for one epoch, printing every 50 iterationsmean_loss, lr utils.train_one_epoch(model, optimizer, train_data_loader,device, epoch, print_freq50,warmupTrue, scalerscaler)train_loss.append(mean_loss.item())learning_rate.append(lr)# update the learning ratelr_scheduler.step()# evaluate on the test datasetdet_info, seg_info utils.evaluate(model, val_data_loader, devicedevice)# write detection into txtwith open(det_results_file, a) as f:# 写入的数据包括coco指标还有loss和learning rateresult_info [f{i:.4f} for i in det_info [mean_loss.item()]] [f{lr:.6f}]txt epoch:{} {}.format(epoch, .join(result_info))f.write(txt \n)# write seg into txtwith open(seg_results_file, a) as f:# 写入的数据包括coco指标还有loss和learning rateresult_info [f{i:.4f} for i in seg_info [mean_loss.item()]] [f{lr:.6f}]txt epoch:{} {}.format(epoch, .join(result_info))f.write(txt \n)val_map.append(det_info[1]) # pascal mAP# save weightssave_files {model: model.state_dict(),optimizer: optimizer.state_dict(),lr_scheduler: lr_scheduler.state_dict(),epoch: epoch}if args.amp:save_files[scaler] scaler.state_dict()torch.save(save_files, ./save_weights/model_{}.pth.format(epoch))# plot loss and lr curveif len(train_loss) ! 0 and len(learning_rate) ! 0:from plot_curve import plot_loss_and_lrplot_loss_and_lr(train_loss, learning_rate)# plot mAP curveif len(val_map) ! 0:from plot_curve import plot_mapplot_map(val_map)if __name__ __main__:import argparseparser argparse.ArgumentParser(description__doc__)# 训练设备类型parser.add_argument(--device, defaultcuda:0, helpdevice)# 训练数据集的根目录parser.add_argument(--data-path, default/root/autodl-tmp/COCO2017, helpdataset)# 检测目标类别数(不包含背景)parser.add_argument(--num-classes, default90, typeint, helpnum_classes)# 文件保存地址parser.add_argument(--output-dir, default./save_weights, helppath where to save)# 若需要接着上次训练则指定上次训练保存权重文件地址parser.add_argument(--resume, default, typestr, helpresume from checkpoint)# 指定接着从哪个epoch数开始训练parser.add_argument(--start_epoch, default0, typeint, helpstart epoch)# 训练的总epoch数parser.add_argument(--epochs, default3, typeint, metavarN,helpnumber of total epochs to run)# 学习率parser.add_argument(--lr, default0.004, typefloat,helpinitial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu)# SGD的momentum参数parser.add_argument(--momentum, default0.9, typefloat, metavarM,helpmomentum)# SGD的weight_decay参数parser.add_argument(--wd, --weight-decay, default1e-4, typefloat,metavarW, helpweight decay (default: 1e-4),destweight_decay)# 针对torch.optim.lr_scheduler.MultiStepLR的参数parser.add_argument(--lr-steps, default[16, 22], nargs, typeint,helpdecrease lr every step-size epochs)# 针对torch.optim.lr_scheduler.MultiStepLR的参数parser.add_argument(--lr-gamma, default0.1, typefloat, helpdecrease lr by a factor of lr-gamma)# 训练的batch size(如果内存/GPU显存充裕建议设置更大)parser.add_argument(--batch_size, default2, typeint, metavarN,helpbatch size when training.)parser.add_argument(--aspect-ratio-group-factor, default3, typeint)parser.add_argument(--pretrain, typebool, defaultTrue, helpload COCO pretrain weights.)# 是否使用混合精度训练(需要GPU支持混合精度)parser.add_argument(--amp, defaultFalse, helpUse torch.cuda.amp for mixed precision training)args parser.parse_args()print(args)# 检查保存权重文件夹是否存在不存在则创建if not os.path.exists(args.output_dir):os.makedirs(args.output_dir)main(args) 资源有限所以只训练了一轮部分类别的准确度欠佳训练结果评估如下 性能评估代码如下 该脚本用于调用训练好的模型权重去计算验证集/测试集的COCO指标
以及每个类别的mAP(IoU0.5)
import os
import jsonimport torch
from tqdm import tqdm
import numpy as npimport transforms
from backbone import resnet50_fpn_backbone
from network_files import MaskRCNN
from my_dataset_coco import CocoDetection
from my_dataset_voc import VOCInstances
from train_utils import EvalCOCOMetricdef summarize(self, catIdNone):Compute and display summary metrics for evaluation results.Note this functin can *only* be applied on the default parameter settingdef _summarize(ap1, iouThrNone, areaRngall, maxDets100):p self.paramsiStr {:18} {} [ IoU{:9} | area{:6s} | maxDets{:3d} ] {:0.3f}titleStr Average Precision if ap 1 else Average RecalltypeStr (AP) if ap 1 else (AR)iouStr {:0.2f}:{:0.2f}.format(p.iouThrs[0], p.iouThrs[-1]) \if iouThr is None else {:0.2f}.format(iouThr)aind [i for i, aRng in enumerate(p.areaRngLbl) if aRng areaRng]mind [i for i, mDet in enumerate(p.maxDets) if mDet maxDets]if ap 1:# dimension of precision: [TxRxKxAxM]s self.eval[precision]# IoUif iouThr is not None:t np.where(iouThr p.iouThrs)[0]s s[t]if isinstance(catId, int):s s[:, :, catId, aind, mind]else:s s[:, :, :, aind, mind]else:# dimension of recall: [TxKxAxM]s self.eval[recall]if iouThr is not None:t np.where(iouThr p.iouThrs)[0]s s[t]if isinstance(catId, int):s s[:, catId, aind, mind]else:s s[:, :, aind, mind]if len(s[s -1]) 0:mean_s -1else:mean_s np.mean(s[s -1])print_string iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)return mean_s, print_stringstats, print_list [0] * 12, [] * 12stats[0], print_list[0] _summarize(1)stats[1], print_list[1] _summarize(1, iouThr.5, maxDetsself.params.maxDets[2])stats[2], print_list[2] _summarize(1, iouThr.75, maxDetsself.params.maxDets[2])stats[3], print_list[3] _summarize(1, areaRngsmall, maxDetsself.params.maxDets[2])stats[4], print_list[4] _summarize(1, areaRngmedium, maxDetsself.params.maxDets[2])stats[5], print_list[5] _summarize(1, areaRnglarge, maxDetsself.params.maxDets[2])stats[6], print_list[6] _summarize(0, maxDetsself.params.maxDets[0])stats[7], print_list[7] _summarize(0, maxDetsself.params.maxDets[1])stats[8], print_list[8] _summarize(0, maxDetsself.params.maxDets[2])stats[9], print_list[9] _summarize(0, areaRngsmall, maxDetsself.params.maxDets[2])stats[10], print_list[10] _summarize(0, areaRngmedium, maxDetsself.params.maxDets[2])stats[11], print_list[11] _summarize(0, areaRnglarge, maxDetsself.params.maxDets[2])print_info \n.join(print_list)if not self.eval:raise Exception(Please run accumulate() first)return stats, print_infodef save_info(coco_evaluator,category_index: dict,save_name: str record_mAP.txt):iou_type coco_evaluator.params.iouTypeprint(fIoU metric: {iou_type})# calculate COCO info for all classescoco_stats, print_coco summarize(coco_evaluator)# calculate voc info for every classes(IoU0.5)classes [v for v in category_index.values() if v ! N/A]voc_map_info_list []for i in range(len(classes)):stats, _ summarize(coco_evaluator, catIdi)voc_map_info_list.append( {:15}: {}.format(classes[i], stats[1]))print_voc \n.join(voc_map_info_list)print(print_voc)# 将验证结果保存至txt文件中with open(save_name, w) as f:record_lines [COCO results:,print_coco,,mAP(IoU0.5) for each category:,print_voc]f.write(\n.join(record_lines))def main(parser_data):device torch.device(parser_data.device if torch.cuda.is_available() else cpu)print(Using {} device training..format(device.type))data_transform {val: transforms.Compose([transforms.ToTensor()])}# read class_indictlabel_json_path parser_data.label_json_pathassert os.path.exists(label_json_path), json file {} dose not exist..format(label_json_path)with open(label_json_path, r) as f:category_index json.load(f)data_root parser_data.data_path# 注意这里的collate_fn是自定义的因为读取的数据包括image和targets不能直接使用默认的方法合成batchbatch_size parser_data.batch_sizenw min([os.cpu_count(), batch_size if batch_size 1 else 0, 8]) # number of workersprint(Using %g dataloader workers % nw)# load validation data setval_dataset CocoDetection(data_root, val, data_transform[val])# VOCdevkit - VOC2012 - ImageSets - Main - val.txt# val_dataset VOCInstances(data_root, year2012, txt_nameval.txt, transformsdata_transform[val])val_dataset_loader torch.utils.data.DataLoader(val_dataset,batch_sizebatch_size,shuffleFalse,pin_memoryTrue,num_workersnw,collate_fnval_dataset.collate_fn)# create modelbackbone resnet50_fpn_backbone()model MaskRCNN(backbone, num_classesargs.num_classes 1)# 载入你自己训练好的模型权重weights_path parser_data.weights_pathassert os.path.exists(weights_path), not found {} file..format(weights_path)model.load_state_dict(torch.load(weights_path, map_locationcpu)[model])# print(model)model.to(device)# evaluate on the val datasetcpu_device torch.device(cpu)det_metric EvalCOCOMetric(val_dataset.coco, bbox, det_results.json)seg_metric EvalCOCOMetric(val_dataset.coco, segm, seg_results.json)model.eval()with torch.no_grad():for image, targets in tqdm(val_dataset_loader, descvalidation...):# 将图片传入指定设备deviceimage list(img.to(device) for img in image)# inferenceoutputs model(image)outputs [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]det_metric.update(targets, outputs)seg_metric.update(targets, outputs)det_metric.synchronize_results()seg_metric.synchronize_results()det_metric.evaluate()seg_metric.evaluate()save_info(det_metric.coco_evaluator, category_index, det_record_mAP.txt)save_info(seg_metric.coco_evaluator, category_index, seg_record_mAP.txt)if __name__ __main__:import argparseparser argparse.ArgumentParser(description__doc__)# 使用设备类型parser.add_argument(--device, defaultcuda, helpdevice)# 检测目标类别数(不包含背景)parser.add_argument(--num-classes, typeint, default90, helpnumber of classes)# 数据集的根目录parser.add_argument(--data-path, default/root/autodl-tmp/COCO2017, helpdataset root)# 训练好的权重文件parser.add_argument(--weights-path, default./save_weights/model_0.pth, typestr, helptraining weights)# batch size(set to 1, dont change)parser.add_argument(--batch-size, default1, typeint, metavarN,helpbatch size when validation.)# 类别索引和类别名称对应关系parser.add_argument(--label-json-path, typestr, defaultcoco91_indices.json)args parser.parse_args()main(args) 结果预测 输入图像1 效果展示1 输入图像2 效果展示2 预测代码如下
import os
import time
import jsonimport numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transformsfrom network_files import MaskRCNN
from backbone import resnet50_fpn_backbone
from draw_box_utils import draw_objsdef create_model(num_classes, box_thresh0.5):backbone resnet50_fpn_backbone()model MaskRCNN(backbone,num_classesnum_classes,rpn_score_threshbox_thresh,box_score_threshbox_thresh)return modeldef time_synchronized():torch.cuda.synchronize() if torch.cuda.is_available() else Nonereturn time.time()def main():num_classes 90 # 不包含背景box_thresh 0.5weights_path ./save_weights/model_0.pthimg_path ./street.pnglabel_json_path ./coco91_indices.json# get devicesdevice torch.device(cuda:0 if torch.cuda.is_available() else cpu)print(using {} device..format(device))# create modelmodel create_model(num_classesnum_classes 1, box_threshbox_thresh)# load train weightsassert os.path.exists(weights_path), {} file dose not exist..format(weights_path)weights_dict torch.load(weights_path, map_locationcpu)weights_dict weights_dict[model] if model in weights_dict else weights_dictmodel.load_state_dict(weights_dict)model.to(device)# read class_indictassert os.path.exists(label_json_path), json file {} dose not exist..format(label_json_path)with open(label_json_path, r) as json_file:category_index json.load(json_file)# load imageassert os.path.exists(img_path), f{img_path} does not exits.original_img Image.open(img_path).convert(RGB)# from pil image to tensor, do not normalize imagedata_transform transforms.Compose([transforms.ToTensor()])img data_transform(original_img)# expand batch dimensionimg torch.unsqueeze(img, dim0)model.eval() # 进入验证模式with torch.no_grad():# initimg_height, img_width img.shape[-2:]init_img torch.zeros((1, 3, img_height, img_width), devicedevice)model(init_img)t_start time_synchronized()predictions model(img.to(device))[0]t_end time_synchronized()print(inferenceNMS time: {}.format(t_end - t_start))predict_boxes predictions[boxes].to(cpu).numpy()predict_classes predictions[labels].to(cpu).numpy()predict_scores predictions[scores].to(cpu).numpy()predict_mask predictions[masks].to(cpu).numpy()predict_mask np.squeeze(predict_mask, axis1) # [batch, 1, h, w] - [batch, h, w]if len(predict_boxes) 0:print(没有检测到任何目标!)returnplot_img draw_objs(original_img,boxespredict_boxes,classespredict_classes,scorespredict_scores,maskspredict_mask,category_indexcategory_index,line_thickness3,fontarial.ttf,font_size20)plt.imshow(plot_img)plt.show()# 保存预测的图片结果plot_img.save(test_result.jpg)if __name__ __main__:main() 总结
Mask R-CNN通过引入RoIAlign层和全卷积网络分枝不仅提高了分割精度还实现了像素级的掩码输出极大地推动了目标检测技术的发展。这一突破性工作不仅在COCO等数据集上取得了最佳性能而且对后续研究产生了深远影响激发了包括注意力机制、多模态信息融合和小样本学习在内的多种优化策略的研究。尽管Mask R-CNN在速度和参数量方面仍存在挑战但其未来的优化方向如结合强化学习、模型轻量化等有望进一步提升模型性能降低计算成本使其在实时应用和资源受限的设备上更具实用性。