diff --git a/octohook/events.py b/octohook/events.py index 3788ff0..db2ef2c 100644 --- a/octohook/events.py +++ b/octohook/events.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from typing import Optional, List, Any +from typing import Optional, List, Any, Annotated from octohook.models import ( Repository, @@ -35,7 +35,6 @@ Sponsorship, Branch, StatusCommit, - RawDict, Commit, CommitUser, _optional, @@ -78,12 +77,12 @@ def __init__(self, payload: dict): class BranchProtectionRuleEvent(BaseWebhookEvent): payload: dict rule: Rule - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.rule = Rule(payload.get("rule")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") class CheckRunEvent(BaseWebhookEvent): @@ -278,13 +277,13 @@ class IssueCommentEvent(BaseWebhookEvent): issue: Issue comment: Comment - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.issue = Issue(payload.get("issue")) self.comment = Comment(payload.get("comment")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") class IssuesEvent(BaseWebhookEvent): @@ -293,7 +292,7 @@ class IssuesEvent(BaseWebhookEvent): """ issue: Issue - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] label: Optional[Label] assignee: Optional[User] milestone: Optional[Milestone] @@ -301,7 +300,7 @@ class IssuesEvent(BaseWebhookEvent): def __init__(self, payload: dict): super().__init__(payload) self.issue = Issue(payload.get("issue")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") self.label = _optional(payload, "label", Label) self.assignee = _optional(payload, "assignee", User) self.milestone = _optional(payload, "milestone", Milestone) @@ -313,12 +312,12 @@ class LabelEvent(BaseWebhookEvent): """ label: Label - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.label = Label(payload.get("label")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") class MarketplacePurchaseEvent(BaseWebhookEvent): @@ -385,12 +384,12 @@ class MilestoneEvent(BaseWebhookEvent): """ milestone: Milestone - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.milestone = Milestone(payload.get("milestone")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") class OrganizationEvent(BaseWebhookEvent): @@ -451,12 +450,12 @@ class ProjectCardEvent(BaseWebhookEvent): """ project_card: ProjectCard - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.project_card = ProjectCard(payload.get("project_card")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") class ProjectColumnEvent(BaseWebhookEvent): @@ -465,12 +464,12 @@ class ProjectColumnEvent(BaseWebhookEvent): """ project_column: ProjectColumn - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.project_column = ProjectColumn(payload.get("project_column")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") class ProjectEvent(BaseWebhookEvent): @@ -503,7 +502,7 @@ class PullRequestEvent(BaseWebhookEvent): pull_request: PullRequest assignee: Optional[User] label: Optional[Label] - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] before: Optional[str] after: Optional[str] requested_reviewer: Optional[User] @@ -514,7 +513,7 @@ def __init__(self, payload: dict): self.pull_request = PullRequest(payload.get("pull_request")) self.assignee = _optional(payload, "assignee", User) self.label = _optional(payload, "label", Label) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") self.before = payload.get("before") self.after = payload.get("after") self.requested_reviewer = _optional(payload, "requested_reviewer", User) @@ -527,13 +526,13 @@ class PullRequestReviewEvent(BaseWebhookEvent): review: Review pull_request: PullRequest - changes: RawDict + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.review = Review(payload.get("review")) self.pull_request = PullRequest(payload.get("pull_request")) - self.changes = RawDict(payload.get("pull_request")) + self.changes = payload.get("changes") class PullRequestReviewCommentEvent(BaseWebhookEvent): @@ -543,13 +542,13 @@ class PullRequestReviewCommentEvent(BaseWebhookEvent): comment: Comment pull_request: PullRequest - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.comment = Comment(payload.get("comment")) self.pull_request = PullRequest(payload.get("pull_request")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") class PullRequestReviewThreadEvent(BaseWebhookEvent): @@ -604,12 +603,12 @@ class ReleaseEvent(BaseWebhookEvent): """ release: Release - changes: RawDict + changes: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): super().__init__(payload) self.release = Release(payload.get("release")) - self.changes = RawDict(payload.get("release")) + self.changes = payload.get("changes") class RepositoryDispatchEvent(BaseWebhookEvent): @@ -618,13 +617,13 @@ class RepositoryDispatchEvent(BaseWebhookEvent): """ branch: str - client_payload: RawDict + client_payload: Annotated[dict, "unstructured"] installation: ShortInstallation def __init__(self, payload: dict): super().__init__(payload) self.branch = payload.get("branch") - self.client_payload = RawDict(payload.get("client_payload")) + self.client_payload = payload.get("client_payload") self.installation = ShortInstallation(payload.get("installation")) @@ -682,13 +681,13 @@ class SponsorshipEvent(BaseWebhookEvent): """ sponsorship: Sponsorship - changes: Optional[RawDict] + changes: Optional[Annotated[dict, "unstructured"]] effective_date: Optional[str] def __init__(self, payload: dict): super().__init__(payload) self.sponsorship = Sponsorship(payload.get("sponsorship")) - self.changes = _optional(payload, "changes", RawDict) + self.changes = payload.get("changes") self.effective_date = payload.get("effective_date", None) diff --git a/octohook/models.py b/octohook/models.py index 04c02de..9f91481 100644 --- a/octohook/models.py +++ b/octohook/models.py @@ -1,5 +1,25 @@ +""" +GitHub webhook model classes. + +This module uses Annotated[dict, "unstructured"] to mark intentionally unstructured +dictionary data in webhook payloads. This annotation serves two purposes: + +1. Documentation: Clearly indicates fields containing variable or user-defined data +2. Type safety: Tests enforce that all dicts are either annotated or proper model classes + +When to use Annotated[dict, "unstructured"]: +- GitHub's variable payload structures (e.g., 'changes' field format varies by event) +- User-defined data (e.g., deployment payloads, client_payload) +- Hypermedia links (_links fields) +- Configuration dictionaries (webhook config) +- Error details with varying structures + +When to create a model class instead: +- Structured GitHub data with consistent fields across events +- Data that benefits from type hints and IDE autocomplete +""" from abc import ABC -from typing import TypeVar, Optional, Type, List, Any +from typing import TypeVar, Optional, Type, List, Any, Annotated T = TypeVar("T") @@ -137,11 +157,6 @@ def __str__(self): return self.full_name -class RawDict(dict): - def __init__(self, payload: dict): - super().__init__(payload) - - class Permissions(BaseGithubModel): payload: dict metadata: str @@ -486,7 +501,8 @@ class Comment(BaseGithubModel): start_side: Optional[str] original_line: Optional[int] side: Optional[str] - reactions: Optional[RawDict] + reactions: Optional[Annotated[dict, "unstructured"]] + _links: Optional[Annotated[dict, "unstructured"]] def __init__(self, payload: dict): self.payload = payload @@ -509,13 +525,13 @@ def __init__(self, payload: dict): self.updated_at = payload.get("updated_at") self.author_association = payload.get("author_association") self.body = payload.get("body") - self._links = _optional(payload, "_links", RawDict) + self._links = payload.get("_links") self.start_line = payload.get("start_line") self.original_start_line = payload.get("original_start_line") self.start_side = payload.get("start_side") self.original_line = payload.get("original_line") self.side = payload.get("side") - self.reactions = _optional(payload, "reactions", RawDict) + self.reactions = payload.get("reactions") def __str__(self): return self.body @@ -566,16 +582,16 @@ class ChecksPullRequest(BaseGithubModel): url: str id: int number: int - head: RawDict - base: RawDict + head: Annotated[dict, "unstructured"] + base: Annotated[dict, "unstructured"] def __init__(self, payload: dict): self.payload = payload self.url = payload.get("url") self.id = payload.get("id") self.number = payload.get("number") - self.head = RawDict(payload.get("head")) - self.base = RawDict(payload.get("base")) + self.head = payload.get("head") + self.base = payload.get("base") class CommitUser(BaseGithubModel): @@ -796,7 +812,7 @@ class Deployment(BaseGithubModel): sha: str ref: str task: str - payload: RawDict + payload: Annotated[dict, "unstructured"] original_environment: str environment: str description: Optional[str] @@ -814,7 +830,7 @@ def __init__(self, payload: dict): self.sha = payload.get("sha") self.ref = payload.get("ref") self.task = payload.get("task") - self.payload = RawDict(payload.get("payload")) + self.payload = payload.get("payload") self.original_environment = payload.get("original_environment") self.environment = payload.get("environment") self.description = payload.get("description") @@ -1098,7 +1114,7 @@ class Hook(BaseGithubModel): name: str active: bool events: List[str] - config: RawDict + config: Annotated[dict, "unstructured"] updated_at: str created_at: str @@ -1109,7 +1125,7 @@ def __init__(self, payload: dict): self.name = payload.get("name") self.active = payload.get("active") self.events = payload.get("events") - self.config = RawDict(payload.get("config")) + self.config = payload.get("config") self.updated_at = payload.get("updated_at") self.created_at = payload.get("created_at") @@ -1329,7 +1345,7 @@ class PageBuild(BaseGithubModel): payload: dict url: str status: str - error: RawDict + error: Annotated[dict, "unstructured"] pusher: User commit: str duration: int @@ -1340,7 +1356,7 @@ def __init__(self, payload: dict): self.payload = payload self.url = payload.get("url") self.status = payload.get("status") - self.error = RawDict(payload.get("error")) + self.error = payload.get("error") self.pusher = User(payload.get("pusher")) self.commit = payload.get("commit") self.duration = payload.get("duration") @@ -1488,7 +1504,7 @@ class PullRequest(BaseGithubModel): statuses_url: str head: Ref base: Ref - _links: RawDict + _links: Annotated[dict, "unstructured"] author_association: str draft: bool merged: Optional[bool] @@ -1540,7 +1556,7 @@ def __init__(self, payload: dict): self.statuses_url = payload.get("statuses_url") self.head = Ref(payload.get("head")) self.base = Ref(payload.get("base")) - self._links = RawDict(payload.get("_links")) + self._links = payload.get("_links") self.author_association = payload.get("author_association") self.draft = payload.get("draft") self.merged = payload.get("merged") @@ -1577,7 +1593,7 @@ class Review(BaseGithubModel): html_url: str pull_request_url: str author_association: str - _links: RawDict + _links: Annotated[dict, "unstructured"] def __init__(self, payload: dict): self.payload = payload @@ -1591,7 +1607,7 @@ def __init__(self, payload: dict): self.html_url = payload.get("html_url") self.pull_request_url = payload.get("pull_request_url") self.author_association = payload.get("author_association") - self._links = RawDict(payload.get("_links")) + self._links = payload.get("_links") class VulnerabilityAlert(BaseGithubModel): diff --git a/tests/test_models.py b/tests/test_models.py index e166c40..d65281e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,11 +1,10 @@ import json import os -from typing import get_type_hints, get_origin, get_args +from typing import get_type_hints, get_origin, get_args, Annotated, Union import pytest from octohook.events import parse, WebhookEventAction, BaseWebhookEvent -from octohook.models import RawDict from tests.conftest import discover_fixtures paths = ["tests/fixtures/complete", "tests/fixtures/incomplete"] @@ -33,10 +32,17 @@ def test_model_loads(event_name, fixture_loader): def check_model(data, obj): """ - Checks if every key in the json is represented either as a RawDict or a nested object. + Checks if every key in the json is represented either as an Annotated[dict, "unstructured"] or a nested object. + + This enforces that all dicts are intentionally marked as unstructured, preventing accidental + use of plain dicts where model classes should exist. + :param data: The JSON dictionary :param obj: The Class Object for the dictionary """ + # Get type hints with annotations preserved + hints = get_type_hints(type(obj), include_extras=True) + for key in data: json_value = data[key] try: @@ -48,10 +54,20 @@ def check_model(data, obj): raise AttributeError(f"Couldn't find function or attribute for {key}") # When the nested object is another dictionary - if not isinstance(obj_value, RawDict) and isinstance(json_value, dict): - if isinstance(obj_value, dict): - raise AttributeError(f"Object is a plain dictionary for {key}") + if isinstance(json_value, dict): + type_hint = hints.get(key) + + if _is_unstructured_dict(type_hint): + # Intentionally unstructured - skip validation + continue + elif isinstance(obj_value, dict): + # Plain dict without "unstructured" annotation - ERROR! + raise AttributeError( + f"Plain dict for '{key}' in {type(obj).__name__} - " + f"should be a model class or Annotated[dict, 'unstructured']" + ) else: + # It's a model class - recursively validate check_model(json_value, obj_value) # When the nested object is a list of objects @@ -107,6 +123,71 @@ def test_all_event_actions_are_in_enum(path): WebhookEventAction(action) +def _unwrap_annotated(type_hint): + """ + Extract the actual type from Annotated and Optional wrappers. + + Handles: + - Annotated[T, ...] -> T + - Optional[Annotated[T, ...]] -> T + - Optional[T] -> T + - T -> T + + Args: + type_hint: Type hint to unwrap + + Returns: + The unwrapped type, or the original if not wrapped + """ + origin = get_origin(type_hint) + + # Handle Optional/Union - extract non-None type + if origin is Union: + args = get_args(type_hint) + for arg in args: + if arg is not type(None): + # Recursively unwrap in case it's Optional[Annotated[...]] + return _unwrap_annotated(arg) + + # Handle Annotated - extract the actual type + if origin is Annotated: + args = get_args(type_hint) + return args[0] + + return type_hint + + +def _is_unstructured_dict(type_hint) -> bool: + """ + Check if a type hint is Annotated[dict, "unstructured"]. + + Handles both direct annotation and Optional[Annotated[dict, "unstructured"]]. + + Args: + type_hint: Type hint to check + + Returns: + bool: True if it's an unstructured dict annotation + """ + origin = get_origin(type_hint) + + # Handle Optional[Annotated[dict, "unstructured"]] + if origin is Union: + args = get_args(type_hint) + # Check non-None args + for arg in args: + if arg is not type(None): + return _is_unstructured_dict(arg) + + # Check if it's Annotated[dict, "unstructured"] + if origin is Annotated: + args = get_args(type_hint) + # First arg is the actual type, rest are metadata + return args[0] is dict and "unstructured" in args[1:] + + return False + + def _is_primitive_type(type_hint): """ Check if a type hint represents a primitive type. @@ -115,10 +196,13 @@ def _is_primitive_type(type_hint): type_hint: Type hint to check Returns: - bool: True if the type is primitive (str, int, None, bool, RawDict) + bool: True if the type is primitive (str, int, None, bool, or Annotated[dict, "unstructured"]) """ - primitives = [str, int, type(None), bool, RawDict] - return type_hint in primitives + primitives = [str, int, type(None), bool] + if type_hint in primitives: + return True + # Also treat unstructured dicts as primitives + return _is_unstructured_dict(type_hint) def _validate_simple_type(obj, attr, type_hint, obj_value): @@ -134,7 +218,14 @@ def _validate_simple_type(obj, attr, type_hint, obj_value): Raises: TypeHintError: If the type doesn't match """ - if isinstance(obj_value, type_hint): + # Handle None values for Optional types + if obj_value is None and type(None) in get_args(type_hint): + return # None is valid for Optional types + + # Extract actual type from Annotated/Optional wrappers + check_type = _unwrap_annotated(type_hint) + + if isinstance(obj_value, check_type): # Recursively validate nested objects (non-primitives without None in Union) # Example: For PullRequest type, recursively validate its nested User objects # Skip primitives (str, int, etc.) and Optional types (which include None) @@ -217,7 +308,7 @@ def check_type_hints(obj): TypeHintError: If any type hint doesn't match the actual runtime type AssertionError: If the object doesn't have a 'payload' attribute """ - hints = get_type_hints(type(obj)) + hints = get_type_hints(type(obj), include_extras=True) # All webhook objects should have a payload attribute assert "payload" in hints.keys() @@ -262,3 +353,32 @@ def test_missing_models_return_basewebhookevent(): payload = json.load(file)[0] assert isinstance(parse("code_scanning_alert", payload), BaseWebhookEvent) + + +def test_unannotated_dict_enforcement(): + """ + Verify that check_model enforces Annotated[dict, "unstructured"] requirement. + + Tests that using a plain dict without the annotation raises an AttributeError, + preventing accidental use of unstructured data where model classes should exist. + """ + from octohook.models import BaseGithubModel + + # Create a test model with a plain dict field (missing annotation) + class BadModel(BaseGithubModel): + payload: dict + bad_field: dict # This should be Annotated[dict, "unstructured"] + + def __init__(self, payload: dict): + self.payload = payload + self.bad_field = payload.get("bad_field") + + test_payload = {"bad_field": {"key": "value"}} + obj = BadModel(test_payload) + + # This should raise an error because bad_field is a plain dict without annotation + with pytest.raises( + AttributeError, + match="Plain dict for 'bad_field' in BadModel - should be a model class or Annotated", + ): + check_model(test_payload, obj)