Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions baselines/menu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import argparse
import os
import re
import requests
import json
import subprocess
import glob
from pathlib import Path
import uuid
from red_gym_env import RedGymEnv
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.callbacks import CheckpointCallback
import webbrowser

DEFAULT_BASE_URL = "http://127.0.0.1:5000"
directory_path = 'downloaded_checkpoints'

if not os.path.exists(directory_path):
os.makedirs(directory_path)

def make_env(rank, env_conf, seed=0):
def _init():
env = RedGymEnv(env_conf)
env.reset(seed=(seed + rank))
return env
set_random_seed(seed)
return _init

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--menu')
parser.add_argument('--restore')
parser.add_argument('--upload')
parser.add_argument('--url')
return parser.parse_args()

def show_menu(selected_checkpoint):
while True:
session_dict, downloaded_checkpoints = list_all_sessions_and_pokes()
if not session_dict:
print("No checkpoints found.")
return selected_checkpoint

downloaded_checkpoint_count = len(session_dict)
print(f"\nAvailable sessions sorted by their largest checkpoints:")
for i, (session, largest_step) in enumerate(session_dict.items()):
print(f" {i + 1}. {session}/poke-{largest_step}_steps.zip")

print("\nDownloaded checkpoints:")
for i, checkpoint in enumerate(downloaded_checkpoints, start=downloaded_checkpoint_count + 1):
print(f" {i}. {checkpoint}")

print("\nDefault Runs:")
matching_files = [file for file in os.listdir(os.getcwd()) if file.startswith("run_") and file.endswith(".py")]
for i, file in enumerate(matching_files, start=downloaded_checkpoint_count + 1):
print(f" {i}. {file}")

print("\n 97. Resume from remote")
print(" 98. Upload to remote")
print(" 99. View progress using index.html")
menu_selection = input("Enter the number of the menu option: ")

if menu_selection.isdigit():
selection = int(menu_selection)
if 1 <= selection <= len(session_dict):
selected_session = list(session_dict.keys())[selection - 1]
selected_step = session_dict[selected_session]
selected_checkpoint = f"{selected_session}/poke_{selected_step}_steps.zip"
return selected_checkpoint
elif downloaded_checkpoint_count + 1 <= selection <= downloaded_checkpoint_count + len(downloaded_checkpoints):
selected_checkpoint = os.path.join('downloaded_checkpoints', downloaded_checkpoints[selection - downloaded_checkpoint_count - 1])
return selected_checkpoint
elif downloaded_checkpoint_count + len(downloaded_checkpoints) + 1 <= selection <= downloaded_checkpoint_count + len(downloaded_checkpoints) + len(matching_files):
selected_run = matching_files[selection - downloaded_checkpoint_count - len(downloaded_checkpoints) - 1]
run_script = f"python3 {selected_run}"
subprocess.run(run_script, shell=True)
elif menu_selection == '97':
selected_checkpoint = remote_actions()
if selected_checkpoint:
return selected_checkpoint
elif menu_selection == '98':
selection = int(input("Enter your selection for remote upload: "))
upload(selection, session_dict)
elif menu_selection == '99':
create_index('index.html')
print('Open index.html to monitor the newest run.')


else:
print("Invalid selection.")
else:
print("Invalid input. Please enter a valid number.")

def list_all_sessions_and_pokes():
all_folders = os.listdir()
session_folders = [folder for folder in all_folders if re.match(r'session_[0-9a-fA-F]{8}', folder)]
session_dict = {}
downloaded_checkpoints = []

for session_folder in session_folders:
poke_files = glob.glob(f"{session_folder}/poke_*_steps.zip")
if poke_files:
largest_poke_file = max(poke_files, key=lambda x: int(re.search(r'poke_(\d+)_steps', x).group(1)))
largest_step = int(re.search(r'poke_(\d+)_steps', largest_poke_file).group(1))
session_dict[session_folder] = largest_step

downloaded_checkpoints = [file for file in os.listdir('downloaded_checkpoints') if file.endswith('.zip')]
sorted_session_dict = {k: v for k, v in sorted(session_dict.items(), key=lambda item: item[1], reverse=True)}
return sorted_session_dict, downloaded_checkpoints

