forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_utils.h
78 lines (65 loc) · 2.55 KB
/
test_utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#pragma once
#include <memory>
#include <vector>
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
#include <torch/csrc/jit/testing/file_check.h>
namespace torch {
namespace jit {
using namespace torch::jit::tensorexpr;
#define IS_NODE(T, node) \
{ \
auto node_ = to<T>(node); \
ASSERT_NE(nullptr, node_); \
}
#define IS_NODE_WITH_NAME(T, node, name) \
auto name = to<T>(node); \
ASSERT_NE(nullptr, name);
#define IS_NODE_WITH_NAME_AND_CAST(T, node, name, Type) \
NodePtr<T> name = nullptr; \
{ \
auto node_ = to<Cast>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->dtype().scalar_type(), ScalarType::Type); \
name = to<T>(node_->src_value()); \
} \
ASSERT_NE(nullptr, name);
#define IS_IMM_WITH_VAL(T, node, val) \
{ \
auto node_ = to<T##Imm>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->value(), val); \
}
#define IS_VAR_WITH_NAME(node, name) \
{ \
auto node_ = to<Var>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->name_hint(), name); \
}
#define IS_BINOP_W_VARS(T, node, name, v1, v2) \
NodePtr<T> name = nullptr; \
{ \
name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v1); \
IS_VAR_WITH_NAME(name->rhs(), v2); \
}
#define IS_BINOP_W_CONST(T, node, name, v, c) \
NodePtr<T> name = nullptr; \
{ \
name = to<T>(node); \
ASSERT_NE(nullptr, name); \
IS_VAR_WITH_NAME(name->lhs(), v); \
IS_IMM_WITH_VAL(Int, name->rhs(), c); \
}
#define IS_RAND(node) \
{ \
auto node_ = to<Intrinsics>(node); \
ASSERT_NE(nullptr, node_); \
ASSERT_EQ(node_->op_type(), kRand); \
}
void checkIR(StmtPtr s, const std::string& pattern);
void checkExprIR(ExprPtr e, const std::string& pattern);
void checkExprIR(const ExprHandle& e, const std::string& pattern);
} // namespace jit
} // namespace torch