diff --git a/ppocr/utils/export_model.py b/ppocr/utils/export_model.py index 3394c20c79..3f613f5a90 100644 --- a/ppocr/utils/export_model.py +++ b/ppocr/utils/export_model.py @@ -53,7 +53,7 @@ def dump_infer_config(config, path, logger): } elif arch_config["model_type"] == "det": common_dynamic_shapes = { - "x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]] + "x": [[1, 3, 160, 160], [1, 3, 640, 640], [1, 3, 1280, 1280]] } elif arch_config["algorithm"] == "SLANet": common_dynamic_shapes = { @@ -64,11 +64,17 @@ def dump_infer_config(config, path, logger): "x": [[1, 3, 224, 224], [1, 3, 448, 448], [8, 3, 1280, 1280]] } elif arch_config["algorithm"] == "UniMERNet": - common_dynamic_shapes = {"x": [[1, 3, 192, 672]]} + common_dynamic_shapes = { + "x": [[1, 3, 192, 672], [1, 3, 192, 672], [8, 3, 192, 672]] + } elif arch_config["algorithm"] == "PP-FormulaNet-L": - common_dynamic_shapes = {"x": [[1, 3, 768, 768]]} + common_dynamic_shapes = { + "x": [[1, 3, 768, 768], [1, 3, 768, 768], [8, 3, 768, 768]] + } elif arch_config["algorithm"] == "PP-FormulaNet-S": - common_dynamic_shapes = {"x": [[1, 3, 384, 384]]} + common_dynamic_shapes = { + "x": [[1, 3, 384, 384], [1, 3, 384, 384], [8, 3, 384, 384]] + } else: common_dynamic_shapes = None @@ -345,17 +351,22 @@ def export_single_model( ModuleNotFoundError ): # Encryption is not needed if the module cannot be imported print("Skipping import of the encryption module") + paddle_version = version.parse(paddle.__version__) if config["Global"].get("export_with_pir", False): - paddle_version = version.parse(paddle.__version__) assert ( paddle_version >= version.parse("3.0.0b2") or paddle_version == version.parse("0.0.0") ) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"] paddle.jit.save(model, save_path) else: - model.forward.rollback() - with paddle.pir_utils.OldIrGuard(): - model = dynamic_to_static(model, arch_config, logger, input_shape) + if paddle_version >= version.parse( + "3.0.0b2" + ) or paddle_version == version.parse("0.0.0"): + model.forward.rollback() + with paddle.pir_utils.OldIrGuard(): + model = dynamic_to_static(model, arch_config, logger, input_shape) + paddle.jit.save(model, save_path) + else: paddle.jit.save(model, save_path) else: quanter.save_quantized_model(model, save_path)