From 4e01c576f771db6ef855afcd50917e4759db8928 Mon Sep 17 00:00:00 2001 From: dhritinaidu Date: Wed, 31 Jul 2024 14:05:59 -0400 Subject: [PATCH] added unit tests, make test, and auto formatted --- .gitignore | 1 + Makefile | 5 + requirements.txt | 1 + src/__init__.py | 6 +- src/__main__.py | 2 + src/test_verificationclassifier.py | 400 +++++++++++++++++++++++++++++ src/verificationclassifier.py | 197 ++++++++------ 7 files changed, 538 insertions(+), 74 deletions(-) create mode 100644 .gitignore create mode 100644 src/test_verificationclassifier.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/Makefile b/Makefile index 632ae72..ccbeebf 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,7 @@ +PYTHONPATH := ./verification-system:$(PYTHONPATH) + module.tar.gz: requirements.txt *.sh src/*.py tar czf module.tar.gz $^ + +test: + PYTHONPATH=$(PYTHONPATH) pytest diff --git a/requirements.txt b/requirements.txt index 4937b07..038d4ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ viam-sdk >= 0.21.0 pillow +pytest-asyncio diff --git a/src/__init__.py b/src/__init__.py index 7419afc..307c1ae 100755 --- a/src/__init__.py +++ b/src/__init__.py @@ -7,4 +7,8 @@ from .verificationclassifier import VerificationSystem -Registry.register_resource_creator(VisionClient.SUBTYPE, VerificationSystem.MODEL, ResourceCreatorRegistration(VerificationSystem.new, VerificationSystem.validate)) +Registry.register_resource_creator( + VisionClient.SUBTYPE, + VerificationSystem.MODEL, + ResourceCreatorRegistration(VerificationSystem.new, VerificationSystem.validate), +) diff --git a/src/__main__.py b/src/__main__.py index fd0db72..8e05636 100755 --- a/src/__main__.py +++ b/src/__main__.py @@ -4,10 +4,12 @@ from viam.module.module import Module from .verificationclassifier import VerificationSystem + async def main(): module = Module.from_args() module.add_model_from_registry(VisionClient.SUBTYPE, VerificationSystem.MODEL) await module.start() + if __name__ == "__main__": asyncio.run(main()) diff --git a/src/test_verificationclassifier.py b/src/test_verificationclassifier.py new file mode 100644 index 0000000..626d840 --- /dev/null +++ b/src/test_verificationclassifier.py @@ -0,0 +1,400 @@ +import time +import pytest +from google.protobuf.struct_pb2 import Struct +from unittest.mock import AsyncMock, MagicMock +from viam.proto.app.robot import ComponentConfig +from viam.services.vision import Vision +from viam.components.camera import Camera +from src.verificationclassifier import VerificationSystem, AlarmState + + +def make_component_config(dictionary): + struct = Struct() + struct.update(dictionary) + return ComponentConfig(attributes=struct) + + +class TestVerificationSystem: + @pytest.fixture + def mock_dependencies(self): + mock_camera = MagicMock(spec=Camera) + mock_vision = MagicMock(spec=Vision) + + mock_config = MagicMock(spec=ComponentConfig) + + mock_attributes = MagicMock() + mock_attributes.fields = { + "camera_name": MagicMock(string_value="mock_camera"), + "trigger_1_detector": MagicMock(string_value="mock_trigger1"), + "trigger_1_labels": MagicMock(list_value=["label1"]), + "trigger_1_confidence": MagicMock(number_value=0.8), + "trigger_2_detector": MagicMock(string_value="mock_trigger2"), + "trigger_2_labels": MagicMock(list_value=["label2"]), + "trigger_2_confidence": MagicMock(number_value=0.6), + "verification_detector": MagicMock(string_value="mock_verification"), + "verification_labels": MagicMock(list_value=["verify_label"]), + "verification_confidence": MagicMock(number_value=0.9), + "countdown_time_s": MagicMock(number_value=15), + "disarmed_time_s": MagicMock(number_value=20), + "alarm_time_s": MagicMock(number_value=25), + "disable_alarm": MagicMock(bool_value=True), + } + mock_config.attributes = mock_attributes + + mock_dependencies = { + Camera.get_resource_name("mock_camera"): mock_camera, + Vision.get_resource_name("mock_trigger1"): mock_vision, + Vision.get_resource_name("mock_trigger2"): mock_vision, + Vision.get_resource_name("mock_verification"): mock_vision, + } + return mock_config, mock_dependencies + + @pytest.fixture + def model(self, mock_dependencies): + mock_config, dependencies = mock_dependencies + return VerificationSystem.new(mock_config, dependencies) + + @pytest.mark.asyncio + async def test_validate_empty_config(self): + empty_config = make_component_config({}) + with pytest.raises(Exception) as excinfo: + VerificationSystem.validate(config=empty_config) + + assert str(excinfo.value) in [ + "attribute 'camera_name' is required and cannot be blank", + "attribute 'trigger_1_confidence' must be between 0.0 and 1.0", + "attribute 'trigger_2_confidence' must be between 0.0 and 1.0", + "attribute 'trigger_2_labels' cannot be empty", + "attribute 'trigger_2_detector' is required and cannot be blank", + "attribute 'verification_confidence' must be between 0.0 and 1.0", + "attribute 'verification_labels' cannot be empty", + "attribute 'verification_detector' is required and cannot be blank", + ] + + @pytest.mark.asyncio + async def test_validate(self, mock_dependencies): + valid_config, _ = mock_dependencies + response = VerificationSystem.validate(config=valid_config) + assert response == [ + "mock_trigger1", + "mock_trigger2", + "mock_camera", + "mock_verification", + ] + + @pytest.mark.asyncio + async def test_validation_with_invalid_camera_name(self, mock_dependencies): + invalid_config, _ = mock_dependencies + invalid_config.attributes.fields["camera_name"].string_value = "" + with pytest.raises(Exception) as excinfo: + VerificationSystem.validate(config=invalid_config) + assert ( + str(excinfo.value) + == "attribute 'camera_name' is required and cannot be blank" + ) + + @pytest.mark.asyncio + async def test_validation_with_invalid_confidence(self, mock_dependencies): + invalid_config, _ = mock_dependencies + invalid_config.attributes.fields["trigger_1_confidence"].number_value = 1.1 + with pytest.raises(Exception) as excinfo: + VerificationSystem.validate(config=invalid_config) + assert ( + str(excinfo.value) + == "attribute 'trigger_1_confidence' must be between 0.0 and 1.0" + ) + + def test_verification_system_initialization(self, model): + system = model + + assert system.camera_name == "mock_camera" + assert system.trigger_1_confidence == 0.8 + assert system.trigger_2_confidence == 0.6 + assert system.verification_confidence == 0.9 + assert system.countdown_time_s == 15 + assert system.disarmed_time_s == 20 + assert system.alarm_time_s == 25 + assert system.disable_alarm is True + + def test_reconfigure(self, model, mock_dependencies): + system = model + config, dependencies = mock_dependencies + system.reconfigure(config, dependencies) + + assert system.alarm_state == AlarmState.TRIGGER_1 + assert system.camera_name == "mock_camera" + assert system.camera == dependencies[Camera.get_resource_name("mock_camera")] + assert ( + system.trigger_1_detector + == dependencies[Vision.get_resource_name("mock_trigger1")] + ) + assert system.trigger_1_labels == ["label1"] + assert system.trigger_1_confidence == 0.8 + assert ( + system.trigger_2_detector + == dependencies[Vision.get_resource_name("mock_trigger2")] + ) + assert system.trigger_2_labels == ["label2"] + assert system.trigger_2_confidence == 0.6 + assert ( + system.verification_detector + == dependencies[Vision.get_resource_name("mock_verification")] + ) + assert system.verification_labels == ["verify_label"] + assert system.verification_confidence == 0.9 + assert system.countdown_time_s == 15 + assert system.disarmed_time_s == 20 + assert system.alarm_time_s == 25 + assert system.disable_alarm is True + + assert system.start_time is not None + assert system.last_disarmed_by == "" + assert system.detect_count == 0 + assert system.detect_limit == 10 + + def test_reconfigure_with_empty(self, mock_dependencies): + config, dependencies = mock_dependencies + + # make them empty + config.attributes.fields["trigger_1_confidence"].number_value = None + config.attributes.fields["trigger_2_confidence"].number_value = None + config.attributes.fields["verification_confidence"].number_value = None + config.attributes.fields["countdown_time_s"].number_value = None + config.attributes.fields["disarmed_time_s"].number_value = None + config.attributes.fields["alarm_time_s"].number_value = None + config.attributes.fields["disable_alarm"].bool_value = None + config.attributes.fields["trigger_1_detector"].string_value = "" + config.attributes.fields["trigger_1_labels"].list_value = [] + config.attributes.fields["trigger_2_detector"].string_value = "" + config.attributes.fields["trigger_2_labels"].list_value = [] + config.attributes.fields["verification_detector"].string_value = "" + config.attributes.fields["verification_labels"].list_value = [] + + dependencies[Vision.get_resource_name("")] = None + + system = VerificationSystem.new(config, dependencies) + system.reconfigure(config, dependencies) + + # check if defaults are set correctly + assert system.alarm_state == AlarmState.TRIGGER_1 + assert system.camera_name == "mock_camera" + assert system.camera == dependencies[Camera.get_resource_name("mock_camera")] + assert system.trigger_1_detector is None + assert system.trigger_1_labels == [] + assert system.trigger_1_confidence == 0.2 + assert system.trigger_2_detector is None + assert system.trigger_2_labels == [] + assert system.trigger_2_confidence == 0.5 + assert system.verification_detector is None + assert system.verification_labels == [] + assert system.verification_confidence == 0.8 + assert system.countdown_time_s == 20 + assert system.disarmed_time_s == 10 + assert system.alarm_time_s == 10 + assert system.disable_alarm is False + + assert system.start_time is not None + assert system.last_disarmed_by == "" + assert system.detect_count == 0 + assert system.detect_limit == 10 + + @pytest.mark.asyncio + async def test_get_properties(self, model): + system = model + + properties = await system.get_properties() + + assert properties.classifications_supported is True + assert properties.detections_supported is False + assert properties.object_point_clouds_supported is False + + @pytest.mark.asyncio + async def test_capture_all_from_camera(self, model): + system = model + + mock_camera = system.camera + mock_camera.get_image = AsyncMock(return_value=b"mock_image_data") + system.get_classifications = AsyncMock(return_value=["class1", "class2"]) + + result = await system.capture_all_from_camera( + camera_name="mock_camera", + return_image=True, + return_classifications=True, + return_detections=True, + return_object_point_clouds=True, + ) + + assert result.image == b"mock_image_data" + assert result.classifications == ["class1", "class2"] + assert result.detections == [] + assert result.objects == [] + + @pytest.mark.asyncio + async def test_get_classifications_countdown(self, model): + system = model + + system.process_image = AsyncMock(return_value="") + + system.alarm_state = AlarmState.COUNTDOWN + system.start_time = time.time() - 5 + system.countdown_time_s = 20 + + result = await system.get_classifications(b"mock_image_data", 1) + + assert result == [{"class_name": "COUNTDOWN: 15 s remain", "confidence": 1.0}] + + @pytest.mark.asyncio + async def test_get_classifications_disarmed(self, model): + system = model + + system.process_image = AsyncMock(return_value="user1") + + system.alarm_state = AlarmState.DISARMED + system.start_time = time.time() - 10 + system.disarmed_time_s = 20 + + result = await system.get_classifications(b"mock_image_data", 1) + + assert result == [ + {"class_name": "DISARMED by user1: 10 s remain", "confidence": 1.0} + ] + + @pytest.mark.asyncio + async def test_get_classifications_no_last_disarmed_by(self, model): + system = model + + system.process_image = AsyncMock(return_value="") + + system.alarm_state = AlarmState.TRIGGER_1 + + result = await system.get_classifications(b"mock_image_data", 1) + + assert result == [{"class_name": "TRIGGER_1", "confidence": 1.0}] + + @pytest.mark.asyncio + async def test_process_image_trigger_1_transitions(self, model): + system = model + + # case 1: with detector + mock_detector = system.trigger_1_detector + mock_detector.get_detections = AsyncMock( + return_value=[MagicMock(class_name="label1", confidence=0.9)] + ) + + system.alarm_state = AlarmState.TRIGGER_1 + system.trigger_1_labels = ["label1"] + system.trigger_1_confidence = 0.8 + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.TRIGGER_2 + assert result == "" + + # case 2: without detector + system.alarm_state = AlarmState.TRIGGER_1 + system.trigger_1_detector = None + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.TRIGGER_2 + assert result == "" + + @pytest.mark.asyncio + async def test_process_image_trigger_2_transitions(self, model): + + # case 1: TRIGGER_2 to COUNTDOWN + system = model + + mock_detector = system.trigger_2_detector + mock_detector.get_detections = AsyncMock( + return_value=[MagicMock(class_name="label2", confidence=0.7)] + ) + + system.alarm_state = AlarmState.TRIGGER_2 + system.trigger_2_labels = ["label2"] + system.trigger_2_confidence = 0.6 + system.disable_alarm = False + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.COUNTDOWN + assert system.detect_count == 0 + assert result == "" + + # case 2: TRIGGER_2 to TRIGGER_1 + system = model + + mock_detector.get_detections = AsyncMock( + return_value=[MagicMock(class_name="other_label", confidence=0.7)] + ) + + system.alarm_state = AlarmState.TRIGGER_2 + system.trigger_2_labels = ["label2"] + system.trigger_2_confidence = 0.6 + system.detect_count = 10 + system.detect_limit = 10 + system.disable_alarm = False + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.TRIGGER_1 + assert system.detect_count == 0 + assert result == "" + + @pytest.mark.asyncio + async def test_process_image_countdown_transitions(self, model): + system = model + + # case 1: COUNTDOWN to ALARM + system.alarm_state = AlarmState.COUNTDOWN + system.start_time = time.time() - 16 # elapsed time + system.countdown_time_s = 15 + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.ALARM + assert result == "" + + # reset + system.alarm_state = AlarmState.COUNTDOWN + system.start_time = time.time() - 5 # elapsed time + + # Scenario 2: COUNTDOWN to DISARMED + mock_verification_detector = system.verification_detector + mock_verification_detector.get_detections = AsyncMock( + return_value=[MagicMock(class_name="verify_label", confidence=1.0)] + ) + + system.verification_labels = ["verify_label"] + system.verification_confidence = 0.9 + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.DISARMED + assert result == "verify_label" + + @pytest.mark.asyncio + async def test_process_image_alarm_to_trigger_1(self, model): + system = model + + system.alarm_state = AlarmState.ALARM + system.start_time = time.time() - 26 # elapsed time + system.alarm_time_s = 25 + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.TRIGGER_1 + assert result == "" + + @pytest.mark.asyncio + async def test_process_image_disarmed_to_trigger_1(self, model): + system = model + + system.alarm_state = AlarmState.DISARMED + system.start_time = time.time() - 21 # elapsed time + system.disarmed_time_s = 20 + + result = await system.process_image(b"mock_image_data") + + assert system.alarm_state == AlarmState.TRIGGER_1 + assert result == "" diff --git a/src/verificationclassifier.py b/src/verificationclassifier.py index 982a056..6d503a9 100755 --- a/src/verificationclassifier.py +++ b/src/verificationclassifier.py @@ -32,9 +32,9 @@ class AlarmState(Enum): class VerificationSystem(Vision, Reconfigurable): - - MODEL: ClassVar[Model] = Model(ModelFamily( - "viam-labs", "classifier"), "verification-system") + MODEL: ClassVar[Model] = Model( + ModelFamily("viam-labs", "classifier"), "verification-system" + ) camera_name: str camera: Camera trigger_1_detector: str # @@ -56,7 +56,9 @@ class VerificationSystem(Vision, Reconfigurable): # Constructor @classmethod - def new(cls, config: ComponentConfig, dependencies: Mapping[ResourceName, ResourceBase]) -> Self: + def new( + cls, config: ComponentConfig, dependencies: Mapping[ResourceName, ResourceBase] + ) -> Self: my_class = cls(config.name) my_class.reconfigure(config, dependencies) return my_class @@ -65,40 +67,46 @@ def new(cls, config: ComponentConfig, dependencies: Mapping[ResourceName, Resour @classmethod def validate(cls, config: ComponentConfig): # verify camera - camera_name = config.attributes.fields["camera_name"].string_value.strip( - ) + camera_name = config.attributes.fields["camera_name"].string_value.strip() if camera_name == "": - raise Exception( - "attribute 'camera_name' is required and cannot be blank") + raise Exception("attribute 'camera_name' is required and cannot be blank") # verify trigger 1 detector if config.attributes.fields["trigger_1_confidence"].number_value > 1.0: raise Exception( - "attribute 'trigger_1_confidence' must be between 0.0 and 1.0") + "attribute 'trigger_1_confidence' must be between 0.0 and 1.0" + ) # verify trigger 2 detector if config.attributes.fields["trigger_2_confidence"].number_value > 1.0: raise Exception( - "attribute 'trigger_2_confidence' must be between 0.0 and 1.0") + "attribute 'trigger_2_confidence' must be between 0.0 and 1.0" + ) if len(config.attributes.fields["trigger_2_labels"].list_value) == 0: raise Exception("attribute 'trigger_2_labels' cannot be empty") - trigger_2_name = config.attributes.fields["trigger_2_detector"].string_value.strip( - ) + trigger_2_name = config.attributes.fields[ + "trigger_2_detector" + ].string_value.strip() if trigger_2_name == "": raise Exception( - "attribute 'trigger_2_detector' is required and cannot be blank") + "attribute 'trigger_2_detector' is required and cannot be blank" + ) # verify verification module if config.attributes.fields["verification_confidence"].number_value > 1.0: raise Exception( - "attribute 'verification_confidence' must be between 0.0 and 1.0") + "attribute 'verification_confidence' must be between 0.0 and 1.0" + ) if len(config.attributes.fields["verification_labels"].list_value) == 0: raise Exception("attribute 'verification_labels' cannot be empty") - verification_name = config.attributes.fields["verification_detector"].string_value.strip( - ) + verification_name = config.attributes.fields[ + "verification_detector" + ].string_value.strip() if verification_name == "": raise Exception( - "attribute 'verification_detector' is required and cannot be blank") + "attribute 'verification_detector' is required and cannot be blank" + ) # return dependencies - trigger_1_name = config.attributes.fields["trigger_1_detector"].string_value.strip( - ) + trigger_1_name = config.attributes.fields[ + "trigger_1_detector" + ].string_value.strip() if trigger_1_name == "": return [trigger_2_name, camera_name, verification_name] else: @@ -107,40 +115,66 @@ def validate(cls, config: ComponentConfig): return [trigger_1_name, trigger_2_name, camera_name, verification_name] # Handles attribute reconfiguration - def reconfigure(self, config: ComponentConfig, dependencies: Mapping[ResourceName, ResourceBase]): + def reconfigure( + self, config: ComponentConfig, dependencies: Mapping[ResourceName, ResourceBase] + ): self.alarm_state = AlarmState.TRIGGER_1 - self.camera_name = config.attributes.fields["camera_name"].string_value.strip( - ) + self.camera_name = config.attributes.fields["camera_name"].string_value.strip() self.camera = dependencies[Camera.get_resource_name(self.camera_name)] # the 1st trigger self.trigger_1_detector = None - trigger_1_name = config.attributes.fields["trigger_1_detector"].string_value.strip( - ) + trigger_1_name = config.attributes.fields[ + "trigger_1_detector" + ].string_value.strip() if trigger_1_name != "": - self.trigger_1_detector = dependencies[Vision.get_resource_name( - trigger_1_name)] + self.trigger_1_detector = dependencies[ + Vision.get_resource_name(trigger_1_name) + ] self.trigger_1_labels = config.attributes.fields["trigger_1_labels"].list_value - self.trigger_1_confidence = config.attributes.fields["trigger_1_confidence"].number_value or 0.2 - # the 2nd trigger - trigger_2_name = config.attributes.fields["trigger_2_detector"].string_value.strip( + self.trigger_1_confidence = ( + config.attributes.fields["trigger_1_confidence"].number_value or 0.2 ) - self.trigger_2_detector = dependencies[Vision.get_resource_name(trigger_2_name)] + # the 2nd trigger + self.trigger_2_detector = None + trigger_2_name = config.attributes.fields[ + "trigger_2_detector" + ].string_value.strip() + if trigger_2_name != "": + self.trigger_2_detector = dependencies[ + Vision.get_resource_name(trigger_2_name) + ] self.trigger_2_labels = config.attributes.fields["trigger_2_labels"].list_value - self.trigger_2_confidence = config.attributes.fields[ - "trigger_2_confidence"].number_value or 0.5 + self.trigger_2_confidence = ( + config.attributes.fields["trigger_2_confidence"].number_value or 0.5 + ) + # the verification module - verification_name = config.attributes.fields["verification_detector"].string_value.strip( + self.verification_detector = None + verification_name = config.attributes.fields[ + "verification_detector" + ].string_value.strip() + if verification_name != "": + self.verification_detector = dependencies[ + Vision.get_resource_name(verification_name) + ] + self.verification_labels = config.attributes.fields[ + "verification_labels" + ].list_value + self.verification_confidence = ( + config.attributes.fields["verification_confidence"].number_value or 0.8 ) - self.verification_detector = dependencies[Vision.get_resource_name( - verification_name)] - self.verification_labels = config.attributes.fields["verification_labels"].list_value - self.verification_confidence = config.attributes.fields[ - "verification_confidence"].number_value or 0.8 + # the timing - self.countdown_time_s = config.attributes.fields["countdown_time_s"].number_value or 20 - self.disarmed_time_s = config.attributes.fields["disarmed_time_s"].number_value or 10 + self.countdown_time_s = ( + config.attributes.fields["countdown_time_s"].number_value or 20 + ) + self.disarmed_time_s = ( + config.attributes.fields["disarmed_time_s"].number_value or 10 + ) self.alarm_time_s = config.attributes.fields["alarm_time_s"].number_value or 10 - self.disable_alarm = config.attributes.fields["disable_alarm"].bool_value or False + self.disable_alarm = ( + config.attributes.fields["disable_alarm"].bool_value or False + ) self.start_time = time.time() self.last_disarmed_by = "" self.detect_count = 0 @@ -161,14 +195,15 @@ async def get_object_point_clouds(self): return async def get_properties( - self, - *, - extra: Optional[Mapping[str, Any]] = None, - timeout: Optional[float] = None) -> Vision.Properties: + self, + *, + extra: Optional[Mapping[str, Any]] = None, + timeout: Optional[float] = None, + ) -> Vision.Properties: return Vision.Properties( - classifications_supported=True, - detections_supported=False, - object_point_clouds_supported=False, + classifications_supported=True, + detections_supported=False, + object_point_clouds_supported=False, ) async def capture_all_from_camera( @@ -185,7 +220,8 @@ async def capture_all_from_camera( result = CaptureAllResult() if camera_name != self.camera_name: raise Exception( - f"camera {camera_name} was not declared in the camera_name dependency") + f"camera {camera_name} was not declared in the camera_name dependency" + ) cam_image = await self.camera.get_image(mime_type="image/jpeg") if return_image: result.image = cam_image @@ -198,27 +234,31 @@ async def capture_all_from_camera( result.objects = [] return result - async def get_classifications_from_camera(self, - camera_name: str, - count: int, - *, - extra: Optional[Dict[str, - Any]] = None, - timeout: Optional[float] = None, - **kwargs) -> List[Classification]: + async def get_classifications_from_camera( + self, + camera_name: str, + count: int, + *, + extra: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> List[Classification]: if camera_name != self.camera_name: raise Exception( - f"camera {camera_name} was not declared in the camera_name dependency") + f"camera {camera_name} was not declared in the camera_name dependency" + ) cam_image = await self.camera.get_image(mime_type="image/jpeg") return await self.get_classifications(cam_image, 1) - async def get_classifications(self, - image: ViamImage, - count: int, - *, - extra: Optional[Dict[str, Any]] = None, - timeout: Optional[float] = None, - **kwargs) -> List[Classification]: + async def get_classifications( + self, + image: ViamImage, + count: int, + *, + extra: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + **kwargs, + ) -> List[Classification]: last_disarmed_by = await self.process_image(image) if last_disarmed_by != "": self.last_disarmed_by = last_disarmed_by @@ -226,12 +266,14 @@ async def get_classifications(self, if self.alarm_state is AlarmState.COUNTDOWN: elapsed_time = time.time() - self.start_time time_remaining = self.countdown_time_s - elapsed_time - class_name = class_name + \ - f": {time_remaining:.0f} s remain" + class_name = class_name + f": {time_remaining:.0f} s remain" if self.alarm_state is AlarmState.DISARMED: elapsed_time = time.time() - self.start_time time_remaining = self.disarmed_time_s - elapsed_time - class_name = class_name + f" by {self.last_disarmed_by}: {time_remaining:.0f} s remain" + class_name = ( + class_name + + f" by {self.last_disarmed_by}: {time_remaining:.0f} s remain" + ) classifications = [{"class_name": class_name, "confidence": 1.0}] return classifications @@ -240,15 +282,21 @@ async def process_image(self, image: ViamImage): if self.trigger_1_detector is None: self.alarm_state = AlarmState.TRIGGER_2 # go straight to trigger 2 else: - detections = await self.trigger_1_detector.get_detections( - image) + detections = await self.trigger_1_detector.get_detections(image) for detection in detections: - if detection.class_name in self.trigger_1_labels and detection.confidence > self.trigger_1_confidence: + if ( + detection.class_name in self.trigger_1_labels + and detection.confidence > self.trigger_1_confidence + ): self.alarm_state = AlarmState.TRIGGER_2 if self.alarm_state is AlarmState.TRIGGER_2: detections = await self.trigger_2_detector.get_detections(image) for detection in detections: - if detection.class_name in self.trigger_2_labels and detection.confidence > self.trigger_2_confidence and not self.disable_alarm: + if ( + detection.class_name in self.trigger_2_labels + and detection.confidence > self.trigger_2_confidence + and not self.disable_alarm + ): self.start_time = time.time() self.alarm_state = AlarmState.COUNTDOWN self.detect_count = 0 @@ -265,7 +313,10 @@ async def process_image(self, image: ViamImage): return "" detections = await self.verification_detector.get_detections(image) for detection in detections: - if detection.class_name in self.verification_labels and detection.confidence > self.verification_confidence: + if ( + detection.class_name in self.verification_labels + and detection.confidence > self.verification_confidence + ): self.start_time = time.time() self.alarm_state = AlarmState.DISARMED return detection.class_name