1mod edit_parser;
2#[cfg(test)]
3mod evals;
4
5use crate::{Template, Templates};
6use aho_corasick::AhoCorasick;
7use anyhow::Result;
8use assistant_tool::ActionLog;
9use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
10use futures::{
11 Stream, StreamExt,
12 channel::mpsc::{self, UnboundedReceiver},
13 pin_mut,
14 stream::BoxStream,
15};
16use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
17use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
18use language_model::{
19 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
20 MessageContent, Role,
21};
22use project::{AgentLocation, Project};
23use serde::Serialize;
24use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
25use streaming_diff::{CharOperation, StreamingDiff};
26
27#[derive(Serialize)]
28struct CreateFilePromptTemplate {
29 path: Option<PathBuf>,
30 edit_description: String,
31}
32
33impl Template for CreateFilePromptTemplate {
34 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
35}
36
37#[derive(Serialize)]
38struct EditFilePromptTemplate {
39 path: Option<PathBuf>,
40 edit_description: String,
41}
42
43impl Template for EditFilePromptTemplate {
44 const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
45}
46
47#[derive(Clone, Debug, PartialEq, Eq)]
48pub enum EditAgentOutputEvent {
49 Edited,
50 OldTextNotFound(SharedString),
51}
52
53#[derive(Clone, Debug)]
54pub struct EditAgentOutput {
55 pub _raw_edits: String,
56 pub _parser_metrics: EditParserMetrics,
57}
58
59#[derive(Clone)]
60pub struct EditAgent {
61 model: Arc<dyn LanguageModel>,
62 action_log: Entity<ActionLog>,
63 project: Entity<Project>,
64 templates: Arc<Templates>,
65}
66
67impl EditAgent {
68 pub fn new(
69 model: Arc<dyn LanguageModel>,
70 project: Entity<Project>,
71 action_log: Entity<ActionLog>,
72 templates: Arc<Templates>,
73 ) -> Self {
74 EditAgent {
75 model,
76 project,
77 action_log,
78 templates,
79 }
80 }
81
82 pub fn overwrite(
83 &self,
84 buffer: Entity<Buffer>,
85 edit_description: String,
86 previous_messages: Vec<LanguageModelRequestMessage>,
87 cx: &mut AsyncApp,
88 ) -> (
89 Task<Result<EditAgentOutput>>,
90 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
91 ) {
92 let this = self.clone();
93 let (events_tx, events_rx) = mpsc::unbounded();
94 let output = cx.spawn(async move |cx| {
95 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
96 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
97 let prompt = CreateFilePromptTemplate {
98 path,
99 edit_description,
100 }
101 .render(&this.templates)?;
102 let new_chunks = this.request(previous_messages, prompt, cx).await?;
103
104 let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
105 while let Some(event) = inner_events.next().await {
106 events_tx.unbounded_send(event).ok();
107 }
108 output.await
109 });
110 (output, events_rx)
111 }
112
113 fn overwrite_with_chunks(
114 &self,
115 buffer: Entity<Buffer>,
116 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
117 cx: &mut AsyncApp,
118 ) -> (
119 Task<Result<EditAgentOutput>>,
120 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
121 ) {
122 let (output_events_tx, output_events_rx) = mpsc::unbounded();
123 let this = self.clone();
124 let task = cx.spawn(async move |cx| {
125 this.action_log
126 .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
127 let output = this
128 .overwrite_with_chunks_internal(buffer, edit_chunks, output_events_tx, cx)
129 .await;
130 this.project
131 .update(cx, |project, cx| project.set_agent_location(None, cx))?;
132 output
133 });
134 (task, output_events_rx)
135 }
136
137 async fn overwrite_with_chunks_internal(
138 &self,
139 buffer: Entity<Buffer>,
140 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
141 output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
142 cx: &mut AsyncApp,
143 ) -> Result<EditAgentOutput> {
144 cx.update(|cx| {
145 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
146 self.action_log.update(cx, |log, cx| {
147 log.buffer_edited(buffer.clone(), cx);
148 });
149 self.project.update(cx, |project, cx| {
150 project.set_agent_location(
151 Some(AgentLocation {
152 buffer: buffer.downgrade(),
153 position: language::Anchor::MAX,
154 }),
155 cx,
156 )
157 });
158 output_events_tx
159 .unbounded_send(EditAgentOutputEvent::Edited)
160 .ok();
161 })?;
162
163 let mut raw_edits = String::new();
164 pin_mut!(edit_chunks);
165 while let Some(chunk) = edit_chunks.next().await {
166 let chunk = chunk?;
167 raw_edits.push_str(&chunk);
168 cx.update(|cx| {
169 buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
170 self.action_log
171 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
172 self.project.update(cx, |project, cx| {
173 project.set_agent_location(
174 Some(AgentLocation {
175 buffer: buffer.downgrade(),
176 position: language::Anchor::MAX,
177 }),
178 cx,
179 )
180 });
181 })?;
182 output_events_tx
183 .unbounded_send(EditAgentOutputEvent::Edited)
184 .ok();
185 }
186
187 Ok(EditAgentOutput {
188 _raw_edits: raw_edits,
189 _parser_metrics: EditParserMetrics::default(),
190 })
191 }
192
193 pub fn edit(
194 &self,
195 buffer: Entity<Buffer>,
196 edit_description: String,
197 previous_messages: Vec<LanguageModelRequestMessage>,
198 cx: &mut AsyncApp,
199 ) -> (
200 Task<Result<EditAgentOutput>>,
201 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
202 ) {
203 self.project
204 .update(cx, |project, cx| {
205 project.set_agent_location(
206 Some(AgentLocation {
207 buffer: buffer.downgrade(),
208 position: language::Anchor::MIN,
209 }),
210 cx,
211 );
212 })
213 .ok();
214
215 let this = self.clone();
216 let (events_tx, events_rx) = mpsc::unbounded();
217 let output = cx.spawn(async move |cx| {
218 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
219 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
220 let prompt = EditFilePromptTemplate {
221 path,
222 edit_description,
223 }
224 .render(&this.templates)?;
225 let edit_chunks = this.request(previous_messages, prompt, cx).await?;
226
227 let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
228 while let Some(event) = inner_events.next().await {
229 events_tx.unbounded_send(event).ok();
230 }
231 output.await
232 });
233 (output, events_rx)
234 }
235
236 fn apply_edit_chunks(
237 &self,
238 buffer: Entity<Buffer>,
239 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
240 cx: &mut AsyncApp,
241 ) -> (
242 Task<Result<EditAgentOutput>>,
243 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
244 ) {
245 let (output_events_tx, output_events_rx) = mpsc::unbounded();
246 let this = self.clone();
247 let task = cx.spawn(async move |mut cx| {
248 this.action_log
249 .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
250 let output = this
251 .apply_edit_chunks_internal(buffer, edit_chunks, output_events_tx, &mut cx)
252 .await;
253 this.project
254 .update(cx, |project, cx| project.set_agent_location(None, cx))?;
255 output
256 });
257 (task, output_events_rx)
258 }
259
260 async fn apply_edit_chunks_internal(
261 &self,
262 buffer: Entity<Buffer>,
263 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
264 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
265 cx: &mut AsyncApp,
266 ) -> Result<EditAgentOutput> {
267 let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
268 while let Some(edit_event) = edit_events.next().await {
269 let EditParserEvent::OldText(old_text_query) = edit_event? else {
270 continue;
271 };
272 let old_text_query = SharedString::from(old_text_query);
273
274 let (edits_tx, edits_rx) = mpsc::unbounded();
275 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
276 let old_range = cx
277 .background_spawn({
278 let snapshot = snapshot.clone();
279 let old_text_query = old_text_query.clone();
280 async move { Self::resolve_location(&snapshot, &old_text_query) }
281 })
282 .await;
283 let Some(old_range) = old_range else {
284 // We couldn't find the old text in the buffer. Report the error.
285 output_events
286 .unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
287 .ok();
288 continue;
289 };
290
291 let compute_edits = cx.background_spawn(async move {
292 let buffer_start_indent =
293 snapshot.line_indent_for_row(snapshot.offset_to_point(old_range.start).row);
294 let old_text_start_indent = old_text_query
295 .lines()
296 .next()
297 .map_or(buffer_start_indent, |line| {
298 LineIndent::from_iter(line.chars())
299 });
300 let indent_delta = if buffer_start_indent.tabs > 0 {
301 IndentDelta::Tabs(
302 buffer_start_indent.tabs as isize - old_text_start_indent.tabs as isize,
303 )
304 } else {
305 IndentDelta::Spaces(
306 buffer_start_indent.spaces as isize - old_text_start_indent.spaces as isize,
307 )
308 };
309
310 let old_text = snapshot
311 .text_for_range(old_range.clone())
312 .collect::<String>();
313 let mut diff = StreamingDiff::new(old_text);
314 let mut edit_start = old_range.start;
315 let mut new_text_chunks =
316 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
317 let mut done = false;
318 while !done {
319 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await
320 {
321 diff.push_new(&new_text_chunk?)
322 } else {
323 done = true;
324 mem::take(&mut diff).finish()
325 };
326
327 for op in char_operations {
328 match op {
329 CharOperation::Insert { text } => {
330 let edit_start = snapshot.anchor_after(edit_start);
331 edits_tx
332 .unbounded_send((edit_start..edit_start, Arc::from(text)))?;
333 }
334 CharOperation::Delete { bytes } => {
335 let edit_end = edit_start + bytes;
336 let edit_range = snapshot.anchor_after(edit_start)
337 ..snapshot.anchor_before(edit_end);
338 edit_start = edit_end;
339 edits_tx.unbounded_send((edit_range, Arc::from("")))?;
340 }
341 CharOperation::Keep { bytes } => edit_start += bytes,
342 }
343 }
344 }
345
346 drop(new_text_chunks);
347 anyhow::Ok(edit_events)
348 });
349
350 // TODO: group all edits into one transaction
351 let mut edits_rx = edits_rx.ready_chunks(32);
352 while let Some(edits) = edits_rx.next().await {
353 if edits.is_empty() {
354 continue;
355 }
356
357 // Edit the buffer and report edits to the action log as part of the
358 // same effect cycle, otherwise the edit will be reported as if the
359 // user made it.
360 cx.update(|cx| {
361 let max_edit_end = buffer.update(cx, |buffer, cx| {
362 buffer.edit(edits.iter().cloned(), None, cx);
363 let max_edit_end = buffer
364 .summaries_for_anchors::<Point, _>(
365 edits.iter().map(|(range, _)| &range.end),
366 )
367 .max()
368 .unwrap();
369 buffer.anchor_before(max_edit_end)
370 });
371 self.action_log
372 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
373 self.project.update(cx, |project, cx| {
374 project.set_agent_location(
375 Some(AgentLocation {
376 buffer: buffer.downgrade(),
377 position: max_edit_end,
378 }),
379 cx,
380 );
381 });
382 })?;
383 output_events
384 .unbounded_send(EditAgentOutputEvent::Edited)
385 .ok();
386 }
387
388 edit_events = compute_edits.await?;
389 }
390
391 output.await
392 }
393
394 fn parse_edit_chunks(
395 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
396 cx: &mut AsyncApp,
397 ) -> (
398 Task<Result<EditAgentOutput>>,
399 UnboundedReceiver<Result<EditParserEvent>>,
400 ) {
401 let (tx, rx) = mpsc::unbounded();
402 let output = cx.background_spawn(async move {
403 pin_mut!(chunks);
404
405 let mut parser = EditParser::new();
406 let mut raw_edits = String::new();
407 while let Some(chunk) = chunks.next().await {
408 match chunk {
409 Ok(chunk) => {
410 raw_edits.push_str(&chunk);
411 for event in parser.push(&chunk) {
412 tx.unbounded_send(Ok(event))?;
413 }
414 }
415 Err(error) => {
416 tx.unbounded_send(Err(error.into()))?;
417 }
418 }
419 }
420 Ok(EditAgentOutput {
421 _raw_edits: raw_edits,
422 _parser_metrics: parser.finish(),
423 })
424 });
425 (output, rx)
426 }
427
428 fn reindent_new_text_chunks(
429 delta: IndentDelta,
430 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
431 ) -> impl Stream<Item = Result<String>> {
432 let mut buffer = String::new();
433 let mut in_leading_whitespace = true;
434 let mut done = false;
435 futures::stream::poll_fn(move |cx| {
436 while !done {
437 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
438 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
439 (chunk, done)
440 }
441 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
442 Poll::Pending => return Poll::Pending,
443 _ => return Poll::Ready(None),
444 };
445
446 buffer.push_str(&chunk);
447
448 let mut indented_new_text = String::new();
449 let mut start_ix = 0;
450 let mut newlines = buffer.match_indices('\n').peekable();
451 loop {
452 let (line_end, is_pending_line) = match newlines.next() {
453 Some((ix, _)) => (ix, false),
454 None => (buffer.len(), true),
455 };
456 let line = &buffer[start_ix..line_end];
457
458 if in_leading_whitespace {
459 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
460 // We found a non-whitespace character, adjust
461 // indentation based on the delta.
462 let new_indent_len =
463 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
464 indented_new_text
465 .extend(iter::repeat(delta.character()).take(new_indent_len));
466 indented_new_text.push_str(&line[non_whitespace_ix..]);
467 in_leading_whitespace = false;
468 } else if is_pending_line {
469 // We're still in leading whitespace and this line is incomplete.
470 // Stop processing until we receive more input.
471 break;
472 } else {
473 // This line is entirely whitespace. Push it without indentation.
474 indented_new_text.push_str(line);
475 }
476 } else {
477 indented_new_text.push_str(line);
478 }
479
480 if is_pending_line {
481 start_ix = line_end;
482 break;
483 } else {
484 in_leading_whitespace = true;
485 indented_new_text.push('\n');
486 start_ix = line_end + 1;
487 }
488 }
489 buffer.replace_range(..start_ix, "");
490
491 // This was the last chunk, push all the buffered content as-is.
492 if is_last_chunk {
493 indented_new_text.push_str(&buffer);
494 buffer.clear();
495 done = true;
496 }
497
498 if !indented_new_text.is_empty() {
499 return Poll::Ready(Some(Ok(indented_new_text)));
500 }
501 }
502
503 Poll::Ready(None)
504 })
505 }
506
507 async fn request(
508 &self,
509 mut messages: Vec<LanguageModelRequestMessage>,
510 prompt: String,
511 cx: &mut AsyncApp,
512 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
513 let mut message_content = Vec::new();
514 if let Some(last_message) = messages.last_mut() {
515 if last_message.role == Role::Assistant {
516 last_message
517 .content
518 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
519 if last_message.content.is_empty() {
520 messages.pop();
521 }
522 }
523 }
524 message_content.push(MessageContent::Text(prompt));
525 messages.push(LanguageModelRequestMessage {
526 role: Role::User,
527 content: message_content,
528 cache: false,
529 });
530
531 let request = LanguageModelRequest {
532 messages,
533 ..Default::default()
534 };
535 Ok(self.model.stream_completion_text(request, cx).await?.stream)
536 }
537
538 fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
539 let range = Self::resolve_location_exact(buffer, search_query)
540 .or_else(|| Self::resolve_location_fuzzy(buffer, search_query))?;
541
542 // Expand the range to include entire lines.
543 let mut start = buffer.offset_to_point(buffer.clip_offset(range.start, Bias::Left));
544 start.column = 0;
545 let mut end = buffer.offset_to_point(buffer.clip_offset(range.end, Bias::Right));
546 if end.column > 0 {
547 end.column = buffer.line_len(end.row);
548 }
549
550 Some(buffer.point_to_offset(start)..buffer.point_to_offset(end))
551 }
552
553 fn resolve_location_exact(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
554 let search = AhoCorasick::new([search_query]).ok()?;
555 let mat = search
556 .stream_find_iter(buffer.bytes_in_range(0..buffer.len()))
557 .next()?
558 .expect("buffer can't error");
559 Some(mat.range())
560 }
561
562 fn resolve_location_fuzzy(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
563 const INSERTION_COST: u32 = 3;
564 const DELETION_COST: u32 = 10;
565
566 let buffer_line_count = buffer.max_point().row as usize + 1;
567 let query_line_count = search_query.lines().count();
568 let mut matrix = SearchMatrix::new(query_line_count + 1, buffer_line_count + 1);
569 let mut leading_deletion_cost = 0_u32;
570 for (row, query_line) in search_query.lines().enumerate() {
571 let query_line = query_line.trim();
572 leading_deletion_cost = leading_deletion_cost.saturating_add(DELETION_COST);
573 matrix.set(
574 row + 1,
575 0,
576 SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
577 );
578
579 let mut buffer_lines = buffer.as_rope().chunks().lines();
580 let mut col = 0;
581 while let Some(buffer_line) = buffer_lines.next() {
582 let buffer_line = buffer_line.trim();
583 let up = SearchState::new(
584 matrix.get(row, col + 1).cost.saturating_add(DELETION_COST),
585 SearchDirection::Up,
586 );
587 let left = SearchState::new(
588 matrix.get(row + 1, col).cost.saturating_add(INSERTION_COST),
589 SearchDirection::Left,
590 );
591 let diagonal = SearchState::new(
592 if fuzzy_eq(query_line, buffer_line) {
593 matrix.get(row, col).cost
594 } else {
595 matrix
596 .get(row, col)
597 .cost
598 .saturating_add(DELETION_COST + INSERTION_COST)
599 },
600 SearchDirection::Diagonal,
601 );
602 matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
603 col += 1;
604 }
605 }
606
607 // Traceback to find the best match
608 let mut buffer_row_end = buffer_line_count as u32;
609 let mut best_cost = u32::MAX;
610 for col in 1..=buffer_line_count {
611 let cost = matrix.get(query_line_count, col).cost;
612 if cost < best_cost {
613 best_cost = cost;
614 buffer_row_end = col as u32;
615 }
616 }
617
618 let mut matched_lines = 0;
619 let mut query_row = query_line_count;
620 let mut buffer_row_start = buffer_row_end;
621 while query_row > 0 && buffer_row_start > 0 {
622 let current = matrix.get(query_row, buffer_row_start as usize);
623 match current.direction {
624 SearchDirection::Diagonal => {
625 query_row -= 1;
626 buffer_row_start -= 1;
627 matched_lines += 1;
628 }
629 SearchDirection::Up => {
630 query_row -= 1;
631 }
632 SearchDirection::Left => {
633 buffer_row_start -= 1;
634 }
635 }
636 }
637
638 let matched_buffer_row_count = buffer_row_end - buffer_row_start;
639 let matched_ratio =
640 matched_lines as f32 / (matched_buffer_row_count as f32).max(query_line_count as f32);
641 if matched_ratio >= 0.8 {
642 let buffer_start_ix = buffer.point_to_offset(Point::new(buffer_row_start, 0));
643 let buffer_end_ix = buffer.point_to_offset(Point::new(
644 buffer_row_end - 1,
645 buffer.line_len(buffer_row_end - 1),
646 ));
647 Some(buffer_start_ix..buffer_end_ix)
648 } else {
649 None
650 }
651 }
652}
653
654fn fuzzy_eq(left: &str, right: &str) -> bool {
655 const THRESHOLD: f64 = 0.8;
656
657 let min_levenshtein = left.len().abs_diff(right.len());
658 let min_normalized_levenshtein =
659 1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
660 if min_normalized_levenshtein < THRESHOLD {
661 return false;
662 }
663
664 strsim::normalized_levenshtein(left, right) >= THRESHOLD
665}
666
667#[derive(Copy, Clone, Debug)]
668enum IndentDelta {
669 Spaces(isize),
670 Tabs(isize),
671}
672
673impl IndentDelta {
674 fn character(&self) -> char {
675 match self {
676 IndentDelta::Spaces(_) => ' ',
677 IndentDelta::Tabs(_) => '\t',
678 }
679 }
680
681 fn len(&self) -> isize {
682 match self {
683 IndentDelta::Spaces(n) => *n,
684 IndentDelta::Tabs(n) => *n,
685 }
686 }
687}
688
689#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
690enum SearchDirection {
691 Up,
692 Left,
693 Diagonal,
694}
695
696#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
697struct SearchState {
698 cost: u32,
699 direction: SearchDirection,
700}
701
702impl SearchState {
703 fn new(cost: u32, direction: SearchDirection) -> Self {
704 Self { cost, direction }
705 }
706}
707
708struct SearchMatrix {
709 cols: usize,
710 data: Vec<SearchState>,
711}
712
713impl SearchMatrix {
714 fn new(rows: usize, cols: usize) -> Self {
715 SearchMatrix {
716 cols,
717 data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
718 }
719 }
720
721 fn get(&self, row: usize, col: usize) -> SearchState {
722 self.data[row * self.cols + col]
723 }
724
725 fn set(&mut self, row: usize, col: usize, cost: SearchState) {
726 self.data[row * self.cols + col] = cost;
727 }
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733 use fs::FakeFs;
734 use futures::stream;
735 use gpui::{App, AppContext, TestAppContext};
736 use indoc::indoc;
737 use language_model::fake_provider::FakeLanguageModel;
738 use project::{AgentLocation, Project};
739 use rand::prelude::*;
740 use rand::rngs::StdRng;
741 use std::cmp;
742 use unindent::Unindent;
743 use util::test::{generate_marked_text, marked_text_ranges};
744
745 #[gpui::test(iterations = 100)]
746 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
747 let agent = init_test(cx).await;
748 let buffer = cx.new(|cx| {
749 Buffer::local(
750 indoc! {"
751 lorem
752 ipsum
753 dolor
754 sit
755 "},
756 cx,
757 )
758 });
759 let raw_edits = simulate_llm_output(
760 indoc! {"
761 <old_text>
762 ipsum
763 dolor
764 sit
765 </old_text>
766 <new_text>
767 ipsum
768 dolor
769 sit
770 amet
771 </new_text>
772 "},
773 &mut rng,
774 cx,
775 );
776 let (apply, _events) =
777 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
778 apply.await.unwrap();
779 pretty_assertions::assert_eq!(
780 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
781 indoc! {"
782 lorem
783 ipsum
784 dolor
785 sit
786 amet
787 "}
788 );
789 }
790
791 #[gpui::test(iterations = 100)]
792 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
793 let agent = init_test(cx).await;
794 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
795 let raw_edits = simulate_llm_output(
796 indoc! {"
797 <old_text>
798 def
799 </old_text>
800 <new_text>
801 DEF
802 </new_text>
803
804 <old_text>
805 DEF
806 </old_text>
807 <new_text>
808 DeF
809 </new_text>
810 "},
811 &mut rng,
812 cx,
813 );
814 let (apply, _events) =
815 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
816 apply.await.unwrap();
817 assert_eq!(
818 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
819 "abc\nDeF\nghi"
820 );
821 }
822
823 #[gpui::test(iterations = 100)]
824 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
825 let agent = init_test(cx).await;
826 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
827 let raw_edits = simulate_llm_output(
828 indoc! {"
829 <old_text>
830 jkl
831 </old_text>
832 <new_text>
833 mno
834 </new_text>
835
836 <old_text>
837 abc
838 </old_text>
839 <new_text>
840 ABC
841 </new_text>
842 "},
843 &mut rng,
844 cx,
845 );
846 let (apply, _events) =
847 agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
848 apply.await.unwrap();
849 assert_eq!(
850 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
851 "ABC\ndef\nghi"
852 );
853 }
854
855 #[gpui::test]
856 async fn test_edit_events(cx: &mut TestAppContext) {
857 let agent = init_test(cx).await;
858 let project = agent
859 .action_log
860 .read_with(cx, |log, _| log.project().clone());
861 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
862 let (chunks_tx, chunks_rx) = mpsc::unbounded();
863 let (apply, mut events) = agent.apply_edit_chunks(
864 buffer.clone(),
865 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
866 &mut cx.to_async(),
867 );
868
869 chunks_tx.unbounded_send("<old_text>a").unwrap();
870 cx.run_until_parked();
871 assert_eq!(drain_events(&mut events), vec![]);
872 assert_eq!(
873 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
874 "abc\ndef\nghi"
875 );
876 assert_eq!(
877 project.read_with(cx, |project, _| project.agent_location()),
878 None
879 );
880
881 chunks_tx.unbounded_send("bc</old_text>").unwrap();
882 cx.run_until_parked();
883 assert_eq!(drain_events(&mut events), vec![]);
884 assert_eq!(
885 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
886 "abc\ndef\nghi"
887 );
888 assert_eq!(
889 project.read_with(cx, |project, _| project.agent_location()),
890 None
891 );
892
893 chunks_tx.unbounded_send("<new_text>abX").unwrap();
894 cx.run_until_parked();
895 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
896 assert_eq!(
897 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
898 "abXc\ndef\nghi"
899 );
900 assert_eq!(
901 project.read_with(cx, |project, _| project.agent_location()),
902 Some(AgentLocation {
903 buffer: buffer.downgrade(),
904 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
905 })
906 );
907
908 chunks_tx.unbounded_send("cY").unwrap();
909 cx.run_until_parked();
910 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
911 assert_eq!(
912 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
913 "abXcY\ndef\nghi"
914 );
915 assert_eq!(
916 project.read_with(cx, |project, _| project.agent_location()),
917 Some(AgentLocation {
918 buffer: buffer.downgrade(),
919 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
920 })
921 );
922
923 chunks_tx.unbounded_send("</new_text>").unwrap();
924 chunks_tx.unbounded_send("<old_text>hall").unwrap();
925 cx.run_until_parked();
926 assert_eq!(drain_events(&mut events), vec![]);
927 assert_eq!(
928 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
929 "abXcY\ndef\nghi"
930 );
931 assert_eq!(
932 project.read_with(cx, |project, _| project.agent_location()),
933 Some(AgentLocation {
934 buffer: buffer.downgrade(),
935 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
936 })
937 );
938
939 chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
940 chunks_tx.unbounded_send("<new_text>").unwrap();
941 cx.run_until_parked();
942 assert_eq!(
943 drain_events(&mut events),
944 vec![EditAgentOutputEvent::OldTextNotFound(
945 "hallucinated old".into()
946 )]
947 );
948 assert_eq!(
949 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
950 "abXcY\ndef\nghi"
951 );
952 assert_eq!(
953 project.read_with(cx, |project, _| project.agent_location()),
954 Some(AgentLocation {
955 buffer: buffer.downgrade(),
956 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
957 })
958 );
959
960 chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
961 chunks_tx.unbounded_send("text>").unwrap();
962 cx.run_until_parked();
963 assert_eq!(drain_events(&mut events), vec![]);
964 assert_eq!(
965 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
966 "abXcY\ndef\nghi"
967 );
968 assert_eq!(
969 project.read_with(cx, |project, _| project.agent_location()),
970 Some(AgentLocation {
971 buffer: buffer.downgrade(),
972 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
973 })
974 );
975
976 chunks_tx.unbounded_send("<old_text>gh").unwrap();
977 chunks_tx.unbounded_send("i</old_text>").unwrap();
978 chunks_tx.unbounded_send("<new_text>").unwrap();
979 cx.run_until_parked();
980 assert_eq!(drain_events(&mut events), vec![]);
981 assert_eq!(
982 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
983 "abXcY\ndef\nghi"
984 );
985 assert_eq!(
986 project.read_with(cx, |project, _| project.agent_location()),
987 Some(AgentLocation {
988 buffer: buffer.downgrade(),
989 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
990 })
991 );
992
993 chunks_tx.unbounded_send("GHI</new_text>").unwrap();
994 cx.run_until_parked();
995 assert_eq!(
996 drain_events(&mut events),
997 vec![EditAgentOutputEvent::Edited]
998 );
999 assert_eq!(
1000 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1001 "abXcY\ndef\nGHI"
1002 );
1003 assert_eq!(
1004 project.read_with(cx, |project, _| project.agent_location()),
1005 Some(AgentLocation {
1006 buffer: buffer.downgrade(),
1007 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1008 })
1009 );
1010
1011 drop(chunks_tx);
1012 apply.await.unwrap();
1013 assert_eq!(
1014 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1015 "abXcY\ndef\nGHI"
1016 );
1017 assert_eq!(drain_events(&mut events), vec![]);
1018 assert_eq!(
1019 project.read_with(cx, |project, _| project.agent_location()),
1020 None
1021 );
1022 }
1023
1024 #[gpui::test]
1025 async fn test_overwrite_events(cx: &mut TestAppContext) {
1026 let agent = init_test(cx).await;
1027 let project = agent
1028 .action_log
1029 .read_with(cx, |log, _| log.project().clone());
1030 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1031 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1032 let (apply, mut events) = agent.overwrite_with_chunks(
1033 buffer.clone(),
1034 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1035 &mut cx.to_async(),
1036 );
1037
1038 cx.run_until_parked();
1039 assert_eq!(
1040 drain_events(&mut events),
1041 vec![EditAgentOutputEvent::Edited]
1042 );
1043 assert_eq!(
1044 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1045 ""
1046 );
1047 assert_eq!(
1048 project.read_with(cx, |project, _| project.agent_location()),
1049 Some(AgentLocation {
1050 buffer: buffer.downgrade(),
1051 position: language::Anchor::MAX
1052 })
1053 );
1054
1055 chunks_tx.unbounded_send("jkl\n").unwrap();
1056 cx.run_until_parked();
1057 assert_eq!(
1058 drain_events(&mut events),
1059 vec![EditAgentOutputEvent::Edited]
1060 );
1061 assert_eq!(
1062 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1063 "jkl\n"
1064 );
1065 assert_eq!(
1066 project.read_with(cx, |project, _| project.agent_location()),
1067 Some(AgentLocation {
1068 buffer: buffer.downgrade(),
1069 position: language::Anchor::MAX
1070 })
1071 );
1072
1073 chunks_tx.unbounded_send("mno\n").unwrap();
1074 cx.run_until_parked();
1075 assert_eq!(
1076 drain_events(&mut events),
1077 vec![EditAgentOutputEvent::Edited]
1078 );
1079 assert_eq!(
1080 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1081 "jkl\nmno\n"
1082 );
1083 assert_eq!(
1084 project.read_with(cx, |project, _| project.agent_location()),
1085 Some(AgentLocation {
1086 buffer: buffer.downgrade(),
1087 position: language::Anchor::MAX
1088 })
1089 );
1090
1091 chunks_tx.unbounded_send("pqr").unwrap();
1092 cx.run_until_parked();
1093 assert_eq!(
1094 drain_events(&mut events),
1095 vec![EditAgentOutputEvent::Edited]
1096 );
1097 assert_eq!(
1098 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1099 "jkl\nmno\npqr"
1100 );
1101 assert_eq!(
1102 project.read_with(cx, |project, _| project.agent_location()),
1103 Some(AgentLocation {
1104 buffer: buffer.downgrade(),
1105 position: language::Anchor::MAX
1106 })
1107 );
1108
1109 drop(chunks_tx);
1110 apply.await.unwrap();
1111 assert_eq!(
1112 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1113 "jkl\nmno\npqr"
1114 );
1115 assert_eq!(drain_events(&mut events), vec![]);
1116 assert_eq!(
1117 project.read_with(cx, |project, _| project.agent_location()),
1118 None
1119 );
1120 }
1121
1122 #[gpui::test]
1123 fn test_resolve_location(cx: &mut App) {
1124 assert_location_resolution(
1125 concat!(
1126 " Lorem\n",
1127 "« ipsum»\n",
1128 " dolor sit amet\n",
1129 " consecteur",
1130 ),
1131 "ipsum",
1132 cx,
1133 );
1134
1135 assert_location_resolution(
1136 concat!(
1137 " Lorem\n",
1138 "« ipsum\n",
1139 " dolor sit amet»\n",
1140 " consecteur",
1141 ),
1142 "ipsum\ndolor sit amet",
1143 cx,
1144 );
1145
1146 assert_location_resolution(
1147 &"
1148 «fn foo1(a: usize) -> usize {
1149 40
1150 }»
1151
1152 fn foo2(b: usize) -> usize {
1153 42
1154 }
1155 "
1156 .unindent(),
1157 "fn foo1(a: usize) -> u32 {\n40\n}",
1158 cx,
1159 );
1160
1161 assert_location_resolution(
1162 &"
1163 class Something {
1164 one() { return 1; }
1165 « two() { return 2222; }
1166 three() { return 333; }
1167 four() { return 4444; }
1168 five() { return 5555; }
1169 six() { return 6666; }»
1170 seven() { return 7; }
1171 eight() { return 8; }
1172 }
1173 "
1174 .unindent(),
1175 &"
1176 two() { return 2222; }
1177 four() { return 4444; }
1178 five() { return 5555; }
1179 six() { return 6666; }
1180 "
1181 .unindent(),
1182 cx,
1183 );
1184
1185 assert_location_resolution(
1186 &"
1187 use std::ops::Range;
1188 use std::sync::Mutex;
1189 use std::{
1190 collections::HashMap,
1191 env,
1192 ffi::{OsStr, OsString},
1193 fs,
1194 io::{BufRead, BufReader},
1195 mem,
1196 path::{Path, PathBuf},
1197 process::Command,
1198 sync::LazyLock,
1199 time::SystemTime,
1200 };
1201 "
1202 .unindent(),
1203 &"
1204 use std::collections::{HashMap, HashSet};
1205 use std::ffi::{OsStr, OsString};
1206 use std::fmt::Write as _;
1207 use std::fs;
1208 use std::io::{BufReader, Read, Write};
1209 use std::mem;
1210 use std::path::{Path, PathBuf};
1211 use std::process::Command;
1212 use std::sync::Arc;
1213 "
1214 .unindent(),
1215 cx,
1216 );
1217
1218 assert_location_resolution(
1219 indoc! {"
1220 impl Foo {
1221 fn new() -> Self {
1222 Self {
1223 subscriptions: vec![
1224 cx.observe_window_activation(window, |editor, window, cx| {
1225 let active = window.is_window_active();
1226 editor.blink_manager.update(cx, |blink_manager, cx| {
1227 if active {
1228 blink_manager.enable(cx);
1229 } else {
1230 blink_manager.disable(cx);
1231 }
1232 });
1233 }),
1234 ];
1235 }
1236 }
1237 }
1238 "},
1239 concat!(
1240 " editor.blink_manager.update(cx, |blink_manager, cx| {\n",
1241 " blink_manager.enable(cx);\n",
1242 " });",
1243 ),
1244 cx,
1245 );
1246
1247 assert_location_resolution(
1248 indoc! {r#"
1249 let tool = cx
1250 .update(|cx| working_set.tool(&tool_name, cx))
1251 .map_err(|err| {
1252 anyhow!("Failed to look up tool '{}': {}", tool_name, err)
1253 })?;
1254
1255 let Some(tool) = tool else {
1256 return Err(anyhow!("Tool '{}' not found", tool_name));
1257 };
1258
1259 let project = project.clone();
1260 let action_log = action_log.clone();
1261 let messages = messages.clone();
1262 let tool_result = cx
1263 .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
1264 .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
1265
1266 tasks.push(tool_result.output);
1267 "#},
1268 concat!(
1269 "let tool_result = cx\n",
1270 " .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
1271 " .output;",
1272 ),
1273 cx,
1274 );
1275 }
1276
1277 #[gpui::test(iterations = 100)]
1278 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1279 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1280 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1281 Ok(EditParserEvent::NewTextChunk {
1282 chunk: chunk.clone(),
1283 done: index == chunks.len() - 1,
1284 })
1285 }));
1286 let indented_chunks =
1287 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1288 .collect::<Vec<_>>()
1289 .await;
1290 let new_text = indented_chunks
1291 .into_iter()
1292 .collect::<Result<String>>()
1293 .unwrap();
1294 assert_eq!(new_text, " abc\n def\n ghi");
1295 }
1296
1297 #[gpui::test(iterations = 100)]
1298 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1299 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1300 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1301 Ok(EditParserEvent::NewTextChunk {
1302 chunk: chunk.clone(),
1303 done: index == chunks.len() - 1,
1304 })
1305 }));
1306 let indented_chunks =
1307 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1308 .collect::<Vec<_>>()
1309 .await;
1310 let new_text = indented_chunks
1311 .into_iter()
1312 .collect::<Result<String>>()
1313 .unwrap();
1314 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1315 }
1316
1317 #[gpui::test(iterations = 100)]
1318 async fn test_random_indents(mut rng: StdRng) {
1319 let len = rng.gen_range(1..=100);
1320 let new_text = util::RandomCharIter::new(&mut rng)
1321 .with_simple_text()
1322 .take(len)
1323 .collect::<String>();
1324 let new_text = new_text
1325 .split('\n')
1326 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1327 .collect::<Vec<_>>()
1328 .join("\n");
1329 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1330
1331 let chunks = to_random_chunks(&mut rng, &new_text);
1332 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1333 Ok(EditParserEvent::NewTextChunk {
1334 chunk: chunk.clone(),
1335 done: index == chunks.len() - 1,
1336 })
1337 }));
1338 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1339 .collect::<Vec<_>>()
1340 .await;
1341 let actual_reindented_text = reindented_chunks
1342 .into_iter()
1343 .collect::<Result<String>>()
1344 .unwrap();
1345 let expected_reindented_text = new_text
1346 .split('\n')
1347 .map(|line| {
1348 if let Some(ix) = line.find(|c| c != ' ') {
1349 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1350 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1351 } else {
1352 line.to_string()
1353 }
1354 })
1355 .collect::<Vec<_>>()
1356 .join("\n");
1357 assert_eq!(actual_reindented_text, expected_reindented_text);
1358 }
1359
1360 #[track_caller]
1361 fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) {
1362 let (text, _) = marked_text_ranges(text_with_expected_range, false);
1363 let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
1364 let snapshot = buffer.read(cx).snapshot();
1365 let mut ranges = Vec::new();
1366 ranges.extend(EditAgent::resolve_location(&snapshot, query));
1367 let text_with_actual_range = generate_marked_text(&text, &ranges, false);
1368 pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
1369 }
1370
1371 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1372 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1373 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1374 chunk_indices.sort();
1375 chunk_indices.push(input.len());
1376
1377 let mut chunks = Vec::new();
1378 let mut last_ix = 0;
1379 for chunk_ix in chunk_indices {
1380 chunks.push(input[last_ix..chunk_ix].to_string());
1381 last_ix = chunk_ix;
1382 }
1383 chunks
1384 }
1385
1386 fn simulate_llm_output(
1387 output: &str,
1388 rng: &mut StdRng,
1389 cx: &mut TestAppContext,
1390 ) -> impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>> {
1391 let executor = cx.executor();
1392 stream::iter(to_random_chunks(rng, output).into_iter().map(Ok)).then(move |chunk| {
1393 let executor = executor.clone();
1394 async move {
1395 executor.simulate_random_delay().await;
1396 chunk
1397 }
1398 })
1399 }
1400
1401 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1402 cx.update(settings::init);
1403 cx.update(Project::init_settings);
1404 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1405 let model = Arc::new(FakeLanguageModel::default());
1406 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1407 EditAgent::new(model, project, action_log, Templates::new())
1408 }
1409
1410 fn drain_events(
1411 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1412 ) -> Vec<EditAgentOutputEvent> {
1413 let mut events = Vec::new();
1414 while let Ok(Some(event)) = stream.try_next() {
1415 events.push(event);
1416 }
1417 events
1418 }
1419}