diff --git a/.github/workflows/binary-decompilation.yml b/.github/workflows/binary-decompilation.yml new file mode 100644 index 00000000..a3931afd --- /dev/null +++ b/.github/workflows/binary-decompilation.yml @@ -0,0 +1,157 @@ +# Copyright 2022 Kry10 Limited +# +# SPDX-License-Identifier: BSD-2-Clause + +name: Binary decompilation + +on: + repository_dispatch: + types: + - binary-verification + workflow_dispatch: + inputs: + repo: + description: 'Repository' + required: true + type: string + default: 'seL4/l4v' + run_id: + description: 'Workflow run ID' + required: true + tag: + description: | + A brief description of the source of the event, + e.g. a workflow identifier. This is used when + reporting the results of a binary verification + run, to help users identify the proof run that + triggered the binary verification run. + required: true + type: string + default: 'workflow-dispatch' + +jobs: + targets: + # Fetch artifacts from the remote workflow that triggered this one, + # and store them locally in this workflow for easier access during the + # matrix job. Also identify which targets to run in the matrix. + name: Prepare decompilation targets + runs-on: ubuntu-latest + outputs: + targets_enabled: ${{ steps.prepare.outputs.targets_enabled }} + steps: + - name: Indentify trigger + id: id_trigger + # Different event types use different context variables for the inputs, + # so here we figure out which variables to use. + run: | + # Identify source workflow + set -euo pipefail + case "${{ github.event_name }}" in + repository_dispatch) + echo "trigger_repo=${{ github.event.client_payload.repo }}" >> "${GITHUB_OUTPUT}" + echo "trigger_run=${{ github.event.client_payload.run_id }}" >> "${GITHUB_OUTPUT}" + echo "trigger_tag=${{ github.event.client_payload.tag }}" >> "${GITHUB_OUTPUT}" + ;; + workflow_dispatch) + echo "trigger_repo=${{ github.event.input.repo }}" >> "${GITHUB_OUTPUT}" + echo "trigger_run=${{ github.event.input.run_id }}" >> "${GITHUB_OUTPUT}" + echo "trigger_tag=${{ github.event.input.tag }}" >> "${GITHUB_OUTPUT}" + ;; + *) + echo "Unexpected github.event_name: ${{ github.event_name }}" + exit 1 + ;; + esac + + - name: Download kernel builds from source workflow + uses: seL4/ci-actions/await-remote-artifacts@master + with: + repo: ${{ steps.id_trigger.outputs.trigger_repo }} + run-id: ${{ steps.id_trigger.outputs.trigger_run }} + artifact-names: kernel-builds + token: ${{ secrets.PRIV_REPO_TOKEN }} + download-dir: artifacts + + - name: Checkout graph-refine + uses: actions/checkout@v3 + with: + path: graph-refine + + - name: Install Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + - run: pip install lxml + + - name: Prepare job + id: prepare + run: graph-refine/ci/github-prepare-decompile.py artifacts/kernel-builds job + env: + TAG: ${{ steps.id_trigger.outputs.trigger_tag }} + PROOF_REPO: ${{ steps.id_trigger.outputs.trigger_repo }} + PROOF_RUN: ${{ steps.id_trigger.outputs.trigger_run }} + DECOMPILE_REPO: ${{ github.repository }} + DECOMPILE_RUN: ${{ github.run_id }} + + - name: Upload job + uses: actions/upload-artifact@v3 + with: + name: graph-refine-job + path: job + if-no-files-found: ignore + + decompilation: + name: Decompile + needs: targets + if: needs.targets.outputs.targets_enabled != '[]' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + target: ${{ fromJSON(needs.targets.outputs.targets_enabled) }} + steps: + - name: Download targets + uses: actions/download-artifact@v3 + with: + name: graph-refine-job + path: job + - name: Decompile + run: | + # Decompile + docker run --rm -i \ + --mount "type=bind,src=${PWD}/job/targets/${{matrix.target}}/target,dst=/target" \ + --mount "type=tmpfs,dst=/tmp" \ + ghcr.io/sel4/sel4-decompiler /target + # Isolate current target for re-upload, + # to avoid interference between matrix jobs. + mkdir -p "my-job/targets/${{matrix.target}}" + mv "job/targets/${{matrix.target}}/target" "my-job/targets/${{matrix.target}}/target" + - name: Re-upload target + uses: actions/upload-artifact@v3 + with: + name: graph-refine-job + path: my-job + + submission: + name: Submit graph-refine job + needs: decompilation + runs-on: ubuntu-latest + steps: + - name: Download targets + uses: actions/download-artifact@v3 + with: + name: graph-refine-job + path: job + - name: Checkout graph-refine + uses: actions/checkout@v3 + with: + path: graph-refine + - name: Submit graph-refine job + env: + BV_BACKEND_WORK_DIR: bv + BV_SSH_CONFIG: "${{ secrets.BV_SSH_CONFIG }}" + BV_SSH_KEY: "${{ secrets.BV_SSH_KEY }}" + BV_SSH_KNOWN_HOSTS: "${{ secrets.BV_SSH_KNOWN_HOSTS }}" + DOCKER_RUN_COMMAND: "${{ secrets.DOCKER_RUN_COMMAND }}" + JOB_DIR: job + run: graph-refine/ci/submit-graph-refine diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 00000000..3b9638c2 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,122 @@ +# Copyright (c) 2022, Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +name: Build + +on: + # FIXME: When RISC-V and ARM branches are merged to master: + # - switch the push trigger to master, + # - add schedule and pull_request triggers. + push: + branches: + - ci-riscv64 + +# Ensure Docker image builds are fully serialised, +# so there is not a race to set the `latest` tag. +concurrency: + group: graph-refine-docker-builds + cancel-in-progress: true + +jobs: + build: + name: Build + runs-on: ubuntu-latest + steps: + - name: Checkout graph-refine + uses: actions/checkout@v3 + - name: Login to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: seL4 + password: ${{ secrets.GITHUB_TOKEN }} + - name: Set up Git + env: + CI_SSH: ${{ secrets.CI_SSH }} + shell: bash + run: | + # Configure SSH access and Git identity + eval $(ssh-agent) + ssh-add -q - <<< "${CI_SSH}" + echo "SSH_AUTH_SOCK=${SSH_AUTH_SOCK}" >> "${GITHUB_ENV}" + git config --global user.name "seL4 CI" + git config --global user.email "ci@sel4.systems" + # We use the Nix package manager to build Docker images. + - name: Install Nix + uses: cachix/install-nix-action@v20 + with: + nix_path: nixpkgs=channel:nixos-unstable + # We cache Nix builds using cachix.org. + - name: Install Cachix + uses: cachix/cachix-action@v12 + with: + name: sel4-bv + authToken: "${{ secrets.BV_CACHIX_AUTH_TOKEN }}" + - name: Build graph-refine + shell: bash + run: | + # Build graph-refine + build_image() { nix-build -A "$1" -o "$1" nix/graph-refine.nix; } + build_image graph-refine-image + build_image graph-refine-runner-image + - name: Load and tag graph-refine images + id: image_tags + run: | + # Load and tag graph-refine images + ./graph-refine-image | docker load + ./graph-refine-runner-image | docker load + TAG="$(docker image ls --format '{{.Tag}}' graph-refine)" + RUN_TAG="$(docker image ls --format '{{.Tag}}' graph-refine-runner)" + echo "runner_tag=$RUN_TAG" >> "$GITHUB_OUTPUT" + docker tag "graph-refine:$TAG" ghcr.io/sel4/graph-refine:latest + docker tag "graph-refine:$TAG" ghcr.io/sel4/graph-refine:"$TAG" + docker tag "graph-refine-runner:$RUN_TAG" ghcr.io/sel4/graph-refine-runner:latest + docker tag "graph-refine-runner:$RUN_TAG" ghcr.io/sel4/graph-refine-runner:"$RUN_TAG" + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Checkout HOL4 and polyml + run: decompiler/setup-decompiler.py checkout --upstream --ssh + - name: Build decompiler + run: | + # Build decompiler + nix-build -A decompiler-image -o decompiler-image decompiler + nix-build -A sel4-decompiler-image -o sel4-decompiler-image seL4-example + - name: Load and tag decompiler image + run: | + # Load and tag decompiler image + ./decompiler-image | docker load + ./sel4-decompiler-image | docker load + IMAGE_TAG="$(docker image ls --format '{{.Tag}}' decompiler)" + SEL4_IMAGE_TAG="$(docker image ls --format '{{.Tag}}' sel4-decompiler)" + docker tag "decompiler:$IMAGE_TAG" ghcr.io/sel4/decompiler:latest + docker tag "decompiler:$IMAGE_TAG" ghcr.io/sel4/decompiler:"$IMAGE_TAG" + docker tag "sel4-decompiler:$SEL4_IMAGE_TAG" ghcr.io/sel4/sel4-decompiler:latest + docker tag "sel4-decompiler:$SEL4_IMAGE_TAG" ghcr.io/sel4/sel4-decompiler:"$SEL4_IMAGE_TAG" + - name: Push upstream branches + if: github.event_name == 'push' || github.event_name == 'schedule' + run: | + # Push upstream branches + (cd decompiler/src/HOL4 && git push) + (cd decompiler/src/polyml && git push) + - name: Push Docker images + if: github.event_name == 'push' || github.event_name == 'schedule' + run: | + # Push Docker images + docker push --all-tags ghcr.io/sel4/graph-refine + docker push --all-tags ghcr.io/sel4/graph-refine-runner + docker push --all-tags ghcr.io/sel4/decompiler + docker push --all-tags ghcr.io/sel4/sel4-decompiler + # Ensure the graph-refine backend is using an up-to-date runner + - name: Install graph-refine-runner + if: github.event_name == 'push' || github.event_name == 'schedule' + env: + BV_BACKEND_WORK_DIR: bv + BV_BACKEND_CONCURRENCY: 64 + BV_SSH_CONFIG: "${{ secrets.BV_SSH_CONFIG }}" + BV_SSH_KEY: "${{ secrets.BV_SSH_KEY }}" + BV_SSH_KNOWN_HOSTS: "${{ secrets.BV_SSH_KNOWN_HOSTS }}" + DOCKER_RUN_COMMAND: "${{ secrets.DOCKER_RUN_COMMAND }}" + RUNNER_TAG: "${{ steps.image_tags.outputs.runner_tag }}" + run: ci/install-runner diff --git a/.gitignore b/.gitignore index 64a28e95..cdd92ea4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,12 @@ # .pyc -/internal/ \ No newline at end of file +__pycache__/ + +# Nix +/result* + +# Decompiler setup +/decompiler/decompile +/decompiler/src/ +/decompiler/install/ diff --git a/README.md b/README.md index 383b3d11..0f30c767 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ use of SMT solvers. The design and theory of this tool are described in the paper [Translation Validation for a Verified OS Kernel][1] by Sewell, Myreen & Klein. - [1]: https://ts.data61.csiro.au/publications/nictaabstracts/Sewell_MK_13.abstract "Translation Validation for a Verified OS Kernel" + [1]: https://trustworthy.systems/publications/nictaabstracts/Sewell_MK_13.abstract "Translation Validation for a Verified OS Kernel" Repository Setup ---------------- diff --git a/c_rodata.py b/c_rodata.py index f933f681..40e4afab 100644 --- a/c_rodata.py +++ b/c_rodata.py @@ -8,35 +8,35 @@ import target_objects def get_cache (p): - k = 'c_rodata_hook_cache' - if k not in p.cached_analysis: - p.cached_analysis[k] = {} - return p.cached_analysis[k] + k = 'c_rodata_hook_cache' + if k not in p.cached_analysis: + p.cached_analysis[k] = {} + return p.cached_analysis[k] def hook (rep, (n, vc)): - p = rep.p - tag = p.node_tags[n][0] - is_C = tag == 'C' or p.hook_tag_hints.get (tag, None) == 'C' - if not is_C: - return - upd_ps = [rep.to_smt_expr (ptr, (n, vc)) - for (kind, ptr, v, m) in p.nodes[n].get_mem_accesses () - if kind == 'MemUpdate'] - if not upd_ps: - return - cache = get_cache (p) - for ptr in set (upd_ps): - pc = rep.get_pc ((n, vc)) - eq_rodata = rep.solv.get_eq_rodata_witness (ptr) - hyp = rep.to_smt_expr (syntax.mk_implies (pc, - syntax.mk_not (eq_rodata)), (n, vc)) - if ((n, vc), ptr) in cache: - res = cache[((n, vc), ptr)] - else: - res = rep.test_hyp_whyps (hyp, [], cache = cache) - cache[((n, vc), ptr)] = res - if res: - rep.solv.assert_fact (hyp, {}) + p = rep.p + tag = p.node_tags[n][0] + is_C = tag == 'C' or p.hook_tag_hints.get (tag, None) == 'C' + if not is_C: + return + upd_ps = [rep.to_smt_expr (ptr, (n, vc)) + for (kind, ptr, v, m) in p.nodes[n].get_mem_accesses () + if kind == 'MemUpdate'] + if not upd_ps: + return + cache = get_cache (p) + for ptr in set (upd_ps): + pc = rep.get_pc ((n, vc)) + eq_rodata = rep.solv.get_eq_rodata_witness (ptr) + hyp = rep.to_smt_expr (syntax.mk_implies (pc, + syntax.mk_not (eq_rodata)), (n, vc)) + if ((n, vc), ptr) in cache: + res = cache[((n, vc), ptr)] + else: + res = rep.test_hyp_whyps (hyp, [], cache = cache) + cache[((n, vc), ptr)] = res + if res: + rep.solv.assert_fact (hyp, {}) module_hook_k = 'c_rodata' target_objects.add_hook ('post_emit_node', module_hook_k, hook) diff --git a/check.py b/check.py index ab009072..cff418ae 100644 --- a/check.py +++ b/check.py @@ -15,946 +15,960 @@ from target_objects import functions, pairings, trace, printout import target_objects from rep_graph import (vc_num, vc_offs, vc_double_range, vc_upto, mk_vc_opts, - VisitCount) + VisitCount) import logic -from syntax import (true_term, false_term, boolT, mk_var, mk_word32, mk_word8, - mk_plus, mk_minus, word32T, word8T, mk_and, mk_eq, mk_implies, mk_not, - rename_expr) +from syntax import (true_term, false_term, boolT, mk_var, + mk_plus, mk_minus, mk_and, mk_eq, mk_implies, mk_not, + rename_expr) import syntax def build_problem (pairing, force_inline = None, avoid_abort = False): - p = Problem (pairing) + p = Problem (pairing) - for (tag, fname) in pairing.funs.items (): - p.add_entry_function (functions[fname], tag) + for (tag, fname) in pairing.funs.items (): + p.add_entry_function (functions[fname], tag) - p.do_analysis () + p.do_analysis () - # FIXME: the inlining is heuristic, and arguably belongs in 'search' - inline_completely_unmatched (p, skip_underspec = avoid_abort) - - # now do any C inlining - inline_reachable_unmatched_C (p, force_inline, - skip_underspec = avoid_abort) + # FIXME: the inlining is heuristic, and arguably belongs in 'search' + inline_completely_unmatched (p, skip_underspec = avoid_abort) - trace ('Done inlining.') + # now do any C inlining + assert avoid_abort == False + inline_reachable_unmatched_C (p, force_inline, + skip_underspec = avoid_abort) - p.pad_merge_points () - p.do_analysis () + trace ('Done inlining.') - if not avoid_abort: - p.check_no_inner_loops () + p.pad_merge_points () + p.do_analysis () - return p + if not avoid_abort: + p.check_no_inner_loops () + + return p def inline_completely_unmatched (p, ref_tags = None, skip_underspec = False): - if ref_tags == None: - ref_tags = p.pairing.tags - while True: - ns = [(n, skip_underspec - and not functions[p.nodes[n].fname].entry) - for n in p.nodes - if p.nodes[n].kind == 'Call' - if not [pair for pair - in pairings.get (p.nodes[n].fname, []) - if pair.tags == ref_tags]] - [trace ('Skipped inlining underspecified %s.' - % p.nodes[n].fname) for (n, skip) in ns if skip] - ns = [n for (n, skip) in ns if not skip] - for n in ns: - trace ('Function %s at %d - %s - completely unmatched.' - % (p.nodes[n].fname, n, p.node_tags[n][0])) - inline_at_point (p, n, do_analysis = False) - if not ns: - p.do_analysis () - return + + if ref_tags == None: + ref_tags = p.pairing.tags + while True: + ns = [(n, skip_underspec + and not functions[p.nodes[n].fname].entry) + for n in p.nodes + if p.nodes[n].kind == 'Call' + if not [pair for pair + in pairings.get (p.nodes[n].fname, []) + if pair.tags == ref_tags]] + [trace ('Skipped inlining underspecified %s.' + % p.nodes[n].fname) for (n, skip) in ns if skip] + + ns = [n for (n, skip) in ns if not skip] + + for n in ns: + trace ('Function %s at %d - %s - completely unmatched.' + % (p.nodes[n].fname, n, p.node_tags[n][0])) + inline_at_point (p, n, do_analysis = False) + if not ns: + p.do_analysis () + return def inline_reachable_unmatched_C (p, force_inline = None, - skip_underspec = False): - if 'C' not in p.pairing.tags: - return - [compare_tag] = [tag for tag in p.pairing.tags if tag != 'C'] - inline_reachable_unmatched (p, 'C', compare_tag, force_inline, - skip_underspec = skip_underspec) + skip_underspec = False): + if 'C' not in p.pairing.tags: + return + [compare_tag] = [tag for tag in p.pairing.tags if tag != 'C'] + inline_reachable_unmatched (p, 'C', compare_tag, force_inline, + skip_underspec = skip_underspec) def inline_reachable_unmatched (p, inline_tag, compare_tag, - force_inline = None, skip_underspec = False): - funs = [pair.funs[inline_tag] - for n in p.nodes - if p.nodes[n].kind == 'Call' - if p.node_tags[n][0] == compare_tag - for pair in pairings.get (p.nodes[n].fname, []) - if inline_tag in pair.tags] - - rep = mk_graph_slice (p, - consider_inline (funs, inline_tag, force_inline, - skip_underspec)) - opts = vc_double_range (3, 3) - while True: - try: - heads = problem.loop_heads_including_inner (p) - limits = [(n, opts) for n in heads] - - for n in p.nodes.keys (): - try: - r = rep.get_node_pc_env ((n, limits)) - except rep.TooGeneral: - pass - - rep.get_node_pc_env (('Ret', limits), inline_tag) - rep.get_node_pc_env (('Err', limits), inline_tag) - break - except rep_graph.InlineEvent: - continue + force_inline = None, skip_underspec = False): + funs = [pair.funs[inline_tag] + for n in p.nodes + if p.nodes[n].kind == 'Call' + #if p.node_tags[n][0] == compare_tag or (inline_tag == 'C' and p.node_tags[n][0] == 'C') + if p.node_tags[n][0] == compare_tag + for pair in pairings.get (p.nodes[n].fname, []) + if inline_tag in pair.tags] + + rep = mk_graph_slice (p, + consider_inline (funs, inline_tag, force_inline, + skip_underspec)) + opts = vc_double_range (3, 3) + while True: + try: + heads = problem.loop_heads_including_inner (p) + limits = [(n, opts) for n in heads] + for n in p.nodes.keys (): + try: + r = rep.get_node_pc_env ((n, limits)) + except rep.TooGeneral: + pass + + rep.get_node_pc_env (('Ret', limits), inline_tag) + rep.get_node_pc_env (('Err', limits), inline_tag) + break + except rep_graph.InlineEvent: + continue def consider_inline1 (p, n, matched_funs, inline_tag, - force_inline, skip_underspec): - node = p.nodes[n] - assert node.kind == 'Call' - - if p.node_tags[n][0] != inline_tag: - return False - - f_nm = node.fname - if skip_underspec and not functions[f_nm].entry: - trace ('Skipping inlining underspecified %s' % f_nm) - return False - if f_nm not in matched_funs or (force_inline and force_inline (f_nm)): - return lambda: inline_at_point (p, n) - else: - return False + force_inline, skip_underspec): + node = p.nodes[n] + assert node.kind == 'Call' + + if p.node_tags[n][0] != inline_tag: + return False + + f_nm = node.fname + + if skip_underspec and not functions[f_nm].entry: + trace ('Skipping inlining underspecified %s' % f_nm) + return False + if f_nm not in matched_funs or (force_inline and force_inline (f_nm)): + return lambda: inline_at_point (p, n) + else: + return False def consider_inline (matched_funs, tag, force_inline, skip_underspec = False): - return lambda (p, n): consider_inline1 (p, n, matched_funs, tag, - force_inline, skip_underspec) + return lambda (p, n): consider_inline1 (p, n, matched_funs, tag, + force_inline, skip_underspec) def inst_eqs (p, restrs, eqs, tag_map = {}): - addr_map = {} - if not tag_map: - tag_map = dict ([(tag, tag) for tag in p.tags ()]) - for (pair_tag, p_tag) in tag_map.iteritems (): - addr_map[pair_tag + '_IN'] = ((p.get_entry (p_tag), ()), p_tag) - addr_map[pair_tag + '_OUT'] = (('Ret', restrs), p_tag) - renames = p.entry_exit_renames (tag_map.values ()) - for (pair_tag, p_tag) in tag_map.iteritems (): - renames[pair_tag + '_IN'] = renames[p_tag + '_IN'] - renames[pair_tag + '_OUT'] = renames[p_tag + '_OUT'] - hyps = [] - for (lhs, rhs) in eqs: - vals = [(rename_expr (x, renames[x_addr]), addr_map[x_addr]) - for (x, x_addr) in (lhs, rhs)] - hyps.append (eq_hyp (vals[0], vals[1])) - return hyps + addr_map = {} + if not tag_map: + tag_map = dict ([(tag, tag) for tag in p.tags ()]) + for (pair_tag, p_tag) in tag_map.iteritems (): + addr_map[pair_tag + '_IN'] = ((p.get_entry (p_tag), ()), p_tag) + addr_map[pair_tag + '_OUT'] = (('Ret', restrs), p_tag) + renames = p.entry_exit_renames (tag_map.values ()) + + for (pair_tag, p_tag) in tag_map.iteritems (): + renames[pair_tag + '_IN'] = renames[p_tag + '_IN'] + renames[pair_tag + '_OUT'] = renames[p_tag + '_OUT'] + hyps = [] + for (lhs, rhs) in eqs: + vals = [(rename_expr (x, renames[x_addr]), addr_map[x_addr]) + for (x, x_addr) in (lhs, rhs)] + hyps.append (eq_hyp (vals[0], vals[1])) + return hyps def init_point_hyps (p): - (inp_eqs, _) = p.pairing.eqs - return inst_eqs (p, (), inp_eqs) + (inp_eqs, _) = p.pairing.eqs + return inst_eqs (p, (), inp_eqs) class ProofNode: - def __init__ (self, kind, args = None, subproofs = []): - self.kind = kind - self.args = args - self.subproofs = tuple (subproofs) - if self.kind == 'Leaf': - assert args == None - assert list (subproofs) == [] - elif self.kind == 'Restr': - (self.point, self.restr_range) = args - assert len (subproofs) == 1 - elif self.kind == 'SingleRevInduct': - (self.point, self.eqs_proof, self.rev_proof) = args - assert len (subproofs) == 1 - elif self.kind == 'Split': - self.split = args - (l_details, r_details, eqs, n, loop_r_max) = args - assert len (subproofs) == 2 - elif self.kind == 'CaseSplit': - (self.point, self.tag) = args - assert len (subproofs) == 2 - else: - assert not 'proof node kind understood', kind - - def __repr__ (self): - return 'ProofNode (%r, %r, %r)' % (self.kind, - self.args, self.subproofs) - - def serialise (self, p, ss): - if self.kind == 'Leaf': - ss.append ('Leaf') - elif self.kind == 'Restr': - (kind, (x, y)) = self.restr_range - tag = p.node_tags[self.point][0] - ss.extend (['Restr', '%d' % self.point, - tag, kind, '%d' % x, '%d' % y]) - elif self.kind == 'SingleRevInduct': - tag = p.node_tags[self.point][0] - (eqs, n) = self.eqs_proof - ss.extend (['SingleRevInduct', '%d' % self.point, - tag, '%d' % n, '%d' % len (eqs)]) - for (x, y) in eqs: - serialise_lambda (x, ss) - serialise_lambda (y, ss) - (pred, n_bound) = self.rev_proof - pred.serialise (ss) - ss.append ('%d' % n_bound) - elif self.kind == 'Split': - (l_details, r_details, eqs, n, loop_r_max) = self.args - ss.extend (['Split', '%d' % n, '%d' % loop_r_max]) - serialise_details (l_details, ss) - serialise_details (r_details, ss) - ss.append ('%d' % len (eqs)) - for (x, y) in eqs: - serialise_lambda (x, ss) - serialise_lambda (y, ss) - elif self.kind == 'CaseSplit': - ss.extend (['CaseSplit', '%d' % self.point, self.tag]) - else: - assert not 'proof node kind understood' - for proof in self.subproofs: - proof.serialise (p, ss) - - def all_subproofs (self): - return [self] + [proof for proof1 in self.subproofs - for proof in proof1.all_subproofs ()] - - def all_subproblems (self, p, restrs, hyps, name): - subproblems = proof_subproblems (p, self.kind, - self.args, restrs, hyps, name) - subproofs = logic.azip (subproblems, self.subproofs) - return [(self, restrs, hyps)] + [problem - for ((restrs2, hyps2, name2), proof) in subproofs - for problem in proof.all_subproblems (p, restrs2, - hyps2, name2)] - - def save_serialise (self, p, fname): - f = open (fname, 'w') - ss = [] - self.serialise (p, ss) - f.write (' '.join (ss) + '\n') - f.close () - - def __hash__ (self): - return syntax.hash_tuplify (self.kind, self.args, - self.subproofs) + def __init__ (self, kind, args = None, subproofs = []): + self.kind = kind + self.args = args + self.subproofs = tuple (subproofs) + if self.kind == 'Leaf': + assert args == None + assert list (subproofs) == [] + elif self.kind == 'Restr': + (self.point, self.restr_range) = args + assert len (subproofs) == 1 + elif self.kind == 'SingleRevInduct': + (self.point, self.eqs_proof, self.rev_proof) = args + assert len (subproofs) == 1 + elif self.kind == 'Split': + self.split = args + (l_details, r_details, eqs, n, loop_r_max) = args + assert len (subproofs) == 2 + elif self.kind == 'CaseSplit': + (self.point, self.tag) = args + assert len (subproofs) == 2 + else: + assert not 'proof node kind understood', kind + + def __repr__ (self): + return 'ProofNode (%r, %r, %r)' % (self.kind, + self.args, self.subproofs) + + def serialise (self, p, ss): + if self.kind == 'Leaf': + ss.append ('Leaf') + elif self.kind == 'Restr': + (kind, (x, y)) = self.restr_range + tag = p.node_tags[self.point][0] + ss.extend (['Restr', '%d' % self.point, + tag, kind, '%d' % x, '%d' % y]) + elif self.kind == 'SingleRevInduct': + tag = p.node_tags[self.point][0] + (eqs, n) = self.eqs_proof + ss.extend (['SingleRevInduct', '%d' % self.point, + tag, '%d' % n, '%d' % len (eqs)]) + for (x, y) in eqs: + serialise_lambda (x, ss) + serialise_lambda (y, ss) + (pred, n_bound) = self.rev_proof + pred.serialise (ss) + ss.append ('%d' % n_bound) + elif self.kind == 'Split': + (l_details, r_details, eqs, n, loop_r_max) = self.args + ss.extend (['Split', '%d' % n, '%d' % loop_r_max]) + serialise_details (l_details, ss) + serialise_details (r_details, ss) + ss.append ('%d' % len (eqs)) + for (x, y) in eqs: + serialise_lambda (x, ss) + serialise_lambda (y, ss) + elif self.kind == 'CaseSplit': + ss.extend (['CaseSplit', '%d' % self.point, self.tag]) + else: + assert not 'proof node kind understood' + for proof in self.subproofs: + proof.serialise (p, ss) + + def all_subproofs (self): + return [self] + [proof for proof1 in self.subproofs + for proof in proof1.all_subproofs ()] + + def all_subproblems (self, p, restrs, hyps, name): + subproblems = proof_subproblems (p, self.kind, + self.args, restrs, hyps, name) + subproofs = logic.azip (subproblems, self.subproofs) + return [(self, restrs, hyps)] + [problem + for ((restrs2, hyps2, name2), proof) in subproofs + for problem in proof.all_subproblems (p, restrs2, + hyps2, name2)] + + def save_serialise (self, p, fname): + f = open (fname, 'w') + ss = [] + self.serialise (p, ss) + f.write (' '.join (ss) + '\n') + f.close () + + def __hash__ (self): + return syntax.hash_tuplify (self.kind, self.args, + self.subproofs) def serialise_details (details, ss): - (split, (seq_start, step), eqs) = details - ss.extend (['%d' % split, '%d' % seq_start, '%d' % step]) - ss.append ('%d' % len (eqs)) - for eq in eqs: - serialise_lambda (eq, ss) + (split, (seq_start, step), eqs) = details + ss.extend (['%d' % split, '%d' % seq_start, '%d' % step]) + ss.append ('%d' % len (eqs)) + for eq in eqs: + serialise_lambda (eq, ss) def serialise_lambda (eq_term, ss): - ss.extend (['Lambda', '%i']) - word32T.serialise (ss) - eq_term.serialise (ss) + ss.extend (['Lambda', '%i']) + syntax.arch.word_type.serialise(ss) + eq_term.serialise (ss) def deserialise_details (ss, i): - (split, seq_start, step) = [int (x) for x in ss[i : i + 3]] - (i, eqs) = syntax.parse_list (deserialise_lambda, ss, i + 3) - return (i, (split, (seq_start, step), eqs)) + (split, seq_start, step) = [int (x) for x in ss[i : i + 3]] + (i, eqs) = syntax.parse_list (deserialise_lambda, ss, i + 3) + return (i, (split, (seq_start, step), eqs)) def deserialise_lambda (ss, i): - assert ss[i : i + 2] == ['Lambda', '%i'], (ss, i) - (i, typ) = syntax.parse_typ (ss, i + 2) - assert typ == word32T, typ - (i, eq_term) = syntax.parse_expr (ss, i) - return (i, eq_term) + assert ss[i : i + 2] == ['Lambda', '%i'], (ss, i) + (i, typ) = syntax.parse_typ (ss, i + 2) + assert typ == syntax.arch.word_type, typ + (i, eq_term) = syntax.parse_expr (ss, i) + return (i, eq_term) def deserialise_double_lambda (ss, i): - (i, x) = deserialise_lambda (ss, i) - (i, y) = deserialise_lambda (ss, i) - return (i, (x, y)) + (i, x) = deserialise_lambda (ss, i) + (i, y) = deserialise_lambda (ss, i) + return (i, (x, y)) def deserialise_inner (ss, i): - if ss[i] == 'Leaf': - return (i + 1, ProofNode ('Leaf')) - elif ss[i] == 'Restr': - point = int (ss[i + 1]) - tag = ss[i + 2] - kind = ss[i + 3] - assert kind in ['Number', 'Offset'], (kind, i) - x = int (ss[i + 4]) - y = int (ss[i + 5]) - (i, p1) = deserialise_inner (ss, i + 6) - return (i, ProofNode ('Restr', (point, (kind, (x, y))), [p1])) - elif ss[i] == 'SingleRevInduct': - point = int (ss[i + 1]) - tag = ss[i + 2] - n = int (ss[i + 3]) - (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i + 4) - (i, pred) = syntax.parse_term (ss, i) - n_bound = int (ss[i]) - (i, p1) = deserialise_inner (ss, i + 1) - return (i, ProofNode ('SingleRevInduct', (point, (eqs, n), - (pred, n_bound)), [p1])) - elif ss[i] == 'Split': - n = int (ss[i + 1]) - loop_r_max = int (ss[i + 2]) - (i, l_details) = deserialise_details (ss, i + 3) - (i, r_details) = deserialise_details (ss, i) - (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i) - (i, p1) = deserialise_inner (ss, i) - (i, p2) = deserialise_inner (ss, i) - return (i, ProofNode ('Split', (l_details, r_details, eqs, - n, loop_r_max), [p1, p2])) - elif ss[i] == 'CaseSplit': - n = int (ss[i + 1]) - tag = ss[i + 2] - (i, p1) = deserialise_inner (ss, i + 3) - (i, p2) = deserialise_inner (ss, i) - return (i, ProofNode ('CaseSplit', (n, tag), [p1, p2])) - else: - assert not 'proof node type understood', (ss, i) + if ss[i] == 'Leaf': + return (i + 1, ProofNode ('Leaf')) + elif ss[i] == 'Restr': + point = int (ss[i + 1]) + tag = ss[i + 2] + kind = ss[i + 3] + assert kind in ['Number', 'Offset'], (kind, i) + x = int (ss[i + 4]) + y = int (ss[i + 5]) + (i, p1) = deserialise_inner (ss, i + 6) + return (i, ProofNode ('Restr', (point, (kind, (x, y))), [p1])) + elif ss[i] == 'SingleRevInduct': + point = int (ss[i + 1]) + tag = ss[i + 2] + n = int (ss[i + 3]) + (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i + 4) + (i, pred) = syntax.parse_term (ss, i) + n_bound = int (ss[i]) + (i, p1) = deserialise_inner (ss, i + 1) + return (i, ProofNode ('SingleRevInduct', (point, (eqs, n), + (pred, n_bound)), [p1])) + elif ss[i] == 'Split': + n = int (ss[i + 1]) + loop_r_max = int (ss[i + 2]) + (i, l_details) = deserialise_details (ss, i + 3) + (i, r_details) = deserialise_details (ss, i) + (i, eqs) = syntax.parse_list (deserialise_double_lambda, ss, i) + (i, p1) = deserialise_inner (ss, i) + (i, p2) = deserialise_inner (ss, i) + return (i, ProofNode ('Split', (l_details, r_details, eqs, + n, loop_r_max), [p1, p2])) + elif ss[i] == 'CaseSplit': + n = int (ss[i + 1]) + tag = ss[i + 2] + (i, p1) = deserialise_inner (ss, i + 3) + (i, p2) = deserialise_inner (ss, i) + return (i, ProofNode ('CaseSplit', (n, tag), [p1, p2])) + else: + assert not 'proof node type understood', (ss, i) def deserialise (line): - ss = line.split () - (i, proof) = deserialise_inner (ss, 0) - assert i == len (ss), (ss, i) - return proof + ss = line.split () + (i, proof) = deserialise_inner (ss, 0) + assert i == len (ss), (ss, i) + return proof def proof_subproblems (p, kind, args, restrs, hyps, path): - tags = p.pairing.tags - if kind == 'Leaf': - return [] - elif kind == 'Restr': - restr = get_proof_restr (args[0], args[1]) - hyps = hyps + [restr_trivial_hyp (p, args[0], args[1], restrs)] - return [((restr,) + restrs, hyps, - '%s (%d limited)' % (path, args[0]))] - elif kind == 'SingleRevInduct': - hyp = single_induct_resulting_hyp (p, restrs, args) - return [(restrs, hyps + [hyp], path)] - elif kind == 'Split': - split = args - return [(restrs, hyps + split_no_loop_hyps (tags, split, restrs), - '%d init case in %s' % (split[0][0], path)), - (restrs, hyps + split_loop_hyps (tags, split, restrs, exit = True), - '%d loop case in %s' % (split[0][0], path))] - elif kind == 'CaseSplit': - (point, tag) = args - visit = ((point, restrs), tag) - true_hyps = hyps + [pc_true_hyp (visit)] - false_hyps = hyps + [pc_false_hyp (visit)] - return [(restrs, true_hyps, - 'true case (%d visited) in %s' % (point, path)), - (restrs, false_hyps, - 'false case (%d not visited) in %s' % (point, path))] - else: - assert not 'proof node kind understood', proof.kind + tags = p.pairing.tags + if kind == 'Leaf': + return [] + elif kind == 'Restr': + restr = get_proof_restr (args[0], args[1]) + hyps = hyps + [restr_trivial_hyp (p, args[0], args[1], restrs)] + return [((restr,) + restrs, hyps, + '%s (%d limited)' % (path, args[0]))] + elif kind == 'SingleRevInduct': + hyp = single_induct_resulting_hyp (p, restrs, args) + return [(restrs, hyps + [hyp], path)] + elif kind == 'Split': + split = args + return [(restrs, hyps + split_no_loop_hyps (tags, split, restrs), + '%d init case in %s' % (split[0][0], path)), + (restrs, hyps + split_loop_hyps (tags, split, restrs, exit = True), + '%d loop case in %s' % (split[0][0], path))] + elif kind == 'CaseSplit': + (point, tag) = args + visit = ((point, restrs), tag) + true_hyps = hyps + [pc_true_hyp (visit)] + false_hyps = hyps + [pc_false_hyp (visit)] + return [(restrs, true_hyps, + 'true case (%d visited) in %s' % (point, path)), + (restrs, false_hyps, + 'false case (%d not visited) in %s' % (point, path))] + else: + assert not 'proof node kind understood', proof.kind def split_heads ((l_details, r_details, eqs, n, _)): - (l_split, _, _) = l_details - (r_split, _, _) = r_details - return [l_split, r_split] + (l_split, _, _) = l_details + (r_split, _, _) = r_details + return [l_split, r_split] def split_no_loop_hyps (tags, split, restrs): - ((_, (l_seq_start, l_step), _), _, _, n, _) = split + ((_, (l_seq_start, l_step), _), _, _, n, _) = split - (l_visit, _) = split_visit_visits (tags, split, restrs, vc_num (n)) + (l_visit, _) = split_visit_visits (tags, split, restrs, vc_num (n)) - return [pc_false_hyp (l_visit)] + return [pc_false_hyp (l_visit)] def split_visit_one_visit (tag, details, restrs, visit): - if details == None: - return None - (split, (seq_start, step), eqs) = details - - # the split point sequence at low numbers ('Number') is offset - # by the point the sequence starts. At symbolic offsets we ignore - # that, instead having the loop counter for the two sequences - # be the same number of iterations after the sequence start. - if visit.kind == 'Offset': - visit = vc_offs (visit.n * step) - else: - visit = vc_num (seq_start + (visit.n * step)) - - visit = ((split, ((split, visit), ) + restrs), tag) - return visit + if details == None: + return None + (split, (seq_start, step), eqs) = details + + # the split point sequence at low numbers ('Number') is offset + # by the point the sequence starts. At symbolic offsets we ignore + # that, instead having the loop counter for the two sequences + # be the same number of iterations after the sequence start. + if visit.kind == 'Offset': + visit = vc_offs (visit.n * step) + else: + visit = vc_num (seq_start + (visit.n * step)) + + visit = ((split, ((split, visit), ) + restrs), tag) + return visit def split_visit_visits (tags, split, restrs, visit): - (ltag, rtag) = tags - (l_details, r_details, eqs, _, _) = split + (ltag, rtag) = tags + (l_details, r_details, eqs, _, _) = split - l_visit = split_visit_one_visit (ltag, l_details, restrs, visit) - r_visit = split_visit_one_visit (rtag, r_details, restrs, visit) + l_visit = split_visit_one_visit (ltag, l_details, restrs, visit) + r_visit = split_visit_one_visit (rtag, r_details, restrs, visit) - return (l_visit, r_visit) + return (l_visit, r_visit) def split_hyps_at_visit (tags, split, restrs, visit): - (l_details, r_details, eqs, _, _) = split - (l_split, (l_seq_start, l_step), l_eqs) = l_details - (r_split, (r_seq_start, r_step), r_eqs) = r_details - - (l_visit, r_visit) = split_visit_visits (tags, split, restrs, visit) - (l_start, r_start) = split_visit_visits (tags, split, restrs, vc_num (0)) - (l_tag, r_tag) = tags - - def mksub (v): - return lambda exp: logic.var_subst (exp, {('%i', word32T) : v}, - must_subst = False) - def inst (exp): - return logic.inst_eq_at_visit (exp, visit) - zsub = mksub (mk_word32 (0)) - if visit.kind == 'Number': - lsub = mksub (mk_word32 (visit.n)) - else: - lsub = mksub (mk_plus (mk_var ('%n', word32T), - mk_word32 (visit.n))) - - hyps = [(Hyp ('PCImp', l_visit, r_visit), 'pc imp'), - (Hyp ('PCImp', l_visit, l_start), '%s pc imp' % l_tag), - (Hyp ('PCImp', r_visit, r_start), '%s pc imp' % r_tag)] - hyps += [(eq_hyp ((zsub (l_exp), l_start), (lsub (l_exp), l_visit), - (l_split, r_split)), '%s const' % l_tag) - for l_exp in l_eqs if inst (l_exp)] - hyps += [(eq_hyp ((zsub (r_exp), r_start), (lsub (r_exp), r_visit), - (l_split, r_split)), '%s const' % r_tag) - for r_exp in r_eqs if inst (r_exp)] - hyps += [(eq_hyp ((lsub (l_exp), l_visit), (lsub (r_exp), r_visit), - (l_split, r_split)), 'eq') - for (l_exp, r_exp) in eqs - if inst (l_exp) and inst (r_exp)] - return hyps + (l_details, r_details, eqs, _, _) = split + (l_split, (l_seq_start, l_step), l_eqs) = l_details + (r_split, (r_seq_start, r_step), r_eqs) = r_details + + (l_visit, r_visit) = split_visit_visits (tags, split, restrs, visit) + (l_start, r_start) = split_visit_visits (tags, split, restrs, vc_num (0)) + (l_tag, r_tag) = tags + + def mksub (v): + return lambda exp: logic.var_subst (exp, {('%i', syntax.arch.word_type) : v}, + must_subst = False) + def inst (exp): + return logic.inst_eq_at_visit (exp, visit) + zsub = mksub (syntax.arch.mk_word(0)) + if visit.kind == 'Number': + lsub = mksub (syntax.arch.mk_word(visit.n)) + else: + lsub = mksub (mk_plus (mk_var ('%n', syntax.arch.word_type), + syntax.arch.mk_word(visit.n))) + + hyps = [(Hyp ('PCImp', l_visit, r_visit), 'pc imp'), + (Hyp ('PCImp', l_visit, l_start), '%s pc imp' % l_tag), + (Hyp ('PCImp', r_visit, r_start), '%s pc imp' % r_tag)] + hyps += [(eq_hyp ((zsub (l_exp), l_start), (lsub (l_exp), l_visit), + (l_split, r_split)), '%s const' % l_tag) + for l_exp in l_eqs if inst (l_exp)] + hyps += [(eq_hyp ((zsub (r_exp), r_start), (lsub (r_exp), r_visit), + (l_split, r_split)), '%s const' % r_tag) + for r_exp in r_eqs if inst (r_exp)] + hyps += [(eq_hyp ((lsub (l_exp), l_visit), (lsub (r_exp), r_visit), + (l_split, r_split)), 'eq') + for (l_exp, r_exp) in eqs + if inst (l_exp) and inst (r_exp)] + + return hyps def split_loop_hyps (tags, split, restrs, exit): - ((r_split, _, _), _, _, n, _) = split - (l_visit, _) = split_visit_visits (tags, split, restrs, vc_offs (n - 1)) - (l_cont, _) = split_visit_visits (tags, split, restrs, vc_offs (n)) - (l_tag, r_tag) = tags - - l_enter = pc_true_hyp (l_visit) - l_exit = pc_false_hyp (l_cont) - if exit: - hyps = [l_enter, l_exit] - else: - hyps = [l_enter] - return hyps + [hyp for offs in map (vc_offs, range (n)) - for (hyp, _) in split_hyps_at_visit (tags, split, restrs, offs)] + ((r_split, _, _), _, _, n, _) = split + (l_visit, _) = split_visit_visits (tags, split, restrs, vc_offs (n - 1)) + (l_cont, _) = split_visit_visits (tags, split, restrs, vc_offs (n)) + (l_tag, r_tag) = tags + + l_enter = pc_true_hyp (l_visit) + l_exit = pc_false_hyp (l_cont) + if exit: + hyps = [l_enter, l_exit] + else: + hyps = [l_enter] + return hyps + [hyp for offs in map (vc_offs, range (n)) + for (hyp, _) in split_hyps_at_visit (tags, split, restrs, offs)] def loops_to_split (p, restrs): - loop_heads_with_split = set ([p.loop_id (n) - for (n, visit_set) in restrs]) - rem_loop_heads = set (p.loop_heads ()) - loop_heads_with_split - for (n, visit_set) in restrs: - if not visit_set.has_zero (): - # n must be visited, so loop heads must be - # reachable from n (or on another tag) - rem_loop_heads = [lh for lh in rem_loop_heads - if p.is_reachable_from (n, lh) - or p.node_tags[n][0] != p.node_tags[lh][0]] - return rem_loop_heads + loop_heads_with_split = set ([p.loop_id (n) + for (n, visit_set) in restrs]) + rem_loop_heads = set (p.loop_heads ()) - loop_heads_with_split + for (n, visit_set) in restrs: + if not visit_set.has_zero (): + # n must be visited, so loop heads must be + # reachable from n (or on another tag) + rem_loop_heads = [lh for lh in rem_loop_heads + if p.is_reachable_from (n, lh) + or p.node_tags[n][0] != p.node_tags[lh][0]] + return rem_loop_heads def restr_others (p, restrs, n): - extras = [(sp, vc_upto (n)) for sp in loops_to_split (p, restrs)] - return restrs + tuple (extras) + extras = [(sp, vc_upto (n)) for sp in loops_to_split (p, restrs)] + return restrs + tuple (extras) def non_r_err_pc_hyp (tags, restrs): - return pc_false_hyp ((('Err', restrs), tags[1])) + return pc_false_hyp ((('Err', restrs), tags[1])) def split_r_err_pc_hyp (p, split, restrs, tags = None): - (_, r_details, _, n, loop_r_max) = split - (r_split, (r_seq_start, r_step), r_eqs) = r_details + (_, r_details, _, n, loop_r_max) = split + (r_split, (r_seq_start, r_step), r_eqs) = r_details - nc = n * r_step - vc = vc_double_range (r_seq_start + nc, loop_r_max + 2) + nc = n * r_step + vc = vc_double_range (r_seq_start + nc, loop_r_max + 2) - restrs = restr_others (p, ((r_split, vc), ) + restrs, 2) + restrs = restr_others (p, ((r_split, vc), ) + restrs, 2) - if tags == None: - tags = p.pairing.tags + if tags == None: + tags = p.pairing.tags - return non_r_err_pc_hyp (tags, restrs) + return non_r_err_pc_hyp (tags, restrs) restr_bump = 0 def get_proof_restr (n, (kind, (x, y))): - return (n, mk_vc_opts ([VisitCount (kind, i) - for i in range (x, y + restr_bump)])) + return (n, mk_vc_opts ([VisitCount (kind, i) + for i in range (x, y + restr_bump)])) def restr_trivial_hyp (p, n, (kind, (x, y)), restrs): - restr = (n, VisitCount (kind, y - 1)) - return rep_graph.pc_triv_hyp (((n, (restr, ) + restrs), - p.node_tags[n][0])) + restr = (n, VisitCount (kind, y - 1)) + return rep_graph.pc_triv_hyp (((n, (restr, ) + restrs), + p.node_tags[n][0])) def proof_restr_checks (n, (kind, (x, y)), p, restrs, hyps): - restr = get_proof_restr (n, (kind, (x, y))) - ncerr_hyp = non_r_err_pc_hyp (p.pairing.tags, - restr_others (p, (restr, ) + restrs, 2)) - hyps = [ncerr_hyp] + hyps - def visit (vc): - return ((n, ((n, vc), ) + restrs), p.node_tags[n][0]) - - # this cannot be more uniform because the representation of visit - # at offset 0 is all a bit odd, with n being the only node so visited: - if kind == 'Offset': - min_vc = vc_offs (max (0, x - 1)) - elif x > 1: - min_vc = vc_num (x - 1) - else: - min_vc = None - if min_vc: - init_check = [(hyps, pc_true_hyp (visit (min_vc)), - 'Check of restr min %d %s for %d' % (x, kind, n))] - else: - init_check = [] - - # if we can reach node n with (y - 1) visits to n, then the next - # node will have y visits to n, which we are disallowing - # thus we show that this visit is impossible - top_vc = VisitCount (kind, y - 1) - top_check = (hyps, pc_false_hyp (visit (top_vc)), - 'Check of restr max %d %s for %d' % (y, kind, n)) - return init_check + [top_check] + restr = get_proof_restr (n, (kind, (x, y))) + ncerr_hyp = non_r_err_pc_hyp (p.pairing.tags, + restr_others (p, (restr, ) + restrs, 2)) + hyps = [ncerr_hyp] + hyps + def visit (vc): + return ((n, ((n, vc), ) + restrs), p.node_tags[n][0]) + + # this cannot be more uniform because the representation of visit + # at offset 0 is all a bit odd, with n being the only node so visited: + if kind == 'Offset': + min_vc = vc_offs (max (0, x - 1)) + elif x > 1: + min_vc = vc_num (x - 1) + else: + min_vc = None + if min_vc: + init_check = [(hyps, pc_true_hyp (visit (min_vc)), + 'Check of restr min %d %s for %d' % (x, kind, n))] + else: + init_check = [] + + # if we can reach node n with (y - 1) visits to n, then the next + # node will have y visits to n, which we are disallowing + # thus we show that this visit is impossible + top_vc = VisitCount (kind, y - 1) + top_check = (hyps, pc_false_hyp (visit (top_vc)), + 'Check of restr max %d %s for %d' % (y, kind, n)) + return init_check + [top_check] def split_init_step_checks (p, restrs, hyps, split, tags = None): - (_, _, _, n, _) = split - if tags == None: - tags = p.pairing.tags - - err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) - hyps = [err_hyp] + hyps - checks = [] - for i in range (n): - (l_visit, r_visit) = split_visit_visits (tags, split, - restrs, vc_num (i)) - lpc_hyp = pc_true_hyp (l_visit) - # this trivial 'hyp' ensures the rep is built to include - # the matching rhs visits when checking lhs consts - rpc_triv_hyp = rep_graph.pc_triv_hyp (r_visit) - vis_hyps = split_hyps_at_visit (tags, split, restrs, vc_num (i)) - - for (hyp, desc) in vis_hyps: - checks.append ((hyps + [lpc_hyp, rpc_triv_hyp], hyp, - 'Induct check at visit %d: %s' % (i, desc))) - return checks + (_, _, _, n, _) = split + if tags == None: + tags = p.pairing.tags + + err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) + hyps = [err_hyp] + hyps + checks = [] + for i in range (n): + (l_visit, r_visit) = split_visit_visits (tags, split, + restrs, vc_num (i)) + lpc_hyp = pc_true_hyp (l_visit) + # this trivial 'hyp' ensures the rep is built to include + # the matching rhs visits when checking lhs consts + rpc_triv_hyp = rep_graph.pc_triv_hyp (r_visit) + vis_hyps = split_hyps_at_visit (tags, split, restrs, vc_num (i)) + + for (hyp, desc) in vis_hyps: + checks.append ((hyps + [lpc_hyp, rpc_triv_hyp], hyp, + 'Induct check at visit %d: %s' % (i, desc))) + return checks def split_induct_step_checks (p, restrs, hyps, split, tags = None): - ((l_split, _, _), _, _, n, _) = split - if tags == None: - tags = p.pairing.tags - - err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) - (cont, r_cont) = split_visit_visits (tags, split, restrs, vc_offs (n)) - # the 'trivial' hyp here ensures the representation includes a loop - # of the rhs when proving const equations on the lhs - hyps = ([err_hyp, pc_true_hyp (cont), - rep_graph.pc_triv_hyp (r_cont)] + hyps - + split_loop_hyps (tags, split, restrs, exit = False)) - - return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' - % (desc, l_split)) - for (hyp, desc) in split_hyps_at_visit (tags, split, - restrs, vc_offs (n))] + ((l_split, _, _), _, _, n, _) = split + if tags == None: + tags = p.pairing.tags + + err_hyp = split_r_err_pc_hyp (p, split, restrs, tags = tags) + (cont, r_cont) = split_visit_visits (tags, split, restrs, vc_offs (n)) + # the 'trivial' hyp here ensures the representation includes a loop + # of the rhs when proving const equations on the lhs + hyps = ([err_hyp, pc_true_hyp (cont), + rep_graph.pc_triv_hyp (r_cont)] + hyps + + split_loop_hyps (tags, split, restrs, exit = False)) + + return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' + % (desc, l_split)) + for (hyp, desc) in split_hyps_at_visit (tags, split, + restrs, vc_offs (n))] def check_split_induct_step_group (rep, restrs, hyps, split, tags = None): - checks = split_induct_step_checks (rep.p, restrs, hyps, split, - tags = tags) - groups = proof_check_groups (checks) - for group in groups: - (verdict, _) = test_hyp_group (rep, group) - if not verdict: - return False - return True + checks = split_induct_step_checks (rep.p, restrs, hyps, split, + tags = tags) + groups = proof_check_groups (checks) + for group in groups: + (verdict, _) = test_hyp_group (rep, group) + if not verdict: + return False + return True def split_checks (p, restrs, hyps, split, tags = None): - return (split_init_step_checks (p, restrs, hyps, split, tags = tags) - + split_induct_step_checks (p, restrs, hyps, split, tags = tags)) + return (split_init_step_checks (p, restrs, hyps, split, tags = tags) + + split_induct_step_checks (p, restrs, hyps, split, tags = tags)) def loop_eq_hyps_at_visit (tag, split, eqs, restrs, visit_num, - use_if_at = False): - details = (split, (0, 1), eqs) - visit = split_visit_one_visit (tag, details, restrs, visit_num) - start = split_visit_one_visit (tag, details, restrs, vc_num (0)) - - def mksub (v): - return lambda exp: logic.var_subst (exp, {('%i', word32T) : v}, - must_subst = False) - zsub = mksub (mk_word32 (0)) - if visit_num.kind == 'Number': - isub = mksub (mk_word32 (visit_num.n)) - else: - isub = mksub (mk_plus (mk_var ('%n', word32T), - mk_word32 (visit_num.n))) - - hyps = [(Hyp ('PCImp', visit, start), '%s pc imp' % tag)] - hyps += [(eq_hyp ((zsub (exp), start), (isub (exp), visit), - (split, 0), use_if_at = use_if_at), '%s const' % tag) - for exp in eqs if logic.inst_eq_at_visit (exp, visit_num)] - - return hyps + use_if_at = False): + details = (split, (0, 1), eqs) + visit = split_visit_one_visit (tag, details, restrs, visit_num) + start = split_visit_one_visit (tag, details, restrs, vc_num (0)) + + def mksub (v): + return lambda exp: logic.var_subst (exp, {('%i', syntax.arch.word_type) : v}, + must_subst = False) + zsub = mksub(syntax.arch.mk_word(0)) + if visit_num.kind == 'Number': + isub = mksub(syntax.arch.mk_word(visit_num.n)) + else: + isub = mksub (mk_plus (mk_var ('%n', syntax.arch.word_type), + syntax.arch.mk_word(visit_num.n))) + + hyps = [(Hyp ('PCImp', visit, start), '%s pc imp' % tag)] + hyps += [(eq_hyp ((zsub (exp), start), (isub (exp), visit), + (split, 0), use_if_at = use_if_at), '%s const' % tag) + for exp in eqs if logic.inst_eq_at_visit (exp, visit_num)] + + return hyps def single_induct_resulting_hyp (p, restrs, rev_induct_args): - (point, _, (pred, _)) = rev_induct_args - (tag, _) = p.node_tags[point] - vis = ((point, restrs + tuple ([(point, vc_num (0))])), tag) - return rep_graph.true_if_at_hyp (pred, vis) + (point, _, (pred, _)) = rev_induct_args + (tag, _) = p.node_tags[point] + vis = ((point, restrs + tuple ([(point, vc_num (0))])), tag) + return rep_graph.true_if_at_hyp (pred, vis) def single_loop_induct_base_checks (p, restrs, hyps, tag, split, n, eqs): - tests = [] - details = (split, (0, 1), eqs) - for i in range (n + 1): - reach = split_visit_one_visit (tag, details, restrs, vc_num (i)) - nhyps = [pc_true_hyp (reach)] - tests.extend ([(hyps + nhyps, hyp, - 'Base check (%s, %d) at induct step for %d' - % (desc, i, split)) - for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, - eqs, restrs, vc_num (i))]) - return tests + tests = [] + details = (split, (0, 1), eqs) + for i in range (n + 1): + reach = split_visit_one_visit (tag, details, restrs, vc_num (i)) + nhyps = [pc_true_hyp (reach)] + tests.extend ([(hyps + nhyps, hyp, + 'Base check (%s, %d) at induct step for %d' + % (desc, i, split)) + for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, + eqs, restrs, vc_num (i))]) + return tests def single_loop_induct_step_checks (p, restrs, hyps, tag, split, n, - eqs, eqs_assume = None): - if eqs_assume == None: - eqs_assume = [] - details = (split, (0, 1), eqs_assume + eqs) - cont = split_visit_one_visit (tag, details, restrs, vc_offs (n)) - hyps = ([pc_true_hyp (cont)] + hyps - + [h for i in range (n) - for (h, _) in loop_eq_hyps_at_visit (tag, split, - eqs_assume + eqs, restrs, vc_offs (i))]) - - return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' - % (desc, split)) - for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, eqs, - restrs, vc_offs (n))] + eqs, eqs_assume = None): + if eqs_assume == None: + eqs_assume = [] + details = (split, (0, 1), eqs_assume + eqs) + cont = split_visit_one_visit (tag, details, restrs, vc_offs (n)) + hyps = ([pc_true_hyp (cont)] + hyps + + [h for i in range (n) + for (h, _) in loop_eq_hyps_at_visit (tag, split, + eqs_assume + eqs, restrs, vc_offs (i))]) + + return [(hyps, hyp, 'Induct check (%s) at inductive step for %d' + % (desc, split)) + for (hyp, desc) in loop_eq_hyps_at_visit (tag, split, eqs, + restrs, vc_offs (n))] def mk_loop_counter_eq_hyp (p, split, restrs, n): - details = (split, (0, 1), []) - (tag, _) = p.node_tags[split] - visit = split_visit_one_visit (tag, details, restrs, vc_offs (0)) - return eq_hyp ((mk_var ('%n', word32T), visit), - (mk_word32 (n), visit), (split, 0)) + details = (split, (0, 1), []) + (tag, _) = p.node_tags[split] + visit = split_visit_one_visit (tag, details, restrs, vc_offs (0)) + return eq_hyp ((mk_var ('%n', syntax.arch.word_type), visit), + (syntax.arch.mk_word(n), visit), (split, 0)) def single_loop_rev_induct_base_checks (p, restrs, hyps, tag, split, - n_bound, eqs_assume, pred): - details = (split, (0, 1), eqs_assume) - cont = split_visit_one_visit (tag, details, restrs, vc_offs (1)) - n_hyp = mk_loop_counter_eq_hyp (p, split, restrs, n_bound) + n_bound, eqs_assume, pred): + details = (split, (0, 1), eqs_assume) + cont = split_visit_one_visit (tag, details, restrs, vc_offs (1)) + n_hyp = mk_loop_counter_eq_hyp (p, split, restrs, n_bound) - split_details = (None, details, None, 1, 1) - non_err = split_r_err_pc_hyp (p, split_details, restrs) + split_details = (None, details, None, 1, 1) + non_err = split_r_err_pc_hyp (p, split_details, restrs) - hyps = (hyps + [n_hyp, pc_true_hyp (cont), non_err] - + [h for (h, _) in loop_eq_hyps_at_visit (tag, - split, eqs_assume, restrs, vc_offs (0))]) - goal = rep_graph.true_if_at_hyp (pred, cont) + hyps = (hyps + [n_hyp, pc_true_hyp (cont), non_err] + + [h for (h, _) in loop_eq_hyps_at_visit (tag, + split, eqs_assume, restrs, vc_offs (0))]) + goal = rep_graph.true_if_at_hyp (pred, cont) - return [(hyps, goal, 'Pred true at %d check.' % n_bound)] + return [(hyps, goal, 'Pred true at %d check.' % n_bound)] def single_loop_rev_induct_checks (p, restrs, hyps, tag, split, - eqs_assume, pred): - details = (split, (0, 1), eqs_assume) - curr = split_visit_one_visit (tag, details, restrs, vc_offs (1)) - cont = split_visit_one_visit (tag, details, restrs, vc_offs (2)) + eqs_assume, pred): + details = (split, (0, 1), eqs_assume) + curr = split_visit_one_visit (tag, details, restrs, vc_offs (1)) + cont = split_visit_one_visit (tag, details, restrs, vc_offs (2)) - split_details = (None, details, None, 1, 1) - non_err = split_r_err_pc_hyp (p, split_details, restrs) - true_next = rep_graph.true_if_at_hyp (pred, cont) + split_details = (None, details, None, 1, 1) + non_err = split_r_err_pc_hyp (p, split_details, restrs) + true_next = rep_graph.true_if_at_hyp (pred, cont) - hyps = (hyps + [pc_true_hyp (curr), true_next, non_err] - + [h for (h, _) in loop_eq_hyps_at_visit (tag, split, - eqs_assume, restrs, vc_offs (1), use_if_at = True)]) - goal = rep_graph.true_if_at_hyp (pred, curr) + hyps = (hyps + [pc_true_hyp (curr), true_next, non_err] + + [h for (h, _) in loop_eq_hyps_at_visit (tag, split, + eqs_assume, restrs, vc_offs (1), use_if_at = True)]) + goal = rep_graph.true_if_at_hyp (pred, curr) - return [(hyps, goal, 'Pred reverse step.')] + return [(hyps, goal, 'Pred reverse step.')] def all_rev_induct_checks (p, restrs, hyps, point, (eqs, n), (pred, n_bound)): - (tag, _) = p.node_tags[point] - checks = (single_loop_induct_step_checks (p, restrs, hyps, tag, - point, n, eqs) - + single_loop_induct_base_checks (p, restrs, hyps, tag, - point, n, eqs) - + single_loop_rev_induct_checks (p, restrs, hyps, tag, - point, eqs, pred) - + single_loop_rev_induct_base_checks (p, restrs, hyps, - tag, point, n_bound, eqs, pred)) - return checks + (tag, _) = p.node_tags[point] + checks = (single_loop_induct_step_checks (p, restrs, hyps, tag, + point, n, eqs) + + single_loop_induct_base_checks (p, restrs, hyps, tag, + point, n, eqs) + + single_loop_rev_induct_checks (p, restrs, hyps, tag, + point, eqs, pred) + + single_loop_rev_induct_base_checks (p, restrs, hyps, + tag, point, n_bound, eqs, pred)) + return checks def leaf_condition_checks (p, restrs, hyps): - '''checks of the final refinement conditions''' - nrerr_pc_hyp = non_r_err_pc_hyp (p.pairing.tags, restrs) - hyps = [nrerr_pc_hyp] + hyps - [l_tag, r_tag] = p.pairing.tags - - nlerr_pc = pc_false_hyp ((('Err', restrs), l_tag)) - # this 'hypothesis' ensures that the representation is built all - # the way to Ret. in particular this ensures that function relations - # are available to use in proving single-side equalities - ret_eq = eq_hyp ((true_term, (('Ret', restrs), l_tag)), - (true_term, (('Ret', restrs), r_tag))) - - ### TODO: previously we considered the case where 'Ret' was unreachable - ### (as a result of unsatisfiable hyps) and proved a simpler property. - ### we might want to restore this - (_, out_eqs) = p.pairing.eqs - checks = [(hyps + [nlerr_pc, ret_eq], hyp, 'Leaf eq check') for hyp in - inst_eqs (p, restrs, out_eqs)] - return [(hyps + [ret_eq], nlerr_pc, 'Leaf path-cond imp')] + checks + '''checks of the final refinement conditions''' + nrerr_pc_hyp = non_r_err_pc_hyp (p.pairing.tags, restrs) + hyps = [nrerr_pc_hyp] + hyps + [l_tag, r_tag] = p.pairing.tags + + nlerr_pc = pc_false_hyp ((('Err', restrs), l_tag)) + # this 'hypothesis' ensures that the representation is built all + # the way to Ret. in particular this ensures that function relations + # are available to use in proving single-side equalities + ret_eq = eq_hyp ((true_term, (('Ret', restrs), l_tag)), + (true_term, (('Ret', restrs), r_tag))) + + ### TODO: previously we considered the case where 'Ret' was unreachable + ### (as a result of unsatisfiable hyps) and proved a simpler property. + ### we might want to restore this + (_, out_eqs) = p.pairing.eqs + checks = [(hyps + [nlerr_pc, ret_eq], hyp, 'Leaf eq check') for hyp in + inst_eqs (p, restrs, out_eqs)] + return [(hyps + [ret_eq], nlerr_pc, 'Leaf path-cond imp')] + checks def proof_checks (p, proof): - return proof_checks_rec (p, (), init_point_hyps (p), proof, 'root') + return proof_checks_rec (p, (), init_point_hyps (p), proof, 'root') def proof_checks_imm (p, restrs, hyps, proof, path): - if proof.kind == 'Restr': - checks = proof_restr_checks (proof.point, proof.restr_range, - p, restrs, hyps) - elif proof.kind == 'SingleRevInduct': - checks = all_rev_induct_checks (p, restrs, hyps, proof.point, - proof.eqs_proof, proof.rev_proof) - elif proof.kind == 'Split': - checks = split_checks (p, restrs, hyps, proof.split) - elif proof.kind == 'Leaf': - checks = leaf_condition_checks (p, restrs, hyps) - elif proof.kind == 'CaseSplit': - checks = [] - - return [(hs, hyp, '%s on %s' % (name, path)) - for (hs, hyp, name) in checks] + if proof.kind == 'Restr': + checks = proof_restr_checks (proof.point, proof.restr_range, + p, restrs, hyps) + elif proof.kind == 'SingleRevInduct': + checks = all_rev_induct_checks (p, restrs, hyps, proof.point, + proof.eqs_proof, proof.rev_proof) + elif proof.kind == 'Split': + checks = split_checks (p, restrs, hyps, proof.split) + elif proof.kind == 'Leaf': + checks = leaf_condition_checks (p, restrs, hyps) + elif proof.kind == 'CaseSplit': + checks = [] + + return [(hs, hyp, '%s on %s' % (name, path)) + for (hs, hyp, name) in checks] def proof_checks_rec (p, restrs, hyps, proof, path): - checks = proof_checks_imm (p, restrs, hyps, proof, path) + checks = proof_checks_imm (p, restrs, hyps, proof, path) - subproblems = proof_subproblems (p, proof.kind, - proof.args, restrs, hyps, path) - for (subprob, subproof) in logic.azip (subproblems, proof.subproofs): - (restrs, hyps, path) = subprob - checks.extend (proof_checks_rec (p, restrs, hyps, subproof, path)) - return checks + subproblems = proof_subproblems (p, proof.kind, + proof.args, restrs, hyps, path) + for (subprob, subproof) in logic.azip (subproblems, proof.subproofs): + (restrs, hyps, path) = subprob + checks.extend (proof_checks_rec (p, restrs, hyps, subproof, path)) + return checks last_failed_check = [None] + def proof_check_groups (checks): - groups = {} - for (hyps, hyp, name) in checks: - n_vcs = set ([n_vc for hyp2 in [hyp] + hyps - for n_vc in hyp2.visits ()]) - k = (tuple (sorted (list (n_vcs)))) - groups.setdefault (k, []).append ((hyps, hyp, name)) - return groups.values () + groups = {} + group_keys = [] + for (hyps, hyp, name) in checks: + n_vcs = set ([n_vc for hyp2 in [hyp] + hyps + for n_vc in hyp2.visits ()]) + k = (tuple (sorted (list (n_vcs)))) + group_values = groups.setdefault(k, []) + if not group_values: + group_keys.append(k) + group_values.append((hyps, hyp, name)) + return [groups[k] for k in group_keys] + def test_hyp_group (rep, group, detail = None): - imps = [(hyps, hyp) for (hyps, hyp, _) in group] - names = set ([name for (_, _, name) in group]) - - trace ('Testing group of hyps: %s' % list (names), push = 1) - (res, i, res_kind) = rep.test_hyp_imps (imps) - trace ('Group result: %r' % res, push = -1) - if res: - return (res, None) - else: - if detail: - detail[0] = res_kind - return (res, group[i]) + imps = [(hyps, hyp) for (hyps, hyp, _) in group] + names = set ([name for (_, _, name) in group]) + + trace ('Testing group of hyps: %s' % list (names), push = 1) + (res, i, res_kind) = rep.test_hyp_imps (imps) + trace ('Group result: %r' % res, push = -1) + if res: + return (res, None) + else: + if detail: + detail[0] = res_kind + return (res, group[i]) def failed_test_sets (p, checks): - failed = [] - sets = {} - for (hyps, hyp, name) in checks: - sets.setdefault (name, []) - sets[name].append ((hyps, hyp)) - for name in sets: - rep = rep_graph.mk_graph_slice (p) - (res, _, _) = rep.test_hyp_imps (sets[name]) - if not res: - failed.append (name) - return failed + failed = [] + sets = {} + for (hyps, hyp, name) in checks: + sets.setdefault (name, []) + sets[name].append ((hyps, hyp)) + for name in sets: + rep = rep_graph.mk_graph_slice (p) + (res, _, _) = rep.test_hyp_imps (sets[name]) + if not res: + failed.append (name) + return failed save_checked_proofs = [None] def check_proof (p, proof, use_rep = None): - checks = proof_checks (p, proof) - groups = proof_check_groups (checks) - - for group in groups: - if use_rep == None: - rep = rep_graph.mk_graph_slice (p) - else: - rep = use_rep - - detail = [0] - (verdict, elt) = test_hyp_group (rep, group, detail) - if verdict: - continue - (hyps, hyp, name) = elt - last_failed_check[0] = elt - trace ('%s: proof failed!' % name) - trace (' (failure kind: %r)' % detail[0]) - return False - if save_checked_proofs[0]: - save = save_checked_proofs[0] - save (p, proof) - return True + checks = proof_checks (p, proof) + groups = proof_check_groups (checks) + + for group in groups: + if use_rep == None: + rep = rep_graph.mk_graph_slice (p) + else: + rep = use_rep + + detail = [0] + (verdict, elt) = test_hyp_group (rep, group, detail) + if verdict: + continue + (hyps, hyp, name) = elt + last_failed_check[0] = elt + trace ('%s: proof failed!' % name) + trace (' (failure kind: %r)' % detail[0]) + return False + if save_checked_proofs[0]: + save = save_checked_proofs[0] + save (p, proof) + return True def pretty_vseq ((split, (seq_start, seq_step), _)): - if (seq_start, seq_step) == (0, 1): - return 'visits to %d' % split - else: - i = seq_start + 1 - j = i + seq_step - k = j + seq_step - return 'visits [%d, %d, %d ...] to %d' % (i, j, k, split) + if (seq_start, seq_step) == (0, 1): + return 'visits to %d' % split + else: + i = seq_start + 1 + j = i + seq_step + k = j + seq_step + return 'visits [%d, %d, %d ...] to %d' % (i, j, k, split) def next_induct_var (n): - s = 'ijkabc' - v = s[n % 6] - if n >= 6: - v += str ((n / 6) + 1) - return v + s = 'ijkabc' + v = s[n % 6] + if n >= 6: + v += str ((n / 6) + 1) + return v def pretty_lambda (t): - v = syntax.mk_var ('#seq-visits', word32T) - t = logic.var_subst (t, {('%i', word32T) : v}, must_subst = False) - return syntax.pretty_expr (t, print_type = True) + v = syntax.mk_var ('#seq-visits', syntax.arch.word_type) + t = logic.var_subst (t, {('%i', syntax.arch.word_type) : v}, must_subst = False) + return syntax.pretty_expr (t, print_type = True) def check_proof_report_rec (p, restrs, hyps, proof, step_num, ctxt, inducts, - do_check = True): - printout ('Step %d: %s' % (step_num, ctxt)) - if proof.kind == 'Restr': - (kind, (x, y)) = proof.restr_range - if kind == 'Offset': - v = inducts[1][proof.point] - rexpr = '{%s + %s ..< %s + %s}' % (v, x, v, y) - else: - rexpr = '{%s ..< %s}' % (x, y) - printout (' Prove the number of visits to %d is in %s' - % (proof.point, rexpr)) - - checks = proof_restr_checks (proof.point, proof.restr_range, - p, restrs, hyps) - cases = [''] - elif proof.kind == 'SingleRevInduct': - printout (' Proving a predicate by future induction.') - (eqs, n) = proof.eqs_proof - point = proof.point - printout (' proving these invariants by %d-induction' % n) - for x in eqs: - printout (' %s (@ addr %s)' - % (pretty_lambda (x), point)) - printout (' then establishing this predicate') - (pred, n_bound) = proof.rev_proof - printout (' %s (@ addr %s)' - % (pretty_lambda (pred), point)) - printout (' at large iterations (%d) and by back induction.' - % n_bound) - cases = [''] - checks = all_rev_induct_checks (p, restrs, hyps, point, - proof.eqs_proof, proof.rev_proof) - elif proof.kind == 'Split': - (l_dts, r_dts, eqs, n, lrmx) = proof.split - v = next_induct_var (inducts[0]) - inducts = (inducts[0] + 1, dict (inducts[1])) - inducts[1][l_dts[0]] = v - inducts[1][r_dts[0]] = v - printout (' prove %s related to %s' % (pretty_vseq (l_dts), - pretty_vseq (r_dts))) - printout (' with equalities') - for (x, y) in eqs: - printout (' %s (@ addr %s)' % (pretty_lambda (x), - l_dts[0])) - printout (' = %s (@ addr %s)' % (pretty_lambda (y), - r_dts[0])) - printout (' and with invariants') - for x in l_dts[2]: - printout (' %s (@ addr %s)' - % (pretty_lambda (x), l_dts[0])) - for x in r_dts[2]: - printout (' %s (@ addr %s)' - % (pretty_lambda (x), r_dts[0])) - checks = split_checks (p, restrs, hyps, proof.split) - cases = ['case in (%d) where the length of the sequence < %d' - % (step_num, n), - 'case in (%d) where the length of the sequence is %s + %s' - % (step_num, v, n)] - elif proof.kind == 'Leaf': - printout (' prove all verification conditions') - checks = leaf_condition_checks (p, restrs, hyps) - cases = [] - elif proof.kind == 'CaseSplit': - printout (' case split on whether %d is visited' % proof.point) - checks = [] - cases = ['case in (%d) where %d is visited' % (step_num, proof.point), - 'case in (%d) where %d is not visited' % (step_num, proof.point)] - - if checks and do_check: - groups = proof_check_groups (checks) - for group in groups: - rep = rep_graph.mk_graph_slice (p) - detail = [0] - (res, _) = test_hyp_group (rep, group, detail) - if not res: - printout (' .. failed to prove this.') - printout (' (failure kind: %r)' % detail[0]) - return - - printout (' .. proven.') - - subproblems = proof_subproblems (p, proof.kind, - proof.args, restrs, hyps, '') - xs = logic.azip (subproblems, proof.subproofs) - xs = logic.azip (xs, cases) - step_num += 1 - for ((subprob, subproof), case) in xs: - (restrs, hyps, _) = subprob - res = check_proof_report_rec (p, restrs, hyps, subproof, - step_num, case, inducts, do_check = do_check) - if not res: - return - (step_num, induct_var_num) = res - inducts = (induct_var_num, inducts[1]) - return (step_num, inducts[0]) + do_check = True): + printout ('Step %d: %s' % (step_num, ctxt)) + if proof.kind == 'Restr': + (kind, (x, y)) = proof.restr_range + if kind == 'Offset': + v = inducts[1][proof.point] + rexpr = '{%s + %s ..< %s + %s}' % (v, x, v, y) + else: + rexpr = '{%s ..< %s}' % (x, y) + printout (' Prove the number of visits to %d is in %s' + % (proof.point, rexpr)) + + checks = proof_restr_checks (proof.point, proof.restr_range, + p, restrs, hyps) + cases = [''] + elif proof.kind == 'SingleRevInduct': + printout (' Proving a predicate by future induction.') + (eqs, n) = proof.eqs_proof + point = proof.point + printout (' proving these invariants by %d-induction' % n) + for x in eqs: + printout (' %s (@ addr %s)' + % (pretty_lambda (x), point)) + printout (' then establishing this predicate') + (pred, n_bound) = proof.rev_proof + printout (' %s (@ addr %s)' + % (pretty_lambda (pred), point)) + printout (' at large iterations (%d) and by back induction.' + % n_bound) + cases = [''] + checks = all_rev_induct_checks (p, restrs, hyps, point, + proof.eqs_proof, proof.rev_proof) + elif proof.kind == 'Split': + (l_dts, r_dts, eqs, n, lrmx) = proof.split + v = next_induct_var (inducts[0]) + inducts = (inducts[0] + 1, dict (inducts[1])) + inducts[1][l_dts[0]] = v + inducts[1][r_dts[0]] = v + printout (' prove %s related to %s' % (pretty_vseq (l_dts), + pretty_vseq (r_dts))) + printout (' with equalities') + for (x, y) in eqs: + printout (' %s (@ addr %s)' % (pretty_lambda (x), + l_dts[0])) + printout (' = %s (@ addr %s)' % (pretty_lambda (y), + r_dts[0])) + printout (' and with invariants') + for x in l_dts[2]: + printout (' %s (@ addr %s)' + % (pretty_lambda (x), l_dts[0])) + for x in r_dts[2]: + printout (' %s (@ addr %s)' + % (pretty_lambda (x), r_dts[0])) + checks = split_checks (p, restrs, hyps, proof.split) + cases = ['case in (%d) where the length of the sequence < %d' + % (step_num, n), + 'case in (%d) where the length of the sequence is %s + %s' + % (step_num, v, n)] + elif proof.kind == 'Leaf': + printout (' prove all verification conditions') + checks = leaf_condition_checks (p, restrs, hyps) + cases = [] + elif proof.kind == 'CaseSplit': + printout (' case split on whether %d is visited' % proof.point) + checks = [] + cases = ['case in (%d) where %d is visited' % (step_num, proof.point), + 'case in (%d) where %d is not visited' % (step_num, proof.point)] + + if checks and do_check: + groups = proof_check_groups (checks) + + for group in groups: + rep = rep_graph.mk_graph_slice (p) + detail = [0] + (res, _) = test_hyp_group (rep, group, detail) + if not res: + printout (' .. failed to prove this.') + printout (' (failure kind: %r)' % detail[0]) + return + + printout (' .. proven.') + + subproblems = proof_subproblems (p, proof.kind, + proof.args, restrs, hyps, '') + xs = logic.azip (subproblems, proof.subproofs) + xs = logic.azip (xs, cases) + step_num += 1 + for ((subprob, subproof), case) in xs: + (restrs, hyps, _) = subprob + res = check_proof_report_rec (p, restrs, hyps, subproof, + step_num, case, inducts, do_check = do_check) + if not res: + return + (step_num, induct_var_num) = res + inducts = (induct_var_num, inducts[1]) + return (step_num, inducts[0]) def check_proof_report (p, proof, do_check = True): - res = check_proof_report_rec (p, (), init_point_hyps (p), proof, - 1, '', (0, {}), do_check = do_check) - return bool (res) + res = check_proof_report_rec (p, (), init_point_hyps (p), proof, + 1, '', (0, {}), do_check = do_check) + return bool (res) def save_proofs_to_file (fname, mode = 'w'): - assert mode in ['w', 'a'] - f = open (fname, mode) - - def save (p, proof): - f.write ('ProblemProof (%s) {\n' % p.name) - for s in p.serialise (): - f.write (s + '\n') - ss = [] - proof.serialise (p, ss) - f.write (' '.join (ss)) - f.write ('\n}\n') - f.flush () - return save + assert mode in ['w', 'a'] + f = open (fname, mode) + + def save (p, proof): + f.write ('ProblemProof (%s) {\n' % p.name) + for s in p.serialise (): + f.write (s + '\n') + ss = [] + proof.serialise (p, ss) + f.write (' '.join (ss)) + f.write ('\n}\n') + f.flush () + return save def load_proofs_from_file (fname): - f = open (fname) - - proofs = {} - lines = None - for line in f: - line = line.strip () - if line.startswith ('ProblemProof'): - assert line.endswith ('{'), line - name_bit = line[len ('ProblemProof') : -1].strip () - assert name_bit.startswith ('('), name_bit - assert name_bit.endswith (')'), name_bit - name = name_bit[1:-1] - lines = [] - elif line == '}': - assert lines[0] == 'Problem' - assert lines[-2] == 'EndProblem' - import problem - trace ('loading proof from %d lines' % len (lines)) - p = problem.deserialise (name, lines[:-1]) - proof = deserialise (lines[-1]) - proofs.setdefault (name, []) - proofs[name].append ((p, proof)) - trace ('loaded proof %s' % name) - lines = None - elif line.startswith ('#'): - pass - elif line: - lines.append (line) - assert not lines - return proofs + f = open (fname) + + proofs = {} + lines = None + for line in f: + line = line.strip () + if line.startswith ('ProblemProof'): + assert line.endswith ('{'), line + name_bit = line[len ('ProblemProof') : -1].strip () + assert name_bit.startswith ('('), name_bit + assert name_bit.endswith (')'), name_bit + name = name_bit[1:-1] + lines = [] + elif line == '}': + assert lines[0] == 'Problem' + assert lines[-2] == 'EndProblem' + import problem + trace ('loading proof from %d lines' % len (lines)) + p = problem.deserialise (name, lines[:-1]) + proof = deserialise (lines[-1]) + proofs.setdefault (name, []) + proofs[name].append ((p, proof)) + trace ('loaded proof %s' % name) + lines = None + elif line.startswith ('#'): + pass + elif line: + lines.append (line) + assert not lines + return proofs diff --git a/ci/dir_hash.py b/ci/dir_hash.py new file mode 100755 index 00000000..a7b91db2 --- /dev/null +++ b/ci/dir_hash.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +# Copyright 2022, Kry10 Limited. +# SPDX-License-Identifier: BSD-2-Clause + +# Calculate a hash of a directory tree, using file and directory names and +# contents. Used to uniquely identify graph-refine jobs in CI. + +import os +import sys + +from base64 import b32encode +from hashlib import blake2b +from pathlib import Path +from typing import Sequence + + +def dir_entry_name(entry: os.DirEntry) -> str: + return entry.name + + +def hash_dir(path: Path) -> bytes: + hasher = blake2b(person=b'dir contents') + for entry in sorted(os.scandir(path), key=dir_entry_name): + hasher.update(blake2b(entry.name.encode(), person=b'dir entry').digest()) + hasher.update(hash_dir_entry(path, entry)) + return hasher.digest() + + +def hash_file(path: Path) -> bytes: + hasher = blake2b(person=b'file contents') + with open(path, 'rb') as file: + while chunk := file.read(2 ** 10): + hasher.update(chunk) + return hasher.digest() + + +def hash_symlink(path: Path) -> bytes: + return blake2b(os.readlink(path).encode(), person=b'symlink target').digest() + + +class UnknownEntryType(Exception): + pass + + +def hash_dir_entry(parent_path: Path, entry: os.DirEntry) -> bytes: + path = parent_path / entry.name + if entry.is_symlink(): + return hash_symlink(path) + if entry.is_dir(): + return hash_dir(path) + if entry.is_file(): + return hash_file(path) + raise UnknownEntryType(path) + + +def hash_path(path: Path) -> bytes: + if path.is_symlink(): + return hash_symlink(path) + if path.is_dir(): + return hash_dir(path) + if path.is_file(): + return hash_file(path) + raise FileNotFoundError(path) + + +def hash_path_b32(path: str) -> str: + path_hash = blake2b(hash_path(Path(path)), digest_size=20).digest() + return b32encode(path_hash).decode().lower().rstrip('=') + + +def main(args: Sequence[str]) -> None: + [_, path] = args + print(hash_path_b32(path)) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/ci/github-prepare-decompile.py b/ci/github-prepare-decompile.py new file mode 100755 index 00000000..c40a2e27 --- /dev/null +++ b/ci/github-prepare-decompile.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +import json +import os +import shutil +import sys + +from lxml import etree # type: ignore +from pathlib import Path +from typing import Any + + +supported_targets = ( + 'ARM-O1', + 'ARM-O2', + 'ARM-MCS-O1', + 'ARM-MCS-O2', + 'RISCV64-O1', + 'RISCV64-MCS-O1', +) + + +def workflow_url(repo: str, run: str) -> str: + return f'https://github.com/{repo}/actions/runs/{run}' + + +artifacts_input_dir = Path(sys.argv[1]).resolve() +job_output_dir = Path(sys.argv[2]).resolve() + +targets: dict[str, Any] = {} + +for target in supported_targets: + artifact_path = artifacts_input_dir / target + print(f'Checking artifact_path {artifact_path}') + if (artifact_path / 'CFunctions.txt').is_file(): + print('Found CFunctions.txt') + target_path = job_output_dir / 'targets' / target + target_path.mkdir(parents=True) + print(f'Target path {target_path}') + shutil.move(artifact_path, target_path / 'target') + + def get_var(line: str) -> tuple[str, str]: + parts = [s.strip() for s in line.split(sep='=', maxsplit=1)] + assert len(parts) == 2 + return parts[0], parts[1] + + with open(target_path / 'target' / 'config.env') as config_env: + config = dict(get_var(line) for line in config_env) + + with open(target_path / 'target' / 'manifest.xml') as manifest_xml: + manifest = etree.parse(manifest_xml) + versions = {'seL4': manifest.xpath(f'string(//project[@name="seL4"]/@revision)'), + 'l4v': manifest.xpath(f'string(//project[@name="l4v"]/@revision)')} + + targets[target] = {'config': config, 'versions': versions} + +if targets: + with open(job_output_dir / 'job_info.json', 'w') as job_info_json: + github_info = {'tag': os.environ['TAG'], + 'proof': {'repo': os.environ['PROOF_REPO'], + 'run': os.environ['PROOF_RUN']}, + 'decompile': {'repo': os.environ['DECOMPILE_REPO'], + 'run': os.environ['DECOMPILE_RUN']}} + job_info = {'targets': targets, 'github': github_info} + json.dump(job_info, job_info_json, separators=(',', ':')) + +with open(os.environ['GITHUB_OUTPUT'], 'a') as github_output: + enabled_json = json.dumps(list(targets), separators=(',', ':')) + print(f'targets_enabled={enabled_json}', file=github_output) diff --git a/ci/install-runner b/ci/install-runner new file mode 100755 index 00000000..4994a1d2 --- /dev/null +++ b/ci/install-runner @@ -0,0 +1,144 @@ +#!/bin/bash + +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Install the latest graph-refine runner on the graph-refine back end. +# +# This script is intended to be run in the GitHub CI workflow +# that builds graph-refine Docker images. When it builds a new +# graph-refine-runner image, the workflow uses this script to +# tell the backend about the new version, and to start a runner +# with the new version. The new version will take over from any +# currently running version. +# +# It requires arguments: +# +# It requires environment variables: +# - BV_BACKEND_WORK_DIR: Path on the remote (back-end) host of the +# graph-refine work directory, relative to the SSH home directory. +# - BV_BACKEND_CONCURRENCY: Number of concurrent function analyses to run +# on the back-end. +# - BV_SSH_CONFIG: Contents of an SSH config file that uses the name +# `graph-refine` for the remote (back-end) host. +# - BV_SSH_KEY: Private key with access to a user on the `graph-refine` host. +# - BV_SSH_KNOWN_HOSTS: Contents of an SSH known hosts file suitable for +# accessing the `graph-refine` host. +# - DOCKER_RUN_COMMAND: Command to use in place of `docker run`, +# e.g. `podman run --memory 20g`. +# - RUNNER_TAG: Unique tag of the new image. +# - RUNNER_TEMP: Path to a local temporary directory. +# +# The BV_BACKEND_WORK_DIR is assumed to follow the same structure as used +# in the parallel job runner (see runner.py). + + +set -euo pipefail + +if [ $# -ne 0 ]; then + echo "install-runner: error: unexpected arguments" >&2 + exit 1 +fi + +if [ -z ${RUNNER_TAG:+x} ]; then + echo "install-runner: error: RUNNER_TAG not set" >&2 + exit 1 +fi + +CI_TMP="$(mktemp -d -p "${RUNNER_TEMP}")" +cleanup() { rm -rf "${CI_TMP}"; } +trap cleanup EXIT + +# Build an SSH config for reaching the back end. +SSH_DIR="${CI_TMP}/ssh" +mkdir "${SSH_DIR}" +touch "${SSH_DIR}/ssh_key" +chmod 0700 "${SSH_DIR}/ssh_key" +cat > "${SSH_DIR}/ssh_key" <<< "${BV_SSH_KEY}" +cat > "${SSH_DIR}/ssh_known_hosts" <<< "${BV_SSH_KNOWN_HOSTS}" + +# BV_SSH_CONFIG should define how to connect to a host named `graph-refine`. +# For example, it might contain something like: +# Host graph-refine +# Hostname real-hostname.example.org +# User bv +# It may also contain configuration for any required jump hosts. +cat > "${SSH_DIR}/ssh_config" < "${CI_TMP}/ci-install" <> "${CI_TMP}/ci-install" <<'EOF' +mkdir -p "${WORK_DIR}/private" + +# We follow the back end's locking protocol. See the +# comments in runner.py for details. + +# Set the allowed back-end concurrency. + +CPUS_ALLOWED_FILE="${WORK_DIR}/private/cpus_allowed.txt" + +exec 15<> "${CPUS_ALLOWED_FILE}.lock" +if ! flock -w 30 15; then + echo "Failed to lock cpus_allowed.txt" >&2 + exit 1 +fi + +echo "${CONCURRENCY}" > "${CPUS_ALLOWED_FILE}.tmp" +mv "${CPUS_ALLOWED_FILE}.tmp" "${CPUS_ALLOWED_FILE}" + +flock -u 15 + +# Installing the new version, by writing the image tag to a file +# which determines which image is used on the back end. +# This is effectively a no-op if the build is identical to the +# previously installed version. + +ID_FILE="${WORK_DIR}/private/active_runner_id.txt" + +exec 15<> "${ID_FILE}.lock" +if ! flock -w 30 15; then + echo "Failed to lock active_runner_id.txt" >&2 + exit 1 +fi + +echo "${RUNNER_TAG}" > "${ID_FILE}.tmp" +mv "${ID_FILE}.tmp" "${ID_FILE}" + +flock -u 15 + +# Start a runner, whether or not this is really a new version. +# This helps maintain runner uptime, without having to configure +# the runner as an operating system service on the back end. + +${DOCKER_RUN_COMMAND} --init -d \ + --mount "type=bind,src=${HOME}/${WORK_DIR},dst=/work" \ + "ghcr.io/sel4/graph-refine-runner:${RUNNER_TAG}" \ + --id "${RUNNER_TAG}" \ + --work /work +EOF + +bv_ssh() { ssh -F "${SSH_DIR}/ssh_config" graph-refine "$@"; } +bv_ssh "$(cat "${CI_TMP}/ci-install")" diff --git a/ci/runner.py b/ci/runner.py new file mode 100755 index 00000000..ddda2fcc --- /dev/null +++ b/ci/runner.py @@ -0,0 +1,1562 @@ +#!/usr/bin/env python3 + +import argparse +import enum +import fcntl +import json +import os +import queue +import random +import re +import shutil +import signal +import sqlite3 +import string +import subprocess +import sys +import threading +import time + +from base64 import b32encode +from contextlib import contextmanager, redirect_stdout, redirect_stderr +from dataclasses import dataclass +from datetime import datetime, timezone +from hashlib import blake2b +from pathlib import Path +from typing import (Any, Callable, ContextManager, Iterable, Iterator, Mapping, NamedTuple, + Optional, Protocol, Sequence, TypeVar) + + +K = TypeVar('K') +T = TypeVar('T') +U = TypeVar('U') +V = TypeVar('V') + + +def now_utc() -> datetime: + return datetime.now(tz=timezone.utc) + + +# Date/time format used by SQLite. +def sqlite_timestamp(timestamp: datetime) -> str: + return f'{timestamp:%Y-%m-%dT%H:%M:%SZ}' + + +dev_random = random.SystemRandom() +rand_chars = string.ascii_lowercase + string.digits + + +# Unique IDs for naming runner instances, and function analysis result directories. +def mk_unique_id(timestamp: Optional[datetime] = None) -> str: + if timestamp is None: + timestamp = now_utc() + extra = ''.join(dev_random.choice(rand_chars) for i in range(20)) + return f'{timestamp:%Y-%m-%d-%H-%M-%S}-{extra}' + + +def write_log(msg: str, *msgs: str) -> None: + date_fmt = "%Y-%m-%d %H:%M:%S" + timestamp = f'{datetime.now(tz=timezone.utc):%Y-%m-%d %H:%M:%S}' + print(f'{timestamp}: {msg}') + for msg in msgs: + print(f'{timestamp}: {msg}') + sys.stdout.flush() + + +# Filesyste layout +# ---------------- +# +# WORK_DIR/ +# +# public/ +# jobs.json - list of running, waiting and recently completed jobs. +# jobs/ +# JOB_ID/ +# targets.json - summary status for each target. +# targets/ +# TARGET_ID/ +# target/ +# functions-list.txt, ... +# functions.json - status of each task in the target. +# functions/ +# FUNCTION_NAME/ +# UNIQUE_RUN_ID/ +# report.txt +# log.txt +# exit_code.txt +# +# runners/ +# RUNNER_ID/ +# INSTANCE_ID.log +# +# smt/ +# +# private/ +# new/ +# jobs.db +# +# cpus_allowed.txt +# active_runner_id.txt +# active_instance.txt +# +# cpus/ +# CPU_ID/ +# .lock +# +# runners/ +# RUNNER_ID/ +# .lock + + +# A task is an individual graph-refine function analysis, +# identified by a triple of job ID, target name and function name. +class Task(NamedTuple): + job_id: str + target: str + function: str + + +# A simple SQLite database for tracking jobs and tasks. +# +# A job is a collection of targets. +# A target specifies a combination of architecture, seL4 features and optimisation level. +# Each function within a target gives rise to a task. + +class WaitingJob(NamedTuple): + job_id: str + time_job_submitted: str + + +class EnqueueTask(NamedTuple): + task: Task + priority: int + + +GetTasksForJob = Callable[[str], Iterable[EnqueueTask]] +EnqueuePriority = Callable[[Task], int] + + +class TaskRun(NamedTuple): + job_id: str + target: str + function: str + run_id: str + + +class JobData(NamedTuple): + time_job_submitted: str + time_job_started: Optional[str] + time_job_finished: Optional[str] + + +class TaskAssigned(NamedTuple): + runner_id: str + run_id: str + + +class TaskUnassigned(NamedTuple): + pass + + +TaskAssignment = TaskAssigned | TaskUnassigned + + +class RunData(NamedTuple): + time_run_started: str + time_run_finished: Optional[str] + result: Optional[str] + + +class TaskData(NamedTuple): + assignment: Optional[TaskAssignment] + runs: Mapping[TaskAssigned, RunData] + + +# job_id -> target -> T +TargetMap = Mapping[str, Mapping[str, T]] + +# job_id -> JobData +DirtyJobs = Mapping[str, JobData] + +# job_id -> target -> status -> count +DirtyTargets = TargetMap[Mapping[str, int]] + +# job_id -> target -> function -> TaskData +DirtyTasks = TargetMap[Mapping[str, TaskData]] + + +class DirtyDbState(NamedTuple): + jobs: DirtyJobs + targets: DirtyTargets + tasks: DirtyTasks + clean: Callable[[], None] + + +@enum.unique +class TaskResult(enum.Enum): + PASSED = enum.auto() + FAILED = enum.auto() + NO_SPLIT = enum.auto() + EXCEPT = enum.auto() + MALFORMED = enum.auto() + UNDERSPECIFIED = enum.auto() + COMPLEX_LOOP = enum.auto() + NO_RESULT = enum.auto() + IMPOSSIBLE = enum.auto() + TIMEOUT = enum.auto() + KILLED = enum.auto() + + +passing_task_results = { + TaskResult.PASSED, + TaskResult.UNDERSPECIFIED, + TaskResult.COMPLEX_LOOP, + TaskResult.IMPOSSIBLE, +} + + +class JobsDB(NamedTuple): + conn: sqlite3.Connection + get_tasks: GetTasksForJob + + def query(self, sql: str, parameters=()) -> ContextManager[sqlite3.Cursor]: + @contextmanager + def impl() -> Iterator[sqlite3.Cursor]: + cursor = self.conn.execute(sql, parameters) + try: + yield cursor + finally: + cursor.close() + + return impl() + + def execute(self, sql: str, parameters=()) -> None: + self.conn.execute(sql, parameters).close() + + def executemany(self, sql: str, parameters) -> None: + self.conn.executemany(sql, parameters).close() + + # Schema versioning and migration is a future problem. + def initialise(self) -> None: + self.conn.executescript(""" + BEGIN; + + CREATE TABLE IF NOT EXISTS + jobs (job_id TEXT PRIMARY KEY NOT NULL, + time_job_submitted TEXT NOT NULL, + time_job_started TEXT, + time_job_finished TEXT, + is_job_dirty INTEGER NOT NULL) + STRICT; + + CREATE TABLE IF NOT EXISTS + tasks (job_id TEXT NOT NULL, + target TEXT NOT NULL, + function TEXT NOT NULL, + priority INTEGER NOT NULL, + runner_id TEXT, + run_id TEXT, + is_task_dirty INTEGER NOT NULL, + PRIMARY KEY (job_id, target, function)) + STRICT; + + CREATE TABLE IF NOT EXISTS + runs (job_id TEXT NOT NULL, + target TEXT NOT NULL, + function TEXT NOT NULL, + runner_id TEXT NOT NULL, + run_id TEXT NOT NULL, + time_run_started TEXT NOT NULL, + time_run_finished TEXT, + result TEXT, + is_run_dirty INTEGER NOT NULL, + PRIMARY KEY (job_id, target, function, runner_id, run_id)) + STRICT; + + -- Used by enqueue_tasks + CREATE INDEX IF NOT EXISTS + jobs_waiting_by_submit_time + ON jobs (unixepoch(time_job_submitted)) + WHERE time_job_started IS NULL + AND time_job_finished IS NULL; + + -- Used by try_assign_tasks + CREATE INDEX IF NOT EXISTS + tasks_unassigned_by_priority + ON tasks (priority) + WHERE runner_id IS NULL + OR run_id IS NULL; + + -- Used by unassign_tasks + CREATE INDEX IF NOT EXISTS + tasks_assigned_by_runner_id + ON tasks (runner_id) + WHERE runner_id IS NOT NULL + AND run_id IS NOT NULL; + + -- USED by export_dirty_state + CREATE INDEX IF NOT EXISTS + jobs_dirty + ON jobs (job_id) + WHERE is_job_dirty = 1; + + -- USED by export_dirty_state + CREATE INDEX IF NOT EXISTS + tasks_dirty + ON tasks (job_id, target, function) + WHERE is_task_dirty = 1; + + -- USED by export_dirty_state + CREATE INDEX IF NOT EXISTS + runs_dirty + ON runs (job_id, target, function, runner_id, run_id) + WHERE is_run_dirty = 1; + + COMMIT; + """) + + # Move jobs from the `new` directory into the database. + def add_waiting_jobs(self, jobs: Sequence[WaitingJob]) -> None: + with self.conn: + self.execute("BEGIN") + insert_job = """ + INSERT INTO jobs(job_id, time_job_submitted, is_job_dirty) + VALUES (:job_id, :time_job_submitted, 1) + ON CONFLICT DO NOTHING + """ + self.executemany(insert_job, (job._asdict() for job in jobs)) + self.execute("COMMIT") + + # Enqueue tasks from waiting jobs, until either `wanted` tasks + # have been enqueued, or there are no more waiting jobs. + def enqueue_tasks(self, wanted: int, timestamp: str) -> None: + tasks_enqueued = 0 + + def count_tasks(iter: Iterable[T]) -> Iterator[T]: + nonlocal tasks_enqueued + for i in iter: + tasks_enqueued += 1 + yield i + + with self.conn: + while tasks_enqueued < wanted: + self.execute("BEGIN") + next_job_query = """ + SELECT job_id + FROM jobs + WHERE time_job_started IS NULL + AND time_job_finished IS NULL + ORDER BY unixepoch(time_job_submitted) ASC + LIMIT 1 + """ + with self.query(next_job_query) as query: + job_rows = query.fetchall() + if not job_rows: + break + job_id = job_rows[0]["job_id"] + + insert_task = """ + INSERT INTO tasks(job_id, target, function, priority, is_task_dirty) + VALUES (:job_id, :target, :function, :priority, 1) + ON CONFLICT DO NOTHING + """ + task_rows = ({**enqueue.task._asdict(), "priority": enqueue.priority} + for enqueue in count_tasks(self.get_tasks(job_id))) + self.executemany(insert_task, task_rows) + + update_job_status = """ + UPDATE jobs + SET time_job_started = :time_job_started, + is_job_dirty = 1 + WHERE job_id = :job_id + """ + args = {"job_id": job_id, "time_job_started": timestamp} + self.execute(update_job_status, args) + self.execute("COMMIT") + + # Assign tasks to a runner. + def try_assign_tasks(self, *, runner_id: str, run_id: str, wanted: int, + timestamp: str) -> list[TaskRun]: + + def mk_task(row) -> TaskRun: + return TaskRun(job_id=row["job_id"], + target=row["target"], + function=row["function"], + run_id=run_id) + + def mk_run_row(task: TaskRun) -> dict[str, str]: + return {**task._asdict(), "runner_id": runner_id, "time_run_started": timestamp} + + with self.conn: + self.execute("BEGIN") + assign_tasks_query = """ + UPDATE tasks + SET runner_id = :runner_id, + run_id = :run_id, + is_task_dirty = 1 + WHERE rowid IN (SELECT rowid + FROM tasks AS t + WHERE runner_id IS NULL + OR run_id IS NULL + OR NOT EXISTS (SELECT * FROM runs + WHERE runs.job_id = t.job_id + AND runs.target = t.target + AND runs.function = t.function + AND runs.runner_id = t.runner_id + AND runs.run_id = t.run_id) + ORDER BY priority ASC, random() + LIMIT :wanted) + RETURNING job_id, target, function + """ + args = {"runner_id": runner_id, "run_id": run_id, "wanted": wanted} + with self.query(assign_tasks_query, args) as query: + tasks = [mk_task(row) for row in query] + + insert_run = """ + INSERT INTO runs(job_id, target, function, runner_id, run_id, + time_run_started, is_run_dirty) + VALUES (:job_id, :target, :function, :runner_id, :run_id, :time_run_started, 1) + ON CONFLICT DO NOTHING + """ + self.executemany(insert_run, (mk_run_row(task) for task in tasks)) + self.execute("COMMIT") + return tasks + + # Assign tasks to a runner, enqueueing if necessary. + def assign_tasks(self, runner_id: str, wanted: int) -> Sequence[TaskRun]: + now = now_utc() + timestamp = sqlite_timestamp(now) + run_id = mk_unique_id(now) + + tasks = self.try_assign_tasks(runner_id=runner_id, + run_id=run_id, + wanted=wanted, + timestamp=timestamp) + + wanted -= len(tasks) + if wanted > 0: + self.enqueue_tasks(wanted, timestamp) + tasks += self.try_assign_tasks(runner_id=runner_id, + run_id=run_id, + wanted=wanted, + timestamp=timestamp) + + return tasks + + # Take back unfinished jobs previously assigned to a runner, + # for example, if the runner died before completing them. + def unassign_tasks(self, runner_id: str) -> None: + with self.conn: + self.execute("BEGIN") + unassign_tasks_update = """ + UPDATE tasks + SET runner_id = NULL, + run_id = NULL, + is_task_dirty = 1 + WHERE runner_id = :runner_id + AND NOT EXISTS (SELECT * FROM runs + WHERE runs.job_id = tasks.job_id + AND runs.target = tasks.target + AND runs.function = tasks.function + AND runs.runner_id = :runner_id + AND runs.run_id = tasks.run_id + AND runs.result IS NOT NULL) + """ + self.execute(unassign_tasks_update, {"runner_id": runner_id}) + self.execute("COMMIT") + + # Mark a task run as complete. + def finish_run(self, *, runner_id: str, task: TaskRun, result: TaskResult) -> None: + timestamp = sqlite_timestamp(now_utc()) + with self.conn: + self.execute("BEGIN") + finish_run_update = """ + UPDATE runs + SET time_run_finished = :time_run_finished, + result = :result, + is_run_dirty = 1 + WHERE job_id = :job_id + AND target = :target + AND function = :function + AND runner_id = :runner_id + AND run_id = :run_id + """ + args = {**task._asdict(), + "runner_id": runner_id, + "time_run_finished": timestamp, + "result": result.name} + self.execute(finish_run_update, args) + + finish_job_update = """ + UPDATE jobs + SET time_job_finished = :time_job_finished, + is_job_dirty = 1 + WHERE job_id = :job_id + AND time_job_started IS NOT NULL + AND NOT EXISTS (SELECT * + FROM tasks NATURAL LEFT JOIN runs + WHERE result IS NULL) + """ + args = {"job_id": task.job_id, + "time_job_finished": timestamp} + self.execute(finish_job_update, args) + self.execute("COMMIT") + + # Get the jobs and tasks that have been touched since the last export. + # Used to create and efficiently update JSON files showing the state of + # queued and running jobs and tasks, so that they can be made visible + # via a static website. + def export_dirty_state(self) -> ContextManager[DirtyDbState]: + @contextmanager + def impl() -> Iterator[DirtyDbState]: + with self.conn: + self.execute("BEGIN") + + jobs = self.export_dirty_jobs() + targets = self.export_dirty_targets() + tasks = self.export_dirty_tasks() + + do_clean = False + + def clean() -> None: + nonlocal do_clean + do_clean = True + + yield DirtyDbState(jobs=jobs, targets=targets, tasks=tasks, clean=clean) + + if do_clean: + self.execute("UPDATE jobs SET is_job_dirty = 0 WHERE is_job_dirty = 1") + self.execute("UPDATE tasks SET is_task_dirty = 0 WHERE is_task_dirty = 1") + self.execute("UPDATE runs SET is_run_dirty = 0 WHERE is_run_dirty = 1") + + self.execute("COMMIT") + + return impl() + + def export_dirty_jobs(self) -> DirtyJobs: + jobs_query = """ + SELECT job_id, time_job_submitted, time_job_started, time_job_finished + FROM jobs + WHERE is_job_dirty = 1 + OR job_id IN (SELECT job_id FROM tasks WHERE is_task_dirty = 1 + UNION SELECT job_id FROM runs WHERE is_run_dirty = 1) + """ + with self.query(jobs_query) as query: + return {row["job_id"]: JobData(time_job_submitted=row["time_job_submitted"], + time_job_started=row["time_job_started"], + time_job_finished=row["time_job_finished"]) + for row in query} + + def export_dirty_targets(self) -> DirtyTargets: + targets_query = """ + SELECT job_id, target, + IFNULL((SELECT IFNULL(result, 'RUNNING') + FROM runs + WHERE runs.job_id = tasks.job_id + AND runs.target = tasks.target + AND runs.function = tasks.function + AND runs.runner_id = tasks.runner_id + AND runs.run_id = tasks.run_id), + 'WAITING') + AS status, + count(*) AS task_count + FROM tasks + WHERE job_id in (SELECT job_id FROM jobs WHERE is_job_dirty = 1 + UNION SELECT job_id FROM tasks AS t WHERE t.is_task_dirty = 1 + UNION SELECT job_id FROM runs WHERE is_run_dirty = 1) + GROUP BY job_id, target, status + """ + with self.query(targets_query) as query: + targets: dict[str, dict[str, dict[str, int]]] = {} + for row in query: + job = targets.setdefault(row["job_id"], {}) + target = job.setdefault(row["target"], {}) + target[row["status"]] = row["task_count"] + + return targets + + def export_dirty_tasks(self) -> DirtyTasks: + assignments = self.export_dirty_assignments() + runs = self.export_dirty_runs() + tasks: dict[str, dict[str, dict[str, TaskData]]] = {} + + for job_id in assignments.keys() | runs.keys(): + assignments_targets = assignments.get(job_id, {}) + runs_targets = runs.get(job_id, {}) + tasks_targets = tasks.setdefault(job_id, {}) + + for target in assignments_targets.keys() | runs_targets.keys(): + assignments_functions = assignments_targets.get(target, {}) + runs_functions = runs_targets.get(target, {}) + tasks_functions = tasks_targets.setdefault(target, {}) + + for function in assignments_functions.keys() | runs_functions.keys(): + assignment = assignments_functions.get(function, None) + runs_data = runs_functions.get(function, {}) + tasks_functions[function] = TaskData(assignment=assignment, runs=runs_data) + + return tasks + + def export_dirty_assignments(self) -> TargetMap[Mapping[str, TaskAssignment]]: + tasks_query = """ + SELECT job_id, target, function, + runs.runner_id AS runner_id, + runs.run_id AS run_id + FROM tasks NATURAL LEFT JOIN runs + WHERE is_task_dirty = 1 + """ + with self.query(tasks_query) as query: + assignments: dict[str, dict[str, dict[str, TaskAssignment]]] = {} + for row in query: + job = assignments.setdefault(row["job_id"], {}) + target = job.setdefault(row["target"], {}) + target[row["function"]] = \ + TaskUnassigned() if row["runner_id"] is None or row["run_id"] is None \ + else TaskAssigned(runner_id=row["runner_id"], run_id=row["run_id"]) + + return assignments + + def export_dirty_runs(self) -> TargetMap[Mapping[str, Mapping[TaskAssigned, RunData]]]: + runs_query = """ + SELECT job_id, target, function, runner_id, run_id, + time_run_started, time_run_finished, result + FROM runs + WHERE is_run_dirty = 1 + """ + with self.query(runs_query) as query: + runs: dict[str, dict[str, dict[str, dict[TaskAssigned, RunData]]]] = {} + for row in query: + job = runs.setdefault(row["job_id"], {}) + target = job.setdefault(row["target"], {}) + function = target.setdefault(row["function"], {}) + assigned = TaskAssigned(runner_id=row["runner_id"], run_id=row["run_id"]) + function[assigned] = RunData(time_run_started=row["time_run_started"], + time_run_finished=row["time_run_finished"], + result=row["result"]) + + return runs + + +# File locking utilities. +# +# To allow job runners to execute concurrently with each other, and with job +# submission, we use `flock` advisory file locks to control access to shared +# resources. +# +# We use file locks in two ways: +# +# - To ensure uniqueness of certain processes. Each unique process is represented +# by a lock file that is not otherwise used. On startup, the process makes a +# non-blocking attempt to lock the file. If successful, the process should hold +# onto the lock as long as it continues executing. Otherwise, it should exit, +# since there is already a process running. +# +# - To ensure atomic updates to shared files. For this, we need three files: the +# file we want to read and update, a separate lock file, and a temporary file to +# stage the update. We first make a blocking attempt to lock the lock file. When +# we get the lock, we can read the file and act on its contents. To update it, +# we write a new version to a temporary file, and then atomically move it to the +# original file's location. Finally, we can release the lock. +# +# The main differences are whether we block on the attempt to lock the file, and +# whether we hold onto the lock for an extended period. +# +# For updateable files, we use JSON-encoded data, so we build JSON handling into +# the locking protocol. We try not to let these files get too big. + + +class LockFile(NamedTuple): + path: Path + is_locked: Callable[[], bool] + + +def flock(lock_path: Path, *, block: bool) -> ContextManager[LockFile]: + @contextmanager + def impl() -> Iterator[LockFile]: + with open(lock_path, 'w+') as lock_file: + + def is_locked() -> bool: + return not lock_file.closed + + def is_not_locked() -> bool: + return False + + try: + block_op = 0 if block else fcntl.LOCK_NB + fcntl.flock(lock_file, fcntl.LOCK_EX | block_op) + yield LockFile(path=lock_path, is_locked=is_locked) + + except BlockingIOError: + # `flock` is intended to be used at the head of a `with` block. + # Therefore, if we propagate the exception, the caller will be + # unable to tell whether the exception was raised by the `flock` + # call or by the body of the `with` block, unless they also + # wrap the body in an exception handler. It's more convenient + # to have the context manager return a value we can test. + yield LockFile(path=lock_path, is_locked=is_not_locked) + + return impl() + + +def file_lock(path: Path, *, block: bool) -> ContextManager[LockFile]: + assert path.name, 'file_lock: empty basename: {path}' + assert path.parent.is_dir(), 'file_lock: parent does not exist: {path}' + return flock(path.parent / f'{path.name}.lock', block=block) + + +def dir_lock(path: Path, *, block: bool) -> ContextManager[LockFile]: + assert path.is_dir(), 'dir_lock: not a directory' + return flock(path / '.lock', block=block) + + +class LockedJson(Protocol): + def get_data(self, default: T, get: Callable[[Any], T]) -> T: + ... + + def put_data(self, data: Any) -> None: + ... + + +def json_file_lock(path: Path) -> ContextManager[LockedJson]: + + @contextmanager + def impl() -> Iterator[LockedJson]: + + with file_lock(path, block=True) as lock: + assert lock.is_locked(), f'json_file_lock: failed to lock: {path}' + tmp_path = path.parent / f'{path.name}.tmp' + + class LockedJsonImpl(NamedTuple): + + def get_data(self, default: T, get: Callable[[Any], T]) -> T: + assert lock.is_locked(), f'json_file_lock: get_data without a lock: {path}' + try: + with open(path, 'r') as json_file: + return get(json.load(json_file)) + except FileNotFoundError: + return default + + def put_data(self, data: Any) -> None: + assert lock.is_locked(), f'json_file_lock: put_data without a lock: {path}' + with open(tmp_path, 'w') as tmp_file: + json.dump(data, tmp_file, separators=(',', ':')) + shutil.move(tmp_path, path) + + yield LockedJsonImpl() + + return impl() + + +def iter_non_empty(iter: Iterable[T]) -> bool: + return any(True for i in iter) + + +def sets_disjoint(*sets: set[T]) -> bool: + return all(sets[i].isdisjoint(sets[j]) + for i in range(len(sets)) + for j in range(i + 1, len(sets))) + + +# A CpuSet represents a collection of computational resources +# that may be available for running tasks. A "CPU" is represented +# by a lock file which is held while a task is being performed. +class CpuSet(NamedTuple): + # The number of CPUs we have lock files for. + present: int + # The format string to translate CPU number to lock directory name. + fmt: str + # The path containing lock directories. + path: Path + + def cpu_name(self, cpu_id) -> str: + assert 0 <= cpu_id < self.present, f'cpu_name: invalid cpu_id: {cpu_id}' + return f'{cpu_id:{self.fmt}}' + + # The directory containing the lock file. + def cpu_dir(self, cpu_id) -> Path: + return self.path / self.cpu_name(cpu_id) + + def set(self) -> set[int]: + return set(range(self.present)) + + +# We need a lock per CPU, so we can wait on CPU availability. +# We create lock directories up to the allowed number of CPUs. +# We don't remove existing directories exceeding the allowance, +# since other runners may be using them. +# This should only be used when we have a lock on cpus_allowed.txt. +def build_cpu_set(cpus_path: Path, cpus_allowed: int) -> CpuSet: + cpus_path.mkdir(parents=True, exist_ok=True) + + existing_cpus = list(enumerate(sorted(os.listdir(cpus_path)))) + cpus_present = max(cpus_allowed, len(existing_cpus)) + + cpu_digits = len(str(cpus_present - 1)) + cpu_fmt = f'0{cpu_digits}d' + + assert all((cpus_path / j).is_dir() for i, j in existing_cpus), \ + 'initialise_cpus error: cpus entries are not all directories' + + assert all(j.isdigit() and i == int(j) for i, j in existing_cpus), \ + 'initialise_cpus error: cpu lock directories have unexpected names' + + cpu_set = CpuSet(present=cpus_present, fmt=cpu_fmt, path=cpus_path) + + for cpu_num, existing_cpu in existing_cpus: + new_cpu_name = f'{cpu_num:{cpu_fmt}}' + if new_cpu_name != existing_cpu: + assert int(new_cpu_name) == int(existing_cpu) + new_cpu_path = cpus_path / new_cpu_name + assert not new_cpu_path.exists() + # It's ok for us to rename these, since we are the only runner + # allowed to allocate CPUs to new tasks, and we preserve numbers. + os.rename(cpus_path / existing_cpu, new_cpu_path) + + for cpu_num in range(len(existing_cpus), cpus_allowed): + cpu_set.cpu_dir(cpu_num).mkdir() + + return cpu_set + + +class ActiveRunnerStatus(NamedTuple): + is_active: bool + is_cache_valid: bool + + +@dataclass(kw_only=True) +class RunnerState: + cond_var: threading.Condition + + tasks_running: set[TaskRun] + runners_waiting: set[str] + + cpu_set: CpuSet + cpus_allowed: int + cpus_waiting: set[int] + cpus_working: set[int] + cpus_idle: set[int] + + +def default_priority(task: Task) -> int: + return 200 + + +targets_admitted: set[str] = { + 'RISCV64-O1', + 'RISCV64-MCS-O1', +} + +functions_rejected: dict[str, set[str]] = { + 'RISCV64-O1': { + 'create_untypeds_for_region', + }, + 'RISCV64-MCS-O1': { + 'create_untypeds_for_region', + }, +} + + +# Enumerate the tasks for a job. +# Used to populate the tasks table in the database when starting a job. +def get_tasks_for_job(jobs_path: Path, + priority: EnqueuePriority = default_priority) -> GetTasksForJob: + + def get_tasks(job_id: str) -> Sequence[EnqueueTask]: + job_path = jobs_path / job_id + + enqueue_tasks: list[EnqueueTask] = [] + targets: list[str] = [] + + for target in os.listdir(job_path / 'targets'): + if target not in targets_admitted: + continue + + target_path = job_path / 'targets' / target + functions_list_path = target_path / 'target' / 'functions-list.txt' + if not functions_list_path.is_file(): + continue + targets.append(target) + + def enq_task(function: str) -> EnqueueTask: + task = Task(job_id=job_id, target=target, function=function) + return EnqueueTask(task=task, priority=priority(task)) + + with open(functions_list_path) as functions_list: + enqueue_tasks.extend(enq_task(function) + for function in (line.strip() for line in functions_list) + if function not in functions_rejected.get(target, set())) + + return enqueue_tasks + + return get_tasks + + +result_re = re.compile( + r'Result (?P\w+) for pair Pairing \((?P\w+) \(ASM\) <= \S+ \(C\)\), time taken: .*') +underspecified_fn_re = re.compile( + r'Aborting Problem \(Pairing \((?P\S+) \(ASM\) <= \S+ \(C\)\)\): underspecified \S+') +complex_loop_re = re.compile( + r'Aborting Problem \(Pairing \((?P\S+) \(ASM\) <= \S+ \(C\)\)\), complex loop') +impossible_re = re.compile( + r"Possibilities for '(?P\S+)': \[\]") +split_limit_assertion_re = re.compile( + r"assert not 'split limit found'") + + +def graph_refine_result(name: str, report_path: Path) -> TaskResult: + if report_path.is_file(): + with open(report_path) as report: + for line in report: + line = line.strip() + + match = result_re.fullmatch(line) + if match: + assert match['name'] == name + if match['result'] == 'True': + return TaskResult.PASSED + if match['result'] == 'False': + return TaskResult.FAILED + if match['result'] == 'ProofNoSplit': + return TaskResult.NO_SPLIT + if match['result'] == 'ProofEXCEPT': + return TaskResult.EXCEPT + else: + return TaskResult.MALFORMED + + match = underspecified_fn_re.fullmatch(line) + if match: + assert match['name'] == name + return TaskResult.UNDERSPECIFIED + + match = complex_loop_re.fullmatch(line) + if match: + assert match['name'] == name + return TaskResult.COMPLEX_LOOP + + return TaskResult.NO_RESULT + + +def graph_refine_impossible(name: str, log_path: Path) -> bool: + if log_path.is_file(): + with open(log_path) as log: + for line in log: + line = line.strip() + match = impossible_re.fullmatch(line) + if match: + assert match['name'] == name + return True + return False + + +def split_limit_assertion(log_path: Path) -> bool: + if log_path.is_file(): + with open(log_path) as log: + for line in log: + line = line.strip() + if line == "assert not 'split limit found'": + return True + return False + + +def ensure_dict(data: Any) -> dict[str, Any]: + assert isinstance(data, dict) + assert all(isinstance(k, str) for k in data) + return data + + +class Runner(NamedTuple): + work_dir: Path + runner_id: str + instance_id: str + graph_refine: Path + + def jobs_db(self) -> ContextManager[JobsDB]: + @contextmanager + def impl() -> Iterator[JobsDB]: + conn = sqlite3.connect(self.work_dir / 'private' / 'jobs.db') + try: + conn.row_factory = sqlite3.Row + yield JobsDB(conn=conn, get_tasks=get_tasks_for_job(jobs_path=self.jobs_dir())) + finally: + conn.close() + + return impl() + + def jobs_dir(self) -> Path: + return self.work_dir / 'public' / 'jobs' + + def targets_dir(self, job_id: str) -> Path: + return self.jobs_dir() / job_id / 'targets' + + def runner_lock_dir(self, runner_id: str) -> Path: + return self.work_dir / 'private' / 'runners' / runner_id + + def runner_lock(self, runner_id: str, *, block: bool) -> ContextManager[LockFile]: + return dir_lock(self.runner_lock_dir(runner_id), block=block) + + def new_jobs(self) -> Sequence[WaitingJob]: + new_dir = self.work_dir / 'private' / 'new' + if not new_dir.is_dir(): + return [] + + def get_job(entry: os.DirEntry) -> Iterator[WaitingJob]: + job_id = entry.name + job_path = new_dir / job_id + if job_path.is_file(): + with open(job_path) as job_file: + time_job_submitted = job_file.read().strip() + yield WaitingJob(job_id=job_id, time_job_submitted=time_job_submitted) + + return [job for entry in os.scandir(new_dir) for job in get_job(entry)] + + def add_new_jobs(self) -> None: + new_jobs = self.new_jobs() + with self.jobs_db() as db: + db.add_waiting_jobs(new_jobs) + for job in new_jobs: + write_log(f'add_new_jobs: added job {job.job_id}') + (self.work_dir / 'private' / 'new' / job.job_id).unlink() + + # Get the allowed number of CPUs, and hold a lock on cpus_allowed.txt. + def cpus_allowed_lock(self) -> ContextManager[int]: + cpus_allowed_path = self.work_dir / 'private' / 'cpus_allowed.txt' + + @contextmanager + def impl() -> Iterator[int]: + with file_lock(cpus_allowed_path, block=True) as lock: + assert lock.is_locked(), 'cpus_allowed_lock: failed to lock cpus_allowed.txt' + + # Let it crash if the file doesn't exist, or has invalid contents. + with open(cpus_allowed_path) as cpus_allowed_file: + cpus_allowed = int(cpus_allowed_file.read().strip()) + + yield cpus_allowed + + return impl() + + # Take a lock on the active_runner_id.txt file, and return flags indicating: + # - whether our runner_id is the active runner ID, + # - if so, whether the most recently set instance ID matches ours. + # The latter is used to determine cache validity. + def active_runner_lock(self) -> ContextManager[ActiveRunnerStatus]: + runner_id_path = self.work_dir / 'private' / 'active_runner_id.txt' + instance_id_path = self.work_dir / 'private' / 'active_instance.txt' + + @contextmanager + def impl() -> Iterator[ActiveRunnerStatus]: + with file_lock(runner_id_path, block=True) as lock: + assert lock.is_locked(), f'active_runner_lock: failed to acquire lock' + + try: + with open(runner_id_path, 'r') as runner_id_file: + is_active = runner_id_file.read().strip() == self.runner_id + except FileNotFoundError: + is_active = False + + if not is_active: + yield ActiveRunnerStatus(is_active=False, is_cache_valid=False) + return + + try: + with open(instance_id_path, 'r') as instance_id_file: + is_cache_valid = instance_id_file.read().strip() == self.instance_id + except FileNotFoundError: + is_cache_valid = False + + yield ActiveRunnerStatus(is_active=True, is_cache_valid=is_cache_valid) + + return impl() + + # Invalidates caches of any other runners. + # This should only be called when we have locks on both: + # - our runner_id directory, + # - the active_runner_id.txt file. + def set_active_instance_id(self) -> None: + with open(self.work_dir / 'private' / 'active_instance.txt', 'w') as instance_id_file: + print(self.instance_id, file=instance_id_file) + + def other_runners(self) -> set[str]: + return set(r for r in os.listdir(self.work_dir / 'private' / 'runners') + if r != self.runner_id) + + def wait_for_runners(self, runners: Iterable[str], state: RunnerState) -> None: + for runner in runners: + write_log(f'wait_for_runners: waiting for {runner}') + state.runners_waiting.add(runner) + + def thread_fn(runner: str) -> None: + with self.runner_lock(runner, block=True) as lock: + assert lock.is_locked(), f'wait_for_runners: failed to lock runner {runner}' + with state.cond_var: + write_log(f'wait_for_runners: got {runner}') + state.runners_waiting.remove(runner) + # Only the active runner should archive another runner. + with self.active_runner_lock() as active_runner: + if active_runner.is_active: + with self.jobs_db() as db: + db.unassign_tasks(runner) + shutil.rmtree(self.runner_lock_dir(runner)) + state.cond_var.notify() + + thread = threading.Thread(target=thread_fn, args=(runner, ), daemon=True) + thread.start() + + def cpus_dir(self) -> Path: + return self.work_dir / 'private' / 'cpus' + + # Called from the main thread while holding the active_runner_id lock + # and the shared state condition variable. + # Initially adds each CPU to `cpus_waiting`, and starts a thread to wait + # for the CPU lock. When the thread acquires the lock, it moves the CPU to + # `cpus_idle`, and notifies the main thread. + def wait_for_cpus(self, cpus: Iterable[int], state: RunnerState) -> None: + for cpu in cpus: + write_log(f'wait_for_cpus: waiting for CPU {state.cpu_set.cpu_name(cpu)}') + state.cpus_waiting.add(cpu) + + def thread_fn(cpu: int) -> None: + with dir_lock(state.cpu_set.cpu_dir(cpu), block=True) as lock: + assert lock.is_locked(), f'wait_for_cpus: failed to lock CPU {cpu}' + with state.cond_var: + write_log(f'wait_for_cpus: got CPU {state.cpu_set.cpu_name(cpu)}') + state.cpus_waiting.remove(cpu) + state.cpus_idle.add(cpu) + state.cond_var.notify() + + thread = threading.Thread(target=thread_fn, args=(cpu, ), daemon=True) + thread.start() + + def initialise_state(self, cond_var: threading.Condition) -> Optional[RunnerState]: + with self.active_runner_lock() as active_runner: + if not active_runner.is_active: + return None + + with self.cpus_allowed_lock() as cpus_allowed: + cpu_set = build_cpu_set(self.cpus_dir(), cpus_allowed) + + state = RunnerState(cond_var=cond_var, + tasks_running=set(), + runners_waiting=set(), + cpu_set=cpu_set, + cpus_allowed=cpus_allowed, + cpus_waiting=set(), + cpus_working=set(), + cpus_idle=set()) + + self.wait_for_cpus(cpu_set.set(), state) + self.wait_for_runners(self.other_runners(), state) + + with self.jobs_db() as db: + db.initialise() + # In case we are recovering from a crash/reboot. + db.unassign_tasks(self.runner_id) + + self.add_new_jobs() + + self.set_active_instance_id() + return state + + # Assumes we have the runner_id lock, and the active_runner_id lock. + def refresh_state(self, state: RunnerState, is_cache_valid: bool) -> None: + if is_cache_valid: + write_log('refresh_state: cache is valid, minimally refreshing CPU set') + with self.cpus_allowed_lock() as cpus_allowed: + if state.cpus_allowed < cpus_allowed: + state.cpu_set = build_cpu_set(self.cpus_dir(), cpus_allowed) + # Force generation of `new_cpus` before calling `wait_for_cpus`, + # because`wait_for_cpus` can modify `cpus_waiting` and `cpus_idle`. + new_cpus = state.cpu_set.set() \ + - state.cpus_working - state.cpus_waiting - state.cpus_idle + self.wait_for_cpus(new_cpus, state) + state.cpus_allowed = cpus_allowed + else: + write_log('refresh_state: cache is not valid, refreshing everything') + with self.cpus_allowed_lock() as cpus_allowed: + state.cpu_set = build_cpu_set(self.cpus_dir(), cpus_allowed) + # CPUs we thought were idle may have become occupied, + # so we need to wait for them again. + state.cpus_idle.clear() + # Force generation of `wait_cpus` before calling `wait_for_cpus`, + # because`wait_for_cpus` can modify `cpus_waiting` and `cpus_idle`. + wait_cpus = state.cpu_set.set() - state.cpus_working - state.cpus_waiting + self.wait_for_cpus(wait_cpus, state) + state.cpus_allowed = cpus_allowed + runners = self.other_runners() - state.runners_waiting + self.wait_for_runners(runners, state) + self.set_active_instance_id() + + def run_graph_refine(self, task: TaskRun, state: RunnerState) -> None: + if task in state.tasks_running: + return + + cpu = state.cpus_idle.pop() + state.cpus_working.add(cpu) + + log_info = f'{task.function} {task.target} {task.job_id} {task.run_id}' + write_log(f'run_graph_refine: begin {log_info}') + + smt_dir = self.jobs_dir() / task.job_id / 'smt' + target_path = self.jobs_dir() / task.job_id / 'targets' / task.target + target_inputs = target_path / 'target' + task_dir = target_path / 'functions' / task.function / task.run_id + report_path = task_dir / 'report.txt' + log_path = task_dir / 'log.txt' + + l4v_arch: Optional[str] = None + with open(target_inputs / 'config.env') as config_env: + for line in config_env: + var, val = line.strip().split('=', maxsplit=1) + if var == "L4V_ARCH": + l4v_arch = val + + assert l4v_arch is not None + + for new_dir in [smt_dir, task_dir]: + new_dir.mkdir(parents=True, exist_ok=True) + + # Runs in the thread started below. + def thread_finished(proc_result: Optional[int]) -> None: + result = graph_refine_result(task.function, report_path) + + if result is TaskResult.EXCEPT: + if split_limit_assertion(log_path): + result = TaskResult.NO_SPLIT + + elif result is TaskResult.NO_RESULT: + if graph_refine_impossible(task.function, log_path): + result = TaskResult.IMPOSSIBLE + elif proc_result is None: + result = TaskResult.TIMEOUT + elif proc_result < 0: + result = TaskResult.KILLED + + with open(task_dir / 'exit_code.txt', 'w') as exit_code_txt: + print(proc_result, file=exit_code_txt) + + write_log(f'run_graph_refine: result {result.name} {log_info}') + + with state.cond_var: + state.tasks_running.remove(task) + state.cpus_working.remove(cpu) + state.cpus_idle.add(cpu) + + with self.jobs_db() as db: + db.finish_run(runner_id=self.runner_id, task=task, result=result) + + state.cond_var.notify() + + class StartOk(NamedTuple): + proc: subprocess.Popen + + class StartExcept(NamedTuple): + ex: Exception + + subproc_queue: queue.Queue[StartOk | StartExcept] = queue.Queue() + + def thread_fn() -> None: + try: + cmd: list[str | Path] = \ + [self.graph_refine, target_inputs, f'trace-to:{report_path}', task.function] + + assert l4v_arch is not None + env: dict[str, str | Path] = \ + {**os.environ, "L4V_ARCH": l4v_arch, "GRAPH_REFINE_SMT2_DIR": smt_dir} + + with dir_lock(state.cpu_set.cpu_dir(cpu), block=True): + with open(log_path, 'w') as log_file: + proc = subprocess.Popen(cmd, cwd=task_dir, env=env, + stdin=subprocess.DEVNULL, stdout=log_file, + stderr=subprocess.STDOUT) + + subproc_queue.put(StartOk(proc)) + + try: + result: Optional[int] = proc.wait(timeout=259200) + except subprocess.TimeoutExpired: + result = None + finally: + try: + os.killpg(proc.pid, signal.SIGTERM) + except ProcessLookupError: + pass + finally: + thread_finished(result) + + except Exception as ex: + subproc_queue.put(StartExcept(ex)) + raise + + thread = threading.Thread(target=thread_fn) + thread.start() + + thread_start_result = subproc_queue.get() + if isinstance(thread_start_result, StartExcept): + raise thread_start_result.ex + + state.tasks_running.add(task) + + def start_tasks(self, state: RunnerState) -> None: + assert sets_disjoint(state.cpus_waiting, state.cpus_working, state.cpus_idle), \ + 'start_tasks: CPU sets not disjoint' + assert state.cpu_set.set() == state.cpus_waiting | state.cpus_working | state.cpus_idle, \ + 'start_tasks: CPU sets incomplete' + assert state.cpus_allowed <= state.cpu_set.present, \ + 'start_tasks: allowed CPUs not present' + cpus_busy = len(state.cpus_waiting) + len(state.cpus_working) + cpus_available = state.cpus_allowed - cpus_busy + if cpus_available <= 0: + return + write_log(f'start_tasks: {cpus_available} CPUs available') + with self.jobs_db() as db: + tasks = db.assign_tasks(self.runner_id, cpus_available) + for task in tasks: + self.run_graph_refine(task, state) + + def export_json_tasks(self, dirty_tasks: DirtyTasks) -> None: + for job_id, dirty_targets in dirty_tasks.items(): + for target, dirty_functions in dirty_targets.items(): + functions_json_path = self.targets_dir(job_id) / target / 'functions.json' + with json_file_lock(functions_json_path) as functions_json: + functions = ensure_dict(functions_json.get_data({}, lambda fs: fs)) + for function, task_data in dirty_functions.items(): + function_data = ensure_dict(functions.setdefault(function, {})) + if isinstance(task_data.assignment, TaskUnassigned): + function_data['assignment'] = None + elif isinstance(task_data.assignment, TaskAssigned): + function_data['assignment'] = task_data.assignment._asdict() + runs = ensure_dict(function_data.setdefault('runs', {})) + for assignment, run_data in task_data.runs.items(): + runner = ensure_dict(runs.setdefault(assignment.runner_id, {})) + runner[assignment.run_id] = run_data._asdict() + functions_json.put_data(functions) + + def export_json_jobs(self, *, dirty_jobs: DirtyJobs, dirty_targets: DirtyTargets) -> None: + # Maps from job_id to time submitted. + dirty_waiting: dict[str, str] = {} + dirty_running: dict[str, str] = {} + dirty_finished: dict[str, str] = {} + + for job_id, job_data in dirty_jobs.items(): + job_status = {"timestamps": job_data._asdict(), + "targets": dirty_targets.get(job_id, {})} + + with json_file_lock(self.jobs_dir() / job_id / 'job_status.json') as job_status_json: + job_status_json.put_data(job_status) + + dirty_set = dirty_finished if job_data.time_job_finished is not None \ + else dirty_running if job_data.time_job_started is not None \ + else dirty_waiting + + dirty_set[job_id] = job_data.time_job_submitted + + with json_file_lock(self.work_dir / 'public' / 'jobs.json') as jobs_json: + jobs = ensure_dict(jobs_json.get_data({}, lambda js: js)) + + class ListedJob(NamedTuple): + job_id: str + submitted: str + + def time_job_submitted_key(job: ListedJob) -> str: + return job.submitted + + def ensure_job(job_data: Any) -> ListedJob: + assert isinstance(job_data, dict) + assert isinstance(job_data["job_id"], str) + assert isinstance(job_data["submitted"], str) + return ListedJob(**job_data) + + def ensure_jobs(jobs_data: Any) -> list[ListedJob]: + assert isinstance(jobs_data, list) + return [ensure_job(job) for job in jobs_data] + + def build_jobs(key: str, new: dict[str, str], exclude: Sequence[dict[str, str]], + limit: Optional[int] = None) -> list[dict[str, str]]: + + jobs_list = [j for j in ensure_jobs(jobs.get(key, [])) + if j.job_id not in new and not any(j.job_id in e for e in exclude)] + + jobs_list.extend(ListedJob(job_id, submitted) for job_id, submitted in new.items()) + jobs_list.sort(key=time_job_submitted_key, reverse=True) + + if limit is not None: + jobs_list = jobs_list[:limit] + + return [job._asdict() for job in jobs_list] + + waiting = build_jobs(key='waiting', new=dirty_waiting, + exclude=(dirty_running, dirty_finished)) + running = build_jobs(key='running', new=dirty_running, + exclude=(dirty_waiting, dirty_finished)) + finished = build_jobs(key='finished', new=dirty_finished, + exclude=(dirty_waiting, dirty_running), limit=500) + + jobs_json.put_data({"waiting": waiting, "running": running, "finished": finished}) + + def export_json(self, state: RunnerState) -> None: + with self.jobs_db() as db, db.export_dirty_state() as dirty_state: + self.export_json_tasks(dirty_state.tasks) + self.export_json_jobs(dirty_jobs=dirty_state.jobs, dirty_targets=dirty_state.targets) + dirty_state.clean() + + # Assumes we are holding our runner_id lock. + def do_work(self) -> None: + # We use a condition variable to allow the main thread to receive + # notifications from other threads, while also controlling access + # to shared resources. To ensure the main thread doesn't miss any + # notifications, it must always be holding the condition variable, + # except when it pauses to wait for notifications. + cond_var = threading.Condition() + with cond_var: + state = self.initialise_state(cond_var) + + if state is None: + return + + while True: + with self.active_runner_lock() as active_runner: + + if active_runner.is_active: + write_log('do_work: acquired active runner lock') + self.refresh_state(state, active_runner.is_cache_valid) + self.add_new_jobs() + self.start_tasks(state) + self.export_json(state) + else: + write_log('do_work: not the active runner') + + if not state.tasks_running: + if not active_runner.is_active: + write_log('do_work: no running tasks, exiting') + return + if not state.runners_waiting and not state.cpus_waiting: + write_log('do_work: not waiting on CPUs or runners, exiting') + return + write_log('do_work: no tasks, but waiting on' + + f' {len(state.cpus_waiting)} CPUs and' + + f' {len(state.runners_waiting)} runners') + else: + write_log(f'do_work: {len(state.tasks_running)} tasks still running') + + # A long timeout is ok, since cond_var will be notified by + # other threads as soon as there is a finished task or idle CPU. + # The timeout limits how long it takes us to: + # - Start using an increased CPU allowance. + # - Add new jobs to the database and its exported JSON. + write_log('do_work: waiting for events') + cond_var.wait(timeout=50) + + def have_new_jobs(self) -> bool: + with self.active_runner_lock() as active_runner: + return active_runner.is_active and iter_non_empty(self.new_jobs()) + + def run(self) -> None: + self.runner_lock_dir(self.runner_id).mkdir(parents=True, exist_ok=True) + + # The outer loop here avoids a race condition during shutdown. + # Most runners will only iterate the outer loop once, since `do_work` should + # normally continue until the runner's work is exhausted. + + while True: + with self.runner_lock(self.runner_id, block=False) as lock: + if not lock.is_locked(): + write_log('run: failed to acquire runner lock, exiting') + return + + # We are the runner authorised to use our runner_id, + # so we need to do the work. + write_log('run: acquired runner lock') + self.do_work() + + # If `do_work` returned, then it has completed all the work it could see. + # Normally, that means we want to shut down. + + # However, there is a potential race with new job submissions. + + # After submitting a new job, the submitter will attempt to start a new + # runner. The new runner will exit immediately if there is already a + # runner holding runner_lock (see above). If the existing runner + # has already decided to shut down, but has not yet released runner_lock, + # then we could be left without a runner for the new job. + + # To avoid the race, ensuring that we're left with exactly one runner for + # the new job, the existing runner (but not the new runner) must perform an + # additional check for new work after it has released runner_lock. + + write_log('run: released runner lock, checking for new jobs') + if not self.have_new_jobs(): + write_log('run: no new jobs, exiting') + return + + +# Use a low-level redirect to make sure we get everything. +def redirect_to_log(path: Path) -> None: + sys.stdout.flush() + sys.stderr.flush() + log_fd = os.open(path, os.O_WRONLY | os.O_CREAT) + os.set_inheritable(log_fd, True) + os.dup2(log_fd, 1) + os.dup2(log_fd, 2) + os.close(log_fd) + + +class RunnerConfig(NamedTuple): + work_dir: Path + runner_id: str + graph_refine: Path + + def run(self) -> None: + runner = Runner(work_dir=self.work_dir, + runner_id=self.runner_id, + instance_id=mk_unique_id(), + graph_refine=self.graph_refine) + + log_dir = self.work_dir / 'public' / 'runners' / self.runner_id / 'logs' + log_dir.mkdir(parents=True, exist_ok=True) + + redirect_to_log(log_dir / f'{runner.instance_id}.log') + runner.run() + + +runner_id_re = re.compile(r'\w+') + + +def parse_args() -> RunnerConfig: + def runner_id(arg: str) -> str: + if runner_id_re.fullmatch(arg): + return arg + raise argparse.ArgumentTypeError(f'{arg} is not a valid runner ID') + + def work_dir_path(arg: str) -> Path: + if os.path.isdir(arg) and os.access(arg, os.R_OK | os.W_OK | os.X_OK): + return Path(arg).resolve() + raise argparse.ArgumentTypeError(f'{arg} is not a valid work directory') + + def graph_refine_script(arg: str) -> Path: + if os.path.isfile(arg) and os.access(arg, os.R_OK | os.X_OK): + return Path(arg).resolve() + raise argparse.ArgumentTypeError(f'{arg} is not a valid work directory') + + parser = argparse.ArgumentParser(description='Run graph-refine jobs') + parser.add_argument('--id', metavar='ID', required=True, type=runner_id, + help='Runner ID') + parser.add_argument('--work', metavar='DIR', required=True, type=work_dir_path, + help='Root of the graph-refine work directory') + parser.add_argument('--graph-refine-py', metavar='PATH', type=graph_refine_script, + default=os.environ.get('GRAPH_REFINE_SCRIPT'), + help='Path to graph-refine script') + + parser.set_defaults(redirect_output=False) + args = parser.parse_args() + + return RunnerConfig(work_dir=args.work, + runner_id=args.id, + graph_refine=args.graph_refine_py) + + +def main(runner_config: RunnerConfig) -> int: + runner_config.run() + return 0 + + +if __name__ == '__main__': + exit(main(parse_args())) diff --git a/ci/submit-graph-refine b/ci/submit-graph-refine new file mode 100755 index 00000000..8551bbe3 --- /dev/null +++ b/ci/submit-graph-refine @@ -0,0 +1,148 @@ +#!/bin/bash + +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Submit a job from GitHub CI to the graph-refine back end. +# +# This script is intended to be run in the GitHub CI workflow +# that prepares graph-refine inputs. It sends those inputs over +# SSH to a back-end server that is suitable for running +# graph-refine. +# +# It requires environment variables: +# - BV_BACKEND_WORK_DIR: Path on the remote (back-end) host of the +# graph-refine work directory, relative to the SSH home directory. +# - BV_SSH_CONFIG: Contents of an SSH config file that uses the name +# `graph-refine` for the remote (back-end) host. +# - BV_SSH_KEY: Private key with access to a user on the `graph-refine` host. +# - BV_SSH_KNOWN_HOSTS: Contents of an SSH known hosts file suitable for +# accessing the `graph-refine` host. +# - DOCKER_RUN_COMMAND: Command to use in place of `docker run`, +# e.g. `podman run --memory 20g`. +# - JOB_DIR: Path of a local directory (on GitHub CI) containing the job +# to submit. +# - RUNNER_TEMP: Path to a local temporary directory. +# +# It assumes that this script is running from a `graph-refine` checkout, +# where it is able to find the `dir_hash.py` script. +# +# The BV_BACKEND_WORK_DIR is assumed to follow the same structure as used +# in the parallel job runner (see runner.py). + +set -euo pipefail + +if [ $# -ne 0 ]; then + echo "submit-graph-refine: error: unexpected arguments" >&2 + exit 1 +fi + +if [ ! -d "${JOB_DIR}" ]; then + echo "submit-graph-refine: error: JOB_DIR does not exist: ${JOB_DIR}" >&2 + exit 1 +fi + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) + +JOB_ID=$("${SCRIPT_DIR}/dir_hash.py" "${JOB_DIR}") +if [ -z "${JOB_ID}" ]; then + echo "submit-graph-refine: error: failed to get job ID" >&2 + exit 1 +fi + +CI_TMP=$(mktemp -d -p "${RUNNER_TEMP}") +cleanup() { rm -rf "${CI_TMP}"; } +trap cleanup EXIT + +# Build an SSH config for reaching the back end. +SSH_DIR="${CI_TMP}/ssh" +mkdir "${SSH_DIR}" +touch "${SSH_DIR}/ssh_key" +chmod 0700 "${SSH_DIR}/ssh_key" +cat > "${SSH_DIR}/ssh_key" <<< "${BV_SSH_KEY}" +cat > "${SSH_DIR}/ssh_known_hosts" <<< "${BV_SSH_KNOWN_HOSTS}" + +# BV_SSH_CONFIG should define how to connect to a host named `graph-refine`. +# For example, it might contain something like: +# ``` +# Host graph-refine +# Hostname real-hostname.example.org +# User bv +# ``` +# It may also contain configuration for any required jump hosts. +cat > "${SSH_DIR}/ssh_config" < "${CI_TMP}/ci-receive" <> "${CI_TMP}/ci-receive" <<'EOF' +mkdir -p "${WORK_DIR}/private/tmp" "${WORK_DIR}/public/jobs" "${WORK_DIR}/private/new" + +INCOMING_TMP="$(mktemp -d -p "${WORK_DIR}/private/tmp")" +cleanup() { rm -rf "${INCOMING_TMP}"; } +trap cleanup EXIT +chmod +rx "${INCOMING_TMP}" + +# Unpack the job into a temp directory, so we can atomically +# move it to work directory. +mkdir -p "${INCOMING_TMP}/jobs/${JOB_ID}" +tar -xJf - -C "${INCOMING_TMP}/jobs/${JOB_ID}" + +if ! mv "${INCOMING_TMP}/jobs/${JOB_ID}" "${WORK_DIR}/public/jobs" 2>/dev/null; then + # If the job directory already exists, then someone else has + # submitted a job with the same hash. In that case, we assume the + # existing job is identical to the one we're trying to submit, so + # ignore the error. + echo "Job already exists" >&2 +fi + +# Whether or not someone else got here first, the job directory should +# now exist, so it's a real error if it doesn't. +if [ ! -d "${WORK_DIR}/public/jobs/${JOB_ID}" ]; then + echo "Failed to submit job" >&2 + exit 1 +fi + +# Flag the job as new, unless it has already been flagged. +mkdir -p "${INCOMING_TMP}/new" +TZ=UTC date +'%Y-%m-%d %H:%M:%S+00:00' > "${INCOMING_TMP}/new/${JOB_ID}" +mv -n "${INCOMING_TMP}/new/${JOB_ID}" "${WORK_DIR}/private/new" + +# Start a runner, in case there is not one already running. +# We use the Docker image tag as the runner ID. +ID_FILE="${WORK_DIR}/private/active_runner_id.txt" +if [ -f "${ID_FILE}" ]; then + RUNNER_TAG="$(< "${ID_FILE}")" + if [ -n "${RUNNER_TAG}" ]; then + ${DOCKER_RUN_COMMAND} --init -d \ + --mount "type=bind,src=${HOME}/${WORK_DIR},dst=/work" \ + "ghcr.io/sel4/graph-refine-runner:${RUNNER_TAG}" \ + --id "${RUNNER_TAG}" \ + --work /work + fi +fi +EOF + +bv_ssh() { ssh -F "${SSH_DIR}/ssh_config" graph-refine "$@"; } +tar -cJf - -C "${JOB_DIR}" . | bv_ssh "$(cat "${CI_TMP}/ci-receive")" diff --git a/ci/web/.gitignore b/ci/web/.gitignore new file mode 100644 index 00000000..a6d2bdfe --- /dev/null +++ b/ci/web/.gitignore @@ -0,0 +1,6 @@ +/dist/style.css +/node_modules/ + +# Useful for development +/dist/jobs/ +/dist/jobs.json diff --git a/ci/web/dist/index.html b/ci/web/dist/index.html new file mode 100644 index 00000000..5a129e47 --- /dev/null +++ b/ci/web/dist/index.html @@ -0,0 +1,17 @@ + + + + + + seL4 binary verification + + + + +
+

seL4 binary verification

+

Continuous testing with the seL4 binary verification toolchain

+
+
+ + diff --git a/ci/web/dist/main.js b/ci/web/dist/main.js new file mode 100644 index 00000000..648165f2 --- /dev/null +++ b/ci/web/dist/main.js @@ -0,0 +1,354 @@ +const jobs_index_sections = [ + { label: "running", title: "Running jobs" }, + { label: "waiting", title: "Waiting jobs" }, + { label: "finished", title: "Recently finished jobs" }, +]; + +const job_status_classification = new Map([ + ["WAITING", "waiting"], + ["RUNNING", "running"], + ["PASSED", "passed"], + ["UNDERSPECIFIED", "passed"], + ["COMPLEX_LOOP", "passed"], + ["IMPOSSIBLE", "passed"], + ["FAILED", "failed"], + ["NO_SPLIT", "failed"], + ["EXCEPT", "failed"], + ["TIMEOUT", "failed"], + ["MALFORMED", "failed"], + ["NO_RESULT", "failed"], + ["KILLED", "failed"], +]); + +const function_result_classification = new Map([ + ["PASSED", { group: "passed" }], + ["UNDERSPECIFIED", { group: "skipped", detail: "underspecified function" }], + ["COMPLEX_LOOP", { group: "skipped", detail: "complex loop" }], + ["IMPOSSIBLE", { group: "skipped", detail: "underspecified function" }], + ["FAILED", { group: "failed", detail: "failed refinement" }], + ["NO_SPLIT", { group: "failed", detail: "failed to split loop" }], + ["EXCEPT", { group: "failed", detail: "exception" }], + ["TIMEOUT", { group: "failed", detail: "timeout" }], + ["MALFORMED", { group: "failed", detail: "malformed report" }], + ["NO_RESULT", { group: "failed", detail: "no group found" }], + ["KILLED", { group: "failed", detail: "killed" }], +]); + +function on(f) { + return (a, b) => f(a) < f(b) ? -1 : f(b) < f(a) ? 1 : 0; +} + +function esc_html(s) { + return String(s).replaceAll('&', '&') + .replaceAll('<', '<') + .replaceAll('>', '>') + .replaceAll('"', '"'); +} + +async function fetch_json(url) { + return await fetch(url).then(r => r.json()); +} + +function zpad(n, len) { + return String(n).padStart(len, '0'); +} + +function date_time(date_str) { + const d = new Date(date_str); + return `${zpad(d.getFullYear(), 2)}-${zpad(d.getMonth() + 1, 2)}-${zpad(d.getDate(), 2)}` + + ` ${zpad(d.getHours(), 2)}:${zpad(d.getMinutes(), 2)}`; +} + +function time_diff_ms(t1, t2) { + const d1 = new Date(t1); + const d2 = (t2 === undefined || t2 === null) ? new Date() : new Date(t2); + return d2 - d1; +} + +function time_diff_str(t1, t2) { + const seconds = time_diff_ms(t1, t2) / 1000; + if (seconds < 90) { + return `${seconds.toFixed(0)}s`; + } + const minutes = seconds / 60; + if (minutes < 90) { + return `${minutes.toFixed(0)}m`; + } + const hours = minutes / 60; + if (hours < 36) { + return `${hours.toFixed(0)}h`; + } + const days = hours / 24; + return `${days.toFixed(0)}d`; +} + +function date_markup(title, date) { + return date === null ? '' : ` +
+
${title}
+
${date_time(date)}
+
+ `; +} + +function github_run(title, {repo, run}) { + return ` +
+
${title}
+ +
+ `; +} + +async function job_list_item(job_id) { + const [job_info, job_status] = await Promise.all([ + fetch_json(`jobs/${job_id}/job_info.json`), + fetch_json(`jobs/${job_id}/job_status.json`) + ]); + + const targets = Object.entries(job_status.targets).map(([target, status]) => { + const status_map = new Map(); + Object.entries(status).forEach(([status, count]) => { + const group = job_status_classification.get(status) ?? "failed"; + status_map.set(group, (status_map.get(group) ?? 0) + count); + }); + function render_status(s) { + return ` +
+
${s}
+
${status_map.get(s) ?? 0}
+
+ `; + } + return { target, content: ` +
+
+
target
+ +
+
+
+ ${render_status("waiting")} + ${render_status("running")} +
+
+ ${render_status("passed")} + ${render_status("failed")} +
+
+
+ `}; + }); + + targets.sort(on(s => s.target)); + + return ` +
+
+
+
job id
+
${esc_html(job_id.substring(0,10))}
+
+
+
+ ${date_markup("submitted", job_status.timestamps.time_job_submitted)} + ${date_markup("started", job_status.timestamps.time_job_started)} + ${date_markup("finished", job_status.timestamps.time_job_finished)} +
+
+ ${github_run("proof run", job_info.github.proof)} + ${github_run("decompile run", job_info.github.decompile)} +
+
+
+ ${targets.map(s => s.content).join('')} +
+ `; +} + +async function job_list({label, title, job_ids}) { + const job_items = await Promise.all(job_ids.map(job_list_item)); + + return job_ids.length === 0 ? '' : ` +
+

${title}

+ ${job_items.join('')} +
+ `; +} + +async function render_index() { + const jobs_json = await fetch_json('jobs.json'); + + const section_data = jobs_index_sections.flatMap( + ({label, title}) => ({label, title, job_ids: jobs_json[label].map(job => job.job_id)}) + ); + + const jobs_lists = await Promise.all(section_data.map(job_list)); + document.getElementById("content-container").innerHTML = jobs_lists.join(''); +} + + + +async function render_target(job_id, target) { + const [job_info, job_status, functions_status] = await Promise.all([ + fetch_json(`jobs/${job_id}/job_info.json`), + fetch_json(`jobs/${job_id}/job_status.json`), + fetch_json(`jobs/${job_id}/targets/${target}/functions.json`) + ]); + + const versions = job_info.targets[target].versions; + + const waiting_functions = []; + const running_functions = []; + + const passed_functions = []; + const failed_functions = []; + const skipped_functions = []; + + Object.entries(functions_status).forEach(([name, info]) => { + const runner_id = info.assignment?.runner_id; + const run_id = info.assignment?.run_id; + const run_info = info.runs?.[runner_id]?.[run_id]; + const started = run_info?.time_run_started; + const finished = run_info?.time_run_finished; + const result = run_info?.result; + + function fun_name_cell({indicator_colour}) { + const indicator = ``; + return indicator + esc_html(name); + } + + if (started === undefined) { + waiting_functions.push({name, content: ` + + ${fun_name_cell({indicator_colour: "bg-gray-500"})} + + `}); + return; + } + + function report_log_cell() { + const fun_url = `jobs/${esc_html(job_id)}/targets/${esc_html(target)}/functions/${esc_html(name)}/${esc_html(run_id)}`; + const link = s => `${s}`; + return `${link("report")} ${link("log")}`; + } + + function fun_row(config) { + return {name, content: ` + + ${fun_name_cell(config)} + ${config.detail === undefined ? '' : config.detail} + ${report_log_cell()} + ${time_diff_str(started, finished)} + + `}; + } + + if (result === null && finished === null) { + running_functions.push(fun_row({indicator_colour: "bg-gray-500"})); + return; + } + + const {group, detail} = function_result_classification.get(result) ?? { + group: "failed", detail: "internal error" + }; + + if (group === "passed") { + passed_functions.push(fun_row({indicator_colour: "bg-green-500"})); + return; + } + + if (group === "failed") { + failed_functions.push(fun_row({detail, indicator_colour: "bg-red-500"})); + return; + } + + if (group === "skipped") { + skipped_functions.push(fun_row({detail, indicator_colour: "bg-amber-500"})); + return; + } + }); + + waiting_functions.sort(on(f => f.name)); + running_functions.sort(on(f => f.name)); + passed_functions.sort(on(f => f.name)); + failed_functions.sort(on(f => f.name)); + skipped_functions.sort(on(f => f.name)); + + function function_count(arr) { + return arr.length === 1 ? "1 function" : `${arr.length} functions`; + } + + function render_functions(group, arr) { + return arr.length === 0 ? '' : ` +
+ ${function_count(arr)} ${group} +
+ + + ${arr.map(f => f.content).join('')} + +
+
+
+ `; + } + + document.getElementById("content-container").innerHTML = ` +
+

Target detail

+
+
+
+
job id
+
${esc_html(job_id.substring(0,10))}
+
+
+
target
+
${target}
+
+ ${!job_info?.github?.tag ? '' :` +
+
job tag
+
${esc_html(job_info.github.tag)}
+
+ `} +
+
+ ${date_markup("submitted", job_status.timestamps.time_job_submitted)} + ${date_markup("started", job_status.timestamps.time_job_started)} + ${date_markup("finished", job_status.timestamps.time_job_finished)} +
+
+ ${!job_info?.github?.proof ? '' : github_run("proof run", job_info.github.proof)} + ${!job_info?.github?.decompile ? '' : github_run("decompile run", job_info.github.decompile)} +
+
seL4 commit
+
${esc_html(versions["seL4"].substring(0,10))}
+
+
+
l4v commit
+
${esc_html(versions["l4v"].substring(0,10))}
+
+
+
+ ${render_functions("failed", failed_functions)} + ${render_functions("passed", passed_functions)} + ${render_functions("running", running_functions)} + ${render_functions("waiting", waiting_functions)} + ${render_functions("skipped", skipped_functions)} +
+ `; +} + +async function render_content() { + const params = new URLSearchParams(window.location.search); + const job_id = params.get("job_id"); + const target = params.get("target"); + await (job_id === null || target === null) ? render_index() : render_target(job_id, target); +} + +render_content(); diff --git a/ci/web/package-lock.json b/ci/web/package-lock.json new file mode 100644 index 00000000..b4f6ca26 --- /dev/null +++ b/ci/web/package-lock.json @@ -0,0 +1,848 @@ +{ + "name": "web", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "devDependencies": { + "@tailwindcss/typography": "^0.5.9", + "tailwindcss": "^3.2.7" + } + }, + "node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "dev": true, + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "dev": true, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "dev": true, + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/@tailwindcss/typography": { + "version": "0.5.9", + "resolved": "https://registry.npmjs.org/@tailwindcss/typography/-/typography-0.5.9.tgz", + "integrity": "sha512-t8Sg3DyynFysV9f4JDOVISGsjazNb48AeIYQwcL+Bsq5uf4RYL75C1giZ43KISjeDGBaTN3Kxh7Xj/vRSMJUUg==", + "dev": true, + "dependencies": { + "lodash.castarray": "^4.4.0", + "lodash.isplainobject": "^4.0.6", + "lodash.merge": "^4.6.2", + "postcss-selector-parser": "6.0.10" + }, + "peerDependencies": { + "tailwindcss": ">=3.0.0 || insiders" + } + }, + "node_modules/@tailwindcss/typography/node_modules/postcss-selector-parser": { + "version": "6.0.10", + "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz", + "integrity": "sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w==", + "dev": true, + "dependencies": { + "cssesc": "^3.0.0", + "util-deprecate": "^1.0.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/acorn": { + "version": "7.4.1", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-7.4.1.tgz", + "integrity": "sha512-nQyp0o1/mNdbTO1PO6kHkwSrmgZ0MT/jCCpNiwbUjGoRN4dlBhqJtoQuCnEOKzgTVwg0ZWiCoQy6SxMebQVh8A==", + "dev": true, + "bin": { + "acorn": "bin/acorn" + }, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/acorn-node": { + "version": "1.8.2", + "resolved": "https://registry.npmjs.org/acorn-node/-/acorn-node-1.8.2.tgz", + "integrity": "sha512-8mt+fslDufLYntIoPAaIMUe/lrbrehIiwmR3t2k9LljIzoigEPF27eLk2hy8zSGzmR/ogr7zbRKINMo1u0yh5A==", + "dev": true, + "dependencies": { + "acorn": "^7.0.0", + "acorn-walk": "^7.0.0", + "xtend": "^4.0.2" + } + }, + "node_modules/acorn-walk": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-7.2.0.tgz", + "integrity": "sha512-OPdCF6GsMIP+Az+aWfAAOEt2/+iVDKE7oy6lJ098aoe59oAmK76qV6Gw60SbZ8jHuG2wH058GF4pLFbYamYrVA==", + "dev": true, + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "dev": true, + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/arg": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/arg/-/arg-5.0.2.tgz", + "integrity": "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==", + "dev": true + }, + "node_modules/binary-extensions": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", + "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", + "dev": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/braces": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", + "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "dev": true, + "dependencies": { + "fill-range": "^7.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/camelcase-css": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz", + "integrity": "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==", + "dev": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/chokidar": { + "version": "3.5.3", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", + "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://paulmillr.com/funding/" + } + ], + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/chokidar/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "dev": true + }, + "node_modules/cssesc": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", + "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", + "dev": true, + "bin": { + "cssesc": "bin/cssesc" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/defined": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/defined/-/defined-1.0.1.tgz", + "integrity": "sha512-hsBd2qSVCRE+5PmNdHt1uzyrFu5d3RwmFDKzyNZMFq/EwDNJF7Ee5+D5oEKF0hU6LhtoUF1macFvOe4AskQC1Q==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/detective": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/detective/-/detective-5.2.1.tgz", + "integrity": "sha512-v9XE1zRnz1wRtgurGu0Bs8uHKFSTdteYZNbIPFVhUZ39L/S79ppMpdmVOZAnoz1jfEFodc48n6MX483Xo3t1yw==", + "dev": true, + "dependencies": { + "acorn-node": "^1.8.2", + "defined": "^1.0.0", + "minimist": "^1.2.6" + }, + "bin": { + "detective": "bin/detective.js" + }, + "engines": { + "node": ">=0.8.0" + } + }, + "node_modules/didyoumean": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", + "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==", + "dev": true + }, + "node_modules/dlv": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz", + "integrity": "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==", + "dev": true + }, + "node_modules/fast-glob": { + "version": "3.2.12", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.2.12.tgz", + "integrity": "sha512-DVj4CQIYYow0BlaelwK1pHl5n5cRSJfM60UA0zK891sVInoPri2Ekj7+e1CT3/3qxXenpI+nBBmQAcJPJgaj4w==", + "dev": true, + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.4" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fastq": { + "version": "1.15.0", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.15.0.tgz", + "integrity": "sha512-wBrocU2LCXXa+lWBt8RoIRD89Fi8OdABODa/kEnyeyjS5aZO5/GNvI5sEINADqP/h8M29UHTHUb53sUu5Ihqdw==", + "dev": true, + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/fill-range": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", + "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "dev": true, + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/fsevents": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", + "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/function-bind": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", + "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", + "dev": true + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "dev": true, + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/has": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", + "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.1" + }, + "engines": { + "node": ">= 0.4.0" + } + }, + "node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "dev": true, + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/is-core-module": { + "version": "2.11.0", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.11.0.tgz", + "integrity": "sha512-RRjxlvLDkD1YJwDbroBHMb+cukurkDWNyHx7D3oNB5x9rb5ogcksMC5wHCadcXoo67gVr/+3GFySh3134zi6rw==", + "dev": true, + "dependencies": { + "has": "^1.0.3" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "dev": true, + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "dev": true, + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/lilconfig": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-2.1.0.tgz", + "integrity": "sha512-utWOt/GHzuUxnLKxB6dk81RoOeoNeHgbrXiuGk4yyF5qlRz+iIVWu56E2fqGHFrXz0QNUhLB/8nKqvRH66JKGQ==", + "dev": true, + "engines": { + "node": ">=10" + } + }, + "node_modules/lodash.castarray": { + "version": "4.4.0", + "resolved": "https://registry.npmjs.org/lodash.castarray/-/lodash.castarray-4.4.0.tgz", + "integrity": "sha512-aVx8ztPv7/2ULbArGJ2Y42bG1mEQ5mGjpdvrbJcJFU3TbYybe+QlLS4pst9zV52ymy2in1KpFPiZnAOATxD4+Q==", + "dev": true + }, + "node_modules/lodash.isplainobject": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/lodash.isplainobject/-/lodash.isplainobject-4.0.6.tgz", + "integrity": "sha512-oSXzaWypCMHkPC3NvBEaPHf0KsA5mvPrOPgQWDsbg8n7orZ290M0BmC/jgRZ4vcJ6DTAhjrsSYgdsW/F+MFOBA==", + "dev": true + }, + "node_modules/lodash.merge": { + "version": "4.6.2", + "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", + "integrity": "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==", + "dev": true + }, + "node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "dev": true, + "engines": { + "node": ">= 8" + } + }, + "node_modules/micromatch": { + "version": "4.0.5", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.5.tgz", + "integrity": "sha512-DMy+ERcEW2q8Z2Po+WNXuw3c5YaUSFjAO5GsJqfEl7UjvtIuFKO6ZrKvcItdy98dwFI2N1tg3zNIdKaQT+aNdA==", + "dev": true, + "dependencies": { + "braces": "^3.0.2", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/minimist": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", + "integrity": "sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/nanoid": { + "version": "3.3.4", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.4.tgz", + "integrity": "sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw==", + "dev": true, + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-hash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-3.0.0.tgz", + "integrity": "sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw==", + "dev": true, + "engines": { + "node": ">= 6" + } + }, + "node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "dev": true + }, + "node_modules/picocolors": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.0.0.tgz", + "integrity": "sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==", + "dev": true + }, + "node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "dev": true, + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/pify": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", + "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/postcss": { + "version": "8.4.21", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.21.tgz", + "integrity": "sha512-tP7u/Sn/dVxK2NnruI4H9BG+x+Wxz6oeZ1cJ8P6G/PZY0IKk4k/63TDsQf2kQq3+qoJeLm2kIBUNlZe3zgb4Zg==", + "dev": true, + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + } + ], + "dependencies": { + "nanoid": "^3.3.4", + "picocolors": "^1.0.0", + "source-map-js": "^1.0.2" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/postcss-import": { + "version": "14.1.0", + "resolved": "https://registry.npmjs.org/postcss-import/-/postcss-import-14.1.0.tgz", + "integrity": "sha512-flwI+Vgm4SElObFVPpTIT7SU7R3qk2L7PyduMcokiaVKuWv9d/U+Gm/QAd8NDLuykTWTkcrjOeD2Pp1rMeBTGw==", + "dev": true, + "dependencies": { + "postcss-value-parser": "^4.0.0", + "read-cache": "^1.0.0", + "resolve": "^1.1.7" + }, + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "postcss": "^8.0.0" + } + }, + "node_modules/postcss-js": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/postcss-js/-/postcss-js-4.0.1.tgz", + "integrity": "sha512-dDLF8pEO191hJMtlHFPRa8xsizHaM82MLfNkUHdUtVEV3tgTp5oj+8qbEqYM57SLfc74KSbw//4SeJma2LRVIw==", + "dev": true, + "dependencies": { + "camelcase-css": "^2.0.1" + }, + "engines": { + "node": "^12 || ^14 || >= 16" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + "peerDependencies": { + "postcss": "^8.4.21" + } + }, + "node_modules/postcss-load-config": { + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/postcss-load-config/-/postcss-load-config-3.1.4.tgz", + "integrity": "sha512-6DiM4E7v4coTE4uzA8U//WhtPwyhiim3eyjEMFCnUpzbrkK9wJHgKDT2mR+HbtSrd/NubVaYTOpSpjUl8NQeRg==", + "dev": true, + "dependencies": { + "lilconfig": "^2.0.5", + "yaml": "^1.10.2" + }, + "engines": { + "node": ">= 10" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + "peerDependencies": { + "postcss": ">=8.0.9", + "ts-node": ">=9.0.0" + }, + "peerDependenciesMeta": { + "postcss": { + "optional": true + }, + "ts-node": { + "optional": true + } + } + }, + "node_modules/postcss-nested": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/postcss-nested/-/postcss-nested-6.0.0.tgz", + "integrity": "sha512-0DkamqrPcmkBDsLn+vQDIrtkSbNkv5AD/M322ySo9kqFkCIYklym2xEmWkwo+Y3/qZo34tzEPNUw4y7yMCdv5w==", + "dev": true, + "dependencies": { + "postcss-selector-parser": "^6.0.10" + }, + "engines": { + "node": ">=12.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + "peerDependencies": { + "postcss": "^8.2.14" + } + }, + "node_modules/postcss-selector-parser": { + "version": "6.0.11", + "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.11.tgz", + "integrity": "sha512-zbARubNdogI9j7WY4nQJBiNqQf3sLS3wCP4WfOidu+p28LofJqDH1tcXypGrcmMHhDk2t9wGhCsYe/+szLTy1g==", + "dev": true, + "dependencies": { + "cssesc": "^3.0.0", + "util-deprecate": "^1.0.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/postcss-value-parser": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", + "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==", + "dev": true + }, + "node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ] + }, + "node_modules/quick-lru": { + "version": "5.1.1", + "resolved": "https://registry.npmjs.org/quick-lru/-/quick-lru-5.1.1.tgz", + "integrity": "sha512-WuyALRjWPDGtt/wzJiadO5AXY+8hZ80hVpe6MyivgraREW751X3SbhRvG3eLKOYN+8VEvqLcf3wdnt44Z4S4SA==", + "dev": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/read-cache": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", + "integrity": "sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==", + "dev": true, + "dependencies": { + "pify": "^2.3.0" + } + }, + "node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "dev": true, + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/resolve": { + "version": "1.22.1", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.1.tgz", + "integrity": "sha512-nBpuuYuY5jFsli/JIs1oldw6fOQCBioohqWZg/2hiaOybXOft4lonv85uDOKXdf8rhyK159cxU5cDcK/NKk8zw==", + "dev": true, + "dependencies": { + "is-core-module": "^2.9.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "dev": true, + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "dev": true, + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/source-map-js": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", + "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "dev": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "dev": true, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/tailwindcss": { + "version": "3.2.7", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.2.7.tgz", + "integrity": "sha512-B6DLqJzc21x7wntlH/GsZwEXTBttVSl1FtCzC8WP4oBc/NKef7kaax5jeihkkCEWc831/5NDJ9gRNDK6NEioQQ==", + "dev": true, + "dependencies": { + "arg": "^5.0.2", + "chokidar": "^3.5.3", + "color-name": "^1.1.4", + "detective": "^5.2.1", + "didyoumean": "^1.2.2", + "dlv": "^1.1.3", + "fast-glob": "^3.2.12", + "glob-parent": "^6.0.2", + "is-glob": "^4.0.3", + "lilconfig": "^2.0.6", + "micromatch": "^4.0.5", + "normalize-path": "^3.0.0", + "object-hash": "^3.0.0", + "picocolors": "^1.0.0", + "postcss": "^8.0.9", + "postcss-import": "^14.1.0", + "postcss-js": "^4.0.0", + "postcss-load-config": "^3.1.4", + "postcss-nested": "6.0.0", + "postcss-selector-parser": "^6.0.11", + "postcss-value-parser": "^4.2.0", + "quick-lru": "^5.1.1", + "resolve": "^1.22.1" + }, + "bin": { + "tailwind": "lib/cli.js", + "tailwindcss": "lib/cli.js" + }, + "engines": { + "node": ">=12.13.0" + }, + "peerDependencies": { + "postcss": "^8.0.9" + } + }, + "node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "dev": true, + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", + "dev": true + }, + "node_modules/xtend": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz", + "integrity": "sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==", + "dev": true, + "engines": { + "node": ">=0.4" + } + }, + "node_modules/yaml": { + "version": "1.10.2", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-1.10.2.tgz", + "integrity": "sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==", + "dev": true, + "engines": { + "node": ">= 6" + } + } + } +} diff --git a/ci/web/package.json b/ci/web/package.json new file mode 100644 index 00000000..c5ffc265 --- /dev/null +++ b/ci/web/package.json @@ -0,0 +1,6 @@ +{ + "devDependencies": { + "@tailwindcss/typography": "^0.5.9", + "tailwindcss": "^3.2.7" + } +} diff --git a/ci/web/src/style.css b/ci/web/src/style.css new file mode 100644 index 00000000..bd6213e1 --- /dev/null +++ b/ci/web/src/style.css @@ -0,0 +1,3 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; \ No newline at end of file diff --git a/ci/web/tailwind.config.js b/ci/web/tailwind.config.js new file mode 100644 index 00000000..c7a75c9b --- /dev/null +++ b/ci/web/tailwind.config.js @@ -0,0 +1,14 @@ +module.exports = { + content: { + relative: true, + files: [ + './dist/**/*.{html,js}', + ], + }, + theme: { + extend: {}, + }, + plugins: [ + require('@tailwindcss/typography'), + ], +} diff --git a/debug.py b/debug.py index 22636bf2..69f9e605 100644 --- a/debug.py +++ b/debug.py @@ -18,849 +18,849 @@ import random def check_entry_var_deps (f): - if not f.entry: - return set () - p = f.as_problem (Problem) - diff = check_problem_entry_var_deps (p) + if not f.entry: + return set () + p = f.as_problem (Problem) + diff = check_problem_entry_var_deps (p) - return diff + return diff def check_problem_entry_var_deps (p, var_deps = None): - if var_deps == None: - var_deps = p.compute_var_dependencies () - for (entry, tag, _, inputs) in p.entries: - if entry not in var_deps: - print 'Entry missing from var_deps: %d' % entry - continue - diff = set (var_deps[entry]) - set (inputs) - if diff: - print 'Vars deps escaped in %s in %s: %s' % (tag, - p.name, diff) - return diff - return set () + if var_deps == None: + var_deps = p.compute_var_dependencies () + for (entry, tag, _, inputs) in p.entries: + if entry not in var_deps: + print 'Entry missing from var_deps: %d' % entry + continue + diff = set (var_deps[entry]) - set (inputs) + if diff: + print 'Vars deps escaped in %s in %s: %s' % (tag, + p.name, diff) + return diff + return set () def check_all_var_deps (): - return [f for f in functions if check_entry_var_deps(functions[f])] + return [f for f in functions if check_entry_var_deps(functions[f])] def walk_var_deps (p, n, v, var_deps = None, - interest = set (), symmetric = False): - if var_deps == None: - var_deps = p.compute_var_dependencies () - while True: - if n == 'Ret' or n == 'Err': - print n - return n - if symmetric: - opts = set ([n2 for n2 in p.preds[n] if n2 in p.nodes]) - else: - opts = set ([n2 for n2 in p.nodes[n].get_conts () - if n2 in p.nodes]) - choices = [n2 for n2 in opts if v in var_deps[n2]] - if not choices: - print 'Walk ends at %d.' % n - return - if len (choices) > 1: - print 'choices %s, gambling' % choices - random.shuffle (choices) - print ' ... rolled a %s' % choices[0] - elif len (opts) > 1: - print 'picked %s from %s' % (choices[0], opts) - n = choices[0] - if n in interest: - print '** %d' % n - else: - print n + interest = set (), symmetric = False): + if var_deps == None: + var_deps = p.compute_var_dependencies () + while True: + if n == 'Ret' or n == 'Err': + print n + return n + if symmetric: + opts = set ([n2 for n2 in p.preds[n] if n2 in p.nodes]) + else: + opts = set ([n2 for n2 in p.nodes[n].get_conts () + if n2 in p.nodes]) + choices = [n2 for n2 in opts if v in var_deps[n2]] + if not choices: + print 'Walk ends at %d.' % n + return + if len (choices) > 1: + print 'choices %s, gambling' % choices + random.shuffle (choices) + print ' ... rolled a %s' % choices[0] + elif len (opts) > 1: + print 'picked %s from %s' % (choices[0], opts) + n = choices[0] + if n in interest: + print '** %d' % n + else: + print n def diagram_var_deps (p, fname, v, var_deps = None): - if var_deps == None: - var_deps = p.compute_var_dependencies () - cols = {} - for n in p.nodes: - if n not in var_deps: - cols[n] = 'darkgrey' - elif v not in var_deps[n]: - cols[n] = 'darkblue' - else: - cols[n] = 'orange' - problem.save_graph (p.nodes, fname, cols = cols) + if var_deps == None: + var_deps = p.compute_var_dependencies () + cols = {} + for n in p.nodes: + if n not in var_deps: + cols[n] = 'darkgrey' + elif v not in var_deps[n]: + cols[n] = 'darkblue' + else: + cols[n] = 'orange' + problem.save_graph (p.nodes, fname, cols = cols) def trace_model (rep, m, simplify = True): - p = rep.p - tags = set ([tag for (tag, n, vc) in rep.node_pc_env_order]) - if p.pairing and tags == set (p.pairing.tags): - tags = reversed (p.pairing.tags) - for tag in tags: - print "Walking %s in model" % tag - n_vcs = walk_model (rep, tag, m) - prev_era = None - for (i, (n, vc)) in enumerate (n_vcs): - era = n_vc_era (p, (n, vc)) - if era != prev_era: - print 'now in era %s' % era - prev_era = era - if n in ['Ret', 'Err']: - print 'ends at %s' % n - break - node = logic.simplify_node_elementary (p.nodes[n]) - if node.kind != 'Cond': - continue - name = rep.cond_name ((n, vc)) - cond = m[name] == syntax.true_term - print '%s: %s (%s, %s)' % (name, cond, - node.left, node.right) - investigate_cond (rep, m, name, simplify) + p = rep.p + tags = set ([tag for (tag, n, vc) in rep.node_pc_env_order]) + if p.pairing and tags == set (p.pairing.tags): + tags = reversed (p.pairing.tags) + for tag in tags: + print "Walking %s in model" % tag + n_vcs = walk_model (rep, tag, m) + prev_era = None + for (i, (n, vc)) in enumerate (n_vcs): + era = n_vc_era (p, (n, vc)) + if era != prev_era: + print 'now in era %s' % era + prev_era = era + if n in ['Ret', 'Err']: + print 'ends at %s' % n + break + node = logic.simplify_node_elementary (p.nodes[n]) + if node.kind != 'Cond': + continue + name = rep.cond_name ((n, vc)) + cond = m[name] == syntax.true_term + print '%s: %s (%s, %s)' % (name, cond, + node.left, node.right) + investigate_cond (rep, m, name, simplify) def walk_model (rep, tag, m): - n_vcs = [(n, vc) for (tag2, n, vc) in rep.node_pc_env_order - if tag2 == tag - if search.eval_model_expr (m, rep.solv, - rep.get_pc ((n, vc), tag)) - == syntax.true_term] + n_vcs = [(n, vc) for (tag2, n, vc) in rep.node_pc_env_order + if tag2 == tag + if search.eval_model_expr (m, rep.solv, + rep.get_pc ((n, vc), tag)) + == syntax.true_term] - n_vcs = era_sort (rep, n_vcs) + n_vcs = era_sort (rep, n_vcs) - return n_vcs + return n_vcs def investigate_cond (rep, m, cond, simplify = True, rec = True): - cond_def = rep.solv.defs[cond] - while rec and type (cond_def) == str and cond_def in rep.solv.defs: - cond_def = rep.solv.defs[cond_def] - def do_bit (bit): - if bit == 'true': - return True - valid = eval_model_bool (m, bit) - if simplify: - # looks a bit strange to do this now but some pointer - # lookups have to be done with unmodified s-exprs - bit = simplify_sexp (bit, rep, m, flatten = False) - print ' %s: %s' % (valid, solver.flat_s_expression (bit)) - return valid - while cond_def[0] == '=>': - valid = do_bit (cond_def[1]) - if not valid: - break - cond_def = cond_def[2] - bits = solver.split_hyp_sexpr (cond_def, []) - for bit in bits: - do_bit (bit) + cond_def = rep.solv.defs[cond] + while rec and type (cond_def) == str and cond_def in rep.solv.defs: + cond_def = rep.solv.defs[cond_def] + def do_bit (bit): + if bit == 'true': + return True + valid = eval_model_bool (m, bit) + if simplify: + # looks a bit strange to do this now but some pointer + # lookups have to be done with unmodified s-exprs + bit = simplify_sexp (bit, rep, m, flatten = False) + print ' %s: %s' % (valid, solver.flat_s_expression (bit)) + return valid + while cond_def[0] == '=>': + valid = do_bit (cond_def[1]) + if not valid: + break + cond_def = cond_def[2] + bits = solver.split_hyp_sexpr (cond_def, []) + for bit in bits: + do_bit (bit) def eval_model_bool (m, x): - if hasattr (x, 'typ'): - x = solver.smt_expr (x, {}, None) - x = solver.parse_s_expression (x) - try: - r = search.eval_model (m, x) - assert r in [syntax.true_term, syntax.false_term], r - return r == syntax.true_term - except: - return 'EXCEPT' + if hasattr (x, 'typ'): + x = solver.smt_expr (x, {}, None) + x = solver.parse_s_expression (x) + try: + r = search.eval_model (m, x) + assert r in [syntax.true_term, syntax.false_term], r + return r == syntax.true_term + except: + return 'EXCEPT' def funcall_name (rep): - return lambda n_vc: "%s @%s" % (rep.p.nodes[n_vc[0]].fname, - rep.node_count_name (n_vc)) + return lambda n_vc: "%s @%s" % (rep.p.nodes[n_vc[0]].fname, + rep.node_count_name (n_vc)) def n_vc_era (p, (n, vc)): - era = 0 - for (split, vcount) in vc: - if not p.loop_id (split): - continue - (ns, os) = vcount.get_opts () - if len (ns + os) > 1: - era += 3 - elif ns: - era += 1 - elif os: - era += 2 - return era + era = 0 + for (split, vcount) in vc: + if not p.loop_id (split): + continue + (ns, os) = vcount.get_opts () + if len (ns + os) > 1: + era += 3 + elif ns: + era += 1 + elif os: + era += 2 + return era def era_merge (era): - # fold onramp to loops into pre-loop era - if era % 3 == 1: - era -= 1 - return era + # fold onramp to loops into pre-loop era + if era % 3 == 1: + era -= 1 + return era def do_era_merge (do_merge, era): - if do_merge: - return era_merge (era) - else: - return era + if do_merge: + return era_merge (era) + else: + return era def era_sort (rep, n_vcs): - with_eras = [(n_vc_era (rep.p, n_vc), n_vc) for n_vc in n_vcs] - with_eras.sort (key = lambda x: x[0]) - for i in range (len (with_eras) - 1): - (e1, n_vc1) = with_eras[i] - (e2, n_vc2) = with_eras[i + 1] - if e1 != e2: - continue - if n_vc1[0] in ['Ret', 'Err']: - assert not 'Era issues', n_vcs - assert rep.is_cont (n_vc1, n_vc2), [n_vc1, n_vc2] - return [n_vc for (_, n_vc) in with_eras] + with_eras = [(n_vc_era (rep.p, n_vc), n_vc) for n_vc in n_vcs] + with_eras.sort (key = lambda x: x[0]) + for i in range (len (with_eras) - 1): + (e1, n_vc1) = with_eras[i] + (e2, n_vc2) = with_eras[i + 1] + if e1 != e2: + continue + if n_vc1[0] in ['Ret', 'Err']: + assert not 'Era issues', n_vcs + assert rep.is_cont (n_vc1, n_vc2), [n_vc1, n_vc2] + return [n_vc for (_, n_vc) in with_eras] def investigate_funcalls (rep, m, verbose = False, verbose_imp = False, - simplify = True, pairing = 'Args', era_merge = True): - l_tag, r_tag = rep.p.pairing.tags - l_ns = walk_model (rep, l_tag, m) - r_ns = walk_model (rep, r_tag, m) - nodes = rep.p.nodes - - l_calls = [n_vc for n_vc in l_ns if n_vc in rep.funcs] - r_calls = [n_vc for n_vc in r_ns if n_vc in rep.funcs] - print '%s calls: %s' % (l_tag, map (funcall_name (rep), l_calls)) - print '%s calls: %s' % (r_tag, map (funcall_name (rep), r_calls)) - - if pairing == 'Eras': - fc_pairs = pair_funcalls_by_era (rep, l_calls, r_calls, - era_m = era_merge) - elif pairing == 'Seq': - fc_pairs = pair_funcalls_sequential (rep, l_calls, r_calls) - elif pairing == 'Args': - fc_pairs = pair_funcalls_by_match (rep, m, l_calls, r_calls, - era_m = era_merge) - elif pairing == 'All': - fc_pairs = [(lc, rc) for lc in l_calls for rc in r_calls] - else: - assert pairing in ['Eras', 'Seq', 'Args', 'All'], pairing - - for (l_n_vc, r_n_vc) in fc_pairs: - if not rep.get_func_pairing (l_n_vc, r_n_vc): - print 'call seq mismatch: (%s, %s)' % (l_n_vc, r_n_vc) - continue - investigate_funcall_pair (rep, m, l_n_vc, r_n_vc, - verbose, verbose_imp, simplify) + simplify = True, pairing = 'Args', era_merge = True): + l_tag, r_tag = rep.p.pairing.tags + l_ns = walk_model (rep, l_tag, m) + r_ns = walk_model (rep, r_tag, m) + nodes = rep.p.nodes + + l_calls = [n_vc for n_vc in l_ns if n_vc in rep.funcs] + r_calls = [n_vc for n_vc in r_ns if n_vc in rep.funcs] + print '%s calls: %s' % (l_tag, map (funcall_name (rep), l_calls)) + print '%s calls: %s' % (r_tag, map (funcall_name (rep), r_calls)) + + if pairing == 'Eras': + fc_pairs = pair_funcalls_by_era (rep, l_calls, r_calls, + era_m = era_merge) + elif pairing == 'Seq': + fc_pairs = pair_funcalls_sequential (rep, l_calls, r_calls) + elif pairing == 'Args': + fc_pairs = pair_funcalls_by_match (rep, m, l_calls, r_calls, + era_m = era_merge) + elif pairing == 'All': + fc_pairs = [(lc, rc) for lc in l_calls for rc in r_calls] + else: + assert pairing in ['Eras', 'Seq', 'Args', 'All'], pairing + + for (l_n_vc, r_n_vc) in fc_pairs: + if not rep.get_func_pairing (l_n_vc, r_n_vc): + print 'call seq mismatch: (%s, %s)' % (l_n_vc, r_n_vc) + continue + investigate_funcall_pair (rep, m, l_n_vc, r_n_vc, + verbose, verbose_imp, simplify) def pair_funcalls_by_era (rep, l_calls, r_calls, era_m = True): - eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) - eras = sorted (eras + set (map (era_merge, eras))) - pairs = [] - for era in eras: - ls = [n_vc for n_vc in l_calls - if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] - rs = [n_vc for n_vc in r_calls - if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] - if len (ls) != len (rs): - print 'call seq length mismatch in era %d:' % era - print map (funcall_name (rep), ls) - print map (funcall_name (rep), rs) - pairs.extend (zip (ls, rs)) - return pairs + eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) + eras = sorted (eras + set (map (era_merge, eras))) + pairs = [] + for era in eras: + ls = [n_vc for n_vc in l_calls + if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] + rs = [n_vc for n_vc in r_calls + if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] + if len (ls) != len (rs): + print 'call seq length mismatch in era %d:' % era + print map (funcall_name (rep), ls) + print map (funcall_name (rep), rs) + pairs.extend (zip (ls, rs)) + return pairs def pair_funcalls_sequential (rep, l_calls, r_calls): - if len (l_calls) != len (r_calls): - print 'call seq tail mismatch' - if len (l_calls) > len (r_calls): - print 'dropping lhs: %s' % map (funcall_name (rep), - l_calls[len (r_calls):]) - else: - print 'dropping rhs: %s' % map (funcall_name (rep), - r_calls[len (l_calls):]) - # really should add some smarts to this to 'recover' from upsets or - # reorders, but maybe not worth it. - return zip (l_calls, r_calls) + if len (l_calls) != len (r_calls): + print 'call seq tail mismatch' + if len (l_calls) > len (r_calls): + print 'dropping lhs: %s' % map (funcall_name (rep), + l_calls[len (r_calls):]) + else: + print 'dropping rhs: %s' % map (funcall_name (rep), + r_calls[len (l_calls):]) + # really should add some smarts to this to 'recover' from upsets or + # reorders, but maybe not worth it. + return zip (l_calls, r_calls) def pair_funcalls_by_match (rep, m, l_calls, r_calls, era_m = True): - eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) - eras = sorted (set.union (eras, set (map (era_merge, eras)))) - pairs = [] - for era in eras: - ls = [n_vc for n_vc in l_calls - if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] - rs = [n_vc for n_vc in r_calls - if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] - res = None - matches = [(1 - func_assert_premise_strength (rep, m, - n_vc, n_vc2), i, j) - for (i, n_vc) in enumerate (ls) - for (j, n_vc2) in enumerate (rs) - if rep.get_func_pairing (n_vc, n_vc2)] - matches.sort () - if not matches: - print 'Cannot match any (%d, %d) at era %d' % (len (ls), - len (rs), era) - continue - (_, i, j) = matches[0] - if i > j: - pairs.extend ((zip (ls[i - j:], rs))) - else: - pairs.extend ((zip (ls, rs[j - i:]))) - return pairs + eras = set ([n_vc_era (rep.p, n_vc) for n_vc in l_calls + r_calls]) + eras = sorted (set.union (eras, set (map (era_merge, eras)))) + pairs = [] + for era in eras: + ls = [n_vc for n_vc in l_calls + if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] + rs = [n_vc for n_vc in r_calls + if do_era_merge (era_m, n_vc_era (rep.p, n_vc)) == era] + res = None + matches = [(1 - func_assert_premise_strength (rep, m, + n_vc, n_vc2), i, j) + for (i, n_vc) in enumerate (ls) + for (j, n_vc2) in enumerate (rs) + if rep.get_func_pairing (n_vc, n_vc2)] + matches.sort () + if not matches: + print 'Cannot match any (%d, %d) at era %d' % (len (ls), + len (rs), era) + continue + (_, i, j) = matches[0] + if i > j: + pairs.extend ((zip (ls[i - j:], rs))) + else: + pairs.extend ((zip (ls, rs[j - i:]))) + return pairs def func_assert_premise_strength (rep, m, l_n_vc, r_n_vc): - imp = rep.get_func_assert (l_n_vc, r_n_vc) - assert imp.is_op ('Implies'), imp - [pred, concl] = imp.vals - pred = solver.smt_expr (pred, {}, rep.solv) - pred = solver.parse_s_expression (pred) - bits = solver.split_hyp_sexpr (pred, []) - assert bits, bits - scores = [] - for bit in bits: - try: - res = eval_model_bool (m, bit) - if res: - scores.append (1.0) - else: - scores.append (0.0) - except solver.EnvMiss, e: - scores.append (0.5) - except AssertionError, e: - scores.append (0.5) - return sum (scores) / len (scores) - return all ([eval_model_bool (m, v) for v in bits]) + imp = rep.get_func_assert (l_n_vc, r_n_vc) + assert imp.is_op ('Implies'), imp + [pred, concl] = imp.vals + pred = solver.smt_expr (pred, {}, rep.solv) + pred = solver.parse_s_expression (pred) + bits = solver.split_hyp_sexpr (pred, []) + assert bits, bits + scores = [] + for bit in bits: + try: + res = eval_model_bool (m, bit) + if res: + scores.append (1.0) + else: + scores.append (0.0) + except solver.EnvMiss, e: + scores.append (0.5) + except AssertionError, e: + scores.append (0.5) + return sum (scores) / len (scores) + return all ([eval_model_bool (m, v) for v in bits]) def investigate_funcall_pair (rep, m, l_n_vc, r_n_vc, - verbose = False, verbose_imp = False, simplify = True): - - l_nm = "%s @ %s" % (rep.p.nodes[l_n_vc[0]].fname, rep.node_count_name (l_n_vc)) - r_nm = "%s @ %s" % (rep.p.nodes[r_n_vc[0]].fname, rep.node_count_name (r_n_vc)) - print 'Attempt match %s -> %s' % (l_nm, r_nm) - imp = rep.get_func_assert (l_n_vc, r_n_vc) - imp = logic.weaken_assert (imp) - if verbose_imp: - imp2 = solver.smt_expr (imp, {}, rep.solv) - if simplify: - imp2 = simplify_sexp (imp2, rep, m) - print imp2 - assert imp.is_op ('Implies'), imp - [pred, concl] = imp.vals - pred = solver.smt_expr (pred, {}, rep.solv) - pred = solver.parse_s_expression (pred) - bits = solver.split_hyp_sexpr (pred, []) - xs = [eval_model_bool (m, v) for v in bits] - print ' %s' % xs - for (v, bit) in zip (xs, bits): - if v != True or verbose: - print ' %s: %s' % (v, bit) - if bit[0] == 'word32-eq': - vs = [model_sx_word (m, x) - for x in bit[1:]] - print ' (%s = %s)' % tuple (vs) + verbose = False, verbose_imp = False, simplify = True): + + l_nm = "%s @ %s" % (rep.p.nodes[l_n_vc[0]].fname, rep.node_count_name (l_n_vc)) + r_nm = "%s @ %s" % (rep.p.nodes[r_n_vc[0]].fname, rep.node_count_name (r_n_vc)) + print 'Attempt match %s -> %s' % (l_nm, r_nm) + imp = rep.get_func_assert (l_n_vc, r_n_vc) + imp = logic.weaken_assert (imp) + if verbose_imp: + imp2 = solver.smt_expr (imp, {}, rep.solv) + if simplify: + imp2 = simplify_sexp (imp2, rep, m) + print imp2 + assert imp.is_op ('Implies'), imp + [pred, concl] = imp.vals + pred = solver.smt_expr (pred, {}, rep.solv) + pred = solver.parse_s_expression (pred) + bits = solver.split_hyp_sexpr (pred, []) + xs = [eval_model_bool (m, v) for v in bits] + print ' %s' % xs + for (v, bit) in zip (xs, bits): + if v != True or verbose: + print ' %s: %s' % (v, bit) + if bit[0] == 'word32-eq': + vs = [model_sx_word (m, x) + for x in bit[1:]] + print ' (%s = %s)' % tuple (vs) def model_sx_word (m, sx): - v = search.eval_model (m, sx) - x = expr_num (v) - return solver.smt_num_t (x, v.typ) + v = search.eval_model (m, sx) + x = expr_num (v) + return solver.smt_num_t (x, v.typ) def expr_num (expr): - assert expr.typ.kind == 'Word' - return expr.val & ((1 << expr.typ.num) - 1) + assert expr.typ.kind == 'Word' + return expr.val & ((1 << expr.typ.num) - 1) def str_to_num (smt_str): - v = solver.smt_to_val(smt_str) - return expr_num (v) + v = solver.smt_to_val(smt_str) + return expr_num (v) def m_var_name (expr): - while expr.is_op ('MemUpdate'): - [expr, p, v] = expr.vals - if expr.kind == 'Var': - return expr.name - elif expr.kind == 'Op': - return '' % op.name - else: - return '' % expr.kind + while expr.is_op ('MemUpdate'): + [expr, p, v] = expr.vals + if expr.kind == 'Var': + return expr.name + elif expr.kind == 'Op': + return '' % op.name + else: + return '' % expr.kind def eval_str (expr, env, solv, m): - expr = solver.to_smt_expr (expr, env, solv) - v = search.eval_model_expr (m, solv, expr) - if v.typ == syntax.boolT: - assert v in [syntax.true_term, syntax.false_term] - return v.name - elif v.typ.kind == 'Word': - return solver.smt_num_t (v.val, v.typ) - else: - assert not 'type printable', v + expr = solver.to_smt_expr (expr, env, solv) + v = search.eval_model_expr (m, solv, expr) + if v.typ == syntax.boolT: + assert v in [syntax.true_term, syntax.false_term] + return v.name + elif v.typ.kind == 'Word': + return solver.smt_num_t (v.val, v.typ) + else: + assert not 'type printable', v def trace_mem (rep, tag, m, verbose = False, simplify = True, symbs = True, - resolve_addrs = False): - p = rep.p - ns = walk_model (rep, tag, m) - trace = [] - for (n, vc) in ns: - if (n, vc) not in rep.arc_pc_envs: - # this n_vc has a pre-state, but has not been emitted. - # no point trying to evaluate its expressions, the - # solve won't have seen them yet. - continue - n_nm = rep.node_count_name ((n, vc)) - node = p.nodes[n] - if node.kind == 'Call': - exprs = list (node.args) - elif node.kind == 'Basic': - exprs = [expr for (_, expr) in node.upds] - elif node.kind == 'Cond': - exprs = [node.cond] - env = rep.node_pc_envs[(tag, n, vc)][1] - accs = list (set ([acc for expr in exprs - for acc in expr.get_mem_accesses ()])) - for (kind, addr, v, mem) in accs: - addr_s = solver.smt_expr (addr, env, rep.solv) - v_s = solver.smt_expr (v, env, rep.solv) - addr = eval_str (addr, env, rep.solv, m) - v = eval_str (v, env, rep.solv, m) - m_nm = m_var_name (mem) - print '%s: %s @ <%s> -- %s -- %s' % (kind, m_nm, addr, v, n_nm) - if simplify: - addr_s = simplify_sexp (addr_s, rep, m) - v_s = simplify_sexp (v_s, rep, m) - if verbose: - print '\t %s -- %s' % (addr_s, v_s) - if symbs: - addr_n = str_to_num (addr) - (hit_symbs, secs) = find_symbol (addr_n, output = False) - ss = hit_symbs + secs - if ss: - print '\t [%s]' % ', '.join (ss) - if resolve_addrs: - accs = [(kind, solver.to_smt_expr (addr, env, rep.solv), - solver.to_smt_expr (v, env, rep.solv), mem) - for (kind, addr, v, mem) in accs] - trace.extend ([(kind, addr, v, mem, n, vc) - for (kind, addr, v, mem) in accs]) - if node.kind == 'Call': - msg = '' % (node.fname, n_nm) - print msg - trace.append (msg) - return trace + resolve_addrs = False): + p = rep.p + ns = walk_model (rep, tag, m) + trace = [] + for (n, vc) in ns: + if (n, vc) not in rep.arc_pc_envs: + # this n_vc has a pre-state, but has not been emitted. + # no point trying to evaluate its expressions, the + # solve won't have seen them yet. + continue + n_nm = rep.node_count_name ((n, vc)) + node = p.nodes[n] + if node.kind == 'Call': + exprs = list (node.args) + elif node.kind == 'Basic': + exprs = [expr for (_, expr) in node.upds] + elif node.kind == 'Cond': + exprs = [node.cond] + env = rep.node_pc_envs[(tag, n, vc)][1] + accs = list (set ([acc for expr in exprs + for acc in expr.get_mem_accesses ()])) + for (kind, addr, v, mem) in accs: + addr_s = solver.smt_expr (addr, env, rep.solv) + v_s = solver.smt_expr (v, env, rep.solv) + addr = eval_str (addr, env, rep.solv, m) + v = eval_str (v, env, rep.solv, m) + m_nm = m_var_name (mem) + print '%s: %s @ <%s> -- %s -- %s' % (kind, m_nm, addr, v, n_nm) + if simplify: + addr_s = simplify_sexp (addr_s, rep, m) + v_s = simplify_sexp (v_s, rep, m) + if verbose: + print '\t %s -- %s' % (addr_s, v_s) + if symbs: + addr_n = str_to_num (addr) + (hit_symbs, secs) = find_symbol (addr_n, output = False) + ss = hit_symbs + secs + if ss: + print '\t [%s]' % ', '.join (ss) + if resolve_addrs: + accs = [(kind, solver.to_smt_expr (addr, env, rep.solv), + solver.to_smt_expr (v, env, rep.solv), mem) + for (kind, addr, v, mem) in accs] + trace.extend ([(kind, addr, v, mem, n, vc) + for (kind, addr, v, mem) in accs]) + if node.kind == 'Call': + msg = '' % (node.fname, n_nm) + print msg + trace.append (msg) + return trace def simplify_sexp (smt_xp, rep, m, flatten = True): - if type (smt_xp) == str: - smt_xp = solver.parse_s_expression (smt_xp) - if smt_xp[0] == 'ite': - (_, c, x, y) = smt_xp - if eval_model_bool (m, c): - return simplify_sexp (x, rep, m, flatten) - else: - return simplify_sexp (y, rep, m, flatten) - if type (smt_xp) == tuple: - smt_xp = tuple ([simplify_sexp (x, rep, m, False) - for x in smt_xp]) - if flatten: - return solver.flat_s_expression (smt_xp) - else: - return smt_xp + if type (smt_xp) == str: + smt_xp = solver.parse_s_expression (smt_xp) + if smt_xp[0] == 'ite': + (_, c, x, y) = smt_xp + if eval_model_bool (m, c): + return simplify_sexp (x, rep, m, flatten) + else: + return simplify_sexp (y, rep, m, flatten) + if type (smt_xp) == tuple: + smt_xp = tuple ([simplify_sexp (x, rep, m, False) + for x in smt_xp]) + if flatten: + return solver.flat_s_expression (smt_xp) + else: + return smt_xp def trace_mems (rep, m, verbose = False, symbs = True, tags = None): - if tags == None: - if rep.p.pairing: - tags = reversed (rep.p.pairing.tags) - else: - tags = rep.p.tags () - for tag in tags: - print '%s mem trace:' % tag - trace_mem (rep, tag, m, verbose = verbose, symbs = symbs) + if tags == None: + if rep.p.pairing: + tags = reversed (rep.p.pairing.tags) + else: + tags = rep.p.tags () + for tag in tags: + print '%s mem trace:' % tag + trace_mem (rep, tag, m, verbose = verbose, symbs = symbs) def trace_mems_diff (rep, m, tags = ['ASM', 'C']): - asms = trace_mem (rep, tags[0], m, resolve_addrs = True) - cs = trace_mem (rep, tags[1], m, resolve_addrs = True) - ev = lambda expr: eval_str (expr, {}, None, m) - c_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in cs - if kind == 'MemUpdate'] - asm_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in asms - if kind == 'MemUpdate' and 'mem' in m_var_name (mem)] - c_upd_d = dict (c_upds) - asm_upd_d = dict (asm_upds) - addr_ord = [addr for (addr, _) in asm_upds] + [addr for (addr, _) in c_upds - if addr not in asm_upd_d] - mism = [addr for addr in addr_ord - if c_upd_d.get (addr) != asm_upd_d.get (addr)] - return (c_upd_d == asm_upd_d, mism, c_upds, asm_upds) + asms = trace_mem (rep, tags[0], m, resolve_addrs = True) + cs = trace_mem (rep, tags[1], m, resolve_addrs = True) + ev = lambda expr: eval_str (expr, {}, None, m) + c_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in cs + if kind == 'MemUpdate'] + asm_upds = [(ev (addr), ev (v)) for (kind, addr, v, mem, _, _) in asms + if kind == 'MemUpdate' and 'mem' in m_var_name (mem)] + c_upd_d = dict (c_upds) + asm_upd_d = dict (asm_upds) + addr_ord = [addr for (addr, _) in asm_upds] + [addr for (addr, _) in c_upds + if addr not in asm_upd_d] + mism = [addr for addr in addr_ord + if c_upd_d.get (addr) != asm_upd_d.get (addr)] + return (c_upd_d == asm_upd_d, mism, c_upds, asm_upds) def get_pv_type (pv): - assert pv.is_op (['PValid', 'PArrayValid']) - typ_v = pv.vals[1] - assert typ_v.kind == 'Type' - typ = typ_v.val - if pv.is_op ('PArrayValid'): - return ('PArrayValid', typ, pv.vals[3]) - else: - return ('PValid', typ, None) + assert pv.is_op (['PValid', 'PArrayValid']) + typ_v = pv.vals[1] + assert typ_v.kind == 'Type' + typ = typ_v.val + if pv.is_op ('PArrayValid'): + return ('PArrayValid', typ, pv.vals[3]) + else: + return ('PValid', typ, None) def guess_pv (p, n, addr_expr): - vs = syntax.get_expr_var_set (addr_expr) - [pred] = p.preds[n] - pvs = [] - def vis (expr): - if expr.is_op (['PValid', 'PArrayValid']): - pvs.append (expr) - p.nodes[pred].cond.visit (vis) - match_pvs = [pv for pv in pvs - if set.union (* [syntax.get_expr_var_set (v) for v in pv.vals[2:]]) - == vs] - if len (match_pvs) > 1: - match_pvs = [pv for pv in match_pvs if pv.is_op ('PArrayValid')] - pv = match_pvs[0] - return pv + vs = syntax.get_expr_var_set (addr_expr) + [pred] = p.preds[n] + pvs = [] + def vis (expr): + if expr.is_op (['PValid', 'PArrayValid']): + pvs.append (expr) + p.nodes[pred].cond.visit (vis) + match_pvs = [pv for pv in pvs + if set.union (* [syntax.get_expr_var_set (v) for v in pv.vals[2:]]) + == vs] + if len (match_pvs) > 1: + match_pvs = [pv for pv in match_pvs if pv.is_op ('PArrayValid')] + pv = match_pvs[0] + return pv def eval_pv_type (rep, (n, vc), m, data): - if data[0] == 'PValid': - return data - else: - (nm, typ, offs) = data - offs = rep.to_smt_expr (offs, (n, vc)) - offs = search.eval_model_expr (m, rep.solv, offs) - return (nm, typ, offs) + if data[0] == 'PValid': + return data + else: + (nm, typ, offs) = data + offs = rep.to_smt_expr (offs, (n, vc)) + offs = search.eval_model_expr (m, rep.solv, offs) + return (nm, typ, offs) def trace_suspicious_mem (rep, m, tag = 'C'): - cs = trace_mem (rep, tag, m) - data = [(addr, search.eval_model_expr (m, rep.solv, - rep.to_smt_expr (addr, (n, vc))), (n, vc)) - for (kind, addr, v, mem, n, vc) in cs] - addr_sets = {} - for (addr, addr_v, _) in data: - addr_sets.setdefault (addr_v, set ()) - addr_sets[addr_v].add (addr) - dup_addrs = set ([addr_v for addr_v in addr_sets - if len (addr_sets[addr_v]) > 1]) - data = [(addr, addr_v, guess_pv (rep.p, n, addr), (n, vc)) - for (addr, addr_v, (n, vc)) in data - if addr_v in dup_addrs] - data = [(addr, addr_v, eval_pv_type (rep, (n, vc), m, - get_pv_type (pv)), rep.to_smt_expr (pv, (n, vc)), n) - for (addr, addr_v, pv, (n, vc)) in data] - dup_addr_types = set ([addr_v for addr_v in dup_addrs - if len (set ([t for (_, addr_v2, t, _, _) in data - if addr_v2 == addr_v])) > 1]) - res = [(addr_v, [(t, pv, n) for (_, addr_v2, t, pv, n) in data - if addr_v2 == addr_v]) - for addr_v in dup_addr_types] - for (addr_v, insts) in res: - print 'Address %s' % addr_v - for (t, pv, n) in insts: - print ' -- accessed with type %s at %s' % (t, n) - print ' (covered by %s)' % pv - return res + cs = trace_mem (rep, tag, m) + data = [(addr, search.eval_model_expr (m, rep.solv, + rep.to_smt_expr (addr, (n, vc))), (n, vc)) + for (kind, addr, v, mem, n, vc) in cs] + addr_sets = {} + for (addr, addr_v, _) in data: + addr_sets.setdefault (addr_v, set ()) + addr_sets[addr_v].add (addr) + dup_addrs = set ([addr_v for addr_v in addr_sets + if len (addr_sets[addr_v]) > 1]) + data = [(addr, addr_v, guess_pv (rep.p, n, addr), (n, vc)) + for (addr, addr_v, (n, vc)) in data + if addr_v in dup_addrs] + data = [(addr, addr_v, eval_pv_type (rep, (n, vc), m, + get_pv_type (pv)), rep.to_smt_expr (pv, (n, vc)), n) + for (addr, addr_v, pv, (n, vc)) in data] + dup_addr_types = set ([addr_v for addr_v in dup_addrs + if len (set ([t for (_, addr_v2, t, _, _) in data + if addr_v2 == addr_v])) > 1]) + res = [(addr_v, [(t, pv, n) for (_, addr_v2, t, pv, n) in data + if addr_v2 == addr_v]) + for addr_v in dup_addr_types] + for (addr_v, insts) in res: + print 'Address %s' % addr_v + for (t, pv, n) in insts: + print ' -- accessed with type %s at %s' % (t, n) + print ' (covered by %s)' % pv + return res def trace_var (rep, tag, m, v): - p = rep.p - ns = walk_model (rep, tag, m) - vds = rep.p.compute_var_dependencies () - trace = [] - vs = syntax.get_expr_var_set (v) - def fetch ((n, vc)): - if n in vds and [(nm, typ) for (nm, typ) in vs - if (nm, typ) not in vds[n]]: - return None - try: - (_, env) = rep.get_node_pc_env ((n, vc), tag) - s = solver.smt_expr (v, env, rep.solv) - s_x = solver.parse_s_expression (s) - ev = search.eval_model (m, s_x) - return (s, solver.smt_expr (ev, {}, None)) - except solver.EnvMiss, e: - return None - except AssertionError, e: - return None - val = None - for (n, vc) in ns: - n_nm = rep.node_count_name ((n, vc)) - val2 = fetch ((n, vc)) - if val2 != val: - if val2 == None: - print 'at %s: undefined' % n_nm - else: - print 'at %s:\t\t%s:\t\t%s' % (n_nm, - val2[0], val2[1]) - val = val2 - trace.append (((n, vc), val)) - if n not in p.nodes: - break - node = p.nodes[n] - if node.kind == 'Call': - msg = '' % (node.fname, - rep.node_count_name ((n, vc))) - print msg - trace.append (msg) - return trace + p = rep.p + ns = walk_model (rep, tag, m) + vds = rep.p.compute_var_dependencies () + trace = [] + vs = syntax.get_expr_var_set (v) + def fetch ((n, vc)): + if n in vds and [(nm, typ) for (nm, typ) in vs + if (nm, typ) not in vds[n]]: + return None + try: + (_, env) = rep.get_node_pc_env ((n, vc), tag) + s = solver.smt_expr (v, env, rep.solv) + s_x = solver.parse_s_expression (s) + ev = search.eval_model (m, s_x) + return (s, solver.smt_expr (ev, {}, None)) + except solver.EnvMiss, e: + return None + except AssertionError, e: + return None + val = None + for (n, vc) in ns: + n_nm = rep.node_count_name ((n, vc)) + val2 = fetch ((n, vc)) + if val2 != val: + if val2 == None: + print 'at %s: undefined' % n_nm + else: + print 'at %s:\t\t%s:\t\t%s' % (n_nm, + val2[0], val2[1]) + val = val2 + trace.append (((n, vc), val)) + if n not in p.nodes: + break + node = p.nodes[n] + if node.kind == 'Call': + msg = '' % (node.fname, + rep.node_count_name ((n, vc))) + print msg + trace.append (msg) + return trace def trace_deriv_ops (rep, m, tag): - n_vcs = walk_model (rep, tag, m) - derivs = set (('CountTrailingZeroes', 'CountLeadingZeroes', - 'WordReverse')) - def get_derivs (node): - dvs = set () - def visit (expr): - if expr.is_op (derivs): - dvs.add (expr) - node.visit (lambda x: (), visit) - return dvs - for (n, vc) in n_vcs: - if n not in rep.p.nodes: - continue - dvs = get_derivs (rep.p.nodes[n]) - if not dvs: - continue - print '%s:' % (rep.node_count_name ((n, vc))) - for dv in dvs: - [x] = dv.vals - x = rep.to_smt_expr (x, (n, vc)) - x = eval_str (x, {}, rep.solv, m) - print '\t%s: %s' % (dv.name, x) + n_vcs = walk_model (rep, tag, m) + derivs = set (('CountTrailingZeroes', 'CountLeadingZeroes', + 'WordReverse')) + def get_derivs (node): + dvs = set () + def visit (expr): + if expr.is_op (derivs): + dvs.add (expr) + node.visit (lambda x: (), visit) + return dvs + for (n, vc) in n_vcs: + if n not in rep.p.nodes: + continue + dvs = get_derivs (rep.p.nodes[n]) + if not dvs: + continue + print '%s:' % (rep.node_count_name ((n, vc))) + for dv in dvs: + [x] = dv.vals + x = rep.to_smt_expr (x, (n, vc)) + x = eval_str (x, {}, rep.solv, m) + print '\t%s: %s' % (dv.name, x) def check_pairings (): - for p in pairings.itervalues (): - print p['C'], p['ASM'] - as_args = functions[p['ASM']].inputs - c_args = functions[p['C']].inputs - print as_args, c_args - logic.mk_fun_inp_eqs (as_args, c_args, True) + for p in pairings.itervalues (): + print p['C'], p['ASM'] + as_args = functions[p['ASM']].inputs + c_args = functions[p['C']].inputs + print as_args, c_args + logic.mk_fun_inp_eqs (as_args, c_args, True) def loop_var_deps (p): - return [(n, [v for v in p.var_deps[n] - if p.var_deps[n][v] == 'LoopVariable']) - for n in p.loop_data] + return [(n, [v for v in p.var_deps[n] + if p.var_deps[n][v] == 'LoopVariable']) + for n in p.loop_data] def find_symbol (n, output = True): - from target_objects import symbols, sections - symbs = [] - secs = [] - if output: - def p (s): - print s - else: - p = lambda s: () - for (s, (addr, size, _)) in symbols.iteritems (): - if addr <= n and n < addr + size: - symbs.append (s) - p ('%x in %s (%x - %x)' % (n, s, addr, addr + size - 1)) - for (s, (start, end)) in sections.iteritems (): - if start <= n and n <= end: - secs.append (s) - p ('%x in section %s (%x - %x)' % (n, s, start, end)) - return (symbs, secs) + from target_objects import symbols, sections + symbs = [] + secs = [] + if output: + def p (s): + print s + else: + p = lambda s: () + for (s, (addr, size, _)) in symbols.iteritems (): + if addr <= n and n < addr + size: + symbs.append (s) + p ('%x in %s (%x - %x)' % (n, s, addr, addr + size - 1)) + for (s, (start, end)) in sections.iteritems (): + if start <= n and n <= end: + secs.append (s) + p ('%x in section %s (%x - %x)' % (n, s, start, end)) + return (symbs, secs) def assembly_point (p, n): - (_, hints) = p.node_tags[n] - if type (hints) != tuple or not logic.is_int (hints[1]): - return None - while p.node_tags[n][1][1] % 4 != 0: - [n] = p.preds[n] - return p.node_tags[n][1][1] + (_, hints) = p.node_tags[n] + if type (hints) != tuple or not logic.is_int (hints[1]): + return None + while p.node_tags[n][1][1] % 4 != 0: + [n] = p.preds[n] + return p.node_tags[n][1][1] def assembly_points (p, ns): - ns = [assembly_point (p, n) for n in ns] - ns = [n for n in ns if n != None] - return ns + ns = [assembly_point (p, n) for n in ns] + ns = [n for n in ns if n != None] + return ns def disassembly_lines (addrs): - f = open ('%s/kernel.elf.txt' % target_objects.target_dir) - addr_set = set (['%x' % addr for addr in addrs]) - ss = [l.strip () - for l in f if ':' in l and l.split(':', 1)[0] in addr_set] - return ss + f = open ('%s/kernel.elf.txt' % target_objects.target_dir) + addr_set = set (['%x' % addr for addr in addrs]) + ss = [l.strip () + for l in f if ':' in l and l.split(':', 1)[0] in addr_set] + return ss def disassembly (p, n): - if hasattr (n, '__iter__'): - ns = set (n) - else: - ns = [n] - addrs = sorted (set ([assembly_point (p, n) for n in ns]) - - set ([None])) - print 'asm %s' % ', '.join (['0x%x' % addr for addr in addrs]) - for s in disassembly_lines (addrs): - print s + if hasattr (n, '__iter__'): + ns = set (n) + else: + ns = [n] + addrs = sorted (set ([assembly_point (p, n) for n in ns]) + - set ([None])) + print 'asm %s' % ', '.join (['0x%x' % addr for addr in addrs]) + for s in disassembly_lines (addrs): + print s def disassembly_loop (p, n): - head = p.loop_id (n) - loop = p.loop_body (n) - ns = sorted (set (assembly_points (p, loop))) - entries = assembly_points (p, [n for n in p.preds[head] - if n not in loop]) - print 'Loop: [%s]' % ', '.join (['%x' % addr for addr in ns]) - for s in disassembly_lines (ns): - print s - print 'entry from %s' % ', '.join (['%x' % addr for addr in entries]) - for s in disassembly_lines (entries): - print s + head = p.loop_id (n) + loop = p.loop_body (n) + ns = sorted (set (assembly_points (p, loop))) + entries = assembly_points (p, [n for n in p.preds[head] + if n not in loop]) + print 'Loop: [%s]' % ', '.join (['%x' % addr for addr in ns]) + for s in disassembly_lines (ns): + print s + print 'entry from %s' % ', '.join (['%x' % addr for addr in entries]) + for s in disassembly_lines (entries): + print s def try_interpret_hyp (rep, hyp): - try: - expr = rep.interpret_hyp (hyp) - solver.smt_expr (expr, {}, rep.solv) - return None - except: - return ('Broken Hyp', hyp) + try: + expr = rep.interpret_hyp (hyp) + solver.smt_expr (expr, {}, rep.solv) + return None + except: + return ('Broken Hyp', hyp) def check_checks (): - p = problem.last_problem[0] - rep = rep_graph.mk_graph_slice (p) - proof = search.last_proof[0] - checks = check.proof_checks (p, proof) - all_hyps = set ([hyp for (_, hyp, _) in checks] - + [hyp for (hyps, _, _) in checks for hyp in hyps]) - results = [try_interpret_hyp (rep, hyp) for hyp in all_hyps] - return [r[1] for r in results if r] + p = problem.last_problem[0] + rep = rep_graph.mk_graph_slice (p) + proof = search.last_proof[0] + checks = check.proof_checks (p, proof) + all_hyps = set ([hyp for (_, hyp, _) in checks] + + [hyp for (hyps, _, _) in checks for hyp in hyps]) + results = [try_interpret_hyp (rep, hyp) for hyp in all_hyps] + return [r[1] for r in results if r] def proof_failed_groups (p = None, proof = None): - if p == None: - p = problem.last_problem[0] - if proof == None: - proof = search.last_proof[0] - checks = check.proof_checks (p, proof) - groups = check.proof_check_groups (checks) - failed = [] - for group in groups: - rep = rep_graph.mk_graph_slice (p) - (res, el) = check.test_hyp_group (rep, group) - if not res: - failed.append (group) - print 'Failed element: %s' % el - failed_nms = set ([s for group in failed for (_, _, s) in group]) - print 'Failed: %s' % failed_nms - return failed + if p == None: + p = problem.last_problem[0] + if proof == None: + proof = search.last_proof[0] + checks = check.proof_checks (p, proof) + groups = check.proof_check_groups (checks) + failed = [] + for group in groups: + rep = rep_graph.mk_graph_slice (p) + (res, el) = check.test_hyp_group (rep, group) + if not res: + failed.append (group) + print 'Failed element: %s' % el + failed_nms = set ([s for group in failed for (_, _, s) in group]) + print 'Failed: %s' % failed_nms + return failed def read_summary (f): - results = {} - times = {} - for line in f: - if not line.startswith ('Time taken to'): - continue - bits = line.split () - assert bits[:4] == ['Time', 'taken', 'to', 'check'] - res = bits[4] - [ref] = [i for (i, b) in enumerate (bits) if b == '<='] - f = bits[ref + 1] - [pair] = [pair for pair in pairings[f] - if pair.name in line] - time = float (bits[-1]) - results[pair] = res - times[pair] = time - return (results, times) + results = {} + times = {} + for line in f: + if not line.startswith ('Time taken to'): + continue + bits = line.split () + assert bits[:4] == ['Time', 'taken', 'to', 'check'] + res = bits[4] + [ref] = [i for (i, b) in enumerate (bits) if b == '<='] + f = bits[ref + 1] + [pair] = [pair for pair in pairings[f] + if pair.name in line] + time = float (bits[-1]) + results[pair] = res + times[pair] = time + return (results, times) def unfold_defs_sexpr (defs, sexpr, depthlimit = -1): - if type (sexpr) == str: - sexpr = defs.get (sexpr, sexpr) - print sexpr - return sexpr - elif depthlimit == 0: - return sexpr - return tuple ([sexpr[0]] + [unfold_defs_sexpr (defs, s, depthlimit - 1) - for s in sexpr[1:]]) + if type (sexpr) == str: + sexpr = defs.get (sexpr, sexpr) + print sexpr + return sexpr + elif depthlimit == 0: + return sexpr + return tuple ([sexpr[0]] + [unfold_defs_sexpr (defs, s, depthlimit - 1) + for s in sexpr[1:]]) def unfold_defs (defs, hyp, depthlimit = -1): - return solver.flat_s_expression (unfold_defs_sexpr (defs, - solver.parse_s_expression (hyp), depthlimit)) + return solver.flat_s_expression (unfold_defs_sexpr (defs, + solver.parse_s_expression (hyp), depthlimit)) def investigate_unsat (solv, hyps = None): - if hyps == None: - hyps = list (solver.last_hyps[0]) - assert solv.hyps_sat_raw (hyps) == 'unsat', hyps - kept_hyps = [] - while hyps: - h = hyps.pop () - if solv.hyps_sat_raw (hyps + kept_hyps) != 'unsat': - kept_hyps.append (h) - assert solv.hyps_sat_raw (kept_hyps) == 'unsat', kept_hyps - split_hyps = sorted (set ([(hyp2, tag) for (hyp, tag) in kept_hyps - for hyp2 in solver.split_hyp (hyp)])) - if len (split_hyps) > len (kept_hyps): - return investigate_unsat (solv, split_hyps) - def_hyps = [(unfold_defs (solv.defs, h, 2), tag) - for (h, tag) in kept_hyps] - if def_hyps != kept_hyps: - return investigate_unsat (solv, def_hyps) - return kept_hyps + if hyps == None: + hyps = list (solver.last_hyps[0]) + assert solv.hyps_sat_raw (hyps) == 'unsat', hyps + kept_hyps = [] + while hyps: + h = hyps.pop () + if solv.hyps_sat_raw (hyps + kept_hyps) != 'unsat': + kept_hyps.append (h) + assert solv.hyps_sat_raw (kept_hyps) == 'unsat', kept_hyps + split_hyps = sorted (set ([(hyp2, tag) for (hyp, tag) in kept_hyps + for hyp2 in solver.split_hyp (hyp)])) + if len (split_hyps) > len (kept_hyps): + return investigate_unsat (solv, split_hyps) + def_hyps = [(unfold_defs (solv.defs, h, 2), tag) + for (h, tag) in kept_hyps] + if def_hyps != kept_hyps: + return investigate_unsat (solv, def_hyps) + return kept_hyps def test_interesting_linear_series_exprs (): - pairs = set ([pair for f in pairings for pair in pairings[f]]) - notes = {} - for pair in pairs: - p = check.build_problem (pair) - for n in search.init_loops_to_split (p, ()): - intr = logic.interesting_linear_series_exprs (p, n, - search.get_loop_var_analysis_at (p, n)) - if intr: - notes[pair.name] = True - if 'Call' in str (intr): - notes[pair.name] = 'Call!' - return notes + pairs = set ([pair for f in pairings for pair in pairings[f]]) + notes = {} + for pair in pairs: + p = check.build_problem (pair) + for n in search.init_loops_to_split (p, ()): + intr = logic.interesting_linear_series_exprs (p, n, + search.get_loop_var_analysis_at (p, n)) + if intr: + notes[pair.name] = True + if 'Call' in str (intr): + notes[pair.name] = 'Call!' + return notes def var_analysis (p, n): - va = search.get_loop_var_analysis_at (p, n) - cats = {} - for (v, kind) in va: - if kind[0] == 'LoopLinearSeries': - offs = kind[2] - kind = kind[0] - else: - offs = None - cats.setdefault (kind, []) - cats[kind].append ((v, offs)) - for kind in cats: - print '%s:' % kind - for (v, offs) in cats[kind]: - print ' %s (%s)' % (syntax.pretty_expr (v), - syntax.pretty_type (v.typ)) - if offs: - print ' ++ %s' % syntax.pretty_expr (offs) + va = search.get_loop_var_analysis_at (p, n) + cats = {} + for (v, kind) in va: + if kind[0] == 'LoopLinearSeries': + offs = kind[2] + kind = kind[0] + else: + offs = None + cats.setdefault (kind, []) + cats[kind].append ((v, offs)) + for kind in cats: + print '%s:' % kind + for (v, offs) in cats[kind]: + print ' %s (%s)' % (syntax.pretty_expr (v), + syntax.pretty_type (v.typ)) + if offs: + print ' ++ %s' % syntax.pretty_expr (offs) def var_value_sites (rep, v): - if type (v) == str: - matches = lambda (nm, _): v in nm - elif type (v) == tuple: - matches = lambda (nm, typ): v == (nm, typ) - v_ord = [] - d = {} - for (tag, n, vc) in rep.node_pc_env_order: - (pc, env) = rep.get_node_pc_env ((n, vc), tag = tag) - for (v2, smt_exp) in env.iteritems (): - if matches (v2): - if smt_exp not in d: - v_ord.append (smt_exp) - d[smt_exp] = [] - d[smt_exp].append ((n, vc)) - for smt_exp in v_ord: - print smt_exp - if smt_exp in rep.solv.defs: - print (' = %s' % repr (rep.solv.defs[smt_exp])) - print (' - at: %s' % d[smt_exp]) - if v_ord: - print ('') - return (v_ord, d) + if type (v) == str: + matches = lambda (nm, _): v in nm + elif type (v) == tuple: + matches = lambda (nm, typ): v == (nm, typ) + v_ord = [] + d = {} + for (tag, n, vc) in rep.node_pc_env_order: + (pc, env) = rep.get_node_pc_env ((n, vc), tag = tag) + for (v2, smt_exp) in env.iteritems (): + if matches (v2): + if smt_exp not in d: + v_ord.append (smt_exp) + d[smt_exp] = [] + d[smt_exp].append ((n, vc)) + for smt_exp in v_ord: + print smt_exp + if smt_exp in rep.solv.defs: + print (' = %s' % repr (rep.solv.defs[smt_exp])) + print (' - at: %s' % d[smt_exp]) + if v_ord: + print ('') + return (v_ord, d) def loop_num_leaves (p, n): - for n in p.loop_body (n): - va = search.get_loop_var_analysis_at (p, n) - n_leaf = len ([1 for (v, kind) in va if kind == 'LoopLeaf']) - print (n, n_leaf) + for n in p.loop_body (n): + va = search.get_loop_var_analysis_at (p, n) + n_leaf = len ([1 for (v, kind) in va if kind == 'LoopLeaf']) + print (n, n_leaf) def try_pairing_at_funcall (p, name, head = None, restrs = None, hyps = None, - at = 'At'): - pairs = set (pairings[name]) - addrs = [n for (n, name2) in p.function_call_addrs () - if [pair for pair in pairings[name2] if pair in pairs]] - assert at in ['At', 'After'] - if at == 'After': - addrs = [p.nodes[n].cont for n in addrs] - if head == None: - tags = p.pairing.tags - [head] = [n for n in search.init_loops_to_split (p, ()) - if p.node_tags[n][0] == tags[0]] - if restrs == None: - restrs = () - if hyps == None: - hyps = check.init_point_hyps (p) - while True: - res = search.find_split_loop (p, head, restrs, hyps, - node_restrs = set (addrs)) - if res[0] == 'CaseSplit': - (_, ((n, tag), _)) = res - hyp = rep_graph.pc_true_hyp (((n, restrs), tag)) - hyps = hyps + [hyp] - else: - return res + at = 'At'): + pairs = set (pairings[name]) + addrs = [n for (n, name2) in p.function_call_addrs () + if [pair for pair in pairings[name2] if pair in pairs]] + assert at in ['At', 'After'] + if at == 'After': + addrs = [p.nodes[n].cont for n in addrs] + if head == None: + tags = p.pairing.tags + [head] = [n for n in search.init_loops_to_split (p, ()) + if p.node_tags[n][0] == tags[0]] + if restrs == None: + restrs = () + if hyps == None: + hyps = check.init_point_hyps (p) + while True: + res = search.find_split_loop (p, head, restrs, hyps, + node_restrs = set (addrs)) + if res[0] == 'CaseSplit': + (_, ((n, tag), _)) = res + hyp = rep_graph.pc_true_hyp (((n, restrs), tag)) + hyps = hyps + [hyp] + else: + return res def init_true_hyp (p, tag, expr): - n = p.get_entry (tag) - vis = ((n, ()), tag) - assert expr.typ == syntax.boolT, expr - return rep_graph.eq_hyp ((expr, vis), (syntax.true_term, vis)) + n = p.get_entry (tag) + vis = ((n, ()), tag) + assert expr.typ == syntax.boolT, expr + return rep_graph.eq_hyp ((expr, vis), (syntax.true_term, vis)) def smt_print (expr): - env = {} - while True: - try: - return solver.smt_expr (expr, env, None) - except solver.EnvMiss, e: - env[(e.name, e.typ)] = e.name + env = {} + while True: + try: + return solver.smt_expr (expr, env, None) + except solver.EnvMiss, e: + env[(e.name, e.typ)] = e.name diff --git a/decompiler/decompile.py b/decompiler/decompile.py new file mode 100755 index 00000000..db77f6c9 --- /dev/null +++ b/decompiler/decompile.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2022, Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +import argparse +import itertools +import mmap +import os +import subprocess +import sys +import textwrap + +from pathlib import Path +from typing import NamedTuple, Protocol + + +def unlines(*lines: str) -> str: + return ''.join(f'{line}\n' for line in lines) + + +# We need to see this exact text printed near the end of the output. +check_text = textwrap.dedent('''\ + Proving correctness of call offsets + =================================== + + Offsets proved correct. + + Summary + ======= + + No stack intro failures. + No graph spec failures. + No export failures. + No call offset failures. +''').encode() + + +def output_ok(output_filename: str) -> bool: + with open(output_filename, 'rb') as output: + with mmap.mmap(output.fileno(), 0, prot=mmap.PROT_READ) as buff: + return buff.find(check_text) > 0 + + +class Decompile(Protocol): + def decompile(self) -> None: + ... + + +class Args(NamedTuple): + filename_prefix: Path + fast: bool + ignore: str + + +class DecompileLocal(NamedTuple): + hol4_src: str + args: Args + + def hol_input(self) -> bytes: + args = self.args + fast_str = 'true' if args.fast else 'false' + text = unlines(f'val _ = load "decompileLib";', + f'val _ = decompileLib.decomp "{args.filename_prefix}" {fast_str} "{args.ignore}";') + return text.encode() + + def decompile(self) -> None: + decompiler_src = self.hol4_src / 'examples' / 'machine-code' / 'graph' + hol4_bin = self.hol4_src / 'bin' + hol4_exe = hol4_bin / 'hol' + PATH = os.environ['PATH'] + env = {**os.environ, 'PATH': f'{hol4_bin}:{PATH}'} + output_file = f'{self.args.filename_prefix}_output.txt' + subprocess.run(f'{hol4_exe} 2>&1 | tee {output_file}', shell=True, cwd=decompiler_src, env=env, + input=self.hol_input(), check=True) + assert output_ok(output_file) + + +class DecompileDocker(NamedTuple): + command: str + image: str + args: Args + + def decompile(self) -> None: + target_dir = self.args.filename_prefix.parent + cmd = [self.command, 'run', '--rm', '-i', + '--mount', f'type=bind,source={target_dir},dst=/target', + '--mount', 'type=tmpfs,dst=/tmp', + self.image, f'/target/{self.args.filename_prefix.name}'] + if self.args.fast: + cmd.append('--fast') + if self.args.ignore: + cmd.extend(['--ignore', self.args.ignore]) + subprocess.run(cmd, cwd=target_dir, stdin=subprocess.DEVNULL, check=True) + + +def parse_args() -> Decompile: + parser = argparse.ArgumentParser(description='Run the decompiler.') + parser.add_argument('filename', metavar='FILENAME', type=str, + help='input filename prefix, e.g. /path/to/example for /path/to/example.elf.txt') + parser.add_argument('--fast', action='store_true', dest='fast', + help='skip some proofs') + parser.add_argument('--ignore', metavar='NAMES', type=str, action='append', + help='functions to ignore (comma-separated list)') + # For internal use only. + parser.add_argument('--docker', metavar='COMMAND', help=argparse.SUPPRESS) + + parsed_args = parser.parse_args() + + # Combine multiple ignore options + ignore_gen = (i for j in (parsed_args.ignore or []) for i in j.split(',') if i) + ignore = ','.join({i: None for i in ignore_gen}.keys()) + + args = Args(filename_prefix=Path(parsed_args.filename).resolve(), + fast=parsed_args.fast, + ignore=ignore) + + if parsed_args.docker: + image = os.environ.get('DECOMPILER_DOCKER_IMAGE', 'ghcr.io/sel4/decompiler:latest') + return DecompileDocker(command=parsed_args.docker, image=image, args=args) + + hol4_dir = os.environ.get('HOL4_DIR') + if hol4_dir is None: + hol4_path = Path(__file__).resolve().parent / 'src' / 'HOL4' + else: + hol4_path = Path(hol4_dir).resolve() + + return DecompileLocal(hol4_src=hol4_path, args=args) + + +def main(decompile: Decompile): + try: + decompile.decompile() + return 0 + except KeyboardInterrupt: + return STATUS_SIGNAL_OFFSET + signal.SIGINT + except BrokenPipeError: + # Our output was closed early, e.g. we were piped to `less` and the user quit + # before we finished. If sys.stdout still has unwritten characters its buffer + # that it can't write to the closed file descriptor, then the interpreter prints + # an ugly warning to sys.stderr as it shuts down. We assume we don't care, so + # close sys.stderr to suppress the warning. + sys.stderr.close() + return STATUS_SIGNAL_OFFSET + signal.SIGPIPE + + +if __name__ == '__main__': + exit(main(parse_args())) diff --git a/decompiler/default.nix b/decompiler/default.nix new file mode 100644 index 00000000..df2acb42 --- /dev/null +++ b/decompiler/default.nix @@ -0,0 +1,86 @@ +# Copyright (c) 2022, Kry10 Limited. +# SPDX-License-Identifier: BSD-2-Clause + +# Packages the decompiler as a Nix derivation, +# and also produces a Docker image. + +# This assumes that PolyML and HOL4 sources have been checked out. +# These can be checked out using: +# ./setup-decompiler checkout --upstream + +{ + polyml_src ? ./src/polyml, + hol4_src ? ./src/HOL4, +}: + +let + + pins = import ../nix/pins.nix; + inherit (pins) pkgs lib stdenv; + inherit (pins.herculesGitignore) gitignoreFilter; + + polyml = pkgs.polyml.overrideAttrs (_: { + name = "polyml"; + src = lib.cleanSourceWith { + name = "polyml-source"; + src = polyml_src; + filter = gitignoreFilter polyml_src; + }; + }); + + hol4-graph-decompiler = stdenv.mkDerivation { + name = "hol4-graph-decompiler"; + + src = lib.cleanSourceWith { + name = "hol4-source"; + src = hol4_src; + filter = gitignoreFilter hol4_src; + }; + + buildInputs = [ pkgs.fontconfig pkgs.graphviz polyml ]; + + buildCommand = '' + set -eu + + mkdir fonts + cat ${pkgs.fontconfig.out}/etc/fonts/fonts.conf > fonts/fonts.conf + export FONTCONFIG_FILE=$PWD/fonts/fonts.conf + + cp -r "$src" "$out" + chmod -R +w "$out" + cd "$out" + + poly < tools/smart-configure.sml + bin/build + + PATH="$out/bin:$PATH" + cd examples/machine-code/graph + Holmake + ''; + }; + + decompile-py = pkgs.runCommand "decompile-py" {} '' + mkdir -p "$out" + cp --preserve=all "${./decompile.py}" "$out/decompile.py" + (cd $out && ${pkgs.python3.interpreter} -m compileall "decompile.py") + ''; + + decompile_script = '' + #!${pkgs.runtimeShell} + export HOL4_DIR="${hol4-graph-decompiler}" + exec "${pkgs.python3.interpreter}" "${decompile-py}/decompile.py" "$@" + ''; + + # For the Docker customisation layer, the script should be in a `bin` directory. + decompile-bin = pkgs.writeScriptBin "decompile" decompile_script; + + # But for non-Docker use, a plain script is more convenient. + decompile = pkgs.writeScript "decompile" decompile_script; + + decompiler-image = pkgs.dockerTools.streamLayeredImage { + name = "decompiler"; + contents = with pkgs; [ bashInteractive coreutils polyml python3 decompile-bin ]; + config = { Entrypoint = [ "${decompile-bin}/bin/decompile" ]; }; + }; + +in { inherit decompile decompile-bin decompiler-image; } diff --git a/decompiler/setup-decompiler.py b/decompiler/setup-decompiler.py new file mode 100755 index 00000000..560b9145 --- /dev/null +++ b/decompiler/setup-decompiler.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2022, Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# A toolkit for building, installing and distributing the decompiler. + +import argparse +import os +import shutil +import signal +import subprocess +import sys +import textwrap + +from pathlib import Path +from typing import Callable, NamedTuple, Protocol, TypeVar + +STATUS_SIGNAL_OFFSET = 128 +STATUS_USAGE = 1 + +# We currently always install decompiler artifcats in a fixed location +# relative to the graph-refine repository root. +decompiler_dir = Path(__file__).resolve().parent +decompile = decompiler_dir / 'decompile' + +hol4_source_dir = decompiler_dir / 'src' / 'HOL4' +polyml_source_dir = decompiler_dir / 'src' / 'polyml' +polyml_install_dir = decompiler_dir / 'install' / 'polyml' + +# File/link names relative to `decompiler_dir`: +decompile_local = Path('decompile.py') + +seL4_branch = 'seL4-decompiler' + +T = TypeVar('T') + + +class Repo(Protocol): + def url(self, ssh: bool) -> str: + ... + + +class GitHub(NamedTuple): + repo: str + + def url(self, ssh: bool) -> str: + return f'git@github.com:{self.repo}.git' if ssh else f'https://github.com/{self.repo}.git' + + +class Branch(NamedTuple): + repo: Repo + branch: str + + +class Source(NamedTuple): + name: str + upstream: Branch + fork: Branch + checkout_dir: Path + + +hol4_source = Source(name='HOL4', + upstream=Branch(repo=GitHub('HOL-Theorem-Prover/HOL'), branch='master'), + fork=Branch(repo=GitHub('seL4/HOL'), branch=seL4_branch), + checkout_dir=hol4_source_dir) + +polyml_source = Source(name='PolyML', + upstream=Branch(repo=GitHub('polyml/polyml'), branch='master'), + fork=Branch(repo=GitHub('seL4/polyml'), branch=seL4_branch), + checkout_dir=polyml_source_dir) + + +class CheckoutPhase(Protocol): + def do_checkout(self) -> None: + ... + + +def rm_path(*paths: Path) -> None: + for p in paths: + if p.is_symlink(): + os.unlink(p) + if p.is_dir(): + shutil.rmtree(p) + elif p.exists(): + os.unlink(p) + + +def check_empty(*paths: Path) -> None: + non_empty_paths = [p for p in paths if p.exists()] + if non_empty_paths: + sys.stderr.write( + 'setup-decompiler: paths should not exist when using --checkout without --force:\n') + for p in non_empty_paths: + sys.stderr.write(f' {p}\n') + sys.stderr.flush() + exit(STATUS_USAGE) + + +def check_exists(*paths: Path) -> None: + paths_not_exist = [p for p in paths if not p.is_dir()] + if paths_not_exist: + sys.stderr.write( + 'setup-decompiler: paths should already exist when not using --checkout:\n') + for p in paths_not_exist: + sys.stderr.write(f' {p}\n') + sys.stderr.flush() + exit(STATUS_USAGE) + + +def checkout_source(source: Source, ssh: bool) -> None: + print(f'setup-decompiler: Checking out {source.name}...') + source.checkout_dir.parent.mkdir(parents=True, exist_ok=True) + subprocess.run(['git', 'clone', '-q', '-b', source.fork.branch, + source.fork.repo.url(ssh), source.checkout_dir], + cwd=decompiler_dir, stdin=subprocess.DEVNULL, check=True) + subprocess.run(['git', 'remote', 'add', 'upstream', source.upstream.repo.url(ssh)], + cwd=source.checkout_dir, stdin=subprocess.DEVNULL, check=True) + subprocess.run(['git', 'fetch', '-q', 'upstream'], + cwd=source.checkout_dir, stdin=subprocess.DEVNULL, check=True) + + +def upstream_source(source: Source) -> None: + print(f'setup-decompiler: Merging upstream {source.name}...') + subprocess.run(['git', 'merge', '-q', '-m', 'merge upstream', f'upstream/{source.upstream.branch}'], + cwd=source.checkout_dir, stdin=subprocess.DEVNULL, check=True) + + +class CheckoutSources(NamedTuple): + checkout: bool + upstream: bool + force: bool + ssh: bool + + def github_url(self, repo) -> str: + return f'git@github.com:{repo}.git' if self.ssh else f'https://github.com/{repo}.git' + + def do_checkout(self) -> None: + if (self.force or self.upstream) and not self.checkout: + sys.stderr.write( + 'setup-decompiler: --upstream and --force only makes sense with --checkout\n') + sys.stderr.flush() + exit(STATUS_USAGE) + if not self.checkout: + check_exists(hol4_source_dir, polyml_source_dir) + return + if self.force: + rm_path(hol4_source_dir, polyml_source_dir) + if self.checkout: + check_empty(hol4_source_dir, polyml_source_dir) + for source in [hol4_source, polyml_source]: + checkout_source(source, self.ssh) + if self.upstream: + for source in [hol4_source, polyml_source]: + upstream_source(source) + + +class CheckoutNone(NamedTuple): + def do_checkout(self) -> None: + # Nothing to do here. + pass + + +class InstallPhase(Protocol): + def do_install(self) -> None: + ... + + +class InstallNone(NamedTuple): + def do_install(self) -> None: + # Nothing to do here. + pass + + +def rm_decompiler_setup() -> None: + rm_path(decompile) + + +# This is mainly intended for initial installation of the decompiler, +# and not for incremental development on HOL4 or the decompiler. +# If you're working on the decompiler, we assume you know how to use Holmake. +class InstallLocal(NamedTuple): + def do_install(self) -> None: + rm_decompiler_setup() + rm_path(polyml_install_dir) + # PolyML + print('setup-decompiler: Building Polyml...') + subprocess.run(['git', 'clean', '-fdX'], + cwd=polyml_source_dir, stdin=subprocess.DEVNULL, check=True) + subprocess.run(['./configure', f'--prefix={polyml_install_dir}'], + cwd=polyml_source_dir, stdin=subprocess.DEVNULL, check=True) + subprocess.run(['make'], cwd=polyml_source_dir, stdin=subprocess.DEVNULL, check=True) + subprocess.run(['make', 'install'], cwd=polyml_source_dir, + stdin=subprocess.DEVNULL, check=True) + poly = polyml_install_dir / 'bin' / 'poly' + # HOL4 + print('setup-decompiler: Building HOL4...') + subprocess.run(['git', 'clean', '-fdX'], + cwd=hol4_source_dir, stdin=subprocess.DEVNULL, check=True) + with open(hol4_source_dir / 'tools' / 'smart-configure.sml') as configure: + subprocess.run([poly], cwd=hol4_source_dir, stdin=configure, check=True) + subprocess.run([hol4_source_dir / 'bin' / 'build'], + cwd=hol4_source_dir, stdin=subprocess.DEVNULL, check=True) + # Decompiler + print('setup-decompiler: Building the decompiler...') + PATH = os.environ['PATH'] + decompiler_src = hol4_source_dir / 'examples' / 'machine-code' / 'graph' + env = {**os.environ, 'PATH': f'{hol4_source_dir}/bin:{PATH}'} + subprocess.run(['Holmake'], env=env, cwd=decompiler_src, + stdin=subprocess.DEVNULL, check=True) + # Script + os.symlink(decompile_local, decompile) + + +class InstallNix(NamedTuple): + def do_install(self) -> None: + rm_decompiler_setup() + print('setup-decompiler: Building the decompiler using Nix...') + subprocess.run(['nix-build', '-A', 'decompiler', '-o', decompile], + cwd=decompiler_dir, stdin=subprocess.DEVNULL, check=True) + + +class InstallDocker(NamedTuple): + command: str + + def do_install(self) -> None: + rm_decompiler_setup() + print('setup-decompiler: Installing a container-based decompiler...') + subprocess.run([self.command, 'pull', 'ghcr.io/sel4/decompiler:latest'], + stdin=subprocess.DEVNULL, check=True) + wrapper_script = textwrap.dedent(f'''\ + #!/usr/bin/env python3 + + # Copyright (c) 2022, Kry10 Limited + # SPDX-License-Identifier: BSD-2-Clause + + # Use {self.command} to run the decompiler. + + import subprocess + import sys + from pathlib import Path + + # The actual functionality is in decompile.py, so we don't + # have to duplicate argument parsing here. + decompile = Path(__file__).resolve().parent / 'decompile.py' + p = subprocess.run([decompile, '--docker', '{self.command}'] + sys.argv[1:]) + sys.exit(p.returncode) + ''') + with open(decompile, 'w') as script_file: + script_file.write(wrapper_script) + # Copy read permissions to execute permissions. + mode = decompile.stat().st_mode + decompile.chmod(mode | mode >> 2 & 0o111) + + +class SetupCommand(NamedTuple): + checkout: CheckoutPhase + install: InstallPhase + + def do_setup(self) -> None: + self.checkout.do_checkout() + self.install.do_install() + + +def parse_args() -> SetupCommand: + parser = argparse.ArgumentParser(description='Set up the decompiler.') + + checkout_opt = argparse.ArgumentParser(add_help=False) + checkout_opt.add_argument('--checkout', action='store_true', dest='checkout', + help='Clone source repositories.') + + checkout_extra = argparse.ArgumentParser(add_help=False) + checkout_extra.add_argument('--upstream', action='store_true', dest='upstream', + help='Merge upstream changes.') + checkout_extra.add_argument('--force', action='store_true', dest='force', + help='Replace an existing checkout with a new one.') + checkout_extra.add_argument('--ssh', action='store_true', dest='ssh', + help='Use SSH when cloning from GitHub.') + + subparsers = parser.add_subparsers(required=True, dest='command') + checkout_cmd = subparsers.add_parser('checkout', parents=[checkout_extra], + help='Clone source repositories, but do not install.') + local_cmd = subparsers.add_parser('local', parents=[checkout_opt, checkout_extra], + help='Install decompiler built locally.') + nix_cmd = subparsers.add_parser('nix', parents=[checkout_opt, checkout_extra], + help='Install decompiler built using Nix.') + docker_cmd = subparsers.add_parser('docker', help='Install decompiler using Docker.') + podman_cmd = subparsers.add_parser('podman', help='Install decompiler using podman.') + + def require_checkout(args: argparse.Namespace) -> CheckoutPhase: + return CheckoutSources(checkout=args.checkout, + upstream=args.upstream, + force=args.force, + ssh=args.ssh) + + def install_docker(command: str) -> Callable[[argparse.Namespace], InstallPhase]: + def do_it(args: argparse.Namespace) -> InstallPhase: + return InstallDocker(command=command) + return do_it + + def just(cons: Callable[[], T]) -> Callable[[argparse.Namespace], T]: + def do_it(args: argparse.Namespace) -> T: + return cons() + return do_it + + checkout_cmd.set_defaults(checkout=True, + checkout_phase=require_checkout, install_phase=just(InstallNone)) + local_cmd.set_defaults(checkout_phase=require_checkout, install_phase=just(InstallLocal)) + nix_cmd.set_defaults(checkout_phase=require_checkout, install_phase=just(InstallNix)) + docker_cmd.set_defaults(checkout_phase=just(CheckoutNone), install_phase=install_docker('docker')) + podman_cmd.set_defaults(checkout_phase=just(CheckoutNone), install_phase=install_docker('podman')) + + args = parser.parse_args() + + return SetupCommand(checkout=args.checkout_phase(args), + install=args.install_phase(args)) + + +def main(setup_command: SetupCommand) -> int: + try: + setup_command.do_setup() + return 0 + except KeyboardInterrupt: + return STATUS_SIGNAL_OFFSET + signal.SIGINT + except BrokenPipeError: + # Our output was closed early, e.g. we were piped to `less` and the user quit + # before we finished. If sys.stdout still has unwritten characters its buffer + # that it can't write to the closed file descriptor, then the interpreter prints + # an ugly warning to sys.stderr as it shuts down. We assume we don't care, so + # close sys.stderr to suppress the warning. + sys.stderr.close() + return STATUS_SIGNAL_OFFSET + signal.SIGPIPE + + +if __name__ == '__main__': + exit(main(parse_args())) diff --git a/graph-refine.py b/graph-refine.py index 676aa889..cec419dc 100644 --- a/graph-refine.py +++ b/graph-refine.py @@ -23,335 +23,411 @@ import random import traceback import time -#import diagnostic import sys +import os +import os.path + + +def toplevel_check(pair, check_loops=True, report=False, count=None, + only_build_problem=False): + if not only_build_problem: + printout('Testing Function pair %s' % pair) + if count and not only_build_problem: + (i, n) = count + printout(' (function pairing %d of %d)' % (i + 1, n)) + + for (tag, fname) in pair.funs.iteritems(): + if not functions[fname].entry: + printout('Skipping %s, underspecified %s' % (pair, tag)) + return 'None' + prev_tracer = tracer[0] + if report: + tracer[0] = lambda s, n: () + + exception = None + + trace(time.asctime()) + start_time = time.time() + sys.stdout.flush() + try: + p = check.build_problem(pair) + if only_build_problem: + tracer[0] = prev_tracer + return 'True' + if report: + printout(' .. built problem, finding proof') + if not check_loops and p.loop_data: + printout('Problem has loop!') + tracer[0] = prev_tracer + return 'Loop' + if check_loops == 'only' and not p.loop_data: + printout('No loop in problem.') + tracer[0] = prev_tracer + return 'NoLoop' + proof = search.build_proof(p) + if report: + printout(' .. proof found.') + #print 'report\n' + #print report + #print proof + + try: + if report: + result = check.check_proof_report(p, proof) + else: + result = check.check_proof(p, proof) + if result: + printout('Refinement proven.') + else: + printout('Refinement NOT proven.') + except solver.SolverFailure, e: + printout('Solver timeout/failure in proof check.') + result = 'CheckSolverFailure' + except Exception, e: + trace('EXCEPTION in checking %s:' % p.name) + exception = sys.exc_info() + result = 'CheckEXCEPT' + + except problem.Abort: + result = 'ProblemAbort' + except search.NoSplit: + result = 'ProofNoSplit' + except solver.SolverFailure, e: + printout('Solver timeout/failure in proof search.') + result = 'ProofSolverFailure' + + except Exception, e: + trace('EXCEPTION in handling %s:' % pair) + exception = sys.exc_info() + result = 'ProofEXCEPT' + + end_time = time.time() + tracer[0] = prev_tracer + if exception: + (etype, evalue, tb) = exception + traceback.print_exception(etype, evalue, tb, + file=sys.stdout) + + if not only_build_problem: + printout('Result %s for pair %s, time taken: %.2fs' + % (result, pair, end_time - start_time)) + sys.stdout.flush() + + return str(result) -if __name__ == '__main__': - args = target_objects.load_target_args () - -def toplevel_check (pair, check_loops = True, report = False, count = None, - only_build_problem = False): - if not only_build_problem: - printout ('Testing Function pair %s' % pair) - if count and not only_build_problem: - (i, n) = count - printout (' (function pairing %d of %d)' % (i + 1, n)) - - for (tag, fname) in pair.funs.iteritems (): - if not functions[fname].entry: - printout ('Skipping %s, underspecified %s' % (pair, tag)) - return 'None' - prev_tracer = tracer[0] - if report: - tracer[0] = lambda s, n: () - - exception = None - - trace (time.asctime ()) - start_time = time.time() - sys.stdout.flush () - try: - p = check.build_problem (pair) - if only_build_problem: - tracer[0] = prev_tracer - return 'True' - if report: - printout (' .. built problem, finding proof') - if not check_loops and p.loop_data: - printout ('Problem has loop!') - tracer[0] = prev_tracer - return 'Loop' - if check_loops == 'only' and not p.loop_data: - printout ('No loop in problem.') - tracer[0] = prev_tracer - return 'NoLoop' - proof = search.build_proof (p) - if report: - printout (' .. proof found.') - - try: - if report: - result = check.check_proof_report (p, proof) - else: - result = check.check_proof (p, proof) - if result: - printout ('Refinement proven.') - else: - printout ('Refinement NOT proven.') - except solver.SolverFailure, e: - printout ('Solver timeout/failure in proof check.') - result = 'CheckSolverFailure' - except Exception, e: - trace ('EXCEPTION in checking %s:' % p.name) - exception = sys.exc_info () - result = 'CheckEXCEPT' - - except problem.Abort: - result = 'ProblemAbort' - except search.NoSplit: - result = 'ProofNoSplit' - except solver.SolverFailure, e: - printout ('Solver timeout/failure in proof search.') - result = 'ProofSolverFailure' - - except Exception, e: - trace ('EXCEPTION in handling %s:' % pair) - exception = sys.exc_info () - result = 'ProofEXCEPT' - - end_time = time.time () - tracer[0] = prev_tracer - if exception: - (etype, evalue, tb) = exception - traceback.print_exception (etype, evalue, tb, - file = sys.stdout) - - if not only_build_problem: - printout ('Result %s for pair %s, time taken: %.2fs' - % (result, pair, end_time - start_time)) - sys.stdout.flush () - - return str (result) word_re = re.compile('\\w+') -def name_search (s, tags = None): - if s in pairings and len (pairings[s]) == 1: - return pairings[s][0] - pairs = list (set ([pair for f in pairings for pair in pairings[f] - if s in pair.name - if not tags or tags.issubset (set (pair.tags))])) - if len (pairs) == 1: - return pairs[0] - word_pairs = [p for p in pairs if s in word_re.findall (str (p))] - if len (word_pairs) == 1: - return word_pairs[0] - print 'Possibilities for %r: %s' % (s, [str (p) for p in pairs]) - return None - -def check_search (s, tags = None, report_mode = False, - check_loops = True): - pair = name_search (s, tags = tags) - if not pair: - return 'None' - else: - return toplevel_check (pair, report = report_mode, - check_loops = check_loops) - -def problem_search (s): - pair = name_search (s) - print pair.name - return check.build_problem (pair) + +def name_search(s, tags=None): + if s in pairings and len(pairings[s]) == 1: + return pairings[s][0] + pairs = list(set([pair for f in pairings for pair in pairings[f] + if s in pair.name + if not tags or tags.issubset(set(pair.tags))])) + if len(pairs) == 1: + return pairs[0] + word_pairs = [p for p in pairs if s in word_re.findall(str(p))] + if len(word_pairs) == 1: + return word_pairs[0] + print 'Possibilities for %r: %s' % (s, [str(p) for p in pairs]) + return None + + +def check_search(s, tags=None, report_mode=False, + check_loops=True): + pair = name_search(s, tags=tags) + if not pair: + return 'None' + else: + return toplevel_check(pair, report=report_mode, + check_loops=check_loops) + + +def problem_search(s): + pair = name_search(s) + #print pair.name + return check.build_problem(pair) + # somewhat arbitrary assignment of return codes to outcomes. # larger numbers are (roughly) worse outcomes. # key categories are success, skipped (not in covered cases), and failure result_codes = { - 'True' : (0, 'Success'), - 'Loop' : (1, 'Skipped'), - 'NoLoop' : (2, 'Skipped'), - 'None' : (3, 'Skipped'), - 'ProblemAbort' : (4, 'Skipped'), - 'False': (5, 'Failed'), - 'ProofNoSplit' : (6, 'Failed'), - 'ProofSolverFailure' : (7, 'Failed'), - 'ProofEXCEPT' : (8, 'Failed'), - 'CheckSolverFailure' : (9, 'Failed'), - 'CheckEXCEPT' : (10, 'Failed'), + 'True': (0, 'Success'), + 'Loop': (1, 'Skipped'), + 'NoLoop': (2, 'Skipped'), + 'None': (3, 'Skipped'), + 'ProblemAbort': (4, 'Skipped'), + 'False': (5, 'Failed'), + 'ProofNoSplit': (6, 'Failed'), + 'ProofSolverFailure': (7, 'Failed'), + 'ProofEXCEPT': (8, 'Failed'), + 'CheckSolverFailure': (9, 'Failed'), + 'CheckEXCEPT': (10, 'Failed'), } -def comb_results (r1, r2): - (_, r) = max ([(result_codes[r], r) for r in [r1, r2]]) - return r - -def check_pairs (pairs, loops = True, report_mode = False, - only_build_problem = False): - num_pairs = len (pairs) - printout ('Checking %d function pair problems' % len (pairs)) - results = [(pair, toplevel_check (pair, check_loops = loops, - report = report_mode, count = (i, num_pairs), - only_build_problem = only_build_problem)) - for (i, pair) in enumerate (pairs)] - result_dict = logic.dict_list ([(result_codes[r][1], pair) - for (pair, r) in results]) - if not only_build_problem: - printout ('Results: %s' - % [(pair.name, r) for (pair, r) in results]) - printout ('Result summary:') - success = result_dict.get ('Success', []) - if only_build_problem: - printout (' - %d problems build' % len (success)) - else: - printout (' - %d proofs checked' % len (success)) - skipped = result_dict.get ('Skipped', []) - printout (' - %d proofs skipped' % len (skipped)) - fails = [pair.name for pair in result_dict.get ('Failed', [])] - print_coverage_report (set (skipped), set (success + fails)) - printout (' - failures: %s' % fails) - return syntax.foldr1 (comb_results, ['True'] - + [r for (_, r) in results]) - -def print_coverage_report (skipped_pairs, covered_pairs): - try: - from trace_refute import addrs_covered, funs_sort_by_num_addrs - covered_fs = set ([f for pair in covered_pairs - for f in [pair.l_f, pair.r_f]]) - coverage = addrs_covered (covered_fs) - printout (' - %.2f%% instructions covered' % (coverage * 100)) - skipped_fs = set ([f for pair in skipped_pairs - for f in [pair.l_f, pair.r_f]]) - fs = funs_sort_by_num_addrs (set (skipped_fs)) - if not fs: - return - lrg_msgs = ['%s (%.2f%%)' % (f, addrs_covered ([f]) * 100) - for f in reversed (fs[-3:])] - printout (' - largest skipped functions:') - printout (' %s' % ', '.join (lrg_msgs)) - except Exception, e: - pass - -def check_all (omit_set = set (), loops = True, tags = None, - report_mode = False, only_build_problem = False): - pairs = list (set ([pair for f in pairings for pair in pairings[f] - if omit_set.isdisjoint (pair.funs.values ()) - if not tags or tags.issubset (set (pair.tags))])) - omitted = list (set ([pair for f in pairings for pair in pairings[f] - if not omit_set.isdisjoint (pair.funs.values())])) - random.shuffle (pairs) - r = check_pairs (pairs, loops = loops, report_mode = report_mode, - only_build_problem = only_build_problem) - if omitted: - printout (' - %d pairings omitted: %s' - % (len (omitted), [pair.name for pair in omitted])) - return r - -def check_division_pairs (num, denom, loops = True, report_mode = False): - pairs = list (set ([pair for f in pairings for pair in pairings[f]])) - def pair_size (pair): - return (len (functions[pair.l_f].nodes) - + len (functions[pair.r_f].nodes)) - pairs = sorted ([(pair_size (pair), pair) for pair in pairs]) - division = [pairs[i][1] for i in range (num, len (pairs), denom)] - return check_pairs (division, loops = loops, report_mode = report_mode) - -def check_deps (fname, report_mode = False): - frontier = set ([fname]) - funs = set () - while frontier: - fname = frontier.pop () - if fname in funs: - continue - funs.add (fname) - frontier.update (functions[fname].function_calls ()) - funs = sorted (funs) - funs = [fun for fun in funs if fun in pairings] - printout ('Testing functions: %s' % funs) - pairs = [pair for f in funs for pair in pairings[f]] - return check_pairs (pairs, report_mode = report_mode) - -def save_compiled_funcs (fname): - out = open (fname, 'w') - for (f, func) in functions.iteritems (): - trace ('Saving %s' % f) - for s in func.serialise (): - out.write (s + '\n') - out.close () - -def rerun_set (vs): - def get_strs (vs): - return [v for v in vs if type (v) == str] + [v2 - for v in vs if type (v) != str for v2 in get_strs (v)] - vs = set (get_strs (vs)) - pairs = map (name_search, vs) - strs = [pair.funs[pair.tags[0]] for pair in pairs if pair] - return ' '.join (strs) - -def main (args): - excluding = False - excludes = set () - loops = True - tags = set () - report = True - result = 'True' - pairs_to_check = [] - for arg in args: - r = 'True' - try: - if arg == 'verbose': - report = False - elif arg.startswith ('trace-to:'): - (_, s) = arg.split (':', 1) - f = open (s, 'w') - target_objects.trace_files.append (f) - elif arg == 'all': - r = check_all (excludes, loops = loops, - tags = tags, report_mode = report) - elif arg == 'all_safe': - ex = set.union (excludes, - target_objects.danger_set) - r = check_all (ex, loops = loops, - tags = tags, report_mode = report) - elif arg == 'coverage': - r = check_all (excludes, loops = loops, - tags = tags, report_mode = report, - only_build_problem = True) - elif arg.startswith ('div:'): - [_, num, denom] = arg.split (':') - num = int (num) - denom = int (denom) - r = check_division_pairs (num, denom, - loops = loops, report_mode = report) - elif arg == 'no_loops': - loops = False - elif arg == 'only_loops': - loops = 'only' - elif arg.startswith('save:'): - save_compiled_funcs (arg[5:]) - elif arg.startswith('save-proofs:'): - fname = arg[len ('save-proofs:') :] - save = check.save_proofs_to_file (fname, 'a') - check.save_checked_proofs[0] = save - elif arg == '-exclude': - excluding = True - elif arg == '-end-exclude': - excluding = False - elif arg.startswith ('t:'): - tags.add (arg[2:]) - elif arg.startswith ('target:'): - pass - elif arg.startswith ('skip-proofs-of:'): - (_, fname) = arg.split(':', 1) - import stats - prev_proofs = stats.scan_proofs (open (fname)) - prev_fs = [f for pair in prev_proofs - for f in pair.funs.values ()] - excludes.update (prev_fs) - elif excluding: - excludes.add (arg) - elif arg.startswith ('deps:'): - r = check_deps (arg[5:], - report_mode = report) - else: - r = name_search (arg, tags = tags) - if r != None: - pairs_to_check.append (r) - r = 'True' - else: - r = 'None' - except Exception, e: - print 'EXCEPTION in syscall arg %s:' % arg - print traceback.format_exc () - r = 'ProofEXCEPT' - result = comb_results (r, result) - if pairs_to_check: - r = check_pairs (pairs_to_check, loops = loops, - report_mode = report) - result = comb_results (r, result) - return result -if __name__ == '__main__': - result = main (args) - (code, category) = result_codes[result] - sys.exit (0) +def comb_results(r1, r2): + (_, r) = max([(result_codes[r], r) for r in [r1, r2]]) + return r + + +def check_pairs(pairs, loops=True, report_mode=False, + only_build_problem=False): + num_pairs = len(pairs) + printout('Checking %d function pair problems' % len(pairs)) + results = [(pair, toplevel_check(pair, check_loops=loops, + report=report_mode, count=(i, num_pairs), + only_build_problem=only_build_problem)) + for (i, pair) in enumerate(pairs)] + result_dict = logic.dict_list([(result_codes[r][1], pair) + for (pair, r) in results]) + if not only_build_problem: + printout('Results: %s' + % [(pair.name, r) for (pair, r) in results]) + printout('Result summary:') + success = result_dict.get('Success', []) + if only_build_problem: + printout(' - %d problems build' % len(success)) + else: + printout(' - %d proofs checked' % len(success)) + skipped = result_dict.get('Skipped', []) + printout(' - %d proofs skipped' % len(skipped)) + fails = [pair.name for pair in result_dict.get('Failed', [])] + print_coverage_report(set(skipped), set(success + fails)) + printout(' - failures: %s' % fails) + return syntax.foldr1(comb_results, ['True'] + + [r for (_, r) in results]) + +def print_coverage_report(skipped_pairs, covered_pairs): + try: + from trace_refute import addrs_covered, funs_sort_by_num_addrs + covered_fs = set([f for pair in covered_pairs + for f in [pair.l_f, pair.r_f]]) + coverage = addrs_covered(covered_fs) + printout(' - %.2f%% instructions covered' % (coverage * 100)) + skipped_fs = set([f for pair in skipped_pairs + for f in [pair.l_f, pair.r_f]]) + fs = funs_sort_by_num_addrs(set(skipped_fs)) + if not fs: + return + lrg_msgs = ['%s (%.2f%%)' % (f, addrs_covered([f]) * 100) + for f in reversed(fs[-3:])] + printout(' - largest skipped functions:') + printout(' %s' % ', '.join(lrg_msgs)) + except Exception, e: + pass + + +def check_all(omit_set=set(), loops=True, tags=None, + report_mode=False, only_build_problem=False): + pairs = list(set([pair for f in pairings for pair in pairings[f] + if omit_set.isdisjoint(pair.funs.values()) + if not tags or tags.issubset(set(pair.tags))])) + omitted = list(set([pair for f in pairings for pair in pairings[f] + if not omit_set.isdisjoint(pair.funs.values())])) + random.shuffle(pairs) + r = check_pairs(pairs, loops=loops, report_mode=report_mode, + only_build_problem=only_build_problem) + if omitted: + printout(' - %d pairings omitted' % (len(omitted)) ) + return r + + +def check_division_pairs(num, denom, loops=True, report_mode=False): + pairs = list(set([pair for f in pairings for pair in pairings[f]])) + + def pair_size(pair): + return (len(functions[pair.l_f].nodes) + + len(functions[pair.r_f].nodes)) + + pairs = sorted([(pair_size(pair), pair) for pair in pairs]) + division = [pairs[i][1] for i in range(num, len(pairs), denom)] + return check_pairs(division, loops=loops, report_mode=report_mode) + + +def check_deps(fname, report_mode=False): + frontier = set([fname]) + funs = set() + while frontier: + fname = frontier.pop() + if fname in funs: + continue + funs.add(fname) + frontier.update(functions[fname].function_calls()) + funs = sorted(funs) + funs = [fun for fun in funs if fun in pairings] + printout('Testing functions: %s' % funs) + pairs = [pair for f in funs for pair in pairings[f]] + return check_pairs(pairs, report_mode=report_mode) + + +def save_compiled_funcs(fname): + out = open(fname, 'w') + for (f, func) in functions.iteritems(): + trace('Saving %s' % f) + for s in func.serialise(): + out.write(s + '\n') + out.close() + + +def rerun_set(vs): + def get_strs(vs): + return [v for v in vs if type(v) == str] + [v2 + for v in vs if type(v) != str for v2 in get_strs(v)] + + vs = set(get_strs(vs)) + pairs = map(name_search, vs) + strs = [pair.funs[pair.tags[0]] for pair in pairs if pair] + return ' '.join(strs) + +def exitWithUsage(args): + objname = os.path.basename (args[0]) + dirname = os.path.dirname (args[0]) + exname = os.path.join (dirname, 'example') + print 'Usage: python %s ' % objname + print 'Target should be a directory.' + if os.path.isdir (exname): + print 'See example target (in %s)' % exname + else: + print 'See example target in graph-refine dir.' + assert not 'Target specified' + +def main(): + args = list (sys.argv) + # we have to set up tracing (logging) so we can write reports even if the rest fails + trace_to_arguments = [ arg[9:] for arg in args if arg.startswith('trace-to:') ] + for filename in trace_to_arguments: + f = open(filename, 'w') + target_objects.trace_files.append(f) + + # Write version info to the log, obtained either from the environment or git. + if 'GRAPH_REFINE_VERSION_INFO' in os.environ: + printout('VERSION_INFO %s' % os.environ['GRAPH_REFINE_VERSION_INFO']) + else: + graph_refine_dir = os.path.abspath(os.path.dirname(sys.argv[0])) + import subprocess + git_process = subprocess.Popen(['git', 'rev-parse', 'HEAD'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=graph_refine_dir) + git_out , git_err = git_process.communicate() + if len(git_err) > 0: + printout( 'VERSION_INFO GITSTATUS error' ) + printout( 'VERSION_INFO GITCOMMIT error - %s' % git_err ) + else: + # we refresh the index in case we have files with new timestamps but no changes + subprocess.call(['git','update-index','--refresh'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=graph_refine_dir) + git_status = subprocess.call(['git','diff-index','--quiet','HEAD','--'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=graph_refine_dir) + if git_status: + printout( 'VERSION_INFO GITSTATUS dirty - There are uncommitted changes!' ) + else: + printout( 'VERSION_INFO GITSTATUS clean - There are no uncommitted changes.' ) + printout( 'VERSION_INFO GITCOMMIT %s' % git_out ) + + # then we need to load all syntax from target + if len(args) <= 1: + exitWithUsage(args) + target = args[1] + target_arguments = [ arg[7:] for arg in args if arg.startswith('target:') ] + target_objects.load_target(target, target_arguments) + + # finally we can parse and execute all other arguments + args = args[2:] + excluding = False + excludes = set() + loops = True + tags = set() + report = True + result = 'True' + pairs_to_check = [] + for arg in args: + r = 'True' + try: + if arg == 'verbose': + report = False + elif arg.startswith('trace-to:'): + pass + elif arg == 'all': + r = check_all(excludes, loops=loops, + tags=tags, report_mode=report) + elif arg == 'all_safe': + ex = set.union(excludes, + target_objects.danger_set) + r = check_all(ex, loops=loops, + tags=tags, report_mode=report) + elif arg == 'coverage': + r = check_all(excludes, loops=loops, + tags=tags, report_mode=report, + only_build_problem=True) + elif arg.startswith('div:'): + [_, num, denom] = arg.split(':') + num = int(num) + denom = int(denom) + r = check_division_pairs(num, denom, + loops=loops, report_mode=report) + elif arg == 'no_loops': + loops = False + elif arg == 'only_loops': + loops = 'only' + elif arg.startswith('save:'): + save_compiled_funcs(arg[5:]) + elif arg.startswith('save-proofs:'): + fname = arg[len('save-proofs:'):] + save = check.save_proofs_to_file(fname, 'a') + check.save_checked_proofs[0] = save + elif arg == '-exclude': + excluding = True + elif arg == '-end-exclude': + excluding = False + elif arg == '-include-only': + excluding = True + excludes = set () + elif arg == '-end-include-only': # simulate include using exclude + excluding = False + includes = set ([f for e in excludes for p in pairings[e] for f in p.funs.values()]) + excludes = set ([f for f in pairings]).difference(includes) + elif arg.startswith('t:'): + tags.add(arg[2:]) + elif arg.startswith('target:'): + pass + elif arg.startswith('skip-proofs-of:'): + (_, fname) = arg.split(':', 1) + import stats + prev_proofs = stats.scan_proofs(open(fname)) + prev_fs = [f for pair in prev_proofs + for f in pair.funs.values()] + excludes.update(prev_fs) + elif excluding: + excludes.add(arg) + elif arg.startswith('deps:'): + r = check_deps(arg[5:], + report_mode=report) + else: + r = name_search(arg, tags=tags) + if r != None: + pairs_to_check.append(r) + r = 'True' + else: + r = 'None' + except Exception, e: + print 'EXCEPTION in syscall arg %s:' % arg + print traceback.format_exc() + r = 'ProofEXCEPT' + sys.exit(1) + result = comb_results(r, result) + if pairs_to_check: + r = check_pairs(pairs_to_check, loops=loops, + report_mode=report) + result = comb_results(r, result) + return result + + +if __name__ == '__main__': + sys.setrecursionlimit(5000) + result = main() + (code, category) = result_codes[result] + sys.exit (0) diff --git a/graph-to-graph/bench.py b/graph-to-graph/bench.py index 4e4271fa..7fb28e93 100644 --- a/graph-to-graph/bench.py +++ b/graph-to-graph/bench.py @@ -39,6 +39,9 @@ 'strncmp' ] +# Global arch variable + +bench_arch = 'armv7' def makeGraph(f_name,fs): p = fs[f_name].as_problem(problem.Problem) @@ -104,7 +107,7 @@ def analyseFunction(f,asm_fs,dir_name,gen_heads,load_counts,emit_graphs, stopAtI #toDot(imm_fun) #toDotBB(imm_fun) - emitter = chronos.emitter.ChronosEmitter(dir_name, f, imm_fun) + emitter = chronos.emitter.ChronosEmitter(dir_name, f, imm_fun, None, bench_arch) emitter.emitTopLevelFunction() imm_file_name = emitter.imm_file_name @@ -118,7 +121,9 @@ def analyseFunction(f,asm_fs,dir_name,gen_heads,load_counts,emit_graphs, stopAtI return wcet return None -def init(dir_name): +def init(dir_name, arch='armv7'): + global bench_arch + bench_arch = arch '''setup the target and initialise the elfFile''' target_objects.load_target(dir_name) sys.setrecursionlimit(2000) @@ -137,7 +142,7 @@ def silent_tracer (s,v): print s target_objects.tracer[0] = silent_tracer - elf_parser.parseElf(dir_name) + elf_parser.parseElf(dir_name, arch) asm_fs = elfFile().asm_fs tran_call_graph = call_graph_utils.transitiveCallGraph(asm_fs,dir_name,dummy_funs) @@ -146,8 +151,10 @@ def silent_tracer (s,v): elfFile().immed = None return asm_fs -def bench(dir_name, entry_point_function, gen_heads,load_counts, interactive, parse_only=False, conflict_file=None): - asm_fs = init(dir_name) +def bench(dir_name, entry_point_function, gen_heads,load_counts, interactive, parse_only=False, conflict_file=None, arch='armv7'): + global bench_arch + bench_arch = arch + asm_fs = init(dir_name, bench_arch) functions = target_objects.functions if parse_only or interactive: t = entry_point_function diff --git a/graph-to-graph/chronos/emitter.py b/graph-to-graph/chronos/emitter.py index d4142488..d452dabd 100644 --- a/graph-to-graph/chronos/emitter.py +++ b/graph-to-graph/chronos/emitter.py @@ -6,21 +6,44 @@ import re import parser +import riscv_parser import subprocess from addr_utils import phyAddrP + verbose = False from elf_file import elfFile, rawVals +''' +By default, the emitter handles armv7 instructions. +As we add RISC-V support in Chronos, we add "rv64", 64-bit RISC-V, +as another supported arch. +''' + +valid_arch = ['armv7', 'rv64'] + + class ChronosEmitter: - def __init__(self, dir_name, function_name, imm_fun, emit_as_dummy=None): + def __init__(self, dir_name, function_name, imm_fun, emit_as_dummy=None, arch='armv7'): self.function_name = function_name self.imm_fun = imm_fun self.imm_file_name = '%s/%s.imm' % (dir_name, function_name) self.imm_f = open(self.imm_file_name, 'w') - self.debug_f = open('%s/d_%s.imm' % (dir_name, function_name),'w') - self.emitted_loop_counts_file = open('%s/%s_emittedLoopCounts' % (dir_name, function_name),'w') + self.debug_f = open('%s/d_%s.imm' % (dir_name, function_name), 'w') + self.emitted_loop_counts_file = open('%s/%s_emittedLoopCounts' % (dir_name, function_name), 'w') self.emit_as_dummy = emit_as_dummy - self.elf_fun_to_skip = elfFile().funcs['clean_D_PoU'] + + if arch in valid_arch: + self.arch = arch + else: + self.arch = 'armv7' + + if self.arch == 'armv7': + # self.elf_fun_to_skip = elfFile().funcs['clean_D_PoU'] + # hack for running test code + self.elf_fun_to_skip = [] + else: + self.elf_fun_to_skip = [] + self.skip_fun = False def emitTopLevelFunction(self): @@ -37,15 +60,15 @@ def emitTopLevelFunction(self): self.debug_f.close() self.emitted_loop_counts_file.close() - def emitSyms (self): + def emitSyms(self): ef = elfFile() - for name in sorted(ef.syms.keys(),key=lambda x: ef.syms[x].addr): + for name in sorted(ef.syms.keys(), key=lambda x: ef.syms[x].addr): flag_str = '' sym = ef.syms[name] - #objects(O) in objdump is data + # objects(O) in objdump is data if 'O' in sym.flags: - flag_str += 'd' - #functions are text + flag_str += 'd' + # functions are text if 'F' in sym.flags: flag_str += 't' self.imm_f.write('s %s 0x%s %s %s\n' % (name, sym.addr, sym.ali_size, flag_str)) @@ -56,7 +79,7 @@ def emitEntry(self): self.emitString(s) def emitFunction(self): - #def emitFunction(imm_fun,imm_f,debug_f=None): + # def emitFunction(imm_fun,imm_f,debug_f=None): imm_fun = self.imm_fun imm_f = self.imm_f debug_f = self.debug_f @@ -65,87 +88,87 @@ def emitFunction(self): i_nodes = imm_fun.imm_nodes imm_loopheads = imm_fun.imm_loopheads - #locate the first and last addresses - first_addr,last_addr = self.firstAndLastAddr() - print 'first - last addrs : %x-%x' % (first_addr,last_addr) + # locate the first and last addresses + first_addr, last_addr = self.firstAndLastAddr() + print 'first - last addrs : %x-%x' % (first_addr, last_addr) size = 4 to_emit = {} - #dict of complex loop "head"s to ( addrs in the loop, its bound) + # dict of complex loop "head"s to ( addrs in the loop, its bound) complex_loops = {} - #we need to emit instructions in the order of addresses - #firstly, put all the lines in a dict + # we need to emit instructions in the order of addresses + # firstly, put all the lines in a dict for bb_start_addr in imm_fun.bbs: if self.skip_fun and bb_start_addr in self.elf_fun_to_skip.lines: continue for addr in imm_fun.bbs[bb_start_addr]: if addr in imm_loopheads: p_head, f = imm_loopheads[addr] - bin_head = phyAddrP(p_head,imm_fun.f_problems[f]) + bin_head = phyAddrP(p_head, imm_fun.f_problems[f]) import graph_refine.loop_bounds if imm_fun.loaded_loop_counts and bin_head in imm_fun.bin_loops_by_fs[f]: - #The user specified a manual loop-count override - loop_count,desc,_ = imm_fun.bin_loops_by_fs[f][bin_head] + # The user specified a manual loop-count override + loop_count, desc, _ = imm_fun.bin_loops_by_fs[f][bin_head] else: - print "imm_fun.loaded_loop_counts: %s, bin_loops_by_fs[f].keys: %s, function: %s" % (imm_fun.loaded_loop_counts, str(imm_fun.loops_by_fs[f]), f ) + print "imm_fun.loaded_loop_counts: %s, bin_loops_by_fs[f].keys: %s, function: %s" % ( + imm_fun.loaded_loop_counts, str(imm_fun.loops_by_fs[f]), f) assert False - loop_count,desc = graph_refine.loop_bounds.get_bound_super_ctxt(bin_head, []) + loop_count, desc = graph_refine.loop_bounds.get_bound_super_ctxt(bin_head, []) if graph_refine.loop_bounds.is_complex_loop(addr): body_addrs = graph_refine.loop_bounds.get_loop_addrs(addr) complex_loops[addr] = (body_addrs, loop_count) emitted_loop_counts[bin_head] = (loop_count, desc) - print '%x: bound %d/0x%x, %s' % (addr, loop_count, loop_count,desc) + print '%x: bound %d/0x%x, %s' % (addr, loop_count, loop_count, desc) else: loop_count = None - to_emit[addr] = (addr,addr == bb_start_addr,loop_count) + to_emit[addr] = (addr, addr == bb_start_addr, loop_count) for loop_addr in complex_loops.keys(): print "complex loop at 0x%x" % (addr) print "body: %s" % str(map(hex, body_addrs)) - #apply the loopcounts to all the instructions in this complex loop + # apply the loopcounts to all the instructions in this complex loop body_addrs, loop_bound = complex_loops[loop_addr] for addr in body_addrs: if addr not in to_emit: - #dodge the halt case + # dodge the halt case continue addr, is_start_bb, _ = to_emit[addr] - to_emit[addr] = (addr,is_start_bb, loop_bound) + to_emit[addr] = (addr, is_start_bb, loop_bound) emitted_loop_counts[addr] = (loop_bound, "complex_body") - - #then emit them in order - for addr in xrange (first_addr, last_addr + size, size): + # then emit them in order + for addr in xrange(first_addr, last_addr + size, size): if addr in to_emit: - addr,is_start_bb, loop_count = to_emit[addr] - self.emitImm(addr,i_nodes,is_start_bb,loop_count) + addr, is_start_bb, loop_count = to_emit[addr] + self.emitImm(addr, i_nodes, is_start_bb, loop_count) else: - #pad with nop - self.emitNop(addr, size) + # pad with nop + self.emitNop(addr, size) for bin_head in emitted_loop_counts: count, desc = emitted_loop_counts[bin_head] - self.emitted_loop_counts_file.write("0x%x : count %d, desc: %s\n" % ( bin_head, count, desc)) + self.emitted_loop_counts_file.write("0x%x : count %d, desc: %s\n" % (bin_head, count, desc)) def firstAndLastAddr(self): i_addrs = [] bbs = self.imm_fun.bbs for bb_n in bbs: i_addrs += bbs[bb_n] - #print 'chronos_emit i_addrs %s' % i_addrs - return min(i_addrs,key=int), max(i_addrs,key = int) + # print 'chronos_emit i_addrs %s' % i_addrs + return min(i_addrs, key=int), max(i_addrs, key=int) - def emitLiterals (self): + def emitLiterals(self): ef = elfFile() - for addr in sorted(ef.literals,key=int): - (size,value) = ef.literals[addr] - self.imm_f.write('v %s %s %d\n'% (hex(addr),value,size)) + for addr in sorted(ef.literals, key=int): + (size, value) = ef.literals[addr] + self.imm_f.write('v %s %s %d\n' % (hex(addr), value, size)) - def emitLoopcount (self,addr,loop_count): - self.imm_f.write('l 0x%x %s\n'% (addr,loop_count)) - print 'l 0x%x %s\n'% (addr,loop_count) + def emitLoopcount(self, addr, loop_count): + self.imm_f.write('l 0x%x %s\n' % (addr, loop_count)) + print 'l 0x%x %s\n' % (addr, loop_count) if self.debug_f: - self.debug_f.write('l 0x%x %s\n'% (addr,loop_count)) + self.debug_f.write('l 0x%x %s\n' % (addr, loop_count)) def emitString(self, s): self.imm_f.write(s) @@ -160,41 +183,86 @@ def emitNop(self, addr, size): s += '\n' self.emitString(s) - def emitImm(self,addr,nodes,is_startbb,loop_count): + + def emitArmImm(self, s, inst, value, txt): + i = inst + s += ' ' + i.mnemonic + ' ' + + if i.condition: + s += i.condition + ' ' + else: + s += '_ ' + + if i.setcc: + s += 's ' + else: + s += '_ ' + + for reg in i.input_registers: + s += 'input ' + reg + ' ' + for reg in i.output_registers: + s += 'output ' + reg + ' ' + if hasattr(i, 'shift_val'): + s += 'shift #' + i.shift_val + ' ' + i.shift_mode + ' ' + if hasattr(i, 'shift_reg'): + s += 'shift ' + i.shift_reg + ' ' + i.shift_mode + ' ' + # finally the raw inst and the text + + s += '%s ' % hexSansX(value) + s += '"%s"' % txt + s += '\n' + + self.emitString(s) + + + def emitRVImm(self, s, inst, value, txt): + s += ' ' + inst.mnemonic + ' ' + for r in inst.input_registers: + s += 'input ' + r + ' ' + if inst.has_imm: + s += 'input #' + inst.imm + ' ' + for r in inst.output_registers: + s += 'output ' + r + ' ' + + s += '%s ' % hexSansX(value) + s += '"%s"' % txt + s += '\n' + self.emitString(s) + + def emitImm(self, addr, nodes, is_startbb, loop_count): ''' Emit a single line of imm instruction ''' s = '' node = nodes[addr] - #if this is a loop head, emit its loop count + # if this is a loop head, emit its loop count if loop_count != None: - self.emitLoopcount (addr,loop_count) + self.emitLoopcount(addr, loop_count) if verbose: - print 'emitting %s: %s' % (addr,node.text) + print 'emitting %s: %s' % (addr, node.text) - #is this the start of a basic block ? + # is this the start of a basic block ? if is_startbb: bb = 'startbb' else: bb = 'contbb' - #all insts are of size 4 - s += ('i %s 4 %s' % (hex(addr),bb)) + # all insts are of size 4 + s += ('i %s 4 %s' % (hex(addr), bb)) - #output edges + # output edges s += ' edges' - #types of edges : next, call, callret,tailcall,return - #call: function call, callret : where to go when call returns ? - #return: this edge returns - #tailcall: namesake - + # types of edges : next, call, callret,tailcall,return + # call: function call, callret : where to go when call returns ? + # return: this edge returns + # tailcall: namesake - for e in sorted(node.edges, key = lambda x: x.targ): + for e in sorted(node.edges, key=lambda x: x.targ): if type(e.targ) != int: print 'e.targ %s' % e.targ if e.emit: - s += ' next '+ hex(e.targ) + s += ' next ' + hex(e.targ) dummy_call = False if node.call_edge: assert node.call_ret_edge != None @@ -207,7 +275,7 @@ def emitImm(self,addr,nodes,is_startbb,loop_count): s += ' tailcall ' + hex(node.call_edge.targ) else: s += ' call ' + hex(node.call_edge.targ) - #print 'call_ret_edge %s ' % node.call_ret_edge.targ + # print 'call_ret_edge %s ' % node.call_ret_edge.targ s += ' callret ' + hex(node.call_ret_edge.targ) if node.ret_edge: s += ' return' @@ -220,46 +288,26 @@ def emitImm(self,addr,nodes,is_startbb,loop_count): s += (' end') txt = node.text - #mnenomic condition setcc, input output etc - #print '%s: %s' % (addr,txt) + # mnenomic condition setcc, input output etc + # print '%s: %s' % (addr,txt) value = node.raw_inst if dummy_call: s += ' nop _ _ %s\n' % hexSansX(value) self.emitString(s) return - i = parser.decode_instruction(addr, value, txt) - s += ' ' + i.mnemonic + ' ' - if i.condition: - s += i.condition + ' ' + if self.arch == 'armv7': + i = parser.decode_instruction(addr, value, txt) + self.emitArmImm(s, i, value, txt) + elif self.arch == 'rv64': + i = riscv_parser.decode_instruction(addr, value, txt) + self.emitRVImm(s, i, value, txt) else: - s += '_ ' - - if i.setcc: - s += 's ' - else: - s += '_ ' - - for reg in i.input_registers: - s += 'input ' + reg + ' ' - for reg in i.output_registers: - s += 'output ' + reg + ' ' - if hasattr(i, 'shift_val'): - s += 'shift #' + i.shift_val + ' ' + i.shift_mode + ' ' - if hasattr(i,'shift_reg'): - s += 'shift ' + i.shift_reg + ' ' + i.shift_mode + ' ' - #finally the raw inst and the text - - s += '%s ' % hexSansX(value) - s += '"%s"' % txt - s += '\n' - - self.emitString(s) + print 'unsupported arch %s ' % self.arch + return def hexSansX(n): '''translate the input to a hex without the 0x prefix''' s = hex(n) return s[2:] - - diff --git a/graph-to-graph/chronos/parser.py b/graph-to-graph/chronos/parser.py index 6e8269f5..549bcae1 100644 --- a/graph-to-graph/chronos/parser.py +++ b/graph-to-graph/chronos/parser.py @@ -62,23 +62,23 @@ valid_conditions = ( '', 'ne', 'eq', - 'cs', 'hs', - 'cc', 'lo', - 'mi', 'pl', 'vs', 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le', + 'cs', 'hs', + 'cc', 'lo', + 'mi', 'pl', 'vs', 'vc', 'hi', 'ls', 'ge', 'lt', 'gt', 'le', ) valid_instruction_re = re.compile( r'''^(?: - (?P%(arith_instructions)s) - (?Ps?) - (?P%(conditions)s) | - (?P%(other_instructions)s) - (?P%(conditions)s) - )$''' % { - 'arith_instructions': '|'.join(valid_arith_instructions), - 'other_instructions': '|'.join(valid_other_instructions), - 'conditions': '|'.join(valid_conditions) - }, re.X) + (?P%(arith_instructions)s) + (?Ps?) + (?P%(conditions)s) | + (?P%(other_instructions)s) + (?P%(conditions)s) + )$''' % { + 'arith_instructions': '|'.join(valid_arith_instructions), + 'other_instructions': '|'.join(valid_other_instructions), + 'conditions': '|'.join(valid_conditions) + }, re.X) # # The following regexes take the arguments of a specific instruction (whose @@ -103,52 +103,52 @@ any_register = r'%s' % ('|'.join(list(all_registers) + aliases.keys())) ldrstr_args_re = re.compile( r'''(?:(?:%(any_register)s),\s*)? - (?P%(any_register)s),\s* - \[ - (?P%(any_register)s)\s* - (?:,\s* - (?: - \#(?P-?[0-9]+) | - (?P%(any_register)s)\s* - (?:,\s* - (?Plsl|lsr|asr|ror|rrx)\s+ - \#(?P[0-9]+) - )? - ) - )? - \] - (?: - (?P !) | - ,\s* (?P%(any_register)s) | - ,\s* \#(?P-?[0-9]+) - )?\s*(;.*)? - $''' % {'any_register' : any_register}, + (?P%(any_register)s),\s* + \[ + (?P%(any_register)s)\s* + (?:,\s* + (?: + \#(?P-?[0-9]+) | + (?P%(any_register)s)\s* + (?:,\s* + (?Plsl|lsr|asr|ror|rrx)\s+ + \#(?P[0-9]+) + )? + ) + )? + \] + (?: + (?P !) | + ,\s* (?P%(any_register)s) | + ,\s* \#(?P-?[0-9]+) + )?\s*(;.*)? + $''' % {'any_register' : any_register}, re.X) operand2 = r'''(?: - \#(?P-?[0-9]+) | - (?: - (?P%(any_register)s - ) - (?:,\s* - (?Plsl|lsr|asr|ror|rrx)\s+ - (?: - \#(?P[0-9]+) | - (?P%(any_register)s) - ) - )? - ) - )''' + \#(?P-?[0-9]+) | + (?: + (?P%(any_register)s + ) + (?:,\s* + (?Plsl|lsr|asr|ror|rrx)\s+ + (?: + \#(?P[0-9]+) | + (?P%(any_register)s) + ) + )? + ) + )''' onereg_and_operand2_re = re.compile( (r'''(?P%(any_register)s),\s*''' + operand2 + '(\s*;.*)?$') % { - 'any_register' : any_register}, + 'any_register' : any_register}, re.X) tworegs_and_operand2_re = re.compile( (r'''(?P%(any_register)s),\s* - (?P%(any_register)s),\s*''' + operand2 + '(\s*;.*)?$') % { - 'any_register' : any_register}, + (?P%(any_register)s),\s*''' + operand2 + '(\s*;.*)?$') % { + 'any_register' : any_register}, re.X) @@ -156,7 +156,7 @@ #just used for decoding for us class ARMInstruction: def __init__(self, addr, value, disassembly, - mnemonic, condition, dirflags, cpsflags, setcc, args): + mnemonic, condition, dirflags, cpsflags, setcc, args): self.addr = addr self.value = value @@ -234,10 +234,10 @@ def decode(self): # Record input and output registers. # if load: self.output_registers.append(args['target_reg']) - #self.input_registers.append('memory') + #self.input_registers.append('memory') # else: # self.input_registers.append(args['target_reg']) - #self.output_registers.append('memory') + #self.output_registers.append('memory') self.input_registers.append(args['base_addr_reg']) if args['incr_reg']: self.input_registers.append(args['incr_reg']) @@ -269,8 +269,8 @@ def decode(self): addr_reg, reg_list = [x.strip() for x in self.args.split(',', 1)] writeback = addr_reg[-1] == '!' if writeback: - self.output_registers.append('writeback') - addr_reg = addr_reg.rstrip('!') + self.output_registers.append('writeback') + addr_reg = addr_reg.rstrip('!') # self.output_registers.append(addr_reg) # self.input_registers.append(addr_reg) @@ -556,17 +556,21 @@ def decode(self): # Convert above into mnemonic -> class map. mnemonic_to_class_map = dict([(m, c) - for ms, c in mnemonic_groups_to_class_map.iteritems() - for m in ms]) + for ms, c in mnemonic_groups_to_class_map.iteritems() + for m in ms]) def decode_instruction(addr, value, decoding): + print(decoding) decoding = decoding.strip() bits = decoding.split(None, 1) if len(bits) == 1: instruction, args = bits[0], [] else: instruction, args = bits - + print(decoding) + print(instruction) + print(args) + print(value) g = valid_instruction_re.match(instruction) if g is None: raise FatalError("Unknown instruction %s at address %#x" % (instruction, addr)) @@ -594,7 +598,7 @@ def decode_instruction(addr, value, decoding): #print '%s: %s \n instruction %s \n condition %s\n dirflags %s\n cpsflags %s\n setcc %s\n args %s\n' % (addr,decoding, instruction,condition,dirflags,cpsflags,setcc,args) arm_inst = cls(addr, value, decoding, - instruction, condition, dirflags, cpsflags, setcc, args) + instruction, condition, dirflags, cpsflags, setcc, args) arm_inst.decode() mnemonic = arm_inst.mnemonic diff --git a/graph-to-graph/chronos/riscv_parser.py b/graph-to-graph/chronos/riscv_parser.py new file mode 100644 index 00000000..957ab2b1 --- /dev/null +++ b/graph-to-graph/chronos/riscv_parser.py @@ -0,0 +1,534 @@ +# +# Copyright 2020, Cog Systems +# +# SPDX-License-Identifier: BSD-2-Clause +# + +import re + +''' +Generate input file for RISCV Chronos + +Working in progress + +Regular expression based single instruction parsing utility. +Appopriated from previous work for interfacing with Chronos + +Pleae do not blame the author for the madness if you have headaches +when reading this. So is the Chronos code. + +''' + +''' +The following constants culminate to valid_instruction_re, a regex search pattern, +which is used to decipher an instruction mnemonic into the base instruction and all +the possible modifiers that can apply to it. +''' + +# set rd = imm +rd_imm = ( + 'lui', + 'li', + 'addpc', + 'jal', # rd = pc + 4; pc = pc + offset; but use the absolute offset + 'auipc', +) + +zero_oprand = ( + 'fence.i', + 'wfi', + 'nop', + 'ret', # ret == jalr x0, 0(x1) +) + +imm_only = ( + 'j', # j #offset +) +# rd = rs1 some are pseudo instructions, but that does not matter +rd_rs1 = ( + 'mv', + 'sfence.mva', +) + +# rd = op rs2 +rd_rs2 = ( + 'neg', + 'negw', +) + +rs1_imm = { + 'beqz', + 'bltz', + 'bgez', + 'bnez', +} + +rs2_imm = { + 'blez', + 'bgtz', +} + +# rd = some op using rs1 and imm +rd_rs1_imm = ( + 'lb', + 'lh', + 'lw', + 'lwu', + 'ld', + 'lbu', + 'lhu', + 'addi', # rd = rs1 + imm + 'slti', + 'xori', + 'ori', + 'andi', + 'addiw', + 'slli', + 'slliw', + 'srli', + 'srliw', + 'srai', + 'sraiw', + 'not', # not rd, rs1 == xori rd, rs1, #-1 + 'sext.w', # sext.w rd, rs1 == addiw rd, rs1, #0 + 'seqz', # seqz rd, rs1 == sltiu rd, rs1, #1 + 'jalr', # rd = pc + 4; pc = (rs1 + imm) * 2 +) + +# ops using rs1, rs2, and imm +rs1_rs2_imm = ( + 'sb', # u8[rs1 + imm] = rs2 + 'sh', + 'sw', + 'sd', + 'st', + 'beq', # fi rs1 == rs2 pc = pc + imm + 'bne', + 'blt', + 'bltu', + 'bge', + 'bgeu', +) + +# ops using rd = rs1 op rs2 +rd_rs1_rs2 = [ + 'add', + 'addw', + 'sub', + 'subw', + 'sll', + 'sllw', + 'slt', + 'sltu', + 'xor', + 'srl', + 'srlw', + 'sra', + 'sraw', + 'or', + 'mul', + 'mulw', + 'mulh', + 'mulhsu', + 'mulhu', + 'div', + 'divw', + 'divu', + 'divuw', + 'rem', + 'remu', + 'remw', + 'remuw', + 'sltz', # sltz rd, rs1 == slt rd, rs1, x0 + 'sgtz', # sgtz rd, rs2 == slt rd, x0, rs2 +] + +csr = ( + 'csrrw', + 'csrrs', + 'csrrc', + 'csrrwi', + 'csrrsi', + 'csrrci', + 'csrr', + 'csrw', + 'csrs', + 'csrc', +) + +all_registers = ( + 'x0', 'x1', + 'x2', 'x3', + 'x4', 'x5', + 'x6', 'x7', + 'x8', 'x9', + 'x10','x11', + 'x12','x13', + 'x14','x15', + 'x16','x17', + 'x18','x19', + 'x20','x21', + 'x22','x23', + 'x24','x25', + 'x26','x27', + 'x28','x29', + 'x30','x31', + 'pc' +) + +aliases = { + 'zero': 'x0', + 'ra': 'x1', + 'sp': 'x2', + 'gp': 'x3', + 'tp': 'x4', + 't0': 'x5', + 't1': 'x6', + 't2': 'x7', + 'fp': 'x8', + 's0': 'x8', + 's1': 'x9', + 'a0': 'x10', + 'a1': 'x11', + 'a2': 'x12', + 'a3': 'x13', + 'a4': 'x14', + 'a5': 'x15', + 'a6': 'x16', + 'a7': 'x17', + 's2': 'x18', + 's3': 'x19', + 's4': 'x20', + 's5': 'x21', + 's6': 'x22', + 's7': 'x23', + 's8': 'x24', + 's9': 'x25', + 's10': 'x26', + 's11': 'x27', + 't3': 'x28', + 't4': 'x29', + 't5': 'x30', + 't6': 'x31' +} + +csrs = ( + 'sstatus', + 'stvec', + 'sip', + 'sie', + 'scounteren', + 'sscratch', + 'sepc', + 'scause', + 'stval', + 'satp', +) + +any_register = r'%s' % ('|'.join(list(all_registers) + aliases.keys())) + +def is_hex(imm): + if len(imm) >= 2 and imm[0] == '0' and imm[1] == 'x': + return True + for c in imm: + if c in ['a', 'A', 'b', 'B', 'c', 'C', 'd', 'D', 'e', 'E', 'f', 'F']: + return True + + return False + + +def to_int(imm): + fs = imm.split() + imm = fs[0] + try: + if is_hex(imm): + return int(imm, base = 16) + else: + return int(imm) + except Exception as e: + print e + print 'fail to convert %s' % imm + raise + +def valid_gp_reg(reg): + return reg in all_registers or reg in list(aliases.keys()) + +class RVInstruction: + def __init__(self, addr, value, disassembly, mnemonic, args): + self.rd = '' + self.rs1 = '' + self.rs2 = '' + self.imm = '' + self.imm_val = 0 + self.has_imm = False + self.rd_csr = '' + self.rs_csr = '' + + self.addr = addr + self.value = value + self.disassembly = disassembly + + self.mnemonic = mnemonic + self.args = args + self.is_loop_cond = False + print "%s %s %s" % (addr, mnemonic, args) + print type(args) + self.output_registers = [] + self.input_registers = [] + + def decode(self): + raise NotImplementedError + + +class RdImm(RVInstruction): + def decode(self): + print 'args %s' % self.args + fields = self.args.split(',') + assert len(fields) == 2 + self.rd = fields[0].strip() + assert valid_gp_reg(self.rd) + self.imm = fields[1].strip().split()[0] + self.imm_val = to_int(self.imm) + self.output_registers.append(self.rd) + self.has_imm = True + + +class ZeroOprand(RVInstruction): + def decode(self): + if self.mnemonic == 'ret': + self.input_registers.append('x1') + +class ImmOnly(RVInstruction): + def decode(self): + self.imm = self.args.strip().split()[0] + self.imm_val = to_int(self.imm) + self.has_imm = True + +class RdRs1(RVInstruction): + def decode(self): + fs = self.args.split(',') + assert len(fs) == 2 + self.rd = fs[0].strip() + self.rs1 = fs[1].strip() + assert valid_gp_reg(self.rd) and valid_gp_reg(self.rs1) + self.output_registers.append(self.rd) + self.input_registers.append(self.rs1) + +class RdRs2(RVInstruction): + def decode(self): + fs = self.args.split(',') + assert len(fs) == 2 + self.rd = fs[0].strip() + self.rs2 = fs[1].strip() + assert valid_gp_reg(self.rd) and valid_gp_reg(self.rs2) + self.output_registers.append(self.rd) + self.input_registers.append(self.rs2) + +class RdRs1Imm(RVInstruction): + def decode(self): + fs = self.args.split(',') + l = len(fs) + self.has_imm = True + assert l == 2 or l == 3 + + ''' + if length is 2 we have the form: + ld a1,64(a4) + if the length is 3, the form: + addi s0,sp,640 + ''' + + if l == 2: + self.rd = fs[0].strip() + fs[1] = fs[1].strip() + left = fs[1].find('(') + right = fs[1].find(')') + + if self.mnemonic == 'sext.w': + self.imm = '0' + self.imm_val = 0 + self.rs1 = fs[1] + else: + assert left != -1 and right != -1 + self.imm = fs[1][0:left] + self.imm_val = to_int(self.imm) + self.rs1 = fs[1][left + 1 : right] + + self.output_registers.append(self.rd) + assert valid_gp_reg((self.rs1)) + + if l == 3: + self.rd = fs[0].strip() + self.rs1 = fs[1].strip() + self.imm = fs[2].strip() + self.imm_val = to_int(self.imm) + self.input_registers.append(self.rs1) + self.output_registers.append(self.rd) + +class Rs1Imm(RVInstruction): + def decode(self): + fs = self.args.split(',') + assert len(fs) == 2 + self.rs1 = fs[0].strip() + self.imm = fs[1].strip() + self.imm_val = to_int(self.imm) + self.input_registers.append(self.rs1) + self.has_imm = True + assert valid_gp_reg(self.rs1) + +class Rs2Imm(RVInstruction): + def decode(self): + fs = self.args.split(',') + assert len(fs) == 2 + self.rs2 = fs[0].strip() + self.imm = fs[1].strip() + self.imm_val = to_int(self.imm) + self.input_registers.append(self.rs2) + self.has_imm = True + assert valid_gp_reg(self.rs2) + +class Rs1Rs2Imm(RVInstruction): + def decode(self): + fs = self.args.split(',') + l = len(fs) + self.has_imm = True + assert l == 2 or l == 3 + + if l == 2: + self.rs2 = fs[0] + fs[1] = fs[1].strip() + left = fs[1].find('(') + right = fs[1].find(')') + assert left != -1 and right != -1 + self.imm = fs[1][0:left] + self.imm_val = to_int(self.imm) + self.rs1 = fs[1][left + 1 : right] + self.input_registers.append(self.rs1) + self.input_registers.append(self.rs2) + + if l == 3: + self.rs1 = fs[0].strip() + self.rs2 = fs[1].strip() + self.imm = fs[2].strip().split()[0] + self.imm_val = to_int(self.imm) + self.input_registers.append(self.rs1) + self.input_registers.append(self.rs2) + + assert valid_gp_reg(self.rs1) and valid_gp_reg(self.rs2) + +class RdRs1Rs2(RVInstruction): + def decode(self): + fs = self.args.split(',') + assert len(fs) == 3 + self.rd = fs[0].strip() + self.rs1 = fs[1].strip() + self.rs2 = fs[2].strip() + assert valid_gp_reg(self.rd) + assert valid_gp_reg(self.rs1) + assert valid_gp_reg(self.rs2) + self.input_registers.append(self.rs1) + self.input_registers.append(self.rs2) + self.output_registers.append(self.rd) + + +class CSR(RVInstruction): + def decode(self): + fs = self.args.split(',') + if self.mnemonic in ['csrrw', 'csrrs', 'csrrc']: + self.rd = fs[0].strip() + self.rd_csr = self.rs_csr = fs[1].strip() + self.rs1 = fs[2].strip() + pass + if self.mnemonic in ['csrwi']: + self.rd_csr = fs[0].strip() + self.imm = fs[1].strip() + self.imm_val = to_int(self.imm) + self.output_registers.append(self.rd_csr) + assert self.rd_csr in csrs + if self.mnemonic in ['csrrsi', 'csrrci']: + self.rd = fs[0].strip() + self.rs_csr = self.rd_csr = fs[1].strip() + self.imm = fs[2].strip() + self.imm_val = to_int(self.imm) + self.output_registers.append(self.rd) + self.output_registers.append(self.rd_csr) + self.input_registers.append(self.rs_csr) + assert valid_gp_reg(self.rd) + assert self.rd_csr in csrs + if self.mnemonic in ['csrr']: + self.rd = fs[0].strip() + self.rd_csr = fs[1].strip() + assert valid_gp_reg(self.rd) + assert self.rd_csr in csrs + if self.mnemonic in ['csrw', 'csrc', 'csrs']: + self.rd = fs[0].strip() + self.rd_csr = self.rs_csr = fs[1].strip() + self.rs1 = fs[2].strip() + assert valid_gp_reg(self.rd) + assert valid_gp_reg(self.rs1) + assert self.rd_csr in csrs + assert self.rs_csr in csrs + self.output_registers.append(self.rd) + self.output_registers.append(self.rd_csr) + self.input_registers.append(self.rs1) + self.input_registers.append(self.rs_csr) + if self.mnemonic in ['csrwi', 'csrsi', 'csrci']: + self.rd_csr = fs[0].strip() + self.imm = fs[1].strip() + self.imm_value = to_int(self.imm) + assert self.rd_csr in csrs + self.output_registers.append(self.rd_csr) + +class UnhandledInstruction(RVInstruction): + def decode(self): + NopInstruction.decode(self) + print 'Unhandled instruction "%s" at %#x' % (self.mnemonic, self.addr) + + +def decode_instruction(addr, value, decoding): + decoding = decoding.strip() + print decoding + bits = decoding.split(None, 1) + if len(bits) == 1: + instruction, args = bits[0], [] + else: + instruction, args = bits + + oi = instruction + print instruction + + if oi in rd_imm: + cls = RdImm + elif oi in zero_oprand: + cls = ZeroOprand + elif oi in imm_only: + cls = ImmOnly + elif oi in rd_rs1: + cls = RdRs1 + elif oi in rd_rs2: + cls = RdRs2 + elif oi in rs1_imm: + cls = Rs1Imm + elif oi in rs2_imm: + cls = Rs2Imm + elif oi in rd_rs1_imm: + cls = RdRs1Imm + elif oi in rs1_rs2_imm: + cls = Rs1Rs2Imm + elif oi in rd_rs1_rs2: + cls = RdRs1Rs2 + elif oi in csr: + cls = CSR + else: + print "unsopported instructions %s" % oi + assert False + + print '%s %s' % (instruction, cls) + inst = cls(addr, value, decoding, instruction,args) + + inst.decode() + + mnemonic = inst.mnemonic + output_registers = inst.output_registers + input_registers = inst.input_registers + return inst diff --git a/graph-to-graph/dot_utils.py b/graph-to-graph/dot_utils.py index eed35adf..ffce2b20 100644 --- a/graph-to-graph/dot_utils.py +++ b/graph-to-graph/dot_utils.py @@ -5,8 +5,10 @@ # import pydot +import subprocess from elf_file import * + ''' All functions in this file will be updated to reflect changes made, they are currently outdated and produce misleading outputs if called. This file was used for debugging the toolchains. diff --git a/graph-to-graph/elf_parser.py b/graph-to-graph/elf_parser.py index 2dfc66ac..8a63a6d8 100644 --- a/graph-to-graph/elf_parser.py +++ b/graph-to-graph/elf_parser.py @@ -7,6 +7,8 @@ from elf_file import * import re +elf_arch = 'armv7' + def parseSymTab(): ef = elfFile() #parse the symbol table @@ -36,6 +38,7 @@ def parseSymTab(): break def parseTxt (): + print elf_arch ef = elfFile() curr_func_name = None skip_next_line = False @@ -47,8 +50,16 @@ def parseTxt (): #ingore empty lines and the header if line in ['\n','\r\n']: continue - header = re.search('kernel\.elf:\s*file\s*format\s*elf32-littlearm',line) - header2 = re.search('Disassembly of section \..*:',line) + if elf_arch == 'armv7': + header = re.search('kernel\.elf:\s*file\s*format\s*elf32-littlearm',line) + header2 = re.search('Disassembly of section \..*:',line) + elif elf_arch == 'rv64': + header = re.search('kernel\.o:\s*file\s*format\s*elf64-littleriscv', line) + header2 = re.search('Disassembly of section \..*:', line) + else: + print 'Unsupported arch %s' % elf_arch + assert False + if header != None or header2 != None: continue #ndsk_boot is a strange function only used in bootstrapping @@ -76,6 +87,8 @@ def parseTxt (): else: #check if this is a literal line literal = re.search('(?P.*):\s*[a-f0-9]+\s*(?P(\.word)|(\.short)|(\.byte))\s*(?P0x[a-f0-9]+)$',line) + print line + print literal if literal != None: if literal.group('size') == '.word': size = 4 @@ -105,8 +118,10 @@ def isDirectBranch(addr): match = re.search(r'[a-f0-9]+:\s*[a-f0-9]+\s+(b|bx)\s+.*',inst) return match is not None -def parseElf(dir_name): +def parseElf(dir_name, arch='armv7'): ef = elfFile(dir_name) + global elf_arch + elf_arch = arch parseTxt() parseSymTab() return ef diff --git a/graph-to-graph/graph_to_graph.py b/graph-to-graph/graph_to_graph.py index 0c909ec9..62cea16e 100644 --- a/graph-to-graph/graph_to_graph.py +++ b/graph-to-graph/graph_to_graph.py @@ -27,6 +27,7 @@ def printHelp(): --i interactive mode (for debugging) --x automated WCET estimating, firstly generate the loop heads, then automatically deduce the loop bounds, and finally use the automatically determined loopbounds to estimate teh WCET. A conflict file specifying additional preemption points --xL same as --x but do not generate (and thus overwrite) loop_counts.py + --a architecture (armv7, rv64) armv7 by default ''' @@ -45,10 +46,19 @@ def printHelp(): dir_name = sys.argv[1] print 'dir_name: %s' % dir_name flag = sys.argv[3] - assert flag in ['--l','--L','--i', '--x', '--xL'] + # hack fo now, assume the last two parameters are arch + # and we use armv7 by default + if len(sys.argv) == 6: + arch = sys.argv[5] + else: + arch = 'armv7' + print len(sys.argv) + print "arch is %s" % arch + assert arch in ['armv7', 'rv64'] + assert flag in ['--l','--L','--i', '--x', '--xL', '--a'] if flag == '--l': gen_heads = True - bench.bench(dir_name, entry_point_function, gen_heads, load_counts,interactive) + bench.bench(dir_name, entry_point_function, gen_heads, load_counts,interactive, False, None, arch) sys.exit(0) if flag == '--L': load_counts = True @@ -58,7 +68,7 @@ def printHelp(): if len(sys.argv) < 4: printHelp() sys.exit(-1) - asm_fs = bench.init(dir_name) + asm_fs = bench.init(dir_name, arch) if flag == '--x': import convert_loop_bounds analyseFunction(entry_point_function,asm_fs, dir_name, True, False, False) @@ -79,5 +89,5 @@ def printHelp(): preemption_limit = 5 conflict.conflict(entry_point_function, tcfg_map_file_name, [], stripped_ilp, ilp_to_generate, dir_name, sol_to_generate, emit_conflicts=True, do_cplex=True, preempt_limit= preemption_limit,default_phantom_preempt=True) sys.exit(0) - bench_ret = bench.bench(dir_name, entry_point_function, gen_heads,load_counts,interactive) + bench_ret = bench.bench(dir_name, entry_point_function, gen_heads,load_counts,interactive, False, None, arch) print 'bench returned: ' + str(bench_ret) diff --git a/inst_logic.py b/inst_logic.py index f1c106ce..cfc4c3ae 100644 --- a/inst_logic.py +++ b/inst_logic.py @@ -19,176 +19,416 @@ import re -reg_aliases = {'sb': 'r9', 'sl': 'r10', 'fp': 'r11', 'ip': 'r12', - 'sp': 'r13', 'lr': 'r14', 'pc': 'r15'} +reg_aliases_armv7 = {'sb': 'r9', 'sl': 'r10', 'fp': 'r11', 'ip': 'r12', + 'sp': 'r13', 'lr': 'r14', 'pc': 'r15'} + +reg_aliases_rv64 = { + 'zero': 'r0', + 'ra': 'r1', + 'sp': 'r2', + 'gp': 'r3', + 'tp': 'r4', + 't0': 'r5', + 't1': 'r6', + 't2': 'r7', + 'fp': 'r8', + 's0': 'r8', + 's1': 'r9', + 'a0': 'r10', + 'a1': 'r11', + 'a2': 'r12', + 'a3': 'r13', + 'a4': 'r14', + 'a5': 'r15', + 'a6': 'r16', + 'a7': 'r17', + 's2': 'r18', + 's3': 'r19', + 's4': 'r20', + 's5': 'r21', + 's6': 'r22', + 's7': 'r23', + 's8': 'r24', + 's9': 'r25', + 's10': 'r26', + 's11': 'r27', + 't3': 'r28', + 't4': 'r29', + 't5': 'r30', + 't6': 'r31' +} + +csrs_rv64 = ( + 'sstatus', + 'stvec', + 'sip', + 'sie', + 'scounteren', + 'sscratch', + 'sepc', + 'scause', + 'stval', + 'satp', +) reg_set = set (['r%d' % i for i in range (16)]) +reg_set_rv64 = set(['x%d' % i for i in range(32)]) + inst_split_re = re.compile('[_,]*') + +def split_inst_name_regs_rv64(nm): + reg_aliases = reg_aliases_rv64 + bits = inst_split_re.split(nm) + fin_bits = [] + regs = [] + for i in range(len(bits)): + if bits[i] in reg_aliases_rv64.keys(): + if bits[i] == 'zero' and bits[0] == 'sfence.vma': + fin_bits.append('x0') + else: + regs.append(reg_aliases_rv64.get(bits[i])) + fin_bits.append('-argv%d' % len(regs)) + + elif bits[0] == 'ecall': + fin_bits.append(bits[0]) + regs.append('r10') + + fin_bits.append('-argv%d' % len(regs)) + + regs.append('r11') + + fin_bits.append('-argv%d' % len(regs)) + + regs.append('r12') + + fin_bits.append('-argv%d' % len(regs)) + + regs.append('r17') + + fin_bits.append('-argv%d' % len(regs)) + + regs.append('r10') + fin_bits.append('-ret%d' % len(regs)) +# regs.append('r14') + +# fin_bits.append('-argv%d' % len(regs)) + +# regs.append('r15') + +# fin_bits.append('-argv%d' % len(regs)) + + elif bits[i] == 'x0' and bits[0] == 'sfence.vma': + fin_bits.append(bits[i]) + elif bits[i] in reg_set_rv64: + regs.append('r' + bits[i][1:]) + fin_bits.append('-argv%d' % len(regs)) + elif bits[i] in csrs_rv64: + #regs.append(bits[i]) + fin_bits.append(bits[i]) + elif bits[i].startswith('%'): + regs.append(bits[i]) + fin_bits.append('-argv%d' % len(regs)) + else: + fin_bits.append(bits[i]) + for f in fin_bits: + if f.startswith('-argv'): + fin_bits.remove(f) + fin_bits.append(f) + return (regs, '_'.join (fin_bits)) + + crn_re = re.compile('cr[0123456789][0123456789]*') pn_re = re.compile('p[0123456789][0123456789]*') -def split_inst_name_regs (nm): - bits = inst_split_re.split (nm) - fin_bits = [] - regs = [] - if len (bits) > 1 and pn_re.match (bits[1]): - bits[1] = bits[1][1:] - for bit in bits: - bit2 = bit.lower () - bit2 = reg_aliases.get (bit, bit) - if crn_re.match (bit2): - assert bit2.startswith ('cr') - bit2 = 'c' + bit2[2:] - if bit2 in reg_set or bit2.startswith ('%'): - regs.append (bit2) - fin_bits.append ('argv%d' % (len (regs))) - else: - fin_bits.append (bit2) - return (regs, '_'.join (fin_bits)) + +def split_inst_name_regs_armv7 (nm): + reg_aliases = reg_aliases_armv7 + bits = inst_split_re.split (nm) + fin_bits = [] + regs = [] + if len (bits) > 1 and pn_re.match (bits[1]): + bits[1] = bits[1][1:] + for bit in bits: + bit2 = bit.lower () + bit2 = reg_aliases.get (bit, bit) + if crn_re.match (bit2): + assert bit2.startswith ('cr') + bit2 = 'c' + bit2[2:] + if bit2 in reg_set or bit2.startswith ('%'): + regs.append (bit2) + fin_bits.append ('argv%d' % (len (regs))) + else: + fin_bits.append (bit2) + return (regs, '_'.join (fin_bits)) + + +def split_inst_name_regs(nm): + if syntax.arch.name == 'armv7': + return split_inst_name_regs_armv7(nm) + elif syntax.arch.name == 'rv64': + return split_inst_name_regs_rv64(nm) + else: + assert False bin_globs = [('mem', syntax.builtinTs['Mem'])] asm_globs = [('Mem', syntax.builtinTs['Mem'])] def mk_fun (nm, word_args, ex_args, word_rets, ex_rets, globs): - """wrapper for making a syntax.Function with standard args/rets.""" - return syntax.Function (nm, - [(nm, syntax.word32T) for nm in word_args] + ex_args + globs, - [(nm, syntax.word32T) for nm in word_rets] + ex_rets + globs) - -instruction_fun_specs = { - 'mcr' : ("impl'mcr", ["I"]), - 'mcr2' : ("impl'mcr", ["I"]), - 'mcrr' : ("impl'mcrr", ["I", "I"]), - 'mcrr2' : ("impl'mcrr", ["I", "I"]), - 'mrc': ("impl'mrc", ["O"]), - 'mrc2': ("impl'mrc", ["O"]), - 'mrrc': ("impl'mrrc", ["O", "O"]), - 'mrrc2': ("impl'mrrc", ["O", "O"]), - 'dsb': ("impl'dsb", []), - 'dmb': ("impl'dmb", []), - 'isb': ("impl'isb", []), - 'wfi': ("impl'wfi", []), + """wrapper for making a syntax.Function with standard args/rets.""" + #assert False + #rv64_hack + if syntax.arch.is_64bit: + wordT = syntax.word64T + else: + wordT = syntax.word32T + + return syntax.Function (nm, + [(nm, wordT) for nm in word_args] + ex_args + globs, + [(nm, wordT) for nm in word_rets] + ex_rets + globs) + +instruction_fun_specs_armv7 = { + 'mcr' : ("impl'mcr", ["I"]), + 'mcr2' : ("impl'mcr", ["I"]), + 'mcrr' : ("impl'mcrr", ["I", "I"]), + 'mcrr2' : ("impl'mcrr", ["I", "I"]), + 'mrc': ("impl'mrc", ["O"]), + 'mrc2': ("impl'mrc", ["O"]), + 'mrrc': ("impl'mrrc", ["O", "O"]), + 'mrrc2': ("impl'mrrc", ["O", "O"]), + 'dsb': ("impl'dsb", []), + 'dmb': ("impl'dmb", []), + 'isb': ("impl'isb", []), + 'wfi': ("impl'wfi", []), +} + +instruction_fun_specs_rv64 = { + + 'fence_i': ("impl'fence_i", []), + 'sfence.vma': ("impl'sfence_vma", []), + 'sfence.vma_x0': ("impl'sfence.vma_x0", ["I"]), + 'csrc_sstatus': ("impl'csrc_sstatus", ["I"]), + 'csrr_sip': ("impl'csrr_sip", ["O"]), + 'csrr_sepc': ("impl'csrr_sepc", ["O"]), + 'csrr_scause': ("impl'csrr_scause", ["O"]), + 'csrr_sstatus': ("impl'csrr_sstatus", ["O"]), + 'csrr_sscratch': ("impl'csrr_sscratch", ["O"]), + 'csrr_stval': ("impl'csrr_stval", ["O"]), + 'csrr_sbadaddr': ("impl'csrr_sbadaddr", ["O"]), + 'csrr_sie': ("impl'csrr_sie", ["O"]), + 'csrw_sptbr': ("impl'csrw_sptbr", ["I"]), + 'csrw_sepc': ("impl'csrw_sepc", ["I"]), + 'csrw_stvec': ("impl'csrw_stvec", ["I"]), + 'csrw_satp': ("impl'csrw_satp", ["I"]), + 'csrw_sscratch': ("impl'csrw_sscratch", ["I"]), + 'csrw_sstatus': ("impl'sstatus", ["I"]), + 'csrw_sie': ("impl'csrw_sie", ["I"]), + 'csrrc_sie': ("impl'csrrc_sie", ["O", "I"]), + 'csrrs_sie': ("impl'csrrs_sie", ["O", "I"]), + 'csrrw_sscratch': ("impl'csrrw_sscratch", ["O", "I"]), + 'sc_w': ("impl'sc_w", ["I","O","I"]), + 'wfi': ("impl'wfi", []), + 'sret': ("sret", []), + 'ebreak': ("impl'ebreak", []), + 'ecall': ("impl'ecall", ["I", "I", "I", "I", "O"]), + 'rdtime': ("impl'rdtime", ["O"]), + 'rdcycle': ("impl'rdcycle", ["O"]), + 'unimp': ("unimp", []), } instruction_name_aliases = { - 'isb_sy': 'isb', - 'dmb_sy': 'dmb', - 'dsb_sy': 'dsb', + 'isb_sy': 'isb', + 'dmb_sy': 'dmb', + 'dsb_sy': 'dsb', } def add_impl_fun (impl_fname, regspecs): - l_fname = 'l_' + impl_fname - r_fname = 'r_' + impl_fname - if l_fname in functions: - return - assert r_fname not in functions - - ident_v = ("inst_ident", syntax.builtinTs['Token']) - - inps = [s for s in regspecs if s == 'I'] - inps = ['reg_val%d' % (i + 1) for (i, s) in enumerate (inps)] - rets = [s for s in regspecs if s == 'O'] - rets = ['ret_val%d' % (i + 1) for (i, s) in enumerate (rets)] - l_fun = mk_fun (l_fname, inps, [ident_v], rets, [], bin_globs) - r_fun = mk_fun (r_fname, inps, [ident_v], rets, [], bin_globs) - inp_eqs = [((mk_var (nm, typ), 'ASM_IN'), (mk_var (nm, typ), 'C_IN')) - for (nm, typ) in l_fun.inputs] - inp_eqs += [((logic.mk_rodata (mk_var (nm, typ)), 'ASM_IN'), - (syntax.true_term, 'C_IN')) for (nm, typ) in bin_globs] - out_eqs = [((mk_var (nm, typ), 'ASM_OUT'), (mk_var (nm, typ), 'C_OUT')) - for (nm, typ) in l_fun.outputs] - out_eqs += [((logic.mk_rodata (mk_var (nm, typ)), 'ASM_OUT'), - (syntax.true_term, 'C_OUT')) for (nm, typ) in bin_globs] - pair = logic.Pairing (['ASM', 'C'], {'ASM': l_fname, 'C': r_fname}, - (inp_eqs, out_eqs)) - assert l_fname not in pairings - assert r_fname not in pairings - functions[l_fname] = l_fun - functions[r_fname] = r_fun - pairings[l_fname] = [pair] - pairings[r_fname] = [pair] - -inst_addr_re = re.compile('E[0123456789][0123456789]*') + #assert False + l_fname = 'l_' + impl_fname + r_fname = 'r_' + impl_fname + + if l_fname in functions: + #print 'skip_add %s' % l_fname + return + + assert r_fname not in functions + + ident_v = ("inst_ident", syntax.builtinTs['Token']) + + inps = [s for s in regspecs if s == 'I'] + inps = ['reg_val%d' % (i + 1) for (i, s) in enumerate (inps)] + rets = [s for s in regspecs if s == 'O'] + rets = ['ret_val%d' % (i + 1) for (i, s) in enumerate (rets)] + + l_fun = mk_fun (l_fname, inps, [ident_v], rets, [], bin_globs) + r_fun = mk_fun (r_fname, inps, [ident_v], rets, [], bin_globs) + + inp_eqs = [((mk_var (nm, typ), 'ASM_IN'), (mk_var (nm, typ), 'C_IN')) + for (nm, typ) in l_fun.inputs] + inp_eqs += [((logic.mk_rodata (mk_var (nm, typ)), 'ASM_IN'), + (syntax.true_term, 'C_IN')) for (nm, typ) in bin_globs] + out_eqs = [((mk_var (nm, typ), 'ASM_OUT'), (mk_var (nm, typ), 'C_OUT')) + for (nm, typ) in l_fun.outputs] + out_eqs += [((logic.mk_rodata (mk_var (nm, typ)), 'ASM_OUT'), + (syntax.true_term, 'C_OUT')) for (nm, typ) in bin_globs] + pair = logic.Pairing (['ASM', 'C'], {'ASM': l_fname, 'C': r_fname}, + (inp_eqs, out_eqs)) + assert l_fname not in pairings + assert r_fname not in pairings + functions[l_fname] = l_fun + functions[r_fname] = r_fun + pairings[l_fname] = [pair] + pairings[r_fname] = [pair] + #print 'addpairing %s %s' % (l_fname, r_fname) + +inst_addr_re_armv7 = re.compile('E[0123456789][0123456789]*') +# Note that the decompiler seems ignore the top 4 bytes which +# are 0xfffffff for rv64. So we just use the least significant +# 4 bytes at the moment. We lock the 0x84xxxxxx addresses. +inst_addr_re_rv64 = re.compile('84[0123456789ABCDEF]*') + def split_inst_name_addr (instname): - bits = instname.split('_') - assert bits, instname - addr = bits[-1] - assert inst_addr_re.match (addr), instname - addr = int (addr[1:], 16) - return ('_'.join (bits[:-1]), addr) + bits = instname.split('_') + assert bits, instname + addr = bits[-1] + if syntax.arch.name == 'armv7': + inst_addr_re = inst_addr_re_armv7 + elif syntax.arch.name == 'rv64': + inst_addr_re = inst_addr_re_rv64 + else: + assert False + + assert inst_addr_re.match(addr), instname + + #addr = int (addr[1:], 16) + #rv64_hack + addr = int(addr, 16) + return ('_'.join (bits[:-1]), addr) def mk_bin_inst_spec (fname): - if not fname.startswith ("instruction'"): - return - if functions[fname].entry: - return - (_, ident) = fname.split ("'", 1) - (ident, addr) = split_inst_name_addr (ident) - (regs, ident) = split_inst_name_regs (ident) - ident = instruction_name_aliases.get (ident, ident) - base_ident = ident.split ("_")[0] - if base_ident not in instruction_fun_specs: - return - (impl_fname, regspecs) = instruction_fun_specs[base_ident] - add_impl_fun (impl_fname, regspecs) - assert len (regspecs) == len (regs), (fname, regs, regspecs) - inp_regs = [reg for (reg, d) in zip (regs, regspecs) if d == 'I'] - out_regs = [reg for (reg, d) in zip (regs, regspecs) if d == 'O'] - call = syntax.Node ('Call', 'Ret', ('l_' + impl_fname, - [syntax.mk_var (reg, syntax.word32T) for reg in inp_regs] - + [syntax.mk_token (ident)] - + [syntax.mk_var (nm, typ) for (nm, typ) in bin_globs], - [(reg, syntax.word32T) for reg in out_regs] + bin_globs)) - assert not functions[fname].nodes - functions[fname].nodes[1] = call - functions[fname].entry = 1 + #rv64_hack + if syntax.arch.name == 'armv7': + instruction_fun_specs = instruction_fun_specs_armv7 + wordT = syntax.word32T + elif syntax.arch.name == 'rv64': + instruction_fun_specs = instruction_fun_specs_rv64 + wordT = syntax.word64T + else: + assert False + + if not fname.startswith ("instruction'"): + return + + if functions[fname].entry: + return + (_, ident) = fname.split ("'", 1) + (ident, addr) = split_inst_name_addr (ident) + (regs, ident) = split_inst_name_regs (ident) + ident = instruction_name_aliases.get (ident, ident) + if syntax.arch.name == 'armv7': + base_ident = ident.split ("_")[0] + else: + tmp = ident.split('-') + if len(tmp) > 1: + base_ident = ident.split('-')[0][:-1] + else: + base_ident = tmp[0] + if base_ident not in instruction_fun_specs: + print base_ident + assert False + return + + (impl_fname, regspecs) = instruction_fun_specs[base_ident] + #impl_fname = impl_fname + '@' + str(hex(addr)) + add_impl_fun (impl_fname, regspecs) + assert len (regspecs) == len (regs), (fname, regs, regspecs) + inp_regs = [reg for (reg, d) in zip (regs, regspecs) if d == 'I'] + out_regs = [reg for (reg, d) in zip (regs, regspecs) if d == 'O'] + call = syntax.Node ('Call', 'Ret', ('l_' + impl_fname, + [syntax.mk_var (reg, wordT) for reg in inp_regs] + + [syntax.mk_token (ident)] + + [syntax.mk_var (nm, typ) for (nm, typ) in bin_globs], + [(reg, wordT) for reg in out_regs] + bin_globs)) + assert not functions[fname].nodes + functions[fname].nodes[1] = call + functions[fname].entry = 1 + +# inline assembly from C-refine def mk_asm_inst_spec (fname): - if not fname.startswith ("asm_instruction'"): - return - if functions[fname].entry: - return - (_, ident) = fname.split ("'", 1) - (args, ident) = split_inst_name_regs (ident) - if not all ([arg.startswith ('%') for arg in args]): - printout ('Warning: asm instruction name: formatting: %r' - % fname) - return - base_ident = ident.split ("_")[0] - if base_ident not in instruction_fun_specs: - return - (impl_fname, regspecs) = instruction_fun_specs[base_ident] - add_impl_fun (impl_fname, regspecs) - (iscs, imems, _) = logic.split_scalar_pairs (functions[fname].inputs) - (oscs, omems, _) = logic.split_scalar_pairs (functions[fname].outputs) - call = syntax.Node ('Call', 'Ret', ('r_' + impl_fname, - iscs + [syntax.mk_token (ident)] + imems, - [(v.name, v.typ) for v in oscs + omems])) - assert not functions[fname].nodes - functions[fname].nodes[1] = call - functions[fname].entry = 1 + + if not fname.startswith ("asm_instruction'"): + return + + if syntax.arch.name == 'armv7': + instruction_fun_specs = instruction_fun_specs_armv7 + elif syntax.arch.name == 'rv64': + instruction_fun_specs = instruction_fun_specs_rv64 + else: + assert False + + if functions[fname].entry: + # print 'already %s %s' % (fname, functions[fname].entry) + return + + (_, ident) = fname.split ("'", 1) + (args, ident) = split_inst_name_regs (ident) + + if syntax.arch.name == 'armv7': + if not all ([arg.startswith ('%') for arg in args]): + printout ('Warning: asm instruction name: formatting: %r' + % fname) + return + + if syntax.arch.name == 'armv7': + base_ident = ident.split ("_")[0] + else: + tmp = ident.split('-') + if len(tmp) > 1: + base_ident = tmp[0][:-1] + else: + base_ident = tmp[0] + + if base_ident not in instruction_fun_specs: + print base_ident + assert False + return + + (impl_fname, regspecs) = instruction_fun_specs[base_ident] + +# impl_fname = 'asm_' + impl_fname + add_impl_fun (impl_fname, regspecs) + + (iscs, imems, _) = logic.split_scalar_pairs (functions[fname].inputs) + (oscs, omems, _) = logic.split_scalar_pairs (functions[fname].outputs) + + call = syntax.Node ('Call', 'Ret', ('r_' + impl_fname, + iscs + [syntax.mk_token (ident)] + imems, + [(v.name, v.typ) for v in oscs + omems])) + assert not functions[fname].nodes + functions[fname].nodes[1] = call + functions[fname].entry = 1 def add_inst_specs (report_problematic = True): - for f in functions.keys (): - mk_asm_inst_spec (f) - mk_bin_inst_spec (f) - if report_problematic: - problematic_instructions () + for f in functions.keys (): + mk_asm_inst_spec (f) + mk_bin_inst_spec (f) + if report_problematic: + problematic_instructions () def problematic_instructions (): - add_inst_specs (report_problematic = False) - unhandled = {} - for f in functions: - for f2 in functions[f].function_calls (): - if "instruction'" not in f2: - continue - if functions[f2].entry: - continue - unhandled.setdefault (f, []) - unhandled[f].append (f2) - for f in unhandled: - printout ('Function %r contains unhandled instructions:' % f) - printout (' %s' % unhandled[f]) - return unhandled + add_inst_specs (report_problematic = False) + unhandled = {} + for f in functions: + for f2 in functions[f].function_calls (): + if "instruction'" not in f2: + continue + if functions[f2].entry: + continue + unhandled.setdefault (f, []) + unhandled[f].append (f2) + for f in unhandled: + printout ('Function %r contains unhandled instructions:' % f) + printout (' %s' % unhandled[f]) + return unhandled diff --git a/logic.py b/logic.py index 164dcb9a..3d39f88d 100644 --- a/logic.py +++ b/logic.py @@ -5,902 +5,881 @@ # import syntax -from syntax import word32T, word8T, boolT, builtinTs, Expr, Node +from syntax import word64T, word32T, word8T, boolT, builtinTs, Expr, Node from syntax import true_term, false_term, mk_num from syntax import foldr1 -(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq, -mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8, -mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs, -mk_if, mk_meta_typ, mk_pvalid) = syntax.mks +from syntax import (mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq, + mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word64, mk_word32, mk_word8, + mk_word32_maybe, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs, + mk_if, mk_meta_typ, mk_pvalid) from syntax import structs from target_objects import trace, printout def is_int (n): - return hasattr (n, '__int__') + return hasattr (n, '__int__') def mk_eq_with_cast (a, c): - return mk_eq (a, mk_cast (c, a.typ)) + return mk_eq (a, syntax.arch.mk_cast (c, a.typ)) def mk_rodata (m): - assert m.typ == builtinTs['Mem'] - return Expr ('Op', boolT, name = 'ROData', vals = [m]) - -def cast_pair (((a, a_addr), (c, c_addr))): - if a.typ != c.typ and c.typ == boolT: - c = mk_if (c, mk_word32 (1), mk_word32 (0)) - return ((a, a_addr), (mk_cast (c, a.typ), c_addr)) - -ghost_assertion_type = syntax.Type ('WordArray', 50, 32) + assert m.typ == builtinTs['Mem'] + return Expr ('Op', boolT, name = 'ROData', vals = [m]) def split_scalar_globals (vs): - for i in range (len (vs)): - if vs[i].typ.kind != 'Word' and vs[i].typ != boolT: - break - else: - i = len (vs) - scalars = vs[:i] - global_vars = vs[i:] - for v in global_vars: - if v.typ not in [builtinTs['Mem'], builtinTs['Dom'], - builtinTs['HTD'], builtinTs['PMS'], - ghost_assertion_type]: - assert not "scalar_global split expected", vs - memT = builtinTs['Mem'] - mems = [v for v in global_vars if v.typ == memT] - others = [v for v in global_vars if v.typ != memT] - return (scalars, mems, others) + for i in range (len (vs)): + if vs[i].typ.kind != 'Word' and vs[i].typ != boolT: + break + else: + i = len (vs) + scalars = vs[:i] + global_vars = vs[i:] + + for v in global_vars: + if v.typ not in [builtinTs['Mem'], builtinTs['Dom'], + builtinTs['HTD'], builtinTs['PMS'], + syntax.arch.ghost_assertion_type]: + assert not "scalar_global split expected", vs + + memT = builtinTs['Mem'] + mems = [v for v in global_vars if v.typ == memT] + others = [v for v in global_vars if v.typ != memT] + return (scalars, mems, others) def mk_vars (tups): - return [mk_var (nm, typ) for (nm, typ) in tups] + return [mk_var (nm, typ) for (nm, typ) in tups] def split_scalar_pairs (var_pairs): - return split_scalar_globals (mk_vars (var_pairs)) + return split_scalar_globals (mk_vars (var_pairs)) def azip (xs, ys): - assert len (xs) == len (ys) - return zip (xs, ys) + assert len (xs) == len (ys), (xs, ys) + return zip (xs, ys) def mk_mem_eqs (a_imem, c_imem, a_omem, c_omem, tags): - [a_imem] = a_imem - a_tag, c_tag = tags - (c_in, c_out) = (c_tag + '_IN', c_tag + '_OUT') - (a_in, a_out) = (a_tag + '_IN', a_tag + '_OUT') - if c_imem: - [c_imem] = c_imem - ieqs = [((a_imem, a_in), (c_imem, c_in)), - ((mk_rodata (c_imem), c_in), (true_term, c_in))] - else: - ieqs = [((mk_rodata (a_imem), a_in), (true_term, c_in))] - if c_omem: - [a_m] = a_omem - [c_omem] = c_omem - oeqs = [((a_m, a_out), (c_omem, c_out)), - ((mk_rodata (c_omem), c_out), (true_term, c_out))] - else: - oeqs = [((a_m, a_out), (a_imem, a_in)) for a_m in a_omem] - - return (ieqs, oeqs) - -def mk_fun_eqs (as_f, c_f, prunes = None): - (var_a_args, a_imem, glob_a_args) = split_scalar_pairs (as_f.inputs) - (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_f.inputs) - (var_a_rets, a_omem, glob_a_rets) = split_scalar_pairs (as_f.outputs) - (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_f.outputs) - - (mem_ieqs, mem_oeqs) = mk_mem_eqs (a_imem, c_imem, a_omem, c_omem, - ['ASM', 'C']) - - if not prunes: - prunes = (var_a_args, var_a_args) - assert len (prunes[0]) == len (var_c_args), (params, var_a_args, - var_c_args, prunes) - a_map = dict (azip (prunes[1], var_a_args)) - ivar_pairs = [((a_map[p], 'ASM_IN'), (c, 'C_IN')) for (p, c) - in azip (prunes[0], var_c_args) if p in a_map] - - ovar_pairs = [((a_ret, 'ASM_OUT'), (c_ret, 'C_OUT')) for (a_ret, c_ret) - in azip (var_a_rets, var_c_rets)] - return (map (cast_pair, mem_ieqs + ivar_pairs), - map (cast_pair, mem_oeqs + ovar_pairs)) + [a_imem] = a_imem + a_tag, c_tag = tags + (c_in, c_out) = (c_tag + '_IN', c_tag + '_OUT') + (a_in, a_out) = (a_tag + '_IN', a_tag + '_OUT') + if c_imem: + [c_imem] = c_imem + ieqs = [((a_imem, a_in), (c_imem, c_in)), + ((mk_rodata (c_imem), c_in), (true_term, c_in))] + else: + ieqs = [((mk_rodata (a_imem), a_in), (true_term, c_in))] + if c_omem: + [a_m] = a_omem + [c_omem] = c_omem + oeqs = [((a_m, a_out), (c_omem, c_out)), + ((mk_rodata (c_omem), c_out), (true_term, c_out))] + else: + oeqs = [((a_m, a_out), (a_imem, a_in)) for a_m in a_omem] + + return (ieqs, oeqs) def mk_var_list (vs, typ): - return [syntax.mk_var (v, typ) for v in vs] + return [syntax.mk_var (v, typ) for v in vs] def mk_offs_sequence (init, offs, n, do_reverse = False): - r = range (n) - if do_reverse: - r.reverse () - def mk_offs (n): - return Expr ('Num', init.typ, val = offs * n) - return [mk_plus (init, mk_offs (m)) for m in r] + r = range (n) + if do_reverse: + r.reverse () + def mk_offs (n): + return Expr ('Num', init.typ, val = offs * n) + return [mk_plus (init, mk_offs (m)) for m in r] def mk_stack_sequence (sp, offs, stack, typ, n, do_reverse = False): - return [(mk_memacc (stack, addr, typ), addr) - for addr in mk_offs_sequence (sp, offs, n, do_reverse)] + return [(mk_memacc (stack, addr, typ), addr) + for addr in mk_offs_sequence (sp, offs, n, do_reverse)] def mk_aligned (w, n): - assert w.typ.kind == 'Word' - mask = Expr ('Num', w.typ, val = ((1 << n) - 1)) - return mk_eq (mk_bwand (w, mask), mk_num (0, w.typ)) - -def mk_eqs_arm_none_eabi_gnu (var_c_args, var_c_rets, c_imem, c_omem, - min_stack_size): - arg_regs = mk_var_list (['r0', 'r1', 'r2', 'r3'], word32T) - r0 = arg_regs[0] - sp = mk_var ('r13', word32T) - st = mk_var ('stack', builtinTs['Mem']) - r0_input = mk_var ('ret_addr_input', word32T) - sregs = mk_stack_sequence (sp, 4, st, word32T, len (var_c_args) + 1) - - ret = mk_var ('ret', word32T) - preconds = [mk_aligned (sp, 2), mk_eq (ret, mk_var ('r14', word32T)), - mk_aligned (ret, 2), mk_eq (r0_input, r0), - mk_less_eq (min_stack_size, sp)] - post_eqs = [(x, x) for x in mk_var_list (['r4', 'r5', 'r6', 'r7', 'r8', - 'r9', 'r10', 'r11', 'r13'], word32T)] - - arg_seq = [(r, None) for r in arg_regs] + sregs - if len (var_c_rets) > 1: - # the 'return-too-much' issue. - # instead r0 is a save-returns-here pointer - arg_seq.pop (0) - preconds += [mk_aligned (r0, 2), mk_less_eq (sp, r0)] - save_seq = mk_stack_sequence (r0_input, 4, st, word32T, - len (var_c_rets)) - save_addrs = [addr for (_, addr) in save_seq] - post_eqs += [(r0_input, r0_input)] - out_eqs = zip (var_c_rets, [x for (x, _) in save_seq]) - out_eqs = [(c, mk_cast (a, c.typ)) for (c, a) in out_eqs] - init_save_seq = mk_stack_sequence (r0, 4, st, word32T, - len (var_c_rets)) - (_, last_arg_addr) = arg_seq[len (var_c_args) - 1] - preconds += [mk_less_eq (sp, addr) - for (_, addr) in init_save_seq[-1:]] - if last_arg_addr: - preconds += [mk_less (last_arg_addr, addr) - for (_, addr) in init_save_seq[:1]] - else: - out_eqs = zip (var_c_rets, [r0]) - save_addrs = [] - arg_seq_addrs = [addr for ((_, addr), _) in zip (arg_seq, var_c_args) - if addr != None] - swrap = mk_stack_wrapper (sp, st, arg_seq_addrs) - swrap2 = mk_stack_wrapper (sp, st, save_addrs) - post_eqs += [(swrap, swrap2)] - - mem = mk_var ('mem', builtinTs['Mem']) - (mem_ieqs, mem_oeqs) = mk_mem_eqs ([mem], c_imem, [mem], c_omem, - ['ASM', 'C']) - - addr = None - arg_eqs = [cast_pair (((a_x, 'ASM_IN'), (c_x, 'C_IN'))) - for (c_x, (a_x, addr)) in zip (var_c_args, arg_seq)] - if addr: - preconds += [mk_less_eq (sp, addr)] - ret_eqs = [cast_pair (((a_x, 'ASM_OUT'), (c_x, 'C_OUT'))) - for (c_x, a_x) in out_eqs] - preconds = [((a_x, 'ASM_IN'), (true_term, 'ASM_IN')) for a_x in preconds] - asm_invs = [((vin, 'ASM_IN'), (vout, 'ASM_OUT')) for (vin, vout) in post_eqs] - - return (arg_eqs + mem_ieqs + preconds, - ret_eqs + mem_oeqs + asm_invs) - -known_CPUs = { - 'arm-none-eabi-gnu': mk_eqs_arm_none_eabi_gnu -} - -def mk_fun_eqs_CPU (cpu_f, c_f, cpu_name, funcall_depth = 1): - cpu = known_CPUs[cpu_name] - (var_c_args, c_imem, glob_c_args) = split_scalar_pairs (c_f.inputs) - (var_c_rets, c_omem, glob_c_rets) = split_scalar_pairs (c_f.outputs) - - return cpu (var_c_args, var_c_rets, c_imem, c_omem, - (funcall_depth * 256) + 256) + assert w.typ.kind == 'Word' + mask = Expr ('Num', w.typ, val = ((1 << n) - 1)) + return mk_eq (mk_bwand (w, mask), mk_num (0, w.typ)) + +def mk_eqs(var_c_args, var_c_rets, c_imem, c_omem, min_stack_size): + arg_regs = mk_var_list(syntax.arch.argument_registers, syntax.arch.word_type) + ar0 = arg_regs[0] + + sp = mk_var(syntax.arch.sp_register, syntax.arch.word_type) + st = mk_var('stack', builtinTs['Mem']) + ra_input = mk_var('ret_addr_input', syntax.arch.word_type) + sregs = mk_stack_sequence(sp, syntax.arch.ptr_size, st, syntax.arch.word_type, len(var_c_args) + 1) + ret = mk_var('ret', syntax.arch.word_type) + + preconds = [ + mk_aligned(sp, syntax.arch.stack_alignment_bits), + mk_eq(ret, mk_var(syntax.arch.ra_register, syntax.arch.word_type)), + mk_aligned(ret, syntax.arch.ret_addr_alignment_bits), + mk_eq(ra_input, ar0), + mk_less_eq(min_stack_size, sp) + ] + + post_eqs = [ + (x, x) for x in mk_var_list( + syntax.arch.callee_saved_registers, + syntax.arch.word_type) + ] + + arg_seq = [(r, None) for r in arg_regs] + sregs + + if len(var_c_rets) > len(syntax.arch.return_registers): + # Our return value exceeds the capacity of the return + # registers. In this case, the caller will pass us a + # memory address where we can save our return value + # as an implicit first argument (passed in the first + # argument register). + arg_seq.pop(0) + preconds += [ + mk_aligned(ar0, syntax.arch.ptr_alignment_bits), + mk_less_eq(sp, ar0) + ] + save_seq = mk_stack_sequence( + ra_input, syntax.arch.ptr_size, st, + syntax.arch.word_type, len(var_c_rets) + ) + save_addrs = [addr for (_, addr) in save_seq] + post_eqs += [(ra_input, ra_input)] + out_eqs = zip(var_c_rets, [x for (x, _) in save_seq]) + out_eqs = [(c, syntax.arch.mk_cast(a, c.typ)) for (c, a) in out_eqs] + init_save_seq = mk_stack_sequence( + ar0, syntax.arch.ptr_size, st, + syntax.arch.word_type, len(var_c_rets) + ) + (_, last_arg_addr) = arg_seq[len(var_c_args) - 1] + preconds += [mk_less_eq (sp, addr) + for (_, addr) in init_save_seq[-1:]] + if last_arg_addr: + preconds += [mk_less (last_arg_addr, addr) + for (_, addr) in init_save_seq[:1]] + else: + out_eqs = zip(var_c_rets, arg_regs[:len(syntax.arch.argument_registers)]) + save_addrs = [] + + arg_seq_addrs = [addr for ((_, addr), _) in zip (arg_seq, var_c_args) + if addr != None] + swrap = mk_stack_wrapper (sp, st, arg_seq_addrs) + swrap2 = mk_stack_wrapper (sp, st, save_addrs) + + post_eqs += [(swrap, swrap2)] + + mem = mk_var ('mem', builtinTs['Mem']) + (mem_ieqs, mem_oeqs) = mk_mem_eqs ([mem], c_imem, [mem], c_omem, + ['ASM', 'C']) + + addr = None + arg_eqs = [syntax.arch.cast_pair(((a_x, 'ASM_IN'), (c_x, 'C_IN'))) + for (c_x, (a_x, addr)) in zip (var_c_args, arg_seq)] + + if addr: + preconds += [mk_less_eq (sp, addr)] + + ret_eqs = [syntax.arch.cast_pair(((a_x, 'ASM_OUT'), (c_x, 'C_OUT'))) + for (c_x, a_x) in out_eqs] + + preconds = [((a_x, 'ASM_IN'), (true_term, 'ASM_IN')) for a_x in preconds] + asm_invs = [((vin, 'ASM_IN'), (vout, 'ASM_OUT')) for (vin, vout) in post_eqs] + + return (arg_eqs + mem_ieqs + preconds, ret_eqs + mem_oeqs + asm_invs) + class Pairing: - def __init__ (self, tags, funs, eqs, notes = None): - [l_tag, r_tag] = tags - self.tags = tags - assert set (funs) == set (tags) - self.funs = funs - self.eqs = eqs - - self.l_f = funs[l_tag] - self.r_f = funs[r_tag] - self.name = 'Pairing (%s (%s) <= %s (%s))' % (self.l_f, - l_tag, self.r_f, r_tag) - - self.notes = {} - if notes != None: - self.notes.update (notes) - - def __str__ (self): - return self.name - - def __hash__ (self): - return hash (self.name) - - def __eq__ (self, other): - return self.name == other.name and self.eqs == other.eqs - - def __ne__ (self, other): - return not other or not self == other - -def mk_pairing (functions, c_f, as_f, prunes = None, cpu = None): - fs = (functions[as_f], functions[c_f]) - if cpu: - eqs = mk_fun_eqs_CPU (fs[0], fs[1], cpu, - funcall_depth = funcall_depth (functions, c_f)) - else: - eqs = mk_fun_eqs (fs[0], fs[1], prunes = prunes) - return Pairing (['ASM', 'C'], {'C': c_f, 'ASM': as_f}, eqs) + def __init__ (self, tags, funs, eqs, notes = None): + [l_tag, r_tag] = tags + self.tags = tags + assert set (funs) == set (tags) + self.funs = funs + self.eqs = eqs + + self.l_f = funs[l_tag] + self.r_f = funs[r_tag] + self.name = 'Pairing (%s (%s) <= %s (%s))' % (self.l_f, + l_tag, self.r_f, r_tag) + + self.notes = {} + if notes != None: + self.notes.update (notes) + + def __str__ (self): + return self.name + + def __hash__ (self): + return hash (self.name) + + def __eq__ (self, other): + return self.name == other.name and self.eqs == other.eqs + + def __ne__ (self, other): + return not other or not self == other def inst_eqs_pattern (pattern, params): - (pat_params, inp_eqs, out_eqs) = pattern - substs = [((x.name, x.typ), y) - for (pat_vs, vs) in azip (pat_params, params) - for (x, y) in azip (pat_vs, vs)] - substs = dict (substs) - subst = lambda x: var_subst (x, substs) - return ([(subst (x), subst (y)) for (x, y) in inp_eqs], - [(subst (x), subst (y)) for (x, y) in out_eqs]) + (pat_params, inp_eqs, out_eqs) = pattern + substs = [((x.name, x.typ), y) + for (pat_vs, vs) in azip (pat_params, params) + for (x, y) in azip (pat_vs, vs)] + substs = dict (substs) + subst = lambda x: var_subst (x, substs) + return ([(subst (x), subst (y)) for (x, y) in inp_eqs], + [(subst (x), subst (y)) for (x, y) in out_eqs]) def inst_eqs_pattern_tuples (pattern, params): - return inst_eqs_pattern (pattern, map (mk_vars, params)) + return inst_eqs_pattern (pattern, map (mk_vars, params)) def inst_eqs_pattern_exprs (pattern, params): - (inp_eqs, out_eqs) = inst_eqs_pattern (pattern, params) - return (foldr1 (mk_and, [mk_eq (a, c) for (a, c) in inp_eqs]), - foldr1 (mk_and, [mk_eq (a, c) for (a, c) in out_eqs])) + (inp_eqs, out_eqs) = inst_eqs_pattern (pattern, params) + return (foldr1 (mk_and, [mk_eq (a, c) for (a, c) in inp_eqs]), + foldr1 (mk_and, [mk_eq (a, c) for (a, c) in out_eqs])) def var_match (var_exp, conc_exp, assigns): - if var_exp.typ != conc_exp.typ: - return False - if var_exp.kind == 'Var': - key = (var_exp.name, var_exp.typ) - if key in assigns: - return conc_exp == assigns[key] - else: - assigns[key] = conc_exp - return True - elif var_exp.kind == 'Op': - if conc_exp.kind != 'Op' or var_exp.name != conc_exp.name: - return False - return all ([var_match (a, b, assigns) - for (a, b) in azip (var_exp.vals, conc_exp.vals)]) - else: - return False + if var_exp.typ != conc_exp.typ: + return False + if var_exp.kind == 'Var': + key = (var_exp.name, var_exp.typ) + if key in assigns: + return conc_exp == assigns[key] + else: + assigns[key] = conc_exp + return True + elif var_exp.kind == 'Op': + if conc_exp.kind != 'Op' or var_exp.name != conc_exp.name: + return False + return all ([var_match (a, b, assigns) + for (a, b) in azip (var_exp.vals, conc_exp.vals)]) + else: + return False def var_subst (var_exp, assigns, must_subst = True): - def substor (var_exp): - if var_exp.kind == 'Var': - k = (var_exp.name, var_exp.typ) - if must_subst or k in assigns: - return assigns[k] - else: - return None - else: - return None - return syntax.do_subst (var_exp, substor) + + # hack for void func(void) + must_subst = False + def substor (var_exp): + if var_exp.kind == 'Var': + k = (var_exp.name, var_exp.typ) + if must_subst or k in assigns: + return assigns[k] + else: + return None + else: + return None + return syntax.do_subst (var_exp, substor) def recursive_term_subst (eqs, expr): - if expr in eqs: - return eqs[expr] - if expr.kind == 'Op': - vals = [recursive_term_subst (eqs, x) for x in expr.vals] - return syntax.adjust_op_vals (expr, vals) - return expr + if expr in eqs: + return eqs[expr] + if expr.kind == 'Op': + vals = [recursive_term_subst (eqs, x) for x in expr.vals] + return syntax.adjust_op_vals (expr, vals) + return expr def mk_accum_rewrites (typ): - x = mk_var ('x', typ) - y = mk_var ('y', typ) - z = mk_var ('z', typ) - i = mk_var ('i', typ) - return [(x, i, mk_plus (x, y), mk_plus (x, mk_times (i, y)), - y), - (x, i, mk_plus (y, x), mk_plus (x, mk_times (i, y)), - y), - (x, i, mk_minus (x, y), mk_minus (x, mk_times (i, y)), - mk_uminus (y)), - (x, i, mk_plus (mk_plus (x, y), z), - mk_plus (x, mk_times (i, mk_plus (y, z))), - mk_plus (y, z)), - (x, i, mk_plus (mk_plus (y, x), z), - mk_plus (x, mk_times (i, mk_plus (y, z))), - mk_plus (y, z)), - (x, i, mk_plus (y, mk_plus (x, z)), - mk_plus (x, mk_times (i, mk_plus (y, z))), - mk_plus (y, z)), - (x, i, mk_plus (y, mk_plus (z, x)), - mk_plus (x, mk_times (i, mk_plus (y, z))), - mk_plus (y, z)), - (x, i, mk_minus (mk_minus (x, y), z), - mk_minus (x, mk_times (i, mk_plus (y, z))), - mk_plus (y, z)), - ] + x = mk_var ('x', typ) + y = mk_var ('y', typ) + z = mk_var ('z', typ) + i = mk_var ('i', typ) + return [(x, i, mk_plus (x, y), mk_plus (x, mk_times (i, y)), + y), + (x, i, mk_plus (y, x), mk_plus (x, mk_times (i, y)), + y), + (x, i, mk_minus (x, y), mk_minus (x, mk_times (i, y)), + mk_uminus (y)), + (x, i, mk_plus (mk_plus (x, y), z), + mk_plus (x, mk_times (i, mk_plus (y, z))), + mk_plus (y, z)), + (x, i, mk_plus (mk_plus (y, x), z), + mk_plus (x, mk_times (i, mk_plus (y, z))), + mk_plus (y, z)), + (x, i, mk_plus (y, mk_plus (x, z)), + mk_plus (x, mk_times (i, mk_plus (y, z))), + mk_plus (y, z)), + (x, i, mk_plus (y, mk_plus (z, x)), + mk_plus (x, mk_times (i, mk_plus (y, z))), + mk_plus (y, z)), + (x, i, mk_minus (mk_minus (x, y), z), + mk_minus (x, mk_times (i, mk_plus (y, z))), + mk_plus (y, z)), + ] def mk_all_accum_rewrites (): - return [rew for typ in [word8T, word32T, syntax.word16T, - syntax.word64T] - for rew in mk_accum_rewrites (typ)] + return [rew for typ in [word8T, word32T, syntax.word16T, + syntax.word64T] + for rew in mk_accum_rewrites (typ)] accum_rewrites = mk_all_accum_rewrites () def default_val (typ): - if typ.kind == 'Word': - return Expr ('Num', typ, val = 0) - elif typ == boolT: - return false_term - else: - assert not 'default value for type %s created', typ + if typ.kind == 'Word': + return Expr ('Num', typ, val = 0) + elif typ == boolT: + return false_term + else: + assert not 'default value for type %s created', typ trace_accumulators = [] def accumulator_closed_form (expr, (nm, typ), add_mask = None): - expr = toplevel_split_out_cast (expr) - n = get_bwand_mask (expr) - if n and not add_mask: - return accumulator_closed_form (expr.vals[0], (nm, typ), - add_mask = n) - - for (x, i, pattern, rewrite, offset) in accum_rewrites: - var = mk_var (nm, typ) - ass = {(x.name, x.typ): var} - m = var_match (pattern, expr, ass) - if m: - x2_def = default_val (typ) - i2_def = default_val (word32T) - def do_rewrite (x2 = x2_def, i2 = i2_def): - ass[(x.name, x.typ)] = x2 - ass[(i.name, i.typ)] = i2 - vs = var_subst (rewrite, ass) - if add_mask: - vs = mk_bwand_mask (vs, add_mask) - return vs - offs = var_subst (offset, ass) - return (do_rewrite, offs) - if trace_accumulators: - trace ('no accumulator %s' % ((expr, nm, typ), )) - return (None, None) + expr = toplevel_split_out_cast (expr) + n = get_bwand_mask (expr) + if n and not add_mask: + return accumulator_closed_form (expr.vals[0], (nm, typ), + add_mask = n) + + for (x, i, pattern, rewrite, offset) in accum_rewrites: + var = mk_var (nm, typ) + ass = {(x.name, x.typ): var} + m = var_match (pattern, expr, ass) + if m: + x2_def = default_val (typ) + i2_def = default_val (word32T) + def do_rewrite (x2 = x2_def, i2 = i2_def): + ass[(x.name, x.typ)] = x2 + ass[(i.name, i.typ)] = i2 + vs = var_subst (rewrite, ass) + if add_mask: + vs = mk_bwand_mask (vs, add_mask) + return vs + offs = var_subst (offset, ass) + return (do_rewrite, offs) + if trace_accumulators: + trace ('no accumulator %s' % ((expr, nm, typ), )) + return (None, None) def split_out_cast (expr, target_typ, bits): - """given a word-type expression expr (of any word length), - compute a simplified expression expr' of the target type, which will - have the property that expr' && mask bits = cast expr, - where && is bitwise-and (BWAnd), mask n is the bitpattern set at the - bottom n bits, e.g. (1 << n) - 1, and cast is WordCast.""" - if expr.is_op (['WordCast', 'WordCastSigned']): - [x] = expr.vals - if x.typ.num >= bits and expr.typ.num >= bits: - return split_out_cast (x, target_typ, bits) - else: - return mk_cast (expr, target_typ) - elif expr.is_op ('BWAnd'): - [x, y] = expr.vals - if y.kind == 'Num': - val = y.val - else: - val = 0 - full_mask = (1 << bits) - 1 - if val & full_mask == full_mask: - return split_out_cast (x, target_typ, bits) - else: - return mk_cast (expr, target_typ) - elif expr.is_op (['Plus', 'Minus']): - # rounding issues will appear if this arithmetic is done - # at a smaller number of bits than we'll eventually report - if expr.typ.num >= bits: - vals = [split_out_cast (x, target_typ, bits) - for x in expr.vals] - if expr.is_op ('Plus'): - return mk_plus (vals[0], vals[1]) - else: - return mk_minus (vals[0], vals[1]) - else: - return mk_cast (expr, target_typ) - else: - return mk_cast (expr, target_typ) + """given a word-type expression expr (of any word length), + compute a simplified expression expr' of the target type, which will + have the property that expr' && mask bits = cast expr, + where && is bitwise-and (BWAnd), mask n is the bitpattern set at the + bottom n bits, e.g. (1 << n) - 1, and cast is WordCast.""" + if expr.is_op (['WordCast', 'WordCastSigned']): + [x] = expr.vals + if x.typ.num >= bits and expr.typ.num >= bits: + return split_out_cast (x, target_typ, bits) + else: + return syntax.arch.mk_cast (expr, target_typ) + elif expr.is_op ('BWAnd'): + [x, y] = expr.vals + if y.kind == 'Num': + val = y.val + else: + val = 0 + full_mask = (1 << bits) - 1 + if val & full_mask == full_mask: + return split_out_cast (x, target_typ, bits) + else: + return syntax.arch.mk_cast(expr, target_typ) + elif expr.is_op (['Plus', 'Minus']): + # rounding issues will appear if this arithmetic is done + # at a smaller number of bits than we'll eventually report + if expr.typ.num >= bits: + vals = [split_out_cast (x, target_typ, bits) + for x in expr.vals] + if expr.is_op ('Plus'): + return mk_plus (vals[0], vals[1]) + else: + return mk_minus (vals[0], vals[1]) + else: + return syntax.arch.mk_cast(expr, target_typ) + else: + return syntax.arch.mk_cast(expr, target_typ) def toplevel_split_out_cast (expr): - bits = None - if expr.is_op (['WordCast', 'WordCastSigned']): - bits = min ([expr.typ.num, expr.vals[0].typ.num]) - elif expr.is_op ('BWAnd'): - bits = get_bwand_mask (expr) - - if bits: - expr = split_out_cast (expr, expr.typ, bits) - return mk_bwand_mask (expr, bits) - else: - return expr + bits = None + if expr.is_op (['WordCast', 'WordCastSigned']): + bits = min ([expr.typ.num, expr.vals[0].typ.num]) + elif expr.is_op ('BWAnd'): + bits = get_bwand_mask (expr) + + if bits: + expr = split_out_cast (expr, expr.typ, bits) + return mk_bwand_mask (expr, bits) + else: + return expr two_powers = {} def get_bwand_mask (expr): - """recognise (x && mask) opers, where mask = ((1 << n) - 1) - for some n""" - if not expr.is_op ('BWAnd'): - return - [x, y] = expr.vals - if not y.kind == 'Num': - return - val = y.val & ((1 << (y.typ.num)) - 1) - if not two_powers: - for i in range (129): - two_powers[1 << i] = i - return two_powers.get (val + 1) + """recognise (x && mask) opers, where mask = ((1 << n) - 1) + for some n""" + if not expr.is_op ('BWAnd'): + return + [x, y] = expr.vals + if not y.kind == 'Num': + return + val = y.val & ((1 << (y.typ.num)) - 1) + if not two_powers: + for i in range (129): + two_powers[1 << i] = i + return two_powers.get (val + 1) def mk_bwand_mask (expr, n): - return mk_bwand (expr, mk_num (((1 << n) - 1), expr.typ)) + return mk_bwand (expr, mk_num (((1 << n) - 1), expr.typ)) def end_addr (p, typ): - if typ[0] == 'Array': - (_, typ, n) = typ - sz = mk_times (mk_word32 (typ.size ()), n) - else: - assert typ[0] == 'Type', typ - (_, typ) = typ - sz = mk_word32 (typ.size ()) - return mk_plus (p, mk_minus (sz, mk_word32 (1))) + if typ[0] == 'Array': + (_, typ, n) = typ + sz = mk_times (syntax.arch.mk_word(typ.size ()), n) + else: + assert typ[0] == 'Type', typ + (_, typ) = typ + sz = syntax.arch.mk_word(typ.size ()) + return mk_plus (p, mk_minus (sz, syntax.arch.mk_word(1))) def pvalid_assertion1 ((typ, k, p, pv), (typ2, k2, p2, pv2)): - """first pointer validity assertion: incompatibility. - pvalid1 & pvalid2 --> non-overlapping OR somehow-contained. - typ/typ2 is ('Type', syntax.Type) or ('Array', Type, Expr) for - dynamically sized arrays. - """ - offs1 = mk_minus (p, p2) - cond1 = get_styp_condition (offs1, typ, typ2) - offs2 = mk_minus (p2, p) - cond2 = get_styp_condition (offs2, typ2, typ) - - out1 = mk_less (end_addr (p, typ), p2) - out2 = mk_less (end_addr (p2, typ2), p) - return mk_implies (mk_and (pv, pv2), foldr1 (mk_or, - [cond1, cond2, out1, out2])) + """first pointer validity assertion: incompatibility. + pvalid1 & pvalid2 --> non-overlapping OR somehow-contained. + typ/typ2 is ('Type', syntax.Type) or ('Array', Type, Expr) for + dynamically sized arrays. + """ + offs1 = mk_minus (p, p2) + cond1 = get_styp_condition (offs1, typ, typ2) + offs2 = mk_minus (p2, p) + cond2 = get_styp_condition (offs2, typ2, typ) + + out1 = mk_less (end_addr (p, typ), p2) + out2 = mk_less (end_addr (p2, typ2), p) + return mk_implies (mk_and (pv, pv2), foldr1 (mk_or, + [cond1, cond2, out1, out2])) def pvalid_assertion2 ((typ, k, p, pv), (typ2, k2, p2, pv2)): - """second pointer validity assertion: implication. - pvalid1 & strictly-contained --> pvalid2 - """ - if typ[0] == 'Array' and typ2[0] == 'Array': - # this is such a vague notion it's not worth it - return true_term - offs1 = mk_minus (p, p2) - cond1 = get_styp_condition (offs1, typ, typ2) - imp1 = mk_implies (mk_and (cond1, pv2), pv) - offs2 = mk_minus (p2, p) - cond2 = get_styp_condition (offs2, typ2, typ) - imp2 = mk_implies (mk_and (cond2, pv), pv2) - return mk_and (imp1, imp2) + """second pointer validity assertion: implication. + pvalid1 & strictly-contained --> pvalid2 + """ + if typ[0] == 'Array' and typ2[0] == 'Array': + # this is such a vague notion it's not worth it + return true_term + offs1 = mk_minus (p, p2) + cond1 = get_styp_condition (offs1, typ, typ2) + imp1 = mk_implies (mk_and (cond1, pv2), pv) + offs2 = mk_minus (p2, p) + cond2 = get_styp_condition (offs2, typ2, typ) + imp2 = mk_implies (mk_and (cond2, pv), pv2) + return mk_and (imp1, imp2) def sym_distinct_assertion ((typ, p, pv), (start, end)): - out1 = mk_less (mk_plus (p, mk_word32 (typ.size () - 1)), mk_word32 (start)) - out2 = mk_less (mk_word32 (end), p) - return mk_implies (pv, mk_or (out1, out2)) + out1 = mk_less (mk_plus (p, syntax.arch.mk_word(typ.size () - 1)), syntax.arch.mk_word(start)) + out2 = mk_less (syntax.arch.mk_word(end), p) + return mk_implies (pv, mk_or (out1, out2)) def norm_array_type (t): - if t[0] == 'Type' and t[1].kind == 'Array': - (_, atyp) = t - return ('Array', atyp.el_typ_symb, mk_word32 (atyp.num), 'Strong') - elif t[0] == 'Array' and len (t) == 3: - (_, typ, l) = t - # these derive from PArrayValid assertions. we know the array is - # at least this long, but it might be longer. - return ('Array', typ, l, 'Weak') - else: - return t + if t[0] == 'Type' and t[1].kind == 'Array': + (_, atyp) = t + return ('Array', atyp.el_typ_symb, syntax.arch.mk_word(atyp.num), 'Strong') + elif t[0] == 'Array' and len (t) == 3: + (_, typ, l) = t + # these derive from PArrayValid assertions. we know the array is + # at least this long, but it might be longer. + return ('Array', typ, l, 'Weak') + else: + return t stored_styp_conditions = {} def get_styp_condition (offs, inner_typ, outer_typ): - r = get_styp_condition_inner1 (inner_typ, outer_typ) - if not r: - return false_term - else: - return r (offs) + r = get_styp_condition_inner1 (inner_typ, outer_typ) + if not r: + return false_term + else: + return r (offs) def get_styp_condition_inner1 (inner_typ, outer_typ): - inner_typ = norm_array_type (inner_typ) - outer_typ = norm_array_type (outer_typ) - k = (inner_typ, outer_typ) - if k in stored_styp_conditions: - return stored_styp_conditions[k] - r = get_styp_condition_inner2 (inner_typ, outer_typ) - stored_styp_conditions[k] = r - return r + inner_typ = norm_array_type (inner_typ) + outer_typ = norm_array_type (outer_typ) + k = (inner_typ, outer_typ) + if k in stored_styp_conditions: + return stored_styp_conditions[k] + r = get_styp_condition_inner2 (inner_typ, outer_typ) + stored_styp_conditions[k] = r + return r def array_typ_size ((kind, el_typ, num, _)): - el_size = mk_word32 (el_typ.size ()) - return mk_times (num, el_size) + el_size = syntax.arch.mk_word(el_typ.size ()) + return mk_times (num, el_size) def get_styp_condition_inner2 (inner_typ, outer_typ): - if inner_typ[0] == 'Array' and outer_typ[0] == 'Array': - (_, ityp, inum, _) = inner_typ - (_, otyp, onum, outer_bound) = outer_typ - # array fits in another array if the starting element is - # a sub-element, and if the size of the left array plus - # the offset fits in the right array - cond = get_styp_condition_inner1 (('Type', ityp), outer_typ) - isize = array_typ_size (inner_typ) - osize = array_typ_size (outer_typ) - if outer_bound == 'Strong' and cond: - return lambda offs: mk_and (cond (offs), - mk_less_eq (mk_plus (isize, offs), osize)) - else: - return cond - elif inner_typ == outer_typ: - return lambda offs: mk_eq (offs, mk_word32 (0)) - elif outer_typ[0] == 'Type' and outer_typ[1].kind == 'Struct': - conds = [(get_styp_condition_inner1 (inner_typ, - ('Type', sf_typ)), mk_word32 (offs2)) - for (_, offs2, sf_typ) - in structs[outer_typ[1].name].fields.itervalues()] - conds = [cond for cond in conds if cond[0]] - if conds: - return lambda offs: foldr1 (mk_or, - [c (mk_minus (offs, offs2)) - for (c, offs2) in conds]) - else: - return None - elif outer_typ[0] == 'Array': - (_, el_typ, n, bound) = outer_typ - cond = get_styp_condition_inner1 (inner_typ, ('Type', el_typ)) - el_size = mk_word32 (el_typ.size ()) - size = mk_times (n, el_size) - if bound == 'Strong' and cond: - return lambda offs: mk_and (mk_less (offs, size), - cond (mk_modulus (offs, el_size))) - elif cond: - return lambda offs: cond (mk_modulus (offs, el_size)) - else: - return None - else: - return None + if inner_typ[0] == 'Array' and outer_typ[0] == 'Array': + (_, ityp, inum, _) = inner_typ + (_, otyp, onum, outer_bound) = outer_typ + # array fits in another array if the starting element is + # a sub-element, and if the size of the left array plus + # the offset fits in the right array + cond = get_styp_condition_inner1 (('Type', ityp), outer_typ) + isize = array_typ_size (inner_typ) + osize = array_typ_size (outer_typ) + if outer_bound == 'Strong' and cond: + return lambda offs: mk_and (cond (offs), + mk_less_eq (mk_plus (isize, offs), osize)) + else: + return cond + elif inner_typ == outer_typ: + return lambda offs: mk_eq (offs, syntax.arch.mk_word(0)) + elif outer_typ[0] == 'Type' and outer_typ[1].kind == 'Struct': + conds = [(get_styp_condition_inner1 (inner_typ, + ('Type', sf_typ)), syntax.arch.mk_word(offs2)) + for (_, offs2, sf_typ) + in structs[outer_typ[1].name].fields.itervalues()] + conds = [cond for cond in conds if cond[0]] + if conds: + return lambda offs: foldr1 (mk_or, + [c (mk_minus (offs, offs2)) + for (c, offs2) in conds]) + else: + return None + elif outer_typ[0] == 'Array': + (_, el_typ, n, bound) = outer_typ + cond = get_styp_condition_inner1 (inner_typ, ('Type', el_typ)) + el_size = syntax.arch.mk_word(el_typ.size ()) + size = mk_times (n, el_size) + if bound == 'Strong' and cond: + return lambda offs: mk_and (mk_less (offs, size), + cond (mk_modulus (offs, el_size))) + elif cond: + return lambda offs: cond (mk_modulus (offs, el_size)) + else: + return None + else: + return None def all_vars_have_prop (expr, prop): - class Failed (Exception): - pass - def visit (expr): - if expr.kind != 'Var': - return - v2 = (expr.name, expr.typ) - if not prop (v2): - raise Failed () - try: - expr.visit (visit) - return True - except Failed: - return False + class Failed (Exception): + pass + def visit (expr): + if expr.kind != 'Var': + return + v2 = (expr.name, expr.typ) + if not prop (v2): + raise Failed () + try: + expr.visit (visit) + return True + except Failed: + return False def all_vars_in_set (expr, var_set): - return all_vars_have_prop (expr, lambda v: v in var_set) + return all_vars_have_prop (expr, lambda v: v in var_set) def var_not_in_expr (var, expr): - v2 = (var.name, var.typ) - return all_vars_have_prop (expr, lambda v: v != v2) + v2 = (var.name, var.typ) + return all_vars_have_prop (expr, lambda v: v != v2) def mk_array_size_ineq (typ, num, p): - align = typ.align () - size = mk_times (mk_word32 (typ.size ()), num) - size_lim = ((2 ** 32) - 4) / typ.size () - return mk_less_eq (num, mk_word32 (size_lim)) + align = typ.align () + size = mk_times (syntax.arch.mk_word(typ.size ()), num) + size_lim = ((2 ** syntax.arch.word_size) - syntax.arch.ptr_size) / typ.size() + return mk_less_eq (num, syntax.arch.mk_word(size_lim)) def mk_align_valid_ineq (typ, p): - if typ[0] == 'Type': - (_, typ) = typ - align = typ.align () - size = mk_word32 (typ.size ()) - size_req = [] - else: - assert typ[0] == 'Array', typ - (kind, typ, num) = typ - align = typ.align () - size = mk_times (mk_word32 (typ.size ()), num) - size_req = [mk_array_size_ineq (typ, num, p)] - assert align in [1, 4, 8] - w0 = mk_word32 (0) - if align > 1: - align_req = [mk_eq (mk_bwand (p, mk_word32 (align - 1)), w0)] - else: - align_req = [] - return foldr1 (mk_and, align_req + size_req + [mk_not (mk_eq (p, w0)), - mk_implies (mk_less (w0, size), - mk_less_eq (p, mk_uminus (size)))]) + if typ[0] == 'Type': + (_, typ) = typ + align = typ.align () + size = syntax.arch.mk_word(typ.size ()) + size_req = [] + else: + assert typ[0] == 'Array', typ + (kind, typ, num) = typ + align = typ.align () + size = mk_times (syntax.arch.mk_word(typ.size ()), num) + size_req = [mk_array_size_ineq (typ, num, p)] + assert align in [1, 4, 8] + w0 = syntax.arch.mk_word(0) + if align > 1: + align_req = [mk_eq (mk_bwand (p, syntax.arch.mk_word(align - 1)), w0)] + else: + align_req = [] + return foldr1 (mk_and, align_req + size_req + [mk_not (mk_eq (p, w0)), + mk_implies (mk_less (w0, size), + mk_less_eq (p, mk_uminus (size)))]) # generic operations on function/problem graphs def dict_list (xys, keys = None): - """dict_list ([(1, 2), (1, 3), (2, 4)]) = {1: [2, 3], 2: [4]}""" - d = {} - for (x, y) in xys: - d.setdefault (x, []) - d[x].append (y) - if keys: - for x in keys: - d.setdefault (x, []) - return d + """dict_list ([(1, 2), (1, 3), (2, 4)]) = {1: [2, 3], 2: [4]}""" + d = {} + for (x, y) in xys: + d.setdefault (x, []) + d[x].append (y) + if keys: + for x in keys: + d.setdefault (x, []) + return d def compute_preds (nodes): - preds = dict_list ([(c, n) for n in nodes - for c in nodes[n].get_conts ()], - keys = nodes) - for n in ['Ret', 'Err']: - preds.setdefault (n, []) - preds = dict ([(n, sorted (set (ps))) - for (n, ps) in preds.iteritems ()]) - return preds + preds = dict_list ([(c, n) for n in nodes + for c in nodes[n].get_conts ()], + keys = nodes) + for n in ['Ret', 'Err']: + preds.setdefault (n, []) + preds = dict ([(n, sorted (set (ps))) + for (n, ps) in preds.iteritems ()]) + return preds def simplify_node_elementary(node): - if node.kind == 'Cond' and node.cond == true_term: - return Node ('Basic', node.left, []) - elif node.kind == 'Cond' and node.cond == false_term: - return Node ('Basic', node.right, []) - elif node.kind == 'Cond' and node.left == node.right: - return Node ('Basic', node.left, []) - else: - return node + if node.kind == 'Cond' and node.cond == true_term: + return Node ('Basic', node.left, []) + elif node.kind == 'Cond' and node.cond == false_term: + return Node ('Basic', node.right, []) + elif node.kind == 'Cond' and node.left == node.right: + return Node ('Basic', node.left, []) + else: + return node def compute_var_flows (nodes, outputs, preds, override_lvals_rvals = {}): - # compute a graph of reverse var flows to pass to tarjan's algorithm - graph = {} - entries = ['Ret'] - for (n, node) in nodes.iteritems (): - if node.kind == 'Basic': - for (lv, rv) in node.upds: - graph[(n, 'Post', lv)] = [(n, 'Pre', v) - for v in syntax.get_expr_var_set (rv)] - elif node.is_noop (): - pass - else: - if n in override_lvals_rvals: - (lvals, rvals) = override_lvals_rvals[n] - else: - rvals = syntax.get_node_rvals (node) - rvals = set (rvals.iteritems ()) - lvals = set (node.get_lvals ()) - if node.kind != 'Basic': - lvals = list (lvals) + ['PC'] - entries.append ((n, 'Post', 'PC')) - for lv in lvals: - graph[(n, 'Post', lv)] = [(n, 'Pre', rv) - for rv in rvals] - graph['Ret'] = [(n, 'Post', v) - for n in preds['Ret'] for v in outputs (n)] - vs = set ([v for k in graph for (_, _, v) in graph[k]]) - for v in vs: - for n in nodes: - graph.setdefault ((n, 'Post', v), [(n, 'Pre', v)]) - graph[(n, 'Pre', v)] = [(n2, 'Post', v) - for n2 in preds[n]] - - comps = tarjan (graph, entries) - return (graph, comps) + # compute a graph of reverse var flows to pass to tarjan's algorithm + graph = {} + entries = ['Ret'] + for (n, node) in nodes.iteritems (): + if node.kind == 'Basic': + for (lv, rv) in node.upds: + graph[(n, 'Post', lv)] = [(n, 'Pre', v) + for v in syntax.get_expr_var_set (rv)] + elif node.is_noop (): + pass + else: + if n in override_lvals_rvals: + (lvals, rvals) = override_lvals_rvals[n] + else: + rvals = syntax.get_node_rvals (node) + rvals = set (rvals.iteritems ()) + lvals = set (node.get_lvals ()) + if node.kind != 'Basic': + lvals = list (lvals) + ['PC'] + entries.append ((n, 'Post', 'PC')) + for lv in lvals: + graph[(n, 'Post', lv)] = [(n, 'Pre', rv) + for rv in rvals] + graph['Ret'] = [(n, 'Post', v) + for n in preds['Ret'] for v in outputs (n)] + vs = set ([v for k in graph for (_, _, v) in graph[k]]) + for v in vs: + for n in nodes: + graph.setdefault ((n, 'Post', v), [(n, 'Pre', v)]) + graph[(n, 'Pre', v)] = [(n2, 'Post', v) + for n2 in preds[n]] + + comps = tarjan (graph, entries) + return (graph, comps) def mk_not_red (v): - if v.is_op ('Not'): - [v] = v.vals - return v - else: - return syntax.mk_not (v) + if v.is_op ('Not'): + [v] = v.vals + return v + else: + return syntax.mk_not (v) def cont_with_conds (nodes, n, conds): - while True: - if n not in nodes or nodes[n].kind != 'Cond': - return n - cond = nodes[n].cond - if cond in conds: - n = nodes[n].left - elif mk_not_red (cond) in conds: - n = nodes[n].right - else: - return n + while True: + if n not in nodes or nodes[n].kind != 'Cond': + return n + cond = nodes[n].cond + if cond in conds: + n = nodes[n].left + elif mk_not_red (cond) in conds: + n = nodes[n].right + else: + return n def contextual_conds (nodes, preds): - """computes a collection of conditions that can be assumed true - at any point in the node graph.""" - pre_conds = {} - arc_conds = {} - visit = [n for n in nodes if not (preds[n])] - while visit: - n = visit.pop () - if n not in nodes: - continue - in_arc_conds = [arc_conds.get ((pre, n), set ()) - for pre in preds[n]] - if not in_arc_conds: - conds = set () - else: - conds = set.intersection (* in_arc_conds) - if pre_conds.get (n) == conds: - continue - pre_conds[n] = conds - if n not in nodes: - continue - if nodes[n].kind == 'Cond' and nodes[n].left == nodes[n].right: - c_conds = [conds, conds] - elif nodes[n].kind == 'Cond': - c_conds = [nodes[n].cond, mk_not_red (nodes[n].cond)] - c_conds = [set.union (set ([c]), conds) - for c in c_conds] - else: - upds = set (nodes[n].get_lvals ()) - c_conds = [set ([c for c in conds if - not set.intersection (upds, - syntax.get_expr_var_set (c))])] - for (cont, conds) in zip (nodes[n].get_conts (), c_conds): - arc_conds[(n, cont)] = conds - visit.append (cont) - return (arc_conds, pre_conds) + """computes a collection of conditions that can be assumed true + at any point in the node graph.""" + pre_conds = {} + arc_conds = {} + visit = [n for n in nodes if not (preds[n])] + while visit: + n = visit.pop () + if n not in nodes: + continue + in_arc_conds = [arc_conds.get ((pre, n), set ()) + for pre in preds[n]] + if not in_arc_conds: + conds = set () + else: + conds = set.intersection (* in_arc_conds) + if pre_conds.get (n) == conds: + continue + pre_conds[n] = conds + if n not in nodes: + continue + if nodes[n].kind == 'Cond' and nodes[n].left == nodes[n].right: + c_conds = [conds, conds] + elif nodes[n].kind == 'Cond': + c_conds = [nodes[n].cond, mk_not_red (nodes[n].cond)] + c_conds = [set.union (set ([c]), conds) + for c in c_conds] + else: + upds = set (nodes[n].get_lvals ()) + c_conds = [set ([c for c in conds if + not set.intersection (upds, + syntax.get_expr_var_set (c))])] + for (cont, conds) in zip (nodes[n].get_conts (), c_conds): + arc_conds[(n, cont)] = conds + visit.append (cont) + return (arc_conds, pre_conds) def contextual_cond_simps (nodes, preds): - """a common pattern in architectures with conditional operations is - a sequence of instructions with the same condition. - we can usually then reduce to a single contional block. - b e => b-e - / \ / \ => / \ - a-c-d-f-g => a-c-f-g - this is sometimes important if b calculates a register that e uses - since variable dependency analysis will see this register escape via - the impossible path a-c-d-e - """ - (arc_conds, pre_conds) = contextual_conds (nodes, preds) - nodes = dict (nodes) - for n in nodes: - if nodes[n].kind == 'Cond': - continue - cont = nodes[n].cont - conds = arc_conds[(n, cont)] - cont2 = cont_with_conds (nodes, cont, conds) - if cont2 != cont: - nodes[n] = syntax.copy_rename (nodes[n], - ({}, {cont: cont2})) - return nodes + """a common pattern in architectures with conditional operations is + a sequence of instructions with the same condition. + we can usually then reduce to a single contional block. + b e => b-e + / \ / \ => / \ + a-c-d-f-g => a-c-f-g + this is sometimes important if b calculates a register that e uses + since variable dependency analysis will see this register escape via + the impossible path a-c-d-e + """ + (arc_conds, pre_conds) = contextual_conds (nodes, preds) + nodes = dict (nodes) + for n in nodes: + if nodes[n].kind == 'Cond': + continue + cont = nodes[n].cont + conds = arc_conds[(n, cont)] + cont2 = cont_with_conds (nodes, cont, conds) + if cont2 != cont: + nodes[n] = syntax.copy_rename (nodes[n], + ({}, {cont: cont2})) + return nodes def minimal_loop_node_set (p): - """discover a minimal set of loop addresses, excluding some operations - using conditional instructions which are syntactically within the - loop but semantically must always be followed by an immediate loop - exit. - - amounts to rerunning loop detection after contextual_cond_simps.""" - - loop_ns = set (p.loop_data) - really_in_loop = {} - nodes = contextual_cond_simps (p.nodes, p.preds) - def is_really_in_loop (n): - if n in really_in_loop: - return really_in_loop[n] - ns = [] - r = None - while r == None: - ns.append (n) - if n not in loop_ns: - r = False - elif n in p.splittable_points (n): - r = True - else: - conts = [n2 for n2 in nodes[n].get_conts () - if n2 != 'Err'] - if len (conts) > 1: - r = True - else: - [n] = conts - for n in ns: - really_in_loop[n] = r - return r - return set ([n for n in loop_ns if is_really_in_loop (n)]) + """discover a minimal set of loop addresses, excluding some operations + using conditional instructions which are syntactically within the + loop but semantically must always be followed by an immediate loop + exit. + + amounts to rerunning loop detection after contextual_cond_simps.""" + + loop_ns = set (p.loop_data) + really_in_loop = {} + nodes = contextual_cond_simps (p.nodes, p.preds) + def is_really_in_loop (n): + if n in really_in_loop: + return really_in_loop[n] + ns = [] + r = None + while r == None: + ns.append (n) + if n not in loop_ns: + r = False + elif n in p.splittable_points (n): + r = True + else: + conts = [n2 for n2 in nodes[n].get_conts () + if n2 != 'Err'] + if len (conts) > 1: + r = True + else: + [n] = conts + for n in ns: + really_in_loop[n] = r + return r + return set ([n for n in loop_ns if is_really_in_loop (n)]) def possible_graph_divs (p, min_cost = 20, max_cost = 20, ratio = 0.85, - trace = None): - es = [e[0] for e in p.entries] - divs = [] - direct_costs = {} - future_costs = {'Ret': set (), 'Err': set ()} - prev_costs = {} - int_costs = {} - fracs = {} - for n in p.nodes: - node = p.nodes[n] - if node.kind == 'Call': - cost = set ([(n, 20)]) - elif p.loop_id (n): - cost = set ([(p.loop_id (n), 50)]) - else: - cost = set ([(n, len (node.get_mem_accesses ()))]) - cost.discard ((n, 0)) - direct_costs[n] = cost - for n in p.tarjan_order: - prev_costs[n] = set.union (* ([direct_costs[n]] - + [prev_costs.get (c, set ()) for c in p.preds[n]])) - for n in reversed (p.tarjan_order): - cont_costs = [future_costs.get (c, set ()) - for c in p.nodes[n].get_conts ()] - cost = set.union (* ([direct_costs[n]] + cont_costs)) - p_ct = sum ([c for (_, c) in prev_costs[n]]) - future_costs[n] = cost - if p.nodes[n].kind != 'Cond' or p_ct > max_cost: - continue - ct = sum ([c for (_, c) in set.union (cost, prev_costs[n])]) - if ct < min_cost: - continue - [c1, c2] = [sum ([c for (_, c) - in set.union (cs, prev_costs[n])]) - for cs in cont_costs] - fracs[n] = ((c1 * c1) + (c2 * c2)) / (ct * ct * 1.0) - if fracs[n] < ratio: - divs.append (n) - divs.reverse () - if trace != None: - trace[0] = (direct_costs, future_costs, prev_costs, - int_costs, fracs) - return divs + trace = None): + es = [e[0] for e in p.entries] + divs = [] + direct_costs = {} + future_costs = {'Ret': set (), 'Err': set ()} + prev_costs = {} + int_costs = {} + fracs = {} + for n in p.nodes: + node = p.nodes[n] + if node.kind == 'Call': + cost = set ([(n, 20)]) + elif p.loop_id (n): + cost = set ([(p.loop_id (n), 50)]) + else: + cost = set ([(n, len (node.get_mem_accesses ()))]) + cost.discard ((n, 0)) + direct_costs[n] = cost + for n in p.tarjan_order: + prev_costs[n] = set.union (* ([direct_costs[n]] + + [prev_costs.get (c, set ()) for c in p.preds[n]])) + for n in reversed (p.tarjan_order): + cont_costs = [future_costs.get (c, set ()) + for c in p.nodes[n].get_conts ()] + cost = set.union (* ([direct_costs[n]] + cont_costs)) + p_ct = sum ([c for (_, c) in prev_costs[n]]) + future_costs[n] = cost + if p.nodes[n].kind != 'Cond' or p_ct > max_cost: + continue + ct = sum ([c for (_, c) in set.union (cost, prev_costs[n])]) + if ct < min_cost: + continue + [c1, c2] = [sum ([c for (_, c) + in set.union (cs, prev_costs[n])]) + for cs in cont_costs] + fracs[n] = ((c1 * c1) + (c2 * c2)) / (ct * ct * 1.0) + if fracs[n] < ratio: + divs.append (n) + divs.reverse () + if trace != None: + trace[0] = (direct_costs, future_costs, prev_costs, + int_costs, fracs) + return divs def compute_var_deps (nodes, outputs, preds, override_lvals_rvals = {}, - trace = None): - # outs = list of (outname, retvars) - var_deps = {} - visit = set () - visit.update (preds['Ret']) - visit.update (preds['Err']) - - nodes = contextual_cond_simps (nodes, preds) - - while visit: - n = visit.pop () - - node = simplify_node_elementary (nodes[n]) - if n in override_lvals_rvals: - (lvals, rvals) = override_lvals_rvals[n] - lvals = set (lvals) - rvals = set (rvals) - elif node.is_noop (): - lvals = set ([]) - rvals = set ([]) - else: - rvals = syntax.get_node_rvals (node) - rvals = set (rvals.iteritems ()) - lvals = set (node.get_lvals ()) - cont_vs = set () - - for c in node.get_conts (): - if c == 'Ret': - cont_vs.update (outputs (n)) - elif c == 'Err': - pass - else: - cont_vs.update (var_deps.get (c, [])) - vs = set.union (rvals, cont_vs - lvals) - - if n in var_deps and vs <= var_deps[n]: - continue - if trace and n in trace: - diff = vs - var_deps.get (n, set()) - printout ('add %s at %d' % (diff, n)) - printout (' %s, %s, %s, %s' % (len (vs), len (cont_vs), len (lvals), len (rvals))) - var_deps[n] = vs - visit.update (preds[n]) - - return var_deps + trace = None): + # outs = list of (outname, retvars) + var_deps = {} + visit = set () + visit.update (preds['Ret']) + visit.update (preds['Err']) + + nodes = contextual_cond_simps (nodes, preds) + + while visit: + n = visit.pop () + + node = simplify_node_elementary (nodes[n]) + if n in override_lvals_rvals: + (lvals, rvals) = override_lvals_rvals[n] + lvals = set (lvals) + rvals = set (rvals) + elif node.is_noop (): + lvals = set ([]) + rvals = set ([]) + else: + rvals = syntax.get_node_rvals (node) + rvals = set (rvals.iteritems ()) + lvals = set (node.get_lvals ()) + cont_vs = set () + + for c in node.get_conts (): + if c == 'Ret': + cont_vs.update (outputs (n)) + elif c == 'Err': + pass + else: + cont_vs.update (var_deps.get (c, [])) + vs = set.union (rvals, cont_vs - lvals) + + if n in var_deps and vs <= var_deps[n]: + continue + if trace and n in trace: + diff = vs - var_deps.get (n, set()) + printout ('add %s at %d' % (diff, n)) + printout (' %s, %s, %s, %s' % (len (vs), len (cont_vs), len (lvals), len (rvals))) + var_deps[n] = vs + visit.update (preds[n]) + + return var_deps def compute_loop_var_analysis (p, var_deps, n, override_nodes = None): - if override_nodes == None: - nodes = p.nodes - else: - nodes = override_nodes - - upd_vs = set ([v for n2 in p.loop_body (n) - if not nodes[n2].is_noop () - for v in nodes[n2].get_lvals ()]) - const_vs = set ([v for n2 in p.loop_body (n) - for v in var_deps[n2] if v not in upd_vs]) - - vca = compute_var_cycle_analysis (p, nodes, n, - const_vs, set (var_deps[n])) - vca = [(syntax.mk_var (nm, typ), data) - for ((nm, typ), data) in vca.items ()] - return vca + if override_nodes == None: + nodes = p.nodes + else: + nodes = override_nodes + + upd_vs = set ([v for n2 in p.loop_body (n) + if not nodes[n2].is_noop () + for v in nodes[n2].get_lvals ()]) + const_vs = set ([v for n2 in p.loop_body (n) + for v in var_deps[n2] if v not in upd_vs]) + + vca = compute_var_cycle_analysis (p, nodes, n, + const_vs, set (var_deps[n])) + vca = [(syntax.mk_var (nm, typ), data) + for ((nm, typ), data) in vca.items ()] + return vca cvca_trace = [] cvca_diag = [False] @@ -908,785 +887,782 @@ def compute_loop_var_analysis (p, var_deps, n, override_nodes = None): def compute_var_cycle_analysis (p, nodes, n, const_vars, vs, diag = None): - if diag == None: - diag = cvca_diag[0] - - cache = {} - del cvca_trace[:] - impossible_nodes = {} - loop = p.loop_body (n) - - def warm_cache_before (n2, v): - cvca_trace.append ((n2, v)) - cvca_trace.append ('(') - arc = [] - for i in range (100000): - opts = [n3 for n3 in p.preds[n2] if n3 in loop - if v not in nodes[n3].get_lvals () - if n3 != n - if (n3, v) not in cache] - if not opts: - break - n2 = opts[0] - arc.append (n2) - if not (len (arc) < 100000): - trace ('warmup overrun in compute_var_cycle_analysis') - trace ('chasing %s in %s' % (v, set (arc))) - assert False, (v, arc[-500:]) - for n2 in reversed (arc): - var_eval_before (n2, v) - cvca_trace.append (')') - - def var_eval_before (n2, v, do_cmp = True): - if (n2, v) in cache and do_cmp: - return cache[(n2, v)] - if n2 == n and do_cmp: - var_exp = mk_var (v[0], v[1]) - vs = set ([v for v in [v] if v not in const_vars]) - return (vs, var_exp) - warm_cache_before (n2, v) - ps = [n3 for n3 in p.preds[n2] if n3 in loop - if not node_impossible (n3)] - if not ps: - return None - vs = [var_eval_after (n3, v) for n3 in ps] - if not all ([v3 == vs[0] for v3 in vs]): - if diag: - trace ('vs disagree for %s @ %d: %s' % (v, n2, vs)) - r = None - else: - r = vs[0] - if do_cmp: - cache[(n2, v)] = r - return r - def var_eval_after (n2, v): - node = nodes[n2] - if node.kind == 'Call' and v in node.rets: - if diag: - trace ('fetched %s from call at %d' % (v, n2)) - return None - elif node.kind == 'Basic': - for (lv, val) in node.upds: - if lv == v: - return expr_eval_before (n2, val) - return var_eval_before (n2, v) - else: - return var_eval_before (n2, v) - def expr_eval_before (n2, expr): - if expr.kind == 'Op': - if expr.vals == []: - return (set(), expr) - vals = [expr_eval_before (n2, v) - for v in expr.vals] - if None in vals: - return None - s = set.union (* [s for (s, v) in vals]) - if len(s) > 1: - if diag: - trace ('too many vars for %s @ %d: %s' % (expr, n2, s)) - return None - return (s, Expr ('Op', expr.typ, - name = expr.name, - vals = [v for (s, v) in vals])) - elif expr.kind == 'Num': - return (set(), expr) - elif expr.kind == 'Var': - return var_eval_before (n2, - (expr.name, expr.typ)) - else: - if diag: - trace ('Unwalkable expr %s' % expr) - return None - def node_impossible (n2): - if n2 in impossible_nodes: - return impossible_nodes[n2] - if n2 == n or n2 in p.get_loop_splittables (n): - imposs = False - else: - pres = [n3 for n3 in p.preds[n2] - if n3 in loop if not node_impossible (n3)] - if n2 in impossible_nodes: - imposs = impossible_nodes[n2] - else: - imposs = not bool (pres) - impossible_nodes[n2] = imposs - node = nodes[n2] - if imposs or node.kind != 'Cond': - return imposs - if 1 >= len ([n3 for n3 in node.get_conts () - if n3 in loop]): - return imposs - c = expr_eval_before (n2, node.cond) - if c != None: - c = try_eval_expr (c[1]) - if c != None: - trace ('determined loop inner cond at %d equals %s' - % (n2, c == syntax.true_term)) - if c == syntax.true_term: - impossible_nodes[node.right] = True - elif c == syntax.false_term: - impossible_nodes[node.left] = True - return imposs - - vca = {} - for v in vs: - rv = var_eval_before (n, v, do_cmp = False) - if rv == None: - vca[v] = 'LoopVariable' - continue - (s, expr) = rv - if expr == mk_var (v[0], v[1]): - vca[v] = 'LoopConst' - continue - if all_vars_in_set (expr, const_vars): - # a repeatedly evaluated const expression, is const - vca[v] = 'LoopConst' - continue - if var_not_in_expr (mk_var (v[0], v[1]), expr): - # leaf calculations do not have data flow to - # themselves. the search algorithm doesn't - # have to worry about these. - vca[v] = 'LoopLeaf' - continue - (form, offs) = accumulator_closed_form (expr, v) - if form != None and all_vars_in_set (form (), const_vars): - vca[v] = ('LoopLinearSeries', form, offs) - else: - if diag: - trace ('No accumulator %s => %s' - % (v, expr)) - no_accum_expressions.add ((v, expr)) - vca[v] = 'LoopVariable' - return vca + if diag == None: + diag = cvca_diag[0] + + cache = {} + del cvca_trace[:] + impossible_nodes = {} + loop = p.loop_body (n) + + def warm_cache_before (n2, v): + cvca_trace.append ((n2, v)) + cvca_trace.append ('(') + arc = [] + for i in range (100000): + opts = [n3 for n3 in p.preds[n2] if n3 in loop + if v not in nodes[n3].get_lvals () + if n3 != n + if (n3, v) not in cache] + if not opts: + break + n2 = opts[0] + arc.append (n2) + if not (len (arc) < 100000): + trace ('warmup overrun in compute_var_cycle_analysis') + trace ('chasing %s in %s' % (v, set (arc))) + assert False, (v, arc[-500:]) + for n2 in reversed (arc): + var_eval_before (n2, v) + cvca_trace.append (')') + + def var_eval_before (n2, v, do_cmp = True): + if (n2, v) in cache and do_cmp: + return cache[(n2, v)] + if n2 == n and do_cmp: + var_exp = mk_var (v[0], v[1]) + vs = set ([v for v in [v] if v not in const_vars]) + return (vs, var_exp) + warm_cache_before (n2, v) + ps = [n3 for n3 in p.preds[n2] if n3 in loop + if not node_impossible (n3)] + if not ps: + return None + vs = [var_eval_after (n3, v) for n3 in ps] + if not all ([v3 == vs[0] for v3 in vs]): + if diag: + trace ('vs disagree for %s @ %d: %s' % (v, n2, vs)) + r = None + else: + r = vs[0] + if do_cmp: + cache[(n2, v)] = r + return r + def var_eval_after (n2, v): + node = nodes[n2] + if node.kind == 'Call' and v in node.rets: + if diag: + trace ('fetched %s from call at %d' % (v, n2)) + return None + elif node.kind == 'Basic': + for (lv, val) in node.upds: + if lv == v: + return expr_eval_before (n2, val) + return var_eval_before (n2, v) + else: + return var_eval_before (n2, v) + def expr_eval_before (n2, expr): + if expr.kind == 'Op': + if expr.vals == []: + return (set(), expr) + vals = [expr_eval_before (n2, v) + for v in expr.vals] + if None in vals: + return None + s = set.union (* [s for (s, v) in vals]) + if len(s) > 1: + if diag: + trace ('too many vars for %s @ %d: %s' % (expr, n2, s)) + return None + return (s, Expr ('Op', expr.typ, + name = expr.name, + vals = [v for (s, v) in vals])) + elif expr.kind == 'Num': + return (set(), expr) + elif expr.kind == 'Var': + return var_eval_before (n2, + (expr.name, expr.typ)) + else: + if diag: + trace ('Unwalkable expr %s' % expr) + return None + def node_impossible (n2): + if n2 in impossible_nodes: + return impossible_nodes[n2] + if n2 == n or n2 in p.get_loop_splittables (n): + imposs = False + else: + pres = [n3 for n3 in p.preds[n2] + if n3 in loop if not node_impossible (n3)] + if n2 in impossible_nodes: + imposs = impossible_nodes[n2] + else: + imposs = not bool (pres) + impossible_nodes[n2] = imposs + node = nodes[n2] + if imposs or node.kind != 'Cond': + return imposs + if 1 >= len ([n3 for n3 in node.get_conts () + if n3 in loop]): + return imposs + c = expr_eval_before (n2, node.cond) + if c != None: + c = try_eval_expr (c[1]) + if c != None: + trace ('determined loop inner cond at %d equals %s' + % (n2, c == syntax.true_term)) + if c == syntax.true_term: + impossible_nodes[node.right] = True + elif c == syntax.false_term: + impossible_nodes[node.left] = True + return imposs + + vca = {} + for v in vs: + rv = var_eval_before (n, v, do_cmp = False) + if rv == None: + vca[v] = 'LoopVariable' + continue + (s, expr) = rv + if expr == mk_var (v[0], v[1]): + vca[v] = 'LoopConst' + continue + if all_vars_in_set (expr, const_vars): + # a repeatedly evaluated const expression, is const + vca[v] = 'LoopConst' + continue + if var_not_in_expr (mk_var (v[0], v[1]), expr): + # leaf calculations do not have data flow to + # themselves. the search algorithm doesn't + # have to worry about these. + vca[v] = 'LoopLeaf' + continue + (form, offs) = accumulator_closed_form (expr, v) + if form != None and all_vars_in_set (form (), const_vars): + vca[v] = ('LoopLinearSeries', form, offs) + else: + if diag: + trace ('No accumulator %s => %s' + % (v, expr)) + no_accum_expressions.add ((v, expr)) + vca[v] = 'LoopVariable' + return vca eval_expr_solver = [None] def try_eval_expr (expr): - """attempt to reduce an expression to a single result, vaguely like - what constant propagation would do. it might work!""" - import search - if not eval_expr_solver[0]: - import solver - eval_expr_solver[0] = solver.Solver () - try: - return search.eval_model_expr ({}, eval_expr_solver[0], expr) - except KeyboardInterrupt, e: - raise e - except Exception, e: - return None + """attempt to reduce an expression to a single result, vaguely like + what constant propagation would do. it might work!""" + import search + if not eval_expr_solver[0]: + import solver + eval_expr_solver[0] = solver.Solver () + try: + return search.eval_model_expr ({}, eval_expr_solver[0], expr) + except KeyboardInterrupt, e: + raise e + except Exception, e: + return None expr_linear_sum = set (['Plus', 'Minus']) expr_linear_cast = set (['WordCast', 'WordCastSigned']) expr_linear_all = set.union (expr_linear_sum, expr_linear_cast, - ['Times', 'ShiftLeft']) + ['Times', 'ShiftLeft']) def possibly_linear (expr): - if expr.kind in set (['Var', 'Num', 'Symbol', 'Type', 'Token']): - return True - elif expr.is_op (expr_linear_all): - return all ([possibly_linear (x) for x in expr.vals]) - else: - return False + if expr.kind in set (['Var', 'Num', 'Symbol', 'Type', 'Token']): + return True + elif expr.is_op (expr_linear_all): + return all ([possibly_linear (x) for x in expr.vals]) + else: + return False def lv_expr (expr, env): - if expr in env: - return env[expr] - elif expr.kind in set (['Num', 'Symbol', 'Type', 'Token']): - return (expr, 'LoopConst', None, set ()) - elif expr.kind == 'Var': - return (None, None, None, None) - elif expr.kind != 'Op': - assert expr in env, expr - - lvs = [lv_expr (v, env) for v in expr.vals] - rs = [lv[1] for lv in lvs] - mk_offs = lambda vals: syntax.adjust_op_vals (expr, vals) - if None in rs: - return (None, None, None, None) - if set (rs) == set (['LoopConst']): - return (expr, 'LoopConst', None, set ()) - offs_set = set.union (* ([lv[3] for lv in lvs] + [set ()])) - arg_offs = [] - for (expr2, k, offs, _) in lvs: - if k == 'LoopConst' and expr2.typ.kind == 'Word': - arg_offs.append (syntax.mk_num (0, expr2.typ)) - else: - arg_offs.append (offs) - if expr.is_op (expr_linear_sum): - if set (rs) == set (['LoopConst', 'LoopLinearSeries']): - return (expr, 'LoopLinearSeries', mk_offs (arg_offs), - offs_set) - elif expr.is_op ('Times'): - if set (rs) == set (['LoopLinearSeries', 'LoopConst']): - # the new offset is the product of the linear offset - # and the constant value - [linear_offs] = [offs for (_, k, offs, _) in lvs - if k == 'LoopLinearSeries'] - [const_value] = [v for (v, k, _, _) in lvs - if k == 'LoopConst'] - return (expr, 'LoopLinearSeries', - mk_offs ([linear_offs, const_value]), offs_set) - if expr.is_op ('ShiftLeft'): - if rs == ['LoopLinearSeries', 'LoopConst']: - return (expr, 'LoopLinearSeries', - mk_offs ([arg_offs[0], lvs[1][0]]), offs_set) - if expr.is_op (expr_linear_cast): - if rs == ['LoopLinearSeries']: - return (expr, 'LoopLinearSeries', mk_offs (arg_offs), - offs_set) - return (None, None, None, None) + if expr in env: + return env[expr] + elif expr.kind in set (['Num', 'Symbol', 'Type', 'Token']): + return (expr, 'LoopConst', None, set ()) + elif expr.kind == 'Var': + return (None, None, None, None) + elif expr.kind != 'Op': + assert expr in env, expr + + lvs = [lv_expr (v, env) for v in expr.vals] + rs = [lv[1] for lv in lvs] + mk_offs = lambda vals: syntax.adjust_op_vals (expr, vals) + if None in rs: + return (None, None, None, None) + if set (rs) == set (['LoopConst']): + return (expr, 'LoopConst', None, set ()) + offs_set = set.union (* ([lv[3] for lv in lvs] + [set ()])) + arg_offs = [] + for (expr2, k, offs, _) in lvs: + if k == 'LoopConst' and expr2.typ.kind == 'Word': + arg_offs.append (syntax.mk_num (0, expr2.typ)) + else: + arg_offs.append (offs) + if expr.is_op (expr_linear_sum): + if set (rs) == set (['LoopConst', 'LoopLinearSeries']): + return (expr, 'LoopLinearSeries', mk_offs (arg_offs), + offs_set) + elif expr.is_op ('Times'): + if set (rs) == set (['LoopLinearSeries', 'LoopConst']): + # the new offset is the product of the linear offset + # and the constant value + [linear_offs] = [offs for (_, k, offs, _) in lvs + if k == 'LoopLinearSeries'] + [const_value] = [v for (v, k, _, _) in lvs + if k == 'LoopConst'] + return (expr, 'LoopLinearSeries', + mk_offs ([linear_offs, const_value]), offs_set) + if expr.is_op ('ShiftLeft'): + if rs == ['LoopLinearSeries', 'LoopConst']: + return (expr, 'LoopLinearSeries', + mk_offs ([arg_offs[0], lvs[1][0]]), offs_set) + if expr.is_op (expr_linear_cast): + if rs == ['LoopLinearSeries']: + return (expr, 'LoopLinearSeries', mk_offs (arg_offs), + offs_set) + return (None, None, None, None) # FIXME: this should probably be unified with compute_var_cycle_analysis, # but doing so is complicated def linear_series_exprs (p, loop, va): - def lv_init (v, data): - if data[0] == 'LoopLinearSeries': - return (v, 'LoopLinearSeries', data[2], set ([data[2]])) - elif data == 'LoopConst': - return (v, 'LoopConst', None, set ()) - else: - return (None, None, None, None) - cache = {loop: dict ([(v, lv_init (v, data)) for (v, data) in va])} - post_cache = {} - loop_body = p.loop_body (loop) - frontier = [n2 for n2 in p.nodes[loop].get_conts () - if n2 in loop_body] - def lv_merge ((v1, lv1, offs1, oset1), (v2, lv2, offs2, oset2)): - if v1 != v2: - return (None, None, None, None) - assert lv1 == lv2 and offs1 == offs2 - return (v1, lv1, offs1, oset1) - def compute_post (n): - if n in post_cache: - return post_cache[n] - pre_env = cache[n] - env = dict (cache[n]) - if p.nodes[n].kind == 'Basic': - for ((v, typ), rexpr) in p.nodes[n].upds: - env[mk_var (v, typ)] = lv_expr (rexpr, pre_env) - elif p.nodes[n].kind == 'Call': - for (v, typ) in p.nodes[n].get_lvals (): - env[mk_var (v, typ)] = (None, None, None, None) - post_cache[n] = env - return env - while frontier: - n = frontier.pop () - if [n2 for n2 in p.preds[n] if n2 in loop_body - if n2 not in cache]: - continue - if n in cache: - continue - envs = [compute_post (n2) for n2 in p.preds[n] - if n2 in loop_body] - all_vs = set.union (* [set (env) for env in envs]) - cache[n] = dict ([(v, foldr1 (lv_merge, - [env.get (v, (None, None, None, None)) - for env in envs])) - for v in all_vs]) - frontier.extend ([n2 for n2 in p.nodes[n].get_conts () - if n2 in loop_body]) - return cache + def lv_init (v, data): + if data[0] == 'LoopLinearSeries': + return (v, 'LoopLinearSeries', data[2], set ([data[2]])) + elif data == 'LoopConst': + return (v, 'LoopConst', None, set ()) + else: + return (None, None, None, None) + cache = {loop: dict ([(v, lv_init (v, data)) for (v, data) in va])} + post_cache = {} + loop_body = p.loop_body (loop) + frontier = [n2 for n2 in p.nodes[loop].get_conts () + if n2 in loop_body] + def lv_merge ((v1, lv1, offs1, oset1), (v2, lv2, offs2, oset2)): + if v1 != v2: + return (None, None, None, None) + assert lv1 == lv2 and offs1 == offs2 + return (v1, lv1, offs1, oset1) + def compute_post (n): + if n in post_cache: + return post_cache[n] + pre_env = cache[n] + env = dict (cache[n]) + if p.nodes[n].kind == 'Basic': + for ((v, typ), rexpr) in p.nodes[n].upds: + env[mk_var (v, typ)] = lv_expr (rexpr, pre_env) + elif p.nodes[n].kind == 'Call': + for (v, typ) in p.nodes[n].get_lvals (): + env[mk_var (v, typ)] = (None, None, None, None) + post_cache[n] = env + return env + while frontier: + n = frontier.pop () + if [n2 for n2 in p.preds[n] if n2 in loop_body + if n2 not in cache]: + continue + if n in cache: + continue + envs = [compute_post (n2) for n2 in p.preds[n] + if n2 in loop_body] + all_vs = set.union (* [set (env) for env in envs]) + cache[n] = dict ([(v, foldr1 (lv_merge, + [env.get (v, (None, None, None, None)) + for env in envs])) + for v in all_vs]) + frontier.extend ([n2 for n2 in p.nodes[n].get_conts () + if n2 in loop_body]) + return cache def get_loop_linear_offs (p, loop_head): - import search - va = search.get_loop_var_analysis_at (p, loop_head) - exprs = linear_series_exprs (p, loop_head, va) - def offs_fn (n, expr): - assert p.loop_id (n) == loop_head - env = exprs[n] - rv = lv_expr (expr, env) - if rv[1] == None: - return None - elif rv[1] == 'LoopConst': - return mk_num (0, expr.typ) - elif rv[1] == 'LoopLinearSeries': - return rv[2] - else: - assert not 'lv_expr kind understood', rv - return offs_fn + import search + va = search.get_loop_var_analysis_at (p, loop_head) + exprs = linear_series_exprs (p, loop_head, va) + def offs_fn (n, expr): + assert p.loop_id (n) == loop_head + env = exprs[n] + rv = lv_expr (expr, env) + if rv[1] == None: + return None + elif rv[1] == 'LoopConst': + return mk_num (0, expr.typ) + elif rv[1] == 'LoopLinearSeries': + return rv[2] + else: + assert not 'lv_expr kind understood', rv + return offs_fn def interesting_node_exprs (p, n, tags = None, use_pairings = True): - if tags == None: - tags = p.pairing.tags - node = p.nodes[n] - memaccs = node.get_mem_accesses () - vs = [(kind, ptr) for (kind, ptr, v, m) in memaccs] - vs += [('MemUpdateArg', v) for (kind, ptr, v, m) in memaccs - if kind == 'MemUpdate'] - - if node.kind == 'Call' and use_pairings: - tag = p.node_tags[n][0] - from target_objects import functions, pairings - import solver - fun = functions[node.fname] - arg_input_map = dict (azip (fun.inputs, node.args)) - pairs = [pair for pair in pairings.get (node.fname, []) - if pair.tags == tags] - if not pairs: - return vs - [pair] = pairs - in_eq_vs = [(('Call', pair.name, i), - var_subst (v, arg_input_map)) - for (i, ((lhs, l_s), (rhs, r_s))) - in enumerate (pair.eqs[0]) - if l_s.endswith ('_IN') and r_s.endswith ('_IN') - if l_s != r_s - if solver.typ_representable (lhs.typ) - for (v, site) in [(lhs, l_s), (rhs, r_s)] - if site == '%s_IN' % tag] - vs.extend (in_eq_vs) - return vs + if tags == None: + tags = p.pairing.tags + node = p.nodes[n] + memaccs = node.get_mem_accesses () + vs = [(kind, ptr) for (kind, ptr, v, m) in memaccs] + vs += [('MemUpdateArg', v) for (kind, ptr, v, m) in memaccs + if kind == 'MemUpdate'] + if node.kind == 'Call' and use_pairings: + tag = p.node_tags[n][0] + from target_objects import functions, pairings + import solver + fun = functions[node.fname] + arg_input_map = dict (azip (fun.inputs, node.args)) + pairs = [pair for pair in pairings.get (node.fname, []) + if pair.tags == tags] + if not pairs: + return vs + [pair] = pairs + in_eq_vs = [(('Call', pair.name, i), + var_subst (v, arg_input_map)) + for (i, ((lhs, l_s), (rhs, r_s))) + in enumerate (pair.eqs[0]) + if l_s.endswith ('_IN') and r_s.endswith ('_IN') + if l_s != r_s + if solver.typ_representable (lhs.typ) + for (v, site) in [(lhs, l_s), (rhs, r_s)] + if site == '%s_IN' % tag] + vs.extend (in_eq_vs) + return vs def interesting_linear_series_exprs (p, loop, va, tags = None, - use_pairings = True): - if tags == None: - tags = p.pairing.tags - expr_env = linear_series_exprs (p, loop, va) - res_env = {} - for (n, env) in expr_env.iteritems (): - vs = interesting_node_exprs (p, n) - - vs = [(kind, v, lv_expr (v, env)) for (kind, v) in vs] - vs = [(kind, v, offs, offs_set) - for (kind, v, (_, lv, offs, offs_set)) in vs - if lv == 'LoopLinearSeries'] - if vs: - res_env[n] = vs - return res_env + use_pairings = True): + if tags == None: + tags = p.pairing.tags + expr_env = linear_series_exprs (p, loop, va) + res_env = {} + for (n, env) in expr_env.iteritems (): + vs = interesting_node_exprs (p, n) + + vs = [(kind, v, lv_expr (v, env)) for (kind, v) in vs] + vs = [(kind, v, offs, offs_set) + for (kind, v, (_, lv, offs, offs_set)) in vs + if lv == 'LoopLinearSeries'] + if vs: + res_env[n] = vs + return res_env def mk_var_renames (xs, ys): - renames = {} - for (x, y) in azip (xs, ys): - assert x.kind == 'Var' and y.kind == 'Var' - assert x.name not in renames - renames[x.name] = y.name - return renames + renames = {} + for (x, y) in azip (xs, ys): + assert x.kind == 'Var' and y.kind == 'Var' + assert x.name not in renames + renames[x.name] = y.name + return renames def first_aligned_address (nodes, radix): - ks = [k for k in nodes - if k % radix == 0] - if ks: - return min (ks) - else: - return None + ks = [k for k in nodes + if k % radix == 0] + if ks: + return min (ks) + else: + return None def entry_aligned_address (fun, radix): - n = fun.entry - while n % radix != 0: - ns = fun.nodes[n].get_conts () - assert len (ns) == 1, (fun.name, n) - [n] = ns - return n + n = fun.entry + while n % radix != 0: + ns = fun.nodes[n].get_conts () + assert len (ns) == 1, (fun.name, n) + [n] = ns + return n def aligned_address_sanity (functions, symbols, radix): - for (f, func) in functions.iteritems (): - if f not in symbols: - # happens for static or invented functions sometimes - continue - if func.entry: - addr = first_aligned_address (func.nodes, radix) - if addr == None: - printout ('Warning: %s: no aligned instructions' % f) - continue - addr2 = symbols[f][0] - if addr != addr2: - printout ('target mismatch on func %s' % f) - printout (' (starts at 0x%x not 0x%x)' % (addr, addr2)) - return False - addr3 = entry_aligned_address (func, radix) - if addr3 != addr2: - printout ('entry mismatch on func %s' % f) - printout (' (enters at 0x%x not 0x%x)' % (addr3, addr2)) - return False - return True + for (f, func) in functions.iteritems (): + if f not in symbols: + # happens for static or invented functions sometimes + continue + if func.entry: + addr = first_aligned_address (func.nodes, radix) + if addr == None: + printout ('Warning: %s: no aligned instructions' % f) + continue + addr2 = symbols[f][0] + if addr != addr2: + printout ('target mismatch on func %s' % f) + printout (' (starts at 0x%x not 0x%x)' % (addr, addr2)) + return False + addr3 = entry_aligned_address (func, radix) + if addr3 != addr2: + printout ('entry mismatch on func %s' % f) + printout (' (enters at 0x%x not 0x%x)' % (addr3, addr2)) + return False + return True # variant of tarjan's strongly connected component algorithm def tarjan (graph, entries): - """tarjan (graph, entries) - variant of tarjan's strongly connected component algorithm - e.g. tarjan ({1: [2, 3], 3: [4, 5]}, [1]) - entries should not be reachable""" - data = {} - comps = [] - for v in entries: - assert v not in data - tarjan1 (graph, v, data, [], set ([]), comps) - return comps + """tarjan (graph, entries) + variant of tarjan's strongly connected component algorithm + e.g. tarjan ({1: [2, 3], 3: [4, 5]}, [1]) + entries should not be reachable""" + data = {} + comps = [] + for v in entries: + assert v not in data + tarjan1 (graph, v, data, [], set ([]), comps) + return comps def tarjan1 (graph, v, data, stack, stack_set, comps): - vs = [] - while True: - # skip through nodes with single successors - data[v] = [len(data), len(data)] - stack.append(v) - stack_set.add(v) - cs = graph[v] - if len (cs) != 1 or cs[0] in data: - break - vs.append ((v, cs[0])) - [v] = cs - - for c in graph[v]: - if c not in data: - tarjan1 (graph, c, data, stack, stack_set, comps) - data[v][1] = min (data[v][1], data[c][1]) - elif c in stack_set: - data[v][1] = min (data[v][1], data[c][0]) - - vs.reverse () - for (v2, c) in vs: - data[v2][1] = min (data[v2][1], data[c][1]) - - for (v2, _) in [(v, 0)] + vs: - if data[v2][1] == data[v2][0]: - comp = [] - while True: - x = stack.pop () - stack_set.remove (x) - if x == v2: - break - comp.append (x) - comps.append ((v2, comp)) + vs = [] + while True: + # skip through nodes with single successors + data[v] = [len(data), len(data)] + stack.append(v) + stack_set.add(v) + cs = graph[v] + if len (cs) != 1 or cs[0] in data: + break + vs.append ((v, cs[0])) + [v] = cs + + for c in graph[v]: + if c not in data: + tarjan1 (graph, c, data, stack, stack_set, comps) + data[v][1] = min (data[v][1], data[c][1]) + elif c in stack_set: + data[v][1] = min (data[v][1], data[c][0]) + + vs.reverse () + for (v2, c) in vs: + data[v2][1] = min (data[v2][1], data[c][1]) + + for (v2, _) in [(v, 0)] + vs: + if data[v2][1] == data[v2][0]: + comp = [] + while True: + x = stack.pop () + stack_set.remove (x) + if x == v2: + break + comp.append (x) + comps.append ((v2, comp)) def divides_loop (graph, split_set): - graph2 = dict (graph) - for n in split_set: - graph2[n] = [] - assert 'ENTRY_POINT' not in graph2 - graph2['ENTRY_POINT'] = list (graph) - comps = tarjan (graph2, ['ENTRY_POINT']) - return not ([(h, t) for (h, t) in comps if t]) + graph2 = dict (graph) + for n in split_set: + graph2[n] = [] + assert 'ENTRY_POINT' not in graph2 + graph2['ENTRY_POINT'] = list (graph) + comps = tarjan (graph2, ['ENTRY_POINT']) + return not ([(h, t) for (h, t) in comps if t]) def strongly_connected_split_points1 (graph): - """find the nodes of a strongly connected - component which, when removed, disconnect the component. - complex loops lack such a split point.""" - - # find one simple cycle in the graph - walk = [] - walk_set = set () - n = min (graph) - while n not in walk_set: - walk.append (n) - walk_set.add (n) - n = graph[n][0] - i = walk.index (n) - cycle = walk[i:] - - def subgraph_test (subgraph): - graph2 = dict ([(n, [n2 for n2 in graph[n] if n2 in subgraph]) - for n in subgraph]) - graph2['HEAD'] = list (subgraph) - comps = tarjan (graph2, ['HEAD']) - return bool ([h for (h, t) in comps if t]) - - cycle_set = set (cycle) - cycle = [('Node', set ([n]), False, - [n2 for n2 in graph[n] if n2 != graph[n][0]]) - for n in cycle] - i = 0 - while i < len (cycle): - print i, cycle - (kind, ns, test, unvisited) = cycle[i] - if not unvisited: - i += 1 - continue - n = unvisited.pop () - arc_set = set () - while n not in cycle_set: - if n in arc_set: - # found two totally disjoint loops, so there - # are no splitting points - return set () - arc_set.add (n) - n = graph[n][0] - if n in ns: - if kind == 'Node': - # only this node can be a splittable now. - if subgraph_test (set (graph) - set ([n])): - return set () - else: - return set ([n]) - else: - cycle[i] = (kind, ns, True, unvisited) - ns.update (arc_set) - continue - j = (i + 1) % len (cycle) - new_ns = set () - new_unvisited = set () - new_test = False - while n not in cycle[j][1]: - new_ns.update (cycle[j][1]) - new_unvisited.update (cycle[j][3]) - new_test = cycle[j][2] or new_test - j = (j + 1) % len (cycle) - new_ns.update (arc_set) - new_unvisited.update ([n3 for n2 in arc_set for n3 in graph[n2]]) - new_v = ('Group', new_ns, new_test, list (new_unvisited - new_ns)) - print i, j, n - if j > i: - cycle[i + 1:j] = [new_v] - else: - cycle = [cycle[i], new_v] + cycle[j:i] - i = 0 - cycle_set.update (new_ns) - for (kind, ns, test, unvisited) in cycle: - if test and subgraph_test (ns): - return set () - return set ([n for (kind, ns, _, _) in cycle - if kind == 'Node' for n in ns]) + """find the nodes of a strongly connected + component which, when removed, disconnect the component. + complex loops lack such a split point.""" + + # find one simple cycle in the graph + walk = [] + walk_set = set () + n = min (graph) + while n not in walk_set: + walk.append (n) + walk_set.add (n) + n = graph[n][0] + i = walk.index (n) + cycle = walk[i:] + + def subgraph_test (subgraph): + graph2 = dict ([(n, [n2 for n2 in graph[n] if n2 in subgraph]) + for n in subgraph]) + graph2['HEAD'] = list (subgraph) + comps = tarjan (graph2, ['HEAD']) + return bool ([h for (h, t) in comps if t]) + + cycle_set = set (cycle) + cycle = [('Node', set ([n]), False, + [n2 for n2 in graph[n] if n2 != graph[n][0]]) + for n in cycle] + i = 0 + while i < len (cycle): + (kind, ns, test, unvisited) = cycle[i] + if not unvisited: + i += 1 + continue + n = unvisited.pop () + arc_set = set () + while n not in cycle_set: + if n in arc_set: + # found two totally disjoint loops, so there + # are no splitting points + return set () + arc_set.add (n) + n = graph[n][0] + if n in ns: + if kind == 'Node': + # only this node can be a splittable now. + if subgraph_test (set (graph) - set ([n])): + return set () + else: + return set ([n]) + else: + cycle[i] = (kind, ns, True, unvisited) + ns.update (arc_set) + continue + j = (i + 1) % len (cycle) + new_ns = set () + new_unvisited = set () + new_test = False + while n not in cycle[j][1]: + new_ns.update (cycle[j][1]) + new_unvisited.update (cycle[j][3]) + new_test = cycle[j][2] or new_test + j = (j + 1) % len (cycle) + new_ns.update (arc_set) + new_unvisited.update ([n3 for n2 in arc_set for n3 in graph[n2]]) + new_v = ('Group', new_ns, new_test, list (new_unvisited - new_ns)) + if j > i: + cycle[i + 1:j] = [new_v] + else: + cycle = [cycle[i], new_v] + cycle[j:i] + i = 0 + cycle_set.update (new_ns) + for (kind, ns, test, unvisited) in cycle: + if test and subgraph_test (ns): + return set () + return set ([n for (kind, ns, _, _) in cycle + if kind == 'Node' for n in ns]) def strongly_connected_split_points (graph): - res = strongly_connected_split_points1 (graph) - res2 = set () - for n in graph: - graph2 = dict (graph) - graph2[n] = [] - graph2['ENTRY'] = list (graph) - comps = tarjan (graph2, ['ENTRY']) - if not [comp for comp in comps if comp[1]]: - res2.add (n) - assert res == res2, (graph, res, res2) - return res + res = strongly_connected_split_points1 (graph) + res2 = set () + for n in graph: + graph2 = dict (graph) + graph2[n] = [] + graph2['ENTRY'] = list (graph) + comps = tarjan (graph2, ['ENTRY']) + if not [comp for comp in comps if comp[1]]: + res2.add (n) + assert res == res2, (graph, res, res2) + return res def get_one_loop_splittable (p, loop_set): - """discover a component of a strongly connected - component which, when removed, disconnects the component. - complex loops lack such a split point.""" - candidates = set (loop_set) - graph = dict ([(x, [y for y in p.nodes[x].get_conts () - if y in loop_set]) for x in loop_set]) - while candidates: - loop2 = find_loop_avoiding (graph, loop_set, candidates) - candidates = set.intersection (loop2, candidates) - if not candidates: - return None - n = candidates.pop () - graph2 = dict ([(x, [y for y in graph[x] if y != n]) - for x in graph]) - comps = tarjan (graph2, [n]) - comps = [(h, t) for (h, t) in comps if t] - if not comps: - return n - for (h, t) in comps: - s = set ([h] + t) - candidates = set.intersection (s, candidates) - return None + """discover a component of a strongly connected + component which, when removed, disconnects the component. + complex loops lack such a split point.""" + candidates = set (loop_set) + graph = dict ([(x, [y for y in p.nodes[x].get_conts () + if y in loop_set]) for x in loop_set]) + while candidates: + loop2 = find_loop_avoiding (graph, loop_set, candidates) + candidates = set.intersection (loop2, candidates) + if not candidates: + return None + n = candidates.pop () + graph2 = dict ([(x, [y for y in graph[x] if y != n]) + for x in graph]) + comps = tarjan (graph2, [n]) + comps = [(h, t) for (h, t) in comps if t] + if not comps: + return n + for (h, t) in comps: + s = set ([h] + t) + candidates = set.intersection (s, candidates) + return None def find_loop_avoiding (graph, loop, avoid): - n = (list (loop - avoid) + list (loop))[0] - arc = [n] - visited = set ([n]) - while True: - cs = set (graph[n]) - acs = cs - avoid - vcs = set.intersection (cs, visited) - if vcs: - n = vcs.pop () - break - elif acs: - n = acs.pop () - else: - n = cs.pop () - visited.add (n) - arc.append (n) - [i] = [i for (i, n2) in enumerate (arc) if n2 == n] - return set (arc[i:]) + n = (list (loop - avoid) + list (loop))[0] + arc = [n] + visited = set ([n]) + while True: + cs = set (graph[n]) + acs = cs - avoid + vcs = set.intersection (cs, visited) + if vcs: + n = vcs.pop () + break + elif acs: + n = acs.pop () + else: + n = cs.pop () + visited.add (n) + arc.append (n) + [i] = [i for (i, n2) in enumerate (arc) if n2 == n] + return set (arc[i:]) # non-equality relations in proof hypotheses are recorded as a pretend # equality and reverted to their 'real' meaning here. def mk_stack_wrapper (stack_ptr, stack, excepts): - return syntax.mk_rel_wrapper ('StackWrapper', - [stack_ptr, stack] + excepts) + return syntax.mk_rel_wrapper ('StackWrapper', + [stack_ptr, stack] + excepts) def mk_mem_acc_wrapper (addr, v): - return syntax.mk_rel_wrapper ('MemAccWrapper', [addr, v]) + return syntax.mk_rel_wrapper ('MemAccWrapper', [addr, v]) def mk_mem_wrapper (m): - return syntax.mk_rel_wrapper ('MemWrapper', [m]) - -def tm_with_word32_list (xs): - if xs: - return foldr1 (mk_plus, map (mk_word32, xs)) - else: - return mk_uminus (mk_word32 (0)) - -def word32_list_from_tm (t): - xs = [] - while t.is_op ('Plus'): - [x, t] = t.vals - assert x.kind == 'Num' and x.typ == word32T - xs.append (x.val) - if t.kind == 'Num': - xs.append (t.val) - return xs + return syntax.mk_rel_wrapper ('MemWrapper', [m]) + +def list_from_tm (t): + xs = [] + while t.is_op ('Plus'): + [x, t] = t.vals + assert x.kind == 'Num' and x.typ == syntax.arch.word_type + xs.append (x.val) + if t.kind == 'Num': + xs.append (t.val) + return xs + +def tm_with_word_list(xs): + if xs: + return foldr1(mk_plus, map(syntax.arch.mk_word, xs)) + else: + return mk_uminus(syntax.arch.mk_word(0)) def mk_eq_selective_wrapper (v, (xs, ys)): - # this is a huge hack, but we need to put these lists somewhere - xs = tm_with_word32_list (xs) - ys = tm_with_word32_list (ys) - return syntax.mk_rel_wrapper ('EqSelectiveWrapper', [v, xs, ys]) + # this is a huge hack, but we need to put these lists somewhere + xs = tm_with_word_list (xs) + ys = tm_with_word_list (ys) + return syntax.mk_rel_wrapper ('EqSelectiveWrapper', [v, xs, ys]) def apply_rel_wrapper (lhs, rhs): - assert lhs.typ == syntax.builtinTs['RelWrapper'] - assert rhs.typ == syntax.builtinTs['RelWrapper'] - assert lhs.kind == 'Op' - assert rhs.kind == 'Op' - ops = set ([lhs.name, rhs.name]) - if ops == set (['StackWrapper']): - [sp1, st1] = lhs.vals[:2] - [sp2, st2] = rhs.vals[:2] - excepts = list (set (lhs.vals[2:] + rhs.vals[2:])) - for p in excepts: - st1 = syntax.mk_memupd (st1, p, syntax.mk_word32 (0)) - st2 = syntax.mk_memupd (st2, p, syntax.mk_word32 (0)) - return syntax.Expr ('Op', boolT, name = 'StackEquals', - vals = [sp1, st1, sp2, st2]) - elif ops == set (['MemAccWrapper', 'MemWrapper']): - [acc] = [v for v in [lhs, rhs] if v.is_op ('MemAccWrapper')] - [addr, val] = acc.vals - assert addr.typ == syntax.word32T - [m] = [v for v in [lhs, rhs] if v.is_op ('MemWrapper')] - [m] = m.vals - assert m.typ == builtinTs['Mem'] - expr = mk_eq (mk_memacc (m, addr, val.typ), val) - return expr - elif ops == set (['EqSelectiveWrapper']): - [lhs_v, _, _] = lhs.vals - [rhs_v, _, _] = rhs.vals - if lhs_v.typ == syntax.builtinTs['RelWrapper']: - return apply_rel_wrapper (lhs_v, rhs_v) - else: - return mk_eq (lhs, rhs) - else: - assert not 'rel wrapper opname understood' + assert lhs.typ == syntax.builtinTs['RelWrapper'] + assert rhs.typ == syntax.builtinTs['RelWrapper'] + assert lhs.kind == 'Op' + assert rhs.kind == 'Op' + ops = set ([lhs.name, rhs.name]) + if ops == set (['StackWrapper']): + [sp1, st1] = lhs.vals[:2] + [sp2, st2] = rhs.vals[:2] + excepts = list (set (lhs.vals[2:] + rhs.vals[2:])) + for p in excepts: + st1 = syntax.mk_memupd (st1, p, syntax.arch.mk_word(0)) + st2 = syntax.mk_memupd (st2, p, syntax.arch.mk_word(0)) + + return syntax.Expr ('Op', boolT, name = 'StackEquals', + vals = [sp1, st1, sp2, st2]) + elif ops == set (['MemAccWrapper', 'MemWrapper']): + [acc] = [v for v in [lhs, rhs] if v.is_op ('MemAccWrapper')] + [addr, val] = acc.vals + assert addr.typ == syntax.arch.wordT + [m] = [v for v in [lhs, rhs] if v.is_op ('MemWrapper')] + [m] = m.vals + assert m.typ == builtinTs['Mem'] + expr = mk_eq (mk_memacc (m, addr, val.typ), val) + return expr + elif ops == set (['EqSelectiveWrapper']): + [lhs_v, _, _] = lhs.vals + [rhs_v, _, _] = rhs.vals + if lhs_v.typ == syntax.builtinTs['RelWrapper']: + return apply_rel_wrapper (lhs_v, rhs_v) + else: + return mk_eq (lhs, rhs) + else: + assert not 'rel wrapper opname understood' def inst_eq_at_visit (exp, vis): - if not exp.is_op ('EqSelectiveWrapper'): - return True - [_, xs, ys] = exp.vals - # hacks - xs = word32_list_from_tm (xs) - ys = word32_list_from_tm (ys) - if vis.kind == 'Number': - return vis.n in xs - elif vis.kind == 'Offset': - return vis.n in ys - else: - assert not 'visit kind useable', vis + if not exp.is_op ('EqSelectiveWrapper'): + return True + [_, xs, ys] = exp.vals + xs = list_from_tm(xs) + ys = list_from_tm(ys) + if vis.kind == 'Number': + return vis.n in xs + elif vis.kind == 'Offset': + return vis.n in ys + else: + assert not 'visit kind useable', vis def strengthen_hyp (expr, sign = 1): - if not expr.kind == 'Op': - return expr - if expr.name in ['And', 'Or']: - vals = [strengthen_hyp (v, sign) for v in expr.vals] - return syntax.adjust_op_vals (expr, vals) - elif expr.name == 'Implies': - [l, r] = expr.vals - l = strengthen_hyp (l, - sign) - r = strengthen_hyp (r, sign) - return syntax.mk_implies (l, r) - elif expr.name == 'Not': - [x] = expr.vals - x = strengthen_hyp (x, - sign) - return syntax.mk_not (x) - elif expr.name == 'StackEquals': - if sign == 1: - return syntax.Expr ('Op', boolT, - name = 'ImpliesStackEquals', vals = expr.vals) - else: - return syntax.Expr ('Op', boolT, - name = 'StackEqualsImplies', vals = expr.vals) - elif expr.name == 'ROData': - if sign == 1: - return syntax.Expr ('Op', boolT, - name = 'ImpliesROData', vals = expr.vals) - else: - return expr - elif expr.name == 'Equals' and expr.vals[0].typ == boolT: - vals = expr.vals - if vals[1] in [syntax.true_term, syntax.false_term]: - vals = [vals[1], vals[0]] - if vals[0] == syntax.true_term: - return strengthen_hyp (vals[1], sign) - elif vals[0] == syntax.false_term: - return strengthen_hyp (syntax.mk_not (vals[1]), sign) - else: - return expr - else: - return expr + if not expr.kind == 'Op': + return expr + if expr.name in ['And', 'Or']: + vals = [strengthen_hyp (v, sign) for v in expr.vals] + return syntax.adjust_op_vals (expr, vals) + elif expr.name == 'Implies': + [l, r] = expr.vals + l = strengthen_hyp (l, - sign) + r = strengthen_hyp (r, sign) + return syntax.mk_implies (l, r) + elif expr.name == 'Not': + [x] = expr.vals + x = strengthen_hyp (x, - sign) + return syntax.mk_not (x) + elif expr.name == 'StackEquals': + if sign == 1: + return syntax.Expr ('Op', boolT, + name = 'ImpliesStackEquals', vals = expr.vals) + else: + return syntax.Expr ('Op', boolT, + name = 'StackEqualsImplies', vals = expr.vals) + elif expr.name == 'ROData': + if sign == 1: + return syntax.Expr ('Op', boolT, + name = 'ImpliesROData', vals = expr.vals) + else: + return expr + elif expr.name == 'Equals' and expr.vals[0].typ == boolT: + vals = expr.vals + if vals[1] in [syntax.true_term, syntax.false_term]: + vals = [vals[1], vals[0]] + if vals[0] == syntax.true_term: + return strengthen_hyp (vals[1], sign) + elif vals[0] == syntax.false_term: + return strengthen_hyp (syntax.mk_not (vals[1]), sign) + else: + return expr + else: + return expr def weaken_assert (expr): - return strengthen_hyp (expr, -1) + return strengthen_hyp (expr, -1) pred_logic_ops = set (['Not', 'And', 'Or', 'Implies']) def norm_neg (expr): - if not expr.is_op ('Not'): - return expr - [nexpr] = expr.vals - if not nexpr.is_op (pred_logic_ops): - return expr - if nexpr.is_op ('Not'): - [expr] = nexpr.vals - return norm_neg (expr) - [x, y] = nexpr.vals - if nexpr.is_op ('And'): - return mk_or (norm_mk_not (x), norm_mk_not (y)) - elif nexpr.is_op ('Or'): - return mk_and (norm_mk_not (x), norm_mk_not (y)) - elif nexpr.is_op ('Implies'): - return mk_and (x, mk_not (y)) + if not expr.is_op ('Not'): + return expr + [nexpr] = expr.vals + if not nexpr.is_op (pred_logic_ops): + return expr + if nexpr.is_op ('Not'): + [expr] = nexpr.vals + return norm_neg (expr) + [x, y] = nexpr.vals + if nexpr.is_op ('And'): + return mk_or (norm_mk_not (x), norm_mk_not (y)) + elif nexpr.is_op ('Or'): + return mk_and (norm_mk_not (x), norm_mk_not (y)) + elif nexpr.is_op ('Implies'): + return mk_and (x, mk_not (y)) def norm_mk_not (expr): - return norm_neg (mk_not (expr)) + return norm_neg (mk_not (expr)) def split_conjuncts (expr): - expr = norm_neg (expr) - if expr.is_op ('And'): - [x, y] = expr.vals - return split_conjuncts (x) + split_conjuncts (y) - else: - return [expr] + expr = norm_neg (expr) + if expr.is_op ('And'): + [x, y] = expr.vals + return split_conjuncts (x) + split_conjuncts (y) + else: + return [expr] def split_disjuncts (expr): - expr = norm_neg (expr) - if expr.is_op ('Or'): - [x, y] = expr.vals - return split_disjuncts (x) + split_disjuncts (y) - else: - return [expr] + expr = norm_neg (expr) + if expr.is_op ('Or'): + [x, y] = expr.vals + return split_disjuncts (x) + split_disjuncts (y) + else: + return [expr] def binary_search_least (test, minimum, maximum): - """find least n, minimum <= n <= maximum, for which test (n).""" - assert maximum >= minimum - if test (minimum): - return minimum - if maximum == minimum or not test (maximum): - return None - while maximum > minimum + 1: - cur = (minimum + maximum) / 2 - if test (cur): - maximum = cur - else: - minimum = cur + 1 - assert minimum + 1 == maximum - return maximum + """find least n, minimum <= n <= maximum, for which test (n).""" + assert maximum >= minimum + if test (minimum): + return minimum + if maximum == minimum or not test (maximum): + return None + while maximum > minimum + 1: + cur = (minimum + maximum) / 2 + if test (cur): + maximum = cur + else: + minimum = cur + 1 + assert minimum + 1 == maximum + return maximum def binary_search_greatest (test, minimum, maximum): - """find greatest n, minimum <= n <= maximum, for which test (n).""" - assert maximum >= minimum - if test (maximum): - return maximum - if maximum == minimum or not test (minimum): - return None - while maximum > minimum + 1: - cur = (minimum + maximum) / 2 - if test (cur): - minimum = cur - else: - maximum = cur - 1 - assert minimum + 1 == maximum - return minimum + """find greatest n, minimum <= n <= maximum, for which test (n).""" + assert maximum >= minimum + if test (maximum): + return maximum + if maximum == minimum or not test (minimum): + return None + while maximum > minimum + 1: + cur = (minimum + maximum) / 2 + if test (cur): + minimum = cur + else: + maximum = cur - 1 + assert minimum + 1 == maximum + return minimum diff --git a/loop-example/O1/target.py b/loop-example/O1/target.py index a7407638..89c668f8 100644 --- a/loop-example/O1/target.py +++ b/loop-example/O1/target.py @@ -39,12 +39,12 @@ syntax.check_funs (functions) def asm_split_pairings (): - pairs = [(s, 'Loop.' + s) for s in afunctions] - target_objects.use_hooks.add ('stack_logic') - import stack_logic - stack_bounds = '%s/StackBounds.txt' % target_dir - new_pairings = stack_logic.mk_stack_pairings (pairs, stack_bounds) - pairings.update (new_pairings) + pairs = [(s, 'Loop.' + s) for s in afunctions] + target_objects.use_hooks.add ('stack_logic') + import stack_logic + stack_bounds = '%s/StackBounds.txt' % target_dir + new_pairings = stack_logic.mk_stack_pairings (pairs, stack_bounds) + pairings.update (new_pairings) asm_split_pairings () diff --git a/loop-example/O2/target.py b/loop-example/O2/target.py index 312c57fc..c24ece2a 100644 --- a/loop-example/O2/target.py +++ b/loop-example/O2/target.py @@ -39,12 +39,12 @@ syntax.check_funs (functions) def asm_split_pairings (): - pairs = [(s, 'Loop.' + s) for s in afunctions] - target_objects.use_hooks.add ('stack_logic') - import stack_logic - stack_bounds = '%s/StackBounds.txt' % target_dir - new_pairings = stack_logic.mk_stack_pairings (pairs, stack_bounds) - pairings.update (new_pairings) + pairs = [(s, 'Loop.' + s) for s in afunctions] + target_objects.use_hooks.add ('stack_logic') + import stack_logic + stack_bounds = '%s/StackBounds.txt' % target_dir + new_pairings = stack_logic.mk_stack_pairings (pairs, stack_bounds) + pairings.update (new_pairings) asm_split_pairings () diff --git a/loop-example/synth/target.py b/loop-example/synth/target.py index da3a175b..02e9df21 100644 --- a/loop-example/synth/target.py +++ b/loop-example/synth/target.py @@ -25,16 +25,16 @@ #pseudo_compile.combine_function_duplicates (functions) def run_pairings (): - for f in functions: - if f.startswith ('C.'): - f2 = 'mc_' + f[2:] - else: - f2 = f + '_impl' - if f2 in functions: - pair = logic.mk_pairing (functions, f, f2) - pairings[f] = [pair] - pairings[f2] = [pair] - print '%d pairing halves built.' % (len (pairings)) + for f in functions: + if f.startswith ('C.'): + f2 = 'mc_' + f[2:] + else: + f2 = f + '_impl' + if f2 in functions: + pair = logic.mk_pairing (functions, f, f2) + pairings[f] = [pair] + pairings[f2] = [pair] + print '%d pairing halves built.' % (len (pairings)) run_pairings () diff --git a/loop_bounds.py b/loop_bounds.py index 8783276a..2bb66082 100644 --- a/loop_bounds.py +++ b/loop_bounds.py @@ -22,12 +22,12 @@ def downBinSearch(minimum, maximum, tryFun): upperBound = maximum lowerBound = minimum while upperBound > lowerBound: - print 'searching in %d - %d' % (lowerBound,upperBound) - cur = (lowerBound + upperBound) / 2 - if tryFun(cur): - upperBound = cur - else: - lowerBound = cur + 1 + print 'searching in %d - %d' % (lowerBound,upperBound) + cur = (lowerBound + upperBound) / 2 + if tryFun(cur): + upperBound = cur + else: + lowerBound = cur + 1 assert upperBound == lowerBound ret = lowerBound return ret @@ -40,19 +40,19 @@ def upDownBinSearch (minimum, maximum, tryFun): more than twice as high as the bound, which may avoid some issues.""" upperBound = 2 * minimum while upperBound < maximum: - if tryFun (upperBound): - return downBinSearch (minimum, upperBound, tryFun) - else: - upperBound *= 2 + if tryFun (upperBound): + return downBinSearch (minimum, upperBound, tryFun) + else: + upperBound *= 2 if tryFun (maximum): - return downBinSearch (minimum, maximum, tryFun) + return downBinSearch (minimum, maximum, tryFun) else: - return None + return None def addr_of_node (preds, n): - while not trace_refute.is_addr (n): - [n] = preds[n] - return n + while not trace_refute.is_addr (n): + [n] = preds[n] + return n def all_asm_functions (): ss = stack_logic.get_functions_with_tag ('ASM') @@ -62,31 +62,31 @@ def all_asm_functions (): def build_call_site_set (): for f in all_asm_functions (): - preds = logic.compute_preds (functions[f].nodes) - for (n, node) in functions[f].nodes.iteritems (): - if node.kind == 'Call': - s = call_site_set.setdefault (node.fname, set ()) - s.add (addr_of_node (preds, n)) + preds = logic.compute_preds (functions[f].nodes) + for (n, node) in functions[f].nodes.iteritems (): + if node.kind == 'Call': + s = call_site_set.setdefault (node.fname, set ()) + s.add (addr_of_node (preds, n)) call_site_set[('IsLoaded', None)] = True def all_call_sites (f): if not call_site_set: - build_call_site_set () + build_call_site_set () return list (call_site_set.get (f, [])) #naive binary search to find loop bounds def findLoopBoundBS(p_n, p, restrs=None, hyps=None, try_seq=None): if hyps == None: - hyps = [] + hyps = [] #print 'restrs: %s' % str(restrs) if try_seq == None: #bound_try_seq = [1,2,3,4,5,10,50,130,200,260] #bound_try_seq = [0,1,2,3,4,5,10,50,260] calls = [n for n in p.loop_body (p_n) if p.nodes[n].kind == 'Call'] if calls: - bound_try_seq = [0,1,20] + bound_try_seq = [0,1,20] else: - bound_try_seq = [0,1,20,34] + bound_try_seq = [0,1,20,34] else: bound_try_seq = try_seq rep = mk_graph_slice (p, fast = True) @@ -110,37 +110,37 @@ def findLoopBoundBS(p_n, p, restrs=None, hyps=None, try_seq=None): #do a downward binary search to find the concrete loop bound if index == 0: - loop_bound = bound_try_seq[0] - print 'bound = %d' % loop_bound - return loop_bound + loop_bound = bound_try_seq[0] + print 'bound = %d' % loop_bound + return loop_bound loop_bound = downBinSearch(bound_try_seq[index-1], bound_try_seq[index], lambda x: tryLoopBound(p_n,p,[x],rep,restrs=restrs, hyps=hyps, bin_return=True)) print 'bound = %d' % loop_bound return loop_bound def default_n_vc_cases (p, n): - head = p.loop_id (n) - general = [(n2, rep_graph.vc_options ([0], [1])) - for n2 in p.loop_heads () - if n2 != head] + head = p.loop_id (n) + general = [(n2, rep_graph.vc_options ([0], [1])) + for n2 in p.loop_heads () + if n2 != head] - if head: - return [(n, tuple (general + [(head, rep_graph.vc_num (1))])), - (n, tuple (general + [(head, rep_graph.vc_offs (1))]))] - specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head] - return [(n, tuple (general + specific))] + if head: + return [(n, tuple (general + [(head, rep_graph.vc_num (1))])), + (n, tuple (general + [(head, rep_graph.vc_offs (1))]))] + specific = [(head, rep_graph.vc_offs (1)) for _ in [1] if head] + return [(n, tuple (general + specific))] def callNodes(p, fs= None): - ns = [n for n in p.nodes if p.nodes[n].kind == 'Call'] - if fs != None: - ns = [n for n in ns if p.nodes[n].fname in fs] - return ns + ns = [n for n in p.nodes if p.nodes[n].kind == 'Call'] + if fs != None: + ns = [n for n in ns if p.nodes[n].fname in fs] + return ns def noHaltHyps(split,p): ret = [] all_halts = callNodes(p,fs=['halt']) for x in all_halts: - ret += [rep_graph.pc_false_hyp((n_vc, p.node_tags[x][0])) - for n_vc in default_n_vc_cases (p, x)] + ret += [rep_graph.pc_false_hyp((n_vc, p.node_tags[x][0])) + for n_vc in default_n_vc_cases (p, x)] return ret def tryLoopBound(p_n, p, bounds,rep,restrs =None, hints =None,kind = 'Number',bin_return = False,hyps = None): @@ -149,37 +149,37 @@ def tryLoopBound(p_n, p, bounds,rep,restrs =None, hints =None,kind = 'Number',bi if hints == None: hints = [] if hyps == None: - hyps = [] + hyps = [] tag = p.node_tags[p_n][0] from stack_logic import default_n_vc print 'trying bound: %s' % bounds ret_bounds = [] for (index,i) in enumerate(bounds): - print 'testing %d' % i - restrs2 = restrs + ((p_n, VisitCount (kind, i)), ) - try: - pc = rep.get_pc ((p_n, restrs2)) - except: - print 'get_pc failed' - if bin_return: - return False - else: - return -1 - #print 'got rep_.get_pc' - restrs3 = restr_others (p, restrs2, 2) - epc = rep.get_pc (('Err', restrs3), tag = tag) - hyp = mk_implies (mk_not (epc), mk_not (pc)) - hyps = hyps + noHaltHyps(p_n,p) - - #hyps = [] - #print 'calling test_hyp_whyps' - if rep.test_hyp_whyps (hyp, hyps): - print 'p_n %d: split limit found: %d' % (p_n, i) - if bin_return: - return True - return index + print 'testing %d' % i + restrs2 = restrs + ((p_n, VisitCount (kind, i)), ) + try: + pc = rep.get_pc ((p_n, restrs2)) + except: + print 'get_pc failed' + if bin_return: + return False + else: + return -1 + #print 'got rep_.get_pc' + restrs3 = restr_others (p, restrs2, 2) + epc = rep.get_pc (('Err', restrs3), tag = tag) + hyp = mk_implies (mk_not (epc), mk_not (pc)) + hyps = hyps + noHaltHyps(p_n,p) + + #hyps = [] + #print 'calling test_hyp_whyps' + if rep.test_hyp_whyps (hyp, hyps): + print 'p_n %d: split limit found: %d' % (p_n, i) + if bin_return: + return True + return index if bin_return: - return False + return False print 'loop bound not found!' return -1 assert False, 'failed to find loop bound for p_n %d' % p_n @@ -188,8 +188,8 @@ def get_linear_series_eqs (p, split, restrs, hyps, omit_standard = False): k = ('linear_series_eqs', split, restrs, tuple (hyps)) if k in p.cached_analysis: if omit_standard: - standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) - return set (p.cached_analysis[k]) - standard + standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) + return set (p.cached_analysis[k]) - standard return p.cached_analysis[k] cands = search.mk_seq_eqs (p, split, 1, with_rodata = False) @@ -200,9 +200,9 @@ def get_linear_series_eqs (p, split, restrs, hyps, omit_standard = False): def do_checks (eqs_assume, eqs): checks = (check.single_loop_induct_step_checks (p, restrs, hyps, tag, - split, 1, eqs, eqs_assume = eqs_assume) - + check.single_loop_induct_base_checks (p, restrs, hyps, tag, - split, 1, eqs)) + split, 1, eqs, eqs_assume = eqs_assume) + + check.single_loop_induct_base_checks (p, restrs, hyps, tag, + split, 1, eqs)) groups = check.proof_check_groups (checks) for group in groups: @@ -226,46 +226,46 @@ def do_checks (eqs_assume, eqs): assert do_checks ([], eqs) p.cached_analysis[k] = eqs if omit_standard: - standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) - return set (eqs) - standard + standard = set (search.mk_seq_eqs (p, split, 1, with_rodata = False)) + return set (eqs) - standard return eqs def get_linear_series_hyps (p, split, restrs, hyps): eqs = get_linear_series_eqs (p, split, restrs, hyps) (tag, _) = p.node_tags[split] hyps = [h for (h, _) in linear_eq_hyps_at_visit (tag, split, eqs, - restrs, vc_offs (0))] + restrs, vc_offs (0))] return hyps def is_zero (expr): return expr.kind == 'Num' and expr.val & ((1 << expr.typ.num) - 1) == 0 def candidate_additional_eqs (p, split): + eq_vals = set () def visitor (expr): - if expr.is_op ('Equals') and expr.vals[0].typ.kind == 'Word': - [x, y] = expr.vals - eq_vals.update ([(x, y), (y, x)]) + if expr.is_op ('Equals') and expr.vals[0].typ.kind == 'Word': + [x, y] = expr.vals + eq_vals.update ([(x, y), (y, x)]) for n in p.loop_body (split): - p.nodes[n].visit (lambda x: (), visitor) + p.nodes[n].visit (lambda x: (), visitor) for (x, y) in list (eq_vals): if is_zero (x) and y.is_op ('Plus'): - [x, y] = y.vals - eq_vals.add ((x, syntax.mk_uminus (y))) - eq_vals.add ((y, syntax.mk_uminus (x))) + [x, y] = y.vals + eq_vals.add ((x, syntax.mk_uminus (y))) + eq_vals.add ((y, syntax.mk_uminus (x))) elif is_zero (x) and y.is_op ('Minus'): - [x, y] = y.vals - eq_vals.add ((x, y)) - eq_vals.add ((y, x)) - - loop = syntax.mk_var ('%i', syntax.word32T) + [x, y] = y.vals + eq_vals.add ((x, y)) + eq_vals.add ((y, x)) + loop = syntax.mk_var ('%i', syntax.arch.word_type) minus_loop_step = syntax.mk_uminus (loop) vas = search.get_loop_var_analysis_at(p, split) ls_vas = dict ([(var, [data]) for (var, data) in vas - if data[0] == 'LoopLinearSeries']) + if data[0] == 'LoopLinearSeries']) cmp_series = [(x, y, rew, offs) for (x, y) in eq_vals - for (_, rew, offs) in ls_vas.get (x, [])] + for (_, rew, offs) in ls_vas.get (x, [])] odd_eqs = [] for (x, y, rew, offs) in cmp_series: x_init_cmp1 = syntax.mk_less_eq (x, rew (x, minus_loop_step)) @@ -297,25 +297,25 @@ def get_call_ctxt_problem (split, call_ctxt, timing = True): from trace_refute import identify_function, build_compound_problem_with_links f = identify_function (call_ctxt, [split]) for (ctxt2, p, hyps, addr_map) in call_ctxt_problems: - if ctxt2 == (call_ctxt, f): - return (p, hyps, addr_map) + if ctxt2 == (call_ctxt, f): + return (p, hyps, addr_map) (p, hyps, addr_map) = build_compound_problem_with_links (call_ctxt, f) if avoid_C_information[0]: - hyps = [h for h in hyps if not has_C_information (p, h)] + hyps = [h for h in hyps if not has_C_information (p, h)] call_ctxt_problems.append(((call_ctxt, f), p, hyps, addr_map)) del call_ctxt_problems[: -20] end = time.time () if timing: - save_extra_timing ('GetProblem', call_ctxt + [split], end - start) + save_extra_timing ('GetProblem', call_ctxt + [split], end - start) return (p, hyps, addr_map) def has_C_information (p, hyp): for (n_vc, tag) in hyp.visits (): - if not p.hook_tag_hints.get (tag, None) == 'ASM': - return True + if not p.hook_tag_hints.get (tag, None) == 'ASM': + return True known_bound_restr_hyps = {} @@ -323,41 +323,41 @@ def has_C_information (p, hyp): def serialise_bound (addr, bound_info): if bound_info == None: - return [hex(addr), "None", "None"] + return [hex(addr), "None", "None"] else: - (bound, kind) = bound_info - assert logic.is_int (bound) - assert str (kind) == kind - return [hex (addr), str (bound), kind] + (bound, kind) = bound_info + assert logic.is_int (bound) + assert str (kind) == kind + return [hex (addr), str (bound), kind] def save_bound (glob, split_bin_addr, call_ctxt, prob_hash, prev_bounds, bound, - time = None): + time = None): f_names = [trace_refute.get_body_addrs_fun (x) - for x in call_ctxt + [split_bin_addr]] + for x in call_ctxt + [split_bin_addr]] loop_name = '<%s>' % ' -> '.join (f_names) comment = '# bound for loop in %s:' % loop_name ss = ['LoopBound'] + serialise_bound (split_bin_addr, bound) if glob: - ss[0] = 'GlobalLoopBound' + ss[0] = 'GlobalLoopBound' ss += [str (len (call_ctxt))] + map (hex, call_ctxt) ss += [str (prob_hash)] if glob: - assert prev_bounds == None + assert prev_bounds == None else: - ss += [str (len (prev_bounds))] - for (split, bound) in prev_bounds: - ss += serialise_bound (split, bound) + ss += [str (len (prev_bounds))] + for (split, bound) in prev_bounds: + ss += serialise_bound (split, bound) s = ' '.join (ss) f = open ('%s/LoopBounds.txt' % target_objects.target_dir, 'a') f.write (comment + '\n') f.write (s + '\n') if time != None: - ctxt2 = call_ctxt + [split_bin_addr] - ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2)) - f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time)) + ctxt2 = call_ctxt + [split_bin_addr] + ctxt2 = ' '.join ([str (len (ctxt2))] + map (hex, ctxt2)) + f.write ('LoopBoundTiming %s %s\n' % (ctxt2, time)) f.close () trace ('Found bound %s for 0x%x in %s.' % (bound, split_bin_addr, - loop_name)) + loop_name)) def save_extra_timing (nm, ctxt, time): ss = ['ExtraTiming', nm, str (len (ctxt))] + map (hex, ctxt) + [str(time)] @@ -369,12 +369,12 @@ def parse_bound (ss, n): addr = syntax.parse_int (ss[n]) bound = ss[n + 1] if bound == 'None': - bound = None - return (n + 3, (addr, None)) + bound = None + return (n + 3, (addr, None)) else: - bound = syntax.parse_int (bound) - kind = ss[n + 2] - return (n + 3, (addr, (bound, kind))) + bound = syntax.parse_int (bound) + kind = ss[n + 2] + return (n + 3, (addr, (bound, kind))) def parse_ctxt_id (bits, n): return (n + 1, syntax.parse_int (bits[n])) @@ -384,29 +384,29 @@ def parse_ctxt (bits, n): def load_bounds (): try: - f = open ('%s/LoopBounds.txt' % target_objects.target_dir) - ls = list (f) - f.close () + f = open ('%s/LoopBounds.txt' % target_objects.target_dir) + ls = list (f) + f.close () except IOError, e: - ls = [] + ls = [] from syntax import parse_int, parse_list for l in ls: - bits = l.split () - if bits[:1] not in [['LoopBound'], ['GlobalLoopBound']]: - continue - (n, (addr, bound)) = parse_bound (bits, 1) - (n, ctxt) = parse_ctxt (bits, n) - prob_hash = parse_int (bits[n]) - n += 1 - if bits[0] == 'LoopBound': - (n, prev_bounds) = parse_list (parse_bound, bits, n) - assert n == len (bits), bits - known = known_bounds.setdefault (addr, []) - known.append ((ctxt, prob_hash, prev_bounds, bound)) - else: - assert n == len (bits), bits - known = known_bounds.setdefault ((addr, 'Global'), []) - known.append ((ctxt, prob_hash, bound)) + bits = l.split () + if bits[:1] not in [['LoopBound'], ['GlobalLoopBound']]: + continue + (n, (addr, bound)) = parse_bound (bits, 1) + (n, ctxt) = parse_ctxt (bits, n) + prob_hash = parse_int (bits[n]) + n += 1 + if bits[0] == 'LoopBound': + (n, prev_bounds) = parse_list (parse_bound, bits, n) + assert n == len (bits), bits + known = known_bounds.setdefault (addr, []) + known.append ((ctxt, prob_hash, prev_bounds, bound)) + else: + assert n == len (bits), bits + known = known_bounds.setdefault ((addr, 'Global'), []) + known.append ((ctxt, prob_hash, bound)) known_bounds['Loaded'] = True def get_bound_ctxt (split, call_ctxt, use_cache = True): @@ -417,26 +417,26 @@ def get_bound_ctxt (split, call_ctxt, use_cache = True): split = p.loop_id (addr_map[split]) assert split, (orig_split, call_ctxt) split_bin_addr = min ([addr for addr in addr_map - if p.loop_id (addr_map[addr]) == split]) + if p.loop_id (addr_map[addr]) == split]) prior = get_prior_loop_heads (p, split) restrs = () prev_bounds = [] for split2 in prior: - # recursion! - split2 = p.loop_id (split2) - assert split2 - addr = min ([addr for addr in addr_map - if p.loop_id (addr_map[addr]) == split2]) - bound = get_bound_ctxt (addr, call_ctxt) - prev_bounds.append ((addr, bound)) - k = (p.name, split2, bound, restrs, tuple (hyps)) - if k in known_bound_restr_hyps: - (restrs, hyps) = known_bound_restr_hyps[k] - else: - (restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps, - split2, bound, call_ctxt + [orig_split]) - known_bound_restr_hyps[k] = (restrs, hyps) + # recursion! + split2 = p.loop_id (split2) + assert split2 + addr = min ([addr for addr in addr_map + if p.loop_id (addr_map[addr]) == split2]) + bound = get_bound_ctxt (addr, call_ctxt) + prev_bounds.append ((addr, bound)) + k = (p.name, split2, bound, restrs, tuple (hyps)) + if k in known_bound_restr_hyps: + (restrs, hyps) = known_bound_restr_hyps[k] + else: + (restrs, hyps) = add_loop_bound_restrs_hyps (p, restrs, hyps, + split2, bound, call_ctxt + [orig_split]) + known_bound_restr_hyps[k] = (restrs, hyps) # start timing now. we miss some setup time, but it avoids double counting # the recursive searches. @@ -445,53 +445,53 @@ def get_bound_ctxt (split, call_ctxt, use_cache = True): p_h = problem_hash (p) prev_bounds = sorted (prev_bounds) if not known_bounds: - load_bounds () + load_bounds () known = known_bounds.get (split_bin_addr, []) for (call_ctxt2, h, prev_bounds2, bound) in known: - match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2) - if match and use_cache and h == p_h and prev_bounds2 == prev_bounds: - return bound + match = (not call_ctxt2 or call_ctxt[- len (call_ctxt2):] == call_ctxt2) + if match and use_cache and h == p_h and prev_bounds2 == prev_bounds: + return bound bound = search_bin_bound (p, restrs, hyps, split) known = known_bounds.setdefault (split_bin_addr, []) known.append ((call_ctxt, p_h, prev_bounds, bound)) end = time.time () save_bound (False, split_bin_addr, call_ctxt, p_h, prev_bounds, bound, - time = end - start) + time = end - start) return bound def problem_hash (p): return syntax.hash_tuplify ([p.name, p.entries, - sorted (p.outputs.iteritems ()), sorted (p.nodes.iteritems ())]) + sorted (p.outputs.iteritems ()), sorted (p.nodes.iteritems ())]) def search_bin_bound (p, restrs, hyps, split): trace ('Searching for bound for 0x%x in %s.', (split, p.name)) bound = search_bound (p, restrs, hyps, split) if bound: - return bound + return bound # try to use a bound inferred from C if avoid_C_information[0]: - # OK told not to - return None + # OK told not to + return None if get_prior_loop_heads (p, split): - # too difficult for now - return None + # too difficult for now + return None asm_tag = p.node_tags[split][0] (_, fname, _) = p.get_entry_details (asm_tag) funs = [f for pair in target_objects.pairings[fname] - for f in pair.funs.values ()] + for f in pair.funs.values ()] c_tags = [tag for tag in p.tags () - if p.get_entry_details (tag)[1] in funs and tag != asm_tag] + if p.get_entry_details (tag)[1] in funs and tag != asm_tag] if len (c_tags) != 1: - print 'Surprised to see multiple matching tags %s' % c_tags - return None + print 'Surprised to see multiple matching tags %s' % c_tags + return None [c_tag] = c_tags rep = rep_graph.mk_graph_slice (p) if len (search.get_loop_entry_sites (rep, restrs, hyps, split)) != 1: - # technical, but it's not going to work in this case - return None + # technical, but it's not going to work in this case + return None return getBinaryBoundFromC (p, c_tag, split, restrs, hyps) @@ -511,9 +511,9 @@ def search_bound (p, restrs, hyps, split): # limit this to a small bound for time purposes # - for larger bounds the less naive approach can be faster bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps, - try_seq = [0, 1, 6]) + try_seq = [0, 1, 6]) if bound != None: - return (bound, 'NaiveBinSearch') + return (bound, 'NaiveBinSearch') l_hyps = get_linear_series_hyps (p, split, restrs, hyps) @@ -525,30 +525,30 @@ def test (n): visit = ((split, vc_offs (2)), ) + restrs continue_to_split_guess = rep.get_pc ((split, visit)) return rep.test_hyp_whyps (syntax.mk_not (continue_to_split_guess), - [hyp] + l_hyps + hyps) + [hyp] + l_hyps + hyps) # findLoopBoundBS always checks to at least 16 min_bound = 16 max_bound = max_acceptable_bound[0] bound = upDownBinSearch (min_bound, max_bound, test) if bound != None and test (bound): - return (bound, 'InductiveBinSearch') + return (bound, 'InductiveBinSearch') # let the naive bin search go a bit further bound = findLoopBoundBS(split, p, restrs=restrs, hyps=hyps) if bound != None: - return (bound, 'NaiveBinSearch') + return (bound, 'NaiveBinSearch') return None def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps): c_heads = [h for h in search.init_loops_to_split (p, restrs) - if p.node_tags[h][0] == c_tag] + if p.node_tags[h][0] == c_tag] c_bounds = [(p.loop_id (split), search_bound (p, (), hyps, split)) - for split in c_heads] + for split in c_heads] if not [b for (n, b) in c_bounds if b]: - trace ('no C bounds found (%s).' % c_bounds) - return None + trace ('no C bounds found (%s).' % c_bounds) + return None asm_tag = p.node_tags[asm_split][0] @@ -557,32 +557,32 @@ def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps): j_seq_opts = [(0, 1), (0, 2), (1, 1)] tags = [p.node_tags[asm_split][0], c_tag] try: - split = search.find_split (rep, asm_split, restrs, hyps, i_seq_opts, - j_seq_opts, 5, tags = [asm_tag, c_tag]) + split = search.find_split (rep, asm_split, restrs, hyps, i_seq_opts, + j_seq_opts, 5, tags = [asm_tag, c_tag]) except solver.SolverFailure, e: - return None + return None if not split or split[0] != 'Split': - trace ('no split found (%s).' % repr (split)) - return None + trace ('no split found (%s).' % repr (split)) + return None (_, split) = split rep = rep_graph.mk_graph_slice (p) checks = check.split_checks (p, (), hyps, split, tags = [asm_tag, c_tag]) groups = check.proof_check_groups (checks) try: - for group in groups: - (res, el) = check.test_hyp_group (rep, group) - if not res: - trace ('split check failed!') - trace ('failed at %s' % el) - return None + for group in groups: + (res, el) = check.test_hyp_group (rep, group) + if not res: + trace ('split check failed!') + trace ('failed at %s' % el) + return None except solver.SolverFailure, e: - return None + return None (as_details, c_details, _, n, _) = split (c_split, (seq_start, step), _) = c_details c_bound = dict (c_bounds).get (p.loop_id (c_split)) if not c_bound: - trace ('key split was not bounded (%r, %r).' % (c_split, c_bounds)) - return None + trace ('key split was not bounded (%r, %r).' % (c_split, c_bounds)) + return None (c_bound, _) = c_bound max_it = (c_bound - seq_start) / step assert max_it > n, (max_it, n) @@ -595,18 +595,18 @@ def getBinaryBoundFromC (p, c_tag, asm_split, restrs, hyps): def get_prior_loop_heads (p, split, use_rep = None): if use_rep: - rep = use_rep + rep = use_rep else: - rep = rep_graph.mk_graph_slice (p) + rep = rep_graph.mk_graph_slice (p) prior = [] split = p.loop_id (split) for h in p.loop_heads (): - s = set (prior) - if h not in s and rep.get_reachable (h, split) and h != split: - # need to recurse to ensure prior are in order - prior2 = get_prior_loop_heads (p, h, use_rep = rep) - prior.extend ([h2 for h2 in prior2 if h2 not in s]) - prior.append (h) + s = set (prior) + if h not in s and rep.get_reachable (h, split) and h != split: + # need to recurse to ensure prior are in order + prior2 = get_prior_loop_heads (p, h, use_rep = rep) + prior.extend ([h2 for h2 in prior2 if h2 not in s]) + prior.append (h) return prior def add_loop_bound_restrs_hyps (p, restrs, hyps, split, bound, ctxt): @@ -632,7 +632,7 @@ def add_loop_bound_restrs_hyps (p, restrs, hyps, split, bound, ctxt): def get_functions_hash (): if functions_hash[0] != None: - return functions_hash[0] + return functions_hash[0] h = hash (tuple (sorted ([(f, hash (functions[f])) for f in functions]))) functions_hash[0] = h return h @@ -642,56 +642,57 @@ def get_functions_hash (): def addr_to_loop_id (split): if split not in addr_to_loop_id_cache: - add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) + add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) return addr_to_loop_id_cache[split] def is_complex_loop (split): split = addr_to_loop_id (split) if split not in complex_loop_id_cache: - add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) + add_fun_to_loop_data_cache (trace_refute.get_body_addrs_fun (split)) return complex_loop_id_cache[split] def get_loop_addrs (split): split = addr_to_loop_id (split) f = functions[trace_refute.get_body_addrs_fun (split)] return [addr for addr in f.nodes if trace_refute.is_addr (addr) - if addr_to_loop_id_cache.get (addr) == split] + if addr_to_loop_id_cache.get (addr) == split] def add_fun_to_loop_data_cache (fname): + #print fname p = functions[fname].as_problem (problem.Problem) p.do_loop_analysis () for h in p.loop_heads (): - addrs = [n for n in p.loop_body (h) - if trace_refute.is_addr (n)] - min_addr = min (addrs) - for addr in addrs: - addr_to_loop_id_cache[addr] = min_addr - complex_loop_id_cache[min_addr] = problem.has_inner_loop (p, h) + addrs = [n for n in p.loop_body (h) + if trace_refute.is_addr (n)] + min_addr = min (addrs) + for addr in addrs: + addr_to_loop_id_cache[addr] = min_addr + complex_loop_id_cache[min_addr] = problem.has_inner_loop (p, h) return min_addr def get_bound_super_ctxt (split, call_ctxt, no_splitting=False, - known_bound_only=False): + known_bound_only=False): if not known_bounds: - load_bounds () + load_bounds () for (ctxt2, fn_hash, bound) in known_bounds.get ((split, 'Global'), []): - if ctxt2 == call_ctxt and fn_hash == get_functions_hash (): - return bound + if ctxt2 == call_ctxt and fn_hash == get_functions_hash (): + return bound min_loop_addr = addr_to_loop_id (split) if min_loop_addr != split: - return get_bound_super_ctxt (min_loop_addr, call_ctxt, - no_splitting = no_splitting, known_bound_only = known_bound_only) + return get_bound_super_ctxt (min_loop_addr, call_ctxt, + no_splitting = no_splitting, known_bound_only = known_bound_only) if known_bound_only: return None no_splitting_abort = [False] try: - bound = get_bound_super_ctxt_inner (split, call_ctxt, - no_splitting = (no_splitting, no_splitting_abort)) + bound = get_bound_super_ctxt_inner (split, call_ctxt, + no_splitting = (no_splitting, no_splitting_abort)) except problem.Abort, e: - bound = None + bound = None if no_splitting_abort[0]: - # don't record this bound, since it might change if splitting was allowed - return bound + # don't record this bound, since it might change if splitting was allowed + return bound known = known_bounds.setdefault ((split, 'Global'), []) known.append ((call_ctxt, get_functions_hash (), bound)) save_bound (True, split, call_ctxt, get_functions_hash (), None, bound) @@ -701,68 +702,68 @@ def get_bound_super_ctxt (split, call_ctxt, no_splitting=False, def call_ctxt_computable (split, call_ctxt): fs = [trace_refute.identify_function ([], [call_site]) - for call_site in call_ctxt] + for call_site in call_ctxt] non_computable = [f for f in fs if trace_refute.has_complex_loop (f)] if non_computable: - trace ('avoiding functions with complex loops: %s' % non_computable) + trace ('avoiding functions with complex loops: %s' % non_computable) return not non_computable def get_bound_super_ctxt_inner (split, call_ctxt, - no_splitting = (False, None)): + no_splitting = (False, None)): first_f = trace_refute.identify_function ([], (call_ctxt + [split])[:1]) call_sites = all_call_sites (first_f) if function_limit (first_f) == 0: - return (0, 'FunctionLimit') + return (0, 'FunctionLimit') safe_call_sites = [cs for cs in call_sites - if ctxt_within_function_limits ([cs] + call_ctxt)] + if ctxt_within_function_limits ([cs] + call_ctxt)] if call_sites and not safe_call_sites: - return (0, 'FunctionLimit') + return (0, 'FunctionLimit') if len (call_ctxt) < 3 and len (safe_call_sites) == 1: - call_ctxt2 = list (safe_call_sites) + call_ctxt - if call_ctxt_computable (split, call_ctxt2): - trace ('using unique calling context %s' % str ((split, call_ctxt2))) - return get_bound_super_ctxt (split, call_ctxt2) + call_ctxt2 = list (safe_call_sites) + call_ctxt + if call_ctxt_computable (split, call_ctxt2): + trace ('using unique calling context %s' % str ((split, call_ctxt2))) + return get_bound_super_ctxt (split, call_ctxt2) fname = trace_refute.identify_function (call_ctxt, [split]) bound = function_limit_bound (fname, split) if bound: - return bound + return bound bound = get_bound_ctxt (split, call_ctxt) if bound: - return bound + return bound trace ('no bound found immediately.') if no_splitting[0]: - assert no_splitting[1], no_splitting - no_splitting[1][0] = True - trace ('cannot split by context (recursion).') - return None + assert no_splitting[1], no_splitting + no_splitting[1][0] = True + trace ('cannot split by context (recursion).') + return None # try to split over potential call sites if len (call_ctxt) >= 3: - trace ('cannot split by context (context depth).') - return None + trace ('cannot split by context (context depth).') + return None if len (call_sites) == 0: - # either entry point or nonsense - trace ('cannot split by context (reached top level).') - return None + # either entry point or nonsense + trace ('cannot split by context (reached top level).') + return None problem_sites = [call_site for call_site in safe_call_sites - if not call_ctxt_computable (split, [call_site] + call_ctxt)] + if not call_ctxt_computable (split, [call_site] + call_ctxt)] if problem_sites: - trace ('cannot split by context (issues in %s).' % problem_sites) - return None + trace ('cannot split by context (issues in %s).' % problem_sites) + return None anc_bounds = [get_bound_super_ctxt (split, [call_site] + call_ctxt, - no_splitting = True) - for call_site in safe_call_sites] + no_splitting = True) + for call_site in safe_call_sites] if None in anc_bounds: - return None + return None (bound, kind) = max (anc_bounds) return (bound, 'MergedBound') @@ -770,18 +771,18 @@ def function_limit_bound (fname, split): p = functions[fname].as_problem (problem.Problem) p.do_analysis () cuts = [n for n in p.loop_body (split) - if p.nodes[n].kind == 'Call' - if function_limit (p.nodes[n].fname) != None] + if p.nodes[n].kind == 'Call' + if function_limit (p.nodes[n].fname) != None] if not cuts: - return None + return None graph = p.mk_node_graph (p.loop_body (split)) # it is not possible to iterate the loop without visiting a bounded # function. naively, this sets the limit to the sum of all the possible # bounds, plus one because we can enter the loop a final time without # visiting any function call site yet. if logic.divides_loop (graph, set (cuts)): - fnames = set ([p.nodes[n].fname for n in cuts]) - return (sum ([function_limit (f) for f in fnames]) + 1, 'FunctionLimit') + fnames = set ([p.nodes[n].fname for n in cuts]) + return (sum ([function_limit (f) for f in fnames]) + 1, 'FunctionLimit') def loop_bound_difficulty_estimates (split, ctxt): # various guesses at how hard the loop bounding problem is. @@ -810,55 +811,55 @@ def load_timing (): loop_time = 0.0 ext_time = 0.0 for line in f: - bits = line.split () - if not (bits and 'Timing' in bits[0]): - continue - if bits[0] == 'LoopBoundTiming': - (n, ext_ctxt) = parse_ctxt (bits, 1) - assert n == len (bits) - 1 - time = float (bits[n]) - ctxt = ext_ctxt[:-1] - split = ext_ctxt[-1] - timing[(split, tuple(ctxt))] = time - loop_time += time - elif bits[0] == 'ExtraTiming': - time = float (bits[-1]) - ext_time += time + bits = line.split () + if not (bits and 'Timing' in bits[0]): + continue + if bits[0] == 'LoopBoundTiming': + (n, ext_ctxt) = parse_ctxt (bits, 1) + assert n == len (bits) - 1 + time = float (bits[n]) + ctxt = ext_ctxt[:-1] + split = ext_ctxt[-1] + timing[(split, tuple(ctxt))] = time + loop_time += time + elif bits[0] == 'ExtraTiming': + time = float (bits[-1]) + ext_time += time f.close () f = open ('%s/time' % target_objects.target_dir) [l] = [l for l in f if '(wall clock)' in l] f.close () tot_time_str = l.split ()[-1] tot_time = sum ([float(s) * (60 ** i) - for (i, s) in enumerate (reversed (tot_time_str.split(':')))]) + for (i, s) in enumerate (reversed (tot_time_str.split(':')))]) return (loop_time, ext_time, tot_time, timing) def mk_timing_metrics (): if not known_bounds: - load_bounds () + load_bounds () probs = [(split_bin_addr, tuple (call_ctxt), bound) - for (split_bin_addr, known) in known_bounds.iteritems () - if type (split_bin_addr) == int - for (call_ctxt, h, prev_bounds, bound) in known] + for (split_bin_addr, known) in known_bounds.iteritems () + if type (split_bin_addr) == int + for (call_ctxt, h, prev_bounds, bound) in known] probs = set (probs) data = [(split, ctxt, bound, - loop_bound_difficulty_estimates (split, list (ctxt))) - for (split, ctxt, bound) in probs] + loop_bound_difficulty_estimates (split, list (ctxt))) + for (split, ctxt, bound) in probs] return data # sigh, this is so much work. bound_kind_nums = { - 'FunctionLimit': 2, - 'NaiveBinSearch': 3, - 'InductiveBinSearch': 4, - 'FromC': 5, - 'MergedBound': 6, + 'FunctionLimit': 2, + 'NaiveBinSearch': 3, + 'InductiveBinSearch': 4, + 'FromC': 5, + 'MergedBound': 6, } gnuplot_colours = [ - "dark-red", "dark-blue", "dark-green", "dark-grey", - "dark-orange", "dark-magenta", "dark-cyan"] + "dark-red", "dark-blue", "dark-green", "dark-grey", + "dark-orange", "dark-magenta", "dark-cyan"] def save_timing_metrics (num): (loop_time, ext_time, tot_time, timing) = load_timing () @@ -872,48 +873,48 @@ def save_timing_metrics (num): f.write ('"%s"\n' % short_name) for (split, ctxt, bound, ests) in time_ests: - time = timing[(split, tuple (ctxt))] - if bound == None: - bdata = "1000000 7" - else: - bdata = '%d %d' % (bound[0], bound_kind_nums[bound[1]]) - (l_i, f_i, ct_i) = ests - f.write ('%s %s %s %s %s %r %s\n' % (short_name, l_i, f_i, ct_i, - bdata, col, time)) + time = timing[(split, tuple (ctxt))] + if bound == None: + bdata = "1000000 7" + else: + bdata = '%d %d' % (bound[0], bound_kind_nums[bound[1]]) + (l_i, f_i, ct_i) = ests + f.write ('%s %s %s %s %s %r %s\n' % (short_name, l_i, f_i, ct_i, + bdata, col, time)) f.close () def get_loop_heads (fun): if not fun.entry: - return [] + return [] p = fun.as_problem (problem.Problem) p.do_loop_analysis () loops = set () for h in p.loop_heads (): - # any address in the loop will do. pick the smallest one - addr = min ([n for n in p.loop_body (h) if trace_refute.is_addr (n)]) - loops.add ((addr, fun.name, problem.has_inner_loop (p, h))) + # any address in the loop will do. pick the smallest one + addr = min ([n for n in p.loop_body (h) if trace_refute.is_addr (n)]) + loops.add ((addr, fun.name, problem.has_inner_loop (p, h))) return list (loops) def get_all_loop_heads (): loops = set () abort_funs = set () for f in all_asm_functions (): - try: - loops.update (get_loop_heads (functions[f])) - except problem.Abort, e: - abort_funs.add (f) + try: + loops.update (get_loop_heads (functions[f])) + except problem.Abort, e: + abort_funs.add (f) if abort_funs: - trace ('Cannot analyse loops in: %s' % ', '.join (abort_funs)) + trace ('Cannot analyse loops in: %s' % ', '.join (abort_funs)) return loops def get_complex_loops (): return [(loop, name) for (loop, name, compl) in get_all_loop_heads () - if compl] + if compl] def search_all_loops (): all_loops = get_all_loop_heads () for (loop, _, _) in all_loops: - get_bound_super_ctxt (loop, []) + get_bound_super_ctxt (loop, []) main = search_all_loops @@ -921,9 +922,9 @@ def search_all_loops (): import sys args = target_objects.load_target_args () if args == ['search']: - search_all_loops () + search_all_loops () elif args[:1] == ['metrics']: - num = args[1:].index (str (target_objects.target_dir)) - save_timing_metrics (num) + num = args[1:].index (str (target_objects.target_dir)) + save_timing_metrics (num) diff --git a/scripts/debug b/misc/debug similarity index 100% rename from scripts/debug rename to misc/debug diff --git a/nix/cross-tools.nix b/nix/cross-tools.nix new file mode 100644 index 00000000..2d547434 --- /dev/null +++ b/nix/cross-tools.nix @@ -0,0 +1,32 @@ +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Cross-compilers used to generate kernel binaries. + +let + + inherit (import ./pins.nix) + mk-default-pkgs + mk-arm-cross-pkgs; + +in rec { + + mk-cross-pkgs = mk-pkgs: crossSystem: + let cross-pkgs = mk-pkgs { inherit crossSystem; }; + in cross-pkgs.pkgsBuildTarget; + + mk-cross-tools = mk-pkgs: crossSystem: + with (mk-cross-pkgs mk-pkgs crossSystem); [ gcc-unwrapped binutils-unwrapped ]; + + x86-cross-tools = + mk-cross-tools mk-default-pkgs { config = "x86_64-unknown-linux-gnu"; }; + + riscv64-cross-tools = + mk-cross-tools mk-default-pkgs { config = "riscv64-unknown-linux-gnu"; }; + + arm-embedded-tools = + mk-cross-tools mk-arm-cross-pkgs { config = "arm-none-eabi"; libc = "newlib"; }; + + cross-tools = x86-cross-tools ++ riscv64-cross-tools ++ arm-embedded-tools; + +} diff --git a/nix/graph-refine.nix b/nix/graph-refine.nix new file mode 100644 index 00000000..f47edea1 --- /dev/null +++ b/nix/graph-refine.nix @@ -0,0 +1,122 @@ +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Packages graph-refine and a solver configuration as a Nix derivation, +# and also produces a Docker image. + +{ solver_config ? {} }: + +let + + inherit (import ./pins.nix) pkgs python_2_7_pkgs lib stdenv; + inherit (import ./util.nix) explicit_sources; + + solvers = import ./solvers.nix { inherit solver_config; }; + + graph-refine-python = python_2_7_pkgs.python2.withPackages (python-pkgs: with python-pkgs; [ + enum34 typing psutil + ]); + + # List of graph-refine Python source files. + # We take only the files we need, to minimise Docker layer churn. + graph-refine-modules = [ + "c_rodata.py" + "check.py" + "graph-refine.py" + "inst_logic.py" + "logic.py" + "loop_bounds.py" + "objdump.py" + "parallel_solver.py" + "problem.py" + "pseudo_compile.py" + "rep_graph.py" + "search.py" + "solver.py" + "stack_logic.py" + "stats.py" + "syntax.py" + "target_objects.py" + "trace_refute.py" + "typed_commons.py" + ]; + + # Byte-code-compiled Python sources. + graph-refine-py = with lib; + let + modules = toString graph-refine-modules; + py = stdenv.mkDerivation rec { + name = "graph-refine-py"; + src = explicit_sources "graph-refine-sources" ./.. graph-refine-modules; + installPhase = '' + mkdir -p $out + (cd $src && tar cf - ${modules}) | (cd $out && tar xf -) + (cd $out && ${graph-refine-python.interpreter} -m compileall ${modules}) + ''; + }; + in py; + + # Wrapper that sets environment variables for graph-refine. + # We use runCommand instead of writeScriptBin so we can grab the store path from `$out`. + graph-refine = + let + text = '' + #!${pkgs.runtimeShell} + set -euo pipefail + export GRAPH_REFINE_SOLVERLIST_DIR="${solvers.solverlist}" + export GRAPH_REFINE_VERSION_INFO="NIX_STORE_PATH NIX_GRAPH_REFINE_OUTPUT_DIR" + exec "${graph-refine-python.interpreter}" "${graph-refine-py}/graph-refine.py" "$@" + ''; + deriv_args = { inherit text; passAsFile = [ "text" ]; nativeBuildInputs = [pkgs.perl]; }; + script = pkgs.runCommand "graph-refine" deriv_args '' + mkdir -p "$out/bin" + out_dir="$(basename "$out")" + perl -pe "s/\bNIX_GRAPH_REFINE_OUTPUT_DIR\b/$out_dir/" "$textPath" > "$out/bin/graph-refine" + chmod +x "$out/bin/graph-refine" + ''; + in script; + + # A graph-refine Docker image. + graph-refine-image = with pkgs; dockerTools.streamLayeredImage { + name = "graph-refine"; + contents = [ bashInteractive coreutils graph-refine graph-refine-python ]; + config = { Entrypoint = [ "${graph-refine}/bin/graph-refine" ]; }; + }; + + graph-refine-runner-py = with pkgs; runCommand "graph-refine-runner-py" {} '' + mkdir -p "$out" + cp --preserve=all "${../ci/runner.py}" "$out/graph-refine-runner.py" + (cd "$out" && "${python3.interpreter}" -m compileall "graph-refine-runner.py") + ''; + + graph-refine-runner = with pkgs; writeScriptBin "graph-refine-runner" '' + #!${runtimeShell} + export GRAPH_REFINE_SCRIPT="${graph-refine}/bin/graph-refine" + exec "${python3.interpreter}" "${graph-refine-runner-py}/graph-refine-runner.py" "$@" + ''; + + graph-refine-runner-image = with pkgs; dockerTools.streamLayeredImage { + name = "graph-refine-runner"; + contents = [ + bashInteractive coreutils + graph-refine graph-refine-runner + graph-refine-python python3 + sqlite + ]; + config = { Entrypoint = [ "${graph-refine-runner}/bin/graph-refine-runner" ]; }; + }; + +in { + + inherit + graph-refine + graph-refine-image + graph-refine-runner + graph-refine-runner-image + graph-refine-python; + + inherit (solvers) + solverlist + solvers; + +} diff --git a/nix/l4v-deps.nix b/nix/l4v-deps.nix new file mode 100644 index 00000000..13a4ac4e --- /dev/null +++ b/nix/l4v-deps.nix @@ -0,0 +1,22 @@ +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Additional dependencies for running l4v proofs. + +let + + inherit (import ./pins.nix) pkgs; + + l4v-texlive = with pkgs.texlive; combine { + inherit + collection-latexextra + collection-fontsrecommended + collection-metapost + collection-bibtexextra; + }; + + l4v-deps = with pkgs; [ + l4v-texlive + ]; + +in { inherit l4v-deps; } diff --git a/nix/pins.nix b/nix/pins.nix new file mode 100644 index 00000000..69645537 --- /dev/null +++ b/nix/pins.nix @@ -0,0 +1,49 @@ +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Pin versions of nixpkgs and other repositories used for building +# graph-refine, the decompiler, etc. + +let + + inherit (import ./util.nix) fetchGitHub fetch-nixpkgs mk-rosetta-pkgs; + +in + +rec { + + # Recent nixpkgs pin used for most things. + mk-default-pkgs = fetch-nixpkgs { + # nixos-unstable 2023-01-13 + rev = "befc83905c965adfd33e5cae49acb0351f6e0404"; + sha256 = "sha256:0m0ik7z06q3rshhhrg2p0vsrkf2jnqcq5gq1q6wb9g291rhyk6h2"; + }; + + # Python 2.7 needs an older pin. + python_2_7_pkgs = fetch-nixpkgs { + # nixos-unstable 2022-05-31 + rev = "d4964be44cb430760b266f5df34a185f2920e80e"; + sha256 = "sha256:01wd40yn8crz1dmypd9vcc9gcv8d83haai2cdv704vg2s423gg88"; + } {}; + + # The arm-none-eabi cross compiler needs an older pin, + # and on aarch64-darwin, only builds via Rosetta. + mk-arm-cross-pkgs = mk-rosetta-pkgs pkgs (fetch-nixpkgs { + # nixos-22.05 2023-01-13 + rev = "9e96b1562d67a90ca2f7cfb919f1e86b76a65a2c"; + sha256 = "sha256:0nma745rx2f2syggzl99r0mv1pmdy36nsar1wxggci647gdqriwf"; + }); + + pkgs = mk-default-pkgs {}; + inherit (pkgs) lib stdenv; + + rosetta-pkgs = mk-rosetta-pkgs pkgs mk-default-pkgs {}; + + herculesGitignore = import (fetchGitHub { + owner = "hercules-ci"; + repo = "gitignore.nix"; + rev = "a20de23b925fd8264fd7fad6454652e142fd7f73"; + sha256 = "sha256:07vg2i9va38zbld9abs9lzqblz193vc5wvqd6h7amkmwf66ljcgh"; + }) { inherit lib; }; + +} diff --git a/nix/sel4-deps.nix b/nix/sel4-deps.nix new file mode 100644 index 00000000..6fe78ba7 --- /dev/null +++ b/nix/sel4-deps.nix @@ -0,0 +1,28 @@ +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Dependencies for building seL4. + +let + + inherit (import ./pins.nix) pkgs; + inherit (import ./sel4-python.nix) sel4-python; + inherit (import ./cross-tools.nix) cross-tools; + + deps = with pkgs; [ + bash + cmake + coreutils + dtc + findutils + gnugrep + gnumake + ninja + libxml2 + sel4-python + which + ]; + + sel4-deps = deps ++ cross-tools; + +in { inherit sel4-deps sel4-python cross-tools; } diff --git a/nix/sel4-python.nix b/nix/sel4-python.nix new file mode 100644 index 00000000..4b225528 --- /dev/null +++ b/nix/sel4-python.nix @@ -0,0 +1,100 @@ +# Copyright 2020 Arm Limited +# Copyright 2022 Kry10 Limited +# +# SPDX-License-Identifier: MIT + +# A Python environment suitable for building seL4, adapted from: +# https://gitlab.com/icecap-project/icecap/-/blob/main/nix/framework/overlay/python-overrides.nix + +let + + inherit (import ./pins.nix) pkgs; + + sel4-python = pkgs.python3.withPackages (python-pkgs: with python-pkgs; + let + autopep8_1_4_3 = buildPythonPackage rec { + pname = "autopep8"; + version = "1.4.3"; + src = fetchPypi { + inherit pname version; + sha256 = "13140hs3kh5k13yrp1hjlyz2xad3xs1fjkw1811gn6kybcrbblik"; + }; + propagatedBuildInputs = [ + pycodestyle + ]; + doCheck = false; + checkInputs = [ glibcLocales ]; + LC_ALL = "en_US.UTF-8"; + }; + + cmake-format = buildPythonPackage rec { + pname = "cmake_format"; + version = "0.4.5"; + src = fetchPypi { + inherit pname version; + sha256 = "0nl78yb6zdxawidp62w9wcvwkfid9kg86n52ryg9ikblqw428q0n"; + }; + propagatedBuildInputs = [ + jinja2 + pyyaml + ]; + doCheck = false; + }; + + guardonce = buildPythonPackage rec { + pname = "guardonce"; + version = "2.4.0"; + src = fetchPypi { + inherit pname version; + sha256 = "0sr7c1f9mh2vp6pkw3bgpd7crldmaksjfafy8wp5vphxk98ix2f7"; + }; + buildInputs = [ + nose + ]; + }; + + pyfdt = buildPythonPackage rec { + pname = "pyfdt"; + version = "0.3"; + src = fetchPypi { + inherit pname version; + sha256 = "1w7lp421pssfgv901103521qigwb12i6sk68lqjllfgz0lh1qq31"; + }; + }; + + sel4-deps = buildPythonPackage rec { + pname = "sel4-deps"; + version = "0.3.1"; + src = fetchPypi { + inherit pname version; + sha256 = "09xjv4gc9cwanxdhpqg2sy2pfzn2rnrnxgjdw93nqxyrbpdagd5r"; + }; + postPatch = '' + substituteInPlace setup.py --replace bs4 beautifulsoup4 + ''; + propagatedBuildInputs = [ + autopep8_1_4_3 + beautifulsoup4 + cmake-format + future + guardonce + jinja2 + jsonschema + libarchive-c + lxml + pexpect + ply + psutil + pyaml + pyelftools + pyfdt + setuptools + six + sh + types-pyyaml + ]; + }; + + in [ sel4-deps mypy ]); + +in { inherit sel4-python; } diff --git a/nix/solvers.nix b/nix/solvers.nix new file mode 100644 index 00000000..16959020 --- /dev/null +++ b/nix/solvers.nix @@ -0,0 +1,131 @@ +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Write a solver configuration (.solverlist) to the Nix store. + +let + + inherit (import ./pins.nix) pkgs lib; + sonolar_available = pkgs.system == "x86_64-linux"; + +in + +# Optionally accept a solver config, which may override +# attributes in default_solver_config below. +{ + use_sonolar ? sonolar_available, + solver_config ? {}, +}: + +let + + default_solver_list = + if use_sonolar + then [ "yices" "sonolar" "cvc4" ] + else [ "yices" "cvc4" ]; + + default_solver_config = { + online_solver = "cvc4"; + offline_solver = if use_sonolar then "sonolar" else "cvc4"; + strategy_solvers = default_solver_list; + model_solvers = default_solver_list; + }; + + # Fill in missing solver_config items using defaults. + solver_config_with_defaults = with builtins; + let + # Ensure solver_config has no unwanted items. + bad_attrs = filter (n: !hasAttr n default_solver_config) (attrNames solver_config); + checked_config = if bad_attrs == [] + then solver_config + else throw "solvers.nix: solver_config has unwanted attributes: ${toString bad_attrs}"; + in default_solver_config // checked_config; + +in + +with solver_config_with_defaults; + +let + + maybe_sonolar = if !sonolar_available then {} else { + sonolar = rec { + pkg = pkgs.fetchzip rec { + name = "sonolar-2014-12-04"; + url = "http://www.informatik.uni-bremen.de/agbs/florian/sonolar/${name}-x86_64-linux.tar.gz"; + sha256 = "sha256:01k6270ycv532w5hx0xhfms63migr7wq359lsnr4a6d047q15ix7"; + }; + offline = "${pkg}/bin/sonolar --input-format=smtlib2"; + }; + }; + + solver_cmds = maybe_sonolar // { + cvc4 = rec { + pkg = pkgs.cvc4; + online = "${pkg}/bin/cvc4 --incremental --lang smt --tlimit=120"; + offline = "${pkg}/bin/cvc4 --lang smt"; + }; + yices = rec { + pkg = pkgs.yices; + offline = "${pkg}/bin/yices-smt2"; + }; + }; + + # Names of selected solvers for each mode. May contain duplicates. + selected_solvers = { + offline = [ offline_solver ] ++ strategy_solvers ++ model_solvers; + online = [ online_solver ]; + }; + + # Solver packages selected by solver_config. + solvers = with lib; + mapAttrs (_: s: s.pkg) (getAttrs (concatLists (attrValues selected_solvers)) solver_cmds); + + # Templates for solverlist solver specifications, in offline and online modes. + mk_solvers = { + offline = name: cmd: [ + "${name}: offline:: ${cmd}" + "${name}-word8: offline: mem_mode=8: ${cmd}" + ]; + online = name: cmd: [ + "${name}: online:: ${cmd}" + ]; + }; + + # We currently always use both "all" and "hyp" strategies, + # and both machine-word- and byte-granular views of memory. + strategy = + let f = n: "${n} all, ${n} hyp, ${n}-word8 all, ${n}-word8 hyp"; + in lib.concatMapStringsSep ", " f strategy_solvers; + + # For model repair, first use all selected solvers with a machine-word-granular + # view of memory, then use the selected solvers with a byte-granular view. + model_strategy = with lib; + let solv_suf = suf: map (solv: solv + suf) model_solvers; + in concatStringsSep ", " (concatMap solv_suf ["" "-word8"]); + + # Render the solverlist templates for the given mode and selected solvers, + # returning a list of unterminated lines. + render_solvers_mode = mode: selected: with lib; + concatLists (mapAttrsToList (name: cmd: mk_solvers."${mode}" name cmd.${mode}) + (getAttrs selected solver_cmds)); + + # Render solverlist templates for both modes, emitting online templates first. + selected_solver_configs = with lib; + concatStringsSep "\n" + (concatLists (attrVals [ "online" "offline" ] + (mapAttrs render_solvers_mode selected_solvers))); + + # The solver configuration file. + solverlist = pkgs.writeTextFile { + name = "graph-refine-solverlist"; + destination = "/.solverlist"; + text = '' + strategy: ${strategy} + model-strategy: ${model_strategy} + online-solver: ${online_solver} + offline-solver: ${offline_solver} + ${selected_solver_configs} + ''; + }; + +in { inherit solverlist solvers; } diff --git a/nix/util.nix b/nix/util.nix new file mode 100644 index 00000000..d912d969 --- /dev/null +++ b/nix/util.nix @@ -0,0 +1,111 @@ +# Copyright 2023 Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Shared utility functions. + +let + + inherit (import ./pins.nix) pkgs lib; + +in + +rec { + + # Like nixpkgs fetchFromGitHub, but using the built-in fetchTarball. + fetchGitHub = { owner, repo, rev, name ? "${owner}-${repo}-${rev}", sha256 }: + fetchTarball { + inherit name sha256; + url = "https://github.com/${owner}/${repo}/archive/${rev}.tar.gz"; + }; + + # Fetch and import a pinned nixpkgs. + fetch-nixpkgs = { rev, sha256 }: import (fetchGitHub { + owner = "nixos"; repo = "nixpkgs"; name = "nixpkgs"; + inherit rev sha256; + }); + + # Some packages don't work natively on Apple Silicon, but work with Rosetta 2. + mk-rosetta-pkgs = pkgs: mk-pkgs: config: + mk-pkgs (if pkgs.system == "aarch64-darwin" + then { localSystem = pkgs.lib.systems.examples.x86_64-darwin; } // config + else config); + + # A source filter that accepts an explicit list of paths relative to a given root. + explicit_sources_filter = root: paths: + assert (builtins.isPath root); + with lib; + let + # Augment the list of paths with all of its own ancestor paths. + # E.g. [ "x/y/z" "a/b/c" ] -> [ "x" "a" "x/y" "a/b" "x/y/z" "a/b/c" ] + # This is necessary if the paths contains files in subdirectories, + # since source filtering will not descend into directories that do not + # themselves pass the filter. + prefix_paths = + let xs = genericClosure { startSet = map (p: { key = p; }) paths; + operator = { key }: let d = builtins.dirOf key; + s = { key = d; }; + in if d == "." then [] else [s]; }; + in map (x: x.key) xs; + + abs_path = path: toString (root + "/${path}"); + prefix_path_set = listToAttrs (map (p: nameValuePair (abs_path p) null) prefix_paths); + abs_paths = map abs_path paths; + is_desc_path = desc: builtins.any (anc: lib.hasPrefix "${anc}/" desc) abs_paths; + + filter = path: type: + let p = toString path; in hasAttr p prefix_path_set || is_desc_path p; + + in filter; + + explicit_sources = name: src: paths: lib.cleanSourceWith { + inherit name src; + filter = explicit_sources_filter src paths; + }; + + # A conjunction of two-place predicates. Can be used with path filters. + conj2 = preds: x: y: builtins.all (pred: pred x y) preds; + + # A conjunction of two-place predicates. Can be used with path filters. + disj2 = preds: x: y: builtins.any (pred: pred x y) preds; + + # Apply a function to an attribute value, + # with a default if the attribute is not present. + maybeGetAttr = name: set: default: f: + if lib.hasAttr name set then f set.${name} else default; + + # Creates a shell script snippet that checks for the presence of given + # environment variables, and prints a usage message if not. + mk-check-env = { name, env }: + let + env-checks = map (var: ''[ -z ''${${var}+x} ]'') (lib.attrNames env); + choices-check = var: info: + let + check = c: ''[ "''${${var}}" != "${c}" ]''; + checks = cs: '' + if ${lib.concatStringsSep " && " (map check cs)}; then + echo "${name}: error: invalid value for environment variable ${var}" >&2 + echo "${name}: valid choices for ${var}: ${lib.concatStringsSep " " cs}" >&2 + exit 1 + fi + ''; + in maybeGetAttr "choices" info "" (cs: lib.removeSuffix "\n" (checks cs)); + choices-checks = + lib.concatStringsSep "\n" (lib.mapAttrsToList choices-check env); + var-usage = var: info: + let + choices = maybeGetAttr "choices" info [] + (cs: ["one of: ${lib.concatStringsSep " " cs}"]); + help_ls = maybeGetAttr "help" info [] (h: [h]); + help = lib.concatStringsSep " - " ([var] ++ help_ls ++ choices); + in ''echo " ${help}" >&2''; + exists-check = '' + if ${lib.concatStringsSep " || " env-checks}; then + echo "${name}: error: required environment variables not set" >&2 + echo "${name}: required variables:" >&2 + ${lib.concatStringsSep "\n " (lib.mapAttrsToList var-usage env)} + exit 1 + fi + ''; + in lib.removeSuffix "\n" (exists-check + choices-checks); + +} diff --git a/objdump.py b/objdump.py index 22187e64..41890288 100644 --- a/objdump.py +++ b/objdump.py @@ -11,131 +11,138 @@ import re def build_syms (symtab): - syms = {} - for line in symtab: - bits = line.split () - try: - addr = int (bits[0], 16) - size = int (bits[-2], 16) - section = bits[-3] - syms[bits[-1]] = (addr, size, section) - except ValueError: - pass - except IndexError: - pass - - sections = {} - for (addr, size, section) in syms.itervalues (): - if not size: - continue - (start, end) = sections.get (section, (addr, addr)) - start = min (addr, start) - end = max (addr + size - 1, end) - sections[section] = (start, end) - - return (syms, sections) + syms = {} + for line in symtab: + bits = line.split () + try: + addr = int (bits[0], 16) + size = int (bits[-2], 16) + section = bits[-3] + syms[bits[-1]] = (addr, size, section) + except ValueError: + pass + except IndexError: + pass + + sections = {} + for (addr, size, section) in syms.itervalues (): + if not size: + continue + (start, end) = sections.get (section, (addr, addr)) + start = min (addr, start) + end = max (addr + size - 1, end) + sections[section] = (start, end) + + return (syms, sections) def install_syms (symtab): - (syms, sects) = build_syms (symtab) - import target_objects - target_objects.symbols.update (syms) - target_objects.sections.update (sects) + (syms, sects) = build_syms (symtab) + import target_objects + target_objects.symbols.update (syms) + target_objects.sections.update (sects) is_rodata_line = re.compile('^\s*[0-9a-fA-F]+:\s+[0-9a-fA-F]+\s+') def build_rodata (rodata_stream, rodata_ranges = [('Section', '.rodata')]): - from syntax import structs, fresh_name, Struct, mk_word32 - import syntax - from target_objects import symbols, sections, trace - - act_rodata_ranges = [] - for (kind, nm) in rodata_ranges: - if kind == 'Symbol': - (addr, size, _) = symbols[nm] - act_rodata_ranges.append ((addr, addr + size - 1)) - elif kind == 'Section': - if nm in sections: - act_rodata_ranges.append (sections[nm]) - else: - # it's reasonable to supply .rodata as the - # expected section only for it to be missing - trace ('No %r section in objdump.' % nm) - else: - assert kind in ['Symbol', 'Section'], rodata_ranges - - comb_ranges = [] - for (start, end) in sorted (act_rodata_ranges): - if comb_ranges and comb_ranges[-1][1] + 1 == start: - (start, _) = comb_ranges[-1] - comb_ranges[-1] = (start, end) - else: - comb_ranges.append ((start, end)) - - rodata = {} - for line in rodata_stream: - if not is_rodata_line.match (line): - continue - bits = line.split () - (addr, v) = (int (bits[0][:-1], 16), int (bits[1], 16)) - if [1 for (start, end) in comb_ranges - if start <= addr and addr <= end]: - assert addr % 4 == 0, addr - rodata[addr] = v - - if len (comb_ranges) == 1: - rodata_names = ['rodata_struct'] - else: - rodata_names = ['rodata_struct_%d' % (i + 1) - for (i, _) in enumerate (comb_ranges)] - - rodata_ptrs = [] - for ((start, end), name) in zip (comb_ranges, rodata_names): - struct_name = fresh_name (name, structs) - struct = Struct (struct_name, (end - start) + 1, 1) - structs[struct_name] = struct - typ = syntax.get_global_wrapper (struct.typ) - rodata_ptrs.append ((mk_word32 (start), typ)) - - return (rodata, comb_ranges, rodata_ptrs) + from syntax import structs, fresh_name, Struct, mk_word64, mk_word32, mk_word16 + import syntax + from target_objects import symbols, sections, trace + + act_rodata_ranges = [] + for (kind, nm) in rodata_ranges: + if kind == 'Symbol': + (addr, size, _) = symbols[nm] + act_rodata_ranges.append ((addr, addr + size - 1)) + elif kind == 'Section': + if nm in sections: + act_rodata_ranges.append (sections[nm]) + else: + # it's reasonable to supply .rodata as the + # expected section only for it to be missing + trace ('No %r section in objdump.' % nm) + else: + assert kind in ['Symbol', 'Section'], rodata_ranges + + comb_ranges = [] + for (start, end) in sorted (act_rodata_ranges): + if comb_ranges and comb_ranges[-1][1] + 1 == start: + (start, _) = comb_ranges[-1] + comb_ranges[-1] = (start, end) + else: + comb_ranges.append ((start, end)) + + rodata = {} + for line in rodata_stream: + if not is_rodata_line.match (line): + continue + bits = line.split () + (addr, v) = (int (bits[0][:-1], 16), int (bits[1], 16)) + if [1 for (start, end) in comb_ranges if start <= addr and addr <= end]: + if len(bits[1]) > 4: + # RISC-V rodata is little-endian + rodata[addr] = v & 0xffff + rodata[addr + 2] = v >> 16 + else: + rodata[addr] = v + + if len (comb_ranges) == 1: + rodata_names = ['rodata_struct'] + else: + rodata_names = ['rodata_struct_%d' % (i + 1) + for (i, _) in enumerate (comb_ranges)] + + rodata_ptrs = [] + for ((start, end), name) in zip (comb_ranges, rodata_names): + struct_name = fresh_name (name, structs) + struct = Struct (struct_name, (end - start) + 1, 1) + structs[struct_name] = struct + typ = syntax.get_global_wrapper (struct.typ) + if syntax.arch.is_64bit: + mk_word = mk_word16 + else: + mk_word = mk_word32 + rodata_ptrs.append ((mk_word(start), typ)) + + return (rodata, comb_ranges, rodata_ptrs) def install_rodata (rodata_stream, rodata_ranges = [('Section', '.rodata')]): - import target_objects - rodata = build_rodata (rodata_stream, rodata_ranges) - target_objects.rodata[:] = rodata + import target_objects + rodata = build_rodata (rodata_stream, rodata_ranges) + target_objects.rodata[:] = rodata # the prunes file isn't really an objdump file, but this seems the best place non_var_re = re.compile('[(),\s\[\]]+') def parse_prunes (prune_stream): - prunes = {} - for l in prune_stream: - [lname, rhs] = l.split ('from [') - bits = lname.split () - assert bits[:3] == ['Pruned', 'inputs', 'of'] - name = bits[3] - [lvars, rvars] = rhs.split ('] to [') - lvars = [v for v in non_var_re.split (lvars) if v] - rvars = [v for v in non_var_re.split (rvars) if v] - if not (lvars[-2:] == ['dm', 'm'] - and rvars[-2:] == ['dm', 'm']): - continue - lvars = lvars[:-2] - rvars = rvars[:-2] - prunes['DecompiledFuns.' + name + '_step'] = (lvars, rvars) - return prunes + prunes = {} + for l in prune_stream: + [lname, rhs] = l.split ('from [') + bits = lname.split () + assert bits[:3] == ['Pruned', 'inputs', 'of'] + name = bits[3] + [lvars, rvars] = rhs.split ('] to [') + lvars = [v for v in non_var_re.split (lvars) if v] + rvars = [v for v in non_var_re.split (rvars) if v] + if not (lvars[-2:] == ['dm', 'm'] + and rvars[-2:] == ['dm', 'm']): + continue + lvars = lvars[:-2] + rvars = rvars[:-2] + prunes['DecompiledFuns.' + name + '_step'] = (lvars, rvars) + return prunes # likewise the signatures produced by the c-parser def parse_sigs (sig_stream): - sigs = {} - for l in sig_stream: - bits = l.split () - if not bits: - continue - ret = int (bits[0]) - nm = bits[1] - args = [int(bit) for bit in bits[2:]] - sigs[nm] = (args, ret) - return sigs + sigs = {} + for l in sig_stream: + bits = l.split () + if not bits: + continue + ret = int (bits[0]) + nm = bits[1] + args = [int(bit) for bit in bits[2:]] + sigs[nm] = (args, ret) + return sigs diff --git a/parallel_solver.py b/parallel_solver.py new file mode 100644 index 00000000..87fd737b --- /dev/null +++ b/parallel_solver.py @@ -0,0 +1,980 @@ +# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) +# +# SPDX-License-Identifier: BSD-2-Clause + +# You must have the typing module in your Python2.7 search path. +# Use `pip2.7 install typing` to install the module. +from typing import (Any, Dict, IO, List, NamedTuple, Optional, Set, Tuple, Union) + +# will probably change in Python3+, not sure how +from StringIO import StringIO + +# backported from Python3.4, use pip2.7 install enum34 +from enum import Enum + +import copy +import os +import select +import signal +import subprocess + +import typed_commons +from typed_commons import (CMIAbort, CMIContinue, CMIResult, CheckModelIterationState, CheckModelIterationVerdict, Hypothesis, LoggerProtocol, PersistableModel, PrintLogger, SMTExpr, OfflineSolverExecution, SolverContextProtocol, SolverImpl, VariableEnvironment) + +import syntax + +# Implementor's notes: +# +# While proving trace refinement, `graph-refine` tends to send large queries +# to SMT solvers. To guarantee good performance (given CPU cycles to burn), +# these queries are best sent to multiple different solvers that run in +# parallel in 'offline' mode (i.e. one dedicated process for each query), +# so that we can take the result from whichever solver returns first. +# +# Running solvers in parallel allows us to exploit the specific performance +# characteristics of different SMT algorithms/implementations, and even lets +# us test the same problem in both the 8-bit and the 32/64-bit (arch-native) +# memory representations. +# +# The `graph-refine` tool uses a 'model-guided' validation process: heuristics +# picking plausible-looking relations between binary and C nodes, with all +# relevant registers related to available C variables of 'matching' types. +# At first, these relations are unlikely to be valid, but `graph-refine` can +# make a request to an SMT solver to produce a model in which the nodes are +# hit, but the relation we picked fails to hold. If the solver indicates that +# the desired conditions are unsatisfiabile ('unsat'), then we have found a +# valid relation; otherwise, the solver responds with the appropriate model, +# which narrows the future iterations of the heuristic search substantially, +# since the model can be used to eliminate more path-condition equalities and +# variable relations than just the ones that were explicitly tested. +# +# Unfortunately, the SMT solvers often return incomplete ("bogus") models, +# which need to be refined via additional queries: this is the so-called +# 'model repair' process. +# +# The purpose of the parallel solvers mechanism (`ParallelTaskManager`) is to +# control and coordinate the execution of parallel solvers. It is provided with +# a list of hypotheses (the 'goals'), and sends them to SMT solvers in parallel +# in an attempt to either confirm all the goals, or refute one of them. In the +# latter case, the `ParallelTaskManager` should also ensure that the provided +# countermodel is not bogus, by executing the model repair procedure if +# necessary. +# +# --- +# +# Implementation-specific Glossary: +# hypothesis: +# A predicate that can be converted to an SMTLIB2-compatible assertion. +# goal: +# A hypothesis that the current parallel query is supposed to confirm, +# and could potentially refute with a counter-model. +# execution: +# The process of a currently running SMT solver implementation, along +# with its associated state (IO streams, filenames, etc.) +# task: +# A single query sent to a singel SMT solver implementation, with the +# aim of confirming or refuting a given list of goals. +# prove task: +# A task that aims to directly confirm or refute a conjunction of +# goals by asking an SMT solver to produce a counter-model. +# model repair task: +# A task that indirectly aims to confirm or refute a conjunction of +# goals (of another, specified task), by repairing a 'bogus' counter- +# model previously returned by another task. +# cancelled (task state): +# A task whose associated process was deliberately terminated by the +# `ParallelTaskManager` before it could produce a result. Usually, +# this happens because the task became redundant during its execution. +# failed (task state): +# A task whose associated process returned, but without confirming or +# refuting any of the goals of the task. +# parent: +# The solver context that instantiated the `ParallelTaskManager`. +# +# Implementation details: +# +# `ParallelTaskManager` is built around (and encapsulates) a single piece of +# mutable state, the 'task pool'. The task pool is a dictionary of tasks +# indexed by unique task identifiers. The pool contains information about all +# the tasks that are currently executing, as well as the tasks that have +# finished executing, and acts as a single source of truth throughout the +# parallel search process. Any other information required to make decisions +# about task management can be (and _is_) calculated on demand by inspecting +# the state of the task pool. +# +# Since the pool contains info about all the tasks that have ever executed, +# the size of the task pool is guaranteed never to decrease throughout the +# search process (i.e. no information is lost). + + +Key = int +KeyedHypothesis = Tuple[Key, Hypothesis] + +SMTResponse = List[str] + + +class TaskStrategy(Enum): + ALL = 0 + HYP = 1 + + +class TaskOutcome(Enum): + confirmed_all = 0 + refuted_some = 1 + sent_for_model_repair = 2 + cancelled = 3 + failed = 4 + + +TaskStateFinished = NamedTuple('TaskStateFinished', [('solver', SolverImpl), + ('filename', str), + ('outcome', TaskOutcome), + ('raw_response', SMTResponse)]) +TaskState = Union[OfflineSolverExecution, TaskStateFinished] +# TaskState.__doc__ = """Stores the state of an offline SMT solver query that +# has finished executing and returned a result. +# +# Attributes: +# solver: The solver implementation that handled the resolved query. +# filename: The name of the file containing the resolved query. +# outcome: The outcome, as determined by interpreting the raw output of the SMT solver. +# raw_response: The full response written by the SMT solver to its standard output, as a list of lines. +# """ + +TaskId = NamedTuple('TaskId', [('to_int', int)]) + +ProveTask = NamedTuple('ProveTask', [('goals', List[KeyedHypothesis]), + ('strategy', TaskStrategy), + ('model', Optional[PersistableModel]), + ('state', TaskState)]) +# ProveTask.__doc__ = """A task that was created to confirm or refute a given conjunction of hypotheses. +# +# The purpose of a `ProveTask` is to confirm or refute a given conjunction of +# goal hypotheses, by sending a query to all eligible solvers in parallel, and +# waiting for at least one of them to respond. +# +# If the task resolves with one of the solver refuting the conjunction, +# a counter-model can be obtained. This counter-model will be built by using +# the solver's output to update the model given in the `model` attribute. +# +# Attributes: +# goals: The given goals whose conjunction is to be confirmed/refuted. +# strategy: If ALL, should be executed by a solver on the 'all' solverlist, otherwise one on the 'hyp' solverlist. +# model: If not `None`, base any potential countermodel on this one. If `None`, do not obtain countermodels. +# state: The execution state of the task. +# """ + +ModelRepairTask = NamedTuple('ModelRepairTask', [('provenance', TaskId), + ('check_model_iteration_result_state', CheckModelIterationState), + ('sent_hypotheses', List[KeyedHypothesis]), + ('model', PersistableModel), + ('state', TaskState)]) +# ModelRepairTask.__doc__ = """A task that was created to repair a counter-model to some hypothesis. +# +# When an SMT solver refutes one of the goals (or a conjunction thereof), it +# provides a countermodel, a variable assignment which makes the relevant +# hypotheses false. Unfortunately, these returned models can be incomplete, +# as determined by a 'bogus model check'. +# +# The purpose of a `ModelRepairTask` is to obtain a more complete countermodel +# to a previously refuted goal. Multiple iterations of the model repair +# procedure (followed by a bogus model check) may be required to fully repair +# a countermodel. +# +# Every `ModelRepairTask` keeps track of its own `provenance`, the finished +# task which triggered the execution of the model repair. This may be a refuted +# `ProveTask` (if the returned model was bogus), or it may be a previous +# `ModelRepairTask` (if its model repair procedure was finished, but the bogus +# model check determined that further repairs are required). +# +# The result of a `ModelRepairTask` is a more complete counter-model. This +# counter-model will be built by using the solver's output to update the model +# stored in the `model` attribute of the task. +# +# Attributes: +# provenance: The id of the finished task which triggered this task. +# check_model_iteration_result_state: Saved state of the bogus model check. +# sent_hypotheses: The list of hypotheses sent to the SMT solver. +# model: The counter-model to update. +# """ + +Task = Union[ProveTask, ModelRepairTask] + + +class ParallelTaskManager: + def __init__(self, parent, goals, environment, model, log=PrintLogger("PTM")): + # type: (SolverContextProtocol, List[KeyedHypothesis], VariableEnvironment, Optional[PersistableModel], LoggerProtocol) -> None + self.parent = parent + self.goals = goals + self.environment = environment + self.model = model + self.task_pool = {} # type: Dict[TaskId, Task] + self.log = log + + # wrapped syntax function helps the type-checker + def wrap_syntax_mk_not(self, x): + # (Any) -> Any + return syntax.mk_not(x) + + # wrapped syntax function helps the type-checker + def wrap_syntax_mk_and(self, x, y): + # (Any,Any) -> Any + return syntax.mk_and(x, y) + + # wrapped syntax function helps the type-checker + def wrap_syntax_foldr1(self, f, xs): + # (Any, Any) -> Any + return syntax.foldr1(f, xs) + + def smt_expr_from_goals(self, the_goals): + # type: (List[KeyedHypothesis]) -> SMTExpr + """Returns an assertable SMT expression stating that the given goals can fail to hold. + + The elements constituting the given list of hypotheses are required + to be 'goals', not arbitrary hypotheses (say ones generated as part + of the model repair procedure). This is because we're looking to find + counter-models to the goals, so the returned assertable `SMTExpr` will + state that one of the goals fails to hold. + + Args: + the_goals: The given goals. + """ + hypotheses = [h for (k, h) in the_goals] + # we want counter-models, so we negate everything in sight + gr_hypothesis = self.wrap_syntax_mk_not(self.wrap_syntax_foldr1(self.wrap_syntax_mk_and, hypotheses)) + return self.parent.smt_expr(gr_hypothesis, self.environment) + + def state_with_outcome(self, the_state, with_outcome): + # type: (TaskStateFinished, TaskOutcome) -> TaskStateFinished + """Returns a copy of the given state, with the value of the `outcome` field replaced by the given outcome. + + Note that this does not perform an in-place update: `TaskState`s are + immutable, so we return a modified copy instead. + + Args: + the_state: The given state. + with_outcome: The given outcome (resulting value of the `outcome` field). + """ + new_state = TaskStateFinished(the_state.solver, + the_state.filename, + with_outcome, + the_state.raw_response) + return new_state + + def task_with_state(self, the_task, with_state): + # type: (Task, TaskState) -> Task + """Returns a copy of the given `Task`, with the value of the `state` field replaced by the given state. + + Note that this does not constitute an in-place update: `Task` objects + are immutable, so we return a modified copy instead. + + Args: + the_task: The given task. + the_state: The given state (resulting value of the `state` field). + """ + if isinstance(the_task, ProveTask): + return ProveTask(the_task.goals, + the_task.strategy, + the_task.model, + with_state) + elif isinstance(the_task, ModelRepairTask): + return ModelRepairTask(the_task.provenance, + the_task.check_model_iteration_result_state, + the_task.sent_hypotheses, + the_task.model, + with_state) + raise TypeError('inexhaustive pattern, unknown Task type %r' % type(the_task)) + + def extract_outcome_from_smt_response(self, the_raw_response): + # type: (SMTResponse) -> TaskOutcome + """Returns the outcome of a task by parsing its given SMT response. + + SMTLIB2-compatible solver implementations respond to the `(check-sat)` + command by searching for a model that satisfies all the previously + asserted formulae. Once the search concludes, the solver prints a + one-line response to its output. A 'sat' response indicates that the + solver has found a model, while an 'unsat' response indicates that + the solver has established the non-existence of such models. Other + responses, including 'unknown' or the empty string '' indicate that an + error occurred, or the search was otherwise inconclusive. + + The `extract-*` methods assume that they are reading the response to a + query emitted by `start_execution`: such queries contain a single + `(check-sat)` command, potentially followed by a `(get-model)` command, + and no other output-producing commands. Consequently, we assume that + the response consist of a one-line `(check-sat)` response, potentially + followed by a multi-line `(get-model)` response. + + This function parses the response, and determines the outcome of the + task that emitted the query (with respect to the associated goals). + + Args: + the_raw_response: The given raw response (to be parsed). + """ + if len(the_raw_response) < 1: + self.log.warning('failed to extract outcome, response was empty') + return TaskOutcome.failed + headline = the_raw_response[0].strip() + if headline == 'unsat': + return TaskOutcome.confirmed_all + if headline == 'sat': + return TaskOutcome.refuted_some + self.log.warning("failed to extract outcome, expected 'sat' or 'unsat', got %r" % headline) + return TaskOutcome.failed + + def extract_model_from_smt_response(self, the_raw_response): + # type: (SMTResponse) -> Optional[PersistableModel] + """Returns the model returned by a refuting task, by parsing its given SMT response. + + SMTLIB2-compatible solver implementations respnod to the `(get-model)` + command by printing a list of SMTLIB definitions specifying values of + all (and only) the user-declared function symbols in a previously found + model satisfying all the previously asserted formulae. + (cf. `extract_outcome_from_smt_response`). An error report is printed + if no such model has been found. + + The `extract-*` methods assume that they are reading the response to a + query emitted by `start_execution`: such queries contain a single + `(check-sat)` command, potentially followed by a `(get-model)` command, + and no other output-producing commands. Consequently, we assume that + the response consist of a one-line `(check-sat)` response, potentially + followed by a multi-line `(get-model)` response. + + This function parses the response, and creates a `PersistableModel` + instance mapping the user-declared function symbols to their + definitions. + + Args: + the_raw_response: The given raw response (to be parsed). + """ + if len(the_raw_response) == 0: + self.log.warning('failed to extract model, response was empty') + return None + # fetch_model_response used to operate straight on the stdout of the solver, + # but was called after the first line of the stdout stream was consumed. + # we "fake" this behavior by turning the raw output into a stream. + the_model = PersistableModel({}) + faux_stream = StringIO('\n'.join(the_raw_response[1:])) # type: IO[str] + got_result = self.parent.fetch_model_response(the_model, stream=faux_stream) + if got_result is None: + # model could not be parsed, or was malformed + self.log.warning('failed to extract model, no parse') + return None + return the_model + + def get_task_by_id(self, task_id): + # type: (TaskId) -> Task + """Returns the task currently associated with the given `TaskId` in the task pool. + + Args: + task_id: The given task id. + """ + if task_id not in self.task_pool: + # Note that we avoid `Optional[Task]` for a reason: + # We don't expect invalid `TaskId`s to arise from thin air, and thus + # something must have gone wrong if we found ourselves here. + self.log.error('no task with id %s' % task_id.to_int) + raise KeyError(task_id) + return self.task_pool[task_id] + + def cancel_task_by_id(self, task_id): + # type: (TaskId) -> None + """Cancels the execution of the task in the task pool with the given task id. + + Args: + task_id: The given task id (whose execution is to be cancelled). + """ + the_task = self.get_task_by_id(task_id) + if not isinstance(the_task.state, OfflineSolverExecution): + self.log.error('cannot cancel non-running task of id %s' % task_id.to_int) + raise TypeError('expected OfflineSolverExecution, got %s' % type(the_task.state)) + assert isinstance(the_task.state, OfflineSolverExecution) + new_state = TaskStateFinished(the_task.state.solver, the_task.state.filename, TaskOutcome.cancelled, []) + the_task.state.kill() + self.task_pool[task_id] = self.task_with_state(the_task, new_state) + return + + def _task_id_to_int(self, task_id): + # type: (TaskId) -> int + # workaround for a mypy bug involving 1-element `NamedTuple`s in "for comprehensions" + return task_id.to_int + + def add_task_to_pool(self, the_task): + # type: (Task) -> TaskId + """Adds the given task to the task pool with a fresh task id, returning the task id. + + Note: this is the only way to mint a new task id. + + Args: + the_task: The given task (to be added to the pool). + """ + # we should probably check that we're not adding the same task twice + indices = [self._task_id_to_int(i) for i in self.task_pool.keys()] # type: List[int] + next_index = 0 if len(indices) == 0 else max(indices) + 1 + self.task_pool[TaskId(next_index)] = the_task + return TaskId(next_index) + + def get_goals_by_id(self, task_id): + # type: (TaskId) -> List[KeyedHypothesis] + """Returns the goals that would be confirmed if the task with the given task id was confirmed. + + Args: + task_id: The given task id. + """ + the_task = self.get_task_by_id(task_id) + if isinstance(the_task, ProveTask): + return the_task.goals + elif isinstance(the_task, ModelRepairTask): + return self.get_goals_by_id(the_task.provenance) + raise TypeError('inexhaustive pattern, unknown Task type %r' % type(the_task)) + + def get_strategy_by_id(self, task_id): + # type: (TaskId) -> TaskStrategy + """Returns the task strategy associated with the task with the given task id. + + Note: each solver implementation (`SolverImpl`) has an associated + strategy (ALL or HYP) determined by the solver context and configured + in a solverlist file. This `TaskStrategy` is used to determine which + solvers receive a particular query. Generally, tasks that attempt to + confirm all goals at once use the ALL strategy, while tasks that aim to + confirm or refute a single hypothesis use the HYP strategy. + + Args: + task_id: The given task id. + """ + the_task = self.get_task_by_id(task_id) + if isinstance(the_task, ProveTask): + return the_task.strategy + elif isinstance(the_task, ModelRepairTask): + return self.get_strategy_by_id(the_task.provenance) + raise TypeError('inexhaustive pattern, unknown Task type %r' % type(the_task)) + + def get_solvers_by_strategy(self, the_strategy): + # type: (TaskStrategy) -> List[SolverImpl] + """Returns a list of SMT solver implementations suitable for use with the given task strategy. + + Note: each solver implementation (`SolverImpl`) has an associated + strategy (ALL or HYP) determined by the solver context and configured + in a solverlist file. This `TaskStrategy` is used to determine which + solvers receive a particular query. Generally, tasks that attempt to + confirm all goals at once use the ALL strategy, while tasks that aim to + confirm or refute a single hypothesis use the HYP strategy. + + Args: + the_strategy: The given strategy. + """ + result = [] # type: List[SolverImpl] + if the_strategy == TaskStrategy.ALL: + result = [solver for (solver, strat) in self.parent.get_strategy() if strat == 'all'] + elif the_strategy == TaskStrategy.HYP: + result = [solver for (solver, strat) in self.parent.get_strategy() if strat == 'hyp'] + else: + raise TypeError('inexhaustive pattern, unknown TaskStrategy type %r' % the_strategy) + if len(result) > 0: + return result + self.log.error('no solvers found for strategy %s' % the_strategy) + raise ValueError('no solvers found for strategy %s' % the_strategy) + + def start_execution(self, the_hypotheses, the_model, the_solver): + # type: (List[SMTExpr], Optional[PersistableModel], SolverImpl) -> OfflineSolverExecution + """Executes the given SMT solver implementation with the given hypotheses, returning an `OfflineSolverExecution` instance containing the resulting process and script file. + + Args: + the_hypotheses: The given hypotheses (to be sent to the solver). + the_model: If not None, generate a request to fetch model values. + the_solver: The given SMT solver implementation (which is to execute the query). + """ + smt_cmds = [] # type: List[str] + for h in the_hypotheses: + smt_cmds.append('(assert %s)' % str(h)) + smt_cmds.append('(check-sat)') + if the_model is not None: + smt_cmds.append(self.parent.fetch_model_request()) + return self.parent.exec_slow_solver(smt_cmds, timeout=the_solver.timeout, solver=the_solver) + + def start_prove_task_with_solver(self, goals, strategy, model, the_solver): + # type: (List[KeyedHypothesis], TaskStrategy, Optional[PersistableModel], SolverImpl) -> TaskId + """Start a task with the aim of confirming or refuting the given goals, using the given SMT solver implementation. + + Args: + goals: The given goals (to be confirmed or refuted). + strategy: The given task strategy. + model: The base model. If not `None`, generate a request to fetch values from the resulting counter-model. + the_solver: The given solver (to be used to settle the query). + """ + smt_hypothesis = self.smt_expr_from_goals(goals) + model_copy = None # type: Optional[PersistableModel] + if model is not None: + model_copy = model.copy() + execution = self.start_execution([smt_hypothesis], model_copy, the_solver) + new_task = ProveTask(goals, strategy, model_copy, execution) + return self.add_task_to_pool(new_task) + + def start_prove_task(self, goals, strategy): + # type: (List[KeyedHypothesis], TaskStrategy) -> List[TaskId] + """Start parallel tasks with the aim of confirming or refuting the given goals, using all solvers configured for the given strategy. + + Args: + goals: The given goals (to be confirmed or refuted). + strategy: The given strategy (used to choose the solvers). + """ + if len(goals) == 0: + self.log.error('attempted to prove zero goals') + raise ValueError('list of goals must be non-empty') + self.log.info('starting prove task for %s goal(s)' % len(goals)) + solvers = self.get_solvers_by_strategy(strategy) + return [self.start_prove_task_with_solver(goals, strategy, self.model, solver) for solver in solvers] + + def start_model_repair_task(self, original_task_id, original_task_model, check_model_iteration_result): + # type: (TaskId, PersistableModel, CMIContinue) -> TaskId + """Start a task with the aim of repairing a given bogus model returned by a previously executed task. + + Args: + original_task_id: The task id of the previously executed task which returned the bogus model. + original_task_model: The given bogus model (to be repaired). + check_model_iteration_result: The result of the bogus model check which triggered the start of this model repair task. + """ + self.log.info('starting model repair task for %s hypotheses' % len(check_model_iteration_result.next_hypotheses)) + new_task_model = original_task_model.copy() + execution = self.start_execution(check_model_iteration_result.next_hypotheses, + new_task_model, + check_model_iteration_result.next_solver) + new_task = ModelRepairTask(provenance=original_task_id, + check_model_iteration_result_state=check_model_iteration_result.state, + sent_hypotheses=check_model_iteration_result.next_hypotheses, + model=new_task_model, + state=execution) + return self.add_task_to_pool(new_task) + + def restart_model_repair_task_change_solver(self, task_id): + # type: (TaskId) -> Optional[TaskId] + """Restarts the failed model repair task with the given id using the next available solver. + + Some SMT solver implementations may enter various error states, such as + segfaults, while performing the model repair process on a model emitted + by a prove task. Since model repairs are not parallel, in the sense + that only one solver works on repairing any given model at any point in + time, such failures may cause the search to stall (with `timeout`). + + This function allows us to restart failed model repair tasks with the + next available solver, skipping the failing solver, and ensuring that + we do not prematurely abandon a model repair attempt due to the failure + of the solver we tried first. + + Args: + task_id: The given task id (of the task to be restarted). + """ + # Uses `TaskId` instead of `ModelRepairTask` since we need to keep + # track of provenance. Fixes the "segfaulting SONOLAR" bug. + the_task = self.get_task_by_id(task_id) + if not isinstance(the_task, ModelRepairTask): + self.log.info('no need to change solver, given task is not ModelRepair') + return None + assert isinstance(the_task, ModelRepairTask) + if not isinstance(the_task.state, TaskStateFinished): + self.log.info('no need to change solver, given task is still running') + return None + assert isinstance(the_task.state, TaskStateFinished) + if the_task.state.outcome != TaskOutcome.failed: + self.log.info('no need to change solver, given task did not fail') + return None + assert the_task.state.outcome == TaskOutcome.failed + the_cmi_state = the_task.check_model_iteration_result_state + next_solvers = the_cmi_state.remaining_solvers # type: List[SolverImpl] + if len(next_solvers) == 0: + self.log.warning('cannot change solver, no further solvers to try') + return None + self.log.info('restarting task with next solver') + new_candidate_model = the_cmi_state.candidate_model.copy() if the_cmi_state.candidate_model is not None else None + new_cmi_state = CheckModelIterationState(the_cmi_state.confirmed_hypotheses, + the_cmi_state.test_hypotheses, + new_candidate_model, + next_solvers[1:]) + new_task_model = the_task.model.copy() + execution = self.start_execution(the_task.sent_hypotheses, + new_task_model, + next_solvers[0]) + new_task = ModelRepairTask(provenance=task_id, + check_model_iteration_result_state=new_cmi_state, + sent_hypotheses=the_task.sent_hypotheses, + model=new_task_model, + state=execution) + return self.add_task_to_pool(new_task) + + def perform_bogus_model_check(self, task_id): + # type: (TaskId) -> None + """Checks if the task with the given id has produced a bogus or invalid counter-model, and if so, initiates model repair tasks. + + When a prove task ends with the refutation of a given goal, we may ask + for a counter-model. The counter-models provided by the SMT solvers + often end up incomplete ('bogus'): if so, additional model repair tasks + have to be executed to refine them into complete models usable by the + 'model-guided' search heuristics of `graph-refine`. + + This function is responsible for checking if the task with the given id + has produced such a bogus model. If so, it starts the appropriate model + repair tasks. Note that a model repair task may itself produce another + bogus model, in which case another model repair task may be started. + + In some cases, the bogus model check can indicate that further repairs + to the current bogus model are futile and won't result in a usable + counter-model. If so, the model repair task is considered a failure and + is abandoned. + + Args: + task_id: The given task id. + """ + the_task = self.get_task_by_id(task_id) + if the_task.model is None: + # we won't be returning the model, so we can skip this check + self.log.info('check elided, no model will be returned') + return + assert the_task.model is not None + if not isinstance(the_task.state, TaskStateFinished): + self.log.info('check elided, given task is still running') + return + assert isinstance(the_task.state, TaskStateFinished) + if the_task.state.outcome != TaskOutcome.refuted_some: + if isinstance(the_task, ProveTask): + # on ModelRepair, we abort later due to inconsistent sat/unsat + self.log.info('check elided, given task did not refute any hypotheses') + return + if isinstance(the_task, ModelRepairTask) and the_task.state.outcome == TaskOutcome.failed: + self.log.warning('given task may have failed due to faulty solver, changing solver') + self.restart_model_repair_task_change_solver(task_id) + response_model = self.extract_model_from_smt_response(the_task.state.raw_response) + if response_model is None: + # The SMT solver claims to have refuted some of our hypotheses, but it did not provide a counter-model + # that we could parse. In accordance with the previous parallel solvers mechanism, we treat this as a + # failed (erroneous) SMT query. + self.log.warning('counter-model not found or could not be parsed, SMT query failed') + the_task.model.persist() + finished_state = self.state_with_outcome(the_task.state, TaskOutcome.failed) + self.task_pool[task_id] = self.task_with_state(the_task, finished_state) + return + assert response_model is not None + smt_hypothesis = self.smt_expr_from_goals(self.get_goals_by_id(task_id)) + state = the_task.check_model_iteration_result_state if isinstance(the_task, ModelRepairTask) else None + response_line = the_task.state.raw_response[0].strip() if len(the_task.state.raw_response) > 0 else 'unknown' + cmi_verdict = self.parent.check_model_iteration([smt_hypothesis], state, (response_line, response_model)) + if isinstance(cmi_verdict, CMIAbort): + # The model is still bogus (incomplete), but the check suggests that no further repair is possible: + # we have either tried all the solvers, or got inconsistent results. In accordance with the previous + # parallel solvers mechanism, we treat this as a failed SMT query. + self.log.info('model is incomplete, irreparable') + the_task.model.persist() + finished_state = self.state_with_outcome(the_task.state, TaskOutcome.failed) + self.task_pool[task_id] = self.task_with_state(the_task, finished_state) + return + elif isinstance(cmi_verdict, CMIContinue): + # The model is still bogus (incomplete), but we might be able to repair it by spawning a ModelRepair task. + self.log.info('model is incomplete, requires repair') + the_task.model.persist() + finished_state = self.state_with_outcome(the_task.state, TaskOutcome.sent_for_model_repair) + self.task_pool[task_id] = self.task_with_state(the_task, finished_state) + self.start_model_repair_task(task_id, the_task.model, cmi_verdict) + return + elif isinstance(cmi_verdict, CMIResult): + # We passed the check, and the model is complete. We return it. + self.log.info('model is complete, check passed') + the_task.model.update(cmi_verdict.candidate_model) + the_task.model.persist() + return + raise TypeError('inexhaustive pattern, unknown CheckModelIterationVerdict type %r' % type(cmi_verdict)) + + def handle_progress(self, task_id): + # type: (TaskId) -> None + """Updates the task pool according to the result of the finished task with the given task id. + + Whenever an SMT solver implementation finishes its execution on a task, + the task outcome has to be determined, and the task pool has to be + updated to reflect the changed status of the task. Moreover, if the + SMT solver produced a refutation, other, auxiliary tasks such as model + repairs may have to be performed. This function performs the required + updates and starts the required auxiliary tasks. + + Args: + task_id: The given task id (whose progress is to be assessed). + """ + the_task = self.get_task_by_id(task_id) + original_state = the_task.state + if not isinstance(original_state, OfflineSolverExecution): + self.log.error('unable to read output, given task not running') + raise TypeError('expected OfflineSolverExecution, got %s' % type(original_state)) + assert isinstance(original_state, OfflineSolverExecution) + out, err = original_state.process.communicate() + the_raw_response = out.splitlines() + original_state.kill() + the_outcome = self.extract_outcome_from_smt_response(the_raw_response) # type: TaskOutcome + finished_state = TaskStateFinished(original_state.solver, original_state.filename, the_outcome, the_raw_response) + self.task_pool[task_id] = self.task_with_state(the_task, finished_state) + self.perform_bogus_model_check(task_id) + return + + def wait_for_progress(self): + # type: () -> List[TaskId] + """Synchronously monitors the currently running tasks until one or more of them make progrss, then returns a list of task ids that made progress. + + This function monitors the file descriptors of the currently running + tasks (as given in the task pool), waiting until one or more of the + descriptors become 'ready to read'. The task ids of the tasks that + are associated to the ready to read file descriptors are returned. + """ + running_tasks = [(task.state, id) for (id, task) in self.task_pool.iteritems() if isinstance(task.state, OfflineSolverExecution)] # type: List[Tuple[OfflineSolverExecution, TaskId]] + running_executions = [state for (state, id) in running_tasks] # type: List[OfflineSolverExecution] + # synchronously wait for IO completion using select: + # eventually returns a list of ready-to-read OfflineSolverExecution objects + (ready_to_read, _, _) = select.select(running_executions, [], []) + task_id_by_execution = dict(running_tasks) + return [task_id_by_execution[execution] for execution in ready_to_read] + + def collect_explicit_refutations(self): + # type: () -> List[TaskId] + """Returns the list of all current explicit refutations in the task pool. + + A task is considered a 'refutation' if it finished with one or more of + the associated goals (see above for the distinction between goals and + other hypotheses) being refuted. + + If a refutation has multiple associated goals, then all we know is that + we have a counter-model that refutes one of the goal hypotheses, but + we may not know precisely which hypotheses were refuted. Currently, + `graph-refine` is not equipped to figure out the precise identity of + the refuted hypotheses. Hence, we consider a refutation 'explicit' when + it has exactly one associated goal. + """ + refutations = [] # List[TaskId] + for (id, task) in self.task_pool.iteritems(): + if isinstance(task.state, TaskStateFinished): + is_refutation = task.state.outcome == TaskOutcome.refuted_some + is_explicit = len(self.get_goals_by_id(id)) == 1 + if is_refutation and is_explicit: + refutations.append(id) + return refutations + + def collect_confirmed_goals(self): + # type: () -> Set[KeyedHypothesis] + """Returns the set of all current goals that have been confirmed by any task in the task pool.""" + confirmed_goals = [] # type: List[KeyedHypothesis] + for (id, task) in self.task_pool.iteritems(): + if isinstance(task.state, TaskStateFinished) and task.state.outcome == TaskOutcome.confirmed_all: + confirmed_goals.extend(self.get_goals_by_id(id)) + return set(confirmed_goals) + + def collect_explicit_refuted_goals(self): + # type: () -> Set[KeyedHypothesis] + """Returns the set of all current goals that have been explicitly refuted by some task in the task pool. + + See `collect_explicit_refutations` for the "refuted / explicit refuted" + distinction. + """ + explicit_refuted_goals = [] # List[TaskId] + for (id, task) in self.task_pool.iteritems(): + if isinstance(task.state, TaskStateFinished): + is_refutation = task.state.outcome == TaskOutcome.refuted_some + task_goals = self.get_goals_by_id(id) + is_explicit = len(task_goals) == 1 + if is_refutation and is_explicit: + explicit_refuted_goals.extend(task_goals) + return set(explicit_refuted_goals) + + def collect_running_goals(self): + # type: () -> Set[KeyedHypothesis] + """Returns the set of all current goals that are associated to a currently running task in the task pool.""" + running_goals = [] # type: List[KeyedHypothesis] + for (id, task) in self.task_pool.iteritems(): + if isinstance(task.state, OfflineSolverExecution): + running_goals.extend(self.get_goals_by_id(id)) + return set(running_goals) + + def collect_explicit_running_goals(self): + # type: () -> Set[KeyedHypothesis] + """Returns the set of all current goals that are associated to a currently running explicit task in the task pool. + + A task is considered 'explicit' if it has exactly one associated goal. + See `collect_explicit_refutations` for more on the "refuted / explicit + refuted" distinction. + """ + explicit_running_goals = [] # type: List[KeyedHypothesis] + for (id, task) in self.task_pool.iteritems(): + if isinstance(task.state, OfflineSolverExecution) and self.get_strategy_by_id(id) == TaskStrategy.HYP: + explicit_running_goals.extend(self.get_goals_by_id(id)) + return set(explicit_running_goals) + + def collect_explicit_failed_goals(self): + # type: () -> Set[KeyedHypothesis] + """Returns the set of all current goals whose associated explicit tasks have all finished and failed to produce a result. + + A task is considered 'explicit' if it has exactly one associated goal. + See `collect_explicit_refutations` for more on the "refuted / explicit + refuted" distinction. + + A goal that is not associated to any (running or finished) task is not + considered failed: we consider a task failed only if we have started + _explicit_ attempts to confirm or refute it, but all such attempts have + ended in failure. + """ + confirmed_goals = self.collect_confirmed_goals() + explicit_refuted_goals = self.collect_explicit_refuted_goals() + running_goals = self.collect_explicit_running_goals() + excluded_goals = confirmed_goals.union(explicit_refuted_goals).union(running_goals) + + explicit_failed_goals = [] # type: List[KeyedHypothesis] + for (id, task) in self.task_pool.iteritems(): + if isinstance(task.state, TaskStateFinished): + is_failed = task.state.outcome == TaskOutcome.failed + task_goals = set(self.get_goals_by_id(id)) # type: Set[KeyedHypothesis] + is_explicit = len(task_goals) == 1 + if is_explicit and is_failed and task_goals.isdisjoint(excluded_goals): + explicit_failed_goals.extend(task_goals) + return set(explicit_failed_goals) + + def collect_startable_goals(self): + # type: () -> List[KeyedHypothesis] + """Returns the set of all current goals for which an explicit task can still be started.""" + settled_goals = self.collect_confirmed_goals().union(self.collect_explicit_refuted_goals()).union(self.collect_explicit_failed_goals()) + return [goal for goal in self.goals if goal not in settled_goals] + + def cancel_redundant_tasks(self): + # type: () -> None + """Cancels all currently running redundant tasks in the task pool. + + A task is considered redundant if either + a.) all its associated goals have already been confirmed; or + b.) one of its associated goals has already been explicitly refuted; + as determined by inspecting the task pool. + + Note that redundant tasks can always be safely cancelled: letting them + run to completion would not increase the number of confirmed or + explicit refuted hypotheses. + """ + confirmed_goals = self.collect_confirmed_goals() + refuted_goals = self.collect_explicit_refuted_goals() + running_tasks = [id for (id, task) in self.task_pool.iteritems() if isinstance(task.state, OfflineSolverExecution)] + for task_id in running_tasks: + task_goals = set(self.get_goals_by_id(task_id)) + if not task_goals.isdisjoint(refuted_goals): + self.log.info('task %s redundant, one of its goals is already refuted' % task_id.to_int) + self.cancel_task_by_id(task_id) + if task_goals.issubset(confirmed_goals): + self.log.info('task %s redundant, all of its goals are already confirmed' % task_id.to_int) + self.cancel_task_by_id(task_id) + + def cancel_all_tasks(self): + # type: () -> None + """Cancels all currently running tasks.""" + running_tasks = [id for (id, task) in self.task_pool.iteritems() if isinstance(task.state, OfflineSolverExecution)] + for task_id in running_tasks: + self.cancel_task_by_id(task_id) + return + + def start_next_explicit_goal(self): + # type: () -> List[TaskId] + """Starts an explicit prove task with the next startable goal (if there are any) as its only goal. Returns the list of tasks it started.""" + explicit_running_goals = self.collect_explicit_running_goals() + if len(explicit_running_goals) > 0: + self.log.info('cannot start new goal, other explicit goals still running') + return [] + startable_goals = self.collect_startable_goals() + if len(startable_goals) == 0: + self.log.warning('cannot start new goal, no further goals to start') + return [] + goal = startable_goals[0] + return self.start_prove_task([goal], TaskStrategy.HYP) + + def print_task_pool(self): + # type: () -> None + """Prints a human-readable summary of the current state of the task pool.""" + self.log.raw_print('\ncurrent task pool:\n--- [') + for (id, task) in self.task_pool.iteritems(): + outcome = 'running' # type: str + if isinstance(task.state, TaskStateFinished): + outcome = '%s %s on %s' % (str(task.state.outcome)[12:], task.state.solver.origname, task.state.filename) + elif isinstance(task.state, OfflineSolverExecution): + outcome = str(task.state) + outcome = outcome[:87] + task_type = ' ' + if isinstance(task, ModelRepairTask): + task_type = 'MR%s' % task.provenance.to_int + elif isinstance(task, ProveTask): + task_type = '%s' % (str(task.strategy)[13:]) + self.log.raw_print('%s %s %s' % (id.to_int, task_type, outcome)) + self.log.raw_print('--- ]\n') + + def main_loop(self): + # type: () -> Tuple[str, Optional[Key], Optional[PersistableModel]] + """Repeatedly performs task management operations until a termination condition is reached. Returns the final verdict. + + This function attempts to confirm all goals passed to the + `ParallelTaskManager`. The loop monitors the task pool, and attempts to + start new explicit prove tasks to settle as-yet-unconfirmed hypotheses. + + The loop terminates once all goals are confirmed; some goal is + explicitly refuted; or there are no further tasks left to start. + + The final verdict is returned in the non-typechecked format expected by + `parallel_check_hyps`. + """ + if len(self.goals) > 1: + self.start_prove_task(self.goals, TaskStrategy.ALL) + target_goals = set(self.goals) + + while True: + self.print_task_pool() + confirmed_goals = self.collect_confirmed_goals() + refutations = self.collect_explicit_refutations() + running_goals = self.collect_running_goals() + explicit_running_goals = self.collect_explicit_running_goals() + + self.log.info('checking termination conditions...') + if len(refutations) > 0: + # we have explicitly refuted some hypothesis + refuted_task_id = refutations[0] # type: TaskId + refuted_model = self.get_task_by_id(refuted_task_id).model + refuted_goals = self.get_goals_by_id(refuted_task_id) # type: List[KeyedHypothesis] + self.log.info('termination condition reached, task %s refuted a hypothesis' % refuted_task_id.to_int) + self.cancel_all_tasks() + assert len(refuted_goals) == 1 + return ('sat', refuted_goals[0][0], refuted_model) + if target_goals.issubset(confirmed_goals): + # we have confirmed all hypotheses + self.log.info('termination condition reached, all hypotheses confirmed') + self.cancel_all_tasks() + return ('unsat', None, None) + if len(explicit_running_goals) == 0: + # no explicit tasks, try to start a new one + started_tasks = self.start_next_explicit_goal() + if len(started_tasks) == 0 and len(running_goals) == 0: + # there were no tasks left to start + self.log.warning('termination condition reached, all solvers failed or timed out') + return ('timeout', None, None) + self.log.info('termination conditions not satisfied, continuing') + + self.log.info('waiting for SMT solvers to make progress...') + progressed_tasks = self.wait_for_progress() # type: List[TaskId] + self.log.info('%s solver(s) made progress' % len(progressed_tasks)) + for task_id in progressed_tasks: + self.log.info('handling progress on task %s' % task_id.to_int) + self.handle_progress(task_id) + self.log.info('handled progress from %s solver(s)' % len(progressed_tasks)) + + self.cancel_redundant_tasks() + continue + + def run(self): + # type: () -> Tuple[str, Optional[Key], Optional[PersistableModel]] + """Repeatedly performs task management operations on the task pool until all goals are confirmed, or some goal is explicitly refuted. Returns the final verdict. + + The final verdict is returned in the non-typechecked format expected by + `parallel_check_hyps`. + """ + self.log.info('started with %s goals' % len(self.goals)) + try: + verdict = self.main_loop() + except KeyboardInterrupt, e: + self.log.error('interrupted by user') + self.cancel_all_tasks() + raise e + self.print_task_pool() + self.log.info('finished') + return verdict diff --git a/problem.py b/problem.py index cb19bda6..8ae5549b 100644 --- a/problem.py +++ b/problem.py @@ -5,7 +5,7 @@ # from syntax import (Expr, mk_var, Node, true_term, false_term, - fresh_name, word32T, word8T, mk_eq, mk_word32, builtinTs) + fresh_name, word32T, word8T, mk_eq, mk_word32, builtinTs) import syntax from target_objects import functions, pairings, trace, printout @@ -14,827 +14,824 @@ from logic import azip class Abort(Exception): - pass + pass last_problem = [None] class Problem: - def __init__ (self, pairing, name = None): - if name == None: - name = pairing.name - self.name = 'Problem (%s)' % name - self.pairing = pairing - - self.nodes = {} - self.vs = {} - self.next_node_name = 1 - self.preds = {} - self.loop_data = {} - self.node_tags = {} - self.node_tag_revs = {} - self.inline_scripts = {} - self.entries = [] - self.outputs = {} - self.tarjan_order = [] - self.loop_var_analysis_cache = {} - - self.known_eqs = {} - self.cached_analysis = {} - self.hook_tag_hints = {} - - last_problem[0] = self - - def fail_msg (self): - return 'FAILED %s (size %05d)' % (self.name, len(self.nodes)) - - def alloc_node (self, tag, detail, loop_id = None, hint = None): - name = self.next_node_name - self.next_node_name = name + 1 - - self.node_tags[name] = (tag, detail) - self.node_tag_revs.setdefault ((tag, detail), []) - self.node_tag_revs[(tag, detail)].append (name) - - if loop_id != None: - self.loop_data[name] = ('Mem', loop_id) - - return name - - def fresh_var (self, name, typ): - name = fresh_name (name, self.vs, typ) - return mk_var (name, typ) - - def clone_function (self, fun, tag): - self.nodes = {} - self.vs = syntax.get_vars (fun) - for n in fun.reachable_nodes (): - self.nodes[n] = fun.nodes[n] - detail = (fun.name, n) - self.node_tags[n] = (tag, detail) - self.node_tag_revs.setdefault ((tag, detail), []) - self.node_tag_revs[(tag, detail)].append (n) - self.outputs[tag] = fun.outputs - self.entries = [(fun.entry, tag, fun.name, fun.inputs)] - self.next_node_name = max (self.nodes.keys () + [2]) + 1 - self.inline_scripts[tag] = [] - - def add_function (self, fun, tag, node_renames, loop_id = None): - if not fun.entry: - printout ('Aborting %s: underspecified %s' % ( - self.name, fun.name)) - raise Abort () - node_renames.setdefault('Ret', 'Ret') - node_renames.setdefault('Err', 'Err') - new_node_renames = {} - vs = syntax.get_vars (fun) - vs = dict ([(v, fresh_name (v, self.vs, vs[v])) for v in vs]) - ns = fun.reachable_nodes () - check_no_symbols ([fun.nodes[n] for n in ns]) - for n in ns: - assert n not in node_renames - node_renames[n] = self.alloc_node (tag, (fun.name, n), - loop_id = loop_id, hint = n) - new_node_renames[n] = node_renames[n] - for n in ns: - self.nodes[node_renames[n]] = syntax.copy_rename ( - fun.nodes[n], (vs, node_renames)) - - return (new_node_renames, vs) - - def add_entry_function (self, fun, tag): - (ns, vs) = self.add_function (fun, tag, {}) - - entry = ns[fun.entry] - args = [(vs[v], typ) for (v, typ) in fun.inputs] - rets = [(vs[v], typ) for (v, typ) in fun.outputs] - self.entries.append((entry, tag, fun.name, args)) - self.outputs[tag] = rets - - self.inline_scripts[tag] = [] - - return (args, rets, entry) - - def get_entry_details (self, tag): - [(e, t, fname, args)] = [(e, t, fname, args) - for (e, t, fname, args) in self.entries if t == tag] - return (e, fname, args) - - def get_entry (self, tag): - (e, fname, args) = self.get_entry_details (tag) - return e - - def tags (self): - return self.outputs.keys () - - def entry_exit_renames (self, tags = None): - """computes the rename set of a function's formal parameters - to the actual input/output variable names at the various entry - and exit points""" - mk = lambda xs, ys: dict ([(x[0], y[0]) for (x, y) in - azip (xs, ys)]) - renames = {} - if tags == None: - tags = self.tags () - for tag in tags: - (_, fname, args) = self.get_entry_details (tag) - fun = functions[fname] - out = self.outputs[tag] - renames[tag + '_IN'] = mk (fun.inputs, args) - renames[tag + '_OUT'] = mk (fun.outputs, out) - return renames - - def redirect_conts (self, reds): - for node in self.nodes.itervalues(): - if node.kind == 'Cond': - node.left = reds.get(node.left, node.left) - node.right = reds.get(node.right, node.right) - else: - node.cont = reds.get(node.cont, node.cont) - - def do_analysis (self): - self.cached_analysis.clear () - self.compute_preds () - self.do_loop_analysis () - - def mk_node_graph (self, node_subset = None): - if node_subset == None: - node_subset = self.nodes - return dict ([(n, [c for c in self.nodes[n].get_conts () - if c in node_subset]) - for n in node_subset]) - - def do_loop_analysis (self): - entries = [e for (e, tag, nm, args) in self.entries] - self.loop_data = {} - - graph = self.mk_node_graph () - comps = logic.tarjan (graph, entries) - self.tarjan_order = [] - - for (head, tail) in comps: - self.tarjan_order.append (head) - self.tarjan_order.extend (tail) - if not tail and head not in graph[head]: - continue - trace ('Loop (%d, %s)' % (head, tail)) - - loop_set = set (tail) - loop_set.add (head) - - r = self.force_single_loop_return (head, loop_set) - if r != None: - tail.append (r) - loop_set.add (r) - self.tarjan_order.append (r) - self.compute_preds () - - self.loop_data[head] = ('Head', loop_set) - for t in tail: - self.loop_data[t] = ('Mem', head) - - # put this in first-to-last order. - self.tarjan_order.reverse () - - def check_no_inner_loops (self): - for loop in self.loop_heads (): - check_no_inner_loop (self, loop) - - def force_single_loop_return (self, head, loop_set): - rets = [n for n in self.preds[head] if n in loop_set] - if (len (rets) == 1 and rets[0] != head and - self.nodes[rets[0]].is_noop ()): - return None - r = self.alloc_node (self.node_tags[head][0], - 'LoopReturn', loop_id = head) - self.nodes[r] = Node ('Basic', head, []) - for r2 in rets: - self.nodes[r2] = syntax.copy_rename (self.nodes[r2], - ({}, {head: r})) - return r - - def splittable_points (self, n): - """splittable points are points which when removed, the loop - 'splits' and ceases to be a loop. - - equivalently, the set of splittable points is the intersection - of all sub-loops of the loop.""" - head = self.loop_id (n) - assert head != None - k = ('Splittables', head) - if k in self.cached_analysis: - return self.cached_analysis[k] - - # check if the head point is a split (the inner loop - # check does exactly that) - if has_inner_loop (self, head): - head = logic.get_one_loop_splittable (self, - self.loop_body (head)) - if head == None: - return set () - - splits = self.get_loop_splittables (head) - self.cached_analysis[k] = splits - return splits - - def get_loop_splittables (self, head): - loop_set = self.loop_body (head) - splittable = dict ([(n, False) for n in loop_set]) - arc = [head] - n = head - while True: - ns = [n2 for n2 in self.nodes[n].get_conts () - if n2 in loop_set] - ns2 = [x for x in ns if x == head or x not in arc] - #n = ns[0] - n = ns2[0] - arc.append (n) - splittable[n] = True - if n == head: - break - last_descs = {} - for i in range (len (arc)): - last_descs[arc[i]] = i - def last_desc (n): - if n in last_descs: - return last_descs[n] - n2s = [n2 for n2 in self.nodes[n].get_conts() - if n2 in loop_set] - last_descs[n] = None - for n2 in n2s: - x = last_desc(n2) - if last_descs[n] == None or x >= last_descs[n]: - last_descs[n] = x - return last_descs[n] - for i in range (len (arc)): - max_arc = max ([last_desc (n) - for n in self.nodes[arc[i]].get_conts () - if n in loop_set]) - for j in range (i + 1, max_arc): - splittable[arc[j]] = False - return set ([n for n in splittable if splittable[n]]) - - def loop_heads (self): - return [n for n in self.loop_data - if self.loop_data[n][0] == 'Head'] - - def loop_id (self, n): - if n not in self.loop_data: - return None - elif self.loop_data[n][0] == 'Head': - return n - else: - assert self.loop_data[n][0] == 'Mem' - return self.loop_data[n][1] - - def loop_body (self, n): - head = self.loop_id (n) - return self.loop_data[head][1] - - def compute_preds (self): - self.preds = logic.compute_preds (self.nodes) - - def var_dep_outputs (self, n): - return self.outputs[self.node_tags[n][0]] - - def compute_var_dependencies (self): - if 'var_dependencies' in self.cached_analysis: - return self.cached_analysis['var_dependencies'] - var_deps = logic.compute_var_deps (self.nodes, - self.var_dep_outputs, self.preds) - var_deps2 = dict ([(n, dict ([(v, None) - for v in var_deps.get (n, [])])) - for n in self.nodes]) - self.cached_analysis['var_dependencies'] = var_deps2 - return var_deps2 - - def get_loop_var_analysis (self, var_deps, n): - head = self.loop_id (n) - assert head, n - assert n in self.splittable_points (n) - loop_sort = tuple (sorted (self.loop_body (head))) - node_data = [(self.nodes[n2], sorted (self.preds[n]), - sorted (var_deps[n2].keys ())) - for n2 in loop_sort] - k = (n, loop_sort) - data = (node_data, n) - if k in self.loop_var_analysis_cache: - for (data2, va) in self.loop_var_analysis_cache[k]: - if data2 == data: - return va - va = logic.compute_loop_var_analysis (self, var_deps, n) - group = self.loop_var_analysis_cache.setdefault (k, []) - group.append ((data, va)) - del group[:-10] - return va - - def save_graph (self, fname): - cols = mk_graph_cols (self.node_tags) - save_graph (self.nodes, fname, cols = cols, - node_tags = self.node_tags) - - def save_graph_summ (self, fname): - node_ids = {} - def is_triv (n): - if n not in self.nodes: - return False - if len (self.preds[n]) != 1: - return False - node = self.nodes[n] - if node.kind == 'Basic': - return (True, node.cont) - elif node.kind == 'Cond' and node.right == 'Err': - return (True, node.left) - else: - return False - for n in self.nodes: - if n in node_ids: - continue - ns = [] - while is_triv (n): - ns.append (n) - n = is_triv (n)[1] - for n2 in ns: - node_ids[n2] = n - nodes = {} - for n in self.nodes: - if is_triv (n): - continue - nodes[n] = syntax.copy_rename (self.nodes[n], - ({}, node_ids)) - cols = mk_graph_cols (self.node_tags) - save_graph (nodes, fname, cols = cols, - node_tags = self.node_tags) - - def serialise (self): - ss = ['Problem'] - for (n, tag, fname, inputs) in self.entries: - xs = ['Entry', '%d' % n, tag, fname, - '%d' % len (inputs)] - for (nm, typ) in inputs: - xs.append (nm) - typ.serialise (xs) - xs.append ('%d' % len (self.outputs[tag])) - for (nm, typ) in self.outputs[tag]: - xs.append (nm) - typ.serialise (xs) - ss.append (' '.join (xs)) - for n in self.nodes: - xs = ['%d' % n] - self.nodes[n].serialise (xs) - ss.append (' '.join (xs)) - ss.append ('EndProblem') - return ss - - def save_serialise (self, fname): - ss = self.serialise () - f = open (fname, 'w') - for s in ss: - f.write (s + '\n') - f.close () - - def pad_merge_points (self): - self.compute_preds () - - arcs = [(pred, n) for n in self.preds - if len (self.preds[n]) > 1 - if n in self.nodes - for pred in self.preds[n] - if (self.nodes[pred].kind != 'Basic' - or self.nodes[pred].upds != [])] - - for (pred, n) in arcs: - (tag, _) = self.node_tags[pred] - name = self.alloc_node (tag, 'MergePadding') - self.nodes[name] = Node ('Basic', n, []) - self.nodes[pred] = syntax.copy_rename (self.nodes[pred], - ({}, {n: name})) - - def function_call_addrs (self): - return [(n, self.nodes[n].fname) - for n in self.nodes if self.nodes[n].kind == 'Call'] - - def function_calls (self): - return set ([fn for (n, fn) in self.function_call_addrs ()]) - - def get_extensions (self): - if 'extensions' in self.cached_analysis: - return self.cached_analysis['extensions'] - extensions = set () - for node in self.nodes.itervalues (): - extensions.update (syntax.get_extensions (node)) - self.cached_analysis['extensions'] = extensions - return extensions - - def replay_inline_script (self, tag, script): - for (detail, idx, fname) in script: - n = self.node_tag_revs[(tag, detail)][idx] - assert self.nodes[n].kind == 'Call', self.nodes[n] - assert self.nodes[n].fname == fname, self.nodes[n] - inline_at_point (self, n, do_analysis = False) - if script: - self.do_analysis () - - def is_reachable_from (self, source, target): - '''discover if graph addr "target" is reachable - from starting node "source"''' - k = ('is_reachable_from', source) - if k in self.cached_analysis: - reachable = self.cached_analysis[k] - if target in reachable: - return reachable[target] - - reachable = {} - visit = [source] - while visit: - n = visit.pop () - if n not in self.nodes: - continue - for n2 in self.nodes[n].get_conts (): - if n2 not in reachable: - reachable[n2] = True - visit.append (n2) - for n in list (self.nodes) + ['Ret', 'Err']: - if n not in reachable: - reachable[n] = False - self.cached_analysis[k] = reachable - return reachable[target] - - def is_reachable_without (self, cutpoint, target): - '''discover if graph addr "target" is reachable - without visiting node "cutpoint" - (an oddity: cutpoint itself is considered reachable)''' - k = ('is_reachable_without', cutpoint) - if k in self.cached_analysis: - reachable = self.cached_analysis[k] - if target in reachable: - return reachable[target] - - reachable = dict ([(self.get_entry (t), True) - for t in self.tags ()]) - for n in self.tarjan_order + ['Ret', 'Err']: - if n in reachable: - continue - reachable[n] = bool ([pred for pred in self.preds[n] - if pred != cutpoint - if reachable.get (pred) == True]) - self.cached_analysis[k] = reachable - return reachable[target] + def __init__ (self, pairing, name = None): + if name == None: + name = pairing.name + self.name = 'Problem (%s)' % name + self.pairing = pairing + self.nodes = {} + self.vs = {} + self.next_node_name = 1 + self.preds = {} + self.loop_data = {} + self.node_tags = {} + self.node_tag_revs = {} + self.inline_scripts = {} + self.entries = [] + self.outputs = {} + self.tarjan_order = [] + self.loop_var_analysis_cache = {} + + self.known_eqs = {} + self.cached_analysis = {} + self.hook_tag_hints = {} + + last_problem[0] = self + + def fail_msg (self): + return 'FAILED %s (size %05d)' % (self.name, len(self.nodes)) + + def alloc_node (self, tag, detail, loop_id = None, hint = None): + name = self.next_node_name + self.next_node_name = name + 1 + + self.node_tags[name] = (tag, detail) + self.node_tag_revs.setdefault ((tag, detail), []) + self.node_tag_revs[(tag, detail)].append (name) + + if loop_id != None: + self.loop_data[name] = ('Mem', loop_id) + + return name + + def fresh_var (self, name, typ): + name = fresh_name (name, self.vs, typ) + return mk_var (name, typ) + + def clone_function (self, fun, tag): + self.nodes = {} + self.vs = syntax.get_vars (fun) + for n in fun.reachable_nodes (): + self.nodes[n] = fun.nodes[n] + detail = (fun.name, n) + self.node_tags[n] = (tag, detail) + self.node_tag_revs.setdefault ((tag, detail), []) + self.node_tag_revs[(tag, detail)].append (n) + self.outputs[tag] = fun.outputs + self.entries = [(fun.entry, tag, fun.name, fun.inputs)] + self.next_node_name = max (self.nodes.keys () + [2]) + 1 + self.inline_scripts[tag] = [] + + def add_function (self, fun, tag, node_renames, loop_id = None): + if not fun.entry: + printout ('Aborting %s: underspecified %s' % ( + self.name, fun.name)) + raise Abort () + node_renames.setdefault('Ret', 'Ret') + node_renames.setdefault('Err', 'Err') + new_node_renames = {} + vs = syntax.get_vars (fun) + vs = dict ([(v, fresh_name (v, self.vs, vs[v])) for v in vs]) + ns = fun.reachable_nodes () + check_no_symbols (self.name, [fun.nodes[n] for n in ns]) + for n in ns: + assert n not in node_renames + node_renames[n] = self.alloc_node (tag, (fun.name, n), + loop_id = loop_id, hint = n) + new_node_renames[n] = node_renames[n] + for n in ns: + self.nodes[node_renames[n]] = syntax.copy_rename ( + fun.nodes[n], (vs, node_renames)) + + return (new_node_renames, vs) + + def add_entry_function (self, fun, tag): + (ns, vs) = self.add_function (fun, tag, {}) + entry = ns[fun.entry] + args = [(vs[v], typ) for (v, typ) in fun.inputs] + rets = [(vs[v], typ) for (v, typ) in fun.outputs] + self.entries.append((entry, tag, fun.name, args)) + self.outputs[tag] = rets + + self.inline_scripts[tag] = [] + return (args, rets, entry) + + def get_entry_details (self, tag): + [(e, t, fname, args)] = [(e, t, fname, args) + for (e, t, fname, args) in self.entries if t == tag] + return (e, fname, args) + + def get_entry (self, tag): + (e, fname, args) = self.get_entry_details (tag) + return e + + def tags (self): + return self.outputs.keys () + + def entry_exit_renames (self, tags = None): + """computes the rename set of a function's formal parameters + to the actual input/output variable names at the various entry + and exit points""" + mk = lambda xs, ys: dict ([(x[0], y[0]) for (x, y) in + azip (xs, ys)]) + renames = {} + if tags == None: + tags = self.tags () + for tag in tags: + (_, fname, args) = self.get_entry_details (tag) + fun = functions[fname] + out = self.outputs[tag] + renames[tag + '_IN'] = mk (fun.inputs, args) + renames[tag + '_OUT'] = mk (fun.outputs, out) + return renames + + def redirect_conts (self, reds): + for node in self.nodes.itervalues(): + if node.kind == 'Cond': + node.left = reds.get(node.left, node.left) + node.right = reds.get(node.right, node.right) + else: + node.cont = reds.get(node.cont, node.cont) + + def do_analysis (self): + self.cached_analysis.clear () + self.compute_preds () + self.do_loop_analysis () + + def mk_node_graph (self, node_subset = None): + if node_subset == None: + node_subset = self.nodes + return dict ([(n, [c for c in self.nodes[n].get_conts () + if c in node_subset]) + for n in node_subset]) + + def do_loop_analysis (self): + entries = [e for (e, tag, nm, args) in self.entries] + + self.loop_data = {} + + graph = self.mk_node_graph () + comps = logic.tarjan (graph, entries) + self.tarjan_order = [] + for (head, tail) in comps: + self.tarjan_order.append (head) + self.tarjan_order.extend (tail) + if not tail and head not in graph[head]: + continue + trace ('Loop (%d, %s)' % (head, tail)) + + loop_set = set (tail) + loop_set.add (head) + + r = self.force_single_loop_return (head, loop_set) + if r != None: + tail.append (r) + loop_set.add (r) + self.tarjan_order.append (r) + self.compute_preds () + + self.loop_data[head] = ('Head', loop_set) + for t in tail: + self.loop_data[t] = ('Mem', head) + + # put this in first-to-last order. + self.tarjan_order.reverse () + + def check_no_inner_loops (self): + for loop in self.loop_heads (): + check_no_inner_loop (self, loop) + + def force_single_loop_return (self, head, loop_set): + rets = [n for n in self.preds[head] if n in loop_set] + if (len (rets) == 1 and rets[0] != head and + self.nodes[rets[0]].is_noop ()): + return None + r = self.alloc_node (self.node_tags[head][0], + 'LoopReturn', loop_id = head) + self.nodes[r] = Node ('Basic', head, []) + for r2 in rets: + self.nodes[r2] = syntax.copy_rename (self.nodes[r2], + ({}, {head: r})) + return r + + def splittable_points (self, n): + """splittable points are points which when removed, the loop + 'splits' and ceases to be a loop. + + equivalently, the set of splittable points is the intersection + of all sub-loops of the loop.""" + head = self.loop_id (n) + assert head != None + k = ('Splittables', head) + if k in self.cached_analysis: + return self.cached_analysis[k] + + # check if the head point is a split (the inner loop + # check does exactly that) + if has_inner_loop (self, head): + head = logic.get_one_loop_splittable (self, + self.loop_body (head)) + if head == None: + return set () + + splits = self.get_loop_splittables (head) + self.cached_analysis[k] = splits + return splits + + def get_loop_splittables (self, head): + loop_set = self.loop_body (head) + splittable = dict ([(n, False) for n in loop_set]) + arc = [head] + n = head + while True: + ns = [n2 for n2 in self.nodes[n].get_conts () + if n2 in loop_set] + ns2 = [x for x in ns if x == head or x not in arc] + #n = ns[0] + n = ns2[0] + arc.append (n) + splittable[n] = True + if n == head: + break + last_descs = {} + for i in range (len (arc)): + last_descs[arc[i]] = i + def last_desc (n): + if n in last_descs: + return last_descs[n] + n2s = [n2 for n2 in self.nodes[n].get_conts() + if n2 in loop_set] + last_descs[n] = None + for n2 in n2s: + x = last_desc(n2) + if last_descs[n] == None or x >= last_descs[n]: + last_descs[n] = x + return last_descs[n] + for i in range (len (arc)): + max_arc = max ([last_desc (n) + for n in self.nodes[arc[i]].get_conts () + if n in loop_set]) + for j in range (i + 1, max_arc): + splittable[arc[j]] = False + return set ([n for n in splittable if splittable[n]]) + + def loop_heads (self): + return [n for n in self.loop_data + if self.loop_data[n][0] == 'Head'] + + def loop_id (self, n): + if n not in self.loop_data: + return None + elif self.loop_data[n][0] == 'Head': + return n + else: + assert self.loop_data[n][0] == 'Mem' + return self.loop_data[n][1] + + def loop_body (self, n): + head = self.loop_id (n) + return self.loop_data[head][1] + + def compute_preds (self): + self.preds = logic.compute_preds (self.nodes) + + def var_dep_outputs (self, n): + return self.outputs[self.node_tags[n][0]] + + def compute_var_dependencies (self): + if 'var_dependencies' in self.cached_analysis: + return self.cached_analysis['var_dependencies'] + var_deps = logic.compute_var_deps (self.nodes, + self.var_dep_outputs, self.preds) + var_deps2 = dict ([(n, dict ([(v, None) + for v in var_deps.get (n, [])])) + for n in self.nodes]) + self.cached_analysis['var_dependencies'] = var_deps2 + return var_deps2 + + def get_loop_var_analysis (self, var_deps, n): + head = self.loop_id (n) + assert head, n + assert n in self.splittable_points (n) + loop_sort = tuple (sorted (self.loop_body (head))) + node_data = [(self.nodes[n2], sorted (self.preds[n]), + sorted (var_deps[n2].keys ())) + for n2 in loop_sort] + k = (n, loop_sort) + data = (node_data, n) + if k in self.loop_var_analysis_cache: + for (data2, va) in self.loop_var_analysis_cache[k]: + if data2 == data: + return va + va = logic.compute_loop_var_analysis (self, var_deps, n) + group = self.loop_var_analysis_cache.setdefault (k, []) + group.append ((data, va)) + del group[:-10] + return va + + def save_graph (self, fname): + cols = mk_graph_cols (self.node_tags) + save_graph (self.nodes, fname, cols = cols, + node_tags = self.node_tags) + + def save_graph_summ (self, fname): + node_ids = {} + def is_triv (n): + if n not in self.nodes: + return False + if len (self.preds[n]) != 1: + return False + node = self.nodes[n] + if node.kind == 'Basic': + return (True, node.cont) + elif node.kind == 'Cond' and node.right == 'Err': + return (True, node.left) + else: + return False + for n in self.nodes: + if n in node_ids: + continue + ns = [] + while is_triv (n): + ns.append (n) + n = is_triv (n)[1] + for n2 in ns: + node_ids[n2] = n + nodes = {} + for n in self.nodes: + if is_triv (n): + continue + nodes[n] = syntax.copy_rename (self.nodes[n], + ({}, node_ids)) + cols = mk_graph_cols (self.node_tags) + save_graph (nodes, fname, cols = cols, + node_tags = self.node_tags) + + def serialise (self): + ss = ['Problem'] + for (n, tag, fname, inputs) in self.entries: + xs = ['Entry', '%d' % n, tag, fname, + '%d' % len (inputs)] + for (nm, typ) in inputs: + xs.append (nm) + typ.serialise (xs) + xs.append ('%d' % len (self.outputs[tag])) + for (nm, typ) in self.outputs[tag]: + xs.append (nm) + typ.serialise (xs) + ss.append (' '.join (xs)) + for n in self.nodes: + xs = ['%d' % n] + self.nodes[n].serialise (xs) + ss.append (' '.join (xs)) + ss.append ('EndProblem') + return ss + + def save_serialise (self, fname): + ss = self.serialise () + f = open (fname, 'w') + for s in ss: + f.write (s + '\n') + f.close () + + def pad_merge_points (self): + self.compute_preds () + + arcs = [(pred, n) for n in self.preds + if len (self.preds[n]) > 1 + if n in self.nodes + for pred in self.preds[n] + if (self.nodes[pred].kind != 'Basic' + or self.nodes[pred].upds != [])] + + for (pred, n) in arcs: + (tag, _) = self.node_tags[pred] + name = self.alloc_node (tag, 'MergePadding') + self.nodes[name] = Node ('Basic', n, []) + self.nodes[pred] = syntax.copy_rename (self.nodes[pred], + ({}, {n: name})) + + def function_call_addrs (self): + return [(n, self.nodes[n].fname) + for n in self.nodes if self.nodes[n].kind == 'Call'] + + def function_calls (self): + return set ([fn for (n, fn) in self.function_call_addrs ()]) + + def get_extensions (self): + if 'extensions' in self.cached_analysis: + return self.cached_analysis['extensions'] + extensions = set () + for node in self.nodes.itervalues (): + extensions.update (syntax.get_extensions (node)) + self.cached_analysis['extensions'] = extensions + return extensions + + def replay_inline_script (self, tag, script): + for (detail, idx, fname) in script: + n = self.node_tag_revs[(tag, detail)][idx] + assert self.nodes[n].kind == 'Call', self.nodes[n] + assert self.nodes[n].fname == fname, self.nodes[n] + inline_at_point (self, n, do_analysis = False) + if script: + self.do_analysis () + + def is_reachable_from (self, source, target): + '''discover if graph addr "target" is reachable + from starting node "source"''' + k = ('is_reachable_from', source) + if k in self.cached_analysis: + reachable = self.cached_analysis[k] + if target in reachable: + return reachable[target] + + reachable = {} + visit = [source] + while visit: + n = visit.pop () + if n not in self.nodes: + continue + for n2 in self.nodes[n].get_conts (): + if n2 not in reachable: + reachable[n2] = True + visit.append (n2) + for n in list (self.nodes) + ['Ret', 'Err']: + if n not in reachable: + reachable[n] = False + self.cached_analysis[k] = reachable + return reachable[target] + + def is_reachable_without (self, cutpoint, target): + '''discover if graph addr "target" is reachable + without visiting node "cutpoint" + (an oddity: cutpoint itself is considered reachable)''' + k = ('is_reachable_without', cutpoint) + if k in self.cached_analysis: + reachable = self.cached_analysis[k] + if target in reachable: + return reachable[target] + + reachable = dict ([(self.get_entry (t), True) + for t in self.tags ()]) + for n in self.tarjan_order + ['Ret', 'Err']: + if n in reachable: + continue + reachable[n] = bool ([pred for pred in self.preds[n] + if pred != cutpoint + if reachable.get (pred) == True]) + self.cached_analysis[k] = reachable + return reachable[target] def deserialise (name, lines): - assert lines[0] == 'Problem', lines[0] - assert lines[-1] == 'EndProblem', lines[-1] - i = 1 - # not easy to reconstruct pairing - p = Problem (pairing = None, name = name) - while lines[i].startswith ('Entry'): - bits = lines[i].split () - en = int (bits[1]) - tag = bits[2] - fname = bits[3] - (n, inputs) = syntax.parse_list (syntax.parse_lval, bits, 4) - (n, outputs) = syntax.parse_list (syntax.parse_lval, bits, n) - assert n == len (bits), (n, bits) - p.entries.append ((en, tag, fname, inputs)) - p.outputs[tag] = outputs - i += 1 - for i in range (i, len (lines) - 1): - bits = lines[i].split () - n = int (bits[0]) - node = syntax.parse_node (bits, 1) - p.nodes[n] = node - return p + assert lines[0] == 'Problem', lines[0] + assert lines[-1] == 'EndProblem', lines[-1] + i = 1 + # not easy to reconstruct pairing + p = Problem (pairing = None, name = name) + while lines[i].startswith ('Entry'): + bits = lines[i].split () + en = int (bits[1]) + tag = bits[2] + fname = bits[3] + (n, inputs) = syntax.parse_list (syntax.parse_lval, bits, 4) + (n, outputs) = syntax.parse_list (syntax.parse_lval, bits, n) + assert n == len (bits), (n, bits) + p.entries.append ((en, tag, fname, inputs)) + p.outputs[tag] = outputs + i += 1 + for i in range (i, len (lines) - 1): + bits = lines[i].split () + n = int (bits[0]) + node = syntax.parse_node (bits, 1) + p.nodes[n] = node + return p # trivia -def check_no_symbols (nodes): - import pseudo_compile - symbs = pseudo_compile.nodes_symbols (nodes) - if not symbs: - return - printout ('Aborting %s: undefined symbols %s' % (self.name, symbs)) - raise Abort () +def check_no_symbols (name, nodes): + import pseudo_compile + symbs = pseudo_compile.nodes_symbols (nodes) + if not symbs: + return + printout ('Aborting %s: undefined symbols %s' % (name, symbs)) + raise Abort () # printing of problem graphs def sanitise_str (s): - return s.replace ('"', '_').replace ("'", "_").replace (' ', '') + return s.replace ('"', '_').replace ("'", "_").replace (' ', '') def graph_name (nodes, node_tags, n, prev=None): - if type (n) == str: - return 't_%s_%d' % (n, prev) - if n not in nodes: - return 'unknown_%d' % n - if n not in node_tags: - ident = '%d' % n - else: - (tag, details) = node_tags[n] - if len (details) > 1 and logic.is_int (details[1]): - ident = '%d_%s_%s_0x%x' % (n, tag, - details[0], details[1]) - elif type (details) != str: - details = '_'.join (map (str, details)) - ident = '%d_%s_%s' % (n, tag, details) - else: - ident = '%d_%s_%s' % (n, tag, details) - ident = sanitise_str (ident) - node = nodes[n] - if node.kind == 'Call': - return 'fcall_%s' % ident - if node.kind == 'Cond': - return ident - if node.kind == 'Basic': - return 'ass_%s' % ident - assert not 'node kind understood' + if type (n) == str: + return 't_%s_%d' % (n, prev) + if n not in nodes: + return 'unknown_%d' % n + if n not in node_tags: + ident = '%d' % n + else: + (tag, details) = node_tags[n] + if len (details) > 1 and logic.is_int (details[1]): + ident = '%d_%s_%s_0x%x' % (n, tag, + details[0], details[1]) + elif type (details) != str: + details = '_'.join (map (str, details)) + ident = '%d_%s_%s' % (n, tag, details) + else: + ident = '%d_%s_%s' % (n, tag, details) + ident = sanitise_str (ident) + node = nodes[n] + if node.kind == 'Call': + return 'fcall_%s' % ident + if node.kind == 'Cond': + return ident + if node.kind == 'Basic': + return 'ass_%s' % ident + assert not 'node kind understood' def graph_node_tooltip (nodes, n): - if n == 'Err': - return 'Error point' - if n == 'Ret': - return 'Return point' - node = nodes[n] - if node.kind == 'Call': - return "%s: call to '%s'" % (n, sanitise_str (node.fname)) - if node.kind == 'Cond': - return '%s: conditional node' % n - if node.kind == 'Basic': - var_names = [sanitise_str (x[0][0]) for x in node.upds] - return '%s: assignment to [%s]' % (n, ', '.join (var_names)) - assert not 'node kind understood' + if n == 'Err': + return 'Error point' + if n == 'Ret': + return 'Return point' + node = nodes[n] + if node.kind == 'Call': + return "%s: call to '%s'" % (n, sanitise_str (node.fname)) + if node.kind == 'Cond': + return '%s: conditional node' % n + if node.kind == 'Basic': + var_names = [sanitise_str (x[0][0]) for x in node.upds] + return '%s: assignment to [%s]' % (n, ', '.join (var_names)) + assert not 'node kind understood' def graph_edges (nodes, n): - node = nodes[n] - if node.is_noop (): - return [(node.get_conts () [0], 'N')] - elif node.kind == 'Cond': - return [(node.left, 'T'), (node.right, 'F')] - else: - return [(node.cont, 'C')] + node = nodes[n] + if node.is_noop (): + return [(node.get_conts () [0], 'N')] + elif node.kind == 'Cond': + return [(node.left, 'T'), (node.right, 'F')] + else: + return [(node.cont, 'C')] def get_graph_font (n, col): - font = 'fontname = "Arial", fontsize = 20, penwidth=3' - if col: - font = font + ', color=%s, fontcolor=%s' % (col, col) - return font + font = 'fontname = "Arial", fontsize = 20, penwidth=3' + if col: + font = font + ', color=%s, fontcolor=%s' % (col, col) + return font def get_graph_loops (nodes): - graph = dict ([(n, [c for c in nodes[n].get_conts () - if type (c) != str]) for n in nodes]) - graph['ENTRY'] = list (nodes) - comps = logic.tarjan (graph, ['ENTRY']) - comp_ids = {} - for (head, tail) in comps: - comp_ids[head] = head - for n in tail: - comp_ids[n] = head - loops = set ([(n, n2) for n in graph for n2 in graph[n] - if comp_ids[n] == comp_ids[n2]]) - return loops + graph = dict ([(n, [c for c in nodes[n].get_conts () + if type (c) != str]) for n in nodes]) + graph['ENTRY'] = list (nodes) + comps = logic.tarjan (graph, ['ENTRY']) + comp_ids = {} + for (head, tail) in comps: + comp_ids[head] = head + for n in tail: + comp_ids[n] = head + loops = set ([(n, n2) for n in graph for n2 in graph[n] + if comp_ids[n] == comp_ids[n2]]) + return loops def make_graph (nodes, cols, node_tags = {}, entries = []): - graph = [] - graph.append ('digraph foo {') - - loops = get_graph_loops (nodes) - - for n in nodes: - n_nm = graph_name (nodes, node_tags, n) - f = get_graph_font (n, cols.get (n)) - graph.append ('%s [%s\n label="%s"\n tooltip="%s"];' % (n, - f, n_nm, graph_node_tooltip (nodes, n))) - for (c, l) in graph_edges (nodes, n): - if c in ['Ret', 'Err']: - c_nm = '%s_%s' % (c, n) - if c == 'Ret': - f2 = f + ', shape=doubleoctagon' - else: - f2 = f + ', shape=Mdiamond' - graph.append ('%s [label="%s", %s];' - % (c_nm, c, f2)) - else: - c_nm = c - ft = f - if (n, c) in loops: - ft = f + ', penwidth=6' - graph.append ('%s -> %s [label=%s, %s];' % ( - n, c_nm, l, ft)) - - for (i, (n, tag, inps)) in enumerate (entries): - f = get_graph_font (n, cols.get (n)) - nm1 = tag + ' ENTRY_POINT' - nm2 = 'entry_point_%d' % i - graph.extend (['%s -> %s [%s];' % (nm2, n, f), - '%s [label = "%s", shape=none, %s];' % (nm2, nm1, f)]) - - graph.append ('}') - return graph + graph = [] + graph.append ('digraph foo {') + + loops = get_graph_loops (nodes) + + for n in nodes: + n_nm = graph_name (nodes, node_tags, n) + f = get_graph_font (n, cols.get (n)) + graph.append ('%s [%s\n label="%s"\n tooltip="%s"];' % (n, + f, n_nm, graph_node_tooltip (nodes, n))) + for (c, l) in graph_edges (nodes, n): + if c in ['Ret', 'Err']: + c_nm = '%s_%s' % (c, n) + if c == 'Ret': + f2 = f + ', shape=doubleoctagon' + else: + f2 = f + ', shape=Mdiamond' + graph.append ('%s [label="%s", %s];' + % (c_nm, c, f2)) + else: + c_nm = c + ft = f + if (n, c) in loops: + ft = f + ', penwidth=6' + graph.append ('%s -> %s [label=%s, %s];' % ( + n, c_nm, l, ft)) + + for (i, (n, tag, inps)) in enumerate (entries): + f = get_graph_font (n, cols.get (n)) + nm1 = tag + ' ENTRY_POINT' + nm2 = 'entry_point_%d' % i + graph.extend (['%s -> %s [%s];' % (nm2, n, f), + '%s [label = "%s", shape=none, %s];' % (nm2, nm1, f)]) + + graph.append ('}') + return graph def print_graph (nodes, cols = {}, entries = []): - for line in make_graph (nodes, cols, entries): - print line + for line in make_graph (nodes, cols, entries): + print line def save_graph (nodes, fname, cols = {}, entries = [], node_tags = {}): - f = open (fname, 'w') - for line in make_graph (nodes, cols = cols, node_tags = node_tags, - entries = entries): - f.write (line + '\n') - f.close () + f = open (fname, 'w') + for line in make_graph (nodes, cols = cols, node_tags = node_tags, + entries = entries): + f.write (line + '\n') + f.close () def mk_graph_cols (node_tags): - known_cols = {'C': "forestgreen", 'ASM_adj': "darkblue", - 'ASM': "darkorange"} - cols = {} - for n in node_tags: - if node_tags[n][0] in known_cols: - cols[n] = known_cols[node_tags[n][0]] - return cols + known_cols = {'C': "forestgreen", 'ASM_adj': "darkblue", + 'ASM': "darkorange"} + cols = {} + for n in node_tags: + if node_tags[n][0] in known_cols: + cols[n] = known_cols[node_tags[n][0]] + return cols def make_graph_with_eqs (p, invis = False): - if invis: - invis_s = ', style=invis' - else: - invis_s = '' - cols = mk_graph_cols (p.node_tags) - graph = make_graph (p.nodes, cols = cols) - graph.pop () - for k in p.known_eqs: - if k == 'Hyps': - continue - (n_vc_x, tag_x) = k - nm1 = graph_name (p.nodes, p.node_tags, n_vc_x[0]) - for (x, n_vc_y, tag_y, y, hyps) in p.known_eqs[k]: - nm2 = graph_name (p.nodes, p.node_tags, n_vc_y[0]) - graph.extend ([('%s -> %s [ dir = back, color = blue, ' - 'penwidth = 3, weight = 0 %s ]') - % (nm2, nm1, invis_s)]) - graph.append ('}') - return graph + if invis: + invis_s = ', style=invis' + else: + invis_s = '' + cols = mk_graph_cols (p.node_tags) + graph = make_graph (p.nodes, cols = cols) + graph.pop () + for k in p.known_eqs: + if k == 'Hyps': + continue + (n_vc_x, tag_x) = k + nm1 = graph_name (p.nodes, p.node_tags, n_vc_x[0]) + for (x, n_vc_y, tag_y, y, hyps) in p.known_eqs[k]: + nm2 = graph_name (p.nodes, p.node_tags, n_vc_y[0]) + graph.extend ([('%s -> %s [ dir = back, color = blue, ' + 'penwidth = 3, weight = 0 %s ]') + % (nm2, nm1, invis_s)]) + graph.append ('}') + return graph def save_graph_with_eqs (p, fname = 'diagram.dot', invis = False): - graph = make_graph_with_eqs (p, invis = invis) - f = open (fname, 'w') - for s in graph: - f.write (s + '\n') - f.close () + graph = make_graph_with_eqs (p, invis = invis) + f = open (fname, 'w') + for s in graph: + f.write (s + '\n') + f.close () def get_problem_vars (p): - inout = set.union (* ([set(xs) for xs in p.outputs.itervalues ()] - + [set (args) for (_, _, _, args) in p.entries])) + inout = set.union (* ([set(xs) for xs in p.outputs.itervalues ()] + + [set (args) for (_, _, _, args) in p.entries])) - vs = dict(inout) - for node in p.nodes.itervalues(): - syntax.get_node_vars(node, vs) - return vs + vs = dict(inout) + for node in p.nodes.itervalues(): + syntax.get_node_vars(node, vs) + return vs def is_trivial_fun (fun): - for node in fun.nodes.itervalues (): - if node.is_noop (): - continue - if node.kind == 'Call': - return False - elif node.kind == 'Basic': - for (lv, v) in node.upds: - if v.kind not in ['Var', 'Num']: - return False - elif node.kind == 'Cond': - if node.cond.kind != 'Var' and node.cond not in [ - true_term, false_term]: - return False - return True + for node in fun.nodes.itervalues (): + if node.is_noop (): + continue + if node.kind == 'Call': + return False + elif node.kind == 'Basic': + for (lv, v) in node.upds: + if v.kind not in ['Var', 'Num']: + return False + elif node.kind == 'Cond': + if node.cond.kind != 'Var' and node.cond not in [ + true_term, false_term]: + return False + return True last_alt_nodes = [0] def avail_val (vs, typ): - for (nm, typ2) in vs: - if typ2 == typ: - return mk_var (nm, typ2) - return logic.default_val (typ) + for (nm, typ2) in vs: + if typ2 == typ: + return mk_var (nm, typ2) + return logic.default_val (typ) def inline_at_point (p, n, do_analysis = True): - node = p.nodes[n] - if node.kind != 'Call': - return + node = p.nodes[n] + if node.kind != 'Call': + return - f_nm = node.fname - fun = functions[f_nm] - (tag, detail) = p.node_tags[n] - idx = p.node_tag_revs[(tag, detail)].index (n) - p.inline_scripts[tag].append ((detail, idx, f_nm)) + f_nm = node.fname + fun = functions[f_nm] + (tag, detail) = p.node_tags[n] + idx = p.node_tag_revs[(tag, detail)].index (n) + p.inline_scripts[tag].append ((detail, idx, f_nm)) - trace ('Inlining %s into %s' % (f_nm, p.name)) - if n in p.loop_data: - trace (' inlining into loop %d!' % p.loop_id (n)) + trace ('Inlining %s into %s' % (f_nm, p.name)) + if n in p.loop_data: + trace (' inlining into loop %d!' % p.loop_id (n)) - ex = p.alloc_node (tag, (f_nm, 'RetToCaller')) + ex = p.alloc_node (tag, (f_nm, 'RetToCaller')) - (ns, vs) = p.add_function (fun, tag, {'Ret': ex}) - en = ns[fun.entry] + (ns, vs) = p.add_function (fun, tag, {'Ret': ex}) + en = ns[fun.entry] - inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs] - p.nodes[n] = Node ('Basic', en, azip (inp_lvs, node.args)) + inp_lvs = [(vs[v], typ) for (v, typ) in fun.inputs] + p.nodes[n] = Node ('Basic', en, azip (inp_lvs, node.args)) - out_vs = [mk_var (vs[v], typ) for (v, typ) in fun.outputs] - p.nodes[ex] = Node ('Basic', node.cont, azip (node.rets, out_vs)) + out_vs = [mk_var (vs[v], typ) for (v, typ) in fun.outputs] + p.nodes[ex] = Node ('Basic', node.cont, azip (node.rets, out_vs)) - p.cached_analysis.clear () + p.cached_analysis.clear () - if do_analysis: - p.do_analysis () + if do_analysis: + p.do_analysis () - trace ('Problem size now %d' % len(p.nodes)) - sys.stdin.flush () + trace ('Problem size now %d' % len(p.nodes)) + sys.stdin.flush () - return ns.values () + return ns.values () def loop_body_inner_loops (p, head, loop_body): - loop_set_all = set (loop_body) - loop_set = loop_set_all - set ([head]) - graph = dict([(n, [c for c in p.nodes[n].get_conts () - if c in loop_set]) - for n in loop_set_all]) + loop_set_all = set (loop_body) + loop_set = loop_set_all - set ([head]) + graph = dict([(n, [c for c in p.nodes[n].get_conts () + if c in loop_set]) + for n in loop_set_all]) - comps = logic.tarjan (graph, [head]) - assert sum ([1 + len (t) for (_, t) in comps]) == len (loop_set_all) - return [comp for comp in comps if comp[1]] + comps = logic.tarjan (graph, [head]) + assert sum ([1 + len (t) for (_, t) in comps]) == len (loop_set_all) + return [comp for comp in comps if comp[1]] def loop_inner_loops (p, head): - k = ('inner_loop_set', head) - if k in p.cached_analysis: - return p.cached_analysis[k] - res = loop_body_inner_loops (p, head, p.loop_body (head)) - p.cached_analysis[k] = res - return res + k = ('inner_loop_set', head) + if k in p.cached_analysis: + return p.cached_analysis[k] + res = loop_body_inner_loops (p, head, p.loop_body (head)) + p.cached_analysis[k] = res + return res def loop_heads_including_inner (p): - heads = p.loop_heads () - check = [(head, p.loop_body (head)) for head in heads] - while check: - (head, body) = check.pop () - comps = loop_body_inner_loops (p, head, body) - heads.extend ([head for (head, _) in comps]) - check.extend ([(head, [head] + list (body)) - for (head, body) in comps]) - return heads + heads = p.loop_heads () + check = [(head, p.loop_body (head)) for head in heads] + while check: + (head, body) = check.pop () + comps = loop_body_inner_loops (p, head, body) + heads.extend ([head for (head, _) in comps]) + check.extend ([(head, [head] + list (body)) + for (head, body) in comps]) + return heads def check_no_inner_loop (p, head): - subs = loop_inner_loops (p, head) - if subs: - printout ('Aborting %s, complex loop' % p.name) - trace (' sub-loops %s of loop at %s' % (subs, head)) - for (h, _) in subs: - trace (' head %d tagged %s' % (h, p.node_tags[h])) - raise Abort () + subs = loop_inner_loops (p, head) + if subs: + printout ('Aborting %s, complex loop' % p.name) + trace (' sub-loops %s of loop at %s' % (subs, head)) + for (h, _) in subs: + trace (' head %d tagged %s' % (h, p.node_tags[h])) + raise Abort () def has_inner_loop (p, head): - return bool (loop_inner_loops (p, head)) + return bool (loop_inner_loops (p, head)) def fun_has_inner_loop (f): - p = f.as_problem (Problem) - p.do_analysis () - return bool ([head for head in p.loop_heads () - if has_inner_loop (p, head)]) + p = f.as_problem (Problem) + p.do_analysis () + return bool ([head for head in p.loop_heads () + if has_inner_loop (p, head)]) def loop_var_analysis (p, head, tail): - # getting the set of variables that go round the loop - nodes = set (tail) - nodes.add (head) - used_vs = set ([]) - created_vs_at = {} - visit = [] - - def process_node (n, created): - if p.nodes[n].is_noop (): - lvals = set ([]) - else: - vs = syntax.get_node_rvals (p.nodes[n]) - for rv in vs.iteritems (): - if rv not in created: - used_vs.add (rv) - lvals = set (p.nodes[n].get_lvals ()) - - created = set.union (created, lvals) - created_vs_at[n] = created - - visit.extend (p.nodes[n].get_conts ()) - - process_node (head, set ([])) - - while visit: - n = visit.pop () - if (n not in nodes) or (n in created_vs_at): - continue - if not all ([pr in created_vs_at for pr in p.preds[n]]): - continue - - pre_created = [created_vs_at[pr] for pr in p.preds[n]] - process_node (n, set.union (* pre_created)) - - final_pre_created = [created_vs_at[pr] for pr in p.preds[head] - if pr in nodes] - created = set.union (* final_pre_created) - - loop_vs = set.intersection (created, used_vs) - trace ('Loop vars at head: %s' % loop_vs) - - return loop_vs + # getting the set of variables that go round the loop + nodes = set (tail) + nodes.add (head) + used_vs = set ([]) + created_vs_at = {} + visit = [] + + def process_node (n, created): + if p.nodes[n].is_noop (): + lvals = set ([]) + else: + vs = syntax.get_node_rvals (p.nodes[n]) + for rv in vs.iteritems (): + if rv not in created: + used_vs.add (rv) + lvals = set (p.nodes[n].get_lvals ()) + + created = set.union (created, lvals) + created_vs_at[n] = created + + visit.extend (p.nodes[n].get_conts ()) + + process_node (head, set ([])) + + while visit: + n = visit.pop () + if (n not in nodes) or (n in created_vs_at): + continue + if not all ([pr in created_vs_at for pr in p.preds[n]]): + continue + + pre_created = [created_vs_at[pr] for pr in p.preds[n]] + process_node (n, set.union (* pre_created)) + + final_pre_created = [created_vs_at[pr] for pr in p.preds[head] + if pr in nodes] + created = set.union (* final_pre_created) + + loop_vs = set.intersection (created, used_vs) + trace ('Loop vars at head: %s' % loop_vs) + + return loop_vs diff --git a/pseudo_compile.py b/pseudo_compile.py index e9d9f8c1..9b5c2bcf 100644 --- a/pseudo_compile.py +++ b/pseudo_compile.py @@ -11,463 +11,477 @@ import logic -(mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq, -mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word32, mk_word8, -mk_word32_maybe, mk_cast, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs, -mk_if, mk_meta_typ, mk_pvalid) = syntax.mks +from syntax import (mk_var, mk_plus, mk_uminus, mk_minus, mk_times, mk_modulus, mk_bwand, mk_eq, + mk_less_eq, mk_less, mk_implies, mk_and, mk_or, mk_not, mk_word64, mk_word32, mk_word8, + mk_word32_maybe, mk_memacc, mk_memupd, mk_arr_index, mk_arroffs, + mk_if, mk_meta_typ, mk_pvalid) -from syntax import word32T, word8T +from syntax import word64T, word32T, word8T from syntax import fresh_name, foldr1 from target_objects import symbols, trace def compile_field_acc (name, expr, replaces): - '''pseudo-compile access to field (named name) of expr''' - if expr.kind == 'StructCons': - return expr.vals[name] - elif expr.kind == 'FieldUpd': - if expr.field[0] == name: - return expr.val - else: - return compile_field_acc (name, expr.struct, replaces) - elif expr.kind == 'Var': - assert expr.name in replaces - [(v_nm, typ)] = [(v_nm, typ) for (f_nm, v_nm, typ) - in replaces[expr.name] if f_nm == name] - return mk_var (v_nm, typ) - elif expr.is_op ('MemAcc'): - assert expr.typ.kind == 'Struct' - (typ, offs, _) = structs[expr.typ.name].fields[name] - [m, p] = expr.vals - return mk_memacc (m, mk_plus (p, mk_word32 (offs)), typ) - elif expr.kind == 'Field': - expr2 = compile_field_acc (expr.field[0], expr.struct, replaces) - return compile_field_acc (name, expr2, replaces) - elif expr.is_op ('ArrayIndex'): - [arr, i] = expr.vals - expr2 = compile_array_acc (i, arr, replaces, False) - assert expr2, (arr, i) - return compile_field_acc (name, expr2, replaces) - else: - print expr - assert not 'field acc compilable' + '''pseudo-compile access to field (named name) of expr''' + if expr.kind == 'StructCons': + return expr.vals[name] + elif expr.kind == 'FieldUpd': + if expr.field[0] == name: + return expr.val + else: + return compile_field_acc (name, expr.struct, replaces) + elif expr.kind == 'Var': + assert expr.name in replaces + [(v_nm, typ)] = [(v_nm, typ) for (f_nm, v_nm, typ) + in replaces[expr.name] if f_nm == name] + return mk_var (v_nm, typ) + elif expr.is_op ('MemAcc'): + assert expr.typ.kind == 'Struct' + (typ, offs, _) = structs[expr.typ.name].fields[name] + [m, p] = expr.vals + assert False + return mk_memacc (m, mk_plus (p, mk_word32 (offs)), typ) + elif expr.kind == 'Field': + expr2 = compile_field_acc (expr.field[0], expr.struct, replaces) + return compile_field_acc (name, expr2, replaces) + elif expr.is_op ('ArrayIndex'): + [arr, i] = expr.vals + expr2 = compile_array_acc (i, arr, replaces, False) + assert expr2, (arr, i) + return compile_field_acc (name, expr2, replaces) + else: + print expr + assert not 'field acc compilable' def compile_array_acc (i, expr, replaces, must = True): - '''pseudo-compile access to array element i of expr''' - if not logic.is_int (i) and i.kind == 'Num': - assert i.typ == word32T - i = i.val - if expr.kind == 'Array': - if logic.is_int (i): - return expr.vals[i] - else: - expr2 = expr.vals[-1] - for (j, v) in enumerate (expr.vals[:-1]): - expr2 = mk_if (mk_eq (i, mk_word32 (j)), v, expr2) - return expr2 - elif expr.is_op ('ArrayUpdate'): - [arr, j, v] = expr.vals - if j.kind == 'Num' and logic.is_int (i): - if i == j.val: - return v - else: - return compile_array_acc (i, arr, replaces) - else: - return mk_if (mk_eq (j, mk_word32_maybe (i)), v, - compile_array_acc (i, arr, replaces)) - elif expr.is_op ('MemAcc'): - [m, p] = expr.vals - return mk_memacc (m, mk_arroffs (p, expr.typ, i), expr.typ.el_typ) - elif expr.is_op ('IfThenElse'): - [cond, left, right] = expr.vals - return mk_if (cond, compile_array_acc (i, left, replaces), - compile_array_acc (i, right, replaces)) - elif expr.kind == 'Var': - assert expr.name in replaces - if logic.is_int (i): - (_, v_nm, typ) = replaces[expr.name][i] - return mk_var (v_nm, typ) - else: - vs = [(mk_word32 (j), mk_var (v_nm, typ)) - for (j, v_nm, typ) - in replaces[expr.name]] - expr2 = vs[0][1] - for (j, v) in vs[1:]: - expr2 = mk_if (mk_eq (i, j), v, expr2) - return expr2 - else: - if not must: - return None - return mk_arr_index (expr, mk_word32_maybe (i)) + '''pseudo-compile access to array element i of expr''' + if not logic.is_int (i) and i.kind == 'Num': + assert i.typ == word32T + i = i.val + if expr.kind == 'Array': + if logic.is_int (i): + return expr.vals[i] + else: + assert False + expr2 = expr.vals[-1] + for (j, v) in enumerate (expr.vals[:-1]): + expr2 = mk_if (mk_eq (i, mk_word32 (j)), v, expr2) + return expr2 + elif expr.is_op ('ArrayUpdate'): + [arr, j, v] = expr.vals + if j.kind == 'Num' and logic.is_int (i): + if i == j.val: + return v + else: + return compile_array_acc (i, arr, replaces) + else: + assert False + return mk_if (mk_eq (j, mk_word32_maybe (i)), v, + compile_array_acc (i, arr, replaces)) + elif expr.is_op ('MemAcc'): + [m, p] = expr.vals + return mk_memacc (m, mk_arroffs (p, expr.typ, i), expr.typ.el_typ) + elif expr.is_op ('IfThenElse'): + [cond, left, right] = expr.vals + return mk_if (cond, compile_array_acc (i, left, replaces), + compile_array_acc (i, right, replaces)) + elif expr.kind == 'Var': + assert expr.name in replaces + if logic.is_int (i): + (_, v_nm, typ) = replaces[expr.name][i] + return mk_var (v_nm, typ) + else: + assert False + vs = [(mk_word32 (j), mk_var (v_nm, typ)) + for (j, v_nm, typ) + in replaces[expr.name]] + expr2 = vs[0][1] + for (j, v) in vs[1:]: + expr2 = mk_if (mk_eq (i, j), v, expr2) + return expr2 + else: + if not must: + return None + assert False + return mk_arr_index (expr, mk_word32_maybe (i)) def num_fields (container, typ): - if container == typ: - return 1 - elif container.kind == 'Array': - return container.num * num_fields (container.el_typ, typ) - elif container.kind == 'Struct': - struct = structs[container.name] - return sum ([num_fields (typ2, typ) - for (nm, typ2) in struct.field_list]) - else: - return 0 + if container == typ: + return 1 + elif container.kind == 'Array': + return container.num * num_fields (container.el_typ, typ) + elif container.kind == 'Struct': + struct = structs[container.name] + return sum ([num_fields (typ2, typ) + for (nm, typ2) in struct.field_list]) + else: + return 0 def get_const_global_acc_offset (expr, offs, typ): - if expr.kind == 'ConstGlobal': - return (expr, offs) - elif expr.is_op ('ArrayIndex'): - [expr2, offs2] = expr.vals - offs = mk_plus (offs, mk_times (offs2, - mk_word32 (num_fields (expr.typ, typ)))) - return get_const_global_acc_offset (expr2, offs, typ) - elif expr.kind == 'Field': - struct = structs[expr.struct.typ.name] - offs2 = 0 - for (nm, typ2) in struct.field_list: - if (nm, typ2) == expr.field: - offs = mk_plus (offs, mk_word32 (offs2)) - return get_const_global_acc_offset ( - expr.struct, offs, typ) - else: - offs2 += num_fields (typ2, typ) - else: - return None + if expr.kind == 'ConstGlobal': + return (expr, offs) + elif expr.is_op ('ArrayIndex'): + [expr2, offs2] = expr.vals + assert False + offs = mk_plus (offs, mk_times (offs2, + mk_word32 (num_fields (expr.typ, typ)))) + return get_const_global_acc_offset (expr2, offs, typ) + elif expr.kind == 'Field': + struct = structs[expr.struct.typ.name] + offs2 = 0 + for (nm, typ2) in struct.field_list: + if (nm, typ2) == expr.field: + assert False + offs = mk_plus (offs, mk_word32 (offs2)) + return get_const_global_acc_offset ( + expr.struct, offs, typ) + else: + offs2 += num_fields (typ2, typ) + else: + return None def compile_const_global_acc (expr): - if expr.kind == 'ConstGlobal' or (expr.is_op ('ArrayIndex') - and expr.vals[0].kind == 'ConstGlobal'): - return None - if expr.typ.kind != 'Word': - return None - r = get_const_global_acc_offset (expr, mk_word32 (0), expr.typ) - if r == None: - return None - (cg, offs) = r - return mk_arr_index (cg, offs) + if expr.kind == 'ConstGlobal' or (expr.is_op ('ArrayIndex') + and expr.vals[0].kind == 'ConstGlobal'): + return None + if expr.typ.kind != 'Word': + return None + assert False + r = get_const_global_acc_offset (expr, mk_word32 (0), expr.typ) + if r == None: + return None + (cg, offs) = r + return mk_arr_index (cg, offs) def compile_val_fields (expr, replaces): - if expr.typ.kind == 'Array': - res = [] - for i in range (expr.typ.num): - acc = compile_array_acc (i, expr, replaces) - res.extend (compile_val_fields (acc, replaces)) - return res - elif expr.typ.kind == 'Struct': - res = [] - for (nm, typ2) in structs[expr.typ.name].field_list: - acc = compile_field_acc (nm, expr, replaces) - res.extend (compile_val_fields (acc, replaces)) - return res - else: - return [compile_accs (replaces, expr)] + if expr.typ.kind == 'Array': + res = [] + for i in range (expr.typ.num): + acc = compile_array_acc (i, expr, replaces) + res.extend (compile_val_fields (acc, replaces)) + return res + elif expr.typ.kind == 'Struct': + res = [] + for (nm, typ2) in structs[expr.typ.name].field_list: + acc = compile_field_acc (nm, expr, replaces) + res.extend (compile_val_fields (acc, replaces)) + return res + else: + return [compile_accs (replaces, expr)] def compile_val_fields_of_typ (expr, replaces, typ): - return [e for e in compile_val_fields (expr, replaces) - if e.typ == typ] + return [e for e in compile_val_fields (expr, replaces) + if e.typ == typ] # it helps in this compilation to know that the outermost expression we are # trying to fetch is always of basic type, never struct or array. # sort of fudged in the array indexing case here def compile_accs (replaces, expr): - r = compile_const_global_acc (expr) - if r: - return compile_accs (replaces, r) - - if expr.kind == 'Field': - expr = compile_field_acc (expr.field[0], expr.struct, replaces) - return compile_accs (replaces, expr) - elif expr.is_op ('ArrayIndex'): - [arr, n] = expr.vals - expr2 = compile_array_acc (n, arr, replaces, False) - if expr2: - return compile_accs (replaces, expr2) - arr = compile_accs (replaces, arr) - n = compile_accs (replaces, n) - expr2 = compile_array_acc (n, arr, replaces, False) - if expr2: - return compile_accs (replaces, expr2) - else: - return mk_arr_index (arr, n) - elif (expr.is_op ('MemUpdate') - and expr.vals[2].is_op ('MemAcc') - and expr.vals[2].vals[0] == expr.vals[0] - and expr.vals[2].vals[1] == expr.vals[1]): - # null memory copy. probably created by ops below - return compile_accs (replaces, expr.vals[0]) - elif (expr.is_op ('MemUpdate') - and expr.vals[2].kind == 'FieldUpd'): - [m, p, f_upd] = expr.vals - assert f_upd.typ.kind == 'Struct' - (typ, offs, _) = structs[f_upd.typ.name].fields[f_upd.field[0]] - assert f_upd.val.typ == typ - return compile_accs (replaces, - mk_memupd (mk_memupd (m, p, f_upd.struct), - mk_plus (p, mk_word32 (offs)), f_upd.val)) - elif (expr.is_op ('MemUpdate') - and expr.vals[2].typ.kind == 'Struct'): - [m, p, s_val] = expr.vals - struct = structs[s_val.typ.name] - for (nm, (typ, offs, _)) in struct.fields.iteritems (): - f = compile_field_acc (nm, s_val, replaces) - assert f.typ == typ - m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), f) - return compile_accs (replaces, m) - elif (expr.is_op ('MemUpdate') - and expr.vals[2].is_op ('ArrayUpdate')): - [m, p, arr_upd] = expr.vals - [arr, i, v] = arr_upd.vals - return compile_accs (replaces, - mk_memupd (mk_memupd (m, p, arr), - mk_arroffs (p, arr.typ, i), v)) - elif (expr.is_op ('MemUpdate') - and expr.vals[2].typ.kind == 'Array'): - [m, p, arr] = expr.vals - n = arr.typ.num - typ = arr.typ.el_typ - for i in range (n): - offs = i * typ.size () - assert offs == i or offs % 4 == 0 - e = compile_array_acc (i, arr, replaces) - m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), e) - return compile_accs (replaces, m) - elif expr.is_op ('Equals') \ - and expr.vals[0].typ.kind in ['Struct', 'Array']: - [x, y] = expr.vals - assert x.typ == y.typ - xs = compile_val_fields (x, replaces) - ys = compile_val_fields (y, replaces) - eq = foldr1 (mk_and, map (mk_eq, xs, ys)) - return compile_accs (replaces, eq) - elif expr.is_op ('PAlignValid'): - [typ, p] = expr.vals - p = compile_accs (replaces, p) - assert typ.kind == 'Type' - return logic.mk_align_valid_ineq (('Type', typ.val), p) - elif expr.kind == 'Op': - vals = [compile_accs (replaces, v) for v in expr.vals] - return syntax.adjust_op_vals (expr, vals) - elif expr.kind == 'Symbol': - return mk_word32 (symbols[expr.name][0]) - else: - if expr.kind not in {'Var':True, 'ConstGlobal':True, - 'Num':True, 'Invent':True, 'Type':True}: - print expr - assert not 'field acc compiled' - return expr + r = compile_const_global_acc (expr) + if r: + return compile_accs (replaces, r) + + if expr.kind == 'Field': + expr = compile_field_acc (expr.field[0], expr.struct, replaces) + return compile_accs (replaces, expr) + elif expr.is_op ('ArrayIndex'): + [arr, n] = expr.vals + expr2 = compile_array_acc (n, arr, replaces, False) + if expr2: + return compile_accs (replaces, expr2) + arr = compile_accs (replaces, arr) + n = compile_accs (replaces, n) + expr2 = compile_array_acc (n, arr, replaces, False) + if expr2: + return compile_accs (replaces, expr2) + else: + return mk_arr_index (arr, n) + elif (expr.is_op ('MemUpdate') + and expr.vals[2].is_op ('MemAcc') + and expr.vals[2].vals[0] == expr.vals[0] + and expr.vals[2].vals[1] == expr.vals[1]): + # null memory copy. probably created by ops below + return compile_accs (replaces, expr.vals[0]) + elif (expr.is_op ('MemUpdate') + and expr.vals[2].kind == 'FieldUpd'): + [m, p, f_upd] = expr.vals + assert f_upd.typ.kind == 'Struct' + (typ, offs, _) = structs[f_upd.typ.name].fields[f_upd.field[0]] + assert f_upd.val.typ == typ + assert False + return compile_accs (replaces, + mk_memupd (mk_memupd (m, p, f_upd.struct), + mk_plus (p, mk_word32 (offs)), f_upd.val)) + elif (expr.is_op ('MemUpdate') + and expr.vals[2].typ.kind == 'Struct'): + [m, p, s_val] = expr.vals + struct = structs[s_val.typ.name] + for (nm, (typ, offs, _)) in struct.fields.iteritems (): + f = compile_field_acc (nm, s_val, replaces) + assert f.typ == typ + assert False + m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), f) + return compile_accs (replaces, m) + elif (expr.is_op ('MemUpdate') + and expr.vals[2].is_op ('ArrayUpdate')): + [m, p, arr_upd] = expr.vals + [arr, i, v] = arr_upd.vals + return compile_accs (replaces, + mk_memupd (mk_memupd (m, p, arr), + mk_arroffs (p, arr.typ, i), v)) + elif (expr.is_op ('MemUpdate') + and expr.vals[2].typ.kind == 'Array'): + [m, p, arr] = expr.vals + n = arr.typ.num + typ = arr.typ.el_typ + for i in range (n): + offs = i * typ.size () + assert offs == i or offs % 4 == 0 + e = compile_array_acc (i, arr, replaces) + assert False + m = mk_memupd (m, mk_plus (p, mk_word32 (offs)), e) + return compile_accs (replaces, m) + elif expr.is_op ('Equals') \ + and expr.vals[0].typ.kind in ['Struct', 'Array']: + [x, y] = expr.vals + assert x.typ == y.typ + xs = compile_val_fields (x, replaces) + ys = compile_val_fields (y, replaces) + eq = foldr1 (mk_and, map (mk_eq, xs, ys)) + return compile_accs (replaces, eq) + elif expr.is_op ('PAlignValid'): + [typ, p] = expr.vals + p = compile_accs (replaces, p) + assert typ.kind == 'Type' + return logic.mk_align_valid_ineq (('Type', typ.val), p) + elif expr.kind == 'Op': + vals = [compile_accs (replaces, v) for v in expr.vals] + return syntax.adjust_op_vals (expr, vals) + elif expr.kind == 'Symbol': + assert False + return mk_word32 (symbols[expr.name][0]) + + else: + if expr.kind not in {'Var':True, 'ConstGlobal':True, + 'Num':True, 'Invent':True, 'Type':True}: + print expr + assert not 'field acc compiled' + return expr def expand_arg_fields (replaces, args): - xs = [] - for arg in args: - if arg.typ.kind == 'Struct': - ys = [compile_field_acc (nm, arg, replaces) - for (nm, _) in structs[arg.typ.name].field_list] - xs.extend (expand_arg_fields (replaces, ys)) - elif arg.typ.kind == 'Array': - ys = [compile_array_acc (i, arg, replaces) - for i in range (arg.typ.num)] - xs.extend (expand_arg_fields (replaces, ys)) - else: - xs.append (compile_accs (replaces, arg)) - return xs + xs = [] + for arg in args: + if arg.typ.kind == 'Struct': + ys = [compile_field_acc (nm, arg, replaces) + for (nm, _) in structs[arg.typ.name].field_list] + xs.extend (expand_arg_fields (replaces, ys)) + elif arg.typ.kind == 'Array': + ys = [compile_array_acc (i, arg, replaces) + for i in range (arg.typ.num)] + xs.extend (expand_arg_fields (replaces, ys)) + else: + xs.append (compile_accs (replaces, arg)) + return xs def expand_lval_list (replaces, lvals): - xs = [] - for (nm, typ) in lvals: - if nm in replaces: - xs.extend (expand_lval_list (replaces, [(v_nm, typ) - for (f_nm, v_nm, typ) in replaces[nm]])) - else: - assert typ.kind not in ['Struct', 'Array'] - xs.append ((nm, typ)) - return xs + xs = [] + for (nm, typ) in lvals: + if nm in replaces: + xs.extend (expand_lval_list (replaces, [(v_nm, typ) + for (f_nm, v_nm, typ) in replaces[nm]])) + else: + assert typ.kind not in ['Struct', 'Array'] + xs.append ((nm, typ)) + return xs def mk_acc (idx, expr, replaces): - if logic.is_int (idx): - assert expr.typ.kind == 'Array' - return compile_array_acc (idx, expr, replaces) - else: - assert expr.typ.kind == 'Struct' - return compile_field_acc (idx, expr, replaces) + if logic.is_int (idx): + assert expr.typ.kind == 'Array' + return compile_array_acc (idx, expr, replaces) + else: + assert expr.typ.kind == 'Struct' + return compile_field_acc (idx, expr, replaces) def compile_upds (replaces, upds): - lvs = expand_lval_list (replaces, [lv for (lv, v) in upds]) - vs = expand_arg_fields (replaces, [v for (lv, v) in upds]) + lvs = expand_lval_list (replaces, [lv for (lv, v) in upds]) + vs = expand_arg_fields (replaces, [v for (lv, v) in upds]) - assert [typ for (nm, typ) in lvs] == map (get_expr_typ, vs), (lvs, vs) + assert [typ for (nm, typ) in lvs] == map (get_expr_typ, vs), (lvs, vs) - return [(lv, v) for (lv, v) in zip (lvs, vs) - if not v.is_var (lv)] + return [(lv, v) for (lv, v) in zip (lvs, vs) + if not v.is_var (lv)] def compile_struct_use (function): - trace ('Compiling in %s.' % function.name) - vs = get_vars (function) - max_node = max (function.nodes.keys () + [2]) - - visit_vs = vs.keys () - replaces = {} - while visit_vs: - v = visit_vs.pop () - typ = vs[v] - if typ.kind == 'Struct': - fields = structs[typ.name].field_list - elif typ.kind == 'Array': - fields = [(i, typ.el_typ) for i in range (typ.num)] - else: - continue - new_vs = [(nm, fresh_name ('%s.%s' % (v, nm), vs, f_typ), f_typ) - for (nm, f_typ) in fields] - replaces[v] = new_vs - visit_vs.extend ([v_nm for (_, v_nm, _) in new_vs]) - - for n in function.nodes: - node = function.nodes[n] - if node.kind == 'Basic': - node.upds = compile_upds (replaces, node.upds) - elif node.kind == 'Basic': - assert not node.lval[1].kind in ['Struct', 'Array'] - node.val = compile_accs (replaces, node.val) - elif node.kind == 'Call': - node.args = expand_arg_fields (replaces, node.args) - node.rets = expand_lval_list (replaces, node.rets) - elif node.kind == 'Cond': - node.cond = compile_accs (replaces, node.cond) - else: - assert not 'node kind understood' - - function.inputs = expand_lval_list (replaces, function.inputs) - function.outputs = expand_lval_list (replaces, function.outputs) - return len (replaces) == 0 + trace ('Compiling in %s.' % function.name) + vs = get_vars (function) + max_node = max (function.nodes.keys () + [2]) + + visit_vs = vs.keys () + replaces = {} + while visit_vs: + v = visit_vs.pop () + typ = vs[v] + if typ.kind == 'Struct': + fields = structs[typ.name].field_list + elif typ.kind == 'Array': + fields = [(i, typ.el_typ) for i in range (typ.num)] + else: + continue + new_vs = [(nm, fresh_name ('%s.%s' % (v, nm), vs, f_typ), f_typ) + for (nm, f_typ) in fields] + replaces[v] = new_vs + visit_vs.extend ([v_nm for (_, v_nm, _) in new_vs]) + + for n in function.nodes: + node = function.nodes[n] + if node.kind == 'Basic': + node.upds = compile_upds (replaces, node.upds) + elif node.kind == 'Basic': + assert not node.lval[1].kind in ['Struct', 'Array'] + node.val = compile_accs (replaces, node.val) + elif node.kind == 'Call': + node.args = expand_arg_fields (replaces, node.args) + node.rets = expand_lval_list (replaces, node.rets) + elif node.kind == 'Cond': + node.cond = compile_accs (replaces, node.cond) + else: + assert not 'node kind understood' + + function.inputs = expand_lval_list (replaces, function.inputs) + function.outputs = expand_lval_list (replaces, function.outputs) + return len (replaces) == 0 def check_compile (func): - for node in func.nodes.itervalues (): - vs = {} - get_node_vars (node, vs) - for (v_nm, typ) in vs.iteritems (): - if typ.kind == 'Struct': - print 'Failed to compile struct %s in %s' % (v_nm, func) - print node - assert not 'compiled' - if typ.kind == 'Array': - print 'Failed to compile array %s in %s' % (v_nm, func) - print node - assert not 'compiled' + for node in func.nodes.itervalues (): + vs = {} + get_node_vars (node, vs) + for (v_nm, typ) in vs.iteritems (): + if typ.kind == 'Struct': + print 'Failed to compile struct %s in %s' % (v_nm, func) + print node + assert not 'compiled' + if typ.kind == 'Array': + print 'Failed to compile array %s in %s' % (v_nm, func) + print node + assert not 'compiled' def subst_expr (expr): - if expr.kind == 'Symbol': - if expr.name in symbols: - return mk_word32 (symbols[expr.name][0]) - else: - return None - elif expr.is_op ('PAlignValid'): - [typ, p] = expr.vals - assert typ.kind == 'Type' - return logic.mk_align_valid_ineq (('Type', typ.val), p) - elif expr.kind in ['Op', 'Var', 'Num', 'Type']: - return None - else: - assert not 'expression simple-substitutable', expr + if expr.kind == 'Symbol': + if expr.name in symbols: + #FIXME: dubious assumption of native word size here + return syntax.arch.mk_word(symbols[expr.name][0]) + else: + return None + elif expr.is_op ('PAlignValid'): + [typ, p] = expr.vals + assert typ.kind == 'Type' + return logic.mk_align_valid_ineq (('Type', typ.val), p) + elif expr.kind in ['Op', 'Var', 'Num', 'Type']: + return None + else: + assert not 'expression simple-substitutable', expr def substitute_simple (func): - from syntax import Node - for (n, node) in func.nodes.items (): - func.nodes[n] = node.subst_exprs (subst_expr, - ss = set (['Symbol', 'PAlignValid'])) + from syntax import Node + for (n, node) in func.nodes.items (): + func.nodes[n] = node.subst_exprs (subst_expr, + ss = set (['Symbol', 'PAlignValid'])) def nodes_symbols (nodes): - symbols_needed = set() - def visitor (expr): - if expr.kind == 'Symbol': - symbols_needed.add(expr.name) - for node in nodes: - node.visit (lambda l: (), visitor) - return symbols_needed + symbols_needed = set() + def visitor (expr): + if expr.kind == 'Symbol': + symbols_needed.add(expr.name) + for node in nodes: + node.visit (lambda l: (), visitor) + return symbols_needed def missing_symbols (functions): - symbols_needed = nodes_symbols ([node - for func in functions.itervalues () - for node in func.nodes.itervalues ()]) - trouble = symbols_needed - set (symbols) - if trouble: - print ('Symbols missing for substitution: %r' % trouble) - return trouble + symbols_needed = nodes_symbols ([node + for func in functions.itervalues () + for node in func.nodes.itervalues ()]) + trouble = symbols_needed - set (symbols) + if trouble: + print ('Symbols missing for substitution: %r' % trouble) + return trouble def compile_funcs (functions): - missing_symbols (functions) - for (f, func) in functions.iteritems (): - substitute_simple (func) - check_compile (func) + missing_symbols (functions) + for (f, func) in functions.iteritems (): + substitute_simple (func) + check_compile (func) def combine_duplicate_nodes (nodes): - orig_size = len (nodes) - node_renames = {} - progress = True - while progress: - progress = False - node_names = {} - for (n, node) in nodes.items (): - if node in node_names: - node_renames[n] = node_names[node] - del nodes[n] - progress = True - else: - node_names[node] = n - - if not progress: - break - - for (n, node) in nodes.items (): - nodes[n] = rename_node_conts (node, node_renames) - - if len (nodes) < orig_size: - print 'Trimmed duplicates %d -> %d' % (orig_size, len (nodes)) - return node_renames + orig_size = len (nodes) + node_renames = {} + progress = True + while progress: + progress = False + node_names = {} + for (n, node) in nodes.items (): + if node in node_names: + node_renames[n] = node_names[node] + del nodes[n] + progress = True + else: + node_names[node] = n + + if not progress: + break + + for (n, node) in nodes.items (): + nodes[n] = rename_node_conts (node, node_renames) + + if len (nodes) < orig_size: + print 'Trimmed duplicates %d -> %d' % (orig_size, len (nodes)) + return node_renames def rename_node_conts (node, renames): - new_conts = [renames.get (c, c) for c in node.get_conts ()] - return Node (node.kind, new_conts, node.get_args ()) + new_conts = [renames.get (c, c) for c in node.get_conts ()] + return Node (node.kind, new_conts, node.get_args ()) def recommended_rename (s): - bits = s.split ('.') - if len (bits) != 2: - return s - if all ([x in '0123456789' for x in bits[1]]): - return bits[0] - else: - return s + bits = s.split ('.') + if len (bits) != 2: + return s + if all ([x in '0123456789' for x in bits[1]]): + return bits[0] + else: + return s def rename_vars (function): - preds = logic.compute_preds (function.nodes) - var_deps = logic.compute_var_deps (function.nodes, - lambda x: function.outputs, preds) - - vs = set () - dont_rename_vs = set () - for n in var_deps: - rev_renames = {} - for (v, t) in var_deps[n]: - v2 = recommended_rename (v) - rev_renames.setdefault (v2, []) - rev_renames[v2].append ((v, t)) - vs.add ((v, t)) - for (v2, vlist) in rev_renames.iteritems (): - if len (vlist) > 1: - dont_rename_vs.update (vlist) - - renames = dict ([(v, recommended_rename (v)) for (v, t) in vs - if (v, t) not in dont_rename_vs]) - - f = function - f.inputs = [(renames.get (v, v), t) for (v, t) in f.inputs] - f.outputs = [(renames.get (v, v), t) for (v, t) in f.outputs] - for n in f.nodes: - f.nodes[n] = syntax.copy_rename (f.nodes[n], (renames, {})) + preds = logic.compute_preds (function.nodes) + var_deps = logic.compute_var_deps (function.nodes, + lambda x: function.outputs, preds) + + vs = set () + dont_rename_vs = set () + for n in var_deps: + rev_renames = {} + for (v, t) in var_deps[n]: + v2 = recommended_rename (v) + rev_renames.setdefault (v2, []) + rev_renames[v2].append ((v, t)) + vs.add ((v, t)) + for (v2, vlist) in rev_renames.iteritems (): + if len (vlist) > 1: + dont_rename_vs.update (vlist) + + renames = dict ([(v, recommended_rename (v)) for (v, t) in vs + if (v, t) not in dont_rename_vs]) + + f = function + f.inputs = [(renames.get (v, v), t) for (v, t) in f.inputs] + f.outputs = [(renames.get (v, v), t) for (v, t) in f.outputs] + for n in f.nodes: + f.nodes[n] = syntax.copy_rename (f.nodes[n], (renames, {})) def rename_and_combine_function_duplicates (functions): - for (f, fun) in functions.iteritems (): - rename_vars (fun) - renames = combine_duplicate_nodes (fun.nodes) - fun.entry = renames.get (fun.entry, fun.entry) + for (f, fun) in functions.iteritems (): + rename_vars (fun) + renames = combine_duplicate_nodes (fun.nodes) + fun.entry = renames.get (fun.entry, fun.entry) diff --git a/rep_graph.py b/rep_graph.py index 41b0a0bf..6487955a 100644 --- a/rep_graph.py +++ b/rep_graph.py @@ -4,9 +4,11 @@ # SPDX-License-Identifier: BSD-2-Clause # +from typed_commons import PersistableModel from solver import Solver, merge_envs_pcs, smt_expr, mk_smt_expr, to_smt_expr from syntax import (true_term, false_term, boolT, mk_and, mk_not, mk_implies, - builtinTs, word32T, word8T, foldr1, mk_eq, mk_plus, mk_word32, mk_var) + builtinTs, word64T, word32T, word8T, foldr1, mk_eq, mk_plus, mk_word64, + mk_word32, mk_var) import syntax import logic import solver @@ -16,1244 +18,1264 @@ import target_objects import problem +import itertools class VisitCount: - """Used to represent a target number of visits to a split point. - Options include a number (0, 1, 2), a symbolic offset (i + 1, i + 2), - or a list of options.""" - def __init__ (self, kind, value): - self.kind = kind - self.is_visit_count = True - if kind == 'Number': - self.n = value - elif kind == 'Offset': - self.n = value - elif kind == 'Options': - self.opts = tuple (value) - for opt in self.opts: - assert opt.kind in ['Number', 'Offset'] - else: - assert not 'VisitCount type understood' - - def __hash__ (self): - if self.kind == 'Options': - return hash (self.opts) - else: - return hash (self.kind) + self.n - - def __eq__ (self, other): - if not other: - return False - if self.kind == 'Options': - return (other.kind == 'Options' - and self.opts == other.opts) - else: - return self.kind == other.kind and self.n == other.n - - def __neq__ (self, other): - if not other: - return True - return not (self == other) - - def __str__ (self): - if self.kind == 'Number': - return str (self.n) - elif self.kind == 'Offset': - return 'i+%s' % self.n - elif self.kind == 'Options': - return '_'.join (map (str, self.opts)) - - def __repr__ (self): - (ns, os) = self.get_opts () - return 'vc_options (%r, %r)' % (ns, os) - - def get_opts (self): - if self.kind == 'Options': - opts = self.opts - else: - opts = [self] - ns = [vc.n for vc in opts if vc.kind == 'Number'] - os = [vc.n for vc in opts if vc.kind == 'Offset'] - return (ns, os) - - def serialise (self, ss): - ss.append ('VC') - (ns, os) = self.get_opts () - ss.append ('%d' % len (ns)) - ss.extend (['%d' % n for n in ns]) - ss.append ('%d' % len (os)) - ss.extend (['%d' % n for n in os]) - - def incr (self, incr): - if self.kind in ['Number', 'Offset']: - n = self.n + incr - if n < 0: - return None - return VisitCount (self.kind, n) - elif self.kind == 'Options': - opts = [vc.incr (incr) for vc in self.opts] - opts = [opt for opt in opts if opt] - if opts == []: - return None - return mk_vc_opts (opts) - else: - assert not 'VisitCount type understood' - - def has_zero (self): - if self.kind == 'Options': - return bool ([vc for vc in self.opts - if vc.has_zero ()]) - else: - return self.kind == 'Number' and self.n == 0 + """Used to represent a target number of visits to a split point. + Options include a number (0, 1, 2), a symbolic offset (i + 1, i + 2), + or a list of options.""" + def __init__ (self, kind, value): + self.kind = kind + self.is_visit_count = True + if kind == 'Number': + self.n = value + elif kind == 'Offset': + self.n = value + elif kind == 'Options': + self.opts = tuple (value) + else: + assert not 'VisitCount type understood' + self.assert_invs() + + def assert_invs(self): + """ + Check the structural invariants of the type: + - `Options` values contain more than one possible + option, so `VisitCount('Options',[x])` is flattened + to just `x`. + - `Options` values contain a list of 'Number' values, + followed by a list of 'Offset' values, in which + both sublists are sorted. + """ + if self.kind in ['Number', 'Offset']: + return + assert self.kind == 'Options' + assert len(self.opts) > 1 + structure = [k for k,g in itertools.groupby(self.opts, lambda x: x.kind)] + assert structure in [['Number','Offset'],['Number'],['Offset'],[]] + valueLists = [[x.n for x in g] for k,g in itertools.groupby(self.opts, lambda x: x.kind)] + for v in valueLists: + assert v == sorted(v) + + def __hash__ (self): + (ns, os) = self.get_opts() + return hash((tuple(ns),tuple(os))) + + def __eq__ (self, other): + if not isinstance(other, VisitCount): + return False + return self.get_opts() == other.get_opts() + + def __neq__ (self, other): + if not other: + return True + return not (self == other) + + def __cmp__(self, other): + " returning a negative value if self < other , positive if self > other , and zero if they were equal." + return cmp(self.get_opts(), other.get_opts()) + + def __str__ (self): + if self.kind == 'Number': + return str (self.n) + elif self.kind == 'Offset': + return 'i+%s' % self.n + elif self.kind == 'Options': + return '_'.join (map (str, self.opts)) + + def __repr__ (self): + (ns, os) = self.get_opts () + return 'vc_options (%r, %r)' % (ns, os) + + def get_opts (self): + if self.kind == 'Options': + opts = self.opts + else: + opts = [self] + ns = [vc.n for vc in opts if vc.kind == 'Number'] + os = [vc.n for vc in opts if vc.kind == 'Offset'] + return (ns, os) + + def serialise (self, ss): + ss.append ('VC') + (ns, os) = self.get_opts () + ss.append ('%d' % len (ns)) + ss.extend (['%d' % n for n in ns]) + ss.append ('%d' % len (os)) + ss.extend (['%d' % n for n in os]) + + def incr (self, incr): + if self.kind in ['Number', 'Offset']: + n = self.n + incr + if n < 0: + return None + return VisitCount (self.kind, n) + elif self.kind == 'Options': + opts = [vc.incr (incr) for vc in self.opts] + opts = [opt for opt in opts if opt] + if opts == []: + return None + return mk_vc_opts (opts) + else: + assert not 'VisitCount type understood' + + def has_zero (self): + if self.kind == 'Options': + return bool ([vc for vc in self.opts + if vc.has_zero ()]) + else: + return self.kind == 'Number' and self.n == 0 def mk_vc_opts (opts): - if len (opts) == 1: - return opts[0] - else: - return VisitCount ('Options', opts) + if len (opts) == 1: + return opts[0] + else: + return VisitCount ('Options', opts) def vc_options (nums, offsets): - return mk_vc_opts (map (vc_num, nums) + map (vc_offs, offsets)) + return mk_vc_opts (map (vc_num, nums) + map (vc_offs, offsets)) def vc_num (n): - return VisitCount ('Number', n) + return VisitCount ('Number', n) def vc_upto (n): - return mk_vc_opts (map (vc_num, range (n))) + return mk_vc_opts (map (vc_num, range (n))) def vc_offs (n): - return VisitCount ('Offset', n) + return VisitCount ('Offset', n) def vc_offset_upto (n): - return mk_vc_opts (map (vc_offs, range (n))) + return mk_vc_opts (map (vc_offs, range (n))) def vc_double_range (n, m): - return mk_vc_opts (map (vc_num, range (n)) + map (vc_offs, range (m))) + return mk_vc_opts (map (vc_num, range (n)) + map (vc_offs, range (m))) class InlineEvent(Exception): - pass + pass class Hyp: - """Used to represent a proposition about path conditions or data at - various points in execution.""" - - def __init__ (self, kind, arg1, arg2, induct = None): - self.kind = kind - if kind == 'PCImp': - self.pcs = [arg1, arg2] - elif kind == 'Eq': - self.vals = [arg1, arg2] - self.induct = induct - elif kind == 'EqIfAt': - self.vals = [arg1, arg2] - self.induct = induct - else: - assert not 'hyp kind understood' - - def __repr__ (self): - if self.kind == 'PCImp': - vals = map (repr, self.pcs) - elif self.kind in ['Eq', 'EqIfAt']: - vals = map (repr, self.vals) - if self.induct: - vals += [repr (self.induct)] - else: - assert not 'hyp kind understood' - return 'Hyp (%r, %s)' % (self.kind, ', '.join (vals)) - - def hyp_tuple (self): - if self.kind == 'PCImp': - return ('PCImp', self.pcs[0], self.pcs[1]) - elif self.kind in ['Eq', 'EqIfAt']: - return (self.kind, self.vals[0], - self.vals[1], self.induct) - else: - assert not 'hyp kind understood' - - def __hash__ (self): - return hash (self.hyp_tuple ()) - - def __ne__ (self, other): - return not other or not (self == other) - - def __cmp__ (self, other): - return cmp (self.hyp_tuple (), other.hyp_tuple ()) - - def visits (self): - if self.kind == 'PCImp': - return [vis for vis in self.pcs - if vis[0] != 'Bool'] - elif self.kind in ['Eq', 'EqIfAt']: - return [vis for (_, vis) in self.vals] - else: - assert not 'hyp kind understood' - - def get_vals (self): - if self.kind == 'PCImp': - return [] - else: - return [val for (val, _) in self.vals] - - def serialise_visit (self, (n, restrs), ss): - ss.append ('%s' % n) - ss.append ('%d' % len (restrs)) - for (n2, vc) in restrs: - ss.append ('%d' % n2) - vc.serialise (ss) - - def serialise_pc (self, pc, ss): - if pc[0] == 'Bool' and pc[1] == true_term: - ss.append ('True') - elif pc[0] == 'Bool' and pc[1] == false_term: - ss.append ('False') - else: - ss.append ('PC') - serialise_visit (pc[0], ss) - ss.append (pc[1]) - - def serialise_hyp (self, ss): - if self.kind == 'PCImp': - (visit1, visit2) = self.pcs - ss.append ('PCImp') - self.serialise_pc (visit1, ss) - self.serialise_pc (visit2, ss) - elif self.kind in ['Eq', 'EqIfAt']: - assert len (self.vals) == 2 - ss.extend (self.kind) - for (exp, visit) in self.vals: - exp.serialise (ss) - self.serialise_visit (visit, ss) - if induct: - ss.append ('%d' % induct[0]) - ss.append ('%d' % induct[1]) - else: - ss.extend (['None', 'None']) - else: - assert not 'hyp kind understood' - - def interpret (self, rep): - if self.kind == 'PCImp': - ((visit1, tag1), (visit2, tag2)) = self.pcs - if visit1 == 'Bool': - pc1 = tag1 - else: - pc1 = rep.get_pc (visit1, tag = tag1) - if visit2 == 'Bool': - pc2 = tag2 - else: - pc2 = rep.get_pc (visit2, tag = tag2) - return mk_implies (pc1, pc2) - elif self.kind in ['Eq', 'EqIfAt']: - [(x, xvis), (y, yvis)] = self.vals - if self.induct: - v = rep.get_induct_var (self.induct) - x = subst_induct (x, v) - y = subst_induct (y, v) - x_pc_env = rep.get_node_pc_env (xvis[0], tag = xvis[1]) - y_pc_env = rep.get_node_pc_env (yvis[0], tag = yvis[1]) - if x_pc_env == None or y_pc_env == None: - if self.kind == 'EqIfAt': - return syntax.true_term - else: - return syntax.false_term - ((_, xenv), (_, yenv)) = (x_pc_env, y_pc_env) - eq = inst_eq_with_envs ((x, xenv), (y, yenv), rep.solv) - if self.kind == 'EqIfAt': - x_pc = rep.get_pc (xvis[0], tag = xvis[1]) - y_pc = rep.get_pc (yvis[0], tag = yvis[1]) - return syntax.mk_n_implies ([x_pc, y_pc], eq) - else: - return eq - else: - assert not 'hypothesis type understood' + """Used to represent a proposition about path conditions or data at + various points in execution.""" + + def __init__ (self, kind, arg1, arg2, induct = None): + self.kind = kind + if kind == 'PCImp': + self.pcs = [arg1, arg2] + elif kind == 'Eq': + self.vals = [arg1, arg2] + self.induct = induct + elif kind == 'EqIfAt': + self.vals = [arg1, arg2] + self.induct = induct + else: + assert not 'hyp kind understood' + + def __repr__ (self): + if self.kind == 'PCImp': + vals = map (repr, self.pcs) + elif self.kind in ['Eq', 'EqIfAt']: + vals = map (repr, self.vals) + if self.induct: + vals += [repr (self.induct)] + else: + assert not 'hyp kind understood' + return 'Hyp (%r, %s)' % (self.kind, ', '.join (vals)) + + def hyp_tuple (self): + if self.kind == 'PCImp': + return ('PCImp', self.pcs[0], self.pcs[1]) + elif self.kind in ['Eq', 'EqIfAt']: + return (self.kind, self.vals[0], + self.vals[1], self.induct) + else: + assert not 'hyp kind understood' + + def __hash__ (self): + return hash (self.hyp_tuple ()) + + def __ne__ (self, other): + return not other or not (self == other) + + def __cmp__ (self, other): + return cmp (self.hyp_tuple (), other.hyp_tuple ()) + + def visits (self): + if self.kind == 'PCImp': + return [vis for vis in self.pcs + if vis[0] != 'Bool'] + elif self.kind in ['Eq', 'EqIfAt']: + return [vis for (_, vis) in self.vals] + else: + assert not 'hyp kind understood' + + def get_vals (self): + if self.kind == 'PCImp': + return [] + else: + return [val for (val, _) in self.vals] + + def serialise_visit (self, (n, restrs), ss): + ss.append ('%s' % n) + ss.append ('%d' % len (restrs)) + for (n2, vc) in restrs: + ss.append ('%d' % n2) + vc.serialise (ss) + + def serialise_pc (self, pc, ss): + if pc[0] == 'Bool' and pc[1] == true_term: + ss.append ('True') + elif pc[0] == 'Bool' and pc[1] == false_term: + ss.append ('False') + else: + ss.append ('PC') + serialise_visit (pc[0], ss) + ss.append (pc[1]) + + def serialise_hyp (self, ss): + if self.kind == 'PCImp': + (visit1, visit2) = self.pcs + ss.append ('PCImp') + self.serialise_pc (visit1, ss) + self.serialise_pc (visit2, ss) + elif self.kind in ['Eq', 'EqIfAt']: + assert len (self.vals) == 2 + ss.extend (self.kind) + for (exp, visit) in self.vals: + exp.serialise (ss) + self.serialise_visit (visit, ss) + if induct: + ss.append ('%d' % induct[0]) + ss.append ('%d' % induct[1]) + else: + ss.extend (['None', 'None']) + else: + assert not 'hyp kind understood' + + def interpret (self, rep): + if self.kind == 'PCImp': + ((visit1, tag1), (visit2, tag2)) = self.pcs + if visit1 == 'Bool': + pc1 = tag1 + else: + pc1 = rep.get_pc (visit1, tag = tag1) + if visit2 == 'Bool': + pc2 = tag2 + else: + pc2 = rep.get_pc (visit2, tag = tag2) + return mk_implies (pc1, pc2) + elif self.kind in ['Eq', 'EqIfAt']: + [(x, xvis), (y, yvis)] = self.vals + if self.induct: + v = rep.get_induct_var (self.induct) + x = subst_induct (x, v) + y = subst_induct (y, v) + x_pc_env = rep.get_node_pc_env (xvis[0], tag = xvis[1]) + y_pc_env = rep.get_node_pc_env (yvis[0], tag = yvis[1]) + if x_pc_env == None or y_pc_env == None: + if self.kind == 'EqIfAt': + return syntax.true_term + else: + return syntax.false_term + ((_, xenv), (_, yenv)) = (x_pc_env, y_pc_env) + eq = inst_eq_with_envs ((x, xenv), (y, yenv), rep.solv) + if self.kind == 'EqIfAt': + x_pc = rep.get_pc (xvis[0], tag = xvis[1]) + y_pc = rep.get_pc (yvis[0], tag = yvis[1]) + return syntax.mk_n_implies ([x_pc, y_pc], eq) + else: + return eq + else: + assert not 'hypothesis type understood' def check_vis_is_vis (((n, vc), tag)): - assert vc[:0] == (), vc + assert vc[:0] == (), vc def eq_hyp (lhs, rhs, induct = None, use_if_at = False): - check_vis_is_vis (lhs[1]) - check_vis_is_vis (rhs[1]) - kind = 'Eq' - if use_if_at: - kind = 'EqIfAt' - return Hyp (kind, lhs, rhs, induct = induct) + check_vis_is_vis (lhs[1]) + check_vis_is_vis (rhs[1]) + kind = 'Eq' + if use_if_at: + kind = 'EqIfAt' + return Hyp (kind, lhs, rhs, induct = induct) def true_if_at_hyp (expr, vis, induct = None): - check_vis_is_vis (vis) - return Hyp ('EqIfAt', (expr, vis), (true_term, vis), - induct = induct) + check_vis_is_vis (vis) + return Hyp ('EqIfAt', (expr, vis), (true_term, vis), + induct = induct) def pc_true_hyp (vis): - check_vis_is_vis (vis) - return Hyp ('PCImp', ('Bool', true_term), vis) + check_vis_is_vis (vis) + return Hyp ('PCImp', ('Bool', true_term), vis) def pc_false_hyp (vis): - check_vis_is_vis (vis) - return Hyp ('PCImp', vis, ('Bool', false_term)) + check_vis_is_vis (vis) + return Hyp ('PCImp', vis, ('Bool', false_term)) def pc_triv_hyp (vis): - check_vis_is_vis (vis) - return Hyp ('PCImp', vis, vis) + check_vis_is_vis (vis) + return Hyp ('PCImp', vis, vis) class GraphSlice: - """Used to represent a slice of potential execution in a graph where - looping is limited to certain specific examples. For instance, we - might say that execution through node n will be represented only - by visits 0, 1, 2, 3, i, and i + 1 (for a symbolic value i). The - variable state at visits 4 and i + 2 will be calculated but no - further execution will be done.""" - - def __init__ (self, p, solv, inliner = None, fast = False): - self.p = p - self.solv = solv - self.inp_envs = {} - self.mem_calls = {} - self.add_input_envs () - - self.node_pc_envs = {} - self.node_pc_env_order = [] - self.arc_pc_envs = {} - self.inliner = inliner - self.funcs = {} - self.pc_env_requests = set () - self.fast = fast - self.induct_var_env = {} - self.contractions = {} - - self.local_defs_unsat = False - self.use_known_eqs = True - - self.avail_hyps = set () - self.used_hyps = set () - - def add_input_envs (self): - for (entry, _, _, args) in self.p.entries: - self.inp_envs[entry] = mk_inp_env (entry, args, self) - - def get_reachable (self, split, n): - return self.p.is_reachable_from (split, n) - - class TooGeneral (Exception): - def __init__ (self, split): - self.split = split - - def get_tag_vcount (self, (n, vcount), tag): - if tag == None: - tag = self.p.node_tags[n][0] - vcount_r = [(split, count, self.get_reachable (split, n)) - for (split, count) in vcount - if self.p.node_tags[split][0] == tag] - for (split, count, r) in vcount_r: - if not r and not count.has_zero (): - return (tag, None) - assert count.is_visit_count - vcount = [(s, c) for (s, c, r) in vcount_r if r] - vcount = tuple (sorted (vcount)) - - loop_id = self.p.loop_id (n) - if loop_id != None: - for (split, visits) in vcount: - if (self.p.loop_id (split) == loop_id - and visits.kind == 'Options'): - raise self.TooGeneral (split) - - return (tag, vcount) - - def get_node_pc_env (self, (n, vcount), tag = None, request = True): - tag, vcount = self.get_tag_vcount ((n, vcount), tag) - if vcount == None: - return None - - if (tag, n, vcount) in self.node_pc_envs: - return self.node_pc_envs[(tag, n, vcount)] - - if request: - self.pc_env_requests.add (((n, vcount), tag)) - - self.warm_pc_env_cache ((n, vcount), tag) - - pc_env = self.get_node_pc_env_raw ((n, vcount), tag) - if pc_env: - pc_env = self.apply_known_eqs_pc_env ((n, vcount), - tag, pc_env) - - assert not (tag, n, vcount) in self.node_pc_envs - self.node_pc_envs[(tag, n, vcount)] = pc_env - if pc_env: - self.node_pc_env_order.append ((tag, n, vcount)) - - return pc_env - - def warm_pc_env_cache (self, n_vc, tag): - 'this is to avoid recursion limits and spot bugs' - prev_chain = [] - for i in range (5000): - prevs = self.prevs (n_vc) - try: - prevs = [p for p in prevs - if (tag, p[0], p[1]) - not in self.node_pc_envs - if self.get_tag_vcount (p, None) - == (tag, n_vc[1])] - except self.TooGeneral: - break - if not prevs: - break - n_vc = prevs[0] - prev_chain.append(n_vc) - if not (len (prev_chain) < 5000): - printout ([n for (n, vc) in prev_chain]) - assert len (prev_chain) < 5000, (prev_chain[:10], - prev_chain[-10:]) - - prev_chain.reverse () - for n_vc in prev_chain: - self.get_node_pc_env (n_vc, tag, request = False) - - def get_loop_pc_env (self, split, vcount): - vcount2 = dict (vcount) - vcount2[split] = vc_num (0) - vcount2 = tuple (sorted (vcount2.items ())) - prev_pc_env = self.get_node_pc_env ((split, vcount2)) - if prev_pc_env == None: - return None - (_, prev_env) = prev_pc_env - mem_calls = self.scan_mem_calls (prev_env) - mem_calls = self.add_loop_mem_calls (split, mem_calls) - def av (nm, typ, mem_name = None): - nm2 = '%s_loop_at_%s' % (nm, split) - return self.add_var (nm2, typ, - mem_name = mem_name, mem_calls = mem_calls) - env = {} - consts = set () - for (nm, typ) in prev_env: - check_const = self.fast or (typ in - [builtinTs['HTD'], builtinTs['Dom']]) - if check_const and self.is_synt_const (nm, typ, split): - env[(nm, typ)] = prev_env[(nm, typ)] - consts.add ((nm, typ)) - else: - env[(nm, typ)] = av (nm + '_after', typ, - ('Loop', prev_env[(nm, typ)])) - for (nm, typ) in prev_env: - if (nm, typ) in consts: - continue - z = self.var_rep_request ((nm, typ), 'Loop', - (split, vcount), env) - if z: - env[(nm, typ)] = z - - pc = mk_smt_expr (av ('pc_of', boolT), boolT) - if self.fast: - imp = syntax.mk_implies (pc, prev_pc_env[0]) - self.solv.assert_fact (imp, prev_env, - unsat_tag = ('LoopPCImp', split)) - - return (pc, env) - - def is_synt_const (self, nm, typ, split): - """check if a variable at a split point is a syntactic constant - which is always unmodified by the loop. - we allow cases where a variable is renamed and renamed back - during the loop (this often happens because of inlining). - the check is done by depth-first-search backward through the - graph looking for a source of a variant value.""" - loop = self.p.loop_id (split) - if problem.has_inner_loop (self.p, split): - return False - loop_set = set (self.p.loop_body (split)) - - orig_nm = nm - safe = set ([(orig_nm, split)]) - first_step = True - visit = [] - count = 0 - while first_step or visit: - if first_step: - (nm, n) = (orig_nm, split) - first_step = False - else: - (nm, n) = visit.pop () - if (nm, n) in safe: - continue - elif n == split: - return False - new_nm = nm - node = self.p.nodes[n] - if node.kind == 'Call': - if (nm, typ) not in node.rets: - pass - elif self.fast_const_ret (n, nm, typ): - pass - else: - return False - elif node.kind == 'Basic': - upds = [arg for (lv, arg) in node.upds - if lv == (nm, typ)] - if [v for v in upds if v.kind != 'Var']: - return False - if upds: - new_nm = upds[0].name - preds = [(new_nm, n2) for n2 in self.p.preds[n] - if n2 in loop_set] - unknowns = [p for p in preds if p not in safe] - if unknowns: - visit.extend ([(nm, n)] + unknowns) - else: - safe.add ((nm, n)) - count += 1 - if count % 100000 == 0: - trace ('is_synt_const: %d iterations' % count) - trace ('visit length %d' % len (visit)) - trace ('visit tail %s' % visit[-20:]) - return True - - def fast_const_ret (self, n, nm, typ): - """determine if we can heuristically consider this return - value to be the same as an input. this is known for some - function returns, e.g. memory. - this is important for heuristic "fast" analysis.""" - if not self.fast: - return False - node = self.p.nodes[n] - assert node.kind == 'Call' - for hook in target_objects.hooks ('rep_unsafe_const_ret'): - if hook (node, nm, typ): - return True - return False - - def get_node_pc_env_raw (self, (n, vcount), tag): - if n in self.inp_envs: - return (true_term, self.inp_envs[n]) - - for (split, count) in vcount: - if split == n and count == vc_offs (0): - return self.get_loop_pc_env (split, vcount) - - pc_envs = [pc_env for n_prev in self.p.preds[n] - if self.p.node_tags[n_prev][0] == tag - for pc_env in self.get_arc_pc_envs (n_prev, - (n, vcount))] - - pc_envs = [pc_env for pc_env in pc_envs if pc_env] - if pc_envs == []: - return None - - if n == 'Err': - # we'll never care about variable values here - # and there are sometimes a LOT of arcs to Err - # so we save a lot of merge effort - pc_envs = [(to_smt_expr (pc, env, self.solv), {}) - for (pc, env) in pc_envs] - - (pc, env, large) = merge_envs_pcs (pc_envs, self.solv) - - if pc.kind != 'SMTExpr': - name = self.path_cond_name ((n, vcount), tag) - name = self.solv.add_def (name, pc, env) - pc = mk_smt_expr (name, boolT) - - for (nm, typ) in env: - if len (env[(nm, typ)]) > 80: - env[(nm, typ)] = self.contract (nm, (n, vcount), - env[(nm, typ)], typ) - - return (pc, env) - - def contract (self, name, n_vc, val, typ): - if val in self.contractions: - return self.contractions[val] - - name = self.local_name_before (name, n_vc) - name = self.solv.add_def (name, mk_smt_expr (val, typ), {}) - - self.contractions[val] = name - return name - - def get_arc_pc_envs (self, n, n_vc2): - try: - prevs = [n_vc for n_vc in self.prevs (n_vc2) - if n_vc[0] == n] - assert len (prevs) <= 1 - return [self.get_arc_pc_env (n_vc, n_vc2) - for n_vc in prevs] - except self.TooGeneral, e: - # consider specialisations of the target - specs = self.specialise (n_vc2, e.split) - specs = [(n_vc2[0], spec) for spec in specs] - return [pc_env for spec in specs - for pc_env in self.get_arc_pc_envs (n, spec)] - - def get_arc_pc_env (self, (n, vcount), n2): - tag, vcount = self.get_tag_vcount ((n, vcount), None) - - if vcount == None: - return None - - assert self.is_cont ((n, vcount), n2), ((n, vcount), - n2, self.p.nodes[n].get_conts ()) - - if (n, vcount) in self.arc_pc_envs: - return self.arc_pc_envs[(n, vcount)].get (n2[0]) - - if self.get_node_pc_env ((n, vcount), request = False) == None: - return None - - arcs = self.emit_node ((n, vcount)) - self.post_emit_node_hooks ((n, vcount)) - arcs = dict ([(cont, (pc, env)) for (cont, pc, env) in arcs]) - - self.arc_pc_envs[(n, vcount)] = arcs - return arcs.get (n2[0]) - - def add_local_def (self, n, vname, name, val, env): - if self.local_defs_unsat: - smt_name = self.solv.add_var (name, val.typ) - eq = mk_eq (mk_smt_expr (smt_name, val.typ), val) - self.solv.assert_fact (eq, env, unsat_tag - = ('Def', n, vname)) - else: - smt_name = self.solv.add_def (name, val, env) - return smt_name - - def add_var (self, name, typ, mem_name = None, mem_calls = None): - r = self.solv.add_var_restr (name, typ, mem_name = mem_name) - if typ == syntax.builtinTs['Mem']: - r_x = solver.parse_s_expression (r) - self.mem_calls[r_x] = mem_calls - return r - - def var_rep_request (self, (nm, typ), kind, n_vc, env): - assert type (n_vc[0]) != str - for hook in target_objects.hooks ('problem_var_rep'): - z = hook (self.p, (nm, typ), kind, n_vc[0]) - if z == None: - continue - if z[0] == 'SplitMem': - assert typ == builtinTs['Mem'] - (_, addr) = z - addr = smt_expr (addr, env, self.solv) - name = '%s_for_%s' % (nm, - self.node_count_name (n_vc)) - return self.solv.add_split_mem_var (addr, name, - typ, mem_name = 'SplitMemNonsense') - else: - assert z == None - - def emit_node (self, n): - (pc, env) = self.get_node_pc_env (n, request = False) - tag = self.p.node_tags[n[0]][0] - app_eqs = self.apply_known_eqs_tm (n, tag) - # node = logic.simplify_node_elementary (self.p.nodes[n[0]]) - # whether to ignore unreachable Cond arcs seems to be a huge - # dilemma. if we ignore them, some reachable sites become - # unreachable and we can't interpret all hyps - # if we don't ignore them, the variable set disagrees with - # var_deps and so the abstracted loop pc/env may not be - # sufficient and we get EnvMiss again. I don't really know - # what to do about this corner case. - node = self.p.nodes[n[0]] - env = dict (env) - - if node.kind == 'Call': - self.try_inline (n[0], pc, env) - - if pc == false_term: - return [(c, false_term, {}) for c in node.get_conts()] - elif node.kind == 'Cond' and node.left == node.right: - return [(node.left, pc, env)] - elif node.kind == 'Cond' and node.cond == true_term: - return [(node.left, pc, env), - (node.right, false_term, env)] - elif node.kind == 'Basic': - upds = [] - for (lv, v) in node.upds: - if v.kind == 'Var': - upds.append ((lv, env[(v.name, v.typ)])) - else: - name = self.local_name (lv[0], n) - v = app_eqs (v) - vname = self.add_local_def (n, - ('Var', lv), name, v, env) - upds.append ((lv, vname)) - for (lv, v) in upds: - env[lv] = v - return [(node.cont, pc, env)] - elif node.kind == 'Cond': - name = self.cond_name (n) - cond = self.p.fresh_var (name, boolT) - env[(cond.name, boolT)] = self.add_local_def (n, - 'Cond', name, app_eqs (node.cond), env) - lpc = mk_and (cond, pc) - rpc = mk_and (mk_not (cond), pc) - return [(node.left, lpc, env), (node.right, rpc, env)] - elif node.kind == 'Call': - nm = self.success_name (node.fname, n) - success = self.solv.add_var (nm, boolT) - success = mk_smt_expr (success, boolT) - fun = functions[node.fname] - ins = dict ([((x, typ), smt_expr (app_eqs (arg), env, self.solv)) - for ((x, typ), arg) in azip (fun.inputs, node.args)]) - mem_name = None - for (x, typ) in reversed (fun.inputs): - if typ == builtinTs['Mem']: - inp_mem = ins[(x, typ)] - mem_name = (node.fname, inp_mem) - mem_calls = self.scan_mem_calls (ins) - mem_calls = self.add_mem_call (node.fname, mem_calls) - outs = {} - for ((x, typ), (y, typ2)) in azip (node.rets, fun.outputs): - assert typ2 == typ - if self.fast_const_ret (n[0], x, typ): - outs[(y, typ2)] = env [(x, typ)] - continue - name = self.local_name (x, n) - env[(x, typ)] = self.add_var (name, typ, - mem_name = mem_name, - mem_calls = mem_calls) - outs[(y, typ2)] = env[(x, typ)] - for ((x, typ), (y, _)) in azip (node.rets, fun.outputs): - z = self.var_rep_request ((x, typ), - 'Call', n, env) - if z != None: - env[(x, typ)] = z - outs[(y, typ)] = z - self.add_func (node.fname, ins, outs, success, n) - return [(node.cont, pc, env)] - else: - assert not 'node kind understood' - - def post_emit_node_hooks (self, (n, vcount)): - for hook in target_objects.hooks ('post_emit_node'): - hook (self, (n, vcount)) - - def fetch_known_eqs (self, n_vc, tag): - if not self.use_known_eqs: - return None - eqs = self.p.known_eqs.get ((n_vc, tag)) - if eqs == None: - return None - avail = [] - for (x, n_vc_y, tag_y, y, hyps) in eqs: - if hyps <= self.avail_hyps: - (_, env) = self.get_node_pc_env (n_vc_y, tag_y) - avail.append ((x, smt_expr (y, env, self.solv))) - self.used_hyps.update (hyps) - if avail: - return avail - return None - - def apply_known_eqs_pc_env (self, n_vc, tag, (pc, env)): - eqs = self.fetch_known_eqs (n_vc, tag) - if eqs == None: - return (pc, env) - env = dict (env) - for (x, sx) in eqs: - if x.kind == 'Var': - cur_rhs = env[x.name] - for y in env: - if env[y] == cur_rhs: - trace ('substituted %s at %s.' % (y, n_vc)) - env[y] = sx - return (pc, env) - - def apply_known_eqs_tm (self, n_vc, tag): - eqs = self.fetch_known_eqs (n_vc, tag) - if eqs == None: - return lambda x: x - eqs = dict ([(x, mk_smt_expr (sexpr, x.typ)) - for (x, sexpr) in eqs]) - return lambda tm: logic.recursive_term_subst (eqs, tm) - - def rebuild (self, solv = None): - requests = self.pc_env_requests - - self.node_pc_env_order = [] - self.node_pc_envs = {} - self.arc_pc_envs = {} - self.funcs = {} - self.pc_env_requests = set () - self.induct_var_env = {} - self.contractions = {} - - if not solv: - solv = Solver (produce_unsat_cores - = self.local_defs_unsat) - self.solv = solv - - self.add_input_envs () - - self.used_hyps = set () - run_requests (self, requests) - - def add_func (self, name, inputs, outputs, success, n_vc): - assert n_vc not in self.funcs - self.funcs[n_vc] = (inputs, outputs, success) - for pair in pairings.get (name, []): - self.funcs.setdefault (pair.name, []) - group = self.funcs[pair.name] - for n_vc2 in group: - if self.get_func_pairing (n_vc, n_vc2): - self.add_func_assert (n_vc, n_vc2) - group.append (n_vc) - - def get_func (self, n_vc, tag = None): - """returns (input_env, output_env, success_var) for - function call at given n_vc.""" - tag, vc = self.get_tag_vcount (n_vc, tag) - n_vc = (n_vc[0], vc) - assert self.p.nodes[n_vc[0]].kind == 'Call' - - if n_vc not in self.funcs: - # try to ensure n_vc has been emitted - cont = self.get_cont (n_vc) - self.get_node_pc_env (cont, tag = tag) - - return self.funcs[n_vc] - - def get_func_pairing_nocheck (self, n_vc, n_vc2): - fnames = [self.p.nodes[n_vc[0]].fname, - self.p.nodes[n_vc2[0]].fname] - pairs = [pair for pair in pairings[list (fnames)[0]] - if set (pair.funs.values ()) == set (fnames)] - if not pairs: - return None - [pair] = pairs - if pair.funs[pair.tags[0]] == fnames[0]: - return (pair, n_vc, n_vc2) - else: - return (pair, n_vc2, n_vc) - - def get_func_pairing (self, n_vc, n_vc2): - res = self.get_func_pairing_nocheck (n_vc, n_vc2) - if not res: - return res - (pair, l_n_vc, r_n_vc) = res - (lin, _, _) = self.funcs[l_n_vc] - (rin, _, _) = self.funcs[r_n_vc] - l_mem_calls = self.scan_mem_calls (lin) - r_mem_calls = self.scan_mem_calls (rin) - tags = pair.tags - (c, s) = mem_calls_compatible (tags, l_mem_calls, r_mem_calls) - if not c: - trace ('skipped emitting func pairing %s -> %s' - % (l_n_vc, r_n_vc)) - trace (' ' + s) - return None - return res - - def get_func_assert (self, n_vc, n_vc2): - (pair, l_n_vc, r_n_vc) = self.get_func_pairing (n_vc, n_vc2) - (ltag, rtag) = pair.tags - (inp_eqs, out_eqs) = pair.eqs - (lin, lout, lsucc) = self.funcs[l_n_vc] - (rin, rout, rsucc) = self.funcs[r_n_vc] - lpc = self.get_pc (l_n_vc) - rpc = self.get_pc (r_n_vc) - envs = {ltag + '_IN': lin, rtag + '_IN': rin, - ltag + '_OUT': lout, rtag + '_OUT': rout} - inp_eqs = inst_eqs (inp_eqs, envs, self.solv) - out_eqs = inst_eqs (out_eqs, envs, self.solv) - succ_imp = mk_implies (rsucc, lsucc) - - return mk_implies (foldr1 (mk_and, inp_eqs + [rpc]), - foldr1 (mk_and, out_eqs + [succ_imp])) - - def add_func_assert (self, n_vc, n_vc2): - imp = self.get_func_assert (n_vc, n_vc2) - imp = logic.weaken_assert (imp) - if self.local_defs_unsat: - self.solv.assert_fact (imp, {}, unsat_tag = ('FunEq', - ln, rn)) - else: - self.solv.assert_fact (imp, {}) - - def node_count_name (self, (n, vcount)): - name = str (n) - bits = [str (n)] + ['%s=%s' % (split, count) - for (split, count) in vcount] - return '_'.join (bits) - - def get_mem_calls (self, mem_sexpr): - mem_sexpr = solver.parse_s_expression (mem_sexpr) - return self.get_mem_calls_sexpr (mem_sexpr) - - def get_mem_calls_sexpr (self, mem_sexpr): - stores = set (['store-word32', 'store-word8', 'store-word64']) - if mem_sexpr in self.mem_calls: - return self.mem_calls[mem_sexpr] - elif len (mem_sexpr) == 4 and mem_sexpr[0] in stores: - return self.get_mem_calls_sexpr (mem_sexpr[1]) - elif mem_sexpr[:1] == ('ite', ): - (_, _, x, y) = mem_sexpr - x_calls = self.get_mem_calls_sexpr (x) - y_calls = self.get_mem_calls_sexpr (y) - return merge_mem_calls (x_calls, y_calls) - elif mem_sexpr in self.solv.defs: - mem_sexpr = self.solv.defs[mem_sexpr] - return self.get_mem_calls_sexpr (mem_sexpr) - assert not "mem_calls fallthrough", mem_sexpr - - def scan_mem_calls (self, env): - mem_vs = [env[(nm, typ)] - for (nm, typ) in env - if typ == syntax.builtinTs['Mem']] - mem_calls = [self.get_mem_calls (v) - for v in mem_vs if v[0] != 'SplitMem'] - if mem_calls: - return foldr1 (merge_mem_calls, mem_calls) - else: - return None - - def add_mem_call (self, fname, mem_calls): - if mem_calls == None: - return None - mem_calls = dict (mem_calls) - (min_calls, max_calls) = mem_calls.get (fname, (0, 0)) - if max_calls == None: - mem_calls[fname] = (min_calls + 1, None) - else: - mem_calls[fname] = (min_calls + 1, max_calls + 1) - return mem_calls - - def add_loop_mem_calls (self, split, mem_calls): - if mem_calls == None: - return None - fnames = set ([self.p.nodes[n].fname - for n in self.p.loop_body (split) - if self.p.nodes[n].kind == 'Call']) - if not fnames: - return mem_calls - mem_calls = dict (mem_calls) - for fname in fnames: - if fname not in mem_calls: - mem_calls[fname] = (0, None) - else: - (min_calls, max_calls) = mem_calls[fname] - mem_calls[fname] = (min_calls, None) - return mem_calls - - # note these names are designed to be unique by suffix - # (so that smt names are independent of order of requests) - def local_name (self, s, n_vc): - return '%s_after_%s' % (s, self.node_count_name (n_vc)) - - def local_name_before (self, s, n_vc): - return '%s_v_at_%s' % (s, self.node_count_name (n_vc)) - - def cond_name (self, n_vc): - return 'cond_at_%s' % self.node_count_name (n_vc) - - def path_cond_name (self, n_vc, tag): - return 'path_cond_to_%s_%s' % ( - self.node_count_name (n_vc), tag) - - def success_name (self, fname, n_vc): - bits = fname.split ('.') - nms = ['_'.join (bits[i:]) for i in range (len (bits)) - if bits[i:][0].isalpha ()] - if nms: - nm = nms[-1] - else: - nm = 'fun' - return '%s_success_at_%s' % (nm, self.node_count_name (n_vc)) - - def try_inline (self, n, pc, env): - if not self.inliner: - return False - - inline = self.inliner ((self.p, n)) - if not inline: - return False - - # make sure this node is reachable before inlining - if self.solv.test_hyp (mk_not (pc), env): - trace ('Skipped inlining at %d.' % n) - return False - - trace ('Inlining at %d.' % n) - inline () - raise InlineEvent () - - def incr (self, vcount, n, incr): - vcount2 = dict (vcount) - vcount2[n] = vcount2[n].incr (incr) - if vcount2[n] == None: - return None - return tuple (sorted (vcount2.items ())) - - def get_cont (self, (n, vcount)): - [c] = self.p.nodes[n].get_conts () - vcount2 = dict (vcount) - if n in vcount2: - vcount = self.incr (vcount, n, 1) - cont = (c, vcount) - assert self.is_cont ((n, vcount), cont) - return cont - - def is_cont (self, (n, vcount), (n2, vcount2)): - if n2 not in self.p.nodes[n].get_conts (): - trace ('Not a graph cont.') - return False - - vcount_d = dict (vcount) - vcount_d2 = dict (vcount2) - if n in vcount_d2: - if n in vcount_d: - assert vcount_d[n].kind != 'Options' - vcount_d2[n] = vcount_d2[n].incr (-1) - - if not vcount_d <= vcount_d2: - trace ('Restrictions not subset.') - return False - - for (split, count) in vcount_d2.iteritems (): - if split in vcount_d: - continue - if self.get_reachable (split, n): - return False - if not count.has_zero (): - return False - - return True - - def prevs (self, (n, vcount)): - prevs = [] - vcount_d = dict (vcount) - for p in self.p.preds[n]: - if p in vcount_d: - vcount2 = self.incr (vcount, p, -1) - if vcount2 == None: - continue - prevs.append ((p, vcount2)) - else: - prevs.append ((p, vcount)) - return prevs - - def specialise (self, (n, vcount), split): - vcount = dict (vcount) - assert vcount[split].kind == 'Options' - specs = [] - for n in vcount[split].opts: - v = dict (vcount) - v[split] = n - specs.append (tuple (sorted (v.items ()))) - return specs - - def get_pc (self, (n, vcount), tag = None): - pc_env = self.get_node_pc_env ((n, vcount), tag = tag) - if pc_env == None: - trace ('Warning: unreachable n_vc, tag: %s, %s' % ((n, vcount), tag)) - return false_term - (pc, env) = pc_env - return to_smt_expr (pc, env, self.solv) - - def to_smt_expr (self, expr, (n, vcount), tag = None): - pc_env = self.get_node_pc_env ((n, vcount), tag = tag) - (pc, env) = pc_env - return to_smt_expr (expr, env, self.solv) - - def get_induct_var (self, (n1, n2)): - if (n1, n2) not in self.induct_var_env: - vname = self.solv.add_var ('induct_i_%d_%d' % (n1, n2), - word32T) - self.induct_var_env[(n1, n2)] = vname - self.pc_env_requests.add (((n1, n2), 'InductVar')) - else: - vname = self.induct_var_env[(n1, n2)] - return mk_smt_expr (vname, word32T) - - def interpret_hyp (self, hyp): - return hyp.interpret (self) - - def interpret_hyp_imps (self, hyps, concl): - hyps = map (self.interpret_hyp, hyps) - return logic.strengthen_hyp (syntax.mk_n_implies (hyps, concl)) - - def test_hyp_whyps (self, hyp, hyps, cache = None, fast = False, - model = None): - self.avail_hyps = set (hyps) - if not self.used_hyps <= self.avail_hyps: - self.rebuild () - - last_test[0] = (hyp, hyps, list (self.pc_env_requests)) - - expr = self.interpret_hyp_imps (hyps, hyp) - - trace ('Testing hyp whyps', push = 1) - trace ('requests = %s' % self.pc_env_requests) - - expr_s = smt_expr (expr, {}, self.solv) - if cache and expr_s in cache: - trace ('Cached: %s' % cache[expr_s]) - return cache[expr_s] - if fast: - trace ('(not in cache)') - return None - - self.solv.add_pvalid_dom_assertions () - - (result, _, _) = self.solv.parallel_test_hyps ([(None, expr)], - {}, model = model) - trace ('Result: %s' % result, push = -1) - if cache != None: - cache[expr_s] = result - if not result: - last_failed_test[0] = last_test[0] - return result - - def test_hyp_imp (self, hyps, hyp, model = None): - return self.test_hyp_whyps (self.interpret_hyp (hyp), hyps, - model = model) - - def test_hyp_imps (self, imps): - last_hyp_imps[0] = imps - if imps == []: - return (True, None) - interp_imps = list (enumerate ([self.interpret_hyp_imps (hyps, - self.interpret_hyp (hyp)) - for (hyps, hyp) in imps])) - reqs = list (self.pc_env_requests) - last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) - self.solv.add_pvalid_dom_assertions () - result = self.solv.parallel_test_hyps (interp_imps, {}) - assert result[0] in [True, False], result - if result[0] == False: - (hyps, hyp) = imps[result[1]] - last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) - last_failed_test[0] = last_test[0] - return result - - def replay_requests (self, reqs): - for ((n, vc), tag) in reqs: - self.get_node_pc_env ((n, vc), tag = tag) + """Used to represent a slice of potential execution in a graph where + looping is limited to certain specific examples. For instance, we + might say that execution through node n will be represented only + by visits 0, 1, 2, 3, i, and i + 1 (for a symbolic value i). The + variable state at visits 4 and i + 2 will be calculated but no + further execution will be done.""" + + def __init__ (self, p, solv, inliner = None, fast = False): + self.p = p + self.solv = solv + self.inp_envs = {} + self.mem_calls = {} + self.add_input_envs () + + self.node_pc_envs = {} + self.node_pc_env_order = [] + self.arc_pc_envs = {} + self.inliner = inliner + self.funcs = {} + self.pc_env_requests = set () + self.fast = fast + self.induct_var_env = {} + self.contractions = {} + + self.local_defs_unsat = False + self.use_known_eqs = True + + self.avail_hyps = set () + self.used_hyps = set () + + def add_input_envs (self): + for (entry, _, _, args) in self.p.entries: + self.inp_envs[entry] = mk_inp_env (entry, args, self) + + def get_reachable (self, split, n): + return self.p.is_reachable_from (split, n) + + class TooGeneral (Exception): + def __init__ (self, split): + self.split = split + + def get_tag_vcount (self, (n, vcount), tag): + if tag == None: + tag = self.p.node_tags[n][0] + vcount_r = [(split, count, self.get_reachable (split, n)) + for (split, count) in vcount + if self.p.node_tags[split][0] == tag] + for (split, count, r) in vcount_r: + if not r and not count.has_zero (): + return (tag, None) + assert count.is_visit_count + vcount = [(s, c) for (s, c, r) in vcount_r if r] + vcount = tuple (sorted (vcount)) + + loop_id = self.p.loop_id (n) + if loop_id != None: + for (split, visits) in vcount: + if (self.p.loop_id (split) == loop_id + and visits.kind == 'Options'): + raise self.TooGeneral (split) + + return (tag, vcount) + + def get_node_pc_env (self, (n, vcount), tag = None, request = True): + tag, vcount = self.get_tag_vcount ((n, vcount), tag) + if vcount == None: + return None + + if (tag, n, vcount) in self.node_pc_envs: + return self.node_pc_envs[(tag, n, vcount)] + + if request: + self.pc_env_requests.add (((n, vcount), tag)) + + self.warm_pc_env_cache ((n, vcount), tag) + + pc_env = self.get_node_pc_env_raw ((n, vcount), tag) + if pc_env: + pc_env = self.apply_known_eqs_pc_env ((n, vcount), + tag, pc_env) + + assert not (tag, n, vcount) in self.node_pc_envs + self.node_pc_envs[(tag, n, vcount)] = pc_env + if pc_env: + self.node_pc_env_order.append ((tag, n, vcount)) + + return pc_env + + def warm_pc_env_cache (self, n_vc, tag): + 'this is to avoid recursion limits and spot bugs' + prev_chain = [] + for i in range (5000): + prevs = self.prevs (n_vc) + try: + prevs = [p for p in prevs + if (tag, p[0], p[1]) + not in self.node_pc_envs + if self.get_tag_vcount (p, None) + == (tag, n_vc[1])] + except self.TooGeneral: + break + if not prevs: + break + n_vc = prevs[0] + prev_chain.append(n_vc) + if not (len (prev_chain) < 5000): + printout ([n for (n, vc) in prev_chain]) + assert len (prev_chain) < 5000, (prev_chain[:10], + prev_chain[-10:]) + + prev_chain.reverse () + for n_vc in prev_chain: + self.get_node_pc_env (n_vc, tag, request = False) + + def get_loop_pc_env (self, split, vcount): + vcount2 = dict (vcount) + vcount2[split] = vc_num (0) + vcount2 = tuple (sorted (vcount2.items ())) + prev_pc_env = self.get_node_pc_env ((split, vcount2)) + if prev_pc_env == None: + return None + (_, prev_env) = prev_pc_env + mem_calls = self.scan_mem_calls (prev_env) + mem_calls = self.add_loop_mem_calls (split, mem_calls) + def av (nm, typ, mem_name = None): + nm2 = '%s_loop_at_%s' % (nm, split) + return self.add_var (nm2, typ, + mem_name = mem_name, mem_calls = mem_calls) + env = {} + consts = set () + for (nm, typ) in prev_env: + check_const = self.fast or (typ in + [builtinTs['HTD'], builtinTs['Dom']]) + if check_const and self.is_synt_const (nm, typ, split): + env[(nm, typ)] = prev_env[(nm, typ)] + consts.add ((nm, typ)) + else: + env[(nm, typ)] = av (nm + '_after', typ, + ('Loop', prev_env[(nm, typ)])) + for (nm, typ) in prev_env: + if (nm, typ) in consts: + continue + z = self.var_rep_request ((nm, typ), 'Loop', + (split, vcount), env) + if z: + env[(nm, typ)] = z + + pc = mk_smt_expr (av ('pc_of', boolT), boolT) + if self.fast: + imp = syntax.mk_implies (pc, prev_pc_env[0]) + self.solv.assert_fact (imp, prev_env, + unsat_tag = ('LoopPCImp', split)) + + return (pc, env) + + def is_synt_const (self, nm, typ, split): + """check if a variable at a split point is a syntactic constant + which is always unmodified by the loop. + we allow cases where a variable is renamed and renamed back + during the loop (this often happens because of inlining). + the check is done by depth-first-search backward through the + graph looking for a source of a variant value.""" + loop = self.p.loop_id (split) + if problem.has_inner_loop (self.p, split): + return False + loop_set = set (self.p.loop_body (split)) + + orig_nm = nm + safe = set ([(orig_nm, split)]) + first_step = True + visit = [] + count = 0 + while first_step or visit: + if first_step: + (nm, n) = (orig_nm, split) + first_step = False + else: + (nm, n) = visit.pop () + if (nm, n) in safe: + continue + elif n == split: + return False + new_nm = nm + node = self.p.nodes[n] + if node.kind == 'Call': + if (nm, typ) not in node.rets: + pass + elif self.fast_const_ret (n, nm, typ): + pass + else: + return False + elif node.kind == 'Basic': + upds = [arg for (lv, arg) in node.upds + if lv == (nm, typ)] + if [v for v in upds if v.kind != 'Var']: + return False + if upds: + new_nm = upds[0].name + preds = [(new_nm, n2) for n2 in self.p.preds[n] + if n2 in loop_set] + unknowns = [p for p in preds if p not in safe] + if unknowns: + visit.extend ([(nm, n)] + unknowns) + else: + safe.add ((nm, n)) + count += 1 + if count % 100000 == 0: + trace ('is_synt_const: %d iterations' % count) + trace ('visit length %d' % len (visit)) + trace ('visit tail %s' % visit[-20:]) + return True + + def fast_const_ret (self, n, nm, typ): + """determine if we can heuristically consider this return + value to be the same as an input. this is known for some + function returns, e.g. memory. + this is important for heuristic "fast" analysis.""" + if not self.fast: + return False + node = self.p.nodes[n] + assert node.kind == 'Call' + for hook in target_objects.hooks ('rep_unsafe_const_ret'): + if hook (node, nm, typ): + return True + return False + + def get_node_pc_env_raw (self, (n, vcount), tag): + if n in self.inp_envs: + return (true_term, self.inp_envs[n]) + + for (split, count) in vcount: + if split == n and count == vc_offs (0): + return self.get_loop_pc_env (split, vcount) + + pc_envs = [pc_env for n_prev in self.p.preds[n] + if self.p.node_tags[n_prev][0] == tag + for pc_env in self.get_arc_pc_envs (n_prev, + (n, vcount))] + + pc_envs = [pc_env for pc_env in pc_envs if pc_env] + if pc_envs == []: + return None + + if n == 'Err': + # we'll never care about variable values here + # and there are sometimes a LOT of arcs to Err + # so we save a lot of merge effort + pc_envs = [(to_smt_expr (pc, env, self.solv), {}) + for (pc, env) in pc_envs] + + (pc, env, large) = merge_envs_pcs (pc_envs, self.solv) + + if pc.kind != 'SMTExpr': + name = self.path_cond_name ((n, vcount), tag) + name = self.solv.add_def (name, pc, env) + pc = mk_smt_expr (name, boolT) + + for (nm, typ) in env: + if len (env[(nm, typ)]) > 80: + env[(nm, typ)] = self.contract (nm, (n, vcount), + env[(nm, typ)], typ) + + return (pc, env) + + def contract (self, name, n_vc, val, typ): + if val in self.contractions: + return self.contractions[val] + + name = self.local_name_before (name, n_vc) + name = self.solv.add_def (name, mk_smt_expr (val, typ), {}) + + self.contractions[val] = name + return name + + def get_arc_pc_envs (self, n, n_vc2): + try: + prevs = [n_vc for n_vc in self.prevs (n_vc2) + if n_vc[0] == n] + assert len (prevs) <= 1 + return [self.get_arc_pc_env (n_vc, n_vc2) + for n_vc in prevs] + except self.TooGeneral, e: + # consider specialisations of the target + specs = self.specialise (n_vc2, e.split) + specs = [(n_vc2[0], spec) for spec in specs] + return [pc_env for spec in specs + for pc_env in self.get_arc_pc_envs (n, spec)] + + def get_arc_pc_env (self, (n, vcount), n2): + tag, vcount = self.get_tag_vcount ((n, vcount), None) + + if vcount == None: + return None + + assert self.is_cont ((n, vcount), n2), ((n, vcount), + n2, self.p.nodes[n].get_conts ()) + + if (n, vcount) in self.arc_pc_envs: + return self.arc_pc_envs[(n, vcount)].get (n2[0]) + + if self.get_node_pc_env ((n, vcount), request = False) == None: + return None + + arcs = self.emit_node ((n, vcount)) + self.post_emit_node_hooks ((n, vcount)) + arcs = dict ([(cont, (pc, env)) for (cont, pc, env) in arcs]) + + self.arc_pc_envs[(n, vcount)] = arcs + return arcs.get (n2[0]) + + def add_local_def (self, n, vname, name, val, env): + if self.local_defs_unsat: + smt_name = self.solv.add_var (name, val.typ) + eq = mk_eq (mk_smt_expr (smt_name, val.typ), val) + self.solv.assert_fact (eq, env, unsat_tag + = ('Def', n, vname)) + else: + smt_name = self.solv.add_def (name, val, env) + return smt_name + + def add_var (self, name, typ, mem_name = None, mem_calls = None): + r = self.solv.add_var_restr (name, typ, mem_name = mem_name) + if typ == syntax.builtinTs['Mem']: + r_x = solver.parse_s_expression (r) + self.mem_calls[r_x] = mem_calls + return r + + def var_rep_request (self, (nm, typ), kind, n_vc, env): + assert type (n_vc[0]) != str + for hook in target_objects.hooks ('problem_var_rep'): + z = hook (self.p, (nm, typ), kind, n_vc[0]) + if z == None: + continue + if z[0] == 'SplitMem': + assert typ == builtinTs['Mem'] + (_, addr) = z + addr = smt_expr (addr, env, self.solv) + name = '%s_for_%s' % (nm, + self.node_count_name (n_vc)) + return self.solv.add_split_mem_var (addr, name, + typ, mem_name = 'SplitMemNonsense') + else: + assert z == None + + def emit_node (self, n): + (pc, env) = self.get_node_pc_env (n, request = False) + tag = self.p.node_tags[n[0]][0] + app_eqs = self.apply_known_eqs_tm (n, tag) + # node = logic.simplify_node_elementary (self.p.nodes[n[0]]) + # whether to ignore unreachable Cond arcs seems to be a huge + # dilemma. if we ignore them, some reachable sites become + # unreachable and we can't interpret all hyps + # if we don't ignore them, the variable set disagrees with + # var_deps and so the abstracted loop pc/env may not be + # sufficient and we get EnvMiss again. I don't really know + # what to do about this corner case. + node = self.p.nodes[n[0]] + env = dict (env) + + if node.kind == 'Call': + self.try_inline (n[0], pc, env) + + if pc == false_term: + return [(c, false_term, {}) for c in node.get_conts()] + elif node.kind == 'Cond' and node.left == node.right: + return [(node.left, pc, env)] + elif node.kind == 'Cond' and node.cond == true_term: + return [(node.left, pc, env), + (node.right, false_term, env)] + elif node.kind == 'Basic': + upds = [] + for (lv, v) in node.upds: + if v.kind == 'Var': + upds.append ((lv, env[(v.name, v.typ)])) + else: + name = self.local_name (lv[0], n) + + v = app_eqs (v) + vname = self.add_local_def (n, + ('Var', lv), name, v, env) + upds.append ((lv, vname)) + for (lv, v) in upds: + env[lv] = v + return [(node.cont, pc, env)] + elif node.kind == 'Cond': + name = self.cond_name (n) + cond = self.p.fresh_var (name, boolT) + env[(cond.name, boolT)] = self.add_local_def (n, + 'Cond', name, app_eqs (node.cond), env) + lpc = mk_and (cond, pc) + rpc = mk_and (mk_not (cond), pc) + return [(node.left, lpc, env), (node.right, rpc, env)] + elif node.kind == 'Call': + nm = self.success_name (node.fname, n) + success = self.solv.add_var (nm, boolT) + success = mk_smt_expr (success, boolT) + fun = functions[node.fname] + + ins = dict ([((x, typ), smt_expr (app_eqs (arg), env, self.solv)) + for ((x, typ), arg) in azip (fun.inputs, node.args)]) + mem_name = None + for (x, typ) in reversed (fun.inputs): + if typ == builtinTs['Mem']: + inp_mem = ins[(x, typ)] + mem_name = (node.fname, inp_mem) + mem_calls = self.scan_mem_calls (ins) + mem_calls = self.add_mem_call (node.fname, mem_calls) + outs = {} + for ((x, typ), (y, typ2)) in azip (node.rets, fun.outputs): + assert typ2 == typ + if self.fast_const_ret (n[0], x, typ): + outs[(y, typ2)] = env [(x, typ)] + continue + name = self.local_name (x, n) + env[(x, typ)] = self.add_var (name, typ, + mem_name = mem_name, + mem_calls = mem_calls) + outs[(y, typ2)] = env[(x, typ)] + for ((x, typ), (y, _)) in azip (node.rets, fun.outputs): + z = self.var_rep_request ((x, typ), + 'Call', n, env) + if z != None: + env[(x, typ)] = z + outs[(y, typ)] = z + self.add_func (node.fname, ins, outs, success, n) + return [(node.cont, pc, env)] + else: + assert not 'node kind understood' + + def post_emit_node_hooks (self, (n, vcount)): + for hook in target_objects.hooks ('post_emit_node'): + hook (self, (n, vcount)) + + def fetch_known_eqs (self, n_vc, tag): + if not self.use_known_eqs: + return None + eqs = self.p.known_eqs.get ((n_vc, tag)) + if eqs == None: + return None + avail = [] + for (x, n_vc_y, tag_y, y, hyps) in eqs: + if hyps <= self.avail_hyps: + (_, env) = self.get_node_pc_env (n_vc_y, tag_y) + avail.append ((x, smt_expr (y, env, self.solv))) + self.used_hyps.update (hyps) + if avail: + return avail + return None + + def apply_known_eqs_pc_env (self, n_vc, tag, (pc, env)): + eqs = self.fetch_known_eqs (n_vc, tag) + if eqs == None: + return (pc, env) + env = dict (env) + for (x, sx) in eqs: + if x.kind == 'Var': + cur_rhs = env[x.name] + for y in env: + if env[y] == cur_rhs: + trace ('substituted %s at %s.' % (y, n_vc)) + env[y] = sx + return (pc, env) + + def apply_known_eqs_tm (self, n_vc, tag): + eqs = self.fetch_known_eqs (n_vc, tag) + if eqs == None: + return lambda x: x + eqs = dict ([(x, mk_smt_expr (sexpr, x.typ)) + for (x, sexpr) in eqs]) + return lambda tm: logic.recursive_term_subst (eqs, tm) + + def rebuild (self, solv = None): + requests = self.pc_env_requests + + self.node_pc_env_order = [] + self.node_pc_envs = {} + self.arc_pc_envs = {} + self.funcs = {} + self.pc_env_requests = set () + self.induct_var_env = {} + self.contractions = {} + + if not solv: + solv = Solver (produce_unsat_cores + = self.local_defs_unsat) + self.solv = solv + + self.add_input_envs () + + self.used_hyps = set () + run_requests (self, requests) + + def add_func (self, name, inputs, outputs, success, n_vc): + assert n_vc not in self.funcs + self.funcs[n_vc] = (inputs, outputs, success) + for pair in pairings.get (name, []): + self.funcs.setdefault (pair.name, []) + group = self.funcs[pair.name] + for n_vc2 in group: + if self.get_func_pairing (n_vc, n_vc2): + self.add_func_assert (n_vc, n_vc2) + group.append (n_vc) + + def get_func (self, n_vc, tag = None): + """returns (input_env, output_env, success_var) for + function call at given n_vc.""" + tag, vc = self.get_tag_vcount (n_vc, tag) + n_vc = (n_vc[0], vc) + assert self.p.nodes[n_vc[0]].kind == 'Call' + + if n_vc not in self.funcs: + # try to ensure n_vc has been emitted + cont = self.get_cont (n_vc) + self.get_node_pc_env (cont, tag = tag) + + return self.funcs[n_vc] + + def get_func_pairing_nocheck (self, n_vc, n_vc2): + fnames = [self.p.nodes[n_vc[0]].fname, + self.p.nodes[n_vc2[0]].fname] + pairs = [pair for pair in pairings[list (fnames)[0]] + if set (pair.funs.values ()) == set (fnames)] + if not pairs: + return None + [pair] = pairs + if pair.funs[pair.tags[0]] == fnames[0]: + return (pair, n_vc, n_vc2) + else: + return (pair, n_vc2, n_vc) + + def get_func_pairing (self, n_vc, n_vc2): + res = self.get_func_pairing_nocheck (n_vc, n_vc2) + if not res: + return res + (pair, l_n_vc, r_n_vc) = res + (lin, _, _) = self.funcs[l_n_vc] + (rin, _, _) = self.funcs[r_n_vc] + l_mem_calls = self.scan_mem_calls (lin) + r_mem_calls = self.scan_mem_calls (rin) + tags = pair.tags + (c, s) = mem_calls_compatible (tags, l_mem_calls, r_mem_calls) + if not c: + trace ('skipped emitting func pairing %s -> %s' + % (l_n_vc, r_n_vc)) + trace (' ' + s) + return None + return res + + def get_func_assert (self, n_vc, n_vc2): + (pair, l_n_vc, r_n_vc) = self.get_func_pairing (n_vc, n_vc2) + (ltag, rtag) = pair.tags + (inp_eqs, out_eqs) = pair.eqs + (lin, lout, lsucc) = self.funcs[l_n_vc] + (rin, rout, rsucc) = self.funcs[r_n_vc] + lpc = self.get_pc (l_n_vc) + rpc = self.get_pc (r_n_vc) + envs = {ltag + '_IN': lin, rtag + '_IN': rin, + ltag + '_OUT': lout, rtag + '_OUT': rout} + inp_eqs = inst_eqs (inp_eqs, envs, self.solv) + out_eqs = inst_eqs (out_eqs, envs, self.solv) + succ_imp = mk_implies (rsucc, lsucc) + + return mk_implies (foldr1 (mk_and, inp_eqs + [rpc]), + foldr1 (mk_and, out_eqs + [succ_imp])) + + def add_func_assert (self, n_vc, n_vc2): + imp = self.get_func_assert (n_vc, n_vc2) + imp = logic.weaken_assert (imp) + if self.local_defs_unsat: + self.solv.assert_fact (imp, {}, unsat_tag = ('FunEq', + ln, rn)) + else: + self.solv.assert_fact (imp, {}) + + def node_count_name (self, (n, vcount)): + name = str (n) + bits = [str (n)] + ['%s=%s' % (split, count) + for (split, count) in vcount] + return '_'.join (bits) + + def get_mem_calls (self, mem_sexpr): + mem_sexpr = solver.parse_s_expression (mem_sexpr) + return self.get_mem_calls_sexpr (mem_sexpr) + + def get_mem_calls_sexpr (self, mem_sexpr): + stores = set (['store-word32', 'store-word8', 'store-word64']) + if mem_sexpr in self.mem_calls: + return self.mem_calls[mem_sexpr] + elif len (mem_sexpr) == 4 and mem_sexpr[0] in stores: + return self.get_mem_calls_sexpr (mem_sexpr[1]) + elif mem_sexpr[:1] == ('ite', ): + (_, _, x, y) = mem_sexpr + x_calls = self.get_mem_calls_sexpr (x) + y_calls = self.get_mem_calls_sexpr (y) + return merge_mem_calls (x_calls, y_calls) + elif mem_sexpr in self.solv.defs: + mem_sexpr = self.solv.defs[mem_sexpr] + return self.get_mem_calls_sexpr (mem_sexpr) + assert not "mem_calls fallthrough", mem_sexpr + + def scan_mem_calls (self, env): + mem_vs = [env[(nm, typ)] + for (nm, typ) in env + if typ == syntax.builtinTs['Mem']] + mem_calls = [self.get_mem_calls (v) + for v in mem_vs if v[0] != 'SplitMem'] + if mem_calls: + return foldr1 (merge_mem_calls, mem_calls) + else: + return None + + def add_mem_call (self, fname, mem_calls): + if mem_calls == None: + return None + mem_calls = dict (mem_calls) + (min_calls, max_calls) = mem_calls.get (fname, (0, 0)) + if max_calls == None: + mem_calls[fname] = (min_calls + 1, None) + else: + mem_calls[fname] = (min_calls + 1, max_calls + 1) + return mem_calls + + def add_loop_mem_calls (self, split, mem_calls): + if mem_calls == None: + return None + fnames = set ([self.p.nodes[n].fname + for n in self.p.loop_body (split) + if self.p.nodes[n].kind == 'Call']) + if not fnames: + return mem_calls + mem_calls = dict (mem_calls) + for fname in fnames: + if fname not in mem_calls: + mem_calls[fname] = (0, None) + else: + (min_calls, max_calls) = mem_calls[fname] + mem_calls[fname] = (min_calls, None) + return mem_calls + + # note these names are designed to be unique by suffix + # (so that smt names are independent of order of requests) + def local_name (self, s, n_vc): + return '%s_after_%s' % (s, self.node_count_name (n_vc)) + + def local_name_before (self, s, n_vc): + return '%s_v_at_%s' % (s, self.node_count_name (n_vc)) + + def cond_name (self, n_vc): + return 'cond_at_%s' % self.node_count_name (n_vc) + + def path_cond_name (self, n_vc, tag): + return 'path_cond_to_%s_%s' % ( + self.node_count_name (n_vc), tag) + + def success_name (self, fname, n_vc): + bits = fname.split ('.') + nms = ['_'.join (bits[i:]) for i in range (len (bits)) + if bits[i:][0].isalpha ()] + if nms: + nm = nms[-1] + else: + nm = 'fun' + return '%s_success_at_%s' % (nm, self.node_count_name (n_vc)) + + def try_inline (self, n, pc, env): + if not self.inliner: + return False + + inline = self.inliner ((self.p, n)) + if not inline: + return False + + # make sure this node is reachable before inlining + if self.solv.test_hyp (mk_not (pc), env): + trace ('Skipped inlining at %d.' % n) + return False + + trace ('Inlining at %d.' % n) + inline () + raise InlineEvent () + + def incr (self, vcount, n, incr): + vcount2 = dict (vcount) + vcount2[n] = vcount2[n].incr (incr) + if vcount2[n] == None: + return None + return tuple (sorted (vcount2.items ())) + + def get_cont (self, (n, vcount)): + [c] = self.p.nodes[n].get_conts () + vcount2 = dict (vcount) + if n in vcount2: + vcount = self.incr (vcount, n, 1) + cont = (c, vcount) + assert self.is_cont ((n, vcount), cont) + return cont + + def is_cont (self, (n, vcount), (n2, vcount2)): + if n2 not in self.p.nodes[n].get_conts (): + trace ('Not a graph cont.') + return False + + vcount_d = dict (vcount) + vcount_d2 = dict (vcount2) + if n in vcount_d2: + if n in vcount_d: + assert vcount_d[n].kind != 'Options' + vcount_d2[n] = vcount_d2[n].incr (-1) + + if not vcount_d <= vcount_d2: + trace ('Restrictions not subset.') + return False + + for (split, count) in vcount_d2.iteritems (): + if split in vcount_d: + continue + if self.get_reachable (split, n): + return False + if not count.has_zero (): + return False + + return True + + def prevs (self, (n, vcount)): + prevs = [] + vcount_d = dict (vcount) + for p in self.p.preds[n]: + if p in vcount_d: + vcount2 = self.incr (vcount, p, -1) + if vcount2 == None: + continue + prevs.append ((p, vcount2)) + else: + prevs.append ((p, vcount)) + return prevs + + def specialise (self, (n, vcount), split): + vcount = dict (vcount) + assert vcount[split].kind == 'Options' + specs = [] + for n in vcount[split].opts: + v = dict (vcount) + v[split] = n + specs.append (tuple (sorted (v.items ()))) + return specs + + def get_pc (self, (n, vcount), tag = None): + pc_env = self.get_node_pc_env ((n, vcount), tag = tag) + if pc_env == None: + trace ('Warning: unreachable n_vc, tag: %s, %s' % ((n, vcount), tag)) + return false_term + (pc, env) = pc_env + return to_smt_expr (pc, env, self.solv) + + def to_smt_expr (self, expr, (n, vcount), tag = None): + pc_env = self.get_node_pc_env ((n, vcount), tag = tag) + (pc, env) = pc_env + return to_smt_expr (expr, env, self.solv) + + def get_induct_var (self, (n1, n2)): + if (n1, n2) not in self.induct_var_env: + vname = self.solv.add_var('induct_i_%d_%d' % (n1, n2), syntax.arch.word_type) + self.induct_var_env[(n1, n2)] = vname + self.pc_env_requests.add (((n1, n2), 'InductVar')) + else: + vname = self.induct_var_env[(n1, n2)] + return mk_smt_expr(vname, syntax.arch.word_type) + + def interpret_hyp (self, hyp): + return hyp.interpret (self) + + def interpret_hyp_imps (self, hyps, concl): + hyps = map (self.interpret_hyp, hyps) + return logic.strengthen_hyp (syntax.mk_n_implies (hyps, concl)) + + def test_hyp_whyps (self, hyp, hyps, cache = None, fast = False, + model = None): + self.avail_hyps = set (hyps) + if not self.used_hyps <= self.avail_hyps: + self.rebuild () + + last_test[0] = (hyp, hyps, list (self.pc_env_requests)) + + expr = self.interpret_hyp_imps (hyps, hyp) + + trace ('Testing hyp whyps', push = 1) + trace ('requests = %s' % self.pc_env_requests) + + expr_s = smt_expr (expr, {}, self.solv) + if cache and expr_s in cache: + trace ('Cached: %s' % cache[expr_s]) + return cache[expr_s] + if fast: + trace ('(not in cache)') + return None + + self.solv.add_pvalid_dom_assertions () + + (result, _, _) = self.solv.parallel_test_hyps ([(None, expr)], + {}, model = model) + trace ('Result: %s' % result, push = -1) + if cache != None: + cache[expr_s] = result + if not result: + last_failed_test[0] = last_test[0] + return result + + def test_hyp_imp (self, hyps, hyp, model = None): + return self.test_hyp_whyps (self.interpret_hyp (hyp), hyps, + model = model) + + def test_hyp_imps (self, imps): + last_hyp_imps[0] = imps + if imps == []: + return (True, None) + interp_imps = list (enumerate ([self.interpret_hyp_imps (hyps, + self.interpret_hyp (hyp)) + for (hyps, hyp) in imps])) + reqs = list (self.pc_env_requests) + last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) + self.solv.add_pvalid_dom_assertions () + result = self.solv.parallel_test_hyps(interp_imps, PersistableModel({})) + assert result[0] in [True, False], result + if result[0] == False: + (hyps, hyp) = imps[result[1]] + last_test[0] = (self.interpret_hyp (hyp), hyps, reqs) + last_failed_test[0] = last_test[0] + return result + + def replay_requests (self, reqs): + for ((n, vc), tag) in reqs: + self.get_node_pc_env ((n, vc), tag = tag) last_test = [0] last_failed_test = [0] last_hyp_imps = [0] def to_smt_expr_under_op (expr, env, solv): - if expr.kind == 'Op': - vals = [to_smt_expr (v, env, solv) for v in expr.vals] - return syntax.adjust_op_vals (expr, vals) - else: - return to_smt_expr (expr, env, solv) + if expr.kind == 'Op': + vals = [to_smt_expr (v, env, solv) for v in expr.vals] + return syntax.adjust_op_vals (expr, vals) + else: + return to_smt_expr (expr, env, solv) def inst_eq_with_envs ((x, env1), (y, env2), solv): - x = to_smt_expr_under_op (x, env1, solv) - y = to_smt_expr_under_op (y, env2, solv) - if x.typ == syntax.builtinTs['RelWrapper']: - return logic.apply_rel_wrapper (x, y) - else: - return mk_eq (x, y) + x = to_smt_expr_under_op (x, env1, solv) + y = to_smt_expr_under_op (y, env2, solv) + + if x.typ == syntax.builtinTs['RelWrapper']: + return logic.apply_rel_wrapper (x, y) + else: + return mk_eq (x, y) def inst_eqs (eqs, envs, solv): - return [inst_eq_with_envs ((x, envs[x_addr]), (y, envs[y_addr]), solv) - for ((x, x_addr), (y, y_addr)) in eqs] + return [inst_eq_with_envs ((x, envs[x_addr]), (y, envs[y_addr]), solv) + for ((x, x_addr), (y, y_addr)) in eqs] def subst_induct (expr, induct_var): - substs = {('%n', word32T): induct_var} - return logic.var_subst (expr, substs, must_subst = False) + substs = {('%n', syntax.arch.word_type): induct_var} + return logic.var_subst (expr, substs, must_subst = False) printed_hyps = {} def print_hyps (hyps): - hyps = tuple (hyps) - if hyps in printed_hyps: - trace ('hyps = %s' % printed_hyps[hyps]) - else: - hname = 'hyp_set_%d' % (len (printed_hyps) + 1) - trace ('%s = %s' % (hname, list (hyps))) - printed_hyps[hname] = hyps - trace ('hyps = %s' % hname) + hyps = tuple (hyps) + if hyps in printed_hyps: + trace ('hyps = %s' % printed_hyps[hyps]) + else: + hname = 'hyp_set_%d' % (len (printed_hyps) + 1) + trace ('%s = %s' % (hname, list (hyps))) + printed_hyps[hname] = hyps + trace ('hyps = %s' % hname) def merge_mem_calls (mem_calls_x, mem_calls_y): - if mem_calls_x == mem_calls_y: - return mem_calls_x - mem_calls = {} - for fname in set (mem_calls_x.keys () + mem_calls_y.keys ()): - (min_x, max_x) = mem_calls_x.get (fname, (0, 0)) - (min_y, max_y) = mem_calls_y.get (fname, (0, 0)) - if None in [max_x, max_y]: - max_v = None - else: - max_v = max (max_x, max_y) - mem_calls[fname] = (min (min_x, min_y), max_v) - return mem_calls + if mem_calls_x == mem_calls_y: + return mem_calls_x + mem_calls = {} + for fname in set (mem_calls_x.keys () + mem_calls_y.keys ()): + (min_x, max_x) = mem_calls_x.get (fname, (0, 0)) + (min_y, max_y) = mem_calls_y.get (fname, (0, 0)) + if None in [max_x, max_y]: + max_v = None + else: + max_v = max (max_x, max_y) + mem_calls[fname] = (min (min_x, min_y), max_v) + return mem_calls def mem_calls_compatible (tags, l_mem_calls, r_mem_calls): - if l_mem_calls == None or r_mem_calls == None: - return (True, None) - r_cast_calls = {} - for (fname, calls) in l_mem_calls.iteritems (): - pairs = [pair for pair in pairings[fname] - if pair.tags == tags] - if not pairs: - return (None, 'no pairing for %s' % fname) - assert len (pairs) <= 1, pairs - [pair] = pairs - r_fun = pair.funs[tags[1]] - if not [nm for (nm, typ) in functions[r_fun].outputs - if typ == syntax.builtinTs['Mem']]: - continue - r_cast_calls[pair.funs[tags[1]]] = calls - for fname in set (r_cast_calls.keys () + r_mem_calls.keys ()): - r_cast = r_cast_calls.get (fname, (0, 0)) - r_actual = r_mem_calls.get (fname, (0, 0)) - s = 'mismatch in calls to %s and pairs, %s / %s' % (fname, - r_cast, r_actual) - if r_cast[1] != None and r_cast[1] < r_actual[0]: - return (None, s) - if r_actual[1] != None and r_actual[1] < r_cast[0]: - return (None, s) - return (True, None) + if l_mem_calls == None or r_mem_calls == None: + return (True, None) + r_cast_calls = {} + for (fname, calls) in l_mem_calls.iteritems (): + pairs = [pair for pair in pairings[fname] + if pair.tags == tags] + if not pairs: + return (None, 'no pairing for %s' % fname) + assert len (pairs) <= 1, pairs + [pair] = pairs + r_fun = pair.funs[tags[1]] + if not [nm for (nm, typ) in functions[r_fun].outputs + if typ == syntax.builtinTs['Mem']]: + continue + r_cast_calls[pair.funs[tags[1]]] = calls + for fname in set (r_cast_calls.keys () + r_mem_calls.keys ()): + r_cast = r_cast_calls.get (fname, (0, 0)) + r_actual = r_mem_calls.get (fname, (0, 0)) + s = 'mismatch in calls to %s and pairs, %s / %s' % (fname, + r_cast, r_actual) + if r_cast[1] != None and r_cast[1] < r_actual[0]: + return (None, s) + if r_actual[1] != None and r_actual[1] < r_cast[0]: + return (None, s) + return (True, None) def mk_inp_env (n, args, rep): - trace ('rep_graph setting up input env at %d' % n, push = 1) - inp_env = {} + trace ('rep_graph setting up input env at %d' % n, push = 1) + inp_env = {} - for (v_nm, typ) in args: - inp_env[(v_nm, typ)] = rep.add_var (v_nm + '_init', typ, - mem_name = 'Init', mem_calls = {}) - for (v_nm, typ) in args: - z = rep.var_rep_request ((v_nm, typ), 'Init', (n, ()), inp_env) - if z: - inp_env[(v_nm, typ)] = z + for (v_nm, typ) in args: + inp_env[(v_nm, typ)] = rep.add_var (v_nm + '_init', typ, + mem_name = 'Init', mem_calls = {}) + for (v_nm, typ) in args: + z = rep.var_rep_request ((v_nm, typ), 'Init', (n, ()), inp_env) + if z: + inp_env[(v_nm, typ)] = z - trace ('done setting up input env at %d' % n, push = -1) - return inp_env + trace ('done setting up input env at %d' % n, push = -1) + return inp_env def mk_graph_slice (p, inliner = None, fast = False, mk_solver = Solver): - trace ('rep_graph setting up solver', push = 1) - solv = mk_solver () - trace ('rep_graph setting up solver', push = -1) - return GraphSlice (p, solv, inliner, fast = fast) + trace ('rep_graph setting up solver', push = 1) + solv = mk_solver () + trace ('rep_graph setting up solver', push = -1) + return GraphSlice (p, solv, inliner, fast = fast) def run_requests (rep, requests): - for (n_vc, tag) in requests: - if tag == 'InductVar': - rep.get_induct_var (n_vc) - else: - rep.get_pc (n_vc, tag = tag) - rep.solv.add_pvalid_dom_assertions () + for (n_vc, tag) in requests: + if tag == 'InductVar': + rep.get_induct_var (n_vc) + else: + rep.get_pc (n_vc, tag = tag) + rep.solv.add_pvalid_dom_assertions () import re paren_w_re = re.compile (r"(\(|\)|\w+)") def mk_function_link_hyps (p, call_vis, tag, adjust_eq_seq = None): - (entry, _, args) = p.get_entry_details (tag) - ((call_site, restrs), call_tag) = call_vis - assert p.nodes[call_site].kind == 'Call' - entry_vis = ((entry, ()), p.node_tags[entry][0]) - - args = [syntax.mk_var (nm, typ) for (nm, typ) in args] - - pc = pc_true_hyp (call_vis) - eq_seq = logic.azip (p.nodes[call_site].args, args) - if adjust_eq_seq: - eq_seq = adjust_eq_seq (eq_seq) - hyps = [pc] + [eq_hyp ((x, call_vis), (y, entry_vis)) - for (x, y) in eq_seq - if x.typ.kind == 'Word' or x.typ == syntax.builtinTs['Mem'] - or x.typ.kind == 'WordArray'] - - return hyps + (entry, _, args) = p.get_entry_details (tag) + ((call_site, restrs), call_tag) = call_vis + assert p.nodes[call_site].kind == 'Call' + entry_vis = ((entry, ()), p.node_tags[entry][0]) + + args = [syntax.mk_var (nm, typ) for (nm, typ) in args] + + pc = pc_true_hyp (call_vis) + eq_seq = logic.azip (p.nodes[call_site].args, args) + if adjust_eq_seq: + eq_seq = adjust_eq_seq (eq_seq) + hyps = [pc] + [eq_hyp ((x, call_vis), (y, entry_vis)) + for (x, y) in eq_seq + if x.typ.kind == 'Word' or x.typ == syntax.builtinTs['Mem'] + or x.typ.kind == 'WordArray'] + + return hyps diff --git a/scripts/setup-HOL4.sh b/scripts/setup-HOL4.sh deleted file mode 100755 index d423318d..00000000 --- a/scripts/setup-HOL4.sh +++ /dev/null @@ -1,93 +0,0 @@ -#! /bin/bash - -# -# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) -# -# SPDX-License-Identifier: BSD-2-Clause -# - -function report_err { - echo .. failed! - echo Short error output: - echo - tail -n 20 $1 - echo - echo " (more error output in $1)" - exit 1 -} - -HOL4_SEARCH="$PWD" -while [ "$HOL4_SEARCH" != "/" -a ! -d "$HOL4_SEARCH/HOL4" ]; do - HOL4_SEARCH=$(dirname "$HOL4_SEARCH") -done -HOL4_DIR=$(readlink -f "$HOL4_SEARCH/HOL4") -if [ ! -d "$HOL4_DIR" ]; then - echo "No HOL4 found" - exit 1 -fi -echo "Setting up HOL4 in $HOL4_DIR" - -POLY_DIR=$HOL4_DIR/polyml -POLY=$POLY_DIR/deploy/bin/poly -if [[ -e $POLY ]] -then - echo PolyML already built. -elif [[ -e $POLY_DIR/configure ]] -then - echo Building PolyML in $POLY_DIR - echo ' (tracing build progress to poly_output.txt)' - OUT=$(readlink -f poly_output.txt) - pushd $POLY_DIR - echo ' (configuring)' - (./configure --prefix=$POLY_DIR/deploy) &> $OUT - echo ' (building)' - (make && make install) &>> $OUT - if [[ -e $POLY ]] - then - echo Built PolyML - else - report_err poly_output.txt - exit 1 - fi - popd -elif [[ -e $POLY_DIR ]] -then - echo Missing PolyML source in $POLY_DIR - exit 1 -else - echo No PolyML dir $POLY_DIR - exit 1 -fi - -# this script cleans any previous build of HOL4 -# this is needed when pulling in new revisions to the base system -OUT=$(readlink -f hol4_output.txt) -echo output is $OUT -pushd $HOL4_DIR - -echo Cleaning HOL4 build in $HOL4_DIR -git clean -fdX -e polyml &> /dev/null - -echo Building HOL4 now. -echo ' (tracing build progress to hol4_output.txt)' -echo ' (configuring)' -$POLY < tools-poly/smart-configure.sml &> $OUT - -if [[ ! -e $HOL4_DIR/bin/build ]] -then - report_err hol4_output.txt - exit 1 -fi - -echo ' (building)' -PATH=$HOL4_DIR/bin:$PATH build &>> $OUT - -if ( tail $OUT | grep 'built successfully' ) -then - echo 'Built HOL4.' -else - report_err hol4_output.txt - exit 1 -fi -popd - diff --git a/seL4-example/.gitignore b/seL4-example/.gitignore index 2114c602..3e789177 100644 --- a/seL4-example/.gitignore +++ b/seL4-example/.gitignore @@ -4,4 +4,5 @@ # SPDX-License-Identifier: BSD-2-Clause # +/build/ /target/ diff --git a/seL4-example/Makefile b/seL4-example/Makefile index d4502f79..2f47cf8d 100644 --- a/seL4-example/Makefile +++ b/seL4-example/Makefile @@ -10,123 +10,89 @@ # (e.g. standalone c-parser and decompiler) properly, so may not know to # rebuild if a custom tool is updated. -ifndef CONFIG_OPTIMISATION_LEVEL - CONFIG_OPTIMISATION_LEVEL := -O1 +ifndef KERNEL_CONFIG_OPTIMISATION_LEVEL + KERNEL_CONFIG_OPTIMISATION_LEVEL := -O1 endif -# FIXME: solver self-test is currently broken -SKIP_SOLV_TEST := SKIP +L4V_CONFIG := ${L4V_ARCH}$(if ${L4V_FEATURES},-${L4V_FEATURES},) +TARGET_NAME := ${L4V_CONFIG}${KERNEL_CONFIG_OPTIMISATION_LEVEL} -ifndef GREF_ROOT - GREF_ROOT := $(realpath $(dir $(lastword ${MAKEFILE_LIST}))..) +ifndef TARGET_DIR + TARGET_DIR := target/${TARGET_NAME} endif -ifndef HOL4_ROOT - HOL4_ROOT := $(realpath ${GREF_ROOT}/../HOL4) +ifndef GRAPH_REFINE_ROOT + GRAPH_REFINE_ROOT := $(realpath $(dir $(lastword ${MAKEFILE_LIST}))..) endif -L4V_CONFIG := ${L4V_ARCH}$(if ${L4V_FEATURES},-${L4V_FEATURES},) -TARGET_NAME := ${L4V_CONFIG}${CONFIG_OPTIMISATION_LEVEL} -TARGET_DIR := target/${TARGET_NAME} +ifndef DECOMPILER + DECOMPILER := ${GRAPH_REFINE_ROOT}/decompiler/decompile +endif + +ifndef L4V_KERNEL_MK + L4V_KERNEL_MK := $(realpath ${GRAPH_REFINE_ROOT}/../l4v/spec/cspec/c/kernel.mk) +endif + +ifndef KERNEL_BUILD_ROOT + KERNEL_BUILD_ROOT=${GRAPH_REFINE_ROOT}/seL4-example/build/${TARGET_NAME} +endif + +KERNEL_BUILD_EXPORT_DIR := ${TARGET_DIR} # We build our own kernel locally, so we can store builds # according to their optimisation levels. -KERNEL_BUILD_ROOT := ${TARGET_DIR}/build -KERNEL_CMAKE_EXTRA_OPTIONS := -DKernelOptimisation=${CONFIG_OPTIMISATION_LEVEL} -include ${GREF_ROOT}/../l4v/spec/cspec/c/kernel.mk +# This depends on KERNEL_BUILD_ROOT and KERNEL_BUILD_EXPORT_DIR. +include ${L4V_KERNEL_MK} # However, CFunctions.txt depends on l4v's kernel build. -# FIXME: the l4v build directory should really depend on L4V_FEATURES. L4V_KERNEL_BUILD_DIR := build/${L4V_ARCH} L4V_KERNEL_BUILD_PATH := ${CSPEC_DIR}/c/${L4V_KERNEL_BUILD_DIR} -DECOMP_DIR := ${HOL4_ROOT}/examples/machine-code/graph -DECOMP_SCRIPT := $(shell PATH="${DECOMP_DIR}:${PATH}" sh -c "which decompile.py") - -# sanity test configuration - -$(if ${DECOMP_SCRIPT},,$(error decompile.py not executable in ${DECOMP_DIR})) - -$(if $(wildcard ${HOL4_ROOT}/bin/Holmake ${HOL4_ROOT}/bin/build),, \ - $(error Holmake/build not found in ${HOL4_ROOT}/bin - first configure HOL4. \ - See INSTALL in HOL4, but skip the bin/build step)) - -SOLV=python ${GREF_ROOT}/solver.py - -SOLV_TEST_SUCC := 'Solver self-test succ' -SOLV_TEST := $(shell $(if ${SKIP_SOLV_TEST}, echo ${SOLV_TEST_SUCC}, \ - ${SOLV} testq) | grep ${SOLV_TEST_SUCC}) -$(if ${SOLV_TEST},,$(error Solver self-test failed (${SOLV} test))) - -# compile and decompile - -${TARGET_DIR}/summary.txt: ${TARGET_DIR}/kernel_all.c_pp - echo Summary > pre_summary.txt - bash mk_summ ${SOURCE_ROOT} >> pre_summary.txt - bash mk_summ ${L4V_REPO_PATH} >> pre_summary.txt - bash mk_summ ${HOL4_ROOT} >> pre_summary.txt - bash mk_summ . >> pre_summary.txt - mv pre_summary.txt summary.txt - -KERNEL_FILES := kernel.elf.rodata kernel.elf.txt kernel.elf.symtab kernel_all.c_pp kernel.sigs kernel.elf -TARGET_FILES := target.py CFunctions.txt ASMFunctions.txt - -KERNEL_PATHS := $(patsubst %, $(TARGET_DIR)/%, $(KERNEL_FILES)) -TARGET_PATHS := $(patsubst %, $(TARGET_DIR)/%, $(TARGET_FILES)) +${DECOMPILER}: + @echo "No decompiler found. Use setup-decompiler.py to install one." >&2 + @exit 1 -KERNEL_TGZ := ${TARGET_DIR}/kernel.tar.gz -TARGET_TGZ := ${TARGET_DIR}/target.tar.gz +# Compile and decompile -${KERNEL_TGZ}: ${KERNEL_PATHS} - tar -czf $@ -C ${TARGET_DIR} ${KERNEL_FILES} - -${TARGET_TGZ}: ${KERNEL_PATHS} ${TARGET_PATHS} - tar -czf $@ -C ${TARGET_DIR} ${KERNEL_FILES} ${TARGET_FILES} - -tar: ${KERNEL_TGZ} ${TARGET_TGZ} - -${KERNEL_PATHS}: ${TARGET_DIR}/%: ${KERNEL_BUILD_ROOT}/% - @mkdir -p ${TARGET_DIR} - cp $< $@ - -clean: - rm -rf build kernel.elf.* kernel_all* kernel.tar* - -.PHONY: clean tar - -H4PATH := $(realpath ${HOL4_ROOT}/bin):${PATH} - -IGNORES_ARM := restore_user_context,c_handle_fastpath_call,c_handle_fastpath_reply_recv -IGNORES_RISCV64 := # TODO - -KERNEL_ALL_PP_FILES := ${L4V_KERNEL_BUILD_PATH}/kernel_all.c_pp ${KERNEL_BUILD_ROOT}/kernel_all.c_pp +${L4V_KERNEL_BUILD_PATH}/kernel_all.c_pp: ${KERNEL_DEPS} ${CONFIG_DOMAIN_SCHEDULE} + MAKEFILES= make -C ${CSPEC_DIR}/c ${L4V_KERNEL_BUILD_DIR}/kernel_all.c_pp -# FIXME: This should be a prerequisite of some other essential target, -# but for convenience during development, it is currently not. -${TARGET_DIR}/.diff: ${KERNEL_ALL_PP_FILES} - diff -q --ignore-matching-lines='^#' ${KERNEL_ALL_PP_FILES} +# A quick check that the two kernel builds produce the same C source +${TARGET_DIR}/.diff: ${L4V_KERNEL_BUILD_PATH}/kernel_all.c_pp ${KERNEL_BUILD_ROOT}/kernel_all.c_pp + diff -q --ignore-matching-lines='^#' $^ @mkdir -p ${TARGET_DIR} @touch $@ diff: ${TARGET_DIR}/.diff .PHONY: diff -${L4V_KERNEL_BUILD_PATH}/kernel_all.c_pp: ${KERNEL_DEPS} ${CONFIG_DOMAIN_SCHEDULE} - MAKEFILES= make -C ${CSPEC_DIR}/c ${L4V_KERNEL_BUILD_DIR}/kernel_all.c_pp +ASM_FUNCTIONS_DEPS := ${TARGET_DIR}/kernel.elf.txt ${TARGET_DIR}/kernel.sigs ${DECOMPILER} ${AUIPC_FIXUP} +ASM_FUNCTIONS := ${TARGET_DIR}/ASMFunctions.txt +ASM_FUNCTIONS_OUT := ${ASM_FUNCTIONS} ${TARGET_DIR}/kernel_output.txt ${TARGET_DIR}/StackBounds.txt + +IGNORE_ARM := _start,c_handle_fastpath_call,c_handle_fastpath_reply_recv,restore_user_context -${TARGET_DIR}/ASMFunctions.txt: ${TARGET_DIR}/kernel.elf.txt ${TARGET_DIR}/kernel.sigs - cd ${TARGET_DIR} && PATH=${H4PATH} ${DECOMP_SCRIPT} --fast ./kernel --ignore=${IGNORES_${L4V_CONFIG}} - mv ${TARGET_DIR}/kernel_mc_graph.txt ${TARGET_DIR}/ASMFunctions.txt +${ASM_FUNCTIONS_OUT} &: ${ASM_FUNCTIONS_DEPS} + ${DECOMPILER} ${TARGET_DIR}/kernel --ignore=${IGNORE_${L4V_ARCH}} + ${GRAPH_REFINE_ROOT}/seL4-example/functions-tool.py \ + --target-dir ${TARGET_DIR} \ + --asm-functions-out ASMFunctions.txt \ + --stack_bounds-out StackBounds.txt -${TARGET_DIR}/CFunctions.txt: ${L4V_KERNEL_BUILD_PATH}/kernel_all.c_pp ${L4V_REPO_PATH}/tools/asmrefine/*.thy +C_FUNCITONS_DEPS := \ + ${L4V_KERNEL_BUILD_PATH}/kernel_all.c_pp \ + ${L4V_REPO_PATH}/tools/asmrefine/*.thy \ + ${L4V_REPO_PATH}/tools/asmrefine/${L4V_ARCH}/*.thy + +${TARGET_DIR}/CFunctions.txt: ${C_FUNCTIONS_DEPS} @mkdir -p ${TARGET_DIR} MAKEFILES= make -C ${L4V_REPO_PATH}/proof/ SimplExport - # FIXME: the following path should really depend on L4V_FEATURES. cp ${L4V_REPO_PATH}/proof/asmrefine/export/${L4V_ARCH}/CFunDump.txt $@ -${TARGET_DIR}/target.py: target.py +TARGET_PY := ${GRAPH_REFINE_ROOT}/seL4-example/target-${L4V_ARCH}.py +${TARGET_DIR}/target.py: ${TARGET_PY} @mkdir -p ${TARGET_DIR} - cp target.py $@ + cp ${TARGET_PY} $@ GRAPH_REFINE_INPUTS := \ ${TARGET_DIR}/kernel.elf.rodata \ @@ -134,12 +100,9 @@ GRAPH_REFINE_INPUTS := \ ${TARGET_DIR}/ASMFunctions.txt \ ${TARGET_DIR}/CFunctions.txt \ ${TARGET_DIR}/target.py \ - ${GREF_ROOT}/*.py - -GRAPH_REFINE := python ${GREF_ROOT}/graph-refine.py + ${GRAPH_REFINE_ROOT}/*.py -${TARGET_DIR}/StackBounds.txt: ${GRAPH_REFINE_INPUTS} - ${GRAPH_REFINE} ${TARGET_DIR} +GRAPH_REFINE := python2 ${GRAPH_REFINE_ROOT}/graph-refine.py ${TARGET_DIR}/demo-report.txt: ${TARGET_DIR}/StackBounds.txt ${GRAPH_REFINE_INPUTS} ${GRAPH_REFINE} ${TARGET_DIR} trace-to:$@.partial deps:Kernel_C.cancelAllIPC @@ -166,7 +129,7 @@ default: report # WCET (worst-case execution time) targets -GTG := ${GREF_ROOT}/graph-to-graph/ +GTG := ${GRAPH_REFINE_ROOT}/graph-to-graph/ TARGET_DIR_ABS := $(realpath TARGET_DIR) ${TARGET_DIR}/loop_counts_1.py: ${TARGET_DIR}/StackBounds.txt ${GRAPH_REFINE_INPUTS} diff --git a/seL4-example/configure_default.sh b/seL4-example/configure_default.sh deleted file mode 100755 index 18cd6ac8..00000000 --- a/seL4-example/configure_default.sh +++ /dev/null @@ -1,108 +0,0 @@ -#! /bin/bash - -# -# Copyright 2020, Data61, CSIRO (ABN 41 687 119 230) -# -# SPDX-License-Identifier: BSD-2-Clause -# - -# setup Isabelle -ISABELLE=../../isabelle/bin/isabelle -if [[ -e ~/.isabelle/etc/settings ]] -then - ISA_PLAT=$($ISABELLE env bash -c 'echo $ML_PLATFORM') - if echo $ISA_PLAT | grep -q 64 - then - echo Isabelle platform is $ISA_PLAT - else - echo Isabelle platform $ISA_PLAT not 64-bit - echo Will not be able to build seL4 C models. - echo Reconfigure in ~/.isabelle/etc/settings - exit 1 - fi -else - echo No Isabelle settings, setting defaults. - mkdir -p ~/.isabelle/etc/ - cp ../../l4v/misc/etc/settings ~/.isabelle/etc/ -fi -$ISABELLE components -a - -HOL4_DIR=$(readlink -f ../../HOL4) -POLY_DIR=$HOL4_DIR/polyml - -function mk_build_summ { - SUMM=$(readlink -f $HOL4_DIR/bin/$1) - pushd $HOL4_DIR - echo Results of '"git show"' in $HOL4_DIR before building. > $SUMM - git show >> $SUMM - popd - echo Results of '"find $POLY_DIR/deploy -type f | xargs md5sum"' before building. >> $SUMM - find $POLY_DIR/deploy -type f | xargs md5sum >> $SUMM -} - -# check if HOL4 is built, and if it is up to date -if [[ ! -e $HOL4_DIR/bin/build ]] -then - echo 'HOL4 not built.' - REBUILD='true' -elif [[ ! -e $HOL4_DIR/sigobj/realTheory.sig ]] -then - echo 'HOL4 theories not built' - REBUILD='true' -elif [[ -e $HOL4_DIR/bin/build_summ ]] -then - mk_build_summ build_summ2 - if ( diff -q $HOL4_DIR/bin/build_summ $HOL4_DIR/bin/build_summ2 ) - then - # curiously this is the equal case of diff - echo HOL4 matches last build status. - REBUILD='false' - else - echo HOL4 configuration changed from previous build. - REBUILD='true' - fi -fi - -if [[ $REBUILD == 'true' ]] -then - bash ../scripts/setup-HOL4.sh -fi - -# setup graph-refine to use CVC4 from Isabelle -SVFL=../../.solverlist -if python ../solver.py testq | grep -q 'Solver self-test succ' -then - echo Solvers already configured. -else if [[ -e $SVFL ]] -then - echo Solvers configured but self-test failed. - echo Try python ../solver.py test - echo and adjust $SVFL to succeed. - exit 1 -else - ISA_CVC4=$($ISABELLE env bash -c 'echo $CVC4_SOLVER') - echo '# minimum autogenerated .solverlist' > $SVFL - echo CVC4: online: $ISA_CVC4 --incremental --lang smt --tlimit=5000 >> $SVFL - echo CVC4: offline: $ISA_CVC4 --lang smt >> $SVFL - echo Configured graph-refine to use CVC4 SMT solver. - if python ../solver.py testq | grep -q 'Solver self-test succ' - then - echo Self test passed. - else - echo Self test failed! - echo Try python ../solver.py test - echo and adjust $SVFL to succeed. - exit 1 - fi -fi fi - -if which mlton -then - echo MLton available. -else - echo MLton not available or not found. - echo e.g. 'which mlton' should succeed. - exit 1 -fi - - diff --git a/seL4-example/default.nix b/seL4-example/default.nix new file mode 100644 index 00000000..be1e1654 --- /dev/null +++ b/seL4-example/default.nix @@ -0,0 +1,210 @@ +# Copyright (c) 2022, Kry10 Limited. +# SPDX-License-Identifier: BSD-2-Clause + +# Packages the decompiler with some extra tools and scripts, to make Docker +# images for producing graph-refine inputs other than CFunctions.txt. +# These can be useful when some other process (e.g. GitHub CI) is used to +# generate CFunctions.txt. +# - sel4-decompile can be used to perform decompilation and stack analysis, +# assuming the kernel has already been compiled. +# - sel4-compile-decompile can additionally perform kernel compilation. + +# We assume that PolyML and HOL4 sources have been checked out. +# These can be checked out using: +# ../decompiler/setup-decompiler checkout --upstream +# These are used to pre-build a decompiler. + +# For export-kernel-builds, we also assume l4v and isabelle checkouts. +# These are used to pre-build a standalone C parser. + +{ + l4v_src ? ../../l4v, + isabelle_src ? ../../isabelle, + polyml_src ? ../decompiler/src/polyml, + hol4_src ? ../decompiler/src/HOL4, +}: + +let + + pins = import ../nix/pins.nix; + inherit (pins) pkgs lib stdenv rosetta-pkgs; + inherit (pins.herculesGitignore) gitignoreFilter; + + inherit (import ../nix/util.nix) explicit_sources_filter explicit_sources conj2 mk-check-env; + inherit (import ../nix/sel4-deps.nix) sel4-deps; + + inherit (import ../decompiler { inherit polyml_src hol4_src; }) decompile-bin; + + isabelle-table-ml = explicit_sources "isabelle-table-ml" isabelle_src [ + "src/Pure/General/table.ML" + ]; + + graph-refine-seL4 = explicit_sources "graph-refine-seL4" ./. [ + "functions-tool.py" + "target-ARM.py" + "target-RISCV64.py" + ]; + + # The standalone C parser is needed to produce kernel.sigs. + mk-arch-c-parser = arch: stdenv.mkDerivation { + name = "c-parser-${arch}"; + src = lib.cleanSourceWith { + name = "c-parser-src"; + src = l4v_src; + filter = conj2 [ (gitignoreFilter l4v_src) + (explicit_sources_filter l4v_src [ "tools/c-parser" ]) ]; + }; + nativeBuildInputs = [ rosetta-pkgs.mlton pkgs.perl ]; + buildPhase = '' + ISABELLE_HOME=${isabelle-table-ml} \ + SML_COMPILER=mlton \ + GLOBAL_MAKES_INCLUDED=true \ + make -C tools/c-parser/standalone-parser \ + "$PWD/tools/c-parser/standalone-parser/${arch}/c-parser" + ''; + installPhase = '' + mkdir -p "$out/bin" + cp -a "tools/c-parser/standalone-parser/${arch}/c-parser" "$out/bin" + ''; + }; + + perl = "${pkgs.perl}/bin/perl"; + + sel4-decompile-python = + let python = pkgs.python3.withPackages (p: with p; [networkx]); + in "${python}/bin/python"; + + sel4-decompile-arch-cmd = arch: + let + ignore = builtins.readFile (pkgs.runCommand "sel4-decompile-${arch}" {} '' + ${perl} -ne 'print $1 if /^IGNORE_${arch}\s*[\?:]?=\s*(\S*)\s*$/' ${./Makefile} > $out + ''); + decompile = if ignore == "" + then ''"${decompile-bin}/bin/decompile"'' + else ''"${decompile-bin}/bin/decompile" --ignore "${ignore}"''; + in decompile; + + mk-arches = arches: rec { + deriv_args = text: { + inherit text; + passthru = { inherit arches; }; + passAsFile = [ "text" ]; + }; + + writeScriptBin = name: text: pkgs.runCommand name (deriv_args text) '' + mkdir -p "$out/bin" + mv "$textPath" "$out/bin/${name}" + chmod +x "$out/bin/${name}" + ''; + + writeScript = name: text: pkgs.runCommand name (deriv_args text) '' + mv "$textPath" "$out" + chmod +x "$out" + ''; + + c-parser = + let + arch-case = arch: ''${arch}) exec ${mk-arch-c-parser arch}/bin/c-parser "$@";;''; + script = writeScriptBin "c-parser" '' + #!${pkgs.runtimeShell} + set -euo pipefail + case "$L4V_ARCH" in + ${lib.concatStringsSep "\n " (map arch-case arches)} + *) echo "error: unknown L4V_ARCH: $L4V_ARCH" >&2; exit 1;; + esac + ''; + in script; + + # A wrapper that runs a given command in an environment that has + # everything needed for building seL4 using the l4v kernel make files, + # including a pre-built standalone C parser. + sel4-build-env = + let + deriv_args = { nativeBuildInputs = [ pkgs.makeBinaryWrapper ]; }; + script = writeScript "sel4-build-env-exec" '' + #!${pkgs.runtimeShell} + exec "$@" + ''; + wrapper = pkgs.runCommand "sel4-build-env" deriv_args '' + makeWrapper "${script}" "$out" \ + --set PATH "${lib.makeBinPath sel4-deps}" \ + --set STANDALONE_C_PARSER_EXE "${c-parser}/bin/c-parser" + ''; + in wrapper; + + sel4-decompile-script = let writeScriptDefault = writeScript; in + { name ? "sel4-decompile", writeScript ? writeScriptDefault, decompile ? true }: + let + arch-case = arch: if decompile + then ''${arch}) ${sel4-decompile-arch-cmd arch} "$TARGET_DIR/kernel";;'' + else ''${arch}) ;;''; + + script = '' + #!${pkgs.runtimeShell} + set -euo pipefail + + if [ $# -eq 0 ]; then + echo "${name}: error: no arguments" >&2 + echo "${name} usage: ${name} TARGET_DIR..." >&2 + exit 1 + fi + + for TARGET_DIR in "$@"; do + if [ ! -f "$TARGET_DIR/config.env" ]; then + echo "${name}: error: $TARGET_DIR/config.env does not exist" >&2 + exit 1 + fi + done + + for TARGET_DIR in "$@"; do + export L4V_ARCH=$("${perl}" -ne 'print $1 if /^L4V_ARCH=(\S+)$/' "$TARGET_DIR/config.env") + + case "$L4V_ARCH" in + ${lib.concatStringsSep "\n " (map arch-case arches)} + *) echo "${name}: error: unknown L4V_ARCH in $TARGET_DIR/config.env: $L4V_ARCH" >&2; exit 1;; + esac + + if [ -f "$TARGET_DIR/CFunctions.txt" ]; then + FUNCTIONS_ARG="--functions-list-out functions-list.txt" + else + FUNCTIONS_ARG="" + fi + + "${sel4-decompile-python}" "${graph-refine-seL4}/functions-tool.py" \ + --target-dir "$TARGET_DIR" \ + --asm-functions-out ASMFunctions.txt \ + --stack-bounds-out StackBounds.txt \ + $FUNCTIONS_ARG + + cp "${graph-refine-seL4}/target-$L4V_ARCH.py" "$TARGET_DIR/target.py" + done + ''; + in writeScript name script; + + # For interactive use. + sel4-decompile = sel4-decompile-script {}; + + # For the Docker image. + sel4-decompile-bin = sel4-decompile-script { writeScript = writeScriptBin; }; + + # For development, skips the actual decompile, and just does the stuff after it. + sel4-decompile-post = sel4-decompile-script { name = "sel4-decompile-post"; decompile = false; }; + + sel4-decompiler-image = pkgs.dockerTools.streamLayeredImage { + name = "sel4-decompiler"; + contents = with pkgs; [ bashInteractive coreutils sel4-decompile-bin ]; + config = { EntryPoint = [ "${sel4-decompile-bin}/bin/sel4-decompile" ]; }; + }; + }; + +in { + + inherit (mk-arches [ "ARM" "RISCV64" ]) + c-parser + sel4-build-env + sel4-decompile + sel4-decompile-bin + sel4-decompile-post + sel4-decompiler-image; + +} diff --git a/seL4-example/functions-tool.py b/seL4-example/functions-tool.py new file mode 100755 index 00000000..f59354df --- /dev/null +++ b/seL4-example/functions-tool.py @@ -0,0 +1,977 @@ +#!/usr/bin/env python3 + +# Copyright 2021, Data61, CSIRO (ABN 41 687 119 230) +# Copyright 2023, Kry10 Limited +# SPDX-License-Identifier: BSD-2-Clause + +# Perform various minor fixups and analyses on the outputs of l4v +# and decompilation, before passing them as inputs to graph-refine: +# - Generate StackBounds.txt using a simplified stack usage analysis +# (see below). +# - Generate functions-list.txt containing the list of function names +# that are common to CFunctions.txt and ASMFunctions.txt. +# - RISCV64 ASMFunctions.txt fixups: +# - Fix the values loaded by AUIPC instructions, since the decompiler +# currently truncates these to 32 bits. +# - Rename `sfence_vma` to `sfence.vma`. + +import argparse +import os +import re +import sys + +from pathlib import Path +from typing import (Callable, Iterable, Iterator, Mapping, NamedTuple, + Optional, Sequence, Set, TextIO, Tuple, TypeVar, Union, Protocol) + + +K = TypeVar('K') +V = TypeVar('V') +R = TypeVar('R') +T = TypeVar('T') +L = TypeVar('L') + +Elim = Callable[[T], R] + + +# Exceptions that may occur during parsing. + +class UnexpectedEofError(Exception): + pass + + +class BadFileFormatError(Exception): + pass + + +class UnexpectedInput(Exception): + pass + + +class DuplicateNode(Exception): + pass + + +class DuplicateFunction(Exception): + pass + + +class MalformedFunction(Exception): + pass + + +class MalformedInstruction(Exception): + pass + + +# A simple parser for graph-lang files. + +class GraphLangNode(Protocol): + def elim(self, basic: Elim['BasicNode', R], call: Elim['CallNode', R], + cond: Elim['CondNode', R]) -> R: + ... + + +class BasicNode(NamedTuple): + succ: str + assign: str + + def elim(self, basic: Elim['BasicNode', R], call: Elim['CallNode', R], + cond: Elim['CondNode', R]) -> R: + return basic(self) + + +class CallNode(NamedTuple): + succ: str + callee: str + args: str + + def elim(self, basic: Elim['BasicNode', R], call: Elim['CallNode', R], + cond: Elim['CondNode', R]) -> R: + return call(self) + + +class CondNode(NamedTuple): + succ_true: str + succ_false: str + expr: str + + def elim(self, basic: Elim['BasicNode', R], call: Elim['CallNode', R], + cond: Elim['CondNode', R]) -> R: + return cond(self) + + +GraphLangNodes = Mapping[str, GraphLangNode] + + +class GraphLangFunction(NamedTuple): + args: str + nodes: GraphLangNodes + entry: Optional[str] + + +GraphLang = Mapping[str, GraphLangFunction] + + +def parse_graph_lang(file_path: Path, structs: bool) -> GraphLang: + functions: dict[str, GraphLangFunction] = {} + + cur_fn_name: Optional[str] = None + cur_fn_args: Optional[str] = None + cur_fn_nodes: dict[str, GraphLangNode] = {} + cur_fn_entry: Optional[str] = None + + function_re = re.compile(r'Function (?P\S+) (?P\S.*)') + basic_re = re.compile(r'(?P