Error in importing trained TensorFlow SavedModel using importTensorFlowNetwork function

조회 수: 22 (최근 30일)
I am trying to import a trained tensoflow neural network model. Initially the trained model is in checkpoint format (ckpt). I was able to convert the ckpt to savedModel (pb) format for use in importTensorFlowNetwork function. While running the function I obtain the following error:
>>
unzip('EXP1_pb_model.zip')
net = importTensorFlowNetwork('EXP1_pb_model');
Importing the saved model...
Unrecognized field name "object_graph_def".

Error in nnet.internal.cnn.tensorflow.savedmodel.TFSavedModel (line 32)
obj.KerasManager = savedmodel.TFKerasManager(smstruct.meta_graphs.object_graph_def, obj.SavedModelPath, kerasImporterOptions, importNetwork);

Error in nnet.internal.cnn.tensorflow.importTensorFlowNetwork (line 21)
sm = savedmodel.TFSavedModel(path, options, true);

Error in importTensorFlowNetwork (line 107)
Network = nnet.internal.cnn.tensorflow.importTensorFlowNetwork(modelFolder, varargin{:});
The python code I used to convert the ckpt to pb is as follow:
import os
import tensorflow as tf
trained_checkpoint_prefix = '/exp_1/model.ckpt'
export_dir = os.path.join('model', 'EXP1_model')
graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
# Restore from checkpoint
loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
loader.restore(sess, trained_checkpoint_prefix)
# Export checkpoint to SavedModel
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,
[tf.saved_model.TRAINING, tf.saved_model.SERVING],
strip_default_attrs=True)
builder.save()
Attached are the both the ckpt and pb models.
I would appericiate any help in resolving this issue.
Thanks

채택된 답변

Anshika Chaurasia
Anshika Chaurasia 2021년 9월 14일
Hi Iman,
We currently support the import of TF models saved using the Sequential and Funtional Keras Model APIs (https://keras.io/guides/sequential_model/ & https://keras.io/guides/functional_api/ ) . But the models in your case do not use that API to save the model.
The importTensorFlowNetwork look for a graph called "object_graph_def" in the saved model and does not find it if the model was saved as it is in your case (hence the "Unrecognized field" error).
Hope it helps!

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2021a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by