Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ pub struct VarPat {

/// Inferred type of the binder. Filled in by the type checker.
pub ty: Option<Ty>,

/// Only after type checking: when the binder type is refined by pattern matching, this holds
/// the refined type.
pub refined: Option<Ty>,
}

#[derive(Debug, Clone)]
Expand All @@ -419,6 +423,7 @@ pub struct RecordPat {
pub struct VariantPat {
pub pat: Box<L<Pat>>,
pub inferred_ty: Option<Ty>,
pub inferred_pat_ty: Option<Ty>,
}

#[derive(Debug, Clone)]
Expand Down
11 changes: 9 additions & 2 deletions src/ast/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,11 +814,14 @@ impl Expr {
impl Pat {
pub fn print(&self, buf: &mut String) {
match self {
Pat::Var(VarPat { var, ty }) => {
Pat::Var(VarPat { var, ty, refined }) => {
buf.push_str(var);
if let Some(ty) = ty {
write!(buf, ": {ty}").unwrap();
}
if let Some(refined) = refined {
write!(buf, " ~> {refined}").unwrap();
}
}

Pat::Con(ConPat {
Expand Down Expand Up @@ -901,7 +904,11 @@ impl Pat {
buf.push(')');
}

Pat::Variant(VariantPat { pat, inferred_ty }) => {
Pat::Variant(VariantPat {
pat,
inferred_ty,
inferred_pat_ty: _,
}) => {
buf.push('~');
pat.node.print(buf);
if let Some(ty) = inferred_ty {
Expand Down
24 changes: 18 additions & 6 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ impl Pgm {
closures,
type_objs: _,
record_objs: _,
variant_objs: _,
true_con_idx,
false_con_idx,
char_con_idx,
Expand All @@ -77,7 +78,7 @@ impl Pgm {
// Allocate singletons for constructors without fields.
for (i, heap_obj) in heap_objs.iter().enumerate() {
match heap_obj {
HeapObj::Builtin(_) => continue,
HeapObj::Builtin(_) | HeapObj::Variant(_) => continue,

HeapObj::Source(source_con) => {
if source_con.fields.is_empty() {
Expand Down Expand Up @@ -664,7 +665,11 @@ fn eval<W: Write>(

Expr::Do(stmts, _) => exec(w, pgm, heap, locals, stmts, call_stack),

Expr::Variant(expr) => {
Expr::Variant {
expr,
expr_ty: _,
variant_ty: _,
} => {
// Note: the interpreter can only deal with variants of boxed types. If `expr` is an
// unboxed type things will go wrong.
//
Expand Down Expand Up @@ -729,8 +734,11 @@ fn assign<W: Write>(
/// compiled version `StrView`s will be allocated on stack.
fn try_bind_pat(pgm: &Pgm, heap: &mut Heap, pat: &L<Pat>, locals: &mut [u64], value: u64) -> bool {
match &pat.node {
Pat::Var(var) => {
locals[var.as_usize()] = value;
Pat::Var(VarPat {
idx,
original_ty: _,
}) => {
locals[idx.as_usize()] = value;
true
}

Expand Down Expand Up @@ -775,10 +783,14 @@ fn try_bind_pat(pgm: &Pgm, heap: &mut Heap, pat: &L<Pat>, locals: &mut [u64], va
true
}

Pat::Variant(p) => {
Pat::Variant {
pat,
variant_ty: _,
pat_ty: _,
} => {
// `p` needs to match a boxed type, but we can't check this here (e.g. in an `assert`).
// See the documentation in `Expr::Variant` evaluator.
try_bind_pat(pgm, heap, p, locals, value)
try_bind_pat(pgm, heap, pat, locals, value)
}
}
}
Expand Down
110 changes: 87 additions & 23 deletions src/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ pub mod printer;
use crate::ast;
use crate::collections::*;
use crate::mono_ast::{self as mono, Id, L, Loc};
pub(crate) use crate::type_collector::RecordType;
use crate::type_collector::collect_anonymous_types;
pub(crate) use crate::type_collector::{RecordType, VariantType};
use crate::utils::loc_display;

use smol_str::SmolStr;
Expand All @@ -23,9 +23,16 @@ pub struct LoweredPgm {
/// Product types will have one index per type. Sum types may have multiple.
pub type_objs: HashMap<Id, HashMap<Vec<mono::Type>, TypeObjs>>,

/// Maps record types to their heap object indices.
/// For C backend: maps record types to their heap object indices.
pub record_objs: HashMap<RecordType, HeapObjIdx>,

/// For C backend: maps variant types to their heap object indices.
///
/// Note: variants don't have their own tags, they use the tags of the types in the variant
/// instead. These tags are to make it easy to refer to a variant type in AST nodes, dependency
/// analysis etc.
pub variant_objs: HashMap<VariantType, HeapObjIdx>,

// Ids of some special cons that the interpreter needs to know.
//
// Note that for product types, type and con tags are the same.
Expand All @@ -50,7 +57,10 @@ pub struct LoweredPgm {
#[derive(Debug)]
pub enum TypeObjs {
Product(HeapObjIdx),
Sum(Vec<HeapObjIdx>),
Sum {
con_indices: Vec<HeapObjIdx>,
value: bool,
},
}

pub const CON_CON_IDX: HeapObjIdx = HeapObjIdx(0);
Expand Down Expand Up @@ -291,6 +301,7 @@ pub enum HeapObj {
Builtin(BuiltinConDecl),
Source(SourceConDecl),
Record(RecordType),
Variant(VariantType),
}

#[derive(Debug)]
Expand Down Expand Up @@ -429,7 +440,11 @@ pub enum Expr {
mono::Type,
),

Variant(Box<L<Expr>>),
Variant {
expr: Box<L<Expr>>,
expr_ty: mono::Type,
variant_ty: OrdMap<Id, mono::NamedType>,
},
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -486,13 +501,28 @@ pub struct IsExpr {

#[derive(Debug, Clone)]
pub enum Pat {
Var(LocalIdx),
Var(VarPat),
Con(ConPat),
Ignore,
Str(String),
Char(char),
Or(Box<L<Pat>>, Box<L<Pat>>),
Variant(Box<L<Pat>>),
Variant {
pat: Box<L<Pat>>,
variant_ty: OrdMap<Id, mono::NamedType>,
pat_ty: mono::Type,
},
}

#[derive(Debug, Clone)]
pub struct VarPat {
pub idx: LocalIdx,

/// When the binder was refined by pattern matching, the local in `SourceFunDecl` will have the
/// refined type, and this will be the original type.
///
/// When pattern matching, we should convert the original type to the local's type.
pub original_ty: mono::Type,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -693,6 +723,7 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm {
closures: vec![],
type_objs: Default::default(),
record_objs: Default::default(),
variant_objs: Default::default(),
true_con_idx: *sum_con_nums
.get("Bool")
.unwrap()
Expand Down Expand Up @@ -768,6 +799,9 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm {
match &con_decl.rhs {
Some(rhs) => match rhs {
mono::TypeDeclRhs::Sum(cons) => {
// For sum types, we generate an index representing the type itself (rather
// than its consturctors). This index is used in dependency anlaysis, and to
// get the type details during code generation.
let mut con_indices: Vec<HeapObjIdx> = Vec::with_capacity(cons.len());
for mono::ConDecl { name, fields } in cons {
let idx = HeapObjIdx(lowered_pgm.heap_objs.len() as u32);
Expand All @@ -784,7 +818,13 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm {
.type_objs
.entry(con_id.clone())
.or_default()
.insert(con_ty_args.clone(), TypeObjs::Sum(con_indices));
.insert(
con_ty_args.clone(),
TypeObjs::Sum {
con_indices,
value: con_decl.value,
},
);
assert!(old.is_none());
}

Expand Down Expand Up @@ -871,9 +911,8 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm {
}

// Assign indices to record shapes.
let (record_types, _variant_types) = collect_anonymous_types(mono_pgm);
let (record_types, variant_types) = collect_anonymous_types(mono_pgm);

// TODO: We could assign indices to records as we see them during lowering below.
let mut record_indices: HashMap<RecordType, HeapObjIdx> = Default::default();
for record_type in record_types {
let idx = next_con_idx;
Expand All @@ -882,6 +921,14 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm {
lowered_pgm.heap_objs.push(HeapObj::Record(record_type));
}

let mut variant_indices: HashMap<VariantType, HeapObjIdx> = Default::default();
for variant_type in variant_types {
let idx = next_con_idx;
next_con_idx = HeapObjIdx(next_con_idx.0 + 1);
variant_indices.insert(variant_type.clone(), idx);
lowered_pgm.heap_objs.push(HeapObj::Variant(variant_type));
}

lowered_pgm.unit_con_idx = *record_indices
.get(&RecordType::unit())
.unwrap_or_else(|| panic!("BUG: Unit record not defined {record_indices:#?}"));
Expand Down Expand Up @@ -1425,6 +1472,7 @@ pub fn lower(mono_pgm: &mut mono::MonoPgm) -> LoweredPgm {
}

lowered_pgm.record_objs = indices.records;
lowered_pgm.variant_objs = variant_indices;
lowered_pgm
}

Expand Down Expand Up @@ -2094,11 +2142,19 @@ fn lower_expr(
(expr, Default::default())
}

mono::Expr::Variant(mono::VariantExpr { expr, ty: _ }) => {
mono::Expr::Variant(mono::VariantExpr { expr, ty }) => {
// Note: Type of the expr in the variant won't be a variant type. Use the
// `VariantExpr`'s type.
let expr_ty = expr.node.ty();
let (expr, vars) = lower_bl_expr(expr, closures, indices, scope, mono_pgm);
(Expr::Variant(expr), vars)
(
Expr::Variant {
expr,
expr_ty,
variant_ty: ty.clone(),
},
vars,
)
}
}
}
Expand Down Expand Up @@ -2147,22 +2203,28 @@ fn lower_pat(

// This map is to map binders in alternatives of or patterns to the same local.
//
// Only in or pattern alternatives we allow same binders, so if we see a binder for the second
// time, we must be checking another alternative of an or pattern.
// Only in or-pattern alternatives we allow same binders, so if we see a binder for the second
// time, we must be checking another alternative of an or-pattern.
mapped_binders: &mut HashMap<Id, LocalIdx>,
) -> Pat {
match pat {
mono::Pat::Var(mono::VarPat { var, ty }) => match mapped_binders.get(var) {
Some(idx) => Pat::Var(*idx),
mono::Pat::Var(mono::VarPat { var, ty, refined }) => match mapped_binders.get(var) {
Some(idx) => Pat::Var(VarPat {
idx: *idx,
original_ty: ty.clone(),
}),
None => {
let var_idx = LocalIdx(scope.locals.len() as u32);
scope.locals.push(LocalInfo {
name: var.clone(),
ty: ty.clone(),
ty: refined.as_ref().unwrap_or(ty).clone(),
});
scope.bounds.insert(var.clone(), var_idx);
mapped_binders.insert(var.clone(), var_idx);
Pat::Var(var_idx)
Pat::Var(VarPat {
idx: var_idx,
original_ty: ty.clone(),
})
}
},

Expand Down Expand Up @@ -2311,13 +2373,15 @@ fn lower_pat(
lower_bl_pat(p2, indices, scope, mono_pgm, mapped_binders),
),

mono::Pat::Variant(mono::VariantPat { pat, ty: _ }) => Pat::Variant(Box::new(lower_l_pat(
mono::Pat::Variant(mono::VariantPat {
pat,
indices,
scope,
mono_pgm,
mapped_binders,
))),
variant_ty,
pat_ty,
}) => Pat::Variant {
pat: Box::new(lower_l_pat(pat, indices, scope, mono_pgm, mapped_binders)),
variant_ty: variant_ty.clone(),
pat_ty: pat_ty.clone(),
},
}
}

Expand Down
21 changes: 17 additions & 4 deletions src/lowering/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ impl LoweredPgm {
}

HeapObj::Record(record) => write!(buf, "{record:?}").unwrap(),

HeapObj::Variant(variant) => write!(buf, "{variant:?}").unwrap(),
}
buf.push('\n');
}
Expand Down Expand Up @@ -385,7 +387,11 @@ impl Expr {
}
}

Expr::Variant(expr) => {
Expr::Variant {
expr,
expr_ty: _,
variant_ty: _,
} => {
buf.push('~');
expr.node.print(buf, indent);
}
Expand All @@ -396,7 +402,10 @@ impl Expr {
impl Pat {
pub fn print(&self, buf: &mut String) {
match self {
Pat::Var(idx) => write!(buf, "local{}", idx.0).unwrap(),
Pat::Var(VarPat {
idx,
original_ty: _,
}) => write!(buf, "local{}", idx.0).unwrap(),

Pat::Con(ConPat { con, fields }) => {
write!(buf, "con{}(", con.0).unwrap();
Expand Down Expand Up @@ -431,9 +440,13 @@ impl Pat {
p2.node.print(buf);
}

Pat::Variant(p) => {
Pat::Variant {
pat,
variant_ty: _,
pat_ty: _,
} => {
buf.push('~');
p.node.print(buf);
pat.node.print(buf);
}
}
}
Expand Down
Loading