diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java index 229470f018..141ad8948c 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java @@ -27,22 +27,29 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collection; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.Queue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.concurrent.locks.LockSupport; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked; import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull; import static io.servicetalk.concurrent.internal.TerminalNotification.complete; import static io.servicetalk.concurrent.internal.TerminalNotification.error; -import static io.servicetalk.utils.internal.PlatformDependent.throwException; +import static io.servicetalk.utils.internal.PlatformDependent.newUnboundedSpscQueue; +import static io.servicetalk.utils.internal.ThrowableUtils.throwException; +import static java.lang.Integer.getInteger; +import static java.lang.Long.getLong; import static java.lang.Math.min; import static java.lang.Thread.currentThread; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.NANOSECONDS; /** * As returned by {@link Publisher#toIterable(int)} and {@link Publisher#toIterable()}. @@ -101,7 +108,7 @@ private static final class SubscriberAndIterator implements Subscriber, Bl SubscriberAndIterator(int queueCapacity) { requestN = queueCapacity; - data = new LinkedBlockingQueue<>(); + data = new SpscBlockingQueue<>(newUnboundedSpscQueue(queueCapacity)); } @Override @@ -261,4 +268,366 @@ private T processNext() { return unwrapNullUnchecked(signal); } } + + private static final class SpscBlockingQueue implements BlockingQueue { + /** + * Amount of times to call {@link Thread#yield()} before calling {@link LockSupport#park()}. + * {@link LockSupport#park()} can be expensive and if the producer is generating data it is likely we will see + * it without parking. + */ + private static final int POLL_YIELD_COUNT = + getInteger("io.servicetalk.concurrent.internal.blockingIterableYieldCount", 1); + /** + * Amount of nanoseconds to spin on {@link Thread#yield()} before calling {@link LockSupport#parkNanos(long)}. + * {@link LockSupport#parkNanos(long)} can be expensive and if the producer is generating data it is likely + * we will see it without parking. + */ + private static final long POLL_YIELD_SPIN_NS = + getLong("io.servicetalk.concurrent.internal.blockingIterableYieldNs", 1024); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater threadStampUpdater = + AtomicReferenceFieldUpdater.newUpdater(SpscBlockingQueue.class, ThreadStamp.class, "threadStamp"); + private static final Thread PRODUCED_THREAD = new Thread(() -> { }); + private final Queue spscQueue; + @Nullable + private volatile ThreadStamp threadStamp; + + SpscBlockingQueue(Queue spscQueue) { + this.spscQueue = requireNonNull(spscQueue); + } + + @Override + public boolean add(final T t) { + if (spscQueue.add(t)) { + producerSignalConsumer(); + return true; + } + return false; + } + + @Override + public boolean offer(final T t) { + if (spscQueue.offer(t)) { + producerSignalConsumer(); + return true; + } + return false; + } + + @Override + public T remove() { + return spscQueue.remove(); + } + + @Override + public T poll() { + return spscQueue.poll(); + } + + @Override + public T element() { + final T t = poll(); + if (t == null) { + throw new NoSuchElementException(); + } + return t; + } + + @Override + public T peek() { + return spscQueue.peek(); + } + + @Override + public void put(final T t) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean offer(final T t, final long timeout, final TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public T take() throws InterruptedException { + return take0(this::pollAndParkIgnoreTime, 0, NANOSECONDS); + } + + @Override + public T poll(final long timeout, final TimeUnit unit) throws InterruptedException { + return take0(this::pollAndPark, timeout, unit); + } + + @Override + public int remainingCapacity() { + return Integer.MAX_VALUE; + } + + @Override + public boolean remove(final Object o) { + if (spscQueue.remove(o)) { + producerSignalConsumer(); + return true; + } + return false; + } + + @Override + public boolean containsAll(final Collection c) { + return spscQueue.containsAll(c); + } + + @Override + public boolean addAll(final Collection c) { + if (spscQueue.addAll(c)) { + producerSignalConsumer(); + return true; + } + return false; + } + + @Override + public boolean removeAll(final Collection c) { + if (spscQueue.removeAll(c)) { + producerSignalConsumer(); + return true; + } + return false; + } + + @Override + public boolean retainAll(final Collection c) { + if (spscQueue.retainAll(c)) { + producerSignalConsumer(); + return true; + } + return false; + } + + @Override + public void clear() { + spscQueue.clear(); + producerSignalConsumer(); + } + + @Override + public int size() { + return spscQueue.size(); + } + + @Override + public boolean isEmpty() { + return spscQueue.isEmpty(); + } + + @Override + public boolean contains(final Object o) { + return spscQueue.contains(o); + } + + @Override + public Iterator iterator() { + return spscQueue.iterator(); + } + + @Override + public Object[] toArray() { + return spscQueue.toArray(); + } + + @Override + public T1[] toArray(final T1[] a) { + return spscQueue.toArray(a); + } + + @Override + public int drainTo(final Collection c) { + int i = 0; + T item; + while ((item = poll()) != null) { + if (c.add(item)) { + ++i; + } + } + return i; + } + + @Override + public int drainTo(final Collection c, final int maxElements) { + int i = 0; + T item; + while (i < maxElements && (item = poll()) != null) { + if (c.add(item)) { + ++i; + } + } + return i; + } + + @Override + public boolean equals(Object o) { + return o instanceof SpscBlockingQueue && spscQueue.equals(((SpscBlockingQueue) o).spscQueue); + } + + @Override + public int hashCode() { + return spscQueue.hashCode(); + } + + @Override + public String toString() { + return spscQueue.toString(); + } + + private void producerSignalConsumer() { + ThreadStamp nextStamp = null; + for (;;) { + final ThreadStamp currStamp = threadStamp; + if (currStamp == null) { + if (nextStamp == null) { + nextStamp = new ThreadStamp(PRODUCED_THREAD, 1); + } else { + nextStamp.count = 1; + } + if (threadStampUpdater.compareAndSet(this, null, nextStamp)) { + break; + } + } else if (currStamp.thread == PRODUCED_THREAD) { + if (nextStamp == null) { + nextStamp = new ThreadStamp(PRODUCED_THREAD, currStamp.count + 1); + } else { + nextStamp.count = currStamp.count + 1; + } + + if (threadStampUpdater.compareAndSet(this, currStamp, nextStamp)) { + break; + } + } else if (threadStampUpdater.compareAndSet(this, currStamp, null)) { + LockSupport.unpark(currStamp.thread); + assert currStamp.count == 0; // only a single consumer allowed + break; + } + } + } + + private T take0(BiLongFunction taker, long timeout, TimeUnit unit) throws InterruptedException { + final Thread currentThread = Thread.currentThread(); + ThreadStamp nextStamp = new ThreadStamp(currentThread); + for (;;) { + final ThreadStamp currStamp = threadStamp; + if (currStamp == null) { + nextStamp.count = 0; + nextStamp.thread = currentThread; + if (threadStampUpdater.compareAndSet(this, null, nextStamp)) { + try { + return taker.apply(timeout, unit); + } finally { + threadStampUpdater.compareAndSet(this, nextStamp, null); + } + } + } else if (currStamp.thread == PRODUCED_THREAD) { + final ThreadStamp nextStamp2; + if (currStamp.count == 1) { + nextStamp2 = null; + } else { + nextStamp.count = currStamp.count - 1; + nextStamp.thread = PRODUCED_THREAD; + nextStamp2 = nextStamp; + } + if (threadStampUpdater.compareAndSet(this, currStamp, nextStamp2)) { + final T item = spscQueue.poll(); + assert item != null; + return item; + } + } else { + throwTooManyConsumers(currStamp.thread, currentThread); + } + } + } + + private T pollAndParkIgnoreTime(@SuppressWarnings("unused") final long timeout, + @SuppressWarnings("unused") final TimeUnit unit) throws InterruptedException { + T item; + int pollCount = 0; + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (pollCount++ <= POLL_YIELD_COUNT) { + Thread.yield(); + } else { + LockSupport.park(); + } + checkInterrupted(); + } + return item; + } + + @Nullable + private T pollAndPark(final long timeout, final TimeUnit unit) throws InterruptedException { + T item; + final long originalNs = unit.toNanos(timeout); + long remainingNs = originalNs; + long beforeTimeNs = System.nanoTime(); + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (originalNs - remainingNs <= POLL_YIELD_SPIN_NS) { + Thread.yield(); + } else { + LockSupport.parkNanos(remainingNs); + } + checkInterrupted(); + final long afterTimeNs = System.nanoTime(); + final long durationNs = afterTimeNs - beforeTimeNs; + if (durationNs > remainingNs) { + return null; + } + remainingNs -= durationNs; + beforeTimeNs = afterTimeNs; + } + return item; + } + + private static void throwTooManyConsumers(Thread currentConsumer, Thread currentThread) { + throw new IllegalStateException("Only single consumer allowed. Existing consumer: " + currentConsumer + + " attempted new consumer: " + currentThread); + } + + private static void checkInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + + /** + * The producer thread may produce multiple items before the consumer thread consume the events. If the consumer + * thread changes we need to make sure the new consumer thread observes production events so this object + * contains the thread to wakeup and a count of how many items are in the queue and not yet consumed. + */ + private static final class ThreadStamp { + Thread thread; + int count; + + ThreadStamp(Thread thread) { + this.thread = thread; + } + + ThreadStamp(Thread thread, int count) { + this.thread = thread; + this.count = count; + } + + @Override + public String toString() { + return "thread: " + thread + " count: " + count; + } + } + + private interface BiLongFunction { + R apply(long l, T t) throws InterruptedException; + } + } }