diff --git a/crates/agent2/src/agent.rs b/crates/agent2/src/agent.rs index 5b5dbff5892c66b0aa2008514f7457eebf0b50aa..be0bac047f61509cdf2d103f144d40f50d86c109 100644 --- a/crates/agent2/src/agent.rs +++ b/crates/agent2/src/agent.rs @@ -139,7 +139,6 @@ impl LanguageModels { &self, model_id: &acp_thread::AgentModelId, ) -> Option> { - dbg!(&self.models.len()); self.models.get(model_id).cloned() } @@ -277,6 +276,7 @@ impl NativeAgent { let thread_database = self.thread_database.clone(); session.save_task = cx.spawn(async move |this, cx| { cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await; + let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await; thread_database.save_thread(id, db_thread).await?; this.update(cx, |this, cx| this.reload_history(cx))?; @@ -527,7 +527,7 @@ impl NativeAgent { if thread.model().is_none() && let Some(model) = default_model.clone() { - thread.set_model(model); + thread.set_model(model, cx); cx.notify(); } let summarization_model = registry diff --git a/crates/agent2/src/tests/mod.rs b/crates/agent2/src/tests/mod.rs index 2678a511261c13a6cb66ad259a620f1c4b816282..9aac27dcd613315eb1a0d3ded76cb66d2d1302c0 100644 --- a/crates/agent2/src/tests/mod.rs +++ b/crates/agent2/src/tests/mod.rs @@ -1554,6 +1554,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { action_log, templates, Some(model.clone()), + None, cx, ) }); diff --git a/crates/agent2/src/thread.rs b/crates/agent2/src/thread.rs index 52430d67b9118660c7b268fd9c9acc21a3e4b4cf..93a6fad23a67795801f74901085082177f6fc32b 100644 --- a/crates/agent2/src/thread.rs +++ b/crates/agent2/src/thread.rs @@ -495,6 +495,7 @@ pub struct Thread { project_context: Rc>, templates: Arc, model: Option>, + summarization_model: Option>, project: Entity, action_log: Entity, } @@ -508,6 +509,7 @@ impl Thread { action_log: Entity, templates: Arc, model: Option>, + summarization_model: Option>, cx: &mut Context, ) -> Self { let profile_id = AgentSettings::get_global(cx).default_profile.clone(); @@ -561,7 +563,7 @@ impl Thread { context_server_registry, action_log, Templates::new(), - model, + Some(model), None, cx, ) @@ -604,7 +606,7 @@ impl Thread { profile_id, project_context, templates, - model, + model: Some(model), summarization_model, project, action_log, @@ -622,9 +624,9 @@ impl Thread { initial_project_snapshot: None, cumulative_token_usage: self.cumulative_token_usage.clone(), request_token_usage: self.request_token_usage.clone(), - model: Some(DbLanguageModel { - provider: self.model.provider_id().to_string(), - model: self.model.name().0.to_string(), + model: self.model.as_ref().map(|model| DbLanguageModel { + provider: model.provider_id().to_string(), + model: model.name().0.to_string(), }), completion_mode: Some(self.completion_mode.into()), profile: Some(self.profile_id.clone()), @@ -850,8 +852,18 @@ impl Thread { self.model.as_ref() } - pub fn set_model(&mut self, model: Arc) { + pub fn set_model(&mut self, model: Arc, cx: &mut Context) { self.model = Some(model); + cx.notify() + } + + pub fn set_summarization_model( + &mut self, + model: Option>, + cx: &mut Context, + ) { + self.summarization_model = model; + cx.notify() } pub fn completion_mode(&self) -> CompletionMode { @@ -931,7 +943,7 @@ impl Thread { id: UserMessageId, content: impl IntoIterator, cx: &mut Context, - ) -> mpsc::UnboundedReceiver> + ) -> Result>> where T: Into, { @@ -951,10 +963,13 @@ impl Thread { self.run_turn(cx) } - fn run_turn(&mut self, cx: &mut Context) -> mpsc::UnboundedReceiver> { + fn run_turn( + &mut self, + cx: &mut Context, + ) -> Result>> { self.cancel(cx); - let model = self.model.clone(); + let model = self.model.clone().context("No language model configured")?; let (events_tx, events_rx) = mpsc::unbounded::>(); let event_stream = ThreadEventStream(events_tx); let message_ix = self.messages.len().saturating_sub(1); @@ -1145,6 +1160,7 @@ impl Thread { }); self.title = ThreadTitle::Pending(task); + cx.notify() } pub fn build_system_message(&self) -> LanguageModelRequestMessage { diff --git a/crates/agent2/src/tools/edit_file_tool.rs b/crates/agent2/src/tools/edit_file_tool.rs index f540349f82e42533dba6e608e1ca3ec0721358d5..756698bf3fc70f2f23315f3ee4d032addf0c4a7f 100644 --- a/crates/agent2/src/tools/edit_file_tool.rs +++ b/crates/agent2/src/tools/edit_file_tool.rs @@ -257,8 +257,10 @@ impl AgentTool for EditFileTool { let (request, model, action_log) = self.thread.update(cx, |thread, cx| { let request = thread.build_completion_request(CompletionIntent::ToolResults, cx); - (request, thread.model().clone(), thread.action_log().clone()) + (request, thread.model().cloned(), thread.action_log().clone()) })?; + let request = request?; + let model = model.context("No language model configured")?; let edit_format = EditFormat::from_model(model.clone())?; let edit_agent = EditAgent::new( diff --git a/crates/agent_ui/src/agent_panel.rs b/crates/agent_ui/src/agent_panel.rs index 24599ab62184c1b985b6fea7b496f056b1bafaa1..20e6206fa2add33487a8a5ea67ceaee7e2709157 100644 --- a/crates/agent_ui/src/agent_panel.rs +++ b/crates/agent_ui/src/agent_panel.rs @@ -1697,13 +1697,13 @@ impl AgentPanel { window.dispatch_action(NewTextThread.boxed_clone(), cx); } AgentType::NativeAgent => { - self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), None, window, cx) } AgentType::Gemini => { - self.new_external_thread(Some(crate::ExternalAgent::Gemini), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::Gemini), None, window, cx) } AgentType::ClaudeCode => { - self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), window, cx) + self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), None, window, cx) } } }