From 80dedf4b9e0814c78ca6c41468497abcc8dd76ca Mon Sep 17 00:00:00 2001
From: Alejandro Martinez Ruiz <amartine@redhat.com>
Date: Tue, 22 Dec 2020 12:41:05 +0100
Subject: [PATCH 1/3] dispatcher: add on_create_child_context for a simpler way
 to create contexts

This introduces a new hook that is called when a new child context is
created from a root context. The implementer can construct and specify the
type of the new context with a single hook method.

This patch is based on work by Yaroslav Skopets <y.skopets@gmail.com>.

Signed-off-by: Alejandro Martinez Ruiz <amartine@redhat.com>
---
 src/dispatcher.rs | 61 ++++++++++++++++++++++++++++++-----------------
 src/traits.rs     |  6 +++++
 src/types.rs      |  5 ++++
 3 files changed, 50 insertions(+), 22 deletions(-)

diff --git a/src/dispatcher.rs b/src/dispatcher.rs
index d9c9d491..b1f5d3c0 100644
--- a/src/dispatcher.rs
+++ b/src/dispatcher.rs
@@ -117,14 +117,8 @@ impl Dispatcher {
             },
             None => panic!("invalid root_context_id"),
         };
-        if self
-            .streams
-            .borrow_mut()
-            .insert(context_id, new_context)
-            .is_some()
-        {
-            panic!("duplicate context_id")
-        }
+
+        self.register_stream_context(context_id, new_context);
     }
 
     fn create_http_context(&self, context_id: u32, root_context_id: u32) {
@@ -138,35 +132,58 @@ impl Dispatcher {
             },
             None => panic!("invalid root_context_id"),
         };
