Skip to content

Commit 89c1918

Browse files
SimonBartelsSimon Bartels
andauthored
Feature mlflow observer (#290)
* adds an mlflow observer * lints the file * lints file with pre-commit * allows the user to set the mlflow tracking uri --------- Co-authored-by: Simon Bartels <[email protected]>
1 parent 8320f69 commit 89c1918

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/poli/core/util/observers/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from pathlib import Path
2+
3+
import mlflow
4+
import numpy as np
5+
6+
from poli.core.black_box_information import BlackBoxInformation
7+
from poli.core.util.abstract_observer import AbstractObserver
8+
9+
TRACKING_URI = "tracking_uri"
10+
OBJECTIVE = "OBJECTIVE"
11+
SEQUENCE = "SEQUENCE"
12+
SEED = "SEED"
13+
14+
15+
class MLFlowObserver(AbstractObserver):
16+
"""
17+
This observer uses mlflow as a backend.
18+
"""
19+
20+
def __init__(self, tracking_uri: Path = None):
21+
self.step = 0
22+
if tracking_uri is not None:
23+
mlflow.set_tracking_uri(tracking_uri)
24+
25+
def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None:
26+
for n in range(y.shape[0]):
27+
self.step += 1
28+
mlflow.log_metrics(
29+
{OBJECTIVE + str(i): y[n, i].item() for i in range(y.shape[1])},
30+
step=self.step,
31+
)
32+
# with mlflow it's unfortunately not so easy to log sequences
33+
mlflow.log_param(str(self.step) + SEQUENCE, x[n, ...])
34+
35+
def log(self, algorithm_info: dict):
36+
mlflow.log_metrics(algorithm_info, step=self.step)
37+
38+
def initialize_observer(
39+
self,
40+
problem_setup_info: BlackBoxInformation,
41+
caller_info: dict,
42+
seed: int,
43+
) -> object:
44+
tracking_uri = caller_info.pop(TRACKING_URI, None)
45+
if tracking_uri is not None:
46+
mlflow.set_tracking_uri(tracking_uri)
47+
48+
experiment = mlflow.set_experiment(
49+
experiment_name=problem_setup_info.get_problem_name()
50+
)
51+
run = mlflow.start_run(experiment_id=experiment.experiment_id)
52+
mlflow.set_tag(SEED, str(seed))
53+
mlflow.set_tags(caller_info)
54+
return run.info.run_id
55+
56+
def finish(self) -> None:
57+
mlflow.end_run()

0 commit comments

Comments
 (0)