@@ -16,7 +16,7 @@ defmodule EXGBoost do
16
16
```elixir
17
17
def deps do
18
18
[
19
- {:exgboost, "~> 0.4 "}
19
+ {:exgboost, "~> 0.5 "}
20
20
]
21
21
end
22
22
```
@@ -92,7 +92,7 @@ defmodule EXGBoost do
92
92
preds = EXGBoost.train(X, y) |> EXGBoost.predict(X)
93
93
```
94
94
95
- ## Serliaztion
95
+ ## Serialization
96
96
97
97
A Booster can be serialized to a file using `EXGBoost.write_*` and loaded from a file
98
98
using `EXGBoost.read_*`. The file format can be specified using the `:format` option
@@ -113,6 +113,34 @@ defmodule EXGBoost do
113
113
- `config` - Save the configuration only.
114
114
- `weights` - Save the model parameters only. Use this when you want to save the model to a format that can be ingested by other XGBoost APIs.
115
115
- `model` - Save both the model parameters and the configuration.
116
+
117
+ ## Plotting
118
+
119
+ `EXGBoost.plot_tree/2` is the primary entry point for plotting a tree from a trained model.
120
+ It accepts an `EXGBoost.Booster` struct (which is the output of `EXGBoost.train/2`).
121
+ `EXGBoost.plot_tree/2` returns a VegaLite spec that can be rendered in a notebook or saved to a file.
122
+ `EXGBoost.plot_tree/2` also accepts a keyword list of options that can be used to configure the plotting process.
123
+
124
+ See `EXGBoost.Plotting` for more detail on plotting.
125
+
126
+ You can see available styles by running `EXGBoost.Plotting.get_styles()` or refer to the `EXGBoost.Plotting.Styles`
127
+ documentation for a gallery of the styles.
128
+
129
+ ## Kino & Livebook Integration
130
+
131
+ `EXGBoost` integrates with [Kino](https://hexdocs.pm/kino/Kino.html) and [Livebook](https://livebook.dev/)
132
+ to provide a rich interactive experience for data scientists.
133
+
134
+ EXGBoost implements the `Kino.Render` protocol for `EXGBoost.Booster` structs. This allows you to render
135
+ a Booster in a Livebook notebook. Under the hood, `EXGBoost` uses [Vega-Lite](https://vega.github.io/vega-lite/)
136
+ and [Kino Vega-Lite](https://hexdocs.pm/kino_vega_lite/Kino.VegaLite.html) to render the Booster.
137
+
138
+ See the [`Plotting in EXGBoost`](notebooks/plotting.livemd) Notebook for an example of how to use `EXGBoost` with `Kino` and `Livebook`.
139
+
140
+ ## Examples
141
+
142
+ See the example Notebooks in the left sidebar (under the `Pages` tab) for more examples and tutorials
143
+ on how to use EXGBoost.
116
144
"""
117
145
118
146
alias EXGBoost.ArrayInterface
@@ -121,13 +149,15 @@ defmodule EXGBoost do
121
149
alias EXGBoost.DMatrix
122
150
alias EXGBoost.ProxyDMatrix
123
151
alias EXGBoost.Training
152
+ alias EXGBoost.Plotting
124
153
125
154
@ doc """
126
155
Check the build information of the xgboost library.
127
156
128
157
Returns a map containing information about the build.
129
158
"""
130
159
@ spec xgboost_build_info ( ) :: map ( )
160
+ @ doc type: :system
131
161
def xgboost_build_info ,
132
162
do: EXGBoost.NIF . xgboost_build_info ( ) |> Internal . unwrap! ( ) |> Jason . decode! ( )
133
163
@@ -137,6 +167,7 @@ defmodule EXGBoost do
137
167
Returns a 3-tuple in the form of `{major, minor, patch}`.
138
168
"""
139
169
@ spec xgboost_version ( ) :: { integer ( ) , integer ( ) , integer ( ) } | { :error , String . t ( ) }
170
+ @ doc type: :system
140
171
def xgboost_version , do: EXGBoost.NIF . xgboost_version ( ) |> Internal . unwrap! ( )
141
172
142
173
@ doc """
@@ -147,6 +178,7 @@ defmodule EXGBoost do
147
178
for the full list of parameters supported in the global configuration.
148
179
"""
149
180
@ spec set_config ( map ( ) ) :: :ok | { :error , String . t ( ) }
181
+ @ doc type: :system
150
182
def set_config ( % { } = config ) do
151
183
config = EXGBoost.Parameters . validate_global! ( config )
152
184
EXGBoost.NIF . set_global_config ( Jason . encode! ( config ) ) |> Internal . unwrap! ( )
@@ -160,6 +192,7 @@ defmodule EXGBoost do
160
192
for the full list of parameters supported in the global configuration.
161
193
"""
162
194
@ spec get_config ( ) :: map ( )
195
+ @ doc type: :system
163
196
def get_config do
164
197
EXGBoost.NIF . get_global_config ( ) |> Internal . unwrap! ( ) |> Jason . decode! ( )
165
198
end
@@ -208,10 +241,11 @@ defmodule EXGBoost do
208
241
* `opts` - Refer to `EXGBoost.Parameters` for the full list of options.
209
242
"""
210
243
@ spec train ( Nx.Tensor . t ( ) , Nx.Tensor . t ( ) , Keyword . t ( ) ) :: EXGBoost.Booster . t ( )
244
+ @ doc type: :train_pred
211
245
def train ( x , y , opts \\ [ ] ) do
212
246
x = Nx . concatenate ( x )
213
247
y = Nx . concatenate ( y )
214
- { dmat_opts , opts } = Keyword . split ( opts , Internal . dmatrix_feature_opts ( ) )
248
+ dmat_opts = Keyword . take ( opts , Internal . dmatrix_feature_opts ( ) )
215
249
dmat = DMatrix . from_tensor ( x , y , Keyword . put_new ( dmat_opts , :format , :dense ) )
216
250
Training . train ( dmat , opts )
217
251
end
@@ -272,6 +306,7 @@ defmodule EXGBoost do
272
306
273
307
Returns an Nx.Tensor containing the predictions.
274
308
"""
309
+ @ doc type: :train_pred
275
310
def predict ( % Booster { } = bst , x , opts \\ [ ] ) do
276
311
x = Nx . concatenate ( x )
277
312
{ dmat_opts , opts } = Keyword . split ( opts , Internal . dmatrix_feature_opts ( ) )
@@ -302,6 +337,7 @@ defmodule EXGBoost do
302
337
303
338
Returns an Nx.Tensor containing the predictions.
304
339
"""
340
+ @ doc type: :train_pred
305
341
def inplace_predict ( % Booster { } = boostr , data , opts \\ [ ] ) do
306
342
opts =
307
343
Keyword . validate! ( opts ,
@@ -428,6 +464,7 @@ defmodule EXGBoost do
428
464
## Options
429
465
#{ NimbleOptions . docs ( @ write_schema ) }
430
466
"""
467
+ @ doc type: :serialization
431
468
@ spec write_model ( Booster . t ( ) , String . t ( ) ) :: :ok | { :error , String . t ( ) }
432
469
def write_model ( % Booster { } = booster , path , opts \\ [ ] ) do
433
470
opts = NimbleOptions . validate! ( opts , @ write_schema )
@@ -437,6 +474,7 @@ defmodule EXGBoost do
437
474
@ doc """
438
475
Read a model from a file and return the Booster.
439
476
"""
477
+ @ doc type: :serialization
440
478
@ spec read_model ( String . t ( ) ) :: EXGBoost.Booster . t ( )
441
479
def read_model ( path ) do
442
480
EXGBoost.Booster . load ( path , deserialize: :model )
@@ -449,6 +487,7 @@ defmodule EXGBoost do
449
487
#{ NimbleOptions . docs ( @ dump_schema ) }
450
488
"""
451
489
@ spec dump_model ( Booster . t ( ) ) :: binary ( )
490
+ @ doc type: :serialization
452
491
def dump_model ( % Booster { } = booster , opts \\ [ ] ) do
453
492
opts = NimbleOptions . validate! ( opts , @ dump_schema )
454
493
EXGBoost.Booster . save ( booster , opts ++ [ serialize: :model , to: :buffer ] )
@@ -458,6 +497,7 @@ defmodule EXGBoost do
458
497
Read a model from a buffer and return the Booster.
459
498
"""
460
499
@ spec load_model ( binary ( ) ) :: EXGBoost.Booster . t ( )
500
+ @ doc type: :serialization
461
501
def load_model ( buffer ) do
462
502
EXGBoost.Booster . load ( buffer , deserialize: :model , from: :buffer )
463
503
end
@@ -469,6 +509,7 @@ defmodule EXGBoost do
469
509
#{ NimbleOptions . docs ( @ write_schema ) }
470
510
"""
471
511
@ spec write_config ( Booster . t ( ) , String . t ( ) ) :: :ok | { :error , String . t ( ) }
512
+ @ doc type: :serialization
472
513
def write_config ( % Booster { } = booster , path , opts \\ [ ] ) do
473
514
opts = NimbleOptions . validate! ( opts , @ write_schema )
474
515
EXGBoost.Booster . save ( booster , opts ++ [ path: path , serialize: :config ] )
@@ -481,6 +522,7 @@ defmodule EXGBoost do
481
522
#{ NimbleOptions . docs ( @ dump_schema ) }
482
523
"""
483
524
@ spec dump_config ( Booster . t ( ) ) :: binary ( )
525
+ @ doc type: :serialization
484
526
def dump_config ( % Booster { } = booster , opts \\ [ ] ) do
485
527
opts = NimbleOptions . validate! ( opts , @ dump_schema )
486
528
EXGBoost.Booster . save ( booster , opts ++ [ serialize: :config , to: :buffer ] )
@@ -493,6 +535,7 @@ defmodule EXGBoost do
493
535
#{ NimbleOptions . docs ( @ load_schema ) }
494
536
"""
495
537
@ spec read_config ( String . t ( ) ) :: EXGBoost.Booster . t ( )
538
+ @ doc type: :serialization
496
539
def read_config ( path , opts \\ [ ] ) do
497
540
opts = NimbleOptions . validate! ( opts , @ load_schema )
498
541
EXGBoost.Booster . load ( path , opts ++ [ deserialize: :config ] )
@@ -505,6 +548,7 @@ defmodule EXGBoost do
505
548
#{ NimbleOptions . docs ( @ load_schema ) }
506
549
"""
507
550
@ spec load_config ( binary ( ) ) :: EXGBoost.Booster . t ( )
551
+ @ doc type: :serialization
508
552
def load_config ( buffer , opts \\ [ ] ) do
509
553
opts = NimbleOptions . validate! ( opts , @ load_schema )
510
554
EXGBoost.Booster . load ( buffer , opts ++ [ deserialize: :config , from: :buffer ] )
@@ -517,6 +561,7 @@ defmodule EXGBoost do
517
561
#{ NimbleOptions . docs ( @ write_schema ) }
518
562
"""
519
563
@ spec write_weights ( Booster . t ( ) , String . t ( ) ) :: :ok | { :error , String . t ( ) }
564
+ @ doc type: :serialization
520
565
def write_weights ( % Booster { } = booster , path , opts \\ [ ] ) do
521
566
opts = NimbleOptions . validate! ( opts , @ write_schema )
522
567
EXGBoost.Booster . save ( booster , opts ++ [ path: path , serialize: :weights ] )
@@ -529,6 +574,7 @@ defmodule EXGBoost do
529
574
#{ NimbleOptions . docs ( @ dump_schema ) }
530
575
"""
531
576
@ spec dump_weights ( Booster . t ( ) ) :: binary ( )
577
+ @ doc type: :serialization
532
578
def dump_weights ( % Booster { } = booster , opts \\ [ ] ) do
533
579
opts = NimbleOptions . validate! ( opts , @ dump_schema )
534
580
EXGBoost.Booster . save ( booster , opts ++ [ serialize: :weights , to: :buffer ] )
@@ -538,6 +584,7 @@ defmodule EXGBoost do
538
584
Read a model's trained parameters from a file and return the Booster.
539
585
"""
540
586
@ spec read_weights ( String . t ( ) ) :: EXGBoost.Booster . t ( )
587
+ @ doc type: :serialization
541
588
def read_weights ( path ) do
542
589
EXGBoost.Booster . load ( path , deserialize: :weights )
543
590
end
@@ -546,7 +593,30 @@ defmodule EXGBoost do
546
593
Read a model's trained parameters from a buffer and return the Booster.
547
594
"""
548
595
@ spec load_weights ( binary ( ) ) :: EXGBoost.Booster . t ( )
596
+ @ doc type: :serialization
549
597
def load_weights ( buffer ) do
550
598
EXGBoost.Booster . load ( buffer , deserialize: :weights , from: :buffer )
551
599
end
600
+
601
+ @ doc """
602
+ Plot a tree from a Booster model and save it to a file.
603
+
604
+ ## Options
605
+ * `:format` - the format to export the graphic as, must be either of: `:json`, `:html`, `:png`, `:svg`, `:pdf`. By default the format is inferred from the file extension.
606
+ * `:local_npm_prefix` - a relative path pointing to a local npm project directory where the necessary npm packages are installed. For instance, in Phoenix projects you may want to pass local_npm_prefix: "assets". By default the npm packages are searched for in the current directory and globally.
607
+ * `:path` - the path to save the graphic to. If not provided, the graphic is returned as a VegaLite spec.
608
+ * `:opts` - additional options to pass to `EXGBoost.Plotting.plot/2`. See `EXGBoost.Plotting` for more information.
609
+ """
610
+ @ doc type: :plotting
611
+ def plot_tree ( booster , opts \\ [ ] ) do
612
+ { path , opts } = Keyword . pop ( opts , :path )
613
+ { save_opts , opts } = Keyword . split ( opts , [ :format , :local_npm_prefix ] )
614
+ vega = Plotting . plot ( booster , opts )
615
+
616
+ if path != nil do
617
+ VegaLite.Export . save! ( vega , path , save_opts )
618
+ else
619
+ vega
620
+ end
621
+ end
552
622
end
0 commit comments