Skip to content
Merged
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
* :pr:`61`: adds function to plot onnx model as graphs
* :pr:`60`: supports translation of local functions
Expand Down
82 changes: 82 additions & 0 deletions _unittests/ut_reference/test_reference_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,88 @@ def test_fused_matmul11(self):
got = ref.run(None, {"X": a, "Y": a})
self.assertEqualArray(a.T @ a.T, got[0])

def test_memcpy(self):
model = make_model(
make_graph(
[
make_node("MemcpyToHost", ["X"], ["Z"]),
make_node("MemcpyFromHost", ["X"], ["Z"]),
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(a, got[0])

def test_quick_gelu(self):
from onnxruntime import InferenceSession

for alpha in [0.0, 2.0]:
model = make_model(
make_graph(
[
make_node(
"QuickGelu",
["X"],
["Z"],
domain="com.microsoft",
alpha=alpha,
)
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
sess = InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
expected = sess.run(None, {"X": a})
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(expected[0], got[0])

def test_scatter_elements(self):
model = make_model(
make_graph(
[
make_node(
"ScatterElements",
["data", "indices", "updates"],
["Z"],
axis=3,
reduction="add",
)
],
"name",
[
make_tensor_value_info("data", TensorProto.FLOAT, None),
make_tensor_value_info("indices", TensorProto.INT64, None),
make_tensor_value_info("updates", TensorProto.FLOAT, None),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18)],
)
data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
indices = np.array([[[[0]]]], dtype=np.int64)
updates = np.array([[[[1]]]], dtype=np.float32)
y = np.array(
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
).reshape((2, 2, 2, 2))
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
self.assertEqualArray(y, got[0])


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 7 additions & 0 deletions onnx_array_api/reference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from .ops.op_concat import Concat
from .ops.op_constant_of_shape import ConstantOfShape
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_quick_gelu import QuickGelu
from .ops.op_scatter_elements import ScatterElements


logger = getLogger("onnx-array-api-eval")
Expand All @@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
CastLike_19,
ConstantOfShape,
FusedMatMul,
MemcpyFromHost,
MemcpyToHost,
QuickGelu,
ScatterElements,
]

@staticmethod
Expand Down
11 changes: 11 additions & 0 deletions onnx_array_api/reference/ops/op_memcpy_host.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from onnx.reference.op_run import OpRun


class MemcpyFromHost(OpRun):
def _run(self, x):
return (x,)


class MemcpyToHost(OpRun):
def _run(self, x):
return (x,)
23 changes: 23 additions & 0 deletions onnx_array_api/reference/ops/op_quick_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
from onnx.reference.op_run import OpRun


def sigmoid(x): # type: ignore
if x > 0:
return 1 / (1 + np.exp(-x))
return np.exp(x) / (1 + np.exp(x))


class QuickGelu(OpRun):
op_domain = "com.microsoft"

def __init__(self, onnx_node, run_params): # type: ignore
OpRun.__init__(self, onnx_node, run_params)
self.vf = np.vectorize(sigmoid)

def _run(self, X, alpha=1.0):
if len(X.shape) == 0:
return ((X * sigmoid(X * alpha)).astype(X.dtype),)
if X.size == 0:
return (X,)
return ((X * self.vf(X * alpha)).astype(X.dtype),)
98 changes: 98 additions & 0 deletions onnx_array_api/reference/ops/op_scatter_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np

from onnx.reference.op_run import OpRun


def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
if reduction == "add":

def f(x, y):
return x + y

elif reduction == "min":

def f(x, y):
return min(x, y)

elif reduction == "max":

def f(x, y):
return max(x, y)

else:

def f(x, y):
return y

if axis < 0:
axis = data.ndim + axis

if len(data.shape) == 1 and axis == 0:
scattered = np.copy(data)
for pos, up in zip(indices, updates):
scattered[pos] = f(scattered[pos], up)
return scattered

if len(indices.shape) == 2:
scattered = np.copy(data)
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
scattered[indices[i, j], j] = f(
scattered[indices[i, j], j], updates[i, j]
)
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
scattered[i, indices[i, j]] = f(
scattered[i, indices[i, j]], updates[i, j]
)
return scattered

if len(indices.shape) == 3:
scattered = np.copy(data)
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[indices[i, j, k], j, k] = f(
scattered[indices[i, j, k], j, k], updates[i, j, k]
)
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[i, indices[i, j, k], k] = f(
scattered[i, indices[i, j, k], k], updates[i, j, k]
)
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[i, j, indices[i, j, k]] = f(
scattered[i, j, indices[i, j, k]], updates[i, j, k]
)
return scattered

if len(indices.shape) == 4:
scattered = np.copy(data)
if axis == 3:
for a in range(indices.shape[0]):
for i in range(indices.shape[1]):
for j in range(indices.shape[2]):
for k in range(indices.shape[3]):
scattered[a, i, j, indices[a, i, j, k]] = f(
scattered[a, i, j, indices[a, i, j, k]],
updates[a, i, j, k],
)
return scattered

raise RuntimeError(
f"Not implemented for indices.shape={indices.shape} and axis={axis}"
)


class ScatterElements(OpRun):
def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
return (res,)