-
Notifications
You must be signed in to change notification settings - Fork 894
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
951e5ec
commit 7b9292f
Showing
4 changed files
with
110 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Use `aria2` as downloader | ||
|
||
Two environment variables are needed to use `aria2` as the downloader. | ||
|
||
```bash | ||
export COMFYUI_MANAGER_ARIA2_SERVER=http://127.0.0.1:6800 | ||
export COMFYUI_MANAGER_ARIA2_SECRET=__YOU_MUST_CHANGE_IT__ | ||
``` | ||
|
||
An example `docker-compose.yml` | ||
|
||
```yaml | ||
services: | ||
|
||
aria2: | ||
container_name: aria2 | ||
image: p3terx/aria2-pro | ||
environment: | ||
- PUID=1000 | ||
- PGID=1000 | ||
- UMASK_SET=022 | ||
- RPC_SECRET=__YOU_MUST_CHANGE_IT__ | ||
- RPC_PORT=5080 | ||
- DISK_CACHE=64M | ||
- IPV6_MODE=false | ||
- UPDATE_TRACKERS=false | ||
- CUSTOM_TRACKER_URL= | ||
volumes: | ||
- ./config:/config | ||
- ./downloads:/downloads | ||
- ~/ComfyUI/models:/models | ||
- ~/ComfyUI/custom_nodes:/custom_nodes | ||
ports: | ||
- 6800:6800 | ||
restart: unless-stopped | ||
logging: | ||
driver: json-file | ||
options: | ||
max-size: 1m | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
|
||
aria2 = os.getenv('COMFYUI_MANAGER_ARIA2_SERVER') | ||
HF_ENDPOINT = os.getenv('HF_ENDPOINT') | ||
|
||
if aria2 is not None: | ||
secret = os.getenv('COMFYUI_MANAGER_ARIA2_SECRET') | ||
host, port = aria2.split(':') | ||
import aria2p | ||
|
||
aria2 = aria2p.API(aria2p.Client(host=host, port=port, secret=secret)) | ||
|
||
|
||
def download_url(model_url: str, model_dir: str, filename: str): | ||
if aria2: | ||
return aria2_download_url(model_url, model_dir, filename) | ||
else: | ||
from torchvision.datasets.utils import download_url as torchvision_download_url | ||
|
||
return torchvision_download_url(model_url, model_dir, filename) | ||
|
||
|
||
def aria2_find_task(dir: str, filename: str): | ||
target = os.path.join(dir, filename) | ||
|
||
downloads = aria2.get_downloads() | ||
|
||
for download in downloads: | ||
for file in download.files: | ||
if file.is_metadata: | ||
continue | ||
if str(file.path) == target: | ||
return download | ||
|
||
|
||
def aria2_download_url(model_url: str, model_dir: str, filename: str): | ||
import manager_core as core | ||
import tqdm | ||
import time | ||
|
||
if model_dir.startswith(core.comfy_path): | ||
model_dir = model_dir[len(core.comfy_path) :] | ||
|
||
if HF_ENDPOINT: | ||
model_url = model_url.replace('https://huggingface.co', HF_ENDPOINT) | ||
|
||
download_dir = model_dir if model_dir.startswith('/') else os.path.join('/models', model_dir) | ||
|
||
download = aria2_find_task(download_dir, filename) | ||
if download is None: | ||
options = {'dir': download_dir, 'out': filename} | ||
download = aria2.add(model_url, options)[0] | ||
|
||
if download.is_active: | ||
with tqdm.tqdm( | ||
total=download.total_length, | ||
bar_format='{l_bar}{bar}{r_bar}', | ||
desc=filename, | ||
unit='B', | ||
unit_scale=True, | ||
) as progress_bar: | ||
while download.is_active: | ||
if progress_bar.total == 0 and download.total_length != 0: | ||
progress_bar.reset(download.total_length) | ||
progress_bar.update(download.completed_length - progress_bar.n) | ||
time.sleep(1) | ||
download.update() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters