77from __future__ import annotations
88
99import logging
10- from typing import TYPE_CHECKING , Any
10+ from typing import TYPE_CHECKING , Any , Literal
1111
1212import numpy as np
13- import pandas as pd
1413import xarray as xr
1514
1615from linopy .expressions import LinearExpression
2120logger = logging .getLogger (__name__ )
2221
2322
24- def _var_lookup (m : Model ) -> dict :
23+ def _skip (
24+ da : xr .DataArray , component_type : Literal ["variable" , "constraint" ], name : str
25+ ) -> bool :
26+ """
27+ Determine whether to skip processing a variable or constraint based on its labels.
28+
29+ Parameters
30+ ----------
31+ da : xr.DataArray
32+ The labels DataArray of the variable or constraint.
33+ component_type : Literal["variable", "constraint"]
34+ The type of component being checked, used for logging.
35+ name : str
36+ The name of the variable or constraint, used for logging.
37+
38+ Returns
39+ -------
40+ bool
41+ True if the component should be skipped (empty or fully masked), False otherwise.
42+ """
43+ if da .size == 0 :
44+ logger .debug (f"Skipping empty { component_type } '{ name } '." )
45+ return True
46+
47+ if (da == - 1 ).all ():
48+ logger .debug (f"{ component_type } '{ name } ' is fully masked, skipping." )
49+ return True
50+ return False
51+
52+
53+ def _lookup (
54+ labels : xr .DataArray , name : str , component_type : Literal ["variable" , "constraint" ]
55+ ) -> dict [int , tuple [str , dict ]]:
56+ """
57+ Create a lookup dictionary mapping labels to their corresponding names and coordinates.
58+
59+ Parameters
60+ ----------
61+ labels : xr.DataArray
62+ Array of labels.
63+ name : str
64+ Name of the component.
65+ component_type : Literal["variable", "constraint"]
66+ Type of the component.
67+
68+ Returns
69+ -------
70+ dict[int, tuple[str, dict]]
71+ Mapping from flat integer label to (name, coord_dict) tuple.
72+ """
73+ lookup : dict [int , tuple [str , dict ]] = {}
74+
75+ vals = labels .values
76+ if _skip (labels , component_type , name ):
77+ return lookup
78+
79+ logger .debug (
80+ f"Creating label lookup for { component_type } '{ name } ' with shape { labels .shape } and dims { labels .dims } ."
81+ )
82+
83+ if labels .ndim == 0 :
84+ lookup [int (vals .item ())] = (name , {})
85+ return lookup
86+
87+ coord_values = [labels .coords [d ].values for d in labels .dims ]
88+
89+ # Choosing np.ndindex over np.argwhere or da.to_series for memory efficiency on large n-dimensional arrays
90+ for idx in np .ndindex (vals .shape ):
91+ label = int (vals [idx ])
92+ if label == - 1 :
93+ continue
94+ lookup [label ] = (
95+ name ,
96+ {dim : coord_values [i ][idx [i ]] for i , dim in enumerate (labels .dims )},
97+ )
98+
99+ return lookup
100+
101+
102+ def _var_lookup (m : Model ) -> dict [int , tuple [str , dict ]]:
25103 """
26104 Build a flat label -> (var_name, coord_dict) lookup for all variables in m.
27105
@@ -43,40 +121,12 @@ def _var_lookup(m: Model) -> dict:
43121 var_lookup = {}
44122 logger .debug ("Building variable label lookup." )
45123 for var_name , var in m .variables .items ():
46- labels = var .labels
47- flat_labels = labels .values .flatten ()
48-
49- if len (flat_labels ) == 0 :
50- logger .debug (f"Skipping empty variable '{ var_name } '." )
51- continue
52- if not (flat_labels != - 1 ).any ():
53- logger .debug (f"Variable '{ var_name } ' is fully masked, skipping." )
54- continue
55-
56- logger .debug (
57- f"Creating label lookup for variable '{ var_name } ' with shape { labels .shape } and dims { labels .dims } ."
58- )
59-
60- coord_arrays = (
61- np .meshgrid (
62- * [labels .coords [dim ].values for dim in labels .dims ], indexing = "ij"
63- )
64- if len (labels .dims ) > 0
65- else []
66- )
67- flat_coords = [arr .flatten () for arr in coord_arrays ]
68-
69- for k , flat in enumerate (flat_labels ):
70- if flat != - 1 :
71- var_lookup [int (flat )] = (
72- var_name ,
73- {dim : flat_coords [i ][k ] for i , dim in enumerate (labels .dims )},
74- )
75-
124+ lookup = _lookup (var .labels , var_name , "variable" )
125+ var_lookup .update (lookup )
76126 return var_lookup
77127
78128
79- def _con_lookup (m : Model ) -> dict :
129+ def _con_lookup (m : Model ) -> dict [ int , tuple [ str , dict ]] :
80130 """
81131 Build a flat label -> (con_name, coord_dict) lookup for all constraints in m.
82132
@@ -98,36 +148,8 @@ def _con_lookup(m: Model) -> dict:
98148 con_lookup = {}
99149 logger .debug ("Building constraint label lookup." )
100150 for con_name , con in m .constraints .items ():
101- labels = con .labels
102- flat_labels = labels .values .flatten ()
103-
104- if len (flat_labels ) == 0 :
105- logger .debug (f"Skipping empty constraint '{ con_name } '." )
106- continue
107- if not (flat_labels != - 1 ).any ():
108- logger .debug (f"Constraint '{ con_name } ' is fully masked, skipping." )
109- continue
110-
111- logger .debug (
112- f"Creating label lookup for constraint '{ con_name } ' with shape { labels .shape } and dims { labels .dims } ."
113- )
114-
115- coord_arrays = (
116- np .meshgrid (
117- * [labels .coords [dim ].values for dim in labels .dims ], indexing = "ij"
118- )
119- if len (labels .dims ) > 0
120- else []
121- )
122- flat_coords = [arr .flatten () for arr in coord_arrays ]
123-
124- for k , flat in enumerate (flat_labels ):
125- if flat != - 1 :
126- con_lookup [int (flat )] = (
127- con_name ,
128- {dim : flat_coords [i ][k ] for i , dim in enumerate (labels .dims )},
129- )
130-
151+ lookup = _lookup (con .labels , con_name , "constraint" )
152+ con_lookup .update (lookup )
131153 return con_lookup
132154
133155
@@ -228,44 +250,35 @@ def _add_dual_variables(m: Model, m2: Model) -> dict:
228250
229251 dual_vars = {}
230252 for name , con in m .constraints .items ():
231- sign_vals = con .sign .values .flatten ()
232-
233- if len (sign_vals ) == 0 :
234- logger .warning (f"Constraint '{ name } ' has no sign values, skipping." )
253+ if _skip (con .labels , "constraint" , name ):
235254 continue
236255
237256 mask = con .labels != - 1
238- if not mask .any ():
239- logger .debug (f"Constraint '{ name } ' is fully masked, skipping." )
240- continue
241-
242- if sign_vals [0 ] == "=" :
243- lower , upper = - np .inf , np .inf
244- var_type = "free"
245- elif sign_vals [0 ] == "<=" :
246- lower , upper = (- np .inf , 0 ) if primal_is_min else (0 , np .inf )
247- var_type = "non-positive" if primal_is_min else "non-negative"
248- elif sign_vals [0 ] == ">=" :
249- lower , upper = (0 , np .inf ) if primal_is_min else (- np .inf , 0 )
250- var_type = "non-negative" if primal_is_min else "non-positive"
251- else :
252- logger .warning (
253- f"Constraint '{ name } ' has unrecognized sign '{ sign_vals [0 ]} ', skipping."
254- )
255- continue
257+ sign = con .sign .isel ({d : 0 for d in con .sign .dims }).item ()
258+
259+ match sign :
260+ case "=" :
261+ lower , upper = - np .inf , np .inf
262+ var_type = "free"
263+ case "<=" :
264+ lower , upper = (- np .inf , 0 ) if primal_is_min else (0 , np .inf )
265+ var_type = "non-positive" if primal_is_min else "non-negative"
266+ case ">=" :
267+ lower , upper = (0 , np .inf ) if primal_is_min else (- np .inf , 0 )
268+ var_type = "non-negative" if primal_is_min else "non-positive"
269+ case _:
270+ logger .warning (
271+ f"Constraint '{ name } ' has unrecognized sign '{ sign } ', skipping."
272+ )
273+ continue
256274
257275 logger .debug (
258276 f"Adding { var_type } dual variable for constraint '{ name } ' with shape { con .shape } and dims { con .labels .dims } ."
259277 )
260- coords = (
261- [con .labels .coords [dim ] for dim in con .labels .dims ]
262- if con .labels .dims
263- else None
264- )
265278 dual_vars [name ] = m2 .add_variables (
266279 lower = lower ,
267280 upper = upper ,
268- coords = coords ,
281+ coords = con . labels . coords ,
269282 name = name ,
270283 mask = mask ,
271284 )
@@ -409,37 +422,29 @@ def _add_dual_feasibility_constraints(
409422 # add dual feasibility constraints to m2
410423 logger .debug ("Adding dual feasibility constraints to model." )
411424 for var_name , var in m .variables .items ():
412- coords = [
413- pd .Index (var .labels .coords [dim ].values , name = dim ) for dim in var .labels .dims
414- ]
415425 mask = var .labels != - 1
416426
417427 c_vals = xr .DataArray (
418428 np .vectorize (lambda flat : c_by_label .get (flat , 0.0 ))(var .labels .values ),
419429 coords = var .labels .coords ,
420430 )
421431
422- def rule (
423- m : Model ,
424- * coord_vals : Any ,
425- vname : str = var_name ,
426- vdims : tuple = var .labels .dims ,
427- ) -> LinearExpression | None :
428- coord_dict = dict (zip (vdims , coord_vals ))
432+ def __rule (m : Model , * coord_vals : Any ) -> LinearExpression | None :
433+ coord_dict = {
434+ str (dim ): val for dim , val in zip (var .labels .dims , coord_vals )
435+ }
429436 flat = var .labels .sel (** coord_dict ).item ()
430- if flat == - 1 :
431- return None
432- if flat not in dual_feas_terms [vname ]:
437+ if flat == - 1 or flat not in (term_dict := dual_feas_terms [var_name ]):
433438 return None
434- _ , terms , _ = dual_feas_terms [ vname ] [flat ]
439+ _ , terms , _ = term_dict [flat ]
435440 if not terms :
436441 return None
437442 return sum (
438443 coeff * dual_vars [con_name ].at [tuple (con_coords .values ())]
439444 for con_name , con_coords , coeff in terms
440445 )
441446
442- lhs = LinearExpression .from_rule (m2 , rule , coords )
447+ lhs = LinearExpression .from_rule (m2 , __rule , var . labels . coords )
443448 m2 .add_constraints (lhs == c_vals , name = var_name , mask = mask )
444449
445450
0 commit comments