1mod edit_parser;
2#[cfg(test)]
3mod evals;
4
5use crate::{Template, Templates};
6use aho_corasick::AhoCorasick;
7use anyhow::Result;
8use assistant_tool::ActionLog;
9use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
10use futures::{
11 Stream, StreamExt,
12 channel::mpsc::{self, UnboundedReceiver},
13 pin_mut,
14 stream::BoxStream,
15};
16use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
17use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
18use language_model::{
19 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
20 MessageContent, Role,
21};
22use project::{AgentLocation, Project};
23use serde::Serialize;
24use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
25use streaming_diff::{CharOperation, StreamingDiff};
26
27#[derive(Serialize)]
28struct CreateFilePromptTemplate {
29 path: Option<PathBuf>,
30 edit_description: String,
31}
32
33impl Template for CreateFilePromptTemplate {
34 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
35}
36
37#[derive(Serialize)]
38struct EditFilePromptTemplate {
39 path: Option<PathBuf>,
40 edit_description: String,
41}
42
43impl Template for EditFilePromptTemplate {
44 const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
45}
46
47#[derive(Clone, Debug, PartialEq, Eq)]
48pub enum EditAgentOutputEvent {
49 Edited,
50 OldTextNotFound(SharedString),
51}
52
53#[derive(Clone, Debug)]
54pub struct EditAgentOutput {
55 pub _raw_edits: String,
56 pub _parser_metrics: EditParserMetrics,
57}
58
59#[derive(Clone)]
60pub struct EditAgent {
61 model: Arc<dyn LanguageModel>,
62 action_log: Entity<ActionLog>,
63 project: Entity<Project>,
64 templates: Arc<Templates>,
65}
66
67impl EditAgent {
68 pub fn new(
69 model: Arc<dyn LanguageModel>,
70 project: Entity<Project>,
71 action_log: Entity<ActionLog>,
72 templates: Arc<Templates>,
73 ) -> Self {
74 EditAgent {
75 model,
76 project,
77 action_log,
78 templates,
79 }
80 }
81
82 pub fn overwrite(
83 &self,
84 buffer: Entity<Buffer>,
85 edit_description: String,
86 previous_messages: Vec<LanguageModelRequestMessage>,
87 cx: &mut AsyncApp,
88 ) -> (
89 Task<Result<EditAgentOutput>>,
90 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
91 ) {
92 let this = self.clone();
93 let (events_tx, events_rx) = mpsc::unbounded();
94 let output = cx.spawn(async move |cx| {
95 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
96 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
97 let prompt = CreateFilePromptTemplate {
98 path,
99 edit_description,
100 }
101 .render(&this.templates)?;
102 let new_chunks = this.request(previous_messages, prompt, cx).await?;
103
104 let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
105 while let Some(event) = inner_events.next().await {
106 events_tx.unbounded_send(event).ok();
107 }
108 output.await
109 });
110 (output, events_rx)
111 }
112
113 fn overwrite_with_chunks(
114 &self,
115 buffer: Entity<Buffer>,
116 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
117 cx: &mut AsyncApp,
118 ) -> (
119 Task<Result<EditAgentOutput>>,
120 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
121 ) {
122 let (output_events_tx, output_events_rx) = mpsc::unbounded();
123 let this = self.clone();
124 let task = cx.spawn(async move |cx| {
125 this.action_log
126 .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
127 let output = this
128 .overwrite_with_chunks_internal(buffer, edit_chunks, output_events_tx, cx)
129 .await;
130 this.project
131 .update(cx, |project, cx| project.set_agent_location(None, cx))?;
132 output
133 });
134 (task, output_events_rx)
135 }
136
137 async fn overwrite_with_chunks_internal(
138 &self,
139 buffer: Entity<Buffer>,
140 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
141 output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
142 cx: &mut AsyncApp,
143 ) -> Result<EditAgentOutput> {
144 cx.update(|cx| {
145 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
146 self.action_log.update(cx, |log, cx| {
147 log.buffer_edited(buffer.clone(), cx);
148 });
149 self.project.update(cx, |project, cx| {
150 project.set_agent_location(
151 Some(AgentLocation {
152 buffer: buffer.downgrade(),
153 position: language::Anchor::MAX,
154 }),
155 cx,
156 )
157 });
158 output_events_tx
159 .unbounded_send(EditAgentOutputEvent::Edited)
160 .ok();
161 })?;
162
163 let mut raw_edits = String::new();
164 pin_mut!(edit_chunks);
165 while let Some(chunk) = edit_chunks.next().await {
166 let chunk = chunk?;
167 raw_edits.push_str(&chunk);
168 cx.update(|cx| {
169 buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
170 self.action_log
171 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
172 self.project.update(cx, |project, cx| {
173 project.set_agent_location(
174 Some(AgentLocation {
175 buffer: buffer.downgrade(),
176 position: language::Anchor::MAX,
177 }),
178 cx,
179 )
180 });
181 })?;
182 output_events_tx
183 .unbounded_send(EditAgentOutputEvent::Edited)
184 .ok();
185 }
186
187 Ok(EditAgentOutput {
188 _raw_edits: raw_edits,
189 _parser_metrics: EditParserMetrics::default(),
190 })
191 }
192
193 pub fn edit(
194 &self,
195 buffer: Entity<Buffer>,
196 edit_description: String,
197 previous_messages: Vec<LanguageModelRequestMessage>,
198 cx: &mut AsyncApp,
199 ) -> (
200 Task<Result<EditAgentOutput>>,
201 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
202 ) {
203 self.project
204 .update(cx, |project, cx| {
205 project.set_agent_location(
206 Some(AgentLocation {
207 buffer: buffer.downgrade(),
208 position: language::Anchor::MIN,
209 }),
210 cx,
211 );
212 })
213 .ok();
214
215 let this = self.clone();
216 let (events_tx, events_rx) = mpsc::unbounded();
217 let output = cx.spawn(async move |cx| {
218 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
219 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
220 let prompt = EditFilePromptTemplate {
221 path,
222 edit_description,
223 }
224 .render(&this.templates)?;
225 let edit_chunks = this.request(previous_messages, prompt, cx).await?;
226
227 let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
228 while let Some(event) = inner_events.next().await {
229 events_tx.unbounded_send(event).ok();
230 }
231 output.await
232 });
233 (output, events_rx)
234 }
235
236 fn apply_edit_chunks(
237 &self,
238 buffer: Entity<Buffer>,
239 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
240 cx: &mut AsyncApp,
241 ) -> (
242 Task<Result<EditAgentOutput>>,
243 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
244 ) {
245 let (output_events_tx, output_events_rx) = mpsc::unbounded();
246 let this = self.clone();
247 let task = cx.spawn(async move |mut cx| {
248 this.action_log
249 .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
250 let output = this
251 .apply_edit_chunks_internal(buffer, edit_chunks, output_events_tx, &mut cx)
252 .await;
253 this.project
254 .update(cx, |project, cx| project.set_agent_location(None, cx))?;
255 output
256 });
257 (task, output_events_rx)
258 }
259
260 async fn apply_edit_chunks_internal(
261 &self,
262 buffer: Entity<Buffer>,
263 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
264 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
265 cx: &mut AsyncApp,
266 ) -> Result<EditAgentOutput> {
267 let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
268 while let Some(edit_event) = edit_events.next().await {
269 let EditParserEvent::OldText(old_text_query) = edit_event? else {
270 continue;
271 };
272 let old_text_query = SharedString::from(old_text_query);
273
274 let (edits_tx, edits_rx) = mpsc::unbounded();
275 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
276 let old_range = cx
277 .background_spawn({
278 let snapshot = snapshot.clone();
279 let old_text_query = old_text_query.clone();
280 async move { Self::resolve_location(&snapshot, &old_text_query) }
281 })
282 .await;
283 let Some(old_range) = old_range else {
284 // We couldn't find the old text in the buffer. Report the error.
285 output_events
286 .unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
287 .ok();
288 continue;
289 };
290
291 let compute_edits = cx.background_spawn(async move {
292 let buffer_start_indent =
293 snapshot.line_indent_for_row(snapshot.offset_to_point(old_range.start).row);
294 let old_text_start_indent = old_text_query
295 .lines()
296 .next()
297 .map_or(buffer_start_indent, |line| {
298 LineIndent::from_iter(line.chars())
299 });
300 let indent_delta = if buffer_start_indent.tabs > 0 {
301 IndentDelta::Tabs(
302 buffer_start_indent.tabs as isize - old_text_start_indent.tabs as isize,
303 )
304 } else {
305 IndentDelta::Spaces(
306 buffer_start_indent.spaces as isize - old_text_start_indent.spaces as isize,
307 )
308 };
309
310 let old_text = snapshot
311 .text_for_range(old_range.clone())
312 .collect::<String>();
313 let mut diff = StreamingDiff::new(old_text);
314 let mut edit_start = old_range.start;
315 let mut new_text_chunks =
316 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
317 let mut done = false;
318 while !done {
319 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await
320 {
321 diff.push_new(&new_text_chunk?)
322 } else {
323 done = true;
324 mem::take(&mut diff).finish()
325 };
326
327 for op in char_operations {
328 match op {
329 CharOperation::Insert { text } => {
330 let edit_start = snapshot.anchor_after(edit_start);
331 edits_tx
332 .unbounded_send((edit_start..edit_start, Arc::from(text)))?;
333 }
334 CharOperation::Delete { bytes } => {
335 let edit_end = edit_start + bytes;
336 let edit_range = snapshot.anchor_after(edit_start)
337 ..snapshot.anchor_before(edit_end);
338 edit_start = edit_end;
339 edits_tx.unbounded_send((edit_range, Arc::from("")))?;
340 }
341 CharOperation::Keep { bytes } => edit_start += bytes,
342 }
343 }
344 }
345
346 drop(new_text_chunks);
347 anyhow::Ok(edit_events)
348 });
349
350 // TODO: group all edits into one transaction
351 let mut edits_rx = edits_rx.ready_chunks(32);
352 while let Some(edits) = edits_rx.next().await {
353 if edits.is_empty() {
354 continue;
355 }
356
357 // Edit the buffer and report edits to the action log as part of the
358 // same effect cycle, otherwise the edit will be reported as if the
359 // user made it.
360 cx.update(|cx| {
361 let max_edit_end = buffer.update(cx, |buffer, cx| {
362 buffer.edit(edits.iter().cloned(), None, cx);
363 let max_edit_end = buffer
364 .summaries_for_anchors::<Point, _>(
365 edits.iter().map(|(range, _)| &range.end),
366 )
367 .max()
368 .unwrap();
369 buffer.anchor_before(max_edit_end)
370 });
371 self.action_log
372 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
373 self.project.update(cx, |project, cx| {
374 project.set_agent_location(
375 Some(AgentLocation {
376 buffer: buffer.downgrade(),
377 position: max_edit_end,
378 }),
379 cx,
380 );
381 });
382 })?;
383 output_events
384 .unbounded_send(EditAgentOutputEvent::Edited)
385 .ok();
386 }
387
388 edit_events = compute_edits.await?;
389 }
390
391 output.await
392 }
393
394 fn parse_edit_chunks(
395 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
396 cx: &mut AsyncApp,
397 ) -> (
398 Task<Result<EditAgentOutput>>,
399 UnboundedReceiver<Result<EditParserEvent>>,
400 ) {
401 let (tx, rx) = mpsc::unbounded();
402 let output = cx.background_spawn(async move {
403 pin_mut!(chunks);
404
405 let mut parser = EditParser::new();
406 let mut raw_edits = String::new();
407 while let Some(chunk) = chunks.next().await {
408 match chunk {
409 Ok(chunk) => {
410 raw_edits.push_str(&chunk);
411 for event in parser.push(&chunk) {
412 tx.unbounded_send(Ok(event))?;
413 }
414 }
415 Err(error) => {
416 tx.unbounded_send(Err(error.into()))?;
417 }
418 }
419 }
420 Ok(EditAgentOutput {
421 _raw_edits: raw_edits,
422 _parser_metrics: parser.finish(),
423 })
424 });
425 (output, rx)
426 }
427
428 fn reindent_new_text_chunks(
429 delta: IndentDelta,
430 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
431 ) -> impl Stream<Item = Result<String>> {
432 let mut buffer = String::new();
433 let mut in_leading_whitespace = true;
434 let mut done = false;
435 futures::stream::poll_fn(move |cx| {
436 while !done {
437 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
438 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
439 (chunk, done)
440 }
441 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
442 Poll::Pending => return Poll::Pending,
443 _ => return Poll::Ready(None),
444 };
445
446 buffer.push_str(&chunk);
447
448 let mut indented_new_text = String::new();
449 let mut start_ix = 0;
450 let mut newlines = buffer.match_indices('\n').peekable();
451 loop {
452 let (line_end, is_pending_line) = match newlines.next() {
453 Some((ix, _)) => (ix, false),
454 None => (buffer.len(), true),
455 };
456 let line = &buffer[start_ix..line_end];
457
458 if in_leading_whitespace {
459 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
460 // We found a non-whitespace character, adjust
461 // indentation based on the delta.
462 let new_indent_len =
463 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
464 indented_new_text
465 .extend(iter::repeat(delta.character()).take(new_indent_len));
466 indented_new_text.push_str(&line[non_whitespace_ix..]);
467 in_leading_whitespace = false;
468 } else if is_pending_line {
469 // We're still in leading whitespace and this line is incomplete.
470 // Stop processing until we receive more input.
471 break;
472 } else {
473 // This line is entirely whitespace. Push it without indentation.
474 indented_new_text.push_str(line);
475 }
476 } else {
477 indented_new_text.push_str(line);
478 }
479
480 if is_pending_line {
481 start_ix = line_end;
482 break;
483 } else {
484 in_leading_whitespace = true;
485 indented_new_text.push('\n');
486 start_ix = line_end + 1;
487 }
488 }
489 buffer.replace_range(..start_ix, "");
490
491 // This was the last chunk, push all the buffered content as-is.
492 if is_last_chunk {
493 indented_new_text.push_str(&buffer);
494 buffer.clear();
495 done = true;
496 }
497
498 if !indented_new_text.is_empty() {
499 return Poll::Ready(Some(Ok(indented_new_text)));
500 }
501 }
502
503 Poll::Ready(None)
504 })
505 }
506
507 async fn request(
508 &self,
509 mut messages: Vec<LanguageModelRequestMessage>,
510 prompt: String,
511 cx: &mut AsyncApp,
512 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
513 let mut message_content = Vec::new();
514 if let Some(last_message) = messages.last_mut() {
515 if last_message.role == Role::Assistant {
516 last_message
517 .content
518 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
519 if last_message.content.is_empty() {
520 messages.pop();
521 }
522 }
523 }
524 message_content.push(MessageContent::Text(prompt));
525 messages.push(LanguageModelRequestMessage {
526 role: Role::User,
527 content: message_content,
528 cache: false,
529 });
530
531 let request = LanguageModelRequest {
532 messages,
533 ..Default::default()
534 };
535 Ok(self.model.stream_completion_text(request, cx).await?.stream)
536 }
537
538 fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
539 let range = Self::resolve_location_exact(buffer, search_query)
540 .or_else(|| Self::resolve_location_fuzzy(buffer, search_query))?;
541
542 // Expand the range to include entire lines.
543 let mut start = buffer.offset_to_point(buffer.clip_offset(range.start, Bias::Left));
544 start.column = 0;
545 let mut end = buffer.offset_to_point(buffer.clip_offset(range.end, Bias::Right));
546 if end.column > 0 {
547 end.column = buffer.line_len(end.row);
548 }
549
550 Some(buffer.point_to_offset(start)..buffer.point_to_offset(end))
551 }
552
553 fn resolve_location_exact(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
554 let search = AhoCorasick::new([search_query]).ok()?;
555 let mat = search
556 .stream_find_iter(buffer.bytes_in_range(0..buffer.len()))
557 .next()?
558 .expect("buffer can't error");
559 Some(mat.range())
560 }
561
562 fn resolve_location_fuzzy(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
563 const INSERTION_COST: u32 = 3;
564 const DELETION_COST: u32 = 10;
565
566 let buffer_line_count = buffer.max_point().row as usize + 1;
567 let query_line_count = search_query.lines().count();
568 let mut matrix = SearchMatrix::new(query_line_count + 1, buffer_line_count + 1);
569 let mut leading_deletion_cost = 0_u32;
570 for (row, query_line) in search_query.lines().enumerate() {
571 let query_line = query_line.trim();
572 leading_deletion_cost = leading_deletion_cost.saturating_add(DELETION_COST);
573 matrix.set(
574 row + 1,
575 0,
576 SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
577 );
578
579 let mut buffer_lines = buffer.as_rope().chunks().lines();
580 let mut col = 0;
581 while let Some(buffer_line) = buffer_lines.next() {
582 let buffer_line = buffer_line.trim();
583 let up = SearchState::new(
584 matrix.get(row, col + 1).cost.saturating_add(DELETION_COST),
585 SearchDirection::Up,
586 );
587 let left = SearchState::new(
588 matrix.get(row + 1, col).cost.saturating_add(INSERTION_COST),
589 SearchDirection::Left,
590 );
591 let diagonal = SearchState::new(
592 if fuzzy_eq(query_line, buffer_line) {
593 matrix.get(row, col).cost
594 } else {
595 matrix
596 .get(row, col)
597 .cost
598 .saturating_add(DELETION_COST + INSERTION_COST)
599 },
600 SearchDirection::Diagonal,
601 );
602 matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
603 col += 1;
604 }
605 }
606
607 // Traceback to find the best match
608 let mut buffer_row_end = buffer_line_count as u32;
609 let mut best_cost = u32::MAX;
610 for col in 1..=buffer_line_count {
611 let cost = matrix.get(query_line_count, col).cost;
612 if cost < best_cost {
613 best_cost = cost;
614 buffer_row_end = col as u32;
615 }
616 }
617
618 let mut matched_lines = 0;
619 let mut query_row = query_line_count;
620 let mut buffer_row_start = buffer_row_end;
621 while query_row > 0 && buffer_row_start > 0 {
622 let current = matrix.get(query_row, buffer_row_start as usize);
623 match current.direction {
624 SearchDirection::Diagonal => {
625 query_row -= 1;
626 buffer_row_start -= 1;
627 matched_lines += 1;
628 }
629 SearchDirection::Up => {
630 query_row -= 1;
631 }
632 SearchDirection::Left => {
633 buffer_row_start -= 1;
634 }
635 }
636 }
637
638 let matched_buffer_row_count = buffer_row_end - buffer_row_start;
639 let matched_ratio =
640 matched_lines as f32 / (matched_buffer_row_count as f32).max(query_line_count as f32);
641 if matched_ratio >= 0.8 {
642 let buffer_start_ix = buffer.point_to_offset(Point::new(buffer_row_start, 0));
643 let buffer_end_ix = buffer.point_to_offset(Point::new(
644 buffer_row_end - 1,
645 buffer.line_len(buffer_row_end - 1),
646 ));
647 Some(buffer_start_ix..buffer_end_ix)
648 } else {
649 None
650 }
651 }
652}
653
654fn fuzzy_eq(left: &str, right: &str) -> bool {
655 let min_levenshtein = left.len().abs_diff(right.len());
656 let min_normalized_levenshtein =
657 1. - (min_levenshtein as f32 / cmp::max(left.len(), right.len()) as f32);
658 if min_normalized_levenshtein < 0.8 {
659 return false;
660 }
661
662 strsim::normalized_levenshtein(left, right) >= 0.8
663}
664
665#[derive(Copy, Clone, Debug)]
666enum IndentDelta {
667 Spaces(isize),
668 Tabs(isize),
669}
670
671impl IndentDelta {
672 fn character(&self) -> char {
673 match self {
674 IndentDelta::Spaces(_) => ' ',
675 IndentDelta::Tabs(_) => '\t',
676 }
677 }
678
679 fn len(&self) -> isize {
680 match self {
681 IndentDelta::Spaces(n) => *n,
682 IndentDelta::Tabs(n) => *n,
683 }
684 }
685}
686
687#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
688enum SearchDirection {
689 Up,
690 Left,
691 Diagonal,
692}
693
694#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
695struct SearchState {
696 cost: u32,
697 direction: SearchDirection,
698}
699
700impl SearchState {
701 fn new(cost: u32, direction: SearchDirection) -> Self {
702 Self { cost, direction }
703 }
704}
705
706struct SearchMatrix {
707 cols: usize,
708 data: Vec<SearchState>,
709}
710
711impl SearchMatrix {
712 fn new(rows: usize, cols: usize) -> Self {
713 SearchMatrix {
714 cols,
715 data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
716 }
717 }
718
719 fn get(&self, row: usize, col: usize) -> SearchState {
720 self.data[row * self.cols + col]
721 }
722
723 fn set(&mut self, row: usize, col: usize, cost: SearchState) {
724 self.data[row * self.cols + col] = cost;
725 }
726}
727
728#[cfg(test)]
729mod tests {
730 use super::*;
731 use fs::FakeFs;
732 use futures::stream;
733 use gpui::{App, AppContext, TestAppContext};
734 use indoc::indoc;
735 use language_model::fake_provider::FakeLanguageModel;
736 use project::{AgentLocation, Project};
737 use rand::prelude::*;
738 use rand::rngs::StdRng;
739 use std::cmp;
740 use unindent::Unindent;
741 use util::test::{generate_marked_text, marked_text_ranges};
742
743 #[gpui::test(iterations = 100)]
744 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
745 let agent = init_test(cx).await;
746 let buffer = cx.new(|cx| {
747 Buffer::local(
748 indoc! {"
749 lorem
750 ipsum
751 dolor
752 sit
753 "},
754 cx,
755 )
756 });
757 let raw_edits = simulate_llm_output(
758 indoc! {"
759 <old_text>
760 ipsum
761 dolor
762 sit
763 </old_text>
764 <new_text>
765 ipsum
766 dolor
767 sit
768 amet
769 </new_text>
770 "},
771 &mut rng,
772 cx,
773 );
774 let (apply, _events) =
775 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
776 apply.await.unwrap();
777 pretty_assertions::assert_eq!(
778 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
779 indoc! {"
780 lorem
781 ipsum
782 dolor
783 sit
784 amet
785 "}
786 );
787 }
788
789 #[gpui::test(iterations = 100)]
790 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
791 let agent = init_test(cx).await;
792 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
793 let raw_edits = simulate_llm_output(
794 indoc! {"
795 <old_text>
796 def
797 </old_text>
798 <new_text>
799 DEF
800 </new_text>
801
802 <old_text>
803 DEF
804 </old_text>
805 <new_text>
806 DeF
807 </new_text>
808 "},
809 &mut rng,
810 cx,
811 );
812 let (apply, _events) =
813 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
814 apply.await.unwrap();
815 assert_eq!(
816 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
817 "abc\nDeF\nghi"
818 );
819 }
820
821 #[gpui::test(iterations = 100)]
822 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
823 let agent = init_test(cx).await;
824 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
825 let raw_edits = simulate_llm_output(
826 indoc! {"
827 <old_text>
828 jkl
829 </old_text>
830 <new_text>
831 mno
832 </new_text>
833
834 <old_text>
835 abc
836 </old_text>
837 <new_text>
838 ABC
839 </new_text>
840 "},
841 &mut rng,
842 cx,
843 );
844 let (apply, _events) =
845 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
846 apply.await.unwrap();
847 assert_eq!(
848 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
849 "ABC\ndef\nghi"
850 );
851 }
852
853 #[gpui::test]
854 async fn test_edit_events(cx: &mut TestAppContext) {
855 let agent = init_test(cx).await;
856 let project = agent
857 .action_log
858 .read_with(cx, |log, _| log.project().clone());
859 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
860 let (chunks_tx, chunks_rx) = mpsc::unbounded();
861 let (apply, mut events) = agent.apply_edit_chunks(
862 buffer.clone(),
863 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
864 &mut cx.to_async(),
865 );
866
867 chunks_tx.unbounded_send("<old_text>a").unwrap();
868 cx.run_until_parked();
869 assert_eq!(drain_events(&mut events), vec![]);
870 assert_eq!(
871 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
872 "abc\ndef\nghi"
873 );
874 assert_eq!(
875 project.read_with(cx, |project, _| project.agent_location()),
876 None
877 );
878
879 chunks_tx.unbounded_send("bc</old_text>").unwrap();
880 cx.run_until_parked();
881 assert_eq!(drain_events(&mut events), vec![]);
882 assert_eq!(
883 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
884 "abc\ndef\nghi"
885 );
886 assert_eq!(
887 project.read_with(cx, |project, _| project.agent_location()),
888 None
889 );
890
891 chunks_tx.unbounded_send("<new_text>abX").unwrap();
892 cx.run_until_parked();
893 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
894 assert_eq!(
895 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
896 "abXc\ndef\nghi"
897 );
898 assert_eq!(
899 project.read_with(cx, |project, _| project.agent_location()),
900 Some(AgentLocation {
901 buffer: buffer.downgrade(),
902 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
903 })
904 );
905
906 chunks_tx.unbounded_send("cY").unwrap();
907 cx.run_until_parked();
908 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
909 assert_eq!(
910 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
911 "abXcY\ndef\nghi"
912 );
913 assert_eq!(
914 project.read_with(cx, |project, _| project.agent_location()),
915 Some(AgentLocation {
916 buffer: buffer.downgrade(),
917 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
918 })
919 );
920
921 chunks_tx.unbounded_send("</new_text>").unwrap();
922 chunks_tx.unbounded_send("<old_text>hall").unwrap();
923 cx.run_until_parked();
924 assert_eq!(drain_events(&mut events), vec![]);
925 assert_eq!(
926 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
927 "abXcY\ndef\nghi"
928 );
929 assert_eq!(
930 project.read_with(cx, |project, _| project.agent_location()),
931 Some(AgentLocation {
932 buffer: buffer.downgrade(),
933 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
934 })
935 );
936
937 chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
938 chunks_tx.unbounded_send("<new_text>").unwrap();
939 cx.run_until_parked();
940 assert_eq!(
941 drain_events(&mut events),
942 vec![EditAgentOutputEvent::OldTextNotFound(
943 "hallucinated old".into()
944 )]
945 );
946 assert_eq!(
947 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
948 "abXcY\ndef\nghi"
949 );
950 assert_eq!(
951 project.read_with(cx, |project, _| project.agent_location()),
952 Some(AgentLocation {
953 buffer: buffer.downgrade(),
954 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
955 })
956 );
957
958 chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
959 chunks_tx.unbounded_send("text>").unwrap();
960 cx.run_until_parked();
961 assert_eq!(drain_events(&mut events), vec![]);
962 assert_eq!(
963 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
964 "abXcY\ndef\nghi"
965 );
966 assert_eq!(
967 project.read_with(cx, |project, _| project.agent_location()),
968 Some(AgentLocation {
969 buffer: buffer.downgrade(),
970 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
971 })
972 );
973
974 chunks_tx.unbounded_send("<old_text>gh").unwrap();
975 chunks_tx.unbounded_send("i</old_text>").unwrap();
976 chunks_tx.unbounded_send("<new_text>").unwrap();
977 cx.run_until_parked();
978 assert_eq!(drain_events(&mut events), vec![]);
979 assert_eq!(
980 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
981 "abXcY\ndef\nghi"
982 );
983 assert_eq!(
984 project.read_with(cx, |project, _| project.agent_location()),
985 Some(AgentLocation {
986 buffer: buffer.downgrade(),
987 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
988 })
989 );
990
991 chunks_tx.unbounded_send("GHI</new_text>").unwrap();
992 cx.run_until_parked();
993 assert_eq!(
994 drain_events(&mut events),
995 vec![EditAgentOutputEvent::Edited]
996 );
997 assert_eq!(
998 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
999 "abXcY\ndef\nGHI"
1000 );
1001 assert_eq!(
1002 project.read_with(cx, |project, _| project.agent_location()),
1003 Some(AgentLocation {
1004 buffer: buffer.downgrade(),
1005 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1006 })
1007 );
1008
1009 drop(chunks_tx);
1010 apply.await.unwrap();
1011 assert_eq!(
1012 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1013 "abXcY\ndef\nGHI"
1014 );
1015 assert_eq!(drain_events(&mut events), vec![]);
1016 assert_eq!(
1017 project.read_with(cx, |project, _| project.agent_location()),
1018 None
1019 );
1020 }
1021
1022 #[gpui::test]
1023 async fn test_overwrite_events(cx: &mut TestAppContext) {
1024 let agent = init_test(cx).await;
1025 let project = agent
1026 .action_log
1027 .read_with(cx, |log, _| log.project().clone());
1028 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1029 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1030 let (apply, mut events) = agent.overwrite_with_chunks(
1031 buffer.clone(),
1032 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1033 &mut cx.to_async(),
1034 );
1035
1036 cx.run_until_parked();
1037 assert_eq!(
1038 drain_events(&mut events),
1039 vec![EditAgentOutputEvent::Edited]
1040 );
1041 assert_eq!(
1042 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1043 ""
1044 );
1045 assert_eq!(
1046 project.read_with(cx, |project, _| project.agent_location()),
1047 Some(AgentLocation {
1048 buffer: buffer.downgrade(),
1049 position: language::Anchor::MAX
1050 })
1051 );
1052
1053 chunks_tx.unbounded_send("jkl\n").unwrap();
1054 cx.run_until_parked();
1055 assert_eq!(
1056 drain_events(&mut events),
1057 vec![EditAgentOutputEvent::Edited]
1058 );
1059 assert_eq!(
1060 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1061 "jkl\n"
1062 );
1063 assert_eq!(
1064 project.read_with(cx, |project, _| project.agent_location()),
1065 Some(AgentLocation {
1066 buffer: buffer.downgrade(),
1067 position: language::Anchor::MAX
1068 })
1069 );
1070
1071 chunks_tx.unbounded_send("mno\n").unwrap();
1072 cx.run_until_parked();
1073 assert_eq!(
1074 drain_events(&mut events),
1075 vec![EditAgentOutputEvent::Edited]
1076 );
1077 assert_eq!(
1078 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1079 "jkl\nmno\n"
1080 );
1081 assert_eq!(
1082 project.read_with(cx, |project, _| project.agent_location()),
1083 Some(AgentLocation {
1084 buffer: buffer.downgrade(),
1085 position: language::Anchor::MAX
1086 })
1087 );
1088
1089 chunks_tx.unbounded_send("pqr").unwrap();
1090 cx.run_until_parked();
1091 assert_eq!(
1092 drain_events(&mut events),
1093 vec![EditAgentOutputEvent::Edited]
1094 );
1095 assert_eq!(
1096 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1097 "jkl\nmno\npqr"
1098 );
1099 assert_eq!(
1100 project.read_with(cx, |project, _| project.agent_location()),
1101 Some(AgentLocation {
1102 buffer: buffer.downgrade(),
1103 position: language::Anchor::MAX
1104 })
1105 );
1106
1107 drop(chunks_tx);
1108 apply.await.unwrap();
1109 assert_eq!(
1110 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1111 "jkl\nmno\npqr"
1112 );
1113 assert_eq!(drain_events(&mut events), vec![]);
1114 assert_eq!(
1115 project.read_with(cx, |project, _| project.agent_location()),
1116 None
1117 );
1118 }
1119
1120 #[gpui::test]
1121 fn test_resolve_location(cx: &mut App) {
1122 assert_location_resolution(
1123 concat!(
1124 " Lorem\n",
1125 "« ipsum»\n",
1126 " dolor sit amet\n",
1127 " consecteur",
1128 ),
1129 "ipsum",
1130 cx,
1131 );
1132
1133 assert_location_resolution(
1134 concat!(
1135 " Lorem\n",
1136 "« ipsum\n",
1137 " dolor sit amet»\n",
1138 " consecteur",
1139 ),
1140 "ipsum\ndolor sit amet",
1141 cx,
1142 );
1143
1144 assert_location_resolution(
1145 &"
1146 «fn foo1(a: usize) -> usize {
1147 40
1148 }»
1149
1150 fn foo2(b: usize) -> usize {
1151 42
1152 }
1153 "
1154 .unindent(),
1155 "fn foo1(a: usize) -> u32 {\n40\n}",
1156 cx,
1157 );
1158
1159 assert_location_resolution(
1160 &"
1161 class Something {
1162 one() { return 1; }
1163 « two() { return 2222; }
1164 three() { return 333; }
1165 four() { return 4444; }
1166 five() { return 5555; }
1167 six() { return 6666; }»
1168 seven() { return 7; }
1169 eight() { return 8; }
1170 }
1171 "
1172 .unindent(),
1173 &"
1174 two() { return 2222; }
1175 four() { return 4444; }
1176 five() { return 5555; }
1177 six() { return 6666; }
1178 "
1179 .unindent(),
1180 cx,
1181 );
1182
1183 assert_location_resolution(
1184 &"
1185 use std::ops::Range;
1186 use std::sync::Mutex;
1187 use std::{
1188 collections::HashMap,
1189 env,
1190 ffi::{OsStr, OsString},
1191 fs,
1192 io::{BufRead, BufReader},
1193 mem,
1194 path::{Path, PathBuf},
1195 process::Command,
1196 sync::LazyLock,
1197 time::SystemTime,
1198 };
1199 "
1200 .unindent(),
1201 &"
1202 use std::collections::{HashMap, HashSet};
1203 use std::ffi::{OsStr, OsString};
1204 use std::fmt::Write as _;
1205 use std::fs;
1206 use std::io::{BufReader, Read, Write};
1207 use std::mem;
1208 use std::path::{Path, PathBuf};
1209 use std::process::Command;
1210 use std::sync::Arc;
1211 "
1212 .unindent(),
1213 cx,
1214 );
1215
1216 assert_location_resolution(
1217 indoc! {"
1218 impl Foo {
1219 fn new() -> Self {
1220 Self {
1221 subscriptions: vec![
1222 cx.observe_window_activation(window, |editor, window, cx| {
1223 let active = window.is_window_active();
1224 editor.blink_manager.update(cx, |blink_manager, cx| {
1225 if active {
1226 blink_manager.enable(cx);
1227 } else {
1228 blink_manager.disable(cx);
1229 }
1230 });
1231 }),
1232 ];
1233 }
1234 }
1235 }
1236 "},
1237 concat!(
1238 " editor.blink_manager.update(cx, |blink_manager, cx| {\n",
1239 " blink_manager.enable(cx);\n",
1240 " });",
1241 ),
1242 cx,
1243 );
1244
1245 assert_location_resolution(
1246 indoc! {r#"
1247 let tool = cx
1248 .update(|cx| working_set.tool(&tool_name, cx))
1249 .map_err(|err| {
1250 anyhow!("Failed to look up tool '{}': {}", tool_name, err)
1251 })?;
1252
1253 let Some(tool) = tool else {
1254 return Err(anyhow!("Tool '{}' not found", tool_name));
1255 };
1256
1257 let project = project.clone();
1258 let action_log = action_log.clone();
1259 let messages = messages.clone();
1260 let tool_result = cx
1261 .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
1262 .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
1263
1264 tasks.push(tool_result.output);
1265 "#},
1266 concat!(
1267 "let tool_result = cx\n",
1268 " .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
1269 " .output;",
1270 ),
1271 cx,
1272 );
1273 }
1274
1275 #[gpui::test(iterations = 100)]
1276 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1277 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1278 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1279 Ok(EditParserEvent::NewTextChunk {
1280 chunk: chunk.clone(),
1281 done: index == chunks.len() - 1,
1282 })
1283 }));
1284 let indented_chunks =
1285 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1286 .collect::<Vec<_>>()
1287 .await;
1288 let new_text = indented_chunks
1289 .into_iter()
1290 .collect::<Result<String>>()
1291 .unwrap();
1292 assert_eq!(new_text, " abc\n def\n ghi");
1293 }
1294
1295 #[gpui::test(iterations = 100)]
1296 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1297 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1298 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1299 Ok(EditParserEvent::NewTextChunk {
1300 chunk: chunk.clone(),
1301 done: index == chunks.len() - 1,
1302 })
1303 }));
1304 let indented_chunks =
1305 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1306 .collect::<Vec<_>>()
1307 .await;
1308 let new_text = indented_chunks
1309 .into_iter()
1310 .collect::<Result<String>>()
1311 .unwrap();
1312 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1313 }
1314
1315 #[gpui::test(iterations = 100)]
1316 async fn test_random_indents(mut rng: StdRng) {
1317 let len = rng.gen_range(1..=100);
1318 let new_text = util::RandomCharIter::new(&mut rng)
1319 .with_simple_text()
1320 .take(len)
1321 .collect::<String>();
1322 let new_text = new_text
1323 .split('\n')
1324 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1325 .collect::<Vec<_>>()
1326 .join("\n");
1327 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1328
1329 let chunks = to_random_chunks(&mut rng, &new_text);
1330 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1331 Ok(EditParserEvent::NewTextChunk {
1332 chunk: chunk.clone(),
1333 done: index == chunks.len() - 1,
1334 })
1335 }));
1336 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1337 .collect::<Vec<_>>()
1338 .await;
1339 let actual_reindented_text = reindented_chunks
1340 .into_iter()
1341 .collect::<Result<String>>()
1342 .unwrap();
1343 let expected_reindented_text = new_text
1344 .split('\n')
1345 .map(|line| {
1346 if let Some(ix) = line.find(|c| c != ' ') {
1347 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1348 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1349 } else {
1350 line.to_string()
1351 }
1352 })
1353 .collect::<Vec<_>>()
1354 .join("\n");
1355 assert_eq!(actual_reindented_text, expected_reindented_text);
1356 }
1357
1358 #[track_caller]
1359 fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) {
1360 let (text, _) = marked_text_ranges(text_with_expected_range, false);
1361 let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
1362 let snapshot = buffer.read(cx).snapshot();
1363 let mut ranges = Vec::new();
1364 ranges.extend(EditAgent::resolve_location(&snapshot, query));
1365 let text_with_actual_range = generate_marked_text(&text, &ranges, false);
1366 pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
1367 }
1368
1369 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1370 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1371 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1372 chunk_indices.sort();
1373 chunk_indices.push(input.len());
1374
1375 let mut chunks = Vec::new();
1376 let mut last_ix = 0;
1377 for chunk_ix in chunk_indices {
1378 chunks.push(input[last_ix..chunk_ix].to_string());
1379 last_ix = chunk_ix;
1380 }
1381 chunks
1382 }
1383
1384 fn simulate_llm_output(
1385 output: &str,
1386 rng: &mut StdRng,
1387 cx: &mut TestAppContext,
1388 ) -> impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>> {
1389 let executor = cx.executor();
1390 stream::iter(to_random_chunks(rng, output).into_iter().map(Ok)).then(move |chunk| {
1391 let executor = executor.clone();
1392 async move {
1393 executor.simulate_random_delay().await;
1394 chunk
1395 }
1396 })
1397 }
1398
1399 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1400 cx.update(settings::init);
1401 cx.update(Project::init_settings);
1402 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1403 let model = Arc::new(FakeLanguageModel::default());
1404 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1405 EditAgent::new(model, project, action_log, Templates::new())
1406 }
1407
1408 fn drain_events(
1409 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1410 ) -> Vec<EditAgentOutputEvent> {
1411 let mut events = Vec::new();
1412 while let Ok(Some(event)) = stream.try_next() {
1413 events.push(event);
1414 }
1415 events
1416 }
1417}