Skip to content

Commit 5f6c0d7

Browse files
authored
Add shuji suzuki method (#11)
* Add suzuki method * change description * add method to wf * remove old scripts * `metrics/mse`: Allow matrices to be dense or sparse
1 parent 76ce922 commit 5f6c0d7

File tree

9 files changed

+888
-60
lines changed

9 files changed

+888
-60
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@
77
* Added Novel method (PR #2).
88

99
* Added Simple MLP method (PR #3).
10+
11+
* `methods/suzuki_mlp`: Ported NeurIPS2022 top method (PR #11).
1012

1113
## MINOR CHANGES
1214

1315
* Bump image version for `openproblems/base_*` images to 1 -- a sliding release (PR #9).
1416

1517
* Bump Viash version to 0.9.4 (PR #12).
1618

19+
## BUG FIXES
20+
21+
* `metrics/mse`: Allow matrices to be dense or sparse (PR #11).
22+
1723
# task_predict_modality 0.1.0
1824

1925
Initial release after migrating the codebase.
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
__merge__: ../../api/comp_method.yaml
2+
name: suzuki_mlp
3+
label: Suzuki MLP
4+
summary: Hierarchical encoder-decoder neural network with task-specific preprocessing and residual connections for cross-modal prediction.
5+
description: |
6+
A hierarchical neural network encoder-decoder model based on Shuji Suzuki's 1st place solution
7+
in the Open Problems Multimodal Single-Cell Integration competition. The model uses task-specific
8+
preprocessing, SVD dimensionality reduction, and hierarchical MLP blocks with residual connections
9+
for learning cross-modal mappings.
10+
11+
The original author's code was adapted by GitHub Copilot
12+
(using Claude Sonnet) to integrate with this repository's framework and standards.
13+
links:
14+
documentation: https://www.kaggle.com/competitions/open-problems-multimodal/discussion/348468
15+
repository: https://github.com/shu65/open-problems-multimodal
16+
info:
17+
preferred_normalization: log_cp10k
18+
arguments:
19+
# Task configuration
20+
- name: "--task_type"
21+
type: "string"
22+
default: "auto"
23+
description: Task type - 'auto' for automatic detection, 'cite' for CITE-seq, 'multi' for multiome.
24+
25+
# Preprocessing arguments
26+
- name: "--inputs_n_components"
27+
type: "integer"
28+
default: 128
29+
description: Number of SVD components for input modality dimensionality reduction.
30+
- name: "--targets_n_components"
31+
type: "integer"
32+
default: 128
33+
description: Number of SVD components for target modality dimensionality reduction.
34+
35+
# Model architecture arguments
36+
- name: "--encoder_h_dim"
37+
type: "integer"
38+
default: 512
39+
description: Hidden dimension size for the encoder.
40+
- name: "--decoder_h_dim"
41+
type: "integer"
42+
default: 512
43+
description: Hidden dimension size for the decoder.
44+
- name: "--n_encoder_block"
45+
type: "integer"
46+
default: 3
47+
description: Number of encoder blocks.
48+
- name: "--n_decoder_block"
49+
type: "integer"
50+
default: 3
51+
description: Number of decoder blocks.
52+
- name: "--dropout_p"
53+
type: "double"
54+
default: 0.1
55+
description: Dropout probability.
56+
- name: "--activation"
57+
type: "string"
58+
default: "relu"
59+
description: Activation function (relu or gelu).
60+
- name: "--norm"
61+
type: "string"
62+
default: "layer_norm"
63+
description: Normalization type (layer_norm or batch_norm).
64+
- name: "--use_skip_connections"
65+
type: "boolean"
66+
default: true
67+
description: Whether to use skip connections in blocks.
68+
69+
# Training arguments
70+
- name: "--learning_rate"
71+
type: "double"
72+
default: 1e-4
73+
description: Learning rate for training.
74+
- name: "--weight_decay"
75+
type: "double"
76+
default: 1e-6
77+
description: Weight decay for regularization.
78+
- name: "--epochs"
79+
type: "integer"
80+
default: 40
81+
description: Number of training epochs.
82+
- name: "--batch_size"
83+
type: "integer"
84+
default: 64
85+
description: Batch size for training.
86+
- name: "--use_residual_connections"
87+
type: "boolean"
88+
default: true
89+
description: Whether to use residual connections for multi task.
90+
91+
resources:
92+
- type: python_script
93+
path: script.py
94+
- path: utils.py
95+
engines:
96+
- type: docker
97+
image: openproblems/base_pytorch_nvidia:1
98+
runners:
99+
- type: executable
100+
- type: nextflow
101+
directives:
102+
label: [hightime, highmem, midcpu, gpu]

src/methods/suzuki_mlp/script.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import sys
2+
import logging
3+
import anndata as ad
4+
import numpy as np
5+
import pandas as pd
6+
import torch
7+
import torch.nn as nn
8+
from sklearn.decomposition import TruncatedSVD, PCA
9+
from sklearn.preprocessing import StandardScaler
10+
from scipy import sparse
11+
import gc
12+
import warnings
13+
warnings.filterwarnings('ignore')
14+
15+
## VIASH START
16+
par = {
17+
'input_train_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod1.h5ad',
18+
'input_train_mod2': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/train_mod2.h5ad',
19+
'input_test_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/test_mod1.h5ad',
20+
'output': 'output.h5ad',
21+
'task_type': 'auto',
22+
'inputs_n_components': 128,
23+
'targets_n_components': 128,
24+
'encoder_h_dim': 512,
25+
'decoder_h_dim': 512,
26+
'n_encoder_block': 3,
27+
'n_decoder_block': 3,
28+
'dropout_p': 0.1,
29+
'activation': 'relu',
30+
'norm': 'layer_norm',
31+
'use_skip_connections': True,
32+
'learning_rate': 0.0001,
33+
'weight_decay': 0.000001,
34+
'epochs': 40,
35+
'batch_size': 64,
36+
'use_residual_connections': True,
37+
}
38+
meta = {
39+
'name': 'suzuki_mlp'
40+
}
41+
## VIASH END
42+
43+
# Import utils functions
44+
import sys
45+
import os
46+
sys.path.append(meta["resources_dir"])
47+
48+
from utils import (
49+
determine_task_type, preprocess_data, train_model,
50+
MLPBModule, HierarchicalMLPBModule, SuzukiEncoderDecoderModule
51+
)
52+
53+
def main():
54+
# Enable logging
55+
logging.basicConfig(level=logging.INFO)
56+
57+
# Determine device
58+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59+
print(f"Using device: {device}", flush=True)
60+
61+
# Read input files
62+
print("Reading input files", flush=True)
63+
adata_train_mod1 = ad.read_h5ad(par['input_train_mod1'])
64+
adata_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
65+
adata_test_mod1 = ad.read_h5ad(par['input_test_mod1'])
66+
67+
# Determine task type
68+
if par['task_type'] == 'auto':
69+
task_type = determine_task_type(adata_train_mod1, adata_train_mod2)
70+
print(f"Auto-detected task type: {task_type}", flush=True)
71+
else:
72+
task_type = par['task_type']
73+
74+
print(f"Task type: {task_type}", flush=True)
75+
print(f"Modality 1: {adata_train_mod1.uns.get('modality', 'Unknown')}, n_features: {adata_train_mod1.n_vars}")
76+
print(f"Modality 2: {adata_train_mod2.uns.get('modality', 'Unknown')}, n_features: {adata_train_mod2.n_vars}")
77+
78+
# Preprocess data
79+
print("Preprocessing data", flush=True)
80+
data = preprocess_data(
81+
adata_train_mod1=adata_train_mod1,
82+
adata_train_mod2=adata_train_mod2,
83+
adata_test_mod1=adata_test_mod1,
84+
task_type=task_type,
85+
inputs_n_components=par['inputs_n_components'],
86+
targets_n_components=par['targets_n_components']
87+
)
88+
89+
X_train = data['X_train']
90+
y_train = data['y_train']
91+
X_test = data['X_test']
92+
metadata_train = data['metadata_train']
93+
metadata_test = data['metadata_test']
94+
targets_decomposer_components = data['targets_decomposer_components']
95+
targets_global_median = data['targets_global_median']
96+
y_statistic = data['y_statistic']
97+
98+
print(f"Training data shape: X={X_train.shape}, y={y_train.shape}")
99+
print(f"Test data shape: X={X_test.shape}")
100+
101+
# Build model
102+
print("Building model", flush=True)
103+
input_dim = X_train.shape[1]
104+
output_dim = y_train.shape[1]
105+
106+
# Create encoder
107+
encoder = MLPBModule(
108+
input_dim=None, # Will be set in the main module
109+
output_dim=par['encoder_h_dim'],
110+
n_block=par['n_encoder_block'],
111+
h_dim=par['encoder_h_dim'],
112+
skip=par['use_skip_connections'],
113+
dropout_p=par['dropout_p'],
114+
activation=par['activation'],
115+
norm="layer_norm"
116+
)
117+
118+
# Create hierarchical decoder
119+
decoder = HierarchicalMLPBModule(
120+
input_dim=par['encoder_h_dim'],
121+
output_dim=None, # Will create multiple outputs
122+
n_block=par['n_decoder_block'],
123+
h_dim=par['decoder_h_dim'],
124+
skip=par['use_skip_connections'],
125+
dropout_p=par['dropout_p'],
126+
activation=par['activation'],
127+
norm="layer_norm"
128+
)
129+
130+
# Create main model
131+
model = SuzukiEncoderDecoderModule(
132+
x_dim=input_dim,
133+
y_dim=output_dim,
134+
y_statistic=y_statistic,
135+
encoder_h_dim=par['encoder_h_dim'],
136+
decoder_h_dim=par['decoder_h_dim'],
137+
n_decoder_block=par['n_decoder_block'],
138+
targets_decomposer_components=targets_decomposer_components,
139+
targets_global_median=targets_global_median,
140+
encoder=encoder,
141+
decoder=decoder,
142+
task_type=task_type,
143+
use_residual_connections=par['use_residual_connections']
144+
).to(device)
145+
146+
# Train model
147+
print("Training model", flush=True)
148+
trained_model = train_model(
149+
model=model,
150+
X_train=X_train,
151+
y_train=y_train,
152+
metadata_train=metadata_train,
153+
device=device,
154+
lr=par['learning_rate'],
155+
weight_decay=par['weight_decay'],
156+
epochs=par['epochs'],
157+
batch_size=par['batch_size'],
158+
task_type=task_type
159+
)
160+
161+
# Predict on test data
162+
print("Predicting on test data", flush=True)
163+
trained_model.eval()
164+
predictions = []
165+
166+
with torch.no_grad():
167+
# Handle metadata safely for test data
168+
if 'gender' in metadata_test.columns:
169+
gender_values = metadata_test['gender'].values
170+
if gender_values.dtype == object:
171+
gender_values = pd.to_numeric(gender_values, errors='coerce').fillna(0).astype(int)
172+
gender_test = torch.LongTensor(gender_values)
173+
else:
174+
gender_test = torch.LongTensor(np.zeros(len(X_test), dtype=int))
175+
176+
info_test = torch.FloatTensor(np.zeros((len(X_test), 1)))
177+
178+
test_dataset = torch.utils.data.TensorDataset(
179+
torch.FloatTensor(X_test),
180+
gender_test,
181+
info_test
182+
)
183+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=par['batch_size'], shuffle=False)
184+
185+
for batch_x, batch_gender, batch_info in test_loader:
186+
batch_x = batch_x.to(device)
187+
batch_gender = batch_gender.to(device)
188+
batch_info = batch_info.to(device)
189+
190+
pred = trained_model.predict(batch_x, batch_gender, batch_info)
191+
predictions.append(pred.cpu().numpy())
192+
193+
y_pred = np.vstack(predictions)
194+
195+
# Create output AnnData object
196+
print("Creating output", flush=True)
197+
adata_pred = ad.AnnData(
198+
obs=adata_test_mod1.obs.copy(),
199+
var=adata_train_mod2.var.copy(),
200+
layers={
201+
'normalized': y_pred
202+
},
203+
uns={
204+
'dataset_id': adata_train_mod1.uns.get('dataset_id', 'unknown'),
205+
'method_id': meta['name']
206+
}
207+
)
208+
209+
# Write output
210+
print("Writing output to file", flush=True)
211+
adata_pred.write_h5ad(par['output'], compression='gzip')
212+
213+
print("Done!", flush=True)
214+
215+
if __name__ == '__main__':
216+
main()

0 commit comments

Comments
 (0)