def remote_actions():
BASE_URL = DEFAULT_BASE_URL
response = requests.get(f"{BASE_URL}/uploads/metadata.txt")
if response.status_code != 200:
print("Failed to fetch metadata from the server.")
return None
server_metadata = response.text.strip()
if not server_metadata:
print("No checkpoint metadata found. Is this an empty server?")
return None
try:
server_metadata = json.loads(server_metadata)
except json.decoder.JSONDecodeError as e:
print("Error decoding JSON:", str(e))
return None

print(f"\nAvailable remote checkpoints:")
for i, entry in enumerate(server_metadata):
print(f"{i + 1}. Filename: {entry['filename']}, Steps: {entry['steps']}")

server_selection = input("Enter the number of the checkpoint you want to download: ")
try:
server_selection = int(server_selection)
if 1 <= server_selection <= len(server_metadata):
selected_server_entry = server_metadata[server_selection - 1]
filename = selected_server_entry['filename']
download_response = requests.get(f"{BASE_URL}/uploads/{filename}")

if download_response.status_code == 200:
with open(f"downloaded_checkpoints/{filename}", 'wb') as f:
f.write(download_response.content)
print(f"Downloaded checkpoint: {filename}")
else:
print(f"Failed to download the selected checkpoint: {filename}")
else:
print("Invalid selection.")
except ValueError:
print("Invalid input. Please enter a valid number.")
return None

def restore(url, download_selection):
response = requests.get(url)

if response.status_code == 200:
filename = url.split("/")[-1]
with open(filename, 'wb') as file:
file.write(response.content)
print(f"Downloaded checkpoint: {filename}")
return filename
else:
print("Failed to download checkpoint.")
return None

def upload(selection, session_dict):
try:
selected_session = list(session_dict.keys())[selection - 1]
selected_step = session_dict[selected_session]
file_path = f"{selected_session}/poke_{selected_step}_steps.zip"
upload_command = f"curl -X POST -F file=@{file_path} http://127.0.0.1:5000/upload"
subprocess.run(upload_command, shell=True)
except (ValueError, IndexError):
print("Invalid selection")

def main(selected_checkpoint):
sess_path = Path(f'session_{str(uuid.uuid4())[:8]}')
ep_length = 2048 * 10
env_config = {
'headless': True, 'save_final_state': True, 'early_stop': False,
'action_freq': 24, 'init_state': '../has_pokedex_nballs.state', 'max_steps': ep_length,
'print_rewards': True, 'save_video': False, 'fast_video': True, 'session_path': sess_path,
'gb_path': '../PokemonRed.gb', 'debug': False, 'sim_frame_dist': 2_000_000.0,
'use_screen_explore': True, 'reward_scale': 4, 'extra_buttons': False,
'explore_weight': 3
}
print(env_config)
num_cpu = 16
env = SubprocVecEnv([make_env(i, env_config) for i in range(num_cpu)])
checkpoint_callback = CheckpointCallback(save_freq=ep_length, save_path=sess_path, name_prefix='poke')
learn_steps = 40

print('\nLoading checkpoint', selected_checkpoint, ' ... \n')
model = PPO.load(selected_checkpoint, env=env)
model.n_steps = ep_length
model.n_envs = num_cpu
model.rollout_buffer.buffer_size = ep_length
model.rollout_buffer.n_envs = num_cpu
model.rollout_buffer.reset()
for i in range(learn_steps):
model.learn(total_timesteps=(ep_length) * num_cpu * 1000, callback=checkpoint_callback)


def create_index(output_file='index.html'):
# Find all session folders within the current working directory
session_folders = [folder for folder in os.listdir() if folder.startswith('session_')]
if not session_folders:
print("No 'session_' folders found in the current working directory.")
return

# Sort the session folders by their names (timestamps) and get the newest one
newest_session = max(session_folders, key=lambda folder: os.path.getctime(folder))

image_names = []

# Get a list of image names in the newest session folder
image_dir = os.path.join(newest_session)
for filename in os.listdir(image_dir):
if filename.endswith('.jpeg'):
image_names.append(filename)

