From 4163c1bac6986a93b58cbe74e9e8ae62298f70a9 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sat, 18 May 2024 22:48:01 +0200 Subject: [PATCH 1/7] enh(constraints): Improve infer_schema_polars --- linopy/common.py | 6 ++++-- linopy/constraints.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 4e28fe63..22767bec 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -260,7 +260,7 @@ def check_has_nulls(df: pd.DataFrame, name: str): raise ValueError(f"{name} contains nan's in field(s) {fields}") -def infer_schema_polars(ds: pl.DataFrame) -> dict: +def infer_schema_polars(ds: Dataset, overwrites: dict[str, pl.DataType]) -> dict: """ Infer the schema for a Polars DataFrame based on the data types of its columns. @@ -272,7 +272,9 @@ def infer_schema_polars(ds: pl.DataFrame) -> dict: """ schema = {} for col_name, array in ds.items(): - if np.issubdtype(array.dtype, np.integer): + if col_name in overwrites: + schema[col_name] = overwrites[col_name] + elif np.issubdtype(array.dtype, np.integer): schema[col_name] = pl.Int32 if os.name == "nt" else pl.Int64 elif np.issubdtype(array.dtype, np.floating): schema[col_name] = pl.Float64 diff --git a/linopy/constraints.py b/linopy/constraints.py index 0d7740b3..d96d0004 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -578,8 +578,7 @@ def to_polars(self): check_has_nulls_polars(long, name=f"{self.type} {self.name}") short = ds[[k for k in ds if "_term" not in ds[k].dims]] - schema = infer_schema_polars(short) - schema["sign"] = pl.Enum(["=", "<=", ">="]) + schema = infer_schema_polars(short, overwrites={"sign": pl.Enum(["=", "<=", ">="])}) short = to_polars(short, schema=schema) short = filter_nulls_polars(short) check_has_nulls_polars(short, name=f"{self.type} {self.name}") From 16b86658be1d7efbf899968203f968d5fe16464b Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sat, 18 May 2024 22:51:02 +0200 Subject: [PATCH 2/7] Make polars frames lazy and stream into csv --- linopy/common.py | 8 ++-- linopy/constraints.py | 8 ++-- linopy/expressions.py | 30 +++++++------- linopy/io.py | 94 ++++++++++++++++++++++--------------------- linopy/variables.py | 16 ++++---- 5 files changed, 79 insertions(+), 77 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 22767bec..a573182e 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -303,10 +303,10 @@ def to_polars(ds: Dataset, **kwargs) -> pl.DataFrame: DataFrame constructor. """ data = broadcast(ds)[0] - return pl.DataFrame({k: v.values.reshape(-1) for k, v in data.items()}, **kwargs) + return pl.LazyFrame({k: v.values.reshape(-1) for k, v in data.items()}, **kwargs) -def check_has_nulls_polars(df: pl.DataFrame, name: str = "") -> None: +def check_has_nulls_polars(df: pl.LazyFrame, name: str = "") -> None: """ Checks if the given DataFrame contains any null values and raises a ValueError if it does. @@ -318,7 +318,7 @@ def check_has_nulls_polars(df: pl.DataFrame, name: str = "") -> None: ValueError: If the DataFrame contains null values, a ValueError is raised with a message indicating the name of the constraint and the fields containing null values. """ - has_nulls = df.select(pl.col("*").is_null().any()) + has_nulls = df.select(pl.col("*").is_null().any()).collect() null_columns = [col for col in has_nulls.columns if has_nulls[col][0]] if null_columns: raise ValueError(f"{name} contains nan's in field(s) {null_columns}") @@ -347,7 +347,7 @@ def filter_nulls_polars(df: pl.DataFrame) -> pl.DataFrame: return df.filter(cond) -def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame: +def group_terms_polars(df: pl.LazyFrame) -> pl.LazyFrame: """ Groups terms in a polars DataFrame. diff --git a/linopy/constraints.py b/linopy/constraints.py index d96d0004..a5b78cdd 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -583,12 +583,12 @@ def to_polars(self): short = filter_nulls_polars(short) check_has_nulls_polars(short, name=f"{self.type} {self.name}") - df = pl.concat([short, long], how="diagonal").sort(["labels", "rhs"]) + lf = pl.concat([short, long], how="diagonal").sort(["labels", "rhs"]) # delete subsequent non-null rhs (happens is all vars per label are -1) - is_non_null = df["rhs"].is_not_null() + is_non_null = lf["rhs"].is_not_null() prev_non_is_null = is_non_null.shift(1).fill_null(False) - df = df.filter(is_non_null & ~prev_non_is_null | ~is_non_null) - return df[["labels", "coeffs", "vars", "sign", "rhs"]] + lf = lf.filter(is_non_null & ~prev_non_is_null | ~is_non_null) + return lf[["labels", "coeffs", "vars", "sign", "rhs"]] sel = conwrap(Dataset.sel) diff --git a/linopy/expressions.py b/linopy/expressions.py index 124af6be..08b541ef 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -1268,9 +1268,9 @@ def mask_func(data): check_has_nulls(df, name=self.type) return df - def to_polars(self) -> pl.DataFrame: + def to_polars(self) -> pl.LazyFrame: """ - Convert the expression to a polars DataFrame. + Convert the expression to a polars lazyFrame. The resulting DataFrame represents a long table format of the all non-masked expressions with non-zero coefficients. It contains the @@ -1278,13 +1278,13 @@ def to_polars(self) -> pl.DataFrame: Returns ------- - df : polars.DataFrame + lf : polars.LazyFrame """ - df = to_polars(self.data) - df = filter_nulls_polars(df) - df = group_terms_polars(df) - check_has_nulls_polars(df, name=self.type) - return df + lf = to_polars(self.data) + lf = filter_nulls_polars(lf) + lf = group_terms_polars(lf) + check_has_nulls_polars(lf, name=self.type) + return lf # Wrapped function which would convert variable to dataarray assign = exprwrap(Dataset.assign) @@ -1480,7 +1480,7 @@ def mask_func(data): check_has_nulls(df, name=self.type) return df - def to_polars(self, **kwargs): + def to_polars(self, **kwargs) -> pl.LazyFrame: """ Convert the expression to a polars DataFrame. @@ -1490,17 +1490,17 @@ def to_polars(self, **kwargs): Returns ------- - df : polars.DataFrame + lf : polars.LazyFrame """ vars = self.data.vars.assign_coords( {FACTOR_DIM: ["vars1", "vars2"]} ).to_dataset(FACTOR_DIM) ds = self.data.drop_vars("vars").assign(vars) - df = to_polars(ds, **kwargs) - df = filter_nulls_polars(df) - df = group_terms_polars(df) - check_has_nulls_polars(df, name=self.type) - return df + lf = to_polars(ds, **kwargs) + lf = filter_nulls_polars(lf) + lf = group_terms_polars(lf) + check_has_nulls_polars(lf, name=self.type) + return lf def to_matrix(self): """ diff --git a/linopy/io.py b/linopy/io.py index 779d3636..a4df25c6 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -3,6 +3,7 @@ """ Module containing all import/export functionalities. """ + import logging import shutil import time @@ -12,6 +13,8 @@ import numpy as np import pandas as pd import polars as pl +import pyarrow as pa +import pyarrow.csv import xarray as xr from numpy import ones_like, zeros_like from scipy.sparse import tril, triu @@ -278,20 +281,43 @@ def to_lp_file(m, fn, integer_label): logger.info(f" Writing time: {round(time.time()-start, 2)}s") -def objective_write_linear_terms_polars(f, df): +def write_lazyframe(f, lf): + lf = lf.fill_null("") + + def to_pyarrow_schema(schema): + return pa.schema( + (k, pl.datatypes.py_type_to_arrow_type(pl.datatypes.dtype_to_py_type(v))) + for k, v in lf.schema.items() + ) + + writer = pa.csv.CSVWriter( + f, + to_pyarrow_schema(lf.schema), + write_options=pa.csv.WriteOptions( + include_header=False, delimiter=" ", quoting_style="none" + ), + ) + + def write_batch(batch): + writer.write(batch.to_arrow()) + return pl.DataFrame({"written": np.ones(batch.shape[0], dtype=int)}) + + lf.map_batches( + write_batch, schema={"written": pl.Int64}, streamable=True + ).sum().collect() + + +def objective_write_linear_terms_polars(f, lf): cols = [ pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), pl.col("coeffs").cast(pl.String), pl.lit(" x"), pl.col("vars").cast(pl.String), ] - df = df.select(pl.concat_str(cols, ignore_nulls=True)) - df.write_csv( - f, separator=" ", null_value="", quote_style="never", include_header=False - ) + write_lazyframe(f, lf.select(pl.concat_str(cols, ignore_nulls=True))) -def objective_write_quadratic_terms_polars(f, df): +def objective_write_quadratic_terms_polars(f, lf): cols = [ pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), pl.col("coeffs").mul(2).cast(pl.String), @@ -301,10 +327,7 @@ def objective_write_quadratic_terms_polars(f, df): pl.col("vars2").cast(pl.String), ] f.write(b"+ [\n") - df = df.select(pl.concat_str(cols, ignore_nulls=True)) - df.write_csv( - f, separator=" ", null_value="", quote_style="never", include_header=False - ) + write_lazyframe(lf.select(pl.concat_str(cols, ignore_nulls=True))) f.write(b"] / 2\n") @@ -317,13 +340,13 @@ def objective_to_file_polars(m, f, log=False): sense = m.objective.sense f.write(f"{sense}\n\nobj:\n\n".encode("utf-8")) - df = m.objective.to_polars() + lf = m.objective.to_polars() if m.is_linear: - objective_write_linear_terms_polars(f, df) + objective_write_linear_terms_polars(f, lf) elif m.is_quadratic: - lins = df.filter(pl.col("vars1").eq(-1) | pl.col("vars2").eq(-1)) + lins = lf.filter(pl.col("vars1").eq(-1) | pl.col("vars2").eq(-1)) lins = lins.with_columns( pl.when(pl.col("vars1").eq(-1)) .then(pl.col("vars2")) @@ -332,7 +355,7 @@ def objective_to_file_polars(m, f, log=False): ) objective_write_linear_terms_polars(f, lins) - quads = df.filter(pl.col("vars1").ne(-1) & pl.col("vars2").ne(-1)) + quads = lf.filter(pl.col("vars1").ne(-1) & pl.col("vars2").ne(-1)) objective_write_quadratic_terms_polars(f, quads) @@ -353,7 +376,7 @@ def bounds_to_file_polars(m, f, log=False): ) for name in names: - df = m.variables[name].to_polars() + lf = m.variables[name].to_polars() columns = [ pl.when(pl.col("lower") >= 0).then(pl.lit("+")).otherwise(pl.lit("")), @@ -365,11 +388,7 @@ def bounds_to_file_polars(m, f, log=False): pl.col("upper").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def binaries_to_file_polars(m, f, log=False): @@ -389,18 +408,14 @@ def binaries_to_file_polars(m, f, log=False): ) for name in names: - df = m.variables[name].to_polars() + lf = m.variables[name].to_polars() columns = [ pl.lit("x"), pl.col("labels").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def integers_to_file_polars(m, f, log=False, integer_label="general"): @@ -420,18 +435,14 @@ def integers_to_file_polars(m, f, log=False, integer_label="general"): ) for name in names: - df = m.variables[name].to_polars() + lf = m.variables[name].to_polars() columns = [ pl.lit("x"), pl.col("labels").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def constraints_to_file_polars(m, f, log=False, lazy=False): @@ -447,14 +458,14 @@ def constraints_to_file_polars(m, f, log=False, lazy=False): colour=TQDM_COLOR, ) - # to make this even faster, we can use polars expression + # to make this even faster, we could create a custom polars expression plugin # https://docs.pola.rs/user-guide/expressions/plugins/#output-data-types for name in names: - df = m.constraints[name].to_polars() + lf = m.constraints[name].to_polars() - # df = df.lazy() + # lf = lf.lazy() # filter out repeated label values - df = df.with_columns( + lf = lf.with_columns( pl.when(pl.col("labels").is_first_distinct()) .then(pl.col("labels")) .otherwise(pl.lit(None)) @@ -474,16 +485,7 @@ def constraints_to_file_polars(m, f, log=False, lazy=False): pl.col("rhs").cast(pl.String), ] - kwargs = dict( - separator=" ", null_value="", quote_style="never", include_header=False - ) - formatted = df.select(pl.concat_str(columns, ignore_nulls=True)) - formatted.write_csv(f, **kwargs) - - # in the future, we could use lazy dataframes when they support appending - # tp existent files - # formatted = df.lazy().select(pl.concat_str(columns, ignore_nulls=True)) - # formatted.sink_csv(f, **kwargs) + write_lazyframe(f, lf.select(pl.concat_str(columns, ignore_nulls=True))) def to_lp_file_polars(m, fn, integer_label="general"): diff --git a/linopy/variables.py b/linopy/variables.py index 8c00e97a..7e9fcabf 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -159,7 +159,6 @@ def __init__(self, data: Dataset, model: Any, name: str): self._model = model def __getitem__(self, selector) -> Union["Variable", "ScalarVariable"]: - keys = selector if isinstance(selector, tuple) else (selector,) if all(map(pd.api.types.is_scalar, keys)): warn( @@ -799,10 +798,10 @@ def to_polars(self) -> pl.DataFrame: ------- pl.DataFrame """ - df = to_polars(self.data) - df = filter_nulls_polars(df) - check_has_nulls_polars(df, name=f"{self.type} {self.name}") - return df + lf = to_polars(self.data) + lf = filter_nulls_polars(lf) + check_has_nulls_polars(lf, name=f"{self.type} {self.name}") + return lf def sum(self, dim=None, **kwargs): """ @@ -935,7 +934,8 @@ def ffill(self, dim, limit=None): self.data.where(self.labels != -1) # .ffill(dim, limit=limit) # breaks with Dataset.ffill, use map instead - .map(DataArray.ffill, dim=dim, limit=limit).fillna(self._fill_value) + .map(DataArray.ffill, dim=dim, limit=limit) + .fillna(self._fill_value) ) data = data.assign(labels=data.labels.astype(int)) return self.__class__(data, self.model, self.name) @@ -962,7 +962,8 @@ def bfill(self, dim, limit=None): self.data.where(~self.isnull()) # .bfill(dim, limit=limit) # breaks with Dataset.bfill, use map instead - .map(DataArray.bfill, dim=dim, limit=limit).fillna(self._fill_value) + .map(DataArray.bfill, dim=dim, limit=limit) + .fillna(self._fill_value) ) data = data.assign(labels=data.labels.astype(int)) return self.__class__(data, self.model, self.name) @@ -1020,7 +1021,6 @@ def __init__(self, obj): self.object = obj def __getitem__(self, keys) -> "ScalarVariable": - keys = keys if isinstance(keys, tuple) else (keys,) object = self.object From cba0269d388604988f432455cfadba376c9a0ef1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 18 May 2024 20:53:56 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- linopy/constraints.py | 4 +++- linopy/variables.py | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index a5b78cdd..7b286a87 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -578,7 +578,9 @@ def to_polars(self): check_has_nulls_polars(long, name=f"{self.type} {self.name}") short = ds[[k for k in ds if "_term" not in ds[k].dims]] - schema = infer_schema_polars(short, overwrites={"sign": pl.Enum(["=", "<=", ">="])}) + schema = infer_schema_polars( + short, overwrites={"sign": pl.Enum(["=", "<=", ">="])} + ) short = to_polars(short, schema=schema) short = filter_nulls_polars(short) check_has_nulls_polars(short, name=f"{self.type} {self.name}") diff --git a/linopy/variables.py b/linopy/variables.py index 7e9fcabf..03afd0ba 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -934,8 +934,7 @@ def ffill(self, dim, limit=None): self.data.where(self.labels != -1) # .ffill(dim, limit=limit) # breaks with Dataset.ffill, use map instead - .map(DataArray.ffill, dim=dim, limit=limit) - .fillna(self._fill_value) + .map(DataArray.ffill, dim=dim, limit=limit).fillna(self._fill_value) ) data = data.assign(labels=data.labels.astype(int)) return self.__class__(data, self.model, self.name) @@ -962,8 +961,7 @@ def bfill(self, dim, limit=None): self.data.where(~self.isnull()) # .bfill(dim, limit=limit) # breaks with Dataset.bfill, use map instead - .map(DataArray.bfill, dim=dim, limit=limit) - .fillna(self._fill_value) + .map(DataArray.bfill, dim=dim, limit=limit).fillna(self._fill_value) ) data = data.assign(labels=data.labels.astype(int)) return self.__class__(data, self.model, self.name) From 71958e923a2793f1fa78d0f7870a1473be7ba681 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sat, 18 May 2024 22:58:27 +0200 Subject: [PATCH 4/7] map_batches is also fine with an empty frame --- linopy/io.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/linopy/io.py b/linopy/io.py index a4df25c6..20173146 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -300,11 +300,9 @@ def to_pyarrow_schema(schema): def write_batch(batch): writer.write(batch.to_arrow()) - return pl.DataFrame({"written": np.ones(batch.shape[0], dtype=int)}) + return pl.DataFrame() - lf.map_batches( - write_batch, schema={"written": pl.Int64}, streamable=True - ).sum().collect() + lf.map_batches(write_batch, schema={}, streamable=True).collect() def objective_write_linear_terms_polars(f, lf): From 7c9fd2f9dd945c26edbf6eb7ecfd3ef944a5b48d Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sat, 18 May 2024 23:03:38 +0200 Subject: [PATCH 5/7] fix typo --- linopy/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/linopy/io.py b/linopy/io.py index 20173146..855dac1d 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -287,7 +287,7 @@ def write_lazyframe(f, lf): def to_pyarrow_schema(schema): return pa.schema( (k, pl.datatypes.py_type_to_arrow_type(pl.datatypes.dtype_to_py_type(v))) - for k, v in lf.schema.items() + for k, v in schema.items() ) writer = pa.csv.CSVWriter( From db64c01f11f0659c13b0efcc947d75d1c263f3a0 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sun, 19 May 2024 22:43:08 +0200 Subject: [PATCH 6/7] Fix tests --- linopy/constraints.py | 4 ++-- linopy/io.py | 12 ++++++------ test/test_constraint.py | 2 +- test/test_linear_expression.py | 4 ++-- test/test_quadratic_expression.py | 4 ++-- test/test_variable.py | 4 ++-- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/linopy/constraints.py b/linopy/constraints.py index 7b286a87..3985c909 100644 --- a/linopy/constraints.py +++ b/linopy/constraints.py @@ -587,10 +587,10 @@ def to_polars(self): lf = pl.concat([short, long], how="diagonal").sort(["labels", "rhs"]) # delete subsequent non-null rhs (happens is all vars per label are -1) - is_non_null = lf["rhs"].is_not_null() + is_non_null = pl.col("rhs").is_not_null() prev_non_is_null = is_non_null.shift(1).fill_null(False) lf = lf.filter(is_non_null & ~prev_non_is_null | ~is_non_null) - return lf[["labels", "coeffs", "vars", "sign", "rhs"]] + return lf.select(pl.col(["labels", "coeffs", "vars", "sign", "rhs"])) sel = conwrap(Dataset.sel) diff --git a/linopy/io.py b/linopy/io.py index 855dac1d..9582fb32 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -294,7 +294,7 @@ def to_pyarrow_schema(schema): f, to_pyarrow_schema(lf.schema), write_options=pa.csv.WriteOptions( - include_header=False, delimiter=" ", quoting_style="none" + include_header=False, delimiter=";", quoting_style="none" ), ) @@ -325,7 +325,7 @@ def objective_write_quadratic_terms_polars(f, lf): pl.col("vars2").cast(pl.String), ] f.write(b"+ [\n") - write_lazyframe(lf.select(pl.concat_str(cols, ignore_nulls=True))) + write_lazyframe(f, lf.select(pl.concat_str(cols, ignore_nulls=True))) f.write(b"] / 2\n") @@ -471,14 +471,14 @@ def constraints_to_file_polars(m, f, log=False, lazy=False): ) columns = [ - pl.when(pl.col("labels").is_not_null()).then(pl.lit("c")).alias("c"), + pl.when(pl.col("labels").is_not_null()).then(pl.lit("c")), pl.col("labels").cast(pl.String), - pl.when(pl.col("labels").is_not_null()).then(pl.lit(":\n")).alias(":"), + pl.when(pl.col("labels").is_not_null()).then(pl.lit(": ")), pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")), pl.col("coeffs").cast(pl.String), - pl.when(pl.col("vars").is_not_null()).then(pl.lit(" x")).alias("x"), + pl.when(pl.col("vars").is_not_null()).then(pl.lit(" x")), pl.col("vars").cast(pl.String), - "sign", + pl.col("sign"), pl.lit(" "), pl.col("rhs").cast(pl.String), ] diff --git a/test/test_constraint.py b/test/test_constraint.py index 7866fc94..93b12bac 100644 --- a/test/test_constraint.py +++ b/test/test_constraint.py @@ -335,7 +335,7 @@ def test_constraint_flat(c): def test_constraint_to_polars(c): - assert isinstance(c.to_polars(), pl.DataFrame) + assert isinstance(c.to_polars(), pl.LazyFrame) def test_constraint_assignment_with_anonymous_constraints(m, x, y): diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 74c31fe7..07c519be 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -529,8 +529,8 @@ def test_linear_expression_to_polars(v): coeff = np.arange(1, 21) # use non-zero coefficients expr = coeff * v df = expr.to_polars() - assert isinstance(df, pl.DataFrame) - assert (df["coeffs"].to_numpy() == coeff).all() + assert isinstance(df, pl.LazyFrame) + assert (df.collect()["coeffs"].to_numpy() == coeff).all() def test_linear_expression_where(v): diff --git a/test/test_quadratic_expression.py b/test/test_quadratic_expression.py index 2d70e621..c46907df 100644 --- a/test/test_quadratic_expression.py +++ b/test/test_quadratic_expression.py @@ -221,10 +221,10 @@ def test_quadratic_expression_flat(x, y): def test_linear_expression_to_polars(x, y): expr = x * y + x + 5 df = expr.to_polars() - assert isinstance(df, pl.DataFrame) + assert isinstance(df, pl.LazyFrame) assert "vars1" in df.columns assert "vars2" in df.columns - assert len(df) == expr.nterm * 2 + assert len(df.collect()) == expr.nterm * 2 def test_quadratic_expression_to_matrix(model, x, y): diff --git a/test/test_variable.py b/test/test_variable.py index df6fee28..b461ac2b 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -260,8 +260,8 @@ def test_variable_flat(x): def test_variable_polars(x): result = x.to_polars() - assert isinstance(result, pl.DataFrame) - assert len(result) == x.size + assert isinstance(result, pl.LazyFrame) + assert len(result.collect()) == x.size def test_variable_sanitize(x): From d82e97e3f9156bd0c431a417a518acf90ce2e429 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Sun, 19 May 2024 22:46:35 +0200 Subject: [PATCH 7/7] Add pyarrow dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 116a62d8..7a7d00af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "xarray>=2023.9.0", "dask>=0.18.0", "polars", + "pyarrow", "tqdm", "deprecation", ]