Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support accumulating GEMMs in TileAndFuse with intrinsic without needing c promotion #19546

Open
nirvedhmeshram opened this issue Dec 20, 2024 · 4 comments

Comments

@nirvedhmeshram
Copy link
Contributor

nirvedhmeshram commented Dec 20, 2024

Currently for accumulating GEMM we fail to bufferize if we dont do c promotion in TileAndFuse pipeline when using intrinsics. See dump here . I know there are some tranforms that are still in development but I wasnt sure they will serve this case as well.

@nirvedhmeshram nirvedhmeshram changed the title Support accumulating GEMMs in TileAndFuse without needing c promotion Support accumulating GEMMs in TileAndFuse with intrinssic without needing c promotion Dec 20, 2024
@nirvedhmeshram
Copy link
Contributor Author

nirvedhmeshram commented Dec 27, 2024

Here is what is causing this to fail to bufferize, after GPUFuseAndHoistParallelLoopsPass
We have the following access

%read_write_input = flow.dispatch.tensor.load  ... -> tensor<32x16x32x16xi32>
%workgroup_scf_forall = scf.forall ...  shared_outs(%arg2 = %read_write_input ) -> (tensor<32x16x32x16xi32>) {
  %11 = tensor.empty() : tensor<4x16x4x16xi32>
  %subgroup_scf_forall = scf.forall ... shared_outs(%arg5 = %11)  -> (tensor<4x16x4x16xi32>) {
    %extracted_slice = tensor.extract_slice %arg2
    %extracted_slice_0 = tensor.extract_slice %arg5
      %thread_scf_forall = scf.forall ...  shared_outs(%arg7 = %extracted_slice_0) -> (tensor<2x16x2x16xi32>) {
      %acc_input = tensor.extract_slice %extracted_slice
      %some_intrinsic_compute (...,  %acc_input )
       %acc_output= tensor.extract_slice %arg7

The point being that the accumlator input slice was derived from %read_write_input but accumlator ouput is written back to a slice derived from an empty tensor.

Later after EliminateEmptyTensorsPass
we are left with almost the same access pattern accept that we have

  %extracted_slice = tensor.extract_slice %arg2
  %subgroup_scf_forall = scf.forall ... shared_outs(%arg5 = %extracted_slice)  -> (tensor<4x16x4x16xi32>) {

At this stage the acc_input is a slice of %arg2 while the acc_ouput is written to %arg7 both of which are slices from %read_write_input , This is what is perceived by OneShotAnalysis in hasReadAfterWriteInterference as a RaW Conflict following which it introduces a copy from which we cant recover. As a hack I bypassed the logic in the analysis to say there is no conflict and then things work out with a few unnecessary sub views that i am assuming lower level codegen will take care of and I verified that I got correct numerics with that.
So I see following possible solutions, not sure if any of them are good ones.

  1. During GPUFuseAndHoistParallelLoopsPass add pattern(s) so that we directly pass %arg2 to %arg5 and make the %acc_input take %arg7
  2. Same thing as 1 but after/during EliminateEmptyTensorsPass
  3. Improve the hasReadAfterWriteInterference logic to make it understand this case is not a RaW

@nirvedhmeshram nirvedhmeshram changed the title Support accumulating GEMMs in TileAndFuse with intrinssic without needing c promotion Support accumulating GEMMs in TileAndFuse with intrinsic without needing c promotion Dec 27, 2024
@hanhanW
Copy link
Contributor

hanhanW commented Jan 6, 2025

I spot an issue in your dump. The issue that I spot in the dump is that the output binding is ReadOnly. This is because your input program write the result to the function argument. IREE is not smart enough to create a global buffer for the output tensor, and maybe it should not happen -- I can't interpret the meaning of writing the result into input argument. It is a tensor, not a pointer.

  func.func @bmm(%arg0: tensor<512x128xi8>, %arg1: tensor<512x128xi8>, %arg2: tensor<512x512xi32>) -> tensor<512x512xi32> {
    %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%arg2 : tensor<512x512xi32>) -> tensor<512x512xi32>
    return %0 : tensor<512x512xi32>
  }

I think it is better to have a dump with tensor.empty variant. E.g.,

  func.func @bmm(%arg0: tensor<512x128xi8>, %arg1: tensor<512x128xi8>) -> tensor<512x512xi32> {
    %arg2 = tensor.empty() : tensor<512x512xi32>
    %0 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%arg2 : tensor<512x512xi32>) -> tensor<512x512xi32>
    return %0 : tensor<512x512xi32>
  }

It is clearer because we explicitly ask IREE to create a global buffer for the tensor and output the result at the end. I did not run the example myself because I don't know what the compilation command is.

@hanhanW
Copy link
Contributor

hanhanW commented Jan 6, 2025

If the issue comes from tests/e2e/matmul/, we should probably just fix the generated input programs to the tensor.empty form.

@hanhanW
Copy link
Contributor

hanhanW commented Jan 6, 2025

The other solution might be running something like RemoveArgOutsDependency pattern at global level, which is similar to RemoveCstOutsDependency pattern in the ConvertToDestinationPassingStylePass.

struct RemoveCstOutsDependency
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::LinalgOp op,
PatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
bool modifiedOutput = false;
Location loc = op.getLoc();
for (OpOperand &opOperand : op.getDpsInitsMutable()) {
ElementsAttr attr;
if (!matchPattern(opOperand.get(), m_Constant(&attr)))
continue;
if (!attr.isSplat())
continue;
auto type = llvm::dyn_cast<RankedTensorType>(attr.getType());
if (!type)
continue;
TypedAttr scalarAttr = attr.getValues<TypedAttr>()[0];
modifiedOutput = true;
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, type.getShape(), type.getElementType());
Value cstOp = rewriter.create<arith::ConstantOp>(loc, scalarAttr);
Value fillOp =
rewriter.create<linalg::FillOp>(loc, cstOp, emptyTensor).result();
op->setOperand(opOperand.getOperandNumber(), fillOp);
}
if (!modifiedOutput) {
rewriter.cancelOpModification(op);
return failure();
}
rewriter.finalizeOpModification(op);
return success();
}
};

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants