Skip to content

Commit

Permalink
Added main module detection and command line construction, now should…
Browse files Browse the repository at this point in the history
… show the right name (instead of __main__.py) when called for modules

Co-Authored-By: Henry Fredrick Schreiner <[email protected]>
  • Loading branch information
KOLANICH and henryiii committed Jan 12, 2021
1 parent 9883fc8 commit ae39984
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
10 changes: 8 additions & 2 deletions plumbum/cli/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import os
import sys
import functools
import re
from textwrap import TextWrapper
from collections import defaultdict

from plumbum.lib import six, getdoc
from plumbum.lib import six, getdoc, get_main_module_frame
from .terminal import get_terminal_size
from .switches import (SwitchError, UnknownSwitch, MissingArgument,
WrongArgumentType, MissingMandatorySwitch,
Expand Down Expand Up @@ -68,6 +69,7 @@ def __repr__(self):
# CLI Application base class
#===================================================================================================

main_module_ending_rx = re.compile("\.__main__$")

class Application(object):
"""The base class for CLI applications; your "entry point" class should derive from it,
Expand Down Expand Up @@ -166,7 +168,11 @@ def __init__(self, executable):
# Filter colors

if self.PROGNAME is None:
self.PROGNAME = os.path.basename(executable)
spec = get_main_module_frame().f_globals.get("__spec__", None)
if spec:
self.PROGNAME = " ".join(("python -m", main_module_ending_rx.sub("", spec.name)))
else:
self.PROGNAME = os.path.basename(executable)
elif isinstance(self.PROGNAME, colors._style):
self.PROGNAME = self.PROGNAME | os.path.basename(executable)
elif colors.filter(self.PROGNAME) == '':
Expand Down
9 changes: 9 additions & 0 deletions plumbum/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,12 @@ def read_fd_decode_safely(fd, size=4096):
if i == 3:
raise
data += os.read(fd.fileno(), 1)

def get_main_module_frame():
"""
Gets the frame of the __main__ module (the one which is called with command line) of an app.
"""
fr=sys._getframe(0)
while fr and fr.f_globals['__name__'] != '__main__':
fr=fr.f_back
return fr

0 comments on commit ae39984

Please sign in to comment.