完善资料让更多小伙伴认识你,还能领取20积分哦, 立即完善>
keras可视化可以帮助我们直观的查看所搭建的模型拓扑结构,以及模型的训练的过程,方便我们优化模型。
模型可视化又分为模型拓扑结构可视化以及训练过程可视化。 以上一讲的mnist为例,演示不同可视化方法: 1 Netron 查看h5模型 参考《TFlite之格式解析》 Netron部分,Netron 是一款常见的可视化工具,支持网页查看常见的AI模型,支持非常丰富的格式(ONNX, Tensorflow, Pytorch, Keras, Caffe等),网页地址: https://netron.app/ 将上一讲生成的keras_mnist.h5导入,得到模型结构,如下图: 2 keras的model.summary()方法 对于一些简单的模型,可以直接使用keras提供的model.summary()方法,如上一讲的mnist模型,代码中: # 搭建好模型后,加上这一句 print("model:") model.summary()输出模型如下: model: Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten (Flatten) (None, 784) 0 dense (Dense) (None, 784) 615440 dense_1 (Dense) (None, 10) 7850 ================================================================= Total params: 623,290 Trainable params: 623,290 Non-trainable params: 0 _________________________________________________________________可见模型有一个flatten层,两个全连接层。 3 keras的graphviz功能 keras.utils.vis_utils模块提供了画出Keras模型的函数(利用graphviz) import tensorflow as tf import tensorflow.keras as keras # 搭建好模型后,加上下一句 keras.utils.plot_model(model, to_file='model.svg', show_shapes=True)注意:我这里按照官网生成model.png会失败,寻找网上的解决方案也无法解决,只好换用svg格式,svg格式可以通过浏览器直接打开,但是可能显示不全,可以修改model.svg的参数解决,参照参考2: # 要调节width、height参数,以及viewBox参数 <svg width="825pt" height="825pt" viewBox="0.00 0.00 825.00 825.00" # viewBox="x, y, width, height"浏览器解析,模型如下: plot_model函数定义: tf.keras.utils.plot_model( model, # keras model句柄 to_file='model.png', # 保存文件名及格式(这里使用svg格式) show_shapes=False, # 是否显示形状信息,默认不显示 show_layer_names=True, # 显示layer名 rankdir='TB', # 横向显示(LR), 纵向显示(TB) expand_nested=False, # 是否将嵌套模型扩展到聚类中 dpi=96 )4 训练历史可视化 Keras Model 上的 fit() 方法返回一个 History 对象。History.history 是一个记录了连续迭代的训练/验证损失值和评估值的字典。可以通过matplotlib将数据展示出来(这里就不使用matplotlib画图了,将数据打印出来): # 截取部分代码如下: import tensorflow as tf import tensorflow.keras as keras # step4: train history = model.fit(x_train, y_train, batch_size=64, epochs=5) # 打印history字典中的keys值 print(history.history.keys()) # 获取验证准确率数据 print(history.history['accuracy']) # 获取训练时的损失值 print(history.history['loss'])执行结果为: dict_keys(['loss', 'accuracy']) [0.9143333435058594, 0.9519500136375427, 0.9597333073616028, 0.9615499973297119, 0.9619333148002625] [3.5814223289489746, 0.38597372174263, 0.26168128848075867, 0.23524414002895355, 0.24421487748622894]5 训练过程的可视化:keras + Tensorboard Tensorboard提供训练过程可视化的功能,是通过keras的回调函数来实现的。 # 截取部分代码如下: import tensorflow as tf import tensorflow.keras as keras from keras.callbacks import TensorBoard tbCallBack = TensorBoard() # 默认日志放到./logs 文件夹下 history = model.fit(x_train, y_train, batch_size=64, epochs=5, callbacks=[tbCallBack]) 通过shell下执行如下命令,然后使用浏览器可打开tensorboard面板: tensorboard --logdir /path/to/logsTensorboard 功能强大,对tensorflow、keras、还是pytorch都提供良好支持,这里先不做展开,可以参考3。 |
|
相关推荐 |
|
只有小组成员才能发言,加入小组>>
在软件SDK中选择不同的下载模式时,是哪个部件更改了QSPI0中寄存器的值?
419 浏览 2 评论
cmt_instret_ena的使能为什么要排除branch等指令造成流水线冲刷的情况?
562 浏览 1 评论
e203 rom启动仅仅是引导到itcm执行指令吗?flash启动就是加载指令到itcm中吗?
593 浏览 1 评论
小黑屋| 手机版| Archiver| 电子发烧友 ( 湘ICP备2023018690号 )
GMT+8, 2025-1-11 03:11 , Processed in 0.634175 second(s), Total 62, Slave 48 queries .
Powered by 电子发烧友网
© 2015 bbs.elecfans.com
关注我们的微信
下载发烧友APP
电子发烧友观察
版权所有 © 湖南华秋数字科技有限公司
电子发烧友 (电路图) 湘公网安备 43011202000918 号 电信与信息服务业务经营许可证:合字B2-20210191 工商网监 湘ICP备2023018690号