1mod edit_parser;
2#[cfg(test)]
3mod evals;
4
5use crate::{Template, Templates};
6use aho_corasick::AhoCorasick;
7use anyhow::Result;
8use assistant_tool::ActionLog;
9use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
10use futures::{
11 Stream, StreamExt,
12 channel::mpsc::{self, UnboundedReceiver},
13 pin_mut,
14 stream::BoxStream,
15};
16use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
17use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
18use language_model::{
19 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
20 MessageContent, Role,
21};
22use serde::Serialize;
23use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
24use streaming_diff::{CharOperation, StreamingDiff};
25
26#[derive(Serialize)]
27struct CreateFilePromptTemplate {
28 path: Option<PathBuf>,
29 edit_description: String,
30}
31
32impl Template for CreateFilePromptTemplate {
33 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
34}
35
36#[derive(Serialize)]
37struct EditFilePromptTemplate {
38 path: Option<PathBuf>,
39 edit_description: String,
40}
41
42impl Template for EditFilePromptTemplate {
43 const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
44}
45
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub enum EditAgentOutputEvent {
48 Edited,
49 OldTextNotFound(SharedString),
50}
51
52#[derive(Clone, Debug)]
53pub struct EditAgentOutput {
54 pub _raw_edits: String,
55 pub _parser_metrics: EditParserMetrics,
56}
57
58#[derive(Clone)]
59pub struct EditAgent {
60 model: Arc<dyn LanguageModel>,
61 action_log: Entity<ActionLog>,
62 templates: Arc<Templates>,
63}
64
65impl EditAgent {
66 pub fn new(
67 model: Arc<dyn LanguageModel>,
68 action_log: Entity<ActionLog>,
69 templates: Arc<Templates>,
70 ) -> Self {
71 EditAgent {
72 model,
73 action_log,
74 templates,
75 }
76 }
77
78 pub fn overwrite(
79 &self,
80 buffer: Entity<Buffer>,
81 edit_description: String,
82 previous_messages: Vec<LanguageModelRequestMessage>,
83 cx: &mut AsyncApp,
84 ) -> (
85 Task<Result<EditAgentOutput>>,
86 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
87 ) {
88 let this = self.clone();
89 let (events_tx, events_rx) = mpsc::unbounded();
90 let output = cx.spawn(async move |cx| {
91 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
92 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
93 let prompt = CreateFilePromptTemplate {
94 path,
95 edit_description,
96 }
97 .render(&this.templates)?;
98 let new_chunks = this.request(previous_messages, prompt, cx).await?;
99
100 let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
101 while let Some(event) = inner_events.next().await {
102 events_tx.unbounded_send(event).ok();
103 }
104 output.await
105 });
106 (output, events_rx)
107 }
108
109 fn replace_text_with_chunks(
110 &self,
111 buffer: Entity<Buffer>,
112 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
113 cx: &mut AsyncApp,
114 ) -> (
115 Task<Result<EditAgentOutput>>,
116 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
117 ) {
118 let (output_events_tx, output_events_rx) = mpsc::unbounded();
119 let this = self.clone();
120 let task = cx.spawn(async move |cx| {
121 // Ensure the buffer is tracked by the action log.
122 this.action_log
123 .update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
124
125 cx.update(|cx| {
126 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
127 this.action_log
128 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
129 })?;
130
131 let mut raw_edits = String::new();
132 pin_mut!(edit_chunks);
133 while let Some(chunk) = edit_chunks.next().await {
134 let chunk = chunk?;
135 raw_edits.push_str(&chunk);
136 cx.update(|cx| {
137 buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
138 this.action_log
139 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
140 })?;
141 output_events_tx
142 .unbounded_send(EditAgentOutputEvent::Edited)
143 .ok();
144 }
145
146 Ok(EditAgentOutput {
147 _raw_edits: raw_edits,
148 _parser_metrics: EditParserMetrics::default(),
149 })
150 });
151 (task, output_events_rx)
152 }
153
154 pub fn edit(
155 &self,
156 buffer: Entity<Buffer>,
157 edit_description: String,
158 previous_messages: Vec<LanguageModelRequestMessage>,
159 cx: &mut AsyncApp,
160 ) -> (
161 Task<Result<EditAgentOutput>>,
162 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
163 ) {
164 let this = self.clone();
165 let (events_tx, events_rx) = mpsc::unbounded();
166 let output = cx.spawn(async move |cx| {
167 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
168 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
169 let prompt = EditFilePromptTemplate {
170 path,
171 edit_description,
172 }
173 .render(&this.templates)?;
174 let edit_chunks = this.request(previous_messages, prompt, cx).await?;
175
176 let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
177 while let Some(event) = inner_events.next().await {
178 events_tx.unbounded_send(event).ok();
179 }
180 output.await
181 });
182 (output, events_rx)
183 }
184
185 fn apply_edit_chunks(
186 &self,
187 buffer: Entity<Buffer>,
188 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
189 cx: &mut AsyncApp,
190 ) -> (
191 Task<Result<EditAgentOutput>>,
192 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
193 ) {
194 let (output_events_tx, output_events_rx) = mpsc::unbounded();
195 let this = self.clone();
196 let task = cx.spawn(async move |mut cx| {
197 this.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
198 .await
199 });
200 (task, output_events_rx)
201 }
202
203 async fn apply_edits_internal(
204 &self,
205 buffer: Entity<Buffer>,
206 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
207 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
208 cx: &mut AsyncApp,
209 ) -> Result<EditAgentOutput> {
210 // Ensure the buffer is tracked by the action log.
211 self.action_log
212 .update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
213
214 let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
215 while let Some(edit_event) = edit_events.next().await {
216 let EditParserEvent::OldText(old_text_query) = edit_event? else {
217 continue;
218 };
219 let old_text_query = SharedString::from(old_text_query);
220
221 let (edits_tx, edits_rx) = mpsc::unbounded();
222 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
223 let old_range = cx
224 .background_spawn({
225 let snapshot = snapshot.clone();
226 let old_text_query = old_text_query.clone();
227 async move { Self::resolve_location(&snapshot, &old_text_query) }
228 })
229 .await;
230 let Some(old_range) = old_range else {
231 // We couldn't find the old text in the buffer. Report the error.
232 output_events
233 .unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
234 .ok();
235 continue;
236 };
237
238 let compute_edits = cx.background_spawn(async move {
239 let buffer_start_indent =
240 snapshot.line_indent_for_row(snapshot.offset_to_point(old_range.start).row);
241 let old_text_start_indent = old_text_query
242 .lines()
243 .next()
244 .map_or(buffer_start_indent, |line| {
245 LineIndent::from_iter(line.chars())
246 });
247 let indent_delta = if buffer_start_indent.tabs > 0 {
248 IndentDelta::Tabs(
249 buffer_start_indent.tabs as isize - old_text_start_indent.tabs as isize,
250 )
251 } else {
252 IndentDelta::Spaces(
253 buffer_start_indent.spaces as isize - old_text_start_indent.spaces as isize,
254 )
255 };
256
257 let old_text = snapshot
258 .text_for_range(old_range.clone())
259 .collect::<String>();
260 let mut diff = StreamingDiff::new(old_text);
261 let mut edit_start = old_range.start;
262 let mut new_text_chunks =
263 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
264 let mut done = false;
265 while !done {
266 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await
267 {
268 diff.push_new(&new_text_chunk?)
269 } else {
270 done = true;
271 mem::take(&mut diff).finish()
272 };
273
274 for op in char_operations {
275 match op {
276 CharOperation::Insert { text } => {
277 let edit_start = snapshot.anchor_after(edit_start);
278 edits_tx.unbounded_send((edit_start..edit_start, text))?;
279 }
280 CharOperation::Delete { bytes } => {
281 let edit_end = edit_start + bytes;
282 let edit_range = snapshot.anchor_after(edit_start)
283 ..snapshot.anchor_before(edit_end);
284 edit_start = edit_end;
285 edits_tx.unbounded_send((edit_range, String::new()))?;
286 }
287 CharOperation::Keep { bytes } => edit_start += bytes,
288 }
289 }
290 }
291
292 drop(new_text_chunks);
293 anyhow::Ok(edit_events)
294 });
295
296 // TODO: group all edits into one transaction
297 let mut edits_rx = edits_rx.ready_chunks(32);
298 while let Some(edits) = edits_rx.next().await {
299 // Edit the buffer and report edits to the action log as part of the
300 // same effect cycle, otherwise the edit will be reported as if the
301 // user made it.
302 cx.update(|cx| {
303 buffer.update(cx, |buffer, cx| buffer.edit(edits, None, cx));
304 self.action_log
305 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
306 })?;
307 output_events
308 .unbounded_send(EditAgentOutputEvent::Edited)
309 .ok();
310 }
311
312 edit_events = compute_edits.await?;
313 }
314
315 output.await
316 }
317
318 fn parse_edit_chunks(
319 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
320 cx: &mut AsyncApp,
321 ) -> (
322 Task<Result<EditAgentOutput>>,
323 UnboundedReceiver<Result<EditParserEvent>>,
324 ) {
325 let (tx, rx) = mpsc::unbounded();
326 let output = cx.background_spawn(async move {
327 pin_mut!(chunks);
328
329 let mut parser = EditParser::new();
330 let mut raw_edits = String::new();
331 while let Some(chunk) = chunks.next().await {
332 match chunk {
333 Ok(chunk) => {
334 raw_edits.push_str(&chunk);
335 for event in parser.push(&chunk) {
336 tx.unbounded_send(Ok(event))?;
337 }
338 }
339 Err(error) => {
340 tx.unbounded_send(Err(error.into()))?;
341 }
342 }
343 }
344 Ok(EditAgentOutput {
345 _raw_edits: raw_edits,
346 _parser_metrics: parser.finish(),
347 })
348 });
349 (output, rx)
350 }
351
352 fn reindent_new_text_chunks(
353 delta: IndentDelta,
354 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
355 ) -> impl Stream<Item = Result<String>> {
356 let mut buffer = String::new();
357 let mut in_leading_whitespace = true;
358 let mut done = false;
359 futures::stream::poll_fn(move |cx| {
360 while !done {
361 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
362 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
363 (chunk, done)
364 }
365 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
366 Poll::Pending => return Poll::Pending,
367 _ => return Poll::Ready(None),
368 };
369
370 buffer.push_str(&chunk);
371
372 let mut indented_new_text = String::new();
373 let mut start_ix = 0;
374 let mut newlines = buffer.match_indices('\n').peekable();
375 loop {
376 let (line_end, is_pending_line) = match newlines.next() {
377 Some((ix, _)) => (ix, false),
378 None => (buffer.len(), true),
379 };
380 let line = &buffer[start_ix..line_end];
381
382 if in_leading_whitespace {
383 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
384 // We found a non-whitespace character, adjust
385 // indentation based on the delta.
386 let new_indent_len =
387 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
388 indented_new_text
389 .extend(iter::repeat(delta.character()).take(new_indent_len));
390 indented_new_text.push_str(&line[non_whitespace_ix..]);
391 in_leading_whitespace = false;
392 } else if is_pending_line {
393 // We're still in leading whitespace and this line is incomplete.
394 // Stop processing until we receive more input.
395 break;
396 } else {
397 // This line is entirely whitespace. Push it without indentation.
398 indented_new_text.push_str(line);
399 }
400 } else {
401 indented_new_text.push_str(line);
402 }
403
404 if is_pending_line {
405 start_ix = line_end;
406 break;
407 } else {
408 in_leading_whitespace = true;
409 indented_new_text.push('\n');
410 start_ix = line_end + 1;
411 }
412 }
413 buffer.replace_range(..start_ix, "");
414
415 // This was the last chunk, push all the buffered content as-is.
416 if is_last_chunk {
417 indented_new_text.push_str(&buffer);
418 buffer.clear();
419 done = true;
420 }
421
422 if !indented_new_text.is_empty() {
423 return Poll::Ready(Some(Ok(indented_new_text)));
424 }
425 }
426
427 Poll::Ready(None)
428 })
429 }
430
431 async fn request(
432 &self,
433 mut messages: Vec<LanguageModelRequestMessage>,
434 prompt: String,
435 cx: &mut AsyncApp,
436 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
437 let mut message_content = Vec::new();
438 if let Some(last_message) = messages.last_mut() {
439 if last_message.role == Role::Assistant {
440 last_message
441 .content
442 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
443 if last_message.content.is_empty() {
444 messages.pop();
445 }
446 }
447 }
448 message_content.push(MessageContent::Text(prompt));
449 messages.push(LanguageModelRequestMessage {
450 role: Role::User,
451 content: message_content,
452 cache: false,
453 });
454
455 let request = LanguageModelRequest {
456 messages,
457 ..Default::default()
458 };
459 Ok(self.model.stream_completion_text(request, cx).await?.stream)
460 }
461
462 fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
463 let range = Self::resolve_location_exact(buffer, search_query)
464 .or_else(|| Self::resolve_location_fuzzy(buffer, search_query))?;
465
466 // Expand the range to include entire lines.
467 let mut start = buffer.offset_to_point(buffer.clip_offset(range.start, Bias::Left));
468 start.column = 0;
469 let mut end = buffer.offset_to_point(buffer.clip_offset(range.end, Bias::Right));
470 if end.column > 0 {
471 end.column = buffer.line_len(end.row);
472 }
473
474 Some(buffer.point_to_offset(start)..buffer.point_to_offset(end))
475 }
476
477 fn resolve_location_exact(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
478 let search = AhoCorasick::new([search_query]).ok()?;
479 let mat = search
480 .stream_find_iter(buffer.bytes_in_range(0..buffer.len()))
481 .next()?
482 .expect("buffer can't error");
483 Some(mat.range())
484 }
485
486 fn resolve_location_fuzzy(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
487 const INSERTION_COST: u32 = 3;
488 const DELETION_COST: u32 = 10;
489
490 let buffer_line_count = buffer.max_point().row as usize + 1;
491 let query_line_count = search_query.lines().count();
492 let mut matrix = SearchMatrix::new(query_line_count + 1, buffer_line_count + 1);
493 let mut leading_deletion_cost = 0_u32;
494 for (row, query_line) in search_query.lines().enumerate() {
495 let query_line = query_line.trim();
496 leading_deletion_cost = leading_deletion_cost.saturating_add(DELETION_COST);
497 matrix.set(
498 row + 1,
499 0,
500 SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
501 );
502
503 let mut buffer_lines = buffer.as_rope().chunks().lines();
504 let mut col = 0;
505 while let Some(buffer_line) = buffer_lines.next() {
506 let buffer_line = buffer_line.trim();
507 let up = SearchState::new(
508 matrix.get(row, col + 1).cost.saturating_add(DELETION_COST),
509 SearchDirection::Up,
510 );
511 let left = SearchState::new(
512 matrix.get(row + 1, col).cost.saturating_add(INSERTION_COST),
513 SearchDirection::Left,
514 );
515 let diagonal = SearchState::new(
516 if fuzzy_eq(query_line, buffer_line) {
517 matrix.get(row, col).cost
518 } else {
519 matrix
520 .get(row, col)
521 .cost
522 .saturating_add(DELETION_COST + INSERTION_COST)
523 },
524 SearchDirection::Diagonal,
525 );
526 matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
527 col += 1;
528 }
529 }
530
531 // Traceback to find the best match
532 let mut buffer_row_end = buffer_line_count as u32;
533 let mut best_cost = u32::MAX;
534 for col in 1..=buffer_line_count {
535 let cost = matrix.get(query_line_count, col).cost;
536 if cost < best_cost {
537 best_cost = cost;
538 buffer_row_end = col as u32;
539 }
540 }
541
542 let mut matched_lines = 0;
543 let mut query_row = query_line_count;
544 let mut buffer_row_start = buffer_row_end;
545 while query_row > 0 && buffer_row_start > 0 {
546 let current = matrix.get(query_row, buffer_row_start as usize);
547 match current.direction {
548 SearchDirection::Diagonal => {
549 query_row -= 1;
550 buffer_row_start -= 1;
551 matched_lines += 1;
552 }
553 SearchDirection::Up => {
554 query_row -= 1;
555 }
556 SearchDirection::Left => {
557 buffer_row_start -= 1;
558 }
559 }
560 }
561
562 let matched_buffer_row_count = buffer_row_end - buffer_row_start;
563 let matched_ratio =
564 matched_lines as f32 / (matched_buffer_row_count as f32).max(query_line_count as f32);
565 if matched_ratio >= 0.8 {
566 let buffer_start_ix = buffer.point_to_offset(Point::new(buffer_row_start, 0));
567 let buffer_end_ix = buffer.point_to_offset(Point::new(
568 buffer_row_end - 1,
569 buffer.line_len(buffer_row_end - 1),
570 ));
571 Some(buffer_start_ix..buffer_end_ix)
572 } else {
573 None
574 }
575 }
576}
577
578fn fuzzy_eq(left: &str, right: &str) -> bool {
579 let min_levenshtein = left.len().abs_diff(right.len());
580 let min_normalized_levenshtein =
581 1. - (min_levenshtein as f32 / cmp::max(left.len(), right.len()) as f32);
582 if min_normalized_levenshtein < 0.8 {
583 return false;
584 }
585
586 strsim::normalized_levenshtein(left, right) >= 0.8
587}
588
589#[derive(Copy, Clone, Debug)]
590enum IndentDelta {
591 Spaces(isize),
592 Tabs(isize),
593}
594
595impl IndentDelta {
596 fn character(&self) -> char {
597 match self {
598 IndentDelta::Spaces(_) => ' ',
599 IndentDelta::Tabs(_) => '\t',
600 }
601 }
602
603 fn len(&self) -> isize {
604 match self {
605 IndentDelta::Spaces(n) => *n,
606 IndentDelta::Tabs(n) => *n,
607 }
608 }
609}
610
611#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
612enum SearchDirection {
613 Up,
614 Left,
615 Diagonal,
616}
617
618#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
619struct SearchState {
620 cost: u32,
621 direction: SearchDirection,
622}
623
624impl SearchState {
625 fn new(cost: u32, direction: SearchDirection) -> Self {
626 Self { cost, direction }
627 }
628}
629
630struct SearchMatrix {
631 cols: usize,
632 data: Vec<SearchState>,
633}
634
635impl SearchMatrix {
636 fn new(rows: usize, cols: usize) -> Self {
637 SearchMatrix {
638 cols,
639 data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
640 }
641 }
642
643 fn get(&self, row: usize, col: usize) -> SearchState {
644 self.data[row * self.cols + col]
645 }
646
647 fn set(&mut self, row: usize, col: usize, cost: SearchState) {
648 self.data[row * self.cols + col] = cost;
649 }
650}
651
652#[cfg(test)]
653mod tests {
654 use super::*;
655 use fs::FakeFs;
656 use futures::stream;
657 use gpui::{App, AppContext, TestAppContext};
658 use indoc::indoc;
659 use language_model::fake_provider::FakeLanguageModel;
660 use project::Project;
661 use rand::prelude::*;
662 use rand::rngs::StdRng;
663 use std::cmp;
664 use unindent::Unindent;
665 use util::test::{generate_marked_text, marked_text_ranges};
666
667 #[gpui::test(iterations = 100)]
668 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
669 let agent = init_test(cx).await;
670 let buffer = cx.new(|cx| {
671 Buffer::local(
672 indoc! {"
673 lorem
674 ipsum
675 dolor
676 sit
677 "},
678 cx,
679 )
680 });
681 let raw_edits = simulate_llm_output(
682 indoc! {"
683 <old_text>
684 ipsum
685 dolor
686 sit
687 </old_text>
688 <new_text>
689 ipsum
690 dolor
691 sit
692 amet
693 </new_text>
694 "},
695 &mut rng,
696 cx,
697 );
698 let (apply, _events) =
699 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
700 apply.await.unwrap();
701 pretty_assertions::assert_eq!(
702 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
703 indoc! {"
704 lorem
705 ipsum
706 dolor
707 sit
708 amet
709 "}
710 );
711 }
712
713 #[gpui::test(iterations = 100)]
714 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
715 let agent = init_test(cx).await;
716 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
717 let raw_edits = simulate_llm_output(
718 indoc! {"
719 <old_text>
720 def
721 </old_text>
722 <new_text>
723 DEF
724 </new_text>
725
726 <old_text>
727 DEF
728 </old_text>
729 <new_text>
730 DeF
731 </new_text>
732 "},
733 &mut rng,
734 cx,
735 );
736 let (apply, _events) =
737 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
738 apply.await.unwrap();
739 assert_eq!(
740 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
741 "abc\nDeF\nghi"
742 );
743 }
744
745 #[gpui::test(iterations = 100)]
746 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
747 let agent = init_test(cx).await;
748 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
749 let raw_edits = simulate_llm_output(
750 indoc! {"
751 <old_text>
752 jkl
753 </old_text>
754 <new_text>
755 mno
756 </new_text>
757
758 <old_text>
759 abc
760 </old_text>
761 <new_text>
762 ABC
763 </new_text>
764 "},
765 &mut rng,
766 cx,
767 );
768 let (apply, _events) =
769 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
770 apply.await.unwrap();
771 assert_eq!(
772 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
773 "ABC\ndef\nghi"
774 );
775 }
776
777 #[gpui::test]
778 async fn test_events(cx: &mut TestAppContext) {
779 let agent = init_test(cx).await;
780 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
781 let (chunks_tx, chunks_rx) = mpsc::unbounded();
782 let (apply, mut events) = agent.apply_edit_chunks(
783 buffer.clone(),
784 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
785 &mut cx.to_async(),
786 );
787
788 chunks_tx.unbounded_send("<old_text>a").unwrap();
789 cx.run_until_parked();
790 assert_eq!(drain_events(&mut events), vec![]);
791 assert_eq!(
792 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
793 "abc\ndef\nghi"
794 );
795
796 chunks_tx.unbounded_send("bc</old_text>").unwrap();
797 cx.run_until_parked();
798 assert_eq!(drain_events(&mut events), vec![]);
799 assert_eq!(
800 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
801 "abc\ndef\nghi"
802 );
803
804 chunks_tx.unbounded_send("<new_text>abX").unwrap();
805 cx.run_until_parked();
806 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
807 assert_eq!(
808 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
809 "abXc\ndef\nghi"
810 );
811
812 chunks_tx.unbounded_send("cY").unwrap();
813 cx.run_until_parked();
814 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
815 assert_eq!(
816 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
817 "abXcY\ndef\nghi"
818 );
819
820 chunks_tx.unbounded_send("</new_text>").unwrap();
821 chunks_tx.unbounded_send("<old_text>hall").unwrap();
822 cx.run_until_parked();
823 assert_eq!(drain_events(&mut events), vec![]);
824 assert_eq!(
825 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
826 "abXcY\ndef\nghi"
827 );
828
829 chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
830 chunks_tx.unbounded_send("<new_text>").unwrap();
831 cx.run_until_parked();
832 assert_eq!(
833 drain_events(&mut events),
834 vec![EditAgentOutputEvent::OldTextNotFound(
835 "hallucinated old".into()
836 )]
837 );
838 assert_eq!(
839 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
840 "abXcY\ndef\nghi"
841 );
842
843 chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
844 chunks_tx.unbounded_send("text>").unwrap();
845 cx.run_until_parked();
846 assert_eq!(drain_events(&mut events), vec![]);
847 assert_eq!(
848 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
849 "abXcY\ndef\nghi"
850 );
851
852 chunks_tx.unbounded_send("<old_text>gh").unwrap();
853 chunks_tx.unbounded_send("i</old_text>").unwrap();
854 chunks_tx.unbounded_send("<new_text>").unwrap();
855 cx.run_until_parked();
856 assert_eq!(drain_events(&mut events), vec![]);
857 assert_eq!(
858 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
859 "abXcY\ndef\nghi"
860 );
861
862 chunks_tx.unbounded_send("GHI</new_text>").unwrap();
863 cx.run_until_parked();
864 assert_eq!(
865 drain_events(&mut events),
866 vec![EditAgentOutputEvent::Edited]
867 );
868 assert_eq!(
869 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
870 "abXcY\ndef\nGHI"
871 );
872
873 drop(chunks_tx);
874 apply.await.unwrap();
875 assert_eq!(
876 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
877 "abXcY\ndef\nGHI"
878 );
879 assert_eq!(drain_events(&mut events), vec![]);
880
881 fn drain_events(
882 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
883 ) -> Vec<EditAgentOutputEvent> {
884 let mut events = Vec::new();
885 while let Ok(Some(event)) = stream.try_next() {
886 events.push(event);
887 }
888 events
889 }
890 }
891
892 #[gpui::test]
893 fn test_resolve_location(cx: &mut App) {
894 assert_location_resolution(
895 concat!(
896 " Lorem\n",
897 "« ipsum»\n",
898 " dolor sit amet\n",
899 " consecteur",
900 ),
901 "ipsum",
902 cx,
903 );
904
905 assert_location_resolution(
906 concat!(
907 " Lorem\n",
908 "« ipsum\n",
909 " dolor sit amet»\n",
910 " consecteur",
911 ),
912 "ipsum\ndolor sit amet",
913 cx,
914 );
915
916 assert_location_resolution(
917 &"
918 «fn foo1(a: usize) -> usize {
919 40
920 }»
921
922 fn foo2(b: usize) -> usize {
923 42
924 }
925 "
926 .unindent(),
927 "fn foo1(a: usize) -> u32 {\n40\n}",
928 cx,
929 );
930
931 assert_location_resolution(
932 &"
933 class Something {
934 one() { return 1; }
935 « two() { return 2222; }
936 three() { return 333; }
937 four() { return 4444; }
938 five() { return 5555; }
939 six() { return 6666; }»
940 seven() { return 7; }
941 eight() { return 8; }
942 }
943 "
944 .unindent(),
945 &"
946 two() { return 2222; }
947 four() { return 4444; }
948 five() { return 5555; }
949 six() { return 6666; }
950 "
951 .unindent(),
952 cx,
953 );
954
955 assert_location_resolution(
956 &"
957 use std::ops::Range;
958 use std::sync::Mutex;
959 use std::{
960 collections::HashMap,
961 env,
962 ffi::{OsStr, OsString},
963 fs,
964 io::{BufRead, BufReader},
965 mem,
966 path::{Path, PathBuf},
967 process::Command,
968 sync::LazyLock,
969 time::SystemTime,
970 };
971 "
972 .unindent(),
973 &"
974 use std::collections::{HashMap, HashSet};
975 use std::ffi::{OsStr, OsString};
976 use std::fmt::Write as _;
977 use std::fs;
978 use std::io::{BufReader, Read, Write};
979 use std::mem;
980 use std::path::{Path, PathBuf};
981 use std::process::Command;
982 use std::sync::Arc;
983 "
984 .unindent(),
985 cx,
986 );
987
988 assert_location_resolution(
989 indoc! {"
990 impl Foo {
991 fn new() -> Self {
992 Self {
993 subscriptions: vec![
994 cx.observe_window_activation(window, |editor, window, cx| {
995 let active = window.is_window_active();
996 editor.blink_manager.update(cx, |blink_manager, cx| {
997 if active {
998 blink_manager.enable(cx);
999 } else {
1000 blink_manager.disable(cx);
1001 }
1002 });
1003 }),
1004 ];
1005 }
1006 }
1007 }
1008 "},
1009 concat!(
1010 " editor.blink_manager.update(cx, |blink_manager, cx| {\n",
1011 " blink_manager.enable(cx);\n",
1012 " });",
1013 ),
1014 cx,
1015 );
1016
1017 assert_location_resolution(
1018 indoc! {r#"
1019 let tool = cx
1020 .update(|cx| working_set.tool(&tool_name, cx))
1021 .map_err(|err| {
1022 anyhow!("Failed to look up tool '{}': {}", tool_name, err)
1023 })?;
1024
1025 let Some(tool) = tool else {
1026 return Err(anyhow!("Tool '{}' not found", tool_name));
1027 };
1028
1029 let project = project.clone();
1030 let action_log = action_log.clone();
1031 let messages = messages.clone();
1032 let tool_result = cx
1033 .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
1034 .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
1035
1036 tasks.push(tool_result.output);
1037 "#},
1038 concat!(
1039 "let tool_result = cx\n",
1040 " .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
1041 " .output;",
1042 ),
1043 cx,
1044 );
1045 }
1046
1047 #[gpui::test(iterations = 100)]
1048 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1049 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1050 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1051 Ok(EditParserEvent::NewTextChunk {
1052 chunk: chunk.clone(),
1053 done: index == chunks.len() - 1,
1054 })
1055 }));
1056 let indented_chunks =
1057 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1058 .collect::<Vec<_>>()
1059 .await;
1060 let new_text = indented_chunks
1061 .into_iter()
1062 .collect::<Result<String>>()
1063 .unwrap();
1064 assert_eq!(new_text, " abc\n def\n ghi");
1065 }
1066
1067 #[gpui::test(iterations = 100)]
1068 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1069 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1070 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1071 Ok(EditParserEvent::NewTextChunk {
1072 chunk: chunk.clone(),
1073 done: index == chunks.len() - 1,
1074 })
1075 }));
1076 let indented_chunks =
1077 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1078 .collect::<Vec<_>>()
1079 .await;
1080 let new_text = indented_chunks
1081 .into_iter()
1082 .collect::<Result<String>>()
1083 .unwrap();
1084 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1085 }
1086
1087 #[gpui::test(iterations = 100)]
1088 async fn test_random_indents(mut rng: StdRng) {
1089 let len = rng.gen_range(1..=100);
1090 let new_text = util::RandomCharIter::new(&mut rng)
1091 .with_simple_text()
1092 .take(len)
1093 .collect::<String>();
1094 let new_text = new_text
1095 .split('\n')
1096 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1097 .collect::<Vec<_>>()
1098 .join("\n");
1099 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1100
1101 let chunks = to_random_chunks(&mut rng, &new_text);
1102 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1103 Ok(EditParserEvent::NewTextChunk {
1104 chunk: chunk.clone(),
1105 done: index == chunks.len() - 1,
1106 })
1107 }));
1108 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1109 .collect::<Vec<_>>()
1110 .await;
1111 let actual_reindented_text = reindented_chunks
1112 .into_iter()
1113 .collect::<Result<String>>()
1114 .unwrap();
1115 let expected_reindented_text = new_text
1116 .split('\n')
1117 .map(|line| {
1118 if let Some(ix) = line.find(|c| c != ' ') {
1119 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1120 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1121 } else {
1122 line.to_string()
1123 }
1124 })
1125 .collect::<Vec<_>>()
1126 .join("\n");
1127 assert_eq!(actual_reindented_text, expected_reindented_text);
1128 }
1129
1130 #[track_caller]
1131 fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) {
1132 let (text, _) = marked_text_ranges(text_with_expected_range, false);
1133 let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
1134 let snapshot = buffer.read(cx).snapshot();
1135 let mut ranges = Vec::new();
1136 ranges.extend(EditAgent::resolve_location(&snapshot, query));
1137 let text_with_actual_range = generate_marked_text(&text, &ranges, false);
1138 pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
1139 }
1140
1141 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1142 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1143 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1144 chunk_indices.sort();
1145 chunk_indices.push(input.len());
1146
1147 let mut chunks = Vec::new();
1148 let mut last_ix = 0;
1149 for chunk_ix in chunk_indices {
1150 chunks.push(input[last_ix..chunk_ix].to_string());
1151 last_ix = chunk_ix;
1152 }
1153 chunks
1154 }
1155
1156 fn simulate_llm_output(
1157 output: &str,
1158 rng: &mut StdRng,
1159 cx: &mut TestAppContext,
1160 ) -> impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>> {
1161 let executor = cx.executor();
1162 stream::iter(to_random_chunks(rng, output).into_iter().map(Ok)).then(move |chunk| {
1163 let executor = executor.clone();
1164 async move {
1165 executor.simulate_random_delay().await;
1166 chunk
1167 }
1168 })
1169 }
1170
1171 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1172 cx.update(settings::init);
1173 cx.update(Project::init_settings);
1174 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1175 let model = Arc::new(FakeLanguageModel::default());
1176 let action_log = cx.new(|_| ActionLog::new(project));
1177 EditAgent::new(model, action_log, Templates::new())
1178 }
1179}