• 大小: 6KB
    文件类型: .py
    金币: 2
    下载: 1 次
    发布日期: 2021-06-06
  • 语言: Python
  • 标签: yolo  yolov3  kmeans  聚类  

资源简介

找了好久找到的,求出的kmeans相对效果会好一些 - 作用:根据darknet训练txt的存储文件夹,经过多次计算求anchors,最后取2次avg_iou相等时的结果 - 需要修改:txt储存的文件夹、生成anchors.txt路径、anchors数量、训练宽度、训练高度 - 运行:python kmeans.py

资源截图

代码片段和文件信息

# -*- coding: utf-8 -*-
import numpy as np
import random
import argparse
import os
#参数名称
parser = argparse.ArgumentParser(description=‘使用该脚本生成YOLO-V3的anchor boxes\n‘)
parser.add_argument(‘--input_annotation_txt_dir‘default=‘labels‘type=strhelp=‘输入存储图片的标注txt文件(注意不要有中文)‘)
parser.add_argument(‘--output_anchors_txt‘default=‘anchors.txt‘type=strhelp=‘输出的存储Anchor boxes的文本文件‘)
parser.add_argument(‘--input_num_anchors‘default=6type=inthelp=‘输入要计算的聚类(Anchor boxes的个数)‘)
parser.add_argument(‘--input_cfg_width‘default=224type=inthelp=“配置文件中width“)
parser.add_argument(‘--input_cfg_height‘default=224type=inthelp=“配置文件中height“)
args = parser.parse_args()
‘‘‘
centroids 聚类点 尺寸是 numx2类型是ndarray
annotation_array 其中之一的标注框
‘‘‘
def IOU(annotation_arraycentroids):
    #
    similarities = []
    #其中一个标注框
    wh = annotation_array
    for centroid in centroids:
        c_wc_h = centroid
        if c_w >=w and c_h >= h:#第1中情况
            similarity = w*h/(c_w*c_h)
        elif c_w >= w and c_h <= h:#第2中情况
            similarity = w*c_h/(w*h + (c_w - w)*c_h)
        elif c_w <= w and c_h >= h:#第3种情况
            similarity = c_w*h/(w*h +(c_h - h)*c_w)
        else:#第3种情况
            similarity = (c_w*c_h)/(w*h)
        similarities.append(similarity)
    #将列表转换为ndarray
    return np.array(similaritiesnp.float32) #返回的是一维数组,尺寸为(num)
 
‘‘‘
k_means:k均值聚类
annotations_array 所有的标注框的宽高,N个标注框,尺寸是Nx2类型是ndarray
centroids 聚类点 尺寸是 numx2类型是ndarray
‘‘‘
def k_means(annotations_arraycentroidseps=0.00005iterations=200000):
    #
    N = annotations_array.shape[0]#C=2
    num = centroids.shape[0]
    #损失函数
    distance_sum_pre = -1
    assignments_pre = -1*np.ones(Ndtype=np.int64)
    #
    iteration = 0
    #循环处理
    while(True):
        #
        iteration += 1
        #
        distances = []
        #循环计算每一个标注框与所有的聚类点的距离(IOU)
        for i in range(N):
            distance = 1 - IOU(annotations_array[i]centroids)
            distances.append(distance)
        #列表转换成ndarray
        distances_array = np.array(distancesnp.float32)#该ndarray的尺寸为 Nxnum
        #找出每一个标注框到当前聚类点最近的点
        assignments = np.argmin(distances_arrayaxis=1)#计算每一行的最小值的位置索引
        #计算距离的总和,相当于k均值聚类的损失函数
        distances_sum = np.sum(distances_array)
        #计算新的聚类点
        centr

评论

共有 条评论