Skip to content

Commit 3321105

Browse files
authored
Merge pull request #37 from acalejos/plotting
Draft: Plotting
2 parents 9649af0 + cf10353 commit 3321105

18 files changed

+2545
-32
lines changed

.github/workflows/precompile.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
name: precompile
22

3-
on: push
3+
on:
4+
- push
5+
- workflow_dispatch
46

57
jobs:
68
linux:
@@ -52,7 +54,7 @@ jobs:
5254
MIX_ENV: "prod"
5355
strategy:
5456
matrix:
55-
runner: ["macos-latest", "self-hosted"]
57+
runner: ["macos-latest", "exgboost-m2-runner"]
5658
otp: ["25.0", "26.0"]
5759
elixir: ["1.14.5"]
5860
steps:

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ erl_crash.dump
1111
.elixir_ls/
1212
.tool-versions
1313
.vscode/
14-
checksum.exs
14+
checksum.exs
15+
.DS_Store

Makefile

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ $(XGBOOST_LIB_DIR_FLAG):
6363
git fetch --depth 1 --recurse-submodules origin $(XGBOOST_GIT_REV) && \
6464
git checkout FETCH_HEAD && \
6565
git submodule update --init --recursive && \
66+
sed 's|learner_parameters\["generic_param"\] = ToJson(ctx_);|&\nlearner_parameters\["default_metric"\] = String(obj_->DefaultEvalMetric());|' src/learner.cc > src/learner.cc.tmp && mv src/learner.cc.tmp src/learner.cc && \
6667
cmake -DCMAKE_INSTALL_PREFIX=$(XGBOOST_LIB_DIR) -B build . $(CMAKE_FLAGS) && \
6768
make -C build -j1 install
6869
touch $(XGBOOST_LIB_DIR_FLAG)

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ billions of examples.
2222
```elixir
2323
def deps do
2424
[
25-
{:exgboost, "~> 0.3"}
25+
{:exgboost, "~> 0.5"}
2626
]
2727
end
2828
```

lib/exgboost.ex

