|
| 1 | +import sys |
| 2 | +import logging |
| 3 | +import anndata as ad |
| 4 | +import numpy as np |
| 5 | +import pandas as pd |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +from sklearn.decomposition import TruncatedSVD, PCA |
| 9 | +from sklearn.preprocessing import StandardScaler |
| 10 | +from scipy import sparse |
| 11 | +import gc |
| 12 | +import warnings |
| 13 | +warnings.filterwarnings('ignore') |
| 14 | + |
| 15 | +## VIASH START |
| 16 | +par = { |
| 17 | + 'input_train_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod1.h5ad', |
| 18 | + 'input_train_mod2': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod2.h5ad', |
| 19 | + 'input_test_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/test_mod1.h5ad', |
| 20 | + 'output': 'output.h5ad', |
| 21 | + 'task_type': 'auto', |
| 22 | + 'inputs_n_components': 128, |
| 23 | + 'targets_n_components': 128, |
| 24 | + 'encoder_h_dim': 512, |
| 25 | + 'decoder_h_dim': 512, |
| 26 | + 'n_encoder_block': 3, |
| 27 | + 'n_decoder_block': 3, |
| 28 | + 'dropout_p': 0.1, |
| 29 | + 'activation': 'relu', |
| 30 | + 'norm': 'layer_norm', |
| 31 | + 'use_skip_connections': True, |
| 32 | + 'learning_rate': 0.0001, |
| 33 | + 'weight_decay': 0.000001, |
| 34 | + 'epochs': 40, |
| 35 | + 'batch_size': 64, |
| 36 | + 'use_residual_connections': True, |
| 37 | +} |
| 38 | +meta = { |
| 39 | + 'name': 'suzuki_mlp' |
| 40 | +} |
| 41 | +## VIASH END |
| 42 | + |
| 43 | +# Import utils functions |
| 44 | +import sys |
| 45 | +import os |
| 46 | +sys.path.append(meta["resources_dir"]) |
| 47 | + |
| 48 | +from utils import ( |
| 49 | + determine_task_type, preprocess_data, train_model, |
| 50 | + MLPBModule, HierarchicalMLPBModule, SuzukiEncoderDecoderModule |
| 51 | +) |
| 52 | + |
| 53 | +def main(): |
| 54 | + # Enable logging |
| 55 | + logging.basicConfig(level=logging.INFO) |
| 56 | + |
| 57 | + # Determine device |
| 58 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 59 | + print(f"Using device: {device}", flush=True) |
| 60 | + |
| 61 | + # Read input files |
| 62 | + print("Reading input files", flush=True) |
| 63 | + adata_train_mod1 = ad.read_h5ad(par['input_train_mod1']) |
| 64 | + adata_train_mod2 = ad.read_h5ad(par['input_train_mod2']) |
| 65 | + adata_test_mod1 = ad.read_h5ad(par['input_test_mod1']) |
| 66 | + |
| 67 | + # Determine task type |
| 68 | + if par['task_type'] == 'auto': |
| 69 | + task_type = determine_task_type(adata_train_mod1, adata_train_mod2) |
| 70 | + print(f"Auto-detected task type: {task_type}", flush=True) |
| 71 | + else: |
| 72 | + task_type = par['task_type'] |
| 73 | + |
| 74 | + print(f"Task type: {task_type}", flush=True) |
| 75 | + print(f"Modality 1: {adata_train_mod1.uns.get('modality', 'Unknown')}, n_features: {adata_train_mod1.n_vars}") |
| 76 | + print(f"Modality 2: {adata_train_mod2.uns.get('modality', 'Unknown')}, n_features: {adata_train_mod2.n_vars}") |
| 77 | + |
| 78 | + # Preprocess data |
| 79 | + print("Preprocessing data", flush=True) |
| 80 | + data = preprocess_data( |
| 81 | + adata_train_mod1=adata_train_mod1, |
| 82 | + adata_train_mod2=adata_train_mod2, |
| 83 | + adata_test_mod1=adata_test_mod1, |
| 84 | + task_type=task_type, |
| 85 | + inputs_n_components=par['inputs_n_components'], |
| 86 | + targets_n_components=par['targets_n_components'] |
| 87 | + ) |
| 88 | + |
| 89 | + X_train = data['X_train'] |
| 90 | + y_train = data['y_train'] |
| 91 | + X_test = data['X_test'] |
| 92 | + metadata_train = data['metadata_train'] |
| 93 | + metadata_test = data['metadata_test'] |
| 94 | + targets_decomposer_components = data['targets_decomposer_components'] |
| 95 | + targets_global_median = data['targets_global_median'] |
| 96 | + y_statistic = data['y_statistic'] |
| 97 | + |
| 98 | + print(f"Training data shape: X={X_train.shape}, y={y_train.shape}") |
| 99 | + print(f"Test data shape: X={X_test.shape}") |
| 100 | + |
| 101 | + # Build model |
| 102 | + print("Building model", flush=True) |
| 103 | + input_dim = X_train.shape[1] |
| 104 | + output_dim = y_train.shape[1] |
| 105 | + |
| 106 | + # Create encoder |
| 107 | + encoder = MLPBModule( |
| 108 | + input_dim=None, # Will be set in the main module |
| 109 | + output_dim=par['encoder_h_dim'], |
| 110 | + n_block=par['n_encoder_block'], |
| 111 | + h_dim=par['encoder_h_dim'], |
| 112 | + skip=par['use_skip_connections'], |
| 113 | + dropout_p=par['dropout_p'], |
| 114 | + activation=par['activation'], |
| 115 | + norm="layer_norm" |
| 116 | + ) |
| 117 | + |
| 118 | + # Create hierarchical decoder |
| 119 | + decoder = HierarchicalMLPBModule( |
| 120 | + input_dim=par['encoder_h_dim'], |
| 121 | + output_dim=None, # Will create multiple outputs |
| 122 | + n_block=par['n_decoder_block'], |
| 123 | + h_dim=par['decoder_h_dim'], |
| 124 | + skip=par['use_skip_connections'], |
| 125 | + dropout_p=par['dropout_p'], |
| 126 | + activation=par['activation'], |
| 127 | + norm="layer_norm" |
| 128 | + ) |
| 129 | + |
| 130 | + # Create main model |
| 131 | + model = SuzukiEncoderDecoderModule( |
| 132 | + x_dim=input_dim, |
| 133 | + y_dim=output_dim, |
| 134 | + y_statistic=y_statistic, |
| 135 | + encoder_h_dim=par['encoder_h_dim'], |
| 136 | + decoder_h_dim=par['decoder_h_dim'], |
| 137 | + n_decoder_block=par['n_decoder_block'], |
| 138 | + targets_decomposer_components=targets_decomposer_components, |
| 139 | + targets_global_median=targets_global_median, |
| 140 | + encoder=encoder, |
| 141 | + decoder=decoder, |
| 142 | + task_type=task_type, |
| 143 | + use_residual_connections=par['use_residual_connections'] |
| 144 | + ).to(device) |
| 145 | + |
| 146 | + # Train model |
| 147 | + print("Training model", flush=True) |
| 148 | + trained_model = train_model( |
| 149 | + model=model, |
| 150 | + X_train=X_train, |
| 151 | + y_train=y_train, |
| 152 | + metadata_train=metadata_train, |
| 153 | + device=device, |
| 154 | + lr=par['learning_rate'], |
| 155 | + weight_decay=par['weight_decay'], |
| 156 | + epochs=par['epochs'], |
| 157 | + batch_size=par['batch_size'], |
| 158 | + task_type=task_type |
| 159 | + ) |
| 160 | + |
| 161 | + # Predict on test data |
| 162 | + print("Predicting on test data", flush=True) |
| 163 | + trained_model.eval() |
| 164 | + predictions = [] |
| 165 | + |
| 166 | + with torch.no_grad(): |
| 167 | + # Handle metadata safely for test data |
| 168 | + if 'gender' in metadata_test.columns: |
| 169 | + gender_values = metadata_test['gender'].values |
| 170 | + if gender_values.dtype == object: |
| 171 | + gender_values = pd.to_numeric(gender_values, errors='coerce').fillna(0).astype(int) |
| 172 | + gender_test = torch.LongTensor(gender_values) |
| 173 | + else: |
| 174 | + gender_test = torch.LongTensor(np.zeros(len(X_test), dtype=int)) |
| 175 | + |
| 176 | + info_test = torch.FloatTensor(np.zeros((len(X_test), 1))) |
| 177 | + |
| 178 | + test_dataset = torch.utils.data.TensorDataset( |
| 179 | + torch.FloatTensor(X_test), |
| 180 | + gender_test, |
| 181 | + info_test |
| 182 | + ) |
| 183 | + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=par['batch_size'], shuffle=False) |
| 184 | + |
| 185 | + for batch_x, batch_gender, batch_info in test_loader: |
| 186 | + batch_x = batch_x.to(device) |
| 187 | + batch_gender = batch_gender.to(device) |
| 188 | + batch_info = batch_info.to(device) |
| 189 | + |
| 190 | + pred = trained_model.predict(batch_x, batch_gender, batch_info) |
| 191 | + predictions.append(pred.cpu().numpy()) |
| 192 | + |
| 193 | + y_pred = np.vstack(predictions) |
| 194 | + |
| 195 | + # Create output AnnData object |
| 196 | + print("Creating output", flush=True) |
| 197 | + adata_pred = ad.AnnData( |
| 198 | + obs=adata_test_mod1.obs.copy(), |
| 199 | + var=adata_train_mod2.var.copy(), |
| 200 | + layers={ |
| 201 | + 'normalized': y_pred |
| 202 | + }, |
| 203 | + uns={ |
| 204 | + 'dataset_id': adata_train_mod1.uns.get('dataset_id', 'unknown'), |
| 205 | + 'method_id': meta['name'] |
| 206 | + } |
| 207 | + ) |
| 208 | + |
| 209 | + # Write output |
| 210 | + print("Writing output to file", flush=True) |
| 211 | + adata_pred.write_h5ad(par['output'], compression='gzip') |
| 212 | + |
| 213 | + print("Done!", flush=True) |
| 214 | + |
| 215 | +if __name__ == '__main__': |
| 216 | + main() |
0 commit comments