1mod create_file_parser;
2mod edit_parser;
3#[cfg(test)]
4mod evals;
5mod streaming_fuzzy_matcher;
6
7use crate::{Template, Templates};
8use anyhow::Result;
9use assistant_tool::ActionLog;
10use create_file_parser::{CreateFileParser, CreateFileParserEvent};
11use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
12use futures::{
13 Stream, StreamExt,
14 channel::mpsc::{self, UnboundedReceiver},
15 pin_mut,
16 stream::BoxStream,
17};
18use gpui::{AppContext, AsyncApp, Entity, Task};
19use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot};
20use language_model::{
21 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
22 LanguageModelToolChoice, MessageContent, Role,
23};
24use project::{AgentLocation, Project};
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
28use streaming_diff::{CharOperation, StreamingDiff};
29use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
30use util::debug_panic;
31use zed_llm_client::CompletionIntent;
32
33#[derive(Serialize)]
34struct CreateFilePromptTemplate {
35 path: Option<PathBuf>,
36 edit_description: String,
37}
38
39impl Template for CreateFilePromptTemplate {
40 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
41}
42
43#[derive(Serialize)]
44struct EditFilePromptTemplate {
45 path: Option<PathBuf>,
46 edit_description: String,
47}
48
49impl Template for EditFilePromptTemplate {
50 const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
54pub enum EditAgentOutputEvent {
55 ResolvingEditRange(Range<Anchor>),
56 UnresolvedEditRange,
57 AmbiguousEditRange(Vec<Range<usize>>),
58 Edited,
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
62pub struct EditAgentOutput {
63 pub raw_edits: String,
64 pub parser_metrics: EditParserMetrics,
65}
66
67#[derive(Clone)]
68pub struct EditAgent {
69 model: Arc<dyn LanguageModel>,
70 action_log: Entity<ActionLog>,
71 project: Entity<Project>,
72 templates: Arc<Templates>,
73}
74
75impl EditAgent {
76 pub fn new(
77 model: Arc<dyn LanguageModel>,
78 project: Entity<Project>,
79 action_log: Entity<ActionLog>,
80 templates: Arc<Templates>,
81 ) -> Self {
82 EditAgent {
83 model,
84 project,
85 action_log,
86 templates,
87 }
88 }
89
90 pub fn overwrite(
91 &self,
92 buffer: Entity<Buffer>,
93 edit_description: String,
94 conversation: &LanguageModelRequest,
95 cx: &mut AsyncApp,
96 ) -> (
97 Task<Result<EditAgentOutput>>,
98 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
99 ) {
100 let this = self.clone();
101 let (events_tx, events_rx) = mpsc::unbounded();
102 let conversation = conversation.clone();
103 let output = cx.spawn(async move |cx| {
104 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
105 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
106 let prompt = CreateFilePromptTemplate {
107 path,
108 edit_description,
109 }
110 .render(&this.templates)?;
111 let new_chunks = this
112 .request(conversation, CompletionIntent::CreateFile, prompt, cx)
113 .await?;
114
115 let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
116 while let Some(event) = inner_events.next().await {
117 events_tx.unbounded_send(event).ok();
118 }
119 output.await
120 });
121 (output, events_rx)
122 }
123
124 fn overwrite_with_chunks(
125 &self,
126 buffer: Entity<Buffer>,
127 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
128 cx: &mut AsyncApp,
129 ) -> (
130 Task<Result<EditAgentOutput>>,
131 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
132 ) {
133 let (output_events_tx, output_events_rx) = mpsc::unbounded();
134 let (parse_task, parse_rx) = Self::parse_create_file_chunks(edit_chunks, cx);
135 let this = self.clone();
136 let task = cx.spawn(async move |cx| {
137 this.action_log
138 .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
139 this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx)
140 .await?;
141 parse_task.await
142 });
143 (task, output_events_rx)
144 }
145
146 async fn overwrite_with_chunks_internal(
147 &self,
148 buffer: Entity<Buffer>,
149 mut parse_rx: UnboundedReceiver<Result<CreateFileParserEvent>>,
150 output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
151 cx: &mut AsyncApp,
152 ) -> Result<()> {
153 cx.update(|cx| {
154 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
155 self.action_log.update(cx, |log, cx| {
156 log.buffer_edited(buffer.clone(), cx);
157 });
158 self.project.update(cx, |project, cx| {
159 project.set_agent_location(
160 Some(AgentLocation {
161 buffer: buffer.downgrade(),
162 position: language::Anchor::MAX,
163 }),
164 cx,
165 )
166 });
167 output_events_tx
168 .unbounded_send(EditAgentOutputEvent::Edited)
169 .ok();
170 })?;
171
172 while let Some(event) = parse_rx.next().await {
173 match event? {
174 CreateFileParserEvent::NewTextChunk { chunk } => {
175 cx.update(|cx| {
176 buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
177 self.action_log
178 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
179 self.project.update(cx, |project, cx| {
180 project.set_agent_location(
181 Some(AgentLocation {
182 buffer: buffer.downgrade(),
183 position: language::Anchor::MAX,
184 }),
185 cx,
186 )
187 });
188 })?;
189 output_events_tx
190 .unbounded_send(EditAgentOutputEvent::Edited)
191 .ok();
192 }
193 }
194 }
195
196 Ok(())
197 }
198
199 pub fn edit(
200 &self,
201 buffer: Entity<Buffer>,
202 edit_description: String,
203 conversation: &LanguageModelRequest,
204 cx: &mut AsyncApp,
205 ) -> (
206 Task<Result<EditAgentOutput>>,
207 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
208 ) {
209 let this = self.clone();
210 let (events_tx, events_rx) = mpsc::unbounded();
211 let conversation = conversation.clone();
212 let output = cx.spawn(async move |cx| {
213 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
214 let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
215 let prompt = EditFilePromptTemplate {
216 path,
217 edit_description,
218 }
219 .render(&this.templates)?;
220 let edit_chunks = this
221 .request(conversation, CompletionIntent::EditFile, prompt, cx)
222 .await?;
223 this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
224 .await
225 });
226 (output, events_rx)
227 }
228
229 async fn apply_edit_chunks(
230 &self,
231 buffer: Entity<Buffer>,
232 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
233 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
234 cx: &mut AsyncApp,
235 ) -> Result<EditAgentOutput> {
236 self.action_log
237 .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
238
239 let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
240 let mut edit_events = edit_events.peekable();
241 while let Some(edit_event) = Pin::new(&mut edit_events).peek().await {
242 // Skip events until we're at the start of a new edit.
243 let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else {
244 edit_events.next().await.unwrap()?;
245 continue;
246 };
247
248 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
249
250 // Resolve the old text in the background, updating the agent
251 // location as we keep refining which range it corresponds to.
252 let (resolve_old_text, mut old_range) =
253 Self::resolve_old_text(snapshot.text.clone(), edit_events, cx);
254 while let Ok(old_range) = old_range.recv().await {
255 if let Some(old_range) = old_range {
256 let old_range = snapshot.anchor_before(old_range.start)
257 ..snapshot.anchor_before(old_range.end);
258 self.project.update(cx, |project, cx| {
259 project.set_agent_location(
260 Some(AgentLocation {
261 buffer: buffer.downgrade(),
262 position: old_range.end,
263 }),
264 cx,
265 );
266 })?;
267 output_events
268 .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range))
269 .ok();
270 }
271 }
272
273 let (edit_events_, mut resolved_old_text) = resolve_old_text.await?;
274 edit_events = edit_events_;
275
276 // If we can't resolve the old text, restart the loop waiting for a
277 // new edit (or for the stream to end).
278 let resolved_old_text = match resolved_old_text.len() {
279 1 => resolved_old_text.pop().unwrap(),
280 0 => {
281 output_events
282 .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
283 .ok();
284 continue;
285 }
286 _ => {
287 let ranges = resolved_old_text
288 .into_iter()
289 .map(|text| {
290 let start_line =
291 (snapshot.offset_to_point(text.range.start).row + 1) as usize;
292 let end_line =
293 (snapshot.offset_to_point(text.range.end).row + 1) as usize;
294 start_line..end_line
295 })
296 .collect();
297 output_events
298 .unbounded_send(EditAgentOutputEvent::AmbiguousEditRange(ranges))
299 .ok();
300 continue;
301 }
302 };
303
304 // Compute edits in the background and apply them as they become
305 // available.
306 let (compute_edits, edits) =
307 Self::compute_edits(snapshot, resolved_old_text, edit_events, cx);
308 let mut edits = edits.ready_chunks(32);
309 while let Some(edits) = edits.next().await {
310 if edits.is_empty() {
311 continue;
312 }
313
314 // Edit the buffer and report edits to the action log as part of the
315 // same effect cycle, otherwise the edit will be reported as if the
316 // user made it.
317 cx.update(|cx| {
318 let max_edit_end = buffer.update(cx, |buffer, cx| {
319 buffer.edit(edits.iter().cloned(), None, cx);
320 let max_edit_end = buffer
321 .summaries_for_anchors::<Point, _>(
322 edits.iter().map(|(range, _)| &range.end),
323 )
324 .max()
325 .unwrap();
326 buffer.anchor_before(max_edit_end)
327 });
328 self.action_log
329 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
330 self.project.update(cx, |project, cx| {
331 project.set_agent_location(
332 Some(AgentLocation {
333 buffer: buffer.downgrade(),
334 position: max_edit_end,
335 }),
336 cx,
337 );
338 });
339 })?;
340 output_events
341 .unbounded_send(EditAgentOutputEvent::Edited)
342 .ok();
343 }
344
345 edit_events = compute_edits.await?;
346 }
347
348 output.await
349 }
350
351 fn parse_edit_chunks(
352 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
353 cx: &mut AsyncApp,
354 ) -> (
355 Task<Result<EditAgentOutput>>,
356 UnboundedReceiver<Result<EditParserEvent>>,
357 ) {
358 let (tx, rx) = mpsc::unbounded();
359 let output = cx.background_spawn(async move {
360 pin_mut!(chunks);
361
362 let mut parser = EditParser::new();
363 let mut raw_edits = String::new();
364 while let Some(chunk) = chunks.next().await {
365 match chunk {
366 Ok(chunk) => {
367 raw_edits.push_str(&chunk);
368 for event in parser.push(&chunk) {
369 tx.unbounded_send(Ok(event))?;
370 }
371 }
372 Err(error) => {
373 tx.unbounded_send(Err(error.into()))?;
374 }
375 }
376 }
377 Ok(EditAgentOutput {
378 raw_edits,
379 parser_metrics: parser.finish(),
380 })
381 });
382 (output, rx)
383 }
384
385 fn parse_create_file_chunks(
386 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
387 cx: &mut AsyncApp,
388 ) -> (
389 Task<Result<EditAgentOutput>>,
390 UnboundedReceiver<Result<CreateFileParserEvent>>,
391 ) {
392 let (tx, rx) = mpsc::unbounded();
393 let output = cx.background_spawn(async move {
394 pin_mut!(chunks);
395
396 let mut parser = CreateFileParser::new();
397 let mut raw_edits = String::new();
398 while let Some(chunk) = chunks.next().await {
399 match chunk {
400 Ok(chunk) => {
401 raw_edits.push_str(&chunk);
402 for event in parser.push(Some(&chunk)) {
403 tx.unbounded_send(Ok(event))?;
404 }
405 }
406 Err(error) => {
407 tx.unbounded_send(Err(error.into()))?;
408 }
409 }
410 }
411 // Send final events with None to indicate completion
412 for event in parser.push(None) {
413 tx.unbounded_send(Ok(event))?;
414 }
415 Ok(EditAgentOutput {
416 raw_edits,
417 parser_metrics: EditParserMetrics::default(),
418 })
419 });
420 (output, rx)
421 }
422
423 fn resolve_old_text<T>(
424 snapshot: TextBufferSnapshot,
425 mut edit_events: T,
426 cx: &mut AsyncApp,
427 ) -> (
428 Task<Result<(T, Vec<ResolvedOldText>)>>,
429 watch::Receiver<Option<Range<usize>>>,
430 )
431 where
432 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
433 {
434 let (mut old_range_tx, old_range_rx) = watch::channel(None);
435 let task = cx.background_spawn(async move {
436 let mut matcher = StreamingFuzzyMatcher::new(snapshot);
437 while let Some(edit_event) = edit_events.next().await {
438 let EditParserEvent::OldTextChunk {
439 chunk,
440 done,
441 line_hint,
442 } = edit_event?
443 else {
444 break;
445 };
446
447 old_range_tx.send(matcher.push(&chunk, line_hint))?;
448 if done {
449 break;
450 }
451 }
452
453 let matches = matcher.finish();
454 let best_match = matcher.select_best_match();
455
456 old_range_tx.send(best_match.clone())?;
457
458 let indent = LineIndent::from_iter(
459 matcher
460 .query_lines()
461 .first()
462 .unwrap_or(&String::new())
463 .chars(),
464 );
465
466 let resolved_old_texts = if let Some(best_match) = best_match {
467 vec![ResolvedOldText {
468 range: best_match,
469 indent,
470 }]
471 } else {
472 matches
473 .into_iter()
474 .map(|range| ResolvedOldText { range, indent })
475 .collect::<Vec<_>>()
476 };
477
478 Ok((edit_events, resolved_old_texts))
479 });
480
481 (task, old_range_rx)
482 }
483
484 fn compute_edits<T>(
485 snapshot: BufferSnapshot,
486 resolved_old_text: ResolvedOldText,
487 mut edit_events: T,
488 cx: &mut AsyncApp,
489 ) -> (
490 Task<Result<T>>,
491 UnboundedReceiver<(Range<Anchor>, Arc<str>)>,
492 )
493 where
494 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
495 {
496 let (edits_tx, edits_rx) = mpsc::unbounded();
497 let compute_edits = cx.background_spawn(async move {
498 let buffer_start_indent = snapshot
499 .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row);
500 let indent_delta = if buffer_start_indent.tabs > 0 {
501 IndentDelta::Tabs(
502 buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize,
503 )
504 } else {
505 IndentDelta::Spaces(
506 buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize,
507 )
508 };
509
510 let old_text = snapshot
511 .text_for_range(resolved_old_text.range.clone())
512 .collect::<String>();
513 let mut diff = StreamingDiff::new(old_text);
514 let mut edit_start = resolved_old_text.range.start;
515 let mut new_text_chunks =
516 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
517 let mut done = false;
518 while !done {
519 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await {
520 diff.push_new(&new_text_chunk?)
521 } else {
522 done = true;
523 mem::take(&mut diff).finish()
524 };
525
526 for op in char_operations {
527 match op {
528 CharOperation::Insert { text } => {
529 let edit_start = snapshot.anchor_after(edit_start);
530 edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
531 }
532 CharOperation::Delete { bytes } => {
533 let edit_end = edit_start + bytes;
534 let edit_range =
535 snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end);
536 edit_start = edit_end;
537 edits_tx.unbounded_send((edit_range, Arc::from("")))?;
538 }
539 CharOperation::Keep { bytes } => edit_start += bytes,
540 }
541 }
542 }
543
544 drop(new_text_chunks);
545 anyhow::Ok(edit_events)
546 });
547
548 (compute_edits, edits_rx)
549 }
550
551 fn reindent_new_text_chunks(
552 delta: IndentDelta,
553 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
554 ) -> impl Stream<Item = Result<String>> {
555 let mut buffer = String::new();
556 let mut in_leading_whitespace = true;
557 let mut done = false;
558 futures::stream::poll_fn(move |cx| {
559 while !done {
560 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
561 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
562 (chunk, done)
563 }
564 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
565 Poll::Pending => return Poll::Pending,
566 _ => return Poll::Ready(None),
567 };
568
569 buffer.push_str(&chunk);
570
571 let mut indented_new_text = String::new();
572 let mut start_ix = 0;
573 let mut newlines = buffer.match_indices('\n').peekable();
574 loop {
575 let (line_end, is_pending_line) = match newlines.next() {
576 Some((ix, _)) => (ix, false),
577 None => (buffer.len(), true),
578 };
579 let line = &buffer[start_ix..line_end];
580
581 if in_leading_whitespace {
582 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
583 // We found a non-whitespace character, adjust
584 // indentation based on the delta.
585 let new_indent_len =
586 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
587 indented_new_text
588 .extend(iter::repeat(delta.character()).take(new_indent_len));
589 indented_new_text.push_str(&line[non_whitespace_ix..]);
590 in_leading_whitespace = false;
591 } else if is_pending_line {
592 // We're still in leading whitespace and this line is incomplete.
593 // Stop processing until we receive more input.
594 break;
595 } else {
596 // This line is entirely whitespace. Push it without indentation.
597 indented_new_text.push_str(line);
598 }
599 } else {
600 indented_new_text.push_str(line);
601 }
602
603 if is_pending_line {
604 start_ix = line_end;
605 break;
606 } else {
607 in_leading_whitespace = true;
608 indented_new_text.push('\n');
609 start_ix = line_end + 1;
610 }
611 }
612 buffer.replace_range(..start_ix, "");
613
614 // This was the last chunk, push all the buffered content as-is.
615 if is_last_chunk {
616 indented_new_text.push_str(&buffer);
617 buffer.clear();
618 done = true;
619 }
620
621 if !indented_new_text.is_empty() {
622 return Poll::Ready(Some(Ok(indented_new_text)));
623 }
624 }
625
626 Poll::Ready(None)
627 })
628 }
629
630 async fn request(
631 &self,
632 mut conversation: LanguageModelRequest,
633 intent: CompletionIntent,
634 prompt: String,
635 cx: &mut AsyncApp,
636 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
637 let mut messages_iter = conversation.messages.iter_mut();
638 if let Some(last_message) = messages_iter.next_back() {
639 if last_message.role == Role::Assistant {
640 let old_content_len = last_message.content.len();
641 last_message
642 .content
643 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
644 let new_content_len = last_message.content.len();
645
646 // We just removed pending tool uses from the content of the
647 // last message, so it doesn't make sense to cache it anymore
648 // (e.g., the message will look very different on the next
649 // request). Thus, we move the flag to the message prior to it,
650 // as it will still be a valid prefix of the conversation.
651 if old_content_len != new_content_len && last_message.cache {
652 if let Some(prev_message) = messages_iter.next_back() {
653 last_message.cache = false;
654 prev_message.cache = true;
655 }
656 }
657
658 if last_message.content.is_empty() {
659 conversation.messages.pop();
660 }
661 } else {
662 debug_panic!(
663 "Last message must be an Assistant tool calling! Got {:?}",
664 last_message.content
665 );
666 }
667 }
668
669 conversation.messages.push(LanguageModelRequestMessage {
670 role: Role::User,
671 content: vec![MessageContent::Text(prompt)],
672 cache: false,
673 });
674
675 // Include tools in the request so that we can take advantage of
676 // caching when ToolChoice::None is supported.
677 let mut tool_choice = None;
678 let mut tools = Vec::new();
679 if !conversation.tools.is_empty()
680 && self
681 .model
682 .supports_tool_choice(LanguageModelToolChoice::None)
683 {
684 tool_choice = Some(LanguageModelToolChoice::None);
685 tools = conversation.tools.clone();
686 }
687
688 let request = LanguageModelRequest {
689 thread_id: conversation.thread_id,
690 prompt_id: conversation.prompt_id,
691 intent: Some(intent),
692 mode: conversation.mode,
693 messages: conversation.messages,
694 tool_choice,
695 tools,
696 stop: Vec::new(),
697 temperature: None,
698 };
699
700 Ok(self.model.stream_completion_text(request, cx).await?.stream)
701 }
702}
703
704struct ResolvedOldText {
705 range: Range<usize>,
706 indent: LineIndent,
707}
708
709#[derive(Copy, Clone, Debug)]
710enum IndentDelta {
711 Spaces(isize),
712 Tabs(isize),
713}
714
715impl IndentDelta {
716 fn character(&self) -> char {
717 match self {
718 IndentDelta::Spaces(_) => ' ',
719 IndentDelta::Tabs(_) => '\t',
720 }
721 }
722
723 fn len(&self) -> isize {
724 match self {
725 IndentDelta::Spaces(n) => *n,
726 IndentDelta::Tabs(n) => *n,
727 }
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use fs::FakeFs;
735 use futures::stream;
736 use gpui::{AppContext, TestAppContext};
737 use indoc::indoc;
738 use language_model::fake_provider::FakeLanguageModel;
739 use project::{AgentLocation, Project};
740 use rand::prelude::*;
741 use rand::rngs::StdRng;
742 use std::cmp;
743
744 #[gpui::test(iterations = 100)]
745 async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) {
746 let agent = init_test(cx).await;
747 let buffer = cx.new(|cx| {
748 Buffer::local(
749 indoc! {"
750 abc
751 def
752 ghi
753 "},
754 cx,
755 )
756 });
757 let (apply, _events) = agent.edit(
758 buffer.clone(),
759 String::new(),
760 &LanguageModelRequest::default(),
761 &mut cx.to_async(),
762 );
763 cx.run_until_parked();
764
765 simulate_llm_output(
766 &agent,
767 indoc! {"
768 <old_text></old_text>
769 <new_text>jkl</new_text>
770 <old_text>def</old_text>
771 <new_text>DEF</new_text>
772 "},
773 &mut rng,
774 cx,
775 );
776 apply.await.unwrap();
777
778 pretty_assertions::assert_eq!(
779 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
780 indoc! {"
781 abc
782 DEF
783 ghi
784 "}
785 );
786 }
787
788 #[gpui::test(iterations = 100)]
789 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
790 let agent = init_test(cx).await;
791 let buffer = cx.new(|cx| {
792 Buffer::local(
793 indoc! {"
794 lorem
795 ipsum
796 dolor
797 sit
798 "},
799 cx,
800 )
801 });
802 let (apply, _events) = agent.edit(
803 buffer.clone(),
804 String::new(),
805 &LanguageModelRequest::default(),
806 &mut cx.to_async(),
807 );
808 cx.run_until_parked();
809
810 simulate_llm_output(
811 &agent,
812 indoc! {"
813 <old_text>
814 ipsum
815 dolor
816 sit
817 </old_text>
818 <new_text>
819 ipsum
820 dolor
821 sit
822 amet
823 </new_text>
824 "},
825 &mut rng,
826 cx,
827 );
828 apply.await.unwrap();
829
830 pretty_assertions::assert_eq!(
831 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
832 indoc! {"
833 lorem
834 ipsum
835 dolor
836 sit
837 amet
838 "}
839 );
840 }
841
842 #[gpui::test(iterations = 100)]
843 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
844 let agent = init_test(cx).await;
845 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
846 let (apply, _events) = agent.edit(
847 buffer.clone(),
848 String::new(),
849 &LanguageModelRequest::default(),
850 &mut cx.to_async(),
851 );
852 cx.run_until_parked();
853
854 simulate_llm_output(
855 &agent,
856 indoc! {"
857 <old_text>
858 def
859 </old_text>
860 <new_text>
861 DEF
862 </new_text>
863
864 <old_text>
865 DEF
866 </old_text>
867 <new_text>
868 DeF
869 </new_text>
870 "},
871 &mut rng,
872 cx,
873 );
874 apply.await.unwrap();
875
876 assert_eq!(
877 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
878 "abc\nDeF\nghi"
879 );
880 }
881
882 #[gpui::test(iterations = 100)]
883 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
884 let agent = init_test(cx).await;
885 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
886 let (apply, _events) = agent.edit(
887 buffer.clone(),
888 String::new(),
889 &LanguageModelRequest::default(),
890 &mut cx.to_async(),
891 );
892 cx.run_until_parked();
893
894 simulate_llm_output(
895 &agent,
896 indoc! {"
897 <old_text>
898 jkl
899 </old_text>
900 <new_text>
901 mno
902 </new_text>
903
904 <old_text>
905 abc
906 </old_text>
907 <new_text>
908 ABC
909 </new_text>
910 "},
911 &mut rng,
912 cx,
913 );
914 apply.await.unwrap();
915
916 assert_eq!(
917 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
918 "ABC\ndef\nghi"
919 );
920 }
921
922 #[gpui::test]
923 async fn test_edit_events(cx: &mut TestAppContext) {
924 let agent = init_test(cx).await;
925 let model = agent.model.as_fake();
926 let project = agent
927 .action_log
928 .read_with(cx, |log, _| log.project().clone());
929 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx));
930
931 let mut async_cx = cx.to_async();
932 let (apply, mut events) = agent.edit(
933 buffer.clone(),
934 String::new(),
935 &LanguageModelRequest::default(),
936 &mut async_cx,
937 );
938 cx.run_until_parked();
939
940 model.stream_last_completion_response("<old_text>a");
941 cx.run_until_parked();
942 assert_eq!(drain_events(&mut events), vec![]);
943 assert_eq!(
944 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
945 "abc\ndef\nghi\njkl"
946 );
947 assert_eq!(
948 project.read_with(cx, |project, _| project.agent_location()),
949 None
950 );
951
952 model.stream_last_completion_response("bc</old_text>");
953 cx.run_until_parked();
954 assert_eq!(
955 drain_events(&mut events),
956 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
957 cx,
958 |buffer, _| buffer.anchor_before(Point::new(0, 0))
959 ..buffer.anchor_before(Point::new(0, 3))
960 ))]
961 );
962 assert_eq!(
963 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
964 "abc\ndef\nghi\njkl"
965 );
966 assert_eq!(
967 project.read_with(cx, |project, _| project.agent_location()),
968 Some(AgentLocation {
969 buffer: buffer.downgrade(),
970 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
971 })
972 );
973
974 model.stream_last_completion_response("<new_text>abX");
975 cx.run_until_parked();
976 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
977 assert_eq!(
978 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
979 "abXc\ndef\nghi\njkl"
980 );
981 assert_eq!(
982 project.read_with(cx, |project, _| project.agent_location()),
983 Some(AgentLocation {
984 buffer: buffer.downgrade(),
985 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
986 })
987 );
988
989 model.stream_last_completion_response("cY");
990 cx.run_until_parked();
991 assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
992 assert_eq!(
993 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
994 "abXcY\ndef\nghi\njkl"
995 );
996 assert_eq!(
997 project.read_with(cx, |project, _| project.agent_location()),
998 Some(AgentLocation {
999 buffer: buffer.downgrade(),
1000 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1001 })
1002 );
1003
1004 model.stream_last_completion_response("</new_text>");
1005 model.stream_last_completion_response("<old_text>hall");
1006 cx.run_until_parked();
1007 assert_eq!(drain_events(&mut events), vec![]);
1008 assert_eq!(
1009 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1010 "abXcY\ndef\nghi\njkl"
1011 );
1012 assert_eq!(
1013 project.read_with(cx, |project, _| project.agent_location()),
1014 Some(AgentLocation {
1015 buffer: buffer.downgrade(),
1016 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1017 })
1018 );
1019
1020 model.stream_last_completion_response("ucinated old</old_text>");
1021 model.stream_last_completion_response("<new_text>");
1022 cx.run_until_parked();
1023 assert_eq!(
1024 drain_events(&mut events),
1025 vec![EditAgentOutputEvent::UnresolvedEditRange]
1026 );
1027 assert_eq!(
1028 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1029 "abXcY\ndef\nghi\njkl"
1030 );
1031 assert_eq!(
1032 project.read_with(cx, |project, _| project.agent_location()),
1033 Some(AgentLocation {
1034 buffer: buffer.downgrade(),
1035 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1036 })
1037 );
1038
1039 model.stream_last_completion_response("hallucinated new</new_");
1040 model.stream_last_completion_response("text>");
1041 cx.run_until_parked();
1042 assert_eq!(drain_events(&mut events), vec![]);
1043 assert_eq!(
1044 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1045 "abXcY\ndef\nghi\njkl"
1046 );
1047 assert_eq!(
1048 project.read_with(cx, |project, _| project.agent_location()),
1049 Some(AgentLocation {
1050 buffer: buffer.downgrade(),
1051 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1052 })
1053 );
1054
1055 model.stream_last_completion_response("<old_text>\nghi\nj");
1056 cx.run_until_parked();
1057 assert_eq!(
1058 drain_events(&mut events),
1059 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1060 cx,
1061 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1062 ..buffer.anchor_before(Point::new(2, 3))
1063 ))]
1064 );
1065 assert_eq!(
1066 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1067 "abXcY\ndef\nghi\njkl"
1068 );
1069 assert_eq!(
1070 project.read_with(cx, |project, _| project.agent_location()),
1071 Some(AgentLocation {
1072 buffer: buffer.downgrade(),
1073 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1074 })
1075 );
1076
1077 model.stream_last_completion_response("kl</old_text>");
1078 model.stream_last_completion_response("<new_text>");
1079 cx.run_until_parked();
1080 assert_eq!(
1081 drain_events(&mut events),
1082 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1083 cx,
1084 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1085 ..buffer.anchor_before(Point::new(3, 3))
1086 ))]
1087 );
1088 assert_eq!(
1089 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1090 "abXcY\ndef\nghi\njkl"
1091 );
1092 assert_eq!(
1093 project.read_with(cx, |project, _| project.agent_location()),
1094 Some(AgentLocation {
1095 buffer: buffer.downgrade(),
1096 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)))
1097 })
1098 );
1099
1100 model.stream_last_completion_response("GHI</new_text>");
1101 cx.run_until_parked();
1102 assert_eq!(
1103 drain_events(&mut events),
1104 vec![EditAgentOutputEvent::Edited]
1105 );
1106 assert_eq!(
1107 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1108 "abXcY\ndef\nGHI"
1109 );
1110 assert_eq!(
1111 project.read_with(cx, |project, _| project.agent_location()),
1112 Some(AgentLocation {
1113 buffer: buffer.downgrade(),
1114 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1115 })
1116 );
1117
1118 model.end_last_completion_stream();
1119 apply.await.unwrap();
1120 assert_eq!(
1121 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1122 "abXcY\ndef\nGHI"
1123 );
1124 assert_eq!(drain_events(&mut events), vec![]);
1125 assert_eq!(
1126 project.read_with(cx, |project, _| project.agent_location()),
1127 Some(AgentLocation {
1128 buffer: buffer.downgrade(),
1129 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1130 })
1131 );
1132 }
1133
1134 #[gpui::test]
1135 async fn test_overwrite_events(cx: &mut TestAppContext) {
1136 let agent = init_test(cx).await;
1137 let project = agent
1138 .action_log
1139 .read_with(cx, |log, _| log.project().clone());
1140 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1141 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1142 let (apply, mut events) = agent.overwrite_with_chunks(
1143 buffer.clone(),
1144 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1145 &mut cx.to_async(),
1146 );
1147
1148 cx.run_until_parked();
1149 assert_eq!(
1150 drain_events(&mut events),
1151 vec![EditAgentOutputEvent::Edited]
1152 );
1153 assert_eq!(
1154 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1155 ""
1156 );
1157 assert_eq!(
1158 project.read_with(cx, |project, _| project.agent_location()),
1159 Some(AgentLocation {
1160 buffer: buffer.downgrade(),
1161 position: language::Anchor::MAX
1162 })
1163 );
1164
1165 chunks_tx.unbounded_send("```\njkl\n").unwrap();
1166 cx.run_until_parked();
1167 assert_eq!(
1168 drain_events(&mut events),
1169 vec![EditAgentOutputEvent::Edited]
1170 );
1171 assert_eq!(
1172 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1173 "jkl"
1174 );
1175 assert_eq!(
1176 project.read_with(cx, |project, _| project.agent_location()),
1177 Some(AgentLocation {
1178 buffer: buffer.downgrade(),
1179 position: language::Anchor::MAX
1180 })
1181 );
1182
1183 chunks_tx.unbounded_send("mno\n").unwrap();
1184 cx.run_until_parked();
1185 assert_eq!(
1186 drain_events(&mut events),
1187 vec![EditAgentOutputEvent::Edited]
1188 );
1189 assert_eq!(
1190 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1191 "jkl\nmno"
1192 );
1193 assert_eq!(
1194 project.read_with(cx, |project, _| project.agent_location()),
1195 Some(AgentLocation {
1196 buffer: buffer.downgrade(),
1197 position: language::Anchor::MAX
1198 })
1199 );
1200
1201 chunks_tx.unbounded_send("pqr\n```").unwrap();
1202 cx.run_until_parked();
1203 assert_eq!(
1204 drain_events(&mut events),
1205 vec![EditAgentOutputEvent::Edited]
1206 );
1207 assert_eq!(
1208 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1209 "jkl\nmno\npqr"
1210 );
1211 assert_eq!(
1212 project.read_with(cx, |project, _| project.agent_location()),
1213 Some(AgentLocation {
1214 buffer: buffer.downgrade(),
1215 position: language::Anchor::MAX
1216 })
1217 );
1218
1219 drop(chunks_tx);
1220 apply.await.unwrap();
1221 assert_eq!(
1222 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1223 "jkl\nmno\npqr"
1224 );
1225 assert_eq!(drain_events(&mut events), vec![]);
1226 assert_eq!(
1227 project.read_with(cx, |project, _| project.agent_location()),
1228 Some(AgentLocation {
1229 buffer: buffer.downgrade(),
1230 position: language::Anchor::MAX
1231 })
1232 );
1233 }
1234
1235 #[gpui::test(iterations = 100)]
1236 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1237 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1238 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1239 Ok(EditParserEvent::NewTextChunk {
1240 chunk: chunk.clone(),
1241 done: index == chunks.len() - 1,
1242 })
1243 }));
1244 let indented_chunks =
1245 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1246 .collect::<Vec<_>>()
1247 .await;
1248 let new_text = indented_chunks
1249 .into_iter()
1250 .collect::<Result<String>>()
1251 .unwrap();
1252 assert_eq!(new_text, " abc\n def\n ghi");
1253 }
1254
1255 #[gpui::test(iterations = 100)]
1256 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1257 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1258 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1259 Ok(EditParserEvent::NewTextChunk {
1260 chunk: chunk.clone(),
1261 done: index == chunks.len() - 1,
1262 })
1263 }));
1264 let indented_chunks =
1265 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1266 .collect::<Vec<_>>()
1267 .await;
1268 let new_text = indented_chunks
1269 .into_iter()
1270 .collect::<Result<String>>()
1271 .unwrap();
1272 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1273 }
1274
1275 #[gpui::test(iterations = 100)]
1276 async fn test_random_indents(mut rng: StdRng) {
1277 let len = rng.gen_range(1..=100);
1278 let new_text = util::RandomCharIter::new(&mut rng)
1279 .with_simple_text()
1280 .take(len)
1281 .collect::<String>();
1282 let new_text = new_text
1283 .split('\n')
1284 .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1285 .collect::<Vec<_>>()
1286 .join("\n");
1287 let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1288
1289 let chunks = to_random_chunks(&mut rng, &new_text);
1290 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1291 Ok(EditParserEvent::NewTextChunk {
1292 chunk: chunk.clone(),
1293 done: index == chunks.len() - 1,
1294 })
1295 }));
1296 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1297 .collect::<Vec<_>>()
1298 .await;
1299 let actual_reindented_text = reindented_chunks
1300 .into_iter()
1301 .collect::<Result<String>>()
1302 .unwrap();
1303 let expected_reindented_text = new_text
1304 .split('\n')
1305 .map(|line| {
1306 if let Some(ix) = line.find(|c| c != ' ') {
1307 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1308 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1309 } else {
1310 line.to_string()
1311 }
1312 })
1313 .collect::<Vec<_>>()
1314 .join("\n");
1315 assert_eq!(actual_reindented_text, expected_reindented_text);
1316 }
1317
1318 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1319 let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1320 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1321 chunk_indices.sort();
1322 chunk_indices.push(input.len());
1323
1324 let mut chunks = Vec::new();
1325 let mut last_ix = 0;
1326 for chunk_ix in chunk_indices {
1327 chunks.push(input[last_ix..chunk_ix].to_string());
1328 last_ix = chunk_ix;
1329 }
1330 chunks
1331 }
1332
1333 fn simulate_llm_output(
1334 agent: &EditAgent,
1335 output: &str,
1336 rng: &mut StdRng,
1337 cx: &mut TestAppContext,
1338 ) {
1339 let executor = cx.executor();
1340 let chunks = to_random_chunks(rng, output);
1341 let model = agent.model.clone();
1342 cx.background_spawn(async move {
1343 for chunk in chunks {
1344 executor.simulate_random_delay().await;
1345 model.as_fake().stream_last_completion_response(chunk);
1346 }
1347 model.as_fake().end_last_completion_stream();
1348 })
1349 .detach();
1350 }
1351
1352 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1353 cx.update(settings::init);
1354 cx.update(Project::init_settings);
1355 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1356 let model = Arc::new(FakeLanguageModel::default());
1357 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1358 EditAgent::new(model, project, action_log, Templates::new())
1359 }
1360
1361 #[gpui::test(iterations = 10)]
1362 async fn test_non_unique_text_error(cx: &mut TestAppContext, mut rng: StdRng) {
1363 let agent = init_test(cx).await;
1364 let original_text = indoc! {"
1365 function foo() {
1366 return 42;
1367 }
1368
1369 function bar() {
1370 return 42;
1371 }
1372
1373 function baz() {
1374 return 42;
1375 }
1376 "};
1377 let buffer = cx.new(|cx| Buffer::local(original_text, cx));
1378 let (apply, mut events) = agent.edit(
1379 buffer.clone(),
1380 String::new(),
1381 &LanguageModelRequest::default(),
1382 &mut cx.to_async(),
1383 );
1384 cx.run_until_parked();
1385
1386 // When <old_text> matches text in more than one place
1387 simulate_llm_output(
1388 &agent,
1389 indoc! {"
1390 <old_text>
1391 return 42;
1392 }
1393 </old_text>
1394 <new_text>
1395 return 100;
1396 }
1397 </new_text>
1398 "},
1399 &mut rng,
1400 cx,
1401 );
1402 apply.await.unwrap();
1403
1404 // Then the text should remain unchanged
1405 let result_text = buffer.read_with(cx, |buffer, _| buffer.snapshot().text());
1406 assert_eq!(
1407 result_text,
1408 indoc! {"
1409 function foo() {
1410 return 42;
1411 }
1412
1413 function bar() {
1414 return 42;
1415 }
1416
1417 function baz() {
1418 return 42;
1419 }
1420 "},
1421 "Text should remain unchanged when there are multiple matches"
1422 );
1423
1424 // And AmbiguousEditRange even should be emitted
1425 let events = drain_events(&mut events);
1426 let ambiguous_ranges = vec![2..3, 6..7, 10..11];
1427 assert!(
1428 events.contains(&EditAgentOutputEvent::AmbiguousEditRange(ambiguous_ranges)),
1429 "Should emit AmbiguousEditRange for non-unique text"
1430 );
1431 }
1432
1433 fn drain_events(
1434 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1435 ) -> Vec<EditAgentOutputEvent> {
1436 let mut events = Vec::new();
1437 while let Ok(Some(event)) = stream.try_next() {
1438 events.push(event);
1439 }
1440 events
1441 }
1442}