模型导出&预测¶
1 模型导出¶
执行如下python代码即可完成,其中checkpoint_path参数如果不传,默认会使用pipeline_config_path中model_dir下的最新的checkpoint
import easy_vision
easy_vision.export(export_dir, pipeline_config_path, checkpoint_path)
2 输入输出信息¶
使用saved_model模型进行预测,我们需要获取输入输出的tensor节点。
2.1 输入的placeholder定义¶
name | 说明 | shape | type |
---|---|---|---|
image | batched图像tensor, Channel为RGB顺序 | [batch_size, None, None, 3] | tf.uint8 |
true_image_shape | 每一张图像的真实shape,最后一维顺序为[height, width, channel]例如[ [224,224,3], [448, 448, 3]] | [batch_size, 3] | tf.int32 |
注: batch_size为导出模型时,pipeline_config中export_config配置的batch_size,batch_size设置为-1,表示使用动态的batch_size,目前只有分类模型支持动态batch_size。
2.2 输出tensor信息¶
输出为List of Json Result,List的Length与输入图像的张数相等,一下为各模型Json结果的示例与说明
feature_extractor¶
结果示例:
{"feature": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4583122730255127, 0.0]}
字段说明
name | 含义 | shape | type | 备注 |
---|---|---|---|---|
feature | 输出特征 | [feature_dim] | float | 坐标顺序 [top, left, bottom, right] |
classifier¶
结果示例
{
"class": 3,
"class_name": "coho4",
"class_probs": {"coho1": 4.028851974258174e-10,
"coho2": 0.48115724325180054,
"coho3": 5.116515922054532e-07,
"coho4": 0.5188422446937221}
}
name | 含义 | shape | type |
---|---|---|---|
class | 类别id | [] | int32 |
class_name | 类别名称 | [] | string |
class_probs | 所有类别概率 | [num_classes] | dict{key: string, value: float} |
multilabel_classifier¶
结果示例
{
"class": [3, 4],
"class_names": ["coho3", "coho4"],
"class_probs": {"coho1": 4.028851974258174e-10,
"coho2": 0.10115724325180054,
"coho3": 0.6188422446937221,
"coho4": 0.5188422446937221}
}
name | 含义 | shape | type |
---|---|---|---|
class | 类别id | [None] | int32 |
class_names | 类别名称 | [None] | string |
class_probs | 所有类别概率 | [num_classes] | dict{key: string, value: float} |
detector¶
注:同时支持实例分割
结果示例
{
"detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]],
"detection_scores": [0.9942291975021362, 0.9940272569656372],
"detection_classes": [1, 1],
"detection_classe_names": ["text", "text"]
}
字段说明
name | 含义 | shape | type | 备注 |
---|---|---|---|---|
detection_boxes | 检测到的目标框 | [num_detections, 4] | float | 坐标顺序 [top, left, bottom, right] |
detection_scores | 目标检测概率 | num_detections | float | |
detection_classes | 目标区域类别id | num_detections | int | |
detection_class_names | 目标区域类别名称 | num_detections | string | |
detection_masks | 目标区域分割遮罩 | [num_detection, image_height, image_width] | float | 可选 |
detection_keypoints | 目标区域中的关键点 | [num_detection, num_keypoints, 2] | float | 可选 |
detection_roi_features | 目标区域的局部特征图 | [num_detection, roi_height, roi_width, channels] | float | 可选 |
detector_with_rpn¶
结果示例
{
"proposal_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], 243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797],
"proposal_scores": [0.88, 0.56],
"detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]],
"detection_scores": [0.9942291975021362, 0.9940272569656372],
"detection_classes": [1, 1],
"detection_classe_names": ["text", "text"]
}
字段说明
name | 含义 | shape | type | 备注 |
---|---|---|---|---|
proposal_boxes | proposal框 | [num_proposal, 4] | float | |
proposal_scores | prososal的分 | num_proposal | float | |
detection_boxes | 检测到的目标框 | [num_detections, 4] | float | 坐标顺序 [top, left, bottom, right] |
detection_scores | 目标检测概率 | num_detections | float | |
detection_classes | 目标区域类别id | num_detections | int | |
detection_class_names | 目标区域类别名称 | num_detections | string |
segmentor¶
结果示例
{
"probs" : [[[0.8, 0.8], [0.6, 0.7]],[[0.8, 0.5], [0.4, 0.3]]],
"preds" : [[[1,1], [0, 0]], [[0, 0], [1,1]]]
}
字段说明
name | 含义 | shape | type |
---|---|---|---|
probs | 分割像素点概率 | [output_height, output_width, num_classes] | float |
preds | 分割像素类别id | [output_height, output_widths] | int |
text_detector¶
结果示例
{
"detection_keypoints": [[[243.57516479492188, 198.84210205078125], [243.91038513183594, 247.62425231933594], [385.5513916015625, 246.61660766601562], [385.2197570800781, 197.79345703125]], [[292.2718200683594, 114.44700622558594], [292.2237243652344, 164.684814453125], [571.1962890625, 164.931640625], [571.2444458007812, 114.67433166503906]]],
"detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]],
"detection_scores": [0.9942291975021362, 0.9940272569656372],
"detection_classes": [1, 1],
"detection_classe_names": ["text", "text"],
"image_shape": [1024, 968, 3]
}
字段说明
name | 含义 | shape | type | 备注 |
---|---|---|---|---|
detection_boxes | 检测到的文字框 | [num_detections, 4] | float | 坐标顺序 [top, left, bottom, right] |
detection_scores | 文字检测概率 | num_detections | float | |
detection_classes | 文字区域类别id | num_detections | int | |
detection_class_names | 文字区域类别名称 | num_detections | string | |
detection_keypoints | 检测到的文字区域四个角点 | [num_detections, 4, 2] | float | 每个point坐标为(y,x) |
image_shape | 输入图像大小 | [3], 分别为height, width ,channel | list |
text_recognizer¶
结果示例
{
"sequence_predict_ids": [1,2,2008,12],
"sequence_predict_texts": "这是示例",
"sequence_probability": 0.88
}
字段说明
name | 含义 | shape | type |
---|---|---|---|
sequence_predict_ids | 单行文字识别类别id | [text_length] | int |
sequence_predict_texts | 单行文字识别结果 | [] | string |
sequence_probability | 单行文字识别概率 | [] | float |
text_spotter/text_pipeline_predictor¶
结果示例
{
"detection_keypoints": [[[243.57516479492188, 198.84210205078125], [243.91038513183594, 247.62425231933594], [385.5513916015625, 246.61660766601562], [385.2197570800781, 197.79345703125]], [[292.2718200683594, 114.44700622558594], [292.2237243652344, 164.684814453125], [571.1962890625, 164.931640625], [571.2444458007812, 114.67433166503906]]],
"detection_boxes": [[243.5308074951172, 197.69570922851562, 385.59625244140625, 247.7247772216797], [292.1929931640625, 114.28043365478516, 571.2748413085938, 165.09771728515625]],
"detection_scores": [0.9942291975021362, 0.9940272569656372],
"detection_classes": [1, 1],
"detection_classe_names": ["text", "text"],
"detection_texts_ids" : [[1,2,2008,12], [1,2,2008,12]],
"detection_texts": ["这是示例", "这是示例"],
"detection_texts_scores" : [0.88, 0.88],
"image_shape": [1024, 968, 3]
}
字段说明
name | 含义 | shape | type | 备注 |
---|---|---|---|---|
detection_boxes | 检测到的文字框 | [num_detections, 4] | float | 坐标顺序 [top, left, bottom, right] |
detection_scores | 文字检测概率 | num_detections | float | |
detection_classes | 文字区域类别id | num_detections | int | |
detection_class_names | 文字区域类别名称 | num_detections | string | |
detection_keypoints | 检测到的文字区域四个角点 | [num_detections, 4, 2] | float | 每个point坐标为(y,x) |
detection_texts_ids | 单行文字识别类别id | [num_detections, max_text_length] | int | |
detection_texts | 单行文字识别结果 | [num_detections] | string | |
detection_texts_scores | 单行文字识别概率 | [num_detections] | float | |
image_shape | 输入图像大小 | [3], 分别为height, width ,channel | list |
3 本地预测¶
EasyVision提供python预测接口,可以加载easy-vision导出的saved model进行预测,预测api详见API文档。 具体使用demo如下
import easy_vision as ev
import numpy as np
#识别
saved_model_path = 'xxx/xxx'
classifier = ev.Classifier(saved_model_path)
image = np.zeros([640, 480, 3], dtype=np.float32)
output_dict = classifier.predict([image])
#检测
saved_model_path = 'xxx/xxx'
detector = ev.Detector(saved_model_path)
image = np.zeros([640, 480, 3], dtype=np.float32)
output_dict = detector.predict([image])
#文字识别
saved_model_path = 'xxx/xxx'
text_recognizer = ev.TextRecognizer(saved_model_path)
image = np.zeros([640, 480, 3], dtype=np.float32)
output_dict = text_recognizer.predict([image])
#文字检测
saved_model_path = 'xxx/xxx'
text_detector = ev.TextDetector(saved_model_path)
image = np.zeros([640, 480, 3], dtype=np.float32)
output_dict = text_detector.predict([image])
#端到端文字识别
saved_model_path = 'xxx/xxx'
text_spotter = ev.TextSpotter(saved_model_path)
image = np.zeros([640, 480, 3], dtype=np.float32)
output_dict = text_spotter.predict([image])
#基础predictor
saved_model_path = 'xxx/xxx'
predictor = ev.Predictor(saved_model_path)
image = np.zeros([640, 480, 3], dtype=np.float32)
image_list = [image for i in range(10)]
batched_images, origin_shapes = predictor.batch(images)
input_data = {
'image': batched_images,
'true_image_shape', origin_shapes
}
output_data_dict = predictor.predict(input_data)