diff --git a/align_system/algorithms/abstracts.py b/align_system/algorithms/abstracts.py index e10310c9..5681a444 100644 --- a/align_system/algorithms/abstracts.py +++ b/align_system/algorithms/abstracts.py @@ -16,11 +16,12 @@ def choose_action(self, class StructuredInferenceEngine(ABC): @abstractmethod - def dialog_to_prompt(dialog: list[dict]) -> str: + def dialog_to_prompt(self, dialog: list[dict]) -> str: pass @abstractmethod - def run_inference(prompts: Union[str, list[str]], + def run_inference(self, + prompts: Union[str, list[str]], schema: str) -> Union[dict, list[dict]]: pass diff --git a/align_system/algorithms/outlines_adm.py b/align_system/algorithms/outlines_adm.py index 56a9e2ab..e44dffd2 100644 --- a/align_system/algorithms/outlines_adm.py +++ b/align_system/algorithms/outlines_adm.py @@ -7,7 +7,7 @@ from functools import partial import outlines -from outlines.samplers import MultinomialSampler +from outlines.types import JsonSchema import jinja2 from rich.highlighter import JSONHighlighter from align_system.data_models.compat.ta3_ph1_client_models import ( @@ -16,6 +16,7 @@ CharacterTagEnum, KDMAValue ) +import transformers from align_system.utils import logging from align_system.utils import adm_utils @@ -67,7 +68,7 @@ def __init__(self, model_name, device='auto', baseline=False, - sampler=MultinomialSampler(), + generation_kwargs=None, scenario_description_template=scenario_state_description_1, action_selection_prompt_template=action_selection_prompt, baseline_system_prompt=baseline_system_prompt, @@ -86,19 +87,21 @@ def __init__(self, f"Unexpected value for 'precision' ({kwargs['precision']})" ", expecting either 'half' or 'full'") - model_kwargs['torch_dtype'] = torch_dtype + model_kwargs['dtype'] = torch_dtype - self.model = outlines.models.transformers( - model_name, - device=device, - model_kwargs=model_kwargs, - tokenizer_kwargs=kwargs.get('tokenizer_kwargs', {})) - # NOTE: In cases where we want multiple samples, we're passing - # in a list of prompts (this allows us to shuffle answers in - # each prompt), rather than setting the number of samples in - # the sampler itself (which defaults to 1); setting the number - # of samples in the sampler may result in unexpected behavior - self.sampler = sampler + self.model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs, device_map=device), + transformers.AutoTokenizer.from_pretrained(model_name, **kwargs.get('tokenizer_kwargs', {})), + device_dtype=torch_dtype) + + if generation_kwargs is None: + generation_kwargs = {'temperature': 0.7} + self.generation_kwargs = generation_kwargs + + # Sometimes the internal default for outlines/transformers is 20, + # leading to very short (and often invalid JSON) outputs. Setting a + # somewhat generous default. + self.generation_kwargs.setdefault('max_new_tokens', 8192) self.outlines_seed = outlines_seed if self.outlines_seed is None: @@ -240,15 +243,11 @@ def batched(cls, iterable, n): yield batch @classmethod - def run_in_batches(cls, inference_function, inputs, batch_size, rng=None): + def run_in_batches(cls, inference_function, inputs, batch_size, **generation_kwargs): ''' Batch inference to avoid out of memory error''' outputs = [] for batch in cls.batched(inputs, batch_size): - if rng is None: - output = inference_function(list(batch)) - else: - output = inference_function(list(batch), rng=rng) - + output = inference_function(list(batch), **generation_kwargs) if not isinstance(output, list): output = [output] outputs.extend(output) @@ -432,12 +431,14 @@ def top_level_choose_action(self, # Need to set the whitespace_pattern to prevent the state # machine from looping indefinitely in some cases, see: # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934 - generator = outlines.generate.json( - self.model, + json_schema = JsonSchema( action_choice_json_schema(json.dumps(choices), reasoning_max_length), - sampler=self.sampler, whitespace_pattern=r"[ ]?") + generator = outlines.Generator( + self.model, + json_schema) + if max_generator_tokens >= 0: generator = partial(generator, max_tokens=max_generator_tokens) @@ -454,7 +455,13 @@ def top_level_choose_action(self, extra={"markup": True}) log.info(dialog_texts[0]) - responses = self.run_in_batches(generator, dialog_texts, generator_batch_size, rng=self.outlines_rng) + responses = self.run_in_batches(generator.batch, + dialog_texts, + generator_batch_size, + rng=self.outlines_rng, + **self.generation_kwargs) + responses = [json.loads(r) for r in responses] + positive_responses_choices =\ [r['action_choice'] for r in responses[0:num_positive_samples]] @@ -657,17 +664,19 @@ def ensure_character_id_is_populated(self, character_names = [c.name for c in characters] - generator = outlines.generate.json( - self.model, + json_schema = JsonSchema( character_choice_json_schema(json.dumps(character_names)), - sampler=self.sampler, whitespace_pattern=r"[ ]?") + generator = outlines.Generator( + self.model, + json_schema) + log.info("[bold]*DIALOG PROMPT*[/bold]", extra={"markup": True}) log.info(dialog_text) - selected_character = generator(dialog_text) + selected_character = json.loads(generator(dialog_text, **self.generation_kwargs)) selected_character_idx = character_names.index(selected_character['character_choice']) log.info("[bold]*STRUCTURED RESPONSE*[/bold]", @@ -727,19 +736,21 @@ def populate_treatment_parameters(self, dialog_text = self.dialog_to_prompt(dialog) - generator = outlines.generate.json( - self.model, + json_schema = JsonSchema( treatment_choice_json_schema( json.dumps([s.type for s in available_supplies]), json.dumps(valid_treatment_locations)), - sampler=self.sampler, whitespace_pattern=r"[ ]?") + generator = outlines.Generator( + self.model, + json_schema) + log.info("[bold]*DIALOG PROMPT*[/bold]", extra={"markup": True}) log.info(dialog_text) - selected_treatment = generator(dialog_text) + selected_treatment = json.loads(generator(dialog_text, **self.generation_kwargs)) log.info("[bold]*STRUCTURED RESPONSE*[/bold]", extra={"markup": True}) @@ -799,14 +810,16 @@ def select_treatment_parameters(self, extra={"markup": True}) log.info(dialog_text) - generator = outlines.generate.json( - self.model, + json_schema = JsonSchema( treatment_choice_from_list_json_schema( json.dumps(possible_treatments)), - sampler=self.sampler, whitespace_pattern=r"[ ]?") - selected_treatment = generator(dialog_text) + generator = outlines.Generator( + self.model, + json_schema) + + selected_treatment = json.loads(generator(dialog_text, **self.generation_kwargs)) log.info("[bold]*STRUCTURED RESPONSE*[/bold]", extra={"markup": True}) log.info(selected_treatment, extra={"highlighter": JSON_HIGHLIGHTER}) @@ -843,18 +856,20 @@ def populate_tagging_parameters(self, dialog_text = self.dialog_to_prompt(dialog) - generator = outlines.generate.json( - self.model, + json_schema = JsonSchema( tag_choice_json_schema( json.dumps(valid_tags)), - sampler=self.sampler, whitespace_pattern=r"[ ]?") + generator = outlines.Generator( + self.model, + json_schema) + log.info("[bold]*DIALOG PROMPT*[/bold]", extra={"markup": True}) log.info(dialog_text) - selected_tag = generator(dialog_text) + selected_tag = json.loads(generator(dialog_text, **self.generation_kwargs)) log.info("[bold]*STRUCTURED RESPONSE*[/bold]", extra={"markup": True}) @@ -906,18 +921,20 @@ def populate_aid_parameters(self, dialog_text = self.dialog_to_prompt(dialog) - generator = outlines.generate.json( - self.model, + json_schema = JsonSchema( aid_choice_json_schema( json.dumps([aid.id for aid in available_aids])), - sampler=self.sampler, whitespace_pattern=r"[ ]?") + generator = outlines.Generator( + self.model, + json_schema) + log.info("[bold]*DIALOG PROMPT*[/bold]", extra={"markup": True}) log.info(dialog_text) - selected_aid = generator(dialog_text) + selected_aid = json.loads(generator(dialog_text, **self.generation_kwargs)) log.info("[bold]*STRUCTURED RESPONSE*[/bold]", extra={"markup": True}) diff --git a/align_system/algorithms/outlines_inference_engine.py b/align_system/algorithms/outlines_inference_engine.py index 3fc3db35..b8c0bb11 100644 --- a/align_system/algorithms/outlines_inference_engine.py +++ b/align_system/algorithms/outlines_inference_engine.py @@ -1,30 +1,46 @@ import itertools from collections.abc import Iterable from textwrap import dedent +import json +import transformers import outlines -from outlines.samplers import MultinomialSampler +from outlines.types import JsonSchema import jinja2 import torch from align_system.algorithms.abstracts import StructuredInferenceEngine +# Sometimes the internal default for outlines/transformers is 20, +# leading to very short (and often invalid JSON) outputs. Setting a +# somewhat generous default. +DEFAULT_MAX_GENERATOR_TOKENS=8192 class OutlinesTransformersInferenceEngine(StructuredInferenceEngine): def __init__(self, model_name, - device='auto', precision='full', - max_generator_tokens=None, - sampler=MultinomialSampler(), + max_generator_tokens=DEFAULT_MAX_GENERATOR_TOKENS, inference_batch_size=5, - model_kwargs={}, - tokenizer_kwargs={}): + generation_kwargs=None, + model_kwargs=None, + tokenizer_kwargs=None): self.model_name = model_name self.precision = precision self.inference_batch_size = inference_batch_size + + if model_kwargs is None: + model_kwargs = {} self.model_kwargs = model_kwargs + + if tokenizer_kwargs is None: + tokenizer_kwargs = {} self.tokenizer_kwargs = tokenizer_kwargs + + if generation_kwargs is None: + generation_kwargs = {} + self.generation_kwargs = generation_kwargs + self.max_generator_tokens = max_generator_tokens if self.precision == 'half': @@ -36,19 +52,12 @@ def __init__(self, f"Unexpected value for 'precision' ({precision})" ", expecting either 'half' or 'full'") - self.model_kwargs['torch_dtype'] = torch_dtype + self.model_kwargs['dtype'] = torch_dtype - self.model = outlines.models.transformers( - self.model_name, - device=device, - model_kwargs=self.model_kwargs, - tokenizer_kwargs=self.tokenizer_kwargs) - # NOTE: In cases where we want multiple samples, we're passing - # in a list of prompts (this allows us to shuffle answers in - # each prompt), rather than setting the number of samples in - # the sampler itself (which defaults to 1); setting the number - # of samples in the sampler may result in unexpected behavior - self.sampler = sampler + self.model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained(model_name, **self.model_kwargs, device_map='auto'), + transformers.AutoTokenizer.from_pretrained(model_name, **self.tokenizer_kwargs), + device_dtype=torch_dtype) def dialog_to_prompt(self, dialog): tokenizer = self.model.tokenizer.tokenizer @@ -85,29 +94,36 @@ def batched(cls, iterable, n): yield batch @classmethod - def run_in_batches(cls, inference_function, inputs, batch_size, max_generator_tokens=None): + def run_in_batches(cls, + inference_function, + inputs, + batch_size, + max_generator_tokens=DEFAULT_MAX_GENERATOR_TOKENS, + **generation_kwargs): ''' Batch inference to avoid out of memory error''' outputs = [] for batch in cls.batched(inputs, batch_size): - output = inference_function(list(batch), max_tokens=max_generator_tokens) + output = inference_function(list(batch), max_new_tokens=max_generator_tokens, **generation_kwargs) if not isinstance(output, list): output = [output] outputs.extend(output) return outputs def run_inference(self, prompts, schema): - generator = outlines.generate.json( + json_schema = JsonSchema(schema, whitespace_pattern=r"[ ]?") + + generator = outlines.Generator( self.model, - schema, - sampler=self.sampler, - whitespace_pattern=r"[ ]?") + json_schema) if isinstance(prompts, str): - return generator(prompts, max_tokens=self.max_generator_tokens) + output = generator(prompts, max_new_tokens=self.max_generator_tokens, **self.generation_kwargs) + return json.loads(output) elif isinstance(prompts, Iterable): - return self.run_in_batches( - generator, prompts, self.inference_batch_size, self.max_generator_tokens + output = self.run_in_batches( + generator.batch, prompts, self.inference_batch_size, self.max_generator_tokens, **self.generation_kwargs ) + return [json.loads(r) for r in output] else: raise TypeError("Don't know how to run inference on provided " "`prompts` object") @@ -116,7 +132,7 @@ def run_inference_unstructured(self, prompts): generator = outlines.generate.regex( self.model, r'.*', # "allow anything" regex - sampler=self.sampler) + **self.generation_kwargs) if isinstance(prompts, str): return generator(prompts, self.max_generator_tokens) @@ -135,18 +151,12 @@ def cache_repr(self): object instances, it's assumed that inference output will be the same ''' - def _sampler_repr(sampler): - return "{}.{}({})".format( - sampler.__class__.__module__, - sampler.__class__.__name__, - ", ".join([f"{k}={v}" for k, v in vars(sampler).items()])) - return dedent(f""" {self.__class__.__module__}.{self.__class__.__name__}( model_name="{self.model_name}", precision="{self.precision}", - sampler={_sampler_repr(self.sampler)}, inference_batch_size={self.inference_batch_size}, model_kwargs={self.model_kwargs}, tokenizer_kwargs={self.tokenizer_kwargs}, + generation_kwargs={self.generation_kwargs}, )""").strip() diff --git a/align_system/algorithms/outlines_regression_adm_comparative.py b/align_system/algorithms/outlines_regression_adm_comparative.py index bef8e178..80e6e6cb 100644 --- a/align_system/algorithms/outlines_regression_adm_comparative.py +++ b/align_system/algorithms/outlines_regression_adm_comparative.py @@ -5,9 +5,10 @@ import numpy as np import outlines -from outlines.samplers import MultinomialSampler +from outlines.types import JsonSchema from rich.highlighter import JSONHighlighter from swagger_client.models import kdma_value +import transformers from align_system.utils import logging from align_system.utils import adm_utils @@ -41,7 +42,7 @@ def __init__(self, model_name, device='auto', baseline=False, - sampler=MultinomialSampler(), + generation_kwargs=None, probabilistic=False, **kwargs): self.baseline = baseline @@ -60,19 +61,22 @@ def __init__(self, f"Unexpected value for 'precision' ({kwargs['precision']})" ", expecting either 'half' or 'full'") - model_kwargs['torch_dtype'] = torch_dtype + model_kwargs['dtype'] = torch_dtype + + self.model = outlines.from_transformers( + transformers.AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs, device_map=device), + transformers.AutoTokenizer.from_pretrained(model_name, **kwargs.get('tokenizer_kwargs', {})), + device_dtype=torch_dtype) + + if generation_kwargs is None: + generation_kwargs = {'temperature': 0.7} + self.generation_kwargs = generation_kwargs + + # Sometimes the internal default for outlines/transformers is 20, + # leading to very short (and often invalid JSON) outputs. Setting a + # somewhat generous default. + self.generation_kwargs.setdefault('max_new_tokens', 8192) - self.model = outlines.models.transformers( - model_name, - device=device, - model_kwargs=model_kwargs, - tokenizer_kwargs=kwargs.get('tokenizer_kwargs', {})) - # NOTE: In cases where we want multiple samples, we're passing - # in a list of prompts (this allows us to shuffle answers in - # each prompt), rather than setting the number of samples in - # the sampler itself (which defaults to 1); setting the number - # of samples in the sampler may result in unexpected behavior - self.sampler = sampler def reset_history(self): self.choice_history = {} @@ -98,11 +102,12 @@ def sample_outcome_predictions(self, # Need to set the whitespace_pattern to prevent the state # machine from looping indefinitely in some cases, see: # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934 - outcome_generator = outlines.generate.json( + json_schema = JsonSchema(comparative_outcome_prediction_json_schema(choices), + whitespace_pattern=r"[ ]?") + + outcome_generator = outlines.Generator( self.model, - comparative_outcome_prediction_json_schema(choices), - sampler=self.sampler, - whitespace_pattern=r"[ ]?") + json_schema) outcome_dialog_texts = [self.dialog_to_prompt(d) for d in outcome_dialogs] @@ -111,7 +116,12 @@ def sample_outcome_predictions(self, log.info(outcome_dialog_texts[0]) # List of {choice: {predicted_outcomes:str}, ...} with length = num_samples - predicted_outcomes = self.run_in_batches(outcome_generator, outcome_dialog_texts, batch_size) + predicted_outcomes = self.run_in_batches(outcome_generator.batch, + outcome_dialog_texts, + batch_size, + **self.generation_kwargs) + # Newer outlines doesn't automatically JSON load output + predicted_outcomes = [json.loads(o) for o in predicted_outcomes] log.info("[bold]*OUTCOME PREDICTION RESPONSE*[/bold]", extra={"markup": True}) @@ -175,13 +185,14 @@ def sample_relevance_predictions(self, # Need to set the whitespace_pattern to prevent the state # machine from looping indefinitely in some cases, see: # https://github.com/outlines-dev/outlines/issues/690#issuecomment-2102291934 - relevance_schema = relevance_classification_json_schema(choices, target_kdma['factor']) - relevance_generator = outlines.generate.json( - self.model, - relevance_schema, - sampler=self.sampler, + relevance_schema = JsonSchema( + relevance_classification_json_schema(choices, target_kdma['factor']), whitespace_pattern=r"[ ]?") + relevance_generator = outlines.Generator( + self.model, + relevance_schema) + relevance_dialog_texts = [self.dialog_to_prompt(d) for d in relevance_dialogs] log.info("[bold]*KDMA SCORE PREDICTION DIALOG PROMPT*[/bold]", @@ -189,7 +200,11 @@ def sample_relevance_predictions(self, log.info(relevance_dialog_texts[0]) # List of {choice: {score:int, reasoning:str}, ...} with length = num_samples*len(target_kdmas) - relevance_score_responses = self.run_in_batches(relevance_generator, relevance_dialog_texts, batch_size) + relevance_score_responses = self.run_in_batches(relevance_generator.batch, + relevance_dialog_texts, + batch_size, + **self.generation_kwargs) + relevance_score_responses = [json.loads(r) for r in relevance_score_responses] # Reshape to matrix of num_samples x len(target_kdmas) relevance_responses = [relevance_score_responses[i:i+len(target_kdmas)] \ for i in range(0,len(relevance_score_responses),len(target_kdmas))] @@ -303,11 +318,12 @@ def sample_kdma_score_predictions(self, score_schema = enum_comparative_kdma_score_prediction_json_schema(choices, target_kdma['valid_scores']) else: score_schema = comparative_kdma_score_prediction_json_schema(choices, target_kdma['factor']) - kdma_score_generator = outlines.generate.json( + + score_json_schema = JsonSchema(score_schema, whitespace_pattern=r"[ ]?") + + kdma_score_generator = outlines.Generator( self.model, - score_schema, - sampler=self.sampler, - whitespace_pattern=r"[ ]?") + score_json_schema) kdma_dialog_texts = [self.dialog_to_prompt(d) for d in kdma_dialogs] @@ -316,7 +332,12 @@ def sample_kdma_score_predictions(self, log.info(kdma_dialog_texts[0]) # List of {choice: {score:int, reasoning:str}, ...} with length = num_samples*len(target_kdmas) - kdma_score_responses = self.run_in_batches(kdma_score_generator, kdma_dialog_texts, batch_size) + kdma_score_responses = self.run_in_batches(kdma_score_generator.batch, + kdma_dialog_texts, + batch_size, + **self.generation_kwargs) + kdma_score_responses = [json.loads(r) for r in kdma_score_responses] + # Reshape to matrix of num_samples x len(target_kdmas) kdma_score_responses = [kdma_score_responses[i:i+len(target_kdmas)] \ for i in range(0,len(kdma_score_responses),len(target_kdmas))] diff --git a/align_system/algorithms/vllm_inference_engine.py b/align_system/algorithms/vllm_inference_engine.py new file mode 100644 index 00000000..b5c4f98e --- /dev/null +++ b/align_system/algorithms/vllm_inference_engine.py @@ -0,0 +1,70 @@ +from typing import Union +import json + +import jinja2 +from vllm import LLM, SamplingParams +from vllm.sampling_params import StructuredOutputsParams + +from align_system.algorithms.abstracts import StructuredInferenceEngine + +# Sometimes the internal default for VLLM is 50, +# leading to very short (and often invalid JSON) outputs. Setting a +# somewhat generous default. +DEFAULT_MAX_TOKENS = 8192 + +class VLLMInferenceEngine(StructuredInferenceEngine): + def __init__(self, + model_name, + sampling_params=None): + self.llm = LLM(model=model_name) + + self.sampling_params = sampling_params + if self.sampling_params is None: + self.sampling_params = {} + + if 'max_tokens' not in self.sampling_params: + self.sampling_params['max_tokens'] = DEFAULT_MAX_TOKENS + + def dialog_to_prompt(self, dialog: list[dict]) -> str: + tokenizer = self.llm.get_tokenizer() + + try: + encoded_dialog = tokenizer.apply_chat_template(dialog) + except jinja2.exceptions.TemplateError: + # Assume that the tokenizer chat template doesn't accept + # system messages; combine system message first user + # message + # Ensure each dialog element is a dict + system_msg, user_msg, *rest = [dict(d) for d in dialog] + + assert user_msg['role'] == 'user' + + updated_content = system_msg['content'] + '\n' + user_msg['content'] + + dialog = [{'role': 'user', 'content': updated_content}, *rest] + + encoded_dialog = tokenizer.apply_chat_template(dialog) + + return tokenizer.decode(encoded_dialog) + + def run_inference(self, + prompts: Union[str, list[str]], + schema: str) -> Union[dict, list[dict]]: + json_schema = json.loads(schema) + schema_params = StructuredOutputsParams(json=json_schema) + + structured_sampling_params = SamplingParams( + **self.sampling_params, + structured_outputs=schema_params) + + outputs = self.llm.generate( + prompts, + sampling_params=structured_sampling_params) + + parsed_outputs = [json.loads(o.outputs[0].text) for o in outputs] + + if isinstance(prompts, str): + # Return single instance if single prompt provided as a string + return parsed_outputs[0] + else: + return parsed_outputs diff --git a/align_system/configs/experiment/integration_tests/comp_reg_icl_adept_1.yaml b/align_system/configs/experiment/integration_tests/comp_reg_icl_adept_1.yaml index f32e25b3..a4a1bc7d 100644 --- a/align_system/configs/experiment/integration_tests/comp_reg_icl_adept_1.yaml +++ b/align_system/configs/experiment/integration_tests/comp_reg_icl_adept_1.yaml @@ -10,8 +10,8 @@ interface: adm: instance: precision: half - sampler: - _target_: outlines.samplers.GreedySampler + generation_kwargs: + do_sample: false model_name: mistralai/Mistral-7B-Instruct-v0.3 inference_kwargs: kdma_score_examples: true diff --git a/align_system/configs/experiment/integration_tests/comp_reg_icl_soartech_1.yaml b/align_system/configs/experiment/integration_tests/comp_reg_icl_soartech_1.yaml index f4ed95fd..a6871c79 100644 --- a/align_system/configs/experiment/integration_tests/comp_reg_icl_soartech_1.yaml +++ b/align_system/configs/experiment/integration_tests/comp_reg_icl_soartech_1.yaml @@ -10,8 +10,8 @@ interface: adm: instance: precision: half - sampler: - _target_: outlines.samplers.GreedySampler + generation_kwargs: + do_sample: false model_name: meta-llama/Llama-3.2-3B-Instruct inference_kwargs: distribution_matching: cumulative_kde diff --git a/align_system/configs/inference_engine/outlines_structured_greedy.yaml b/align_system/configs/inference_engine/outlines_structured_greedy.yaml index c6a800df..c08c2796 100644 --- a/align_system/configs/inference_engine/outlines_structured_greedy.yaml +++ b/align_system/configs/inference_engine/outlines_structured_greedy.yaml @@ -2,5 +2,7 @@ _target_: align_system.algorithms.outlines_inference_engine.OutlinesTransformers model_name: mistralai/Mistral-7B-Instruct-v0.3 precision: half -sampler: - _target_: outlines.samplers.GreedySampler +max_generator_tokens: 8192 +generation_kwargs: + do_sample: false + temperature: 0.0 diff --git a/align_system/configs/inference_engine/outlines_structured_multinomial.yaml b/align_system/configs/inference_engine/outlines_structured_multinomial.yaml index 37a97606..ffafdebe 100644 --- a/align_system/configs/inference_engine/outlines_structured_multinomial.yaml +++ b/align_system/configs/inference_engine/outlines_structured_multinomial.yaml @@ -2,6 +2,7 @@ _target_: align_system.algorithms.outlines_inference_engine.OutlinesTransformers model_name: mistralai/Mistral-7B-Instruct-v0.3 precision: half -sampler: - _target_: outlines.samplers.MultinomialSampler +max_generator_tokens: 8192 +generation_kwargs: + do_sample: true temperature: 0.7 diff --git a/align_system/configs/inference_engine/outlines_structured_multinomial_constrained.yaml b/align_system/configs/inference_engine/outlines_structured_multinomial_constrained.yaml index a57cf713..56dbcf27 100644 --- a/align_system/configs/inference_engine/outlines_structured_multinomial_constrained.yaml +++ b/align_system/configs/inference_engine/outlines_structured_multinomial_constrained.yaml @@ -3,7 +3,7 @@ _target_: align_system.algorithms.outlines_inference_engine.OutlinesTransformers model_name: meta-llama/Llama-3.1-8B-Instruct precision: half max_generator_tokens: 8192 -sampler: - _target_: outlines.samplers.MultinomialSampler +generation_kwargs: + do_sample: true temperature: 0.3 # Low temp = nearly deterministic, but faster than greedy top_p: 0.75 diff --git a/align_system/prompt_engineering/compat/outlines/templates.py b/align_system/prompt_engineering/compat/outlines/templates.py new file mode 100644 index 00000000..29013f66 --- /dev/null +++ b/align_system/prompt_engineering/compat/outlines/templates.py @@ -0,0 +1,369 @@ +# COMPATIBILITY NOTICE: This file copied from version 0.2.1 of +# outlines (https://github.com/dottxt-ai/outlines/blob/0.2.1/outlines/templates.py), +# as we still make heavy use of the outlines.prompt +# decorator. Future prompt setups should use Template.from_string or +# Template.from_file methods recommended in newer outlines versions. + +import functools +import inspect +import json +import os +import re +import textwrap +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Type, cast +import warnings + +import jinja2 +import pydantic + + +@dataclass +class Template: + """Represents a prompt template. + + We return a `Template` class instead of a simple function so the + template can be accessed by callers. + + """ + + template: jinja2.Template + signature: Optional[inspect.Signature] + + def __call__(self, *args, **kwargs) -> str: + """Render and return the template. + + Returns + ------- + The rendered template as a Python ``str``. + + """ + if self.signature is not None: + bound_arguments = self.signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return self.template.render(**bound_arguments.arguments) + else: + return self.template.render(**kwargs) + + @classmethod + def from_str(cls, content: str, filters: Dict[str, Callable] = {}): + """Create a `Template` instance from a string containing a Jinja template. + + Parameters + ---------- + content : str + The string content to be converted into a template. + + Returns + ------- + An instance of the class with the provided content as a template. + """ + return cls(build_template_from_str(content, filters), None) + + @classmethod + def from_file(cls, path: Path, filters: Dict[str, Callable] = {}): + """Create a `Template` instance from a file containing a Jinja template. + + Note: This method does not allow to include and inheritance to reference files + that are outside the folder or subfolders of the file given to `from_file`. + + Parameters + ---------- + path : Path + The path to the file containing the Jinja template. + + Returns + ------- + Template + An instance of the Template class with the template loaded from the file. + """ + # We don't use a `Signature` here because it seems not feasible to infer one from a Jinja2 environment that is + # split across multiple files (since e.g. we support features like Jinja2 includes and template inheritance) + return cls(build_template_from_file(path, filters), None) + + +def build_template_from_str( + content: str, filters: Dict[str, Callable] = {} +) -> jinja2.Template: + # Dedent, and remove extra linebreak + cleaned_template = inspect.cleandoc(content) + + # Add linebreak if there were any extra linebreaks that + # `cleandoc` would have removed + ends_with_linebreak = content.replace(" ", "").endswith("\n\n") + if ends_with_linebreak: + cleaned_template += "\n" + + # Remove extra whitespaces, except those that immediately follow a newline symbol. + # This is necessary to avoid introducing whitespaces after backslash `\` characters + # used to continue to the next line without linebreak. + cleaned_template = re.sub(r"(?![\r\n])(\b\s+)", " ", cleaned_template) + + env = create_jinja_env(None, filters) + + return env.from_string(cleaned_template) + + +def build_template_from_file( + path: Path, filters: Dict[str, Callable] = {} +) -> jinja2.Template: + file_directory = os.path.dirname(os.path.abspath(path)) + env = create_jinja_env(jinja2.FileSystemLoader(file_directory), filters) + + return env.get_template(os.path.basename(path)) + + +def prompt( + fn: Optional[Callable] = None, + filters: Dict[str, Callable] = {}, +) -> Callable: + """Decorate a function that contains a prompt template. + + This allows to define prompts in the docstring of a function and simplify their + manipulation by providing some degree of encapsulation. It uses the `render` + function internally to render templates. + + ```pycon + >>> import outlines + >>> + >>> @outlines.prompt + >>> def build_prompt(question): + ... "I have a ${question}" + ... + >>> prompt = build_prompt("How are you?") + ``` + + This API can also be helpful in an "agent" context where parts of the prompt + are set when the agent is initialized and never modified later. In this situation + we can partially apply the prompt function at initialization. + + ```pycon + >>> import outlines + >>> import functools as ft + ... + >>> @outlines.prompt + ... def solve_task(name: str, objective: str, task: str): + ... \"""Your name is {{name}}. + ... Your overall objective is to {{objective}}. + ... Please solve the following task: {{task}} + ... \""" + ... + >>> hal = ft.partial(solve_task, "HAL", "Travel to Jupiter") + ``` + + Additional Jinja2 filters can be provided as keyword arguments to the decorator. + + ```pycon + >>> def reverse(s: str) -> str: + ... return s[::-1] + ... + >>> @outlines.prompt(filters={ 'reverse': reverse }) + ... def reverse_prompt(text): + ... \"""{{ text | reverse }}\""" + ... + >>> prompt = reverse_prompt("Hello") + >>> print(prompt) + ... "olleH" + ``` + + Returns + ------- + A `Template` callable class which will render the template when called. + + """ + warnings.warn( + "The @prompt decorator is deprecated and will be removed in outlines 1.1.0. " + "Instead of using docstring templates, please use Template.from_file() to " + "load your prompts from separate template files, or a simple Python function " + "that returns text. This helps keep prompt content separate from code and is " + "more maintainable.", + DeprecationWarning, + stacklevel=2, + ) + + if fn is None: + return lambda fn: prompt(fn, cast(Dict[str, Callable], filters)) + + signature = inspect.signature(fn) + + # The docstring contains the template that will be rendered to be used + # as a prompt to the language model. + docstring = fn.__doc__ + if docstring is None: + raise TypeError("Could not find a template in the function's docstring.") + + template = build_template_from_str(cast(str, docstring), filters) + + return Template(template, signature) + + +def create_jinja_env( + loader: Optional[jinja2.BaseLoader], filters: Dict[str, Callable] +) -> jinja2.Environment: + """Create a new Jinja environment. + + The Jinja environment is loaded with a set of pre-defined filters: + - `name`: get the name of a function + - `description`: get a function's docstring + - `source`: get a function's source code + - `signature`: get a function's signature + - `args`: get a function's arguments + - `schema`: isplay a JSON Schema + + Users may pass additional filters, and/or override existing ones. + + Arguments + --------- + loader + An optional `BaseLoader` instance + filters + A dictionary of filters, map between the filter's name and the + corresponding function. + + """ + env = jinja2.Environment( + loader=loader, + trim_blocks=True, + lstrip_blocks=True, + keep_trailing_newline=True, + undefined=jinja2.StrictUndefined, + ) + + env.filters["name"] = get_fn_name + env.filters["description"] = get_fn_description + env.filters["source"] = get_fn_source + env.filters["signature"] = get_fn_signature + env.filters["schema"] = get_schema + env.filters["args"] = get_fn_args + + # The filters passed by the user may override the + # pre-defined filters. + for name, filter_fn in filters.items(): + env.filters[name] = filter_fn + + return env + + +def get_fn_name(fn: Callable): + """Returns the name of a callable.""" + if not callable(fn): + raise TypeError("The `name` filter only applies to callables.") + + if not hasattr(fn, "__name__"): + name = type(fn).__name__ + else: + name = fn.__name__ + + return name + + +def get_fn_args(fn: Callable): + """Returns the arguments of a function with annotations and default values if provided.""" + if not callable(fn): + raise TypeError("The `args` filter only applies to callables.") + + arg_str_list = [] + signature = inspect.signature(fn) + arg_str_list = [str(param) for param in signature.parameters.values()] + arg_str = ", ".join(arg_str_list) + return arg_str + + +def get_fn_description(fn: Callable): + """Returns the first line of a callable's docstring.""" + if not callable(fn): + raise TypeError("The `description` filter only applies to callables.") + + docstring = inspect.getdoc(fn) + if docstring is None: + description = "" + else: + description = docstring.split("\n")[0].strip() + + return description + + +def get_fn_source(fn: Callable): + """Return the source code of a callable.""" + if not callable(fn): + raise TypeError("The `source` filter only applies to callables.") + + source = textwrap.dedent(inspect.getsource(fn)) + re_search = re.search(re.compile(r"(\bdef\b.*)", re.DOTALL), source) + if re_search is not None: + source = re_search.group(0) + else: + raise TypeError("Could not read the function's source code") + + return source + + +def get_fn_signature(fn: Callable): + """Return the signature of a callable.""" + if not callable(fn): + raise TypeError("The `source` filter only applies to callables.") + + source = textwrap.dedent(inspect.getsource(fn)) + re_search = re.search(re.compile(r"\(([^)]+)\)"), source) + if re_search is None: + signature = "" + else: + signature = re_search.group(1) + + return signature + + +@functools.singledispatch +def get_schema(model: Any): + raise NotImplementedError( + f"No schema rendering function defined for type {type(model)}." + ) + + +@get_schema.register(dict) +def get_schema_dict(model: Dict): + """Return a pretty-printed dictionary""" + return json.dumps(model, indent=2) + + +@get_schema.register(type(pydantic.BaseModel)) +def get_schema_pydantic(model: Type[pydantic.BaseModel]): + """Return the schema of a Pydantic model.""" + if not isinstance(model, type(pydantic.BaseModel)): + raise TypeError("The `schema` filter only applies to Pydantic models.") + + if hasattr(model, "model_json_schema"): + def_key = "$defs" + raw_schema = model.model_json_schema() + else: # pragma: no cover + def_key = "definitions" + raw_schema = model.schema() + + definitions = raw_schema.get(def_key, None) + schema = parse_pydantic_schema(raw_schema, definitions) + + return json.dumps(schema, indent=2) + + +def parse_pydantic_schema(raw_schema, definitions): + """Parse the output of `Basemodel.[schema|model_json_schema]()`. + + This recursively follows the references to other schemas in case + of nested models. Other schemas are stored under the "definitions" + key in the schema of the top-level model. + + """ + simple_schema = {} + for name, value in raw_schema["properties"].items(): + if "description" in value: + simple_schema[name] = value["description"] + elif "$ref" in value: + refs = value["$ref"].split("/") + simple_schema[name] = parse_pydantic_schema( + definitions[refs[2]], definitions + ) + else: + simple_schema[name] = f"<{name}>" + + return simple_schema diff --git a/align_system/prompt_engineering/outlines_prompts.py b/align_system/prompt_engineering/outlines_prompts.py index c52b024f..7d584b50 100644 --- a/align_system/prompt_engineering/outlines_prompts.py +++ b/align_system/prompt_engineering/outlines_prompts.py @@ -1,8 +1,8 @@ import jinja2 -import outlines import json import re +from align_system.prompt_engineering.compat.outlines.templates import prompt as compat_outlines_prompt from align_system.utils.outlines_prompts_utils import ( get_unique_structured_character_info, new_get_relevant_structured_character_info) @@ -10,7 +10,7 @@ InjuryStatusEnum) -@outlines.prompt +@compat_outlines_prompt def baseline_system_prompt(): """ You are an assistant specialized in answering multiple-choice \ @@ -24,7 +24,7 @@ def baseline_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def high_moral_deservingness_system_prompt(): """ You are a specialized algorithm tasked with aiding medical triage \ @@ -44,7 +44,7 @@ def high_moral_deservingness_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def low_moral_deservingness_system_prompt(): """ As an algorithm designed to support medical triage in a military \ @@ -62,7 +62,7 @@ def low_moral_deservingness_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def high_maximization_system_prompt(): """ You are an assistant specialized in answering multiple-choice \ @@ -85,7 +85,7 @@ def high_maximization_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def low_maximization_system_prompt(): """ You are an assistant specialized in answering multiple-choice \ @@ -105,7 +105,7 @@ def low_maximization_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def high_protocol_focus_system_prompt(): """ You are a specialized algorithm built to assist in the decision-making process \ @@ -124,7 +124,7 @@ def high_protocol_focus_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def low_protocol_focus_system_prompt(): """ You are a specialized algorithm built to assist in the decision-making process \ @@ -141,7 +141,7 @@ def low_protocol_focus_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def high_fairness_system_prompt(): """ You are a specialized algorithm designed to help in the decision-making \ @@ -158,7 +158,7 @@ def high_fairness_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def low_fairness_system_prompt(): """ You are a specialized algorithm built to assist in medical triage \ @@ -174,7 +174,7 @@ def low_fairness_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def high_risk_aversion_system_prompt(): """ You are a specialized algorithm built to assist in the decision-making \ @@ -192,7 +192,7 @@ def high_risk_aversion_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def low_risk_aversion_system_prompt(): """ You are a specialized algorithm built to assist in the decision-making process \ @@ -207,7 +207,7 @@ def low_risk_aversion_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def high_continuing_care_system_prompt(): """ You are an algorithm programmed to assist with medical triage in a \ @@ -226,7 +226,7 @@ def high_continuing_care_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def low_continuing_care_system_prompt(): """ As an algorithm purposed for supporting medical triage in a military \ @@ -244,7 +244,7 @@ def low_continuing_care_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def high_utilitarianism_care_system_prompt(): """ You are a specialized algorithm designed to aid in the decision-making \ @@ -263,7 +263,7 @@ def high_utilitarianism_care_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def low_utilitarianism_system_prompt(): """ You are a specialized algorithm designed to assist in decision-making \ @@ -282,7 +282,7 @@ def low_utilitarianism_system_prompt(): """ -@outlines.prompt +@compat_outlines_prompt def action_selection_prompt(situation, choices): """ {{ situation }} @@ -295,14 +295,14 @@ def action_selection_prompt(situation, choices): """ -@outlines.prompt +@compat_outlines_prompt def detailed_unstructured_generic_action_text(action, character_id_to_name): """ {{ action.unstructured }}{% if action.character_id is not none %} -- {{ character_id_to_name[action.character_id] }}{% endif %}{% if action.parameters is not none %} -- {{ action.parameters }}{% endif %} ({{ action.unstructured }}) """ -@outlines.prompt +@compat_outlines_prompt def detailed_unstructured_treatment_action_text(action, character_id_to_name): """ {% if 'location' not in action.parameters or action.parameters['location'] == 'internal' -%} @@ -311,7 +311,7 @@ def detailed_unstructured_treatment_action_text(action, character_id_to_name): """ -@outlines.prompt +@compat_outlines_prompt def detailed_unstructured_tagging_action_text(action, character_id_to_name): """ {% if action.character_id is none -%} @@ -319,13 +319,13 @@ def detailed_unstructured_tagging_action_text(action, character_id_to_name): Tag {{ character_id_to_name[action.character_id] }} as {{ action.parameters['category'] }} ({{ action.unstructured }}){% endif -%} """ -@outlines.prompt +@compat_outlines_prompt def scenario_state_unstructured(scenario_state): """ {{ scenario_state.unstructured.rstrip() }} """ -@outlines.prompt +@compat_outlines_prompt def scenario_state_description_1(scenario_state): """ CHARACTERS: @@ -343,7 +343,7 @@ def scenario_state_description_1(scenario_state): # Alternative scenario state prompt that uses events when present # (untested / speculative) -@outlines.prompt +@compat_outlines_prompt def scenario_state_description_2(scenario_state): """ CHARACTERS: @@ -364,7 +364,7 @@ def scenario_state_description_2(scenario_state): {% endif %} """ -@outlines.prompt +@compat_outlines_prompt def scenario_state_description_dre(scenario_state): """ CHARACTERS: @@ -388,7 +388,7 @@ def scenario_state_description_dre(scenario_state): # Prompt that includes provided relevant structured character info # characters is a list of dicts with keys: name, unstructured, relevant_structured # the relevant_structured field is expected to already be serialized to a string (JSON.stringify(strucutured_json)) -@outlines.prompt +@compat_outlines_prompt def scenario_state_description_with_relevant_char_info(scenario_state, characters): """ CHARACTERS: @@ -403,7 +403,7 @@ def scenario_state_description_with_relevant_char_info(scenario_state, character {{ scenario_state.unstructured.rstrip() }} """ -@outlines.prompt +@compat_outlines_prompt def followup_clarify_character(characters): """ Please clarify which character should receive the action. @@ -418,7 +418,7 @@ def followup_clarify_character(characters): """ -@outlines.prompt +@compat_outlines_prompt def followup_clarify_treatment(character, supplies, treated_value=InjuryStatusEnum.TREATED): @@ -441,7 +441,7 @@ def followup_clarify_treatment(character, {% endfor %} """ -@outlines.prompt +@compat_outlines_prompt def followup_clarify_treatment_from_list(character, supplies, options_list, @@ -470,7 +470,7 @@ def followup_clarify_treatment_from_list(character, {% endfor %} """ -@outlines.prompt +@compat_outlines_prompt def followup_clarify_tag(character, treated_value=InjuryStatusEnum.TREATED): """ @@ -493,7 +493,7 @@ def followup_clarify_tag(character, {% endfor %} """ -@outlines.prompt +@compat_outlines_prompt def followup_clarify_aid(character, available_aids): """ Given the casualty description below, please clarify which aid option \ @@ -520,7 +520,7 @@ def followup_clarify_aid(character, available_aids): """ -@outlines.prompt +@compat_outlines_prompt def action_choice_json_schema(choices_json_str, reasoning_max_length=512): ''' {"$defs": {"ActionChoice": {"enum": {{ choices_json_str }}, @@ -538,7 +538,7 @@ def action_choice_json_schema(choices_json_str, reasoning_max_length=512): ''' -@outlines.prompt +@compat_outlines_prompt def character_choice_json_schema(choices_json_str): ''' {"$defs": {"CharacterChoice": {"enum": {{ choices_json_str }}, @@ -553,7 +553,7 @@ def character_choice_json_schema(choices_json_str): ''' -@outlines.prompt +@compat_outlines_prompt def tag_choice_json_schema(tags_json_str, reasoning_max_length=512): ''' {"$defs": {"TriageTag": {"enum": {{ tags_json_str }}, @@ -570,7 +570,7 @@ def tag_choice_json_schema(tags_json_str, reasoning_max_length=512): "type": "object"} ''' -@outlines.prompt +@compat_outlines_prompt def aid_choice_json_schema(choices_json_str): ''' {"$defs": {"AidChoice": {"enum": {{ choices_json_str }}, @@ -585,7 +585,7 @@ def aid_choice_json_schema(choices_json_str): ''' -@outlines.prompt +@compat_outlines_prompt def treatment_choice_json_schema(supplies_json_str, locations_json_str): ''' {"$defs": {"SupplyChoice": {"enum": {{ supplies_json_str }}, @@ -604,7 +604,7 @@ def treatment_choice_json_schema(supplies_json_str, locations_json_str): ''' -@outlines.prompt +@compat_outlines_prompt def treatment_choice_from_list_json_schema(treatment_choices_json_str): ''' {"$defs": {"TreatmentChoice": {"enum": {{ treatment_choices_json_str }}, @@ -619,7 +619,7 @@ def treatment_choice_from_list_json_schema(treatment_choices_json_str): ''' -@outlines.prompt +@compat_outlines_prompt def outcomes_system_prompt(): """ In a clear and consice way describe the outcomes of answering the \ @@ -629,7 +629,7 @@ def outcomes_system_prompt(): Limit your analysis to a single sentence. """ -@outlines.prompt +@compat_outlines_prompt def comparative_outcomes_system_prompt(): """ In a clear and consice way describe the outcomes of answering the \ @@ -639,7 +639,7 @@ def comparative_outcomes_system_prompt(): Limit your analysis to a single sentence per response. """ -@outlines.prompt +@compat_outlines_prompt def relevance_classification_system_prompt(kdma, kdma_description, kdma_scale_factor): """ You are a specialized algorithm tasked with considering medical triage decisions \ @@ -652,7 +652,7 @@ def relevance_classification_system_prompt(kdma, kdma_description, kdma_scale_fa yes or no, where yes means {{kdma}} is relevant and no means {{kdma}} is irrelevant. """ -@outlines.prompt +@compat_outlines_prompt def outcome_prediction_prompt(situation, choices, choice): """ Consider the following scenario: @@ -667,7 +667,7 @@ def outcome_prediction_prompt(situation, choices, choice): {{ choice }} """ -@outlines.prompt +@compat_outlines_prompt def comparative_outcome_prediction_prompt(situation, choices): """ Consider the following scenario: @@ -679,7 +679,7 @@ def comparative_outcome_prediction_prompt(situation, choices): {% endfor %} """ -@outlines.prompt +@compat_outlines_prompt def outcome_prediction_json_schema(): ''' {"properties": { @@ -717,7 +717,7 @@ def comparative_outcome_prediction_json_schema(choices): } return json.dumps(json_schema) -@outlines.prompt +@compat_outlines_prompt def kdma_score_prediction_system_prompt(kdma, kdma_description): """ You are a specialized algorithm tasked with judging medical triage decisions \ @@ -734,7 +734,7 @@ def kdma_score_prediction_system_prompt(kdma, kdma_description): Provide a score and one sentence of reasoning explaining your score. """ -@outlines.prompt +@compat_outlines_prompt def comparative_kdma_score_prediction_system_prompt(kdma, kdma_description, kdma_scale_factor): """ You are a specialized algorithm tasked with judging medical triage decisions \ @@ -753,7 +753,7 @@ def comparative_kdma_score_prediction_system_prompt(kdma, kdma_description, kdma For each response, provide a score and one sentence of reasoning explaining your score. """ -@outlines.prompt +@compat_outlines_prompt def kdma_score_prediction_system_prompt_with_examples(kdma, kdma_description, kdma_score_examples): """ You are a specialized algorithm tasked with judging medical triage decisions \ @@ -773,7 +773,7 @@ def kdma_score_prediction_system_prompt_with_examples(kdma, kdma_description, kd Provide a score and one sentence of reasoning explaining your score. """ -@outlines.prompt +@compat_outlines_prompt def comparative_kdma_score_prediction_system_prompt_with_examples( kdma, kdma_description, kdma_score_examples, kdma_scale_factor ): @@ -797,7 +797,7 @@ def comparative_kdma_score_prediction_system_prompt_with_examples( For each response, provide a score and one sentence of reasoning explaining your score. """ -@outlines.prompt +@compat_outlines_prompt def kdma_score_prediction_prompt(situation, choices, choice, outcome, kdma): """ Consider the following situation and selected response: @@ -821,7 +821,7 @@ def kdma_score_prediction_prompt(situation, choices, choice, outcome, kdma): making behavior of someone who is acting with high {{kdma}} with one sentence of reasoning. """ -@outlines.prompt +@compat_outlines_prompt def comparative_kdma_score_prediction_prompt(situation, choices, kdma): """ Consider the following scenario and responses: @@ -839,7 +839,7 @@ def comparative_kdma_score_prediction_prompt(situation, choices, kdma): behavior of someone who is acting with high {{kdma}} with one sentence of reasoning. """ -@outlines.prompt +@compat_outlines_prompt def comparative_kdma_score_prediction_prompt_no_outcomes(situation, choices, kdma): """ Consider the following scenario and responses: @@ -854,7 +854,7 @@ def comparative_kdma_score_prediction_prompt_no_outcomes(situation, choices, kdm behavior of someone who is acting with high {{kdma}} with one sentence of reasoning. """ -@outlines.prompt +@compat_outlines_prompt def relevance_classification_prompt(situation, choices, kdma): """ Consider the following scenario and responses: @@ -869,7 +869,7 @@ def relevance_classification_prompt(situation, choices, kdma): or no (irrelevant) and provide one sentence of reasoning. """ -@outlines.prompt +@compat_outlines_prompt def kdma_score_prediction_json_schema(): ''' {"properties": { @@ -967,14 +967,14 @@ def enum_comparative_kdma_score_prediction_json_schema(choices, valid_scores): return json.dumps(json_schema) -@outlines.prompt +@compat_outlines_prompt def scenario_description_hybrid_regression(scenario_state): """ {{ scenario_state.unstructured.rstrip() }} {% for character in scenario_state.characters %}{{ character.name }} - {{ character.unstructured.rstrip()}} {% endfor %} """ -@outlines.prompt +@compat_outlines_prompt def kaleido_default_itm_prompt(scenario_state, choice, other_choices): @@ -985,7 +985,7 @@ def kaleido_default_itm_prompt(scenario_state, {% endfor %} """ -@outlines.prompt +@compat_outlines_prompt def kaleido_default_itm_ph2_prompt(scenario_state, choice, other_choices): @@ -1203,7 +1203,7 @@ def __call__(self, target_kdma, target_value): return high_utilitarianism_care_system_prompt() -@outlines.prompt +@compat_outlines_prompt def phase2_scenario_state_description(scenario_state): """ {{ scenario_state.unstructured.rstrip() }} @@ -1215,7 +1215,7 @@ def __call__(self, scenario_state): return phase2_scenario_state_description(scenario_state) -@outlines.prompt +@compat_outlines_prompt def phase2_baseline_prompt(scenario_description, choices): """ Scenario: @@ -1233,7 +1233,7 @@ def __call__(self, scenario_description, choices): return phase2_baseline_prompt(scenario_description, choices) -@outlines.prompt +@compat_outlines_prompt def comparative_regression_system_prompt( kdma, kdma_description, kdma_scale_factor ): @@ -1264,7 +1264,7 @@ def __call__(self, target_attribute): target_attribute.factor) -@outlines.prompt +@compat_outlines_prompt def comparative_regression_system_prompt_with_examples( kdma, kdma_description, kdma_score_examples, kdma_scale_factor ): @@ -1299,7 +1299,7 @@ def __call__(self, target_attribute): target_attribute.factor) -@outlines.prompt +@compat_outlines_prompt def comparative_regression_prompt(situation, choices, kdma): """ Scenario: @@ -1367,7 +1367,7 @@ def __call__(self, choices, attribute): self.reasoning_max_length) -@outlines.prompt +@compat_outlines_prompt def probe_most_relevant_system_prompt( kdmas ): @@ -1394,7 +1394,7 @@ def __call__(self, target_attributes): target_attributes) -@outlines.prompt +@compat_outlines_prompt def probe_most_relevant_prompt(situation, choices, kdmas): """ Scenario: @@ -1444,7 +1444,7 @@ def __call__(self, target_attribute_names): target_attribute_names) -@outlines.prompt +@compat_outlines_prompt def variables_extraction_stage_prompt(situation, choices): """ You are given a piece of text and a problem. Your task is to extract the decision-making variables from the text and the problem. A decision-making variable is either: @@ -1545,7 +1545,7 @@ class VariablesOutputSchema(): def __call__(self): return variables_extraction_output_schema() -@outlines.prompt +@compat_outlines_prompt def extraction_stage_prompt(situation, choices, variables): """ You are an expert in information extraction and summarization. Your task is to analyze the given text and extract all key pieces of information that might be valuable while ensuring that every extracted sentence explicitly includes the subject from the provided variables instead of using pronouns. @@ -1654,7 +1654,7 @@ class ExtractionOutputSchema(): def __call__(self): return extraction_output_schema() -@outlines.prompt +@compat_outlines_prompt def attribute_stage_prompt(scenario_description, choices, extraction, variables, target_bias): """ Given the extracted information from a scenario, your task is to summarize attributes and values for each variable with a focus on characteristics that are most relevant to a specified target bias. @@ -1825,7 +1825,7 @@ class AttributeOutputSchema(): def __call__(self): return attribute_output_schema() -@outlines.prompt +@compat_outlines_prompt def filter_stage_prompt(scenario_description, choices, attribute_information, target_bias): """ You are tasked with evaluating the relationship between a given attribute, and a specified target bias. The goal is to determine how relevant the given attribute are to the target bias and to assign a weight that reflects the strength of this relationship. @@ -1921,7 +1921,7 @@ class FilterOutputSchema(): def __call__(self): return filter_output_schema() -@outlines.prompt +@compat_outlines_prompt def objective_stage_prompt(scenario_description, choices, objective_components, objective_function_text, weight_threshold): """ You are tasked with creating and refining an objective function based on filtered attributes that exceeded the weight threshold. The objective function should mathematically represent the decision-making criteria based on the most relevant attributes. @@ -2013,7 +2013,7 @@ class ObjectiveOutputSchema(): def __call__(self): return objective_output_schema() -@outlines.prompt +@compat_outlines_prompt def express_stage_prompt(scenario_description, choices, structure): """ You are an expert in optimization modeling and ethical decision-making. Your task is to construct an optimization model based on the following information. @@ -2149,7 +2149,7 @@ def __call__(self): return express_output_schema() -@outlines.prompt +@compat_outlines_prompt def math_reason_prompt(objective, attribute, variable, constraints, choice, target_bias): """ You will receive a mathematical model structure along with a dataset containing variables, attributes, and values. Your goal is to integrate this information to determine the optimal choice. @@ -2211,7 +2211,7 @@ def math_reason_prompt(objective, attribute, variable, constraints, choice, targ {% endraw %} """ -@outlines.prompt +@compat_outlines_prompt def math_reason_error_prompt(structure, choice, target_bias): """ You will receive a mathematical model structure along with a dataset containing variables, attributes, and values. Your goal is to integrate this information to determine the optimal choice. diff --git a/poetry.lock b/poetry.lock index 0729c9eb..759aa5a0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -156,18 +156,6 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" -[[package]] -name = "airportsdata" -version = "20250224" -description = "Extensive database of location and timezone data for nearly every airport and landing strip in the world." -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "airportsdata-20250224-py3-none-any.whl", hash = "sha256:006128bca2cc1983dc5ed4fb1227e8df2289b5e95b8ab30d9bdd4eb7c6d2160d"}, - {file = "airportsdata-20250224.tar.gz", hash = "sha256:7f4538a613504444a13149be701aac5f9599b86da476d26b46aa24fd54714eda"}, -] - [[package]] name = "annotated-types" version = "0.7.0" @@ -302,18 +290,6 @@ files = [ {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, ] -[[package]] -name = "cfgv" -version = "3.4.0" -description = "Validate configuration and produce human readable error messages." -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, - {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, -] - [[package]] name = "charset-normalizer" version = "3.3.2" @@ -621,18 +597,6 @@ files = [ {file = "diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc"}, ] -[[package]] -name = "distlib" -version = "0.3.9" -description = "Distribution utilities" -optional = false -python-versions = "*" -groups = ["main"] -files = [ - {file = "distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87"}, - {file = "distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403"}, -] - [[package]] name = "distro" version = "1.9.0" @@ -983,6 +947,42 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "hf-xet" +version = "1.2.0" +description = "Fast transfer of large files with the Hugging Face Hub." +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\"" +files = [ + {file = "hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649"}, + {file = "hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813"}, + {file = "hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc"}, + {file = "hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5"}, + {file = "hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f"}, + {file = "hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832"}, + {file = "hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382"}, + {file = "hf_xet-1.2.0-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:10bfab528b968c70e062607f663e21e34e2bba349e8038db546646875495179e"}, + {file = "hf_xet-1.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2a212e842647b02eb6a911187dc878e79c4aa0aa397e88dd3b26761676e8c1f8"}, + {file = "hf_xet-1.2.0-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e06daccb3a7d4c065f34fc26c14c74f4653069bb2b194e7f18f17cbe9939c0"}, + {file = "hf_xet-1.2.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:29c8fc913a529ec0a91867ce3d119ac1aac966e098cf49501800c870328cc090"}, + {file = "hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:66e159cbfcfbb29f920db2c09ed8b660eb894640d284f102ada929b6e3dc410a"}, + {file = "hf_xet-1.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:9c91d5ae931510107f148874e9e2de8a16052b6f1b3ca3c1b12f15ccb491390f"}, + {file = "hf_xet-1.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:210d577732b519ac6ede149d2f2f34049d44e8622bf14eb3d63bbcd2d4b332dc"}, + {file = "hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848"}, + {file = "hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4"}, + {file = "hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd"}, + {file = "hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c"}, + {file = "hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737"}, + {file = "hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865"}, + {file = "hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69"}, + {file = "hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f"}, +] + +[package.extras] +tests = ["pytest"] + [[package]] name = "httpcore" version = "1.0.5" @@ -1032,19 +1032,20 @@ socks = ["socksio (==1.*)"] [[package]] name = "huggingface-hub" -version = "0.29.3" +version = "0.36.0" description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false python-versions = ">=3.8.0" groups = ["main"] files = [ - {file = "huggingface_hub-0.29.3-py3-none-any.whl", hash = "sha256:0b25710932ac649c08cdbefa6c6ccb8e88eef82927cacdb048efb726429453aa"}, - {file = "huggingface_hub-0.29.3.tar.gz", hash = "sha256:64519a25716e0ba382ba2d3fb3ca082e7c7eb4a2fc634d200e8380006e0760e5"}, + {file = "huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d"}, + {file = "huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25"}, ] [package.dependencies] filelock = "*" fsspec = ">=2023.5.0" +hf-xet = {version = ">=1.1.3,<2.0.0", markers = "platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\""} packaging = ">=20.9" pyyaml = ">=5.1" requests = "*" @@ -1052,16 +1053,19 @@ tqdm = ">=4.42.1" typing-extensions = ">=3.7.4.3" [package.extras] -all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "libcst (>=1.4.0)", "mypy (==1.15.0) ; python_version >= \"3.9\"", "mypy (>=1.14.1,<1.15.0) ; python_version == \"3.8\"", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures (<16.0)", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "ty", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] cli = ["InquirerPy (==0.3.4)"] -dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "libcst (>=1.4.0)", "mypy (==1.15.0) ; python_version >= \"3.9\"", "mypy (>=1.14.1,<1.15.0) ; python_version == \"3.8\"", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures (<16.0)", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "ty", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] hf-transfer = ["hf-transfer (>=0.1.4)"] +hf-xet = ["hf-xet (>=1.1.2,<2.0.0)"] inference = ["aiohttp"] -quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"] +mcp = ["aiohttp", "mcp (>=1.8.0)", "typer"] +oauth = ["authlib (>=1.3.2)", "fastapi", "httpx", "itsdangerous"] +quality = ["libcst (>=1.4.0)", "mypy (==1.15.0) ; python_version >= \"3.9\"", "mypy (>=1.14.1,<1.15.0) ; python_version == \"3.8\"", "ruff (>=0.9.0)", "ty"] tensorflow = ["graphviz", "pydot", "tensorflow"] tensorflow-testing = ["keras (<3.0)", "tensorflow"] -testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "authlib (>=1.3.2)", "fastapi", "gradio (>=4.0.0)", "httpx", "itsdangerous", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures (<16.0)", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] torch = ["safetensors[torch]", "torch"] typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] @@ -1082,21 +1086,6 @@ antlr4-python3-runtime = "==4.9.*" omegaconf = ">=2.2,<2.4" packaging = "*" -[[package]] -name = "identify" -version = "2.6.9" -description = "File identification library for Python" -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "identify-2.6.9-py2.py3-none-any.whl", hash = "sha256:c98b4322da415a8e5a70ff6e51fbc2d2932c015532d77e9f8537b4ba7813b150"}, - {file = "identify-2.6.9.tar.gz", hash = "sha256:d40dfe3142a1421d8518e3d3985ef5ac42890683e32306ad614a29490abeb6bf"}, -] - -[package.extras] -license = ["ukkonen"] - [[package]] name = "idna" version = "3.7" @@ -1129,46 +1118,6 @@ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy ; platform_python_implementation != \"PyPy\"", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] -[[package]] -name = "intel-openmp" -version = "2021.4.0" -description = "Intel OpenMP* Runtime Library" -optional = false -python-versions = "*" -groups = ["main"] -markers = "platform_system == \"Windows\"" -files = [ - {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"}, - {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"}, -] - -[[package]] -name = "interegular" -version = "0.3.3" -description = "a regex intersection checker" -optional = false -python-versions = ">=3.7" -groups = ["main"] -files = [ - {file = "interegular-0.3.3-py37-none-any.whl", hash = "sha256:b0c07007d48c89d6d19f7204972d369b2a77222722e126b6aa63aa721dc3b19c"}, - {file = "interegular-0.3.3.tar.gz", hash = "sha256:d9b697b21b34884711399ba0f0376914b81899ce670032486d0d048344a76600"}, -] - -[[package]] -name = "iso3166" -version = "2.1.1" -description = "Self-contained ISO 3166-1 country definitions." -optional = false -python-versions = ">=3.6" -groups = ["main"] -files = [ - {file = "iso3166-2.1.1-py3-none-any.whl", hash = "sha256:263660b36f8471c42acd1ff673d28a3715edbce7d24b1550d0cf010f6816c47f"}, - {file = "iso3166-2.1.1.tar.gz", hash = "sha256:fcd551b8dda66b44e9f9e6d6bbbee3a1145a22447c0a556e5d0fb1ad1e491719"}, -] - [[package]] name = "jinja2" version = "3.1.4" @@ -1214,6 +1163,22 @@ files = [ [package.dependencies] jsonpointer = ">=1.9" +[[package]] +name = "jsonpath-ng" +version = "1.7.0" +description = "A final implementation of JSONPath for Python that aims to be standard compliant, including arithmetic and binary comparison operators and providing clear AST for metaprogramming." +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "jsonpath-ng-1.7.0.tar.gz", hash = "sha256:f6f5f7fd4e5ff79c785f1573b394043b39849fb2bb47bcead935d12b00beab3c"}, + {file = "jsonpath_ng-1.7.0-py2-none-any.whl", hash = "sha256:898c93fc173f0c336784a3fa63d7434297544b7198124a68f9a3ef9597b0ae6e"}, + {file = "jsonpath_ng-1.7.0-py3-none-any.whl", hash = "sha256:f3d7f9e848cba1b6da28c55b1c26ff915dc9e0b1ba7e752a53d6da8d5cbd00b6"}, +] + +[package.dependencies] +ply = "*" + [[package]] name = "jsonpointer" version = "2.4" @@ -1433,24 +1398,6 @@ files = [ pydantic = ">=1,<3" requests = ">=2,<3" -[[package]] -name = "lark" -version = "1.1.9" -description = "a modern parsing library" -optional = false -python-versions = ">=3.6" -groups = ["main"] -files = [ - {file = "lark-1.1.9-py3-none-any.whl", hash = "sha256:a0dd3a87289f8ccbb325901e4222e723e7d745dbfc1803eaf5f3d2ace19cf2db"}, - {file = "lark-1.1.9.tar.gz", hash = "sha256:15fa5236490824c2c4aba0e22d2d6d823575dcaf4cdd1848e34b6ad836240fba"}, -] - -[package.extras] -atomic-cache = ["atomicwrites"] -interegular = ["interegular (>=0.3.1,<0.4.0)"] -nearley = ["js2py"] -regex = ["regex"] - [[package]] name = "llama-index" version = "0.8.42" @@ -1661,26 +1608,6 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] -[[package]] -name = "mkl" -version = "2021.4.0" -description = "IntelĀ® oneAPI Math Kernel Library" -optional = false -python-versions = "*" -groups = ["main"] -markers = "platform_system == \"Windows\"" -files = [ - {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"}, - {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"}, - {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"}, - {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"}, -] - -[package.dependencies] -intel-openmp = "==2021.*" -tbb = "==2021.*" - [[package]] name = "mpmath" version = "1.3.0" @@ -1893,18 +1820,6 @@ plot = ["matplotlib"] tgrep = ["pyparsing"] twitter = ["twython"] -[[package]] -name = "nodeenv" -version = "1.9.1" -description = "Node.js virtual environment builder" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["main"] -files = [ - {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, - {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, -] - [[package]] name = "numpy" version = "1.26.4" @@ -1952,168 +1867,224 @@ files = [ ] [[package]] -name = "nvidia-cublas-cu11" -version = "11.11.3.6" +name = "nvidia-cublas" +version = "13.0.0.19" description = "CUBLAS native runtime libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux1_x86_64.whl", hash = "sha256:39fb40e8f486dd8a2ddb8fdeefe1d5b28f5b99df01c87ab3676f057a74a5a6f3"}, - {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5ccae9e069a2c6be9af9cb5a0b0c6928c19c7915e390d15f598a1eead2a01a7a"}, - {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux2014_x86_64.whl", hash = "sha256:60252822adea5d0b10cd990a7dc7bedf7435f30ae40083c7a624a85a43225abc"}, - {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-win_amd64.whl", hash = "sha256:6ab12b1302bef8ac1ff4414edd1c059e57f4833abef9151683fb8f4de25900be"}, + {file = "nvidia_cublas-13.0.0.19-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:381b1a0ca636fdcb6920a871e8fc89dbfd1f6157f421ed0a6f2673e14cffd3bd"}, + {file = "nvidia_cublas-13.0.0.19-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:f6723af2e8e2600a11dc384037d90d9bf93070e346c24ef2e8f9001658c99896"}, + {file = "nvidia_cublas-13.0.0.19-py3-none-win_amd64.whl", hash = "sha256:e6ecde441aaf0bb74ed538cfb3b18aa374f452aebf0162088bcb10942f7bbc33"}, ] [[package]] -name = "nvidia-cuda-cupti-cu11" -version = "11.8.87" +name = "nvidia-cuda-cupti" +version = "13.0.48" description = "CUDA profiling tools runtime libs." optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux1_x86_64.whl", hash = "sha256:0e50c707df56c75a2c0703dc6b886f3c97a22f37d6f63839f75b7418ba672a8d"}, - {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9aaa638963a8271df26b6ee0d93b781730b7acc6581ff700bd023d7934e4385e"}, - {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux2014_x86_64.whl", hash = "sha256:4191a17913a706b5098681280cd089cd7d8d3df209a6f5cb79384974a96d24f2"}, - {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-win_amd64.whl", hash = "sha256:4332d8550ad5f5b673f98d08e4e4f82030cb604c66d8d5ee919399ea01312e58"}, + {file = "nvidia_cuda_cupti-13.0.48-py3-none-manylinux_2_25_aarch64.whl", hash = "sha256:67c22627ef436afcf080b48e4ad17b3f83d9e7c0d990ad0c6c0627b01fb92ccc"}, + {file = "nvidia_cuda_cupti-13.0.48-py3-none-manylinux_2_25_x86_64.whl", hash = "sha256:417699e216b23d81bc0bbcb7032352f81b9c5372ef73c097a01abb83125a3d09"}, + {file = "nvidia_cuda_cupti-13.0.48-py3-none-win_amd64.whl", hash = "sha256:c0f0266d5674afad541888d4383bd172b7f90ff6df62df83ef9f5431a3c2c3b1"}, ] [[package]] -name = "nvidia-cuda-nvrtc-cu11" -version = "11.8.89" +name = "nvidia-cuda-nvrtc" +version = "13.0.48" description = "NVRTC native runtime libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl", hash = "sha256:1f27d67b0f72902e9065ae568b4f6268dfe49ba3ed269c9a3da99bb86d1d2008"}, - {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8ab17ed51e7c4928f7170a0551e3e3b42f89d973bd267ced9688c238b3e10aef"}, - {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a8d02f3cba345be56b1ffc3e74d8f61f02bb758dd31b0f20e12277a5a244f756"}, - {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-win_amd64.whl", hash = "sha256:e18a23a8f4064664a6f1c4a64f38c581cbebfb5935280e94a4943ea8ae3791b1"}, + {file = "nvidia_cuda_nvrtc-13.0.48-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:87e13d186905a35e7c04ad553a2abded0fba22f93b43d02e5da6f6cf73fb4d0a"}, + {file = "nvidia_cuda_nvrtc-13.0.48-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6ccf1ef1b90a0763ac7536f3c17046659d89869d76b98ac358efc2e09b348365"}, + {file = "nvidia_cuda_nvrtc-13.0.48-py3-none-win_amd64.whl", hash = "sha256:9f10c41c3822a9d44a19e9150a05c99425514b691b342c6db6729072c5b88edd"}, ] [[package]] -name = "nvidia-cuda-runtime-cu11" -version = "11.8.89" +name = "nvidia-cuda-runtime" +version = "13.0.48" description = "CUDA Runtime native Libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl", hash = "sha256:f587bd726eb2f7612cf77ce38a2c1e65cf23251ff49437f6161ce0d647f64f7c"}, - {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_aarch64.whl", hash = "sha256:e53bf160b6b660819cb6e4a9da2cc89e6aa2329858299780a2459780a2b8d012"}, - {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl", hash = "sha256:92d04069a987e1fbc9213f8376d265df0f7bb42617d44f5eda1f496acea7f2d1"}, - {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-win_amd64.whl", hash = "sha256:f60c9fdaed8065b38de8097867240efc5556a8a710007146daeb9082334a6e63"}, + {file = "nvidia_cuda_runtime-13.0.48-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b807c0bb925a307bfa667a24f24d253aef8eda3ac4be66b333f2c9d357557008"}, + {file = "nvidia_cuda_runtime-13.0.48-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b54d12087a1abff81a4cbfa6556876e3afea1fc60da2e0816da374619810c89"}, + {file = "nvidia_cuda_runtime-13.0.48-py3-none-win_amd64.whl", hash = "sha256:03e581c7584b13e42ce175c774f46e1219e9c574f27fe88c2ccc75dd3f926ed7"}, ] [[package]] -name = "nvidia-cudnn-cu11" -version = "8.7.0.84" +name = "nvidia-cudnn-cu13" +version = "9.13.0.50" description = "cuDNN runtime libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cudnn_cu11-8.7.0.84-py3-none-manylinux1_x86_64.whl", hash = "sha256:b3e062498fbbb1c1930435a6a454c1b41c903e1e65b7063bd2b4021e8285408e"}, + {file = "nvidia_cudnn_cu13-9.13.0.50-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:33f0aa0b64230101b348648fd0693342188071d3f8a137c0cf50051c24b3584b"}, + {file = "nvidia_cudnn_cu13-9.13.0.50-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:2150b4850725d30653ec3e365f0732e3e2e3eb8633cf3bd2d3117628dea8b4f9"}, + {file = "nvidia_cudnn_cu13-9.13.0.50-py3-none-win_amd64.whl", hash = "sha256:216f6af23842823dfa84a32c3b91c9180aca9c036eebce0b8e05489c6bfc4b5c"}, ] [package.dependencies] -nvidia-cublas-cu11 = "*" +nvidia-cublas = "*" [[package]] -name = "nvidia-cufft-cu11" -version = "10.9.0.58" +name = "nvidia-cufft" +version = "12.0.0.15" description = "CUFFT native runtime libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" +files = [ + {file = "nvidia_cufft-12.0.0.15-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1885731254835797572ff075f3daf43a2a0a2801210dea26971940dae7e1a367"}, + {file = "nvidia_cufft-12.0.0.15-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9f160b1f018e80bcb0d7c0fa50564b042fa26b13edc1b1ff14b6375a9edd2812"}, + {file = "nvidia_cufft-12.0.0.15-py3-none-win_amd64.whl", hash = "sha256:ff2083e7c4f5bc063d37ec8399277e514ca97b554e069aa6f7eb7ce6d727dc7b"}, +] + +[package.dependencies] +nvidia-nvjitlink = "*" + +[[package]] +name = "nvidia-cufile" +version = "1.15.0.42" +description = "cuFile GPUDirect libraries" +optional = false +python-versions = ">=3" +groups = ["main"] +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl", hash = "sha256:222f9da70c80384632fd6035e4c3f16762d64ea7a843829cb278f98b3cb7dd81"}, - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux2014_aarch64.whl", hash = "sha256:34b7315104e615b230dc3c2d1861f13bff9ec465c5d3b4bb65b4986d03a1d8d4"}, - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e21037259995243cc370dd63c430d77ae9280bedb68d5b5a18226bfc92e5d748"}, - {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-win_amd64.whl", hash = "sha256:c4d316f17c745ec9c728e30409612eaf77a8404c3733cdf6c9c1569634d1ca03"}, + {file = "nvidia_cufile-1.15.0.42-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8f9813eff24d61586699c615e39817e2b4e4f642cace32733c2ab6f663a7eab"}, + {file = "nvidia_cufile-1.15.0.42-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:bced4036b5a8dbf57e4d78cd4fafefec58ad754b784a9eaa272b011896754c62"}, ] [[package]] -name = "nvidia-curand-cu11" -version = "10.3.0.86" +name = "nvidia-curand" +version = "10.4.0.35" description = "CURAND native runtime libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_curand_cu11-10.3.0.86-py3-none-manylinux1_x86_64.whl", hash = "sha256:ac439548c88580269a1eb6aeb602a5aed32f0dbb20809a31d9ed7d01d77f6bf5"}, - {file = "nvidia_curand_cu11-10.3.0.86-py3-none-manylinux2014_aarch64.whl", hash = "sha256:64defc3016d8c1de351a764617818c2961210430f12476faee10084b269b188c"}, - {file = "nvidia_curand_cu11-10.3.0.86-py3-none-manylinux2014_x86_64.whl", hash = "sha256:cd4cffbf78bb06580206b4814d5dc696d1161c902aae37b2bba00056832379e6"}, - {file = "nvidia_curand_cu11-10.3.0.86-py3-none-win_amd64.whl", hash = "sha256:8fa8365065fc3e3760d7437b08f164a6bcf8f7124f3b544d2463ded01e6bdc70"}, + {file = "nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:133df5a7509c3e292aaa2b477afd0194f06ce4ea24d714d616ff36439cee349a"}, + {file = "nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:1aee33a5da6e1db083fe2b90082def8915f30f3248d5896bcec36a579d941bfc"}, + {file = "nvidia_curand-10.4.0.35-py3-none-win_amd64.whl", hash = "sha256:65b1710aa6961d326b411e314b374290904c5ddf41dc3f766ebc3f1d7d4ca69f"}, ] [[package]] -name = "nvidia-cusolver-cu11" -version = "11.4.1.48" +name = "nvidia-cusolver" +version = "12.0.3.29" description = "CUDA solver native runtime libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux1_x86_64.whl", hash = "sha256:ca538f545645b7e6629140786d3127fe067b3d5a085bd794cde5bfe877c8926f"}, - {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1a96acb05768610bc414dbef5b25ebd2d820fc8a1e8c72097f41f53d80934d61"}, - {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea9fb1ad8c644ca9ed55af13cc39af3b7ba4c3eb5aef18471fe1fe77d94383cb"}, - {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-win_amd64.whl", hash = "sha256:7efe43b113495a64e2cf9a0b4365bd53b0a82afb2e2cf91e9f993c9ef5e69ee8"}, + {file = "nvidia_cusolver-12.0.3.29-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:3bb6e65ce0beaeafdd069b320246e8f17c1cd30ddb27a0539143a3706733a4d8"}, + {file = "nvidia_cusolver-12.0.3.29-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:6f54c2eed5edab54c224dd1852dde80ba76b2b78e6d3ce7344fef5dfc66d16ab"}, + {file = "nvidia_cusolver-12.0.3.29-py3-none-win_amd64.whl", hash = "sha256:103fa9e99d63e4be4d04e4a9cb7dfaec5d86f486bb59838cb72803cabb1690a4"}, ] [package.dependencies] -nvidia-cublas-cu11 = "*" +nvidia-cublas = "*" +nvidia-cusparse = "*" +nvidia-nvjitlink = "*" [[package]] -name = "nvidia-cusparse-cu11" -version = "11.7.5.86" +name = "nvidia-cusparse" +version = "12.6.2.49" description = "CUSPARSE native runtime libraries" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" +files = [ + {file = "nvidia_cusparse-12.6.2.49-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d3269c19283a0057fb5ebfb003ae2a10c97a28a6958f4238354826b055827c7"}, + {file = "nvidia_cusparse-12.6.2.49-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:efcf0b01e3a0827c144feff5391456b8a06e9ce63dcd51c0943e32e605251952"}, + {file = "nvidia_cusparse-12.6.2.49-py3-none-win_amd64.whl", hash = "sha256:b48237614131dedf9cd00d99ce950d8e1b2945ab9d29337fbdc1e014f0ee9830"}, +] + +[package.dependencies] +nvidia-nvjitlink = "*" + +[[package]] +name = "nvidia-cusparselt-cu13" +version = "0.8.0" +description = "NVIDIA cuSPARSELt" +optional = false +python-versions = "*" +groups = ["main"] +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux1_x86_64.whl", hash = "sha256:4ae709fe78d3f23f60acaba8c54b8ad556cf16ca486e0cc1aa92dca7555d2d2b"}, - {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6c7da46abee7567e619d4aa2e90a1b032cfcbd1211d429853b1a6e87514a14b2"}, - {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8d7cf1628fd8d462b5d2ba6678fae34733a48ecb80495b9c68672ec6a6dde5ef"}, - {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-win_amd64.whl", hash = "sha256:a0f6ee81cd91be606fc2f55992d06b09cd4e86d74b6ae5e8dd1631cf7f5a8706"}, + {file = "nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:400c6ed1cf6780fc6efedd64ec9f1345871767e6a1a0a552a1ea0578117ea77c"}, + {file = "nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:25e30a8a7323935d4ad0340b95a0b69926eee755767e8e0b1cf8dd85b197d3fd"}, + {file = "nvidia_cusparselt_cu13-0.8.0-py3-none-win_amd64.whl", hash = "sha256:e80212ed7b1afc97102fbb2b5c82487aa73f6a0edfa6d26c5a152593e520bb8f"}, ] [[package]] -name = "nvidia-nccl-cu11" -version = "2.20.5" +name = "nvidia-nccl-cu13" +version = "2.27.7" description = "NVIDIA Collective Communication Library (NCCL) Runtime" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" +files = [ + {file = "nvidia_nccl_cu13-2.27.7-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5e3cc863e52bf9dd1e3ab1941bddb414098f489ae7342f6b3a274602303da123"}, + {file = "nvidia_nccl_cu13-2.27.7-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b28a524abd8389b76a4a3f133c76a7aaa7005e47fcaa9d9603b90103927a3f93"}, +] + +[[package]] +name = "nvidia-nvjitlink" +version = "13.0.39" +description = "Nvidia JIT LTO Library" +optional = false +python-versions = ">=3" +groups = ["main"] +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_nccl_cu11-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:3619e25dfb0c8f4c554561c3459ee7dfe1250eed05e9aa4d147a75c45cc6ae0d"}, + {file = "nvidia_nvjitlink-13.0.39-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:bc3179be558329ef9687884c6faa27cdc0659bdbc642432ec8cc6cc00d182627"}, + {file = "nvidia_nvjitlink-13.0.39-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ce0d63fa5ebedf542056e7491c49feed2297c900980aa6269b6a55f478056ad7"}, + {file = "nvidia_nvjitlink-13.0.39-py3-none-win_amd64.whl", hash = "sha256:478d06d3783b1c26a9bea308b972737a9fb2bb832d2254aa51fe753713b4a583"}, ] [[package]] -name = "nvidia-nvtx-cu11" -version = "11.8.86" +name = "nvidia-nvshmem-cu13" +version = "3.3.24" +description = "NVSHMEM creates a global address space that provides efficient and scalable communication for NVIDIA GPU clusters." +optional = false +python-versions = ">=3" +groups = ["main"] +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" +files = [ + {file = "nvidia_nvshmem_cu13-3.3.24-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:28ae82a4d14b322b93409535de62df6b7b83f4f7672ca97fc89107c2d40ce2c2"}, + {file = "nvidia_nvshmem_cu13-3.3.24-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c14d09571697d2e57cb079c8daec88ab1c68cb3586532bfbd4886125a08339b7"}, +] + +[[package]] +name = "nvidia-nvtx" +version = "13.0.39" description = "NVIDIA Tools Extension" optional = false python-versions = ">=3" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "nvidia_nvtx_cu11-11.8.86-py3-none-manylinux1_x86_64.whl", hash = "sha256:890656d8bd9b4e280231c832e1f0d03459200ba4824ddda3dcb59b1e1989b9f5"}, - {file = "nvidia_nvtx_cu11-11.8.86-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5e84b97062eb102b45a8a9172a06cfe28b239b1635075a13d6474e91295e0468"}, - {file = "nvidia_nvtx_cu11-11.8.86-py3-none-manylinux2014_x86_64.whl", hash = "sha256:979f5b2aef5da164c5c53c64c85c3dfa61b8b4704f4f963bb568bf98fa8472e8"}, - {file = "nvidia_nvtx_cu11-11.8.86-py3-none-win_amd64.whl", hash = "sha256:54031010ee38d774b2991004d88f90bbd7bbc1458a96bbc4b42662756508c252"}, + {file = "nvidia_nvtx-13.0.39-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:cc113127785c96db8a0fe715df92db9788777b4b3d1bd713d42f75969201b5ce"}, + {file = "nvidia_nvtx-13.0.39-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cddd2e08b35144f1000631c3880c9ebbcb8a2863d762e76f92d47d30ecaf87cc"}, + {file = "nvidia_nvtx-13.0.39-py3-none-win_amd64.whl", hash = "sha256:14e4e4cf8976ce9544ec5e70e39dca8a7cc62af4692f20d8dc85266709d2e641"}, ] [[package]] @@ -2158,89 +2129,101 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "outlines" -version = "0.2.1" +version = "1.2.7" description = "Probabilistic Generative Model Programming" optional = false python-versions = "<3.13,>=3.9" groups = ["main"] files = [ - {file = "outlines-0.2.1-py3-none-any.whl", hash = "sha256:bc3ff720d0a3117c70bf77236b619e02824d7ffb205ea7e8f701690ab42d3707"}, - {file = "outlines-0.2.1.tar.gz", hash = "sha256:43391e18665245eeb141b09cc53896fad610a6c587698577de196ba4758cc420"}, + {file = "outlines-1.2.7-py3-none-any.whl", hash = "sha256:5d1cb695cb14213e64e632b742090880094877440bc292c5fd4ebb4a912d8c02"}, + {file = "outlines-1.2.7.tar.gz", hash = "sha256:1b588e7a6c789deae29dc212037089f4fc9f954b1d7d23f223d5451db45bb5b7"}, ] [package.dependencies] -airportsdata = "*" cloudpickle = "*" diskcache = "*" genson = "*" -interegular = "*" -iso3166 = "*" jinja2 = "*" +jsonpath_ng = "*" jsonschema = "*" -lark = "*" -nest_asyncio = "*" -numpy = "*" -outlines_core = "0.1.26" -pre-commit = ">=4.0.1" +outlines_core = "0.2.11" +pillow = "*" pydantic = ">=2.0" -referencing = "*" -requests = "*" -torch = "*" -tqdm = "*" typing_extensions = "*" [package.extras] -exllamav2 = ["exllamav2"] -llamacpp = ["datasets", "llama-cpp-python", "numpy (<2)", "transformers"] -mlxlm = ["datasets", "mlx-lm"] +airports = ["airportsdata"] +anthropic = ["anthropic"] +countries = ["iso3166"] +dottxt = ["dottxt"] +gemini = ["google-genai"] +llamacpp = ["huggingface-hub", "llama-cpp-python", "numba"] +llguidance = ["llguidance"] +mistral = ["mistralai"] +mlxlm = ["datasets", "mlx", "mlx-lm"] +ollama = ["ollama"] openai = ["openai"] -serve = ["fastapi", "pydantic (>=2.0)", "uvicorn", "vllm (>=0.3.0)"] -test = ["accelerate", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "exllamav2", "huggingface_hub", "jax", "llama-cpp-python", "mlx-lm (>=0.19.2) ; platform_machine == \"arm64\" and sys_platform == \"darwin\"", "openai (>=1.0.0)", "pillow", "pre-commit", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "responses", "transformers"] +sglang = ["openai"] +test = ["accelerate", "airportsdata", "anthropic", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "dottxt", "flax", "google-genai", "huggingface_hub", "iso3166", "jax", "llama-cpp-python", "llguidance", "mistralai", "mkdocs_gen_files", "mlx-lm (>=0.19.2) ; platform_machine == \"arm64\" and sys_platform == \"darwin\"", "numba", "numpy (>=2.0.0,<2.2.0)", "ollama", "openai (>=1.0.0)", "pillow", "pre-commit", "pytest", "pytest-asyncio", "pytest-benchmark", "pytest-cov", "pytest-mock", "requests", "responses", "sentencepiece", "tensorflow", "tf-keras", "torch (>2.3.0)", "transformers", "xgrammar"] test-gpu = ["outlines[test]", "vllm ; sys_platform == \"linux\""] -transformers = ["accelerate", "datasets", "numpy (<2)", "transformers"] -vllm = ["numpy (<2)", "transformers", "vllm"] +tgi = ["huggingface_hub"] +transformers = ["accelerate", "datasets", "setuptools", "transformers"] +vllm = ["openai"] +xgrammar = ["xgrammar"] [[package]] name = "outlines-core" -version = "0.1.26" +version = "0.2.11" description = "Structured Text Generation in Rust" optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "outlines_core-0.1.26-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6a962a7452e7ac170fa04d405342cadae2d28fafa5b1830cef7aa610257ed32f"}, - {file = "outlines_core-0.1.26-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:15a3684fa29564da2db03934cf0097bef3e871f70d3af0ef2b52fdb886da2e09"}, - {file = "outlines_core-0.1.26-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64e01c0cfa9ba371634d7c3f6ea1862397cef98e4509fe98e3f57faa721a72d6"}, - {file = "outlines_core-0.1.26-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3c4196148e47f455f1ace78e329d5b97e531cbc406456d681592952adae7e17"}, - {file = "outlines_core-0.1.26-cp310-cp310-win32.whl", hash = "sha256:f38d290a7f6e5e12cbfcaee03269dfc0dbda49b360024b4279d1aba251fdc346"}, - {file = "outlines_core-0.1.26-cp310-cp310-win_amd64.whl", hash = "sha256:11ff56af56cb54c563b7f25d86cd9ee77f3fed825f1d4dccd9449bb1e4e89538"}, - {file = "outlines_core-0.1.26-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b6787b07b7c673fc3087d2b537719ecac8e03b10a47d032dd1926985c32885b0"}, - {file = "outlines_core-0.1.26-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e0ea28a76da31d25b6f53242bf13e1b59a0241badf82353c88f55e1cf81b128"}, - {file = "outlines_core-0.1.26-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8932044a3d9329be53a226118850638f85b4d7842f9b863d0a123f23de220cd"}, - {file = "outlines_core-0.1.26-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a84b7cd2fb6268bf990dd3d479ffb4fa0bace6f571cb85b15b6cdb44b84f5b69"}, - {file = "outlines_core-0.1.26-cp311-cp311-win32.whl", hash = "sha256:f19765c151abfc970996368080aeea6d2a19e927817fe4e2af6726e639be3de4"}, - {file = "outlines_core-0.1.26-cp311-cp311-win_amd64.whl", hash = "sha256:3f59aeccea21ed6ff3cf52102fd163f26d279821c20e5127ddd18d4ea4d0c8d2"}, - {file = "outlines_core-0.1.26-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f54633bca50055d42ea4d94ae06dcbe52d3d76a9b621b75723b1177d0d952953"}, - {file = "outlines_core-0.1.26-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9525321b48700dcaaabf60bcdc951e45f9357ba3fb3e1bfc81b662d7d4170e7c"}, - {file = "outlines_core-0.1.26-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00f409f72c11f6ffadb57066950dd384d5388015028c1a1a615c9a64988dae3e"}, - {file = "outlines_core-0.1.26-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e86a1bb46adc5cbf6dfd7a7fe4105e0e2a4c6e041732a053126b41c521a1f223"}, - {file = "outlines_core-0.1.26-cp312-cp312-win32.whl", hash = "sha256:19f462f6b00935708677ad27cb4df55e0e17f6ffe713ab750f5f2683b090f95d"}, - {file = "outlines_core-0.1.26-cp312-cp312-win_amd64.whl", hash = "sha256:9b36bff12779e58883747116893a17b3551bbd10865878b951b03a44d112229a"}, - {file = "outlines_core-0.1.26-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:7b7849cf40028319ebb9d8ba0fe4c590ef5888eebe524a81b3af30aaa06ea21c"}, - {file = "outlines_core-0.1.26-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2f8641aab4a6bd84516907492ce82099503129da01b3c29c1dc9ad50320bae77"}, - {file = "outlines_core-0.1.26-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bba56604efdbc5932c7a8a88c2b8b0d0c740ab883b0012fb5464a9736796802b"}, - {file = "outlines_core-0.1.26-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc8c87d89bd267356f8149c9066cbb98970425ec162997fbf195c3f1feb7009"}, - {file = "outlines_core-0.1.26-cp39-cp39-win32.whl", hash = "sha256:9d792a43ed9d8a4e1b38f4d83fe99db442d57aad4404c2edf98b710892eda47e"}, - {file = "outlines_core-0.1.26-cp39-cp39-win_amd64.whl", hash = "sha256:ad8564ecd7b64bcb840596c5049ff1c1a96346de494302ffcc0f2b188c15675e"}, - {file = "outlines_core-0.1.26.tar.gz", hash = "sha256:481c4301341e77cc8f1832d616784adb4d461b4fec65878e7c0d2cba7163a189"}, + {file = "outlines_core-0.2.11-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:89d79d8454b321f60047541a896d410ca9db631d241960266c4fe839cf5cd1b1"}, + {file = "outlines_core-0.2.11-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:44d581893f8644da02db7be11887229a40d26077cbdd22072ad1ed1db0ad0b2d"}, + {file = "outlines_core-0.2.11-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:e88b7f717915d91136d915adb65c2603d2aa6457ec3fc336884bdb0b28d3188a"}, + {file = "outlines_core-0.2.11-cp310-cp310-macosx_15_0_x86_64.whl", hash = "sha256:8c7ecdba2162e9b30b837251387c26b1a23f80f58d01d02e7600e4b1962c5333"}, + {file = "outlines_core-0.2.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd5fcefd221c10c95ce74838869450c6fdbbe2f581f0ba27e57a95232bd88c3a"}, + {file = "outlines_core-0.2.11-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a3c7774b112106f3afe931c65637fb3e0725d43707ceff1d34d6899cf0fa8200"}, + {file = "outlines_core-0.2.11-cp310-cp310-win32.whl", hash = "sha256:1cfbb4cdcf34be5c6b08d279928b2b1050ed4c5e96e6e8405e3e624305c6799e"}, + {file = "outlines_core-0.2.11-cp310-cp310-win_amd64.whl", hash = "sha256:670c1c1fca26fb5c7f00dbb11d1f81cca4204863c3dfdeee82017a6846397bf9"}, + {file = "outlines_core-0.2.11-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:e96b8d0b56afcd3b86f4efca466c578f3725da1148ef62423249c92993841762"}, + {file = "outlines_core-0.2.11-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:d108ee8cd5e2fe71c2b0720b949d004901fec8bdb64bcd0c01b8abe38ab7ae1c"}, + {file = "outlines_core-0.2.11-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:ebf42ab5b7ae38235d3c3333b5cacd6e91449b87b8a48a85094ea28ad9de9878"}, + {file = "outlines_core-0.2.11-cp311-cp311-macosx_15_0_x86_64.whl", hash = "sha256:fd4305ff8418d14059d95dc3276ca96ba1b5aa499908e1af8bb3c7207aa7ac68"}, + {file = "outlines_core-0.2.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:132605b8dd1e3d1369da6a851992dd357f6376068292f6bd47caa7a28b794d19"}, + {file = "outlines_core-0.2.11-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:b31d5fc83b78aad282dd667b8d6e684614481fe08a7609ce0ce45dee64cd2991"}, + {file = "outlines_core-0.2.11-cp311-cp311-win32.whl", hash = "sha256:3e316a79f3ecfa12c17746edebcbd66538ee22a43986982f6b96166fb94ee6b1"}, + {file = "outlines_core-0.2.11-cp311-cp311-win_amd64.whl", hash = "sha256:c260a042b5854ff69291649cfd112066e6bab0dad0bb9cec8a6c3705ef3a59cd"}, + {file = "outlines_core-0.2.11-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:4a9db4872bae083631d720994f4cee603bce0536b33d5a988814576863b657cf"}, + {file = "outlines_core-0.2.11-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:8359a45c59f6a8f2eb717245806501a59044c75f6ea8bd08faaa131cc8cdec45"}, + {file = "outlines_core-0.2.11-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:5d26a46591377340e0b870b8a96ea8341058341a62ee0bded9098e0c88dd24f4"}, + {file = "outlines_core-0.2.11-cp312-cp312-macosx_15_0_x86_64.whl", hash = "sha256:ae460a34675fb11d92a5c605a480fbae4cd6c1b2d11b3698da64a7fcaba64dcf"}, + {file = "outlines_core-0.2.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86df9740368866295077346440d911df4972da2b3f1f54b8125e6f329e8a8891"}, + {file = "outlines_core-0.2.11-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:96ce4dd78f106799be4a0a5795cefd1352806162973756a4b6fce4bb6eddd7e4"}, + {file = "outlines_core-0.2.11-cp312-cp312-win32.whl", hash = "sha256:358db161cce3650ba822e118dcf0a1efa571c7deb4864ab9d64ca2c9cca7425d"}, + {file = "outlines_core-0.2.11-cp312-cp312-win_amd64.whl", hash = "sha256:231f9d20d2630c70665345821780d7808b29539620a75c99f65113b518c51032"}, + {file = "outlines_core-0.2.11-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:0907ff25d79edbf8650268028de85a1b41b38696f147059e007da4626a1031f1"}, + {file = "outlines_core-0.2.11-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:f4146da5957f97550eebd19e80635e48035886fd10f03e9735cc111caaf74e93"}, + {file = "outlines_core-0.2.11-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:8776a6db8843187c90e4c54bf94510cda68ca7a11c9b48d90587179fd3224bc2"}, + {file = "outlines_core-0.2.11-cp313-cp313-macosx_15_0_x86_64.whl", hash = "sha256:d44f38a89028bed50494420b47d08ebefa78f34b129e2ea6383c801e5ba62c26"}, + {file = "outlines_core-0.2.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:daef6eaaf8c3403455ab5cbf265cb5c6838df571eb7c4b23cddac19cfc701726"}, + {file = "outlines_core-0.2.11-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:76b2512417c68863f8f227a080e87f755682dfd895e23b021121318be11da579"}, + {file = "outlines_core-0.2.11-cp313-cp313-win32.whl", hash = "sha256:707eeb3d190485f55a27ad9a6ad70df86688fa2bf405894a118283be7f59bd55"}, + {file = "outlines_core-0.2.11-cp313-cp313-win_amd64.whl", hash = "sha256:ad46698564c9b13cbfbc744067de12be73bd740d7b2de20ec6b979ad7511f7c9"}, + {file = "outlines_core-0.2.11-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:defe30707d2c7718e6572b222028de1973c150ce3ec29ecf3f16dc5309a313ee"}, + {file = "outlines_core-0.2.11-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:576fefbf50ff09ad3b42e3d5bd344d8668fc650188fcc06b9a0356fdc6a89b84"}, + {file = "outlines_core-0.2.11-cp39-cp39-macosx_15_0_arm64.whl", hash = "sha256:63a2f1d54929421ac8af715921a67b6da1f52cfe7c3ca6cddb194268bbc99140"}, + {file = "outlines_core-0.2.11-cp39-cp39-macosx_15_0_x86_64.whl", hash = "sha256:90f43cc83a109bfe72f4862d34b1d29e28c76477bbdf58b091ec34aa7f795ff1"}, + {file = "outlines_core-0.2.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dae17b09f6f08d01fa0c228ab282197379ea10aa46b27f40b80c2014331af217"}, + {file = "outlines_core-0.2.11-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:3a9db6831346ec4e683022c05b45403ec1c5f4a3fe52a2a7ebcc1d7d9dc3a5fb"}, + {file = "outlines_core-0.2.11-cp39-cp39-win32.whl", hash = "sha256:a41c2d518367a4628bca3e4f509b268642c2cdec70b631c64f07d5158d029e0d"}, + {file = "outlines_core-0.2.11-cp39-cp39-win_amd64.whl", hash = "sha256:bc173be0f5c089c23fdb1df0dc4b9075140be2f4928748fefc58ea46a2bd36bd"}, + {file = "outlines_core-0.2.11.tar.gz", hash = "sha256:dfce56f717ff5083e54cbcfdb66cad243365437fccbb5509adaa7e31e030f1d8"}, ] -[package.dependencies] -interegular = "*" -jsonschema = "*" - [package.extras] -test = ["accelerate", "asv", "beartype (<0.16.0)", "coverage[toml] (>=5.1)", "datasets", "diff-cover", "huggingface_hub", "numpy", "pillow", "pre-commit", "psutil", "pydantic", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "scipy", "setuptools-rust", "torch", "transformers"] +test = ["asv", "coverage[toml] (>=5.1)", "diff-cover", "maturin[patchelf]", "numba", "numpy", "pre-commit", "psutil", "pydantic", "pytest", "pytest-benchmark", "pytest-cov", "pytest-mock", "scipy", "torch"] [[package]] name = "packaging" @@ -2416,41 +2399,17 @@ typing = ["typing-extensions ; python_version < \"3.10\""] xmp = ["defusedxml"] [[package]] -name = "platformdirs" -version = "4.3.6" -description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, - {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, -] - -[package.extras] -docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2.4)"] -test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.2)", "pytest-cov (>=5)", "pytest-mock (>=3.14)"] -type = ["mypy (>=1.11.2)"] - -[[package]] -name = "pre-commit" -version = "4.1.0" -description = "A framework for managing and maintaining multi-language pre-commit hooks." +name = "ply" +version = "3.11" +description = "Python Lex & Yacc" optional = false -python-versions = ">=3.9" +python-versions = "*" groups = ["main"] files = [ - {file = "pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b"}, - {file = "pre_commit-4.1.0.tar.gz", hash = "sha256:ae3f018575a588e30dfddfab9a05448bfbd6b73d78709617b5a2b853549716d4"}, + {file = "ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce"}, + {file = "ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3"}, ] -[package.dependencies] -cfgv = ">=2.0.0" -identify = ">=1.0.0" -nodeenv = ">=0.11.1" -pyyaml = ">=5.1" -virtualenv = ">=20.10.0" - [[package]] name = "protobuf" version = "5.28.3" @@ -3586,33 +3545,21 @@ resolved_reference = "eae104809fe492eb2c2750329572a5ee0d4587d9" [[package]] name = "sympy" -version = "1.12.1" +version = "1.14.0" description = "Computer algebra system (CAS) in Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"}, - {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"}, + {file = "sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5"}, + {file = "sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517"}, ] [package.dependencies] -mpmath = ">=1.1.0,<1.4.0" +mpmath = ">=1.1.0,<1.4" -[[package]] -name = "tbb" -version = "2021.12.0" -description = "IntelĀ® oneAPI Threading Building Blocks (oneTBB)" -optional = false -python-versions = "*" -groups = ["main"] -markers = "platform_system == \"Windows\"" -files = [ - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"}, - {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"}, - {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"}, - {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"}, -] +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] [[package]] name = "tenacity" @@ -3697,85 +3644,132 @@ blobfile = ["blobfile (>=2)"] [[package]] name = "tokenizers" -version = "0.21.1" +version = "0.22.1" description = "" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "tokenizers-0.21.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e78e413e9e668ad790a29456e677d9d3aa50a9ad311a40905d6861ba7692cf41"}, - {file = "tokenizers-0.21.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:cd51cd0a91ecc801633829fcd1fda9cf8682ed3477c6243b9a095539de4aecf3"}, - {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28da6b72d4fb14ee200a1bd386ff74ade8992d7f725f2bde2c495a9a98cf4d9f"}, - {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34d8cfde551c9916cb92014e040806122295a6800914bab5865deb85623931cf"}, - {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaa852d23e125b73d283c98f007e06d4595732104b65402f46e8ef24b588d9f8"}, - {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a21a15d5c8e603331b8a59548bbe113564136dc0f5ad8306dd5033459a226da0"}, - {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2fdbd4c067c60a0ac7eca14b6bd18a5bebace54eb757c706b47ea93204f7a37c"}, - {file = "tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dd9a0061e403546f7377df940e866c3e678d7d4e9643d0461ea442b4f89e61a"}, - {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:db9484aeb2e200c43b915a1a0150ea885e35f357a5a8fabf7373af333dcc8dbf"}, - {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed248ab5279e601a30a4d67bdb897ecbe955a50f1e7bb62bd99f07dd11c2f5b6"}, - {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9ac78b12e541d4ce67b4dfd970e44c060a2147b9b2a21f509566d556a509c67d"}, - {file = "tokenizers-0.21.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e5a69c1a4496b81a5ee5d2c1f3f7fbdf95e90a0196101b0ee89ed9956b8a168f"}, - {file = "tokenizers-0.21.1-cp39-abi3-win32.whl", hash = "sha256:1039a3a5734944e09de1d48761ade94e00d0fa760c0e0551151d4dd851ba63e3"}, - {file = "tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382"}, - {file = "tokenizers-0.21.1.tar.gz", hash = "sha256:a1bb04dc5b448985f86ecd4b05407f5a8d97cb2c0532199b2a302a604a0165ab"}, + {file = "tokenizers-0.22.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59fdb013df17455e5f950b4b834a7b3ee2e0271e6378ccb33aa74d178b513c73"}, + {file = "tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:8d4e484f7b0827021ac5f9f71d4794aaef62b979ab7608593da22b1d2e3c4edc"}, + {file = "tokenizers-0.22.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d2962dd28bc67c1f205ab180578a78eef89ac60ca7ef7cbe9635a46a56422a"}, + {file = "tokenizers-0.22.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:38201f15cdb1f8a6843e6563e6e79f4abd053394992b9bbdf5213ea3469b4ae7"}, + {file = "tokenizers-0.22.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1cbe5454c9a15df1b3443c726063d930c16f047a3cc724b9e6e1a91140e5a21"}, + {file = "tokenizers-0.22.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7d094ae6312d69cc2a872b54b91b309f4f6fbce871ef28eb27b52a98e4d0214"}, + {file = "tokenizers-0.22.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afd7594a56656ace95cdd6df4cca2e4059d294c5cfb1679c57824b605556cb2f"}, + {file = "tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2ef6063d7a84994129732b47e7915e8710f27f99f3a3260b8a38fc7ccd083f4"}, + {file = "tokenizers-0.22.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ba0a64f450b9ef412c98f6bcd2a50c6df6e2443b560024a09fa6a03189726879"}, + {file = "tokenizers-0.22.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:331d6d149fa9c7d632cde4490fb8bbb12337fa3a0232e77892be656464f4b446"}, + {file = "tokenizers-0.22.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:607989f2ea68a46cb1dfbaf3e3aabdf3f21d8748312dbeb6263d1b3b66c5010a"}, + {file = "tokenizers-0.22.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a0f307d490295717726598ef6fa4f24af9d484809223bbc253b201c740a06390"}, + {file = "tokenizers-0.22.1-cp39-abi3-win32.whl", hash = "sha256:b5120eed1442765cd90b903bb6cfef781fd8fe64e34ccaecbae4c619b7b12a82"}, + {file = "tokenizers-0.22.1-cp39-abi3-win_amd64.whl", hash = "sha256:65fd6e3fb11ca1e78a6a93602490f134d1fdeb13bcef99389d5102ea318ed138"}, + {file = "tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9"}, ] [package.dependencies] -huggingface-hub = ">=0.16.4,<1.0" +huggingface-hub = ">=0.16.4,<2.0" [package.extras] dev = ["tokenizers[testing]"] docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] -testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "pytest-asyncio", "requests", "ruff"] [[package]] name = "torch" -version = "2.3.1+cu118" +version = "2.0.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" groups = ["main"] +markers = "python_version < \"3.11\"" files = [ - {file = "torch-2.3.1+cu118-cp310-cp310-linux_x86_64.whl", hash = "sha256:fb4c9249b29f58e066ef1d259410de49a2c23c8727883f69065f61244bb780b9"}, - {file = "torch-2.3.1+cu118-cp310-cp310-win_amd64.whl", hash = "sha256:c8248eb98425573e496a7ee9d77b2329bb2ef70e3af7eb51fad5438a12b30b8e"}, - {file = "torch-2.3.1+cu118-cp311-cp311-linux_x86_64.whl", hash = "sha256:5b0d531814886573cbe8c8ca91d17676f96bbaa33b569dd37ea235f124314e97"}, - {file = "torch-2.3.1+cu118-cp311-cp311-win_amd64.whl", hash = "sha256:a697df4337d6f18a204b7603c06bec9c81ed393ceae71432c4a4e2902bc20270"}, - {file = "torch-2.3.1+cu118-cp312-cp312-linux_x86_64.whl", hash = "sha256:6c03ff41879674cbd661b598cec80fb5e6f7faa225624732a2a197b5471dbd38"}, - {file = "torch-2.3.1+cu118-cp312-cp312-win_amd64.whl", hash = "sha256:f44c7b64d990a6b1a382d1cd63c359806153974e7db8d16f6780645a8a9c9fe0"}, - {file = "torch-2.3.1+cu118-cp38-cp38-linux_x86_64.whl", hash = "sha256:5669916fed356a6a4034aeaf9c78184bd1b4467b06d75d95f6540dd16969ad31"}, - {file = "torch-2.3.1+cu118-cp38-cp38-win_amd64.whl", hash = "sha256:2345d7a880c29123744d74719ebbaf04aba170d45dd8c9a86e876e81493408fc"}, - {file = "torch-2.3.1+cu118-cp39-cp39-linux_x86_64.whl", hash = "sha256:815090508144030b54b8c34af9abe45168332d513b3e0e35971afbca5813c2ed"}, - {file = "torch-2.3.1+cu118-cp39-cp39-win_amd64.whl", hash = "sha256:78c9e0206f40a9f12c0369b2737d78d1998244238384286fd5492b39299059a7"}, + {file = "torch-2.0.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:359bfaad94d1cda02ab775dc1cc386d585712329bb47b8741607ef6ef4950747"}, + {file = "torch-2.0.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b6019b1de4978e96daa21d6a3ebb41e88a0b474898fe251fd96189587408873e"}, + {file = "torch-2.0.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:0882243755ff28895e8e6dc6bc26ebcf5aa0911ed81b2a12f241fc4b09075b13"}, + {file = "torch-2.0.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:423e0ae257b756bb45a4b49072046772d1ad0c592265c5080070e0767da4e490"}, ] [package.dependencies] filelock = "*" -fsspec = "*" jinja2 = "*" -mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""} networkx = "*" -nvidia-cublas-cu11 = {version = "11.11.3.6", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-cupti-cu11 = {version = "11.8.87", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-nvrtc-cu11 = {version = "11.8.89", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cuda-runtime-cu11 = {version = "11.8.89", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cudnn-cu11 = {version = "8.7.0.84", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-curand-cu11 = {version = "10.3.0.86", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusolver-cu11 = {version = "11.4.1.48", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-cusparse-cu11 = {version = "11.7.5.86", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nccl-cu11 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} -nvidia-nvtx-cu11 = {version = "11.8.86", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} -typing-extensions = ">=4.8.0" +typing-extensions = "*" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] + +[package.source] +type = "legacy" +url = "https://download.pytorch.org/whl/cu130" +reference = "pytorch" + +[[package]] +name = "torch" +version = "2.9.0+cu130" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version >= \"3.11\"" +files = [ + {file = "torch-2.9.0+cu130-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:46004a346db6bfd69ecd2e42dce48e0fce2ad0e5a910f8203db5206f5515387e"}, + {file = "torch-2.9.0+cu130-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:856c15eff328534bf6be54349b1684f524abbd521c704ab8b3e077de87810966"}, + {file = "torch-2.9.0+cu130-cp310-cp310-win_amd64.whl", hash = "sha256:ec07f6494f5c9925bd2e3d76d05a7d50464ddb6295998084073469f50f9e80ef"}, + {file = "torch-2.9.0+cu130-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6c7e0205f110b6b057820d4d2128d97bfb536526d35c48969935bb27a9ee9218"}, + {file = "torch-2.9.0+cu130-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e748b84e700634fbf08f62319e18e228c861895fd41ea7c73043c81d6d0968c4"}, + {file = "torch-2.9.0+cu130-cp311-cp311-win_amd64.whl", hash = "sha256:906d163dcce05cf095c3ce076f872c4469175e00e24399900bf9708c54a44cce"}, + {file = "torch-2.9.0+cu130-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:3aef05b6247261f4a7c440be9a052c4be36c673c6721920181a4ac9a66d6c2a2"}, + {file = "torch-2.9.0+cu130-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:cc241ffb20428f6a44c299ca06b934445606cf1fa48f3b68ef3af0a04c86bc3b"}, + {file = "torch-2.9.0+cu130-cp312-cp312-win_amd64.whl", hash = "sha256:b9979a7c0a1c9544a857fc2390ebc89938f116eaaf6a359a0d46597402ca51da"}, + {file = "torch-2.9.0+cu130-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ecf3d24bd4c0e6e425bd778a6de99b52279e0021a60d7eb11ab0c2d669f3f9b0"}, + {file = "torch-2.9.0+cu130-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:92a92db6cde38d05220c1f7de91ceacff020546386c5b7a0a268dcaae17b5c18"}, + {file = "torch-2.9.0+cu130-cp313-cp313-win_amd64.whl", hash = "sha256:7d83c2439d01aefc8ffea61cae2b8288cded5a90f60e034bc9830a7dc8029d84"}, + {file = "torch-2.9.0+cu130-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e5a45f68dd2c93e18d62d8ed5d2ba4243865d32a049b654ad3ee6527bda5b437"}, + {file = "torch-2.9.0+cu130-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:bd7331780bd444077792b699a535b20a7f1275e3bca99f6bec3c88d324bb0bee"}, + {file = "torch-2.9.0+cu130-cp313-cp313t-win_amd64.whl", hash = "sha256:5899d5becbec8ecf33edaadc0cfed6a26cf5143ae63ce138988eeb8081b45d81"}, + {file = "torch-2.9.0+cu130-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:cb0db232eb9edaad9d2ae4e18f9f0a7763ff5c1774bacd2d6eb4a92a8ba28678"}, + {file = "torch-2.9.0+cu130-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:a656d92ec2c8305a00b061f0cac1da4df54bf491fd937e10754c76518a5ce87b"}, + {file = "torch-2.9.0+cu130-cp314-cp314-win_amd64.whl", hash = "sha256:3c9c96b4168020e91d90756070a793af1ff511cab8090ea487acd12b7419d861"}, + {file = "torch-2.9.0+cu130-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:6fb83834a825d4dfe6cd55cc2b370337ab369110ead6aecda98dcefacc8f3b24"}, + {file = "torch-2.9.0+cu130-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:f3f3cce8e6c13887bedf0354de3a2f4ca8989e9c3d9cb8dc3bc77f7eddf6ea97"}, + {file = "torch-2.9.0+cu130-cp314-cp314t-win_amd64.whl", hash = "sha256:cdc189be3f216661353486e678199d4102f281804ebddd1c4d0f91b10a30963b"}, +] + +[package.dependencies] +filelock = "*" +fsspec = ">=0.8.5" +jinja2 = "*" +networkx = ">=2.5.1" +nvidia-cublas = {version = "13.0.0.19", markers = "platform_system == \"Linux\""} +nvidia-cuda-cupti = {version = "13.0.48", markers = "platform_system == \"Linux\""} +nvidia-cuda-nvrtc = {version = "13.0.48", markers = "platform_system == \"Linux\""} +nvidia-cuda-runtime = {version = "13.0.48", markers = "platform_system == \"Linux\""} +nvidia-cudnn-cu13 = {version = "9.13.0.50", markers = "platform_system == \"Linux\""} +nvidia-cufft = {version = "12.0.0.15", markers = "platform_system == \"Linux\""} +nvidia-cufile = {version = "1.15.0.42", markers = "platform_system == \"Linux\""} +nvidia-curand = {version = "10.4.0.35", markers = "platform_system == \"Linux\""} +nvidia-cusolver = {version = "12.0.3.29", markers = "platform_system == \"Linux\""} +nvidia-cusparse = {version = "12.6.2.49", markers = "platform_system == \"Linux\""} +nvidia-cusparselt-cu13 = {version = "0.8.0", markers = "platform_system == \"Linux\""} +nvidia-nccl-cu13 = {version = "2.27.7", markers = "platform_system == \"Linux\""} +nvidia-nvjitlink = {version = "13.0.39", markers = "platform_system == \"Linux\""} +nvidia-nvshmem-cu13 = {version = "3.3.24", markers = "platform_system == \"Linux\""} +nvidia-nvtx = {version = "13.0.39", markers = "platform_system == \"Linux\""} +setuptools = {version = "*", markers = "python_version >= \"3.12\""} +sympy = ">=1.13.3" +triton = {version = "3.5.0", markers = "platform_system == \"Linux\""} +typing-extensions = ">=4.10.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -optree = ["optree (>=0.9.1)"] +optree = ["optree (>=0.13.0)"] +pyyaml = ["pyyaml"] [package.source] type = "legacy" -url = "https://download.pytorch.org/whl/cu118" +url = "https://download.pytorch.org/whl/cu130" reference = "pytorch" [[package]] @@ -3801,98 +3795,108 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.49.0" +version = "4.57.1" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.9.0" groups = ["main"] files = [ - {file = "transformers-4.49.0-py3-none-any.whl", hash = "sha256:6b4fded1c5fee04d384b1014495b4235a2b53c87503d7d592423c06128cbbe03"}, - {file = "transformers-4.49.0.tar.gz", hash = "sha256:7e40e640b5b8dc3f48743f5f5adbdce3660c82baafbd3afdfc04143cdbd2089e"}, + {file = "transformers-4.57.1-py3-none-any.whl", hash = "sha256:b10d05da8fa67dc41644dbbf9bc45a44cb86ae33da6f9295f5fbf5b7890bd267"}, + {file = "transformers-4.57.1.tar.gz", hash = "sha256:f06c837959196c75039809636cd964b959f6604b75b8eeec6fdfc0440b89cc55"}, ] [package.dependencies] filelock = "*" -huggingface-hub = ">=0.26.0,<1.0" +huggingface-hub = ">=0.34.0,<1.0" numpy = ">=1.17" packaging = ">=20.0" pyyaml = ">=5.1" regex = "!=2019.12.17" requests = "*" -safetensors = ">=0.4.1" -tokenizers = ">=0.21,<0.22" +safetensors = ">=0.4.3" +tokenizers = ">=0.22.0,<=0.23.0" tqdm = ">=4.27" [package.extras] accelerate = ["accelerate (>=0.26.0)"] -agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=2.0)"] -all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision"] +all = ["Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "av", "codecarbon (>=2.8.1)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "jinja2 (>=3.1.0)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.6.1,<=0.9)", "librosa", "mistral-common[opencv] (>=1.6.3)", "num2words", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (!=1.0.18,<=1.0.19)", "tokenizers (>=0.22.0,<=0.23.0)", "torch (>=2.2)", "torchaudio", "torchvision"] audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] benchmark = ["optimum-benchmark (>=0.3.0)"] +chat-template = ["jinja2 (>=3.1.0)"] codecarbon = ["codecarbon (>=2.8.1)"] deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fastapi", "libcst", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "openai (>=1.98.0)", "optuna", "parameterized (>=0.9)", "protobuf", "psutil", "pydantic (>=2)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.13.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "tensorboard", "timeout-decorator", "torch (>=2.2)", "uvicorn"] +dev = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "av", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fastapi", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "jinja2 (>=3.1.0)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "kernels (>=0.6.1,<=0.9)", "libcst", "libcst", "librosa", "mistral-common[opencv] (>=1.6.3)", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "num2words", "onnxconverter-common", "openai (>=1.98.0)", "optax (>=0.0.8,<=0.1.4)", "optuna", "pandas (<2.3.0)", "parameterized (>=0.9)", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (>=2)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.13.1)", "ruff (==0.13.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (!=1.0.18,<=1.0.19)", "tokenizers (>=0.22.0,<=0.23.0)", "torch (>=2.2)", "torch (>=2.2)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)", "uvicorn"] +dev-tensorflow = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fastapi", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "libcst", "librosa", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "onnxconverter-common", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "openai (>=1.98.0)", "pandas (<2.3.0)", "parameterized (>=0.9)", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (>=2)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.13.1)", "ruff (==0.13.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "tf2onnx", "timeout-decorator", "tokenizers (>=0.22.0,<=0.23.0)", "torch (>=2.2)", "urllib3 (<2.0.0)", "uvicorn"] +dev-torch = ["GitPython (<3.1.19)", "GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (>=2.8.1)", "cookiecutter (==1.7.3)", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fastapi", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "kenlm", "kernels (>=0.6.1,<=0.9)", "libcst", "libcst", "librosa", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "num2words", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "openai (>=1.98.0)", "optuna", "pandas (<2.3.0)", "parameterized (>=0.9)", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic (>=2)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.13.1)", "ruff (==0.13.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (!=1.0.18,<=1.0.19)", "tokenizers (>=0.22.0,<=0.23.0)", "torch (>=2.2)", "torch (>=2.2)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)", "urllib3 (<2.0.0)", "uvicorn"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] -integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] -ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +hf-xet = ["hf_xet"] +hub-kernels = ["kernels (>=0.6.1,<=0.9)"] +integrations = ["kernels (>=0.6.1,<=0.9)", "optuna", "ray[tune] (>=2.7.0)"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict_core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic_lite (>=1.0.7)"] +mistral-common = ["mistral-common[opencv] (>=1.6.3)"] modelcreation = ["cookiecutter (==1.7.3)"] natten = ["natten (>=0.14.6,<0.15.0)"] +num2words = ["num2words"] onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +open-telemetry = ["opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk"] optuna = ["optuna"] -quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"] +quality = ["GitPython (<3.1.19)", "datasets (>=2.15.0)", "libcst", "pandas (<2.3.0)", "rich", "ruff (==0.13.1)", "urllib3 (<2.0.0)"] ray = ["ray[tune] (>=2.7.0)"] -retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] -ruff = ["ruff (==0.5.1)"] +retrieval = ["datasets (>=2.15.0)", "faiss-cpu"] +ruff = ["ruff (==0.13.1)"] sagemaker = ["sagemaker (>=2.31.0)"] sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] -serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +serving = ["accelerate (>=0.26.0)", "fastapi", "openai (>=1.98.0)", "pydantic (>=2)", "starlette", "torch (>=2.2)", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-asyncio", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (>=2.15.0)", "datasets (>=2.15.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fastapi", "libcst", "mistral-common[opencv] (>=1.6.3)", "nltk (<=3.8.1)", "openai (>=1.98.0)", "parameterized (>=0.9)", "psutil", "pydantic (>=2)", "pydantic (>=2)", "pytest (>=7.2.0)", "pytest-asyncio", "pytest-order", "pytest-rerunfailures (<16.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.13.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "starlette", "tensorboard", "timeout-decorator", "torch (>=2.2)", "uvicorn"] tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] tiktoken = ["blobfile", "tiktoken"] -timm = ["timm (<=1.0.11)"] -tokenizers = ["tokenizers (>=0.21,<0.22)"] -torch = ["accelerate (>=0.26.0)", "torch (>=2.0)"] +timm = ["timm (!=1.0.18,<=1.0.19)"] +tokenizers = ["tokenizers (>=0.22.0,<=0.23.0)"] +torch = ["accelerate (>=0.26.0)", "torch (>=2.2)"] torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] -torchhub = ["filelock", "huggingface-hub (>=0.26.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch (>=2.0)", "tqdm (>=4.27)"] +torchhub = ["filelock", "huggingface-hub (>=0.34.0,<1.0)", "importlib_metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.22.0,<=0.23.0)", "torch (>=2.2)", "tqdm (>=4.27)"] video = ["av"] vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "triton" -version = "2.3.1" +version = "3.5.0" description = "A language and compiler for custom Deep Learning operations" optional = false -python-versions = "*" +python-versions = "<3.15,>=3.10" groups = ["main"] -markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version <= \"3.11\"" +markers = "python_version >= \"3.11\" and platform_system == \"Linux\"" files = [ - {file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"}, - {file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"}, - {file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"}, - {file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"}, - {file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"}, - {file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"}, + {file = "triton-3.5.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6f90de6a6566bb619b4c0adc9855729e1b1b5e26533fca1bf6206e96b6d277a3"}, + {file = "triton-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d5d3b3d480debf24eaa739623c9a42446b0b77f95593d30eb1f64cd2278cc1f0"}, + {file = "triton-3.5.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8457b22148defefdcb7fa8144b05ce211b9faefad650a1ce85b23df488d5549c"}, + {file = "triton-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f34bfa21c5b3a203c0f0eab28dcc1e49bd1f67d22724e77fb6665a659200a4ec"}, + {file = "triton-3.5.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7da21fccceafc163e3a5e857abe34351ef76345af06cabf9637a914742671f0b"}, + {file = "triton-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9e71db82261c4ffa3921cd050cd5faa18322d2d405c30eb56084afaff3b0833"}, + {file = "triton-3.5.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:188da5b81fa2f8322c27fec1627703eac24cb9bb7ab0dfbe9925973bc1b070d3"}, + {file = "triton-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e6bb9aa5519c084a333acdba443789e50012a4b851cd486c54f0b8dc2a8d3a12"}, + {file = "triton-3.5.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03127d9b33aaf979c856676b394bc059ec1d68cb6da68ae03f62dd8ad77a04ae"}, + {file = "triton-3.5.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c83f2343e1a220a716c7b3ab9fccfcbe3ad4020d189549200e2d2e8d5868bed9"}, + {file = "triton-3.5.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:468936651d383f4a6d10068d34a627505e13af55be5d002b9f27b987e7a5f0ac"}, + {file = "triton-3.5.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da0fa67ccd76c3dcfb0bffe1b1c57c685136a6bd33d141c24d9655d4185b1289"}, + {file = "triton-3.5.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c7ceef21410229ac23173a28eee5cfc0e37c1dfdb8b4bc11ecda2e3ecec7c686"}, + {file = "triton-3.5.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:317fe477ea8fd4524a6a8c499fb0a36984a56d0b75bf9c9cb6133a1c56d5a6e7"}, ] -[package.dependencies] -filelock = "*" - [package.extras] -build = ["cmake (>=3.20)", "lit"] -tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] -tutorials = ["matplotlib", "pandas", "tabulate", "torch"] +build = ["cmake (>=3.20,<4.0)", "lit"] +tests = ["autopep8", "isort", "llnl-hatchet", "numpy", "pytest", "pytest-forked", "pytest-xdist", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "typing-extensions" @@ -3975,27 +3979,6 @@ brotli = ["brotli (==1.0.9) ; os_name != \"nt\" and python_version < \"3\" and p secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress ; python_version == \"2.7\"", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] -[[package]] -name = "virtualenv" -version = "20.29.3" -description = "Virtual Python Environment builder" -optional = false -python-versions = ">=3.8" -groups = ["main"] -files = [ - {file = "virtualenv-20.29.3-py3-none-any.whl", hash = "sha256:3e3d00f5807e83b234dfb6122bf37cfadf4be216c53a49ac059d02414f819170"}, - {file = "virtualenv-20.29.3.tar.gz", hash = "sha256:95e39403fcf3940ac45bc717597dba16110b74506131845d9b687d5e73d947ac"}, -] - -[package.dependencies] -distlib = ">=0.3.7,<1" -filelock = ">=3.12.2,<4" -platformdirs = ">=3.9.1,<5" - -[package.extras] -docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] - [[package]] name = "xxhash" version = "3.5.0" @@ -4253,4 +4236,4 @@ test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.funct [metadata] lock-version = "2.1" python-versions = ">=3.9,<3.13" -content-hash = "5a469023bf886c3ba9aea5ae6e2594090465413a6395da2eabfb7550870c0c30" +content-hash = "a5afd2ba5fcf06c0f5bc8c7aaa5b0bdc4e3a8b2f4adda494edb30870b4cf76bb" diff --git a/pyproject.toml b/pyproject.toml index 0ffe0044..ed1436b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,13 +8,13 @@ packages = [{include = "align_system"}] [[tool.poetry.source]] name = "pytorch" -url = "https://download.pytorch.org/whl/cu118" +url = "https://download.pytorch.org/whl/cu130" priority = "supplemental" [tool.poetry.dependencies] -python = ">=3.9,<3.13" +python = ">=3.10,<3.13" torch = { version = "^2.0.1", source = "pytorch" } -transformers = "^4.49.0" +transformers = "^4.57.1" llama-index = "^0.8.21" langchain = "^0.0.308" sentence-transformers = "^2.2.2" @@ -28,12 +28,13 @@ rich = "^13.6.0" rouge-score = "^0.1.2" swagger-client = {git = "https://github.com/NextCenturyCorporation/itm-evaluation-client.git", rev = "0.5.0"} hydra-core = "^1.3.2" -outlines = "^0.2.1" -setuptools = "^70.1.1" +outlines = "^1.2.7" +setuptools = "^77.0.3" sentencepiece = "^0.2.0" protobuf = "^5.28.3" datasets = "^3.3.2" ubelt = "1.3.6" +vllm = "^0.11.1" [tool.poetry.scripts] run_align_system = 'align_system.cli.run_align_system:main'