# Create the updated HTML content with the image names
html_content = f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Dynamic Photo Grid</title>
<script src="https://cdn.tailwindcss.com"></script>
<meta http-equiv="refresh" content="2"> <!-- Refresh the grid frame every 5 seconds -->
</head>
<body class="p-4">
<div class="grid grid-cols-4 gap-4">
"""

# Update the image sources based on the image names
for i, image_name in enumerate(image_names, start=1):
image_src = os.path.join(newest_session, image_name)
html_content += f' <img src="{image_src}" class="w-full h-auto max-w-lg" alt="Image {i}">\n'

html_content += """ </div>
<style>
.max-w-lg {
max-width: 250px; /* Adjust this value to make the images slightly bigger */
}
</style>
</body>
</html>
"""
# Save the updated HTML content to a file
with open(output_file, 'w') as file:
file.write(html_content)

print(f"HTML content updated and saved as '{output_file}'.")

if __name__ == '__main__':
selected_checkpoint = None
selected_checkpoint = show_menu(selected_checkpoint)
main(selected_checkpoint)
102 changes: 102 additions & 0 deletions www/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import json
from flask import Flask, request, jsonify, render_template, send_from_directory
from werkzeug.utils import secure_filename
import hashlib
from datetime import datetime
import re

app = Flask(__name__)

app.config['UPLOAD_FOLDER'] = './uploads'
app.config['METADATA_FILE'] = './uploads/metadata.txt'

# Initialize the files_data list with metadata on startup
files_data = []
@app.route('/uploads')
def list_files():
"""Display a list of uploaded files for download."""
read_metadata()
sorted_files = sorted(files_data, key=lambda x: x.get('steps', 0), reverse=True)
return render_template('list_files.html', files=sorted_files)

def read_metadata():
"""Read metadata from the metadata file."""
global files_data
try:
with open(app.config['METADATA_FILE'], 'r') as f:
files_data = json.load(f)
except FileNotFoundError:
files_data = []
except Exception as e:
print(f"Error reading metadata: {str(e)}")

def write_metadata(data):
"""Write metadata to the metadata file."""
with open(app.config['METADATA_FILE'], 'w') as f:
json.dump(data, f)

# Read metadata from file on startup
read_metadata()

@app.route('/')
def index():
"""Display a list of uploaded files with metadata sorted by steps."""
read_metadata()
sorted_files = sorted(files_data, key=lambda x: x.get('steps', 0), reverse=True)
return render_template('index.html', files=sorted_files)

@app.route('/uploads/<filename>')
def download_file(filename):
"""Download an uploaded file by providing the filename."""
return send_from_directory(app.config['UPLOAD_FOLDER'], filename)

@app.route('/upload', methods=['POST'])
def upload_file():
"""Upload a file, extract metadata, and save it with metadata."""
global files_data

uploaded_file = request.files['file']

# Generate a unique filename using timestamp and SHA1 hash
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
sha1 = hashlib.sha1(uploaded_file.read()).hexdigest()[:10]
uploaded_file.seek(0)
original_filename = secure_filename(uploaded_file.filename)

# Extract the 'steps' from the original filename using regex
match = re.search(r'poke_(\d+)_steps\.zip', original_filename)
if match:
steps = int(match.group(1))
else:
steps = None # Default value if not found

filename = f"poke_{steps}_steps_{sha1}_{timestamp}.zip"
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)

# Check if a file with the same 'steps' value already exists
existing_entry = next((entry for entry in files_data if entry.get('steps') == steps), None)

if existing_entry:
# Update the existing entry
existing_entry['filename'] = filename
existing_entry['filepath'] = filepath
existing_entry['timestamp'] = timestamp
else:
# Create a new entry
file_info = {'filename': filename, 'filepath': filepath, 'timestamp': timestamp, 'steps': steps}
files_data.append(file_info)

# Save the uploaded file to the specified filepath
uploaded_file.save(filepath)

# Sort the metadata by 'steps' in reverse order
files_data.sort(key=lambda x: x.get('steps', 0), reverse=True)

# Write metadata to the metadata file
write_metadata(files_data)

return jsonify({'success': True})

if __name__ == '__main__':
app.run(debug=True)
Loading