Skip to content

Files

Latest commit

a351e89 · Aug 30, 2023

History

History
136 lines (106 loc) · 2.72 KB

symex.livemd

File metadata and controls

136 lines (106 loc) · 2.72 KB

Symbolic Elixir

Mix.install([
  {:math, "~> 0.7.0"}
])

Section

import Math

defprotocol Expression do
  def substitute(expr, bindings)

  @spec evaluate(__MODULE__.t(), [{Atom.t(), any()}]) :: any()
  def evaluate(expr, input)

  # @spec substitute(__MODULE__.t, Map.t) :: __MODULE__.t
  # def substitute(expr, var_mapping)

  @spec get_variables(__MODULE__.t()) :: [Atom.t()]
  def get_variables(expr)
end

defimpl Expression, for: Atom do
  def substitute(var, bindings), do: bindings[var]
  def evaluate(var, bindings), do: bindings[var]
  def get_variables(var), do: [var]
end

defimpl Expression, for: List do
  def substitute(exprs, bindings) do
    for {expr, input} <- Enum.zip(exprs, inputs) do
      expr |> Expression.substitute(input)
    end
  end

  def evaluate(exprs, bindings) when length(exprs) >= length(inputs) do
    for {expr, input} <- Enum.zip(exprs, inputs) do
      expr |> Expression.evaluate(input)
    end
  end

  def get_variables(exprs), do: exprs |> Enum.map(&Expression.get_variables/1)
end

defmodule Operation do
  defstruct [:op, :exprs]

  def substitute_exprs(%Operation{op: op, exprs: exprs}) do
  end
end

defimpl Expression, for: ArithOp do
  def substitute(%ArithOp{op: op, exprs: exprs}, bindings) do
    exprs |> Expression.substitute(inputs) |> to_list()
  end

  def evaluate(%ArithOp{op: op, exprs: exprs}, inputs) do
    input_values = exprs |> Expression.evaluate(inputs) |> to_list()

    f =
      case op do
        :plus -> fn x, y -> x + y end
        :minus -> fn x, y -> x - y end
      end

    input_values |> Enum.reduce(f)
  end

  def to_list(l) when is_list(l) do
    l
  end

  def to_list(v), do: [v]
end

defprotocol Differentiable do
  @spec derivative(Expression.t()) :: Expression.t()
  def derivative(expr)
end

defmodule Trig do
  defstruct [:fn, :expr]

  def trig_fn(:sin) do
    &Math.sin/1
  end

  def trig_fn(:cos) do
    &Math.cos/1
  end

  def new(f, expr) when f == :sin or f == :cos do
    %Trig{fn: f, expr: expr}
  end
end

defimpl Expression, for: Trig do
  def evaluate(%Trig{fn: f, expr: expr}, input) do
    Trig.trig_fn(f).(expr |> Expression.evaluate(input))
  end

  def get_variables(%Trig{expr: expr}) do
    Expression.get_variables(expr)
  end
end

defimpl Differentiable, for: Trig do
  def derivative(%Trig{fn: f, expr: expr}) do
    case f do
      :cos -> %ArithOp{op: :minus, exprs: %Trig{fn: :sin, expr: expr}}
      :sin -> %Trig{fn: :cos, expr: expr}
    end
  end
end
defmodule Foo do
  def new(opts) do
    opts
  end
end

Foo.new(x: 1)[:x]
expr1 = Trig.new(:sin, :x)
dexpr1 = expr1 |> Differentiable.derivative()
Expression.evaluate(dexpr1, 0)
Expression.evaluate(%ArithOp{op: :plus, exprs: [dexpr1, expr1]}, 0)