Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 48 additions & 39 deletions scaaml/models/gpam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
import networkx as nx
except ImportError:
nx = None # type: ignore[assignment]
import tensorflow as tf
import keras
from tensorflow.keras import layers
from tensorflow import Tensor
import numpy as np
from keras import layers
from keras.src.backend import KerasTensor


@keras.saving.register_keras_serializable()
class Rescale(layers.Layer): # type: ignore[type-arg]

Check failure on line 44 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Error code "no-any-unimported" not covered by "type: ignore" comment

Check failure on line 44 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Base type Layer becomes "Any" due to an unfollowed import [no-any-unimported]

Check failure on line 44 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Error code "misc" not covered by "type: ignore" comment

Check failure on line 44 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Class cannot subclass "Layer" (has type "Any") [misc]

Check failure on line 44 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Unused "type: ignore" comment [unused-ignore]
"""Rescale input to the interval [-1, 1].
"""

Expand All @@ -59,7 +59,7 @@
self.trace_min: float = trace_min
self.trace_delta: float = trace_delta

def call(self, inputs: Tensor, **kwargs: Any) -> Tensor:
def call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor:

Check failure on line 62 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Argument 2 to "call" becomes "Any" due to an unfollowed import [no-any-unimported]

Check failure on line 62 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Return type becomes "Any" due to an unfollowed import [no-any-unimported]
"""Rescale to the interval [-1, 1]."""
del kwargs # unused
x = inputs
Expand All @@ -74,11 +74,11 @@
"trace_min": self.trace_min,
"trace_delta": self.trace_delta,
})
return config

Check failure on line 77 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Returning Any from function declared to return "dict[str, Any]" [no-any-return]


@keras.saving.register_keras_serializable()
class ScaledNorm(layers.Layer): # type: ignore[type-arg]

Check failure on line 81 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Class cannot subclass "Layer" (has type "Any") [misc]

Check failure on line 81 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / mypy

Unused "type: ignore" comment [unused-ignore]
"""ScaledNorm layer.

Transformers without Tears: Improving the Normalization of Self-Attention
Expand All @@ -104,17 +104,17 @@
self._scale = self.add_weight(
name="norm_scale",
shape=(),
initializer=tf.constant_initializer(value=1.0),
initializer=keras.initializers.Constant(value=1.0),
trainable=True,
)

def call(self, inputs: Tensor) -> Tensor:
def call(self, inputs: KerasTensor) -> KerasTensor:
"""Return the output of this layer.
"""
x = inputs
axes = list(range(len(x.shape)))[self._begin_axis:]
mean_square = tf.reduce_mean(tf.math.square(x), axes, keepdims=True)
x = x * tf.math.rsqrt(mean_square + self._epsilon)
mean_square = keras.ops.mean(keras.ops.square(x), axes, keepdims=True)
x = x * keras.ops.rsqrt(mean_square + self._epsilon)
return x * self._scale

def get_config(self) -> dict[str, Any]:
Expand All @@ -128,19 +128,19 @@
return config


def clone_initializer(initializer: tf.keras.initializers.Initializer) -> Any:
def clone_initializer(initializer: keras.initializers.Initializer) -> Any:
"""Clone an initializer (if an initializer is reused the generated
weights are the same).
"""
if isinstance(initializer, tf.keras.initializers.Initializer):
if isinstance(initializer, keras.initializers.Initializer):
return initializer.__class__.from_config(initializer.get_config())
return initializer # type: ignore[unreachable]


def rope(
x: Tensor,
x: KerasTensor,
axis: Union[list[int], int],
) -> Tensor:
) -> KerasTensor:
"""RoPE positional encoding.

