From 2813e70f078f31cf6fbb759c2152e16b7e78071c Mon Sep 17 00:00:00 2001 From: Micah Denbraver Date: Sun, 5 May 2024 11:48:43 -0700 Subject: [PATCH 1/6] remove undocumented 'creds' parameter --- push_notifications/apns.py | 37 ++++++++++++++++++------------------ push_notifications/models.py | 8 ++++---- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/push_notifications/apns.py b/push_notifications/apns.py index 04064872..5c6cc58b 100644 --- a/push_notifications/apns.py +++ b/push_notifications/apns.py @@ -16,19 +16,18 @@ from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError -def _apns_create_socket(creds=None, application_id=None): - if creds is None: - if not get_manager().has_auth_token_creds(application_id): - cert = get_manager().get_apns_certificate(application_id) - creds = apns2_credentials.CertificateCredentials(cert) - else: - keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - creds = creds or apns2_credentials.TokenCredentials(keyPath, keyId, teamId) +def _apns_create_socket(application_id=None): + if not get_manager().has_auth_token_creds(application_id): + cert = get_manager().get_apns_certificate(application_id) + creds = apns2_credentials.CertificateCredentials(cert) + else: + keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) + # No use getting a lifetime because this credential is + # ephemeral, but if you're looking at this to see how to + # create a credential, you could also pass the lifetime and + # algorithm. Neither of those settings are exposed in the + # settings API at the moment. + creds = apns2_credentials.TokenCredentials(keyPath, keyId, teamId) client = apns2_client.APNsClient( creds, use_sandbox=get_manager().get_apns_use_sandbox(application_id), @@ -59,9 +58,9 @@ def _apns_prepare( def _apns_send( - registration_id, alert, batch=False, application_id=None, creds=None, **kwargs + registration_id, alert, batch=False, application_id=None, **kwargs ): - client = _apns_create_socket(creds=creds, application_id=application_id) + client = _apns_create_socket(application_id=application_id) notification_kwargs = {} @@ -97,7 +96,7 @@ def _apns_send( ) -def apns_send_message(registration_id, alert, application_id=None, creds=None, **kwargs): +def apns_send_message(registration_id, alert, application_id=None, **kwargs): """ Sends an APNS notification to a single registration_id. This will send the notification as form data. @@ -112,7 +111,7 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, * try: _apns_send( registration_id, alert, application_id=application_id, - creds=creds, **kwargs + **kwargs ) except apns2_errors.APNsException as apns2_exception: if isinstance(apns2_exception, apns2_errors.Unregistered): @@ -124,7 +123,7 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, * def apns_send_bulk_message( - registration_ids, alert, application_id=None, creds=None, **kwargs + registration_ids, alert, application_id=None, **kwargs ): """ Sends an APNS notification to one or more registration_ids. @@ -137,7 +136,7 @@ def apns_send_bulk_message( results = _apns_send( registration_ids, alert, batch=True, application_id=application_id, - creds=creds, **kwargs + **kwargs ) inactive_tokens = [token for token, result in results.items() if result == "Unregistered"] models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update(active=False) diff --git a/push_notifications/models.py b/push_notifications/models.py index 33f44205..4e6c3924 100644 --- a/push_notifications/models.py +++ b/push_notifications/models.py @@ -133,7 +133,7 @@ def get_queryset(self): class APNSDeviceQuerySet(models.query.QuerySet): - def send_message(self, message, creds=None, **kwargs): + def send_message(self, message, **kwargs): if self.exists(): from .apns import apns_send_bulk_message @@ -146,7 +146,7 @@ def send_message(self, message, creds=None, **kwargs): ) r = apns_send_bulk_message( registration_ids=reg_ids, alert=message, application_id=app_id, - creds=creds, **kwargs + **kwargs ) if hasattr(r, "keys"): res += [r] @@ -169,13 +169,13 @@ class APNSDevice(Device): class Meta: verbose_name = _("APNS device") - def send_message(self, message, creds=None, **kwargs): + def send_message(self, message, **kwargs): from .apns import apns_send_message return apns_send_message( registration_id=self.registration_id, alert=message, - application_id=self.application_id, creds=creds, + application_id=self.application_id, **kwargs ) From 6b4d226bd4cfd2be45ecccfb85b4daeabe22fcbe Mon Sep 17 00:00:00 2001 From: Micah Denbraver Date: Sun, 5 May 2024 11:51:10 -0700 Subject: [PATCH 2/6] convert connection creation to a context manager --- push_notifications/apns.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/push_notifications/apns.py b/push_notifications/apns.py index 5c6cc58b..a9eb3c53 100644 --- a/push_notifications/apns.py +++ b/push_notifications/apns.py @@ -4,6 +4,7 @@ https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/APNSOverview.html """ +import contextlib import time from apns2 import client as apns2_client @@ -16,6 +17,7 @@ from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError +@contextlib.contextmanager def _apns_create_socket(application_id=None): if not get_manager().has_auth_token_creds(application_id): cert = get_manager().get_apns_certificate(application_id) @@ -34,7 +36,7 @@ def _apns_create_socket(application_id=None): use_alternative_port=get_manager().get_apns_use_alternative_port(application_id) ) client.connect() - return client + yield client def _apns_prepare( @@ -60,8 +62,6 @@ def _apns_prepare( def _apns_send( registration_id, alert, batch=False, application_id=None, **kwargs ): - client = _apns_create_socket(application_id=application_id) - notification_kwargs = {} # if expiration isn"t specified use 1 month from now @@ -78,23 +78,24 @@ def _apns_send( notification_kwargs["collapse_id"] = kwargs.pop("collapse_id", None) - if batch: - data = [apns2_client.Notification( - token=rid, payload=_apns_prepare(rid, alert, **kwargs)) for rid in registration_id] - # returns a dictionary mapping each token to its result. That - # result is either "Success" or the reason for the failure. - return client.send_notification_batch( - data, get_manager().get_apns_topic(application_id=application_id), + with _apns_create_socket(application_id=application_id) as client: + if batch: + data = [apns2_client.Notification( + token=rid, payload=_apns_prepare(rid, alert, **kwargs)) for rid in registration_id] + # returns a dictionary mapping each token to its result. That + # result is either "Success" or the reason for the failure. + return client.send_notification_batch( + data, get_manager().get_apns_topic(application_id=application_id), + **notification_kwargs + ) + + data = _apns_prepare(registration_id, alert, **kwargs) + client.send_notification( + registration_id, data, + get_manager().get_apns_topic(application_id=application_id), **notification_kwargs ) - data = _apns_prepare(registration_id, alert, **kwargs) - client.send_notification( - registration_id, data, - get_manager().get_apns_topic(application_id=application_id), - **notification_kwargs - ) - def apns_send_message(registration_id, alert, application_id=None, **kwargs): """ From fef1bab110ae42574eff67844a15a1dc08ffee1d Mon Sep 17 00:00:00 2001 From: Micah Denbraver Date: Sun, 5 May 2024 17:35:41 -0700 Subject: [PATCH 3/6] use aioapns --- .pre-commit-config.yaml | 2 + push_notifications/apns.py | 310 +++++++++++++++++++------------- push_notifications/models.py | 1 + setup.cfg | 4 +- tests/test_apns_models.py | 225 ++++++++++++----------- tests/test_apns_push_payload.py | 217 ++++++++++++---------- tox.ini | 4 +- 7 files changed, 438 insertions(+), 325 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b0f0162..50575efa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,3 +12,5 @@ repos: rev: v3.15.2 hooks: - id: pyupgrade + args: + - --keep-mock # for AsyncMock in 3.7 diff --git a/push_notifications/apns.py b/push_notifications/apns.py index a9eb3c53..0e3777b9 100644 --- a/push_notifications/apns.py +++ b/push_notifications/apns.py @@ -4,141 +4,207 @@ https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/APNSOverview.html """ +import asyncio import contextlib +import tempfile import time -from apns2 import client as apns2_client -from apns2 import credentials as apns2_credentials -from apns2 import errors as apns2_errors -from apns2 import payload as apns2_payload +import aioapns +from aioapns.common import APNS_RESPONSE_CODE, PRIORITY_HIGH, PRIORITY_NORMAL +from asgiref.sync import async_to_sync from . import models from .conf import get_manager -from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError +from .exceptions import APNSError, APNSServerError, APNSUnsupportedPriority + + +SUCCESS_RESULT = "Success" +UNREGISTERED_RESULT = "Unregistered" @contextlib.contextmanager -def _apns_create_socket(application_id=None): - if not get_manager().has_auth_token_creds(application_id): - cert = get_manager().get_apns_certificate(application_id) - creds = apns2_credentials.CertificateCredentials(cert) - else: - keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - creds = apns2_credentials.TokenCredentials(keyPath, keyId, teamId) - client = apns2_client.APNsClient( - creds, - use_sandbox=get_manager().get_apns_use_sandbox(application_id), - use_alternative_port=get_manager().get_apns_use_alternative_port(application_id) - ) - client.connect() - yield client +def _apns_path_for_cert(cert): + if cert is None: + yield None + with tempfile.NamedTemporaryFile("w") as cert_file: + cert_file.write(cert) + cert_file.flush() + yield cert_file.name + + +def _apns_create_client(application_id=None): + cert = None + key_path = None + key_id = None + team_id = None + + if not get_manager().has_auth_token_creds(application_id): + cert = get_manager().get_apns_certificate(application_id) + else: + key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id) + # No use getting a lifetime because this credential is + # ephemeral, but if you're looking at this to see how to + # create a credential, you could also pass the lifetime and + # algorithm. Neither of those settings are exposed in the + # settings API at the moment. + + with _apns_path_for_cert(cert) as cert_path: + client = aioapns.APNs( + client_cert=cert_path, + key=key_path, + key_id=key_id, + team_id=team_id, + use_sandbox=get_manager().get_apns_use_sandbox(application_id), + ) + + return client def _apns_prepare( - token, alert, application_id=None, badge=None, sound=None, category=None, - content_available=False, action_loc_key=None, loc_key=None, loc_args=[], - extra={}, mutable_content=False, thread_id=None, url_args=None): - if action_loc_key or loc_key or loc_args: - apns2_alert = apns2_payload.PayloadAlert( - body=alert if alert else {}, body_localized_key=loc_key, - body_localized_args=loc_args, action_localized_key=action_loc_key) - else: - apns2_alert = alert - - if callable(badge): - badge = badge(token) - - return apns2_payload.Payload( - alert=apns2_alert, badge=badge, sound=sound, category=category, - url_args=url_args, custom=extra, thread_id=thread_id, - content_available=content_available, mutable_content=mutable_content) - - -def _apns_send( - registration_id, alert, batch=False, application_id=None, **kwargs + token, + alert, + application_id=None, + badge=None, + sound=None, + category=None, + content_available=False, + action_loc_key=None, + loc_key=None, + loc_args=[], + extra={}, + mutable_content=False, + thread_id=None, + url_args=None, ): - notification_kwargs = {} - - # if expiration isn"t specified use 1 month from now - notification_kwargs["expiration"] = kwargs.pop("expiration", None) - if not notification_kwargs["expiration"]: - notification_kwargs["expiration"] = int(time.time()) + 2592000 - - priority = kwargs.pop("priority", None) - if priority: - try: - notification_kwargs["priority"] = apns2_client.NotificationPriority(str(priority)) - except ValueError: - raise APNSUnsupportedPriority("Unsupported priority %d" % (priority)) - - notification_kwargs["collapse_id"] = kwargs.pop("collapse_id", None) - - with _apns_create_socket(application_id=application_id) as client: - if batch: - data = [apns2_client.Notification( - token=rid, payload=_apns_prepare(rid, alert, **kwargs)) for rid in registration_id] - # returns a dictionary mapping each token to its result. That - # result is either "Success" or the reason for the failure. - return client.send_notification_batch( - data, get_manager().get_apns_topic(application_id=application_id), - **notification_kwargs - ) - - data = _apns_prepare(registration_id, alert, **kwargs) - client.send_notification( - registration_id, data, - get_manager().get_apns_topic(application_id=application_id), - **notification_kwargs - ) + if action_loc_key or loc_key or loc_args: + alert_payload = { + "body": alert if alert else {}, + "body_localized_key": loc_key, + "body_localized_args": loc_args, + "action_localized_key": action_loc_key, + } + else: + alert_payload = alert + + if callable(badge): + badge = badge(token) + + return { + "alert": alert_payload, + "badge": badge, + "sound": sound, + "category": category, + "url_args": url_args, + "custom": extra, + "thread_id": thread_id, + "content_available": content_available, + "mutable_content": mutable_content, + } + + +@async_to_sync +async def _apns_send( + registration_ids, + alert, + application_id=None, + *, + priority=None, + expiration=None, + collapse_id=None, + **kwargs, +): + """Make calls to APNs for each device token (registration_id) provided. + + Since the underlying library (aioapns) is asynchronous, we are + taking advantage of that here and making the requests in parallel. + """ + client = _apns_create_client(application_id=application_id) + + # if expiration isn't specified use 1 month from now + # converting to ttl for underlying library + if expiration: + time_to_live = expiration - int(time.time()) + else: + time_to_live = 2592000 + + if priority is not None: + if str(priority) not in [PRIORITY_HIGH, PRIORITY_NORMAL]: + raise APNSUnsupportedPriority(f"Unsupported priority {priority}") + + # track which device token belongs to each coroutine. + # this allows us to stitch the results back together later + coro_registration_ids = {} + for registration_id in set(registration_ids): + coro = client.send_notification( + aioapns.NotificationRequest( + device_token=registration_id, + message={"aps": _apns_prepare(registration_id, alert, **kwargs)}, + time_to_live=time_to_live, + priority=priority, + collapse_key=collapse_id, + ) + ) + coro_registration_ids[asyncio.create_task(coro)] = registration_id + + # run all of the tasks. this will resolve once all requests are complete + done, _ = await asyncio.wait(coro_registration_ids.keys()) + + # recombine task results with their device tokens + results = {} + for coro in done: + registration_id = coro_registration_ids[coro] + result = await coro + if result.is_successful: + results[registration_id] = SUCCESS_RESULT + else: + results[registration_id] = result.description + + return results def apns_send_message(registration_id, alert, application_id=None, **kwargs): - """ - Sends an APNS notification to a single registration_id. - This will send the notification as form data. - If sending multiple notifications, it is more efficient to use - apns_send_bulk_message() - - Note that if set alert should always be a string. If it is not set, - it won"t be included in the notification. You will need to pass None - to this for silent notifications. - """ - - try: - _apns_send( - registration_id, alert, application_id=application_id, - **kwargs - ) - except apns2_errors.APNsException as apns2_exception: - if isinstance(apns2_exception, apns2_errors.Unregistered): - device = models.APNSDevice.objects.get(registration_id=registration_id) - device.active = False - device.save() - - raise APNSServerError(status=apns2_exception.__class__.__name__) - - -def apns_send_bulk_message( - registration_ids, alert, application_id=None, **kwargs -): - """ - Sends an APNS notification to one or more registration_ids. - The registration_ids argument needs to be a list. - - Note that if set alert should always be a string. If it is not set, - it won"t be included in the notification. You will need to pass None - to this for silent notifications. - """ - - results = _apns_send( - registration_ids, alert, batch=True, application_id=application_id, - **kwargs - ) - inactive_tokens = [token for token, result in results.items() if result == "Unregistered"] - models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update(active=False) - return results + """ + Sends an APNS notification to a single registration_id. + This will send the notification as form data. + If sending multiple notifications, it is more efficient to use + apns_send_bulk_message() + + Note that if set alert should always be a string. If it is not set, + it won"t be included in the notification. You will need to pass None + to this for silent notifications. + """ + + results = _apns_send( + [registration_id], alert, application_id=application_id, **kwargs + ) + result = results[registration_id] + + if result == SUCCESS_RESULT: + return + if result == UNREGISTERED_RESULT: + models.APNSDevice.objects.filter(registration_id=registration_id).update( + active=False + ) + raise APNSServerError(status=result) + + +def apns_send_bulk_message(registration_ids, alert, application_id=None, **kwargs): + """ + Sends an APNS notification to one or more registration_ids. + The registration_ids argument needs to be a list. + + Note that if set alert should always be a string. If it is not set, + it won"t be included in the notification. You will need to pass None + to this for silent notifications. + """ + + results = _apns_send( + registration_ids, alert, application_id=application_id, **kwargs + ) + inactive_tokens = [ + token for token, result in results.items() if result == UNREGISTERED_RESULT + ] + models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( + active=False + ) + return results diff --git a/push_notifications/models.py b/push_notifications/models.py index 4e6c3924..2bfd83fc 100644 --- a/push_notifications/models.py +++ b/push_notifications/models.py @@ -1,6 +1,7 @@ from django.db import models from django.utils.translation import gettext_lazy as _ +from .apns import apns_send_bulk_message from .fields import HexIntegerField from .settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS diff --git a/setup.cfg b/setup.cfg index 99dfc8c8..180b18cb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,9 +35,9 @@ setup_requires = [options.extras_require] APNS = - apns2>=0.3.0 + aioapns + asgiref>=2.0 importlib-metadata;python_version < "3.8" - Django>=2.2 WP = pywebpush>=1.3.0 diff --git a/tests/test_apns_models.py b/tests/test_apns_models.py index bb1041a7..f8100618 100644 --- a/tests/test_apns_models.py +++ b/tests/test_apns_models.py @@ -1,7 +1,7 @@ -from unittest import mock - -from apns2.client import NotificationPriority -from apns2.errors import BadTopic, PayloadTooLarge, Unregistered +import aioapns +import mock +import pytest +from aioapns.common import APNS_RESPONSE_CODE, NotificationResult from django.conf import settings from django.test import TestCase, override_settings @@ -11,103 +11,120 @@ class APNSModelTestCase(TestCase): - def _create_devices(self, devices): - for device in devices: - APNSDevice.objects.create(registration_id=device) - - @override_settings() - def test_apns_send_bulk_message(self): - self._create_devices(["abc", "def"]) - - # legacy conf manager requires a value - settings.PUSH_NOTIFICATIONS_SETTINGS.update({ - "APNS_CERTIFICATE": "/path/to/apns/certificate.pem" - }) - - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification_batch") as s: - APNSDevice.objects.all().send_message("Hello world", expiration=1) - args, kargs = s.call_args - self.assertEqual(args[0][0].token, "abc") - self.assertEqual(args[0][1].token, "def") - self.assertEqual(args[0][0].payload.alert, "Hello world") - self.assertEqual(args[0][1].payload.alert, "Hello world") - self.assertEqual(kargs["expiration"], 1) - - def test_apns_send_message_extra(self): - self._create_devices(["abc"]) - - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - APNSDevice.objects.get().send_message( - "Hello world", expiration=2, priority=5, extra={"foo": "bar"}) - args, kargs = s.call_args - self.assertEqual(args[0], "abc") - self.assertEqual(args[1].alert, "Hello world") - self.assertEqual(args[1].custom, {"foo": "bar"}) - self.assertEqual(kargs["priority"], NotificationPriority.Delayed) - self.assertEqual(kargs["expiration"], 2) - - def test_apns_send_message(self): - self._create_devices(["abc"]) - - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - APNSDevice.objects.get().send_message("Hello world", expiration=1) - args, kargs = s.call_args - self.assertEqual(args[0], "abc") - self.assertEqual(args[1].alert, "Hello world") - self.assertEqual(kargs["expiration"], 1) - - def test_apns_send_message_to_single_device_with_error(self): - # these errors are device specific, device.active will be set false - devices = ["abc"] - self._create_devices(devices) - - with mock.patch("push_notifications.apns._apns_send") as s: - s.side_effect = Unregistered - device = APNSDevice.objects.get(registration_id="abc") - with self.assertRaises(APNSError) as ae: - device.send_message("Hello World!") - self.assertEqual(ae.exception.status, "Unregistered") - self.assertFalse(APNSDevice.objects.get(registration_id="abc").active) - - def test_apns_send_message_to_several_devices_with_error(self): - # these errors are device specific, device.active will be set false - devices = ["abc", "def", "ghi"] - expected_exceptions_statuses = ["PayloadTooLarge", "BadTopic", "Unregistered"] - self._create_devices(devices) - - with mock.patch("push_notifications.apns._apns_send") as s: - s.side_effect = [PayloadTooLarge, BadTopic, Unregistered] - - for idx, token in enumerate(devices): - device = APNSDevice.objects.get(registration_id=token) - with self.assertRaises(APNSError) as ae: - device.send_message("Hello World!") - self.assertEqual(ae.exception.status, expected_exceptions_statuses[idx]) - - if idx == 2: - self.assertFalse(APNSDevice.objects.get(registration_id=token).active) - else: - self.assertTrue(APNSDevice.objects.get(registration_id=token).active) - - def test_apns_send_message_to_bulk_devices_with_error(self): - # these errors are device specific, device.active will be set false - devices = ["abc", "def", "ghi"] - results = {"abc": "PayloadTooLarge", "def": "BadTopic", "ghi": "Unregistered"} - self._create_devices(devices) - - with mock.patch("push_notifications.apns._apns_send") as s: - s.return_value = results - - results = APNSDevice.objects.all().send_message("Hello World!") - - for idx, token in enumerate(devices): - if idx == 2: - self.assertFalse(APNSDevice.objects.get(registration_id=token).active) - else: - self.assertTrue(APNSDevice.objects.get(registration_id=token).active) + def _create_devices(self, devices): + for device in devices: + print("created", device) + APNSDevice.objects.create(registration_id=device) + + @pytest.fixture(autouse=True) + def _apns_client(self): + with mock.patch( + "aioapns.APNs", + **{ + "return_value.send_notification": mock.AsyncMock( + return_value=NotificationResult("xxx", APNS_RESPONSE_CODE.SUCCESS) + ), + } + ) as mock_client_class: + self.apns_client = mock_client_class.return_value + yield + del self.apns_client + + @override_settings() + def test_apns_send_bulk_message(self): + self._create_devices(["abc", "def"]) + + # legacy conf manager requires a value + settings.PUSH_NOTIFICATIONS_SETTINGS.update( + {"APNS_CERTIFICATE": "/path/to/apns/certificate.pem"} + ) + + with mock.patch("time.time", return_value=0): + APNSDevice.objects.all().send_message("Hello world", expiration=1) + requests = {} + for args, kwargs in self.apns_client.send_notification.call_args_list: + assert not kwargs + [request] = args + requests[request.device_token] = request + + self.assertEqual(requests["abc"].message["aps"]["alert"], "Hello world") + self.assertEqual(requests["def"].message["aps"]["alert"], "Hello world") + self.assertEqual(requests["abc"].time_to_live, 1) + + def test_apns_send_message_extra(self): + self._create_devices(["abc"]) + + with mock.patch("time.time", return_value=0): + APNSDevice.objects.get().send_message( + "Hello world", expiration=2, priority=5, extra={"foo": "bar"} + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + self.assertEqual(request.device_token, "abc") + self.assertEqual(request.message["aps"]["alert"], "Hello world") + self.assertEqual(request.message["aps"]["custom"], {"foo": "bar"}) + self.assertEqual(str(request.priority), aioapns.PRIORITY_NORMAL) + self.assertEqual(request.time_to_live, 2) + + def test_apns_send_message(self): + self._create_devices(["abc"]) + with mock.patch("time.time", return_value=0): + APNSDevice.objects.get().send_message("Hello world", expiration=1) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "abc" + assert request.message["aps"]["alert"] == "Hello world" + assert request.time_to_live == 1 + + def test_apns_send_message_to_single_device_with_error(self): + # these errors are device specific, device.active will be set false + devices = ["abc"] + self._create_devices(devices) + + with mock.patch("push_notifications.apns._apns_send") as s: + s.return_value = {"abc": "Unregistered"} + device = APNSDevice.objects.get(registration_id="abc") + with self.assertRaises(APNSError) as ae: + device.send_message("Hello World!") + self.assertEqual(ae.exception.status, "Unregistered") + self.assertFalse(APNSDevice.objects.get(registration_id="abc").active) + + def test_apns_send_message_to_several_devices_with_error(self): + # these errors are device specific, device.active will be set false + devices = {"abc": "PayloadTooLarge", "def": "BadTopic", "ghi": "Unregistered"} + self._create_devices(devices) + + with mock.patch("push_notifications.apns._apns_send") as s: + + for token, status in devices.items(): + s.return_value = {token: status} + device = APNSDevice.objects.get(registration_id=token) + with self.assertRaises(APNSError) as ae: + device.send_message("Hello World!") + + assert ae.exception.status == status + if status == "Unregistered": + assert not APNSDevice.objects.get(registration_id=token).active + else: + assert APNSDevice.objects.get(registration_id=token).active + + def test_apns_send_message_to_bulk_devices_with_error(self): + # these errors are device specific, device.active will be set false + results = {"abc": "PayloadTooLarge", "def": "BadTopic", "ghi": "Unregistered"} + self._create_devices(results.keys()) + + with mock.patch("push_notifications.apns._apns_send") as s: + s.return_value = results + + APNSDevice.objects.all().send_message("Hello World!") + + for token, status in results.items(): + print(token) + if status == "Unregistered": + assert not APNSDevice.objects.get(registration_id=token).active + else: + assert APNSDevice.objects.get(registration_id=token).active diff --git a/tests/test_apns_push_payload.py b/tests/test_apns_push_payload.py index dba72b00..40ad6f2b 100644 --- a/tests/test_apns_push_payload.py +++ b/tests/test_apns_push_payload.py @@ -1,6 +1,7 @@ -from unittest import mock - -from apns2.client import NotificationPriority +import mock +import pytest +from aioapns import PRIORITY_HIGH +from aioapns.common import APNS_RESPONSE_CODE, NotificationResult from django.test import TestCase from push_notifications.apns import _apns_send @@ -8,97 +9,121 @@ class APNSPushPayloadTest(TestCase): + @pytest.fixture(autouse=True) + def _apns_client(self): + with mock.patch( + "aioapns.APNs", + **{ + "return_value.send_notification": mock.AsyncMock( + return_value=NotificationResult("xxx", APNS_RESPONSE_CODE.SUCCESS) + ), + } + ) as mock_client_class: + self.apns_client = mock_client_class.return_value + yield + del self.apns_client + + def test_push_payload(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], + "Hello world", + badge=1, + sound="chime", + extra={"custom_data": 12345}, + expiration=3, + ) + + self.apns_client.send_notification.assert_called_once() + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"] == "Hello world" + assert request.message["aps"]["badge"] == 1 + assert request.message["aps"]["sound"] == "chime" + assert request.message["aps"]["custom"] == {"custom_data": 12345} + assert request.time_to_live == 3 + + def test_push_payload_with_thread_id(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], + "Hello world", + thread_id="565", + sound="chime", + extra={"custom_data": 12345}, + expiration=3, + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"] == "Hello world" + assert request.message["aps"]["thread_id"] == "565" + assert request.message["aps"]["sound"] == "chime" + assert request.message["aps"]["custom"] == {"custom_data": 12345} + assert request.time_to_live == 3 + + def test_push_payload_with_alert_dict(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], + alert={"title": "t1", "body": "b1"}, + sound="chime", + extra={"custom_data": 12345}, + expiration=3, + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"]["body"] == "b1" + assert request.message["aps"]["alert"]["title"] == "t1" + assert request.message["aps"]["sound"] == "chime" + assert request.message["aps"]["custom"] == {"custom_data": 12345} + assert request.time_to_live == 3 + + def test_localised_push_with_empty_body(self): + with mock.patch("time.time", return_value=0): + _apns_send(["123"], None, loc_key="TEST_LOC_KEY", expiration=3) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"]["body_localized_key"] == "TEST_LOC_KEY" + assert request.time_to_live == 3 + + def test_using_extra(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], "sample", extra={"foo": "bar"}, expiration=30, priority=10 + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"] == "sample" + assert request.message["aps"]["custom"] == {"foo": "bar"} + assert str(request.priority) == PRIORITY_HIGH + assert request.time_to_live == 30 + + def test_collapse_id(self): + _apns_send(["123"], "sample", collapse_id="456789") + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"], "sample" + assert request.collapse_key == "456789" - def test_push_payload(self): - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - _apns_send( - "123", "Hello world", badge=1, sound="chime", - extra={"custom_data": 12345}, expiration=3 - ) - - self.assertTrue(s.called) - args, kargs = s.call_args - self.assertEqual(args[0], "123") - self.assertEqual(args[1].alert, "Hello world") - self.assertEqual(args[1].badge, 1) - self.assertEqual(args[1].sound, "chime") - self.assertEqual(args[1].custom, {"custom_data": 12345}) - self.assertEqual(kargs["expiration"], 3) - - def test_push_payload_with_thread_id(self): - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - _apns_send( - "123", "Hello world", thread_id="565", sound="chime", - extra={"custom_data": 12345}, expiration=3 - ) - args, kargs = s.call_args - self.assertEqual(args[0], "123") - self.assertEqual(args[1].alert, "Hello world") - self.assertEqual(args[1].thread_id, "565") - self.assertEqual(args[1].sound, "chime") - self.assertEqual(args[1].custom, {"custom_data": 12345}) - self.assertEqual(kargs["expiration"], 3) - - def test_push_payload_with_alert_dict(self): - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - _apns_send( - "123", alert={"title": "t1", "body": "b1"}, sound="chime", - extra={"custom_data": 12345}, expiration=3 - ) - args, kargs = s.call_args - self.assertEqual(args[0], "123") - self.assertEqual(args[1].alert["body"], "b1") - self.assertEqual(args[1].alert["title"], "t1") - self.assertEqual(args[1].sound, "chime") - self.assertEqual(args[1].custom, {"custom_data": 12345}) - self.assertEqual(kargs["expiration"], 3) - - def test_localised_push_with_empty_body(self): - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - _apns_send("123", None, loc_key="TEST_LOC_KEY", expiration=3) - args, kargs = s.call_args - self.assertEqual(args[0], "123") - self.assertEqual(args[1].alert.body_localized_key, "TEST_LOC_KEY") - self.assertEqual(kargs["expiration"], 3) - - def test_using_extra(self): - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - _apns_send( - "123", "sample", extra={"foo": "bar"}, - expiration=30, priority=10 - ) - args, kargs = s.call_args - self.assertEqual(args[0], "123") - self.assertEqual(args[1].alert, "sample") - self.assertEqual(args[1].custom, {"foo": "bar"}) - self.assertEqual(kargs["priority"], NotificationPriority.Immediate) - self.assertEqual(kargs["expiration"], 30) - - def test_collapse_id(self): - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - _apns_send( - "123", "sample", collapse_id="456789" - ) - args, kargs = s.call_args - self.assertEqual(args[0], "123") - self.assertEqual(args[1].alert, "sample") - self.assertEqual(kargs["collapse_id"], "456789") - - def test_bad_priority(self): - with mock.patch("apns2.credentials.init_context"): - with mock.patch("apns2.client.APNsClient.connect"): - with mock.patch("apns2.client.APNsClient.send_notification") as s: - self.assertRaises(APNSUnsupportedPriority, _apns_send, "123", "_" * 2049, priority=24) - s.assert_has_calls([]) + def test_bad_priority(self): + with pytest.raises(APNSUnsupportedPriority): + _apns_send(["123"], "_" * 2049, priority=24) + self.apns_client.send_notification.assert_not_called() diff --git a/tox.ini b/tox.ini index 7f3f8826..5383d01e 100644 --- a/tox.ini +++ b/tox.ini @@ -29,7 +29,9 @@ commands = pytest pytest --ds=tests.settings_unique tests/tst_unique.py deps = - apns2 + aioapns + asgiref + mock pytest pytest-cov pytest-django From 8d79718676a0cc314306a180f6bbf9cc706bd638 Mon Sep 17 00:00:00 2001 From: Micah Denbraver Date: Sun, 5 May 2024 18:03:17 -0700 Subject: [PATCH 4/6] add python 3.10-3.11, and django 4.1-5.0 --- .github/workflows/test.yml | 2 +- tox.ini | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 75779548..ab63c42f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.7', '3.8', '3.9'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v2 diff --git a/tox.ini b/tox.ini index 5383d01e..8ad3cdb0 100644 --- a/tox.ini +++ b/tox.ini @@ -2,8 +2,11 @@ skipsdist = False usedevelop = true envlist = - py{37,38,39}-dj{22,32} - py{38,39}-dj{40,405} + py{37,38,39}-dj22 + py{38,39,310}-dj40 + py{38,39,310,311}-dj41 + py{38,39,310,311}-dj42 + py{310,311}-dj50 flake8 [gh-actions] @@ -11,13 +14,17 @@ python = 3.7: py37 3.8: py38 3.9: py39, flake8 + 3.10: py310 + 3.11: py311 [gh-actions:env] DJANGO = 2.2: dj22 3.2: dj32 4.0: dj40 - 4.0.5: dj405 + 4.1: dj41 + 4.2: dj42 + 5.0: dj50 [testenv] usedevelop = true @@ -40,8 +47,10 @@ deps = firebase-admin>=6.2 dj22: Django>=2.2,<3.0 dj32: Django>=3.2,<3.3 - dj40: Django>=4.0,<4.0.5 - dj405: Django>=4.0.5,<4.1 + dj40: Django>=4.0,<4.1 + dj41: Django>=4.1.3,<4.2 + dj42: Django>=4.2.8,<4.3 + dj50: Django>=5.0,<5.1 [testenv:flake8] commands = flake8 --exit-zero From 445a21640257a704a0c4adf27e4424aa6b41b5dd Mon Sep 17 00:00:00 2001 From: Micah Denbraver Date: Mon, 6 May 2024 21:46:41 -0700 Subject: [PATCH 5/6] swap back to tabs --- push_notifications/apns.py | 332 ++++++++++++++++---------------- tests/test_apns_models.py | 234 +++++++++++----------- tests/test_apns_push_payload.py | 236 +++++++++++------------ 3 files changed, 401 insertions(+), 401 deletions(-) diff --git a/push_notifications/apns.py b/push_notifications/apns.py index 0e3777b9..32d04eac 100644 --- a/push_notifications/apns.py +++ b/push_notifications/apns.py @@ -24,187 +24,187 @@ @contextlib.contextmanager def _apns_path_for_cert(cert): - if cert is None: - yield None - with tempfile.NamedTemporaryFile("w") as cert_file: - cert_file.write(cert) - cert_file.flush() - yield cert_file.name + if cert is None: + yield None + with tempfile.NamedTemporaryFile("w") as cert_file: + cert_file.write(cert) + cert_file.flush() + yield cert_file.name def _apns_create_client(application_id=None): - cert = None - key_path = None - key_id = None - team_id = None - - if not get_manager().has_auth_token_creds(application_id): - cert = get_manager().get_apns_certificate(application_id) - else: - key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - - with _apns_path_for_cert(cert) as cert_path: - client = aioapns.APNs( - client_cert=cert_path, - key=key_path, - key_id=key_id, - team_id=team_id, - use_sandbox=get_manager().get_apns_use_sandbox(application_id), - ) - - return client + cert = None + key_path = None + key_id = None + team_id = None + + if not get_manager().has_auth_token_creds(application_id): + cert = get_manager().get_apns_certificate(application_id) + else: + key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id) + # No use getting a lifetime because this credential is + # ephemeral, but if you're looking at this to see how to + # create a credential, you could also pass the lifetime and + # algorithm. Neither of those settings are exposed in the + # settings API at the moment. + + with _apns_path_for_cert(cert) as cert_path: + client = aioapns.APNs( + client_cert=cert_path, + key=key_path, + key_id=key_id, + team_id=team_id, + use_sandbox=get_manager().get_apns_use_sandbox(application_id), + ) + + return client def _apns_prepare( - token, - alert, - application_id=None, - badge=None, - sound=None, - category=None, - content_available=False, - action_loc_key=None, - loc_key=None, - loc_args=[], - extra={}, - mutable_content=False, - thread_id=None, - url_args=None, + token, + alert, + application_id=None, + badge=None, + sound=None, + category=None, + content_available=False, + action_loc_key=None, + loc_key=None, + loc_args=[], + extra={}, + mutable_content=False, + thread_id=None, + url_args=None, ): - if action_loc_key or loc_key or loc_args: - alert_payload = { - "body": alert if alert else {}, - "body_localized_key": loc_key, - "body_localized_args": loc_args, - "action_localized_key": action_loc_key, - } - else: - alert_payload = alert - - if callable(badge): - badge = badge(token) - - return { - "alert": alert_payload, - "badge": badge, - "sound": sound, - "category": category, - "url_args": url_args, - "custom": extra, - "thread_id": thread_id, - "content_available": content_available, - "mutable_content": mutable_content, - } + if action_loc_key or loc_key or loc_args: + alert_payload = { + "body": alert if alert else {}, + "body_localized_key": loc_key, + "body_localized_args": loc_args, + "action_localized_key": action_loc_key, + } + else: + alert_payload = alert + + if callable(badge): + badge = badge(token) + + return { + "alert": alert_payload, + "badge": badge, + "sound": sound, + "category": category, + "url_args": url_args, + "custom": extra, + "thread_id": thread_id, + "content_available": content_available, + "mutable_content": mutable_content, + } @async_to_sync async def _apns_send( - registration_ids, - alert, - application_id=None, - *, - priority=None, - expiration=None, - collapse_id=None, - **kwargs, + registration_ids, + alert, + application_id=None, + *, + priority=None, + expiration=None, + collapse_id=None, + **kwargs, ): - """Make calls to APNs for each device token (registration_id) provided. - - Since the underlying library (aioapns) is asynchronous, we are - taking advantage of that here and making the requests in parallel. - """ - client = _apns_create_client(application_id=application_id) - - # if expiration isn't specified use 1 month from now - # converting to ttl for underlying library - if expiration: - time_to_live = expiration - int(time.time()) - else: - time_to_live = 2592000 - - if priority is not None: - if str(priority) not in [PRIORITY_HIGH, PRIORITY_NORMAL]: - raise APNSUnsupportedPriority(f"Unsupported priority {priority}") - - # track which device token belongs to each coroutine. - # this allows us to stitch the results back together later - coro_registration_ids = {} - for registration_id in set(registration_ids): - coro = client.send_notification( - aioapns.NotificationRequest( - device_token=registration_id, - message={"aps": _apns_prepare(registration_id, alert, **kwargs)}, - time_to_live=time_to_live, - priority=priority, - collapse_key=collapse_id, - ) - ) - coro_registration_ids[asyncio.create_task(coro)] = registration_id - - # run all of the tasks. this will resolve once all requests are complete - done, _ = await asyncio.wait(coro_registration_ids.keys()) - - # recombine task results with their device tokens - results = {} - for coro in done: - registration_id = coro_registration_ids[coro] - result = await coro - if result.is_successful: - results[registration_id] = SUCCESS_RESULT - else: - results[registration_id] = result.description - - return results + """Make calls to APNs for each device token (registration_id) provided. + + Since the underlying library (aioapns) is asynchronous, we are + taking advantage of that here and making the requests in parallel. + """ + client = _apns_create_client(application_id=application_id) + + # if expiration isn't specified use 1 month from now + # converting to ttl for underlying library + if expiration: + time_to_live = expiration - int(time.time()) + else: + time_to_live = 2592000 + + if priority is not None: + if str(priority) not in [PRIORITY_HIGH, PRIORITY_NORMAL]: + raise APNSUnsupportedPriority(f"Unsupported priority {priority}") + + # track which device token belongs to each coroutine. + # this allows us to stitch the results back together later + coro_registration_ids = {} + for registration_id in set(registration_ids): + coro = client.send_notification( + aioapns.NotificationRequest( + device_token=registration_id, + message={"aps": _apns_prepare(registration_id, alert, **kwargs)}, + time_to_live=time_to_live, + priority=priority, + collapse_key=collapse_id, + ) + ) + coro_registration_ids[asyncio.create_task(coro)] = registration_id + + # run all of the tasks. this will resolve once all requests are complete + done, _ = await asyncio.wait(coro_registration_ids.keys()) + + # recombine task results with their device tokens + results = {} + for coro in done: + registration_id = coro_registration_ids[coro] + result = await coro + if result.is_successful: + results[registration_id] = SUCCESS_RESULT + else: + results[registration_id] = result.description + + return results def apns_send_message(registration_id, alert, application_id=None, **kwargs): - """ - Sends an APNS notification to a single registration_id. - This will send the notification as form data. - If sending multiple notifications, it is more efficient to use - apns_send_bulk_message() - - Note that if set alert should always be a string. If it is not set, - it won"t be included in the notification. You will need to pass None - to this for silent notifications. - """ - - results = _apns_send( - [registration_id], alert, application_id=application_id, **kwargs - ) - result = results[registration_id] - - if result == SUCCESS_RESULT: - return - if result == UNREGISTERED_RESULT: - models.APNSDevice.objects.filter(registration_id=registration_id).update( - active=False - ) - raise APNSServerError(status=result) + """ + Sends an APNS notification to a single registration_id. + This will send the notification as form data. + If sending multiple notifications, it is more efficient to use + apns_send_bulk_message() + + Note that if set alert should always be a string. If it is not set, + it won"t be included in the notification. You will need to pass None + to this for silent notifications. + """ + + results = _apns_send( + [registration_id], alert, application_id=application_id, **kwargs + ) + result = results[registration_id] + + if result == SUCCESS_RESULT: + return + if result == UNREGISTERED_RESULT: + models.APNSDevice.objects.filter(registration_id=registration_id).update( + active=False + ) + raise APNSServerError(status=result) def apns_send_bulk_message(registration_ids, alert, application_id=None, **kwargs): - """ - Sends an APNS notification to one or more registration_ids. - The registration_ids argument needs to be a list. - - Note that if set alert should always be a string. If it is not set, - it won"t be included in the notification. You will need to pass None - to this for silent notifications. - """ - - results = _apns_send( - registration_ids, alert, application_id=application_id, **kwargs - ) - inactive_tokens = [ - token for token, result in results.items() if result == UNREGISTERED_RESULT - ] - models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( - active=False - ) - return results + """ + Sends an APNS notification to one or more registration_ids. + The registration_ids argument needs to be a list. + + Note that if set alert should always be a string. If it is not set, + it won"t be included in the notification. You will need to pass None + to this for silent notifications. + """ + + results = _apns_send( + registration_ids, alert, application_id=application_id, **kwargs + ) + inactive_tokens = [ + token for token, result in results.items() if result == UNREGISTERED_RESULT + ] + models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( + active=False + ) + return results diff --git a/tests/test_apns_models.py b/tests/test_apns_models.py index f8100618..3cf9c7e9 100644 --- a/tests/test_apns_models.py +++ b/tests/test_apns_models.py @@ -11,120 +11,120 @@ class APNSModelTestCase(TestCase): - def _create_devices(self, devices): - for device in devices: - print("created", device) - APNSDevice.objects.create(registration_id=device) - - @pytest.fixture(autouse=True) - def _apns_client(self): - with mock.patch( - "aioapns.APNs", - **{ - "return_value.send_notification": mock.AsyncMock( - return_value=NotificationResult("xxx", APNS_RESPONSE_CODE.SUCCESS) - ), - } - ) as mock_client_class: - self.apns_client = mock_client_class.return_value - yield - del self.apns_client - - @override_settings() - def test_apns_send_bulk_message(self): - self._create_devices(["abc", "def"]) - - # legacy conf manager requires a value - settings.PUSH_NOTIFICATIONS_SETTINGS.update( - {"APNS_CERTIFICATE": "/path/to/apns/certificate.pem"} - ) - - with mock.patch("time.time", return_value=0): - APNSDevice.objects.all().send_message("Hello world", expiration=1) - requests = {} - for args, kwargs in self.apns_client.send_notification.call_args_list: - assert not kwargs - [request] = args - requests[request.device_token] = request - - self.assertEqual(requests["abc"].message["aps"]["alert"], "Hello world") - self.assertEqual(requests["def"].message["aps"]["alert"], "Hello world") - self.assertEqual(requests["abc"].time_to_live, 1) - - def test_apns_send_message_extra(self): - self._create_devices(["abc"]) - - with mock.patch("time.time", return_value=0): - APNSDevice.objects.get().send_message( - "Hello world", expiration=2, priority=5, extra={"foo": "bar"} - ) - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - self.assertEqual(request.device_token, "abc") - self.assertEqual(request.message["aps"]["alert"], "Hello world") - self.assertEqual(request.message["aps"]["custom"], {"foo": "bar"}) - self.assertEqual(str(request.priority), aioapns.PRIORITY_NORMAL) - self.assertEqual(request.time_to_live, 2) - - def test_apns_send_message(self): - self._create_devices(["abc"]) - with mock.patch("time.time", return_value=0): - APNSDevice.objects.get().send_message("Hello world", expiration=1) - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - assert request.device_token == "abc" - assert request.message["aps"]["alert"] == "Hello world" - assert request.time_to_live == 1 - - def test_apns_send_message_to_single_device_with_error(self): - # these errors are device specific, device.active will be set false - devices = ["abc"] - self._create_devices(devices) - - with mock.patch("push_notifications.apns._apns_send") as s: - s.return_value = {"abc": "Unregistered"} - device = APNSDevice.objects.get(registration_id="abc") - with self.assertRaises(APNSError) as ae: - device.send_message("Hello World!") - self.assertEqual(ae.exception.status, "Unregistered") - self.assertFalse(APNSDevice.objects.get(registration_id="abc").active) - - def test_apns_send_message_to_several_devices_with_error(self): - # these errors are device specific, device.active will be set false - devices = {"abc": "PayloadTooLarge", "def": "BadTopic", "ghi": "Unregistered"} - self._create_devices(devices) - - with mock.patch("push_notifications.apns._apns_send") as s: - - for token, status in devices.items(): - s.return_value = {token: status} - device = APNSDevice.objects.get(registration_id=token) - with self.assertRaises(APNSError) as ae: - device.send_message("Hello World!") - - assert ae.exception.status == status - if status == "Unregistered": - assert not APNSDevice.objects.get(registration_id=token).active - else: - assert APNSDevice.objects.get(registration_id=token).active - - def test_apns_send_message_to_bulk_devices_with_error(self): - # these errors are device specific, device.active will be set false - results = {"abc": "PayloadTooLarge", "def": "BadTopic", "ghi": "Unregistered"} - self._create_devices(results.keys()) - - with mock.patch("push_notifications.apns._apns_send") as s: - s.return_value = results - - APNSDevice.objects.all().send_message("Hello World!") - - for token, status in results.items(): - print(token) - if status == "Unregistered": - assert not APNSDevice.objects.get(registration_id=token).active - else: - assert APNSDevice.objects.get(registration_id=token).active + def _create_devices(self, devices): + for device in devices: + print("created", device) + APNSDevice.objects.create(registration_id=device) + + @pytest.fixture(autouse=True) + def _apns_client(self): + with mock.patch( + "aioapns.APNs", + **{ + "return_value.send_notification": mock.AsyncMock( + return_value=NotificationResult("xxx", APNS_RESPONSE_CODE.SUCCESS) + ), + } + ) as mock_client_class: + self.apns_client = mock_client_class.return_value + yield + del self.apns_client + + @override_settings() + def test_apns_send_bulk_message(self): + self._create_devices(["abc", "def"]) + + # legacy conf manager requires a value + settings.PUSH_NOTIFICATIONS_SETTINGS.update( + {"APNS_CERTIFICATE": "/path/to/apns/certificate.pem"} + ) + + with mock.patch("time.time", return_value=0): + APNSDevice.objects.all().send_message("Hello world", expiration=1) + requests = {} + for args, kwargs in self.apns_client.send_notification.call_args_list: + assert not kwargs + [request] = args + requests[request.device_token] = request + + self.assertEqual(requests["abc"].message["aps"]["alert"], "Hello world") + self.assertEqual(requests["def"].message["aps"]["alert"], "Hello world") + self.assertEqual(requests["abc"].time_to_live, 1) + + def test_apns_send_message_extra(self): + self._create_devices(["abc"]) + + with mock.patch("time.time", return_value=0): + APNSDevice.objects.get().send_message( + "Hello world", expiration=2, priority=5, extra={"foo": "bar"} + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + self.assertEqual(request.device_token, "abc") + self.assertEqual(request.message["aps"]["alert"], "Hello world") + self.assertEqual(request.message["aps"]["custom"], {"foo": "bar"}) + self.assertEqual(str(request.priority), aioapns.PRIORITY_NORMAL) + self.assertEqual(request.time_to_live, 2) + + def test_apns_send_message(self): + self._create_devices(["abc"]) + with mock.patch("time.time", return_value=0): + APNSDevice.objects.get().send_message("Hello world", expiration=1) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "abc" + assert request.message["aps"]["alert"] == "Hello world" + assert request.time_to_live == 1 + + def test_apns_send_message_to_single_device_with_error(self): + # these errors are device specific, device.active will be set false + devices = ["abc"] + self._create_devices(devices) + + with mock.patch("push_notifications.apns._apns_send") as s: + s.return_value = {"abc": "Unregistered"} + device = APNSDevice.objects.get(registration_id="abc") + with self.assertRaises(APNSError) as ae: + device.send_message("Hello World!") + self.assertEqual(ae.exception.status, "Unregistered") + self.assertFalse(APNSDevice.objects.get(registration_id="abc").active) + + def test_apns_send_message_to_several_devices_with_error(self): + # these errors are device specific, device.active will be set false + devices = {"abc": "PayloadTooLarge", "def": "BadTopic", "ghi": "Unregistered"} + self._create_devices(devices) + + with mock.patch("push_notifications.apns._apns_send") as s: + + for token, status in devices.items(): + s.return_value = {token: status} + device = APNSDevice.objects.get(registration_id=token) + with self.assertRaises(APNSError) as ae: + device.send_message("Hello World!") + + assert ae.exception.status == status + if status == "Unregistered": + assert not APNSDevice.objects.get(registration_id=token).active + else: + assert APNSDevice.objects.get(registration_id=token).active + + def test_apns_send_message_to_bulk_devices_with_error(self): + # these errors are device specific, device.active will be set false + results = {"abc": "PayloadTooLarge", "def": "BadTopic", "ghi": "Unregistered"} + self._create_devices(results.keys()) + + with mock.patch("push_notifications.apns._apns_send") as s: + s.return_value = results + + APNSDevice.objects.all().send_message("Hello World!") + + for token, status in results.items(): + print(token) + if status == "Unregistered": + assert not APNSDevice.objects.get(registration_id=token).active + else: + assert APNSDevice.objects.get(registration_id=token).active diff --git a/tests/test_apns_push_payload.py b/tests/test_apns_push_payload.py index 40ad6f2b..fed960da 100644 --- a/tests/test_apns_push_payload.py +++ b/tests/test_apns_push_payload.py @@ -9,121 +9,121 @@ class APNSPushPayloadTest(TestCase): - @pytest.fixture(autouse=True) - def _apns_client(self): - with mock.patch( - "aioapns.APNs", - **{ - "return_value.send_notification": mock.AsyncMock( - return_value=NotificationResult("xxx", APNS_RESPONSE_CODE.SUCCESS) - ), - } - ) as mock_client_class: - self.apns_client = mock_client_class.return_value - yield - del self.apns_client - - def test_push_payload(self): - with mock.patch("time.time", return_value=0): - _apns_send( - ["123"], - "Hello world", - badge=1, - sound="chime", - extra={"custom_data": 12345}, - expiration=3, - ) - - self.apns_client.send_notification.assert_called_once() - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - assert request.device_token == "123" - assert request.message["aps"]["alert"] == "Hello world" - assert request.message["aps"]["badge"] == 1 - assert request.message["aps"]["sound"] == "chime" - assert request.message["aps"]["custom"] == {"custom_data": 12345} - assert request.time_to_live == 3 - - def test_push_payload_with_thread_id(self): - with mock.patch("time.time", return_value=0): - _apns_send( - ["123"], - "Hello world", - thread_id="565", - sound="chime", - extra={"custom_data": 12345}, - expiration=3, - ) - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - assert request.device_token == "123" - assert request.message["aps"]["alert"] == "Hello world" - assert request.message["aps"]["thread_id"] == "565" - assert request.message["aps"]["sound"] == "chime" - assert request.message["aps"]["custom"] == {"custom_data": 12345} - assert request.time_to_live == 3 - - def test_push_payload_with_alert_dict(self): - with mock.patch("time.time", return_value=0): - _apns_send( - ["123"], - alert={"title": "t1", "body": "b1"}, - sound="chime", - extra={"custom_data": 12345}, - expiration=3, - ) - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - assert request.device_token == "123" - assert request.message["aps"]["alert"]["body"] == "b1" - assert request.message["aps"]["alert"]["title"] == "t1" - assert request.message["aps"]["sound"] == "chime" - assert request.message["aps"]["custom"] == {"custom_data": 12345} - assert request.time_to_live == 3 - - def test_localised_push_with_empty_body(self): - with mock.patch("time.time", return_value=0): - _apns_send(["123"], None, loc_key="TEST_LOC_KEY", expiration=3) - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - assert request.device_token == "123" - assert request.message["aps"]["alert"]["body_localized_key"] == "TEST_LOC_KEY" - assert request.time_to_live == 3 - - def test_using_extra(self): - with mock.patch("time.time", return_value=0): - _apns_send( - ["123"], "sample", extra={"foo": "bar"}, expiration=30, priority=10 - ) - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - assert request.device_token == "123" - assert request.message["aps"]["alert"] == "sample" - assert request.message["aps"]["custom"] == {"foo": "bar"} - assert str(request.priority) == PRIORITY_HIGH - assert request.time_to_live == 30 - - def test_collapse_id(self): - _apns_send(["123"], "sample", collapse_id="456789") - args, kwargs = self.apns_client.send_notification.call_args - [request] = args - - assert not kwargs - assert request.device_token == "123" - assert request.message["aps"]["alert"], "sample" - assert request.collapse_key == "456789" - - def test_bad_priority(self): - with pytest.raises(APNSUnsupportedPriority): - _apns_send(["123"], "_" * 2049, priority=24) - self.apns_client.send_notification.assert_not_called() + @pytest.fixture(autouse=True) + def _apns_client(self): + with mock.patch( + "aioapns.APNs", + **{ + "return_value.send_notification": mock.AsyncMock( + return_value=NotificationResult("xxx", APNS_RESPONSE_CODE.SUCCESS) + ), + } + ) as mock_client_class: + self.apns_client = mock_client_class.return_value + yield + del self.apns_client + + def test_push_payload(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], + "Hello world", + badge=1, + sound="chime", + extra={"custom_data": 12345}, + expiration=3, + ) + + self.apns_client.send_notification.assert_called_once() + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"] == "Hello world" + assert request.message["aps"]["badge"] == 1 + assert request.message["aps"]["sound"] == "chime" + assert request.message["aps"]["custom"] == {"custom_data": 12345} + assert request.time_to_live == 3 + + def test_push_payload_with_thread_id(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], + "Hello world", + thread_id="565", + sound="chime", + extra={"custom_data": 12345}, + expiration=3, + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"] == "Hello world" + assert request.message["aps"]["thread_id"] == "565" + assert request.message["aps"]["sound"] == "chime" + assert request.message["aps"]["custom"] == {"custom_data": 12345} + assert request.time_to_live == 3 + + def test_push_payload_with_alert_dict(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], + alert={"title": "t1", "body": "b1"}, + sound="chime", + extra={"custom_data": 12345}, + expiration=3, + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"]["body"] == "b1" + assert request.message["aps"]["alert"]["title"] == "t1" + assert request.message["aps"]["sound"] == "chime" + assert request.message["aps"]["custom"] == {"custom_data": 12345} + assert request.time_to_live == 3 + + def test_localised_push_with_empty_body(self): + with mock.patch("time.time", return_value=0): + _apns_send(["123"], None, loc_key="TEST_LOC_KEY", expiration=3) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"]["body_localized_key"] == "TEST_LOC_KEY" + assert request.time_to_live == 3 + + def test_using_extra(self): + with mock.patch("time.time", return_value=0): + _apns_send( + ["123"], "sample", extra={"foo": "bar"}, expiration=30, priority=10 + ) + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"] == "sample" + assert request.message["aps"]["custom"] == {"foo": "bar"} + assert str(request.priority) == PRIORITY_HIGH + assert request.time_to_live == 30 + + def test_collapse_id(self): + _apns_send(["123"], "sample", collapse_id="456789") + args, kwargs = self.apns_client.send_notification.call_args + [request] = args + + assert not kwargs + assert request.device_token == "123" + assert request.message["aps"]["alert"], "sample" + assert request.collapse_key == "456789" + + def test_bad_priority(self): + with pytest.raises(APNSUnsupportedPriority): + _apns_send(["123"], "_" * 2049, priority=24) + self.apns_client.send_notification.assert_not_called() From e116a908e835c1b2a7d6ee14249bde948b4b923d Mon Sep 17 00:00:00 2001 From: aalbinati Date: Thu, 23 May 2024 18:11:32 +0000 Subject: [PATCH 6/6] Fixed token based auth --- push_notifications/apns.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/push_notifications/apns.py b/push_notifications/apns.py index 32d04eac..82067e2e 100644 --- a/push_notifications/apns.py +++ b/push_notifications/apns.py @@ -40,20 +40,20 @@ def _apns_create_client(application_id=None): if not get_manager().has_auth_token_creds(application_id): cert = get_manager().get_apns_certificate(application_id) + with _apns_path_for_cert(cert) as cert_path: + client = aioapns.APNs( + client_cert=cert_path, + team_id=team_id, + topic=get_manager().get_apns_topic(application_id), + use_sandbox=get_manager().get_apns_use_sandbox(application_id), + ) else: key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id) - # No use getting a lifetime because this credential is - # ephemeral, but if you're looking at this to see how to - # create a credential, you could also pass the lifetime and - # algorithm. Neither of those settings are exposed in the - # settings API at the moment. - - with _apns_path_for_cert(cert) as cert_path: client = aioapns.APNs( - client_cert=cert_path, key=key_path, key_id=key_id, team_id=team_id, + topic=get_manager().get_apns_topic(application_id), use_sandbox=get_manager().get_apns_use_sandbox(application_id), )