edit_agent.rs

   1mod create_file_parser;
   2mod edit_parser;
   3#[cfg(test)]
   4mod evals;
   5mod streaming_fuzzy_matcher;
   6
   7use crate::{Template, Templates};
   8use anyhow::Result;
   9use assistant_tool::ActionLog;
  10use create_file_parser::{CreateFileParser, CreateFileParserEvent};
  11use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
  12use futures::{
  13    Stream, StreamExt,
  14    channel::mpsc::{self, UnboundedReceiver},
  15    pin_mut,
  16    stream::BoxStream,
  17};
  18use gpui::{AppContext, AsyncApp, Entity, Task};
  19use language::{Anchor, Buffer, BufferSnapshot, LineIndent, Point, TextBufferSnapshot};
  20use language_model::{
  21    LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
  22    LanguageModelToolChoice, MessageContent, Role,
  23};
  24use project::{AgentLocation, Project};
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize};
  27use std::{cmp, iter, mem, ops::Range, path::PathBuf, pin::Pin, sync::Arc, task::Poll};
  28use streaming_diff::{CharOperation, StreamingDiff};
  29use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
  30use util::debug_panic;
  31use zed_llm_client::CompletionIntent;
  32
  33#[derive(Serialize)]
  34struct CreateFilePromptTemplate {
  35    path: Option<PathBuf>,
  36    edit_description: String,
  37}
  38
  39impl Template for CreateFilePromptTemplate {
  40    const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
  41}
  42
  43#[derive(Serialize)]
  44struct EditFilePromptTemplate {
  45    path: Option<PathBuf>,
  46    edit_description: String,
  47}
  48
  49impl Template for EditFilePromptTemplate {
  50    const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
  51}
  52
  53#[derive(Clone, Debug, PartialEq, Eq)]
  54pub enum EditAgentOutputEvent {
  55    ResolvingEditRange(Range<Anchor>),
  56    UnresolvedEditRange,
  57    Edited,
  58}
  59
  60#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  61pub struct EditAgentOutput {
  62    pub raw_edits: String,
  63    pub parser_metrics: EditParserMetrics,
  64}
  65
  66#[derive(Clone)]
  67pub struct EditAgent {
  68    model: Arc<dyn LanguageModel>,
  69    action_log: Entity<ActionLog>,
  70    project: Entity<Project>,
  71    templates: Arc<Templates>,
  72}
  73
  74impl EditAgent {
  75    pub fn new(
  76        model: Arc<dyn LanguageModel>,
  77        project: Entity<Project>,
  78        action_log: Entity<ActionLog>,
  79        templates: Arc<Templates>,
  80    ) -> Self {
  81        EditAgent {
  82            model,
  83            project,
  84            action_log,
  85            templates,
  86        }
  87    }
  88
  89    pub fn overwrite(
  90        &self,
  91        buffer: Entity<Buffer>,
  92        edit_description: String,
  93        conversation: &LanguageModelRequest,
  94        cx: &mut AsyncApp,
  95    ) -> (
  96        Task<Result<EditAgentOutput>>,
  97        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
  98    ) {
  99        let this = self.clone();
 100        let (events_tx, events_rx) = mpsc::unbounded();
 101        let conversation = conversation.clone();
 102        let output = cx.spawn(async move |cx| {
 103            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 104            let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
 105            let prompt = CreateFilePromptTemplate {
 106                path,
 107                edit_description,
 108            }
 109            .render(&this.templates)?;
 110            let new_chunks = this
 111                .request(conversation, CompletionIntent::CreateFile, prompt, cx)
 112                .await?;
 113
 114            let (output, mut inner_events) = this.overwrite_with_chunks(buffer, new_chunks, cx);
 115            while let Some(event) = inner_events.next().await {
 116                events_tx.unbounded_send(event).ok();
 117            }
 118            output.await
 119        });
 120        (output, events_rx)
 121    }
 122
 123    fn overwrite_with_chunks(
 124        &self,
 125        buffer: Entity<Buffer>,
 126        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 127        cx: &mut AsyncApp,
 128    ) -> (
 129        Task<Result<EditAgentOutput>>,
 130        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 131    ) {
 132        let (output_events_tx, output_events_rx) = mpsc::unbounded();
 133        let (parse_task, parse_rx) = Self::parse_create_file_chunks(edit_chunks, cx);
 134        let this = self.clone();
 135        let task = cx.spawn(async move |cx| {
 136            this.action_log
 137                .update(cx, |log, cx| log.buffer_created(buffer.clone(), cx))?;
 138            this.overwrite_with_chunks_internal(buffer, parse_rx, output_events_tx, cx)
 139                .await?;
 140            parse_task.await
 141        });
 142        (task, output_events_rx)
 143    }
 144
 145    async fn overwrite_with_chunks_internal(
 146        &self,
 147        buffer: Entity<Buffer>,
 148        mut parse_rx: UnboundedReceiver<Result<CreateFileParserEvent>>,
 149        output_events_tx: mpsc::UnboundedSender<EditAgentOutputEvent>,
 150        cx: &mut AsyncApp,
 151    ) -> Result<()> {
 152        cx.update(|cx| {
 153            buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
 154            self.action_log.update(cx, |log, cx| {
 155                log.buffer_edited(buffer.clone(), cx);
 156            });
 157            self.project.update(cx, |project, cx| {
 158                project.set_agent_location(
 159                    Some(AgentLocation {
 160                        buffer: buffer.downgrade(),
 161                        position: language::Anchor::MAX,
 162                    }),
 163                    cx,
 164                )
 165            });
 166            output_events_tx
 167                .unbounded_send(EditAgentOutputEvent::Edited)
 168                .ok();
 169        })?;
 170
 171        while let Some(event) = parse_rx.next().await {
 172            match event? {
 173                CreateFileParserEvent::NewTextChunk { chunk } => {
 174                    cx.update(|cx| {
 175                        buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
 176                        self.action_log
 177                            .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 178                        self.project.update(cx, |project, cx| {
 179                            project.set_agent_location(
 180                                Some(AgentLocation {
 181                                    buffer: buffer.downgrade(),
 182                                    position: language::Anchor::MAX,
 183                                }),
 184                                cx,
 185                            )
 186                        });
 187                    })?;
 188                    output_events_tx
 189                        .unbounded_send(EditAgentOutputEvent::Edited)
 190                        .ok();
 191                }
 192            }
 193        }
 194
 195        Ok(())
 196    }
 197
 198    pub fn edit(
 199        &self,
 200        buffer: Entity<Buffer>,
 201        edit_description: String,
 202        conversation: &LanguageModelRequest,
 203        cx: &mut AsyncApp,
 204    ) -> (
 205        Task<Result<EditAgentOutput>>,
 206        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 207    ) {
 208        let this = self.clone();
 209        let (events_tx, events_rx) = mpsc::unbounded();
 210        let conversation = conversation.clone();
 211        let output = cx.spawn(async move |cx| {
 212            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 213            let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
 214            let prompt = EditFilePromptTemplate {
 215                path,
 216                edit_description,
 217            }
 218            .render(&this.templates)?;
 219            let edit_chunks = this
 220                .request(conversation, CompletionIntent::EditFile, prompt, cx)
 221                .await?;
 222            this.apply_edit_chunks(buffer, edit_chunks, events_tx, cx)
 223                .await
 224        });
 225        (output, events_rx)
 226    }
 227
 228    async fn apply_edit_chunks(
 229        &self,
 230        buffer: Entity<Buffer>,
 231        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 232        output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
 233        cx: &mut AsyncApp,
 234    ) -> Result<EditAgentOutput> {
 235        self.action_log
 236            .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx))?;
 237
 238        let (output, edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
 239        let mut edit_events = edit_events.peekable();
 240        while let Some(edit_event) = Pin::new(&mut edit_events).peek().await {
 241            // Skip events until we're at the start of a new edit.
 242            let Ok(EditParserEvent::OldTextChunk { .. }) = edit_event else {
 243                edit_events.next().await.unwrap()?;
 244                continue;
 245            };
 246
 247            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 248
 249            // Resolve the old text in the background, updating the agent
 250            // location as we keep refining which range it corresponds to.
 251            let (resolve_old_text, mut old_range) =
 252                Self::resolve_old_text(snapshot.text.clone(), edit_events, cx);
 253            while let Ok(old_range) = old_range.recv().await {
 254                if let Some(old_range) = old_range {
 255                    let old_range = snapshot.anchor_before(old_range.start)
 256                        ..snapshot.anchor_before(old_range.end);
 257                    self.project.update(cx, |project, cx| {
 258                        project.set_agent_location(
 259                            Some(AgentLocation {
 260                                buffer: buffer.downgrade(),
 261                                position: old_range.end,
 262                            }),
 263                            cx,
 264                        );
 265                    })?;
 266                    output_events
 267                        .unbounded_send(EditAgentOutputEvent::ResolvingEditRange(old_range))
 268                        .ok();
 269                }
 270            }
 271
 272            let (edit_events_, resolved_old_text) = resolve_old_text.await?;
 273            edit_events = edit_events_;
 274
 275            // If we can't resolve the old text, restart the loop waiting for a
 276            // new edit (or for the stream to end).
 277            let Some(resolved_old_text) = resolved_old_text else {
 278                output_events
 279                    .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
 280                    .ok();
 281                continue;
 282            };
 283
 284            // Compute edits in the background and apply them as they become
 285            // available.
 286            let (compute_edits, edits) =
 287                Self::compute_edits(snapshot, resolved_old_text, edit_events, cx);
 288            let mut edits = edits.ready_chunks(32);
 289            while let Some(edits) = edits.next().await {
 290                if edits.is_empty() {
 291                    continue;
 292                }
 293
 294                // Edit the buffer and report edits to the action log as part of the
 295                // same effect cycle, otherwise the edit will be reported as if the
 296                // user made it.
 297                cx.update(|cx| {
 298                    let max_edit_end = buffer.update(cx, |buffer, cx| {
 299                        buffer.edit(edits.iter().cloned(), None, cx);
 300                        let max_edit_end = buffer
 301                            .summaries_for_anchors::<Point, _>(
 302                                edits.iter().map(|(range, _)| &range.end),
 303                            )
 304                            .max()
 305                            .unwrap();
 306                        buffer.anchor_before(max_edit_end)
 307                    });
 308                    self.action_log
 309                        .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 310                    self.project.update(cx, |project, cx| {
 311                        project.set_agent_location(
 312                            Some(AgentLocation {
 313                                buffer: buffer.downgrade(),
 314                                position: max_edit_end,
 315                            }),
 316                            cx,
 317                        );
 318                    });
 319                })?;
 320                output_events
 321                    .unbounded_send(EditAgentOutputEvent::Edited)
 322                    .ok();
 323            }
 324
 325            edit_events = compute_edits.await?;
 326        }
 327
 328        output.await
 329    }
 330
 331    fn parse_edit_chunks(
 332        chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 333        cx: &mut AsyncApp,
 334    ) -> (
 335        Task<Result<EditAgentOutput>>,
 336        UnboundedReceiver<Result<EditParserEvent>>,
 337    ) {
 338        let (tx, rx) = mpsc::unbounded();
 339        let output = cx.background_spawn(async move {
 340            pin_mut!(chunks);
 341
 342            let mut parser = EditParser::new();
 343            let mut raw_edits = String::new();
 344            while let Some(chunk) = chunks.next().await {
 345                match chunk {
 346                    Ok(chunk) => {
 347                        raw_edits.push_str(&chunk);
 348                        for event in parser.push(&chunk) {
 349                            tx.unbounded_send(Ok(event))?;
 350                        }
 351                    }
 352                    Err(error) => {
 353                        tx.unbounded_send(Err(error.into()))?;
 354                    }
 355                }
 356            }
 357            Ok(EditAgentOutput {
 358                raw_edits,
 359                parser_metrics: parser.finish(),
 360            })
 361        });
 362        (output, rx)
 363    }
 364
 365    fn parse_create_file_chunks(
 366        chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 367        cx: &mut AsyncApp,
 368    ) -> (
 369        Task<Result<EditAgentOutput>>,
 370        UnboundedReceiver<Result<CreateFileParserEvent>>,
 371    ) {
 372        let (tx, rx) = mpsc::unbounded();
 373        let output = cx.background_spawn(async move {
 374            pin_mut!(chunks);
 375
 376            let mut parser = CreateFileParser::new();
 377            let mut raw_edits = String::new();
 378            while let Some(chunk) = chunks.next().await {
 379                match chunk {
 380                    Ok(chunk) => {
 381                        raw_edits.push_str(&chunk);
 382                        for event in parser.push(Some(&chunk)) {
 383                            tx.unbounded_send(Ok(event))?;
 384                        }
 385                    }
 386                    Err(error) => {
 387                        tx.unbounded_send(Err(error.into()))?;
 388                    }
 389                }
 390            }
 391            // Send final events with None to indicate completion
 392            for event in parser.push(None) {
 393                tx.unbounded_send(Ok(event))?;
 394            }
 395            Ok(EditAgentOutput {
 396                raw_edits,
 397                parser_metrics: EditParserMetrics::default(),
 398            })
 399        });
 400        (output, rx)
 401    }
 402
 403    fn resolve_old_text<T>(
 404        snapshot: TextBufferSnapshot,
 405        mut edit_events: T,
 406        cx: &mut AsyncApp,
 407    ) -> (
 408        Task<Result<(T, Option<ResolvedOldText>)>>,
 409        async_watch::Receiver<Option<Range<usize>>>,
 410    )
 411    where
 412        T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
 413    {
 414        let (old_range_tx, old_range_rx) = async_watch::channel(None);
 415        let task = cx.background_spawn(async move {
 416            let mut matcher = StreamingFuzzyMatcher::new(snapshot);
 417            while let Some(edit_event) = edit_events.next().await {
 418                let EditParserEvent::OldTextChunk { chunk, done } = edit_event? else {
 419                    break;
 420                };
 421
 422                old_range_tx.send(matcher.push(&chunk))?;
 423                if done {
 424                    break;
 425                }
 426            }
 427
 428            let old_range = matcher.finish();
 429            old_range_tx.send(old_range.clone())?;
 430            if let Some(old_range) = old_range {
 431                let line_indent =
 432                    LineIndent::from_iter(matcher.query_lines().first().unwrap().chars());
 433                Ok((
 434                    edit_events,
 435                    Some(ResolvedOldText {
 436                        range: old_range,
 437                        indent: line_indent,
 438                    }),
 439                ))
 440            } else {
 441                Ok((edit_events, None))
 442            }
 443        });
 444
 445        (task, old_range_rx)
 446    }
 447
 448    fn compute_edits<T>(
 449        snapshot: BufferSnapshot,
 450        resolved_old_text: ResolvedOldText,
 451        mut edit_events: T,
 452        cx: &mut AsyncApp,
 453    ) -> (
 454        Task<Result<T>>,
 455        UnboundedReceiver<(Range<Anchor>, Arc<str>)>,
 456    )
 457    where
 458        T: 'static + Send + Unpin + Stream<Item = Result<EditParserEvent>>,
 459    {
 460        let (edits_tx, edits_rx) = mpsc::unbounded();
 461        let compute_edits = cx.background_spawn(async move {
 462            let buffer_start_indent = snapshot
 463                .line_indent_for_row(snapshot.offset_to_point(resolved_old_text.range.start).row);
 464            let indent_delta = if buffer_start_indent.tabs > 0 {
 465                IndentDelta::Tabs(
 466                    buffer_start_indent.tabs as isize - resolved_old_text.indent.tabs as isize,
 467                )
 468            } else {
 469                IndentDelta::Spaces(
 470                    buffer_start_indent.spaces as isize - resolved_old_text.indent.spaces as isize,
 471                )
 472            };
 473
 474            let old_text = snapshot
 475                .text_for_range(resolved_old_text.range.clone())
 476                .collect::<String>();
 477            let mut diff = StreamingDiff::new(old_text);
 478            let mut edit_start = resolved_old_text.range.start;
 479            let mut new_text_chunks =
 480                Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
 481            let mut done = false;
 482            while !done {
 483                let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await {
 484                    diff.push_new(&new_text_chunk?)
 485                } else {
 486                    done = true;
 487                    mem::take(&mut diff).finish()
 488                };
 489
 490                for op in char_operations {
 491                    match op {
 492                        CharOperation::Insert { text } => {
 493                            let edit_start = snapshot.anchor_after(edit_start);
 494                            edits_tx.unbounded_send((edit_start..edit_start, Arc::from(text)))?;
 495                        }
 496                        CharOperation::Delete { bytes } => {
 497                            let edit_end = edit_start + bytes;
 498                            let edit_range =
 499                                snapshot.anchor_after(edit_start)..snapshot.anchor_before(edit_end);
 500                            edit_start = edit_end;
 501                            edits_tx.unbounded_send((edit_range, Arc::from("")))?;
 502                        }
 503                        CharOperation::Keep { bytes } => edit_start += bytes,
 504                    }
 505                }
 506            }
 507
 508            drop(new_text_chunks);
 509            anyhow::Ok(edit_events)
 510        });
 511
 512        (compute_edits, edits_rx)
 513    }
 514
 515    fn reindent_new_text_chunks(
 516        delta: IndentDelta,
 517        mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
 518    ) -> impl Stream<Item = Result<String>> {
 519        let mut buffer = String::new();
 520        let mut in_leading_whitespace = true;
 521        let mut done = false;
 522        futures::stream::poll_fn(move |cx| {
 523            while !done {
 524                let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
 525                    Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
 526                        (chunk, done)
 527                    }
 528                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
 529                    Poll::Pending => return Poll::Pending,
 530                    _ => return Poll::Ready(None),
 531                };
 532
 533                buffer.push_str(&chunk);
 534
 535                let mut indented_new_text = String::new();
 536                let mut start_ix = 0;
 537                let mut newlines = buffer.match_indices('\n').peekable();
 538                loop {
 539                    let (line_end, is_pending_line) = match newlines.next() {
 540                        Some((ix, _)) => (ix, false),
 541                        None => (buffer.len(), true),
 542                    };
 543                    let line = &buffer[start_ix..line_end];
 544
 545                    if in_leading_whitespace {
 546                        if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
 547                            // We found a non-whitespace character, adjust
 548                            // indentation based on the delta.
 549                            let new_indent_len =
 550                                cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
 551                            indented_new_text
 552                                .extend(iter::repeat(delta.character()).take(new_indent_len));
 553                            indented_new_text.push_str(&line[non_whitespace_ix..]);
 554                            in_leading_whitespace = false;
 555                        } else if is_pending_line {
 556                            // We're still in leading whitespace and this line is incomplete.
 557                            // Stop processing until we receive more input.
 558                            break;
 559                        } else {
 560                            // This line is entirely whitespace. Push it without indentation.
 561                            indented_new_text.push_str(line);
 562                        }
 563                    } else {
 564                        indented_new_text.push_str(line);
 565                    }
 566
 567                    if is_pending_line {
 568                        start_ix = line_end;
 569                        break;
 570                    } else {
 571                        in_leading_whitespace = true;
 572                        indented_new_text.push('\n');
 573                        start_ix = line_end + 1;
 574                    }
 575                }
 576                buffer.replace_range(..start_ix, "");
 577
 578                // This was the last chunk, push all the buffered content as-is.
 579                if is_last_chunk {
 580                    indented_new_text.push_str(&buffer);
 581                    buffer.clear();
 582                    done = true;
 583                }
 584
 585                if !indented_new_text.is_empty() {
 586                    return Poll::Ready(Some(Ok(indented_new_text)));
 587                }
 588            }
 589
 590            Poll::Ready(None)
 591        })
 592    }
 593
 594    async fn request(
 595        &self,
 596        mut conversation: LanguageModelRequest,
 597        intent: CompletionIntent,
 598        prompt: String,
 599        cx: &mut AsyncApp,
 600    ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
 601        let mut messages_iter = conversation.messages.iter_mut();
 602        if let Some(last_message) = messages_iter.next_back() {
 603            if last_message.role == Role::Assistant {
 604                let old_content_len = last_message.content.len();
 605                last_message
 606                    .content
 607                    .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
 608                let new_content_len = last_message.content.len();
 609
 610                // We just removed pending tool uses from the content of the
 611                // last message, so it doesn't make sense to cache it anymore
 612                // (e.g., the message will look very different on the next
 613                // request). Thus, we move the flag to the message prior to it,
 614                // as it will still be a valid prefix of the conversation.
 615                if old_content_len != new_content_len && last_message.cache {
 616                    if let Some(prev_message) = messages_iter.next_back() {
 617                        last_message.cache = false;
 618                        prev_message.cache = true;
 619                    }
 620                }
 621
 622                if last_message.content.is_empty() {
 623                    conversation.messages.pop();
 624                }
 625            } else {
 626                debug_panic!(
 627                    "Last message must be an Assistant tool calling! Got {:?}",
 628                    last_message.content
 629                );
 630            }
 631        }
 632
 633        conversation.messages.push(LanguageModelRequestMessage {
 634            role: Role::User,
 635            content: vec![MessageContent::Text(prompt)],
 636            cache: false,
 637        });
 638
 639        // Include tools in the request so that we can take advantage of
 640        // caching when ToolChoice::None is supported.
 641        let mut tool_choice = None;
 642        let mut tools = Vec::new();
 643        if !conversation.tools.is_empty()
 644            && self
 645                .model
 646                .supports_tool_choice(LanguageModelToolChoice::None)
 647        {
 648            tool_choice = Some(LanguageModelToolChoice::None);
 649            tools = conversation.tools.clone();
 650        }
 651
 652        let request = LanguageModelRequest {
 653            thread_id: conversation.thread_id,
 654            prompt_id: conversation.prompt_id,
 655            intent: Some(intent),
 656            mode: conversation.mode,
 657            messages: conversation.messages,
 658            tool_choice,
 659            tools,
 660            stop: Vec::new(),
 661            temperature: None,
 662        };
 663
 664        Ok(self.model.stream_completion_text(request, cx).await?.stream)
 665    }
 666}
 667
 668struct ResolvedOldText {
 669    range: Range<usize>,
 670    indent: LineIndent,
 671}
 672
 673#[derive(Copy, Clone, Debug)]
 674enum IndentDelta {
 675    Spaces(isize),
 676    Tabs(isize),
 677}
 678
 679impl IndentDelta {
 680    fn character(&self) -> char {
 681        match self {
 682            IndentDelta::Spaces(_) => ' ',
 683            IndentDelta::Tabs(_) => '\t',
 684        }
 685    }
 686
 687    fn len(&self) -> isize {
 688        match self {
 689            IndentDelta::Spaces(n) => *n,
 690            IndentDelta::Tabs(n) => *n,
 691        }
 692    }
 693}
 694
 695#[cfg(test)]
 696mod tests {
 697    use super::*;
 698    use fs::FakeFs;
 699    use futures::stream;
 700    use gpui::{AppContext, TestAppContext};
 701    use indoc::indoc;
 702    use language_model::fake_provider::FakeLanguageModel;
 703    use project::{AgentLocation, Project};
 704    use rand::prelude::*;
 705    use rand::rngs::StdRng;
 706    use std::cmp;
 707
 708    #[gpui::test(iterations = 100)]
 709    async fn test_empty_old_text(cx: &mut TestAppContext, mut rng: StdRng) {
 710        let agent = init_test(cx).await;
 711        let buffer = cx.new(|cx| {
 712            Buffer::local(
 713                indoc! {"
 714                    abc
 715                    def
 716                    ghi
 717                "},
 718                cx,
 719            )
 720        });
 721        let (apply, _events) = agent.edit(
 722            buffer.clone(),
 723            String::new(),
 724            &LanguageModelRequest::default(),
 725            &mut cx.to_async(),
 726        );
 727        cx.run_until_parked();
 728
 729        simulate_llm_output(
 730            &agent,
 731            indoc! {"
 732                <old_text></old_text>
 733                <new_text>jkl</new_text>
 734                <old_text>def</old_text>
 735                <new_text>DEF</new_text>
 736            "},
 737            &mut rng,
 738            cx,
 739        );
 740        apply.await.unwrap();
 741
 742        pretty_assertions::assert_eq!(
 743            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 744            indoc! {"
 745                abc
 746                DEF
 747                ghi
 748            "}
 749        );
 750    }
 751
 752    #[gpui::test(iterations = 100)]
 753    async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
 754        let agent = init_test(cx).await;
 755        let buffer = cx.new(|cx| {
 756            Buffer::local(
 757                indoc! {"
 758                    lorem
 759                            ipsum
 760                            dolor
 761                            sit
 762                "},
 763                cx,
 764            )
 765        });
 766        let (apply, _events) = agent.edit(
 767            buffer.clone(),
 768            String::new(),
 769            &LanguageModelRequest::default(),
 770            &mut cx.to_async(),
 771        );
 772        cx.run_until_parked();
 773
 774        simulate_llm_output(
 775            &agent,
 776            indoc! {"
 777                <old_text>
 778                    ipsum
 779                    dolor
 780                    sit
 781                </old_text>
 782                <new_text>
 783                    ipsum
 784                    dolor
 785                    sit
 786                amet
 787                </new_text>
 788            "},
 789            &mut rng,
 790            cx,
 791        );
 792        apply.await.unwrap();
 793
 794        pretty_assertions::assert_eq!(
 795            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 796            indoc! {"
 797                lorem
 798                        ipsum
 799                        dolor
 800                        sit
 801                    amet
 802            "}
 803        );
 804    }
 805
 806    #[gpui::test(iterations = 100)]
 807    async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
 808        let agent = init_test(cx).await;
 809        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 810        let (apply, _events) = agent.edit(
 811            buffer.clone(),
 812            String::new(),
 813            &LanguageModelRequest::default(),
 814            &mut cx.to_async(),
 815        );
 816        cx.run_until_parked();
 817
 818        simulate_llm_output(
 819            &agent,
 820            indoc! {"
 821                <old_text>
 822                def
 823                </old_text>
 824                <new_text>
 825                DEF
 826                </new_text>
 827
 828                <old_text>
 829                DEF
 830                </old_text>
 831                <new_text>
 832                DeF
 833                </new_text>
 834            "},
 835            &mut rng,
 836            cx,
 837        );
 838        apply.await.unwrap();
 839
 840        assert_eq!(
 841            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 842            "abc\nDeF\nghi"
 843        );
 844    }
 845
 846    #[gpui::test(iterations = 100)]
 847    async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
 848        let agent = init_test(cx).await;
 849        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 850        let (apply, _events) = agent.edit(
 851            buffer.clone(),
 852            String::new(),
 853            &LanguageModelRequest::default(),
 854            &mut cx.to_async(),
 855        );
 856        cx.run_until_parked();
 857
 858        simulate_llm_output(
 859            &agent,
 860            indoc! {"
 861                <old_text>
 862                jkl
 863                </old_text>
 864                <new_text>
 865                mno
 866                </new_text>
 867
 868                <old_text>
 869                abc
 870                </old_text>
 871                <new_text>
 872                ABC
 873                </new_text>
 874            "},
 875            &mut rng,
 876            cx,
 877        );
 878        apply.await.unwrap();
 879
 880        assert_eq!(
 881            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 882            "ABC\ndef\nghi"
 883        );
 884    }
 885
 886    #[gpui::test]
 887    async fn test_edit_events(cx: &mut TestAppContext) {
 888        let agent = init_test(cx).await;
 889        let model = agent.model.as_fake();
 890        let project = agent
 891            .action_log
 892            .read_with(cx, |log, _| log.project().clone());
 893        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi\njkl", cx));
 894
 895        let mut async_cx = cx.to_async();
 896        let (apply, mut events) = agent.edit(
 897            buffer.clone(),
 898            String::new(),
 899            &LanguageModelRequest::default(),
 900            &mut async_cx,
 901        );
 902        cx.run_until_parked();
 903
 904        model.stream_last_completion_response("<old_text>a");
 905        cx.run_until_parked();
 906        assert_eq!(drain_events(&mut events), vec![]);
 907        assert_eq!(
 908            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 909            "abc\ndef\nghi\njkl"
 910        );
 911        assert_eq!(
 912            project.read_with(cx, |project, _| project.agent_location()),
 913            None
 914        );
 915
 916        model.stream_last_completion_response("bc</old_text>");
 917        cx.run_until_parked();
 918        assert_eq!(
 919            drain_events(&mut events),
 920            vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
 921                cx,
 922                |buffer, _| buffer.anchor_before(Point::new(0, 0))
 923                    ..buffer.anchor_before(Point::new(0, 3))
 924            ))]
 925        );
 926        assert_eq!(
 927            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 928            "abc\ndef\nghi\njkl"
 929        );
 930        assert_eq!(
 931            project.read_with(cx, |project, _| project.agent_location()),
 932            Some(AgentLocation {
 933                buffer: buffer.downgrade(),
 934                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
 935            })
 936        );
 937
 938        model.stream_last_completion_response("<new_text>abX");
 939        cx.run_until_parked();
 940        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 941        assert_eq!(
 942            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 943            "abXc\ndef\nghi\njkl"
 944        );
 945        assert_eq!(
 946            project.read_with(cx, |project, _| project.agent_location()),
 947            Some(AgentLocation {
 948                buffer: buffer.downgrade(),
 949                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 3)))
 950            })
 951        );
 952
 953        model.stream_last_completion_response("cY");
 954        cx.run_until_parked();
 955        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 956        assert_eq!(
 957            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 958            "abXcY\ndef\nghi\njkl"
 959        );
 960        assert_eq!(
 961            project.read_with(cx, |project, _| project.agent_location()),
 962            Some(AgentLocation {
 963                buffer: buffer.downgrade(),
 964                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
 965            })
 966        );
 967
 968        model.stream_last_completion_response("</new_text>");
 969        model.stream_last_completion_response("<old_text>hall");
 970        cx.run_until_parked();
 971        assert_eq!(drain_events(&mut events), vec![]);
 972        assert_eq!(
 973            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 974            "abXcY\ndef\nghi\njkl"
 975        );
 976        assert_eq!(
 977            project.read_with(cx, |project, _| project.agent_location()),
 978            Some(AgentLocation {
 979                buffer: buffer.downgrade(),
 980                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
 981            })
 982        );
 983
 984        model.stream_last_completion_response("ucinated old</old_text>");
 985        model.stream_last_completion_response("<new_text>");
 986        cx.run_until_parked();
 987        assert_eq!(
 988            drain_events(&mut events),
 989            vec![EditAgentOutputEvent::UnresolvedEditRange]
 990        );
 991        assert_eq!(
 992            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 993            "abXcY\ndef\nghi\njkl"
 994        );
 995        assert_eq!(
 996            project.read_with(cx, |project, _| project.agent_location()),
 997            Some(AgentLocation {
 998                buffer: buffer.downgrade(),
 999                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1000            })
1001        );
1002
1003        model.stream_last_completion_response("hallucinated new</new_");
1004        model.stream_last_completion_response("text>");
1005        cx.run_until_parked();
1006        assert_eq!(drain_events(&mut events), vec![]);
1007        assert_eq!(
1008            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1009            "abXcY\ndef\nghi\njkl"
1010        );
1011        assert_eq!(
1012            project.read_with(cx, |project, _| project.agent_location()),
1013            Some(AgentLocation {
1014                buffer: buffer.downgrade(),
1015                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(0, 5)))
1016            })
1017        );
1018
1019        model.stream_last_completion_response("<old_text>\nghi\nj");
1020        cx.run_until_parked();
1021        assert_eq!(
1022            drain_events(&mut events),
1023            vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1024                cx,
1025                |buffer, _| buffer.anchor_before(Point::new(2, 0))
1026                    ..buffer.anchor_before(Point::new(2, 3))
1027            ))]
1028        );
1029        assert_eq!(
1030            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1031            "abXcY\ndef\nghi\njkl"
1032        );
1033        assert_eq!(
1034            project.read_with(cx, |project, _| project.agent_location()),
1035            Some(AgentLocation {
1036                buffer: buffer.downgrade(),
1037                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1038            })
1039        );
1040
1041        model.stream_last_completion_response("kl</old_text>");
1042        model.stream_last_completion_response("<new_text>");
1043        cx.run_until_parked();
1044        assert_eq!(
1045            drain_events(&mut events),
1046            vec![EditAgentOutputEvent::ResolvingEditRange(buffer.read_with(
1047                cx,
1048                |buffer, _| buffer.anchor_before(Point::new(2, 0))
1049                    ..buffer.anchor_before(Point::new(3, 3))
1050            ))]
1051        );
1052        assert_eq!(
1053            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1054            "abXcY\ndef\nghi\njkl"
1055        );
1056        assert_eq!(
1057            project.read_with(cx, |project, _| project.agent_location()),
1058            Some(AgentLocation {
1059                buffer: buffer.downgrade(),
1060                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(3, 3)))
1061            })
1062        );
1063
1064        model.stream_last_completion_response("GHI</new_text>");
1065        cx.run_until_parked();
1066        assert_eq!(
1067            drain_events(&mut events),
1068            vec![EditAgentOutputEvent::Edited]
1069        );
1070        assert_eq!(
1071            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1072            "abXcY\ndef\nGHI"
1073        );
1074        assert_eq!(
1075            project.read_with(cx, |project, _| project.agent_location()),
1076            Some(AgentLocation {
1077                buffer: buffer.downgrade(),
1078                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1079            })
1080        );
1081
1082        model.end_last_completion_stream();
1083        apply.await.unwrap();
1084        assert_eq!(
1085            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1086            "abXcY\ndef\nGHI"
1087        );
1088        assert_eq!(drain_events(&mut events), vec![]);
1089        assert_eq!(
1090            project.read_with(cx, |project, _| project.agent_location()),
1091            Some(AgentLocation {
1092                buffer: buffer.downgrade(),
1093                position: buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(2, 3)))
1094            })
1095        );
1096    }
1097
1098    #[gpui::test]
1099    async fn test_overwrite_events(cx: &mut TestAppContext) {
1100        let agent = init_test(cx).await;
1101        let project = agent
1102            .action_log
1103            .read_with(cx, |log, _| log.project().clone());
1104        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
1105        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1106        let (apply, mut events) = agent.overwrite_with_chunks(
1107            buffer.clone(),
1108            chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
1109            &mut cx.to_async(),
1110        );
1111
1112        cx.run_until_parked();
1113        assert_eq!(
1114            drain_events(&mut events),
1115            vec![EditAgentOutputEvent::Edited]
1116        );
1117        assert_eq!(
1118            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1119            ""
1120        );
1121        assert_eq!(
1122            project.read_with(cx, |project, _| project.agent_location()),
1123            Some(AgentLocation {
1124                buffer: buffer.downgrade(),
1125                position: language::Anchor::MAX
1126            })
1127        );
1128
1129        chunks_tx.unbounded_send("```\njkl\n").unwrap();
1130        cx.run_until_parked();
1131        assert_eq!(
1132            drain_events(&mut events),
1133            vec![EditAgentOutputEvent::Edited]
1134        );
1135        assert_eq!(
1136            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1137            "jkl"
1138        );
1139        assert_eq!(
1140            project.read_with(cx, |project, _| project.agent_location()),
1141            Some(AgentLocation {
1142                buffer: buffer.downgrade(),
1143                position: language::Anchor::MAX
1144            })
1145        );
1146
1147        chunks_tx.unbounded_send("mno\n").unwrap();
1148        cx.run_until_parked();
1149        assert_eq!(
1150            drain_events(&mut events),
1151            vec![EditAgentOutputEvent::Edited]
1152        );
1153        assert_eq!(
1154            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1155            "jkl\nmno"
1156        );
1157        assert_eq!(
1158            project.read_with(cx, |project, _| project.agent_location()),
1159            Some(AgentLocation {
1160                buffer: buffer.downgrade(),
1161                position: language::Anchor::MAX
1162            })
1163        );
1164
1165        chunks_tx.unbounded_send("pqr\n```").unwrap();
1166        cx.run_until_parked();
1167        assert_eq!(
1168            drain_events(&mut events),
1169            vec![EditAgentOutputEvent::Edited]
1170        );
1171        assert_eq!(
1172            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1173            "jkl\nmno\npqr"
1174        );
1175        assert_eq!(
1176            project.read_with(cx, |project, _| project.agent_location()),
1177            Some(AgentLocation {
1178                buffer: buffer.downgrade(),
1179                position: language::Anchor::MAX
1180            })
1181        );
1182
1183        drop(chunks_tx);
1184        apply.await.unwrap();
1185        assert_eq!(
1186            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
1187            "jkl\nmno\npqr"
1188        );
1189        assert_eq!(drain_events(&mut events), vec![]);
1190        assert_eq!(
1191            project.read_with(cx, |project, _| project.agent_location()),
1192            Some(AgentLocation {
1193                buffer: buffer.downgrade(),
1194                position: language::Anchor::MAX
1195            })
1196        );
1197    }
1198
1199    #[gpui::test(iterations = 100)]
1200    async fn test_indent_new_text_chunks(mut rng: StdRng) {
1201        let chunks = to_random_chunks(&mut rng, "    abc\n  def\n      ghi");
1202        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1203            Ok(EditParserEvent::NewTextChunk {
1204                chunk: chunk.clone(),
1205                done: index == chunks.len() - 1,
1206            })
1207        }));
1208        let indented_chunks =
1209            EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1210                .collect::<Vec<_>>()
1211                .await;
1212        let new_text = indented_chunks
1213            .into_iter()
1214            .collect::<Result<String>>()
1215            .unwrap();
1216        assert_eq!(new_text, "      abc\n    def\n        ghi");
1217    }
1218
1219    #[gpui::test(iterations = 100)]
1220    async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1221        let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1222        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1223            Ok(EditParserEvent::NewTextChunk {
1224                chunk: chunk.clone(),
1225                done: index == chunks.len() - 1,
1226            })
1227        }));
1228        let indented_chunks =
1229            EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1230                .collect::<Vec<_>>()
1231                .await;
1232        let new_text = indented_chunks
1233            .into_iter()
1234            .collect::<Result<String>>()
1235            .unwrap();
1236        assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1237    }
1238
1239    #[gpui::test(iterations = 100)]
1240    async fn test_random_indents(mut rng: StdRng) {
1241        let len = rng.gen_range(1..=100);
1242        let new_text = util::RandomCharIter::new(&mut rng)
1243            .with_simple_text()
1244            .take(len)
1245            .collect::<String>();
1246        let new_text = new_text
1247            .split('\n')
1248            .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1249            .collect::<Vec<_>>()
1250            .join("\n");
1251        let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1252
1253        let chunks = to_random_chunks(&mut rng, &new_text);
1254        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1255            Ok(EditParserEvent::NewTextChunk {
1256                chunk: chunk.clone(),
1257                done: index == chunks.len() - 1,
1258            })
1259        }));
1260        let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1261            .collect::<Vec<_>>()
1262            .await;
1263        let actual_reindented_text = reindented_chunks
1264            .into_iter()
1265            .collect::<Result<String>>()
1266            .unwrap();
1267        let expected_reindented_text = new_text
1268            .split('\n')
1269            .map(|line| {
1270                if let Some(ix) = line.find(|c| c != ' ') {
1271                    let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1272                    format!("{}{}", " ".repeat(new_indent), &line[ix..])
1273                } else {
1274                    line.to_string()
1275                }
1276            })
1277            .collect::<Vec<_>>()
1278            .join("\n");
1279        assert_eq!(actual_reindented_text, expected_reindented_text);
1280    }
1281
1282    fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1283        let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1284        let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1285        chunk_indices.sort();
1286        chunk_indices.push(input.len());
1287
1288        let mut chunks = Vec::new();
1289        let mut last_ix = 0;
1290        for chunk_ix in chunk_indices {
1291            chunks.push(input[last_ix..chunk_ix].to_string());
1292            last_ix = chunk_ix;
1293        }
1294        chunks
1295    }
1296
1297    fn simulate_llm_output(
1298        agent: &EditAgent,
1299        output: &str,
1300        rng: &mut StdRng,
1301        cx: &mut TestAppContext,
1302    ) {
1303        let executor = cx.executor();
1304        let chunks = to_random_chunks(rng, output);
1305        let model = agent.model.clone();
1306        cx.background_spawn(async move {
1307            for chunk in chunks {
1308                executor.simulate_random_delay().await;
1309                model.as_fake().stream_last_completion_response(chunk);
1310            }
1311            model.as_fake().end_last_completion_stream();
1312        })
1313        .detach();
1314    }
1315
1316    async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1317        cx.update(settings::init);
1318        cx.update(Project::init_settings);
1319        let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1320        let model = Arc::new(FakeLanguageModel::default());
1321        let action_log = cx.new(|_| ActionLog::new(project.clone()));
1322        EditAgent::new(model, project, action_log, Templates::new())
1323    }
1324
1325    fn drain_events(
1326        stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
1327    ) -> Vec<EditAgentOutputEvent> {
1328        let mut events = Vec::new();
1329        while let Ok(Some(event)) = stream.try_next() {
1330            events.push(event);
1331        }
1332        events
1333    }
1334}