Implementation of the Rotary Position Embedding proposed in
Expand All @@ -153,7 +153,10 @@
Returns:
The input tensor with RoPE encodings.
"""
shape = x.shape.as_list()
# TensorFlow and JAX treat the shape differently. For the case of
# TensorFlow we need a list otherwise there is a problem in the
# toeplitz_matrix_rope.
shape = list(x.shape)

if isinstance(axis, int):
axis = [axis]
Expand All @@ -161,41 +164,41 @@
if isinstance(shape, (list, tuple)):
spatial_shape = [shape[i] for i in axis]
total_len = 1
for i in spatial_shape:

Check failure on line 167 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / yapf

reformatted
total_len *= i # type: ignore[operator]
position = tf.reshape(
tf.cast(tf.range(total_len, delta=1.0), tf.float32), spatial_shape)
position = keras.ops.reshape(
keras.ops.cast(keras.ops.arange(total_len), np.float32), spatial_shape)

Check failure on line 170 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / pylint (3.12)

Line too long (83/80) (line-too-long)

Check failure on line 170 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / pylint (3.13)

Line too long (83/80) (line-too-long)
else:
raise ValueError(f"Unsupported shape: {shape}")

# we assume that the axis can not be negative (e.g., -1)
if any(dim < 0 for dim in axis):
raise ValueError(f"Unsupported axis: {axis}")
for i in range(axis[-1] + 1, len(shape) - 1, 1):
position = tf.expand_dims(position, axis=-1)
position = keras.ops.expand_dims(position, axis=-1)

half_size = shape[-1] // 2 # type: ignore[operator]
freq_seq = tf.cast(tf.range(half_size), tf.float32) / float(half_size)
freq_seq = keras.ops.cast(keras.ops.arange(half_size), np.float32) / float(half_size)

Check failure on line 181 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / pylint (3.12)

Line too long (89/80) (line-too-long)

Check failure on line 181 in scaaml/models/gpam.py

View workflow job for this annotation

GitHub Actions / pylint (3.13)

Line too long (89/80) (line-too-long)
inv_freq = 10000**-freq_seq
sinusoid = tf.einsum("...,d->...d", position, inv_freq)
sin = tf.cast(tf.sin(sinusoid), dtype=x.dtype)
cos = tf.cast(tf.cos(sinusoid), dtype=x.dtype)
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat( # type: ignore[no-any-return]
sinusoid = keras.ops.einsum("...,d->...d", position, inv_freq)
sin = keras.ops.cast(keras.ops.sin(sinusoid), dtype=x.dtype)
cos = keras.ops.cast(keras.ops.cos(sinusoid), dtype=x.dtype)
x1, x2 = keras.ops.split(x, 2, axis=-1)
return keras.ops.concatenate( # type: ignore[no-any-return]
[x1 * cos - x2 * sin, x2 * cos + x1 * sin],
axis=-1,
)


def toeplitz_matrix_rope(
n: int,
a: Tensor,
b: Tensor,
) -> Tensor:
a: KerasTensor,
b: KerasTensor,
) -> KerasTensor:
"""Obtain Toeplitz matrix using rope."""
a = rope(tf.tile(a[None, :], [n, 1]), axis=[0])
b = rope(tf.tile(b[None, :], [n, 1]), axis=[0])
return tf.einsum("mk,nk->mn", a, b) # type: ignore[no-any-return]
a = rope(keras.ops.tile(a[None, :], [n, 1]), axis=[0])
b = rope(keras.ops.tile(b[None, :], [n, 1]), axis=[0])
return keras.ops.einsum("mk,nk->mn", a, b) # type: ignore[no-any-return]


@keras.saving.register_keras_serializable()
Expand Down Expand Up @@ -281,7 +284,7 @@
self.spatial_dropout_rate)

# attention activation function
self.attention_activation_layer = tf.keras.layers.Activation(
self.attention_activation_layer = keras.layers.Activation(
self.attention_activation)

def build(self, input_shape: tuple[int, ...]) -> None:
Expand Down Expand Up @@ -334,15 +337,21 @@
uv = self.proj1(x)
uv = self.dropout2(uv, training=training)

u, v, base = tf.split(
uv, [self.expand_dim, self.expand_dim, self.shared_dim], axis=-1)
u, v, base = keras.ops.split(
uv,
[self.expand_dim, self.expand_dim + self.expand_dim],
axis=-1,
)
Comment on lines +354 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While using keras.ops.split is functionally correct, it's less readable than the original tf.split because it uses split indices rather than chunk sizes. For better clarity and maintainability, consider using tensor slicing, which is more explicit about the intended chunk sizes and avoids potential confusion with the different split function semantics between TensorFlow and Keras/NumPy.

        u = uv[..., :self.expand_dim]
        v = uv[..., self.expand_dim:self.expand_dim + self.expand_dim]
        base = uv[..., self.expand_dim + self.expand_dim:]

assert u.shape[-1] == self.expand_dim
assert v.shape[-1] == self.expand_dim
assert base.shape[-1] == self.shared_dim

# generate q, k by scaled offset
base = tf.einsum("bnr,hr->bnhr", base, self.gamma) + self.beta
q, k = tf.unstack(base, axis=-2)
base = keras.ops.einsum("bnr,hr->bnhr", base, self.gamma) + self.beta
q, k = keras.ops.unstack(base, axis=-2)

# compute key-query scores
qk = tf.einsum("bnd,bmd->bnm", q, k)
qk = keras.ops.einsum("bnd,bmd->bnm", q, k)
qk = qk / self.max_len

# add relative position bias for attention
Expand All @@ -355,7 +364,7 @@
kernel = self.attention_dropout(kernel)

# apply values and project
x = u * tf.einsum("bnm,bme->bne", kernel, v)
x = u * keras.ops.einsum("bnm,bme->bne", kernel, v)

x = self.proj2(x)
return x + shortcut
Expand All @@ -377,11 +386,11 @@

@property
def weight_initializer(self) -> Any:
return clone_initializer(tf.random_normal_initializer(stddev=0.02))
return clone_initializer(keras.initializers.RandomNormal(stddev=0.02))

@property
def zeros_initializer(self) -> Any:
return clone_initializer(tf.initializers.zeros())
return clone_initializer(keras.initializers.Zeros())


@keras.saving.register_keras_serializable()
Expand Down Expand Up @@ -434,7 +443,7 @@

Args:

x (Tensor): Stem of the neural network.
x (KerasTensor): Stem of the neural network.

heads (dict[str, keras.layers.Layer]): A dictionary of previous heads
(those that are sooner in the topologically sorted outputs).
Expand Down Expand Up @@ -545,7 +554,7 @@


def create_heads_outputs( # type: ignore[no-any-unimported]
x: Tensor,
x: KerasTensor,
outputs: dict[str, dict[str, int]],
output_relations: list[tuple[str, str]],
) -> dict[str, keras.layers.Layer]:
Expand Down
Loading