diff --git a/rikulox-parse/src/parse.rs b/rikulox-parse/src/parse.rs index b3b629f..663e724 100644 --- a/rikulox-parse/src/parse.rs +++ b/rikulox-parse/src/parse.rs @@ -955,3 +955,219 @@ where Ok(token) } } + +#[cfg(test)] +mod tests { + use super::*; + use rikulox_ast::{ + expr::{BinOp, ExprKind, Literal}, + stmt::{FunctionDecl, StmtKind}, + }; + use rikulox_lex::scan::Scanner; + + /// Helper to scan source code and run the parser on it. + fn parse_source(src: &str) -> Vec> { + let mut scanner = Scanner::new(src); + let scan_tokens = scanner.scan_tokens(); + assert!( + scan_tokens.errors.is_empty(), + "lex errors: {:?}", + scan_tokens.errors + ); + let mut parser = Parser::new( + scan_tokens.tokens.into_iter().peekable(), + scan_tokens.eof_span, + ); + parser.parse().expect("failed to parse") + } + + #[test] + fn parse_empty_source_returns_no_statements() { + let stmts = parse_source(""); + assert!(stmts.is_empty()); + } + + #[test] + fn parse_var_declaration_with_initializer() { + let stmts = parse_source("var x = 1;"); + assert_eq!(stmts.len(), 1); + match &stmts[0].kind { + StmtKind::Var { + name, + init: Some(expr), + } => { + assert_eq!(name.symbol, "x"); + match &expr.kind { + ExprKind::Literal(Literal::Number(n)) => { + assert_eq!(*n, 1.0) + } + _ => panic!("expected number literal"), + } + } + _ => panic!("expected var declaration"), + } + } + + #[test] + fn parse_function_declaration() { + let stmts = parse_source("fun add(a, b) { return a + b; }"); + assert_eq!(stmts.len(), 1); + match &stmts[0].kind { + StmtKind::Function(FunctionDecl { name, params, body }) => { + assert_eq!(name.symbol, "add"); + assert_eq!(params.len(), 2); + assert_eq!(params[0].symbol, "a"); + assert_eq!(params[1].symbol, "b"); + assert_eq!(body.len(), 1); + match &body[0].kind { + StmtKind::Return(Some(expr)) => match &expr.kind { + ExprKind::Binary { op: BinOp::Add, .. } => {} + _ => panic!("expected addition in return"), + }, + _ => panic!("expected return statement"), + } + } + _ => panic!("expected function declaration"), + } + } + + #[test] + fn parse_class_declaration_with_method() { + let stmts = parse_source("class Foo { bar() { return 1; } }"); + assert_eq!(stmts.len(), 1); + match &stmts[0].kind { + StmtKind::Class(class_decl) => { + assert_eq!(class_decl.name.symbol, "Foo"); + assert_eq!(class_decl.methods.len(), 1); + let method = &class_decl.methods[0]; + assert_eq!(method.name.symbol, "bar"); + assert!(method.params.is_empty()); + assert_eq!(method.body.len(), 1); + } + _ => panic!("expected class declaration"), + } + } + + #[test] + fn parse_if_else_statement() { + let stmts = parse_source("if (true) print 1; else print 2;"); + assert_eq!(stmts.len(), 1); + match &stmts[0].kind { + StmtKind::If { + condition, + then_branch, + else_branch, + } => { + match &condition.kind { + ExprKind::Literal(Literal::Bool(true)) => {} + _ => panic!("expected true literal"), + } + match &then_branch.kind { + StmtKind::Print(expr) => match &expr.kind { + ExprKind::Literal(Literal::Number(n)) => { + assert_eq!(*n, 1.0) + } + _ => panic!("expected number literal"), + }, + _ => panic!("expected print statement"), + } + match else_branch.as_deref().map(|stmt| &stmt.kind) { + Some(StmtKind::Print(expr)) => match &expr.kind { + ExprKind::Literal(Literal::Number(n)) => { + assert_eq!(*n, 2.0) + } + _ => panic!("expected number literal"), + }, + _ => panic!("expected else branch"), + } + } + _ => panic!("expected if statement"), + } + } + + #[test] + fn parse_while_statement() { + let stmts = parse_source("while (false) print 1;"); + assert_eq!(stmts.len(), 1); + match &stmts[0].kind { + StmtKind::While { condition, body } => { + match &condition.kind { + ExprKind::Literal(Literal::Bool(false)) => {} + _ => panic!("expected false literal"), + } + match &body.kind { + StmtKind::Print(expr) => match &expr.kind { + ExprKind::Literal(Literal::Number(n)) => { + assert_eq!(*n, 1.0) + } + _ => panic!("expected number literal"), + }, + _ => panic!("expected print statement"), + } + } + _ => panic!("expected while statement"), + } + } + + #[test] + fn parse_for_statement_desugars_to_while() { + let stmts = parse_source("for (var i = 0; i < 2; i = i + 1) print i;"); + assert_eq!(stmts.len(), 1); + match &stmts[0].kind { + StmtKind::Block(block) => { + assert_eq!(block.len(), 2); + match &block[0].kind { + StmtKind::Var { name, .. } => assert_eq!(name.symbol, "i"), + _ => panic!("expected loop initializer"), + } + match &block[1].kind { + StmtKind::While { condition, body } => { + match &condition.kind { + ExprKind::Binary { + op: BinOp::Less, .. + } => {} + _ => panic!("expected less-than condition"), + } + match &body.kind { + StmtKind::Block(inner) => { + assert_eq!(inner.len(), 2); + matches!(inner[0].kind, StmtKind::Print(_)) + .then_some(()) + .expect("expected print in body"); + match &inner[1].kind { + StmtKind::Expression(expr) => { + match &expr.kind { + ExprKind::Assign { .. } => {} + _ => panic!( + "expected increment expression" + ), + } + } + _ => { + panic!("expected expression statement") + } + } + } + _ => panic!("expected block body"), + } + } + _ => panic!("expected while loop"), + } + } + _ => panic!("expected desugared for loop"), + } + } + + #[test] + fn parse_var_declaration_missing_semicolon_errors() { + let mut scanner = Scanner::new("var a"); + let scan_tokens = scanner.scan_tokens(); + assert!(scan_tokens.errors.is_empty()); + let mut parser = Parser::new( + scan_tokens.tokens.into_iter().peekable(), + scan_tokens.eof_span, + ); + let err = parser.parse().expect_err("expected parse error"); + assert!(matches!(err.kind, ParseErrorKind::UnexpectedEof { .. })); + } +}