From 9fa7399de209c98284422b4e8562594f6027e26d Mon Sep 17 00:00:00 2001 From: Jarrett Tierney Date: Mon, 4 Oct 2021 11:32:37 -0700 Subject: [PATCH] AWS DocumentDB support through OP_MSG and handling different OK type --- mongodb/vibe/db/mongo/connection.d | 188 ++++++++++++++++++++--------- mongodb/vibe/db/mongo/cursor.d | 26 ++-- 2 files changed, 148 insertions(+), 66 deletions(-) diff --git a/mongodb/vibe/db/mongo/connection.d b/mongodb/vibe/db/mongo/connection.d index 27aaab3566..f17e35218f 100644 --- a/mongodb/vibe/db/mongo/connection.d +++ b/mongodb/vibe/db/mongo/connection.d @@ -17,7 +17,7 @@ import vibe.db.mongo.flags; import vibe.inet.webform; import vibe.stream.tls; -import std.algorithm : map, splitter; +import std.algorithm : map, splitter, canFind; import std.array; import std.conv; import std.digest.md; @@ -208,6 +208,7 @@ final class MongoConnection { "os": Bson(["type": Bson(os.os.to!string), "architecture": Bson(hostArchitecture)]), "platform": Bson(platform) ]); + string cn = (m_settings.database == string.init ? "admin" : m_settings.database) ~ ".$cmd"; if (m_settings.appName.length) { enforce!MongoAuthException(m_settings.appName.length <= 128, @@ -215,13 +216,14 @@ final class MongoConnection { handshake["client"]["application"] = Bson(["name": Bson(m_settings.appName)]); } - query!Bson("$external.$cmd", QueryFlags.none, 0, -1, handshake, Bson(null), + query!Bson(cn, QueryFlags.none, 0, -1, handshake, Bson(null), (cursor, flags, first_doc, num_docs) { enforce!MongoDriverException(!(flags & ReplyFlags.QueryFailure) && num_docs == 1, "Authentication handshake failed."); }, (idx, ref doc) { - enforce!MongoAuthException(doc["ok"].get!double == 1.0, "Authentication failed."); + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + enforce!MongoAuthException(flag == 1.0, "Authentication failed."); m_description = deserializeBson!ServerDescription(doc); }); @@ -306,26 +308,50 @@ final class MongoConnection { { scope(failure) disconnect(); foreach (d; documents) if (d["_id"].isNull()) d["_id"] = Bson(BsonObjectID.generate()); - send(OpCode.Insert, -1, cast(int)flags, collection_name, documents); + string collection = collection_name.canFind(".") ? collection_name.split(".")[1] : collection_name; + Bson operation = Bson([ + "insert": Bson(collection), + "$db": Bson(m_settings.database == string.init ? "admin" : m_settings.database), + "documents": Bson(documents) + ]); + auto id = send(OpCode.Msg, -1, cast(int)0, cast(ubyte)0, operation); + recvReply!Bson(id, (long cursor, ReplyFlags flags, int first, int num) { + + }, (size_t idx, ref Bson doc) {}); if (m_settings.safe) checkForError(collection_name); } - void query(T)(string collection_name, QueryFlags flags, int nskip, int nret, Bson query, Bson returnFieldSelector, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc) + void query(T)(string collection_name, QueryFlags flags, int nskip, int nret, Bson query, Bson returnFieldSelector, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc, bool isAggregate = false) { scope(failure) disconnect(); flags |= m_settings.defQueryFlags; int id; - if (returnFieldSelector.isNull) - id = send(OpCode.Query, -1, cast(int)flags, collection_name, nskip, nret, query); - else - id = send(OpCode.Query, -1, cast(int)flags, collection_name, nskip, nret, query, returnFieldSelector); - recvReply!T(id, on_msg, on_doc); + Bson op = query; + op["$db"] = Bson(m_settings.database == string.init ? "admin" : m_settings.database); + + if (!returnFieldSelector.isNull) + op["projection"] = returnFieldSelector; + + if (flags && QueryFlags.tailableCursor) + op["tailable"] = Bson(true); + if (flags && QueryFlags.awaitData) + op["awaitData"] = Bson(true); + + Bson operation = op; + id = send(OpCode.Msg, -1, cast(int)0, cast(ubyte)0, operation); + recvReply!T(id, on_msg, on_doc, isAggregate); } void getMore(T)(string collection_name, int nret, long cursor_id, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc) { scope(failure) disconnect(); - auto id = send(OpCode.GetMore, -1, cast(int)0, collection_name, nret, cursor_id); + string collection = collection_name.canFind(".") ? collection_name.split(".")[1] : collection_name; + Bson operation = Bson([ + "getMore": Bson(cursor_id), + "$db": Bson(m_settings.database == string.init ? "admin" : m_settings.database), + "collection": Bson(collection), + ]); + auto id = send(OpCode.Msg, -1, cast(int)0, cast(ubyte)0, operation); recvReply!T(id, on_msg, on_doc); } @@ -418,7 +444,8 @@ final class MongoConnection { Bson result; void on_doc(size_t idx, ref Bson doc) { - if (doc["ok"].get!double != 1.0) + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + if (flag != 1.0) throw new MongoAuthException("listDatabases failed."); result = doc["databases"]; @@ -429,7 +456,7 @@ final class MongoConnection { return result.byValue.map!toInfo; } - private int recvReply(T)(int reqid, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc) + private int recvReply(T)(int reqid, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc, bool isAggregate = false) { import std.traits; @@ -440,47 +467,86 @@ final class MongoConnection { int opcode = recvInt(); enforce(respto == reqid, "Reply is not for the expected message on a sequential connection!"); - enforce(opcode == OpCode.Reply, "Got a non-'Reply' reply!"); - - auto flags = cast(ReplyFlags)recvInt(); - long cursor = recvLong(); - int start = recvInt(); - int numret = recvInt(); - - scope (exit) { - if (m_bytesRead - bytes_read < msglen) { - logWarn("MongoDB reply was longer than expected, skipping the rest: %d vs. %d", msglen, m_bytesRead - bytes_read); - ubyte[] dst = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)]; - recv(dst); - } else if (m_bytesRead - bytes_read > msglen) { - logWarn("MongoDB reply was shorter than expected. Dropping connection."); - disconnect(); - throw new MongoDriverException("MongoDB reply was too short for data."); + enforce(opcode == OpCode.Reply || opcode == OpCode.Msg, "Got a non-'Reply' reply!"); + + if (opcode == OpCode.Reply) { + auto flags = cast(ReplyFlags)recvInt(); + long cursor = recvLong(); + int start = recvInt(); + int numret = recvInt(); + + scope (exit) { + if (m_bytesRead - bytes_read < msglen) { + logWarn("MongoDB reply was longer than expected, skipping the rest: %d vs. %d", msglen, m_bytesRead - bytes_read); + ubyte[] dst = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)]; + recv(dst); + } else if (m_bytesRead - bytes_read > msglen) { + logWarn("MongoDB reply was shorter than expected. Dropping connection."); + disconnect(); + throw new MongoDriverException("MongoDB reply was too short for data."); + } } - } - on_msg(cursor, flags, start, numret); - static if (hasIndirections!T || is(T == Bson)) - auto buf = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)]; - foreach (i; 0 .. cast(size_t)numret) { - // TODO: directly deserialize from the wire - static if (!hasIndirections!T && !is(T == Bson)) { - ubyte[256] buf = void; - ubyte[] bufsl = buf; - auto bson = () @trusted { return recvBson(bufsl); } (); - } else { - auto bson = () @trusted { return recvBson(buf); } (); - } + on_msg(cursor, flags, start, numret); + static if (hasIndirections!T || is(T == Bson)) + auto buf = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)]; + foreach (i; 0 .. cast(size_t)numret) { + // TODO: directly deserialize from the wire + static if (!hasIndirections!T && !is(T == Bson)) { + ubyte[256] buf = void; + ubyte[] bufsl = buf; + auto bson = () @trusted { return recvBson(bufsl); } (); + } else { + auto bson = () @trusted { return recvBson(buf); } (); + } - // logDebugV("Received mongo response on %s:%s: %s", reqid, i, bson); + static if (is(T == Bson)) on_doc(i, bson); + else { + T doc = deserializeBson!T(bson); + on_doc(i, doc); + } + } - static if (is(T == Bson)) on_doc(i, bson); - else { - T doc = deserializeBson!T(bson); - on_doc(i, doc); + return resid; + } else { + auto flags = cast(ReplyFlags)recvInt(); + ubyte kind = recvByte(); + if (kind == 0) { + on_msg(0, flags, 0, 1); + static if (hasIndirections!T || is(T == Bson)) + auto buf = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)]; + static if (!hasIndirections!T && !is(T == Bson)) { + ubyte[256] buf = void; + ubyte[] bufsl = buf; + auto bson = () @trusted { return recvBson(bufsl); } (); + } else { + auto bson = () @trusted { return recvBson(buf); } (); + } + + if (!bson["cursor"].isNull && !isAggregate) { + Bson cursor = bson["cursor"]; + on_msg(cursor["id"].get!long, flags, 0, cast(int)cursor["firstBatch"].length); + auto index = 0; + foreach (element; cursor["firstBatch"].byValue()) { + static if (is(T == Bson)) on_doc(index, element); + else { + T doc = deserializeBson!T(element); + on_doc(index, doc); + } + index++; + } + } else { + on_msg(0, flags, 0, 1); + static if (is(T == Bson)) on_doc(0, bson); + else { + T doc = deserializeBson!T(bson); + on_doc(0, doc); + } + } + } else { + throw new MongoDriverException("Kind in reply was 1?? Why?"); } } - return resid; } @@ -507,6 +573,7 @@ final class MongoConnection { { import std.traits; static if (is(T == int)) sendBytes(toBsonData(value)); + else static if (is(T == ubyte)) sendBytes(toBsonData(value)); else static if (is(T == long)) sendBytes(toBsonData(value)); else static if (is(T == Bson)) sendBytes(value.data); else static if (is(T == string)) { @@ -520,6 +587,7 @@ final class MongoConnection { private void sendBytes(in ubyte[] data){ m_outRange.put(data); } + private byte recvByte() { ubyte[1] ret; recv(ret); return fromBsonData!byte(ret); } private int recvInt() { ubyte[int.sizeof] ret; recv(ret); return fromBsonData!int(ret); } private long recvLong() { ubyte[long.sizeof] ret; recv(ret); return fromBsonData!long(ret); } private Bson recvBson(ref ubyte[] buf) @@ -567,13 +635,15 @@ final class MongoConnection { cmd["user"] = Bson(m_settings.username); } - query!Bson("$external.$cmd", QueryFlags.None, 0, -1, cmd, Bson(null), + string cn = (m_settings.database == string.init ? "admin" : m_settings.database) ~ ".$cmd"; + query!Bson(cn, QueryFlags.None, 0, -1, cmd, Bson(null), (cursor, flags, first_doc, num_docs) { if ((flags & ReplyFlags.QueryFailure) || num_docs != 1) throw new MongoDriverException("Calling authenticate failed."); }, (idx, ref doc) { - if (doc["ok"].get!double != 1.0) + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + if (flag != 1.0) throw new MongoAuthException("Authentication failed."); } ); @@ -592,7 +662,8 @@ final class MongoConnection { throw new MongoDriverException("Calling getNonce failed."); }, (idx, ref doc) { - if (doc["ok"].get!double != 1.0) + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + if (flag != 1.0) throw new MongoDriverException("getNonce failed."); nonce = doc["nonce"].get!string; key = toLower(toHexString(md5Of(nonce ~ m_settings.username ~ m_settings.digest)).idup); @@ -611,7 +682,8 @@ final class MongoConnection { throw new MongoDriverException("Calling authenticate failed."); }, (idx, ref doc) { - if (doc["ok"].get!double != 1.0) + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + if (flag != 1.0) throw new MongoAuthException("Authentication failed."); } ); @@ -637,7 +709,8 @@ final class MongoConnection { throw new MongoDriverException("SASL start failed."); }, (idx, ref doc) { - if (doc["ok"].get!double != 1.0) + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + if (flag != 1.0) throw new MongoAuthException("Authentication failed."); response = cast(string)doc["payload"].get!BsonBinData().rawData; conversationId = doc["conversationId"]; @@ -653,7 +726,8 @@ final class MongoConnection { throw new MongoDriverException("SASL continue failed."); }, (idx, ref doc) { - if (doc["ok"].get!double != 1.0) + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + if (flag != 1.0) throw new MongoAuthException("Authentication failed."); response = cast(string)doc["payload"].get!BsonBinData().rawData; }); @@ -669,7 +743,8 @@ final class MongoConnection { throw new MongoDriverException("SASL finish failed."); }, (idx, ref doc) { - if (doc["ok"].get!double != 1.0) + auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double; + if (flag != 1.0) throw new MongoAuthException("Authentication failed."); }); } @@ -677,7 +752,7 @@ final class MongoConnection { private enum OpCode : int { Reply = 1, // sent only by DB - Msg = 1000, + Msg = 2013, Update = 2001, Insert = 2002, Reserved1 = 2003, @@ -703,6 +778,7 @@ private int sendLength(ARGS...)(ARGS args) static if (ARGS.length == 1) { alias T = ARGS[0]; static if (is(T == string)) return cast(int)args[0].length + 1; + else static if (is(T == ubyte)) return 1; else static if (is(T == int)) return 4; else static if (is(T == long)) return 8; else static if (is(T == Bson)) return cast(int)args[0].data.length; diff --git a/mongodb/vibe/db/mongo/cursor.d b/mongodb/vibe/db/mongo/cursor.d index daae027d32..c414355294 100644 --- a/mongodb/vibe/db/mongo/cursor.d +++ b/mongodb/vibe/db/mongo/cursor.d @@ -12,8 +12,8 @@ public import vibe.data.bson; import vibe.db.mongo.connection; import vibe.db.mongo.client; -import std.array : array; -import std.algorithm : map, max, min; +import std.array : array, split; +import std.algorithm : map, max, min, canFind, endsWith; import std.exception; deprecated alias MongoCursor(Q, R = Bson, S = Bson) = MongoCursor!R; @@ -352,20 +352,26 @@ private class MongoFindCursor(Q, R, S) : MongoCursorData!R { query = () @trusted { return serializeToBson(m_query, query_buf); } (); } - Bson full_query; + Bson full_query = Bson.emptyObject; - if (!query["query"].isNull() || !query["$query"].isNull()) { - // TODO: emit deprecation warning - full_query = query; + string collection; + if (query["aggregate"].isNull && m_collection.canFind(".") && !m_collection.endsWith("$cmd")) { + collection = m_collection.split(".")[1]; } else { - full_query = Bson.emptyObject; - full_query["$query"] = query; + collection = m_collection; } - if (!m_sort.isNull()) full_query["orderby"] = m_sort; + if (query["aggregate"].isNull) { + full_query["find"] = Bson(collection); + full_query["filter"] = query; + } else { + full_query = query; + } - conn.query!R(m_collection, m_flags, m_nskip, m_nret, full_query, selector, &handleReply, &handleDocument); + if (!m_sort.isNull()) full_query["orderby"] = m_sort; + conn.query!R(collection, m_flags, m_nskip, m_nret, full_query, selector, &handleReply, &handleDocument, !query["aggregate"].isNull); + m_iterationStarted = true; } }