agent2: Remove model param from Thread::send method (#35936)

Ben Brandt created

It instead uses the currently selected model

Release Notes:

- N/A

Change summary

crates/agent2/src/agent.rs     |  3 +--
crates/agent2/src/tests/mod.rs | 33 ++++++++++++++-------------------
crates/agent2/src/thread.rs    |  2 +-
3 files changed, 16 insertions(+), 22 deletions(-)

Detailed changes

crates/agent2/src/agent.rs 🔗

@@ -491,8 +491,7 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
 
             // Send to thread
             log::info!("Sending message to thread with model: {:?}", model.name());
-            let mut response_stream =
-                thread.update(cx, |thread, cx| thread.send(model, message, cx))?;
+            let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
 
             // Handle response stream and forward to session.acp_thread
             while let Some(result) = response_stream.next().await {

crates/agent2/src/tests/mod.rs 🔗

@@ -29,11 +29,11 @@ use test_tools::*;
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_echo(cx: &mut TestAppContext) {
-    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
+    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 
     let events = thread
         .update(cx, |thread, cx| {
-            thread.send(model.clone(), "Testing: Reply with 'Hello'", cx)
+            thread.send("Testing: Reply with 'Hello'", cx)
         })
         .collect()
         .await;
@@ -49,12 +49,11 @@ async fn test_echo(cx: &mut TestAppContext) {
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_thinking(cx: &mut TestAppContext) {
-    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
+    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
 
     let events = thread
         .update(cx, |thread, cx| {
             thread.send(
-                model.clone(),
                 indoc! {"
                     Testing:
 
@@ -91,7 +90,7 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
 
     project_context.borrow_mut().shell = "test-shell".into();
     thread.update(cx, |thread, _| thread.add_tool(EchoTool));
-    thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
+    thread.update(cx, |thread, cx| thread.send("abc", cx));
     cx.run_until_parked();
     let mut pending_completions = fake_model.pending_completions();
     assert_eq!(
@@ -121,14 +120,13 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_basic_tool_calls(cx: &mut TestAppContext) {
-    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
+    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 
     // Test a tool call that's likely to complete *before* streaming stops.
     let events = thread
         .update(cx, |thread, cx| {
             thread.add_tool(EchoTool);
             thread.send(
-                model.clone(),
                 "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
                 cx,
             )
@@ -143,7 +141,6 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
             thread.remove_tool(&AgentTool::name(&EchoTool));
             thread.add_tool(DelayTool);
             thread.send(
-                model.clone(),
                 "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
                 cx,
             )
@@ -171,12 +168,12 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
-    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
+    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 
     // Test a tool call that's likely to complete *before* streaming stops.
     let mut events = thread.update(cx, |thread, cx| {
         thread.add_tool(WordListTool);
-        thread.send(model.clone(), "Test the word_list tool.", cx)
+        thread.send("Test the word_list tool.", cx)
     });
 
     let mut saw_partial_tool_use = false;
@@ -223,7 +220,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
 
     let mut events = thread.update(cx, |thread, cx| {
         thread.add_tool(ToolRequiringPermission);
-        thread.send(model.clone(), "abc", cx)
+        thread.send("abc", cx)
     });
     cx.run_until_parked();
     fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -290,7 +287,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
     let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
     let fake_model = model.as_fake();
 
-    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "abc", cx));
+    let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
     cx.run_until_parked();
     fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
         LanguageModelToolUse {
@@ -375,14 +372,13 @@ async fn next_tool_call_authorization(
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
-    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
+    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 
     // Test concurrent tool calls with different delay times
     let events = thread
         .update(cx, |thread, cx| {
             thread.add_tool(DelayTool);
             thread.send(
-                model.clone(),
                 "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
                 cx,
             )
@@ -414,13 +410,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 #[gpui::test]
 #[ignore = "can't run on CI yet"]
 async fn test_cancellation(cx: &mut TestAppContext) {
-    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Sonnet4).await;
+    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 
     let mut events = thread.update(cx, |thread, cx| {
         thread.add_tool(InfiniteTool);
         thread.add_tool(EchoTool);
         thread.send(
-            model.clone(),
             "Call the echo tool and then call the infinite tool, then explain their output",
             cx,
         )
@@ -466,7 +461,7 @@ async fn test_cancellation(cx: &mut TestAppContext) {
     // Ensure we can still send a new message after cancellation.
     let events = thread
         .update(cx, |thread, cx| {
-            thread.send(model.clone(), "Testing: reply with 'Hello' then stop.", cx)
+            thread.send("Testing: reply with 'Hello' then stop.", cx)
         })
         .collect::<Vec<_>>()
         .await;
@@ -484,7 +479,7 @@ async fn test_refusal(cx: &mut TestAppContext) {
     let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
     let fake_model = model.as_fake();
 
-    let events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Hello", cx));
+    let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
     cx.run_until_parked();
     thread.read_with(cx, |thread, _| {
         assert_eq!(
@@ -648,7 +643,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
     thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
     let fake_model = model.as_fake();
 
-    let mut events = thread.update(cx, |thread, cx| thread.send(model.clone(), "Think", cx));
+    let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
     cx.run_until_parked();
 
     // Simulate streaming partial input.

crates/agent2/src/thread.rs 🔗

@@ -200,11 +200,11 @@ impl Thread {
     /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
     pub fn send(
         &mut self,
-        model: Arc<dyn LanguageModel>,
         content: impl Into<MessageContent>,
         cx: &mut Context<Self>,
     ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
         let content = content.into();
+        let model = self.selected_model.clone();
         log::info!("Thread::send called with model: {:?}", model.name());
         log::debug!("Thread::send content: {:?}", content);