Skip to content

Commit

Permalink
PublisherAsBlockingIterable LinkedBlockingQueue -> SpscBlockingQueue
Browse files Browse the repository at this point in the history
Motivation:
LinkedBlockingQueue allows for multiple producers and multiple consumers.
It uses LockSupport park in offer and unpark in take. LockSupport unpark
on the EventLoop thread has been shown to impact throughput during benchmarks.

Before:
```
```

After:
```
```
  • Loading branch information
Scottmitch committed Oct 2, 2022
1 parent 450096e commit b5163cf
Showing 1 changed file with 308 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,23 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collection;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.concurrent.locks.LockSupport;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked;
import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull;
import static io.servicetalk.concurrent.internal.TerminalNotification.complete;
import static io.servicetalk.concurrent.internal.TerminalNotification.error;
import static io.servicetalk.utils.internal.PlatformDependent.throwException;
import static io.servicetalk.utils.internal.PlatformDependent.newUnboundedSpscQueue;
import static io.servicetalk.utils.internal.ThrowableUtils.throwException;
import static java.lang.Math.min;
import static java.lang.Thread.currentThread;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -101,7 +105,7 @@ private static final class SubscriberAndIterator<T> implements Subscriber<T>, Bl

SubscriberAndIterator(int queueCapacity) {
requestN = queueCapacity;
data = new LinkedBlockingQueue<>();
data = new SpscBlockingQueue<>(newUnboundedSpscQueue(queueCapacity));
}

@Override
Expand Down Expand Up @@ -261,4 +265,305 @@ private T processNext() {
return unwrapNullUnchecked(signal);
}
}

