diff --git a/.github/workflows/perf-bench.yml b/.github/workflows/perf-bench.yml index 3657862f2f..1ec03708f9 100644 --- a/.github/workflows/perf-bench.yml +++ b/.github/workflows/perf-bench.yml @@ -32,7 +32,7 @@ jobs: - name: Install data-juicer working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' + docker compose exec ray-head bash -c 'uv pip install --system -e .\[all\]' - name: Clean dataset cache working-directory: dj-${{ github.run_id }}/.github/workflows/docker diff --git a/.github/workflows/unit-test-partial.yml b/.github/workflows/unit-test-partial.yml index 89fb8a732b..58a0497677 100644 --- a/.github/workflows/unit-test-partial.yml +++ b/.github/workflows/unit-test-partial.yml @@ -31,12 +31,12 @@ jobs: - name: Install data-juicer working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' + docker compose exec ray-head bash -c 'UV_HTTP_TIMEOUT=3600 uv pip install --system -e .\[all\]' - name: Print Pip Dependency Tree working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system pipdeptree' + docker compose exec ray-head bash -c 'uv pip install --system pipdeptree' docker compose exec ray-head bash -c 'pipdeptree' - name: Clean dataset cache @@ -90,8 +90,8 @@ jobs: - name: Install data-juicer working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' - docker compose exec ray-worker bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' + docker compose exec ray-head bash -c 'UV_HTTP_TIMEOUT=3600 uv pip install --system -e .\[all\]' + docker compose exec ray-worker bash -c 'UV_HTTP_TIMEOUT=3600 uv pip install --system -e .\[all\]' - name: Clean dataset cache working-directory: dj-${{ github.run_id }}/.github/workflows/docker @@ -140,7 +140,7 @@ jobs: - name: Install coverage working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system coverage' + docker compose exec ray-head bash -c 'uv pip install --system coverage' - name: Download Coverage Report Standalone uses: actions/download-artifact@v4 diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index cb3526c9d7..f6bd879547 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -32,13 +32,13 @@ jobs: - name: Install data-juicer working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' - docker compose exec ray-worker bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' + docker compose exec ray-head bash -c 'uv pip install --system -e .\[all\]' + docker compose exec ray-worker bash -c 'uv pip install --system -e .\[all\]' - name: Print Pip Dependency Tree working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system pipdeptree' + docker compose exec ray-head bash -c 'uv pip install --system pipdeptree' docker compose exec ray-head bash -c 'pipdeptree' - name: Clean dataset cache @@ -87,8 +87,8 @@ jobs: - name: Install data-juicer working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' - docker compose exec ray-worker bash -c '/root/.local/bin/uv pip install --system -e .\[all\]' + docker compose exec ray-head bash -c 'uv pip install --system -e .\[all\]' + docker compose exec ray-worker bash -c 'uv pip install --system -e .\[all\]' - name: Clean dataset cache working-directory: dj-${{ github.run_id }}/.github/workflows/docker @@ -139,7 +139,7 @@ jobs: - name: Install coverage working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec ray-head bash -c '/root/.local/bin/uv pip install --system coverage' + docker compose exec ray-head bash -c 'uv pip install --system coverage' - name: Download Coverage Report Standalone uses: actions/download-artifact@v4 diff --git a/.gitignore b/.gitignore index 8ac3ff5547..08eddd1397 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,8 @@ perf_bench_data/ # env file .env + +# cython outputs +/data_juicer/ops/deduplicator/minhash.cpython-* +/data_juicer/ops/deduplicator/tokenize.c +/data_juicer/ops/deduplicator/tokenize.cpython-* diff --git a/.pre-commit-hooks/build_op_doc.py b/.pre-commit-hooks/build_op_doc.py index 46c2fc1b96..9663c9ef21 100644 --- a/.pre-commit-hooks/build_op_doc.py +++ b/.pre-commit-hooks/build_op_doc.py @@ -340,12 +340,14 @@ def get_op_list_from_code(): # get docs for formatters first op_record_list = get_op_list_from_code_for_formatter() # get docs for other ops + op_num_dict = {} for type in os.listdir(OP_CODE_PREFIX): if type in OP_EXCLUDE: continue type_dir = os.path.join(OP_CODE_PREFIX, type) if os.path.isfile(type_dir): continue + op_num_dict[type] = 0 for op in os.listdir(type_dir): if op in OP_EXCLUDE: continue @@ -369,8 +371,9 @@ def get_op_list_from_code(): ref=ref_link(op.replace(".py", "")), ) ) + op_num_dict[type] += 1 op_record_list.sort(key=lambda record: (record.type, record.name)) - return op_record_list + return op_record_list, op_num_dict def generate_new_doc(op_record_list, old_op_record_list): @@ -516,6 +519,22 @@ def get_op_desc_in_en_zh_batched(descs): return zhs +def parse_op_num_from_doc(doc_content): + pattern = r"\| +(.*?) +\| +(.*?) +\| +(.*?) +\|" + link_pattern = r"\[(.*?)\]\(.*\)" + overview_section = doc_content.split("## Overview 概览")[1].split("##")[0] + res = re.findall(pattern, overview_section) + num_dict = {} + for type, num, desc in res: + if type == "Type 类型": + continue + type = re.findall(link_pattern, type)[0] + if type == "formatter": + continue + num_dict[type] = int(num) + return num_dict + + def parse_op_record_from_current_doc(): """ Parse the old-version OP records from the existing OP doc. @@ -527,6 +546,7 @@ def parse_op_record_from_current_doc(): op_record_list = [] with open(DOC_PATH, "r", encoding="utf-8") as fin: content = fin.read() + op_num_dict = parse_op_num_from_doc(content) res = re.findall(tab_pattern, content) for name, tags, desc, info, ref in res: # skip table header @@ -553,9 +573,9 @@ def parse_op_record_from_current_doc(): ) ) op_record_list.sort(key=lambda record: (record.type, record.name)) - return op_record_list + return op_record_list, op_num_dict else: - return [] + return [], {} def check_and_update_op_record(old_op_record_list, new_op_record_list): @@ -620,11 +640,11 @@ def check_and_update_op_record(old_op_record_list, new_op_record_list): def main(): - old_op_record_list = parse_op_record_from_current_doc() - new_op_record_list = get_op_list_from_code() + old_op_record_list, old_op_num_dict = parse_op_record_from_current_doc() + new_op_record_list, new_op_num_dict = get_op_list_from_code() updated_op_record_list = check_and_update_op_record(old_op_record_list, new_op_record_list) # if the doc is changed, exit with non-zero value - if old_op_record_list == updated_op_record_list: + if new_op_num_dict == old_op_num_dict and old_op_record_list == updated_op_record_list: exit(0) else: generate_new_doc(updated_op_record_list, old_op_record_list) diff --git a/Dockerfile b/Dockerfile index eae1995c1e..c775832176 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # The data-juicer image includes all open-source contents of data-juicer, # and it will be installed in editable mode. -FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 +FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu24.04 # change to aliyun source RUN sed -i 's/archive.ubuntu.com/mirrors.aliyun.com/g' /etc/apt/sources.list \ diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 1d9df16ff4..2825ed1c01 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -47,6 +47,7 @@ from .video_duration_filter import VideoDurationFilter from .video_frames_text_similarity_filter import VideoFramesTextSimilarityFilter from .video_motion_score_filter import VideoMotionScoreFilter +from .video_motion_score_ptlflow_filter import VideoMotionScorePtlflowFilter from .video_motion_score_raft_filter import VideoMotionScoreRaftFilter from .video_nsfw_filter import VideoNSFWFilter from .video_ocr_area_ratio_filter import VideoOcrAreaRatioFilter @@ -103,6 +104,7 @@ "VideoDurationFilter", "VideoFramesTextSimilarityFilter", "VideoMotionScoreFilter", + "VideoMotionScorePtlflowFilter", "VideoMotionScoreRaftFilter", "VideoNSFWFilter", "VideoOcrAreaRatioFilter", diff --git a/data_juicer/ops/filter/llm_perplexity_filter.py b/data_juicer/ops/filter/llm_perplexity_filter.py index 7848893462..e2ea14bc28 100644 --- a/data_juicer/ops/filter/llm_perplexity_filter.py +++ b/data_juicer/ops/filter/llm_perplexity_filter.py @@ -65,34 +65,34 @@ def __init__( self.model_params = model_params self.model_key = prepare_model(model_type="huggingface", pretrained_model_name_or_path=hf_model, **model_params) - @torch.no_grad() def _loss(self, example, pre_example=None, rank=None): - model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) - model.eval() - tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token - tokenizer.padding_side = "left" - tokenizer.truncation_side = "left" - - pre_msgs = pre_example["messages"] if pre_example is not None else [] - msgs = pre_msgs + example["messages"] - # TODO: chat template - full_text = " ".join([msg["content"] for msg in msgs]).strip() - response_text = msgs[-1]["content"].strip() - max_length = self.model_params.get("max_length", None) - full_tokenized = tokenizer(full_text, max_length=max_length, truncation=True, return_tensors="pt") - input_ids = full_tokenized["input_ids"] - response_ids = tokenizer(response_text, max_length=max_length, truncation=True, return_tensors="pt")[ - "input_ids" - ][0] - response_len = len(response_ids) - int(tokenizer.bos_token_id is not None) - labels = input_ids.clone() - labels[0, :-response_len] = -100 - - input_ids = input_ids.to(model.device) - labels = labels.to(model.device) - loss = model(input_ids=input_ids, labels=labels).loss.item() - - return loss + with torch.no_grad(): + model, tokenizer = get_model(self.model_key, rank, self.use_cuda()) + model.eval() + tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token + tokenizer.padding_side = "left" + tokenizer.truncation_side = "left" + + pre_msgs = pre_example["messages"] if pre_example is not None else [] + msgs = pre_msgs + example["messages"] + # TODO: chat template + full_text = " ".join([msg["content"] for msg in msgs]).strip() + response_text = msgs[-1]["content"].strip() + max_length = self.model_params.get("max_length", None) + full_tokenized = tokenizer(full_text, max_length=max_length, truncation=True, return_tensors="pt") + input_ids = full_tokenized["input_ids"] + response_ids = tokenizer(response_text, max_length=max_length, truncation=True, return_tensors="pt")[ + "input_ids" + ][0] + response_len = len(response_ids) - int(tokenizer.bos_token_id is not None) + labels = input_ids.clone() + labels[0, :-response_len] = -100 + + input_ids = input_ids.to(model.device) + labels = labels.to(model.device) + loss = model(input_ids=input_ids, labels=labels).loss.item() + + return loss def sample_with_messages(self, sample, system_prompt=None): if "messages" in sample: diff --git a/data_juicer/ops/filter/video_motion_score_filter.py b/data_juicer/ops/filter/video_motion_score_filter.py index b210f249bc..5aa002bd0e 100644 --- a/data_juicer/ops/filter/video_motion_score_filter.py +++ b/data_juicer/ops/filter/video_motion_score_filter.py @@ -5,7 +5,7 @@ import numpy as np from pydantic import PositiveFloat, PositiveInt -from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.constant import Fields, MetaKeys, StatsKeys from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.mm_utils import calculate_resized_dimensions @@ -28,7 +28,7 @@ def VideoCapture(*args, **kwargs): @UNFORKABLE.register_module(OP_NAME) @OPERATORS.register_module(OP_NAME) class VideoMotionScoreFilter(Filter): - """Filter to keep samples with video motion scores within a specific range. + """Filter to keep samples with video motion scores from OpenCV within a specific range. The operator uses Farneback's algorithm from OpenCV to compute dense optical flow. It calculates the average motion score for each video and retains samples based on the @@ -58,6 +58,8 @@ def __init__( divisible: PositiveInt = 1, relative: bool = False, any_or_all: str = "any", + if_output_optical_flow: bool = False, + optical_flow_key: str = MetaKeys.video_optical_flow, *args, **kwargs, ): @@ -87,6 +89,11 @@ def __init__( all videos. 'any': keep this sample if any videos meet the condition. 'all': keep this sample only if all videos meet the condition. + :param if_output_optical_flow: Determines whether to output + the computed optical flows into the metas. The optical flows for each + video will be stored in the shape of (num_frame, H, W, 2) + :param optical_flow_key: The field name to store the optical flows. It's + "video_optical_flow" in default. :param args: extra args :param kwargs: extra args """ @@ -117,6 +124,12 @@ def __init__( raise ValueError(f"Keep strategy [{any_or_all}] is not supported. " f'Can only be one of ["any", "all"].') self.any = any_or_all == "any" + self.if_output_optical_flow = if_output_optical_flow + self.optical_flow_key = optical_flow_key + + # setup model + self.model = None + def setup_model(self, rank=None): self.model = cv2.calcOpticalFlowFarneback @@ -150,29 +163,42 @@ def compute_stats_single(self, sample, rank=None, context=False): all_videos_frames = sample[self.frame_field] num_videos = len(all_videos_frames) unique_motion_scores = {} + video_optical_flows = {} for video_idx in range(num_videos): - unique_motion_scores[video_idx] = self._compute_motion_scores_from_frames(all_videos_frames[video_idx]) + unique_motion_scores[video_idx], video_optical_flows[video_idx] = ( + self._compute_motion_scores_from_frames(all_videos_frames[video_idx]) + ) sample[Fields.stats][StatsKeys.video_motion_score] = [ unique_motion_scores.get(idx, -1) for idx in range(num_videos) ] + if self.if_output_optical_flow: + sample[Fields.meta][self.optical_flow_key] = [ + video_optical_flows.get(idx, -1) for idx in range(num_videos) + ] else: # Read videos and compute motion scores loaded_video_keys = sample[self.video_key] unique_motion_scores = {} + video_optical_flows = {} for video_key in loaded_video_keys: # skip duplicate videos if video_key in unique_motion_scores: continue - unique_motion_scores[video_key] = self._compute_motion_scores_from_video(video_key) + unique_motion_scores[video_key], video_optical_flows[video_key] = ( + self._compute_motion_scores_from_video(video_key) + ) sample[Fields.stats][StatsKeys.video_motion_score] = [ unique_motion_scores.get(key, -1) for key in sample[self.video_key] ] + if self.if_output_optical_flow: + sample[Fields.meta][self.optical_flow_key] = [video_optical_flows[key] for key in loaded_video_keys] return sample def _compute_motion_scores_from_frames(self, frames): video_motion_scores = [] + optical_flows = [] prev_frame = None for frame in frames: if isinstance(frame, bytes): @@ -190,16 +216,20 @@ def _compute_motion_scores_from_frames(self, frames): flow, prev_frame = self.compute_flow(prev_frame, frame) if flow is None: continue + optical_flows.append(flow) mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) frame_motion_score = np.mean(mag) if self.relative: frame_motion_score /= np.hypot(*frame.shape[:2]) video_motion_scores.append(float(frame_motion_score)) - return np.mean(video_motion_scores or [-1]) + res_optical_flow = np.stack(optical_flows).tolist() if optical_flows else [] + + return np.mean(video_motion_scores or [-1]), res_optical_flow def _compute_motion_scores_from_video(self, video_key): video_motion_scores = [] + optical_flows = [] with VideoCapture(video_key) as cap: if cap.isOpened(): fps = cap.get(cv2.CAP_PROP_FPS) @@ -229,6 +259,7 @@ def _compute_motion_scores_from_video(self, video_key): flow, prev_frame = self.compute_flow(prev_frame, frame) if flow is None: continue + optical_flows.append(flow) mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) frame_motion_score = np.mean(mag) if self.relative: @@ -239,7 +270,9 @@ def _compute_motion_scores_from_video(self, video_key): frame_count += sampling_step cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count) - return np.mean(video_motion_scores or [-1]) + res_optical_flow = np.stack(optical_flows).tolist() if optical_flows else [] + + return np.mean(video_motion_scores or [-1]), res_optical_flow def process_single(self, sample): video_motion_scores = sample[Fields.stats][StatsKeys.video_motion_score] diff --git a/data_juicer/ops/filter/video_motion_score_ptlflow_filter.py b/data_juicer/ops/filter/video_motion_score_ptlflow_filter.py new file mode 100644 index 0000000000..514cbb8e37 --- /dev/null +++ b/data_juicer/ops/filter/video_motion_score_ptlflow_filter.py @@ -0,0 +1,113 @@ +import sys +from typing import Optional, Tuple, Union + +from jsonargparse import dict_to_namespace +from pydantic import PositiveFloat, PositiveInt + +from data_juicer.ops.filter.video_motion_score_filter import VideoMotionScoreFilter +from data_juicer.utils.constant import MetaKeys +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.resource_utils import cuda_device_count + +from ..base_op import OPERATORS, UNFORKABLE + +torch = LazyLoader("torch") +tvm = LazyLoader("torchvision.models") +tvt = LazyLoader("torchvision.transforms") +ptlflow = LazyLoader("ptlflow") +ptlflow_io_adapter = LazyLoader("ptlflow.utils.io_adapter") + +OP_NAME = "video_motion_score_ptlflow_filter" + + +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +class VideoMotionScorePtlflowFilter(VideoMotionScoreFilter): + """Filter to keep samples with video motion scores from ptlflow within a specified range. + + This operator utilizes the ptlflow library (https://github.com/hmorimitsu/ptlflow) to + predict optical flow between video frames. It keeps samples where the + video motion score is within the given min and max score range. The motion score is + computed based on the optical flow between frames, which is estimated using the models + supported in ptlflow. The operator can sample frames at a specified FPS and apply + transformations to the frames before computing the flow. + + - The models in ptlflow is used to estimate the optical flow. + - Frames are preprocessed using a series of transformations including normalization and + color channel flipping. + - The motion score is calculated from the optical flow data. + - The operator can be configured to filter based on any or all frames in the video. + - The device for model inference (CPU or CUDA) is automatically detected and set. + + For further details, refer to the official documentation: + https://ptlflow.readthedocs.io/ + """ + + _accelerator = "cuda" + _default_kwargs = {} + + def __init__( + self, + min_score: float = 1.0, + max_score: float = sys.float_info.max, + frame_field: Optional[str] = None, + model_name: str = "dpflow", + ckpt_path: Optional[str] = "things", + get_model_args: Optional[dict] = None, + sampling_fps: PositiveFloat = 2, + size: Union[PositiveInt, Tuple[PositiveInt], Tuple[PositiveInt, PositiveInt], None] = None, + max_size: Optional[PositiveInt] = None, + divisible: PositiveInt = 8, + relative: bool = False, + any_or_all: str = "any", + if_output_optical_flow: bool = False, + optical_flow_key: str = MetaKeys.video_optical_flow, + *args, + **kwargs, + ): + super().__init__( + min_score, + max_score, + frame_field, + sampling_fps, + size, + max_size, + divisible, + relative, + any_or_all, + if_output_optical_flow, + optical_flow_key, + *args, + **kwargs, + ) + + self.model_name = model_name + self.ckpt_path = ckpt_path + if get_model_args is not None: + get_model_args = dict_to_namespace(get_model_args) + self.get_model_args = get_model_args + + def setup_model(self, rank=None): + self.model = ptlflow.get_model(self.model_name, ckpt_path=self.ckpt_path, args=self.get_model_args) + if self.use_cuda(): + rank = rank if rank is not None else 0 + rank = rank % cuda_device_count() + self.device = f"cuda:{rank}" + else: + self.device = "cpu" + self.model.to(self.device) + self.model.eval() + + def compute_flow(self, prev_frame, curr_frame): + if prev_frame is None: + flow = None + else: + io_adapter = ptlflow_io_adapter.IOAdapter(self.model, prev_frame.shape[:2]) + frames = [prev_frame, curr_frame] + inputs = io_adapter.prepare_inputs(frames) + inputs = {key: value.to(self.device) for key, value in inputs.items()} + with torch.no_grad(): + predictions = self.model(inputs) + flows = predictions.get("flows") # shape: (1, 1, 2, H, W) + flow = flows[-1][0].detach().cpu().numpy().transpose((1, 2, 0)) # 2, H, W -> H, W, 2 + return flow, curr_frame diff --git a/data_juicer/ops/filter/video_motion_score_raft_filter.py b/data_juicer/ops/filter/video_motion_score_raft_filter.py index da78a6e168..1bfa6f9cc4 100644 --- a/data_juicer/ops/filter/video_motion_score_raft_filter.py +++ b/data_juicer/ops/filter/video_motion_score_raft_filter.py @@ -4,6 +4,7 @@ from pydantic import PositiveFloat, PositiveInt from data_juicer.ops.filter.video_motion_score_filter import VideoMotionScoreFilter +from data_juicer.utils.constant import MetaKeys from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.resource_utils import cuda_device_count @@ -49,17 +50,32 @@ def __init__( self, min_score: float = 1.0, max_score: float = sys.float_info.max, + frame_field: Optional[str] = None, sampling_fps: PositiveFloat = 2, size: Union[PositiveInt, Tuple[PositiveInt], Tuple[PositiveInt, PositiveInt], None] = None, max_size: Optional[PositiveInt] = None, divisible: PositiveInt = 8, relative: bool = False, any_or_all: str = "any", + if_output_optical_flow: bool = False, + optical_flow_key: str = MetaKeys.video_optical_flow, *args, **kwargs, ): super().__init__( - min_score, max_score, sampling_fps, size, max_size, divisible, relative, any_or_all, *args, **kwargs + min_score, + max_score, + frame_field, + sampling_fps, + size, + max_size, + divisible, + relative, + any_or_all, + if_output_optical_flow, + optical_flow_key, + *args, + **kwargs, ) def setup_model(self, rank=None): diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index db1f531997..7ee08e75fc 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -68,6 +68,8 @@ class MetaKeys(object): video_object_segment_tags = "video_object_segment_tags" # # depth info in video video_depth_tags = "video_depth_tags" + # # video optical flow + video_optical_flow = "video_optical_flow" # # info extracted by VGGT vggt_tags = "vggt_tags" # # image tags diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 95c86a5b56..547e1370fa 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -38,7 +38,7 @@ aes_pred = LazyLoader("aesthetics_predictor", "simple-aesthetics-predictor") vllm = LazyLoader("vllm") diffusers = LazyLoader("diffusers") -ram = LazyLoader("ram", "git+https://github.com/HYLcool/recognize-anything.git") +ram = LazyLoader("ram", "git+https://github.com/datajuicer/recognize-anything.git") cv2 = LazyLoader("cv2", "opencv-python") openai = LazyLoader("openai") ultralytics = LazyLoader("ultralytics") @@ -46,7 +46,7 @@ dashscope = LazyLoader("dashscope") qwen_vl_utils = LazyLoader("qwen_vl_utils", "qwen-vl-utils") transformers_stream_generator = LazyLoader( - "transformers_stream_generator", "git+https://github.com/HYLcool/transformers-stream-generator.git" + "transformers_stream_generator", "git+https://github.com/datajuicer/transformers-stream-generator.git" ) MODEL_ZOO = {} diff --git a/docs/Operators.md b/docs/Operators.md index 43d4c14664..090337be92 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -40,15 +40,15 @@ The operators in Data-Juicer are categorized into 8 types. Data-Juicer 中的算子分为以下 8 种类型。 | Type 类型 | Number 数量 | Description 描述 | -|------|:---------:|-------------| -| [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 | -| [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 | -| [filter](#filter) | 55 | Filters out low-quality samples. 过滤低质量样本。 | -| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | -| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | -| [mapper](#mapper) | 96 | Edits and transforms samples. 对数据样本进行编辑和转换。 | -| [pipeline](#pipeline) | 3 | Combines multiple operators into a data processing pipeline. 将多个算子组合成数据处理流水线。 | -| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | +|------|:------:|-------------| +| [aggregator](#aggregator) | 4 | Aggregate for batched samples, such as summary or conclusion. 对批量样本进行汇总,如得出总结或结论。 | +| [deduplicator](#deduplicator) | 10 | Detects and removes duplicate samples. 识别、删除重复样本。 | +| [filter](#filter) | 56 | Filters out low-quality samples. 过滤低质量样本。 | +| [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | +| [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | +| [mapper](#mapper) | 98 | Edits and transforms samples. 对数据样本进行编辑和转换。 | +| [pipeline](#pipeline) | 3 | Applies dataset-level processing; both input and output are datasets. 执行数据集级别的操作,输入和输出均为完整数据集。 | +| [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | All the specific operators are listed below, each featured with several capability tags. 下面列出所有具体算子,每种算子都通过多个标签来注明其主要功能。 @@ -144,7 +144,8 @@ All the specific operators are listed below, each featured with several capabili | video_aspect_ratio_filter | 🎬Video 💻CPU 🟢Stable | Filter to keep samples with video aspect ratio within a specific range. 过滤器将视频纵横比的样本保持在特定范围内。 | [info](operators/filter/video_aspect_ratio_filter.md) | - | | video_duration_filter | 🎬Video 💻CPU 🟢Stable | Keep data samples whose videos' durations are within a specified range. 保留视频持续时间在指定范围内的数据样本。 | [info](operators/filter/video_duration_filter.md) | - | | video_frames_text_similarity_filter | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Filter to keep samples based on the similarity between video frame images and text within a specific range. 根据视频帧图像和文本之间的相似性进行过滤,以保持样本在特定范围内。 | [info](operators/filter/video_frames_text_similarity_filter.md) | - | -| video_motion_score_filter | 🎬Video 💻CPU 🟢Stable | Filter to keep samples with video motion scores within a specific range. 过滤器将视频运动分数的样本保持在特定范围内。 | [info](operators/filter/video_motion_score_filter.md) | - | +| video_motion_score_filter | 🎬Video 💻CPU 🟢Stable | Filter to keep samples with video motion scores from OpenCV within a specific range. 过滤器将来自OpenCV的视频运动分数的样本保持在特定范围内。 | [info](operators/filter/video_motion_score_filter.md) | - | +| video_motion_score_ptlflow_filter | 🎬Video 🚀GPU 🟡Beta | Filter to keep samples with video motion scores from ptlflow within a specified range. 过滤器以将来自ptlflow的具有视频运动分数的样本保持在指定范围内。 | - | - | | video_motion_score_raft_filter | 🎬Video 🚀GPU 🟢Stable | Filter to keep samples with video motion scores within a specified range. 过滤器将视频运动分数的样本保持在指定范围内。 | [info](operators/filter/video_motion_score_raft_filter.md) | [RAFT](https://arxiv.org/abs/2003.12039) | | video_nsfw_filter | 🎬Video 🚀GPU 🧩HF 🟢Stable | Filter to keep samples whose videos have nsfw scores in a specified range. 过滤器以保留其视频的nsfw分数在指定范围内的样本。 | [info](operators/filter/video_nsfw_filter.md) | - | | video_ocr_area_ratio_filter | 🎬Video 🚀GPU 🟢Stable | Keep data samples whose detected text area ratios for specified frames in the video are within a specified range. 保留检测到的视频中指定帧的文本面积比率在指定范围内的数据样本。 | [info](operators/filter/video_ocr_area_ratio_filter.md) | - | diff --git a/hatch_build.py b/hatch_build.py index c597089862..311a7d4324 100644 --- a/hatch_build.py +++ b/hatch_build.py @@ -36,7 +36,8 @@ def build_extensions(self): opts.append("/std:c++11") for ext in self.extensions: - ext.extra_compile_args = opts + if ext.language == "c++": + ext.extra_compile_args = opts build_ext.build_extensions(self) diff --git a/pyproject.toml b/pyproject.toml index f32dec2c35..fcb6ede4e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,8 @@ vision = [ "rembg", # Background removal "decord", # video/audio handling "qwen-vl-utils==0.0.14", # for Qwen-VL + "ptlflow==0.4.1", # optical flow models collection + "timm==1.0.22", # avoid importing issue in the latest v1.0.23 ] # Natural Language Processing diff --git a/tests/ops/filter/test_video_motion_score_filter.py b/tests/ops/filter/test_video_motion_score_filter.py index 2478a1da11..69d4833d73 100644 --- a/tests/ops/filter/test_video_motion_score_filter.py +++ b/tests/ops/filter/test_video_motion_score_filter.py @@ -5,7 +5,7 @@ from data_juicer.ops.filter.video_motion_score_filter import \ VideoMotionScoreFilter -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -19,14 +19,12 @@ class VideoMotionScoreFilterTest(DataJuicerTestCaseBase): img1_path = os.path.join(data_path, 'img6.jpg') - def _run_helper(self, op, source_list, target_list, np=1, select_field=None): + def _run_helper(self, op, source_list, target_list, select_field=None): dataset = Dataset.from_list(source_list) if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=np) - dataset = dataset.filter(op.process, num_proc=np) - + dataset = op.run(dataset) if select_field is not None: dataset = dataset.select_columns(column_names=select_field) else: @@ -180,8 +178,40 @@ def test_parallel(self): 'videos': [self.vid3_path] }] tgt_list = [{'videos': [self.vid1_path]}] - op = VideoMotionScoreFilter(min_score=1.5, max_score=3.0) - self._run_helper(op, ds_list, tgt_list, np=2) + op = VideoMotionScoreFilter(min_score=1.5, max_score=3.0, num_proc=2) + self._run_helper(op, ds_list, tgt_list) + + def test_output_optical_flow(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScoreFilter(if_output_optical_flow=True) + dataset = Dataset.from_list(ds_list) + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=1) + dataset = dataset.filter(op.process, num_proc=1) + metas = dataset.select_columns(column_names=[Fields.meta]) + self.assertIn(MetaKeys.video_optical_flow, metas.features[Fields.meta]) + + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) def test_frame_field(self): ds_list = [{ @@ -196,8 +226,8 @@ def test_frame_field(self): }, { 'frames': [[self.img1_path, self.img1_path]], }] - op = VideoMotionScoreFilter(min_score=0, max_score=3.0, frame_field='frames') - self._run_helper(op, ds_list, tgt_list, np=2, select_field=['frames']) + op = VideoMotionScoreFilter(min_score=0, max_score=3.0, frame_field='frames', num_proc=2) + self._run_helper(op, ds_list, tgt_list, select_field=['frames']) if __name__ == '__main__': diff --git a/tests/ops/filter/test_video_motion_score_ptlflow_filter.py b/tests/ops/filter/test_video_motion_score_ptlflow_filter.py new file mode 100644 index 0000000000..fff501012f --- /dev/null +++ b/tests/ops/filter/test_video_motion_score_ptlflow_filter.py @@ -0,0 +1,238 @@ +import os +import unittest + +from datasets import Dataset + +from data_juicer.ops.filter.video_motion_score_ptlflow_filter import \ + VideoMotionScorePtlflowFilter +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + +class VideoMotionScorePtlflowFilterTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') # 22.65 + vid2_path = os.path.join(data_path, 'video2.mp4') # 11,97 + vid3_path = os.path.join(data_path, 'video3.mp4') # 2.27 + + def _run_helper(self, op, source_list, target_list): + dataset = Dataset.from_list(source_list) + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + dataset = op.run(dataset) + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test_default(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScorePtlflowFilter() + self._run_helper(op, ds_list, tgt_list) + + def test_different_model(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScorePtlflowFilter(model_name='raft') + self._run_helper(op, ds_list, tgt_list) + + def test_downscale(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=5, size=128) + self._run_helper(op, ds_list, tgt_list) + + def test_downscale_max(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=0.0, size=256, max_size=256) + self._run_helper(op, ds_list, tgt_list) + + def test_downscale_relative(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=0.005, size=(128, 160), relative=True) + self._run_helper(op, ds_list, tgt_list) + + def test_high(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=20) + self._run_helper(op, ds_list, tgt_list) + + def test_low(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid3_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=0.0, max_score=5) + self._run_helper(op, ds_list, tgt_list) + + def test_middle(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid2_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=10, max_score=20) + self._run_helper(op, ds_list, tgt_list) + + def test_any(self): + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=10, + max_score=20, + any_or_all='any') + self._run_helper(op, ds_list, tgt_list) + + def test_all(self): + ds_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }, { + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'videos': [self.vid1_path, self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path, self.vid2_path] + }] + op = VideoMotionScorePtlflowFilter(min_score=10, + max_score=30, + any_or_all='all') + self._run_helper(op, ds_list, tgt_list) + + def test_parallel(self): + import multiprocess as mp + mp.set_start_method('spawn', force=True) + + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{'videos': [self.vid3_path]}] + op = VideoMotionScorePtlflowFilter(min_score=0, max_score=10, num_proc=2) + self._run_helper(op, ds_list, tgt_list) + + def test_output_optical_flow(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScorePtlflowFilter(if_output_optical_flow=True) + dataset = Dataset.from_list(ds_list) + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=1) + dataset = dataset.filter(op.process, num_proc=1) + metas = dataset.select_columns(column_names=[Fields.meta]) + self.assertIn(MetaKeys.video_optical_flow, metas.features[Fields.meta]) + + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/filter/test_video_motion_score_raft_filter.py b/tests/ops/filter/test_video_motion_score_raft_filter.py index c41efa3b4e..1bcdeb9c01 100644 --- a/tests/ops/filter/test_video_motion_score_raft_filter.py +++ b/tests/ops/filter/test_video_motion_score_raft_filter.py @@ -5,7 +5,7 @@ from data_juicer.ops.filter.video_motion_score_raft_filter import \ VideoMotionScoreRaftFilter -from data_juicer.utils.constant import Fields +from data_juicer.utils.constant import Fields, MetaKeys from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase class VideoMotionScoreRaftFilterTest(DataJuicerTestCaseBase): @@ -15,13 +15,12 @@ class VideoMotionScoreRaftFilterTest(DataJuicerTestCaseBase): vid2_path = os.path.join(data_path, 'video2.mp4') # 10.098914 vid3_path = os.path.join(data_path, 'video3.mp4') # 2.0731936 - def _run_helper(self, op, source_list, target_list, np=1): + def _run_helper(self, op, source_list, target_list): dataset = Dataset.from_list(source_list) if Fields.stats not in dataset.features: dataset = dataset.add_column(name=Fields.stats, column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats, num_proc=np) - dataset = dataset.filter(op.process, num_proc=np) + dataset = op.run(dataset) dataset = dataset.select_columns(column_names=[op.video_key]) res_list = dataset.to_list() self.assertEqual(res_list, target_list) @@ -176,8 +175,40 @@ def test_parallel(self): 'videos': [self.vid3_path] }] tgt_list = [{'videos': [self.vid2_path]}] - op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.2) - self._run_helper(op, ds_list, tgt_list, np=2) + op = VideoMotionScoreRaftFilter(min_score=3, max_score=10.2, num_proc=2) + self._run_helper(op, ds_list, tgt_list) + + def test_output_optical_flow(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + op = VideoMotionScoreRaftFilter(if_output_optical_flow=True) + dataset = Dataset.from_list(ds_list) + if Fields.stats not in dataset.features: + dataset = dataset.add_column(name=Fields.stats, + column=[{}] * dataset.num_rows) + if Fields.meta not in dataset.features: + dataset = dataset.add_column(name=Fields.meta, + column=[{}] * dataset.num_rows) + dataset = dataset.map(op.compute_stats, num_proc=1) + dataset = dataset.filter(op.process, num_proc=1) + metas = dataset.select_columns(column_names=[Fields.meta]) + self.assertIn(MetaKeys.video_optical_flow, metas.features[Fields.meta]) + + dataset = dataset.select_columns(column_names=[op.video_key]) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) if __name__ == '__main__': diff --git a/uv.lock b/uv.lock index 9ceebc8404..8b2830e837 100644 --- a/uv.lock +++ b/uv.lock @@ -64,6 +64,15 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", ] +[[package]] +name = "absl-py" +version = "2.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/10/2a/c93173ffa1b39c1d0395b7e842bbdc62e556ca9d8d3b5572926f3e4ca752/absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9", size = 116588, upload-time = "2025-07-03T09:31:44.05Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" }, +] + [[package]] name = "accelerate" version = "1.8.1" @@ -344,6 +353,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.9.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b", size = 117034, upload-time = "2021-11-06T17:52:23.524Z" } + [[package]] name = "anyio" version = "4.9.0" @@ -432,8 +447,7 @@ dependencies = [ { name = "numpy-minmax" }, { name = "numpy-rms" }, { name = "python-stretch" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "soxr" }, ] sdist = { url = "https://files.pythonhosted.org/packages/06/1a/89d90284278f540825d40b727661aaf7ae0e6e8ba04a8f6200fba8b9b51b/audiomentations-0.41.0.tar.gz", hash = "sha256:e83e2de91393e2fdc80e4713f01f3eb5f085c55fbe77e60c9e9e3c35d7930aa7", size = 83265, upload-time = "2025-05-05T08:57:36.909Z" } @@ -625,6 +639,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/40/be3858ffed004e47e48a2cefecdbf9b950d41098b780f9dc3aa609a88351/bitarray-3.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:2a3d1b05ffdd3e95687942ae7b13c63689f85d3f15c39b33329e3cb9ce6c015f", size = 147015, upload-time = "2025-11-02T21:40:35.064Z" }, ] +[[package]] +name = "bitsandbytes" +version = "0.49.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", marker = "sys_platform != 'darwin'" }, + { name = "packaging", marker = "sys_platform != 'darwin'" }, + { name = "torch", marker = "sys_platform != 'darwin'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/4f/9f6d161e9ea68cdd6b85585dee9b383748ca07431e31c4c134111f87489e/bitsandbytes-0.49.0-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:7e69951b4d207a676986fce967544d9599f23518d0f09d478295996aeff377c2", size = 31065242, upload-time = "2025-12-11T20:50:41.903Z" }, + { url = "https://files.pythonhosted.org/packages/a5/a8/26f7815b376b1d3dae615263471cb6d0d9f9792a472d5dab529502deac67/bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl", hash = "sha256:0c46cdef50b3174463b6bdf13715c9f1f00b360be3626e3c5d2f8d226af2cf3f", size = 59053880, upload-time = "2025-12-11T20:50:45.422Z" }, + { url = "https://files.pythonhosted.org/packages/69/76/bc6460b1618322258e7d251cd0c9d11d98d5232bb37cd507451e40127f8e/bitsandbytes-0.49.0-py3-none-win_amd64.whl", hash = "sha256:57a327c6d65f7eda32eb8d416ef8e44d2415c2e7b4fdb735896abd04171ae696", size = 54700284, upload-time = "2025-12-11T20:50:49.373Z" }, +] + [[package]] name = "black" version = "25.1.0" @@ -2102,8 +2131,7 @@ dependencies = [ { name = "python-bidi" }, { name = "pyyaml" }, { name = "scikit-image" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "shapely" }, { name = "torch" }, { name = "torchvision" }, @@ -2908,6 +2936,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, ] +[[package]] +name = "h5py" +version = "3.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cc/0c/5c2b0a88158682aeafb10c1c2b735df5bc31f165bfe192f2ee9f2a23b5f1/h5py-3.12.1.tar.gz", hash = "sha256:326d70b53d31baa61f00b8aa5f95c2fcb9621a3ee8365d770c551a13dbbcbfdf", size = 411457, upload-time = "2024-09-26T16:41:39.883Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/7d/b21045fbb004ad8bb6fb3be4e6ca903841722706f7130b9bba31ef2f88e3/h5py-3.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2f0f1a382cbf494679c07b4371f90c70391dedb027d517ac94fa2c05299dacda", size = 3402133, upload-time = "2024-09-26T16:39:27.937Z" }, + { url = "https://files.pythonhosted.org/packages/29/a7/3c2a33fba1da64a0846744726fd067a92fb8abb887875a0dd8e3bac8b45d/h5py-3.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cb65f619dfbdd15e662423e8d257780f9a66677eae5b4b3fc9dca70b5fd2d2a3", size = 2866436, upload-time = "2024-09-26T16:39:32.495Z" }, + { url = "https://files.pythonhosted.org/packages/1e/d0/4bf67c3937a2437c20844165766ddd1a1817ae6b9544c3743050d8e0f403/h5py-3.12.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b15d8dbd912c97541312c0e07438864d27dbca857c5ad634de68110c6beb1c2", size = 5168596, upload-time = "2024-09-26T16:39:39.107Z" }, + { url = "https://files.pythonhosted.org/packages/85/bc/e76f4b2096e0859225f5441d1b7f5e2041fffa19fc2c16756c67078417aa/h5py-3.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59685fe40d8c1fbbee088c88cd4da415a2f8bee5c270337dc5a1c4aa634e3307", size = 5341537, upload-time = "2024-09-26T16:39:46.037Z" }, + { url = "https://files.pythonhosted.org/packages/99/bd/fb8ed45308bb97e04c02bd7aed324ba11e6a4bf9ed73967ca2a168e9cf92/h5py-3.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:577d618d6b6dea3da07d13cc903ef9634cde5596b13e832476dd861aaf651f3e", size = 2990575, upload-time = "2024-09-26T16:39:50.903Z" }, + { url = "https://files.pythonhosted.org/packages/33/61/c463dc5fc02fbe019566d067a9d18746cd3c664f29c9b8b3c3f9ed025365/h5py-3.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ccd9006d92232727d23f784795191bfd02294a4f2ba68708825cb1da39511a93", size = 3410828, upload-time = "2024-09-26T16:39:56.19Z" }, + { url = "https://files.pythonhosted.org/packages/95/9d/eb91a9076aa998bb2179d6b1788055ea09cdf9d6619cd967f1d3321ed056/h5py-3.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad8a76557880aed5234cfe7279805f4ab5ce16b17954606cca90d578d3e713ef", size = 2872586, upload-time = "2024-09-26T16:40:00.204Z" }, + { url = "https://files.pythonhosted.org/packages/b0/62/e2b1f9723ff713e3bd3c16dfeceec7017eadc21ef063d8b7080c0fcdc58a/h5py-3.12.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1473348139b885393125126258ae2d70753ef7e9cec8e7848434f385ae72069e", size = 5273038, upload-time = "2024-09-26T16:40:06.444Z" }, + { url = "https://files.pythonhosted.org/packages/e1/89/118c3255d6ff2db33b062ec996a762d99ae50c21f54a8a6047ae8eda1b9f/h5py-3.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:018a4597f35092ae3fb28ee851fdc756d2b88c96336b8480e124ce1ac6fb9166", size = 5452688, upload-time = "2024-09-26T16:40:13.054Z" }, + { url = "https://files.pythonhosted.org/packages/1d/4d/cbd3014eb78d1e449b29beba1f3293a841aa8086c6f7968c383c2c7ff076/h5py-3.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fdf95092d60e8130ba6ae0ef7a9bd4ade8edbe3569c13ebbaf39baefffc5ba4", size = 3006095, upload-time = "2024-09-26T16:40:17.822Z" }, + { url = "https://files.pythonhosted.org/packages/d4/e1/ea9bfe18a3075cdc873f0588ff26ce394726047653557876d7101bf0c74e/h5py-3.12.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:06a903a4e4e9e3ebbc8b548959c3c2552ca2d70dac14fcfa650d9261c66939ed", size = 3372538, upload-time = "2024-09-26T16:40:22.796Z" }, + { url = "https://files.pythonhosted.org/packages/0d/74/1009b663387c025e8fa5f3ee3cf3cd0d99b1ad5c72eeb70e75366b1ce878/h5py-3.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b3b8f3b48717e46c6a790e3128d39c61ab595ae0a7237f06dfad6a3b51d5351", size = 2868104, upload-time = "2024-09-26T16:40:26.817Z" }, + { url = "https://files.pythonhosted.org/packages/af/52/c604adc06280c15a29037d4aa79a24fe54d8d0b51085e81ed24b2fa995f7/h5py-3.12.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:050a4f2c9126054515169c49cb900949814987f0c7ae74c341b0c9f9b5056834", size = 5194606, upload-time = "2024-09-26T16:40:32.847Z" }, + { url = "https://files.pythonhosted.org/packages/fa/63/eeaacff417b393491beebabb8a3dc5342950409eb6d7b39d437289abdbae/h5py-3.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c4b41d1019322a5afc5082864dfd6359f8935ecd37c11ac0029be78c5d112c9", size = 5413256, upload-time = "2024-09-26T16:40:39.188Z" }, + { url = "https://files.pythonhosted.org/packages/86/f7/bb465dcb92ca3521a15cbe1031f6d18234dbf1fb52a6796a00bfaa846ebf/h5py-3.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:e4d51919110a030913201422fb07987db4338eba5ec8c5a15d6fab8e03d443fc", size = 2993055, upload-time = "2024-09-26T16:40:44.278Z" }, + { url = "https://files.pythonhosted.org/packages/23/1c/ecdd0efab52c24f2a9bf2324289828b860e8dd1e3c5ada3cf0889e14fdc1/h5py-3.12.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:513171e90ed92236fc2ca363ce7a2fc6f2827375efcbb0cc7fbdd7fe11fecafc", size = 3346239, upload-time = "2024-09-26T16:40:48.735Z" }, + { url = "https://files.pythonhosted.org/packages/93/cd/5b6f574bf3e318bbe305bc93ba45181676550eb44ba35e006d2e98004eaa/h5py-3.12.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:59400f88343b79655a242068a9c900001a34b63e3afb040bd7cdf717e440f653", size = 2843416, upload-time = "2024-09-26T16:40:53.424Z" }, + { url = "https://files.pythonhosted.org/packages/8a/4f/b74332f313bfbe94ba03fff784219b9db385e6139708e55b11490149f90a/h5py-3.12.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3e465aee0ec353949f0f46bf6c6f9790a2006af896cee7c178a8c3e5090aa32", size = 5154390, upload-time = "2024-09-26T16:40:59.787Z" }, + { url = "https://files.pythonhosted.org/packages/1a/57/93ea9e10a6457ea8d3b867207deb29a527e966a08a84c57ffd954e32152a/h5py-3.12.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba51c0c5e029bb5420a343586ff79d56e7455d496d18a30309616fdbeed1068f", size = 5378244, upload-time = "2024-09-26T16:41:06.22Z" }, + { url = "https://files.pythonhosted.org/packages/50/51/0bbf3663062b2eeee78aa51da71e065f8a0a6e3cb950cc7020b4444999e6/h5py-3.12.1-cp313-cp313-win_amd64.whl", hash = "sha256:52ab036c6c97055b85b2a242cb540ff9590bacfda0c03dd0cf0661b311f522f8", size = 2979760, upload-time = "2024-09-26T16:41:10.425Z" }, +] + [[package]] name = "hf-xet" version = "1.1.5" @@ -3045,6 +3104,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/97/dde1ba4ffc9923c9d97069c76d0469eb17b886513cdd6380a21e4290134d/humansignal_drf_yasg-1.21.10.post1-py3-none-any.whl", hash = "sha256:aa6fdd504b727bcc6ef8c05540505b8c53922156058cf64da269cd90bdeed2d8", size = 4289795, upload-time = "2024-11-12T16:44:29.311Z" }, ] +[[package]] +name = "hydra-core" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "omegaconf" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/8e/07e42bc434a847154083b315779b0a81d567154504624e181caf2c71cd98/hydra-core-1.3.2.tar.gz", hash = "sha256:8a878ed67216997c3e9d88a8e72e7b4767e81af37afb4ea3334b269a4390a824", size = 3263494, upload-time = "2023-02-23T18:33:43.03Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl", hash = "sha256:fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b", size = 154547, upload-time = "2023-02-23T18:33:40.801Z" }, +] + [[package]] name = "identify" version = "2.6.12" @@ -3402,6 +3475,9 @@ wheels = [ ] [package.optional-dependencies] +jsonnet = [ + { name = "jsonnet" }, +] signatures = [ { name = "docstring-parser" }, { name = "typeshed-client" }, @@ -3419,6 +3495,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/62/d9ba6323b9202dd2fe166beab8a86d29465c41a0288cbe229fac60c1ab8d/jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55", size = 8701, upload-time = "2023-09-01T12:34:42.563Z" }, ] +[[package]] +name = "jsonnet" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/bd/e4a77ccb757a3060f30eefbd090b9593fe6ad15e5ef8ff0c3fc4aa5237cf/jsonnet-0.21.0.tar.gz", hash = "sha256:7fe2865e6e1dc2b9791d880fea3eba7e72334b256d85f027da3ae1f56a55b1da", size = 461207, upload-time = "2025-05-07T13:20:51.321Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/8e/7658eccf7b1c76da3d65016f2000c10118e9f406268592d61e4e9b13ee84/jsonnet-0.21.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4717d83a15144adc9ae7d3d0a0d0ff54d7fe18349346130bd9b9bb7f8c9b0db", size = 473029, upload-time = "2025-05-07T13:20:01.359Z" }, + { url = "https://files.pythonhosted.org/packages/fb/e8/46ba8d6ac206429c3d6f64b453b034e743316ccb281d1c6d36b663ed926a/jsonnet-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:121a24583fe6980705b8f775f2b66e2b01c4006dbd258d047d54f60b76b98681", size = 438844, upload-time = "2025-05-07T13:20:03.959Z" }, + { url = "https://files.pythonhosted.org/packages/13/1b/a77b8922d3e0dc90baba2a3bb783267acd76becd125d91144312d865b908/jsonnet-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c87bbf37e2f118e75de30ec4d3d1d2a5eedd7fe213f00042e3a2fe0e7026bbc", size = 6527221, upload-time = "2025-05-07T13:20:06.276Z" }, + { url = "https://files.pythonhosted.org/packages/a6/48/cd23105784731f94beecc53c8d7e966fde9a5efd0276b8765b453d97afea/jsonnet-0.21.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:902cb1a9bb7916f3e8041a2936e6ba4deea7312843927360c698d1092144d49c", size = 6777703, upload-time = "2025-05-07T13:20:08.812Z" }, + { url = "https://files.pythonhosted.org/packages/c0/5c/323f52ee8284c9c37690a625fee8f1e3ededc8a7b79e4ff1adf8a16da02d/jsonnet-0.21.0-cp310-cp310-win_amd64.whl", hash = "sha256:ad896e2d70bc6ea4c5503b9587703e75a233506a57c33fa3192922e49b97a90a", size = 318250, upload-time = "2025-05-07T13:20:10.586Z" }, + { url = "https://files.pythonhosted.org/packages/6c/da/2f359a0d29811f7f1c9be3f6beb5cd1f5c2f571cf815511316854bce6ed0/jsonnet-0.21.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bc2c8b35122884dcb63431a831e81d6ab494e37148704a781ef88bb7e12fb36b", size = 473030, upload-time = "2025-05-07T13:20:12.681Z" }, + { url = "https://files.pythonhosted.org/packages/89/39/70062f4f57d03d5fee91b2eaccaae49504a3623ad5fea42527007170aa4d/jsonnet-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f837389c6b384070b870c98f12c05847fdd801bb7752ab7893beaeac662f4b54", size = 438840, upload-time = "2025-05-07T13:20:14.501Z" }, + { url = "https://files.pythonhosted.org/packages/55/0b/601cbdaddf6c0fad50ed823b8d2dbb7f10e428034c251fb2f5355869838e/jsonnet-0.21.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85a2089fb77d6db86ef84d9403654d710ba3e41dcf4ad21d0cba2635497ba852", size = 6528833, upload-time = "2025-05-07T13:20:16.313Z" }, + { url = "https://files.pythonhosted.org/packages/03/45/30b1cf590e56fa2ee082c6abb8cc5410fcb13a2e944fc49305f15f4e6e22/jsonnet-0.21.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:559d59e8984b804f60a97d72e7aeaa2a2572fc0a5bf7ef1109eb21b91dbc166c", size = 6779026, upload-time = "2025-05-07T13:20:18.584Z" }, + { url = "https://files.pythonhosted.org/packages/98/1b/70cb03ad7299008798878e146d1ffea67579ab0c53b9372f438eddd7987d/jsonnet-0.21.0-cp311-cp311-win_amd64.whl", hash = "sha256:6018365037491e91b5d3f0eccfdf78812d84e25aa9ccbba097bd3ba6ce70709a", size = 318250, upload-time = "2025-05-07T13:20:20.377Z" }, + { url = "https://files.pythonhosted.org/packages/5c/42/2bf7da089e6b5ca75f7a7c3bb2e9c39e1783d4359ab17c5083b0698dfbfa/jsonnet-0.21.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ba35051103bed81ddcb446db52c31bba00391c52069107498eb44952feac8a30", size = 473523, upload-time = "2025-05-07T13:20:22.449Z" }, + { url = "https://files.pythonhosted.org/packages/8d/e0/f3ef97fa0535b435fbde76df9da63b78602692cca0d4b8ddf2d8439830fc/jsonnet-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:71afa464a74dcbec30b39d8f28cad091ce27497a8620c0ef7859814e173ce454", size = 438912, upload-time = "2025-05-07T13:20:24.425Z" }, + { url = "https://files.pythonhosted.org/packages/97/ee/3613b2f2216d4a53c13bb081f7b77d6a7977a4169039efa7eb77bf9d71da/jsonnet-0.21.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fccebb019917004cf860490a80d17189bad01c9d425b7a1cb138a14745488cf0", size = 6529912, upload-time = "2025-05-07T13:20:26.671Z" }, + { url = "https://files.pythonhosted.org/packages/3b/90/dee03ee550737b913f64428ac392e8970d807b02c938b766bf7e40fa3cfc/jsonnet-0.21.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ba913bb650b2b5dac29e65fd6963dff7cad960580523c0ccdd66e23e22e3b772", size = 6782491, upload-time = "2025-05-07T13:20:28.576Z" }, + { url = "https://files.pythonhosted.org/packages/aa/a5/c3a2592383ec68e02d9d31740764f144dfb8df28d4f2d003c40d05f73478/jsonnet-0.21.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a39b5a3195bb6ec16050d14f8aa9378cf862ff2dd54ca0973cbbfbc9cec6e89", size = 318260, upload-time = "2025-05-07T13:20:30.159Z" }, + { url = "https://files.pythonhosted.org/packages/87/9a/b7825f91d889fbe47125911a34f56f0cb94f01afdffb0bc6390f1573bb1c/jsonnet-0.21.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:95d0e0e59ed29f7e424066c05c4585fd255e288fd6050686e1d5bb54bd719896", size = 473524, upload-time = "2025-05-07T13:20:31.601Z" }, + { url = "https://files.pythonhosted.org/packages/d0/66/fe05afdcf269a8be7a99aa33741ad31cf083a505acb46010a90781b00106/jsonnet-0.21.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:eb926cae6ea157e2e0851e6ec8f6a2949e926f67754a87980bbcb2698a211dc5", size = 438913, upload-time = "2025-05-07T13:20:33.281Z" }, + { url = "https://files.pythonhosted.org/packages/c4/2c/c4760c07b3506312f37c237c9a0840f3db44e476da5af2c0b883bb5b1070/jsonnet-0.21.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb642fe864e41a432957f71bfa57ae4eaab904886f06dec183c9e40d6ce4e24b", size = 6529866, upload-time = "2025-05-07T13:20:35.359Z" }, + { url = "https://files.pythonhosted.org/packages/1c/56/33a2eb1d263952603f9b16f3789ecac0c7a3b9bb5a6410d69173a6ed3bd9/jsonnet-0.21.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:22a87070c1c50ecf6c0c8df252a4984a89275ceb18fe059dfa99eeaf548be71f", size = 6782518, upload-time = "2025-05-07T13:20:37.777Z" }, + { url = "https://files.pythonhosted.org/packages/bb/d9/2c68a80f9cbda8e4b4721032b7def236109bd8991d1670b60beb5cfb505c/jsonnet-0.21.0-cp313-cp313-win_amd64.whl", hash = "sha256:6e23e55e0a0811b899398aaa03a5b46eea01ffcafc697a705fe7b07eb8cd0ce7", size = 318264, upload-time = "2025-05-07T13:20:39.842Z" }, +] + [[package]] name = "jsonschema" version = "4.24.0" @@ -3446,6 +3550,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/0e/b27cdbaccf30b890c40ed1da9fd4a3593a5cf94dae54fb34f8a4b74fcd3f/jsonschema_specifications-2025.4.1-py3-none-any.whl", hash = "sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af", size = 18437, upload-time = "2025-04-23T12:34:05.422Z" }, ] +[[package]] +name = "kaleido" +version = "0.2.1.post1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/4b/d668e288b694661d2fbfc2b972db69cf1f30f8b8a91be14dcf9f000cab16/kaleido-0.2.1.post1-py2.py3-none-manylinux2014_armv7l.whl", hash = "sha256:d313940896c24447fc12c74f60d46ea826195fc991f58569a6e73864d53e5c20", size = 71653137, upload-time = "2021-04-14T11:26:19.068Z" }, +] + [[package]] name = "kenlm" version = "0.3.0" @@ -3751,8 +3863,7 @@ dependencies = [ { name = "numpy" }, { name = "pooch" }, { name = "scikit-learn" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "soundfile" }, { name = "soxr" }, { name = "typing-extensions" }, @@ -3762,6 +3873,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8c/8a/2d231b35456506b7c98b3ab9bbf07917b205fed8615d2e59e976ab497fff/librosa-0.10.2.post1-py3-none-any.whl", hash = "sha256:dc882750e8b577a63039f25661b7e39ec4cfbacc99c1cffba666cd664fb0a7a0", size = 260089, upload-time = "2024-05-14T15:49:38.919Z" }, ] +[[package]] +name = "lightning" +version = "2.5.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fsspec", extra = ["http"] }, + { name = "lightning-utilities" }, + { name = "packaging" }, + { name = "pytorch-lightning" }, + { name = "pyyaml" }, + { name = "torch" }, + { name = "torchmetrics" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/da/289e17b2d4631b885771ce10ab7fe19c6c0ab2b1208d1dda418818ffbbfd/lightning-2.5.6.tar.gz", hash = "sha256:57b6abe87080895bc237fb7f36b7b4abaa2793760cbca00e3907e56607e0ed27", size = 640106, upload-time = "2025-11-05T20:53:06.823Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/dc/d7804f13928b6a81a0e948cfecbf0071e8cc74e3f341c704e23e75e504ad/lightning-2.5.6-py3-none-any.whl", hash = "sha256:25bb2053078c2efc57c082fda89dfbd975dfa76beb08def191947c2b571a8c8a", size = 827915, upload-time = "2025-11-05T20:53:03.169Z" }, +] + +[package.optional-dependencies] +pytorch-extra = [ + { name = "bitsandbytes", marker = "sys_platform != 'darwin'" }, + { name = "hydra-core" }, + { name = "jsonargparse", extra = ["jsonnet", "signatures"] }, + { name = "matplotlib" }, + { name = "omegaconf" }, + { name = "rich" }, + { name = "tensorboardx" }, +] + +[[package]] +name = "lightning-utilities" +version = "0.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "setuptools", version = "80.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b8/39/6fc58ca81492db047149b4b8fd385aa1bfb8c28cd7cacb0c7eb0c44d842f/lightning_utilities-0.15.2.tar.gz", hash = "sha256:cdf12f530214a63dacefd713f180d1ecf5d165338101617b4742e8f22c032e24", size = 31090, upload-time = "2025-08-06T13:57:39.242Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/73/3d757cb3fc16f0f9794dd289bcd0c4a031d9cf54d8137d6b984b2d02edf3/lightning_utilities-0.15.2-py3-none-any.whl", hash = "sha256:ad3ab1703775044bbf880dbf7ddaaac899396c96315f3aa1779cec9d618a9841", size = 29431, upload-time = "2025-08-06T13:57:38.046Z" }, +] + [[package]] name = "linkify-it-py" version = "2.0.3" @@ -4039,6 +4196,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/40/5f9eb8b73030cc4b0d6817176e66079a62a2ddd9d5530da54f8011473428/marisa_trie-1.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:aa7cd17e1c690ce96c538b2f4aae003d9a498e65067dd433c52dd069009951d4", size = 149035, upload-time = "2024-10-12T11:29:31.332Z" }, ] +[[package]] +name = "markdown" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/7dd27d9d863b3376fcf23a5a13cb5d024aed1db46f963f1b5735ae43b3be/markdown-3.10.tar.gz", hash = "sha256:37062d4f2aa4b2b6b32aefb80faa300f82cc790cb949a35b8caede34f2b68c0e", size = 364931, upload-time = "2025-11-03T19:51:15.007Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/81/54e3ce63502cd085a0c556652a4e1b919c45a446bd1e5300e10c44c8c521/markdown-3.10-py3-none-any.whl", hash = "sha256:b5b99d6951e2e4948d939255596523444c0e677c669700b1d17aa4a8a464cb7c", size = 107678, upload-time = "2025-11-03T19:51:13.887Z" }, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -5045,6 +5211,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/41/e74ec826e1585ad6d31f41de96f6faae8ffc712a45c2b880baca4ae87a64/nvtx-0.2.12-cp313-cp313-win_amd64.whl", hash = "sha256:a37e063c3c745a4c6b561993a2dae2f67fcc26f2a2c2653f24eeae5810a2180d", size = 97070, upload-time = "2025-05-25T08:43:41.323Z" }, ] +[[package]] +name = "omegaconf" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "antlr4-python3-runtime" }, + { name = "pyyaml" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/48/6388f1bb9da707110532cb70ec4d2822858ddfb44f1cdf1233c20a80ea4b/omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7", size = 3298120, upload-time = "2022-12-08T20:59:22.753Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b", size = 79500, upload-time = "2022-12-08T20:59:19.686Z" }, +] + [[package]] name = "onnxruntime" version = "1.22.0" @@ -5162,19 +5341,19 @@ wheels = [ [[package]] name = "opencv-python" -version = "4.11.0.86" +version = "4.10.0.84" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/17/06/68c27a523103dad5837dc5b87e71285280c4f098c60e4fe8a8db6486ab09/opencv-python-4.11.0.86.tar.gz", hash = "sha256:03d60ccae62304860d232272e4a4fda93c39d595780cb40b161b310244b736a4", size = 95171956, upload-time = "2025-01-16T13:52:24.737Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/b70a2d9ab205110d715906fc8ec83fbb00404aeb3a37a0654fdb68eb0c8c/opencv-python-4.10.0.84.tar.gz", hash = "sha256:72d234e4582e9658ffea8e9cae5b63d488ad06994ef12d81dc303b17472f3526", size = 95103981, upload-time = "2024-06-17T18:29:56.757Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/05/4d/53b30a2a3ac1f75f65a59eb29cf2ee7207ce64867db47036ad61743d5a23/opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:432f67c223f1dc2824f5e73cdfcd9db0efc8710647d4e813012195dc9122a52a", size = 37326322, upload-time = "2025-01-16T13:52:25.887Z" }, - { url = "https://files.pythonhosted.org/packages/3b/84/0a67490741867eacdfa37bc18df96e08a9d579583b419010d7f3da8ff503/opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_x86_64.whl", hash = "sha256:9d05ef13d23fe97f575153558653e2d6e87103995d54e6a35db3f282fe1f9c66", size = 56723197, upload-time = "2025-01-16T13:55:21.222Z" }, - { url = "https://files.pythonhosted.org/packages/f3/bd/29c126788da65c1fb2b5fb621b7fed0ed5f9122aa22a0868c5e2c15c6d23/opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b92ae2c8852208817e6776ba1ea0d6b1e0a1b5431e971a2a0ddd2a8cc398202", size = 42230439, upload-time = "2025-01-16T13:51:35.822Z" }, - { url = "https://files.pythonhosted.org/packages/2c/8b/90eb44a40476fa0e71e05a0283947cfd74a5d36121a11d926ad6f3193cc4/opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b02611523803495003bd87362db3e1d2a0454a6a63025dc6658a9830570aa0d", size = 62986597, upload-time = "2025-01-16T13:52:08.836Z" }, - { url = "https://files.pythonhosted.org/packages/fb/d7/1d5941a9dde095468b288d989ff6539dd69cd429dbf1b9e839013d21b6f0/opencv_python-4.11.0.86-cp37-abi3-win32.whl", hash = "sha256:810549cb2a4aedaa84ad9a1c92fbfdfc14090e2749cedf2c1589ad8359aa169b", size = 29384337, upload-time = "2025-01-16T13:52:13.549Z" }, - { url = "https://files.pythonhosted.org/packages/a4/7d/f1c30a92854540bf789e9cd5dde7ef49bbe63f855b85a2e6b3db8135c591/opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl", hash = "sha256:085ad9b77c18853ea66283e98affefe2de8cc4c1f43eda4c100cf9b2721142ec", size = 39488044, upload-time = "2025-01-16T13:52:21.928Z" }, + { url = "https://files.pythonhosted.org/packages/66/82/564168a349148298aca281e342551404ef5521f33fba17b388ead0a84dc5/opencv_python-4.10.0.84-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:fc182f8f4cda51b45f01c64e4cbedfc2f00aff799debebc305d8d0210c43f251", size = 54835524, upload-time = "2024-06-18T04:57:32.973Z" }, + { url = "https://files.pythonhosted.org/packages/64/4a/016cda9ad7cf18c58ba074628a4eaae8aa55f3fd06a266398cef8831a5b9/opencv_python-4.10.0.84-cp37-abi3-macosx_12_0_x86_64.whl", hash = "sha256:71e575744f1d23f79741450254660442785f45a0797212852ee5199ef12eed98", size = 56475426, upload-time = "2024-06-17T19:34:10.927Z" }, + { url = "https://files.pythonhosted.org/packages/81/e4/7a987ebecfe5ceaf32db413b67ff18eb3092c598408862fff4d7cc3fd19b/opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09a332b50488e2dda866a6c5573ee192fe3583239fb26ff2f7f9ceb0bc119ea6", size = 41746971, upload-time = "2024-06-17T20:00:25.211Z" }, + { url = "https://files.pythonhosted.org/packages/3f/a4/d2537f47fd7fcfba966bd806e3ec18e7ee1681056d4b0a9c8d983983e4d5/opencv_python-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ace140fc6d647fbe1c692bcb2abce768973491222c067c131d80957c595b71f", size = 62548253, upload-time = "2024-06-17T18:29:43.659Z" }, + { url = "https://files.pythonhosted.org/packages/1e/39/bbf57e7b9dab623e8773f6ff36385456b7ae7fa9357a5e53db732c347eac/opencv_python-4.10.0.84-cp37-abi3-win32.whl", hash = "sha256:2db02bb7e50b703f0a2d50c50ced72e95c574e1e5a0bb35a8a86d0b35c98c236", size = 28737688, upload-time = "2024-06-17T18:28:13.177Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6c/fab8113424af5049f85717e8e527ca3773299a3c6b02506e66436e19874f/opencv_python-4.10.0.84-cp37-abi3-win_amd64.whl", hash = "sha256:32dbbd94c26f611dc5cc6979e6b7aa1f55a64d6b463cc1dcd3c95505a63e48fe", size = 38842521, upload-time = "2024-06-17T18:28:21.813Z" }, ] [[package]] @@ -5496,15 +5675,15 @@ wheels = [ [[package]] name = "plotly" -version = "6.2.0" +version = "5.24.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "narwhals" }, { name = "packaging" }, + { name = "tenacity" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6e/5c/0efc297df362b88b74957a230af61cd6929f531f72f48063e8408702ffba/plotly-6.2.0.tar.gz", hash = "sha256:9dfa23c328000f16c928beb68927444c1ab9eae837d1fe648dbcda5360c7953d", size = 6801941, upload-time = "2025-06-26T16:20:45.765Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/4f/428f6d959818d7425a94c190a6b26fbc58035cbef40bf249be0b62a9aedd/plotly-5.24.1.tar.gz", hash = "sha256:dbc8ac8339d248a4bcc36e08a5659bacfe1b079390b8953533f4eb22169b4bae", size = 9479398, upload-time = "2024-09-12T15:36:31.068Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/20/f2b7ac96a91cc5f70d81320adad24cc41bf52013508d649b1481db225780/plotly-6.2.0-py3-none-any.whl", hash = "sha256:32c444d4c940887219cb80738317040363deefdfee4f354498cc0b6dab8978bd", size = 9635469, upload-time = "2025-06-26T16:20:40.76Z" }, + { url = "https://files.pythonhosted.org/packages/e5/ae/580600f441f6fc05218bd6c9d5794f4aef072a7d9093b291f1c50a9db8bc/plotly-5.24.1-py3-none-any.whl", hash = "sha256:f67073a1e637eb0dc3e46324d9d51e2fe76e9727c892dde64ddf1e1b51f29089", size = 19054220, upload-time = "2024-09-12T15:36:24.08Z" }, ] [[package]] @@ -5793,6 +5972,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/50/d13ea0a054189ae1bc21af1d85b6f8bb9bbc5572991055d70ad9006fe2d6/psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142", size = 2569224, upload-time = "2025-01-04T20:09:19.234Z" }, ] +[[package]] +name = "ptlflow" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "h5py" }, + { name = "kaleido" }, + { name = "lightning", extra = ["pytorch-extra"] }, + { name = "loguru" }, + { name = "opencv-python" }, + { name = "pandas" }, + { name = "plotly" }, + { name = "pypng" }, + { name = "scipy" }, + { name = "tabulate" }, + { name = "tensorboard" }, + { name = "timm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/b4/6c5d5276ccfa82883bcb9ceba81ce9a242be4eba6e1391a8e92292d47959/ptlflow-0.4.1.tar.gz", hash = "sha256:02e0bbc6ca55e939150f68ab1598e18ca3bd0c467baf4e3226dfdaa8763f3afe", size = 781083, upload-time = "2025-03-20T08:40:43.472Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/f6/d92d24e5cc10862ac7b908324b44173f86736cc8d902dd50f21efabe5435/ptlflow-0.4.1-py3-none-any.whl", hash = "sha256:8dab7c5fcece03b6da4174bbcb4a8b42666a315cc9feb7338d5f74b34e753b94", size = 974476, upload-time = "2025-03-20T08:40:41.901Z" }, +] + [[package]] name = "py-cpuinfo" version = "9.0.0" @@ -5890,6 +6093,7 @@ all = [ { name = "opencc" }, { name = "opencv-python" }, { name = "pre-commit" }, + { name = "ptlflow" }, { name = "pyspark" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -5911,6 +6115,7 @@ all = [ { name = "sphinx-autobuild" }, { name = "sphinx-copybutton" }, { name = "tiktoken" }, + { name = "timm" }, { name = "toml" }, { name = "torch" }, { name = "torchaudio" }, @@ -5988,10 +6193,12 @@ vision = [ { name = "diffusers" }, { name = "imagededup" }, { name = "opencv-python" }, + { name = "ptlflow" }, { name = "qwen-vl-utils" }, { name = "rembg" }, { name = "scenedetect", extra = ["opencv"] }, { name = "simple-aesthetics-predictor" }, + { name = "timm" }, { name = "ultralytics" }, ] @@ -6087,6 +6294,8 @@ requires-dist = [ { name = "pre-commit", marker = "extra == 'all'" }, { name = "pre-commit", marker = "extra == 'dev'" }, { name = "psutil" }, + { name = "ptlflow", marker = "extra == 'all'", specifier = "==0.4.1" }, + { name = "ptlflow", marker = "extra == 'vision'", specifier = "==0.4.1" }, { name = "pydantic", specifier = ">=2.0" }, { name = "pylance" }, { name = "pyspark", marker = "extra == 'all'", specifier = "==3.5.5" }, @@ -6138,6 +6347,8 @@ requires-dist = [ { name = "tabulate" }, { name = "tiktoken", marker = "extra == 'all'" }, { name = "tiktoken", marker = "extra == 'nlp'" }, + { name = "timm", marker = "extra == 'all'", specifier = "==1.0.22" }, + { name = "timm", marker = "extra == 'vision'", specifier = "==1.0.22" }, { name = "toml", marker = "extra == 'all'" }, { name = "toml", marker = "extra == 'dev'" }, { name = "tomli" }, @@ -6727,8 +6938,7 @@ dependencies = [ { name = "numba" }, { name = "numpy" }, { name = "pillow" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/35/43/cd7a82913dfde95dfb653efd09c7b394a76b3865570050b674a36fc0078c/pymatting-1.1.14.tar.gz", hash = "sha256:75e2ec1e346dbd564c9a2cc8229b134ec939f49008fa570025db30003d0c46fc", size = 44165, upload-time = "2025-05-16T20:21:28.27Z" } wheels = [ @@ -6779,6 +6989,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/6b/2706497c86e8d69fb76afe5ea857fe1794621aa0f3b1d863feb953fe0f22/pypdfium2-4.30.1-py3-none-win_arm64.whl", hash = "sha256:c2b6d63f6d425d9416c08d2511822b54b8e3ac38e639fc41164b1d75584b3a8c", size = 2814810, upload-time = "2024-12-19T19:28:09.857Z" }, ] +[[package]] +name = "pypng" +version = "0.20220715.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/93/cd/112f092ec27cca83e0516de0a3368dbd9128c187fb6b52aaaa7cde39c96d/pypng-0.20220715.0.tar.gz", hash = "sha256:739c433ba96f078315de54c0db975aee537cbc3e1d0ae4ed9aab0ca1e427e2c1", size = 128992, upload-time = "2022-07-15T14:11:05.301Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3e/b9/3766cc361d93edb2ce81e2e1f87dd98f314d7d513877a342d31b30741680/pypng-0.20220715.0-py3-none-any.whl", hash = "sha256:4a43e969b8f5aaafb2a415536c1a8ec7e341cd6a3f957fd5b5f32a4cfeed902c", size = 58057, upload-time = "2022-07-15T14:11:03.713Z" }, +] + [[package]] name = "pyproject-hooks" version = "1.2.0" @@ -7021,6 +7240,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/00/4409ffd333045bd533556babe4012048bcc09e521409fc2b894670bebee7/python_stretch-0.3.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:6a3d19b5f067a35cf9bfd4e3270ba281163b1608da94f7e813fa50e225a621bf", size = 96478, upload-time = "2025-02-14T14:36:51.13Z" }, ] +[[package]] +name = "pytorch-lightning" +version = "2.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fsspec", extra = ["http"] }, + { name = "lightning-utilities" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "torch" }, + { name = "torchmetrics" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/d7/e3963d9669758f93b07941f4e2e82a394eb3d0980e29baa4764f3bad6689/pytorch_lightning-2.6.0.tar.gz", hash = "sha256:25b0d4f05e1f33b72be0920c34d0465777fe5f623228f9d6252b4b0f685d7037", size = 658853, upload-time = "2025-11-28T09:34:13.098Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/eb/cc6dbfe70d15318dbce82674b1e8057cef2634ca9f9121a16b8a06c630db/pytorch_lightning-2.6.0-py3-none-any.whl", hash = "sha256:ee72cff4b8c983ecfaae8599382544bd5236d9eb300adc7dd305f359195f4e79", size = 849476, upload-time = "2025-11-28T09:34:11.271Z" }, +] + [[package]] name = "pytz" version = "2022.7.1" @@ -7400,8 +7638,7 @@ dependencies = [ { name = "pooch" }, { name = "pymatting" }, { name = "scikit-image" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "tqdm" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ec/c6/6c6f1f068d80c39856d25752003185afe4da0254ef9372f35e6bc8ae299e/rembg-2.0.66.tar.gz", hash = "sha256:851068ab4ccc63a6abb49b3af73fdc667b71e35b6658633d13fd0748558fbd99", size = 49842, upload-time = "2025-05-14T20:25:32.754Z" } @@ -7844,8 +8081,7 @@ dependencies = [ { name = "numpy" }, { name = "packaging" }, { name = "pillow" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "tifffile", version = "2025.5.10", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "tifffile", version = "2025.6.11", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] @@ -7881,8 +8117,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "joblib" }, { name = "numpy" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "threadpoolctl" }, ] sdist = { url = "https://files.pythonhosted.org/packages/df/3b/29fa87e76b1d7b3b77cc1fcbe82e6e6b8cd704410705b008822de530277c/scikit_learn-1.7.0.tar.gz", hash = "sha256:c01e869b15aec88e2cdb73d27f15bdbe03bce8e2fb43afbe77c45d399e73a5a3", size = 7178217, upload-time = "2025-06-05T22:02:46.703Z" } @@ -7915,169 +8150,45 @@ wheels = [ [[package]] name = "scipy" -version = "1.15.3" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'win32'", - "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'win32'", - "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'", - "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version < '3.11' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'darwin'", - "python_full_version < '3.11' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'win32'", - "python_full_version < '3.11' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux'", - "python_full_version < '3.11' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", -] -dependencies = [ - { name = "numpy", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0f/37/6964b830433e654ec7485e45a00fc9a27cf868d622838f6b6d9c5ec0d532/scipy-1.15.3.tar.gz", hash = "sha256:eae3cf522bc7df64b42cad3925c876e1b0b6c35c1337c93e12c0f366f55b0eaf", size = 59419214, upload-time = "2025-05-08T16:13:05.955Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/78/2f/4966032c5f8cc7e6a60f1b2e0ad686293b9474b65246b0c642e3ef3badd0/scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c", size = 38702770, upload-time = "2025-05-08T16:04:20.849Z" }, - { url = "https://files.pythonhosted.org/packages/a0/6e/0c3bf90fae0e910c274db43304ebe25a6b391327f3f10b5dcc638c090795/scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253", size = 30094511, upload-time = "2025-05-08T16:04:27.103Z" }, - { url = "https://files.pythonhosted.org/packages/ea/b1/4deb37252311c1acff7f101f6453f0440794f51b6eacb1aad4459a134081/scipy-1.15.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:aef683a9ae6eb00728a542b796f52a5477b78252edede72b8327a886ab63293f", size = 22368151, upload-time = "2025-05-08T16:04:31.731Z" }, - { url = "https://files.pythonhosted.org/packages/38/7d/f457626e3cd3c29b3a49ca115a304cebb8cc6f31b04678f03b216899d3c6/scipy-1.15.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:1c832e1bd78dea67d5c16f786681b28dd695a8cb1fb90af2e27580d3d0967e92", size = 25121732, upload-time = "2025-05-08T16:04:36.596Z" }, - { url = "https://files.pythonhosted.org/packages/db/0a/92b1de4a7adc7a15dcf5bddc6e191f6f29ee663b30511ce20467ef9b82e4/scipy-1.15.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:263961f658ce2165bbd7b99fa5135195c3a12d9bef045345016b8b50c315cb82", size = 35547617, upload-time = "2025-05-08T16:04:43.546Z" }, - { url = "https://files.pythonhosted.org/packages/8e/6d/41991e503e51fc1134502694c5fa7a1671501a17ffa12716a4a9151af3df/scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2abc762b0811e09a0d3258abee2d98e0c703eee49464ce0069590846f31d40", size = 37662964, upload-time = "2025-05-08T16:04:49.431Z" }, - { url = "https://files.pythonhosted.org/packages/25/e1/3df8f83cb15f3500478c889be8fb18700813b95e9e087328230b98d547ff/scipy-1.15.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed7284b21a7a0c8f1b6e5977ac05396c0d008b89e05498c8b7e8f4a1423bba0e", size = 37238749, upload-time = "2025-05-08T16:04:55.215Z" }, - { url = "https://files.pythonhosted.org/packages/93/3e/b3257cf446f2a3533ed7809757039016b74cd6f38271de91682aa844cfc5/scipy-1.15.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5380741e53df2c566f4d234b100a484b420af85deb39ea35a1cc1be84ff53a5c", size = 40022383, upload-time = "2025-05-08T16:05:01.914Z" }, - { url = "https://files.pythonhosted.org/packages/d1/84/55bc4881973d3f79b479a5a2e2df61c8c9a04fcb986a213ac9c02cfb659b/scipy-1.15.3-cp310-cp310-win_amd64.whl", hash = "sha256:9d61e97b186a57350f6d6fd72640f9e99d5a4a2b8fbf4b9ee9a841eab327dc13", size = 41259201, upload-time = "2025-05-08T16:05:08.166Z" }, - { url = "https://files.pythonhosted.org/packages/96/ab/5cc9f80f28f6a7dff646c5756e559823614a42b1939d86dd0ed550470210/scipy-1.15.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:993439ce220d25e3696d1b23b233dd010169b62f6456488567e830654ee37a6b", size = 38714255, upload-time = "2025-05-08T16:05:14.596Z" }, - { url = "https://files.pythonhosted.org/packages/4a/4a/66ba30abe5ad1a3ad15bfb0b59d22174012e8056ff448cb1644deccbfed2/scipy-1.15.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:34716e281f181a02341ddeaad584205bd2fd3c242063bd3423d61ac259ca7eba", size = 30111035, upload-time = "2025-05-08T16:05:20.152Z" }, - { url = "https://files.pythonhosted.org/packages/4b/fa/a7e5b95afd80d24313307f03624acc65801846fa75599034f8ceb9e2cbf6/scipy-1.15.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b0334816afb8b91dab859281b1b9786934392aa3d527cd847e41bb6f45bee65", size = 22384499, upload-time = "2025-05-08T16:05:24.494Z" }, - { url = "https://files.pythonhosted.org/packages/17/99/f3aaddccf3588bb4aea70ba35328c204cadd89517a1612ecfda5b2dd9d7a/scipy-1.15.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6db907c7368e3092e24919b5e31c76998b0ce1684d51a90943cb0ed1b4ffd6c1", size = 25152602, upload-time = "2025-05-08T16:05:29.313Z" }, - { url = "https://files.pythonhosted.org/packages/56/c5/1032cdb565f146109212153339f9cb8b993701e9fe56b1c97699eee12586/scipy-1.15.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:721d6b4ef5dc82ca8968c25b111e307083d7ca9091bc38163fb89243e85e3889", size = 35503415, upload-time = "2025-05-08T16:05:34.699Z" }, - { url = "https://files.pythonhosted.org/packages/bd/37/89f19c8c05505d0601ed5650156e50eb881ae3918786c8fd7262b4ee66d3/scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39cb9c62e471b1bb3750066ecc3a3f3052b37751c7c3dfd0fd7e48900ed52982", size = 37652622, upload-time = "2025-05-08T16:05:40.762Z" }, - { url = "https://files.pythonhosted.org/packages/7e/31/be59513aa9695519b18e1851bb9e487de66f2d31f835201f1b42f5d4d475/scipy-1.15.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:795c46999bae845966368a3c013e0e00947932d68e235702b5c3f6ea799aa8c9", size = 37244796, upload-time = "2025-05-08T16:05:48.119Z" }, - { url = "https://files.pythonhosted.org/packages/10/c0/4f5f3eeccc235632aab79b27a74a9130c6c35df358129f7ac8b29f562ac7/scipy-1.15.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:18aaacb735ab38b38db42cb01f6b92a2d0d4b6aabefeb07f02849e47f8fb3594", size = 40047684, upload-time = "2025-05-08T16:05:54.22Z" }, - { url = "https://files.pythonhosted.org/packages/ab/a7/0ddaf514ce8a8714f6ed243a2b391b41dbb65251affe21ee3077ec45ea9a/scipy-1.15.3-cp311-cp311-win_amd64.whl", hash = "sha256:ae48a786a28412d744c62fd7816a4118ef97e5be0bee968ce8f0a2fba7acf3bb", size = 41246504, upload-time = "2025-05-08T16:06:00.437Z" }, - { url = "https://files.pythonhosted.org/packages/37/4b/683aa044c4162e10ed7a7ea30527f2cbd92e6999c10a8ed8edb253836e9c/scipy-1.15.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6ac6310fdbfb7aa6612408bd2f07295bcbd3fda00d2d702178434751fe48e019", size = 38766735, upload-time = "2025-05-08T16:06:06.471Z" }, - { url = "https://files.pythonhosted.org/packages/7b/7e/f30be3d03de07f25dc0ec926d1681fed5c732d759ac8f51079708c79e680/scipy-1.15.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:185cd3d6d05ca4b44a8f1595af87f9c372bb6acf9c808e99aa3e9aa03bd98cf6", size = 30173284, upload-time = "2025-05-08T16:06:11.686Z" }, - { url = "https://files.pythonhosted.org/packages/07/9c/0ddb0d0abdabe0d181c1793db51f02cd59e4901da6f9f7848e1f96759f0d/scipy-1.15.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:05dc6abcd105e1a29f95eada46d4a3f251743cfd7d3ae8ddb4088047f24ea477", size = 22446958, upload-time = "2025-05-08T16:06:15.97Z" }, - { url = "https://files.pythonhosted.org/packages/af/43/0bce905a965f36c58ff80d8bea33f1f9351b05fad4beaad4eae34699b7a1/scipy-1.15.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:06efcba926324df1696931a57a176c80848ccd67ce6ad020c810736bfd58eb1c", size = 25242454, upload-time = "2025-05-08T16:06:20.394Z" }, - { url = "https://files.pythonhosted.org/packages/56/30/a6f08f84ee5b7b28b4c597aca4cbe545535c39fe911845a96414700b64ba/scipy-1.15.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05045d8b9bfd807ee1b9f38761993297b10b245f012b11b13b91ba8945f7e45", size = 35210199, upload-time = "2025-05-08T16:06:26.159Z" }, - { url = "https://files.pythonhosted.org/packages/0b/1f/03f52c282437a168ee2c7c14a1a0d0781a9a4a8962d84ac05c06b4c5b555/scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271e3713e645149ea5ea3e97b57fdab61ce61333f97cfae392c28ba786f9bb49", size = 37309455, upload-time = "2025-05-08T16:06:32.778Z" }, - { url = "https://files.pythonhosted.org/packages/89/b1/fbb53137f42c4bf630b1ffdfc2151a62d1d1b903b249f030d2b1c0280af8/scipy-1.15.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6cfd56fc1a8e53f6e89ba3a7a7251f7396412d655bca2aa5611c8ec9a6784a1e", size = 36885140, upload-time = "2025-05-08T16:06:39.249Z" }, - { url = "https://files.pythonhosted.org/packages/2e/2e/025e39e339f5090df1ff266d021892694dbb7e63568edcfe43f892fa381d/scipy-1.15.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0ff17c0bb1cb32952c09217d8d1eed9b53d1463e5f1dd6052c7857f83127d539", size = 39710549, upload-time = "2025-05-08T16:06:45.729Z" }, - { url = "https://files.pythonhosted.org/packages/e6/eb/3bf6ea8ab7f1503dca3a10df2e4b9c3f6b3316df07f6c0ded94b281c7101/scipy-1.15.3-cp312-cp312-win_amd64.whl", hash = "sha256:52092bc0472cfd17df49ff17e70624345efece4e1a12b23783a1ac59a1b728ed", size = 40966184, upload-time = "2025-05-08T16:06:52.623Z" }, - { url = "https://files.pythonhosted.org/packages/73/18/ec27848c9baae6e0d6573eda6e01a602e5649ee72c27c3a8aad673ebecfd/scipy-1.15.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2c620736bcc334782e24d173c0fdbb7590a0a436d2fdf39310a8902505008759", size = 38728256, upload-time = "2025-05-08T16:06:58.696Z" }, - { url = "https://files.pythonhosted.org/packages/74/cd/1aef2184948728b4b6e21267d53b3339762c285a46a274ebb7863c9e4742/scipy-1.15.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:7e11270a000969409d37ed399585ee530b9ef6aa99d50c019de4cb01e8e54e62", size = 30109540, upload-time = "2025-05-08T16:07:04.209Z" }, - { url = "https://files.pythonhosted.org/packages/5b/d8/59e452c0a255ec352bd0a833537a3bc1bfb679944c4938ab375b0a6b3a3e/scipy-1.15.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:8c9ed3ba2c8a2ce098163a9bdb26f891746d02136995df25227a20e71c396ebb", size = 22383115, upload-time = "2025-05-08T16:07:08.998Z" }, - { url = "https://files.pythonhosted.org/packages/08/f5/456f56bbbfccf696263b47095291040655e3cbaf05d063bdc7c7517f32ac/scipy-1.15.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0bdd905264c0c9cfa74a4772cdb2070171790381a5c4d312c973382fc6eaf730", size = 25163884, upload-time = "2025-05-08T16:07:14.091Z" }, - { url = "https://files.pythonhosted.org/packages/a2/66/a9618b6a435a0f0c0b8a6d0a2efb32d4ec5a85f023c2b79d39512040355b/scipy-1.15.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79167bba085c31f38603e11a267d862957cbb3ce018d8b38f79ac043bc92d825", size = 35174018, upload-time = "2025-05-08T16:07:19.427Z" }, - { url = "https://files.pythonhosted.org/packages/b5/09/c5b6734a50ad4882432b6bb7c02baf757f5b2f256041da5df242e2d7e6b6/scipy-1.15.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9deabd6d547aee2c9a81dee6cc96c6d7e9a9b1953f74850c179f91fdc729cb7", size = 37269716, upload-time = "2025-05-08T16:07:25.712Z" }, - { url = "https://files.pythonhosted.org/packages/77/0a/eac00ff741f23bcabd352731ed9b8995a0a60ef57f5fd788d611d43d69a1/scipy-1.15.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:dde4fc32993071ac0c7dd2d82569e544f0bdaff66269cb475e0f369adad13f11", size = 36872342, upload-time = "2025-05-08T16:07:31.468Z" }, - { url = "https://files.pythonhosted.org/packages/fe/54/4379be86dd74b6ad81551689107360d9a3e18f24d20767a2d5b9253a3f0a/scipy-1.15.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f77f853d584e72e874d87357ad70f44b437331507d1c311457bed8ed2b956126", size = 39670869, upload-time = "2025-05-08T16:07:38.002Z" }, - { url = "https://files.pythonhosted.org/packages/87/2e/892ad2862ba54f084ffe8cc4a22667eaf9c2bcec6d2bff1d15713c6c0703/scipy-1.15.3-cp313-cp313-win_amd64.whl", hash = "sha256:b90ab29d0c37ec9bf55424c064312930ca5f4bde15ee8619ee44e69319aab163", size = 40988851, upload-time = "2025-05-08T16:08:33.671Z" }, - { url = "https://files.pythonhosted.org/packages/1b/e9/7a879c137f7e55b30d75d90ce3eb468197646bc7b443ac036ae3fe109055/scipy-1.15.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3ac07623267feb3ae308487c260ac684b32ea35fd81e12845039952f558047b8", size = 38863011, upload-time = "2025-05-08T16:07:44.039Z" }, - { url = "https://files.pythonhosted.org/packages/51/d1/226a806bbd69f62ce5ef5f3ffadc35286e9fbc802f606a07eb83bf2359de/scipy-1.15.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6487aa99c2a3d509a5227d9a5e889ff05830a06b2ce08ec30df6d79db5fcd5c5", size = 30266407, upload-time = "2025-05-08T16:07:49.891Z" }, - { url = "https://files.pythonhosted.org/packages/e5/9b/f32d1d6093ab9eeabbd839b0f7619c62e46cc4b7b6dbf05b6e615bbd4400/scipy-1.15.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:50f9e62461c95d933d5c5ef4a1f2ebf9a2b4e83b0db374cb3f1de104d935922e", size = 22540030, upload-time = "2025-05-08T16:07:54.121Z" }, - { url = "https://files.pythonhosted.org/packages/e7/29/c278f699b095c1a884f29fda126340fcc201461ee8bfea5c8bdb1c7c958b/scipy-1.15.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:14ed70039d182f411ffc74789a16df3835e05dc469b898233a245cdfd7f162cb", size = 25218709, upload-time = "2025-05-08T16:07:58.506Z" }, - { url = "https://files.pythonhosted.org/packages/24/18/9e5374b617aba742a990581373cd6b68a2945d65cc588482749ef2e64467/scipy-1.15.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a769105537aa07a69468a0eefcd121be52006db61cdd8cac8a0e68980bbb723", size = 34809045, upload-time = "2025-05-08T16:08:03.929Z" }, - { url = "https://files.pythonhosted.org/packages/e1/fe/9c4361e7ba2927074360856db6135ef4904d505e9b3afbbcb073c4008328/scipy-1.15.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9db984639887e3dffb3928d118145ffe40eff2fa40cb241a306ec57c219ebbbb", size = 36703062, upload-time = "2025-05-08T16:08:09.558Z" }, - { url = "https://files.pythonhosted.org/packages/b7/8e/038ccfe29d272b30086b25a4960f757f97122cb2ec42e62b460d02fe98e9/scipy-1.15.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:40e54d5c7e7ebf1aa596c374c49fa3135f04648a0caabcb66c52884b943f02b4", size = 36393132, upload-time = "2025-05-08T16:08:15.34Z" }, - { url = "https://files.pythonhosted.org/packages/10/7e/5c12285452970be5bdbe8352c619250b97ebf7917d7a9a9e96b8a8140f17/scipy-1.15.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5e721fed53187e71d0ccf382b6bf977644c533e506c4d33c3fb24de89f5c3ed5", size = 38979503, upload-time = "2025-05-08T16:08:21.513Z" }, - { url = "https://files.pythonhosted.org/packages/81/06/0a5e5349474e1cbc5757975b21bd4fad0e72ebf138c5592f191646154e06/scipy-1.15.3-cp313-cp313t-win_amd64.whl", hash = "sha256:76ad1fb5f8752eabf0fa02e4cc0336b4e8f021e2d5f061ed37d6d264db35e3ca", size = 40308097, upload-time = "2025-05-08T16:08:27.627Z" }, -] - -[[package]] -name = "scipy" -version = "1.16.0" +version = "1.14.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '4' and platform_machine == 'aarch64' and sys_platform == 'darwin'", - "python_full_version >= '4' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '4' and platform_machine == 'aarch64' and sys_platform == 'win32'", - "python_full_version >= '4' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'aarch64' and sys_platform == 'darwin'", - "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'aarch64' and sys_platform == 'win32'", - "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '4' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '4' and platform_machine == 'x86_64' and sys_platform == 'win32'", - "python_full_version >= '4' and platform_machine == 'x86_64' and sys_platform == 'linux'", - "python_full_version >= '4' and platform_machine == 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'x86_64' and sys_platform == 'win32'", - "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'win32'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'x86_64' and sys_platform == 'linux'", - "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine == 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'win32'", - "python_full_version >= '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux'", - "python_full_version >= '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'darwin'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'win32'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'win32'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and python_full_version < '4' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.12.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'darwin'", - "python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'win32'", - "python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux'", - "python_full_version == '3.11.*' and platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", -] dependencies = [ - { name = "numpy", marker = "python_full_version >= '3.11'" }, + { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/81/18/b06a83f0c5ee8cddbde5e3f3d0bb9b702abfa5136ef6d4620ff67df7eee5/scipy-1.16.0.tar.gz", hash = "sha256:b5ef54021e832869c8cfb03bc3bf20366cbcd426e02a58e8a58d7584dfbb8f62", size = 30581216, upload-time = "2025-06-22T16:27:55.782Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d9/f8/53fc4884df6b88afd5f5f00240bdc49fee2999c7eff3acf5953eb15bc6f8/scipy-1.16.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:deec06d831b8f6b5fb0b652433be6a09db29e996368ce5911faf673e78d20085", size = 36447362, upload-time = "2025-06-22T16:18:17.817Z" }, - { url = "https://files.pythonhosted.org/packages/c9/25/fad8aa228fa828705142a275fc593d701b1817c98361a2d6b526167d07bc/scipy-1.16.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d30c0fe579bb901c61ab4bb7f3eeb7281f0d4c4a7b52dbf563c89da4fd2949be", size = 28547120, upload-time = "2025-06-22T16:18:24.117Z" }, - { url = "https://files.pythonhosted.org/packages/8d/be/d324ddf6b89fd1c32fecc307f04d095ce84abb52d2e88fab29d0cd8dc7a8/scipy-1.16.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:b2243561b45257f7391d0f49972fca90d46b79b8dbcb9b2cb0f9df928d370ad4", size = 20818922, upload-time = "2025-06-22T16:18:28.035Z" }, - { url = "https://files.pythonhosted.org/packages/cd/e0/cf3f39e399ac83fd0f3ba81ccc5438baba7cfe02176be0da55ff3396f126/scipy-1.16.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:e6d7dfc148135e9712d87c5f7e4f2ddc1304d1582cb3a7d698bbadedb61c7afd", size = 23409695, upload-time = "2025-06-22T16:18:32.497Z" }, - { url = "https://files.pythonhosted.org/packages/5b/61/d92714489c511d3ffd6830ac0eb7f74f243679119eed8b9048e56b9525a1/scipy-1.16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:90452f6a9f3fe5a2cf3748e7be14f9cc7d9b124dce19667b54f5b429d680d539", size = 33444586, upload-time = "2025-06-22T16:18:37.992Z" }, - { url = "https://files.pythonhosted.org/packages/af/2c/40108915fd340c830aee332bb85a9160f99e90893e58008b659b9f3dddc0/scipy-1.16.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a2f0bf2f58031c8701a8b601df41701d2a7be17c7ffac0a4816aeba89c4cdac8", size = 35284126, upload-time = "2025-06-22T16:18:43.605Z" }, - { url = "https://files.pythonhosted.org/packages/d3/30/e9eb0ad3d0858df35d6c703cba0a7e16a18a56a9e6b211d861fc6f261c5f/scipy-1.16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6c4abb4c11fc0b857474241b812ce69ffa6464b4bd8f4ecb786cf240367a36a7", size = 35608257, upload-time = "2025-06-22T16:18:49.09Z" }, - { url = "https://files.pythonhosted.org/packages/c8/ff/950ee3e0d612b375110d8cda211c1f787764b4c75e418a4b71f4a5b1e07f/scipy-1.16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b370f8f6ac6ef99815b0d5c9f02e7ade77b33007d74802efc8316c8db98fd11e", size = 38040541, upload-time = "2025-06-22T16:18:55.077Z" }, - { url = "https://files.pythonhosted.org/packages/8b/c9/750d34788288d64ffbc94fdb4562f40f609d3f5ef27ab4f3a4ad00c9033e/scipy-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:a16ba90847249bedce8aa404a83fb8334b825ec4a8e742ce6012a7a5e639f95c", size = 38570814, upload-time = "2025-06-22T16:19:00.912Z" }, - { url = "https://files.pythonhosted.org/packages/01/c0/c943bc8d2bbd28123ad0f4f1eef62525fa1723e84d136b32965dcb6bad3a/scipy-1.16.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7eb6bd33cef4afb9fa5f1fb25df8feeb1e52d94f21a44f1d17805b41b1da3180", size = 36459071, upload-time = "2025-06-22T16:19:06.605Z" }, - { url = "https://files.pythonhosted.org/packages/99/0d/270e2e9f1a4db6ffbf84c9a0b648499842046e4e0d9b2275d150711b3aba/scipy-1.16.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:1dbc8fdba23e4d80394ddfab7a56808e3e6489176d559c6c71935b11a2d59db1", size = 28490500, upload-time = "2025-06-22T16:19:11.775Z" }, - { url = "https://files.pythonhosted.org/packages/1c/22/01d7ddb07cff937d4326198ec8d10831367a708c3da72dfd9b7ceaf13028/scipy-1.16.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:7dcf42c380e1e3737b343dec21095c9a9ad3f9cbe06f9c05830b44b1786c9e90", size = 20762345, upload-time = "2025-06-22T16:19:15.813Z" }, - { url = "https://files.pythonhosted.org/packages/34/7f/87fd69856569ccdd2a5873fe5d7b5bbf2ad9289d7311d6a3605ebde3a94b/scipy-1.16.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:26ec28675f4a9d41587266084c626b02899db373717d9312fa96ab17ca1ae94d", size = 23418563, upload-time = "2025-06-22T16:19:20.746Z" }, - { url = "https://files.pythonhosted.org/packages/f6/f1/e4f4324fef7f54160ab749efbab6a4bf43678a9eb2e9817ed71a0a2fd8de/scipy-1.16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:952358b7e58bd3197cfbd2f2f2ba829f258404bdf5db59514b515a8fe7a36c52", size = 33203951, upload-time = "2025-06-22T16:19:25.813Z" }, - { url = "https://files.pythonhosted.org/packages/6d/f0/b6ac354a956384fd8abee2debbb624648125b298f2c4a7b4f0d6248048a5/scipy-1.16.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:03931b4e870c6fef5b5c0970d52c9f6ddd8c8d3e934a98f09308377eba6f3824", size = 35070225, upload-time = "2025-06-22T16:19:31.416Z" }, - { url = "https://files.pythonhosted.org/packages/e5/73/5cbe4a3fd4bc3e2d67ffad02c88b83edc88f381b73ab982f48f3df1a7790/scipy-1.16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:512c4f4f85912767c351a0306824ccca6fd91307a9f4318efe8fdbd9d30562ef", size = 35389070, upload-time = "2025-06-22T16:19:37.387Z" }, - { url = "https://files.pythonhosted.org/packages/86/e8/a60da80ab9ed68b31ea5a9c6dfd3c2f199347429f229bf7f939a90d96383/scipy-1.16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e69f798847e9add03d512eaf5081a9a5c9a98757d12e52e6186ed9681247a1ac", size = 37825287, upload-time = "2025-06-22T16:19:43.375Z" }, - { url = "https://files.pythonhosted.org/packages/ea/b5/29fece1a74c6a94247f8a6fb93f5b28b533338e9c34fdcc9cfe7a939a767/scipy-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:adf9b1999323ba335adc5d1dc7add4781cb5a4b0ef1e98b79768c05c796c4e49", size = 38431929, upload-time = "2025-06-22T16:19:49.385Z" }, - { url = "https://files.pythonhosted.org/packages/46/95/0746417bc24be0c2a7b7563946d61f670a3b491b76adede420e9d173841f/scipy-1.16.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:e9f414cbe9ca289a73e0cc92e33a6a791469b6619c240aa32ee18abdce8ab451", size = 36418162, upload-time = "2025-06-22T16:19:56.3Z" }, - { url = "https://files.pythonhosted.org/packages/19/5a/914355a74481b8e4bbccf67259bbde171348a3f160b67b4945fbc5f5c1e5/scipy-1.16.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:bbba55fb97ba3cdef9b1ee973f06b09d518c0c7c66a009c729c7d1592be1935e", size = 28465985, upload-time = "2025-06-22T16:20:01.238Z" }, - { url = "https://files.pythonhosted.org/packages/58/46/63477fc1246063855969cbefdcee8c648ba4b17f67370bd542ba56368d0b/scipy-1.16.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:58e0d4354eacb6004e7aa1cd350e5514bd0270acaa8d5b36c0627bb3bb486974", size = 20737961, upload-time = "2025-06-22T16:20:05.913Z" }, - { url = "https://files.pythonhosted.org/packages/93/86/0fbb5588b73555e40f9d3d6dde24ee6fac7d8e301a27f6f0cab9d8f66ff2/scipy-1.16.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:75b2094ec975c80efc273567436e16bb794660509c12c6a31eb5c195cbf4b6dc", size = 23377941, upload-time = "2025-06-22T16:20:10.668Z" }, - { url = "https://files.pythonhosted.org/packages/ca/80/a561f2bf4c2da89fa631b3cbf31d120e21ea95db71fd9ec00cb0247c7a93/scipy-1.16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b65d232157a380fdd11a560e7e21cde34fdb69d65c09cb87f6cc024ee376351", size = 33196703, upload-time = "2025-06-22T16:20:16.097Z" }, - { url = "https://files.pythonhosted.org/packages/11/6b/3443abcd0707d52e48eb315e33cc669a95e29fc102229919646f5a501171/scipy-1.16.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d8747f7736accd39289943f7fe53a8333be7f15a82eea08e4afe47d79568c32", size = 35083410, upload-time = "2025-06-22T16:20:21.734Z" }, - { url = "https://files.pythonhosted.org/packages/20/ab/eb0fc00e1e48961f1bd69b7ad7e7266896fe5bad4ead91b5fc6b3561bba4/scipy-1.16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eb9f147a1b8529bb7fec2a85cf4cf42bdfadf9e83535c309a11fdae598c88e8b", size = 35387829, upload-time = "2025-06-22T16:20:27.548Z" }, - { url = "https://files.pythonhosted.org/packages/57/9e/d6fc64e41fad5d481c029ee5a49eefc17f0b8071d636a02ceee44d4a0de2/scipy-1.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d2b83c37edbfa837a8923d19c749c1935ad3d41cf196006a24ed44dba2ec4358", size = 37841356, upload-time = "2025-06-22T16:20:35.112Z" }, - { url = "https://files.pythonhosted.org/packages/7c/a7/4c94bbe91f12126b8bf6709b2471900577b7373a4fd1f431f28ba6f81115/scipy-1.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:79a3c13d43c95aa80b87328a46031cf52508cf5f4df2767602c984ed1d3c6bbe", size = 38403710, upload-time = "2025-06-22T16:21:54.473Z" }, - { url = "https://files.pythonhosted.org/packages/47/20/965da8497f6226e8fa90ad3447b82ed0e28d942532e92dd8b91b43f100d4/scipy-1.16.0-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:f91b87e1689f0370690e8470916fe1b2308e5b2061317ff76977c8f836452a47", size = 36813833, upload-time = "2025-06-22T16:20:43.925Z" }, - { url = "https://files.pythonhosted.org/packages/28/f4/197580c3dac2d234e948806e164601c2df6f0078ed9f5ad4a62685b7c331/scipy-1.16.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:88a6ca658fb94640079e7a50b2ad3b67e33ef0f40e70bdb7dc22017dae73ac08", size = 28974431, upload-time = "2025-06-22T16:20:51.302Z" }, - { url = "https://files.pythonhosted.org/packages/8a/fc/e18b8550048d9224426e76906694c60028dbdb65d28b1372b5503914b89d/scipy-1.16.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:ae902626972f1bd7e4e86f58fd72322d7f4ec7b0cfc17b15d4b7006efc385176", size = 21246454, upload-time = "2025-06-22T16:20:57.276Z" }, - { url = "https://files.pythonhosted.org/packages/8c/48/07b97d167e0d6a324bfd7484cd0c209cc27338b67e5deadae578cf48e809/scipy-1.16.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:8cb824c1fc75ef29893bc32b3ddd7b11cf9ab13c1127fe26413a05953b8c32ed", size = 23772979, upload-time = "2025-06-22T16:21:03.363Z" }, - { url = "https://files.pythonhosted.org/packages/4c/4f/9efbd3f70baf9582edf271db3002b7882c875ddd37dc97f0f675ad68679f/scipy-1.16.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:de2db7250ff6514366a9709c2cba35cb6d08498e961cba20d7cff98a7ee88938", size = 33341972, upload-time = "2025-06-22T16:21:11.14Z" }, - { url = "https://files.pythonhosted.org/packages/3f/dc/9e496a3c5dbe24e76ee24525155ab7f659c20180bab058ef2c5fa7d9119c/scipy-1.16.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e85800274edf4db8dd2e4e93034f92d1b05c9421220e7ded9988b16976f849c1", size = 35185476, upload-time = "2025-06-22T16:21:19.156Z" }, - { url = "https://files.pythonhosted.org/packages/ce/b3/21001cff985a122ba434c33f2c9d7d1dc3b669827e94f4fc4e1fe8b9dfd8/scipy-1.16.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4f720300a3024c237ace1cb11f9a84c38beb19616ba7c4cdcd771047a10a1706", size = 35570990, upload-time = "2025-06-22T16:21:27.797Z" }, - { url = "https://files.pythonhosted.org/packages/e5/d3/7ba42647d6709251cdf97043d0c107e0317e152fa2f76873b656b509ff55/scipy-1.16.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:aad603e9339ddb676409b104c48a027e9916ce0d2838830691f39552b38a352e", size = 37950262, upload-time = "2025-06-22T16:21:36.976Z" }, - { url = "https://files.pythonhosted.org/packages/eb/c4/231cac7a8385394ebbbb4f1ca662203e9d8c332825ab4f36ffc3ead09a42/scipy-1.16.0-cp313-cp313t-win_amd64.whl", hash = "sha256:f56296fefca67ba605fd74d12f7bd23636267731a72cb3947963e76b8c0a25db", size = 38515076, upload-time = "2025-06-22T16:21:45.694Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/62/11/4d44a1f274e002784e4dbdb81e0ea96d2de2d1045b2132d5af62cc31fd28/scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417", size = 58620554, upload-time = "2024-08-21T00:09:20.662Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/68/3bc0cfaf64ff507d82b1e5d5b64521df4c8bf7e22bc0b897827cbee9872c/scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389", size = 39069598, upload-time = "2024-08-21T00:03:32.896Z" }, + { url = "https://files.pythonhosted.org/packages/43/a5/8d02f9c372790326ad405d94f04d4339482ec082455b9e6e288f7100513b/scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3", size = 29879676, upload-time = "2024-08-21T00:03:38.844Z" }, + { url = "https://files.pythonhosted.org/packages/07/42/0e0bea9666fcbf2cb6ea0205db42c81b1f34d7b729ba251010edf9c80ebd/scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0", size = 23088696, upload-time = "2024-08-21T00:03:43.583Z" }, + { url = "https://files.pythonhosted.org/packages/15/47/298ab6fef5ebf31b426560e978b8b8548421d4ed0bf99263e1eb44532306/scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3", size = 25470699, upload-time = "2024-08-21T00:03:48.466Z" }, + { url = "https://files.pythonhosted.org/packages/d8/df/cdb6be5274bc694c4c22862ac3438cb04f360ed9df0aecee02ce0b798380/scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d", size = 35606631, upload-time = "2024-08-21T00:03:54.532Z" }, + { url = "https://files.pythonhosted.org/packages/47/78/b0c2c23880dd1e99e938ad49ccfb011ae353758a2dc5ed7ee59baff684c3/scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69", size = 41178528, upload-time = "2024-08-21T00:04:00.862Z" }, + { url = "https://files.pythonhosted.org/packages/5d/aa/994b45c34b897637b853ec04334afa55a85650a0d11dacfa67232260fb0a/scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad", size = 42784535, upload-time = "2024-08-21T00:04:12.65Z" }, + { url = "https://files.pythonhosted.org/packages/e7/1c/8daa6df17a945cb1a2a1e3bae3c49643f7b3b94017ff01a4787064f03f84/scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5", size = 44772117, upload-time = "2024-08-21T00:04:20.613Z" }, + { url = "https://files.pythonhosted.org/packages/b2/ab/070ccfabe870d9f105b04aee1e2860520460ef7ca0213172abfe871463b9/scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675", size = 39076999, upload-time = "2024-08-21T00:04:32.61Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c5/02ac82f9bb8f70818099df7e86c3ad28dae64e1347b421d8e3adf26acab6/scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2", size = 29894570, upload-time = "2024-08-21T00:04:37.938Z" }, + { url = "https://files.pythonhosted.org/packages/ed/05/7f03e680cc5249c4f96c9e4e845acde08eb1aee5bc216eff8a089baa4ddb/scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617", size = 23103567, upload-time = "2024-08-21T00:04:42.582Z" }, + { url = "https://files.pythonhosted.org/packages/5e/fc/9f1413bef53171f379d786aabc104d4abeea48ee84c553a3e3d8c9f96a9c/scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8", size = 25499102, upload-time = "2024-08-21T00:04:47.467Z" }, + { url = "https://files.pythonhosted.org/packages/c2/4b/b44bee3c2ddc316b0159b3d87a3d467ef8d7edfd525e6f7364a62cd87d90/scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37", size = 35586346, upload-time = "2024-08-21T00:04:53.872Z" }, + { url = "https://files.pythonhosted.org/packages/93/6b/701776d4bd6bdd9b629c387b5140f006185bd8ddea16788a44434376b98f/scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2", size = 41165244, upload-time = "2024-08-21T00:05:00.489Z" }, + { url = "https://files.pythonhosted.org/packages/06/57/e6aa6f55729a8f245d8a6984f2855696c5992113a5dc789065020f8be753/scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2", size = 42817917, upload-time = "2024-08-21T00:05:07.533Z" }, + { url = "https://files.pythonhosted.org/packages/ea/c2/5ecadc5fcccefaece775feadcd795060adf5c3b29a883bff0e678cfe89af/scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94", size = 44781033, upload-time = "2024-08-21T00:05:14.297Z" }, + { url = "https://files.pythonhosted.org/packages/c0/04/2bdacc8ac6387b15db6faa40295f8bd25eccf33f1f13e68a72dc3c60a99e/scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d", size = 39128781, upload-time = "2024-08-21T04:08:04.15Z" }, + { url = "https://files.pythonhosted.org/packages/c8/53/35b4d41f5fd42f5781dbd0dd6c05d35ba8aa75c84ecddc7d44756cd8da2e/scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07", size = 29939542, upload-time = "2024-08-21T00:05:25.758Z" }, + { url = "https://files.pythonhosted.org/packages/66/67/6ef192e0e4d77b20cc33a01e743b00bc9e68fb83b88e06e636d2619a8767/scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5", size = 23148375, upload-time = "2024-08-21T00:05:30.359Z" }, + { url = "https://files.pythonhosted.org/packages/f6/32/3a6dedd51d68eb7b8e7dc7947d5d841bcb699f1bf4463639554986f4d782/scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc", size = 25578573, upload-time = "2024-08-21T00:05:35.274Z" }, + { url = "https://files.pythonhosted.org/packages/f0/5a/efa92a58dc3a2898705f1dc9dbaf390ca7d4fba26d6ab8cfffb0c72f656f/scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310", size = 35319299, upload-time = "2024-08-21T00:05:40.956Z" }, + { url = "https://files.pythonhosted.org/packages/8e/ee/8a26858ca517e9c64f84b4c7734b89bda8e63bec85c3d2f432d225bb1886/scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066", size = 40849331, upload-time = "2024-08-21T00:05:47.53Z" }, + { url = "https://files.pythonhosted.org/packages/a5/cd/06f72bc9187840f1c99e1a8750aad4216fc7dfdd7df46e6280add14b4822/scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1", size = 42544049, upload-time = "2024-08-21T00:05:59.294Z" }, + { url = "https://files.pythonhosted.org/packages/aa/7d/43ab67228ef98c6b5dd42ab386eae2d7877036970a0d7e3dd3eb47a0d530/scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f", size = 44521212, upload-time = "2024-08-21T00:06:06.521Z" }, + { url = "https://files.pythonhosted.org/packages/50/ef/ac98346db016ff18a6ad7626a35808f37074d25796fd0234c2bb0ed1e054/scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79", size = 39091068, upload-time = "2024-08-21T00:06:13.671Z" }, + { url = "https://files.pythonhosted.org/packages/b9/cc/70948fe9f393b911b4251e96b55bbdeaa8cca41f37c26fd1df0232933b9e/scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e", size = 29875417, upload-time = "2024-08-21T00:06:21.482Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2e/35f549b7d231c1c9f9639f9ef49b815d816bf54dd050da5da1c11517a218/scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73", size = 23084508, upload-time = "2024-08-21T00:06:28.064Z" }, + { url = "https://files.pythonhosted.org/packages/3f/d6/b028e3f3e59fae61fb8c0f450db732c43dd1d836223a589a8be9f6377203/scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e", size = 25503364, upload-time = "2024-08-21T00:06:35.25Z" }, + { url = "https://files.pythonhosted.org/packages/a7/2f/6c142b352ac15967744d62b165537a965e95d557085db4beab2a11f7943b/scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d", size = 35292639, upload-time = "2024-08-21T00:06:44.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/46/2449e6e51e0d7c3575f289f6acb7f828938eaab8874dbccfeb0cd2b71a27/scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e", size = 40798288, upload-time = "2024-08-21T00:06:54.182Z" }, + { url = "https://files.pythonhosted.org/packages/32/cd/9d86f7ed7f4497c9fd3e39f8918dd93d9f647ba80d7e34e4946c0c2d1a7c/scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06", size = 42524647, upload-time = "2024-08-21T00:07:04.649Z" }, + { url = "https://files.pythonhosted.org/packages/f5/1b/6ee032251bf4cdb0cc50059374e86a9f076308c1512b61c4e003e241efb7/scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84", size = 44469524, upload-time = "2024-08-21T00:07:15.381Z" }, ] [[package]] @@ -8976,6 +9087,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" }, ] +[[package]] +name = "tensorboard" +version = "2.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "grpcio" }, + { name = "markdown" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, + { name = "setuptools", version = "80.9.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "six" }, + { name = "tensorboard-data-server" }, + { name = "werkzeug" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/de/021c1d407befb505791764ad2cbd56ceaaa53a746baed01d2e2143f05f18/tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab", size = 5503036, upload-time = "2024-09-25T21:21:50.169Z" }, +] + +[[package]] +name = "tensorboard-data-server" +version = "0.7.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356, upload-time = "2023-10-23T21:23:32.16Z" }, + { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598, upload-time = "2023-10-23T21:23:33.714Z" }, + { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" }, +] + +[[package]] +name = "tensorboardx" +version = "2.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2b/c5/d4cc6e293fb837aaf9f76dd7745476aeba8ef7ef5146c3b3f9ee375fe7a5/tensorboardx-2.6.4.tar.gz", hash = "sha256:b163ccb7798b31100b9f5fa4d6bc22dad362d7065c2f24b51e50731adde86828", size = 4769801, upload-time = "2025-06-10T22:37:07.419Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/1d/b5d63f1a6b824282b57f7b581810d20b7a28ca951f2d5b59f1eb0782c12b/tensorboardx-2.6.4-py3-none-any.whl", hash = "sha256:5970cf3a1f0a6a6e8b180ccf46f3fe832b8a25a70b86e5a237048a7c0beb18e2", size = 87201, upload-time = "2025-06-10T22:37:05.44Z" }, +] + [[package]] name = "termcolor" version = "3.1.0" @@ -9156,6 +9312,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/de/a8/8f499c179ec900783ffe133e9aab10044481679bb9aad78436d239eee716/tiktoken-0.9.0-cp313-cp313-win_amd64.whl", hash = "sha256:5ea0edb6f83dc56d794723286215918c1cde03712cbbafa0348b33448faf5b95", size = 894669, upload-time = "2025-02-14T06:02:47.341Z" }, ] +[[package]] +name = "timm" +version = "1.0.22" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "pyyaml" }, + { name = "safetensors" }, + { name = "torch" }, + { name = "torchvision" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/9d/e4670765d1c033f97096c760b3b907eeb659cf80f3678640e5f060b04c6c/timm-1.0.22.tar.gz", hash = "sha256:14fd74bcc17db3856b1a47d26fb305576c98579ab9d02b36714a5e6b25cde422", size = 2382998, upload-time = "2025-11-05T04:06:09.377Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d6/14/fc04d491527b774ec7479897f5861959209de1480e4c4cd32ed098ff8bea/timm-1.0.22-py3-none-any.whl", hash = "sha256:888981753e65cbaacfc07494370138b1700a27b1f0af587f4f9b47bc024161d0", size = 2530238, upload-time = "2025-11-05T04:06:06.823Z" }, +] + [[package]] name = "tldextract" version = "5.3.0" @@ -9354,6 +9526,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/a9/e2b6301fbf4590d352e183bef64927f74ef4d4f660cca3ed7a32dda60484/torchcodec-0.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:31b402c9ae3c6e9f33c41fddf7058f9492c443ad55d02f022395f8fa196b58f6", size = 1565405, upload-time = "2025-09-08T14:18:04.217Z" }, ] +[[package]] +name = "torchmetrics" +version = "1.8.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lightning-utilities" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/2e/48a887a59ecc4a10ce9e8b35b3e3c5cef29d902c4eac143378526e7485cb/torchmetrics-1.8.2.tar.gz", hash = "sha256:cf64a901036bf107f17a524009eea7781c9c5315d130713aeca5747a686fe7a5", size = 580679, upload-time = "2025-09-03T14:00:54.077Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/21/aa0f434434c48490f91b65962b1ce863fdcce63febc166ca9fe9d706c2b6/torchmetrics-1.8.2-py3-none-any.whl", hash = "sha256:08382fd96b923e39e904c4d570f3d49e2cc71ccabd2a94e0f895d1f0dac86242", size = 983161, upload-time = "2025-09-03T14:00:51.921Z" }, +] + [[package]] name = "torchvision" version = "0.23.0" @@ -9609,8 +9796,7 @@ dependencies = [ { name = "py-cpuinfo" }, { name = "pyyaml" }, { name = "requests" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "torch" }, { name = "torchvision" }, { name = "tqdm" }, @@ -9804,8 +9990,7 @@ dependencies = [ { name = "ray", extra = ["cgraph"] }, { name = "regex" }, { name = "requests" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy" }, { name = "sentencepiece" }, { name = "setproctitle" }, { name = "setuptools", version = "79.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, @@ -10097,6 +10282,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] +[[package]] +name = "werkzeug" +version = "3.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/45/ea/b0f8eeb287f8df9066e56e831c7824ac6bab645dd6c7a8f4b2d767944f9b/werkzeug-3.1.4.tar.gz", hash = "sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e", size = 864687, upload-time = "2025-11-29T02:15:22.841Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/f9/9e082990c2585c744734f85bec79b5dae5df9c974ffee58fe421652c8e91/werkzeug-3.1.4-py3-none-any.whl", hash = "sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905", size = 224960, upload-time = "2025-11-29T02:15:21.13Z" }, +] + [[package]] name = "wget" version = "3.2"