Skip to content

Commit 3102ddc

Browse files
authored
Add settings for model averaging in the frontend (#620)
1 parent 21ea4c4 commit 3102ddc

File tree

25 files changed

+116
-8
lines changed

25 files changed

+116
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Released changes are shown in the
1111
## [Not released]
1212

1313
### Added
14+
* Prediction after BMA can now be displayed in the app.
1415

1516
### Changed
1617

azimuth/modules/base_classes/aggregation_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class FilterableModule(AggregationModule[ConfigScope], ExpirableMixin, ABC):
3131
"""Filterable Module are affected by filters in mod options."""
3232

3333
required_mod_options = {"pipeline_index"}
34-
optional_mod_options = {"filters", "without_postprocessing"}
34+
optional_mod_options = {"filters", "without_postprocessing", "use_bma"}
3535

3636
def get_dataset_split(self, name: DatasetSplitName = None) -> Dataset:
3737
"""Get the specified dataset_split, filtered according to mod_options.

azimuth/modules/model_contracts/text_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def post_process(self, batch: Dataset) -> List[PredictionResponse]:
109109
mod_options=ModuleOptions(
110110
model_contract_method_name=SupportedMethod.Predictions,
111111
pipeline_index=self.mod_options.pipeline_index,
112+
use_bma=self.mod_options.use_bma,
112113
indices=cast(List[int], batch[DatasetColumn.row_idx]),
113114
),
114115
)

azimuth/modules/model_performance/confidence_binning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class ConfidenceBinIndexModule(DatasetResultModule[ModelContractConfig]):
114114
"""Return confidence bin indices for the selected dataset split."""
115115

116116
required_mod_options = {"pipeline_index"}
117-
optional_mod_options = DatasetResultModule.optional_mod_options | {"threshold"}
117+
optional_mod_options = DatasetResultModule.optional_mod_options | {"threshold", "use_bma"}
118118

119119
def compute_on_dataset_split(self) -> List[int]: # type: ignore
120120
"""Get the bin indices for each utterance in the dataset split.

azimuth/modules/model_performance/outcome_count.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class OutcomeCountPerThresholdModule(AggregationModule[ModelContractConfig]):
178178
"""Compute the outcome count per threshold."""
179179

180180
required_mod_options = {"pipeline_index"}
181-
optional_mod_options = {"x_ticks_count"}
181+
optional_mod_options = {"x_ticks_count", "use_bma"}
182182

183183
def compute_on_dataset_split(self) -> List[OutcomeCountPerThresholdResponse]: # type: ignore
184184
if not postprocessing_editable(self.config, self.mod_options.pipeline_index):
@@ -199,6 +199,7 @@ def compute_on_dataset_split(self) -> List[OutcomeCountPerThresholdResponse]: #
199199
# Convert to float instead of numpy.float64
200200
threshold=float(th),
201201
pipeline_index=self.mod_options.pipeline_index,
202+
use_bma=self.mod_options.use_bma,
202203
),
203204
)
204205
outcomes = outcomes_mod.compute_on_dataset_split()

azimuth/modules/model_performance/outcomes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class OutcomesModule(DatasetResultModule[ModelContractConfig]):
2020
"""Computes the outcome for each utterance in the dataset split."""
2121

2222
required_mod_options = {"pipeline_index"}
23-
optional_mod_options = DatasetResultModule.optional_mod_options | {"threshold"}
23+
optional_mod_options = DatasetResultModule.optional_mod_options | {"threshold", "use_bma"}
2424

2525
def _get_predictions(self, without_postprocessing: bool) -> ndarray:
2626
mod_options = self.mod_options.copy(deep=True)

azimuth/routers/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def get_dataset_info(
104104
eval=eval_dm and eval_dm.num_rows, train=training_dm and training_dm.num_rows
105105
),
106106
similarity_available=similarity_available(config),
107+
model_averaging_available=config.uncertainty is not None
108+
and config.uncertainty.iterations > 1,
107109
postprocessing_editable=None
108110
if config.pipelines is None
109111
else [postprocessing_editable(config, idx) for idx in range(len(config.pipelines))],

azimuth/routers/model_performance/confidence_histogram.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ def get_confidence_histogram(
3131
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
3232
pipeline_index: int = Depends(require_pipeline_index),
3333
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
34+
use_bma: bool = Query(False, title="Use BMA"),
3435
) -> ConfidenceHistogramResponse:
3536
mod_options = ModuleOptions(
3637
filters=named_filters.to_dataset_filters(dataset_split_manager.get_class_names()),
3738
pipeline_index=pipeline_index,
39+
use_bma=use_bma,
3840
without_postprocessing=without_postprocessing,
3941
)
4042

azimuth/routers/model_performance/confusion_matrix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ def get_confusion_matrix(
3131
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
3232
pipeline_index: int = Depends(require_pipeline_index),
3333
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
34+
use_bma: bool = Query(False, title="Use BMA"),
3435
normalize: bool = Query(True, title="Normalize"),
3536
reorder_classes: bool = Query(True, title="Reorder Classes"),
3637
) -> ConfusionMatrixResponse:
3738
mod_options = ModuleOptions(
3839
filters=named_filters.to_dataset_filters(dataset_split_manager.get_class_names()),
3940
pipeline_index=pipeline_index,
41+
use_bma=use_bma,
4042
without_postprocessing=without_postprocessing,
4143
cf_normalize=normalize,
4244
cf_reorder_classes=reorder_classes,

azimuth/routers/model_performance/metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ def get_metrics(
3939
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
4040
pipeline_index: int = Depends(require_pipeline_index),
4141
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
42+
use_bma: bool = Query(False, title="Use BMA"),
4243
) -> MetricsAPIResponse:
4344
mod_options = ModuleOptions(
4445
filters=named_filters.to_dataset_filters(dataset_split_manager.get_class_names()),
4546
pipeline_index=pipeline_index,
47+
use_bma=use_bma,
4648
without_postprocessing=without_postprocessing,
4749
)
4850

0 commit comments

Comments
 (0)