diff --git a/include/proxy-wasm/context.h b/include/proxy-wasm/context.h index 12937041f..7f675525c 100644 --- a/include/proxy-wasm/context.h +++ b/include/proxy-wasm/context.h @@ -397,6 +397,13 @@ class ContextBase : public RootInterface, bool destroyed_ = false; bool stream_failed_ = false; // Set true after failStream is called in case of VM failure. + // If true, convertVmCallResultToFilterHeadersStatus() propagates + // FilterHeadersStatus::StopIteration unmodified to callers. If false, it + // translates FilterHeaderStatus::StopIteration to + // FilterHeadersStatus::StopAllIterationAndWatermark, which is the default + // behavior for v0.2.* of the Proxy-Wasm ABI. + bool allow_on_headers_stop_iteration_ = false; + private: // helper functions FilterHeadersStatus convertVmCallResultToFilterHeadersStatus(uint64_t result); diff --git a/src/context.cc b/src/context.cc index 5353a52a5..d8dbc9a2a 100644 --- a/src/context.cc +++ b/src/context.cc @@ -493,10 +493,12 @@ FilterHeadersStatus ContextBase::convertVmCallResultToFilterHeadersStatus(uint64 result > static_cast(FilterHeadersStatus::StopAllIterationAndWatermark)) { return FilterHeadersStatus::StopAllIterationAndWatermark; } - if (result == static_cast(FilterHeadersStatus::StopIteration)) { - // Always convert StopIteration (pause processing headers, but continue processing body) - // to StopAllIterationAndWatermark (pause all processing), since the former breaks all - // assumptions about HTTP processing. + if (result == static_cast(FilterHeadersStatus::StopIteration) && + !allow_on_headers_stop_iteration_) { + // Default behavior for Proxy-Wasm 0.2.* ABI is to translate StopIteration + // (pause processing headers, but continue processing body) to + // StopAllIterationAndWatermark (pause all processing), as described in + // https://github.com/proxy-wasm/proxy-wasm-cpp-host/issues/143. return FilterHeadersStatus::StopAllIterationAndWatermark; } return static_cast(result); diff --git a/test/BUILD b/test/BUILD index 61973ce17..73787e4b0 100644 --- a/test/BUILD +++ b/test/BUILD @@ -132,6 +132,21 @@ cc_test( ], ) +cc_test( + name = "stop_iteration_test", + srcs = ["stop_iteration_test.cc"], + data = [ + "//test/test_data:stop_iteration.wasm", + ], + linkstatic = 1, + deps = [ + ":utility_lib", + "//:lib", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "security_test", srcs = ["security_test.cc"], diff --git a/test/stop_iteration_test.cc b/test/stop_iteration_test.cc new file mode 100644 index 000000000..9ff443f24 --- /dev/null +++ b/test/stop_iteration_test.cc @@ -0,0 +1,81 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "include/proxy-wasm/wasm.h" +#include "test/utility.h" + +namespace proxy_wasm { + +INSTANTIATE_TEST_SUITE_P(WasmEngines, TestVm, testing::ValuesIn(getWasmEngines()), + [](const testing::TestParamInfo &info) { + return info.param; + }); + +// TestVm is parameterized for each engine and creates a VM on construction. +TEST_P(TestVm, AllowOnHeadersStopIteration) { + // Read the wasm source. + auto source = readTestWasmFile("stop_iteration.wasm"); + ASSERT_FALSE(source.empty()); + + // Create a WasmBase and load the plugin. + auto wasm = std::make_shared(std::move(vm_)); + ASSERT_TRUE(wasm->load(source, /*allow_precompiled=*/false)); + ASSERT_TRUE(wasm->initialize()); + + // Create a plugin. + const auto plugin = std::make_shared( + /*name=*/"test", /*root_id=*/"", /*vm_id=*/"", + /*engine=*/wasm->wasm_vm()->getEngineName(), /*plugin_config=*/"", + /*fail_open=*/false, /*key=*/""); + + // Create root context, call onStart() and onConfigure() + ContextBase *root_context = wasm->start(plugin); + ASSERT_TRUE(root_context != nullptr); + ASSERT_TRUE(wasm->configure(root_context, plugin)); + + auto wasm_handle = std::make_shared(wasm); + auto plugin_handle = std::make_shared(wasm_handle, plugin); + + // By default, stream context onRequestHeaders and onResponseHeaders + // translates FilterHeadersStatus::StopIteration to + // FilterHeadersStatus::StopAllIterationAndWatermark. + { + auto stream_context = TestContext(wasm.get(), root_context->id(), plugin_handle); + stream_context.onCreate(); + EXPECT_EQ(stream_context.onRequestHeaders(/*headers=*/0, /*end_of_stream=*/false), + FilterHeadersStatus::StopAllIterationAndWatermark); + EXPECT_EQ(stream_context.onResponseHeaders(/*headers=*/0, /*end_of_stream=*/false), + FilterHeadersStatus::StopAllIterationAndWatermark); + stream_context.onDone(); + stream_context.onDelete(); + } + ASSERT_FALSE(wasm->isFailed()); + + // Create a stream context that propagates FilterHeadersStatus::StopIteration. + { + auto stream_context = TestContext(wasm.get(), root_context->id(), plugin_handle); + stream_context.set_allow_on_headers_stop_iteration(true); + stream_context.onCreate(); + EXPECT_EQ(stream_context.onRequestHeaders(/*headers=*/0, /*end_of_stream=*/false), + FilterHeadersStatus::StopIteration); + EXPECT_EQ(stream_context.onResponseHeaders(/*headers=*/0, /*end_of_stream=*/false), + FilterHeadersStatus::StopIteration); + stream_context.onDone(); + stream_context.onDelete(); + } + ASSERT_FALSE(wasm->isFailed()); +} + +} // namespace proxy_wasm diff --git a/test/test_data/BUILD b/test/test_data/BUILD index bd70b8eb9..e5ecd439e 100644 --- a/test/test_data/BUILD +++ b/test/test_data/BUILD @@ -89,3 +89,8 @@ proxy_wasm_cc_binary( name = "http_logging.wasm", srcs = ["http_logging.cc"], ) + +proxy_wasm_cc_binary( + name = "stop_iteration.wasm", + srcs = ["stop_iteration.cc"], +) diff --git a/test/test_data/stop_iteration.cc b/test/test_data/stop_iteration.cc new file mode 100644 index 000000000..55594285f --- /dev/null +++ b/test/test_data/stop_iteration.cc @@ -0,0 +1,31 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "proxy_wasm_intrinsics.h" + +class StopIterationContext : public Context { +public: + explicit StopIterationContext(uint32_t id, RootContext *root) : Context(id, root) {} + + FilterHeadersStatus onRequestHeaders(uint32_t headers, bool end_of_stream) override { + return FilterHeadersStatus::StopIteration; + } + + FilterHeadersStatus onResponseHeaders(uint32_t headers, bool end_of_stream) override { + return FilterHeadersStatus::StopIteration; + } +}; + +static RegisterContextFactory register_StaticContext(CONTEXT_FACTORY(StopIterationContext), + ROOT_FACTORY(RootContext)); diff --git a/test/utility.h b/test/utility.h index 27b3b0493..ccd2a59b6 100644 --- a/test/utility.h +++ b/test/utility.h @@ -133,6 +133,8 @@ class TestContext : public ContextBase { .count(); } + void set_allow_on_headers_stop_iteration(bool allow) { allow_on_headers_stop_iteration_ = allow; } + private: std::string log_; static std::string global_log_;