private static final class SpscBlockingQueue<T> implements BlockingQueue<T> {
/**
* Amount of times to call {@link Thread#yield()} before calling {@link LockSupport#park()}.
* {@link LockSupport#park()} can be expensive and if the producer is generating data it is likely we will see
* it without park/unpark.
*/
private static final int POLL_YIELD_SPIN_COUNT =
Integer.getInteger("io.servicetalk.concurrent.internal.blockingIterableYieldSpinCount", 1);
/**
* Amount of nanoseconds to spin on {@link Thread#yield()} before calling {@link LockSupport#parkNanos(long)}.
* {@link LockSupport#parkNanos(long)} can be expensive and if the producer is generating data it is likely
* we will see it without park/unpark.
*/
private static final long POLL_YIELD_SPIN_NS =
Long.getLong("io.servicetalk.concurrent.internal.blockingIterableYieldSpinNs", 1024);
@SuppressWarnings("rawtypes")
private static final AtomicReferenceFieldUpdater<SpscBlockingQueue, Thread> consumerThreadUpdater =
AtomicReferenceFieldUpdater.newUpdater(SpscBlockingQueue.class, Thread.class, "consumerThread");
private static final Thread PRODUCED_THREAD = new Thread(() -> { });
private final Queue<T> spscQueue;
@Nullable
private volatile Thread consumerThread;

SpscBlockingQueue(Queue<T> spscQueue) {
this.spscQueue = requireNonNull(spscQueue);
}

@Override
public boolean add(final T t) {
if (spscQueue.add(t)) {
signalConsumer();
return true;
}
return false;
}

@Override
public boolean offer(final T t) {
if (spscQueue.offer(t)) {
signalConsumer();
return true;
}
return false;
}

private void signalConsumer() {
final Thread thread = consumerThreadUpdater.getAndSet(this, PRODUCED_THREAD);
if (thread != null && thread != PRODUCED_THREAD) {
LockSupport.unpark(thread);
}
}

@Override
public T remove() {
return spscQueue.remove();
}

@Override
public T poll() {
return spscQueue.poll();
}

@Override
public T element() {
final T t = poll();
if (t == null) {
throw new NoSuchElementException();
}
return t;
}

@Override
public T peek() {
return spscQueue.peek();
}

@Override
public void put(final T t) {
throw new UnsupportedOperationException();
}

@Override
public boolean offer(final T t, final long timeout, final TimeUnit unit) {
throw new UnsupportedOperationException();
}

@Override
public T take() throws InterruptedException {
final Thread currentThread = Thread.currentThread();
for (;;) {
final Thread thread = consumerThread;
if (thread != null && thread != currentThread && thread != PRODUCED_THREAD) {
throwTooManyConsumers(currentThread);
} else if (thread == currentThread ||
consumerThreadUpdater.compareAndSet(this, thread, currentThread)) {
try {
T item;
int pollCount = 0;
while ((item = spscQueue.poll()) == null) {
// Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and
// unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost
// on the EventLoop thread and increase throughput in these scenarios.
if (pollCount++ > POLL_YIELD_SPIN_COUNT) {
LockSupport.park();
} else {
Thread.yield();
}
checkInterrupted();
}

return item;
} finally {
// If this call changed the consumerThread before the poll call we should restore it after.
// This should be done atomically in case another thread has produced concurrently and swapped
// the value to PRODUCED_THREAD.
if (thread != currentThread) {
consumerThreadUpdater.compareAndSet(this, currentThread, null);
}
}
}
}
}

@Override
public T poll(final long timeout, final TimeUnit unit) throws InterruptedException {
final Thread currentThread = Thread.currentThread();
for (;;) {
final Thread thread = consumerThread;
if (thread != null && thread != currentThread && thread != PRODUCED_THREAD) {
throwTooManyConsumers(currentThread);
} else if (thread == currentThread ||
consumerThreadUpdater.compareAndSet(this, thread, currentThread)) {
try {
final long originalNs = unit.toNanos(timeout);
long remainingNs = originalNs;
long beforeTimeNs = System.nanoTime();
T item;
while ((item = spscQueue.poll()) == null) {
// Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and
// unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost
// on the EventLoop thread and increase throughput in these scenarios.
if (originalNs - remainingNs > POLL_YIELD_SPIN_NS) {
LockSupport.parkNanos(remainingNs);
} else {
Thread.yield();
}
checkInterrupted();
final long afterTimeNs = System.nanoTime();
final long durationNs = afterTimeNs - beforeTimeNs;
if (durationNs > remainingNs) {
return null;
}
remainingNs -= durationNs;
beforeTimeNs = afterTimeNs;
}

return item;
} finally {
// If this call changed the consumerThread before the poll call we should restore it after.
// This should be done atomically in case another thread has produced concurrently and swapped
// the value to PRODUCED_THREAD.
if (thread != currentThread) {
consumerThreadUpdater.compareAndSet(this, currentThread, null);
}
}
}
}
}

private static void throwTooManyConsumers(Thread currentThread) {
throw new IllegalStateException("Only single consumer allowed, current consumer: " + currentThread);
}

private static void checkInterrupted() throws InterruptedException {
if (Thread.interrupted()) {
throw new InterruptedException();
}
}

@Override
public int remainingCapacity() {
return Integer.MAX_VALUE;
}

@Override
public boolean remove(final Object o) {
if (spscQueue.remove(o)) {
signalConsumer();
return true;
}
return false;
}

@Override
public boolean containsAll(final Collection<?> c) {
return spscQueue.containsAll(c);
}

@Override
public boolean addAll(final Collection<? extends T> c) {
if (spscQueue.addAll(c)) {
signalConsumer();
return true;
}
return false;
}

@Override
public boolean removeAll(final Collection<?> c) {
if (spscQueue.removeAll(c)) {
signalConsumer();
return true;
}
return false;
}

@Override
public boolean retainAll(final Collection<?> c) {
if (spscQueue.retainAll(c)) {
signalConsumer();
return true;
}
return false;
}

@Override
public void clear() {
spscQueue.clear();
signalConsumer();
}

@Override
public int size() {
return spscQueue.size();
}

@Override
public boolean isEmpty() {
return spscQueue.isEmpty();
}

@Override
public boolean contains(final Object o) {
return spscQueue.contains(o);
}

@Override
public Iterator<T> iterator() {
return spscQueue.iterator();
}

@Override
public Object[] toArray() {
return spscQueue.toArray();
}

@Override
public <T1> T1[] toArray(final T1[] a) {
return spscQueue.toArray(a);
}

@Override
public int drainTo(final Collection<? super T> c) {
int i = 0;
T item;
while ((item = poll()) != null) {
if (c.add(item)) {
++i;
}
}
return i;
}

@Override
public int drainTo(final Collection<? super T> c, final int maxElements) {
int i = 0;
T item;
while (i < maxElements && (item = poll()) != null) {
if (c.add(item)) {
++i;
}
}
return i;
}

@Override
public boolean equals(Object o) {
return o instanceof SpscBlockingQueue && spscQueue.equals(((SpscBlockingQueue<?>) o).spscQueue);
}

@Override
public int hashCode() {
return spscQueue.hashCode();
}

@Override
public String toString() {
return spscQueue.toString();
}
}
}

0 comments on commit b5163cf

Please sign in to comment.