Skip to content

Commit

Permalink
update e2e tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dixin Zhou committed Dec 23, 2024
1 parent a79a1fc commit cab423b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
19 changes: 19 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,14 @@
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"AtenLinearVecMat_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModule_basic",
"ReduceAminSingleDim_basic",
"AtenDotModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
Expand Down Expand Up @@ -1764,6 +1772,9 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModule_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"ElementwiseAddBoolModule_basic",
"Exp2StaticModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
Expand Down Expand Up @@ -3339,6 +3350,10 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleDynamic_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down Expand Up @@ -4098,6 +4113,10 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleDynamic_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
Expand Down
28 changes: 14 additions & 14 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,7 +1950,7 @@ def Aten_TrilinearModuleZerodDimBug_basic(module, tu: TestUtils):
# ==============================================================================


class Aten_BilinearModuleStaticShape(torch.nn.Module):
class Aten_BilinearModule(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -1968,12 +1968,12 @@ def forward(self, input1, input2, weight, bias):
return torch.ops.aten.bilinear(input1, input2, weight, bias)


@register_test_case(module_factory=lambda: Aten_BilinearModuleStaticShape())
def Aten_BilinearModuleStaticShape_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Aten_BilinearModule())
def Aten_BilinearModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 2), tu.rand(8, 3), tu.rand(4, 2, 3), tu.rand(4))


class Aten_BilinearModuleDynamicShape(torch.nn.Module):
class Aten_BilinearModuleDynamic(torch.nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -1991,8 +1991,8 @@ def forward(self, input1, input2, weight, bias):
return torch.ops.aten.bilinear(input1, input2, weight, bias)


@register_test_case(module_factory=lambda: Aten_BilinearModuleDynamicShape())
def Aten_BilinearModuleDynamicShape_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Aten_BilinearModuleDynamic())
def Aten_BilinearModuleDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 2), tu.rand(8, 3), tu.rand(4, 2, 3), tu.rand(4))


Expand All @@ -2004,10 +2004,10 @@ def __init__(self):
@annotate_args(
[
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([2], torch.float32, True),
([3], torch.float32, True),
([4, 2, 3], torch.float32, True),
([4], torch.float32, True),
]
)
def forward(self, input1, input2, weight, bias):
Expand All @@ -2027,10 +2027,10 @@ def __init__(self):
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([8, 6, 12, 2], torch.float32, True),
([8, 6, 12, 3], torch.float32, True),
([4, 2, 3], torch.float32, True),
([4], torch.float32, True),
]
)
def forward(self, input1, input2, weight, bias):
Expand Down

0 comments on commit cab423b

Please sign in to comment.