Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PJRT] Fix tensor element type for signed integers #19496

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

PragmaTwice
Copy link
Member

@PragmaTwice PragmaTwice commented Dec 17, 2024

In PJRT, MapElementTypeToMlirType is used in DeviceInstance::TransposeBroadcastDeviceBuffer to generate stablehlo programs for transpose and broadcast tensors.

In the current implementation, IREE_HAL_ELEMENT_TYPE_SINT_N is mapped to MLIR type siN, which is straightforward. But it can produce errors, because in StableHLO ranked tensors of signed integers are not supported (due to a historical reason, I think).

For example, stablehlo.tranpose requires a tensor with constraint HLO_TensorOrPerAxisQuantizedTensor. And HLO_TensorOrPerAxisQuantizedTensor is defined as ranked tensor of many types, including HLO_Int.

Here's the interesting thing: HLO_Int is defined as HLO_SInt or HLO_UInt, and HLO_SInt is defined as signless integers instead of signed integers:

https://github.com/openxla/stablehlo/blob/9e0c9e34d02af255bfe8f896ba7dc910758c6ecf/stablehlo/dialect/Base.td#L39-L43

And above this definition we can see a TODO that states it's due to legacy reasons.

In PJRT plugin MLIR type siN will be used to generate some code like stablehlo.transpose %x : tensor<2x2xsi32>, which is invalid in StableHLO and will lead to a verification error.

ci-exactly: build_packages, test_pjrt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant