diff --git a/app/models/__init__.py b/app/models/__init__.py index b77f2b26df..6709e8831d 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -14,7 +14,7 @@ class JSONModelMeta(SerialisedModelMeta, ABCMeta): @total_ordering -class JSONModel(SerialisedModel, ABC, metaclass=JSONModelMeta): +class StrictJSONModel(SerialisedModel, ABC, metaclass=JSONModelMeta): @property @abstractmethod def __sort_attribute__(self): @@ -43,6 +43,14 @@ def __eq__(self, other): def __hash__(self): return hash(self.id) + def _get_by_id(self, things, id): + try: + return next(thing for thing in things if thing["id"] == str(id)) + except StopIteration: + abort(404) + + +class JSONModel(StrictJSONModel): def __init__(self, _dict): # in the case of a bad request _dict may be `None` self._dict = _dict or {} @@ -54,12 +62,6 @@ def __init__(self, _dict): def __bool__(self): return self._dict != {} - def _get_by_id(self, things, id): - try: - return next(thing for thing in things if thing["id"] == str(id)) - except StopIteration: - abort(404) - class ModelList(SerialisedModelCollection): @property diff --git a/app/models/api_key.py b/app/models/api_key.py index 8f279b60f8..83b9fac075 100644 --- a/app/models/api_key.py +++ b/app/models/api_key.py @@ -3,11 +3,11 @@ from flask import abort -from app.models import JSONModel, ModelList +from app.models import ModelList, StrictJSONModel from app.notify_client.api_key_api_client import api_key_api_client -class APIKey(JSONModel): +class APIKey(StrictJSONModel): created_at: datetime created_by: Any expiry_date: datetime diff --git a/app/models/contact_list.py b/app/models/contact_list.py index 5e29142c88..365f37e272 100644 --- a/app/models/contact_list.py +++ b/app/models/contact_list.py @@ -8,7 +8,7 @@ from notifications_utils.recipients import RecipientCSV from werkzeug.utils import cached_property -from app.models import JSONModel, ModelList +from app.models import ModelList, StrictJSONModel from app.models.job import PaginatedJobsAndScheduledJobs from app.notify_client.contact_list_api_client import contact_list_api_client from app.s3_client.s3_csv_client import ( @@ -20,7 +20,7 @@ from app.utils.templates import get_sample_template -class ContactList(JSONModel): +class ContactList(StrictJSONModel): id: Any created_at: datetime created_by: Any diff --git a/app/models/letter_rates.py b/app/models/letter_rates.py index c33d25b00f..39c85e2ea9 100644 --- a/app/models/letter_rates.py +++ b/app/models/letter_rates.py @@ -1,11 +1,11 @@ from datetime import datetime from typing import Any -from app.models import JSONModel, ModelList +from app.models import ModelList, StrictJSONModel from app.notify_client.letter_rate_api_client import letter_rate_api_client -class LetterRate(JSONModel): +class LetterRate(StrictJSONModel): sheet_count: int rate: float post_class: Any diff --git a/app/models/report_request.py b/app/models/report_request.py index c488e6bb73..55b58ba909 100644 --- a/app/models/report_request.py +++ b/app/models/report_request.py @@ -5,11 +5,11 @@ from notifications_utils.s3 import s3download from app import report_request_api_client -from app.models import JSONModel +from app.models import StrictJSONModel from app.s3_client import check_s3_object_exists -class ReportRequest(JSONModel): +class ReportRequest(StrictJSONModel): id: Any user_id: Any service_id: Any diff --git a/app/models/service.py b/app/models/service.py index 1913d525df..e6a00e3657 100644 --- a/app/models/service.py +++ b/app/models/service.py @@ -13,7 +13,7 @@ SIGN_IN_METHOD_TEXT, SIGN_IN_METHOD_TEXT_OR_EMAIL, ) -from app.models import JSONModel +from app.models import JSONModel, StrictJSONModel from app.models.api_key import APIKeys from app.models.branding import EmailBranding, LetterBranding from app.models.contact_list import ContactLists @@ -680,7 +680,7 @@ class Services(SerialisedModelCollection): model = Service -class ServiceJoinRequest(JSONModel): +class ServiceJoinRequest(StrictJSONModel): id: Any requester: Any service_id: Any @@ -690,8 +690,6 @@ class ServiceJoinRequest(JSONModel): reason: str status: str contacted_service_users: list[str] - requested_service: Any - permissions: list[str] __sort_attribute__ = "id" diff --git a/app/models/sms_rate.py b/app/models/sms_rate.py index 015955a7b2..3d4f67b471 100644 --- a/app/models/sms_rate.py +++ b/app/models/sms_rate.py @@ -1,11 +1,11 @@ from datetime import datetime from app.formatters import format_pennies_as_currency -from app.models import JSONModel +from app.models import StrictJSONModel from app.notify_client.sms_rate_client import sms_rate_api_client -class SMSRate(JSONModel): +class SMSRate(StrictJSONModel): rate: float valid_from: datetime diff --git a/tests/__init__.py b/tests/__init__.py index 18d6374a62..0c4465c41f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -375,8 +375,15 @@ def template_version_json(service_id, id_, created_by, version=1, created_at=Non return template -def api_key_json(id_, name, expiry_date=None, key_type="normal"): - return {"id": id_, "name": name, "expiry_date": expiry_date, "key_type": key_type} +def api_key_json(id_, name, expiry_date=None, created_at=None, created_by=None, key_type="normal"): + return { + "id": id_, + "name": name, + "expiry_date": expiry_date, + "key_type": key_type, + "created_at": created_at, + "created_by": created_by, + } def invite_json( diff --git a/tests/app/main/forms/test_create_key_form.py b/tests/app/main/forms/test_create_key_form.py index b7ea82b3af..3e9b412e36 100644 --- a/tests/app/main/forms/test_create_key_form.py +++ b/tests/app/main/forms/test_create_key_form.py @@ -1,8 +1,11 @@ +from uuid import uuid4 + import pytest from werkzeug.datastructures import MultiDict from app.main.forms import CreateKeyForm from app.models.api_key import APIKeys +from tests import api_key_json @pytest.mark.parametrize( @@ -22,14 +25,16 @@ def test_return_validation_error_when_key_name_exists( "app.models.api_key.api_key_api_client.get_api_keys", return_value={ "apiKeys": [ - { - "name": "some key", - "expiry_date": expiry_date, - }, - { - "name": "another key", - "expiry_date": None, - }, + api_key_json( + id_=str(uuid4()), + name="some key", + expiry_date=expiry_date, + ), + api_key_json( + id_=str(uuid4()), + name="another key", + expiry_date=None, + ), ] }, ) diff --git a/tests/app/models/test_base_model.py b/tests/app/models/test_base_model.py index 06eff72b67..8cff6876fb 100644 --- a/tests/app/models/test_base_model.py +++ b/tests/app/models/test_base_model.py @@ -1,6 +1,6 @@ import pytest -from app.models import JSONModel +from app.models import JSONModel, StrictJSONModel def test_looks_up_from_dict(): @@ -58,6 +58,17 @@ class Custom(JSONModel): assert str(e.value) == "'Custom' object has no attribute 'foo'" +def test_strict_model_raises_keyerror_if_item_missing_from_dict_on_instantiation(): + class Custom(StrictJSONModel): + foo: str + __sort_attribute__ = "foo" + + with pytest.raises(KeyError) as e: + Custom({}) + + assert str(e.value) == "'foo'" + + @pytest.mark.parametrize( "json_response", ( diff --git a/tests/app/models/test_contact_list.py b/tests/app/models/test_contact_list.py index 0b9d8ec79b..e3c2e3cfa1 100644 --- a/tests/app/models/test_contact_list.py +++ b/tests/app/models/test_contact_list.py @@ -1,9 +1,10 @@ from app.models.contact_list import ContactList from app.models.job import PaginatedJobs +from tests import contact_list_json def test_get_jobs(mock_get_jobs): - contact_list = ContactList({"id": "a", "service_id": "b"}) + contact_list = ContactList(contact_list_json(id_="a", service_id="b")) assert isinstance(contact_list.get_jobs(page=123), PaginatedJobs) # mock_get_jobs mocks the underlying API client method, not # contact_list.get_jobs diff --git a/tests/conftest.py b/tests/conftest.py index 9cd5cbc1cc..d313b66b06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1909,13 +1909,13 @@ def _create( row_count, template_type, ): - return { - "service_id": service_id, - "upload_id": upload_id, - "original_file_name": original_file_name, - "row_count": row_count, - "template_type": template_type, - } + return contact_list_json( + id_=upload_id, + service_id=service_id, + original_file_name=original_file_name, + row_count=row_count, + template_type=template_type, + ) return mocker.patch( "app.contact_list_api_client.create_contact_list",