蜂蜜网站建设,天津小型网站建设,北京出名的室内设计公司,织梦更换网站模板前言
最近工作上需要在C上快速集成Tensorflow/Keras训练好的模型#xff0c;做算法验证。首先想到的就是opencv里面的dnn模块了#xff0c;但是它需要的格式文件比较郁闷#xff0c;是pb格式的模型#xff0c;但是keras通常保存的是h5文件#xff0c;查阅了很多资料…前言
最近工作上需要在C上快速集成Tensorflow/Keras训练好的模型做算法验证。首先想到的就是opencv里面的dnn模块了但是它需要的格式文件比较郁闷是pb格式的模型但是keras通常保存的是h5文件查阅了很多资料最后找到了很方便的方法。
国际惯例参考博客
Frozen_Graph_TensorFlow
这个地址的大佬用fashion mnist写的训练和测试我这里用更简单的线性回归为例。
训练
老样子引入相关的包创建数据集
import numpy as np
import tensorflow as tf# build data
input_x np.random.rand(1000, 4)
output_y np.dot(input_x,np.array([[3],[14],[6],[10]]))然后创建简单的模型
# build model
inputs tf.keras.layers.Input(shape(4,));
outputs tf.keras.layers.Dense(units1)(inputs)
model tf.keras.Model(inputsinputs,outputsoutputs)编译、训练
model.compile(optimizertf.keras.optimizers.SGD(learning_rate0.01),lossmse)
model.fit(xinput_x,youtput_y,epochs100,validation_split0.1)看看训练完毕的权重是否接近我们创建的权重
model.get_weights()[array([[ 2.669988],[13.562156],[ 5.622848],[ 9.587276]], dtypefloat32),array([0.80391794], dtypefloat32)]保存
保存的时候就需要注意了要保存成pb格式的不要直接model.save成h5格式的如果你已经保存完了可以model.load_model载入进来然后再执行如下函数
# Convert Keras model to ConcreteFunction
full_model tf.function(lambda x: model(x))
full_model full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))# Get frozen ConcreteFunction
frozen_func convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()layers [op.name for op in frozen_func.graph.get_operations()]
print(- * 50)
print(Frozen model layers: )
for layer in layers:print(layer)print(- * 50)
print(Frozen model inputs: )
print(frozen_func.inputs)
print(Frozen model outputs: )
print(frozen_func.outputs)# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_deffrozen_func.graph,logdir./frozen_models,namefrozen_graph.pb,as_textFalse)--------------------------------------------------
Frozen model layers:
x
model/dense/MatMul/ReadVariableOp/resource
model/dense/MatMul/ReadVariableOp
model/dense/MatMul
model/dense/BiasAdd/ReadVariableOp/resource
model/dense/BiasAdd/ReadVariableOp
model/dense/BiasAdd
Identity
--------------------------------------------------
Frozen model inputs:
[tf.Tensor x:0 shape(None, 4) dtypefloat32]
Frozen model outputs:
[tf.Tensor Identity:0 shape(None, 1) dtypefloat32]这样就会在./frozen_models/frozen_graph.pb目录下看到模型了。
使用TensorFlow调用
在参考博客中作者也提供了对应的调用方法
def wrap_frozen_graph(graph_def, inputs, outputs, print_graphFalse):def _imports_graph_def():tf.compat.v1.import_graph_def(graph_def, name)wrapped_import tf.compat.v1.wrap_function(_imports_graph_def, [])import_graph wrapped_import.graphprint(- * 50)print(Frozen model layers: )layers [op.name for op in import_graph.get_operations()]if print_graph True:for layer in layers:print(layer)print(- * 50)return wrapped_import.prune(tf.nest.map_structure(import_graph.as_graph_element, inputs),tf.nest.map_structure(import_graph.as_graph_element, outputs))通过tensorflow1.x载入模型
with tf.io.gfile.GFile(./frozen_models/frozen_graph.pb, rb) as f:graph_def tf.compat.v1.GraphDef()loaded graph_def.ParseFromString(f.read())随后生成预测函数
frozen_func wrap_frozen_graph(graph_defgraph_def,inputs[x:0],outputs[Identity:0],print_graphTrue)调用这个预测函数试验一下
test_x np.array([[1,1,1,1]],np.float32)
pred_y frozen_func(xtf.constant(test_x))[0]
print(pred_y)#tf.Tensor([[32.246185]], shape(1, 1), dtypefloat32)
true_y np.dot(test_x,np.array([[3],[14],[6],[10]]))
print(true_y) #[[33.]]使用OpenCV-python调用模型
这个非常简单了四句话完成创建测试数据、读取模型、载入数据、预测
test_x np.array([[1,1,1,1]],np.float32)
net cv2.dnn.readNetFromTensorflow(./frozen_models/frozen_graph.pb)
net.setInput(test_x)
pred net.forward()
print(pred)#[[32.246185]]与上面用tensorflow调用的结果一模一样。
使用opencv-C调用模型
核心就在于opencv的readNetFromTensorflow接受的输入固定是InputArray类型的这个类型等价于Mat、vector等具体可以上网查。这里我们需要创建一个大小为(1,4)的Mat数据
下面这个函数就是使用数组去初始化Mat数据
void InitMat(Mat m,float* num)
{for(int i0;im.rows;i)for(int j0;jm.cols;j)m.atfloat(i,j)*(numi*m.rowsj);
}创建一个Mat数据
float sz[] {1,1,1,1};
Mat input(1,4,CV_32F);
InitMat(input, sz);接下来就是读取模型、载入数据、预测
dnn::Net net dnn::readNetFromTensorflow(./frozen_models/frozen_graph.pb);
net.setInput(input);
Mat pred net.forward();
coutpredendl; //[32.246185]后记
上面基本贴出了所有的代码如果运行不了可以进入公众号在公众号简介的github中找到源码。