Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Contributing to NiceWebRL

Thank you for your interest in contributing to NiceWebRL! This document provides guidelines and instructions for contributing to the project.

## Development Environment Setup

1. Clone the repository:
```bash
git clone https://github.com/yourusername/nicewebrl.git
cd nicewebrl
```

2. Create and activate a virtual environment:
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```

3. Install development dependencies:
```bash
pip install -r requirements.txt
pip install -r requirements-dev.txt # Development dependencies
```

4. Install JAX and JAXlib:
```bash
pip install "jax>=0.2.26" "jaxlib>=0.1.74"
```

## Code Style and Formatting

We use Ruff for code formatting and linting. To ensure your code follows our style guidelines:

1. Install Ruff:
```bash
pip install ruff
```

2. Run Ruff to check your code:
```bash
ruff check .
```

3. Run Ruff to automatically fix issues:
```bash
ruff check --fix .
```

4. Format your code:
```bash
ruff format .
```

### Style Guidelines

- Follow the Google Python Style Guide
- Use type hints for all function parameters and return values
- Write docstrings for all public functions, classes, and methods
- Keep lines under 80 characters
- Use meaningful variable and function names
- Add comments for complex logic

## Running Tests

Before submitting a pull request, ensure all tests pass:

1. Run the test suite:
```bash
pytest
```

2. Run tests with coverage:
```bash
pytest --cov=nicewebrl
```

## Pull Request Process

1. **Issue First**: For significant changes, please open an issue first to discuss the proposed changes.

2. **Work-in-Progress PRs**: We accept work-in-progress PRs for early feedback. Please mark them with the "WIP" prefix in the title.

3. **Branch Naming**: Use descriptive branch names:
- `feature/your-feature-name`
- `fix/your-fix-name`
- `docs/your-docs-update`

4. **Commit Messages**: Write clear, descriptive commit messages that explain the "why" of your changes.

5. **PR Description**: Include:
- A clear description of the changes
- Related issue numbers
- Any breaking changes
- Screenshots for UI changes

## Development Workflow

1. Keep your fork up to date:
```bash
git remote add upstream https://github.com/original-owner/nicewebrl.git
git fetch upstream
git checkout main
git merge upstream/main
```

2. Create a new branch for your changes:
```bash
git checkout -b feature/your-feature-name
```

3. Make your changes and commit them:
```bash
git add .
git commit -m "Description of your changes"
```

4. Push to your fork:
```bash
git push origin feature/your-feature-name
```

## Additional Guidelines

- Write clear, concise documentation
- Add tests for new features
- Update existing tests if you change functionality
- Keep dependencies up to date
- Follow semantic versioning for releases

## Questions?

If you have any questions about contributing, please open an issue or contact the maintainers.

Thank you for contributing to NiceWebRL!
5 changes: 2 additions & 3 deletions nicewebrl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from nicewebrl.dataframe import DataFrame
from nicewebrl.dataframe import concat_dataframes

from nicewebrl.dataframe import concat_list as concat_dataframes
from nicewebrl.utils import toggle_fullscreen
from nicewebrl.utils import check_fullscreen
from nicewebrl.utils import clear_element
Expand Down Expand Up @@ -31,7 +30,7 @@
from nicewebrl.nicejax import base64_npimage
from nicewebrl.nicejax import StepType
from nicewebrl.nicejax import TimeStep
from nicewebrl.nicejax import EnvParams
#from nicewebrl.nicejax import EnvParams
from nicewebrl.nicejax import TimestepWrapper
from nicewebrl.nicejax import JaxWebEnv
from nicewebrl.nicejax import MultiAgentJaxWebEnv
Expand Down
4 changes: 3 additions & 1 deletion nicewebrl/container.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import uuid
from asyncio import Lock

from nicegui import app, ui
import uuid

from nicewebrl.utils import get_user_lock


Expand Down
29 changes: 16 additions & 13 deletions nicewebrl/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import List, Callable, Optional
import polars as pl
from typing import Callable, List, Optional

import numpy as np
import polars as pl
from flax import struct

Remove = bool
Episode = struct.PyTreeNode
EpisodeFilter = Callable[[Episode], Remove]
# Type definitions
REMOVE = bool
EPISODE = struct.PyTreeNode
EPISODE_FILTER = Callable[[EPISODE], REMOVE]


class DataFrame(object):
Expand Down Expand Up @@ -69,7 +71,7 @@ def wrapped_method(*args, **kwargs):
def filter(
self,
*args,
episode_filter: Optional[EpisodeFilter] = None,
episode_filter: Optional[EPISODE_FILTER] = None,
reindex: Optional[bool] = True,
**kwargs,
):
Expand Down Expand Up @@ -98,7 +100,7 @@ def filter(
df = df._filter_episodes(episode_filter)
return df

def _filter_episodes(self, episode_filter: EpisodeFilter):
def _filter_episodes(self, episode_filter: EPISODE_FILTER):
"""
Filter rows and episodes based on a given filter function.

Expand All @@ -125,10 +127,10 @@ def _filter_episodes(self, episode_filter: EpisodeFilter):

def filter_by_group(
self,
input_episode_filter: EpisodeFilter,
input_episode_filter: EPISODE_FILTER,
input_settings: dict,
output_settings: dict,
output_episode_filter: Optional[EpisodeFilter] = None,
output_episode_filter: Optional[EPISODE_FILTER] = None,
group_key: str = "user_id",
):
"""
Expand Down Expand Up @@ -197,7 +199,7 @@ def filter_by_group(
def apply(
self,
fn,
episode_filter: Optional[EpisodeFilter] = None,
episode_filter: Optional[EPISODE_FILTER] = None,
output_transform=lambda x: x,
**kwargs,
):
Expand Down Expand Up @@ -231,10 +233,10 @@ def apply(
def apply_by_group(
self,
fn,
input_episode_filter: EpisodeFilter,
input_episode_filter: EPISODE_FILTER,
input_settings: dict,
output_settings: dict,
output_episode_filter: Optional[EpisodeFilter] = None,
output_episode_filter: Optional[EPISODE_FILTER] = None,
output_transform=lambda x: x,
splitting_key: str = "user_id",
):
Expand Down Expand Up @@ -311,4 +313,5 @@ def concat_list(*dfs: List[DataFrame]) -> DataFrame:
return DataFrame(df=pl.concat(_dfs, how="diagonal_relaxed"), episodes=episodes)


concat_dataframes = concat_list
# Module-level variables
CONCAT_DATAFRAMES = concat_list
12 changes: 7 additions & 5 deletions nicewebrl/experiment.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from nicegui import app
from typing import List, Union
import dataclasses
import uuid
from typing import List, Union

import jax.numpy as jnp
import jax.random
from nicegui import app

from nicewebrl.stages import Block, Stage
from nicewebrl.container import Container
from nicewebrl.nicejax import new_rng
from nicewebrl.logging import get_logger
from nicewebrl.utils import get_user_lock, get_progress
from nicewebrl.nicejax import new_rng
from nicewebrl.stages import Block, Stage
from nicewebrl.utils import get_progress, get_user_lock

# Module-level variables
logger = get_logger(__name__)


Expand Down
10 changes: 5 additions & 5 deletions nicewebrl/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def log_filename_fn(log_dir, user_id):
> (2/31) 3266458590: 2024/10/26 20:24:47 - some_name - WARNING - some warning message
"""

from typing import Optional, Callable
import traceback
import io

import logging
import sys
import os
from nicegui import app
import sys
import traceback
from functools import lru_cache
from typing import Callable, Optional

from nicegui import app


class UserAwareOutput(io.TextIOBase):
Expand Down
38 changes: 19 additions & 19 deletions nicewebrl/nicejax.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
import time
import typing
from datetime import datetime
from typing import Optional
from typing import Union, Any, Callable, Tuple
from typing import get_type_hints
from base64 import b64encode
from flax import struct
from flax import serialization
from flax.core import FrozenDict
import io
import inspect
import random
import sys
import time
from base64 import b64encode
from datetime import datetime
from typing import Any, Callable, Optional, Tuple, Union, get_type_hints

import jax
import jax.numpy as jnp
import jax.random
import numpy as np
import random
import sys
from flax import serialization, struct
from flax.core import FrozenDict
from nicegui import app, ui
from PIL import Image

from nicewebrl.logging import get_logger

Timestep = Any
RenderFn = Callable[[Timestep], jax.Array]
# Type definitions
TIMESTEP = Any
RENDER_FN = Callable[[TIMESTEP], jax.Array]

# Module-level variables
logger = get_logger(__name__)


Expand Down Expand Up @@ -109,7 +109,7 @@ class StepType(jnp.uint8):
MID: jax.Array = jnp.asarray(1, dtype=jnp.uint8)
LAST: jax.Array = jnp.asarray(2, dtype=jnp.uint8)

EnvParams = struct.PyTreeNode
ENV_PARAMS = struct.PyTreeNode

class TimeStep(struct.PyTreeNode):
state: struct.PyTreeNode
Expand Down Expand Up @@ -267,8 +267,8 @@ def precompile(self, dummy_env_params: Optional[struct.PyTreeNode] = None) -> No
print(f"\tstep time: {time.time() - start}")

def precompile_vmap_render_fn(
self, render_fn: RenderFn, dummy_env_params: struct.PyTreeNode
) -> RenderFn:
self, render_fn: RENDER_FN, dummy_env_params: struct.PyTreeNode
) -> RENDER_FN:
"""Call this function to pre-compile a multi-render function before experiment starts."""
print("Compiling multi-render function.")
start = time.time()
Expand Down Expand Up @@ -324,8 +324,8 @@ def precompile(self, dummy_env_params: struct.PyTreeNode={'random_reset_fn': 're

def precompile_vmap_render_fn(
self,
render_fn: RenderFn,
dummy_env_params: struct.PyTreeNode={'random_reset_fn': 'reset_all'}) -> RenderFn:
render_fn: RENDER_FN,
dummy_env_params: struct.PyTreeNode={'random_reset_fn': 'reset_all'}) -> RENDER_FN:
"""Call this function to pre-compile a multi-render function before experiment starts."""
logger.info("Compiling multi-render function.")
start = time.time()
Expand Down
Loading