Skip to content

Commit db3402c

Browse files
committed
Add Jason encoding impl for non binary backends
1 parent 05d15f0 commit db3402c

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

.github/workflows/precompile.yml

+7-10
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ jobs:
1313
strategy:
1414
matrix:
1515
# Elixir 1.14.5 is first version compatible with OTP 26
16-
# NIF versionsss change according to
16+
# NIF versions change according to
1717
# https://github.com/erlang/otp/blob/dd57c853a324a9572a9e5ce227d8675ff004c6fe/erts/emulator/beam/erl_nif.h#L33
18-
otp: ["25.0", "26.0"]
18+
otp: ["25.0", "26.0", "27.0"]
1919
elixir: ["1.14.5"]
2020
steps:
2121
- uses: actions/checkout@v3
@@ -32,7 +32,7 @@ jobs:
3232
- name: Mix Test
3333
run: |
3434
mix deps.get
35-
MIX_ENV=test mix test
35+
MIX_ENV=test mix tests
3636
- name: Create precompiled library
3737
run: |
3838
export ELIXIR_MAKE_CACHE_DIR=$(pwd)/cache
@@ -49,13 +49,13 @@ jobs:
4949
# Homebrew supports versioned Erlang/OTP but not Elixir
5050
# It's a deliberate design decision from Homebrew to
5151
# only support versioned distrinutions for certin packages
52-
name: Mac (${{ matrix.runner == 'macos-latest' && 'Intel' || 'ARM' }}) Erlang/OTP ${{matrix.otp}} / Elixir
52+
name: Mac (${{ matrix.runner == 'macos-13' && 'Intel' || 'ARM' }}) Erlang/OTP ${{matrix.otp}} / Elixir
5353
env:
5454
MIX_ENV: "prod"
5555
strategy:
5656
matrix:
57-
runner: ["macos-latest", "exgboost-m2-runner"]
58-
otp: ["25.0", "26.0"]
57+
runner: ["macos-13", "macos-14"]
58+
otp: ["25.0", "26.0", "27.0"]
5959
elixir: ["1.14.5"]
6060
steps:
6161
- uses: actions/checkout@v3
@@ -82,7 +82,4 @@ jobs:
8282
if: startsWith(github.ref, 'refs/tags/')
8383
with:
8484
files: |
85-
${{ matrix.runner == 'macos-latest' && 'cache/*x86_64*.tar.gz' || 'cache/*aarch64*.tar.gz' }}
86-
- name: Cleanup # We need this since self-hosted runners are not ephemeral
87-
run: |
88-
make clean
85+
${{ matrix.runner == 'macos-13' && 'cache/*x86_64*.tar.gz' || 'cache/*aarch64*.tar.gz' }}

lib/exgboost/internal.ex

+22-11
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,31 @@ defmodule EXGBoost.Internal do
111111
# a string, so if we pass string NaN to XGBoost, it will fail.
112112
# This allows the user to use Nx.Constants.nan() and have it work as expected.
113113
defimpl Jason.Encoder, for: Nx.Tensor do
114-
def encode(%Nx.Tensor{data: %Nx.BinaryBackend{state: <<0x7FC0::16-native>>}}, _opts),
115-
do: "NaN"
114+
@binary_nans [
115+
<<0x7FC0::16-native>>,
116+
<<0x7E00::16-native>>,
117+
<<0x7FC00000::32-native>>,
118+
<<0x7FF8000000000000::64-native>>
119+
]
120+
def encode(%Nx.Tensor{data: %Nx.BinaryBackend{state: state}}, _opts)
121+
when state in @binary_nans,
122+
do: "NaN"
116123

117-
def encode(%Nx.Tensor{data: %Nx.BinaryBackend{state: <<0x7E00::16-native>>}}, _opts),
118-
do: "NaN"
124+
def encode(%Nx.Tensor{} = tensor, _opts) do
125+
case Nx.to_binary(tensor, limit: 1) do
126+
binary when binary in @binary_nans ->
127+
"NaN"
119128

120-
def encode(%Nx.Tensor{data: %Nx.BinaryBackend{state: <<0x7FC00000::32-native>>}}, _opts),
121-
do: "NaN"
129+
_ ->
130+
raise ArgumentError,
131+
"""
132+
JSON Encoding only implemented for NaN Tensors (Nx.Constants.nan())!
122133
123-
def encode(
124-
%Nx.Tensor{data: %Nx.BinaryBackend{state: <<0x7FF8000000000000::64-native>>}},
125-
_opts
126-
),
127-
do: "NaN"
134+
This normally is only used to map the `missing` parameter during EXGBoost
135+
training when `missing` is Nx.Constants.nan()
136+
"""
137+
end
138+
end
128139
end
129140

130141
def unwrap!({:ok, val}), do: val

mix.exs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
defmodule EXGBoost.MixProject do
22
use Mix.Project
3-
@version "0.5.0"
3+
@version "0.5.1"
44

55
def project do
66
[
@@ -47,7 +47,7 @@ defmodule EXGBoost.MixProject do
4747
[
4848
{:elixir_make, "~> 0.4", runtime: false},
4949
{:nimble_options, "~> 1.0"},
50-
{:nx, "~> 0.5"},
50+
{:nx, "~> 0.7"},
5151
{:jason, "~> 1.3"},
5252
{:ex_doc, "~> 0.31.0", only: :docs},
5353
{:cc_precompiler, "~> 0.1.0", runtime: false},

mix.lock

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"},
2525
"nimble_options": {:hex, :nimble_options, "1.1.0", "3b31a57ede9cb1502071fade751ab0c7b8dbe75a9a4c2b5bbb0943a690b63172", [:mix], [], "hexpm", "8bbbb3941af3ca9acc7835f5655ea062111c9c27bcac53e004460dfd19008a99"},
2626
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
27-
"nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"},
27+
"nx": {:hex, :nx, "0.7.2", "7f6f6584585e49ffbf81769e7ccc2d01c5639074e399c1f94adc2b509869673e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e2c0680066eec5af8b8ef00c99e9bf40a0d08d8b2bbba77f59f801ec54a3f90e"},
2828
"parse_trans": {:hex, :parse_trans, "3.4.1", "6e6aa8167cb44cc8f39441d05193be6e6f4e7c2946cb2759f015f8c56b76e5ff", [:rebar3], [], "hexpm", "620a406ce75dada827b82e453c19cf06776be266f5a67cff34e1ef2cbb60e49a"},
2929
"scidata": {:hex, :scidata, "0.1.11", "fe3358bac7d740374b4f2a7eff6a1cb02e5ee7f87f7cdb1e8648ad93c533165f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "90873337a9d5fe880d640517efa93d3c07e46c8ba436de44117f581800549f93"},
3030
"ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"},

0 commit comments

Comments
 (0)