diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..6a784ea4 --- /dev/null +++ b/.env.example @@ -0,0 +1,9 @@ +# example env file for local docker run + +# Stripe Webhook: +WEBHOOK_URL=http://host.docker.internal:8082/stripe/webhook +WEBHOOK_SIGNING_SECRET=your_secret + +# Stripe Connect Webhook: +CONNECT_WEBHOOK_URL=http://host.docker.internal:8082/stripe/connect-webhook +CONNECT_WEBHOOK_SIGNING_SECRET=your_connect_secret diff --git a/.gitignore b/.gitignore index 85b9fc4c..a65c98c6 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ __pycache__ *.py[cod] /dist /*.egg-info + +.env.* +!.env.example diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..fd4d1f24 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +FROM python:3 + +ENV VAULT_VERSION=1.3.1 + +RUN cd /tmp && \ + wget https://releases.hashicorp.com/vault/${VAULT_VERSION}/vault_${VAULT_VERSION}_linux_amd64.zip && \ + unzip vault_${VAULT_VERSION}_linux_amd64.zip && \ + mv vault /usr/local/bin/vault && \ + rm vault_${VAULT_VERSION}_linux_amd64.zip + +COPY docker/entrypoint.sh /usr/local/bin/entrypoint.sh +RUN chmod +x /usr/local/bin/entrypoint.sh + +COPY . /app + +RUN cd /app && \ + rm dist/localstripe-*.tar.gz && \ + python setup.py sdist && \ + pip install dist/localstripe-*.tar.gz && \ + rm -rf /app + +ENV PORT=8420 +EXPOSE 8420 + +CMD ["/usr/local/bin/entrypoint.sh"] diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..3b529d29 --- /dev/null +++ b/Makefile @@ -0,0 +1,65 @@ +.PHONY: all +all: help + +PROJECTNAME = localstripe +ENV ?= preview +ifeq ($(ENV),sandbox) +REGION = us-east-1 +else +REGION = us-west-1 +endif +LOCAL_TAG = $(ENV)-$(REGION)-localstripe:latest +ECR_URL = 819738237059.dkr.ecr.$(REGION).amazonaws.com +ECR_TAG = $(ECR_URL)/$(ENV)-$(REGION)-localstripe:latest + +SYSTEM = $(shell uname -s) +HOST_PORT ?= 8420 +HOST_NAME=host.docker.internal +RUN_ENV ?= development + +## docker-login: Login to ECR repository +.PHONY: docker-login +docker-login: + @echo " > Logging in to ECR repository for environment $(ENV)" + @aws ecr get-login-password --region $(REGION) | docker login --username AWS --password-stdin $(ECR_URL) + +## docker-build: Build docker image ENV=preview (default) +.PHONY: docker-build +docker-build: + @echo " > Building Docker image $(LOCAL_TAG)" + @docker build -t $(LOCAL_TAG) . + @echo " > Build Completed" + +## docker-push: Push docker image ENV=preview (default) +.PHONY: docker-push +docker-push: + @echo " > Tagging image: $(ECR_TAG)" + @docker tag $(LOCAL_TAG) $(ECR_TAG) + @echo " > Pushing docker image: $(ECR_TAG)" + @docker push $(ECR_TAG) + @echo " > Push Completed" + +## docker-run: Run docker image locally RUN_ENV=development (default) +.PHONY: docker-run +ifeq ($(SYSTEM),Darwin) +docker-run: + @echo " > Running docker image: $(LOCAL_TAG)" + @docker run --rm -p $(HOST_PORT):8420 --env-file .env.$(RUN_ENV) $(LOCAL_TAG) +else +HOST_IP ?= $(shell docker network inspect bridge -f "{{json (index .IPAM.Config 0).Gateway}}") +docker-run: + @echo " > Running docker image: $(LOCAL_TAG) with host $(HOST_NAME):$(HOST_IP)" + @docker run --rm -p $(HOST_PORT):8420 --add-host "$(HOST_NAME):$(HOST_IP)" --env-file .env.$(RUN_ENV) $(LOCAL_TAG) +endif + +## docker-image: Combine docker build and push +.PHONY: docker-image +docker-image: docker-build docker-push + +.PHONY: help +help: Makefile + @echo + @echo "Choose a command to run in: '$(PROJECTNAME)':" + @echo + @sed -n 's/^##//p' $< | column -t -s ':' | sed -e 's/^/ /' + @echo diff --git a/README.rst b/README.rst index a18c0b11..940c2615 100644 --- a/README.rst +++ b/README.rst @@ -51,21 +51,17 @@ Then simply run the command ``localstripe``. The fake Stripe server is now listening on port 8420. Or launch a container using `the Docker image -`_: +`_: .. code:: shell - docker run -p 8420:8420 adrienverge/localstripe:latest + make docker-run ENV=development Docker image can be rebuilt using: .. code:: - docker build --no-cache -t adrienverge/localstripe -< { this.value = { card: { - number: this._inputs.number.value, - exp_month: this._inputs.exp_month.value, - exp_year: '20' + this._inputs.exp_year.value, - cvc: this._inputs.cvc.value, + number: this._inputs.number && this._inputs.number.value, + exp_month: this._inputs.exp_month && this._inputs.exp_month.value, + exp_year: this._inputs.exp_year && '20' + this._inputs.exp_year.value, + cvc: this._inputs.cvc && this._inputs.cvc.value, }, - postal_code: this._inputs.postal_code.value, - } + postal_code: this._inputs.postal_code && this._inputs.postal_code.value, + }; + var evt = { + elementType: this._type, + empty: event.target.value.length == 0, + complete: false, + error: null, + brand: this._cardBrand(), + }; - if (event.target === this._inputs.number && - this.value.card.number.length >= 16) { - this._inputs.exp_month.focus(); - } else if (event.target === this._inputs.exp_month && - parseInt(this.value.card.exp_month) > 1) { - this._inputs.exp_year.focus(); - } else if (event.target === this._inputs.exp_year && - this.value.card.exp_year.length >= 4) { - this._inputs.cvc.focus(); - } else if (event.target === this._inputs.cvc && - this.value.card.cvc.length >= 3) { - this._inputs.postal_code.focus(); + switch (event.target) { + case this._inputs.number: + var numberLen = evt.brand == 'amex' ? 15 : 16; + if (this.value.card.number.length >= numberLen) { + evt.complete = true; + this._inputs.exp_month && this._inputs.exp_month.focus(); + } + break; + case this._inputs.exp_month: + if (parseInt(this.value.card.exp_month) > 1) { + evt.complete = true; + this._inputs.exp_year && this._inputs.exp_year.focus(); + } + break; + case this._inputs.exp_year: + if (this.value.card.exp_year.length >= 4) { + evt.complete = true; + this._inputs.cvc && this._inputs.cvc.focus(); + } + break; + case this._inputs.cvc: + if (this.value.card.cvc.length >= 3) { + evt.complete = true; + this._inputs.postal_code && this._inputs.postal_code.focus(); + } + break; } - (this.listeners['change'] || []).forEach(handler => handler()); + (this.listeners['change'] || []).forEach(handler => handler(evt)); }; Object.keys(this._inputs).forEach(field => { @@ -157,10 +214,39 @@ class Element { field === 'postal_code' ? 5 : field === 'cvc' ? 3 : 2); this._inputs[field].oninput = changed; + this._inputs[field].onblur = () => { + (this.listeners['blur'] || []).forEach(handler => handler()); + }; + this._inputs[field].onfocus = () => { + (this.listeners['focus'] || []).forEach(handler => handler()); + } this._domChildren.push(this._inputs[field]); }); this._domChildren.forEach((child) => domElement.appendChild(child)); + (this.listeners['ready'] || []).forEach(handler => handler()); + } + + _cardBrand() { + if (!this._inputs.number) { + return 'unknown'; + } + + const brands = { + 'visa': '^4', + 'mastercard': '^(?:2(?:22[1-9]|2[3-9]|[3-6]|7[01]|720)|5[1-5])', + 'amex': '^3[47]', + 'discover': '^6(?:011|22|4[4-9]|5)', + 'diners': '^36', + 'jcb': '^35(?:2[89]|[3-8])', + 'unionpay': '^62', + } + Object.keys(brands).forEach(brand => { + if (this._inputs.number.value.match(brands[brand])) { + return brand; + } + }); + return 'unknown'; } unmount() { @@ -172,8 +258,36 @@ class Element { destroy() { this.unmount(); - if (this._stripeElements._cardElement === this) { - this._stripeElements._cardElement = null; + if (this._stripeElements._elements[this._type] === this) { + this._stripeElements._elements[this._type] = null; + } + } + + blur() { + Object.keys(this._inputs).forEach(field => { + this._inputs[field].blur(); + }); + (this.listeners['blur'] || []).forEach(handler => handler()); + } + + focus() { + var field = Object.keys(this._inputs)[0]; + this._inputs[field].focus(); + (this.listeners['focus'] || []).forEach(handler => handler()); + } + + clear() { + Object.keys(this._inputs).forEach(field => { + this._inputs[field].value = ''; + }); + } + + update(options) { + if (!options) { + return; + } + if (options.value && options.value.postalCode && this._inputs.postal_code) { + this._inputs.postal_code.value = options.value.postalCode; } } @@ -181,22 +295,35 @@ class Element { this.listeners[event] = this.listeners[event] || []; this.listeners[event].push(handler); } + + off(event, handler) { + if (handler) { + var i = this.listeners[event].indexOf(handler); + this.listeners[event].splice(i, 1); + } else { + delete this.listeners[event]; + } + } } -Stripe = (apiKey) => { - return { +function Stripe(apiKey) { + var _elements = {}; + return window.stripe = { elements: () => { return { - _cardElement: null, + _elements: _elements, create: function(type, options) { - if (this._cardElement) { - throw new Error("Can only create one Element of type card"); + if (this._elements[type]) { + throw new Error('Can only create one Element of type ' + type); + } + if (!['card', 'cardNumber', 'cardExpiry', 'cardCvc'].includes(type)) { + throw new Error('Element type not supported: ' + type); } - this._cardElement = new Element(this); - return this._cardElement; + this._elements[type] = new Element(this, type); + return this._elements[type]; }, getElement: function(type) { - return this._cardElement; + return this._elements[type]; } }; }, @@ -334,7 +461,7 @@ Stripe = (apiKey) => { ...data.payment_method_data, }}); }, - confirmCardPayment: async (clientSecret, data) => { + confirmCardPayment: async (clientSecret, data, options) => { console.log('localstripe: Stripe().confirmCardPayment()'); try { const success = await openModal( @@ -401,9 +528,98 @@ Stripe = (apiKey) => { } }, - createPaymentMethod: async () => {}, + createPaymentMethod: async (dataOrType, dataOrElement, legacyData) => { + console.log('localstripe: Stripe().createPaymentMethod()'); + try { + let data, element; + let card = {}; + if (typeof dataOrType === 'string') { + if (dataOrElement && dataOrElement.constructor && dataOrElement.constructor.name === 'Element') { + data = legacyData; + element = dataOrElement; + } else { + data = dataOrElement; + } + if (data.type && data.type !== dataOrType) { + return {error: 'The type supplied in payment_method_data is not consistent.'}; + } + data.type = dataOrType; + } else { + data = dataOrType; + element = data.card; + } + + if (element) { + let types = ['card', 'cardNumber', 'cardExpiry', 'cardCvc']; + types.forEach(type => { + let elem = element._stripeElements.getElement(type); + if (elem) { + Object.keys(elem._inputs).forEach(field => { + card[field] = elem._inputs[field].value; + }); + } + }); + } + + const url = `${LOCALSTRIPE_SOURCE}/v1/payment_methods`; + let response = await fetch(url, { + method: 'POST', + body: JSON.stringify({ + key: apiKey, + type: data.type, + card: card, + billing_details: data.billing_details, + }), + }); + const body = await response.json().catch(() => ({})); + if (response.status !== 200 || body.error) { + return {error: body.error}; + } else { + return {paymentMethod: body}; + } + } catch (err) { + if (typeof err === 'object' && err.error) { + return err; + } else { + return {error: err}; + } + } + }, + + confirmPaymentIntent: // deprecated + async function(clientSecret, data) { + return this.confirmCardPayment(clientSecret, data, { + handleActions: false, + }); + }, + + paymentRequest: function() { + return { + listeners: [], + abort: () => {}, + canMakePayment: () => { + return new Promise(resolve => { + resolve(null); + }); + }, + show: () => {}, + update: () => {}, + on: (event, handler) => { + this.listeners[event] = this.listeners[event] || []; + this.listeners[event].push(handler); + }, + off: (event, handler) => { + if (handler) { + var i = this.listeners[event].indexOf(handler); + this.listeners[event].splice(i, 1); + } else { + delete this.listeners[event]; + } + } + }; + }, }; -}; +} console.log('localstripe: The Stripe object was just replaced in the page. ' + 'Stripe elements created from now on will be fake ones, ' + diff --git a/localstripe/resources.py b/localstripe/resources.py index fed2f444..7962fb4c 100644 --- a/localstripe/resources.py +++ b/localstripe/resources.py @@ -17,6 +17,7 @@ import asyncio from datetime import datetime, timedelta import hashlib +import json import pickle import random import re @@ -26,7 +27,8 @@ from dateutil.relativedelta import relativedelta from .errors import UserError -from .webhooks import schedule_webhook +from .webhooks import schedule_webhook, register_webhook, unregister_webhook, \ + list_webhooks # Save built-in keyword `type`, because some classes override it by using @@ -37,6 +39,7 @@ class Store(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.save_to_disk = True def try_load_from_disk(self): try: @@ -44,13 +47,35 @@ def try_load_from_disk(self): old = pickle.load(f) self.clear() self.update(old) + self._restore_webhooks() except FileNotFoundError: pass def dump_to_disk(self): + if not self.save_to_disk: + return with open('/tmp/localstripe.pickle', 'wb') as f: pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) + def load_from_config(self, fp): + conf = json.load(fp) + if conf['WebhookEndpoints']: + self.load_config_webhooks(conf['WebhookEndpoints']) + self._restore_webhooks() + + def load_config_webhooks(self, config): + for name in config: + obj = config[name] + WebhookEndpoint(url=obj['url'], + _secret=obj['secret'], + enabled_events=obj['events']) + + def _restore_webhooks(self): + for key, obj in self.items(): + if not key.startswith('webhook_endpoint:'): + continue + register_webhook(obj.id, obj.url, obj._secret, obj.enabled_events) + def __setitem__(self, *args, **kwargs): super().__setitem__(*args, **kwargs) self.dump_to_disk() @@ -174,12 +199,13 @@ def _update(self, **data): # Do not modify object during checks -> do two loops for key, value in data.items(): if key.startswith('_') or not hasattr(self, key): - raise UserError(400, 'Bad request') + raise UserError( + 400, 'Bad request: field %s not recognised' % key) # Treat metadata differently: do not delete absent fields metadata = data.pop('metadata', None) if metadata: if type(metadata) is not dict: - raise UserError(400, 'Bad request') + raise UserError(400, 'Bad request: invalid metadata') self.metadata = self.metadata or {} for key, value in metadata.items(): self.metadata[key] = value @@ -190,10 +216,10 @@ def _export(self, expand=None): try: if expand is None: expand = [] - assert type(expand) is list - assert all([type(e) is str for e in expand]) - except AssertionError: - raise UserError(400, 'Bad request') + assert type(expand) is list, 'invalid expand' + assert all([type(e) is str for e in expand]), 'invalid expand' + except AssertionError as ex: + raise UserError(400, 'Bad request') from ex if any(len(path.split('.')) > 4 for path in expand): raise UserError( @@ -244,6 +270,129 @@ def do_expand(path, obj): return obj +class Account(StripeObject): + object = 'account' + _id_prefix = 'acct_' + _default_account = None + + def __init__(self, + type=None, + country=None, + email=None, + capabilities=None, + business_type=None, + company=None, + individual=None, + metadata=None, + tos_acceptance=None, + business_profile=None, + default_currency=None, + documents=None, + external_accounts=None, + settings=None, + **kwargs): + if kwargs: + raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) + + try: + assert type in ('custom', 'express', 'standard'), 'invalid type' + assert country is None or _type(country) is str, 'invalid country' + assert email is None or _type(email) is str, 'invalid email' + if type == 'custom': + assert capabilities is not None, \ + 'custom account type must specify capabilities' + if capabilities is not None: + assert _type(capabilities) is dict, \ + 'invalid capabilities' + assert set(capabilities.values()).issubset({ + 'active', 'inactive', 'pending'}), \ + 'invalid capabilities' + assert business_type is None or \ + business_type in ('individual', 'company', 'non_profit', + 'government_entity'), \ + 'invalid business_type' + assert company is None or _type(company) is dict, 'invalid company' + assert individual is None or _type(individual) is dict, \ + 'invalid individual' + assert metadata is None or _type(metadata) is dict, \ + 'invalid metadata' + assert tos_acceptance is None or _type(tos_acceptance) is dict, \ + 'invalid tos_acceptance' + assert business_profile is None or \ + _type(business_profile) is dict, 'invalid business_profile' + assert default_currency is None or \ + _type(default_currency) is str, 'invalid default_currency' + assert documents is None or _type(documents) is dict, \ + 'invalid documents' + if external_accounts is not None: + assert _type(external_accounts) is list, \ + 'invalid external_accounts' + assert all(_type(v) is dict + for v in external_accounts.values()), \ + 'invalid external_accounts' + assert settings is None or _type(settings) is dict, \ + 'invalid settings' + except AssertionError as e: + raise UserError(400, 'Bad request') from e + + # All exceptions must be raised before this point. + super().__init__() + + self.type = type + self.country = country or 'US' + self.email = email + self.capabilities = capabilities or {} + self.business_type = business_type + self.company = company or {} + self.individual = individual or {} + self.metadata = metadata or {} + self.tos_acceptance = tos_acceptance or {} + self.business_profile = business_profile or {} + self.default_currency = default_currency or 'usd' + self.documents = documents or {} + self.external_accounts = List( + '/v1/accounts/' + self.id + '/external_accounts') + self.external_accounts._list = external_accounts or [] + self.settings = settings or {} + self.details_submitted = False + self.payouts_enabled = False + self.charges_enabled = False + + @classmethod + def _api_create(cls, external_account=None, **data): + external_accounts = None + if external_account is not None: + external_accounts = [external_account] + + return super()._api_create(external_accounts=external_accounts, **data) + + @classmethod + def _api_update(cls, id, external_account=None, **data): + obj = cls._api_retrieve(id) + if external_account is not None: + if obj.external_accounts is None: + obj.external_accounts = [external_account] + else: + obj.external_accounts.append(external_account) + + obj._update(**data) + return obj + + @classmethod + def _api_list_all(cls, url, created=None, ending_before=None, **kwargs): + return super()._api_list_all(url, **kwargs) + + @classmethod + def _api_default(cls, **kwargs): + if cls._default_account is None: + cls._default_account = Account(type='standard') + + return cls._default_account + + +extra_apis.append(('GET', '/v1/account', Account._api_default)) + + class Balance(object): object = 'balance' @@ -304,22 +453,25 @@ def __init__(self, amount=None, currency=None, description=None, amount = try_convert_to_int(amount) exchange_rate = try_convert_to_float(exchange_rate) try: - assert _type(amount) is int - assert _type(currency) is str and currency - assert description is None or _type(description) is str - assert exchange_rate is None or _type(exchange_rate) is float - assert reporting_category in ('charge', 'refund') - assert _type(source) is str - assert type in ('charge', 'refund') - except AssertionError: - raise UserError(400, 'Bad request') + assert _type(amount) is int, 'invalid amount' + assert _type(currency) is str and currency, 'invalid currency' + assert description is None or _type(description) is str, \ + 'invalid description' + assert exchange_rate is None or _type(exchange_rate) is float, \ + 'invalid exchange_rate' + assert reporting_category in ('charge', 'refund'), \ + 'invalid reporting_category' + assert _type(source) is str, 'invalid source' + assert type in ('charge', 'refund'), 'invalid type' + except AssertionError as e: + raise UserError(400, 'Bad request') from e if source.startswith('ch_'): Charge._api_retrieve(source) # to return 404 if not existent elif source.startswith('re_'): Refund._api_retrieve(source) # to return 404 if not existent else: - raise UserError(400, 'Bad request') + raise UserError(400, 'Bad request: unknown source prefix') # All exceptions must be raised before this point super().__init__() @@ -376,8 +528,8 @@ def __init__(self, source=None, **kwargs): raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(source) is dict - assert source.get('object') == 'card' + assert type(source) is dict, 'invalid source' + assert source.get('object') == 'card', 'invalid source.object' number = source.get('number') exp_month = try_convert_to_int(source.get('exp_month')) exp_year = try_convert_to_int(source.get('exp_year')) @@ -389,16 +541,19 @@ def __init__(self, source=None, **kwargs): address_state = source.get('address_state') address_zip = source.get('address_zip') name = source.get('name') - assert type(number) is str and len(number) == 16 - assert type(exp_month) is int - assert exp_month >= 1 and exp_month <= 12 - assert type(exp_year) is int + assert type(number) is str and len(number) == 16, \ + 'invalid source.number' + assert type(exp_month) is int, 'invalid source.exp_month' + assert exp_month >= 1 and exp_month <= 12, \ + 'invalid source.exp_month' + assert type(exp_year) is int, 'invalid source.exp_year' if exp_year > 0 and exp_year < 100: exp_year += 2000 - assert exp_year >= 2017 and exp_year <= 2100 - assert type(cvc) is str and len(cvc) == 3 - except AssertionError: - raise UserError(400, 'Bad request') + assert exp_year >= 2017 and exp_year <= 2100, \ + 'invalid source.exp_year' + assert type(cvc) is str and len(cvc) == 3, 'invalid source.cvc' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -456,23 +611,27 @@ def __init__(self, amount=None, currency=None, description=None, amount = try_convert_to_int(amount) capture = try_convert_to_bool(capture) try: - assert type(amount) is int and amount >= 0 - assert type(currency) is str and currency + assert type(amount) is int and amount >= 0, 'invalid amount' + assert type(currency) is str and currency, 'invalid currency' if description is not None: - assert type(description) is str + assert type(description) is str, 'invalid description' if customer is not None: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if source is not None: - assert type(source) is str + assert type(source) is str, 'invalid source' assert (source.startswith('pm_') or source.startswith('src_') - or source.startswith('card_')) - assert type(capture) is bool + or source.startswith('card_')), 'invalid source' + assert type(capture) is bool, 'invalid capture' if statement_descriptor is not None: - assert type(statement_descriptor) is str - assert len(statement_descriptor) <= 22 - assert re.search('[a-zA-Z]', statement_descriptor) - except AssertionError: - raise UserError(400, 'Bad request') + assert type(statement_descriptor) is str, \ + 'invalid statement_descriptor' + assert len(statement_descriptor) <= 22, \ + 'invalid statement_descriptor' + assert re.search('[a-zA-Z]', statement_descriptor), \ + 'invalid statement_descriptor' + except AssertionError as e: + raise UserError(400, 'Bad request') from e if source is None: customer_obj = Customer._api_retrieve(customer) @@ -577,9 +736,9 @@ def _api_capture(cls, id, amount=None, **kwargs): raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(id) is str and id.startswith('ch_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(id) is str and id.startswith('ch_'), 'invalid id' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) @@ -588,10 +747,11 @@ def _api_capture(cls, id, amount=None, **kwargs): amount = try_convert_to_int(amount) try: - assert type(amount) is int and 0 <= amount <= obj.amount - assert obj.captured is False - except AssertionError: - raise UserError(400, 'Bad request') + assert type(amount) is int and 0 <= amount <= obj.amount, \ + 'invalid amount' + assert obj.captured is False, 'invalid captured' + except AssertionError as e: + raise UserError(400, 'Bad request') from e def on_success(): obj.captured = True @@ -624,18 +784,21 @@ def _api_list_all(cls, url, customer=None, created=None, limit=10, starting_after=None): try: if customer is not None: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if created is not None: - assert type(created) in (dict, str) + assert type(created) in (dict, str), 'invalid created' if type(created) is dict: assert len(created.keys()) == 1 and \ - list(created.keys())[0] in ('gt', 'gte', 'lt', 'lte') + list(created.keys())[0] in \ + ('gt', 'gte', 'lt', 'lte'), 'invalid created' date = try_convert_to_int(list(created.values())[0]) elif type(created) is str: date = try_convert_to_int(created) - assert type(date) is int and date > 1500000000 - except AssertionError: - raise UserError(400, 'Bad request') + assert type(date) is int and date > 1500000000, \ + 'invalid created' + except AssertionError as e: + raise UserError(400, 'Bad request') from e if customer: Customer._api_retrieve(customer) # to return 404 if not existant @@ -671,21 +834,26 @@ def __init__(self, id=None, duration=None, amount_off=None, percent_off = try_convert_to_float(percent_off) duration_in_months = try_convert_to_int(duration_in_months) try: - assert type(id) is str and id - assert (amount_off is None) != (percent_off is None) + assert type(id) is str and id, 'invalid id' + assert (amount_off is None) != (percent_off is None), \ + 'invalid amount_off' if amount_off is not None: - assert type(amount_off) is int and amount_off >= 0 + assert type(amount_off) is int and amount_off >= 0, \ + 'invalid amount_off' if percent_off is not None: - assert type(percent_off) is float - assert percent_off >= 0 and percent_off <= 100 - assert duration in ('forever', 'once', 'repeating') + assert type(percent_off) is float, 'invalid percent_off' + assert percent_off >= 0 and percent_off <= 100, \ + 'invalid percent_off' + assert duration in ('forever', 'once', 'repeating'), \ + 'invalid duration' if amount_off is not None: - assert type(currency) is str and currency + assert type(currency) is str and currency, 'invalid currency' if duration == 'repeating': - assert type(duration_in_months) is int - assert duration_in_months > 0 - except AssertionError: - raise UserError(400, 'Bad request') + assert type(duration_in_months) is int, \ + 'invalid duration_in_months' + assert duration_in_months > 0, 'invalid duration_in_months' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__(id) @@ -716,43 +884,51 @@ def __init__(self, name=None, description=None, email=None, try: if name is not None: - assert type(name) is str + assert type(name) is str, 'invalid name' if description is not None: - assert type(description) is str + assert type(description) is str, 'invalid description' if email is not None: - assert type(email) is str + assert type(email) is str, 'invalid email' if phone is not None: - assert type(phone) is str + assert type(phone) is str, 'invalid phone' if address is not None: - assert type(address) is dict + assert type(address) is dict, 'invalid address' assert set(address.keys()).issubset({ 'city', 'country', 'line1', 'line2', 'postal_code', - 'state'}) - assert all(type(f) is str for f in address.values()) + 'state'}), 'invalid address key' + assert all(type(f) is str for f in address.values()), \ + 'invalid address value' if invoice_settings is None: invoice_settings = {} - assert type(invoice_settings) is dict + assert type(invoice_settings) is dict, 'invalid invoice_settings' if 'default_payment_method' not in invoice_settings: invoice_settings['default_payment_method'] = None if invoice_settings['default_payment_method'] is not None: - assert type(invoice_settings['default_payment_method']) is str + assert type(invoice_settings['default_payment_method']) \ + is str, 'invalid invoice_settings default_payment_method' assert (invoice_settings['default_payment_method'] - .startswith('pm_')) + .startswith('pm_')), \ + 'invalid invoice_settings default_payment_method' if business_vat_id is not None: - assert type(business_vat_id) is str + assert type(business_vat_id) is str, 'invalid business_vat_id' if preferred_locales is not None: - assert type(preferred_locales) is list - assert all(type(lo) is str for lo in preferred_locales) + assert type(preferred_locales) is list, \ + 'invalid preferred_locales' + assert all(type(lo) is str for lo in preferred_locales), \ + 'invalid preferred_locales' if tax_id_data is None: tax_id_data = [] - assert type(tax_id_data) is list + assert type(tax_id_data) is list, 'invalid tax_id_data' for data in tax_id_data: - assert type(data) is dict - assert set(data.keys()) == {'type', 'value'} - assert data['type'] in ('eu_vat', 'nz_gst', 'au_abn') - assert type(data['value']) is str and len(data['value']) > 10 - except AssertionError: - raise UserError(400, 'Bad request') + assert type(data) is dict, 'invalid tax_id_data' + assert set(data.keys()) == {'type', 'value'}, \ + 'invalid tax_id_data' + assert data['type'] in ('eu_vat', 'nz_gst', 'au_abn'), \ + 'invalid tax_id_data' + assert type(data['value']) is str and \ + len(data['value']) > 10, 'invalid tax_id_data' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -824,6 +1000,18 @@ def _api_delete(cls, id): schedule_webhook(Event('customer.deleted', obj)) return super()._api_delete(id) + @classmethod + def _api_list_all(cls, url, email=None, created=None, ending_before=None, + limit=None, starting_after=None, **kwargs): + if kwargs: + raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) + + li = super()._api_list_all(url, limit, starting_after, **kwargs) + + if email is not None: + li._list = [c for c in li._list if c.email == email] + return li + @classmethod def _api_retrieve_source(cls, id, source_id, **kwargs): if kwargs: @@ -839,7 +1027,7 @@ def _api_retrieve_source(cls, id, source_id, **kwargs): if source_obj.customer != id: raise UserError(404, 'This customer does not own this card') else: - raise UserError(400, 'Bad request') + raise UserError(400, 'Bad request: unknown source prefix') return source_obj @@ -855,11 +1043,11 @@ def _api_add_source(cls, id, source=None, **kwargs): try: if type(source) is str: - assert source[:4] in ('src_', 'tok_') + assert source[:4] in ('src_', 'tok_'), 'invalid source' else: - assert type(source) is dict - except AssertionError: - raise UserError(400, 'Bad request') + assert type(source) is dict, 'invalid source' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) @@ -908,10 +1096,10 @@ def _api_add_tax_id(cls, id, type=None, value=None, **kwargs): raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type in ('eu_vat', 'nz_gst', 'au_abn') - assert _type(value) is str and len(value) > 10 - except AssertionError: - raise UserError(400, 'Bad request') + assert type in ('eu_vat', 'nz_gst', 'au_abn'), 'invalid type' + assert _type(value) is str and len(value) > 10, 'invalid value' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) @@ -1029,27 +1217,31 @@ def __init__(self, customer=None, subscription=None, metadata=None, tax_percent = try_convert_to_float(tax_percent) date = try_convert_to_int(date) try: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if subscription is not None: - assert type(subscription) is str - assert subscription.startswith('sub_') + assert type(subscription) is str, 'invalid subscription' + assert subscription.startswith('sub_'), 'invalid subscription' if date is not None: - assert type(date) is int and date > 1500000000 + assert type(date) is int and date > 1500000000, 'invalid date' else: date = int(time.time()) if description is not None: - assert type(description) is str + assert type(description) is str, 'invalid description' if tax_percent is not None: - assert default_tax_rates is None - assert type(tax_percent) is float - assert tax_percent >= 0 and tax_percent <= 100 + assert default_tax_rates is None, 'invalid default_tax_rates' + assert type(tax_percent) is float, 'invalid tax_percent' + assert tax_percent >= 0 and tax_percent <= 100, \ + 'invalid tax_percent' if default_tax_rates is not None: - assert tax_percent is None - assert type(default_tax_rates) is list + assert tax_percent is None, 'invalid tax_percent' + assert type(default_tax_rates) is list, \ + 'invalid default_tax_rates' assert all(type(txr) is str and txr.startswith('txr_') - for txr in default_tax_rates) - except AssertionError: - raise UserError(400, 'Bad request') + for txr in default_tax_rates), \ + 'invalid default_tax_rates' + except AssertionError as e: + raise UserError(400, 'Bad request') from e Customer._api_retrieve(customer) # to return 404 if not existant @@ -1186,12 +1378,14 @@ def charge(self): return pi.charges._list[-1] def _finalize(self): - assert self.status == 'draft' + assert self.status == 'draft', \ + 'cannot transition from status %s' % self.status self._draft = False self.status_transitions['finalized_at'] = int(time.time()) def _on_payment_success(self): - assert self.status == 'paid' + assert self.status == 'paid', \ + 'cannot transition from status %s' % self.status self.status_transitions['paid_at'] = int(time.time()) schedule_webhook(Event('invoice.payment_succeeded', self)) if self.subscription: @@ -1199,7 +1393,8 @@ def _on_payment_success(self): sub._on_initial_payment_success(self) def _on_payment_failure_now(self): - assert self.status in ('open', 'void') + assert self.status in ('open', 'void'), \ + 'cannot transition from status %s' % self.status if self.status == 'void': self.status_transitions['voided_at'] = int(time.time()) schedule_webhook(Event('invoice.payment_failed', self)) @@ -1211,7 +1406,8 @@ def _on_payment_failure_now(self): sub._on_recurring_payment_failure(self) def _on_payment_failure_later(self): - assert self.status in ('open', 'void') + assert self.status in ('open', 'void'), \ + 'cannot transition from status %s' % self.status if self.status == 'void': self.status_transitions['voided_at'] = int(time.time()) schedule_webhook(Event('invoice.payment_failed', self)) @@ -1238,31 +1434,45 @@ def _get_next_invoice(cls, customer=None, subscription=None, subscription_proration_date = \ try_convert_to_int(subscription_proration_date) try: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if default_tax_rates is not None: - assert type(default_tax_rates) is list + assert type(default_tax_rates) is list, \ + 'invalid default_tax_rates' assert all(type(txr) is str and txr.startswith('txr_') - for txr in default_tax_rates) + for txr in default_tax_rates), \ + 'invalid default_tax_rates' if subscription_items is not None: - assert type(subscription_items) is list + assert type(subscription_items) is list, \ + 'invalid subscription_items' for si in subscription_items: - assert type(si.get('plan')) is str + assert type(si.get('plan')) is str, \ + 'invalid subscription_items plan' si['tax_rates'] = si.get('tax_rates') if si['tax_rates'] is not None: - assert type(si['tax_rates']) is list - assert all(type(tr) is str for tr in si['tax_rates']) + assert type(si['tax_rates']) is list, \ + 'invalid subscription_items tax_rates' + assert all(type(tr) is str + for tr in si['tax_rates']), \ + 'invalid subscription_items tax_rates' if subscription_default_tax_rates is not None: - assert subscription_tax_percent is None - assert type(subscription_default_tax_rates) is list + assert subscription_tax_percent is None, \ + 'invalid subscription_default_tax_rates' + assert type(subscription_default_tax_rates) is list, \ + 'invalid subscription_default_tax_rates' assert all(type(txr) is str and txr.startswith('txr_') - for txr in subscription_default_tax_rates) + for txr in subscription_default_tax_rates), \ + 'invalid subscription_default_tax_rates' assert all(type(tr) is str - for tr in subscription_default_tax_rates) + for tr in subscription_default_tax_rates), \ + 'invalid subscription_default_tax_rates' if subscription_proration_date is not None: - assert type(subscription_proration_date) is int - assert subscription_proration_date > 1500000000 - except AssertionError: - raise UserError(400, 'Bad request') + assert type(subscription_proration_date) is int, \ + 'invalid subscription_proration_date' + assert subscription_proration_date > 1500000000, \ + 'invalid subscription_proration_date' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # return 404 if not existant customer_obj = Customer._api_retrieve(customer) @@ -1282,7 +1492,9 @@ def _get_next_invoice(cls, customer=None, subscription=None, if ii.invoice is None] if (not upcoming and not subscription and not subscription_items and not pending_items): - raise UserError(400, 'Bad request') + raise UserError( + 400, 'Bad request: need one of upcoming, subscription, ' + + 'subscription_items, pending_items') simulation = subscription_items is not None or \ subscription_prorate is not None or \ @@ -1399,7 +1611,8 @@ def _api_create(cls, customer=None, subscription=None, tax_percent=None, def _api_delete(cls, id): obj = cls._api_retrieve(id) if obj.status != 'draft': - raise UserError(400, 'Bad request') + raise UserError( + 400, 'Bad request: can only delete if status is draft.') return super()._api_delete(id) @@ -1408,12 +1621,13 @@ def _api_list_all(cls, url, customer=None, subscription=None, limit=None, starting_after=None): try: if customer is not None: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if subscription is not None: - assert type(subscription) is str - assert subscription.startswith('sub_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(subscription) is str, 'invalid subscription' + assert subscription.startswith('sub_'), 'invalid subscription' + except AssertionError as e: + raise UserError(400, 'Bad request') from e li = super(Invoice, cls)._api_list_all(url, limit=limit, starting_after=starting_after) @@ -1458,7 +1672,7 @@ def _api_pay_invoice(cls, id): if obj.status == 'paid': raise UserError(400, 'Invoice is already paid') elif obj.status not in ('draft', 'open'): - raise UserError(400, 'Bad request') + raise UserError(400, 'Bad request: status must be draft or open') obj._draft = False @@ -1484,7 +1698,7 @@ def _api_void_invoice(cls, id): obj = Invoice._api_retrieve(id) if obj.status not in ('draft', 'open'): - raise UserError(400, 'Bad request') + raise UserError(400, 'Bad request: status must be draft or open') PaymentIntent._api_cancel(obj.payment_intent) @@ -1535,30 +1749,35 @@ def __init__(self, invoice=None, subscription=None, plan=None, amount=None, proration = try_convert_to_bool(proration) try: if invoice is not None: - assert type(invoice) is str and invoice.startswith('in_') + assert type(invoice) is str and invoice.startswith('in_'), \ + 'invalid invoice' if subscription is not None: - assert type(subscription) is str - assert subscription.startswith('sub_') + assert type(subscription) is str, 'invalid subscription' + assert subscription.startswith('sub_'), 'invalid subscription' if plan is not None: - assert type(plan) is str and plan - assert type(amount) is int - assert type(currency) is str and currency - assert type(customer) is str and customer.startswith('cus_') + assert type(plan) is str and plan, 'invalid plan' + assert type(amount) is int, 'invalid amount' + assert type(currency) is str and currency, 'invalid currency' + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if period_start is not None: - assert type(period_start) is int and period_start > 1500000000 - assert type(period_end) is int and period_end > 1500000000 + assert type(period_start) is int and \ + period_start > 1500000000, 'invalid period_start' + assert type(period_end) is int and \ + period_end > 1500000000, 'invalid period_end' else: period_start = period_end = int(time.time()) - assert type(proration) is bool + assert type(proration) is bool, 'invalid proration' if description is not None: - assert type(description) is str + assert type(description) is str, 'invalid description' else: description = 'Invoice item' if tax_rates is not None: - assert type(tax_rates) is list - assert all(type(tr) is str for tr in tax_rates) - except AssertionError: - raise UserError(400, 'Bad request') + assert type(tax_rates) is list, 'invalid tax_rates' + assert all(type(tr) is str for tr in tax_rates), \ + 'invalid tax_rates' + except AssertionError as e: + raise UserError(400, 'Bad request') from e Customer._api_retrieve(customer) # to return 404 if not existant if invoice is not None: @@ -1591,9 +1810,10 @@ def _api_list_all(cls, url, customer=None, limit=None, starting_after=None): try: if customer is not None: - assert type(customer) is str and customer.startswith('cus_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' + except AssertionError as e: + raise UserError(400, 'Bad request') from e li = super(InvoiceItem, cls)._api_list_all(url, limit=limit, @@ -1612,9 +1832,10 @@ class InvoiceLineItem(StripeObject): def __init__(self, item): try: - assert isinstance(item, (InvoiceItem, SubscriptionItem)) - except AssertionError: - raise UserError(400, 'Bad request') + assert isinstance(item, (InvoiceItem, SubscriptionItem)), \ + 'invalid item' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -1672,11 +1893,13 @@ def __init__(self, url=None, limit=None, starting_after=None): limit = try_convert_to_int(limit) limit = 10 if limit is None else limit try: - assert type(limit) is int and limit > 0 + assert type(limit) is int, 'invalid limit type %s' % limit + assert limit > 0, 'invalid limit, must be greater than 0' if starting_after is not None: - assert type(starting_after) is str and len(starting_after) > 0 - except AssertionError: - raise UserError(400, 'Bad request') + assert type(starting_after) is str and \ + len(starting_after) > 0, 'invalid starting_after' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -1730,17 +1953,19 @@ def __init__(self, amount=None, currency=None, customer=None, amount = try_convert_to_int(amount) try: # Invoices with amount == 0 don't create PaymentIntents: - assert type(amount) is int and amount > 0 - assert type(currency) is str and currency + assert type(amount) is int and amount > 0, 'invalid amount' + assert type(currency) is str and currency, 'invalid currency' if customer is not None: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if payment_method is not None: - assert type(payment_method) is str + assert type(payment_method) is str, 'invalid payment_method' assert (payment_method.startswith('pm_') or payment_method.startswith('src_') or - payment_method.startswith('card_')) - except AssertionError: - raise UserError(400, 'Bad request') + payment_method.startswith('card_')), \ + 'invalid payment_method' + except AssertionError as e: + raise UserError(400, 'Bad request') from e if customer: Customer._api_retrieve(customer) # to return 404 if not existant @@ -1766,19 +1991,24 @@ def __init__(self, amount=None, currency=None, customer=None, def _trigger_payment(self): if self.status != 'requires_confirmation': - raise UserError(400, 'Bad request') + raise UserError( + 400, 'Bad request: status must be requires_confirmation') def on_success(): + schedule_webhook(Event('payment_intent.succeeded', self)) if self.invoice: invoice = Invoice._api_retrieve(self.invoice) invoice._on_payment_success() + self._api_delete(self.id) def on_failure_now(): + schedule_webhook(Event('payment_intent.payment_failed', self)) if self.invoice: invoice = Invoice._api_retrieve(self.invoice) invoice._on_payment_failure_now() def on_failure_later(): + schedule_webhook(Event('payment_intent.payment_failed', self)) if self.invoice: invoice = Invoice._api_retrieve(self.invoice) invoice._on_payment_failure_later() @@ -1831,14 +2061,15 @@ def _api_create(cls, confirm=None, off_session=None, **data): off_session = try_convert_to_bool(off_session) try: if confirm is not None: - assert type(confirm) is bool + assert type(confirm) is bool, 'invalid confirm' if off_session is not None: - assert type(off_session) is bool - assert confirm is True - except AssertionError: - raise UserError(400, 'Bad request') + assert type(off_session) is bool, 'invalid off_session' + assert confirm is True, 'invalid confirm' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = super()._api_create(**data) + schedule_webhook(Event('payment_intent.created', obj)) if confirm: cls._api_confirm(obj.id) @@ -1854,14 +2085,15 @@ def _api_confirm(cls, id, payment_method=None, **kwargs): raise UserError(500, 'Not implemented') try: - assert type(id) is str and id.startswith('pi_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(id) is str and id.startswith('pi_'), 'invalid id' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) if obj.status != 'requires_confirmation': - raise UserError(400, 'Bad request') + raise UserError( + 400, 'Bad request: status must be requires_confirmation') obj._authentication_failed = False payment_method = PaymentMethod._api_retrieve(obj.payment_method) @@ -1882,17 +2114,18 @@ def _api_cancel(cls, id, **kwargs): raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(id) is str and id.startswith('pi_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(id) is str and id.startswith('pi_'), 'invalid id' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) if obj.status not in ('requires_payment_method', 'requires_capture', 'requires_confirmation', 'requires_action'): - raise UserError(400, 'Bad request') + raise UserError(400, 'Bad request: invalid status transition') obj._canceled = True obj.next_action = None + schedule_webhook(Event('payment_intent.canceled', obj)) return obj @classmethod @@ -1903,18 +2136,20 @@ def _api_authenticate(cls, id, client_secret=None, success=False, success = try_convert_to_bool(success) try: - assert type(id) is str and id.startswith('pi_') - assert type(client_secret) is str - assert type(success) is bool - except AssertionError: - raise UserError(400, 'Bad request') + assert type(id) is str and id.startswith('pi_'), 'invalid id' + assert type(client_secret) is str, 'invalid client_secret' + assert type(success) is bool, 'invalid success' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) if client_secret != obj.client_secret: raise UserError(401, 'Unauthorized') - if obj.status != 'requires_action': - raise UserError(400, 'Bad request') + if obj.status not in ('requires_action', 'requires_confirmation'): + raise UserError(400, 'Bad request: status must be ' + + 'requires_action or requires_confirmation', + {'status': obj.status}) obj.next_action = None if success: @@ -1946,29 +2181,34 @@ def __init__(self, type=None, billing_details=None, card=None, raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type in ('card', 'sepa_debit') - assert billing_details is None or _type(billing_details) is dict + assert type in ('card', 'sepa_debit'), 'invalid type' + assert billing_details is None or _type(billing_details) is dict, \ + 'invalid billing_details' if type == 'card': assert _type(card) is dict and card.keys() == { - 'number', 'exp_month', 'exp_year', 'cvc'} + 'number', 'exp_month', 'exp_year', 'cvc'}, 'invalid card' card['exp_month'] = try_convert_to_int(card['exp_month']) card['exp_year'] = try_convert_to_int(card['exp_year']) - assert _type(card['number']) is str - assert _type(card['exp_month']) is int - assert _type(card['exp_year']) is int - assert _type(card['cvc']) is str - assert len(card['number']) == 16 - assert card['exp_month'] >= 1 and card['exp_month'] <= 12 + assert _type(card['number']) is str, 'invalid card.number' + assert _type(card['exp_month']) is int, \ + 'invalid card.exp_month' + assert _type(card['exp_year']) is int, 'invalid card.exp_year' + assert _type(card['cvc']) is str, 'invalid card.cvc' + assert len(card['number']) == 16, 'invalid card.number' + assert card['exp_month'] >= 1 and card['exp_month'] <= 12, \ + 'invalid card.exp_month' if card['exp_year'] > 0 and card['exp_year'] < 100: card['exp_year'] += 2000 - assert len(card['cvc']) == 3 + assert len(card['cvc']) == 3, 'invalid card.cvc' elif type == 'sepa_debit': - assert _type(sepa_debit) is dict - assert 'iban' in sepa_debit - assert _type(sepa_debit['iban']) is str - assert 14 <= len(sepa_debit['iban']) <= 34 - except AssertionError: - raise UserError(400, 'Bad request') + assert _type(sepa_debit) is dict, 'invalid sepa_debit' + assert 'iban' in sepa_debit, 'invalid sepa_debit' + assert _type(sepa_debit['iban']) is str, \ + 'invalid sepa_debit.iban' + assert 14 <= len(sepa_debit['iban']) <= 34, \ + 'invalid sepa_debit.iban' + except AssertionError as e: + raise UserError(400, 'Bad request') from e if type == 'card': if not (2019 <= card['exp_year'] < 2100): @@ -2045,10 +2285,11 @@ def _api_attach(cls, id, customer=None, **kwargs): raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(id) is str and id.startswith('pm_') - assert type(customer) is str and customer.startswith('cus_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(id) is str and id.startswith('pm_'), 'invalid id' + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) Customer._api_retrieve(customer) # to return 404 if not existant @@ -2058,6 +2299,7 @@ def _api_attach(cls, id, customer=None, **kwargs): {'code': 'card_declined'}) obj.customer = customer + schedule_webhook(Event('payment_method.attached', obj)) return obj @classmethod @@ -2066,12 +2308,13 @@ def _api_detach(cls, id, **kwargs): raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(id) is str and id.startswith('pm_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(id) is str and id.startswith('pm_'), 'invalid id' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) obj.customer = None + schedule_webhook(Event('payment_method.detached', obj)) return obj @classmethod @@ -2090,10 +2333,11 @@ def _api_retrieve(cls, id): def _api_list_all(cls, url, customer=None, type=None, limit=None, starting_after=None): try: - assert _type(customer) is str and customer.startswith('cus_') - assert type in ('card', ) - except AssertionError: - raise UserError(400, 'Bad request') + assert _type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' + assert type in ('card', ), 'invalid type' + except AssertionError as e: + raise UserError(400, 'Bad request') from e Customer._api_retrieve(customer) # to return 404 if not existant @@ -2136,34 +2380,41 @@ def __init__(self, id=None, metadata=None, amount=None, product=None, trial_period_days = try_convert_to_int(trial_period_days) active = try_convert_to_bool(active) try: - assert id is None or type(id) is str and id - assert type(active) is bool - assert billing_scheme in ['per_unit', 'tiered'] + assert id is None or type(id) is str and id, 'invalid id' + assert type(active) is bool, 'invalid active' + assert billing_scheme in ['per_unit', 'tiered'], \ + 'invalid billing_scheme' if billing_scheme == 'per_unit': - assert type(amount) is int and amount >= 0 + assert type(amount) is int and amount >= 0, 'invalid amount' else: - assert tiers_mode in ['graduated', 'volume'] - assert type(tiers) is list and len(tiers) > 0 + assert tiers_mode in ['graduated', 'volume'], \ + 'invalid tiers_mode' + assert type(tiers) is list and len(tiers) > 0, 'invalid tiers' for t in tiers: assert \ type(t) is dict and 'up_to' in t and \ (t['up_to'] == 'inf' or - type(try_convert_to_int(t['up_to'])) is int) + type(try_convert_to_int(t['up_to'])) is int), \ + 'invalid tiers up_to' unit_amount = try_convert_to_int(t.get('unit_amount', 0)) - assert type(unit_amount) is int and unit_amount >= 0 + assert type(unit_amount) is int and unit_amount >= 0, \ + 'invalid tiers unit_amount' flat_amount = try_convert_to_int(t.get('flat_amount', 0)) - assert type(flat_amount) is int and flat_amount >= 0 - assert type(currency) is str and currency - assert type(interval) is str - assert interval in ('day', 'week', 'month', 'year') - assert type(interval_count) is int + assert type(flat_amount) is int and flat_amount >= 0, \ + 'invalid tiers flat_amount' + assert type(currency) is str and currency, 'invalid currency' + assert type(interval) is str, 'invalid interval' + assert interval in ('day', 'week', 'month', 'year'), \ + 'invalid interval' + assert type(interval_count) is int, 'invalid interval_count' if trial_period_days is not None: - assert type(trial_period_days) is int + assert type(trial_period_days) is int, \ + 'invalid trial_period_days' if nickname is not None: - assert type(nickname) is str - assert usage_type in ['licensed', 'metered'] - except AssertionError: - raise UserError(400, 'Bad request') + assert type(nickname) is str, 'invalid nickname' + assert usage_type in ['licensed', 'metered'], 'invalid usage_type' + except AssertionError as e: + raise UserError(400, 'Bad request') from e if type(product) is str: Product._api_retrieve(product) # to return 404 if not existant @@ -2210,23 +2461,25 @@ def __init__(self, amount=None, currency=None, description=None, amount = try_convert_to_int(amount) try: - assert type(amount) is int and amount > 0 - assert currency in ('eur',) + assert type(amount) is int and amount > 0, 'invalid amount' + assert currency in ('eur',), 'invalid currency' if description is not None: - assert type(description) is str + assert type(description) is str, 'invalid description' if metadata is not None: - assert type(metadata) is dict + assert type(metadata) is dict, 'invalid metadata' if statement_descriptor is not None: assert type(statement_descriptor) is str \ - and len(statement_descriptor) <= 22 + and len(statement_descriptor) <= 22, \ + 'invalid statement_descriptor' if method is not None: - assert method in ('standard', 'instant') + assert method in ('standard', 'instant'), 'invalid method' if source_type is not None: - assert type(source_type) is str + assert type(source_type) is str, 'invalid source_type' if status is not None: - assert status in ('paid', 'pending', 'failed') - except AssertionError: - raise UserError(400, 'Bad request') + assert status in ('paid', 'pending', 'failed'), \ + 'invalid status' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -2308,24 +2561,26 @@ def __init__(self, id=None, name=None, type='service', active=True, active = try_convert_to_bool(active) try: - assert id is None or _type(id) is str and id - assert _type(name) is str and name - assert type in ('good', 'service') - assert _type(active) is bool + assert id is None or _type(id) is str and id, 'invalid id' + assert _type(name) is str and name, 'invalid name' + assert type in ('good', 'service'), 'invalid type' + assert _type(active) is bool, 'invalid active' if caption is not None: - assert _type(caption) is str + assert _type(caption) is str, 'invalid caption' if description is not None: - assert _type(description) is str + assert _type(description) is str, 'invalid description' if attributes is not None: - assert _type(attributes) is list - assert _type(shippable) is bool + assert _type(attributes) is list, 'invalid attributes' + assert _type(shippable) is bool, 'invalid shippable' if url is not None: - assert _type(url) is str + assert _type(url) is str, 'invalid url' if statement_descriptor is not None: - assert _type(statement_descriptor) is str - assert len(statement_descriptor) <= 22 - except AssertionError: - raise UserError(400, 'Bad request') + assert _type(statement_descriptor) is str, \ + 'invalid statement_descriptor' + assert len(statement_descriptor) <= 22, \ + 'invalid statement_descriptor' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__(id) @@ -2354,11 +2609,12 @@ def __init__(self, charge=None, amount=None, metadata=None, **kwargs): amount = try_convert_to_int(amount) try: - assert type(charge) is str and charge.startswith('ch_') + assert type(charge) is str and charge.startswith('ch_'), \ + 'invalid charge' if amount is not None: - assert type(amount) is int and amount > 0 - except AssertionError: - raise UserError(400, 'Bad request') + assert type(amount) is int and amount > 0, 'invalid amount' + except AssertionError as e: + raise UserError(400, 'Bad request') from e charge_obj = Charge._api_retrieve(charge) @@ -2388,9 +2644,10 @@ def __init__(self, charge=None, amount=None, metadata=None, **kwargs): def _api_list_all(cls, url, charge=None, limit=None, starting_after=None): try: if charge is not None: - assert type(charge) is str and charge.startswith('ch_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(charge) is str and charge.startswith('ch_'), \ + 'invalid charge' + except AssertionError as e: + raise UserError(400, 'Bad request') from e li = super(Refund, cls)._api_list_all(url, limit=limit, starting_after=starting_after) @@ -2416,19 +2673,24 @@ def __init__(self, type=None, currency=None, owner=None, metadata=None, assert type in ( 'ach_credit_transfer', 'ach_debit', 'alipay', 'bancontact', 'bitcoin', 'card', 'eps', 'giropay', 'ideal', 'multibanco', - 'p24', 'sepa_debit', 'sofort', 'three_d_secure') - assert _type(currency) is str and currency + 'p24', 'sepa_debit', 'sofort', 'three_d_secure'), \ + 'invalid type' + assert _type(currency) is str and currency, 'invalid currency' if owner is not None: - assert _type(owner) is dict - assert _type(owner.get('name', '')) is str - assert _type(owner.get('email', '')) is str + assert _type(owner) is dict, 'invalid owner' + assert _type(owner.get('name', '')) is str, \ + 'invalid owner.name' + assert _type(owner.get('email', '')) is str, \ + 'invalid owner.email' if type == 'sepa_debit': - assert _type(sepa_debit) is dict - assert 'iban' in sepa_debit - assert _type(sepa_debit['iban']) is str - assert 14 <= len(sepa_debit['iban']) <= 34 - except AssertionError: - raise UserError(400, 'Bad request') + assert _type(sepa_debit) is dict, 'invalid sepa_debit' + assert 'iban' in sepa_debit, 'invalid sepa_debit' + assert _type(sepa_debit['iban']) is str, \ + 'invalid sepa_debit.iban' + assert 14 <= len(sepa_debit['iban']) <= 34, \ + 'invalid sepa_debit.iban' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -2479,17 +2741,20 @@ def __init__(self, customer=None, usage=None, payment_method_types=None, try: if customer is not None: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if usage is None: usage = 'off_session' - assert usage in ('off_session', 'on_session') + assert usage in ('off_session', 'on_session'), 'invalid usage' if payment_method_types is None: payment_method_types = ['card'] - assert type(payment_method_types) is list + assert type(payment_method_types) is list, \ + 'invalid payment_method_types' assert all(t in ('card', 'sepa_debit', 'ideal') - for t in payment_method_types) - except AssertionError: - raise UserError(400, 'Bad request') + for t in payment_method_types), \ + 'invalid payment_method_types' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -2510,13 +2775,14 @@ def _api_confirm(cls, id, use_stripe_sdk=None, client_secret=None, raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(id) is str and id.startswith('seti_') + assert type(id) is str and id.startswith('seti_'), 'invalid id' if client_secret is not None: - assert type(client_secret) is str + assert type(client_secret) is str, 'invalid client_secret' if payment_method_data is not None: - assert type(payment_method_data) is dict - except AssertionError: - raise UserError(400, 'Bad request') + assert type(payment_method_data) is dict, \ + 'invalid payment_method_data' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) @@ -2525,7 +2791,7 @@ def _api_confirm(cls, id, use_stripe_sdk=None, client_secret=None, if payment_method_data: if obj.payment_method is not None: - raise UserError(400, 'Bad request') + raise UserError(400, 'Bad request: missing payment_method') pm = PaymentMethod(**payment_method_data) obj.payment_method = pm.id @@ -2559,11 +2825,11 @@ def _api_cancel(cls, id, use_stripe_sdk=None, client_secret=None, raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(id) is str and id.startswith('seti_') + assert type(id) is str and id.startswith('seti_'), 'invalid id' if client_secret is not None: - assert type(client_secret) is str - except AssertionError: - raise UserError(400, 'Bad request') + assert type(client_secret) is str, 'invalid client_secret' + except AssertionError as e: + raise UserError(400, 'Bad request') from e obj = cls._api_retrieve(id) @@ -2610,52 +2876,68 @@ def __init__(self, customer=None, metadata=None, items=None, billing_cycle_anchor = try_convert_to_int(billing_cycle_anchor) try: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if trial_end is not None: if trial_end == 'now': trial_end = int(time.time()) - assert type(trial_end) is int - assert trial_end > 1500000000 + assert type(trial_end) is int, 'invalid trial_end' + assert trial_end > 1500000000, 'invalid trial_end' if tax_percent is not None: - assert default_tax_rates is None - assert type(tax_percent) is float - assert tax_percent >= 0 and tax_percent <= 100 + assert default_tax_rates is None, 'invalid default_tax_rates' + assert type(tax_percent) is float, 'invalid tax_percent' + assert tax_percent >= 0 and tax_percent <= 100, \ + 'invalid tax_percent' if default_tax_rates is not None: - assert tax_percent is None - assert type(default_tax_rates) is list + assert tax_percent is None, 'invalid tax_percent' + assert type(default_tax_rates) is list, \ + 'invalid default_tax_rates' assert all(type(txr) is str and txr.startswith('txr_') - for txr in default_tax_rates) + for txr in default_tax_rates), \ + 'invalid default_tax_rates' if trial_period_days is not None: - assert type(trial_period_days) is int + assert type(trial_period_days) is int, \ + 'invalid trial_period_days' if backdate_start_date is not None: - assert type(backdate_start_date) is int - assert backdate_start_date > 1500000000 + assert type(backdate_start_date) is int, \ + 'invalid backdate_start_date' + assert backdate_start_date > 1500000000, \ + 'invalid backdate_start_date' if billing_cycle_anchor is not None: - assert type(billing_cycle_anchor) is int - assert billing_cycle_anchor > int(time.time()) + assert type(billing_cycle_anchor) is int, \ + 'invalid billing_cycle_anchor' + assert billing_cycle_anchor > int(time.time()), \ + 'invalid billing_cycle_anchor' if proration_behavior is not None: - assert proration_behavior in ['create_prorations', 'none'] - assert type(items) is list + assert proration_behavior in ['create_prorations', 'none'], \ + 'invalid proration_behavior' + assert type(items) is list, 'invalid items' for item in items: - assert type(item.get('plan')) is str + assert type(item.get('plan')) is str, 'invalid items plan' if item.get('quantity') is not None: item['quantity'] = try_convert_to_int(item['quantity']) - assert type(item['quantity']) is int - assert item['quantity'] > 0 + assert type(item['quantity']) is int, \ + 'invalid items quantity' + assert item['quantity'] > 0, 'invalid items quantity' else: item['quantity'] = 1 item['tax_rates'] = item.get('tax_rates') if item['tax_rates'] is not None: - assert type(item['tax_rates']) is list - assert all(type(tr) is str for tr in item['tax_rates']) + assert type(item['tax_rates']) is list, \ + 'invalid items tax_rates' + assert all(type(tr) is str for tr in item['tax_rates']), \ + 'invalid items tax_rates' item['metadata'] = item.get('metadata') if item['metadata'] is not None: - assert type(item['metadata']) is dict - assert type(enable_incomplete_payments) is bool + assert type(item['metadata']) is dict, \ + 'invalid items metadata' + assert type(enable_incomplete_payments) is bool, \ + 'invalid enable_incomplete_payments' assert payment_behavior in ('allow_incomplete', - 'error_if_incomplete') - except AssertionError: - raise UserError(400, 'Bad request') + 'error_if_incomplete'), \ + 'invalid payment_behavior' + except AssertionError as e: + raise UserError(400, 'Bad request') from e if len(items) != 1: raise UserError(500, 'Not implemented') @@ -2810,48 +3092,58 @@ def _update(self, metadata=None, items=None, trial_end=None, if trial_end is not None: if trial_end == 'now': trial_end = int(time.time()) - assert type(trial_end) is int - assert trial_end > 1500000000 + assert type(trial_end) is int, 'invalid trial_end' + assert trial_end > 1500000000, 'invalid trial_end' if tax_percent is not None: - assert default_tax_rates is None - assert type(tax_percent) is float - assert tax_percent >= 0 and tax_percent <= 100 + assert default_tax_rates is None, 'invalid default_tax_rates' + assert type(tax_percent) is float, 'invalid tax_percent' + assert tax_percent >= 0 and tax_percent <= 100, \ + 'invalid tax_percent' if default_tax_rates is not None: - assert tax_percent is None - assert type(default_tax_rates) is list + assert tax_percent is None, 'invalid tax_percent' + assert type(default_tax_rates) is list, \ + 'invalid default_tax_rates' assert all(type(txr) is str and txr.startswith('txr_') - for txr in default_tax_rates) + for txr in default_tax_rates), \ + 'invalid default_tax_rates' if prorate is not None: - assert type(prorate) is bool + assert type(prorate) is bool, 'invalid prorate' if proration_date is not None: - assert type(proration_date) is int - assert proration_date > 1500000000 + assert type(proration_date) is int, 'invalid proration_date' + assert proration_date > 1500000000, 'invalid proration_date' if cancel_at_period_end is not None: - assert type(cancel_at_period_end) is bool + assert type(cancel_at_period_end) is bool, \ + 'invalid cancel_at_period_end' if cancel_at is not None: - assert type(cancel_at) is int - assert cancel_at > 1500000000 + assert type(cancel_at) is int, 'invalid cancel_at' + assert cancel_at > 1500000000, 'invalid cancel_at' if items is not None: - assert type(items) is list + assert type(items) is list, 'invalid items' for item in items: id = item.get('id') if id is not None: - assert type(id) is str and id.startswith('si_') + assert type(id) is str and id.startswith('si_'), \ + 'invalid items id' if item.get('quantity') is not None: item['quantity'] = try_convert_to_int(item['quantity']) - assert type(item['quantity']) is int - assert item['quantity'] > 0 + assert type(item['quantity']) is int, \ + 'invalid items quantity' + assert item['quantity'] > 0, 'invalid items quantity' else: item['quantity'] = 1 item['tax_rates'] = item.get('tax_rates') if item['tax_rates'] is not None: - assert type(item['tax_rates']) is list - assert all(type(tr) is str for tr in item['tax_rates']) + assert type(item['tax_rates']) is list, \ + 'invalid items tax_rates' + assert all(type(tr) is str + for tr in item['tax_rates']), \ + 'invalid items tax_rates' item['metadata'] = item.get('metadata') if item['metadata'] is not None: - assert type(item['metadata']) is dict - except AssertionError: - raise UserError(400, 'Bad request') + assert type(item['metadata']) is dict, \ + 'invalid items metadata' + except AssertionError as e: + raise UserError(400, 'Bad request') from e old_plan = self.plan if items is not None: @@ -2948,13 +3240,15 @@ def _api_list_all(cls, url, customer=None, status=None, limit=None, starting_after=None): try: if customer is not None: - assert type(customer) is str and customer.startswith('cus_') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' if status is not None: assert status in ('all', 'incomplete', 'incomplete_expired', 'trialing', 'active', 'past_due', 'unpaid', - 'canceled') - except AssertionError: - raise UserError(400, 'Bad request') + 'canceled'), \ + 'invalid status' + except AssertionError as e: + raise UserError(400, 'Bad request') from e li = super(Subscription, cls)._api_list_all(url, limit=limit, @@ -2982,15 +3276,16 @@ def __init__(self, subscription=None, plan=None, quantity=1, quantity = try_convert_to_int(quantity) try: if subscription is not None: - assert type(subscription) is str - assert subscription.startswith('sub_') - assert type(plan) is str - assert type(quantity) is int and quantity > 0 + assert type(subscription) is str, 'invalid subscription' + assert subscription.startswith('sub_'), 'invalid subscription' + assert type(plan) is str, 'invalid plan' + assert type(quantity) is int and quantity > 0, 'invalid quantity' if tax_rates is not None: - assert type(tax_rates) is list - assert all(type(tr) is str for tr in tax_rates) - except AssertionError: - raise UserError(400, 'Bad request') + assert type(tax_rates) is list, 'invalid tax_rates' + assert all(type(tr) is str for tr in tax_rates), \ + 'invalid tax_rates' + except AssertionError as e: + raise UserError(400, 'Bad request') from e plan = Plan._api_retrieve(plan) # to return 404 if not existant # To return 404 if not existant: @@ -3077,15 +3372,15 @@ def __init__(self, country=None, customer=None, type=None, value=None, raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert _type(customer) is str - assert customer.startswith('cus_') - assert type in ('eu_vat', 'nz_gst', 'au_abn') - assert _type(value) is str and len(value) > 10 + assert _type(customer) is str, 'invalid customer' + assert customer.startswith('cus_'), 'invalid customer' + assert type in ('eu_vat', 'nz_gst', 'au_abn'), 'invalid type' + assert _type(value) is str and len(value) > 10, 'invalid value' if country is None: country = value[0:2] - assert _type(country) is str - except AssertionError: - raise UserError(400, 'Bad request') + assert _type(country) is str, 'invalid country' + except AssertionError as e: + raise UserError(400, 'Bad request') from e Customer._api_retrieve(customer) # to return 404 if not existant @@ -3122,15 +3417,18 @@ def __init__(self, display_name=None, inclusive=None, percentage=None, percentage = try_convert_to_float(percentage) active = try_convert_to_bool(active) try: - assert type(display_name) is str and display_name - assert type(inclusive) is bool - assert type(percentage) is float - assert type(active) is bool - assert percentage >= 0 and percentage <= 100 - assert description is None or type(description) is str - assert jurisdiction is None or type(jurisdiction) is str - except AssertionError: - raise UserError(400, 'Bad request') + assert type(display_name) is str and display_name, \ + 'invalid display_name' + assert type(inclusive) is bool, 'invalid inclusive' + assert type(percentage) is float, 'invalid percentage' + assert type(active) is bool, 'invalid active' + assert percentage >= 0 and percentage <= 100, 'invalid percentage' + assert description is None or type(description) is str, \ + 'invalid description' + assert jurisdiction is None or type(jurisdiction) is str, \ + 'invalid jurisdiction' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # All exceptions must be raised before this point. super().__init__() @@ -3158,11 +3456,12 @@ def __init__(self, card=None, customer=None, **kwargs): raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) try: - assert type(card) is dict + assert type(card) is dict, 'invalid card' if customer is not None: - assert type(customer) is str and customer.startswith('cus_') - except AssertionError: - raise UserError(400, 'Bad request') + assert type(customer) is str and customer.startswith('cus_'), \ + 'invalid customer' + except AssertionError as e: + raise UserError(400, 'Bad request') from e # If this raises, abort and don't create the token card['object'] = 'card' @@ -3175,3 +3474,60 @@ def __init__(self, card=None, customer=None, **kwargs): self.type = 'card' self.card = card_obj + + +class WebhookEndpoint(StripeObject): + object = 'webhook_endpoint' + _id_prefix = 'we_' + _secret_prefix = 'whsec_' + + def __init__(self, id=None, url=None, enabled_events=None, + api_version=None, description=None, application=None, + status=None, _secret=None, **kwargs): + if kwargs: + raise UserError(400, 'Unexpected ' + ', '.join(kwargs.keys())) + + super().__init__() + + self.id = id + self.url = url + self.enabled_events = enabled_events or [] + self.api_version = api_version + self.description = description or '' + self.application = application + self.status = status or 'enabled' + self._secret = _secret or \ + getattr(self, '_secret_prefix') + random_id(14) + + @classmethod + def _api_create(cls, **data): + if '_secret' in data: + raise UserError(400, 'Unexpected _secret') + obj = super()._api_create(**data) + + register_webhook(obj.id, obj.url, obj._secret, obj.enabled_events) + return obj + + @classmethod + def _api_update(cls, **data): + if '_secret' in data: + raise UserError(400, 'Unexpected _secret') + obj = super()._api_update(**data) + + register_webhook(obj.id, obj.url, obj.secret, obj.enabled_events) + return obj + + @classmethod + def _api_retrieve(cls, id): + obj = list_webhooks().get(id) + + if obj is None: + raise UserError(404, 'Not Found') + + return WebhookEndpoint(id=id, url=obj.url, enabled_events=obj.events) + + @classmethod + def _api_delete(cls, id): + if id not in list_webhooks().keys(): + raise UserError(404, 'Not Found') + unregister_webhook(id) diff --git a/localstripe/server.py b/localstripe/server.py index 081d7b84..2f6b96a6 100644 --- a/localstripe/server.py +++ b/localstripe/server.py @@ -21,13 +21,15 @@ import os.path import re import socket +import sys from aiohttp import web +from aiohttp.abc import AbstractAccessLogger -from .resources import BalanceTransaction, Charge, Coupon, Customer, Event, \ - Invoice, InvoiceItem, PaymentIntent, PaymentMethod, Payout, Plan, \ +from .resources import Account, BalanceTransaction, Charge, Coupon, Customer, \ + Event, Invoice, InvoiceItem, PaymentIntent, PaymentMethod, Payout, Plan, \ Product, Refund, SetupIntent, Source, Subscription, SubscriptionItem, \ - TaxRate, Token, extra_apis, store + TaxRate, Token, WebhookEndpoint, extra_apis, store from .errors import UserError from .webhooks import register_webhook @@ -159,6 +161,9 @@ async def auth_middleware(request, handler): elif request.path.startswith('/_config/'): is_auth = True + elif request.path.startswith('/_status'): + is_auth = True + else: # There are exceptions (for example POST /v1/tokens, POST /v1/sources) # where authentication can be done using the public key (passed as @@ -169,6 +174,7 @@ async def auth_middleware(request, handler): r'^/v1/tokens$', r'^/v1/sources$', r'^/v1/payment_intents/\w+/_authenticate\b', + r'^/v1/payment_methods$', r'^/v1/setup_intents/\w+/confirm$', r'^/v1/setup_intents/\w+/cancel$', ))) @@ -200,7 +206,10 @@ async def save_store_middleware(request, handler): store.dump_to_disk() +norm_path_middleware = web.normalize_path_middleware(append_slash=False, + remove_slash=True) app = web.Application(middlewares=[error_middleware, auth_middleware, + norm_path_middleware, save_store_middleware]) app.on_response_prepare.append(add_cors_headers) @@ -273,10 +282,10 @@ async def f(request): app.router.add_route(method, url, api_extra(func, url)) -for cls in (BalanceTransaction, Charge, Coupon, Customer, Event, Invoice, - InvoiceItem, PaymentIntent, PaymentMethod, Payout, Plan, Product, - Refund, SetupIntent, Source, Subscription, SubscriptionItem, - TaxRate, Token): +for cls in (Account, BalanceTransaction, Charge, Coupon, Customer, Event, + Invoice, InvoiceItem, PaymentIntent, PaymentMethod, Payout, Plan, + Product, Refund, SetupIntent, Source, Subscription, + SubscriptionItem, TaxRate, Token, WebhookEndpoint): for method, url, func in ( ('POST', '/v1/' + cls.object + 's', api_create), ('GET', '/v1/' + cls.object + 's/{id}', api_retrieve), @@ -302,12 +311,14 @@ async def config_webhook(request): url = data.get('url', None) secret = data.get('secret', None) events = data.get('events', None) + expand = data.pop('expand', None) if not url or not secret or not url.startswith('http'): raise UserError(400, 'Bad request') if events is not None and type(events) is not list: raise UserError(400, 'Bad request') register_webhook(id, url, secret, events) - return web.Response() + wh = WebhookEndpoint(id=id, url=url, enabled_events=events, _secret=secret) + return json_response(wh._export(expand=expand)) async def flush_store(request): @@ -315,16 +326,37 @@ async def flush_store(request): return web.Response() +async def get_status(request): + return web.Response(text='{"status": "ok"}\n', + content_type='application/json') + + app.router.add_post('/_config/webhooks/{id}', config_webhook) app.router.add_delete('/_config/data', flush_store) +app.router.add_get('/_status', get_status) + + +class AccessLogger(AbstractAccessLogger): + def log(self, request, response, time): + if request.path.startswith('/_status'): + return + + self.logger.info(f'{request.remote} ' + f'"{request.method} {request.path}" ' + f'done in {time}s: {response.status}') def start(): parser = argparse.ArgumentParser() parser.add_argument('--port', type=int, default=8420) parser.add_argument('--from-scratch', action='store_true') + parser.add_argument('--config', action='store_true') + parser.add_argument('--no-save', default=False, action='store_true') args = parser.parse_args() + store.save_to_disk = not args.no_save + if args.config: + store.load_from_config(sys.stdin) if not args.from_scratch: store.try_load_from_disk() @@ -337,7 +369,10 @@ def start(): logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler()) - web.run_app(app, sock=sock, access_log=logger) + web.run_app(app, + sock=sock, + access_log=logger, + access_log_class=AccessLogger) if __name__ == '__main__': diff --git a/localstripe/webhooks.py b/localstripe/webhooks.py index 700a921b..72679857 100644 --- a/localstripe/webhooks.py +++ b/localstripe/webhooks.py @@ -37,6 +37,14 @@ def register_webhook(id, url, secret, events): _webhooks[id] = Webhook(url, secret, events) +def unregister_webhook(id): + del _webhooks[id] + + +def list_webhooks(): + return _webhooks + + async def _send_webhook(event): payload = json.dumps(event._export(), indent=2, sort_keys=True) payload = payload.encode('utf-8') @@ -45,6 +53,7 @@ async def _send_webhook(event): await asyncio.sleep(1) logger = logging.getLogger('aiohttp.access') + logger.debug('sending webhook %s if registered' % event.type) for webhook in _webhooks.values(): if webhook.events is not None and event.type not in webhook.events: diff --git a/setup.py b/setup.py index 8d0d0c52..2d0e6f26 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ name='localstripe', version=__version__, author=__author__, - url='https://github.com/adrienverge/localstripe', + url='https://github.com/Fatsoma/localstripe', description=('A fake but stateful Stripe server that you can run locally, ' 'for testing purposes.'),