Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example_plugins/src/register_plugins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Sequence, Type
from collections.abc import Sequence
from typing import Any, Type

from osprey.engine.udf.base import UDFBase
from osprey.worker.adaptor.plugin_manager import hookimpl_osprey
Expand Down
3 changes: 2 additions & 1 deletion example_plugins/src/services/labels_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Generator
from typing import Any

from osprey.engine.language_types.entities import EntityT
from osprey.worker.lib.osprey_shared.labels import EntityLabels
Expand Down
14 changes: 7 additions & 7 deletions example_plugins/src/udfs/ban_user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Self, cast
from typing import Self, cast

from osprey.engine.executor.custom_extracted_features import CustomExtractedFeature
from osprey.engine.executor.execution_context import ExecutionContext
Expand All @@ -16,7 +16,7 @@ class BanUserArguments(ArgumentsBase):


@dataclass
class BanUserEffect(EffectToCustomExtractedFeatureBase[List[str]]):
class BanUserEffect(EffectToCustomExtractedFeatureBase[list[str]]):
"""Adds a 'ban user' effect to the action."""

entity: str
Expand All @@ -29,20 +29,20 @@ def to_str(self) -> str:
return f'{self.entity}|{self.comment}'

@classmethod
def build_custom_extracted_feature_from_list(cls, values: List[Self]) -> CustomExtractedFeature[List[str]]:
return BanEffectsExtractedFeature(effects=cast(List[BanUserEffect], values))
def build_custom_extracted_feature_from_list(cls, values: list[Self]) -> CustomExtractedFeature[list[str]]:
return BanEffectsExtractedFeature(effects=cast(list[BanUserEffect], values))


@add_slots
@dataclass
class BanEffectsExtractedFeature(CustomExtractedFeature[List[str]]):
effects: List[BanUserEffect]
class BanEffectsExtractedFeature(CustomExtractedFeature[list[str]]):
effects: list[BanUserEffect]

@classmethod
def feature_name(cls) -> str:
return 'ban_user'

def get_serializable_feature(self) -> List[str] | None:
def get_serializable_feature(self) -> list[str] | None:
return [effect.to_str() for effect in self.effects]


Expand Down
13 changes: 7 additions & 6 deletions osprey_worker/src/osprey/engine/ast/ast_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Type, TypeVar, Union
from collections.abc import Callable, Iterator, Sequence
from typing import Any, Type, TypeVar

# from osprey.engine.utils.periodic_execution_yielder import maybe_periodic_yield
from .grammar import ASTNode, Root, Statement
Expand All @@ -26,7 +27,7 @@ def traverse_mro(klass: Any) -> None:

traverse_mro(node.__class__)

seen_fields: Set[str] = set()
seen_fields: set[str] = set()

for klass in ordered_mro:
for field in getattr(klass, '__annotations__', ()):
Expand All @@ -39,11 +40,11 @@ def traverse_mro(klass: Any) -> None:


def _make_memoized_field_values_iterator() -> Callable[
['ASTNode'], Iterator[Tuple[str, Union['ASTNode', Sequence['ASTNode']]]]
['ASTNode'], Iterator[tuple[str, 'ASTNode' | Sequence['ASTNode']]]
]:
_field_cache: Dict[Type['ASTNode'], List[str]] = {}
_field_cache: dict[Type['ASTNode'], list[str]] = {}

def _iter_field_values(node: ASTNode) -> Iterator[Tuple[str, Union['ASTNode', Sequence['ASTNode']]]]:
def _iter_field_values(node: ASTNode) -> Iterator[tuple[str, 'ASTNode' | Sequence['ASTNode']]]:
# To avoid the cost of iterating fields over known node classes,
# perform simple memoization.

Expand Down Expand Up @@ -87,7 +88,7 @@ def _iter_inner(node: 'ASTNode') -> Iterator['ASTNode']:
return _iter_inner(root)


