edit_agent.rs

   1mod edit_parser;
   2#[cfg(test)]
   3mod evals;
   4
   5use crate::{Template, Templates};
   6use aho_corasick::AhoCorasick;
   7use anyhow::Result;
   8use assistant_tool::ActionLog;
   9use edit_parser::{EditParser, EditParserEvent, EditParserMetrics};
  10use futures::{
  11    Stream, StreamExt,
  12    channel::mpsc::{self, UnboundedReceiver},
  13    stream::BoxStream,
  14};
  15use gpui::{AppContext, AsyncApp, Entity, SharedString, Task};
  16use language::{Bias, Buffer, BufferSnapshot, LineIndent, Point};
  17use language_model::{
  18    LanguageModel, LanguageModelCompletionError, LanguageModelRequest, LanguageModelRequestMessage,
  19    MessageContent, Role,
  20};
  21use serde::Serialize;
  22use std::{cmp, iter, mem, ops::Range, path::PathBuf, sync::Arc, task::Poll};
  23use streaming_diff::{CharOperation, StreamingDiff};
  24
  25#[derive(Serialize)]
  26pub struct EditAgentTemplate {
  27    path: Option<PathBuf>,
  28    edit_description: String,
  29}
  30
  31impl Template for EditAgentTemplate {
  32    const TEMPLATE_NAME: &'static str = "edit_agent.hbs";
  33}
  34
  35#[derive(Clone, Debug, PartialEq, Eq)]
  36pub enum EditAgentOutputEvent {
  37    Edited,
  38    HallucinatedOldText(SharedString),
  39}
  40
  41#[derive(Clone, Debug)]
  42pub struct EditAgentOutput {
  43    pub _raw_edits: String,
  44    pub _parser_metrics: EditParserMetrics,
  45}
  46
  47#[derive(Clone)]
  48pub struct EditAgent {
  49    model: Arc<dyn LanguageModel>,
  50    action_log: Entity<ActionLog>,
  51    templates: Arc<Templates>,
  52}
  53
  54impl EditAgent {
  55    pub fn new(
  56        model: Arc<dyn LanguageModel>,
  57        action_log: Entity<ActionLog>,
  58        templates: Arc<Templates>,
  59    ) -> Self {
  60        EditAgent {
  61            model,
  62            action_log,
  63            templates,
  64        }
  65    }
  66
  67    pub fn edit(
  68        &self,
  69        buffer: Entity<Buffer>,
  70        edit_description: String,
  71        previous_messages: Vec<LanguageModelRequestMessage>,
  72        cx: &mut AsyncApp,
  73    ) -> (
  74        Task<Result<EditAgentOutput>>,
  75        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
  76    ) {
  77        let this = self.clone();
  78        let (events_tx, events_rx) = mpsc::unbounded();
  79        let output = cx.spawn(async move |cx| {
  80            let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
  81            let edit_chunks = this
  82                .request_edits(snapshot, edit_description, previous_messages, cx)
  83                .await?;
  84            let (output, mut inner_events) = this.apply_edits(buffer, edit_chunks, cx);
  85            while let Some(event) = inner_events.next().await {
  86                events_tx.unbounded_send(event).ok();
  87            }
  88            output.await
  89        });
  90        (output, events_rx)
  91    }
  92
  93    fn apply_edits(
  94        &self,
  95        buffer: Entity<Buffer>,
  96        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
  97        cx: &mut AsyncApp,
  98    ) -> (
  99        Task<Result<EditAgentOutput>>,
 100        mpsc::UnboundedReceiver<EditAgentOutputEvent>,
 101    ) {
 102        let (output_events_tx, output_events_rx) = mpsc::unbounded();
 103        let this = self.clone();
 104        let task = cx.spawn(async move |mut cx| {
 105            this.apply_edits_internal(buffer, edit_chunks, output_events_tx, &mut cx)
 106                .await
 107        });
 108        (task, output_events_rx)
 109    }
 110
 111    async fn apply_edits_internal(
 112        &self,
 113        buffer: Entity<Buffer>,
 114        edit_chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 115        output_events: mpsc::UnboundedSender<EditAgentOutputEvent>,
 116        cx: &mut AsyncApp,
 117    ) -> Result<EditAgentOutput> {
 118        // Ensure the buffer is tracked by the action log.
 119        self.action_log
 120            .update(cx, |log, cx| log.track_buffer(buffer.clone(), cx))?;
 121
 122        let (output, mut edit_events) = Self::parse_edit_chunks(edit_chunks, cx);
 123        while let Some(edit_event) = edit_events.next().await {
 124            let EditParserEvent::OldText(old_text_query) = edit_event? else {
 125                continue;
 126            };
 127            let old_text_query = SharedString::from(old_text_query);
 128
 129            let (edits_tx, edits_rx) = mpsc::unbounded();
 130            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 131            let old_range = cx
 132                .background_spawn({
 133                    let snapshot = snapshot.clone();
 134                    let old_text_query = old_text_query.clone();
 135                    async move { Self::resolve_location(&snapshot, &old_text_query) }
 136                })
 137                .await;
 138            let Some(old_range) = old_range else {
 139                // We couldn't find the old text in the buffer. Report the error.
 140                output_events
 141                    .unbounded_send(EditAgentOutputEvent::HallucinatedOldText(old_text_query))
 142                    .ok();
 143                continue;
 144            };
 145
 146            let compute_edits = cx.background_spawn(async move {
 147                let buffer_start_indent =
 148                    snapshot.line_indent_for_row(snapshot.offset_to_point(old_range.start).row);
 149                let old_text_start_indent = old_text_query
 150                    .lines()
 151                    .next()
 152                    .map_or(buffer_start_indent, |line| {
 153                        LineIndent::from_iter(line.chars())
 154                    });
 155                let indent_delta = if buffer_start_indent.tabs > 0 {
 156                    IndentDelta::Tabs(
 157                        buffer_start_indent.tabs as isize - old_text_start_indent.tabs as isize,
 158                    )
 159                } else {
 160                    IndentDelta::Spaces(
 161                        buffer_start_indent.spaces as isize - old_text_start_indent.spaces as isize,
 162                    )
 163                };
 164
 165                let old_text = snapshot
 166                    .text_for_range(old_range.clone())
 167                    .collect::<String>();
 168                let mut diff = StreamingDiff::new(old_text);
 169                let mut edit_start = old_range.start;
 170                let mut new_text_chunks =
 171                    Self::reindent_new_text_chunks(indent_delta, &mut edit_events);
 172                let mut done = false;
 173                while !done {
 174                    let char_operations = if let Some(new_text_chunk) = new_text_chunks.next().await
 175                    {
 176                        diff.push_new(&new_text_chunk?)
 177                    } else {
 178                        done = true;
 179                        mem::take(&mut diff).finish()
 180                    };
 181
 182                    for op in char_operations {
 183                        match op {
 184                            CharOperation::Insert { text } => {
 185                                let edit_start = snapshot.anchor_after(edit_start);
 186                                edits_tx.unbounded_send((edit_start..edit_start, text))?;
 187                            }
 188                            CharOperation::Delete { bytes } => {
 189                                let edit_end = edit_start + bytes;
 190                                let edit_range = snapshot.anchor_after(edit_start)
 191                                    ..snapshot.anchor_before(edit_end);
 192                                edit_start = edit_end;
 193                                edits_tx.unbounded_send((edit_range, String::new()))?;
 194                            }
 195                            CharOperation::Keep { bytes } => edit_start += bytes,
 196                        }
 197                    }
 198                }
 199
 200                drop(new_text_chunks);
 201                anyhow::Ok(edit_events)
 202            });
 203
 204            // TODO: group all edits into one transaction
 205            let mut edits_rx = edits_rx.ready_chunks(32);
 206            while let Some(edits) = edits_rx.next().await {
 207                // Edit the buffer and report edits to the action log as part of the
 208                // same effect cycle, otherwise the edit will be reported as if the
 209                // user made it.
 210                cx.update(|cx| {
 211                    buffer.update(cx, |buffer, cx| buffer.edit(edits, None, cx));
 212                    self.action_log
 213                        .update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx))
 214                })?;
 215                output_events
 216                    .unbounded_send(EditAgentOutputEvent::Edited)
 217                    .ok();
 218            }
 219
 220            edit_events = compute_edits.await?;
 221        }
 222
 223        output.await
 224    }
 225
 226    fn parse_edit_chunks(
 227        chunks: impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>>,
 228        cx: &mut AsyncApp,
 229    ) -> (
 230        Task<Result<EditAgentOutput>>,
 231        UnboundedReceiver<Result<EditParserEvent>>,
 232    ) {
 233        let (tx, rx) = mpsc::unbounded();
 234        let output = cx.background_spawn(async move {
 235            futures::pin_mut!(chunks);
 236
 237            let mut parser = EditParser::new();
 238            let mut raw_edits = String::new();
 239            while let Some(chunk) = chunks.next().await {
 240                match chunk {
 241                    Ok(chunk) => {
 242                        raw_edits.push_str(&chunk);
 243                        for event in parser.push(&chunk) {
 244                            tx.unbounded_send(Ok(event))?;
 245                        }
 246                    }
 247                    Err(error) => {
 248                        tx.unbounded_send(Err(error.into()))?;
 249                    }
 250                }
 251            }
 252            Ok(EditAgentOutput {
 253                _raw_edits: raw_edits,
 254                _parser_metrics: parser.finish(),
 255            })
 256        });
 257        (output, rx)
 258    }
 259
 260    fn reindent_new_text_chunks(
 261        delta: IndentDelta,
 262        mut stream: impl Unpin + Stream<Item = Result<EditParserEvent>>,
 263    ) -> impl Stream<Item = Result<String>> {
 264        let mut buffer = String::new();
 265        let mut in_leading_whitespace = true;
 266        let mut done = false;
 267        futures::stream::poll_fn(move |cx| {
 268            while !done {
 269                let (chunk, is_last_chunk) = match stream.poll_next_unpin(cx) {
 270                    Poll::Ready(Some(Ok(EditParserEvent::NewTextChunk { chunk, done }))) => {
 271                        (chunk, done)
 272                    }
 273                    Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err))),
 274                    Poll::Pending => return Poll::Pending,
 275                    _ => return Poll::Ready(None),
 276                };
 277
 278                buffer.push_str(&chunk);
 279
 280                let mut indented_new_text = String::new();
 281                let mut start_ix = 0;
 282                let mut newlines = buffer.match_indices('\n').peekable();
 283                loop {
 284                    let (line_end, is_pending_line) = match newlines.next() {
 285                        Some((ix, _)) => (ix, false),
 286                        None => (buffer.len(), true),
 287                    };
 288                    let line = &buffer[start_ix..line_end];
 289
 290                    if in_leading_whitespace {
 291                        if let Some(non_whitespace_ix) = line.find(|c| delta.character() != c) {
 292                            // We found a non-whitespace character, adjust
 293                            // indentation based on the delta.
 294                            let new_indent_len =
 295                                cmp::max(0, non_whitespace_ix as isize + delta.len()) as usize;
 296                            indented_new_text
 297                                .extend(iter::repeat(delta.character()).take(new_indent_len));
 298                            indented_new_text.push_str(&line[non_whitespace_ix..]);
 299                            in_leading_whitespace = false;
 300                        } else if is_pending_line {
 301                            // We're still in leading whitespace and this line is incomplete.
 302                            // Stop processing until we receive more input.
 303                            break;
 304                        } else {
 305                            // This line is entirely whitespace. Push it without indentation.
 306                            indented_new_text.push_str(line);
 307                        }
 308                    } else {
 309                        indented_new_text.push_str(line);
 310                    }
 311
 312                    if is_pending_line {
 313                        start_ix = line_end;
 314                        break;
 315                    } else {
 316                        in_leading_whitespace = true;
 317                        indented_new_text.push('\n');
 318                        start_ix = line_end + 1;
 319                    }
 320                }
 321                buffer.replace_range(..start_ix, "");
 322
 323                // This was the last chunk, push all the buffered content as-is.
 324                if is_last_chunk {
 325                    indented_new_text.push_str(&buffer);
 326                    buffer.clear();
 327                    done = true;
 328                }
 329
 330                if !indented_new_text.is_empty() {
 331                    return Poll::Ready(Some(Ok(indented_new_text)));
 332                }
 333            }
 334
 335            Poll::Ready(None)
 336        })
 337    }
 338
 339    async fn request_edits(
 340        &self,
 341        snapshot: BufferSnapshot,
 342        edit_description: String,
 343        mut messages: Vec<LanguageModelRequestMessage>,
 344        cx: &mut AsyncApp,
 345    ) -> Result<BoxStream<'static, Result<String, LanguageModelCompletionError>>> {
 346        let path = cx.update(|cx| snapshot.resolve_file_path(cx, true))?;
 347        let prompt = EditAgentTemplate {
 348            path,
 349            edit_description,
 350        }
 351        .render(&self.templates)?;
 352
 353        let mut message_content = Vec::new();
 354        if let Some(last_message) = messages.last_mut() {
 355            if last_message.role == Role::Assistant {
 356                last_message
 357                    .content
 358                    .retain(|content| !matches!(content, MessageContent::ToolUse(_)));
 359                if last_message.content.is_empty() {
 360                    messages.pop();
 361                }
 362            }
 363        }
 364        message_content.push(MessageContent::Text(prompt));
 365        messages.push(LanguageModelRequestMessage {
 366            role: Role::User,
 367            content: message_content,
 368            cache: false,
 369        });
 370
 371        let request = LanguageModelRequest {
 372            messages,
 373            ..Default::default()
 374        };
 375        Ok(self.model.stream_completion_text(request, cx).await?.stream)
 376    }
 377
 378    fn resolve_location(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
 379        let range = Self::resolve_location_exact(buffer, search_query)
 380            .or_else(|| Self::resolve_location_fuzzy(buffer, search_query))?;
 381
 382        // Expand the range to include entire lines.
 383        let mut start = buffer.offset_to_point(buffer.clip_offset(range.start, Bias::Left));
 384        start.column = 0;
 385        let mut end = buffer.offset_to_point(buffer.clip_offset(range.end, Bias::Right));
 386        if end.column > 0 {
 387            end.column = buffer.line_len(end.row);
 388        }
 389
 390        Some(buffer.point_to_offset(start)..buffer.point_to_offset(end))
 391    }
 392
 393    fn resolve_location_exact(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
 394        let search = AhoCorasick::new([search_query]).ok()?;
 395        let mat = search
 396            .stream_find_iter(buffer.bytes_in_range(0..buffer.len()))
 397            .next()?
 398            .expect("buffer can't error");
 399        Some(mat.range())
 400    }
 401
 402    fn resolve_location_fuzzy(buffer: &BufferSnapshot, search_query: &str) -> Option<Range<usize>> {
 403        const INSERTION_COST: u32 = 3;
 404        const DELETION_COST: u32 = 10;
 405
 406        let buffer_line_count = buffer.max_point().row as usize + 1;
 407        let query_line_count = search_query.lines().count();
 408        let mut matrix = SearchMatrix::new(query_line_count + 1, buffer_line_count + 1);
 409        let mut leading_deletion_cost = 0_u32;
 410        for (row, query_line) in search_query.lines().enumerate() {
 411            let query_line = query_line.trim();
 412            leading_deletion_cost = leading_deletion_cost.saturating_add(DELETION_COST);
 413            matrix.set(
 414                row + 1,
 415                0,
 416                SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
 417            );
 418
 419            let mut buffer_lines = buffer.as_rope().chunks().lines();
 420            let mut col = 0;
 421            while let Some(buffer_line) = buffer_lines.next() {
 422                let buffer_line = buffer_line.trim();
 423                let up = SearchState::new(
 424                    matrix.get(row, col + 1).cost.saturating_add(DELETION_COST),
 425                    SearchDirection::Up,
 426                );
 427                let left = SearchState::new(
 428                    matrix.get(row + 1, col).cost.saturating_add(INSERTION_COST),
 429                    SearchDirection::Left,
 430                );
 431                let diagonal = SearchState::new(
 432                    if fuzzy_eq(query_line, buffer_line) {
 433                        matrix.get(row, col).cost
 434                    } else {
 435                        matrix
 436                            .get(row, col)
 437                            .cost
 438                            .saturating_add(DELETION_COST + INSERTION_COST)
 439                    },
 440                    SearchDirection::Diagonal,
 441                );
 442                matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
 443                col += 1;
 444            }
 445        }
 446
 447        // Traceback to find the best match
 448        let mut buffer_row_end = buffer_line_count as u32;
 449        let mut best_cost = u32::MAX;
 450        for col in 1..=buffer_line_count {
 451            let cost = matrix.get(query_line_count, col).cost;
 452            if cost < best_cost {
 453                best_cost = cost;
 454                buffer_row_end = col as u32;
 455            }
 456        }
 457
 458        let mut matched_lines = 0;
 459        let mut query_row = query_line_count;
 460        let mut buffer_row_start = buffer_row_end;
 461        while query_row > 0 && buffer_row_start > 0 {
 462            let current = matrix.get(query_row, buffer_row_start as usize);
 463            match current.direction {
 464                SearchDirection::Diagonal => {
 465                    query_row -= 1;
 466                    buffer_row_start -= 1;
 467                    matched_lines += 1;
 468                }
 469                SearchDirection::Up => {
 470                    query_row -= 1;
 471                }
 472                SearchDirection::Left => {
 473                    buffer_row_start -= 1;
 474                }
 475            }
 476        }
 477
 478        let matched_buffer_row_count = buffer_row_end - buffer_row_start;
 479        let matched_ratio =
 480            matched_lines as f32 / (matched_buffer_row_count as f32).max(query_line_count as f32);
 481        if matched_ratio >= 0.8 {
 482            let buffer_start_ix = buffer.point_to_offset(Point::new(buffer_row_start, 0));
 483            let buffer_end_ix = buffer.point_to_offset(Point::new(
 484                buffer_row_end - 1,
 485                buffer.line_len(buffer_row_end - 1),
 486            ));
 487            Some(buffer_start_ix..buffer_end_ix)
 488        } else {
 489            None
 490        }
 491    }
 492}
 493
 494fn fuzzy_eq(left: &str, right: &str) -> bool {
 495    let min_levenshtein = left.len().abs_diff(right.len());
 496    let min_normalized_levenshtein =
 497        1. - (min_levenshtein as f32 / cmp::max(left.len(), right.len()) as f32);
 498    if min_normalized_levenshtein < 0.8 {
 499        return false;
 500    }
 501
 502    strsim::normalized_levenshtein(left, right) >= 0.8
 503}
 504
 505#[derive(Copy, Clone, Debug)]
 506enum IndentDelta {
 507    Spaces(isize),
 508    Tabs(isize),
 509}
 510
 511impl IndentDelta {
 512    fn character(&self) -> char {
 513        match self {
 514            IndentDelta::Spaces(_) => ' ',
 515            IndentDelta::Tabs(_) => '\t',
 516        }
 517    }
 518
 519    fn len(&self) -> isize {
 520        match self {
 521            IndentDelta::Spaces(n) => *n,
 522            IndentDelta::Tabs(n) => *n,
 523        }
 524    }
 525}
 526
 527#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
 528enum SearchDirection {
 529    Up,
 530    Left,
 531    Diagonal,
 532}
 533
 534#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
 535struct SearchState {
 536    cost: u32,
 537    direction: SearchDirection,
 538}
 539
 540impl SearchState {
 541    fn new(cost: u32, direction: SearchDirection) -> Self {
 542        Self { cost, direction }
 543    }
 544}
 545
 546struct SearchMatrix {
 547    cols: usize,
 548    data: Vec<SearchState>,
 549}
 550
 551impl SearchMatrix {
 552    fn new(rows: usize, cols: usize) -> Self {
 553        SearchMatrix {
 554            cols,
 555            data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
 556        }
 557    }
 558
 559    fn get(&self, row: usize, col: usize) -> SearchState {
 560        self.data[row * self.cols + col]
 561    }
 562
 563    fn set(&mut self, row: usize, col: usize, cost: SearchState) {
 564        self.data[row * self.cols + col] = cost;
 565    }
 566}
 567
 568#[cfg(test)]
 569mod tests {
 570    use super::*;
 571    use fs::FakeFs;
 572    use futures::stream;
 573    use gpui::{App, AppContext, TestAppContext};
 574    use indoc::indoc;
 575    use language_model::fake_provider::FakeLanguageModel;
 576    use project::Project;
 577    use rand::prelude::*;
 578    use rand::rngs::StdRng;
 579    use std::cmp;
 580    use unindent::Unindent;
 581    use util::test::{generate_marked_text, marked_text_ranges};
 582
 583    #[gpui::test(iterations = 100)]
 584    async fn test_indentation(cx: &mut TestAppContext, mut rng: StdRng) {
 585        let agent = init_test(cx).await;
 586        let buffer = cx.new(|cx| {
 587            Buffer::local(
 588                indoc! {"
 589                    lorem
 590                            ipsum
 591                            dolor
 592                            sit
 593                "},
 594                cx,
 595            )
 596        });
 597        let raw_edits = simulate_llm_output(
 598            indoc! {"
 599                <old_text>
 600                    ipsum
 601                    dolor
 602                    sit
 603                </old_text>
 604                <new_text>
 605                    ipsum
 606                    dolor
 607                    sit
 608                amet
 609                </new_text>
 610            "},
 611            &mut rng,
 612            cx,
 613        );
 614        let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
 615        apply.await.unwrap();
 616        pretty_assertions::assert_eq!(
 617            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 618            indoc! {"
 619                lorem
 620                        ipsum
 621                        dolor
 622                        sit
 623                    amet
 624            "}
 625        );
 626    }
 627
 628    #[gpui::test(iterations = 100)]
 629    async fn test_dependent_edits(cx: &mut TestAppContext, mut rng: StdRng) {
 630        let agent = init_test(cx).await;
 631        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 632        let raw_edits = simulate_llm_output(
 633            indoc! {"
 634                <old_text>
 635                def
 636                </old_text>
 637                <new_text>
 638                DEF
 639                </new_text>
 640
 641                <old_text>
 642                DEF
 643                </old_text>
 644                <new_text>
 645                DeF
 646                </new_text>
 647            "},
 648            &mut rng,
 649            cx,
 650        );
 651        let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
 652        apply.await.unwrap();
 653        assert_eq!(
 654            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 655            "abc\nDeF\nghi"
 656        );
 657    }
 658
 659    #[gpui::test(iterations = 100)]
 660    async fn test_old_text_hallucination(cx: &mut TestAppContext, mut rng: StdRng) {
 661        let agent = init_test(cx).await;
 662        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 663        let raw_edits = simulate_llm_output(
 664            indoc! {"
 665                <old_text>
 666                jkl
 667                </old_text>
 668                <new_text>
 669                mno
 670                </new_text>
 671
 672                <old_text>
 673                abc
 674                </old_text>
 675                <new_text>
 676                ABC
 677                </new_text>
 678            "},
 679            &mut rng,
 680            cx,
 681        );
 682        let (apply, _events) = agent.apply_edits(buffer.clone(), raw_edits, &mut cx.to_async());
 683        apply.await.unwrap();
 684        assert_eq!(
 685            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 686            "ABC\ndef\nghi"
 687        );
 688    }
 689
 690    #[gpui::test]
 691    async fn test_events(cx: &mut TestAppContext) {
 692        let agent = init_test(cx).await;
 693        let buffer = cx.new(|cx| Buffer::local("abc\ndef\nghi", cx));
 694        let (chunks_tx, chunks_rx) = mpsc::unbounded();
 695        let (apply, mut events) = agent.apply_edits(
 696            buffer.clone(),
 697            chunks_rx.map(|chunk: &str| Ok(chunk.to_string())),
 698            &mut cx.to_async(),
 699        );
 700
 701        chunks_tx.unbounded_send("<old_text>a").unwrap();
 702        cx.run_until_parked();
 703        assert_eq!(drain_events(&mut events), vec![]);
 704        assert_eq!(
 705            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 706            "abc\ndef\nghi"
 707        );
 708
 709        chunks_tx.unbounded_send("bc</old_text>").unwrap();
 710        cx.run_until_parked();
 711        assert_eq!(drain_events(&mut events), vec![]);
 712        assert_eq!(
 713            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 714            "abc\ndef\nghi"
 715        );
 716
 717        chunks_tx.unbounded_send("<new_text>abX").unwrap();
 718        cx.run_until_parked();
 719        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 720        assert_eq!(
 721            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 722            "abXc\ndef\nghi"
 723        );
 724
 725        chunks_tx.unbounded_send("cY").unwrap();
 726        cx.run_until_parked();
 727        assert_eq!(drain_events(&mut events), [EditAgentOutputEvent::Edited]);
 728        assert_eq!(
 729            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 730            "abXcY\ndef\nghi"
 731        );
 732
 733        chunks_tx.unbounded_send("</new_text>").unwrap();
 734        chunks_tx.unbounded_send("<old_text>hall").unwrap();
 735        cx.run_until_parked();
 736        assert_eq!(drain_events(&mut events), vec![]);
 737        assert_eq!(
 738            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 739            "abXcY\ndef\nghi"
 740        );
 741
 742        chunks_tx.unbounded_send("ucinated old</old_text>").unwrap();
 743        chunks_tx.unbounded_send("<new_text>").unwrap();
 744        cx.run_until_parked();
 745        assert_eq!(
 746            drain_events(&mut events),
 747            vec![EditAgentOutputEvent::HallucinatedOldText(
 748                "hallucinated old".into()
 749            )]
 750        );
 751        assert_eq!(
 752            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 753            "abXcY\ndef\nghi"
 754        );
 755
 756        chunks_tx.unbounded_send("hallucinated new</new_").unwrap();
 757        chunks_tx.unbounded_send("text>").unwrap();
 758        cx.run_until_parked();
 759        assert_eq!(drain_events(&mut events), vec![]);
 760        assert_eq!(
 761            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 762            "abXcY\ndef\nghi"
 763        );
 764
 765        chunks_tx.unbounded_send("<old_text>gh").unwrap();
 766        chunks_tx.unbounded_send("i</old_text>").unwrap();
 767        chunks_tx.unbounded_send("<new_text>").unwrap();
 768        cx.run_until_parked();
 769        assert_eq!(drain_events(&mut events), vec![]);
 770        assert_eq!(
 771            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 772            "abXcY\ndef\nghi"
 773        );
 774
 775        chunks_tx.unbounded_send("GHI</new_text>").unwrap();
 776        cx.run_until_parked();
 777        assert_eq!(
 778            drain_events(&mut events),
 779            vec![EditAgentOutputEvent::Edited]
 780        );
 781        assert_eq!(
 782            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 783            "abXcY\ndef\nGHI"
 784        );
 785
 786        drop(chunks_tx);
 787        apply.await.unwrap();
 788        assert_eq!(
 789            buffer.read_with(cx, |buffer, _| buffer.snapshot().text()),
 790            "abXcY\ndef\nGHI"
 791        );
 792        assert_eq!(drain_events(&mut events), vec![]);
 793
 794        fn drain_events(
 795            stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
 796        ) -> Vec<EditAgentOutputEvent> {
 797            let mut events = Vec::new();
 798            while let Ok(Some(event)) = stream.try_next() {
 799                events.push(event);
 800            }
 801            events
 802        }
 803    }
 804
 805    #[gpui::test]
 806    fn test_resolve_location(cx: &mut App) {
 807        assert_location_resolution(
 808            concat!(
 809                "    Lorem\n",
 810                "«    ipsum»\n",
 811                "    dolor sit amet\n",
 812                "    consecteur",
 813            ),
 814            "ipsum",
 815            cx,
 816        );
 817
 818        assert_location_resolution(
 819            concat!(
 820                "    Lorem\n",
 821                "«    ipsum\n",
 822                "    dolor sit amet»\n",
 823                "    consecteur",
 824            ),
 825            "ipsum\ndolor sit amet",
 826            cx,
 827        );
 828
 829        assert_location_resolution(
 830            &"
 831            «fn foo1(a: usize) -> usize {
 832                40
 833 834
 835            fn foo2(b: usize) -> usize {
 836                42
 837            }
 838            "
 839            .unindent(),
 840            "fn foo1(a: usize) -> u32 {\n40\n}",
 841            cx,
 842        );
 843
 844        assert_location_resolution(
 845            &"
 846            class Something {
 847                one() { return 1; }
 848            «    two() { return 2222; }
 849                three() { return 333; }
 850                four() { return 4444; }
 851                five() { return 5555; }
 852                six() { return 6666; }»
 853                seven() { return 7; }
 854                eight() { return 8; }
 855            }
 856            "
 857            .unindent(),
 858            &"
 859                two() { return 2222; }
 860                four() { return 4444; }
 861                five() { return 5555; }
 862                six() { return 6666; }
 863            "
 864            .unindent(),
 865            cx,
 866        );
 867
 868        assert_location_resolution(
 869            &"
 870                use std::ops::Range;
 871                use std::sync::Mutex;
 872                use std::{
 873                    collections::HashMap,
 874                    env,
 875                    ffi::{OsStr, OsString},
 876                    fs,
 877                    io::{BufRead, BufReader},
 878                    mem,
 879                    path::{Path, PathBuf},
 880                    process::Command,
 881                    sync::LazyLock,
 882                    time::SystemTime,
 883                };
 884            "
 885            .unindent(),
 886            &"
 887                use std::collections::{HashMap, HashSet};
 888                use std::ffi::{OsStr, OsString};
 889                use std::fmt::Write as _;
 890                use std::fs;
 891                use std::io::{BufReader, Read, Write};
 892                use std::mem;
 893                use std::path::{Path, PathBuf};
 894                use std::process::Command;
 895                use std::sync::Arc;
 896            "
 897            .unindent(),
 898            cx,
 899        );
 900
 901        assert_location_resolution(
 902            indoc! {"
 903                impl Foo {
 904                    fn new() -> Self {
 905                        Self {
 906                            subscriptions: vec![
 907                                cx.observe_window_activation(window, |editor, window, cx| {
 908                                    let active = window.is_window_active();
 909                                    editor.blink_manager.update(cx, |blink_manager, cx| {
 910                                        if active {
 911                                            blink_manager.enable(cx);
 912                                        } else {
 913                                            blink_manager.disable(cx);
 914                                        }
 915                                    });
 916                                }),
 917                            ];
 918                        }
 919                    }
 920                }
 921            "},
 922            concat!(
 923                "                    editor.blink_manager.update(cx, |blink_manager, cx| {\n",
 924                "                        blink_manager.enable(cx);\n",
 925                "                    });",
 926            ),
 927            cx,
 928        );
 929
 930        assert_location_resolution(
 931            indoc! {r#"
 932                let tool = cx
 933                    .update(|cx| working_set.tool(&tool_name, cx))
 934                    .map_err(|err| {
 935                        anyhow!("Failed to look up tool '{}': {}", tool_name, err)
 936                    })?;
 937
 938                let Some(tool) = tool else {
 939                    return Err(anyhow!("Tool '{}' not found", tool_name));
 940                };
 941
 942                let project = project.clone();
 943                let action_log = action_log.clone();
 944                let messages = messages.clone();
 945                let tool_result = cx
 946                    .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
 947                    .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
 948
 949                tasks.push(tool_result.output);
 950            "#},
 951            concat!(
 952                "let tool_result = cx\n",
 953                "    .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
 954                "    .output;",
 955            ),
 956            cx,
 957        );
 958    }
 959
 960    #[gpui::test(iterations = 100)]
 961    async fn test_indent_new_text_chunks(mut rng: StdRng) {
 962        let chunks = to_random_chunks(&mut rng, "    abc\n  def\n      ghi");
 963        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
 964            Ok(EditParserEvent::NewTextChunk {
 965                chunk: chunk.clone(),
 966                done: index == chunks.len() - 1,
 967            })
 968        }));
 969        let indented_chunks =
 970            EditAgent::reindent_new_text_chunks(IndentDelta::Spaces(2), new_text_chunks)
 971                .collect::<Vec<_>>()
 972                .await;
 973        let new_text = indented_chunks
 974            .into_iter()
 975            .collect::<Result<String>>()
 976            .unwrap();
 977        assert_eq!(new_text, "      abc\n    def\n        ghi");
 978    }
 979
 980    #[gpui::test(iterations = 100)]
 981    async fn test_outdent_new_text_chunks(mut rng: StdRng) {
 982        let chunks = to_random_chunks(&mut rng, "\t\t\t\tabc\n\t\tdef\n\t\t\t\t\t\tghi");
 983        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
 984            Ok(EditParserEvent::NewTextChunk {
 985                chunk: chunk.clone(),
 986                done: index == chunks.len() - 1,
 987            })
 988        }));
 989        let indented_chunks =
 990            EditAgent::reindent_new_text_chunks(IndentDelta::Tabs(-2), new_text_chunks)
 991                .collect::<Vec<_>>()
 992                .await;
 993        let new_text = indented_chunks
 994            .into_iter()
 995            .collect::<Result<String>>()
 996            .unwrap();
 997        assert_eq!(new_text, "\t\tabc\ndef\n\t\t\t\tghi");
 998    }
 999
1000    #[gpui::test(iterations = 100)]
1001    async fn test_random_indents(mut rng: StdRng) {
1002        let len = rng.gen_range(1..=100);
1003        let new_text = util::RandomCharIter::new(&mut rng)
1004            .with_simple_text()
1005            .take(len)
1006            .collect::<String>();
1007        let new_text = new_text
1008            .split('\n')
1009            .map(|line| format!("{}{}", " ".repeat(rng.gen_range(0..=8)), line))
1010            .collect::<Vec<_>>()
1011            .join("\n");
1012        let delta = IndentDelta::Spaces(rng.gen_range(-4..=4));
1013
1014        let chunks = to_random_chunks(&mut rng, &new_text);
1015        let new_text_chunks = stream::iter(chunks.iter().enumerate().map(|(index, chunk)| {
1016            Ok(EditParserEvent::NewTextChunk {
1017                chunk: chunk.clone(),
1018                done: index == chunks.len() - 1,
1019            })
1020        }));
1021        let reindented_chunks = EditAgent::reindent_new_text_chunks(delta, new_text_chunks)
1022            .collect::<Vec<_>>()
1023            .await;
1024        let actual_reindented_text = reindented_chunks
1025            .into_iter()
1026            .collect::<Result<String>>()
1027            .unwrap();
1028        let expected_reindented_text = new_text
1029            .split('\n')
1030            .map(|line| {
1031                if let Some(ix) = line.find(|c| c != ' ') {
1032                    let new_indent = cmp::max(0, ix as isize + delta.len()) as usize;
1033                    format!("{}{}", " ".repeat(new_indent), &line[ix..])
1034                } else {
1035                    line.to_string()
1036                }
1037            })
1038            .collect::<Vec<_>>()
1039            .join("\n");
1040        assert_eq!(actual_reindented_text, expected_reindented_text);
1041    }
1042
1043    #[track_caller]
1044    fn assert_location_resolution(text_with_expected_range: &str, query: &str, cx: &mut App) {
1045        let (text, _) = marked_text_ranges(text_with_expected_range, false);
1046        let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
1047        let snapshot = buffer.read(cx).snapshot();
1048        let mut ranges = Vec::new();
1049        ranges.extend(EditAgent::resolve_location(&snapshot, query));
1050        let text_with_actual_range = generate_marked_text(&text, &ranges, false);
1051        pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
1052    }
1053
1054    fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
1055        let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
1056        let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
1057        chunk_indices.sort();
1058        chunk_indices.push(input.len());
1059
1060        let mut chunks = Vec::new();
1061        let mut last_ix = 0;
1062        for chunk_ix in chunk_indices {
1063            chunks.push(input[last_ix..chunk_ix].to_string());
1064            last_ix = chunk_ix;
1065        }
1066        chunks
1067    }
1068
1069    fn simulate_llm_output(
1070        output: &str,
1071        rng: &mut StdRng,
1072        cx: &mut TestAppContext,
1073    ) -> impl 'static + Send + Stream<Item = Result<String, LanguageModelCompletionError>> {
1074        let executor = cx.executor();
1075        stream::iter(to_random_chunks(rng, output).into_iter().map(Ok)).then(move |chunk| {
1076            let executor = executor.clone();
1077            async move {
1078                executor.simulate_random_delay().await;
1079                chunk
1080            }
1081        })
1082    }
1083
1084    async fn init_test(cx: &mut TestAppContext) -> EditAgent {
1085        cx.update(settings::init);
1086        cx.update(Project::init_settings);
1087        let project = Project::test(FakeFs::new(cx.executor()), [], cx).await;
1088        let model = Arc::new(FakeLanguageModel::default());
1089        let action_log = cx.new(|_| ActionLog::new(project));
1090        EditAgent::new(model, action_log, Templates::new())
1091    }
1092}