模型导出&预测

1 模型导出

执行如下python代码即可完成,其中checkpoint_path参数如果不传,默认会使用pipeline_config_path中model_dir下的最新的checkpoint

import easy_vision
easy_vision.export(export_dir, pipeline_config_path, checkpoint_path)

2 输入输出信息

使用saved_model模型进行预测,我们需要获取输入输出的tensor节点。

2.1 输入的placeholder定义

name 说明 shape type
image batched图像tensor, Channel为RGB顺序 [batch_size, None, None, 3] tf.uint8
true_image_shape 每一张图像的真实shape,最后一维顺序为[height, width, channel]例如[ [224,224,3], [448, 448, 3]] [batch_size, 3] tf.int32

注: batch_size为导出模型时,pipeline_config中export_config配置的batch_size,batch_size设置为-1,表示使用动态的batch_size,目前只有分类模型支持动态batch_size。

2.2 输出tensor信息

输出为List of Json Result,List的Length与输入图像的张数相等,一下为各模型Json结果的示例与说明

feature_extractor

结果示例:

{"feature": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4583122730255127, 0.0]}

字段说明

name 含义 shape type 备注
feature 输出特征 [feature_dim] float 坐标顺序 [top, left, bottom, right]

classifier

结果示例

{
"class": 3, 
"class_name": "coho4",
"class_probs": {"coho1": 4.028851974258174e-10, 
          "coho2": 0.48115724325180054, 
          "coho3": 5.116515922054532e-07, 
          "coho4": 0.5188422446937221}
}
name 含义 shape type
class 类别id [] int32
class_name 类别名称 [] string
class_probs 所有类别概率 [num_classes] dict{key: string, value: float}

multilabel_classifier

结果示例

{
"class": [3, 4], 
"class_names": ["coho3", "coho4"],
"class_probs": {"coho1": 4.028851974258174e-10, 
          "coho2": 0.10115724325180054, 
          "coho3": 0.6188422446937221, 
          "coho4": 0.5188422446937221}
}
name 含义 shape type
class 类别id [None] int32
class_names 类别名称 [None] string
class_probs 所有类别概率 [num_classes] dict{key: string, value: float}

detector

注:同时支持实例分割

结果示例

{
  "detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]], 
  "detection_scores": [0.9942291975021362, 0.9940272569656372],
  "detection_classes": [1, 1],
  "detection_classe_names": ["text", "text"]
 }

字段说明

name 含义 shape type 备注
detection_boxes 检测到的目标框 [num_detections, 4] float 坐标顺序 [top, left, bottom, right]
detection_scores 目标检测概率 num_detections float
detection_classes 目标区域类别id num_detections int
detection_class_names 目标区域类别名称 num_detections string
detection_masks 目标区域分割遮罩 [num_detection, image_height, image_width] float 可选
detection_keypoints 目标区域中的关键点 [num_detection, num_keypoints, 2] float 可选
detection_roi_features 目标区域的局部特征图 [num_detection, roi_height, roi_width, channels] float 可选

detector_with_rpn

结果示例

{
  "proposal_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], 243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797],
  "proposal_scores": [0.88, 0.56],
  "detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]], 
  "detection_scores": [0.9942291975021362, 0.9940272569656372],
  "detection_classes": [1, 1],
  "detection_classe_names": ["text", "text"]
 }

字段说明

name 含义 shape type 备注
proposal_boxes proposal框 [num_proposal, 4] float
proposal_scores prososal的分 num_proposal float
detection_boxes 检测到的目标框 [num_detections, 4] float 坐标顺序 [top, left, bottom, right]
detection_scores 目标检测概率 num_detections float
detection_classes 目标区域类别id num_detections int
detection_class_names 目标区域类别名称 num_detections string

segmentor

结果示例

{
  "probs" : [[[0.8, 0.8], [0.6, 0.7]],[[0.8, 0.5], [0.4, 0.3]]],
  "preds" : [[[1,1], [0, 0]], [[0, 0], [1,1]]]
}

字段说明

name 含义 shape type
probs 分割像素点概率 [output_height, output_width, num_classes] float
preds 分割像素类别id [output_height, output_widths] int

text_detector

结果示例

{
  "detection_keypoints": [[[243.57516479492188, 198.84210205078125], [243.91038513183594, 247.62425231933594], [385.5513916015625, 246.61660766601562], [385.2197570800781, 197.79345703125]], [[292.2718200683594, 114.44700622558594], [292.2237243652344, 164.684814453125], [571.1962890625, 164.931640625], [571.2444458007812, 114.67433166503906]]], 
  "detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]], 
  "detection_scores": [0.9942291975021362, 0.9940272569656372],
  "detection_classes": [1, 1],
  "detection_classe_names": ["text", "text"],
   "image_shape": [1024, 968, 3]
 }

