Skip to content

Commit 84eb8df

Browse files
authored
Merge pull request #577 from graphistry/dev/tolerate-cudf-fail
refactor(lazy import): centralize, optimize, CPU fallback when broken…
2 parents 710bab6 + 5a921a9 commit 84eb8df

19 files changed

+301
-252
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
77

88
## [Development]
99

10+
### Fixed
11+
12+
* Graceful CPU fallbacks: When lazy GPU dependency imports throw `ImportError`, commonly seen due to broken CUDA environments or having CUDA libraries but no GPU, warn and fall back to CPU.
13+
14+
* Ring layouts now support filtered inputs, giving expected positions
15+
16+
* `encode_axis()` updates are now functional, not inplace
17+
18+
### Changed
19+
20+
* Centralize lazy imports into `graphistry.utils.lazy_import`
21+
* Lazy imports distinguish `ModuleNotFound` (=> `False`) from `ImportError` (warn + `False`)
22+
1023
## [0.34.0 - 2024-07-17]
1124

1225
### Infra

graphistry/Engine.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from inspect import getmodule
12
import pandas as pd
23
from typing import Any, Optional, Union
34
from enum import Enum
5+
from graphistry.utils.lazy_import import lazy_cudf_import
46

57

68
class Engine(Enum):
@@ -21,18 +23,6 @@ class EngineAbstract(Enum):
2123
DataframeLocalLike = Any # pdf, cudf
2224
GraphistryLke = Any
2325

