1mod create_file_parser;
2mod edit_parser;
3#[cfg(test)]
4mod evals;
5pub mod streaming_fuzzy_matcher;
6
7use crate::{Template, Templates};
8use action_log::ActionLog;
9use anyhow::Result;
10use cloud_llm_client::CompletionIntent;
11use create_file_parser::{CreateFileParser, CreateFileParserEvent};
12pub use edit_parser::EditFormat;
13use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
14use futures::{
15 Stream, StreamExt,
16 channel::mpsc::{self, UnboundedReceiver},
17 pin_mut,
18 stream::BoxStream,
19};
20use gpui::{AppContext, AsyncApp, Entity, Task};
21use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot};
22use language_model::{
23 LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
24 LanguageModelToolChoice, MessageContent, Role,
25};
26use project::{AgentLocation, Project};
27use schemars::JsonSchema;
28use serde::{Deserialize, Serialize};
29use std::{cmp, iter, mem, ops::Range, pin::Pin, sync::Arc, task::Poll};
30use streaming_diff::{CharOperation, StreamingDiff};
31use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
32
33#[derive(Serialize)]
34struct CreateFilePromptTemplate {
35 path: Option<String>,
36 edit_description: String,
37}
38
39impl Template for CreateFilePromptTemplate {
40 const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
41}
42
43#[derive(Serialize)]
44struct EditFileXmlPromptTemplate {
45 path: Option<String>,
46 edit_description: String,
47}
48
49impl Template for EditFileXmlPromptTemplate {
50 const TEMPLATE_NAME: &'static str = "edit_file_prompt_xml.hbs";
51}
52
53#[derive(Serialize)]
54struct EditFileDiffFencedPromptTemplate {
55 path: Option<String>,
56 edit_description: String,
57}
58
59impl Template for EditFileDiffFencedPromptTemplate {
60 const TEMPLATE_NAME: &'static str = "edit_file_prompt_diff_fenced.hbs";
61}
62
63#[derive(Clone, Debug, PartialEq, Eq)]
64pub enum EditAgentOutputEvent {
65 ResolvingEditRange(Range<Anchor>),
66 UnresolvedEditRange,
67 AmbiguousEditRange(Vec<Range<usize>>),
68 Edited(Range<Anchor>),
69}
70
71#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
72pub struct EditAgentOutput {
73 pub raw_edits: String,
74 pub parser_metrics: EditParserMetrics,
75}
76
77#[derive(Clone)]
78pub struct EditAgent {
79 model: Arc<dyn LanguageModel>,
80 action_log: Entity<ActionLog>,
81 project: Entity<Project>,
82 templates: Arc<Templates>,
83 edit_format: EditFormat,
84 thinking_allowed: bool,
85}
86
87impl EditAgent {
88 pub fn new(
89 model: Arc<dyn LanguageModel>,
90 project: Entity<Project>,
91 action_log: Entity<ActionLog>,
92 templates: Arc<Templates>,
93 edit_format: EditFormat,
94 allow_thinking: bool,
95 ) -> Self {
96 EditAgent {
97 model,
98 project,
99 action_log,
100 templates,
101 edit_format,
102 thinking_allowed: allow_thinking,
103 }
104 }
105
106 pub fn overwrite(
107 &self,
108 buffer: Entity<Buffer>,
109 edit_description: String,
110 conversation: &LanguageModelRequest,
111 cx: &mut AsyncApp,
112 ) -> (
113 Task<Result<EditAgentOutput>>,
114 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
115 ) {
116 let this = self.clone();
117 let (events_tx, events_rx) = mpsc::unbounded();
118 let conversation = conversation.clone();
119 let output = cx.spawn(async move |cx| {
120 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
121 let path = cx.update(|cx| snapshot.resolve_file_path(true, cx));
122 let prompt = CreateFilePromptTemplate {
123 path,
124 edit_description,
125 }
126 .render(&this.templates)?;
127 let new_chunks = this
128 .request(conversation, CompletionIntent::CreateFile, prompt, cx)
129 .await?;
130
131 let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
132 while let Some(event) = inner_events.next().await {
133 events_tx.unbounded_send(event).ok();
134 }
135 output.await
136 });
137 (output, events_rx)
138 }
139
140 fn overwrite_with_chunks(
141 &self,
142 buffer: Entity<Buffer>,
143 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
144 cx: &mut AsyncApp,
145 ) -> (
146 Task<Result<EditAgentOutput>>,
147 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
148 ) {
149 let (output_events_tx, output_events_rx) = mpsc::unbounded();
150 let (parse_task, parse_rx) = Self::parse_create_file_chunks(edit_chunks, cx);
151 let this = self.clone();
152 let task = cx.spawn(async move |cx| {
153 this.action_log
154 .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx));
155 this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx)
156 .await?;
157 parse_task.await
158 });
159 (task, output_events_rx)
160 }
161
162 async fn overwrite_with_chunks_internal(
163 &self,
164 buffer: Entity<Buffer>,
165 mut parse_rx: UnboundedReceiver<Result<CreateFileParserEvent>>,
166 output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
167 cx: &mut AsyncApp,
168 ) -> Result<()> {
169 let buffer_id = cx.update(|cx| {
170 let buffer_id = buffer.read(cx).remote_id();
171 self.project.update(cx, |project, cx| {
172 project.set_agent_location(
173 Some(AgentLocation {
174 buffer: buffer.downgrade(),
175 position: language::Anchor::min_for_buffer(buffer_id),
176 }),
177 cx,
178 )
179 });
180 buffer_id
181 });
182
183 let send_edit_event = || {
184 output_events_tx
185 .unbounded_send(EditAgentOutputEvent::Edited(
186 Anchor::min_max_range_for_buffer(buffer_id),
187 ))
188 .ok()
189 };
190 let set_agent_location = |cx: &mut _| {
191 self.project.update(cx, |project, cx| {
192 project.set_agent_location(
193 Some(AgentLocation {
194 buffer: buffer.downgrade(),
195 position: language::Anchor::max_for_buffer(buffer_id),
196 }),
197 cx,
198 )
199 })
200 };
201 let mut first_chunk = true;
202 while let Some(event) = parse_rx.next().await {
203 match event? {
204 CreateFileParserEvent::NewTextChunk { chunk } => {
205 cx.update(|cx| {
206 buffer.update(cx, |buffer, cx| {
207 if mem::take(&mut first_chunk) {
208 buffer.set_text(chunk, cx)
209 } else {
210 buffer.append(chunk, cx)
211 }
212 });
213 self.action_log
214 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
215 set_agent_location(cx);
216 });
217 send_edit_event();
218 }
219 }
220 }
221
222 if first_chunk {
223 cx.update(|cx| {
224 buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
225 self.action_log
226 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
227 set_agent_location(cx);
228 });
229 send_edit_event();
230 }
231
232 Ok(())
233 }
234
235 pub fn edit(
236 &self,
237 buffer: Entity<Buffer>,
238 edit_description: String,
239 conversation: &LanguageModelRequest,
240 cx: &mut AsyncApp,
241 ) -> (
242 Task<Result<EditAgentOutput>>,
243 mpsc::UnboundedReceiver<EditAgentOutputEvent>,
244 ) {
245 let this = self.clone();
246 let (events_tx, events_rx) = mpsc::unbounded();
247 let conversation = conversation.clone();
248 let edit_format = self.edit_format;
249 let output = cx.spawn(async move |cx| {
250 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
251 let path = cx.update(|cx| snapshot.resolve_file_path(true, cx));
252 let prompt = match edit_format {
253 EditFormat::XmlTags => EditFileXmlPromptTemplate {
254 path,
255 edit_description,
256 }
257 .render(&this.templates)?,
258 EditFormat::DiffFenced => EditFileDiffFencedPromptTemplate {
259 path,
260 edit_description,
261 }
262 .render(&this.templates)?,
263 };
264
265 let edit_chunks = this
266 .request(conversation, CompletionIntent::EditFile, prompt, cx)
267 .await?;
268 this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
269 .await
270 });
271 (output, events_rx)
272 }
273
274 async fn apply_edit_chunks(
275 &self,
276 buffer: Entity<Buffer>,
277 edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
278 output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
279 cx: &mut AsyncApp,
280 ) -> Result<EditAgentOutput> {
281 self.action_log
282 .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
283
284 let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, self.edit_format, cx);
285 let mut edit_events = edit_events.peekable();
286 while let Some(edit_event) = Pin::new(&mut edit_events).peek().await {
287 // Skip events until we're at the start of a new edit.
288 let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else {
289 edit_events.next().await.unwrap()?;
290 continue;
291 };
292
293 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
294
295 // Resolve the old text in the background, updating the agent
296 // location as we keep refining which range it corresponds to.
297 let (resolve_old_text, mut old_range) =
298 Self::resolve_old_text(snapshot.text.clone(), edit_events, cx);
299 while let Ok(old_range) = old_range.recv().await {
300 if let Some(old_range) = old_range {
301 let old_range = snapshot.anchor_before(old_range.start)
302 ..snapshot.anchor_before(old_range.end);
303 self.project.update(cx, |project, cx| {
304 project.set_agent_location(
305 Some(AgentLocation {
306 buffer: buffer.downgrade(),
307 position: old_range.end,
308 }),
309 cx,
310 );
311 });
312 output_events
313 .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range))
314 .ok();
315 }
316 }
317
318 let (edit_events_, mut resolved_old_text) = resolve_old_text.await?;
319 edit_events = edit_events_;
320
321 // If we can't resolve the old text, restart the loop waiting for a
322 // new edit (or for the stream to end).
323 let resolved_old_text = match resolved_old_text.len() {
324 1 => resolved_old_text.pop().unwrap(),
325 0 => {
326 output_events
327 .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
328 .ok();
329 continue;
330 }
331 _ => {
332 let ranges = resolved_old_text
333 .into_iter()
334 .map(|text| {
335 let start_line =
336 (snapshot.offset_to_point(text.range.start).row + 1) as usize;
337 let end_line =
338 (snapshot.offset_to_point(text.range.end).row + 1) as usize;
339 start_line..end_line
340 })
341 .collect();
342 output_events
343 .unbounded_send(EditAgentOutputEvent::AmbiguousEditRange(ranges))
344 .ok();
345 continue;
346 }
347 };
348
349 // Compute edits in the background and apply them as they become
350 // available.
351 let (compute_edits, edits) =
352 Self::compute_edits(snapshot, resolved_old_text, edit_events, cx);
353 let mut edits = edits.ready_chunks(32);
354 while let Some(edits) = edits.next().await {
355 if edits.is_empty() {
356 continue;
357 }
358
359 // Edit the buffer and report edits to the action log as part of the
360 // same effect cycle, otherwise the edit will be reported as if the
361 // user made it.
362 let (min_edit_start, max_edit_end) = cx.update(|cx| {
363 let (min_edit_start, max_edit_end) = buffer.update(cx, |buffer, cx| {
364 buffer.edit(edits.iter().cloned(), None, cx);
365 let max_edit_end = buffer
366 .summaries_for_anchors::<Point, _>(
367 edits.iter().map(|(range, _)| &range.end),
368 )
369 .max()
370 .unwrap();
371 let min_edit_start = buffer
372 .summaries_for_anchors::<Point, _>(
373 edits.iter().map(|(range, _)| &range.start),
374 )
375 .min()
376 .unwrap();
377 (
378 buffer.anchor_after(min_edit_start),
379 buffer.anchor_before(max_edit_end),
380 )
381 });
382 self.action_log
383 .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
384 self.project.update(cx, |project, cx| {
385 project.set_agent_location(
386 Some(AgentLocation {
387 buffer: buffer.downgrade(),
388 position: max_edit_end,
389 }),
390 cx,
391 );
392 });
393 (min_edit_start, max_edit_end)
394 });
395 output_events
396 .unbounded_send(EditAgentOutputEvent::Edited(min_edit_start..max_edit_end))
397 .ok();
398 }
399
400 edit_events = compute_edits.await?;
401 }
402
403 output.await
404 }
405
406 fn parse_edit_chunks(
407 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
408 edit_format: EditFormat,
409 cx: &mut AsyncApp,
410 ) -> (
411 Task<Result<EditAgentOutput>>,
412 UnboundedReceiver<Result<EditParserEvent>>,
413 ) {
414 let (tx, rx) = mpsc::unbounded();
415 let output = cx.background_spawn(async move {
416 pin_mut!(chunks);
417
418 let mut parser = EditParser::new(edit_format);
419 let mut raw_edits = String::new();
420 while let Some(chunk) = chunks.next().await {
421 match chunk {
422 Ok(chunk) => {
423 raw_edits.push_str(&chunk);
424 for event in parser.push(&chunk) {
425 tx.unbounded_send(Ok(event))?;
426 }
427 }
428 Err(error) => {
429 tx.unbounded_send(Err(error.into()))?;
430 }
431 }
432 }
433 Ok(EditAgentOutput {
434 raw_edits,
435 parser_metrics: parser.finish(),
436 })
437 });
438 (output, rx)
439 }
440
441 fn parse_create_file_chunks(
442 chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
443 cx: &mut AsyncApp,
444 ) -> (
445 Task<Result<EditAgentOutput>>,
446 UnboundedReceiver<Result<CreateFileParserEvent>>,
447 ) {
448 let (tx, rx) = mpsc::unbounded();
449 let output = cx.background_spawn(async move {
450 pin_mut!(chunks);
451
452 let mut parser = CreateFileParser::new();
453 let mut raw_edits = String::new();
454 while let Some(chunk) = chunks.next().await {
455 match chunk {
456 Ok(chunk) => {
457 raw_edits.push_str(&chunk);
458 for event in parser.push(Some(&chunk)) {
459 tx.unbounded_send(Ok(event))?;
460 }
461 }
462 Err(error) => {
463 tx.unbounded_send(Err(error.into()))?;
464 }
465 }
466 }
467 // Send final events with None to indicate completion
468 for event in parser.push(None) {
469 tx.unbounded_send(Ok(event))?;
470 }
471 Ok(EditAgentOutput {
472 raw_edits,
473 parser_metrics: EditParserMetrics::default(),
474 })
475 });
476 (output, rx)
477 }
478
479 fn resolve_old_text<T>(
480 snapshot: TextBufferSnapshot,
481 mut edit_events: T,
482 cx: &mut AsyncApp,
483 ) -> (
484 Task<Result<(T, Vec<ResolvedOldText>)>>,
485 watch::Receiver<Option<Range<usize>>>,
486 )
487 where
488 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
489 {
490 let (mut old_range_tx, old_range_rx) = watch::channel(None);
491 let task = cx.background_spawn(async move {
492 let mut matcher = StreamingFuzzyMatcher::new(snapshot);
493 while let Some(edit_event) = edit_events.next().await {
494 let EditParserEvent::OldTextChunk {
495 chunk,
496 done,
497 line_hint,
498 } = edit_event?
499 else {
500 break;
501 };
502
503 old_range_tx.send(matcher.push(&chunk, line_hint))?;
504 if done {
505 break;
506 }
507 }
508
509 let matches = matcher.finish();
510 let best_match = matcher.select_best_match();
511
512 old_range_tx.send(best_match.clone())?;
513
514 let indent = LineIndent::from_iter(
515 matcher
516 .query_lines()
517 .first()
518 .unwrap_or(&String::new())
519 .chars(),
520 );
521
522 let resolved_old_texts = if let Some(best_match) = best_match {
523 vec![ResolvedOldText {
524 range: best_match,
525 indent,
526 }]
527 } else {
528 matches
529 .into_iter()
530 .map(|range| ResolvedOldText { range, indent })
531 .collect::<Vec<_>>()
532 };
533
534 Ok((edit_events, resolved_old_texts))
535 });
536
537 (task, old_range_rx)
538 }
539
540 fn compute_edits<T>(
541 snapshot: BufferSnapshot,
542 resolved_old_text: ResolvedOldText,
543 mut edit_events: T,
544 cx: &mut AsyncApp,
545 ) -> (
546 Task<Result<T>>,
547 UnboundedReceiver<(Range<Anchor>, Arc<str>)>,
548 )
549 where
550 T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
551 {
552 let (edits_tx, edits_rx) = mpsc::unbounded();
553 let compute_edits = cx.background_spawn(async move {
554 let buffer_start_indent = snapshot
555 .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row);
556 let indent_delta = if buffer_start_indent.tabs > 0 {
557 IndentDelta::Tabs(
558 buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize,
559 )
560 } else {
561 IndentDelta::Spaces(
562 buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize,
563 )
564 };
565
566 let old_text = snapshot
567 .text_for_range(resolved_old_text.range.clone())
568 .collect::<String>();
569 let mut diff = StreamingDiff::new(old_text);
570 let mut edit_start = resolved_old_text.range.start;
571 let mut new_text_chunks =
572 Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
573 let mut done = false;
574 while !done {
575 let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await {
576 diff.push_new(&new_text_chunk?)
577 } else {
578 done = true;
579 mem::take(&mut diff).finish()
580 };
581
582 for op in char_operations {
583 match op {
584 CharOperation::Insert { text } => {
585 let edit_start = snapshot.anchor_after(edit_start);
586 edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
587 }
588 CharOperation::Delete { bytes } => {
589 let edit_end = edit_start + bytes;
590 let edit_range =
591 snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end);
592 edit_start = edit_end;
593 edits_tx.unbounded_send((edit_range, Arc::from("")))?;
594 }
595 CharOperation::Keep { bytes } => edit_start += bytes,
596 }
597 }
598 }
599
600 drop(new_text_chunks);
601 anyhow::Ok(edit_events)
602 });
603
604 (compute_edits, edits_rx)
605 }
606
607 fn reindent_new_text_chunks(
608 delta: IndentDelta,
609 mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
610 ) -> impl Stream<Item = Result<String>> {
611 let mut buffer = String::new();
612 let mut in_leading_whitespace = true;
613 let mut done = false;
614 futures::stream::poll_fn(move |cx| {
615 while !done {
616 let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
617 Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
618 (chunk, done)
619 }
620 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
621 Poll::Pending => return Poll::Pending,
622 _ => return Poll::Ready(None),
623 };
624
625 buffer.push_str(&chunk);
626
627 let mut indented_new_text = String::new();
628 let mut start_ix = 0;
629 let mut newlines = buffer.match_indices('\n').peekable();
630 loop {
631 let (line_end, is_pending_line) = match newlines.next() {
632 Some((ix, _)) => (ix, false),
633 None => (buffer.len(), true),
634 };
635 let line = &buffer[start_ix..line_end];
636
637 if in_leading_whitespace {
638 if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
639 // We found a non-whitespace character, adjust
640 // indentation based on the delta.
641 let new_indent_len =
642 cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
643 indented_new_text
644 .extend(iter::repeat(delta.character()).take(new_indent_len));
645 indented_new_text.push_str(&line[non_whitespace_ix..]);
646 in_leading_whitespace = false;
647 } else if is_pending_line {
648 // We're still in leading whitespace and this line is incomplete.
649 // Stop processing until we receive more input.
650 break;
651 } else {
652 // This line is entirely whitespace. Push it without indentation.
653 indented_new_text.push_str(line);
654 }
655 } else {
656 indented_new_text.push_str(line);
657 }
658
659 if is_pending_line {
660 start_ix = line_end;
661 break;
662 } else {
663 in_leading_whitespace = true;
664 indented_new_text.push('\n');
665 start_ix = line_end + 1;
666 }
667 }
668 buffer.replace_range(..start_ix, "");
669
670 // This was the last chunk, push all the buffered content as-is.
671 if is_last_chunk {
672 indented_new_text.push_str(&buffer);
673 buffer.clear();
674 done = true;
675 }
676
677 if !indented_new_text.is_empty() {
678 return Poll::Ready(Some(Ok(indented_new_text)));
679 }
680 }
681
682 Poll::Ready(None)
683 })
684 }
685
686 async fn request(
687 &self,
688 mut conversation: LanguageModelRequest,
689 intent: CompletionIntent,
690 prompt: String,
691 cx: &mut AsyncApp,
692 ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
693 let mut messages_iter = conversation.messages.iter_mut();
694 if let Some(last_message) = messages_iter.next_back()
695 && last_message.role == Role::Assistant
696 {
697 let old_content_len = last_message.content.len();
698 last_message
699 .content
700 .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
701 let new_content_len = last_message.content.len();
702
703 // We just removed pending tool uses from the content of the
704 // last message, so it doesn't make sense to cache it anymore
705 // (e.g., the message will look very different on the next
706 // request). Thus, we move the flag to the message prior to it,
707 // as it will still be a valid prefix of the conversation.
708 if old_content_len != new_content_len
709 && last_message.cache
710 && let Some(prev_message) = messages_iter.next_back()
711 {
712 last_message.cache = false;
713 prev_message.cache = true;
714 }
715
716 if last_message.content.is_empty() {
717 conversation.messages.pop();
718 }
719 }
720
721 conversation.messages.push(LanguageModelRequestMessage {
722 role: Role::User,
723 content: vec![MessageContent::Text(prompt)],
724 cache: false,
725 reasoning_details: None,
726 });
727
728 // Include tools in the request so that we can take advantage of
729 // caching when ToolChoice::None is supported.
730 let mut tool_choice = None;
731 let mut tools = Vec::new();
732 if !conversation.tools.is_empty()
733 && self
734 .model
735 .supports_tool_choice(LanguageModelToolChoice::None)
736 {
737 tool_choice = Some(LanguageModelToolChoice::None);
738 tools = conversation.tools.clone();
739 }
740
741 let request = LanguageModelRequest {
742 thread_id: conversation.thread_id,
743 prompt_id: conversation.prompt_id,
744 intent: Some(intent),
745 messages: conversation.messages,
746 tool_choice,
747 tools,
748 stop: Vec::new(),
749 temperature: None,
750 thinking_allowed: self.thinking_allowed,
751 thinking_effort: None,
752 };
753
754 Ok(self.model.stream_completion_text(request, cx).await?.stream)
755 }
756}
757
758struct ResolvedOldText {
759 range: Range<usize>,
760 indent: LineIndent,
761}
762
763#[derive(Copy, Clone, Debug)]
764enum IndentDelta {
765 Spaces(isize),
766 Tabs(isize),
767}
768
769impl IndentDelta {
770 fn character(&self) -> char {
771 match self {
772 IndentDelta::Spaces(_) => ' ',
773 IndentDelta::Tabs(_) => '\t',
774 }
775 }
776
777 fn len(&self) -> isize {
778 match self {
779 IndentDelta::Spaces(n) => *n,
780 IndentDelta::Tabs(n) => *n,
781 }
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788 use fs::FakeFs;
789 use futures::stream;
790 use gpui::{AppContext, TestAppContext};
791 use indoc::indoc;
792 use language_model::fake_provider::FakeLanguageModel;
793 use pretty_assertions::assert_matches;
794 use project::{AgentLocation, Project};
795 use rand::prelude::*;
796 use rand::rngs::StdRng;
797 use std::cmp;
798
799 #[gpui::test(iterations = 100)]
800 async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) {
801 let agent = init_test(cx).await;
802 let buffer = cx.new(|cx| {
803 Buffer::local(
804 indoc! {"
805 abc
806 def
807 ghi
808 "},
809 cx,
810 )
811 });
812 let (apply, _events) = agent.edit(
813 buffer.clone(),
814 String::new(),
815 &LanguageModelRequest::default(),
816 &mut cx.to_async(),
817 );
818 cx.run_until_parked();
819
820 simulate_llm_output(
821 &agent,
822 indoc! {"
823 <old_text></old_text>
824 <new_text>jkl</new_text>
825 <old_text>def</old_text>
826 <new_text>DEF</new_text>
827 "},
828 &mut rng,
829 cx,
830 );
831 apply.await.unwrap();
832
833 pretty_assertions::assert_eq!(
834 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
835 indoc! {"
836 abc
837 DEF
838 ghi
839 "}
840 );
841 }
842
843 #[gpui::test(iterations = 100)]
844 async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
845 let agent = init_test(cx).await;
846 let buffer = cx.new(|cx| {
847 Buffer::local(
848 indoc! {"
849 lorem
850 ipsum
851 dolor
852 sit
853 "},
854 cx,
855 )
856 });
857 let (apply, _events) = agent.edit(
858 buffer.clone(),
859 String::new(),
860 &LanguageModelRequest::default(),
861 &mut cx.to_async(),
862 );
863 cx.run_until_parked();
864
865 simulate_llm_output(
866 &agent,
867 indoc! {"
868 <old_text>
869 ipsum
870 dolor
871 sit
872 </old_text>
873 <new_text>
874 ipsum
875 dolor
876 sit
877 amet
878 </new_text>
879 "},
880 &mut rng,
881 cx,
882 );
883 apply.await.unwrap();
884
885 pretty_assertions::assert_eq!(
886 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
887 indoc! {"
888 lorem
889 ipsum
890 dolor
891 sit
892 amet
893 "}
894 );
895 }
896
897 #[gpui::test(iterations = 100)]
898 async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
899 let agent = init_test(cx).await;
900 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
901 let (apply, _events) = agent.edit(
902 buffer.clone(),
903 String::new(),
904 &LanguageModelRequest::default(),
905 &mut cx.to_async(),
906 );
907 cx.run_until_parked();
908
909 simulate_llm_output(
910 &agent,
911 indoc! {"
912 <old_text>
913 def
914 </old_text>
915 <new_text>
916 DEF
917 </new_text>
918
919 <old_text>
920 DEF
921 </old_text>
922 <new_text>
923 DeF
924 </new_text>
925 "},
926 &mut rng,
927 cx,
928 );
929 apply.await.unwrap();
930
931 assert_eq!(
932 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
933 "abc\nDeF\nghi"
934 );
935 }
936
937 #[gpui::test(iterations = 100)]
938 async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
939 let agent = init_test(cx).await;
940 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
941 let (apply, _events) = agent.edit(
942 buffer.clone(),
943 String::new(),
944 &LanguageModelRequest::default(),
945 &mut cx.to_async(),
946 );
947 cx.run_until_parked();
948
949 simulate_llm_output(
950 &agent,
951 indoc! {"
952 <old_text>
953 jkl
954 </old_text>
955 <new_text>
956 mno
957 </new_text>
958
959 <old_text>
960 abc
961 </old_text>
962 <new_text>
963 ABC
964 </new_text>
965 "},
966 &mut rng,
967 cx,
968 );
969 apply.await.unwrap();
970
971 assert_eq!(
972 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
973 "ABC\ndef\nghi"
974 );
975 }
976
977 #[gpui::test]
978 async fn test_edit_events(cx: &mut TestAppContext) {
979 let agent = init_test(cx).await;
980 let model = agent.model.as_fake();
981 let project = agent
982 .action_log
983 .read_with(cx, |log, _| log.project().clone());
984 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx));
985
986 let mut async_cx = cx.to_async();
987 let (apply, mut events) = agent.edit(
988 buffer.clone(),
989 String::new(),
990 &LanguageModelRequest::default(),
991 &mut async_cx,
992 );
993 cx.run_until_parked();
994
995 model.send_last_completion_stream_text_chunk("<old_text>a");
996 cx.run_until_parked();
997 assert_eq!(drain_events(&mut events), vec![]);
998 assert_eq!(
999 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1000 "abc\ndef\nghi\njkl"
1001 );
1002 assert_eq!(
1003 project.read_with(cx, |project, _| project.agent_location()),
1004 None
1005 );
1006
1007 model.send_last_completion_stream_text_chunk("bc</old_text>");
1008 cx.run_until_parked();
1009 assert_eq!(
1010 drain_events(&mut events),
1011 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1012 cx,
1013 |buffer, _| buffer.anchor_before(Point::new(0, 0))
1014 ..buffer.anchor_before(Point::new(0, 3))
1015 ))]
1016 );
1017 assert_eq!(
1018 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1019 "abc\ndef\nghi\njkl"
1020 );
1021 assert_eq!(
1022 project.read_with(cx, |project, _| project.agent_location()),
1023 Some(AgentLocation {
1024 buffer: buffer.downgrade(),
1025 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
1026 })
1027 );
1028
1029 model.send_last_completion_stream_text_chunk("<new_text>abX");
1030 cx.run_until_parked();
1031 assert_matches!(
1032 drain_events(&mut events).as_slice(),
1033 [EditAgentOutputEvent::Edited(_)]
1034 );
1035 assert_eq!(
1036 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1037 "abXc\ndef\nghi\njkl"
1038 );
1039 assert_eq!(
1040 project.read_with(cx, |project, _| project.agent_location()),
1041 Some(AgentLocation {
1042 buffer: buffer.downgrade(),
1043 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
1044 })
1045 );
1046
1047 model.send_last_completion_stream_text_chunk("cY");
1048 cx.run_until_parked();
1049 assert_matches!(
1050 drain_events(&mut events).as_slice(),
1051 [EditAgentOutputEvent::Edited { .. }]
1052 );
1053 assert_eq!(
1054 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1055 "abXcY\ndef\nghi\njkl"
1056 );
1057 assert_eq!(
1058 project.read_with(cx, |project, _| project.agent_location()),
1059 Some(AgentLocation {
1060 buffer: buffer.downgrade(),
1061 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1062 })
1063 );
1064
1065 model.send_last_completion_stream_text_chunk("</new_text>");
1066 model.send_last_completion_stream_text_chunk("<old_text>hall");
1067 cx.run_until_parked();
1068 assert_eq!(drain_events(&mut events), vec![]);
1069 assert_eq!(
1070 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1071 "abXcY\ndef\nghi\njkl"
1072 );
1073 assert_eq!(
1074 project.read_with(cx, |project, _| project.agent_location()),
1075 Some(AgentLocation {
1076 buffer: buffer.downgrade(),
1077 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1078 })
1079 );
1080
1081 model.send_last_completion_stream_text_chunk("ucinated old</old_text>");
1082 model.send_last_completion_stream_text_chunk("<new_text>");
1083 cx.run_until_parked();
1084 assert_eq!(
1085 drain_events(&mut events),
1086 vec![EditAgentOutputEvent::UnresolvedEditRange]
1087 );
1088 assert_eq!(
1089 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1090 "abXcY\ndef\nghi\njkl"
1091 );
1092 assert_eq!(
1093 project.read_with(cx, |project, _| project.agent_location()),
1094 Some(AgentLocation {
1095 buffer: buffer.downgrade(),
1096 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1097 })
1098 );
1099
1100 model.send_last_completion_stream_text_chunk("hallucinated new</new_");
1101 model.send_last_completion_stream_text_chunk("text>");
1102 cx.run_until_parked();
1103 assert_eq!(drain_events(&mut events), vec![]);
1104 assert_eq!(
1105 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1106 "abXcY\ndef\nghi\njkl"
1107 );
1108 assert_eq!(
1109 project.read_with(cx, |project, _| project.agent_location()),
1110 Some(AgentLocation {
1111 buffer: buffer.downgrade(),
1112 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1113 })
1114 );
1115
1116 model.send_last_completion_stream_text_chunk("<old_text>\nghi\nj");
1117 cx.run_until_parked();
1118 assert_eq!(
1119 drain_events(&mut events),
1120 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1121 cx,
1122 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1123 ..buffer.anchor_before(Point::new(2, 3))
1124 ))]
1125 );
1126 assert_eq!(
1127 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1128 "abXcY\ndef\nghi\njkl"
1129 );
1130 assert_eq!(
1131 project.read_with(cx, |project, _| project.agent_location()),
1132 Some(AgentLocation {
1133 buffer: buffer.downgrade(),
1134 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1135 })
1136 );
1137
1138 model.send_last_completion_stream_text_chunk("kl</old_text>");
1139 model.send_last_completion_stream_text_chunk("<new_text>");
1140 cx.run_until_parked();
1141 assert_eq!(
1142 drain_events(&mut events),
1143 vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1144 cx,
1145 |buffer, _| buffer.anchor_before(Point::new(2, 0))
1146 ..buffer.anchor_before(Point::new(3, 3))
1147 ))]
1148 );
1149 assert_eq!(
1150 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1151 "abXcY\ndef\nghi\njkl"
1152 );
1153 assert_eq!(
1154 project.read_with(cx, |project, _| project.agent_location()),
1155 Some(AgentLocation {
1156 buffer: buffer.downgrade(),
1157 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)))
1158 })
1159 );
1160
1161 model.send_last_completion_stream_text_chunk("GHI</new_text>");
1162 cx.run_until_parked();
1163 assert_matches!(
1164 drain_events(&mut events).as_slice(),
1165 [EditAgentOutputEvent::Edited { .. }]
1166 );
1167 assert_eq!(
1168 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1169 "abXcY\ndef\nGHI"
1170 );
1171 assert_eq!(
1172 project.read_with(cx, |project, _| project.agent_location()),
1173 Some(AgentLocation {
1174 buffer: buffer.downgrade(),
1175 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1176 })
1177 );
1178
1179 model.end_last_completion_stream();
1180 apply.await.unwrap();
1181 assert_eq!(
1182 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1183 "abXcY\ndef\nGHI"
1184 );
1185 assert_eq!(drain_events(&mut events), vec![]);
1186 assert_eq!(
1187 project.read_with(cx, |project, _| project.agent_location()),
1188 Some(AgentLocation {
1189 buffer: buffer.downgrade(),
1190 position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1191 })
1192 );
1193 }
1194
1195 #[gpui::test]
1196 async fn test_overwrite_events(cx: &mut TestAppContext) {
1197 let agent = init_test(cx).await;
1198 let project = agent
1199 .action_log
1200 .read_with(cx, |log, _| log.project().clone());
1201 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1202 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1203 let (apply, mut events) = agent.overwrite_with_chunks(
1204 buffer.clone(),
1205 chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1206 &mut cx.to_async(),
1207 );
1208
1209 cx.run_until_parked();
1210 assert_eq!(drain_events(&mut events).as_slice(), []);
1211 assert_eq!(
1212 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1213 "abc\ndef\nghi"
1214 );
1215 assert_eq!(
1216 project.read_with(cx, |project, _| project.agent_location()),
1217 Some(AgentLocation {
1218 buffer: buffer.downgrade(),
1219 position: language::Anchor::min_for_buffer(
1220 cx.update(|cx| buffer.read(cx).remote_id())
1221 ),
1222 })
1223 );
1224
1225 chunks_tx.unbounded_send("```\njkl\n").unwrap();
1226 cx.run_until_parked();
1227 assert_matches!(
1228 drain_events(&mut events).as_slice(),
1229 [EditAgentOutputEvent::Edited { .. }]
1230 );
1231 assert_eq!(
1232 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1233 "jkl"
1234 );
1235 assert_eq!(
1236 project.read_with(cx, |project, _| project.agent_location()),
1237 Some(AgentLocation {
1238 buffer: buffer.downgrade(),
1239 position: language::Anchor::max_for_buffer(
1240 cx.update(|cx| buffer.read(cx).remote_id())
1241 ),
1242 })
1243 );
1244
1245 chunks_tx.unbounded_send("mno\n").unwrap();
1246 cx.run_until_parked();
1247 assert_matches!(
1248 drain_events(&mut events).as_slice(),
1249 [EditAgentOutputEvent::Edited { .. }]
1250 );
1251 assert_eq!(
1252 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1253 "jkl\nmno"
1254 );
1255 assert_eq!(
1256 project.read_with(cx, |project, _| project.agent_location()),
1257 Some(AgentLocation {
1258 buffer: buffer.downgrade(),
1259 position: language::Anchor::max_for_buffer(
1260 cx.update(|cx| buffer.read(cx).remote_id())
1261 ),
1262 })
1263 );
1264
1265 chunks_tx.unbounded_send("pqr\n```").unwrap();
1266 cx.run_until_parked();
1267 assert_matches!(
1268 drain_events(&mut events).as_slice(),
1269 [EditAgentOutputEvent::Edited(_)],
1270 );
1271 assert_eq!(
1272 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1273 "jkl\nmno\npqr"
1274 );
1275 assert_eq!(
1276 project.read_with(cx, |project, _| project.agent_location()),
1277 Some(AgentLocation {
1278 buffer: buffer.downgrade(),
1279 position: language::Anchor::max_for_buffer(
1280 cx.update(|cx| buffer.read(cx).remote_id())
1281 ),
1282 })
1283 );
1284
1285 drop(chunks_tx);
1286 apply.await.unwrap();
1287 assert_eq!(
1288 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1289 "jkl\nmno\npqr"
1290 );
1291 assert_eq!(drain_events(&mut events), vec![]);
1292 assert_eq!(
1293 project.read_with(cx, |project, _| project.agent_location()),
1294 Some(AgentLocation {
1295 buffer: buffer.downgrade(),
1296 position: language::Anchor::max_for_buffer(
1297 cx.update(|cx| buffer.read(cx).remote_id())
1298 ),
1299 })
1300 );
1301 }
1302
1303 #[gpui::test]
1304 async fn test_overwrite_no_content(cx: &mut TestAppContext) {
1305 let agent = init_test(cx).await;
1306 let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1307 let (chunks_tx, chunks_rx) = mpsc::unbounded::<&str>();
1308 let (apply, mut events) = agent.overwrite_with_chunks(
1309 buffer.clone(),
1310 chunks_rx.map(|chunk| Ok(chunk.to_string())),
1311 &mut cx.to_async(),
1312 );
1313
1314 drop(chunks_tx);
1315 cx.run_until_parked();
1316
1317 let result = apply.await;
1318 assert!(result.is_ok(),);
1319 assert_matches!(
1320 drain_events(&mut events).as_slice(),
1321 [EditAgentOutputEvent::Edited { .. }]
1322 );
1323 assert_eq!(
1324 buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1325 ""
1326 );
1327 }
1328
1329 #[gpui::test(iterations = 100)]
1330 async fn test_indent_new_text_chunks(mut rng: StdRng) {
1331 let chunks = to_random_chunks(&mut rng, " abc\n def\n ghi");
1332 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1333 Ok(EditParserEvent::NewTextChunk {
1334 chunk: chunk.clone(),
1335 done: index == chunks.len() - 1,
1336 })
1337 }));
1338 let indented_chunks =
1339 EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1340 .collect::<Vec<_>>()
1341 .await;
1342 let new_text = indented_chunks
1343 .into_iter()
1344 .collect::<Result<String>>()
1345 .unwrap();
1346 assert_eq!(new_text, " abc\n def\n ghi");
1347 }
1348
1349 #[gpui::test(iterations = 100)]
1350 async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1351 let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1352 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1353 Ok(EditParserEvent::NewTextChunk {
1354 chunk: chunk.clone(),
1355 done: index == chunks.len() - 1,
1356 })
1357 }));
1358 let indented_chunks =
1359 EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1360 .collect::<Vec<_>>()
1361 .await;
1362 let new_text = indented_chunks
1363 .into_iter()
1364 .collect::<Result<String>>()
1365 .unwrap();
1366 assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1367 }
1368
1369 #[gpui::test(iterations = 100)]
1370 async fn test_random_indents(mut rng: StdRng) {
1371 let len = rng.random_range(1..=100);
1372 let new_text = util::RandomCharIter::new(&mut rng)
1373 .with_simple_text()
1374 .take(len)
1375 .collect::<String>();
1376 let new_text = new_text
1377 .split('\n')
1378 .map(|line| format!("{}{}", " ".repeat(rng.random_range(0..=8)), line))
1379 .collect::<Vec<_>>()
1380 .join("\n");
1381 let delta = IndentDelta::Spaces(rng.random_range(-4i8..=4i8) as isize);
1382
1383 let chunks = to_random_chunks(&mut rng, &new_text);
1384 let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1385 Ok(EditParserEvent::NewTextChunk {
1386 chunk: chunk.clone(),
1387 done: index == chunks.len() - 1,
1388 })
1389 }));
1390 let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1391 .collect::<Vec<_>>()
1392 .await;
1393 let actual_reindented_text = reindented_chunks
1394 .into_iter()
1395 .collect::<Result<String>>()
1396 .unwrap();
1397 let expected_reindented_text = new_text
1398 .split('\n')
1399 .map(|line| {
1400 if let Some(ix) = line.find(|c| c != ' ') {
1401 let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1402 format!("{}{}", " ".repeat(new_indent), &line[ix..])
1403 } else {
1404 line.to_string()
1405 }
1406 })
1407 .collect::<Vec<_>>()
1408 .join("\n");
1409 assert_eq!(actual_reindented_text, expected_reindented_text);
1410 }
1411
1412 fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1413 let chunk_count = rng.random_range(1..=cmp::min(input.len(), 50));
1414 let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1415 chunk_indices.sort();
1416 chunk_indices.push(input.len());
1417
1418 let mut chunks = Vec::new();
1419 let mut last_ix = 0;
1420 for chunk_ix in chunk_indices {
1421 chunks.push(input[last_ix..chunk_ix].to_string());
1422 last_ix = chunk_ix;
1423 }
1424 chunks
1425 }
1426
1427 fn simulate_llm_output(
1428 agent: &EditAgent,
1429 output: &str,
1430 rng: &mut StdRng,
1431 cx: &mut TestAppContext,
1432 ) {
1433 let executor = cx.executor();
1434 let chunks = to_random_chunks(rng, output);
1435 let model = agent.model.clone();
1436 cx.background_spawn(async move {
1437 for chunk in chunks {
1438 executor.simulate_random_delay().await;
1439 model
1440 .as_fake()
1441 .send_last_completion_stream_text_chunk(chunk);
1442 }
1443 model.as_fake().end_last_completion_stream();
1444 })
1445 .detach();
1446 }
1447
1448 async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1449 init_test_with_thinking(cx, true).await
1450 }
1451
1452 async fn init_test_with_thinking(cx: &mut TestAppContext, thinking_allowed: bool) -> EditAgent {
1453 cx.update(settings::init);
1454
1455 let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1456 let model = Arc::new(FakeLanguageModel::default());
1457 let action_log = cx.new(|_| ActionLog::new(project.clone()));
1458 EditAgent::new(
1459 model,
1460 project,
1461 action_log,
1462 Templates::new(),
1463 EditFormat::XmlTags,
1464 thinking_allowed,
1465 )
1466 }
1467
1468 #[gpui::test(iterations = 10)]
1469 async fn test_non_unique_text_error(cx: &mut TestAppContext, mut rng: StdRng) {
1470 let agent = init_test(cx).await;
1471 let original_text = indoc! {"
1472 function foo() {
1473 return 42;
1474 }
1475
1476 function bar() {
1477 return 42;
1478 }
1479
1480 function baz() {
1481 return 42;
1482 }
1483 "};
1484 let buffer = cx.new(|cx| Buffer::local(original_text, cx));
1485 let (apply, mut events) = agent.edit(
1486 buffer.clone(),
1487 String::new(),
1488 &LanguageModelRequest::default(),
1489 &mut cx.to_async(),
1490 );
1491 cx.run_until_parked();
1492
1493 // When <old_text> matches text in more than one place
1494 simulate_llm_output(
1495 &agent,
1496 indoc! {"
1497 <old_text>
1498 return 42;
1499 }
1500 </old_text>
1501 <new_text>
1502 return 100;
1503 }
1504 </new_text>
1505 "},
1506 &mut rng,
1507 cx,
1508 );
1509 apply.await.unwrap();
1510
1511 // Then the text should remain unchanged
1512 let result_text = buffer.read_with(cx, |buffer, _| buffer.snapshot().text());
1513 assert_eq!(
1514 result_text,
1515 indoc! {"
1516 function foo() {
1517 return 42;
1518 }
1519
1520 function bar() {
1521 return 42;
1522 }
1523
1524 function baz() {
1525 return 42;
1526 }
1527 "},
1528 "Text should remain unchanged when there are multiple matches"
1529 );
1530
1531 // And AmbiguousEditRange even should be emitted
1532 let events = drain_events(&mut events);
1533 let ambiguous_ranges = vec![2..3, 6..7, 10..11];
1534 assert!(
1535 events.contains(&EditAgentOutputEvent::AmbiguousEditRange(ambiguous_ranges)),
1536 "Should emit AmbiguousEditRange for non-unique text"
1537 );
1538 }
1539
1540 #[gpui::test]
1541 async fn test_thinking_allowed_forwarded_to_request(cx: &mut TestAppContext) {
1542 let agent = init_test_with_thinking(cx, false).await;
1543 let buffer = cx.new(|cx| Buffer::local("hello\n", cx));
1544 let (_apply, _events) = agent.edit(
1545 buffer.clone(),
1546 String::new(),
1547 &LanguageModelRequest::default(),
1548 &mut cx.to_async(),
1549 );
1550 cx.run_until_parked();
1551
1552 let pending = agent.model.as_fake().pending_completions();
1553 assert_eq!(pending.len(), 1);
1554 assert!(
1555 !pending[0].thinking_allowed,
1556 "Expected thinking_allowed to be false when EditAgent is constructed with allow_thinking=false"
1557 );
1558 agent.model.as_fake().end_last_completion_stream();
1559
1560 let agent = init_test_with_thinking(cx, true).await;
1561 let buffer = cx.new(|cx| Buffer::local("hello\n", cx));
1562 let (_apply, _events) = agent.edit(
1563 buffer,
1564 String::new(),
1565 &LanguageModelRequest::default(),
1566 &mut cx.to_async(),
1567 );
1568 cx.run_until_parked();
1569
1570 let pending = agent.model.as_fake().pending_completions();
1571 assert_eq!(pending.len(), 1);
1572 assert!(
1573 pending[0].thinking_allowed,
1574 "Expected thinking_allowed to be true when EditAgent is constructed with allow_thinking=true"
1575 );
1576 agent.model.as_fake().end_last_completion_stream();
1577 }
1578
1579 fn drain_events(
1580 stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1581 ) -> Vec<EditAgentOutputEvent> {
1582 let mut events = Vec::new();
1583 while let Ok(Some(event)) = stream.try_next() {
1584 events.push(event);
1585 }
1586 events
1587 }
1588}