diff --git a/CHANGELOG.md b/CHANGELOG.md index 75adedf36..fbd0a4f44 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added - Added basic type stubs to help with IDE autocompletion and type checking. ### Fixed +- Implemented all binary operations between MatrixExpr and GenExpr - Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr. ### Changed - Speed up MatrixVariable.sum(axis=None) via quicksum diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index 2fc56f5cb..7806686db 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -146,7 +146,7 @@ def buildGenExprObj(expr): GenExprs = np.empty(expr.shape, dtype=object) for idx in np.ndindex(expr.shape): GenExprs[idx] = buildGenExprObj(expr[idx]) - return GenExprs + return GenExprs.view(MatrixExpr) else: assert isinstance(expr, GenExpr) @@ -223,6 +223,9 @@ cdef class Expr: return self def __mul__(self, other): + if isinstance(other, MatrixExpr): + return other * self + if _is_number(other): f = float(other) return Expr({v:f*c for v,c in self.terms.items()}) @@ -420,6 +423,9 @@ cdef class GenExpr: return UnaryExpr(Operator.fabs, self) def __add__(self, other): + if isinstance(other, MatrixExpr): + return other + self + left = buildGenExprObj(self) right = buildGenExprObj(other) ans = SumExpr() @@ -475,6 +481,9 @@ cdef class GenExpr: # return self def __mul__(self, other): + if isinstance(other, MatrixExpr): + return other * self + left = buildGenExprObj(self) right = buildGenExprObj(other) ans = ProdExpr() @@ -537,7 +546,7 @@ cdef class GenExpr: def __truediv__(self,other): divisor = buildGenExprObj(other) # we can't divide by 0 - if divisor.getOp() == Operator.const and divisor.number == 0.0: + if isinstance(divisor, GenExpr) and divisor.getOp() == Operator.const and divisor.number == 0.0: raise ZeroDivisionError("cannot divide by 0") return self * divisor**(-1) diff --git a/tests/test_matrix_variable.py b/tests/test_matrix_variable.py index 04f9d0fcc..5f0ee0ca0 100644 --- a/tests/test_matrix_variable.py +++ b/tests/test_matrix_variable.py @@ -1,3 +1,10 @@ +import operator +import pdb +import pprint +import pytest +from pyscipopt import Model, Variable, log, exp, cos, sin, sqrt +from pyscipopt import Expr, MatrixExpr, MatrixVariable, MatrixExprCons, MatrixConstraint, ExprCons +from pyscipopt.scip import GenExpr from time import time import numpy as np @@ -209,7 +216,7 @@ def test_matrix_sum_argument(): assert (m.getVal(x) == np.full((2, 3), 4)).all().all() assert (m.getVal(y) == np.full((2, 4), 3)).all().all() - +@pytest.mark.skip(reason="Performance test") def test_sum_performance(): n = 1000 model = Model() @@ -442,6 +449,25 @@ def test_matrix_cons_indicator(): assert m.getVal(z) == 1 +_binop_model = Model() + +def var(): + return _binop_model.addVar() + +def genexpr(): + return _binop_model.addVar() ** 0.6 + +def matvar(): + return _binop_model.addMatrixVar((1,)) + +@pytest.mark.parametrize("right", [var(), genexpr(), matvar()], ids=["var", "genexpr", "matvar"]) +@pytest.mark.parametrize("left", [var(), genexpr(), matvar()], ids=["var", "genexpr", "matvar"]) +@pytest.mark.parametrize("op", [operator.add, operator.sub, operator.mul, operator.truediv]) +def test_binop(op, left, right): + res = op(left, right) + assert isinstance(res, (Expr, GenExpr, MatrixExpr)) + + def test_matrix_matmul_return_type(): # test #1058, require returning type is MatrixExpr not MatrixVariable m = Model()