diff --git a/.gitignore b/.gitignore index 453e6de5..cab7f08b 100644 --- a/.gitignore +++ b/.gitignore @@ -66,4 +66,9 @@ test_dxt/* # UME HF model src/lobster/model/integrations/ume_huggingface/model/* -*.sh \ No newline at end of file +*.sh +# Dataset cache files +data/**/*.parquet.cnt +data/**/*.parquet.cnt.lock +data/**/index.json.lock +data/0c57441b132b458b45160a306ca82943/ diff --git a/src/lobster/hydra_config/model/classification.yaml b/src/lobster/hydra_config/model/classification.yaml new file mode 100644 index 00000000..7e915383 --- /dev/null +++ b/src/lobster/hydra_config/model/classification.yaml @@ -0,0 +1,19 @@ +_target_: lobster.model._property_classification.PropertyClassification + +# Instantiate base encoder from pretrained UME v1 (ModernBERT) +encoder: + _target_: lobster.model.UME.from_pretrained + model_name: ume-mini-base-12M + use_flash_attn: false + +config: + _target_: lobster.model._property_classification.PropertyClassificationConfig + task_name: property + num_classes: 2 # binary classification + loss_function: auto # will use BCE for binary + hidden_sizes: [512] + dropout: 0.1 + activation: auto + pooling: attn + lr: 1e-3 + weight_decay: 0.0 diff --git a/src/lobster/hydra_config/model/classification_ume2.yaml b/src/lobster/hydra_config/model/classification_ume2.yaml new file mode 100644 index 00000000..8713f8bf --- /dev/null +++ b/src/lobster/hydra_config/model/classification_ume2.yaml @@ -0,0 +1,26 @@ +_target_: lobster.model._property_classification.PropertyClassification + +# UME-2 encoder (UMESequenceEncoderModule) +# +# Default: Load from checkpoint (most common for finetuning) +# Override checkpoint_path in experiment config +# +# For random initialization, override _target_ in experiment config: +# _target_: lobster.model.ume2.UMESequenceEncoderModule +# # Must specify: model_size, pad_token_id, max_length, vocab_size +encoder: + _target_: lobster.model.ume2.UMESequenceEncoderModule.load_from_checkpoint + checkpoint_path: ??? # Required: set in experiment config + cache_dir: null + +config: + _target_: lobster.model._property_classification.PropertyClassificationConfig + task_name: property + num_classes: 2 # binary classification + loss_function: auto # will use BCE for binary, CE for multi-class + hidden_sizes: [512] + dropout: 0.1 + activation: auto + pooling: attn + lr: 1e-3 + weight_decay: 0.0 diff --git a/src/lobster/hydra_config/model/regression_ume2.yaml b/src/lobster/hydra_config/model/regression_ume2.yaml new file mode 100644 index 00000000..d7531b55 --- /dev/null +++ b/src/lobster/hydra_config/model/regression_ume2.yaml @@ -0,0 +1,25 @@ +_target_: lobster.model._property_regression.PropertyRegression + +# UME-2 encoder (UMESequenceEncoderModule) +# +# Default: Load from checkpoint (most common for finetuning) +# Override checkpoint_path in experiment config +# +# For random initialization, override _target_ in experiment config: +# _target_: lobster.model.ume2.UMESequenceEncoderModule +# # Must specify: model_size, pad_token_id, max_length, vocab_size +encoder: + _target_: lobster.model.ume2.UMESequenceEncoderModule.load_from_checkpoint + checkpoint_path: ??? # Required: set in experiment config + cache_dir: null + +config: + _target_: lobster.model._property_regression.PropertyRegressionConfig + task_name: property + loss_function: ${training.loss_function} # e.g., mse, huber, smooth_l1, gaussian, exponential + hidden_sizes: [512] + dropout: 0.1 + activation: auto + pooling: mean + lr: 1e-3 + weight_decay: 0.0 diff --git a/src/lobster/model/__init__.py b/src/lobster/model/__init__.py index b19a0624..419c0011 100644 --- a/src/lobster/model/__init__.py +++ b/src/lobster/model/__init__.py @@ -13,6 +13,8 @@ from ._mlp import LobsterMLP from ._peft_lightning_module import LobsterPEFT from ._ppi_clf import PPIClassifier +from ._property_classification import PropertyClassification, PropertyClassificationConfig +from ._property_regression import PropertyRegression, PropertyRegressionConfig from ._seq2seq import PrescientPT5 from ._ume import UME from ._heads import TaskConfig, TaskHead, MultiTaskHead, FlexibleEncoderWithHeads diff --git a/src/lobster/model/_property_classification.py b/src/lobster/model/_property_classification.py new file mode 100644 index 00000000..37674618 --- /dev/null +++ b/src/lobster/model/_property_classification.py @@ -0,0 +1,188 @@ +from dataclasses import dataclass +import logging +from typing import Literal + +import lightning as L +from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC +import torch +import torch.nn as nn + +from ._heads import TaskConfig, FlexibleEncoderWithHeads +from lobster.post_train.unfreezing import set_unfrozen_layers + + +@dataclass +class PropertyClassificationConfig: + """Configuration for training a property classification head on a generic encoder. + + Parameters + ---------- + task_name : str + Name of the task/head; used to route outputs and metrics. + num_classes : int + Number of classes. For binary classification, use 2. + loss_function : str + Classification loss to use. Supported examples include + 'auto', 'bce', 'cross_entropy', 'focal'. + hidden_sizes : list[int] | None + Sizes of the MLP layers in the head. When None, a single + linear layer is used. + dropout : float + Dropout probability applied inside the head MLP. + activation : str + Activation function for the head MLP. 'auto' picks a + sensible default. + pooling : Literal["cls", "mean", "attn", "weighted_mean"] + How to pool token embeddings into a sequence embedding. + lr : float + Learning rate for the optimizer configured by this module. + weight_decay : float + Weight decay for the optimizer. + unfreeze_last_n_layers : int | None + Controls encoder layer unfreezing via `set_unfrozen_layers`: + - None: leave `requires_grad` as-is + - -1: unfreeze all encoder layers + - 0: freeze all encoder layers + - >0: unfreeze the last N encoder layers + """ + + task_name: str = "property" + num_classes: int = 2 + loss_function: str = "auto" + hidden_sizes: list[int] | None = None + dropout: float = 0.1 + activation: str = "auto" + pooling: Literal["cls", "mean", "attn", "weighted_mean"] = "mean" + lr: float = 1e-3 + weight_decay: float = 0.0 + unfreeze_last_n_layers: int | None = None + + +class PropertyClassification(L.LightningModule): + """LightningModule for training a classification head on top of any encoder. + + Args + ---- + encoder : nn.Module + The pretrained encoder used as the backbone. + config : PropertyClassificationConfig + Configuration controlling the head, loss, optimizer, pooling, + and encoder unfreezing policy. + """ + + def __init__(self, encoder: nn.Module, *, config: PropertyClassificationConfig | None = None) -> None: + super().__init__() + self.save_hyperparameters(ignore=["encoder"]) + + self.encoder = encoder + cfg = config or PropertyClassificationConfig() + self.cfg = cfg + + # Determine task type and output dimension + if cfg.num_classes == 2: + task_type = "binary_classification" + head_output_dim = 1 + else: + task_type = "multiclass_classification" + head_output_dim = cfg.num_classes + + task = TaskConfig( + name=cfg.task_name, + output_dim=head_output_dim, + task_type=task_type, + pooling=cfg.pooling, + hidden_sizes=cfg.hidden_sizes, + dropout=cfg.dropout, + activation=cfg.activation, + loss_function=cfg.loss_function, + ) + + # Resolve encoder hidden size for head construction + hidden_size = None + if hasattr(self.encoder, "embedding_dim"): + hidden_size = self.encoder.embedding_dim + elif hasattr(self.encoder, "config") and hasattr(self.encoder.config, "hidden_size"): + hidden_size = self.encoder.config.hidden_size + elif hasattr(self.encoder, "hidden_size"): + hidden_size = self.encoder.hidden_size + + self.model = FlexibleEncoderWithHeads( + encoder=self.encoder, + task_configs=[task], + hidden_size=hidden_size, + ) + + # Apply unfreezing if requested via config + logging.getLogger(__name__).info(f"PropertyClassification: unfreeze_last_n_layers={cfg.unfreeze_last_n_layers}") + if cfg.unfreeze_last_n_layers is not None: + n = int(cfg.unfreeze_last_n_layers) + set_unfrozen_layers(self.encoder, n) + + self.loss_fns = self.model.get_loss_functions() + + # Metrics for binary or multiclass classification + task_metric = "binary" if cfg.num_classes == 2 else "multiclass" + metric_kwargs = {"task": task_metric, "num_classes": cfg.num_classes if task_metric == "multiclass" else None} + + self.train_acc = Accuracy(**metric_kwargs) + self.val_acc = Accuracy(**metric_kwargs) + self.train_precision = Precision(**metric_kwargs) + self.val_precision = Precision(**metric_kwargs) + self.train_recall = Recall(**metric_kwargs) + self.val_recall = Recall(**metric_kwargs) + self.train_f1 = F1Score(**metric_kwargs) + self.val_f1 = F1Score(**metric_kwargs) + self.train_auroc = AUROC(**metric_kwargs) + self.val_auroc = AUROC(**metric_kwargs) + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) + return outputs[self.cfg.task_name] + + def _shared_step(self, batch: dict[str, torch.Tensor], stage: str) -> torch.Tensor: + logits = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) + targets = batch["targets"].to(logits.device) + loss_fn = self.loss_fns[self.cfg.task_name] + + # For binary classification, logits are (B,) and need to be passed through sigmoid for metrics + if self.cfg.num_classes == 2: + loss = loss_fn(logits, targets.float()) + probs = torch.sigmoid(logits) + preds = (probs > 0.5).long() + else: + # For multiclass, logits are (B, C) + loss = loss_fn(logits, targets.long()) + probs = torch.softmax(logits, dim=-1) + preds = torch.argmax(probs, dim=-1) + + # Update metrics + acc = self.train_acc if stage == "train" else self.val_acc + precision = self.train_precision if stage == "train" else self.val_precision + recall = self.train_recall if stage == "train" else self.val_recall + f1 = self.train_f1 if stage == "train" else self.val_f1 + auroc = self.train_auroc if stage == "train" else self.val_auroc + + acc(preds, targets) + precision(preds, targets) + recall(preds, targets) + f1(preds, targets) + auroc(probs if self.cfg.num_classes == 2 else probs, targets) + + self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_acc", acc, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_precision", precision, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_recall", recall, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_f1", f1, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_auroc", auroc, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) + return loss + + def training_step(self, batch: dict[str, torch.Tensor], _: int) -> torch.Tensor: + return self._shared_step(batch, "train") + + def validation_step(self, batch: dict[str, torch.Tensor], _: int) -> torch.Tensor: + return self._shared_step(batch, "val") + + def configure_optimizers(self): + params = [p for p in self.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay) + return optimizer diff --git a/src/lobster/model/_property_regression.py b/src/lobster/model/_property_regression.py new file mode 100644 index 00000000..d2358ead --- /dev/null +++ b/src/lobster/model/_property_regression.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +import logging +from typing import Literal + +import lightning as L +from torchmetrics import MeanAbsoluteError, R2Score, SpearmanCorrCoef, PearsonCorrCoef +import torch +import torch.nn as nn + +from ._heads import TaskConfig, FlexibleEncoderWithHeads +from lobster.post_train.unfreezing import set_unfrozen_layers + + +@dataclass +class PropertyRegressionConfig: + """Configuration for training a property regression head on a generic encoder. + + Parameters + - task_name: Name of the task/head; used to route outputs and metrics. + - loss_function: Regression loss to use. Supported examples include + 'auto', 'l1', 'mse', 'huber', 'gaussian', 'mdn_gaussian'. For + 'gaussian', the head outputs two values per example (mean, log_scale). + For 'mdn_gaussian', the head outputs parameters for a K-component + Gaussian mixture; K is set by `mixture_components`. + - hidden_sizes: Sizes of the MLP layers in the head. When None, a single + linear layer is used. + - dropout: Dropout probability applied inside the head MLP. + - activation: Activation function for the head MLP. 'auto' picks a + sensible default. + - pooling: How to pool token embeddings into a sequence embedding. + One of 'cls', 'mean', 'attn', 'weighted_mean'. + - lr: Learning rate for the optimizer configured by this module. + - weight_decay: Weight decay for the optimizer. + - unfreeze_last_n_layers: Controls encoder layer unfreezing via + `set_unfrozen_layers`: + - None: leave `requires_grad` as-is + - -1: unfreeze all encoder layers + - 0: freeze all encoder layers + - >0: unfreeze the last N encoder layers + - mixture_components: Number of mixture components K for 'mdn_gaussian'. + """ + + task_name: str = "property" + loss_function: str = "auto" + hidden_sizes: list[int] | None = None + dropout: float = 0.1 + activation: str = "auto" + pooling: Literal["cls", "mean", "attn", "weighted_mean"] = "mean" + lr: float = 1e-3 + weight_decay: float = 0.0 + unfreeze_last_n_layers: int | None = None + mixture_components: int | None = None + + +class PropertyRegression(L.LightningModule): + """LightningModule for training a regression head on top of any encoder. + + Args: + encoder: The pretrained encoder used as the backbone. + config: `PropertyRegressionConfig` controlling the head, loss, + optimizer, pooling, and encoder unfreezing policy. + """ + + def __init__(self, encoder: nn.Module, *, config: PropertyRegressionConfig | None = None) -> None: + super().__init__() + self.save_hyperparameters(ignore=[encoder]) + + self.encoder = encoder + cfg = config or PropertyRegressionConfig() + self.cfg = cfg + + # Determine head output dimension based on loss + head_output_dim = 1 + if cfg.loss_function == "gaussian": + head_output_dim = 2 # mean, log_scale + + task = TaskConfig( + name=cfg.task_name, + output_dim=head_output_dim, + task_type="regression", + pooling=cfg.pooling, + hidden_sizes=cfg.hidden_sizes, + dropout=cfg.dropout, + activation=cfg.activation, + loss_function=cfg.loss_function, + mixture_components=cfg.mixture_components, + ) + + # Resolve encoder hidden size for head construction + hidden_size = None + if hasattr(self.encoder, "embedding_dim"): + hidden_size = getattr(self.encoder, "embedding_dim") + elif hasattr(self.encoder, "config") and hasattr(self.encoder.config, "hidden_size"): + hidden_size = self.encoder.config.hidden_size + elif hasattr(self.encoder, "hidden_size"): + hidden_size = getattr(self.encoder, "hidden_size") + + self.model = FlexibleEncoderWithHeads( + encoder=self.encoder, + task_configs=[task], + hidden_size=hidden_size, + ) + + # Apply unfreezing if requested via config + logging.getLogger(__name__).info( + f"PropertyRegression: unfreeze_last_n_layers={cfg.unfreeze_last_n_layers}" + ) + if cfg.unfreeze_last_n_layers is not None: + n = int(cfg.unfreeze_last_n_layers) + set_unfrozen_layers(self.encoder, n) + + self.loss_fns = self.model.get_loss_functions() + self.train_mae = MeanAbsoluteError() + self.val_mae = MeanAbsoluteError() + self.train_r2 = R2Score() + self.val_r2 = R2Score() + self.train_spearman = SpearmanCorrCoef() + self.val_spearman = SpearmanCorrCoef() + self.train_pearson = PearsonCorrCoef() + self.val_pearson = PearsonCorrCoef() + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) + # Return raw head output. For MDN this is params vector; for standard regression it's (B,1) + return outputs[self.cfg.task_name] + + def _shared_step(self, batch: dict[str, torch.Tensor], stage: str) -> torch.Tensor: + preds = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) # (B, P) or (B,1) + targets = batch["targets"].to(preds.device) + loss_fn = self.loss_fns[self.cfg.task_name] + + # Compute scalar prediction for metrics + if self.cfg.loss_function == "mdn_gaussian": + # Parse MDN params for D=1 + P = preds.shape[-1] + if P % 3 != 0: + raise ValueError(f"Expected MDN param size divisible by 3 for D=1, got {P}") + K = P // 3 + logits = preds[:, :K] + means = preds[:, K : 2 * K] + # log_scales = preds[:, 2 * K : 3 * K] # not needed for metrics + weights = torch.softmax(logits, dim=-1) + y_hat = torch.sum(weights * means, dim=-1) # (B,) + preds_for_loss = preds + elif self.cfg.loss_function == "gaussian": + # Natural Gaussian: preds = [mean, log_scale] + if preds.shape[-1] != 2: + raise ValueError(f"Gaussian loss expects head output dim=2, got {preds.shape[-1]}") + y_hat = preds[..., 0] + preds_for_loss = preds + else: + y_hat = preds.squeeze(-1) # (B,) + preds_for_loss = y_hat + + loss = loss_fn(preds_for_loss, targets) + + mae = self.train_mae if stage == "train" else self.val_mae + r2 = self.train_r2 if stage == "train" else self.val_r2 + spearman = self.train_spearman if stage == "train" else self.val_spearman + pearson = self.train_pearson if stage == "train" else self.val_pearson + mae(y_hat, targets) + r2(y_hat, targets) + spearman(y_hat, targets) + pearson(y_hat, targets) + self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_mae", mae, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_r2", r2, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_spearman", spearman, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"{stage}_pearson", pearson, prog_bar=False, on_step=False, on_epoch=True, sync_dist=True) + return loss + + def training_step(self, batch: dict[str, torch.Tensor], _: int) -> torch.Tensor: + return self._shared_step(batch, "train") + + def validation_step(self, batch: dict[str, torch.Tensor], _: int) -> torch.Tensor: + return self._shared_step(batch, "val") + + def configure_optimizers(self): + params = [p for p in self.parameters() if p.requires_grad] + optimizer = torch.optim.AdamW(params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay) + return optimizer + + diff --git a/src/lobster/model/ume2/_checkpoint_utils.py b/src/lobster/model/ume2/_checkpoint_utils.py index 303f4dd9..3f54c4d0 100644 --- a/src/lobster/model/ume2/_checkpoint_utils.py +++ b/src/lobster/model/ume2/_checkpoint_utils.py @@ -50,3 +50,38 @@ def map_checkpoint_keys( mapped_state_dict[key.replace(original_prefix, new_prefix, 1)] = value return mapped_state_dict + + +# TODO: Add config.json support for checkpoint parameters +# Once config.json files are added alongside checkpoints in S3, implement a +# load_config_json() function to load non-critical params (model_size, pad_token_id, max_length) + + +def infer_architecture_from_state_dict(state_dict: dict[str, torch.Tensor], prefix: str = "") -> dict: + """ + Infer critical architecture parameters from checkpoint state_dict. + + Infers vocab_size and hidden_size from encoder weight shapes. + These parameters MUST match checkpoint weights and cannot be overridden. + """ + inferred_params = {} + + # Try different possible key formats + encoder_key_variants = [ + f"{prefix}model.encoder.weight", + "model.encoder.weight", + "encoder.neobert.model.encoder.weight", + ] + + for encoder_key in encoder_key_variants: + if encoder_key in state_dict: + vocab_size, hidden_size = state_dict[encoder_key].shape + inferred_params["vocab_size"] = vocab_size + inferred_params["hidden_size"] = hidden_size + logger.info(f"Inferred from checkpoint key '{encoder_key}': vocab_size={vocab_size}, hidden_size={hidden_size}") + break + + if not inferred_params: + logger.warning(f"Could not find encoder.weight in state_dict to infer vocab_size/hidden_size. Keys: {list(state_dict.keys())[:10]}") + + return inferred_params diff --git a/src/lobster/model/ume2/ume_sequence_encoder.py b/src/lobster/model/ume2/ume_sequence_encoder.py index 35b449bb..426d21fe 100644 --- a/src/lobster/model/ume2/ume_sequence_encoder.py +++ b/src/lobster/model/ume2/ume_sequence_encoder.py @@ -10,7 +10,7 @@ from lobster.tokenization import get_ume_tokenizer_transforms from ..neobert import NeoBERTModule -from ._checkpoint_utils import load_checkpoint_from_s3_uri_or_local_path, map_checkpoint_keys +from ._checkpoint_utils import infer_architecture_from_state_dict, load_checkpoint_from_s3_uri_or_local_path, map_checkpoint_keys from .auxiliary_tasks import AuxiliaryRegressionTaskHead, AuxiliaryTask logger = logging.getLogger(__name__) @@ -51,6 +51,11 @@ def __init__( } ) + @property + def config(self): + """Expose neobert config for compatibility with downstream tasks.""" + return self.neobert.config + @classmethod def load_from_checkpoint( cls, checkpoint_path: str, *, device: str | None = None, cache_dir: str | None = None, **kwargs @@ -70,6 +75,11 @@ def load_from_checkpoint( state_dict = checkpoint["state_dict"] encoder_kwargs = hyper_parameters.pop("encoder_kwargs", {}) + + # Infer critical architecture params from state_dict (vocab_size, hidden_size) + # These override any values in encoder_kwargs to ensure consistency with checkpoint weights + inferred_params = infer_architecture_from_state_dict(state_dict, prefix="encoder.neobert.") + encoder_kwargs.update(inferred_params) # Initialize encoder encoder = cls(**hyper_parameters, **encoder_kwargs) diff --git a/src/lobster/model/ume2/ume_sequence_encoder_lightning_module.py b/src/lobster/model/ume2/ume_sequence_encoder_lightning_module.py index 4d9428d5..6d3dc65b 100644 --- a/src/lobster/model/ume2/ume_sequence_encoder_lightning_module.py +++ b/src/lobster/model/ume2/ume_sequence_encoder_lightning_module.py @@ -34,7 +34,6 @@ def __init__( scheduler_kwargs: dict | None = None, encoder_kwargs: dict | None = None, use_shared_tokenizer: bool = False, - ckpt_path: str | None = None, ): self.save_hyperparameters()