在TensorFlow中,保存模型的方法有以下几种:
- 使用
tf.keras.models.save_model()
函数保存整个模型,包括模型结构、模型权重和优化器状态等信息,可以通过tf.keras.models.load_model()
函数载入模型。
model.save('model.h5') loaded_model = tf.keras.models.load_model('model.h5')
- 使用
tf.saved_model.save()
函数保存模型为SavedModel格式,包括模型结构、权重和计算图等信息,可以通过tf.saved_model.load()
函数载入模型。
tf.saved_model.save(model, 'saved_model') loaded_model = tf.saved_model.load('saved_model')
- 使用
tf.train.Checkpoint
类保存模型的权重和优化器状态,可以通过restore()
方法恢复模型。
checkpoint = tf.train.Checkpoint(model=model) checkpoint.save('model_checkpoint') checkpoint.restore('model_checkpoint')
- 使用
tf.train.Saver
类保存和恢复模型的变量。
saver = tf.train.Saver() saver.save(sess, 'model.ckpt') saver.restore(sess, 'model.ckpt')
- 使用
tf.io.write_graph()
和tf.train.write_graph()
函数将模型导出为GraphDef格式或PB格式。
tf.io.write_graph(sess.graph_def, './', 'model.pb', as_text=False) tf.train.write_graph(sess.graph_def, './', 'model.pbtxt')