Keras入门(二)模型的保存、读取及加载
off999 2024-11-22 19:03 27 浏览 0 评论
本文将会介绍如何利用Keras来实现模型的保存、读取以及加载。
??本文使用的模型为解决IRIS数据集的多分类问题而设计的深度神经网络(DNN)模型,模型的结构示意图如下:
具体的模型参数可以参考文章:Keras入门(一)搭建深度神经网络(DNN)解决多分类问题。
模型保存
??Keras使用HDF5文件系统来保存模型。模型保存的方法很容易,只需要使用save()方法即可。
??以Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的DNN模型为例,整个模型的变量为model,我们设置模型共训练10次,在原先的代码中加入Python代码即可保存模型:
# save model
print("Saving model to disk \n")
mp = "E://logs/iris_model.h5"
model.save(mp)
保存的模型文件(iris_model.h5)如下:
模型读取
??保存后的iris_model.h5以HDF5文件系统的形式储存,在我们使用Python读取h5文件里面的数据之前,我们先用HDF5的可视化工具HDFView来查看里面的数据:
??我们感兴趣的是这个模型中的各个神经层之间的连接权重及偏重,也就是上图中的红色部分,model_weights里面包含了各个神经层之间的连接权重及偏重,分别位于dense_1,dense_2,dense_3中。蓝色部分为dense_3/dense_3/kernel:0的数据,即最后输出层的连接权重矩阵。
??有了对模型参数的直观认识,我们要做的下一步工作就是读取各个神经层之间的连接权重及偏重。我们使用Python的h5py这个模块来这个iris_model.h5这个文件。关于h5py的快速入门指南,可以参考文章:h5py快速入门指南(https://www.jianshu.com/p/a6328c4f4986)。
??使用以下Python代码可以读取各个神经层之间的连接权重及偏重数据:
import h5py
# 模型地址
MODEL_PATH = 'E://logs/iris_model.h5'
# 获取每一层的连接权重及偏重
print("读取模型中...")
with h5py.File(MODEL_PATH, 'r') as f:
dense_1 = f['/model_weights/dense_1/dense_1']
dense_1_bias = dense_1['bias:0'][:]
dense_1_kernel = dense_1['kernel:0'][:]
dense_2 = f['/model_weights/dense_2/dense_2']
dense_2_bias = dense_2['bias:0'][:]
dense_2_kernel = dense_2['kernel:0'][:]
dense_3 = f['/model_weights/dense_3/dense_3']
dense_3_bias = dense_3['bias:0'][:]
dense_3_kernel = dense_3['kernel:0'][:]
print("第一层的连接权重矩阵:\n%s\n"%dense_1_kernel)
print("第一层的连接偏重矩阵:\n%s\n"%dense_1_bias)
print("第二层的连接权重矩阵:\n%s\n"%dense_2_kernel)
print("第二层的连接偏重矩阵:\n%s\n"%dense_2_bias)
print("第三层的连接权重矩阵:\n%s\n"%dense_3_kernel)
print("第三层的连接偏重矩阵:\n%s\n"%dense_3_bias)
输出的结果如下:
读取模型中...
第一层的连接权重矩阵:
[[ 0.04141677 0.03080632 -0.02768146 0.14334357 0.06242227]
[-0.41209617 -0.77948487 0.5648218 -0.699587 -0.19246106]
[ 0.6856315 0.28241938 -0.91930366 -0.07989818 0.47165248]
[ 0.8655262 0.72175753 0.36529952 -0.53172135 0.26573092]]
第一层的连接偏重矩阵:
[-0.16441862 -0.02462054 -0.14060321 0. -0.14293939]
第二层的连接权重矩阵:
[[ 0.39296603 0.01864707 0.12538083 0.07935872 0.27940807 -0.4565802 ]
[-0.34312084 0.6446907 -0.92546445 -0.00538039 0.95466876 -0.32819661]
[-0.7593299 -0.07227057 0.20751365 0.40547106 0.35726753 0.8884158 ]
[-0.48096 0.11294878 -0.29462305 -0.410536 -0.23620337 -0.72703975]
[ 0.7666149 -0.41720924 0.29576775 -0.6328017 0.43118536 0.6589351 ]]
第二层的连接偏重矩阵:
[-0.1899569 0. -0.09710662 -0.12964155 -0.26443157 0.6050924 ]
第三层的连接权重矩阵:
[[-0.44450542 0.09977101 0.12196152]
[ 0.14334357 0.18546402 -0.23861367]
[-0.7284191 0.7859063 -0.878823 ]
[ 0.0876545 0.51531947 0.09671918]
[-0.7964963 -0.16435687 0.49531657]
[ 0.8645698 0.4439873 0.24599855]]
第三层的连接偏重矩阵:
[ 0.39192322 -0.1266532 -0.29631865]
值得注意的是,我们得到的这些矩阵的数据类型都是numpy.ndarray。
??OK,既然我们已经得到了各个神经层之间的连接权重及偏重的数据,那我们能做什么呢?当然是去做一些有趣的事啦,那就是用我们自己的方法来实现新数据的预测向量(softmax函数作用后的向量)。so, really?
??新的输入向量为[6.1, 3.1, 5.1, 1.1],使用以下Python代码即可输出新数据的预测向量:
import h5py
import numpy as np
# 模型地址
MODEL_PATH = 'E://logs/iris_model.h5'
# 获取每一层的连接权重及偏重
print("读取模型中...")
with h5py.File(MODEL_PATH, 'r') as f:
dense_1 = f['/model_weights/dense_1/dense_1']
dense_1_bias = dense_1['bias:0'][:]
dense_1_kernel = dense_1['kernel:0'][:]
dense_2 = f['/model_weights/dense_2/dense_2']
dense_2_bias = dense_2['bias:0'][:]
dense_2_kernel = dense_2['kernel:0'][:]
dense_3 = f['/model_weights/dense_3/dense_3']
dense_3_bias = dense_3['bias:0'][:]
dense_3_kernel = dense_3['kernel:0'][:]
# 模拟每个神经层的计算,得到该层的输出
def layer_output(input, kernel, bias):
return np.dot(input, kernel) + bias
# 实现ReLU函数
relu = np.vectorize(lambda x: x if x >=0 else 0)
# 实现softmax函数
def softmax_func(arr):
exp_arr = np.exp(arr)
arr_sum = np.sum(exp_arr)
softmax_arr = exp_arr/arr_sum
return softmax_arr
# 输入向量
unkown = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
# 第一层的输出
print("模型计算中...")
output_1 = layer_output(unkown, dense_1_kernel, dense_1_bias)
output_1 = relu(output_1)
# 第二层的输出
output_2 = layer_output(output_1, dense_2_kernel, dense_2_bias)
output_2 = relu(output_2)
# 第三层的输出
output_3 = layer_output(output_2, dense_3_kernel, dense_3_bias)
output_3 = softmax_func(output_3)
# 最终的输出的softmax值
np.set_printoptions(precision=4)
print("最终的预测值向量为: %s"%output_3)
其输出的结果如下:
读取模型中...
模型计算中...
最终的预测值向量为: [[0.0242 0.6763 0.2995]]
??额,这个输出的预测值向量会是我们的DNN模型的预测值向量吗?这时候,我们就需要回过头来看看Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的代码了,注意,为了保证数值的可比较性,笔者已经将DNN模型的训练次数改为10次了。让我们来看看原来代码的输出结果吧:
Using model to predict species for features:
[[6.1 3.1 5.1 1.1]]
Predicted softmax vector is:
[[0.0242 0.6763 0.2995]]
Predicted species is:
Iris-versicolor
Yes,两者的预测值向量完全一致!因此,我们用自己的方法也实现了这个DNN模型的预测功能,棒!
模型加载
??当然,在实际的使用中,我们不需要再用自己的方法来实现模型的预测功能,只需使用Keras给我们提供好的模型导入功能(keras.models.load_model())即可。使用以下Python代码即可加载模型
# 模型的加载及使用
from keras.models import load_model
print("Using loaded model to predict...")
load_model = load_model("E://logs/iris_model.h5")
np.set_printoptions(precision=4)
unknown = np.array([[6.1, 3.1, 5.1, 1.1]], dtype=np.float32)
predicted = load_model.predict(unknown)
print("Using model to predict species for features: ")
print(unknown)
print("\nPredicted softmax vector is: ")
print(predicted)
species_dict = {v: k for k, v in Class_dict.items()}
print("\nPredicted species is: ")
print(species_dict[np.argmax(predicted)])
输出结果如下:
Using loaded model to predict...
Using model to predict species for features:
[[6.1 3.1 5.1 1.1]]
Predicted softmax vector is:
[[0.0242 0.6763 0.2995]]
Predicted species is:
Iris-versicolor
总结
??本文主要介绍如何利用Keras来实现模型的保存、读取以及加载。
??本文将不再给出完整的Python代码,如需完整的代码,请参考Github地址:https://github.com/percent4/Keras_4_multiclass.
注意:本人现已开通微信公众号: NLP奇幻之旅(微信号为:easy_web_scrape), 欢迎大家关注哦~~
相关推荐
- 笔记本电脑选哪个品牌比较好
-
1、苹果APPLE/美国2、戴尔DELL/美国3、华为HUAWEI/中国4、小米MI/中国5、微软Microsoft/美国6、联想LENOVO/中国7、惠普HP/美国8、华硕ASUS/...
- 10系列显卡排名(10系显卡性能排行)
-
十系显卡指NVIDIAGeForce10系列,是英伟达研发并推出的图形处理器系列,被用以取代NVIDIAGeForce900系列图形处理器。新系列采用帕斯卡微架构来代替之前的麦克斯韦微架构,并...
-
- 最新win7系统下载(windows7最新版本下载)
-
最简单的方法就是,下载完镜像文件后,直接把镜像文件解压,解压到非C盘,然后在解压文件里面找到setup.exe,点击运行即可。安装系统完成后,在C盘找到一个Windows.old(好几个GB,是旧系统打包在这里,垃圾文件了)删除即可。扩展资...
-
2026-01-15 06:43 off999
- 哪个电脑管家软件好用(哪个电脑管家好用些)
-
腾讯电脑管家吧,因为这个是杀毒和管理合一的,占用内存小,因此显得更为简洁,使电脑运行更加流畅此外电脑诊所,工具箱以及4+1的杀毒模式让腾讯电脑管家也收到了广泛的关注4+1杀毒引擎,管家反病毒引擎、金山...
- 怎么进入win7安全模式(怎么进入win7安全模式界面)
-
方法如下:1、首先进入Win7系统,然后使用Win键+R组合键打开运行框,输入“Msconfig”回车进入系统配置。2、在打开的系统配置中,找到“引导”选项,然后单击,选择Win7的引导项,然后在“安...
- 怎么分区固态硬盘(怎样分区固态硬盘)
-
固态硬盘的分区方法与传统机械硬盘基本相同,以下是一个简单的步骤:1.打开磁盘管理工具:在Windows操作系统中,按下Win+X键,选择"磁盘管理"。或者打开控制面板,在"...
-
- 笔记本声卡驱动怎么下载(笔记本如何下载声卡)
-
1、在浏览器中输入并搜索,然后下载并安装。2、安装完成后打开360驱动大师,它就会自动检测你的电脑需要安装或升级的驱动。3、检测完毕后,我们可以看到我们的声卡驱动需要安装或升级,点击安装或升级,就会开始自动安装或升级声卡了。4、升级过程中会...
-
2026-01-15 05:43 off999
- win10加快开机启动速度(加快开机速度 win10)
-
一、启用快速启动功能1.按win+r键调出“运行”在输入框输入“gpedit.msc”按回车调出“组策略编辑器”?2.在“本地组策略编辑器”依次打开“计算机配置——管理模块——系统——关机”在右侧...
-
- excel的快捷键一览表(excel的快捷键一览表超全)
-
Excel快捷键大全的一些操作如下我在工作中经常使用诸如word或Excel之类的办公软件。我相信每个人都不太熟悉这些办公软件的快捷键。使用快捷键将提高办公效率,并使您的工作更加轻松快捷。。例如,在复制时,请使用CtrI+C进行复制,...
-
2026-01-15 05:03 off999
- 华硕u盘启动按f几(华硕u盘装系统按f几进入)
-
F8。1、开机的同时按F8进入BIOS。2、在Boot菜单中,置secure为disabled。3、BootListOption置为UEFI。4、在1stBootPriority中usb—HD...
- 手机云电脑怎么用(手机云端电脑)
-
使用手机云电脑,您首先需要安装相应的云电脑应用。例如,华为云电脑APP。在安装并打开应用后,您将看到一个显示器的图标,这就是您的云电脑。点击这个图标,您将被连接到一个预装有Windows操作系统和必要...
- ie11浏览器怎么安装(ie11浏览器安装步骤)
-
如果IE浏览器11版本你发现无法正常安装,那么很可能是这样几个原因,一个就是电脑的存储空间不够到时无法安装,再有就是网络的问题,如果没有办法安装的话就不要再安装了,本身这个IE浏览器并不是多好用,你最...
- 台式机重装系统win7(台式机怎么重装win7)
-
下面主要介绍两种方法以重装系统:一、U盘重装系统准备:一台正常开机的电脑和一个U盘1、百度下载“U大师”(老毛桃、大白菜也可以),把这个软件下载并安装在电脑上。2、插上U盘,选择一键制作U盘启动(制作...
- 字母下划线怎么打出来(字母下的下划线怎么去不掉)
-
第一步,在电脑上找到文字处理软件WPS,双击即自动新建一个新文档。第二步,在文档录入需要处理的字母和数字,双击鼠标或拖动鼠标选择要处理的内容。第三步,在页面的左上方的横向菜单栏,找到字母U的按纽,点击...
欢迎 你 发表评论:
- 一周热门
-
-
抖音上好看的小姐姐,Python给你都下载了
-
全网最简单易懂!495页Python漫画教程,高清PDF版免费下载
-
飞牛NAS部署TVGate Docker项目,实现内网一键转发、代理、jx
-
Python 3.14 的 UUIDv6/v7/v8 上新,别再用 uuid4 () 啦!
-
python入门到脱坑 输入与输出—str()函数
-
Python三目运算基础与进阶_python三目运算符判断三个变量
-
(新版)Python 分布式爬虫与 JS 逆向进阶实战吾爱分享
-
失业程序员复习python笔记——条件与循环
-
系统u盘安装(win11系统u盘安装)
-
Python 批量卸载关联包 pip-autoremove
-
- 最近发表
- 标签列表
-
- python计时 (73)
- python安装路径 (56)
- python类型转换 (93)
- python进度条 (67)
- python吧 (67)
- python的for循环 (65)
- python格式化字符串 (61)
- python静态方法 (57)
- python列表切片 (59)
- python面向对象编程 (60)
- python 代码加密 (65)
- python串口编程 (77)
- python封装 (57)
- python写入txt (66)
- python读取文件夹下所有文件 (59)
- python操作mysql数据库 (66)
- python获取列表的长度 (64)
- python接口 (63)
- python调用函数 (57)
- python多态 (60)
- python匿名函数 (59)
- python打印九九乘法表 (65)
- python赋值 (62)
- python异常 (69)
- python元祖 (57)
