Skip to content

Commit

Permalink
support download via aria2 (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
dishuostec authored Jun 22, 2024
1 parent 951e5ec commit 7b9292f
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ NODE_CLASS_MAPPINGS.update({
* When you create the `pip_overrides.json` file, it changes the installation of specific pip packages to installations defined by the user.
* Please refer to the `pip_overrides.json.template` file.

* Use `aria2` as downloader
* [howto](docs/en/use_aria2.md)

## Scanner
When you run the `scan.sh` script:
Expand Down
40 changes: 40 additions & 0 deletions docs/en/use_aria2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Use `aria2` as downloader

Two environment variables are needed to use `aria2` as the downloader.

```bash
export COMFYUI_MANAGER_ARIA2_SERVER=http://127.0.0.1:6800
export COMFYUI_MANAGER_ARIA2_SECRET=__YOU_MUST_CHANGE_IT__
```

An example `docker-compose.yml`

```yaml
services:

aria2:
container_name: aria2
image: p3terx/aria2-pro
environment:
- PUID=1000
- PGID=1000
- UMASK_SET=022
- RPC_SECRET=__YOU_MUST_CHANGE_IT__
- RPC_PORT=5080
- DISK_CACHE=64M
- IPV6_MODE=false
- UPDATE_TRACKERS=false
- CUSTOM_TRACKER_URL=
volumes:
- ./config:/config
- ./downloads:/downloads
- ~/ComfyUI/models:/models
- ~/ComfyUI/custom_nodes:/custom_nodes
ports:
- 6800:6800
restart: unless-stopped
logging:
driver: json-file
options:
max-size: 1m
```
67 changes: 67 additions & 0 deletions glob/manager_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os

aria2 = os.getenv('COMFYUI_MANAGER_ARIA2_SERVER')
HF_ENDPOINT = os.getenv('HF_ENDPOINT')

if aria2 is not None:
secret = os.getenv('COMFYUI_MANAGER_ARIA2_SECRET')
host, port = aria2.split(':')
import aria2p

aria2 = aria2p.API(aria2p.Client(host=host, port=port, secret=secret))


def download_url(model_url: str, model_dir: str, filename: str):
if aria2:
return aria2_download_url(model_url, model_dir, filename)
else:
from torchvision.datasets.utils import download_url as torchvision_download_url

return torchvision_download_url(model_url, model_dir, filename)


def aria2_find_task(dir: str, filename: str):
target = os.path.join(dir, filename)

downloads = aria2.get_downloads()

for download in downloads:
for file in download.files:
if file.is_metadata:
continue
if str(file.path) == target:
return download


def aria2_download_url(model_url: str, model_dir: str, filename: str):
import manager_core as core
import tqdm
import time

if model_dir.startswith(core.comfy_path):
model_dir = model_dir[len(core.comfy_path) :]

if HF_ENDPOINT:
model_url = model_url.replace('https://huggingface.co', HF_ENDPOINT)

download_dir = model_dir if model_dir.startswith('/') else os.path.join('/models', model_dir)

download = aria2_find_task(download_dir, filename)
if download is None:
options = {'dir': download_dir, 'out': filename}
download = aria2.add(model_url, options)[0]

if download.is_active:
with tqdm.tqdm(
total=download.total_length,
bar_format='{l_bar}{bar}{r_bar}',
desc=filename,
unit='B',
unit_scale=True,
) as progress_bar:
while download.is_active:
if progress_bar.total == 0 and download.total_length != 0:
progress_bar.reset(download.total_length)
progress_bar.update(download.completed_length - progress_bar.n)
time.sleep(1)
download.update()
2 changes: 1 addition & 1 deletion glob/manager_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def run_script(self, cmd, cwd='.'):

sys.path.append('../..')

from torchvision.datasets.utils import download_url
from manager_downloader import download_url

core.comfy_path = os.path.dirname(folder_paths.__file__)
core.js_path = os.path.join(core.comfy_path, "web", "extensions")
Expand Down

0 comments on commit 7b9292f

Please sign in to comment.