Generating thread title

Conrad Irwin created

Change summary

crates/agent2/src/agent.rs                |  4 +-
crates/agent2/src/tests/mod.rs            |  1 
crates/agent2/src/thread.rs               | 34 ++++++++++++++++++------
crates/agent2/src/tools/edit_file_tool.rs |  4 ++
crates/agent_ui/src/agent_panel.rs        |  6 ++--
5 files changed, 34 insertions(+), 15 deletions(-)

Detailed changes

crates/agent2/src/agent.rs 🔗

@@ -139,7 +139,6 @@ impl LanguageModels {
         &self,
         model_id: &acp_thread::AgentModelId,
     ) -> Option<Arc<dyn LanguageModel>> {
-        dbg!(&self.models.len());
         self.models.get(model_id).cloned()
     }
 
@@ -277,6 +276,7 @@ impl NativeAgent {
         let thread_database = self.thread_database.clone();
         session.save_task = cx.spawn(async move |this, cx| {
             cx.background_executor().timer(SAVE_THREAD_DEBOUNCE).await;
+
             let db_thread = thread.update(cx, |thread, cx| thread.to_db(cx))?.await;
             thread_database.save_thread(id, db_thread).await?;
             this.update(cx, |this, cx| this.reload_history(cx))?;
@@ -527,7 +527,7 @@ impl NativeAgent {
                 if thread.model().is_none()
                     && let Some(model) = default_model.clone()
                 {
-                    thread.set_model(model);
+                    thread.set_model(model, cx);
                     cx.notify();
                 }
                 let summarization_model = registry

crates/agent2/src/tests/mod.rs 🔗

@@ -1554,6 +1554,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
             action_log,
             templates,
             Some(model.clone()),
+            None,
             cx,
         )
     });

crates/agent2/src/thread.rs 🔗

@@ -495,6 +495,7 @@ pub struct Thread {
     project_context: Rc<RefCell<ProjectContext>>,
     templates: Arc<Templates>,
     model: Option<Arc<dyn LanguageModel>>,
+    summarization_model: Option<Arc<dyn LanguageModel>>,
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
 }
@@ -508,6 +509,7 @@ impl Thread {
         action_log: Entity<ActionLog>,
         templates: Arc<Templates>,
         model: Option<Arc<dyn LanguageModel>>,
+        summarization_model: Option<Arc<dyn LanguageModel>>,
         cx: &mut Context<Self>,
     ) -> Self {
         let profile_id = AgentSettings::get_global(cx).default_profile.clone();
@@ -561,7 +563,7 @@ impl Thread {
             context_server_registry,
             action_log,
             Templates::new(),
-            model,
+            Some(model),
             None,
             cx,
         )
@@ -604,7 +606,7 @@ impl Thread {
             profile_id,
             project_context,
             templates,
-            model,
+            model: Some(model),
             summarization_model,
             project,
             action_log,
@@ -622,9 +624,9 @@ impl Thread {
             initial_project_snapshot: None,
             cumulative_token_usage: self.cumulative_token_usage.clone(),
             request_token_usage: self.request_token_usage.clone(),
-            model: Some(DbLanguageModel {
-                provider: self.model.provider_id().to_string(),
-                model: self.model.name().0.to_string(),
+            model: self.model.as_ref().map(|model| DbLanguageModel {
+                provider: model.provider_id().to_string(),
+                model: model.name().0.to_string(),
             }),
             completion_mode: Some(self.completion_mode.into()),
             profile: Some(self.profile_id.clone()),
@@ -850,8 +852,18 @@ impl Thread {
         self.model.as_ref()
     }
 
-    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>) {
+    pub fn set_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut Context<Self>) {
         self.model = Some(model);
+        cx.notify()
+    }
+
+    pub fn set_summarization_model(
+        &mut self,
+        model: Option<Arc<dyn LanguageModel>>,
+        cx: &mut Context<Self>,
+    ) {
+        self.summarization_model = model;
+        cx.notify()
     }
 
     pub fn completion_mode(&self) -> CompletionMode {
@@ -931,7 +943,7 @@ impl Thread {
         id: UserMessageId,
         content: impl IntoIterator<Item = T>,
         cx: &mut Context<Self>,
-    ) -> mpsc::UnboundedReceiver<Result<ThreadEvent>>
+    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>>
     where
         T: Into<UserMessageContent>,
     {
@@ -951,10 +963,13 @@ impl Thread {
         self.run_turn(cx)
     }
 
-    fn run_turn(&mut self, cx: &mut Context<Self>) -> mpsc::UnboundedReceiver<Result<ThreadEvent>> {
+    fn run_turn(
+        &mut self,
+        cx: &mut Context<Self>,
+    ) -> Result<mpsc::UnboundedReceiver<Result<ThreadEvent>>> {
         self.cancel(cx);
 
-        let model = self.model.clone();
+        let model = self.model.clone().context("No language model configured")?;
         let (events_tx, events_rx) = mpsc::unbounded::<Result<ThreadEvent>>();
         let event_stream = ThreadEventStream(events_tx);
         let message_ix = self.messages.len().saturating_sub(1);
@@ -1145,6 +1160,7 @@ impl Thread {
         });
 
         self.title = ThreadTitle::Pending(task);
+        cx.notify()
     }
 
     pub fn build_system_message(&self) -> LanguageModelRequestMessage {

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -257,8 +257,10 @@ impl AgentTool for EditFileTool {
 
             let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
                 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
-                (request, thread.model().clone(), thread.action_log().clone())
+                (request, thread.model().cloned(), thread.action_log().clone())
             })?;
+            let request = request?;
+            let model = model.context("No language model configured")?;
 
             let edit_format = EditFormat::from_model(model.clone())?;
             let edit_agent = EditAgent::new(

crates/agent_ui/src/agent_panel.rs 🔗

@@ -1697,13 +1697,13 @@ impl AgentPanel {
                 window.dispatch_action(NewTextThread.boxed_clone(), cx);
             }
             AgentType::NativeAgent => {
-                self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), window, cx)
+                self.new_external_thread(Some(crate::ExternalAgent::NativeAgent), None, window, cx)
             }
             AgentType::Gemini => {
-                self.new_external_thread(Some(crate::ExternalAgent::Gemini), window, cx)
+                self.new_external_thread(Some(crate::ExternalAgent::Gemini), None, window, cx)
             }
             AgentType::ClaudeCode => {
-                self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), window, cx)
+                self.new_external_thread(Some(crate::ExternalAgent::ClaudeCode), None, window, cx)
             }
         }
     }