Skip to content

Commit

Permalink
chore: typing pass
Browse files Browse the repository at this point in the history
  • Loading branch information
WeetHet committed Sep 19, 2024
1 parent 784dc90 commit 4fa5a7e
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ data
llm-generated
.envrc
**/.DS_Store
pyrightconfig.json
benches/DafnyBench/
benches/HumanEval-Dafny-Mini/
break_assert.rs
Expand All @@ -13,6 +12,7 @@ break_assert.rs
log
.ruff_cache
.vscode
.zed
run.sh
/dist/
**/.pytest_cache
Expand Down
3 changes: 3 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"typeCheckingMode": "strict"
}
49 changes: 45 additions & 4 deletions verified_cogen/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,47 @@
import os

from verified_cogen.tools.modes import VALID_MODES
from typing import no_type_check, Optional


class ProgramArgs:
input: Optional[str]
dir: Optional[str]
runs: int
insert_conditions_mode: str
bench_type: str
temperature: int
shell: str
verifier_command: str
verifier_timeout: int
prompts_directory: str
grazie_token: str
llm_profile: str
tries: int
retries: int
output_style: str
filter_by_ext: Optional[str]
log_tries: Optional[str]

@no_type_check
def __init__(self, args):
self.input = args.input
self.dir = args.dir
self.runs = args.runs
self.insert_conditions_mode = args.insert_conditions_mode
self.bench_type = args.bench_type
self.temperature = args.temperature
self.shell = args.shell
self.verifier_command = args.verifier_command
self.verifier_timeout = args.verifier_timeout
self.prompts_directory = args.prompts_directory
self.grazie_token = args.grazie_token
self.llm_profile = args.llm_profile
self.tries = args.tries
self.retries = args.retries
self.output_style = args.output_style
self.filter_by_ext = args.filter_by_ext
self.log_tries = args.log_tries


def get_default_parser():
Expand Down Expand Up @@ -50,12 +91,12 @@ def get_default_parser():
parser.add_argument(
"-s", "--output-style", choices=["stats", "full"], default="full"
)
parser.add_argument("--filter-by-ext", help="filter by extension", default=None)
parser.add_argument("--filter-by-ext", help="filter by extension", required=False)
parser.add_argument(
"--log-tries", help="Save output of every try to given dir", default=None
"--log-tries", help="Save output of every try to given dir", required=False
)
return parser


def get_args():
return get_default_parser().parse_args()
def get_args() -> ProgramArgs:
return ProgramArgs(get_default_parser().parse_args())
32 changes: 25 additions & 7 deletions verified_cogen/experiments/use_houdini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import os
from typing import Optional
from typing import Optional, no_type_check

from verified_cogen.llm import LLM
from verified_cogen.runners import LLM_GENERATED_DIR
Expand All @@ -12,6 +12,24 @@
log = logging.getLogger(__name__)


class ProgramArgs:
grazie_token: str
profile: str
prompt_dir: str
program: str
verifier_command: str

@no_type_check
def __init__(self, *args):
(
self.grazie_token,
self.profile,
self.prompt_dir,
self.program,
self.verifier_command,
) = args


INVARIANTS_JSON_PROMPT = """Given the following Rust program, output Verus invariants that should go into the `while` loop
in the function {function}.
Ensure that the invariants are as comprehensive as they can be.
Expand Down Expand Up @@ -96,9 +114,9 @@
"""


