资源简介
共包含三个文件:
demo_video.py用于训练好的模型进行视频检测;
pr-curve.py用于绘制P-R曲线(方法一)
pascal_voc.py用于绘制P-R曲线(方法二)
代码片段和文件信息
#!/usr/bin/env python
# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
“““
Demo script showing detections in sample images.
See README.md for installation instructions before running.
“““
import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe os sys cv2
import argparse
CLASSES = (‘__background__‘
‘sea cucumber‘)
NETS = {‘vgg16‘: (‘VGG16‘
‘VGG16_faster_rcnn_final.caffemodel‘)
‘zf‘: (‘ZF‘
‘ZF_faster_rcnn_final.caffemodel‘)}
def vis_detections(im class_name dets thresh=0.5):
“““Draw detected bounding boxes.“““
inds = np.where(dets[: -1] >= thresh)[0]
if len(inds) == 0:
return
for i in inds:
bbox = dets[i :4]
score = dets[i -1]
font=cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(im ‘{}>= {:.1f}‘.format(class_namethresh) (int(bbox[0]) int(bbox[3])) font 1 (02550) 2)
cv2.rectangle(im(int(bbox[0]) int(bbox[3]))(int(bbox[2]) int(bbox[1]))(02550)5)
cv2.imshow(“im“im)
def demo(net im):
“““Detect object classes in an image using pre-computed object proposals.“““
# Load the demo image
#im_file = os.path.join(cfg.DATA_DIR ‘demo‘ image_name)
#im = cv2.imread(im_file)
# Detect all object classes and regress object bounds
timer = Timer()
timer.tic()
scores boxes = im_detect(net im)
timer.toc()
print (‘Detection took {:.3f}s for ‘
‘{:d} object proposals‘).format(timer.total_time boxes.shape[0])
# Visualize detections for each class
CONF_THRESH = 0.8
NMS_THRESH = 0.3
for cls_ind cls in enumerate(CLASSES[1:]):
cls_ind += 1 # because we skipped background
cls_boxes = boxes[: 4*cls_ind:4*(cls_ind + 1)]
cls_scores = scores[: cls_ind]
dets = np.hstack((cls_boxes
cls_scores[: np.newaxis])).astype(np.float32)
keep = nms(dets NMS_THRESH)
dets = dets[keep :]
vis_detections(im cls dets thresh=CONF_THRESH)
def parse_args():
“““Parse input arguments.“““
parser = argparse.ArgumentParser(description=‘Faster R-CNN demo‘)
parser.add_argument(‘--gpu‘ dest=‘gpu_id‘ help=‘GPU device id to use [0]‘
default=0 type=int)
parser.add_argument(‘--cpu‘ dest=‘cpu_mode‘
help=‘Use CPU mode (overrides --gpu)‘
action=‘store_true‘)
parser.add_argument(‘--net‘ dest=‘demo_net‘ help=‘Network to use [zf]‘
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 4444 2019-03-14 09:00 Faster_RCNN\demo_video.py
文件 14465 2019-03-14 06:47 Faster_RCNN\pascal_voc.py
文件 1595 2019-03-14 14:53 Faster_RCNN\pr-curve.py
目录 0 2019-03-14 17:26 Faster_RCNN\
- 上一篇:Caffe-ssd的宽高比聚类
- 下一篇:RTCM3.1标准协议
评论
共有 条评论