edit_agent.rs

   1mod create_file_parser;
   2mod edit_parser;
   3#[cfg(test)]
   4mod evals;
   5mod streaming_fuzzy_matcher;
   6
   7use crate::{Template, Templates};
   8use anyhow::Result;
   9use assistant_tool::ActionLog;
  10use create_file_parser::{CreateFileParser, CreateFileParserEvent};
  11use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
  12use futures::{
  13    Stream, StreamExt,
  14    channel::mpsc::{self, UnboundedReceiver},
  15    pin_mut,
  16    stream::BoxStream,
  17};
  18use gpui::{AppContext, AsyncApp, Entity, Task};
  19use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot};
  20use language_model::{
  21    LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
  22    LanguageModelToolChoice, MessageContent, Role,
  23};
  24use project::{AgentLocation, Project};
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize};
  27use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
  28use streaming_diff::{CharOperation, StreamingDiff};
  29use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
  30use util::debug_panic;
  31
  32#[derive(Serialize)]
  33struct CreateFilePromptTemplate {
  34    path: Option<PathBuf>,
  35    edit_description: String,
  36}
  37
  38impl Template for CreateFilePromptTemplate {
  39    const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
  40}
  41
  42#[derive(Serialize)]
  43struct EditFilePromptTemplate {
  44    path: Option<PathBuf>,
  45    edit_description: String,
  46}
  47
  48impl Template for EditFilePromptTemplate {
  49    const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
  50}
  51
  52#[derive(Clone, Debug, PartialEq, Eq)]
  53pub enum EditAgentOutputEvent {
  54    ResolvingEditRange(Range<Anchor>),
  55    UnresolvedEditRange,
  56    Edited,
  57}
  58
  59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  60pub struct EditAgentOutput {
  61    pub raw_edits: String,
  62    pub parser_metrics: EditParserMetrics,
  63}
  64
  65#[derive(Clone)]
  66pub struct EditAgent {
  67    model: Arc<dyn LanguageModel>,
  68    action_log: Entity<ActionLog>,
  69    project: Entity<Project>,
  70    templates: Arc<Templates>,
  71}
  72
  73impl EditAgent {
  74    pub fn new(
  75        model: Arc<dyn LanguageModel>,
  76        project: Entity<Project>,
  77        action_log: Entity<ActionLog>,
  78        templates: Arc<Templates>,
  79    ) -> Self {
  80        EditAgent {
  81            model,
  82            project,
  83            action_log,
  84            templates,
  85        }
  86    }
  87
  88    pub fn overwrite(
  89        &self,
  90        buffer: Entity<Buffer>,
  91        edit_description: String,
  92        conversation: &LanguageModelRequest,
  93        cx: &mut AsyncApp,
  94    ) -> (
  95        Task<Result<EditAgentOutput>>,
  96        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
  97    ) {
  98        let this = self.clone();
  99        let (events_tx, events_rx) = mpsc::unbounded();
 100        let conversation = conversation.clone();
 101        let output = cx.spawn(async move |cx| {
 102            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 103            let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
 104            let prompt = CreateFilePromptTemplate {
 105                path,
 106                edit_description,
 107            }
 108            .render(&this.templates)?;
 109            let new_chunks = this.request(conversation, prompt, cx).await?;
 110
 111            let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
 112            while let Some(event) = inner_events.next().await {
 113                events_tx.unbounded_send(event).ok();
 114            }
 115            output.await
 116        });
 117        (output, events_rx)
 118    }
 119
 120    fn overwrite_with_chunks(
 121        &self,
 122        buffer: Entity<Buffer>,
 123        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 124        cx: &mut AsyncApp,
 125    ) -> (
 126        Task<Result<EditAgentOutput>>,
 127        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 128    ) {
 129        let (output_events_tx, output_events_rx) = mpsc::unbounded();
 130        let (parse_task, parse_rx) = Self::parse_create_file_chunks(edit_chunks, cx);
 131        let this = self.clone();
 132        let task = cx.spawn(async move |cx| {
 133            this.action_log
 134                .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
 135            this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx)
 136                .await?;
 137            parse_task.await
 138        });
 139        (task, output_events_rx)
 140    }
 141
 142    async fn overwrite_with_chunks_internal(
 143        &self,
 144        buffer: Entity<Buffer>,
 145        mut parse_rx: UnboundedReceiver<Result<CreateFileParserEvent>>,
 146        output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
 147        cx: &mut AsyncApp,
 148    ) -> Result<()> {
 149        cx.update(|cx| {
 150            buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
 151            self.action_log.update(cx, |log, cx| {
 152                log.buffer_edited(buffer.clone(), cx);
 153            });
 154            self.project.update(cx, |project, cx| {
 155                project.set_agent_location(
 156                    Some(AgentLocation {
 157                        buffer: buffer.downgrade(),
 158                        position: language::Anchor::MAX,
 159                    }),
 160                    cx,
 161                )
 162            });
 163            output_events_tx
 164                .unbounded_send(EditAgentOutputEvent::Edited)
 165                .ok();
 166        })?;
 167
 168        while let Some(event) = parse_rx.next().await {
 169            match event? {
 170                CreateFileParserEvent::NewTextChunk { chunk } => {
 171                    cx.update(|cx| {
 172                        buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
 173                        self.action_log
 174                            .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 175                        self.project.update(cx, |project, cx| {
 176                            project.set_agent_location(
 177                                Some(AgentLocation {
 178                                    buffer: buffer.downgrade(),
 179                                    position: language::Anchor::MAX,
 180                                }),
 181                                cx,
 182                            )
 183                        });
 184                    })?;
 185                    output_events_tx
 186                        .unbounded_send(EditAgentOutputEvent::Edited)
 187                        .ok();
 188                }
 189            }
 190        }
 191
 192        Ok(())
 193    }
 194
 195    pub fn edit(
 196        &self,
 197        buffer: Entity<Buffer>,
 198        edit_description: String,
 199        conversation: &LanguageModelRequest,
 200        cx: &mut AsyncApp,
 201    ) -> (
 202        Task<Result<EditAgentOutput>>,
 203        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 204    ) {
 205        let this = self.clone();
 206        let (events_tx, events_rx) = mpsc::unbounded();
 207        let conversation = conversation.clone();
 208        let output = cx.spawn(async move |cx| {
 209            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 210            let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
 211            let prompt = EditFilePromptTemplate {
 212                path,
 213                edit_description,
 214            }
 215            .render(&this.templates)?;
 216            let edit_chunks = this.request(conversation, prompt, cx).await?;
 217            this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
 218                .await
 219        });
 220        (output, events_rx)
 221    }
 222
 223    async fn apply_edit_chunks(
 224        &self,
 225        buffer: Entity<Buffer>,
 226        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 227        output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
 228        cx: &mut AsyncApp,
 229    ) -> Result<EditAgentOutput> {
 230        self.action_log
 231            .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
 232
 233        let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
 234        let mut edit_events = edit_events.peekable();
 235        while let Some(edit_event) = Pin::new(&mut edit_events).peek().await {
 236            // Skip events until we're at the start of a new edit.
 237            let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else {
 238                edit_events.next().await.unwrap()?;
 239                continue;
 240            };
 241
 242            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 243
 244            // Resolve the old text in the background, updating the agent
 245            // location as we keep refining which range it corresponds to.
 246            let (resolve_old_text, mut old_range) =
 247                Self::resolve_old_text(snapshot.text.clone(), edit_events, cx);
 248            while let Ok(old_range) = old_range.recv().await {
 249                if let Some(old_range) = old_range {
 250                    let old_range = snapshot.anchor_before(old_range.start)
 251                        ..snapshot.anchor_before(old_range.end);
 252                    self.project.update(cx, |project, cx| {
 253                        project.set_agent_location(
 254                            Some(AgentLocation {
 255                                buffer: buffer.downgrade(),
 256                                position: old_range.end,
 257                            }),
 258                            cx,
 259                        );
 260                    })?;
 261                    output_events
 262                        .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range))
 263                        .ok();
 264                }
 265            }
 266
 267            let (edit_events_, resolved_old_text) = resolve_old_text.await?;
 268            edit_events = edit_events_;
 269
 270            // If we can't resolve the old text, restart the loop waiting for a
 271            // new edit (or for the stream to end).
 272            let Some(resolved_old_text) = resolved_old_text else {
 273                output_events
 274                    .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
 275                    .ok();
 276                continue;
 277            };
 278
 279            // Compute edits in the background and apply them as they become
 280            // available.
 281            let (compute_edits, edits) =
 282                Self::compute_edits(snapshot, resolved_old_text, edit_events, cx);
 283            let mut edits = edits.ready_chunks(32);
 284            while let Some(edits) = edits.next().await {
 285                if edits.is_empty() {
 286                    continue;
 287                }
 288
 289                // Edit the buffer and report edits to the action log as part of the
 290                // same effect cycle, otherwise the edit will be reported as if the
 291                // user made it.
 292                cx.update(|cx| {
 293                    let max_edit_end = buffer.update(cx, |buffer, cx| {
 294                        buffer.edit(edits.iter().cloned(), None, cx);
 295                        let max_edit_end = buffer
 296                            .summaries_for_anchors::<Point, _>(
 297                                edits.iter().map(|(range, _)| &range.end),
 298                            )
 299                            .max()
 300                            .unwrap();
 301                        buffer.anchor_before(max_edit_end)
 302                    });
 303                    self.action_log
 304                        .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 305                    self.project.update(cx, |project, cx| {
 306                        project.set_agent_location(
 307                            Some(AgentLocation {
 308                                buffer: buffer.downgrade(),
 309                                position: max_edit_end,
 310                            }),
 311                            cx,
 312                        );
 313                    });
 314                })?;
 315                output_events
 316                    .unbounded_send(EditAgentOutputEvent::Edited)
 317                    .ok();
 318            }
 319
 320            edit_events = compute_edits.await?;
 321        }
 322
 323        output.await
 324    }
 325
 326    fn parse_edit_chunks(
 327        chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 328        cx: &mut AsyncApp,
 329    ) -> (
 330        Task<Result<EditAgentOutput>>,
 331        UnboundedReceiver<Result<EditParserEvent>>,
 332    ) {
 333        let (tx, rx) = mpsc::unbounded();
 334        let output = cx.background_spawn(async move {
 335            pin_mut!(chunks);
 336
 337            let mut parser = EditParser::new();
 338            let mut raw_edits = String::new();
 339            while let Some(chunk) = chunks.next().await {
 340                match chunk {
 341                    Ok(chunk) => {
 342                        raw_edits.push_str(&chunk);
 343                        for event in parser.push(&chunk) {
 344                            tx.unbounded_send(Ok(event))?;
 345                        }
 346                    }
 347                    Err(error) => {
 348                        tx.unbounded_send(Err(error.into()))?;
 349                    }
 350                }
 351            }
 352            Ok(EditAgentOutput {
 353                raw_edits,
 354                parser_metrics: parser.finish(),
 355            })
 356        });
 357        (output, rx)
 358    }
 359
 360    fn parse_create_file_chunks(
 361        chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 362        cx: &mut AsyncApp,
 363    ) -> (
 364        Task<Result<EditAgentOutput>>,
 365        UnboundedReceiver<Result<CreateFileParserEvent>>,
 366    ) {
 367        let (tx, rx) = mpsc::unbounded();
 368        let output = cx.background_spawn(async move {
 369            pin_mut!(chunks);
 370
 371            let mut parser = CreateFileParser::new();
 372            let mut raw_edits = String::new();
 373            while let Some(chunk) = chunks.next().await {
 374                match chunk {
 375                    Ok(chunk) => {
 376                        raw_edits.push_str(&chunk);
 377                        for event in parser.push(Some(&chunk)) {
 378                            tx.unbounded_send(Ok(event))?;
 379                        }
 380                    }
 381                    Err(error) => {
 382                        tx.unbounded_send(Err(error.into()))?;
 383                    }
 384                }
 385            }
 386            // Send final events with None to indicate completion
 387            for event in parser.push(None) {
 388                tx.unbounded_send(Ok(event))?;
 389            }
 390            Ok(EditAgentOutput {
 391                raw_edits,
 392                parser_metrics: EditParserMetrics::default(),
 393            })
 394        });
 395        (output, rx)
 396    }
 397
 398    fn resolve_old_text<T>(
 399        snapshot: TextBufferSnapshot,
 400        mut edit_events: T,
 401        cx: &mut AsyncApp,
 402    ) -> (
 403        Task<Result<(T, Option<ResolvedOldText>)>>,
 404        async_watch::Receiver<Option<Range<usize>>>,
 405    )
 406    where
 407        T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
 408    {
 409        let (old_range_tx, old_range_rx) = async_watch::channel(None);
 410        let task = cx.background_spawn(async move {
 411            let mut matcher = StreamingFuzzyMatcher::new(snapshot);
 412            while let Some(edit_event) = edit_events.next().await {
 413                let EditParserEvent::OldTextChunk { chunk, done } = edit_event? else {
 414                    break;
 415                };
 416
 417                old_range_tx.send(matcher.push(&chunk))?;
 418                if done {
 419                    break;
 420                }
 421            }
 422
 423            let old_range = matcher.finish();
 424            old_range_tx.send(old_range.clone())?;
 425            if let Some(old_range) = old_range {
 426                let line_indent =
 427                    LineIndent::from_iter(matcher.query_lines().first().unwrap().chars());
 428                Ok((
 429                    edit_events,
 430                    Some(ResolvedOldText {
 431                        range: old_range,
 432                        indent: line_indent,
 433                    }),
 434                ))
 435            } else {
 436                Ok((edit_events, None))
 437            }
 438        });
 439
 440        (task, old_range_rx)
 441    }
 442
 443    fn compute_edits<T>(
 444        snapshot: BufferSnapshot,
 445        resolved_old_text: ResolvedOldText,
 446        mut edit_events: T,
 447        cx: &mut AsyncApp,
 448    ) -> (
 449        Task<Result<T>>,
 450        UnboundedReceiver<(Range<Anchor>, Arc<str>)>,
 451    )
 452    where
 453        T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
 454    {
 455        let (edits_tx, edits_rx) = mpsc::unbounded();
 456        let compute_edits = cx.background_spawn(async move {
 457            let buffer_start_indent = snapshot
 458                .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row);
 459            let indent_delta = if buffer_start_indent.tabs > 0 {
 460                IndentDelta::Tabs(
 461                    buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize,
 462                )
 463            } else {
 464                IndentDelta::Spaces(
 465                    buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize,
 466                )
 467            };
 468
 469            let old_text = snapshot
 470                .text_for_range(resolved_old_text.range.clone())
 471                .collect::<String>();
 472            let mut diff = StreamingDiff::new(old_text);
 473            let mut edit_start = resolved_old_text.range.start;
 474            let mut new_text_chunks =
 475                Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
 476            let mut done = false;
 477            while !done {
 478                let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await {
 479                    diff.push_new(&new_text_chunk?)
 480                } else {
 481                    done = true;
 482                    mem::take(&mut diff).finish()
 483                };
 484
 485                for op in char_operations {
 486                    match op {
 487                        CharOperation::Insert { text } => {
 488                            let edit_start = snapshot.anchor_after(edit_start);
 489                            edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
 490                        }
 491                        CharOperation::Delete { bytes } => {
 492                            let edit_end = edit_start + bytes;
 493                            let edit_range =
 494                                snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end);
 495                            edit_start = edit_end;
 496                            edits_tx.unbounded_send((edit_range, Arc::from("")))?;
 497                        }
 498                        CharOperation::Keep { bytes } => edit_start += bytes,
 499                    }
 500                }
 501            }
 502
 503            drop(new_text_chunks);
 504            anyhow::Ok(edit_events)
 505        });
 506
 507        (compute_edits, edits_rx)
 508    }
 509
 510    fn reindent_new_text_chunks(
 511        delta: IndentDelta,
 512        mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
 513    ) -> impl Stream<Item = Result<String>> {
 514        let mut buffer = String::new();
 515        let mut in_leading_whitespace = true;
 516        let mut done = false;
 517        futures::stream::poll_fn(move |cx| {
 518            while !done {
 519                let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
 520                    Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
 521                        (chunk, done)
 522                    }
 523                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
 524                    Poll::Pending => return Poll::Pending,
 525                    _ => return Poll::Ready(None),
 526                };
 527
 528                buffer.push_str(&chunk);
 529
 530                let mut indented_new_text = String::new();
 531                let mut start_ix = 0;
 532                let mut newlines = buffer.match_indices('\n').peekable();
 533                loop {
 534                    let (line_end, is_pending_line) = match newlines.next() {
 535                        Some((ix, _)) => (ix, false),
 536                        None => (buffer.len(), true),
 537                    };
 538                    let line = &buffer[start_ix..line_end];
 539
 540                    if in_leading_whitespace {
 541                        if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
 542                            // We found a non-whitespace character, adjust
 543                            // indentation based on the delta.
 544                            let new_indent_len =
 545                                cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
 546                            indented_new_text
 547                                .extend(iter::repeat(delta.character()).take(new_indent_len));
 548                            indented_new_text.push_str(&line[non_whitespace_ix..]);
 549                            in_leading_whitespace = false;
 550                        } else if is_pending_line {
 551                            // We're still in leading whitespace and this line is incomplete.
 552                            // Stop processing until we receive more input.
 553                            break;
 554                        } else {
 555                            // This line is entirely whitespace. Push it without indentation.
 556                            indented_new_text.push_str(line);
 557                        }
 558                    } else {
 559                        indented_new_text.push_str(line);
 560                    }
 561
 562                    if is_pending_line {
 563                        start_ix = line_end;
 564                        break;
 565                    } else {
 566                        in_leading_whitespace = true;
 567                        indented_new_text.push('\n');
 568                        start_ix = line_end + 1;
 569                    }
 570                }
 571                buffer.replace_range(..start_ix, "");
 572
 573                // This was the last chunk, push all the buffered content as-is.
 574                if is_last_chunk {
 575                    indented_new_text.push_str(&buffer);
 576                    buffer.clear();
 577                    done = true;
 578                }
 579
 580                if !indented_new_text.is_empty() {
 581                    return Poll::Ready(Some(Ok(indented_new_text)));
 582                }
 583            }
 584
 585            Poll::Ready(None)
 586        })
 587    }
 588
 589    async fn request(
 590        &self,
 591        mut conversation: LanguageModelRequest,
 592        prompt: String,
 593        cx: &mut AsyncApp,
 594    ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
 595        let mut messages_iter = conversation.messages.iter_mut();
 596        if let Some(last_message) = messages_iter.next_back() {
 597            if last_message.role == Role::Assistant {
 598                let old_content_len = last_message.content.len();
 599                last_message
 600                    .content
 601                    .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
 602                let new_content_len = last_message.content.len();
 603
 604                // We just removed pending tool uses from the content of the
 605                // last message, so it doesn't make sense to cache it anymore
 606                // (e.g., the message will look very different on the next
 607                // request). Thus, we move the flag to the message prior to it,
 608                // as it will still be a valid prefix of the conversation.
 609                if old_content_len != new_content_len && last_message.cache {
 610                    if let Some(prev_message) = messages_iter.next_back() {
 611                        last_message.cache = false;
 612                        prev_message.cache = true;
 613                    }
 614                }
 615
 616                if last_message.content.is_empty() {
 617                    conversation.messages.pop();
 618                }
 619            } else {
 620                debug_panic!(
 621                    "Last message must be an Assistant tool calling! Got {:?}",
 622                    last_message.content
 623                );
 624            }
 625        }
 626
 627        conversation.messages.push(LanguageModelRequestMessage {
 628            role: Role::User,
 629            content: vec![MessageContent::Text(prompt)],
 630            cache: false,
 631        });
 632
 633        // Include tools in the request so that we can take advantage of
 634        // caching when ToolChoice::None is supported.
 635        let mut tool_choice = None;
 636        let mut tools = Vec::new();
 637        if !conversation.tools.is_empty()
 638            && self
 639                .model
 640                .supports_tool_choice(LanguageModelToolChoice::None)
 641        {
 642            tool_choice = Some(LanguageModelToolChoice::None);
 643            tools = conversation.tools.clone();
 644        }
 645
 646        let request = LanguageModelRequest {
 647            thread_id: conversation.thread_id,
 648            prompt_id: conversation.prompt_id,
 649            mode: conversation.mode,
 650            messages: conversation.messages,
 651            tool_choice,
 652            tools,
 653            stop: Vec::new(),
 654            temperature: None,
 655        };
 656
 657        Ok(self.model.stream_completion_text(request, cx).await?.stream)
 658    }
 659}
 660
 661struct ResolvedOldText {
 662    range: Range<usize>,
 663    indent: LineIndent,
 664}
 665
 666#[derive(Copy, Clone, Debug)]
 667enum IndentDelta {
 668    Spaces(isize),
 669    Tabs(isize),
 670}
 671
 672impl IndentDelta {
 673    fn character(&self) -> char {
 674        match self {
 675            IndentDelta::Spaces(_) => ' ',
 676            IndentDelta::Tabs(_) => '\t',
 677        }
 678    }
 679
 680    fn len(&self) -> isize {
 681        match self {
 682            IndentDelta::Spaces(n) => *n,
 683            IndentDelta::Tabs(n) => *n,
 684        }
 685    }
 686}
 687
 688#[cfg(test)]
 689mod tests {
 690    use super::*;
 691    use fs::FakeFs;
 692    use futures::stream;
 693    use gpui::{AppContext, TestAppContext};
 694    use indoc::indoc;
 695    use language_model::fake_provider::FakeLanguageModel;
 696    use project::{AgentLocation, Project};
 697    use rand::prelude::*;
 698    use rand::rngs::StdRng;
 699    use std::cmp;
 700
 701    #[gpui::test(iterations = 100)]
 702    async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) {
 703        let agent = init_test(cx).await;
 704        let buffer = cx.new(|cx| {
 705            Buffer::local(
 706                indoc! {"
 707                    abc
 708                    def
 709                    ghi
 710                "},
 711                cx,
 712            )
 713        });
 714        let (apply, _events) = agent.edit(
 715            buffer.clone(),
 716            String::new(),
 717            &LanguageModelRequest::default(),
 718            &mut cx.to_async(),
 719        );
 720        cx.run_until_parked();
 721
 722        simulate_llm_output(
 723            &agent,
 724            indoc! {"
 725                <old_text></old_text>
 726                <new_text>jkl</new_text>
 727                <old_text>def</old_text>
 728                <new_text>DEF</new_text>
 729            "},
 730            &mut rng,
 731            cx,
 732        );
 733        apply.await.unwrap();
 734
 735        pretty_assertions::assert_eq!(
 736            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 737            indoc! {"
 738                abc
 739                DEF
 740                ghi
 741            "}
 742        );
 743    }
 744
 745    #[gpui::test(iterations = 100)]
 746    async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
 747        let agent = init_test(cx).await;
 748        let buffer = cx.new(|cx| {
 749            Buffer::local(
 750                indoc! {"
 751                    lorem
 752                            ipsum
 753                            dolor
 754                            sit
 755                "},
 756                cx,
 757            )
 758        });
 759        let (apply, _events) = agent.edit(
 760            buffer.clone(),
 761            String::new(),
 762            &LanguageModelRequest::default(),
 763            &mut cx.to_async(),
 764        );
 765        cx.run_until_parked();
 766
 767        simulate_llm_output(
 768            &agent,
 769            indoc! {"
 770                <old_text>
 771                    ipsum
 772                    dolor
 773                    sit
 774                </old_text>
 775                <new_text>
 776                    ipsum
 777                    dolor
 778                    sit
 779                amet
 780                </new_text>
 781            "},
 782            &mut rng,
 783            cx,
 784        );
 785        apply.await.unwrap();
 786
 787        pretty_assertions::assert_eq!(
 788            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 789            indoc! {"
 790                lorem
 791                        ipsum
 792                        dolor
 793                        sit
 794                    amet
 795            "}
 796        );
 797    }
 798
 799    #[gpui::test(iterations = 100)]
 800    async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
 801        let agent = init_test(cx).await;
 802        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 803        let (apply, _events) = agent.edit(
 804            buffer.clone(),
 805            String::new(),
 806            &LanguageModelRequest::default(),
 807            &mut cx.to_async(),
 808        );
 809        cx.run_until_parked();
 810
 811        simulate_llm_output(
 812            &agent,
 813            indoc! {"
 814                <old_text>
 815                def
 816                </old_text>
 817                <new_text>
 818                DEF
 819                </new_text>
 820
 821                <old_text>
 822                DEF
 823                </old_text>
 824                <new_text>
 825                DeF
 826                </new_text>
 827            "},
 828            &mut rng,
 829            cx,
 830        );
 831        apply.await.unwrap();
 832
 833        assert_eq!(
 834            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 835            "abc\nDeF\nghi"
 836        );
 837    }
 838
 839    #[gpui::test(iterations = 100)]
 840    async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
 841        let agent = init_test(cx).await;
 842        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 843        let (apply, _events) = agent.edit(
 844            buffer.clone(),
 845            String::new(),
 846            &LanguageModelRequest::default(),
 847            &mut cx.to_async(),
 848        );
 849        cx.run_until_parked();
 850
 851        simulate_llm_output(
 852            &agent,
 853            indoc! {"
 854                <old_text>
 855                jkl
 856                </old_text>
 857                <new_text>
 858                mno
 859                </new_text>
 860
 861                <old_text>
 862                abc
 863                </old_text>
 864                <new_text>
 865                ABC
 866                </new_text>
 867            "},
 868            &mut rng,
 869            cx,
 870        );
 871        apply.await.unwrap();
 872
 873        assert_eq!(
 874            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 875            "ABC\ndef\nghi"
 876        );
 877    }
 878
 879    #[gpui::test]
 880    async fn test_edit_events(cx: &mut TestAppContext) {
 881        let agent = init_test(cx).await;
 882        let model = agent.model.as_fake();
 883        let project = agent
 884            .action_log
 885            .read_with(cx, |log, _| log.project().clone());
 886        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx));
 887
 888        let mut async_cx = cx.to_async();
 889        let (apply, mut events) = agent.edit(
 890            buffer.clone(),
 891            String::new(),
 892            &LanguageModelRequest::default(),
 893            &mut async_cx,
 894        );
 895        cx.run_until_parked();
 896
 897        model.stream_last_completion_response("<old_text>a");
 898        cx.run_until_parked();
 899        assert_eq!(drain_events(&mut events), vec![]);
 900        assert_eq!(
 901            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 902            "abc\ndef\nghi\njkl"
 903        );
 904        assert_eq!(
 905            project.read_with(cx, |project, _| project.agent_location()),
 906            None
 907        );
 908
 909        model.stream_last_completion_response("bc</old_text>");
 910        cx.run_until_parked();
 911        assert_eq!(
 912            drain_events(&mut events),
 913            vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
 914                cx,
 915                |buffer, _| buffer.anchor_before(Point::new(0, 0))
 916                    ..buffer.anchor_before(Point::new(0, 3))
 917            ))]
 918        );
 919        assert_eq!(
 920            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 921            "abc\ndef\nghi\njkl"
 922        );
 923        assert_eq!(
 924            project.read_with(cx, |project, _| project.agent_location()),
 925            Some(AgentLocation {
 926                buffer: buffer.downgrade(),
 927                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
 928            })
 929        );
 930
 931        model.stream_last_completion_response("<new_text>abX");
 932        cx.run_until_parked();
 933        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 934        assert_eq!(
 935            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 936            "abXc\ndef\nghi\njkl"
 937        );
 938        assert_eq!(
 939            project.read_with(cx, |project, _| project.agent_location()),
 940            Some(AgentLocation {
 941                buffer: buffer.downgrade(),
 942                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
 943            })
 944        );
 945
 946        model.stream_last_completion_response("cY");
 947        cx.run_until_parked();
 948        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 949        assert_eq!(
 950            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 951            "abXcY\ndef\nghi\njkl"
 952        );
 953        assert_eq!(
 954            project.read_with(cx, |project, _| project.agent_location()),
 955            Some(AgentLocation {
 956                buffer: buffer.downgrade(),
 957                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
 958            })
 959        );
 960
 961        model.stream_last_completion_response("</new_text>");
 962        model.stream_last_completion_response("<old_text>hall");
 963        cx.run_until_parked();
 964        assert_eq!(drain_events(&mut events), vec![]);
 965        assert_eq!(
 966            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 967            "abXcY\ndef\nghi\njkl"
 968        );
 969        assert_eq!(
 970            project.read_with(cx, |project, _| project.agent_location()),
 971            Some(AgentLocation {
 972                buffer: buffer.downgrade(),
 973                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
 974            })
 975        );
 976
 977        model.stream_last_completion_response("ucinated old</old_text>");
 978        model.stream_last_completion_response("<new_text>");
 979        cx.run_until_parked();
 980        assert_eq!(
 981            drain_events(&mut events),
 982            vec![EditAgentOutputEvent::UnresolvedEditRange]
 983        );
 984        assert_eq!(
 985            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 986            "abXcY\ndef\nghi\njkl"
 987        );
 988        assert_eq!(
 989            project.read_with(cx, |project, _| project.agent_location()),
 990            Some(AgentLocation {
 991                buffer: buffer.downgrade(),
 992                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
 993            })
 994        );
 995
 996        model.stream_last_completion_response("hallucinated new</new_");
 997        model.stream_last_completion_response("text>");
 998        cx.run_until_parked();
 999        assert_eq!(drain_events(&mut events), vec![]);
1000        assert_eq!(
1001            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1002            "abXcY\ndef\nghi\njkl"
1003        );
1004        assert_eq!(
1005            project.read_with(cx, |project, _| project.agent_location()),
1006            Some(AgentLocation {
1007                buffer: buffer.downgrade(),
1008                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1009            })
1010        );
1011
1012        model.stream_last_completion_response("<old_text>\nghi\nj");
1013        cx.run_until_parked();
1014        assert_eq!(
1015            drain_events(&mut events),
1016            vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1017                cx,
1018                |buffer, _| buffer.anchor_before(Point::new(2, 0))
1019                    ..buffer.anchor_before(Point::new(2, 3))
1020            ))]
1021        );
1022        assert_eq!(
1023            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1024            "abXcY\ndef\nghi\njkl"
1025        );
1026        assert_eq!(
1027            project.read_with(cx, |project, _| project.agent_location()),
1028            Some(AgentLocation {
1029                buffer: buffer.downgrade(),
1030                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1031            })
1032        );
1033
1034        model.stream_last_completion_response("kl</old_text>");
1035        model.stream_last_completion_response("<new_text>");
1036        cx.run_until_parked();
1037        assert_eq!(
1038            drain_events(&mut events),
1039            vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1040                cx,
1041                |buffer, _| buffer.anchor_before(Point::new(2, 0))
1042                    ..buffer.anchor_before(Point::new(3, 3))
1043            ))]
1044        );
1045        assert_eq!(
1046            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1047            "abXcY\ndef\nghi\njkl"
1048        );
1049        assert_eq!(
1050            project.read_with(cx, |project, _| project.agent_location()),
1051            Some(AgentLocation {
1052                buffer: buffer.downgrade(),
1053                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)))
1054            })
1055        );
1056
1057        model.stream_last_completion_response("GHI</new_text>");
1058        cx.run_until_parked();
1059        assert_eq!(
1060            drain_events(&mut events),
1061            vec![EditAgentOutputEvent::Edited]
1062        );
1063        assert_eq!(
1064            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1065            "abXcY\ndef\nGHI"
1066        );
1067        assert_eq!(
1068            project.read_with(cx, |project, _| project.agent_location()),
1069            Some(AgentLocation {
1070                buffer: buffer.downgrade(),
1071                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1072            })
1073        );
1074
1075        model.end_last_completion_stream();
1076        apply.await.unwrap();
1077        assert_eq!(
1078            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1079            "abXcY\ndef\nGHI"
1080        );
1081        assert_eq!(drain_events(&mut events), vec![]);
1082        assert_eq!(
1083            project.read_with(cx, |project, _| project.agent_location()),
1084            Some(AgentLocation {
1085                buffer: buffer.downgrade(),
1086                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1087            })
1088        );
1089    }
1090
1091    #[gpui::test]
1092    async fn test_overwrite_events(cx: &mut TestAppContext) {
1093        let agent = init_test(cx).await;
1094        let project = agent
1095            .action_log
1096            .read_with(cx, |log, _| log.project().clone());
1097        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1098        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1099        let (apply, mut events) = agent.overwrite_with_chunks(
1100            buffer.clone(),
1101            chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1102            &mut cx.to_async(),
1103        );
1104
1105        cx.run_until_parked();
1106        assert_eq!(
1107            drain_events(&mut events),
1108            vec![EditAgentOutputEvent::Edited]
1109        );
1110        assert_eq!(
1111            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1112            ""
1113        );
1114        assert_eq!(
1115            project.read_with(cx, |project, _| project.agent_location()),
1116            Some(AgentLocation {
1117                buffer: buffer.downgrade(),
1118                position: language::Anchor::MAX
1119            })
1120        );
1121
1122        chunks_tx.unbounded_send("```\njkl\n").unwrap();
1123        cx.run_until_parked();
1124        assert_eq!(
1125            drain_events(&mut events),
1126            vec![EditAgentOutputEvent::Edited]
1127        );
1128        assert_eq!(
1129            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1130            "jkl"
1131        );
1132        assert_eq!(
1133            project.read_with(cx, |project, _| project.agent_location()),
1134            Some(AgentLocation {
1135                buffer: buffer.downgrade(),
1136                position: language::Anchor::MAX
1137            })
1138        );
1139
1140        chunks_tx.unbounded_send("mno\n").unwrap();
1141        cx.run_until_parked();
1142        assert_eq!(
1143            drain_events(&mut events),
1144            vec![EditAgentOutputEvent::Edited]
1145        );
1146        assert_eq!(
1147            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1148            "jkl\nmno"
1149        );
1150        assert_eq!(
1151            project.read_with(cx, |project, _| project.agent_location()),
1152            Some(AgentLocation {
1153                buffer: buffer.downgrade(),
1154                position: language::Anchor::MAX
1155            })
1156        );
1157
1158        chunks_tx.unbounded_send("pqr\n```").unwrap();
1159        cx.run_until_parked();
1160        assert_eq!(
1161            drain_events(&mut events),
1162            vec![EditAgentOutputEvent::Edited]
1163        );
1164        assert_eq!(
1165            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1166            "jkl\nmno\npqr"
1167        );
1168        assert_eq!(
1169            project.read_with(cx, |project, _| project.agent_location()),
1170            Some(AgentLocation {
1171                buffer: buffer.downgrade(),
1172                position: language::Anchor::MAX
1173            })
1174        );
1175
1176        drop(chunks_tx);
1177        apply.await.unwrap();
1178        assert_eq!(
1179            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1180            "jkl\nmno\npqr"
1181        );
1182        assert_eq!(drain_events(&mut events), vec![]);
1183        assert_eq!(
1184            project.read_with(cx, |project, _| project.agent_location()),
1185            Some(AgentLocation {
1186                buffer: buffer.downgrade(),
1187                position: language::Anchor::MAX
1188            })
1189        );
1190    }
1191
1192    #[gpui::test(iterations = 100)]
1193    async fn test_indent_new_text_chunks(mut rng: StdRng) {
1194        let chunks = to_random_chunks(&mut rng, "    abc\n  def\n      ghi");
1195        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1196            Ok(EditParserEvent::NewTextChunk {
1197                chunk: chunk.clone(),
1198                done: index == chunks.len() - 1,
1199            })
1200        }));
1201        let indented_chunks =
1202            EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1203                .collect::<Vec<_>>()
1204                .await;
1205        let new_text = indented_chunks
1206            .into_iter()
1207            .collect::<Result<String>>()
1208            .unwrap();
1209        assert_eq!(new_text, "      abc\n    def\n        ghi");
1210    }
1211
1212    #[gpui::test(iterations = 100)]
1213    async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1214        let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1215        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1216            Ok(EditParserEvent::NewTextChunk {
1217                chunk: chunk.clone(),
1218                done: index == chunks.len() - 1,
1219            })
1220        }));
1221        let indented_chunks =
1222            EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1223                .collect::<Vec<_>>()
1224                .await;
1225        let new_text = indented_chunks
1226            .into_iter()
1227            .collect::<Result<String>>()
1228            .unwrap();
1229        assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1230    }
1231
1232    #[gpui::test(iterations = 100)]
1233    async fn test_random_indents(mut rng: StdRng) {
1234        let len = rng.gen_range(1..=100);
1235        let new_text = util::RandomCharIter::new(&mut rng)
1236            .with_simple_text()
1237            .take(len)
1238            .collect::<String>();
1239        let new_text = new_text
1240            .split('\n')
1241            .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1242            .collect::<Vec<_>>()
1243            .join("\n");
1244        let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1245
1246        let chunks = to_random_chunks(&mut rng, &new_text);
1247        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1248            Ok(EditParserEvent::NewTextChunk {
1249                chunk: chunk.clone(),
1250                done: index == chunks.len() - 1,
1251            })
1252        }));
1253        let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1254            .collect::<Vec<_>>()
1255            .await;
1256        let actual_reindented_text = reindented_chunks
1257            .into_iter()
1258            .collect::<Result<String>>()
1259            .unwrap();
1260        let expected_reindented_text = new_text
1261            .split('\n')
1262            .map(|line| {
1263                if let Some(ix) = line.find(|c| c != ' ') {
1264                    let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1265                    format!("{}{}", " ".repeat(new_indent), &line[ix..])
1266                } else {
1267                    line.to_string()
1268                }
1269            })
1270            .collect::<Vec<_>>()
1271            .join("\n");
1272        assert_eq!(actual_reindented_text, expected_reindented_text);
1273    }
1274
1275    fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1276        let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1277        let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1278        chunk_indices.sort();
1279        chunk_indices.push(input.len());
1280
1281        let mut chunks = Vec::new();
1282        let mut last_ix = 0;
1283        for chunk_ix in chunk_indices {
1284            chunks.push(input[last_ix..chunk_ix].to_string());
1285            last_ix = chunk_ix;
1286        }
1287        chunks
1288    }
1289
1290    fn simulate_llm_output(
1291        agent: &EditAgent,
1292        output: &str,
1293        rng: &mut StdRng,
1294        cx: &mut TestAppContext,
1295    ) {
1296        let executor = cx.executor();
1297        let chunks = to_random_chunks(rng, output);
1298        let model = agent.model.clone();
1299        cx.background_spawn(async move {
1300            for chunk in chunks {
1301                executor.simulate_random_delay().await;
1302                model.as_fake().stream_last_completion_response(chunk);
1303            }
1304            model.as_fake().end_last_completion_stream();
1305        })
1306        .detach();
1307    }
1308
1309    async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1310        cx.update(settings::init);
1311        cx.update(Project::init_settings);
1312        let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1313        let model = Arc::new(FakeLanguageModel::default());
1314        let action_log = cx.new(|_| ActionLog::new(project.clone()));
1315        EditAgent::new(model, project, action_log, Templates::new())
1316    }
1317
1318    fn drain_events(
1319        stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1320    ) -> Vec<EditAgentOutputEvent> {
1321        let mut events = Vec::new();
1322        while let Ok(Some(event)) = stream.try_next() {
1323            events.push(event);
1324        }
1325        events
1326    }
1327}