Skip to content

Commit

Permalink
Merge pull request #418 from rabbitmq/rabbtimq-dotnet-client-417
Browse files Browse the repository at this point in the history
Lock when writing to an SslStream
  • Loading branch information
michaelklishin authored May 14, 2018
2 parents 2d3eff6 + c428e03 commit 60226be
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 58 deletions.
133 changes: 81 additions & 52 deletions projects/client/RabbitMQ.Client/src/client/impl/ModelBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public abstract class ModelBase : IFullModel, IRecoverable
private readonly object m_eventLock = new object();
private readonly object m_flowSendLock = new object();
private readonly object m_shutdownLock = new object();
private readonly object _rpcLock = new object();

private readonly SynchronizedList<ulong> m_unconfirmedSet = new SynchronizedList<ulong>();

Expand Down Expand Up @@ -358,36 +359,43 @@ public string ConnectionOpen(string virtualHost,
bool insist)
{
var k = new ConnectionOpenContinuation();
Enqueue(k);
try
{
_Private_ConnectionOpen(virtualHost, capabilities, insist);
}
catch (AlreadyClosedException)
lock(_rpcLock)
{
// let continuation throw OperationInterruptedException,
// which is a much more suitable exception before connection
// negotiation finishes
Enqueue(k);
try
{
_Private_ConnectionOpen(virtualHost, capabilities, insist);
}
catch (AlreadyClosedException)
{
// let continuation throw OperationInterruptedException,
// which is a much more suitable exception before connection
// negotiation finishes
}
k.GetReply(HandshakeContinuationTimeout);
}
k.GetReply(HandshakeContinuationTimeout);

return k.m_knownHosts;
}

public ConnectionSecureOrTune ConnectionSecureOk(byte[] response)
{
var k = new ConnectionStartRpcContinuation();
Enqueue(k);
try
lock(_rpcLock)
{
_Private_ConnectionSecureOk(response);
}
catch (AlreadyClosedException)
{
// let continuation throw OperationInterruptedException,
// which is a much more suitable exception before connection
// negotiation finishes
Enqueue(k);
try
{
_Private_ConnectionSecureOk(response);
}
catch (AlreadyClosedException)
{
// let continuation throw OperationInterruptedException,
// which is a much more suitable exception before connection
// negotiation finishes
}
k.GetReply(HandshakeContinuationTimeout);
}
k.GetReply(HandshakeContinuationTimeout);
return k.m_result;
}

Expand All @@ -397,19 +405,22 @@ public ConnectionSecureOrTune ConnectionStartOk(IDictionary<string, object> clie
string locale)
{
var k = new ConnectionStartRpcContinuation();
Enqueue(k);
try
{
_Private_ConnectionStartOk(clientProperties, mechanism,
response, locale);
}
catch (AlreadyClosedException)
lock(_rpcLock)
{
// let continuation throw OperationInterruptedException,
// which is a much more suitable exception before connection
// negotiation finishes
Enqueue(k);
try
{
_Private_ConnectionStartOk(clientProperties, mechanism,
response, locale);
}
catch (AlreadyClosedException)
{
// let continuation throw OperationInterruptedException,
// which is a much more suitable exception before connection
// negotiation finishes
}
k.GetReply(HandshakeContinuationTimeout);
}
k.GetReply(HandshakeContinuationTimeout);
return k.m_result;
}

Expand Down Expand Up @@ -456,8 +467,11 @@ public void HandleCommand(ISession session, Command cmd)
public MethodBase ModelRpc(MethodBase method, ContentHeaderBase header, byte[] body)
{
var k = new SimpleBlockingRpcContinuation();
TransmitAndEnqueue(new Command(method, header, body), k);
return k.GetReply(this.ContinuationTimeout).Method;
lock(_rpcLock)
{
TransmitAndEnqueue(new Command(method, header, body), k);
return k.GetReply(this.ContinuationTimeout).Method;
}
}

