Skip to content

Commit a351e89

Browse files
committed
symbolic stuff in elixir
1 parent caad3e6 commit a351e89

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

symex.livemd

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Symbolic Elixir
2+
3+
```elixir
4+
Mix.install([
5+
{:math, "~> 0.7.0"}
6+
])
7+
```
8+
9+
## Section
10+
11+
```elixir
12+
import Math
13+
14+
defprotocol Expression do
15+
def substitute(expr, bindings)
16+
17+
@spec evaluate(__MODULE__.t(), [{Atom.t(), any()}]) :: any()
18+
def evaluate(expr, input)
19+
20+
# @spec substitute(__MODULE__.t, Map.t) :: __MODULE__.t
21+
# def substitute(expr, var_mapping)
22+
23+
@spec get_variables(__MODULE__.t()) :: [Atom.t()]
24+
def get_variables(expr)
25+
end
26+
27+
defimpl Expression, for: Atom do
28+
def substitute(var, bindings), do: bindings[var]
29+
def evaluate(var, bindings), do: bindings[var]
30+
def get_variables(var), do: [var]
31+
end
32+
33+
defimpl Expression, for: List do
34+
def substitute(exprs, bindings) do
35+
for {expr, input} <- Enum.zip(exprs, inputs) do
36+
expr |> Expression.substitute(input)
37+
end
38+
end
39+
40+
def evaluate(exprs, bindings) when length(exprs) >= length(inputs) do
41+
for {expr, input} <- Enum.zip(exprs, inputs) do
42+
expr |> Expression.evaluate(input)
43+
end
44+
end
45+
46+
def get_variables(exprs), do: exprs |> Enum.map(&Expression.get_variables/1)
47+
end
48+
49+
defmodule Operation do
50+
defstruct [:op, :exprs]
51+
52+
def substitute_exprs(%Operation{op: op, exprs: exprs}) do
53+
end
54+
end
55+
56+
defimpl Expression, for: ArithOp do
57+
def substitute(%ArithOp{op: op, exprs: exprs}, bindings) do
58+
exprs |> Expression.substitute(inputs) |> to_list()
59+
end
60+
61+
def evaluate(%ArithOp{op: op, exprs: exprs}, inputs) do
62+
input_values = exprs |> Expression.evaluate(inputs) |> to_list()
63+
64+
f =
65+
case op do
66+
:plus -> fn x, y -> x + y end
67+
:minus -> fn x, y -> x - y end
68+
end
69+
70+
input_values |> Enum.reduce(f)
71+
end
72+
73+
def to_list(l) when is_list(l) do
74+
l
75+
end
76+
77+
def to_list(v), do: [v]
78+
end
79+
80+
defprotocol Differentiable do
81+
@spec derivative(Expression.t()) :: Expression.t()
82+
def derivative(expr)
83+
end
84+
85+
defmodule Trig do
86+
defstruct [:fn, :expr]
87+
88+
def trig_fn(:sin) do
89+
&Math.sin/1
90+
end
91+
92+
def trig_fn(:cos) do
93+
&Math.cos/1
94+
end
95+
96+
def new(f, expr) when f == :sin or f == :cos do
97+
%Trig{fn: f, expr: expr}
98+
end
99+
end
100+
101+
defimpl Expression, for: Trig do
102+
def evaluate(%Trig{fn: f, expr: expr}, input) do
103+
Trig.trig_fn(f).(expr |> Expression.evaluate(input))
104+
end
105+
106+
def get_variables(%Trig{expr: expr}) do
107+
Expression.get_variables(expr)
108+
end
109+
end
110+
111+
defimpl Differentiable, for: Trig do
112+
def derivative(%Trig{fn: f, expr: expr}) do
113+
case f do
114+
:cos -> %ArithOp{op: :minus, exprs: %Trig{fn: :sin, expr: expr}}
115+
:sin -> %Trig{fn: :cos, expr: expr}
116+
end
117+
end
118+
end
119+
```
120+
121+
```elixir
122+
defmodule Foo do
123+
def new(opts) do
124+
opts
125+
end
126+
end
127+
128+
Foo.new(x: 1)[:x]
129+
```
130+
131+
```elixir
132+
expr1 = Trig.new(:sin, :x)
133+
dexpr1 = expr1 |> Differentiable.derivative()
134+
Expression.evaluate(dexpr1, 0)
135+
Expression.evaluate(%ArithOp{op: :plus, exprs: [dexpr1, expr1]}, 0)
136+
```

0 commit comments

Comments
 (0)