个人简历 网站开发,重庆网站建设公司联系方式,卖域名的网站哪些好,公司做网站好不好1、下载mnist数据集请自取#xff1a;
通过百度网盘分享的文件#xff1a;mnist 链接#xff1a;https://pan.baidu.com/s/1ia3vFA73hEtWK9qU-O-4iQ?pwdmnis 提取码#xff1a;mnis
下载后把数据集放在没有中文的路径下。
# 本文将下载好的数据集放在C:\DeepLearning\…1、下载mnist数据集请自取
通过百度网盘分享的文件mnist 链接https://pan.baidu.com/s/1ia3vFA73hEtWK9qU-O-4iQ?pwdmnis 提取码mnis
下载后把数据集放在没有中文的路径下。
# 本文将下载好的数据集放在C:\DeepLearning\dataset路径下
代码中 dataset_dir C:\DeepLearning\dataset
2、加载mnist数据集
以下是mnist.py文件用于加载数据集
# coding: utf-8
try:import urllib.request
except ImportError:raise ImportError(You should use Python 3.x)
import os.path
import gzip
import pickle
import os
import numpy as np# url_base https://ossci-datasets.s3.amazonaws.com/mnist/ # mirror site
key_file {train_img:train-images-idx3-ubyte.gz,train_label:train-labels-idx1-ubyte.gz,test_img:t10k-images-idx3-ubyte.gz,test_label:t10k-labels-idx1-ubyte.gz
}# 将下载好的数据集放在C:\DeepLearning\dataset路径下
dataset_dir C:\DeepLearning\dataset
save_file dataset_dir /mnist.pkltrain_num 60000
test_num 10000
img_dim (1, 28, 28)
img_size 784# 注释掉下载
# def _download(file_name):
# file_path dataset_dir / file_name# if os.path.exists(file_path):
# return# print(Downloading file_name ... )
# urllib.request.urlretrieve(url_base file_name, file_path)
# print(Done)# def download_mnist():
# for v in key_file.values():
# _download(v)def _load_label(file_name):file_path dataset_dir / file_nameprint(Converting file_name to NumPy Array ...)with gzip.open(file_path, rb) as f:labels np.frombuffer(f.read(), np.uint8, offset8)print(Done)return labelsdef _load_img(file_name):file_path dataset_dir / file_nameprint(Converting file_name to NumPy Array ...) with gzip.open(file_path, rb) as f:data np.frombuffer(f.read(), np.uint8, offset16)data data.reshape(-1, img_size)print(Done)return datadef _convert_numpy():dataset {}dataset[train_img] _load_img(key_file[train_img])dataset[train_label] _load_label(key_file[train_label]) dataset[test_img] _load_img(key_file[test_img])dataset[test_label] _load_label(key_file[test_label])return datasetdef init_mnist():# download_mnist() 取消下载dataset _convert_numpy()print(Creating pickle file ...)with open(save_file, wb) as f:pickle.dump(dataset, f, -1)print(Done!)def _change_one_hot_label(X):T np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] 1return Tdef load_mnist(normalizeTrue, flattenTrue, one_hot_labelFalse):读入MNIST数据集Parameters----------normalize : 将图像的像素值正规化为0.0~1.0one_hot_label : one_hot_label为True的情况下标签作为one-hot数组返回one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组flatten : 是否将图像展开为一维数组Returns-------(训练图像, 训练标签), (测试图像, 测试标签)if not os.path.exists(save_file):init_mnist()with open(save_file, rb) as f:dataset pickle.load(f)if normalize:for key in (train_img, test_img):dataset[key] dataset[key].astype(np.float32)dataset[key] / 255.0if one_hot_label:dataset[train_label] _change_one_hot_label(dataset[train_label])dataset[test_label] _change_one_hot_label(dataset[test_label])if not flatten:for key in (train_img, test_img):dataset[key] dataset[key].reshape(-1, 1, 28, 28)return (dataset[train_img], dataset[train_label]), (dataset[test_img], dataset[test_label]) if __name__ __main__:init_mnist()3、调用数据集
mnist_show.py文件用于调用数据集。
注意第三行导入父目录父目录必须有dataset文件夹文件夹中有mnist.py文件此代码才可以调用mnist.py文件。
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
from dataset.mnist import load_mnist #此处要求在当前文件路径下有dataset文件夹文件夹中有mnist.py文件
from PIL import Imagedef img_show(img):pil_img Image.fromarray(np.uint8(img))pil_img.show()(x_train, t_train), (x_test, t_test) load_mnist(flattenTrue, normalizeFalse)img x_train[0]
label t_train[0]
print(label) # 5print(img.shape) # (784,)
img img.reshape(28, 28) # 把图像的形状变为原来的尺寸
print(img.shape) # (28, 28)img_show(img)4、批处理
neuralne_mnist_batch.py
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmaxdef get_data():(x_train, t_train), (x_test, t_test) load_mnist(normalizeTrue, flattenTrue, one_hot_labelFalse)return x_test, t_testdef init_network():with open(sample_weight.pkl, rb) as f:network pickle.load(f)return networkdef predict(network, x):w1, w2, w3 network[W1], network[W2], network[W3]b1, b2, b3 network[b1], network[b2], network[b3]a1 np.dot(x, w1) b1z1 sigmoid(a1)a2 np.dot(z1, w2) b2z2 sigmoid(a2)a3 np.dot(z2, w3) b3y softmax(a3)return yx, t get_data()
network init_network()batch_size 100 # 批数量
accuracy_cnt 0for i in range(0, len(x), batch_size):x_batch x[i:ibatch_size]y_batch predict(network, x_batch)p np.argmax(y_batch, axis1)accuracy_cnt np.sum(p t[i:ibatch_size])print(Accuracy: str(float(accuracy_cnt) / len(x)))