资源简介
模仿mnist数据集制作自己的数据集,并读取自己的数据集
代码片段和文件信息
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License Version 2.0 (the “License“);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing software
# distributed under the License is distributed on an “AS IS“ BASIS
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
“““Functions for downloading and reading MNIST data.“““
#-*-coding:utf-8-*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import numpy
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.python.framework import dtypes
SOURCE_URL = ‘http://yann.lecun.com/exdb/mnist/‘
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder(‘>‘)
return numpy.frombuffer(bytestream.read(4) dtype=dt)[0]
def extract_images(f):
“““Extract the images into a 4D uint8 numpy array [index y x depth].
Args:
f: A file object that can be passed into a gzip reader.
Returns:
data: A 4D uint8 numpy array [index y x depth].
Raises:
ValueError: If the bytestream does not start with 2051.
“““
print(‘Extracting‘ f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(‘Invalid magic number %d in MNIST image file: %s‘ %
(magic f.name))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf dtype=numpy.uint8)
data = data.reshape(num_images rows cols 1)
return data
def dense_to_one_hot(labels_dense num_classes):
“““Convert class labels from scalars to one-hot vectors.“““
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def extract_labels(f one_hot=False num_classes=2): #这里打num_classes手动设置为自己的类别数
“““Extract the labels into a 1D uint8 numpy array [index].
Args:
f: A file object that can be passed into a gzip reader.
one_hot: Does one hot encoding for the result.
num_classes: Number of classes for the one hot encoding.
Returns:
labels: a 1D uint8 numpy array.
Raises:
ValueError: If the bystream doesn‘t start with 2049.
“““
print(‘Extracting‘ f.name)
with gzip.GzipFile(fileobj=f) as bytestream:
magic = _read
评论
共有 条评论