1+ import ast
2+ import inspect
3+ import sys
4+ import time
5+ import traceback
6+ from collections import defaultdict
7+ import textwrap
8+ import numpy as np
9+ from amadeusgpt .analysis_objects .event import Event
10+ from amadeusgpt .logger import AmadeusLogger
11+ from IPython .display import Markdown , Video , display , HTML
12+
13+ def filter_kwargs_for_function (func , kwargs ):
14+ sig = inspect .signature (func )
15+ return {k : v for k , v in kwargs .items () if k in sig .parameters }
16+
17+ def timer_decorator (func ):
18+ def wrapper (* args , ** kwargs ):
19+ start_time = time .time () # before calling the function
20+ result = func (* args , ** kwargs ) # call the function
21+ end_time = time .time () # after calling the function
22+ AmadeusLogger .debug (
23+ f"The function { func .__name__ } took { end_time - start_time } seconds to execute."
24+ )
25+ print (
26+ f"The function { func .__name__ } took { end_time - start_time } seconds to execute."
27+ )
28+ return result
29+ return wrapper
30+
31+ def parse_error_message_from_python ():
32+ exc_type , exc_value , exc_traceback = sys .exc_info ()
33+ traceback_str = "" .join (
34+ traceback .format_exception (exc_type , exc_value , exc_traceback )
35+ )
36+ return traceback_str
37+
38+ def validate_openai_api_key (key ):
39+ import openai
40+ openai .api_key = key
41+ try :
42+ openai .models .list ()
43+ return True
44+ except openai .AuthenticationError :
45+ return False
46+
47+ def flatten_tuple (t ):
48+ """
49+ Used to handle function returns
50+ """
51+ flattened = []
52+ for item in t :
53+ if isinstance (item , tuple ):
54+ flattened .extend (flatten_tuple (item ))
55+ else :
56+ flattened .append (item )
57+ return tuple (flattened )
58+
59+ def func2json (func ):
60+ if isinstance (func , str ):
61+ func_str = textwrap .dedent (func )
62+ parsed = ast .parse (func_str )
63+ func_def = parsed .body [0 ]
64+ func_name = func_def .name
65+ docstring = ast .get_docstring (func_def )
66+ if (
67+ func_def .body
68+ and isinstance (func_def .body [0 ], ast .Expr )
69+ and isinstance (func_def .body [0 ].value , (ast .Str , ast .Constant ))
70+ ):
71+ func_def .body .pop (0 )
72+ func_def .decorator_list = []
73+ if hasattr (ast , "unparse" ):
74+ source_without_docstring_or_decorators = ast .unparse (func_def )
75+ else :
76+ source_without_docstring_or_decorators = None
77+ return_annotation = "No return annotation"
78+ if func_def .returns :
79+ return_annotation = ast .unparse (func_def .returns )
80+ json_obj = {
81+ "name" : func_name ,
82+ "inputs" : "" ,
83+ "source_code" : source_without_docstring_or_decorators ,
84+ "docstring" : docstring ,
85+ "return" : return_annotation ,
86+ }
87+ return json_obj
88+ else :
89+ sig = inspect .signature (func )
90+ inputs = {name : str (param .annotation ) for name , param in sig .parameters .items ()}
91+ docstring = inspect .getdoc (func )
92+ if docstring :
93+ docstring = textwrap .dedent (docstring )
94+ full_source = inspect .getsource (func )
95+ parsed = ast .parse (textwrap .dedent (full_source ))
96+ func_def = parsed .body [0 ]
97+ if (
98+ func_def .body
99+ and isinstance (func_def .body [0 ], ast .Expr )
100+ and isinstance (func_def .body [0 ].value , (ast .Str , ast .Constant ))
101+ ):
102+ func_def .body .pop (0 )
103+ func_def .decorator_list = []
104+ if hasattr (ast , "unparse" ):
105+ source_without_docstring_or_decorators = ast .unparse (func_def )
106+ else :
107+ source_without_docstring_or_decorators = None
108+ json_obj = {
109+ "name" : func .__name__ ,
110+ "inputs" : inputs ,
111+ "source_code" : textwrap .dedent (source_without_docstring_or_decorators ),
112+ "docstring" : docstring ,
113+ "return" : str (sig .return_annotation ),
114+ }
115+ return json_obj
116+
117+ class QA_Message :
118+ def __init__ (self , query : str , video_file_paths : list [str ]):
119+ self .query = query
120+ self .video_file_paths = video_file_paths
121+ self .code = None
122+ self .chain_of_thought = None
123+ self .error_message = defaultdict (list )
124+ self .plots = defaultdict (list )
125+ self .out_videos = defaultdict (list )
126+ self .pose_video = defaultdict (list )
127+ self .function_rets = defaultdict (list )
128+ self .meta_info = {}
129+ def get_masks (self ) -> dict [str , np .ndarray ]:
130+ ret = {}
131+ function_rets = self .function_rets
132+ for video_path , rets in function_rets .items ():
133+ if isinstance (rets , list ) and len (rets ) > 0 and isinstance (rets [0 ], Event ):
134+ events = rets
135+ masks = []
136+ for event in events :
137+ masks .append (event .generate_mask ())
138+ ret [video_path ] = np .array (masks )
139+ else :
140+ ret [video_path ] = None
141+ return ret
142+ def serialize_qa_message (self ):
143+ return {
144+ "query" : self .query ,
145+ "video_file_paths" : self .video_file_paths ,
146+ "code" : self .code ,
147+ "chain_of_thought" : self .chain_of_thought ,
148+ "error_message" : self .error_message ,
149+ "plots" : None ,
150+ "out_videos" : self .out_videos ,
151+ "pose_video" : self .pose_video ,
152+ "function_rets" : self .function_rets ,
153+ "meta_info" : self .meta_info ,
154+ }
155+ def create_qa_message (query : str , video_file_paths : list [str ]) -> QA_Message :
156+ return QA_Message (query , video_file_paths )
157+ def parse_result (amadeus , qa_message , use_ipython = True , skip_code_execution = False ):
158+ if use_ipython :
159+ display (Markdown (qa_message .chain_of_thought ))
160+ else :
161+ print (qa_message .chain_of_thought )
162+ sandbox = amadeus .sandbox
163+ if not skip_code_execution :
164+ qa_message = sandbox .code_execution (qa_message )
165+ qa_message = sandbox .render_qa_message (qa_message )
166+ if len (qa_message .out_videos ) > 0 :
167+ print (f"videos generated to { qa_message .out_videos } " )
168+ print (
169+ "Open it with media player if it does not properly display in the notebook"
170+ )
171+ if use_ipython :
172+ if len (qa_message .out_videos ) > 0 :
173+ for identifier , event_videos in qa_message .out_videos .items ():
174+ for event_video in event_videos :
175+ display (Video (event_video , embed = True ))
176+ if use_ipython :
177+ from matplotlib .animation import FuncAnimation
178+ if len (qa_message .function_rets ) > 0 :
179+ for identifier , rets in qa_message .function_rets .items ():
180+ if not isinstance (rets , (tuple , list )):
181+ rets = [rets ]
182+ for ret in rets :
183+ if isinstance (ret , FuncAnimation ):
184+ display (HTML (ret .to_jshtml ()))
185+ else :
186+ display (Markdown (str (qa_message .function_rets [identifier ])))
187+ return qa_message
188+
189+ def patch_pytorch_weights_only ():
190+ """
191+ Patch for PyTorch 2.6 weights_only issue with DeepLabCut SuperAnimal models.
192+ This adds safe globals to allow loading of ruamel.yaml.scalarfloat.ScalarFloat objects.
193+ Only applies the patch if torch.serialization.add_safe_globals exists (PyTorch >=2.6).
194+ """
195+ try :
196+ import torch
197+ from ruamel .yaml .scalarfloat import ScalarFloat
198+ if hasattr (torch .serialization , "add_safe_globals" ):
199+ torch .serialization .add_safe_globals ([ScalarFloat ])
200+ except ImportError :
201+ pass # If ruamel.yaml is not available, continue without the patch
0 commit comments