From e71097107fc63c4a4c77fcd7abe8a03637324149 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 5 Jan 2024 20:07:16 +0100 Subject: [PATCH 1/7] add function to translate functions --- .../custom_ops_type_inference_fails_0.onnx | Bin 0 -> 2086 bytes _unittests/ut_light_api/test_translate.py | 1 - .../ut_light_api/test_translate_classic.py | 10 +++- onnx_array_api/light_api/emitter.py | 50 +++++++++++++++++- onnx_array_api/light_api/inner_emitter.py | 41 ++++++++++++++ onnx_array_api/light_api/translate.py | 30 +++++++---- 6 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 _unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx diff --git a/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx b/_unittests/ut_light_api/_data/custom_ops_type_inference_fails_0.onnx new file mode 100644 index 0000000000000000000000000000000000000000..8116ec338064567cea06fafe45168567813071ed GIT binary patch literal 2086 zcmah}|4!pZ5cb+m;tai-=I<-p}o%9agw8mMoQutf3x%L%zUG1R9cX9>6uiM%wJS! z0y(DYvEAOFrDHpBoemS`byt71a}_#)?|$8LLhfI)eLrMQY?MLf(fsrckxkl(5MR9z zfT|Y-jvvAw1k%%> z@-wDDi=#IO&i7F~FA2va6nX4~$=3U(m73;A+O;e znpIl3Mn`+B7mIl>rYb~NCHzq9JorLfE=gh=2QYLi7P*=rUoJwSh^Y;@GjHg9#A#}oiS1){#G@YhNA+xC}+`7_? zIHpRCA=n*DHF^gr3o7^9df}Th7Bh1W&=6l*>L)18nCS{mmT5q4Q>EDp^ztF|dM<1A z0-=+0#=4##B$*PPWqe$!?6B}&(D99v=_0m8cW`dEqmqNn5_5kr<-+9eCyP+F-EH>t0 z;+$P2;`~paC-vz%tEE-B&3Av%#tkW& zV{>X&VJs6>Mb>*OvjiyyvLa9g1I9Y|WZ(zkr{e>jRj`VEhjBNIrizj){o(stc`p^_ z-YsPv-l4y@tTo2x=UVQ0rRD`(<*>wPOQmuo6 zv~kjhuWKNCxZHv@7`~%rToHD*BZ@e$uEUK9P@TR%(8rSK&Im-6esVS%&lM0hI(e*@ zhpQY_Hr(QMS?qCK7-^1-Gg7Dx*cX$I?=lZ>C;rV17&vaRoM`)@)47l56J)|;7zc{s M$%T|n&0SOSFP?~MApigX literal 0 HcmV?d00001 diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index c2b2c70..e2ed017 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -220,5 +220,4 @@ def test_aionnxml(self): if __name__ == "__main__": - TestTranslate().test_export_if() unittest.main(verbosity=2) diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index cb7d6a4..61b67a7 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -264,7 +264,6 @@ def test_aionnxml(self): .to_onnx() ) code = translate(onx, api="onnx") - print(code) expected = dedent( """ opset_imports = [ @@ -318,6 +317,15 @@ def test_aionnxml(self): self.maxDiff = None self.assertEqual(expected, code) + def test_remove_nodes(self): + path = os.path.join( + os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx" + ) + onx = load(path) + text = translate(onx, api="onnx") + with open("debug_test_remove_nodes.py", "w") as f: + f.write(text) + if __name__ == "__main__": # TestLightApi().test_topk() diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index a1b0e40..47134b0 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -18,6 +18,9 @@ class EventType(IntEnum): END_FUNCTION = 8 INITIALIZER = 9 SPARSE_INITIALIZER = 10 + FUNCTION_INPUT = 11 + FUNCTION_OUTPUT = 12 + FUNCTION_ATTRIBUTES = 13 @classmethod def to_str(cls, self) -> str: @@ -63,6 +66,21 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: if event == EventType.END_GRAPH: return self._emit_end_graph(**kwargs) + if event == EventType.BEGIN_FUNCTION: + return self._emit_begin_function(**kwargs) + + if event == EventType.END_FUNCTION: + return self._emit_end_function(**kwargs) + + if event == EventType.FUNCTION_INPUT: + return self._emit_function_input(**kwargs) + + if event == EventType.FUNCTION_OUTPUT: + return self._emit_function_output(**kwargs) + + if event == EventType.FUNCTION_ATTRIBUTES: + return self._emit_function_attributes(**kwargs) + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: @@ -104,11 +122,21 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: srows = ".".join(rows[:-1]) return [], f"g().{srows}" + if isinstance(value, tuple) and len(value) == 2 and value[1] is None: + # in a function, an attribute receiving a value from an attribute + v = value[0] + name = v.name + ref = v.ref_attr_name + dt = v.type + return [], f"(name={name!r}, ref_attr_name={ref!r}, dt={dt})" + + raise ValueError( f"Unable to render an attribute {type(v)}, " f"attribute type={value[0].type}, " f"dtype={getattr(v, 'dtype', '-')}, " - f"shape={getattr(v, 'shape', '-')}, {value}." + f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " + f"value={value!r}." ) def join(self, rows: List[str], single_line: bool = False) -> str: @@ -161,6 +189,26 @@ def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." ) + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + class Emitter(BaseEmitter): """ diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index f5d5e4d..9abba9b 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -140,3 +140,44 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: lines[-1] = lines[-1][:-1] lines.extend([" )", ")"]) return before_lines + lines + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "", + f"name_f = {kwargs['name']!r}", + f"domain_f = {kwargs['domain']!r}", + "nodes = []", + "inputs = []", + "outputs = []", + "atts = []", + ] + return lines + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"inputs.append({kwargs['name']!r})"] + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + return [f"outputs.append({kwargs['name']!r})"] + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + atts = kwargs["attributes"] + if isinstance(atts, list) and all(map(lambda t: isinstance(t, str), atts)): + return [f"atts.extend({atts!r})"] + raise NotImplementedError(f"Unable to process function attributes {atts!r}.") + + def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: + lines = [ + "functions.append(", + " make_function(", + " domain, ", + " name, ", + " inputs, ", + " outputs, ", + " nodes, ", + " attributes=atts, ", + " opset_imports=opset_imports,", + " )", + ")", + ] + return lines + diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index a61ce24..83bd4e5 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -38,6 +38,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: nodes = self.proto_.graph.node initializers = self.proto_.graph.initializer sparse_initializers = self.proto_.graph.sparse_initializer + attributes = [] elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -48,19 +49,19 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: else: initializers = [] sparse_initializers = [] + attributes = ( + self.proto_.attribute if hasattr(self.proto_, "attribute") else [] + ) else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") - rows.extend( - self.emitter( - EventType.BEGIN_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.BEGIN_GRAPH - ) - ) + if isinstance(self.proto_, FunctionProto): + rows.extend(self.emitter(EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain)) + else: + rows.extend(self.emitter(EventType.BEGIN_GRAPH)) for i in initializers: rows.extend( @@ -71,7 +72,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: for i in inputs: if isinstance(i, str): - rows.extend(self.emitter(EventType.INPUT, name=i)) + rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) else: rows.extend( self.emitter( @@ -85,6 +86,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) + if attributes: + rows.extend( + self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) + ) + for node in nodes: atts = self.extract_attributes(node) rows.extend( @@ -100,7 +106,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: for o in outputs: if isinstance(o, str): - rows.extend(self.emitter(EventType.INPUT, name=o)) + rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) else: rows.extend( self.emitter( @@ -127,7 +133,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: - raise NotImplementedError("Local functions are not yet implemented.") + for fu in self.proto_.functions: + + cl = self.__class__(fu, self.emitter) + text = cl.export(False, single_line=False) + rows.extend(text) rows.extend(self.emitter(EventType.TO_ONNX)) if as_str: From d6acd350904e866422c0d005fe67f32e5e0b74dd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 5 Jan 2024 20:08:57 +0100 Subject: [PATCH 2/7] doc --- CHANGELOGS.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c3c667d..39aaea9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`60`: supports translation of local functions * :pr:`59`: add methods to update nodes in GraphAPI 0.1.3 From 4b5934c072da1fa2168d3dc7bd71df52b9ba7024 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 01:36:14 +0100 Subject: [PATCH 3/7] fix translation of local functions --- _doc/api/light_api.rst | 11 +- .../ut_light_api/test_translate_classic.py | 120 +++++++++- onnx_array_api/light_api/base_emitter.py | 224 ++++++++++++++++++ onnx_array_api/light_api/emitter.py | 213 +---------------- onnx_array_api/light_api/inner_emitter.py | 40 +++- onnx_array_api/light_api/make_helper.py | 69 ++++++ onnx_array_api/light_api/translate.py | 23 +- 7 files changed, 464 insertions(+), 236 deletions(-) create mode 100644 onnx_array_api/light_api/base_emitter.py create mode 100644 onnx_array_api/light_api/make_helper.py diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 544b35f..5cf59e9 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -16,6 +16,13 @@ translate .. autofunction:: onnx_array_api.light_api.translate +make_helper ++++++++++++ + +.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended + +.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute + Classes for the Light API ========================= @@ -68,7 +75,7 @@ Classes for the Translater BaseEmitter +++++++++++ -.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter +.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter :members: Emitter @@ -80,7 +87,7 @@ Emitter EventType +++++++++ -.. autoclass:: onnx_array_api.light_api.translate.EventType +.. autoclass:: onnx_array_api.light_api.base_emitter.EventType :members: InnerEmitter diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index 61b67a7..4d52183 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -5,6 +5,7 @@ from onnx import ModelProto, TensorProto, load from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun from onnx.helper import ( make_tensor_value_info, make_node, @@ -68,7 +69,7 @@ def test_exp(self): functions = [] inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Exp', ['X'], ['Y'] @@ -144,14 +145,14 @@ def test_transpose(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['r0_0'] ) ) nodes.append( - make_node( + make_node_extended( 'Transpose', ['r0_0'], ['Y'], @@ -210,7 +211,7 @@ def test_topk_reverse(self): inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[])) nodes.append( - make_node( + make_node_extended( 'TopK', ['X', 'K'], ['Values', 'Indices'], @@ -284,14 +285,14 @@ def test_aionnxml(self): ) inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) nodes.append( - make_node( + make_node_extended( 'Reshape', ['X', 'r'], ['USE'] ) ) nodes.append( - make_node( + make_node_extended( 'Normalizer', ['USE'], ['Y'], @@ -317,16 +318,115 @@ def test_aionnxml(self): self.maxDiff = None self.assertEqual(expected, code) + @classmethod + def _code_line(cls, code): + lines = code.split("\n") + return "\n".join(f"{i+1:03d} {line}" for i, line in enumerate(lines)) + + @classmethod + def _run(cls, code): + try: + code_compiled = compile(code, "", mode="exec") + except Exception as e: + raise AssertionError( + f"Compilation failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + + import onnx + import onnx.helper + import onnx.numpy_helper + import onnx_array_api.light_api.make_helper + import onnx.reference.custom_element_types + + def from_array_extended(tensor, name=None): + dt = tensor.dtype + if ( + dt == onnx.reference.custom_element_types.float8e4m3fn + and dt.descr[0][0] == "e4m3fn" + ): + to = TensorProto.FLOAT8E4M3FN + dt_to = np.uint8 + elif ( + dt == onnx.reference.custom_element_types.bfloat16 + and dt.descr[0][0] == "bfloat16" + ): + to = TensorProto.BFLOAT16 + dt_to = np.uint16 + else: + return onnx.numpy_helper.from_array(tensor, name) + + t = onnx.numpy_helper.from_array(tensor.astype(dt_to), name) + t.data_type = to + return t + + globs = onnx.__dict__.copy() + globs.update(onnx.helper.__dict__) + globs.update(onnx.numpy_helper.__dict__) + globs.update(onnx_array_api.light_api.make_helper.__dict__) + globs.update(onnx.reference.custom_element_types.__dict__) + globs["from_array_extended"] = from_array_extended + locs = {} + try: + exec(code_compiled, globs, locs) + except Exception as e: + raise AssertionError( + f"Execution failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}" + ) from e + return globs, locs + def test_remove_nodes(self): path = os.path.join( os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx" ) onx = load(path) - text = translate(onx, api="onnx") - with open("debug_test_remove_nodes.py", "w") as f: - f.write(text) + code = translate(onx, api="onnx") + _, locs = self._run(code) + self.assertIn("model", locs) + model = locs["model"] + x = np.arange(4).reshape((-1, 2)).astype(np.float32) + feeds = {"X": x} + + class CustomGemmFloat8E4M3FN(OpRun): + op_domain = "onnx_extented.ortops.tutorial.cpu" + + def _run( + self, + x, + y, + bias=None, + scale_x=None, + scale_y=None, + scale_z=None, + transA=False, + transB=False, + dtype=None, + rowMajor=None, + computeType=None, + ): + if scale_x is not None: + x = x * scale_x + if transA: + x = x.T + if scale_y is not None: + y = y * scale_y + if transB: + y = y.T + z = x @ y + if bias is not None: + z += bias + if scale_z is not None: + z = z / scale_z + return (z,) + + ref = ReferenceEvaluator(onx, new_ops=[CustomGemmFloat8E4M3FN]) + expected = ref.run(None, feeds)[0] + ref2 = ReferenceEvaluator(model, new_ops=[CustomGemmFloat8E4M3FN]) + got = ref2.run(None, feeds)[0] + self.assertEqualArray(expected, got) + + # with open("debug_test_remove_nodes.py", "w") as f: + # f.write(code) if __name__ == "__main__": - # TestLightApi().test_topk() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/base_emitter.py b/onnx_array_api/light_api/base_emitter.py new file mode 100644 index 0000000..3a0dfb6 --- /dev/null +++ b/onnx_array_api/light_api/base_emitter.py @@ -0,0 +1,224 @@ +import inspect +from typing import Any, Dict, List, Optional, Tuple +from enum import IntEnum +import numpy as np +from onnx import AttributeProto + + +class EventType(IntEnum): + START = 0 + INPUT = 1 + OUTPUT = 2 + NODE = 3 + TO_ONNX_MODEL = 4 + BEGIN_GRAPH = 5 + END_GRAPH = 6 + BEGIN_FUNCTION = 7 + END_FUNCTION = 8 + INITIALIZER = 9 + SPARSE_INITIALIZER = 10 + FUNCTION_INPUT = 11 + FUNCTION_OUTPUT = 12 + FUNCTION_ATTRIBUTES = 13 + TO_ONNX_FUNCTION = 14 + + @classmethod + def to_str(cls, self) -> str: + for k, v in EventType.__dict__.items(): + if self == v: + return f"{cls.__name__}.{k}" + + +class BaseEmitter: + def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: + """ + Converts an event into an instruction. + + :param event: event kind + :param kwargs: event parameters + :return: list of instructions + """ + + if event == EventType.NODE: + return self._emit_node(**kwargs) + + if event == EventType.INITIALIZER: + return self._emit_initializer(**kwargs) + + if event == EventType.SPARSE_INITIALIZER: + return self._emit_sparse_initializer(**kwargs) + + if event == EventType.INPUT: + return self._emit_input(**kwargs) + + if event == EventType.OUTPUT: + return self._emit_output(**kwargs) + + if event == EventType.START: + return self._emit_start(**kwargs) + + if event == EventType.TO_ONNX_MODEL: + return self._emit_to_onnx_model(**kwargs) + + if event == EventType.TO_ONNX_FUNCTION: + return self._emit_to_onnx_function(**kwargs) + + if event == EventType.BEGIN_GRAPH: + return self._emit_begin_graph(**kwargs) + + if event == EventType.END_GRAPH: + return self._emit_end_graph(**kwargs) + + if event == EventType.BEGIN_FUNCTION: + return self._emit_begin_function(**kwargs) + + if event == EventType.END_FUNCTION: + return self._emit_end_function(**kwargs) + + if event == EventType.FUNCTION_INPUT: + return self._emit_function_input(**kwargs) + + if event == EventType.FUNCTION_OUTPUT: + return self._emit_function_output(**kwargs) + + if event == EventType.FUNCTION_ATTRIBUTES: + return self._emit_function_attributes(**kwargs) + + raise ValueError(f"Unexpected event {EventType.to_str(event)}.") + + def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: + """ + Renders an attribute value into a string. + + :param value: value to converter + :return: rows to append before, actual value + """ + v = value[-1] + if value[0].type == AttributeProto.TENSOR: + repl = {"bool": "bool_", "object": "object_", "str": "str_"} + sdtype = repl.get(str(v.dtype), str(str(v.dtype))) + return [], ( + f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), " + f"name={value[0].name!r})" + ) + if isinstance(v, (int, float, list)): + return [], str(v) + if isinstance(v, str): + return [], f"{v!r}" + if isinstance(v, np.ndarray): + if not v.shape: + return [], str(v) + if len(v.shape) == 1: + if value[0].type in ( + AttributeProto.INTS, + AttributeProto.FLOATS, + AttributeProto.STRINGS, + ): + return [], str(v.tolist()) + + if value[0].type == AttributeProto.GRAPH: + from .translate import Translater + + tr = Translater(value[0].g, emitter=self) + rows = tr.export(as_str=False, single_line=False) + # last instruction is to_onnx, let's drop it. + srows = ".".join(rows[:-1]) + return [], f"g().{srows}" + + if isinstance(value, tuple) and len(value) == 2 and value[1] is None: + # in a function, an attribute receiving a value from an attribute + v = value[0] + name = v.name + ref = v.ref_attr_name + dt = v.type + return [], self._make_attribute(name=name, ref_attr_name=ref, attr_type=dt) + + raise ValueError( + f"Unable to render an attribute {type(v)}, " + f"attribute type={value[0].type}, " + f"dtype={getattr(v, 'dtype', '-')}, " + f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " + f"value={value!r}." + ) + + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def join(self, rows: List[str], single_line: bool = False) -> str: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) + + def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: + raise NotImplementedError( + f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." + ) diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index 47134b0..d4f6172 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -1,213 +1,6 @@ -import inspect -from typing import Any, Dict, List, Tuple -from enum import IntEnum -import numpy as np -from onnx import AttributeProto +from typing import Any, Dict, List from .annotations import ELEMENT_TYPE_NAME - - -class EventType(IntEnum): - START = 0 - INPUT = 1 - OUTPUT = 2 - NODE = 3 - TO_ONNX = 4 - BEGIN_GRAPH = 5 - END_GRAPH = 6 - BEGIN_FUNCTION = 7 - END_FUNCTION = 8 - INITIALIZER = 9 - SPARSE_INITIALIZER = 10 - FUNCTION_INPUT = 11 - FUNCTION_OUTPUT = 12 - FUNCTION_ATTRIBUTES = 13 - - @classmethod - def to_str(cls, self) -> str: - for k, v in EventType.__dict__.items(): - if self == v: - return f"{cls.__name__}.{k}" - - -class BaseEmitter: - def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]: - """ - Converts an event into an instruction. - - :param event: event kind - :param kwargs: event parameters - :return: list of instructions - """ - - if event == EventType.NODE: - return self._emit_node(**kwargs) - - if event == EventType.INITIALIZER: - return self._emit_initializer(**kwargs) - - if event == EventType.SPARSE_INITIALIZER: - return self._emit_sparse_initializer(**kwargs) - - if event == EventType.INPUT: - return self._emit_input(**kwargs) - - if event == EventType.OUTPUT: - return self._emit_output(**kwargs) - - if event == EventType.START: - return self._emit_start(**kwargs) - - if event == EventType.TO_ONNX: - return self._emit_to_onnx(**kwargs) - - if event == EventType.BEGIN_GRAPH: - return self._emit_begin_graph(**kwargs) - - if event == EventType.END_GRAPH: - return self._emit_end_graph(**kwargs) - - if event == EventType.BEGIN_FUNCTION: - return self._emit_begin_function(**kwargs) - - if event == EventType.END_FUNCTION: - return self._emit_end_function(**kwargs) - - if event == EventType.FUNCTION_INPUT: - return self._emit_function_input(**kwargs) - - if event == EventType.FUNCTION_OUTPUT: - return self._emit_function_output(**kwargs) - - if event == EventType.FUNCTION_ATTRIBUTES: - return self._emit_function_attributes(**kwargs) - - raise ValueError(f"Unexpected event {EventType.to_str(event)}.") - - def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: - """ - Renders an attribute value into a string. - - :param value: value to converter - :return: rows to append before, actual value - """ - v = value[-1] - if value[0].type == AttributeProto.TENSOR: - repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(v.dtype), str(str(v.dtype))) - return [], ( - f"from_array(np.array({v.tolist()}, dtype=np.{sdtype}), " - f"name={value[0].name!r})" - ) - if isinstance(v, (int, float, list)): - return [], str(v) - if isinstance(v, str): - return [], f"{v!r}" - if isinstance(v, np.ndarray): - if not v.shape: - return [], str(v) - if len(v.shape) == 1: - if value[0].type in ( - AttributeProto.INTS, - AttributeProto.FLOATS, - AttributeProto.STRINGS, - ): - return [], str(v.tolist()) - - if value[0].type == AttributeProto.GRAPH: - from .translate import Translater - - tr = Translater(value[0].g, emitter=self) - rows = tr.export(as_str=False, single_line=False) - # last instruction is to_onnx, let's drop it. - srows = ".".join(rows[:-1]) - return [], f"g().{srows}" - - if isinstance(value, tuple) and len(value) == 2 and value[1] is None: - # in a function, an attribute receiving a value from an attribute - v = value[0] - name = v.name - ref = v.ref_attr_name - dt = v.type - return [], f"(name={name!r}, ref_attr_name={ref!r}, dt={dt})" - - - raise ValueError( - f"Unable to render an attribute {type(v)}, " - f"attribute type={value[0].type}, " - f"dtype={getattr(v, 'dtype', '-')}, " - f"shape={getattr(v, 'shape', '-')}, type(value)={type(value)}, " - f"value={value!r}." - ) - - def join(self, rows: List[str], single_line: bool = False) -> str: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_sparse_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_function_output(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) - - def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]: - raise NotImplementedError( - f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded." - ) +from .base_emitter import BaseEmitter class Emitter(BaseEmitter): @@ -233,7 +26,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: args.append(f"opsets={opsets}") return [f"start({', '.join(args)})"] - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: return ["to_onnx()"] def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index 9abba9b..9484e74 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from onnx import AttributeProto from .annotations import ELEMENT_TYPE_NAME from .emitter import BaseEmitter @@ -31,6 +31,15 @@ def render_attribute_value(self, value: Any) -> Tuple[List[str], str]: return super().render_attribute_value(value) + def _make_attribute( + self, name: str, attr_type: int, ref_attr_name: Optional[str] = None + ) -> str: + if ref_attr_name is None: + raise NotImplementedError( + f"Cannot create attribute with name={name!r}, attr_type={attr_type}." + ) + return f"make_ref_attribute(key={name!r}, attr_type={attr_type}, ref_attr_name={ref_attr_name!r})" + def join(self, rows: List[str], single_line: bool = False) -> str: "Returns the separators. `single_line` is unused." return "\n".join(rows) @@ -43,7 +52,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: lines.append("]") return lines - def _emit_to_onnx(self, **kwargs: Dict[str, Any]) -> List[str]: + def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "model = make_model(", " graph,", @@ -82,11 +91,22 @@ def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]: name = kwargs["name"] value = kwargs["value"] repl = {"bool": "bool_", "object": "object_", "str": "str_"} - sdtype = repl.get(str(value.dtype), str(str(value.dtype))) + fra = "from_array" + sdtype = repl.get(str(value.dtype), str(value.dtype)) + if sdtype.startswith("("): + from onnx.reference.custom_element_types import float8e4m3fn + + if sdtype == str(float8e4m3fn): + sdtype = "float8e4m3fn" + fra = "from_array_extended" + else: + raise NotImplementedError(f"Unexpected dtype={sdtype}.") + else: + sdtype = f"np.{sdtype}" return [ "initializers.append(", - " from_array(", - f" np.array({value.tolist()}, dtype=np.{sdtype}),", + f" {fra}(", + f" np.array({value.tolist()}, dtype={sdtype}),", f" name={name!r}", " )", ")", @@ -124,7 +144,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: before_lines = [] lines = [ "nodes.append(", - " make_node(", + " make_node_extended(", f" {op_type!r},", f" {inputs},", f" {outputs},", @@ -153,6 +173,9 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]: ] return lines + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]: return [f"inputs.append({kwargs['name']!r})"] @@ -169,8 +192,8 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: lines = [ "functions.append(", " make_function(", - " domain, ", - " name, ", + " domain_f, ", + " name_f, ", " inputs, ", " outputs, ", " nodes, ", @@ -180,4 +203,3 @@ def _emit_end_function(self, **kwargs: Dict[str, Any]) -> List[str]: ")", ] return lines - diff --git a/onnx_array_api/light_api/make_helper.py b/onnx_array_api/light_api/make_helper.py new file mode 100644 index 0000000..2e1c092 --- /dev/null +++ b/onnx_array_api/light_api/make_helper.py @@ -0,0 +1,69 @@ +from typing import Any, Optional, Sequence +from onnx import AttributeProto, NodeProto +from onnx.helper import make_attribute + + +def make_ref_attribute( + key: str, attr_type: int, ref_attr_name: Optional[str] = None +) -> AttributeProto: + """ + Creates an attribute. + + :param key: atttribute name + :param attr_type: attribute type + :param ref_attr_name: if not None, link this attribute + to a function attribute + :return: attribute + """ + att = AttributeProto() + att.name = key + att.type = attr_type + att.ref_attr_name = ref_attr_name + return att + + +def make_node_extended( + op_type: str, + inputs: Sequence[str], + outputs: Sequence[str], + name: Optional[str] = None, + doc_string: Optional[str] = None, + domain: Optional[str] = None, + **kwargs: Any, +) -> NodeProto: + """ + Constructs a NodeProto. + + Args: + op_type: The name of the operator to construct + inputs: list of input names + outputs: list of output names + name: optional unique identifier for NodeProto + doc_string: optional documentation string for NodeProto + domain: optional domain for NodeProto. + If it's None, we will just use default domain (which is empty) + **kwargs (dict): the attributes of the node. The acceptable values + are documented in :func:`make_attribute`. + + Returns: + NodeProto + """ + node = NodeProto() + node.op_type = op_type + node.input.extend(inputs) + node.output.extend(outputs) + if name: + node.name = name + if doc_string: + node.doc_string = doc_string + if domain is not None: + node.domain = domain + if kwargs: + for key, value in sorted(kwargs.items()): + if value is None: + continue + if isinstance(value, AttributeProto): + node.attribute.append(value) + else: + node.attribute.append(make_attribute(key, value)) + return node diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index 83bd4e5..7040f28 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -2,7 +2,9 @@ import numpy as np from onnx import AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto from onnx.numpy_helper import to_array -from .emitter import EventType, Emitter +from ..reference import to_array_extended +from .base_emitter import EventType +from .emitter import Emitter class Translater: @@ -30,6 +32,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: :return: list of instructions """ rows = [] + last_event = None if isinstance(self.proto_, ModelProto): opsets = {d.domain: d.version for d in self.proto_.opset_import} rows.extend(self.emitter(EventType.START, opsets=opsets)) @@ -39,6 +42,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: initializers = self.proto_.graph.initializer sparse_initializers = self.proto_.graph.sparse_initializer attributes = [] + last_event = EventType.TO_ONNX_MODEL elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -52,6 +56,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: attributes = ( self.proto_.attribute if hasattr(self.proto_, "attribute") else [] ) + last_event = EventType.TO_ONNX_FUNCTION else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") @@ -59,14 +64,23 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: raise NotImplementedError("Sparse initializer not supported yet.") if isinstance(self.proto_, FunctionProto): - rows.extend(self.emitter(EventType.BEGIN_FUNCTION, name=self.proto_.name, domain=self.proto_.domain)) + rows.extend( + self.emitter( + EventType.BEGIN_FUNCTION, + name=self.proto_.name, + domain=self.proto_.domain, + ) + ) else: rows.extend(self.emitter(EventType.BEGIN_GRAPH)) for i in initializers: rows.extend( self.emitter( - EventType.INITIALIZER, name=i.name, init=i, value=to_array(i) + EventType.INITIALIZER, + name=i.name, + init=i, + value=to_array_extended(i), ) ) @@ -134,12 +148,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: if isinstance(self.proto_, ModelProto) and len(self.proto_.functions) > 0: for fu in self.proto_.functions: - cl = self.__class__(fu, self.emitter) text = cl.export(False, single_line=False) rows.extend(text) - rows.extend(self.emitter(EventType.TO_ONNX)) + rows.extend(self.emitter(last_event)) if as_str: return self.emitter.join(rows, single_line=single_line) return rows From ccf07e7ac7588ca1bba884630bd95ce9d70e2eb8 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 10:30:38 +0100 Subject: [PATCH 4/7] refactoring --- _doc/api/light_api.rst | 12 ++++----- _unittests/ut_light_api/test_translate.py | 3 ++- onnx_array_api/light_api/__init__.py | 2 +- onnx_array_api/light_api/inner_emitter.py | 2 +- .../{emitter.py => light_emitter.py} | 5 +++- onnx_array_api/light_api/translate.py | 25 +++++++++++-------- 6 files changed, 28 insertions(+), 21 deletions(-) rename onnx_array_api/light_api/{emitter.py => light_emitter.py} (96%) diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 5cf59e9..379af90 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -78,12 +78,6 @@ BaseEmitter .. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter :members: -Emitter -+++++++ - -.. autoclass:: onnx_array_api.light_api.emitter.Emitter - :members: - EventType +++++++++ @@ -96,6 +90,12 @@ InnerEmitter .. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter :members: +LightEmitter +++++++++++++ + +.. autoclass:: onnx_array_api.light_api.emitter.LightEmitter + :members: + Translater ++++++++++ diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index e2ed017..9974f81 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -6,7 +6,7 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.light_api import start, translate, g -from onnx_array_api.light_api.emitter import EventType +from onnx_array_api.light_api.base_emitter import EventType OPSET_API = min(19, onnx_opset_version() - 1) @@ -220,4 +220,5 @@ def test_aionnxml(self): if __name__ == "__main__": + TestTranslate().test_export_if() unittest.main(verbosity=2) diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index be6e9dd..558e626 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light") :param single_line: as a single line or not :param api: API to export into, default is `"light"` and this is handle by class - :class:`onnx_array_api.light_api.emitter.Emitter`, + :class:`onnx_array_api.light_api.light_emitter.LightEmitter`, another value is `"onnx"` which is the inner API implemented in onnx package. :return: code diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index 9484e74..72ee725 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Tuple from onnx import AttributeProto from .annotations import ELEMENT_TYPE_NAME -from .emitter import BaseEmitter +from .base_emitter import BaseEmitter from .translate import Translater diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/light_emitter.py similarity index 96% rename from onnx_array_api/light_api/emitter.py rename to onnx_array_api/light_api/light_emitter.py index d4f6172..c2925b5 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/light_emitter.py @@ -3,7 +3,7 @@ from .base_emitter import BaseEmitter -class Emitter(BaseEmitter): +class LightEmitter(BaseEmitter): """ Converts event into proper code. """ @@ -29,6 +29,9 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]: def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]: return ["to_onnx()"] + def _emit_to_onnx_function(self, **kwargs: Dict[str, Any]) -> List[str]: + return [] + def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]: return [] diff --git a/onnx_array_api/light_api/translate.py b/onnx_array_api/light_api/translate.py index 7040f28..31c1bce 100644 --- a/onnx_array_api/light_api/translate.py +++ b/onnx_array_api/light_api/translate.py @@ -4,7 +4,7 @@ from onnx.numpy_helper import to_array from ..reference import to_array_extended from .base_emitter import EventType -from .emitter import Emitter +from .light_emitter import LightEmitter class Translater: @@ -15,10 +15,10 @@ class Translater: def __init__( self, proto: Union[ModelProto, FunctionProto, GraphProto], - emitter: Optional[Emitter] = None, + emitter: Optional[LightEmitter] = None, ): self.proto_ = proto - self.emitter = emitter or Emitter() + self.emitter = emitter or LightEmitter() def __repr__(self) -> str: return f"{self.__class__.__name__}(<{type(self.proto_)})" @@ -43,6 +43,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: sparse_initializers = self.proto_.graph.sparse_initializer attributes = [] last_event = EventType.TO_ONNX_MODEL + is_function = False elif isinstance(self.proto_, (FunctionProto, GraphProto)): inputs = self.proto_.input outputs = self.proto_.output @@ -56,14 +57,17 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: attributes = ( self.proto_.attribute if hasattr(self.proto_, "attribute") else [] ) - last_event = EventType.TO_ONNX_FUNCTION + is_function = isinstance(self.proto_, FunctionProto) + last_event = ( + EventType.TO_ONNX_FUNCTION if is_function else EventType.TO_ONNX_MODEL + ) else: raise ValueError(f"Unexpected type {type(self.proto_)} for proto.") if sparse_initializers: raise NotImplementedError("Sparse initializer not supported yet.") - if isinstance(self.proto_, FunctionProto): + if is_function: rows.extend( self.emitter( EventType.BEGIN_FUNCTION, @@ -85,7 +89,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) for i in inputs: - if isinstance(i, str): + if is_function: rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i)) else: rows.extend( @@ -100,7 +104,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) ) - if attributes: + if is_function and attributes: rows.extend( self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes)) ) @@ -119,7 +123,7 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: ) for o in outputs: - if isinstance(o, str): + if is_function: rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o)) else: rows.extend( @@ -137,11 +141,10 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]: name = self.proto_.name else: name = self.proto_.graph.name + rows.extend( self.emitter( - EventType.END_FUNCTION - if isinstance(self.proto_, FunctionProto) - else EventType.END_GRAPH, + EventType.END_FUNCTION if is_function else EventType.END_GRAPH, name=name, ) ) From 97ac1556d1f209c7907073fd29bb750f69c5f24d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 10:42:22 +0100 Subject: [PATCH 5/7] fix missing import --- _unittests/ut_light_api/test_backend_export.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index b0c1cbc..65f3690 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -17,9 +17,11 @@ make_opsetid, make_tensor_value_info, ) +from onnx.reference.op_run import to_array_extended from onnx.numpy_helper import from_array, to_array from onnx.backend.base import Device, DeviceType from onnx_array_api.reference import ExtendedReferenceEvaluator +from onnx_array_api.light_api.make_helper import make_node_extended from onnx_array_api.light_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot @@ -85,6 +87,7 @@ def run( locs = { "np": numpy, "to_array": to_array, + "to_array_extended": to_array_extended, "from_array": from_array, "TensorProto": TensorProto, "make_function": make_function, @@ -92,6 +95,7 @@ def run( "make_model": make_model, "make_graph": make_graph, "make_node": make_node, + "make_node_extended": make_node_extended, "make_tensor_value_info": make_tensor_value_info, } globs = locs.copy() From e531c13f1763c47ec2d58dfebeaddb53fa2d20b6 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 10:57:09 +0100 Subject: [PATCH 6/7] verbose --- .../ut_light_api/test_backend_export.py | 7 ++++-- onnx_array_api/light_api/make_helper.py | 22 ++++++++----------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/_unittests/ut_light_api/test_backend_export.py b/_unittests/ut_light_api/test_backend_export.py index 65f3690..f597d21 100644 --- a/_unittests/ut_light_api/test_backend_export.py +++ b/_unittests/ut_light_api/test_backend_export.py @@ -1,3 +1,4 @@ +import sys import unittest from typing import Any, Dict, List, Optional from difflib import unified_diff @@ -25,6 +26,8 @@ from onnx_array_api.light_api import translate from onnx_array_api.plotting.text_plot import onnx_simple_text_plot +verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0 + class ReferenceImplementationError(RuntimeError): "Fails, export cannot be compared." @@ -36,7 +39,7 @@ class ExportWrapper: def __init__(self, model): self.model = model - self.expected_sess = ExtendedReferenceEvaluator(self.model) + self.expected_sess = ExtendedReferenceEvaluator(self.model, verbose=verbosity) @property def input_names(self): @@ -109,7 +112,7 @@ def run( f"Unable to executed code for api {api!r}\n{new_code}" ) from e export_model = locs["model"] - ref = ExtendedReferenceEvaluator(export_model) + ref = ExtendedReferenceEvaluator(export_model, verbose=verbosity) try: got = ref.run(names, feeds) except (TypeError, AttributeError) as e: diff --git a/onnx_array_api/light_api/make_helper.py b/onnx_array_api/light_api/make_helper.py index 2e1c092..8b2703c 100644 --- a/onnx_array_api/light_api/make_helper.py +++ b/onnx_array_api/light_api/make_helper.py @@ -34,19 +34,15 @@ def make_node_extended( """ Constructs a NodeProto. - Args: - op_type: The name of the operator to construct - inputs: list of input names - outputs: list of output names - name: optional unique identifier for NodeProto - doc_string: optional documentation string for NodeProto - domain: optional domain for NodeProto. - If it's None, we will just use default domain (which is empty) - **kwargs (dict): the attributes of the node. The acceptable values - are documented in :func:`make_attribute`. - - Returns: - NodeProto + :param op_type: The name of the operator to construct + :param inputs: list of input names + :param outputs: list of output names + :param name: optional unique identifier for NodeProto + :param doc_string: optional documentation string for NodeProto + :param domain: optional domain for NodeProto. + If it's None, we will just use default domain (which is empty) + :param kwargs: the attributes of the node. + :return: node proto """ node = NodeProto() node.op_type = op_type From ef821f52cb6f5841d090469d2f57632d3bb90d82 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 8 Jan 2024 11:13:12 +0100 Subject: [PATCH 7/7] link --- _doc/api/light_api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 379af90..15342c1 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -93,7 +93,7 @@ InnerEmitter LightEmitter ++++++++++++ -.. autoclass:: onnx_array_api.light_api.emitter.LightEmitter +.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter :members: Translater