资源简介
使用pytorch实现了CSRNet人群计数模型的复现,如果下载文档之后有任何问题均可以私信博主进行讨论
代码片段和文件信息
import os
import numpy as np
import scipy
import scipy.io as io
from scipy import spatial
from scipy.ndimage.filters import gaussian_filter
import glob
from matplotlib import pyplot as plt
import h5py
#高斯核函数
def gaussian_filter_density(gt):
print(gt.shape)
density = np.zeros(gt.shape dtype=np.float32)
gt_count = np.count_nonzero(gt)
if gt_count == 0:
return density
pts = np.array(list(zip(np.nonzero(gt)[1] np.nonzero(gt)[0])))
#构造KDTree寻找相邻的人头位置
tree = scipy.spatial.KDTree(pts.copy() leafsize=2048)
distances locations = tree.query(pts k=4)
print(‘generate density...‘)
for i pt in enumerate(pts):
pt2d = np.zeros(gt.shape dtype=np.float32)
pt2d[pt[1]pt[0]] = 1.
if gt_count > 1:
#相邻三个人头的平均距离,其中beta=0.3
sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1
else:
sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point
density += scipy.ndimage.filters.gaussian_filter(pt2d sigma mode=‘constant‘)
print(‘done.‘)
return density
#生成密度图
def create_ground_truth_density(path_sets):
img_paths = []
for path in path_sets:
for img_path in glob.glob(os.path.join(path ‘*.jpg‘)):
img_paths.append(img_path)
print(‘图片数量:‘ len(img_paths))
for img_path in img_paths:
print(img_path)
# 获取每张图片对应的mat标记文件
mat = io.loadmat(img_path.replace(‘images‘ ‘ground_truth‘).replace(‘IMG_‘ ‘GT_IMG_‘).replace(‘.jpg‘ ‘.mat‘))
img = plt.imread(img_path)
# 生成密度图
gt_density_map = np.zeros((img.shape[0] img.shape[1]))
gt = mat[“image_info“][0 0][0 0][0]
for i in range(0 len(gt)):
if int(gt[i][1]) < img.shape[0] and int(gt[i][0]) < img.shape[1]:
gt_density_map[int(gt[i][1]) int(gt[i][0])] = 1
gt_density_map = gaussian_filter_density(gt_density_map)
# 保存生成的密度图
with h5py.File(img_path.replace(‘images‘ ‘ground_truth‘).replace(‘.jpg‘ ‘.h5‘) ‘w‘) as hf:
hf[‘density‘] = gt_density_map
#测试
print(‘总数量=‘len(gt))
print(‘密度图=‘gt_density_map.sum())
# 查看原始图片和生成的密度图
def show(img_path):
from PIL import Image
from matplotlib import cm as CM
plt.imshow(Image.open(img_path))
plt.show()
gt_file = h5py.File(img_path.replace(‘.jpg‘ ‘.h5‘).replace(‘images‘ ‘ground_truth‘) ‘r‘)
groundtruth = np.asarray(gt_file[‘density‘])
plt.imshow(groundtruth cmap=CM.jet)
plt.show()
print(‘总人数为:‘np.sum(groundtruth))
if __name__ == ‘__main__‘:
# set the root to the Shanghai dataset you download
root = ‘E:/数据集/ShanghaiTech_Crowd_Counting_Dataset/‘
# now generate the ShanghaiA‘s ground truth
part_A_train = os.path.join(root ‘part_A_final/train_data‘ ‘images‘)
part_A_test = os.path.join(root ‘part_A_final/test_data‘ ‘images‘)
part_B_train = os.path.join(root ‘part_B_final/train_data‘ ‘images‘)
par
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2020-08-21 15:39 CSRNet-pytorch\
目录 0 2020-08-19 09:36 CSRNet-pytorch\__pycache__\
文件 2280 2020-08-14 15:55 CSRNet-pytorch\__pycache__\model.cpython-38.pyc
文件 3462 2020-08-21 15:39 CSRNet-pytorch\make_dataset.py
文件 2241 2020-08-17 16:44 CSRNet-pytorch\model.py
目录 0 2020-08-19 10:31 CSRNet-pytorch\test_image\
文件 65060360 2020-08-19 09:35 CSRNet-pytorch\test_image\CSRNet_0032.pt
文件 3147776 2020-08-19 09:35 CSRNet-pytorch\test_image\IMG_1.h5
文件 143045 2020-08-19 09:35 CSRNet-pytorch\test_image\IMG_1.jpg
目录 0 2020-08-19 09:56 CSRNet-pytorch\test_image\__pycache__\
文件 2157 2020-08-19 09:35 CSRNet-pytorch\test_image\__pycache__\model.cpython-37.pyc
文件 2174 2020-08-19 09:56 CSRNet-pytorch\test_image\__pycache__\model.cpython-38.pyc
文件 710600 2019-07-11 18:59 CSRNet-pytorch\test_image\classroom1.jpg
文件 725792 2019-07-11 18:59 CSRNet-pytorch\test_image\classroom2.jpg
文件 2289534 2019-07-11 18:59 CSRNet-pytorch\test_image\classroom3.png
文件 2163774 2019-07-11 19:00 CSRNet-pytorch\test_image\classroom4.png
文件 2154330 2019-07-11 19:00 CSRNet-pytorch\test_image\classroom5.png
文件 2241 2020-08-19 09:35 CSRNet-pytorch\test_image\model.py
文件 1772 2020-08-19 10:31 CSRNet-pytorch\test_image\test_single_image.py
文件 4941 2020-08-19 09:44 CSRNet-pytorch\train_model.py
- 上一篇:Xshell5Xftp破解版.rar
- 下一篇:车牌字符集
评论
共有 条评论