1mod create_file_parser;
2mod edit_parser;
3#[cfg(test)]
4mod evals;
5mod streaming_fuzzy_matcher;
6
7use crate::{Template, Templates};
8use anyhow::Result;
9use assistant_tool::ActionLog;
10use create_file_parser::{CreateFileParser, CreateFileParserEvent};
11use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
12use futures::{
13 Stream, StreamExt,
14 channel::mpsc::{self, UnboundedReceiver},
15 pin_mut,
16 stream::BoxStream,
17};
18use gpui::{AppContext, AsyncApp, Entity, Task};
19use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot};
20use language_model::{
21 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
22 LanguageModelToolChoice, MessageContent, Role,
23};
24use project::{AgentLocation, Project};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
28use streaming_diff::{CharOperation, StreamingDiff};
29use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
30use util::debug_panic;
31use zed_llm_client::CompletionIntent;
32
33#[derive(Serialize)]
34struct CreateFilePromptTemplate {
35 path: Option<PathBuf>,
36 edit_description: String,
37}
38
39impl Template for CreateFilePromptTemplate {
40 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
41}
42
43#[derive(Serialize)]
44struct EditFilePromptTemplate {
45 path: Option<PathBuf>,
46 edit_description: String,
47}
48
49impl Template for EditFilePromptTemplate {
50 const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub enum EditAgentOutputEvent {
55 ResolvingEditRange(Range<Anchor>),
56 UnresolvedEditRange,
57 Edited,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
61pub struct EditAgentOutput {
62 pub raw_edits: String,
63 pub parser_metrics: EditParserMetrics,
64}
65
66#[derive(Clone)]
67pub struct EditAgent {
68 model: Arc<dyn LanguageModel>,
69 action_log: Entity<ActionLog>,
70 project: Entity<Project>,
71 templates: Arc<Templates>,
72}
73
74impl EditAgent {
75 pub fn new(
76 model: Arc<dyn LanguageModel>,
77 project: Entity<Project>,
78 action_log: Entity<ActionLog>,
79 templates: Arc<Templates>,
80 ) -> Self {
81 EditAgent {
82 model,
83 project,
84 action_log,
85 templates,
86 }
87 }
88
89 pub fn overwrite(
90 &self,
91 buffer: Entity<Buffer>,
92 edit_description: String,
93 conversation: &LanguageModelRequest,
94 cx: &mut AsyncApp,
95 ) -> (
96 Task<Result<EditAgentOutput>>,
97 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
98 ) {
99 let this = self.clone();
100 let (events_tx, events_rx) = mpsc::unbounded();
101 let conversation = conversation.clone();
102 let output = cx.spawn(async move |cx| {
103 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
104 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
105 let prompt = CreateFilePromptTemplate {
106 path,
107 edit_description,
108 }
109 .render(&this.templates)?;
110 let new_chunks = this
111 .request(conversation, CompletionIntent::CreateFile, prompt, cx)
112 .await?;
113
114 let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
115 while let Some(event) = inner_events.next().await {
116 events_tx.unbounded_send(event).ok();
117 }
118 output.await
119 });
120 (output, events_rx)
121 }
122
123 fn overwrite_with_chunks(
124 &self,
125 buffer: Entity<Buffer>,
126 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
127 cx: &mut AsyncApp,
128 ) -> (
129 Task<Result<EditAgentOutput>>,
130 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
131 ) {
132 let (output_events_tx, output_events_rx) = mpsc::unbounded();
133 let (parse_task, parse_rx) = Self::parse_create_file_chunks(edit_chunks, cx);
134 let this = self.clone();
135 let task = cx.spawn(async move |cx| {
136 this.action_log
137 .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
138 this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx)
139 .await?;
140 parse_task.await
141 });
142 (task, output_events_rx)
143 }
144
145 async fn overwrite_with_chunks_internal(
146 &self,
147 buffer: Entity<Buffer>,
148 mut parse_rx: UnboundedReceiver<Result<CreateFileParserEvent>>,
149 output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
150 cx: &mut AsyncApp,
151 ) -> Result<()> {
152 cx.update(|cx| {
153 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
154 self.action_log.update(cx, |log, cx| {
155 log.buffer_edited(buffer.clone(), cx);
156 });
157 self.project.update(cx, |project, cx| {
158 project.set_agent_location(
159 Some(AgentLocation {
160 buffer: buffer.downgrade(),
161 position: language::Anchor::MAX,
162 }),
163 cx,
164 )
165 });
166 output_events_tx
167 .unbounded_send(EditAgentOutputEvent::Edited)
168 .ok();
169 })?;
170
171 while let Some(event) = parse_rx.next().await {
172 match event? {
173 CreateFileParserEvent::NewTextChunk { chunk } => {
174 cx.update(|cx| {
175 buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
176 self.action_log
177 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
178 self.project.update(cx, |project, cx| {
179 project.set_agent_location(
180 Some(AgentLocation {
181 buffer: buffer.downgrade(),
182 position: language::Anchor::MAX,
183 }),
184 cx,
185 )
186 });
187 })?;
188 output_events_tx
189 .unbounded_send(EditAgentOutputEvent::Edited)
190 .ok();
191 }
192 }
193 }
194
195 Ok(())
196 }
197
198 pub fn edit(
199 &self,
200 buffer: Entity<Buffer>,
201 edit_description: String,
202 conversation: &LanguageModelRequest,
203 cx: &mut AsyncApp,
204 ) -> (
205 Task<Result<EditAgentOutput>>,
206 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
207 ) {
208 let this = self.clone();
209 let (events_tx, events_rx) = mpsc::unbounded();
210 let conversation = conversation.clone();
211 let output = cx.spawn(async move |cx| {
212 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
213 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
214 let prompt = EditFilePromptTemplate {
215 path,
216 edit_description,
217 }
218 .render(&this.templates)?;
219 let edit_chunks = this
220 .request(conversation, CompletionIntent::EditFile, prompt, cx)
221 .await?;
222 this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
223 .await
224 });
225 (output, events_rx)
226 }
227
228 async fn apply_edit_chunks(
229 &self,
230 buffer: Entity<Buffer>,
231 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
232 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
233 cx: &mut AsyncApp,
234 ) -> Result<EditAgentOutput> {
235 self.action_log
236 .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
237
238 let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
239 let mut edit_events = edit_events.peekable();
240 while let Some(edit_event) = Pin::new(&mut edit_events).peek().await {
241 // Skip events until we're at the start of a new edit.
242 let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else {
243 edit_events.next().await.unwrap()?;
244 continue;
245 };
246
247 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
248
249 // Resolve the old text in the background, updating the agent
250 // location as we keep refining which range it corresponds to.
251 let (resolve_old_text, mut old_range) =
252 Self::resolve_old_text(snapshot.text.clone(), edit_events, cx);
253 while let Ok(old_range) = old_range.recv().await {
254 if let Some(old_range) = old_range {
255 let old_range = snapshot.anchor_before(old_range.start)
256 ..snapshot.anchor_before(old_range.end);
257 self.project.update(cx, |project, cx| {
258 project.set_agent_location(
259 Some(AgentLocation {
260 buffer: buffer.downgrade(),
261 position: old_range.end,
262 }),
263 cx,
264 );
265 })?;
266 output_events
267 .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range))
268 .ok();
269 }
270 }
271
272 let (edit_events_, resolved_old_text) = resolve_old_text.await?;
273 edit_events = edit_events_;
274
275 // If we can't resolve the old text, restart the loop waiting for a
276 // new edit (or for the stream to end).
277 let Some(resolved_old_text) = resolved_old_text else {
278 output_events
279 .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
280 .ok();
281 continue;
282 };
283
284 // Compute edits in the background and apply them as they become
285 // available.
286 let (compute_edits, edits) =
287 Self::compute_edits(snapshot, resolved_old_text, edit_events, cx);
288 let mut edits = edits.ready_chunks(32);
289 while let Some(edits) = edits.next().await {
290 if edits.is_empty() {
291 continue;
292 }
293
294 // Edit the buffer and report edits to the action log as part of the
295 // same effect cycle, otherwise the edit will be reported as if the
296 // user made it.
297 cx.update(|cx| {
298 let max_edit_end = buffer.update(cx, |buffer, cx| {
299 buffer.edit(edits.iter().cloned(), None, cx);
300 let max_edit_end = buffer
301 .summaries_for_anchors::<Point, _>(
302 edits.iter().map(|(range, _)| &range.end),
303 )
304 .max()
305 .unwrap();
306 buffer.anchor_before(max_edit_end)
307 });
308 self.action_log
309 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
310 self.project.update(cx, |project, cx| {
311 project.set_agent_location(
312 Some(AgentLocation {
313 buffer: buffer.downgrade(),
314 position: max_edit_end,
315 }),
316 cx,
317 );
318 });
319 })?;
320 output_events
321 .unbounded_send(EditAgentOutputEvent::Edited)
322 .ok();
323 }
324
325 edit_events = compute_edits.await?;
326 }
327
328 output.await
329 }
330
331 fn parse_edit_chunks(
332 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
333 cx: &mut AsyncApp,
334 ) -> (
335 Task<Result<EditAgentOutput>>,
336 UnboundedReceiver<Result<EditParserEvent>>,
337 ) {
338 let (tx, rx) = mpsc::unbounded();
339 let output = cx.background_spawn(async move {
340 pin_mut!(chunks);
341
342 let mut parser = EditParser::new();
343 let mut raw_edits = String::new();
344 while let Some(chunk) = chunks.next().await {
345 match chunk {
346 Ok(chunk) => {
347 raw_edits.push_str(&chunk);
348 for event in parser.push(&chunk) {
349 tx.unbounded_send(Ok(event))?;
350 }
351 }
352 Err(error) => {
353 tx.unbounded_send(Err(error.into()))?;
354 }
355 }
356 }
357 Ok(EditAgentOutput {
358 raw_edits,
359 parser_metrics: parser.finish(),
360 })
361 });
362 (output, rx)
363 }
364
365 fn parse_create_file_chunks(
366 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
367 cx: &mut AsyncApp,
368 ) -> (
369 Task<Result<EditAgentOutput>>,
370 UnboundedReceiver<Result<CreateFileParserEvent>>,
371 ) {
372 let (tx, rx) = mpsc::unbounded();
373 let output = cx.background_spawn(async move {
374 pin_mut!(chunks);
375
376 let mut parser = CreateFileParser::new();
377 let mut raw_edits = String::new();
378 while let Some(chunk) = chunks.next().await {
379 match chunk {
380 Ok(chunk) => {
381 raw_edits.push_str(&chunk);
382 for event in parser.push(Some(&chunk)) {
383 tx.unbounded_send(Ok(event))?;
384 }
385 }
386 Err(error) => {
387 tx.unbounded_send(Err(error.into()))?;
388 }
389 }
390 }
391 // Send final events with None to indicate completion
392 for event in parser.push(None) {
393 tx.unbounded_send(Ok(event))?;
394 }
395 Ok(EditAgentOutput {
396 raw_edits,
397 parser_metrics: EditParserMetrics::default(),
398 })
399 });
400 (output, rx)
401 }
402
403 fn resolve_old_text<T>(
404 snapshot: TextBufferSnapshot,
405 mut edit_events: T,
406 cx: &mut AsyncApp,
407 ) -> (
408 Task<Result<(T, Option<ResolvedOldText>)>>,
409 async_watch::Receiver<Option<Range<usize>>>,
410 )
411 where
412 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
413 {
414 let (old_range_tx, old_range_rx) = async_watch::channel(None);
415 let task = cx.background_spawn(async move {
416 let mut matcher = StreamingFuzzyMatcher::new(snapshot);
417 while let Some(edit_event) = edit_events.next().await {
418 let EditParserEvent::OldTextChunk { chunk, done } = edit_event? else {
419 break;
420 };
421
422 old_range_tx.send(matcher.push(&chunk))?;
423 if done {
424 break;
425 }
426 }
427
428 let old_range = matcher.finish();
429 old_range_tx.send(old_range.clone())?;
430 if let Some(old_range) = old_range {
431 let line_indent =
432 LineIndent::from_iter(matcher.query_lines().first().unwrap().chars());
433 Ok((
434 edit_events,
435 Some(ResolvedOldText {
436 range: old_range,
437 indent: line_indent,
438 }),
439 ))
440 } else {
441 Ok((edit_events, None))
442 }
443 });
444
445 (task, old_range_rx)
446 }
447
448 fn compute_edits<T>(
449 snapshot: BufferSnapshot,
450 resolved_old_text: ResolvedOldText,
451 mut edit_events: T,
452 cx: &mut AsyncApp,
453 ) -> (
454 Task<Result<T>>,
455 UnboundedReceiver<(Range<Anchor>, Arc<str>)>,
456 )
457 where
458 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
459 {
460 let (edits_tx, edits_rx) = mpsc::unbounded();
461 let compute_edits = cx.background_spawn(async move {
462 let buffer_start_indent = snapshot
463 .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row);
464 let indent_delta = if buffer_start_indent.tabs > 0 {
465 IndentDelta::Tabs(
466 buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize,
467 )
468 } else {
469 IndentDelta::Spaces(
470 buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize,
471 )
472 };
473
474 let old_text = snapshot
475 .text_for_range(resolved_old_text.range.clone())
476 .collect::<String>();
477 let mut diff = StreamingDiff::new(old_text);
478 let mut edit_start = resolved_old_text.range.start;
479 let mut new_text_chunks =
480 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
481 let mut done = false;
482 while !done {
483 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await {
484 diff.push_new(&new_text_chunk?)
485 } else {
486 done = true;
487 mem::take(&mut diff).finish()
488 };
489
490 for op in char_operations {
491 match op {
492 CharOperation::Insert { text } => {
493 let edit_start = snapshot.anchor_after(edit_start);
494 edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
495 }
496 CharOperation::Delete { bytes } => {
497 let edit_end = edit_start + bytes;
498 let edit_range =
499 snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end);
500 edit_start = edit_end;
501 edits_tx.unbounded_send((edit_range, Arc::from("")))?;
502 }
503 CharOperation::Keep { bytes } => edit_start += bytes,
504 }
505 }
506 }
507
508 drop(new_text_chunks);
509 anyhow::Ok(edit_events)
510 });
511
512 (compute_edits, edits_rx)
513 }
514
515 fn reindent_new_text_chunks(
516 delta: IndentDelta,
517 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
518 ) -> impl Stream<Item = Result<String>> {
519 let mut buffer = String::new();
520 let mut in_leading_whitespace = true;
521 let mut done = false;
522 futures::stream::poll_fn(move |cx| {
523 while !done {
524 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
525 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
526 (chunk, done)
527 }
528 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
529 Poll::Pending => return Poll::Pending,
530 _ => return Poll::Ready(None),
531 };
532
533 buffer.push_str(&chunk);
534
535 let mut indented_new_text = String::new();
536 let mut start_ix = 0;
537 let mut newlines = buffer.match_indices('\n').peekable();
538 loop {
539 let (line_end, is_pending_line) = match newlines.next() {
540 Some((ix, _)) => (ix, false),
541 None => (buffer.len(), true),
542 };
543 let line = &buffer[start_ix..line_end];
544
545 if in_leading_whitespace {
546 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
547 // We found a non-whitespace character, adjust
548 // indentation based on the delta.
549 let new_indent_len =
550 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
551 indented_new_text
552 .extend(iter::repeat(delta.character()).take(new_indent_len));
553 indented_new_text.push_str(&line[non_whitespace_ix..]);
554 in_leading_whitespace = false;
555 } else if is_pending_line {
556 // We're still in leading whitespace and this line is incomplete.
557 // Stop processing until we receive more input.
558 break;
559 } else {
560 // This line is entirely whitespace. Push it without indentation.
561 indented_new_text.push_str(line);
562 }
563 } else {
564 indented_new_text.push_str(line);
565 }
566
567 if is_pending_line {
568 start_ix = line_end;
569 break;
570 } else {
571 in_leading_whitespace = true;
572 indented_new_text.push('\n');
573 start_ix = line_end + 1;
574 }
575 }
576 buffer.replace_range(..start_ix, "");
577
578 // This was the last chunk, push all the buffered content as-is.
579 if is_last_chunk {
580 indented_new_text.push_str(&buffer);
581 buffer.clear();
582 done = true;
583 }
584
585 if !indented_new_text.is_empty() {
586 return Poll::Ready(Some(Ok(indented_new_text)));
587 }
588 }
589
590 Poll::Ready(None)
591 })
592 }
593
594 async fn request(
595 &self,
596 mut conversation: LanguageModelRequest,
597 intent: CompletionIntent,
598 prompt: String,
599 cx: &mut AsyncApp,
600 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
601 let mut messages_iter = conversation.messages.iter_mut();
602 if let Some(last_message) = messages_iter.next_back() {
603 if last_message.role == Role::Assistant {
604 let old_content_len = last_message.content.len();
605 last_message
606 .content
607 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
608 let new_content_len = last_message.content.len();
609
610 // We just removed pending tool uses from the content of the
611 // last message, so it doesn't make sense to cache it anymore
612 // (e.g., the message will look very different on the next
613 // request). Thus, we move the flag to the message prior to it,
614 // as it will still be a valid prefix of the conversation.
615 if old_content_len != new_content_len && last_message.cache {
616 if let Some(prev_message) = messages_iter.next_back() {
617 last_message.cache = false;
618 prev_message.cache = true;
619 }
620 }
621
622 if last_message.content.is_empty() {
623 conversation.messages.pop();
624 }
625 } else {
626 debug_panic!(
627 "Last message must be an Assistant tool calling! Got {:?}",
628 last_message.content
629 );
630 }
631 }
632
633 conversation.messages.push(LanguageModelRequestMessage {
634 role: Role::User,
635 content: vec![MessageContent::Text(prompt)],
636 cache: false,
637 });
638
639 // Include tools in the request so that we can take advantage of
640 // caching when ToolChoice::None is supported.
641 let mut tool_choice = None;
642 let mut tools = Vec::new();
643 if !conversation.tools.is_empty()
644 && self
645 .model
646 .supports_tool_choice(LanguageModelToolChoice::None)
647 {
648 tool_choice = Some(LanguageModelToolChoice::None);
649 tools = conversation.tools.clone();
650 }
651
652 let request = LanguageModelRequest {
653 thread_id: conversation.thread_id,
654 prompt_id: conversation.prompt_id,
655 intent: Some(intent),
656 mode: conversation.mode,
657 messages: conversation.messages,
658 tool_choice,
659 tools,
660 stop: Vec::new(),
661 temperature: None,
662 };
663
664 Ok(self.model.stream_completion_text(request, cx).await?.stream)
665 }
666}
667
668struct ResolvedOldText {
669 range: Range<usize>,
670 indent: LineIndent,
671}
672
673#[derive(Copy, Clone, Debug)]
674enum IndentDelta {
675 Spaces(isize),
676 Tabs(isize),
677}
678
679impl IndentDelta {
680 fn character(&self) -> char {
681 match self {
682 IndentDelta::Spaces(_) => ' ',
683 IndentDelta::Tabs(_) => '\t',
684 }
685 }
686
687 fn len(&self) -> isize {
688 match self {
689 IndentDelta::Spaces(n) => *n,
690 IndentDelta::Tabs(n) => *n,
691 }
692 }
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698 use fs::FakeFs;
699 use futures::stream;
700 use gpui::{AppContext, TestAppContext};
701 use indoc::indoc;
702 use language_model::fake_provider::FakeLanguageModel;
703 use project::{AgentLocation, Project};
704 use rand::prelude::*;
705 use rand::rngs::StdRng;
706 use std::cmp;
707
708 #[gpui::test(iterations = 100)]
709 async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) {
710 let agent = init_test(cx).await;
711 let buffer = cx.new(|cx| {
712 Buffer::local(
713 indoc! {"
714 abc
715 def
716 ghi
717 "},
718 cx,
719 )
720 });
721 let (apply, _events) = agent.edit(
722 buffer.clone(),
723 String::new(),
724 &LanguageModelRequest::default(),
725 &mut cx.to_async(),
726 );
727 cx.run_until_parked();
728
729 simulate_llm_output(
730 &agent,
731 indoc! {"
732 <old_text></old_text>
733 <new_text>jkl</new_text>
734 <old_text>def</old_text>
735 <new_text>DEF</new_text>
736 "},
737 &mut rng,
738 cx,
739 );
740 apply.await.unwrap();
741
742 pretty_assertions::assert_eq!(
743 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
744 indoc! {"
745 abc
746 DEF
747 ghi
748 "}
749 );
750 }
751
752 #[gpui::test(iterations = 100)]
753 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
754 let agent = init_test(cx).await;
755 let buffer = cx.new(|cx| {
756 Buffer::local(
757 indoc! {"
758 lorem
759 ipsum
760 dolor
761 sit
762 "},
763 cx,
764 )
765 });
766 let (apply, _events) = agent.edit(
767 buffer.clone(),
768 String::new(),
769 &LanguageModelRequest::default(),
770 &mut cx.to_async(),
771 );
772 cx.run_until_parked();
773
774 simulate_llm_output(
775 &agent,
776 indoc! {"
777 <old_text>
778 ipsum
779 dolor
780 sit
781 </old_text>
782 <new_text>
783 ipsum
784 dolor
785 sit
786 amet
787 </new_text>
788 "},
789 &mut rng,
790 cx,
791 );
792 apply.await.unwrap();
793
794 pretty_assertions::assert_eq!(
795 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
796 indoc! {"
797 lorem
798 ipsum
799 dolor
800 sit
801 amet
802 "}
803 );
804 }
805
806 #[gpui::test(iterations = 100)]
807 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
808 let agent = init_test(cx).await;
809 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
810 let (apply, _events) = agent.edit(
811 buffer.clone(),
812 String::new(),
813 &LanguageModelRequest::default(),
814 &mut cx.to_async(),
815 );
816 cx.run_until_parked();
817
818 simulate_llm_output(
819 &agent,
820 indoc! {"
821 <old_text>
822 def
823 </old_text>
824 <new_text>
825 DEF
826 </new_text>
827
828 <old_text>
829 DEF
830 </old_text>
831 <new_text>
832 DeF
833 </new_text>
834 "},
835 &mut rng,
836 cx,
837 );
838 apply.await.unwrap();
839
840 assert_eq!(
841 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
842 "abc\nDeF\nghi"
843 );
844 }
845
846 #[gpui::test(iterations = 100)]
847 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
848 let agent = init_test(cx).await;
849 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
850 let (apply, _events) = agent.edit(
851 buffer.clone(),
852 String::new(),
853 &LanguageModelRequest::default(),
854 &mut cx.to_async(),
855 );
856 cx.run_until_parked();
857
858 simulate_llm_output(
859 &agent,
860 indoc! {"
861 <old_text>
862 jkl
863 </old_text>
864 <new_text>
865 mno
866 </new_text>
867
868 <old_text>
869 abc
870 </old_text>
871 <new_text>
872 ABC
873 </new_text>
874 "},
875 &mut rng,
876 cx,
877 );
878 apply.await.unwrap();
879
880 assert_eq!(
881 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
882 "ABC\ndef\nghi"
883 );
884 }
885
886 #[gpui::test]
887 async fn test_edit_events(cx: &mut TestAppContext) {
888 let agent = init_test(cx).await;
889 let model = agent.model.as_fake();
890 let project = agent
891 .action_log
892 .read_with(cx, |log, _| log.project().clone());
893 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx));
894
895 let mut async_cx = cx.to_async();
896 let (apply, mut events) = agent.edit(
897 buffer.clone(),
898 String::new(),
899 &LanguageModelRequest::default(),
900 &mut async_cx,
901 );
902 cx.run_until_parked();
903
904 model.stream_last_completion_response("<old_text>a");
905 cx.run_until_parked();
906 assert_eq!(drain_events(&mut events), vec![]);
907 assert_eq!(
908 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
909 "abc\ndef\nghi\njkl"
910 );
911 assert_eq!(
912 project.read_with(cx, |project, _| project.agent_location()),
913 None
914 );
915
916 model.stream_last_completion_response("bc</old_text>");
917 cx.run_until_parked();
918 assert_eq!(
919 drain_events(&mut events),
920 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
921 cx,
922 |buffer, _| buffer.anchor_before(Point::new(0, 0))
923 ..buffer.anchor_before(Point::new(0, 3))
924 ))]
925 );
926 assert_eq!(
927 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
928 "abc\ndef\nghi\njkl"
929 );
930 assert_eq!(
931 project.read_with(cx, |project, _| project.agent_location()),
932 Some(AgentLocation {
933 buffer: buffer.downgrade(),
934 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
935 })
936 );
937
938 model.stream_last_completion_response("<new_text>abX");
939 cx.run_until_parked();
940 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
941 assert_eq!(
942 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
943 "abXc\ndef\nghi\njkl"
944 );
945 assert_eq!(
946 project.read_with(cx, |project, _| project.agent_location()),
947 Some(AgentLocation {
948 buffer: buffer.downgrade(),
949 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
950 })
951 );
952
953 model.stream_last_completion_response("cY");
954 cx.run_until_parked();
955 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
956 assert_eq!(
957 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
958 "abXcY\ndef\nghi\njkl"
959 );
960 assert_eq!(
961 project.read_with(cx, |project, _| project.agent_location()),
962 Some(AgentLocation {
963 buffer: buffer.downgrade(),
964 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
965 })
966 );
967
968 model.stream_last_completion_response("</new_text>");
969 model.stream_last_completion_response("<old_text>hall");
970 cx.run_until_parked();
971 assert_eq!(drain_events(&mut events), vec![]);
972 assert_eq!(
973 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
974 "abXcY\ndef\nghi\njkl"
975 );
976 assert_eq!(
977 project.read_with(cx, |project, _| project.agent_location()),
978 Some(AgentLocation {
979 buffer: buffer.downgrade(),
980 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
981 })
982 );
983
984 model.stream_last_completion_response("ucinated old</old_text>");
985 model.stream_last_completion_response("<new_text>");
986 cx.run_until_parked();
987 assert_eq!(
988 drain_events(&mut events),
989 vec![EditAgentOutputEvent::UnresolvedEditRange]
990 );
991 assert_eq!(
992 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
993 "abXcY\ndef\nghi\njkl"
994 );
995 assert_eq!(
996 project.read_with(cx, |project, _| project.agent_location()),
997 Some(AgentLocation {
998 buffer: buffer.downgrade(),
999 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1000 })
1001 );
1002
1003 model.stream_last_completion_response("hallucinated new</new_");
1004 model.stream_last_completion_response("text>");
1005 cx.run_until_parked();
1006 assert_eq!(drain_events(&mut events), vec![]);
1007 assert_eq!(
1008 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1009 "abXcY\ndef\nghi\njkl"
1010 );
1011 assert_eq!(
1012 project.read_with(cx, |project, _| project.agent_location()),
1013 Some(AgentLocation {
1014 buffer: buffer.downgrade(),
1015 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1016 })
1017 );
1018
1019 model.stream_last_completion_response("<old_text>\nghi\nj");
1020 cx.run_until_parked();
1021 assert_eq!(
1022 drain_events(&mut events),
1023 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1024 cx,
1025 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1026 ..buffer.anchor_before(Point::new(2, 3))
1027 ))]
1028 );
1029 assert_eq!(
1030 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1031 "abXcY\ndef\nghi\njkl"
1032 );
1033 assert_eq!(
1034 project.read_with(cx, |project, _| project.agent_location()),
1035 Some(AgentLocation {
1036 buffer: buffer.downgrade(),
1037 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1038 })
1039 );
1040
1041 model.stream_last_completion_response("kl</old_text>");
1042 model.stream_last_completion_response("<new_text>");
1043 cx.run_until_parked();
1044 assert_eq!(
1045 drain_events(&mut events),
1046 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1047 cx,
1048 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1049 ..buffer.anchor_before(Point::new(3, 3))
1050 ))]
1051 );
1052 assert_eq!(
1053 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1054 "abXcY\ndef\nghi\njkl"
1055 );
1056 assert_eq!(
1057 project.read_with(cx, |project, _| project.agent_location()),
1058 Some(AgentLocation {
1059 buffer: buffer.downgrade(),
1060 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)))
1061 })
1062 );
1063
1064 model.stream_last_completion_response("GHI</new_text>");
1065 cx.run_until_parked();
1066 assert_eq!(
1067 drain_events(&mut events),
1068 vec![EditAgentOutputEvent::Edited]
1069 );
1070 assert_eq!(
1071 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1072 "abXcY\ndef\nGHI"
1073 );
1074 assert_eq!(
1075 project.read_with(cx, |project, _| project.agent_location()),
1076 Some(AgentLocation {
1077 buffer: buffer.downgrade(),
1078 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1079 })
1080 );
1081
1082 model.end_last_completion_stream();
1083 apply.await.unwrap();
1084 assert_eq!(
1085 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1086 "abXcY\ndef\nGHI"
1087 );
1088 assert_eq!(drain_events(&mut events), vec![]);
1089 assert_eq!(
1090 project.read_with(cx, |project, _| project.agent_location()),
1091 Some(AgentLocation {
1092 buffer: buffer.downgrade(),
1093 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1094 })
1095 );
1096 }
1097
1098 #[gpui::test]
1099 async fn test_overwrite_events(cx: &mut TestAppContext) {
1100 let agent = init_test(cx).await;
1101 let project = agent
1102 .action_log
1103 .read_with(cx, |log, _| log.project().clone());
1104 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1105 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1106 let (apply, mut events) = agent.overwrite_with_chunks(
1107 buffer.clone(),
1108 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1109 &mut cx.to_async(),
1110 );
1111
1112 cx.run_until_parked();
1113 assert_eq!(
1114 drain_events(&mut events),
1115 vec![EditAgentOutputEvent::Edited]
1116 );
1117 assert_eq!(
1118 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1119 ""
1120 );
1121 assert_eq!(
1122 project.read_with(cx, |project, _| project.agent_location()),
1123 Some(AgentLocation {
1124 buffer: buffer.downgrade(),
1125 position: language::Anchor::MAX
1126 })
1127 );
1128
1129 chunks_tx.unbounded_send("```\njkl\n").unwrap();
1130 cx.run_until_parked();
1131 assert_eq!(
1132 drain_events(&mut events),
1133 vec![EditAgentOutputEvent::Edited]
1134 );
1135 assert_eq!(
1136 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1137 "jkl"
1138 );
1139 assert_eq!(
1140 project.read_with(cx, |project, _| project.agent_location()),
1141 Some(AgentLocation {
1142 buffer: buffer.downgrade(),
1143 position: language::Anchor::MAX
1144 })
1145 );
1146
1147 chunks_tx.unbounded_send("mno\n").unwrap();
1148 cx.run_until_parked();
1149 assert_eq!(
1150 drain_events(&mut events),
1151 vec![EditAgentOutputEvent::Edited]
1152 );
1153 assert_eq!(
1154 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1155 "jkl\nmno"
1156 );
1157 assert_eq!(
1158 project.read_with(cx, |project, _| project.agent_location()),
1159 Some(AgentLocation {
1160 buffer: buffer.downgrade(),
1161 position: language::Anchor::MAX
1162 })
1163 );
1164
1165 chunks_tx.unbounded_send("pqr\n```").unwrap();
1166 cx.run_until_parked();
1167 assert_eq!(
1168 drain_events(&mut events),
1169 vec![EditAgentOutputEvent::Edited]
1170 );
1171 assert_eq!(
1172 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1173 "jkl\nmno\npqr"
1174 );
1175 assert_eq!(
1176 project.read_with(cx, |project, _| project.agent_location()),
1177 Some(AgentLocation {
1178 buffer: buffer.downgrade(),
1179 position: language::Anchor::MAX
1180 })
1181 );
1182
1183 drop(chunks_tx);
1184 apply.await.unwrap();
1185 assert_eq!(
1186 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1187 "jkl\nmno\npqr"
1188 );
1189 assert_eq!(drain_events(&mut events), vec![]);
1190 assert_eq!(
1191 project.read_with(cx, |project, _| project.agent_location()),
1192 Some(AgentLocation {
1193 buffer: buffer.downgrade(),
1194 position: language::Anchor::MAX
1195 })
1196 );
1197 }
1198
1199 #[gpui::test(iterations = 100)]
1200 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1201 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1202 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1203 Ok(EditParserEvent::NewTextChunk {
1204 chunk: chunk.clone(),
1205 done: index == chunks.len() - 1,
1206 })
1207 }));
1208 let indented_chunks =
1209 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1210 .collect::<Vec<_>>()
1211 .await;
1212 let new_text = indented_chunks
1213 .into_iter()
1214 .collect::<Result<String>>()
1215 .unwrap();
1216 assert_eq!(new_text, " abc\n def\n ghi");
1217 }
1218
1219 #[gpui::test(iterations = 100)]
1220 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1221 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1222 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1223 Ok(EditParserEvent::NewTextChunk {
1224 chunk: chunk.clone(),
1225 done: index == chunks.len() - 1,
1226 })
1227 }));
1228 let indented_chunks =
1229 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1230 .collect::<Vec<_>>()
1231 .await;
1232 let new_text = indented_chunks
1233 .into_iter()
1234 .collect::<Result<String>>()
1235 .unwrap();
1236 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1237 }
1238
1239 #[gpui::test(iterations = 100)]
1240 async fn test_random_indents(mut rng: StdRng) {
1241 let len = rng.gen_range(1..=100);
1242 let new_text = util::RandomCharIter::new(&mut rng)
1243 .with_simple_text()
1244 .take(len)
1245 .collect::<String>();
1246 let new_text = new_text
1247 .split('\n')
1248 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1249 .collect::<Vec<_>>()
1250 .join("\n");
1251 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1252
1253 let chunks = to_random_chunks(&mut rng, &new_text);
1254 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1255 Ok(EditParserEvent::NewTextChunk {
1256 chunk: chunk.clone(),
1257 done: index == chunks.len() - 1,
1258 })
1259 }));
1260 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1261 .collect::<Vec<_>>()
1262 .await;
1263 let actual_reindented_text = reindented_chunks
1264 .into_iter()
1265 .collect::<Result<String>>()
1266 .unwrap();
1267 let expected_reindented_text = new_text
1268 .split('\n')
1269 .map(|line| {
1270 if let Some(ix) = line.find(|c| c != ' ') {
1271 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1272 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1273 } else {
1274 line.to_string()
1275 }
1276 })
1277 .collect::<Vec<_>>()
1278 .join("\n");
1279 assert_eq!(actual_reindented_text, expected_reindented_text);
1280 }
1281
1282 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1283 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1284 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1285 chunk_indices.sort();
1286 chunk_indices.push(input.len());
1287
1288 let mut chunks = Vec::new();
1289 let mut last_ix = 0;
1290 for chunk_ix in chunk_indices {
1291 chunks.push(input[last_ix..chunk_ix].to_string());
1292 last_ix = chunk_ix;
1293 }
1294 chunks
1295 }
1296
1297 fn simulate_llm_output(
1298 agent: &EditAgent,
1299 output: &str,
1300 rng: &mut StdRng,
1301 cx: &mut TestAppContext,
1302 ) {
1303 let executor = cx.executor();
1304 let chunks = to_random_chunks(rng, output);
1305 let model = agent.model.clone();
1306 cx.background_spawn(async move {
1307 for chunk in chunks {
1308 executor.simulate_random_delay().await;
1309 model.as_fake().stream_last_completion_response(chunk);
1310 }
1311 model.as_fake().end_last_completion_stream();
1312 })
1313 .detach();
1314 }
1315
1316 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1317 cx.update(settings::init);
1318 cx.update(Project::init_settings);
1319 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1320 let model = Arc::new(FakeLanguageModel::default());
1321 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1322 EditAgent::new(model, project, action_log, Templates::new())
1323 }
1324
1325 fn drain_events(
1326 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1327 ) -> Vec<EditAgentOutputEvent> {
1328 let mut events = Vec::new();
1329 while let Ok(Some(event)) = stream.try_next() {
1330 events.push(event);
1331 }
1332 events
1333 }
1334}