1mod create_file_parser;
2mod edit_parser;
3#[cfg(test)]
4mod evals;
5mod streaming_fuzzy_matcher;
6
7use crate::{Template, Templates};
8use anyhow::Result;
9use assistant_tool::ActionLog;
10use create_file_parser::{CreateFileParser, CreateFileParserEvent};
11use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
12use futures::{
13 Stream, StreamExt,
14 channel::mpsc::{self, UnboundedReceiver},
15 pin_mut,
16 stream::BoxStream,
17};
18use gpui::{AppContext, AsyncApp, Entity, Task};
19use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot};
20use language_model::{
21 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
22 LanguageModelToolChoice, MessageContent, Role,
23};
24use project::{AgentLocation, Project};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
28use streaming_diff::{CharOperation, StreamingDiff};
29use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
30use util::debug_panic;
31
32#[derive(Serialize)]
33struct CreateFilePromptTemplate {
34 path: Option<PathBuf>,
35 edit_description: String,
36}
37
38impl Template for CreateFilePromptTemplate {
39 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
40}
41
42#[derive(Serialize)]
43struct EditFilePromptTemplate {
44 path: Option<PathBuf>,
45 edit_description: String,
46}
47
48impl Template for EditFilePromptTemplate {
49 const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
50}
51
52#[derive(Clone, Debug, PartialEq, Eq)]
53pub enum EditAgentOutputEvent {
54 ResolvingEditRange(Range<Anchor>),
55 UnresolvedEditRange,
56 Edited,
57}
58
59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
60pub struct EditAgentOutput {
61 pub raw_edits: String,
62 pub parser_metrics: EditParserMetrics,
63}
64
65#[derive(Clone)]
66pub struct EditAgent {
67 model: Arc<dyn LanguageModel>,
68 action_log: Entity<ActionLog>,
69 project: Entity<Project>,
70 templates: Arc<Templates>,
71}
72
73impl EditAgent {
74 pub fn new(
75 model: Arc<dyn LanguageModel>,
76 project: Entity<Project>,
77 action_log: Entity<ActionLog>,
78 templates: Arc<Templates>,
79 ) -> Self {
80 EditAgent {
81 model,
82 project,
83 action_log,
84 templates,
85 }
86 }
87
88 pub fn overwrite(
89 &self,
90 buffer: Entity<Buffer>,
91 edit_description: String,
92 conversation: &LanguageModelRequest,
93 cx: &mut AsyncApp,
94 ) -> (
95 Task<Result<EditAgentOutput>>,
96 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
97 ) {
98 let this = self.clone();
99 let (events_tx, events_rx) = mpsc::unbounded();
100 let conversation = conversation.clone();
101 let output = cx.spawn(async move |cx| {
102 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
103 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
104 let prompt = CreateFilePromptTemplate {
105 path,
106 edit_description,
107 }
108 .render(&this.templates)?;
109 let new_chunks = this.request(conversation, prompt, cx).await?;
110
111 let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
112 while let Some(event) = inner_events.next().await {
113 events_tx.unbounded_send(event).ok();
114 }
115 output.await
116 });
117 (output, events_rx)
118 }
119
120 fn overwrite_with_chunks(
121 &self,
122 buffer: Entity<Buffer>,
123 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
124 cx: &mut AsyncApp,
125 ) -> (
126 Task<Result<EditAgentOutput>>,
127 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
128 ) {
129 let (output_events_tx, output_events_rx) = mpsc::unbounded();
130 let (parse_task, parse_rx) = Self::parse_create_file_chunks(edit_chunks, cx);
131 let this = self.clone();
132 let task = cx.spawn(async move |cx| {
133 this.action_log
134 .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
135 this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx)
136 .await?;
137 parse_task.await
138 });
139 (task, output_events_rx)
140 }
141
142 async fn overwrite_with_chunks_internal(
143 &self,
144 buffer: Entity<Buffer>,
145 mut parse_rx: UnboundedReceiver<Result<CreateFileParserEvent>>,
146 output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
147 cx: &mut AsyncApp,
148 ) -> Result<()> {
149 cx.update(|cx| {
150 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
151 self.action_log.update(cx, |log, cx| {
152 log.buffer_edited(buffer.clone(), cx);
153 });
154 self.project.update(cx, |project, cx| {
155 project.set_agent_location(
156 Some(AgentLocation {
157 buffer: buffer.downgrade(),
158 position: language::Anchor::MAX,
159 }),
160 cx,
161 )
162 });
163 output_events_tx
164 .unbounded_send(EditAgentOutputEvent::Edited)
165 .ok();
166 })?;
167
168 while let Some(event) = parse_rx.next().await {
169 match event? {
170 CreateFileParserEvent::NewTextChunk { chunk } => {
171 cx.update(|cx| {
172 buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
173 self.action_log
174 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
175 self.project.update(cx, |project, cx| {
176 project.set_agent_location(
177 Some(AgentLocation {
178 buffer: buffer.downgrade(),
179 position: language::Anchor::MAX,
180 }),
181 cx,
182 )
183 });
184 })?;
185 output_events_tx
186 .unbounded_send(EditAgentOutputEvent::Edited)
187 .ok();
188 }
189 }
190 }
191
192 Ok(())
193 }
194
195 pub fn edit(
196 &self,
197 buffer: Entity<Buffer>,
198 edit_description: String,
199 conversation: &LanguageModelRequest,
200 cx: &mut AsyncApp,
201 ) -> (
202 Task<Result<EditAgentOutput>>,
203 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
204 ) {
205 let this = self.clone();
206 let (events_tx, events_rx) = mpsc::unbounded();
207 let conversation = conversation.clone();
208 let output = cx.spawn(async move |cx| {
209 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
210 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
211 let prompt = EditFilePromptTemplate {
212 path,
213 edit_description,
214 }
215 .render(&this.templates)?;
216 let edit_chunks = this.request(conversation, prompt, cx).await?;
217 this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
218 .await
219 });
220 (output, events_rx)
221 }
222
223 async fn apply_edit_chunks(
224 &self,
225 buffer: Entity<Buffer>,
226 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
227 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
228 cx: &mut AsyncApp,
229 ) -> Result<EditAgentOutput> {
230 self.action_log
231 .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
232
233 let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
234 let mut edit_events = edit_events.peekable();
235 while let Some(edit_event) = Pin::new(&mut edit_events).peek().await {
236 // Skip events until we're at the start of a new edit.
237 let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else {
238 edit_events.next().await.unwrap()?;
239 continue;
240 };
241
242 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
243
244 // Resolve the old text in the background, updating the agent
245 // location as we keep refining which range it corresponds to.
246 let (resolve_old_text, mut old_range) =
247 Self::resolve_old_text(snapshot.text.clone(), edit_events, cx);
248 while let Ok(old_range) = old_range.recv().await {
249 if let Some(old_range) = old_range {
250 let old_range = snapshot.anchor_before(old_range.start)
251 ..snapshot.anchor_before(old_range.end);
252 self.project.update(cx, |project, cx| {
253 project.set_agent_location(
254 Some(AgentLocation {
255 buffer: buffer.downgrade(),
256 position: old_range.end,
257 }),
258 cx,
259 );
260 })?;
261 output_events
262 .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range))
263 .ok();
264 }
265 }
266
267 let (edit_events_, resolved_old_text) = resolve_old_text.await?;
268 edit_events = edit_events_;
269
270 // If we can't resolve the old text, restart the loop waiting for a
271 // new edit (or for the stream to end).
272 let Some(resolved_old_text) = resolved_old_text else {
273 output_events
274 .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
275 .ok();
276 continue;
277 };
278
279 // Compute edits in the background and apply them as they become
280 // available.
281 let (compute_edits, edits) =
282 Self::compute_edits(snapshot, resolved_old_text, edit_events, cx);
283 let mut edits = edits.ready_chunks(32);
284 while let Some(edits) = edits.next().await {
285 if edits.is_empty() {
286 continue;
287 }
288
289 // Edit the buffer and report edits to the action log as part of the
290 // same effect cycle, otherwise the edit will be reported as if the
291 // user made it.
292 cx.update(|cx| {
293 let max_edit_end = buffer.update(cx, |buffer, cx| {
294 buffer.edit(edits.iter().cloned(), None, cx);
295 let max_edit_end = buffer
296 .summaries_for_anchors::<Point, _>(
297 edits.iter().map(|(range, _)| &range.end),
298 )
299 .max()
300 .unwrap();
301 buffer.anchor_before(max_edit_end)
302 });
303 self.action_log
304 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
305 self.project.update(cx, |project, cx| {
306 project.set_agent_location(
307 Some(AgentLocation {
308 buffer: buffer.downgrade(),
309 position: max_edit_end,
310 }),
311 cx,
312 );
313 });
314 })?;
315 output_events
316 .unbounded_send(EditAgentOutputEvent::Edited)
317 .ok();
318 }
319
320 edit_events = compute_edits.await?;
321 }
322
323 output.await
324 }
325
326 fn parse_edit_chunks(
327 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
328 cx: &mut AsyncApp,
329 ) -> (
330 Task<Result<EditAgentOutput>>,
331 UnboundedReceiver<Result<EditParserEvent>>,
332 ) {
333 let (tx, rx) = mpsc::unbounded();
334 let output = cx.background_spawn(async move {
335 pin_mut!(chunks);
336
337 let mut parser = EditParser::new();
338 let mut raw_edits = String::new();
339 while let Some(chunk) = chunks.next().await {
340 match chunk {
341 Ok(chunk) => {
342 raw_edits.push_str(&chunk);
343 for event in parser.push(&chunk) {
344 tx.unbounded_send(Ok(event))?;
345 }
346 }
347 Err(error) => {
348 tx.unbounded_send(Err(error.into()))?;
349 }
350 }
351 }
352 Ok(EditAgentOutput {
353 raw_edits,
354 parser_metrics: parser.finish(),
355 })
356 });
357 (output, rx)
358 }
359
360 fn parse_create_file_chunks(
361 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
362 cx: &mut AsyncApp,
363 ) -> (
364 Task<Result<EditAgentOutput>>,
365 UnboundedReceiver<Result<CreateFileParserEvent>>,
366 ) {
367 let (tx, rx) = mpsc::unbounded();
368 let output = cx.background_spawn(async move {
369 pin_mut!(chunks);
370
371 let mut parser = CreateFileParser::new();
372 let mut raw_edits = String::new();
373 while let Some(chunk) = chunks.next().await {
374 match chunk {
375 Ok(chunk) => {
376 raw_edits.push_str(&chunk);
377 for event in parser.push(Some(&chunk)) {
378 tx.unbounded_send(Ok(event))?;
379 }
380 }
381 Err(error) => {
382 tx.unbounded_send(Err(error.into()))?;
383 }
384 }
385 }
386 // Send final events with None to indicate completion
387 for event in parser.push(None) {
388 tx.unbounded_send(Ok(event))?;
389 }
390 Ok(EditAgentOutput {
391 raw_edits,
392 parser_metrics: EditParserMetrics::default(),
393 })
394 });
395 (output, rx)
396 }
397
398 fn resolve_old_text<T>(
399 snapshot: TextBufferSnapshot,
400 mut edit_events: T,
401 cx: &mut AsyncApp,
402 ) -> (
403 Task<Result<(T, Option<ResolvedOldText>)>>,
404 async_watch::Receiver<Option<Range<usize>>>,
405 )
406 where
407 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
408 {
409 let (old_range_tx, old_range_rx) = async_watch::channel(None);
410 let task = cx.background_spawn(async move {
411 let mut matcher = StreamingFuzzyMatcher::new(snapshot);
412 while let Some(edit_event) = edit_events.next().await {
413 let EditParserEvent::OldTextChunk { chunk, done } = edit_event? else {
414 break;
415 };
416
417 old_range_tx.send(matcher.push(&chunk))?;
418 if done {
419 break;
420 }
421 }
422
423 let old_range = matcher.finish();
424 old_range_tx.send(old_range.clone())?;
425 if let Some(old_range) = old_range {
426 let line_indent =
427 LineIndent::from_iter(matcher.query_lines().first().unwrap().chars());
428 Ok((
429 edit_events,
430 Some(ResolvedOldText {
431 range: old_range,
432 indent: line_indent,
433 }),
434 ))
435 } else {
436 Ok((edit_events, None))
437 }
438 });
439
440 (task, old_range_rx)
441 }
442
443 fn compute_edits<T>(
444 snapshot: BufferSnapshot,
445 resolved_old_text: ResolvedOldText,
446 mut edit_events: T,
447 cx: &mut AsyncApp,
448 ) -> (
449 Task<Result<T>>,
450 UnboundedReceiver<(Range<Anchor>, Arc<str>)>,
451 )
452 where
453 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
454 {
455 let (edits_tx, edits_rx) = mpsc::unbounded();
456 let compute_edits = cx.background_spawn(async move {
457 let buffer_start_indent = snapshot
458 .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row);
459 let indent_delta = if buffer_start_indent.tabs > 0 {
460 IndentDelta::Tabs(
461 buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize,
462 )
463 } else {
464 IndentDelta::Spaces(
465 buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize,
466 )
467 };
468
469 let old_text = snapshot
470 .text_for_range(resolved_old_text.range.clone())
471 .collect::<String>();
472 let mut diff = StreamingDiff::new(old_text);
473 let mut edit_start = resolved_old_text.range.start;
474 let mut new_text_chunks =
475 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
476 let mut done = false;
477 while !done {
478 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await {
479 diff.push_new(&new_text_chunk?)
480 } else {
481 done = true;
482 mem::take(&mut diff).finish()
483 };
484
485 for op in char_operations {
486 match op {
487 CharOperation::Insert { text } => {
488 let edit_start = snapshot.anchor_after(edit_start);
489 edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
490 }
491 CharOperation::Delete { bytes } => {
492 let edit_end = edit_start + bytes;
493 let edit_range =
494 snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end);
495 edit_start = edit_end;
496 edits_tx.unbounded_send((edit_range, Arc::from("")))?;
497 }
498 CharOperation::Keep { bytes } => edit_start += bytes,
499 }
500 }
501 }
502
503 drop(new_text_chunks);
504 anyhow::Ok(edit_events)
505 });
506
507 (compute_edits, edits_rx)
508 }
509
510 fn reindent_new_text_chunks(
511 delta: IndentDelta,
512 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
513 ) -> impl Stream<Item = Result<String>> {
514 let mut buffer = String::new();
515 let mut in_leading_whitespace = true;
516 let mut done = false;
517 futures::stream::poll_fn(move |cx| {
518 while !done {
519 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
520 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
521 (chunk, done)
522 }
523 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
524 Poll::Pending => return Poll::Pending,
525 _ => return Poll::Ready(None),
526 };
527
528 buffer.push_str(&chunk);
529
530 let mut indented_new_text = String::new();
531 let mut start_ix = 0;
532 let mut newlines = buffer.match_indices('\n').peekable();
533 loop {
534 let (line_end, is_pending_line) = match newlines.next() {
535 Some((ix, _)) => (ix, false),
536 None => (buffer.len(), true),
537 };
538 let line = &buffer[start_ix..line_end];
539
540 if in_leading_whitespace {
541 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
542 // We found a non-whitespace character, adjust
543 // indentation based on the delta.
544 let new_indent_len =
545 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
546 indented_new_text
547 .extend(iter::repeat(delta.character()).take(new_indent_len));
548 indented_new_text.push_str(&line[non_whitespace_ix..]);
549 in_leading_whitespace = false;
550 } else if is_pending_line {
551 // We're still in leading whitespace and this line is incomplete.
552 // Stop processing until we receive more input.
553 break;
554 } else {
555 // This line is entirely whitespace. Push it without indentation.
556 indented_new_text.push_str(line);
557 }
558 } else {
559 indented_new_text.push_str(line);
560 }
561
562 if is_pending_line {
563 start_ix = line_end;
564 break;
565 } else {
566 in_leading_whitespace = true;
567 indented_new_text.push('\n');
568 start_ix = line_end + 1;
569 }
570 }
571 buffer.replace_range(..start_ix, "");
572
573 // This was the last chunk, push all the buffered content as-is.
574 if is_last_chunk {
575 indented_new_text.push_str(&buffer);
576 buffer.clear();
577 done = true;
578 }
579
580 if !indented_new_text.is_empty() {
581 return Poll::Ready(Some(Ok(indented_new_text)));
582 }
583 }
584
585 Poll::Ready(None)
586 })
587 }
588
589 async fn request(
590 &self,
591 mut conversation: LanguageModelRequest,
592 prompt: String,
593 cx: &mut AsyncApp,
594 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
595 let mut messages_iter = conversation.messages.iter_mut();
596 if let Some(last_message) = messages_iter.next_back() {
597 if last_message.role == Role::Assistant {
598 let old_content_len = last_message.content.len();
599 last_message
600 .content
601 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
602 let new_content_len = last_message.content.len();
603
604 // We just removed pending tool uses from the content of the
605 // last message, so it doesn't make sense to cache it anymore
606 // (e.g., the message will look very different on the next
607 // request). Thus, we move the flag to the message prior to it,
608 // as it will still be a valid prefix of the conversation.
609 if old_content_len != new_content_len && last_message.cache {
610 if let Some(prev_message) = messages_iter.next_back() {
611 last_message.cache = false;
612 prev_message.cache = true;
613 }
614 }
615
616 if last_message.content.is_empty() {
617 conversation.messages.pop();
618 }
619 } else {
620 debug_panic!(
621 "Last message must be an Assistant tool calling! Got {:?}",
622 last_message.content
623 );
624 }
625 }
626
627 conversation.messages.push(LanguageModelRequestMessage {
628 role: Role::User,
629 content: vec![MessageContent::Text(prompt)],
630 cache: false,
631 });
632
633 // Include tools in the request so that we can take advantage of
634 // caching when ToolChoice::None is supported.
635 let mut tool_choice = None;
636 let mut tools = Vec::new();
637 if !conversation.tools.is_empty()
638 && self
639 .model
640 .supports_tool_choice(LanguageModelToolChoice::None)
641 {
642 tool_choice = Some(LanguageModelToolChoice::None);
643 tools = conversation.tools.clone();
644 }
645
646 let request = LanguageModelRequest {
647 thread_id: conversation.thread_id,
648 prompt_id: conversation.prompt_id,
649 mode: conversation.mode,
650 messages: conversation.messages,
651 tool_choice,
652 tools,
653 stop: Vec::new(),
654 temperature: None,
655 };
656
657 Ok(self.model.stream_completion_text(request, cx).await?.stream)
658 }
659}
660
661struct ResolvedOldText {
662 range: Range<usize>,
663 indent: LineIndent,
664}
665
666#[derive(Copy, Clone, Debug)]
667enum IndentDelta {
668 Spaces(isize),
669 Tabs(isize),
670}
671
672impl IndentDelta {
673 fn character(&self) -> char {
674 match self {
675 IndentDelta::Spaces(_) => ' ',
676 IndentDelta::Tabs(_) => '\t',
677 }
678 }
679
680 fn len(&self) -> isize {
681 match self {
682 IndentDelta::Spaces(n) => *n,
683 IndentDelta::Tabs(n) => *n,
684 }
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691 use fs::FakeFs;
692 use futures::stream;
693 use gpui::{AppContext, TestAppContext};
694 use indoc::indoc;
695 use language_model::fake_provider::FakeLanguageModel;
696 use project::{AgentLocation, Project};
697 use rand::prelude::*;
698 use rand::rngs::StdRng;
699 use std::cmp;
700
701 #[gpui::test(iterations = 100)]
702 async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) {
703 let agent = init_test(cx).await;
704 let buffer = cx.new(|cx| {
705 Buffer::local(
706 indoc! {"
707 abc
708 def
709 ghi
710 "},
711 cx,
712 )
713 });
714 let (apply, _events) = agent.edit(
715 buffer.clone(),
716 String::new(),
717 &LanguageModelRequest::default(),
718 &mut cx.to_async(),
719 );
720 cx.run_until_parked();
721
722 simulate_llm_output(
723 &agent,
724 indoc! {"
725 <old_text></old_text>
726 <new_text>jkl</new_text>
727 <old_text>def</old_text>
728 <new_text>DEF</new_text>
729 "},
730 &mut rng,
731 cx,
732 );
733 apply.await.unwrap();
734
735 pretty_assertions::assert_eq!(
736 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
737 indoc! {"
738 abc
739 DEF
740 ghi
741 "}
742 );
743 }
744
745 #[gpui::test(iterations = 100)]
746 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
747 let agent = init_test(cx).await;
748 let buffer = cx.new(|cx| {
749 Buffer::local(
750 indoc! {"
751 lorem
752 ipsum
753 dolor
754 sit
755 "},
756 cx,
757 )
758 });
759 let (apply, _events) = agent.edit(
760 buffer.clone(),
761 String::new(),
762 &LanguageModelRequest::default(),
763 &mut cx.to_async(),
764 );
765 cx.run_until_parked();
766
767 simulate_llm_output(
768 &agent,
769 indoc! {"
770 <old_text>
771 ipsum
772 dolor
773 sit
774 </old_text>
775 <new_text>
776 ipsum
777 dolor
778 sit
779 amet
780 </new_text>
781 "},
782 &mut rng,
783 cx,
784 );
785 apply.await.unwrap();
786
787 pretty_assertions::assert_eq!(
788 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
789 indoc! {"
790 lorem
791 ipsum
792 dolor
793 sit
794 amet
795 "}
796 );
797 }
798
799 #[gpui::test(iterations = 100)]
800 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
801 let agent = init_test(cx).await;
802 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
803 let (apply, _events) = agent.edit(
804 buffer.clone(),
805 String::new(),
806 &LanguageModelRequest::default(),
807 &mut cx.to_async(),
808 );
809 cx.run_until_parked();
810
811 simulate_llm_output(
812 &agent,
813 indoc! {"
814 <old_text>
815 def
816 </old_text>
817 <new_text>
818 DEF
819 </new_text>
820
821 <old_text>
822 DEF
823 </old_text>
824 <new_text>
825 DeF
826 </new_text>
827 "},
828 &mut rng,
829 cx,
830 );
831 apply.await.unwrap();
832
833 assert_eq!(
834 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
835 "abc\nDeF\nghi"
836 );
837 }
838
839 #[gpui::test(iterations = 100)]
840 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
841 let agent = init_test(cx).await;
842 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
843 let (apply, _events) = agent.edit(
844 buffer.clone(),
845 String::new(),
846 &LanguageModelRequest::default(),
847 &mut cx.to_async(),
848 );
849 cx.run_until_parked();
850
851 simulate_llm_output(
852 &agent,
853 indoc! {"
854 <old_text>
855 jkl
856 </old_text>
857 <new_text>
858 mno
859 </new_text>
860
861 <old_text>
862 abc
863 </old_text>
864 <new_text>
865 ABC
866 </new_text>
867 "},
868 &mut rng,
869 cx,
870 );
871 apply.await.unwrap();
872
873 assert_eq!(
874 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
875 "ABC\ndef\nghi"
876 );
877 }
878
879 #[gpui::test]
880 async fn test_edit_events(cx: &mut TestAppContext) {
881 let agent = init_test(cx).await;
882 let model = agent.model.as_fake();
883 let project = agent
884 .action_log
885 .read_with(cx, |log, _| log.project().clone());
886 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx));
887
888 let mut async_cx = cx.to_async();
889 let (apply, mut events) = agent.edit(
890 buffer.clone(),
891 String::new(),
892 &LanguageModelRequest::default(),
893 &mut async_cx,
894 );
895 cx.run_until_parked();
896
897 model.stream_last_completion_response("<old_text>a");
898 cx.run_until_parked();
899 assert_eq!(drain_events(&mut events), vec![]);
900 assert_eq!(
901 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
902 "abc\ndef\nghi\njkl"
903 );
904 assert_eq!(
905 project.read_with(cx, |project, _| project.agent_location()),
906 None
907 );
908
909 model.stream_last_completion_response("bc</old_text>");
910 cx.run_until_parked();
911 assert_eq!(
912 drain_events(&mut events),
913 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
914 cx,
915 |buffer, _| buffer.anchor_before(Point::new(0, 0))
916 ..buffer.anchor_before(Point::new(0, 3))
917 ))]
918 );
919 assert_eq!(
920 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
921 "abc\ndef\nghi\njkl"
922 );
923 assert_eq!(
924 project.read_with(cx, |project, _| project.agent_location()),
925 Some(AgentLocation {
926 buffer: buffer.downgrade(),
927 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
928 })
929 );
930
931 model.stream_last_completion_response("<new_text>abX");
932 cx.run_until_parked();
933 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
934 assert_eq!(
935 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
936 "abXc\ndef\nghi\njkl"
937 );
938 assert_eq!(
939 project.read_with(cx, |project, _| project.agent_location()),
940 Some(AgentLocation {
941 buffer: buffer.downgrade(),
942 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
943 })
944 );
945
946 model.stream_last_completion_response("cY");
947 cx.run_until_parked();
948 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
949 assert_eq!(
950 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
951 "abXcY\ndef\nghi\njkl"
952 );
953 assert_eq!(
954 project.read_with(cx, |project, _| project.agent_location()),
955 Some(AgentLocation {
956 buffer: buffer.downgrade(),
957 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
958 })
959 );
960
961 model.stream_last_completion_response("</new_text>");
962 model.stream_last_completion_response("<old_text>hall");
963 cx.run_until_parked();
964 assert_eq!(drain_events(&mut events), vec![]);
965 assert_eq!(
966 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
967 "abXcY\ndef\nghi\njkl"
968 );
969 assert_eq!(
970 project.read_with(cx, |project, _| project.agent_location()),
971 Some(AgentLocation {
972 buffer: buffer.downgrade(),
973 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
974 })
975 );
976
977 model.stream_last_completion_response("ucinated old</old_text>");
978 model.stream_last_completion_response("<new_text>");
979 cx.run_until_parked();
980 assert_eq!(
981 drain_events(&mut events),
982 vec![EditAgentOutputEvent::UnresolvedEditRange]
983 );
984 assert_eq!(
985 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
986 "abXcY\ndef\nghi\njkl"
987 );
988 assert_eq!(
989 project.read_with(cx, |project, _| project.agent_location()),
990 Some(AgentLocation {
991 buffer: buffer.downgrade(),
992 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
993 })
994 );
995
996 model.stream_last_completion_response("hallucinated new</new_");
997 model.stream_last_completion_response("text>");
998 cx.run_until_parked();
999 assert_eq!(drain_events(&mut events), vec![]);
1000 assert_eq!(
1001 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1002 "abXcY\ndef\nghi\njkl"
1003 );
1004 assert_eq!(
1005 project.read_with(cx, |project, _| project.agent_location()),
1006 Some(AgentLocation {
1007 buffer: buffer.downgrade(),
1008 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1009 })
1010 );
1011
1012 model.stream_last_completion_response("<old_text>\nghi\nj");
1013 cx.run_until_parked();
1014 assert_eq!(
1015 drain_events(&mut events),
1016 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1017 cx,
1018 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1019 ..buffer.anchor_before(Point::new(2, 3))
1020 ))]
1021 );
1022 assert_eq!(
1023 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1024 "abXcY\ndef\nghi\njkl"
1025 );
1026 assert_eq!(
1027 project.read_with(cx, |project, _| project.agent_location()),
1028 Some(AgentLocation {
1029 buffer: buffer.downgrade(),
1030 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1031 })
1032 );
1033
1034 model.stream_last_completion_response("kl</old_text>");
1035 model.stream_last_completion_response("<new_text>");
1036 cx.run_until_parked();
1037 assert_eq!(
1038 drain_events(&mut events),
1039 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1040 cx,
1041 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1042 ..buffer.anchor_before(Point::new(3, 3))
1043 ))]
1044 );
1045 assert_eq!(
1046 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1047 "abXcY\ndef\nghi\njkl"
1048 );
1049 assert_eq!(
1050 project.read_with(cx, |project, _| project.agent_location()),
1051 Some(AgentLocation {
1052 buffer: buffer.downgrade(),
1053 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)))
1054 })
1055 );
1056
1057 model.stream_last_completion_response("GHI</new_text>");
1058 cx.run_until_parked();
1059 assert_eq!(
1060 drain_events(&mut events),
1061 vec![EditAgentOutputEvent::Edited]
1062 );
1063 assert_eq!(
1064 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1065 "abXcY\ndef\nGHI"
1066 );
1067 assert_eq!(
1068 project.read_with(cx, |project, _| project.agent_location()),
1069 Some(AgentLocation {
1070 buffer: buffer.downgrade(),
1071 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1072 })
1073 );
1074
1075 model.end_last_completion_stream();
1076 apply.await.unwrap();
1077 assert_eq!(
1078 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1079 "abXcY\ndef\nGHI"
1080 );
1081 assert_eq!(drain_events(&mut events), vec![]);
1082 assert_eq!(
1083 project.read_with(cx, |project, _| project.agent_location()),
1084 Some(AgentLocation {
1085 buffer: buffer.downgrade(),
1086 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1087 })
1088 );
1089 }
1090
1091 #[gpui::test]
1092 async fn test_overwrite_events(cx: &mut TestAppContext) {
1093 let agent = init_test(cx).await;
1094 let project = agent
1095 .action_log
1096 .read_with(cx, |log, _| log.project().clone());
1097 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1098 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1099 let (apply, mut events) = agent.overwrite_with_chunks(
1100 buffer.clone(),
1101 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1102 &mut cx.to_async(),
1103 );
1104
1105 cx.run_until_parked();
1106 assert_eq!(
1107 drain_events(&mut events),
1108 vec![EditAgentOutputEvent::Edited]
1109 );
1110 assert_eq!(
1111 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1112 ""
1113 );
1114 assert_eq!(
1115 project.read_with(cx, |project, _| project.agent_location()),
1116 Some(AgentLocation {
1117 buffer: buffer.downgrade(),
1118 position: language::Anchor::MAX
1119 })
1120 );
1121
1122 chunks_tx.unbounded_send("```\njkl\n").unwrap();
1123 cx.run_until_parked();
1124 assert_eq!(
1125 drain_events(&mut events),
1126 vec![EditAgentOutputEvent::Edited]
1127 );
1128 assert_eq!(
1129 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1130 "jkl"
1131 );
1132 assert_eq!(
1133 project.read_with(cx, |project, _| project.agent_location()),
1134 Some(AgentLocation {
1135 buffer: buffer.downgrade(),
1136 position: language::Anchor::MAX
1137 })
1138 );
1139
1140 chunks_tx.unbounded_send("mno\n").unwrap();
1141 cx.run_until_parked();
1142 assert_eq!(
1143 drain_events(&mut events),
1144 vec![EditAgentOutputEvent::Edited]
1145 );
1146 assert_eq!(
1147 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1148 "jkl\nmno"
1149 );
1150 assert_eq!(
1151 project.read_with(cx, |project, _| project.agent_location()),
1152 Some(AgentLocation {
1153 buffer: buffer.downgrade(),
1154 position: language::Anchor::MAX
1155 })
1156 );
1157
1158 chunks_tx.unbounded_send("pqr\n```").unwrap();
1159 cx.run_until_parked();
1160 assert_eq!(
1161 drain_events(&mut events),
1162 vec![EditAgentOutputEvent::Edited]
1163 );
1164 assert_eq!(
1165 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1166 "jkl\nmno\npqr"
1167 );
1168 assert_eq!(
1169 project.read_with(cx, |project, _| project.agent_location()),
1170 Some(AgentLocation {
1171 buffer: buffer.downgrade(),
1172 position: language::Anchor::MAX
1173 })
1174 );
1175
1176 drop(chunks_tx);
1177 apply.await.unwrap();
1178 assert_eq!(
1179 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1180 "jkl\nmno\npqr"
1181 );
1182 assert_eq!(drain_events(&mut events), vec![]);
1183 assert_eq!(
1184 project.read_with(cx, |project, _| project.agent_location()),
1185 Some(AgentLocation {
1186 buffer: buffer.downgrade(),
1187 position: language::Anchor::MAX
1188 })
1189 );
1190 }
1191
1192 #[gpui::test(iterations = 100)]
1193 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1194 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1195 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1196 Ok(EditParserEvent::NewTextChunk {
1197 chunk: chunk.clone(),
1198 done: index == chunks.len() - 1,
1199 })
1200 }));
1201 let indented_chunks =
1202 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1203 .collect::<Vec<_>>()
1204 .await;
1205 let new_text = indented_chunks
1206 .into_iter()
1207 .collect::<Result<String>>()
1208 .unwrap();
1209 assert_eq!(new_text, " abc\n def\n ghi");
1210 }
1211
1212 #[gpui::test(iterations = 100)]
1213 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1214 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1215 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1216 Ok(EditParserEvent::NewTextChunk {
1217 chunk: chunk.clone(),
1218 done: index == chunks.len() - 1,
1219 })
1220 }));
1221 let indented_chunks =
1222 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1223 .collect::<Vec<_>>()
1224 .await;
1225 let new_text = indented_chunks
1226 .into_iter()
1227 .collect::<Result<String>>()
1228 .unwrap();
1229 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1230 }
1231
1232 #[gpui::test(iterations = 100)]
1233 async fn test_random_indents(mut rng: StdRng) {
1234 let len = rng.gen_range(1..=100);
1235 let new_text = util::RandomCharIter::new(&mut rng)
1236 .with_simple_text()
1237 .take(len)
1238 .collect::<String>();
1239 let new_text = new_text
1240 .split('\n')
1241 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1242 .collect::<Vec<_>>()
1243 .join("\n");
1244 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1245
1246 let chunks = to_random_chunks(&mut rng, &new_text);
1247 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1248 Ok(EditParserEvent::NewTextChunk {
1249 chunk: chunk.clone(),
1250 done: index == chunks.len() - 1,
1251 })
1252 }));
1253 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1254 .collect::<Vec<_>>()
1255 .await;
1256 let actual_reindented_text = reindented_chunks
1257 .into_iter()
1258 .collect::<Result<String>>()
1259 .unwrap();
1260 let expected_reindented_text = new_text
1261 .split('\n')
1262 .map(|line| {
1263 if let Some(ix) = line.find(|c| c != ' ') {
1264 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1265 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1266 } else {
1267 line.to_string()
1268 }
1269 })
1270 .collect::<Vec<_>>()
1271 .join("\n");
1272 assert_eq!(actual_reindented_text, expected_reindented_text);
1273 }
1274
1275 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1276 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1277 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1278 chunk_indices.sort();
1279 chunk_indices.push(input.len());
1280
1281 let mut chunks = Vec::new();
1282 let mut last_ix = 0;
1283 for chunk_ix in chunk_indices {
1284 chunks.push(input[last_ix..chunk_ix].to_string());
1285 last_ix = chunk_ix;
1286 }
1287 chunks
1288 }
1289
1290 fn simulate_llm_output(
1291 agent: &EditAgent,
1292 output: &str,
1293 rng: &mut StdRng,
1294 cx: &mut TestAppContext,
1295 ) {
1296 let executor = cx.executor();
1297 let chunks = to_random_chunks(rng, output);
1298 let model = agent.model.clone();
1299 cx.background_spawn(async move {
1300 for chunk in chunks {
1301 executor.simulate_random_delay().await;
1302 model.as_fake().stream_last_completion_response(chunk);
1303 }
1304 model.as_fake().end_last_completion_stream();
1305 })
1306 .detach();
1307 }
1308
1309 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1310 cx.update(settings::init);
1311 cx.update(Project::init_settings);
1312 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1313 let model = Arc::new(FakeLanguageModel::default());
1314 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1315 EditAgent::new(model, project, action_log, Templates::new())
1316 }
1317
1318 fn drain_events(
1319 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1320 ) -> Vec<EditAgentOutputEvent> {
1321 let mut events = Vec::new();
1322 while let Ok(Some(event)) = stream.try_next() {
1323 events.push(event);
1324 }
1325 events
1326 }
1327}