• 大小: 23.93MB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2023-07-07
  • 语言: 其他
  • 标签: 机器学习  Python  

资源简介

创建基本的神经网络,通过使用MNIST训练数据集进行训练,使用MNIST测试集和自己创建的手写数字图像数据对神经网络进行测试

资源截图

代码片段和文件信息

import numpy
import scipy.special
import matplotlib.pyplot
import imageio
import glob

# 神经网络类定义
class neuralNetwork:

    # 初始化神经网络
    def __init__(self inputnodes hiddennodes outputnodes learningrate):
        # 设置每个输入、隐藏、输出层的节点数
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes

        # 链接权值矩阵,wih and who
        # 数组中的权重是w_i_j,其中链路是从节点i到下一层的节点j
        # w11 w21
        # w12 w22 etc
        self.wih = numpy.random.normal(0.0 pow(self.inodes -0.5) (self.hnodes self.inodes))
        self.who = numpy.random.normal(0.0 pow(self.hnodes -0.5) (self.onodes self.hnodes))

        # 学习速率
        self.lr = learningrate

        # 激活函数是s型函数
        self.activation_function = lambda x: scipy.special.expit(x)

        pass

    # 训练神经网络
    def train(self inputs_list targets_list):
        # 将输入列表转换为二维数组
        inputs = numpy.array(inputs_list ndmin=2).T
        targets = numpy.array(targets_list ndmin=2).T

        # 计算信号到隐藏层
        hidden_inputs = numpy.dot(self.wih inputs)
        # 计算从隐含层出现的信号
        hidden_outputs = self.activation_function(hidden_inputs)

        # 计算信号到最终的输出层
        final_inputs = numpy.dot(self.who hidden_outputs)
        # 计算从最终输出层出现的信号
        final_outputs = self.activation_function(final_inputs)

        # 输出层误差为(目标值-实际值)
        output_errors = targets - final_outputs
        # 隐藏层错误是output_errors,按权重分割,在隐藏节点处重新组合
        hidden_errors = numpy.dot(self.who.T output_errors)

        # 更新隐藏层和输出层之间的链接的权重
        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs))
                                        numpy.transpose(hidden_outputs))

        # 更新输入层和隐藏层之间的链接的权值
        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs))
                                        numpy.transpose(inputs))

        pass

    # 查询神经网络
    def query(self inputs_list):
        # 将输入列表转换为二维数组
        inputs = numpy.array(inputs_list ndmin=2).T

        # 计算信号到隐藏层
        hidden_inputs = numpy.dot(self.wih inputs)
        # 计算从隐含层出现的信号
        hidden_outputs = self.activation_function(hidden_inputs)

        # 计算信号到最终的输出层
        final_inputs = numpy.dot(self.who hidden_outputs)
        # 计算从最终输出层出现的信号
        final_outputs = self.activation_function(final_inputs)

        return final_outputs

# 输入、隐藏和输出节点的数量
input_nodes = 784
hidden_nodes = 200
output_nodes = 10

# 学习速率
learning_rate = 0.1

# 创建神经网络实例
n = neuralNetwork(input_nodeshidden_nodesoutput_nodes learning_rate)

# 将mnist训练数据CSV文件加载到列表中
training_data_file = open(“MNIST_data/mnist_train.csv“ ‘r‘)
training_data_list = training_data_file.readlines()
training_data_file.close()

# 训练神经网络

# epochs是训练数据集用于训练的次数
epochs = 10

for e in range(epochs):
    # 检查训练数据集中的所有记录
    for record in training_data_list:
        # 用逗号分隔记录
        all_values

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2020-06-15 12:43  Handwritten_digit_recognition\
     文件        6987  2020-06-15 11:46  Handwritten_digit_recognition\Handwritten.py
     目录           0  2020-06-15 12:42  Handwritten_digit_recognition\MNIST_data\
     文件    18299443  2020-06-14 22:42  Handwritten_digit_recognition\MNIST_data\mnist_test.csv
     文件   109635994  2020-06-14 22:42  Handwritten_digit_recognition\MNIST_data\mnist_train.csv
     文件     7840016  1998-01-26 23:07  Handwritten_digit_recognition\MNIST_data\t10k-images.idx3-ubyte
     文件       10008  1998-01-26 23:07  Handwritten_digit_recognition\MNIST_data\t10k-labels.idx1-ubyte
     文件    47040016  1996-11-18 23:36  Handwritten_digit_recognition\MNIST_data\train-images.idx3-ubyte
     文件       60008  1996-11-18 23:36  Handwritten_digit_recognition\MNIST_data\train-labels.idx1-ubyte
     文件       17699  2020-06-14 23:03  Handwritten_digit_recognition\Number2.png
     文件       31649  2020-06-15 11:43  Handwritten_digit_recognition\Number4.png
     文件       20778  2020-06-15 09:18  Handwritten_digit_recognition\Number6.png
     文件           0  2020-03-04 20:55  Handwritten_digit_recognition\__init__.py

评论

共有 条评论