diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 4f7281e15e..1603211d74 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -764,6 +764,17 @@ fn function_last_expressions<'a>( sys_info: &SysInfo, ) -> Option> { fn f<'a>(sys_info: &SysInfo, x: &'a [Stmt], res: &mut Vec<(LastStmt, &'a Expr)>) -> Option<()> { + fn loop_body_has_break_statement(statement: &Stmt, has_break: &mut bool) { + match statement { + Stmt::Break(_) => { + *has_break = true; + } + Stmt::While(_) | Stmt::For(_) => {} + _ => statement + .recurse(&mut |statement| loop_body_has_break_statement(statement, has_break)), + } + } + match x.last()? { Stmt::Expr(x) => res.push((LastStmt::Expr, &x.value)), Stmt::Return(_) | Stmt::Raise(_) => {} @@ -781,35 +792,20 @@ fn function_last_expressions<'a>( return None; } let mut has_break = false; - fn f(stmt: &Stmt, res: &mut bool) { - match stmt { - Stmt::Break(_) => { - *res = true; - } - Stmt::While(_) | Stmt::For(_) => {} - _ => stmt.recurse(&mut |stmt| f(stmt, res)), - } - } - x.body.visit(&mut |stmt| f(stmt, &mut has_break)); + x.body + .visit(&mut |stmt| loop_body_has_break_statement(stmt, &mut has_break)); if has_break { return None; } } Stmt::For(x) => { let mut has_break = false; - fn f(stmt: &Stmt, res: &mut bool) { - match stmt { - Stmt::Break(_) => { - *res = true; - } - Stmt::While(_) | Stmt::For(_) => {} - _ => stmt.recurse(&mut |stmt| f(stmt, res)), - } - } - x.body.visit(&mut |stmt| f(stmt, &mut has_break)); - if has_break { + x.body + .visit(&mut |stmt| loop_body_has_break_statement(stmt, &mut has_break)); + if has_break || x.orelse.is_empty() { return None; } + f(sys_info, &x.orelse, res)?; } Stmt::If(x) => { let mut last_test = None; diff --git a/pyrefly/lib/test/returns.rs b/pyrefly/lib/test/returns.rs index a40b9aafb3..1b181ee568 100644 --- a/pyrefly/lib/test/returns.rs +++ b/pyrefly/lib/test/returns.rs @@ -35,6 +35,21 @@ assert_type(f(), None) "#, ); +// Regression test for https://github.com/facebook/pyrefly/issues/1491 +testcase!( + test_infer_return_in_for_loop, + r#" +from typing import reveal_type + +class A: + def f(self, x): + for y in x: + pass + +reveal_type(A().f(0)) # E: revealed type: None +"#, +); + testcase!( test_return_unions, r#"