diff --git a/.codecov.yaml b/.codecov.yaml new file mode 100644 index 0000000..d0c0e29 --- /dev/null +++ b/.codecov.yaml @@ -0,0 +1,17 @@ +# Based on pydata/xarray +codecov: + require_ci_to_pass: no + +coverage: + status: + project: + default: + # Require 1% coverage, i.e., always succeed + target: 1 + patch: false + changes: false + +comment: + layout: diff, flags, files + behavior: once + require_base: no diff --git a/.cruft.json b/.cruft.json new file mode 100644 index 0000000..1092989 --- /dev/null +++ b/.cruft.json @@ -0,0 +1,43 @@ +{ + "template": "https://github.com/scverse/cookiecutter-scverse", + "commit": "00e962a93d725184f3c56a9c923680fc69965d08", + "checkout": null, + "context": { + "cookiecutter": { + "project_name": "arrayloaders", + "package_name": "arrayloaders", + "project_description": "A minibatch loader for anndata store", + "author_full_name": "Ilan Gold", + "author_email": "ilan.gold@scverse.org", + "github_user": "scverse", + "github_repo": "arrayloaders", + "license": "MIT License", + "ide_integration": true, + "_copy_without_render": [ + ".github/workflows/build.yaml", + ".github/workflows/test.yaml", + "docs/_templates/autosummary/**.rst" + ], + "_exclude_on_template_update": [ + "CHANGELOG.md", + "LICENSE", + "README.md", + "docs/api.md", + "docs/index.md", + "docs/notebooks/example.ipynb", + "docs/references.bib", + "docs/references.md", + "src/**", + "tests/**" + ], + "_render_devdocs": false, + "_jinja2_env_vars": { + "lstrip_blocks": true, + "trim_blocks": true + }, + "_template": "https://github.com/scverse/cookiecutter-scverse", + "_commit": "00e962a93d725184f3c56a9c923680fc69965d08" + } + }, + "directory": null +} diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..66678e3 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[{*.{yml,yaml,toml},.cruft.json}] +indent_size = 2 + +[Makefile] +indent_style = tab diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..3ca1ccb --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,94 @@ +name: Bug report +description: Report something that is broken or incorrect +labels: bug +body: + - type: markdown + attributes: + value: | + **Note**: Please read [this guide](https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports) + detailing how to provide the necessary information for us to reproduce your bug. In brief: + * Please provide exact steps how to reproduce the bug in a clean Python environment. + * In case it's not clear what's causing this bug, please provide the data or the data generation procedure. + * Sometimes it is not possible to share the data, but usually it is possible to replicate problems on publicly + available datasets or to share a subset of your data. + + - type: textarea + id: report + attributes: + label: Report + description: A clear and concise description of what the bug is. + validations: + required: true + + - type: textarea + id: versions + attributes: + label: Versions + description: | + Which version of packages. + + Please install `session-info2`, run the following command in a notebook, + click the “Copy as Markdown” button, then paste the results into the text box below. + + ```python + In[1]: import session_info2; session_info2.session_info(dependencies=True) + ``` + + Alternatively, run this in a console: + + ```python + >>> import session_info2; print(session_info2.session_info(dependencies=True)._repr_mimebundle_()["text/markdown"]) + ``` + render: python + placeholder: | + anndata 0.11.3 + ---- ---- + charset-normalizer 3.4.1 + coverage 7.7.0 + psutil 7.0.0 + dask 2024.7.1 + jaraco.context 5.3.0 + numcodecs 0.15.1 + jaraco.functools 4.0.1 + Jinja2 3.1.6 + sphinxcontrib-jsmath 1.0.1 + sphinxcontrib-htmlhelp 2.1.0 + toolz 1.0.0 + session-info2 0.1.2 + PyYAML 6.0.2 + llvmlite 0.44.0 + scipy 1.15.2 + pandas 2.2.3 + sphinxcontrib-devhelp 2.0.0 + h5py 3.13.0 + tblib 3.0.0 + setuptools-scm 8.2.0 + more-itertools 10.3.0 + msgpack 1.1.0 + sparse 0.15.5 + wrapt 1.17.2 + jaraco.collections 5.1.0 + numba 0.61.0 + pyarrow 19.0.1 + pytz 2025.1 + MarkupSafe 3.0.2 + crc32c 2.7.1 + sphinxcontrib-qthelp 2.0.0 + sphinxcontrib-serializinghtml 2.0.0 + zarr 2.18.4 + asciitree 0.3.3 + six 1.17.0 + sphinxcontrib-applehelp 2.0.0 + numpy 2.1.3 + cloudpickle 3.1.1 + sphinxcontrib-bibtex 2.6.3 + natsort 8.4.0 + jaraco.text 3.12.1 + setuptools 76.1.0 + Deprecated 1.2.18 + packaging 24.2 + python-dateutil 2.9.0.post0 + ---- ---- + Python 3.13.2 | packaged by conda-forge | (main, Feb 17 2025, 14:10:22) [GCC 13.3.0] + OS Linux-6.11.0-109019-tuxedo-x86_64-with-glibc2.39 + Updated 2025-03-18 15:47 diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..5b62547 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: Scverse Community Forum + url: https://discourse.scverse.org/ + about: If you have questions about “How to do X”, please ask them here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..11d9c9d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,11 @@ +name: Feature request +description: Propose a new feature for arrayloaders +labels: enhancement +body: + - type: textarea + id: description + attributes: + label: Description of feature + description: Please describe your suggestion for a new feature. It might help to describe a problem or use case, plus any alternatives that you have considered. + validations: + required: true diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 0000000..83e01a1 --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,33 @@ +name: Check Build + +on: + push: + branches: [main] + pull_request: + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + # to fail on error in multiline statements (-e), in pipes (-o pipefail), and on unset variables (-u). + shell: bash -euo pipefail {0} + +jobs: + package: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + filter: blob:none + fetch-depth: 0 + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + cache-dependency-glob: pyproject.toml + - name: Build package + run: uv build + - name: Check package + run: uvx twine check --strict dist/*.whl diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index ab44b0a..0000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,64 +0,0 @@ -name: build - -on: - push: - branches: [main] - pull_request: - -jobs: - build: - runs-on: ubuntu-latest - env: - GITHUB_EVENT_NAME: ${{ github.event_name }} - strategy: - fail-fast: false - matrix: - include: - - os: ubuntu-latest - python: "3.12" - timeout-minutes: 15 - - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - uses: actions/checkout@v4 - with: - repository: laminlabs/lndocs - path: lndocs - ref: main - - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - - uses: actions/cache@v4 - with: - path: ~/.cache/pre-commit - key: pre-commit-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }} - - run: pip install git+https://github.com/laminlabs/laminci - - uses: aws-actions/configure-aws-credentials@v4 - with: - aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} - aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} - aws-region: us-east-1 - - run: nox -s lint - - run: nox -s test - - uses: codecov/codecov-action@v5 - with: - token: ${{ secrets.CODECOV_TOKEN }} - - run: nox -s docs - - uses: cloudflare/wrangler-action@v3 - id: cloudflare - with: - apiToken: ${{ secrets.CLOUDFLARE_API_TOKEN }} - accountId: 472bdad691b4483dea759eadb37110bd - command: pages deploy "_build/html" --project-name=arrayloaders - gitHubToken: ${{ secrets.GITHUB_TOKEN }} - - uses: edumserrano/find-create-or-update-comment@v2 - if: github.event_name == 'pull_request' - with: - issue-number: ${{ github.event.pull_request.number }} - body-includes: "Deployment URL" - comment-author: "github-actions[bot]" - body: | - Deployment URL: ${{ steps.cloudflare.outputs.deployment-url }} - edit-mode: replace diff --git a/.github/workflows/doc-changes.yml b/.github/workflows/doc-changes.yml deleted file mode 100644 index 967c6c7..0000000 --- a/.github/workflows/doc-changes.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: doc-changes - -on: - pull_request_target: - branches: - - main - types: - - closed - -jobs: - latest-changes: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - run: pip install "laminci[doc-changes]@git+https://x-access-token:${{ secrets.LAMIN_BUILD_DOCS }}@github.com/laminlabs/laminci" - - run: laminci doc-changes - env: - repo_token: ${{ secrets.GITHUB_TOKEN }} - docs_token: ${{ secrets.LAMIN_BUILD_DOCS }} - changelog_file: lamin-docs/docs/changelog/soon/arrayloaders.md diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..017fefa --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,34 @@ +name: Release + +on: + release: + types: [published] + +defaults: + run: + # to fail on error in multiline statements (-e), in pipes (-o pipefail), and on unset variables (-u). + shell: bash -euo pipefail {0} + +# Use "trusted publishing", see https://docs.pypi.org/trusted-publishers/ +jobs: + release: + name: Upload release to PyPI + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/arrayloaders + permissions: + id-token: write # IMPORTANT: this permission is mandatory for trusted publishing + steps: + - uses: actions/checkout@v4 + with: + filter: blob:none + fetch-depth: 0 + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + cache-dependency-glob: pyproject.toml + - name: Build package + run: uv build + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..a4cea89 --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,99 @@ +name: Test + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + - cron: "0 5 1,15 * *" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +defaults: + run: + # to fail on error in multiline statements (-e), in pipes (-o pipefail), and on unset variables (-u). + shell: bash -euo pipefail {0} + +jobs: + # Get the test environment from hatch as defined in pyproject.toml. + # This ensures that the pyproject.toml is the single point of truth for test definitions and the same tests are + # run locally and on continuous integration. + # Check [[tool.hatch.envs.hatch-test.matrix]] in pyproject.toml and https://hatch.pypa.io/latest/environment/ for + # more details. + get-environments: + runs-on: ubuntu-latest + outputs: + envs: ${{ steps.get-envs.outputs.envs }} + steps: + - uses: actions/checkout@v4 + with: + filter: blob:none + fetch-depth: 0 + - name: Install uv + uses: astral-sh/setup-uv@v5 + - name: Get test environments + id: get-envs + run: | + ENVS_JSON=$(uvx hatch env show --json | jq -c 'to_entries + | map( + select(.key | startswith("hatch-test")) + | { + name: .key, + label: (if (.key | contains("pre")) then .key + " (PRE-RELEASE DEPENDENCIES)" else .key end), + python: .value.python + } + )') + echo "envs=${ENVS_JSON}" | tee $GITHUB_OUTPUT + + # Run tests through hatch. Spawns a separate runner for each environment defined in the hatch matrix obtained above. + test: + needs: get-environments + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + env: ${{ fromJSON(needs.get-environments.outputs.envs) }} + + name: ${{ matrix.env.label }} + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v4 + with: + filter: blob:none + fetch-depth: 0 + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + python-version: ${{ matrix.env.python }} + cache-dependency-glob: pyproject.toml + - name: create hatch environment + run: uvx hatch env create ${{ matrix.env.name }} + - name: run tests using hatch + env: + MPLBACKEND: agg + PLATFORM: ${{ matrix.os }} + DISPLAY: :42 + run: uvx hatch run ${{ matrix.env.name }}:run-cov + - name: generate coverage report + run: uvx hatch run ${{ matrix.env.name }}:coverage xml + - name: Upload coverage + uses: codecov/codecov-action@v4 + + # Check that all tests defined above pass. This makes it easy to set a single "required" test in branch + # protection instead of having to update it frequently. See https://github.com/re-actors/alls-green#why. + check: + name: Tests pass in all hatch environments + if: always() + needs: + - get-environments + - test + runs-on: ubuntu-latest + steps: + - uses: re-actors/alls-green@release/v1 + with: + jobs: ${{ toJSON(needs) }} diff --git a/.gitignore b/.gitignore index 3f2358d..31e10b3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,111 +1,20 @@ -# macOS +# Temp files .DS_Store -.AppleDouble -.LSOverride +*~ +buck-out/ -# local files -scripts/ - -# Byte-compiled / optimized / DLL files +# Compiled files +.venv/ __pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so +.*cache/ # Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# dotenv -.env - -# virtualenv -.venv -venv/ -ENV/ - -# mypy -.mypy_cache/ +/dist/ -# IDE settings -.vscode/ -.idea/ +# Tests and coverage +/data/ +/node_modules/ -# Lamin -_build -docs/arrayloaders.* -lamin_sphinx -docs/conf.py -_docs_tmp* +# docs +/docs/generated/ +/docs/_build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84d60c2..8c94113 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,48 +4,44 @@ default_language_version: default_stages: - pre-commit - pre-push -minimum_pre_commit_version: 2.12.0 +minimum_pre_commit_version: 2.16.0 repos: - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 + - repo: https://github.com/biomejs/pre-commit + rev: v2.2.2 hooks: - - id: prettier - exclude: | - (?x)( - docs/changelog.md - ) - - repo: https://github.com/kynan/nbstripout - rev: 0.8.1 + - id: biome-format + exclude: ^\.cruft\.json$ # inconsistent indentation with cruft - file never to be modified manually. + - repo: https://github.com/tox-dev/pyproject-fmt + rev: v2.6.0 hooks: - - id: nbstripout - exclude: | - (?x)( - docs/examples/| - docs/notes/ - ) + - id: pyproject-fmt - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.0 + rev: v0.12.10 hooks: - - id: ruff - args: [--fix, --exit-non-zero-on-fix, --unsafe-fixes] + - id: ruff-check + types_or: [python, pyi, jupyter] + args: [--fix, --exit-non-zero-on-fix] - id: ruff-format + types_or: [python, pyi, jupyter] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: detect-private-key - id: check-ast - id: end-of-file-fixer - exclude: | - (?x)( - .github/workflows/latest-changes.jinja2 - ) - id: mixed-line-ending args: [--fix=lf] - id: trailing-whitespace - id: check-case-conflict - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.16.1 + # Check that there are no merge conflicts (could be generated by template sync) + - id: check-merge-conflict + args: [--assume-in-merge] + - repo: local hooks: - - id: mypy - args: [--no-strict-optional, --ignore-missing-imports] - additional_dependencies: ["types-requests", "types-attrs"] + - id: forbid-to-commit + name: Don't commit rej files + entry: | + Cannot commit .rej files. These indicate merge conflicts that arise during automated template updates. + Fix the merge conflicts manually and remove the .rej files. + language: fail + files: '.*\.rej$' diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..c3f3f96 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,15 @@ +# https://docs.readthedocs.io/en/stable/config-file/v2.html +version: 2 +build: + os: ubuntu-24.04 + tools: + python: "3.12" + jobs: + create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + build: + html: + - uvx hatch run docs:build + - mv docs/_build $READTHEDOCS_OUTPUT diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..caaeb4f --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,18 @@ +{ + "recommendations": [ + // GitHub integration + "github.vscode-github-actions", + "github.vscode-pull-request-github", + // Language support + "ms-python.python", + "ms-python.vscode-pylance", + "ms-toolsai.jupyter", + "tamasfe.even-better-toml", + // Dependency management + "ninoseki.vscode-mogami", + // Linting and formatting + "editorconfig.editorconfig", + "charliermarsh.ruff", + "biomejs.biome", + ], +} diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..36d1874 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,33 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Build Documentation", + "type": "debugpy", + "request": "launch", + "module": "sphinx", + "args": ["-M", "html", ".", "_build"], + "cwd": "${workspaceFolder}/docs", + "console": "internalConsole", + "justMyCode": false, + }, + { + "name": "Python: Debug Test", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "internalConsole", + "justMyCode": false, + "env": { + "PYTEST_ADDOPTS": "--color=yes", + }, + "presentation": { + "hidden": true, + }, + }, + ], +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..e034b91 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,18 @@ +{ + "[python][json][jsonc]": { + "editor.formatOnSave": true, + }, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.codeActionsOnSave": { + "source.fixAll": "always", + "source.organizeImports": "always", + }, + }, + "[json][jsonc]": { + "editor.defaultFormatter": "biomejs.biome", + }, + "python.analysis.typeCheckingMode": "basic", + "python.testing.pytestEnabled": true, + "python.testing.pytestArgs": ["-vv", "--color=yes"], +} diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..c185628 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,15 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog][], +and this project adheres to [Semantic Versioning][]. + +[keep a changelog]: https://keepachangelog.com/en/1.0.0/ +[semantic versioning]: https://semver.org/spec/v2.0.0.html + +## [Unreleased] + +### Added + +- Basic tool, preprocessing and plotting functions diff --git a/LICENSE b/LICENSE index b09cd78..beb0297 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,21 @@ -Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +MIT License + +Copyright (c) 2025, Ilan Gold + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index c4d8ad4..269bd00 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,57 @@ -# `arrayloaders` +# arrayloaders -This package still is in an early version. To use it, see: +[![Tests][badge-tests]][tests] +[![Documentation][badge-docs]][documentation] -- [Training example](https://lamin.ai/laminlabs/arrayloader-benchmarks/transform/UMQFXo0vs0Z6) -- [Benchmark](https://lamin.ai/laminlabs/arrayloader-benchmarks/transform/2C0ghWpz0auc) +[badge-tests]: https://img.shields.io/github/actions/workflow/status/laminlabs/arrayloaders/test.yaml?branch=main +[badge-docs]: https://img.shields.io/readthedocs/arrayloaders -## Contributing +A minibatch loader for anndata store -Please run `pre-commit install` and `gitmoji -i` on the CLI before starting to work on this repository! +## Getting started + +Please refer to the [documentation][], +in particular, the [API documentation][]. + +## Installation + +You need to have Python 3.10 or newer installed on your system. +If you don't have Python installed, we recommend installing [uv][]. + +There are several alternative options to install arrayloaders: + + + +1. Install the latest development version: + +```bash +pip install git+https://github.com/laminlabs/arrayloaders.git@main +``` + +## Release notes + +See the [changelog][]. + +## Contact + +For questions and help requests, you can reach out in the [scverse discourse][]. +If you found a bug, please use the [issue tracker][]. + +## Citation + +> t.b.a + +[uv]: https://github.com/astral-sh/uv +[scverse discourse]: https://discourse.scverse.org/ +[issue tracker]: https://github.com/laminlabs/arrayloaders/issues +[tests]: https://github.com/laminlabs/arrayloaders/actions/workflows/test.yaml +[documentation]: https://arrayloaders.readthedocs.io +[changelog]: https://arrayloaders.readthedocs.io/en/latest/changelog.html +[api documentation]: https://arrayloaders.readthedocs.io/en/latest/api.html +[pypi]: https://pypi.org/project/arrayloaders diff --git a/arrayloaders/__init__.py b/arrayloaders/__init__.py deleted file mode 100644 index 3b49363..0000000 --- a/arrayloaders/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Modlyn. - -Import the package:: - - import arrayloaders - -The package has two modules. - -.. autosummary:: - :toctree: . - - io - -""" - -from __future__ import annotations - -__version__ = "0.0.3" # denote a pre-release for 0.1.0 with 0.1rc1 - -from . import io diff --git a/arrayloaders/io/__init__.py b/arrayloaders/io/__init__.py deleted file mode 100644 index 8196cc3..0000000 --- a/arrayloaders/io/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -"""IO. - -Lightning data modules. - -.. autosummary:: - :toctree: . - - ClassificationDataModule - -Torch data loaders. - -.. autosummary:: - :toctree: . - - DaskDataset - read_lazy_store - read_lazy - -Array store creation. - -.. autosummary:: - :toctree: . - - create_store_from_h5ads - add_h5ads_to_store - shuffle_and_shard_h5ads - -""" - -from __future__ import annotations - -from .dask_loader import ( # TODO: clean up exported functions - do we need both read_lazy and read_lazy_store? - DaskDataset, - read_lazy, - read_lazy_store, -) -from .datamodules import ClassificationDataModule -from .store_creation import add_h5ads_to_store, create_store_from_h5ads -from .zarr_loader import ZarrDenseDataset, ZarrSparseDataset diff --git a/arrayloaders/io/dask_loader.py b/arrayloaders/io/dask_loader.py deleted file mode 100644 index 51faca5..0000000 --- a/arrayloaders/io/dask_loader.py +++ /dev/null @@ -1,157 +0,0 @@ -from __future__ import annotations - -import pathlib -import warnings -from typing import TYPE_CHECKING - -import anndata as ad -import dask -import numpy as np -import pandas as pd -import zarr -from torch.utils.data import IterableDataset - -from .utils import WorkerHandle, check_lt_1, sample_rows - -if TYPE_CHECKING: - from typing import Literal - - -# TODO: refactor to read full lazy and then simply pick out the needed columns into memory instead of having `read_obs_lazy` as a separate arg -def read_lazy(path, obs_columns: list[str] | None = None, read_obs_lazy: bool = False): - """Reads an individual shard of a Zarr store into an AnnData object. - - Args: - path: Path to individual Zarr-based AnnData shard. - obs_columns: List of observation columns to read. If None, all columns are read. - read_obs_lazy: If True, reads the obs DataFrame lazily. Useful for large obs DataFrames. - - Returns: - AnnData object loaded from the specified shard. - """ - g = zarr.open(path, mode="r") - - adata = ad.experimental.read_lazy(g) - # TODO: Adapt dask code below to just handle an in-memory xarray data array - if not read_obs_lazy: - if obs_columns is None: - adata.obs = ad.io.read_elem(g["obs"]) - else: - adata.obs = pd.DataFrame( - {col: ad.io.read_elem(g[f"obs/{col}"]) for col in obs_columns} - ) - - return adata - - -def read_lazy_store( - path, obs_columns: list[str] | None = None, read_obs_lazy: bool = False -): - """Reads a Zarr store containing multiple shards into a single AnnData object. - - Args: - path: Path to the Zarr store containing multiple shards. - obs_columns: List of observation columns to read. If None, all columns are read. - read_obs_lazy: If True, reads the obs DataFrame lazily. Useful for large obs DataFrames. - - Returns: - AnnData: The concatenated AnnData object loaded from all shards. - """ - path = pathlib.Path(path) - - with warnings.catch_warnings(): - # Ignore zarr v3 warnings - warnings.simplefilter("ignore") - adata = ad.concat( - [ - read_lazy(path / shard, obs_columns, read_obs_lazy) - for shard in path.iterdir() - if str(shard).endswith(".zarr") - ] - ) - - return adata - - -def _combine_chunks(lst, chunk_size): - return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] - - -class DaskDataset(IterableDataset): - """Dask-based IterableDataset for loading AnnData objects in chunks. - - Args: - adata: The AnnData object to yield samples from. - label_column: The name of the column in `adata.obs` that contains the labels. - n_chunks: Number of chunks of the underlying dask.array to load at a time. - Loading more chunks at a time can improve performance and randomness, but increases memory usage. - Defaults to 8. - shuffle: Whether to yield samples in a random order. Defaults to True. - dask_scheduler: The Dask scheduler to use for parallel computation. - "synchronous" for single-threaded execution, "threads" for multithreaded execution. Defaults to "threads". - n_workers: Number of Dask workers to use. If None, the number of workers is determined by Dask. - - Examples: - >>> from arrayloaders.io.dask_loader import DaskDataset, read_lazy_store - >>> from torch.utils.data import DataLoader - >>> label_column = "y" - >>> adata = read_lazy_store("path/to/zarr/store", obs_columns=[label_column]) - >>> dataset = DaskDataset(adata, label_column=label_column, n_chunks=8, shuffle=True) - >>> dataloader = DataLoader(dataset, batch_size=2048, num_workers=4, drop_last=True) - >>> for batch in dataloader: - ... x, y = batch - ... # Process the batch - """ - - def __init__( - self, - adata: ad.AnnData, - label_column: str, - n_chunks: int = 8, - shuffle: bool = True, - dask_scheduler: Literal["synchronous", "threads"] = "threads", - n_workers: int | None = None, - ): - check_lt_1( - [adata.shape[0], n_chunks], - ["Size of anndata obs dimension", "Number of chunks"], - ) - self.adata = adata - self.label_column = label_column - self.n_chunks = n_chunks - self.shuffle = shuffle - self.dask_scheduler = dask_scheduler - self.n_workers = n_workers - - self.worker_handle = WorkerHandle() - - def _get_chunks(self): - chunk_boundaries = np.cumsum([0] + list(self.adata.X.chunks[0])) - slices = [ - slice(int(start), int(end)) - for start, end in zip( - chunk_boundaries[:-1], chunk_boundaries[1:], strict=True - ) - ] - blocks_idxs = np.arange(len(self.adata.X.chunks[0])) - chunks = list(zip(blocks_idxs, slices, strict=True)) - - if self.shuffle: - self.worker_handle.shuffle(chunks) - - return self.worker_handle.get_part_for_worker(chunks) - - def __iter__(self): - for chunks in _combine_chunks(self._get_chunks(), self.n_chunks): - block_idxs, slices = zip(*chunks, strict=True) - x_list = dask.compute( - [self.adata.X.blocks[i] for i in block_idxs], - scheduler=self.dask_scheduler, - )[0] - obs_list = [ - self.adata.obs[self.label_column].iloc[s].to_numpy() for s in slices - ] - yield from sample_rows(x_list, obs_list, shuffle=self.shuffle) - - def __len__(self): - return len(self.adata) diff --git a/arrayloaders/io/datamodules.py b/arrayloaders/io/datamodules.py deleted file mode 100644 index 2721122..0000000 --- a/arrayloaders/io/datamodules.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import lightning as L -from torch.utils.data import DataLoader - -from .dask_loader import DaskDataset - -if TYPE_CHECKING: - from typing import Literal - - import anndata as ad - - -class ClassificationDataModule(L.LightningDataModule): - """A LightningDataModule for classification tasks using arrayloaders.io.DaskDataset. - - Args: - adata_train: anndata.AnnData object containing the training data. - adata_val: anndata.AnnData object containing the validation data. - label_column: Name of the column in `obs` that contains the target values. - train_dataloader_kwargs: Additional keyword arguments passed to the torch DataLoader for the training dataset. - val_dataloader_kwargs: Additional keyword arguments passed to the torch DataLoader for the validation dataset. - n_chunks: Number of chunks of the underlying dask.array to load at a time. Loading more chunks at a time can improve performance and randomness, but increases memory usage. - dask_scheduler: The Dask scheduler to use for parallel computation. Use "synchronous" for single-threaded execution or "threads" for multithreaded execution. - - Examples: - >>> from arrayloaders.io.datamodules import ClassificationDataModule - >>> from arrayloaders.io.dask_loader import read_lazy_store - >>> adata_train = read_lazy_store("path/to/train/store", obs_columns=["label"]) - >>> adata_train.obs["y"] = adata_train.obs["label"].cat.codes.to_numpy().astype("i8") - >>> datamodule = ClassificationDataModule( - ... adata_train=adata_train, - ... adata_val=None, - ... label_column="label", - ... train_dataloader_kwargs={ - ... "batch_size": 2048, - ... "drop_last": True, - ... "num_workers": 4 - ... }, - ... ) - >>> train_loader = datamodule.train_dataloader() - """ - - def __init__( - self, - adata_train: ad.AnnData | None, - adata_val: ad.AnnData | None, - label_column: str, - train_dataloader_kwargs=None, - val_dataloader_kwargs=None, - n_chunks: int = 8, - dask_scheduler: Literal["synchronous", "threads"] = "threads", - ): - super().__init__() - if train_dataloader_kwargs is None: - train_dataloader_kwargs = {} - if val_dataloader_kwargs is None: - val_dataloader_kwargs = {} - - self.adata_train = adata_train - self.adata_val = adata_val - self.label_col = label_column - self.train_dataloader_kwargs = train_dataloader_kwargs - self.val_dataloader_kwargs = val_dataloader_kwargs - self.n_chunks = n_chunks - self.dask_scheduler = dask_scheduler - - def train_dataloader(self): - train_dataset = DaskDataset( - self.adata_train, - label_column=self.label_col, - n_chunks=self.n_chunks, - dask_scheduler=self.dask_scheduler, - ) - - return DataLoader(train_dataset, **self.train_dataloader_kwargs) - - def val_dataloader(self): - val_dataset = DaskDataset( - self.adata_val, - label_column=self.label_col, - shuffle=False, - n_chunks=self.n_chunks, - dask_scheduler=self.dask_scheduler, - ) - - return DataLoader(val_dataset, **self.val_dataloader_kwargs) diff --git a/arrayloaders/io/utils.py b/arrayloaders/io/utils.py deleted file mode 100644 index dbb8bf3..0000000 --- a/arrayloaders/io/utils.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -from functools import cached_property -from typing import Protocol - -import numpy as np -from torch.utils.data import get_worker_info - - -def sample_rows( - x_list: list[np.ndarray], - obs_list: list[np.ndarray] | None, - indices: list[np.ndarray] | None = None, - *, - shuffle: bool = True, -): - """Samples rows from multiple arrays and their corresponding observation arrays. - - Args: - x_list: A list of numpy arrays containing the data to sample from. - obs_list: A list of numpy arrays containing the corresponding observations. - indices: the list of indexes for each element in x_list/ - shuffle: Whether to shuffle the rows before sampling. Defaults to True. - - Yields: - tuple: A tuple containing a row from `x_list` and the corresponding row from `obs_list`. - """ - lengths = np.fromiter((x.shape[0] for x in x_list), dtype=int) - cum = np.concatenate(([0], np.cumsum(lengths))) - total = cum[-1] - idxs = np.arange(total) - if shuffle: - np.random.default_rng().shuffle(idxs) - arr_idxs = np.searchsorted(cum, idxs, side="right") - 1 - row_idxs = idxs - cum[arr_idxs] - for ai, ri in zip(arr_idxs, row_idxs, strict=True): - res = [ - x_list[ai][ri], - obs_list[ai][ri] if obs_list is not None else None, - ] - if indices is not None: - yield (*res, indices[ai][ri]) - else: - yield tuple(res) - - -class WorkerHandle: - @cached_property - def _worker_info(self): - return get_worker_info() - - @cached_property - def _rng(self): - if self._worker_info is None: - return np.random.default_rng() - else: - # This is used for the _get_chunks function - # Use the same seed for all workers that the resulting splits are the same across workers - # torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed - return np.random.default_rng(self._worker_info.seed - self._worker_info.id) - - def shuffle(self, obj: np.typing.ArrayLike) -> None: - """Perform in-place shuffle. - - Args: - obj: The object to be shuffled - """ - self._rng.shuffle(obj) - - def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: - if self._worker_info is None: - return obj - num_workers, worker_id = self._worker_info.num_workers, self._worker_info.id - chunks_split = np.array_split(obj, num_workers) - return chunks_split[worker_id] - - -def check_lt_1(vals: list[int], labels: list[str]): - """Raise a ValueError if any of the values are less than one. - - The format of the error is "{labels[i]} must be greater than 1, got {values[i]}" - and is raised based on the first found less than one value. - - Args: - vals: The values to check < 1 - labels: The label for the value in the error if the value is less than one. - - Raises: - ValueError: _description_ - """ - if any(is_lt_1 := [v < 1 for v in vals]): - label, value = next( - (label, value) - for label, value, check in zip( - labels, - vals, - is_lt_1, - strict=True, - ) - if check - ) - raise ValueError(f"{label} must be greater than 1, got {value}") - - -class SupportsShape(Protocol): - @property - def shape(self) -> tuple[int, int] | list[int]: ... - - -def check_var_shapes(objs: list[SupportsShape]): - if not all(objs[0].shape[1] == d.shape[1] for d in objs): - raise ValueError("TODO: All datasets must have same shape along the var axis.") diff --git a/arrayloaders/io/zarr_loader.py b/arrayloaders/io/zarr_loader.py deleted file mode 100644 index 26ee044..0000000 --- a/arrayloaders/io/zarr_loader.py +++ /dev/null @@ -1,718 +0,0 @@ -from __future__ import annotations - -import asyncio -import math -from abc import ABCMeta, abstractmethod -from collections import OrderedDict, defaultdict -from dataclasses import dataclass -from itertools import accumulate, chain, islice, pairwise -from types import NoneType -from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar, cast - -import anndata as ad -import numpy as np -import zarr -import zarr.core.sync as zsync -from scipy import sparse as sp -from torch.utils.data import IterableDataset - -from .utils import WorkerHandle, check_lt_1, check_var_shapes - -try: - from cupy import ndarray as CupyArray - from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover -except ImportError: - CupyCSRMatrix = NoneType - CupyArray = NoneType - -if TYPE_CHECKING: - from collections.abc import Awaitable, Callable, Iterator - from types import ModuleType - from typing import Self - - -def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: - return np.split(a, np.arange(size, len(a), size)) - - -OnDiskArray = TypeVar("OnDiskArray", ad.abc.CSRDataset, zarr.Array) -accepted_on_disk_types = OnDiskArray.__constraints__ - - -@dataclass -class CSRContainer: - elems: tuple[np.ndarray, np.ndarray, np.ndarray] - shape: tuple[int, int] - - -OutputInMemoryArray = TypeVar( - "OutputInMemoryArray", sp.csr_matrix, np.ndarray, CupyCSRMatrix, CupyArray -) -InputInMemoryArray = TypeVar("InputInMemoryArray", CSRContainer, np.ndarray) - - -def _batched(iterable, n): - if n < 1: - raise ValueError("n must be >= 1") - it = iter(iterable) - while batch := list(islice(it, n)): - yield batch - - -async def index_datasets( - dataset_index_to_slices: OrderedDict[int, list[slice]], - fetch_data: Callable[[list[slice], int], Awaitable[CSRContainer | np.ndarray]], -) -> list[InputInMemoryArray]: - """Helper function meant to encapsulate asynchronous calls so that we can use the same event loop as zarr. - - Args: - dataset_index_to_slices: A lookup of the list-placement index of a dataset to the request slices. - fetch_data: The function to do the fetching for a given slice-dataset index pair. - """ - tasks = [] - for dataset_idx in dataset_index_to_slices.keys(): - tasks.append( - fetch_data( - dataset_index_to_slices[dataset_idx], - dataset_idx, - ) - ) - return await asyncio.gather(*tasks) - - -add_dataset_docstring = """\ -Append datasets to this loader. - -Args: - datasets: List of :class:`anndata.abc.CSRDataset` or :class:`zarr.Array` objects, generally from :attr:`anndata.AnnData.X`. - obs: List of `numpy.ndarray` labels, generally from :attr:`anndata.AnnData.obs`. -""" - -add_dataset_docstring = """\ -Append a dataset to this loader. - -Args: - dataset: :class:`anndata.abc.CSRDataset` or :class:`zarr.Array` object, generally from :attr:`anndata.AnnData.X`. - obs: `numpy.ndarray` labels for the dataset, generally from :attr:`anndata.AnnData.obs`. -""" - - -class AnnDataManager(Generic[OnDiskArray, InputInMemoryArray, OutputInMemoryArray]): - train_datasets: list[OnDiskArray] = [] - labels: list[np.ndarray] | None = None - _return_index: bool = False - _on_add: Callable | None = None - _batch_size: int = 1 - _shapes: list[tuple[int, int]] = [] - sp_module: ModuleType - np_module: ModuleType - - def __init__( - self, - *, - on_add: Callable | None = None, - return_index: bool = False, - batch_size: int = 1, - preload_to_gpu: bool = False, - ): - self._on_add = on_add - self._return_index = return_index - self._batch_size = batch_size - if preload_to_gpu: - try: - import cupy as cp - import cupyx.scipy.sparse as cpx # pragma: no cover - - self.sp_module = cpx # pragma: no cover - self.np_module = cp # pragma: no cover - except ImportError: - raise ImportError( - "Cannot find cupy module even though `preload_to_gpu` argument was set to `True`" - ) from None - else: - self.sp_module = sp - self.np_module = np - - @property - def dataset_type(self) -> type[OnDiskArray]: - return type(self.train_datasets[0]) - - @property - def n_obs(self) -> int: - return sum(shape[0] for shape in self._shapes) - - @property - def n_var(self) -> int: - return self._shapes[0][1] - - def add_anndatas( - self, - adatas: list[ad.AnnData], - layer_keys: list[str | None] | str | None = None, - obs_keys: list[str] | str | None = None, - ) -> None: - raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") - - def add_anndata( - self, - adata: ad.AnnData, - layer_key: str | None = None, - obs_key: str | None = None, - ) -> None: - raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") - - def add_datasets( - self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None - ) -> None: - if obs is None: - obs = [None] * len(datasets) - for ds, o in zip(datasets, obs, strict=True): - self.add_dataset(ds, o) - - def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> None: - if len(self.train_datasets) > 0: - if self.labels is None and obs is not None: - raise ValueError( - f"Cannot add a dataset with obs label {obs} when training datasets have already been added without labels" - ) - if self.labels is not None and obs is None: - raise ValueError( - "Cannot add a dataset with no obs label when training datasets have already been added without labels" - ) - if not isinstance(dataset, accepted_types := accepted_on_disk_types): - raise TypeError( - f"Cannot add a dataset of type {type(dataset)}, only {accepted_types} are allowed" - ) - if len(self.train_datasets) > 0 and not isinstance(dataset, self.dataset_type): - raise TypeError( - f"Cannot add a dataset whose data of type {type(dataset)} was not an instance of expected type {self.dataset_type}" - ) - datasets = self.train_datasets + [dataset] - check_var_shapes(datasets) - self._shapes = self._shapes + [dataset.shape] - self.train_datasets = datasets - if self.labels is not None: # labels exist - self.labels += [obs] - elif ( - obs is not None - ): # labels dont exist yet, but are being added for the first time - self.labels = [obs] - if self._on_add is not None: - self._on_add() - - def _get_relative_obs_indices( - self, index: slice, *, use_original_space: bool = False - ) -> list[tuple[slice, int]]: - """Generate a slice relative to a dataset given a global slice index over all datasets. - - For a given slice indexer of axis 0, return a new slice relative to the on-disk - data it represents given the number of total observations as well as the index of - the underlying data on disk from the argument `sparse_datasets` to the initializer. - - For example, given slice index (10, 15), for 4 datasets each with size 5 on axis zero, - this function returns ((0,5), 2) representing slice (0,5) along axis zero of sparse dataset 2. - - Args: - index: The queried slice. - use_original_space: Whether or not the slices should be reindexed against the anndata objects. - - Returns: - A slice relative to the dataset it represents as well as the index of said dataset in `sparse_datasets`. - """ - min_idx = index.start - max_idx = index.stop - curr_pos = 0 - slices = [] - for idx, (n_obs, _) in enumerate(self._shapes): - array_start = curr_pos - array_end = curr_pos + n_obs - - start = max(min_idx, array_start) - stop = min(max_idx, array_end) - if start < stop: - if use_original_space: - slices.append((slice(start, stop), idx)) - else: - relative_start = start - array_start - relative_stop = stop - array_start - slices.append((slice(relative_start, relative_stop), idx)) - curr_pos += n_obs - return slices - - def _slices_to_slices_with_array_index( - self, slices: list[slice], *, use_original_space: bool = False - ) -> OrderedDict[int, list[slice]]: - """Given a list of slices, give the lookup between on-disk datasets and slices relative to that dataset. - - Args: - slices: Slices to relative to the on-disk datasets. - use_original_space: Whether or not the slices should be reindexed against the anndata objects. - - Returns: - A lookup between the dataset and its indexing slices, ordered by keys. - """ - dataset_index_to_slices: defaultdict[int, list[slice]] = defaultdict(list) - for slice in slices: - for relative_obs_indices in self._get_relative_obs_indices( - slice, use_original_space=use_original_space - ): - dataset_index_to_slices[relative_obs_indices[1]] += [ - relative_obs_indices[0] - ] - keys = sorted(dataset_index_to_slices.keys()) - dataset_index_to_slices_sorted = OrderedDict() - for k in keys: - dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] - return dataset_index_to_slices_sorted - - def _get_chunks( - self, chunk_size: int, worker_handle: WorkerHandle, shuffle: bool - ) -> np.ndarray: - """Get a potentially shuffled list of chunk ids, accounting for the fact that this dataset might be inside a worker. - - Returns: - A :class:`numpy.ndarray` of chunk ids. - """ - chunks = np.arange(math.ceil(self.n_obs / chunk_size)) - if shuffle: - worker_handle.shuffle(chunks) - - return worker_handle.get_part_for_worker(chunks) - - def iter( - self, - chunk_size: int, - worker_handle: WorkerHandle, - preload_nchunks: int, - shuffle: bool, - fetch_data: Callable[[list[slice], int], Awaitable[np.ndarray | CSRContainer]], - ) -> Iterator[ - tuple[InputInMemoryArray, None | np.ndarray] - | tuple[InputInMemoryArray, None | np.ndarray, np.ndarray] - ]: - """Iterate over the on-disk csr datasets. - - Yields: - A one-row sparse matrix. - """ - check_lt_1( - [len(self.train_datasets), self.n_obs], - ["Number of datasets", "Number of observations"], - ) - # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0 - # we must keep track of the leftover data. - in_memory_data = None - in_memory_labels = None - in_memory_indices = None - mod = self.sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np - for chunk_indices in _batched( - self._get_chunks(chunk_size, worker_handle, shuffle), preload_nchunks - ): - slices = [ - slice( - index * chunk_size, - min(self.n_obs, (index + 1) * chunk_size), - ) - for index in chunk_indices - ] - dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) - # Fetch the data over slices - chunks: list[InputInMemoryArray] = zsync.sync( - index_datasets(dataset_index_to_slices, fetch_data) - ) - if any(isinstance(c, CSRContainer) for c in chunks): - chunks_converted: list[OutputInMemoryArray] = [ - self.sp_module.csr_matrix( - tuple(self.np_module.asarray(e) for e in c.elems), shape=c.shape - ) - for c in chunks - ] - else: - chunks_converted = [self.np_module.asarray(c) for c in chunks] - # Accumulate labels - labels: None | list[np.ndarray] = None - if self.labels is not None: - labels = [] - for dataset_idx in dataset_index_to_slices.keys(): - labels += [ - self.labels[dataset_idx][ - np.concatenate( - [ - np.arange(s.start, s.stop) - for s in dataset_index_to_slices[dataset_idx] - ] - ) - ] - ] - # Accumulate indices if necessary - indices: None | list[np.ndarray] = None - if self._return_index: - dataset_index_to_slices = self._slices_to_slices_with_array_index( - slices, use_original_space=True - ) - dataset_indices = dataset_index_to_slices.keys() - indices = [ - np.concatenate( - [ - np.arange( - s.start, - s.stop, - ) - for s in dataset_index_to_slices[index] - ] - ) - for index in dataset_indices - ] - # Do batch returns, handling leftover data as necessary - in_memory_data = ( - mod.vstack(chunks_converted) - if in_memory_data is None - else mod.vstack([in_memory_data, *chunks_converted]) - ) - if self.labels is not None: - in_memory_labels = ( - np.concatenate(labels) - if in_memory_labels is None - else np.concatenate([in_memory_labels, *labels]) - ) - if self._return_index: - in_memory_indices = ( - np.concatenate(indices) - if in_memory_indices is None - else np.concatenate([in_memory_indices, *indices]) - ) - # Create random indices into in_memory_data and then index into it - # If there is "leftover" at the end (see the modulo op), - # save it for the next iteration. - batch_indices = np.arange(in_memory_data.shape[0]) - if shuffle: - np.random.default_rng().shuffle(batch_indices) - splits = split_given_size(batch_indices, self._batch_size) - for i, s in enumerate(splits): - if s.shape[0] == self._batch_size: - res = [ - in_memory_data[s], - in_memory_labels[s] if self.labels is not None else None, - ] - if self._return_index: - res += [in_memory_indices[s]] - yield tuple(res) - if i == ( - len(splits) - 1 - ): # end of iteration, leftover data needs be kept - if (s.shape[0] % self._batch_size) != 0: - in_memory_data = in_memory_data[s] - if in_memory_labels is not None: - in_memory_labels = in_memory_labels[s] - if in_memory_indices is not None: - in_memory_indices = in_memory_indices[s] - else: - in_memory_data = None - in_memory_labels = None - in_memory_indices = None - if in_memory_data is not None: # handle any leftover data - res = [ - in_memory_data, - in_memory_labels if self.labels is not None else None, - ] - if self._return_index: - res += [in_memory_indices] - yield tuple(res) - - -AnnDataManager.add_datasets.__doc__ = add_dataset_docstring -AnnDataManager.add_dataset.__doc__ = add_dataset_docstring - -__init_docstring__ = """A loader for on-disk {array_type} data. - -This loader batches together slice requests to the underlying {array_type} stores to acheive higher performance. -This custom code to do this task will be upstreamed into anndata at some point and no longer rely on private zarr apis. -The loader is agnostic to the on-disk chunking/sharding, but it may be advisable to align with the in-memory chunk size. - -Args: - chunk_size: The obs size (i.e., axis 0) of contiguous array data to fetch, by default 512 - preload_nchunks: The number of chunks of contiguous array data to fetch, by default 32 - shuffle: Whether or not to shuffle the data, by default True - return_index: Whether or not to return the index on each iteration, by default False - preload_to_gpu: Whether or not to use cupy for non-io array operations like vstack and indexing. This option entails greater GPU memory usage. -""" - - -# TODO: make this part of the public zarr or zarrs-python API. -# We can do chunk coalescing in zarrs based on integer arrays, so I think -# there would make sense with ezclump or similar. -# Another "solution" would be for zarrs to support integer indexing properly, if that pipeline works, -# or make this an "experimental setting" and to use integer indexing for the zarr-python pipeline. -# See: https://github.com/zarr-developers/zarr-python/issues/3175 for why this is better than simpler alternatives. -class MultiBasicIndexer(zarr.core.indexing.Indexer): - def __init__(self, indexers: list[zarr.core.indexing.Indexer]): - self.shape = (sum(i.shape[0] for i in indexers), *indexers[0].shape[1:]) - self.drop_axes = indexers[0].drop_axes # maybe? - self.indexers = indexers - - def __iter__(self): - total = 0 - for i in self.indexers: - for c in i: - out_selection = c[2] - gap = out_selection[0].stop - out_selection[0].start - yield type(c)( - c[0], c[1], (slice(total, total + gap), *out_selection[1:]), c[3] - ) - total += gap - - -class AbstractIterableDataset( - Generic[OnDiskArray, InputInMemoryArray, OutputInMemoryArray], metaclass=ABCMeta -): - _shuffle: bool - _preload_nchunks: int - _worker_handle: WorkerHandle - _chunk_size: int - _dataset_manager: AnnDataManager[ - OnDiskArray, InputInMemoryArray, OutputInMemoryArray - ] - - def __init__( - self, - *, - chunk_size: int = 512, - preload_nchunks: int = 32, - shuffle: bool = True, - return_index: bool = False, - batch_size: int = 1, - preload_to_gpu: bool = False, - ): - check_lt_1( - [ - chunk_size, - preload_nchunks, - ], - ["Chunk size", "Preload chunks"], - ) - if batch_size > (chunk_size * preload_nchunks): - raise NotImplementedError( - "If you need batch loading that is bigger than the iterated in-memory size, please open an issue." - ) - self._dataset_manager = AnnDataManager( - # TODO: https://github.com/scverse/anndata/issues/2021 - # on_add=self._cache_update_callback, - return_index=return_index, - batch_size=batch_size, - preload_to_gpu=preload_to_gpu, - ) - self._chunk_size = chunk_size - self._preload_nchunks = preload_nchunks - self._shuffle = shuffle - self._worker_handle = WorkerHandle() - - async def _cache_update_callback(self): - pass - - @abstractmethod - async def _fetch_data( - self, slices: list[slice], dataset_idx: int - ) -> InputInMemoryArray: - """Fetch the data for given slices and the arrays representing a dataset on-disk. - - Args: - slices: The indexing slices to fetch. - dataset_idx: The index of the dataset to fetch from. - - Returns: - The in-memory array data. - """ - ... - - def add_anndatas( - self, - adatas: list[ad.AnnData], - layer_keys: list[str | None] | str | None = None, - obs_keys: list[str] | str | None = None, - ) -> Self: - raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") - - def add_anndata( - self, - adata: ad.AnnData, - layer_key: str | None = None, - obs_key: str | None = None, - ) -> Self: - raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") - - def add_datasets( - self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None - ) -> Self: - self._dataset_manager.add_datasets(datasets, obs) - return self - - def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> Self: - self._dataset_manager.add_dataset(dataset, obs) - return self - - def __len__(self) -> int: - return self._dataset_manager.n_obs - - def __iter__( - self, - ) -> Iterator[ - tuple[InputInMemoryArray, None | np.ndarray] - | tuple[InputInMemoryArray, None | np.ndarray, np.ndarray] - ]: - """Iterate over the on-disk datasets. - - Yields: - A one-row in-memory array optionally with its label. - """ - yield from self._dataset_manager.iter( - self._chunk_size, - self._worker_handle, - self._preload_nchunks, - self._shuffle, - self._fetch_data, - ) - - -AbstractIterableDataset.add_dataset.__doc__ = add_dataset_docstring -AbstractIterableDataset.add_datasets.__doc__ = add_dataset_docstring - - -class ZarrDenseDataset(AbstractIterableDataset, IterableDataset): - async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> np.ndarray: - dataset = self._dataset_manager.train_datasets[dataset_idx] - indexer = MultiBasicIndexer( - [ - zarr.core.indexing.BasicIndexer( - (s, Ellipsis), - shape=dataset.metadata.shape, - chunk_grid=dataset.metadata.chunk_grid, - ) - for s in slices - ] - ) - res = cast( - "np.ndarray", - await dataset._async_array._get_selection( - indexer, prototype=zarr.core.buffer.default_buffer_prototype() - ), - ) - return res - - -ZarrDenseDataset.__init__.__doc__ = __init_docstring__.format(array_type="dense") - - -class CSRDatasetElems(NamedTuple): - indptr: np.ndarray - indices: zarr.AsyncArray - data: zarr.AsyncArray - - -class ZarrSparseDataset(AbstractIterableDataset, IterableDataset): - _dataset_elem_cache: dict[int, CSRDatasetElems] = {} - - def _cache_update_callback(self): - """Callback for when datasets are added to ensure the cache is updated.""" - return zsync.sync(self._ensure_cache()) - - async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems: - """Fetch the in-memory indptr, and backed indices and data for a given dataset index. - - Args: - idx: The index - - Returns: - The constituent elems of the CSR dataset. - """ - indptr = await self._dataset_manager.train_datasets[ - idx - ].group._async_group.getitem("indptr") - return CSRDatasetElems( - *( - await asyncio.gather( - indptr.getitem(Ellipsis), - self._dataset_manager.train_datasets[ - idx - ].group._async_group.getitem("indices"), - self._dataset_manager.train_datasets[ - idx - ].group._async_group.getitem("data"), - ) - ) - ) - - async def _ensure_cache(self): - """Build up the cache of datasets i.e., in-memory indptr, and backed indices and data.""" - arr_idxs = [ - idx - for idx in range(len(self._dataset_manager.train_datasets)) - if idx not in self._dataset_elem_cache - ] - all_elems = await asyncio.gather( - *( - self._create_sparse_elems(idx) - for idx in range(len(self._dataset_manager.train_datasets)) - if idx not in self._dataset_elem_cache - ) - ) - for idx, elems in zip(arr_idxs, all_elems, strict=True): - self._dataset_elem_cache[idx] = elems - - async def _get_sparse_elems(self, dataset_idx: int) -> CSRDatasetElems: - """Return the arrays (zarr or otherwise) needed to represent on-disk data at a given index. - - Args: - dataset_idx: The index of the dataset whose arrays are sought. - - Returns: - The arrays representing the sparse data. - """ - if dataset_idx not in self._dataset_elem_cache: - await self._ensure_cache() - return self._dataset_elem_cache[dataset_idx] - - async def _fetch_data( - self, - slices: list[slice], - dataset_idx: int, - ) -> CSRContainer: - # See https://github.com/scverse/anndata/blob/361325fc621887bf4f381e9412b150fcff599ff7/src/anndata/_core/sparse_dataset.py#L272-L295 - # for the inspiration of this function. - indptr, indices, data = await self._get_sparse_elems(dataset_idx) - indptr_indices = [indptr[slice(s.start, s.stop + 1)] for s in slices] - indptr_limits = [slice(i[0], i[-1]) for i in indptr_indices] - indexer = MultiBasicIndexer( - [ - zarr.core.indexing.BasicIndexer( - (l,), shape=data.metadata.shape, chunk_grid=data.metadata.chunk_grid - ) - for l in indptr_limits - ] - ) - data_np, indices_np = await asyncio.gather( - data._get_selection( - indexer, prototype=zarr.core.buffer.default_buffer_prototype() - ), - indices._get_selection( - indexer, prototype=zarr.core.buffer.default_buffer_prototype() - ), - ) - gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) - offsets = accumulate(chain([indptr_limits[0].start], gaps)) - start_indptr = indptr_indices[0] - next(offsets) - if len(slices) < 2: # there is only one slice so no need to concatenate - return CSRContainer( - elems=(data_np, indices_np, start_indptr), - shape=(start_indptr.shape[0] - 1, self._dataset_manager.n_var), - ) - end_indptr = np.concatenate( - [s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)] - ) - indptr_np = np.concatenate([start_indptr, end_indptr]) - return CSRContainer( - elems=(data_np, indices_np, indptr_np), - shape=(indptr_np.shape[0] - 1, self._dataset_manager.n_var), - ) - - -ZarrSparseDataset.__init__.__doc__ = __init_docstring__.format(array_type="sparse") diff --git a/biome.jsonc b/biome.jsonc new file mode 100644 index 0000000..9f8f220 --- /dev/null +++ b/biome.jsonc @@ -0,0 +1,17 @@ +{ + "$schema": "https://biomejs.dev/schemas/2.2.0/schema.json", + "vcs": { "enabled": true, "clientKind": "git", "useIgnoreFile": true }, + "formatter": { "useEditorconfig": true }, + "overrides": [ + { + "includes": ["./.vscode/*.json", "**/*.jsonc"], + "json": { + "formatter": { "trailingCommas": "all" }, + "parser": { + "allowComments": true, + "allowTrailingCommas": true, + }, + }, + }, + ], +} diff --git a/tests/test_notebooks.py b/docs/_static/.gitkeep similarity index 100% rename from tests/test_notebooks.py rename to docs/_static/.gitkeep diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css new file mode 100644 index 0000000..b8c8d47 --- /dev/null +++ b/docs/_static/css/custom.css @@ -0,0 +1,4 @@ +/* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ +div.cell_output table.dataframe { + font-size: 0.8em; +} diff --git a/docs/_templates/.gitkeep b/docs/_templates/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst new file mode 100644 index 0000000..7b4a0cf --- /dev/null +++ b/docs/_templates/autosummary/class.rst @@ -0,0 +1,61 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. add toctree option to make autodoc generate the pages + +.. autoclass:: {{ objname }} + +{% block attributes %} +{% if attributes %} +Attributes table +~~~~~~~~~~~~~~~~ + +.. autosummary:: +{% for item in attributes %} + ~{{ name }}.{{ item }} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block methods %} +{% if methods %} +Methods table +~~~~~~~~~~~~~ + +.. autosummary:: +{% for item in methods %} + {%- if item != '__init__' %} + ~{{ name }}.{{ item }} + {%- endif -%} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block attributes_documentation %} +{% if attributes %} +Attributes +~~~~~~~~~~ + +{% for item in attributes %} + +.. autoattribute:: {{ [objname, item] | join(".") }} +{%- endfor %} + +{% endif %} +{% endblock %} + +{% block methods_documentation %} +{% if methods %} +Methods +~~~~~~~ + +{% for item in methods %} +{%- if item != '__init__' %} + +.. automethod:: {{ [objname, item] | join(".") }} +{%- endif -%} +{%- endfor %} + +{% endif %} +{% endblock %} diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000..5932792 --- /dev/null +++ b/docs/api.md @@ -0,0 +1 @@ +# API diff --git a/docs/changelog.md b/docs/changelog.md index 35dec83..d9e79ba 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,3 @@ -# Changelog +```{include} ../CHANGELOG.md - -Name | PR | Developer | Date | Version ---- | --- | --- | --- | --- +``` diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..52834d0 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,131 @@ +# Configuration file for the Sphinx documentation builder. + +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- +import sys +from datetime import datetime +from importlib.metadata import metadata +from pathlib import Path + +HERE = Path(__file__).parent +sys.path.insert(0, str(HERE / "extensions")) + + +# -- Project information ----------------------------------------------------- + +# NOTE: If you installed your project in editable mode, this might be stale. +# If this is the case, reinstall it to refresh the metadata +info = metadata("arrayloaders") +project_name = info["Name"] +author = info["Author"] +copyright = f"{datetime.now():%Y}, {author}." +version = info["Version"] +urls = dict(pu.split(", ") for pu in info.get_all("Project-URL")) +repository_url = urls["Source"] + +# The full version, including alpha/beta/rc tags +release = info["Version"] + +bibtex_bibfiles = ["references.bib"] +templates_path = ["_templates"] +nitpicky = True # Warn about broken links +needs_sphinx = "4.0" + +html_context = { + "display_github": True, # Integrate GitHub + "github_user": "scverse", + "github_repo": project_name, + "github_version": "main", + "conf_py_path": "/docs/", +} + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. +# They can be extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [ + "myst_nb", + "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinxcontrib.bibtex", + "sphinx_autodoc_typehints", + "sphinx_tabs.tabs", + "sphinx.ext.mathjax", + "IPython.sphinxext.ipython_console_highlighting", + "sphinxext.opengraph", + *[p.stem for p in (HERE / "extensions").glob("*.py")], +] + +autosummary_generate = True +autodoc_member_order = "groupwise" +default_role = "literal" +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_use_rtype = True # having a separate entry generally helps readability +napoleon_use_param = True +myst_heading_anchors = 6 # create anchors for h1-h6 +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", + "html_admonition", +] +myst_url_schemes = ("http", "https", "mailto") +nb_output_stderr = "remove" +nb_execution_mode = "off" +nb_merge_streams = True +typehints_defaults = "braces" + +source_suffix = { + ".rst": "restructuredtext", + ".ipynb": "myst-nb", + ".myst": "myst-nb", +} + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "anndata": ("https://anndata.readthedocs.io/en/stable/", None), + "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), +} + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_book_theme" +html_static_path = ["_static"] +html_css_files = ["css/custom.css"] + +html_title = project_name + +html_theme_options = { + "repository_url": repository_url, + "use_repository_button": True, + "path_to_docs": "docs/", + "navigation_with_keys": False, +} + +pygments_style = "default" + +nitpick_ignore = [ + # If building the documentation fails because of a missing link that is outside your control, + # you can add an exception to this list. + # ("py:class", "igraph.Graph"), +] diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000..3816d29 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,214 @@ +# Contributing guide + +Scanpy provides extensive [developer documentation][scanpy developer guide], most of which applies to this project, too. +This document will not reproduce the entire content from there. +Instead, it aims at summarizing the most important information to get you started on contributing. + +We assume that you are already familiar with git and with making pull requests on GitHub. +If not, please refer to the [scanpy developer guide][]. + +[scanpy developer guide]: https://scanpy.readthedocs.io/en/latest/dev/index.html + +## Installing dev dependencies + +In addition to the packages needed to _use_ this package, +you need additional python packages to [run tests](#writing-tests) and [build the documentation](#docs-building). + +:::::{tabs} +::::{group-tab} Hatch +The easiest way is to get familiar with [hatch environments][], with which these tasks are simply: + +```bash +hatch test # defined in the table [tool.hatch.envs.hatch-test] in pyproject.toml +hatch run docs:build # defined in the table [tool.hatch.envs.docs] +``` + +:::: + +::::{group-tab} Pip +If you prefer managing environments manually, you can use `pip`: + +```bash +cd arrayloaders +python3 -m venv .venv +source .venv/bin/activate +pip install -e ".[dev,test,doc]" +``` + +:::: +::::: + +[hatch environments]: https://hatch.pypa.io/latest/tutorials/environment/basic-usage/ + +## Code-style + +This package uses [pre-commit][] to enforce consistent code-styles. +On every commit, pre-commit checks will either automatically fix issues with the code, or raise an error message. + +To enable pre-commit locally, simply run + +```bash +pre-commit install +``` + +in the root of the repository. +Pre-commit will automatically download all dependencies when it is run for the first time. + +Alternatively, you can rely on the [pre-commit.ci][] service enabled on GitHub. +If you didn't run `pre-commit` before pushing changes to GitHub it will automatically commit fixes to your pull request, or show an error message. + +If pre-commit.ci added a commit on a branch you still have been working on locally, simply use + +```bash +git pull --rebase +``` + +to integrate the changes into yours. +While the [pre-commit.ci][] is useful, we strongly encourage installing and running pre-commit locally first to understand its usage. + +Finally, most editors have an _autoformat on save_ feature. +Consider enabling this option for [ruff][ruff-editors] and [biome][biome-editors]. + +[pre-commit]: https://pre-commit.com/ +[pre-commit.ci]: https://pre-commit.ci/ +[ruff-editors]: https://docs.astral.sh/ruff/integrations/ +[biome-editors]: https://biomejs.dev/guides/integrate-in-editor/ + +(writing-tests)= + +## Writing tests + +This package uses [pytest][] for automated testing. +Please write {doc}`scanpy:dev/testing` for every function added to the package. + +Most IDEs integrate with pytest and provide a GUI to run tests. +Just point yours to one of the environments returned by + +```bash +hatch env create hatch-test # create test environments for all supported versions +hatch env find hatch-test # list all possible test environment paths +``` + +Alternatively, you can run all tests from the command line by executing + +:::::{tabs} +::::{group-tab} Hatch + +```bash +hatch test # test with the highest supported Python version +# or +hatch test --all # test with all supported Python versions +``` + +:::: + +::::{group-tab} Pip + +```bash +source .venv/bin/activate +pytest +``` + +:::: +::::: + +in the root of the repository. + +[pytest]: https://docs.pytest.org/ + +### Continuous integration + +Continuous integration will automatically run the tests on all pull requests and test +against the minimum and maximum supported Python version. + +Additionally, there's a CI job that tests against pre-releases of all dependencies (if there are any). +The purpose of this check is to detect incompatibilities of new package versions early on and +gives you time to fix the issue or reach out to the developers of the dependency before the package is released to a wider audience. + +## Publishing a release + +### Updating the version number + +Before making a release, you need to update the version number in the `pyproject.toml` file. +Please adhere to [Semantic Versioning][semver], in brief + +> Given a version number MAJOR.MINOR.PATCH, increment the: +> +> 1. MAJOR version when you make incompatible API changes, +> 2. MINOR version when you add functionality in a backwards compatible manner, and +> 3. PATCH version when you make backwards compatible bug fixes. +> +> Additional labels for pre-release and build metadata are available as extensions to the MAJOR.MINOR.PATCH format. + +Once you are done, commit and push your changes and navigate to the "Releases" page of this project on GitHub. +Specify `vX.X.X` as a tag name and create a release. +For more information, see [managing GitHub releases][]. +This will automatically create a git tag and trigger a Github workflow that creates a release on [PyPI][]. + +[semver]: https://semver.org/ +[managing GitHub releases]: https://docs.github.com/en/repositories/releasing-projects-on-github/managing-releases-in-a-repository +[pypi]: https://pypi.org/ + +## Writing documentation + +Please write documentation for new or changed features and use-cases. +This project uses [sphinx][] with the following features: + +- The [myst][] extension allows to write documentation in markdown/Markedly Structured Text +- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension). +- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks)) +- [sphinx-autodoc-typehints][], to automatically reference annotated input and output types +- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/) + +See scanpy’s {doc}`scanpy:dev/documentation` for more information on how to write your own. + +[sphinx]: https://www.sphinx-doc.org/en/master/ +[myst]: https://myst-parser.readthedocs.io/en/latest/intro.html +[myst-nb]: https://myst-nb.readthedocs.io/en/latest/ +[numpydoc-napoleon]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html +[numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html +[sphinx-autodoc-typehints]: https://github.com/tox-dev/sphinx-autodoc-typehints + +### Tutorials with myst-nb and jupyter notebooks + +The documentation is set-up to render jupyter notebooks stored in the `docs/notebooks` directory using [myst-nb][]. +Currently, only notebooks in `.ipynb` format are supported that will be included with both their input and output cells. +It is your responsibility to update and re-run the notebook whenever necessary. + +If you are interested in automatically running notebooks as part of the continuous integration, +please check out [this feature request][issue-render-notebooks] in the `cookiecutter-scverse` repository. + +[issue-render-notebooks]: https://github.com/scverse/cookiecutter-scverse/issues/40 + +#### Hints + +- If you refer to objects from other packages, please add an entry to `intersphinx_mapping` in `docs/conf.py`. + Only if you do so can sphinx automatically create a link to the external documentation. +- If building the documentation fails because of a missing link that is outside your control, + you can add an entry to the `nitpick_ignore` list in `docs/conf.py` + +(docs-building)= + +#### Building the docs locally + +:::::{tabs} +::::{group-tab} Hatch + +```bash +hatch run docs:build +hatch run docs:open +``` + +:::: + +::::{group-tab} Pip + +```bash +source .venv/bin/activate +cd docs +sphinx-build -M html . _build -W +(xdg-)open _build/html/index.html +``` + +:::: +::::: diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py new file mode 100644 index 0000000..7f637d7 --- /dev/null +++ b/docs/extensions/typed_returns.py @@ -0,0 +1,35 @@ +# code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py +# with some minor adjustment +from __future__ import annotations + +import re + +from sphinx.ext.napoleon import NumpyDocstring +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sphinx.application import Sphinx + from collections.abc import Generator, Iterable + + +def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: + for line in lines: + if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): + yield f"-{m['param']} (:class:`~{m['type']}`)" + else: + yield line + + +def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: + lines_raw = self._dedent(self._consume_to_next_section()) + if lines_raw[0] == ":": + del lines_raw[0] + lines = self._format_block(":returns: ", list(_process_return(lines_raw))) + if lines and lines[-1]: + lines.append("") + return lines + + +def setup(app: Sphinx): + """Set app.""" + NumpyDocstring._parse_returns_section = _parse_returns_section diff --git a/docs/guide.md b/docs/guide.md deleted file mode 100644 index f1fbf33..0000000 --- a/docs/guide.md +++ /dev/null @@ -1,7 +0,0 @@ -# Guide - -```{toctree} -:maxdepth: 1 - -quickstart -``` diff --git a/docs/index.md b/docs/index.md index f08b2cb..8b5f298 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,9 +3,13 @@ ``` ```{toctree} +:hidden: true :maxdepth: 1 -:hidden: -reference -changelog +api.md +changelog.md +contributing.md +references.md + +notebooks/example ``` diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb new file mode 100644 index 0000000..43fb8d6 --- /dev/null +++ b/docs/notebooks/example.ipynb @@ -0,0 +1,171 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example notebook" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from anndata import AnnData\n", + "import pandas as pd\n", + "import arrayloaders" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "adata = AnnData(np.random.normal(size=(20, 10)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With myst it is possible to link in the text cell of a notebook such as this one the documentation of a function or a class.\n", + "\n", + "Let's take as an example the function {func}`arrayloaders.pp.basic_preproc`. \n", + "You can see that by clicking on the text, the link redirects to the API documentation of the function. \n", + "Check the raw markdown of this cell to understand how this is specified.\n", + "\n", + "This works also for any package listed by `intersphinx`. Go to `docs/conf.py` and look for the `intersphinx_mapping` variable. \n", + "There, you will see a list of packages (that this package is dependent on) for which this functionality is supported. \n", + "\n", + "For instance, we can link to the class {class}`anndata.AnnData`, to the attribute {attr}`anndata.AnnData.obs` or the method {meth}`anndata.AnnData.write`.\n", + "\n", + "Again, check the raw markdown of this cell to see how each of these links are specified.\n", + "\n", + "You can read more about this in the [intersphinx page](https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html) and the [myst page](https://myst-parser.readthedocs.io/en/v0.15.1/syntax/syntax.html#roles-an-in-line-extension-point)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Implement a preprocessing function here." + ] + }, + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "arrayloaders.pp.basic_preproc(adata)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
AB
0a1
1b2
2c3
\n", + "
" + ], + "text/plain": [ + " A B\n", + "0 a 1\n", + "1 b 2\n", + "2 c 3" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame().assign(A=[\"a\", \"b\", \"c\"], B=[1, 2, 3])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.12 ('squidpy39')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + }, + "vscode": { + "interpreter": { + "hash": "ae6466e8d4f517858789b5c9e8f0ed238fb8964458a36305fca7bddc149e9c64" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/reference.md b/docs/reference.md deleted file mode 100644 index f41f481..0000000 --- a/docs/reference.md +++ /dev/null @@ -1,5 +0,0 @@ -# Reference - -```{eval-rst} -.. automodule:: arrayloaders -``` diff --git a/docs/references.bib b/docs/references.bib new file mode 100644 index 0000000..9f5bed4 --- /dev/null +++ b/docs/references.bib @@ -0,0 +1,10 @@ +@article{Virshup_2023, + doi = {10.1038/s41587-023-01733-8}, + url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, + year = 2023, + month = {apr}, + publisher = {Springer Science and Business Media {LLC}}, + author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and}, + title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, + journal = {Nature Biotechnology} +} diff --git a/docs/references.md b/docs/references.md new file mode 100644 index 0000000..00ad6a6 --- /dev/null +++ b/docs/references.md @@ -0,0 +1,5 @@ +# References + +```{bibliography} +:cited: +``` diff --git a/noxfile.py b/noxfile.py deleted file mode 100644 index bc7289e..0000000 --- a/noxfile.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -import nox -from laminci.nox import build_docs, run, run_pre_commit, run_pytest - -# we'd like to aggregate coverage information across sessions -# and for this the code needs to be located in the same -# directory in every github action runner -# this also allows to break out an installation section -nox.options.default_venv_backend = "none" - - -@nox.session -def lint(session: nox.Session) -> None: - run_pre_commit(session) - - -@nox.session() -def test(session): - run(session, "uv pip install --system -e .[dev]") - run_pytest(session) - - -@nox.session() -def docs(session): - build_docs(session, strict=False) diff --git a/pyproject.toml b/pyproject.toml index 2d82f46..694a065 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,157 +1,163 @@ [build-system] -requires = ["flit_core >=3.2,<4"] -build-backend = "flit_core.buildapi" +build-backend = "hatchling.build" +requires = [ "hatchling" ] [project] name = "arrayloaders" -requires-python = ">=3.11,<3.13" -authors = [{ name = "Lamin Labs", email = "open-source@lamin.ai" }] +version = "0.0.1" +description = "A minibatch loader for anndata store" readme = "README.md" -dynamic = ["version", "description"] +license = { file = "LICENSE" } +maintainers = [ + { name = "Ilan Gold", email = "ilan.gold@scverse.org" }, + { name = "Felix Fischer", email = "felix.fischer@lamin.ai" }, +] +authors = [ + { name = "Ilan Gold" }, + { name = "Felix Fishcer" }, +] +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] dependencies = [ - "anndata[lazy]>=0.12.0rc1", - "zarr>=3", - "lightning", - "torch", - "dask[array]", - "tqdm", - "universal_pathlib>=0.2.6" + "anndata[dask]", + "dask", + "scipy>1.15", + # for debug logging (referenced from the issue template) + "session-info2", + "tqdm", + "zarr>=3", +] +optional-dependencies.dev = [ + "pre-commit", + "twine>=4.0.2", +] +optional-dependencies.doc = [ + "docutils>=0.8,!=0.18.*,!=0.19.*", + "ipykernel", + "ipython", + "myst-nb>=1.1", + "pandas", + # Until pybtex >0.24.0 releases: https://bitbucket.org/pybtex-devs/pybtex/issues/169/ + "setuptools", + "sphinx>=8.1", + "sphinx-autodoc-typehints", + "sphinx-book-theme>=1", + "sphinx-copybutton", + "sphinx-tabs", + "sphinxcontrib-bibtex>=1", + "sphinxext-opengraph", ] +optional-dependencies.gpu = [ + "cupy-cuda12x", +] +optional-dependencies.test = [ + "coverage", + "pytest", + "torch", + "zarrs", +] +# https://docs.pypi.org/project_metadata/#project-urls +urls.Documentation = "https://arrayloaders.readthedocs.io/" +urls.Homepage = "https://github.com/laminlabs/arrayloaders" +urls.Source = "https://github.com/laminlabs/arrayloaders" -[project.urls] -Home = "https://github.com/laminlabs/arrayloaders" +[tool.hatch.envs.default] +installer = "uv" +features = [ "dev" ] -[project.optional-dependencies] -dev = [ - "zarrs", - "scipy>1.15.0", # rng argument to scipy.sparse.random - "lamindb", - "pre-commit", - "nox", - "pytest>=6.0", - "pytest-cov", - "nbproject_test", -] -gpu = ["cupy-cuda12x"] +[tool.hatch.envs.docs] +features = [ "doc" ] +scripts.build = "sphinx-build -M html docs docs/_build -W {args}" +scripts.open = "python -m webbrowser -t docs/_build/html/index.html" +scripts.clean = "git clean -fdX -- {args:docs}" -[tool.pytest.ini_options] -testpaths = [ - "tests", -] -filterwarnings = [ - "ignore:Jupyter is migrating its paths to use standard platformdirs:DeprecationWarning" +# Test the lowest and highest supported Python versions with normal deps +[[tool.hatch.envs.hatch-test.matrix]] +deps = [ "stable" ] +python = [ "3.11", "3.13" ] + +# Test the newest supported Python version also with pre-release deps +[[tool.hatch.envs.hatch-test.matrix]] +deps = [ "pre" ] +python = [ "3.13" ] + +[tool.hatch.envs.hatch-test] +features = [ "test" ] + +[tool.hatch.envs.hatch-test.overrides] +# If the matrix variable `deps` is set to "pre", +# set the environment variable `UV_PRERELEASE` to "allow". +matrix.deps.env-vars = [ + { key = "UV_PRERELEASE", value = "allow", if = [ "pre" ] }, ] [tool.ruff] -line-length = 88 -src = ["src"] +line-length = 120 +src = [ "src" ] +extend-include = [ "*.ipynb" ] + +format.docstring-code-format = true -[tool.ruff.lint] -select = [ - "F", # Errors detected by Pyflakes - "E", # Error detected by Pycodestyle - "W", # Warning detected by Pycodestyle - "I", # isort - "D", # pydocstyle - "B", # flake8-bugbear - "TID", # flake8-tidy-imports - "C4", # flake8-comprehensions - "BLE", # flake8-blind-except - "UP", # pyupgrade - "RUF100", # Report unused noqa directives - "TC", # Typing imports - "NPY", # Numpy specific rules - "PTH", # Use pathlib - "S" # Security +lint.select = [ + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C4", # flake8-comprehensions + "D", # pydocstyle + "E", # Error detected by Pycodestyle + "F", # Errors detected by Pyflakes + "I", # isort + "RUF100", # Report unused noqa directives + "TC", # Typing imports + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # Warning detected by Pycodestyle ] -ignore = [ - # Do not catch blind exception: `Exception` - "BLE001", - # Errors from function calls in argument defaults. These are fine when the result is immutable. - "B008", - # line too long -> we accept long comment lines; black gets rid of long code lines - "E501", - # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient - "E731", - # allow I, O, l as variable names -> I is the identity matrix - "E741", - # Missing docstring in public module - "D100", - # undocumented-public-class - "D101", - # Missing docstring in public method - "D102", - # Missing docstring in public function - "D103", - # Missing docstring in public package - "D104", - # __magic__ methods are are often self-explanatory, allow missing docstrings - "D105", - # Missing docstring in public nested class - "D106", - # Missing docstring in __init__ - "D107", - ## Disable one in each pair of mutually incompatible rules - # We don’t want a blank line before a class docstring - "D203", - # 1 blank line required after class docstring - "D204", - # first line should end with a period [Bug: doesn't work with single-line docstrings] - # We want docstrings to start immediately after the opening triple quote - "D213", - # Section underline is over-indented ("{name}") - "D215", - # First line should end with a period - "D400", - # First line should be in imperative mood; try rephrasing - "D401", - # First word of the first line should be capitalized: {} -> {} - "D403", - # First word of the docstring should not be "This" - "D404", - # Section name should end with a newline ("{name}") - "D406", - # Missing dashed underline after section ("{name}") - "D407", - # Section underline should be in the line following the section's name ("{name}") - "D408", - # Section underline should match the length of its name ("{name}") - "D409", - # No blank lines allowed between a section header and its content ("{name}") - "D412", - # Missing blank line after last section ("{name}") - "D413", - # camcelcase imported as lowercase - "N813", - # module import not at top level of file - "E402", - # open()` should be replaced by `Path.open() - "PTH123", - # subprocess` call: check for execution of untrusted input - https://github.com/PyCQA/bandit/issues/333 - "S603", - # Starting a process with a partial executable path - "S607" +lint.ignore = [ + "B008", # Errors from function calls in argument defaults. These are fine when the result is immutable. + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # __magic__ methods are often self-explanatory, allow missing docstrings + "D107", # Missing docstring in __init__ + # Disable one in each pair of mutually incompatible rules + "D203", # We don’t want a blank line before a class docstring + "D213", # <> We want docstrings to start immediately after the opening triple quote + "D400", # first line should end with a period [Bug: doesn’t work with single-line docstrings] + "D401", # First line should be in imperative mood; try rephrasing + "E501", # line too long -> we accept long comment lines; formatter gets rid of long code lines + "E731", # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient + "E741", # allow I, O, l as variable names -> I is the identity matrix ] +lint.per-file-ignores."*/__init__.py" = [ "F401" ] +lint.per-file-ignores."docs/*" = [ "I" ] +lint.per-file-ignores."tests/*" = [ "D" ] +lint.pydocstyle.convention = "numpy" -[tool.ruff.lint.pydocstyle] -convention = "google" - -[tool.ruff.lint.per-file-ignores] -"docs/*" = ["I"] -"tests/**/*.py" = [ - "D", # docstrings are allowed to look a bit off - "S101", # asserts allowed in tests... - "ARG", # Unused function args -> fixtures nevertheless are functionally relevant... - "FBT", # Don't care about booleans as positional arguments in tests, e.g. via @pytest.mark.parametrize() - "PLR2004", # Magic value used in comparison, ... - "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes +[tool.pytest.ini_options] +testpaths = [ "tests" ] +xfail_strict = true +addopts = [ + "--import-mode=importlib", # allow using test files with same name ] -"*/__init__.py" = ["F401"] -[tool.ruff.lint.isort] -known-first-party = [ "arrayloaders" ] -required-imports = [ "from __future__ import annotations" ] +[tool.coverage.run] +source = [ "arrayloaders" ] +omit = [ + "**/test_*.py", +] -[tool.ruff.lint.flake8-type-checking] -exempt-modules = [ ] -strict = true +[tool.cruft] +skip = [ + "tests", + "src/**/__init__.py", + "src/**/basic.py", + "docs/api.md", + "docs/changelog.md", + "docs/references.bib", + "docs/references.md", + "docs/notebooks/example.ipynb", +] diff --git a/src/arrayloaders/__init__.py b/src/arrayloaders/__init__.py new file mode 100644 index 0000000..6fa9d1e --- /dev/null +++ b/src/arrayloaders/__init__.py @@ -0,0 +1,9 @@ +from importlib.metadata import version + +from .dense import ZarrDenseDataset +from .io import add_h5ads_to_store, create_store_from_h5ads, write_sharded +from .sparse import ZarrSparseDataset + +__version__ = version("arrayloaders") + +__all__ = ["ZarrSparseDataset", "ZarrDenseDataset", "write_sharded", "add_h5ads_to_store", "create_store_from_h5ads"] diff --git a/src/arrayloaders/abc.py b/src/arrayloaders/abc.py new file mode 100644 index 0000000..a420153 --- /dev/null +++ b/src/arrayloaders/abc.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Generic + +from arrayloaders.anndata_manager import AnnDataManager +from arrayloaders.types import InputInMemoryArray, OnDiskArray, OutputInMemoryArray +from arrayloaders.utils import WorkerHandle, add_dataset_docstring, check_lt_1 + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Self + + import anndata as ad + import numpy as np + + +class AbstractIterableDataset(Generic[OnDiskArray, InputInMemoryArray, OutputInMemoryArray], metaclass=ABCMeta): # noqa: D101 + _shuffle: bool + _preload_nchunks: int + _worker_handle: WorkerHandle + _chunk_size: int + _dataset_manager: AnnDataManager[OnDiskArray, InputInMemoryArray, OutputInMemoryArray] + + def __init__( + self, + *, + chunk_size: int = 512, + preload_nchunks: int = 32, + shuffle: bool = True, + return_index: bool = False, + batch_size: int = 1, + preload_to_gpu: bool = False, + ): + check_lt_1( + [ + chunk_size, + preload_nchunks, + ], + ["Chunk size", "Preload chunks"], + ) + if batch_size > (chunk_size * preload_nchunks): + raise NotImplementedError( + "If you need batch loading that is bigger than the iterated in-memory size, please open an issue." + ) + self._dataset_manager = AnnDataManager( + # TODO: https://github.com/scverse/anndata/issues/2021 + # on_add=self._cache_update_callback, + return_index=return_index, + batch_size=batch_size, + preload_to_gpu=preload_to_gpu, + ) + self._chunk_size = chunk_size + self._preload_nchunks = preload_nchunks + self._shuffle = shuffle + self._worker_handle = WorkerHandle() + + async def _cache_update_callback(self): + pass + + @abstractmethod + async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> InputInMemoryArray: + """Fetch the data for given slices and the arrays representing a dataset on-disk. + + Args: + slices: The indexing slices to fetch. + dataset_idx: The index of the dataset to fetch from. + + Returns + ------- + The in-memory array data. + """ + ... + + def add_anndatas( # noqa: D102 + self, + adatas: list[ad.AnnData], + layer_keys: list[str | None] | str | None = None, + obs_keys: list[str] | str | None = None, + ) -> Self: + raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") + + def add_anndata( # noqa: D102 + self, + adata: ad.AnnData, + layer_key: str | None = None, + obs_key: str | None = None, + ) -> Self: + raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") + + @abstractmethod + def _validate(self, datasets: list[OnDiskArray]) -> None: ... + + def add_datasets(self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None) -> Self: # noqa: D102 + self._validate(datasets) + self._dataset_manager.add_datasets(datasets, obs) + return self + + def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> Self: # noqa: D102 + self._validate([dataset]) + self._dataset_manager.add_dataset(dataset, obs) + return self + + def __len__(self) -> int: + return self._dataset_manager.n_obs + + def __iter__( + self, + ) -> Iterator[ + tuple[InputInMemoryArray, None | np.ndarray] | tuple[InputInMemoryArray, None | np.ndarray, np.ndarray] + ]: + """Iterate over the on-disk datasets. + + Yields + ------ + A one-row in-memory array optionally with its label. + """ + yield from self._dataset_manager.iter( + self._chunk_size, + self._worker_handle, + self._preload_nchunks, + self._shuffle, + self._fetch_data, + ) + + +AbstractIterableDataset.add_dataset.__doc__ = add_dataset_docstring +AbstractIterableDataset.add_datasets.__doc__ = add_dataset_docstring diff --git a/src/arrayloaders/anndata_manager.py b/src/arrayloaders/anndata_manager.py new file mode 100644 index 0000000..0663c70 --- /dev/null +++ b/src/arrayloaders/anndata_manager.py @@ -0,0 +1,348 @@ +from __future__ import annotations + +import math +from collections import OrderedDict, defaultdict +from types import NoneType +from typing import TYPE_CHECKING, Generic + +import anndata as ad +import numpy as np +import zarr.core.sync as zsync +from scipy import sparse as sp + +from arrayloaders.types import InputInMemoryArray, OnDiskArray, OutputInMemoryArray +from arrayloaders.utils import ( + CSRContainer, + WorkerHandle, + _batched, + add_dataset_docstring, + check_lt_1, + check_var_shapes, + index_datasets, + split_given_size, +) + +try: + from cupy import ndarray as CupyArray + from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover +except ImportError: + CupyCSRMatrix = NoneType + CupyArray = NoneType + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Iterator + from types import ModuleType + +accepted_on_disk_types = OnDiskArray.__constraints__ + + +class AnnDataManager(Generic[OnDiskArray, InputInMemoryArray, OutputInMemoryArray]): # noqa: D101 + train_datasets: list[OnDiskArray] = [] + labels: list[np.ndarray] | None = None + _return_index: bool = False + _on_add: Callable | None = None + _batch_size: int = 1 + _shapes: list[tuple[int, int]] = [] + _preload_to_gpu: bool = False + + def __init__( + self, + *, + on_add: Callable | None = None, + return_index: bool = False, + batch_size: int = 1, + preload_to_gpu: bool = False, + ): + self._on_add = on_add + self._return_index = return_index + self._batch_size = batch_size + self._preload_to_gpu = preload_to_gpu + + @property + def _sp_module(self) -> ModuleType: + if self._preload_to_gpu: + try: + import cupyx.scipy.sparse as cpx # pragma: no cover + + return cpx + except ImportError: + raise ImportError( + "Cannot find cupy module even though `preload_to_gpu` argument was set to `True`" + ) from None + return sp + + @property + def _np_module(self) -> ModuleType: + if self._preload_to_gpu: + try: + import cupy as cp + + return cp + except ImportError: + raise ImportError( + "Cannot find cupy module even though `preload_to_gpu` argument was set to `True`" + ) from None + + return np + + @property + def dataset_type(self) -> type[OnDiskArray]: # noqa: D102 + return type(self.train_datasets[0]) + + @property + def n_obs(self) -> int: # noqa: D102 + return sum(shape[0] for shape in self._shapes) + + @property + def n_var(self) -> int: # noqa: D102 + return self._shapes[0][1] + + def add_anndatas( # noqa: D102 + self, + adatas: list[ad.AnnData], + layer_keys: list[str | None] | str | None = None, + obs_keys: list[str] | str | None = None, + ) -> None: + raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") + + def add_anndata( # noqa: D102 + self, + adata: ad.AnnData, + layer_key: str | None = None, + obs_key: str | None = None, + ) -> None: + raise NotImplementedError("See https://github.com/scverse/anndata/issues/2021") + + def add_datasets(self, datasets: list[OnDiskArray], obs: list[np.ndarray] | None = None) -> None: # noqa: D102 + if obs is None: + obs = [None] * len(datasets) + for ds, o in zip(datasets, obs, strict=True): + self.add_dataset(ds, o) + + def add_dataset(self, dataset: OnDiskArray, obs: np.ndarray | None = None) -> None: # noqa: D102 + if len(self.train_datasets) > 0: + if self.labels is None and obs is not None: + raise ValueError( + f"Cannot add a dataset with obs label {obs} when training datasets have already been added without labels" + ) + if self.labels is not None and obs is None: + raise ValueError( + "Cannot add a dataset with no obs label when training datasets have already been added without labels" + ) + if not isinstance(dataset, accepted_types := accepted_on_disk_types): + raise TypeError(f"Cannot add a dataset of type {type(dataset)}, only {accepted_types} are allowed") + if len(self.train_datasets) > 0 and not isinstance(dataset, self.dataset_type): + raise TypeError( + f"Cannot add a dataset whose data of type {type(dataset)} was not an instance of expected type {self.dataset_type}" + ) + datasets = self.train_datasets + [dataset] + check_var_shapes(datasets) + self._shapes = self._shapes + [dataset.shape] + self.train_datasets = datasets + if self.labels is not None: # labels exist + self.labels += [obs] + elif obs is not None: # labels dont exist yet, but are being added for the first time + self.labels = [obs] + if self._on_add is not None: + self._on_add() + + def _get_relative_obs_indices(self, index: slice, *, use_original_space: bool = False) -> list[tuple[slice, int]]: + """Generate a slice relative to a dataset given a global slice index over all datasets. + + For a given slice indexer of axis 0, return a new slice relative to the on-disk + data it represents given the number of total observations as well as the index of + the underlying data on disk from the argument `sparse_datasets` to the initializer. + + For example, given slice index (10, 15), for 4 datasets each with size 5 on axis zero, + this function returns ((0,5), 2) representing slice (0,5) along axis zero of sparse dataset 2. + + Args: + index: The queried slice. + use_original_space: Whether or not the slices should be reindexed against the anndata objects. + + Returns + ------- + A slice relative to the dataset it represents as well as the index of said dataset in `sparse_datasets`. + """ + min_idx = index.start + max_idx = index.stop + curr_pos = 0 + slices = [] + for idx, (n_obs, _) in enumerate(self._shapes): + array_start = curr_pos + array_end = curr_pos + n_obs + + start = max(min_idx, array_start) + stop = min(max_idx, array_end) + if start < stop: + if use_original_space: + slices.append((slice(start, stop), idx)) + else: + relative_start = start - array_start + relative_stop = stop - array_start + slices.append((slice(relative_start, relative_stop), idx)) + curr_pos += n_obs + return slices + + def _slices_to_slices_with_array_index( + self, slices: list[slice], *, use_original_space: bool = False + ) -> OrderedDict[int, list[slice]]: + """Given a list of slices, give the lookup between on-disk datasets and slices relative to that dataset. + + Args: + slices: Slices to relative to the on-disk datasets. + use_original_space: Whether or not the slices should be reindexed against the anndata objects. + + Returns + ------- + A lookup between the dataset and its indexing slices, ordered by keys. + """ + dataset_index_to_slices: defaultdict[int, list[slice]] = defaultdict(list) + for slice in slices: + for relative_obs_indices in self._get_relative_obs_indices(slice, use_original_space=use_original_space): + dataset_index_to_slices[relative_obs_indices[1]] += [relative_obs_indices[0]] + keys = sorted(dataset_index_to_slices.keys()) + dataset_index_to_slices_sorted = OrderedDict() + for k in keys: + dataset_index_to_slices_sorted[k] = dataset_index_to_slices[k] + return dataset_index_to_slices_sorted + + def _get_chunks(self, chunk_size: int, worker_handle: WorkerHandle, shuffle: bool) -> np.ndarray: + """Get a potentially shuffled list of chunk ids, accounting for the fact that this dataset might be inside a worker. + + Returns + ------- + A :class:`numpy.ndarray` of chunk ids. + """ + chunks = np.arange(math.ceil(self.n_obs / chunk_size)) + if shuffle: + worker_handle.shuffle(chunks) + + return worker_handle.get_part_for_worker(chunks) + + def iter( + self, + chunk_size: int, + worker_handle: WorkerHandle, + preload_nchunks: int, + shuffle: bool, + fetch_data: Callable[[list[slice], int], Awaitable[np.ndarray | CSRContainer]], + ) -> Iterator[ + tuple[InputInMemoryArray, None | np.ndarray] | tuple[InputInMemoryArray, None | np.ndarray, np.ndarray] + ]: + """Iterate over the on-disk csr datasets. + + Yields + ------ + A one-row sparse matrix. + """ + check_lt_1( + [len(self.train_datasets), self.n_obs], + ["Number of datasets", "Number of observations"], + ) + # In order to handle data returned where (chunk_size * preload_nchunks) mod batch_size != 0 + # we must keep track of the leftover data. + in_memory_data = None + in_memory_labels = None + in_memory_indices = None + mod = self._sp_module if issubclass(self.dataset_type, ad.abc.CSRDataset) else np + for chunk_indices in _batched(self._get_chunks(chunk_size, worker_handle, shuffle), preload_nchunks): + slices = [ + slice( + index * chunk_size, + min(self.n_obs, (index + 1) * chunk_size), + ) + for index in chunk_indices + ] + dataset_index_to_slices = self._slices_to_slices_with_array_index(slices) + # Fetch the data over slices + chunks: list[InputInMemoryArray] = zsync.sync(index_datasets(dataset_index_to_slices, fetch_data)) + if any(isinstance(c, CSRContainer) for c in chunks): + chunks_converted: list[OutputInMemoryArray] = [ + self._sp_module.csr_matrix(tuple(self._np_module.asarray(e) for e in c.elems), shape=c.shape) + for c in chunks + ] + else: + chunks_converted = [self._np_module.asarray(c) for c in chunks] + # Accumulate labels + labels: None | list[np.ndarray] = None + if self.labels is not None: + labels = [] + for dataset_idx in dataset_index_to_slices.keys(): + labels += [ + self.labels[dataset_idx][ + np.concatenate([np.arange(s.start, s.stop) for s in dataset_index_to_slices[dataset_idx]]) + ] + ] + # Accumulate indices if necessary + indices: None | list[np.ndarray] = None + if self._return_index: + dataset_index_to_slices = self._slices_to_slices_with_array_index(slices, use_original_space=True) + dataset_indices = dataset_index_to_slices.keys() + indices = [ + np.concatenate( + [ + np.arange( + s.start, + s.stop, + ) + for s in dataset_index_to_slices[index] + ] + ) + for index in dataset_indices + ] + # Do batch returns, handling leftover data as necessary + in_memory_data = ( + mod.vstack(chunks_converted) + if in_memory_data is None + else mod.vstack([in_memory_data, *chunks_converted]) + ) + if self.labels is not None: + in_memory_labels = ( + np.concatenate(labels) if in_memory_labels is None else np.concatenate([in_memory_labels, *labels]) + ) + if self._return_index: + in_memory_indices = ( + np.concatenate(indices) + if in_memory_indices is None + else np.concatenate([in_memory_indices, *indices]) + ) + # Create random indices into in_memory_data and then index into it + # If there is "leftover" at the end (see the modulo op), + # save it for the next iteration. + batch_indices = np.arange(in_memory_data.shape[0]) + if shuffle: + np.random.default_rng().shuffle(batch_indices) + splits = split_given_size(batch_indices, self._batch_size) + for i, s in enumerate(splits): + if s.shape[0] == self._batch_size: + res = [ + in_memory_data[s], + in_memory_labels[s] if self.labels is not None else None, + ] + if self._return_index: + res += [in_memory_indices[s]] + yield tuple(res) + if i == (len(splits) - 1): # end of iteration, leftover data needs be kept + if (s.shape[0] % self._batch_size) != 0: + in_memory_data = in_memory_data[s] + if in_memory_labels is not None: + in_memory_labels = in_memory_labels[s] + if in_memory_indices is not None: + in_memory_indices = in_memory_indices[s] + else: + in_memory_data = None + in_memory_labels = None + in_memory_indices = None + if in_memory_data is not None: # handle any leftover data + res = [ + in_memory_data, + in_memory_labels if self.labels is not None else None, + ] + if self._return_index: + res += [in_memory_indices] + yield tuple(res) + + +AnnDataManager.add_datasets.__doc__ = add_dataset_docstring +AnnDataManager.add_dataset.__doc__ = add_dataset_docstring diff --git a/src/arrayloaders/dense.py b/src/arrayloaders/dense.py new file mode 100644 index 0000000..6ef4e36 --- /dev/null +++ b/src/arrayloaders/dense.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +import zarr +from torch.utils.data import IterableDataset + +from arrayloaders.abc import AbstractIterableDataset +from arrayloaders.utils import MultiBasicIndexer + +if TYPE_CHECKING: + import numpy as np + + +class ZarrDenseDataset(AbstractIterableDataset, IterableDataset): # noqa: D101 + async def _fetch_data(self, slices: list[slice], dataset_idx: int) -> np.ndarray: + dataset = self._dataset_manager.train_datasets[dataset_idx] + indexer = MultiBasicIndexer( + [ + zarr.core.indexing.BasicIndexer( + (s, Ellipsis), + shape=dataset.metadata.shape, + chunk_grid=dataset.metadata.chunk_grid, + ) + for s in slices + ] + ) + res = cast( + "np.ndarray", + await dataset._async_array._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()), + ) + return res + + def _validate(self, datasets: list[zarr.Array]): + if not all(isinstance(d, zarr.Array) for d in datasets): + raise TypeError("Cannot create dense dataset without using a zarr.Array") diff --git a/arrayloaders/io/store_creation.py b/src/arrayloaders/io.py similarity index 80% rename from arrayloaders/io/store_creation.py rename to src/arrayloaders/io.py index d0846d7..3c21d9d 100644 --- a/arrayloaders/io/store_creation.py +++ b/src/arrayloaders/io.py @@ -22,23 +22,31 @@ from zarr.abc.codec import BytesBytesCodec -def _write_sharded( +def write_sharded( group: zarr.Group, adata: ad.AnnData, chunk_size: int = 4096, shard_size: int = 65536, - compressors: Iterable[BytesBytesCodec] = ( - BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle), - ), + compressors: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), ): + """Write a sharded zarr store from anndata object + + Args: + group: The destination group + adata: The source anndata object + chunk_size: Chunk size inside a shard. Defaults to 4096. + shard_size: Shard size i.e., number of elements in a single file. Defaults to 65536. + compressors: The compressors to pass to `zarr`. Defaults to (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),). + """ ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr def callback( - func: ad.experimental.Write, - g: zarr.Group, - k: str, + write_func: ad.experimental.Write, + store: zarr.Group, + elem_name: str, elem: ad.typing.RWAble, dataset_kwargs: Mapping[str, Any], + *, iospec: ad.experimental.IOSpec, ): if iospec.encoding_type in {"array"}: @@ -56,15 +64,13 @@ def callback( **dataset_kwargs, } - func(g, k, elem, dataset_kwargs=dataset_kwargs) + write_func(store, elem_name, elem, dataset_kwargs=dataset_kwargs) ad.experimental.write_dispatched(group, "/", adata, callback=callback) zarr.consolidate_metadata(group.store) -def _lazy_load_h5ads( - paths: Iterable[PathLike[str]] | Iterable[str], chunk_size: int = 4096 -): +def _lazy_load_h5ads(paths: Iterable[PathLike[str]] | Iterable[str], chunk_size: int = 4096): adatas = [] for path in paths: with h5py.File(path) as f: @@ -92,13 +98,10 @@ def _load_h5ads(paths: Iterable[PathLike[str]] | Iterable[str]): return ad.concat(adatas, join="outer") -def _create_chunks_for_shuffling( - adata: ad.AnnData, shuffle_buffer_size: int = 1_048_576, shuffle: bool = True -): +def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_buffer_size: int = 1_048_576, shuffle: bool = True): chunk_boundaries = np.cumsum([0] + list(adata.X.chunks[0])) slices = [ - slice(int(start), int(end)) - for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True) + slice(int(start), int(end)) for start, end in zip(chunk_boundaries[:-1], chunk_boundaries[1:], strict=True) ] if shuffle: random.shuffle(slices) @@ -114,9 +117,7 @@ def create_store_from_h5ads( var_subset: Iterable[str] | None = None, chunk_size: int = 4096, shard_size: int = 65536, - zarr_compressor: Iterable[BytesBytesCodec] = ( - BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle), - ), + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), h5ad_compressor: Literal["gzip", "lzf"] | None = "gzip", buffer_size: int = 1_048_576, shuffle: bool = True, @@ -142,7 +143,8 @@ def create_store_from_h5ads( should_denseify: Whether or not to write as dense on disk. output_format: Format of the output store. Can be either "zarr" or "h5ad". - Examples: + Examples + -------- >>> from arrayloaders.io.store_creation import create_store_from_h5ads >>> datasets = [ ... "path/to/first_adata.h5ad", @@ -175,13 +177,11 @@ def create_store_from_h5ads( adata_chunk.obs = adata_chunk.obs.iloc[idxs] # convert to dense format before writing to disk if should_denseify: - adata_chunk.X = adata_chunk.X.map_blocks( - lambda xx: xx.toarray().astype("f4"), dtype="f4" - ) + adata_chunk.X = adata_chunk.X.map_blocks(lambda xx: xx.toarray().astype("f4"), dtype="f4") if output_format == "zarr": f = zarr.open_group(Path(output_path) / f"chunk_{i}.zarr", mode="w") - _write_sharded( + write_sharded( f, adata_chunk, chunk_size=chunk_size, @@ -189,13 +189,9 @@ def create_store_from_h5ads( compressors=zarr_compressor, ) elif output_format == "h5ad": - adata_chunk.write_h5ad( - Path(output_path) / f"chunk_{i}.h5ad", compression=h5ad_compressor - ) + adata_chunk.write_h5ad(Path(output_path) / f"chunk_{i}.h5ad", compression=h5ad_compressor) else: - raise ValueError( - f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported." - ) + raise ValueError(f"Unrecognized output_format: {output_format}. Only 'zarr' and 'h5ad' are supported.") def _get_array_encoding_type(path: PathLike[str] | str): @@ -210,9 +206,7 @@ def add_h5ads_to_store( output_path: PathLike[str] | str, chunk_size: int = 4096, shard_size: int = 65536, - zarr_compressor: Iterable[BytesBytesCodec] = ( - BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle), - ), + zarr_compressor: Iterable[BytesBytesCodec] = (BloscCodec(cname="lz4", clevel=3, shuffle=BloscShuffle.shuffle),), cache_h5ads: bool = True, ): """Add h5ad files to an existing Zarr store. @@ -225,7 +219,8 @@ def add_h5ads_to_store( zarr_compressor: Compressors to use to compress the data in the zarr store. cache_h5ads: Whether to cache the h5ad files into memory before writing them to the store. - Examples: + Examples + -------- >>> from arrayloaders.io.store_creation import add_h5ads_to_store >>> datasets = [ ... "path/to/first_adata.h5ad", @@ -241,20 +236,14 @@ def add_h5ads_to_store( ) encoding = _get_array_encoding_type(output_path) if encoding == "array": - print( - "Detected array encoding type. Will convert to dense format before writing." - ) + print("Detected array encoding type. Will convert to dense format before writing.") if cache_h5ads: adata_concat = _load_h5ads(adata_paths) - chunks = np.array_split( - np.random.default_rng().permutation(len(adata_concat)), len(shards) - ) + chunks = np.array_split(np.random.default_rng().permutation(len(adata_concat)), len(shards)) else: adata_concat = _lazy_load_h5ads(adata_paths, chunk_size=chunk_size) - chunks = _create_chunks_for_shuffling( - adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True - ) + chunks = _create_chunks_for_shuffling(adata_concat, np.ceil(len(adata_concat) / len(shards)), shuffle=True) var_mask = adata_concat.var_names.isin(adata_concat.var_names) adata_concat.obs_names_make_unique() @@ -262,9 +251,7 @@ def add_h5ads_to_store( if encoding == "array": f = zarr.open_group(shard) adata_shard = ad.AnnData( - X=ad.experimental.read_elem_lazy(f["X"]) - .map_blocks(sp.csr_matrix) - .compute(), + X=ad.experimental.read_elem_lazy(f["X"]).map_blocks(sp.csr_matrix).compute(), obs=ad.io.read_elem(f["obs"]), var=ad.io.read_elem(f["var"]), ) @@ -281,9 +268,7 @@ def add_h5ads_to_store( ] ) idxs_shuffled = np.random.default_rng().permutation(len(adata)) - adata = adata[ - idxs_shuffled, : - ].copy() # this significantly speeds up writing to disk + adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk if encoding == "array": adata.X = da.from_array(adata.X, chunks=(chunk_size, -1)).map_blocks( @@ -291,7 +276,7 @@ def add_h5ads_to_store( ) f = zarr.open_group(shard, mode="w") - _write_sharded( + write_sharded( f, adata, chunk_size=chunk_size, diff --git a/src/arrayloaders/sparse.py b/src/arrayloaders/sparse.py new file mode 100644 index 0000000..a0105f9 --- /dev/null +++ b/src/arrayloaders/sparse.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import asyncio +from itertools import accumulate, chain, pairwise +from typing import NamedTuple + +import anndata as ad +import numpy as np +import zarr +import zarr.core.sync as zsync +from torch.utils.data import IterableDataset + +from arrayloaders.abc import AbstractIterableDataset +from arrayloaders.utils import CSRContainer, MultiBasicIndexer, __init_docstring__ + + +class CSRDatasetElems(NamedTuple): + """Container for cached objects that will be indexed into to generate CSR matrices""" + + indptr: np.ndarray + indices: zarr.AsyncArray + data: zarr.AsyncArray + + +class ZarrSparseDataset(AbstractIterableDataset, IterableDataset): # noqa: D101 + _dataset_elem_cache: dict[int, CSRDatasetElems] = {} + + def _cache_update_callback(self): + """Callback for when datasets are added to ensure the cache is updated.""" + return zsync.sync(self._ensure_cache()) + + def _validate(self, datasets: list[ad.abc.CSRDataset]): + if not all(isinstance(d, ad.abc.CSRDataset) for d in datasets): + raise TypeError("Cannot create sparse dataset using CSRDataset data") + + async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems: + """Fetch the in-memory indptr, and backed indices and data for a given dataset index. + + Args: + idx: The index + + Returns + ------- + The constituent elems of the CSR dataset. + """ + indptr = await self._dataset_manager.train_datasets[idx].group._async_group.getitem("indptr") + return CSRDatasetElems( + *( + await asyncio.gather( + indptr.getitem(Ellipsis), + self._dataset_manager.train_datasets[idx].group._async_group.getitem("indices"), + self._dataset_manager.train_datasets[idx].group._async_group.getitem("data"), + ) + ) + ) + + async def _ensure_cache(self): + """Build up the cache of datasets i.e., in-memory indptr, and backed indices and data.""" + arr_idxs = [ + idx for idx in range(len(self._dataset_manager.train_datasets)) if idx not in self._dataset_elem_cache + ] + all_elems = await asyncio.gather( + *( + self._create_sparse_elems(idx) + for idx in range(len(self._dataset_manager.train_datasets)) + if idx not in self._dataset_elem_cache + ) + ) + for idx, elems in zip(arr_idxs, all_elems, strict=True): + self._dataset_elem_cache[idx] = elems + + async def _get_sparse_elems(self, dataset_idx: int) -> CSRDatasetElems: + """Return the arrays (zarr or otherwise) needed to represent on-disk data at a given index. + + Args: + dataset_idx: The index of the dataset whose arrays are sought. + + Returns + ------- + The arrays representing the sparse data. + """ + if dataset_idx not in self._dataset_elem_cache: + await self._ensure_cache() + return self._dataset_elem_cache[dataset_idx] + + async def _fetch_data( + self, + slices: list[slice], + dataset_idx: int, + ) -> CSRContainer: + # See https://github.com/scverse/anndata/blob/361325fc621887bf4f381e9412b150fcff599ff7/src/anndata/_core/sparse_dataset.py#L272-L295 + # for the inspiration of this function. + indptr, indices, data = await self._get_sparse_elems(dataset_idx) + indptr_indices = [indptr[slice(s.start, s.stop + 1)] for s in slices] + indptr_limits = [slice(i[0], i[-1]) for i in indptr_indices] + indexer = MultiBasicIndexer( + [ + zarr.core.indexing.BasicIndexer((l,), shape=data.metadata.shape, chunk_grid=data.metadata.chunk_grid) + for l in indptr_limits + ] + ) + data_np, indices_np = await asyncio.gather( + data._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()), + indices._get_selection(indexer, prototype=zarr.core.buffer.default_buffer_prototype()), + ) + gaps = (s1.start - s0.stop for s0, s1 in pairwise(indptr_limits)) + offsets = accumulate(chain([indptr_limits[0].start], gaps)) + start_indptr = indptr_indices[0] - next(offsets) + if len(slices) < 2: # there is only one slice so no need to concatenate + return CSRContainer( + elems=(data_np, indices_np, start_indptr), + shape=(start_indptr.shape[0] - 1, self._dataset_manager.n_var), + ) + end_indptr = np.concatenate([s[1:] - o for s, o in zip(indptr_indices[1:], offsets, strict=True)]) + indptr_np = np.concatenate([start_indptr, end_indptr]) + return CSRContainer( + elems=(data_np, indices_np, indptr_np), + shape=(indptr_np.shape[0] - 1, self._dataset_manager.n_var), + ) + + +ZarrSparseDataset.__init__.__doc__ = __init_docstring__.format(array_type="sparse") diff --git a/src/arrayloaders/types.py b/src/arrayloaders/types.py new file mode 100644 index 0000000..528dd9d --- /dev/null +++ b/src/arrayloaders/types.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from types import NoneType +from typing import TypeVar + +import anndata as ad +import numpy as np +import zarr +from scipy import sparse as sp + +from arrayloaders.utils import CSRContainer + +try: + from cupy import ndarray as CupyArray + from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix # pragma: no cover +except ImportError: + CupyCSRMatrix = NoneType + CupyArray = NoneType + + +OnDiskArray = TypeVar("OnDiskArray", ad.abc.CSRDataset, zarr.Array) + + +OutputInMemoryArray = TypeVar("OutputInMemoryArray", sp.csr_matrix, np.ndarray, CupyCSRMatrix, CupyArray) +InputInMemoryArray = TypeVar("InputInMemoryArray", CSRContainer, np.ndarray) diff --git a/src/arrayloaders/utils.py b/src/arrayloaders/utils.py new file mode 100644 index 0000000..4e5bf4d --- /dev/null +++ b/src/arrayloaders/utils.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from functools import cached_property +from itertools import islice +from typing import TYPE_CHECKING, Protocol + +import numpy as np +import zarr +from torch.utils.data import get_worker_info + +if TYPE_CHECKING: + from collections import OrderedDict + from collections.abc import Awaitable, Callable + + from arrayloaders.types import InputInMemoryArray + + +def split_given_size(a: np.ndarray, size: int) -> list[np.ndarray]: + """Wrapper around `np.split` to split up an array into `size` chunks""" + return np.split(a, np.arange(size, len(a), size)) + + +@dataclass +class CSRContainer: + """A low-cost container for moving around the buffers of a CSR object""" + + elems: tuple[np.ndarray, np.ndarray, np.ndarray] + shape: tuple[int, int] + + +def _batched(iterable, n): + if n < 1: + raise ValueError("n must be >= 1") + it = iter(iterable) + while batch := list(islice(it, n)): + yield batch + + +async def index_datasets( + dataset_index_to_slices: OrderedDict[int, list[slice]], + fetch_data: Callable[[list[slice], int], Awaitable[CSRContainer | np.ndarray]], +) -> list[InputInMemoryArray]: + """Helper function meant to encapsulate asynchronous calls so that we can use the same event loop as zarr. + + Args: + dataset_index_to_slices: A lookup of the list-placement index of a dataset to the request slices. + fetch_data: The function to do the fetching for a given slice-dataset index pair. + """ + tasks = [] + for dataset_idx in dataset_index_to_slices.keys(): + tasks.append( + fetch_data( + dataset_index_to_slices[dataset_idx], + dataset_idx, + ) + ) + return await asyncio.gather(*tasks) + + +add_dataset_docstring = """\ +Append datasets to this loader. + +Args: + datasets: List of :class:`anndata.abc.CSRDataset` or :class:`zarr.Array` objects, generally from :attr:`anndata.AnnData.X`. + obs: List of `numpy.ndarray` labels, generally from :attr:`anndata.AnnData.obs`. +""" + +add_dataset_docstring = """\ +Append a dataset to this loader. + +Args: + dataset: :class:`anndata.abc.CSRDataset` or :class:`zarr.Array` object, generally from :attr:`anndata.AnnData.X`. + obs: `numpy.ndarray` labels for the dataset, generally from :attr:`anndata.AnnData.obs`. +""" + + +__init_docstring__ = """A loader for on-disk {array_type} data. + +This loader batches together slice requests to the underlying {array_type} stores to acheive higher performance. +This custom code to do this task will be upstreamed into anndata at some point and no longer rely on private zarr apis. +The loader is agnostic to the on-disk chunking/sharding, but it may be advisable to align with the in-memory chunk size. + +Args: + chunk_size: The obs size (i.e., axis 0) of contiguous array data to fetch, by default 512 + preload_nchunks: The number of chunks of contiguous array data to fetch, by default 32 + shuffle: Whether or not to shuffle the data, by default True + return_index: Whether or not to return the index on each iteration, by default False + preload_to_gpu: Whether or not to use cupy for non-io array operations like vstack and indexing. This option entails greater GPU memory usage. +""" + + +# TODO: make this part of the public zarr or zarrs-python API. +# We can do chunk coalescing in zarrs based on integer arrays, so I think +# there would make sense with ezclump or similar. +# Another "solution" would be for zarrs to support integer indexing properly, if that pipeline works, +# or make this an "experimental setting" and to use integer indexing for the zarr-python pipeline. +# See: https://github.com/zarr-developers/zarr-python/issues/3175 for why this is better than simpler alternatives. +class MultiBasicIndexer(zarr.core.indexing.Indexer): + """Custom indexer to enable joint fetching of disparate slices""" + + def __init__(self, indexers: list[zarr.core.indexing.Indexer]): + self.shape = (sum(i.shape[0] for i in indexers), *indexers[0].shape[1:]) + self.drop_axes = indexers[0].drop_axes # maybe? + self.indexers = indexers + + def __iter__(self): + total = 0 + for i in self.indexers: + for c in i: + out_selection = c[2] + gap = out_selection[0].stop - out_selection[0].start + yield type(c)(c[0], c[1], (slice(total, total + gap), *out_selection[1:]), c[3]) + total += gap + + +def sample_rows( + x_list: list[np.ndarray], + obs_list: list[np.ndarray] | None, + indices: list[np.ndarray] | None = None, + *, + shuffle: bool = True, +): + """Samples rows from multiple arrays and their corresponding observation arrays. + + Args: + x_list: A list of numpy arrays containing the data to sample from. + obs_list: A list of numpy arrays containing the corresponding observations. + indices: the list of indexes for each element in x_list/ + shuffle: Whether to shuffle the rows before sampling. Defaults to True. + + Yields + ------ + tuple: A tuple containing a row from `x_list` and the corresponding row from `obs_list`. + """ + lengths = np.fromiter((x.shape[0] for x in x_list), dtype=int) + cum = np.concatenate(([0], np.cumsum(lengths))) + total = cum[-1] + idxs = np.arange(total) + if shuffle: + np.random.default_rng().shuffle(idxs) + arr_idxs = np.searchsorted(cum, idxs, side="right") - 1 + row_idxs = idxs - cum[arr_idxs] + for ai, ri in zip(arr_idxs, row_idxs, strict=True): + res = [ + x_list[ai][ri], + obs_list[ai][ri] if obs_list is not None else None, + ] + if indices is not None: + yield (*res, indices[ai][ri]) + else: + yield tuple(res) + + +class WorkerHandle: # noqa: D101 + @cached_property + def _worker_info(self): + return get_worker_info() + + @cached_property + def _rng(self): + if self._worker_info is None: + return np.random.default_rng() + else: + # This is used for the _get_chunks function + # Use the same seed for all workers that the resulting splits are the same across workers + # torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed + return np.random.default_rng(self._worker_info.seed - self._worker_info.id) + + def shuffle(self, obj: np.typing.ArrayLike) -> None: + """Perform in-place shuffle. + + Args: + obj: The object to be shuffled + """ + self._rng.shuffle(obj) + + def get_part_for_worker(self, obj: np.ndarray) -> np.ndarray: + """Get a chunk of an incoming array accordnig to the current worker id. + + Args: + obj: Incoming array + + Returns + ------- + A evenly split part of the ray corresponding to how many workers there are. + """ + if self._worker_info is None: + return obj + num_workers, worker_id = self._worker_info.num_workers, self._worker_info.id + chunks_split = np.array_split(obj, num_workers) + return chunks_split[worker_id] + + +def check_lt_1(vals: list[int], labels: list[str]): + """Raise a ValueError if any of the values are less than one. + + The format of the error is "{labels[i]} must be greater than 1, got {values[i]}" + and is raised based on the first found less than one value. + + Args: + vals: The values to check < 1 + labels: The label for the value in the error if the value is less than one. + + Raises + ------ + ValueError: _description_ + """ + if any(is_lt_1 := [v < 1 for v in vals]): + label, value = next( + (label, value) + for label, value, check in zip( + labels, + vals, + is_lt_1, + strict=True, + ) + if check + ) + raise ValueError(f"{label} must be greater than 1, got {value}") + + +class SupportsShape(Protocol): # noqa: D101 + @property + def shape(self) -> tuple[int, int] | list[int]: ... # noqa: D102 + + +def check_var_shapes(objs: list[SupportsShape]): + """Small utility function to check that all objects have the same shape along the second axis""" + if not all(objs[0].shape[1] == d.shape[1] for d in objs): + raise ValueError("TODO: All datasets must have same shape along the var axis.") diff --git a/tests/conftest.py b/tests/conftest.py index 2cd08f7..b579961 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,18 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING import anndata as ad -import dask.array as da import numpy as np import pandas as pd import pytest import scipy.sparse as sp import zarr +from arrayloaders import write_sharded -from arrayloaders.io.store_creation import _write_sharded +if TYPE_CHECKING: + from collections.abc import Generator @pytest.fixture(autouse=True) @@ -24,22 +26,17 @@ def use_zarrs(request): @pytest.fixture(scope="session") -def mock_store(tmpdir_factory, n_shards: int = 3): +def adata_with_path(tmpdir_factory, n_shards: int = 3) -> Generator[tuple[ad.AnnData, Path]]: """Create a mock Zarr store for testing.""" feature_dim = 100 n_cells_per_shard = 200 tmp_path = Path(tmpdir_factory.mktemp("stores")) + adata_lst = [] for shard in range(n_shards): adata = ad.AnnData( - X=da.random.random( - (n_cells_per_shard, feature_dim), chunks=(10, -1) - ).astype("f4"), + X=np.random.random((n_cells_per_shard, feature_dim)).astype("f4"), obs=pd.DataFrame( - { - "label": np.random.default_rng().integers( - 0, 5, size=n_cells_per_shard - ) - }, + {"label": np.random.default_rng().integers(0, 5, size=n_cells_per_shard)}, index=np.arange(n_cells_per_shard).astype(str), ), layers={ @@ -51,12 +48,16 @@ def mock_store(tmpdir_factory, n_shards: int = 3): ) }, ) - - f = zarr.open(tmp_path / f"chunk_{shard}.zarr", mode="w", zarr_format=3) - _write_sharded( + adata_lst += [adata] + f = zarr.open_group(tmp_path / f"chunk_{shard}.zarr", mode="w", zarr_format=3) + write_sharded( f, adata, chunk_size=10, shard_size=20, ) - yield tmp_path + yield ( + # need to match directory iteration order for correctness so can't just concatenate + ad.concat([ad.read_zarr(tmp_path / shard) for shard in tmp_path.iterdir() if str(shard).endswith(".zarr")]), + tmp_path, + ) diff --git a/tests/test_data_module.py b/tests/test_data_module.py deleted file mode 100644 index aa083cb..0000000 --- a/tests/test_data_module.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import numpy as np - -from arrayloaders.io import ClassificationDataModule -from arrayloaders.io.dask_loader import read_lazy_store - -if TYPE_CHECKING: - from pathlib import Path - - -def test_datamodule(mock_store: Path): - """ - This test verifies that the dataloader for training works correctly: - 1. The training dataloader correctly initializes with training data - 2. The train_dataloader produces batches with expected dimensions - 3. The batch size matches the configured value - """ - adata = read_lazy_store(mock_store, obs_columns=["label"]) - dm = ClassificationDataModule( - adata_train=adata, - adata_val=None, - label_column="label", - train_dataloader_kwargs={ - "batch_size": 15, - "drop_last": True, - }, - ) - - for batch in dm.train_dataloader(): - x, y = batch - assert x.shape[1] == 100 - assert x.shape[0] == 15 # Check batch size - assert y.shape[0] == 15 - - -def test_datamodule_inference(mock_store: Path): - """ - This test verifies that the dataloader for inference (validation) works correctly: - 1. The validation dataloader correctly loads data from the mock store - 2. The batches have the expected feature dimension - 3. All data points from the original dataset are correctly included - 4. The order of the samples are correctly preserved during loading - """ - adata = read_lazy_store(mock_store, obs_columns=["label"]) - dm = ClassificationDataModule( - adata_train=None, - adata_val=adata, - label_column="label", - train_dataloader_kwargs={ - "batch_size": 15, - }, - ) - - x_list, y_list = [], [] - for batch in dm.val_dataloader(): - x, y = batch - x_list.append(x.detach().numpy()) - y_list.append(y.detach().numpy()) - assert x.shape[1] == 100 - - assert np.array_equal(np.vstack(x_list), adata.X.compute()) - assert np.array_equal(np.hstack(y_list), adata.obs["label"].to_numpy()) diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..d134934 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import platform +from importlib.util import find_spec +from types import NoneType +from typing import TYPE_CHECKING, TypedDict + +import anndata as ad +import numpy as np +import pytest +import scipy.sparse as sp +import zarr +import zarrs # noqa: F401 +from arrayloaders import ZarrDenseDataset, ZarrSparseDataset + +try: + from cupy import ndarray as CupyArray + from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix +except ImportError: + CupyCSRMatrix = NoneType + CupyArray = NoneType + +if TYPE_CHECKING: + from pathlib import Path + + +class Data(TypedDict): + dataset: ad.abc.CSRDataset | zarr.Array + obs: np.ndarray + + +class ListData(TypedDict): + datasets: list[ad.abc.CSRDataset | zarr.Array] + obs: list[np.ndarray] + + +def open_sparse(path: Path, *, use_zarrs: bool = False) -> Data: + old_pipeline = zarr.config.get("codec_pipeline.path") + + with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + return { + "dataset": ad.io.sparse_dataset(zarr.open(path)["layers"]["sparse"]), + "obs": ad.io.read_elem(zarr.open(path)["obs"])["label"].to_numpy(), + } + + +def open_dense(path: Path, *, use_zarrs: bool = False) -> Data: + old_pipeline = zarr.config.get("codec_pipeline.path") + + with zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline" if use_zarrs else old_pipeline}): + return { + "dataset": zarr.open(path)["X"], + "obs": ad.io.read_elem(zarr.open(path)["obs"])["label"].to_numpy(), + } + + +def concat(dicts: list[Data]) -> ListData: + return { + "datasets": [d["dataset"] for d in dicts], + "obs": [d["obs"] for d in dicts], + } + + +@pytest.mark.parametrize("shuffle", [True, False], ids=["shuffled", "unshuffled"]) +@pytest.mark.parametrize( + "gen_loader", + [ + pytest.param( + lambda path, + shuffle, + use_zarrs, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + dataset_class=dataset_class, + batch_size=batch_size, + preload_to_gpu=preload_to_gpu: dataset_class( + shuffle=shuffle, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + return_index=True, + batch_size=batch_size, + preload_to_gpu=preload_to_gpu, + ).add_datasets( + **concat( + [ + (open_sparse if issubclass(dataset_class, ZarrSparseDataset) else open_dense)( + p, use_zarrs=use_zarrs + ) + for p in path.glob("*.zarr") + ] + ) + ), + id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-obs_keys={obs_keys}-dataset_class={dataset_class.__name__}-layer_keys={layer_keys}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] + marks=pytest.mark.skipif( + find_spec("cupy") is None and preload_to_gpu, + reason="need cupy installed", + ), + ) + for chunk_size, preload_nchunks, obs_keys, dataset_class, layer_keys, batch_size, preload_to_gpu in [ + elem + for preload_to_gpu in [True, False] + for dataset_class in [ZarrDenseDataset, ZarrSparseDataset] # type: ignore[list-item] + for elem in [ + [ + 1, + 5, + None, + dataset_class, + None, + 1, + preload_to_gpu, + ], # singleton chunk size + [ + 5, + 1, + None, + dataset_class, + None, + 1, + preload_to_gpu, + ], # singleton preload + [ + 10, + 5, + None, + dataset_class, + None, + 5, + preload_to_gpu, + ], # batch size divides total in memory size evenly + [ + 10, + 5, + None, + dataset_class, + None, + 50, + preload_to_gpu, + ], # batch size equal to in-memory size loading + [ + 10, + 5, + None, + dataset_class, + None, + 14, + preload_to_gpu, + ], # batch size does not divide in memory size evenly + ] + ] + ], +) +def test_store_load_dataset(adata_with_path: tuple[ad.AnnData, Path], *, shuffle: bool, gen_loader, use_zarrs): + """ + This test verifies that the DaskDataset works correctly: + 1. The DaskDataset correctly loads data from the mock store + 2. Each sample has the expected feature dimension + 3. All samples from the dataset are processed + 4. If the dataset is not shuffled, it returns the correct data + """ + loader = gen_loader(adata_with_path[1], shuffle, use_zarrs) + adata = adata_with_path[0] + is_dense = isinstance(loader, ZarrDenseDataset) + n_elems = 0 + batches = [] + labels = [] + indices = [] + expected_data = adata.X if is_dense else adata.layers["sparse"].toarray() + for batch in loader: + x, label, index = batch + n_elems += x.shape[0] + # Check feature dimension + assert x.shape[1] == 100 + batches += [x.get() if isinstance(x, CupyCSRMatrix | CupyArray) else x] + if label is not None: + labels += [label] + if index is not None: + indices += [index] + # check that we yield all samples from the dataset + # np.array for sparse + stacked = (np if is_dense else sp).vstack(batches) + if not is_dense: + stacked = stacked.toarray() + if not shuffle: + np.testing.assert_allclose(stacked, expected_data) + if len(labels) > 0: + expected_labels = adata.obs["label"] + np.testing.assert_allclose( + np.concatenate(labels).ravel(), + expected_labels, + ) + else: + if len(indices) > 0: + indices = np.concatenate(indices).ravel() + np.testing.assert_allclose(stacked, expected_data[indices]) + assert n_elems == adata.shape[0] + + +@pytest.mark.parametrize( + "gen_loader", + [ + ( + lambda path, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + dataset_class=dataset_class: dataset_class( + shuffle=True, + chunk_size=chunk_size, + preload_nchunks=preload_nchunks, + ) + ) + for chunk_size, preload_nchunks in [[0, 10], [10, 0]] + for dataset_class in [ZarrSparseDataset, ZarrDenseDataset] + ], +) +def test_zarr_store_errors_lt_1(gen_loader, adata_with_path: tuple[ad.AnnData, Path]): + with pytest.raises(ValueError, match="must be greater than 1"): + gen_loader(adata_with_path[1]) + + +def test_bad_adata_X_type(adata_with_path: tuple[ad.AnnData, Path]): + data = open_dense(next(adata_with_path[1].glob("*.zarr"))) + data["dataset"] = data["dataset"][...] + ds = ZarrDenseDataset( + shuffle=True, + chunk_size=10, + preload_nchunks=10, + ) + with pytest.raises(TypeError, match="Cannot create"): + ds.add_dataset(**data) + + +def _custom_collate_fn(elems): + if isinstance(elems[0][0], sp.csr_matrix): + x = sp.vstack([v[0] for v in elems]).toarray() + else: + x = np.vstack([v[0] for v in elems]) + + if len(elems[0]) == 2: + y = np.array([v[1] for v in elems]) + else: + y = np.array([v[2] for v in elems]) + + return x, y + + +@pytest.mark.parametrize("loader", [ZarrDenseDataset, ZarrSparseDataset]) +@pytest.mark.skipif( + platform.system() == "Linux", + reason="See: https://github.com/scverse/anndata/issues/2021 potentially", +) +def test_torch_multiprocess_dataloading_zarr(adata_with_path: tuple[ad.AnnData, Path], loader, use_zarrs): + """ + Test that the ZarrDatasets can be used with PyTorch's DataLoader in a multiprocess context and that each element of + the dataset gets yielded once. + """ + from torch.utils.data import DataLoader + + if issubclass(loader, ZarrSparseDataset): + ds = ZarrSparseDataset(chunk_size=10, preload_nchunks=4, shuffle=True, return_index=True) + ds.add_datasets(**concat([open_sparse(p, use_zarrs=use_zarrs) for p in adata_with_path[1].glob("*.zarr")])) + x_ref = adata_with_path[0].layers["sparse"].toarray() + elif issubclass(loader, ZarrDenseDataset): + ds = ZarrDenseDataset(chunk_size=10, preload_nchunks=4, shuffle=True, return_index=True) + ds.add_datasets(**concat([open_dense(p, use_zarrs=use_zarrs) for p in adata_with_path[1].glob("*.zarr")])) + x_ref = adata_with_path[0].X + else: + raise ValueError("Unknown loader type") + + dataloader = DataLoader( + ds, + batch_size=32, + num_workers=4, + collate_fn=_custom_collate_fn, + ) + x_list, idx_list = [], [] + for batch in dataloader: + x, idxs = batch + x_list.append(x) + idx_list.append(idxs.ravel()) + + x = np.vstack(x_list) + idxs = np.concatenate(idx_list) + + assert np.array_equal(x[np.argsort(idxs)], x_ref) + + +@pytest.mark.skipif(find_spec("cupy") is not None, reason="Can't test for no cupy if cupy is there") +def test_no_cupy(adata_with_path: tuple[ad.AnnData, Path]): + ds = ZarrDenseDataset( + chunk_size=10, + preload_nchunks=4, + shuffle=True, + return_index=True, + preload_to_gpu=True, + ).add_dataset(**open_dense(list(adata_with_path[1].iterdir())[0])) + with pytest.raises(ImportError, match=r"even though `preload_to_gpu` argument"): + next(iter(ds)) diff --git a/tests/test_dataset_loading.py b/tests/test_dataset_loading.py deleted file mode 100644 index 56d48ca..0000000 --- a/tests/test_dataset_loading.py +++ /dev/null @@ -1,381 +0,0 @@ -from __future__ import annotations - -import platform -from importlib.util import find_spec -from types import NoneType -from typing import TYPE_CHECKING, TypedDict - -import anndata as ad -import numpy as np -import pytest -import scipy.sparse as sp -import zarr -import zarrs # noqa: F401 - -from arrayloaders.io import ( - DaskDataset, - ZarrDenseDataset, - ZarrSparseDataset, - read_lazy_store, -) - -try: - from cupy import ndarray as CupyArray - from cupyx.scipy.sparse import csr_matrix as CupyCSRMatrix -except ImportError: - CupyCSRMatrix = NoneType - CupyArray = NoneType - -if TYPE_CHECKING: - from pathlib import Path - - -class Data(TypedDict): - dataset: ad.abc.CSRDataset | zarr.Array - obs: np.ndarray - - -class ListData(TypedDict): - datasets: list[ad.abc.CSRDataset | zarr.Array] - obs: list[np.ndarray] - - -def open_sparse(path: Path, *, use_zarrs: bool = False) -> Data: - old_pipeline = zarr.config.get("codec_pipeline.path") - - with zarr.config.set( - { - "codec_pipeline.path": "zarrs.ZarrsCodecPipeline" - if use_zarrs - else old_pipeline - } - ): - return { - "dataset": ad.io.sparse_dataset(zarr.open(path)["layers"]["sparse"]), - "obs": ad.io.read_elem(zarr.open(path)["obs"])["label"].to_numpy(), - } - - -def open_dense(path: Path, *, use_zarrs: bool = False) -> Data: - old_pipeline = zarr.config.get("codec_pipeline.path") - - with zarr.config.set( - { - "codec_pipeline.path": "zarrs.ZarrsCodecPipeline" - if use_zarrs - else old_pipeline - } - ): - return { - "dataset": zarr.open(path)["X"], - "obs": ad.io.read_elem(zarr.open(path)["obs"])["label"].to_numpy(), - } - - -def concat(dicts: list[Data]) -> ListData: - return { - "datasets": [d["dataset"] for d in dicts], - "obs": [d["obs"] for d in dicts], - } - - -@pytest.mark.parametrize("shuffle", [True, False], ids=["shuffled", "unshuffled"]) -@pytest.mark.parametrize( - "gen_loader", - [ - pytest.param( - lambda path, shuffle, use_zarrs: DaskDataset( - read_lazy_store(path, obs_columns=["label"]), - label_column="label", - n_chunks=4, - shuffle=shuffle, - ), - id="dask", - ), - *( - pytest.param( - lambda path, - shuffle, - use_zarrs, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - dataset_class=dataset_class, - batch_size=batch_size, - preload_to_gpu=preload_to_gpu: dataset_class( - shuffle=shuffle, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - return_index=True, - batch_size=batch_size, - preload_to_gpu=preload_to_gpu, - ).add_datasets( - **concat( - [ - ( - open_sparse - if issubclass(dataset_class, ZarrSparseDataset) - else open_dense - )(p, use_zarrs=use_zarrs) - for p in path.glob("*.zarr") - ] - ) - ), - id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-obs_keys={obs_keys}-dataset_class={dataset_class.__name__}-layer_keys={layer_keys}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] - marks=pytest.mark.skipif( - find_spec("cupy") is None and preload_to_gpu, - reason="need cupy installed", - ), - ) - for chunk_size, preload_nchunks, obs_keys, dataset_class, layer_keys, batch_size, preload_to_gpu in [ - elem - for preload_to_gpu in [True, False] - for dataset_class in [ZarrDenseDataset, ZarrSparseDataset] # type: ignore[list-item] - for elem in [ - [ - 1, - 5, - None, - dataset_class, - None, - 1, - preload_to_gpu, - ], # singleton chunk size - [ - 5, - 1, - None, - dataset_class, - None, - 1, - preload_to_gpu, - ], # singleton preload - [ - 10, - 5, - None, - dataset_class, - None, - 5, - preload_to_gpu, - ], # batch size divides total in memory size evenly - [ - 10, - 5, - None, - dataset_class, - None, - 50, - preload_to_gpu, - ], # batch size equal to in-memory size loading - [ - 10, - 5, - None, - dataset_class, - None, - 14, - preload_to_gpu, - ], # batch size does not divide in memory size evenly - ] - ] - ), - ], -) -def test_store_load_dataset(mock_store: Path, *, shuffle: bool, gen_loader, use_zarrs): - """ - This test verifies that the DaskDataset works correctly: - 1. The DaskDataset correctly loads data from the mock store - 2. Each sample has the expected feature dimension - 3. All samples from the dataset are processed - 4. If the dataset is not shuffled, it returns the correct data - """ - adata = read_lazy_store(mock_store, obs_columns=["label"]) - - loader = gen_loader(mock_store, shuffle, use_zarrs) - is_dask = isinstance(loader, DaskDataset) - is_dense = isinstance(loader, ZarrDenseDataset) or is_dask - n_elems = 0 - batches = [] - labels = [] - indices = [] - expected_data = ( - adata.X.compute() if is_dense else adata.layers["sparse"].compute().toarray() - ) - for batch in loader: - if isinstance(loader, DaskDataset): - x, label = batch - index = None - else: - x, label, index = batch - n_elems += 1 if (is_dask := isinstance(loader, DaskDataset)) else x.shape[0] - # Check feature dimension - assert x.shape[0 if is_dask else 1] == 100 - batches += [x.get() if isinstance(x, CupyCSRMatrix | CupyArray) else x] - if label is not None: - labels += [label] - if index is not None: - indices += [index] - # check that we yield all samples from the dataset - # np.array for sparse - stacked = (np if is_dense else sp).vstack(batches) - if not is_dense: - stacked = stacked.toarray() - if not shuffle: - np.testing.assert_allclose(stacked, expected_data) - if len(labels) > 0: - expected_labels = adata.obs["label"] - np.testing.assert_allclose( - (np.array(labels) if is_dask else np.concatenate(labels)).ravel(), - expected_labels, - ) - else: - if len(indices) > 0: - indices = np.concatenate(indices).ravel() - np.testing.assert_allclose(stacked, expected_data[indices]) - assert n_elems == adata.shape[0] - - -@pytest.mark.parametrize( - "gen_loader", - [ - lambda path: DaskDataset( - read_lazy_store(path, obs_columns=["label"])[:0], - label_column="label", - n_chunks=4, - shuffle=True, - ), - lambda path: DaskDataset( - read_lazy_store(path, obs_columns=["label"]), - label_column="label", - n_chunks=0, - shuffle=True, - ), - *( - ( - lambda path, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - dataset_class=dataset_class: dataset_class( - shuffle=True, - chunk_size=chunk_size, - preload_nchunks=preload_nchunks, - ) - ) - for chunk_size, preload_nchunks in [[0, 10], [10, 0]] - for dataset_class in [ZarrSparseDataset, ZarrDenseDataset] - ), - ], -) -def test_zarr_store_errors_lt_1(gen_loader, mock_store): - with pytest.raises(ValueError, match="must be greater than 1"): - gen_loader(mock_store) - - -def test_bad_adata_X_type(mock_store): - data = open_dense(next(mock_store.glob("*.zarr"))) - data["dataset"] = data["dataset"][...] - ds = ZarrDenseDataset( - shuffle=True, - chunk_size=10, - preload_nchunks=10, - ) - with pytest.raises(TypeError, match="Cannot add a dataset"): - ds.add_dataset(**data) - - -def _custom_collate_fn(elems): - if isinstance(elems[0][0], sp.csr_matrix): - x = sp.vstack([v[0] for v in elems]).toarray() - else: - x = np.vstack([v[0] for v in elems]) - - if len(elems[0]) == 2: - y = np.array([v[1] for v in elems]) - else: - y = np.array([v[2] for v in elems]) - - return x, y - - -@pytest.mark.parametrize("loader", [DaskDataset, ZarrDenseDataset, ZarrSparseDataset]) -@pytest.mark.skipif( - platform.system() == "Linux", - reason="See: https://github.com/scverse/anndata/issues/2021 potentially", -) -def test_torch_multiprocess_dataloading_zarr(mock_store, loader, use_zarrs): - """ - Test that the ZarrDatasets can be used with PyTorch's DataLoader in a multiprocess context and that each element of - the dataset gets yielded once. - """ - from torch.utils.data import DataLoader - - if issubclass(loader, ZarrSparseDataset): - ds = ZarrSparseDataset( - chunk_size=10, preload_nchunks=4, shuffle=True, return_index=True - ) - ds.add_datasets( - **concat( - [open_sparse(p, use_zarrs=use_zarrs) for p in mock_store.glob("*.zarr")] - ) - ) - x_ref = ( - read_lazy_store(mock_store, obs_columns=["label"]) - .layers["sparse"] - .compute() - .toarray() - ) - elif issubclass(loader, ZarrDenseDataset): - ds = ZarrDenseDataset( - chunk_size=10, preload_nchunks=4, shuffle=True, return_index=True - ) - ds.add_datasets( - **concat( - [open_dense(p, use_zarrs=use_zarrs) for p in mock_store.glob("*.zarr")] - ) - ) - x_ref = read_lazy_store(mock_store, obs_columns=["label"]).X.compute( - scheduler="single-threaded" - ) - elif issubclass(loader, DaskDataset): - adata = read_lazy_store(mock_store, obs_columns=["label"]) - adata.obs["order"] = np.arange(adata.shape[0]) - ds = DaskDataset( - adata, - label_column="order", - n_chunks=4, - shuffle=True, - ) - x_ref = adata.X.compute(scheduler="single-threaded") - else: - raise ValueError("Unknown loader type") - - dataloader = DataLoader( - ds, - batch_size=32, - num_workers=4, - collate_fn=_custom_collate_fn, - ) - x_list, idx_list = [], [] - for batch in dataloader: - x, idxs = batch - x_list.append(x) - idx_list.append(idxs.ravel()) - - x = np.vstack(x_list) - idxs = np.concatenate(idx_list) - - assert np.array_equal(x[np.argsort(idxs)], x_ref) - - -@pytest.mark.skipif( - find_spec("cupy") is not None, reason="Can't test for no cupy if cupy is there" -) -def test_no_cupy(): - with pytest.raises(ImportError, match=r"even though `preload_to_gpu` argument"): - ZarrSparseDataset( - chunk_size=10, - preload_nchunks=4, - shuffle=True, - return_index=True, - preload_to_gpu=True, - ) diff --git a/tests/test_store_creation.py b/tests/test_store_creation.py index 979e54c..cf7c808 100644 --- a/tests/test_store_creation.py +++ b/tests/test_store_creation.py @@ -7,11 +7,9 @@ import numpy as np import pandas as pd import pytest +from arrayloaders import add_h5ads_to_store, create_store_from_h5ads from scipy.sparse import random as sparse_random -from arrayloaders.io.dask_loader import read_lazy_store -from arrayloaders.io.store_creation import add_h5ads_to_store, create_store_from_h5ads - if TYPE_CHECKING: from pathlib import Path @@ -22,13 +20,13 @@ def anndata_settings(): @pytest.fixture -def mock_anndatas_path(tmp_path: Path, n_adatas: int = 4): +def mock_anndatas_path(tmp_path: Path, n_adatas: int = 4) -> tuple[ad.AnnData, Path]: """Create mock anndata objects for testing.""" tmp_path = tmp_path / "adatas" tmp_path.mkdir(parents=True, exist_ok=True) n_features = [random.randint(50, 100) for _ in range(n_adatas)] n_cells = [random.randint(50, 100) for _ in range(n_adatas)] - + adatas = [] for i, (m, n) in enumerate(zip(n_cells, n_features, strict=True)): adata = ad.AnnData( X=sparse_random(m, n, density=0.1, format="csr", dtype="f4"), @@ -40,27 +38,23 @@ def mock_anndatas_path(tmp_path: Path, n_adatas: int = 4): ) adata.write_h5ad(tmp_path / f"adata_{i}.h5ad", compression="gzip") - - return tmp_path + adatas += [adata] + return ad.concat(adatas, join="outer"), tmp_path @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("densify", [True, False]) def test_store_creation( - mock_anndatas_path, + mock_anndatas_path: tuple[ad.AnnData, Path], shuffle: bool, densify: bool, ): var_subset = [f"gene_{i}" for i in range(100)] - (mock_anndatas_path / "zarr_store").mkdir(parents=True, exist_ok=True) + (mock_anndatas_path[1] / "zarr_store").mkdir(parents=True, exist_ok=True) create_store_from_h5ads( - [ - mock_anndatas_path / f - for f in mock_anndatas_path.iterdir() - if str(f).endswith(".h5ad") - ], - mock_anndatas_path / "zarr_store", + [mock_anndatas_path[1] / f for f in mock_anndatas_path[1].iterdir() if str(f).endswith(".h5ad")], + mock_anndatas_path[1] / "zarr_store", var_subset, chunk_size=10, shard_size=20, @@ -69,33 +63,25 @@ def test_store_creation( should_denseify=densify, ) - adatas = [ - ad.read_h5ad(mock_anndatas_path / f) - for f in mock_anndatas_path.iterdir() - if str(f).endswith(".h5ad") - ] - adata = read_lazy_store(mock_anndatas_path / "zarr_store") - assert adata.X.shape[0] == sum([adata.shape[0] for adata in adatas]) - assert adata.X.shape[1] == len( - [gene for gene in var_subset if gene in adata.var.index] + adata_orig = mock_anndatas_path[0] + adata = ad.concat( + [ad.read_zarr(zarr_path) for zarr_path in (mock_anndatas_path[1] / "zarr_store").iterdir()], join="outer" ) + assert adata.X.shape[0] == adata_orig.X.shape[0] + assert adata.X.shape[1] == adata_orig.X.shape[1] assert np.array_equal( sorted(adata.var.index), - sorted([gene for gene in var_subset if gene in adata.var.index]), + sorted(adata_orig.var.index), ) @pytest.mark.parametrize("densify", [True, False]) @pytest.mark.parametrize("cache_h5ads", [True, False]) -def test_store_extension(mock_anndatas_path, densify: bool, cache_h5ads: bool): - store_path = mock_anndatas_path / "zarr_store" +def test_store_extension(mock_anndatas_path: tuple[ad.AnnData, Path], densify: bool, cache_h5ads: bool): + store_path = mock_anndatas_path[1] / "zarr_store" # create new store create_store_from_h5ads( - [ - mock_anndatas_path / f - for f in mock_anndatas_path.iterdir() - if str(f).endswith(".h5ad") - ], + [mock_anndatas_path[1] / f for f in mock_anndatas_path[1].iterdir() if str(f).endswith(".h5ad")], store_path, chunk_size=10, shard_size=20, @@ -105,21 +91,13 @@ def test_store_extension(mock_anndatas_path, densify: bool, cache_h5ads: bool): ) # add h5ads to existing store add_h5ads_to_store( - [ - mock_anndatas_path / f - for f in mock_anndatas_path.iterdir() - if str(f).endswith(".h5ad") - ], + [mock_anndatas_path[1] / f for f in mock_anndatas_path[1].iterdir() if str(f).endswith(".h5ad")], store_path, cache_h5ads=cache_h5ads, chunk_size=10, shard_size=20, ) - adata = read_lazy_store(store_path) - adatas = [ - ad.read_h5ad(mock_anndatas_path / f) - for f in mock_anndatas_path.iterdir() - if str(f).endswith(".h5ad") - ] - assert adata.X.shape[0] == (2 * sum([adata.shape[0] for adata in adatas])) + adata = ad.concat([ad.read_zarr(zarr_path) for zarr_path in (mock_anndatas_path[1] / "zarr_store").iterdir()]) + adata_orig = ad.concat([mock_anndatas_path[0], mock_anndatas_path[0]], join="outer") + assert adata.X.shape[0] == adata_orig.X.shape[0] diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 750d882..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -import numpy as np - -from arrayloaders.io.utils import sample_rows - - -def test_sample_rows_basic(): - """ - This test checks the sample_rows function without shuffling. - - Verifies that the function yields the expected (x, y) pairs - when given lists of arrays and labels, and shuffle is set to False. - """ - x_list = [np.arange(6).reshape(3, 2), np.arange(8, 16).reshape(4, 2)] - y_list = [np.array([0, 1, 2]), np.array([3, 4, 5, 6])] - # Test without shuffling - result = list(sample_rows(x_list, y_list, shuffle=False)) - expected = [ - (np.array([0, 1]), 0), - (np.array([2, 3]), 1), - (np.array([4, 5]), 2), - (np.array([8, 9]), 3), - (np.array([10, 11]), 4), - (np.array([12, 13]), 5), - (np.array([14, 15]), 6), - ] - for (x, y), (ex, ey) in zip(result, expected, strict=True): - np.testing.assert_array_equal(x, ex) - assert y == ey - - -def test_sample_rows_shuffle(): - """ - This test checks the sample_rows function with shuffling enabled. - - Ensures that all unique (x, y) pairs are present in the result, - regardless of order, when shuffle is set to True. - """ - x_list = [np.arange(6).reshape(3, 2), np.arange(8, 16).reshape(4, 2)] - y_list = [np.array([0, 1, 2]), np.array([3, 4, 5, 6])] - result = list(sample_rows(x_list, y_list, shuffle=True)) - # Should have all unique pairs, order may differ - assert sorted([tuple(x) + (y,) for x, y in result]) == [ - (0, 1, 0), - (2, 3, 1), - (4, 5, 2), - (8, 9, 3), - (10, 11, 4), - (12, 13, 5), - (14, 15, 6), - ] diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 0000000..2ccb618 --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,5 @@ +import arrayloaders + + +def test_package_has_version(): + assert arrayloaders.__version__ is not None