Skip to content

Commit 3e0f512

Browse files
brynpickeringbobbyxnglkstrp
authored
Clean up some methods in dual module (#629)
* Clean up some methods in `dual` module * Typing fix (L434). * Ignore type assessment in expressions.py --------- Co-authored-by: Bobby Xiong <bobbyxng@gmail.com> Co-authored-by: Lukas Trippe <lkstrp@pm.me>
1 parent e90bea0 commit 3e0f512

2 files changed

Lines changed: 113 additions & 108 deletions

File tree

linopy/dual.py

Lines changed: 112 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from __future__ import annotations
88

99
import logging
10-
from typing import TYPE_CHECKING, Any
10+
from typing import TYPE_CHECKING, Any, Literal
1111

1212
import numpy as np
13-
import pandas as pd
1413
import xarray as xr
1514

1615
from linopy.expressions import LinearExpression
@@ -21,7 +20,86 @@
2120
logger = logging.getLogger(__name__)
2221

2322

24-
def _var_lookup(m: Model) -> dict:
23+
def _skip(
24+
da: xr.DataArray, component_type: Literal["variable", "constraint"], name: str
25+
) -> bool:
26+
"""
27+
Determine whether to skip processing a variable or constraint based on its labels.
28+
29+
Parameters
30+
----------
31+
da : xr.DataArray
32+
The labels DataArray of the variable or constraint.
33+
component_type : Literal["variable", "constraint"]
34+
The type of component being checked, used for logging.
35+
name : str
36+
The name of the variable or constraint, used for logging.
37+
38+
Returns
39+
-------
40+
bool
41+
True if the component should be skipped (empty or fully masked), False otherwise.
42+
"""
43+
if da.size == 0:
44+
logger.debug(f"Skipping empty {component_type} '{name}'.")
45+
return True
46+
47+
if (da == -1).all():
48+
logger.debug(f"{component_type} '{name}' is fully masked, skipping.")
49+
return True
50+
return False
51+
52+
53+
def _lookup(
54+
labels: xr.DataArray, name: str, component_type: Literal["variable", "constraint"]
55+
) -> dict[int, tuple[str, dict]]:
56+
"""
57+
Create a lookup dictionary mapping labels to their corresponding names and coordinates.
58+
59+
Parameters
60+
----------
61+
labels : xr.DataArray
62+
Array of labels.
63+
name : str
64+
Name of the component.
65+
component_type : Literal["variable", "constraint"]
66+
Type of the component.
67+
68+
Returns
69+
-------
70+
dict[int, tuple[str, dict]]
71+
Mapping from flat integer label to (name, coord_dict) tuple.
72+
"""
73+
lookup: dict[int, tuple[str, dict]] = {}
74+
75+
vals = labels.values
76+
if _skip(labels, component_type, name):
77+
return lookup
78+
79+
logger.debug(
80+
f"Creating label lookup for {component_type} '{name}' with shape {labels.shape} and dims {labels.dims}."
81+
)
82+
83+
if labels.ndim == 0:
84+
lookup[int(vals.item())] = (name, {})
85+
return lookup
86+
87+
coord_values = [labels.coords[d].values for d in labels.dims]
88+
89+
# Choosing np.ndindex over np.argwhere or da.to_series for memory efficiency on large n-dimensional arrays
90+
for idx in np.ndindex(vals.shape):
91+
label = int(vals[idx])
92+
if label == -1:
93+
continue
94+
lookup[label] = (
95+
name,
96+
{dim: coord_values[i][idx[i]] for i, dim in enumerate(labels.dims)},
97+
)
98+
99+
return lookup
100+
101+
102+
def _var_lookup(m: Model) -> dict[int, tuple[str, dict]]:
25103
"""
26104
Build a flat label -> (var_name, coord_dict) lookup for all variables in m.
27105
@@ -43,40 +121,12 @@ def _var_lookup(m: Model) -> dict:
43121
var_lookup = {}
44122
logger.debug("Building variable label lookup.")
45123
for var_name, var in m.variables.items():
46-
labels = var.labels
47-
flat_labels = labels.values.flatten()
48-
49-
if len(flat_labels) == 0:
50-
logger.debug(f"Skipping empty variable '{var_name}'.")
51-
continue
52-
if not (flat_labels != -1).any():
53-
logger.debug(f"Variable '{var_name}' is fully masked, skipping.")
54-
continue
55-
56-
logger.debug(
57-
f"Creating label lookup for variable '{var_name}' with shape {labels.shape} and dims {labels.dims}."
58-
)
59-
60-
coord_arrays = (
61-
np.meshgrid(
62-
*[labels.coords[dim].values for dim in labels.dims], indexing="ij"
63-
)
64-
if len(labels.dims) > 0
65-
else []
66-
)
67-
flat_coords = [arr.flatten() for arr in coord_arrays]
68-
69-
for k, flat in enumerate(flat_labels):
70-
if flat != -1:
71-
var_lookup[int(flat)] = (
72-
var_name,
73-
{dim: flat_coords[i][k] for i, dim in enumerate(labels.dims)},
74-
)
75-
124+
lookup = _lookup(var.labels, var_name, "variable")
125+
var_lookup.update(lookup)
76126
return var_lookup
77127

78128

79-
def _con_lookup(m: Model) -> dict:
129+
def _con_lookup(m: Model) -> dict[int, tuple[str, dict]]:
80130
"""
81131
Build a flat label -> (con_name, coord_dict) lookup for all constraints in m.
82132
@@ -98,36 +148,8 @@ def _con_lookup(m: Model) -> dict:
98148
con_lookup = {}
99149
logger.debug("Building constraint label lookup.")
100150
for con_name, con in m.constraints.items():
101-
labels = con.labels
102-
flat_labels = labels.values.flatten()
103-
104-
if len(flat_labels) == 0:
105-
logger.debug(f"Skipping empty constraint '{con_name}'.")
106-
continue
107-
if not (flat_labels != -1).any():
108-
logger.debug(f"Constraint '{con_name}' is fully masked, skipping.")
109-
continue
110-
111-
logger.debug(
112-
f"Creating label lookup for constraint '{con_name}' with shape {labels.shape} and dims {labels.dims}."
113-
)
114-
115-
coord_arrays = (
116-
np.meshgrid(
117-
*[labels.coords[dim].values for dim in labels.dims], indexing="ij"
118-
)
119-
if len(labels.dims) > 0
120-
else []
121-
)
122-
flat_coords = [arr.flatten() for arr in coord_arrays]
123-
124-
for k, flat in enumerate(flat_labels):
125-
if flat != -1:
126-
con_lookup[int(flat)] = (
127-
con_name,
128-
{dim: flat_coords[i][k] for i, dim in enumerate(labels.dims)},
129-
)
130-
151+
lookup = _lookup(con.labels, con_name, "constraint")
152+
con_lookup.update(lookup)
131153
return con_lookup
132154

133155

@@ -228,44 +250,35 @@ def _add_dual_variables(m: Model, m2: Model) -> dict:
228250

229251
dual_vars = {}
230252
for name, con in m.constraints.items():
231-
sign_vals = con.sign.values.flatten()
232-
233-
if len(sign_vals) == 0:
234-
logger.warning(f"Constraint '{name}' has no sign values, skipping.")
253+
if _skip(con.labels, "constraint", name):
235254
continue
236255

237256
mask = con.labels != -1
238-
if not mask.any():
239-
logger.debug(f"Constraint '{name}' is fully masked, skipping.")
240-
continue
241-
242-
if sign_vals[0] == "=":
243-
lower, upper = -np.inf, np.inf
244-
var_type = "free"
245-
elif sign_vals[0] == "<=":
246-
lower, upper = (-np.inf, 0) if primal_is_min else (0, np.inf)
247-
var_type = "non-positive" if primal_is_min else "non-negative"
248-
elif sign_vals[0] == ">=":
249-
lower, upper = (0, np.inf) if primal_is_min else (-np.inf, 0)
250-
var_type = "non-negative" if primal_is_min else "non-positive"
251-
else:
252-
logger.warning(
253-
f"Constraint '{name}' has unrecognized sign '{sign_vals[0]}', skipping."
254-
)
255-
continue
257+
sign = con.sign.isel({d: 0 for d in con.sign.dims}).item()
258+
259+
match sign:
260+
case "=":
261+
lower, upper = -np.inf, np.inf
262+
var_type = "free"
263+
case "<=":
264+
lower, upper = (-np.inf, 0) if primal_is_min else (0, np.inf)
265+
var_type = "non-positive" if primal_is_min else "non-negative"
266+
case ">=":
267+
lower, upper = (0, np.inf) if primal_is_min else (-np.inf, 0)
268+
var_type = "non-negative" if primal_is_min else "non-positive"
269+
case _:
270+
logger.warning(
271+
f"Constraint '{name}' has unrecognized sign '{sign}', skipping."
272+
)
273+
continue
256274

257275
logger.debug(
258276
f"Adding {var_type} dual variable for constraint '{name}' with shape {con.shape} and dims {con.labels.dims}."
259277
)
260-
coords = (
261-
[con.labels.coords[dim] for dim in con.labels.dims]
262-
if con.labels.dims
263-
else None
264-
)
265278
dual_vars[name] = m2.add_variables(
266279
lower=lower,
267280
upper=upper,
268-
coords=coords,
281+
coords=con.labels.coords,
269282
name=name,
270283
mask=mask,
271284
)
@@ -409,37 +422,29 @@ def _add_dual_feasibility_constraints(
409422
# add dual feasibility constraints to m2
410423
logger.debug("Adding dual feasibility constraints to model.")
411424
for var_name, var in m.variables.items():
412-
coords = [
413-
pd.Index(var.labels.coords[dim].values, name=dim) for dim in var.labels.dims
414-
]
415425
mask = var.labels != -1
416426

417427
c_vals = xr.DataArray(
418428
np.vectorize(lambda flat: c_by_label.get(flat, 0.0))(var.labels.values),
419429
coords=var.labels.coords,
420430
)
421431

422-
def rule(
423-
m: Model,
424-
*coord_vals: Any,
425-
vname: str = var_name,
426-
vdims: tuple = var.labels.dims,
427-
) -> LinearExpression | None:
428-
coord_dict = dict(zip(vdims, coord_vals))
432+
def __rule(m: Model, *coord_vals: Any) -> LinearExpression | None:
433+
coord_dict = {
434+
str(dim): val for dim, val in zip(var.labels.dims, coord_vals)
435+
}
429436
flat = var.labels.sel(**coord_dict).item()
430-
if flat == -1:
431-
return None
432-
if flat not in dual_feas_terms[vname]:
437+
if flat == -1 or flat not in (term_dict := dual_feas_terms[var_name]):
433438
return None
434-
_, terms, _ = dual_feas_terms[vname][flat]
439+
_, terms, _ = term_dict[flat]
435440
if not terms:
436441
return None
437442
return sum(
438443
coeff * dual_vars[con_name].at[tuple(con_coords.values())]
439444
for con_name, con_coords, coeff in terms
440445
)
441446

442-
lhs = LinearExpression.from_rule(m2, rule, coords)
447+
lhs = LinearExpression.from_rule(m2, __rule, var.labels.coords)
443448
m2.add_constraints(lhs == c_vals, name=var_name, mask=mask)
444449

445450

linopy/expressions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2384,7 +2384,7 @@ def merge(
23842384
has_quad_expression = any(type(e) is QuadraticExpression for e in exprs)
23852385
has_linear_expression = any(type(e) is LinearExpression for e in exprs)
23862386
if cls is None:
2387-
cls = QuadraticExpression if has_quad_expression else LinearExpression
2387+
cls = QuadraticExpression if has_quad_expression else LinearExpression # type: ignore[assignment]
23882388

23892389
if cls is QuadraticExpression and dim == TERM_DIM and has_linear_expression:
23902390
raise ValueError(

0 commit comments

Comments
 (0)