1mod edit_parser;
2#[cfg(test)]
3mod evals;
4
5use crate::{Template, Templates};
6use aho_corasick::AhoCorasick;
7use anyhow::Result;
8use assistant_tool::ActionLog;
9use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
10use futures::{
11 Stream, StreamExt,
12 channel::mpsc::{self, UnboundedReceiver},
13 stream::BoxStream,
14};
15use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
16use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
17use language_model::{
18 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
19 MessageContent, Role,
20};
21use serde::Serialize;
22use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
23use streaming_diff::{CharOperation, StreamingDiff};
24
25#[derive(Serialize)]
26pub struct EditAgentTemplate {
27 path: Option<PathBuf>,
28 edit_description: String,
29}
30
31impl Template for EditAgentTemplate {
32 const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
33}
34
35#[derive(Clone, Debug, PartialEq, Eq)]
36pub enum EditAgentOutputEvent {
37 Edited,
38 HallucinatedOldText(SharedString),
39}
40
41#[derive(Clone, Debug)]
42pub struct EditAgentOutput {
43 pub _raw_edits: String,
44 pub _parser_metrics: EditParserMetrics,
45}
46
47#[derive(Clone)]
48pub struct EditAgent {
49 model: Arc<dyn LanguageModel>,
50 action_log: Entity<ActionLog>,
51 templates: Arc<Templates>,
52}
53
54impl EditAgent {
55 pub fn new(
56 model: Arc<dyn LanguageModel>,
57 action_log: Entity<ActionLog>,
58 templates: Arc<Templates>,
59 ) -> Self {
60 EditAgent {
61 model,
62 action_log,
63 templates,
64 }
65 }
66
67 pub fn edit(
68 &self,
69 buffer: Entity<Buffer>,
70 edit_description: String,
71 previous_messages: Vec<LanguageModelRequestMessage>,
72 cx: &mut AsyncApp,
73 ) -> (
74 Task<Result<EditAgentOutput>>,
75 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
76 ) {
77 let this = self.clone();
78 let (events_tx, events_rx) = mpsc::unbounded();
79 let output = cx.spawn(async move |cx| {
80 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
81 let edit_chunks = this
82 .request_edits(snapshot, edit_description, previous_messages, cx)
83 .await?;
84 let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
85 while let Some(event) = inner_events.next().await {
86 events_tx.unbounded_send(event).ok();
87 }
88 output.await
89 });
90 (output, events_rx)
91 }
92
93 fn apply_edits(
94 &self,
95 buffer: Entity<Buffer>,
96 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
97 cx: &mut AsyncApp,
98 ) -> (
99 Task<Result<EditAgentOutput>>,
100 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
101 ) {
102 let (output_events_tx, output_events_rx) = mpsc::unbounded();
103 let this = self.clone();
104 let task = cx.spawn(async move |mut cx| {
105 this.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
106 .await
107 });
108 (task, output_events_rx)
109 }
110
111 async fn apply_edits_internal(
112 &self,
113 buffer: Entity<Buffer>,
114 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
115 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
116 cx: &mut AsyncApp,
117 ) -> Result<EditAgentOutput> {
118 // Ensure the buffer is tracked by the action log.
119 self.action_log
120 .update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
121
122 let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
123 while let Some(edit_event) = edit_events.next().await {
124 let EditParserEvent::OldText(old_text_query) = edit_event? else {
125 continue;
126 };
127 let old_text_query = SharedString::from(old_text_query);
128
129 let (edits_tx, edits_rx) = mpsc::unbounded();
130 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
131 let old_range = cx
132 .background_spawn({
133 let snapshot = snapshot.clone();
134 let old_text_query = old_text_query.clone();
135 async move { Self::resolve_location(&snapshot, &old_text_query) }
136 })
137 .await;
138 let Some(old_range) = old_range else {
139 // We couldn't find the old text in the buffer. Report the error.
140 output_events
141 .unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
142 .ok();
143 continue;
144 };
145
146 let compute_edits = cx.background_spawn(async move {
147 let buffer_start_indent =
148 snapshot.line_indent_for_row(snapshot.offset_to_point(old_range.start).row);
149 let old_text_start_indent = old_text_query
150 .lines()
151 .next()
152 .map_or(buffer_start_indent, |line| {
153 LineIndent::from_iter(line.chars())
154 });
155 let indent_delta = if buffer_start_indent.tabs > 0 {
156 IndentDelta::Tabs(
157 buffer_start_indent.tabs as isize - old_text_start_indent.tabs as isize,
158 )
159 } else {
160 IndentDelta::Spaces(
161 buffer_start_indent.spaces as isize - old_text_start_indent.spaces as isize,
162 )
163 };
164
165 let old_text = snapshot
166 .text_for_range(old_range.clone())
167 .collect::<String>();
168 let mut diff = StreamingDiff::new(old_text);
169 let mut edit_start = old_range.start;
170 let mut new_text_chunks =
171 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
172 let mut done = false;
173 while !done {
174 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await
175 {
176 diff.push_new(&new_text_chunk?)
177 } else {
178 done = true;
179 mem::take(&mut diff).finish()
180 };
181
182 for op in char_operations {
183 match op {
184 CharOperation::Insert { text } => {
185 let edit_start = snapshot.anchor_after(edit_start);
186 edits_tx.unbounded_send((edit_start..edit_start, text))?;
187 }
188 CharOperation::Delete { bytes } => {
189 let edit_end = edit_start + bytes;
190 let edit_range = snapshot.anchor_after(edit_start)
191 ..snapshot.anchor_before(edit_end);
192 edit_start = edit_end;
193 edits_tx.unbounded_send((edit_range, String::new()))?;
194 }
195 CharOperation::Keep { bytes } => edit_start += bytes,
196 }
197 }
198 }
199
200 drop(new_text_chunks);
201 anyhow::Ok(edit_events)
202 });
203
204 // TODO: group all edits into one transaction
205 let mut edits_rx = edits_rx.ready_chunks(32);
206 while let Some(edits) = edits_rx.next().await {
207 // Edit the buffer and report edits to the action log as part of the
208 // same effect cycle, otherwise the edit will be reported as if the
209 // user made it.
210 cx.update(|cx| {
211 buffer.update(cx, |buffer, cx| buffer.edit(edits, None, cx));
212 self.action_log
213 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
214 })?;
215 output_events
216 .unbounded_send(EditAgentOutputEvent::Edited)
217 .ok();
218 }
219
220 edit_events = compute_edits.await?;
221 }
222
223 output.await
224 }
225
226 fn parse_edit_chunks(
227 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
228 cx: &mut AsyncApp,
229 ) -> (
230 Task<Result<EditAgentOutput>>,
231 UnboundedReceiver<Result<EditParserEvent>>,
232 ) {
233 let (tx, rx) = mpsc::unbounded();
234 let output = cx.background_spawn(async move {
235 futures::pin_mut!(chunks);
236
237 let mut parser = EditParser::new();
238 let mut raw_edits = String::new();
239 while let Some(chunk) = chunks.next().await {
240 match chunk {
241 Ok(chunk) => {
242 raw_edits.push_str(&chunk);
243 for event in parser.push(&chunk) {
244 tx.unbounded_send(Ok(event))?;
245 }
246 }
247 Err(error) => {
248 tx.unbounded_send(Err(error.into()))?;
249 }
250 }
251 }
252 Ok(EditAgentOutput {
253 _raw_edits: raw_edits,
254 _parser_metrics: parser.finish(),
255 })
256 });
257 (output, rx)
258 }
259
260 fn reindent_new_text_chunks(
261 delta: IndentDelta,
262 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
263 ) -> impl Stream<Item = Result<String>> {
264 let mut buffer = String::new();
265 let mut in_leading_whitespace = true;
266 let mut done = false;
267 futures::stream::poll_fn(move |cx| {
268 while !done {
269 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
270 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
271 (chunk, done)
272 }
273 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
274 Poll::Pending => return Poll::Pending,
275 _ => return Poll::Ready(None),
276 };
277
278 buffer.push_str(&chunk);
279
280 let mut indented_new_text = String::new();
281 let mut start_ix = 0;
282 let mut newlines = buffer.match_indices('\n').peekable();
283 loop {
284 let (line_end, is_pending_line) = match newlines.next() {
285 Some((ix, _)) => (ix, false),
286 None => (buffer.len(), true),
287 };
288 let line = &buffer[start_ix..line_end];
289
290 if in_leading_whitespace {
291 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
292 // We found a non-whitespace character, adjust
293 // indentation based on the delta.
294 let new_indent_len =
295 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
296 indented_new_text
297 .extend(iter::repeat(delta.character()).take(new_indent_len));
298 indented_new_text.push_str(&line[non_whitespace_ix..]);
299 in_leading_whitespace = false;
300 } else if is_pending_line {
301 // We're still in leading whitespace and this line is incomplete.
302 // Stop processing until we receive more input.
303 break;
304 } else {
305 // This line is entirely whitespace. Push it without indentation.
306 indented_new_text.push_str(line);
307 }
308 } else {
309 indented_new_text.push_str(line);
310 }
311
312 if is_pending_line {
313 start_ix = line_end;
314 break;
315 } else {
316 in_leading_whitespace = true;
317 indented_new_text.push('\n');
318 start_ix = line_end + 1;
319 }
320 }
321 buffer.replace_range(..start_ix, "");
322
323 // This was the last chunk, push all the buffered content as-is.
324 if is_last_chunk {
325 indented_new_text.push_str(&buffer);
326 buffer.clear();
327 done = true;
328 }
329
330 if !indented_new_text.is_empty() {
331 return Poll::Ready(Some(Ok(indented_new_text)));
332 }
333 }
334
335 Poll::Ready(None)
336 })
337 }
338
339 async fn request_edits(
340 &self,
341 snapshot: BufferSnapshot,
342 edit_description: String,
343 mut messages: Vec<LanguageModelRequestMessage>,
344 cx: &mut AsyncApp,
345 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
346 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
347 let prompt = EditAgentTemplate {
348 path,
349 edit_description,
350 }
351 .render(&self.templates)?;
352
353 let mut message_content = Vec::new();
354 if let Some(last_message) = messages.last_mut() {
355 if last_message.role == Role::Assistant {
356 last_message
357 .content
358 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
359 if last_message.content.is_empty() {
360 messages.pop();
361 }
362 }
363 }
364 message_content.push(MessageContent::Text(prompt));
365 messages.push(LanguageModelRequestMessage {
366 role: Role::User,
367 content: message_content,
368 cache: false,
369 });
370
371 let request = LanguageModelRequest {
372 messages,
373 ..Default::default()
374 };
375 Ok(self.model.stream_completion_text(request, cx).await?.stream)
376 }
377
378 fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
379 let range = Self::resolve_location_exact(buffer, search_query)
380 .or_else(|| Self::resolve_location_fuzzy(buffer, search_query))?;
381
382 // Expand the range to include entire lines.
383 let mut start = buffer.offset_to_point(buffer.clip_offset(range.start, Bias::Left));
384 start.column = 0;
385 let mut end = buffer.offset_to_point(buffer.clip_offset(range.end, Bias::Right));
386 if end.column > 0 {
387 end.column = buffer.line_len(end.row);
388 }
389
390 Some(buffer.point_to_offset(start)..buffer.point_to_offset(end))
391 }
392
393 fn resolve_location_exact(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
394 let search = AhoCorasick::new([search_query]).ok()?;
395 let mat = search
396 .stream_find_iter(buffer.bytes_in_range(0..buffer.len()))
397 .next()?
398 .expect("buffer can't error");
399 Some(mat.range())
400 }
401
402 fn resolve_location_fuzzy(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
403 const INSERTION_COST: u32 = 3;
404 const DELETION_COST: u32 = 10;
405
406 let buffer_line_count = buffer.max_point().row as usize + 1;
407 let query_line_count = search_query.lines().count();
408 let mut matrix = SearchMatrix::new(query_line_count + 1, buffer_line_count + 1);
409 let mut leading_deletion_cost = 0_u32;
410 for (row, query_line) in search_query.lines().enumerate() {
411 let query_line = query_line.trim();
412 leading_deletion_cost = leading_deletion_cost.saturating_add(DELETION_COST);
413 matrix.set(
414 row + 1,
415 0,
416 SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
417 );
418
419 let mut buffer_lines = buffer.as_rope().chunks().lines();
420 let mut col = 0;
421 while let Some(buffer_line) = buffer_lines.next() {
422 let buffer_line = buffer_line.trim();
423 let up = SearchState::new(
424 matrix.get(row, col + 1).cost.saturating_add(DELETION_COST),
425 SearchDirection::Up,
426 );
427 let left = SearchState::new(
428 matrix.get(row + 1, col).cost.saturating_add(INSERTION_COST),
429 SearchDirection::Left,
430 );
431 let diagonal = SearchState::new(
432 if fuzzy_eq(query_line, buffer_line) {
433 matrix.get(row, col).cost
434 } else {
435 matrix
436 .get(row, col)
437 .cost
438 .saturating_add(DELETION_COST + INSERTION_COST)
439 },
440 SearchDirection::Diagonal,
441 );
442 matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
443 col += 1;
444 }
445 }
446
447 // Traceback to find the best match
448 let mut buffer_row_end = buffer_line_count as u32;
449 let mut best_cost = u32::MAX;
450 for col in 1..=buffer_line_count {
451 let cost = matrix.get(query_line_count, col).cost;
452 if cost < best_cost {
453 best_cost = cost;
454 buffer_row_end = col as u32;
455 }
456 }
457
458 let mut matched_lines = 0;
459 let mut query_row = query_line_count;
460 let mut buffer_row_start = buffer_row_end;
461 while query_row > 0 && buffer_row_start > 0 {
462 let current = matrix.get(query_row, buffer_row_start as usize);
463 match current.direction {
464 SearchDirection::Diagonal => {
465 query_row -= 1;
466 buffer_row_start -= 1;
467 matched_lines += 1;
468 }
469 SearchDirection::Up => {
470 query_row -= 1;
471 }
472 SearchDirection::Left => {
473 buffer_row_start -= 1;
474 }
475 }
476 }
477
478 let matched_buffer_row_count = buffer_row_end - buffer_row_start;
479 let matched_ratio =
480 matched_lines as f32 / (matched_buffer_row_count as f32).max(query_line_count as f32);
481 if matched_ratio >= 0.8 {
482 let buffer_start_ix = buffer.point_to_offset(Point::new(buffer_row_start, 0));
483 let buffer_end_ix = buffer.point_to_offset(Point::new(
484 buffer_row_end - 1,
485 buffer.line_len(buffer_row_end - 1),
486 ));
487 Some(buffer_start_ix..buffer_end_ix)
488 } else {
489 None
490 }
491 }
492}
493
494fn fuzzy_eq(left: &str, right: &str) -> bool {
495 let min_levenshtein = left.len().abs_diff(right.len());
496 let min_normalized_levenshtein =
497 1. - (min_levenshtein as f32 / cmp::max(left.len(), right.len()) as f32);
498 if min_normalized_levenshtein < 0.8 {
499 return false;
500 }
501
502 strsim::normalized_levenshtein(left, right) >= 0.8
503}
504
505#[derive(Copy, Clone, Debug)]
506enum IndentDelta {
507 Spaces(isize),
508 Tabs(isize),
509}
510
511impl IndentDelta {
512 fn character(&self) -> char {
513 match self {
514 IndentDelta::Spaces(_) => ' ',
515 IndentDelta::Tabs(_) => '\t',
516 }
517 }
518
519 fn len(&self) -> isize {
520 match self {
521 IndentDelta::Spaces(n) => *n,
522 IndentDelta::Tabs(n) => *n,
523 }
524 }
525}
526
527#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
528enum SearchDirection {
529 Up,
530 Left,
531 Diagonal,
532}
533
534#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
535struct SearchState {
536 cost: u32,
537 direction: SearchDirection,
538}
539
540impl SearchState {
541 fn new(cost: u32, direction: SearchDirection) -> Self {
542 Self { cost, direction }
543 }
544}
545
546struct SearchMatrix {
547 cols: usize,
548 data: Vec<SearchState>,
549}
550
551impl SearchMatrix {
552 fn new(rows: usize, cols: usize) -> Self {
553 SearchMatrix {
554 cols,
555 data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
556 }
557 }
558
559 fn get(&self, row: usize, col: usize) -> SearchState {
560 self.data[row * self.cols + col]
561 }
562
563 fn set(&mut self, row: usize, col: usize, cost: SearchState) {
564 self.data[row * self.cols + col] = cost;
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use fs::FakeFs;
572 use futures::stream;
573 use gpui::{App, AppContext, TestAppContext};
574 use indoc::indoc;
575 use language_model::fake_provider::FakeLanguageModel;
576 use project::Project;
577 use rand::prelude::*;
578 use rand::rngs::StdRng;
579 use std::cmp;
580 use unindent::Unindent;
581 use util::test::{generate_marked_text, marked_text_ranges};
582
583 #[gpui::test(iterations = 100)]
584 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
585 let agent = init_test(cx).await;
586 let buffer = cx.new(|cx| {
587 Buffer::local(
588 indoc! {"
589 lorem
590 ipsum
591 dolor
592 sit
593 "},
594 cx,
595 )
596 });
597 let raw_edits = simulate_llm_output(
598 indoc! {"
599 <old_text>
600 ipsum
601 dolor
602 sit
603 </old_text>
604 <new_text>
605 ipsum
606 dolor
607 sit
608 amet
609 </new_text>
610 "},
611 &mut rng,
612 cx,
613 );
614 let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
615 apply.await.unwrap();
616 pretty_assertions::assert_eq!(
617 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
618 indoc! {"
619 lorem
620 ipsum
621 dolor
622 sit
623 amet
624 "}
625 );
626 }
627
628 #[gpui::test(iterations = 100)]
629 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
630 let agent = init_test(cx).await;
631 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
632 let raw_edits = simulate_llm_output(
633 indoc! {"
634 <old_text>
635 def
636 </old_text>
637 <new_text>
638 DEF
639 </new_text>
640
641 <old_text>
642 DEF
643 </old_text>
644 <new_text>
645 DeF
646 </new_text>
647 "},
648 &mut rng,
649 cx,
650 );
651 let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
652 apply.await.unwrap();
653 assert_eq!(
654 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
655 "abc\nDeF\nghi"
656 );
657 }
658
659 #[gpui::test(iterations = 100)]
660 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
661 let agent = init_test(cx).await;
662 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
663 let raw_edits = simulate_llm_output(
664 indoc! {"
665 <old_text>
666 jkl
667 </old_text>
668 <new_text>
669 mno
670 </new_text>
671
672 <old_text>
673 abc
674 </old_text>
675 <new_text>
676 ABC
677 </new_text>
678 "},
679 &mut rng,
680 cx,
681 );
682 let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
683 apply.await.unwrap();
684 assert_eq!(
685 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
686 "ABC\ndef\nghi"
687 );
688 }
689
690 #[gpui::test]
691 async fn test_events(cx: &mut TestAppContext) {
692 let agent = init_test(cx).await;
693 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
694 let (chunks_tx, chunks_rx) = mpsc::unbounded();
695 let (apply, mut events) = agent.apply_edits(
696 buffer.clone(),
697 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
698 &mut cx.to_async(),
699 );
700
701 chunks_tx.unbounded_send("<old_text>a").unwrap();
702 cx.run_until_parked();
703 assert_eq!(drain_events(&mut events), vec![]);
704 assert_eq!(
705 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
706 "abc\ndef\nghi"
707 );
708
709 chunks_tx.unbounded_send("bc</old_text>").unwrap();
710 cx.run_until_parked();
711 assert_eq!(drain_events(&mut events), vec![]);
712 assert_eq!(
713 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
714 "abc\ndef\nghi"
715 );
716
717 chunks_tx.unbounded_send("<new_text>abX").unwrap();
718 cx.run_until_parked();
719 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
720 assert_eq!(
721 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
722 "abXc\ndef\nghi"
723 );
724
725 chunks_tx.unbounded_send("cY").unwrap();
726 cx.run_until_parked();
727 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
728 assert_eq!(
729 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
730 "abXcY\ndef\nghi"
731 );
732
733 chunks_tx.unbounded_send("</new_text>").unwrap();
734 chunks_tx.unbounded_send("<old_text>hall").unwrap();
735 cx.run_until_parked();
736 assert_eq!(drain_events(&mut events), vec![]);
737 assert_eq!(
738 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
739 "abXcY\ndef\nghi"
740 );
741
742 chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
743 chunks_tx.unbounded_send("<new_text>").unwrap();
744 cx.run_until_parked();
745 assert_eq!(
746 drain_events(&mut events),
747 vec![EditAgentOutputEvent::HallucinatedOldText(
748 "hallucinated old".into()
749 )]
750 );
751 assert_eq!(
752 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
753 "abXcY\ndef\nghi"
754 );
755
756 chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
757 chunks_tx.unbounded_send("text>").unwrap();
758 cx.run_until_parked();
759 assert_eq!(drain_events(&mut events), vec![]);
760 assert_eq!(
761 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
762 "abXcY\ndef\nghi"
763 );
764
765 chunks_tx.unbounded_send("<old_text>gh").unwrap();
766 chunks_tx.unbounded_send("i</old_text>").unwrap();
767 chunks_tx.unbounded_send("<new_text>").unwrap();
768 cx.run_until_parked();
769 assert_eq!(drain_events(&mut events), vec![]);
770 assert_eq!(
771 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
772 "abXcY\ndef\nghi"
773 );
774
775 chunks_tx.unbounded_send("GHI</new_text>").unwrap();
776 cx.run_until_parked();
777 assert_eq!(
778 drain_events(&mut events),
779 vec![EditAgentOutputEvent::Edited]
780 );
781 assert_eq!(
782 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
783 "abXcY\ndef\nGHI"
784 );
785
786 drop(chunks_tx);
787 apply.await.unwrap();
788 assert_eq!(
789 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
790 "abXcY\ndef\nGHI"
791 );
792 assert_eq!(drain_events(&mut events), vec![]);
793
794 fn drain_events(
795 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
796 ) -> Vec<EditAgentOutputEvent> {
797 let mut events = Vec::new();
798 while let Ok(Some(event)) = stream.try_next() {
799 events.push(event);
800 }
801 events
802 }
803 }
804
805 #[gpui::test]
806 fn test_resolve_location(cx: &mut App) {
807 assert_location_resolution(
808 concat!(
809 " Lorem\n",
810 "« ipsum»\n",
811 " dolor sit amet\n",
812 " consecteur",
813 ),
814 "ipsum",
815 cx,
816 );
817
818 assert_location_resolution(
819 concat!(
820 " Lorem\n",
821 "« ipsum\n",
822 " dolor sit amet»\n",
823 " consecteur",
824 ),
825 "ipsum\ndolor sit amet",
826 cx,
827 );
828
829 assert_location_resolution(
830 &"
831 «fn foo1(a: usize) -> usize {
832 40
833 }»
834
835 fn foo2(b: usize) -> usize {
836 42
837 }
838 "
839 .unindent(),
840 "fn foo1(a: usize) -> u32 {\n40\n}",
841 cx,
842 );
843
844 assert_location_resolution(
845 &"
846 class Something {
847 one() { return 1; }
848 « two() { return 2222; }
849 three() { return 333; }
850 four() { return 4444; }
851 five() { return 5555; }
852 six() { return 6666; }»
853 seven() { return 7; }
854 eight() { return 8; }
855 }
856 "
857 .unindent(),
858 &"
859 two() { return 2222; }
860 four() { return 4444; }
861 five() { return 5555; }
862 six() { return 6666; }
863 "
864 .unindent(),
865 cx,
866 );
867
868 assert_location_resolution(
869 &"
870 use std::ops::Range;
871 use std::sync::Mutex;
872 use std::{
873 collections::HashMap,
874 env,
875 ffi::{OsStr, OsString},
876 fs,
877 io::{BufRead, BufReader},
878 mem,
879 path::{Path, PathBuf},
880 process::Command,
881 sync::LazyLock,
882 time::SystemTime,
883 };
884 "
885 .unindent(),
886 &"
887 use std::collections::{HashMap, HashSet};
888 use std::ffi::{OsStr, OsString};
889 use std::fmt::Write as _;
890 use std::fs;
891 use std::io::{BufReader, Read, Write};
892 use std::mem;
893 use std::path::{Path, PathBuf};
894 use std::process::Command;
895 use std::sync::Arc;
896 "
897 .unindent(),
898 cx,
899 );
900
901 assert_location_resolution(
902 indoc! {"
903 impl Foo {
904 fn new() -> Self {
905 Self {
906 subscriptions: vec![
907 cx.observe_window_activation(window, |editor, window, cx| {
908 let active = window.is_window_active();
909 editor.blink_manager.update(cx, |blink_manager, cx| {
910 if active {
911 blink_manager.enable(cx);
912 } else {
913 blink_manager.disable(cx);
914 }
915 });
916 }),
917 ];
918 }
919 }
920 }
921 "},
922 concat!(
923 " editor.blink_manager.update(cx, |blink_manager, cx| {\n",
924 " blink_manager.enable(cx);\n",
925 " });",
926 ),
927 cx,
928 );
929
930 assert_location_resolution(
931 indoc! {r#"
932 let tool = cx
933 .update(|cx| working_set.tool(&tool_name, cx))
934 .map_err(|err| {
935 anyhow!("Failed to look up tool '{}': {}", tool_name, err)
936 })?;
937
938 let Some(tool) = tool else {
939 return Err(anyhow!("Tool '{}' not found", tool_name));
940 };
941
942 let project = project.clone();
943 let action_log = action_log.clone();
944 let messages = messages.clone();
945 let tool_result = cx
946 .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
947 .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
948
949 tasks.push(tool_result.output);
950 "#},
951 concat!(
952 "let tool_result = cx\n",
953 " .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
954 " .output;",
955 ),
956 cx,
957 );
958 }
959
960 #[gpui::test(iterations = 100)]
961 async fn test_indent_new_text_chunks(mut rng: StdRng) {
962 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
963 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
964 Ok(EditParserEvent::NewTextChunk {
965 chunk: chunk.clone(),
966 done: index == chunks.len() - 1,
967 })
968 }));
969 let indented_chunks =
970 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
971 .collect::<Vec<_>>()
972 .await;
973 let new_text = indented_chunks
974 .into_iter()
975 .collect::<Result<String>>()
976 .unwrap();
977 assert_eq!(new_text, " abc\n def\n ghi");
978 }
979
980 #[gpui::test(iterations = 100)]
981 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
982 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
983 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
984 Ok(EditParserEvent::NewTextChunk {
985 chunk: chunk.clone(),
986 done: index == chunks.len() - 1,
987 })
988 }));
989 let indented_chunks =
990 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
991 .collect::<Vec<_>>()
992 .await;
993 let new_text = indented_chunks
994 .into_iter()
995 .collect::<Result<String>>()
996 .unwrap();
997 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
998 }
999
1000 #[gpui::test(iterations = 100)]
1001 async fn test_random_indents(mut rng: StdRng) {
1002 let len = rng.gen_range(1..=100);
1003 let new_text = util::RandomCharIter::new(&mut rng)
1004 .with_simple_text()
1005 .take(len)
1006 .collect::<String>();
1007 let new_text = new_text
1008 .split('\n')
1009 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1010 .collect::<Vec<_>>()
1011 .join("\n");
1012 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1013
1014 let chunks = to_random_chunks(&mut rng, &new_text);
1015 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1016 Ok(EditParserEvent::NewTextChunk {
1017 chunk: chunk.clone(),
1018 done: index == chunks.len() - 1,
1019 })
1020 }));
1021 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1022 .collect::<Vec<_>>()
1023 .await;
1024 let actual_reindented_text = reindented_chunks
1025 .into_iter()
1026 .collect::<Result<String>>()
1027 .unwrap();
1028 let expected_reindented_text = new_text
1029 .split('\n')
1030 .map(|line| {
1031 if let Some(ix) = line.find(|c| c != ' ') {
1032 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1033 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1034 } else {
1035 line.to_string()
1036 }
1037 })
1038 .collect::<Vec<_>>()
1039 .join("\n");
1040 assert_eq!(actual_reindented_text, expected_reindented_text);
1041 }
1042
1043 #[track_caller]
1044 fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) {
1045 let (text, _) = marked_text_ranges(text_with_expected_range, false);
1046 let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
1047 let snapshot = buffer.read(cx).snapshot();
1048 let mut ranges = Vec::new();
1049 ranges.extend(EditAgent::resolve_location(&snapshot, query));
1050 let text_with_actual_range = generate_marked_text(&text, &ranges, false);
1051 pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
1052 }
1053
1054 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1055 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1056 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1057 chunk_indices.sort();
1058 chunk_indices.push(input.len());
1059
1060 let mut chunks = Vec::new();
1061 let mut last_ix = 0;
1062 for chunk_ix in chunk_indices {
1063 chunks.push(input[last_ix..chunk_ix].to_string());
1064 last_ix = chunk_ix;
1065 }
1066 chunks
1067 }
1068
1069 fn simulate_llm_output(
1070 output: &str,
1071 rng: &mut StdRng,
1072 cx: &mut TestAppContext,
1073 ) -> impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>> {
1074 let executor = cx.executor();
1075 stream::iter(to_random_chunks(rng, output).into_iter().map(Ok)).then(move |chunk| {
1076 let executor = executor.clone();
1077 async move {
1078 executor.simulate_random_delay().await;
1079 chunk
1080 }
1081 })
1082 }
1083
1084 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1085 cx.update(settings::init);
1086 cx.update(Project::init_settings);
1087 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1088 let model = Arc::new(FakeLanguageModel::default());
1089 let action_log = cx.new(|_| ActionLog::new(project));
1090 EditAgent::new(model, action_log, Templates::new())
1091 }
1092}