diff --git a/crates/agent/src/agent.rs b/crates/agent/src/agent.rs index 62a26f5b10672e3d1367d0fb7b085602a049df47..37dee2d97f44f7290ad9a084fccb3fc226f6de52 100644 --- a/crates/agent/src/agent.rs +++ b/crates/agent/src/agent.rs @@ -727,7 +727,7 @@ impl NativeAgent { fn handle_models_updated_event( &mut self, _registry: Entity, - _event: &language_model::Event, + event: &language_model::Event, cx: &mut Context, ) { self.models.refresh_list(cx); @@ -744,7 +744,13 @@ impl NativeAgent { thread.set_model(model, cx); cx.notify(); } - thread.set_summarization_model(summarization_model.clone(), cx); + if let Some(model) = summarization_model.clone() { + if thread.summarization_model().is_none() + || matches!(event, language_model::Event::ThreadSummaryModelChanged) + { + thread.set_summarization_model(Some(model), cx); + } + } }); } } @@ -2456,6 +2462,61 @@ mod internal_tests { }); } + #[gpui::test] + async fn test_summarization_model_survives_transient_registry_clearing( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({ "a": {} })).await; + let project = Project::test(fs.clone(), [], cx).await; + + let thread_store = cx.new(|cx| ThreadStore::new(cx)); + let agent = + cx.update(|cx| NativeAgent::new(thread_store, Templates::new(), None, fs.clone(), cx)); + let connection = Rc::new(NativeAgentConnection(agent.clone())); + + let acp_thread = cx + .update(|cx| { + connection.clone().new_session( + project.clone(), + PathList::new(&[Path::new("/a")]), + cx, + ) + }) + .await + .unwrap(); + let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone()); + + let thread = agent.read_with(cx, |agent, _| { + agent.sessions.get(&session_id).unwrap().thread.clone() + }); + + thread.read_with(cx, |thread, _| { + assert!( + thread.summarization_model().is_some(), + "session should have a summarization model from the test registry" + ); + }); + + // Simulate what happens during a provider blip: + // update_active_language_model_from_settings calls set_default_model(None) + // when it can't resolve the model, clearing all fallbacks. + cx.update(|cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model(None, cx); + }); + }); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!( + thread.summarization_model().is_some(), + "summarization model should survive a transient default model clearing" + ); + }); + } + #[gpui::test] async fn test_loaded_thread_preserves_thinking_enabled(cx: &mut TestAppContext) { init_test(cx);