Skip to content

Commit bb96289

Browse files
authored
implement nested benchmarks with snapshots (#280)
BREAKING: added benchmark obj to Skill's execute_tool function
1 parent addfa07 commit bb96289

File tree

44 files changed

+3269
-2315
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+3269
-2315
lines changed

api/commands.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
from typing import Literal, Optional
22
from pydantic import BaseModel
3-
from api.enums import CommandTag, KeyboardRecordingType, LogSource, LogType, RecordingDevice, ToastType
4-
from api.interface import CommandActionConfig
3+
from api.enums import (
4+
CommandTag,
5+
KeyboardRecordingType,
6+
LogSource,
7+
LogType,
8+
RecordingDevice,
9+
ToastType,
10+
)
11+
from api.interface import CommandActionConfig, BenchmarkResult
512

613

714
# We use this Marker base class for reflection to "iterate all commands"
@@ -28,10 +35,11 @@ class RecordKeyboardActionsCommand(WebSocketCommandModel):
2835
command: Literal["record_keyboard_actions"] = "record_keyboard_actions"
2936
recording_type: KeyboardRecordingType
3037

38+
3139
class RecordMouseActionsCommand(WebSocketCommandModel):
3240
command: Literal["record_mouse_actions"] = "record_mouse_actions"
3341

34-
42+
3543
class RecordJoystickActionsCommand(WebSocketCommandModel):
3644
command: Literal["record_joystick_actions"] = "record_joystick_actions"
3745

@@ -55,6 +63,7 @@ class LogCommand(WebSocketCommandModel):
5563
tag: Optional[CommandTag] = None
5664
skill_name: Optional[str] = None
5765
additional_data: Optional[dict] = None
66+
benchmark_result: Optional[BenchmarkResult] = None
5867

5968

6069
class PromptSecretCommand(WebSocketCommandModel):

api/interface.py

+10
Original file line numberDiff line numberDiff line change
@@ -742,3 +742,13 @@ class SettingsConfig(BaseModel):
742742
wingman_pro: WingmanProSettings
743743
xvasynth: XVASynthSettings
744744
debug_mode: bool = False
745+
746+
747+
class BenchmarkResult(BaseModel):
748+
label: str
749+
execution_time_ms: float
750+
formatted_execution_time: str
751+
snapshots: Optional[list["BenchmarkResult"]] = None
752+
753+
754+
BenchmarkResult.model_rebuild()

main.py

+18
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from fastapi.middleware.cors import CORSMiddleware
1515
from fastapi.openapi.utils import get_openapi
1616
from api.commands import WebSocketCommandModel
17+
from api.interface import BenchmarkResult
1718
from api.enums import ENUM_TYPES, LogType, WingmanInitializationErrorType
1819
import keyboard.keyboard as keyboard
1920
from services.command_handler import CommandHandler
@@ -256,6 +257,23 @@ async def ping():
256257
return "Ok" if core.is_started else "Starting"
257258

258259

260+
# required to generate API specs for class BenchmarkResult that is only used internally
261+
@app.get("/dummy-benchmark", tags=["main"], response_model=BenchmarkResult)
262+
async def get_dummy_benchmark():
263+
return BenchmarkResult(
264+
label="Sample Benchmark",
265+
execution_time_ms=150.0,
266+
formatted_execution_time=0.15,
267+
snapshots=[
268+
BenchmarkResult(
269+
label="Sub Benchmark",
270+
execution_time_ms=75.0,
271+
formatted_execution_time=0.075,
272+
)
273+
],
274+
)
275+
276+
259277
async def async_main(host: str, port: int, sidecar: bool):
260278
await core.config_service.migrate_configs(system_manager)
261279
await core.config_service.load_config()

services/benchmark.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import time
2+
from api.enums import LogType
3+
from api.interface import BenchmarkResult
4+
from services.printr import Printr
5+
6+
7+
class Benchmark:
8+
def __init__(self, label: str):
9+
self.label = label
10+
self.snapshot_label: str = None
11+
self.start_time = time.perf_counter()
12+
self.snapshot_start_time: float = None
13+
self.snapshots: list[BenchmarkResult] = []
14+
self.printr = Printr()
15+
16+
def finish(self):
17+
if self.snapshot_label or self.snapshot_start_time:
18+
self.finish_snapshot()
19+
self.printr.print(
20+
f"Snapshot benchmark '{self.snapshot_label}' was still running when finishing '{self.label}'.",
21+
color=LogType.WARNING,
22+
server_only=True,
23+
)
24+
result = self._create_benchmark_result(self.label, self.start_time)
25+
if len(self.snapshots) > 0:
26+
result.snapshots = self.snapshots
27+
return result
28+
29+
def start_snapshot(self, label: str):
30+
if self.snapshot_label or self.snapshot_start_time:
31+
self.finish_snapshot()
32+
self.printr.print(
33+
f"Snapshot benchmark '{self.snapshot_label}' was still running when starting '{label}'.",
34+
color=LogType.WARNING,
35+
server_only=True,
36+
)
37+
self.snapshot_label = label
38+
self.snapshot_start_time = time.perf_counter()
39+
40+
def finish_snapshot(self):
41+
try:
42+
result = self._create_benchmark_result(
43+
label=self.snapshot_label, start_time=self.snapshot_start_time
44+
)
45+
self.snapshots.append(result)
46+
except Exception:
47+
pass
48+
self.snapshot_label = None
49+
self.snapshot_start_time = None
50+
51+
def _create_benchmark_result(self, label: str, start_time: float):
52+
end_time = time.perf_counter()
53+
execution_time = (end_time - start_time) * 1000 # Convert to milliseconds
54+
if execution_time >= 1000:
55+
formatted_execution_time = f"{execution_time/1000:.1f}s"
56+
else:
57+
formatted_execution_time = f"{int(execution_time)}ms"
58+
59+
return BenchmarkResult(
60+
label=label,
61+
execution_time_ms=execution_time,
62+
formatted_execution_time=formatted_execution_time,
63+
)

services/printr.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,30 @@
77
from os import path
88
from api.commands import LogCommand, ToastCommand
99
from api.enums import CommandTag, LogSource, LogType, ToastType
10+
from api.interface import BenchmarkResult
1011
from services.file import get_writable_dir
1112
from services.websocket_user import WebSocketUser
1213

14+
1315
class StreamToLogger:
1416
def __init__(self, logger, log_level=logging.INFO, stream=sys.stdout):
1517
self.logger = logger
1618
self.log_level = log_level
17-
self.linebuf = ''
19+
self.linebuf = ""
1820
self.stream = stream
1921

2022
def write(self, buf):
2123
for line in buf.rstrip().splitlines():
2224
self.logger.log(self.log_level, line.rstrip())
23-
self.stream.write(line + '\n')
25+
self.stream.write(line + "\n")
2426

2527
def flush(self):
2628
self.stream.flush()
2729

2830
def isatty(self):
2931
return False
3032

33+
3134
class Printr(WebSocketUser):
3235
"""Singleton"""
3336

@@ -50,7 +53,7 @@ class Printr(WebSocketUser):
5053
def __new__(cls):
5154
if cls._instance is None:
5255
cls._instance = super(Printr, cls).__new__(cls)
53-
cls._instance.logger = logging.getLogger('file_logger')
56+
cls._instance.logger = logging.getLogger("file_logger")
5457
cls._instance.logger.setLevel(logging.INFO)
5558

5659
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
@@ -60,16 +63,18 @@ def __new__(cls):
6063
path.join(get_writable_dir("logs"), f"wingman-core.{timestamp}.log")
6164
)
6265
fh.setLevel(logging.DEBUG)
63-
file_formatter = Formatter('%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
66+
file_formatter = Formatter(
67+
"%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
68+
)
6469
fh.setFormatter(file_formatter)
6570
cls._instance.logger.addHandler(fh)
6671

6772
# Console logger with color
68-
cls._instance.console_logger = logging.getLogger('console_logger')
73+
cls._instance.console_logger = logging.getLogger("console_logger")
6974
cls._instance.console_logger.setLevel(logging.INFO)
7075
ch = logging.StreamHandler()
7176
ch.setLevel(logging.INFO)
72-
console_formatter = Formatter('%(message)s')
77+
console_formatter = Formatter("%(message)s")
7378
ch.setFormatter(console_formatter)
7479
cls._instance.console_logger.addHandler(ch)
7580

@@ -88,6 +93,7 @@ async def __send_to_gui(
8893
command_tag: CommandTag = None,
8994
skill_name: str = "",
9095
additional_data: dict = None,
96+
benchmark_result: BenchmarkResult = None,
9197
):
9298
if self._connection_manager is None:
9399
raise ValueError("connection_manager has not been set.")
@@ -102,10 +108,13 @@ async def __send_to_gui(
102108
if current_frame is not None:
103109
while current_frame:
104110
# Check if the caller is a method of a class
105-
if 'self' in current_frame.f_locals:
106-
caller_instance = current_frame.f_locals['self']
111+
if "self" in current_frame.f_locals:
112+
caller_instance = current_frame.f_locals["self"]
107113
caller_instance_name = caller_instance.__class__.__name__
108-
if caller_instance_name == "Wingman" or caller_instance_name == "OpenAiWingman":
114+
if (
115+
caller_instance_name == "Wingman"
116+
or caller_instance_name == "OpenAiWingman"
117+
):
109118
wingman_name = caller_instance.name
110119
break
111120
# Move to the previous frame in the call stack
@@ -121,6 +130,7 @@ async def __send_to_gui(
121130
skill_name=skill_name,
122131
additional_data=additional_data,
123132
wingman_name=wingman_name,
133+
benchmark_result=benchmark_result,
124134
)
125135
)
126136

@@ -163,9 +173,23 @@ async def print_async(
163173
command_tag: CommandTag = None,
164174
skill_name: str = "",
165175
additional_data: dict = None,
176+
benchmark_result: BenchmarkResult = None,
166177
):
167178
# print to server (terminal)
168-
self.print_colored(text, color=self.get_terminal_color(color))
179+
self.print_colored(
180+
(
181+
text
182+
if not benchmark_result
183+
else f"{text} ({benchmark_result.formatted_execution_time})"
184+
),
185+
color=self.get_terminal_color(color),
186+
)
187+
if benchmark_result and benchmark_result.snapshots:
188+
for snapshot in benchmark_result.snapshots:
189+
self.print_colored(
190+
f" - {snapshot.label}: {snapshot.formatted_execution_time}",
191+
color=self.get_terminal_color(color),
192+
)
169193

170194
if not server_only and self._connection_manager is not None:
171195
await self.__send_to_gui(
@@ -177,6 +201,7 @@ async def print_async(
177201
command_tag=command_tag,
178202
skill_name=skill_name,
179203
additional_data=additional_data,
204+
benchmark_result=benchmark_result,
180205
)
181206

182207
def toast(self, text: str):

0 commit comments

Comments
 (0)