forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgrad_mode.cpp
78 lines (66 loc) · 2.38 KB
/
grad_mode.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <gtest/gtest.h>
#include <test/cpp/api/support.h>
#include <torch/script.h>
using namespace torch::autograd;
using namespace torch::test;
TEST(GradModeTest, TestRequiresGradFunctionalOp) {
torch::AutoGradMode mode(false);
for (bool requires_grad : {true, false}) {
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor func_out = c * c;
ASSERT_FALSE(func_out.requires_grad());
ASSERT_TRUE(func_out.is_leaf());
}
}
TEST(GradModeTest, TestRequiresGradInplaceOp) {
torch::AutoGradMode mode(false);
for (bool requires_grad : {true, false}) {
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
c.mul_(2);
ASSERT_EQ(c.requires_grad(), requires_grad);
}
}
TEST(GradModeTest, TestRequiresGradViewOp) {
torch::AutoGradMode mode(false);
for (bool requires_grad : {true, false}) {
torch::Tensor c = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor view_out = c.view({2, 3});
ASSERT_EQ(view_out.requires_grad(), requires_grad);
ASSERT_TRUE(view_out.is_leaf());
}
}
TEST(GradModeTest, TestRequiresGradViewOpExiting) {
for (bool requires_grad : {true, false}) {
torch::Tensor s = torch::ones({1, 2, 3}).set_requires_grad(requires_grad);
torch::Tensor a = s.clone();
torch::Tensor view_out, tmp;
{
torch::AutoGradMode mode(false);
view_out = a.view(
{2, 3}); // go through kernels: VariableType, ADInplaceOrView, CPU
assert_tensor_creation_meta(
view_out, torch::autograd::CreationMeta::NO_GRAD_MODE);
ASSERT_EQ(view_out.requires_grad(), requires_grad);
ASSERT_TRUE(view_out.is_leaf());
}
tmp = view_out * view_out;
ASSERT_EQ(tmp.requires_grad(), requires_grad);
if (requires_grad) {
tmp.backward(torch::ones_like(tmp));
// TODO: this behavior is a side effect of issue #11390.
ASSERT_FALSE(view_out.grad().defined());
}
if (requires_grad) {
ASSERT_THROWS_WITH(
view_out.mul_(
2), // go through kernels: VariableType, ADInplaceOrView, CPU
"A view was created in no_grad mode and is being modified inplace");
} else {
view_out.mul_(2);
}
tmp = view_out.view({2, 3});
ASSERT_EQ(tmp.requires_grad(), requires_grad);
assert_tensor_creation_meta(
tmp, torch::autograd::CreationMeta::NO_GRAD_MODE);
}
}