diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index ee2ff06b0f03..cade0c823962 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -46,6 +46,7 @@ Truncate, TupleGet, TupleSet, + Unborrow, Unbox, Unreachable, Value, @@ -272,6 +273,9 @@ def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]: def visit_keep_alive(self, op: KeepAlive) -> GenAndKill[T]: return self.visit_register_op(op) + def visit_unborrow(self, op: Unborrow) -> GenAndKill[T]: + return self.visit_register_op(op) + class DefinedVisitor(BaseAnalysisVisitor[Value]): """Visitor for finding defined registers. diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 2e6b7320e898..a31b1517b036 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -44,6 +44,7 @@ Truncate, TupleGet, TupleSet, + Unborrow, Unbox, Unreachable, Value, @@ -422,3 +423,6 @@ def visit_load_address(self, op: LoadAddress) -> None: def visit_keep_alive(self, op: KeepAlive) -> None: pass + + def visit_unborrow(self, op: Unborrow) -> None: + pass diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py index 288c366e50e5..80c2bc348bc2 100644 --- a/mypyc/analysis/selfleaks.py +++ b/mypyc/analysis/selfleaks.py @@ -40,6 +40,7 @@ Truncate, TupleGet, TupleSet, + Unborrow, Unbox, Unreachable, ) @@ -184,6 +185,9 @@ def visit_load_address(self, op: LoadAddress) -> GenAndKill: def visit_keep_alive(self, op: KeepAlive) -> GenAndKill: return CLEAN + def visit_unborrow(self, op: Unborrow) -> GenAndKill: + return CLEAN + def check_register_op(self, op: RegisterOp) -> GenAndKill: if any(src is self.self_reg for src in op.sources()): return DIRTY diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index b4d31544b196..3bce84d3ea59 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -55,6 +55,7 @@ Truncate, TupleGet, TupleSet, + Unborrow, Unbox, Unreachable, Value, @@ -260,7 +261,6 @@ def visit_tuple_set(self, op: TupleSet) -> None: else: for i, item in enumerate(op.items): self.emit_line(f"{dest}.f{i} = {self.reg(item)};") - self.emit_inc_ref(dest, tuple_type) def visit_assign(self, op: Assign) -> None: dest = self.reg(op.dest) @@ -499,7 +499,8 @@ def visit_tuple_get(self, op: TupleGet) -> None: dest = self.reg(op) src = self.reg(op.src) self.emit_line(f"{dest} = {src}.f{op.index};") - self.emit_inc_ref(dest, op.type) + if not op.is_borrowed: + self.emit_inc_ref(dest, op.type) def get_dest_assign(self, dest: Value) -> str: if not dest.is_void: @@ -746,6 +747,12 @@ def visit_keep_alive(self, op: KeepAlive) -> None: # This is a no-op. pass + def visit_unborrow(self, op: Unborrow) -> None: + # This is a no-op that propagates the source value. + dest = self.reg(op) + src = self.reg(op.src) + self.emit_line(f"{dest} = {src};") + # Helpers def label(self, label: BasicBlock) -> str: diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 2d64cc79d822..04c50d1e2841 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -792,6 +792,9 @@ def __init__(self, items: list[Value], line: int) -> None: def sources(self) -> list[Value]: return self.items.copy() + def stolen(self) -> list[Value]: + return self.items.copy() + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_tuple_set(self) @@ -801,13 +804,14 @@ class TupleGet(RegisterOp): error_kind = ERR_NEVER - def __init__(self, src: Value, index: int, line: int = -1) -> None: + def __init__(self, src: Value, index: int, line: int = -1, *, borrow: bool = False) -> None: super().__init__(line) self.src = src self.index = index assert isinstance(src.type, RTuple), "TupleGet only operates on tuples" assert index >= 0 self.type = src.type.types[index] + self.is_borrowed = borrow def sources(self) -> list[Value]: return [self.src] @@ -1387,21 +1391,76 @@ class KeepAlive(RegisterOp): If we didn't have "keep_alive x", x could be freed immediately after taking the address of 'item', resulting in a read after free on the second line. + + If 'steal' is true, the value is considered to be stolen at + this op, i.e. it won't be decref'd. You need to ensure that + the value is freed otherwise, perhaps by using borrowing + followed by Unborrow. + + Be careful with steal=True -- this can cause memory leaks. """ error_kind = ERR_NEVER - def __init__(self, src: list[Value]) -> None: + def __init__(self, src: list[Value], *, steal: bool = False) -> None: assert src self.src = src + self.steal = steal def sources(self) -> list[Value]: return self.src.copy() + def stolen(self) -> list[Value]: + if self.steal: + return self.src.copy() + return [] + def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_keep_alive(self) +class Unborrow(RegisterOp): + """A no-op op to create a regular reference from a borrowed one. + + Borrowed references can only be used temporarily and the reference + counts won't be managed. This value will be refcounted normally. + + This is mainly useful if you split an aggregate value, such as + a tuple, into components using borrowed values (to avoid increfs), + and want to treat the components as sharing the original managed + reference. You'll also need to use KeepAlive with steal=True to + "consume" the original tuple reference: + + # t is a 2-tuple + r0 = borrow t[0] + r1 = borrow t[1] + r2 = unborrow r0 + r3 = unborrow r1 + # now (r2, r3) represent the tuple as separate items, and the + # original tuple can be considered dead and available to be + # stolen + keep_alive steal t + + Be careful with this -- this can easily cause double freeing. + """ + + error_kind = ERR_NEVER + + def __init__(self, src: Value) -> None: + assert src.is_borrowed + self.src = src + self.type = src.type + + def sources(self) -> list[Value]: + return [self.src] + + def stolen(self) -> list[Value]: + return [] + + def accept(self, visitor: OpVisitor[T]) -> T: + return visitor.visit_unborrow(self) + + @trait class OpVisitor(Generic[T]): """Generic visitor over ops (uses the visitor design pattern).""" @@ -1548,6 +1607,10 @@ def visit_load_address(self, op: LoadAddress) -> T: def visit_keep_alive(self, op: KeepAlive) -> T: raise NotImplementedError + @abstractmethod + def visit_unborrow(self, op: Unborrow) -> T: + raise NotImplementedError + # TODO: Should the following definition live somewhere else? diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index c86060c49594..5578049256f1 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -51,6 +51,7 @@ Truncate, TupleGet, TupleSet, + Unborrow, Unbox, Unreachable, Value, @@ -153,7 +154,7 @@ def visit_init_static(self, op: InitStatic) -> str: return self.format("%s = %r :: %s", name, op.value, op.namespace) def visit_tuple_get(self, op: TupleGet) -> str: - return self.format("%r = %r[%d]", op, op.src, op.index) + return self.format("%r = %s%r[%d]", op, self.borrow_prefix(op), op.src, op.index) def visit_tuple_set(self, op: TupleSet) -> str: item_str = ", ".join(self.format("%r", item) for item in op.items) @@ -274,7 +275,16 @@ def visit_load_address(self, op: LoadAddress) -> str: return self.format("%r = load_address %s", op, op.src) def visit_keep_alive(self, op: KeepAlive) -> str: - return self.format("keep_alive %s" % ", ".join(self.format("%r", v) for v in op.src)) + if op.steal: + steal = "steal " + else: + steal = "" + return self.format( + "keep_alive {}{}".format(steal, ", ".join(self.format("%r", v) for v in op.src)) + ) + + def visit_unborrow(self, op: Unborrow) -> str: + return self.format("%r = unborrow %r", op, op.src) # Helpers diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 984b6a4deec0..d1ea91476a66 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -266,6 +266,9 @@ def goto_and_activate(self, block: BasicBlock) -> None: self.goto(block) self.activate_block(block) + def keep_alive(self, values: list[Value], *, steal: bool = False) -> None: + self.add(KeepAlive(values, steal=steal)) + def push_error_handler(self, handler: BasicBlock | None) -> None: self.error_handlers.append(handler) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 63297618108c..d7e01456139d 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -59,11 +59,13 @@ Register, Return, TupleGet, + Unborrow, Unreachable, Value, ) from mypyc.ir.rtypes import ( RInstance, + RTuple, c_pyssize_t_rprimitive, exc_rtuple, is_tagged, @@ -183,8 +185,29 @@ def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None: line = stmt.rvalue.line rvalue_reg = builder.accept(stmt.rvalue) + if builder.non_function_scope() and stmt.is_final_def: builder.init_final_static(first_lvalue, rvalue_reg) + + # Special-case multiple assignments like 'x, y = expr' to reduce refcount ops. + if ( + isinstance(first_lvalue, (TupleExpr, ListExpr)) + and isinstance(rvalue_reg.type, RTuple) + and len(rvalue_reg.type.types) == len(first_lvalue.items) + and len(lvalues) == 1 + and all(is_simple_lvalue(item) for item in first_lvalue.items) + and any(t.is_refcounted for t in rvalue_reg.type.types) + ): + n = len(first_lvalue.items) + for i in range(n): + target = builder.get_assignment_target(first_lvalue.items[i]) + rvalue_item = builder.add(TupleGet(rvalue_reg, i, borrow=True)) + rvalue_item = builder.add(Unborrow(rvalue_item)) + builder.assign(target, rvalue_item, line) + builder.builder.keep_alive([rvalue_reg], steal=True) + builder.flush_keep_alives() + return + for lvalue in lvalues: target = builder.get_assignment_target(lvalue) builder.assign(target, rvalue_reg, line) diff --git a/mypyc/test-data/irbuild-statements.test b/mypyc/test-data/irbuild-statements.test index 062abd47d163..490b41336e88 100644 --- a/mypyc/test-data/irbuild-statements.test +++ b/mypyc/test-data/irbuild-statements.test @@ -502,16 +502,16 @@ L0: [case testMultipleAssignmentBasicUnpacking] from typing import Tuple, Any -def from_tuple(t: Tuple[int, str]) -> None: +def from_tuple(t: Tuple[bool, None]) -> None: x, y = t def from_any(a: Any) -> None: x, y = a [out] def from_tuple(t): - t :: tuple[int, str] - r0, x :: int - r1, y :: str + t :: tuple[bool, None] + r0, x :: bool + r1, y :: None L0: r0 = t[0] x = r0 @@ -563,16 +563,19 @@ def from_any(a: Any) -> None: [out] def from_tuple(t): t :: tuple[int, object] - r0 :: int - r1, x, r2 :: object - r3, y :: int + r0, r1 :: int + r2, x, r3, r4 :: object + r5, y :: int L0: - r0 = t[0] - r1 = box(int, r0) - x = r1 - r2 = t[1] - r3 = unbox(int, r2) - y = r3 + r0 = borrow t[0] + r1 = unborrow r0 + r2 = box(int, r1) + x = r2 + r3 = borrow t[1] + r4 = unborrow r3 + r5 = unbox(int, r4) + y = r5 + keep_alive steal t return 1 def from_any(a): a, r0, r1 :: object diff --git a/mypyc/test-data/refcount.test b/mypyc/test-data/refcount.test index 3db4caa39566..0f2c134ae21e 100644 --- a/mypyc/test-data/refcount.test +++ b/mypyc/test-data/refcount.test @@ -656,6 +656,66 @@ L1: L2: return 4 +[case testReturnTuple] +from typing import Tuple + +class C: pass +def f() -> Tuple[C, C]: + a = C() + b = C() + return a, b +[out] +def f(): + r0, a, r1, b :: __main__.C + r2 :: tuple[__main__.C, __main__.C] +L0: + r0 = C() + a = r0 + r1 = C() + b = r1 + r2 = (a, b) + return r2 + +[case testDecomposeTuple] +from typing import Tuple + +class C: + a: int + +def f() -> int: + x, y = g() + return x.a + y.a + +def g() -> Tuple[C, C]: + return C(), C() +[out] +def f(): + r0 :: tuple[__main__.C, __main__.C] + r1, r2, x, r3, r4, y :: __main__.C + r5, r6, r7 :: int +L0: + r0 = g() + r1 = borrow r0[0] + r2 = unborrow r1 + x = r2 + r3 = borrow r0[1] + r4 = unborrow r3 + y = r4 + r5 = borrow x.a + r6 = borrow y.a + r7 = CPyTagged_Add(r5, r6) + dec_ref x + dec_ref y + return r7 +def g(): + r0, r1 :: __main__.C + r2 :: tuple[__main__.C, __main__.C] +L0: + r0 = C() + r1 = C() + r2 = (r0, r1) + return r2 + [case testUnicodeLiteral] def f() -> str: return "some string"