diff --git a/agent/one_prompt_prototyper.py b/agent/one_prompt_prototyper.py index 67261868ec..d533ad10c5 100644 --- a/agent/one_prompt_prototyper.py +++ b/agent/one_prompt_prototyper.py @@ -55,6 +55,12 @@ def _prompt_builder(self, # For Python projects return prompt_builder.DefaultPythonTemplateBuilder( self.llm, benchmark, self.args.template_directory) + + if benchmark.language == 'go': + # For golang projects + return prompt_builder.DefaultGoTemplateBuilder( + self.llm, benchmark, self.args.template_directory) + if benchmark.language == 'rust': # For Rust projects return prompt_builder.DefaultRustTemplateBuilder( diff --git a/benchmark-sets/go-small/atomic.yaml b/benchmark-sets/go-small/atomic.yaml new file mode 100644 index 0000000000..40a7606e8e --- /dev/null +++ b/benchmark-sets/go-small/atomic.yaml @@ -0,0 +1,45 @@ +"functions": +- "name": "run" + "params": + - "name": "args" + "type": "[]string" + "return_type": "error" + "signature": "run([]string) error" +- "name": "BenchmarkStress" + "params": + - "name": "b" + "type": "B" + - "name": "b" + "type": "B" + - "name": "b" + "type": "B" + - "name": "b" + "type": "B" + - "name": "pb" + "type": "PB" + "return_type": "void" + "signature": "BenchmarkStress(*testing.B,*testing.B,*testing.B,*testing.B,*testing.PB)" +- "name": "*stringList.Set" + "params": + - "name": "s" + "type": "string" + "return_type": "error" + "signature": "(*stringList) *stringList.Set(string) error" +- "name": "*Bool.Toggle" + "params": + - "name": "old" + "type": "bool" + "return_type": "(old bool)" + "signature": "(*Bool) *Bool.Toggle(bool) (old bool)" +- "name": "*String.CompareAndSwap" + "params": + - "name": "swapped" + "type": "bool" + - "name": "old" + "type": "string" + "return_type": "(swapped bool)" + "signature": "(*String) *String.CompareAndSwap(bool,string) (swapped bool)" +"language": "go" +"project": "atomic" +"target_name": "fuzz_test" +"target_path": "/src/atomic/fuzz_test.go" diff --git a/benchmark-sets/go-small/clock.yaml b/benchmark-sets/go-small/clock.yaml new file mode 100644 index 0000000000..1b3f79a6f1 --- /dev/null +++ b/benchmark-sets/go-small/clock.yaml @@ -0,0 +1,49 @@ +"functions": +- "name": "*clock.WithDeadline" + "params": + - "name": "parent" + "type": "Context" + - "name": "d" + "type": "Time" + - "name": "context.Context" + "type": "Context" + - "name": "context.CancelFunc" + "type": "CancelFunc" + "return_type": "(context.Context, context.CancelFunc)" + "signature": "(*clock) *clock.WithDeadline(context.Context,time.Time,context.Context,context.CancelFunc) (context.Context, context.CancelFunc)" +- "name": "*clock.WithTimeout" + "params": + - "name": "context.Context" + "type": "Context" + - "name": "context.CancelFunc" + "type": "CancelFunc" + - "name": "parent" + "type": "Context" + - "name": "t" + "type": "Duration" + "return_type": "(context.Context, context.CancelFunc)" + "signature": "(*clock) *clock.WithTimeout(context.Context,context.CancelFunc,context.Context,time.Duration) (context.Context, context.CancelFunc)" +- "name": "warnf" + "params": + - "name": "msg" + "type": "string" + - "name": "v" + "type": "interface{}" + "return_type": "void" + "signature": "warnf(string,interface{})" +- "name": "*clock.Until" + "params": + - "name": "t" + "type": "Time" + "return_type": "time.Duration" + "signature": "(*clock) *clock.Until(time.Time) time.Duration" +- "name": "*clock.Since" + "params": + - "name": "t" + "type": "Time" + "return_type": "time.Duration" + "signature": "(*clock) *clock.Since(time.Time) time.Duration" +"language": "go" +"project": "clock" +"target_name": "fuzz_test" +"target_path": "/src/clock/fuzz_test.go" diff --git a/benchmark-sets/go-small/nats.yaml b/benchmark-sets/go-small/nats.yaml new file mode 100644 index 0000000000..2b1b91b7dd --- /dev/null +++ b/benchmark-sets/go-small/nats.yaml @@ -0,0 +1,51 @@ +"functions": +- "name": "*cluster.removeJetStream" + "params": + - "name": "s" + "type": "*Server" + "return_type": "void" + "signature": "(*cluster) *cluster.removeJetStream(*Server)" +- "name": "reloadUpdateConfig" + "params": + - "name": "t" + "type": "T" + - "name": "s" + "type": "*Server" + - "name": "conf" + "type": "string" + "return_type": "void" + "signature": "reloadUpdateConfig(*testing.T,*Server,string)" +- "name": "*Server.reloadConfig" + "params": + - "name": "sub" + "type": "*subscription" + - "name": "c" + "type": "*client" + - "name": "_" + "type": "*Account" + - "name": "subject" + "type": "string" + - "name": "hdr" + "type": "[]byte" + - "name": "any" + "type": "any" + - "name": "error" + "type": "error" + "return_type": "void" + "signature": "(*Server) *Server.reloadConfig(*subscription,*client,*Account,string,[]byte,any,error)" +- "name": "*Server.ReloadOptions" + "params": + - "name": "newOpts" + "type": "*Options" + "return_type": "error" + "signature": "(*Server) *Server.ReloadOptions(*Options) error" +- "name": "*Server.reloadOptions" + "params": + - "name": "curOpts" + "type": "*Options" + "return_type": "error" + "signature": "(*Server) *Server.reloadOptions(*Options) error" +"language": "go" +"project": "nats" +"target_name": "fuzz" +"target_path": "/src/nats-server/server/fuzz.go" diff --git a/benchmark-sets/go-small/roughtime.yaml b/benchmark-sets/go-small/roughtime.yaml new file mode 100644 index 0000000000..5afe8f9535 --- /dev/null +++ b/benchmark-sets/go-small/roughtime.yaml @@ -0,0 +1,69 @@ +"functions": +- "name": "FuzzParseRequest" + "params": + - "name": "f" + "type": "F" + - "name": "t" + "type": "T" + - "name": "data" + "type": "[]byte" + "return_type": "void" + "signature": "FuzzParseRequest(*testing.F,*testing.T,[]byte)" +- "name": "DoFromFile" + "params": + - "name": "configFile" + "type": "string" + - "name": "attempts" + "type": "int" + - "name": "timeout" + "type": "Duration" + - "name": "prev" + "type": "*Roughtime" + - "name": "[]Result" + "type": "[]Result" + - "name": "error" + "type": "error" + "return_type": "([]Result, error)" + "signature": "DoFromFile(string,int,time.Duration,*Roughtime,[]Result,error) ([]Result, error)" +- "name": "Do" + "params": + - "name": "servers" + "type": "Server" + - "name": "attempts" + "type": "int" + - "name": "timeout" + "type": "Duration" + - "name": "prev" + "type": "*Roughtime" + "return_type": "[]Result" + "signature": "Do([]config.Server,int,time.Duration,*Roughtime) []Result" +- "name": "Get" + "params": + - "name": "server" + "type": "Server" + - "name": "attempts" + "type": "int" + - "name": "timeout" + "type": "Duration" + - "name": "prev" + "type": "*Roughtime" + - "name": "*Roughtime" + "type": "*Roughtime" + - "name": "error" + "type": "error" + "return_type": "(*Roughtime, error)" + "signature": "Get(*config.Server,int,time.Duration,*Roughtime,*Roughtime,error) (*Roughtime, error)" +- "name": "createServerIdentity" + "params": + - "name": "t" + "type": "T" + - "name": "cert" + "type": "*Certificate" + - "name": "rootPublicKey" + "type": "[]byte" + "return_type": "(cert *Certificate, rootPublicKey []byte)" + "signature": "createServerIdentity(*testing.T,*Certificate,[]byte) (cert *Certificate, rootPublicKey []byte)" +"language": "go" +"project": "roughtime" +"target_name": "protocol_test" +"target_path": "/src/roughtime/protocol/protocol_test.go" diff --git a/benchmark-sets/go-small/smt.yaml b/benchmark-sets/go-small/smt.yaml new file mode 100644 index 0000000000..ede815288e --- /dev/null +++ b/benchmark-sets/go-small/smt.yaml @@ -0,0 +1,51 @@ +"functions": +- "name": "bulkOperations" + "params": + - "name": "t" + "type": "T" + - "name": "operations" + "type": "int" + - "name": "insert" + "type": "int" + - "name": "update" + "type": "int" + - "name": "delete" + "type": "int" + "return_type": "void" + "signature": "bulkOperations(*testing.T,int,int,int,int)" +- "name": "BenchmarkSparseMerkleTree_Delete" + "params": + - "name": "b" + "type": "B" + "return_type": "void" + "signature": "BenchmarkSparseMerkleTree_Delete(*testing.B)" +- "name": "bulkCheckAll" + "params": + - "name": "t" + "type": "T" + - "name": "smt" + "type": "*SparseMerkleTree" + - "name": "kv" + "type": "*map[string]string" + "return_type": "void" + "signature": "bulkCheckAll(*testing.T,*SparseMerkleTree,*map[string]string)" +- "name": "BenchmarkSparseMerkleTree_Update" + "params": + - "name": "b" + "type": "B" + "return_type": "void" + "signature": "BenchmarkSparseMerkleTree_Update(*testing.B)" +- "name": "*SparseMerkleTree.DeleteForRoot" + "params": + - "name": "[]byte" + "type": "[]byte" + - "name": "error" + "type": "error" + - "name": "key" + "type": "[]byte" + "return_type": "([]byte, error)" + "signature": "(*SparseMerkleTree) *SparseMerkleTree.DeleteForRoot([]byte,error,[]byte) ([]byte, error)" +"language": "go" +"project": "smt" +"target_name": "fuzz" +"target_path": "/src/smt/fuzz/delete/fuzz.go" diff --git a/benchmark-sets/go-small/time.yaml b/benchmark-sets/go-small/time.yaml new file mode 100644 index 0000000000..483911e0da --- /dev/null +++ b/benchmark-sets/go-small/time.yaml @@ -0,0 +1,41 @@ +"functions": +- "name": "*SPTP.Run" + "params": + - "name": "ctx" + "type": "Context" + "return_type": "error" + "signature": "(*SPTP) *SPTP.Run(context.Context) error" +- "name": "*Server.Start" + "params": + - "name": "i" + "type": "int" + "return_type": "error" + "signature": "(*Server) *Server.Start(int) error" +- "name": "*Daemon.Run" + "params": + - "name": "ctx" + "type": "Context" + "return_type": "error" + "signature": "(*Daemon) *Daemon.Run(context.Context) error" +- "name": "runTrace" + "params": + - "name": "cfg" + "type": "Config" + - "name": "m" + "type": "MeasurementResult" + "return_type": "error" + "signature": "runTrace(*client.Config,*client.MeasurementResult) error" +- "name": "*Sender.Start" + "params": + - "name": "t" + "type": "traceTask" + - "name": "[]*PathInfo" + "type": "[]*PathInfo" + - "name": "error" + "type": "error" + "return_type": "([]*PathInfo, error)" + "signature": "(*Sender) *Sender.Start(traceTask,[]*PathInfo,error) ([]*PathInfo, error)" +"language": "go" +"project": "time" +"target_name": "ntp_test" +"target_path": "/src/time/ntp/protocol/ntp_test.go" diff --git a/data_prep/introspector.py b/data_prep/introspector.py index 17d279250c..f9d7b70364 100755 --- a/data_prep/introspector.py +++ b/data_prep/introspector.py @@ -887,7 +887,7 @@ def populate_benchmarks_using_introspector(project: str, language: str, logger.error('error: %s %s', filename, interesting.keys()) continue - elif (language not in ['rust'] and interesting and + elif (language not in ['rust', 'go'] and interesting and filename not in [os.path.basename(i) for i in interesting.keys()]): # TODO: Bazel messes up paths to include "/proc/self/cwd/..." logger.error('error: %s %s', filename, interesting.keys()) diff --git a/data_prep/project_src.py b/data_prep/project_src.py index 22d89531d4..cdb226b869 100755 --- a/data_prep/project_src.py +++ b/data_prep/project_src.py @@ -99,6 +99,10 @@ def _get_harness(src_file: str, out: str, language: str) -> tuple[str, str]: if language.lower() == 'rust' and 'fuzz_target!' not in content: return '', '' + if language.lower() == 'go' and any( + target not in content for target in ['testing.F', 'testing.T', '.Fuzz']): + return '', '' + short_path = src_file[len(out):] return short_path, content @@ -309,6 +313,12 @@ def _identify_fuzz_targets(out: str, interesting_filenames: list[str], interesting_filepaths.append(path) if path.endswith('.py'): potential_harnesses.append(path) + elif language == 'go': + # For Rust + if path.endswith(tuple(interesting_filenames)): + interesting_filepaths.append(path) + if path.endswith(('.go', '.cgo')): + potential_harnesses.append(path) elif language == 'rust': # For Rust if path.endswith(tuple(interesting_filenames)): diff --git a/experiment/benchmark.py b/experiment/benchmark.py index d100b37bdb..3407824cf3 100644 --- a/experiment/benchmark.py +++ b/experiment/benchmark.py @@ -204,6 +204,15 @@ def __init__(self, # zipp-zipp.difference. self.id = self.id.replace('._', '.') + if self.language == 'go': + # For golang projects, full signature of functions/methods can contains + # special characters that result in confusion in the directory name of + # benchmarks. + self.id = self.id.replace('*', '').replace('&', '') + self.id = self.id.replace('<', '').replace('>', '') + self.id = self.id.replace('[', '').replace(']', '') + self.id = self.id.replace('(', '_').replace(')', '').replace(',', '_') + if self.language == 'rust': # For rust projects, double colon (::) is sometime used to identify # crate, impl or trait name of a function. This could affect the diff --git a/experiment/builder_runner.py b/experiment/builder_runner.py index 7c310f93da..6fae8fdaab 100644 --- a/experiment/builder_runner.py +++ b/experiment/builder_runner.py @@ -195,9 +195,26 @@ def _contains_target_python_function(self, target_path: str) -> bool: return min_func_name in generated_code + def _contains_target_go_function(self, target_path: str) -> bool: + """Validates if the LLM-generated code contains the target function for + go projects.""" + + with open(target_path) as generated_code_file: + generated_code = generated_code_file.read() + + min_func_name = self._get_minimum_func_name( + self.benchmark.function_signature) + + # Retrieve function name only without packages + min_func_name = min_func_name.rsplit('.', 1)[-1] + min_func_name = min_func_name.split('(', 1)[0] + + return min_func_name in generated_code + def _contains_target_rust_function(self, target_path: str) -> bool: """Validates if the LLM-generated code contains the target function for rust projects.""" + with open(target_path) as generated_code_file: generated_code = generated_code_file.read() @@ -219,6 +236,8 @@ def _pre_build_check(self, target_path: str, result = self._contains_target_jvm_method(target_path) elif self.benchmark.language == 'python': result = self._contains_target_python_function(target_path) + elif self.benchmark.language == 'go': + result = self._contains_target_go_function(target_path) elif self.benchmark.language == 'rust': result = self._contains_target_rust_function(target_path) else: @@ -500,8 +519,8 @@ def build_and_run_local( build_result.succeeded = self.build_target_local(generated_project, benchmark_log_path) - # Copy err.log into work dir (Ignored for JVM/Rust projects) - if language not in ['jvm', 'rust']: + # Copy err.log into work dir (Ignored for JVM projects) + if language not in ['jvm', 'rust', 'go']: try: shutil.copyfile( os.path.join(get_build_artifact_dir(generated_project, "workspace"), @@ -532,8 +551,8 @@ def build_and_run_local( # In many case JVM/python projects won't have much cov # difference in short running. Adding the flag for JVM/python # projects to temporary skip the checking of coverage change. - # Also skipping for rust projects in initial implementation. - flag = not self.benchmark.language in ['jvm', 'python', 'rust'] + # Also skipping for rust/golang projects in initial implementation. + flag = not self.benchmark.language in ['jvm', 'python', 'rust', 'go'] run_result.cov_pcs, run_result.total_pcs, \ run_result.crashes, run_result.crash_info, \ run_result.semantic_check = \ @@ -703,6 +722,7 @@ def _get_coverage_text_filename(self, project_name: str) -> str: 'python': 'all_cov.json', 'c++': f'{self.benchmark.target_name}.covreport', 'c': f'{self.benchmark.target_name}.covreport', + 'go': 'fuzz.cov', 'rust': f'{self.benchmark.target_name}.covreport', } @@ -719,6 +739,7 @@ def _extract_local_textcoverage_data(self, 'python': 'r', 'c': 'rb', 'c++': 'rb', + 'go': 'r', 'rust': 'rb', } with open(local_textcov_location, @@ -727,6 +748,8 @@ def _extract_local_textcoverage_data(self, new_textcov = textcov.Textcov.from_jvm_file(f) elif self.benchmark.language == 'python': new_textcov = textcov.Textcov.from_python_file(f) + elif self.benchmark.language == 'go': + new_textcov = textcov.Textcov.from_go_file(f) else: target_basename = os.path.basename(self.benchmark.target_path) new_textcov = textcov.Textcov.from_file( @@ -1102,6 +1125,8 @@ def _get_cloud_textcov_path(self, coverage_name: str) -> str: return f'{coverage_name}/textcov_reports/jacoco.xml' if self.benchmark.language == 'python': return f'{coverage_name}/textcov_reports/all_cov.json' + if self.benchmark.language == 'go': + return f'{coverage_name}/textcov_reports/fuzz.cov' # For C/C++/Rust return (f'{coverage_name}/textcov_reports/{self.benchmark.target_name}' diff --git a/experiment/evaluator.py b/experiment/evaluator.py index 0047ff7463..870f7ac5b1 100644 --- a/experiment/evaluator.py +++ b/experiment/evaluator.py @@ -152,6 +152,29 @@ def load_existing_python_textcov(project: str) -> textcov.Textcov: return textcov.Textcov.from_python_file(f) +def load_existing_go_textcov(project: str) -> textcov.Textcov: + """Loads existing textcovs for go project.""" + storage_client = storage.Client.create_anonymous_client() + bucket = storage_client.bucket(OSS_FUZZ_INTROSPECTOR_BUCKET) + blobs = storage_client.list_blobs(bucket, + prefix=f'{project}/inspector-report/', + delimiter='/') + # Iterate through all blobs first to get the prefixes (i.e. "subdirectories"). + for blob in blobs: + continue + + if not blobs.prefixes: # type: ignore + # No existing coverage reports. + logger.info('No existing coverage report. Using empty.') + return textcov.Textcov() + + latest_dir = sorted(blobs.prefixes)[-1] # type: ignore + blob = bucket.blob(f'{latest_dir}fuzz.cov') + logger.info('Loading existing fuzz.cov textcov from %s', blob.name) + with blob.open() as f: + return textcov.Textcov.from_go_file(f) + + def load_existing_rust_textcov(project: str) -> textcov.Textcov: """Loads existing textcovs for rust project.""" storage_client = storage.Client.create_anonymous_client() @@ -618,6 +641,9 @@ def load_existing_textcov(self) -> textcov.Textcov: if self.benchmark.language == 'python': return load_existing_python_textcov(self.benchmark.project) + if self.benchmark.language == 'go': + return load_existing_go_textcov(self.benchmark.project) + if self.benchmark.language == 'rust': return load_existing_rust_textcov(self.benchmark.project) diff --git a/experiment/textcov.py b/experiment/textcov.py index 0fc2f531ab..3bccc9fdf4 100644 --- a/experiment/textcov.py +++ b/experiment/textcov.py @@ -143,7 +143,7 @@ def subtract_covered_lines(self, other: Function, language: str = 'c++'): @dataclasses.dataclass class File: - """Represents a file in a textcov, only for Python.""" + """Represents a file in a textcov, only for Python / Golang.""" name: str = '' # Line contents -> Line object. We key on line contents to account for # potential line number movements. @@ -177,7 +177,7 @@ class Textcov: # For JVM / C / C++ / Rust functions: dict[str, Function] = dataclasses.field(default_factory=dict) # File name -> File object. - # For Python + # For Python / Go files: dict[str, File] = dataclasses.field(default_factory=dict) language: str = 'c++' @@ -293,6 +293,46 @@ def from_python_file(cls, file_handle) -> Textcov: return textcov + @classmethod + def from_go_file(cls, file_handle) -> Textcov: + """Read a textcov from a fuzz.cov file for golang project.""" + textcov = cls() + textcov.language = 'go' + line_coverage = {} + + # Extract the fuzz.cov coverage line information + cov_line = file_handle.readlines()[1:] + + # Process line coverage from fuzz.cov + # Line format + # :.,. + for line in cov_line: + file_name, data = line.split(':', 1) + line_split = re.split('[:., ]', data) + start_line = int(line_split[0]) + end_line = int(line_split[2]) + hit_count = int(line_split[5]) + + # Process line coverage information + line_dict = line_coverage.get(file_name, {}) + for count in range(start_line, end_line + 1): + hit_count = max(hit_count, line_dict.get(count, -1)) + line_dict[count] = hit_count + + line_coverage[file_name] = line_dict + + # Process coverage per file + for file_name, line_dict in line_coverage.items(): + current_file = File(name=file_name) + + for line_no, hit_count in line_dict.items(): + line = f'Line{line_no}' + current_file.lines[line] = Line(contents=line, hit_count=hit_count) + + textcov.files[file_name] = current_file + + return textcov + @classmethod def from_jvm_file(cls, file_handle) -> Textcov: """Read a textcov from a jacoco.xml file.""" @@ -451,7 +491,7 @@ def to_file(self, filename: str) -> None: """Writes covered functions/files and lines to |filename|.""" file_content = '' - if self.language == 'python': + if self.language in ['python', 'go']: target = self.files else: target = self.functions @@ -471,7 +511,7 @@ def merge(self, other: Textcov): if self.language != other.language and self.language == 'c++': self.language = other.language - if self.language == 'python': + if self.language in ['python', 'go']: for file in other.files.values(): if file.name not in self.files: self.files[file.name] = File(name=file.name) @@ -484,7 +524,7 @@ def merge(self, other: Textcov): def subtract_covered_lines(self, other: Textcov): """Diff another textcov""" - if self.language == 'python': + if self.language in ['python', 'go']: for file in other.files.values(): if file.name in self.files: self.files[file.name].subtract_covered_lines(file) @@ -496,14 +536,14 @@ def subtract_covered_lines(self, other: Textcov): @property def covered_lines(self): - if self.language == 'python': + if self.language in ['python', 'go']: return sum(f.covered_lines for f in self.files.values()) return sum(f.covered_lines for f in self.functions.values()) @property def total_lines(self): - if self.language == 'python': + if self.language in ['python', 'go']: return sum(len(f.lines) for f in self.files.values()) return sum(len(f.lines) for f in self.functions.values()) diff --git a/experimental/from_scratch/generate.py b/experimental/from_scratch/generate.py index 6be3305f15..3717568eba 100644 --- a/experimental/from_scratch/generate.py +++ b/experimental/from_scratch/generate.py @@ -104,7 +104,7 @@ def check_args(args) -> bool: not args.source_line): return True - print('You must include either:\n (1) target function name by --function;\n' + print('You must include either:\n (1) target function name by --function;\n ' '(2) target source file and line number by --source-file and ' '--source-line;\n (3) --far-reach') return False @@ -146,6 +146,7 @@ def get_target_benchmark( # Get target function if target_function_name: + print(type(project)) function = project.find_function_by_name(target_function_name, only_exact_match) @@ -194,6 +195,10 @@ def construct_fuzz_prompt(model, benchmark, context, """Local benchmarker""" if language in ['c', 'c++']: builder = prompt_builder.DefaultTemplateBuilder(model, benchmark=benchmark) + + elif language == 'go': + builder = prompt_builder.DefaultGoTemplateBuilder(model, + benchmark=benchmark) elif language == 'rust': builder = prompt_builder.DefaultRustTemplateBuilder(model, benchmark=benchmark) @@ -236,10 +241,11 @@ def introspector_lang_to_entrypoint(language: str) -> str: return 'LLVMFuzzerTestOneInput' if language == 'jvm': return 'fuzzerTestOneInput' + if language == 'rust': return 'fuzz_target' - # Not supporting other language yet + # Other supported languages have no fixed entry point return '' @@ -343,6 +349,8 @@ def get_introspector_language(args) -> str: return 'c++' if args.language in ['jvm', 'java']: return 'jvm' + if args.language in ['go', 'cgo']: + return 'go' if args.language in ['rs', 'rust']: return 'rust' diff --git a/llm_toolkit/output_parser.py b/llm_toolkit/output_parser.py index 023054999f..696c98d389 100755 --- a/llm_toolkit/output_parser.py +++ b/llm_toolkit/output_parser.py @@ -79,6 +79,7 @@ def parse_code(response_path: str) -> str: lines = _parse_code_block_by_marker(lines, '```python', '```') lines = _parse_code_block_by_marker(lines, '```rust', '```') lines = _parse_code_block_by_marker(lines, '```java_code', '```') + lines = _parse_code_block_by_marker(lines, '```go', '```') lines = _parse_code_block_by_marker(lines, '', '') lines = _parse_code_block_by_marker(lines, '', '') diff --git a/llm_toolkit/prompt_builder.py b/llm_toolkit/prompt_builder.py index 1e51dbe2bd..c65cf03d2b 100644 --- a/llm_toolkit/prompt_builder.py +++ b/llm_toolkit/prompt_builder.py @@ -1733,3 +1733,95 @@ def post_process_generated_code(self, generated_code: str) -> str: 'int LLVMFuzzerTestOneInput') return generated_code + + +class DefaultGoTemplateBuilder(PromptBuilder): + """Default builder for Go projects.""" + + def __init__(self, + model: models.LLM, + benchmark: Benchmark, + template_dir: str = DEFAULT_TEMPLATE_DIR): + super().__init__(model) + self._template_dir = template_dir + self.benchmark = benchmark + self.project_url = oss_fuzz_checkout.get_project_repository( + self.benchmark.project) + + # Load templates. + self.base_template_file = self._find_template(template_dir, 'go_base.txt') + self.problem_template_file = self._find_template(template_dir, + 'go_problem.txt') + + def _find_template(self, template_dir: str, template_name: str) -> str: + """Finds template file based on |template_dir|.""" + preferred_template = os.path.join(template_dir, template_name) + # Use the preferred template if it exists. + if os.path.isfile(preferred_template): + return preferred_template + + # Fall back to the default template. + default_template = os.path.join(DEFAULT_TEMPLATE_DIR, template_name) + return default_template + + def _get_template(self, template_file: str) -> str: + """Reads the template for prompts.""" + with open(template_file) as file: + return file.read() + + def _format_target(self, signature: str) -> str: + """Format the target function for the prompts creation.""" + target = self._get_template(self.problem_template_file) + arg_count = len(self.benchmark.params) + arg_type = [arg_dict['type'] for arg_dict in self.benchmark.params] + + target = target.replace('{FUNCTION_SIGNATURE}', signature) + target = target.replace('{ARG_COUNT}', str(arg_count)) + target = target.replace('{ARG_TYPE}', ','.join(arg_type)) + + return target + + def _format_problem(self, signature: str) -> str: + """Formats a problem based on the prompt template.""" + base = self._get_template(self.base_template_file) + target_str = self._format_target(signature) + + problem = base + target_str + problem = problem.replace("{PROJECT_NAME}", self.benchmark.project) + problem = problem.replace("{PROJECT_URL}", self.project_url) + + return problem + + def _prepare_prompt(self, prompt_str: str): + """Constructs a prompt using the parameters and saves it.""" + self._prompt.add_priming(prompt_str) + + def build(self, + example_pair: list[list[str]], + project_example_content: Optional[list[list[str]]] = None, + project_context_content: Optional[dict] = None) -> prompts.Prompt: + """Constructs a prompt using the templates in |self| and saves it. + Ignore target_file_type, project_example_content + and project_context_content parameters. + """ + final_problem = self._format_problem(self.benchmark.function_signature) + self._prepare_prompt(final_problem) + return self._prompt + + def build_fixer_prompt(self, benchmark: Benchmark, raw_code: str, + error_desc: Optional[str], + errors: list[str]) -> prompts.Prompt: + """Builds a fixer prompt.""" + # Do nothing for go project now. + return self._prompt + + def build_triager_prompt(self, benchmark: Benchmark, driver_code: str, + crash_info: str, crash_func: dict) -> prompts.Prompt: + """Builds a triager prompt.""" + # Do nothing for go project now. + return self._prompt + + def post_process_generated_code(self, generated_code: str) -> str: + """Allows prompt builder to adjust the generated code.""" + # Do nothing for go project now. + return generated_code diff --git a/prompts/template_xml/go_base.txt b/prompts/template_xml/go_base.txt new file mode 100644 index 0000000000..cdb7a4dfa9 --- /dev/null +++ b/prompts/template_xml/go_base.txt @@ -0,0 +1,3 @@ +You are a security testing engineer who wants to write a program in golang to execute all lines in a given method by defining and initialising its parameters and necessary objects in a suitable way before fuzzing the method. +The tag contains information of the target method to invoke. +The tag contains additional requirements that you MUST follow for this code generation. diff --git a/prompts/template_xml/go_problem.txt b/prompts/template_xml/go_problem.txt new file mode 100644 index 0000000000..203442db08 --- /dev/null +++ b/prompts/template_xml/go_problem.txt @@ -0,0 +1,34 @@ + +Your goal is to write a fuzzing harness for the provided method signature to fuzz the method with random data. It is important that the provided solution compiles and actually calls the function specified by the method signature: + + +{FUNCTION_SIGNATURE} + +The target function is belonging to the Rust project {PROJECT_NAME} ({PROJECT_URL}). +You MUST call to this target function in the original project, NOT creating a dummy function. +This function requires {ARG_COUNT} arguments. You must prepare them with random seeded data. +Here is a list of types for all arguments in order, separated by comma. You MUST preserve the modifiers. +{ARG_TYPE} + + +Try as many variations of these inputs as possible. +Try creating the harness as complex as possible. +Try adding some nested loop to invoke the target method for multiple times. +The generated fuzzing harness should be wrapped with the tag. +Please avoid using any multithreading or multi-processing approach. +You MUST create the fuzzing harness using libfuzzer approach. +You MUST import the testing module. +You MUST create the fuzzing function with name start with Fuzz with *testing.F as parameter. Then create a lambda function call as parameter to *testing.F.Fuzz function call. +You MUST include the use of the necessary functions and crate for calling the target function. +The following is a sample of the fuzzing harness. + +import "testing" + +func FuzzOFG(f *testing.F) { + f.Fuzz(func(t *testing.T /* Other params */) { + // Fuzzing logic here + }) +} + + + diff --git a/report/common.py b/report/common.py index 4f445f1d7b..256ca73808 100644 --- a/report/common.py +++ b/report/common.py @@ -32,7 +32,7 @@ MAX_RUN_LOGS_LEN = 16 * 1024 -TARGET_EXTS = project_src.SEARCH_EXTS + ['.java', '.py', '.rs' +TARGET_EXTS = project_src.SEARCH_EXTS + ['.java', '.py', '.go', '.cgo', '.rs' ] + ['.fuzz_target'] _CHAT_PROMPT_START_MARKER = re.compile(r'')