def collect_invariants(args, prg: str):
def collect_invariants(args: ProgramArgs, prg: str) -> list[str]:
func = basename(args.program)[:-3]
result_invariants = []
result_invariants: list[str] = []
for temperature in [0.0, 0.1, 0.3, 0.4, 0.5, 0.7, 1.0]:
llm = LLM(
grazie_token=args.grazie_token,
Expand All @@ -110,7 +128,7 @@ def collect_invariants(args, prg: str):
llm.user_prompts.append(
INVARIANTS_JSON_PROMPT.replace("{program}", prg).replace("{function}", func)
)
response = llm._make_request()
response = llm._make_request() # type: ignore
try:
invariants = json.loads(response)
result_invariants.extend(invariants)
Expand All @@ -126,7 +144,7 @@ def remove_failed_invariants(
llm: LLM, invariants: list[str], err: str
) -> Optional[list[str]]:
llm.user_prompts.append(REMOVE_FAILED_INVARIANTS_PROMPT.format(error=err))
response = llm._make_request()
response = llm._make_request() # type: ignore
try:
new_invariants = json.loads(response)
log.debug("REMOVED: {}".format(set(invariants).difference(set(new_invariants))))
Expand All @@ -138,7 +156,7 @@ def remove_failed_invariants(


def houdini(
args, verifier: Verifier, prg: str, invariants: list[str]
args: ProgramArgs, verifier: Verifier, prg: str, invariants: list[str]
) -> Optional[list[str]]:
func = basename(args.program).strip(".rs")
log.info(f"Starting Houdini for {func} in file {args.program}")
Expand Down Expand Up @@ -201,7 +219,7 @@ def main():
parser.add_argument("--program", required=True)
parser.add_argument("--verifier-command", required=True)

args = parser.parse_args()
args = ProgramArgs(*parser.parse_args())

log.info("Running on program: {}".format(args.program))

Expand Down
21 changes: 12 additions & 9 deletions verified_cogen/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

from grazie.api.client.chat.prompt import ChatPrompt
from grazie.api.client.chat.response import ChatResponse
from grazie.api.client.endpoints import GrazieApiGatewayUrls
from grazie.api.client.gateway import AuthType, GrazieApiGatewayClient
from grazie.api.client.llm_parameters import LLMParameters
Expand Down Expand Up @@ -32,15 +33,17 @@ def __init__(
self.profile = Profile.get_by_name(profile)
self.prompt_dir = prompt_dir
self.is_gpt = "gpt" in self.profile.name
self.user_prompts = []
self.responses = []
self.user_prompts: list[str] = []
self.responses: list[str] = []
self.had_errors = False
self.temperature = temperature
self.system_prompt = (
system_prompt if system_prompt else prompts.sys_prompt(self.prompt_dir)
)

def _request(self, temperature: Optional[float] = None, tries: int = 5):
def _request(
self, temperature: Optional[float] = None, tries: int = 5
) -> ChatResponse:
if tries == 0:
raise Exception("Exhausted tries to get response from Grazie API")
if temperature is None:
Expand Down Expand Up @@ -70,31 +73,31 @@ def _request(self, temperature: Optional[float] = None, tries: int = 5):
logger.warning("Grazie API is down, retrying...")
return self._request(temperature, tries - 1)

def _make_request(self):
def _make_request(self) -> str:
response = self._request().content
self.responses.append(response)
return extract_code_from_llm_output(response)

def produce(self, prg: str):
def produce(self, prg: str) -> str:
self.user_prompts.append(
prompts.produce_prompt(self.prompt_dir).format(program=prg)
)
return self._make_request()

def add(self, prg: str, checks: str, function: Optional[str] = None):
def add(self, prg: str, checks: str, function: Optional[str] = None) -> str:
prompt = prompts.add_prompt(self.prompt_dir).format(program=prg, checks=checks)
if "{function}" in prompt and function is not None:
prompt = prompt.replace("{function}", function)
self.user_prompts.append(prompt)
return self._make_request()

def rewrite(self, prg: str):
def rewrite(self, prg: str) -> str:
self.user_prompts.append(
prompts.rewrite_prompt(self.prompt_dir).replace("{program}", prg)
)
return self._make_request()

def ask_for_fixed(self, err: str):
def ask_for_fixed(self, err: str) -> str:
prompt = (
prompts.ask_for_fixed_had_errors_prompt(self.prompt_dir)
if self.had_errors
Expand All @@ -103,6 +106,6 @@ def ask_for_fixed(self, err: str):
self.user_prompts.append(prompt.format(error=err))
return self._make_request()

def ask_for_timeout(self):
def ask_for_timeout(self) -> str:
self.user_prompts.append(prompts.ask_for_timeout_prompt(self.prompt_dir))
return self._make_request()
8 changes: 4 additions & 4 deletions verified_cogen/llm/prompts.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from typing import Optional
from typing import Any, Optional


class Singleton(object):
_instance = None

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]):
if not isinstance(cls._instance, cls):
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance


class PromptCache(Singleton):
cache: dict = {}
cache: dict[str, str] = {}

def __init__(self):
def __init__(self, *args: list[Any], **kwargs: dict[str, Any]):
self.cache = {}

def get(self, key: str) -> Optional[str]:
Expand Down
21 changes: 12 additions & 9 deletions verified_cogen/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import pathlib

from verified_cogen.args import get_args
from verified_cogen.args import ProgramArgs, get_args
from verified_cogen.llm import LLM
from verified_cogen.runners.generate import GenerateRunner
from verified_cogen.runners.generic import GenericRunner
Expand Down Expand Up @@ -29,26 +29,28 @@

def run_once(
files: list[Path],
args,
args: ProgramArgs,
runner_cls: Callable[[LLM, Logger, Verifier], Runner],
verifier: Verifier,
mode: Mode,
is_once: bool,
) -> tuple[int, int, int, dict[str, int]]:
success, success_zero_tries, failed = [], [], []
cnt = dict()
_init: tuple[list[str], list[str], list[str]] = ([], [], [])
success, success_zero_tries, failed = _init

cnt: dict[str, int] = dict()

for file in files:
llm = LLM(
args.grazie_token,
args.llm_profile,
args.prompts_directory,
args.temperature,
args.grazie_token, # type: ignore
args.llm_profile, # type: ignore
args.prompts_directory, # type: ignore
args.temperature, # type: ignore
)

runner = runner_cls(llm, logger, verifier)

retries = args.retries + 1
retries = args.retries + 1 # type: ignore
tries = None
while retries > 0 and tries is None:
tries = runner.run_on_file(mode, args.tries, str(file))
Expand Down Expand Up @@ -181,6 +183,7 @@ def main():

json.dump({k: v / args.runs for k, v in total_cnt.items()}, f)
else:
assert args.input is not None, "input file must be specified"
llm = LLM(
args.grazie_token,
args.llm_profile,
Expand Down
2 changes: 1 addition & 1 deletion verified_cogen/runners/languages/dafny.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class DafnyLanguage(GenericLanguage):
method_regex: Pattern[str]

def __init__(self):
def __init__(self): # type: ignore
super().__init__(
re.compile(
r"method\s+(\w+)\s*\((.*?)\)\s*returns\s*\((.*?)\)(.*?)\{", re.DOTALL
Expand Down
12 changes: 6 additions & 6 deletions verified_cogen/runners/languages/language.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from abc import abstractmethod
from typing import Pattern
from typing import Pattern, Any
import re


class Language:
_instance = None

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]):
if not isinstance(cls._instance, cls):
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance

@abstractmethod
def __init__(self): ...
def __init__(self, *args: list[Any], **kwargs: dict[str, Any]): ...

@abstractmethod
def generate_validators(self, code: str) -> str: ...
Expand All @@ -27,7 +27,7 @@ class GenericLanguage(Language):
assert_invariant_patterns: list[str]
inline_assert_comment: str

def __init__(
def __init__( # type: ignore
self,
method_regex: Pattern[str],
validator_template: str,
Expand All @@ -43,7 +43,7 @@ def generate_validators(self, code: str) -> str:
code = re.sub(r"^ *#.*(\r\n|\r|\n)?", "", code, flags=re.MULTILINE)
methods = self.method_regex.finditer(code)

validators = []
validators: list[str] = []

for match in methods:
method_name, parameters, returns, specs = (
Expand Down Expand Up @@ -86,7 +86,7 @@ class LanguageDatabase:
languages: dict[str, Language] = dict()
regularise: dict[str, str] = dict()

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: list[Any], **kwargs: dict[str, Any]):
if not isinstance(cls._instance, cls):
cls._instance = super().__new__(cls, *args, **kwargs)
return cls._instance
Expand Down
2 changes: 1 addition & 1 deletion verified_cogen/runners/languages/nagini.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def {method_name}_valid({parameters}) -> {returns}:{specs}\
class NaginiLanguage(GenericLanguage):
method_regex: Pattern[str]

def __init__(self):
def __init__(self): # type: ignore
super().__init__(
re.compile(
r"def\s+(\w+)\s*\((.*?)\)\s*->\s*(.*?):(:?(?:\r\n|\r|\n)?( *(?:Requires|Ensures)\([^\r\n]*\)(?:\r\n|\r|\n)?)*)",
Expand Down
Loading

0 comments on commit 4fa5a7e

Please sign in to comment.