diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java index d35b53cead..0a1ef481b6 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterable.java @@ -27,22 +27,29 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collection; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.Queue; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedTransferQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.locks.LockSupport; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked; import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull; import static io.servicetalk.concurrent.internal.TerminalNotification.complete; import static io.servicetalk.concurrent.internal.TerminalNotification.error; +import static io.servicetalk.utils.internal.PlatformDependent.newLinkedSpscQueue; import static io.servicetalk.utils.internal.ThrowableUtils.throwException; +import static java.lang.Integer.getInteger; +import static java.lang.Long.getLong; import static java.lang.Math.min; import static java.lang.Thread.currentThread; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.NANOSECONDS; /** * As returned by {@link Publisher#toIterable(int)} and {@link Publisher#toIterable()}. @@ -50,6 +57,7 @@ * @param Type of items emitted by the {@link Publisher} from which this {@link BlockingIterable} is created. */ final class PublisherAsBlockingIterable implements BlockingIterable { + private static final int MAX_OUTSTANDING_DEMAND = 128; final Publisher original; private final int queueCapacityHint; @@ -63,7 +71,7 @@ final class PublisherAsBlockingIterable implements BlockingIterable { throw new IllegalArgumentException("Invalid queueCapacityHint: " + queueCapacityHint + " (expected > 0)."); } // Add a sane upper bound to the capacity to reduce buffering. - this.queueCapacityHint = min(queueCapacityHint, 128); + this.queueCapacityHint = min(queueCapacityHint, MAX_OUTSTANDING_DEMAND); } @Override @@ -101,7 +109,7 @@ private static final class SubscriberAndIterator implements Subscriber, Bl SubscriberAndIterator(int queueCapacity) { requestN = queueCapacity; - data = new LinkedTransferQueue<>(); + data = new SpscBlockingQueue<>(newLinkedSpscQueue()); } @Override @@ -261,4 +269,371 @@ private T processNext() { return unwrapNullUnchecked(signal); } } + + private static final class SpscBlockingQueue implements BlockingQueue { + /** + * Amount of times to call {@link Thread#yield()} before calling {@link LockSupport#park()}. + * {@link LockSupport#park()} can be expensive and if the producer is generating data it is likely we will see + * it without parking. + */ + private static final int POLL_YIELD_COUNT = + getInteger("io.servicetalk.concurrent.internal.blockingIterableYieldCount", 1); + /** + * Amount of nanoseconds to spin on {@link Thread#yield()} before calling {@link LockSupport#parkNanos(long)}. + * {@link LockSupport#parkNanos(long)} can be expensive and if the producer is generating data it is likely + * we will see it without parking. + */ + private static final long POLL_YIELD_SPIN_NS = + getLong("io.servicetalk.concurrent.internal.blockingIterableYieldNs", 1024); + + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater producerConsumerIndexUpdater = + AtomicLongFieldUpdater.newUpdater(SpscBlockingQueue.class, "producerConsumerIndex"); + private final Queue spscQueue; + @Nullable + private Thread consumerThread; + /** + * high 32 bits == producer index (see {@link #producerIndex(long)}) + * low 32 bits == consumer index (see {@link #consumerIndex(long)}} + * @see #combineIndexes(int, int) + */ + private volatile long producerConsumerIndex; + + SpscBlockingQueue(Queue spscQueue) { + this.spscQueue = requireNonNull(spscQueue); + } + + @Override + public boolean add(final T t) { + if (spscQueue.add(t)) { + producerSignalAdded(); + return true; + } + return false; + } + + @Override + public boolean offer(final T t) { + if (spscQueue.offer(t)) { + producerSignalAdded(); + return true; + } + return false; + } + + @Override + public T remove() { + final T t = spscQueue.remove(); + consumerSignalRemoved(1); + return t; + } + + @Override + public T poll() { + final T t = spscQueue.poll(); + if (t != null) { + consumerSignalRemoved(1); + } + return t; + } + + @Override + public T element() { + final T t = poll(); + if (t == null) { + throw new NoSuchElementException(); + } + return t; + } + + @Override + public T peek() { + return spscQueue.peek(); + } + + @Override + public void put(final T t) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean offer(final T t, final long timeout, final TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public T take() throws InterruptedException { + return take0(this::pollAndParkIgnoreTime, 0, NANOSECONDS); + } + + @Override + public T poll(final long timeout, final TimeUnit unit) throws InterruptedException { + return take0(this::pollAndPark, timeout, unit); + } + + @Override + public int remainingCapacity() { + return Integer.MAX_VALUE; + } + + @Override + public boolean remove(final Object o) { + if (spscQueue.remove(o)) { + consumerSignalRemoved(1); + return true; + } + return false; + } + + @Override + public boolean containsAll(final Collection c) { + return spscQueue.containsAll(c); + } + + @Override + public boolean addAll(final Collection c) { + boolean added = false; + for (T t : c) { + if (add(t)) { + added = true; + } + } + return added; + } + + @Override + public boolean removeAll(final Collection c) { + int removed = 0; + try { + for (Object t : c) { + if (spscQueue.remove(t)) { + ++removed; + } + } + } finally { + consumerSignalRemoved(removed); + } + return removed > 0; + } + + @Override + public boolean retainAll(final Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + int removed = 0; + while (spscQueue.poll() != null) { + ++removed; + } + consumerSignalRemoved(removed); + } + + @Override + public int size() { + return spscQueue.size(); + } + + @Override + public boolean isEmpty() { + return spscQueue.isEmpty(); + } + + @Override + public boolean contains(final Object o) { + return spscQueue.contains(o); + } + + @Override + public Iterator iterator() { + return spscQueue.iterator(); + } + + @Override + public Object[] toArray() { + return spscQueue.toArray(); + } + + @Override + public T1[] toArray(final T1[] a) { + return spscQueue.toArray(a); + } + + @Override + public int drainTo(final Collection c) { + int added = 0; + int removed = 0; + T item; + try { + while ((item = spscQueue.poll()) != null) { + ++removed; + if (c.add(item)) { + ++added; + } + } + } finally { + consumerSignalRemoved(removed); + } + return added; + } + + @Override + public int drainTo(final Collection c, final int maxElements) { + int added = 0; + int removed = 0; + T item; + try { + while (added < maxElements && (item = spscQueue.poll()) != null) { + ++removed; + if (c.add(item)) { + ++added; + } + } + } finally { + consumerSignalRemoved(removed); + } + return added; + } + + @Override + public boolean equals(Object o) { + return o instanceof SpscBlockingQueue && spscQueue.equals(((SpscBlockingQueue) o).spscQueue); + } + + @Override + public int hashCode() { + return spscQueue.hashCode(); + } + + @Override + public String toString() { + return spscQueue.toString(); + } + + private void producerSignalAdded() { + for (;;) { + final long currIndex = producerConsumerIndex; + final int producer = producerIndex(currIndex); + final int consumer = consumerIndex(currIndex); + if (producerConsumerIndexUpdater.compareAndSet(this, currIndex, + combineIndexes(producer + 1, consumer))) { + if (producer - consumer <= 0 && consumerThread != null) { + final Thread wakeThread = consumerThread; + consumerThread = null; + LockSupport.unpark(wakeThread); + } + break; + } + } + } + + private T take0(BiLongFunction taker, long timeout, TimeUnit unit) throws InterruptedException { + final Thread currentThread = Thread.currentThread(); + for (;;) { + long currIndex = producerConsumerIndex; + final int producer = producerIndex(currIndex); + final int consumer = consumerIndex(currIndex); + if (producer == consumer) { + // Set consumerThread before pcIndex, to establish happens-before with producer thread. + consumerThread = currentThread; + if (producerConsumerIndexUpdater.compareAndSet(this, currIndex, + combineIndexes(producer, consumer + 1))) { + return taker.apply(timeout, unit); + } + } else { + final T item = spscQueue.poll(); + if (item != null) { + while (!producerConsumerIndexUpdater.compareAndSet(this, currIndex, + combineIndexes(producer, consumer + 1))) { + currIndex = producerConsumerIndex; + } + return item; + } + // It is possible the producer insertion is not yet visible to this thread, yield. + Thread.yield(); + } + } + } + + private void consumerSignalRemoved(final int i) { + for (;;) { + final long currIndex = producerConsumerIndex; + final int producer = producerIndex(currIndex); + final int consumer = consumerIndex(currIndex); + if (producerConsumerIndexUpdater.compareAndSet(this, currIndex, + combineIndexes(producer, consumer + i))) { + break; + } + } + } + + private T pollAndParkIgnoreTime(@SuppressWarnings("unused") final long timeout, + @SuppressWarnings("unused") final TimeUnit unit) throws InterruptedException { + T item; + int yieldCount = 0; + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (yieldCount < POLL_YIELD_COUNT) { + Thread.yield(); + ++yieldCount; + } else { + LockSupport.park(); + } + checkInterrupted(); + } + return item; + } + + @Nullable + private T pollAndPark(final long timeout, final TimeUnit unit) throws InterruptedException { + T item; + final long originalNs = unit.toNanos(timeout); + long remainingNs = originalNs; + long beforeTimeNs = System.nanoTime(); + while ((item = spscQueue.poll()) == null) { + // Benchmarks show that park/unpark is expensive when producer is the EventLoop thread and + // unpark has to wakeup a thread that is parked. Yield has been shown to lower this cost + // on the EventLoop thread and increase throughput in these scenarios. + if (originalNs - remainingNs <= POLL_YIELD_SPIN_NS) { + Thread.yield(); + } else { + LockSupport.parkNanos(remainingNs); + } + checkInterrupted(); + final long afterTimeNs = System.nanoTime(); + final long durationNs = afterTimeNs - beforeTimeNs; + if (durationNs > remainingNs) { + return null; + } + remainingNs -= durationNs; + beforeTimeNs = afterTimeNs; + } + return item; + } + + private static void checkInterrupted() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + } + + private static long combineIndexes(int producer, int consumer) { + return ((long) producer << 32) | consumer; + } + + private static int consumerIndex(long producerConsumerIndex) { + return (int) producerConsumerIndex; + } + + private static int producerIndex(long producerConsumerIndex) { + return (int) (producerConsumerIndex >>> 32); + } + + private interface BiLongFunction { + R apply(long l, T t) throws InterruptedException; + } + } } diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterableTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterableTest.java index 894b8b309b..ba560bc917 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterableTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/PublisherAsBlockingIterableTest.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.NoSuchElementException; import java.util.Spliterator; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; import static io.servicetalk.concurrent.api.Publisher.from; @@ -33,6 +34,7 @@ import static java.util.stream.StreamSupport.stream; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; @@ -366,6 +368,21 @@ void replenishingRequestedShouldHonourQueueContents() { assertThat("Item not expected but found.", iterator.hasNext(), is(false)); } + @Test + void spscThreads() throws ExecutionException, InterruptedException { + Executor executor = Executors.newCachedThreadExecutor(); + try { + int nextExpected = 0; + for (Integer integer : Publisher.range(0, 1000000) + .publishOn(executor) + .toIterable(Integer.MAX_VALUE)) { + assertThat(integer, equalTo(nextExpected++)); + } + } finally { + executor.closeAsync().toFuture().get(); + } + } + @Test void nullShouldBeEmitted() { Iterator iterator = Publisher.from((Void) null).toIterable().iterator(); diff --git a/servicetalk-utils-internal/src/main/java/io/servicetalk/utils/internal/PlatformDependent.java b/servicetalk-utils-internal/src/main/java/io/servicetalk/utils/internal/PlatformDependent.java index 82245e3bc8..174361dae1 100644 --- a/servicetalk-utils-internal/src/main/java/io/servicetalk/utils/internal/PlatformDependent.java +++ b/servicetalk-utils-internal/src/main/java/io/servicetalk/utils/internal/PlatformDependent.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018, 2022 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,16 +34,19 @@ import org.jctools.queues.MpscLinkedQueue; import org.jctools.queues.MpscUnboundedArrayQueue; import org.jctools.queues.SpscChunkedArrayQueue; +import org.jctools.queues.SpscLinkedQueue; import org.jctools.queues.SpscUnboundedArrayQueue; import org.jctools.queues.atomic.MpscGrowableAtomicArrayQueue; import org.jctools.queues.atomic.MpscLinkedAtomicQueue; import org.jctools.queues.atomic.MpscUnboundedAtomicArrayQueue; import org.jctools.queues.atomic.SpscGrowableAtomicArrayQueue; +import org.jctools.queues.atomic.SpscLinkedAtomicQueue; import org.jctools.queues.atomic.SpscUnboundedAtomicArrayQueue; import org.jctools.queues.ea.unpadded.MpscChunkedUnpaddedArrayQueue; import org.jctools.queues.ea.unpadded.MpscLinkedUnpaddedQueue; import org.jctools.queues.ea.unpadded.MpscUnboundedUnpaddedArrayQueue; import org.jctools.queues.ea.unpadded.SpscChunkedUnpaddedArrayQueue; +import org.jctools.queues.ea.unpadded.SpscLinkedUnpaddedQueue; import org.jctools.queues.ea.unpadded.SpscUnboundedUnpaddedArrayQueue; import org.jctools.util.Pow2; import org.jctools.util.UnsafeAccess; @@ -270,6 +273,16 @@ public static Queue newUnboundedSpscQueue(final int initialCapacity) { return Queues.newUnboundedSpscQueue(initialCapacity); } + /** + * Create a new unbounded {@link Queue} that uses a linked data structure which is safe to use for single producer + * (one thread!) and a single consumer (one thread!). + * @param Type of items stored in the queue. + * @return A new unbounded SPSC {@link Queue}. + */ + public static Queue newLinkedSpscQueue() { + return Queues.newLinkedSpscQueue(); + } + private static final class Queues { private static final boolean USE_UNSAFE_QUEUES; private static final boolean USE_UNPADDED_QUEUES; @@ -367,5 +380,13 @@ static Queue newUnboundedSpscQueue(final int initialCapacity) { new SpscUnboundedArrayQueue<>(initialCapacity) : new SpscUnboundedAtomicArrayQueue<>(initialCapacity); } + + static Queue newLinkedSpscQueue() { + return USE_UNSAFE_QUEUES ? + USE_UNPADDED_QUEUES ? + new SpscLinkedUnpaddedQueue<>() : + new SpscLinkedQueue<>() + : new SpscLinkedAtomicQueue<>(); + } } }