+
+        self.register_http_context(context_id, new_context);
+    }
+
+    fn register_stream_context(&self, context_id: u32, stream_context: Box<dyn StreamContext>) {
+        if self
+            .streams
+            .borrow_mut()
+            .insert(context_id, stream_context)
+            .is_some()
+        {
+            panic!("duplicate context_id {}", context_id);
+        }
+    }
+
+    fn register_http_context(&self, context_id: u32, http_context: Box<dyn HttpContext>) {
         if self
             .http_streams
             .borrow_mut()
-            .insert(context_id, new_context)
+            .insert(context_id, http_context)
             .is_some()
         {
-            panic!("duplicate context_id")
+            panic!("duplicate context_id {}", context_id);
         }
     }
 
     fn on_create_context(&self, context_id: u32, root_context_id: u32) {
         if root_context_id == 0 {
             self.create_root_context(context_id);
-        } else if self.new_http_stream.get().is_some() {
-            self.create_http_context(context_id, root_context_id);
-        } else if self.new_stream.get().is_some() {
-            self.create_stream_context(context_id, root_context_id);
-        } else if let Some(root_context) = self.roots.borrow().get(&root_context_id) {
-            match root_context.get_type() {
-                Some(ContextType::HttpContext) => {
-                    self.create_http_context(context_id, root_context_id)
+            return;
+        }
+
+        if let Some(root_context) = self.roots.borrow_mut().get_mut(&root_context_id) {
+            match root_context.on_create_child_context(context_id) {
+                Some(ChildContext::HttpContext(http_context)) => {
+                    self.register_http_context(context_id, http_context);
+                }
+                Some(ChildContext::StreamContext(stream_context)) => {
+                    self.register_stream_context(context_id, stream_context);
                 }
-                Some(ContextType::StreamContext) => {
-                    self.create_stream_context(context_id, root_context_id)
+                None => match root_context.get_type() {
+                    Some(ContextType::HttpContext) => {
+                        self.create_http_context(context_id, root_context_id);
+                    }
+                    Some(ContextType::StreamContext) => {
+                        self.create_stream_context(context_id, root_context_id);
+                    }
+                    None => panic!("you must define on_create_child_context or get_type() and create_http/stream_context in your root context"),
                 }
-                None => panic!("missing ContextType on root_context"),
             }
         } else {
-            panic!("invalid root_context_id and missing constructors");
+            panic!("invalid root_context_id {}", root_context_id);
         }
     }
 
diff --git a/src/traits.rs b/src/traits.rs
index 5b7fc4be..980f8963 100644
--- a/src/traits.rs
+++ b/src/traits.rs
@@ -122,6 +122,12 @@ pub trait RootContext: Context {
 
     fn on_log(&mut self) {}
 
+    fn on_create_child_context(&mut self, _context_id: u32) -> Option<ChildContext> {
+        // on_create_child_context has higher priority than any other methods
+        // for creating non root contexts
+        None
+    }
+
     fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
         None
     }
diff --git a/src/types.rs b/src/types.rs
index 855a414b..79bb27f0 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -18,6 +18,11 @@ pub type NewRootContext = fn(context_id: u32) -> Box<dyn RootContext>;
 pub type NewStreamContext = fn(context_id: u32, root_context_id: u32) -> Box<dyn StreamContext>;
 pub type NewHttpContext = fn(context_id: u32, root_context_id: u32) -> Box<dyn HttpContext>;
 
+pub enum ChildContext {
+    StreamContext(Box<dyn StreamContext>),
+    HttpContext(Box<dyn HttpContext>),
+}
+
 #[repr(u32)]
 #[derive(Debug)]
 pub enum LogLevel {

From 2d6c658b5d98cc5d46f861b301698b7f90815452 Mon Sep 17 00:00:00 2001
From: Alejandro Martinez Ruiz <amartine@redhat.com>
Date: Tue, 22 Dec 2020 19:04:56 +0100
Subject: [PATCH 2/3] examples: use on_create_child_context to create non http
 or stream contexts

Signed-off-by: Alejandro Martinez Ruiz <amartine@redhat.com>
---
 examples/http_auth_random.rs | 12 +++++++++++-
 examples/http_body.rs        |  8 ++------
 examples/http_config.rs      | 10 +++-------
 examples/http_headers.rs     | 10 ++++------
 4 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/examples/http_auth_random.rs b/examples/http_auth_random.rs
index b9e747c6..ddb5fbf7 100644
--- a/examples/http_auth_random.rs
+++ b/examples/http_auth_random.rs
@@ -20,7 +20,17 @@ use std::time::Duration;
 #[no_mangle]
 pub fn _start() {
     proxy_wasm::set_log_level(LogLevel::Trace);
-    proxy_wasm::set_http_context(|_, _| -> Box<dyn HttpContext> { Box::new(HttpAuthRandom) });
+    proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> { Box::new(HttpAuthRandomRoot) });
+}
+
+struct HttpAuthRandomRoot;
+
+impl Context for HttpAuthRandomRoot {}
+
+impl RootContext for HttpAuthRandomRoot {
+    fn on_create_child_context(&mut self, _context_id: u32) -> Option<ChildContext> {
+        Some(ChildContext::HttpContext(Box::new(HttpAuthRandom)))
+    }
 }
 
 struct HttpAuthRandom;
diff --git a/examples/http_body.rs b/examples/http_body.rs
index ff884648..cf160fd0 100644
--- a/examples/http_body.rs
+++ b/examples/http_body.rs
@@ -26,12 +26,8 @@ struct HttpBodyRoot;
 impl Context for HttpBodyRoot {}
 
 impl RootContext for HttpBodyRoot {
-    fn get_type(&self) -> Option<ContextType> {
-        Some(ContextType::HttpContext)
-    }
-
-    fn create_http_context(&self, _: u32) -> Option<Box<dyn HttpContext>> {
-        Some(Box::new(HttpBody))
+    fn on_create_child_context(&mut self, _context_id: u32) -> Option<ChildContext> {
+        Some(ChildContext::HttpContext(Box::new(HttpBody)))
     }
 }
 
diff --git a/examples/http_config.rs b/examples/http_config.rs
index d912ae03..81838315 100644
--- a/examples/http_config.rs
+++ b/examples/http_config.rs
@@ -52,13 +52,9 @@ impl RootContext for HttpConfigHeaderRoot {
         true
     }
 
-    fn create_http_context(&self, _: u32) -> Option<Box<dyn HttpContext>> {
-        Some(Box::new(HttpConfigHeader {
+    fn on_create_child_context(&mut self, _context_id: u32) -> Option<ChildContext> {
+        Some(ChildContext::HttpContext(Box::new(HttpConfigHeader {
             header_content: self.header_content.clone(),
-        }))
-    }
-
-    fn get_type(&self) -> Option<ContextType> {
-        Some(ContextType::HttpContext)
+        })))
     }
 }
diff --git a/examples/http_headers.rs b/examples/http_headers.rs
index b0f1a745..d9692698 100644
--- a/examples/http_headers.rs
+++ b/examples/http_headers.rs
@@ -27,12 +27,10 @@ struct HttpHeadersRoot;
 impl Context for HttpHeadersRoot {}
 
 impl RootContext for HttpHeadersRoot {
-    fn get_type(&self) -> Option<ContextType> {
-        Some(ContextType::HttpContext)
-    }
-
-    fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
-        Some(Box::new(HttpHeaders { context_id }))
+    fn on_create_child_context(&mut self, context_id: u32) -> Option<ChildContext> {
+        Some(ChildContext::HttpContext(Box::new(HttpHeaders {
+            context_id,
+        })))
     }
 }
 

From 471fc5413b1d9a3d81ba8f61da5f63a4cc2c6bdb Mon Sep 17 00:00:00 2001
From: Alejandro Martinez Ruiz <amartine@redhat.com>
Date: Wed, 23 Dec 2020 00:21:27 +0100
Subject: [PATCH 3/3] dispatcher: make on_create_child_context the sole way to
 create non root ctxts

Having two ways to create non root contexts before introducing
on_create_child_context can cause confusion. This new method is simple
and should cover the usages of the other two in a cleaner way, so let's
remove the others in this _breaking_ change.

Signed-off-by: Alejandro Martinez Ruiz <amartine@redhat.com>
---
 src/dispatcher.rs | 60 +----------------------------------------------
 src/lib.rs        |  8 -------
 src/traits.rs     | 15 +-----------
 src/types.rs      |  9 -------
 4 files changed, 2 insertions(+), 90 deletions(-)

diff --git a/src/dispatcher.rs b/src/dispatcher.rs
index b1f5d3c0..769921d0 100644
--- a/src/dispatcher.rs
+++ b/src/dispatcher.rs
@@ -26,14 +26,6 @@ pub(crate) fn set_root_context(callback: NewRootContext) {
     DISPATCHER.with(|dispatcher| dispatcher.set_root_context(callback));
 }
 
-pub(crate) fn set_stream_context(callback: NewStreamContext) {
-    DISPATCHER.with(|dispatcher| dispatcher.set_stream_context(callback));
-}
-
-pub(crate) fn set_http_context(callback: NewHttpContext) {
-    DISPATCHER.with(|dispatcher| dispatcher.set_http_context(callback));
-}
-
 pub(crate) fn register_callout(token_id: u32) {
     DISPATCHER.with(|dispatcher| dispatcher.register_callout(token_id));
 }
@@ -46,9 +38,7 @@ impl RootContext for NoopRoot {}
 struct Dispatcher {
     new_root: Cell<Option<NewRootContext>>,
     roots: RefCell<HashMap<u32, Box<dyn RootContext>>>,
-    new_stream: Cell<Option<NewStreamContext>>,
     streams: RefCell<HashMap<u32, Box<dyn StreamContext>>>,
-    new_http_stream: Cell<Option<NewHttpContext>>,
     http_streams: RefCell<HashMap<u32, Box<dyn HttpContext>>>,
     active_id: Cell<u32>,
     callouts: RefCell<HashMap<u32, u32>>,
@@ -59,9 +49,7 @@ impl Dispatcher {
         Dispatcher {
             new_root: Cell::new(None),
             roots: RefCell::new(HashMap::new()),
-            new_stream: Cell::new(None),
             streams: RefCell::new(HashMap::new()),
-            new_http_stream: Cell::new(None),
             http_streams: RefCell::new(HashMap::new()),
             active_id: Cell::new(0),
             callouts: RefCell::new(HashMap::new()),
@@ -72,14 +60,6 @@ impl Dispatcher {
         self.new_root.set(Some(callback));
     }
 
-    fn set_stream_context(&self, callback: NewStreamContext) {
-        self.new_stream.set(Some(callback));
-    }
-
-    fn set_http_context(&self, callback: NewHttpContext) {
-        self.new_http_stream.set(Some(callback));
-    }
-
     fn register_callout(&self, token_id: u32) {
         if self
             .callouts
@@ -106,36 +86,6 @@ impl Dispatcher {
         }
     }
 
-    fn create_stream_context(&self, context_id: u32, root_context_id: u32) {
-        let new_context = match self.roots.borrow().get(&root_context_id) {
-            Some(root_context) => match self.new_stream.get() {
-                Some(f) => f(context_id, root_context_id),
-                None => match root_context.create_stream_context(context_id) {
-                    Some(stream_context) => stream_context,
-                    None => panic!("create_stream_context returned None"),
-                },
-            },
-            None => panic!("invalid root_context_id"),
-        };
-
-        self.register_stream_context(context_id, new_context);
-    }
-
-    fn create_http_context(&self, context_id: u32, root_context_id: u32) {
-        let new_context = match self.roots.borrow().get(&root_context_id) {
-            Some(root_context) => match self.new_http_stream.get() {
-                Some(f) => f(context_id, root_context_id),
-                None => match root_context.create_http_context(context_id) {
-                    Some(stream_context) => stream_context,
-                    None => panic!("create_http_context returned None"),
-                },
-            },
-            None => panic!("invalid root_context_id"),
-        };
-
-        self.register_http_context(context_id, new_context);
-    }
-
     fn register_stream_context(&self, context_id: u32, stream_context: Box<dyn StreamContext>) {
         if self
             .streams
@@ -172,15 +122,7 @@ impl Dispatcher {
                 Some(ChildContext::StreamContext(stream_context)) => {
                     self.register_stream_context(context_id, stream_context);
                 }
-                None => match root_context.get_type() {
-                    Some(ContextType::HttpContext) => {
-                        self.create_http_context(context_id, root_context_id);
-                    }
-                    Some(ContextType::StreamContext) => {
-                        self.create_stream_context(context_id, root_context_id);
-                    }
-                    None => panic!("you must define on_create_child_context or get_type() and create_http/stream_context in your root context"),
-                }
+                None => panic!("you must implement on_create_child_context in your root context"),
             }
         } else {
             panic!("invalid root_context_id {}", root_context_id);
diff --git a/src/lib.rs b/src/lib.rs
index cee98b72..ff55cbac 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -28,13 +28,5 @@ pub fn set_root_context(callback: types::NewRootContext) {
     dispatcher::set_root_context(callback);
 }
 
-pub fn set_stream_context(callback: types::NewStreamContext) {
-    dispatcher::set_stream_context(callback);
-}
-
-pub fn set_http_context(callback: types::NewHttpContext) {
-    dispatcher::set_http_context(callback);
-}
-
 #[no_mangle]
 pub extern "C" fn proxy_abi_version_0_1_0() {}
diff --git a/src/traits.rs b/src/traits.rs
index 980f8963..a02a2512 100644
--- a/src/traits.rs
+++ b/src/traits.rs
@@ -123,20 +123,7 @@ pub trait RootContext: Context {
     fn on_log(&mut self) {}
 
     fn on_create_child_context(&mut self, _context_id: u32) -> Option<ChildContext> {
-        // on_create_child_context has higher priority than any other methods
-        // for creating non root contexts
-        None
-    }
-
-    fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
-        None
-    }
-
-    fn create_stream_context(&self, _context_id: u32) -> Option<Box<dyn StreamContext>> {
-        None
-    }
-
-    fn get_type(&self) -> Option<ContextType> {
+        // on_create_child_context is required to create non root contexts
         None
     }
 }
diff --git a/src/types.rs b/src/types.rs
index 79bb27f0..743a9f8a 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -15,8 +15,6 @@
 use crate::traits::*;
 
 pub type NewRootContext = fn(context_id: u32) -> Box<dyn RootContext>;
-pub type NewStreamContext = fn(context_id: u32, root_context_id: u32) -> Box<dyn StreamContext>;
-pub type NewHttpContext = fn(context_id: u32, root_context_id: u32) -> Box<dyn HttpContext>;
 
 pub enum ChildContext {
     StreamContext(Box<dyn StreamContext>),
@@ -52,13 +50,6 @@ pub enum Status {
     InternalFailure = 10,
 }
 
-#[repr(u32)]
-#[derive(Debug)]
-pub enum ContextType {
-    HttpContext = 0,
-    StreamContext = 1,
-}
-
 #[repr(u32)]
 #[derive(Debug)]
 pub enum BufferType {