Skip to content

Commit 8001777

Browse files
committed
Re-format onnxmltools/convert/lightgbm after merge
Signed-off-by: Jett Jackson <[email protected]>
1 parent 20ec503 commit 8001777

File tree

4 files changed

+16
-14
lines changed

4 files changed

+16
-14
lines changed

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
# Associate scikit-learn types with our operator names.
1919
# If two scikit-learn models share a single name, it means their
2020
# are equivalent in terms of conversion.
21-
lightgbm_operator_name_map = {LGBMClassifier: 'LgbmClassifier',
22-
LGBMRegressor: 'LgbmRegressor',
23-
LGBMRanker: 'LgbmRanker'}
21+
lightgbm_operator_name_map = {
22+
LGBMClassifier: "LgbmClassifier",
23+
LGBMRegressor: "LgbmRegressor",
24+
LGBMRanker: "LgbmRanker",
25+
}
2426

2527

2628
class WrappedBooster:
@@ -34,8 +36,8 @@ def __init__(self, booster):
3436
elif self.objective_.startswith("multiclass"):
3537
self.operator_name = "LgbmClassifier"
3638
self.classes_ = self._generate_classes(booster)
37-
elif self.objective_.startswith('lambdarank'):
38-
self.operator_name = 'LgbmRanker'
39+
elif self.objective_.startswith("lambdarank"):
40+
self.operator_name = "LgbmRanker"
3941
elif self.objective_.startswith(
4042
("regression", "poisson", "gamma", "quantile", "huber", "tweedie")
4143
):

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -566,10 +566,10 @@ def convert_lightgbm(scope, operator, container):
566566
# so we need to add an 'Exp' post transform node to the model
567567
attrs["post_transform"] = "NONE"
568568
post_transform = "Exp"
569-
elif gbm_text['objective'].startswith('lambdarank'):
569+
elif gbm_text["objective"].startswith("lambdarank"):
570570
n_classes = 1 # Ranker has only one output variable
571-
attrs['post_transform'] = 'NONE'
572-
attrs['n_targets'] = n_classes
571+
attrs["post_transform"] = "NONE"
572+
attrs["n_targets"] = n_classes
573573
else:
574574
raise RuntimeError(
575575
"LightGBM objective should be cleaned already not '{}'.".format(
@@ -1027,7 +1027,7 @@ def convert_lgbm_zipmap(scope, operator, container):
10271027
)
10281028

10291029

1030-
register_converter('LgbmClassifier', convert_lightgbm)
1031-
register_converter('LgbmRegressor', convert_lightgbm)
1032-
register_converter('LgbmZipMap', convert_lgbm_zipmap)
1033-
register_converter('LgbmRanker', convert_lightgbm)
1030+
register_converter("LgbmClassifier", convert_lightgbm)
1031+
register_converter("LgbmRegressor", convert_lightgbm)
1032+
register_converter("LgbmZipMap", convert_lgbm_zipmap)
1033+
register_converter("LgbmRanker", convert_lightgbm)

onnxmltools/convert/lightgbm/shape_calculators/Ranker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from ...common._registration import register_shape_calculator
44
from ...common.shape_calculator import calculate_linear_regressor_output_shapes
55

6-
register_shape_calculator('LgbmRanker', calculate_linear_regressor_output_shapes)
6+
register_shape_calculator("LgbmRanker", calculate_linear_regressor_output_shapes)

onnxmltools/convert/lightgbm/shape_calculators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
# To register shape calculators for lightgbm operators, import associated modules here.
44
from . import Classifier
55
from . import Regressor
6-
from . import Ranker
6+
from . import Ranker

0 commit comments

Comments
 (0)