资源简介
mmdetection在2019年12月13号进行了新版本的更新,其中对api/train.py增加torch.distributed,这块在windows下不支持,所以要在windows中训练的话需要把v1.0rc1的版本的train与新版本的train进行合并,主要是去除torch.distributed以及_non_dist_train的修改为主。
代码片段和文件信息
from __future__ import division
import logging
import random
import numpy as np
import re
from collections import OrderedDict
import torch
from mmcv.runner import Runner DistSamplerSeedHook obj_from_dict
from mmcv.parallel import MMDataParallel MMDistributedDataParallel
from mmdet import datasets
from mmdet.core import (DistOptimizerHook DistEvalmAPHook
CocoDistEvalRecallHook CocoDistEvalmAPHook
Fp16OptimizerHook)
from mmdet.datasets import build_dataloader DATASETS
from mmdet.models import RPN
# from .env import get_root_logger
def get_root_logger(log_file=None log_level=logging.INFO):
logger = logging.getLogger(‘mmdet‘)
# if the logger has been initialized just return it
if logger.hasHandlers():
return logger
logging.basicConfig(
format=‘%(asctime)s - %(levelname)s - %(message)s‘ level=log_level)
# rank _ = get_dist_info()
# if rank != 0:
# logger.setLevel(‘ERROR‘)
# elif log_file is not None:
# file_handler = logging.FileHandler(log_file ‘w‘)
# file_handler.setFormatter(
# logging.Formatter(‘%(asctime)s - %(levelname)s - %(message)s‘))
# file_handler.setLevel(log_level)
# logger.addHandler(file_handler)
return logger
def set_random_seed(seed deterministic=False):
“““Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend i.e. set ‘torch.backends.cudnn.deterministic‘
to True and ‘torch.backends.cudnn.benchmark‘ to False.
Default: False.
“““
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name loss_value in losses.items():
if isinstance(loss_value torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
‘{} is not a tensor or list of tensors‘.format(loss_name))
loss = sum(_value for _key _value in log_vars.items() if ‘loss‘ in _key)
log_vars[‘loss‘] = loss
for name in log_vars:
log_vars[name] = log_vars[name].item()
return loss log_vars
def batch_processor(model data train_mode):
losses = model(**data)
loss log_vars = parse_losses(losses)
outputs = dict(
loss=loss log_vars=log_vars num_samples=len(data[‘img‘].data))
return outputs
def train_detector(model
dataset
cfg
distributed=False
相关资源
- 二级考试python试题12套(包括选择题和
- pywin32_python3.6_64位
- python+ selenium教程
- PycURL(Windows7/Win32)Python2.7安装包 P
- 英文原版-Scientific Computing with Python
- 7.图像风格迁移 基于深度学习 pyt
- 基于Python的学生管理系统
- A Byte of Python(简明Python教程)(第
- Python实例174946
- Python 人脸识别
- Python 人事管理系统
- 基于python-flask的个人博客系统
- 计算机视觉应用开发流程
- python 调用sftp断点续传文件
- python socket游戏
- 基于Python爬虫爬取天气预报信息
- python函数编程和讲解
- Python开发的个人博客
- 基于python的三层神经网络模型搭建
- python实现自动操作windows应用
- python人脸识别(opencv)
- python 绘图(方形、线条、圆形)
- python疫情卡UN管控
- python 连连看小游戏源码
- 基于PyQt5的视频播放器设计
- 一个简单的python爬虫
- csv文件行列转换python实现代码
- Python操作Mysql教程手册
- Python Machine Learning Case Studies
- python获取硬件信息
评论
共有 条评论