|
1 | 1 | import os |
2 | 2 | import sys |
3 | | -import tarfile |
4 | | -import tempfile |
5 | | -import zipfile |
6 | 3 |
|
7 | 4 | import anndata as ad |
8 | 5 | import mlflow.pyfunc |
9 | | -import pandas as pd |
| 6 | +import numpy as np |
10 | 7 |
|
11 | 8 | ## VIASH START |
12 | 9 | # Note: this section is auto-generated by viash at runtime. To edit it, make changes |
|
20 | 17 | ## VIASH END |
21 | 18 |
|
22 | 19 | sys.path.append(meta["resources_dir"]) |
23 | | -from exit_codes import exit_non_applicable |
24 | | -from read_anndata_partial import read_anndata |
25 | | -from unpack import unpack_directory |
| 20 | +from exit_codes import exit_non_applicable # noqa: E402 |
| 21 | +from mlflow import embed # noqa: E402 |
| 22 | +from read_anndata_partial import read_anndata # noqa: E402 |
| 23 | +from unpack import unpack_directory # noqa: E402 |
26 | 24 |
|
27 | 25 | print("====== Geneformer (MLflow model) ======", flush=True) |
28 | 26 |
|
|
45 | 43 | model = mlflow.pyfunc.load_model(model_dir) |
46 | 44 | print(model, flush=True) |
47 | 45 |
|
48 | | -print("\n>>> Writing temporary input H5AD file...", flush=True) |
49 | | -input_adata = ad.AnnData( |
50 | | - X=adata.X.copy(), |
51 | | - var=adata.var.filter(items=["feature_id"]).rename( |
52 | | - columns={"feature_id": "ensembl_id"} |
53 | | - ), |
54 | | -) |
55 | | -print(input_adata, flush=True) |
| 46 | +n_processors = meta.get("cpus") or os.cpu_count() |
| 47 | +print(f"Available processors: {n_processors}", flush=True) |
| 48 | + |
56 | 49 |
|
57 | | -h5ad_file = tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) |
58 | | -print(f"Temporary H5AD file: '{h5ad_file.name}'", flush=True) |
59 | | -input_adata.write(h5ad_file.name) |
60 | | -del input_adata |
| 50 | +def process_geneformer_input(input_adata): |
| 51 | + """Add Geneformer-specific fields to input AnnData.""" |
| 52 | + input_adata.obs["cell_idx"] = np.arange(input_adata.n_obs) |
| 53 | + input_adata.obs["n_counts"] = input_adata.X.sum(axis=1) |
61 | 54 |
|
62 | | -print("\n>>> Running model...", flush=True) |
63 | | -input_df = pd.DataFrame({"input_uri": [h5ad_file.name]}) |
64 | | -embedding = model.predict(input_df) |
| 55 | + |
| 56 | +print("\n>>> Embedding data...", flush=True) |
| 57 | +embedding = embed( |
| 58 | + adata, |
| 59 | + model, |
| 60 | + layers=["counts"], |
| 61 | + var={"feature_id": "ensembl_id"}, |
| 62 | + model_params={"nproc": n_processors}, |
| 63 | + process_adata=process_geneformer_input, |
| 64 | +) |
65 | 65 |
|
66 | 66 | print("\n>>> Storing output...", flush=True) |
67 | 67 | output = ad.AnnData( |
|
85 | 85 | print("\n>>> Cleaning up temporary files...", flush=True) |
86 | 86 | if model_temp is not None: |
87 | 87 | model_temp.cleanup() |
88 | | -h5ad_file.close() |
89 | | -os.unlink(h5ad_file.name) |
90 | 88 |
|
91 | 89 | print("\n>>> Done!", flush=True) |
0 commit comments