diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..1b9c357 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,38 @@ +# Repository Guidelines + +These instructions apply to the entire repository. Follow them when modifying or adding files. + +## Environment Setup +- Use **Python 3.11+**. +- Install dependencies: + ```bash + pip install -r requirements.txt + pip install black flake8 # development tools + ``` +- Verify that all imported modules are listed in `requirements.txt`: + ```bash + python utils/verify_requirements.py + ``` + +## Formatting +- Format code with **Black** before committing: + ```bash + black . + ``` +- Use the default line length (88 characters) and 4‑space indentation. + +## Linting +- Run **flake8** and resolve all issues: + ```bash + flake8 --max-line-length=120 + ``` + +## Testing +- Execute the test suite with **pytest** and ensure it passes: + ```bash + pytest -v + ``` + +## Pull Requests +- Make concise, focused commits using present‑tense messages (e.g., "Add combo detection module"). +- Verify formatting, linting, and tests succeed before submitting a PR. diff --git a/README.md b/README.md index 1ff0ae3..c5420ef 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # 🥊 PunchTracker - AI Boxing Assistant 🥊 ![License](https://img.shields.io/badge/license-MIT-blue.svg) -![Python](https://img.shields.io/badge/python-v3.8+-blue.svg) +![Python](https://img.shields.io/badge/python-v3.11+-blue.svg) ![TensorFlow](https://img.shields.io/badge/TensorFlow-v2.12+-orange.svg) ![OpenCV](https://img.shields.io/badge/OpenCV-v4.5+-green.svg) ![Status](https://img.shields.io/badge/status-active-success.svg) @@ -32,7 +32,7 @@ A real-time boxing assistant application that uses computer vision and machine l ### Prerequisites -- Python 3.8 or higher +- Python 3.11 or higher - Webcam - Required libraries (see requirements.txt) diff --git a/main.py b/main.py index fcae666..73e201c 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from utils.data_manager import DataManager from utils.calibration import Calibrator + class PunchTracker: def __init__(self): # Initialize components @@ -23,20 +24,20 @@ def __init__(self): self.ui_manager = UIManager() self.data_manager = DataManager() self.calibrator = Calibrator() - + # Application state self.is_running = False self.is_calibrating = False self.show_debug = False self.is_paused = False self.session_start_time = None - + # Camera settings self.camera_id = 0 self.frame_width = 640 self.frame_height = 480 self.cap = None - + def initialize_camera(self): """Initialize the webcam""" self.cap = cv2.VideoCapture(self.camera_id) @@ -44,33 +45,39 @@ def initialize_camera(self): self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.frame_height) if not self.cap.isOpened(): raise ValueError("Could not open camera. Check if it's connected properly.") - + def start_session(self): """Start a new punching session""" self.session_start_time = datetime.now() self.punch_counter.reset_counter() self.data_manager.create_new_session() print(f"New session started at {self.session_start_time}") - + def end_session(self): """End the current session and save data""" session_duration = (datetime.now() - self.session_start_time).total_seconds() session_data = { - 'date': self.session_start_time.strftime('%Y-%m-%d %H:%M:%S'), - 'duration': session_duration, - 'total_punches': self.punch_counter.total_count, - 'punch_types': self.punch_counter.get_punch_types_count(), - 'punches_per_minute': self.punch_counter.total_count / (session_duration / 60) if session_duration > 0 else 0 + "date": self.session_start_time.strftime("%Y-%m-%d %H:%M:%S"), + "duration": session_duration, + "total_punches": self.punch_counter.total_count, + "punch_types": self.punch_counter.get_punch_types_count(), + "punches_per_minute": ( + self.punch_counter.total_count / (session_duration / 60) + if session_duration > 0 + else 0 + ), } self.data_manager.save_session_data(session_data) - print(f"Session ended. {self.punch_counter.total_count} punches recorded over {session_duration:.1f} seconds.") - + print( + f"Session ended. {self.punch_counter.total_count} punches recorded over {session_duration:.1f} seconds." + ) + def start_calibration(self): """Start the calibration process""" self.is_calibrating = True self.calibrator.start_calibration() print("Calibration started. Follow the on-screen instructions.") - + def process_frame(self, frame): """Process a single frame from the webcam""" # Detect poses in the frame @@ -83,27 +90,33 @@ def process_frame(self, frame): self.punch_counter.get_punch_types_count(), self.session_start_time, self.punch_counter.velocity_threshold, - paused=True + paused=True, ) - + if self.is_calibrating: # Handle calibration mode - calibration_complete, frame = self.calibrator.process_calibration_frame(frame, poses) + calibration_complete, frame = self.calibrator.process_calibration_frame( + frame, poses + ) if calibration_complete: self.is_calibrating = False - self.punch_counter.apply_calibration(self.calibrator.get_calibration_data()) + self.punch_counter.apply_calibration( + self.calibrator.get_calibration_data() + ) print("Calibration completed!") else: # Normal processing - detect punches punches_detected = self.punch_counter.detect_punches(poses) - + # Add visual feedback if punches are detected if punches_detected: for punch_info in punches_detected: punch_type, coords = punch_info # Visual feedback for the punch - cv2.circle(frame, (int(coords[0]), int(coords[1])), 15, (0, 0, 255), -1) - + cv2.circle( + frame, (int(coords[0]), int(coords[1])), 15, (0, 0, 255), -1 + ) + # Update the UI with the latest data frame = self.ui_manager.update_display( frame, @@ -111,60 +124,60 @@ def process_frame(self, frame): self.punch_counter.get_punch_types_count(), self.session_start_time, self.punch_counter.velocity_threshold, - paused=False + paused=False, ) - + # Show debug visualization if enabled if self.show_debug: frame = self.pose_detector.draw_pose(frame, poses) - + return frame - + def run(self): """Main application loop""" self.initialize_camera() self.is_running = True self.start_session() - + try: while self.is_running: ret, frame = self.cap.read() if not ret: print("Failed to grab frame from camera") break - + # Mirror the frame for a more intuitive display frame = cv2.flip(frame, 1) - + # Process the current frame display_frame = self.process_frame(frame) - + # Display the resulting frame - cv2.imshow('PunchTracker', display_frame) - + cv2.imshow("PunchTracker", display_frame) + # Handle keyboard input key = cv2.waitKey(1) & 0xFF if key == 27: # ESC key break - elif key == ord('c'): + elif key == ord("c"): self.start_calibration() - elif key == ord('d'): + elif key == ord("d"): self.show_debug = not self.show_debug - elif key == ord('r'): + elif key == ord("r"): self.start_session() - elif key == ord('s'): + elif key == ord("s"): stats_image = self.ui_manager.generate_stats_graph( self.data_manager.get_historical_data(), - self.punch_counter.get_punch_types_count() + self.punch_counter.get_punch_types_count(), ) - cv2.imshow('Punch Statistics', stats_image) - elif key == ord('p'): + cv2.imshow("Punch Statistics", stats_image) + elif key == ord("p"): self.is_paused = not self.is_paused - elif key == ord('i'): + elif key == ord("i"): self.punch_counter.increase_sensitivity() - elif key == ord('k'): + elif key == ord("k"): self.punch_counter.decrease_sensitivity() - + finally: # Cleanup self.end_session() @@ -173,9 +186,10 @@ def run(self): cv2.destroyAllWindows() print("Application terminated") + if __name__ == "__main__": try: app = PunchTracker() app.run() except Exception as e: - print(f"Error: {e}") \ No newline at end of file + print(f"Error: {e}") diff --git a/tests/test_data_manager.py b/tests/test_data_manager.py index 4950d8f..2dd2c0d 100644 --- a/tests/test_data_manager.py +++ b/tests/test_data_manager.py @@ -4,21 +4,22 @@ from utils.data_manager import DataManager + def test_create_and_save_session(): with tempfile.TemporaryDirectory() as tmpdir: dm = DataManager(data_dir=tmpdir) dm.create_new_session() session_data = { - 'date': '2023-01-01 00:00:00', - 'duration': 10.0, - 'total_punches': 5, - 'punch_types': {'jab': 2, 'cross': 3}, - 'punches_per_minute': 30.0 + "date": "2023-01-01 00:00:00", + "duration": 10.0, + "total_punches": 5, + "punch_types": {"jab": 2, "cross": 3}, + "punches_per_minute": 30.0, } dm.save_session_data(session_data) hist = dm.get_historical_data() assert len(hist) == 1 - assert hist[0]['total_punches'] == 5 + assert hist[0]["total_punches"] == 5 def test_json_backup_write_failure(tmp_path, capsys): @@ -27,11 +28,11 @@ def test_json_backup_write_failure(tmp_path, capsys): dm = DataManager(data_dir=data_dir) dm.create_new_session() session_data = { - 'date': '2023-01-01 00:00:00', - 'duration': 5.0, - 'total_punches': 2, - 'punch_types': {'jab': 2}, - 'punches_per_minute': 24.0 + "date": "2023-01-01 00:00:00", + "duration": 5.0, + "total_punches": 2, + "punch_types": {"jab": 2}, + "punches_per_minute": 24.0, } os.chmod(data_dir, 0o500) @@ -40,4 +41,4 @@ def test_json_backup_write_failure(tmp_path, capsys): output = capsys.readouterr().out assert "Failed to write backup file" in output - assert len(list(data_dir.glob('session_*.json'))) == 0 + assert len(list(data_dir.glob("session_*.json"))) == 0 diff --git a/tests/test_punch_counter.py b/tests/test_punch_counter.py index 4015f68..ff1ed84 100644 --- a/tests/test_punch_counter.py +++ b/tests/test_punch_counter.py @@ -4,22 +4,22 @@ def test_classify_cross(): pc = PunchCounter() keypoints = { - 'right_wrist': (1.1, 0.0, 1.0), - 'right_elbow': (0.5, 0.0, 1.0), - 'right_shoulder': (0.0, 0.0, 1.0), - 'left_shoulder': (-0.2, 0.0, 1.0), + "right_wrist": (1.1, 0.0, 1.0), + "right_elbow": (0.5, 0.0, 1.0), + "right_shoulder": (0.0, 0.0, 1.0), + "left_shoulder": (-0.2, 0.0, 1.0), } - punch_type = pc._classify_punch_type(keypoints, 'right', velocity=60) + punch_type = pc._classify_punch_type(keypoints, "right", velocity=60) assert punch_type == pc.CROSS def test_classify_uppercut(): pc = PunchCounter() keypoints = { - 'left_wrist': (0.0, -1.1, 1.0), - 'left_elbow': (0.0, -0.5, 1.0), - 'left_shoulder': (0.0, 0.0, 1.0), - 'right_shoulder': (0.5, 0.0, 1.0), + "left_wrist": (0.0, -1.1, 1.0), + "left_elbow": (0.0, -0.5, 1.0), + "left_shoulder": (0.0, 0.0, 1.0), + "right_shoulder": (0.5, 0.0, 1.0), } - punch_type = pc._classify_punch_type(keypoints, 'left', velocity=60) + punch_type = pc._classify_punch_type(keypoints, "left", velocity=60) assert punch_type == pc.UPPERCUT diff --git a/utils/calibration.py b/utils/calibration.py index 89557e2..6db5d17 100644 --- a/utils/calibration.py +++ b/utils/calibration.py @@ -1,10 +1,12 @@ """ Calibration module for adjusting punch detection parameters """ + import cv2 import numpy as np import time + class Calibrator: def __init__(self): """Initialize calibration parameters""" @@ -15,22 +17,22 @@ def __init__(self): "Throw a few jabs and crosses", "Throw a few hooks", "Throw a few uppercuts", - "Calibration complete!" + "Calibration complete!", ] self.step_durations = [5, 10, 10, 10, 3] # seconds for each step self.step_start_time = None - + # Data collected during calibration self.punch_velocities = [] self.punch_distances = [] - + # Final calibration data self.calibration_data = { "velocity_multiplier": 1.0, "direction_adjust": 0.0, - "threshold_adjust": 0.0 + "threshold_adjust": 0.0, } - + def start_calibration(self): """Start the calibration process""" self.is_calibrating = True @@ -39,56 +41,58 @@ def start_calibration(self): self.punch_velocities = [] self.punch_distances = [] print("Calibration started") - + def process_calibration_frame(self, frame, keypoints): """ Process a frame during calibration - + Args: frame: Input video frame keypoints: Detected keypoints from pose detector - + Returns: Tuple of (calibration_complete, processed_frame) """ # Calculate time elapsed in current step current_time = time.time() elapsed_time = current_time - self.step_start_time - + # Check if current step is complete if elapsed_time >= self.step_durations[self.calibration_stage]: self.calibration_stage += 1 self.step_start_time = current_time - + # If all steps are complete, finish calibration if self.calibration_stage >= len(self.calibration_steps): self._process_calibration_data() return True, self._draw_calibration_status(frame) - + # Process keypoints for the current calibration stage if 1 <= self.calibration_stage <= 3: # Stages for collecting punch data self._collect_calibration_data(keypoints) - + # Draw calibration status on frame return False, self._draw_calibration_status(frame) - + def _collect_calibration_data(self, keypoints): """Collect data during the punch calibration stages""" from utils.pose_detector import PoseDetector - + # Extract wrist positions and calculate distances and velocities pose_detector = PoseDetector() hand_keypoints = pose_detector.get_hand_keypoints(keypoints) - + # Process key points relevant to the current calibration stage # For now, just store the positions to calculate velocities later for hand_key in ["left_wrist", "right_wrist"]: keypoint = hand_keypoints.get(hand_key) if keypoint is not None: - # In a real implementation, you would calculate velocities and + # In a real implementation, you would calculate velocities and # other metrics here based on sequential frames - self.punch_velocities.append(keypoint[2]) # Using confidence as a proxy for now - + self.punch_velocities.append( + keypoint[2] + ) # Using confidence as a proxy for now + def _process_calibration_data(self): """Process collected data to determine calibration parameters""" # Calculate velocity multiplier based on collected data @@ -100,47 +104,69 @@ def _process_calibration_data(self): self.calibration_data["velocity_multiplier"] = 1.5 / avg_velocity else: self.calibration_data["velocity_multiplier"] = 1.0 - + # Cap the multiplier to reasonable range - self.calibration_data["velocity_multiplier"] = max(0.5, min(2.0, self.calibration_data["velocity_multiplier"])) - + self.calibration_data["velocity_multiplier"] = max( + 0.5, min(2.0, self.calibration_data["velocity_multiplier"]) + ) + # In a real implementation, calculate direction_adjust and threshold_adjust # based on collected data print(f"Calibration completed: {self.calibration_data}") self.is_calibrating = False - + def _draw_calibration_status(self, frame): """Draw calibration status overlay on the frame""" h, w = frame.shape[:2] - + # Create overlay overlay = frame.copy() cv2.rectangle(overlay, (0, 0), (w, 80), (0, 0, 0), -1) - + # Add transparency alpha = 0.7 cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame) - + # Add calibration step text if self.calibration_stage < len(self.calibration_steps): current_step = self.calibration_steps[self.calibration_stage] - + # Show remaining time for current step elapsed_time = time.time() - self.step_start_time - remaining_time = max(0, self.step_durations[self.calibration_stage] - elapsed_time) - + remaining_time = max( + 0, self.step_durations[self.calibration_stage] - elapsed_time + ) + step_text = f"Step {self.calibration_stage + 1}/{len(self.calibration_steps)}: {current_step}" time_text = f"Time remaining: {int(remaining_time)}s" - - cv2.putText(frame, step_text, (20, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - cv2.putText(frame, time_text, (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) - + + cv2.putText( + frame, + step_text, + (20, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + cv2.putText( + frame, + time_text, + (20, 60), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + (255, 255, 255), + 2, + ) + # Add progress bar - progress_width = int(w * (elapsed_time / self.step_durations[self.calibration_stage])) + progress_width = int( + w * (elapsed_time / self.step_durations[self.calibration_stage]) + ) cv2.rectangle(frame, (0, 70), (progress_width, 80), (0, 255, 0), -1) - + return frame - + def get_calibration_data(self): """Get the calibration data""" - return self.calibration_data \ No newline at end of file + return self.calibration_data diff --git a/utils/data_manager.py b/utils/data_manager.py index 899c061..b41d3a8 100644 --- a/utils/data_manager.py +++ b/utils/data_manager.py @@ -1,37 +1,40 @@ """ Data Manager module for handling session data storage and retrieval """ + import os import json import time from datetime import datetime import sqlite3 + class DataManager: def __init__(self, data_dir="data"): """ Initialize the data manager - + Args: data_dir: Directory to store data files """ self.data_dir = os.path.abspath(data_dir) self.db_path = os.path.join(self.data_dir, "punch_sessions.db") self.current_session_id = None - + # Create data directory if it doesn't exist os.makedirs(self.data_dir, exist_ok=True) - + # Initialize database self._initialize_database() - + def _initialize_database(self): """Initialize the SQLite database for storing session data""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - + # Create sessions table if it doesn't exist - cursor.execute(''' + cursor.execute( + """ CREATE TABLE IF NOT EXISTS sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, date TEXT, @@ -43,133 +46,143 @@ def _initialize_database(self): hook_count INTEGER, uppercut_count INTEGER ) - ''') - + """ + ) + conn.commit() conn.close() - + print(f"Database initialized at {self.db_path}") - + def create_new_session(self): """Create a new tracking session""" # Generate a unique session ID based on timestamp self.current_session_id = int(time.time()) print(f"New session created with ID: {self.current_session_id}") - + def save_session_data(self, session_data): """ Save session data to the database - + Args: session_data: Dictionary containing session information """ if not session_data: print("No session data to save") return - + conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - + # Extract punch type counts - punch_types = session_data.get('punch_types', {}) - + punch_types = session_data.get("punch_types", {}) + # Insert session data into database - cursor.execute(''' + cursor.execute( + """ INSERT INTO sessions (date, duration, total_punches, punches_per_minute, jab_count, cross_count, hook_count, uppercut_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ''', ( - session_data['date'], - session_data['duration'], - session_data['total_punches'], - session_data['punches_per_minute'], - punch_types.get('jab', 0), - punch_types.get('cross', 0), - punch_types.get('hook', 0), - punch_types.get('uppercut', 0) - )) - + """, + ( + session_data["date"], + session_data["duration"], + session_data["total_punches"], + session_data["punches_per_minute"], + punch_types.get("jab", 0), + punch_types.get("cross", 0), + punch_types.get("hook", 0), + punch_types.get("uppercut", 0), + ), + ) + conn.commit() - + # Get the ID of the inserted row cursor.execute("SELECT last_insert_rowid()") session_id = cursor.fetchone()[0] - + conn.close() - + print(f"Session data saved with ID: {session_id}") - + # Also save a JSON backup for easy debugging and portability self._save_json_backup(session_id, session_data) - + def _save_json_backup(self, session_id, session_data): """Save a JSON backup of the session data""" backup_file = os.path.join(self.data_dir, f"session_{session_id}.json") - + try: - with open(backup_file, 'w') as f: + with open(backup_file, "w") as f: json.dump(session_data, f, indent=4) except IOError as e: print(f"Failed to write backup file {backup_file}: {e}") return print(f"Session backup saved to {backup_file}") - + def get_historical_data(self, limit=10): """ Retrieve historical session data from the database - + Args: limit: Maximum number of sessions to retrieve (default 10) - + Returns: List of session data dictionaries """ conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - + # Get the most recent sessions - cursor.execute(''' + cursor.execute( + """ SELECT date, duration, total_punches, punches_per_minute, jab_count, cross_count, hook_count, uppercut_count FROM sessions ORDER BY date DESC LIMIT ? - ''', (limit,)) - + """, + (limit,), + ) + sessions = cursor.fetchall() conn.close() - + # Convert to list of dictionaries historical_data = [] for session in sessions: - historical_data.append({ - 'date': session[0], - 'duration': session[1], - 'total_punches': session[2], - 'punches_per_minute': session[3], - 'punch_types': { - 'jab': session[4], - 'cross': session[5], - 'hook': session[6], - 'uppercut': session[7] + historical_data.append( + { + "date": session[0], + "duration": session[1], + "total_punches": session[2], + "punches_per_minute": session[3], + "punch_types": { + "jab": session[4], + "cross": session[5], + "hook": session[6], + "uppercut": session[7], + }, } - }) - + ) + return historical_data - + def get_stats_summary(self): """ Generate a summary of all session statistics - + Returns: Dictionary with summary statistics """ conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - + # Get aggregate statistics - cursor.execute(''' + cursor.execute( + """ SELECT COUNT(*) as total_sessions, SUM(total_punches) as total_punches, @@ -181,41 +194,37 @@ def get_stats_summary(self): SUM(uppercut_count) as total_uppercuts, SUM(duration) as total_duration FROM sessions - ''') - + """ + ) + result = cursor.fetchone() conn.close() - + if not result or result[0] == 0: return { - 'total_sessions': 0, - 'total_punches': 0, - 'avg_ppm': 0, - 'max_ppm': 0, - 'total_minutes': 0, - 'punch_distribution': { - 'jab': 0, - 'cross': 0, - 'hook': 0, - 'uppercut': 0 - } + "total_sessions": 0, + "total_punches": 0, + "avg_ppm": 0, + "max_ppm": 0, + "total_minutes": 0, + "punch_distribution": {"jab": 0, "cross": 0, "hook": 0, "uppercut": 0}, } - + # Calculate the punch distribution percentages total_punches = result[1] or 1 # Avoid division by zero - + summary = { - 'total_sessions': result[0], - 'total_punches': result[1], - 'avg_ppm': result[2], - 'max_ppm': result[3], - 'total_minutes': result[8] / 60 if result[8] else 0, - 'punch_distribution': { - 'jab': (result[4] / total_punches) * 100 if result[4] else 0, - 'cross': (result[5] / total_punches) * 100 if result[5] else 0, - 'hook': (result[6] / total_punches) * 100 if result[6] else 0, - 'uppercut': (result[7] / total_punches) * 100 if result[7] else 0 - } + "total_sessions": result[0], + "total_punches": result[1], + "avg_ppm": result[2], + "max_ppm": result[3], + "total_minutes": result[8] / 60 if result[8] else 0, + "punch_distribution": { + "jab": (result[4] / total_punches) * 100 if result[4] else 0, + "cross": (result[5] / total_punches) * 100 if result[5] else 0, + "hook": (result[6] / total_punches) * 100 if result[6] else 0, + "uppercut": (result[7] / total_punches) * 100 if result[7] else 0, + }, } - - return summary \ No newline at end of file + + return summary diff --git a/utils/pose_detector.py b/utils/pose_detector.py index a86fc2f..1ef3654 100644 --- a/utils/pose_detector.py +++ b/utils/pose_detector.py @@ -1,12 +1,14 @@ """ Pose Detector module using TensorFlow's MoveNet for skeletal tracking """ + import os import numpy as np import tensorflow as tf import tensorflow_hub as tfhub import cv2 + class PoseDetector: # MoveNet keypoint indices KEYPOINT_NOSE = 0 @@ -26,11 +28,11 @@ class PoseDetector: KEYPOINT_RIGHT_KNEE = 14 KEYPOINT_LEFT_ANKLE = 15 KEYPOINT_RIGHT_ANKLE = 16 - + def __init__(self, model_type="movenet_lightning"): """ Initialize the pose detector with TensorFlow MoveNet model - + Args: model_type (str): Type of MoveNet model to use ('movenet_lightning' or 'movenet_thunder') """ @@ -39,7 +41,7 @@ def __init__(self, model_type="movenet_lightning"): self.model = self._load_model() self.input_size = 192 if model_type == "movenet_lightning" else 256 self.keypoint_threshold = 0.3 # Confidence threshold for keypoints - + def _get_model_path(self): """Get the TF Hub model path based on the selected model type""" if self.model_type == "movenet_lightning": @@ -48,78 +50,79 @@ def _get_model_path(self): return "https://tfhub.dev/google/movenet/singlepose/thunder/4" else: raise ValueError(f"Unsupported model type: {self.model_type}") - + def _load_model(self): """Load the TensorFlow model from TF Hub""" print(f"Loading MoveNet model: {self.model_type}") module = tfhub.load(self.model_path) - model = module.signatures['serving_default'] + model = module.signatures["serving_default"] print("Model loaded successfully") return model - + def _preprocess_image(self, image): """Preprocess the input image for the model""" # Resize and pad the image to the model's input dimensions input_img = tf.image.resize_with_pad( tf.expand_dims(tf.convert_to_tensor(image), axis=0), - self.input_size, self.input_size + self.input_size, + self.input_size, ) # Convert to float32 input_img = tf.cast(input_img, dtype=tf.float32) return input_img - + def detect_pose(self, image): """ Detect poses in the input image - + Args: image: Input RGB image - + Returns: List of detected keypoints with their coordinates and confidence scores """ # Preprocess the image input_img = self._preprocess_image(image) - + # Run inference outputs = self.model(input_img) - keypoints = outputs['output_0'].numpy() - + keypoints = outputs["output_0"].numpy() + # The model returns keypoints in format [1, 1, 17, 3] where the last dimension # consists of [y, x, confidence] keypoints = keypoints[0, 0, :, :] - + # Convert normalized coordinates to pixel coordinates h, w, _ = image.shape keypoints_with_scores = [] - + for i, keypoint in enumerate(keypoints): y, x, confidence = keypoint # Skip keypoints with low confidence if confidence < self.keypoint_threshold: keypoints_with_scores.append((i, None, None, 0.0)) continue - + # Convert to pixel coordinates px = int(x * w) py = int(y * h) keypoints_with_scores.append((i, px, py, confidence)) - + return keypoints_with_scores - + def draw_pose(self, image, keypoints): """ Draw the detected pose keypoints and connections on the image - + Args: image: The input image keypoints: List of detected keypoints - + Returns: Image with pose visualization """ output_img = image.copy() - + # Define connections between keypoints for visualization connections = [ (self.KEYPOINT_NOSE, self.KEYPOINT_LEFT_EYE), @@ -137,35 +140,38 @@ def draw_pose(self, image, keypoints): (self.KEYPOINT_LEFT_HIP, self.KEYPOINT_LEFT_KNEE), (self.KEYPOINT_RIGHT_HIP, self.KEYPOINT_RIGHT_KNEE), (self.KEYPOINT_LEFT_KNEE, self.KEYPOINT_LEFT_ANKLE), - (self.KEYPOINT_RIGHT_KNEE, self.KEYPOINT_RIGHT_ANKLE) + (self.KEYPOINT_RIGHT_KNEE, self.KEYPOINT_RIGHT_ANKLE), ] - + # Create a keypoint lookup dictionary for faster access keypoint_dict = {kp[0]: (kp[1], kp[2], kp[3]) for kp in keypoints} - + # Draw connections for connection in connections: start_idx, end_idx = connection if start_idx in keypoint_dict and end_idx in keypoint_dict: start_point = keypoint_dict[start_idx] end_point = keypoint_dict[end_idx] - + # Skip if either keypoint was not detected if None in (start_point[0], start_point[1], end_point[0], end_point[1]): continue - + # Draw the connection line - cv2.line(output_img, - (start_point[0], start_point[1]), - (end_point[0], end_point[1]), - (0, 255, 0), 2) - + cv2.line( + output_img, + (start_point[0], start_point[1]), + (end_point[0], end_point[1]), + (0, 255, 0), + 2, + ) + # Draw keypoints for idx, x, y, confidence in keypoints: # Skip keypoints with low confidence if x is None or y is None: continue - + # Color based on confidence: green for high confidence, yellow for medium, red for low if confidence > 0.7: color = (0, 255, 0) # Green @@ -173,24 +179,24 @@ def draw_pose(self, image, keypoints): color = (0, 255, 255) # Yellow else: color = (0, 0, 255) # Red - + cv2.circle(output_img, (x, y), 5, color, -1) - + return output_img - + def get_hand_keypoints(self, keypoints): """ Extract hand keypoints (wrists, elbows, shoulders) which are important for punch detection - + Args: keypoints: List of all detected keypoints - + Returns: Dictionary with hand keypoint coordinates """ hand_keypoints = {} keypoint_dict = {kp[0]: (kp[1], kp[2], kp[3]) for kp in keypoints} - + # Extract important keypoints for punch detection important_keypoints = [ ("left_wrist", self.KEYPOINT_LEFT_WRIST), @@ -198,13 +204,17 @@ def get_hand_keypoints(self, keypoints): ("left_elbow", self.KEYPOINT_LEFT_ELBOW), ("right_elbow", self.KEYPOINT_RIGHT_ELBOW), ("left_shoulder", self.KEYPOINT_LEFT_SHOULDER), - ("right_shoulder", self.KEYPOINT_RIGHT_SHOULDER) + ("right_shoulder", self.KEYPOINT_RIGHT_SHOULDER), ] - + for name, idx in important_keypoints: if idx in keypoint_dict and keypoint_dict[idx][0] is not None: - hand_keypoints[name] = (keypoint_dict[idx][0], keypoint_dict[idx][1], keypoint_dict[idx][2]) + hand_keypoints[name] = ( + keypoint_dict[idx][0], + keypoint_dict[idx][1], + keypoint_dict[idx][2], + ) else: hand_keypoints[name] = None - - return hand_keypoints \ No newline at end of file + + return hand_keypoints diff --git a/utils/punch_counter.py b/utils/punch_counter.py index 2ac53ce..01737e9 100644 --- a/utils/punch_counter.py +++ b/utils/punch_counter.py @@ -1,60 +1,51 @@ """ Punch Counter module for detecting and tracking punch movements """ + import numpy as np import time import math from collections import deque from utils.pose_detector import PoseDetector + class PunchCounter: # Punch types JAB = "jab" CROSS = "cross" HOOK = "hook" UPPERCUT = "uppercut" - + def __init__(self, pose_detector): # Store the pose detector self.pose_detector = pose_detector - + # Counters for different punch types self.total_count = 0 - self.punch_counts = { - self.JAB: 0, - self.CROSS: 0, - self.HOOK: 0, - self.UPPERCUT: 0 - } - + self.punch_counts = {self.JAB: 0, self.CROSS: 0, self.HOOK: 0, self.UPPERCUT: 0} + # Track position history for velocity calculation self.position_history = { "left_wrist": deque(maxlen=10), - "right_wrist": deque(maxlen=10) + "right_wrist": deque(maxlen=10), } - + # Timestamps for position history self.timestamp_history = { "left_wrist": deque(maxlen=10), - "right_wrist": deque(maxlen=10) + "right_wrist": deque(maxlen=10), } - + # Cooldown to prevent rapid punch detections - self.last_punch_time = { - "left": 0, - "right": 0 - } + self.last_punch_time = {"left": 0, "right": 0} self.punch_cooldown = 0.5 # seconds - + # Punch detection parameters self.velocity_threshold = 50 # pixels per frame self.direction_threshold = 0.7 # cosine similarity threshold - + # Calibration adjustments - self.calibration_data = { - "velocity_multiplier": 1.0, - "direction_adjust": 0.0 - } + self.calibration_data = {"velocity_multiplier": 1.0, "direction_adjust": 0.0} def increase_sensitivity(self, step=5): """Decrease velocity threshold to increase sensitivity""" @@ -63,91 +54,91 @@ def increase_sensitivity(self, step=5): def decrease_sensitivity(self, step=5): """Increase velocity threshold to decrease sensitivity""" self.velocity_threshold = self.velocity_threshold + step - + def reset_counter(self): """Reset all punch counters""" self.total_count = 0 for punch_type in self.punch_counts: self.punch_counts[punch_type] = 0 - + def get_punch_types_count(self): """Get the count of each punch type""" return self.punch_counts - + def apply_calibration(self, calibration_data): """Apply calibration adjustments""" self.calibration_data = calibration_data print(f"Applied calibration: {calibration_data}") - + def _calculate_velocity(self, positions, timestamps): """Calculate the velocity of a keypoint based on its position history""" if len(positions) < 2 or len(timestamps) < 2: return 0, None - + # Get the two most recent positions pos1 = positions[-2] pos2 = positions[-1] time1 = timestamps[-2] time2 = timestamps[-1] - + # Skip if positions are None if pos1 is None or pos2 is None: return 0, None - + # Calculate displacement dx = pos2[0] - pos1[0] dy = pos2[1] - pos1[1] - displacement = math.sqrt(dx*dx + dy*dy) - + displacement = math.sqrt(dx * dx + dy * dy) + # Calculate time difference dt = time2 - time1 if dt == 0: return 0, None - + # Calculate velocity and direction velocity = displacement / dt - direction = (dx/displacement, dy/displacement) if displacement > 0 else None - + direction = (dx / displacement, dy / displacement) if displacement > 0 else None + # Apply calibration velocity *= self.calibration_data["velocity_multiplier"] - + return velocity, direction - + def _is_punch_motion(self, velocity, direction, hand): """Determine if a hand motion qualifies as a punch""" # Check velocity threshold if velocity < self.velocity_threshold: return False - + # Punches generally move forward (x-axis in mirrored view) # For left hand (right side of screen), x direction should be negative # For right hand (left side of screen), x direction should be positive if direction is None: return False - + forward_direction = direction[0] < 0 if hand == "left" else direction[0] > 0 if not forward_direction: return False - + # Check cooldown to prevent multiple detections of the same punch current_time = time.time() if current_time - self.last_punch_time[hand] < self.punch_cooldown: return False - + # Update last punch time self.last_punch_time[hand] = current_time - + return True - + def _classify_punch_type(self, keypoints, hand, velocity): """ Classify the type of punch based on hand position relative to shoulders and head - + Args: keypoints: Dictionary of keypoints hand: 'left' or 'right' indicating which hand threw the punch velocity: The velocity of the punch - + Returns: String indicating punch type (jab, cross, hook, uppercut) """ @@ -155,33 +146,43 @@ def _classify_punch_type(self, keypoints, hand, velocity): elbow_key = f"{hand}_elbow" shoulder_key = f"{hand}_shoulder" opposite_shoulder_key = "right_shoulder" if hand == "left" else "left_shoulder" - + # Get keypoints wrist = keypoints.get(wrist_key) elbow = keypoints.get(elbow_key) shoulder = keypoints.get(shoulder_key) opposite_shoulder = keypoints.get(opposite_shoulder_key) - + # Return default if keypoints are missing if None in (wrist, elbow, shoulder, opposite_shoulder): # Return cross for right hand (dominant hand), jab for left return self.CROSS if hand == "right" else self.JAB - + # Calculate vertical position of wrist relative to elbow wrist_above_elbow = wrist[1] < elbow[1] - + # Calculate horizontal position of wrist relative to shoulder - wrist_outside_shoulder = wrist[0] < shoulder[0] if hand == "left" else wrist[0] > shoulder[0] - + wrist_outside_shoulder = ( + wrist[0] < shoulder[0] if hand == "left" else wrist[0] > shoulder[0] + ) + # Detect if arm is extended (distance from wrist to shoulder) - wrist_to_shoulder_dist = math.sqrt((wrist[0] - shoulder[0])**2 + (wrist[1] - shoulder[1])**2) - elbow_to_shoulder_dist = math.sqrt((elbow[0] - shoulder[0])**2 + (elbow[1] - shoulder[1])**2) - wrist_to_elbow_dist = math.sqrt((wrist[0] - elbow[0])**2 + (wrist[1] - elbow[1])**2) + wrist_to_shoulder_dist = math.sqrt( + (wrist[0] - shoulder[0]) ** 2 + (wrist[1] - shoulder[1]) ** 2 + ) + elbow_to_shoulder_dist = math.sqrt( + (elbow[0] - shoulder[0]) ** 2 + (elbow[1] - shoulder[1]) ** 2 + ) + wrist_to_elbow_dist = math.sqrt( + (wrist[0] - elbow[0]) ** 2 + (wrist[1] - elbow[1]) ** 2 + ) arm_length = elbow_to_shoulder_dist + wrist_to_elbow_dist - arm_extension_ratio = wrist_to_shoulder_dist / arm_length if arm_length > 0 else 0 - + arm_extension_ratio = ( + wrist_to_shoulder_dist / arm_length if arm_length > 0 else 0 + ) + is_extended = arm_extension_ratio > 0.8 - + # Classify based on arm position if wrist_above_elbow and is_extended: return self.UPPERCUT @@ -191,20 +192,20 @@ def _classify_punch_type(self, keypoints, hand, velocity): return self.CROSS else: return self.JAB - + def detect_punches(self, keypoints_list): """ Detect punches from pose keypoints - + Args: keypoints_list: List of keypoints from the pose detector - + Returns: List of detected punches with type and coordinates """ # Extract hand keypoints hand_keypoints = self.pose_detector.get_hand_keypoints(keypoints_list) - + # Update position history current_time = time.time() for hand_key in ["left_wrist", "right_wrist"]: @@ -212,32 +213,33 @@ def detect_punches(self, keypoints_list): if keypoint is not None: self.position_history[hand_key].append((keypoint[0], keypoint[1])) self.timestamp_history[hand_key].append(current_time) - + # Detect punches from both hands detected_punches = [] - + for hand, wrist_key in [("left", "left_wrist"), ("right", "right_wrist")]: # Calculate velocity and direction velocity, direction = self._calculate_velocity( - self.position_history[wrist_key], - self.timestamp_history[wrist_key] + self.position_history[wrist_key], self.timestamp_history[wrist_key] ) - + # Check if motion is a punch if self._is_punch_motion(velocity, direction, hand): # Classify punch type punch_type = self._classify_punch_type(hand_keypoints, hand, velocity) - + # Get wrist coordinates for visualization wrist_coords = self.position_history[wrist_key][-1] - + # Update counter self.punch_counts[punch_type] += 1 self.total_count += 1 - + # Add to detected punches detected_punches.append((punch_type, wrist_coords)) - - print(f"Detected {punch_type.upper()} - Total punches: {self.total_count}") - - return detected_punches \ No newline at end of file + + print( + f"Detected {punch_type.upper()} - Total punches: {self.total_count}" + ) + + return detected_punches diff --git a/utils/ui_manager.py b/utils/ui_manager.py index 78018b6..cb490a9 100644 --- a/utils/ui_manager.py +++ b/utils/ui_manager.py @@ -1,6 +1,7 @@ """ UI Manager module for handling the application's user interface """ + import cv2 import numpy as np import time @@ -8,6 +9,7 @@ import matplotlib.pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg + class UIManager: def __init__(self): # UI settings @@ -16,112 +18,162 @@ def __init__(self): self.font_color = (255, 255, 255) # White self.line_thickness = 2 self.bg_color = (0, 0, 0, 0.5) # Semi-transparent black - + # Stats panel dimensions self.panel_width = 200 self.panel_height = 200 self.panel_padding = 10 - + # Punch type colors for visualization self.punch_colors = { "jab": (46, 204, 113), # Green "cross": (231, 76, 60), # Red "hook": (52, 152, 219), # Blue - "uppercut": (155, 89, 182) # Purple + "uppercut": (155, 89, 182), # Purple } - - def update_display(self, frame, total_count, punch_counts, session_start_time, sensitivity, paused=False): + + def update_display( + self, + frame, + total_count, + punch_counts, + session_start_time, + sensitivity, + paused=False, + ): """ Update the UI elements on the frame - + Args: frame: The input video frame total_count: Total number of punches detected punch_counts: Dictionary with counts for each punch type session_start_time: Start time of the current session - + Returns: Frame with UI elements added """ # Create a copy of the frame to avoid modifying the original display_frame = frame.copy() - + # Add semi-transparent overlay for stats panel - self._add_stats_panel(display_frame, total_count, punch_counts, session_start_time, sensitivity) + self._add_stats_panel( + display_frame, total_count, punch_counts, session_start_time, sensitivity + ) if paused: self._add_paused_overlay(display_frame) - + # Add instructions self._add_instructions(display_frame) - + return display_frame - - def _add_stats_panel(self, frame, total_count, punch_counts, session_start_time, sensitivity): + + def _add_stats_panel( + self, frame, total_count, punch_counts, session_start_time, sensitivity + ): """Add the statistics panel to the frame""" h, w = frame.shape[:2] - + # Create semi-transparent overlay for stats panel overlay = frame.copy() panel_x = w - self.panel_width - self.panel_padding panel_y = self.panel_padding - cv2.rectangle(overlay, - (panel_x, panel_y), - (panel_x + self.panel_width, panel_y + self.panel_height), - (0, 0, 0), - -1) - + cv2.rectangle( + overlay, + (panel_x, panel_y), + (panel_x + self.panel_width, panel_y + self.panel_height), + (0, 0, 0), + -1, + ) + # Apply the overlay with transparency alpha = 0.7 cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame) - + # Add title title_y = panel_y + 30 - cv2.putText(frame, "PUNCH STATS", - (panel_x + 10, title_y), - self.font, 1, self.font_color, 2) - + cv2.putText( + frame, + "PUNCH STATS", + (panel_x + 10, title_y), + self.font, + 1, + self.font_color, + 2, + ) + # Add total count count_y = title_y + 30 - cv2.putText(frame, f"Total: {total_count}", - (panel_x + 10, count_y), - self.font, self.font_scale, self.font_color, self.line_thickness) - + cv2.putText( + frame, + f"Total: {total_count}", + (panel_x + 10, count_y), + self.font, + self.font_scale, + self.font_color, + self.line_thickness, + ) + # Add individual punch counts y_offset = count_y + 5 for punch_type, count in punch_counts.items(): y_offset += 25 color = self.punch_colors.get(punch_type, self.font_color) - cv2.putText(frame, f"{punch_type.capitalize()}: {count}", - (panel_x + 10, y_offset), - self.font, self.font_scale, color, self.line_thickness) - + cv2.putText( + frame, + f"{punch_type.capitalize()}: {count}", + (panel_x + 10, y_offset), + self.font, + self.font_scale, + color, + self.line_thickness, + ) + # Add session time if session_start_time: session_duration = datetime.now() - session_start_time minutes, seconds = divmod(session_duration.seconds, 60) time_str = f"Time: {minutes:02d}:{seconds:02d}" - cv2.putText(frame, time_str, - (panel_x + 10, y_offset + 30), - self.font, self.font_scale, self.font_color, self.line_thickness) + cv2.putText( + frame, + time_str, + (panel_x + 10, y_offset + 30), + self.font, + self.font_scale, + self.font_color, + self.line_thickness, + ) # Calculate and display punches per minute minutes_float = session_duration.total_seconds() / 60 if minutes_float > 0: ppm = total_count / minutes_float - cv2.putText(frame, f"Pace: {ppm:.1f} p/min", - (panel_x + 10, y_offset + 55), - self.font, self.font_scale, self.font_color, self.line_thickness) + cv2.putText( + frame, + f"Pace: {ppm:.1f} p/min", + (panel_x + 10, y_offset + 55), + self.font, + self.font_scale, + self.font_color, + self.line_thickness, + ) # Show current sensitivity - cv2.putText(frame, f"Sens.: {sensitivity}", - (panel_x + 10, panel_y + self.panel_height - 10), - self.font, self.font_scale, self.font_color, self.line_thickness) - + cv2.putText( + frame, + f"Sens.: {sensitivity}", + (panel_x + 10, panel_y + self.panel_height - 10), + self.font, + self.font_scale, + self.font_color, + self.line_thickness, + ) + def _add_instructions(self, frame): """Add instruction text to the frame""" h, w = frame.shape[:2] - + instructions = [ "ESC - Exit", "C - Calibrate", @@ -130,128 +182,166 @@ def _add_instructions(self, frame): "S - Show Stats", "P - Pause", "I - Sens. +", - "K - Sens. -" + "K - Sens. -", ] - + # Create semi-transparent overlay for instructions overlay = frame.copy() inst_x = self.panel_padding inst_y = h - (len(instructions) * 25 + 10) inst_width = 150 inst_height = len(instructions) * 25 + 10 - - cv2.rectangle(overlay, - (inst_x, inst_y), - (inst_x + inst_width, inst_y + inst_height), - (0, 0, 0), - -1) - + + cv2.rectangle( + overlay, + (inst_x, inst_y), + (inst_x + inst_width, inst_y + inst_height), + (0, 0, 0), + -1, + ) + # Apply the overlay with transparency alpha = 0.7 cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame) - + # Add instruction text for i, instruction in enumerate(instructions): y_pos = inst_y + 25 + (i * 25) - cv2.putText(frame, instruction, - (inst_x + 10, y_pos), - self.font, self.font_scale, self.font_color, 1) + cv2.putText( + frame, + instruction, + (inst_x + 10, y_pos), + self.font, + self.font_scale, + self.font_color, + 1, + ) def _add_paused_overlay(self, frame): """Display a paused overlay on the frame""" h, w = frame.shape[:2] overlay = frame.copy() - cv2.rectangle(overlay, (0, int(h/2 - 40)), (w, int(h/2 + 40)), (0, 0, 0), -1) + cv2.rectangle( + overlay, (0, int(h / 2 - 40)), (w, int(h / 2 + 40)), (0, 0, 0), -1 + ) cv2.addWeighted(overlay, 0.6, frame, 0.4, 0, frame) - cv2.putText(frame, 'PAUSED', (int(w/2) - 60, int(h/2) + 10), - self.font, 1.2, (0, 0, 255), 3) - + cv2.putText( + frame, + "PAUSED", + (int(w / 2) - 60, int(h / 2) + 10), + self.font, + 1.2, + (0, 0, 255), + 3, + ) + def generate_stats_graph(self, historical_data, current_session_data): """ Generate a graph showing historical punch statistics - + Args: historical_data: List of session data from previous sessions current_session_data: Dictionary with current session punch counts - + Returns: Numpy array containing the graph image """ # Create a new figure with matplotlib fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8)) fig.subplots_adjust(hspace=0.3) - + # Extract data for plotting if historical_data: - dates = [session['date'] for session in historical_data[-5:]] # Last 5 sessions - total_punches = [session['total_punches'] for session in historical_data[-5:]] - ppm_values = [session['punches_per_minute'] for session in historical_data[-5:]] - + dates = [ + session["date"] for session in historical_data[-5:] + ] # Last 5 sessions + total_punches = [ + session["total_punches"] for session in historical_data[-5:] + ] + ppm_values = [ + session["punches_per_minute"] for session in historical_data[-5:] + ] + # Add current session if available if current_session_data: - dates.append('Current') + dates.append("Current") current_total = sum(current_session_data.values()) total_punches.append(current_total) # Estimate current PPM based on session duration # (this would be replaced with actual calculation in a real app) ppm_values.append(current_total / 1.0) # Assuming 1 minute - + # Plot total punches per session - ax1.bar(dates, total_punches, color='steelblue') - ax1.set_title('Total Punches per Session') - ax1.set_ylabel('Punch Count') - ax1.tick_params(axis='x', rotation=45) - + ax1.bar(dates, total_punches, color="steelblue") + ax1.set_title("Total Punches per Session") + ax1.set_ylabel("Punch Count") + ax1.tick_params(axis="x", rotation=45) + # Plot punches per minute - ax2.bar(dates, ppm_values, color='firebrick') - ax2.set_title('Punches per Minute') - ax2.set_ylabel('Punches/Min') - ax2.tick_params(axis='x', rotation=45) + ax2.bar(dates, ppm_values, color="firebrick") + ax2.set_title("Punches per Minute") + ax2.set_ylabel("Punches/Min") + ax2.tick_params(axis="x", rotation=45) else: # Show current session data only if current_session_data: punch_types = list(current_session_data.keys()) counts = list(current_session_data.values()) - + # Plot punch type distribution colors = [self.punch_colors.get(pt, (0, 0, 0)) for pt in punch_types] # Convert BGR to RGB for matplotlib - rgb_colors = [(r/255, g/255, b/255) for b, g, r in colors] - + rgb_colors = [(r / 255, g / 255, b / 255) for b, g, r in colors] + ax1.bar(punch_types, counts, color=rgb_colors) - ax1.set_title('Current Session Punch Distribution') - ax1.set_ylabel('Count') - + ax1.set_title("Current Session Punch Distribution") + ax1.set_ylabel("Count") + # Add a placeholder message in the second plot - ax2.text(0.5, 0.5, 'No historical data available yet', - horizontalalignment='center', verticalalignment='center', - transform=ax2.transAxes) - ax2.set_title('Historical Data') - ax2.axis('off') + ax2.text( + 0.5, + 0.5, + "No historical data available yet", + horizontalalignment="center", + verticalalignment="center", + transform=ax2.transAxes, + ) + ax2.set_title("Historical Data") + ax2.axis("off") else: # No data available - ax1.text(0.5, 0.5, 'No data available', - horizontalalignment='center', verticalalignment='center', - transform=ax1.transAxes) - ax1.set_title('Current Session') - ax1.axis('off') - - ax2.text(0.5, 0.5, 'No historical data available', - horizontalalignment='center', verticalalignment='center', - transform=ax2.transAxes) - ax2.set_title('Historical Data') - ax2.axis('off') - + ax1.text( + 0.5, + 0.5, + "No data available", + horizontalalignment="center", + verticalalignment="center", + transform=ax1.transAxes, + ) + ax1.set_title("Current Session") + ax1.axis("off") + + ax2.text( + 0.5, + 0.5, + "No historical data available", + horizontalalignment="center", + verticalalignment="center", + transform=ax2.transAxes, + ) + ax2.set_title("Historical Data") + ax2.axis("off") + # Adjust layout plt.tight_layout() - + # Convert matplotlib figure to OpenCV image canvas = FigureCanvasAgg(fig) canvas.draw() graph_image = np.array(canvas.renderer.buffer_rgba()) # Convert RGBA to BGR for OpenCV graph_image = cv2.cvtColor(graph_image, cv2.COLOR_RGBA2BGR) - + plt.close(fig) - - return graph_image \ No newline at end of file + + return graph_image diff --git a/utils/verify_requirements.py b/utils/verify_requirements.py new file mode 100644 index 0000000..227d503 --- /dev/null +++ b/utils/verify_requirements.py @@ -0,0 +1,49 @@ +import ast +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +def gather_imports(file_path: Path): + with file_path.open("r") as f: + tree = ast.parse(f.read(), filename=str(file_path)) + packages = set() + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + packages.add(alias.name.split(".")[0]) + elif isinstance(node, ast.ImportFrom): + if node.level == 0 and node.module: + packages.add(node.module.split(".")[0]) + return packages + + +def load_requirements(req_file: Path): + pkgs = set() + with req_file.open() as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + pkgs.add(line.split("==")[0].split(">=")[0]) + return pkgs + + +def main(): + imports = set() + for py_file in REPO_ROOT.rglob("*.py"): + if py_file.name == "verify_requirements.py" or "venv" in py_file.parts: + continue + imports.update(gather_imports(py_file)) + + requirements = load_requirements(REPO_ROOT / "requirements.txt") + missing = sorted(pkg for pkg in imports if pkg not in requirements) + if missing: + print("Missing packages:", ", ".join(missing)) + sys.exit(1) + print("All imported packages are present in requirements.txt") + + +if __name__ == "__main__": + main()