Skip to content

Commit b6acb42

Browse files
CopilotQuafadasautofix-ci[bot]
authored
Add zeroWhere! fused masked zeroing primitive with ComparisonOp (#91)
* Initial plan * Add zeroWhere! fused masked zeroing operator with ComparisonOp enum Agent-Logs-Url: https://github.com/Quafadas/vecxt/sessions/2050e7c1-327a-4bfb-83c2-b678afa73246 Co-authored-by: Quafadas <24899792+Quafadas@users.noreply.github.com> * Remove large array JVM test from CI (correctness only) Agent-Logs-Url: https://github.com/Quafadas/vecxt/sessions/ab12a61a-1a73-407b-b445-42d377914d57 Co-authored-by: Quafadas <24899792+Quafadas@users.noreply.github.com> * [autofix.ci] apply automated fixes --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Quafadas <24899792+Quafadas@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent d827c11 commit b6acb42

File tree

9 files changed

+442
-0
lines changed

9 files changed

+442
-0
lines changed

vecxt/src-js/doublearrays.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,35 @@ object doublearrays:
767767

768768
def maxElement: Double = vec.max
769769
// val t = js.Math.max( vec.toArray: _* )
770+
771+
inline def `zeroWhere!`(
772+
other: Array[Double],
773+
threshold: Double,
774+
inline op: ComparisonOp
775+
): Unit =
776+
assert(vec.length == other.length)
777+
var i = 0
778+
while i < vec.length do
779+
val hit = inline op match
780+
case ComparisonOp.LE => other(i) <= threshold
781+
case ComparisonOp.LT => other(i) < threshold
782+
case ComparisonOp.GE => other(i) >= threshold
783+
case ComparisonOp.GT => other(i) > threshold
784+
case ComparisonOp.EQ => other(i) == threshold
785+
case ComparisonOp.NE => other(i) != threshold
786+
if hit then vec(i) = 0.0
787+
end if
788+
i += 1
789+
end while
790+
end `zeroWhere!`
791+
792+
inline def zeroWhere(
793+
other: Array[Double],
794+
threshold: Double,
795+
inline op: ComparisonOp
796+
): Array[Double] =
797+
vec.clone().tap(_.`zeroWhere!`(other, threshold, op))
798+
770799
end extension
771800

772801
end doublearrays

vecxt/src-js/floatarrays.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,34 @@ object floatarrays:
835835
(cv / (vec.length - 1)).toFloat
836836
end covariance
837837

838+
inline def `zeroWhere!`(
839+
other: Array[Float],
840+
threshold: Float,
841+
inline op: ComparisonOp
842+
): Unit =
843+
assert(vec.length == other.length)
844+
var i = 0
845+
while i < vec.length do
846+
val hit = inline op match
847+
case ComparisonOp.LE => other(i) <= threshold
848+
case ComparisonOp.LT => other(i) < threshold
849+
case ComparisonOp.GE => other(i) >= threshold
850+
case ComparisonOp.GT => other(i) > threshold
851+
case ComparisonOp.EQ => other(i) == threshold
852+
case ComparisonOp.NE => other(i) != threshold
853+
if hit then vec(i) = 0.0f
854+
end if
855+
i += 1
856+
end while
857+
end `zeroWhere!`
858+
859+
inline def zeroWhere(
860+
other: Array[Float],
861+
threshold: Float,
862+
inline op: ComparisonOp
863+
): Array[Float] =
864+
vec.clone().tap(_.`zeroWhere!`(other, threshold, op))
865+
838866
end extension
839867

840868
extension (vec: Array[Array[Double]])

vecxt/src-jvm/doublearrays.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,52 @@ object doublearrays:
12391239

12401240
// def max: Double =
12411241
// vec(blas.idamax(vec.length, vec, 1)) // No JS version
1242+
1243+
inline def `zeroWhere!`(
1244+
other: Array[Double],
1245+
threshold: Double,
1246+
inline op: ComparisonOp
1247+
): Unit =
1248+
assert(vec.length == other.length)
1249+
val zero = DoubleVector.zero(spd)
1250+
val thresh = DoubleVector.broadcast(spd, threshold)
1251+
var i = 0
1252+
1253+
while i < spd.loopBound(vec.length) do
1254+
val values = DoubleVector.fromArray(spd, vec, i)
1255+
val cmp = DoubleVector.fromArray(spd, other, i)
1256+
val mask = inline op match
1257+
case ComparisonOp.LE => cmp.compare(VectorOperators.LE, thresh)
1258+
case ComparisonOp.LT => cmp.compare(VectorOperators.LT, thresh)
1259+
case ComparisonOp.GE => cmp.compare(VectorOperators.GE, thresh)
1260+
case ComparisonOp.GT => cmp.compare(VectorOperators.GT, thresh)
1261+
case ComparisonOp.EQ => cmp.compare(VectorOperators.EQ, thresh)
1262+
case ComparisonOp.NE => cmp.compare(VectorOperators.NE, thresh)
1263+
values.blend(zero, mask).intoArray(vec, i)
1264+
i += spdl
1265+
end while
1266+
1267+
while i < vec.length do
1268+
val hit = inline op match
1269+
case ComparisonOp.LE => other(i) <= threshold
1270+
case ComparisonOp.LT => other(i) < threshold
1271+
case ComparisonOp.GE => other(i) >= threshold
1272+
case ComparisonOp.GT => other(i) > threshold
1273+
case ComparisonOp.EQ => other(i) == threshold
1274+
case ComparisonOp.NE => other(i) != threshold
1275+
if hit then vec(i) = 0.0
1276+
end if
1277+
i += 1
1278+
end while
1279+
end `zeroWhere!`
1280+
1281+
inline def zeroWhere(
1282+
other: Array[Double],
1283+
threshold: Double,
1284+
inline op: ComparisonOp
1285+
): Array[Double] =
1286+
vec.clone().tap(_.`zeroWhere!`(other, threshold, op))
1287+
12421288
end extension
12431289

12441290
end doublearrays

vecxt/src-jvm/floatarrays.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -938,5 +938,50 @@ object floatarrays:
938938
Matrix(out, (n, m))(using BoundsCheck.DoBoundsCheck.no)
939939
end outer
940940

941+
inline def `zeroWhere!`(
942+
other: Array[Float],
943+
threshold: Float,
944+
inline op: ComparisonOp
945+
): Unit =
946+
assert(vec.length == other.length)
947+
val zero = FloatVector.zero(spf)
948+
val thresh = FloatVector.broadcast(spf, threshold)
949+
var i = 0
950+
951+
while i < spf.loopBound(vec.length) do
952+
val values = FloatVector.fromArray(spf, vec, i)
953+
val cmp = FloatVector.fromArray(spf, other, i)
954+
val mask = inline op match
955+
case ComparisonOp.LE => cmp.compare(VectorOperators.LE, thresh)
956+
case ComparisonOp.LT => cmp.compare(VectorOperators.LT, thresh)
957+
case ComparisonOp.GE => cmp.compare(VectorOperators.GE, thresh)
958+
case ComparisonOp.GT => cmp.compare(VectorOperators.GT, thresh)
959+
case ComparisonOp.EQ => cmp.compare(VectorOperators.EQ, thresh)
960+
case ComparisonOp.NE => cmp.compare(VectorOperators.NE, thresh)
961+
values.blend(zero, mask).intoArray(vec, i)
962+
i += spfl
963+
end while
964+
965+
while i < vec.length do
966+
val hit = inline op match
967+
case ComparisonOp.LE => other(i) <= threshold
968+
case ComparisonOp.LT => other(i) < threshold
969+
case ComparisonOp.GE => other(i) >= threshold
970+
case ComparisonOp.GT => other(i) > threshold
971+
case ComparisonOp.EQ => other(i) == threshold
972+
case ComparisonOp.NE => other(i) != threshold
973+
if hit then vec(i) = 0.0f
974+
end if
975+
i += 1
976+
end while
977+
end `zeroWhere!`
978+
979+
inline def zeroWhere(
980+
other: Array[Float],
981+
threshold: Float,
982+
inline op: ComparisonOp
983+
): Array[Float] =
984+
vec.clone().tap(_.`zeroWhere!`(other, threshold, op))
985+
941986
end extension
942987
end floatarrays

vecxt/src-native/doublearrays.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,35 @@ object doublearrays:
751751
end covariance
752752

753753
// def max: Double = vec(blas.cblas_idamax(vec.length, vec.at(0), 1)) // No JS version
754+
755+
inline def `zeroWhere!`(
756+
other: Array[Double],
757+
threshold: Double,
758+
inline op: ComparisonOp
759+
): Unit =
760+
assert(vec.length == other.length)
761+
var i = 0
762+
while i < vec.length do
763+
val hit = inline op match
764+
case ComparisonOp.LE => other(i) <= threshold
765+
case ComparisonOp.LT => other(i) < threshold
766+
case ComparisonOp.GE => other(i) >= threshold
767+
case ComparisonOp.GT => other(i) > threshold
768+
case ComparisonOp.EQ => other(i) == threshold
769+
case ComparisonOp.NE => other(i) != threshold
770+
if hit then vec(i) = 0.0
771+
end if
772+
i += 1
773+
end while
774+
end `zeroWhere!`
775+
776+
inline def zeroWhere(
777+
other: Array[Double],
778+
threshold: Double,
779+
inline op: ComparisonOp
780+
): Array[Double] =
781+
vec.clone().tap(_.`zeroWhere!`(other, threshold, op))
782+
754783
end extension
755784

756785
end doublearrays

vecxt/src-native/floatarrays.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,34 @@ object floatarrays:
818818
(cv / (vec.length - 1)).toFloat
819819
end covariance
820820

821+
inline def `zeroWhere!`(
822+
other: Array[Float],
823+
threshold: Float,
824+
inline op: ComparisonOp
825+
): Unit =
826+
assert(vec.length == other.length)
827+
var i = 0
828+
while i < vec.length do
829+
val hit = inline op match
830+
case ComparisonOp.LE => other(i) <= threshold
831+
case ComparisonOp.LT => other(i) < threshold
832+
case ComparisonOp.GE => other(i) >= threshold
833+
case ComparisonOp.GT => other(i) > threshold
834+
case ComparisonOp.EQ => other(i) == threshold
835+
case ComparisonOp.NE => other(i) != threshold
836+
if hit then vec(i) = 0.0f
837+
end if
838+
i += 1
839+
end while
840+
end `zeroWhere!`
841+
842+
inline def zeroWhere(
843+
other: Array[Float],
844+
threshold: Float,
845+
inline op: ComparisonOp
846+
): Array[Float] =
847+
vec.clone().tap(_.`zeroWhere!`(other, threshold, op))
848+
821849
end extension
822850

823851
extension (vec: Array[Array[Double]])

vecxt/src/ComparisonOp.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package vecxt
2+
3+
enum ComparisonOp:
4+
case LT, LE, GT, GE, EQ, NE
5+
end ComparisonOp

vecxt/src/all.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ object all:
1818
export vecxt.IntArraysX.*
1919

2020
export vecxt.VarianceMode
21+
export vecxt.ComparisonOp
2122

2223
// matricies
2324
export vecxt.OneAndZero.given_OneAndZero_Boolean

0 commit comments

Comments
 (0)