-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1227 from IntelPython/experimental/inteumliteral
Adds a new literal type to store IntEnum as Literal types.
- Loading branch information
Showing
13 changed files
with
332 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Provides a FlagEnum class to help distinguish IntEnum types that numba-dpex | ||
intends to use as Integer literal types inside the compiler type inferring | ||
infrastructure. | ||
""" | ||
from enum import IntEnum | ||
|
||
|
||
class FlagEnum(IntEnum): | ||
"""Helper class to distinguish IntEnum types that numba-dpex should consider | ||
as Numba Literal types. | ||
""" | ||
|
||
@classmethod | ||
def basetype(cls) -> int: | ||
"""Returns an dummy int object that helps numba-dpex infer the type of | ||
an instance of a FlagEnum class. | ||
Returns: | ||
int: Dummy int value | ||
""" | ||
return int(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Definition of a new Literal type in numba-dpex that allows treating IntEnum | ||
members as integer literals inside a JIT compiled function. | ||
""" | ||
from enum import IntEnum | ||
|
||
from numba.core.pythonapi import box | ||
from numba.core.typeconv import Conversion | ||
from numba.core.types import Integer, Literal | ||
from numba.core.typing.typeof import typeof | ||
|
||
from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError | ||
from numba_dpex.experimental.flag_enum import FlagEnum | ||
|
||
|
||
class IntEnumLiteral(Literal, Integer): | ||
"""A Literal type for IntEnum objects. The type contains the original Python | ||
value of the IntEnum class in it. | ||
""" | ||
|
||
# pylint: disable=W0231 | ||
def __init__(self, value): | ||
self._literal_init(value) | ||
self.name = f"Literal[IntEnum]({value})" | ||
if issubclass(value, FlagEnum): | ||
basetype = typeof(value.basetype()) | ||
Integer.__init__( | ||
self, | ||
name=self.name, | ||
bitwidth=basetype.bitwidth, | ||
signed=basetype.signed, | ||
) | ||
else: | ||
raise IllegalIntEnumLiteralValueError | ||
|
||
def can_convert_to(self, typingctx, other) -> bool: | ||
conv = typingctx.can_convert(self.literal_type, other) | ||
if conv is not None: | ||
return max(conv, Conversion.promote) | ||
return False | ||
|
||
|
||
Literal.ctor_map[IntEnum] = IntEnumLiteral | ||
|
||
|
||
@box(IntEnumLiteral) | ||
def box_literal_integer(typ, val, c): | ||
"""Defines how a Numba representation for an IntEnumLiteral object should | ||
be converted to a PyObject* object and returned back to Python. | ||
""" | ||
val = c.context.cast(c.builder, val, typ, typ.literal_type) | ||
return c.box(typ.literal_type, val) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
36 changes: 36 additions & 0 deletions
36
numba_dpex/tests/experimental/IntEnumLiteral/test_compilation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import dpnp | ||
|
||
import numba_dpex.experimental as exp_dpex | ||
from numba_dpex import Range | ||
from numba_dpex.experimental.flag_enum import FlagEnum | ||
|
||
|
||
class MockFlags(FlagEnum): | ||
FLAG1 = 100 | ||
FLAG2 = 200 | ||
|
||
|
||
@exp_dpex.kernel( | ||
release_gil=False, | ||
no_compile=True, | ||
no_cpython_wrapper=True, | ||
no_cfunc_wrapper=True, | ||
) | ||
def update_with_flag(a): | ||
a[0] = MockFlags.FLAG1 | ||
a[1] = MockFlags.FLAG2 | ||
|
||
|
||
def test_compilation_of_flag_enum(): | ||
"""Tests if a FlagEnum subclass can be used inside a kernel function.""" | ||
a = dpnp.ones(10, dtype=dpnp.int64) | ||
exp_dpex.call_kernel(update_with_flag, Range(10), a) | ||
|
||
assert a[0] == MockFlags.FLAG1 | ||
assert a[1] == MockFlags.FLAG2 | ||
for idx in range(2, 9): | ||
assert a[idx] == 1 |
30 changes: 30 additions & 0 deletions
30
numba_dpex/tests/experimental/IntEnumLiteral/test_type_creation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from enum import IntEnum | ||
|
||
import pytest | ||
|
||
from numba_dpex.core.exceptions import IllegalIntEnumLiteralValueError | ||
from numba_dpex.experimental import IntEnumLiteral | ||
from numba_dpex.experimental.flag_enum import FlagEnum | ||
|
||
|
||
def test_intenumliteral_creation(): | ||
"""Tests the creation of an IntEnumLiteral type.""" | ||
|
||
class DummyFlags(FlagEnum): | ||
DUMMY = 0 | ||
|
||
try: | ||
IntEnumLiteral(DummyFlags) | ||
except: | ||
pytest.fail("Unexpected failure in IntEnumLiteral initialization") | ||
|
||
with pytest.raises(IllegalIntEnumLiteralValueError): | ||
|
||
class SomeKindOfUnknownEnum(IntEnum): | ||
UNKNOWN_FLAG = 1 | ||
|
||
IntEnumLiteral(SomeKindOfUnknownEnum) |
36 changes: 36 additions & 0 deletions
36
numba_dpex/tests/experimental/IntEnumLiteral/test_type_registration.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: 2023 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
from numba.core.datamodel import default_manager | ||
|
||
from numba_dpex.core.datamodel.models import dpex_data_model_manager | ||
from numba_dpex.experimental import IntEnumLiteral | ||
from numba_dpex.experimental.flag_enum import FlagEnum | ||
from numba_dpex.experimental.models import exp_dmm | ||
|
||
|
||
def test_data_model_registration(): | ||
"""Tests that the IntEnumLiteral type is only registered with the | ||
DpexExpKernelTargetContext target. | ||
""" | ||
|
||
class DummyFlags(FlagEnum): | ||
DUMMY = 0 | ||
|
||
dummy = IntEnumLiteral(DummyFlags) | ||
|
||
with pytest.raises(KeyError): | ||
default_manager.lookup(dummy) | ||
|
||
with pytest.raises(KeyError): | ||
dpex_data_model_manager.lookup(dummy) | ||
|
||
try: | ||
exp_dmm.lookup(dummy) | ||
except: | ||
pytest.fail( | ||
"IntEnumLiteral type lookup failed in experimental " | ||
"data model manager" | ||
) |
Oops, something went wrong.