• 大小: 66.94MB
    文件类型: .zip
    金币: 2
    下载: 2 次
    发布日期: 2022-01-12
  • 语言: 其他
  • 标签: 人群计数  CSRNet  

资源简介

使用pytorch实现了CSRNet人群计数模型的复现,如果下载文档之后有任何问题均可以私信博主进行讨论

资源截图

代码片段和文件信息


import os
import numpy as np
import scipy
import scipy.io as io
from scipy import spatial
from scipy.ndimage.filters import gaussian_filter
import glob
from matplotlib import pyplot as plt
import h5py

#高斯核函数
def gaussian_filter_density(gt):
    print(gt.shape)

    density = np.zeros(gt.shape dtype=np.float32)
    gt_count = np.count_nonzero(gt)
    if gt_count == 0:
        return density
    pts = np.array(list(zip(np.nonzero(gt)[1] np.nonzero(gt)[0])))

    #构造KDTree寻找相邻的人头位置
    tree = scipy.spatial.KDTree(pts.copy() leafsize=2048)
    distances locations = tree.query(pts k=4)

    print(‘generate density...‘)
    for i pt in enumerate(pts):
        pt2d = np.zeros(gt.shape dtype=np.float32)
        pt2d[pt[1]pt[0]] = 1.
        if gt_count > 1:
            #相邻三个人头的平均距离,其中beta=0.3
            sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1
        else:
            sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point
        density += scipy.ndimage.filters.gaussian_filter(pt2d sigma mode=‘constant‘)
    print(‘done.‘)
    return density

#生成密度图
def create_ground_truth_density(path_sets):
    img_paths = []

    for path in path_sets:
        for img_path in glob.glob(os.path.join(path ‘*.jpg‘)):
            img_paths.append(img_path)
    print(‘图片数量:‘ len(img_paths))

    for img_path in img_paths:
        print(img_path)
        # 获取每张图片对应的mat标记文件
        mat = io.loadmat(img_path.replace(‘images‘ ‘ground_truth‘).replace(‘IMG_‘ ‘GT_IMG_‘).replace(‘.jpg‘ ‘.mat‘))
        img = plt.imread(img_path)
        # 生成密度图
        gt_density_map = np.zeros((img.shape[0] img.shape[1]))
        gt = mat[“image_info“][0 0][0 0][0]
        for i in range(0 len(gt)):
            if int(gt[i][1]) < img.shape[0] and int(gt[i][0]) < img.shape[1]:
                gt_density_map[int(gt[i][1]) int(gt[i][0])] = 1
        gt_density_map = gaussian_filter_density(gt_density_map)
        # 保存生成的密度图
        with h5py.File(img_path.replace(‘images‘ ‘ground_truth‘).replace(‘.jpg‘ ‘.h5‘) ‘w‘) as hf:
            hf[‘density‘] = gt_density_map

        #测试
        print(‘总数量=‘len(gt))
        print(‘密度图=‘gt_density_map.sum())


# 查看原始图片和生成的密度图
def show(img_path):
    from PIL import Image
    from matplotlib import cm as CM

    plt.imshow(Image.open(img_path))
    plt.show()
    gt_file = h5py.File(img_path.replace(‘.jpg‘ ‘.h5‘).replace(‘images‘ ‘ground_truth‘) ‘r‘)
    groundtruth = np.asarray(gt_file[‘density‘])
    plt.imshow(groundtruth cmap=CM.jet)
    plt.show()
    print(‘总人数为:‘np.sum(groundtruth))


if __name__ == ‘__main__‘:
    # set the root to the Shanghai dataset you download
    root = ‘E:/数据集/ShanghaiTech_Crowd_Counting_Dataset/‘

    # now generate the ShanghaiA‘s ground truth
    part_A_train = os.path.join(root ‘part_A_final/train_data‘ ‘images‘)
    part_A_test = os.path.join(root ‘part_A_final/test_data‘ ‘images‘)
    part_B_train = os.path.join(root ‘part_B_final/train_data‘ ‘images‘)
    par

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2020-08-21 15:39  CSRNet-pytorch\
     目录           0  2020-08-19 09:36  CSRNet-pytorch\__pycache__\
     文件        2280  2020-08-14 15:55  CSRNet-pytorch\__pycache__\model.cpython-38.pyc
     文件        3462  2020-08-21 15:39  CSRNet-pytorch\make_dataset.py
     文件        2241  2020-08-17 16:44  CSRNet-pytorch\model.py
     目录           0  2020-08-19 10:31  CSRNet-pytorch\test_image\
     文件    65060360  2020-08-19 09:35  CSRNet-pytorch\test_image\CSRNet_0032.pt
     文件     3147776  2020-08-19 09:35  CSRNet-pytorch\test_image\IMG_1.h5
     文件      143045  2020-08-19 09:35  CSRNet-pytorch\test_image\IMG_1.jpg
     目录           0  2020-08-19 09:56  CSRNet-pytorch\test_image\__pycache__\
     文件        2157  2020-08-19 09:35  CSRNet-pytorch\test_image\__pycache__\model.cpython-37.pyc
     文件        2174  2020-08-19 09:56  CSRNet-pytorch\test_image\__pycache__\model.cpython-38.pyc
     文件      710600  2019-07-11 18:59  CSRNet-pytorch\test_image\classroom1.jpg
     文件      725792  2019-07-11 18:59  CSRNet-pytorch\test_image\classroom2.jpg
     文件     2289534  2019-07-11 18:59  CSRNet-pytorch\test_image\classroom3.png
     文件     2163774  2019-07-11 19:00  CSRNet-pytorch\test_image\classroom4.png
     文件     2154330  2019-07-11 19:00  CSRNet-pytorch\test_image\classroom5.png
     文件        2241  2020-08-19 09:35  CSRNet-pytorch\test_image\model.py
     文件        1772  2020-08-19 10:31  CSRNet-pytorch\test_image\test_single_image.py
     文件        4941  2020-08-19 09:44  CSRNet-pytorch\train_model.py

评论

共有 条评论