• 大小:
    文件类型: .gz
    金币: 1
    下载: 0 次
    发布日期: 2023-06-16
  • 语言: 其他
  • 标签: 深度学习  

资源简介

《动手学深度学习》(Dive into Deep Learning)原书中的MXNet实现改为PyTorch实现。

资源截图

代码片段和文件信息

import collections
import math
import os
import random
import sys
import tarfile
import time
import zipfile
from tqdm import tqdm
from collections import namedtuple

from IPython import display
from matplotlib import pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchtext
import torchtext.vocab as Vocab
import numpy as np


VOC_CLASSES = [‘background‘ ‘aeroplane‘ ‘bicycle‘ ‘bird‘ ‘boat‘
               ‘bottle‘ ‘bus‘ ‘car‘ ‘cat‘ ‘chair‘ ‘cow‘
               ‘diningtable‘ ‘dog‘ ‘horse‘ ‘motorbike‘ ‘person‘
               ‘potted plant‘ ‘sheep‘ ‘sofa‘ ‘train‘ ‘tv/monitor‘]


VOC_COLORMAP = [[0 0 0] [128 0 0] [0 128 0] [128 128 0]
                [0 0 128] [128 0 128] [0 128 128] [128 128 128]
                [64 0 0] [192 0 0] [64 128 0] [192 128 0]
                [64 0 128] [192 0 128] [64 128 128] [192 128 128]
                [0 64 0] [128 64 0] [0 192 0] [128 192 0]
                [0 64 128]]



# ###################### 3.2 ############################
def set_figsize(figsize=(3.5 2.5)):
    use_svg_display()
    # 设置图的尺寸
    plt.rcParams[‘figure.figsize‘] = figsize

def use_svg_display():
    “““Use svg format to display plot in jupyter“““
    display.set_matplotlib_formats(‘svg‘)

def data_iter(batch_size features labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)  # 样本的读取顺序是随机的
    for i in range(0 num_examples batch_size):
        j = torch.LongTensor(indices[i: min(i + batch_size num_examples)]) # 最后一次可能不足一个batch
        yield  features.index_select(0 j) labels.index_select(0 j) 

def linreg(X w b):
    return torch.mm(X w) + b

def squared_loss(y_hat y): 
    # 注意这里返回的是向量 另外 pytorch里的MSELoss并没有除以 2
    return ((y_hat - y.view(y_hat.size())) ** 2) / 2

def sgd(params lr batch_size):
    # 为了和原书保持一致,这里除以了batch_size,但是应该是不用除的,因为一般用PyTorch计算loss时就默认已经
    # 沿batch维求了平均了。
    for param in params:
        param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data



# ######################3##### 3.5 #############################
def get_fashion_mnist_labels(labels):
    text_labels = [‘t-shirt‘ ‘trouser‘ ‘pullover‘ ‘dress‘ ‘coat‘
                   ‘sandal‘ ‘shirt‘ ‘sneaker‘ ‘bag‘ ‘ankle boot‘]
    return [text_labels[int(i)] for i in labels]

def show_fashion_mnist(images labels):
    use_svg_display()
    # 这里的_表示我们忽略(不使用)的变量
    _ figs = plt.subplots(1 len(images) figsize=(12 12))
    for f img lbl in zip(figs images labels):
        f.imshow(img.view((28 28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    # plt.show()

# 5.6 修改
# def load_data_fashion_mnist(batch_size root=‘~/Datase

评论

共有 条评论