Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
version: "2.2"
services:
postgres:
image: postgres:13
# image: postgres:13
build:
dockerfile: ./postgres.Dockerfile
ports:
- "65432:5432"
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=memex
- POSTGRES_DB=memex
networks:
- memex1
volumes:
- /dev/urandom:/dev/random # Required to get non-blocking entropy source
- postgres-db:/var/lib/postgresql/data
Expand All @@ -23,7 +27,7 @@ services:
"postgres",
"-c",
"select 1",
"memex",
"memex"
]
interval: 10s
timeout: 10s
Expand All @@ -32,3 +36,7 @@ services:

volumes:
postgres-db:


networks:
memex1:
3 changes: 2 additions & 1 deletion memex/config/config.exs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ config :memex, ecto_repos: [Memex.Repo]
config :memex, Memex.Repo,
url: System.get_env("POSTGRES_DSN"),
pool_size: 5,
timeout: 60_000
timeout: 60_000,
types: Memex.PostgrexTypes

config :elixir, :time_zone_database, Tzdata.TimeZoneDatabase

Expand Down
114 changes: 114 additions & 0 deletions memex/lib/memex/ai/sentence_transformers.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
defmodule Memex.Ai.SentenceTransformers do
alias Bumblebee.Shared

def embed(text) when is_binary(text) do
embedding =
Nx.Serving.batched_run(__MODULE__, [text])
|> Nx.to_flat_list()
end

"""
{:ok, _} =
Supervisor.start_link(
[
PhoenixDemo.Endpoint,
{Nx.Serving, serving: Memex.Ai.SentenceTransformers.serving(), name: Memex.Ai.SentenceTransformers, batch_timeout: 100}
],
strategy: :one_for_one
)
"""

def serving() do
model_name = "sentence-transformers/all-MiniLM-L6-v2"
{:ok, model_info} = Bumblebee.load_model({:hf, model_name})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})

Memex.Ai.SentenceTransformers.sentence_embeddings(model_info, tokenizer,
compile: [batch_size: 10, sequence_length: 128],
defn_options: [compiler: EXLA]
)
end

def test(query \\ "TicketSwap") do
embedding = Nx.Serving.batched_run(__MODULE__, [query])

embeddings =
Nx.Serving.batched_run(__MODULE__, [
"pgvector/pgvector: Open-source vector similarity search for Postgres MacBook Pro: https://github.com/pgvector/pgvector",
"Install Jupyter Notebook | Learn How to Install and Use Jupyter Notebook",
"A picture of London at night",
"Fastai on Apple M1 - Deep Learning - fast.ai Course Forums: https://forums.fast.ai/t/fastai-on-apple-m1/86059/50"
])

Bumblebee.Utils.Nx.cosine_similarity(embedding, embeddings)
end

def sentence_embeddings(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: spec} = model_info
Shared.validate_architecture!(spec, :base)
opts = Keyword.validate!(opts, [:compile, defn_options: []])

compile = opts[:compile]
defn_options = opts[:defn_options]

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

if compile != nil and (batch_size == nil or sequence_length == nil) do
raise ArgumentError,
"expected :compile to be a keyword list specifying :batch_size and :sequence_length, got: #{inspect(compile)}"
end

{_init_fun, predict_fun} = Axon.build(model)

scores_fun = fn params, input ->
outputs = predict_fun.(params, input)
outputs.pooled_state
end

