diff --git a/README.md b/README.md index 6f2461b..09e7c13 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,8 @@ Arguments: Options: -e, --expression Expression to evaluate -l, --load Load files before entering the REPL - -p, --precision Decimal precision to use for calculations [default: 1024] + -p, --precision Binary precision (in bits) to use for calculations [default: 1024] + -d, --digits Decimal digits to use for calculations and output -r, --rounding-mode Rounding mode to use for calculations [default: to-even] --stack-size Stack size in MB for evaluations [default: 128] -h, --help Print help @@ -185,8 +186,12 @@ the hood): ``` You can specify the rounding mode, and what sort of precision you'd like to see -in the output by using the `--rounding-mode` and `--precision` options -respectively. +in the output by using the `--rounding-mode`, `--precision`, and `--digits` +options respectively. `--precision` controls the underlying binary precision +used during calculations, while `--digits` can be used when you care about the +number of decimal places in the result. The rounding behavior when trimming +to a fixed number of decimal places follows the selected `--rounding-mode` +(default: `to-even`). #### Boolean diff --git a/crates/val-wasm/src/lib.rs b/crates/val-wasm/src/lib.rs index 3f7754c..8a5d9a8 100644 --- a/crates/val-wasm/src/lib.rs +++ b/crates/val-wasm/src/lib.rs @@ -48,10 +48,14 @@ pub fn evaluate(input: &str) -> Result { let mut evaluator = Evaluator::from(Environment::new(val::Config { precision: 53, rounding_mode: RoundingMode::FromZero.into(), + digits: None, })); match evaluator.eval(&ast) { - Ok(value) => Ok(to_value(&value.to_string()).unwrap()), + Ok(value) => Ok( + to_value(&value.format_with_config(&evaluator.environment.config)) + .unwrap(), + ), Err(error) => Err( to_value(&[ValError { kind: ErrorKind::Evaluator, diff --git a/src/arguments.rs b/src/arguments.rs index 17a8738..cae4a02 100644 --- a/src/arguments.rs +++ b/src/arguments.rs @@ -39,10 +39,18 @@ pub struct Arguments { short, long, default_value = "1024", - help = "Decimal precision to use for calculations" + help = "Binary precision (in bits) to use for calculations" )] precision: usize, + #[clap( + long, + short = 'd', + conflicts_with = "precision", + help = "Decimal digits to use for calculations and output" + )] + digits: Option, + #[clap( short, long, @@ -83,10 +91,7 @@ impl Arguments { let filename = filename.to_string_lossy().to_string(); - let mut evaluator = Evaluator::from(Environment::new(Config { - precision: self.precision, - rounding_mode: self.rounding_mode.into(), - })); + let mut evaluator = Evaluator::from(Environment::new(self.to_config())); match parse(&content) { Ok(ast) => match evaluator.eval(&ast) { @@ -112,10 +117,7 @@ impl Arguments { } fn eval_expression(&self, value: String) -> Result { - let mut evaluator = Evaluator::from(Environment::new(Config { - precision: self.precision, - rounding_mode: self.rounding_mode.into(), - })); + let mut evaluator = Evaluator::from(Environment::new(self.to_config())); match parse(&value) { Ok(ast) => match evaluator.eval(&ast) { @@ -124,7 +126,10 @@ impl Arguments { return Ok(()); } - println!("{}", value); + println!( + "{}", + value.format_with_config(&evaluator.environment.config) + ); Ok(()) } @@ -166,10 +171,7 @@ impl Arguments { editor.set_helper(Some(Highlighter::new())); editor.load_history(&history).ok(); - let mut evaluator = Evaluator::from(Environment::new(Config { - precision: self.precision, - rounding_mode: self.rounding_mode.into(), - })); + let mut evaluator = Evaluator::from(Environment::new(self.to_config())); if let Some(filenames) = &self.load { for filename in filenames { @@ -212,7 +214,12 @@ impl Arguments { match parse(line) { Ok(ast) => match evaluator.eval(&ast) { - Ok(value) if !matches!(value, Value::Null) => println!("{value}"), + Ok(value) if !matches!(value, Value::Null) => { + println!( + "{}", + value.format_with_config(&evaluator.environment.config) + ); + } Ok(_) => {} Err(error) => error .report("") @@ -228,6 +235,29 @@ impl Arguments { } } } + + fn to_config(&self) -> Config { + Config { + precision: self.precision_bits(), + rounding_mode: self.rounding_mode.into(), + digits: self.digits, + } + } + + fn precision_bits(&self) -> usize { + self + .digits + .map(Self::digits_to_binary_precision) + .unwrap_or(self.precision) + } + + fn digits_to_binary_precision(digits: usize) -> usize { + if digits == 0 { + return 0; + } + + ((digits as f64) * f64::consts::LOG2_10).ceil() as usize + } } #[cfg(test)] @@ -325,4 +355,38 @@ mod tests { error ); } + + #[test] + fn digits_option_sets_decimal_precision() { + let arguments = Arguments::parse_from(vec![ + "program", + "--digits", + "25", + "--rounding-mode", + "to-zero", + ]); + + assert_eq!(arguments.digits, Some(25)); + + assert_eq!( + arguments.precision_bits(), + Arguments::digits_to_binary_precision(25) + ); + } + + #[test] + fn digits_conflicts_with_precision() { + let result = Arguments::try_parse_from(vec![ + "program", + "--digits", + "10", + "--precision", + "512", + ]); + + assert!( + result.is_err(), + "Parser should reject simultaneous --digits and --precision" + ); + } } diff --git a/src/config.rs b/src/config.rs index e573072..bed52de 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,6 +2,7 @@ pub struct Config { pub precision: usize, pub rounding_mode: astro_float::RoundingMode, + pub digits: Option, } impl Default for Config { @@ -9,6 +10,7 @@ impl Default for Config { Self { precision: 1024, rounding_mode: astro_float::RoundingMode::ToEven, + digits: None, } } } diff --git a/src/environment.rs b/src/environment.rs index 8eaeaf2..b94f531 100644 --- a/src/environment.rs +++ b/src/environment.rs @@ -751,7 +751,7 @@ impl<'src> Environment<'src> { let mut output_strings = Vec::with_capacity(payload.arguments.len()); for argument in &payload.arguments { - output_strings.push(format!("{}", argument)); + output_strings.push(argument.format_with_config(&payload.config)); } print!("{}", output_strings.join(" ")); @@ -766,7 +766,7 @@ impl<'src> Environment<'src> { let mut output_strings = Vec::with_capacity(payload.arguments.len()); for argument in &payload.arguments { - output_strings.push(format!("{}", argument)); + output_strings.push(argument.format_with_config(&payload.config)); } println!("{}", output_strings.join(" ")); diff --git a/src/float_ext.rs b/src/float_ext.rs index 8038361..94db427 100644 --- a/src/float_ext.rs +++ b/src/float_ext.rs @@ -1,102 +1,26 @@ use super::*; pub trait FloatExt { - fn display(&self) -> String; - fn to_f64(&self, rounding_mode: astro_float::RoundingMode) -> Option; -} - -impl FloatExt for Float { fn display(&self) -> String { - if self.is_nan() { - return "nan".into(); - } - - if self.is_inf_pos() { - return "inf".into(); - } - - if self.is_inf_neg() { - return "-inf".into(); - } - - if self.is_zero() { - return "0".into(); - } - - let formatted = with_consts(|consts| { - self.format(Radix::Dec, astro_float::RoundingMode::None, consts) - }) - .expect("failed to format Float as decimal"); - - let Some((mantissa_with_sign, exponent_str)) = formatted.split_once('e') - else { - return formatted; - }; - - let Ok(exponent) = exponent_str.parse::() else { - return formatted; - }; - - let (sign, mantissa) = - if let Some(rest) = mantissa_with_sign.strip_prefix('-') { - ("-", rest) - } else if let Some(rest) = mantissa_with_sign.strip_prefix('+') { - ("", rest) - } else { - ("", mantissa_with_sign) - }; - - let mut parts = mantissa.split('.'); - let int_part = parts.next().unwrap_or(""); - let frac_part = parts.next().unwrap_or(""); - - let mut digits = String::with_capacity(int_part.len() + frac_part.len()); - digits.push_str(int_part); - digits.push_str(frac_part); - - let length = int_part.len() as i32 + exponent; - let digits_len = digits.len() as i32; - - let mut result = if length <= 0 { - let zeros = (-length) as usize; - let mut out = - String::with_capacity(sign.len() + 2 + zeros + digits.len()); - out.push_str(sign); - out.push('0'); - out.push('.'); - out.extend(std::iter::repeat_n('0', zeros)); - out.push_str(&digits); - out - } else if length >= digits_len { - let zeros = (length - digits_len) as usize; - let mut out = String::with_capacity(sign.len() + digits.len() + zeros); - out.push_str(sign); - out.push_str(&digits); - out.extend(std::iter::repeat_n('0', zeros)); - out - } else { - let split_at = length as usize; - let (left, right) = digits.split_at(split_at); - let mut out = - String::with_capacity(sign.len() + left.len() + 1 + right.len()); - out.push_str(sign); - out.push_str(left); - out.push('.'); - out.push_str(right); - out - }; + self.display_with_digits(None, astro_float::RoundingMode::None) + } - if result.contains('.') { - while result.ends_with('0') { - result.pop(); - } + fn display_with_digits( + &self, + digits: Option, + rounding_mode: astro_float::RoundingMode, + ) -> String; - if result.ends_with('.') { - result.pop(); - } - } + fn to_f64(&self, rounding_mode: astro_float::RoundingMode) -> Option; +} - result +impl FloatExt for Float { + fn display_with_digits( + &self, + digits: Option, + rounding_mode: astro_float::RoundingMode, + ) -> String { + render_decimal(self, digits, rounding_mode) } fn to_f64(&self, rounding_mode: astro_float::RoundingMode) -> Option { @@ -170,6 +94,322 @@ impl FloatExt for Float { } } +fn render_decimal( + value: &Float, + digits: Option, + rounding_mode: astro_float::RoundingMode, +) -> String { + if value.is_nan() { + return "nan".into(); + } + + if value.is_inf_pos() { + return "inf".into(); + } + + if value.is_inf_neg() { + return "-inf".into(); + } + + if value.is_zero() { + return format_zero(digits); + } + + let normalized = normalize_decimal(value); + + match digits { + None => normalized, + Some(count) => adjust_decimal_digits(&normalized, count, rounding_mode), + } +} + +fn format_zero(digits: Option) -> String { + match digits { + None => "0".into(), + Some(0) => "0".into(), + Some(count) => { + let mut out = String::with_capacity(2 + count); + out.push('0'); + out.push('.'); + out.extend(std::iter::repeat_n('0', count)); + out + } + } +} + +fn normalize_decimal(value: &Float) -> String { + let formatted = with_consts(|consts| { + value.format(Radix::Dec, astro_float::RoundingMode::None, consts) + }) + .expect("failed to format Float as decimal"); + + let Some((mantissa_with_sign, exponent_str)) = formatted.split_once('e') + else { + return formatted; + }; + + let Ok(exponent) = exponent_str.parse::() else { + return formatted; + }; + + let (sign, mantissa) = + if let Some(rest) = mantissa_with_sign.strip_prefix('-') { + ("-", rest) + } else if let Some(rest) = mantissa_with_sign.strip_prefix('+') { + ("", rest) + } else { + ("", mantissa_with_sign) + }; + + let mut parts = mantissa.split('.'); + let int_part = parts.next().unwrap_or(""); + let frac_part = parts.next().unwrap_or(""); + + let mut digits = String::with_capacity(int_part.len() + frac_part.len()); + digits.push_str(int_part); + digits.push_str(frac_part); + + let length = int_part.len() as i32 + exponent; + let digits_len = digits.len() as i32; + + let mut result = if length <= 0 { + let zeros = (-length) as usize; + let mut out = String::with_capacity(sign.len() + 2 + zeros + digits.len()); + out.push_str(sign); + out.push('0'); + out.push('.'); + out.extend(std::iter::repeat_n('0', zeros)); + out.push_str(&digits); + out + } else if length >= digits_len { + let zeros = (length - digits_len) as usize; + let mut out = String::with_capacity(sign.len() + digits.len() + zeros); + out.push_str(sign); + out.push_str(&digits); + out.extend(std::iter::repeat_n('0', zeros)); + out + } else { + let split_at = length as usize; + let (left, right) = digits.split_at(split_at); + let mut out = + String::with_capacity(sign.len() + left.len() + 1 + right.len()); + out.push_str(sign); + out.push_str(left); + out.push('.'); + out.push_str(right); + out + }; + + if result.contains('.') { + while result.ends_with('0') { + result.pop(); + } + + if result.ends_with('.') { + result.pop(); + } + } + + result +} + +fn adjust_decimal_digits( + normalized: &str, + digits: usize, + rounding_mode: astro_float::RoundingMode, +) -> String { + if digits == 0 && !normalized.contains('.') { + return normalized.to_string(); + } + + let (negative, unsigned) = if let Some(rest) = normalized.strip_prefix('-') { + (true, rest) + } else { + (false, normalized) + }; + + let mut split = unsigned.split('.'); + let int_part_str = split.next().unwrap_or(""); + let frac_part_str = split.next().unwrap_or(""); + + let mut int_digits: Vec = if int_part_str.is_empty() { + vec![0] + } else { + int_part_str + .as_bytes() + .iter() + .map(|b| b - b'0') + .collect::>() + }; + + let mut frac_digits: Vec = frac_part_str + .as_bytes() + .iter() + .map(|b| b - b'0') + .collect::>(); + + if digits >= frac_digits.len() { + frac_digits.resize(digits, 0); + return build_decimal_string(negative, &int_digits, &frac_digits, digits); + } + + let next_digit = frac_digits[digits]; + let rest = &frac_digits[digits + 1..]; + + let last_kept = if digits > 0 { + frac_digits[digits - 1] + } else { + *int_digits.last().unwrap_or(&0) + }; + + let round_up = + should_round(rounding_mode, !negative, next_digit, rest, last_kept); + + frac_digits.truncate(digits); + + if round_up { + if digits > 0 { + if increment_digits(&mut frac_digits) && increment_digits(&mut int_digits) + { + int_digits.insert(0, 1); + } + } else if increment_digits(&mut int_digits) { + int_digits.insert(0, 1); + } + } + + let is_zero = int_digits.iter().all(|&d| d == 0) + && (digits == 0 || frac_digits.iter().all(|&d| d == 0)); + + let mut result = build_decimal_string( + negative && !is_zero, + &int_digits, + &frac_digits, + digits, + ); + + if digits == 0 { + result.truncate(result.find('.').unwrap_or(result.len())); + } + + result +} + +fn build_decimal_string( + negative: bool, + int_digits: &[u8], + frac_digits: &[u8], + digits: usize, +) -> String { + let mut out = String::with_capacity( + negative as usize + + int_digits.len() + + if digits > 0 { digits + 1 } else { 0 }, + ); + + if negative { + out.push('-'); + } + + if int_digits.is_empty() { + out.push('0'); + } else { + for &digit in int_digits { + out.push((digit + b'0') as char); + } + } + + if digits > 0 { + out.push('.'); + + if frac_digits.len() >= digits { + for &digit in frac_digits.iter().take(digits) { + out.push((digit + b'0') as char); + } + } else { + for &digit in frac_digits { + out.push((digit + b'0') as char); + } + out.extend(std::iter::repeat_n('0', digits - frac_digits.len())); + } + } + + out +} + +fn increment_digits(digits: &mut [u8]) -> bool { + for digit in digits.iter_mut().rev() { + if *digit < 9 { + *digit += 1; + return false; + } + *digit = 0; + } + + true +} + +fn should_round( + rounding_mode: astro_float::RoundingMode, + is_positive: bool, + next_digit: u8, + rest: &[u8], + last_kept: u8, +) -> bool { + if matches!(rounding_mode, astro_float::RoundingMode::None) { + return false; + } + + let rest_all_zero = rest.iter().all(|&d| d == 0); + let truncated_non_zero = next_digit != 0 || !rest_all_zero; + + match rounding_mode { + astro_float::RoundingMode::ToZero => false, + astro_float::RoundingMode::FromZero => truncated_non_zero, + astro_float::RoundingMode::Up => truncated_non_zero && is_positive, + astro_float::RoundingMode::Down => truncated_non_zero && !is_positive, + astro_float::RoundingMode::ToEven => { + if !truncated_non_zero { + return false; + } + + if next_digit > 5 { + return true; + } + + if next_digit < 5 { + return false; + } + + if !rest_all_zero { + return true; + } + + last_kept % 2 == 1 + } + astro_float::RoundingMode::ToOdd => { + if !truncated_non_zero { + return false; + } + + if next_digit > 5 { + return true; + } + + if next_digit < 5 { + return false; + } + + if !rest_all_zero { + return true; + } + + last_kept % 2 == 0 + } + astro_float::RoundingMode::None => false, + } +} + #[cfg(test)] mod tests { use super::*; @@ -233,6 +473,45 @@ mod tests { assert_eq!(float_from_str("1.23e15").display(), "1230000000000000"); } + #[test] + fn display_with_digits_rounds_to_even() { + let pi = float_from_str("3.1415926535897932384626"); + + assert_eq!( + pi.display_with_digits(Some(4), astro_float::RoundingMode::ToEven), + "3.1416" + ); + } + + #[test] + fn display_with_digits_respects_rounding_mode() { + let pi = float_from_str("3.1415926535897932384626"); + + assert_eq!( + pi.display_with_digits(Some(4), astro_float::RoundingMode::ToZero), + "3.1415" + ); + } + + #[test] + fn display_with_digits_zero_padding() { + assert_eq!( + Float::from(1) + .display_with_digits(Some(3), astro_float::RoundingMode::ToEven), + "1.000" + ); + } + + #[test] + fn display_with_digits_zero_places() { + let value = float_from_str("2.75"); + + assert_eq!( + value.display_with_digits(Some(0), astro_float::RoundingMode::FromZero), + "3" + ); + } + #[test] fn convert_to_double_precision() { assert_eq!( diff --git a/src/lib.rs b/src/lib.rs index c5aef67..59e9e2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub(crate) use { std::{ cell::RefCell, collections::HashMap, + f64, fmt::{self, Display, Formatter}, fs, ops::Range, diff --git a/src/value.rs b/src/value.rs index 61f9fb5..c468534 100644 --- a/src/value.rs +++ b/src/value.rs @@ -64,6 +64,32 @@ impl PartialEq for Value<'_> { } impl<'a> Value<'a> { + pub fn format_with_config(&self, config: &Config) -> String { + match self { + Value::Boolean(boolean) => boolean.to_string(), + Value::BuiltinFunction(name, _) | Value::Function(name, _, _, _) => { + format!("") + } + Value::List(list) => { + let items = list + .iter() + .map(|item| match item { + Value::String(string) => format!("\'{string}\'"), + _ => item.format_with_config(config), + }) + .collect::>() + .join(", "); + + format!("[{items}]") + } + Value::Null => "null".into(), + Value::Number(number) => { + number.display_with_digits(config.digits, config.rounding_mode) + } + Value::String(string) => (*string).into(), + } + } + pub fn boolean(&self, span: Span) -> Result { if let Value::Boolean(x) = self { Ok(*x) @@ -120,3 +146,51 @@ impl<'a> Value<'a> { } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn float_from_str(s: &str, precision: usize) -> Float { + with_consts(|consts| { + Float::parse( + s, + Radix::Dec, + precision, + astro_float::RoundingMode::FromZero, + consts, + ) + }) + } + + #[test] + fn number_format_respects_digits() { + let config = Config { + precision: 256, + rounding_mode: astro_float::RoundingMode::ToEven, + digits: Some(2), + }; + + let value = Value::Number(float_from_str("3.4567", config.precision)); + + assert_eq!(value.format_with_config(&config), "3.46"); + } + + #[test] + fn list_format_propagates_digits() { + let config = Config { + precision: 256, + rounding_mode: astro_float::RoundingMode::ToZero, + digits: Some(3), + }; + + let items = vec![ + Value::Number(float_from_str("1.234567", config.precision)), + Value::String("hello"), + ]; + + let value = Value::List(items); + + assert_eq!(value.format_with_config(&config), "[1.234, 'hello']"); + } +}