@@ -1239,6 +1239,52 @@ object doublearrays:
12391239
12401240 // def max: Double =
12411241 // vec(blas.idamax(vec.length, vec, 1)) // No JS version
1242+
1243+ inline def `zeroWhere!` (
1244+ other : Array [Double ],
1245+ threshold : Double ,
1246+ inline op : ComparisonOp
1247+ ): Unit =
1248+ assert(vec.length == other.length)
1249+ val zero = DoubleVector .zero(spd)
1250+ val thresh = DoubleVector .broadcast(spd, threshold)
1251+ var i = 0
1252+
1253+ while i < spd.loopBound(vec.length) do
1254+ val values = DoubleVector .fromArray(spd, vec, i)
1255+ val cmp = DoubleVector .fromArray(spd, other, i)
1256+ val mask = inline op match
1257+ case ComparisonOp .LE => cmp.compare(VectorOperators .LE , thresh)
1258+ case ComparisonOp .LT => cmp.compare(VectorOperators .LT , thresh)
1259+ case ComparisonOp .GE => cmp.compare(VectorOperators .GE , thresh)
1260+ case ComparisonOp .GT => cmp.compare(VectorOperators .GT , thresh)
1261+ case ComparisonOp .EQ => cmp.compare(VectorOperators .EQ , thresh)
1262+ case ComparisonOp .NE => cmp.compare(VectorOperators .NE , thresh)
1263+ values.blend(zero, mask).intoArray(vec, i)
1264+ i += spdl
1265+ end while
1266+
1267+ while i < vec.length do
1268+ val hit = inline op match
1269+ case ComparisonOp .LE => other(i) <= threshold
1270+ case ComparisonOp .LT => other(i) < threshold
1271+ case ComparisonOp .GE => other(i) >= threshold
1272+ case ComparisonOp .GT => other(i) > threshold
1273+ case ComparisonOp .EQ => other(i) == threshold
1274+ case ComparisonOp .NE => other(i) != threshold
1275+ if hit then vec(i) = 0.0
1276+ end if
1277+ i += 1
1278+ end while
1279+ end `zeroWhere!`
1280+
1281+ inline def zeroWhere (
1282+ other : Array [Double ],
1283+ threshold : Double ,
1284+ inline op : ComparisonOp
1285+ ): Array [Double ] =
1286+ vec.clone().tap(_.`zeroWhere!`(other, threshold, op))
1287+
12421288 end extension
12431289
12441290end doublearrays
0 commit comments