别人不能注册我的wordpress站,视频号视频怎么看下载链接,企业官网设计现状,事业单位 网络网站建设课程来源#xff1a;人工智能实践:Tensorflow笔记2 文章目录前言断点续训主要步骤参数提取主要步骤总结前言
本讲目标:断点续训#xff0c;存取最优模型#xff1b;保存可训练参数至文本 断点续训主要步骤
读取模型#xff1a; 先定义出存放模型的路径和文件名#xff0…课程来源人工智能实践:Tensorflow笔记2 文章目录前言断点续训主要步骤参数提取主要步骤总结前言
本讲目标:断点续训存取最优模型保存可训练参数至文本 断点续训主要步骤
读取模型 先定义出存放模型的路径和文件名命名为.ckpt文件。 生成ckpt文件的时候会同步生成索引表所以通过判断是否存在索引表来知晓是不是已经保存过模型参数。 如果有了索引表就利用load_weights函数读取已经保存的模型参数。 code: checkpoint_save_path ./checkpoint/fashion.ckpt
if os.path.exists(checkpoint_save_path .index):print(-------------load the model-----------------)model.load_weights(checkpoint_save_path)保存模型 保存模型参数可以使用TensorFlow给出的回调函数直接保存训练出来的模型参数 tf.keras.callbacks.ModelCheckpoint( filepath路径文件名(文件存储路径), save_weights_onlyTrue/False,(是否只保留参数模型) save_best_onlyTrue/False(是否只保留最优结果)) 执行训练过程中时加入callbacks选项: historymodel.fit(callbacks[cp_callback]) code:
cp_callback tf.keras.callbacks.ModelCheckpoint(filepathcheckpoint_save_path,save_weights_onlyTrue,save_best_onlyTrue)history model.fit(x_train, y_train, batch_size32, epochs5, validation_data(x_test, y_test), validation_freq1,callbacks[cp_callback])第一次运行 第二次运行可以发现模型并不是从初始训练而是在基于保存的模型开始训练的这一点可以从准确率和损失看出 全部代码
import tensorflow as tf
import osfashion tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) fashion.load_data()
x_train, x_test x_train / 255.0, x_test / 255.0model tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activationrelu),tf.keras.layers.Dense(10, activationsoftmax)
])model.compile(optimizeradam,losstf.keras.losses.SparseCategoricalCrossentropy(from_logitsFalse),metrics[sparse_categorical_accuracy])checkpoint_save_path ./checkpoint/fashion.ckpt
if os.path.exists(checkpoint_save_path .index):print(-------------load the model-----------------)model.load_weights(checkpoint_save_path)cp_callback tf.keras.callbacks.ModelCheckpoint(filepathcheckpoint_save_path,save_weights_onlyTrue,save_best_onlyTrue)history model.fit(x_train, y_train, batch_size32, epochs5, validation_data(x_test, y_test), validation_freq1,callbacks[cp_callback])
model.summary()
参数提取主要步骤
设置打印的格式使所有参数都打印出来 np.set_printoptions(thresholdnp.inf) print(model.trainable_variables) 将所有可训练参数存入文本
file open(./weights.txt, w)
for v in model.trainable_variables:file.write(str(v.name) \n)file.write(str(v.shape) \n)file.write(str(v.numpy()) \n)
file.close()
完整代码:
import tensorflow as tf
import os
import numpy as npnp.set_printoptions(thresholdnp.inf)fashion tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) fashion.load_data()
x_train, x_test x_train / 255.0, x_test / 255.0model tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activationrelu),tf.keras.layers.Dense(10, activationsoftmax)
])model.compile(optimizeradam,losstf.keras.losses.SparseCategoricalCrossentropy(from_logitsFalse),metrics[sparse_categorical_accuracy])checkpoint_save_path ./checkpoint/fashion.ckpt
if os.path.exists(checkpoint_save_path .index):print(-------------load the model-----------------)model.load_weights(checkpoint_save_path)cp_callback tf.keras.callbacks.ModelCheckpoint(filepathcheckpoint_save_path,save_weights_onlyTrue,save_best_onlyTrue)history model.fit(x_train, y_train, batch_size32, epochs5, validation_data(x_test, y_test), validation_freq1,callbacks[cp_callback])
model.summary()print(model.trainable_variables)
file open(./weights.txt, w)
for v in model.trainable_variables:file.write(str(v.name) \n)file.write(str(v.shape) \n)file.write(str(v.numpy()) \n)
file.close()
效果
总结 课程链接:MOOC人工智能实践TensorFlow笔记2