Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 2 additions & 53 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
.idea
gtfs/__version__.py

# Generated / downloaded files
*.zip
*.p
*.csv
*.html
patco-gtfs/
transitfeedcrash.txt
*.pkl

# virtualenv
.venv/
Expand All @@ -18,51 +15,3 @@ transitfeedcrash.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover

# Translations
*.mo
*.pot

# PyBuilder
target/
Binary file added feeds/AlbanyNy.pkl
Binary file not shown.
Binary file added feeds/Berlin.pkl
Binary file not shown.
206 changes: 148 additions & 58 deletions gtfs/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
"""Command line interface for fetching GTFS."""
import logging
import os
import threading
from typing import Optional

import typer
Expand All @@ -9,11 +10,9 @@

from .feed_source import FeedSource
from .feed_sources import feed_sources
from .utils.constants import Predicate, spinner
from .utils.constants import LOG, Predicate, spinner
from .utils.geom import Bbox, bbox_contains_bbox, bbox_intersects_bbox

logging.basicConfig()
LOG = logging.getLogger()
app = typer.Typer()


Expand All @@ -39,6 +38,18 @@ def check_bbox(bbox: str) -> Optional[Bbox]:
return Bbox(min_x, min_y, max_x, max_y)


def check_sources(sources: str) -> Optional[str]:
"""Check if the sources are valid."""
if sources is None:
return None
sources = sources.split(",")
for source in sources:
if not any(src.__name__.lower() == source.lower() for src in feed_sources):
raise typer.BadParameter(f"{source} is not a valid feed source!")

return ",".join(sources)


@app.command()
def list_feeds(
bbox: Annotated[
Expand Down Expand Up @@ -67,7 +78,12 @@ def list_feeds(
),
] = False,
) -> None:
"""Filter feeds spatially based on bounding box."""
"""Filter feeds spatially based on bounding box or list all of them.

:param bbox: set of coordinates to filter feeds spatially
:param predicate: the gtfs feed should intersect or should be contained inside the user's bbox
:param pretty: display feeds inside a pretty table
"""
if bbox is None and predicate is not None:
raise typer.BadParameter(
f"Please pass a bbox if you want to filter feeds spatially based on predicate = {predicate}!"
Expand All @@ -83,6 +99,8 @@ def list_feeds(
["Feed Source", "Transit URL", "Bounding Box"], theme=Themes.OCEAN, hrules=1
)

filtered_srcs = ""

for src in feed_sources:
feed_bbox: Bbox = src.bbox
if bbox is not None and predicate == "contains":
Expand All @@ -94,6 +112,8 @@ def list_feeds(
):
continue

filtered_srcs += src.__name__ + ", "

if pretty is True:
pretty_output.add_row(
[
Expand All @@ -109,67 +129,137 @@ def list_feeds(
if pretty is True:
print("\n" + pretty_output.get_string())

if typer.confirm("Do you want to fetch feeds from these sources?"):
fetch_feeds(sources=filtered_srcs[:-1])


@app.command()
def fetch_feeds(sources=None):
"""
def fetch_feeds(
sources: Annotated[
Optional[str],
typer.Option(
"--sources",
"-src",
help="pass value as a string separated by commas like this: Berlin,AlbanyNy,...",
callback=check_sources,
),
] = None,
search: Annotated[
Optional[str],
typer.Option(
"--search",
"-s",
help="search for feeds based on a string",
),
] = None,
Comment thread
Ananya2001-an marked this conversation as resolved.
Outdated
output_dir: Annotated[
Optional[str],
typer.Option(
"--output-dir",
"-o",
help="the directory where the downloaded feeds will be saved, default is feeds",
),
] = "feeds",
concurrency: Annotated[
Optional[int],
typer.Option(
"--concurrency",
"-c",
help="the number of concurrent downloads, default is 4",
),
] = 4,
) -> None:
"""Fetch feeds from sources.

:param sources: List of :FeedSource: modules to fetch; if not set, will fetch all available.
:param search: Search for feeds based on a string.
:param output_dir: The directory where the downloaded feeds will be saved; default is feeds.
:param concurrency: The number of concurrent downloads; default is 4.
"""
statuses = {} # collect the statuses for all the files
# statuses = {} # collect the statuses for all the files

# default to use all of them
if not sources:
sources = feed_sources

LOG.info("Going to fetch feeds from sources: %s", sources)
for src in sources:
LOG.debug("Going to start fetch for %s...", src)
try:
if issubclass(src, FeedSource):
inst = src()
inst.fetch()
statuses.update(inst.status)
else:
LOG.warning(
"Skipping class %s, which does not subclass FeedSource.",
src.__name__,
)
except AttributeError:
LOG.error("Skipping feed %s, which could not be found.", src)

# remove last check key set at top level of each status dictionary
if "last_check" in statuses:
del statuses["last_check"]

ptable = ColorTable(
[
"file",
"new?",
"valid?",
"current?",
"newly effective?",
"error",
],
theme=Themes.OCEAN,
hrules=1,
)

for file_name in statuses:
stat = statuses[file_name]
msg = []
msg.append(file_name)
msg.append("x" if "is_new" in stat and stat["is_new"] else "")
msg.append("x" if "is_valid" in stat and stat["is_valid"] else "")
msg.append("x" if "is_current" in stat and stat["is_current"] else "")
msg.append("x" if "newly_effective" in stat and stat.get("newly_effective") else "")
if "error" in stat:
msg.append(stat["error"])
if not search:
# fetch all feeds
sources = feed_sources
else:
msg.append("")
ptable.add_row(msg)
# fetch feeds based on search
sources = [
src
for src in feed_sources
if search.lower() in src.__name__.lower() or search.lower() in src.url.lower()
]
else:
if search:
raise typer.BadParameter("Please pass either sources or search, not both at the same time!")
else:
sources = [src for src in feed_sources if src.__name__.lower() in sources.lower()]

output_dir_path = os.path.join(os.getcwd(), output_dir)
Comment thread
Ananya2001-an marked this conversation as resolved.
Outdated
if not os.path.exists(output_dir_path):
os.makedirs(output_dir_path)

LOG.info(f"Going to fetch feeds from sources: {sources}")

threads = []

def thread_worker():
Comment thread
Ananya2001-an marked this conversation as resolved.
Outdated
while True:
try:
src = sources.pop(0)
except IndexError:
break

LOG.debug(f"Going to start fetch for {src}...")
try:
if issubclass(src, FeedSource):
inst = src()
inst.ddir = output_dir_path
inst.status_file = os.path.join(inst.ddir, src.__name__ + ".pkl")
inst.fetch()
# statuses.update(inst.status)
else:
LOG.warning(f"Skipping class {src.__name__}, which does not subclass FeedSource.")
except AttributeError:
LOG.error(f"Skipping feed {src}, which could not be found.")

for _ in range(concurrency):
thread = threading.Thread(target=thread_worker)
thread.start()
threads.append(thread)

# Wait for all threads to complete
for thread in threads:
thread.join()

LOG.info("Results:\n%s", ptable.get_string())
LOG.info("All done!")
# ptable = ColorTable(
# [
# "file",
# "new?",
# "valid?",
# "current?",
# "newly effective?",
# "error",
# ],
# theme=Themes.OCEAN,
# hrules=1,
# )
#
# for file_name in statuses:
# stat = statuses[file_name]
# msg = []
# msg.append(file_name)
# msg.append("x" if "is_new" in stat and stat["is_new"] else "")
# msg.append("x" if "is_valid" in stat and stat["is_valid"] else "")
# msg.append("x" if "is_current" in stat and stat["is_current"] else "")
# msg.append("x" if "newly_effective" in stat and stat.get("newly_effective") else "")
# if "error" in stat:
# msg.append(stat["error"])
# else:
# msg.append("")
# ptable.add_row(msg)
#
# LOG.info("\n" + ptable.get_string())


if __name__ == "__main__":
Expand Down
Loading