字段说明

name 含义 shape type 备注
detection_boxes 检测到的文字框 [num_detections, 4] float 坐标顺序 [top, left, bottom, right]
detection_scores 文字检测概率 num_detections float
detection_classes 文字区域类别id num_detections int
detection_class_names 文字区域类别名称 num_detections string
detection_keypoints 检测到的文字区域四个角点 [num_detections, 4, 2] float 每个point坐标为(y,x)
image_shape 输入图像大小 [3], 分别为height, width ,channel list

text_recognizer

结果示例

{
  "sequence_predict_ids": [1,2,2008,12],
  "sequence_predict_texts": "这是示例",
  "sequence_probability": 0.88
}

字段说明

name 含义 shape type
sequence_predict_ids 单行文字识别类别id [text_length] int
sequence_predict_texts 单行文字识别结果 [] string
sequence_probability 单行文字识别概率 [] float

text_spotter/text_pipeline_predictor

结果示例

{
  "detection_keypoints": [[[243.57516479492188, 198.84210205078125], [243.91038513183594, 247.62425231933594], [385.5513916015625, 246.61660766601562], [385.2197570800781, 197.79345703125]], [[292.2718200683594, 114.44700622558594], [292.2237243652344, 164.684814453125], [571.1962890625, 164.931640625], [571.2444458007812, 114.67433166503906]]], 
  "detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]], 
  "detection_scores": [0.9942291975021362, 0.9940272569656372],
  "detection_classes": [1, 1],
  "detection_classe_names": ["text", "text"],
  "detection_texts_ids" : [[1,2,2008,12], [1,2,2008,12]],
  "detection_texts": ["这是示例", "这是示例"],
  "detection_texts_scores" : [0.88, 0.88],
  "image_shape": [1024, 968, 3]
 }

字段说明

name 含义 shape type 备注
detection_boxes 检测到的文字框 [num_detections, 4] float 坐标顺序 [top, left, bottom, right]
detection_scores 文字检测概率 num_detections float
detection_classes 文字区域类别id num_detections int
detection_class_names 文字区域类别名称 num_detections string
detection_keypoints 检测到的文字区域四个角点 [num_detections, 4, 2] float 每个point坐标为(y,x)
detection_texts_ids 单行文字识别类别id [num_detections, max_text_length] int
detection_texts 单行文字识别结果 [num_detections] string
detection_texts_scores 单行文字识别概率 [num_detections] float
image_shape 输入图像大小 [3], 分别为height, width ,channel list

3 本地预测

EasyVision提供python预测接口,可以加载easy-vision导出的saved model进行预测,预测api详见API文档。 具体使用demo如下

import easy_vision as ev
import numpy as np

#识别
saved_model_path = 'xxx/xxx'
classifier = ev.Classifier(saved_model_path)
image = np.zeros([640, 480, 3],  dtype=np.float32)
output_dict = classifier.predict([image])

#检测
saved_model_path = 'xxx/xxx'
detector = ev.Detector(saved_model_path)
image = np.zeros([640, 480, 3],  dtype=np.float32)
output_dict = detector.predict([image])

#文字识别
saved_model_path = 'xxx/xxx'
text_recognizer = ev.TextRecognizer(saved_model_path)
image = np.zeros([640, 480, 3],  dtype=np.float32)
output_dict = text_recognizer.predict([image]) 

#文字检测
saved_model_path = 'xxx/xxx'
text_detector = ev.TextDetector(saved_model_path)
image = np.zeros([640, 480, 3],  dtype=np.float32)
output_dict = text_detector.predict([image]) 

#端到端文字识别
saved_model_path = 'xxx/xxx'
text_spotter = ev.TextSpotter(saved_model_path)
image = np.zeros([640, 480, 3],  dtype=np.float32)
output_dict = text_spotter.predict([image]) 

#基础predictor
saved_model_path = 'xxx/xxx'
predictor = ev.Predictor(saved_model_path)
image = np.zeros([640, 480, 3],  dtype=np.float32)
image_list = [image for i in range(10)]
batched_images, origin_shapes = predictor.batch(images)
input_data = {
  'image': batched_images,
  'true_image_shape', origin_shapes
}
output_data_dict = predictor.predict(input_data)