保健品网站怎么做的,网站建设创意,作品集展示的网站,湖北海厦建设有限公司网站初学入门:01-02 01 基本介绍02 快速入门库处理数据集网络构建模型训练保存模型加载模型打卡-时间 01 基本介绍
MindSpore Data#xff08;数据处理层#xff09; ModelZoo#xff08;模型库#xff09; MindSpore Science#xff08;科学计算#xff09;#xff0c;包含… 初学入门:01-02 01 基本介绍02 快速入门库处理数据集网络构建模型训练保存模型加载模型打卡-时间 01 基本介绍
MindSpore Data数据处理层 ModelZoo模型库 MindSpore Science科学计算包含了业界领先的数据集、基础模型、预置高精度模型和前后处理工具 MindSpore Insight可视化调试调优工具能够可视化地查看训练过程、优化模型性能、调试精度问题、解释推理结果
02 快速入门
库
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset处理数据集 下载Mnist数据集 # Download data from open datasets
from download import downloadurl https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/ \notebook/datasets/MNIST_Data.zip
path download(url, ./, kindzip, replaceTrue)训练集、测试集 train_dataset MnistDataset(MNIST_Data/train)
test_dataset MnistDataset(MNIST_Data/test)列名图片 和 对应标签分类 数据处理流水线Data Processing Pipeline 参数数据集、batch_size
def datapipe(dataset, batch_size):image_transforms [ vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean(0.1307,), std(0.3081,)),vision.HWC2CHW()]label_transform transforms.TypeCast(mindspore.int32)dataset dataset.map(image_transforms, image)dataset dataset.map(label_transform, label)dataset dataset.batch(batch_size)return dataset首先数据变换Transforms1、对输入数据即图片2、对输出即标签 然后map对图像数据及标签进行变换处理 最后将处理好的数据集打包为大小为64的batch
train_dataset datapipe(train_dataset, 64)
test_dataset datapipe(test_dataset, 64)对数据集进行迭代访问 for data in test_dataset.create_dict_iterator():print(fShape of image [N, C, H, W]: {data[image].shape} {data[image].dtype})print(fShape of label: {data[label].shape} {data[label].dtype})break网络构建
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten nn.Flatten()self.dense_relu_sequential nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x self.flatten(x)logits self.dense_relu_sequential(x)return logitsmodel Network()
print(model)mindspore.nn类是构建所有网络的基类也是网络的基本单元。
自定义网络时可以继承nn.Cell类__init__包含所有网络层的定义construct类似前向传播包含数据Tensor的变换过程。
模型训练 定义损失函数、优化器 loss_fn nn.CrossEntropyLoss()
optimizer nn.SGD(model.trainable_params(), 1e-2)一个完整的训练过程step需要实现以下三步
1. 正向计算模型预测结果logits并与正确标签label求预测损失loss。 2. 反向传播利用自动微分机制自动求模型参数parameters对于loss的梯度gradients。 3. 参数优化将梯度更新到参数上。 定义正向计算函数。 def forward_fn(data, label):logits model(data)loss loss_fn(logits, label)return loss, logits使用value_and_grad通过函数变换获得梯度计算函数。 grad_fn mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_auxTrue)one-step training def train_step(data, label):(loss, _), grads grad_fn(data, label)optimizer(grads)return loss定义训练函数使用set_train设置为训练模式执行正向计算、反向传播和参数优化。 def train(model, dataset):size dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss train_step(data, label)if batch % 100 0:loss, current loss.asnumpy(), batchprint(floss: {loss:7f} [{current:3d}/{size:3d}])定义测试函数用来评估模型的性能。 def test(model, dataset, loss_fn):num_batches dataset.get_dataset_size()model.set_train(False)total, test_loss, correct 0, 0, 0for data, label in dataset.create_tuple_iterator():pred model(data)total len(data)test_loss loss_fn(pred, label).asnumpy()correct (pred.argmax(1) label).asnumpy().sum()test_loss / num_batchescorrect / totalprint(fTest: \n Accuracy: {(100*correct):0.1f}%, Avg loss: {test_loss:8f} \n)训练过程需多轮epoch训练数据集
epochs 3
for t in range(epochs):print(fEpoch {t1}\n-------------------------------)train(model, train_dataset)test(model, test_dataset, loss_fn)
print(Done!)保存模型
模型训练完成后需要保存其参数。
mindspore.save_checkpoint(model, model.ckpt)
print(Saved Model to model.ckpt)加载模型
加载保存的权重
# 1、重新实例化模型对象构造模型
model Network()
# 加载模型参数并将其加载至模型上。
param_dict mindspore.load_checkpoint(model.ckpt)
param_not_load, _ mindspore.load_param_into_net(model, param_dict)
print(param_not_load)param_not_load是未被加载的参数列表为空时代表所有参数均加载成功。
打卡-时间
from datetime import datetime
import pytz
# 设置时区为北京时区
beijing_tz pytz.timezone(Asia/shanghai)
# 获取当前时间并转为北京时间
current_beijing_time datetime.now(beijing_tz)
# 格式化时间输出
formatted_time current_beijing_time.strftime(%Y-%m-%d %H:%M:%S)
print(当前北京时间:,formatted_time,your name)