def filter_nodes(root: 'ASTNode', ty: Type[T], filter_fn: Optional[Callable[[T], bool]] = None) -> Iterator[T]:
def filter_nodes(root: 'ASTNode', ty: Type[T], filter_fn: Callable[[T], bool] | None = None) -> Iterator[T]:
"""Given a root, iterate over nodes, filtering out those who's type do not
match the given `ty`.

Expand Down
6 changes: 3 additions & 3 deletions osprey_worker/src/osprey/engine/ast/error_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Sequence, Union

from osprey.engine.utils.types import add_slots

Expand Down Expand Up @@ -38,7 +38,7 @@ def render_span_context_with_message(
span: Span,
hint: str = '',
additional_spans_message: str = '',
additional_spans: Sequence[Union[Span, SpanWithHint]] = tuple(),
additional_spans: Sequence[Span | SpanWithHint] = tuple(),
message_type: str = 'error',
) -> str:
"""Given a span, a message, and a hint, print out a human readable error
Expand Down Expand Up @@ -71,7 +71,7 @@ def render_span_context_with_message(
# TODO: Colors?!
parts = [f'{message_type}: {_assert_valid_message(message)}']

def append_span(span_: Span, hint_: str, num_prefix: Optional[int] = None) -> None:
def append_span(span_: Span, hint_: str, num_prefix: int | None = None) -> None:
if num_prefix:
num_prefix_str = f'({num_prefix}) '
else:
Expand Down
33 changes: 17 additions & 16 deletions osprey_worker/src/osprey/engine/ast/grammar.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass, field, replace
from enum import Enum
from pathlib import Path
from typing import ClassVar, Dict, Optional, Sequence, TypeVar, Union
from typing import ClassVar, TypeVar

from gevent.lock import Semaphore

Expand All @@ -15,8 +16,8 @@
# Keep these outside `Source` so we don't break pickling/hashing/eq
# Will this leak memory? Maybe
# TODO(old): put this stuff back in cached_property
parsed_ast_root_cache: Dict['Source', 'Root'] = {}
ast_root_lock_cache: Dict['Source', Semaphore] = defaultdict(lambda: Semaphore())
parsed_ast_root_cache: dict['Source', 'Root'] = {}
ast_root_lock_cache: dict['Source', Semaphore] = defaultdict(lambda: Semaphore())

# logger = get_logger()

Expand All @@ -35,7 +36,7 @@ class Source:
"""The contents of the source file.
"""

actual_path: Optional[Path] = field(default=None, compare=False)
actual_path: Path | None = field(default=None, compare=False)
"""If the source file was loaded from disk, rather than a virtual source, this will be Some path, where
the file exists. Useful if we want to provide a jump to source functionality.
"""
Expand Down Expand Up @@ -88,7 +89,7 @@ class Span:
start_pos: int
"""The position in the start-line that the span starts."""

_ast_node: Optional['ASTNode'] = None
_ast_node: 'ASTNode' | None = None

def __repr__(self) -> str:
return f'<Span source={self.source.path} start_line={self.start_line} start_pos={self.start_pos}>'
Expand All @@ -102,14 +103,14 @@ def ast_node(self) -> 'ASTNode':
def copy(self, ast_node: 'ASTNode') -> 'Span':
return replace(self, _ast_node=ast_node)

def parent_ast_node(self, n: int = 1) -> Optional['ASTNode']:
def parent_ast_node(self, n: int = 1) -> 'ASTNode' | None:
"""Helper for node traversal, gets the parent value off the n^th generation ast_node.

>>> self.get_parent_ast_node(n=0) == self.ast_node
>>> self.get_parent_ast_node(n=1) == self.ast_node.parent
>>> self.get_parent_ast_node(n=2) == self.ast_node.parent.parent
"""
node: Optional[ASTNode] = self.ast_node
node: ASTNode | None = self.ast_node
while n > 0 and node is not None:
n -= 1
node = node.parent
Expand Down Expand Up @@ -157,7 +158,7 @@ def can_extract(self) -> bool:
class ASTNode:
"""This is the base-class of all AST nodes."""

parent: Optional['ASTNode'] = field(default=None, init=False, repr=False, compare=False)
parent: 'ASTNode' | None = field(default=None, init=False, repr=False, compare=False)
span: Span = field(repr=False, compare=False)
"""The location of this AST node, in a given source file."""

Expand Down Expand Up @@ -256,7 +257,7 @@ class Name(Expression, IsExtractable):

identifier: str
context: Context
_source_annotation: Union[None, _Sentinel, 'Annotation', 'AnnotationWithVariants'] = field(
_source_annotation: None | _Sentinel | 'Annotation' | 'AnnotationWithVariants' = field(
default=_SOURCE_ANNOTATION_UNSET,
repr=False,
compare=False,
Expand All @@ -275,7 +276,7 @@ def identifier_key(self) -> str:
else:
return self.identifier

def set_source_annotation(self, annotation: Union[None, 'Annotation', 'AnnotationWithVariants']) -> None:
def set_source_annotation(self, annotation: None | 'Annotation' | 'AnnotationWithVariants') -> None:
self._source_annotation = annotation

@property
Expand Down Expand Up @@ -315,7 +316,7 @@ class String(Literal):
class Number(Literal):
"""Represents a number parsed from ast."""

value: Union[int, float]
value: int | float


@add_slots
Expand Down Expand Up @@ -349,7 +350,7 @@ class Assign(Statement, IsConstant, IsExtractable):

target: Name
value: Expression
annotation: Optional[Union['Annotation', 'AnnotationWithVariants']] = None
annotation: 'Annotation' | 'AnnotationWithVariants' | None = None

@cached_property
def should_extract(self) -> bool:
Expand Down Expand Up @@ -407,17 +408,17 @@ class Call(Expression, Statement):
```
"""

func: Union[Name, 'Attribute']
func: Name | 'Attribute'
arguments: Sequence['Keyword']

def find_argument(self, name: str) -> Optional['Keyword']:
def find_argument(self, name: str) -> 'Keyword' | None:
for argument in self.arguments:
if argument.name == name:
return argument

return None

def argument_dict(self) -> Dict[str, Expression]:
def argument_dict(self) -> dict[str, Expression]:
return {arg.name: arg.value for arg in self.arguments}

@property
Expand Down Expand Up @@ -844,7 +845,7 @@ class AnnotationWithVariants(Expression):
"""

identifier: str
variants: Sequence[Union[Annotation, 'AnnotationWithVariants']]
variants: Sequence[Annotation | 'AnnotationWithVariants']

@property
def can_extract(self) -> bool:
Expand Down
20 changes: 10 additions & 10 deletions osprey_worker/src/osprey/engine/ast/py_ast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This implements the Python AST -> Osprey AST transformer."""

import ast
from typing import Any, Optional, Type, TypeVar, Union
from typing import Any, Type, TypeVar
from typing import List as ListT

from .ast_utils import iter_field_values
Expand Down Expand Up @@ -98,7 +98,7 @@ class OspreyAstNodeTransformer:

def __init__(self, source: Source):
self.source = source
self.current_node: Optional[ast.AST] = None
self.current_node: ast.AST | None = None

def transform(self, node: ast.AST) -> ASTNode:
# Store current node here, so that invariant will work in given node's context.
Expand Down Expand Up @@ -160,7 +160,7 @@ def transform_Name(self, node: ast.Name) -> Name:
context = pyast_context_to_osprey_context(self, node.ctx, span=span)
return Name(span=span, identifier=node.id, context=context)

def transform_NameConstant(self, node: ast.NameConstant) -> Union[Boolean, None_]:
def transform_NameConstant(self, node: ast.NameConstant) -> Boolean | None_:
self.invariant(node.value in (True, False, None), 'unexpected name constant', node=node)
if node.value is None:
return None_(span=self.span_for(node))
Expand All @@ -169,7 +169,7 @@ def transform_NameConstant(self, node: ast.NameConstant) -> Union[Boolean, None_

# typeshed is missing ast.Constant as it is new in python 3.8
# so we are typing as Any until it gains that definition.
def transform_Constant(self, node: Any) -> Union[String, Number, Boolean, None_]:
def transform_Constant(self, node: Any) -> String | Number | Boolean | None_:
span = self.span_for(node)
if isinstance(node.value, str):
return String(span=span, value=node.value)
Expand Down Expand Up @@ -331,8 +331,8 @@ def transform_JoinedStr(self, node: ast.JoinedStr) -> FormatString:
return FormatString(span=self.span_for(node), format_string=format_string, names=names)

def span_for(self, node: ast.AST, offset_by: int = 0) -> Span:
start_line: Optional[int] = getattr(node, 'lineno', None)
start_pos: Optional[int] = getattr(node, 'col_offset', None)
start_line: int | None = getattr(node, 'lineno', None)
start_pos: int | None = getattr(node, 'col_offset', None)
assert start_line is not None, f'Every AST node should have a span, but {node!r} did not.'
assert start_pos is not None, f'Every AST node should have a span, but {node!r} did not.'
return Span(source=self.source, start_line=start_line, start_pos=start_pos + offset_by)
Expand Down Expand Up @@ -374,7 +374,7 @@ def expect_pyast_ty(self, ty: Type[T], expr: ast.AST, hint: str = '') -> T:

return expr

def expect_pyast_union(self, ty: Type[T], ty2: Type[V], node: ASTNode, hint: str = '') -> Union[T, V]:
def expect_pyast_union(self, ty: Type[T], ty2: Type[V], node: ASTNode, hint: str = '') -> T | V:
# This function is a bit hacky, but lets us convince mypy that we have a union type.
if not (isinstance(node, ty) or isinstance(node, ty2)):
if isinstance(node, ASTNode):
Expand Down Expand Up @@ -472,7 +472,7 @@ def pyast_context_to_osprey_context(
def fixup_parents(root: Root) -> None:
"""Traverses the AST, updating the nodes to point to their parent nodes."""

def recursively_insert_parents(parent_node: Optional[ASTNode], current_node: ASTNode) -> None:
def recursively_insert_parents(parent_node: ASTNode | None, current_node: ASTNode) -> None:
assert isinstance(current_node, ASTNode)
current_node.parent = parent_node

Expand All @@ -489,7 +489,7 @@ def recursively_insert_parents(parent_node: Optional[ASTNode], current_node: AST
# noinspection PyPep8Naming
def pyast_ann_assign_annotation_to_annotation(
transformer: OspreyAstNodeTransformer, annotation: ast.expr
) -> Union[Annotation, AnnotationWithVariants]:
) -> Annotation | AnnotationWithVariants:
annotation_hint: str = (
'annotation must be in the form of `Target: Annotation = Value`, where `Annotation` is a `Name` '
'or `Name[Name, ...]`'
Expand All @@ -501,7 +501,7 @@ def pyast_ann_assign_annotation_to_annotation(
return Annotation(identifier='None', span=transformer.span_for(annotation))

if isinstance(annotation, ast.Subscript):
variants: ListT[Union[Annotation, AnnotationWithVariants]] = []
variants: ListT[Annotation | AnnotationWithVariants] = []
value = transformer.expect_pyast_ty(ast.Name, annotation.value, hint=annotation_hint)
slice = annotation.slice

Expand Down
Loading