From 90e7c59aade7bfe41bb93f8d40fb21e4f1a2a6e9 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 2 Apr 2025 17:35:11 +0000 Subject: [PATCH 01/11] mistral init commit --- .../src/models/mixtral/mixtral_attention.py | 244 ++++++++++++++++++ .../src/models/mixtral/mixtral_backbone.py | 190 ++++++++++++++ .../src/models/mixtral/mixtral_causal_lm.py | 0 .../mixtral/mixtral_causal_lm_preprocessor.py | 0 .../src/models/mixtral/mixtral_decoder.py | 0 .../src/models/mixtral/mixtral_layer_norm.py | 35 +++ .../src/models/mixtral/mixtral_tokenizer.py | 0 7 files changed, 469 insertions(+) create mode 100644 keras_hub/src/models/mixtral/mixtral_attention.py create mode 100644 keras_hub/src/models/mixtral/mixtral_backbone.py create mode 100644 keras_hub/src/models/mixtral/mixtral_causal_lm.py create mode 100644 keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py create mode 100644 keras_hub/src/models/mixtral/mixtral_decoder.py create mode 100644 keras_hub/src/models/mixtral/mixtral_layer_norm.py create mode 100644 keras_hub/src/models/mixtral/mixtral_tokenizer.py diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py new file mode 100644 index 0000000000..d87a676de2 --- /dev/null +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -0,0 +1,244 @@ +import math + +import keras +from keras import ops + +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +from keras_hub.src.utils.keras_utils import clone_initializer +from keras_hub.src.utils.keras_utils import has_flash_attention_support + + +# This is just a self-attention layer in Mistral. But it can be generalized +# to use the `keras_hub.layers.CachedMultiHeadAttention` API. Since this layer +# implements grouped-query attention and sliding window attention, it might be +# useful outside of Mistral itself. +# TODO(tirthasheshpatel): Generalize the attention layer +# TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer +# TODO(tirthasheshpatel): Use flash attention +class CachedMistralAttention(keras.layers.Layer): + """A cached grounded query attention layer with sliding window.""" + + def __init__( + self, + num_query_heads, + num_key_value_heads, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + kernel_initializer="glorot_uniform", + sliding_window=512, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self._num_query_heads = num_query_heads + self._num_key_value_heads = num_key_value_heads + self._sliding_window = sliding_window + self._dropout = dropout + + self._num_key_value_groups = num_query_heads // num_key_value_heads + self._rope_max_wavelength = rope_max_wavelength + + self._kernel_initializer = keras.initializers.get( + clone_initializer(kernel_initializer) + ) + + self._rope_scaling_factor = rope_scaling_factor + + def build(self, inputs_shape): + # Einsum variables: + # b = batch size + # q = query length + # k = key/value length + # m = model dim + # u = num query heads + # v = num key/value heads + # h = head dim + self._hidden_dim = inputs_shape[-1] + self._head_dim = self._hidden_dim // self._num_query_heads + self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim) + + self._query_dense = keras.layers.EinsumDense( + equation="bqm,muh->bquh", + output_shape=(None, self._num_query_heads, self._head_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="query", + ) + self._query_dense.build(inputs_shape) + + self._key_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self._num_key_value_heads, + self._head_dim, + ), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="key", + ) + self._key_dense.build(inputs_shape) + + self._value_dense = keras.layers.EinsumDense( + equation="bkm,mvh->bkvh", + output_shape=( + None, + self._num_key_value_heads, + self._head_dim, + ), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="value", + ) + self._value_dense.build(inputs_shape) + + self._softmax = keras.layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + self._dropout_layer = keras.layers.Dropout( + rate=self._dropout, + dtype=self.dtype_policy, + ) + + self._output_dense = keras.layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, self._hidden_dim), + kernel_initializer=self._kernel_initializer, + dtype=self.dtype_policy, + name="attention_output", + ) + self._output_dense.build( + (None, None, self._num_query_heads, self._head_dim) + ) + + self.rotary_embedding_layer = RotaryEmbedding( + max_wavelength=self._rope_max_wavelength, + scaling_factor=self._rope_scaling_factor, + dtype=self.dtype_policy, + ) + + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self.built = True + + def call( + self, + hidden_states, + attention_mask=None, + cache=None, + cache_update_index=None, + training=None, + ): + start_index = ( + cache_update_index if cache_update_index is not None else 0 + ) + + query = self._query_dense(hidden_states) + + # Compute RoPE for queries + query = self.rotary_embedding_layer(query, start_index=start_index) + + def _compute_key_value(x): + key, value = self._key_dense(x), self._value_dense(x) + # Compute RoPE for keys + key = self.rotary_embedding_layer(key, start_index=start_index) + return key, value + + if cache is not None: + key_cache = cache[:, 0, ...] + value_cache = cache[:, 1, ...] + if cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_key_value(hidden_states) + start = [0, cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update(value_cache, start, value_update) + cache = ops.stack((key, value), axis=1) + else: + if cache_update_index is not None: + raise ValueError( + "`cache_update_index` should not be set if `cache` is " + f"`None`. Received: cache={cache}, " + f"cache_update_index={cache_update_index}" + ) + key, value = _compute_key_value(hidden_states) + + # [batch_shape, seq_len, num_key_value_heads, head_dim] + # -> [batch_shape, seq_len, num_heads, head_dim] + key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2) + + attention_output = self._compute_attention( + query, key, value, attention_mask + ) + + attention_output = self._dropout_layer( + attention_output, training=training + ) + + attention_output = self._output_dense(attention_output) + + if cache is not None: + return attention_output, cache + return attention_output + + def _masked_softmax(self, attention_scores, attention_mask=None): + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention(self, query, key, value, attention_mask=None): + if has_flash_attention_support(): + # Use `dot_product_attention` with Flash Attention support if + # available. + if attention_mask is not None: + attention_mask = ops.expand_dims(attention_mask, axis=1) + attention_mask = ops.cast(attention_mask, dtype="bool") + attention_output = ops.dot_product_attention( + query, + key, + value, + mask=attention_mask, + scale=self._inv_norm_factor, + ) + return attention_output + + attention_scores = ops.einsum(self._dot_product_equation, query, key) + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output + + def get_config(self): + config = super().get_config() + config.update( + { + "num_query_heads": self._num_query_heads, + "num_key_value_heads": self._num_key_value_heads, + "rope_max_wavelength": self._rope_max_wavelength, + "rope_scaling_factor": self._rope_scaling_factor, + "kernel_initializer": keras.initializers.serialize( + self._kernel_initializer + ), + "sliding_window": self._sliding_window, + "dropout": self._dropout, + } + ) + return config diff --git a/keras_hub/src/models/mixtral/mixtral_backbone.py b/keras_hub/src/models/mixtral/mixtral_backbone.py new file mode 100644 index 0000000000..09a5d38129 --- /dev/null +++ b/keras_hub/src/models/mixtral/mixtral_backbone.py @@ -0,0 +1,190 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.mistral.mistral_layer_norm import ( + MistralLayerNormalization, +) +from keras_hub.src.models.mistral.mistral_transformer_decoder import ( + MistralTransformerDecoder, +) + + +def _mistral_kernel_initializer(stddev=0.02): + return keras.initializers.RandomNormal(stddev=stddev) + + +@keras_hub_export("keras_hub.models.MistralBackbone") +class MistralBackbone(Backbone): + """ + The Mistral Transformer core architecture with hyperparameters. + + This network implements a Transformer-based decoder network, + Mistral, as described in + ["Mistral 7B"](https://arxiv.org/pdf/2310.06825.pdf). + It includes the embedding lookups and transformer layers. + + The default constructor gives a fully customizable, randomly initialized + Mistral model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + vocabulary_size (int): The size of the token vocabulary. + num_layers (int): The number of transformer layers. + num_query_heads (int): The number of query attention heads for + each transformer. + hidden_dim (int): The size of the transformer encoding and pooling + layers. + intermediate_dim (int): The output dimension of the first Dense layer + in a three-layer feedforward network for each transformer. + num_key_value_heads (int): The number of key and value attention heads + for each transformer. + rope_max_wavelength (int, optional): The maximum angular wavelength of + the sine/cosine curves, for rotary embeddings. Defaults to `10000`. + rope_scaling_factor (float, optional): The scaling factor for + calculation of roatary embedding. Defaults to `1.0`. + layer_norm_epsilon (float, optional): Epsilon for the layer + normalization layers in the transformer decoder. Defaults to `1e-6`. + sliding_window (int, optional): The sliding window for the mistral + attention layers. This controls the maximum cache size for the + attention layers in each transformer decoder. Only `sliding_window` + number of tokens are saved in the cache and used to generate the + next token. Defaults to `512`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for model computations and weights. Note that some computations, + such as softmax and layer normalization, will always be done at + float32 precision regardless of dtype. + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained Mistral decoder. + model = keras_hub.models.MistralBackbone.from_preset("mistral7b_base_en") + model(input_data) + + # Randomly initialized Mistral decoder with custom config. + model = keras_hub.models.MistralBackbone( + vocabulary_size=10, + hidden_dim=512, + num_layers=2, + num_query_heads=32, + num_key_value_heads=8, + intermediate_dim=1024, + sliding_window=512, + layer_norm_epsilon=1e-6, + dtype="float32" + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + num_layers, + num_query_heads, + hidden_dim, + intermediate_dim, + num_key_value_heads, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + layer_norm_epsilon=1e-6, + sliding_window=512, + dropout=0, + dtype=None, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + tie_weights=False, + embeddings_initializer=_mistral_kernel_initializer(stddev=0.01), + dtype=dtype, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = MistralTransformerDecoder( + intermediate_dim=intermediate_dim, + num_query_heads=num_query_heads, + num_key_value_heads=num_key_value_heads, + rope_max_wavelength=rope_max_wavelength, + rope_scaling_factor=rope_scaling_factor, + layer_norm_epsilon=layer_norm_epsilon, + activation=ops.silu, + kernel_initializer=_mistral_kernel_initializer(stddev=0.02), + sliding_window=sliding_window, + dropout=dropout, + dtype=dtype, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + self.layer_norm = MistralLayerNormalization( + epsilon=layer_norm_epsilon, + dtype=dtype, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + x = self.token_embedding(token_id_input) + for transformer_layer in self.transformer_layers: + x = transformer_layer(x, decoder_padding_mask=padding_mask_input) + sequence_output = self.layer_norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.num_layers = num_layers + self.num_query_heads = num_query_heads + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.rope_max_wavelength = rope_max_wavelength + self.num_key_value_heads = num_key_value_heads + self.rope_scaling_factor = rope_scaling_factor + self.sliding_window = sliding_window + self.layer_norm_epsilon = layer_norm_epsilon + self.dropout = dropout + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "num_layers": self.num_layers, + "num_query_heads": self.num_query_heads, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "sliding_window": self.sliding_window, + "layer_norm_epsilon": self.layer_norm_epsilon, + "dropout": self.dropout, + } + ) + return config diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm.py b/keras_hub/src/models/mixtral/mixtral_causal_lm.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py b/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/mixtral/mixtral_layer_norm.py b/keras_hub/src/models/mixtral/mixtral_layer_norm.py new file mode 100644 index 0000000000..affca9c45f --- /dev/null +++ b/keras_hub/src/models/mixtral/mixtral_layer_norm.py @@ -0,0 +1,35 @@ +import keras +from keras import ops + + +# TODO: Deprecate this in favor of +# `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is +# removed. +class MistralLayerNormalization(keras.layers.Layer): + """A normalization layer for Mistral that implements RMS normalization.""" + + def __init__(self, epsilon=1e-6, **kwargs): + super().__init__(**kwargs) + self.epsilon = epsilon + + def build(self, input_shape): + dim = input_shape[-1] + self.scale = self.add_weight( + name="scale", + trainable=True, + shape=(dim,), + initializer="ones", + dtype=self.variable_dtype, + ) + self.built = True + + def call(self, x): + x = ops.cast(x, "float32") + var = ops.mean(ops.power(x, 2), axis=-1, keepdims=True) + x = x * ops.rsqrt(var + self.epsilon) + return ops.cast(x * self.scale, self.compute_dtype) + + def get_config(self): + config = super().get_config() + config.update({"epsilon": self.epsilon}) + return config diff --git a/keras_hub/src/models/mixtral/mixtral_tokenizer.py b/keras_hub/src/models/mixtral/mixtral_tokenizer.py new file mode 100644 index 0000000000..e69de29bb2 From 43764fc8c35b807be2229369d50b13e97335f137 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 2 Apr 2025 18:52:32 +0000 Subject: [PATCH 02/11] wip mixtral --- .../src/models/mixtral/mixtral_attention.py | 9 +- .../src/models/mixtral/mixtral_decoder.py | 339 ++++++++++++++++++ .../src/models/mixtral/mixtral_layer_norm.py | 4 +- 3 files changed, 342 insertions(+), 10 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index d87a676de2..c38677151c 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -8,14 +8,7 @@ from keras_hub.src.utils.keras_utils import has_flash_attention_support -# This is just a self-attention layer in Mistral. But it can be generalized -# to use the `keras_hub.layers.CachedMultiHeadAttention` API. Since this layer -# implements grouped-query attention and sliding window attention, it might be -# useful outside of Mistral itself. -# TODO(tirthasheshpatel): Generalize the attention layer -# TODO(tirthasheshpatel): Merge `LlamaAttention` with this layer -# TODO(tirthasheshpatel): Use flash attention -class CachedMistralAttention(keras.layers.Layer): +class CachedMixtralAttention(keras.layers.Layer): """A cached grounded query attention layer with sliding window.""" def __init__( diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index e69de29bb2..00fdb3499e 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -0,0 +1,339 @@ +import keras +from keras import ops +from keras_hub.src.layers.modeling.transformer_layer_utils import compute_causal_mask, merge_padding_and_attention_mask +from keras_hub.src.models.mixtral.mixtral_attention import CachedMixtralAttention +from keras_hub.src.models.mixtral.mixtral_layer_norm import MixtralLayerNormalization +from keras_hub.src.utils.keras_utils import clone_initializer + +class MixtralMoeMLP(keras.layers.Layer): + def __init__( + self, + intermediate_dim, + hidden_dim, + activation_fn="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.hidden_dim = hidden_dim + self.activation_fn = activation_fn + self.kernel_initializer = kernel_initializer + self.layer_norm_epsilon = layer_norm_epsilon + + def build(self, decoder_sequence_shape): + # Feedforward layers. + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + + self._feedforward_gate_dense = keras.layers.Dense( + self.intermediate_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_gate_dense", + ) + self._feedforward_gate_dense.build(decoder_sequence_shape) + + self._feedforward_output_dense = keras.layers.Dense( + self.hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + + self._feedforward_output_dense.build( + self._feedforward_gate_dense.compute_output_shape( + decoder_sequence_shape + ) + ) + + self.activation = keras.activations.get(self.activation_fn) + self.built = True + + def call(self, x): + gate_output = self._feedforward_gate_dense(x) + + # Note that we run the activation function in full 32-bit + # precision since this is what `torch.nn.functional.silu` + # does. Internally, `torch.nn.functional.silu` converts the + # inputs to float32, computes SiLU, and converts the outputs + # back to compute dtype. + # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 + # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 + gate_output = ops.cast(gate_output, "float32") + gate_output = self.activation(gate_output) + gate_output = ops.cast(gate_output, self.compute_dtype) + + x = self._feedforward_intermediate_dense(x) + + x = self._feedforward_output_dense(ops.multiply(x, gate_output)) + + return x + + + +class MixtralSparseMoeBlock(keras.layers.Layer): + + + def __init__(self, + hidden_dim, + intermediate_dim, + num_experts, + top_k, + router_jitter_noise, + **kwargs): + super().__init__(**kwargs) + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_experts = num_experts + self.top_k = top_k + + def build(self, decoder_sequence_shape): + + self._sparse_feedforward_gate_dense = keras.layers.Dense( + self.num_experts, + kernel_initializer=clone_initializer(self.kernel_initializer), + use_bias=False, + dtype=self.dtype_policy, + name="sparse_feedforward_gate_dense", + ) + self._sparse_feedforward_gate_dense.build(decoder_sequence_shape) + + self.experts = [ + MixtralMoeMLP( + intermediate_dim=self.moe_intermediate_dim, + hidden_dim=self.hidden_dim, + kernel_initializer=self.kernel_initializer, + layer_norm_epsilon=self.layer_norm_epsilon, + ) + for _ in range(self.num_experts) + ] + for expert in self.experts: + expert.build(decoder_sequence_shape) + + def call(self): + pass + + + + +class MixtralTransformerDecoder(keras.layers.Layer): + + def __init__( + self, + intermediate_dim, + num_query_heads, + num_key_value_heads, + num_experts, + top_k, + router_jitter_noise, + output_router_logits, + rope_max_wavelength=10000, + rope_scaling_factor=1.0, + activation="silu", + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", + sliding_window=512, + dropout=0, + **kwargs, + ): + super().__init__(**kwargs) + self.intermediate_dim = intermediate_dim + self.num_query_heads = num_query_heads + self.num_key_value_heads = num_key_value_heads + + self.num_experts = num_experts + self.top_k = top_k + self.router_jitter_noise = router_jitter_noise + + self.rope_max_wavelength = rope_max_wavelength + self.rope_scaling_factor = rope_scaling_factor + + self.dropout = dropout + + self.sliding_window = sliding_window + self.activation = keras.activations.get(activation) + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + + self.output_router_logits = output_router_logits + + self.supports_masking = True + + def build(self, decoder_sequence_shape): + self._decoder_sequence_shape = decoder_sequence_shape + self.hidden_dim = decoder_sequence_shape[-1] + + # Self attention layer. + self._self_attention_layer = CachedMixtralAttention( + num_query_heads=self.num_query_heads, + num_key_value_heads=self.num_key_value_heads, + rope_max_wavelength=self.rope_max_wavelength, + rope_scaling_factor=self.rope_scaling_factor, + sliding_window=self.sliding_window, + kernel_initializer=clone_initializer(self.kernel_initializer), + dropout=self.dropout, + dtype=self.dtype_policy, + name="self_attention", + ) + self._self_attention_layer.build(decoder_sequence_shape) + + self._self_attention_layernorm = MixtralLayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layernorm", + ) + self._self_attention_layernorm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + + self._sparse_moe_block = MixtralSparseMoeBlock( + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_experts=self.num_experts, + top_k=self.top_k, + router_jitter_noise=self.router_jitter_noise + ) + self._sparse_moe_block.build(decoder_sequence_shape) + + self._feedforward_layernorm = MixtralLayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layernorm", + ) + self._feedforward_layernorm.build(decoder_sequence_shape) + + self.built = True + + def call(self, + decoder_sequence, + decoder_padding_mask=None, + decoder_attention_mask=None, + self_attention_cache=None, + self_attention_cache_update_index=None, + training=None, + + ): + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=decoder_sequence, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, + ) + residual = decoder_sequence + + x = self._self_attention_layernorm(decoder_sequence) + + # Self attention block. + x = self._self_attention_layer( + hidden_states=x, + attention_mask=self_attention_mask, + cache=self_attention_cache, + cache_update_index=self_attention_cache_update_index, + ) + + if self_attention_cache is not None: + x, self_attention_cache = x + + x = self._self_attention_dropout(x, training=training) + + x = x + residual + residual = x + + x = self._feedforward_layernorm(x) + x, router_logits = self._sparse_moe_block(x) + + decoder_output = x + residual + + output = (decoder_output,) + + if self_attention_cache is not None: + output += (self_attention_cache,) + + if self.output_router_logits: + output += (router_logits, ) + + return output + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + + # The lower traingular attention mask + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + # Mistral uses a banded attention mask if sliding window is not None + if self.sliding_window is not None: + # Below is a workaround for `ops.triu` for Keras 2. + # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is + # removed. + # causal_mask = ops.triu(causal_mask, k=-self.sliding_window) + i = ops.arange(output_length)[:, None] + cache_update_index + j = ops.arange(input_length)[None, :] + causal_mask_upper = ops.cast(i < j + self.sliding_window, "int32") + causal_mask = ops.minimum(causal_mask, causal_mask_upper) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def compute_output_shape(self, decoder_sequence_shape): + return decoder_sequence_shape + + def get_config(self): + config = super().get_config() + config.update( + { + "intermediate_dim": self.intermediate_dim, + "num_query_heads": self.num_query_heads, + "rope_max_wavelength": self.rope_max_wavelength, + "rope_scaling_factor": self.rope_scaling_factor, + "num_key_value_heads": self.num_key_value_heads, + "sliding_window": self.sliding_window, + "activation": keras.activations.serialize(self.activation), + "layer_norm_epsilon": self.layer_norm_epsilon, + "kernel_initializer": keras.initializers.serialize( + self.kernel_initializer + ), + "dropout": self.dropout, + } + ) + return config diff --git a/keras_hub/src/models/mixtral/mixtral_layer_norm.py b/keras_hub/src/models/mixtral/mixtral_layer_norm.py index affca9c45f..ee94e3fd37 100644 --- a/keras_hub/src/models/mixtral/mixtral_layer_norm.py +++ b/keras_hub/src/models/mixtral/mixtral_layer_norm.py @@ -5,8 +5,8 @@ # TODO: Deprecate this in favor of # `keras.layers.LayerNormalization(rms_scaling=True)` once Keras 2 support is # removed. -class MistralLayerNormalization(keras.layers.Layer): - """A normalization layer for Mistral that implements RMS normalization.""" +class MixtralLayerNormalization(keras.layers.Layer): + """A normalization layer for Mixtral that implements RMS normalization.""" def __init__(self, epsilon=1e-6, **kwargs): super().__init__(**kwargs) From b509c487906b3002dd24cc008c12dc4299185ee2 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Thu, 3 Apr 2025 13:41:47 +0530 Subject: [PATCH 03/11] mixtral wip --- .../src/models/mixtral/mixtral_backbone.py | 57 ++-- .../src/models/mixtral/mixtral_causal_lm.py | 316 ++++++++++++++++++ .../mixtral/mixtral_causal_lm_preprocessor.py | 76 +++++ .../src/models/mixtral/mixtral_decoder.py | 71 +++- .../src/models/mixtral/mixtral_tokenizer.py | 57 ++++ .../src/utils/transformers/convert_mixtral.py | 116 +++++++ .../convert_mixtral_checkpoints.py | 285 ++++++++++++++++ 7 files changed, 955 insertions(+), 23 deletions(-) create mode 100644 keras_hub/src/utils/transformers/convert_mixtral.py create mode 100644 tools/checkpoint_conversion/convert_mixtral_checkpoints.py diff --git a/keras_hub/src/models/mixtral/mixtral_backbone.py b/keras_hub/src/models/mixtral/mixtral_backbone.py index 09a5d38129..cd5df71696 100644 --- a/keras_hub/src/models/mixtral/mixtral_backbone.py +++ b/keras_hub/src/models/mixtral/mixtral_backbone.py @@ -6,30 +6,30 @@ ReversibleEmbedding, ) from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.mistral.mistral_layer_norm import ( - MistralLayerNormalization, +from keras_hub.src.models.mixtral.mixtral_layer_norm import ( + MixtralLayerNormalization, ) -from keras_hub.src.models.mistral.mistral_transformer_decoder import ( - MistralTransformerDecoder, +from keras_hub.src.models.mixtral.mixtral_decoder import ( + MixtralTransformerDecoder, ) -def _mistral_kernel_initializer(stddev=0.02): +def _mixtral_kernel_initializer(stddev=0.02): return keras.initializers.RandomNormal(stddev=stddev) -@keras_hub_export("keras_hub.models.MistralBackbone") -class MistralBackbone(Backbone): +@keras_hub_export("keras_hub.models.MixtralBackbone") +class MixtralBackbone(Backbone): """ - The Mistral Transformer core architecture with hyperparameters. + The Mixtral Transformer core architecture with hyperparameters. This network implements a Transformer-based decoder network, - Mistral, as described in - ["Mistral 7B"](https://arxiv.org/pdf/2310.06825.pdf). + Mixtral, as described in + ["Mixtral 7B"](https://arxiv.org/pdf/2310.06825.pdf). It includes the embedding lookups and transformer layers. The default constructor gives a fully customizable, randomly initialized - Mistral model with any number of layers, heads, and embedding + Mixtral model with any number of layers, heads, and embedding dimensions. To load preset architectures and weights, use the `from_preset` constructor. @@ -50,7 +50,7 @@ class MistralBackbone(Backbone): calculation of roatary embedding. Defaults to `1.0`. layer_norm_epsilon (float, optional): Epsilon for the layer normalization layers in the transformer decoder. Defaults to `1e-6`. - sliding_window (int, optional): The sliding window for the mistral + sliding_window (int, optional): The sliding window for the mixtral attention layers. This controls the maximum cache size for the attention layers in each transformer decoder. Only `sliding_window` number of tokens are saved in the cache and used to generate the @@ -68,12 +68,12 @@ class MistralBackbone(Backbone): "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), } - # Pretrained Mistral decoder. - model = keras_hub.models.MistralBackbone.from_preset("mistral7b_base_en") + # Pretrained Mixtral decoder. + model = keras_hub.models.MixtralBackbone.from_preset("mixtral7b_base_en") model(input_data) - # Randomly initialized Mistral decoder with custom config. - model = keras_hub.models.MistralBackbone( + # Randomly initialized Mixtral decoder with custom config. + model = keras_hub.models.MixtralBackbone( vocabulary_size=10, hidden_dim=512, num_layers=2, @@ -96,6 +96,10 @@ def __init__( hidden_dim, intermediate_dim, num_key_value_heads, + num_experts, + top_k, + router_jitter_noise, + output_router_logits, rope_max_wavelength=10000, rope_scaling_factor=1.0, layer_norm_epsilon=1e-6, @@ -109,28 +113,32 @@ def __init__( input_dim=vocabulary_size, output_dim=hidden_dim, tie_weights=False, - embeddings_initializer=_mistral_kernel_initializer(stddev=0.01), + embeddings_initializer=_mixtral_kernel_initializer(stddev=0.01), dtype=dtype, name="token_embedding", ) self.transformer_layers = [] for i in range(num_layers): - layer = MistralTransformerDecoder( + layer = MixtralTransformerDecoder( intermediate_dim=intermediate_dim, num_query_heads=num_query_heads, num_key_value_heads=num_key_value_heads, + num_experts=num_experts, + top_k=top_k, + router_jitter_noise=router_jitter_noise, + output_router_logits=output_router_logits, rope_max_wavelength=rope_max_wavelength, rope_scaling_factor=rope_scaling_factor, layer_norm_epsilon=layer_norm_epsilon, activation=ops.silu, - kernel_initializer=_mistral_kernel_initializer(stddev=0.02), + kernel_initializer=_mixtral_kernel_initializer(stddev=0.02), sliding_window=sliding_window, dropout=dropout, dtype=dtype, name=f"transformer_layer_{i}", ) self.transformer_layers.append(layer) - self.layer_norm = MistralLayerNormalization( + self.layer_norm = MixtralLayerNormalization( epsilon=layer_norm_epsilon, dtype=dtype, name="sequence_output_layernorm", @@ -163,8 +171,12 @@ def __init__( self.num_query_heads = num_query_heads self.hidden_dim = hidden_dim self.intermediate_dim = intermediate_dim - self.rope_max_wavelength = rope_max_wavelength self.num_key_value_heads = num_key_value_heads + self.num_experts = num_experts + self.top_k = top_k + self.router_jitter_noise = router_jitter_noise + + self.rope_max_wavelength = rope_max_wavelength self.rope_scaling_factor = rope_scaling_factor self.sliding_window = sliding_window self.layer_norm_epsilon = layer_norm_epsilon @@ -179,6 +191,9 @@ def get_config(self): "num_query_heads": self.num_query_heads, "hidden_dim": self.hidden_dim, "intermediate_dim": self.intermediate_dim, + "num_experts": self.num_experts, + "top_k": self.top_k, + "router_jitter_noise": self.router_jitter_noise, "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm.py b/keras_hub/src/models/mixtral/mixtral_causal_lm.py index e69de29bb2..7d3e8ee678 100644 --- a/keras_hub/src/models/mixtral/mixtral_causal_lm.py +++ b/keras_hub/src/models/mixtral/mixtral_causal_lm.py @@ -0,0 +1,316 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( + MixtralCausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export("keras_hub.models.MixtralCausalLM") +class MixtralCausalLM(CausalLM): + """An end-to-end Mixtral model for causal language modeling. + + A causal language model (LM) predicts the next token based on previous + tokens. This task setup can be used to train the model unsupervised on + plain text input, or to autoregressively generate plain text similar to + the data used for training. This task can be used for pre-training or + fine-tuning a GPT-NeoX model, simply by calling `fit()`. + + This model has a `generate()` method, which generates text based on a + prompt. The generation strategy used is controlled by an additional + `sampler` argument on `compile()`. You can recompile the model with + different `keras_hub.samplers` objects to control the generation. By + default, `"top_k"` sampling will be used. + + Args: + backbone: A `keras_hub.models.MixtralBackbone` instance. + preprocessor: A `keras_hub.models.MixtralCausalLMPreprocessor` or + `None`. If `None`, this model will not apply preprocessing, and + inputs should be preprocessed before calling the model. + """ + + backbone_cls = MixtralBackbone + preprocessor_cls = MixtralCausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `MixtralCausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, max_length)`. + cache: a dense float Tensor, the cache of key and value. + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_query_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, 0) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: List of id's of end token's to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + # Create and seed cache with a single forward pass. + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop_tokens locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `MixtralCausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the MixtralBackbone and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Examples: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + mixtral_lm = keras_hub.models.MixtralCausalLM.from_preset( + "mixtral_7b_en" + ) + generations = mixtral_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = mixtral_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = mixtral_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer(x, decoder_padding_mask=padding_mask) + x = layer_intercept_fn(x, i) + + x = self.backbone.layer_norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py b/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py index e69de29bb2..7a3ad8f1fd 100644 --- a/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py +++ b/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor.py @@ -0,0 +1,76 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer + + +@keras_hub_export("keras_hub.models.MixtralCausalLMPreprocessor") +class MixtralCausalLMPreprocessor(CausalLMPreprocessor): + """Mixtral Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.MixtralCausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.MixtralCausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.MixtralTokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.MixtralCausalLMPreprocessor.from_preset( + "mixtral_base_en" + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("League of legends") + preprocessor(sentence) + # Same output. + preprocessor("League of legends") + + # Tokenize a batch of sentences. + sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) + preprocessor(sentences) + # Same output. + preprocessor(["Taco tuesday", "Fish taco please!"]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "Avatar 2 is amazing!", + "Well, I am not sure.", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + backbone_cls = MixtralBackbone + tokenizer_cls = MixtralTokenizer diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index 00fdb3499e..5d15b79b84 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -96,6 +96,7 @@ def __init__(self, self.intermediate_dim = intermediate_dim self.num_experts = num_experts self.top_k = top_k + self.router_jitter_noise = router_jitter_noise def build(self, decoder_sequence_shape): @@ -120,10 +121,73 @@ def build(self, decoder_sequence_shape): for expert in self.experts: expert.build(decoder_sequence_shape) - def call(self): - pass + def call(self, hidden_states, training=False): + + batch_size, seq_len, hidden_dim = hidden_states.shape + + # Jitter noise augmentation (training only) + if training and self.router_jitter_noise > 0: + random_factors = ops.random.uniform( + shape=ops.shape(hidden_states), + minval=1.0 - self.router_jitter_noise, + maxval=1.0 + self.router_jitter_noise, + dtype=hidden_states.dtype, + ) + hidden_states = hidden_states * random_factors + + hidden_states_2d = ops.reshape(hidden_states, (-1, hidden_dim)) + + router_logits = self._sparse_feedforward_gate_dense(hidden_states_2d) + routing_weights = ops.softmax(router_logits, axis=1) + + routing_weights, selected_experts = ops.top_k( + routing_weights, + k=self.top_k + ) + sum_topk = ops.sum(routing_weights, axis=-1, keepdims=True) + routing_weights = routing_weights / sum_topk + + routing_weights = ops.cast(routing_weights, hidden_states.dtype) + + # Prepare final hidden states + final_hidden_states = ops.zeros( + (batch_size * seq_len, hidden_dim), dtype=hidden_states.dtype + ) + + expert_mask = ops.one_hot(selected_experts, num_classes=self.num_experts) + expert_mask = ops.transpose(expert_mask, axes=[2, 1, 0]) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + + idx, top_x = ops.where(expert_mask[expert_idx]) + + if ops.shape(top_x)[0] == 0: + continue + + # Gather hidden states belonging to this expert + current_state = ops.take(hidden_states_2d, top_x, axis=0) + expert_output = expert_layer(current_state) + + # Multiply by routing weights + # routing_weights is shape (batch_size*seq_len, top_k) + # We want routing_weights[top_x, idx] + factor = routing_weights[top_x, idx] + factor = ops.expand_dims(factor, axis=-1) # shape = (n_tokens, 1) + current_hidden_states = expert_output * factor + + existing_values = ops.take(final_hidden_states, top_x, axis=0) + updated_values = existing_values + current_hidden_states + final_hidden_states = ops.scatter_update( + final_hidden_states, + top_x[:, None], + updated_values + ) + final_hidden_states = ops.reshape( + final_hidden_states, (batch_size, seq_len, hidden_dim)) + return final_hidden_states, router_logits class MixtralTransformerDecoder(keras.layers.Layer): @@ -327,6 +391,9 @@ def get_config(self): "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, + "num_experts": self.num_experts, + "top_k": self.top_k, + "router_jitter_noise": self.router_jitter_noise, "sliding_window": self.sliding_window, "activation": keras.activations.serialize(self.activation), "layer_norm_epsilon": self.layer_norm_epsilon, diff --git a/keras_hub/src/models/mixtral/mixtral_tokenizer.py b/keras_hub/src/models/mixtral/mixtral_tokenizer.py index e69de29bb2..03c35dbdca 100644 --- a/keras_hub/src/models/mixtral/mixtral_tokenizer.py +++ b/keras_hub/src/models/mixtral/mixtral_tokenizer.py @@ -0,0 +1,57 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) + + +@keras_hub_export( + [ + "keras_hub.tokenizers.MixtralTokenizer", + "keras_hub.models.MixtralTokenizer", + ] +) +class MixtralTokenizer(SentencePieceTokenizer): + """Mixtral tokenizer layer based on SentencePiece. + + This tokenizer class will tokenize raw strings into integer sequences and + is based on `keras_hub.tokenizers.SentencePieceTokenizer`. Unlike the + underlying tokenizer, it will check for all special tokens needed by + Mixtral models and provides a `from_preset()` method to automatically + download a matching vocabulary for a Mixtral preset. + + If input is a batch of strings (rank > 0), the layer will output a + `tf.RaggedTensor` where the last dimension of the output is ragged. + + If input is a scalar string (rank == 0), the layer will output a dense + `tf.Tensor` with static shape `[None]`. + + Args: + proto: Either a `string` path to a SentencePiece proto file, or a + `bytes` object with a serialized SentencePiece proto. See the + [SentencePiece repository](https://github.com/google/sentencepiece) + for more details on the format. + + Examples: + ```python + # Unbatched input. + tokenizer = keras_hub.models.MixtralTokenizer.from_preset( + "mixtral_7b_en", + ) + tokenizer("The quick brown fox jumped.") + + # Batched input. + tokenizer(["The quick brown fox jumped.", "The fox slept."]) + + # Detokenization. + tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) + ``` + """ + + backbone_cls = MixtralBackbone + + def __init__(self, proto, **kwargs): + self._add_special_token("", "start_token") + self._add_special_token("", "end_token") + self.pad_token_id = 0 + super().__init__(proto=proto, **kwargs) diff --git a/keras_hub/src/utils/transformers/convert_mixtral.py b/keras_hub/src/utils/transformers/convert_mixtral.py new file mode 100644 index 0000000000..aa0489f89a --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_mixtral.py @@ -0,0 +1,116 @@ +import numpy as np + +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.utils.preset_utils import get_file + +backbone_cls = MixtralBackbone + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_query_heads": transformers_config["num_attention_heads"], + "hidden_dim": transformers_config["hidden_size"], + "intermediate_dim": transformers_config["intermediate_size"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "rope_max_wavelength": transformers_config["rope_theta"], + "layer_norm_epsilon": transformers_config["rms_norm_eps"], + "sliding_window": transformers_config["sliding_window"], + } + + +def convert_weights(backbone, loader, transformers_config): + # Embeddings + loader.port_weight( + keras_variable=backbone.token_embedding.embeddings, + hf_weight_key="model.embed_tokens.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + loader.port_weight( + keras_variable=backbone.token_embedding.reverse_embeddings, + hf_weight_key="lm_head.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + + # Attention blocks + for index in range(backbone.num_layers): + decoder_layer = backbone.transformer_layers[index] + + # Norm layers + loader.port_weight( + keras_variable=decoder_layer._self_attention_layernorm.scale, + hf_weight_key=f"model.layers.{index}.input_layernorm.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + + # Attention layers + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.q_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.k_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.v_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + loader.port_weight( + keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, + hf_weight_key=f"model.layers.{index}.self_attn.o_proj.weight", + hook_fn=lambda hf_tensor, keras_shape: np.reshape( + np.transpose(hf_tensor.astype(np.float16)), keras_shape + ), + ) + + # MLP layers + loader.port_weight( + keras_variable=decoder_layer._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{index}.mlp.gate_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{index}.mlp.up_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + loader.port_weight( + keras_variable=decoder_layer._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{index}.mlp.down_proj.weight", + hook_fn=lambda hf_tensor, _: np.transpose( + hf_tensor.astype(np.float16), axes=(1, 0) + ), + ) + + # Normalization + loader.port_weight( + keras_variable=backbone.layer_norm.scale, + hf_weight_key="model.norm.weight", + hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + ) + + +def convert_tokenizer(cls, preset, **kwargs): + return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/tools/checkpoint_conversion/convert_mixtral_checkpoints.py b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py new file mode 100644 index 0000000000..071c5f3f6a --- /dev/null +++ b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py @@ -0,0 +1,285 @@ +import gc +import os +import shutil +import tempfile +import traceback + +import numpy as np +import requests +from absl import app +from absl import flags +from keras import ops +from transformers import AutoTokenizer +from transformers import MixtralForCausalLM + +from keras_hub.models import MixtralBackbone +from keras_hub.models import MixtralCausalLMPreprocessor +from keras_hub.models import MixtralTokenizer +from keras_hub.utils.preset_utils import save_to_preset + +PRESET_MAP = { +} + +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) + + +def convert_checkpoints(keras_hub_model, hf_model): + config = hf_model.config + + keras_hub_model.token_embedding.embeddings.assign( + hf_model.model.embed_tokens.weight.detach().cpu().numpy() + ) + + for i in range(keras_hub_model.num_layers): + keras_hub_model.transformer_layers[ + i + ]._self_attention_layer._key_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.k_proj.weight.T.reshape( + config.hidden_size, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + ) + .detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._self_attention_layer._query_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.q_proj.weight.T.reshape( + config.hidden_size, + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + ) + .detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._self_attention_layer._value_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.v_proj.weight.T.reshape( + config.hidden_size, + config.num_key_value_heads, + config.hidden_size // config.num_attention_heads, + ) + .detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._self_attention_layer._output_dense.set_weights( + [ + hf_model.model.layers[i] + .self_attn.o_proj.weight.T.reshape( + config.num_attention_heads, + config.hidden_size // config.num_attention_heads, + config.hidden_size, + ) + .detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._self_attention_layernorm.set_weights( + [ + hf_model.model.layers[i] + .input_layernorm.weight.detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._feedforward_intermediate_dense.set_weights( + [ + hf_model.model.layers[i] + .mlp.up_proj.weight.T.detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._feedforward_output_dense.set_weights( + [ + hf_model.model.layers[i] + .mlp.down_proj.weight.T.detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._feedforward_gate_dense.set_weights( + [ + hf_model.model.layers[i] + .mlp.gate_proj.weight.T.detach() + .cpu() + .numpy() + ] + ) + keras_hub_model.transformer_layers[ + i + ]._feedforward_layernorm.set_weights( + [ + hf_model.model.layers[i] + .post_attention_layernorm.weight.detach() + .cpu() + .numpy() + ] + ) + + keras_hub_model.layer_norm.set_weights( + [hf_model.model.norm.weight.detach().cpu().numpy()] + ) + keras_hub_model.token_embedding.reverse_embeddings.assign( + hf_model.lm_head.weight.T.detach().cpu().numpy() + ) + + +def test_model( + keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer +): + # First, test that the number of parameters match + keras_hub_params = keras_hub_model.count_params() + hf_params = hf_model.num_parameters() + assert keras_hub_params == hf_params + + # Test the outputs of both the models + hf_outputs = hf_model( + **hf_model_tokenizer(["What is Keras?"], return_tensors="pt") + ) + hf_output_logits = hf_outputs.logits.detach().cpu().numpy() + + keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) + keras_hub_output = keras_hub_model( + keras_hub_preprocessor(["What is Keras?"], sequence_length=6)[0] + ) + keras_hub_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + + # High tolerence since bfloat16 is used as the default dtype for Mixtral + try: + np.testing.assert_allclose( + keras_hub_logits, hf_output_logits, atol=1e-4 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) + keras_hub_output = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=6 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def main(_): + # === Get the preset name === + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + + # === Create the temporary save directories === + temp_dir = tempfile.mkdtemp() + + try: + # === Load the Huggingface model === + hf_model = MixtralForCausalLM.from_pretrained(hf_preset) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) + hf_model.eval() + print("\n-> Huggingface model and tokenizer loaded") + + # === Load the KerasHub model === + backbone_kwargs = dict( + vocabulary_size=hf_model.config.vocab_size, + hidden_dim=hf_model.config.hidden_size, + num_layers=hf_model.config.num_hidden_layers, + num_query_heads=hf_model.config.num_attention_heads, + num_key_value_heads=hf_model.config.num_key_value_heads, + intermediate_dim=hf_model.config.intermediate_size, + sliding_window=hf_model.config.sliding_window, + layer_norm_epsilon=hf_model.config.rms_norm_eps, + rope_max_wavelength=hf_model.config.rope_theta, + dtype="float32", + ) + keras_hub_model = MixtralBackbone(**backbone_kwargs) + + # === Download the tokenizer from Huggingface model card === + spm_path = ( + f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" + ) + response = requests.get(spm_path) + if not response.ok: + raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") + tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") + with open(tokenizer_path, "wb") as tokenizer_file: + tokenizer_file.write(response.content) + keras_hub_tokenizer = MixtralTokenizer(tokenizer_path) + print("\n-> Keras 3 model and tokenizer loaded.") + + # === Port the weights === + convert_checkpoints(keras_hub_model, hf_model) + print("\n-> Weight transfer done.") + + # === Check that the models and tokenizers outputs match === + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) + test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) + print("\n-> Tests passed!") + + # === Save the model weights in float32 format === + keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) + print("\n-> Saved the model weights in float32") + + del keras_hub_model, hf_model + gc.collect() + + # === Save the weights again in float16 === + backbone_kwargs["dtype"] = "float16" + keras_hub_model = MixtralBackbone(**backbone_kwargs) + keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) + save_to_preset(keras_hub_model, preset) + print("\n-> Saved the model preset in float16") + + # === Save the tokenizer === + save_to_preset( + keras_hub_tokenizer, preset, config_filename="tokenizer.json" + ) + print("\n-> Saved the tokenizer") + finally: + shutil.rmtree(temp_dir) + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From b0160cb6eecce76ea432eb4fbda66bccd8a7f949 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Sun, 6 Apr 2025 17:55:27 +0000 Subject: [PATCH 04/11] checkpoint conversion wip --- keras_hub/api/models/__init__.py | 6 + keras_hub/api/tokenizers/__init__.py | 1 + .../src/models/mixtral/mixtral_backbone.py | 6 +- .../src/models/mixtral/mixtral_decoder.py | 67 ++- .../src/utils/transformers/convert_mixtral.py | 4 + .../convert_mixtral_checkpoints.py | 553 +++++++++--------- 6 files changed, 331 insertions(+), 306 deletions(-) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 323a638121..fbaf3cd2e6 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -225,6 +225,12 @@ from keras_hub.src.models.mit.mit_image_classifier_preprocessor import ( MiTImageClassifierPreprocessor, ) +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.models.mixtral.mixtral_causal_lm import MixtralCausalLM +from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( + MixtralCausalLMPreprocessor, +) +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( MobileNetImageClassifier, diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 3db0f643e5..f774d87696 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -24,6 +24,7 @@ from keras_hub.src.models.llama.llama_tokenizer import LlamaTokenizer from keras_hub.src.models.llama3.llama3_tokenizer import Llama3Tokenizer from keras_hub.src.models.mistral.mistral_tokenizer import MistralTokenizer +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import ( PaliGemmaTokenizer, diff --git a/keras_hub/src/models/mixtral/mixtral_backbone.py b/keras_hub/src/models/mixtral/mixtral_backbone.py index cd5df71696..50afd76d41 100644 --- a/keras_hub/src/models/mixtral/mixtral_backbone.py +++ b/keras_hub/src/models/mixtral/mixtral_backbone.py @@ -6,12 +6,12 @@ ReversibleEmbedding, ) from keras_hub.src.models.backbone import Backbone -from keras_hub.src.models.mixtral.mixtral_layer_norm import ( - MixtralLayerNormalization, -) from keras_hub.src.models.mixtral.mixtral_decoder import ( MixtralTransformerDecoder, ) +from keras_hub.src.models.mixtral.mixtral_layer_norm import ( + MixtralLayerNormalization, +) def _mixtral_kernel_initializer(stddev=0.02): diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index 5d15b79b84..29f50e2337 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -1,10 +1,21 @@ import keras from keras import ops -from keras_hub.src.layers.modeling.transformer_layer_utils import compute_causal_mask, merge_padding_and_attention_mask -from keras_hub.src.models.mixtral.mixtral_attention import CachedMixtralAttention -from keras_hub.src.models.mixtral.mixtral_layer_norm import MixtralLayerNormalization + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.mixtral.mixtral_attention import ( + CachedMixtralAttention, +) +from keras_hub.src.models.mixtral.mixtral_layer_norm import ( + MixtralLayerNormalization, +) from keras_hub.src.utils.keras_utils import clone_initializer + class MixtralMoeMLP(keras.layers.Layer): def __init__( self, @@ -80,17 +91,16 @@ def call(self, x): return x - class MixtralSparseMoeBlock(keras.layers.Layer): - - - def __init__(self, - hidden_dim, - intermediate_dim, - num_experts, - top_k, - router_jitter_noise, - **kwargs): + def __init__( + self, + hidden_dim, + intermediate_dim, + num_experts, + top_k, + router_jitter_noise, + **kwargs, + ): super().__init__(**kwargs) self.hidden_dim = hidden_dim self.intermediate_dim = intermediate_dim @@ -99,7 +109,6 @@ def __init__(self, self.router_jitter_noise = router_jitter_noise def build(self, decoder_sequence_shape): - self._sparse_feedforward_gate_dense = keras.layers.Dense( self.num_experts, kernel_initializer=clone_initializer(self.kernel_initializer), @@ -122,7 +131,6 @@ def build(self, decoder_sequence_shape): expert.build(decoder_sequence_shape) def call(self, hidden_states, training=False): - batch_size, seq_len, hidden_dim = hidden_states.shape # Jitter noise augmentation (training only) @@ -141,8 +149,7 @@ def call(self, hidden_states, training=False): routing_weights = ops.softmax(router_logits, axis=1) routing_weights, selected_experts = ops.top_k( - routing_weights, - k=self.top_k + routing_weights, k=self.top_k ) sum_topk = ops.sum(routing_weights, axis=-1, keepdims=True) routing_weights = routing_weights / sum_topk @@ -154,7 +161,9 @@ def call(self, hidden_states, training=False): (batch_size * seq_len, hidden_dim), dtype=hidden_states.dtype ) - expert_mask = ops.one_hot(selected_experts, num_classes=self.num_experts) + expert_mask = ops.one_hot( + selected_experts, num_classes=self.num_experts + ) expert_mask = ops.transpose(expert_mask, axes=[2, 1, 0]) for expert_idx in range(self.num_experts): @@ -179,19 +188,17 @@ def call(self, hidden_states, training=False): existing_values = ops.take(final_hidden_states, top_x, axis=0) updated_values = existing_values + current_hidden_states final_hidden_states = ops.scatter_update( - final_hidden_states, - top_x[:, None], - updated_values + final_hidden_states, top_x[:, None], updated_values ) final_hidden_states = ops.reshape( - final_hidden_states, (batch_size, seq_len, hidden_dim)) + final_hidden_states, (batch_size, seq_len, hidden_dim) + ) return final_hidden_states, router_logits - -class MixtralTransformerDecoder(keras.layers.Layer): +class MixtralTransformerDecoder(keras.layers.Layer): def __init__( self, intermediate_dim, @@ -268,7 +275,7 @@ def build(self, decoder_sequence_shape): intermediate_dim=self.intermediate_dim, num_experts=self.num_experts, top_k=self.top_k, - router_jitter_noise=self.router_jitter_noise + router_jitter_noise=self.router_jitter_noise, ) self._sparse_moe_block.build(decoder_sequence_shape) @@ -281,15 +288,15 @@ def build(self, decoder_sequence_shape): self.built = True - def call(self, + def call( + self, decoder_sequence, decoder_padding_mask=None, decoder_attention_mask=None, self_attention_cache=None, self_attention_cache_update_index=None, training=None, - - ): + ): self_attention_mask = self._compute_self_attention_mask( decoder_sequence=decoder_sequence, decoder_padding_mask=decoder_padding_mask, @@ -323,12 +330,12 @@ def call(self, decoder_output = x + residual output = (decoder_output,) - + if self_attention_cache is not None: output += (self_attention_cache,) if self.output_router_logits: - output += (router_logits, ) + output += (router_logits,) return output diff --git a/keras_hub/src/utils/transformers/convert_mixtral.py b/keras_hub/src/utils/transformers/convert_mixtral.py index aa0489f89a..8696b4db43 100644 --- a/keras_hub/src/utils/transformers/convert_mixtral.py +++ b/keras_hub/src/utils/transformers/convert_mixtral.py @@ -14,9 +14,13 @@ def convert_backbone_config(transformers_config): "hidden_dim": transformers_config["hidden_size"], "intermediate_dim": transformers_config["intermediate_size"], "num_key_value_heads": transformers_config["num_key_value_heads"], + "num_experts": transformers_config['num_local_experts'], + "top_k": transformers_config['num_experts_per_tok'], "rope_max_wavelength": transformers_config["rope_theta"], "layer_norm_epsilon": transformers_config["rms_norm_eps"], "sliding_window": transformers_config["sliding_window"], + "router_jitter_noise": transformers_config["router_jitter_noise"], + "output_router_logits": transformers_config['output_router_logits'], } diff --git a/tools/checkpoint_conversion/convert_mixtral_checkpoints.py b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py index 071c5f3f6a..63e0c8ac1d 100644 --- a/tools/checkpoint_conversion/convert_mixtral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py @@ -1,285 +1,292 @@ -import gc -import os -import shutil -import tempfile -import traceback - -import numpy as np -import requests -from absl import app -from absl import flags -from keras import ops -from transformers import AutoTokenizer +# import gc +# import os +# import shutil +# import tempfile +# import traceback + +# import numpy as np +# import requests +# from absl import app +# from absl import flags +# from keras import ops +# from transformers import AutoTokenizer +import torch from transformers import MixtralForCausalLM -from keras_hub.models import MixtralBackbone -from keras_hub.models import MixtralCausalLMPreprocessor -from keras_hub.models import MixtralTokenizer -from keras_hub.utils.preset_utils import save_to_preset +# from keras_hub.models import MixtralBackbone +# from keras_hub.models import MixtralCausalLMPreprocessor +# from keras_hub.models import MixtralTokenizer +# from keras_hub.utils.preset_utils import save_to_preset -PRESET_MAP = { -} +mixtral_lm = MixtralForCausalLM.from_pretrained( + "mistralai/Mixtral-8x7B-v0.1", + torch_dtype=torch.bfloat16 + ) -FLAGS = flags.FLAGS -flags.DEFINE_string( - "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" -) +print(mixtral_lm) +# PRESET_MAP = {} -def convert_checkpoints(keras_hub_model, hf_model): - config = hf_model.config +# FLAGS = flags.FLAGS +# flags.DEFINE_string( +# "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +# ) - keras_hub_model.token_embedding.embeddings.assign( - hf_model.model.embed_tokens.weight.detach().cpu().numpy() - ) - for i in range(keras_hub_model.num_layers): - keras_hub_model.transformer_layers[ - i - ]._self_attention_layer._key_dense.set_weights( - [ - hf_model.model.layers[i] - .self_attn.k_proj.weight.T.reshape( - config.hidden_size, - config.num_key_value_heads, - config.hidden_size // config.num_attention_heads, - ) - .detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._self_attention_layer._query_dense.set_weights( - [ - hf_model.model.layers[i] - .self_attn.q_proj.weight.T.reshape( - config.hidden_size, - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - ) - .detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._self_attention_layer._value_dense.set_weights( - [ - hf_model.model.layers[i] - .self_attn.v_proj.weight.T.reshape( - config.hidden_size, - config.num_key_value_heads, - config.hidden_size // config.num_attention_heads, - ) - .detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._self_attention_layer._output_dense.set_weights( - [ - hf_model.model.layers[i] - .self_attn.o_proj.weight.T.reshape( - config.num_attention_heads, - config.hidden_size // config.num_attention_heads, - config.hidden_size, - ) - .detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._self_attention_layernorm.set_weights( - [ - hf_model.model.layers[i] - .input_layernorm.weight.detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._feedforward_intermediate_dense.set_weights( - [ - hf_model.model.layers[i] - .mlp.up_proj.weight.T.detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._feedforward_output_dense.set_weights( - [ - hf_model.model.layers[i] - .mlp.down_proj.weight.T.detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._feedforward_gate_dense.set_weights( - [ - hf_model.model.layers[i] - .mlp.gate_proj.weight.T.detach() - .cpu() - .numpy() - ] - ) - keras_hub_model.transformer_layers[ - i - ]._feedforward_layernorm.set_weights( - [ - hf_model.model.layers[i] - .post_attention_layernorm.weight.detach() - .cpu() - .numpy() - ] - ) - - keras_hub_model.layer_norm.set_weights( - [hf_model.model.norm.weight.detach().cpu().numpy()] - ) - keras_hub_model.token_embedding.reverse_embeddings.assign( - hf_model.lm_head.weight.T.detach().cpu().numpy() - ) +# def convert_checkpoints(keras_hub_model, hf_model): +# config = hf_model.config +# keras_hub_model.token_embedding.embeddings.assign( +# hf_model.model.embed_tokens.weight.detach().cpu().numpy() +# ) -def test_model( - keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer -): - # First, test that the number of parameters match - keras_hub_params = keras_hub_model.count_params() - hf_params = hf_model.num_parameters() - assert keras_hub_params == hf_params +# for i in range(keras_hub_model.num_layers): +# keras_hub_model.transformer_layers[ +# i +# ]._self_attention_layer._key_dense.set_weights( +# [ +# hf_model.model.layers[i] +# .self_attn.k_proj.weight.T.reshape( +# config.hidden_size, +# config.num_key_value_heads, +# config.hidden_size // config.num_attention_heads, +# ) +# .detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._self_attention_layer._query_dense.set_weights( +# [ +# hf_model.model.layers[i] +# .self_attn.q_proj.weight.T.reshape( +# config.hidden_size, +# config.num_attention_heads, +# config.hidden_size // config.num_attention_heads, +# ) +# .detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._self_attention_layer._value_dense.set_weights( +# [ +# hf_model.model.layers[i] +# .self_attn.v_proj.weight.T.reshape( +# config.hidden_size, +# config.num_key_value_heads, +# config.hidden_size // config.num_attention_heads, +# ) +# .detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._self_attention_layer._output_dense.set_weights( +# [ +# hf_model.model.layers[i] +# .self_attn.o_proj.weight.T.reshape( +# config.num_attention_heads, +# config.hidden_size // config.num_attention_heads, +# config.hidden_size, +# ) +# .detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._self_attention_layernorm.set_weights( +# [ +# hf_model.model.layers[i] +# .input_layernorm.weight.detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._feedforward_intermediate_dense.set_weights( +# [ +# hf_model.model.layers[i] +# .mlp.up_proj.weight.T.detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._feedforward_output_dense.set_weights( +# [ +# hf_model.model.layers[i] +# .mlp.down_proj.weight.T.detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._feedforward_gate_dense.set_weights( +# [ +# hf_model.model.layers[i] +# .mlp.gate_proj.weight.T.detach() +# .cpu() +# .numpy() +# ] +# ) +# keras_hub_model.transformer_layers[ +# i +# ]._feedforward_layernorm.set_weights( +# [ +# hf_model.model.layers[i] +# .post_attention_layernorm.weight.detach() +# .cpu() +# .numpy() +# ] +# ) - # Test the outputs of both the models - hf_outputs = hf_model( - **hf_model_tokenizer(["What is Keras?"], return_tensors="pt") - ) - hf_output_logits = hf_outputs.logits.detach().cpu().numpy() +# keras_hub_model.layer_norm.set_weights( +# [hf_model.model.norm.weight.detach().cpu().numpy()] +# ) +# keras_hub_model.token_embedding.reverse_embeddings.assign( +# hf_model.lm_head.weight.T.detach().cpu().numpy() +# ) - keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) - keras_hub_output = keras_hub_model( - keras_hub_preprocessor(["What is Keras?"], sequence_length=6)[0] - ) - keras_hub_logits = keras_hub_model.token_embedding( - keras_hub_output, reverse=True - ) - keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) - - # High tolerence since bfloat16 is used as the default dtype for Mixtral - try: - np.testing.assert_allclose( - keras_hub_logits, hf_output_logits, atol=1e-4 - ) - except AssertionError as err: - print("\n") - print(traceback.format_exc()) - print(err.args[0]) - print("\n") - - -def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): - hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") - hf_output = hf_output["input_ids"].detach().cpu().numpy() - keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) - keras_hub_output = keras_hub_preprocessor( - ["What is Keras?"], sequence_length=6 - ) - keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) - - np.testing.assert_equal(keras_hub_output, hf_output) - - -def main(_): - # === Get the preset name === - if FLAGS.preset not in PRESET_MAP.keys(): - raise ValueError( - f"Invalid preset {FLAGS.preset}. Must be one " - f"of {','.join(PRESET_MAP.keys())}" - ) - preset = FLAGS.preset - hf_preset = PRESET_MAP[preset] - - # === Create the temporary save directories === - temp_dir = tempfile.mkdtemp() - - try: - # === Load the Huggingface model === - hf_model = MixtralForCausalLM.from_pretrained(hf_preset) - hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) - hf_model.eval() - print("\n-> Huggingface model and tokenizer loaded") - - # === Load the KerasHub model === - backbone_kwargs = dict( - vocabulary_size=hf_model.config.vocab_size, - hidden_dim=hf_model.config.hidden_size, - num_layers=hf_model.config.num_hidden_layers, - num_query_heads=hf_model.config.num_attention_heads, - num_key_value_heads=hf_model.config.num_key_value_heads, - intermediate_dim=hf_model.config.intermediate_size, - sliding_window=hf_model.config.sliding_window, - layer_norm_epsilon=hf_model.config.rms_norm_eps, - rope_max_wavelength=hf_model.config.rope_theta, - dtype="float32", - ) - keras_hub_model = MixtralBackbone(**backbone_kwargs) - - # === Download the tokenizer from Huggingface model card === - spm_path = ( - f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" - ) - response = requests.get(spm_path) - if not response.ok: - raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") - tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") - with open(tokenizer_path, "wb") as tokenizer_file: - tokenizer_file.write(response.content) - keras_hub_tokenizer = MixtralTokenizer(tokenizer_path) - print("\n-> Keras 3 model and tokenizer loaded.") - - # === Port the weights === - convert_checkpoints(keras_hub_model, hf_model) - print("\n-> Weight transfer done.") - - # === Check that the models and tokenizers outputs match === - test_tokenizer(keras_hub_tokenizer, hf_tokenizer) - test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) - print("\n-> Tests passed!") - - # === Save the model weights in float32 format === - keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) - print("\n-> Saved the model weights in float32") - - del keras_hub_model, hf_model - gc.collect() - - # === Save the weights again in float16 === - backbone_kwargs["dtype"] = "float16" - keras_hub_model = MixtralBackbone(**backbone_kwargs) - keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) - save_to_preset(keras_hub_model, preset) - print("\n-> Saved the model preset in float16") - - # === Save the tokenizer === - save_to_preset( - keras_hub_tokenizer, preset, config_filename="tokenizer.json" - ) - print("\n-> Saved the tokenizer") - finally: - shutil.rmtree(temp_dir) - - -if __name__ == "__main__": - flags.mark_flag_as_required("preset") - app.run(main) + +# def test_model( +# keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer +# ): +# # First, test that the number of parameters match +# keras_hub_params = keras_hub_model.count_params() +# hf_params = hf_model.num_parameters() +# assert keras_hub_params == hf_params + +# # Test the outputs of both the models +# hf_outputs = hf_model( +# **hf_model_tokenizer(["What is Keras?"], return_tensors="pt") +# ) +# hf_output_logits = hf_outputs.logits.detach().cpu().numpy() + +# keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) +# keras_hub_output = keras_hub_model( +# keras_hub_preprocessor(["What is Keras?"], sequence_length=6)[0] +# ) +# keras_hub_logits = keras_hub_model.token_embedding( +# keras_hub_output, reverse=True +# ) +# keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) + +# # High tolerence since bfloat16 is used as the default dtype for Mixtral +# try: +# np.testing.assert_allclose( +# keras_hub_logits, hf_output_logits, atol=1e-4 +# ) +# except AssertionError as err: +# print("\n") +# print(traceback.format_exc()) +# print(err.args[0]) +# print("\n") + + +# def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): +# hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") +# hf_output = hf_output["input_ids"].detach().cpu().numpy() +# keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) +# keras_hub_output = keras_hub_preprocessor( +# ["What is Keras?"], sequence_length=6 +# ) +# keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + +# np.testing.assert_equal(keras_hub_output, hf_output) + + +# def main(_): +# # === Get the preset name === +# if FLAGS.preset not in PRESET_MAP.keys(): +# raise ValueError( +# f"Invalid preset {FLAGS.preset}. Must be one " +# f"of {','.join(PRESET_MAP.keys())}" +# ) +# preset = FLAGS.preset +# hf_preset = PRESET_MAP[preset] + +# # === Create the temporary save directories === +# temp_dir = tempfile.mkdtemp() + +# try: +# # === Load the Huggingface model === +# hf_model = MixtralForCausalLM.from_pretrained(hf_preset) +# hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) +# hf_model.eval() +# print("\n-> Huggingface model and tokenizer loaded") + +# # === Load the KerasHub model === +# backbone_kwargs = dict( +# vocabulary_size=hf_model.config.vocab_size, +# hidden_dim=hf_model.config.hidden_size, +# num_layers=hf_model.config.num_hidden_layers, +# num_query_heads=hf_model.config.num_attention_heads, +# num_key_value_heads=hf_model.config.num_key_value_heads, +# intermediate_dim=hf_model.config.intermediate_size, +# sliding_window=hf_model.config.sliding_window, +# layer_norm_epsilon=hf_model.config.rms_norm_eps, +# rope_max_wavelength=hf_model.config.rope_theta, +# dtype="float32", +# ) +# keras_hub_model = MixtralBackbone(**backbone_kwargs) + +# # === Download the tokenizer from Huggingface model card === +# spm_path = ( +# f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" +# ) +# response = requests.get(spm_path) +# if not response.ok: +# raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") +# tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") +# with open(tokenizer_path, "wb") as tokenizer_file: +# tokenizer_file.write(response.content) +# keras_hub_tokenizer = MixtralTokenizer(tokenizer_path) +# print("\n-> Keras 3 model and tokenizer loaded.") + +# # === Port the weights === +# convert_checkpoints(keras_hub_model, hf_model) +# print("\n-> Weight transfer done.") + +# # === Check that the models and tokenizers outputs match === +# test_tokenizer(keras_hub_tokenizer, hf_tokenizer) +# test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) +# print("\n-> Tests passed!") + +# # === Save the model weights in float32 format === +# keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) +# print("\n-> Saved the model weights in float32") + +# del keras_hub_model, hf_model +# gc.collect() + +# # === Save the weights again in float16 === +# backbone_kwargs["dtype"] = "float16" +# keras_hub_model = MixtralBackbone(**backbone_kwargs) +# keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) +# save_to_preset(keras_hub_model, preset) +# print("\n-> Saved the model preset in float16") + +# # === Save the tokenizer === +# save_to_preset( +# keras_hub_tokenizer, preset, config_filename="tokenizer.json" +# ) +# print("\n-> Saved the tokenizer") +# finally: +# shutil.rmtree(temp_dir) + + +# if __name__ == "__main__": +# flags.mark_flag_as_required("preset") +# app.run(main) From b9bc2e3008f55c067b88b1b06b0a710ab9243699 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 9 Apr 2025 18:31:07 +0000 Subject: [PATCH 05/11] mixtral weight matching complete --- keras_hub/src/models/mixtral/README.md | 36 ++ .../src/models/mixtral/mixtral_backbone.py | 2 +- .../src/models/mixtral/mixtral_decoder.py | 7 +- .../src/utils/transformers/convert_mixtral.py | 51 +-- .../src/utils/transformers/preset_loader.py | 4 +- .../convert_mixtral_checkpoints.py | 390 ++++++------------ 6 files changed, 190 insertions(+), 300 deletions(-) create mode 100644 keras_hub/src/models/mixtral/README.md diff --git a/keras_hub/src/models/mixtral/README.md b/keras_hub/src/models/mixtral/README.md new file mode 100644 index 0000000000..1516ba9226 --- /dev/null +++ b/keras_hub/src/models/mixtral/README.md @@ -0,0 +1,36 @@ +# Mixtral Model Architecture: + + +``` +MixtralForCausalLM( + (model): MixtralModel( + (embed_tokens): Embedding(32000, 4096) + (layers): ModuleList( + (0-31): 32 x MixtralDecoderLayer( + (self_attn): MixtralAttention( + (q_proj): Linear(in_features=4096, out_features=4096, bias=False) + (k_proj): Linear(in_features=4096, out_features=1024, bias=False) + (v_proj): Linear(in_features=4096, out_features=1024, bias=False) + (o_proj): Linear(in_features=4096, out_features=4096, bias=False) + ) + (block_sparse_moe): MixtralSparseMoeBlock( + (gate): Linear(in_features=4096, out_features=8, bias=False) + (experts): ModuleList( + (0-7): 8 x MixtralBlockSparseTop2MLP( + (w1): Linear(in_features=4096, out_features=14336, bias=False) + (w2): Linear(in_features=14336, out_features=4096, bias=False) + (w3): Linear(in_features=4096, out_features=14336, bias=False) + (act_fn): SiLU() + ) + ) + ) + (input_layernorm): MixtralRMSNorm((4096,), eps=1e-05) + (post_attention_layernorm): MixtralRMSNorm((4096,), eps=1e-05) + ) + ) + (norm): MixtralRMSNorm((4096,), eps=1e-05) + (rotary_emb): MixtralRotaryEmbedding() + ) + (lm_head): Linear(in_features=4096, out_features=32000, bias=False) +) +``` \ No newline at end of file diff --git a/keras_hub/src/models/mixtral/mixtral_backbone.py b/keras_hub/src/models/mixtral/mixtral_backbone.py index 50afd76d41..06a80127f6 100644 --- a/keras_hub/src/models/mixtral/mixtral_backbone.py +++ b/keras_hub/src/models/mixtral/mixtral_backbone.py @@ -98,8 +98,8 @@ def __init__( num_key_value_heads, num_experts, top_k, - router_jitter_noise, output_router_logits, + router_jitter_noise=0., rope_max_wavelength=10000, rope_scaling_factor=1.0, layer_norm_epsilon=1e-6, diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index 29f50e2337..d26cc21f63 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -99,6 +99,8 @@ def __init__( num_experts, top_k, router_jitter_noise, + layer_norm_epsilon=1e-5, + kernel_initializer="glorot_uniform", **kwargs, ): super().__init__(**kwargs) @@ -108,6 +110,9 @@ def __init__( self.top_k = top_k self.router_jitter_noise = router_jitter_noise + self.layer_norm_epsilon = layer_norm_epsilon + self.kernel_initializer = keras.initializers.get(kernel_initializer) + def build(self, decoder_sequence_shape): self._sparse_feedforward_gate_dense = keras.layers.Dense( self.num_experts, @@ -120,7 +125,7 @@ def build(self, decoder_sequence_shape): self.experts = [ MixtralMoeMLP( - intermediate_dim=self.moe_intermediate_dim, + intermediate_dim=self.intermediate_dim, hidden_dim=self.hidden_dim, kernel_initializer=self.kernel_initializer, layer_norm_epsilon=self.layer_norm_epsilon, diff --git a/keras_hub/src/utils/transformers/convert_mixtral.py b/keras_hub/src/utils/transformers/convert_mixtral.py index 8696b4db43..07fc9dad10 100644 --- a/keras_hub/src/utils/transformers/convert_mixtral.py +++ b/keras_hub/src/utils/transformers/convert_mixtral.py @@ -19,11 +19,9 @@ def convert_backbone_config(transformers_config): "rope_max_wavelength": transformers_config["rope_theta"], "layer_norm_epsilon": transformers_config["rms_norm_eps"], "sliding_window": transformers_config["sliding_window"], - "router_jitter_noise": transformers_config["router_jitter_noise"], "output_router_logits": transformers_config['output_router_logits'], } - def convert_weights(backbone, loader, transformers_config): # Embeddings loader.port_weight( @@ -39,7 +37,7 @@ def convert_weights(backbone, loader, transformers_config): ), ) - # Attention blocks + # Attention blocks and MoE blocks for index in range(backbone.num_layers): decoder_layer = backbone.transformer_layers[index] @@ -85,29 +83,35 @@ def convert_weights(backbone, loader, transformers_config): ), ) - # MLP layers - loader.port_weight( - keras_variable=decoder_layer._feedforward_gate_dense.kernel, - hf_weight_key=f"model.layers.{index}.mlp.gate_proj.weight", - hook_fn=lambda hf_tensor, _: np.transpose( - hf_tensor.astype(np.float16), axes=(1, 0) - ), - ) - loader.port_weight( - keras_variable=decoder_layer._feedforward_intermediate_dense.kernel, - hf_weight_key=f"model.layers.{index}.mlp.up_proj.weight", - hook_fn=lambda hf_tensor, _: np.transpose( - hf_tensor.astype(np.float16), axes=(1, 0) - ), - ) + # MoE block - Router gate loader.port_weight( - keras_variable=decoder_layer._feedforward_output_dense.kernel, - hf_weight_key=f"model.layers.{index}.mlp.down_proj.weight", - hook_fn=lambda hf_tensor, _: np.transpose( - hf_tensor.astype(np.float16), axes=(1, 0) - ), + keras_variable=decoder_layer._sparse_moe_block._sparse_feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{index}.block_sparse_moe.gate.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), ) + # MoE block - Experts + for expert_index in range(backbone.num_experts): + expert = decoder_layer._sparse_moe_block.experts[expert_index] + # w1: Gate dense + loader.port_weight( + keras_variable=expert._feedforward_gate_dense.kernel, + hf_weight_key=f"model.layers.{index}.block_sparse_moe.experts.{expert_index}.w1.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), + ) + # w3: Intermediate dense + loader.port_weight( + keras_variable=expert._feedforward_intermediate_dense.kernel, + hf_weight_key=f"model.layers.{index}.block_sparse_moe.experts.{expert_index}.w3.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), + ) + # w2: Output dense + loader.port_weight( + keras_variable=expert._feedforward_output_dense.kernel, + hf_weight_key=f"model.layers.{index}.block_sparse_moe.experts.{expert_index}.w2.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), + ) + # Normalization loader.port_weight( keras_variable=backbone.layer_norm.scale, @@ -115,6 +119,5 @@ def convert_weights(backbone, loader, transformers_config): hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), ) - def convert_tokenizer(cls, preset, **kwargs): return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 0d58747631..8a98fc0c83 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -3,7 +3,7 @@ from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup -from keras_hub.src.utils.transformers import convert_albert +from keras_hub.src.utils.transformers import convert_albert, convert_mixtral from keras_hub.src.utils.transformers import convert_bart from keras_hub.src.utils.transformers import convert_bert from keras_hub.src.utils.transformers import convert_distilbert @@ -44,6 +44,8 @@ def __init__(self, preset, config): self.converter = convert_vit elif model_type == "qwen2": self.converter = convert_qwen + elif model_type == "mixtral": + self.converter = convert_mixtral else: raise ValueError( "KerasHub has no converter for huggingface/transformers models " diff --git a/tools/checkpoint_conversion/convert_mixtral_checkpoints.py b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py index 63e0c8ac1d..eaa2c4271f 100644 --- a/tools/checkpoint_conversion/convert_mixtral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py @@ -1,292 +1,136 @@ -# import gc -# import os -# import shutil -# import tempfile -# import traceback - -# import numpy as np -# import requests -# from absl import app -# from absl import flags -# from keras import ops -# from transformers import AutoTokenizer -import torch -from transformers import MixtralForCausalLM - -# from keras_hub.models import MixtralBackbone -# from keras_hub.models import MixtralCausalLMPreprocessor -# from keras_hub.models import MixtralTokenizer -# from keras_hub.utils.preset_utils import save_to_preset - -mixtral_lm = MixtralForCausalLM.from_pretrained( - "mistralai/Mixtral-8x7B-v0.1", - torch_dtype=torch.bfloat16 - ) - -print(mixtral_lm) - -# PRESET_MAP = {} - -# FLAGS = flags.FLAGS -# flags.DEFINE_string( -# "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" -# ) - - -# def convert_checkpoints(keras_hub_model, hf_model): -# config = hf_model.config - -# keras_hub_model.token_embedding.embeddings.assign( -# hf_model.model.embed_tokens.weight.detach().cpu().numpy() -# ) - -# for i in range(keras_hub_model.num_layers): -# keras_hub_model.transformer_layers[ -# i -# ]._self_attention_layer._key_dense.set_weights( -# [ -# hf_model.model.layers[i] -# .self_attn.k_proj.weight.T.reshape( -# config.hidden_size, -# config.num_key_value_heads, -# config.hidden_size // config.num_attention_heads, -# ) -# .detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._self_attention_layer._query_dense.set_weights( -# [ -# hf_model.model.layers[i] -# .self_attn.q_proj.weight.T.reshape( -# config.hidden_size, -# config.num_attention_heads, -# config.hidden_size // config.num_attention_heads, -# ) -# .detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._self_attention_layer._value_dense.set_weights( -# [ -# hf_model.model.layers[i] -# .self_attn.v_proj.weight.T.reshape( -# config.hidden_size, -# config.num_key_value_heads, -# config.hidden_size // config.num_attention_heads, -# ) -# .detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._self_attention_layer._output_dense.set_weights( -# [ -# hf_model.model.layers[i] -# .self_attn.o_proj.weight.T.reshape( -# config.num_attention_heads, -# config.hidden_size // config.num_attention_heads, -# config.hidden_size, -# ) -# .detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._self_attention_layernorm.set_weights( -# [ -# hf_model.model.layers[i] -# .input_layernorm.weight.detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._feedforward_intermediate_dense.set_weights( -# [ -# hf_model.model.layers[i] -# .mlp.up_proj.weight.T.detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._feedforward_output_dense.set_weights( -# [ -# hf_model.model.layers[i] -# .mlp.down_proj.weight.T.detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._feedforward_gate_dense.set_weights( -# [ -# hf_model.model.layers[i] -# .mlp.gate_proj.weight.T.detach() -# .cpu() -# .numpy() -# ] -# ) -# keras_hub_model.transformer_layers[ -# i -# ]._feedforward_layernorm.set_weights( -# [ -# hf_model.model.layers[i] -# .post_attention_layernorm.weight.detach() -# .cpu() -# .numpy() -# ] -# ) - -# keras_hub_model.layer_norm.set_weights( -# [hf_model.model.norm.weight.detach().cpu().numpy()] -# ) -# keras_hub_model.token_embedding.reverse_embeddings.assign( -# hf_model.lm_head.weight.T.detach().cpu().numpy() -# ) +import os +import traceback +os.environ["KERAS_BACKEND"] = "torch" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Hide any CUDA devices -# def test_model( -# keras_hub_model, keras_hub_tokenizer, hf_model, hf_model_tokenizer -# ): -# # First, test that the number of parameters match -# keras_hub_params = keras_hub_model.count_params() -# hf_params = hf_model.num_parameters() -# assert keras_hub_params == hf_params - -# # Test the outputs of both the models -# hf_outputs = hf_model( -# **hf_model_tokenizer(["What is Keras?"], return_tensors="pt") -# ) -# hf_output_logits = hf_outputs.logits.detach().cpu().numpy() - -# keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) -# keras_hub_output = keras_hub_model( -# keras_hub_preprocessor(["What is Keras?"], sequence_length=6)[0] -# ) -# keras_hub_logits = keras_hub_model.token_embedding( -# keras_hub_output, reverse=True -# ) -# keras_hub_logits = ops.convert_to_numpy(keras_hub_logits) - -# # High tolerence since bfloat16 is used as the default dtype for Mixtral -# try: -# np.testing.assert_allclose( -# keras_hub_logits, hf_output_logits, atol=1e-4 -# ) -# except AssertionError as err: -# print("\n") -# print(traceback.format_exc()) -# print(err.args[0]) -# print("\n") +import numpy as np +import torch +from absl import app +from absl import flags +device = torch.device("cpu") +# Force PyTorch to use CPU +torch.set_default_device(device) -# def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): -# hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") -# hf_output = hf_output["input_ids"].detach().cpu().numpy() -# keras_hub_preprocessor = MixtralCausalLMPreprocessor(keras_hub_tokenizer) -# keras_hub_output = keras_hub_preprocessor( -# ["What is Keras?"], sequence_length=6 -# ) -# keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) +from keras import ops # noqa: E402 +from transformers import AutoModelForCausalLM # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 -# np.testing.assert_equal(keras_hub_output, hf_output) +import keras_hub # noqa: E402 -# def main(_): -# # === Get the preset name === -# if FLAGS.preset not in PRESET_MAP.keys(): -# raise ValueError( -# f"Invalid preset {FLAGS.preset}. Must be one " -# f"of {','.join(PRESET_MAP.keys())}" -# ) -# preset = FLAGS.preset -# hf_preset = PRESET_MAP[preset] +PRESET_MAP = { + "mixtral_8_7b_en":"mistralai/Mixtral-8x7B-v0.1" +} -# # === Create the temporary save directories === -# temp_dir = tempfile.mkdtemp() +FLAGS = flags.FLAGS +flags.DEFINE_string( + "preset", None, f"Must be one of {','.join(PRESET_MAP.keys())}" +) -# try: -# # === Load the Huggingface model === -# hf_model = MixtralForCausalLM.from_pretrained(hf_preset) -# hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset) -# hf_model.eval() -# print("\n-> Huggingface model and tokenizer loaded") -# # === Load the KerasHub model === -# backbone_kwargs = dict( -# vocabulary_size=hf_model.config.vocab_size, -# hidden_dim=hf_model.config.hidden_size, -# num_layers=hf_model.config.num_hidden_layers, -# num_query_heads=hf_model.config.num_attention_heads, -# num_key_value_heads=hf_model.config.num_key_value_heads, -# intermediate_dim=hf_model.config.intermediate_size, -# sliding_window=hf_model.config.sliding_window, -# layer_norm_epsilon=hf_model.config.rms_norm_eps, -# rope_max_wavelength=hf_model.config.rope_theta, -# dtype="float32", -# ) -# keras_hub_model = MixtralBackbone(**backbone_kwargs) +def compute_hf_output(hf_model, hf_model_tokenizer): + hf_inputs = hf_model_tokenizer(["What is Keras?"], return_tensors="pt").to( + device + ) + hf_outputs = hf_model(**hf_inputs) + hf_output_logits = hf_outputs.logits.detach().cpu().float().numpy() -# # === Download the tokenizer from Huggingface model card === -# spm_path = ( -# f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" -# ) -# response = requests.get(spm_path) -# if not response.ok: -# raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") -# tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") -# with open(tokenizer_path, "wb") as tokenizer_file: -# tokenizer_file.write(response.content) -# keras_hub_tokenizer = MixtralTokenizer(tokenizer_path) -# print("\n-> Keras 3 model and tokenizer loaded.") + return hf_output_logits -# # === Port the weights === -# convert_checkpoints(keras_hub_model, hf_model) -# print("\n-> Weight transfer done.") +def compute_keras_output(keras_hub_model, keras_hub_tokenizer): + keras_hub_preprocessor = keras_hub.models.MixtralCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_inputs = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=6 + )[0] + keras_hub_inputs = {k: v.to(device) for k, v in keras_hub_inputs.items()} + + keras_hub_output = keras_hub_model(keras_hub_inputs) + keras_hub_output_logits = keras_hub_model.token_embedding( + keras_hub_output, reverse=True + ) + keras_hub_output_logits = ops.convert_to_numpy(keras_hub_output_logits) + return keras_hub_output_logits -# # === Check that the models and tokenizers outputs match === -# test_tokenizer(keras_hub_tokenizer, hf_tokenizer) -# test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) -# print("\n-> Tests passed!") -# # === Save the model weights in float32 format === -# keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) -# print("\n-> Saved the model weights in float32") -# del keras_hub_model, hf_model -# gc.collect() +def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): + hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") + hf_output = hf_output["input_ids"].detach().cpu().numpy() + keras_hub_preprocessor = keras_hub.models.MixtralCausalLMPreprocessor( + keras_hub_tokenizer + ) + keras_hub_output = keras_hub_preprocessor( + ["What is Keras?"], sequence_length=6 + ) + keras_hub_output = ops.convert_to_numpy(keras_hub_output[0]["token_ids"]) + + np.testing.assert_equal(keras_hub_output, hf_output) + + +def main(_): + # === Get the preset name === + # if FLAGS.preset not in PRESET_MAP.keys(): + # raise ValueError( + # f"Invalid preset {FLAGS.preset}. Must be one " + # f"of {','.join(PRESET_MAP.keys())}" + # ) + # preset = FLAGS.preset + # hf_preset = PRESET_MAP[preset] + preset = "mixtral_8_7b_en" + hf_preset = "mistralai/Mixtral-8x7B-v0.1" + + # === Load the Huggingface model === + hf_model = AutoModelForCausalLM.from_pretrained( + hf_preset, + device_map=device, + ) + hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") + hf_model.eval() + print("\n-> Huggingface model and tokenizer loaded") + + keras_hub_tokenizer = keras_hub.models.MixtralTokenizer.from_preset( + f"hf://{hf_preset}" + ) + print("\n-> Keras tokenizer loaded") + test_tokenizer(keras_hub_tokenizer, hf_tokenizer) -# # === Save the weights again in float16 === -# backbone_kwargs["dtype"] = "float16" -# keras_hub_model = MixtralBackbone(**backbone_kwargs) -# keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) -# save_to_preset(keras_hub_model, preset) -# print("\n-> Saved the model preset in float16") + print(f"\n -> Keras tokenizer test successful") -# # === Save the tokenizer === -# save_to_preset( -# keras_hub_tokenizer, preset, config_filename="tokenizer.json" -# ) -# print("\n-> Saved the tokenizer") -# finally: -# shutil.rmtree(temp_dir) + hf_params = hf_model.num_parameters() + hf_output_logits = compute_hf_output(hf_model, hf_tokenizer) + print(f"\n -> Computed HF outputs successfully") + del hf_model, hf_tokenizer + keras_hub_model = keras_hub.models.MixtralBackbone.from_preset( + f"hf://{hf_preset}" + ) + print("\n-> Keras model loaded") + + keras_hub_params = keras_hub_model.count_params() + assert keras_hub_params == hf_params + + keras_hub_output_logits = compute_keras_output( + keras_hub_model, + keras_hub_tokenizer + ) -# if __name__ == "__main__": -# flags.mark_flag_as_required("preset") -# app.run(main) + try: + np.testing.assert_allclose( + keras_hub_output_logits, hf_output_logits, atol=1e-4 + ) + except AssertionError as err: + print("\n") + print(traceback.format_exc()) + print(err.args[0]) + print("\n") + + print("\n-> Tests passed!") + + +if __name__ == "__main__": + # flags.mark_flag_as_required("preset") + app.run(main) From d5aee618a40aad4ca3a8a0b02eae123d05b04a70 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Wed, 9 Apr 2025 19:02:13 +0000 Subject: [PATCH 06/11] batched moe impl --- .../src/models/mixtral/mixtral_decoder.py | 212 ++++++++---------- .../src/utils/transformers/convert_mixtral.py | 126 ++++++----- 2 files changed, 157 insertions(+), 181 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index d26cc21f63..6892d714f0 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -16,82 +16,74 @@ from keras_hub.src.utils.keras_utils import clone_initializer -class MixtralMoeMLP(keras.layers.Layer): + +class MixtralMoeExperts(keras.layers.Layer): + """Batched feed-forward experts for Mixtral (pure keras.ops).""" + def __init__( self, - intermediate_dim, + num_experts, hidden_dim, + intermediate_dim, activation_fn="silu", - layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", **kwargs, ): super().__init__(**kwargs) - self.intermediate_dim = intermediate_dim + self.num_experts = num_experts self.hidden_dim = hidden_dim - self.activation_fn = activation_fn - self.kernel_initializer = kernel_initializer - self.layer_norm_epsilon = layer_norm_epsilon - - def build(self, decoder_sequence_shape): - # Feedforward layers. - self._feedforward_intermediate_dense = keras.layers.Dense( - self.intermediate_dim, - kernel_initializer=clone_initializer(self.kernel_initializer), - use_bias=False, - dtype=self.dtype_policy, - name="feedforward_intermediate_dense", - ) - self._feedforward_intermediate_dense.build(decoder_sequence_shape) + self.intermediate_dim = intermediate_dim + self.activation = keras.activations.get(activation_fn) + self.kernel_initializer = keras.initializers.get(kernel_initializer) - self._feedforward_gate_dense = keras.layers.Dense( - self.intermediate_dim, - kernel_initializer=clone_initializer(self.kernel_initializer), - use_bias=False, - dtype=self.dtype_policy, - name="feedforward_gate_dense", + def build(self, _): + # Weight for gate dense layer: [num_experts, hidden_dim, intermediate_dim] + self._expert_feedforward_gate_dense = self.add_weight( + shape=(self.num_experts, self.hidden_dim, self.intermediate_dim), + initializer=self.kernel_initializer, + trainable=True, + name="expert_feedforward_gate_dense", ) - self._feedforward_gate_dense.build(decoder_sequence_shape) - - self._feedforward_output_dense = keras.layers.Dense( - self.hidden_dim, - kernel_initializer=clone_initializer(self.kernel_initializer), - use_bias=False, - dtype=self.dtype_policy, - name="feedforward_output_dense", + # Weight for intermediate dense layer: [num_experts, hidden_dim, intermediate_dim] + self._expert_feedforward_intermediate_dense = self.add_weight( + shape=(self.num_experts, self.hidden_dim, self.intermediate_dim), + initializer=self.kernel_initializer, + trainable=True, + name="expert_feedforward_intermediate_dense", ) - - self._feedforward_output_dense.build( - self._feedforward_gate_dense.compute_output_shape( - decoder_sequence_shape - ) + # Weight for output dense layer: [num_experts, intermediate_dim, hidden_dim] + self._expert_feedforward_output_dense = self.add_weight( + shape=(self.num_experts, self.intermediate_dim, self.hidden_dim), + initializer=self.kernel_initializer, + trainable=True, + name="expert_feedforward_output_dense", ) - - self.activation = keras.activations.get(self.activation_fn) self.built = True - def call(self, x): - gate_output = self._feedforward_gate_dense(x) - - # Note that we run the activation function in full 32-bit - # precision since this is what `torch.nn.functional.silu` - # does. Internally, `torch.nn.functional.silu` converts the - # inputs to float32, computes SiLU, and converts the outputs - # back to compute dtype. - # CPU Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cpu/Activation.cpp#L1221-L1235 # noqa: E501 - # CUDA Kernel: https://github.com/pytorch/pytorch/blob/35c493f2cf9b623bfdc7e6b34dc1cb39690a7919/aten/src/ATen/native/cuda/ActivationSiluKernel.cu # noqa: E501 - gate_output = ops.cast(gate_output, "float32") - gate_output = self.activation(gate_output) - gate_output = ops.cast(gate_output, self.compute_dtype) - - x = self._feedforward_intermediate_dense(x) - - x = self._feedforward_output_dense(ops.multiply(x, gate_output)) - - return x + def call(self, hidden_states): + # Compute gate output for all experts: [num_experts, tokens, intermediate_dim] + gate = ops.einsum( + "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense + ) + gate = ops.cast(gate, "float32") # Match PyTorch SiLU precision + gate = self.activation(gate) + gate = ops.cast(gate, self.compute_dtype) + # Compute intermediate output for all experts: [num_experts, tokens, intermediate_dim] + intermediate = ops.einsum( + "th,ehm->etm", hidden_states, self._expert_feedforward_intermediate_dense + ) + hidden = intermediate * gate # Element-wise multiplication + # Compute final output: [num_experts, tokens, hidden_dim] + out = ops.einsum( + "eti,eih->eth", hidden, self._expert_feedforward_output_dense + ) + return out + class MixtralSparseMoeBlock(keras.layers.Layer): + """Mixtral sparse MoE block rewritten in batched style.""" + def __init__( self, hidden_dim, @@ -109,99 +101,71 @@ def __init__( self.num_experts = num_experts self.top_k = top_k self.router_jitter_noise = router_jitter_noise - self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) def build(self, decoder_sequence_shape): + # Router dense layer to compute logits for expert selection self._sparse_feedforward_gate_dense = keras.layers.Dense( self.num_experts, - kernel_initializer=clone_initializer(self.kernel_initializer), + kernel_initializer=self.kernel_initializer, use_bias=False, dtype=self.dtype_policy, name="sparse_feedforward_gate_dense", ) self._sparse_feedforward_gate_dense.build(decoder_sequence_shape) - self.experts = [ - MixtralMoeMLP( - intermediate_dim=self.intermediate_dim, - hidden_dim=self.hidden_dim, - kernel_initializer=self.kernel_initializer, - layer_norm_epsilon=self.layer_norm_epsilon, - ) - for _ in range(self.num_experts) - ] - for expert in self.experts: - expert.build(decoder_sequence_shape) + # Batched expert bank + self.expert_bank = MixtralMoeExperts( + num_experts=self.num_experts, + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + kernel_initializer=self.kernel_initializer, + name="experts", + ) + self.expert_bank.build(decoder_sequence_shape) + self.built = True def call(self, hidden_states, training=False): - batch_size, seq_len, hidden_dim = hidden_states.shape + batch_size, seq_len, _ = ops.shape(hidden_states) + hidden_states_flattened = ops.reshape(hidden_states, (-1, self.hidden_dim)) - # Jitter noise augmentation (training only) + # Apply jitter noise during training if specified if training and self.router_jitter_noise > 0: random_factors = ops.random.uniform( - shape=ops.shape(hidden_states), + shape=ops.shape(hidden_states_flattened), minval=1.0 - self.router_jitter_noise, maxval=1.0 + self.router_jitter_noise, - dtype=hidden_states.dtype, + dtype=hidden_states_flattened.dtype, ) - hidden_states = hidden_states * random_factors - - hidden_states_2d = ops.reshape(hidden_states, (-1, hidden_dim)) - - router_logits = self._sparse_feedforward_gate_dense(hidden_states_2d) - routing_weights = ops.softmax(router_logits, axis=1) - - routing_weights, selected_experts = ops.top_k( - routing_weights, k=self.top_k - ) - sum_topk = ops.sum(routing_weights, axis=-1, keepdims=True) - routing_weights = routing_weights / sum_topk - - routing_weights = ops.cast(routing_weights, hidden_states.dtype) - - # Prepare final hidden states - final_hidden_states = ops.zeros( - (batch_size * seq_len, hidden_dim), dtype=hidden_states.dtype - ) - - expert_mask = ops.one_hot( - selected_experts, num_classes=self.num_experts - ) - expert_mask = ops.transpose(expert_mask, axes=[2, 1, 0]) - - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] + hidden_states_flattened = hidden_states_flattened * random_factors - idx, top_x = ops.where(expert_mask[expert_idx]) + # Compute router logits and probabilities + router_logits = self._sparse_feedforward_gate_dense(hidden_states_flattened) + router_probs = ops.softmax(router_logits, axis=-1) - if ops.shape(top_x)[0] == 0: - continue + # Select top-k experts and their probabilities + top_p, top_i = ops.top_k(router_probs, k=self.top_k) + sum_topk = ops.sum(top_p, axis=-1, keepdims=True) + top_p = top_p / sum_topk # Normalize top-k probabilities - # Gather hidden states belonging to this expert - current_state = ops.take(hidden_states_2d, top_x, axis=0) - expert_output = expert_layer(current_state) + # Create routing weights for all experts + one_hot = ops.one_hot(top_i, self.num_experts) # [tokens, top_k, num_experts] + routing_full = ops.sum(one_hot * top_p[..., None], axis=1) # [tokens, num_experts] + routing_full = ops.transpose(routing_full, (1, 0)) # [num_experts, tokens] + routing_full = ops.cast(routing_full, hidden_states_flattened.dtype) - # Multiply by routing weights - # routing_weights is shape (batch_size*seq_len, top_k) - # We want routing_weights[top_x, idx] - factor = routing_weights[top_x, idx] - factor = ops.expand_dims(factor, axis=-1) # shape = (n_tokens, 1) - current_hidden_states = expert_output * factor + # Compute expert outputs in a batched manner + expert_out = self.expert_bank(hidden_states_flattened) # [num_experts, tokens, hidden_dim] - existing_values = ops.take(final_hidden_states, top_x, axis=0) - updated_values = existing_values + current_hidden_states - final_hidden_states = ops.scatter_update( - final_hidden_states, top_x[:, None], updated_values - ) - - final_hidden_states = ops.reshape( - final_hidden_states, (batch_size, seq_len, hidden_dim) - ) + # Weight expert outputs by routing probabilities + weighted_out = expert_out * routing_full[:, :, None] # [num_experts, tokens, hidden_dim] + expert_contribution = ops.sum(weighted_out, axis=0) # [tokens, hidden_dim] - return final_hidden_states, router_logits + # Reshape back to original dimensions + out = ops.reshape(expert_contribution, (batch_size, seq_len, self.hidden_dim)) + return out, router_logits class MixtralTransformerDecoder(keras.layers.Layer): def __init__( diff --git a/keras_hub/src/utils/transformers/convert_mixtral.py b/keras_hub/src/utils/transformers/convert_mixtral.py index 07fc9dad10..7d289fb2fd 100644 --- a/keras_hub/src/utils/transformers/convert_mixtral.py +++ b/keras_hub/src/utils/transformers/convert_mixtral.py @@ -22,102 +22,114 @@ def convert_backbone_config(transformers_config): "output_router_logits": transformers_config['output_router_logits'], } + def convert_weights(backbone, loader, transformers_config): # Embeddings loader.port_weight( - keras_variable=backbone.token_embedding.embeddings, + keras_variable=backbone.get_layer("token_embedding").embeddings, hf_weight_key="model.embed_tokens.weight", - hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), ) loader.port_weight( - keras_variable=backbone.token_embedding.reverse_embeddings, + keras_variable=backbone.get_layer("token_embedding").reverse_embeddings, hf_weight_key="lm_head.weight", - hook_fn=lambda hf_tensor, _: np.transpose( - hf_tensor.astype(np.float16), axes=(1, 0) - ), + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), ) - # Attention blocks and MoE blocks - for index in range(backbone.num_layers): - decoder_layer = backbone.transformer_layers[index] + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") - # Norm layers + # Input layernorm loader.port_weight( keras_variable=decoder_layer._self_attention_layernorm.scale, - hf_weight_key=f"model.layers.{index}.input_layernorm.weight", - hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), - ) - loader.port_weight( - keras_variable=decoder_layer._feedforward_layernorm.scale, - hf_weight_key=f"model.layers.{index}.post_attention_layernorm.weight", - hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", ) # Attention layers + ## Query loader.port_weight( keras_variable=decoder_layer._self_attention_layer._query_dense.kernel, - hf_weight_key=f"model.layers.{index}.self_attn.q_proj.weight", - hook_fn=lambda hf_tensor, keras_shape: np.reshape( - np.transpose(hf_tensor.astype(np.float16)), keras_shape - ), + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, ) + ## Key loader.port_weight( keras_variable=decoder_layer._self_attention_layer._key_dense.kernel, - hf_weight_key=f"model.layers.{index}.self_attn.k_proj.weight", - hook_fn=lambda hf_tensor, keras_shape: np.reshape( - np.transpose(hf_tensor.astype(np.float16)), keras_shape - ), + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, ) + ## Value loader.port_weight( keras_variable=decoder_layer._self_attention_layer._value_dense.kernel, - hf_weight_key=f"model.layers.{index}.self_attn.v_proj.weight", - hook_fn=lambda hf_tensor, keras_shape: np.reshape( - np.transpose(hf_tensor.astype(np.float16)), keras_shape - ), + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, ) + ## Output loader.port_weight( keras_variable=decoder_layer._self_attention_layer._output_dense.kernel, - hf_weight_key=f"model.layers.{index}.self_attn.o_proj.weight", - hook_fn=lambda hf_tensor, keras_shape: np.reshape( - np.transpose(hf_tensor.astype(np.float16)), keras_shape - ), + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + hook_fn=transpose_and_reshape, ) - # MoE block - Router gate + # MoE layers + # Router gate loader.port_weight( keras_variable=decoder_layer._sparse_moe_block._sparse_feedforward_gate_dense.kernel, - hf_weight_key=f"model.layers.{index}.block_sparse_moe.gate.weight", - hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), + hf_weight_key=f"model.layers.{i}.block_sparse_moe.gate.weight", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), ) - # MoE block - Experts - for expert_index in range(backbone.num_experts): - expert = decoder_layer._sparse_moe_block.experts[expert_index] - # w1: Gate dense - loader.port_weight( - keras_variable=expert._feedforward_gate_dense.kernel, - hf_weight_key=f"model.layers.{index}.block_sparse_moe.experts.{expert_index}.w1.weight", - hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), + # Batched experts: w1 (gate), w3 (intermediate), and w2 (output) weights + gate_weights_list = [] + intermediate_weights_list = [] + output_weights_list = [] + for expert_idx in range(backbone.num_experts): + # Load w1 (gate dense) for each expert + w1 = loader.get_tensor( + f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w1.weight" ) - # w3: Intermediate dense - loader.port_weight( - keras_variable=expert._feedforward_intermediate_dense.kernel, - hf_weight_key=f"model.layers.{index}.block_sparse_moe.experts.{expert_index}.w3.weight", - hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), + w1_transposed = np.transpose(w1, axes=(1, 0)) # [hidden_dim, intermediate_dim] + gate_weights_list.append(w1_transposed) + + # Load w3 (intermediate dense) for each expert + w3 = loader.get_tensor( + f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w3.weight" ) - # w2: Output dense - loader.port_weight( - keras_variable=expert._feedforward_output_dense.kernel, - hf_weight_key=f"model.layers.{index}.block_sparse_moe.experts.{expert_index}.w2.weight", - hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor.astype(np.float16)), + w3_transposed = np.transpose(w3, axes=(1, 0)) # [hidden_dim, intermediate_dim] + intermediate_weights_list.append(w3_transposed) + + # Load w2 (output dense) for each expert + w2 = loader.get_tensor( + f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w2.weight" ) + w2_transposed = np.transpose(w2, axes=(1, 0)) # [intermediate_dim, hidden_dim] + output_weights_list.append(w2_transposed) - # Normalization + # Stack the lists to create batched weights + gate_batched = np.stack(gate_weights_list, axis=0) # [num_experts, hidden_dim, intermediate_dim] + intermediate_batched = np.stack(intermediate_weights_list, axis=0) # [num_experts, hidden_dim, intermediate_dim] + output_batched = np.stack(output_weights_list, axis=0) # [num_experts, intermediate_dim, hidden_dim] + + # Assign batched weights to expert_bank + decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_gate_dense.assign(gate_batched) + decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_intermediate_dense.assign(intermediate_batched) + decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_output_dense.assign(output_batched) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer._feedforward_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer loader.port_weight( - keras_variable=backbone.layer_norm.scale, + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, hf_weight_key="model.norm.weight", - hook_fn=lambda hf_tensor, _: hf_tensor.astype(np.float16), ) + return backbone + def convert_tokenizer(cls, preset, **kwargs): return cls(get_file(preset, "tokenizer.model"), **kwargs) From 3597d5312609e7a908fbf270437d570d3b905f8f Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Fri, 11 Apr 2025 16:11:02 +0000 Subject: [PATCH 07/11] output matching with batched moe complete --- .../src/models/mixtral/mixtral_backbone.py | 2 +- .../src/models/mixtral/mixtral_decoder.py | 60 +++++++++++++------ .../src/utils/transformers/convert_mixtral.py | 34 ++++++----- .../src/utils/transformers/preset_loader.py | 3 +- .../convert_mixtral_checkpoints.py | 36 +++++------ 5 files changed, 80 insertions(+), 55 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_backbone.py b/keras_hub/src/models/mixtral/mixtral_backbone.py index 06a80127f6..fef053055c 100644 --- a/keras_hub/src/models/mixtral/mixtral_backbone.py +++ b/keras_hub/src/models/mixtral/mixtral_backbone.py @@ -99,7 +99,7 @@ def __init__( num_experts, top_k, output_router_logits, - router_jitter_noise=0., + router_jitter_noise=0.0, rope_max_wavelength=10000, rope_scaling_factor=1.0, layer_norm_epsilon=1e-6, diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index 6892d714f0..98681ca0fd 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -16,7 +16,6 @@ from keras_hub.src.utils.keras_utils import clone_initializer - class MixtralMoeExperts(keras.layers.Layer): """Batched feed-forward experts for Mixtral (pure keras.ops).""" @@ -37,21 +36,24 @@ def __init__( self.kernel_initializer = keras.initializers.get(kernel_initializer) def build(self, _): - # Weight for gate dense layer: [num_experts, hidden_dim, intermediate_dim] + # Weight for gate dense layer: + # [num_experts, hidden_dim, intermediate_dim] self._expert_feedforward_gate_dense = self.add_weight( shape=(self.num_experts, self.hidden_dim, self.intermediate_dim), initializer=self.kernel_initializer, trainable=True, name="expert_feedforward_gate_dense", ) - # Weight for intermediate dense layer: [num_experts, hidden_dim, intermediate_dim] + # Weight for intermediate dense layer: + # [num_experts, hidden_dim, intermediate_dim] self._expert_feedforward_intermediate_dense = self.add_weight( shape=(self.num_experts, self.hidden_dim, self.intermediate_dim), initializer=self.kernel_initializer, trainable=True, name="expert_feedforward_intermediate_dense", ) - # Weight for output dense layer: [num_experts, intermediate_dim, hidden_dim] + # Weight for output dense layer: + # [num_experts, intermediate_dim, hidden_dim] self._expert_feedforward_output_dense = self.add_weight( shape=(self.num_experts, self.intermediate_dim, self.hidden_dim), initializer=self.kernel_initializer, @@ -61,7 +63,8 @@ def build(self, _): self.built = True def call(self, hidden_states): - # Compute gate output for all experts: [num_experts, tokens, intermediate_dim] + # Compute gate output for all experts: + # [num_experts, tokens, intermediate_dim] gate = ops.einsum( "th,ehm->etm", hidden_states, self._expert_feedforward_gate_dense ) @@ -69,9 +72,12 @@ def call(self, hidden_states): gate = self.activation(gate) gate = ops.cast(gate, self.compute_dtype) - # Compute intermediate output for all experts: [num_experts, tokens, intermediate_dim] + # Compute intermediate output for all experts: + # [num_experts, tokens, intermediate_dim] intermediate = ops.einsum( - "th,ehm->etm", hidden_states, self._expert_feedforward_intermediate_dense + "th,ehm->etm", + hidden_states, + self._expert_feedforward_intermediate_dense, ) hidden = intermediate * gate # Element-wise multiplication @@ -80,7 +86,8 @@ def call(self, hidden_states): "eti,eih->eth", hidden, self._expert_feedforward_output_dense ) return out - + + class MixtralSparseMoeBlock(keras.layers.Layer): """Mixtral sparse MoE block rewritten in batched style.""" @@ -128,7 +135,9 @@ def build(self, decoder_sequence_shape): def call(self, hidden_states, training=False): batch_size, seq_len, _ = ops.shape(hidden_states) - hidden_states_flattened = ops.reshape(hidden_states, (-1, self.hidden_dim)) + hidden_states_flattened = ops.reshape( + hidden_states, (-1, self.hidden_dim) + ) # Apply jitter noise during training if specified if training and self.router_jitter_noise > 0: @@ -141,7 +150,9 @@ def call(self, hidden_states, training=False): hidden_states_flattened = hidden_states_flattened * random_factors # Compute router logits and probabilities - router_logits = self._sparse_feedforward_gate_dense(hidden_states_flattened) + router_logits = self._sparse_feedforward_gate_dense( + hidden_states_flattened + ) router_probs = ops.softmax(router_logits, axis=-1) # Select top-k experts and their probabilities @@ -150,23 +161,38 @@ def call(self, hidden_states, training=False): top_p = top_p / sum_topk # Normalize top-k probabilities # Create routing weights for all experts - one_hot = ops.one_hot(top_i, self.num_experts) # [tokens, top_k, num_experts] - routing_full = ops.sum(one_hot * top_p[..., None], axis=1) # [tokens, num_experts] - routing_full = ops.transpose(routing_full, (1, 0)) # [num_experts, tokens] + one_hot = ops.one_hot( + top_i, self.num_experts + ) # [tokens, top_k, num_experts] + routing_full = ops.sum( + one_hot * top_p[..., None], axis=1 + ) # [tokens, num_experts] + routing_full = ops.transpose( + routing_full, (1, 0) + ) # [num_experts, tokens] routing_full = ops.cast(routing_full, hidden_states_flattened.dtype) # Compute expert outputs in a batched manner - expert_out = self.expert_bank(hidden_states_flattened) # [num_experts, tokens, hidden_dim] + expert_out = self.expert_bank( + hidden_states_flattened + ) # [num_experts, tokens, hidden_dim] # Weight expert outputs by routing probabilities - weighted_out = expert_out * routing_full[:, :, None] # [num_experts, tokens, hidden_dim] - expert_contribution = ops.sum(weighted_out, axis=0) # [tokens, hidden_dim] + weighted_out = ( + expert_out * routing_full[:, :, None] + ) # [num_experts, tokens, hidden_dim] + expert_contribution = ops.sum( + weighted_out, axis=0 + ) # [tokens, hidden_dim] # Reshape back to original dimensions - out = ops.reshape(expert_contribution, (batch_size, seq_len, self.hidden_dim)) + out = ops.reshape( + expert_contribution, (batch_size, seq_len, self.hidden_dim) + ) return out, router_logits + class MixtralTransformerDecoder(keras.layers.Layer): def __init__( self, diff --git a/keras_hub/src/utils/transformers/convert_mixtral.py b/keras_hub/src/utils/transformers/convert_mixtral.py index 7d289fb2fd..d64ebf5c92 100644 --- a/keras_hub/src/utils/transformers/convert_mixtral.py +++ b/keras_hub/src/utils/transformers/convert_mixtral.py @@ -14,12 +14,12 @@ def convert_backbone_config(transformers_config): "hidden_dim": transformers_config["hidden_size"], "intermediate_dim": transformers_config["intermediate_size"], "num_key_value_heads": transformers_config["num_key_value_heads"], - "num_experts": transformers_config['num_local_experts'], - "top_k": transformers_config['num_experts_per_tok'], + "num_experts": transformers_config["num_local_experts"], + "top_k": transformers_config["num_experts_per_tok"], "rope_max_wavelength": transformers_config["rope_theta"], "layer_norm_epsilon": transformers_config["rms_norm_eps"], "sliding_window": transformers_config["sliding_window"], - "output_router_logits": transformers_config['output_router_logits'], + "output_router_logits": transformers_config["output_router_logits"], } @@ -90,32 +90,35 @@ def transpose_and_reshape(x, shape): w1 = loader.get_tensor( f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w1.weight" ) - w1_transposed = np.transpose(w1, axes=(1, 0)) # [hidden_dim, intermediate_dim] + w1_transposed = np.transpose(w1, axes=(1, 0)) gate_weights_list.append(w1_transposed) - # Load w3 (intermediate dense) for each expert w3 = loader.get_tensor( f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w3.weight" ) - w3_transposed = np.transpose(w3, axes=(1, 0)) # [hidden_dim, intermediate_dim] + w3_transposed = np.transpose(w3, axes=(1, 0)) intermediate_weights_list.append(w3_transposed) - # Load w2 (output dense) for each expert w2 = loader.get_tensor( f"model.layers.{i}.block_sparse_moe.experts.{expert_idx}.w2.weight" ) - w2_transposed = np.transpose(w2, axes=(1, 0)) # [intermediate_dim, hidden_dim] + w2_transposed = np.transpose(w2, axes=(1, 0)) output_weights_list.append(w2_transposed) - # Stack the lists to create batched weights - gate_batched = np.stack(gate_weights_list, axis=0) # [num_experts, hidden_dim, intermediate_dim] - intermediate_batched = np.stack(intermediate_weights_list, axis=0) # [num_experts, hidden_dim, intermediate_dim] - output_batched = np.stack(output_weights_list, axis=0) # [num_experts, intermediate_dim, hidden_dim] + gate_batched = np.stack(gate_weights_list, axis=0) + intermediate_batched = np.stack(intermediate_weights_list, axis=0) + output_batched = np.stack(output_weights_list, axis=0) # Assign batched weights to expert_bank - decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_gate_dense.assign(gate_batched) - decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_intermediate_dense.assign(intermediate_batched) - decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_output_dense.assign(output_batched) + decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_gate_dense.assign( + gate_batched + ) + decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_intermediate_dense.assign( + intermediate_batched + ) + decoder_layer._sparse_moe_block.expert_bank._expert_feedforward_output_dense.assign( + output_batched + ) # Feedforward layernorm loader.port_weight( @@ -131,5 +134,6 @@ def transpose_and_reshape(x, shape): return backbone + def convert_tokenizer(cls, preset, **kwargs): return cls(get_file(preset, "tokenizer.model"), **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index 8a98fc0c83..a9c0942aa1 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -3,7 +3,7 @@ from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup -from keras_hub.src.utils.transformers import convert_albert, convert_mixtral +from keras_hub.src.utils.transformers import convert_albert from keras_hub.src.utils.transformers import convert_bart from keras_hub.src.utils.transformers import convert_bert from keras_hub.src.utils.transformers import convert_distilbert @@ -11,6 +11,7 @@ from keras_hub.src.utils.transformers import convert_gpt2 from keras_hub.src.utils.transformers import convert_llama3 from keras_hub.src.utils.transformers import convert_mistral +from keras_hub.src.utils.transformers import convert_mixtral from keras_hub.src.utils.transformers import convert_pali_gemma from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_vit diff --git a/tools/checkpoint_conversion/convert_mixtral_checkpoints.py b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py index eaa2c4271f..79de34eb48 100644 --- a/tools/checkpoint_conversion/convert_mixtral_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mixtral_checkpoints.py @@ -19,10 +19,7 @@ import keras_hub # noqa: E402 - -PRESET_MAP = { - "mixtral_8_7b_en":"mistralai/Mixtral-8x7B-v0.1" -} +PRESET_MAP = {"mixtral_8_7b_en": "mistralai/Mixtral-8x7B-v0.1"} FLAGS = flags.FLAGS flags.DEFINE_string( @@ -39,6 +36,7 @@ def compute_hf_output(hf_model, hf_model_tokenizer): return hf_output_logits + def compute_keras_output(keras_hub_model, keras_hub_tokenizer): keras_hub_preprocessor = keras_hub.models.MixtralCausalLMPreprocessor( keras_hub_tokenizer @@ -56,7 +54,6 @@ def compute_keras_output(keras_hub_model, keras_hub_tokenizer): return keras_hub_output_logits - def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): hf_output = hf_tokenizer(["What is Keras?"], return_tensors="pt") hf_output = hf_output["input_ids"].detach().cpu().numpy() @@ -73,15 +70,13 @@ def test_tokenizer(keras_hub_tokenizer, hf_tokenizer): def main(_): # === Get the preset name === - # if FLAGS.preset not in PRESET_MAP.keys(): - # raise ValueError( - # f"Invalid preset {FLAGS.preset}. Must be one " - # f"of {','.join(PRESET_MAP.keys())}" - # ) - # preset = FLAGS.preset - # hf_preset = PRESET_MAP[preset] - preset = "mixtral_8_7b_en" - hf_preset = "mistralai/Mixtral-8x7B-v0.1" + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] # === Load the Huggingface model === hf_model = AutoModelForCausalLM.from_pretrained( @@ -91,31 +86,30 @@ def main(_): hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset, return_tensors="pt") hf_model.eval() print("\n-> Huggingface model and tokenizer loaded") - + keras_hub_tokenizer = keras_hub.models.MixtralTokenizer.from_preset( f"hf://{hf_preset}" ) print("\n-> Keras tokenizer loaded") test_tokenizer(keras_hub_tokenizer, hf_tokenizer) - print(f"\n -> Keras tokenizer test successful") + print("\n -> Keras tokenizer test successful") hf_params = hf_model.num_parameters() hf_output_logits = compute_hf_output(hf_model, hf_tokenizer) - print(f"\n -> Computed HF outputs successfully") + print("\n -> Computed HF outputs successfully") del hf_model, hf_tokenizer keras_hub_model = keras_hub.models.MixtralBackbone.from_preset( f"hf://{hf_preset}" ) print("\n-> Keras model loaded") - + keras_hub_params = keras_hub_model.count_params() assert keras_hub_params == hf_params keras_hub_output_logits = compute_keras_output( - keras_hub_model, - keras_hub_tokenizer + keras_hub_model, keras_hub_tokenizer ) try: @@ -127,7 +121,7 @@ def main(_): print(traceback.format_exc()) print(err.args[0]) print("\n") - + print("\n-> Tests passed!") From 0c73de0ff7d29fb25aed39db595457da3f50fce3 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Sun, 13 Apr 2025 16:19:04 +0000 Subject: [PATCH 08/11] update --- keras_hub/src/models/mixtral/README.md | 36 ---- .../src/models/mixtral/mixtral_backbone.py | 4 +- .../models/mixtral/mixtral_backbone_test.py | 71 ++++++ .../mixtral_causal_lm_preprocessor_test.py | 78 +++++++ .../models/mixtral/mixtral_causal_lm_test.py | 202 ++++++++++++++++++ .../src/models/mixtral/mixtral_decoder.py | 134 +++++++++--- .../tests/test_data/mixtral_test_vocab.spm | Bin 0 -> 237763 bytes 7 files changed, 455 insertions(+), 70 deletions(-) delete mode 100644 keras_hub/src/models/mixtral/README.md create mode 100644 keras_hub/src/models/mixtral/mixtral_backbone_test.py create mode 100644 keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor_test.py create mode 100644 keras_hub/src/models/mixtral/mixtral_causal_lm_test.py create mode 100644 keras_hub/src/tests/test_data/mixtral_test_vocab.spm diff --git a/keras_hub/src/models/mixtral/README.md b/keras_hub/src/models/mixtral/README.md deleted file mode 100644 index 1516ba9226..0000000000 --- a/keras_hub/src/models/mixtral/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# Mixtral Model Architecture: - - -``` -MixtralForCausalLM( - (model): MixtralModel( - (embed_tokens): Embedding(32000, 4096) - (layers): ModuleList( - (0-31): 32 x MixtralDecoderLayer( - (self_attn): MixtralAttention( - (q_proj): Linear(in_features=4096, out_features=4096, bias=False) - (k_proj): Linear(in_features=4096, out_features=1024, bias=False) - (v_proj): Linear(in_features=4096, out_features=1024, bias=False) - (o_proj): Linear(in_features=4096, out_features=4096, bias=False) - ) - (block_sparse_moe): MixtralSparseMoeBlock( - (gate): Linear(in_features=4096, out_features=8, bias=False) - (experts): ModuleList( - (0-7): 8 x MixtralBlockSparseTop2MLP( - (w1): Linear(in_features=4096, out_features=14336, bias=False) - (w2): Linear(in_features=14336, out_features=4096, bias=False) - (w3): Linear(in_features=4096, out_features=14336, bias=False) - (act_fn): SiLU() - ) - ) - ) - (input_layernorm): MixtralRMSNorm((4096,), eps=1e-05) - (post_attention_layernorm): MixtralRMSNorm((4096,), eps=1e-05) - ) - ) - (norm): MixtralRMSNorm((4096,), eps=1e-05) - (rotary_emb): MixtralRotaryEmbedding() - ) - (lm_head): Linear(in_features=4096, out_features=32000, bias=False) -) -``` \ No newline at end of file diff --git a/keras_hub/src/models/mixtral/mixtral_backbone.py b/keras_hub/src/models/mixtral/mixtral_backbone.py index fef053055c..39c68077e0 100644 --- a/keras_hub/src/models/mixtral/mixtral_backbone.py +++ b/keras_hub/src/models/mixtral/mixtral_backbone.py @@ -97,8 +97,7 @@ def __init__( intermediate_dim, num_key_value_heads, num_experts, - top_k, - output_router_logits, + top_k=2, router_jitter_noise=0.0, rope_max_wavelength=10000, rope_scaling_factor=1.0, @@ -106,6 +105,7 @@ def __init__( sliding_window=512, dropout=0, dtype=None, + output_router_logits=False, **kwargs, ): # === Layers === diff --git a/keras_hub/src/models/mixtral/mixtral_backbone_test.py b/keras_hub/src/models/mixtral/mixtral_backbone_test.py new file mode 100644 index 0000000000..b5546d8732 --- /dev/null +++ b/keras_hub/src/models/mixtral/mixtral_backbone_test.py @@ -0,0 +1,71 @@ +import pytest +from keras import ops + +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.tests.test_case import TestCase + + +class MixtralBackboneTest(TestCase): + def setUp(self): + self.init_kwargs = { + "vocabulary_size": 10, + "num_layers": 2, + "num_query_heads": 8, + "num_key_value_heads": 4, + "hidden_dim": 16, + "intermediate_dim": 8, + "num_experts": 2, + "top_k": 2, + "sliding_window": 2, + } + self.input_data = { + "token_ids": ops.ones((2, 5), dtype="int32"), + "padding_mask": ops.ones((2, 5), dtype="int32"), + } + + def test_backbone_basics(self): + self.run_backbone_test( + cls=MixtralBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 5, 16), + run_quantization_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MixtralBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + def test_num_parameters(self): + model = MixtralBackbone(**self.init_kwargs) + # Calculated based on the model architecture: + # - Token embedding: vocabulary_size * hidden_dim + hidden_dim * + # vocabulary_size (tie_weights=False) + # - Transformer layers: 2 * (attention + MoE block + layer norms) + # - Attention: query + key + value + output + # - MoE: experts (gate + intermediate + output) + router + # - Layer norms: hidden_dim each + head_dim = 16 // 8 # hidden_dim / num_query_heads + expected_params = ( + 10 * 16 + + 16 * 10 # Token embedding (embedding + output projection) + + 2 + * ( # Two layers + ( # Attention + 16 * head_dim * 8 # Query + + 16 * head_dim * 4 # Key + + 16 * head_dim * 4 # Value + + 8 * head_dim * 16 # Output + ) + + ( # MoE + 2 * (16 * 8 + 16 * 8 + 8 * 16) + 16 * 2 + ) + + 2 * 16 # Two layer norms (self_attention + feedforward) + ) + + 16 # Final layer norm + ) + self.assertEqual(model.count_params(), expected_params) diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor_test.py b/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor_test.py new file mode 100644 index 0000000000..a9cb8a49d0 --- /dev/null +++ b/keras_hub/src/models/mixtral/mixtral_causal_lm_preprocessor_test.py @@ -0,0 +1,78 @@ +import os + +import pytest + +from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( + MixtralCausalLMPreprocessor, +) +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class MixtralCausalLMPreprocessorTest(TestCase): + def setUp(self): + self.tokenizer = MixtralTokenizer( + # Generated using create_mixtral_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "mixtral_test_vocab.spm" + ) + ) + self.init_kwargs = { + "tokenizer": self.tokenizer, + "sequence_length": 8, + } + self.input_data = (["the quick brown fox"],) + + def test_preprocessor_basics(self): + self.run_preprocessor_test( + cls=MixtralCausalLMPreprocessor, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output=( + { + "token_ids": [[1, 3, 8, 4, 6, 2, 0, 0]], + "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], + }, + [[3, 8, 4, 6, 2, 0, 0, 0]], # Pass through labels. + [[1, 1, 1, 1, 1, 0, 0, 0]], # Pass through sample_weights. + ), + ) + + def test_no_start_end_token(self): + input_data = ["the quick brown fox"] * 4 + + preprocessor = MixtralCausalLMPreprocessor( + **self.init_kwargs, + add_start_token=False, + add_end_token=False, + ) + x, y, sw = preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[3, 8, 4, 6, 0, 0, 0, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4) + self.assertAllEqual(y, [[8, 4, 6, 0, 0, 0, 0, 0]] * 4) + self.assertAllEqual(sw, [[1, 1, 1, 0, 0, 0, 0, 0]] * 4) + + def test_generate_preprocess(self): + input_data = "the quick brown fox" + preprocessor = MixtralCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_preprocess(input_data) + self.assertAllEqual(x["token_ids"], [1, 3, 8, 4, 6, 0, 0, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0]) + + def test_generate_postprocess(self): + input_data = { + "token_ids": [1, 3, 8, 4, 6, 0, 0, 0], + "padding_mask": [1, 1, 1, 1, 1, 0, 0, 0], + } + preprocessor = MixtralCausalLMPreprocessor(**self.init_kwargs) + x = preprocessor.generate_postprocess(input_data) + self.assertAllEqual(x, "the quick brown fox") + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MixtralCausalLMPreprocessor.presets: + self.run_preset_test( + cls=MixtralCausalLMPreprocessor, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/mixtral/mixtral_causal_lm_test.py b/keras_hub/src/models/mixtral/mixtral_causal_lm_test.py new file mode 100644 index 0000000000..a711a06b0e --- /dev/null +++ b/keras_hub/src/models/mixtral/mixtral_causal_lm_test.py @@ -0,0 +1,202 @@ +import os +from unittest.mock import patch + +import pytest +from keras import ops + +from keras_hub.src.models.mixtral.mixtral_backbone import MixtralBackbone +from keras_hub.src.models.mixtral.mixtral_causal_lm import MixtralCausalLM +from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import ( + MixtralCausalLMPreprocessor, +) +from keras_hub.src.models.mixtral.mixtral_tokenizer import MixtralTokenizer +from keras_hub.src.tests.test_case import TestCase + + +class MixtralCausalLMTest(TestCase): + def setUp(self): + self.preprocessor = MixtralCausalLMPreprocessor( + MixtralTokenizer( + # Generated using create_mixtral_test_proto.py + proto=os.path.join( + self.get_test_data_dir(), "mixtral_test_vocab.spm" + ) + ), + sequence_length=8, + ) + self.backbone = MixtralBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, + num_experts=2, + ) + self.init_kwargs = { + "preprocessor": self.preprocessor, + "backbone": self.backbone, + } + self.train_data = (["the quick brown fox", "the earth is round"],) + self.input_data = self.preprocessor(*self.train_data)[0] + + def test_causal_lm_basics(self): + self.run_task_test( + cls=MixtralCausalLM, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 8, 10), + ) + + def test_generate(self): + causal_lm = MixtralCausalLM(**self.init_kwargs) + # String input. + prompt = "the quick brown fox" + output = causal_lm.generate(prompt) + self.assertTrue(prompt in output) + # Int tensor input. + prompt_ids = self.preprocessor.generate_preprocess([prompt]) + causal_lm.preprocessor = None + outputs = causal_lm.generate(prompt_ids, stop_token_ids=None) + # Assert prompt is in output in token id space. + self.assertAllEqual( + outputs["token_ids"][:, :5], + prompt_ids["token_ids"][:, :5], + ) + self.assertAllEqual( + outputs["padding_mask"][:, :5], + prompt_ids["padding_mask"][:, :5], + ) + + def test_early_stopping(self): + causal_lm = MixtralCausalLM(**self.init_kwargs) + call_with_cache = causal_lm.call_with_cache + + def wrapper(*args, **kwargs): + """Modify output logits to always favor end_token_id""" + logits, hidden_states, cache = call_with_cache(*args, **kwargs) + index = self.preprocessor.tokenizer.end_token_id + update = ops.ones_like(logits)[:, :, index] * 1.0e9 + update = ops.expand_dims(update, axis=-1) + logits = ops.slice_update(logits, (0, 0, index), update) + return logits, hidden_states, cache + + with patch.object(causal_lm, "call_with_cache", wraps=wrapper): + prompt = ["the quick brown fox", "the earth"] + output = causal_lm.generate(prompt) + # We should immediately abort and output the prompt. + self.assertEqual(prompt, output) + + def test_generate_compilation(self): + causal_lm = MixtralCausalLM(**self.init_kwargs) + # Assert we do not recompile with successive calls. + causal_lm.generate("the quick brown fox") + first_fn = causal_lm.generate_function + causal_lm.generate("the quick brown fox") + second_fn = causal_lm.generate_function + self.assertEqual(first_fn, second_fn) + # Assert we do recompile after compile is called. + causal_lm.compile(sampler="greedy") + self.assertIsNone(causal_lm.generate_function) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MixtralCausalLM, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + for preset in MixtralCausalLM.presets: + self.run_preset_test( + cls=MixtralCausalLM, + preset=preset, + input_data=self.input_data, + ) + + def test_score_logits(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MixtralCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_loss(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MixtralCausalLM(**self.init_kwargs) + expected_score_shape = (2, 8) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + target_ids = ops.roll(token_ids, shift=-1, axis=1) + + # Get the scores and assert their shape. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="loss", + target_ids=target_ids, + ) + + self.assertEqual(ops.shape(scores), expected_score_shape) + + def test_score_layer_intercept_fn_exfiltration(self): + # Setup prompts, models, and associated expected shapes. + prompts = ["the quick brown fox", "the quick brown fox"] + causal_lm = MixtralCausalLM(**self.init_kwargs) + expected_embedded_shape = (2, 8, 8) + expected_score_shape = (2, 8, 10) + + # Preprocess prompts to get tokenized representations and padding masks. + preprocessed_prompts = causal_lm.preprocessor.generate_preprocess( + prompts + ) + token_ids = preprocessed_prompts["token_ids"] + padding_mask = preprocessed_prompts["padding_mask"] + + # Setup a custom intercept function that extracts the embeddings to a + # a variable from the embeddings layer and otherwise asserts on shapes. + embedded_prompts = None + + def layer_intercept_fn_for_testing(x, i): + if i == -1: + nonlocal embedded_prompts + embedded_prompts = x + else: + nonlocal expected_embedded_shape + self.assertEqual(ops.shape(x), expected_embedded_shape) + return x + + # Get the scores. + scores = causal_lm.score( + token_ids=token_ids, + padding_mask=padding_mask, + scoring_mode="logits", + layer_intercept_fn=layer_intercept_fn_for_testing, + ) + + # Assert shapes for info exfiltrated into the parent context. + self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape) + self.assertEqual(ops.shape(scores), expected_score_shape) diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index 98681ca0fd..7d0a8b8201 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -16,6 +16,77 @@ from keras_hub.src.utils.keras_utils import clone_initializer +def compute_load_balancing_loss( + router_logits, num_experts, top_k, attention_mask=None +): + """ + Compute the load balancing auxiliary loss for a single MoE layer. + Args: + router_logits: Tensor of shape (batch_size * seq_len, num_experts). + num_experts: Integer, total number of experts. + top_k: Integer, number of experts to select per token. + attention_mask: Tensor of shape (batch_size, seq_len), optional mask + for padding. + Returns: + Scalar tensor representing the auxiliary loss. + """ + # Compute routing probabilities + routing_weights = ops.softmax( + router_logits, axis=-1 + ) # Shape: (batch_size * seq_len, num_experts) + + # Get top-k experts + _, selected_experts = ops.top_k( + routing_weights, k=top_k + ) # Shape: (batch_size * seq_len, top_k) + + # Create one-hot encoding for selected experts + expert_mask = ops.one_hot( + selected_experts, num_experts + ) # Shape: (batch_size * seq_len, top_k, num_experts) + + if attention_mask is not None: + # Flatten attention_mask to match router_logits + batch_size, seq_len = ops.shape(attention_mask) + flat_mask = ops.reshape( + attention_mask, (-1,) + ) # Shape: (batch_size * seq_len,) + # Expand mask for broadcasting + expert_attention_mask = ops.expand_dims( + flat_mask, axis=-1 + ) # Shape: (batch_size * seq_len, 1) + expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32") + + # Compute masked means + tokens_per_expert = ops.sum( + expert_mask * expert_attention_mask[:, None, :], axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.sum( + routing_weights * expert_attention_mask, axis=0 + ) / ops.maximum( + ops.sum(expert_attention_mask, axis=0), 1e-9 + ) # Shape: (num_experts,) + else: + # Unmasked means + tokens_per_expert = ops.mean( + expert_mask, axis=0 + ) # Shape: (top_k, num_experts) + router_prob_per_expert = ops.mean( + routing_weights, axis=0 + ) # Shape: (num_experts,) + + # Average over top_k dimension if necessary + tokens_per_expert = ops.mean( + tokens_per_expert, axis=0 + ) # Shape: (num_experts,) + + # Compute the loss + overall_loss = ops.sum(tokens_per_expert * router_prob_per_expert) + return overall_loss * num_experts + + class MixtralMoeExperts(keras.layers.Layer): """Batched feed-forward experts for Mixtral (pure keras.ops).""" @@ -42,6 +113,7 @@ def build(self, _): shape=(self.num_experts, self.hidden_dim, self.intermediate_dim), initializer=self.kernel_initializer, trainable=True, + dtype=self.variable_dtype, name="expert_feedforward_gate_dense", ) # Weight for intermediate dense layer: @@ -50,6 +122,7 @@ def build(self, _): shape=(self.num_experts, self.hidden_dim, self.intermediate_dim), initializer=self.kernel_initializer, trainable=True, + dtype=self.variable_dtype, name="expert_feedforward_intermediate_dense", ) # Weight for output dense layer: @@ -96,8 +169,8 @@ def __init__( hidden_dim, intermediate_dim, num_experts, - top_k, - router_jitter_noise, + top_k=2, + router_jitter_noise=0.0, layer_norm_epsilon=1e-5, kernel_initializer="glorot_uniform", **kwargs, @@ -129,11 +202,12 @@ def build(self, decoder_sequence_shape): intermediate_dim=self.intermediate_dim, kernel_initializer=self.kernel_initializer, name="experts", + dtype=self.dtype_policy, ) self.expert_bank.build(decoder_sequence_shape) self.built = True - def call(self, hidden_states, training=False): + def call(self, hidden_states, attention_mask=None, training=False): batch_size, seq_len, _ = ops.shape(hidden_states) hidden_states_flattened = ops.reshape( hidden_states, (-1, self.hidden_dim) @@ -155,41 +229,34 @@ def call(self, hidden_states, training=False): ) router_probs = ops.softmax(router_logits, axis=-1) - # Select top-k experts and their probabilities top_p, top_i = ops.top_k(router_probs, k=self.top_k) sum_topk = ops.sum(top_p, axis=-1, keepdims=True) top_p = top_p / sum_topk # Normalize top-k probabilities - # Create routing weights for all experts - one_hot = ops.one_hot( - top_i, self.num_experts - ) # [tokens, top_k, num_experts] - routing_full = ops.sum( - one_hot * top_p[..., None], axis=1 - ) # [tokens, num_experts] - routing_full = ops.transpose( - routing_full, (1, 0) - ) # [num_experts, tokens] + one_hot = ops.one_hot(top_i, self.num_experts) + one_hot = ops.cast(one_hot, top_p.dtype) + routing_full = ops.sum(one_hot * top_p[..., None], axis=1) + routing_full = ops.transpose(routing_full, (1, 0)) routing_full = ops.cast(routing_full, hidden_states_flattened.dtype) - # Compute expert outputs in a batched manner - expert_out = self.expert_bank( - hidden_states_flattened - ) # [num_experts, tokens, hidden_dim] + expert_out = self.expert_bank(hidden_states_flattened) - # Weight expert outputs by routing probabilities - weighted_out = ( - expert_out * routing_full[:, :, None] - ) # [num_experts, tokens, hidden_dim] - expert_contribution = ops.sum( - weighted_out, axis=0 - ) # [tokens, hidden_dim] + weighted_out = expert_out * routing_full[:, :, None] + expert_contribution = ops.sum(weighted_out, axis=0) - # Reshape back to original dimensions out = ops.reshape( expert_contribution, (batch_size, seq_len, self.hidden_dim) ) + if training: + aux_loss = compute_load_balancing_loss( + router_logits=router_logits, + num_experts=self.num_experts, + top_k=self.top_k, + attention_mask=attention_mask, + ) + self.add_loss(self.router_aux_loss_coef * aux_loss) + return out, router_logits @@ -200,9 +267,9 @@ def __init__( num_query_heads, num_key_value_heads, num_experts, - top_k, - router_jitter_noise, - output_router_logits, + top_k=2, + router_jitter_noise=0.0, + output_router_logits=False, rope_max_wavelength=10000, rope_scaling_factor=1.0, activation="silu", @@ -271,6 +338,7 @@ def build(self, decoder_sequence_shape): num_experts=self.num_experts, top_k=self.top_k, router_jitter_noise=self.router_jitter_noise, + dtype=self.dtype_policy, ) self._sparse_moe_block.build(decoder_sequence_shape) @@ -320,7 +388,9 @@ def call( residual = x x = self._feedforward_layernorm(x) - x, router_logits = self._sparse_moe_block(x) + x, router_logits = self._sparse_moe_block( + x, attention_mask=self_attention_mask + ) decoder_output = x + residual @@ -332,7 +402,7 @@ def call( if self.output_router_logits: output += (router_logits,) - return output + return output[0] if len(output) == 1 else output def _compute_self_attention_mask( self, @@ -364,7 +434,7 @@ def _compute_self_attention_mask( batch_size, input_length, output_length, cache_update_index ) - # Mistral uses a banded attention mask if sliding window is not None + # Mixtral uses a banded attention mask if sliding window is not None if self.sliding_window is not None: # Below is a workaround for `ops.triu` for Keras 2. # TODO(tirthasheshpatel): Use `ops.triu` once Keras 2 support is diff --git a/keras_hub/src/tests/test_data/mixtral_test_vocab.spm b/keras_hub/src/tests/test_data/mixtral_test_vocab.spm new file mode 100644 index 0000000000000000000000000000000000000000..d753476f535c762c5646937ae2e1c783a225afe2 GIT binary patch literal 237763 zcmZUc4P2Dhdf=a#;X~z1ui+A-h}N*i8f&bvY6XlnUbTiAuc5{o)>vcJ60hMZC0>Kf zz`#Jz5`u&n=_?j689A@|!W+v8HW1XpS4Qp7#>XcYx)lOW)8rQJy|6r2d-R1ZD z|DNZZ_q^x(IbZMl4in-hM5g8aDkmdR_8U?-};* z+L8G9uRHw{5+;np`PaYRoBylCcM~H&{=s9%wqNhx^~fV`g1X^v_P+VZ89N~o=k9&K zA38U0_x7CM_xSbRykEUG=JB&BkLzYk_*g5SYBir}wO?o-{`jYP!t`5A`?%FOb&|t?5s%QfXHJPM3RO? zem*ENv6J{a=Z~cS@yL|mS5kR zEO}2&l>IMEl>E;o$=^LTQ4W0eb9wXIiL#wEe)H|mmE%RM-1{{ChtdBS zS-wCgRp@~)M7F3n;nWkNUr*iWcMw+`ai)tDC+j3IQ7b{pRWyh2`*mVAM9DeA>O08; zTtqi;%|e@AMh|?gli;gb$s(>6+HDJR8|;K(6?QlAd>J7{lQeP@_a6N3B6U+m9|I52 z-PBRw!w3lyE-+3by|}I9v6Op4GFPs?HCu+u)@c7Fc8S)Ezo|+i9;6fWgBt3FG(OSF zzdiq&whq6^#2fx@r`9kjT8i#yotu`<(es^AtSisDDwf41LVno)YoVv zutX!l$24;Iq(%zy_mdZMqDD>~O_F~&{+OILCdnr+CdrvwN%D{SpGh_8RS@nR{pg>r z{+xS%E=#EM^ZhgApN~(G3xOx)voBKR;@3|~=LU^raIJYLT267_f9-!#20xuGUwr-t za{0U2vQXvicC?fbU*Nk4nT2~U;Y??v#7RD?xUU~M7yo^PS%$3OdZlXX`%w~))DJ|7 z;bhcUeFP&SrDy|X%!?Mi9sk}aDP0sTra94KTp?0L9L?C>O5!QI5iRAxXsN}o@?^9e zCarq(+w`U8>?rA=PnvPJf_X)htmB?7peN0Z$n9XyjglPXKaj`3cl3*=b#fZrjE#Sa zJe(=w;{Im(i3i!49wjE?Ev*r0NhVL_A}8@LYNc%kMC_Vq3F5AzZx$WX$T{K&#zxXd zBE^88sYD}XUX2_i&2!X4CwVidWBD~c1jNSMmCA{{rVKd#*+Zu1=aO)BwaX`}+b=0)_EkLWMlH%Rzu z{8V4Og5Ex!{*DZM&Go0WQpYtj3WcX?+O1g{SEz`@-NZJ zW0b{kO(%D_$8v0m`1A!FCx-P4oqQ z7InK>S#UP(YYm>}dvZbYTAlI{6%b^CX?zr9KVwG}1ZBe8HGv zU!jwm_?a1#bknhA^t;5LfPKLQK3l>5Q=;u%-c*s4I$)oWAa~sl-N#AK?{K4NmO8%D6 zF3hvD)HUiO_?DKw6eqK}{tEH8oJ|nZQjz~g-Iq5eNEH56lSCT1HiX}C+<%R(?@hq2 z6XY`aEr(xsm*M9##$3^@;Js3*>u)Q3ye59xnRnq9GsyRp&~q)ktYY$R-@ zAwdq8Cy27!h1|Cizvalzv)oI(TX3%f6YaDexu1M2<+?eNyxgZvT6B`5!r*s!E$w^< z`@;Pj!mYwj^@C#E2O(gipM5HFl`@%jiwq+TUTlzh`&Z@HtC0=dQ)PM7pUm8|33nO! zu%g#hiDcuxL0Dz`1L$Yr6r6|RRm4Rc2GCnHvXV9`%_ZMlGm^K8M3J{CM>%y}M%YU7 zzng1bI7}MNjM+_VHL{DmcaSIZX6A-g`Xm?>ltmBrT#fvSFcuZ>5$ch2)`6M&-j3V? zIbhgBc<7AO$Va5P48ID>v6uT=NUsjP8n9^zbdYasSr5fQojCFLz$pmyGfps1mopv& zrm4Qn`rtHeiroa#U>ExHS}D#HIY)TYAZ<5~v?-&3epEzSd8B!nYZqZy#Yvwmqdv+) zkMgM27y%>kyr>cLDxDaYGNuvkt_u4l?Y2fM`^f(->i!^cRiN9c169TcgfZY}zCc^! ze-Hv?5n{+AZZN8AF5(yF{gwiGt z{dnd%?hnwH(+I=5A$T15d4dGreB>8gLYQl`I}}3^T;g6MQmp}26MwmdIb@LfI~y&_ zac_e4D&1!_@*QFIL)6y|kzc6sa+JIWao<;U7%h7Vw-d@KW8f*)3+Vaq75P{9|97N5 zo3K0RHyZM=4V}KHUn{bhaLLqFJ7Lt?WIp;L*owai&af`}C21BVQ-24Tzo$eq7BZgE zpGauVwUyGFXv z122*v)_4ZmO@EfUM5^&$KT6vTXrzHKhnX+R7m55G@vAKN` z(e-L>VZ2DeZ*2SziM+*qMGyyNanvX6+)SUZCGNk+ZyEJzPSwZ?WXER43jCEFok7|Y zHS#Y0i;1rkzcRwlMxPJOw2eNF^ymYo1=QzS+HA_Bxp_I)Cvo3V!rM1c@1N6;ztc#| z3icRW`_Y~aw{`LXVc16K50kfV=r>L1hSS(Twr)c&LSN54#RX$6ZY_h=`t6$$(6#~#Pwr9Jr#R#m!OxY$UWjy>#W~W=5nq-#I5Gk@sux`v0EQU zdyTS>M>h|qi21;DG1GPlT(>Vuk>+jFrJ3|T;(jAx+mL4b|A=($#Ba={t!Ji4XZm#M zB+d7^ex33pa(^ELx@hl?X|fUjPRg?x*?@lv!Bd)p-seFjtv+h zqhvk$PS^^2p_q2i&*U2A*Mk9lIsTQZ9>|x0`SsVNX&5Cw>b(K~TIdG-Os>-}&9s{z z_Zc`30Tz$J=d{v2O($=1uZg|&75p#3Eb2d4ruq{5j-kjgyPBfv`Knf~bG?i@F=Ub# zc!_(%`27{KYH76mDdEaaM@tLtw^_5?<60g0H(iUCeDZ8y{?GS_A7j zLr{I8G!!L=xYtNutR(!~%=s0^SU((KzN=@PrS2*>vv$}MEf0x1pl6(UQY#&-cb+3o z6Zgy}|Aj~s@#v3GU$k+6z7_mBLZZ2L1b_3|XsJdf;5K`sB?&o|Z2TM9i$Q*5iI$rY zq#wyy5bl<3)K47tgZ?SQ+H}?{_T2=YNByQ_)7i8!YjI`&J=|vl6M7l8UGCDz+g^=S z&PkMh;+vt`|4r=o3(g7X4{FVIoxBu{a{m;O|HO3@{brU5_bhhtOoaR!*Bf>x${hR- zGrt5me+gj!ooh8RlWeH-2z3th;ju*dJC5mofuC`NazDlWPe;f<;^$!eZYf}Vq8x3z zHIhXcx>zr`=FmUJYb27m|1C=N{|+Z2HR4RyNI!9SM@4*t>;;IUN5lCS^g$YFtc227 z5of(lPL%7!O&GNwcpd+}kVqO|5#M!K!nN`eZ0)OP@i*#Z9O?ZxRURre_*E=@H&)75 zFxO@OGKeAyzXJN>y$4U4IdnH3W>;HJEIZU4@ zkC)o#nNz5T)*j(-2s^*cnu+l134RkGNoVPUIcrfJPkLzf3K2Xy#+28g2`*~*+j*q2F9wUz!XxbQbHp5R}+JA!jK+PRDI zM8(OxYa?v6CW7&Uv!-upKgz1|b(A^v5_z}`S72=a&bnKzQ`H%ogZouo>sjBaeOhoA zeds80OdQX7>dz(cay3&PvHrD%Jm@ztr^59a zW_C-`mDj! zGn!zR(C^7FYla5urer+l(Z~mc(ZMjpz+{*T2@nlV>MRMF0)Zavg|&y8P}P}>>i@Iw zS8SzO=qDE;bHv^!yyZ7t_oE#b!Ok9Zq?xPKk~%|G0~QQd!o z`nyBlB|eApSNVr6TwAHG@8LRis`^T?MI)0{{wc3YKM6esW)m*(DfRKCR!&gniz?lv zRFfvPTq}eiYtRZ6 z4pckUA&X%d>C9Is|I|l$3VcU7KfwmEy&m#qU|!cB!Tkw!gguvRAlzEDZlS+2w)}X$ zpx>P#23Svg>max!O#&~aiQz(;RM%%nU=3ka9Wu6WP7^MNoDmO}21N z&2eMn*i6zIST1#xO@EtuxWjpjMPwarV=HB4Ev?G75qH`B3@N#@T(+REhmzCFS(kBr zB7Cdn0os|q5bS53=d3?)Rn7O<%Pu6(V%1(f_$L0G`3E`s4}6XMl=W}saw#6A-FS9t zxS^Fz#Jv^tk?B%Wz`lU|E~Ng{Gr##vVFu=OJ#&PrzoHM8JO!;`&SUh zSpL~x6j3g9COCF|4tg&2i!H11K;2(WST&}rzEz1{3;Cow_KfDMbjrba^ELOx5x{XF=qc07eF zqwGZ&xOb2+K}bcX4wSu~<9avv;S5xF(*KF$Jo=cOCH@e89^unP3RQkpKOsIf{}mh1 zKTDToGqB$o%6=IuS=ZMRZV}~ZK$de3GOXew4mGcgm17Ujv{ahhS4Eod(ATPmmdU%C z2&p08M{AbJU9P`Jxzt!vO20VN$UP0BvZC% z;)p>G;trH)WfFSuY6g9sK9Ak}9sZ1^`qlUsgPw8*Kj->a8N$+y=L6_H8LZn_53oik zBAwFRyt}})66So>zcW<0uc_CxVXKpf>p0gtxUY-x=M&_|Nd2R7ISH?xO$6R&oOluc8ThL-V%4|-Z$bp++AMN< zEI-uSCmMMX{Ty6S_dmw{E>!o3S>9@>En zaMsRiEu>9;7>i1|Pwms<@hdexBmYEO>p2%F9gUXXs&LFjW!TI|=((WuHs!yD__2|{ zMDKvwk0a$h=9HSpiJ$hbewjTnX_k?vcfL!LQtIcoEsi;&Ne)^8(tUPIm@d3K4r`)l^B(;~$6 z6#nevoXCm;q+h_Z_5Bf2N?-GR#C{_Wzhs^Ra!=n2v^#r+p6}@+EyP7V{saE%*_27G z0heec2X`PsD;4CalI!`nE2of8#);pcA7SraMx5{S9Q)YQI;p|$17%AuXsHk4|Dpcg zSMmIi-(ICFf8y&$H;wRK!sj|^Wj-CoJxG~(?MQza*xu=nE}}jchz#!1%74c_NZdy$ z_gCmOGZ|;FhrgW2UhU^>jgmzF1yaS4j_&$GCtd8FRUCsaYvon^)V+f*5eMln#(i}^ z=Njz2`xB`f^j^JfIp{0zzvPdW0={k9re7cawvf1lrG6rE|AMEA~0C^l5N;L9Y_I9S@TJhOM-m{3* zdTIY>Id?m}SOPsunBSI2Lybn<^s9Qp9v|V%rBfsAZ*tB?x}U(YbjELF9ewA6@gn~P zeGsm$(W-ao#P*6-?9`(?F8c5Bk3jxA()6lMMvxBL%ZGi`Qt#uqcGaepSY+vU?164k zzAmjihTB1W?eEYgX@sY}CgN91npeMI95L$TJ>u|7tkmxjIlhhLxZX*Ba=xgQ8p2Og zaWWoWrCyx$do{m4h2M|YpGKbnPl8Ib^f2jTLn^Hr?EXD$?isFieNDX*r^@dV74~zj zbbU*nR9-&R%CkTC9pnCPu5msmY4|+{Z_^LEpT*v&&*yQk0VmgA_(5(%9;eRgVI%rq zfqI6%3%L(|0|sbVOMgaI!t3xJ)Wa1RfN$Vm;Ug$!Eo8W_krU{>@G+c$e+Db)yB_tM z&^YRseBHy3cckP9ZT$c}B=mLq!0^$DqV^xP<*XqdV;p&tF;*gE)1UHe9=5`E*a`L2 zRrv5OIo_y|9Nc^1cgNVHBlF=P6hd=3YZRmjTGz5hN7`Ch;~}e|7M!W{U1S58p_~1l zdnxBf3z$1`TcP(p>wrnD*QaoPf!nFVfd}**%spkyN9e`OQ^uDVD_@C_Q@D*h57!Sd zpUz`_!Z_rkZ=J*cJeaVh(&G{0WbVC)`x5vluWwqUxLFgp26b{7KM!+o-yNPUzI;(* zyvn&DYZM>Xyu{lNK03Yy+ii-|$rZvI>4?L~>tN{5^O855Z@fCOWp`(GUX(Fd33yDi}KUP>Gz+ zP5Y%3eUUFEGd>XGJ3cXuJCBY0tCYXKSW|ZOX-&oBhcs`0F9LeDfHb=VF=7_6iucQ$7)i*dpM$U%0Fdy{TkvdOO_PG%EGHA4D z<#(Kin^#23M=PSlGB8z|hNepM$tR?R=YFlB6lq(VB38{rv5ia?yM3~>YaSEF{Um8w z5GAe9mKi11RZ(I?+S#+WABhr23FrBG?EG<^EGM0nuo~9Ede{g(oa6WAMN3~nw0MyH zxV@Y!_;@zphmO`Lp3g!!P3JrbX~*4;bYRye?x}!EsD@f7^3n#Nhn{-Q)9v&x3-63Fw)W@IzY4Gw zq#t(&<3#5ut>lkF!DOwgz!sI5=PzyU)K|I5BUHdLg}t)Vq#3!rC}=&0}1fsIf&}7lW|Xl z6qp6Gp(UAldWBBfzzR07uf^VTv3D@eNS3*Tn-2?NF)V}Suo8N1Ft&||^o_D_yhHt0 zQUBm$?DRuNi06ZF0-VqVuG=E+LFO75j*XVk40whyd|b4wCeC%R9yY=z*aBN&JA~hl zmYv8PrI*pC%eAr>eINMrPe?w}jh_n+qI-7Gc2J0341U%aKHOf~t{=M7qr?p^=$grU zxQnnA=%n9vjE@pOec87mO1#K^Y}2y`n>Y|9y}41+gB;sG+ZhimoEZ>*B~(K#)PaL= z?bwi=`)pt}U=PRW^HovOf=nWe2^zqBI$6v}3$h8C(MtvUXMNQ+XS!IywsgAKp&i|k zI$fOTR>C>J3Ekixoh~l$pnC?VOCRqG_|Z>+e|L)bwxx&{*?)4nwA`64t z=^RawF2cD8??!eHsrL?2Qswiv_x2{0LwU@D}*Ea;&x_3nAp$2^P){kXljRUh-yzdJ09KX%3+ zq|?jzgYF`JcMW|noAF~HS!NT@T$m3Fp(&X)5ADAg-MU*VZQHc641GB`HnR@G&knX+ zt*pdvHLQd6un{)F7T5~EZ~xnI=fGYtv34(gB|?lZM94ne`CyEQl;Ih?Z;c*m!Df)* zqfbaN(gYPy3Dr;wb>M8&NdwXh&0qz00qYNNpnEp6{@B2t0lgdij1xZaLO)|o%Lc8q zQl>V_qsnS?Q8q7SL^|y3#ao!y)O~OY)Vt$nkb0zgUeSX+_d?%z+82HrpV9r$F~Bq4 zA# zx8z6bmzDUn++i$&wm8P7NX96zk7vwcJajM)t|rVn7_+~Lynlwi5jMdV=$XV=IE6ei zc6y+HgmnRD7(S%`2J4SZ#?mat)NIC9WEXDN0>)hE1~cz6Y$cxUuoH5isXSUNM|l4d z-OBrmZMgTL=YxZI+i}|stRIjE@pFb)Um%OYeT{btkrm*&L%zJ^i@L1DT@7RNHn!Pc zLme{yw#={N54RuLL42KYl%INaUZelf?k=Pox(Bd-OQcwd(*aKKKvQF+SX`0fN4K7hl(zek zQp&@A(==>qCbqE@8(Dy@T!@rY>b^T#IfFb0Zl1HBM_vS1i1jvexaWXQF5$imz6Cn* zr?d9QCcV&-iv2C3kEUaPNZUN(u44WN2VCKvVYm)A!N&SMjJyl?py!y5^%pj7A)P(M zn@7CZz7Ofg-El`Jo%eaS;Ur@KbagT|csU1fvHxt<$phke2)gCi0K`B7OonjF6XN5= z?9fr(-NQXRjrlf?{+mnxg{k~l+U_ESvO|S*RZ%!58cL}l;eLL&~ zCvBdCbRDC8(f2BUY`5KeFi`NdH>epM%)hT+fGtPzX)v3S}jUTny*@0`X{AJ}vy*TavGmuh4!)WL9CoHQUqFUN`*S;|?F zsh=}xbStzFPb=}XosAbO2dOp=NbM&{0~{mqV&u7ngK$ph1`qh*6r6#clL^ubec*w9 z@PZHgBMH*Mdk3AX;^hPf%1&fgcD%SKpBuWNWdrGNCVj$L>q(#T+0omP4sX1iBaQQL z5iY@HxB|m)9eT)rFZu7Y^De|H>X*9pB7LCN=N+d>KS=sWCuQqGx?H4B+}%iLWW3xY zo-o{nd*F^E3_L*hj3;hQygWqLJxAGd<`B59N9mB{S9G%NPiIfyTJShC%_3^8?ZlcLpRJP4(AQ6EJXVB5uDFy zWik3P7#lw(X~hG}ajyhl74bs9Mk`)ejo&&jvj<*}v@mCHL~eqn16tXF+zM98*#`DZ z$_Nhf*bXfNlz)h8^Vn~tGAE`pHzM0{JCNJCcPHe)Uf2f*p%99pr-bphobk6F`>LV+ zJ86IL9bo?f9rTCJYplOQtiSFv{@!5x9l|~b7=OLY{oi0)#54R@yi_26+*ekj7uh-6 z7v}X_ky4Gj7V4k@!mq}Q8QBb0D0!to97rb&zee21P#1YZhChrKKeC%~&iEZ;Hu@Pj z2TtakeB^n!2$#T(pNqNoGPek*)#a;%w5p$;=TGkjK{p+_GAC^y`q!r zge(1y{`NZGZ=r|bF5H6$F#LRkJVb^LvwycliEagBFT_9sG~s6W3n#yePtyNCMer7p(idGg$k{?quHiTz(- z{)H0*%)eagB8;n-^+zk~56Wrw^1dkXYy_i`bJBPC9w?6TCvpygpB32#rQ@hip7%Nm z$XhOXJV0JGQL>dVPCNY+xf9&2%mK)~;KILe4}Bm{Bl~dYgD;zLX%YQ&2mLgQ^#`^; zX8%Yl*uc&h&|Z&?)Ub!;oD^ygy#-~&H&+{XS#uzzr# z#{RK2mrE<|PTpOGI^t*mv*H@h4v-dPQ;7VaJHQETskqb0-#qe{O#TwdUjcD&%>#Zg zKKGXB6aQm*VLpWDKMRq=&+~31Qk~uE zA*`p*F6F%#{8mCV_O}`t>L`>mD*eyjl6ANj!+I#~p}%oHzY%>CY=Nz?9d<%^##@qu zbe0#$UgSQ=hlAigQo#P9Knl@4s}4yY6r-EKkM09+)*;p+hxk^dKq~OBgledTrpyAd zKpncZK>x4pSxmGBxBZbm7S}0be?RKHq7m-H#A#otBDot=gH+aAgr{E0q3?7o+ z5b1yi`ma&88nkX%6@1_$|Sr@VIRzy{axQ{}&j41@c$M(!f-f$IYO@+|!l9^ifmKKiG> zlYWYy7g~l2D1QOvFW~&8Ky1nMe+~VADg7UGE9ol`0|_v8j)xrE$4z4YI|=)rg8k3L z{-J+>d3lI&4CxQ@d~mIL#~J(IjQvAbCicGo`-g7ualT~EEf{+Sq2~Xl60J-n{uHoO zVVC9DC6w}9*A&Z~edLi{+VHc(EW#N-<-MaoB>Nkk%*O3Px^L@bF8X{}2z`6lAM9Y= zP{8`hKGfWXHb-i z4S#h=s`0Od@YO?7hjeOqR*5YAns4r&rVQxrMT{|#oF${HGts^f?`5Jpz|XwpyTJJe zarWb%gr5nV&<%~62>Eal@k|l<-RTHvVZLf*j%vHX{5Z;-bDOzmfH`Q0`H}gvgL$+w zE1L5)&O-NazMseWej?}l$(-*ad*<*|OL56=xk%iX zpmb86n83(G&5|R#_;z={RGj{`T;|#p`1_$lQn{c&4yWhJk@{SDcS?a&dw(OtTnoQ5 zUf!c_YB%tHUG@Q~8Gk@N$c>abZ=_u3`b`xELc9-g`~c_8{4Pgh3HFMgnU{8d$NS$F z%{ac><@=V_@zUfSFRhkDY2jIo71~A<#ojwfY@L(D!E>GV#);B7I!!vRC5!XM6gi=J zTwIzbrK@0y7Izl_6yJ~2g-O-$zv5taPXa@(_Kz&~kO&T$f0FxmJra}tL zf|f~<(wfW}!&1&6z&0;Z?9^ZT+DLJ%h?LocnG5q_AuNVvupCxGPd4MvZpNRJ*f;dw z$G%}~pWtHsF`jQs;Y5h>2fFeYe+n3XkllOm>x>X1&v1Qwn6R4o)`9y%gt#mbvL4+N zj9?Clkd5e@z~4*Uousjdv|tN<=8*{5irfxOH4&17+zXaa1mjl(;}>IBOPrJjVx@hA zbB$4+bL8@z1JwRM^lY@SwwB?Sd7l^CGz`&R13V8Vj(j)>g-{G8sDSWs#)D6>3-oGm zF5;Ox(&gfr60!l*GbZ0H_TA{7D&~S?jQJZF`^%XFkiN}~|F`3$bu>=eM&iV}hkaol z`@#bDzeoq}o!r|DR&an5y1@f}=ow-EJIWpv=~==48@y}Tf9JCQhK@w`zsc->k$g zoG#+_k_OU&dzd(`!%YaoUAPAi;352U{);_(pno@Y489@ic!2fa0Ohw(e&TXAQhsz7 zy1Sn8Q~&1Vc+st5j)i;JWDK$?o$^Bhx^)|Qqr8*Rlfbc?yy0ihBd-VIWh#CtFbihG zT$m3FVKI#DR}!h8WY&YYm%~ah@_vFEABG=Gkaf62SL4LUTwK~gUww>oR{S=?CfEX7 zVLOCBPLQ3*9M}u{ARi8bbB|66k?QQ-oyWUX=q9KDbq4Q&e)_+c{_nfN{2SsMLatRq zE!05+n4uZ0(2~pg{D4{q@O%(#$FLh{M|VI6wmx>gm&y6wD$e&-aK5mX^s%jO=$Xm+ z1oVLi`oRl6@PnQ;x_TE-eTSl+KZL0h2kAP&H<9NNpYmK0J@gsvi97{o;2fNXHu7eL zvGey=8MjWdE;!3Nfi;kQfZsDBjCmD(7kL?)7_Y7%hryD?dqT*Y5Qe*O4<5in2pc); zu|-JeGy2FcxaTpx-$lki0!#)c{VWMN6;fapxJQ})VJ^Dblk|b{YJ|+k?Qf*rz4TG& zhn7g@_vA-=OzXTydun?+^ZO#^_c`oQ$wv|6z2O>T6U-)#g-|*TyXv5RQ`!HXX8#*x z9!K`GF7P6Kxc$hE&Fp{S1aUc`YbpERRQA8nO}gg#C|OK=%V0UIgeEWZKdeT#a=i`K zp|1x=h`f!m1_2vv#Lr3oHX*lw`!?tK$nD@FzkNrT-&il~#GM1aJk}irtV0aUZ^#z% z(MrDB$d?st=yqs7K>72cWH0yZgM2s$g-{G8sDPdU%0EQ;$(M)x^k1X=w<$kTot1RN zQT|DkADsAijZl8Vxslzt&7GW~us<_kqv{>o)zo)WE9GzGnE~-ziNEb^lvJbpcCr5_ zkL?S{Bhq#{O6u@;-r-%NU60NO+}W%L@iT)fi?(0OxCG6(t>8@l`Q{Ar9Gr)Xa0xDhnK|_e(lV%(VdQmax}cSt z$S_z>W5ZzY#Wq~n2`hM0WA0C74X~8Ce-di|^cL*3 zm3+2=^%`l9l7~dj|B|VjIDVIIHTDJxFd33yDx|qa?B~(K#)Ip0SMOy7VV@Fzf=59mUaknEKiTo}BVa(7BR&an5y1@fI zaXkN@!t?)`JpYG&_BCp6OoJ&Vd_07o11;xOjg7E}~xoKhO4kCDX-QHC_5G{BFP;&fj-a56rL1hMUkY6YdIF z(0_cdVHmgZD*cLg9vngHE<_y;@csjB(0Yw`bZ&G0jI_~Kc4Rwl2h#T;zYjt`zs~(P zp>!g4zMtQAK)(w;1?v2XJsx=g4v+#O7*|F z=pO8=k3KpdeIfX9`vy59z~1_ii}70q%fWnuzKpcoXa9}98hstChqeLilkvco_-Kq6 zJO9J2-$uQI>VF#vw*`z6!Pt`x^)Bd*8$G*(UV|o{#wiYwSxHz z*~0bK&4eL-E7-se?cji^+>-*cU^YBDf9Lz9PS!WLdotO#W--raGtVRYaeMRF#}+Wp z=dy1-&Av6rJ{IX5V&95%;dUdt58!trTILhaLNIkE)BedgI&-HCtloQ!P*;{u| z{yl_4{=el{@BeJ#zAdm7w!==yfxXaEMgL|TP-M)&seF1j713S*deh08$wcibKMnxF){*MEx-v4nT z)mYSx9DDy@7wdn<8V_zioC4<&ot#0Q1EW4x&LiDk_5cRnOGfvwFX)3y=$FBd?gKCM zL(2;6pM12DFZHgBZ7umKVC_LU9ldJ(5hqu;Zy2tF@%31_i44PC=)s`(3NPur0W_TEXvGaKf@d{_vJVHqrko;;qv7x4TY>EZVT`cJdJ z4zj;S`bXLS6K5xT=@Z~AVgJj%*Tp{F4c*{m|GScSR>L}25AJz{fsN>%T+SBu#LFi1 zE#R*q?#(>^-oW#3_Q6~6GvnTl+zCzWk8_ZF!Ghj4hkgB`c-eI^~3GIJ`_Q$r=IeW)u z><>=dRC5$(GkLX`OT;#($lem!{;#Tka zG}Aw731^-XFLlTUr9(3`-RJ$^JFF+rt=N1UIMAKo*hbk7P{v%!itNTO{3UIN^usAQ z1Lxp8T!c$-8Lq&v;_G<1jtr@f(%3iuw10k%^*;V#xC=hD22}I!^Soz+`yuGo(WfB> z5+MAfRwg5pU@D}*ESL>*LCwGOkqcol4D(&mGGu58Yha`^i{Fw)u7-869^A>;KWs$z zjK}_$>SPo87WnD?zx%9}CSm^u>>rzHtH=I3v5f(2WC+_pI)aQ(LCy!pGd^)1zLop8 zLk{@(=4~&shqmp7zTLD9^i!u^@S*z|4?0*|bRNO}OR)bvv_IE=djF3wr8B6Xag4#l zl@AA@5Q;(Bxhj9?72XlRJxqO9AS4nyP1DL&Cyn{q1&Ooo_VS!N_KKjH+aAgr{D~ngY(cc!v6O@Yk?!|8B3TG zxzCIAf#1dc_cr@q!kidouY&Bt?YhAJ7rMcDhy5?>W`m1y3@#F<`vze`#KE_5o(04` zhj$e5zYP9F{Be6{k{0p`erD1s<($R@O?KW{$L~5=EI-Vvd8})2hoOBB<=Q|wk@g+T zzl$jUQo7D;y+Hqk7)XE~$~*S_pS1qJ_NRS+dj7A<3S-az zy_BE0yUD9_JZmH3NrI`60&WfE$A1>OCz0~QZ1lO{r;U8Mls}U4!+iYAnY=@WyoY@? ztzxf;z6>n*w^3HL23wAMC5)Zxq*KlnloJf>k<~Xo#xlN{WzSsNktQa-2lDY9&^p4e zhmEiaw!l`{4m%+ShQHyA1Q~jUy{e?iKJ+V{@R+=33&g&mHvP0{U|v_PH7Rgm%sV9nhin_;7+fzZ1I9UEt2<{ExAr zXAkFpoJID5hx5UH{MFh0*n5a`I8!7&2ROk4p;vVh{+91CpXd1{?o%+Fru-I5=zD(8 z2KPBQ4;P{IwdG>k#kc;?&gA`X)*1Ny_ zk6Vpt?V#Rqut7_5hP0+;NZY&&u_A5h8DdAa<8~ms$w$!$-}v)e0eYBC91CIWy`3TWF7i?@ZTnG@IpUq#LpbckWI)f(1gAfxg9KbDAy3>;}o7utQ~JsyHd~CPzTsJ^+q_6V*sw^B-FZgdk6R$MI~K@$C!dj;&GY1F<~*qmJuUAJ&XKCk zbL5>nPf3?$inxXz7ZY(C$FX0pN|a7b;@Go1r+3Pu^@knXwCJRQFlN?xze`S+Mm{Y4 zFgHy;I?cPfD{xQ#;|C1EA&ZYk_M!Xs!rH%8& zl=w$tobrcKQuT-O!L|&kqa8{=Wgh!dCw0UTdL=`Smt#YcA@A<~6RAG1SQ_vTKb|4) zWiFH2-sL>AT_!cze=4o@>C#f3F4o?3X|t!(Ch1~pOy>+eo#&hB;`=&6%-q`yR&an5 zy1@gzo71IdLpsl1(xq=lI?rCxSr?{@pE~j(I}Pd5LEKLKPn4%K_oho%Zn|{ib|cL* z)5K3ar{D~ngQj_DVuADM))i^YjcIZb{Sr8O(>S+G;~X=Mu_#S0<97v|1(X-5zI(im zyb0<%$T0FQcuq40fwz}2$i?|2ZK|Fl-s9Q>cnINF7>m_h$XSy5-ZFeRZ$bg2uVV3+Kl~g(8*MED|?(aNI{^EvF#6WO09vp*l@o>b;0(pd?sVIBPV{^5Fb5Bq8L{^3US zP2guw?ZeOOA}-j1pLrYa!6COp6Z%eM4p_*0TQ>O)@h(5^ebA2XU~g}`!}Doq>7;yK z!UYL8MEQ{F*+TmW^FRAIJ?kU`Yrvw@tm$53KgDkX+~PXg8$gL;t1=1+)fxfSIlMYcbh!``;gF6 zCy z*?&GCOBtAdE%bdaHqgj5@Ym26CsEI|uX>l!uJUW49q?;ewNzSjmr9uU?!rBI01rX8 zo<0jP;JZyf%A_w`pif+*uW-iWIZ3};#JX&Nen@?E$Fc6BpSjRaEMU!*&bo91{T5nS z8@EE6f%OI0K)qM=)BDe_$4P=p=hZlwj7);5kOH&dr~Q96?Vm-PWYRXY^VoMD$c{$V z|8~~@7V4*${eLU_|9bZSHJod(7YOl9W$1Ix0C*2@n6=St;+zZfVIdg#UB$)7Ww0Dp z!fFWf+lAq;czB2sX}P#@>Io zu>ZBQ|3%uo?2D1@xE)CSHGXICHop^y+(dlaVJGCkUg#NM|3Ae3AL+Tl{{KGvfAFFE zp<{^s{|bKNa20#~>;&mTx-!}SBfE*$nabXrc=F*O6oMN+*CM_@K=+go_Z0R&=oR2E zCvF32Ap5Hlq!K^#HulWOT4+MALpFdVmvUuOu9N()0o={demX%MLH7UT*N$wt!2W-b za7)=QgKZ9W0qxkZ`u^9-Jq~a}H;kPJuuoCn0DExv5T=)~|8f4t`8zh^13z?NXPqNI zoWC)KbPZwuuKzfH^ApD@aQ4z?KBjMA(`Rs>1N9B?kKg}#$Xnk``qmWs82R%bAp9ie zc*c%z20IMeGh@AOR*r5)AL>|6j0w2tCXD2)I*V7R-j9&QB)sYzy~% z2!E6y3z5s9Wy%lV|E^*!&f46@TxDl2ZO>zUU(fF!eUvD^x7e4AV|_7=@6fqtC9H;Z z&@-Zb6aJ zSrc@FnQ{6u@mzsnxDHJaN;DoLb?0*#dM|N|cGn;1!#Iqb$ z!fH@w^R7JJc|cd+l=Q)R^o`(OO5a;R8YPSaV427B8~nF`viq&b?O={%9fY)#KO2ml z`!NpW;J28%vIgmkuow3}Fn+>(EZ<5#`tR4@LEME<3_kY87061chFU1S%6I>~Hx%YO z`a0YVV1{O}f&-k;4IT)6%()P9_-Xz}74i(6gY$3^F2QBE0>jY4Iesg&ffa0Ehjws4 z2jfO3W5x-jlkuYq{m1u@x*1z~*yHyuV(p!-&XQPvXjp$RR;X`){fq@p=Ev)#a}&aF z7u+)$hw(SQm@FRbTYcN`0RM;J$F_a=d7;0P-x15CpF#LId$er6*ZGF}cr%%`PqOHK zNq>YGNPx+pp3Nm8hpEe{$P}0bvtcgGhlQ{h!mRU+w|M^l27Ltg@Qa*zB3Hs{SO@Bv z>3ZZw*aTZZeLG>ne=GWr-*0b6_q9xu|M&atmIZu&4sBosTQ<+7SMiKLljqdPdz85- z$Zz_Oq=Uc>YEIg}sTvIw^a|2+H)C_mD{ z^ZRD*v4R7f&<0$1v_TMINr0WAebR<*$Ym^^ovxo zQN|zQZ3o8y`}MWVHCe0&!OWQ8dq0Bn7}f=_k^4&d-8~bGeP77fZJWe8Ad>kdjxiaz zi7;DWD}=w~chrzOAxF8(_(Y`x(YB%2x8~zsf<}g-{G` z#&s8%&^@k4Ytst!N>F=JpMm)UslF9ojbAm?LLD@K8JfWgEgIhIhPFK3^JedDV=lHc zC%3nC+;)9k1)RW&ijaaHW=TG`jJhUwzE zi|cM=_W^z*Wi9(a=-bTsJoKY`;m6-6QuFUA(kXq4=Xp=F_C|Nt69=3__mp!l@t#M& z2>wy-g|Xj1IKp!Y{LE8W!yvCf(;)je*++}&d_j$Bl`xnQgnU-dYXPrUoBJE$~KBfOnI>Rv%r5KCx zn1DiibF%Myy7i`nE5F5#{DpBl^l6A=Zp+9Sn2FgatiK_CEM! zcBz{aZ>!s}*l$Y_$Hq4yd0bvQWo}T7u{+{dVine)_8|Y3_;vIR*o1OyLlxfNe%T?s z3wuzZ&aM1c*Z*mLO5sDu@$*-cN02~|@-9X8+KV7fRv16uM)v)yGKS3K1Ww{KVxQg_ z@*FPU5?bZ)*#5pk&rX;BCY$R)zlQvM`44f7Ko_oy>-nv9mB?Ou3)~{_;2s{J{{q)& z%)wx?1Vb?#BhmK}{WnYGA5>rEqmrlMx=*Ek8;jcg%7D5dp^^_Fp&V%FR`wiJ1|hXx z8H75vbv+u`+_9g$@j3SYJoZ0Q*ykDsN^2a(V*(~&3Z|j`hVlQm-y8qGQ=j}^ee(y6 zTPE|uM<=rXr|SRjr~e-um63c^P^#BiF%q{evn)eU1w(4lVJ@#`;>bd#79A|_&%F&y?Ybxc1#!fn!Hw+e=je^{da0#^t*a{ zgX3#FzlJ%U>jLxd^-ad{Lruad&!mw|I(C!e%CQYq*nwRb=eh48>ztz=4I`w3MtX9Y z_k-x4ciYlDok-n(Ol zVfaa5 z*QIC3);W9`^vch<-*1`2VE?2PO^) z-@9EL_TTtW`0n%%g?;Cr3g2lS7WNMKNOn;STC{7KQq~MWF$SnxfFyT@;d@ zP1D_?kh)o<{r%f;&v6gXf1$bngHeK^Xx})%{Eq?Ve+&p&vP++1H<=U8->^3M{-V%~ zW3%|G&|<8A2CeiqRF5hOxu5(t442kOl%jS~QK*?*6vomMOI#bq(Z?fo&>R~i(TEA+ z(){U@$Z2R*-jtCukm-0>_GA}C|M&e+^`qa09GfD~7U+^ky2-lB?thjxi2ky~YHbdF z$K+mZ5N(W9hxYwW{Xg={Yx2uXY0k!6%*R5sJGR5Ion$ttO_R~K>DI;}^Uixz8>aa@ z{{+&0ZxR1!d_bG!~bIL2w6AhKL5)7&~WyCNTBiN zef?thLlaV%;+Qx$YU>E)&DjUxkZ?6BzTqCxzW!b1-@D4ccfJ32z5jQ;|99E{?}q&C zclBrA*FL`=jvUy>;#VklE<`^z4HBp%ZC( z8}j?5z1ulnyzd>}4_R^5qwa@1c>=YQ?}wA*Y1B-2Ur)I&oDn{Urt|NHlxuA4=RQ&Q z_JdIGz8a7iBOXaKA%zS6a|u__i)*-!KHNh4PWSJ6I`_K&r2Egfwhq@=@BVSL=UsX9 zUH)hHdiP!9Q{2C6Xmt&3j!Q3k5bj9lr2KV{e1O)&j`83AUv~e`NE-v`gOPtq+Rr%$ zx+cC`7?)N(>VbVD9=@Ym=P7aAP}JObH#97JHzZK;HD$j3;-;6}*8=ys-hI+XI^04_=NudPZWho^u8~=8>W!c5ZCT0BXggZk3Tmc%%Hc* zj~RZlne?nWy0cp!6@4!9i+nfoS~uCHZm%CD4-RAhKl8AkDvs@Gly6e%uO@jp_ceXL z{H#=C6<&@P7MYlW;kl>wWBTZ~kU!`QfKl5^q1cezK4SU@3)@)4fgeIb-Q{$OLhr&&wE(6$9~zPOW5aQ*ymH( zQ_^i2!L~;0W%l`b^9S}@_e6UA7pYq?7`4sj4|JG6K#xB8xc<*j`f#NDFRne16d%_j z7%498x=P8hXytPlM~+AILlox!tuV%4coGWp0))FZ>!(<+|NlAnKiN5^yMNbL7?=N| zb0E(Do8mvyP=*mTtItQ5ZrDbKQL>ce-`NEZB0v$TKNA8~%q8pjmo_t4kT zD?g`P{Tlzkv)X^pc<0l-`!T)&#PvrvIA#+nisY9tB=o93`hI3zx(}&G#P@xRKSTR@ z8@=b>*puWA?7|-G!vRzu|3ElI){xcY5!BKXWD1GH@^{jAUn763ms9e7TtnkO_5ZP7 z)BfLvEb=&klQ@n4zy80%{Gao#4UgylXuB8sf1Z=yoZ}oW;RBwg{ZDW6KDKxtTWh?(r1rjX0Jo%h2lrz7>h|2HtUqY} z-vi-Zdk6RbU%dzZzky`$5Z{#Yb1=Q?^VS#mqOl6})>F=Lne9T)%KM#X^*PW-BH!)Z zH~R)Ud<)0zS1#RB@nbO#<1qn~Fa>pkO0+>s*q`ikG}4o3s%I18=tcJR4fge2w)AcG z^4=BPVo?d#^h^C6Vc6Z_e|yUh`x z&qV4z+xI3L_%2&`>yu%&xOm>mUs$KOOPTIk=8Bt-%vFAYjr;<0j9n03jND{?foJ#$ zX7LleXj}kWe5v^5Scz4reb)KMu>a`^*V53<{-1F#p($RMhX5y|EKRhkN?d7IaC_MF%qTd`@DVZ$ja~X z+y9&&!Q7m<_C(Kb%r_IiWvX@-;+*XX!m-SqL{7mplp*?FVi`Pxo|Vtz*nye!*~pKQ z|L|x&$6Rqee{HS{+51`FDY+O+u^cP03cdW2YshujfQnyRtKjFp#d^LxM4w)AukUET z?+BY5Q;u3?!8WoAaZbo<_w~7xTga`q)s^CRU>Eja9}b|hU4HtO_fyROKZ5^%HUGbQ zvr)a7B%6?8zaM>$|Nkxif9-=mnqTJ}?SAXfe(1btd;lJ;+n+ng2e8+F{C>#!s&NDf z#JPEy^U|kh$M~=Rr|E6TKdVd>*G+a!mZrFBb$*^afm-@W@-%98yY8J2ultPfITWtj zy)>?MA7tGy`ERS|h6Easd|mn9?78BC|6IZq^x_(>qYt;xuHD_i@7hUb`D41sZs8o6 zm&cDDR{kfI|24}04&}dh^}F@A?I$9Q*OYs4%;y%|6MlgHUy&zc`>*i>VIaNg33>mU z)&i%u$pgNp8Q(5_0^*pJobRz4ajtsq-}{G2;`@H1 ztwT;jSqy6f%phZ3Gn4E!{%3X!%NKKFSbzU~av>HY*8dgy|6={WRCqaJ{lAi2h5xAj zuTvJ(D`OfaD<@_t7ZxcO&_qvZzZdpL93;Pu@cm0?9X27ZEm`^W)7qb&|4q;Tw&%ao z^WW?FZ}t3<@Amu$MH@id-*M?4ehEYFW&m+8IdsEsZwx zRA@Vsy`Ru#jA76GL=3Yra=#cAPSQ`Kj}N)xx5gG2V{j&hzc4VIBQM|*uHeyk)JuE4-ZhMkGOt-`S<40C&p7O z#c+&7`ziD9&zpaL*8F>G?|0p`{{9W?@9($%K8`+P{{3_2-%BfPzI_Wl<6N!I)n+_F zdV6UomCjg8O2Z`j6y#4yTU@ueF6WshuKIXsC?jW}_U+Ozlbnqj zdP9x-N|uJX!t>EYPwjM{_3qPeb^7n?(QsG)KN{&tG)?#3yzBRm|E>6{{T-#T7*+f$ zm2X=A?yUCTMR}}8`x<{V{&$Z2zd-(9BL9$asKyZ_kU|=5$RdvuIE^!iZHAto>!a4^-}{Tw5a-=T-{w}tevxyI zc{Km;0{s%Mpcj1~_id5aqmElqO6LxF4-e4)e{&Aw+6R(@QG%f;jG-ZW^+TqW6C>%R z7>jXe6&Ke48c#3uKR#ofVfrNGwOMj_)c;sFnGLd4I|YAg{O^|gzZ?HO&HpR;(yOc! zR7S7(dQoVf<^7{`k@t@-dN*<#l@BP4|E)Lnpxd=(j6GoYWYFrjIR1BrbY@~U=3+h; zVln#ub#!6=L(f-yBf`DLo-Zd?VindPj{jXpZonp#VeBp> z$FRem<$GcK*0X&Z7O;O;uzxqREy=<>xn=4E|Jj8-*oQ*lWn6&~u>TP~g){zh4i|63`V}#{i7RQ7%Z;(#Xk!t<(CqA zTzg?CS;601v%r0AbYFA)M|m||eA6uVg}C-Y5?l1MZ#`w~+GYN#x2?VDn7D>dExnXp z>4jIlp?-cV+Q;f=c_%qa{Y<8Wd)^!#dOyZT@DciF+{?$T$?1RNFdn`CYV8lQ$NYjx z9{T!HD}0arVQ8X3N~@z2A-^?x4Td~0Mlbk};9zN6hQeKdTl`Fhy1dU*Ku+rz`& zvqQrhFMZe=l^+h@d3$Ksw&3UXZU1?wl2!s$J4c1e-5(3vr+&;Dm?OfDr=GS)pK*hJ zOJmpeVd0-(Deq43SUqr#qf9|_-n@gt$lZ&~bHW(~~sPlUY(2ZcB84%7FgKQUPx z@_svklQ@ktc=P6vaCmibI7hGUR}_A*-8Z|xDEx4DQ8+Tw!;2CuD8xL)lXL5xUS#rKjL?Fbo5bHNn;0gVGn*1`2EQPID~fB z(vfs+HSS*C#%9-i?6~_!OSAiT{#MVXt)9Igjb7{g9U*&uZQKQ! zLKFmHR>_MygH-mljN8ir@`XR(T zokOg{x{zZN#W5AtemjC%c|Ad*_a-kPngjhtEd8@RmpW zdB!=ZN$vHf=d>@#LVb5%d3~I(SX(6@(3!el4=+vqj+^`2hX@ci$~4KF+^RmS8A`qwnQ`VI*0K^mE!M zpPwGa(#K;0TE%5BiJmtZr9!uuU$E(m=&R-&^X`vhI+MhGZ=eckfOI*(S~Zz`=oSE;|$KBcF^H#72U+z0bNYX_G*W*e%o1G}&X`)~k<&~8nF4s`CfhQV%Y z7j#&|0J*c)Fu>6PzMo>>51Ah2`yn&Jtz?_y>JIvTlfID}-*2~Xf=miGk*Uqr4`|dT zK(%unK>{hH(S|zpc|97QRsSINDaW;C3;qAX8yu5G9w%@Tr*Q`7P_Tb@vw!i}{#DOE zs^1^4|9gY2d!PMFW-e$Wv^e5&LvzyFIqbs(`?Kb|?ufD3>d_kKOwa(G!xLc^XtzNz8zEIfz>Sg_E>e|#2_qo}9((9yASby-P_q8uJ zx_{vm?nvhz9-#l%TmuH91ViyZ-+xp7AHn{|v90bOEy$o1ZHRjZ4wueIlwvGe?>gqD zHGk;Y%kCfJ=@XD=>*vtDfc=k2;^H1kQ^;wEYbTeHGZ6Psieo1dm?=CPP5s;#V%sH& z!u~hvp~CvNeWuvwNk#Q~DF7k33@D-{kl0Uw&MjI+VuU`LhaVYr5CcB^M`Xj{WaRPWqi^goUI5Z8X#irBX>Nw|HH{(Ths**CID zUt*ggH$p!JcTSnbyA-{8TTTH zdlICQY;r7gJu}58#5F|wxx@PP!j*rco*$y@r^mGgGmA>Xe0p}hYr{hNV&vJsIdnUw zYX$pvt@}p<5@_6QYyp~(!czZPj+I!2HCTrY*o1b^sbeSmpUfU)|I@qUJ!RPc{JTeQ zx_=x)`Yrc=+}r}Rq7Bt|pA6;F*@i0YK$4_*~v{v z-8MGBF(+{vXAtM&pCd2e658dX4*94PS#+WMl>B*9{>0Ip?B5sF|8vy!WXlruKiMkW zM#k|gSESR6Yq*Zq7qy?Y3;O8Utnf|K>Xs?tR)im*h&U!WvXh~|a+ z#C}K12=+T#>1`;iFHp>W7nensxNbaJ-*=B|IpW>#BM;yZs$(1ysJm(W0(-uJou9a@ z-~ZJ6<}V(0&S%}H+?#VaITE#p^(&BL5yuDB*IR?(u=2_EHn{HOWpf~rx~Z*p-+g@Q=`hZ*<1qn~ zFa^_4hK|jAkcew^;ctk{H%MtZ{2V4kH(jrp`Sy= z%cFcB%Aft7&Fh}gJkJJsR8JfoRyyVaF5wDl7g_&*-ssRvk9*QL%rZ~>3HkkH@Aly6 z5a-Y*(fH!%aLqC4L8ak3*@ss8E%FXB^v(gs1r(Qtd%_Qp6V8k88m13l{L!n%2i!0| zfK1;vK7h=;V|>5};{(wCtnmetr88YRWY<&vH_8|QY2>BV|8KN|aL4yKm@L6i497^+ zowdIJ3gZB}|6qTCgQeD%8y!j=Hx}bC9<}0XFoB+M?+q`E4wL9pkP_b{oD|>a8NI*$ zo#yv4%s@qpXN>$W`i0rTm0wVHeujw`N*}%|2vI8+AIGfmeq~wqa>Q>skh7z zcidtu#d55~Dy+de6xN?VF8?EoF8tp5|K4{qjy=oQhnDHuhR^(N{AqSI{}eVzYZD%? z&wrVJ3US_ix%h3ULWOU!b))>hR^OFygSsZJF|~`n2Zen@w`)i2*N(tGan+7LKpsM^ z^s31tsF`dHvDM?kww?baRMx3SeP5OP|1wm)Jt0&)^_O8g{U67CIP7@#!y)1L6#i-V zN5igB9}VAp>ce5@Qy&fAUi;@^&(%Kl8(Q~mz&kA_2IO2a=NFAWFxj|^`;E$BnL5Gzw+Uf)^MvrN$>YPWc`YBJ3O&}`2O4S->kKvx!)_{ z*q~QJdd%j~((ms>X4v0{nbK{;SpUg=Qafje_VqXIuP8ho6EF!=a9DZzg9YsH%_02I z+LATHhJ@PNL&6bjQr3;J-sD?P@S%SrBnEwh&*~c?Icj}q+W2*Tn6HPrtJ)towLg$} zPWxlB_Q!PX4>Bb@P5Nb+fti?%xtNcIXuq%hu~QwrU0In_CwHiu5%(z0*JxiXu>ac< z`@fNo$KRjz4WqFB@W!IBSUO9w94pb%;azCjY+U0!?U1+3vs-Wc;Y->lXxONI za#6eBwswJgDy%rp?$OaAK?Cr-9Oni%KejZj6j|| zdQkgoz4n*m(qzj#ZJQdbZizzD4HhMIHJ5d|P|X z@w4WCSG_rY;;J)6;RJaSwNnR#)8rY{&>Kd1ueS$;bHW$Ubk=*l>^Hp$|?w zAaP1tolGhd8u4iTk4ye@1--b2>*&KR+(BXdKN~;#9t-P#9DL}5DA)<&niuGUMD#5j zL-aki`2O48VS|wA>1+t;JV5{dQyxO}J!CrA4D{%8h;#cY|AGJD`}(n`vhyeM-*0yw z40TNPMr%ruBT;+XTASopL?1--JtQzrcsvUG3T$?d$F+~JMZ0aQ_SZBM1pl9eo{D~93fcQp{ZeEZDt=TP(x=!H^o;r=&Hi-V)KXUYZ4f~Y+-~1x`^ab`QTIiWM?9;XEQ?zeo-=Y&)bfFtLlJKP_I#3TLTUl~ce^?C&v^#JLSg*tD*ve!SF_jn zdwQaW{c8;VI{F5r-tk{?$=lMxCUNQQ#{WnA|7q`(UM_AMGP~LSJB|N0*1t-42XYhn zCGcqce;qrzu>Qd;*Y~0{mq_~s=Rj&b8{B{PU>^?P5UOzm3A7*fe9&3!`3~@WM|r;e zJYN*%e;0ecr##<_p6_|jx7qVWYrW^|f7QpGM>=V=A&XjZHIB{G6aL@e|0n1t5w9uk z&6w=xTK#rf+!>t11r*vNlYLLh{Y%1E(2EM+Ob(Tw)(-hi2-n5+q3*ag_}TaTf7bAFXp{0x^_Rb%?tZ)(vPz9n-{SEm&j{m zVSTQf5A$9dnQivpdih9NxeE5jF!qOVTqCrLOlvy~_uEL6Vk}zs%AfM`IC{ek`Rlf^ ze)I`w5>Cxye~`(m#x{$q*5;T(PDAZkeF@|Y)I6o%fP8Jtd*7dLu&+oP?G0m{#m`0W zzuQ-cjN|_olS{E2E3pcdU)64T!d?LMxE9DdGPC(%|LO+%CX}O7TRV%cyS|UJ-Uag5 z=C^A3u!`J)+HU(mkb6**iQ|PIj_2Ped;m=&vKAmLQop>V zez~lEIj(*|e!6}4r1gILG%K9P30(8N)TuA((J+f28jXw8Z+J8gtuCql%c%dl)qiB8 zaFU(gq`i`2v!9gCX;ktN#y!Q((9hulE}?z3{`~>`1O501)XiP$?1oUqBomke|L=*b zUTdra+5ctlZ-h14$-$^u_B;E(M0hCT-i2`w!-D;PTl@bR?f=Q%vvzrcY<$jp-Tk}y zKg0cJBuX(B`QMWy9e+_!&Z(h+Kr*FGaei5F5nV5~an2&{s<4+coOYv>r-g2^4{)^*JR?@SZ zbQpVOXk^)kDvHWOoIlX}Z+{j_ zW89~U!dSBJ4P$eB3vrLMI96qxxbc_}fBU8Uz4O)3vG-MNoL57(=2bqeSCvn%DxY3e zKD`=_u6Q*xuYNTg+xTiouYEPNEP9o1{ngOwxVFrzp>Dyep?=w`p#cdrB8euXFv)q2 zxc4a$`fyA`8D^kPn)PTv0*$-NegEbB@8$gO*qO+p4cqj^G$asK)QDKf_M=qwzod z3e5-E|73a&`+tFNbBX$ijO)*&FK-S52`4q=ZLm6gZCT3$U+KY|9Nh;iBKp|8vLOYixRa^c_b@2|2a%EP7@_r+I3IobP(SHm68bR|8( zR%}3FTpO8EFUGxMDwUa4pHxn7x89E9YRD>b2WshY|DIj+J=li>c(i}dA$qo3+X2<| zBZz$hxvlC?vTLvU?^*tDG)&b#z~lXQm$(M@X~KU}NTUr|s;84&BEEzh6RYz4ipzi?|M9VP1&xzmjjaYKZTG9_Nk4xp&=o zeE*BwE8D!V|KDoQJ|q9v>;KUQbWK{<(T7{OgL`;@|8)G(Mc?0H-yhOwL0n_Jb+_*i z)$H;9Yvd^mMhR-!<~90&hSC!&*;XRF9HSceVRgmP>{6?P!5Wxk8tgSt8Lzw)gCiF)_c;oi_hPod)F;;_$g z2XF}0ID!On-zX04tL6Uz@~Cp)vHd5U6V8)IFY5oqF~_A>c>bF`f3g*A3#|WdELP87 zv-w-14{+ea@qdHZZ_2-vbEJ_)9w+esst-<*ePrL)9ceNpI`nKNDh9KMgHWOFSi)*-!KHNgx0`~DDHYi&)p|O5IL3XTuEzbjrv@_7A^x)X{t?$$ z$)0`BzVnQ8`@R3Y-v4&b$Gd;L{+sun5iY!c$EBZpGE9=r6ih=IT3>Mg?r{b^E3bA6 z&!o>rUN|QYcavSs#ubaJe#!iCav^F5JQ)^~OHt!K8-x??VY%>1G+p)Xg&S{cN0Psr zf3I!KpVfF!+nQ`bDx=-*KWnfK8?XuG*oG>!Yghiy)^|AiHv7MZ{m-6HqlKPH%HP`f zZP-$*ZTgh&9Xq7g_iK5R+=Co{PVYzLLG^jh7JHqDJA`Uf{fqkkvqd4>kA3ly^>xH0 zkRRkdi+kLExx#&jPooW4w2I5D=WC#!z)74&_3QEuS+n2#C;B;hg?`ZsWW#pl!wu!b z-I#}z5A@VkKU=pUFb)s)gA^Hhkf{smEmdSqi67(^FQPK z)3p!C9HKA3ZnyhG0}^O-FG)11$5YrOjdE;56@D;AojJK*_&(qE;irni4o_bo@t!_8P^Wpb#P@9GJLb+Z^KUd zYg7M6*oGtVT2Q(6??To7zY7(E{zurp`tQQ`_+JjFS5kgUqYXXZvcBOW{cBJBdHD9A z5@VM>7_#E~{_3OVhPf}ivH62x@8rRue&(1?4AD-GPnOCwDkK#sQCJ*aLMn5`Mtw@KcnOe_kTrPFRI%0?|*hwxK8iG zE!@F9JV5`y)n>wAlwc@^qj!t>$Yjs=jkzQHj2#?Hj>C9Nz$8q;G?XFk?>B?&<^P`< z!+f-}W4KZOk9I)xjmPWjds^SI`1#1Sm)ZkZ{(p!6ML5p!ue4XdQhKXrpV6LKPLFdA zJF$|!3i%G-=Ea9T$gbPIzh_5WJ4fKThfDPD0FGp&l{%`%N z$)`q##>wWIi%YLA4OQe0w5})(yU0Dri0{*i~l#xadPmXMcboHrlm(Vw6vR*Vh}D9r!+t-5oy`(Nh%*Smk5bWD2x z=x~}mgI49hIr0KB&7(u-Ead_&312~Ox^pN)x~6&u%8|MkM~C`2#scwNjhm&tU)$_; z=h~|M->6;af7fsweYk}?xQF(n_J=k_=W6F63;R!#IpI8c^py5Lj@A2CYP3JQ)eVQW zKhU<<_@S4c)_?JI=sf;($f66~$RUrTuX}ztHsrXvzb zXg~su7rj3;o$~&Umxc$fvHv=G1%pw7p%{*lD8*Pr8^3zqs4$LRvvE|Y-TQ$so<0GS z@OXayP4(}6_3v$Ej&FaO-+F$nej|Hd=~vi`V=I4G9FIRkTw(oz9~Om~^x2q;`B;d) zZ|ZxTrTveFCF*zOSmQ?Nqp4f^8Eud64+x7LyA;cj>naK>$yHc`_UAl*@3GT+&63fl z*L~jezwG(H?fJj0{eNEj|FS+vvPC%7&8?$6f3l}S|NqOSq1V`fo-bS9Uz!`R3FX*^ zD(t{6yc65~zqX&k7e|LZ!s)%nSda&B2-RrqaE@m8Lw~gXa$$c)zKr~C*XDj7t-m}a z{%4*(S@H;KW4|xihMM*K_B-|8>hH}8=g}mbQXe;xN%sGi*NqLtYx~vz>dzC7dDMS! zl71Rz@E^_pt9wEHJLh2?|3~W&toXBV&VMi960V>Z*Ki$uXkWzsN9RlGcXVx3zazKe zVf*6fTm1iT^Z%3Si`plrv`^6L*fvyOw&t32?%*CCp!O>J|F*T(z9Ih)a&3GE1L=eD zNAv&q*bCz;lpiI2ABs5tFZ%z6(<^_(?$`bpNiW4%jKg?LKwnEhj}W^aSg&9WL2H^ebn`9IfaO+?3C!*x9N zPu9za!tL`s|3#kvGWj1}BjkTPw*UWB{$DHqqYb@3j{SA|`v0EqzRP~o(yaKHwfx%I z^`DLNpG(3$aSzacz56ltXCOH^>ff+NCOH(t5yu~oBp;1GE~SsfIK+3DUE-aQ-P!{= zGLAp4l5Wjr^QEm}eTP2Yu@f)})#m0m?lm9$p!wir>RIFS#ZNx%BK|_Pf6C`SgXz*RbE&+1<_Tcd~96`}-MoDmyqqHcn@MlTE@YvS1&+ z*Zx`T|4XqP75ZXUlI^ak1D#X(7tlrTM()0T{@dmsEMb33=a}^3UV|;t&y3OkPH!8b z-?XMAtddUUCq0WVT0fIsE3U>o`gQa~k2&=FOTq^FCZw*Khj7>2`J3j@ljY*l@7Sx5 ztU~Kx4%+Hd>*aLL>SS@qN+*vKIEmWx z*57AGoTex48@qA4IGmxML+TlM1If$24QyE=|0)AtTfzRC$M$}cZEZYE>j-^ij*s(q zu8@WOxA^qB#CIcik^POj?biR=tB(%}zJW&leQ{iV(^_ZL02TT_`~RKx z3I-$hsR5yc9E#!S{hana+4D(r1ITK5YAiVp<1qoX%G6l)O`^xS&v6d)6#6uz&U=1H zI;Ii5zPp~l2k->@A6xFfSD$G2mpOh0W@9eqqwg0_hK1zfsDI5nU2UE(y)gfO$bhh% z-YPEom{!uWH}nl-6@3lzo=Z-7+szNqMHcEG_GMw;)o%4VJGHRC4w)J;FwFFyb=ZJS zD91KbVF%jT*&XWJ&SLfN0QK)E^)H!2elq{NI<^_d*yL%)wy+m7#=o_q4b{)Ff2Fer z`)~lY;^Mw`hv;$7p$1gbk07;|{a?@i=Lcvc-(Nl?;_pbK;!D1WerAyNj%QM*GLTHjt0|KjyLVXL4X@-n@M75EX-~Vig>#k3b zR<$vGlgOH+{S}PSn?j$4+$XKQ!pD$6{AQD*e8Xg_*gTgPj1O?k48%6)Ofr4g_!)98 zaxZ%3bL7kQ@;%wf_m*AF-|x4DSd68(Anz?FS7H^`U>)kzQ}qkp*FINwCFSX+X5|Av z;!$Nq^JV#)Oy7{dZ_D5J|1guUHe^vzVa_jk0w>X@Zy@*bXG4#A{G>DR^uL+EjAPPHqvfXg%V;&fz3qm! z<8FJJOQ%wKTJ9g6_-9(7RHj}&{|U# z#*-6}NtT7qm&(*{Wnq%=6y#i6-dX@%t6k@cvM^0t8D?N6dcSI&1+phh4|B=B-%bzn z$%R;qrC5%YScPiO?Xmx0t$E_&YR9<#XU!9*C){g;`>gnc=Z@5I*Y9~HdtCoJ&xbnq zT8{=Ku8I>Me`}I%3WfgviV!wQs~p=(xug&r~nG?>FM>FhSX&&3k{zc1L`4g>ZL!AGYl1>_Ji2DGwE?|F5 zHonhymenTjT&wNsw|pccZ9H1P=7hL7|L-Ju8nyD~8S)%z{I5YeiRpa!!k5r=Q@Mdg z$0bo$@BY!y;W#wXlW3aiU15uI@=^JF#s7PeYabAu2C0Ep~t=08;XrHrk5e^$s-k~753#B9vP`};TM3+G!DK6Uwm-?WgoV zq4Og9_pCO}U2PlW_G|y)=o0OZG1?zUPuKpCW(KYOv_FvEZjS@$?7%MUL92U=bM*Jo zvscyM`}yDLhmgNv`~te)F@E8;aSh^*Ab}LpXhRl#-zo}uvgfD#B4qC``AeQvUqpLG ze^-vbD(=&J+HZyOkNt6mUOQR)$MZc$k3NwGeFPWimk`%3jB6Ma{Qv8XH4v9pzW0*X z5Z9HwPWB-~FRTmX+q)%v2RY9skHS1uvQ8eUkNE@%G-iDN@@rvztR(-}FNcJC{`Ua= z|A#UkgHeK^XxI1FflmEySv+pv(_>%U(ev{EW%-{>zodTN_?~@0Z(C;Eyfj9l6xFx2 zrOBF!`rj~)J{}Vg#~>tzxjt=}YTtE}jf$Ko8s+$s70xcr~_gYs{I zy4m}9H2x~?vrzE`K87Lsnxt3p-65fz+=eRbz%J~;K1AQ<7W94GoMrNWxI?H$?ngty z_7?kveaPHtZQtJT(a^&WmGE2dm$e5zD^JnWs0>epDn5qv!QwDQxD8q4(aQeIWQxOy z715NT3mq_qU^8_n$u8!X4bh1N7hITSU9H>^kNd=Z`G9)*4@c9P&7N z{14XO6V9L&ZAiaU5(Y~rx20bwA&*GAwZ}2yhYDv0u;DS1J{);^4&D3~T`26o`-1oJ zH})~m7g_l$Iu8)?SfS#Y}?rD*ldEhkqZqiq=H#Kt{fRtc{`j-F5YPMdwF z*me0YtOs8l*7{o^~pBjIEFlq{fT3M8kH4E96CHC{PV^ShJ!DC zFuZy8iEyC#zlHCu|8HUc+0TUU_WakdZ}oo--&yfr_4Sp7Hz zHvh2SG;<@^tUofI=Bct!HKNS?!)amrlK&q5QM+l!Gm}EK{~f_UbxaPshD{0IoIEM) z+&W2r>XfkO;-HZG@raP{`*)g+7n)ZR-q<)e>|HiEd~ZNWIG{bA@>?2h;GPI?YA0ps zJzwyhl5uYT3GyUP;|$KB`Gw*7u!e_p{q=sw`u(C`rr$67wLEpbU*9*M3>W-f|CTOb8b`(B2IMr(YgE)EY(1BQmfH;qwnT=l7e;Rnb0CHj3V{BXsv+Pfo- zKQlgm2|xb=_T6gs-9~M%D9 zIW}w_N3RYD%{R=yBh&Zo|3=TSEnCllka|k|oHfLp%8nvO5Qcp~AzFqvzN1 zQ}82n({rvhk4M+tNzbAk9eA`)UCUYAe88~?ai znzdJpLd{!6VJST^%RRlsuS8#kRK`7vOQKP~QRP>SuXw^fb$&~4=ldt)9^!HD@D21$ zD93-)K4@H^99g2hu}rywI`wks{bExWZh@MyfzQF=3uA&mplID~2(K>~5zz!cf5PRRX@eouNE`u^QBCi6Iflc;`O z8S>o2HeKu+I4!Prr}+ywLr)BFFE~fPgcQ9AkNkg?ANTxzJStrATQ9ERI{G5!{oW#b zUba3V+53%w;U4(_{WtsGkyf@0CQC3B!_nHTjnSdci5|zK#6IX!`dH*|Yj>jiMRfqG z=h-h_dgJ^y9urWzR{ytt{z>%2-qE4q_0eGpeHv1e_3a|LeROEVmeudIXWe_>Z%2eO z$IQS?%tqypE`2^0qUTdjhs9*?PxK!r^?$$Z92fO}pOxm#(R}NpL(}xpAvM|B z^1lpWsbiO8C01b#)?owM&+7lbtpERp{{O4`|9kZRBZoYWny=qHZ*(|@G+J)z|3@p@ zP^Uk?ey8)hmqc9CfJ_QEbvS=!blBt^<=BQQ?7%MU!9KM6Pe)Qef6eHSjn7ouEy!vX0WLN$({Ra^!MdUlw!k)o%O_g->qNAt0c z4qbT7^RJ%eJuL7Z_~2{lxi{7Mj;*0L%##k}pHTTl^NJ=*%d>5aYZ#zT-m6!iH>}XU z-mG4KDYmbDqe!8GuQM;L6F7;}ID>PzfIrM{UFYRT{MfX2O|0@z`L=sI%p~QcNVmM;^XCzsQu^5l`o!bBMU?*Zb ztE=1lUoVesW@nK{ud@Gdu>bF}|8KJYFUp%}l|S1ueE%P_&V+P&KWUE(a%)X+;UP?+ zzeX1NKYncG$Vc+JzKACIm||;??7H1l0$cmx<=j7(=ssZ65spxMPU!Q z#eEmXa}KjkvTz~)e_1(8KZI%=LEn&~kceUFr^qzgkVRa>zeaqX9@ivna6KpJCy^S( z{zdW`{)gw-&f?;@!!zVLv~C_4E|8ay(cjR?{*UASt_b%c?j4j@*A~VB@-fs+=YQYo z9O}o!UTJnn8;{2SUGtyo=)*1CL1F!C?Vl|j)}HYUI=VfhjJn}Z<^Rj-U(f28}R|I(B7 zo4!~7wd(V)_$i;GV>8A8be`oGK%pG{X)Kr7TPyep*#6xct^cP#zy5|ketr1~G_s|W zXhI5$rLz>vu@bAW2J5f^?d;Hw9?u9_bY1mq-tmm|?dRF&N9P%TfMbZ`1X}L1Ssl}Q z!}tTTnvY_Wbjq;}Rj6$?|DSEUgPxE^19s8(AT`FdO?8da-9JCTK5=o~hXdpxwBFUG zBaa|++;>?o|0^F7!YLHieb_7iCw-?i^8Z!qGw|`mHfo%wSMaY*_x(@)|0uiv7%Q)O z(f@19AofzE9E!AuQf!bQmvXT|zDSUc4K_%y;SeN9kYIxZd$GX=2{sggGR)XRIf#ge zl!G{(VP^Ks>=`zD&yPKOZwV4?kRU;VgiCNThaeZp91e1jgD~IM+7!O6-h0pe{_)P| z^Q`?mYd`CGp0z&fvwq57i`i_~v>}HgO8HA3i|uVk9Jf!f{=IWM<;@(r&KO^E!T1ug zC||XVWB>hrP@HMd9oyTk?OaHhX&IC(E>_7$4B> z{_eWJOY-M6_eaKggt4Emp#O67+jkSO-|xI@(&R<52L)wOssB-5L@bx$T;w<>`Lf@x zq8B$%QD$ADG3pWYIQO{Lecqw>q2EU3Pm}+z@&9@MPnPBf?z9G)-v(kZhGIB+^ka=6 zyBn3eWH0+SnjDLg|9r{wyQcp(p7V8W9AA)EKd0U9yB`k|g-t@?M{GX%VElhMeFkQt z_bo>qp>AqU8x^T-9r4pPQEV!Z--u3Ovqlj;NXr6`_~2d=39Bes|JzrVB!3X>Q_xODusuA~(DBu7+iBkW=+u9H4X|y4Sy1nnGkBv@BOQW=w<~N*? zz7y&o&N+f(IF6Gzjk7q9c4c};%xB0U?x~uWA0PVP@%e0h-~V2w{9otYka3Jg8=0P< z4PQJxxQwgF&Q(VkWBmbo&U5Hw=Woz&p}1e%9sZ*=+HiFTVb$scePq8kJ-cDXDv|?H zT7O`=bPY02%<)h(vpK0t>bcHqy6L{pcrO>cr^DV83eS~=BD$`7PdM1*y=IK#AX{s^ z*S+5BP4AhW+3DT9qTK=y)-h_Pr_hKdZ0N970blq{?QyX$r}7{El8EgEh0n?(?x$`U z`@5R`T_=B$jjzgI8SVRd{dD9g=ZwZ!jK@Sw!W2wH``q{YAN3{V&^1IpA1VG}^8F*~ zAAFDx?)OvL59qc3um7=J95WEhjG1KSTE8%xoP+Fi^$~Iba!)Gb63VtmlyPKHIZz-M z`E3a*et9=h_{ComOXCp@z%J}TuexEcex>Tk+B6+k;{cMVy{VmVjOQ1RPkFZLG--MpVx6hk zanoVXnyeC*LlH;N`|jPuG4eQ0qUXDJ6Q{}UFZ?BOmOPJ(F&vkXzH~3~!S;`A9Et5}oJnC2r6w{>3|ar7uwsmS5eM=z6^`QFr%VqJGT1M8kx8i6q%L{a&Jp zY<8R?3(sm3bZu$;y|#cm^gi_4q|8J6#=b-cI(PRaa_D+X+!_DT+LtIc^(9&s-%GS! z>PsBF)|W_k_azR|GxY4azC;_UYx)uc#WNT~F&wqRYQ!;uo_xl=iFXuzG*aW-o3N(I z?se+D#8_eJEB6xP$%)AByqB0nPC-U^XY0L0uH#-}n&Wbm_JvUfbZx(v$fItg^h;}l zv?is!k!+glKaj!<*UZFh%)vY?z#=R``#I$gI*);q3Ld7sdx%s*3U|=MD3}&i5lUn=*hcx6Ad@-Cf3l`BBdVQ%+5ERxtnOj zhTHNt-h3&9&ENTCSnr&Q-^uI0Eeo6Im0w`LzhW#6eG9f>2X>*Lt-elN^@wBSO7;7l z#{cb=|H+hcV~_Lpp&AEJVXSPDOd*YS?f)H!a{zPb(x#tBfnG$5eBG)HICxZhUz7gF z4()-+(%WhvjvR_Of@3(2-alx+BfIsVohEy}rEWoPP;cI-4>yj9x#&23+Z;u*@|*g7 zzNRmoo?WbtgRAu1Eae7YrI&sKMP*6>dC#$HuKI?s?$7(qt@YXSuhtB4+=qS@zHbb~ zU<}1@-~DiM1V*9vU-aXXV^OV48&B4pvELdw3AOYow_-eh^q;01f9Se6Ze<2J6SFY~-CuU@zgYA1J7sLXeMp3r=5haC zy_+8A)@S&qi|9G_z4H-s{OC(jWLFEwC$t|3e>lHzCp*OkH?o&aN7XNox@q5g*Q~%w ztil?s#d>T)`)dC8I{!iDPN-j;VYk?`0$Ft2!uGa```}*wH(6T$j@^tlwhh%IwOo2Td)naFB<op@*a-NfFP zZYO>^|88Q>;=75T^uM0itt|L)&A%mfy?j0KqgQVxc22&T_+j^z#E#AXp7_E1KPI*> zzLIEd9T{3`MuzlTPlkgpm|K(hR5*0=$*|9VRpS8Ote403_bZJx`?-FH>i!=IaZkGa zd2sb6KJbpxGq7UyvhJ-Cdk=tcW#{WR_jWg&&e z?y}H?Zgc)B9@YQ!|y zxcA^pGDFTL=OCLg&XQb!+zZOcS;|JTYoW4{EI6)y@zKyG&10Rj2uo1A&KjP=m(r7O zJsKJ|J{p$OS0LrM*>MxuxLte8u)$&D$id;wXY{4>O;=jWqk7XKRu7xKSxh@L^#d2Kfbhw4|&PZdvG8*dM}54HN&8BJ)!0q3OMf09fgj{i@SZOC+M7syC=9*&DBu>Emf-ox?#L*AeN@QgU5x3vGl zeCb~-{n`vF>iF5NJBH&piPJcX^Jr)HJNWaRD6Q|9VgFaN|M7pwmdVnut>EAz?hA+5 z{>>cOF3=U((F&+~!2~#i)<(Prex(?n!-BHi^g6F!|bM<^1YkX%r zg%48y7{)dt`4Zc>P};Zf4t!NBZn@0c>J~Y zk4LnBoKOcps{P}Fx&fJW+(vf)mjC~lvHySK|37M-e{n9wa;(5gtU}>aWnnF;3LJXB zh+oSaj^A9S{fj*>wST>8UAMi)Y6!1FW~aKxTiRt%S*ITMd+%d0-$lCO-feLVWwr8h zfOFSi3$~$lq5gN};|_XK8QFkc^gT#zR5tHcMsHVEllz3lK7?xW0J8LCJP)xCp_5II z>++@@x1r#;=(sDR4nWpj;(uS$zCb4Lvj616_Jzf+bxjdRa16(B5~pz%57+;>%a6Uu z{$FGN+3&(j>_6EuhW($w{*&qH>_3@toFz-^|GfI)^?%NbvsPFQF4B`%*#B<+8NHxQ zHRZV3aTD3dKfWwH&U-A)FEB1Y>-}F9){Bg`na-Dd7h{aSb9@U0dU1m997TAH*E zAJsm5SX%%N(K9E+b6x&KdqVz3Cvy1f{oC${vk(2M)EBW!{u@XR#!%F~NZM}+_WNEzP5ak2nc*ov9eKgD#J_qwqFo&z+mwae?x|{wB^&kHH630uCWv7;t zEAU{vPCldEO1-{F#yt*J`mH{szg;`P8y~&jo;CS?{W~ZAB95T&`;p-|c@n)}_bmTld?-Ek zeV!%H<05(x$3VxuPcG9R%wMSdlINnoJ>~t!HX!a@XnaQcH|zi2uAdl5{n3s3rJK&^ z|30Ok_^A2q;<$=l+`ui|K_B{kU;1Cx|Bg=N(1kn-D56DQcdP$6xK&>_c}RbGhRix{ zBhybl9R`SJAO>S7vh$szk9s&gr;oZ*pZ5s*C=?wR7V7`jr`<)47FNCZ=`fZYkJ=&n zQ^`q)V;CC5lk9IC!Mo}^+Mk+7N|)nCvT2z9Z)vWV<_09+6MjxOn(3*-Pk(6rpDE(3 zw6;ancdV^KFUNnn{?Cj5hyI`2^1S<-A)c9-jXB6lS4R5g(Q|M4wl>-afW8Pt#|6i6 zj$hYy?YhF2qI$J`9LSm{pAIX?m59$`f%mkEo;<9N`#tZSjC%~kxy5mg@fyFa#d>T) z6)K*vwm-QIJFp9Tun*NpB89@Y@B1XB`7f`V>k`WD?|=6lzP8_cA2r5;>~1tRg6#R% zr$doEf@3(2lQ@mDIFH`1m@i27;Bu7j@}tSP{!cG?1GjJoeTaJs^?OU%hk+Q3tgy^{ z>jKhq&)l!~7JjB)?zlM4cOWc3Sv`2F{=cd6KN_~m|C{B1_Opp>M)Z+O>p$Hz))u2& zUujHtRayTqnm!io-fxF8zLU%?_D*KKZ~x`-B1-fBYuNuL`F}tAPcO~S^*?Rvj0+XV zL`=dIOhY+lp!fUcf2%)rf92^g({ayN>?J|Y!8|NL^^3|+atW4VIckN~DA!idlk#}O zEc>d_S0Ob{+JrStmR9L5?XSx=H1LI!ud)x?noIlZ(o@UKTXW4Cti^hKu>WV1PKji#DXsIrlf(hs3o7+mLl$-2ZL| zJ@phJXx}XTTiqiW_bTk7$Gw3H>NlnL ze_i@<5b2}Re^}a(MH|u`%5m|W#d%yrc8GJHWE1E)HoFs-=~ofg^e-UqoUSqYNnX%T zAp8buXMMC(7vCT+R(IPQahUydgvGSRJrA2v@mX`xjlL5 zM(tO$2cmPadlq)+J!Jv1=d|^qjy3%W82q8#J=)Hj+bC5 zR$wK1J+oEh8dP8C7uJ%Me~9xA>_1Mg-EI8ecHlm9?f+#BgzmR}n~vlDKPfVeHsqqN4p$_P;24ghH2-h6_7(br`G2SB=kaj; z(7nnk)b-~V@e3P9@;`_1i};*PXyzlQu-r9op5Wv2l{fI2&pN({N^={lTH^E8|BcSO zeB7J-d%kdfK0j?PKkX&{H=ni@5B<+4`E|$+;^&>y|IG*M5#ME8MK5Yk>HqfI4SH-7 zXuvJ{9i%R}H#8xx)z>F1?Y{bL@y#N;(RhAxFfzN}U;lrZvHXsQ<6-~*I{xEo{r~*T z`pN#|Iq65^d~v_%KM>d2AK{u&7>%(QkBOLsDJZ4?=3k|MyY>wf(ITC#m!#iwiEAbN zP5*z#hxb32CeGTU(l2f0^rUuyxc>hPdVzoOu>U{a>r{Qf!sD2q4Er;a9_N_NAuGS@ z{g?w%oT|RS&&ZGQzLBHHHA!#D>+_tq0E^rzqJUPmu0{Eh9wW~qGmZ@w z7RMfBQ5yfFoNKr#eS>1Zs5(-AbtA<2e<}ZwMh-*@Nf$t>E236qv3;#rF2Sb^Av5%>RFNzd(Oe|K7c zkiG^*cC~Pu{k_iql52%kuQrB(+=N;-wTj$=(*A#^Jl6}xGC1CW=Dg=x<2jQ}nb`he z{r{Wl=DwRG+334z@|`sME>h$MV>C9N@O>Orzy7-K>bu%0#Z%SjThTYwZahcFIq5gf zBaZdx+HP!zbQg9?|69`klJ9T6?{Bg1kGwzsR6EgiYa<}N+8iG7RpS7X$Y$h|W%3|B zxBY$_Qkvd|;$-(56TAZT*))W`8Ym_X2&Ug zRgGlRQNL}q?;q+B*CI{jjZsF^X8Qof-?VonaoV+KaUK`ZgUc|^IdtN|K7hpwo)O|4 zye_0W?EBXmLhF9dqDh(;OFJ^2&mmkDM=x&R7B)_1_Xlaqq`yfoG)@nF^nP27vHj-s zFpwOK%Flg1RQ*D^qP{=W@o(JYaB>8)!ZH{|&yAD+#^|S`k415$`$PU2_vg2s&;CE! z{rLVZ^6}^G^X#07n1m^)c-CC{pP8plFGueiQ^E{#X4IeH50GZ@J%B{In9yVMr57m!6AGRev6$S`jft4ujKQDY0J^9M>p<&tc zVGVsPQj4Ege|bJMz5IM=^xHqb{JF4i_2+@m%d&U_Hk9{ydm*;CnmRTRNEF8G~u~4fo ztUmE*sC!la*9rX?Xgc-S{r;~rLqhYp$Bbhe6E;m46ZVK>A1V(&8>)s(2o)p8h0Vjp zg?}73A#8Eq)qXpGe;WCjuuWRuKKGfhb>ws5hc8bKJ5EfF-|>bNe*E&(uxt0{!jJZU zF6>G zUGWZ1J+FPr{QkM-_v0YaIQ07So}u~uubSWQy|q7MeE@V$wLSp4##tX=nDqgMSRY`} z^I?O!>PE!zd_DfJps&8WWqPBgVQEA-q=d)s3$eJF|->}`+yC41dtxUdnZJ|X;?``a{XPj9`V=xf%esJMIpSVIMU?g>KCF&#TU`M& z{5BJH6RbaEyg_OGp(f>0#`uH0b|Nx$PWy(k;6eY!Z0FCxJXBdLrt+P#(5~$3K<8v- zp)&15;~(Es_i#_GIQW9H0*9t5Tag{24uVDE>iO*RVF}sY_#^56Rv?JRjDQ>#+${sOIZzA!}CK4~5);TKX<>50aOZ#b~;qeuW49|8@M#dbTc( zu}@BCm&UM5i`oDA?38-I0@vz*+`W9-?!LC6!!Cf(IU>)LHytT zeBEpO-W@rIFXm^ib-;~4oOJ@=lt+1VrX3gZh({y&?YpUVES zt#R-0YVlt7{T*{o?QU}d$djnKB3*gu+N%BFOX|blVq4hilK;O`I@htk?z`Ua4X!JV z|EEX0T$+E^#5WhuS)9j3^x!hCq8IIKdB@Ajc;wK9JRYp`)pAbyaZoyAS`WFOjOUU? z+g@QqjO`Q8E!;sLvXhOPhq@kCSnq%p&T<%XxHCN zwqN6KZsfA@)z0#zQVtur_k%3dxLr>y$v}OaRkS394B!az3;~L9|wi*4c32f-1C@qWXX$B z$7Q5%>Te=@aRaxIRp*QKk300-T;C?WkKXSG{_BB1_=0a$_(T79`uqOxM*i3`B#3VL#~d&4UF8l?KWH#9k?am>d*WS`dheLXgzLY=RG$}bp~W=-@h!nR=tc3}_p zp>C`Ezf=B4lKpB_9z+`v$B3rx>K}H_0VI(^F{rDNZOEa0o&0}T9zzZ#JNuqICjS-D zvYGvRQT~5PUL_ALmN&_)<2Ev`|5Fst5gfyDJQzQBk{-*HPImJ&{VYoRdhV8gxBD;h zys)_b&qcBawb~9YlUEV@z)J0iN98%kH_)ssNI7mK<2r+PgvI@LE6Vi6YyaEu9)JAq z{r3LCW9FCpz3}O>&~LkZfPv_J%v!JHP*nb2J;Q z<6nX{O?nl!U>n-m;Er=_R5$yNuI+5oUN(w-jeGsH%wqp=a4}oQP91uQZM)9?p$*lq z8D}7#UD$(tsC~uwgO`nSqet7_fCEuSXJse#82|AMJBda#O<@0p#dVI4x%PwY8^;}=#Ay^h!~gb- zI#Ajl@jdDHEb=JOi)ay7>vPgSS^6K5{wJj$S$Z4d{KpFGP*;B18tPwEzKgfE>Hhr3 zi}d78_jW~J0R1vj7u*|~5a(iC6_y^R4}iRZ>~!ypyn~GY>^!Glf?<1Jnt>RMp?Ev`x#S3pLi;uKUwOP!Ig;C` z{!7o(3wilk+0Y_yw&I}k((0dwlqqpXP9`>*1-f^y*-ma)5!So51+_=z!aYK(Rz5P^JKwpFIX zf!H@!8h>y~Uzy`o$hyBZqbIP-@@N#uM{%+Hv8lPlbxFTH9aSZLcxG-P(A%|Kk0Bo`2^0oO6n( z{E7bjKRp$yU-7KHnXvZ+o%+f3j`X)1m!@@dIa!A25eMM|PQ$pTBGTz!l>M z&{AXl`Bv-CBfZu7{X5OyziIxybK1x{{oVCwn4te&UwY#h{rB`{G9}I%?(r7xpb!0a zcn=-EqmT-TuSW%ubh)II0^(eQ%vP@0$Txiq^b zg-PO>f@vtn49vuA%t8A&{mh@{=fH~#vx^X2H7XI{UYrh z&lArAEW#3Gg=Mgmo*U%5$8!1#6zK)z(S?=5s=eP;AJ-UcqvzB&N_~O!Vx-%%k>B#y@6Ug$mhX=UKZfHtiP{eNr&XSzKPZ3B($6EV z2@ux?YI9K|872+DyYsj^z`~tuK>mjVCZ$cHe zU>kPe;r>4z{s)cq(vN1(AcbAd*@Jzk#sMUeLK^L!c?U}K?8&Zk_v=;#^{XP^?BV`M zZ?XS4G*|vZ7Hx?A4{hSfp@<`R(Eo5OhOcA8S8LOtpG0vdyMw&v)P>W+V*kTg@;qYy z!$qyW5uv`6qC;P9C->R;EFxK_u#Y1GqaU7%5=DdybwE@iX z+cLJC9{U{zh<6~e^pgMoFZbIPI^PmEe<0TTOMZak=(7)Y?obTJ2*kO&H5f%ty0->? z$!Pjmq#QRpZXz4qZ_n@8*hc>QJ8Ue*JEy|h029ecn1X32#|-p7rJt3Y9d-6Ojh{XVJAJqGQv23x-Y-`Ak5yQMwOEf$ zD73gAvhErFFB*`Xe&45Tn!v~8TgJKe3!JkB+pq(>um}54jduO~9enLhK6ma_{`G49 zwZ8QNS>$WC@NZl9^M8@XA!O?LzuWo0NYf99=d$l9Nv4o(at_k;+%mqk`)Z3iim&pm zh2>ZCwGs1u^)-G`tpCw#uf%!x@)2t24cZH0ANn!J$I(oW`xrc!=U?}-^oyrqo%6^> z@i*O+{=3d2D~&s=@*YpR_B77oJT9UKA6oyBU5$NC`J?h0O8I}Q^zW4Zz0xDSf3yFt z|F|rktLViIWLHc7C1W1wxw+Cm>*L`Ly${9R;@<8*UX*^}{eJ8{v|5LO9EjTIJ{|^> zLs8>dHw==lVfIIKJOa&U^eei@M&YpyrB2?jpCFHn(f%jDH7w^Qw^|e>CFYDMBvgL~WkAw1h`nvpoP98_LL;lam`%T_khW*d0 z|L(VzxHvOpT+43;J?DmGS=fANWti=}Ij9v@gL(Ai ztI~!__O&XuA1Q;H(X>ojg)eYUPcWbUtKYJg&ET-e@e(Y>a_q0~7k+lvT#&>4!p{e( zhiHTD<=gitlUDeBC03!LtSl6;h939)T1&3SCiMJPJ0scs6?GhP8+Kq9KG?so$MHT? zW83x-q4yp0KYp)l`mJ{;tmj{SlVlo|Pek9gEbP8CJp2TEu2^eu>cFt~u5}o%4-EU{ z$A6YDw~reYesE$$*fC~Q_@TP?PHO}HXxz}S>&fBvVjCLX?tUz69rIZDrwNaRZMTiJ z688c8iYrfq&HbMUm6x6fRo5R6ZGOw)&F7v9n}$6V-l;DOMaM^Q499U2 zKRs1uU8}P2BkOzYy!5f~W9-U&ES&b+S^RL}qhSa3Zx7+Tp|*QnsNeW{sN4VDkbLVqp&{`d-|V+TRi*va?}iw zF6SpP!|_ZspI{%vS6auRo874A|Fb0xdHz2?v2l#{zq#7~kiu-&%)vY?K*i^+;Y=>U zQnc6b|C{&-WG>GqAoGq3WXb=3j{lE?{LwTH@u45=Pt&%SUm=baSc&x1{$UlFQ76x; zm#m?$#d>t!7C*Z9;d#V4w8dro3cuBEWbf}9vrMmE#)f|NjZpc!H$v4rZ)n4kmY1c; zf5oxAaSe~S*Y`H(?!bSx{_n=k+61x7xo=+MgOK}BjRW{l`LCVrKVAP5Xubei~;H$2G;Z zz|Yfj(jV7mxk&Fp+>^Y3ymPuxy?0%x5`Gz1(TmzM>q1RtUARF{-d^V$S{H87?;tho zd!ZRk&S}I4_a9HUelPSnC$nt5{^#|f-%n!wTiI8}K3rKJa_G9fUj1;rKH&ACcw&7R z;GBUN?z+L`(5PbsMxpN2?}hrs-wO@%zZa65*N4W9>q8TokwVK0>q9FJB8@|<*Q-O` z|Nl0!eewFxF=2h^9OM6|`+s_Vr2il0|M@`Gch`r}{%0)4V+9>_or9bmMeL>#~<-%rQCT3#}=3xOAp|swTcBxoTS>kvpD)bcYp@pU(awkK*lY~UR(+>r*Ixb|S#WH=VQ6{X z_#Yf}Uiy+TFMPiYnMK*kUsJ z=G9LcL*%z=)L!^VI6x**b535nEzcq4IF06YY;d%_WYcQb|Eess3D2SMOLbrJ2#(=6 zKKT4jIzEjG-x~I#&F!gE!;sL`t6nn%r6)~c7MUTndIOY z_Nnz@Xbe+c3@1lm6h@=sTm8dWay%wt5_+FlAEuDgkiPzrP)^RkY|KISis#pD?GSqI zHNN31>X7ykto$`wIn}#gDE}!3VqbCh7uJUb&g=Qy_q0Ejg*VU1=k^5HP@`T%Ulf1) zg>Q%K{;hF5@d@rmZ8>C@@+Qhz*aV+~GdbC}o z{e51tuFPuV^VVs5Cv&s353p5v6zFk&Uiw9CqW__<<>3TB>_yLuUi%0;?7AXq=nehd_aNg393Mk7J#~|vA0{1sFU^me z?>r>mlD<~qdFiWn9@<=U94B!aXK@}E(S!C`{C~tTwK+VT|Dlc@%i0#t=;8C(>iKw< z8D!~go2Ak3m&J1xy|{r`H^@xYj!)0M!u~&LogaE1ii_EQqmsD7?~-!fnQf3f;MndS!^x~=|?EWNE; zz5J8>(8r7m71tz8!8DYk*L=PiWViex& zk1p}2A2G&8_!2C|a%3;K2dtpidB^oAt#8%gxzJ1PKjy`zhUrst&MK_ITCB$=RACF+ zpX9esu>Qc9`+j`ax%+;5;e@j6D8K#H{-Jd>KOX59`R$AO?bEG4Fqt1eu7Bt@C$jr1 z?ET~HeV6$gp7n`#SX@G@=6}=t(DUNC zh#p)K#z5D{F}GERthY%Y ziihhzDsOZ0e3$Y#e^(j5TmIiE|8JH47ua8{C6ijE7&i@OVxgjIVEQ^;wkrI(X4P&4Fx|NpQyE620ZOpkN_8c(qgD2;!c z&PSHtlk#xmZuuY0^pre2$2Ie?0E@5$OR*d)(B8!VU#9-?s`>}nwNCwmEI2NbE$_+y zI9M+)*2w>Pd2j!H|GsUf@+Gg_e!qNKDZW)G{Eq)ju17^9dqr0LL0R!=S*W6K!8Ytb zT(@f%xd;1DjW};O&LciR&%LC4K$4z9ag6d|>ihK<=}rr4i|57P(EHWF;RtyQ$I0W;~dks2faPnG}q1dY?>|K0D~_a=5Cul!@jo9NACZ2!LO9%J83X@Bg) z1Hv3}%tM^}xqxhE$2(A3XP4dYV)OH_vC-^(QJa5@^0XBPm91&E^`ZTHUHd;`AN3;f zEI}OiR{0rr{%hXL^pA$6!b;=ZR`YvaWcRSf{VW%r6BgUWtCxK=tZ=*%tFQ*K4ZH?x z>CyHzU_E^kqEG+9^@Ho?OaDUWI6tzz69p`M+CzBRkpu z4&Q|OKxzCBnK}{Y|LJG`hPAL_oL}!3PLpSG9v9KRTK-=rkC8d~yz892g#x{JnElW5 z|HXfBGyh*cJTzb4TrB^SZI0vkpC0j4d`#V)yov|oe|qT;*3Y^@zlGv%ac}n@WL!V1 zB}l*T{f^)I5Xb-Y+bd6@H2%joo*eX%FwpT}G@ntfT~N*qi|I69QP^;dKyTwm!zgkz z#$r4Qe|RcPBqw1S%F(@HK$t=HeBE zH1E~_yxqG0uGxZZ*nwSm^R_hWKizQG`q%aP|Is;1|2?`E>c4+c|NTq)@5z?y=KtR| z|NpSL{y5a3{~y^V{r@$_2Mp1_kAp~SqdbI+Hp}c{{rhuAh4$Bs4|v`9fEUD%uAS`5 zUiO7Fjy~*hKgYcHePlHbAc@L9Yj^mK{!MM3^~M!6oU%WC-v0DW_AU@-Gg3%PbISQ? zv>}JAbZ1^Pejy%bqzy;t$57m=|9_+Yf9G_$|Kq}{*F7ChlBZF7#(oOqdDJ94SLsjU zqT?PkAN7CldCuLQvvcdbgZho;xxhN5+BeL7dCNHER%?~wvTLrQ7dLPVchHA^|1AHp zu^nt}CmH9$bdh<-1+wV-ZrN|{OOrV;HNMXd{=d3q#`DX1esM0$0C9Y%ebez^3`OB_ z`^b?Mb^6lC(HM*In22riUV2u)Fo|B;f0nJDLeC1zU>ZF)PT9cbmeXgT_@w+hLHxdnsALM<|WI z_iW?3J8^%p6~dd{PppHK#u<_i*I#f>9H-y7+VijRZkoI!a>Es6E#ABdC$p7?*<2&}Me_zo4H-`V=drm*Y|3C&=w4pSvO?_~UxYnY0 zl#h)(`ya=-rT6v^>xFMZ72^26Y=7-r7y5@Sj!WzIvhCaGJJ7sQy#b|lG`4Fu6jseA z-b3y~?F-tR$pfg_-9Nm!{l6Pi`rkv+aSD|kp9ob~hTDJZ6Jhi2PlSKm|94>v(tc~h zKPh*&Eqpq>y?bcbI^^l_!`J>k?6~yzA?LT=KYS|u__{G0-rbLc?eyM%I?ftC(o^*N z5gbGJr$>Z$E~s~Wa&Xx9j5Qq9YxhnTHep2AV_&SF-X0ZpPaPG0(mE>aO8lKUY(v9N zeP%zJ`-!kazu6C;8y>c|4i7&Vw<2ttyCU=$KTuI-oc>eBqYU!CCVOAg?FHn&PU18Q zpR_L?c@e!|9uO+Os@>!3_Qi928CPRCD!w*2+=#mNuHK6G$N#mk`@c8-?;F~L{B{R@ z==U@A91KJu@D;wG@0UIl!!ZJ*FdAbq9@XkUKgWJ$@6Y6+<>EPuiGG`eDR^_YG;9y% zhOwXe@9Otj18%1|3}g7m)A`5z%4Vd{vW|btS3lUFFF!&5_Ef$rO6#~e|Ka)ohmDU% zR|nsHFCTuVetN!fdb;)h{bxC5U?#Gy;_&-ydhQDUf46@4@AC_M_eK7B0r_*{x^7

f3OP#)3xb*<@x9$Jv+BavhJx2#wH�#z1G^J zN5pXqHD`P`ND4dd_#~P${?FQijjn0R+y7Dhzy3x3$71#L`O1yg)yq+8e|U?(;hM8J zkBf+V*YuE=aTV>(=~$+Y@hbm={G0uU`5>)0h%^o%gDl$6D~=nug*)g&9JkYNKYMFk zgTgnwjtB1UPRAHTPI5E@4c%{kY2!p{fpy% zkloK(|MO$k!O+iN8V_+w`GF~}pN4YGz)Z}>9Q1x)o7l7Fj|@{^a=iJM(vcnAK9c*JMnz>hy7|oa|b(xv94Kxl~{!}Sc~=8 zg!Y%%=lSgOT=w}D_L)zgU&cPa%081V@3H?lc!95bhW~ql{a?+#uVdfI7*-{oE!c(~ z$li8NEXV0N_t%M1`~S1@e{BDMg#AtMf62J#X7zS{*bB-m)aH%3cHKUd*8iEsuAJsW5`PI_@I=dh&w*OxyucB6XFL?u{_W!y57q=YWLGv^Icbxl}?mqm_-?abF zRj)=fQs{F{zn`-&f3(&QIT%AR93$|bw*NmT|KZ?d`A>cQuiO7O$hRB$CU5dl;uyeD z;>?_pKlv`B>2dteSn|QV-7daB?0YPrh${D0`HjzqgX$&Yoi`CZA1e!!$nHOWG)y6T zgE}5rjv1JV*_eZQSb#-Xf}Wr8C&=zP{j+56znU98(zk^;KXrw$l~{$fD13Q(SWi~( zcTZpVe29CdZgO0O*cMUoE#ngCaco5#hu?HlzEdYjk#UcqO4n2!njW?}X9sp+5B8z= zi@sa3`?uPN$eu^lVaXKIXp8cu?~*Ly2#z7{Efm*yJx(vJ_cG4-HTr2hTz}w6-}0E) z{%d~uwlZJmgV$ zsXP>4C=V@7=BH=OPj59peXIHDyUkBW_AT?%QSG-l|K+l{uA&#U{h!x%^?bNNPrl~f zUU~k0{^w6UTz_DpdtLT?xaFMm_VRFt>_gUn_p4SfLZ+i!yOjG{R~`mB9*lx}iuYBz z&-w0er~BLM{_5SI|2(wc{UM7sv|p0GYtl#N?wUVO&tEWq{*3wap5ah&4#xZXN4ml5{`G1cYcN*va1>;bKEx}SO#|o^(-^~A0=Wjx3oep+l zm2=i$E!HE}|2L6U*n;+v>@Pd|zn%a0mi9j!94G(d5ZjbNmfogaAh!Q)6VDFp!X9K7 z@-MXQ?W5;jVn0z$KY-$0`3!kmIH|G??Z(R#|VtVXpF^p^nTj6L3V%s`7nv>`CKAw z7^f^6ro8mtrwNPm|Kb>mICh}i@eDke|2LEVVE*51`W*cKo&R@Cnx?p}RR1=XXC8e4 z7GVkgm-GL&vr#DG|7QN*Qt>Ru3amu!EBELBt)f4e|F?#|7JoheZ@sXzwu4P%6|&a{ zge~MYWcYQRevA9t>~Opb1wMZfT}RbZ$o5yXU!e0XZS3e`OY?|h>`L<&ho~QnR6k&Y z)6P4@=4Ozkx4q1M&SKx^f1v&p>lJKp^K0xZ7I=Oiv>$ui!#-5Ud-*osVJG`cFOC2E z5}%)5o6&!<+5gd#?xn#wZS)*ccg2k+G;VZ{!s7hDBjho}wu|HBNqlhr-)YBZ@i+7T z^8e@gf9G9u5k0t!nE$Vmy|{sg^Z)q#ADaKy%IC)a&HTSx;<VPaqH`x<3Z|hPGcXgg@uBg5?|q%;^2$HYCyl?J|1n1#vF%_Uxd4Ty z9u14gUhC`i{_K7K|BJ>4d|O@bC+5cZZ7G&x1>*RvmE(f9!4k-Nyebpi@x(YWe5erY}bwBcp`?>hePGI3`Hhvxj?kRp?= zO(Bgo9&(&5~pz%wS457Yt~w$KUja@BE1JGZ4=G>^atw?To(4h{C(ANFK!^tzmM~7ZqciL zukDVVxkK+mzgo}5_g?p^^uMKTR-0^cyS7>1Op|Y?`L?nE6+hEP@7#eHjG-8g5g3K` z|E&G*1?_+M-;O^R@{urFJYz8)6Oo<8{$dh6x0(IL6#6t2=>_C>Y9GYr&;Nahefo|VZeHCkvfDe@L-u@C8-A35&m6weRZ(S zXr`x7(P&(>^HNBo4LKBX1nsYWz&}(U?0STMIGukuS^hsK|Eo*2;-K@=%e23)Q#U7L z|6tpE{Rdw+m%#cAac+OCe;gBM9QU(Poh9~N9CsYs@K2Jbqi(OTQvKsM#?3e`tv|s2 zU!iO0~AZxS<*Y4~e`W^6}SByO)2VyXWBChS55k8!rYh_#aTa$)93Z?ynx3jS| zY;8UJ>A&iqVdL;{{yn|qr;?*xGZy195tA?l(@>7~OWGf_>35Pj?fYG1-f@8}YMUsn z|38`iW*6gJyF=1dntv~CZ3+HWz4iaqr#5EPtuRBJwfn9AFMK9FsXh_c|F8I&vG9&l zvF|{gqDlOX`X}ZHi~9PwJ+wl?)cnp zIzwrlp1ExLYiv83JLG@F@?)v%mSY80Vine4E!yjSe~4pDbLe_qyeQC%D9!&pBmW~^ z@B0&H+#5biZ`-f_aZ3BXcs8L5TTpwQ{k`2kY@^3DZyMy|9rRuJ(D<>0wsm3qP>lm9 zeYgCPBt3q|DKd@1r?k(JMI1rp?~E%@|2{@Pj+2P7B zN`5MQb7UCmy5Sgs*alD~bPe(RQ9xWHptSz?8TtPL`%fM^ zp=^E4^G6%Hzxw|8KLZ)!_`lKO9FNK;JfE+O3=`>-Fa^_4`14btob3Hq?L5Et{q8i* z$MH;?2iNZn*x9LK%S;~ea|SIs|P{d8#1hn>{7-MCY~w7&ml zdP+a{Sl2AUQY^;`ti&p;LHk?!|IxYI{C|BjUHbO(D6G@}zgqvl{_fTj`v1xF1^xeI z#&MQxb6)!9(_yW6)?*W@kR9n9Y@z2K5x0KyZS)-|IxaZQlU@3M;{LzYtxw0Iy8KO#k)Ds~i}YXV z_`k?LBaX?QCq4I?^uJ;(0{sZ$n8JtqAG~aT24U6Yaq=W;mpvU$lV?%m+S0lR;yUm6 zB1-!ooKny5zj6PA?yvlP=n>xYwf_<>lULD;8@Po#DA)t1kL*{kj`-Yv4Fk#EU)~9W z$sI5MxA4P}6T|j%|26zz)_)DV&gm~({dZyKApK>B|6ABIQh(YDp9s5WeIop1(D1O& zJ^k~L5n-=>yLVn!ci8`t@bkU;qjru8KYPm@_wQR1^JDJooA$>1i+=C_VT{Zh#)Q0e zFC0qzC8QJo;aUGjXifat*pYt^sqTLd&51vSrZN8(8n69ZNGAR@G^o?npE0)cg1M#E zf2@7)zo^gs7kx4x3vpfapI^6k+9hLu-2a=2@nI8&(?_84xrx@IpBO6M8Xq>l_}{`m zO3Rl1L$yD}`Gft#KMffkw&Cqo{f~=35q>yb|KkfILdCE8hl(&D{CKWDL;Ps&h_G|& z--n;B^Nq2`qr^QL8M(a9_W!@3{(slnFo3$^b* zr`?IZ1>4ZchUU@kehu0cF%W|>6vNSeLjV66{r|{a(*J)~|Nl-lZm+%)vSpU~KMpQr z|8a;O*M`eFukEJm`j>^=BW0o8`{;=8gR8G7-s3FuDA9budpjEYSG>nF<~F*=Q5cP}7>|jVgegd$GLM)nN8JBz2ARF?`@GGM zpg&lDWHx;c{(Ak9()dR<@ZtE!jCnWCTZEeYecKn;A1TehZ?tzE{e$a|{4e((ZdLzB z3UU3tCE{9&ec_{@gzMJ!x77f z9J-`GkJ9+`-Rgg<&71Ih8*(V3wslabS*9PCp4@rA@4xV_x|-v-cUm)=HuL}Sq4|Ht z`yY2shRpKAPtq&38CE^64#uC(@%Ouw&3S%!fj?d(>lX5VwH-7diAFS`87Z6=$61`m zMfBh@uA&$13I0DiksB%gLE@h*{^>D(-@kil!)Ah|Ig?DFLr*I>{nVf^iX4rx=;5c0 zC%d1pw>!CEnEE51_f2h8lN{HblK;=i|KsF;#5HA_&`eLEpv<4|I9c5bC4SH@b`+}wP z<@oFQf1Ce&)7Ud>(D`j8;@p!}g584zbqpdi?eLzxFH2FLb`5{6g33%CBYa zU%jqKwzL|**kSylwAi;@|DpRwR(jf`p-!KF{c7U`*Xau&Eja!(HqBJ>q>l8Nxot)hK+?J(DTK`3GsT>d(p$Z2@ic9Ev!Cv@{+gkK-gx zBd+Zd*El;%&++#=S6kPOei6mkHh{c3Ko`2d&HoqP({7H+FU%SEl`)@=uc9}Gf8O^= zcB{|cA}hW#Al!++`7Zj%evO{<>((nhVcZ{m!!y1wyr~?HYlAhu_kMd4UoVA$&KZn` zE55gGd%=ifsN=#HtYJu2enx)92>K}0O;xU=VUTj3ZEftYT&FjaDfYDG5q1FwpJXS7 zunW%3AnX6z-qXH*n_WTY3HJZ+`|a++OYQKjuBkE4-T#cmcuYjLNgS9&&s}pZrqHJ$ z))@-OUlNz_a$(i%)(mncYOf3kv&lKAp*PgKZ_IPN0L{DIH%j{q>>m&o30s1tSdJA~ ziB(vG2mNbn>G2y`Pi{itUrPCLL*?&{OXBx$5w;CGunT+8+fo+xk+J`_n(X;nSvWu@ zkwO}EyWh9}gS?Oa^8ZNhi!ALwF`51UK0D#O9Ev!CV>phJXg@6fpc6TCA&&xzD6RiL zM*e^1{qcYFIJWhH{W&d;vq*2{XOnT>??ti)Sz(vStH@nZ4s2vQ>6(N)98C#U$n)6SXXw<*YOH!hhh zN5(mIY+*eb7P>c-)+Jl*o?m1CUzMhB-VKF!`@#(8&%|uZ!8|NL`)%tVpw!MYOP*dV zZ@(aqqqvU$KaBk!$^Mh+_t^h)%2s6QZRq(V|6d&-?r+xf*hfD!esGcamtZMYU?t*O zLaWF%QU9{<|i*a**(Y7p&&S{!2jn0XCEj6QV z-Mvu1@t)uAg(MpJSxso}xaZy73(Li`1>3L#yRZlQP>uG?z0i@rr~RW(`$u2sLLLPa z(emEC(29c>?y3Lq;l!CicKf~1hV+no;lTe(*xg4}desTO-xNhD=}=0#$w*o%U0%se zWkN>0l9Aj{vAkVg%VoL5C9c8MxGb0D5|?tEXo|3F^(%x}U!kuM;7*!GFvgqM(igqPuu(Ral2$3TVa zEOTA%C0_0MyS@?E|1a$Gr??aU3;tL9OZ+SRZ}^@0$LgJ%{eOk7ufX@WsQ!14<78q@ z`=8AGUzD6`%nJWn8vi>ULeI=z_&>7R^w;+L{?yHR>S)BX)X_!kNOIUQqwD;GFpg+0 z-vqln$;T(NK6r(H`{eL1aNI08=eT(?C7dP~QX#y@dESQ);6wNaWywd#kKq&O+tEI_ z>fHUV0Yfvc0mJl>m~$T67d~abGdPDf-0(kmgnq++?@{_=SU90fa7@&Oo|k_f7xx67 z!qeFPjDMhz0~p4Y54$&V{q5Q;WuU6%n9G3 zZ@=MM_zt=L7uUkGWHkTpyX5!q1N;aH^~X=hpW^5EB_@5N(Hely)8{IDyZANz8Wz^o z4bF9A{lkp^eM4df6qYa|%Fhrj9 z{YBs4xoYi0cIa#ElWfxdsC`MepWgNc^|N&$W^J20;`++99gwCk9M(qaJP;nR-%-2~ zZ$huJiI~2Ox6lVSoZmigqrU?qTb_-$_^xM#_1EWwcZt99>iyx}3e{7uA8AlmM0uzY-%N+eOb{Sls5&JF>$!=$Zd|T~Lm^)Oh`q(%f`S%8C>@FK@U0?3VC<`bK+}vctFR zA6v@~-yxqxuek4$-@{;B8F0zHbzd(7M#!iP_V9m3*3TeY%jfE-A>*nd^@5szP2~yh4a=EoZIz|yq{FeJDTxs z-QcA_mk1sNM!#%KtGB%;!Wt$mVXQRHbis( z-a)<#G5SEgc4cE$c(?F-kg+S(3q#tL@gv5c_=5NI9Umb7VEw<2gW;fK-iHt1L-+{J zb$JFji)rKdGi>S1_SpI_yF&Q?y$m>GB-@^~^BlPCz zf5<-=enOAt1PloOlpfilBlW-Qe{iJEi=e2}Iu*zZQY+@Gs0JR|))`2ukaA)531ymWqzYuGpDo5w9kzs8yuN96rNWdNqL zQhVN2YwfD(}maDP-1%Ol5^9Yr^?|4H)7YkW&&8qc8R zHTyF*R>mJtzfv>@@EhWzIe@+DfN#;;-k=ZhbL?Q(7L@}-wemfNkLwSpkpI{Cf2kAx z+41+T1K~Z=_zs>$%WJeh$nW6?_z|Wr@}KK4&TZd0GV}l2r20C=H;o)sUykjle~>`N z|6ly1?Nemmh2O5f*UnByW?f%>0>kS0XsrDw&hb4{{#+5;}#smZP+%Bv;MW*a6A3V*9t>a zuirsm*FU@c{kt<`RXxhjuN3Bl?u*uRsn&k6{Z+UN_h9n0I=10JxQ{-kuATLKUrWCq z3k9AHQtI|-{(jW{_-yBI!UMLw5pP1K{o&cXg+93H-K_I3i2e`oBzt=0dfP0%p{?NA zK2?_+{>J@2nP=_Ct=#a0xjA1h&$A}vFG8m>^zrqb@WsuY@P)-)V{(6QtnMGe=XZ9) z=k~6K_N}XYVqY-+_3Q99=XwX)3Z4(=%byP|xxWsrJHHB#T>3@$Yh!EgvhCe?bp6)w z*@A=NGiARF7qg@nKNvpW{L=6q+un!AH}{7xHkg0ccuVNmxg|WYY(9pv<^#5U2+c*= z;Vho4)=#i2KWDvMe!e~QU%Wl^VWi_e<9YWQ1G+5?b>0?wVu!-h=++nUbLG#!Exd#8 z$sHl)e{o;!x-)!4njga_@F|?ZIke#sBsT5|;~0zot1!B8UwBm9V@PJbT3`GkgkdNE<2|R_TakKoGG1e?ReUklO!~U<+zn;q`D7-%QmZmRc z@qb)W_N^!b*OZ0igmAPrU(`Q8MGi=#<(t;WC!=-zC&^iCzuEWpZ>;?p`3La(yTXs@ zO7I=LoL|7Fi$eNW#*OXw3?@%Yb60zb9{pqKTjY1}EHeEx)ym^asxDdFiP|A3S+Y;sQge52k3K`9lLob z+(JKyg+}*n3^_&4FB|_qqTH)4@=d7+m4Uri4~5(8cROOX55$ezyN)}AqyL8um$~i= z_aXf_*=NkTpFZH)2FJuZCcf(amh889C>(O!tI!tBKm2f3xQl)d?!#*_U8DVdN&SjB z%p-MK`M#r#hf()3Chd%WfHL@||Cjnvp+mWe;iNXDLjo{3}6_M@A^17iPJKm7sx&isSp`u|R`KXGFn z@gaH~H`Woq=3b9F?u~d8-h#K`9e5X})#)>s#T+v0-;Al_uz*qZ)R=l|oJ{Pg|H-8A zK@CY8oV|W}-;3+(fUJT&M z2l+sd>-upL%kGUniCLs0-F+tU41KQOv*|t%zCr&M7S#7?aVgAuK7BL4^?xXIoye@q zSK;}WxlZ@^9ml-bzxY|<@8Wy-5vH{zXOLN6M!vq;zEmzPR4IRoJ^vEdNJeWaP0*7? zo_~So@0)Hr>$&`RU-*f%diZ_(l#Gq>@5ImPzr^zx*y10$!RI6L{UC=ivSJ-OZTMf? z-tu*M>UegzMvul4_9f*HM80+lhwKAme}{~ z-swH_(}idI%}>XCitQU?18X-fkbTYUe+*y{Lx^8w|F>TE|Npsb_$8jluW=3gCVgAD z1=C&Z|2X?U#{Os9|IYsh^+SvL7V#Kd`k%TJN5w_#T)S^yam5I|K#`V`%C&id<$c~ zg>f=*Qv0LIw^?KSkDL;o^~~n~QqK{S^r<}0G2e5< zK&R`rZ{LRdpbsDO-Y`_@J>qSy@f~;<-i`O*efR)AghXXt_z3wid;*`sWWDtV8uP*# z`kcCQ)|lZrdK(tBuhaS0+t~B!&qu`d6y$|R$;Z%JmKPo;pFpgF?~3|yaE~9)bJo2O z7maz1s2_*ak8!fEPW>NON5<@z5>Lj3hslv9zIyEF=AG;7Kc=nW+bfL$oL2|5nO8J@ zS^101f7X`eX=QK<8ULCN<*z(Gmgo7Ci6YORObSnuQ}$bb|I5PqlY9XFB%1&4yO;L5 zG5>M$%8&Ka4DAn-^jU1*Doj2T>F+iF-~S4}LH`z7{(OH(oZ1(@Lr;>Ec$WTMd=Il1 zwNG1JUzh76)56hOl0EU9@RaR8z>n|~^oon&r}Sv<{J^mT;pg;UVnldYI8F{#9th8i z`!%j%-<0x5Ih!eO%R@LIdCe~gsy zyYI^mud?4=xCi&)wYVP-U|K)d%*y}3|4pC3STXxwT8RqwKbb60t`@QXcSATTjdgz4 zZzQk$Dkr>&d<))&ci>%kH{OHy;RE;(K7#G%Jae{gX1rTFF}nAUiI3I}NM6zoq0g!N zXSESNMNhLIqB)T_>wk5BPtx(WW#<3AReq=UmKuM^()UjvJmvZ49|(`qAHztc=U?mj zBU-EBadC;1ekAfKOe)u(CVLUpL9_DtTn#%zco@;x&%%as;<)c%kN-@)?;n}@FXZq4 z-?rsv=GHrI60_0qcn06Vw=i9!?@v4Xzq|g%iuUiC@ozFw$Np{9PO$$ZeTux#ck>XEejPeu`fr{qhIH z^JLrC9t`JwcfY1z!@g;C2X4VZ{Lkk9-DaQLaR(0JRk#cH;6D6r`73{|dajG|7oxiU z59a^ zn;wk`y^s7rwEe9ie29$tDnCL-^PfIOegdCD+rN28JL18R=&A@m_gvm1?hMXha>f7K zlNF(jK6lms+jlC$BlJhHPgPmN5de75XUe^5as`2$BknWvq<9^d;`;aTBT;u7A;UF1EOJbX0VN4^$G z>juoKyXLNDh5LmcKxW-Z<;A?ZDuuq~iqMY%?{aXnLi^r(*{BG^TNPm>pM8JT`W#Ef z_7~a9#$coQcyE->`62D+`Tb#9eLPd5jxN^!wWs=67GZ_olpl__J5^8YdU9~1I* zvO@h-#{NCRPOYm5Z<5|y(A!iI-bSABJ@!~L=^ga=MR^`0>+0k!y7th!#Ap8F5obUD zqp>aFE5B05e{Fww5B+`k06v6|;A8j%K7}(lhiG0=8`-B!>30tUuK(u$PfDCiIP#Hs z#6FMWF+7eZ@Dy4`_lKv+>1Fr7>i%8d9Om2If0z63aQ_$FzqH2O>o^iQ?q4~U+;RVy z!uls43Rhmo_J8S7{|~pGgR(4I|GQV((Hdd{;H=zGX4-} zTW0LZTsEh3{t^Rh%fY3b@H{A7UqMXbvTM}^;r^l9~i@O*>( zNA|JH`%64u*E#4qqc&Ci*!4EnNW+2fCda%5Z^JwAF1#D>!TT_+KADl`EFxdic``~j zUF19k>Yu!*|3&+AL;F*GG(jeXC&?-M^|13lAe|2({bu>wx^%slwEtFf!bijn9+5XL z=Y)^ZKY*}Kj9JK7UjfaWX8Ym+qZ_t=}+J(JdHkMfBl{M z-!HT6l`)zBa$i#BV?>{JuYCqEjB!k27HK?#>;E7Bt^Xr`-+Aqu6w;mAIT*dD|NZ#& zG01=R|KrEx7wLQp3;SOdzC%umOIB$=)1x&FW>09B(7zY$Bb@G(ciQD4@(1F6grDH2 z*#0X20rHo49#>wj?Ln^p8~vkymKVC3A@MQ!`&Hy!=oNPl zc^{%V5Cg6~UhTbAD%*ske?0Y_X16u4KiTs)|9_Wc$Fb9+_5ZKA_SZW8emsDqIDwO> z#I$ksnL76OX?7&LaK3=uSj3Jb7mVkR>IWEGVb84TA0Q`$llAOKOgS#GdoWZAeuqRK`=knR#Z1Gxp9TwQ8Y2lP{^La93e~-P${$@MZ+b?5(ce^h(_+Zt+&?v5j z|HQEH$PW9Ij2}N3{^s~g!&7;=p$kuT{!#dHLs@ttzbtfi{%iBqO2Ze6OT**nXs!rf z$g2pSzwl6atl;&by`nsP4xc?;%m>!`cwv7y|COxJHixi6POF<{Qt}x2 z&yJrpr)@n(1Yd*g2N|NhMS zeTD8bKk{$5HOynzK0T|qhO6WrdX0}}&4;r&h&gURI)kZOL#}WhhO3loY^I^kTSFZA z;tEiRA{<6Bj$nH}E0mDypFS8$$t!Qq3T0$DDsT)5&*3d!6;S_!Fs`-E2IpH+D79;9~;bnC; zIkcV=>cl0gUlQuc21IssBiV$=&Yoq@&!Jhk6>0V90_KszhG$kK{sJzd9UbUI7y69d zXZi<{`UkeuxsHirm_D*$e7~Q4kMR@idrVZb?=eZA!gM45fa9*?nCqxz{{V8K zoc~{uxuDJyLqFo^t!96#hsWrHC)LeJ(q}NDZ8MBGhVbJ552tLugqAOUGJOflSivgR&{w1W?_i%}Fs6T2+dRIaoYFQQ!Mc4mu!$|~U>8@hhv{ndKW6oV z%q^<_*3^Gw8qr+$(WL$-X^uCm|1r^cy>EK5R{f7?{jIE&?-l8@uMD~5WR>ryY+uNu z&o#=M_4`6Ty#Ncotu$_|Us)(FT7SQYJdED_eW92Jzp>T|T94BxRXVvA+?(LlWX<^5=je4G6$PSgZRiPR+IEB-=^5K1+x+r_U{Uei)-Cu|Q8Fg?Ws7|B2Ju(NR0CO1rs9 zySdAMg2>2+WcDH-EA8OS|AAZiAuTF{w8!SPk5l4jq?y7cEaEbju#6RC{_CAG&Xgq| zk;B^VBNu(Uoxb4??ce&Muxh_GtYZVcwM8LzswiyI2QM1?#}<7DBlKa!F@#-liIqd) zD!GSA?ES{>jkTvx=9n1rZEM-DKPP(Mh}JhN5Z8Cu z_&omL{{&ga=P`;gj3a@G9An?P{11%nPnGdM$m4%d;5v$22RSdC+Tee%#s7fpX}>cR zI#1i(+V^kH3PtqZ3wIiix-%T6587t{#q=W>aoq5_e+caQmjFt{B}xv3QnCz@-$OZB zfyfVH);-SQnDB9=8??)tw9~s?2l|%X->Umt=YN2qO8y5J#t2S0<|Hang=*B`6i#D$ zP5xKEW&F!7O1D${A2<4UM@u~)j3a@GUHN}c`=6Y0T#t8ME1f#jqXE6k-v4S*XrvDs z_m9T@o9NBR{7=(2AK&o)(Oj!utybP>|C}=hb-_Lt(T)y8>u+_E=l@FH{n`G|Mejxo z{pf4e{_nW%(>EAXFQ$z3t!V!*t0xh+&lr-JK?;|!h|8F+*Z+uF%wb+%a7r1NCKr?& zqe<<5X^tCTOkko?|05N@x_`~x)*pCpc36_``b$GtCbvJA9ahLytVQwI z!1i133=z@#TbsgL*uhoYSbuAeeq;TumRHNWKUDv&$p5(U|IKN>9W(qcX0`L?Fz?x? z)|7>OEfy{pg;C$~7{>W=BrrkGeBY^}FqPCc?Dzd(V8-`@A$lCcjlLh8jmb}O`H9S0 z@SXpycECUG4>|N)w7SM!;Y>y~h5pSN}Hr^Y(vY^*|_go+BthDJC5sjh~m%=dzT6_9>@VU_m%t$bQdf zzmvzr^=OM6Cr=>iCpbw~BDSOLlGKlo$_Z7%)flccZjET|f;ggo=SF?Ung1KDXN4O3 z^q#iJ5}-~ykk+4FUO89xQ({ZO6f+ok^Ra4z||^Mk_AhBHA&n z4xhnnuJhzbw@CialK-(##{QSjZ;t=j?|16|Q(2)y8lC7ukM>+Q8Qav~qHd1S`w_?Q z_{WT~9TClsAG0lqw6Z($U!S3;a0!dZ_{$p`iR|6W!b@1jdGL=~!$HuhgbZhu(&<8JkL z|Gscaxb1`d0<0r=n%>*w-JzBq&1oE{+!yNT^@!G#8{W~zB!`Z<|5o<@MRqco`Trk1 z&h{TpvXj~D#~ssvMl_)rt+;@Tn2sCYx9_ZD=E!+t#;GPYI=R5UAMI!ViyyZ=K~C5| zTH8P4|4&B$|7({{2RhM($qM`A91Pv`IrjW)9e;RwKNc#Dzt^(=PdOJE<=yE2f0=P0 zdT*oax8E3I#@7dqxUVwf@7AtQiW@F*Uj@eB$#}8x_b&GzvtQhPo*u)v)_{c z4(KyD|9xgI)BB zk2M?&SLvDg1Jd22XI=9CUEeU`hhLjS{x903elPHzi`4JrRGv8r_PaTMa7&rNCd%{=ZfYN( z**RKq0f`FzG-R?w|APMic6!=4e*oJ;lQHxojxi)LgA^`d5tp%qZT|yE%j4@G_l*cg^Z!=J zRjgqhJv;lt2DynX?4Wm3{$Jl0cIkr+@;|Q9_i(fQrvGn9dDnMD{eyvu-}XQ2Pmjj- zha1@m=x|Kb$9AKCAnO@zI^RPMnTzSm>UTtS^c?2tDWuhz3mDC1*I>L*`v4R4B%(F` zr;tb<2zk3=`<+xoXj{AuxZ?w=g#IuI%xa}39E0w+<4DpX^7*Zpg& z%#wewesJ^u!2eObKjr#sq;bPv@Dx4rkvvV-q7Dn??zhzalF@j@jq!|P@1@@Mo;d%3 z#_QkU+2|Ym3%+0CBHzbG+oON~MZTSFZ+@7cpZNpAk#A?lf1v(AXtK}t_w~<^t+;@T zXh#P+(ef4FFuCyR#?Vb(`Jna@x&CqQZ(03G&(!~Cn?jsE$-a-qxW?$w*wd^!DoLNg zg1maO|J)cVlCdqkZzTqB=-uqmV*cL|HQj3q2%1v5MH_@=f$ z8T}(5-rd6a#){XTS2uEw96-M&oArz78 z+M7L>v%+C|F`WLtc_xfkp9zVoGhw3a zOh^`=QNErDQw3*2U;JU?&ku(I3~oKlrhYiY3my)``48*AeK?$OzLTg#721AbeUj-A zs_8YDw$DuA!(sOD!{+}ze7$_%(?;A?zANuX>mLqd7)Rpx!(rl>cXaaMFj?_%m?{(J zx=u;wG-^?Y$qv`pbSBi(=aTLX4fIAVTySr#?y=py#?FK$aXm$6LNnQl-sLml0(lX! za?f?gb49yw2Zk?su1lV?XA@uXe6%6Q^lgpv)k|Q4o?J42Z^!yJm}XPYU>0*t>aWx4 zuR8ZfF8EeErP+mUGNh^CF8kTOa9MZ>%UHoG(&k`Z`S_WzMqi(93hU$sHnD{r?BXi+ko9%n^w-XW95NSq zi01$1lLaV55pK-?JxrgI&t}#4=~pY`g>TOP-DT?|lYcK||F`^0KVrWU#N_dTCHYss zLaA`;5A|KsBfp`cG5Htk|ER5E+m*k#GgOe>2h!){@n~(ZI(j`83gvs-Q{?>Ntk58?5lv`D zD=y$7w!i*xXeZZyb0&0FCiUfn zF?!_tJ1`^Px6AkaY)f+3_M7#;`ehJ9N7N0K>WC`!-%0hRW214iBpE}`I{&iH{h_;X zfB4Fdeo4p7AgZJ17S&N}>MAmgg&p-@nfkv%{ZA(H)c*zQ_9FE^IhAYfTla@Uf-Rlt zf3v>eWVil$=Sd+srr%&gzk$B~OTvpt)1x(j=6Chu)Bon`C&N=4pR}H2V|eoLr^AtpO4r2XWR!vd*cVg=ZZcO&K<4`f4T90 zh0D@jLfgp?hVvZ_p{27nwBnIiL-^~NhOlhg3LdThX!vaNr^9E;>%&Dn=3JleZwjln zt>N+7#_&bY=L_qf3>{64*1q)bJN>omwryY&TX?cRgdMWSKeV3qKDx0h{BOr{L$7rz zWB$AL)lOqt?FDSJg3xcgudhBgj1;{*3~%PM(XNjJCd`8y-+5&iyQHs3f6?NF4y<7mmTtidkS(wKH2tH>_=_YLV9np@t?AsP(&ZB)rWE_ zCmg01W29Oh5*No%eNH$ct^_UnZw;kn8P-3{jwP@B?EelGzRH0x0 zc3-nTnz-|HO1E3OGx}&^&eLyAP4;)S{c3Our%{VKq<^Kq;e>C{GoEu@vt&wmo?LKk z=_cR4G)ASFkmfj=>O?k*&q7X_A4Lvhrwdyg7PSuYmo6TR6f)@PXij! zgctiaGz(9c8o$J>c1H9s#d)?tDyiL)VjqyBwc0_r>Hk!&EriK@Z4k75#dCjaUbrBw zi`f3^fzVENMEbY&3+UTe7q|X;K0@Xtbc*jnH)7~V9Epo^#8?RyomF4sNdx?9=t)>S*t#u#pP-59bhZr?ro zCE5>$tZ&F8n2a3^xnv%aapzul-N+X%KzfrczT(`wu7TbYyCoEgFT!CIqj$;ri8w+Z ztTP6T5_%~{Dy%_?c-5`?QvH7-*Ym|dk?SmQU58y4hB1ON$CRT2$8a1ca1xc6)`pu| zQOB_7=g4{CRHeEA3mBEg7{-geAMbJEh-d3Olg>A#f1*kn)u=&_^5zs7i{-H){O9lc zeCqVy*NPk1R2E@~&p_Ni0EaQM#-HW5=YCRo;`#UZ-s+;`kMY|Ia-nJ14TM$fFT%TDg6$Fbmg)ADl) z(fYb$;*yv_3hTxlFOku|e;3Kfrij*jjQRy4JNmM?B`hQIKZw?aT>k?9|JQ^6f3)_% zs%>jn#|AdBg&pkTDz^0tZ~vq9E8eSpwqN)3epS zja_ZX6TS^HY9nXt-zqjDqP1>^lIs7I`kx#TUXfOwGY2^j5H6QBIHM#6|w^74%~m zN$Gz_979te=$bO~|y%+O^>#8@}d1XcpIs3%H1Obf6Pm z=td0vh~vfe*O+k2S>N+nHr1{A?uPtpUVrqA_$y!XoXAU9L^K!eGP#6htY8&uSjPqu z^4unQ!@qxvzJp!NMmG9}{JE!2+LAxD%@=UhwmoEhQ~8TrGnxj62 zo^ES9NH2zd#L*i&6krv6IjsB-_c(rlDAhUp_%b<7&pv4Kr&VF$aIUUmPj zY1XyOVSdv!A#Ln`fiJ*lq5HtNxCAEblf)!_3O#Gr*D%P`6Gb6hm1eYtLCpQ_(FbiG zKr+fBm!tNBYeTelLG-SowX>qWr}V8^A=mavdLEgNsGN!VoaTznD|)l{#eXbKUnueJ zOML?O0o*kzjdO20mZe8hcF^PMgQ38*9Pt5gD{w7d#>%|P@(PG)Wh_a59fqx$JgK# zPUBpcZyskcqfD7bW}PrHg?TI>jj_Yp;3qtzVr_6FkMO}4H&vhwjy~7ezwF!?%u}uw zu#t1v##wA+)H+`s>d}bk-+@hJGg>ii|Cu7s`H1I?-~YcuX8rXl{#P}8t;h-QBUz^W zuTcJb{%z(Ip8sY}xFD_V`4BFW?dXWYpYeY0wFby94uwwPE3Z5hy2y1lL3b232AJ@y zV)TB*(V`Dva?v@~jQa}DExWcQVfqXf#(V>pz57e9mF!u~2@T>?xP(RY=BR&*4us3} zL3Qr{yM2kijOd?s!|e9Bx_5|N5%=Q$qgCNGtYZV2^}oNaUZ8Jb2fMh6J!Jg**!cq( z#E`Z_JVza`{V+mieM`N7T;w4i1t>%jrc1Oxmh_J;>L1hIo2TDg|Eo*;?*#wAD*l0$ ze6`g5H`gaF*8b2ANT_2EOQ#q|P=d&ZKiPgDl+x$O$cMj-UXF!5eH%z&-hLJ0dYbql zkjK$`iH%F1L~O;gyR7~uD}|%^e>eWCJY>K4jPZv}{WTc4%Kr~T^f+#;|9@5+t$9!T z=28e%(umePJO6QgPV^kl@4q|$v7N1i|JnNg+s2%0q;(3X5zULKCF@X+D?fZmXdt8h zuf`~BJgO-Q`(~TT$p5XCynx5~!CfS8_`kK&Z}`7;&^z(_{%>94x)DRqp1c_8Tju|< zRpazABr%{Yh{}i|L}f%&HjLoK{slAkOW_jI@4Pi!CfmNozJ4{I9{Ms?u!=RTBN}_? zm(K?8*`=cb<&saeznfd<#-+^*^1&-l3df1+e?xT!;0w>X%XZ#qI z^vEWS{!L#+ug1tJwrLX^6+>AELyfq|hCfA~#w5L#tV5DMYYzS#>V+GSrZ3>eI)QyP z?B6=)IL-cz(j?=;!{o@!%R-}Ln$V0^T);)NqXX0Cn9rDlFk9t3wbJE#kTMq`&G%s8 z6#s+E=B;CVRbK-p=t)f4FZ%yqr!=B@?33uCpVb!3tj}*;C{ZB2BI!vdjRnVa+b4!u z+valpKOk=X0^yAR&o_NL^j>i>jL`>Ax;G^0GZ>*qa}eT)<{yY)x5AO zoi(gu1Cu%GNp<5UeeRTgBlYqYeFqD*+PmUX4eVd+itFKnf0f)r?=k+w-&QsuR;gY+ zDc`DhbA|JenO9dLA1gQFXx?HUWBYH-!PNfAw@)G3UeB(5dv++IA4V~bAX@jkguG(j zd?{Ima#SGEuJ4{q&S?KPWrySR6F7-z&8s=wScfF7OpR>dzT@(LmFvXdY1h@{y2g|* zi2R$P@&8K4SD_j;IEB-wMP~h{I{E*!=bz{K7kd6k(-(3)|7CpyxcUE&2G76A^CzSE zSL>gYXJ2RTzxDqkKmR&u)}sN9XhJhC;3C?w{bO_g`37{*d%D;y#)dL}n3c+aC;4vG zYZD~*g-+Wu>;Jg7wl{m8-`*c?`v17^A!O{|n0kt>KS)MC?eT7PJbh&CKxmal7rGHc zKjIie5}EaX8`%F%>}ztqj@`J(zFuNqlcUlZllHi}F8V*vg!(k9TPNA$Q|jeJF?&-w zDO|!LCJ)=EO#6dASF8R#rTsx)#zK|)7pc?g-#YCNaXq%Ll56M{zfNu-R`0%~AFZpq zDZGW@eD{USIEVA}^|Svm;Q9u~#4ox3Quz-fID1)J6Vd$O9ml6XsZEMq`c?c-{ckJf zKa3+2vqILA`yca-96u0p=yM&uE#%Vk5dAwaUFsWa^{t`6 zu{{k3LcVPU=-o9&Oco(pLttQ8y@tcW#TfS8MeAB*{xfM;B92Gc*<2Zqn zs6-X2QG-)Rf7JTB;Xz{?#;O0@SlMS@VuNoH(KorR53KdiZGX8oh5fJm-MdSvs@mI~8naHHQVrSEY*>VK?||C{CiPWiuGp3ol1{Qpy)9YKr#h^LWq>?JJX zGM2E6=}PrKW)bzV&DVQ3NMiwiF#lg%5;y+;cQ)yHr93ZkReEis=H>tJ!LUZ(ey2XV zd(8`er*B&L%3p=BL2hCTJLswOEs_g=cUQPd?xD9){;yXKe@Ff=cWRh_qx#4Sr(qEczyVrve$>x&NsPo)R@E3utHBP9}RiVlaB%{tR4+%+f%mBV~#$% zc{CK-_Su?OhyIOReGo4VgE{)i^Yn|G+cxB%Z6n8wbERG$y4U$S*taLkd~15F%lx@V`xfH}N-%cmPr|6Nv&=YI)1QRoig7WN+E#{h zhl|5s9w`px^a`{s9tr20?+q=5M?xze$tnqdeYGSUv+V>P%_szlh70tI7`gDEw#kFWP96+HX!)jox2=lMZd(T?H!DIX*@Yy1HeM0tk`vt3@K@4=JzRAF424A z?jMWv!5#PS+qz6&!pLcTkE`x~+5N9S7?#B)j_eC7NoBYDi`s_6xbh3-!_Y%v z{S8O?Dm)Z&?VE>u6d+?CY12mYX$pn6`;LYp@^Ga8+0oEfV*TGbW3Q*LkHHS{myh$8 zA0|hP$tGP#w02VZhuNXnF>No`SNtVD7W5KKtDk40I$B*5`RdQ-sGIYy`>ZXf+eTN6 z!LAv5B@_0WXjHdjl0J1x{add7ZD#*>{&xM}8TLPhSK0q7FA1g2QHJBLqnxb3u_%la znAV1v(UzDso-l`bV-6|Y{Qp3l{jXnP9En!;f0ywG?UYIFm8p7h+5#t~Q;8~6W70Jx zU3(3EPTe}|T2IkWBU;liUF{w#-K(}nt+*cTkvg&-(Z2&5$VS8*JD}|m&FgOxZpLtt z_Q$Sv#$oM@0{1_|{>RP#e^n@(i@n2i2ToAs9cBDU=6{ef{2VKW_JaDvI&SP0_J57`hjMcq2}~fl$-iise-SeMhp~MjvmV~v*YicEw7U@f11$RY z*W6+6|HJCp=$!Ngedy8J-YN8~t82vfBaSgdwtMvNuOxj?eu@6~Geb{dq)z@37jH1u zhkn=G*UmR-$yyiPe3RDrC}Bi@&ak-jy?No1;}&rlOIXGVRuQeeI^F2r$yx2ROuyEu zJhrF~#+YX_io`L`;Dmc8C#u~$IVrp*jdg5b6I<9e4}XW;McNqK76072N?-rVN$&SIjFd5nqZjl%oR2(6gbh;%mm8JF~)Z;S=a>&k85W zN({F8#t^SoepLFlPWZ;i9NVf;jT)T7Y1E<)eYNaccJcrQ*|bAR>$;_^Q$~*Tv#;6I zW4r9@y=X2!`+AFg-NByjVqagdt{fYFrj>nuF|yBnpPpL^Y5D@1yVQrMcdn?7-9R>? z3C(DW#-_9vTj>{Y5$))R@ca{Xb|G*-Bu$^t5Iv6g~moT!!e!sweZ)U5L%i>nBiZ#4=d~IF0ML*02 zdHy-yK-eEP>08*rF8X{U{l1w2a8_Gj-p&^$3S$GC*uoBWaTR;Wde*mq z?H|7^^q$tYLeE32&OB|hC(f^jEX44-ypCwik2nU7dHyJ}?J)W})xQ|%S5IMR=6b*M zFnt8Y_IYvrb40iVrAU8V`zx;gPpTvJ8_bdO!YO_2X)Iv0oBfY*Brt(!Ufg6%{qMY0 z?kDO4h}Mr?q)%4pYnOI8qCSAg_Mb!BOWDWlpEP{|^YzYy$o8+W?=c+53G_A^zp(8j zeNfs_A3!C&3ei7xh7p%`wAOBB{(YT&PP_k$;#^a_)jsqQvf42P|4Z(loLF=Jac zAIr^4r?=zx=RbCc>+zmC$u9IBvz`Ij;=7MkxvwVob=>tRhx)}2*SIg?Au?`zUnl$b zlJ5e87(!hB9hR?0&}^*!EaHwmca@L8=ghzN-p7QKm|ir#zQo2R=hn0f=qY3T>0RxF zuHTN&8>1gb0uzY-{V|CtBu;1>NGF9$Sj42bBrenETBVI8`Z5;iX{0)w3oGJ!^7zb< zYv`q~lN*T9qksMkVpDhvnRSJPhnC$FqB?b5KSZ?7{uM<1<~#Pg^0y)El2@^Z?bn+> z`=7{PsWZvkNdHIvC+f0j>^90HEukpniO)y+8Z2Ow$HB=Oe=!a2^BPc;B<`(%a z@=KiIlQ_*EabbzyB2t&H&+$(b`F1ddQB2u)Qu{oK2}Emh4IrA^--nSV-!9@9LYZ@v zqXNfp94BxRmFU~DzQ3{gfl}k&$gGRsWsKR_^~l9Tp};;>s74J=;WTPdhv|Of-iMr1?gNwJ37#bE_5S?oAYlQwev8qEto>4 zowp+Yx4Qp}?w?F_xqmV#JV{PzizT+4UpjG&A&E(GNzBmaPPqSL?w5WE3&QCrOwLyx z3X9@;`VWQ6)k?vw00tPPF<3}Ogz3}XcAj@iH_ z((g46KQAmJ{RT58F>?F`Ca)XPam6@-5|m>7!8=14dF2h(av&>k4AFS}aqV&s2y>N+}4I~dA~N3_CaR;{h#NATJd$LM*|wsgl4qj z0=7rB>A%WH@cW0tMd2$i=hH)Wpc7q)`jorL82S-MX8wV?XN(?=Nz5WipTWYOIuy|u zdGxRNjQy>CiT-gn*sjb-Dm#{yA(xdUm-wPNCWT8_M5aGs%=^M+;b}g6GnmC3n|l5< z``Z{mnp`MS|6fpcVVrIL2mXKd=~2%uNoN_6AKD7pTY268ZP7y4 z^fmjeBU<}sgN)Yx*(A3xsr|7-?)pyWj=4s5f9kkv!~&uxr-GyA_As;vYf7GG=k7<{rqISuP`hTywe|7!%s=6K%o@){}{(sbyWp0;r z3Q>f^=q+^ro=Y)(FxRzJ7}KSfU?k7AA&w!GiYr4oD)3_a-(2ZBbv~R$&>1Tw9@A;NV_>ZT%cdX zf^b?mMb5Wphjwv2S$ufOPW0v)8z8$8qYo@9*OrtA!u=Q)9uXc={=~`5`rc9fE&sLz z_imr{ugW{*6=Sb6WD1wC{hAOK$;(*6GTMH~zu;He@$^-!VIAAwJ`gq{ebo0zM*sHR zB6qNhtJp)v;$F;t~`ry(k629_E?TsVKnj+WBUOFr;V}E?x`g0$& z4xsV<^}l^KdvZs9ziMp9_PJI0f7v(iV?OA@QJXZ46rz9C^zB$b4Fed&P$}CM!x%x* zH5NOr6lEw!1&-l3rZ@TbW0oyE*UB%R-JD_@r?HUqe}F3f=QaG#$%OqTu1XV=^r?0J z^No4ogmg}#5>@COV}H*a4Au0(PWE>@KLh$HL~G{{BaR{GI4v%r&!LvA!{jDk1F`{0 z`m8>SIeit!-AALiw7!W2eFpR5Q{C3BKWzLR1NejgAG+oy$26lA7Z8>4nf{*?Ul-x& za`%tf%kCf1oP-pjfBY?Aw8Z_FxqsK2`9DH|`!96=x$Yl5{nkH_&UHUr?^_>1Z`}Q( zQ~cnLYuhw`;cnv#uh1{H=Kh!6|AzY)-({b~mO6lpA@cp~C*z3L`OJ(<*ZH2KFw*pe z8t*=`LFs*b`1|?w4;+^NF@!jxH9ki$odn!Tz5CScKU*Skm)DvVAHbg zQw{QHll)JPZW#aGGX6~__W0h>lWY3P96M#-X#I-MF5VF?OM3~+=WS%1`Z ztkTyoq7Qudh-*aTL-16wH5@&QCpT{iT?N(~Jbt_JmRrJ?s&5aSmA8k->u(ERZ2sf$ zg{s>^$JTA(vE@U?xekTTCGQCB+JG(ncZSyDJ3|}JkNt62ca9C5tNzpQmz94SK6CuG za545T!lU@?%)gKU{wZtI_D9P9IQ(^7y4H;B@Etu~`?B!GwU>o0+qSLe|Jda}3ZHNK zqwu71?#d_4UlG6lCF=!}(fIKb%C0Xf57+mrkN+wsT>1V_v#;8=hphi14Q0>vXZQqs z6u;5$pzMj#X+Er-QEuDBR`xk%K{S6M*Z!II&%HSzkA7qRLOwlOlVTPH^g`Us|JtUJ zuX=~`<%loBVHD#CN>GYEHb(!3@qu-=28QUF^?H|$4`7u2J%({4Frki1BAUZLb;cU|}YnKh7 zQd||PQG-)Bjat+pdREPd%FBA;Xw0Yujr1n;Evo-loCDFneuws?U8HZcK)*?z`I+`< zMw_|U=U*v*(l6j5+Hv#$I~TP7G2ba)Ax&SvsC+i2tQ{v4oBH3B%Td`pNuSa$lc=%A zf^<63g_e2Sy6n@yju)Pr@&3lx^7MWzR7e{s%y;W66c^1O7$cMDT{Rv}rVuN2U%UDh z_smHVUPR`D`=Mg>Vu>~~8U1_nvSXI8j1|1N{kJMS-J$;P zQb%{I|1n>$ZHY7%FuJM!$M~+gUfnyfrv6W=|Hst-NPjy!tVw4bo7ln*cCr0qV_{^} zhqp&wxtD*$-}zpVSx@+tzsa6E>DNtIO#VM)$FqKsJzVnl*&|thpWT=L zPucyQ|CBwD^{ec`qJPMq$a*e2dHTO*r@#K2>=uk>{VaQ|YCC)UbmGT(27@0Y9=uE8lpw!!w-^nWQ=PYbuX z$Mc`w%dVyODq|xb$~yXBop+0RdIK_k)z!++O80h(|Ifv1+5H{YvIn}ZWe<`=W7o3d zscYH8{nxTbQ016LG@%)-xPXgj$FyfRQ}ROgEIC*4LiRkF5>Ar~OD|-PN@EP;&DXLM zm_YKvwd~2-7qX{{UdZl{#`>TACcBfo^2y(1cahzQp&xOKA&D7e>R;d0Ic(p$mz}7( zmYouR3F$XzpOKSI&e3=+`!ap5E~Uo`x4+$d0ePkUfM=afzL4*<0ieCOywx@+y+_+2hKE zlfGZ!tX0=2ydXSZB@dH*uC2e!^RMvyPk8>vJb!Yy-t$L}V{(y)d=#J%ML3M<6Y{&X zW^0^>oR?;*;d&b{>T4cV_Kz(of5^m|@`p?cM?T_HP5i%;d)dX(IfAGUwuJ1xyq6tA zDSc4=68YYi)6>SsM}#Ba+y9@myN~a(tQ-G-?6XZ6FqMP^1re1Lm6UWT=u}clN=m-T zrBsw$Mn=VDlp{$epTPzU7%*VKfB^#r8{62%?%7?GT*(!y$jB(Ua!pE7zLS!Sa`k(j z^XA+2z4?6~kKgZ)*W-QMoyU2e$9deokDIfSjLz`<=ksr({eS7gm5_TdiMen8-wqv=^v=EJP7{v^#z4b>Hp37jx_D z_gPE5bQs;hjY*-5TP13u@TI;L+4bb#6PwBIH{@IL5Dw!AI&logQJbF-PLS{IFFVQH zg;O}_Terv$t#jp%UiI@iW#P2(|5)wcbnOKC^xp;Azanic8J!ba;@Z7^S50JI9&A(@wf2S*GK(zj^`{Lw~ES|16 zl0pjke*apVoytBK@2x)=%G_g*<8X2*ee4qb;a#3}B>QNL!C36krj}qFbNK@AHqSVj zIStX7!TZl!V{q191O1iKeD9Ba=-dC-?EOo;f2H>)C%7gZlQ0=mFcs4<9Zj>9KWW<6 z&z8o0=lZd?bMMd=SD*C$;yEC`TC#8cqfGy(@SXXO=uD&;;+ctAn2m-p`cUJoEntr3 zJ`M^qmw6sK=F;WQerN8Zb%C`8+;(}_1!N}n2%kkpXAt+TKj0V5fwaq$x!m`U_Fleu zMY{J)^X`YW0X6#mGMq?B8@q**c+&kpnca`raafJ*WA{kM@mEjj}v z(gn1Nw=IgBIno8}QV-P#S06nK+7RggqOt!W=JLbp4jg7af-2@n2N0bx8tDK!xgEoC zq{-7g=I>50pF|hB=}k^W=9LNIGXpb`K53FNc@&h{0*ZL_) zK`I7gD28JsMq>=dVjRXJ4HJ-#N$5FB|M$8vy>oyjvv&`31_?P8(=Z*ir|bnGv$HP> zGs#(aZ~xyM=IHFnzWskY28Fro9i!#V3F?6H@+w(+#`w<|YM*gxIav{nt;xP~^2l@E zagllbtWy`M2rde<#W4>{HkpsJXSDD9zsbhgQ;oC9RB07J2FgX;%F%s{O|6S9PuU%}`E+XoKccA*b@`(fL zm0I;m-AUtrvf+^Nzi_qn_SlMN48~#{8ios_Z!w;^#d94*8uJ81|I}{h-^M@s*LFI$ z=nT9`x}RaRi+>hU4hDCn=mDyKbiA zC%e}tnFrCA!D*bqSwv&sbL4sSBH^Ux?{z*Ml9|iRJM7cfNMTMzm3e^uGqwNRD)^Op zhBD8x579n>|MLIcBJ1j1GZe!y68~KON3%EOdS~>_|La4vHfx*FK27_N>QmbPGur=N z?Y}-ypMK=9_WzLfU$|1=qRcnhmtvfbis8oP*w0*r>_>gCXXqYE^>?&EW8HHc#v?7d zzc9@>=$e+n@6_pS6I`Qj8})Ci)&Dh<^#3PEx(5A!vVN-mKiMeU1aYKe5~Bb1PA1>` zzs3~ishEbo|8uAhT8osu9A!}04l>%~JDvXwlqlCT$z8{tNlVVge&(u^@AN4u$$e*( z$79yayy;vK{-xv1KVl!sQIT&AAoiomy1lu=%tHniAQM^0K`xrw=;&u#|D+vmAzQV> zZDc!p2U)GJ>s%c5kNjz_v98$~z!Gb8imd@eyiHk3k$;EN=ZY^M?T=UwImCYct^U9M zjgjHFy0nmcQRLS;A{3J)C`0Xe^$1ymdNdFj6~ zrt6#i!eQ>+{q1)ldyab6tlX`I1X z?2)HRPB>4Bx%`m4b>8{N%)RJ4|Ia&DqC%JgZqd0USfHTpT-&Z zBS(euq0YQdb@sja{r#vy_25y~myHSs=uc~pj8eak3iXFZ$={>s|3-zT!=pko_k+i! zMVa*h&DIB?ow=jfHTLKacAudbj*)0MCXNz&(3xAJYtK1+&FII zk%kFKN7tjH!X&c0ZbX<&PQg_4gppwyIUO@F6T8x`2(!r9n1i|4V_v;vjCBdj<>uM< zWzZ9s8UwOdl~^0lZ0@|x9J+AvsnTw`{Dayf_Q(oTFI^j$Gntp{&`(N|2eb(-mA-ks zZ!hiJ$>=PG(n94Q`%;zvWZz_@Z$tL&KTVT2#gT(te#6Lm&XdNql zw4ap!kIVm4m49Il2vbYerOTUS!)WE-y^6(Af-*$^kgAOAg~~N^4Rbx3u}@hqS9UAT z3Df(|*r1BNjo%?0M%4d3LU!U9j-yXsD1SxkV@|N2L>HFUMdyDf_38gkv!B6P^t@^v zJb50y=ziN8_}{#f|C7lSq+&3JqIQG(lfhg*)4kF6|DQYQ?7QTxUU>`^D9tej%rpKa z%L}!C+Op_e-2H9Fi8a>eiDLmWk%b)OA`kg!(ndB{YX3^if2U~Qhil)_cm7YZ@&7UH z9}b`v|MUL=qcZ^u#8ZeO6r&-<_&32Ap1EbHxKYMjiH-%@)_L04WNj^KxOMGF3iV|7 z7xYI`rA3Y+7LyrCfPH?L|eL*-$cA;wIFZy6uH}8XC?ew9}S9~pT&5S=JZl85Q*p`+YiY8wcw$8aMY?(gX9G&wlW?T_A zO?I};qalnFfA<4-h4JLBnpoI*T)kA^FYL(meG_gE)hT!BTTToI^8BZ){zIWI;Tn4j zCx-T$2ZgdD9}T64J{rpHQ`^^d8~w#C>Zp%|s`Gb+{Rwx5;^B0cQM_|6^nbNzTsR=? zcD~}?oPl9a;Q-&=y8Qe62QB5b#Kp)S|A)k!p+mxw373YYotK8}9jPJR^~=Xx99E=W z5>^&n67o_5sm&OHITwBBmrYY|&(N34*Z&`z8uGa3qX322Gt)kT z8L6R&xqOPb;Y{Og<}y?jI{Pi(8E|C9wABCF{~u>9wJ`N$!x8%enEU$w{nf$jEr;m` z4$%*w9UVv$rX;!s&1l0R9L5opPBuTEW&eTe%Ta;KG4>z0M-^E;%YEkBe?Zqzi@G`1 z{-a@}{RczsKM<}tjeY_x$?ng+?X0>L9oRMN!qDkH+0P|}W8`t{$&pXEpI|P}HHKzC z$=rpiJo*Dvq5`M5oyHlQ#W|cuFA}`y}C2{zbEy7 z^$i;NN8`uw;z`2T>pN0K%>Fu@0(cYg*=2G#LVc!J(e^fA6VtGw>X*#y?>?mFVX*>>-{slf2Q{*JJ_qeZ_PRHkJ`iD-@Dg)?}mEsZ#>ZT zls>p|b=1%9zSB5c{F!JUk`%JY?21^(C6_)$KR+lr2BniN33my^gHUpgWow>7`Ob6WKgH(o3b% z`=?s_mm1p6rH1JLuQ~2>7)Q{FJ=P*eZNM?+Xz$v-vG&Nr+O3n^ zx^N1o@z47EGwf$^4t?iu(ifg*?nT0Hl?{}Zy!(H@HskLj+SC);R_tf4I;0<=ACw|o zDh6XHhGQf~I~4WDWrfH@rv?+#$!_N{%qDkl^J_SmLO+kzk?MwI^+GEBANgPYe_heK##o&>10->_nsHRB<5(1 z%M@~Imw6tlhU-@izAID;Q-KU_wZ(UZ1!N`~D*eNp%t3vPz5Y||OGhqy9@?2Z z(26#cPQNRZ&A!VTg}XvIS&=F3IquPH|2y(sQ-DGgp%^77LnWG~+y6e({&%uvu6Q%V zo9WC3vV*<)i2d(4aKip~)G^niL72vQclYI;U0rv}%*qhEHRHF9eouQd~-}(acHEqld_QltqyE7bOZn1a%V2!=x%tz2M!?R8EjKZ{{ zbei`^-}}!JFDmi=|G#!5gihBT!*QIzNp#^9nzV(@h}xJIyx;$4?m%_#KlT5$mvznF z0}afL`ku9?(!*)-oWWU~LqnH6^7ihZXO8~KdvHQ5^fD*B;`uY_{?N8S|8cN$CAjS> zP7f($D)yA5hr#4flvK()!j?~UrY8GH>=#$`51NYU%5i#3p;UGe)sOApOAOJPpV7iC zr88b~XOecA&Xc|CFZRBY|JU<>qVor5x_%aB zV-Dsb+WR+;{QrIa&pmgCZheXF{`T>IZ17IAy;{GVdx_s{=%KFRqMNybIS|7Z0(Ci?~`EiwMomi3)~)2vNluEc(A zP!$TqQGzm5q6YP7MjM*ung2)M|NTxH{~tE~r&DhyJ9>@(#Z{w?JAnUj{;l?Rn)Vmb z`8S8ea~Ma^iH5P-Uv1nm=9YQd-?{o9%qP%cEZ(kdZkwU~ou&W5Ejs_Ei#&xr!k#A2 zpk($teV20WJ9n;Sk<2^@SAk#`a?MIaV&hzg@!YTFeOZIt@OXBo*SK8mr zoQlC1is2ZE(dd3MDU2bbGx5ff<1ij+n1FOl!emTAPhC=&N=`$mG5`DPFUsg-l+8-) zXRboDet)`fGcXggFdK6)7fo}Of8&IM#*i&&Wo}DV_m5ZpQGG)B#{uDLXDk2Os(P{k zjoP5BwnQJ&+`GA)NCzIBO+Qb(?aE(+Hh2MZCR((;2NBKfwbs!6m#GKHEPi|D=XFh?B0(}#-mFTOWbgcNX z&wEE_a4jRF(~w%|F0hFJOBT*v^g)0 z(9Ya}YH3)51BlK_tQ-8!+~#}d|DP7$8JxvAoJTJbe&^kP?hGh08vmz|-MgKQFlcB< zWgd*7h~_?rlOr)2WANVpzs53`*LydNV;+z9&MmGi@veA({reQ-j2XribBrsnztA`X zr#wrV>n0!_lQ0=mFcnS4AkDqT8DxubO5gf_TJPj1xDqznmDR^XnMSeX#lOG43_i9szE<+-u1) z<9$jjO+0J z{}q@ku^-W1hxkcp^KSlY6aEmg?^M2>>v)8@X{7X*J_pB0|I^C<`KVvdpSc6oQ>6bi z>5p2}rFl2;Hjs_3+m)LdI>ocZ+<5f=k7LYx{F9`F`*G&-0{WjkX9O~zMAdovANxcq zbLfB2#jUnT*}q@8x+f_#Oiv7_xSdAd`62U+3#xRu-?~So;7|z6;g#8jG-8go(G*bMUKW8bZ3V!mh9_)kbUd-)CG(5wa2sX^6V4H zzV-L^_oXxMnWjFhxBij2T>ocZ2K@>16jY_DTM(^1i26WNxlO}#%)m^{!ff>D^R#~> z$^3vbHQBp|7$1`JkbworL>6)ot?$ew>-G7f|2yX~=c5488uFGL`J+hwNE2_m{6Ur_ zn}hamlydFjKJ8&;q=QEE@1JO|c$#aYwlUgs6pif**^5w&5;RZh7os_UzWsIdoqhZ3 z=sw@uUl;xRFZxGt^pD&F3EFjijJg_g-<9UT$zA#O{;xDYah>w3|5ql?J;Idmt7I-e z;<oxn+S;S^5e40=kF!&$QH*5q)G?DkK`^JFg)UNsKHu95z~H}0K2M3g=$-1Z#z z{Q4ZJ%;jf2|7l}9=IraV>4&82NzdQq`MLjZ_aFEc`^8&Dw!cpI=i0vgk5|f@%%d>| zW6_i;4Sd6cslEZ(DvjGxTsOh_1E|&?s>zV{!qt*>Gkk-|(%$!NBun#^k3xC0SouK3 zF?sW(JbL_{`KfX4F&=5?xY2%CG98mJ8BNN2^GxN_c)n$}@@Y)phIVwIZ+-#MvM$i%2acn{ukMzY_&4KKfk+x{%M;27v4Yr z@vQzA`u@Lu${PM(s!OzcS@f52d`UVJD}~jXTy8@ftvZXVV8f1=a7eR7)P+joOnse$k54LKGgg- zjxircRs9uVKPr>Wf8zwVljy=J{PXF_EPizWkdgI zeZhF^3(W2Bx2B+K^eFm~{}F!S|8)Ig_NHU>zbEK_>9||anos|Wc66Xx95v!RfLhe0 zSX+Si{;$^;zyI9yJTx*$im)`gN2`1O)H9AXKQPAs&hnjoHFRXyk!rpwennUzOm{2u zbMo|Yd-b)?(L1ev_u2;M21IB4SG;UZoU}Sv!tQ<7@$bm9e;;k!TCr=WIp6p0H#mgV zbH;={>C%6~6=9?BZjC`P9%*m~d&>)OB&#K5tam zwqSTDIzK#YO}IR)J2ft>9eh*BpF1|J!N&80!-lhiLqX1s_O6Yy7Jv?XJssNF4~4~f z*N5za>qAc0jbX{i4~L}-J{&6CcZ=s(KK-V!V(bU}_so8{;g^I}XNQE9CJYF*%$B3Z8_AWY z28F!c3E}7)vGAO!t{)6nz@6yYxfmlJbU*;&Se<$?Qql?Y=7j}gL5#9ivJBq050X{x(R#DC{yFtP)Ru5NY9B|_;Q`^`efozf zo}1f$ToB^F!5?uI`(Md1+kvBB=x7wnN98hu;wUcd}*1kl0oI)!m-qZc>*t{zY*lv=3)waBKcYJ8SXR4>pa_+$d=TrV!vQMI^?6VQS8y0uy*Av`iG_? zJLB;=#&;R!U*{!+#V1F{vgh;*Ip+SCU@5m{#rjm031PYU=oNYW!bXNNk%h+p!~K zNUXT|g0ORny*N4JV!KDH*T>OAa^G98eJ&mxE1NbX7X6>7akTenO^bEjnHKZ-@qLEw z6{x}w`27dE@G6e*zX(mm%7Sm%2u(-)=lpa{<0q4SQ*H;7)ltXghttlW>2>zZ5pC`Hvo1mgd~ZPHRl~Rgjg3 zlfzE#yONW`4)-XYU|qq;D#U;79d-^qFVn^@0lX*hy=ma$bv{2ac67uwvCfOHi5<I-k5f zW_mmRXX*2}JiZE@S6=Vh8)EyoeMEX4-R*gI2!H#BV*ioEe)~lEM|>Agh;=@CSL~?7 zY)>~Ac;?2~FZj2pQ~ODy=Jcy#KM{vtL&rVP7%%%GhI`KR4~B*jLD} zVIlW#lk-8<6~92bWVvsA7J~9`AJL+a6b5k5zF7Gm;=Hj}`Wp|&!m=Cvn`L)$SRS4V z7xG`Z=7R76@-lo7P3n^T=bjF0@BeOC_rQX16}O|-ux)rFB{ZqX)_gA|T+eMS`_aFs z!|S-yIes__H$u77FPIpa>AjouKkj1w7^dJ+l&# ztkj9IMF|sQi<2kDvWHKM<&2ydTQYcJZ0XR6v1MZ?#&X9^j4dBMF}7mN#MsIS@?dr> ztQtQtwmOZU`14O+6mp{b9rX^o3N7uJNAByB9{!;vJ9Pscc7KZq(k3 ztC+bYvNP|Np-ZMa*KVfs8@ZL1O^lUQPK-tUyh`?_H1qjU*awZ@>pjbjiQ(uF`aN#7 z)_2rd-*Ge;U!++7f$0C?O)AT#4ELp@XdQb|I3aAi!XL%4^wfpn=&Ryamp3&>ZN2(+ zw{jW3D%SaoGPhniPZ}HRd|jFOvT$@OO^37z$1e=au(8HJE`5*s=-Iqmb0pc@$V1pZ z=Z@G(vMX|r(&0JpEG-VRABo&fS?^08d)M4_+Oy*5(OY7tBDXQjuig@C8XdxE_A`;2 zG>yJ(l-7+?Q;p?PL%qf`((Oe5i#{t%&2VdD#?rx^Q@*X!X`5lrKW&h?;M8ys&1gcr zZ9vG<7g#hjDJ(uRK>v&W_M~-D#|MO^o$L?F|N0ioowvOrwSQPSxPQoVUiYf{Yh$ZV z4G3%I^w+oS@1MH;!@9ZZt80=&!N~q$Lyl`lBU}JF7mLZazMzRl5G7~ za#%MjIjlzk^M={UVdLQBuqiEi|KzZl+>)9cwhpxxPTRMQ+@9k7%ThvduXggBb~3^H zYa2_1+e7YkT`5_{zK<*)u3gPd@&3{I%>D)3n-Z!MQbJ9p^OZBS)2^?ZlVa_pc3SvG z*Ef;PU1MSgxwWA6*qB&bdP-k&0O+8zi>%{ej!$< zbFL(>!Ls|4f?;R;M(9$;Z-pUO{C7Ao$Xrh77hYq20}VP?js6{Uj`^=R$lS7(-X)vf zMW^r|+_w8y`mUo%AxS(bSR9hVg=CKTl#(lx!llg19!LsQeR11E)%CvlIQ=i3$jZ^? zaK`ovzvVVkyK*n8>2p3ueg~`Yg6m#KvTLt^Zbke)JdCfQ)w?Y^c|&aRsT*SPUghBY z4Y9uO?%QW~4KNOp@1^$&EJH3<<)nm6*RJeR4)RjM+U6T#>n2?rTT^&LEdTJpu)(<> z8#Atrtv{UV+|g@do7uOVxIVUtEL_kpta0t1v9)(#*x~(lqZFoj;?=MuEZ&4xd=Fdw zr@nUR!caOmB|OLeJdWb0SoF(h!_UcI;w8L{SMVxy_2X~gP5cFK!*C&4##w&Siw%XEL09J3s6xm>>4G&9^>yekeQhOxSnmnNU7D)0&@5 z^JbZ0|Io}(m6{o^7Y-OOi0{nTee19RhoXO-b+L-X`!)j7|G7P6IJ zTiUbw9M6W1asF$rEwAt#zap#1y15JO{jxqwTb6RQbSQo#);9jBa1rwhH3P8`7Z@mu^3f4~(thIoAX!jLt4VOZq8i!&A)|1JzU zMGM1{+=XFj*21tXXJN?ATNsv;E6ywoE6*(qdA$q6s?!U@YOE>rk9MprTj)RF3&Z;6 zg`of&m^We*zrtzHhRwsD4O_4^_1RE#?%A-d_t~(W+`(Qv^x3d;q_RNnW-poWY}hmQ z*{~O-t|^=SY}hyVS?eR7wLaomWB+Hve)cM|+I2N!o(%`aJsWDtx^(GJHq7`=Xgr<~ zntB(6=F=JB;Q5Tua$-SfWp3l%ek>z&oXAip=7+5G`C-wN`C;+o`5}AK{E#zsepqt; z$*?qYzVVN_I`-u==KBx&cf(5I@`}D2R@KmVjQnm`lm9QV{IT=H+N}9uUFZC;{>1!H zaC*LSFh6WOHY2vFcz!6H{khoYv-882sh^8&JvKiS&H8TGmh~^O?Z@XE|73>ZywAmU zrac_nl`=o<9z8#ljFbOzKNs6u__ z_&&ZkPX7=G@bCCBeum#7`NLit=_s5tAj~A^U@qn%7kOyLqPM>r7Qg!4aG1FhOTPJZ zSc>D!r|>4;#&tKT7jY|Y$33_oYwvj`JVHK+eEqoB$#AoBh-@rF`OlvTrLR5{%3gjZ zY-QhpQdDDS^)sP~d=7ub(R-%EI>VIM(YHTo-0`W{0A=C|Y{6bM;s^K%Uc`aYXF@H0 z&HNj51KSD*g&S@Q@mp{Q;+6K_QwS_FFj`Dzv*#SQ z_U5F%F}duFv-i#=hULA9VFkI8J&*e;{;SC~!spY)uf;m7_wSa1=@aERhGZ#tjW-mE@q3?g8?{D4f8S|s(%#Ze($0aM+ zE4lCIUqx06UsG>>)AhAv9ee$B^P4lxZ<0-O&5vf7AI&sBO184Mac}3}LAK6G3T+ul z)-)xB4zg)-QfQu<6b_Ot3zDoCO|o7zDKwCc!c>n;3N@)o;Xrbd^`uFmVmkfh%%rfN ztP-Y_EMwnCmWyY1zWyBcU@vx{7(1~GTc_y{VjH$&6AG~zTd;nx{vQAnjth`NC z-pFi8B3z15xEeR$CftHM5LaJh)zkl} zKNgeO>^aA+yFY2&J-O_RvEDiQzupA-KSBOCugZND|JCFg;q&Vg!dlm_BiFN=G7tBP z`(AyPZ)uC>}{ZV0QIt`DnQ$ND$Llc9Tue{VI3+xW9h-0e>~ z7ca#cF=NC<->{~ZTcKy_*7@ssbzbQCCi!#sUm3YsKf}L!gtNXPd8IQkBKd%|0A$ZA z#@3O{P7Eg^`AchRBdKHEMRq@LKg&zj?d?bk-LEEwd!zhPY7IaXW_MC}F?bGfosPol z-}i{W=g#O%Jm&kX3*deh=Wrf9k0gcONPa);;rNYhA2C{FPBhET@;TLT()?Lp{zS;ehYre|dtW9JtG$yUej)j{%XB21O^LeKG z*bDQtjceY0-m}`gYT?QD2kC6NM|(I5bLc9J@tTf-nG}Tk8tlHt~tW(0oP!&xt47w=(Vw^m4Eb1 z=a^rjS6;?n*lp%qwjblx>k9P&#qq;*`xBhDQ&d(YP?f8?^geIqB3sJ=KY>Myj%XISJjAj!d2eG{XX~i&e8X5-zv|Im*3+0 zR;8XpT#F1cyLZC?vilYJ`j~hO9t%$iAElu((epj;rY?!z@u?_12br_J7}p7JoYrl? z+C4&;aqonQo^g@1+?LMIa~0BWEq7ge^!{&S$2)Ou>-8RI-wCrJ`X%oa7sM&+Ugj@@Qfx^v;Xajbr%Vj5yY# zJG$SMcu=`ze=c%+ITi+8r<{v_Tc&%yz3wElkc|&Z^&17j_At|9i5{*QpZudT6y=r0+@f^v%RgzJSv*m@ zwK@v>O=+SI-6s8`ZxN;cKZa>*hr3_($Z+jt+I;34$eS>b`4;jv+<~5^Sh$#5;`%IkDK5t- zJdU1=pGvGHyIz`_*hqFivLLaA?7$E3L%jcg&lB2zWXnrA(HwO`Qdl}NIV>AP|CvVr zPnWiWUT-BmTi)#Cu!?!Lytbz7f7AX?H|ISwQU5nl|JS$4NYww82j(S)X#Zyswqg6! z#IU2@+JZLcO_IB;XWMwKhxa3C@Dsd%7x8QS25-n)-Q@4_ z8hWs3Yb^YUd<%cWKd|-OM?+EXM?<1-kT0K9ar+bw&ORPK6PdYxCi*^j7;oGk3!f*y zg#0zJ@L1$7+~efe@pP&X?fNh;{Va9~bN5649=nvh9HVeG%ITB-b)5bt`9suVmvLPS*|YoQ z*e zx_&*^Y>egi?-wp)R`lYRK{XKn5N^arpc;zbh5PVHJcxh6t%>H~(G`Eny{?FL5BXZG z>*7mdT{|v{srs#3yfD^tg)?{Rwf6KTaES-BnST@~%aP&`>R~pY0(?1{&tDuUEf7o+; zpIqzTp6mSQbG`Pvz@+eoUi+`fP3(oo`-RQphS0m9a{`G+nDYtCK})*ZhftUYNzM7lmcHuBqmT>Y74Sb^op!%D2iDr6y2P%kDI zAqUx5iX~_+QAW{%gLrRUdIxhmYEX?@968WVjy{C8o`M&WONTAz7R zc$4|9$bNKrV$YRJ6XW&6Le`OC=KltU#p4E=_qfo!hxPPW+B+;PJ86x3;bmd@>C3~4 zaU;XZ!MyMuO8=MY%v*K# z26gsEa+5mz{q;{77tsG%|CDCjmTufe?&x&R`El!?)UCV7-DmZGum^ildQ2U6{@s0C z#zz&#M3vazssD3W-Prcd8KVdEYoaqo>-2M?Ge#TOqccXEGWCDRgOk*;nbt?owqDRP zL}#dWOcVD|c})H&mp}KRQeLeZLoJbAT$5O3-PYnmidj56X8`&GkyM+5RW}yk+ z$42bMFOYRARRxf${cuJJa*&Os_VF(nH%!?|4V%Qdh1`syDXC#AxqV)$aYJgzx0hoL z|8?ZrTyv7-`odKIL`V%A*ncN~e8_#);b(XS7s{iz;C}p>J&+&3DBOmJ@I^d>2k{L& zi6vNz%_zew*oh`o;s^K@UPZra)yEi(>DYubky$=}43FdM__MILa=#Cs##qe29AslY zw&HpG1Qj@eZXCoF%F@L+%X|a*7cxdpz{l|*K8G*kDdb=!cB33kIEG*24ElMdAsB_R zn21T3fgh+7ewQSUWOK5(gOKIo+1P?^;SAs7LxsH#ci_{Qg(tBTn-Rou5w5^E z+=;LBr$2>mN&FxE)LALwyZw^zQF0vjn{g|y;dfda^rAVV-)VE6C;yvv>U!aNUb!Ts zkiV6+ekSY^j2GUtfBbFnrm}CuIAI3icf$P<|H*ADh6Og%8sAEoQ{1n| zc6=VgT>n?@uQFeXop=qG;V4SbgDdb8l;N)ku3I7AJmwi>BJSaKA0EI2T#mns`z>K^ zX1)l2U_Z{Dk4If|3%3v8zjOa3xAl08do}qLviqJ7(EnUQ=Yn6Ne<~daw+eju1IGHP z#<{qa`4I@;^K(3pD=!JJi^K1E;)BIE1ixhdIj&*u$M0*Tr4jK3?Y3~g^%Fh{S5m(cm(lFJS*sq zdm=Nxp8IeOhT-*#W$HUamc8nWdhZO2$!zwVq3L1C$n>z3T;~7jxoPRfzv*EGxspAP z`zrpc$u)ER`#mE)tj$ah>&W%&1>86A-$-r}zA!C4Y^=M#jMAHQ4Bj6x& z3tIWLp`Ez{@$=pvi>9Q9#mJtK9&)gRc`24H@cprzc?DMT%fl+>)mSsZ_eWp*KlK9t z*)c|%sr@H6vTw@J{%dzPlUvxgaxdb)joi+@qh9;pru`>(b!z{Q+dFmgf)MSY+RI*g z&i?3L?LS%0UctSR|9-NHy}I81CfE1vFQ00E({#Ggnb!N0jqFVs##ouwK$9)(t=!xA zx04<06~0|%iS(?Lo@6O|*=*@OS9+7Xee06RzQbgX>HMuK?NPZ2_MH}_%8DC zqV|niJ)R=$RY=4A_&oj%Sy+$1!ZNt{FkFj|VlqCBhw&)ppa9z*#tl;mhi zp(8}q0C5a7F68zZ{B97P6FEw~ztq_4@7ljA@^NFg9~t-kyRq2$aoX9N)57)SIlN+R z#&3+-wo|S=k!GLcHDkk8V}v4X!*=X7hu)1##dU)*;bLv-MdED1XI=XQzlS4ti7Q`6U7x5 zb~XQB^Zy6`Ev~7>i{kl*u(jg<3GOg%tu!!yn!H8apA}{W|IZ6oHk59_``+wc)nR0) zN%DWvFf7y#8XoFiy)4v+S3<+EQ=xIh??Y2~IW%9>6%Jnc+t3nT53RSJ32nFjclz<3 z&=Fn@uZh20{2ji_MV@E(gZl5z0oe11`TNj6ls;;G<(h$E-&co*@^4%kD$He8J~=q- zf93|Yi~Cb(jMBmTH!}Z9TKocgU3(k**S%W@ z|L=%rg!Flwd{Np@6W22Dc(>>3z~@~5Dc5}5wcp_0C`|m6@y{9KpK&*Z#a+ff$iWgU z?YhbM_ok3b@49?$wErn)|C4_rU={P~F|n{_p1D4(#X77XYz_b$uo0Wa)16~8wqWaE zb4%EU?bvbF`e*D!w3mPPV0#y_2YbDKbFA7GEh>aC~W8LPLSx7U)rDc4#a@*sOjnl&TX zi#;gD4(!5CRGy;`LlyQTT1!-lXf09nZ>+z`Yn$YmQS#A#`D3#@I8Yw>wmkJI`RIrI z|AZTbxq;gQWQ)8sP##)^pW!k716`lU|6$jBU0(c{xc-c5_bPHV`qtN_^vG^|Ro~Qo69_&TwG-Gt^%NP{Og{_!s{-fTvZ1YXY>P~s)xV&;w zo+0bl>(9wcz48><#NN#PApaJ!mA$Q={>SwlZ5M^lNRzMQY3aHJ-jmcMI3!JAQ$C1wQWjAMyV?|KEu-N0^_u$5q^)Kw=cP zIDdr4#o5obm*DrV`I_*b=l=)tMe=E3HhRvV@Y`5KSB?U#$2zP<{+Wr^w$hO!n!_$4 zw`TPVTkH$i%)bzuuynMwMZ>Kz8f>jm{WW26vbk<#MT11F?xL&bpT~bCxgx>ZBr=zM z8QIu7HZ&CYW;wncYMBopI(lDj|1^ zdl$Kry_noF%Nc^6VLvK8Qw7SI_o1x$8fTwalP>MMqwg#YZ<97JNV^w&m$z`ObgJU^ z?I0-9e+-Y~N4SIAi?~Z%|BCteh-bPB&*CHQ^=bSI z)?yp7Jl`hI^F2I|{8xsBm&ui$DW0bMr)&Sn#mZK;GL|z_`6rjMFFSHW$W;!PAG;x} z$WZ?0D*yaiUHb>D8Iu?umv(DMCx#oO!6#5KLHQr6{A1H_|Q_i4);<;((YVbQ42j^_O6%zAAld9cu4@#*%AqY(|L zK{aY|0F@)1iI2xU8lRA#KZD1x0>8#Z z^5l0gO1``eU&j-uMI$bfA78^w^5-Y8M&8=<2l?E!9qP*ekVltC`4r#6HhfrK{UolC zKR=95Vh_%u7sKV*u^1+H*(_>hNtG!!K zB;~IsumV+x{^gQ|=uFwg$ZnQqSi-y%%lPFU9}$*IzZK+4{f<1W!fLF^H%B;bgtJsz{j1mhwaNcv1$*E5V1@EOs+~Jsb4LC@uI|)ltwVJ7RRbE)bWB}Y zBL8PPL!(T&_08KxUmDtLqVsbS!VA*iMZAT6(&%zr8s!800Dr;_zT@X{w>v0LXKPX+%gL62IKjL>7eyKJJe`7A{ydeCA+)ktE7o9XGPDcL= zn903en|u}7#NLd9DcWY#opfIB80Q6!>lZ%GznZ<~u=9M)`S9B8b_bRUV=Ro6??H`G<{RT$HLAD_d5HWqN3RI;9hYrE-^RXc0u?w^O7=i^jKDC z{lO`7`<^?R<6qhP&K&z>=Om9!$gMuuwk!y)tmy#dIpT z4p(CwK7!$RSh(!17lzk`JBws)SKy0y3{M~j>rsf^a3)>+7)~sEFm{f-cCfJ_ZpFR$ zZ+IAA!XlKxezy2={0hIp@DI2Kx8p(l2TUcz{Zc%B7pCD~@gx>t2dMDlZ7_8cKaG2b z(5GQ4W?(kHhHv3nY(f!EeBr@Z3Aqnds6#W_aSYpy%eQqJ=O3CG+dAAlVzPO}k;4Ew5yn4+kKmK|48DL&?7%_%JAQ;0@Fp(w zT^t(|pNPBBVP8i30(&wt?8TU8kH#E(HKIKmv+Uhqj`nagOrs~5VQ+^$8?`vV9Gz)b z&97>Vy&$8F^T*m7GR_{6RC`5I-Z@jRlwP7wFEHMol5~243G@QO?sQF`o`8E?m@L;U zA{U>#C}h)NzEX6Y9VmVe|CGxOJ_|=|$4VmxywdA^ybSY#3`v!8Odu}2N**BA0 z_-`eP*te0}d+933V)mWnF8;g8681ghUe`zOP{zKGEazWARHu;H`%-dQraFLJ&c1?N z$v=-=#lD(c4L_Si(R*YlOm z@Q)7%&Gh}p+%s3?J9|*ww!bJfR0&gEW~_(<=ghIUssH%bvcpCu;Cxynr{A-DG8T5wq@gd<6=y9c8G+-_#>i zobds88gVSgYj^`A#W4YQ;eI@B0DmcnOKz`r`@VoWWHZb(6fg^ODT_B$@X~3M;23g}fTNg5z`sv+V=Sq%+VqZ&+a6-L!#W{UJI9 z=FM~Mf7VWKvWBAYEL{To_UQwy0~r{KW?NHHV{HX?Xy1#At@}A?ZAGUw7A4kNq#3K* zQ(ewpdTyX~umi(>`>3l>fl3^hs~<3RpmQzwmT5t4vC)#+u>DAs=f; z+y9UCD8Pm@+W#5ae{Fc-$#>8H$#{4Di*|h5VeS9C0q^ujN}PW%PWxZu-!$6t68GHW z+PxzOgi?KhGJS!4Wcdu|A50k#DlO>hJ7*_F``@MgU!eUTZm$69xYeU!g845rF*hUn zhj0s8nfuNMJg@(c@5=+v%R_zt4WIJ=@bxb6aaH%d@6WsVG*0YO3?fbVy0IXU;-&*$Ec^j-h8*IuvxT6^uiUV9JiB(zF#i0q)>u+_Xr zxk@8zXjjPafmIs#knareU4(w}5c+_89A+Nw2<>tedk-IDPW2FTtI$!#cMFh#eaz7u zpj{*VchJu9@8h@?>A`=qoBKc&_XT8=i~B^G)Bv>ySE>%G_w(KxRPLc%p&WY!lzu?J zoAbrU5-3{1zHFrZo>rQ18+{$~-c}AWcY=OB(#5d@>EwJSGK=GP@4cX7=-&t)n1}Ol zE$t-(9)~A^<%-NaXnV~`6kKnyYn^Ea>s{s9ia zBXAtH5SJU71ziw=+i-gT{sml61VMNars3b=5AZ&GjkfKC0$9s^;&0)v;CCGV9r9*y zKt9yLL3kSc&;u;dWbWsl@^$zCyO6&@F2FU|H^DIc0NqBQAF9?}V9nzKb3GU6?_E&Y z`U@(Bid7f*?%V~|`Y-U!iwmsvUr;mDK`rhL$a?VOuWii*=J75lu;PL|oNGOQfoI_b zd2tKwy1<^V3+ls9FLZ9XpbqHXctKsr@R17|f{|x0Xc!qfdVw{O3+x5BpaDp3y`Thn z8zr6;L^CcZviAae0WK&;yqU8fDGMCuKavx&-~ULi4?j}Qdmkwm@{WDP`2Qmnoc)OL z|3}QVd?fb{#(%rG?vV48!C5H3gXiDFJpYhY`*{8#Yo7oAc>Zy&89yz^)+FtK_ZYm# z_(qzW`ESPj{wBu%$aao9qKyBcE1&T{^wcr_hhFS`(EkDRUoe1u5QbKK#Pk0n4Ks%@ z0ujdfQHVht5}nI?j8nv6exv)7T$6tRRq!YjL)j&Nu3cPvuR$eznQLn$*T8OQp#DoB z38NgBac%uM_E+%p681N68-$OrkHHhf*9$+y?H@V*E@gE)alXSbpApRb5pH+kb}N+P zw#d0Z;D3m7{FGScyBz-xzK`3-kpb3}FW<~_4q6$$P4dhQBSRdAVRR$&udu-M^S>}x z;D9@E55ni*R@e+z!{34Z{WtDYj-%Wc_H&=`|3R+0KPcxg_l?I^=z9EYhnwL|;(QxU zkS6anW_tgOemT5_+ewa3L6W>2MXo_FUQx-+1#V}VmvP*H{2Qbn`B`Ki@+M>y`8nh% zp7-CR9Gbxc8R+~E5P(MhgW}?ow|7n zYTagh8&WRzH#vU}mf#&Yj$WlF7#~A9R6u2Cx%H>j*3Zy2 zJRk2EpnL}SJ`jpWqx5@Nv%uX6**r4~N&9KY&0ze$mht}w{VI6BUxn}W^ZswY-0!fa zF?aF%2lQ#mpq%s9lfTOR3aje;@|ucSTdAdwRCj7n^=AjwfW5Ko;`;~4mc5Msqm2K- zyK!)N|4Z5O9{s?c0c|I~jx^7I`Vw8p?hnZyWt0o0z~^T&`>S~VH=!T0fMfG*<^RF; zKg@+V66X0&{n+2b#7`VQ367ITD4*vkAEcwqXE`C0O^aDgJ}W2cvvS=&D;r#4pH)!h zvkEz%7xG#8bv~<@{(gy{_nZe_o!_b@6GSvyGLuxx2gFK63_Qe z=v~S81##JY-ql8Nk zR}!-S!0(N2;(1M(S;!CILny0cEuIf*w)~WL>K^5rRiFQuW?yc#UjDg@Z*RD)|B3(i zmw1lWzNY^|HvC_&=>jrz<5^vDp2tdMt2NO0d%ktZcWlgYcB_?pEo-H`(DUCpE?Uhv zYRJ;N;ApjUId0~+naX)5s)qOcqP#Ph%}YY>eu44TPc5zI{9l2;=rwIZx_KGn-K$vt zKSY=a>-Sxnd{?gsoEfX-;Jaa2J6Ch#`-F10tX9s()pBiJ%@|>|8V;?d zkH4C}+G_f`tJU&6-}5`l_xxC|ZG?(Ft5tr-YQB-Rn!OUtCmbM;`(giPu6PnfUj(|yj6Tdzizc!vsSC!#oiL6-@RIG5XU`+pJ)@m z#Q>wTdDgr9K12BbV6Ovy2Z?JSxLP64YW45xy?Anvb%-e8s3*SyS~@h2(W5<Y*B{jHr8aJ$73luCsN8TOUc!gQjR~R zTz5)&!7?j9ol?OySP6lVNiZdJV3VpX0>sS5wq zb1C-!m0Pusm0NYsms|DG6lK8o-#%piql$0F{Uz)F?=b&y1@j+#AI^K^WAgEF&bBQd zyEy;xn%bFX`TzZW1kU&K&x&%UM+h@5{-QHXu;n&Oba zo`h8NHH|{%-q)1{j)z{C6SA?pj=ZXz=U-JWGH>6j%0KX`3XZ(aIQVt;KfKQVhu7Kv z@H*q*SK0sgD*GQ_Rr&O*s({MbS5*bo*lVEngIC!z^D6JYzsflHRW;7NuBOAx|LkD? zXAb>O%rgIRhWU>(@|``Swuv+HA3MYQpJ&wm;Tih6XV{N;hTniY!yM=t_24In-(KiD zdWP@!o>2$})}3MecZTub8LsU&cz5_ft zH)`9!81{-c*~|JS?fp&K`k{RtEQxo8VTs z3l2ad;TsOW$^GC>eTU;-+3vD=%f-1{bDw>MRO=uQ2UbNSE%y=9E$>lo8# zG1kwQx|$g4GZ)Z-TRG`^ZH@4+*`Z6edfON&24<}R`3y@7SuAC zAP;JQ>84D63O3URpXT{+4EbxQ`oelu-p;qW@A?$qdH5@pUCTE=?`JIVb=GIu-`4c) z%hdSj$aJ3cK8B=@yF!O))Rww!EdN`zRf3?-Krojqk#MFzw zzLhcFt*N#`FH~Ct$iek74egKc`!X>NPYi41L`)IViJll%3>iNbQ(`uzB<`u{81tXx zVO5RQOuw$l6K8FskHRpE}S99b5PtMvW2 zO18#TT-9K?4~_6{+;J6tFrot7@{f(McRsG%RdMB{kIRM3c1P)B#^u-&S5^iuCGLuA z^qLW+K8&*`VMGaJ{DZh+R~%Oq8F@ac5#%uGgy-TKdMC~pFRlUd6>5sBzbdXi+$HkvHb7USZzq(-nhJX#N~NDuGT&ER%TtBvbN+| zj?OkY*L8@`on*R_ZOVD4Pq|m*S$Rz^E8o+m0{j(LwW(-h9{s*Ns~B1GPMbTBj{-HKC9B5NxMw^;;wXycorj{*jYCRZ` zC*8*W_gCaYw#~K4zoJ6{Wc!W|bzJj`I$a&=I+|^DztE>cdJe_AH_PMO&vu(@+wrTW^Hf65O zx3X5`TMnf2te@v!yIkx2x}9>!UF%m~hM&GsrwWivD`)blm`peB#jE@(S?Oo(%eebh zzRu74Z+_PQ+I0_ps%KBAW{01-4nO_gPSqnDCOX;wazgiVzPZz7wcyse*Ux@(zr4ib z%kZlW>0ei11s?LNy(!=7*y%_Ae5;G}x}%-U|GBIn`Rm=+%l@}~tN;0IE40VY?;rb_ z^YClvg?uY~BHtQDj*zd&!FJaF{Nk3B8K3hj0m&CSl|p_Ux6Eq-%Gwx^V`o6l9RX$U z59l9o&)FN$%lG73-^TtZ6m-7I{C_}2djfKMx>SrTsp?WG&$aJzzPv8Ssvw@qx`3)q z1yp?^pqjY={r&*oRmr#NI|Do$3arNS0W}>cu$pHBYB|)U)}uL=XKO&-=Y#UO3aqxR zxt4!zfW0ySwQmWiBUE5@dJ3#AWcTv{^-%A@WI(sl&N;)o(4d>UI@cwaDBL+m&;SYi7D#*{*i_k?nH0 za;>bR-3;hYs^myFdj;C%<{B+}zg>mT7Fq?Qm(Mkt_fES;VE90;74Gb2?S?)KX${~O zTGOq5V51a^P@a#xgNSZ+tuZ7r@xJUxMh~LDl5Ola!j}K{!5XSo!=_gjv~f^t=zkMnE%eS z^4+b>e|4*HOOaLN&a&L0Ru$u~#MP?O{Y6&Uo+7IpS<%$0$`fAt_N~l+^)PSSs@i8; zRky!K^~i>Ot-Sx+swUi<_qHSg6U#T70aiUe7 z89nMkcCYMVF7Q_hR`sa&P^Q&~TmPz7=093Bu&0&%&#f9lh94@ihS#;y|0%K}M_U!0 zD6(Rt7yqD#`R`0CN&Zrl@96tQR%SBO%F6P{5%tI!_E2ta%eB^R{Y>0 zFZ=(zDx%EY2fQk#?_7c`#c$b(UX>#&(!KP5UsKhQUfwNrST(rSc6wEp_Nso)tA-O^ z`oCWG|GBN^sF$%~vDLcHi~hxym-Kvfy=pt;u>8Bc3LNfa|5LHmvEFHQZuF`Pf8AGj z)#E9)f}vuo7uiR?`l;{GF)#a{digCmFZVw$&&pot+A1m$8pm z$)r~)@-^xw&sX%R-0xG_H6>PQm5*NF zhkbH$?G@oZi2K0G*A+tcQ;#D%`!tMucugPAdmnukpJGRxRy4m)ku9&YXVJ$zHg&ku zr^FGTe4O{r_A!R=sdXRM%zHisaP#y0ZbNp@`qXvG$J&`s9TPsT86WpfpMvPulk_P< z8;PcAAMXX(Qy)}udhb-q{K z(8IMH^b}dWt|F^1S!A`o&`aKWljt1sUT8ReL0;x@U{5`U6E) z!{H*U@la4r(7d)+El~PA_m?}&{UylQA*jloy{dxhb-mP8ukv0A%&R0s!o zMptrQE8ZGZ$u+cZ?ps+$u8_UI@8}iE{vctwRz|Iy4`NpCs;G4s|F!sk4%vugNS=8& z^N)o{rYkd3$X_xy`3mwckRPK@UW5E=B-4ADpGM|F)H=8o@SgBxyB@KyL4JPf7G zv&<3K6<0Agx`p-I>lpj|n0E`sZv$V;Jo{G0Q8yFx@0drvficWB_DSIW?d!R>X!`HyE=d_%`UKXf&l!eUc z>{agb+{civz1+8sv2SnZIQOyB^vStzd3bLDig0(oz#K%hPtD;zHM#rL$bF?@oztqH zqt6=ZQ*FA$s@dUFwX086tDIIP{wgw@R{8sV%==_orNKUxeAuVrV|{W@_c1Tj$2T|o zRNyMH@{iKqPn}WjH1`Ago~{pl%0@cZald%~jIxlKw24tR7^FHssifzVO1zM=;^9v! zmLIjESy3y(Jk7}A9t|h?O^5tXY6!oB_#Hro?zmF@LFVvgqgL-R_C>Mbt_Ru8M!c@| zJv=k6RL2^A8zGN+x+}}_ue_4IC08=n+9&TjPRsM5(`rR#u8LS$yVlCVoT>A}q?P@C zhg?^rteh=tmHQrj6c$DEp@4bR!V^iWC?jIIccrZ2&{~y1DRZx7p`=w#Jyx7u%kz(U zUlu#7_eZRnVA85Rn6m21l2-kph}D2>^jyL^@1<%Eb*m-+Qs(rS_kI3Sd6B-Vl+}jx zAG%b5rnSuVtyM>IE&boMe4n3r;Dg=jVIUq1rmS9n%Id3ItA2OX3ca&d1L?II{BSM( zKl=Z3Yc>1<^S~KND?+_QAvPDW;%vf7%rfu0GHIpOcWCs(l$E*bQf1LUb&wY)WN*Ag zuJZ{iXI+%vzZkXhIxkf|6s)~Og|i8(Xicx&{!7{W#y$w#OZP^tvT(vG-#cnmL=#r! z{-{;enXsz!nFkN9<@wj8y0W#ZKXi#2SY&G?T+`uEtC{1LthH*TEIduf-w+X0CJ8>PU`SomrQvD;=@A@!NynAhNg2&-&k`tp8ua9=L=xFvtBFhPWQW z?_8?k_b$=M)`S&V*{NuH)QX+DRB^^RiDxfW5>nHbXmmxw%3KlS_dgPrBkOWGA^RHM z<(M0#{)5c_T~51U-$&?j6>N-Ig=a^p|Agg!Vbm(#8M8`|rDd$|gfCNhR;MbWm#K0; z`#yq~QUB}%xg%lKKFj(?-DS)PUC#G^FK6G!<@EnAS2OXpggTj*;QKF43CmlRuza1D ztL<#u^0Uz`5WZaPEOd9|2h>@1Idh-9yRm1~>Z#-T0==P_)yIN(KWiJe{x-Y^X~@iv200tEh(+Nvq|(l-0N*ZZ&Z| z)bD4XCJXSjTY6N-Vth4wCTdQktV+*is)F*2xK*($X_c-*Z`y0|3zw+`ijMHRX!O+! z!t|35B(1z>SvTT3%{dgeav#1-*?W5A0{c5xbzFD*+nIN~jJDsSQAl*ME<;<1F&>GN zPJ}XzQjWv##H(=VYH3?W*6yy3ey6jQx3!GJP3s|1t$nUB(_s)`_m6 z+<16(P9&`W51#h*ll#$NN`X6C>*D~0_^ zXgGS6eu1nj+pOqs*)R3NR{cB2P1pR2{A}hsTzn^r=YwbcW_=vlp0QbNo~!gp>|Oa+ zu}_zEu;^9lOkc$wfX(U)UZoz~gD`;GSMb+={wl5m(%-jPBhPQv(ALeY^KDk_{j2Er zU!||%C%TqA?b)iIU{4&_svKnInk~wz+C=|glboJST>o3-f*kC*(M^p1x6uFHqH6pW zuGqpjY771UEh=_x;{DGpDm}PGWywt{$4|uvn^fuFq^e_^R6VhY`L9i?J+X=Fe-q#T z+oFbNx2TzTyzoCLFCWk4VeFnOHpzReX|iT0xF`n?9$@ zNWXtiCH?MI8YWn}x^#7~r|5wxhucqBq)Bms5 z1m9il+;_FQHeSu1lB?ANLFn~wV1Dsx{WWQL!W-nx-=Iy{+g4qzPa|)F_PP!Fdt@}a zL0>|~_g}5=A)kV;!8f3P&(+M;ZD4)=Y7KKd)$~iQR)Obg z6+#iXp?J#%#@AOXGkCSKc5IMi74Oup*r0zQ4p;uw%+GF6?z#<(uQxDPxk2@3uV#F8 zwHlwlPEA>_tNGA%%&%O>{M^;-$=RTC(x`Z71N&1q@cT3yRQ=cn*5Wp>;JJ(S?_F~2 z-7P0%KeSse$ibcqc}I84L7F-66DWoP@WLOsCQH}v=KYu5>bUC)rEq%@UV=Aa1^ryF zSi^gi?|gli_dK7J0?c{0|A4*-?}qz-`U$Q0SFR;kgWt7q1+0Sg&`!JWhAr5)!YA;z z7r74}f=f9*hB?W{G|;=gVs>)`RF^f?V$hj z&M5^Ro>KdPDRrDZrB34R+C%?m&zO4dSgGLADfa)f4qkN$-#_E|4+DHKcrbd2hB_}% z7=|C>`M>WH_Bvg{y7wi_e_o{^X(XdGt^Jl6TI~A4;I~$S)Z?qd&S-9AfYPQQiT7YV0*o`w;H{ zK>Zfp0buWd+56x09^b3rJ%AQyUBNp5-~}JFfgb|U4js@5UGS&h?_SB?|26b|5B^;J zq!ogJRlNTWL+j`(&!#oZ`v)WViM*d?&lm5%Lmd9!@4vVBcdHqis&=ap8VZ)Y|5SBwH+`Djs+?P7|ND8Bhq>nW?N;d>^E$zGGz({84o0}9uHt%&BQN2a z{1~i;-}2nM4)?_!j9Cu;URkSlbHVS{=Q%fu-~F81jGy>pyA^}zk=@Mi?bgV_c@2Bs z<(|G}C#W zw*s_h|Nh;~E$xgI@}A#N$W}E z58$Wp44lNR5_u1Pe~z>=xW6EOgZmkH9lGdG^umAR+`n@!PC4K9S?&QGUk9It#|VE8 zzhA_D6MPFY^VtUqjym=LKsI(4sf)FQ>i88+b z04W%SS~u@{(4OjPD-FoT^|Vi9Gsi8+(!=N3yTP~P?%*485A*G~5Z{ms^X)kHKhz-e z)~;cFWDVn^HM}FVMnxHX^8t#dkE>+I8g9hAUjYtqLN@1IkOR38IZyr|_5t~W1okAP z){#Hxe`r#n&g1m+&<{O_kl}auz6Ek*mixnz)9QSW@`3JMlj`AIkTU3{jQWtRgz-#p ze>ukeW#y#W2;+yqH02AKT%S*KT|B|H@MU-oegyvvKZ8Q}0p}}_zvbF#K_2DUhx`wY z`;c+?CHxESSAY|K%W)p^cgPFyAMgqKd!K^Ma6N1XH-4@}-VI-c|C2bsg)GPYHssHs z274nojU7thclc{Sc0)T1!t8n4G17zI?j6DW7W4{xAPh58+ zFCZU5J`T^p&*0BKOPa`0{0}0>k=MX+?9<2@I0=9LS@eUm*x!J+VG;gZO_khj5oa5KlB;oMgE9QMpLzfqQ1?`L166Pe9EN!JnD|MRqe z+GSqGZrd~09z3i1 z_h|pnc!Ks1_IK(}(f+s6{&&#+cd<^7Y~$E}fPM4i$LyQ$ILtRoj8|o}zu7rF|irnGb1!MreY1 zXnC+NqrQGtF>R>sfihma3QKoVk|+y@}??3AL&VK;ivj)#z87x%d!_W_Rk z_r0l*pLULpeb7T255B;C0G+zgtz!%O1(5;zR_*IA{;nd@%drpHx{iL=+Bfa?zm03* z33F{Ae}r_x=dtgB7VKX_wj&?-SLSKCW=64(!JXJAksEpL-;JMp;hXsRHvAU%KVzQt zZv0$%6Kw!K1)pZ@^tZ^(a5Y>H2MD(vc_-WrH^W!q5x5N=gm1#P;2ZdT6ut#tCGMZ$ z=7d7{F>WQuL&*C$_aHol{WX#@N;}O@<-r=T=*GOKn-+p-iLgSbi0v#9RCRWIqV_o|` z??i8oo!F<4nd^A|gX0?7KV*OSpK{Tba%fArw57Z&c#mW~&wpegZLjDE?f-e&KeB}5 z(p6(BgYrEesDgMZSJD2VdMD3+sKs6f^~BLY9F1#e|L@cOKjitpb&UQ$?H{~^^+6kU zKLiNd4jtG#p^J0fV83U>aqm&uKlEb{!2st5VF-H|hIg|68zR`F5aV1N64;ZFDjU;i z)fj!=^}LU~p7)X0vmfI{#w%y|w#$0fywlLMa$C}K2;OFe)eO=)L>orWc5oG^I>c7yjH$d}{+) z!g1;Q?5q8dJ-EmUjw^AmYWlsZ+4NuI`n_s{{7y-j-vLSTUK6sBb4{K6J`~sCF|NOF zB6GPGpFsWyerArL5Vu!3KZ1P>N;s~7N$juVUW5JD$T?Vm29E!nZzg;KE{9LSySS5e zzTw92XFv?d2EPH#?`J>?v_TtQ&H^W7g9~yX7xD}*XWFLmlXuL4SpbDlVX)6u*fCVt zXDe~7v|TH0*Gk*99oG&R)CO~Wj8tZcW46dJTZZX2*tTM9#kQ>kTZwHe#a3$DjNh`^ ze#~-hIOnJ{M!hy_BmcK*lQ!!rZPBN7H9wT~nT%`T24mQ^ZN|3U7&jVYhcWoT_GdC~ zHpVT+*ok4hZ8M>aSL4%oH2#b?{YAD)9B zz!CT%{0M#wN8u+hmz`;w4%>9vrpq>SY%|w3^KG-xHj8Yt*fvXSv(z@rY_lA5K4RjW zkC@ozBPQ&8w8@w;6LvmUY0S9Kn29XgG-1=OZrcnRvtY=WZgV#6HvZG4Zey15+h?Ci zH}DAhOuDhfn85(%LXF>;jd`|dQe0^4FlK9)G5bwyi&-ZA#r&u-8%*kpZ7E~=P2Lv$ zCf%i6<9{i4(3k@rV-A`UIG<~B_^<``8Mokci|ura?X*)FJ8{}+JMFZcP*E8eMll4LwUss| znZYD|Op?ST6-*+=BrZD|TOkwGHf_6BXu9nN{*5rFx7mfm7uvSjzPH&$vJJ0xacr{- zV%v6YvkPIHT?pH5)Hb^Sw%G-+ZHKnm%5S?#+iaD$-K=f4qT6oKHe0=IJGISLY8x3N zHcUd<>6`4?schHwYeji)zkz>qYI_F%wq={PbM{7U-y!;m+iw=Vr0qMkooH~hBQqwB zws#Yyozok2Z6_gi$! zwW8wfg1N;w+jYNFJFgX$Zx_=}<80Ue-|O%HewRMSPKjOGW&U?*w{F#~8+Gfg`hxao z&jIauJFn`nzNjxgtJ`$j?Yiyh>vS7#Pdu*Mb^Dif`!{s^Pjvf45>mQDU(%QMd|h|y z&M)fDyMCfO?>(qHkLoVnb+7JvKzBVZ9QSIk`G+rS?*Z+7KzqNgFY9jI{YBl4xli}# zo-ga3`)}7h-_t!Wj%c6ueNp>v)4toaj|lb=>%OmP-viqBp!Usl>MQ!nm-UtJ=_^m` ztGZYB-lco*-3Jfo-mmH2`|l6|_Umi<+H<;3_ua1h_US$>Kh*uY|BicfKXKl_PY3jX z9@xvj4Ti1zgL?2bJ@^eh_#J&+U*F6BUOl95=$rZnJ*Jbhf)gy=b-v-pVXs=`QHY({zQ*b zn2+hP2lUwYAJ;$WpUnRucv41S9@gV}oG8Dm@1omx59*K(nSc164n3tqKh&Wg>w9`a zPwdeX==TKPo_Ih{Jg6s3R;a%xFc0&;4YuovALt26@<~1UJw16?Pd=q5pMG3V{#Z{^ z{)dGxJRRPn!w=~2gF5^j9e!Mg4?U^F{BzozHy)nm@Hri(8Znd>N!2FrytNWdd3voGkf&Rm-GzB59*mm^vrX5R^Qk64?b$Jd>U)hm|tV<8tc$lr^dQ8)~&G~jRiHp+@YHUbjVT}!IEUK}X#^M@FXe_C*QH^J6JWJyajXO1-t#Oyeb2Of-@jQ*^YrH_? zg&Hr?xLf1J8ZXg!xyCCrUa9dajaO^DM&q>_uhV$F#v3%=sPQI^H*35_xgQVPcwFNN zjVHw>{t2fhvNhq-M2;qMHIb)@LQNED!mWv7O_XS&R1;;IsL({ECaN^S8exqlYBf=( ziF!>mXrfUQO`2%dM2jX`HQ~{OR}(%>v}wYxi4IM4YNAUM-J0mpL{JmGn&{I+za~PO z7|_I^CWbT-7H=p_jA$aFiKuuJVIr=Hgt$r3tbS3fUzF+>t@?#Wzv$L4`t*x_O}aH% ztjQ8hmTIz0ljWMM&}5}1t29}y$r?@8YO+p~^_pzZWTPgVHQA!cR!w>|>D8o9lL1Y( zYqCR=oto^@WVa@JG#S)nuO|C6NfJYv3~Mr?$*3k{nv82Qp~<8sQ<@yrRHmk~H098g zQ&ZWRa%n0@Q@NVT(^S5u3N%%ysUl6eHC3#s5>1tAs!UVmnyS!LrKYMhRjsKSP1S0u zPE+-oYS2`(rdl-R)l@)J?V9S)RHvr8G}W!C9!&)`)vKvKP4#Okq^SW-4Qgsc)1{g& z({#C}D>YrE={ij}Xu46;Et>Xd+NWv1rUPRC|8%FOyENUc>7dvKG~KW1kfsMTJ*ep+ zO@}od(R5VPalMqKmmGS@sh9HgQh{D7)JsKr$*q@4^-`H$D%VR@dZ|e-wdkd&UP?SA z204sy7$9L(XeLK9d78=BOo3*KG*hgZ63vuqrc5*Cn!!V*W~ww(t(h9l)M}9||Ri*>wF$D8y@re4X?D^9(Vt5=Hjid(M~ z>y}?WIjECEIvLi?zJIfV5>Jq1OK)(@ig*p4x03OvyrXF z@EF6Z*L_B|EgSd^7+bqBI`}uT^P=&3mTiCCfw1ky*2TX$KY%W8c=U!>Z}{{^o8Iv2 zjR5~|v>T(x7(rw78Y5(k0gPFfF>;KNYm7W&oDnenDjeL`W+^U4wH0;@zSX` zy7)J-U>ILUaMlRUcbTX$mOTWQJ&-~dDcGKSmp$9*b(ws2ne@8sld~>67{Ww?Fac*> zb_j%t-ejrEWT|V}6}N8wO}K6oTepd=+r-vwymlL}-NtLT368HG6E}u&Gs3veW^XWt zU1a#T)xxlSE+@~?vQHDr*oef|5yL(-DWa{3)d=Hb*^Uh@dQH;3CX`(ab|KhB(5E;0 z`Dgyba8`EC>_*;IiG$$)HIGW7*w~e9*)>XtIW3 zhhA2Sqh&{w9kl(M40C}D8BZAYp^1CQXtOMg50k`@iFC+zU9KVQ%SXgEWMUgK!NdHU z7%(mh+ttJnHZg=L@2Clf#RRg&2D=KlYBz8$d$s_}Y#v$6zUc8xx$$TRveOJNha|aS z{!QkGO=1}4gc0VHNfX0_8#YM~+cw*WIZyK(F`-6`GsZ>x#UtB?@jGH{5&q4IDE}ro z4Eu07vm7rU*_p#`!Wlux=&~JKjDHh0X8dCOQQ+4c86U=P4BKobZ5N~1HfGC=D6x%Y zd#SOTX*h&!lxsGNf1<|-8l%@XNMknZqG96AA_WX=4*rc#$80vGi@+zwvXjZ1IkG*P zAP!@5^8ZG=F-XQ~d^$}sPA*jpLSXz+81E*a(*)#vHupgcpE3N#Af9X!55`4df^uXB zFp+25*Ptzk$Q+rdOo|w`=Zl8zc{#J@#%Etwreb8yyG%-^5g5bxG=ew~%ct!GZM`v; z{h4bY4~RGVmMtc(PX#7gj6Vt!(&UNOYYZYRuwyW> z6OfW5WK_KOI~&(*O0%OnovG7KozB+j9G%Y9={%h-(rLF&7wdGHPS@(RN2l9#I-t{? zIvv#MKAjHf^q@}1bULonDV@pGnJk@g=!{D|9eFmM$t6J^k7x3Arch^!#8Z+h{7i-3 zYO8vyErF8&7f1c{@!Qg8JY0LOhChd4Kw3AteG zoEzj^7#TJaKO^;jeR9wz2YqtLe-7#85MK^)kx4bp>r6&!{j-F-w1vq_>15_f-Vu`P8t#N6d|1m zbr>P72z3}Cy$E#}A-MgJ)JN6|lu{!#RgqJI?qqv#(+|0w!L z(Laj*QS^_Ze-!O&=pRG>82ZQ1KZgD>^pBx`4EjG<=?J!9w@N6$EV#?dp5o^kYyqh}mFGmf5d^o*lt96jUc8As1Jdd8{eIQ1N-p5xSWoO+H^&vEpRqkkOzMeETyOQBl|-BRe5qV7`Ym!j@c=$JxBIw*4~^h{B2DRfPt>nJ*oqT?t!j-uZv`i-LB zD7uZJ*C@J-qRS|{jFSIR@;^%cN6G&v`5z^(qvUmzypEF3QSv!TK1a#tDES;EpQGe+ zlst}-pHcENN`6Mk&nWpBB|oF&Wi+9;GfSZYDq(>Ci{gK?Egyv6&J1(b2;1i9)Y~}( z=P~~_kNLNAh$g2V8lYEi=enU7O7wP~1F|6(@}UR_mq)m~HV8mB{};voW?Me!eLIgx z^Ng@<`B{29pD6Q*GM^~(TfqxKz-N9R^g{@SfXwGdAgZ?uGJ)^~gfDPG9uP-C4-me9 z@CAe~B+fz)5O1Lmx_~?u;;#^Ym8eyXpPE9w9Uz`Sj^6HWdYM$h9)31A3mlLK`A`6b zPy}u$h7u@+GAM^i!+q!YAUKlE4u7=8iZ0sHg-w67_5ShfLVm?Xw}q;jKiV35F4|bM z{-dqo;#m`WHPq-_7P-kHCt2hoiyUN;gDlQv)j~ZqLNm027yQr;ozM><7=$4hh7pKD z4C3dqP&g|ANu6`#opVqI4pLyljoEK=&Oyu$%F02^4(~Z<**Rz8oRb1G2i@nKRTs~m za}J(!*5kT-20z`{e9&e%=PWcn9OTCw;Fq#@kW)v%aL&<*BR2Bw=rEk4OdNf7V$=-B zB{jm&e3JCqeK;BBSi|pVN&h3(=a{=K*Y}F408}X=QLh9(_^q5OtRR>l2J8i zIOp{L$s-3B-GXG_j?biHoXovKl!%06SfRvXTdQ;O8lz8p&pZZu4P?dt(eK>vdu)q9mUBV#mOCo`;*B+J9I!N zn50RUY+tNT4vf1ANqJWi;j)baqpFOv$xhF51`(!wXlctao8mA^Q!*w-Y&2IVjnYX| zw9_6!(uSO-Moj(CY@Bg0sn9&iDcy41%WzBJWxV%SFHI#Ll7dG50gG)`+3Ip>>-4F zC;-x8{65c^e4a7(Jmc$m#?$kRpXV7f&of@0XS_Vm7KU zQ*jtonl3{+vs-B%plKeHY1%-VmYL>KOJ{chewfxtyIkOgN@xQ=;D@Q3G*dThe}o{6 zD4m1<9OBKvKYg`yP8C!`4Uk3-;dAiM)J&Qwm~;+da`T}G8lf3lfpdApmFI+9AbegE zkX|17&5J9Yp97?mKLEr-$0}XW06rLmA*BlmTZnrhc_}3BLeeiJ?!qV#PZ7Em5f9s> z(?z6Fg#JYlApRoscV|I1;LlAw?lPbp-1u{oUpM)2_XFvETJ4q`XB^^FC`zP_$@`JQtF}v|e^mvbM-}l`ljmx5ucpkaiKm*hs!5|7|233r4e`|wM=fdAp;I04)=^#! zltBY!*g%?%)LSFRjpVrz9h(c4ZYc)*wUWnH^4&_=dx+0Nd3eZ|mw3F?t(UZXwLn-O zY5RJCIDGi`QCDrG-$vSPluH|VZX-OmgmfEuYUA2!YhMVWMjOTJM-@M+`^l!CqW9y) zPxS3X-%e5;WVeGLow#=fl9Gk{Gq5KS}qOf!f~kD?|6 z#Ple|I~vsjgNTL9MlEE8wcyBtN-fa)S)d=Vz%y`xiRcAaHW*n1#n1{J5Ys{q$BgP0 zxFam&^ud4@X#ERJwl3rjYk}6kkVja`U?HyodSM7SmtP1K(4>WeDhO$TR=!YJ0=O5F zm%;?3v`|EzijrDz6OS9eZqj$hwNUH?@>ATYg%a{v;sWwgicY0oEtH{O8RyDIfOyL5 zf%GcyTS+NOcBBfMVz+!CtDdmrV4M ziC!|s^36h@8wk@!JhZxne(e22THx8g5Ng-L0E!NbYGIH98pQ7)#WzUW%&{*F5oQP_ zhwwi{{=?)wOtpo{XP9)u6c%&t3nQe%^3TEu?jz)dx$y<&vKN@sU0^VHb65(wU|`^-Ovvskkld% z)%Db=*ov)E#@?8k;#k2+)^!asV{P=FXnYX7Z6uo7=|GZqgrGZ zVlf|o`NWlvUuGhBGeC>XG%V5-7MWRCWJY1JkYk#_ViEE1U|Mv?v{;-CgfAxT;zIDl zfEG*0KM$Y9Qrx-p7t8QdM!vZG7t2CgEGNI^oGY)@Vg))@aIPYt#Y)0d;#S!LZ9u*& z(Y=y5D+hrxs3hN&lv5S?ts=cD&R3Imb-5O44vV$uQriT?QRe{CtV7>A%C4UBsYmAq z(&VAUyARL~xYH088?%7C^MG1xLf7 zA9-u*0?LL3Pjme*@}ybxqpP2M@wpY=u>jKWlXpLP36N&M4a6BJ0rC=v0AbsSm#5NV zJ85-LMxErT6Sq$Mb&}ss(&-|67iHQ_KD!CqP2JIM7JKp2hyOmxw2$=riH|4FVu<{Q z$mam%HGtb7a*%utQWk^gNjq4i@3}Zc*dg*fgw8|Mf0(iiQwCxD4HM5WVMnOf5%R&h z(;|J-Mb>H->1!^s_Oi&_?;`WNi_Ft5GIzR|Bwx&#E;2W|$Xw_mbDWDj5*Aa`3v+r) z%sDM(x}XTkVNgr-!Izl5S<32zen`QnmK^y|2tgRo66=;r*__L*1s`-k7jQn8cybe3 z$|HUz@|KvuTgt;9?P4jfMN7=MEivP^M5k|wS+*sv{iOopE+~U02mx^vB(+rNhk%xf za4#aAqG2t$$*a2pI4&l=V*HhmKRRejbkLSc8lfFXgO1r!Dd)>LUsj`~3ev8?t%CR~ z$YW)-ma5vcRO1H9gwE7bE%Db9Z!LP);a*n?#92q2T-!@^-CE-2zEqFB!KtHq2KVNr zVajosa^&8;G(vfg;LbYK65W|4)_Rs$&sbs&V~O>MCDtL9=&md=r@q8o`V#ZtORUW= zG538w(+3fq=l*-1*~0T#1CY>pM+uZcrOq?Ka-La$^NefHJCi!k8urfNL+A_j5Re=hb#J2oGx zRfWtQ>*dcMo&}Hp{7?TrKX!cn^2c}H{qwg!?sVI;^VjFXx4%9Y{dX?^?_ZxQn}2<7 ze*E>h^?fA690WB(m~-TOQ~@%ZoeYv1Sb zpT9o+@96X7yT`wN;?MiUzxU+de|`Fa{nPL7pXdHFeo+5B_s2PZ^!RhT@X!0{`_un! zKCk?Bxw`hu{`$QB>hag``v3m=y!F@2&j6qI{+Rdv^WKMl&MrOv7&CvKmK&e*J&%9< zAm!)_ottJy#4h4^7CK*Sl&~9{>z`E_lBR| z34VGf`04lLPw(44{XG;u{b>8?9owgOXrF$M{`BMR=YRaUdN=mze=DEfOMQ9|_353{ zr}sCX-rszB7xU>I%%^uS;=4*cua-J`qub#g>-#p(vKRiD@zdXM^ z6VKGM;92x6d6qpZo>kBP`(v$pHawf2Ezh=R$Fu9%^Xz*LJcphm&#}j!(|@0O&OGOy z3(uwJ%Hz-Vzi&LZo;%OIXXcrE9y~h!`@cO;o@dXC=hgG(dG~1aDQ5il|9bxA`M2jk zo_{?5=lS#3@PD82r2k3(ll~|DPx}4;`gqds|KG=x{wMuU`k(Ya>Gw0*c+&r*|4ILo z{wMuU`u+MaoG%;x?>?UNKk0W8Z9M6J(*LCYN&l06zt@c?{XSuhC;dKU4WF{clYTz~ zjwk(3`k(Ya>3`Dyq~E8o@udGrzfWW1Nx$DS$CG}4o0svV|4F}Jr^b{1C;fgu9sU+L z<4OOM{wMv<`k(bb>wnh&tp8d6v;JrOeg+-S`k(bb>wnholizsO|E&L6|FeF-PmgE) z&-$PBKkI+i|E&L6|FiyQ{m=USx;dWpKkI+i@7K`rtp8d6v;JrO&-$PBKkN7F>v-1h z*Vyr_|5^XD{%8Hq`k(bb>wnh&tp8d6v;JrO&-$PBKkN6oV?67B*6(ME;b-gN6YTKo z_VDZWc+vl&|3&|ce!q?nzmAU={bp6etZMjKeE3;>`2K6W=zr1wqTkOa!>{+_MgNO_ zKdX!v{V)1o^uOqT(f^|VMgNQb7yU2#U-ZA|_p{A-(eKxa@uJ_)Im6F6<3<0A{uli( z`hD*VMU5Gc;cHzv_S0|Em90|EvC2{jd68_4|2m zyy}0||Ek~jrNif;@v8q-|EvC2{jd68^}p(W)&HvBXQuJ0|4qNoPQ%ZQ!?zgYP5+zz zH~nw=-}JxfH>Vrsbi3`GzrvFX9pG(J^ z{x|(^`rq`w>Gv*kyy^Ed>+o~Jc+>x;-|rRUO~21|!{@r;+m-RA|4qMdS%x#m!_O4s zP5+zzH~nw=-}Jxf_p`?EIdHt`f75S=HGDHO-t@oef7Ac2|6TvP{&)TF`h9L3@A}{M zzw3Y3|E~XC|GWNo{qOqU^}p+X*Z;2nUH`lOcm2M%9`E|!^}p+X*Z;2nUH`lOcm41B z{fs`|^}p+X*Z;2nUH`lOcm41B-}S%if7k!6|6TvP{&)TF`rq~YwrTjeYWTf%yz770 z|E~XC|GR$QM2&a-e(oCY`h6}K@A^OVf9U_v|Dpdw|A+n${XRF05B(qdKlJ-~ZG7ne z(Ep+TL;r{V5B(qdKlFd-|Iq)T|3m+W{tx{=gAdE*yN4>(QlJC{OmLS zN54(pu*n-XdBY}e*yIhHyy1Ppu*n;~9UL}!!zORo`@NMI;$s67|44b^+v*57F8#Z~vCU5v$ zIBfEUP2RA{8#Z~vCU4l}4V%32zxr+RhE3kE$s0C#!zORo`u*nt!zORo`u*n-XdBY}e*yIhHyz#&KZSsaq-mu9V zHhIJMH^U}x*yIhHyy2VEVUss}Z94o+Ic)NVP2RA{8#Z~vSE<7$Z}>WO*yIhHyy550 z;X9yVlQ+C?88&&tCU1D}GQ4*gHhJS;`fc)tP2RA{8{WkXo4jF@H*E5TP2RA{8#Z~v zCU1C8Gi>sPP2TV`?XbxkHhJS;`v0Zh_f6wp`v0Zh&%47ronfmtZ1u*!^n1TE{0uyN z-8*dchOOSP)f@lP?=#ZyE^OH94ex-4t={mJ@UYbzwtB-?!^2i@*y;^m5f59v;l0_g z)f>L!8h)-GwtB->Z`kS$TfN~O(y-MVwtBJ3}H;j8Ikt2cbFHf;5Vt=_QJ z8@77GR&UtqjeqO6)f?Vr4O_iot2b=*#=rI3>J3}HVXHTM#XW5GhOOSP)f={Y!&YzD z>J3}H;p^{Vt2b=*hOOSP)f+xLjDPF@xBh?Y_c>zt95Mc_--d5^X*T|&--d7a+I;wY zF>LsT4d1Zg8~@R7!#8}LK5Y1g4d3wjWBB|rZ1{#3Zo`Ie_$)GfXE<#5hR-D9Kl*L? zhArQ)-}sMyo4(=m@UZC{ zHhsgUZ`kw=o4#SwH*ETbP2aHT8#aBzrf>LOXV~-&pP`3M->~T$J|7L6zG2fheCIQ4 z`i4#4u<08%eZyy`VbeEk`o=%{|Iz=Ce%rp`Gu5!|8@7GJwr~8S{~!HcD-7Qq4WGA$ z*9*hOZ`k;afArh~@` zHh;tBZ+KNPZ2pGV6~pFl*!&IO$BuvW+x(4x^!x5=c$G11|HeQ1?f=F<`fUJ*@3e;R zw8lUBZ2`wW`v1}YkA5#nhh5;X3mkTV!!B_6?rV5)I_v_6UEr__9KH)1UXu*3Nrumr z!{^Fj7dY$!hh5;X3mpHa-!5?Y&TQBP4zE*&UEr__9KJ&vzQY(kiw?WM;k%6CGwJw$ zpGoZkhh5;X3mkTV!!B^x1rEEwVHY^Ojv01=!!B^x1rD!fh8McSE^zohWY`4`yTD-= zIP3z4??;C3M}}?Sunioxfx|X%cmX`@1BdVO#{cQ}GI;#a|3|;g;PAR<*bI(8`hC|u z>;{M3;P|88c5wXBZ$CJ^5*mN>|Iu$tIBW@rE#dg1|BwDZ`t1sb*G9vxaM%?NyTb8D z{~!H-^#9TSN58G%ur(aEhU1TZd&6OGIP49Fz2UGo9QKC8-f;Nteb^k1Kl=aZ|D*qp z{y+Nd5XT?=c8J5Ps$q*bY!Qbo;_%vP`0jH2(QiL6{^<7|{`m9ful$#5j=s~3m*V-a zzdheP-#tG(KRv%ZzdaMr)U)7O^elOnJu9A7&zfi5v*FqFYR}6=iGDQx%6Clu01!NThE>6-ZS&eJrABo&wqQKJkOpN&#ULn^X~cZ{Ez2< zJ^%9j+w&jKKc4^d=>MYsi~cYAzvy?xCSUaXPB&llf6@O%zwg*$Y8Pw1e9>>wmoNIQ z`tn7;WnaGNx9-aq{T6=tqTk9dU-Vo0<%@o6zkJbe@s}_9eQ%sE`Yr$RMgJH57J&Jp z-wH5a^!sivU-Wy&k}vxG?RfG_Q5{#8#yqb)qV7!aT7yZ_PF?Pro z{k|WJ_cQSwb$myiull{l%vb$i^?%j>RsUE0U-f_0?|ban6~=ej`Kte`{;&Fd&zP_J zeb<<;`oHS;onyZ0w@JxY{a^Ka7Zh8>eAWL||5yEXiTSGktNyS0zv}<0-`p^Eh54%g ztNyS04F+Q{7~{Zv)&EuhSN&i0o7m;6e&4UhZZ5`gF^-Ga!n7qa0E#_^p zzlwQV%-iy}extVht^aTRMr|=_%isFF){Xa9F=LDUS`63nxBkEN`(8DF>o-!1U0N); zVttms^;?z2Tr7X<_gXmKf5pHgR$VdOiosTVCmZvsm`=rO<5)(;h9{O$v4V;fR18vL z`xG0e*gnOeDF#h3bC>s=`&p`-yN>`KF(%3Re}cv%^*8n|`h;TvfQLa8==|!d1n0%HgW=O+Q-| zwkqHBf7Aa>zgvA_t>Sy<@K)ii!dvB=e&#CP4d$DE<|^Oxf7Aa>KYx{P`oHNn%n64T z4l5j1IIP$^<(vL*`gyE;)BjEXH~nl@zUlv_|C|1A`oHP_rvICM?_px-6GNYT*Z*Dr zcl}=7$9u|r*UxpuYy9}GJm2+y*YA~nd~Y7F_48f-cm3b>f7kzAzp+rh>;JBw2MZ4t z9xUJWf7kzA|9Ac0^?%oI8W2t_oLD%q7#hXUD4bY~jbb(sW25+;JC*yZ-O`4VUsmzwuIj=;zDAm*t24 zANo17{LueH{}25?^cy+Fj3Gbtn=-^eJw{J4P!Ee1d^n1S$!}S=h$9v!W z(9f&ohyEY>xwZVz|3kl7M1JW1q5p^eANqgj=h|XG6$7fUZTX?!d?Jin7`K>FgmH`a z7%{NQ5B)|~`Jw-Z{vY~(=>MVrhyEY>f9U_I|EK<+`hV*GssE>b??S@Kg_Db+R(|UL zso(pN7~aRb>v&HR@2}%sNq*|*=<-wlPyIjjGj;Ki?i?<#r+)7k^Hcv%{Xg}ao5VZE*bT<$EZ#llm;PV+4bWnW5}+Wz^#9WDy<~pr z|E2$ze(&xB8U!@RFa5vt|I+_UzwuhU!^|)Jzw~>TnP2)1+VV@k_nP^o|Cjz>`hV$% z74l0zun;@O*aXD*Ep`F13&=11zx12M?H@2p}w5Zi(L((nCM+y;&D zTx<#QOaCwZzx3M`;J9aG$+6H|JMIo|8M=j_1i4uw|={Y{MK*2li&J(>;J9)w|;wu{MP?l|8M=j_5arY zTmNtUzxDst|64zS5$}WJeQHG$|3v>p|3v>pzjx4?=%47H=(oqn zME^wpME^wpM8A<@Ci=av&P4x2|3v>p|3tq5V7W6OZU(mmxe?h;kLKgHd=wHykpnpOC zf_{6REa+d*zo36X|APJn{R{dR^e^bQcg})-n};muH)_p-{ssNUty$1-7MBJ6hOSxE zZ|s^y{fqjIUSsr{Mg5EV7xf#!#{2Fp>R;5qsDDxaqJHzb7{|uGBK8$Al8w1t7WFUc zU(|1Z9mCly>NlQ^yTn=4zo>sv|Dygy{k9od)W4{IQU9WTv%MJGW>Np5{zd(Z`WN*t z>Nmp8qW(qw#<*G3Z}*f%{fqh+_50hI#||pye_7OTteZvs2D@3(zog%AH>QEHmCBO- zCH*#2S<-Jel_mYQQ(4l#q<=~OlKv(AOZu1eFX>;>@8)ur^e^dO(!ZpCN&k}mCH+hK zm-H{`H$}{n{w4iO`j_0i=s`;sO7OZu1fFY8~{Z?qhf$e2WCS^u(rBj(u2#7-v5`fX*h ztbbYmvi@cL%len~FY8~{Z)=}r{mc57_1oQNS^u(rJDV))w>^tdbe8om>tEJyARSZ8 z7)fVYzdcTt^)KsR)^9kSW&O+gm-XA~WLdwxPL}l>Q)gNKvi@cL%len~+i7H3|FZsN z{mc57^)KsR(Z8a9MgNNa75yvv?c%bc-!MD&aaqyN4rE2Yp>|gEujpUVzoLId|BC(< z{Wc+4(Z8a9MgNNa75yvv%}it9ofZ8n`d9R?=(oSiivAV-EBaUTujpUVzoLId|BC(< z{VV!c^sne&(Z8a9MgNNa75yvvSM;ywHzbcCc~R;8rs()4gs{U2|tNK^R;8rs()3#tzuU7uj*gb zzp8&#zX@+v^{?t*)xWBLRsX7f)84G=U(>&)e@*|I{x$t}k6F{drhiSp4P@5zujx0| z&6@r-{cHNy^snh()4!&FP5+wyHT`S)Z6~v)-;6kG`q%WY={M%jn*KHYYx)iQv!;Jd z|C;_a{cHNy^snh()4!&FP5+wyHT`S)*YvOHU(>&)e@*|I{x$t;`q%U`I$6`t>10hm zK#(>4Yx>vquj^mezpmfRIqUk@^{?w+*T1gc{xj?P*Y&UKU)R5`e_j8&{&oH9`q%ZZ z>tENuu76$sy8d-ud-v#x(#|GNHl{ptEMzo0oO{_IX*?zpfvU$h!V@{pfhAAsee=drv6R+FiOnhv#EbmzYTCU^>6Cm)X!aIQ~##^P5lg3HuZ1n-_&m} zoK5|k`Zx7&>fhAAso&&2oBB8PZ|dLFzo~yy|EB&;{hRtX_1ooTQ~##^P5qntH}!Ao z-_*aU-)=8k`nU9N=?7@CrGHDmS%0?lZ|UFCZ{L?K{agCC^l$0k(!ZtOt~p!!xAbr6 z-_nl+WJ^C5kS+aN`nU9hIoZ;`rGHERmi{gMTl(>VZ0X<9zomam|CatO{agCC^xGI_ zOaGRB)F4~>xAcQP+0wtIANYwKVz%_72-(uVrGHERmi{gMTl%;4Z|etxvaNqxKemu< z{oDGt^>6Fn*1xTPTmQEHZT;K&xAkx9-`2mae_Q{y{%!r+`nUCO>)+PDt$$npw*GDX z+xoZlZ|k=S46Fn*1xTPTmQEHZT;K&xAkx9-`2mae_Q{y{%!r+`nUCO z>)+PDt$$npw*GDX+xoZlZ|mRDzoUOg|Bn70{X6fhDBtAAHNoRVGrwyfFJZ_k=t{k!_@oCA_&SO2d5UH!ZIclGb; z-_^gXe^>vmew*m*>fhDBtAAJjuKr#9yZU$a@9N*xZ#x}8E%wvd)o(+cUH!ZIclGb; z-_^gXe^>vmeuyT!`dNbP>bJSguKr#9yZY^K!x&^&|E~T${d@ZN^zZ54)4!+RMmKx< zZML(ge^39O{yqJB`uFtj>EF{2>|{^>p8h@kd;0hE@9E#uzo&ms|DJxZCwuz$^zZ54 z)4!*GPye3&J^g$7_w?`S-_vh1o<040`uFtj>1QOer+-iXp8h@kd-|cF?CG~Hj%{&l zi(^}yJ^g$7_w?`SXI#TrWMBWj{(b%X`uFwk>j#gruYX_vzW#mv`}+6w@9W>!zpsB^ z|Gxfx{rmd&_3!K7*T1iSU;n=Tef|6T_x11V-`Bsde_#K;e!J(`J!fD4zW#mv%t!2? zv#)<&|Gxfx{rmb^knHP+qq47mU;n;-J|z44ZKt!ZADYU({(b%X`uFwk>)+SEuOFxi zPm%-u2l@~6ALu{Of1n?>3T%}F{RjFF^dIQA-_L>m1N{g35A+}CKhS@mAEFGPl>_|; z`VaKOS~<{vpr2m}zmfy}09Ow5+j-|e|AGDk{kFvaM1N{g35A+}C2Q+h_ z-#$DC`VaIU=s(bZp#MPsf&K&i2l@~6ALu{Of1v+R|Dpav{fGJw^&jd#)PJbo<~)b` z5A`4FKh%GyA2!RO{zLsX={eL7pXE^hp?(|n9O^&Rf2jXZ|Dpav{fGMP*K?@p#|itp8a5vHoNI$NCwC z9P2;Uf2{vlKeG_#CCpZi^&jg$)_<)3SpTv9WBteakM$qxKh}S&|5*RA{$u^e`niWd z$~o45tp8a5vHoNI$NG=;AL~EXf2{vl|A~H1A}9J!^q=TwCUTOa-bkLFbWss2;_ zr}|IzpXxu=4~pkh|Ec~{{R~P@^|MVm)qkp=amuOwQ~jsOa+gs{d5~ss2;_r}|IzpXxu=f2#jf|Ec~{{ipg*^`GhIUUR1ZO#hjF1~zB< z&-9<^KhuAv|4ctim^1y1Y|iwb=|9tdrvFU;nf^2VXZp|dpXoo-f2RLT|C#rWLbN%P~&-I_{KiAJr=3M`|{&W53`p@;B z>p$0juK!&Bx&Cwg=lVIdoa;Z=f3BZ5%DMh?{pb4GqnzvK*uo&?T>rWLbN%P~&-I_{ zKi7Y*|6KpM{&W53`kAFTE|7Em4h-a6|GEBi{pb1}90(tmbN!55xX&<7IoE%#|6Kot z{tNvM6XZhwh5ifujuhlV|Aqbw{TKQ#^k3+|(0`%-LO*wx3;h@RFZ46hao8Xi`Y-ff z=)chK&_OQrJ9dx@{TKQ#^k3+|(C+|3F7#jMztDf7-$8`%SmCk4tL8%gh5ifu7y2*s zU+BNkf1&?E|Aqbw{TKQ#^k3+|(9gQ&LO<`C3;h@RFZ5sP=LW-h=TiTr{!9Ir`Y-ig z>c7-~ssB>{rT$C(m-;XD^RjWFA(#3u^}E!NOZ^OOF7;pPztn%J|5E>@ezrE3`Y-ig z>c7c7-~ssB>{rT$C(jAXdpTA%v?&*e)0m41#cSNgB?U+KTnf2IFQ|CRnL{a5-q`&{Y2(to9&y~~yU zEB#mcuk^Ecxzf)c=Su&T{ww`g`mgj~>A%u{rTyuk~N+zt+z<=UV@@{%if$`mgn0>%Z2|nC4plwf<}U*ZQyZU+cft zf35#o|F!;W{nz@h^yuk~N+zt(@P|62dG{%if$`mgn0>*o@4 zt^Zp8wf<}U*ZQyZU+cftf35#o|F!;W{TySi^)s=#(eJiFZuH;iztMlA|3?3f{u})_ z`fv2#=)ci_qyI+#js6?`H~QU1$c_FR{Wtn=^xx>e(SM`=Mn4~(8~r!>Z}i{jztMlA z|3?3f{u})_`dv54js6?`H~Me%-{`;5?|xEl^xx>e(SM`=M*ofe8~r!>Z}i{jztMlA z|3<(2+quzyqyI)ftDal^xB74O-|D~Bf2;pi|E>O8{kQsW_225h)qku1R{yR3Tm85C zZ}s2mztw-M|5pF4{#*UG`fv5$>c7>0tN&L2t^Qm6xB74O-|D~B@1jd?_225h)z6&f zR{yR3Tm85CZ}s2mztw-M|5iV5pIiO6`fv5$>c7>`;OAEVt^Qm6xB74O-|4^8f2aRW z|DFCj{dfBB^xx^f(|@P`PXC?$JNA%x|r~gj> zo&G!hclz)2-|4^8f2aRW|DFCj{dfBB^xx^f(|@P`PXC?$JNA%x|r~gj>o&G!hclz)2-|N5Ef3N>u|GoZu{rCFs_228i*MG17UjM!R zd;RzN@Acp7zt?}S|6c#S{(JrR`tSAM>%Z53um4{Ez5aXs_xkVk-|N5Ef3N>u|GoZu z{rCFs_228i*MG17UjM!Rd;RzN@Acp7zt?}S|6c#S{(JrR`tS8K?Qv@%_xkVk-|N5E z&$;Jb|4cvco|*oce)mi=(?8Qc)9<2CX8PU#&P=}>;F;-n2Rt+VGyOCD40LAtXZmOQ zXZmOQXZqbX%S``F|4jc(|4jc(zZ(#7J0&yyGyOCDEOlo3XZl?;%1r-E|4jc(|4jc( zKX09x{+WL6I&PC?rhle?rk`=nOh3z?nSM9R!)C|b^0<(cnf|%{x&FESx&FESx&FES zx&FESx&FC+C-~z8f9Cq<`se!R`seyt>df`K0h77@x&FC+Hv}@*&x2>Kf3AP7f3AP7 zf3AP7-`$wZ_0RRs^}E`Xx&FESx&FESx&FESx&FC+zC3gNbNzGubNzGubNzGu?$2bd z-wm3~_0RRs_0RRs_0RRYZ;%K55BeYUKj?QUDi8V}^grl-(C-F99`rxxf6)J+-|db( z=zq}vp#MR?zdb=7^grl#Z7L7?AM`)yf6)J+-;KLG=yw+<5BeYUKj?qZ|DgXt|AYPq z{SW#d^grl-(Ep(SLBD%CdC>o$|3UwQe)o0qp#MStgZ>Bo5BeYUySo#2ck-bBLH~pP z2mNl3VMSl(pDbzKk9eqCy)9c^*`!=)c>geQU9a5C0$jKm33A|M36e|NT96J^K0o@c-fe!~ci>5C0$jKm33A z|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW z@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K z|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<# z;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e z|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe z!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0` z|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+` zhyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=> z{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci> z5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q% z{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@% zAO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk z{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$j zKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8 z{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5 zfB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG z`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A z|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW z@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K z|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<# z;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e z|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe z!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0` z|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+` zhyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=> z{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci> z5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q% z{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@% zAO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk z{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$j zKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8 z{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5 zfB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG z`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A z|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW z@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K z|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<# z;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36e z|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0`|A+q%{~!K8{D1iW@c-fe z!~ci>5C0$jKm33A|M36e|HJ=>{}2Bk{y+SG`2X<#;s3+`hyM@%AO1i5fB66K|Kb0` z|A+q%{~!K8{D1iW@c-fe!~ci>5C0$jKm33A|M36i|I7cE|1bYv{=fWx`Tz3&<^Rk7 zm;W#SU;e-RfBFCN|Kh;FaKZuzx;ps|9;ia|Cj$S|6l&U{D1lX z^8e-k%m0`EFaKZuzx;ps|MLIk|I7cE|1bYv{=fWx`Tz3&<^Rk7_iz3DfBFCN|KR;5qsDDxa zqW(qw{D1lXF6v*@zo>svKmT9;zx;ps|MLIk|I7cE|1bYv{=fWx`Tz3&<^Rk7m;W#S zU;e-RfBFCN|KI^`g z0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l) z8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}t zodKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW z)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9 zK%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$ z0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@ z3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS z&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG z>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4 zpw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H= z0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{ z2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(E zX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(o zbq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$i zP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb z0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&} z15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`o zGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?P zIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$ zs51a{2B6LW)ER&}15jrG>I^`g0qC9nJNo&G!hclz)2-|4^8 zf2aRW|DFCj{dfBB^xx^f(|@P`PQNn%bq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UI zfI0(EX8`I9K%D`oGXQl4pw0l)8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g z0jM(obq1i$0Mr?PIs;H=0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`oGXQl4pw0l) z8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1i$0Mr?PIs;H=0O|}t zodM{*{(JrR`tS8S15jrG>I^`g0qDK{d;RzN@Acp7zt?}S|6c#S{(JrR`tSAM>%Z53 zum4{EO#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*( zO#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*(O#e*3GXQl4pfmk5{WJYD{muZ? z8Gt$iP-g(@3_zU$s51a{2B6LW)ER&}15jrG>I^`g0jM(obq1hw{d4_u{d4_u{d4_u z{d4_u{d4`!0Mr?P&h^jr&-KssI|EQ>0O|}todKvb0CfhS&H&UIfI0(EX8`I9K%D`o zGXQl4pw0l)8Gz39&-Kss&-Kss&-Kss&-Kss&-Kss&-Kss&-Kss&-Kss&-Kss&-Kss z|9@57%Whm*0fteY9YhbBa|VDA3_v*XEeUNd%?N~eAoPFIegu|*hGvkRD$6SSv+Sx< z`?B8jzv+L||EB*<|C|0d{crl;^uOtU(~kid24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_ zh5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~ zVE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g z7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh= z24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ z0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja5 z0EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV% zfMEcJ0T>2g7=U2_h5;A`U>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0T>2g7=U2_h5;A` zU>Ja50EPh=24EO~VE~2!7zSV%fMEcJ0r2l1PFs%#FdD#U0HXnn1~3}HXaJ)Dj0P|o zz-R!Y0gMJP8o+1(qXCQtFdD#U0HXnn1~3}HXaJ)Dj0P|oz-R!Y0gMJP8o+1(qXCQt zFdD#U0HXnn1~3}HXaJ)Dj0P|oz-R!Y0gMJP8o+1(qXCQtFdD#U0HXnn1~3}HXaJ)D zj0P|oz-R!Y0gMJP8o+1(qXCQtFdD#U0HXnn1~3}HXaM_GKMi0sfYAU(0~ifpG=R|n zMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y z(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifp zG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C z4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(E#@T`$+>B4PZ2Y z(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifp zG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C z4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU( z0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy z07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=F zfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfP zU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR z7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|n zMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y z(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifp zG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C z4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU( z0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy z07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=F zfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfP zU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR z7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|n zMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y z(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifp zG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C z4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU( z0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy z07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=F zfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfP zU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR z7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|n zMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y z(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifp zG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C z4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU( z0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy z07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=F zfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfP zU^IZy07e5C4PZ2Y(EvsR7!6=FfYAU(0~ifpG=R|nMgtfPU^IZy07e5C4PZ2Y(EvsR z7!6=FfYAU(0~ifpG=R|nMgtfPU^IYe0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1 zAR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ( z8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2 zKs1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4Immo zG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4 zfM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaN0CKMf!nKs1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$ zhz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<8tY`e^{s0HOgz1BeC?4Immo zG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4 zfM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V z0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?W zL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz z1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$ zhz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c z1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh z5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC? z4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1 zAR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ( z8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2 zKs1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4Immo zG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4 zfM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<8t+{WO4R0MP)V z0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?W zL<5Kh5DlQ8=%)ch1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ( z8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2 zKs1180MP)V0d(oV^k4cf{WO4R0MP)V0d(oV^k4cf{g?hr|E2%Zf9b#UU-~com;Out zrT@}@>A&<}`Y-*L{!9O*|I&Zyzw}@FFa4MPOaG<+(tqi{^k4cf{g?hr|E2%Zf9b#U zU-~com;OutrT@}@>%aBi`fvTW{#*a8|JHx&zxChxZ~eFaTmP;9)_?22_22q${kQ&G z|E>Slf9t>X-}-O;xBgrIt^d|f1BeC?4ImmoG=OLT(Ey?WL<5Kh(5?U0PXmYs5Dg$2 zK)3!||E>Slf9t>X-}-O;xBgrIt^d}4>%aBi`fvR-fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V z0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?W zL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz z1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$ zhz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c z1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh z5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC? z4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1 zAR0hF(@z741`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP*Y zxqcczG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1 zAR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ( z8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2 zKs1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4Immo zG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4 zfM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V z0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?W zL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz z1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$ zhz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c z1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh z5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC? z4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1 zAR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ( z8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2 zKs1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4Immo zG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4 zfM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCF zXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks118 z0MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT z(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G z0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V0Yn3c1`rJ(8bCCFXaLaw zq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT(Ey?WL<5Kh5Dg$2Ks1180MP)V z0Yn3c1`rJ(8bCCFXaLawq5(t$hz1Z1AR0h4fM@{G0HOgz1BeC?4ImmoG=OLT=cg75 z4ZnULerK@GWB7x??2F+txQ;jc?S1&W!Ara0{bSC2p5OfVm;OD!-}u%4-EhZa{pX|p z_n7~?^T+Ss`P29BeCy6vfAOjB5BcvK?z}(JfA73M(pP`|?olz8)`P3b^WbIaJUD@W9vlxo4-UDV2d6>K2k!&U2Ty+IgTMUq!D++Yc_`88q9*7kDh+cM*!g=z?cb=F4=gCg|^JHK6d9uO7;EO+7UO&&)Dh+=&SQ#|D-!jL3o=ps#XRrL{ z*~`Uw_B?!^om)N64uqZO-}@TR_N$*~o378Z9V+L=vdVd}MsQwC)}I$Mt>?wm-g)tU z>%4eFab8?Oo)@r$!MFDEjmLX=|JGh?(>^cu)Sefciq4B&Byni z&#M!;=hZ&+^J;_Fd9~~4d~&&G@B{ngiRgTC$l`pm-S&L8c6dHp*g2oAc%09sP0nZU mh0kXfr04TDe$T#GIzC^lv7Rp$QO*~$=jV&*v-8EZ*gpZgev@zj literal 0 HcmV?d00001 From ff5f4b1fa25d90903622b36c9e48296da4d0f6c6 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Sun, 13 Apr 2025 16:29:06 +0000 Subject: [PATCH 09/11] flash attention fixes --- .../src/models/mixtral/mixtral_attention.py | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_attention.py b/keras_hub/src/models/mixtral/mixtral_attention.py index c38677151c..95bd0a4e48 100644 --- a/keras_hub/src/models/mixtral/mixtral_attention.py +++ b/keras_hub/src/models/mixtral/mixtral_attention.py @@ -1,3 +1,4 @@ +import inspect import math import keras @@ -5,7 +6,10 @@ from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding from keras_hub.src.utils.keras_utils import clone_initializer -from keras_hub.src.utils.keras_utils import has_flash_attention_support +from keras_hub.src.utils.keras_utils import fused_attention_op_available +from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op +from keras_hub.src.utils.keras_utils import running_on_gpu +from keras_hub.src.utils.keras_utils import running_on_tpu class CachedMixtralAttention(keras.layers.Layer): @@ -188,19 +192,41 @@ def _masked_softmax(self, attention_scores, attention_mask=None): ) return self._softmax(attention_scores) + def _use_fused_attention_op(self): + if not fused_attention_op_available(): + return False + if self.dropout > 0.0: + return False + if running_on_gpu(): + # GPU never supports softcap in the fused op. + if self.logit_soft_cap is not None: + return False + return gpu_supports_fused_attention_op() + elif running_on_tpu(): + # TPU supports softcap with on keras >= 3.10. + sig = inspect.signature(ops.dot_product_attention) + return "attn_logits_soft_cap" in sig.parameters + else: + return False + def _compute_attention(self, query, key, value, attention_mask=None): - if has_flash_attention_support(): - # Use `dot_product_attention` with Flash Attention support if - # available. + if self._use_fused_attention_op(): if attention_mask is not None: attention_mask = ops.expand_dims(attention_mask, axis=1) attention_mask = ops.cast(attention_mask, dtype="bool") + + if self.logit_soft_cap: + kwargs = {"attn_logits_soft_cap": self.logit_soft_cap} + else: + kwargs = {} + attention_output = ops.dot_product_attention( query, key, value, mask=attention_mask, scale=self._inv_norm_factor, + **kwargs, ) return attention_output From 1dba1a3ba2d29e99107b0af0e01474ea76c0d06d Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Mon, 14 Apr 2025 05:28:08 +0000 Subject: [PATCH 10/11] bug fixes --- .../src/models/mixtral/mixtral_backbone.py | 4 ++ .../src/models/mixtral/mixtral_decoder.py | 44 +++++++++++++------ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/keras_hub/src/models/mixtral/mixtral_backbone.py b/keras_hub/src/models/mixtral/mixtral_backbone.py index 39c68077e0..1e6434a597 100644 --- a/keras_hub/src/models/mixtral/mixtral_backbone.py +++ b/keras_hub/src/models/mixtral/mixtral_backbone.py @@ -102,6 +102,7 @@ def __init__( rope_max_wavelength=10000, rope_scaling_factor=1.0, layer_norm_epsilon=1e-6, + router_aux_loss_coef=0.02, sliding_window=512, dropout=0, dtype=None, @@ -131,6 +132,7 @@ def __init__( rope_scaling_factor=rope_scaling_factor, layer_norm_epsilon=layer_norm_epsilon, activation=ops.silu, + router_aux_loss_coef=router_aux_loss_coef, kernel_initializer=_mixtral_kernel_initializer(stddev=0.02), sliding_window=sliding_window, dropout=dropout, @@ -177,6 +179,7 @@ def __init__( self.router_jitter_noise = router_jitter_noise self.rope_max_wavelength = rope_max_wavelength + self.router_aux_loss_coef = router_aux_loss_coef self.rope_scaling_factor = rope_scaling_factor self.sliding_window = sliding_window self.layer_norm_epsilon = layer_norm_epsilon @@ -197,6 +200,7 @@ def get_config(self): "rope_max_wavelength": self.rope_max_wavelength, "rope_scaling_factor": self.rope_scaling_factor, "num_key_value_heads": self.num_key_value_heads, + "router_aux_loss_coef": self.router_aux_loss_coef, "sliding_window": self.sliding_window, "layer_norm_epsilon": self.layer_norm_epsilon, "dropout": self.dropout, diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index 7d0a8b8201..7cb76a851c 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -36,9 +36,9 @@ def compute_load_balancing_loss( ) # Shape: (batch_size * seq_len, num_experts) # Get top-k experts - _, selected_experts = ops.top_k( + top_k_weights, selected_experts = ops.top_k( routing_weights, k=top_k - ) # Shape: (batch_size * seq_len, top_k) + ) # Shape: (batch_size * seq_len, top_k) for both # Create one-hot encoding for selected experts expert_mask = ops.one_hot( @@ -47,27 +47,40 @@ def compute_load_balancing_loss( if attention_mask is not None: # Flatten attention_mask to match router_logits - batch_size, seq_len = ops.shape(attention_mask) + seq_len = (ops.shape(attention_mask)[1],) + batch_seq_len = ops.shape(router_logits)[0] + # Dynamically compute the batch size to match router_logits + target_batch_size = batch_seq_len // seq_len + # Slice attention_mask to match the expected batch size + attention_mask = ops.slice( + attention_mask, [0, 0], [target_batch_size, seq_len] + ) flat_mask = ops.reshape( attention_mask, (-1,) ) # Shape: (batch_size * seq_len,) + flat_mask = ops.cast(flat_mask, dtype="float32") # Expand mask for broadcasting expert_attention_mask = ops.expand_dims( flat_mask, axis=-1 ) # Shape: (batch_size * seq_len, 1) - expert_attention_mask = ops.cast(expert_attention_mask, dtype="float32") + expert_attention_mask = ops.expand_dims( + expert_attention_mask, axis=1 + ) # Shape: (batch_size * seq_len, 1, 1) - # Compute masked means + # Compute masked token counts tokens_per_expert = ops.sum( - expert_mask * expert_attention_mask[:, None, :], axis=0 - ) / ops.maximum( - ops.sum(expert_attention_mask[:, None, :], axis=0), 1e-9 + expert_mask * expert_attention_mask, axis=0 ) # Shape: (top_k, num_experts) + mask_sum = ops.sum(expert_attention_mask, axis=0) # Shape: (1, 1) + tokens_per_expert = tokens_per_expert / ops.maximum(mask_sum, 1e-9) + + # Compute masked router probabilities router_prob_per_expert = ops.sum( - routing_weights * expert_attention_mask, axis=0 - ) / ops.maximum( - ops.sum(expert_attention_mask, axis=0), 1e-9 + routing_weights * flat_mask[:, None], axis=0 ) # Shape: (num_experts,) + router_prob_per_expert = router_prob_per_expert / ops.maximum( + ops.sum(flat_mask), 1e-9 + ) else: # Unmasked means tokens_per_expert = ops.mean( @@ -77,7 +90,7 @@ def compute_load_balancing_loss( routing_weights, axis=0 ) # Shape: (num_experts,) - # Average over top_k dimension if necessary + # Average over top_k dimension tokens_per_expert = ops.mean( tokens_per_expert, axis=0 ) # Shape: (num_experts,) @@ -172,6 +185,7 @@ def __init__( top_k=2, router_jitter_noise=0.0, layer_norm_epsilon=1e-5, + router_aux_loss_coef=0.02, kernel_initializer="glorot_uniform", **kwargs, ): @@ -182,6 +196,7 @@ def __init__( self.top_k = top_k self.router_jitter_noise = router_jitter_noise self.layer_norm_epsilon = layer_norm_epsilon + self.router_aux_loss_coef = router_aux_loss_coef self.kernel_initializer = keras.initializers.get(kernel_initializer) def build(self, decoder_sequence_shape): @@ -274,6 +289,7 @@ def __init__( rope_scaling_factor=1.0, activation="silu", layer_norm_epsilon=1e-5, + router_aux_loss_coef=0.02, kernel_initializer="glorot_uniform", sliding_window=512, dropout=0, @@ -298,6 +314,7 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.kernel_initializer = keras.initializers.get(kernel_initializer) + self.router_aux_loss_coef = router_aux_loss_coef self.output_router_logits = output_router_logits self.supports_masking = True @@ -338,6 +355,7 @@ def build(self, decoder_sequence_shape): num_experts=self.num_experts, top_k=self.top_k, router_jitter_noise=self.router_jitter_noise, + router_aux_loss_coef=self.router_aux_loss_coef, dtype=self.dtype_policy, ) self._sparse_moe_block.build(decoder_sequence_shape) @@ -389,7 +407,7 @@ def call( x = self._feedforward_layernorm(x) x, router_logits = self._sparse_moe_block( - x, attention_mask=self_attention_mask + x, attention_mask=decoder_padding_mask ) decoder_output = x + residual From 71d7401d5c0b901329136c490ae620df179652e8 Mon Sep 17 00:00:00 2001 From: Anshuman Mishra Date: Mon, 14 Apr 2025 06:50:16 +0000 Subject: [PATCH 11/11] bug fix --- keras_hub/src/models/mixtral/mixtral_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/mixtral/mixtral_decoder.py b/keras_hub/src/models/mixtral/mixtral_decoder.py index 7cb76a851c..4bfbaa32a7 100644 --- a/keras_hub/src/models/mixtral/mixtral_decoder.py +++ b/keras_hub/src/models/mixtral/mixtral_decoder.py @@ -47,7 +47,7 @@ def compute_load_balancing_loss( if attention_mask is not None: # Flatten attention_mask to match router_logits - seq_len = (ops.shape(attention_mask)[1],) + seq_len = ops.shape(attention_mask)[1] batch_seq_len = ops.shape(router_logits)[0] # Dynamically compute the batch size to match router_logits target_batch_size = batch_seq_len // seq_len