diff --git a/lib/decimal.ex b/lib/decimal.ex index 9f9b644..71386b8 100644 --- a/lib/decimal.ex +++ b/lib/decimal.ex @@ -1061,6 +1061,38 @@ defmodule Decimal do round(decimal(num), n, mode) end + @doc """ + Rounds the given number to the nearest number given as input with the given strategy + (default is to round to nearest one). + + See `Decimal.Context` for more information about rounding algorithms. + + ## Examples + + iex> Decimal.round_to_nearest("47.1", 5) + Decimal.new("45") + + iex> Decimal.round_to_nearest("47.51", 5) + Decimal.new("50") + + """ + @spec round_to_nearest(decimal, integer, rounding) :: t + def round_to_nearest(num, round_number, mode \\ :half_up) + + def round_to_nearest(%Decimal{coef: :NaN} = num, _, _), do: num + + def round_to_nearest(%Decimal{coef: :inf} = num, _, _), do: num + + def round_to_nearest(%Decimal{} = num, round_number, mode) do + div(num, round_number) + |> round(0, mode) + |> mult(round_number) + end + + def round_to_nearest(num, round_number, mode) do + round_to_nearest(decimal(num), round_number, mode) + end + @doc """ Finds the square root. diff --git a/test/decimal_test.exs b/test/decimal_test.exs index d79a15e..9a659f5 100644 --- a/test/decimal_test.exs +++ b/test/decimal_test.exs @@ -742,6 +742,123 @@ defmodule DecimalTest do assert roundneg.(~d"1099") == d(1, 11, 2) end + test "round_to_nearest/3: special" do + assert Decimal.round_to_nearest(~d"inf", 5, :down) == d(1, :inf, 0) + assert Decimal.round_to_nearest(~d"nan", 5, :down) == d(1, :NaN, 0) + end + + test "round_to_nearest/3: down" do + round_nearest = &Decimal.round_to_nearest(&1, 5, :down) + assert round_nearest.(~d"45") == d(1, 45, 0) + assert round_nearest.(~d"46.35") == d(1, 45, 0) + assert round_nearest.(~d"47.38") == d(1, 45, 0) + assert round_nearest.(~d"47.5") == d(1, 45, 0) + assert round_nearest.(~d"49.99") == d(1, 45, 0) + assert round_nearest.(~d"50") == d(1, 50, 0) + assert round_nearest.(~d"-45") == d(-1, 45, 0) + assert round_nearest.(~d"-46.35") == d(-1, 45, 0) + assert round_nearest.(~d"-47.38") == d(-1, 45, 0) + assert round_nearest.(~d"-47.5") == d(-1, 45, 0) + assert round_nearest.(~d"-49.99") == d(-1, 45, 0) + assert round_nearest.(~d"-50") == d(-1, 50, 0) + end + + test "round_to_nearest/3: ceiling" do + round_nearest = &Decimal.round_to_nearest(&1, 5, :ceiling) + assert round_nearest.(~d"45") == d(1, 45, 0) + assert round_nearest.(~d"46.35") == d(1, 50, 0) + assert round_nearest.(~d"47.38") == d(1, 50, 0) + assert round_nearest.(~d"47.5") == d(1, 50, 0) + assert round_nearest.(~d"49.99") == d(1, 50, 0) + assert round_nearest.(~d"50") == d(1, 50, 0) + assert round_nearest.(~d"-45") == d(-1, 45, 0) + assert round_nearest.(~d"-46.35") == d(-1, 45, 0) + assert round_nearest.(~d"-47.38") == d(-1, 45, 0) + assert round_nearest.(~d"-47.5") == d(-1, 45, 0) + assert round_nearest.(~d"-49.99") == d(-1, 45, 0) + assert round_nearest.(~d"-50") == d(-1, 50, 0) + end + + test "round_to_nearest/3: floor" do + round_nearest = &Decimal.round_to_nearest(&1, 5, :floor) + assert round_nearest.(~d"45") == d(1, 45, 0) + assert round_nearest.(~d"46.35") == d(1, 45, 0) + assert round_nearest.(~d"47.38") == d(1, 45, 0) + assert round_nearest.(~d"47.5") == d(1, 45, 0) + assert round_nearest.(~d"49.99") == d(1, 45, 0) + assert round_nearest.(~d"50") == d(1, 50, 0) + assert round_nearest.(~d"-45") == d(-1, 45, 0) + assert round_nearest.(~d"-46.35") == d(-1, 50, 0) + assert round_nearest.(~d"-47.38") == d(-1, 50, 0) + assert round_nearest.(~d"-47.5") == d(-1, 50, 0) + assert round_nearest.(~d"-49.99") == d(-1, 50, 0) + assert round_nearest.(~d"-50") == d(-1, 50, 0) + end + + test "round_to_nearest/3: half up" do + round_nearest = &Decimal.round_to_nearest(&1, 5, :half_up) + assert round_nearest.(~d"45") == d(1, 45, 0) + assert round_nearest.(~d"46.35") == d(1, 45, 0) + assert round_nearest.(~d"47.38") == d(1, 45, 0) + assert round_nearest.(~d"47.5") == d(1, 50, 0) + assert round_nearest.(~d"49.99") == d(1, 50, 0) + assert round_nearest.(~d"50") == d(1, 50, 0) + assert round_nearest.(~d"-45") == d(-1, 45, 0) + assert round_nearest.(~d"-46.35") == d(-1, 45, 0) + assert round_nearest.(~d"-47.38") == d(-1, 45, 0) + assert round_nearest.(~d"-47.5") == d(-1, 50, 0) + assert round_nearest.(~d"-49.99") == d(-1, 50, 0) + assert round_nearest.(~d"-50") == d(-1, 50, 0) + end + + test "round_to_nearest/3: half even" do + round_nearest = &Decimal.round_to_nearest(&1, 5, :half_even) + assert round_nearest.(~d"45") == d(1, 45, 0) + assert round_nearest.(~d"46.35") == d(1, 45, 0) + assert round_nearest.(~d"47.38") == d(1, 45, 0) + assert round_nearest.(~d"47.5") == d(1, 50, 0) + assert round_nearest.(~d"49.99") == d(1, 50, 0) + assert round_nearest.(~d"50") == d(1, 50, 0) + assert round_nearest.(~d"-45") == d(-1, 45, 0) + assert round_nearest.(~d"-46.35") == d(-1, 45, 0) + assert round_nearest.(~d"-47.38") == d(-1, 45, 0) + assert round_nearest.(~d"-47.5") == d(-1, 50, 0) + assert round_nearest.(~d"-49.99") == d(-1, 50, 0) + assert round_nearest.(~d"-50") == d(-1, 50, 0) + end + + test "round_to_nearest/3: half down" do + round_nearest = &Decimal.round_to_nearest(&1, 5, :half_down) + assert round_nearest.(~d"45") == d(1, 45, 0) + assert round_nearest.(~d"46.35") == d(1, 45, 0) + assert round_nearest.(~d"47.38") == d(1, 45, 0) + assert round_nearest.(~d"47.5") == d(1, 45, 0) + assert round_nearest.(~d"49.99") == d(1, 50, 0) + assert round_nearest.(~d"50") == d(1, 50, 0) + assert round_nearest.(~d"-45") == d(-1, 45, 0) + assert round_nearest.(~d"-46.35") == d(-1, 45, 0) + assert round_nearest.(~d"-47.38") == d(-1, 45, 0) + assert round_nearest.(~d"-47.5") == d(-1, 45, 0) + assert round_nearest.(~d"-49.99") == d(-1, 50, 0) + assert round_nearest.(~d"-50") == d(-1, 50, 0) + end + + test "round_to_nearest/3: up" do + round_nearest = &Decimal.round_to_nearest(&1, 5, :up) + assert round_nearest.(~d"45") == d(1, 45, 0) + assert round_nearest.(~d"46.35") == d(1, 50, 0) + assert round_nearest.(~d"47.38") == d(1, 50, 0) + assert round_nearest.(~d"47.5") == d(1, 50, 0) + assert round_nearest.(~d"49.99") == d(1, 50, 0) + assert round_nearest.(~d"50") == d(1, 50, 0) + assert round_nearest.(~d"-45") == d(-1, 45, 0) + assert round_nearest.(~d"-46.35") == d(-1, 50, 0) + assert round_nearest.(~d"-47.38") == d(-1, 50, 0) + assert round_nearest.(~d"-47.5") == d(-1, 50, 0) + assert round_nearest.(~d"-49.99") == d(-1, 50, 0) + assert round_nearest.(~d"-50") == d(-1, 50, 0) + end + test "sqrt/1" do Context.with(%Context{precision: 9, rounding: :half_even}, fn -> assert Decimal.sqrt(~d"0") == d(1, 0, 0)