[PJRT] Fix tensor element type for signed integers #19496
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In PJRT,
MapElementTypeToMlirType
is used inDeviceInstance::TransposeBroadcastDeviceBuffer
to generate stablehlo programs for transpose and broadcast tensors.In the current implementation,
IREE_HAL_ELEMENT_TYPE_SINT_N
is mapped to MLIR typesiN
, which is straightforward. But it can produce errors, because in StableHLO ranked tensors of signed integers are not supported (due to a historical reason, I think).For example,
stablehlo.tranpose
requires a tensor with constraintHLO_TensorOrPerAxisQuantizedTensor
. AndHLO_TensorOrPerAxisQuantizedTensor
is defined as ranked tensor of many types, includingHLO_Int
.Here's the interesting thing:
HLO_Int
is defined asHLO_SInt
orHLO_UInt
, andHLO_SInt
is defined as signless integers instead of signed integers:https://github.com/openxla/stablehlo/blob/9e0c9e34d02af255bfe8f896ba7dc910758c6ecf/stablehlo/dialect/Base.td#L39-L43
And above this definition we can see a TODO that states it's due to legacy reasons.
In PJRT plugin MLIR type
siN
will be used to generate some code likestablehlo.transpose %x : tensor<2x2xsi32>
, which is invalid in StableHLO and will lead to a verification error.ci-exactly: build_packages, test_pjrt