Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Windows support #53

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion human_eval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


ROOT = os.path.dirname(os.path.abspath(__file__))
HUMAN_EVAL = os.path.join(ROOT, "..", "data", "HumanEval.jsonl.gz")
HUMAN_EVAL = os.path.join(ROOT, "data", "HumanEval.jsonl.gz")


def read_problems(evalset_file: str = HUMAN_EVAL) -> Dict[str, Dict]:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 3 additions & 1 deletion human_eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def evaluate_functional_correctness(
n_workers: int = 4,
timeout: float = 3.0,
problem_file: str = HUMAN_EVAL,
ignore_incomplete: bool = False
):
"""
Evaluates the functional correctness of generated samples, and writes
Expand All @@ -68,7 +69,8 @@ def evaluate_functional_correctness(
completion_id[task_id] += 1
n_samples += 1

assert len(completion_id) == len(problems), "Some problems are not attempted."
if not ignore_incomplete:
assert len(completion_id) == len(problems), "Some problems are not attempted."

print("Running test suites...")
for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
Expand Down
260 changes: 179 additions & 81 deletions human_eval/execution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Callable, Dict
from typing import Optional, Callable, Dict, List
import ast
import contextlib
import faulthandler
Expand All @@ -8,73 +8,95 @@
import platform
import signal
import tempfile


def check_correctness(problem: Dict, completion: str, timeout: float,
completion_id: Optional[int] = None) -> Dict:
import threading
import ctypes

IS_WINDOWS = platform.system() == "Windows"


def unsafe_execute(problem: Dict, completion: str, timeout: float, result: List):
"""Runs untrusted code in a sandboxed environment."""
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil

rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir

# Disable functionalities that can make destructive changes to the test.
reliability_guard()

# Construct the check program and run it.
check_program = (
problem["prompt"]
+ completion
+ "\n"
+ problem["test"]
+ "\n"
+ f"check({problem['entry_point']})"
)

try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network. For more
# information on how OpenAI sandboxes its code, see the accompanying paper.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
exec(check_program, exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except BaseException as e:
result.append(f"failed: {e}")

# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir


def check_correctness(
problem: Dict, completion: str, timeout: float, completion_id: Optional[int] = None
) -> Dict:
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.

:param completion_id: an optional completion ID so we can match
the results later even if execution finishes asynchronously.
suite provided in the problem.
"""
manager = multiprocessing.Manager()
result = manager.list()

def unsafe_execute():

with create_tempdir():

# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir

# Disable functionalities that can make destructive changes to the test.
reliability_guard()

# Construct the check program and run it.
check_program = (
problem["prompt"] + completion + "\n" +
problem["test"] + "\n" +
f"check({problem['entry_point']})"
if IS_WINDOWS:
ctx = multiprocessing.get_context("spawn")
with ctx.Pool(processes=1) as pool:
# Use apply_async with a timeout instead of Process
async_result = pool.apply_async(
unsafe_execute, args=(problem, completion, timeout, result)
)

try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
# WARNING
# This program exists to execute untrusted model-generated code. Although
# it is highly unlikely that model-generated code will do something overtly
# malicious in response to this test suite, model-generated code may act
# destructively due to a lack of model capability or alignment.
# Users are strongly encouraged to sandbox this evaluation suite so that it
# does not perform destructive actions on their host or network. For more
# information on how OpenAI sandboxes its code, see the accompanying paper.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
# exec(check_program, exec_globals)
result.append("passed")
except TimeoutException:
async_result.get(timeout=timeout + 1)
except multiprocessing.TimeoutError:
result.append("timed out")
except BaseException as e:
result.append(f"failed: {e}")

# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir

manager = multiprocessing.Manager()
result = manager.list()

p = multiprocessing.Process(target=unsafe_execute)
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
p.kill()
pool.terminate() # Ensure process is terminated on timeout
pool.join() # Wait for cleanup
else:
p = multiprocessing.Process(
target=unsafe_execute, args=(problem, completion, timeout, result)
)
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
p.terminate()
p.join()

if not result:
result.append("timed out")
Expand All @@ -87,16 +109,54 @@ def unsafe_execute():
)


@contextlib.contextmanager
def time_limit(seconds: float):
def _windows_time_limit(seconds: float):
"""Windows-specific timeout implementation using threading."""
done = threading.Event()
raised = threading.Event()
target_thread_id = threading.current_thread().ident