public void ModelSend(MethodBase method, ContentHeaderBase header, byte[] body)
Expand Down Expand Up @@ -1146,10 +1160,12 @@ public void BasicCancel(string consumerTag)
{
var k = new BasicConsumerRpcContinuation { m_consumerTag = consumerTag };

Enqueue(k);

_Private_BasicCancel(consumerTag, false);
k.GetReply(this.ContinuationTimeout);
lock(_rpcLock)
{
Enqueue(k);
_Private_BasicCancel(consumerTag, false);
k.GetReply(this.ContinuationTimeout);
}
lock (m_consumers)
{
m_consumers.Remove(consumerTag);
Expand Down Expand Up @@ -1180,12 +1196,15 @@ public string BasicConsume(string queue,

var k = new BasicConsumerRpcContinuation { m_consumer = consumer };

Enqueue(k);
// Non-nowait. We have an unconventional means of getting
// the RPC response, but a response is still expected.
_Private_BasicConsume(queue, consumerTag, noLocal, autoAck, exclusive,
/*nowait:*/ false, arguments);
k.GetReply(this.ContinuationTimeout);
lock(_rpcLock)
{
Enqueue(k);
// Non-nowait. We have an unconventional means of getting
// the RPC response, but a response is still expected.
_Private_BasicConsume(queue, consumerTag, noLocal, autoAck, exclusive,
/*nowait:*/ false, arguments);
k.GetReply(this.ContinuationTimeout);
}
string actualConsumerTag = k.m_consumerTag;

return actualConsumerTag;
Expand All @@ -1195,9 +1214,13 @@ public BasicGetResult BasicGet(string queue,
bool autoAck)
{
var k = new BasicGetRpcContinuation();
Enqueue(k);
_Private_BasicGet(queue, autoAck);
k.GetReply(this.ContinuationTimeout);
lock(_rpcLock)
{
Enqueue(k);
_Private_BasicGet(queue, autoAck);
k.GetReply(this.ContinuationTimeout);
}

return k.m_result;
}

Expand Down Expand Up @@ -1261,9 +1284,12 @@ public void BasicRecover(bool requeue)
{
var k = new SimpleBlockingRpcContinuation();

Enqueue(k);
_Private_BasicRecover(requeue);
k.GetReply(this.ContinuationTimeout);
lock(_rpcLock)
{
Enqueue(k);
_Private_BasicRecover(requeue);
k.GetReply(this.ContinuationTimeout);
}
}

public abstract void BasicRecoverAsync(bool requeue);
Expand Down Expand Up @@ -1559,9 +1585,12 @@ private QueueDeclareOk QueueDeclare(string queue, bool passive, bool durable, bo
bool autoDelete, IDictionary<string, object> arguments)
{
var k = new QueueDeclareRpcContinuation();
Enqueue(k);
_Private_QueueDeclare(queue, passive, durable, exclusive, autoDelete, false, arguments);
k.GetReply(this.ContinuationTimeout);
lock(_rpcLock)
{
Enqueue(k);
_Private_QueueDeclare(queue, passive, durable, exclusive, autoDelete, false, arguments);
k.GetReply(this.ContinuationTimeout);
}
return k.m_result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ public class SocketFrameHandler : IFrameHandler
private readonly ITcpClient m_socket;
private readonly NetworkBinaryWriter m_writer;
private readonly object _semaphore = new object();
private readonly object _sslStreamLock = new object();
private bool _closed;
private bool _ssl = false;
public SocketFrameHandler(AmqpTcpEndpoint endpoint,
Func<AddressFamily, ITcpClient> socketFactory,
int connectionTimeout, int readTimeout, int writeTimeout)
Expand Down Expand Up @@ -108,6 +110,7 @@ public SocketFrameHandler(AmqpTcpEndpoint endpoint,
try
{
netstream = SslHelper.TcpUpgrade(netstream, endpoint.Ssl);
_ssl = true;
}
catch (Exception)
{
Expand Down Expand Up @@ -217,7 +220,7 @@ public void SendHeader()
nbw.Write((byte)Endpoint.Protocol.MajorVersion);
nbw.Write((byte)Endpoint.Protocol.MinorVersion);
}
m_writer.Write(ms.ToArray());
Write(ms.ToArray());
}

public void WriteFrame(OutboundFrame frame)
Expand All @@ -226,7 +229,7 @@ public void WriteFrame(OutboundFrame frame)
var nbw = new NetworkBinaryWriter(ms);
frame.WriteTo(nbw);
m_socket.Client.Poll(m_writeableStateTimeout, SelectMode.SelectWrite);
m_writer.Write(ms.ToArray());
Write(ms.ToArray());
}

public void WriteFrameSet(IList<OutboundFrame> frames)
Expand All @@ -235,7 +238,22 @@ public void WriteFrameSet(IList<OutboundFrame> frames)
var nbw = new NetworkBinaryWriter(ms);
foreach (var f in frames) f.WriteTo(nbw);
m_socket.Client.Poll(m_writeableStateTimeout, SelectMode.SelectWrite);
m_writer.Write(ms.ToArray());
Write(ms.ToArray());
}

private void Write(byte [] buffer)
{
if(_ssl)
{
lock (_sslStreamLock)
{
m_writer.Write(buffer);
}
}
else
{
m_writer.Write(buffer);
}
}

private bool ShouldTryIPv6(AmqpTcpEndpoint endpoint)
Expand Down
4 changes: 2 additions & 2 deletions projects/client/Unit/src/unit/TestExchangeDeclare.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class TestExchangeDeclare : IntegrationFixture {

[Test]
[Category("RequireSMP")]
public void TestConcurrentQueueDeclare()
public void TestConcurrentExchangeDeclare()
{
string x = GenerateExchangeName();
Random rnd = new Random();
Expand Down Expand Up @@ -83,7 +83,7 @@ public void TestConcurrentQueueDeclare()
t.Join();
}

Assert.IsNotNull(nse);
Assert.IsNull(nse);
Model.ExchangeDelete(x);
}
}
Expand Down
2 changes: 1 addition & 1 deletion projects/client/Unit/src/unit/TestQueueDeclare.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public void TestConcurrentQueueDeclare()
t.Join();
}

Assert.IsNotNull(nse);
Assert.IsNull(nse);
Model.QueueDelete(q);
}
}
Expand Down

0 comments on commit 60226be

Please sign in to comment.