diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index 4bfd0b25..3a8341e3 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -20,6 +20,7 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; @@ -35,6 +36,7 @@ public class ParallelAgent extends BaseAgent { private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class); + private final Scheduler scheduler; /** * Constructor for ParallelAgent. @@ -44,24 +46,34 @@ public class ParallelAgent extends BaseAgent { * @param subAgents The list of sub-agents to run in parallel. * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. + * @param scheduler The scheduler to use for parallel execution. */ private ParallelAgent( String name, String description, List subAgents, List beforeAgentCallback, - List afterAgentCallback) { + List afterAgentCallback, + Scheduler scheduler) { super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + this.scheduler = scheduler; } /** Builder for {@link ParallelAgent}. */ public static class Builder extends BaseAgent.Builder { + private Scheduler scheduler = io.reactivex.rxjava3.schedulers.Schedulers.io(); + + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + @Override public ParallelAgent build() { return new ParallelAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler); } } @@ -131,9 +143,10 @@ protected Flowable runAsyncImpl(InvocationContext invocationContext) { List> agentFlowables = new ArrayList<>(); for (BaseAgent subAgent : currentSubAgents) { - agentFlowables.add(subAgent.runAsync(invocationContext)); + agentFlowables.add(subAgent.runAsync(invocationContext).subscribeOn(scheduler)); } - return Flowable.merge(agentFlowables); + return Flowable.merge(agentFlowables) + .takeUntil((Event event) -> event.actions().escalate().orElse(false)); } /** diff --git a/core/src/test/java/com/google/adk/agents/ParallelAgentEscalationTest.java b/core/src/test/java/com/google/adk/agents/ParallelAgentEscalationTest.java new file mode 100644 index 00000000..9a3d0e88 --- /dev/null +++ b/core/src/test/java/com/google/adk/agents/ParallelAgentEscalationTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.agents; + +import static com.google.adk.testing.TestUtils.createInvocationContext; +import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.adk.events.Event; +import com.google.adk.events.EventActions; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Scheduler; +import io.reactivex.rxjava3.schedulers.TestScheduler; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class ParallelAgentEscalationTest { + + static class EscalatingAgent extends BaseAgent { + private final long delayMillis; + private final Scheduler scheduler; + + private EscalatingAgent(String name, long delayMillis, Scheduler scheduler) { + super(name, "Escalating Agent", ImmutableList.of(), null, null); + this.delayMillis = delayMillis; + this.scheduler = scheduler; + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + Flowable event = + Flowable.fromCallable( + () -> + Event.builder() + .author(name()) + .branch(invocationContext.branch()) + .invocationId(invocationContext.invocationId()) + .content(Content.fromParts(Part.fromText("Escalating!"))) + .actions(EventActions.builder().escalate(true).build()) + .build()); + + if (delayMillis > 0) { + return event.delay(delayMillis, MILLISECONDS, scheduler); + } + return event; + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + throw new UnsupportedOperationException("Not implemented"); + } + } + + static class SlowAgent extends BaseAgent { + private final long delayMillis; + private final Scheduler scheduler; + + private SlowAgent(String name, long delayMillis, Scheduler scheduler) { + super(name, "Slow Agent", ImmutableList.of(), null, null); + this.delayMillis = delayMillis; + this.scheduler = scheduler; + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + Flowable event = + Flowable.fromCallable( + () -> + Event.builder() + .author(name()) + .branch(invocationContext.branch()) + .invocationId(invocationContext.invocationId()) + .content(Content.fromParts(Part.fromText("Finished"))) + .build()); + + if (delayMillis > 0) { + return event.delay(delayMillis, MILLISECONDS, scheduler); + } + return event; + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + throw new UnsupportedOperationException("Not implemented"); + } + } + + @Test + public void runAsync_escalationEvent_shortCircuitsOtherAgents() { + TestScheduler testScheduler = new TestScheduler(); + + EscalatingAgent agent1 = new EscalatingAgent("agent1", 100, testScheduler); + SlowAgent agent2 = new SlowAgent("agent2", 500, testScheduler); + + ParallelAgent parallelAgent = + ParallelAgent.builder() + .name("parallel_agent") + .subAgents(agent1, agent2) + .scheduler(testScheduler) + .build(); + + InvocationContext invocationContext = createInvocationContext(parallelAgent); + + var subscriber = parallelAgent.runAsync(invocationContext).test(); + + testScheduler.advanceTimeBy(200, MILLISECONDS); + + subscriber.assertValueCount(1); + Event event = subscriber.values().get(0); + assertThat(event.author()).isEqualTo("agent1"); + assertThat(event.actions().escalate()).hasValue(true); + + subscriber.assertComplete(); + testScheduler.advanceTimeBy(1000, MILLISECONDS); + subscriber.assertValueCount(1); + } +}