This repository has been archived by the owner on Feb 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlaunch_decision_support.py
152 lines (128 loc) · 4.89 KB
/
launch_decision_support.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python3
"""
MSFT Bonsai, decision support deployment template
Copyright 2020 Microsoft
Usage:
streamlit run launch_decision_support.py
"""
from typing import Dict, List, Union
import streamlit as st
import base64, requests, argparse
import pandas as pd
import SessionState # from https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
from exported_brain_interface import ExportedBrainPredictor
def get_state_action_list(get_request: Dict):
states = get_request['state']['required']
actions = get_request['action']['required']
return states, actions
def initialize_brain_interface(exported_brain_url="http://localhost:5000",):
"""Initializes an interface to an exported brain
in: url corresponding to running exported brain docker container
out: brain_interface, list of states names, list of action names
"""
brain = ExportedBrainPredictor(predictor_url=exported_brain_url)
r = requests.get(exported_brain_url + "/validation.json").json()
state_list, action_list = get_state_action_list(
r
)
return brain, state_list, action_list
def get_table_download_link(df):
"""Generates a link allowing the data in a given panda dataframe to be downloaded
in: dataframe
out: href string
"""
csv = df.to_csv(index=False)
b64 = base64.b64encode(
csv.encode()
).decode() # some strings <-> bytes conversions necessary here
href = f'<a href="data:file/csv;base64,{b64}" download="bonsai_state_action.csv">Download .csv file</a>'
return href
def main():
# Initialize exported brain interface
brain, state_list, action_list = initialize_brain_interface(
exported_brain_url=args.exported_brain_url
)
# Define state and action schema
state = {i: [] for i in state_list}
action = {i: [] for i in action_list}
state_action_log = dict(state, **action)
# Initializing iteration count, state_action_log_df. Then session_state.iteration_count will persist on a per user basis
session_state = SessionState.get(
iteration_count=0, state_action_log_df=pd.DataFrame(state_action_log)
)
st.set_page_config(
page_title="Bonsai deployment",
page_icon="./images/msft_icon.png",
layout="centered",
initial_sidebar_state="auto",
)
# ---------
# Sidebar
# ---------
st.sidebar.markdown("Brain running at {}".format(args.exported_brain_url))
st.sidebar.markdown("### Usage")
st.sidebar.markdown(
"- Enter value of each state below then click `Get Brain action` to obtain Brain's actions"
)
st.sidebar.markdown("- To reset chart click `Reset`")
st.sidebar.markdown(" ")
st.sidebar.markdown("### Brain states")
for key in state.keys():
state[key] = st.sidebar.number_input("{} ".format(key))
st.sidebar.markdown("## Reset")
reset_button = st.sidebar.button("Reset",)
st.sidebar.markdown("## Iteration count")
if reset_button:
session_state.iteration_count = 0
session_state.state_action_log_df = pd.DataFrame(state_action_log)
add_text = st.sidebar.write(
"Iteration count: {}".format(session_state.iteration_count)
)
else:
add_text = st.sidebar.write(
"Iteration count: {}".format(session_state.iteration_count)
)
# -----------
# Main page
# -----------
st.image("./images/bonsai-logomark.png", width=70)
"""
# Bonsai decision support
A template to run a decision support interface with a locally running exported Brain.
More info on [preview.bons.ai](https://preview.bons.ai)
"""
get_brain_action = st.button(label="Get Brain action")
"#### Brain Actions "
" "
if get_brain_action:
with st.spinner("Wait for it..."):
session_state.iteration_count = session_state.iteration_count + 1
action = brain.get_action(state)
session_state.state_action_log_df = session_state.state_action_log_df.append(
dict(state, **action), ignore_index=True
)
for key in action.keys():
st.markdown("{}: {}".format(key, action[key]))
"#### State and action table"
st.table(session_state.state_action_log_df)
st.markdown(
get_table_download_link(session_state.state_action_log_df),
unsafe_allow_html=True,
)
"#### State and action chart vs iteration count "
" "
st.line_chart(session_state.state_action_log_df)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="launch a decision support interface to an exported brain"
)
parser.add_argument(
"exported_brain_url",
type=str,
nargs="?",
default="http://localhost:5000",
help="brain url of a running exported brain",
)
args = parser.parse_args()
print("Connecting to exported brain running at: {}".format(args.exported_brain_url))
main()