diff --git a/vm/devices/vmbus/vmbus_server/src/channels.rs b/vm/devices/vmbus/vmbus_server/src/channels.rs index fc8bd075ab..3e353fa5c2 100644 --- a/vm/devices/vmbus/vmbus_server/src/channels.rs +++ b/vm/devices/vmbus/vmbus_server/src/channels.rs @@ -1920,19 +1920,7 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { modify_sent: false, } => { if self.are_channels_reset(matches!(next_action, ConnectionAction::Reset)) { - self.inner.state = ConnectionState::Disconnecting { - next_action, - modify_sent: true, - }; - - // Reset server state and disconnect the relay if there is one. - self.notifier - .modify_connection(ModifyConnectionRequest { - monitor_page: Update::Reset, - interrupt_page: Update::Reset, - ..Default::default() - }) - .expect("resetting state should not fail"); + self.notify_disconnect(next_action); } } ConnectionState::Disconnecting { @@ -1944,6 +1932,25 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { } } + /// Informs the notifier to reset the connection state when disconnecting. + fn notify_disconnect(&mut self, next_action: ConnectionAction) { + // Assert this on debug only because it is an expensive check if there are many channels. + debug_assert!(self.are_channels_reset(matches!(next_action, ConnectionAction::Reset))); + self.inner.state = ConnectionState::Disconnecting { + next_action, + modify_sent: true, + }; + + // Reset server state and disconnect the relay if there is one. + self.notifier + .modify_connection(ModifyConnectionRequest { + monitor_page: Update::Reset, + interrupt_page: Update::Reset, + ..Default::default() + }) + .expect("resetting state should not fail"); + } + /// If true, the server is mid-reset and cannot take certain actions such /// as handling synic messages or saving state. fn is_resetting(&self) -> bool { @@ -2421,7 +2428,7 @@ impl<'a, N: 'a + Notifier> ServerWithNotifier<'a, N> { ConnectionState::Connected { .. } => { if self.are_channels_reset(vm_reset) { - self.inner.state = ConnectionState::Disconnected; + self.notify_disconnect(new_action); } else { self.inner.state = ConnectionState::Disconnecting { next_action: new_action, @@ -3905,8 +3912,7 @@ mod tests { #[test] fn test_version_negotiation_feature_flags() { - let (mut notifier, _recv) = TestNotifier::new(); - let mut server = Server::new(Vtl::Vtl0, MESSAGE_CONNECTION_ID, 0); + let mut env = TestEnv::new(); // Test with no feature flags. let mut target_info = TargetInfo::new() @@ -3914,14 +3920,17 @@ mod tests { .with_vtl(0) .with_feature_flags(FeatureFlags::new().into()); test_initiate_contact( - &mut server, - &mut notifier, + &mut env.server, + &mut env.notifier, Version::Copper as u32, target_info.into(), true, 0, ); + env.c().handle_unload(); + env.complete_reset(); + env.notifier.messages.clear(); // Request supported feature flags. target_info.set_feature_flags( FeatureFlags::new() @@ -3929,8 +3938,8 @@ mod tests { .into(), ); test_initiate_contact( - &mut server, - &mut notifier, + &mut env.server, + &mut env.notifier, Version::Copper as u32, target_info.into(), true, @@ -3939,14 +3948,17 @@ mod tests { .into(), ); + env.c().handle_unload(); + env.complete_reset(); + env.notifier.messages.clear(); // Request unsupported feature flags. This will succeed and report back the supported ones. target_info.set_feature_flags( u32::from(FeatureFlags::new().with_guest_specified_signal_parameters(true)) | 0xf0000000, ); test_initiate_contact( - &mut server, - &mut notifier, + &mut env.server, + &mut env.notifier, Version::Copper as u32, target_info.into(), true, @@ -3955,11 +3967,14 @@ mod tests { .into(), ); + env.c().handle_unload(); + env.complete_reset(); + env.notifier.messages.clear(); // Verify client ID feature flag. target_info.set_feature_flags(FeatureFlags::new().with_client_id(true).into()); test_initiate_contact( - &mut server, - &mut notifier, + &mut env.server, + &mut env.notifier, Version::Copper as u32, target_info.into(), true, @@ -4597,9 +4612,7 @@ mod tests { self.server.with_notifier(&mut self.notifier) } - // Completes a reset operation if the server send a modify request as part of it. This - // shouldn't be called if the server was not connected or had no open channels or gpadls - // during the reset. + // Completes a reset operation if the server sends a modify request as part of it. fn complete_reset(&mut self) { let _ = self.next_action(); self.c() @@ -5053,6 +5066,7 @@ mod tests { env.c().reset(); // We have to "complete" the connection to let the reset go through. env.complete_connect(); + env.complete_reset(); env.notifier.check_reset(); env.c().restore(state).unwrap(); @@ -5115,6 +5129,7 @@ mod tests { let state = env.server.save(); env.c().reset(); + env.complete_reset(); env.notifier.check_reset(); env.c().restore(state).unwrap(); @@ -5495,6 +5510,7 @@ mod tests { // Reserved channels and gpadls should stay open across unloads env.c().handle_unload(); + env.complete_reset(); // Closing while disconnected should work env.close_reserved(2, 2, SINT.into()); @@ -5553,6 +5569,7 @@ mod tests { env.c().open_complete(offer_id1, 0); env.c().handle_unload(); + env.complete_reset(); // Reset while disconnected should cleanup reserved channels // and complete disconnect automatically @@ -5575,6 +5592,7 @@ mod tests { env.c().open_complete(offer_id2, 0); env.c().handle_unload(); + env.complete_reset(); env.close_reserved(2, 2, SINT.into()); env.c().close_complete(offer_id2); @@ -5912,4 +5930,138 @@ mod tests { } ); } + + #[test] + fn test_disconnect() { + let mut env = TestEnv::new(); + let _offer_id1 = env.offer(1); + let _offer_id2 = env.offer(2); + let _offer_id3 = env.offer(3); + + env.connect(Version::Win10, FeatureFlags::new()); + env.c().handle_request_offers().unwrap(); + + // Send unload message with all channels already closed. + env.c().handle_unload(); + + // Check that modify_connection was invoked on the notifier. + let req = env.notifier.next_action(); + assert_eq!( + req, + ModifyConnectionRequest { + monitor_page: Update::Reset, + interrupt_page: Update::Reset, + ..Default::default() + } + ); + + env.notifier.messages.clear(); + env.c().complete_disconnect(); + env.notifier + .check_message(OutgoingMessage::new(&protocol::UnloadComplete {})); + } + + #[test] + fn test_disconnect_open_channels() { + let mut env = TestEnv::new(); + let offer_id1 = env.offer(1); + let offer_id2 = env.offer(2); + let _offer_id3 = env.offer(3); + + env.connect(Version::Win10, FeatureFlags::new()); + env.c().handle_request_offers().unwrap(); + + // Open two channels. + env.open(1); + env.open(2); + + env.c().open_complete(offer_id1, 0); + env.c().open_complete(offer_id2, 0); + + // Send unload message with channels still open. + env.c().handle_unload(); + + assert!(env.notifier.modify_requests.is_empty()); + + // Unload will close the channels, so complete that operation. + env.c().close_complete(offer_id1); + env.c().close_complete(offer_id2); + + // Modify connection will be invoked once all channels are closed. + let req = env.notifier.next_action(); + assert_eq!( + req, + ModifyConnectionRequest { + monitor_page: Update::Reset, + interrupt_page: Update::Reset, + ..Default::default() + } + ); + + env.notifier.messages.clear(); + env.c().complete_disconnect(); + env.notifier + .check_message(OutgoingMessage::new(&protocol::UnloadComplete {})); + } + + #[test] + fn test_reinitiate_contact() { + let mut env = TestEnv::new(); + let _offer_id1 = env.offer(1); + let _offer_id2 = env.offer(2); + let _offer_id3 = env.offer(3); + + env.connect(Version::Win10, FeatureFlags::new()); + env.c().handle_request_offers().unwrap(); + env.notifier.messages.clear(); + + // Send a new InitiateContact message to force a disconnect without using reload. + let result = env.c().handle_synic_message(in_msg_ex( + protocol::MessageType::INITIATE_CONTACT, + protocol::InitiateContact { + version_requested: Version::Win10 as u32, + interrupt_page_or_target_info: TargetInfo::new().with_sint(SINT).with_vtl(0).into(), + child_to_parent_monitor_page_gpa: 0x123f000, + parent_to_child_monitor_page_gpa: 0x321f000, + ..FromZeros::new_zeroed() + }, + false, + false, + )); + assert!(result.is_ok()); + + // We will first receive a request indicating the forced disconnect. + let req = env.notifier.next_action(); + assert_eq!( + req, + ModifyConnectionRequest { + monitor_page: Update::Reset, + interrupt_page: Update::Reset, + ..Default::default() + } + ); + + env.c().complete_disconnect(); + + // No UnloadComplete is sent in this case since Unload was not sent. + assert!(env.notifier.messages.is_empty()); + + // Now we receive the request for the new connection. + let req = env.notifier.next_action(); + assert_eq!( + req, + ModifyConnectionRequest { + version: Some(Version::Win10 as u32), + monitor_page: Update::Set(MonitorPageGpas { + child_to_parent: 0x123f000, + parent_to_child: 0x321f000, + }), + interrupt_page: Update::Reset, + target_message_vp: Some(0), + ..Default::default() + } + ); + + env.complete_connect(); + } }