Skip to content

Commit 9ba30cc

Browse files
authored
Merge pull request #338 from GeospatialPython/Writer_speed_tests
Add Writer speed tests
2 parents b2e513b + 94ef7dd commit 9ba30cc

File tree

2 files changed

+75
-39
lines changed

2 files changed

+75
-39
lines changed

run_benchmarks.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22

33
from __future__ import annotations
44

5+
import collections
56
import functools
67
import os
78
import timeit
89
from collections.abc import Callable
910
from pathlib import Path
11+
from tempfile import TemporaryFile as TempF
1012
from typing import Union
1113

12-
import shapefile as shp
14+
import shapefile
1315

1416
# For shapefiles from https://github.com/JamesParrott/PyShp_test_shapefile
1517
DEFAULT_PYSHP_TEST_REPO = (
@@ -31,26 +33,41 @@ def benchmark(
3133
name: str,
3234
run_count: int,
3335
func: Callable,
34-
col_width: tuple,
36+
col_widths: tuple,
3537
compare_to: float | None = None,
3638
) -> float:
3739
placeholder = "Running..."
38-
print(f"{name:>{col_width[0]}} | {placeholder}", end="", flush=True)
40+
print(f"{name:>{col_widths[0]}} | {placeholder}", end="", flush=True)
3941
time_taken = timeit.timeit(func, number=run_count)
4042
print("\b" * len(placeholder), end="")
4143
time_suffix = " s"
42-
print(f"{time_taken:{col_width[1]-len(time_suffix)}.3g}{time_suffix}", end="")
44+
print(f"{time_taken:{col_widths[1]-len(time_suffix)}.3g}{time_suffix}", end="")
4345
print()
4446
return time_taken
4547

4648

49+
fields = {}
50+
shapeRecords = collections.defaultdict(list)
51+
52+
4753
def open_shapefile_with_PyShp(target: Union[str, os.PathLike]):
48-
with shp.Reader(target) as r:
54+
with shapefile.Reader(target) as r:
55+
fields[target] = r.fields
4956
for shapeRecord in r.iterShapeRecords():
50-
pass
57+
shapeRecords[target].append(shapeRecord)
58+
59+
60+
def write_shapefile_with_PyShp(target: Union[str, os.PathLike]):
61+
with TempF("wb") as shp, TempF("wb") as dbf, TempF("wb") as shx:
62+
with shapefile.Writer(shp=shp, dbf=dbf, shx=shx) as w: # type: ignore [arg-type]
63+
for field_info_tuple in fields[target]:
64+
w.field(*field_info_tuple)
65+
for shapeRecord in shapeRecords[target]:
66+
w.shape(shapeRecord.shape)
67+
w.record(*shapeRecord.record)
5168

5269

53-
READER_TESTS = {
70+
SHAPEFILES = {
5471
"Blockgroups": blockgroups_file,
5572
"Edit": edit_file,
5673
"Merge": merge_file,
@@ -60,24 +77,47 @@ def open_shapefile_with_PyShp(target: Union[str, os.PathLike]):
6077
}
6178

6279

63-
def run(run_count: int) -> None:
64-
col_width = (21, 10)
80+
# Load files to avoid one off delays that only affect first disk seek
81+
for file_path in SHAPEFILES.values():
82+
file_path.read_bytes()
83+
84+
reader_benchmarks = [
85+
functools.partial(
86+
benchmark,
87+
name=f"Read {test_name}",
88+
func=functools.partial(open_shapefile_with_PyShp, target=target),
89+
)
90+
for test_name, target in SHAPEFILES.items()
91+
]
92+
93+
# Require fields and shapeRecords to first have been populated
94+
# from data from previouly running the reader_benchmarks
95+
writer_benchmarks = [
96+
functools.partial(
97+
benchmark,
98+
name=f"Write {test_name}",
99+
func=functools.partial(write_shapefile_with_PyShp, target=target),
100+
)
101+
for test_name, target in SHAPEFILES.items()
102+
]
103+
104+
105+
def run(run_count: int, benchmarks: list[Callable[[], None]]) -> None:
106+
col_widths = (22, 10)
65107
col_head = ("parser", "exec time", "performance (more is better)")
66-
# Load files to avoid one off delays that only affect first disk seek
67-
for file_path in READER_TESTS.values():
68-
file_path.read_bytes()
69108
print(f"Running benchmarks {run_count} times:")
70-
print("-" * col_width[0] + "---" + "-" * col_width[1])
71-
print(f"{col_head[0]:>{col_width[0]}} | {col_head[1]:>{col_width[1]}}")
72-
print("-" * col_width[0] + "-+-" + "-" * col_width[1])
73-
for test_name, target in READER_TESTS.items():
74-
benchmark(
75-
f"Read {test_name}",
76-
run_count,
77-
functools.partial(open_shapefile_with_PyShp, target=target),
78-
col_width,
109+
print("-" * col_widths[0] + "---" + "-" * col_widths[1])
110+
print(f"{col_head[0]:>{col_widths[0]}} | {col_head[1]:>{col_widths[1]}}")
111+
print("-" * col_widths[0] + "-+-" + "-" * col_widths[1])
112+
for benchmark in benchmarks:
113+
benchmark( # type: ignore [call-arg]
114+
run_count=run_count,
115+
col_widths=col_widths,
79116
)
80117

81118

82119
if __name__ == "__main__":
83-
run(1)
120+
print("Reader tests:")
121+
run(1, reader_benchmarks) # type: ignore [arg-type]
122+
print("\n\nWriter tests:")
123+
run(1, writer_benchmarks) # type: ignore [arg-type]

test_shapefile.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,7 @@ def test_record_oid():
987987
assert shaperec.record.oid == i
988988

989989

990+
@pytest.mark.slow
990991
def test_iterRecords_start_stop():
991992
"""
992993
Assert that Reader.iterRecords(start, stop)
@@ -999,36 +1000,31 @@ def test_iterRecords_start_stop():
9991000

10001001
# Arbitrary selection of record indices
10011002
# (there are 663 records in blockgroups.dbf).
1002-
for i in [
1003+
indices = [
10031004
0,
10041005
1,
10051006
2,
1006-
3,
10071007
5,
10081008
11,
1009-
17,
1010-
33,
1011-
51,
1012-
103,
1013-
170,
1014-
234,
1015-
435,
1016-
543,
1009+
41,
1010+
310,
1011+
513,
10171012
N - 3,
1018-
N - 2,
10191013
N - 1,
1020-
]:
1021-
for record in sf.iterRecords(start=i):
1014+
]
1015+
for i, index in enumerate(indices):
1016+
for record in sf.iterRecords(start=index):
10221017
assert record == sf.record(record.oid)
10231018

1024-
for record in sf.iterRecords(stop=i):
1019+
for record in sf.iterRecords(stop=index):
10251020
assert record == sf.record(record.oid)
10261021

1027-
for stop in range(i, len(sf)):
1022+
for j in range(i + 1, len(indices)):
1023+
stop = indices[j]
10281024
# test negative indexing from end, as well as
10291025
# positive values of stop, and its default
1030-
for stop_arg in (stop, stop - len(sf)):
1031-
for record in sf.iterRecords(start=i, stop=stop_arg):
1026+
for stop_arg in (stop, stop - N):
1027+
for record in sf.iterRecords(start=index, stop=stop_arg):
10321028
assert record == sf.record(record.oid)
10331029

10341030

0 commit comments

Comments
 (0)