def timeout_function():
if not done.is_set() and target_thread_id:
raised.set()
ctypes.pythonapi.PyThreadState_SetAsyncExc(
target_thread_id, ctypes.py_object(TimeoutException)
)

timer = threading.Timer(seconds, timeout_function)
timer.start()
try:
yield
finally:
done.set()
timer.cancel()
timer.join()
# If we raised but didn't catch, clear the exception state
if raised.is_set():
ctypes.pythonapi.PyThreadState_SetAsyncExc(target_thread_id, None)


def _unix_time_limit(seconds: float):
"""Unix-specific timeout implementation using SIGALRM."""

def signal_handler(signum, frame):
raise TimeoutException("Timed out!")

old_handler = signal.signal(signal.SIGALRM, signal_handler)
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old_handler)


@contextlib.contextmanager
def time_limit(seconds: float):
"""Cross-platform timeout context manager."""
if IS_WINDOWS:
yield from _windows_time_limit(seconds)
else:
yield from _unix_time_limit(seconds)


@contextlib.contextmanager
Expand All @@ -120,7 +180,7 @@ class TimeoutException(Exception):


class WriteOnlyStringIO(io.StringIO):
""" StringIO that throws an exception when it's read from """
"""StringIO that throws an exception when it's read from"""

def read(self, *args, **kwargs):
raise IOError
Expand All @@ -132,12 +192,12 @@ def readlines(self, *args, **kwargs):
raise IOError

def readable(self, *args, **kwargs):
""" Returns True if the IO object can be read. """
"""Returns True if the IO object can be read."""
return False


class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = 'stdin'
_stream = "stdin"


@contextlib.contextmanager
Expand All @@ -163,26 +223,61 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):

WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
generated code, should not be blindly executed outside of one. See the
Codex paper for more information about OpenAI's code sandbox, and proceed
with caution.
"""

if maximum_memory_bytes is not None:
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
if not platform.uname().system == 'Darwin':
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
if not IS_WINDOWS:
try:
import resource

resource.setrlimit(
resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
)
resource.setrlimit(
resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
)
if not platform.uname().system == "Darwin":
resource.setrlimit(
resource.RLIMIT_STACK,
(maximum_memory_bytes, maximum_memory_bytes),
)
except ImportError:
pass
else:
# On Windows, use job objects to limit memory
import win32job
import win32api

job = win32job.CreateJobObject(None, "")
process = win32api.GetCurrentProcess()

# Set memory limit
info = win32job.QueryInformationJobObject(
job, win32job.JobObjectExtendedLimitInformation
)
info["BasicLimitInformation"][
"LimitFlags"
] |= win32job.JOB_OBJECT_LIMIT_PROCESS_MEMORY
info["ProcessMemoryLimit"] = maximum_memory_bytes
win32job.SetInformationJobObject(
job, win32job.JobObjectExtendedLimitInformation, info
)

# Assign current process to job
win32job.AssignProcessToJobObject(job, process)

faulthandler.disable()

import builtins

builtins.exit = None
builtins.quit = None

import os
os.environ['OMP_NUM_THREADS'] = '1'

os.environ["OMP_NUM_THREADS"] = "1"

os.kill = None
os.system = None
Expand Down Expand Up @@ -213,18 +308,21 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
os.chdir = None

import shutil

shutil.rmtree = None
shutil.move = None
shutil.chown = None

import subprocess

subprocess.Popen = None # type: ignore

__builtins__['help'] = None
__builtins__["help"] = None

import sys
sys.modules['ipdb'] = None
sys.modules['joblib'] = None
sys.modules['resource'] = None
sys.modules['psutil'] = None
sys.modules['tkinter'] = None

sys.modules["ipdb"] = None
sys.modules["joblib"] = None
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
tqdm
fire
numpy
pywin32>=306; platform_system == "Windows"
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
setup(
name="human-eval",
py_modules=["human-eval"],
version="1.0",
version="1.0.4",
description="",
author="OpenAI",
packages=find_packages(),
Expand All @@ -19,7 +19,7 @@
],
entry_points={
"console_scripts": [
"evaluate_functional_correctness = human_eval.evaluate_functional_correctness",
"evaluate_functional_correctness = human_eval.evaluate_functional_correctness:main",
]
}
)