Skip to content

Commit

Permalink
Add support for $addToSet
Browse files Browse the repository at this point in the history
  • Loading branch information
denis_savitsky committed Feb 2, 2024
1 parent 07a46f9 commit caa32d6
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 9 deletions.
24 changes: 24 additions & 0 deletions oolong-core/src/main/scala/oolong/AstParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ private[oolong] class DefaultAstParser(using quotes: Quotes) extends AstParser {
val value = getValue(valueExpr)
parseUpdater(updater, FieldUpdateExpr.SetOnInsert(UExpr.Prop(prop), value) :: acc)

case '{type t; ($updater: Updater[Doc]).addToSetAll[`t`, `t`]($selectProp, ($valueExpr: Iterable[`t`]))} =>
val prop = parsePropSelector(selectProp)
val value = getValueOrIterable(valueExpr)
parseUpdater(updater, FieldUpdateExpr.AddToSet(UExpr.Prop(prop), value, multipleValues = true) :: acc)

case '{type t; ($updater: Updater[Doc]).addToSet[`t`, `t`]($selectProp, ($valueExpr: `t` ))} =>
val prop = parsePropSelector(selectProp)
val value = getValueOrIterable(valueExpr)
parseUpdater(updater, FieldUpdateExpr.AddToSet(UExpr.Prop(prop), value, multipleValues = false) :: acc)

case '{ $updater: Updater[Doc] } =>
updater match {
case AsTerm(Ident(name)) if name == paramName =>
Expand Down Expand Up @@ -346,6 +356,20 @@ private[oolong] class DefaultAstParser(using quotes: Quotes) extends AstParser {
}
}

private def getValueOrIterable(expr: Expr[Any]): UExpr =
expr match
case '{ $iter: Iterable[t] } => getIterable(iter)
case base => getValue(base)

def getIterable[T: Type](expr: Expr[Iterable[T]]): UExpr =
expr match {
// AsIterable can ignore lift e.g. in following case: lift(List(List(Random.nextInt()))
case '{ type t; lift($x: Iterable[`t`]) } => UExpr.ScalaCodeIterable(x)
case AsIterable(elems) => UExpr.UIterable(elems.map(getConstant).toList)
case _ =>
report.errorAndAbort("Unexpected expr while parsing AST: " + expr.asTerm.show(using Printer.TreeStructure))
}

private def getValue(expr: Expr[Any]): UExpr =
expr match
case '{ lift($x) } => UExpr.ScalaCode(x)
Expand Down
8 changes: 8 additions & 0 deletions oolong-core/src/main/scala/oolong/UExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@ private[oolong] object UExpr {
case class Prop(path: String) extends UExpr

case class Constant[T](t: T) extends UExpr
case class UIterable[T](t: List[UExpr]) extends UExpr

case class ScalaCode(code: Expr[Any]) extends UExpr

case class ScalaCodeIterable(code: Expr[Iterable[Any]]) extends UExpr

@nowarn("msg=unused explicit parameter") // used in macro
sealed abstract class FieldUpdateExpr(prop: Prop)

object FieldUpdateExpr {

// field update operators
case class Set(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop: Prop)

case class Inc(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop)
Expand All @@ -35,6 +39,10 @@ private[oolong] object UExpr {
case class Rename(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop)

case class SetOnInsert(prop: Prop, expr: UExpr) extends FieldUpdateExpr(prop)

// array update operators
case class AddToSet(prop: Prop, expr: UExpr, multipleValues: Boolean) extends FieldUpdateExpr(prop)

}

}
8 changes: 8 additions & 0 deletions oolong-core/src/main/scala/oolong/dsl/Dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,12 @@ sealed trait Updater[DocT] {
def setOnInsert[PropT, ValueT](selectProp: DocT => PropT, value: ValueT)(using
PropT =:= ValueT,
): Updater[DocT] = useWithinMacro("setOnInsert")

def addToSet[PropT, ValueT](selectProp: DocT => Iterable[PropT], value: ValueT)(using
PropT =:= ValueT
): Updater[DocT] = useWithinMacro("addToSet")

def addToSetAll[PropT, ValueT](selectProp: DocT => Iterable[PropT], value: Iterable[ValueT])(using
PropT =:= ValueT
): Updater[DocT] = useWithinMacro("addToSet")
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import scala.concurrent.ExecutionContext
import com.dimafeng.testcontainers.ForAllTestContainer
import com.dimafeng.testcontainers.MongoDBContainer
import concurrent.duration.DurationInt
import oolong.bson.BsonDecoder
import oolong.dsl.*
import org.mongodb.scala.MongoClient
import org.mongodb.scala.bson.BsonDocument
Expand Down Expand Up @@ -129,4 +130,42 @@ class OolongMongoUpdateSpec extends AsyncFlatSpec with ForAllTestContainer with
} yield assert(res.wasAcknowledged() && res.getMatchedCount == 1 && res.getModifiedCount == 1)
}

it should "update with $addToSet" in {
for {
res <- collection
.updateOne(
query[TestClass](_.field1 == "0"),
update[TestClass](
_.addToSet(_.field4, 3)
)
)
.head()
upd <- collection
.find(query[TestClass](_.field1 == "0"))
.head()
.map(BsonDecoder[TestClass].fromBson(_).get)
} yield assert(
res.wasAcknowledged() && res.getModifiedCount == 1 && upd.field4.size == 3
)
}

it should "update with $addToSet using $each" in {
for {
res <- collection
.updateOne(
query[TestClass](_.field1 == "0"),
update[TestClass](
_.addToSetAll(_.field4, List(4, 5))
)
)
.head()
upd <- collection
.find(query[TestClass](_.field1 == "0"))
.head()
.map(BsonDecoder[TestClass].fromBson(_).get)
} yield assert(
res.wasAcknowledged() && res.getModifiedCount == 1 && upd.field4.size == 4
)
}

}
49 changes: 41 additions & 8 deletions oolong-mongo/src/main/scala/oolong/mongo/MongoUpdateCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import scala.quoted.Type
import oolong.*
import oolong.UExpr.FieldUpdateExpr
import oolong.bson.meta.QueryMeta
import oolong.mongo.MongoUpdateNode.MongoUpdateOp
import oolong.mongo.MongoUpdateNode as MU
import org.mongodb.scala.bson.BsonArray
import org.mongodb.scala.bson.BsonBoolean
import org.mongodb.scala.bson.BsonDocument
import org.mongodb.scala.bson.BsonDouble
Expand Down Expand Up @@ -44,10 +46,14 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] {
case FieldUpdateExpr.SetOnInsert(prop, expr) =>
MU.MongoUpdateOp.SetOnInsert(MU.Prop(renames.getOrElse(prop.path, prop.path)), rec(expr))
case FieldUpdateExpr.Unset(prop) => MU.MongoUpdateOp.Unset(MU.Prop(renames.getOrElse(prop.path, prop.path)))
case FieldUpdateExpr.AddToSet(prop, expr, each) =>
MU.MongoUpdateOp.AddToSet(MU.Prop(renames.getOrElse(prop.path, prop.path)), rec(expr), each)
})
case UExpr.ScalaCode(code) => MU.ScalaCode(code)
case UExpr.Constant(t) => MU.Constant(t)
case _ => report.errorAndAbort("Unexpected expr " + pprint(ast))
case UExpr.ScalaCode(code) => MU.ScalaCode(code)
case UExpr.Constant(t) => MU.Constant(t)
case UExpr.UIterable(t) => MU.UIterable(t.map(rec(_)))
case UExpr.ScalaCodeIterable(t) => MU.ScalaCodeIterable(t)
case _ => report.errorAndAbort("Unexpected expr " + pprint(ast))
}

