diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index be87a6a1e1e5ddba8a5d4b3b5bca82168a141840..148702e1bafeae05ac67c6127d8259581aff93dd 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1452,6 +1452,7 @@ impl Thread { self.add_tool(StreamingEditFileTool::new( self.project.clone(), cx.weak_entity(), + self.action_log.clone(), language_registry, )); self.add_tool(FetchTool::new(self.project.read(cx).client().http_client())); diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index 62b96d569f34d65889abee6be803674dfa42e709..7140029df7b029d4fcb947bfabe99535ece7169d 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -73,7 +73,7 @@ pub struct StreamingEditFileToolInput { /// /// `frontend/db.js` /// - pub path: String, + pub path: PathBuf, /// The mode of operation on the file. Possible values: /// - 'write': Replace the entire contents of the file. If the file doesn't exist, it will be created. Requires 'content' field. @@ -93,7 +93,7 @@ pub struct StreamingEditFileToolInput { pub edits: Option>, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum StreamingEditFileMode { /// Overwrite the file with new content (replacing any existing content). @@ -187,20 +187,23 @@ impl From for LanguageModelToolResultContent { } pub struct StreamingEditFileTool { + project: Entity, thread: WeakEntity, + action_log: Entity, language_registry: Arc, - project: Entity, } impl StreamingEditFileTool { pub fn new( project: Entity, thread: WeakEntity, + action_log: Entity, language_registry: Arc, ) -> Self { Self { project, thread, + action_log, language_registry, } } @@ -264,11 +267,11 @@ impl AgentTool for StreamingEditFileTool { .read(cx) .short_full_path_for_project_path(&project_path, cx) }) - .unwrap_or(input.path) + .unwrap_or(input.path.to_string_lossy().into_owned()) .into(), Err(raw_input) => { - if let Some(input) = - serde_json::from_value::(raw_input).ok() + if let Ok(input) = + serde_json::from_value::(raw_input) { let path = input.path.unwrap_or_default(); let path = path.trim(); @@ -311,24 +314,37 @@ impl AgentTool for StreamingEditFileTool { partial = input.recv_partial().fuse() => { let Some(partial_value) = partial else { break }; if let Ok(parsed) = serde_json::from_value::(partial_value) { - if state.is_none() && let Some(path_str) = &parsed.path - && let Some(display_description) = &parsed.display_description - && let Some(mode) = parsed.mode.clone() { - state = Some( - EditSession::new( - path_str, - display_description, - mode, - &self, - &event_stream, - cx, - ) - .await?, - ); + if state.is_none() + && let StreamingEditFileToolPartialInput { + path: Some(path), + display_description: Some(display_description), + mode: Some(mode), + .. + } = &parsed + { + match EditSession::new( + &PathBuf::from(path), + display_description, + *mode, + &self, + &event_stream, + cx, + ) + .await + { + Ok(session) => state = Some(session), + Err(e) => { + log::error!("Failed to create edit session: {}", e); + return Err(e); + } + } } if let Some(state) = &mut state { - state.process(parsed, &self, &event_stream, cx)?; + if let Err(e) = state.process(parsed, &self, &event_stream, cx) { + log::error!("Failed to process edit: {}", e); + return Err(e); + } } } } @@ -341,22 +357,39 @@ impl AgentTool for StreamingEditFileTool { input .recv() .await - .map_err(|e| StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")))?; + .map_err(|e| { + let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")); + log::error!("Failed to receive tool input: {e}"); + err + })?; let mut state = if let Some(state) = state { state } else { - EditSession::new( + match EditSession::new( &full_input.path, &full_input.display_description, - full_input.mode.clone(), + full_input.mode, &self, &event_stream, cx, ) - .await? + .await + { + Ok(session) => session, + Err(e) => { + log::error!("Failed to create edit session: {}", e); + return Err(e); + } + } }; - state.finalize(full_input, &self, &event_stream, cx).await + match state.finalize(full_input, &self, &event_stream, cx).await { + Ok(output) => Ok(output), + Err(e) => { + log::error!("Failed to finalize edit: {}", e); + Err(e) + } + } }) } @@ -442,30 +475,24 @@ impl EditPipeline { } } -/// Compute the `LineIndent` of the first line in a set of query lines. -fn query_first_line_indent(query_lines: &[String]) -> text::LineIndent { - let first_line = query_lines.first().map(|s| s.as_str()).unwrap_or(""); - text::LineIndent::from_iter(first_line.chars()) -} - impl EditSession { async fn new( - path_str: &str, + path: &PathBuf, display_description: &str, mode: StreamingEditFileMode, tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, ) -> Result { - let path = PathBuf::from(path_str); let project_path = cx - .update(|cx| resolve_path(mode.clone(), &path, &tool.project, cx)) + .update(|cx| resolve_path(mode, &path, &tool.project, cx)) .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx)) else { return Err(StreamingEditFileToolOutput::error(format!( - "Worktree at '{path_str}' does not exist" + "Worktree at '{}' does not exist", + path.to_string_lossy() ))); }; @@ -483,12 +510,7 @@ impl EditSession { .await .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; - let action_log = tool - .thread - .read_with(cx, |thread, _cx| thread.action_log().clone()) - .ok(); - - ensure_buffer_saved(&buffer, &abs_path, tool, action_log.as_ref(), cx)?; + ensure_buffer_saved(&buffer, &abs_path, tool, cx)?; let diff = cx.new(|cx| Diff::new(buffer.clone(), cx)); event_stream.update_diff(diff.clone()); @@ -500,9 +522,8 @@ impl EditSession { } }) as Box); - if let Some(action_log) = &action_log { - action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); - } + tool.action_log + .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let old_text = cx @@ -531,69 +552,31 @@ impl EditSession { event_stream: &ToolCallEventStream, cx: &mut AsyncApp, ) -> Result { - let Self { - buffer, - old_text, - diff, - abs_path, - parser, - pipeline, - .. - } = self; - - let action_log = tool - .thread - .read_with(cx, |thread, _cx| thread.action_log().clone()) - .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; + let old_text = self.old_text.clone(); match input.mode { StreamingEditFileMode::Write => { - action_log.update(cx, |log, cx| { - log.buffer_created(buffer.clone(), cx); - }); let content = input.content.ok_or_else(|| { StreamingEditFileToolOutput::error("'content' field is required for write mode") })?; - let events = parser.finalize_content(&content); - Self::process_events( - &events, - buffer, - diff, - pipeline, - abs_path, - tool, - event_stream, - cx, - )?; + let events = self.parser.finalize_content(&content); + self.process_events(&events, tool, event_stream, cx)?; + + tool.action_log.update(cx, |log, cx| { + log.buffer_created(self.buffer.clone(), cx); + }); } StreamingEditFileMode::Edit => { let edits = input.edits.ok_or_else(|| { StreamingEditFileToolOutput::error("'edits' field is required for edit mode") })?; - - let final_edits = edits - .into_iter() - .map(|e| Edit { - old_text: e.old_text, - new_text: e.new_text, - }) - .collect::>(); - let events = parser.finalize_edits(&final_edits); - Self::process_events( - &events, - buffer, - diff, - pipeline, - abs_path, - tool, - event_stream, - cx, - )?; + let events = self.parser.finalize_edits(&edits); + self.process_events(&events, tool, event_stream, cx)?; } } - let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { + let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| { let settings = language_settings::language_settings( buffer.language().map(|l| l.name()), buffer.file(), @@ -603,13 +586,13 @@ impl EditSession { }); if format_on_save_enabled { - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); + tool.action_log.update(cx, |log, cx| { + log.buffer_edited(self.buffer.clone(), cx); }); let format_task = tool.project.update(cx, |project, cx| { project.format( - HashSet::from_iter([buffer.clone()]), + HashSet::from_iter([self.buffer.clone()]), LspFormatTarget::Buffers, false, FormatTrigger::Save, @@ -624,9 +607,9 @@ impl EditSession { }; } - let save_task = tool - .project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); + let save_task = tool.project.update(cx, |project, cx| { + project.save_buffer(self.buffer.clone(), cx) + }); futures::select! { result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; }, _ = event_stream.cancelled_by_user().fuse() => { @@ -634,11 +617,11 @@ impl EditSession { } }; - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); + tool.action_log.update(cx, |log, cx| { + log.buffer_edited(self.buffer.clone(), cx); }); - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let (new_text, unified_diff) = cx .background_spawn({ let new_snapshot = new_snapshot.clone(); @@ -652,7 +635,7 @@ impl EditSession { .await; let output = StreamingEditFileToolOutput::Success { - input_path: PathBuf::from(input.path), + input_path: input.path, new_text, old_text: old_text.clone(), diff: unified_diff, @@ -671,31 +654,13 @@ impl EditSession { StreamingEditFileMode::Write => { if let Some(content) = &partial.content { let events = self.parser.push_content(content); - Self::process_events( - &events, - &self.buffer, - &self.diff, - &mut self.pipeline, - &self.abs_path, - tool, - event_stream, - cx, - )?; + self.process_events(&events, tool, event_stream, cx)?; } } StreamingEditFileMode::Edit => { if let Some(edits) = partial.edits { let events = self.parser.push_edits(&edits); - Self::process_events( - &events, - &self.buffer, - &self.diff, - &mut self.pipeline, - &self.abs_path, - tool, - event_stream, - cx, - )?; + self.process_events(&events, tool, event_stream, cx)?; } } } @@ -703,46 +668,38 @@ impl EditSession { } fn process_events( + &mut self, events: &[ToolEditEvent], - buffer: &Entity, - diff: &Entity, - pipeline: &mut EditPipeline, - abs_path: &PathBuf, tool: &StreamingEditFileTool, event_stream: &ToolCallEventStream, cx: &mut AsyncApp, ) -> Result<(), StreamingEditFileToolOutput> { - let action_log = tool - .thread - .read_with(cx, |thread, _cx| thread.action_log().clone()) - .ok(); - for event in events { match event { ToolEditEvent::ContentChunk { chunk } => { - let (buffer_id, insert_at) = buffer.read_with(cx, |buffer, _cx| { - let insert_at = if !pipeline.content_written && buffer.len() > 0 { - 0..buffer.len() - } else { - let len = buffer.len(); - len..len - }; - (buffer.remote_id(), insert_at) - }); + let (buffer_id, buffer_len) = self + .buffer + .read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len())); + let edit_range = if self.pipeline.content_written { + buffer_len..buffer_len + } else { + 0..buffer_len + }; + agent_edit_buffer( - buffer, - [(insert_at, chunk.as_str())], - action_log.as_ref(), + &self.buffer, + [(edit_range, chunk.as_str())], + &tool.action_log, cx, ); cx.update(|cx| { tool.set_agent_location( - buffer.downgrade(), + self.buffer.downgrade(), text::Anchor::max_for_buffer(buffer_id), cx, ); }); - pipeline.content_written = true; + self.pipeline.content_written = true; } ToolEditEvent::OldTextChunk { @@ -750,23 +707,24 @@ impl EditSession { chunk, done: false, } => { - pipeline.ensure_resolving_old_text(*edit_index, buffer, cx); + self.pipeline + .ensure_resolving_old_text(*edit_index, &self.buffer, cx); if let EditPipelineEntry::ResolvingOldText { matcher } = - &mut pipeline.edits[*edit_index] + &mut self.pipeline.edits[*edit_index] + && !chunk.is_empty() { - if !chunk.is_empty() { - if let Some(match_range) = matcher.push(chunk, None) { - let anchor_range = buffer.read_with(cx, |buffer, _cx| { - buffer.anchor_range_between(match_range.clone()) - }); - diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); - - cx.update(|cx| { - let position = buffer.read(cx).anchor_before(match_range.end); - tool.set_agent_location(buffer.downgrade(), position, cx); - }); - } + if let Some(match_range) = matcher.push(chunk, None) { + let anchor_range = self.buffer.read_with(cx, |buffer, _cx| { + buffer.anchor_range_between(match_range.clone()) + }); + self.diff + .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); + + cx.update(|cx| { + let position = self.buffer.read(cx).anchor_before(match_range.end); + tool.set_agent_location(self.buffer.downgrade(), position, cx); + }); } } } @@ -776,10 +734,11 @@ impl EditSession { chunk, done: true, } => { - pipeline.ensure_resolving_old_text(*edit_index, buffer, cx); + self.pipeline + .ensure_resolving_old_text(*edit_index, &self.buffer, cx); let EditPipelineEntry::ResolvingOldText { matcher } = - &mut pipeline.edits[*edit_index] + &mut self.pipeline.edits[*edit_index] else { continue; }; @@ -787,60 +746,47 @@ impl EditSession { if !chunk.is_empty() { matcher.push(chunk, None); } - let matches = matcher.finish(); - - if matches.is_empty() { - return Err(StreamingEditFileToolOutput::error(format!( - "Could not find matching text for edit at index {}. \ - The old_text did not match any content in the file. \ - Please read the file again to get the current content.", - edit_index, - ))); - } - if matches.len() > 1 { - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let lines = matches - .iter() - .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) - .collect::>() - .join(", "); - return Err(StreamingEditFileToolOutput::error(format!( - "Edit {} matched multiple locations in the file at lines: {}. \ - Please provide more context in old_text to uniquely \ - identify the location.", - edit_index, lines - ))); - } - - let range = matches.into_iter().next().expect("checked len above"); + let range = extract_match(matcher.finish(), &self.buffer, edit_index, cx)?; - let anchor_range = buffer + let anchor_range = self + .buffer .read_with(cx, |buffer, _cx| buffer.anchor_range_between(range.clone())); - diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); + self.diff + .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); let line = snapshot.offset_to_point(range.start).row; event_stream.update_fields( - ToolCallUpdateFields::new() - .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]), + ToolCallUpdateFields::new().locations(vec![ + ToolCallLocation::new(&self.abs_path).line(Some(line)), + ]), ); let EditPipelineEntry::ResolvingOldText { matcher } = - &pipeline.edits[*edit_index] + &self.pipeline.edits[*edit_index] else { continue; }; let buffer_indent = snapshot.line_indent_for_row(snapshot.offset_to_point(range.start).row); - let query_indent = query_first_line_indent(matcher.query_lines()); + let query_indent = text::LineIndent::from_iter( + matcher + .query_lines() + .first() + .map(|s| s.as_str()) + .unwrap_or("") + .chars(), + ); let indent_delta = compute_indent_delta(buffer_indent, query_indent); let old_text_in_buffer = snapshot.text_for_range(range.clone()).collect::(); - let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); - pipeline.edits[*edit_index] = EditPipelineEntry::StreamingNewText { + let text_snapshot = self + .buffer + .read_with(cx, |buffer, _cx| buffer.text_snapshot()); + self.pipeline.edits[*edit_index] = EditPipelineEntry::StreamingNewText { streaming_diff: StreamingDiff::new(old_text_in_buffer), edit_cursor: range.start, reindenter: Reindenter::new(indent_delta), @@ -848,8 +794,8 @@ impl EditSession { }; cx.update(|cx| { - let position = buffer.read(cx).anchor_before(range.end); - tool.set_agent_location(buffer.downgrade(), position, cx); + let position = self.buffer.read(cx).anchor_before(range.end); + tool.set_agent_location(self.buffer.downgrade(), position, cx); }); } @@ -858,7 +804,7 @@ impl EditSession { chunk, done: false, } => { - if *edit_index >= pipeline.edits.len() { + if *edit_index >= self.pipeline.edits.len() { continue; } let EditPipelineEntry::StreamingNewText { @@ -867,7 +813,7 @@ impl EditSession { reindenter, original_snapshot, .. - } = &mut pipeline.edits[*edit_index] + } = &mut self.pipeline.edits[*edit_index] else { continue; }; @@ -878,18 +824,18 @@ impl EditSession { } let char_ops = streaming_diff.push_new(&reindented); - Self::apply_char_operations( + apply_char_operations( &char_ops, - buffer, + &self.buffer, original_snapshot, edit_cursor, - action_log.as_ref(), + &tool.action_log, cx, ); let position = original_snapshot.anchor_before(*edit_cursor); cx.update(|cx| { - tool.set_agent_location(buffer.downgrade(), position, cx); + tool.set_agent_location(self.buffer.downgrade(), position, cx); }); } @@ -898,7 +844,7 @@ impl EditSession { chunk, done: true, } => { - if *edit_index >= pipeline.edits.len() { + if *edit_index >= self.pipeline.edits.len() { continue; } @@ -908,7 +854,7 @@ impl EditSession { mut reindenter, original_snapshot, } = std::mem::replace( - &mut pipeline.edits[*edit_index], + &mut self.pipeline.edits[*edit_index], EditPipelineEntry::Done, ) else { @@ -921,64 +867,95 @@ impl EditSession { if !final_text.is_empty() { let char_ops = streaming_diff.push_new(&final_text); - Self::apply_char_operations( + apply_char_operations( &char_ops, - buffer, + &self.buffer, &original_snapshot, &mut edit_cursor, - action_log.as_ref(), + &tool.action_log, cx, ); } let remaining_ops = streaming_diff.finish(); - Self::apply_char_operations( + apply_char_operations( &remaining_ops, - buffer, + &self.buffer, &original_snapshot, &mut edit_cursor, - action_log.as_ref(), + &tool.action_log, cx, ); let position = original_snapshot.anchor_before(edit_cursor); cx.update(|cx| { - tool.set_agent_location(buffer.downgrade(), position, cx); + tool.set_agent_location(self.buffer.downgrade(), position, cx); }); } } } Ok(()) } +} - fn apply_char_operations( - ops: &[CharOperation], - buffer: &Entity, - snapshot: &text::BufferSnapshot, - edit_cursor: &mut usize, - action_log: Option<&Entity>, - cx: &mut AsyncApp, - ) { - for op in ops { - match op { - CharOperation::Insert { text } => { - let anchor = snapshot.anchor_after(*edit_cursor); - agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx); - } - CharOperation::Delete { bytes } => { - let delete_end = *edit_cursor + bytes; - let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end); - agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx); - *edit_cursor = delete_end; - } - CharOperation::Keep { bytes } => { - *edit_cursor += bytes; - } +fn apply_char_operations( + ops: &[CharOperation], + buffer: &Entity, + snapshot: &text::BufferSnapshot, + edit_cursor: &mut usize, + action_log: &Entity, + cx: &mut AsyncApp, +) { + for op in ops { + match op { + CharOperation::Insert { text } => { + let anchor = snapshot.anchor_after(*edit_cursor); + agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx); + } + CharOperation::Delete { bytes } => { + let delete_end = *edit_cursor + bytes; + let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end); + agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx); + *edit_cursor = delete_end; + } + CharOperation::Keep { bytes } => { + *edit_cursor += bytes; } } } } +fn extract_match( + matches: Vec>, + buffer: &Entity, + edit_index: &usize, + cx: &mut AsyncApp, +) -> Result, StreamingEditFileToolOutput> { + match matches.len() { + 0 => Err(StreamingEditFileToolOutput::error(format!( + "Could not find matching text for edit at index {}. \ + The old_text did not match any content in the file. \ + Please read the file again to get the current content.", + edit_index, + ))), + 1 => Ok(matches.into_iter().next().unwrap()), + _ => { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let lines = matches + .iter() + .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) + .collect::>() + .join(", "); + Err(StreamingEditFileToolOutput::error(format!( + "Edit {} matched multiple locations in the file at lines: {}. \ + Please provide more context in old_text to uniquely \ + identify the location.", + edit_index, lines + ))) + } + } +} + /// Edits a buffer and reports the edit to the action log in the same effect /// cycle. This ensures the action log's subscription handler sees the version /// already updated by `buffer_edited`, so it does not misattribute the agent's @@ -986,7 +963,7 @@ impl EditSession { fn agent_edit_buffer( buffer: &Entity, edits: I, - action_log: Option<&Entity>, + action_log: &Entity, cx: &mut AsyncApp, ) where I: IntoIterator, T)>, @@ -997,9 +974,7 @@ fn agent_edit_buffer( buffer.update(cx, |buffer, cx| { buffer.edit(edits, None, cx); }); - if let Some(action_log) = action_log { - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - } + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); }); } @@ -1007,11 +982,11 @@ fn ensure_buffer_saved( buffer: &Entity, abs_path: &PathBuf, tool: &StreamingEditFileTool, - action_log: Option<&Entity>, cx: &mut AsyncApp, ) -> Result<(), StreamingEditFileToolOutput> { - let last_read_mtime = - action_log.and_then(|log| log.read_with(cx, |log, _| log.file_read_time(abs_path))); + let last_read_mtime = tool + .action_log + .read_with(cx, |log, _| log.file_read_time(abs_path)); let check_result = tool.thread.read_with(cx, |thread, cx| { let current = buffer .read(cx) @@ -1140,42 +1115,17 @@ mod tests { #[gpui::test] async fn test_streaming_edit_create_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Create new file".into(), - path: "root/dir/new_file.txt".into(), - mode: StreamingEditFileMode::Write, - content: Some("Hello, World!".into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Create new file".into(), + path: "root/dir/new_file.txt".into(), + mode: StreamingEditFileMode::Write, + content: Some("Hello, World!".into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -1191,43 +1141,18 @@ mod tests { #[gpui::test] async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"file.txt": "old content"})) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "old content"})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Overwrite file".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Write, - content: Some("new content".into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Overwrite file".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Write, + content: Some("new content".into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -1246,51 +1171,21 @@ mod tests { #[gpui::test] async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit lines".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![Edit { - old_text: "line 2".into(), - new_text: "modified line 2".into(), - }]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit lines".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![Edit { + old_text: "line 2".into(), + new_text: "modified line 2".into(), + }]), + }), ToolCallEventStream::test().0, cx, ) @@ -1305,57 +1200,30 @@ mod tests { #[gpui::test] async fn test_streaming_edit_multiple_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit multiple lines".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![ - Edit { - old_text: "line 5".into(), - new_text: "modified line 5".into(), - }, - Edit { - old_text: "line 1".into(), - new_text: "modified line 1".into(), - }, - ]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit multiple lines".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![ + Edit { + old_text: "line 5".into(), + new_text: "modified line 5".into(), + }, + Edit { + old_text: "line 1".into(), + new_text: "modified line 1".into(), + }, + ]), + }), ToolCallEventStream::test().0, cx, ) @@ -1373,57 +1241,30 @@ mod tests { #[gpui::test] async fn test_streaming_edit_adjacent_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit adjacent lines".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![ - Edit { - old_text: "line 2".into(), - new_text: "modified line 2".into(), - }, - Edit { - old_text: "line 3".into(), - new_text: "modified line 3".into(), - }, - ]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit adjacent lines".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![ + Edit { + old_text: "line 2".into(), + new_text: "modified line 2".into(), + }, + Edit { + old_text: "line 3".into(), + new_text: "modified line 3".into(), + }, + ]), + }), ToolCallEventStream::test().0, cx, ) @@ -1441,57 +1282,30 @@ mod tests { #[gpui::test] async fn test_streaming_edit_ascending_order_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit multiple lines in ascending order".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![ - Edit { - old_text: "line 1".into(), - new_text: "modified line 1".into(), - }, - Edit { - old_text: "line 5".into(), - new_text: "modified line 5".into(), - }, - ]), - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit multiple lines in ascending order".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![ + Edit { + old_text: "line 1".into(), + new_text: "modified line 1".into(), + }, + Edit { + old_text: "line 5".into(), + new_text: "modified line 5".into(), + }, + ]), + }), ToolCallEventStream::test().0, cx, ) @@ -1509,45 +1323,20 @@ mod tests { #[gpui::test] async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Some edit".into(), - path: "root/nonexistent_file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![Edit { - old_text: "foo".into(), - new_text: "bar".into(), - }]), - }; - Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Some edit".into(), + path: "root/nonexistent_file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![Edit { + old_text: "foo".into(), + new_text: "bar".into(), + }]), + }), ToolCallEventStream::test().0, cx, ) @@ -1562,46 +1351,21 @@ mod tests { #[gpui::test] async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"file.txt": "hello world"})) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world"})).await; let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Edit file".into(), - path: "root/file.txt".into(), - mode: StreamingEditFileMode::Edit, - content: None, - edits: Some(vec![Edit { - old_text: "nonexistent text that is not in the file".into(), - new_text: "replacement".into(), - }]), - }; - Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit file".into(), + path: "root/file.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![Edit { + old_text: "nonexistent text that is not in the file".into(), + new_text: "replacement".into(), + }]), + }), ToolCallEventStream::test().0, cx, ) @@ -1619,42 +1383,11 @@ mod tests { #[gpui::test] async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send partials simulating LLM streaming: description first, then path, then mode sender.send_partial(json!({"display_description": "Edit lines"})); @@ -1691,42 +1424,11 @@ mod tests { #[gpui::test] async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send partial with path but NO mode — path should NOT be treated as complete sender.send_partial(json!({ @@ -1760,43 +1462,12 @@ mod tests { #[gpui::test] async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver, mut cancellation_tx) = ToolCallEventStream::test_with_cancellation(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send a partial sender.send_partial(json!({"display_description": "Edit"})); @@ -1822,42 +1493,14 @@ mod tests { #[gpui::test] async fn test_streaming_edit_with_multiple_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Simulate fine-grained streaming of the JSON sender.send_partial(json!({"display_description": "Edit multiple"})); @@ -1918,36 +1561,10 @@ mod tests { #[gpui::test] async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Stream partials for create mode sender.send_partial(json!({"display_description": "Create new file"})); @@ -1985,42 +1602,11 @@ mod tests { #[gpui::test] async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send final immediately with no partials (simulates non-streaming path) sender.send_final(json!({ @@ -2039,42 +1625,14 @@ mod tests { #[gpui::test] async fn test_streaming_incremental_edit_application(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" - }), + let (tool, project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Stream description, path, mode sender.send_partial(json!({"display_description": "Edit multiple lines"})); @@ -2168,42 +1726,11 @@ mod tests { #[gpui::test] async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "aaa\nbbb\nccc\nddd\neee\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup: description + path + mode sender.send_partial(json!({ @@ -2287,44 +1814,13 @@ mod tests { assert_eq!(new_text, "AAA\nbbb\nCCC\nddd\nEEE\n"); } - #[gpui::test] - async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - - let (sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + #[gpui::test] + async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { + let (tool, project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup sender.send_partial(json!({ @@ -2401,42 +1897,11 @@ mod tests { #[gpui::test] async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup + single edit that stays in-progress (no second edit to prove completion) sender.send_partial(json!({ @@ -2480,44 +1945,12 @@ mod tests { #[gpui::test] async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); - let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run(input, event_stream, cx) - }); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send progressively more complete partial snapshots, as the LLM would sender.send_partial(json!({ @@ -2557,44 +1990,12 @@ mod tests { #[gpui::test] async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello world\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world\n"})).await; let (sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); - let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run(input, event_stream, cx) - }); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Send a partial then drop the sender without sending final sender.send_partial(json!({ @@ -2613,41 +2014,14 @@ mod tests { #[gpui::test] async fn test_streaming_input_recv_drains_partials(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; // Create a channel and send multiple partials before a final, then use // ToolInput::resolved-style immediate delivery to confirm recv() works // when partials are already buffered. let (sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); - let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run(input, event_stream, cx) - }); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Buffer several partials before sending the final sender.send_partial(json!({"display_description": "Create"})); @@ -2746,7 +2120,7 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.update(|cx| resolve_path(mode.clone(), &PathBuf::from(path), &project, cx)) + cx.update(|cx| resolve_path(*mode, &PathBuf::from(path), &project, cx)) } #[track_caller] @@ -2761,8 +2135,8 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let (tool, project, action_log, fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let rust_language = Arc::new(language::Language::new( language::LanguageConfig { @@ -2811,9 +2185,10 @@ mod tests { project.register_buffer_with_language_servers(&buffer, cx) }); - const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; - const FORMATTED_CONTENT: &str = - "This file was formatted by the fake formatter in the test.\n"; + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\ +"; + const FORMATTED_CONTENT: &str = "This file was formatted by the fake formatter in the test.\ +"; // Get the fake language server and set up formatting handler let fake_language_server = fake_language_servers.next().await.unwrap(); @@ -2826,20 +2201,6 @@ mod tests { } }); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - // Test with format_on_save enabled cx.update(|cx| { SettingsStore::update_global(cx, |store, cx| { @@ -2855,13 +2216,7 @@ mod tests { let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry.clone(), - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); sender.send_partial(json!({ "display_description": "Create main function", @@ -2912,13 +2267,14 @@ mod tests { let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let tool = Arc::new(StreamingEditFileTool::new( + let tool2 = Arc::new(StreamingEditFileTool::new( project.clone(), thread.downgrade(), + action_log.clone(), language_registry, )); - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool2.run(input, event_stream, cx)); sender.send_partial(json!({ "display_description": "Update main function", @@ -2953,7 +2309,6 @@ mod tests { let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/root", json!({"src": {}})).await; - fs.save( path!("/root/src/main.rs").as_ref(), &"initial content".into(), @@ -2961,22 +2316,9 @@ mod tests { ) .await .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); + let (tool, project, action_log, fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; + let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); // Test with remove_trailing_whitespace_on_save enabled cx.update(|cx| { @@ -2996,20 +2338,14 @@ mod tests { let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Create main function".into(), - path: "root/src/main.rs".into(), - mode: StreamingEditFileMode::Write, - content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry.clone(), - )) - .run( - ToolInput::resolved(input), + tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: StreamingEditFileMode::Write, + content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -3041,22 +2377,23 @@ mod tests { }); }); + let tool2 = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + action_log.clone(), + language_registry, + )); + let result = cx .update(|cx| { - let input = StreamingEditFileToolInput { - display_description: "Update main function".into(), - path: "root/src/main.rs".into(), - mode: StreamingEditFileMode::Write, - content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), - edits: None, - }; - Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )) - .run( - ToolInput::resolved(input), + tool2.run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: StreamingEditFileMode::Write, + content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), + edits: None, + }), ToolCallEventStream::test().0, cx, ) @@ -3076,29 +2413,7 @@ mod tests { #[gpui::test] async fn test_streaming_authorize(cx: &mut TestAppContext) { - init_test(cx); - let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - fs.insert_tree("/root", json!({})).await; + let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; // Test 1: Path with .zed component should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); @@ -3219,27 +2534,8 @@ mod tests { fs.insert_tree("/outside", json!({})).await; fs.insert_symlink("/root/link", PathBuf::from("/outside")) .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -3289,42 +2585,21 @@ mod tests { }), ) .await; - fs.insert_tree( - path!("/outside"), - json!({ - "config.txt": "old content" - }), - ) - .await; - fs.create_symlink( - path!("/root/link_to_external").as_ref(), - PathBuf::from("/outside"), - ) - .await - .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.executor().run_until_parked(); - - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + fs.insert_tree( + path!("/outside"), + json!({ + "config.txt": "old content" + }), + ) + .await; + fs.create_symlink( + path!("/root/link_to_external").as_ref(), + PathBuf::from("/outside"), + ) + .await + .unwrap(); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let _authorize_task = cx.update(|cx| { @@ -3369,29 +2644,8 @@ mod tests { ) .await .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.executor().run_until_parked(); - - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let authorize_task = cx.update(|cx| { @@ -3446,29 +2700,8 @@ mod tests { ) .await .unwrap(); - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.executor().run_until_parked(); - - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let result = cx @@ -3497,26 +2730,8 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let test_cases = vec![ ( @@ -3559,7 +2774,6 @@ mod tests { async fn test_streaming_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { init_test(cx); let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( "/workspace/frontend", json!({ @@ -3587,36 +2801,16 @@ mod tests { }), ) .await; - - let project = Project::test( - fs.clone(), - [ + let (tool, _project, _action_log, _fs, _thread) = setup_test_with_fs( + cx, + fs, + &[ path!("/workspace/frontend").as_ref(), path!("/workspace/backend").as_ref(), path!("/workspace/shared").as_ref(), ], - cx, ) .await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); let test_cases = vec![ ("frontend/src/main.js", false, "File in first worktree"), @@ -3671,26 +2865,8 @@ mod tests { }), ) .await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let test_cases = vec![ ("", false, "Empty path is treated as project root"), @@ -3746,26 +2922,8 @@ mod tests { }), ) .await; - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let modes = vec![StreamingEditFileMode::Edit, StreamingEditFileMode::Write]; @@ -3816,26 +2974,9 @@ mod tests { async fn test_streaming_initial_title_with_partial_input(cx: &mut TestAppContext) { init_test(cx); let fs = project::FakeFs::new(cx.executor()); - let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let tool = Arc::new(StreamingEditFileTool::new( - project, - thread.downgrade(), - language_registry, - )); + fs.insert_tree("/project", json!({})).await; + let (tool, _project, _action_log, _fs, _thread) = + setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; cx.update(|cx| { assert_eq!( @@ -3890,33 +3031,15 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/", json!({"main.rs": ""})).await; - - let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await; - let languages = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry.clone(), - Templates::new(), - Some(model.clone()), - cx, - ) - }); + let (tool, project, action_log, _fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/").as_ref()]).await; + let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); // Ensure the diff is finalized after the edit completes. { - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - languages.clone(), - )); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { - tool.run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit file".into(), path: path!("/main.rs").into(), @@ -3941,7 +3064,8 @@ mod tests { let tool = Arc::new(StreamingEditFileTool::new( project.clone(), thread.downgrade(), - languages.clone(), + action_log, + language_registry, )); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let edit = cx.update(|cx| { @@ -3968,38 +3092,12 @@ mod tests { #[gpui::test] async fn test_streaming_consecutive_edits_work(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "test.txt": "original content" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let languages = project.read_with(cx, |project, _| project.languages().clone()); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - - let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); - let edit_tool = Arc::new(StreamingEditFileTool::new( + let (tool, project, action_log, _fs, _thread) = + setup_test(cx, json!({"test.txt": "original content"})).await; + let read_tool = Arc::new(crate::ReadFileTool::new( project.clone(), - thread.downgrade(), - languages, + action_log.clone(), + true, )); // Read the file first @@ -4020,7 +3118,7 @@ mod tests { // First edit should work let edit_result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "First edit".into(), path: "root/test.txt".into(), @@ -4045,7 +3143,7 @@ mod tests { // Second edit should also work because the edit updated the recorded read time let edit_result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Second edit".into(), path: "root/test.txt".into(), @@ -4070,38 +3168,12 @@ mod tests { #[gpui::test] async fn test_streaming_external_modification_detected(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "test.txt": "original content" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let languages = project.read_with(cx, |project, _| project.languages().clone()); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - - let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); - let edit_tool = Arc::new(StreamingEditFileTool::new( + let (tool, project, action_log, fs, _thread) = + setup_test(cx, json!({"test.txt": "original content"})).await; + let read_tool = Arc::new(crate::ReadFileTool::new( project.clone(), - thread.downgrade(), - languages, + action_log.clone(), + true, )); // Read the file first @@ -4150,7 +3222,7 @@ mod tests { // Try to edit - should fail because file was modified externally let result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit after external change".into(), path: "root/test.txt".into(), @@ -4165,52 +3237,26 @@ mod tests { cx, ) }) - .await; - - let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { - panic!("expected error"); - }; - assert!( - error.contains("has been modified since you last read it"), - "Error should mention file modification, got: {}", - error - ); - } - - #[gpui::test] - async fn test_streaming_dirty_buffer_detected(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "test.txt": "original content" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model.clone()), - cx, - ) - }); - let languages = project.read_with(cx, |project, _| project.languages().clone()); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + .await; + + let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + panic!("expected error"); + }; + assert!( + error.contains("has been modified since you last read it"), + "Error should mention file modification, got: {}", + error + ); + } - let read_tool = Arc::new(crate::ReadFileTool::new(project.clone(), action_log, true)); - let edit_tool = Arc::new(StreamingEditFileTool::new( + #[gpui::test] + async fn test_streaming_dirty_buffer_detected(cx: &mut TestAppContext) { + let (tool, project, action_log, _fs, _thread) = + setup_test(cx, json!({"test.txt": "original content"})).await; + let read_tool = Arc::new(crate::ReadFileTool::new( project.clone(), - thread.downgrade(), - languages, + action_log.clone(), + true, )); // Read the file first @@ -4250,7 +3296,7 @@ mod tests { // Try to edit - should fail because buffer has unsaved changes let result = cx .update(|cx| { - edit_tool.clone().run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit with dirty buffer".into(), path: "root/test.txt".into(), @@ -4289,46 +3335,15 @@ mod tests { #[gpui::test] async fn test_streaming_overlapping_edits_resolved_sequentially(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); // Edit 1's replacement introduces text that contains edit 2's // old_text as a substring. Because edits resolve sequentially // against the current buffer, edit 2 finds a unique match in // the modified buffer and succeeds. - fs.insert_tree( - "/root", - json!({ - "file.txt": "aaa\nbbb\nccc\nddd\neee\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Setup: resolve the buffer sender.send_partial(json!({ @@ -4376,36 +3391,10 @@ mod tests { #[gpui::test] async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"dir": {}})).await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Transition to BufferResolved sender.send_partial(json!({ @@ -4473,42 +3462,14 @@ mod tests { #[gpui::test] async fn test_streaming_overwrite_diff_revealed_during_streaming(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "old line 1\nold line 2\nold line 3\n" - }), + let (tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, mut receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Transition to BufferResolved sender.send_partial(json!({ @@ -4566,42 +3527,14 @@ mod tests { #[gpui::test] async fn test_streaming_overwrite_content_streamed(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "old line 1\nold line 2\nold line 3\n" - }), + let (tool, project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), ) .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); // Transition to BufferResolved sender.send_partial(json!({ @@ -4665,42 +3598,11 @@ mod tests { #[gpui::test] async fn test_streaming_edit_json_fixer_escape_corruption(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - "/root", - json!({ - "file.txt": "hello\nworld\nfoo\n" - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let model = Arc::new(FakeLanguageModel::default()); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - Some(model), - cx, - ) - }); - + let (tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello\nworld\nfoo\n"})).await; let (sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); - - let task = cx.update(|cx| tool.run(input, event_stream, cx)); + let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); sender.send_partial(json!({ "display_description": "Edit", @@ -4750,47 +3652,17 @@ mod tests { // reports changed buffers so that the Accept All / Reject All review UI appears. #[gpui::test] async fn test_streaming_edit_file_tool_registers_changed_buffers(cx: &mut TestAppContext) { - init_test(cx); + let (tool, _project, action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); settings.tool_permissions.default = settings::ToolPermissionMode::Allow; agent_settings::AgentSettings::override_global(settings, cx); }); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "file.txt": "line 1\nline 2\nline 3\n" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - None, - cx, - ) - }); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); let (event_stream, _rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - tool.run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Edit lines".to_string(), path: "root/file.txt".into(), @@ -4814,7 +3686,7 @@ mod tests { let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx)); assert!( !changed.is_empty(), - "action_log.changed_buffers() should be non-empty after streaming edit, \ + "action_log.changed_buffers() should be non-empty after streaming edit, but no changed buffers were found \u{2014} Accept All / Reject All will not appear" ); } @@ -4824,47 +3696,17 @@ mod tests { async fn test_streaming_edit_file_tool_write_mode_registers_changed_buffers( cx: &mut TestAppContext, ) { - init_test(cx); + let (tool, _project, action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "original content"})).await; cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); settings.tool_permissions.default = settings::ToolPermissionMode::Allow; agent_settings::AgentSettings::override_global(settings, cx); }); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "file.txt": "original content" - }), - ) - .await; - - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); - let context_server_registry = - cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); - let thread = cx.new(|cx| { - crate::Thread::new( - project.clone(), - cx.new(|_cx| ProjectContext::default()), - context_server_registry, - Templates::new(), - None, - cx, - ) - }); - let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - - let tool = Arc::new(StreamingEditFileTool::new( - project.clone(), - thread.downgrade(), - language_registry, - )); let (event_stream, _rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - tool.run( + tool.clone().run( ToolInput::resolved(StreamingEditFileToolInput { display_description: "Overwrite file".to_string(), path: "root/file.txt".into(), @@ -4890,6 +3732,58 @@ mod tests { ); } + async fn setup_test_with_fs( + cx: &mut TestAppContext, + fs: Arc, + worktree_paths: &[&std::path::Path], + ) -> ( + Arc, + Entity, + Entity, + Arc, + Entity, + ) { + let project = Project::test(fs.clone(), worktree_paths.iter().copied(), cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + action_log.clone(), + language_registry, + )); + (tool, project, action_log, fs, thread) + } + + async fn setup_test( + cx: &mut TestAppContext, + initial_tree: serde_json::Value, + ) -> ( + Arc, + Entity, + Entity, + Arc, + Entity, + ) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", initial_tree).await; + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx);