@@ -47,6 +47,12 @@ pub enum ConnectError {
4747 #[ error( "expected path header" ) ]
4848 WrongPath ,
4949
50+ #[ error( "invalid protocol header" ) ]
51+ InvalidProtocol ,
52+
53+ #[ error( "structured field error: {0}" ) ]
54+ StructuredFieldError ( Arc < sfv:: Error > ) ,
55+
5056 #[ error( "non-200 status: {0:?}" ) ]
5157 ErrorStatus ( http:: StatusCode ) ,
5258
@@ -60,12 +66,45 @@ impl From<std::io::Error> for ConnectError {
6066 }
6167}
6268
63- #[ derive( Debug ) ]
69+ impl From < sfv:: Error > for ConnectError {
70+ fn from ( err : sfv:: Error ) -> Self {
71+ ConnectError :: StructuredFieldError ( Arc :: new ( err) )
72+ }
73+ }
74+
75+ /// A CONNECT request to initiate a WebTransport session.
76+ #[ non_exhaustive]
77+ #[ derive( Debug , Clone ) ]
6478pub struct ConnectRequest {
79+ /// The URL to connect to.
6580 pub url : Url ,
81+
82+ /// The subprotocols requested (if any).
83+ pub protocols : Vec < String > ,
6684}
6785
6886impl ConnectRequest {
87+ pub fn new ( url : impl Into < Url > ) -> Self {
88+ Self {
89+ url : url. into ( ) ,
90+ protocols : Vec :: new ( ) ,
91+ }
92+ }
93+
94+ pub fn with_protocol ( mut self , protocol : impl Into < String > ) -> Self {
95+ self . protocols . push ( protocol. into ( ) ) ;
96+ self
97+ }
98+
99+ pub fn with_protocols (
100+ mut self ,
101+ protocols : impl IntoIterator < Item = impl Into < String > > ,
102+ ) -> Self {
103+ self . protocols
104+ . extend ( protocols. into_iter ( ) . map ( |p| p. into ( ) ) ) ;
105+ self
106+ }
107+
69108 pub fn decode < B : Buf > ( buf : & mut B ) -> Result < Self , ConnectError > {
70109 let ( typ, mut data) = Frame :: read ( buf) . map_err ( |_| ConnectError :: UnexpectedEnd ) ?;
71110 if typ != Frame :: HEADERS {
@@ -102,9 +141,16 @@ impl ConnectRequest {
102141 return Err ( ConnectError :: WrongProtocol ( protocol. map ( |s| s. to_string ( ) ) ) ) ;
103142 }
104143
144+ let protocols = headers
145+ . get ( protocol_negotiation:: AVAILABLE_NAME )
146+ . map ( protocol_negotiation:: decode_list)
147+ . transpose ( )
148+ . map_err ( |_| ConnectError :: InvalidProtocol ) ?
149+ . unwrap_or_default ( ) ;
150+
105151 let url = Url :: parse ( & format ! ( "{scheme}://{authority}{path_and_query}" ) ) ?;
106152
107- Ok ( Self { url } )
153+ Ok ( Self { url, protocols } )
108154 }
109155
110156 pub async fn read < S : AsyncRead + Unpin > ( stream : & mut S ) -> Result < Self , ConnectError > {
@@ -123,7 +169,7 @@ impl ConnectRequest {
123169 }
124170 }
125171
126- pub fn encode < B : BufMut > ( & self , buf : & mut B ) {
172+ pub fn encode < B : BufMut > ( & self , buf : & mut B ) -> Result < ( ) , ConnectError > {
127173 let mut headers = qpack:: Headers :: default ( ) ;
128174 headers. set ( ":method" , "CONNECT" ) ;
129175 headers. set ( ":scheme" , self . url . scheme ( ) ) ;
@@ -135,6 +181,11 @@ impl ConnectRequest {
135181 headers. set ( ":path" , & path_and_query) ;
136182 headers. set ( ":protocol" , "webtransport" ) ;
137183
184+ if !self . protocols . is_empty ( ) {
185+ let encoded = protocol_negotiation:: encode_list ( & self . protocols ) ?;
186+ headers. set ( protocol_negotiation:: AVAILABLE_NAME , & encoded) ;
187+ }
188+
138189 // Use a temporary buffer so we can compute the size.
139190 let mut tmp = Vec :: new ( ) ;
140191 headers. encode ( & mut tmp) ;
@@ -143,22 +194,51 @@ impl ConnectRequest {
143194 Frame :: HEADERS . encode ( buf) ;
144195 size. encode ( buf) ;
145196 buf. put_slice ( & tmp) ;
197+
198+ Ok ( ( ) )
146199 }
147200
148201 pub async fn write < S : AsyncWrite + Unpin > ( & self , stream : & mut S ) -> Result < ( ) , ConnectError > {
149202 let mut buf = BytesMut :: new ( ) ;
150- self . encode ( & mut buf) ;
203+ self . encode ( & mut buf) ? ;
151204 stream. write_all_buf ( & mut buf) . await ?;
152205 Ok ( ( ) )
153206 }
154207}
155208
156- #[ derive( Debug ) ]
209+ impl From < Url > for ConnectRequest {
210+ fn from ( url : Url ) -> Self {
211+ Self {
212+ url,
213+ protocols : Vec :: new ( ) ,
214+ }
215+ }
216+ }
217+
218+ /// A CONNECT response to accept or reject a WebTransport session.
219+ #[ non_exhaustive]
220+ #[ derive( Debug , Clone ) ]
157221pub struct ConnectResponse {
222+ /// The status code of the response.
158223 pub status : http:: status:: StatusCode ,
224+
225+ /// The subprotocol selected by the server, if any
226+ pub protocol : Option < String > ,
159227}
160228
161229impl ConnectResponse {
230+ pub fn new ( status : http:: StatusCode ) -> Self {
231+ Self {
232+ status,
233+ protocol : None ,
234+ }
235+ }
236+
237+ pub fn with_protocol ( mut self , protocol : impl Into < String > ) -> Self {
238+ self . protocol = Some ( protocol. into ( ) ) ;
239+ self
240+ }
241+
162242 pub fn decode < B : Buf > ( buf : & mut B ) -> Result < Self , ConnectError > {
163243 let ( typ, mut data) = Frame :: read ( buf) . map_err ( |_| ConnectError :: UnexpectedEnd ) ?;
164244 if typ != Frame :: HEADERS {
@@ -178,7 +258,13 @@ impl ConnectResponse {
178258 o => return Err ( ConnectError :: WrongStatus ( o) ) ,
179259 } ;
180260
181- Ok ( Self { status } )
261+ let protocol = headers
262+ . get ( protocol_negotiation:: SELECTED_NAME )
263+ . map ( protocol_negotiation:: decode_item)
264+ . transpose ( )
265+ . map_err ( |_| ConnectError :: InvalidProtocol ) ?;
266+
267+ Ok ( Self { status, protocol } )
182268 }
183269
184270 pub async fn read < S : AsyncRead + Unpin > ( stream : & mut S ) -> Result < Self , ConnectError > {
@@ -197,11 +283,16 @@ impl ConnectResponse {
197283 }
198284 }
199285
200- pub fn encode < B : BufMut > ( & self , buf : & mut B ) {
286+ pub fn encode < B : BufMut > ( & self , buf : & mut B ) -> Result < ( ) , ConnectError > {
201287 let mut headers = qpack:: Headers :: default ( ) ;
202288 headers. set ( ":status" , self . status . as_str ( ) ) ;
203289 headers. set ( "sec-webtransport-http3-draft" , "draft02" ) ;
204290
291+ if let Some ( protocol) = self . protocol . as_ref ( ) {
292+ let encoded = protocol_negotiation:: encode_item ( protocol) ?;
293+ headers. set ( protocol_negotiation:: SELECTED_NAME , & encoded) ;
294+ }
295+
205296 // Use a temporary buffer so we can compute the size.
206297 let mut tmp = Vec :: new ( ) ;
207298 headers. encode ( & mut tmp) ;
@@ -210,12 +301,82 @@ impl ConnectResponse {
210301 Frame :: HEADERS . encode ( buf) ;
211302 size. encode ( buf) ;
212303 buf. put_slice ( & tmp) ;
304+
305+ Ok ( ( ) )
213306 }
214307
215308 pub async fn write < S : AsyncWrite + Unpin > ( & self , stream : & mut S ) -> Result < ( ) , ConnectError > {
216309 let mut buf = BytesMut :: new ( ) ;
217- self . encode ( & mut buf) ;
310+ self . encode ( & mut buf) ? ;
218311 stream. write_all_buf ( & mut buf) . await ?;
219312 Ok ( ( ) )
220313 }
221314}
315+
316+ impl From < http:: StatusCode > for ConnectResponse {
317+ fn from ( status : http:: StatusCode ) -> Self {
318+ Self {
319+ status,
320+ protocol : None ,
321+ }
322+ }
323+ }
324+
325+ mod protocol_negotiation {
326+ //! WebTransport sub-protocol negotiation using RFC 8941 Structured Fields,
327+ //!
328+ //! according to [draft 14](https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-14.html#section-3.3)
329+
330+ use sfv:: { Item , ItemSerializer , List , ListEntry , ListSerializer , Parser , StringRef } ;
331+
332+ use crate :: ConnectError ;
333+
334+ /// The header name for the available protocols, sent within the WebTransport Connect request.
335+ pub const AVAILABLE_NAME : & str = "wt-available-protocols" ;
336+ /// The header name for the selected protocol, sent within the WebTransport Connect response.
337+ pub const SELECTED_NAME : & str = "wt-protocol" ;
338+
339+ /// Encode a list of protocol strings as an RFC 8941 Structured Field List.
340+ pub fn encode_list ( protocols : & [ String ] ) -> Result < String , ConnectError > {
341+ let mut serializer = ListSerializer :: new ( ) ;
342+ for protocol in protocols {
343+ let s = StringRef :: from_str ( protocol) ?;
344+ let _ = serializer. bare_item ( s) ;
345+ }
346+ serializer. finish ( ) . ok_or ( ConnectError :: InvalidProtocol )
347+ }
348+
349+ /// Decode an RFC 8941 Structured Field List of strings.
350+ pub fn decode_list ( value : & str ) -> Result < Vec < String > , ConnectError > {
351+ let list = Parser :: new ( value) . parse :: < List > ( ) ?;
352+
353+ list. iter ( )
354+ . map ( |entry| match entry {
355+ ListEntry :: Item ( item) => Ok ( item
356+ . bare_item
357+ . as_string ( )
358+ . ok_or ( ConnectError :: InvalidProtocol ) ?
359+ . as_str ( )
360+ . to_string ( ) ) ,
361+ _ => Err ( ConnectError :: InvalidProtocol ) ,
362+ } )
363+ . collect ( )
364+ }
365+
366+ /// Encode a single string as an RFC 8941 Structured Field Item.
367+ pub fn encode_item ( protocol : & str ) -> Result < String , ConnectError > {
368+ let s = StringRef :: from_str ( protocol) ?;
369+ Ok ( ItemSerializer :: new ( ) . bare_item ( s) . finish ( ) )
370+ }
371+
372+ /// Decode an RFC 8941 Structured Field Item (single string).
373+ pub fn decode_item ( value : & str ) -> Result < String , ConnectError > {
374+ let item = Parser :: new ( value) . parse :: < Item > ( ) ?;
375+ Ok ( item
376+ . bare_item
377+ . as_string ( )
378+ . ok_or ( ConnectError :: InvalidProtocol ) ?
379+ . as_str ( )
380+ . to_string ( ) )
381+ }
382+ }
0 commit comments