Nx.Serving.new(
fn ->
scores_fun =
Shared.compile_or_jit(scores_fun, defn_options, compile != nil, fn ->
inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :s64),
"token_type_ids" => Nx.template({batch_size, sequence_length}, :s64),
"attention_mask" => Nx.template({batch_size, sequence_length}, :s64)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs)
end
end,
batch_size: batch_size
)
|> Nx.Serving.client_preprocessing(fn input ->
{texts, multi?} = Shared.validate_serving_input!(input, &is_binary/1, "a string")

inputs = Bumblebee.apply_tokenizer(tokenizer, texts)

{Nx.Batch.concatenate([inputs]), multi?}
end)
|> Nx.Serving.client_postprocessing(fn scores, metadata, multi? ->
scores |> IO.inspect(label: "103")
multi? |> IO.inspect(label: "98")
metadata |> IO.inspect(label: "99")
# Mean Pooling - Take attention mask into account for correct averaging
# def mean_pooling(model_output, attention_mask):
# token_embeddings = model_output[0] #First element of model_output contains all token embeddings
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
# return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

token_embeddings = scores[0]
input_mask_expanded = Nx.unsqueece(token_embeddings, -1)
Nx.size(token_embeddings)

scores
|> Shared.normalize_output(multi?)
end)
end
end
6 changes: 5 additions & 1 deletion memex/lib/memex/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ defmodule Memex.Application do
{Memex.Repo, []},
# Start a worker by calling: Memex.Worker.start_link(arg)
# {Memex.Worker, arg}
{Memex.Scheduler, []}
{Memex.Scheduler, []},
{Nx.Serving,
serving: Memex.Ai.SentenceTransformers.serving(),
name: Memex.Ai.SentenceTransformers,
batch_timeout: 100}
]

# See https://hexdocs.pm/elixir/Supervisor.html
Expand Down
11 changes: 11 additions & 0 deletions memex/lib/memex/search/postgres.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ defmodule Memex.Search.Postgres do
from(d in Document)
|> add_search(query)
|> add_select(query)
# |> add_vector_search(query)
|> add_filters(prepare_filters(query))
|> add_limit(query)
|> add_order_by(query.order_by)
Expand All @@ -29,6 +30,12 @@ defmodule Memex.Search.Postgres do
from(q in q, where: fragment("? @@ to_tsquery('simple', ?)", q.search, ^to_tsquery(query)))
end

defp add_vector_search(q, %Query{query: ""} = _query), do: q

defp add_vector_search(q, %Query{} = query) do
from(q in q, where: fragment("? <=> ? > 0.6", q.search_embedding, ^to_vector(query)))
end

defp add_select(q, %Query{select: :hits_with_highlights} = _query) do
from(q in q, select: q)
end
Expand Down Expand Up @@ -154,6 +161,10 @@ defmodule Memex.Search.Postgres do
|> String.trim()
end

defp to_vector(%Query{} = query) do
Memex.Ai.SentenceTransformers.embed(query.query)
end

defp format_results(results, %Query{select: :hits_with_highlights} = query) do
results
|> Enum.map(&put_in(&1, ["hit", "_formatted"], format_hit(&1["hit"], query)))
Expand Down
5 changes: 5 additions & 0 deletions memex/lib/postgrex_types.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Postgrex.Types.define(
Memex.PostgrexTypes,
[Pgvector.Extensions.Vector] ++ Ecto.Adapters.Postgres.extensions(),
[]
)
9 changes: 9 additions & 0 deletions memex/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ defmodule Memex.MixProject do
# Type `mix help deps` for examples and options.
defp deps do
[
{:axon, "~> 0.3.0"},
{:axon_onnx, path: "./axon_onnx"},
{:bumblebee, "~> 0.1.0"},
{:exla, "~> 0.4.0"},
{:nx, "~> 0.4.0"},
{:con_cache, "~> 0.13"},
{:earmark, "~> 1.4.16"},
{:ecto_sql, "~> 3.4"},
Expand All @@ -58,6 +63,7 @@ defmodule Memex.MixProject do
{:gettext, "~> 0.11"},
{:git_diff, "~> 0.6.2"},
{:hackney, "~> 1.16.0"},
{:heroicons, "~> 0.2.2"},
{:jason, "~> 1.0"},
{:money, "~> 1.9"},
{:month, "~> 2.1"},
Expand All @@ -68,12 +74,15 @@ defmodule Memex.MixProject do
{:phoenix, "~> 1.6.0"},
{:plug_cowboy, "~> 2.0"},
{:postgrex, "~> 0.15.9"},
{:pgvector, "~> 0.1.0"},
{:rustler, ">= 0.0.0", optional: true},
{:surface, "~> 0.9.0"},
{:surface_catalogue, "~> 0.5.0"},
{:surface_formatter, "~> 0.7.5", only: :dev},
{:telemetry_metrics, "~> 0.6"},
{:telemetry_poller, "~> 0.5"},
{:tesla, "~> 1.4.0"},
{:tokenizers, "~> 0.2.0"},
{:tzdata, "~> 1.0"}
]
end
Expand Down
Loading