Skip to content

Commit 87740a9

Browse files
committed
Support downloading from DiffusionDB Large
Signed-off-by: Jay Wang <jay@zijie.wang>
1 parent 944f96d commit 87740a9

File tree

3 files changed

+60
-23
lines changed

3 files changed

+60
-23
lines changed

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"python.linting.pylintEnabled": false,
2+
"python.linting.pylintEnabled": true,
33
"python.linting.enabled": true,
44
"python.linting.flake8Enabled": false,
55
"docwriter.style": "Auto-detect"

datasheet.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Organization: Georgia Institute of Technology
1919

2020
1. **For what purpose was the dataset created?** Was there a specific task in mind? Was there a specific gap that needed to be filled? Please provide a description.
2121

22-
The DiffusionDB project was inspired by important needs in research focused on diffusion models and prompt engineering. As large text-to-image models are relatively new, there is a pressing need to understand how these models work, how to write effective prompts, and how to design tools to help users generate images. To tackle these critical challenges, we present DiffusionDB, the first large-scale prompt dataset with 2 million real prompt-image pairs.
22+
The DiffusionDB project was inspired by important needs in research focused on diffusion models and prompt engineering. As large text-to-image models are relatively new, there is a pressing need to understand how these models work, how to write effective prompts, and how to design tools to help users generate images. To tackle these critical challenges, we present DiffusionDB, the first large-scale prompt dataset with 14 million real prompt-image pairs.
2323

2424
2. **Who created this dataset (e.g. which team, research group) and on behalf of which entity (e.g. company, institution, organization)**?
2525

@@ -45,7 +45,7 @@ of Technology.
4545

4646
2. **How many instances are there in total (of each type, if appropriate)?**
4747

48-
There are 2 million instances in total in the dataset.
48+
There are 14 million instances in total in the dataset.
4949

5050
3. **Does the dataset contain all possible instances or is it a sample (not necessarily random) of instances from a larger set?** If the dataset is a sample, then what is the larger set? Is the sample representative of the larger set (e.g. geographic coverage)? If so, please describe how this representativeness was validated/verified. If it is not representative of the larger set, please describe why not (e.g. to cover a more diverse range of instances, because instances were withheld or unavailable).
5151

scripts/download.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
"""A script to make downloading the DiffusionDB dataset easier."""
55
from urllib.error import HTTPError
66
from urllib.request import urlretrieve
7-
import shutil
8-
97
from alive_progress import alive_bar
8+
from os.path import exists
9+
10+
import shutil
11+
import os
1012
import time
1113
import argparse
1214

1315
index = None # initiate main arguments as None
1416
range_max = None
1517
output = None
1618
unzip = None
19+
large = None
1720

1821
parser = argparse.ArgumentParser(description="Download a file from a URL") #
1922

@@ -43,6 +46,13 @@
4346
# It's setting the argument to True if it's provided.
4447
action="store_true",
4548
)
49+
parser.add_argument(
50+
"-l",
51+
"--large",
52+
default=False,
53+
help="Download from DiffusionDB Large (14 million images)",
54+
action="store_true",
55+
)
4656

4757
args = parser.parse_args() # parse the arguments
4858

@@ -56,32 +66,49 @@
5666
output = args.output
5767
if args.unzip:
5868
unzip = args.unzip
69+
if args.large:
70+
large = args.large
5971

6072
if (
61-
args.index and args.range and args.output and args.unzip is None
73+
args.index and args.range and args.output and args.unzip and args.large is None
6274
): # if no arguments are provided, set default behaviour
6375
index = 1
6476
range_max = 2000
6577
output = "images"
6678
unzip = False
79+
large = False
6780

6881

