Mix.install([
{:math, "~> 0.7.0"}
])
import Math
defprotocol Expression do
def substitute(expr, bindings)
@spec evaluate(__MODULE__.t(), [{Atom.t(), any()}]) :: any()
def evaluate(expr, input)
# @spec substitute(__MODULE__.t, Map.t) :: __MODULE__.t
# def substitute(expr, var_mapping)
@spec get_variables(__MODULE__.t()) :: [Atom.t()]
def get_variables(expr)
end
defimpl Expression, for: Atom do
def substitute(var, bindings), do: bindings[var]
def evaluate(var, bindings), do: bindings[var]
def get_variables(var), do: [var]
end
defimpl Expression, for: List do
def substitute(exprs, bindings) do
for {expr, input} <- Enum.zip(exprs, inputs) do
expr |> Expression.substitute(input)
end
end
def evaluate(exprs, bindings) when length(exprs) >= length(inputs) do
for {expr, input} <- Enum.zip(exprs, inputs) do
expr |> Expression.evaluate(input)
end
end
def get_variables(exprs), do: exprs |> Enum.map(&Expression.get_variables/1)
end
defmodule Operation do
defstruct [:op, :exprs]
def substitute_exprs(%Operation{op: op, exprs: exprs}) do
end
end
defimpl Expression, for: ArithOp do
def substitute(%ArithOp{op: op, exprs: exprs}, bindings) do
exprs |> Expression.substitute(inputs) |> to_list()
end
def evaluate(%ArithOp{op: op, exprs: exprs}, inputs) do
input_values = exprs |> Expression.evaluate(inputs) |> to_list()
f =
case op do
:plus -> fn x, y -> x + y end
:minus -> fn x, y -> x - y end
end
input_values |> Enum.reduce(f)
end
def to_list(l) when is_list(l) do
l
end
def to_list(v), do: [v]
end
defprotocol Differentiable do
@spec derivative(Expression.t()) :: Expression.t()
def derivative(expr)
end
defmodule Trig do
defstruct [:fn, :expr]
def trig_fn(:sin) do
&Math.sin/1
end
def trig_fn(:cos) do
&Math.cos/1
end
def new(f, expr) when f == :sin or f == :cos do
%Trig{fn: f, expr: expr}
end
end
defimpl Expression, for: Trig do
def evaluate(%Trig{fn: f, expr: expr}, input) do
Trig.trig_fn(f).(expr |> Expression.evaluate(input))
end
def get_variables(%Trig{expr: expr}) do
Expression.get_variables(expr)
end
end
defimpl Differentiable, for: Trig do
def derivative(%Trig{fn: f, expr: expr}) do
case f do
:cos -> %ArithOp{op: :minus, exprs: %Trig{fn: :sin, expr: expr}}
:sin -> %Trig{fn: :cos, expr: expr}
end
end
end
defmodule Foo do
def new(opts) do
opts
end
end
Foo.new(x: 1)[:x]
expr1 = Trig.new(:sin, :x)
dexpr1 = expr1 |> Differentiable.derivative()
Expression.evaluate(dexpr1, 0)
Expression.evaluate(%ArithOp{op: :plus, exprs: [dexpr1, expr1]}, 0)