1mod create_file_parser;
2mod edit_parser;
3#[cfg(test)]
4mod evals;
5mod streaming_fuzzy_matcher;
6
7use crate::{Template, Templates};
8use anyhow::Result;
9use assistant_tool::ActionLog;
10use create_file_parser::{CreateFileParser, CreateFileParserEvent};
11use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
12use futures::{
13 Stream, StreamExt,
14 channel::mpsc::{self, UnboundedReceiver},
15 pin_mut,
16 stream::BoxStream,
17};
18use gpui::{AppContext, AsyncApp, Entity, Task};
19use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot};
20use language_model::{
21 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
22 LanguageModelToolChoice, MessageContent, Role,
23};
24use project::{AgentLocation, Project};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
28use streaming_diff::{CharOperation, StreamingDiff};
29use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
30use util::debug_panic;
31use zed_llm_client::CompletionIntent;
32
33#[derive(Serialize)]
34struct CreateFilePromptTemplate {
35 path: Option<PathBuf>,
36 edit_description: String,
37}
38
39impl Template for CreateFilePromptTemplate {
40 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
41}
42
43#[derive(Serialize)]
44struct EditFilePromptTemplate {
45 path: Option<PathBuf>,
46 edit_description: String,
47}
48
49impl Template for EditFilePromptTemplate {
50 const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub enum EditAgentOutputEvent {
55 ResolvingEditRange(Range<Anchor>),
56 UnresolvedEditRange,
57 AmbiguousEditRange(Vec<Range<usize>>),
58 Edited,
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
62pub struct EditAgentOutput {
63 pub raw_edits: String,
64 pub parser_metrics: EditParserMetrics,
65}
66
67#[derive(Clone)]
68pub struct EditAgent {
69 model: Arc<dyn LanguageModel>,
70 action_log: Entity<ActionLog>,
71 project: Entity<Project>,
72 templates: Arc<Templates>,
73}
74
75impl EditAgent {
76 pub fn new(
77 model: Arc<dyn LanguageModel>,
78 project: Entity<Project>,
79 action_log: Entity<ActionLog>,
80 templates: Arc<Templates>,
81 ) -> Self {
82 EditAgent {
83 model,
84 project,
85 action_log,
86 templates,
87 }
88 }
89
90 pub fn overwrite(
91 &self,
92 buffer: Entity<Buffer>,
93 edit_description: String,
94 conversation: &LanguageModelRequest,
95 cx: &mut AsyncApp,
96 ) -> (
97 Task<Result<EditAgentOutput>>,
98 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
99 ) {
100 let this = self.clone();
101 let (events_tx, events_rx) = mpsc::unbounded();
102 let conversation = conversation.clone();
103 let output = cx.spawn(async move |cx| {
104 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
105 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
106 let prompt = CreateFilePromptTemplate {
107 path,
108 edit_description,
109 }
110 .render(&this.templates)?;
111 let new_chunks = this
112 .request(conversation, CompletionIntent::CreateFile, prompt, cx)
113 .await?;
114
115 let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
116 while let Some(event) = inner_events.next().await {
117 events_tx.unbounded_send(event).ok();
118 }
119 output.await
120 });
121 (output, events_rx)
122 }
123
124 fn overwrite_with_chunks(
125 &self,
126 buffer: Entity<Buffer>,
127 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
128 cx: &mut AsyncApp,
129 ) -> (
130 Task<Result<EditAgentOutput>>,
131 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
132 ) {
133 let (output_events_tx, output_events_rx) = mpsc::unbounded();
134 let (parse_task, parse_rx) = Self::parse_create_file_chunks(edit_chunks, cx);
135 let this = self.clone();
136 let task = cx.spawn(async move |cx| {
137 this.action_log
138 .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
139 this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx)
140 .await?;
141 parse_task.await
142 });
143 (task, output_events_rx)
144 }
145
146 async fn overwrite_with_chunks_internal(
147 &self,
148 buffer: Entity<Buffer>,
149 mut parse_rx: UnboundedReceiver<Result<CreateFileParserEvent>>,
150 output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
151 cx: &mut AsyncApp,
152 ) -> Result<()> {
153 cx.update(|cx| {
154 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
155 self.action_log.update(cx, |log, cx| {
156 log.buffer_edited(buffer.clone(), cx);
157 });
158 self.project.update(cx, |project, cx| {
159 project.set_agent_location(
160 Some(AgentLocation {
161 buffer: buffer.downgrade(),
162 position: language::Anchor::MAX,
163 }),
164 cx,
165 )
166 });
167 output_events_tx
168 .unbounded_send(EditAgentOutputEvent::Edited)
169 .ok();
170 })?;
171
172 while let Some(event) = parse_rx.next().await {
173 match event? {
174 CreateFileParserEvent::NewTextChunk { chunk } => {
175 cx.update(|cx| {
176 buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
177 self.action_log
178 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
179 self.project.update(cx, |project, cx| {
180 project.set_agent_location(
181 Some(AgentLocation {
182 buffer: buffer.downgrade(),
183 position: language::Anchor::MAX,
184 }),
185 cx,
186 )
187 });
188 })?;
189 output_events_tx
190 .unbounded_send(EditAgentOutputEvent::Edited)
191 .ok();
192 }
193 }
194 }
195
196 Ok(())
197 }
198
199 pub fn edit(
200 &self,
201 buffer: Entity<Buffer>,
202 edit_description: String,
203 conversation: &LanguageModelRequest,
204 cx: &mut AsyncApp,
205 ) -> (
206 Task<Result<EditAgentOutput>>,
207 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
208 ) {
209 let this = self.clone();
210 let (events_tx, events_rx) = mpsc::unbounded();
211 let conversation = conversation.clone();
212 let output = cx.spawn(async move |cx| {
213 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
214 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
215 let prompt = EditFilePromptTemplate {
216 path,
217 edit_description,
218 }
219 .render(&this.templates)?;
220 let edit_chunks = this
221 .request(conversation, CompletionIntent::EditFile, prompt, cx)
222 .await?;
223 this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
224 .await
225 });
226 (output, events_rx)
227 }
228
229 async fn apply_edit_chunks(
230 &self,
231 buffer: Entity<Buffer>,
232 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
233 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
234 cx: &mut AsyncApp,
235 ) -> Result<EditAgentOutput> {
236 self.action_log
237 .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
238
239 let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
240 let mut edit_events = edit_events.peekable();
241 while let Some(edit_event) = Pin::new(&mut edit_events).peek().await {
242 // Skip events until we're at the start of a new edit.
243 let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else {
244 edit_events.next().await.unwrap()?;
245 continue;
246 };
247
248 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
249
250 // Resolve the old text in the background, updating the agent
251 // location as we keep refining which range it corresponds to.
252 let (resolve_old_text, mut old_range) =
253 Self::resolve_old_text(snapshot.text.clone(), edit_events, cx);
254 while let Ok(old_range) = old_range.recv().await {
255 if let Some(old_range) = old_range {
256 let old_range = snapshot.anchor_before(old_range.start)
257 ..snapshot.anchor_before(old_range.end);
258 self.project.update(cx, |project, cx| {
259 project.set_agent_location(
260 Some(AgentLocation {
261 buffer: buffer.downgrade(),
262 position: old_range.end,
263 }),
264 cx,
265 );
266 })?;
267 output_events
268 .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range))
269 .ok();
270 }
271 }
272
273 let (edit_events_, mut resolved_old_text) = resolve_old_text.await?;
274 edit_events = edit_events_;
275
276 // If we can't resolve the old text, restart the loop waiting for a
277 // new edit (or for the stream to end).
278 let resolved_old_text = match resolved_old_text.len() {
279 1 => resolved_old_text.pop().unwrap(),
280 0 => {
281 output_events
282 .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
283 .ok();
284 continue;
285 }
286 _ => {
287 let ranges = resolved_old_text
288 .into_iter()
289 .map(|text| text.range)
290 .collect();
291 output_events
292 .unbounded_send(EditAgentOutputEvent::AmbiguousEditRange(ranges))
293 .ok();
294 continue;
295 }
296 };
297
298 // Compute edits in the background and apply them as they become
299 // available.
300 let (compute_edits, edits) =
301 Self::compute_edits(snapshot, resolved_old_text, edit_events, cx);
302 let mut edits = edits.ready_chunks(32);
303 while let Some(edits) = edits.next().await {
304 if edits.is_empty() {
305 continue;
306 }
307
308 // Edit the buffer and report edits to the action log as part of the
309 // same effect cycle, otherwise the edit will be reported as if the
310 // user made it.
311 cx.update(|cx| {
312 let max_edit_end = buffer.update(cx, |buffer, cx| {
313 buffer.edit(edits.iter().cloned(), None, cx);
314 let max_edit_end = buffer
315 .summaries_for_anchors::<Point, _>(
316 edits.iter().map(|(range, _)| &range.end),
317 )
318 .max()
319 .unwrap();
320 buffer.anchor_before(max_edit_end)
321 });
322 self.action_log
323 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
324 self.project.update(cx, |project, cx| {
325 project.set_agent_location(
326 Some(AgentLocation {
327 buffer: buffer.downgrade(),
328 position: max_edit_end,
329 }),
330 cx,
331 );
332 });
333 })?;
334 output_events
335 .unbounded_send(EditAgentOutputEvent::Edited)
336 .ok();
337 }
338
339 edit_events = compute_edits.await?;
340 }
341
342 output.await
343 }
344
345 fn parse_edit_chunks(
346 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
347 cx: &mut AsyncApp,
348 ) -> (
349 Task<Result<EditAgentOutput>>,
350 UnboundedReceiver<Result<EditParserEvent>>,
351 ) {
352 let (tx, rx) = mpsc::unbounded();
353 let output = cx.background_spawn(async move {
354 pin_mut!(chunks);
355
356 let mut parser = EditParser::new();
357 let mut raw_edits = String::new();
358 while let Some(chunk) = chunks.next().await {
359 match chunk {
360 Ok(chunk) => {
361 raw_edits.push_str(&chunk);
362 for event in parser.push(&chunk) {
363 tx.unbounded_send(Ok(event))?;
364 }
365 }
366 Err(error) => {
367 tx.unbounded_send(Err(error.into()))?;
368 }
369 }
370 }
371 Ok(EditAgentOutput {
372 raw_edits,
373 parser_metrics: parser.finish(),
374 })
375 });
376 (output, rx)
377 }
378
379 fn parse_create_file_chunks(
380 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
381 cx: &mut AsyncApp,
382 ) -> (
383 Task<Result<EditAgentOutput>>,
384 UnboundedReceiver<Result<CreateFileParserEvent>>,
385 ) {
386 let (tx, rx) = mpsc::unbounded();
387 let output = cx.background_spawn(async move {
388 pin_mut!(chunks);
389
390 let mut parser = CreateFileParser::new();
391 let mut raw_edits = String::new();
392 while let Some(chunk) = chunks.next().await {
393 match chunk {
394 Ok(chunk) => {
395 raw_edits.push_str(&chunk);
396 for event in parser.push(Some(&chunk)) {
397 tx.unbounded_send(Ok(event))?;
398 }
399 }
400 Err(error) => {
401 tx.unbounded_send(Err(error.into()))?;
402 }
403 }
404 }
405 // Send final events with None to indicate completion
406 for event in parser.push(None) {
407 tx.unbounded_send(Ok(event))?;
408 }
409 Ok(EditAgentOutput {
410 raw_edits,
411 parser_metrics: EditParserMetrics::default(),
412 })
413 });
414 (output, rx)
415 }
416
417 fn resolve_old_text<T>(
418 snapshot: TextBufferSnapshot,
419 mut edit_events: T,
420 cx: &mut AsyncApp,
421 ) -> (
422 Task<Result<(T, Vec<ResolvedOldText>)>>,
423 watch::Receiver<Option<Range<usize>>>,
424 )
425 where
426 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
427 {
428 let (mut old_range_tx, old_range_rx) = watch::channel(None);
429 let task = cx.background_spawn(async move {
430 let mut matcher = StreamingFuzzyMatcher::new(snapshot);
431 while let Some(edit_event) = edit_events.next().await {
432 let EditParserEvent::OldTextChunk { chunk, done } = edit_event? else {
433 break;
434 };
435
436 old_range_tx.send(matcher.push(&chunk))?;
437 if done {
438 break;
439 }
440 }
441
442 let matches = matcher.finish();
443
444 let old_range = if matches.len() == 1 {
445 matches.first()
446 } else {
447 // No matches or multiple ambiguous matches
448 None
449 };
450 old_range_tx.send(old_range.cloned())?;
451
452 let indent = LineIndent::from_iter(
453 matcher
454 .query_lines()
455 .first()
456 .unwrap_or(&String::new())
457 .chars(),
458 );
459 let resolved_old_texts = matches
460 .into_iter()
461 .map(|range| ResolvedOldText { range, indent })
462 .collect::<Vec<_>>();
463
464 Ok((edit_events, resolved_old_texts))
465 });
466
467 (task, old_range_rx)
468 }
469
470 fn compute_edits<T>(
471 snapshot: BufferSnapshot,
472 resolved_old_text: ResolvedOldText,
473 mut edit_events: T,
474 cx: &mut AsyncApp,
475 ) -> (
476 Task<Result<T>>,
477 UnboundedReceiver<(Range<Anchor>, Arc<str>)>,
478 )
479 where
480 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
481 {
482 let (edits_tx, edits_rx) = mpsc::unbounded();
483 let compute_edits = cx.background_spawn(async move {
484 let buffer_start_indent = snapshot
485 .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row);
486 let indent_delta = if buffer_start_indent.tabs > 0 {
487 IndentDelta::Tabs(
488 buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize,
489 )
490 } else {
491 IndentDelta::Spaces(
492 buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize,
493 )
494 };
495
496 let old_text = snapshot
497 .text_for_range(resolved_old_text.range.clone())
498 .collect::<String>();
499 let mut diff = StreamingDiff::new(old_text);
500 let mut edit_start = resolved_old_text.range.start;
501 let mut new_text_chunks =
502 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
503 let mut done = false;
504 while !done {
505 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await {
506 diff.push_new(&new_text_chunk?)
507 } else {
508 done = true;
509 mem::take(&mut diff).finish()
510 };
511
512 for op in char_operations {
513 match op {
514 CharOperation::Insert { text } => {
515 let edit_start = snapshot.anchor_after(edit_start);
516 edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
517 }
518 CharOperation::Delete { bytes } => {
519 let edit_end = edit_start + bytes;
520 let edit_range =
521 snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end);
522 edit_start = edit_end;
523 edits_tx.unbounded_send((edit_range, Arc::from("")))?;
524 }
525 CharOperation::Keep { bytes } => edit_start += bytes,
526 }
527 }
528 }
529
530 drop(new_text_chunks);
531 anyhow::Ok(edit_events)
532 });
533
534 (compute_edits, edits_rx)
535 }
536
537 fn reindent_new_text_chunks(
538 delta: IndentDelta,
539 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
540 ) -> impl Stream<Item = Result<String>> {
541 let mut buffer = String::new();
542 let mut in_leading_whitespace = true;
543 let mut done = false;
544 futures::stream::poll_fn(move |cx| {
545 while !done {
546 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
547 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
548 (chunk, done)
549 }
550 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
551 Poll::Pending => return Poll::Pending,
552 _ => return Poll::Ready(None),
553 };
554
555 buffer.push_str(&chunk);
556
557 let mut indented_new_text = String::new();
558 let mut start_ix = 0;
559 let mut newlines = buffer.match_indices('\n').peekable();
560 loop {
561 let (line_end, is_pending_line) = match newlines.next() {
562 Some((ix, _)) => (ix, false),
563 None => (buffer.len(), true),
564 };
565 let line = &buffer[start_ix..line_end];
566
567 if in_leading_whitespace {
568 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
569 // We found a non-whitespace character, adjust
570 // indentation based on the delta.
571 let new_indent_len =
572 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
573 indented_new_text
574 .extend(iter::repeat(delta.character()).take(new_indent_len));
575 indented_new_text.push_str(&line[non_whitespace_ix..]);
576 in_leading_whitespace = false;
577 } else if is_pending_line {
578 // We're still in leading whitespace and this line is incomplete.
579 // Stop processing until we receive more input.
580 break;
581 } else {
582 // This line is entirely whitespace. Push it without indentation.
583 indented_new_text.push_str(line);
584 }
585 } else {
586 indented_new_text.push_str(line);
587 }
588
589 if is_pending_line {
590 start_ix = line_end;
591 break;
592 } else {
593 in_leading_whitespace = true;
594 indented_new_text.push('\n');
595 start_ix = line_end + 1;
596 }
597 }
598 buffer.replace_range(..start_ix, "");
599
600 // This was the last chunk, push all the buffered content as-is.
601 if is_last_chunk {
602 indented_new_text.push_str(&buffer);
603 buffer.clear();
604 done = true;
605 }
606
607 if !indented_new_text.is_empty() {
608 return Poll::Ready(Some(Ok(indented_new_text)));
609 }
610 }
611
612 Poll::Ready(None)
613 })
614 }
615
616 async fn request(
617 &self,
618 mut conversation: LanguageModelRequest,
619 intent: CompletionIntent,
620 prompt: String,
621 cx: &mut AsyncApp,
622 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
623 let mut messages_iter = conversation.messages.iter_mut();
624 if let Some(last_message) = messages_iter.next_back() {
625 if last_message.role == Role::Assistant {
626 let old_content_len = last_message.content.len();
627 last_message
628 .content
629 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
630 let new_content_len = last_message.content.len();
631
632 // We just removed pending tool uses from the content of the
633 // last message, so it doesn't make sense to cache it anymore
634 // (e.g., the message will look very different on the next
635 // request). Thus, we move the flag to the message prior to it,
636 // as it will still be a valid prefix of the conversation.
637 if old_content_len != new_content_len && last_message.cache {
638 if let Some(prev_message) = messages_iter.next_back() {
639 last_message.cache = false;
640 prev_message.cache = true;
641 }
642 }
643
644 if last_message.content.is_empty() {
645 conversation.messages.pop();
646 }
647 } else {
648 debug_panic!(
649 "Last message must be an Assistant tool calling! Got {:?}",
650 last_message.content
651 );
652 }
653 }
654
655 conversation.messages.push(LanguageModelRequestMessage {
656 role: Role::User,
657 content: vec![MessageContent::Text(prompt)],
658 cache: false,
659 });
660
661 // Include tools in the request so that we can take advantage of
662 // caching when ToolChoice::None is supported.
663 let mut tool_choice = None;
664 let mut tools = Vec::new();
665 if !conversation.tools.is_empty()
666 && self
667 .model
668 .supports_tool_choice(LanguageModelToolChoice::None)
669 {
670 tool_choice = Some(LanguageModelToolChoice::None);
671 tools = conversation.tools.clone();
672 }
673
674 let request = LanguageModelRequest {
675 thread_id: conversation.thread_id,
676 prompt_id: conversation.prompt_id,
677 intent: Some(intent),
678 mode: conversation.mode,
679 messages: conversation.messages,
680 tool_choice,
681 tools,
682 stop: Vec::new(),
683 temperature: None,
684 };
685
686 Ok(self.model.stream_completion_text(request, cx).await?.stream)
687 }
688}
689
690struct ResolvedOldText {
691 range: Range<usize>,
692 indent: LineIndent,
693}
694
695#[derive(Copy, Clone, Debug)]
696enum IndentDelta {
697 Spaces(isize),
698 Tabs(isize),
699}
700
701impl IndentDelta {
702 fn character(&self) -> char {
703 match self {
704 IndentDelta::Spaces(_) => ' ',
705 IndentDelta::Tabs(_) => '\t',
706 }
707 }
708
709 fn len(&self) -> isize {
710 match self {
711 IndentDelta::Spaces(n) => *n,
712 IndentDelta::Tabs(n) => *n,
713 }
714 }
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use fs::FakeFs;
721 use futures::stream;
722 use gpui::{AppContext, TestAppContext};
723 use indoc::indoc;
724 use language_model::fake_provider::FakeLanguageModel;
725 use project::{AgentLocation, Project};
726 use rand::prelude::*;
727 use rand::rngs::StdRng;
728 use std::cmp;
729
730 #[gpui::test(iterations = 100)]
731 async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) {
732 let agent = init_test(cx).await;
733 let buffer = cx.new(|cx| {
734 Buffer::local(
735 indoc! {"
736 abc
737 def
738 ghi
739 "},
740 cx,
741 )
742 });
743 let (apply, _events) = agent.edit(
744 buffer.clone(),
745 String::new(),
746 &LanguageModelRequest::default(),
747 &mut cx.to_async(),
748 );
749 cx.run_until_parked();
750
751 simulate_llm_output(
752 &agent,
753 indoc! {"
754 <old_text></old_text>
755 <new_text>jkl</new_text>
756 <old_text>def</old_text>
757 <new_text>DEF</new_text>
758 "},
759 &mut rng,
760 cx,
761 );
762 apply.await.unwrap();
763
764 pretty_assertions::assert_eq!(
765 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
766 indoc! {"
767 abc
768 DEF
769 ghi
770 "}
771 );
772 }
773
774 #[gpui::test(iterations = 100)]
775 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
776 let agent = init_test(cx).await;
777 let buffer = cx.new(|cx| {
778 Buffer::local(
779 indoc! {"
780 lorem
781 ipsum
782 dolor
783 sit
784 "},
785 cx,
786 )
787 });
788 let (apply, _events) = agent.edit(
789 buffer.clone(),
790 String::new(),
791 &LanguageModelRequest::default(),
792 &mut cx.to_async(),
793 );
794 cx.run_until_parked();
795
796 simulate_llm_output(
797 &agent,
798 indoc! {"
799 <old_text>
800 ipsum
801 dolor
802 sit
803 </old_text>
804 <new_text>
805 ipsum
806 dolor
807 sit
808 amet
809 </new_text>
810 "},
811 &mut rng,
812 cx,
813 );
814 apply.await.unwrap();
815
816 pretty_assertions::assert_eq!(
817 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
818 indoc! {"
819 lorem
820 ipsum
821 dolor
822 sit
823 amet
824 "}
825 );
826 }
827
828 #[gpui::test(iterations = 100)]
829 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
830 let agent = init_test(cx).await;
831 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
832 let (apply, _events) = agent.edit(
833 buffer.clone(),
834 String::new(),
835 &LanguageModelRequest::default(),
836 &mut cx.to_async(),
837 );
838 cx.run_until_parked();
839
840 simulate_llm_output(
841 &agent,
842 indoc! {"
843 <old_text>
844 def
845 </old_text>
846 <new_text>
847 DEF
848 </new_text>
849
850 <old_text>
851 DEF
852 </old_text>
853 <new_text>
854 DeF
855 </new_text>
856 "},
857 &mut rng,
858 cx,
859 );
860 apply.await.unwrap();
861
862 assert_eq!(
863 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
864 "abc\nDeF\nghi"
865 );
866 }
867
868 #[gpui::test(iterations = 100)]
869 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
870 let agent = init_test(cx).await;
871 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
872 let (apply, _events) = agent.edit(
873 buffer.clone(),
874 String::new(),
875 &LanguageModelRequest::default(),
876 &mut cx.to_async(),
877 );
878 cx.run_until_parked();
879
880 simulate_llm_output(
881 &agent,
882 indoc! {"
883 <old_text>
884 jkl
885 </old_text>
886 <new_text>
887 mno
888 </new_text>
889
890 <old_text>
891 abc
892 </old_text>
893 <new_text>
894 ABC
895 </new_text>
896 "},
897 &mut rng,
898 cx,
899 );
900 apply.await.unwrap();
901
902 assert_eq!(
903 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
904 "ABC\ndef\nghi"
905 );
906 }
907
908 #[gpui::test]
909 async fn test_edit_events(cx: &mut TestAppContext) {
910 let agent = init_test(cx).await;
911 let model = agent.model.as_fake();
912 let project = agent
913 .action_log
914 .read_with(cx, |log, _| log.project().clone());
915 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx));
916
917 let mut async_cx = cx.to_async();
918 let (apply, mut events) = agent.edit(
919 buffer.clone(),
920 String::new(),
921 &LanguageModelRequest::default(),
922 &mut async_cx,
923 );
924 cx.run_until_parked();
925
926 model.stream_last_completion_response("<old_text>a");
927 cx.run_until_parked();
928 assert_eq!(drain_events(&mut events), vec![]);
929 assert_eq!(
930 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
931 "abc\ndef\nghi\njkl"
932 );
933 assert_eq!(
934 project.read_with(cx, |project, _| project.agent_location()),
935 None
936 );
937
938 model.stream_last_completion_response("bc</old_text>");
939 cx.run_until_parked();
940 assert_eq!(
941 drain_events(&mut events),
942 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
943 cx,
944 |buffer, _| buffer.anchor_before(Point::new(0, 0))
945 ..buffer.anchor_before(Point::new(0, 3))
946 ))]
947 );
948 assert_eq!(
949 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
950 "abc\ndef\nghi\njkl"
951 );
952 assert_eq!(
953 project.read_with(cx, |project, _| project.agent_location()),
954 Some(AgentLocation {
955 buffer: buffer.downgrade(),
956 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
957 })
958 );
959
960 model.stream_last_completion_response("<new_text>abX");
961 cx.run_until_parked();
962 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
963 assert_eq!(
964 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
965 "abXc\ndef\nghi\njkl"
966 );
967 assert_eq!(
968 project.read_with(cx, |project, _| project.agent_location()),
969 Some(AgentLocation {
970 buffer: buffer.downgrade(),
971 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
972 })
973 );
974
975 model.stream_last_completion_response("cY");
976 cx.run_until_parked();
977 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
978 assert_eq!(
979 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
980 "abXcY\ndef\nghi\njkl"
981 );
982 assert_eq!(
983 project.read_with(cx, |project, _| project.agent_location()),
984 Some(AgentLocation {
985 buffer: buffer.downgrade(),
986 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
987 })
988 );
989
990 model.stream_last_completion_response("</new_text>");
991 model.stream_last_completion_response("<old_text>hall");
992 cx.run_until_parked();
993 assert_eq!(drain_events(&mut events), vec![]);
994 assert_eq!(
995 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
996 "abXcY\ndef\nghi\njkl"
997 );
998 assert_eq!(
999 project.read_with(cx, |project, _| project.agent_location()),
1000 Some(AgentLocation {
1001 buffer: buffer.downgrade(),
1002 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1003 })
1004 );
1005
1006 model.stream_last_completion_response("ucinated old</old_text>");
1007 model.stream_last_completion_response("<new_text>");
1008 cx.run_until_parked();
1009 assert_eq!(
1010 drain_events(&mut events),
1011 vec![EditAgentOutputEvent::UnresolvedEditRange]
1012 );
1013 assert_eq!(
1014 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1015 "abXcY\ndef\nghi\njkl"
1016 );
1017 assert_eq!(
1018 project.read_with(cx, |project, _| project.agent_location()),
1019 Some(AgentLocation {
1020 buffer: buffer.downgrade(),
1021 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1022 })
1023 );
1024
1025 model.stream_last_completion_response("hallucinated new</new_");
1026 model.stream_last_completion_response("text>");
1027 cx.run_until_parked();
1028 assert_eq!(drain_events(&mut events), vec![]);
1029 assert_eq!(
1030 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1031 "abXcY\ndef\nghi\njkl"
1032 );
1033 assert_eq!(
1034 project.read_with(cx, |project, _| project.agent_location()),
1035 Some(AgentLocation {
1036 buffer: buffer.downgrade(),
1037 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1038 })
1039 );
1040
1041 model.stream_last_completion_response("<old_text>\nghi\nj");
1042 cx.run_until_parked();
1043 assert_eq!(
1044 drain_events(&mut events),
1045 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1046 cx,
1047 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1048 ..buffer.anchor_before(Point::new(2, 3))
1049 ))]
1050 );
1051 assert_eq!(
1052 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1053 "abXcY\ndef\nghi\njkl"
1054 );
1055 assert_eq!(
1056 project.read_with(cx, |project, _| project.agent_location()),
1057 Some(AgentLocation {
1058 buffer: buffer.downgrade(),
1059 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1060 })
1061 );
1062
1063 model.stream_last_completion_response("kl</old_text>");
1064 model.stream_last_completion_response("<new_text>");
1065 cx.run_until_parked();
1066 assert_eq!(
1067 drain_events(&mut events),
1068 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1069 cx,
1070 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1071 ..buffer.anchor_before(Point::new(3, 3))
1072 ))]
1073 );
1074 assert_eq!(
1075 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1076 "abXcY\ndef\nghi\njkl"
1077 );
1078 assert_eq!(
1079 project.read_with(cx, |project, _| project.agent_location()),
1080 Some(AgentLocation {
1081 buffer: buffer.downgrade(),
1082 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)))
1083 })
1084 );
1085
1086 model.stream_last_completion_response("GHI</new_text>");
1087 cx.run_until_parked();
1088 assert_eq!(
1089 drain_events(&mut events),
1090 vec![EditAgentOutputEvent::Edited]
1091 );
1092 assert_eq!(
1093 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1094 "abXcY\ndef\nGHI"
1095 );
1096 assert_eq!(
1097 project.read_with(cx, |project, _| project.agent_location()),
1098 Some(AgentLocation {
1099 buffer: buffer.downgrade(),
1100 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1101 })
1102 );
1103
1104 model.end_last_completion_stream();
1105 apply.await.unwrap();
1106 assert_eq!(
1107 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1108 "abXcY\ndef\nGHI"
1109 );
1110 assert_eq!(drain_events(&mut events), vec![]);
1111 assert_eq!(
1112 project.read_with(cx, |project, _| project.agent_location()),
1113 Some(AgentLocation {
1114 buffer: buffer.downgrade(),
1115 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1116 })
1117 );
1118 }
1119
1120 #[gpui::test]
1121 async fn test_overwrite_events(cx: &mut TestAppContext) {
1122 let agent = init_test(cx).await;
1123 let project = agent
1124 .action_log
1125 .read_with(cx, |log, _| log.project().clone());
1126 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1127 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1128 let (apply, mut events) = agent.overwrite_with_chunks(
1129 buffer.clone(),
1130 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1131 &mut cx.to_async(),
1132 );
1133
1134 cx.run_until_parked();
1135 assert_eq!(
1136 drain_events(&mut events),
1137 vec![EditAgentOutputEvent::Edited]
1138 );
1139 assert_eq!(
1140 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1141 ""
1142 );
1143 assert_eq!(
1144 project.read_with(cx, |project, _| project.agent_location()),
1145 Some(AgentLocation {
1146 buffer: buffer.downgrade(),
1147 position: language::Anchor::MAX
1148 })
1149 );
1150
1151 chunks_tx.unbounded_send("```\njkl\n").unwrap();
1152 cx.run_until_parked();
1153 assert_eq!(
1154 drain_events(&mut events),
1155 vec![EditAgentOutputEvent::Edited]
1156 );
1157 assert_eq!(
1158 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1159 "jkl"
1160 );
1161 assert_eq!(
1162 project.read_with(cx, |project, _| project.agent_location()),
1163 Some(AgentLocation {
1164 buffer: buffer.downgrade(),
1165 position: language::Anchor::MAX
1166 })
1167 );
1168
1169 chunks_tx.unbounded_send("mno\n").unwrap();
1170 cx.run_until_parked();
1171 assert_eq!(
1172 drain_events(&mut events),
1173 vec![EditAgentOutputEvent::Edited]
1174 );
1175 assert_eq!(
1176 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1177 "jkl\nmno"
1178 );
1179 assert_eq!(
1180 project.read_with(cx, |project, _| project.agent_location()),
1181 Some(AgentLocation {
1182 buffer: buffer.downgrade(),
1183 position: language::Anchor::MAX
1184 })
1185 );
1186
1187 chunks_tx.unbounded_send("pqr\n```").unwrap();
1188 cx.run_until_parked();
1189 assert_eq!(
1190 drain_events(&mut events),
1191 vec![EditAgentOutputEvent::Edited]
1192 );
1193 assert_eq!(
1194 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1195 "jkl\nmno\npqr"
1196 );
1197 assert_eq!(
1198 project.read_with(cx, |project, _| project.agent_location()),
1199 Some(AgentLocation {
1200 buffer: buffer.downgrade(),
1201 position: language::Anchor::MAX
1202 })
1203 );
1204
1205 drop(chunks_tx);
1206 apply.await.unwrap();
1207 assert_eq!(
1208 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1209 "jkl\nmno\npqr"
1210 );
1211 assert_eq!(drain_events(&mut events), vec![]);
1212 assert_eq!(
1213 project.read_with(cx, |project, _| project.agent_location()),
1214 Some(AgentLocation {
1215 buffer: buffer.downgrade(),
1216 position: language::Anchor::MAX
1217 })
1218 );
1219 }
1220
1221 #[gpui::test(iterations = 100)]
1222 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1223 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1224 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1225 Ok(EditParserEvent::NewTextChunk {
1226 chunk: chunk.clone(),
1227 done: index == chunks.len() - 1,
1228 })
1229 }));
1230 let indented_chunks =
1231 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1232 .collect::<Vec<_>>()
1233 .await;
1234 let new_text = indented_chunks
1235 .into_iter()
1236 .collect::<Result<String>>()
1237 .unwrap();
1238 assert_eq!(new_text, " abc\n def\n ghi");
1239 }
1240
1241 #[gpui::test(iterations = 100)]
1242 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1243 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1244 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1245 Ok(EditParserEvent::NewTextChunk {
1246 chunk: chunk.clone(),
1247 done: index == chunks.len() - 1,
1248 })
1249 }));
1250 let indented_chunks =
1251 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1252 .collect::<Vec<_>>()
1253 .await;
1254 let new_text = indented_chunks
1255 .into_iter()
1256 .collect::<Result<String>>()
1257 .unwrap();
1258 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1259 }
1260
1261 #[gpui::test(iterations = 100)]
1262 async fn test_random_indents(mut rng: StdRng) {
1263 let len = rng.gen_range(1..=100);
1264 let new_text = util::RandomCharIter::new(&mut rng)
1265 .with_simple_text()
1266 .take(len)
1267 .collect::<String>();
1268 let new_text = new_text
1269 .split('\n')
1270 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1271 .collect::<Vec<_>>()
1272 .join("\n");
1273 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1274
1275 let chunks = to_random_chunks(&mut rng, &new_text);
1276 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1277 Ok(EditParserEvent::NewTextChunk {
1278 chunk: chunk.clone(),
1279 done: index == chunks.len() - 1,
1280 })
1281 }));
1282 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1283 .collect::<Vec<_>>()
1284 .await;
1285 let actual_reindented_text = reindented_chunks
1286 .into_iter()
1287 .collect::<Result<String>>()
1288 .unwrap();
1289 let expected_reindented_text = new_text
1290 .split('\n')
1291 .map(|line| {
1292 if let Some(ix) = line.find(|c| c != ' ') {
1293 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1294 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1295 } else {
1296 line.to_string()
1297 }
1298 })
1299 .collect::<Vec<_>>()
1300 .join("\n");
1301 assert_eq!(actual_reindented_text, expected_reindented_text);
1302 }
1303
1304 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1305 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1306 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1307 chunk_indices.sort();
1308 chunk_indices.push(input.len());
1309
1310 let mut chunks = Vec::new();
1311 let mut last_ix = 0;
1312 for chunk_ix in chunk_indices {
1313 chunks.push(input[last_ix..chunk_ix].to_string());
1314 last_ix = chunk_ix;
1315 }
1316 chunks
1317 }
1318
1319 fn simulate_llm_output(
1320 agent: &EditAgent,
1321 output: &str,
1322 rng: &mut StdRng,
1323 cx: &mut TestAppContext,
1324 ) {
1325 let executor = cx.executor();
1326 let chunks = to_random_chunks(rng, output);
1327 let model = agent.model.clone();
1328 cx.background_spawn(async move {
1329 for chunk in chunks {
1330 executor.simulate_random_delay().await;
1331 model.as_fake().stream_last_completion_response(chunk);
1332 }
1333 model.as_fake().end_last_completion_stream();
1334 })
1335 .detach();
1336 }
1337
1338 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1339 cx.update(settings::init);
1340 cx.update(Project::init_settings);
1341 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1342 let model = Arc::new(FakeLanguageModel::default());
1343 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1344 EditAgent::new(model, project, action_log, Templates::new())
1345 }
1346
1347 #[gpui::test(iterations = 10)]
1348 async fn test_non_unique_text_error(cx: &mut TestAppContext, mut rng: StdRng) {
1349 let agent = init_test(cx).await;
1350 let original_text = indoc! {"
1351 function foo() {
1352 return 42;
1353 }
1354
1355 function bar() {
1356 return 42;
1357 }
1358
1359 function baz() {
1360 return 42;
1361 }
1362 "};
1363 let buffer = cx.new(|cx| Buffer::local(original_text, cx));
1364 let (apply, mut events) = agent.edit(
1365 buffer.clone(),
1366 String::new(),
1367 &LanguageModelRequest::default(),
1368 &mut cx.to_async(),
1369 );
1370 cx.run_until_parked();
1371
1372 // When <old_text> matches text in more than one place
1373 simulate_llm_output(
1374 &agent,
1375 indoc! {"
1376 <old_text>
1377 return 42;
1378 </old_text>
1379 <new_text>
1380 return 100;
1381 </new_text>
1382 "},
1383 &mut rng,
1384 cx,
1385 );
1386 apply.await.unwrap();
1387
1388 // Then the text should remain unchanged
1389 let result_text = buffer.read_with(cx, |buffer, _| buffer.snapshot().text());
1390 assert_eq!(
1391 result_text,
1392 indoc! {"
1393 function foo() {
1394 return 42;
1395 }
1396
1397 function bar() {
1398 return 42;
1399 }
1400
1401 function baz() {
1402 return 42;
1403 }
1404 "},
1405 "Text should remain unchanged when there are multiple matches"
1406 );
1407
1408 // And AmbiguousEditRange even should be emitted
1409 let events = drain_events(&mut events);
1410 let ambiguous_ranges = vec![17..31, 52..66, 87..101];
1411 assert!(
1412 events.contains(&EditAgentOutputEvent::AmbiguousEditRange(ambiguous_ranges)),
1413 "Should emit AmbiguousEditRange for non-unique text"
1414 );
1415 }
1416
1417 fn drain_events(
1418 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1419 ) -> Vec<EditAgentOutputEvent> {
1420 let mut events = Vec::new();
1421 while let Ok(Some(event)) = stream.try_next() {
1422 events.push(event);
1423 }
1424 events
1425 }
1426}