Skip to content

Commit

Permalink
feat!(pkg): generate __version__ and build_meta
Browse files Browse the repository at this point in the history
chore: adopt conventionalcommits

chore(pkg): use `import x as x` for re-export.

BREAKING-CHANGE: do not re-export `punica.ops.*` to `punica`.
  • Loading branch information
abcdabcd987 committed Dec 22, 2023
1 parent e523846 commit d415814
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 57 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
csrc/flashinfer_adapter/generated/
src/punica/_build_meta.py
data/
build/
tmp/
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ dev = [
]

[tool.ruff]
exclude = ["third_party"]
exclude = ["third_party", "src/punica/_build_meta.py"]

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["punica"]
combine-as-imports = true

[tool.ruff.lint]
select = [
Expand All @@ -48,7 +49,6 @@ ignore = [
]

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]

[tool.pytest.ini_options]
testpaths = ["tests"]
Expand Down
117 changes: 81 additions & 36 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import contextlib
import datetime
import itertools
import os
import pathlib
import platform
import re
import subprocess

import setuptools
import torch
import torch.utils.cpp_extension as torch_cpp_ext

root = pathlib.Path(__name__).parent
Expand All @@ -13,14 +18,6 @@ def glob(pattern):
return [str(p) for p in root.glob(pattern)]


def get_version(path):
with open(path) as f:
for line in f:
if line.startswith("__version__"):
return line.split("=", maxsplit=1)[1].replace('"', "").strip()
raise ValueError("Version not found")


def remove_unwanted_pytorch_nvcc_flags():
REMOVE_NVCC_FLAGS = [
"-D__CUDA_NO_HALF_OPERATORS__",
Expand Down Expand Up @@ -85,33 +82,81 @@ def generate_flashinfer_cu() -> list[str]:
return files


remove_unwanted_pytorch_nvcc_flags()
ext_modules = []
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="punica.ops._kernels",
sources=[
"csrc/punica_ops.cc",
"csrc/bgmv/bgmv_all.cu",
"csrc/flashinfer_adapter/flashinfer_all.cu",
"csrc/rms_norm/rms_norm_cutlass.cu",
"csrc/sgmv/sgmv_cutlass.cu",
"csrc/sgmv_flashinfer/sgmv_all.cu",
]
+ generate_flashinfer_cu(),
include_dirs=[
str(root.resolve() / "third_party/cutlass/include"),
str(root.resolve() / "third_party/flashinfer/include"),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3"],
},
def get_local_version_suffix() -> str:
if not (root / ".git").is_dir():
return ""
now = datetime.datetime.now()
git_hash = subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"], cwd=root, text=True
).strip()
commit_number = subprocess.check_output(
["git", "rev-list", "HEAD", "--count"], cwd=root, text=True
).strip()
dirty = ".dirty" if subprocess.run(["git", "diff", "--quiet"]).returncode else ""
return f"+c{commit_number}.d{now:%Y%m%d}.{git_hash}{dirty}"


def get_version() -> str:
version = os.getenv("PUNICA_BUILD_VERSION")
if version is None:
with open(root / "version.txt") as f:
version = f.read().strip()
version += get_local_version_suffix()
return version


def get_cuda_version() -> tuple[int, int]:
if torch_cpp_ext.CUDA_HOME is None:
nvcc = "nvcc"
else:
nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc")
txt = subprocess.check_output([nvcc, "--version"], text=True)
major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0])
return major, minor


def generate_build_meta() -> None:
d = {}
version = get_version()
d["cuda_major"], d["cuda_minor"] = get_cuda_version()
d["torch"] = torch.__version__
d["python"] = platform.python_version()
d["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
with open(root / "src/punica/_build_meta.py", "w") as f:
f.write(f"__version__ = {version!r}\n")
f.write(f"build_meta = {d!r}")


if __name__ == "__main__":
remove_unwanted_pytorch_nvcc_flags()
generate_build_meta()

ext_modules = []
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="punica.ops._kernels",
sources=[
"csrc/punica_ops.cc",
"csrc/bgmv/bgmv_all.cu",
"csrc/flashinfer_adapter/flashinfer_all.cu",
"csrc/rms_norm/rms_norm_cutlass.cu",
"csrc/sgmv/sgmv_cutlass.cu",
"csrc/sgmv_flashinfer/sgmv_all.cu",
]
+ generate_flashinfer_cu(),
include_dirs=[
str(root.resolve() / "third_party/cutlass/include"),
str(root.resolve() / "third_party/flashinfer/include"),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3"],
},
)
)
)

setuptools.setup(
version=get_version(root / "src/punica/__init__.py"),
ext_modules=ext_modules,
cmdclass={"build_ext": torch_cpp_ext.BuildExtension},
)
setuptools.setup(
version=get_version(),
ext_modules=ext_modules,
cmdclass={"build_ext": torch_cpp_ext.BuildExtension},
)
38 changes: 21 additions & 17 deletions src/punica/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
__version__ = "0.2.0"

from .models.llama import LlamaForCausalLM, LlamaModel
from . import ops as ops
from ._build_meta import __version__ as __version__
from .models.llama import (
LlamaForCausalLM as LlamaForCausalLM,
LlamaModel as LlamaModel,
)
from .models.llama_lora import (
BatchedLlamaLoraWeight,
LlamaForCausalLMWithLora,
LlamaLoraWeight,
LlamaModelWithLora,
BatchedLlamaLoraWeight as BatchedLlamaLoraWeight,
LlamaForCausalLMWithLora as LlamaForCausalLMWithLora,
LlamaLoraWeight as LlamaLoraWeight,
LlamaModelWithLora as LlamaModelWithLora,
)
from .utils.cat_tensor import (
BatchLenInfo as BatchLenInfo,
)
from .utils.kvcache import (
BatchedKvCache as BatchedKvCache,
KvCache as KvCache,
KvPool as KvPool,
)
from .ops import (
add_lora_sgmv_custom_cutlass,
append_kv,
batch_decode,
init_kv,
rms_norm,
sgmv,
from .utils.lora import (
BatchedLoraWeight as BatchedLoraWeight,
LoraWeight as LoraWeight,
)
from .utils import BatchedLoraWeight, LoraWeight
from .utils.cat_tensor import BatchLenInfo
from .utils.kvcache import BatchedKvCache, KvCache, KvPool
6 changes: 5 additions & 1 deletion src/punica/models/llama_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

from punica.ops import (
add_lora_sgmv_custom_cutlass as add_lora,
append_kv,
batch_decode,
batch_prefill,
init_kv,
rms_norm,
)
from punica.ops import append_kv, batch_decode, batch_prefill, init_kv, rms_norm
from punica.utils import BatchedKvCache, BatchedLoraWeight, BatchLenInfo, LoraWeight


Expand Down
1 change: 1 addition & 0 deletions version.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.3.0

0 comments on commit d415814

Please sign in to comment.