edit_agent.rs

   1mod edit_parser;
   2#[cfg(test)]
   3mod evals;
   4
   5use crate::{Template, Templates};
   6use aho_corasick::AhoCorasick;
   7use anyhow::Result;
   8use assistant_tool::ActionLog;
   9use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
  10use futures::{
  11    Stream, StreamExt,
  12    channel::mpsc::{self, UnboundedReceiver},
  13    pin_mut,
  14    stream::BoxStream,
  15};
  16use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
  17use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
  18use language_model::{
  19    LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
  20    MessageContent, Role,
  21};
  22use serde::Serialize;
  23use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
  24use streaming_diff::{CharOperation, StreamingDiff};
  25
  26#[derive(Serialize)]
  27struct CreateFilePromptTemplate {
  28    path: Option<PathBuf>,
  29    edit_description: String,
  30}
  31
  32impl Template for CreateFilePromptTemplate {
  33    const TEMPLATE_NAME: &'static str = "create_file_prompt.hbs";
  34}
  35
  36#[derive(Serialize)]
  37struct EditFilePromptTemplate {
  38    path: Option<PathBuf>,
  39    edit_description: String,
  40}
  41
  42impl Template for EditFilePromptTemplate {
  43    const TEMPLATE_NAME: &'static str = "edit_file_prompt.hbs";
  44}
  45
  46#[derive(Clone, Debug, PartialEq, Eq)]
  47pub enum EditAgentOutputEvent {
  48    Edited,
  49    OldTextNotFound(SharedString),
  50}
  51
  52#[derive(Clone, Debug)]
  53pub struct EditAgentOutput {
  54    pub _raw_edits: String,
  55    pub _parser_metrics: EditParserMetrics,
  56}
  57
  58#[derive(Clone)]
  59pub struct EditAgent {
  60    model: Arc<dyn LanguageModel>,
  61    action_log: Entity<ActionLog>,
  62    templates: Arc<Templates>,
  63}
  64
  65impl EditAgent {
  66    pub fn new(
  67        model: Arc<dyn LanguageModel>,
  68        action_log: Entity<ActionLog>,
  69        templates: Arc<Templates>,
  70    ) -> Self {
  71        EditAgent {
  72            model,
  73            action_log,
  74            templates,
  75        }
  76    }
  77
  78    pub fn overwrite(
  79        &self,
  80        buffer: Entity<Buffer>,
  81        edit_description: String,
  82        previous_messages: Vec<LanguageModelRequestMessage>,
  83        cx: &mut AsyncApp,
  84    ) -> (
  85        Task<Result<EditAgentOutput>>,
  86        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
  87    ) {
  88        let this = self.clone();
  89        let (events_tx, events_rx) = mpsc::unbounded();
  90        let output = cx.spawn(async move |cx| {
  91            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
  92            let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
  93            let prompt = CreateFilePromptTemplate {
  94                path,
  95                edit_description,
  96            }
  97            .render(&this.templates)?;
  98            let new_chunks = this.request(previous_messages, prompt, cx).await?;
  99
 100            let (output, mut inner_events) = this.replace_text_with_chunks(buffer, new_chunks, cx);
 101            while let Some(event) = inner_events.next().await {
 102                events_tx.unbounded_send(event).ok();
 103            }
 104            output.await
 105        });
 106        (output, events_rx)
 107    }
 108
 109    fn replace_text_with_chunks(
 110        &self,
 111        buffer: Entity<Buffer>,
 112        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 113        cx: &mut AsyncApp,
 114    ) -> (
 115        Task<Result<EditAgentOutput>>,
 116        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 117    ) {
 118        let (output_events_tx, output_events_rx) = mpsc::unbounded();
 119        let this = self.clone();
 120        let task = cx.spawn(async move |cx| {
 121            // Ensure the buffer is tracked by the action log.
 122            this.action_log
 123                .update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
 124
 125            cx.update(|cx| {
 126                buffer.update(cx, |buffer, cx| buffer.set_text("", cx));
 127                this.action_log
 128                    .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 129            })?;
 130
 131            let mut raw_edits = String::new();
 132            pin_mut!(edit_chunks);
 133            while let Some(chunk) = edit_chunks.next().await {
 134                let chunk = chunk?;
 135                raw_edits.push_str(&chunk);
 136                cx.update(|cx| {
 137                    buffer.update(cx, |buffer, cx| buffer.append(chunk, cx));
 138                    this.action_log
 139                        .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 140                })?;
 141                output_events_tx
 142                    .unbounded_send(EditAgentOutputEvent::Edited)
 143                    .ok();
 144            }
 145
 146            Ok(EditAgentOutput {
 147                _raw_edits: raw_edits,
 148                _parser_metrics: EditParserMetrics::default(),
 149            })
 150        });
 151        (task, output_events_rx)
 152    }
 153
 154    pub fn edit(
 155        &self,
 156        buffer: Entity<Buffer>,
 157        edit_description: String,
 158        previous_messages: Vec<LanguageModelRequestMessage>,
 159        cx: &mut AsyncApp,
 160    ) -> (
 161        Task<Result<EditAgentOutput>>,
 162        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 163    ) {
 164        let this = self.clone();
 165        let (events_tx, events_rx) = mpsc::unbounded();
 166        let output = cx.spawn(async move |cx| {
 167            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 168            let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
 169            let prompt = EditFilePromptTemplate {
 170                path,
 171                edit_description,
 172            }
 173            .render(&this.templates)?;
 174            let edit_chunks = this.request(previous_messages, prompt, cx).await?;
 175
 176            let (output, mut inner_events) = this.apply_edit_chunks(buffer, edit_chunks, cx);
 177            while let Some(event) = inner_events.next().await {
 178                events_tx.unbounded_send(event).ok();
 179            }
 180            output.await
 181        });
 182        (output, events_rx)
 183    }
 184
 185    fn apply_edit_chunks(
 186        &self,
 187        buffer: Entity<Buffer>,
 188        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 189        cx: &mut AsyncApp,
 190    ) -> (
 191        Task<Result<EditAgentOutput>>,
 192        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 193    ) {
 194        let (output_events_tx, output_events_rx) = mpsc::unbounded();
 195        let this = self.clone();
 196        let task = cx.spawn(async move |mut cx| {
 197            this.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
 198                .await
 199        });
 200        (task, output_events_rx)
 201    }
 202
 203    async fn apply_edits_internal(
 204        &self,
 205        buffer: Entity<Buffer>,
 206        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 207        output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
 208        cx: &mut AsyncApp,
 209    ) -> Result<EditAgentOutput> {
 210        // Ensure the buffer is tracked by the action log.
 211        self.action_log
 212            .update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
 213
 214        let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
 215        while let Some(edit_event) = edit_events.next().await {
 216            let EditParserEvent::OldText(old_text_query) = edit_event? else {
 217                continue;
 218            };
 219            let old_text_query = SharedString::from(old_text_query);
 220
 221            let (edits_tx, edits_rx) = mpsc::unbounded();
 222            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 223            let old_range = cx
 224                .background_spawn({
 225                    let snapshot = snapshot.clone();
 226                    let old_text_query = old_text_query.clone();
 227                    async move { Self::resolve_location(&snapshot, &old_text_query) }
 228                })
 229                .await;
 230            let Some(old_range) = old_range else {
 231                // We couldn't find the old text in the buffer. Report the error.
 232                output_events
 233                    .unbounded_send(EditAgentOutputEvent::OldTextNotFound(old_text_query))
 234                    .ok();
 235                continue;
 236            };
 237
 238            let compute_edits = cx.background_spawn(async move {
 239                let buffer_start_indent =
 240                    snapshot.line_indent_for_row(snapshot.offset_to_point(old_range.start).row);
 241                let old_text_start_indent = old_text_query
 242                    .lines()
 243                    .next()
 244                    .map_or(buffer_start_indent, |line| {
 245                        LineIndent::from_iter(line.chars())
 246                    });
 247                let indent_delta = if buffer_start_indent.tabs > 0 {
 248                    IndentDelta::Tabs(
 249                        buffer_start_indent.tabs as isize - old_text_start_indent.tabs as isize,
 250                    )
 251                } else {
 252                    IndentDelta::Spaces(
 253                        buffer_start_indent.spaces as isize - old_text_start_indent.spaces as isize,
 254                    )
 255                };
 256
 257                let old_text = snapshot
 258                    .text_for_range(old_range.clone())
 259                    .collect::<String>();
 260                let mut diff = StreamingDiff::new(old_text);
 261                let mut edit_start = old_range.start;
 262                let mut new_text_chunks =
 263                    Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
 264                let mut done = false;
 265                while !done {
 266                    let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await
 267                    {
 268                        diff.push_new(&new_text_chunk?)
 269                    } else {
 270                        done = true;
 271                        mem::take(&mut diff).finish()
 272                    };
 273
 274                    for op in char_operations {
 275                        match op {
 276                            CharOperation::Insert { text } => {
 277                                let edit_start = snapshot.anchor_after(edit_start);
 278                                edits_tx.unbounded_send((edit_start..edit_start, text))?;
 279                            }
 280                            CharOperation::Delete { bytes } => {
 281                                let edit_end = edit_start + bytes;
 282                                let edit_range = snapshot.anchor_after(edit_start)
 283                                    ..snapshot.anchor_before(edit_end);
 284                                edit_start = edit_end;
 285                                edits_tx.unbounded_send((edit_range, String::new()))?;
 286                            }
 287                            CharOperation::Keep { bytes } => edit_start += bytes,
 288                        }
 289                    }
 290                }
 291
 292                drop(new_text_chunks);
 293                anyhow::Ok(edit_events)
 294            });
 295
 296            // TODO: group all edits into one transaction
 297            let mut edits_rx = edits_rx.ready_chunks(32);
 298            while let Some(edits) = edits_rx.next().await {
 299                // Edit the buffer and report edits to the action log as part of the
 300                // same effect cycle, otherwise the edit will be reported as if the
 301                // user made it.
 302                cx.update(|cx| {
 303                    buffer.update(cx, |buffer, cx| buffer.edit(edits, None, cx));
 304                    self.action_log
 305                        .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
 306                })?;
 307                output_events
 308                    .unbounded_send(EditAgentOutputEvent::Edited)
 309                    .ok();
 310            }
 311
 312            edit_events = compute_edits.await?;
 313        }
 314
 315        output.await
 316    }
 317
 318    fn parse_edit_chunks(
 319        chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 320        cx: &mut AsyncApp,
 321    ) -> (
 322        Task<Result<EditAgentOutput>>,
 323        UnboundedReceiver<Result<EditParserEvent>>,
 324    ) {
 325        let (tx, rx) = mpsc::unbounded();
 326        let output = cx.background_spawn(async move {
 327            pin_mut!(chunks);
 328
 329            let mut parser = EditParser::new();
 330            let mut raw_edits = String::new();
 331            while let Some(chunk) = chunks.next().await {
 332                match chunk {
 333                    Ok(chunk) => {
 334                        raw_edits.push_str(&chunk);
 335                        for event in parser.push(&chunk) {
 336                            tx.unbounded_send(Ok(event))?;
 337                        }
 338                    }
 339                    Err(error) => {
 340                        tx.unbounded_send(Err(error.into()))?;
 341                    }
 342                }
 343            }
 344            Ok(EditAgentOutput {
 345                _raw_edits: raw_edits,
 346                _parser_metrics: parser.finish(),
 347            })
 348        });
 349        (output, rx)
 350    }
 351
 352    fn reindent_new_text_chunks(
 353        delta: IndentDelta,
 354        mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
 355    ) -> impl Stream<Item = Result<String>> {
 356        let mut buffer = String::new();
 357        let mut in_leading_whitespace = true;
 358        let mut done = false;
 359        futures::stream::poll_fn(move |cx| {
 360            while !done {
 361                let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
 362                    Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
 363                        (chunk, done)
 364                    }
 365                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
 366                    Poll::Pending => return Poll::Pending,
 367                    _ => return Poll::Ready(None),
 368                };
 369
 370                buffer.push_str(&chunk);
 371
 372                let mut indented_new_text = String::new();
 373                let mut start_ix = 0;
 374                let mut newlines = buffer.match_indices('\n').peekable();
 375                loop {
 376                    let (line_end, is_pending_line) = match newlines.next() {
 377                        Some((ix, _)) => (ix, false),
 378                        None => (buffer.len(), true),
 379                    };
 380                    let line = &buffer[start_ix..line_end];
 381
 382                    if in_leading_whitespace {
 383                        if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
 384                            // We found a non-whitespace character, adjust
 385                            // indentation based on the delta.
 386                            let new_indent_len =
 387                                cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
 388                            indented_new_text
 389                                .extend(iter::repeat(delta.character()).take(new_indent_len));
 390                            indented_new_text.push_str(&line[non_whitespace_ix..]);
 391                            in_leading_whitespace = false;
 392                        } else if is_pending_line {
 393                            // We're still in leading whitespace and this line is incomplete.
 394                            // Stop processing until we receive more input.
 395                            break;
 396                        } else {
 397                            // This line is entirely whitespace. Push it without indentation.
 398                            indented_new_text.push_str(line);
 399                        }
 400                    } else {
 401                        indented_new_text.push_str(line);
 402                    }
 403
 404                    if is_pending_line {
 405                        start_ix = line_end;
 406                        break;
 407                    } else {
 408                        in_leading_whitespace = true;
 409                        indented_new_text.push('\n');
 410                        start_ix = line_end + 1;
 411                    }
 412                }
 413                buffer.replace_range(..start_ix, "");
 414
 415                // This was the last chunk, push all the buffered content as-is.
 416                if is_last_chunk {
 417                    indented_new_text.push_str(&buffer);
 418                    buffer.clear();
 419                    done = true;
 420                }
 421
 422                if !indented_new_text.is_empty() {
 423                    return Poll::Ready(Some(Ok(indented_new_text)));
 424                }
 425            }
 426
 427            Poll::Ready(None)
 428        })
 429    }
 430
 431    async fn request(
 432        &self,
 433        mut messages: Vec<LanguageModelRequestMessage>,
 434        prompt: String,
 435        cx: &mut AsyncApp,
 436    ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
 437        let mut message_content = Vec::new();
 438        if let Some(last_message) = messages.last_mut() {
 439            if last_message.role == Role::Assistant {
 440                last_message
 441                    .content
 442                    .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
 443                if last_message.content.is_empty() {
 444                    messages.pop();
 445                }
 446            }
 447        }
 448        message_content.push(MessageContent::Text(prompt));
 449        messages.push(LanguageModelRequestMessage {
 450            role: Role::User,
 451            content: message_content,
 452            cache: false,
 453        });
 454
 455        let request = LanguageModelRequest {
 456            messages,
 457            ..Default::default()
 458        };
 459        Ok(self.model.stream_completion_text(request, cx).await?.stream)
 460    }
 461
 462    fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
 463        let range = Self::resolve_location_exact(buffer, search_query)
 464            .or_else(|| Self::resolve_location_fuzzy(buffer, search_query))?;
 465
 466        // Expand the range to include entire lines.
 467        let mut start = buffer.offset_to_point(buffer.clip_offset(range.start, Bias::Left));
 468        start.column = 0;
 469        let mut end = buffer.offset_to_point(buffer.clip_offset(range.end, Bias::Right));
 470        if end.column > 0 {
 471            end.column = buffer.line_len(end.row);
 472        }
 473
 474        Some(buffer.point_to_offset(start)..buffer.point_to_offset(end))
 475    }
 476
 477    fn resolve_location_exact(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
 478        let search = AhoCorasick::new([search_query]).ok()?;
 479        let mat = search
 480            .stream_find_iter(buffer.bytes_in_range(0..buffer.len()))
 481            .next()?
 482            .expect("buffer can't error");
 483        Some(mat.range())
 484    }
 485
 486    fn resolve_location_fuzzy(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
 487        const INSERTION_COST: u32 = 3;
 488        const DELETION_COST: u32 = 10;
 489
 490        let buffer_line_count = buffer.max_point().row as usize + 1;
 491        let query_line_count = search_query.lines().count();
 492        let mut matrix = SearchMatrix::new(query_line_count + 1, buffer_line_count + 1);
 493        let mut leading_deletion_cost = 0_u32;
 494        for (row, query_line) in search_query.lines().enumerate() {
 495            let query_line = query_line.trim();
 496            leading_deletion_cost = leading_deletion_cost.saturating_add(DELETION_COST);
 497            matrix.set(
 498                row + 1,
 499                0,
 500                SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
 501            );
 502
 503            let mut buffer_lines = buffer.as_rope().chunks().lines();
 504            let mut col = 0;
 505            while let Some(buffer_line) = buffer_lines.next() {
 506                let buffer_line = buffer_line.trim();
 507                let up = SearchState::new(
 508                    matrix.get(row, col + 1).cost.saturating_add(DELETION_COST),
 509                    SearchDirection::Up,
 510                );
 511                let left = SearchState::new(
 512                    matrix.get(row + 1, col).cost.saturating_add(INSERTION_COST),
 513                    SearchDirection::Left,
 514                );
 515                let diagonal = SearchState::new(
 516                    if fuzzy_eq(query_line, buffer_line) {
 517                        matrix.get(row, col).cost
 518                    } else {
 519                        matrix
 520                            .get(row, col)
 521                            .cost
 522                            .saturating_add(DELETION_COST + INSERTION_COST)
 523                    },
 524                    SearchDirection::Diagonal,
 525                );
 526                matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
 527                col += 1;
 528            }
 529        }
 530
 531        // Traceback to find the best match
 532        let mut buffer_row_end = buffer_line_count as u32;
 533        let mut best_cost = u32::MAX;
 534        for col in 1..=buffer_line_count {
 535            let cost = matrix.get(query_line_count, col).cost;
 536            if cost < best_cost {
 537                best_cost = cost;
 538                buffer_row_end = col as u32;
 539            }
 540        }
 541
 542        let mut matched_lines = 0;
 543        let mut query_row = query_line_count;
 544        let mut buffer_row_start = buffer_row_end;
 545        while query_row > 0 && buffer_row_start > 0 {
 546            let current = matrix.get(query_row, buffer_row_start as usize);
 547            match current.direction {
 548                SearchDirection::Diagonal => {
 549                    query_row -= 1;
 550                    buffer_row_start -= 1;
 551                    matched_lines += 1;
 552                }
 553                SearchDirection::Up => {
 554                    query_row -= 1;
 555                }
 556                SearchDirection::Left => {
 557                    buffer_row_start -= 1;
 558                }
 559            }
 560        }
 561
 562        let matched_buffer_row_count = buffer_row_end - buffer_row_start;
 563        let matched_ratio =
 564            matched_lines as f32 / (matched_buffer_row_count as f32).max(query_line_count as f32);
 565        if matched_ratio >= 0.8 {
 566            let buffer_start_ix = buffer.point_to_offset(Point::new(buffer_row_start, 0));
 567            let buffer_end_ix = buffer.point_to_offset(Point::new(
 568                buffer_row_end - 1,
 569                buffer.line_len(buffer_row_end - 1),
 570            ));
 571            Some(buffer_start_ix..buffer_end_ix)
 572        } else {
 573            None
 574        }
 575    }
 576}
 577
 578fn fuzzy_eq(left: &str, right: &str) -> bool {
 579    let min_levenshtein = left.len().abs_diff(right.len());
 580    let min_normalized_levenshtein =
 581        1. - (min_levenshtein as f32 / cmp::max(left.len(), right.len()) as f32);
 582    if min_normalized_levenshtein < 0.8 {
 583        return false;
 584    }
 585
 586    strsim::normalized_levenshtein(left, right) >= 0.8
 587}
 588
 589#[derive(Copy, Clone, Debug)]
 590enum IndentDelta {
 591    Spaces(isize),
 592    Tabs(isize),
 593}
 594
 595impl IndentDelta {
 596    fn character(&self) -> char {
 597        match self {
 598            IndentDelta::Spaces(_) => ' ',
 599            IndentDelta::Tabs(_) => '\t',
 600        }
 601    }
 602
 603    fn len(&self) -> isize {
 604        match self {
 605            IndentDelta::Spaces(n) => *n,
 606            IndentDelta::Tabs(n) => *n,
 607        }
 608    }
 609}
 610
 611#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
 612enum SearchDirection {
 613    Up,
 614    Left,
 615    Diagonal,
 616}
 617
 618#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
 619struct SearchState {
 620    cost: u32,
 621    direction: SearchDirection,
 622}
 623
 624impl SearchState {
 625    fn new(cost: u32, direction: SearchDirection) -> Self {
 626        Self { cost, direction }
 627    }
 628}
 629
 630struct SearchMatrix {
 631    cols: usize,
 632    data: Vec<SearchState>,
 633}
 634
 635impl SearchMatrix {
 636    fn new(rows: usize, cols: usize) -> Self {
 637        SearchMatrix {
 638            cols,
 639            data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
 640        }
 641    }
 642
 643    fn get(&self, row: usize, col: usize) -> SearchState {
 644        self.data[row * self.cols + col]
 645    }
 646
 647    fn set(&mut self, row: usize, col: usize, cost: SearchState) {
 648        self.data[row * self.cols + col] = cost;
 649    }
 650}
 651
 652#[cfg(test)]
 653mod tests {
 654    use super::*;
 655    use fs::FakeFs;
 656    use futures::stream;
 657    use gpui::{App, AppContext, TestAppContext};
 658    use indoc::indoc;
 659    use language_model::fake_provider::FakeLanguageModel;
 660    use project::Project;
 661    use rand::prelude::*;
 662    use rand::rngs::StdRng;
 663    use std::cmp;
 664    use unindent::Unindent;
 665    use util::test::{generate_marked_text, marked_text_ranges};
 666
 667    #[gpui::test(iterations = 100)]
 668    async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
 669        let agent = init_test(cx).await;
 670        let buffer = cx.new(|cx| {
 671            Buffer::local(
 672                indoc! {"
 673                    lorem
 674                            ipsum
 675                            dolor
 676                            sit
 677                "},
 678                cx,
 679            )
 680        });
 681        let raw_edits = simulate_llm_output(
 682            indoc! {"
 683                <old_text>
 684                    ipsum
 685                    dolor
 686                    sit
 687                </old_text>
 688                <new_text>
 689                    ipsum
 690                    dolor
 691                    sit
 692                amet
 693                </new_text>
 694            "},
 695            &mut rng,
 696            cx,
 697        );
 698        let (apply, _events) =
 699            agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
 700        apply.await.unwrap();
 701        pretty_assertions::assert_eq!(
 702            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 703            indoc! {"
 704                lorem
 705                        ipsum
 706                        dolor
 707                        sit
 708                    amet
 709            "}
 710        );
 711    }
 712
 713    #[gpui::test(iterations = 100)]
 714    async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
 715        let agent = init_test(cx).await;
 716        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 717        let raw_edits = simulate_llm_output(
 718            indoc! {"
 719                <old_text>
 720                def
 721                </old_text>
 722                <new_text>
 723                DEF
 724                </new_text>
 725
 726                <old_text>
 727                DEF
 728                </old_text>
 729                <new_text>
 730                DeF
 731                </new_text>
 732            "},
 733            &mut rng,
 734            cx,
 735        );
 736        let (apply, _events) =
 737            agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
 738        apply.await.unwrap();
 739        assert_eq!(
 740            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 741            "abc\nDeF\nghi"
 742        );
 743    }
 744
 745    #[gpui::test(iterations = 100)]
 746    async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
 747        let agent = init_test(cx).await;
 748        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 749        let raw_edits = simulate_llm_output(
 750            indoc! {"
 751                <old_text>
 752                jkl
 753                </old_text>
 754                <new_text>
 755                mno
 756                </new_text>
 757
 758                <old_text>
 759                abc
 760                </old_text>
 761                <new_text>
 762                ABC
 763                </new_text>
 764            "},
 765            &mut rng,
 766            cx,
 767        );
 768        let (apply, _events) =
 769            agent.apply_edit_chunks(buffer.clone(), raw_edits, &mut cx.to_async());
 770        apply.await.unwrap();
 771        assert_eq!(
 772            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 773            "ABC\ndef\nghi"
 774        );
 775    }
 776
 777    #[gpui::test]
 778    async fn test_events(cx: &mut TestAppContext) {
 779        let agent = init_test(cx).await;
 780        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 781        let (chunks_tx, chunks_rx) = mpsc::unbounded();
 782        let (apply, mut events) = agent.apply_edit_chunks(
 783            buffer.clone(),
 784            chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
 785            &mut cx.to_async(),
 786        );
 787
 788        chunks_tx.unbounded_send("<old_text>a").unwrap();
 789        cx.run_until_parked();
 790        assert_eq!(drain_events(&mut events), vec![]);
 791        assert_eq!(
 792            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 793            "abc\ndef\nghi"
 794        );
 795
 796        chunks_tx.unbounded_send("bc</old_text>").unwrap();
 797        cx.run_until_parked();
 798        assert_eq!(drain_events(&mut events), vec![]);
 799        assert_eq!(
 800            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 801            "abc\ndef\nghi"
 802        );
 803
 804        chunks_tx.unbounded_send("<new_text>abX").unwrap();
 805        cx.run_until_parked();
 806        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 807        assert_eq!(
 808            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 809            "abXc\ndef\nghi"
 810        );
 811
 812        chunks_tx.unbounded_send("cY").unwrap();
 813        cx.run_until_parked();
 814        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 815        assert_eq!(
 816            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 817            "abXcY\ndef\nghi"
 818        );
 819
 820        chunks_tx.unbounded_send("</new_text>").unwrap();
 821        chunks_tx.unbounded_send("<old_text>hall").unwrap();
 822        cx.run_until_parked();
 823        assert_eq!(drain_events(&mut events), vec![]);
 824        assert_eq!(
 825            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 826            "abXcY\ndef\nghi"
 827        );
 828
 829        chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
 830        chunks_tx.unbounded_send("<new_text>").unwrap();
 831        cx.run_until_parked();
 832        assert_eq!(
 833            drain_events(&mut events),
 834            vec![EditAgentOutputEvent::OldTextNotFound(
 835                "hallucinated old".into()
 836            )]
 837        );
 838        assert_eq!(
 839            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 840            "abXcY\ndef\nghi"
 841        );
 842
 843        chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
 844        chunks_tx.unbounded_send("text>").unwrap();
 845        cx.run_until_parked();
 846        assert_eq!(drain_events(&mut events), vec![]);
 847        assert_eq!(
 848            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 849            "abXcY\ndef\nghi"
 850        );
 851
 852        chunks_tx.unbounded_send("<old_text>gh").unwrap();
 853        chunks_tx.unbounded_send("i</old_text>").unwrap();
 854        chunks_tx.unbounded_send("<new_text>").unwrap();
 855        cx.run_until_parked();
 856        assert_eq!(drain_events(&mut events), vec![]);
 857        assert_eq!(
 858            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 859            "abXcY\ndef\nghi"
 860        );
 861
 862        chunks_tx.unbounded_send("GHI</new_text>").unwrap();
 863        cx.run_until_parked();
 864        assert_eq!(
 865            drain_events(&mut events),
 866            vec![EditAgentOutputEvent::Edited]
 867        );
 868        assert_eq!(
 869            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 870            "abXcY\ndef\nGHI"
 871        );
 872
 873        drop(chunks_tx);
 874        apply.await.unwrap();
 875        assert_eq!(
 876            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 877            "abXcY\ndef\nGHI"
 878        );
 879        assert_eq!(drain_events(&mut events), vec![]);
 880
 881        fn drain_events(
 882            stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
 883        ) -> Vec<EditAgentOutputEvent> {
 884            let mut events = Vec::new();
 885            while let Ok(Some(event)) = stream.try_next() {
 886                events.push(event);
 887            }
 888            events
 889        }
 890    }
 891
 892    #[gpui::test]
 893    fn test_resolve_location(cx: &mut App) {
 894        assert_location_resolution(
 895            concat!(
 896                "    Lorem\n",
 897                "«    ipsum»\n",
 898                "    dolor sit amet\n",
 899                "    consecteur",
 900            ),
 901            "ipsum",
 902            cx,
 903        );
 904
 905        assert_location_resolution(
 906            concat!(
 907                "    Lorem\n",
 908                "«    ipsum\n",
 909                "    dolor sit amet»\n",
 910                "    consecteur",
 911            ),
 912            "ipsum\ndolor sit amet",
 913            cx,
 914        );
 915
 916        assert_location_resolution(
 917            &"
 918            «fn foo1(a: usize) -> usize {
 919                40
 920 921
 922            fn foo2(b: usize) -> usize {
 923                42
 924            }
 925            "
 926            .unindent(),
 927            "fn foo1(a: usize) -> u32 {\n40\n}",
 928            cx,
 929        );
 930
 931        assert_location_resolution(
 932            &"
 933            class Something {
 934                one() { return 1; }
 935            «    two() { return 2222; }
 936                three() { return 333; }
 937                four() { return 4444; }
 938                five() { return 5555; }
 939                six() { return 6666; }»
 940                seven() { return 7; }
 941                eight() { return 8; }
 942            }
 943            "
 944            .unindent(),
 945            &"
 946                two() { return 2222; }
 947                four() { return 4444; }
 948                five() { return 5555; }
 949                six() { return 6666; }
 950            "
 951            .unindent(),
 952            cx,
 953        );
 954
 955        assert_location_resolution(
 956            &"
 957                use std::ops::Range;
 958                use std::sync::Mutex;
 959                use std::{
 960                    collections::HashMap,
 961                    env,
 962                    ffi::{OsStr, OsString},
 963                    fs,
 964                    io::{BufRead, BufReader},
 965                    mem,
 966                    path::{Path, PathBuf},
 967                    process::Command,
 968                    sync::LazyLock,
 969                    time::SystemTime,
 970                };
 971            "
 972            .unindent(),
 973            &"
 974                use std::collections::{HashMap, HashSet};
 975                use std::ffi::{OsStr, OsString};
 976                use std::fmt::Write as _;
 977                use std::fs;
 978                use std::io::{BufReader, Read, Write};
 979                use std::mem;
 980                use std::path::{Path, PathBuf};
 981                use std::process::Command;
 982                use std::sync::Arc;
 983            "
 984            .unindent(),
 985            cx,
 986        );
 987
 988        assert_location_resolution(
 989            indoc! {"
 990                impl Foo {
 991                    fn new() -> Self {
 992                        Self {
 993                            subscriptions: vec![
 994                                cx.observe_window_activation(window, |editor, window, cx| {
 995                                    let active = window.is_window_active();
 996                                    editor.blink_manager.update(cx, |blink_manager, cx| {
 997                                        if active {
 998                                            blink_manager.enable(cx);
 999                                        } else {
1000                                            blink_manager.disable(cx);
1001                                        }
1002                                    });
1003                                }),
1004                            ];
1005                        }
1006                    }
1007                }
1008            "},
1009            concat!(
1010                "                    editor.blink_manager.update(cx, |blink_manager, cx| {\n",
1011                "                        blink_manager.enable(cx);\n",
1012                "                    });",
1013            ),
1014            cx,
1015        );
1016
1017        assert_location_resolution(
1018            indoc! {r#"
1019                let tool = cx
1020                    .update(|cx| working_set.tool(&tool_name, cx))
1021                    .map_err(|err| {
1022                        anyhow!("Failed to look up tool '{}': {}", tool_name, err)
1023                    })?;
1024
1025                let Some(tool) = tool else {
1026                    return Err(anyhow!("Tool '{}' not found", tool_name));
1027                };
1028
1029                let project = project.clone();
1030                let action_log = action_log.clone();
1031                let messages = messages.clone();
1032                let tool_result = cx
1033                    .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
1034                    .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
1035
1036                tasks.push(tool_result.output);
1037            "#},
1038            concat!(
1039                "let tool_result = cx\n",
1040                "    .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
1041                "    .output;",
1042            ),
1043            cx,
1044        );
1045    }
1046
1047    #[gpui::test(iterations = 100)]
1048    async fn test_indent_new_text_chunks(mut rng: StdRng) {
1049        let chunks = to_random_chunks(&mut rng, "    abc\n  def\n      ghi");
1050        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1051            Ok(EditParserEvent::NewTextChunk {
1052                chunk: chunk.clone(),
1053                done: index == chunks.len() - 1,
1054            })
1055        }));
1056        let indented_chunks =
1057            EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
1058                .collect::<Vec<_>>()
1059                .await;
1060        let new_text = indented_chunks
1061            .into_iter()
1062            .collect::<Result<String>>()
1063            .unwrap();
1064        assert_eq!(new_text, "      abc\n    def\n        ghi");
1065    }
1066
1067    #[gpui::test(iterations = 100)]
1068    async fn test_outdent_new_text_chunks(mut rng: StdRng) {
1069        let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
1070        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1071            Ok(EditParserEvent::NewTextChunk {
1072                chunk: chunk.clone(),
1073                done: index == chunks.len() - 1,
1074            })
1075        }));
1076        let indented_chunks =
1077            EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
1078                .collect::<Vec<_>>()
1079                .await;
1080        let new_text = indented_chunks
1081            .into_iter()
1082            .collect::<Result<String>>()
1083            .unwrap();
1084        assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
1085    }
1086
1087    #[gpui::test(iterations = 100)]
1088    async fn test_random_indents(mut rng: StdRng) {
1089        let len = rng.gen_range(1..=100);
1090        let new_text = util::RandomCharIter::new(&mut rng)
1091            .with_simple_text()
1092            .take(len)
1093            .collect::<String>();
1094        let new_text = new_text
1095            .split('\n')
1096            .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1097            .collect::<Vec<_>>()
1098            .join("\n");
1099        let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1100
1101        let chunks = to_random_chunks(&mut rng, &new_text);
1102        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1103            Ok(EditParserEvent::NewTextChunk {
1104                chunk: chunk.clone(),
1105                done: index == chunks.len() - 1,
1106            })
1107        }));
1108        let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1109            .collect::<Vec<_>>()
1110            .await;
1111        let actual_reindented_text = reindented_chunks
1112            .into_iter()
1113            .collect::<Result<String>>()
1114            .unwrap();
1115        let expected_reindented_text = new_text
1116            .split('\n')
1117            .map(|line| {
1118                if let Some(ix) = line.find(|c| c != ' ') {
1119                    let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1120                    format!("{}{}", " ".repeat(new_indent), &line[ix..])
1121                } else {
1122                    line.to_string()
1123                }
1124            })
1125            .collect::<Vec<_>>()
1126            .join("\n");
1127        assert_eq!(actual_reindented_text, expected_reindented_text);
1128    }
1129
1130    #[track_caller]
1131    fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) {
1132        let (text, _) = marked_text_ranges(text_with_expected_range, false);
1133        let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
1134        let snapshot = buffer.read(cx).snapshot();
1135        let mut ranges = Vec::new();
1136        ranges.extend(EditAgent::resolve_location(&snapshot, query));
1137        let text_with_actual_range = generate_marked_text(&text, &ranges, false);
1138        pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
1139    }
1140
1141    fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1142        let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1143        let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1144        chunk_indices.sort();
1145        chunk_indices.push(input.len());
1146
1147        let mut chunks = Vec::new();
1148        let mut last_ix = 0;
1149        for chunk_ix in chunk_indices {
1150            chunks.push(input[last_ix..chunk_ix].to_string());
1151            last_ix = chunk_ix;
1152        }
1153        chunks
1154    }
1155
1156    fn simulate_llm_output(
1157        output: &str,
1158        rng: &mut StdRng,
1159        cx: &mut TestAppContext,
1160    ) -> impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>> {
1161        let executor = cx.executor();
1162        stream::iter(to_random_chunks(rng, output).into_iter().map(Ok)).then(move |chunk| {
1163            let executor = executor.clone();
1164            async move {
1165                executor.simulate_random_delay().await;
1166                chunk
1167            }
1168        })
1169    }
1170
1171    async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1172        cx.update(settings::init);
1173        cx.update(Project::init_settings);
1174        let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1175        let model = Arc::new(FakeLanguageModel::default());
1176        let action_log = cx.new(|_| ActionLog::new(project));
1177        EditAgent::new(model, action_log, Templates::new())
1178    }
1179}