-
Notifications
You must be signed in to change notification settings - Fork 305
Expand file tree
/
Copy pathtuning_param.py
More file actions
133 lines (105 loc) · 4.53 KB
/
tuning_param.py
File metadata and controls
133 lines (105 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The tunable parameters module."""
import typing
from enum import Enum, auto
from typing import Any
from pydantic import BaseModel
from neural_compressor.common import logger
class ParamLevel(Enum):
"""Enumeration representing the different levels of tuning parameters.
Attributes:
OP_LEVEL: Represents the level of tuning parameters for operations.
OP_TYPE_LEVEL: Represents the level of tuning parameters for operation types.
MODEL_LEVEL: Represents the level of tuning parameters for models.
"""
OP_LEVEL = auto()
OP_TYPE_LEVEL = auto()
MODEL_LEVEL = auto()
class TuningParam:
"""Define the tunable parameter for the algorithm.
Example:
Class FakeAlgoConfig(BaseConfig):
'''Fake algo config.'''.
params_list = [
...
# For simple tunable types, like a list of int, giving
# the param name is enough. `BaseConfig` class will
# create the `TuningParam` implicitly.
"simple_attr"
# For complex tunable types, like a list of lists,
# developers need to create the `TuningParam` explicitly.
TuningParam("complex_attr", tunable_type=List[List[str]])
# The default parameter level is `ParamLevel.OP_LEVEL`.
# If the parameter is at a different level, developers need
# to specify it explicitly.
TuningParam("model_attr", level=ParamLevel.MODEL_LEVEL)
...
# TODO: more examples to explain the usage of `TuningParam`.
"""
def __init__(
self,
name: str,
default_val: Any = None,
tunable_type=None,
options=None,
level: ParamLevel = ParamLevel.OP_LEVEL,
) -> None:
"""Initialize a TuningParam object.
Args:
name (str): The name of the tuning parameter.
default_val (Any, optional): The default value of the tuning parameter. Defaults to None.
tunable_type (optional): The type of the tuning parameter. Defaults to None.
options (optional): The available options for the tuning parameter. Defaults to None.
level (ParamLevel, optional): The level of the tuning parameter. Defaults to ParamLevel.OP_LEVEL.
"""
self.name = name
self.default_val = default_val
self.tunable_type = tunable_type
self.options = options
self.level = level
@staticmethod
def create_input_args_model(expect_args_type: Any):
"""Dynamically create an InputArgsModel based on the provided type hint.
Args:
expect_args_type (Any): The user-provided type hint for input_args.
Returns:
The dynamically created InputArgsModel class.
"""
class DynamicInputArgsModel(BaseModel):
"""Pydantic model for validating dynamic input arguments."""
input_args: expect_args_type
return DynamicInputArgsModel
def is_tunable(self, value: Any) -> bool:
"""Checks if the given value is tunable based on the specified tunable type.
Args:
value (Any): The value to be checked for tunability.
Returns:
bool: True if the value is tunable, False otherwise.
"""
# Use `Pydantic` to validate the input_args.
# TODO: refine the implementation in further.
assert isinstance(
self.tunable_type, typing._GenericAlias
), f"Expected a type hint, got {self.tunable_type} instead."
try:
DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type)
new_args = DynamicInputArgsModel(input_args=value)
return True
except Exception as e:
logger.debug(f"Failed to validate the input_args: {e}")
return False
def __str__(self) -> str:
"""Return the name of the tuning parameter."""
return self.name