edit_file_thread_test.rs

  1use super::*;
  2use crate::{AgentTool, EditFileTool, ReadFileTool};
  3use acp_thread::UserMessageId;
  4use action_log::ActionLog;
  5use fs::FakeFs;
  6use language_model::{
  7    LanguageModelCompletionEvent, LanguageModelToolUse, MessageContent, StopReason,
  8    fake_provider::FakeLanguageModel,
  9};
 10use prompt_store::ProjectContext;
 11use serde_json::json;
 12use std::{collections::BTreeMap, sync::Arc, time::Duration};
 13use util::path;
 14
 15#[gpui::test]
 16async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
 17    // This test verifies that the edit_file tool works correctly when invoked
 18    // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back).
 19    // This is different from tests that call tool.run() directly.
 20    super::init_test(cx);
 21    super::always_allow_tools(cx);
 22
 23    let fs = FakeFs::new(cx.executor());
 24    fs.insert_tree(
 25        path!("/project"),
 26        json!({
 27            "src": {
 28                "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}\n"
 29            }
 30        }),
 31    )
 32    .await;
 33
 34    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
 35    let project_context = cx.new(|_cx| ProjectContext::default());
 36    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
 37    let context_server_registry =
 38        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
 39    let model = Arc::new(FakeLanguageModel::default());
 40    let fake_model = model.as_fake();
 41
 42    let thread = cx.new(|cx| {
 43        let mut thread = crate::Thread::new(
 44            project.clone(),
 45            project_context,
 46            context_server_registry,
 47            crate::Templates::new(),
 48            Some(model.clone()),
 49            cx,
 50        );
 51        // Add just the tools we need for this test
 52        let language_registry = project.read(cx).languages().clone();
 53        thread.add_tool(crate::ReadFileTool::new(
 54            cx.weak_entity(),
 55            project.clone(),
 56            thread.action_log().clone(),
 57        ));
 58        thread.add_tool(crate::EditFileTool::new(
 59            project.clone(),
 60            cx.weak_entity(),
 61            language_registry,
 62            crate::Templates::new(),
 63        ));
 64        thread
 65    });
 66
 67    // First, read the file so the thread knows about its contents
 68    let _events = thread
 69        .update(cx, |thread, cx| {
 70            thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx)
 71        })
 72        .unwrap();
 73    cx.run_until_parked();
 74
 75    // Model calls read_file tool
 76    let read_tool_use = LanguageModelToolUse {
 77        id: "read_tool_1".into(),
 78        name: ReadFileTool::NAME.into(),
 79        raw_input: json!({"path": "project/src/main.rs"}).to_string(),
 80        input: json!({"path": "project/src/main.rs"}),
 81        is_input_complete: true,
 82        thought_signature: None,
 83    };
 84    fake_model
 85        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
 86    fake_model
 87        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
 88    fake_model.end_last_completion_stream();
 89    cx.run_until_parked();
 90
 91    // Wait for the read tool to complete and model to be called again
 92    while fake_model.pending_completions().is_empty() {
 93        cx.run_until_parked();
 94    }
 95
 96    // Model responds after seeing the file content, then calls edit_file
 97    fake_model.send_last_completion_stream_text_chunk("I'll edit the file now.");
 98    let edit_tool_use = LanguageModelToolUse {
 99        id: "edit_tool_1".into(),
100        name: EditFileTool::NAME.into(),
101        raw_input: json!({
102            "display_description": "Change greeting message",
103            "path": "project/src/main.rs",
104            "mode": "edit"
105        })
106        .to_string(),
107        input: json!({
108            "display_description": "Change greeting message",
109            "path": "project/src/main.rs",
110            "mode": "edit"
111        }),
112        is_input_complete: true,
113        thought_signature: None,
114    };
115    fake_model
116        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
117    fake_model
118        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
119    fake_model.end_last_completion_stream();
120    cx.run_until_parked();
121
122    // The edit_file tool creates an EditAgent which makes its own model request.
123    // We need to respond to that request with the edit instructions.
124    // Wait for the edit agent's completion request
125    let deadline = std::time::Instant::now() + Duration::from_secs(5);
126    while fake_model.pending_completions().is_empty() {
127        if std::time::Instant::now() >= deadline {
128            panic!(
129                "Timed out waiting for edit agent completion request. Pending: {}",
130                fake_model.pending_completions().len()
131            );
132        }
133        cx.run_until_parked();
134        cx.background_executor
135            .timer(Duration::from_millis(10))
136            .await;
137    }
138
139    // Send the edit agent's response with the XML format it expects
140    let edit_response = "<old_text>println!(\"Hello, world!\");</old_text>\n<new_text>println!(\"Hello, Zed!\");</new_text>";
141    fake_model.send_last_completion_stream_text_chunk(edit_response);
142    fake_model.end_last_completion_stream();
143    cx.run_until_parked();
144
145    // Wait for the edit to complete and the thread to call the model again with tool results
146    let deadline = std::time::Instant::now() + Duration::from_secs(5);
147    while fake_model.pending_completions().is_empty() {
148        if std::time::Instant::now() >= deadline {
149            panic!("Timed out waiting for model to be called after edit completion");
150        }
151        cx.run_until_parked();
152        cx.background_executor
153            .timer(Duration::from_millis(10))
154            .await;
155    }
156
157    // Verify the file was edited
158    let file_content = fs
159        .load(path!("/project/src/main.rs").as_ref())
160        .await
161        .expect("file should exist");
162    assert!(
163        file_content.contains("Hello, Zed!"),
164        "File should have been edited. Content: {}",
165        file_content
166    );
167    assert!(
168        !file_content.contains("Hello, world!"),
169        "Old content should be replaced. Content: {}",
170        file_content
171    );
172
173    // Verify the tool result was sent back to the model
174    let pending = fake_model.pending_completions();
175    assert!(
176        !pending.is_empty(),
177        "Model should have been called with tool result"
178    );
179
180    let last_request = pending.last().unwrap();
181    let has_tool_result = last_request.messages.iter().any(|m| {
182        m.content
183            .iter()
184            .any(|c| matches!(c, language_model::MessageContent::ToolResult(_)))
185    });
186    assert!(
187        has_tool_result,
188        "Tool result should be in the messages sent back to the model"
189    );
190
191    // Complete the turn
192    fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message.");
193    fake_model
194        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
195    fake_model.end_last_completion_stream();
196    cx.run_until_parked();
197
198    // Verify the thread completed successfully
199    thread.update(cx, |thread, _cx| {
200        assert!(
201            thread.is_turn_complete(),
202            "Thread should be complete after the turn ends"
203        );
204    });
205}
206
207#[gpui::test]
208async fn test_subagent_uses_read_file_tool(cx: &mut TestAppContext) {
209    // This test verifies that subagents can successfully use the read_file tool
210    // through the full thread flow, and that tools are properly rebound to use
211    // the subagent's thread ID instead of the parent's.
212    super::init_test(cx);
213    super::always_allow_tools(cx);
214
215    cx.update(|cx| {
216        cx.update_flags(true, vec!["subagents".to_string()]);
217    });
218
219    let fs = FakeFs::new(cx.executor());
220    fs.insert_tree(
221        path!("/project"),
222        json!({
223            "src": {
224                "lib.rs": "pub fn hello() -> &'static str {\n    \"Hello from lib!\"\n}\n"
225            }
226        }),
227    )
228    .await;
229
230    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
231    let project_context = cx.new(|_cx| ProjectContext::default());
232    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
233    let context_server_registry =
234        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
235    let model = Arc::new(FakeLanguageModel::default());
236    let fake_model = model.as_fake();
237
238    // Create subagent context
239    let subagent_context = crate::SubagentContext {
240        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
241        tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"),
242        depth: 1,
243        summary_prompt: "Summarize what you found".to_string(),
244        context_low_prompt: "Context low".to_string(),
245    };
246
247    // Create parent tools that will be passed to the subagent
248    // This simulates how the subagent_tool passes tools to new_subagent
249    let parent_tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> = {
250        let action_log = cx.new(|_| ActionLog::new(project.clone()));
251        // Create a "fake" parent thread reference - this should get rebound
252        let fake_parent_thread = cx.new(|cx| {
253            crate::Thread::new(
254                project.clone(),
255                cx.new(|_cx| ProjectContext::default()),
256                cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)),
257                crate::Templates::new(),
258                Some(model.clone()),
259                cx,
260            )
261        });
262        let mut tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> =
263            BTreeMap::new();
264        tools.insert(
265            ReadFileTool::NAME.into(),
266            crate::ReadFileTool::new(fake_parent_thread.downgrade(), project.clone(), action_log)
267                .erase(),
268        );
269        tools
270    };
271
272    // Create subagent - tools should be rebound to use subagent's thread
273    let subagent = cx.new(|cx| {
274        crate::Thread::new_subagent(
275            project.clone(),
276            project_context,
277            context_server_registry,
278            crate::Templates::new(),
279            model.clone(),
280            subagent_context,
281            parent_tools,
282            cx,
283        )
284    });
285
286    // Get the subagent's thread ID
287    let _subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string());
288
289    // Verify the subagent has the read_file tool
290    subagent.read_with(cx, |thread, _| {
291        assert!(
292            thread.has_registered_tool(ReadFileTool::NAME),
293            "subagent should have read_file tool"
294        );
295    });
296
297    // Submit a user message to the subagent
298    subagent
299        .update(cx, |thread, cx| {
300            thread.submit_user_message("Read the file src/lib.rs", cx)
301        })
302        .unwrap();
303    cx.run_until_parked();
304
305    // Simulate the model calling the read_file tool
306    let read_tool_use = LanguageModelToolUse {
307        id: "read_tool_1".into(),
308        name: ReadFileTool::NAME.into(),
309        raw_input: json!({"path": "project/src/lib.rs"}).to_string(),
310        input: json!({"path": "project/src/lib.rs"}),
311        is_input_complete: true,
312        thought_signature: None,
313    };
314    fake_model
315        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
316    fake_model.end_last_completion_stream();
317    cx.run_until_parked();
318
319    // Wait for the tool to complete and the model to be called again with tool results
320    let deadline = std::time::Instant::now() + Duration::from_secs(5);
321    while fake_model.pending_completions().is_empty() {
322        if std::time::Instant::now() >= deadline {
323            panic!("Timed out waiting for model to be called after read_file tool completion");
324        }
325        cx.run_until_parked();
326        cx.background_executor
327            .timer(Duration::from_millis(10))
328            .await;
329    }
330
331    // Verify the tool result was sent back to the model
332    let pending = fake_model.pending_completions();
333    assert!(
334        !pending.is_empty(),
335        "Model should have been called with tool result"
336    );
337
338    let last_request = pending.last().unwrap();
339    let tool_result = last_request.messages.iter().find_map(|m| {
340        m.content.iter().find_map(|c| match c {
341            MessageContent::ToolResult(result) => Some(result),
342            _ => None,
343        })
344    });
345    assert!(
346        tool_result.is_some(),
347        "Tool result should be in the messages sent back to the model"
348    );
349
350    // Verify the tool result contains the file content
351    let result = tool_result.unwrap();
352    let result_text = match &result.content {
353        language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
354        _ => panic!("expected text content in tool result"),
355    };
356    assert!(
357        result_text.contains("Hello from lib!"),
358        "Tool result should contain file content, got: {}",
359        result_text
360    );
361
362    // Verify the subagent is ready for more input (tool completed, model called again)
363    // This test verifies the subagent can successfully use read_file tool.
364    // The summary flow is tested separately in test_subagent_returns_summary_on_completion.
365}
366
367#[gpui::test]
368async fn test_subagent_uses_edit_file_tool(cx: &mut TestAppContext) {
369    // This test verifies that subagents can successfully use the edit_file tool
370    // through the full thread flow, including the edit agent's model request.
371    // It also verifies that the edit agent uses the subagent's thread ID, not the parent's.
372    super::init_test(cx);
373    super::always_allow_tools(cx);
374
375    cx.update(|cx| {
376        cx.update_flags(true, vec!["subagents".to_string()]);
377    });
378
379    let fs = FakeFs::new(cx.executor());
380    fs.insert_tree(
381        path!("/project"),
382        json!({
383            "src": {
384                "config.rs": "pub const VERSION: &str = \"1.0.0\";\n"
385            }
386        }),
387    )
388    .await;
389
390    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
391    let project_context = cx.new(|_cx| ProjectContext::default());
392    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
393    let context_server_registry =
394        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
395    let model = Arc::new(FakeLanguageModel::default());
396    let fake_model = model.as_fake();
397
398    // Create a "parent" thread to simulate the real scenario where tools are inherited
399    let parent_thread = cx.new(|cx| {
400        crate::Thread::new(
401            project.clone(),
402            cx.new(|_cx| ProjectContext::default()),
403            cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx)),
404            crate::Templates::new(),
405            Some(model.clone()),
406            cx,
407        )
408    });
409    let parent_thread_id = parent_thread.read_with(cx, |thread, _| thread.id().to_string());
410
411    // Create parent tools that reference the parent thread
412    let parent_tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> = {
413        let action_log = cx.new(|_| ActionLog::new(project.clone()));
414        let language_registry = project.read_with(cx, |p, _| p.languages().clone());
415        let mut tools: BTreeMap<gpui::SharedString, std::sync::Arc<dyn crate::AnyAgentTool>> =
416            BTreeMap::new();
417        tools.insert(
418            ReadFileTool::NAME.into(),
419            crate::ReadFileTool::new(parent_thread.downgrade(), project.clone(), action_log)
420                .erase(),
421        );
422        tools.insert(
423            EditFileTool::NAME.into(),
424            crate::EditFileTool::new(
425                project.clone(),
426                parent_thread.downgrade(),
427                language_registry,
428                crate::Templates::new(),
429            )
430            .erase(),
431        );
432        tools
433    };
434
435    // Create subagent context
436    let subagent_context = crate::SubagentContext {
437        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
438        tool_use_id: language_model::LanguageModelToolUseId::from("subagent-tool-use-id"),
439        depth: 1,
440        summary_prompt: "Summarize what you changed".to_string(),
441        context_low_prompt: "Context low".to_string(),
442    };
443
444    // Create subagent - tools should be rebound to use subagent's thread
445    let subagent = cx.new(|cx| {
446        crate::Thread::new_subagent(
447            project.clone(),
448            project_context,
449            context_server_registry,
450            crate::Templates::new(),
451            model.clone(),
452            subagent_context,
453            parent_tools,
454            cx,
455        )
456    });
457
458    // Get the subagent's thread ID - it should be different from parent
459    let subagent_thread_id = subagent.read_with(cx, |thread, _| thread.id().to_string());
460    assert_ne!(
461        parent_thread_id, subagent_thread_id,
462        "Subagent should have a different thread ID than parent"
463    );
464
465    // Verify the subagent has the tools
466    subagent.read_with(cx, |thread, _| {
467        assert!(
468            thread.has_registered_tool(ReadFileTool::NAME),
469            "subagent should have read_file tool"
470        );
471        assert!(
472            thread.has_registered_tool(EditFileTool::NAME),
473            "subagent should have edit_file tool"
474        );
475    });
476
477    // Submit a user message to the subagent
478    subagent
479        .update(cx, |thread, cx| {
480            thread.submit_user_message("Update the version in config.rs to 2.0.0", cx)
481        })
482        .unwrap();
483    cx.run_until_parked();
484
485    // First, model calls read_file to see the current content
486    let read_tool_use = LanguageModelToolUse {
487        id: "read_tool_1".into(),
488        name: ReadFileTool::NAME.into(),
489        raw_input: json!({"path": "project/src/config.rs"}).to_string(),
490        input: json!({"path": "project/src/config.rs"}),
491        is_input_complete: true,
492        thought_signature: None,
493    };
494    fake_model
495        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
496    fake_model.end_last_completion_stream();
497    cx.run_until_parked();
498
499    // Wait for the read tool to complete and model to be called again
500    let deadline = std::time::Instant::now() + Duration::from_secs(5);
501    while fake_model.pending_completions().is_empty() {
502        if std::time::Instant::now() >= deadline {
503            panic!("Timed out waiting for model to be called after read_file tool");
504        }
505        cx.run_until_parked();
506        cx.background_executor
507            .timer(Duration::from_millis(10))
508            .await;
509    }
510
511    // Model responds and calls edit_file
512    fake_model.send_last_completion_stream_text_chunk("I'll update the version now.");
513    let edit_tool_use = LanguageModelToolUse {
514        id: "edit_tool_1".into(),
515        name: EditFileTool::NAME.into(),
516        raw_input: json!({
517            "display_description": "Update version to 2.0.0",
518            "path": "project/src/config.rs",
519            "mode": "edit"
520        })
521        .to_string(),
522        input: json!({
523            "display_description": "Update version to 2.0.0",
524            "path": "project/src/config.rs",
525            "mode": "edit"
526        }),
527        is_input_complete: true,
528        thought_signature: None,
529    };
530    fake_model
531        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
532    fake_model.end_last_completion_stream();
533    cx.run_until_parked();
534
535    // The edit_file tool creates an EditAgent which makes its own model request.
536    // Wait for that request.
537    let deadline = std::time::Instant::now() + Duration::from_secs(5);
538    while fake_model.pending_completions().is_empty() {
539        if std::time::Instant::now() >= deadline {
540            panic!(
541                "Timed out waiting for edit agent completion request in subagent. Pending: {}",
542                fake_model.pending_completions().len()
543            );
544        }
545        cx.run_until_parked();
546        cx.background_executor
547            .timer(Duration::from_millis(10))
548            .await;
549    }
550
551    // Verify the edit agent's request uses the SUBAGENT's thread ID, not the parent's
552    let pending = fake_model.pending_completions();
553    let edit_agent_request = pending.last().unwrap();
554    let edit_agent_thread_id = edit_agent_request.thread_id.as_ref().unwrap();
555    std::assert_eq!(
556        edit_agent_thread_id,
557        &subagent_thread_id,
558        "Edit agent should use subagent's thread ID, not parent's. Got: {}, expected: {}",
559        edit_agent_thread_id,
560        subagent_thread_id
561    );
562    std::assert_ne!(
563        edit_agent_thread_id,
564        &parent_thread_id,
565        "Edit agent should NOT use parent's thread ID"
566    );
567
568    // Send the edit agent's response with the XML format it expects
569    let edit_response = "<old_text>pub const VERSION: &str = \"1.0.0\";</old_text>\n<new_text>pub const VERSION: &str = \"2.0.0\";</new_text>";
570    fake_model.send_last_completion_stream_text_chunk(edit_response);
571    fake_model.end_last_completion_stream();
572    cx.run_until_parked();
573
574    // Wait for the edit to complete and the thread to call the model again with tool results
575    let deadline = std::time::Instant::now() + Duration::from_secs(5);
576    while fake_model.pending_completions().is_empty() {
577        if std::time::Instant::now() >= deadline {
578            panic!("Timed out waiting for model to be called after edit completion in subagent");
579        }
580        cx.run_until_parked();
581        cx.background_executor
582            .timer(Duration::from_millis(10))
583            .await;
584    }
585
586    // Verify the file was edited
587    let file_content = fs
588        .load(path!("/project/src/config.rs").as_ref())
589        .await
590        .expect("file should exist");
591    assert!(
592        file_content.contains("2.0.0"),
593        "File should have been edited to contain new version. Content: {}",
594        file_content
595    );
596    assert!(
597        !file_content.contains("1.0.0"),
598        "Old version should be replaced. Content: {}",
599        file_content
600    );
601
602    // Verify the tool result was sent back to the model
603    let pending = fake_model.pending_completions();
604    assert!(
605        !pending.is_empty(),
606        "Model should have been called with tool result"
607    );
608
609    let last_request = pending.last().unwrap();
610    let has_tool_result = last_request.messages.iter().any(|m| {
611        m.content
612            .iter()
613            .any(|c| matches!(c, MessageContent::ToolResult(_)))
614    });
615    assert!(
616        has_tool_result,
617        "Tool result should be in the messages sent back to the model"
618    );
619}