您的位置 首页 > 数码极客

如何调用pb模型预测一张图

内容导读

但是有很多公司后台应用是用Java开发的,如果用python提供HTTP接口,对业务延迟要求比较高的话,仍然会有一定得延迟,所以能不能使用Java调用模型,python可以离线的训练模型?Deeplearning4j目前支持导入Keras训练的模型,并且提供了类似python中numpy的一些功能,更方便地处理结构化的数据。该方法可以将tensor为Variable的graph全部转为constant并且使用训练后的weight。注意output_name比较重要,后面Java调用模型的时候会用到。下面的代码可以查看定义好的Keras模型的输入、输出的name,这对之后Java调用有帮助。至此,已经可以实现Keras离线训练,Java在线预测的功能。

实现python离线训练模型,Java在线预测部署。查看原文

目前深度学习主流使用python训练自己的模型,有非常多的框架提供了能快速搭建神经网络的功能,其中Keras提供了high-level的语法,底层可以使用Tensorflow或者theano。

但是有很多公司后台应用是用Java开发的,如果用python提供HTTP接口,对业务延迟要求比较高的话,仍然会有一定得延迟,所以能不能使用Java调用模型,python可以离线的训练模型?(tensorflow也提供了成熟的部署方案TensorFlow Serving)

手头上有一个用Keras训练的模型,网上关于Java调用Keras模型的资料不是很多,而且大部分是重复的,并且也没有讲的很详细。大致有两种方案,一种是基于Java的深度学习库导入Keras模型实现,另外一种是用tensorflow提供的Java接口调用。

Deeplearning4J

Eclipse Deeplearning4j is the first commercial-grade, open-source, distributed deep-learning library written for Java and Scala. Integrated with Hadoop and Spark, DL4J brings AIAI to business environments for use on distributed GPUs and CPUs.

Deeplearning4j目前支持导入Keras训练的模型,并且提供了类似python中numpy的一些功能,更方便地处理结构化的数据。遗憾的是,Deeplearning4j现在只覆盖了Keras <2.0版本的大部分Layer,如果你是用Keras 2.0以上的版本,在导入模型的时候可能会报错。

了解更多:

Keras Model Import: Supported Features

Importing Models From Keras to Deeplearning4j

Tensorflow

文档,Java的文档很少,不过调用模型的过程也很简单。采用这种方式调用模型需要先将Keras导出的模型转成tensorflow的protobuf协议的模型。

1、Keras的h5模型转为pb模型

在Keras中使用model.save)保存当前模型为HDF5格式的文件中。

Keras的后端框架使用的是tensorflow,所以先把模型导出为pb模型。在Java中只需要调用模型进行预测,所以将当前的graph中的Variable全部变成Constant,并且使用训练后的weight。以下是freeze graph的代码:

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): """ :param session: 需要转换的tensorflow的session :param keep_var_names:需要保留的variable,默认全部转换constant :param output_names:output的名字 :param clear_devices:是否移除设备指令以获得更好的可移植性 :return: """ from import convert_variables_to_constants graph = with gra(): freeze_var_names = list(se for v in ()).difference(keep_var_names or [])) output_names = output_names or [] # 如果指定了output名字,则复制一个新的Tensor,并且以指定的名字命名 if len(output_names) > 0: for i in range(output_names): # 当前graph中复制一个新的Tensor,指定名字 [i], name=output_names[i]) output_names += [v.op.name for v in ()] input_graph_def = gra() if clear_devices: for node in in: node.device = "" frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names) return frozen_graph

该方法可以将tensor为Variable的graph全部转为constant并且使用训练后的weight。注意output_name比较重要,后面Java调用模型的时候会用到。

在Keras中,模型是这么定义的:

def create_model(self): input_tensor = Input(shape=,), name="input") x = Embedding(len) + 1, 200)(input_tensor) x = Bidirectional(LSTM(128))(x) x = Dense(256, activation="relu")(x) x = Dropou)(x) x = Dense(len), activation='softmax', name="output_softmax")(x) model = Model(inputs=input_tensor, outputs=x) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

下面的代码可以查看定义好的Keras模型的输入、输出的name,这对之后Java调用有帮助。

prin) prin)

训练好Keras模型后,转换为pb模型:

from keras import backend as K import tensorflow as tf model.load_model("model.h5") prin) prin) # 自定义output_names frozen_graph = freeze_session(), output_names=["output"]) (frozen_graph, "./", "model.pb", as_text=False) ### 输出: # input # output_softmax/Softmax # 如果不自定义output_name,则生成的pb模型的output_name为output_softmax/Softmax,如果自定义则以自定义名为output_name

运行之后会生成model.pb的模型,这将是之后调用的模型。

2、Java调用

新建一个maven项目,pom里面导入tensorflow包:

<dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.6.0</version> </dependency>

核心代码:

public void predict() throws Exception { try (Graph graph = new Graph()) { gra( "path/to; ))); try (Session sess = new Session(graph)) { // 自己构造一个输入 float[][] input = {{56, 632, 675, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; try (Tensor x = Ten(input); // input是输入的name,output是输出的name Tensor y = ().feed("input", x).fetch("output").run().get(0)) { float[][] result = new float[1][y.shape[1]]; y.copyTo(result); Sy())); Sy(result[0])); } } } }

Graph和Tensor对象都是需要通过close()方法显式地释放占用的资源,代码中使用了try-with-resources的方法实现的。

至此,已经可以实现Keras离线训练,Java在线预测的功能。

责任编辑: 鲁达

1.内容基于多重复合算法人工智能语言模型创作,旨在以深度学习研究为目的传播信息知识,内容观点与本网站无关,反馈举报请
2.仅供读者参考,本网站未对该内容进行证实,对其原创性、真实性、完整性、及时性不作任何保证;
3.本站属于非营利性站点无毒无广告,请读者放心使用!

“如何调用pb模型预测一张图”边界阅读