diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index f90c214d6b..3d3fad5f76 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -12,6 +12,19 @@ from keras_hub.src.layers.modeling.cached_multi_head_attention import ( CachedMultiHeadAttention as CachedMultiHeadAttention, ) +from keras_hub.src.layers.modeling.dora_dense import DoRADense as DoRADense +from keras_hub.src.layers.modeling.dora_dense import ( + convert_dense_to_dora as convert_dense_to_dora, +) +from keras_hub.src.layers.modeling.dora_embeddings import ( + DoRAEmbedding as DoRAEmbedding, +) +from keras_hub.src.layers.modeling.dora_embeddings import ( + DoRAPositionEmbedding as DoRAPositionEmbedding, +) +from keras_hub.src.layers.modeling.dora_embeddings import ( + convert_embedding_to_dora as embedding_to_dora, +) from keras_hub.src.layers.modeling.f_net_encoder import ( FNetEncoder as FNetEncoder, ) diff --git a/keras_hub/src/layers/modeling/dora_dense.py b/keras_hub/src/layers/modeling/dora_dense.py new file mode 100644 index 0000000000..4fda010583 --- /dev/null +++ b/keras_hub/src/layers/modeling/dora_dense.py @@ -0,0 +1,426 @@ +"""DoRA (Weight-Decomposed Low-Rank Adaptation) Dense Layer Implementation. + +This module implements the DoRA dense layer that decomposes weights +into magnitude and direction components, applying low-rank +adaptation for efficient fine-tuning. + +Backend-compatible with TensorFlow, PyTorch, and JAX. + +Reference: DoRA: Weight-Decomposed Low-Rank Adaptation +""" + +import keras +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + + +@keras_hub_export("keras_hub.layers.DoRADense") +class DoRADense(layers.Layer): + """DoRA (Weight-Decomposed Low-Rank Adaptation) Dense layer. + + DoRA decomposes the weight matrix W into magnitude and direction components + W = m * (W_0 + B @ A) / ||W_0 + B @ A||_c + + Where: + - m: magnitude vector (learnable) + - W_0: frozen pretrained weights + - A, B: low-rank adaptation matrices (learnable) + - ||.||_c: column-wise L2 norm + + Args: + units: Positive integer, dimensionality of the output space. + rank: Rank of the adaptation. Positive integer. + alpha: LoRA scaling parameter. Float. + use_bias: Boolean, whether the layer uses a bias vector. + dropout: Float between 0 and 1. Fraction of input units to drop. + activation: Activation function to use. + kernel_initializer: Initializer for the kernel weights matrix. + bias_initializer: Initializer for the bias vector. + + lora_a_initializer: Initializer for the A matrix. + Defaults to 'he_uniform'. + lora_b_initializer: Initializer for the B matrix. + Defaults to 'zeros'. + magnitude_initializer: Initializer for magnitude vector. + Defaults to 'ones'. + + kernel_regularizer: Regularizer function applied to kernel weights. + bias_regularizer: Regularizer function applied to bias. + activity_regularizer: Regularizer function applied to output. + kernel_constraint: Constraint function applied to kernel weights. + bias_constraint: Constraint function applied to bias. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + units, + rank=4, + alpha=1.0, + use_bias=True, + dropout=0.0, + activation=None, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + lora_a_initializer="he_uniform", + lora_b_initializer="zeros", + magnitude_initializer="ones", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__(**kwargs) + + # Validate parameters + if units <= 0: + raise ValueError(f"units must be positive, got {units}") + if rank <= 0: + raise ValueError(f"rank must be positive, got {rank}") + if alpha <= 0: + raise ValueError(f"alpha must be positive, got {alpha}") + if not 0 <= dropout < 1: + raise ValueError(f"dropout must be in [0, 1), got {dropout}") + + self.units = units + self.rank = rank + self.alpha = alpha + self.use_bias = use_bias + self.dropout_rate = dropout + self.activation = keras.activations.get(activation) + + # Initializers + self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.bias_initializer = keras.initializers.get(bias_initializer) + self.lora_a_initializer = keras.initializers.get(lora_a_initializer) + self.lora_b_initializer = keras.initializers.get(lora_b_initializer) + self.magnitude_initializer = keras.initializers.get( + magnitude_initializer + ) + + # Regularizers + self.kernel_regularizer = keras.regularizers.get(kernel_regularizer) + self.bias_regularizer = keras.regularizers.get(bias_regularizer) + self.activity_regularizer = keras.regularizers.get(activity_regularizer) + + # Constraints + self.kernel_constraint = keras.constraints.get(kernel_constraint) + self.bias_constraint = keras.constraints.get(bias_constraint) + + # Dropout layer + self.dropout_layer = ( + layers.Dropout(self.dropout_rate) if self.dropout_rate > 0 else None + ) + + # Scaling factor + self.scaling = self.alpha / self.rank + self.input_spec = None + + # Weight matrices (will be initialized in build()) + self.kernel = None # Frozen pretrained weights W_0 + self.lora_a = None # Low-rank matrix A (input_dim, rank) + self.lora_b = None # Low-rank matrix B (rank, units) + self.magnitude = None # Magnitude vector m (units,) + self.bias = None + + def build(self, input_shape): + """Build the layer weights.""" + if len(input_shape) < 2: + raise ValueError( + f"Input shape must have at least 2 dimensions," + f" got {input_shape}" + ) + + input_dim = input_shape[-1] + if input_dim is None: + raise ValueError( + "The last dimension of input shape must be defined" + ) + + # Build frozen kernel weights (pretrained weights W_0) + self.kernel = self.add_weight( + name="kernel", + shape=(input_dim, self.units), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + trainable=False, # Frozen pretrained weights + ) + + # Build LoRA matrices + self.lora_a = self.add_weight( + name="lora_a", + shape=(input_dim, self.rank), + initializer=self.lora_a_initializer, + trainable=True, + ) + + self.lora_b = self.add_weight( + name="lora_b", + shape=(self.rank, self.units), + initializer=self.lora_b_initializer, + trainable=True, + ) + + # Build magnitude vector + self.magnitude = self.add_weight( + name="magnitude", + shape=(self.units,), + initializer=self.magnitude_initializer, + trainable=True, + ) + + # Build bias + if self.use_bias: + self.bias = self.add_weight( + name="bias", + shape=(self.units,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + trainable=True, + ) + + super().build(input_shape) + + def call(self, inputs, training=None): + if self.dropout_layer is not None: + inputs = self.dropout_layer(inputs, training=training) + + # Compute LoRA adaptation: A @ B + lora_adaptation = ops.matmul(self.lora_a, self.lora_b) * self.scaling + + # Combine with frozen weights: W_0 + B @ A + combined_weight = self.kernel + lora_adaptation + + # Compute column-wise L2 norms + column_norms = ops.sqrt( + ops.sum(ops.square(combined_weight), axis=0, keepdims=True) + ) + column_norms = ops.maximum(column_norms, 1e-8) + + # Normalize by column norms + normalized_weight = combined_weight / column_norms + + # Apply magnitude scaling + dora_weight = normalized_weight * ops.expand_dims( + self.magnitude, axis=0 + ) + + # Apply linear transformation + outputs = ops.matmul(inputs, dora_weight) + + if self.use_bias: + outputs = outputs + self.bias + if self.activation is not None: + outputs = self.activation(outputs) + return outputs + + def get_dora_parameters(self): + """Get DoRA-specific parameters. + + Returns: + Dictionary containing DoRA parameters. + """ + params = { + "lora_a": self.lora_a, + "lora_b": self.lora_b, + "magnitude": self.magnitude, + } + if self.use_bias: + params["bias"] = self.bias + return params + + def get_effective_weight(self): + """Compute the effective weight matrix after DoRA adaptation. + + Returns: + The effective weight matrix: m * (W_0 + B @ A) / ||W_0 + B @ A||_c + """ + # Compute adaptation + lora_adaptation = ops.matmul(self.lora_a, self.lora_b) * self.scaling + combined_weight = self.kernel + lora_adaptation + + # Normalize + column_norms = ops.sqrt( + ops.sum(ops.square(combined_weight), axis=0, keepdims=True) + ) + column_norms = ops.maximum(column_norms, 1e-8) + normalized_weight = combined_weight / column_norms + + # Apply magnitude + return normalized_weight * ops.expand_dims(self.magnitude, axis=0) + + def merge_weights(self): + """Merge DoRA weights back to a single weight matrix. + + This is useful for inference optimization + or converting back to standard Dense layer. + + Returns: + Dictionary with 'kernel' and optionally 'bias'. + """ + merged_weights = {"kernel": self.get_effective_weight()} + if self.use_bias: + merged_weights["bias"] = self.bias + return merged_weights + + def count_params(self): + """Count the number of trainable parameters in DoRA layer. + + Returns: + Number of trainable parameters. + """ + if not self.built: + return 0 + + input_dim = self.kernel.shape[0] + param_count = ( + input_dim * self.rank # lora_a + + self.rank * self.units # lora_b + + self.units # magnitude + ) + if self.use_bias: + param_count += self.units + return param_count + + def load_pretrained_weights(self, pretrained_kernel, pretrained_bias=None): + """Load pretrained weights into the frozen kernel. + + Args: + pretrained_kernel: Pretrained weight matrix. + pretrained_bias: Optional pretrained bias vector. + """ + if pretrained_kernel.shape != self.kernel.shape: + raise ValueError( + f"Pretrained kernel shape {pretrained_kernel.shape} " + f"doesn't match expected shape {self.kernel.shape}" + ) + + self.kernel.assign(pretrained_kernel) + + # Initialize magnitude vector to column-wise + # norms of pretrained weights + # This ensures DoRA starts with behavior identical to original weights + column_norms = ops.sqrt(ops.sum(ops.square(pretrained_kernel), axis=0)) + column_norms = ops.maximum(column_norms, 1e-8) + self.magnitude.assign(column_norms) + + if pretrained_bias is not None and self.use_bias: + if pretrained_bias.shape != self.bias.shape: + raise ValueError( + f"Pretrained bias shape {pretrained_bias.shape} " + f"doesn't match expected shape {self.bias.shape}" + ) + self.bias.assign(pretrained_bias) + + def get_config(self): + """Get layer configuration.""" + config = super().get_config() + config.update( + { + "units": self.units, + "rank": self.rank, + "alpha": self.alpha, + "use_bias": self.use_bias, + "dropout": self.dropout_rate, + "activation": keras.activations.serialize(self.activation), + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": keras.initializers.serialize( + self.bias_initializer + ), + "lora_a_initializer": keras.initializers.serialize( + self.lora_a_initializer + ), + "lora_b_initializer": keras.initializers.serialize( + self.lora_b_initializer + ), + "magnitude_initializer": keras.initializers.serialize( + self.magnitude_initializer + ), + "kernel_regularizer": keras.regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": keras.regularizers.serialize( + self.bias_regularizer + ), + "activity_regularizer": keras.regularizers.serialize( + self.activity_regularizer + ), + "kernel_constraint": keras.constraints.serialize( + self.kernel_constraint + ), + "bias_constraint": keras.constraints.serialize( + self.bias_constraint + ), + } + ) + return config + + @classmethod + def from_config(cls, config): + """Create layer from configuration.""" + return cls(**config) + + def compute_output_shape(self, input_shape): + """Compute output shape.""" + return input_shape[:-1] + (self.units,) + + +# Utility function to convert Dense layer to DoRADense +@keras_hub_export("keras_hub.layers.convert_dense_to_dora") +def convert_dense_to_dora( + dense_layer, + rank=4, + alpha=1.0, + dropout=0.0, +) -> DoRADense: + """Convert a standard Dense layer to DoRADense layer. + + Args: + dense_layer: The Dense layer to convert. + rank: Rank for DoRA adaptation. + alpha: Alpha parameter for DoRA. + dropout: Dropout rate. + + Returns: + DoRADense layer with pretrained weights loaded. + """ + # Create DoRA layer with same configuration + dora_layer = DoRADense( + units=dense_layer.units, + rank=rank, + alpha=alpha, + use_bias=dense_layer.use_bias, + dropout=dropout, + activation=dense_layer.activation, + kernel_initializer=dense_layer.kernel_initializer, + bias_initializer=dense_layer.bias_initializer, + lora_a_initializer="he_uniform", + lora_b_initializer="zeros", + kernel_regularizer=dense_layer.kernel_regularizer, + bias_regularizer=dense_layer.bias_regularizer, + activity_regularizer=dense_layer.activity_regularizer, + kernel_constraint=dense_layer.kernel_constraint, + bias_constraint=dense_layer.bias_constraint, + name=dense_layer.name + "_dora" if dense_layer.name else None, + ) + + # Build the DoRA layer if Dense layer is already built + if dense_layer.built: + # Build with the correct input shape from the dense layer + input_shape = (None, dense_layer.kernel.shape[0]) + dora_layer.build(input_shape) + # Load pretrained weights + dora_layer.load_pretrained_weights( + dense_layer.kernel, + dense_layer.bias if dense_layer.use_bias else None, + ) + + return dora_layer diff --git a/keras_hub/src/layers/modeling/dora_dense_test.py b/keras_hub/src/layers/modeling/dora_dense_test.py new file mode 100644 index 0000000000..c139aa5898 --- /dev/null +++ b/keras_hub/src/layers/modeling/dora_dense_test.py @@ -0,0 +1,539 @@ +"""Test suite for DoRA Dense Layer Implementation. + +This test suite is backend-independent and works with TensorFlow, +PyTorch, and JAX. +Run with: python -m pytest test_dora_dense.py -v +""" + +import keras +import numpy as np +import pytest +from keras import layers +from keras import ops + +from .dora_dense import DoRADense +from .dora_dense import convert_dense_to_dora + + +class TestDoRADense: + """Test cases for DoRADense layer.""" + + @pytest.fixture + def sample_input(self): + """Create sample input data.""" + return np.random.randn(32, 64).astype(np.float32) + + @pytest.fixture + def dora_layer(self): + """Create a basic DoRA layer.""" + return DoRADense( + units=128, rank=8, alpha=2.0, use_bias=True, activation="relu" + ) + + def test_layer_creation(self): + """Test basic layer creation with various configurations.""" + # Test default parameters + layer = DoRADense(units=64) + assert layer.units == 64 + assert layer.rank == 4 + assert layer.alpha == 1.0 + assert layer.use_bias is True + assert layer.dropout_rate == 0.0 + + # Test custom parameters + layer = DoRADense( + units=128, + rank=16, + alpha=0.5, + use_bias=False, + dropout=0.2, + activation="tanh", + ) + assert layer.units == 128 + assert layer.rank == 16 + assert layer.alpha == 0.5 + assert layer.use_bias is False + assert layer.dropout_rate == 0.2 + assert layer.activation == keras.activations.tanh + + def test_parameter_validation(self): + """Test parameter validation.""" + # Test invalid units + with pytest.raises(ValueError, match="units must be positive"): + DoRADense(units=0) + + with pytest.raises(ValueError, match="units must be positive"): + DoRADense(units=-5) + + # Test invalid rank + with pytest.raises(ValueError, match="rank must be positive"): + DoRADense(units=64, rank=0) + + with pytest.raises(ValueError, match="rank must be positive"): + DoRADense(units=64, rank=-2) + + # Test invalid alpha + with pytest.raises(ValueError, match="alpha must be positive"): + DoRADense(units=64, alpha=0) + + with pytest.raises(ValueError, match="alpha must be positive"): + DoRADense(units=64, alpha=-1.0) + + # Test invalid dropout + with pytest.raises(ValueError, match="dropout must be in"): + DoRADense(units=64, dropout=1.0) + + with pytest.raises(ValueError, match="dropout must be in"): + DoRADense(units=64, dropout=-0.1) + + def test_layer_build(self, sample_input): + """Test layer building process.""" + layer = DoRADense(units=32, rank=4) + + # Layer should not be built initially + assert not layer.built + + # Build the layer + layer.build(sample_input.shape) + + # Check if layer is built + assert layer.built + + # Check weight shapes + input_dim = sample_input.shape[-1] + assert layer.kernel.shape == (input_dim, 32) + assert layer.lora_a.shape == (input_dim, 4) + assert layer.lora_b.shape == (4, 32) + assert layer.magnitude.shape == (32,) + assert layer.bias.shape == (32,) + + def test_forward_pass(self, sample_input, dora_layer): + """Test forward pass functionality.""" + # Build and run forward pass + output = dora_layer(sample_input) + + # Check output shape + expected_shape = (sample_input.shape[0], dora_layer.units) + assert output.shape == expected_shape + + # Check output is not NaN or Inf + output_np = ops.convert_to_numpy(output) + assert not np.isnan(output_np).any() + assert not np.isinf(output_np).any() + + def test_weight_initialization(self, sample_input): + """Test weight initialization.""" + layer = DoRADense( + units=32, + rank=4, + lora_a_initializer="he_uniform", + lora_b_initializer="zeros", + magnitude_initializer="ones", + ) + + # Build the layer + layer.build(sample_input.shape) + + # Check lora_b is initialized to zeros + lora_b_np = ops.convert_to_numpy(layer.lora_b) + assert np.allclose(lora_b_np, 0.0) + + # Check magnitude is initialized to ones + magnitude_np = ops.convert_to_numpy(layer.magnitude) + assert np.allclose(magnitude_np, 1.0) + + def test_activation_functions(self, sample_input): + """Test different activation functions.""" + activations = ["relu", "tanh", "sigmoid", "linear", None] + + for activation in activations: + layer = DoRADense(units=16, activation=activation) + output = layer(sample_input) + + # Check output shape + assert output.shape == (sample_input.shape[0], 16) + + # Check activation is applied correctly + if activation == "relu": + output_np = ops.convert_to_numpy(output) + assert (output_np >= 0).all() + + def test_bias_configuration(self, sample_input): + """Test bias configuration.""" + # With bias + layer_with_bias = DoRADense(units=16, use_bias=True) + layer_with_bias.build(sample_input.shape) + assert layer_with_bias.bias is not None + + # Without bias + layer_without_bias = DoRADense(units=16, use_bias=False) + layer_without_bias.build(sample_input.shape) + assert layer_without_bias.bias is None + + def test_dropout_functionality(self, sample_input): + """Test dropout functionality.""" + layer_no_dropout = DoRADense(units=16, dropout=0.0) + layer_with_dropout = DoRADense(units=16, dropout=0.5) + + # Test without dropout + output_no_dropout = layer_no_dropout(sample_input, training=True) + assert output_no_dropout.shape == (sample_input.shape[0], 16) + + # Test with dropout + output_with_dropout = layer_with_dropout(sample_input, training=True) + assert output_with_dropout.shape == (sample_input.shape[0], 16) + + def test_get_effective_weight(self, sample_input, dora_layer): + """Test effective weight computation.""" + # Build the layer first + dora_layer.build(sample_input.shape) + + # Get effective weight + effective_weight = dora_layer.get_effective_weight() + + # Check shape + input_dim = sample_input.shape[-1] + expected_shape = (input_dim, dora_layer.units) + assert effective_weight.shape == expected_shape + + # Check it's not NaN or Inf + weight_np = ops.convert_to_numpy(effective_weight) + assert not np.isnan(weight_np).any() + assert not np.isinf(weight_np).any() + + def test_get_dora_parameters(self, sample_input, dora_layer): + """Test DoRA parameter retrieval.""" + dora_layer.build(sample_input.shape) + + params = dora_layer.get_dora_parameters() + + # Check all expected parameters are present + assert "lora_a" in params + assert "lora_b" in params + assert "magnitude" in params + assert "bias" in params # Since use_bias=True by default + + # Check shapes + input_dim = sample_input.shape[-1] + assert params["lora_a"].shape == (input_dim, dora_layer.rank) + assert params["lora_b"].shape == (dora_layer.rank, dora_layer.units) + assert params["magnitude"].shape == (dora_layer.units,) + assert params["bias"].shape == (dora_layer.units,) + + def test_merge_weights(self, sample_input, dora_layer): + """Test weight merging functionality.""" + dora_layer.build(sample_input.shape) + + merged = dora_layer.merge_weights() + + # Check structure + assert "kernel" in merged + assert "bias" in merged + + # Check shapes + input_dim = sample_input.shape[-1] + assert merged["kernel"].shape == (input_dim, dora_layer.units) + assert merged["bias"].shape == (dora_layer.units,) + + def test_count_params(self, sample_input): + """Test parameter counting.""" + layer = DoRADense(units=32, rank=8, use_bias=True) + + # Should return 0 before building + assert layer.count_params() == 0 + + # Build and count + layer.build(sample_input.shape) + input_dim = sample_input.shape[-1] + + expected_params = ( + input_dim * 8 # lora_a + + 8 * 32 # lora_b + + 32 # magnitude + + 32 # bias + ) + + assert layer.count_params() == expected_params + + def test_load_pretrained_weights(self, sample_input): + """Test loading pretrained weights.""" + layer = DoRADense(units=32, rank=4) + layer.build(sample_input.shape) + + input_dim = sample_input.shape[-1] + + # Create fake pretrained weights + pretrained_kernel = np.random.randn(input_dim, 32).astype(np.float32) + pretrained_bias = np.random.randn(32).astype(np.float32) + + # Load weights + layer.load_pretrained_weights(pretrained_kernel, pretrained_bias) + + # Check if weights are loaded correctly + kernel_np = ops.convert_to_numpy(layer.kernel) + bias_np = ops.convert_to_numpy(layer.bias) + + assert np.allclose(kernel_np, pretrained_kernel) + assert np.allclose(bias_np, pretrained_bias) + + # Check magnitude is initialized to column norms + expected_magnitude = np.linalg.norm(pretrained_kernel, axis=0) + magnitude_np = ops.convert_to_numpy(layer.magnitude) + assert np.allclose(magnitude_np, expected_magnitude, rtol=1e-5) + + def test_load_pretrained_weights_shape_mismatch(self, sample_input): + """Test loading pretrained weights with wrong shapes.""" + layer = DoRADense(units=32, rank=4) + layer.build(sample_input.shape) + + # Wrong kernel shape + wrong_kernel = np.random.randn(10, 20).astype(np.float32) + with pytest.raises(ValueError, match="Pretrained kernel shape"): + layer.load_pretrained_weights(wrong_kernel) + + # Wrong bias shape + correct_kernel = np.random.randn(sample_input.shape[-1], 32).astype( + np.float32 + ) + wrong_bias = np.random.randn(20).astype(np.float32) + with pytest.raises(ValueError, match="Pretrained bias shape"): + layer.load_pretrained_weights(correct_kernel, wrong_bias) + + def test_serialization(self, dora_layer): + """Test layer serialization and deserialization.""" + # Get config + config = dora_layer.get_config() + + # Check essential parameters are in config + assert config["units"] == dora_layer.units + assert config["rank"] == dora_layer.rank + assert config["alpha"] == dora_layer.alpha + assert config["use_bias"] == dora_layer.use_bias + assert config["dropout"] == dora_layer.dropout_rate + + # Create layer from config + restored_layer = DoRADense.from_config(config) + + # Check restored layer has same parameters + assert restored_layer.units == dora_layer.units + assert restored_layer.rank == dora_layer.rank + assert restored_layer.alpha == dora_layer.alpha + assert restored_layer.use_bias == dora_layer.use_bias + + def test_compute_output_shape(self): + """Test output shape computation.""" + layer = DoRADense(units=64) + + # Test various input shapes + input_shapes = [ + (None, 32), + (10, 32), + (None, 16, 32), + (5, 10, 32), + ] + + for input_shape in input_shapes: + output_shape = layer.compute_output_shape(input_shape) + expected_shape = input_shape[:-1] + (64,) + assert output_shape == expected_shape + + def test_regularization(self, sample_input): + """Test regularization functionality.""" + layer = DoRADense( + units=32, + kernel_regularizer="l2", + bias_regularizer="l1", + activity_regularizer="l2", + ) + + # Build and run forward pass + output = layer(sample_input) + + # Check output shape + assert output.shape == (sample_input.shape[0], 32) + + def test_constraints(self, sample_input): + """Test constraint functionality.""" + layer = DoRADense( + units=32, kernel_constraint="max_norm", bias_constraint="non_neg" + ) + + # Build and run forward pass + output = layer(sample_input) + + # Check output shape + assert output.shape == (sample_input.shape[0], 32) + + def test_training_inference_consistency(self, sample_input, dora_layer): + """Test consistency between training and inference modes.""" + # Forward pass in training mode + output_train = dora_layer(sample_input, training=True) + + # Forward pass in inference mode + output_infer = dora_layer(sample_input, training=False) + + # Should have same shape + assert output_train.shape == output_infer.shape + + # For layers without dropout, outputs should be identical + if dora_layer.dropout_rate == 0: + output_train_np = ops.convert_to_numpy(output_train) + output_infer_np = ops.convert_to_numpy(output_infer) + assert np.allclose(output_train_np, output_infer_np) + + +class TestDoRAConversion: + """Test cases for Dense to DoRA conversion.""" + + def test_convert_dense_to_dora(self): + """Test converting Dense layer to DoRA layer.""" + # Create a Dense layer + dense_layer = layers.Dense( + units=64, + activation="relu", + use_bias=True, + kernel_initializer="glorot_uniform", + ) + + # Build with sample input + sample_input = np.random.randn(10, 32).astype(np.float32) + dense_output = dense_layer(sample_input) + + # Convert to DoRA + dora_layer = convert_dense_to_dora( + dense_layer, rank=8, alpha=2.0, dropout=0.1 + ) + + # Check configuration + assert dora_layer.units == dense_layer.units + assert dora_layer.rank == 8 + assert dora_layer.alpha == 2.0 + assert dora_layer.dropout_rate == 0.1 + assert dora_layer.use_bias == dense_layer.use_bias + assert dora_layer.activation == dense_layer.activation + + # Check weights are loaded + assert dora_layer.built + + # Test forward pass produces reasonable output + dora_output = dora_layer(sample_input) + assert dora_output.shape == dense_output.shape + + def test_convert_unbuilt_dense(self): + """Test converting unbuilt Dense layer.""" + dense_layer = layers.Dense(units=32, activation="tanh") + + # Convert unbuilt layer + dora_layer = convert_dense_to_dora(dense_layer, rank=4) + + # Should not be built yet + assert not dora_layer.built + + # But should have correct configuration + assert dora_layer.units == 32 + assert dora_layer.rank == 4 + assert dora_layer.activation == keras.activations.tanh + + +class TestDoRAMathematicalProperties: + """Test mathematical properties of DoRA.""" + + def test_magnitude_scaling_property(self): + """Test that DoRA properly applies magnitude scaling.""" + # Create layer + layer = DoRADense(units=16, rank=4) + sample_input = np.random.randn(8, 32).astype(np.float32) + layer.build(sample_input.shape) + + # Get effective weight + effective_weight = layer.get_effective_weight() + effective_weight_np = ops.convert_to_numpy(effective_weight) + + # Compute column norms of effective weight + column_norms = np.linalg.norm(effective_weight_np, axis=0) + magnitude_np = ops.convert_to_numpy(layer.magnitude) + + # Column norms should equal magnitude values (approximately) + assert np.allclose(column_norms, magnitude_np, rtol=1e-5) + + def test_low_rank_adaptation_property(self): + """Test that adaptation is indeed low-rank.""" + layer = DoRADense(units=64, rank=8) + sample_input = np.random.randn(16, 128).astype(np.float32) + layer.build(sample_input.shape) + + # Compute LoRA adaptation + lora_a_np = ops.convert_to_numpy(layer.lora_a) + lora_b_np = ops.convert_to_numpy(layer.lora_b) + adaptation = lora_a_np @ lora_b_np + + # Check that adaptation matrix has rank <= layer.rank + actual_rank = np.linalg.matrix_rank(adaptation) + assert actual_rank <= layer.rank + + def test_zero_initialization_equivalence(self): + """Test that zero LoRA initialization gives original behavior.""" + # Create layer with zero LoRA initialization + layer = DoRADense( + units=32, + rank=4, + lora_a_initializer="zeros", + lora_b_initializer="zeros", + ) + + sample_input = np.random.randn(8, 16).astype(np.float32) + layer.build(sample_input.shape) + + # Set magnitude to column norms of kernel + kernel_np = ops.convert_to_numpy(layer.kernel) + column_norms = np.linalg.norm(kernel_np, axis=0) + layer.magnitude.assign(column_norms) + + # Effective weight should equal original kernel + effective_weight = layer.get_effective_weight() + effective_weight_np = ops.convert_to_numpy(effective_weight) + + assert np.allclose(effective_weight_np, kernel_np, rtol=1e-5) + + +def test_backend_compatibility(): + """Test that the implementation works across different backends.""" + # This test ensures the code runs without backend-specific errors + layer = DoRADense(units=16, rank=4) + sample_input = np.random.randn(4, 8).astype(np.float32) + + # Should work regardless of backend + output = layer(sample_input) + assert output.shape == (4, 16) + + # Test parameter access + params = layer.get_dora_parameters() + assert len(params) == 4 # lora_a, lora_b, magnitude, bias + + print( + f"Backend compatibility test " + f"passed with Keras backend: {keras.backend.backend()}" + ) + + +if __name__ == "__main__": + # Run basic tests if executed directly + test_backend_compatibility() + + # Create and test a basic layer + layer = DoRADense(units=32, rank=8, alpha=2.0) + sample_input = np.random.randn(16, 64).astype(np.float32) + + # Test forward pass + output = layer(sample_input) + print(f"Output shape: {output.shape}") + + # Test parameter counting + param_count = layer.count_params() + print(f"Trainable parameters: {param_count}") + + # Test effective weight computation + effective_weight = layer.get_effective_weight() + print(f"Effective weight shape: {effective_weight.shape}") + + print("All basic tests passed!") diff --git a/keras_hub/src/layers/modeling/dora_embeddings.py b/keras_hub/src/layers/modeling/dora_embeddings.py new file mode 100644 index 0000000000..23dcb086ae --- /dev/null +++ b/keras_hub/src/layers/modeling/dora_embeddings.py @@ -0,0 +1,733 @@ +"""DoRA (Weight-Decomposed Low-Rank Adaptation) +Embedding Layer Implementation. + +This module implements the DoRA embedding +layer that applies weight decomposition +and low-rank adaptation to token +embeddings for efficient fine-tuning. + +Backend-compatible with TensorFlow, PyTorch, and JAX. + +Reference: DoRA: Weight-Decomposed Low-Rank Adaptation +""" + +import keras +from keras import layers +from keras import ops + +from keras_hub.src.api_export import keras_hub_export + + +@keras_hub_export("keras_hub.layers.DoRAEmbedding") +class DoRAEmbedding(layers.Layer): + """DoRA (Weight-Decomposed Low-Rank Adaptation) + Embedding layer. + + DoRA decomposes the embedding weight + matrix W into magnitude and direction components: + W = m * (W_0 + B @ A) / ||W_0 + B @ A||_c + + Where: + - m: magnitude vector (learnable) + - W_0: frozen pretrained embedding weights + - A, B: low-rank adaptation matrices (learnable) + - ||.||_c: column-wise L2 norm + + Args: + input_dim: Size of the vocabulary (number of tokens). + output_dim: Dimension of the dense embedding vectors. + rank: Rank of the adaptation. Positive integer. + alpha: LoRA scaling parameter. Float. + embeddings_initializer: + Initializer for the embeddings matrix. + lora_a_initializer: + Initializer for the A matrix. Defaults to 'he_uniform'. + lora_b_initializer: + Initializer for the B matrix. Defaults to 'zeros'. + magnitude_initializer: + Initializer for magnitude vector. Defaults to 'ones'. + embeddings_regularizer: + Regularizer function applied to embeddings. + activity_regularizer: + Regularizer function applied to output. + embeddings_constraint: + Constraint function applied to embeddings. + mask_zero: + Whether input value 0 is a special "padding" value. + input_length: + Length of input sequences (for compatibility). + sparse: + Whether to use sparse embedding lookup (experimental). + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + input_dim, + output_dim, + rank=4, + alpha=1.0, + embeddings_initializer="uniform", + lora_a_initializer="he_uniform", + lora_b_initializer="zeros", + magnitude_initializer="ones", + embeddings_regularizer=None, + activity_regularizer=None, + embeddings_constraint=None, + mask_zero=False, + input_length=None, + sparse=False, + **kwargs, + ): + super().__init__(**kwargs) + + # Validate parameters + if input_dim <= 0: + raise ValueError(f"input_dim must be positive, got {input_dim}") + if output_dim <= 0: + raise ValueError(f"output_dim must be positive, got {output_dim}") + if rank <= 0: + raise ValueError(f"rank must be positive, got {rank}") + if alpha <= 0: + raise ValueError(f"alpha must be positive, got {alpha}") + + self.input_dim = input_dim + self.output_dim = output_dim + self.rank = rank + self.alpha = alpha + self.mask_zero = mask_zero + self.input_length = input_length + self.sparse = sparse + + # Initializers + self.embeddings_initializer = keras.initializers.get( + embeddings_initializer + ) + self.lora_a_initializer = keras.initializers.get(lora_a_initializer) + self.lora_b_initializer = keras.initializers.get(lora_b_initializer) + self.magnitude_initializer = keras.initializers.get( + magnitude_initializer + ) + + # Regularizers + self.embeddings_regularizer = keras.regularizers.get( + embeddings_regularizer + ) + self.activity_regularizer = keras.regularizers.get(activity_regularizer) + + # Constraints + self.embeddings_constraint = keras.constraints.get( + embeddings_constraint + ) + + # Scaling factor + self.scaling = self.alpha / self.rank + + # Weight matrices (will be initialized in build()) + self.embeddings = None # Frozen pretrained embeddings W_0 + self.lora_a = None # Low-rank matrix A (input_dim, rank) + self.lora_b = None # Low-rank matrix B (rank, output_dim) + self.magnitude = None # Magnitude vector m (output_dim,) + + # Set compute dtype policy + self._supports_masking = mask_zero + + def build(self, input_shape): + """Build the layer weights.""" + # Build frozen embedding weights (pretrained embeddings W_0) + self.embeddings = self.add_weight( + name="embeddings", + shape=(self.input_dim, self.output_dim), + initializer=self.embeddings_initializer, + regularizer=self.embeddings_regularizer, + constraint=self.embeddings_constraint, + trainable=False, # Frozen pretrained weights + ) + + # Build LoRA matrices + self.lora_a = self.add_weight( + name="lora_a", + shape=(self.input_dim, self.rank), + initializer=self.lora_a_initializer, + trainable=True, + ) + + self.lora_b = self.add_weight( + name="lora_b", + shape=(self.rank, self.output_dim), + initializer=self.lora_b_initializer, + trainable=True, + ) + + # Build magnitude vector + self.magnitude = self.add_weight( + name="magnitude", + shape=(self.output_dim,), + initializer=self.magnitude_initializer, + trainable=True, + ) + + super().build(input_shape) + + def call(self, inputs, training=None): + """Forward pass of DoRA embedding layer. + + Implements: output = embedding_lookup + (inputs, m * (W_0 + B @ A) / ||W_0 + B @ A||_c) + + Args: + inputs: Input tensor containing token indices. + training: Boolean indicating whether in training mode. + + Returns: + Output tensor after DoRA embedding lookup. + """ + # Cast inputs to integers for all backends + inputs = ops.cast(inputs, "int32") + + # Get effective embedding matrix + effective_embeddings = self._get_effective_embeddings() + + # Perform embedding lookup using backend-agnostic operations + outputs = ops.take(effective_embeddings, inputs, axis=0) + + return outputs + + def _get_effective_embeddings(self): + """Compute the effective embedding matrix after DoRA adaptation. + + Returns: + The effective embedding matrix: + m * (W_0 + B @ A) / ||W_0 + B @ A||_c + """ + # Compute low-rank adaptation: A @ B (with scaling applied to B) + # Use ops.multiply for backend compatibility + scaled_lora_b = ops.multiply(self.lora_b, self.scaling) + lora_adaptation = ops.matmul(self.lora_a, scaled_lora_b) + + # Combine pretrained embeddings with adaptation: W_0 + ΔW + combined_embeddings = ops.add(self.embeddings, lora_adaptation) + + # Compute column-wise L2 norms: ||W_0 + ΔW||_c + # Use ops for all operations to ensure backend compatibility + squared_embeddings = ops.square(combined_embeddings) + sum_squares = ops.sum(squared_embeddings, axis=0, keepdims=True) + column_norms = ops.sqrt(sum_squares) + + # Prevent division by zero with backend-agnostic maximum + eps = ops.convert_to_tensor(1e-8, dtype=column_norms.dtype) + column_norms = ops.maximum(column_norms, eps) + + # DoRA formula: m * (W_0 + ΔW) / ||W_0 + ΔW||_c + # Expand magnitude dimensions for broadcasting + magnitude_expanded = ops.expand_dims(self.magnitude, axis=0) + + # Apply magnitude scaling and normalization + numerator = ops.multiply(combined_embeddings, magnitude_expanded) + dora_embeddings = ops.divide(numerator, column_norms) + + return dora_embeddings + + def compute_mask(self, inputs, mask=None): + """Compute output mask for masking support.""" + if not self.mask_zero: + return None + + # Create mask where input is not zero using backend-agnostic ops + zero_tensor = ops.convert_to_tensor(0, dtype=inputs.dtype) + return ops.not_equal(inputs, zero_tensor) + + def get_dora_parameters(self): + """Get DoRA-specific parameters. + + Returns: + Dictionary containing DoRA parameters. + """ + return { + "lora_a": self.lora_a, + "lora_b": self.lora_b, + "magnitude": self.magnitude, + } + + def get_effective_embeddings(self): + """Get the effective embedding matrix after DoRA adaptation. + + Returns: + The effective embedding matrix. + """ + return self._get_effective_embeddings() + + def merge_weights(self): + """Merge DoRA weights back to a single embedding matrix. + + This is useful for inference optimization or converting + back to standard Embedding layer. + + Returns: + Dictionary with 'embeddings'. + """ + return {"embeddings": self._get_effective_embeddings()} + + def count_params(self): + """Count the number of trainable parameters in DoRA embedding layer. + + Returns: + Number of trainable parameters. + """ + return ( + self.input_dim * self.rank # lora_a + + self.rank * self.output_dim # lora_b + + self.output_dim # magnitude + ) + + def load_pretrained_embeddings(self, pretrained_embeddings): + """Load pretrained embeddings into the frozen embedding matrix. + + Args: + pretrained_embeddings: Pretrained embedding matrix. + """ + # Convert to tensor if needed for backend compatibility + if not hasattr(pretrained_embeddings, "shape"): + pretrained_embeddings = ops.convert_to_tensor(pretrained_embeddings) + + expected_shape = (self.input_dim, self.output_dim) + if tuple(pretrained_embeddings.shape) != expected_shape: + raise ValueError( + f"Pretrained embeddings shape {pretrained_embeddings.shape} " + f"doesn't match expected shape {expected_shape}" + ) + + # Use backend-compatible assignment + self._safe_assign_weight(self.embeddings, pretrained_embeddings) + + # Initialize magnitude to preserve exact functional equivalence + # Compute column norms using backend-agnostic operations + squared_embeddings = ops.square(pretrained_embeddings) + sum_squares = ops.sum(squared_embeddings, axis=0) + column_norms = ops.sqrt(sum_squares) + + self._safe_assign_weight(self.magnitude, column_norms) + + def _safe_assign_weight(self, weight_var, new_value): + """Safely assign new values to weights across backends.""" + try: + # Try standard Keras approach first + weight_var.assign(new_value) + except Exception: + # Fallback for backends that don't support assign + # This approach works across all backends + weight_var._value = ops.convert_to_tensor(new_value) + + def expand_vocabulary(self, new_vocab_size, new_token_embeddings=None): + """Expand vocabulary size and optionally add new token embeddings. + + Since Keras doesn't allow modifying weights after building, + this method returns a new DoRAEmbedding layer with expanded + vocabulary instead of modifying the current layer in-place. + + Args: + new_vocab_size: New vocabulary size + (must be >= current input_dim). + new_token_embeddings: Optional embeddings for new tokens. + Shape should be (new_vocab_size - current_input_dim, output_dim). + + Returns: + New DoRAEmbedding layer with expanded vocabulary. + """ + if new_vocab_size <= self.input_dim: + raise ValueError( + f"new_vocab_size ({new_vocab_size}) must be greater than " + f"current input_dim ({self.input_dim})" + ) + + if not self.built: + raise ValueError("Layer must be built before expanding vocabulary") + + num_new_tokens = new_vocab_size - self.input_dim + + # Create new layer with expanded vocabulary + expanded_layer = DoRAEmbedding( + input_dim=new_vocab_size, + output_dim=self.output_dim, + rank=self.rank, + alpha=self.alpha, + embeddings_initializer=self.embeddings_initializer, + lora_a_initializer=self.lora_a_initializer, + lora_b_initializer=self.lora_b_initializer, + magnitude_initializer=self.magnitude_initializer, + embeddings_regularizer=self.embeddings_regularizer, + activity_regularizer=self.activity_regularizer, + embeddings_constraint=self.embeddings_constraint, + mask_zero=self.mask_zero, + input_length=self.input_length, + sparse=self.sparse, + name=self.name + "_expanded", + ) + + # Build the new layer + expanded_layer.build(None) + + # Get current weights as tensors + current_embeddings = self.embeddings + current_lora_a = self.lora_a + current_lora_b = self.lora_b + current_magnitude = self.magnitude + + # Prepare new token embeddings using backend-agnostic operations + if new_token_embeddings is None: + # Use the same initializer as the original embeddings + new_embeddings = self.embeddings_initializer( + shape=(num_new_tokens, self.output_dim) + ) + else: + # Convert to tensor for backend compatibility + new_embeddings = ops.convert_to_tensor(new_token_embeddings) + expected_shape = (num_new_tokens, self.output_dim) + if tuple(new_embeddings.shape) != expected_shape: + raise ValueError( + f"new_token_embeddings shape {new_embeddings.shape} " + f"doesn't match expected shape {expected_shape}" + ) + + # Prepare new LoRA A rows using the same initializer + new_lora_a_rows = self.lora_a_initializer( + shape=(num_new_tokens, self.rank) + ) + + # Create expanded tensors using backend-agnostic concatenation + expanded_embeddings = ops.concatenate( + [current_embeddings, new_embeddings], axis=0 + ) + expanded_lora_a = ops.concatenate( + [current_lora_a, new_lora_a_rows], axis=0 + ) + + # Assign the expanded weights to the new layer + expanded_layer._safe_assign_weight( + expanded_layer.embeddings, expanded_embeddings + ) + expanded_layer._safe_assign_weight( + expanded_layer.lora_a, expanded_lora_a + ) + expanded_layer._safe_assign_weight( + expanded_layer.lora_b, current_lora_b + ) + expanded_layer._safe_assign_weight( + expanded_layer.magnitude, current_magnitude + ) + + return expanded_layer + + def get_config(self): + """Get layer configuration.""" + config = super().get_config() + config.update( + { + "input_dim": self.input_dim, + "output_dim": self.output_dim, + "rank": self.rank, + "alpha": self.alpha, + "embeddings_initializer": keras.initializers.serialize( + self.embeddings_initializer + ), + "lora_a_initializer": keras.initializers.serialize( + self.lora_a_initializer + ), + "lora_b_initializer": keras.initializers.serialize( + self.lora_b_initializer + ), + "magnitude_initializer": keras.initializers.serialize( + self.magnitude_initializer + ), + "embeddings_regularizer": keras.regularizers.serialize( + self.embeddings_regularizer + ), + "activity_regularizer": keras.regularizers.serialize( + self.activity_regularizer + ), + "embeddings_constraint": keras.constraints.serialize( + self.embeddings_constraint + ), + "mask_zero": self.mask_zero, + "input_length": self.input_length, + "sparse": self.sparse, + } + ) + return config + + @classmethod + def from_config(cls, config): + """Create layer from configuration.""" + return cls(**config) + + def compute_output_shape(self, input_shape): + """Compute output shape.""" + if self.input_length is not None: + return input_shape + (self.output_dim,) + else: + return input_shape + (self.output_dim,) + + +@keras_hub_export("keras_hub.layers.DoRAPositionEmbedding") +class DoRAPositionEmbedding(layers.Layer): + """DoRA-enabled position embedding layer. + + This layer creates learnable positional embeddings + that are added to token embeddings, + using DoRA weight decomposition for efficient adaptation. + """ + + def __init__( + self, + sequence_length, + output_dim, + rank=4, + alpha=1.0, + initializer="uniform", + lora_a_initializer="he_uniform", + lora_b_initializer="zeros", + magnitude_initializer="ones", + **kwargs, + ): + super().__init__(**kwargs) + + self.sequence_length = sequence_length + self.output_dim = output_dim + self.rank = rank + self.alpha = alpha + + # Initializers + self.initializer = keras.initializers.get(initializer) + self.lora_a_initializer = keras.initializers.get(lora_a_initializer) + self.lora_b_initializer = keras.initializers.get(lora_b_initializer) + self.magnitude_initializer = keras.initializers.get( + magnitude_initializer + ) + + # Scaling factor + self.scaling = self.alpha / self.rank + + # Weight matrices (will be initialized in build()) + self.position_embeddings = None # Frozen position embeddings + self.lora_a = None # Low-rank matrix A + self.lora_b = None # Low-rank matrix B + self.magnitude = None # Magnitude vector + + def build(self, input_shape): + """Build the position embedding weights.""" + # Build frozen position embedding weights + self.position_embeddings = self.add_weight( + name="position_embeddings", + shape=(self.sequence_length, self.output_dim), + initializer=self.initializer, + trainable=False, # Frozen + ) + + # Build LoRA matrices + self.lora_a = self.add_weight( + name="lora_a", + shape=(self.sequence_length, self.rank), + initializer=self.lora_a_initializer, + trainable=True, + ) + + self.lora_b = self.add_weight( + name="lora_b", + shape=(self.rank, self.output_dim), + initializer=self.lora_b_initializer, + trainable=True, + ) + + # Build magnitude vector + self.magnitude = self.add_weight( + name="magnitude", + shape=(self.output_dim,), + initializer=self.magnitude_initializer, + trainable=True, + ) + + super().build(input_shape) + + """def call(self, inputs, start_index=0): + Forward pass of DoRA position embedding. + + Args: + inputs: Input tensor (token embeddings) + of shape [batch_size, seq_len, hidden_dim]. + start_index: Starting position index + (for compatibility with KerasHub). + + Returns: + Position embeddings of shape [batch_size, seq_len, hidden_dim]. + + input_shape = ops.shape(inputs) + seq_len = input_shape[-2] + + # Get effective position embeddings using DoRA + effective_pos_embeddings = self._get_effective_position_embeddings() + + # Create position indices using backend-agnostic operations + start_tensor = ops.convert_to_tensor(start_index, dtype="int32") + seq_len_tensor = ops.convert_to_tensor(seq_len, dtype="int32") + end_index = ops.add(start_tensor, seq_len_tensor) + + positions = ops.arange(start_index, end_index, dtype="int32") + + # Clip positions to valid range + max_pos = ops.convert_to_tensor( + self.sequence_length - 1, dtype="int32" + ) + min_pos = ops.convert_to_tensor(0, dtype="int32") + positions = ops.clip(positions, min_pos, max_pos) + + # Lookup position embeddings + position_embeddings = ops.take( + effective_pos_embeddings, positions, axis=0 + ) + + # Expand dimensions to match input batch size + position_embeddings = ops.expand_dims(position_embeddings, axis=0) + + # Create target shape for broadcasting + batch_size = input_shape[0] + target_shape = [batch_size, seq_len, self.output_dim] + position_embeddings = ops.broadcast_to( + position_embeddings, target_shape + ) + + return position_embeddings""" + + def call(self, inputs, start_index=0): + """Forward pass of DoRA position embedding. + + Args: + inputs: Input tensor (token embeddings) + of shape [batch_size, seq_len, hidden_dim]. + start_index: Starting position index + (for compatibility with KerasHub). + + Returns: + Position embeddings of shape [batch_size, seq_len, hidden_dim]. + """ + input_shape = ops.shape(inputs) + seq_len = input_shape[-2] + batch_size = input_shape[0] + + # Get effective position embeddings using DoRA + effective_pos_embeddings = self._get_effective_position_embeddings() + + # Convert start_index to tensor for consistent operations + start_tensor = ops.convert_to_tensor(start_index, dtype="int32") + + # Create position indices from start_index to start_index + seq_len + position_indices = ops.arange(seq_len, dtype="int32") + start_tensor + + # Clamp indices to valid range [0, sequence_length - 1] + max_pos = ops.convert_to_tensor(self.sequence_length - 1, dtype="int32") + min_pos = ops.convert_to_tensor(0, dtype="int32") + position_indices = ops.clip(position_indices, min_pos, max_pos) + + # Gather position embeddings using the indices + position_embeddings = ops.take( + effective_pos_embeddings, position_indices, axis=0 + ) + + # Add batch dimension and broadcast to match batch size + position_embeddings = ops.expand_dims(position_embeddings, axis=0) + position_embeddings = ops.broadcast_to( + position_embeddings, [batch_size, seq_len, self.output_dim] + ) + + return position_embeddings + + def _get_effective_position_embeddings(self): + """Compute effective position embeddings using DoRA decomposition.""" + # Compute low-rank adaptation (scaling applied to B matrix) + scaled_lora_b = ops.multiply(self.lora_b, self.scaling) + lora_adaptation = ops.matmul(self.lora_a, scaled_lora_b) + + # Combine with frozen weights + combined_embeddings = ops.add(self.position_embeddings, lora_adaptation) + + # Compute column-wise L2 norms using backend-agnostic operations + squared_embeddings = ops.square(combined_embeddings) + sum_squares = ops.sum(squared_embeddings, axis=0, keepdims=True) + column_norms = ops.sqrt(sum_squares) + + # Prevent division by zero + eps = ops.convert_to_tensor(1e-8, dtype=column_norms.dtype) + column_norms = ops.maximum(column_norms, eps) + + # Apply DoRA formula: m * (W_0 + ΔW) / ||W_0 + ΔW||_c + magnitude_expanded = ops.expand_dims(self.magnitude, axis=0) + numerator = ops.multiply(combined_embeddings, magnitude_expanded) + dora_embeddings = ops.divide(numerator, column_norms) + + return dora_embeddings + + def get_config(self): + """Get layer configuration.""" + config = super().get_config() + config.update( + { + "sequence_length": self.sequence_length, + "output_dim": self.output_dim, + "rank": self.rank, + "alpha": self.alpha, + "initializer": keras.initializers.serialize(self.initializer), + "lora_a_initializer": keras.initializers.serialize( + self.lora_a_initializer + ), + "lora_b_initializer": keras.initializers.serialize( + self.lora_b_initializer + ), + "magnitude_initializer": keras.initializers.serialize( + self.magnitude_initializer + ), + } + ) + return config + + +# Utility function to convert Embedding layer to DoRAEmbedding +@keras_hub_export("keras_hub.layers.embedding_to_dora") +def convert_embedding_to_dora( + embedding_layer, + rank=4, + alpha=1.0, +) -> DoRAEmbedding: + """Convert a standard Embedding layer to DoRAEmbedding layer. + + Args: + embedding_layer: The Embedding layer to convert. + rank: Rank for DoRA adaptation. + alpha: Alpha parameter for DoRA. + + Returns: + DoRAEmbedding layer with pretrained weights loaded. + """ + # Safely get input_length attribute + input_length = getattr(embedding_layer, "input_length", None) + + # Create DoRA embedding layer with same configuration + dora_layer = DoRAEmbedding( + input_dim=embedding_layer.input_dim, + output_dim=embedding_layer.output_dim, + rank=rank, + alpha=alpha, + embeddings_initializer=embedding_layer.embeddings_initializer, + embeddings_regularizer=embedding_layer.embeddings_regularizer, + activity_regularizer=embedding_layer.activity_regularizer, + embeddings_constraint=embedding_layer.embeddings_constraint, + mask_zero=embedding_layer.mask_zero, + input_length=input_length, + name=embedding_layer.name + "_dora", + ) + + # Build the DoRA layer if Embedding layer is already built + if embedding_layer.built: + dora_layer.build(None) # Embedding layers don't depend on input shape + # Load pretrained embeddings + dora_layer.load_pretrained_embeddings(embedding_layer.embeddings) + + return dora_layer diff --git a/keras_hub/src/layers/modeling/dora_embeddings_test.py b/keras_hub/src/layers/modeling/dora_embeddings_test.py new file mode 100644 index 0000000000..7639a7ec96 --- /dev/null +++ b/keras_hub/src/layers/modeling/dora_embeddings_test.py @@ -0,0 +1,868 @@ +"""Test suite for DoRA Embedding Layer Implementation. + +This test suite is backend-independent and works with +TensorFlow, PyTorch, and JAX. +Run with: python -m pytest test_dora_embeddings.py -v +""" + +import keras +import numpy as np +import pytest +from keras import layers +from keras import ops + +from .dora_embeddings import DoRAEmbedding +from .dora_embeddings import DoRAPositionEmbedding +from .dora_embeddings import convert_embedding_to_dora + + +def safe_convert_to_numpy(tensor): + """Safely convert tensor to numpy across backends.""" + try: + return ops.convert_to_numpy(tensor) + except Exception: + # Fallback for different backends + if hasattr(tensor, "numpy"): + return tensor.numpy() + elif hasattr(tensor, "detach"): + return tensor.detach().numpy() + else: + return np.array(tensor) + + +def safe_allclose(a, b, rtol=1e-5, atol=1e-8): + """Safely check if arrays are close across backends.""" + a_np = safe_convert_to_numpy(a) if not isinstance(a, np.ndarray) else a + b_np = safe_convert_to_numpy(b) if not isinstance(b, np.ndarray) else b + return np.allclose(a_np, b_np, rtol=rtol, atol=atol) + + +def safe_array_equal(a, b): + """Safely check if arrays are equal across backends.""" + a_np = safe_convert_to_numpy(a) if not isinstance(a, np.ndarray) else a + b_np = safe_convert_to_numpy(b) if not isinstance(b, np.ndarray) else b + return np.array_equal(a_np, b_np) + + +def check_no_nan_inf(tensor): + """Check tensor has no NaN or Inf values across backends.""" + tensor_np = safe_convert_to_numpy(tensor) + return not (np.isnan(tensor_np).any() or np.isinf(tensor_np).any()) + + +def create_random_tensor(shape, dtype="float32", seed=42): + """Create random tensor compatible across backends.""" + np.random.seed(seed) + if dtype == "int32": + if len(shape) == 2: + # Fix: Ensure high value is always > 0 + vocab_size = max(shape[0] // 10, 10) # Minimum vocab size of 10 + high_value = max(min(vocab_size, 100), 2) + return np.random.randint(0, high_value, size=shape, dtype=np.int32) + else: + return np.random.randint(0, 1000, size=shape, dtype=np.int32) + else: + return np.random.randn(*shape).astype(dtype) + + +class TestDoRAEmbedding: + """Test cases for DoRAEmbedding layer.""" + + @pytest.fixture + def sample_input(self): + """Create sample token indices.""" + return create_random_tensor((32, 64), dtype="int32", seed=42) + + @pytest.fixture + def dora_embedding(self): + """Create a basic DoRA embedding layer.""" + return DoRAEmbedding( + input_dim=1000, output_dim=128, rank=8, alpha=2.0, mask_zero=True + ) + + def test_layer_creation(self): + """Test basic layer creation with various configurations.""" + # Test default parameters + layer = DoRAEmbedding(input_dim=1000, output_dim=64) + assert layer.input_dim == 1000 + assert layer.output_dim == 64 + assert layer.rank == 4 + assert layer.alpha == 1.0 + assert layer.mask_zero is False + + # Test custom parameters + layer = DoRAEmbedding( + input_dim=5000, + output_dim=256, + rank=16, + alpha=0.5, + mask_zero=True, + input_length=128, + sparse=True, + ) + assert layer.input_dim == 5000 + assert layer.output_dim == 256 + assert layer.rank == 16 + assert layer.alpha == 0.5 + assert layer.mask_zero is True + assert layer.input_length == 128 + assert layer.sparse is True + + def test_parameter_validation(self): + """Test parameter validation.""" + # Test invalid input_dim + with pytest.raises(ValueError, match="input_dim must be positive"): + DoRAEmbedding(input_dim=0, output_dim=64) + + with pytest.raises(ValueError, match="input_dim must be positive"): + DoRAEmbedding(input_dim=-5, output_dim=64) + + # Test invalid output_dim + with pytest.raises(ValueError, match="output_dim must be positive"): + DoRAEmbedding(input_dim=1000, output_dim=0) + + with pytest.raises(ValueError, match="output_dim must be positive"): + DoRAEmbedding(input_dim=1000, output_dim=-10) + + # Test invalid rank + with pytest.raises(ValueError, match="rank must be positive"): + DoRAEmbedding(input_dim=1000, output_dim=64, rank=0) + + with pytest.raises(ValueError, match="rank must be positive"): + DoRAEmbedding(input_dim=1000, output_dim=64, rank=-2) + + # Test invalid alpha + with pytest.raises(ValueError, match="alpha must be positive"): + DoRAEmbedding(input_dim=1000, output_dim=64, alpha=0) + + with pytest.raises(ValueError, match="alpha must be positive"): + DoRAEmbedding(input_dim=1000, output_dim=64, alpha=-1.0) + + def test_layer_build(self, dora_embedding): + """Test layer building process.""" + # Layer should not be built initially + assert not dora_embedding.built + + # Build the layer + dora_embedding.build(None) # Embedding layers don't need input shape + + # Check if layer is built + assert dora_embedding.built + + # Check weight shapes + assert dora_embedding.embeddings.shape == (1000, 128) + assert dora_embedding.lora_a.shape == (1000, 8) + assert dora_embedding.lora_b.shape == (8, 128) + assert dora_embedding.magnitude.shape == (128,) + + def test_forward_pass(self, sample_input, dora_embedding): + """Test forward pass functionality.""" + # Build and run forward pass + output = dora_embedding(sample_input) + + # Check output shape + expected_shape = sample_input.shape + (dora_embedding.output_dim,) + assert output.shape == expected_shape + + # Check output is not NaN or Inf + assert check_no_nan_inf(output) + + def test_weight_initialization(self, dora_embedding): + """Test weight initialization.""" + # Build the layer + dora_embedding.build(None) + + # Check lora_b is initialized to zeros + lora_b_np = safe_convert_to_numpy(dora_embedding.lora_b) + assert np.allclose(lora_b_np, 0.0) + + # Check magnitude is initialized to ones + magnitude_np = safe_convert_to_numpy(dora_embedding.magnitude) + assert np.allclose(magnitude_np, 1.0) + + def test_integer_input_conversion(self, dora_embedding): + """Test that various input types are converted to integers.""" + # Build the layer + dora_embedding.build(None) + + # Test with float inputs (should be converted to int) + float_input = ops.convert_to_tensor([[1.0, 2.5, 3.9]], dtype="float32") + output_float = dora_embedding(float_input) + + # Test with int inputs + int_input = ops.convert_to_tensor([[1, 2, 3]], dtype="int32") + output_int = dora_embedding(int_input) + + # Both should work and have correct shape + assert output_float.shape == (1, 3, 128) + assert output_int.shape == (1, 3, 128) + + def test_mask_zero_functionality(self): + """Test mask_zero functionality.""" + # Layer with mask_zero=True + layer_masked = DoRAEmbedding( + input_dim=100, output_dim=32, mask_zero=True + ) + + # Layer with mask_zero=False + layer_unmasked = DoRAEmbedding( + input_dim=100, output_dim=32, mask_zero=False + ) + + # Test input with zeros + test_input = ops.convert_to_tensor([[1, 2, 0, 3, 0]], dtype="int32") + + # Test mask computation + mask_result = layer_masked.compute_mask(test_input) + assert mask_result is not None + + no_mask_result = layer_unmasked.compute_mask(test_input) + assert no_mask_result is None + + def test_sparse_embedding(self): + """Test sparse embedding functionality.""" + layer = DoRAEmbedding(input_dim=100, output_dim=32, sparse=True) + + test_input = ops.convert_to_tensor([[1, 2, 3]], dtype="int32") + output = layer(test_input) + + # Should work and produce correct shape + assert output.shape == (1, 3, 32) + + def test_get_effective_embeddings(self, dora_embedding): + """Test effective embeddings computation.""" + # Build the layer + dora_embedding.build(None) + + # Get effective embeddings + effective_embeddings = dora_embedding.get_effective_embeddings() + + # Check shape + assert effective_embeddings.shape == (1000, 128) + + # Check it's not NaN or Inf + assert check_no_nan_inf(effective_embeddings) + + def test_get_dora_parameters(self, dora_embedding): + """Test DoRA parameter retrieval.""" + dora_embedding.build(None) + + params = dora_embedding.get_dora_parameters() + + # Check all expected parameters are present + assert "lora_a" in params + assert "lora_b" in params + assert "magnitude" in params + + # Check shapes + assert params["lora_a"].shape == (1000, 8) + assert params["lora_b"].shape == (8, 128) + assert params["magnitude"].shape == (128,) + + def test_merge_weights(self, dora_embedding): + """Test weight merging functionality.""" + dora_embedding.build(None) + + merged = dora_embedding.merge_weights() + + # Check structure + assert "embeddings" in merged + + # Check shapes + assert merged["embeddings"].shape == (1000, 128) + + def test_count_params(self): + """Test parameter counting.""" + layer = DoRAEmbedding(input_dim=1000, output_dim=128, rank=8) + + expected_params = ( + 1000 * 8 # lora_a + + 8 * 128 # lora_b + + 128 # magnitude + ) + + assert layer.count_params() == expected_params + + def test_load_pretrained_embeddings(self, dora_embedding): + """Test loading pretrained embeddings.""" + dora_embedding.build(None) + + # Create fake pretrained embeddings using backend-agnostic operations + pretrained_embeddings = create_random_tensor((1000, 128), seed=123) + pretrained_tensor = ops.convert_to_tensor(pretrained_embeddings) + + # Load embeddings + dora_embedding.load_pretrained_embeddings(pretrained_tensor) + + # Check if embeddings are loaded correctly + embeddings_np = safe_convert_to_numpy(dora_embedding.embeddings) + assert safe_allclose(embeddings_np, pretrained_embeddings) + + def test_load_pretrained_embeddings_shape_mismatch(self, dora_embedding): + """Test loading pretrained embeddings with wrong shapes.""" + dora_embedding.build(None) + + # Wrong shape + wrong_embeddings = create_random_tensor((500, 64), seed=123) + wrong_tensor = ops.convert_to_tensor(wrong_embeddings) + + with pytest.raises(ValueError, match="Pretrained embeddings shape"): + dora_embedding.load_pretrained_embeddings(wrong_tensor) + + def test_expand_vocabulary(self, dora_embedding): + """Test vocabulary expansion functionality.""" + dora_embedding.build(None) + + # Expand vocabulary from 1000 to 1200 + expanded_layer = dora_embedding.expand_vocabulary(1200) + + # Check new dimensions + assert expanded_layer.input_dim == 1200 + assert expanded_layer.output_dim == 128 # Should remain same + + # Check weight shapes + assert expanded_layer.embeddings.shape == (1200, 128) + assert expanded_layer.lora_a.shape == (1200, 8) + assert expanded_layer.lora_b.shape == (8, 128) + assert expanded_layer.magnitude.shape == (128,) + + def test_expand_vocabulary_with_new_embeddings(self, dora_embedding): + """Test vocabulary expansion with provided new embeddings.""" + dora_embedding.build(None) + + # Create new token embeddings for 200 additional tokens + new_token_embeddings = create_random_tensor((200, 128), seed=456) + new_embeddings_tensor = ops.convert_to_tensor(new_token_embeddings) + + # Expand vocabulary + expanded_layer = dora_embedding.expand_vocabulary( + 1200, new_embeddings_tensor + ) + + # Check dimensions + assert expanded_layer.input_dim == 1200 + assert expanded_layer.embeddings.shape == (1200, 128) + + def test_expand_vocabulary_errors(self, dora_embedding): + """Test vocabulary expansion error cases.""" + dora_embedding.build(None) + + # Test expanding to smaller size + with pytest.raises( + ValueError, match="new_vocab_size .* must be greater" + ): + dora_embedding.expand_vocabulary(500) + + # Test with wrong new embeddings shape + wrong_embeddings = create_random_tensor((100, 64), seed=789) + wrong_tensor = ops.convert_to_tensor(wrong_embeddings) + + with pytest.raises(ValueError, match="new_token_embeddings shape"): + dora_embedding.expand_vocabulary(1200, wrong_tensor) + + def test_expand_vocabulary_unbuilt_layer(self): + """Test expanding vocabulary on unbuilt layer.""" + layer = DoRAEmbedding(input_dim=1000, output_dim=128) + + with pytest.raises(ValueError, match="Layer must be built"): + layer.expand_vocabulary(1200) + + def test_serialization(self, dora_embedding): + """Test layer serialization and deserialization.""" + # Get config + config = dora_embedding.get_config() + + # Check essential parameters are in config + assert config["input_dim"] == dora_embedding.input_dim + assert config["output_dim"] == dora_embedding.output_dim + assert config["rank"] == dora_embedding.rank + assert config["alpha"] == dora_embedding.alpha + assert config["mask_zero"] == dora_embedding.mask_zero + + # Create layer from config + restored_layer = DoRAEmbedding.from_config(config) + + # Check restored layer has same parameters + assert restored_layer.input_dim == dora_embedding.input_dim + assert restored_layer.output_dim == dora_embedding.output_dim + assert restored_layer.rank == dora_embedding.rank + assert restored_layer.alpha == dora_embedding.alpha + + def test_compute_output_shape(self): + """Test output shape computation.""" + layer = DoRAEmbedding(input_dim=1000, output_dim=64, input_length=10) + + # Test various input shapes + input_shapes = [ + (None,), + (10,), + (None, 5), + (32, 10), + ] + + for input_shape in input_shapes: + output_shape = layer.compute_output_shape(input_shape) + expected_shape = input_shape + (64,) + assert output_shape == expected_shape + + def test_regularization(self): + """Test regularization functionality.""" + layer = DoRAEmbedding( + input_dim=100, + output_dim=32, + embeddings_regularizer="l2", + activity_regularizer="l2", + ) + + sample_input = ops.convert_to_tensor([[1, 2, 3]], dtype="int32") + output = layer(sample_input) + + # Check output shape + assert output.shape == (1, 3, 32) + + def test_constraints(self): + """Test constraint functionality.""" + layer = DoRAEmbedding( + input_dim=100, output_dim=32, embeddings_constraint="max_norm" + ) + + sample_input = ops.convert_to_tensor([[1, 2, 3]], dtype="int32") + output = layer(sample_input) + + # Check output shape + assert output.shape == (1, 3, 32) + + +class TestDoRAPositionEmbedding: + """Test cases for DoRAPositionEmbedding layer.""" + + @pytest.fixture + def sample_input(self): + """Create sample token embeddings.""" + return create_random_tensor((8, 32, 64), seed=42) + + @pytest.fixture + def position_layer(self): + """Create a basic DoRA position embedding layer.""" + return DoRAPositionEmbedding( + sequence_length=128, output_dim=64, rank=8, alpha=2.0 + ) + + def test_layer_creation(self, position_layer): + """Test basic layer creation.""" + assert position_layer.sequence_length == 128 + assert position_layer.output_dim == 64 + assert position_layer.rank == 8 + assert position_layer.alpha == 2.0 + + def test_layer_build(self, position_layer): + """Test layer building process.""" + # Build the layer + position_layer.build(None) + + # Check if layer is built + assert position_layer.built + + # Check weight shapes + assert position_layer.position_embeddings.shape == (128, 64) + assert position_layer.lora_a.shape == (128, 8) + assert position_layer.lora_b.shape == (8, 64) + assert position_layer.magnitude.shape == (64,) + + def test_forward_pass(self, sample_input, position_layer): + """Test forward pass functionality.""" + # Convert to tensor for backend compatibility + input_tensor = ops.convert_to_tensor(sample_input) + + # Build and run forward pass + output = position_layer(input_tensor) + + # Check output shape matches input + assert output.shape == input_tensor.shape + + # Check output is not NaN or Inf + assert check_no_nan_inf(output) + + def test_start_index_parameter(self, sample_input, position_layer): + """Test start_index parameter.""" + input_tensor = ops.convert_to_tensor(sample_input) + + # Test with different start indices + output1 = position_layer(input_tensor, start_index=0) + output2 = position_layer(input_tensor, start_index=10) + + # Both should have same shape + assert output1.shape == input_tensor.shape + assert output2.shape == input_tensor.shape + + # Should produce different outputs for different start indices + assert not safe_allclose(output1, output2) + + def test_sequence_length_clipping(self, position_layer): + """Test that positions are clipped to sequence length.""" + position_layer.build(None) + + # Create input longer than sequence_length + long_input = create_random_tensor((4, 200, 64), seed=42) + long_tensor = ops.convert_to_tensor(long_input) + + # Should still work (positions get clipped) + output = position_layer(long_tensor) + assert output.shape == long_tensor.shape + + def test_effective_position_embeddings(self, position_layer): + """Test effective position embeddings computation.""" + position_layer.build(None) + + # Get effective position embeddings + effective_embeddings = ( + position_layer._get_effective_position_embeddings() + ) + + # Check shape + assert effective_embeddings.shape == (128, 64) + + # Check it's not NaN or Inf + assert check_no_nan_inf(effective_embeddings) + + def test_serialization(self, position_layer): + """Test layer serialization and deserialization.""" + # Get config + config = position_layer.get_config() + + # Check essential parameters are in config + assert config["sequence_length"] == position_layer.sequence_length + assert config["output_dim"] == position_layer.output_dim + assert config["rank"] == position_layer.rank + assert config["alpha"] == position_layer.alpha + + # Create layer from config + restored_layer = DoRAPositionEmbedding.from_config(config) + + # Check restored layer has same parameters + assert restored_layer.sequence_length == position_layer.sequence_length + assert restored_layer.output_dim == position_layer.output_dim + assert restored_layer.rank == position_layer.rank + assert restored_layer.alpha == position_layer.alpha + + +class TestEmbeddingConversion: + """Test cases for Embedding to DoRA conversion.""" + + def test_convert_embedding_to_dora(self): + """Test converting Embedding layer to DoRA layer.""" + # Create an Embedding layer + embedding_layer = layers.Embedding( + input_dim=1000, + output_dim=64, + mask_zero=True, + embeddings_initializer="uniform", + ) + + # Build with sample input + sample_input = ops.convert_to_tensor([[1, 2, 3, 4]], dtype="int32") + embedding_output = embedding_layer(sample_input) + + # Convert to DoRA + dora_layer = convert_embedding_to_dora( + embedding_layer, rank=8, alpha=2.0 + ) + + # Check configuration + assert dora_layer.input_dim == embedding_layer.input_dim + assert dora_layer.output_dim == embedding_layer.output_dim + assert dora_layer.rank == 8 + assert dora_layer.alpha == 2.0 + assert dora_layer.mask_zero == embedding_layer.mask_zero + + # Check weights are loaded + assert dora_layer.built + + # Test forward pass produces reasonable output + dora_output = dora_layer(sample_input) + assert dora_output.shape == embedding_output.shape + + def test_convert_unbuilt_embedding(self): + """Test converting unbuilt Embedding layer.""" + embedding_layer = layers.Embedding(input_dim=500, output_dim=32) + + # Convert unbuilt layer + dora_layer = convert_embedding_to_dora(embedding_layer, rank=4) + + # Should not be built yet + assert not dora_layer.built + + # But should have correct configuration + assert dora_layer.input_dim == 500 + assert dora_layer.output_dim == 32 + assert dora_layer.rank == 4 + + def test_convert_embedding_without_input_length(self): + """Test converting embedding layer without input_length attribute.""" + + # Create a mock embedding layer without input_length + class MockEmbedding: + def __init__(self): + self.input_dim = 100 + self.output_dim = 32 + self.embeddings_initializer = "uniform" + self.embeddings_regularizer = None + self.activity_regularizer = None + self.embeddings_constraint = None + self.mask_zero = False + self.name = "test_embedding" + self.built = False + + mock_layer = MockEmbedding() + + # Should work even without input_length + dora_layer = convert_embedding_to_dora(mock_layer, rank=4) + assert dora_layer.input_dim == 100 + assert dora_layer.output_dim == 32 + assert dora_layer.input_length is None + + +class TestDoRAEmbeddingMathematicalProperties: + """Test mathematical properties of DoRA embeddings.""" + + def test_magnitude_scaling_property(self): + """Test that DoRA properly applies magnitude scaling.""" + layer = DoRAEmbedding(input_dim=100, output_dim=32, rank=4) + layer.build(None) + + # Get effective embeddings + effective_embeddings = layer.get_effective_embeddings() + effective_embeddings_np = safe_convert_to_numpy(effective_embeddings) + + # Compute column norms of effective embeddings + column_norms = np.linalg.norm(effective_embeddings_np, axis=0) + magnitude_np = safe_convert_to_numpy(layer.magnitude) + + # Column norms should equal magnitude values (approximately) + assert safe_allclose(column_norms, magnitude_np, rtol=1e-5) + + def test_low_rank_adaptation_property(self): + """Test that adaptation is indeed low-rank.""" + layer = DoRAEmbedding(input_dim=100, output_dim=64, rank=8) + layer.build(None) + + # Compute LoRA adaptation using backend-agnostic operations + lora_a_np = safe_convert_to_numpy(layer.lora_a) + lora_b_np = safe_convert_to_numpy(layer.lora_b) + adaptation = lora_a_np @ (lora_b_np * layer.scaling) + + # Check that adaptation matrix has rank <= layer.rank + actual_rank = np.linalg.matrix_rank(adaptation) + assert actual_rank <= layer.rank + + def test_zero_initialization_equivalence(self): + """Test that zero LoRA initialization gives expected behavior.""" + layer = DoRAEmbedding( + input_dim=50, + output_dim=32, + rank=4, + lora_a_initializer="zeros", + lora_b_initializer="zeros", + ) + layer.build(None) + + # With zero LoRA matrices, effective embeddings should have + # column norms equal to magnitude (which is initialized to ones) + effective_embeddings = layer.get_effective_embeddings() + effective_embeddings_np = safe_convert_to_numpy(effective_embeddings) + column_norms = np.linalg.norm(effective_embeddings_np, axis=0) + + magnitude_np = safe_convert_to_numpy(layer.magnitude) + assert safe_allclose(column_norms, magnitude_np, rtol=1e-5) + + def test_embedding_lookup_correctness(self): + """Test that embedding lookup works correctly.""" + layer = DoRAEmbedding(input_dim=10, output_dim=4, rank=2) + layer.build(None) + + # Test specific token indices + test_indices = ops.convert_to_tensor( + [[0, 1, 2], [3, 4, 5]], dtype="int32" + ) + output = layer(test_indices) + + # Get effective embeddings + effective_embeddings = layer.get_effective_embeddings() + + # Manually lookup embeddings for comparison + output_np = safe_convert_to_numpy(output) + effective_embeddings_np = safe_convert_to_numpy(effective_embeddings) + + # Check first batch, first token (index 0) + expected_first = effective_embeddings_np[0] + actual_first = output_np[0, 0] + assert safe_allclose(actual_first, expected_first) + + # Check second batch, third token (index 5) + expected_last = effective_embeddings_np[5] + actual_last = output_np[1, 2] + assert safe_allclose(actual_last, expected_last) + + +def test_backend_compatibility(): + """Test that the implementation works across different backends.""" + try: + backend_name = keras.backend.backend() + print(f"Testing with backend: {backend_name}") + except Exception: + print("Backend detection failed, proceeding with tests...") + + # Test DoRAEmbedding + embedding_layer = DoRAEmbedding(input_dim=100, output_dim=32, rank=4) + sample_input = create_random_tensor((1, 4), dtype="int32") + sample_tensor = ops.convert_to_tensor(sample_input) + + try: + output = embedding_layer(sample_tensor) + assert output.shape == (1, 4, 32) + print("DoRAEmbedding test passed") + except Exception as e: + print(f"DoRAEmbedding test failed: {e}") + return False + + # Test DoRAPositionEmbedding + pos_layer = DoRAPositionEmbedding(sequence_length=10, output_dim=32, rank=4) + sample_embeddings = create_random_tensor((2, 4, 32)) + embeddings_tensor = ops.convert_to_tensor(sample_embeddings) + + try: + pos_output = pos_layer(embeddings_tensor) + assert pos_output.shape == (2, 4, 32) + print("DoRAPositionEmbedding test passed") + except Exception as e: + print(f"DoRAPositionEmbedding test failed: {e}") + return False + + return True + + +def test_masking_integration(): + """Test integration with Keras masking.""" + # Create layer with masking + layer = DoRAEmbedding(input_dim=100, output_dim=32, mask_zero=True) + + # Input with zeros (should be masked) + input_with_zeros = ops.convert_to_tensor([[1, 2, 0, 3, 0]], dtype="int32") + + # Get output and mask + output = layer(input_with_zeros) + mask = layer.compute_mask(input_with_zeros) + + assert output.shape == (1, 5, 32) + assert mask is not None + + # Check mask values + mask_np = safe_convert_to_numpy(mask) + expected_mask = np.array([[True, True, False, True, False]]) + assert safe_array_equal(mask_np, expected_mask) + + +def test_safe_weight_assignment(): + """Test safe weight assignment across backends.""" + layer = DoRAEmbedding(input_dim=10, output_dim=8, rank=2) + layer.build(None) + + # Test loading pretrained embeddings + pretrained = create_random_tensor((10, 8), seed=999) + pretrained_tensor = ops.convert_to_tensor(pretrained) + + try: + layer.load_pretrained_embeddings(pretrained_tensor) + # Check if assignment worked + loaded_embeddings = safe_convert_to_numpy(layer.embeddings) + assert safe_allclose(loaded_embeddings, pretrained) + print("Safe weight assignment test passed") + return True + except Exception as e: + print(f"Safe weight assignment test failed: {e}") + return False + + +def test_backend_agnostic_operations(): + """Test that all operations use backend-agnostic ops.""" + layer = DoRAEmbedding(input_dim=20, output_dim=16, rank=4) + layer.build(None) + + # Test effective embeddings computation + try: + effective_embeddings = layer._get_effective_embeddings() + assert effective_embeddings.shape == (20, 16) + assert check_no_nan_inf(effective_embeddings) + print("Backend-agnostic operations test passed") + return True + except Exception as e: + print(f"Backend-agnostic operations test failed: {e}") + return False + + +if __name__ == "__main__": + # Run comprehensive backend compatibility tests + print("=" * 60) + print("DORA EMBEDDINGS BACKEND COMPATIBILITY TEST SUITE") + print("=" * 60) + + tests_passed = 0 + total_tests = 5 + + # Test 1: Backend compatibility + if test_backend_compatibility(): + tests_passed += 1 + + # Test 2: Masking integration + try: + test_masking_integration() + print("Masking integration test passed!") + tests_passed += 1 + except Exception as e: + print(f"Masking integration test failed: {e}") + + # Test 3: Safe weight assignment + if test_safe_weight_assignment(): + tests_passed += 1 + + # Test 4: Backend-agnostic operations + if test_backend_agnostic_operations(): + tests_passed += 1 + + # Test 5: Comprehensive functionality test + try: + layer = DoRAEmbedding(input_dim=1000, output_dim=128, rank=8, alpha=2.0) + sample_input = create_random_tensor((1, 5), dtype="int32") + sample_tensor = ops.convert_to_tensor(sample_input) + + # Test forward pass + output = layer(sample_tensor) + print(f"Output shape: {output.shape}") + + # Test parameter counting + param_count = layer.count_params() + print(f"Trainable parameters: {param_count}") + + # Test effective embeddings computation + effective_embeddings = layer.get_effective_embeddings() + print(f"Effective embeddings shape: {effective_embeddings.shape}") + + # Test position embedding + pos_layer = DoRAPositionEmbedding(sequence_length=64, output_dim=128) + pos_input = create_random_tensor((4, 10, 128)) + pos_tensor = ops.convert_to_tensor(pos_input) + pos_output = pos_layer(pos_tensor) + print(f"Position embedding output shape: {pos_output.shape}") + + print("Comprehensive functionality test passed!") + tests_passed += 1 + except Exception as e: + print(f"Comprehensive functionality test failed: {e}") + + print("=" * 60) + print(f"RESULTS: {tests_passed}/{total_tests} tests passed") + if tests_passed == total_tests: + print("🎉 ALL TESTS PASSED! Backend compatibility confirmed.") + else: + print("⚠️ Some tests failed. Check backend compatibility.") + print("=" * 60) diff --git a/keras_hub/src/models/bert/bert_backbone.py b/keras_hub/src/models/bert/bert_backbone.py index 8ea51dfcf9..1a16cca3e4 100644 --- a/keras_hub/src/models/bert/bert_backbone.py +++ b/keras_hub/src/models/bert/bert_backbone.py @@ -1,6 +1,9 @@ import keras from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.dora_dense import DoRADense +from keras_hub.src.layers.modeling.dora_embeddings import DoRAEmbedding +from keras_hub.src.layers.modeling.dora_embeddings import DoRAPositionEmbedding from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding from keras_hub.src.layers.modeling.reversible_embedding import ( ReversibleEmbedding, @@ -36,7 +39,8 @@ class BertBackbone(Backbone): num_layers: int. The number of transformer layers. num_heads: int. The number of attention heads for each transformer. The hidden size must be divisible by the number of attention heads. - hidden_dim: int. The size of the transformer encoding and pooler layers. + hidden_dim: int. + The size of the transformer encoding and pooler layers. intermediate_dim: int. The output dimension of the first Dense layer in a two-layer feedforward network for each transformer. dropout: float. Dropout probability for the Transformer encoder. @@ -83,26 +87,55 @@ def __init__( num_heads, hidden_dim, intermediate_dim, + enable_dora=False, + dora_rank=8, + dora_alpha=16.0, dropout=0.1, max_sequence_length=512, num_segments=2, dtype=None, **kwargs, ): + self.enable_dora = enable_dora + self.dora_rank = dora_rank + self.dora_alpha = dora_alpha # === Layers === - self.token_embedding = ReversibleEmbedding( - input_dim=vocabulary_size, - output_dim=hidden_dim, - embeddings_initializer=bert_kernel_initializer(), - dtype=dtype, - name="token_embedding", - ) - self.position_embedding = PositionEmbedding( - initializer=bert_kernel_initializer(), - sequence_length=max_sequence_length, - dtype=dtype, - name="position_embedding", - ) + if enable_dora: + self.token_embedding = DoRAEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + rank=dora_rank, + alpha=dora_alpha, + embeddings_initializer=bert_kernel_initializer(), + dtype=dtype, + name="token_embedding", + ) + else: + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + embeddings_initializer=bert_kernel_initializer(), + dtype=dtype, + name="token_embedding", + ) + + if enable_dora: + self.position_embedding = DoRAPositionEmbedding( + sequence_length=max_sequence_length, + output_dim=hidden_dim, + rank=dora_rank, + alpha=dora_alpha, + initializer=bert_kernel_initializer(), + dtype=dtype, + name="position_embedding", + ) + else: + self.position_embedding = PositionEmbedding( + initializer=bert_kernel_initializer(), + sequence_length=max_sequence_length, + dtype=dtype, + name="position_embedding", + ) self.segment_embedding = keras.layers.Embedding( input_dim=num_segments, output_dim=hidden_dim, @@ -138,13 +171,25 @@ def __init__( name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) - self.pooled_dense = keras.layers.Dense( - hidden_dim, - kernel_initializer=bert_kernel_initializer(), - activation="tanh", - dtype=dtype, - name="pooled_dense", - ) + + if enable_dora: + self.pooled_dense = DoRADense( + units=hidden_dim, + rank=dora_rank, + alpha=dora_alpha, + kernel_initializer=bert_kernel_initializer(), + activation="tanh", + dtype=dtype, + name="pooled_dense", + ) + else: + self.pooled_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=bert_kernel_initializer(), + activation="tanh", + dtype=dtype, + name="pooled_dense", + ) # === Functional Model === token_id_input = keras.Input( @@ -205,6 +250,9 @@ def get_config(self): "num_heads": self.num_heads, "hidden_dim": self.hidden_dim, "intermediate_dim": self.intermediate_dim, + "enable_dora": self.enable_dora, + "dora_rank": self.dora_rank, + "dora_alpha": self.dora_alpha, "dropout": self.dropout, "max_sequence_length": self.max_sequence_length, "num_segments": self.num_segments, diff --git a/keras_hub/src/models/bert/bert_backbone_test.py b/keras_hub/src/models/bert/bert_backbone_test.py index 0dcb1f7de5..0e248f36ce 100644 --- a/keras_hub/src/models/bert/bert_backbone_test.py +++ b/keras_hub/src/models/bert/bert_backbone_test.py @@ -32,6 +32,75 @@ def test_backbone_basics(self): }, ) + def test_backbone_with_dora(self): + """Test BERT backbone with DoRA layers enabled.""" + dora_init_kwargs = { + **self.init_kwargs, + "enable_dora": True, + "dora_rank": 4, + "dora_alpha": 8.0, + } + + self.run_backbone_test( + cls=BertBackbone, + init_kwargs=dora_init_kwargs, + input_data=self.input_data, + expected_output_shape={ + "sequence_output": (2, 5, 2), + "pooled_output": (2, 2), + }, + ) + + def test_dora_config_preservation(self): + """Test that DoRA configuration is properly saved and restored.""" + model = BertBackbone( + vocabulary_size=10, + num_layers=2, + num_heads=2, + hidden_dim=4, + intermediate_dim=8, + enable_dora=True, + dora_rank=8, + dora_alpha=16.0, + max_sequence_length=5, + ) + + config = model.get_config() + + # Verify DoRA parameters are in config + self.assertEqual(config["enable_dora"], True) + self.assertEqual(config["dora_rank"], 8) + self.assertEqual(config["dora_alpha"], 16.0) + + # Test model can be recreated from config + new_model = BertBackbone.from_config(config) + self.assertEqual(new_model.enable_dora, True) + self.assertEqual(new_model.dora_rank, 8) + self.assertEqual(new_model.dora_alpha, 16.0) + + def test_dora_vs_regular_output_shapes(self): + """Test that DoRA and regular models produce same output shapes.""" + regular_model = BertBackbone(**self.init_kwargs) + dora_model = BertBackbone( + **self.init_kwargs, + enable_dora=True, + dora_rank=4, + dora_alpha=8.0, + ) + + regular_output = regular_model(self.input_data) + dora_output = dora_model(self.input_data) + + # Shapes should be identical + self.assertEqual( + regular_output["sequence_output"].shape, + dora_output["sequence_output"].shape, + ) + self.assertEqual( + regular_output["pooled_output"].shape, + dora_output["pooled_output"].shape, + ) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( @@ -40,6 +109,22 @@ def test_saved_model(self): input_data=self.input_data, ) + @pytest.mark.large + def test_saved_model_with_dora(self): + """Test model saving/loading with DoRA enabled.""" + dora_init_kwargs = { + **self.init_kwargs, + "enable_dora": True, + "dora_rank": 4, + "dora_alpha": 8.0, + } + + self.run_model_saving_test( + cls=BertBackbone, + init_kwargs=dora_init_kwargs, + input_data=self.input_data, + ) + @pytest.mark.large def test_smallest_preset(self): self.run_preset_test(