Skip to content

Commit

Permalink
Added main module detection and command line construction, now should…
Browse files Browse the repository at this point in the history
… show the right name (instead of __main__.py) when called for modules

Co-Authored-By: Henry Fredrick Schreiner <[email protected]>
  • Loading branch information
KOLANICH and henryiii committed Oct 6, 2022
1 parent ef13cbc commit 8958983
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
10 changes: 8 additions & 2 deletions plumbum/cli/application.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import functools
import inspect
import os
import re
import sys
from collections import defaultdict
from textwrap import TextWrapper

from plumbum import colors, local
from plumbum.cli.i18n import get_translation_for
from plumbum.lib import getdoc
from plumbum.lib import getdoc, get_main_module_frame

from .switches import (
CountOf,
Expand Down Expand Up @@ -77,6 +78,7 @@ def __repr__(self):
# CLI Application base class
# ===================================================================================================

main_module_ending_rx = re.compile("\.__main__$")

class Application:
"""The base class for CLI applications; your "entry point" class should derive from it,
Expand Down Expand Up @@ -182,7 +184,11 @@ def __init__(self, executable):
# Filter colors

if self.PROGNAME is None:
self.PROGNAME = os.path.basename(executable)
spec = get_main_module_frame().f_globals.get("__spec__", None)
if spec:
self.PROGNAME = " ".join(("python -m", main_module_ending_rx.sub("", spec.name)))
else:
self.PROGNAME = os.path.basename(executable)
elif isinstance(self.PROGNAME, colors._style):
self.PROGNAME = self.PROGNAME | os.path.basename(executable)
elif colors.filter(self.PROGNAME) == "":
Expand Down
9 changes: 9 additions & 0 deletions plumbum/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,12 @@ def read_fd_decode_safely(fd, size=4096):
data += os.read(fd.fileno(), 1)

return data, data.decode("utf-8")

def get_main_module_frame():
"""
Gets the frame of the __main__ module (the one which is called with command line) of an app.
"""
fr = sys._getframe(0)
while fr and fr.f_globals["__name__"] != "__main__":
fr = fr.f_back
return fr

0 comments on commit 8958983

Please sign in to comment.