From 4fa5a7e003be2c7b3ffc27b96f7151a900c6bc86 Mon Sep 17 00:00:00 2001 From: WeetHet Date: Thu, 19 Sep 2024 13:19:35 +0300 Subject: [PATCH] chore: typing pass --- .gitignore | 2 +- pyrightconfig.json | 3 ++ verified_cogen/args.py | 49 ++++++++++++++++++-- verified_cogen/experiments/use_houdini.py | 32 ++++++++++--- verified_cogen/llm/llm.py | 21 +++++---- verified_cogen/llm/prompts.py | 8 ++-- verified_cogen/main.py | 21 +++++---- verified_cogen/runners/languages/dafny.py | 2 +- verified_cogen/runners/languages/language.py | 12 ++--- verified_cogen/runners/languages/nagini.py | 2 +- verified_cogen/tools/__init__.py | 10 ++-- 11 files changed, 115 insertions(+), 47 deletions(-) create mode 100644 pyrightconfig.json diff --git a/.gitignore b/.gitignore index c756e5e..85a8bb1 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ data llm-generated .envrc **/.DS_Store -pyrightconfig.json benches/DafnyBench/ benches/HumanEval-Dafny-Mini/ break_assert.rs @@ -13,6 +12,7 @@ break_assert.rs log .ruff_cache .vscode +.zed run.sh /dist/ **/.pytest_cache diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 0000000..0102dcd --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,3 @@ +{ + "typeCheckingMode": "strict" +} diff --git a/verified_cogen/args.py b/verified_cogen/args.py index 29d57b9..64ba3fe 100644 --- a/verified_cogen/args.py +++ b/verified_cogen/args.py @@ -2,6 +2,47 @@ import os from verified_cogen.tools.modes import VALID_MODES +from typing import no_type_check, Optional + + +class ProgramArgs: + input: Optional[str] + dir: Optional[str] + runs: int + insert_conditions_mode: str + bench_type: str + temperature: int + shell: str + verifier_command: str + verifier_timeout: int + prompts_directory: str + grazie_token: str + llm_profile: str + tries: int + retries: int + output_style: str + filter_by_ext: Optional[str] + log_tries: Optional[str] + + @no_type_check + def __init__(self, args): + self.input = args.input + self.dir = args.dir + self.runs = args.runs + self.insert_conditions_mode = args.insert_conditions_mode + self.bench_type = args.bench_type + self.temperature = args.temperature + self.shell = args.shell + self.verifier_command = args.verifier_command + self.verifier_timeout = args.verifier_timeout + self.prompts_directory = args.prompts_directory + self.grazie_token = args.grazie_token + self.llm_profile = args.llm_profile + self.tries = args.tries + self.retries = args.retries + self.output_style = args.output_style + self.filter_by_ext = args.filter_by_ext + self.log_tries = args.log_tries def get_default_parser(): @@ -50,12 +91,12 @@ def get_default_parser(): parser.add_argument( "-s", "--output-style", choices=["stats", "full"], default="full" ) - parser.add_argument("--filter-by-ext", help="filter by extension", default=None) + parser.add_argument("--filter-by-ext", help="filter by extension", required=False) parser.add_argument( - "--log-tries", help="Save output of every try to given dir", default=None + "--log-tries", help="Save output of every try to given dir", required=False ) return parser -def get_args(): - return get_default_parser().parse_args() +def get_args() -> ProgramArgs: + return ProgramArgs(get_default_parser().parse_args()) diff --git a/verified_cogen/experiments/use_houdini.py b/verified_cogen/experiments/use_houdini.py index 1465ae8..e955d8c 100644 --- a/verified_cogen/experiments/use_houdini.py +++ b/verified_cogen/experiments/use_houdini.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Optional +from typing import Optional, no_type_check from verified_cogen.llm import LLM from verified_cogen.runners import LLM_GENERATED_DIR @@ -12,6 +12,24 @@ log = logging.getLogger(__name__) +class ProgramArgs: + grazie_token: str + profile: str + prompt_dir: str + program: str + verifier_command: str + + @no_type_check + def __init__(self, *args): + ( + self.grazie_token, + self.profile, + self.prompt_dir, + self.program, + self.verifier_command, + ) = args + + INVARIANTS_JSON_PROMPT = """Given the following Rust program, output Verus invariants that should go into the `while` loop in the function {function}. Ensure that the invariants are as comprehensive as they can be. @@ -96,9 +114,9 @@ """ -def collect_invariants(args, prg: str): +def collect_invariants(args: ProgramArgs, prg: str) -> list[str]: func = basename(args.program)[:-3] - result_invariants = [] + result_invariants: list[str] = [] for temperature in [0.0, 0.1, 0.3, 0.4, 0.5, 0.7, 1.0]: llm = LLM( grazie_token=args.grazie_token, @@ -110,7 +128,7 @@ def collect_invariants(args, prg: str): llm.user_prompts.append( INVARIANTS_JSON_PROMPT.replace("{program}", prg).replace("{function}", func) ) - response = llm._make_request() + response = llm._make_request() # type: ignore try: invariants = json.loads(response) result_invariants.extend(invariants) @@ -126,7 +144,7 @@ def remove_failed_invariants( llm: LLM, invariants: list[str], err: str ) -> Optional[list[str]]: llm.user_prompts.append(REMOVE_FAILED_INVARIANTS_PROMPT.format(error=err)) - response = llm._make_request() + response = llm._make_request() # type: ignore try: new_invariants = json.loads(response) log.debug("REMOVED: {}".format(set(invariants).difference(set(new_invariants)))) @@ -138,7 +156,7 @@ def remove_failed_invariants( def houdini( - args, verifier: Verifier, prg: str, invariants: list[str] + args: ProgramArgs, verifier: Verifier, prg: str, invariants: list[str] ) -> Optional[list[str]]: func = basename(args.program).strip(".rs") log.info(f"Starting Houdini for {func} in file {args.program}") @@ -201,7 +219,7 @@ def main(): parser.add_argument("--program", required=True) parser.add_argument("--verifier-command", required=True) - args = parser.parse_args() + args = ProgramArgs(*parser.parse_args()) log.info("Running on program: {}".format(args.program)) diff --git a/verified_cogen/llm/llm.py b/verified_cogen/llm/llm.py index b7dbe21..c5cc99b 100644 --- a/verified_cogen/llm/llm.py +++ b/verified_cogen/llm/llm.py @@ -3,6 +3,7 @@ from typing import Optional from grazie.api.client.chat.prompt import ChatPrompt +from grazie.api.client.chat.response import ChatResponse from grazie.api.client.endpoints import GrazieApiGatewayUrls from grazie.api.client.gateway import AuthType, GrazieApiGatewayClient from grazie.api.client.llm_parameters import LLMParameters @@ -32,15 +33,17 @@ def __init__( self.profile = Profile.get_by_name(profile) self.prompt_dir = prompt_dir self.is_gpt = "gpt" in self.profile.name - self.user_prompts = [] - self.responses = [] + self.user_prompts: list[str] = [] + self.responses: list[str] = [] self.had_errors = False self.temperature = temperature self.system_prompt = ( system_prompt if system_prompt else prompts.sys_prompt(self.prompt_dir) ) - def _request(self, temperature: Optional[float] = None, tries: int = 5): + def _request( + self, temperature: Optional[float] = None, tries: int = 5 + ) -> ChatResponse: if tries == 0: raise Exception("Exhausted tries to get response from Grazie API") if temperature is None: @@ -70,31 +73,31 @@ def _request(self, temperature: Optional[float] = None, tries: int = 5): logger.warning("Grazie API is down, retrying...") return self._request(temperature, tries - 1) - def _make_request(self): + def _make_request(self) -> str: response = self._request().content self.responses.append(response) return extract_code_from_llm_output(response) - def produce(self, prg: str): + def produce(self, prg: str) -> str: self.user_prompts.append( prompts.produce_prompt(self.prompt_dir).format(program=prg) ) return self._make_request() - def add(self, prg: str, checks: str, function: Optional[str] = None): + def add(self, prg: str, checks: str, function: Optional[str] = None) -> str: prompt = prompts.add_prompt(self.prompt_dir).format(program=prg, checks=checks) if "{function}" in prompt and function is not None: prompt = prompt.replace("{function}", function) self.user_prompts.append(prompt) return self._make_request() - def rewrite(self, prg: str): + def rewrite(self, prg: str) -> str: self.user_prompts.append( prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg) ) return self._make_request() - def ask_for_fixed(self, err: str): + def ask_for_fixed(self, err: str) -> str: prompt = ( prompts.ask_for_fixed_had_errors_prompt(self.prompt_dir) if self.had_errors @@ -103,6 +106,6 @@ def ask_for_fixed(self, err: str): self.user_prompts.append(prompt.format(error=err)) return self._make_request() - def ask_for_timeout(self): + def ask_for_timeout(self) -> str: self.user_prompts.append(prompts.ask_for_timeout_prompt(self.prompt_dir)) return self._make_request() diff --git a/verified_cogen/llm/prompts.py b/verified_cogen/llm/prompts.py index bb43a3c..19650e1 100644 --- a/verified_cogen/llm/prompts.py +++ b/verified_cogen/llm/prompts.py @@ -1,19 +1,19 @@ -from typing import Optional +from typing import Any, Optional class Singleton(object): _instance = None - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]): if not isinstance(cls._instance, cls): cls._instance = super().__new__(cls, *args, **kwargs) return cls._instance class PromptCache(Singleton): - cache: dict = {} + cache: dict[str, str] = {} - def __init__(self): + def __init__(self, *args: list[Any], **kwargs: dict[str, Any]): self.cache = {} def get(self, key: str) -> Optional[str]: diff --git a/verified_cogen/main.py b/verified_cogen/main.py index 94030a6..0709b63 100644 --- a/verified_cogen/main.py +++ b/verified_cogen/main.py @@ -1,7 +1,7 @@ import logging import pathlib -from verified_cogen.args import get_args +from verified_cogen.args import ProgramArgs, get_args from verified_cogen.llm import LLM from verified_cogen.runners.generate import GenerateRunner from verified_cogen.runners.generic import GenericRunner @@ -29,26 +29,28 @@ def run_once( files: list[Path], - args, + args: ProgramArgs, runner_cls: Callable[[LLM, Logger, Verifier], Runner], verifier: Verifier, mode: Mode, is_once: bool, ) -> tuple[int, int, int, dict[str, int]]: - success, success_zero_tries, failed = [], [], [] - cnt = dict() + _init: tuple[list[str], list[str], list[str]] = ([], [], []) + success, success_zero_tries, failed = _init + + cnt: dict[str, int] = dict() for file in files: llm = LLM( - args.grazie_token, - args.llm_profile, - args.prompts_directory, - args.temperature, + args.grazie_token, # type: ignore + args.llm_profile, # type: ignore + args.prompts_directory, # type: ignore + args.temperature, # type: ignore ) runner = runner_cls(llm, logger, verifier) - retries = args.retries + 1 + retries = args.retries + 1 # type: ignore tries = None while retries > 0 and tries is None: tries = runner.run_on_file(mode, args.tries, str(file)) @@ -181,6 +183,7 @@ def main(): json.dump({k: v / args.runs for k, v in total_cnt.items()}, f) else: + assert args.input is not None, "input file must be specified" llm = LLM( args.grazie_token, args.llm_profile, diff --git a/verified_cogen/runners/languages/dafny.py b/verified_cogen/runners/languages/dafny.py index b9c12e8..d6a58c4 100644 --- a/verified_cogen/runners/languages/dafny.py +++ b/verified_cogen/runners/languages/dafny.py @@ -12,7 +12,7 @@ class DafnyLanguage(GenericLanguage): method_regex: Pattern[str] - def __init__(self): + def __init__(self): # type: ignore super().__init__( re.compile( r"method\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{", re.DOTALL diff --git a/verified_cogen/runners/languages/language.py b/verified_cogen/runners/languages/language.py index 72be824..388feff 100644 --- a/verified_cogen/runners/languages/language.py +++ b/verified_cogen/runners/languages/language.py @@ -1,18 +1,18 @@ from abc import abstractmethod -from typing import Pattern +from typing import Pattern, Any import re class Language: _instance = None - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]): if not isinstance(cls._instance, cls): cls._instance = super().__new__(cls, *args, **kwargs) return cls._instance @abstractmethod - def __init__(self): ... + def __init__(self, *args: list[Any], **kwargs: dict[str, Any]): ... @abstractmethod def generate_validators(self, code: str) -> str: ... @@ -27,7 +27,7 @@ class GenericLanguage(Language): assert_invariant_patterns: list[str] inline_assert_comment: str - def __init__( + def __init__( # type: ignore self, method_regex: Pattern[str], validator_template: str, @@ -43,7 +43,7 @@ def generate_validators(self, code: str) -> str: code = re.sub(r"^ *#.*(\r\n|\r|\n)?", "", code, flags=re.MULTILINE) methods = self.method_regex.finditer(code) - validators = [] + validators: list[str] = [] for match in methods: method_name, parameters, returns, specs = ( @@ -86,7 +86,7 @@ class LanguageDatabase: languages: dict[str, Language] = dict() regularise: dict[str, str] = dict() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]): if not isinstance(cls._instance, cls): cls._instance = super().__new__(cls, *args, **kwargs) return cls._instance diff --git a/verified_cogen/runners/languages/nagini.py b/verified_cogen/runners/languages/nagini.py index 0a76df9..b5b7455 100644 --- a/verified_cogen/runners/languages/nagini.py +++ b/verified_cogen/runners/languages/nagini.py @@ -13,7 +13,7 @@ def {method_name}_valid({parameters}) -> {returns}:{specs}\ class NaginiLanguage(GenericLanguage): method_regex: Pattern[str] - def __init__(self): + def __init__(self): # type: ignore super().__init__( re.compile( r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(:?(?:\r\n|\r|\n)?( *(?:Requires|Ensures)\([^\r\n]*\)(?:\r\n|\r|\n)?)*)", diff --git a/verified_cogen/tools/__init__.py b/verified_cogen/tools/__init__.py index 3ecf268..8470633 100644 --- a/verified_cogen/tools/__init__.py +++ b/verified_cogen/tools/__init__.py @@ -2,11 +2,11 @@ import re from typing import Optional -import appdirs +import appdirs # type: ignore -def get_cache_dir(): - return appdirs.user_cache_dir("verified-cogen", "jetbrains.research") +def get_cache_dir() -> str: + return appdirs.user_cache_dir("verified-cogen", "jetbrains.research") # type: ignore def basename(path: str): @@ -34,11 +34,11 @@ def extension_from_file_list(files: list[pathlib.Path]) -> str: return extension -def pprint_stat(name: str, stat: int, total: int, runs=1): +def pprint_stat(name: str, stat: int, total: int, runs: int = 1): print(f"{name}: {stat / runs} ({stat / (total * runs) * 100:.2f}%)") -def tabulate_list(lst: list): +def tabulate_list(lst: list[str]) -> str: return "\n\t - " + "\n\t - ".join(lst)