69-
def download(index=1, range_index=0, output=""):
82+
def download(index=1, range_index=0, output="", large=False):
7083
"""
7184
Download a file from a URL and save it to a local file
7285
7386
:param index: The index of the file to download, defaults to 1 (optional)
7487
:param range_index: The number of files to download. If you want to download
75-
all files, set this to
76-
the number of files you want to download, defaults to 0 (optional) :param
77-
output: The directory to download the files to :return: A list of files to
78-
unzip
88+
all files, set this to the number of files you want to download,
89+
defaults to 0 (optional)
90+
:param output: The directory to download the files to :return: A list of
91+
files to unzip
92+
:param large: If downloading from DiffusionDB Large (14 million images)
93+
instead of DiffusionDB 2M (2 million images)
7994
"""
8095
baseurl = "https://huggingface.co/datasets/poloclub/diffusiondb/resolve/main/"
8196
files_to_unzip = []
82-
url = f"{baseurl}images/part-{index:06}.zip"
97+
98+
if large:
99+
if index <= 10000:
100+
url = f"{baseurl}diffusiondb-large-part-1/part-{index:06}.zip"
101+
else:
102+
url = f"{baseurl}diffusiondb-large-part-2/part-{index:06}.zip"
103+
else:
104+
url = f"{baseurl}images/part-{index:06}.zip"
105+
83106
if output != "":
84107
output = f"{output}/"
108+
109+
if not exists(output):
110+
os.makedirs(output)
111+
85112
if range_index == 0:
86113
print("Downloading file: ", url)
87114
file_path = f"{output}part-{index:06}.zip"
@@ -93,9 +120,16 @@ def download(index=1, range_index=0, output=""):
93120
unzip(file_path)
94121
else:
95122
# It's downloading the files numbered from index to range_index.
96-
with alive_bar(range_index, title="Downloading files") as bar:
123+
with alive_bar(range_index - index, title="Downloading files") as bar:
97124
for idx in range(index, range_index):
98-
url = f"{baseurl}images/part-{idx:06}.zip"
125+
if large:
126+
if idx <= 10000:
127+
url = f"{baseurl}diffusiondb-large-part-1/part-{idx:06}.zip"
128+
else:
129+
url = f"{baseurl}diffusiondb-large-part-2/part-{idx:06}.zip"
130+
else:
131+
url = f"{baseurl}images/part-{idx:06}.zip"
132+
99133
loop_file_path = f"{output}part-{idx:06}.zip"
100134
# It's trying to download the file, and if it encounters an
101135
# HTTPError, it prints the error.
@@ -117,7 +151,7 @@ def download(index=1, range_index=0, output=""):
117151
return files_to_unzip
118152

119153

120-
def unzip(file: str):
154+
def unzip_file(file: str):
121155
"""
122156
> This function takes a zip file as an argument and unpacks it
123157
@@ -138,12 +172,12 @@ def unzip_all(files: list):
138172
"""
139173
with alive_bar(len(files), title="Unzipping files") as bar:
140174
for file in files:
141-
unzip(file)
175+
unzip_file(file)
142176
time.sleep(0.1)
143177
bar()
144178

145179

146-
def main(index=None, range_max=None, output=None, unzip=None):
180+
def main(index=None, range_max=None, output=None, unzip=None, large=None):
147181
"""
148182
`main` is a function that takes in an index, a range_max, an output, and an
149183
unzip, and if the user confirms that they have enough space, it downloads
@@ -154,17 +188,20 @@ def main(index=None, range_max=None, output=None, unzip=None):
154188
:param output: The directory to download the files to
155189
:param unzip: If you want to unzip the files after downloading them, set
156190
this to True
191+
:param large: If you want to download from DiffusionDB Large (14 million
192+
images) instead of DiffusionDB 2M (2 million images)
157193
:return: A list of files that have been downloaded
158194
"""
159-
confirmation = input("Do you have at least 1.7Tb free: (y/n)")
160-
if confirmation != "y":
161-
return
195+
if range_max - index > 1999:
196+
confirmation = input("Do you have at least 1.7Tb free: (y/n)")
197+
if confirmation != "y":
198+
return
162199
if index and range_max:
163-
files = download(index, range_max, output)
200+
files = download(index, range_max, output, large)
164201
if unzip:
165202
unzip_all(files)
166203
elif index:
167-
download(index, output=output)
204+
download(index, output=output, large=large)
168205
else:
169206
print("No index provided")
170207

@@ -174,4 +211,4 @@ def main(index=None, range_max=None, output=None, unzip=None):
174211
# to import the script into the interpreter without automatically running the
175212
# main function.
176213
if __name__ == "__main__":
177-
main(index, range_max, output, unzip)
214+
main(index, range_max, output, unzip, large)

0 commit comments

Comments
 (0)