Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] add simulation support for user-defined iterable objects #70620

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -788,8 +788,8 @@ def LOAD_ATTR(self, instr: Instruction):
return
else:
attr_name = self._code.co_names[instr.arg]
attr_name_var = ConstantVariable.wrap_literal(attr_name, self._graph)
obj = self.stack.pop()
attr_name_var = ConstantVariable.wrap_literal(attr_name, self._graph)
GoldenStain marked this conversation as resolved.
Show resolved Hide resolved
self.stack.push(
BuiltinVariable(
getattr, graph=self._graph, tracker=DanglingTracker()
Expand Down Expand Up @@ -860,11 +860,20 @@ def LOAD_GLOBAL(self, instr: Instruction):
if push_null and CALL_METHOD_LAYOUT_NULL_AFTER_VALUE:
self.stack.push(NullVariable())

def load_sequence(self, obj):
self.stack.push(obj.get_iter())
# skip call
while self._instructions[self._lasti].opname != "RETURN_VALUE":
self._lasti += 1

def load_method(self, method_name):
obj = self.stack.pop()
if isinstance(obj, ContainerVariable) and method_name == "__iter__":
self.load_sequence(obj)
return
method_name_var = ConstantVariable.wrap_literal(
method_name, self._graph
)
Comment on lines +863 to 876
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段逻辑对上下文有较大的修改,如何保证全场景的正确性?

每条字节码应该只做自己的事情,只看当前 instruction,不应该去访问 self._instructions

另外这里特判 ContainerVariable__iter__ 的原因是?

obj = self.stack.pop()

method = BuiltinVariable(
getattr, graph=self._graph, tracker=DanglingTracker()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,35 @@ def main_info(self) -> dict[str, Any]:
def get_py_value(self, allow_tensor=False) -> Any:
return self.value

def get_iter(self):
"""
To simplify the problem, we only support the case where the __iter__ method returns a list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不止 list,应该是现在全部已经支持的 builtin types

另外这个实现在 VariableBase 上会有什么问题么

"""
from . import (
BuiltinVariable,
ConstantVariable,
SequenceIterVariable,
UserDefinedFunctionVariable,
)

if not hasattr(self.value, "__iter__"):
return super().get_iter()
iter_name_var = ConstantVariable.wrap_literal("__iter__", self.graph)
iter_method = BuiltinVariable(
getattr, graph=self.graph, tracker=DanglingTracker()
)(self, iter_name_var)
# If the target object is a builtin object like list_iterator, the iter_method's fn will be a ObjectVariable instead of UserDefinedFunctionVariable.
if not isinstance(iter_method.fn, UserDefinedFunctionVariable):
return super().get_iter()
iter_result = iter_method()

if iter_result is None or not isinstance(
iter_result, SequenceIterVariable
):
return super().get_iter()

return iter_result


class SliceVariable(VariableBase):
"""
Expand Down
19 changes: 19 additions & 0 deletions test/sot/test_04_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ def list_extend_dict():
return l1


class IterableWithList:
def __init__(self):
self._list = [1, 2, 3]

def __iter__(self):
return self._list.__iter__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议使用 iter(self._list),这是更加惯用的形式,不过当前写法也测一下吧



@check_no_breakgraph
def list_within_class(x: paddle.Tensor):
my_iterable = IterableWithList()
for i in my_iterable:
x += i
return x


class TestListBasic(TestCaseBase):
def test_list_basic(self):
self.assert_results(list_getitem_int, 1, paddle.to_tensor(2))
Expand Down Expand Up @@ -375,6 +391,9 @@ def test_list_extend_range(self):
def test_list_extend_dict(self):
self.assert_results(list_extend_dict)

def test_list_within_class(self):
self.assert_results(list_within_class, paddle.to_tensor(1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以新建一个文件用来测 iter,比如 test_iter.py



if __name__ == "__main__":
unittest.main()
Loading