diff --git a/src/ast.rs b/src/ast.rs index 2d093070..f8d7ca42 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -396,6 +396,10 @@ pub struct VarPat { /// Inferred type of the binder. Filled in by the type checker. pub ty: Option, + + /// Only after type checking: when the binder type is refined by pattern matching, this holds + /// the refined type. + pub refined: Option, } #[derive(Debug, Clone)] @@ -419,6 +423,7 @@ pub struct RecordPat { pub struct VariantPat { pub pat: Box>, pub inferred_ty: Option, + pub inferred_pat_ty: Option, } #[derive(Debug, Clone)] diff --git a/src/ast/printer.rs b/src/ast/printer.rs index ad3e73f8..d65fc204 100644 --- a/src/ast/printer.rs +++ b/src/ast/printer.rs @@ -814,11 +814,14 @@ impl Expr { impl Pat { pub fn print(&self, buf: &mut String) { match self { - Pat::Var(VarPat { var, ty }) => { + Pat::Var(VarPat { var, ty, refined }) => { buf.push_str(var); if let Some(ty) = ty { write!(buf, ": {ty}").unwrap(); } + if let Some(refined) = refined { + write!(buf, " ~> {refined}").unwrap(); + } } Pat::Con(ConPat { @@ -901,7 +904,11 @@ impl Pat { buf.push(')'); } - Pat::Variant(VariantPat { pat, inferred_ty }) => { + Pat::Variant(VariantPat { + pat, + inferred_ty, + inferred_pat_ty: _, + }) => { buf.push('~'); pat.node.print(buf); if let Some(ty) = inferred_ty { diff --git a/src/interpreter.rs b/src/interpreter.rs index 9dad9e0d..a8c5bb87 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -59,6 +59,7 @@ impl Pgm { closures, type_objs: _, record_objs: _, + variant_objs: _, true_con_idx, false_con_idx, char_con_idx, @@ -77,7 +78,7 @@ impl Pgm { // Allocate singletons for constructors without fields. for (i, heap_obj) in heap_objs.iter().enumerate() { match heap_obj { - HeapObj::Builtin(_) => continue, + HeapObj::Builtin(_) | HeapObj::Variant(_) => continue, HeapObj::Source(source_con) => { if source_con.fields.is_empty() { @@ -664,7 +665,11 @@ fn eval( Expr::Do(stmts, _) => exec(w, pgm, heap, locals, stmts, call_stack), - Expr::Variant(expr) => { + Expr::Variant { + expr, + expr_ty: _, + variant_ty: _, + } => { // Note: the interpreter can only deal with variants of boxed types. If `expr` is an // unboxed type things will go wrong. // @@ -729,8 +734,11 @@ fn assign( /// compiled version `StrView`s will be allocated on stack. fn try_bind_pat(pgm: &Pgm, heap: &mut Heap, pat: &L, locals: &mut [u64], value: u64) -> bool { match &pat.node { - Pat::Var(var) => { - locals[var.as_usize()] = value; + Pat::Var(VarPat { + idx, + original_ty: _, + }) => { + locals[idx.as_usize()] = value; true } @@ -775,10 +783,14 @@ fn try_bind_pat(pgm: &Pgm, heap: &mut Heap, pat: &L, locals: &mut [u64], va true } - Pat::Variant(p) => { + Pat::Variant { + pat, + variant_ty: _, + pat_ty: _, + } => { // `p` needs to match a boxed type, but we can't check this here (e.g. in an `assert`). // See the documentation in `Expr::Variant` evaluator. - try_bind_pat(pgm, heap, p, locals, value) + try_bind_pat(pgm, heap, pat, locals, value) } } } diff --git a/src/lowering.rs b/src/lowering.rs index 14cfd417..2be01f60 100644 --- a/src/lowering.rs +++ b/src/lowering.rs @@ -6,8 +6,8 @@ pub mod printer; use crate::ast; use crate::collections::*; use crate::mono_ast::{self as mono, Id, L, Loc}; -pub(crate) use crate::type_collector::RecordType; use crate::type_collector::collect_anonymous_types; +pub(crate) use crate::type_collector::{RecordType, VariantType}; use crate::utils::loc_display; use smol_str::SmolStr; @@ -23,9 +23,16 @@ pub struct LoweredPgm { /// Product types will have one index per type. Sum types may have multiple. pub type_objs: HashMap, TypeObjs>>, - /// Maps record types to their heap object indices. + /// For C backend: maps record types to their heap object indices. pub record_objs: HashMap, + /// For C backend: maps variant types to their heap object indices. + /// + /// Note: variants don't have their own tags, they use the tags of the types in the variant + /// instead. These tags are to make it easy to refer to a variant type in AST nodes, dependency + /// analysis etc. + pub variant_objs: HashMap, + // Ids of some special cons that the interpreter needs to know. // // Note that for product types, type and con tags are the same. @@ -50,7 +57,10 @@ pub struct LoweredPgm { #[derive(Debug)] pub enum TypeObjs { Product(HeapObjIdx), - Sum(Vec), + Sum { + con_indices: Vec, + value: bool, + }, } pub const CON_CON_IDX: HeapObjIdx = HeapObjIdx(0); @@ -291,6 +301,7 @@ pub enum HeapObj { Builtin(BuiltinConDecl), Source(SourceConDecl), Record(RecordType), + Variant(VariantType), } #[derive(Debug)] @@ -429,7 +440,11 @@ pub enum Expr { mono::Type, ), - Variant(Box>), + Variant { + expr: Box>, + expr_ty: mono::Type, + variant_ty: OrdMap, + }, } #[derive(Debug, Clone)] @@ -486,13 +501,28 @@ pub struct IsExpr { #[derive(Debug, Clone)] pub enum Pat { - Var(LocalIdx), + Var(VarPat), Con(ConPat), Ignore, Str(String), Char(char), Or(Box>, Box>), - Variant(Box>), + Variant { + pat: Box>, + variant_ty: OrdMap, + pat_ty: mono::Type, + }, +} + +#[derive(Debug, Clone)] +pub struct VarPat { + pub idx: LocalIdx, + + /// When the binder was refined by pattern matching, the local in `SourceFunDecl` will have the + /// refined type, and this will be the original type. + /// + /// When pattern matching, we should convert the original type to the local's type. + pub original_ty: mono::Type, } #[derive(Debug, Clone)] @@ -693,6 +723,7 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm { closures: vec![], type_objs: Default::default(), record_objs: Default::default(), + variant_objs: Default::default(), true_con_idx: *sum_con_nums .get("Bool") .unwrap() @@ -768,6 +799,9 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm { match &con_decl.rhs { Some(rhs) => match rhs { mono::TypeDeclRhs::Sum(cons) => { + // For sum types, we generate an index representing the type itself (rather + // than its consturctors). This index is used in dependency anlaysis, and to + // get the type details during code generation. let mut con_indices: Vec = Vec::with_capacity(cons.len()); for mono::ConDecl { name, fields } in cons { let idx = HeapObjIdx(lowered_pgm.heap_objs.len() as u32); @@ -784,7 +818,13 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm { .type_objs .entry(con_id.clone()) .or_default() - .insert(con_ty_args.clone(), TypeObjs::Sum(con_indices)); + .insert( + con_ty_args.clone(), + TypeObjs::Sum { + con_indices, + value: con_decl.value, + }, + ); assert!(old.is_none()); } @@ -871,9 +911,8 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm { } // Assign indices to record shapes. - let (record_types, _variant_types) = collect_anonymous_types(mono_pgm); + let (record_types, variant_types) = collect_anonymous_types(mono_pgm); - // TODO: We could assign indices to records as we see them during lowering below. let mut record_indices: HashMap = Default::default(); for record_type in record_types { let idx = next_con_idx; @@ -882,6 +921,14 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm { lowered_pgm.heap_objs.push(HeapObj::Record(record_type)); } + let mut variant_indices: HashMap = Default::default(); + for variant_type in variant_types { + let idx = next_con_idx; + next_con_idx = HeapObjIdx(next_con_idx.0 + 1); + variant_indices.insert(variant_type.clone(), idx); + lowered_pgm.heap_objs.push(HeapObj::Variant(variant_type)); + } + lowered_pgm.unit_con_idx = *record_indices .get(&RecordType::unit()) .unwrap_or_else(|| panic!("BUG: Unit record not defined {record_indices:#?}")); @@ -1425,6 +1472,7 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm { } lowered_pgm.record_objs = indices.records; + lowered_pgm.variant_objs = variant_indices; lowered_pgm } @@ -2094,11 +2142,19 @@ fn lower_expr( (expr, Default::default()) } - mono::Expr::Variant(mono::VariantExpr { expr, ty: _ }) => { + mono::Expr::Variant(mono::VariantExpr { expr, ty }) => { // Note: Type of the expr in the variant won't be a variant type. Use the // `VariantExpr`'s type. + let expr_ty = expr.node.ty(); let (expr, vars) = lower_bl_expr(expr, closures, indices, scope, mono_pgm); - (Expr::Variant(expr), vars) + ( + Expr::Variant { + expr, + expr_ty, + variant_ty: ty.clone(), + }, + vars, + ) } } } @@ -2147,22 +2203,28 @@ fn lower_pat( // This map is to map binders in alternatives of or patterns to the same local. // - // Only in or pattern alternatives we allow same binders, so if we see a binder for the second - // time, we must be checking another alternative of an or pattern. + // Only in or-pattern alternatives we allow same binders, so if we see a binder for the second + // time, we must be checking another alternative of an or-pattern. mapped_binders: &mut HashMap, ) -> Pat { match pat { - mono::Pat::Var(mono::VarPat { var, ty }) => match mapped_binders.get(var) { - Some(idx) => Pat::Var(*idx), + mono::Pat::Var(mono::VarPat { var, ty, refined }) => match mapped_binders.get(var) { + Some(idx) => Pat::Var(VarPat { + idx: *idx, + original_ty: ty.clone(), + }), None => { let var_idx = LocalIdx(scope.locals.len() as u32); scope.locals.push(LocalInfo { name: var.clone(), - ty: ty.clone(), + ty: refined.as_ref().unwrap_or(ty).clone(), }); scope.bounds.insert(var.clone(), var_idx); mapped_binders.insert(var.clone(), var_idx); - Pat::Var(var_idx) + Pat::Var(VarPat { + idx: var_idx, + original_ty: ty.clone(), + }) } }, @@ -2311,13 +2373,15 @@ fn lower_pat( lower_bl_pat(p2, indices, scope, mono_pgm, mapped_binders), ), - mono::Pat::Variant(mono::VariantPat { pat, ty: _ }) => Pat::Variant(Box::new(lower_l_pat( + mono::Pat::Variant(mono::VariantPat { pat, - indices, - scope, - mono_pgm, - mapped_binders, - ))), + variant_ty, + pat_ty, + }) => Pat::Variant { + pat: Box::new(lower_l_pat(pat, indices, scope, mono_pgm, mapped_binders)), + variant_ty: variant_ty.clone(), + pat_ty: pat_ty.clone(), + }, } } diff --git a/src/lowering/printer.rs b/src/lowering/printer.rs index 1f6a6dde..0359f7b0 100644 --- a/src/lowering/printer.rs +++ b/src/lowering/printer.rs @@ -37,6 +37,8 @@ impl LoweredPgm { } HeapObj::Record(record) => write!(buf, "{record:?}").unwrap(), + + HeapObj::Variant(variant) => write!(buf, "{variant:?}").unwrap(), } buf.push('\n'); } @@ -385,7 +387,11 @@ impl Expr { } } - Expr::Variant(expr) => { + Expr::Variant { + expr, + expr_ty: _, + variant_ty: _, + } => { buf.push('~'); expr.node.print(buf, indent); } @@ -396,7 +402,10 @@ impl Expr { impl Pat { pub fn print(&self, buf: &mut String) { match self { - Pat::Var(idx) => write!(buf, "local{}", idx.0).unwrap(), + Pat::Var(VarPat { + idx, + original_ty: _, + }) => write!(buf, "local{}", idx.0).unwrap(), Pat::Con(ConPat { con, fields }) => { write!(buf, "con{}(", con.0).unwrap(); @@ -431,9 +440,13 @@ impl Pat { p2.node.print(buf); } - Pat::Variant(p) => { + Pat::Variant { + pat, + variant_ty: _, + pat_ty: _, + } => { buf.push('~'); - p.node.print(buf); + pat.node.print(buf); } } } diff --git a/src/mono_ast.rs b/src/mono_ast.rs index bb798798..b8825ac5 100644 --- a/src/mono_ast.rs +++ b/src/mono_ast.rs @@ -23,6 +23,7 @@ pub struct MonoPgm { pub struct TypeDecl { pub name: Id, pub rhs: Option, + pub value: bool, } #[derive(Debug, Clone)] @@ -233,6 +234,7 @@ pub enum Pat { pub struct VarPat { pub var: Id, pub ty: Type, + pub refined: Option, } #[derive(Debug, Clone)] @@ -266,7 +268,8 @@ pub struct RecordPat { #[derive(Debug, Clone)] pub struct VariantPat { pub pat: Box>, - pub ty: OrdMap, // the variant type + pub variant_ty: OrdMap, + pub pat_ty: Type, } #[derive(Debug, Clone)] diff --git a/src/mono_ast/printer.rs b/src/mono_ast/printer.rs index a4376a35..eb2359b6 100644 --- a/src/mono_ast/printer.rs +++ b/src/mono_ast/printer.rs @@ -576,10 +576,14 @@ impl Expr { impl Pat { pub fn print(&self, buf: &mut String) { match self { - Pat::Var(VarPat { var, ty }) => { + Pat::Var(VarPat { var, ty, refined }) => { buf.push_str(var); buf.push_str(": "); ty.print(buf); + if let Some(refined) = refined { + buf.push_str(" ~> "); + refined.print(buf); + } } Pat::Con(ConPat { con, fields }) => { @@ -640,7 +644,11 @@ impl Pat { buf.push(')'); } - Pat::Variant(VariantPat { pat, ty: _ }) => { + Pat::Variant(VariantPat { + pat, + variant_ty: _, + pat_ty: _, + }) => { buf.push('~'); pat.node.print(buf); } diff --git a/src/monomorph.rs b/src/monomorph.rs index eba1ecc9..81430a9d 100644 --- a/src/monomorph.rs +++ b/src/monomorph.rs @@ -575,6 +575,7 @@ fn mono_expr( node: mono::Pat::Var(mono::VarPat { var: SmolStr::new_static("$receiver$"), ty: mono_object_ty, + refined: None, }), }, rhs: *mono_object, @@ -1263,12 +1264,16 @@ fn mono_pat( loc: &ast::Loc, ) -> mono::Pat { match pat { - ast::Pat::Var(ast::VarPat { var, ty }) => { + ast::Pat::Var(ast::VarPat { var, ty, refined }) => { let mono_ty = mono_tc_ty(ty.as_ref().unwrap(), ty_map, poly_pgm, mono_pgm); + let refined = refined + .as_ref() + .map(|refined| mono_tc_ty(refined, ty_map, poly_pgm, mono_pgm)); locals.insert(var.clone()); mono::Pat::Var(mono::VarPat { var: var.clone(), ty: mono_ty, + refined, }) } @@ -1342,15 +1347,23 @@ fn mono_pat( ), }), - ast::Pat::Variant(ast::VariantPat { pat, inferred_ty }) => { - mono::Pat::Variant(mono::VariantPat { - pat: mono_bl_pat(pat, ty_map, poly_pgm, mono_pgm, locals), - ty: get_variant_ty( - mono_tc_ty(inferred_ty.as_ref().unwrap(), ty_map, poly_pgm, mono_pgm), - loc, - ), - }) - } + ast::Pat::Variant(ast::VariantPat { + pat, + inferred_ty, + inferred_pat_ty, + }) => mono::Pat::Variant(mono::VariantPat { + pat: mono_bl_pat(pat, ty_map, poly_pgm, mono_pgm, locals), + variant_ty: get_variant_ty( + mono_tc_ty(inferred_ty.as_ref().unwrap(), ty_map, poly_pgm, mono_pgm), + loc, + ), + pat_ty: mono_tc_ty( + inferred_pat_ty.as_ref().unwrap(), + ty_map, + poly_pgm, + mono_pgm, + ), + }), } } @@ -1877,6 +1890,7 @@ fn mono_ty_decl( mono::TypeDecl { name: mono_ty_id.clone(), rhs: None, + value: ty_decl.value, }, ); @@ -1905,6 +1919,7 @@ fn mono_ty_decl( mono::TypeDecl { name: mono_ty_id.clone(), rhs, + value: ty_decl.value, }, ); diff --git a/src/parser.lalrpop b/src/parser.lalrpop index dda25d03..ab282478 100644 --- a/src/parser.lalrpop +++ b/src/parser.lalrpop @@ -951,7 +951,7 @@ LPat: L = { Pat: Pat = { #[precedence(level = "0")] - => Pat::Var(VarPat { var: id.smol_str(), ty: None }), + => Pat::Var(VarPat { var: id.smol_str(), ty: None, refined: None }), => Pat::Con(con), @@ -974,7 +974,7 @@ Pat: Pat = { => Pat::Char(parse_char_lit(&char.text)), "~" => - Pat::Variant(VariantPat { pat: Box::new(L::new(module, l, r, pat)), inferred_ty: None }), + Pat::Variant(VariantPat { pat: Box::new(L::new(module, l, r, pat)), inferred_ty: None, inferred_pat_ty: None }), #[precedence(level = "1")] #[assoc(side = "right")] diff --git a/src/parser.rs b/src/parser.rs index 2c4c032d..20469a29 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1,5 +1,5 @@ // auto-generated: "lalrpop 0.22.2" -// sha3: f7a8829affc1d86e00b24542123a0e5870ee952878606b525e1e96baeb8430d1 +// sha3: 7cce679c0a9108b8107aba89173064c3c44ce886f9d92d415bbe0b1ddb93c5c7 #![allow(clippy::all)] use crate::ast::*; use crate::interpolation::{copy_update_escapes, str_parts}; @@ -48799,6 +48799,7 @@ fn __action147<'a>(module: &'a Rc, (_, id, _): (Loc, Token, Loc)) -> Pat { Pat::Var(VarPat { var: id.smol_str(), ty: None, + refined: None, }) } @@ -48886,6 +48887,7 @@ fn __action153<'a>( Pat::Variant(VariantPat { pat: Box::new(L::new(module, l, r, pat)), inferred_ty: None, + inferred_pat_ty: None, }) } diff --git a/src/to_c.rs b/src/to_c.rs index fff36863..d9671d3a 100644 --- a/src/to_c.rs +++ b/src/to_c.rs @@ -65,10 +65,6 @@ pub(crate) fn to_c(pgm: &LoweredPgm, main: &str) -> String { #include #include - typedef struct {{ - uint64_t tag; - }} Variant; - " ); @@ -102,7 +98,12 @@ pub(crate) fn to_c(pgm: &LoweredPgm, main: &str) -> String { } } - let heap_objs_sorted = top_sort(&pgm.type_objs, &pgm.record_objs, &pgm.heap_objs); + let heap_objs_sorted = top_sort( + &pgm.type_objs, + &pgm.record_objs, + &pgm.variant_objs, + &pgm.heap_objs, + ); for scc in &heap_objs_sorted { // If SCC has more than one element, forward-declare the structs. if scc.len() > 1 { @@ -116,6 +117,10 @@ pub(crate) fn to_c(pgm: &LoweredPgm, main: &str) -> String { let struct_name = record_struct_name(record); wln!(p, "typedef struct {struct_name} {struct_name};"); } + HeapObj::Variant(variant) => { + let struct_name = variant_struct_name(variant); + wln!(p, "typedef struct {struct_name} {struct_name};"); + } HeapObj::Builtin(BuiltinConDecl::Array { t }) => { let struct_name = array_struct_name(t); wln!(p, "typedef struct {struct_name} {struct_name};"); @@ -148,12 +153,12 @@ pub(crate) fn to_c(pgm: &LoweredPgm, main: &str) -> String { typedef struct ExnHandler {{ jmp_buf buf; struct ExnHandler* prev; - uint64_t exn_value; + void* exn_value; // pointer to boxed exception value }} ExnHandler; static ExnHandler* current_exn_handler = NULL; - static void throw_exn(uint64_t exn) {{ + static void throw_exn(void* exn) {{ if (current_exn_handler == NULL) {{ fprintf(stderr, \"Uncaught exception\\n\"); exit(1); @@ -467,6 +472,7 @@ fn heap_obj_to_c(heap_obj: &HeapObj, tag: u32, p: &mut Printer) { HeapObj::Builtin(builtin) => builtin_con_decl_to_c(builtin, tag, p), HeapObj::Source(source_con) => source_con_decl_to_c(source_con, tag, p), HeapObj::Record(record) => record_decl_to_c(record, tag, p), + HeapObj::Variant(variant) => variant_decl_to_c(variant, tag, p), } } @@ -568,6 +574,15 @@ fn record_struct_name(record: &RecordType) -> String { name } +fn variant_struct_name(variant: &VariantType) -> String { + let mut name = String::from("Variant"); + for named_ty in variant.alts.values() { + name.push('_'); + named_ty_to_c(named_ty, &mut name); + } + name +} + fn array_struct_name(t: &mono::Type) -> String { let mut name = String::from("Array_"); let t = c_ty(t); @@ -580,6 +595,7 @@ fn heap_obj_struct_name(pgm: &LoweredPgm, idx: HeapObjIdx) -> String { match &pgm.heap_objs[idx.0 as usize] { HeapObj::Source(source_con) => source_con_struct_name(source_con), HeapObj::Record(record) => record_struct_name(record), + HeapObj::Variant(variant) => variant_struct_name(variant), HeapObj::Builtin(_) => panic!("Builtin in heap_obj_struct_name"), } } @@ -588,6 +604,7 @@ fn heap_obj_tag_name(pgm: &LoweredPgm, idx: HeapObjIdx) -> String { match &pgm.heap_objs[idx.0 as usize] { HeapObj::Source(source_con) => source_con_tag_name(source_con), HeapObj::Record(record) => format!("TAG_{}", record_struct_name(record)), + HeapObj::Variant(_) => panic!("Variants don't have runtime tags"), HeapObj::Builtin(_) => panic!("Builtin in heap_obj_tag_name"), } } @@ -617,6 +634,7 @@ fn heap_obj_singleton_name(pgm: &LoweredPgm, idx: HeapObjIdx) -> String { match &pgm.heap_objs[idx.0 as usize] { HeapObj::Source(source_con) => source_con_singleton_name(source_con), HeapObj::Record(record) => format!("_singleton_{}", record_struct_name(record)), + HeapObj::Variant(_) => panic!("Variants don't have singletons"), HeapObj::Builtin(_) => panic!("Builtin heap objects don't have singletons"), } } @@ -634,37 +652,58 @@ fn source_con_decl_to_c(source_con: &SourceConDecl, tag: u32, p: &mut Printer) { let tag_name = source_con_tag_name(source_con); let struct_name = source_con_struct_name(source_con); - wln!(p, "#define {} {}", tag_name, tag); + wln!(p, "#define {tag_name} {tag}"); - w!(p, "typedef struct {} {{", struct_name); + w!(p, "typedef struct {struct_name} {{"); p.indent(); p.nl(); w!(p, "uint64_t _tag;"); for (i, ty) in fields.iter().enumerate() { p.nl(); - w!(p, "{} _{};", c_ty(ty), i); + w!(p, "{} _{i};", c_ty(ty)); } p.dedent(); p.nl(); - wln!(p, "}} {};", struct_name); + wln!(p, "}} {struct_name};"); } fn record_decl_to_c(record: &RecordType, tag: u32, p: &mut Printer) { let struct_name = record_struct_name(record); - wln!(p, "#define TAG_{} {}", struct_name, tag); + wln!(p, "#define TAG_{struct_name} {tag}"); - w!(p, "typedef struct {} {{", struct_name); + w!(p, "typedef struct {struct_name} {{"); p.indent(); p.nl(); w!(p, "uint64_t _tag;"); for (i, (_field_name, field_ty)) in record.fields.iter().enumerate() { p.nl(); - w!(p, "{} _{};", c_ty(field_ty), i); + w!(p, "{} _{i};", c_ty(field_ty)); + } + p.dedent(); + p.nl(); + wln!(p, "}} {struct_name};"); +} + +fn variant_decl_to_c(variant: &VariantType, tag: u32, p: &mut Printer) { + let struct_name = variant_struct_name(variant); + wln!(p, "// tag = {tag}"); + w!(p, "typedef struct {} {{", struct_name); + p.indent(); + p.nl(); + wln!(p, "uint64_t _tag;"); + w!(p, "union {{"); + p.indent(); + for (i, alt) in variant.alts.values().enumerate() { + p.nl(); + w!(p, "{} _{i};", c_ty(&mono::Type::Named(alt.clone()))); } p.dedent(); p.nl(); - wln!(p, "}} {};", struct_name); + w!(p, "}} _alt;"); + p.dedent(); + p.nl(); + wln!(p, "}} {struct_name};"); } fn named_ty_to_c(named_ty: &mono::NamedType, out: &mut String) { @@ -700,12 +739,11 @@ fn c_ty(ty: &mono::Type) -> String { if let mono::Type::Fn(_) = ty { return "CLOSURE*".to_string(); } - if let mono::Type::Variant { .. } = ty { - return "Variant*".to_string(); - } let mut s = String::new(); ty_to_c(ty, &mut s); - s.push('*'); // make pointer + if !matches!(ty, mono::Type::Variant { .. }) { + s.push('*'); // make pointer + } s } @@ -988,16 +1026,13 @@ fn builtin_fun_to_c( BuiltinFunDecl::I32Neg => wln!(p, "static I32 _fun_{idx}(I32 a) {{ return -a; }}"), BuiltinFunDecl::ThrowUnchecked => { - w!( - p, - "static {} _fun_{}({} exn) {{", - c_ty(ret), - idx, - c_ty(¶ms[0]) - ); + let exn_ty = c_ty(¶ms[0]); + w!(p, "static {} _fun_{}({} exn) {{", c_ty(ret), idx, exn_ty); p.indent(); p.nl(); - wln!(p, "throw_exn((uint64_t)exn);"); + wln!(p, "{exn_ty}* boxed = malloc(sizeof({exn_ty}));"); + wln!(p, "*boxed = exn;"); + wln!(p, "throw_exn(boxed);"); w!(p, "__builtin_unreachable();"); p.dedent(); p.nl(); @@ -1052,7 +1087,8 @@ fn builtin_fun_to_c( "{err_struct_name}* err = malloc(sizeof({err_struct_name}));" ); wln!(p, "err->_tag = {};", err_tag_name); - wln!(p, "err->_0 = ({})handler.exn_value;", c_ty(&ty_args[1])); + let exn_ty = c_ty(&ty_args[1]); + wln!(p, "err->_0 = *({exn_ty}*)handler.exn_value;"); w!(p, "return ({})err;", c_ty(ret)); p.dedent(); p.nl(); @@ -1416,7 +1452,11 @@ fn stmt_to_c( w!(p, "{} {} = ", c_ty(rhs_ty), rhs_temp); expr_to_c(&rhs.node, &rhs.loc, locals, cg, p); wln!(p, "; // {}", loc_display(&rhs.loc)); - wln!(p, "{};", pat_to_cond(&lhs.node, &rhs_temp, cg)); + wln!( + p, + "{};", + pat_to_cond(&lhs.node, &rhs_temp, rhs_ty, None, locals, cg) + ); if let Some(result_var) = result_var { wln!( p, @@ -1759,7 +1799,7 @@ fn expr_to_c(expr: &Expr, loc: &Loc, locals: &[LocalInfo], cg: &mut Cg, p: &mut w!(p, " else "); } // Generate pattern match condition - let cond = pat_to_cond(&alt.pat.node, &scrut_temp, cg); + let cond = pat_to_cond(&alt.pat.node, &scrut_temp, scrut_ty, None, locals, cg); w!(p, "if ({}", cond); // Add guard if present @@ -1935,7 +1975,11 @@ fn expr_to_c(expr: &Expr, loc: &Loc, locals: &[LocalInfo], cg: &mut Cg, p: &mut expr_to_c(&expr.node, &expr.loc, locals, cg, p); wln!(p, "; // {}", loc_display(&expr.loc)); wln!(p, "Bool* _is_result;"); - w!(p, "if ({}) {{", pat_to_cond(&pat.node, &expr_temp, cg)); + w!( + p, + "if ({}) {{", + pat_to_cond(&pat.node, &expr_temp, expr_ty, None, locals, cg) + ); p.indent(); p.nl(); w!( @@ -1975,32 +2019,232 @@ fn expr_to_c(expr: &Expr, loc: &Loc, locals: &[LocalInfo], cg: &mut Cg, p: &mut w!(p, "}})"); } - Expr::Variant(expr) => { - // Variants are represented as their underlying type - w!(p, "((Variant*)"); + Expr::Variant { + expr, + expr_ty, + variant_ty, + } => { + /* + ~ + + ==> + + ({ temp1 = ; + uint64_t temp2 = get_tag(temp1); + temp3 = { .tag = temp2, .alt._N = temp1 }; + temp3; }) + + where: + + - `get_tag` is type-specific tag getter + - `N` is the index of the named type in the variant type + */ + + w!(p, "({{"); + p.indent(); + p.nl(); + + // TODO: Check that variant exprs are named types in an earlier pass. + let expr_named_ty = match expr_ty { + mono::Type::Named(named_ty) => named_ty, + _ => panic!(), + }; + + let alt_idx = variant_ty + .iter() + .enumerate() + .find_map(|(idx, (_, alt_ty))| { + if alt_ty == expr_named_ty { + Some(idx) + } else { + None + } + }) + .unwrap(); + + let variant_struct_name = variant_struct_name(&VariantType { + alts: variant_ty.clone(), + }); + + let expr_temp = cg.fresh_temp(); + w!(p, "{} {expr_temp} = ", c_ty(expr_ty)); expr_to_c(&expr.node, &expr.loc, locals, cg, p); - w!(p, ")"); + wln!(p, "; // {}", loc_display(&expr.loc)); + + let expr_tag_temp = cg.fresh_temp(); + wln!( + p, + "uint32_t {expr_tag_temp} = {};", + gen_get_tag(cg.pgm, &expr_temp, expr_ty) + ); + + let variant_temp = cg.fresh_temp(); + wln!( + p, + "{variant_struct_name} {variant_temp} = {{ ._tag = {expr_tag_temp}, ._alt._{alt_idx} = {expr_temp} }};" + ); + w!(p, "{variant_temp};"); + + p.dedent(); + p.nl(); + w!(p, "}})"); + } + } +} + +/// Given a pattern type inside a variant pattern, find which alternative in the variant it matches. +/// Returns the index of the alternative. +fn find_variant_alt_index(pat_ty: &mono::Type, variant_ty: &OrdMap) -> usize { + let type_name = match pat_ty { + mono::Type::Named(named_ty) => named_ty.name.as_str(), + _ => panic!("Non-named type in variant pattern: {:?}", pat_ty), + }; + + variant_ty + .iter() + .enumerate() + .find_map(|(idx, (name, _))| { + if name.as_str() == type_name { + Some(idx) + } else { + None + } + }) + .unwrap_or_else(|| panic!("Type {type_name} not found in variant alternatives")) +} + +/// Check if we need to convert between two variant types. Returns true if both types are variants +/// and they differ. +fn needs_variant_conversion(from_ty: &mono::Type, to_ty: &mono::Type) -> bool { + match (from_ty, to_ty) { + (mono::Type::Variant { alts: from_alts }, mono::Type::Variant { alts: to_alts }) => { + from_alts != to_alts + } + _ => false, + } +} + +/// Generate code to convert a value from one variant type to another: unpack the value from the +/// source variant and repack it into the target variant. +fn gen_variant_conversion( + scrutinee: &str, + from_ty: &mono::Type, + to_ty: &mono::Type, + cg: &mut Cg, +) -> String { + let (from_alts, to_alts) = match (from_ty, to_ty) { + (mono::Type::Variant { alts: from }, mono::Type::Variant { alts: to }) => (from, to), + _ => panic!("gen_variant_conversion called with non-variant types"), + }; + + let to_variant_ty = VariantType { + alts: to_alts.clone(), + }; + let to_struct_name = variant_struct_name(&to_variant_ty); + + // Handle empty target variant - this is an unreachable case at runtime, + // but we still need to generate valid C code. Just copy the tag. + if to_alts.is_empty() { + let temp = cg.fresh_temp(); + return format!( + "({{ {to_struct_name} {temp}; {temp}._tag = ({scrutinee})._tag; {temp}; }})" + ); + } + + // Find the mapping from source alternative index to target alternative index. + // The value's tag tells us which alternative is active in the source. + // We need to find the corresponding alternative in the target and repack. + + // Generate a compound expression that: + // 1. Reads the tag from the source + // 2. Based on the tag, copies the value to the appropriate field in the target + + let temp = cg.fresh_temp(); + let mut cases = String::new(); + + for (to_idx, (type_name, named_ty)) in to_alts.iter().enumerate() { + // Find this type in the source variant + let (from_idx, _) = from_alts + .iter() + .enumerate() + .find(|(_, (name, _))| *name == type_name) + .unwrap_or_else(|| { + panic!( + "Type {} not found in source variant during conversion", + type_name + ) + }); + + let alt_ty = mono::Type::Named(named_ty.clone()); + let expected_tag = gen_get_tag(cg.pgm, &format!("({scrutinee})._alt._{from_idx}"), &alt_ty); + + if !cases.is_empty() { + cases.push_str(" else "); } + cases.push_str(&format!( + "if (({scrutinee})._tag == {expected_tag}) {{ {temp}._tag = {expected_tag}; {temp}._alt._{to_idx} = ({scrutinee})._alt._{from_idx}; }}" + )); } + + // Add a fallback case (should never happen if types are correct) + cases.push_str(" else { fprintf(stderr, \"Invalid variant conversion\\n\"); exit(1); }"); + + format!("({{ {to_struct_name} {temp}; {cases} {temp}; }})") } /// Generate a C condition expression for pattern matching. -fn pat_to_cond(pat: &Pat, scrutinee: &str, cg: &mut Cg) -> String { +/// +/// - `scrutinee` is the expression being matched against. +/// +/// - `scrutinee_ty` is the type of the scrutinee (used for generating tag checks with `gen_get_tag`). +/// +/// - `tag_expr` is an optional override for how to get the tag. When `Some`, use that expression +/// directly (e.g., for variant patterns where we check the variant's `_tag` field). When `None`, +/// derive the tag from the scrutinee using `gen_get_tag`. +fn pat_to_cond( + pat: &Pat, + scrutinee: &str, + scrutinee_ty: &mono::Type, + tag_expr: Option<&str>, + locals: &[LocalInfo], + cg: &mut Cg, +) -> String { match pat { Pat::Ignore => "1".to_string(), - Pat::Var(idx) => { - format!("({{ _{} = {}; 1; }})", idx.as_usize(), scrutinee) + Pat::Var(VarPat { idx, original_ty }) => { + let refined_ty = &locals[idx.as_usize()].ty; + if needs_variant_conversion(original_ty, refined_ty) { + let conversion = gen_variant_conversion(scrutinee, original_ty, refined_ty, cg); + format!("({{ _{} = {}; 1; }})", idx.as_usize(), conversion) + } else { + format!("({{ _{} = {}; 1; }})", idx.as_usize(), scrutinee) + } } Pat::Con(ConPat { con, fields }) => { let struct_name = heap_obj_struct_name(cg.pgm, *con); let tag_name = heap_obj_tag_name(cg.pgm, *con); - let mut cond = format!("(get_tag({}) == {})", scrutinee, tag_name); + let tag_check = match tag_expr { + Some(expr) => format!("({expr} == {tag_name})"), + None => { + let get_tag = gen_get_tag(cg.pgm, scrutinee, scrutinee_ty); + format!("({get_tag} == {tag_name})") + } + }; + let field_tys: Vec = match &cg.pgm.heap_objs[con.as_usize()] { + HeapObj::Source(source_con) => source_con.fields.clone(), + HeapObj::Record(record) => record.fields.values().cloned().collect(), + HeapObj::Builtin(_) => panic!("Builtin constructor {:?} in Pat::Con", con), + HeapObj::Variant(_) => panic!("Variant in Pat::Con"), + }; + let mut cond = tag_check; for (i, field_pat) in fields.iter().enumerate() { let field_expr = format!("(({struct_name}*){scrutinee})->_{i}"); - let field_cond = pat_to_cond(&field_pat.node, &field_expr, cg); - cond = format!("({} && {})", cond, field_cond); + let field_ty = &field_tys[i]; + let field_cond = + pat_to_cond(&field_pat.node, &field_expr, field_ty, None, locals, cg); + cond = format!("({cond} && {field_cond})"); } cond } @@ -2009,39 +2253,59 @@ fn pat_to_cond(pat: &Pat, scrutinee: &str, cg: &mut Cg) -> String { let mut escaped = String::new(); for byte in s.bytes() { if byte == b'"' || byte == b'\\' || !(32..=126).contains(&byte) { - // Same as `Expr::Str`, use octal escape here instead of hex. escaped.push_str(&format!("\\{:03o}", byte)); } else { escaped.push(byte as char); } } - // Note: the type cast below is to handle strings in variants. Variants are currently - // `Variant*` so they need to be cast. + let tag_check = match tag_expr { + Some(expr) => format!("({expr} == {})", cg.pgm.str_con_idx.0), + None => { + let get_tag = gen_get_tag(cg.pgm, scrutinee, scrutinee_ty); + format!("({get_tag} == {})", cg.pgm.str_con_idx.0) + } + }; format!( - "(get_tag({}) == {} && str_eq((Str*){}, \"{}\", {}))", - scrutinee, - cg.pgm.str_con_idx.0, - scrutinee, - escaped, + "({tag_check} && str_eq((Str*){scrutinee}, \"{escaped}\", {}))", s.len() ) } Pat::Char(c) => { let tag_name = heap_obj_tag_name(cg.pgm, cg.pgm.char_con_idx); - format!( - "(get_tag({}) == {} && ((Char*){})->_0 == {})", - scrutinee, tag_name, scrutinee, *c as u32 - ) + let tag_check = match tag_expr { + Some(expr) => format!("({expr} == {tag_name})"), + None => { + let get_tag = gen_get_tag(cg.pgm, scrutinee, scrutinee_ty); + format!("({get_tag} == {tag_name})") + } + }; + format!("({tag_check} && ((Char*){scrutinee})->_0 == {})", *c as u32) } Pat::Or(p1, p2) => { - let c1 = pat_to_cond(&p1.node, scrutinee, cg); - let c2 = pat_to_cond(&p2.node, scrutinee, cg); - format!("({} || {})", c1, c2) + let c1 = pat_to_cond(&p1.node, scrutinee, scrutinee_ty, tag_expr, locals, cg); + let c2 = pat_to_cond(&p2.node, scrutinee, scrutinee_ty, tag_expr, locals, cg); + format!("({c1} || {c2})") } - Pat::Variant(inner) => pat_to_cond(&inner.node, scrutinee, cg), + Pat::Variant { + pat, + variant_ty, + pat_ty, + } => { + let alt_idx = find_variant_alt_index(pat_ty, variant_ty); + let inner_expr = format!("({scrutinee})._alt._{alt_idx}"); + let variant_tag_expr = format!("({scrutinee})._tag"); + pat_to_cond( + &pat.node, + &inner_expr, + pat_ty, + Some(&variant_tag_expr), + locals, + cg, + ) + } } } @@ -2069,6 +2333,59 @@ fn generate_main_fn(pgm: &LoweredPgm, main: &str, p: &mut Printer) { wln!(p, "}}"); } +/// Generate the C expression to get the tag of the expression `expr`, which should have the type `ty`. +/// +/// For product types: the tag will be the macro that defines the tag. +/// +/// For sum types: the tag will be extracted from the expression and the code will depend on whether +/// the sum type is a value type or not. +/// +/// - For boxed sum types: the generated code will read the tag word of the heap allocated object. +/// - For unboxed sum types: the generated code will read the tag from the struct of the sum type. +fn gen_get_tag(pgm: &LoweredPgm, expr: &str, ty: &mono::Type) -> String { + // For product types, use the tag macro. + match ty { + mono::Type::Named(mono::NamedType { name, args }) => { + match pgm.type_objs.get(name).unwrap().get(args).unwrap() { + TypeObjs::Product(heap_obj_idx) => heap_obj_tag_name(pgm, *heap_obj_idx), + + TypeObjs::Sum { + con_indices: _, + value: true, + } => { + format!("((uint32_t)({expr})._tag)") + } + + TypeObjs::Sum { + con_indices: _, + value: false, + } => { + format!("((uint32_t)*(uint64_t*)({expr}))") + } + } + } + + mono::Type::Record { fields } => { + let heap_obj_idx = *pgm + .record_objs + .get(&RecordType { + fields: fields.clone(), + }) + .unwrap(); + heap_obj_tag_name(pgm, heap_obj_idx) + } + + mono::Type::Variant { alts: _ } => { + format!("((uint32_t)({expr})._tag)") + } + + mono::Type::Fn(_) => "CLOSURE_TAG".to_string(), + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Printing utils + #[derive(Debug, Default)] struct Printer { lines: Vec, @@ -2120,6 +2437,7 @@ impl Write for Printer { fn top_sort( type_objs: &HashMap, TypeObjs>>, record_objs: &HashMap, + variant_objs: &HashMap, heap_objs: &[HeapObj], ) -> Vec> { let mut idx_gen = SccIdxGen::default(); @@ -2143,7 +2461,7 @@ fn top_sort( output.push(std::iter::once(HeapObjIdx(heap_obj_idx as u32)).collect()); Some(idx_gen.next()) } - HeapObj::Source(_) | HeapObj::Record(_) => None, + HeapObj::Source(_) | HeapObj::Record(_) | HeapObj::Variant(_) => None, }, low_link: None, on_stack: false, @@ -2157,6 +2475,7 @@ fn top_sort( _scc( type_objs, record_objs, + variant_objs, heap_objs, HeapObjIdx(heap_obj_idx as u32), &mut idx_gen, @@ -2196,6 +2515,7 @@ impl SccIdxGen { fn _scc( type_objs: &HashMap, TypeObjs>>, record_objs: &HashMap, + variant_objs: &HashMap, heap_objs: &[HeapObj], heap_obj_idx: HeapObjIdx, idx_gen: &mut SccIdxGen, @@ -2212,13 +2532,20 @@ fn _scc( stack.push(heap_obj_idx); // Add dependencies to the output. - let deps = heap_obj_deps(type_objs, record_objs, heap_objs, heap_obj_idx); + let deps = heap_obj_deps( + type_objs, + record_objs, + variant_objs, + heap_objs, + heap_obj_idx, + ); for dep_obj in deps { if nodes[dep_obj.as_usize()].idx.is_none() { // Dependency not visited yet. _scc( type_objs, record_objs, + variant_objs, heap_objs, dep_obj, idx_gen, @@ -2256,6 +2583,7 @@ fn _scc( fn heap_obj_deps( type_objs: &HashMap, TypeObjs>>, record_objs: &HashMap, + variant_objs: &HashMap, heap_objs: &[HeapObj], heap_obj_idx: HeapObjIdx, ) -> HashSet { @@ -2263,20 +2591,26 @@ fn heap_obj_deps( match &heap_objs[heap_obj_idx.as_usize()] { HeapObj::Builtin(BuiltinConDecl::Array { t }) => { - type_heap_obj_deps(type_objs, record_objs, t, &mut deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, t, &mut deps); } HeapObj::Builtin(_) => {} HeapObj::Source(source_decl) => { for field in source_decl.fields.iter() { - type_heap_obj_deps(type_objs, record_objs, field, &mut deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, field, &mut deps); } } HeapObj::Record(record_type) => { for field in record_type.fields.values() { - type_heap_obj_deps(type_objs, record_objs, field, &mut deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, field, &mut deps); + } + } + + HeapObj::Variant(variant_type) => { + for named_ty in variant_type.alts.values() { + named_type_heap_obj_deps(type_objs, named_ty, &mut deps); } } } @@ -2287,6 +2621,7 @@ fn heap_obj_deps( fn type_heap_obj_deps( type_objs: &HashMap, TypeObjs>>, record_objs: &HashMap, + variant_objs: &HashMap, ty: &mono::Type, deps: &mut HashSet, ) { @@ -2296,18 +2631,22 @@ fn type_heap_obj_deps( } mono::Type::Record { fields } => { - let record_idx = record_objs + let record_idx = *record_objs .get(&RecordType { fields: fields.clone(), }) .unwrap(); - deps.insert(*record_idx); + deps.insert(record_idx); for ty in fields.values() { - type_heap_obj_deps(type_objs, record_objs, ty, deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, ty, deps); } } mono::Type::Variant { alts } => { + let variant_idx = *variant_objs + .get(&VariantType { alts: alts.clone() }) + .unwrap(); + deps.insert(variant_idx); for ty in alts.values() { named_type_heap_obj_deps(type_objs, ty, deps); } @@ -2317,18 +2656,18 @@ fn type_heap_obj_deps( match args { mono::FunArgs::Positional(args) => { for arg in args { - type_heap_obj_deps(type_objs, record_objs, arg, deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, arg, deps); } } mono::FunArgs::Named(args) => { for arg in args.values() { - type_heap_obj_deps(type_objs, record_objs, arg, deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, arg, deps); } } } - type_heap_obj_deps(type_objs, record_objs, ret, deps); - type_heap_obj_deps(type_objs, record_objs, exn, deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, ret, deps); + type_heap_obj_deps(type_objs, record_objs, variant_objs, exn, deps); } } } @@ -2347,6 +2686,9 @@ fn named_type_heap_obj_deps( TypeObjs::Product(idx) => { deps.insert(*idx); } - TypeObjs::Sum(idxs) => deps.extend(idxs.iter().cloned()), + TypeObjs::Sum { + con_indices, + value: _, + } => deps.extend(con_indices.iter().cloned()), } } diff --git a/src/type_checker/expr.rs b/src/type_checker/expr.rs index c9bd0298..0d67f197 100644 --- a/src/type_checker/expr.rs +++ b/src/type_checker/expr.rs @@ -858,6 +858,7 @@ pub(super) fn check_expr( node: ast::Pat::Var(ast::VarPat { var: buf_id.clone(), ty: Some(str_buf_ty.clone()), + refined: None, }), }, ty: None, @@ -1777,7 +1778,22 @@ pub(super) fn check_match_expr( for (alt_idx, (ast::Alt { pat, guard, rhs }, mut alt_scope)) in alts.iter_mut().zip(alt_envs.into_iter()).enumerate() { - refine_binders(&mut alt_scope, &info.bound_vars[alt_idx], &pat.loc); + let refined_binders = refine_binders(&info.bound_vars[alt_idx], &pat.loc); + + if cfg!(debug_assertions) { + let scope_vars: HashSet<&Id> = alt_scope.keys().collect(); + let binders_vars: HashSet<&Id> = refined_binders.keys().collect(); + assert_eq!(scope_vars, binders_vars); + } + + add_coercions( + &mut pat.node, + &refined_binders, + tc_state.tys.tys.cons(), + &pat.loc, + ); + + alt_scope.extend(refined_binders); tc_state.env.push_scope(alt_scope); @@ -2149,12 +2165,8 @@ pub(crate) fn make_variant(tc_state: &mut TcFunState, ty: Ty, level: u32, loc: & } } -fn refine_binders(scope: &mut HashMap, binders: &HashMap>, loc: &ast::Loc) { - if cfg!(debug_assertions) { - let scope_vars: HashSet<&Id> = scope.keys().collect(); - let binders_vars: HashSet<&Id> = binders.keys().collect(); - assert_eq!(scope_vars, binders_vars); - } +fn refine_binders(binders: &HashMap>, loc: &ast::Loc) -> HashMap { + let mut refined_binders: HashMap = Default::default(); for (var, tys) in binders.iter() { // println!("{} --> {:?}", var, tys); @@ -2162,7 +2174,8 @@ fn refine_binders(scope: &mut HashMap, binders: &HashMap if tys.len() == 1 { // println!("{} --> {}", var, tys.iter().next().unwrap().clone()); - scope.insert(var.clone(), tys.iter().next().unwrap().clone()); + let old = refined_binders.insert(var.clone(), tys.iter().next().unwrap().clone()); + assert_eq!(old, None); } else { let mut labels: OrdMap = Default::default(); let mut extension: Option> = None; @@ -2208,7 +2221,58 @@ fn refine_binders(scope: &mut HashMap, binders: &HashMap // println!("{} --> {}", var, new_ty); - scope.insert(var.clone(), new_ty); + let old = refined_binders.insert(var.clone(), new_ty); + assert_eq!(old, None); + } + } + + refined_binders +} + +fn add_coercions( + pat: &mut ast::Pat, + refined_binders: &HashMap, + cons: &ScopeMap, + _loc: &ast::Loc, +) { + match pat { + ast::Pat::Var(ast::VarPat { var, ty, refined }) => { + assert_eq!(refined, &mut None); + let refined_ty = refined_binders.get(var).unwrap().deep_normalize(cons); + let ty = ty.as_ref().unwrap().deep_normalize(cons); + if refined_ty != ty { + *refined = Some(refined_ty); + } + } + + ast::Pat::Con(ast::ConPat { + con: _, + fields, + ignore_rest: _, + }) + | ast::Pat::Record(ast::RecordPat { + fields, + ignore_rest: _, + inferred_ty: _, + }) => { + for field in fields { + add_coercions(&mut field.node.node, refined_binders, cons, &field.node.loc); + } + } + + ast::Pat::Ignore | ast::Pat::Str(_) | ast::Pat::Char(_) => {} + + ast::Pat::Or(p1, p2) => { + add_coercions(&mut p1.node, refined_binders, cons, &p1.loc); + add_coercions(&mut p2.node, refined_binders, cons, &p2.loc); + } + + ast::Pat::Variant(ast::VariantPat { + pat, + inferred_ty: _, + inferred_pat_ty: _, + }) => { + add_coercions(&mut pat.node, refined_binders, cons, &pat.loc); } } } diff --git a/src/type_checker/normalization.rs b/src/type_checker/normalization.rs index 4c13028e..5fb52ad3 100644 --- a/src/type_checker/normalization.rs +++ b/src/type_checker/normalization.rs @@ -228,8 +228,15 @@ fn normalize_expr(expr: &mut ast::Expr, loc: &ast::Loc, cons: &ScopeMap) { match pat { - ast::Pat::Var(ast::VarPat { var: _, ty }) => { + ast::Pat::Var(ast::VarPat { + var: _, + ty, + refined, + }) => { *ty = Some(ty.as_ref().unwrap().deep_normalize(cons)); + if let Some(ty) = refined { + *ty = ty.deep_normalize(cons); + } } ast::Pat::Ignore | ast::Pat::Str(_) | ast::Pat::Char(_) => {} @@ -263,8 +270,13 @@ fn normalize_pat(pat: &mut ast::Pat, cons: &ScopeMap) { .for_each(|ast::Named { name: _, node }| normalize_pat(&mut node.node, cons)); } - ast::Pat::Variant(ast::VariantPat { pat, inferred_ty }) => { + ast::Pat::Variant(ast::VariantPat { + pat, + inferred_ty, + inferred_pat_ty, + }) => { *inferred_ty = Some(inferred_ty.as_mut().unwrap().deep_normalize(cons)); + *inferred_pat_ty = Some(inferred_pat_ty.as_mut().unwrap().deep_normalize(cons)); normalize_pat(&mut pat.node, cons); } } diff --git a/src/type_checker/pat.rs b/src/type_checker/pat.rs index bcd2c930..944eb953 100644 --- a/src/type_checker/pat.rs +++ b/src/type_checker/pat.rs @@ -10,7 +10,11 @@ use crate::type_checker::{TcFunState, loc_display}; /// `pat` is `mut` to be able to add types of variables and type arguments of constructors. pub(super) fn check_pat(tc_state: &mut TcFunState, pat: &mut ast::L, level: u32) -> Ty { match &mut pat.node { - ast::Pat::Var(ast::VarPat { var, ty }) => { + ast::Pat::Var(ast::VarPat { + var, + ty, + refined: _, + }) => { assert!(ty.is_none()); let fresh_ty = Ty::UVar(tc_state.var_gen.new_var(level, Kind::Star, pat.loc.clone())); *ty = Some(fresh_ty.clone()); @@ -252,12 +256,21 @@ pub(super) fn check_pat(tc_state: &mut TcFunState, pat: &mut ast::L, l ty } - ast::Pat::Variant(ast::VariantPat { pat, inferred_ty }) => { + ast::Pat::Variant(ast::VariantPat { + pat, + inferred_ty, + inferred_pat_ty, + }) => { assert!(inferred_ty.is_none()); + assert!(inferred_pat_ty.is_none()); + let pat_ty = check_pat(tc_state, pat, level); + *inferred_pat_ty = Some(pat_ty.clone()); + let variant_ty = crate::type_checker::expr::make_variant(tc_state, pat_ty, level, &pat.loc); *inferred_ty = Some(variant_ty.clone()); + variant_ty } } diff --git a/src/type_checker/pat_coverage.rs b/src/type_checker/pat_coverage.rs index 9dcd72bd..d29ba08c 100644 --- a/src/type_checker/pat_coverage.rs +++ b/src/type_checker/pat_coverage.rs @@ -42,8 +42,12 @@ struct PatMatrix { struct Row { /// `match` arm index the row is generated from. arm_index: ArmIndex, + pats: Vec>, + + /// Maps variables in the row to types they're bound. bound_vars: HashMap>, + guarded: bool, } @@ -52,7 +56,7 @@ pub(crate) struct CoverageInfo { /// Maps arm indices to variables bound in the arms. pub(crate) bound_vars: Vec>>, - /// Maps arm indices to whether its useful. + /// Maps arm indices to whether they're useful. pub(crate) usefulness: Vec, } @@ -502,7 +506,11 @@ impl PatMatrix { // pat should be variable with the same name as the // field. (type checker checks this) match &named_field.node.node { - ast::Pat::Var(ast::VarPat { var, ty: _ }) => var, + ast::Pat::Var(ast::VarPat { + var, + ty: _, + refined: _, + }) => var, _ => panic!(), } } @@ -602,6 +610,7 @@ impl PatMatrix { ast::Pat::Variant(ast::VariantPat { pat, inferred_ty: _, + inferred_pat_ty: _, }) => { work.push((*pat).clone()); } diff --git a/src/type_checker/stmt.rs b/src/type_checker/stmt.rs index cc2ecffd..0ecebf54 100644 --- a/src/type_checker/stmt.rs +++ b/src/type_checker/stmt.rs @@ -392,6 +392,7 @@ fn check_stmt( node: ast::Pat::Var(ast::VarPat { var: expr_local.clone(), ty: Some(iter_ty.clone()), + refined: None, }), }, ty: None, diff --git a/src/type_collector.rs b/src/type_collector.rs index 47cf4b56..069d7174 100644 --- a/src/type_collector.rs +++ b/src/type_collector.rs @@ -239,10 +239,18 @@ fn visit_pat( } } - mono::Pat::Variant(mono::VariantPat { pat, ty }) => { - ty.values() + mono::Pat::Variant(mono::VariantPat { + pat, + variant_ty, + pat_ty, + }) => { + variant_ty + .values() .for_each(|ty| visit_named_ty(ty, records, variants)); - variants.insert(VariantType { alts: ty.clone() }); + visit_ty(pat_ty, records, variants); + variants.insert(VariantType { + alts: variant_ty.clone(), + }); visit_pat(&pat.node, records, variants); } }