Skip to content

Commit 5ef68a8

Browse files
authored
Merge branch 'main' into dev/camintrinsics_and_video_undistort_ops
2 parents fa45d2d + 36e6389 commit 5ef68a8

30 files changed

+2778
-33
lines changed

.github/workflows/deploy_sphinx_docs.yml

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,48 @@ jobs:
2323
REPO_OWNER: ${{ github.repository_owner }}
2424
MIN_TAG: v1.4.0
2525
steps:
26+
- name: Mount /mnt into workspace (writable)
27+
run: |
28+
set -euxo pipefail
29+
sudo mkdir -p /mnt/repo
30+
sudo chown -R "$USER:$USER" /mnt/repo
31+
mkdir -p "$GITHUB_WORKSPACE/repo"
32+
sudo mount --bind /mnt/repo "$GITHUB_WORKSPACE/repo"
33+
sudo chown -R "$USER:$USER" "$GITHUB_WORKSPACE/repo"
34+
ls -ld /mnt/repo "$GITHUB_WORKSPACE/repo"
2635
- name: Checkout
2736
uses: actions/checkout@v4
2837
with:
2938
fetch-depth: 0
39+
path: repo
3040
- name: Setup Python ${{ matrix.python-version }}
3141
uses: actions/setup-python@master
3242
with:
3343
python-version: ${{ matrix.python-version }}
44+
- name: Free disk space
45+
run: |
46+
sudo swapoff -a
47+
sudo rm -f /swapfile
48+
sudo apt-get autoremove -y >/dev/null 2>&1
49+
sudo apt-get autoclean -y >/dev/null 2>&1
50+
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
51+
sudo apt clean
52+
df -h
3453
- name: Install uv
3554
uses: astral-sh/setup-uv@v7
3655
with:
3756
enable-cache: true
3857
- name: Install dependencies with uv
58+
working-directory: repo
3959
run: |
4060
uv pip install --system --upgrade pip
4161
uv pip install --system -e .[all]
62+
- name: Check disk
63+
run: |
64+
set -euxo pipefail
65+
df -h
4266
- name: Fetch Data-Juicer Sphinx Template
67+
working-directory: repo
4368
run: |
4469
set -e
4570
echo "Cloning sphinx template..."
@@ -57,31 +82,37 @@ jobs:
5782
echo "Restoring custom files..."
5883
cp -rf /tmp/custom_files/source/* docs/sphinx_doc/source
5984
echo "Done!"
85+
df -h
6086
- name: Get git tags
87+
working-directory: repo
6188
run: |
6289
git fetch --all --tags
6390
git branch -a
6491
git tag
6592
- id: build
6693
name: Build Documentation
94+
working-directory: repo
6795
run: |
6896
cd docs/sphinx_doc
6997
python build_versions.py --tags
98+
df -h
7099
- name: Redirect index.html
100+
working-directory: repo
71101
run: |
72102
REPOSITORY_OWNER="${GITHUB_REPOSITORY_OWNER}"
73103
cd docs/sphinx_doc
74104
cp ./redirect.html build/index.html
75105
sed -i "s/\[REPOSITORY_OWNER\]/${REPOSITORY_OWNER}/g" build/index.html
76106
sed -i "s/\[PROJECT\]/${PROJECT}/g" build/index.html
77107
cp build/index.html build/404.html
108+
df -h
78109
- name: Upload Documentation
79110
uses: actions/upload-artifact@v4
80111
with:
81112
name: SphinxDoc
82-
path: "docs/sphinx_doc/build"
113+
path: "repo/docs/sphinx_doc/build"
83114
- uses: peaceiris/actions-gh-pages@v3
84115
if: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/')) }}
85116
with:
86117
github_token: ${{ secrets.GITHUB_TOKEN }}
87-
publish_dir: "docs/sphinx_doc/build"
118+
publish_dir: "repo/docs/sphinx_doc/build"

.pre-commit-hooks/build_op_doc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,8 @@ def get_op_list_from_code_for_formatter():
316316
test_path = os.path.join(FORMATTER_TEST_PREFIX, f"test_{formatter}")
317317
if os.path.isdir(code_path):
318318
continue
319+
if "_cpp" in code_path:
320+
continue
319321
docstrings = get_class_and_docstring(code_path)
320322
_, doc = docstrings[0]
321323
op_record_list.append(
@@ -351,6 +353,8 @@ def get_op_list_from_code():
351353
test_path = os.path.join(OP_TEST_PREFIX, type, f"test_{op}")
352354
if os.path.isdir(code_path):
353355
continue
356+
if not code_path.endswith(".py") or "_cpp" in code_path:
357+
continue
354358
docstrings = get_class_and_docstring(code_path)
355359
_, doc = docstrings[0]
356360
info = info_link(op.replace(".py", ""))

data_juicer/core/data/load_strategy.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ class RayS3DataLoadStrategy(RayDataLoadStrategy):
558558
"aws_session_token",
559559
"aws_region",
560560
"endpoint_url",
561+
"format",
561562
],
562563
"field_types": {"path": str},
563564
"custom_validators": {
@@ -590,23 +591,38 @@ def load_data(self, **kwargs):
590591
}
591592

592593
auto_detect = False
593-
data_source = self.ds_config.get("source", None)
594-
if data_source is None:
594+
data_format = self.ds_config.get("format", None)
595+
if data_format is None:
595596
auto_detect = True
596597
else:
597-
suffix = os.path.splitext(data_source)[1]
598-
if suffix in file_extension_map:
599-
data_format = file_extension_map[suffix]
600-
elif "." + data_source in file_extension_map:
601-
data_format = file_extension_map["." + data_source]
598+
# First check if it's already a valid format name
599+
valid_formats = set(file_extension_map.values())
600+
if data_format in valid_formats:
601+
pass # It's a valid format name, use it as is
602602
else:
603-
auto_detect = True
603+
# Try to interpret as an extension or filename
604+
suffix = os.path.splitext(data_format)[1]
605+
if suffix in file_extension_map:
606+
data_format = file_extension_map[suffix]
607+
elif "." + data_format in file_extension_map:
608+
data_format = file_extension_map["." + data_format]
609+
else:
610+
auto_detect = True
604611

605612
if auto_detect:
606613
# Extract extension from path
607614
file_extension = os.path.splitext(path)[1]
608-
data_format = file_extension_map.get(file_extension, "parquet") # Default to parquet for S3
609-
logger.info(f"Auto-detected data format: {data_format}")
615+
if file_extension in file_extension_map:
616+
data_format = file_extension_map[file_extension]
617+
logger.info(f"Auto-detected data format: {data_format} from extension: {file_extension}")
618+
else:
619+
data_format = "parquet"
620+
logger.warning(
621+
f"Could not determine data format from path '{path}' "
622+
f"(extension: '{file_extension or '(none)'}'), "
623+
f"defaulting to 'parquet'. "
624+
f"Consider explicitly specifying 'format' field in dataset config."
625+
)
610626
else:
611627
logger.info(f"Using specified data format: {data_format}")
612628

data_juicer/ops/deduplicator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .document_simhash_deduplicator import DocumentSimhashDeduplicator
77
from .image_deduplicator import ImageDeduplicator
88
from .ray_basic_deduplicator import RayBasicDeduplicator
9+
from .ray_bts_minhash_cpp_deduplicator import RayBTSMinhashCppDeduplicator
910
from .ray_bts_minhash_deduplicator import (
1011
RayBTSMinhashDeduplicator,
1112
RayBTSMinhashDeduplicatorWithUid,
@@ -27,5 +28,6 @@
2728
"RayVideoDeduplicator",
2829
"RayBTSMinhashDeduplicator",
2930
"RayBTSMinhashDeduplicatorWithUid",
31+
"RayBTSMinhashCppDeduplicator",
3032
"VideoDeduplicator",
3133
]
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#include <pybind11/pybind11.h>
2+
#include <pybind11/numpy.h>
3+
#include <pybind11/stl.h>
4+
#include <vector>
5+
#include <cstdint>
6+
#include <algorithm>
7+
#include <omp.h>
8+
9+
namespace py = pybind11;
10+
11+
// Constants
12+
const uint32_t MERSENNE_PRIME = 2147483647; // 2^31 - 1
13+
const uint32_t MAX_HASH = 4294967295; // 2^32 - 1
14+
15+
uint32_t simple_hash(const std::string& token) {
16+
uint32_t hash = 5381;
17+
for (const uint8_t c : token) {
18+
hash = ((hash << 5) + hash) + c;
19+
}
20+
return hash;
21+
}
22+
23+
std::vector<std::tuple<uint32_t, py::bytes, uint64_t>> calc_minhash_c(
24+
const std::vector<std::string>& tokens,
25+
const py::array_t<uint32_t>& perm_a,
26+
const py::array_t<uint32_t>& perm_b,
27+
const py::bytes& empty_hash_value,
28+
const std::vector<std::pair<size_t, size_t>>& hash_ranges,
29+
uint32_t union_find_parallel_num,
30+
uint64_t uid)
31+
{
32+
std::vector<std::tuple<uint32_t, py::bytes, uint64_t>> pairs;
33+
34+
if (tokens.empty()) {
35+
pairs.emplace_back(MAX_HASH % union_find_parallel_num, empty_hash_value, uid);
36+
return pairs;
37+
}
38+
39+
std::vector<uint32_t> hv;
40+
hv.reserve(tokens.size());
41+
for (const std::string& token : tokens) {
42+
hv.push_back(simple_hash(token));
43+
}
44+
45+
auto perm_a_data = perm_a.unchecked<1>();
46+
auto perm_b_data = perm_b.unchecked<1>();
47+
size_t num_permutation = perm_a.shape(0);
48+
49+
std::vector<uint32_t> hash_values(num_permutation, MAX_HASH);
50+
for (size_t i = 0; i < num_permutation; ++i) {
51+
for (uint32_t h : hv) {
52+
uint32_t phv = ((static_cast<uint64_t>(h) * perm_a_data(i) + perm_b_data(i)) % MERSENNE_PRIME) & MAX_HASH;
53+
hash_values[i] = std::min(hash_values[i], phv);
54+
}
55+
}
56+
57+
58+
for (size_t i = 0; i < hash_ranges.size(); ++i) {
59+
const auto& [start, end] = hash_ranges[i];
60+
std::vector<uint32_t> band_hash_values(hash_values.begin() + start, hash_values.begin() + end);
61+
62+
py::bytes hash_value = py::bytes(
63+
std::string(reinterpret_cast<char*>(&i), sizeof(uint32_t)) +
64+
std::string(reinterpret_cast<char*>(band_hash_values.data()), band_hash_values.size() * sizeof(uint32_t))
65+
);
66+
67+
uint32_t hash_table_id = hash_values[start] % union_find_parallel_num;
68+
pairs.emplace_back(hash_table_id, hash_value, uid);
69+
}
70+
71+
return pairs;
72+
}
73+
74+
py::list calc_minhash_batch_c(
75+
const std::vector<std::vector<std::string>>& tokens_list,
76+
const uint64_t uid_begin,
77+
const std::vector<uint64_t>& perm_a,
78+
const std::vector<uint64_t>& perm_b,
79+
const std::string& empty_hash_value,
80+
const std::vector<std::pair<size_t, size_t>>& hash_ranges,
81+
uint32_t union_find_parallel_num,
82+
uint32_t num_threads)
83+
{
84+
omp_set_num_threads(num_threads);
85+
size_t total_docs = tokens_list.size();
86+
std::vector<std::tuple<uint32_t, std::string, uint64_t>> intermediate_pairs;
87+
intermediate_pairs.reserve(total_docs * hash_ranges.size());
88+
89+
size_t num_permutation = perm_a.size();
90+
91+
#pragma omp parallel
92+
{
93+
std::vector<std::tuple<uint32_t, std::string, uint64_t>> local_pairs;
94+
local_pairs.reserve(total_docs * hash_ranges.size() / num_threads);
95+
std::vector<uint32_t> hash_values(num_permutation);
96+
97+
#pragma omp for nowait
98+
for (size_t doc_idx = 0; doc_idx < total_docs; ++doc_idx) {
99+
const auto& tokens = tokens_list[doc_idx];
100+
uint64_t uid = uid_begin + doc_idx;
101+
102+
if (tokens.empty()) {
103+
local_pairs.emplace_back(MAX_HASH % union_find_parallel_num, empty_hash_value, uid);
104+
continue;
105+
}
106+
107+
std::fill(hash_values.begin(), hash_values.end(), MAX_HASH);
108+
for (const auto& token : tokens) {
109+
uint32_t h = simple_hash(token);
110+
for (size_t i = 0; i < num_permutation; ++i) {
111+
uint32_t phv = (static_cast<uint64_t>(h) * perm_a[i] + perm_b[i]) >> 32;
112+
hash_values[i] = std::min(hash_values[i], phv);
113+
}
114+
}
115+
116+
for (size_t i = 0; i < hash_ranges.size(); ++i) {
117+
const auto& [start, end] = hash_ranges[i];
118+
std::string hash_value(reinterpret_cast<char*>(&i), sizeof(uint32_t));
119+
hash_value.append(reinterpret_cast<char*>(&hash_values[start]), (end - start) * sizeof(uint32_t));
120+
121+
uint32_t hash_table_id = hash_values[start] % union_find_parallel_num;
122+
local_pairs.emplace_back(hash_table_id, std::move(hash_value), uid);
123+
}
124+
}
125+
126+
#pragma omp critical
127+
{
128+
intermediate_pairs.insert(intermediate_pairs.end(), local_pairs.begin(), local_pairs.end());
129+
}
130+
}
131+
py::list result;
132+
for (const auto& item : intermediate_pairs) {
133+
uint32_t first = std::get<0>(item);
134+
py::bytes second = py::bytes(std::get<1>(item));
135+
uint64_t third = std::get<2>(item);
136+
result.append(py::make_tuple(first, second, third));
137+
}
138+
return result;
139+
}
140+
141+
std::vector<std::tuple<uint32_t, py::bytes>> calc_simple_minhash_c(
142+
const std::vector<std::string>& tokens,
143+
const py::array_t<uint32_t>& perm_a,
144+
const py::array_t<uint32_t>& perm_b,
145+
const std::vector<std::pair<size_t, size_t>>& hash_ranges,
146+
uint32_t bucket_per_band,
147+
uint64_t uid)
148+
{
149+
std::vector<std::tuple<uint32_t, py::bytes>> pairs;
150+
151+
if (tokens.empty()) {
152+
pairs.emplace_back(0, py::bytes(""));
153+
return pairs;
154+
}
155+
156+
std::vector<uint32_t> hv;
157+
hv.reserve(tokens.size());
158+
for (const std::string& token : tokens) {
159+
hv.push_back(simple_hash(token));
160+
}
161+
162+
auto perm_a_data = perm_a.unchecked<1>();
163+
auto perm_b_data = perm_b.unchecked<1>();
164+
size_t num_permutation = perm_a.shape(0);
165+
166+
std::vector<uint32_t> hash_values(num_permutation, MAX_HASH);
167+
for (size_t i = 0; i < num_permutation; ++i) {
168+
for (uint32_t h : hv) {
169+
uint32_t phv = ((static_cast<uint64_t>(h) * perm_a_data(i) + perm_b_data(i)) % MERSENNE_PRIME) & MAX_HASH;
170+
hash_values[i] = std::min(hash_values[i], phv);
171+
}
172+
}
173+
174+
175+
for (size_t i = 0; i < hash_ranges.size(); ++i) {
176+
const auto& [start, end] = hash_ranges[i];
177+
std::vector<uint32_t> band_hash_values(hash_values.begin() + start, hash_values.begin() + end);
178+
179+
py::bytes hash_value = py::bytes(
180+
std::string(reinterpret_cast<char*>(band_hash_values.data()), band_hash_values.size() * sizeof(uint32_t))
181+
);
182+
183+
uint32_t hash_table_id = bucket_per_band * i + (hash_values[start] % bucket_per_band);
184+
pairs.emplace_back(hash_table_id, hash_value);
185+
}
186+
187+
return pairs;
188+
}
189+
190+
191+
PYBIND11_MODULE(minhash, m) {
192+
m.def("calc_minhash_c", &calc_minhash_c, "C++ implementation of calc_minhash");
193+
m.def("calc_simple_minhash_c", &calc_simple_minhash_c, "C++ implementation of calc_simple_minhash");
194+
m.def("calc_minhash_batch_c", &calc_minhash_batch_c, "C++ implementation of calc_minhash (batch version)");
195+
}

0 commit comments

Comments
 (0)