Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion onnxmltools/convert/lightgbm/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Int64Type,
)

from lightgbm import LGBMClassifier, LGBMRegressor
from lightgbm import LGBMClassifier, LGBMRegressor, LGBMRanker

lightgbm_classifier_list = [LGBMClassifier]

Expand All @@ -21,6 +21,7 @@
lightgbm_operator_name_map = {
LGBMClassifier: "LgbmClassifier",
LGBMRegressor: "LgbmRegressor",
LGBMRanker: "LgbmRanker",
}


Expand All @@ -35,6 +36,8 @@ def __init__(self, booster):
elif self.objective_.startswith("multiclass"):
self.operator_name = "LgbmClassifier"
self.classes_ = self._generate_classes(booster)
elif self.objective_.startswith("lambdarank"):
self.operator_name = "LgbmRanker"
elif self.objective_.startswith(
("regression", "poisson", "gamma", "quantile", "huber", "tweedie")
):
Expand Down
5 changes: 5 additions & 0 deletions onnxmltools/convert/lightgbm/operator_converters/LightGbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,10 @@ def convert_lightgbm(scope, operator, container):
# so we need to add an 'Exp' post transform node to the model
attrs["post_transform"] = "NONE"
post_transform = "Exp"
elif gbm_text["objective"].startswith("lambdarank"):
n_classes = 1 # Ranker has only one output variable
attrs["post_transform"] = "NONE"
attrs["n_targets"] = n_classes
else:
raise RuntimeError(
"LightGBM objective should be cleaned already not '{}'.".format(
Expand Down Expand Up @@ -1026,3 +1030,4 @@ def convert_lgbm_zipmap(scope, operator, container):
register_converter("LgbmClassifier", convert_lightgbm)
register_converter("LgbmRegressor", convert_lightgbm)
register_converter("LgbmZipMap", convert_lgbm_zipmap)
register_converter("LgbmRanker", convert_lightgbm)
6 changes: 6 additions & 0 deletions onnxmltools/convert/lightgbm/shape_calculators/Ranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from ...common._registration import register_shape_calculator
from ...common.shape_calculator import calculate_linear_regressor_output_shapes

register_shape_calculator("LgbmRanker", calculate_linear_regressor_output_shapes)
1 change: 1 addition & 0 deletions onnxmltools/convert/lightgbm/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
# To register shape calculators for lightgbm operators, import associated modules here.
from . import Classifier
from . import Regressor
from . import Ranker