Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWS DocumentDB support through OP_MSG and handling different OK type #2616

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 132 additions & 56 deletions mongodb/vibe/db/mongo/connection.d
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import vibe.db.mongo.flags;
import vibe.inet.webform;
import vibe.stream.tls;

import std.algorithm : map, splitter;
import std.algorithm : map, splitter, canFind;
import std.array;
import std.conv;
import std.digest.md;
Expand Down Expand Up @@ -208,20 +208,22 @@ final class MongoConnection {
"os": Bson(["type": Bson(os.os.to!string), "architecture": Bson(hostArchitecture)]),
"platform": Bson(platform)
]);
string cn = (m_settings.database == string.init ? "admin" : m_settings.database) ~ ".$cmd";

if (m_settings.appName.length) {
enforce!MongoAuthException(m_settings.appName.length <= 128,
"The application name may not be larger than 128 bytes");
handshake["client"]["application"] = Bson(["name": Bson(m_settings.appName)]);
}

query!Bson("$external.$cmd", QueryFlags.none, 0, -1, handshake, Bson(null),
query!Bson(cn, QueryFlags.none, 0, -1, handshake, Bson(null),
(cursor, flags, first_doc, num_docs) {
enforce!MongoDriverException(!(flags & ReplyFlags.QueryFailure) && num_docs == 1,
"Authentication handshake failed.");
},
(idx, ref doc) {
enforce!MongoAuthException(doc["ok"].get!double == 1.0, "Authentication failed.");
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
enforce!MongoAuthException(flag == 1.0, "Authentication failed.");
m_description = deserializeBson!ServerDescription(doc);
});

Expand Down Expand Up @@ -306,26 +308,50 @@ final class MongoConnection {
{
scope(failure) disconnect();
foreach (d; documents) if (d["_id"].isNull()) d["_id"] = Bson(BsonObjectID.generate());
send(OpCode.Insert, -1, cast(int)flags, collection_name, documents);
string collection = collection_name.canFind(".") ? collection_name.split(".")[1] : collection_name;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the semantic of the first part of collection'_name here, the database name? Since it is unused, would that result in the expected behavior if it mismatches?

Bson operation = Bson([
"insert": Bson(collection),
"$db": Bson(m_settings.database == string.init ? "admin" : m_settings.database),
"documents": Bson(documents)
]);
auto id = send(OpCode.Msg, -1, cast(int)0, cast(ubyte)0, operation);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses 0 instead of the flags argument that was used for OpCode.insert - I'm not sure how the semantics translate to OpCode.Msg, though, probably they have to be added as fields to operation?

recvReply!Bson(id, (long cursor, ReplyFlags flags, int first, int num) {

}, (size_t idx, ref Bson doc) {});
if (m_settings.safe) checkForError(collection_name);
}

void query(T)(string collection_name, QueryFlags flags, int nskip, int nret, Bson query, Bson returnFieldSelector, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc)
void query(T)(string collection_name, QueryFlags flags, int nskip, int nret, Bson query, Bson returnFieldSelector, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc, bool isAggregate = false)
{
scope(failure) disconnect();
flags |= m_settings.defQueryFlags;
int id;
if (returnFieldSelector.isNull)
id = send(OpCode.Query, -1, cast(int)flags, collection_name, nskip, nret, query);
else
id = send(OpCode.Query, -1, cast(int)flags, collection_name, nskip, nret, query, returnFieldSelector);
recvReply!T(id, on_msg, on_doc);
Bson op = query;
op["$db"] = Bson(m_settings.database == string.init ? "admin" : m_settings.database);

if (!returnFieldSelector.isNull)
op["projection"] = returnFieldSelector;

if (flags && QueryFlags.tailableCursor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (flags && QueryFlags.tailableCursor)
if (flags & QueryFlags.tailableCursor)

op["tailable"] = Bson(true);
if (flags && QueryFlags.awaitData)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (flags && QueryFlags.awaitData)
if (flags & QueryFlags.awaitData)

op["awaitData"] = Bson(true);

Bson operation = op;
id = send(OpCode.Msg, -1, cast(int)0, cast(ubyte)0, operation);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few more query flags that are unhandled here: slaveOk, oplogReplay, noCursorTimeout, exhaust, partial

Also, the collection_name parameter is now unused.

recvReply!T(id, on_msg, on_doc, isAggregate);
}

void getMore(T)(string collection_name, int nret, long cursor_id, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc)
{
scope(failure) disconnect();
auto id = send(OpCode.GetMore, -1, cast(int)0, collection_name, nret, cursor_id);
string collection = collection_name.canFind(".") ? collection_name.split(".")[1] : collection_name;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as for insert here

Bson operation = Bson([
"getMore": Bson(cursor_id),
"$db": Bson(m_settings.database == string.init ? "admin" : m_settings.database),
"collection": Bson(collection),
]);
auto id = send(OpCode.Msg, -1, cast(int)0, cast(ubyte)0, operation);
recvReply!T(id, on_msg, on_doc);
}

Expand Down Expand Up @@ -418,7 +444,8 @@ final class MongoConnection {

Bson result;
void on_doc(size_t idx, ref Bson doc) {
if (doc["ok"].get!double != 1.0)
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
if (flag != 1.0)
throw new MongoAuthException("listDatabases failed.");

result = doc["databases"];
Expand All @@ -429,7 +456,7 @@ final class MongoConnection {
return result.byValue.map!toInfo;
}

private int recvReply(T)(int reqid, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc)
private int recvReply(T)(int reqid, scope ReplyDelegate on_msg, scope DocDelegate!T on_doc, bool isAggregate = false)
{
import std.traits;

Expand All @@ -440,47 +467,86 @@ final class MongoConnection {
int opcode = recvInt();

enforce(respto == reqid, "Reply is not for the expected message on a sequential connection!");
enforce(opcode == OpCode.Reply, "Got a non-'Reply' reply!");

auto flags = cast(ReplyFlags)recvInt();
long cursor = recvLong();
int start = recvInt();
int numret = recvInt();

scope (exit) {
if (m_bytesRead - bytes_read < msglen) {
logWarn("MongoDB reply was longer than expected, skipping the rest: %d vs. %d", msglen, m_bytesRead - bytes_read);
ubyte[] dst = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)];
recv(dst);
} else if (m_bytesRead - bytes_read > msglen) {
logWarn("MongoDB reply was shorter than expected. Dropping connection.");
disconnect();
throw new MongoDriverException("MongoDB reply was too short for data.");
enforce(opcode == OpCode.Reply || opcode == OpCode.Msg, "Got a non-'Reply' reply!");

if (opcode == OpCode.Reply) {
auto flags = cast(ReplyFlags)recvInt();
long cursor = recvLong();
int start = recvInt();
int numret = recvInt();

scope (exit) {
if (m_bytesRead - bytes_read < msglen) {
logWarn("MongoDB reply was longer than expected, skipping the rest: %d vs. %d", msglen, m_bytesRead - bytes_read);
ubyte[] dst = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)];
recv(dst);
} else if (m_bytesRead - bytes_read > msglen) {
logWarn("MongoDB reply was shorter than expected. Dropping connection.");
disconnect();
throw new MongoDriverException("MongoDB reply was too short for data.");
}
}
}

on_msg(cursor, flags, start, numret);
static if (hasIndirections!T || is(T == Bson))
auto buf = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)];
foreach (i; 0 .. cast(size_t)numret) {
// TODO: directly deserialize from the wire
static if (!hasIndirections!T && !is(T == Bson)) {
ubyte[256] buf = void;
ubyte[] bufsl = buf;
auto bson = () @trusted { return recvBson(bufsl); } ();
} else {
auto bson = () @trusted { return recvBson(buf); } ();
}
on_msg(cursor, flags, start, numret);
static if (hasIndirections!T || is(T == Bson))
auto buf = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)];
foreach (i; 0 .. cast(size_t)numret) {
// TODO: directly deserialize from the wire
static if (!hasIndirections!T && !is(T == Bson)) {
ubyte[256] buf = void;
ubyte[] bufsl = buf;
auto bson = () @trusted { return recvBson(bufsl); } ();
} else {
auto bson = () @trusted { return recvBson(buf); } ();
}

// logDebugV("Received mongo response on %s:%s: %s", reqid, i, bson);
static if (is(T == Bson)) on_doc(i, bson);
else {
T doc = deserializeBson!T(bson);
on_doc(i, doc);
}
}

static if (is(T == Bson)) on_doc(i, bson);
else {
T doc = deserializeBson!T(bson);
on_doc(i, doc);
return resid;
} else {
auto flags = cast(ReplyFlags)recvInt();
ubyte kind = recvByte();
if (kind == 0) {
on_msg(0, flags, 0, 1);
static if (hasIndirections!T || is(T == Bson))
auto buf = new ubyte[msglen - cast(size_t)(m_bytesRead - bytes_read)];
static if (!hasIndirections!T && !is(T == Bson)) {
ubyte[256] buf = void;
ubyte[] bufsl = buf;
auto bson = () @trusted { return recvBson(bufsl); } ();
} else {
auto bson = () @trusted { return recvBson(buf); } ();
}

if (!bson["cursor"].isNull && !isAggregate) {
Bson cursor = bson["cursor"];
on_msg(cursor["id"].get!long, flags, 0, cast(int)cursor["firstBatch"].length);
auto index = 0;
foreach (element; cursor["firstBatch"].byValue()) {
static if (is(T == Bson)) on_doc(index, element);
else {
T doc = deserializeBson!T(element);
on_doc(index, doc);
}
index++;
}
} else {
on_msg(0, flags, 0, 1);
static if (is(T == Bson)) on_doc(0, bson);
else {
T doc = deserializeBson!T(bson);
on_doc(0, doc);
}
}
} else {
throw new MongoDriverException("Kind in reply was 1?? Why?");
}
}

return resid;
}

Expand All @@ -507,6 +573,7 @@ final class MongoConnection {
{
import std.traits;
static if (is(T == int)) sendBytes(toBsonData(value));
else static if (is(T == ubyte)) sendBytes(toBsonData(value));
else static if (is(T == long)) sendBytes(toBsonData(value));
else static if (is(T == Bson)) sendBytes(value.data);
else static if (is(T == string)) {
Expand All @@ -520,6 +587,7 @@ final class MongoConnection {

private void sendBytes(in ubyte[] data){ m_outRange.put(data); }

private byte recvByte() { ubyte[1] ret; recv(ret); return fromBsonData!byte(ret); }
private int recvInt() { ubyte[int.sizeof] ret; recv(ret); return fromBsonData!int(ret); }
private long recvLong() { ubyte[long.sizeof] ret; recv(ret); return fromBsonData!long(ret); }
private Bson recvBson(ref ubyte[] buf)
Expand Down Expand Up @@ -567,13 +635,15 @@ final class MongoConnection {

cmd["user"] = Bson(m_settings.username);
}
query!Bson("$external.$cmd", QueryFlags.None, 0, -1, cmd, Bson(null),
string cn = (m_settings.database == string.init ? "admin" : m_settings.database) ~ ".$cmd";
query!Bson(cn, QueryFlags.None, 0, -1, cmd, Bson(null),
(cursor, flags, first_doc, num_docs) {
if ((flags & ReplyFlags.QueryFailure) || num_docs != 1)
throw new MongoDriverException("Calling authenticate failed.");
},
(idx, ref doc) {
if (doc["ok"].get!double != 1.0)
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
if (flag != 1.0)
throw new MongoAuthException("Authentication failed.");
}
);
Expand All @@ -592,7 +662,8 @@ final class MongoConnection {
throw new MongoDriverException("Calling getNonce failed.");
},
(idx, ref doc) {
if (doc["ok"].get!double != 1.0)
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
if (flag != 1.0)
throw new MongoDriverException("getNonce failed.");
nonce = doc["nonce"].get!string;
key = toLower(toHexString(md5Of(nonce ~ m_settings.username ~ m_settings.digest)).idup);
Expand All @@ -611,7 +682,8 @@ final class MongoConnection {
throw new MongoDriverException("Calling authenticate failed.");
},
(idx, ref doc) {
if (doc["ok"].get!double != 1.0)
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
if (flag != 1.0)
throw new MongoAuthException("Authentication failed.");
}
);
Expand All @@ -637,7 +709,8 @@ final class MongoConnection {
throw new MongoDriverException("SASL start failed.");
},
(idx, ref doc) {
if (doc["ok"].get!double != 1.0)
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
if (flag != 1.0)
throw new MongoAuthException("Authentication failed.");
response = cast(string)doc["payload"].get!BsonBinData().rawData;
conversationId = doc["conversationId"];
Expand All @@ -653,7 +726,8 @@ final class MongoConnection {
throw new MongoDriverException("SASL continue failed.");
},
(idx, ref doc) {
if (doc["ok"].get!double != 1.0)
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
if (flag != 1.0)
throw new MongoAuthException("Authentication failed.");
response = cast(string)doc["payload"].get!BsonBinData().rawData;
});
Expand All @@ -669,15 +743,16 @@ final class MongoConnection {
throw new MongoDriverException("SASL finish failed.");
},
(idx, ref doc) {
if (doc["ok"].get!double != 1.0)
auto flag = doc["ok"].type == Bson.Type.int_ ? cast(double)doc["ok"].get!int : doc["ok"].get!double;
if (flag != 1.0)
throw new MongoAuthException("Authentication failed.");
});
}
}

private enum OpCode : int {
Reply = 1, // sent only by DB
Msg = 1000,
Msg = 2013,
Update = 2001,
Insert = 2002,
Reserved1 = 2003,
Expand All @@ -703,6 +778,7 @@ private int sendLength(ARGS...)(ARGS args)
static if (ARGS.length == 1) {
alias T = ARGS[0];
static if (is(T == string)) return cast(int)args[0].length + 1;
else static if (is(T == ubyte)) return 1;
else static if (is(T == int)) return 4;
else static if (is(T == long)) return 8;
else static if (is(T == Bson)) return cast(int)args[0].data.length;
Expand Down
26 changes: 16 additions & 10 deletions mongodb/vibe/db/mongo/cursor.d
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ public import vibe.data.bson;
import vibe.db.mongo.connection;
import vibe.db.mongo.client;

import std.array : array;
import std.algorithm : map, max, min;
import std.array : array, split;
import std.algorithm : map, max, min, canFind, endsWith;
import std.exception;

deprecated alias MongoCursor(Q, R = Bson, S = Bson) = MongoCursor!R;
Expand Down Expand Up @@ -352,20 +352,26 @@ private class MongoFindCursor(Q, R, S) : MongoCursorData!R {
query = () @trusted { return serializeToBson(m_query, query_buf); } ();
}

Bson full_query;
Bson full_query = Bson.emptyObject;

if (!query["query"].isNull() || !query["$query"].isNull()) {
// TODO: emit deprecation warning
full_query = query;
string collection;
if (query["aggregate"].isNull && m_collection.canFind(".") && !m_collection.endsWith("$cmd")) {
collection = m_collection.split(".")[1];
} else {
full_query = Bson.emptyObject;
full_query["$query"] = query;
collection = m_collection;
}

if (!m_sort.isNull()) full_query["orderby"] = m_sort;
if (query["aggregate"].isNull) {
full_query["find"] = Bson(collection);
full_query["filter"] = query;
} else {
full_query = query;
}

conn.query!R(m_collection, m_flags, m_nskip, m_nret, full_query, selector, &handleReply, &handleDocument);
if (!m_sort.isNull()) full_query["orderby"] = m_sort;

conn.query!R(collection, m_flags, m_nskip, m_nret, full_query, selector, &handleReply, &handleDocument, !query["aggregate"].isNull);

m_iterationStarted = true;
}
}
Expand Down