@@ -81,6 +81,7 @@ pub trait VhostUserBackendReqHandler {
8181 fd : File ,
8282 ) -> Result < Option < File > > ;
8383 fn check_device_state ( & self ) -> Result < ( ) > ;
84+ fn get_shmem_config ( & self ) -> Result < VhostUserShMemConfig > ;
8485 #[ cfg( feature = "postcopy" ) ]
8586 fn postcopy_advice ( & self ) -> Result < File > ;
8687 #[ cfg( feature = "postcopy" ) ]
@@ -146,6 +147,7 @@ pub trait VhostUserBackendReqHandlerMut {
146147 fd : File ,
147148 ) -> Result < Option < File > > ;
148149 fn check_device_state ( & mut self ) -> Result < ( ) > ;
150+ fn get_shmem_config ( & self ) -> Result < VhostUserShMemConfig > ;
149151 #[ cfg( feature = "postcopy" ) ]
150152 fn postcopy_advice ( & mut self ) -> Result < File > ;
151153 #[ cfg( feature = "postcopy" ) ]
@@ -289,6 +291,10 @@ impl<T: VhostUserBackendReqHandlerMut> VhostUserBackendReqHandler for Mutex<T> {
289291 self . lock ( ) . unwrap ( ) . check_device_state ( )
290292 }
291293
294+ fn get_shmem_config ( & self ) -> Result < VhostUserShMemConfig > {
295+ self . lock ( ) . unwrap ( ) . get_shmem_config ( )
296+ }
297+
292298 #[ cfg( feature = "postcopy" ) ]
293299 fn postcopy_advice ( & self ) -> Result < File > {
294300 self . lock ( ) . unwrap ( ) . postcopy_advice ( )
@@ -679,6 +685,11 @@ impl<S: VhostUserBackendReqHandler> BackendReqHandler<S> {
679685 } ;
680686 self . send_reply_message ( & hdr, & msg) ?;
681687 }
688+ Ok ( FrontendReq :: GET_SHMEM_CONFIG ) => {
689+ self . check_proto_feature ( VhostUserProtocolFeatures :: SHMEM ) ?;
690+ let msg = self . backend . get_shmem_config ( ) . unwrap_or_default ( ) ;
691+ self . send_reply_message ( & hdr, & msg) ?;
692+ }
682693 #[ cfg( feature = "postcopy" ) ]
683694 Ok ( FrontendReq :: POSTCOPY_ADVISE ) => {
684695 self . check_proto_feature ( VhostUserProtocolFeatures :: PAGEFAULT ) ?;
@@ -1038,4 +1049,111 @@ mod tests {
10381049 handler. check_state ( ) . unwrap_err ( ) ;
10391050 assert ! ( handler. as_raw_fd( ) >= 0 ) ;
10401051 }
1052+
1053+ #[ test]
1054+ fn test_get_shmem_config_max_regions ( ) {
1055+ // Create a configuration with maximum number of regions (8)
1056+ let memory_sizes = [
1057+ 0x1000 , 0x2000 , 0x3000 , 0x4000 , 0x5000 , 0x6000 , 0x7000 , 0x8000 ,
1058+ ] ;
1059+ let config = VhostUserShMemConfig :: new ( 8 , & memory_sizes) ;
1060+
1061+ let ( p1, p2) = UnixStream :: pair ( ) . unwrap ( ) ;
1062+ let mut dummy_backend = DummyBackendReqHandler :: new ( ) ;
1063+ dummy_backend. set_shmem_config ( config) ;
1064+ let backend = Arc :: new ( Mutex :: new ( dummy_backend) ) ;
1065+ let mut handler = BackendReqHandler :: new (
1066+ Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p1) ,
1067+ backend,
1068+ ) ;
1069+ handler. acked_protocol_features = VhostUserProtocolFeatures :: SHMEM . bits ( ) ;
1070+
1071+ let mut frontend_endpoint = Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p2) ;
1072+
1073+ let handle = std:: thread:: spawn ( move || {
1074+ let hdr = VhostUserMsgHeader :: new ( FrontendReq :: GET_SHMEM_CONFIG , 0 , 0 ) ;
1075+ frontend_endpoint
1076+ . send_message ( & hdr, & VhostUserEmpty , None )
1077+ . unwrap ( ) ;
1078+
1079+ let ( reply_hdr, reply_config, rfds) = frontend_endpoint
1080+ . recv_body :: < VhostUserShMemConfig > ( )
1081+ . unwrap ( ) ;
1082+ assert_eq ! ( reply_hdr. get_code( ) . unwrap( ) , FrontendReq :: GET_SHMEM_CONFIG ) ;
1083+ assert ! ( reply_hdr. is_reply( ) ) ;
1084+ assert ! ( rfds. is_none( ) ) ;
1085+ reply_config
1086+ } ) ;
1087+
1088+ handler. handle_request ( ) . unwrap ( ) ;
1089+
1090+ let reply_config = handle. join ( ) . unwrap ( ) ;
1091+ assert_eq ! ( reply_config. nregions, 8 ) ;
1092+ for i in 0 ..8 {
1093+ assert_eq ! ( reply_config. memory_sizes[ i] , ( i as u64 + 1 ) * 0x1000 ) ;
1094+ }
1095+ }
1096+
1097+ #[ test]
1098+ fn test_get_shmem_config_non_continuous_regions ( ) {
1099+ // Create a configuration with non-continuous regions
1100+ let memory_sizes = [ 0x10000 , 0 , 0x20000 , 0 , 0 , 0 , 0 , 0 ] ;
1101+ let config = VhostUserShMemConfig :: new ( 2 , & memory_sizes) ;
1102+
1103+ let ( p1, p2) = UnixStream :: pair ( ) . unwrap ( ) ;
1104+ let mut dummy_backend = DummyBackendReqHandler :: new ( ) ;
1105+ dummy_backend. set_shmem_config ( config) ;
1106+ let backend = Arc :: new ( Mutex :: new ( dummy_backend) ) ;
1107+ let mut handler = BackendReqHandler :: new (
1108+ Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p1) ,
1109+ backend,
1110+ ) ;
1111+ handler. acked_protocol_features = VhostUserProtocolFeatures :: SHMEM . bits ( ) ;
1112+
1113+ let mut frontend_endpoint = Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p2) ;
1114+
1115+ let handle = std:: thread:: spawn ( move || {
1116+ let hdr = VhostUserMsgHeader :: new ( FrontendReq :: GET_SHMEM_CONFIG , 0 , 0 ) ;
1117+ frontend_endpoint
1118+ . send_message ( & hdr, & VhostUserEmpty , None )
1119+ . unwrap ( ) ;
1120+
1121+ let ( reply_hdr, reply_config, rfds) = frontend_endpoint
1122+ . recv_body :: < VhostUserShMemConfig > ( )
1123+ . unwrap ( ) ;
1124+ assert_eq ! ( reply_hdr. get_code( ) . unwrap( ) , FrontendReq :: GET_SHMEM_CONFIG ) ;
1125+ assert ! ( reply_hdr. is_reply( ) ) ;
1126+ assert ! ( rfds. is_none( ) ) ;
1127+ reply_config
1128+ } ) ;
1129+
1130+ handler. handle_request ( ) . unwrap ( ) ;
1131+
1132+ let reply_config = handle. join ( ) . unwrap ( ) ;
1133+ assert_eq ! ( reply_config. nregions, 2 ) ;
1134+ assert_eq ! ( reply_config. memory_sizes[ 0 ] , 0x10000 ) ;
1135+ assert_eq ! ( reply_config. memory_sizes[ 1 ] , 0 ) ;
1136+ assert_eq ! ( reply_config. memory_sizes[ 2 ] , 0x20000 ) ;
1137+ for i in 3 ..8 {
1138+ assert_eq ! ( reply_config. memory_sizes[ i] , 0 ) ;
1139+ }
1140+ }
1141+
1142+ #[ test]
1143+ fn test_get_shmem_config_feature_not_negotiated ( ) {
1144+ // Test that the request fails when SHMEM protocol feature is not negotiated
1145+ let ( p1, p2) = UnixStream :: pair ( ) . unwrap ( ) ;
1146+ let backend = Arc :: new ( Mutex :: new ( DummyBackendReqHandler :: new ( ) ) ) ;
1147+ let mut handler = BackendReqHandler :: new (
1148+ Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p1) ,
1149+ backend,
1150+ ) ;
1151+ let mut frontend_endpoint = Endpoint :: < VhostUserMsgHeader < FrontendReq > > :: from_stream ( p2) ;
1152+
1153+ std:: thread:: spawn ( move || {
1154+ let hdr = VhostUserMsgHeader :: new ( FrontendReq :: GET_SHMEM_CONFIG , 0 , 0 ) ;
1155+ let _ = frontend_endpoint. send_message ( & hdr, & VhostUserEmpty , None ) ;
1156+ } ) ;
1157+ assert ! ( handler. handle_request( ) . is_err( ) ) ;
1158+ }
10411159}
0 commit comments