24-
#TODO use new importer when it lands (this is copied from umap_utils)
25-
def lazy_cudf_import_has_dependancy():
26-
try:
27-
import warnings
28-
29-
warnings.filterwarnings("ignore")
30-
import cudf # type: ignore
31-
32-
return True, "ok", cudf
33-
except ModuleNotFoundError as e:
34-
return False, e, None
35-
3626
def resolve_engine(
3727
engine: Union[EngineAbstract, str],
3828
g_or_df: Optional[Any] = None,
@@ -58,14 +48,15 @@ def resolve_engine(
5848
if isinstance(g_or_df, pd.DataFrame):
5949
return Engine.PANDAS
6050

61-
has_cudf_dependancy_, _, _ = lazy_cudf_import_has_dependancy()
62-
if has_cudf_dependancy_:
63-
import cudf
64-
if isinstance(g_or_df, cudf.DataFrame):
65-
return Engine.CUDF
66-
raise ValueError(f'Expected cudf dataframe, got: {type(g_or_df)}')
51+
if 'cudf.core.dataframe' in str(getmodule(g_or_df)):
52+
has_cudf_dependancy_, _, _ = lazy_cudf_import()
53+
if has_cudf_dependancy_:
54+
import cudf
55+
if isinstance(g_or_df, cudf.DataFrame):
56+
return Engine.CUDF
57+
raise ValueError(f'Expected cudf dataframe, got: {type(g_or_df)}')
6758

68-
has_cudf_dependancy_, _, _ = lazy_cudf_import_has_dependancy()
59+
has_cudf_dependancy_, _, _ = lazy_cudf_import()
6960
if has_cudf_dependancy_:
7061
return Engine.CUDF
7162
return Engine.PANDAS

graphistry/PlotterBase.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,11 @@ def encode_axis(self, rows: List[Dict] = []) -> Plottable:
373373
374374
"""
375375

376-
complex_encodings = self._complex_encodings or {}
377-
if 'node_encodings' not in complex_encodings:
378-
complex_encodings['node_encodings'] = {}
379-
node_encodings = complex_encodings['node_encodings']
380-
if 'current' not in node_encodings:
381-
node_encodings['current'] = {}
382-
if 'default' not in node_encodings:
383-
node_encodings['default'] = {}
376+
complex_encodings = {**self._complex_encodings} if self._complex_encodings else {}
377+
node_encodings = {**complex_encodings['node_encodings']} if 'node_encodings' not in complex_encodings else {}
378+
complex_encodings['node_encodings'] = node_encodings
379+
node_encodings['current'] = {**node_encodings['current']} if 'current' in node_encodings else {}
380+
node_encodings['default'] = {**node_encodings['default']} if 'default' in node_encodings else {}
384381
node_encodings['default']["pointAxisEncoding"] = {
385382
"graphType": "point",
386383
"encodingType": "axis",

graphistry/compute/cluster.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from graphistry.constants import CUML, UMAP_LEARN, DBSCAN # noqa type: ignore
1111
from graphistry.features import ModelDict
1212
from graphistry.feature_utils import get_matrix_by_column_parts
13+
from graphistry.utils.lazy_import import lazy_cudf_import, lazy_dbscan_import
1314

1415
logger = logging.getLogger("compute.cluster")
1516

@@ -22,37 +23,6 @@
2223
DBSCANEngine = Literal[DBSCANEngineConcrete, "auto"]
2324

2425

25-
def lazy_dbscan_import_has_dependency():
26-
has_min_dependency = True
27-
DBSCAN = None
28-
try:
29-
from sklearn.cluster import DBSCAN
30-
except ImportError:
31-
has_min_dependency = False
32-
logger.info("Please install sklearn for CPU DBSCAN")
33-
34-
has_cuml_dependency = True
35-
cuDBSCAN = None
36-
try:
37-
from cuml import DBSCAN as cuDBSCAN
38-
except ImportError:
39-
has_cuml_dependency = False
40-
logger.info("Please install cuml for GPU DBSCAN")
41-
42-
return has_min_dependency, DBSCAN, has_cuml_dependency, cuDBSCAN
43-
44-
def lazy_cudf_import_has_dependancy():
45-
try:
46-
import warnings
47-
48-
warnings.filterwarnings("ignore")
49-
import cudf # type: ignore
50-
51-
return True, "ok", cudf
52-
except ModuleNotFoundError as e:
53-
return False, e, None
54-
55-
5626
def resolve_cpu_gpu_engine(
5727
engine: DBSCANEngine,
5828
) -> DBSCANEngineConcrete: # noqa
@@ -64,7 +34,7 @@ def resolve_cpu_gpu_engine(
6434
_,
6535
has_cuml_dependency,
6636
_,
67-
) = lazy_dbscan_import_has_dependency()
37+
) = lazy_dbscan_import()
6838
if has_cuml_dependency:
6939
return "cuml"
7040
if has_min_dependency:
@@ -90,7 +60,7 @@ def safe_cudf(X, y):
9060
new_kwargs[key] = value
9161
return new_kwargs['X'], new_kwargs['y']
9262

93-
has_cudf_dependancy_, _, cudf = lazy_cudf_import_has_dependancy()
63+
has_cudf_dependancy_, _, cudf = lazy_cudf_import()
9464
if has_cudf_dependancy_:
9565
# print('DBSCAN CUML Matrices')
9666
return safe_cudf(X, y)
@@ -209,7 +179,7 @@ def _cluster_dbscan(
209179
):
210180
"""DBSCAN clustering on cpu or gpu infered by .engine flag
211181
"""
212-
_, DBSCAN, _, cuDBSCAN = lazy_dbscan_import_has_dependency()
182+
_, DBSCAN, _, cuDBSCAN = lazy_dbscan_import()
213183

214184
if engine_dbscan in [CUML]:
215185
print('`g.transform_dbscan(..)` not supported for engine=cuml, will return `g.transform_umap(..)` instead')

graphistry/dgl_utils.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
import numpy as np
66
import pandas as pd
77

8+
from graphistry.utils.lazy_import import (
9+
lazy_dgl_import,
10+
lazy_torch_import_has_dependency
11+
)
812
from . import constants as config
913
from .feature_utils import (
1014
FeatureEngine,
@@ -34,26 +38,6 @@
3438
MIXIN_BASE = object
3539

3640

37-
def lazy_dgl_import_has_dependency():
38-
try:
39-
import warnings
40-
warnings.filterwarnings('ignore')
41-
import dgl # noqa: F811
42-
return True, 'ok', dgl
43-
except ModuleNotFoundError as e:
44-
return False, e, None
45-
46-
47-
def lazy_torch_import_has_dependency():
48-
try:
49-
import warnings
50-
warnings.filterwarnings('ignore')
51-
import torch # noqa: F811
52-
return True, 'ok', torch
53-
except ModuleNotFoundError as e:
54-
return False, e, None
55-
56-
5741
logger = setup_logger(name=__name__)
5842

5943

@@ -181,7 +165,7 @@ def pandas_to_dgl_graph(
181165
sp_mat: sparse scipy matrix
182166
ordered_nodes_dict: dict ordered from most common src and dst nodes
183167
"""
184-
_, _, dgl = lazy_dgl_import_has_dependency() # noqa: F811
168+
_, _, dgl = lazy_dgl_import() # noqa: F811
185169
sp_mat, ordered_nodes_dict = pandas_to_sparse_adjacency(df, src, dst, weight_col)
186170
g = dgl.from_scipy(sp_mat, device=device) # there are other ways too
187171
logger.info(f"Graph Type: {type(g)}")
@@ -225,7 +209,7 @@ def dgl_lazy_init(self, train_split: float = 0.8, device: str = "cpu"):
225209
"""
226210

227211
if not self.dgl_initialized:
228-
lazy_dgl_import_has_dependency()
212+
lazy_dgl_import()
229213
lazy_torch_import_has_dependency()
230214
self.train_split = train_split
231215
self.device = device

graphistry/embed_utils.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,11 @@
33
import pandas as pd
44
from typing import Optional, Union, Callable, List, TYPE_CHECKING, Any, Tuple
55

6+
from graphistry.utils.lazy_import import lazy_embed_import
67
from .PlotterBase import Plottable
78
from .compute.ComputeMixin import ComputeMixin
89

910

10-
def lazy_embed_import_dep():
11-
try:
12-
import torch
13-
import torch.nn as nn
14-
import dgl
15-
from dgl.dataloading import GraphDataLoader
16-
import torch.nn.functional as F
17-
from .networks import HeteroEmbed
18-
from tqdm import trange
19-
return True, torch, nn, dgl, GraphDataLoader, HeteroEmbed, F, trange
20-
21-
except:
22-
return False, None, None, None, None, None, None, None
23-
2411
def check_cudf():
2512
try:
2613
import cudf
@@ -30,7 +17,7 @@ def check_cudf():
3017

3118

3219
if TYPE_CHECKING:
33-
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
20+
_, torch, _, _, _, _, _, _ = lazy_embed_import()
3421
TT = torch.Tensor
3522
MIXIN_BASE = ComputeMixin
3623
else:
@@ -147,7 +134,7 @@ def _preprocess_embedding_data(self, res, train_split:Union[float, int] = 0.8) -
147134
return res
148135

149136
def _build_graph(self, res) -> Plottable:
150-
_, _, _, dgl, _, _, _, _ = lazy_embed_import_dep()
137+
_, _, _, dgl, _, _, _, _ = lazy_embed_import()
151138
s, r, t = res._triplets.T
152139

153140
if res._train_idx is not None:
@@ -169,7 +156,7 @@ def _build_graph(self, res) -> Plottable:
169156

170157

171158
def _init_model(self, res, batch_size:int, sample_size:int, num_steps:int, device):
172-
_, _, _, _, GraphDataLoader, HeteroEmbed, _, _ = lazy_embed_import_dep()
159+
_, _, _, _, GraphDataLoader, HeteroEmbed, _, _ = lazy_embed_import()
173160
g_iter = SubgraphIterator(res._kg_dgl, sample_size, num_steps)
174161
g_dataloader = GraphDataLoader(
175162
g_iter, batch_size=batch_size, collate_fn=lambda x: x[0]
@@ -188,7 +175,7 @@ def _init_model(self, res, batch_size:int, sample_size:int, num_steps:int, devic
188175
return model, g_dataloader
189176

190177
def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_size:int, num_steps:int, device) -> Plottable:
191-
_, torch, nn, _, _, _, _, trange = lazy_embed_import_dep()
178+
_, torch, nn, _, _, _, _, trange = lazy_embed_import()
192179
log('Training embedding')
193180
model, g_dataloader = res._init_model(res, batch_size, sample_size, num_steps, device)
194181
if hasattr(res, "_embed_model") and not res._build_new_embedding_model:
@@ -232,7 +219,7 @@ def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_siz
232219

233220
@property
234221
def _gcn_node_embeddings(self):
235-
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
222+
_, torch, _, _, _, _, _, _ = lazy_embed_import()
236223
g_dgl = self._kg_dgl.to(self._device)
237224
em = self._embed_model(g_dgl).detach()
238225
torch.cuda.empty_cache()
@@ -540,7 +527,7 @@ def fetch_triplets_for_inference(x_r):
540527

541528

542529
def _score(self, triplets: Union[np.ndarray, TT]) -> TT: # type: ignore
543-
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
530+
_, torch, _, _, _, _, _, _ = lazy_embed_import()
544531
emb = self._kg_embeddings.clone().detach()
545532
if not isinstance(triplets, torch.Tensor):
546533
triplets = torch.tensor(triplets)
@@ -571,7 +558,7 @@ def __len__(self) -> int:
571558
return self.num_steps
572559

573560
def __getitem__(self, i:int):
574-
_, torch, nn, dgl, GraphDataLoader, _, F, _ = lazy_embed_import_dep()
561+
_, torch, nn, dgl, GraphDataLoader, _, F, _ = lazy_embed_import()
575562
eids = torch.from_numpy(np.random.choice(self.eids, self.sample_size))
576563

577564
src, dst = self.g.find_edges(eids)
@@ -593,7 +580,7 @@ def __getitem__(self, i:int):
593580

594581
@staticmethod
595582
def _sample_neg(triplets:np.ndarray, num_nodes:int) -> Tuple[TT, TT]: # type: ignore
596-
_, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
583+
_, torch, _, _, _, _, _, _ = lazy_embed_import()
597584
triplets = torch.tensor(triplets)
598585
h, r, t = triplets.T
599586
h_o_t = torch.randint(high=2, size=h.size())

0 commit comments

Comments
 (0)