rec(ast, meta)
Expand Down Expand Up @@ -82,7 +88,10 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] {
)("$rename"),
renderOps(
ops.collect { case s: MU.MongoUpdateOp.SetOnInsert => s }.map(op => render(op.prop) + ": " + render(op.value))
)("$setOnInsert")
)("$setOnInsert"),
renderOps(
ops.collect { case s: MU.MongoUpdateOp.AddToSet => s }.map(renderAddToSet)
)("$addToSet")
).flatten
.mkString("{\n", ",\n", "\n}")

Expand All @@ -100,10 +109,20 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] {
case '{ ${ x }: t } => RenderUtils.renderCaseClass[t](x)
case _ => "?"

case MU.UIterable(iterable) => iterable.map(render).mkString("[", ",", "]")
case MU.ScalaCodeIterable(_) => "[ ? ]"

case _ => report.errorAndAbort(s"Wrong term: $query")
}

def renderOps(ops: List[String])(op: String) =
private def renderAddToSet(op: MU.MongoUpdateOp.AddToSet)(using Quotes): String =
val renderOfValue = render(op.value)
val finalRenderOfValue =
if op.each then s"""{ "$$each" : $renderOfValue }"""
else renderOfValue
render(op.prop) + ": " + finalRenderOfValue

private def renderOps(ops: List[String])(op: String) =
ops match
case Nil => None
case list => Some(s"\t \"$op\": { " + list.mkString(", ") + " }")
Expand All @@ -112,10 +131,15 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] {
import quotes.reflect.*

def targetOps(setters: List[MU.MongoUpdateOp]): List[Expr[(String, BsonValue)]] =
setters.map { case op: MU.MongoUpdateOp =>
setters.map { op =>
val key = op.prop.path
val valueExpr = handleValues(op.value)
'{ ${ Expr(key) } -> $valueExpr }
val finalValueExpr = op match
case addToSet: MongoUpdateOp.AddToSet =>
if addToSet.each then '{ BsonDocument("$each" -> $valueExpr) }
else valueExpr
case _ => valueExpr
'{ ${ Expr(key) } -> $finalValueExpr }
}

optRepr match {
Expand All @@ -128,6 +152,7 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] {
val tMuls = targetOps(ops.collect { case s: MU.MongoUpdateOp.Mul => s })
val tRenames = targetOps(ops.collect { case s: MU.MongoUpdateOp.Rename => s })
val tSetOnInserts = targetOps(ops.collect { case s: MU.MongoUpdateOp.SetOnInsert => s })
val tAddToSets = targetOps(ops.collect { case s: MU.MongoUpdateOp.AddToSet => s })

// format: off
def updaterGroup(groupName: String, updaters: List[Expr[(String, BsonValue)]]): Option[Expr[(String, BsonDocument)]] =
Expand All @@ -147,6 +172,7 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] {
updaterGroup("$mul", tMuls),
updaterGroup("$rename", tRenames),
updaterGroup("$setOnInsert", tSetOnInserts),
updaterGroup("$addToSet", tAddToSets),
).flatten

'{
Expand Down Expand Up @@ -181,6 +207,13 @@ object MongoUpdateCompiler extends Backend[UExpr, MU, BsonDocument] {
case MU.Constant(b: Boolean) =>
'{ BsonBoolean.apply(${ Expr(b: Boolean) }) }
case MU.ScalaCode(code) => BsonUtils.extractLifted(code)
case _ => report.errorAndAbort(s"Given type is not literal constant")
case MU.UIterable(list) =>
'{
BsonArray.fromIterable(${
Expr.ofList(list.map(handleValues))
})
}
case MU.ScalaCodeIterable(exprList) => BsonUtils.extractLifted(exprList)
case _ => report.errorAndAbort(s"Given type is not literal constant")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ case object MongoUpdateNode {

case class Constant[T](t: T) extends MU

case class UIterable[T](t: List[MU]) extends MU

case class ScalaCode(code: Expr[Any]) extends MU

case class ScalaCodeIterable(code: Expr[Iterable[Any]]) extends MU

sealed abstract class MongoUpdateOp(val prop: Prop, val value: MU) extends MU
object MongoUpdateOp {
case class Set(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value)
Expand All @@ -25,5 +29,7 @@ case object MongoUpdateNode {
case class Mul(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value)
case class Rename(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value)
case class SetOnInsert(override val prop: Prop, override val value: MU) extends MongoUpdateOp(prop, value)

case class AddToSet(override val prop: Prop, override val value: MU, each: Boolean) extends MongoUpdateOp(prop, value)
}
}
50 changes: 49 additions & 1 deletion oolong-mongo/src/test/scala/oolong/mongo/UpdateSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class UpdateSpec extends AnyFunSuite {
dateField: LocalDate,
innerClassField: InnerClass,
optionField: Option[Long],
optionInnerClassField: Option[InnerClass]
optionInnerClassField: Option[InnerClass],
listField: List[Int],
classInnerClassField: List[InnerClass],
nestedListField: List[List[Int]]
)

case class InnerClass(
Expand Down Expand Up @@ -170,6 +173,51 @@ class UpdateSpec extends AnyFunSuite {
)
}

test("$addToSet") {
val q = update[TestClass](_.addToSet(_.listField, 1))
val repr = renderUpdate[TestClass](_.addToSet(_.listField, 1))
test(
q,
repr,
BsonDocument("$addToSet" -> BsonDocument("listField" -> BsonInt32(1)))
)
}

test("$addToSet nested") {
val q = update[TestClass](_.addToSet(_.nestedListField, List(1, 2, 3)))
val repr = renderUpdate[TestClass](_.addToSet(_.nestedListField, List(1, 2, 3)))
test(
q,
repr,
BsonDocument("$addToSet" -> BsonDocument("nestedListField" -> BsonArray(BsonInt32(1), BsonInt32(2), BsonInt32(3))))
)
}

test("$addToSet with $each") {
val q = update[TestClass](_.addToSetAll(_.listField, List(1)))
val repr = renderUpdate[TestClass](_.addToSetAll(_.listField, List(1)))
test(
q,
repr,
BsonDocument("$addToSet" -> BsonDocument("listField" -> BsonDocument("$each" -> BsonArray(BsonInt32(1)))))
)
}

test("$addToSet with $each nested") {
val q = update[TestClass](_.addToSetAll(_.nestedListField, lift(List(List(1, 2, 3)))))
val repr = renderUpdate[TestClass](_.addToSetAll(_.nestedListField, lift(List(List(1, 2, 3)))))
test(
q,
repr,
BsonDocument(
"$addToSet" -> BsonDocument(
"nestedListField" -> BsonDocument("$each" -> BsonArray(BsonArray(BsonInt32(1), BsonInt32(2), BsonInt32(3))))
)
),
ignoreRender = true
)
}

test("several update operators combined") {
val q = update[TestClass](
_.unset(_.dateField)
Expand Down

0 comments on commit caa32d6

Please sign in to comment.