+73-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ defmodule EXGBoost do
1616
```elixir
1717
def deps do
1818
[
19-
{:exgboost, "~> 0.4"}
19+
{:exgboost, "~> 0.5"}
2020
]
2121
end
2222
```
@@ -92,7 +92,7 @@ defmodule EXGBoost do
9292
preds = EXGBoost.train(X, y) |> EXGBoost.predict(X)
9393
```
9494
95-
## Serliaztion
95+
## Serialization
9696
9797
A Booster can be serialized to a file using `EXGBoost.write_*` and loaded from a file
9898
using `EXGBoost.read_*`. The file format can be specified using the `:format` option
@@ -113,6 +113,34 @@ defmodule EXGBoost do
113113
- `config` - Save the configuration only.
114114
- `weights` - Save the model parameters only. Use this when you want to save the model to a format that can be ingested by other XGBoost APIs.
115115
- `model` - Save both the model parameters and the configuration.
116+
117+
## Plotting
118+
119+
`EXGBoost.plot_tree/2` is the primary entry point for plotting a tree from a trained model.
120+
It accepts an `EXGBoost.Booster` struct (which is the output of `EXGBoost.train/2`).
121+
`EXGBoost.plot_tree/2` returns a VegaLite spec that can be rendered in a notebook or saved to a file.
122+
`EXGBoost.plot_tree/2` also accepts a keyword list of options that can be used to configure the plotting process.
123+
124+
See `EXGBoost.Plotting` for more detail on plotting.
125+
126+
You can see available styles by running `EXGBoost.Plotting.get_styles()` or refer to the `EXGBoost.Plotting.Styles`
127+
documentation for a gallery of the styles.
128+
129+
## Kino & Livebook Integration
130+
131+
`EXGBoost` integrates with [Kino](https://hexdocs.pm/kino/Kino.html) and [Livebook](https://livebook.dev/)
132+
to provide a rich interactive experience for data scientists.
133+
134+
EXGBoost implements the `Kino.Render` protocol for `EXGBoost.Booster` structs. This allows you to render
135+
a Booster in a Livebook notebook. Under the hood, `EXGBoost` uses [Vega-Lite](https://vega.github.io/vega-lite/)
136+
and [Kino Vega-Lite](https://hexdocs.pm/kino_vega_lite/Kino.VegaLite.html) to render the Booster.
137+
138+
See the [`Plotting in EXGBoost`](notebooks/plotting.livemd) Notebook for an example of how to use `EXGBoost` with `Kino` and `Livebook`.
139+
140+
## Examples
141+
142+
See the example Notebooks in the left sidebar (under the `Pages` tab) for more examples and tutorials
143+
on how to use EXGBoost.
116144
"""
117145

118146
alias EXGBoost.ArrayInterface
@@ -121,13 +149,15 @@ defmodule EXGBoost do
121149
alias EXGBoost.DMatrix
122150
alias EXGBoost.ProxyDMatrix
123151
alias EXGBoost.Training
152+
alias EXGBoost.Plotting
124153

125154
@doc """
126155
Check the build information of the xgboost library.
127156
128157
Returns a map containing information about the build.
129158
"""
130159
@spec xgboost_build_info() :: map()
160+
@doc type: :system
131161
def xgboost_build_info,
132162
do: EXGBoost.NIF.xgboost_build_info() |> Internal.unwrap!() |> Jason.decode!()
133163

@@ -137,6 +167,7 @@ defmodule EXGBoost do
137167
Returns a 3-tuple in the form of `{major, minor, patch}`.
138168
"""
139169
@spec xgboost_version() :: {integer(), integer(), integer()} | {:error, String.t()}
170+
@doc type: :system
140171
def xgboost_version, do: EXGBoost.NIF.xgboost_version() |> Internal.unwrap!()
141172

142173
@doc """
@@ -147,6 +178,7 @@ defmodule EXGBoost do
147178
for the full list of parameters supported in the global configuration.
148179
"""
149180
@spec set_config(map()) :: :ok | {:error, String.t()}
181+
@doc type: :system
150182
def set_config(%{} = config) do
151183
config = EXGBoost.Parameters.validate_global!(config)
152184
EXGBoost.NIF.set_global_config(Jason.encode!(config)) |> Internal.unwrap!()
@@ -160,6 +192,7 @@ defmodule EXGBoost do
160192
for the full list of parameters supported in the global configuration.
161193
"""
162194
@spec get_config() :: map()
195+
@doc type: :system
163196
def get_config do
164197
EXGBoost.NIF.get_global_config() |> Internal.unwrap!() |> Jason.decode!()
165198
end
@@ -208,10 +241,11 @@ defmodule EXGBoost do
208241
* `opts` - Refer to `EXGBoost.Parameters` for the full list of options.
209242
"""
210243
@spec train(Nx.Tensor.t(), Nx.Tensor.t(), Keyword.t()) :: EXGBoost.Booster.t()
244+
@doc type: :train_pred
211245
def train(x, y, opts \\ []) do
212246
x = Nx.concatenate(x)
213247
y = Nx.concatenate(y)
214-
{dmat_opts, opts} = Keyword.split(opts, Internal.dmatrix_feature_opts())
248+
dmat_opts = Keyword.take(opts, Internal.dmatrix_feature_opts())
215249
dmat = DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense))
216250
Training.train(dmat, opts)
217251
end
@@ -272,6 +306,7 @@ defmodule EXGBoost do
272306
273307
Returns an Nx.Tensor containing the predictions.
274308
"""
309+
@doc type: :train_pred
275310
def predict(%Booster{} = bst, x, opts \\ []) do
276311
x = Nx.concatenate(x)
277312
{dmat_opts, opts} = Keyword.split(opts, Internal.dmatrix_feature_opts())
@@ -302,6 +337,7 @@ defmodule EXGBoost do
302337
303338
Returns an Nx.Tensor containing the predictions.
304339
"""
340+
@doc type: :train_pred
305341
def inplace_predict(%Booster{} = boostr, data, opts \\ []) do
306342
opts =
307343
Keyword.validate!(opts,
@@ -428,6 +464,7 @@ defmodule EXGBoost do
428464
## Options
429465
#{NimbleOptions.docs(@write_schema)}
430466
"""
467+
@doc type: :serialization
431468
@spec write_model(Booster.t(), String.t()) :: :ok | {:error, String.t()}
432469
def write_model(%Booster{} = booster, path, opts \\ []) do
433470
opts = NimbleOptions.validate!(opts, @write_schema)
@@ -437,6 +474,7 @@ defmodule EXGBoost do
437474
@doc """
438475
Read a model from a file and return the Booster.
439476
"""
477+
@doc type: :serialization
440478
@spec read_model(String.t()) :: EXGBoost.Booster.t()
441479
def read_model(path) do
442480
EXGBoost.Booster.load(path, deserialize: :model)
@@ -449,6 +487,7 @@ defmodule EXGBoost do
449487
#{NimbleOptions.docs(@dump_schema)}
450488
"""
451489
@spec dump_model(Booster.t()) :: binary()
490+
@doc type: :serialization
452491
def dump_model(%Booster{} = booster, opts \\ []) do
453492
opts = NimbleOptions.validate!(opts, @dump_schema)
454493
EXGBoost.Booster.save(booster, opts ++ [serialize: :model, to: :buffer])
@@ -458,6 +497,7 @@ defmodule EXGBoost do
458497
Read a model from a buffer and return the Booster.
459498
"""
460499
@spec load_model(binary()) :: EXGBoost.Booster.t()
500+
@doc type: :serialization
461501
def load_model(buffer) do
462502
EXGBoost.Booster.load(buffer, deserialize: :model, from: :buffer)
463503
end
@@ -469,6 +509,7 @@ defmodule EXGBoost do
469509
#{NimbleOptions.docs(@write_schema)}
470510
"""
471511
@spec write_config(Booster.t(), String.t()) :: :ok | {:error, String.t()}
512+
@doc type: :serialization
472513
def write_config(%Booster{} = booster, path, opts \\ []) do
473514
opts = NimbleOptions.validate!(opts, @write_schema)
474515
EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :config])
@@ -481,6 +522,7 @@ defmodule EXGBoost do
481522
#{NimbleOptions.docs(@dump_schema)}
482523
"""
483524
@spec dump_config(Booster.t()) :: binary()
525+
@doc type: :serialization
484526
def dump_config(%Booster{} = booster, opts \\ []) do
485527
opts = NimbleOptions.validate!(opts, @dump_schema)
486528
EXGBoost.Booster.save(booster, opts ++ [serialize: :config, to: :buffer])
@@ -493,6 +535,7 @@ defmodule EXGBoost do
493535
#{NimbleOptions.docs(@load_schema)}
494536
"""
495537
@spec read_config(String.t()) :: EXGBoost.Booster.t()
538+
@doc type: :serialization
496539
def read_config(path, opts \\ []) do
497540
opts = NimbleOptions.validate!(opts, @load_schema)
498541
EXGBoost.Booster.load(path, opts ++ [deserialize: :config])
@@ -505,6 +548,7 @@ defmodule EXGBoost do
505548
#{NimbleOptions.docs(@load_schema)}
506549
"""
507550
@spec load_config(binary()) :: EXGBoost.Booster.t()
551+
@doc type: :serialization
508552
def load_config(buffer, opts \\ []) do
509553
opts = NimbleOptions.validate!(opts, @load_schema)
510554
EXGBoost.Booster.load(buffer, opts ++ [deserialize: :config, from: :buffer])
@@ -517,6 +561,7 @@ defmodule EXGBoost do
517561
#{NimbleOptions.docs(@write_schema)}
518562
"""
519563
@spec write_weights(Booster.t(), String.t()) :: :ok | {:error, String.t()}
564+
@doc type: :serialization
520565
def write_weights(%Booster{} = booster, path, opts \\ []) do
521566
opts = NimbleOptions.validate!(opts, @write_schema)
522567
EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :weights])
@@ -529,6 +574,7 @@ defmodule EXGBoost do
529574
#{NimbleOptions.docs(@dump_schema)}
530575
"""
531576
@spec dump_weights(Booster.t()) :: binary()
577+
@doc type: :serialization
532578
def dump_weights(%Booster{} = booster, opts \\ []) do
533579
opts = NimbleOptions.validate!(opts, @dump_schema)
534580
EXGBoost.Booster.save(booster, opts ++ [serialize: :weights, to: :buffer])
@@ -538,6 +584,7 @@ defmodule EXGBoost do
538584
Read a model's trained parameters from a file and return the Booster.
539585
"""
540586
@spec read_weights(String.t()) :: EXGBoost.Booster.t()
587+
@doc type: :serialization
541588
def read_weights(path) do
542589
EXGBoost.Booster.load(path, deserialize: :weights)
543590
end
@@ -546,7 +593,30 @@ defmodule EXGBoost do
546593
Read a model's trained parameters from a buffer and return the Booster.
547594
"""
548595
@spec load_weights(binary()) :: EXGBoost.Booster.t()
596+
@doc type: :serialization
549597
def load_weights(buffer) do
550598
EXGBoost.Booster.load(buffer, deserialize: :weights, from: :buffer)
551599
end
600+
601+
@doc """
602+
Plot a tree from a Booster model and save it to a file.
603+
604+
## Options
605+
* `:format` - the format to export the graphic as, must be either of: `:json`, `:html`, `:png`, `:svg`, `:pdf`. By default the format is inferred from the file extension.
606+
* `:local_npm_prefix` - a relative path pointing to a local npm project directory where the necessary npm packages are installed. For instance, in Phoenix projects you may want to pass local_npm_prefix: "assets". By default the npm packages are searched for in the current directory and globally.
607+
* `:path` - the path to save the graphic to. If not provided, the graphic is returned as a VegaLite spec.
608+
* `:opts` - additional options to pass to `EXGBoost.Plotting.plot/2`. See `EXGBoost.Plotting` for more information.
609+
"""
610+
@doc type: :plotting
611+
def plot_tree(booster, opts \\ []) do
612+
{path, opts} = Keyword.pop(opts, :path)
613+
{save_opts, opts} = Keyword.split(opts, [:format, :local_npm_prefix])
614+
vega = Plotting.plot(booster, opts)
615+
616+
if path != nil do
617+
VegaLite.Export.save!(vega, path, save_opts)
618+
else
619+
vega
620+
end
621+
end
552622
end

lib/exgboost/booster.ex

+12
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,15 @@ defmodule EXGBoost.Booster do
163163
def booster(dmats, opts \\ [])
164164

165165
def booster(dmats, opts) when is_list(dmats) do
166+
{str_opts, opts} = Keyword.split(opts, Internal.dmatrix_str_feature_opts())
166167
opts = EXGBoost.Parameters.validate!(opts)
167168
refs = Enum.map(dmats, & &1.ref)
168169
booster_ref = EXGBoost.NIF.booster_create(refs) |> Internal.unwrap!()
170+
171+
Enum.each(str_opts, fn {key, value} ->
172+
EXGBoost.NIF.booster_set_str_feature_info(booster_ref, Atom.to_string(key), value)
173+
end)
174+
169175
set_params(%__MODULE__{ref: booster_ref}, opts)
170176
end
171177

@@ -174,9 +180,15 @@ defmodule EXGBoost.Booster do
174180
end
175181

176182
def booster(%__MODULE__{} = bst, opts) do
183+
{str_opts, opts} = Keyword.split(opts, Internal.dmatrix_str_feature_opts())
177184
opts = EXGBoost.Parameters.validate!(opts)
178185
boostr_bytes = EXGBoost.NIF.booster_serialize_to_buffer(bst.ref) |> Internal.unwrap!()
179186
booster_ref = EXGBoost.NIF.booster_deserialize_from_buffer(boostr_bytes) |> Internal.unwrap!()
187+
188+
Enum.each(str_opts, fn {key, value} ->
189+
EXGBoost.NIF.booster_set_str_feature_info(booster_ref, Atom.to_string(key), value)
190+
end)
191+
180192
set_params(%__MODULE__{ref: booster_ref}, opts)
181193
end
182194

lib/exgboost/nif.ex

-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ defmodule EXGBoost.NIF do
100100
def dmatrix_create_from_file(_file_uri, _silent),
101101
do: :erlang.nif_error(:not_implemented)
102102

103-
@since "0.4.0"
104103
def dmatrix_create_from_uri(_config), do: :erlang.nif_error(:not_implemented)
105104

106105
@spec dmatrix_create_from_mat(binary, integer(), integer(), float()) ::

0 commit comments

Comments
 (0)