diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..ede54aaa --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM rust:1.72.0 as builder + + +RUN apt update && apt full-upgrade -y && apt install python3 python3-pip -y && apt autoremove -y +# Install libtorch +RUN pip3 install torch==2.0.0 --break-system-packages + +WORKDIR /app + +COPY src src +COPY benches benches +COPY build.rs build.rs +COPY Cargo.toml Cargo.toml +COPY Cargo.lock Cargo.lock +ENV LIBTORCH_USE_PYTORCH=1 +RUN cargo build --release --bin convert-tensor + +# # ============ + +FROM python:3.11-slim + +COPY --from=builder /app/target/release/convert-tensor . + +RUN apt update && apt full-upgrade -y && apt install libgomp1 && apt autoremove -y + +RUN pip3 install torch==2.0.0 numpy==1.26.0 --break-system-packages + +COPY utils utils + +ENV ON_DOCKER=1 +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib + +ENTRYPOINT [ "python3", "./utils/convert_model.py"] + diff --git a/utils/convert_model.py b/utils/convert_model.py index b5e5f089..7d517380 100644 --- a/utils/convert_model.py +++ b/utils/convert_model.py @@ -48,6 +48,8 @@ import subprocess import sys import zipfile +import os + from pathlib import Path from typing import Dict @@ -178,16 +180,25 @@ def append_to_zipf( source = str(target_folder / "model.npz") target = str(target_folder / "rust_model.ot") - toml_location = (Path(__file__).resolve() / ".." / ".." / "Cargo.toml").resolve() - cargo_args = [ - "cargo", - "run", - "--bin=convert-tensor", - "--manifest-path=%s" % toml_location, - "--", - source, - target, - ] - if args.download_libtorch: - cargo_args += ["--features", "download-libtorch"] + if os.getenv("ON_DOCKER") is not None: + cargo_args = [ + "/convert-tensor", + source, + target, + ] + else: + toml_location = ( + Path(__file__).resolve() / ".." / ".." / "Cargo.toml" + ).resolve() + cargo_args = [ + "cargo", + "run", + "--bin=convert-tensor", + "--manifest-path=%s" % toml_location, + "--", + source, + target, + ] + if args.download_libtorch: + cargo_args += ["--features", "download-libtorch"] subprocess.run(cargo_args)