Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TBD: Ast serde #727

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 295 additions & 15 deletions datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,41 @@ def get_furthest_parent(
return curr_parent
curr_parent = curr_parent.parent

def flatten(self) -> Iterator["Node"]:
def flatten(
self,
obfuscated: bool = False,
named: bool = False,
) -> Iterator["Node"]:
"""
Flatten the sub-ast of the node as an iterator
"""
return self.filter(lambda _: True)
seen = set()

def _flatten(args):
parent_key = None
if named:
parent_key, node = args
else:
node = args
if id(node) not in seen:
yield (node if not named else (parent_key, node))
seen.add(id(node))
for child in chain(
*[
_flatten(child)
for child in node.fields(
nodes_only=True,
flat=True,
named=named,
nones=False,
obfuscated=obfuscated,
)
if id(child) not in seen
]
):
yield child

return _flatten((self.parent_key, self) if named else self)

# pylint: disable=R0913
def fields(
Expand Down Expand Up @@ -471,13 +501,13 @@ def is_compiled(self) -> bool:
return self._is_compiled


class DJEnum(Enum):
class DJEnum(str, Enum):
"""
A DJ AST enum
"""

def __repr__(self) -> str:
return str(self)
def __str__(self):
return self.value


@dataclass(eq=False)
Expand Down Expand Up @@ -1190,9 +1220,6 @@ class UnaryOpKind(DJEnum):
Exists = "EXISTS"
Not = "NOT"

def __str__(self):
return self.value


@dataclass(eq=False)
class UnaryOp(Operation):
Expand Down Expand Up @@ -1317,7 +1344,7 @@ def __str__(self) -> str:
right = self.right.copy().use_alias_as_name()
if isinstance(self.left, Column) and self.left.alias:
left = self.left.copy().use_alias_as_name()
ret = f"{left} {self.op.value} {right}"
ret = f"{left} {self.op} {right}"

if self.parenthesized:
return f"({ret})"
Expand Down Expand Up @@ -2235,12 +2262,6 @@ def add_aliases_to_unnamed_columns(self) -> None:
projection.append(expression)
self.projection = projection


class Select(SelectExpression):
"""
A single select statement type
"""

def __str__(self) -> str:
parts = ["SELECT "]
if self.quantifier:
Expand Down Expand Up @@ -2272,6 +2293,12 @@ def __str__(self) -> str:
return f"{select}{as_}{self.alias}"
return select


class Select(SelectExpression):
"""
A single select statement type
"""

@property
def type(self) -> ColumnType:
if len(self.projection) != 1:
Expand Down Expand Up @@ -2416,3 +2443,256 @@ def build( # pylint: disable=R0913,C0415
self.select.projection,
key=lambda x: str(x.alias_or_name),
)[:]


###################################
###SERIALIZATION/DESERIALIZATION###
###################################
def get_node_key(node: Node) -> int:
"""
Returns the unique identifier of a node.
"""
return id(node)


def serialize_value(
value: Any,
serialization: Dict[int, Tuple[str, Dict[str, Any]]],
visited_nodes: Set[int],
) -> Any:
"""
Serializes a value to a dictionary representation.
"""
if isinstance(value, Node):
node_key = get_node_key(value)
if node_key in visited_nodes:
return {"kind": "node", "value": node_key}
visited_nodes.add(node_key)
_serialize_ast(value, serialization, visited_nodes)
return {"kind": "node", "value": node_key}
if isinstance(value, ColumnType):
return {"kind": "type", "value": str(value)}
if isinstance(value, DJEnum):
return {"kind": "primitive", "value": value.value}
if type(value) in PRIMITIVES:
return {"kind": "primitive", "value": value}
if isinstance(value, list):
return {
"kind": "list",
"value": [
serialize_value(item, serialization, visited_nodes) for item in value
],
}
if isinstance(value, tuple):
return {
"kind": "tuple",
"value": [
serialize_value(item, serialization, visited_nodes) for item in value
],
}
if isinstance(value, set):
return {
"kind": "set",
"value": [
serialize_value(item, serialization, visited_nodes) for item in value
],
}


def _serialize_ast(
node: Node,
serialization: Dict[int, Tuple[str, Dict[str, Any]]],
visited_nodes: Set[int],
):
"""
Recursively serializes an AST node and its children.
"""
node_key = get_node_key(node)
if node_key in serialization:
return
cls_name = type(node).__name__

data = {}
for key, value in node.__dict__.items():
data[key] = serialize_value(value, serialization, visited_nodes)
serialization[node_key] = (cls_name, data)


def serialize_ast(node: Node) -> Dict[int, Tuple[str, Dict[str, Any]]]:
"""
Serializes an AST node and returns its serialization.
"""
ret = {}
visited_nodes = set()
_serialize_ast(node, ret, visited_nodes)
return ret


def deserialize_value(
parent_id: int,
parent_key: str,
value: Any,
serialization: Dict[int, Tuple[str, Dict[str, Any]]],
visited: Set[int],
lazies: List["LazyNode"],
) -> Any:
"""
Deserializes a value from its dictionary representation.
"""
if not value:
return
if value["kind"] == "node":
node_key = value["value"]
return (
_deserialize_ast(
parent_id,
parent_key,
node_key,
serialization,
visited,
lazies,
)
or serialization[node_key]
)
elif value["kind"] == "type":
from datajunction_server.sql.parsing.backends.antlr4 import parse

return parse(f"select CAST(x as {value['value']})").select.projection[0].type # type: ignore
elif value["kind"] == "primitive":
return value["value"]
elif value["kind"] == "list":
return [
deserialize_value(
parent_id,
parent_key,
item,
serialization,
visited,
lazies,
)
for item in value["value"]
]
elif value["kind"] == "tuple":
return tuple(
deserialize_value(
parent_id,
parent_key,
item,
serialization,
visited,
lazies,
)
for item in value["value"]
)
elif value["kind"] == "set":
return {
deserialize_value(
parent_id,
parent_key,
item,
serialization,
visited,
lazies,
)
for item in value["value"]
}
raise TypeError(f"Cannot deserialize value `{value}`.")


@dataclass
class LazyNode(Node):
"""
Type used during deserialization of an AST
in place of nodes that have yet to finish
deserializing.
"""

parent_id: Optional[int] = None
parent_key: Optional[str] = None
key: Optional[int] = None
refs: Optional[Dict[int, Tuple[str, Dict[str, Any]]]] = None

def finalize(self) -> Node:
"""
Replaces the lazy node with the fully deserialized node.
"""
node = self.refs[self.key] # type: ignore
self.parent = self.refs[self.parent_id] # type: ignore
node = node.copy()
self.swap(node)
return node

def __str__(self):
raise NotImplementedError()


def _deserialize_ast(
parent_id: Optional[int],
parent_key: Optional[str],
node_id: int,
serialization: Dict[int, Tuple[str, Dict[str, Any]]],
visited: Set[int],
lazies: List[LazyNode],
) -> Node:
"""
Recursively deserializes an AST node and its children.
"""
value = serialization[node_id]
if isinstance(value, Node): # node already deserialized
return value
elif node_id in visited: # circular references
# create a lazy node to be swapped once deserialization is complete
lazy = LazyNode(parent_id, parent_key, node_id, serialization)
lazies.append(lazy)
return lazy
else: # node not deserialized yet
cls_name, data = value
visited.add(node_id)
cls = globals().get(cls_name)
# get the fields we can feed the class init
init_fields = {field.name for field in fields(cls) if field.init == True}
attrs = []
kwargs = {}
for key, value in data.items():
if key in init_fields:
deserialized_value = deserialize_value(
node_id,
key,
value,
serialization,
visited,
lazies,
)
kwargs[key] = deserialized_value
else:
attrs.append(key)
ret = cls(**kwargs) # type: ignore
# set the rest of the attributes from data
for key in attrs:
deserialized_value = deserialize_value(
node_id,
key,
data[key],
serialization,
visited,
lazies,
)
setattr(ret, key, deserialized_value)

serialization[node_id] = ret

return ret


def deserialize_ast(
node_id: int,
serialization: Dict[int, Tuple[str, Dict[str, Any]]],
) -> Node:
"""
Deserializes an AST from its serialization and returns the root node.
"""
lazies = []
ret = _deserialize_ast(None, None, node_id, serialization, set(), lazies)
for lazy in lazies:
lazy.finalize()
return ret
Loading
Loading