11from __future__ import annotations
22
33import inspect
4- from collections .abc import Mapping , Sequence # noqa: TC003
4+ import re
5+ import textwrap
6+ from collections .abc import Sequence
57from copy import copy
6- from functools import partial
8+ from functools import cache , partial
79from itertools import combinations , product
810from numbers import Integral
9- from typing import (
10- TYPE_CHECKING ,
11- Any , # noqa: TC003
12- Literal , # noqa: TC003
13- )
11+ from typing import TYPE_CHECKING
1412
1513import numpy as np
1614import pandas as pd
17- from anndata import AnnData # noqa: TC002
18- from cycler import Cycler # noqa: TC002
1915from matplotlib import colormaps , colors , patheffects , rcParams
2016from matplotlib import pyplot as plt
21- from matplotlib .axes import Axes # noqa: TC002
22- from matplotlib .colors import (
23- Colormap , # noqa: TC002
24- Normalize ,
25- )
26- from matplotlib .figure import Figure # noqa: TC002
17+ from matplotlib .colors import Normalize
2718from matplotlib .markers import MarkerStyle
28- from numpy .typing import NDArray # noqa: TC002
2919
3020from ... import logging as logg
3121from ..._compat import deprecated
3222from ..._settings import settings
33- from ..._utils import (
34- Empty , # noqa: TC001
35- _doc_params ,
36- _empty ,
37- sanitize_anndata ,
38- )
23+ from ..._utils import _doc_params , _empty , sanitize_anndata
3924from ..._utils ._doctests import doctest_internet
4025from ...get import _check_mask
41- from ...tools ._draw_graph import _Layout # noqa: TC001
4226from .. import _utils
4327from .._docs import (
4428 doc_adata_color_etc ,
4731 doc_scatter_spatial ,
4832 doc_show_save_ax ,
4933)
50- from .._utils import (
51- ColorLike , # noqa: TC001
52- VBound , # noqa: TC001
53- _FontSize , # noqa: TC001
54- _FontWeight , # noqa: TC001
55- _LegendLoc , # noqa: TC001
56- check_colornorm ,
57- check_projection ,
58- circles ,
59- )
34+ from .._utils import check_colornorm , check_projection , circles
6035
6136if TYPE_CHECKING :
62- from collections .abc import Collection
37+ from collections .abc import Callable , Collection , Mapping
38+ from types import FunctionType
39+ from typing import Any , Literal
40+
41+ from anndata import AnnData
42+ from cycler import Cycler
43+ from matplotlib .axes import Axes
44+ from matplotlib .colors import Colormap
45+ from matplotlib .figure import Figure
46+ from numpy .typing import NDArray
47+
48+ from ..._utils import Empty
49+ from ...tools ._draw_graph import _Layout
50+ from .._utils import ColorLike , VBound , _FontSize , _FontWeight , _LegendLoc
6351
6452
6553@_doc_params (
@@ -600,10 +588,29 @@ def _get_vboundnorm(
600588 return tuple (out )
601589
602590
603- def _wraps_plot_scatter (wrapper ):
591+ _TYPE_GUARD_IMPORT_RE = re .compile (r"\nif TYPE_CHECKING:[^\n]*([\s\S]*?)(?=\n\S)" )
592+
593+
594+ @cache
595+ def _get_guarded_imports (obj : FunctionType ) -> Mapping [str , Any ]:
596+ """Simplified version from `sphinx-autodoc-typehints`."""
597+ module = inspect .getmodule (obj )
598+ assert module
599+ code = inspect .getsource (module )
600+ rv : dict [str , Any ] = {}
601+ for m in _TYPE_GUARD_IMPORT_RE .finditer (code ):
602+ guarded_code = textwrap .dedent (m .group (1 ))
603+ rv .update (obj .__globals__ )
604+ exec (guarded_code , rv )
605+ for k in obj .__globals__ :
606+ del rv [k ]
607+ return rv
608+
609+
610+ def _wraps_plot_scatter [** P , R ](wrapper : Callable [P , R ]) -> Callable [P , R ]:
604611 """Update the wrapper function to use the correct signature."""
605- params = inspect .signature (embedding , eval_str = True ).parameters .copy ()
606- wrapper_sig = inspect .signature (wrapper , eval_str = True )
612+ params = inspect .signature (embedding ).parameters .copy ()
613+ wrapper_sig = inspect .signature (wrapper )
607614 wrapper_params = wrapper_sig .parameters .copy ()
608615
609616 params .pop ("basis" )
@@ -619,6 +626,14 @@ def _wraps_plot_scatter(wrapper):
619626 if wrapper_sig .return_annotation is not inspect .Signature .empty :
620627 annotations ["return" ] = wrapper_sig .return_annotation
621628
629+ # `sphinx-autodoc-typehints` can execute `if TYPECHECKING` blocks,
630+ # but all all users of `_wraps_plot_scatter` that aren’t in this module
631+ # won’t have any imports. So we execute and inject the imports here.
632+ wrapper .__globals__ .update ({
633+ k : v
634+ for k , v in {** embedding .__globals__ , ** _get_guarded_imports (embedding )}.items ()
635+ if k not in wrapper .__globals__
636+ })
622637 wrapper .__signature__ = inspect .Signature (
623638 list (params .values ()), return_annotation = wrapper_sig .return_annotation
624639 )
0 commit comments