资源简介
mmdetection在2019年12月13号进行了新版本的更新,其中对api/train.py增加torch.distributed,这块在windows下不支持,所以要在windows中训练的话需要把v1.0rc1的版本的train与新版本的train进行合并,主要是去除torch.distributed以及_non_dist_train的修改为主。
代码片段和文件信息
from __future__ import division
import logging
import random
import numpy as np
import re
from collections import OrderedDict
import torch
from mmcv.runner import Runner DistSamplerSeedHook obj_from_dict
from mmcv.parallel import MMDataParallel MMDistributedDataParallel
from mmdet import datasets
from mmdet.core import (DistOptimizerHook DistEvalmAPHook
CocoDistEvalRecallHook CocoDistEvalmAPHook
Fp16OptimizerHook)
from mmdet.datasets import build_dataloader DATASETS
from mmdet.models import RPN
# from .env import get_root_logger
def get_root_logger(log_file=None log_level=logging.INFO):
logger = logging.getLogger(‘mmdet‘)
# if the logger has been initialized just return it
if logger.hasHandlers():
return logger
logging.basicConfig(
format=‘%(asctime)s - %(levelname)s - %(message)s‘ level=log_level)
# rank _ = get_dist_info()
# if rank != 0:
# logger.setLevel(‘ERROR‘)
# elif log_file is not None:
# file_handler = logging.FileHandler(log_file ‘w‘)
# file_handler.setFormatter(
# logging.Formatter(‘%(asctime)s - %(levelname)s - %(message)s‘))
# file_handler.setLevel(log_level)
# logger.addHandler(file_handler)
return logger
def set_random_seed(seed deterministic=False):
“““Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend i.e. set ‘torch.backends.cudnn.deterministic‘
to True and ‘torch.backends.cudnn.benchmark‘ to False.
Default: False.
“““
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name loss_value in losses.items():
if isinstance(loss_value torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
‘{} is not a tensor or list of tensors‘.format(loss_name))
loss = sum(_value for _key _value in log_vars.items() if ‘loss‘ in _key)
log_vars[‘loss‘] = loss
for name in log_vars:
log_vars[name] = log_vars[name].item()
return loss log_vars
def batch_processor(model data train_mode):
losses = model(**data)
loss log_vars = parse_losses(losses)
outputs = dict(
loss=loss log_vars=log_vars num_samples=len(data[‘img‘].data))
return outputs
def train_detector(model
dataset
cfg
distributed=False
相关资源
- Python3入门与进阶
- 编译原理由正则表达式到NFA到DFA到最
- Pyboard利用两个Zigbee模块发送并接收
- python实现贪吃蛇小游戏
- Python-TensorFlow语义分割组件
- numpy-1.17.2+mkl-cp37-cp37m-win_amd64.rar
- python大作业--爬虫完美应付大作业.z
- pyltp python3.7可用版本,已编译好的.
- python小游戏大全——30个
- python3网络爬虫开发实战 无密码
- 山东大学抢课脚本
- Python赶集网北京地区招聘信息爬虫
- wordcloud-1.6.0-cp38-cp38-win32.whl
- Python从入门到精通(明日科技出版)
- python 批量修改文件夹和文件名 解压
- python 快速搭建blog demo
- 面向Arcgis的python脚本编程 中文教程英
- 第十届蓝桥杯大赛青少年创意编程P
- python实现Alphapose骨骼关键点信息的提
- python实现photoshop自动化
- 利用Python爬虫抓取网页上的图片含异
- python基础视频地址.txt
- Python实现社区发现算法-fast_unfolding算
- python-cwt时频图绘制
- python火焰检测代码
- 树莓派_python_PCA9685_16路舵机自定义角
- python从入门到实践课后试一试代码.
- Python PyQt5编写的天气预报
- Hopfield Neural Network——神经网络pytho
- 计算N50的python脚本
评论
共有 条评论