Skip to content

Commit 01260df

Browse files
committed
Add unit tests
1 parent 4925f72 commit 01260df

File tree

3 files changed

+234
-16
lines changed

3 files changed

+234
-16
lines changed

polars_bio/quality_stats.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def base_sequence_quality(
15-
df: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame],
15+
df: Union[str, Path, pl.DataFrame, pl.LazyFrame, pd.DataFrame],
1616
quality_scores_column: str = "quality_scores",
1717
output_type: str = "polars.DataFrame",
1818
target_partitions: int = 8,
@@ -32,7 +32,8 @@ def base_sequence_quality(
3232
"datafusion.execution.target_partitions", str(target_partitions), False
3333
)
3434

35-
if isinstance(df, str):
35+
if isinstance(df, (str, Path)):
36+
df = str(df)
3637
supported_exts = {".parquet", ".csv", ".bed", ".vcf", ".fastq"}
3738
ext = set(Path(df).suffixes)
3839
if not (supported_exts & ext or not ext):

src/quantile_stats.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -235,24 +235,26 @@ fn calculate_histogram_stats(hist: &[u64]) -> Option<(f64, f64, f64, f64, f64, f
235235

236236
fn quantile(hist: &[u64], quantile: f64, total: u64) -> f64 {
237237
let target = quantile * (total - 1) as f64;
238+
let target_ = target.floor();
239+
let delta = target - target_;
240+
let n = target_ as u64 + 1;
241+
let mut lo = None;
238242
let mut acc = 0u64;
239-
let mut prev_idx = 0usize;
240-
for (idx, &count) in hist.iter().enumerate() {
241-
if count == 0 {
242-
continue;
243-
}
244-
if (acc as f64) <= target && (acc + count) as f64 > target {
245-
let delta = target - acc as f64;
246-
if count > 1 && delta > 0.0 {
247-
return idx as f64 + delta / count as f64;
248-
} else {
249-
return idx as f64;
250-
}
243+
for (hi, &count) in hist.iter().enumerate().filter(|(_, &count)| count > 0) {
244+
if acc == n && lo.is_some() {
245+
let lo = lo.unwrap() as f64;
246+
return (lo + (hi as f64 - lo) * delta) as f64;
247+
} else if acc + count > n {
248+
return hi as f64;
251249
}
252250
acc += count;
253-
prev_idx = idx;
251+
lo = Some(hi);
254252
}
255-
prev_idx as f64
253+
254+
hist.iter().enumerate().fold(
255+
0_usize,
256+
|acc, (value, &count)| if count > 0 { value } else { acc },
257+
) as f64
256258
}
257259

258260
let q1 = quantile(hist, 0.25, total_count);
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from pathlib import Path
2+
import polars as pl
3+
import pandas as pd
4+
5+
import pytest
6+
7+
import polars_bio as pb
8+
9+
10+
class TestBaseSequenceQuality:
11+
def record_header(self, len: int):
12+
return f"@test\n{'N'*len}\n+\n"
13+
14+
def simple_test_data(self):
15+
data = self.record_header(1) + chr(5 + 33) + "\n"
16+
result = pl.DataFrame(
17+
{
18+
"pos": [0],
19+
"avg": [5.0],
20+
"q1": [5.0],
21+
"median": [5.0],
22+
"q3": [5.0],
23+
"lower": [5.0],
24+
"upper": [5.0],
25+
}
26+
)
27+
28+
return data, result
29+
30+
@pytest.mark.usefixtures("tmp_path")
31+
def test_result_from_empty_fastq_should_be_empty_df_polars(self, tmp_path: Path):
32+
fastq_file = tmp_path / "test.fastq"
33+
fastq_file.write_text("", encoding="ascii")
34+
35+
result = pb.base_sequence_quality(fastq_file)
36+
assert isinstance(result, pl.DataFrame)
37+
assert pb.base_sequence_quality(fastq_file).is_empty()
38+
39+
@pytest.mark.usefixtures("tmp_path")
40+
def test_result_from_empty_fastq_should_be_empty_df_pandas(self, tmp_path: Path):
41+
fastq_file = tmp_path / "test.fastq"
42+
fastq_file.write_text("", encoding="ascii")
43+
44+
result = pb.base_sequence_quality(fastq_file, output_type="pandas.DataFrame")
45+
assert isinstance(result, pd.DataFrame)
46+
assert pb.base_sequence_quality(fastq_file).is_empty()
47+
48+
@pytest.mark.usefixtures("tmp_path")
49+
def test_one_record_one_length(self, tmp_path: Path):
50+
data, expected_result = self.simple_test_data()
51+
fastq_file = tmp_path / "test.fastq"
52+
fastq_file.write_text(data, encoding="ascii")
53+
54+
result = pb.base_sequence_quality(fastq_file)
55+
assert result.equals(expected_result)
56+
57+
@pytest.mark.usefixtures("tmp_path")
58+
def test_one_record_one_length_lazyframe(self, tmp_path: Path):
59+
data, expected_result = self.simple_test_data()
60+
fastq_file = tmp_path / "test.fastq"
61+
fastq_file.write_text(data, encoding="ascii")
62+
63+
lf = pb.read_fastq(str(fastq_file))
64+
result = pb.base_sequence_quality(lf)
65+
66+
assert result.equals(expected_result)
67+
68+
@pytest.mark.usefixtures("tmp_path")
69+
def test_one_record_one_length_polars(self, tmp_path: Path):
70+
data, expected_result = self.simple_test_data()
71+
fastq_file = tmp_path / "test.fastq"
72+
fastq_file.write_text(data, encoding="ascii")
73+
74+
lf = pb.read_fastq(str(fastq_file))
75+
polars_df = lf.collect()
76+
result = pb.base_sequence_quality(polars_df)
77+
78+
assert result.equals(expected_result)
79+
80+
@pytest.mark.usefixtures("tmp_path")
81+
def test_one_record_one_length_pandas(self, tmp_path: Path):
82+
data, expected_result = self.simple_test_data()
83+
fastq_file = tmp_path / "test.fastq"
84+
fastq_file.write_text(data, encoding="ascii")
85+
86+
lf = pb.read_fastq(str(fastq_file))
87+
pandas_df = lf.collect().to_pandas()
88+
result = pb.base_sequence_quality(pandas_df)
89+
90+
assert result.equals(expected_result)
91+
92+
@pytest.mark.usefixtures("tmp_path")
93+
def test_one_record_two_length(self, tmp_path: Path):
94+
fastq_file = tmp_path / "test.fastq"
95+
fastq_file.write_text(
96+
self.record_header(1) + chr(5 + 33) + chr(6 + 33) + "\n",
97+
encoding="ascii",
98+
)
99+
100+
result = pb.base_sequence_quality(fastq_file).sort(by="pos")
101+
assert result.equals(
102+
pl.DataFrame(
103+
{
104+
"pos": [0, 1],
105+
"avg": [5.0, 6.0],
106+
"q1": [5.0, 6.0],
107+
"median": [5.0, 6.0],
108+
"q3": [5.0, 6.0],
109+
"lower": [5.0, 6.0],
110+
"upper": [5.0, 6.0],
111+
}
112+
)
113+
)
114+
115+
@pytest.mark.usefixtures("tmp_path")
116+
def test_two_record_one_length(self, tmp_path: Path):
117+
fastq_file = tmp_path / "test.fastq"
118+
119+
file_content = ""
120+
for read in [0, 2]:
121+
file_content += self.record_header(2) + chr(read + 33) + "\n"
122+
123+
fastq_file.write_text(file_content, encoding="ascii")
124+
125+
result = pb.base_sequence_quality(fastq_file)
126+
assert result.equals(
127+
pl.DataFrame(
128+
{
129+
"pos": [0],
130+
"avg": [1.0],
131+
"q1": [0.5],
132+
"median": [1.0],
133+
"q3": [1.5],
134+
"lower": [-1.0],
135+
"upper": [3.0],
136+
}
137+
)
138+
)
139+
140+
@pytest.mark.usefixtures("tmp_path")
141+
def test_three_record_one_length(self, tmp_path: Path):
142+
fastq_file = tmp_path / "test.fastq"
143+
144+
file_content = ""
145+
for read in [0, 2, 4]:
146+
file_content += self.record_header(3) + chr(read + 33) + "\n"
147+
148+
fastq_file.write_text(file_content, encoding="ascii")
149+
150+
result = pb.base_sequence_quality(fastq_file)
151+
assert result.equals(
152+
pl.DataFrame(
153+
{
154+
"pos": [0],
155+
"avg": [2.0],
156+
"q1": [1.0],
157+
"median": [2.0],
158+
"q3": [3.0],
159+
"lower": [-2.0],
160+
"upper": [6.0],
161+
}
162+
)
163+
)
164+
165+
@pytest.mark.usefixtures("tmp_path")
166+
def test_four_record_one_length(self, tmp_path: Path):
167+
fastq_file = tmp_path / "test.fastq"
168+
169+
file_content = ""
170+
for read in [0, 2, 4, 9]:
171+
file_content += self.record_header(3) + chr(read + 33) + "\n"
172+
173+
fastq_file.write_text(file_content, encoding="ascii")
174+
175+
result = pb.base_sequence_quality(fastq_file)
176+
assert result.equals(
177+
pl.DataFrame(
178+
{
179+
"pos": [0],
180+
"avg": [3.75],
181+
"q1": [1.5],
182+
"median": [3.0],
183+
"q3": [5.25],
184+
"lower": [-4.125],
185+
"upper": [10.875],
186+
}
187+
)
188+
)
189+
190+
@pytest.mark.usefixtures("tmp_path")
191+
def test_four_record_two_length(self, tmp_path: Path):
192+
fastq_file = tmp_path / "test.fastq"
193+
194+
file_content = ""
195+
for read1, read2 in zip([0, 2, 4, 9], [1, 3, 5, 10]):
196+
file_content += (
197+
self.record_header(3) + chr(read1 + 33) + chr(read2 + 33) + "\n"
198+
)
199+
200+
fastq_file.write_text(file_content, encoding="ascii")
201+
202+
result = pb.base_sequence_quality(fastq_file).sort(by="pos")
203+
assert result.equals(
204+
pl.DataFrame(
205+
{
206+
"pos": [0, 1],
207+
"avg": [3.75, 4.75],
208+
"q1": [1.5, 2.5],
209+
"median": [3.0, 4.0],
210+
"q3": [5.25, 6.25],
211+
"lower": [-4.125, -3.125],
212+
"upper": [10.875, 11.875],
213+
}
214+
)
215+
)

0 commit comments

Comments
 (0)