buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use uuid::Uuid;
   5
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use futures::{
  10    SinkExt, Stream, StreamExt, TryStreamExt as _,
  11    channel::mpsc,
  12    future::{LocalBoxFuture, Shared},
  13    join,
  14    stream::BoxStream,
  15};
  16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  17use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  18use language_model::{
  19    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  20    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  21    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  22    LanguageModelToolUse, Role, TokenUsage,
  23};
  24use multi_buffer::MultiBufferRow;
  25use parking_lot::Mutex;
  26use prompt_store::PromptBuilder;
  27use rope::Rope;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings as _;
  31use smol::future::FutureExt;
  32use std::{
  33    cmp,
  34    future::Future,
  35    iter,
  36    ops::{Range, RangeInclusive},
  37    pin::Pin,
  38    sync::Arc,
  39    task::{self, Poll},
  40    time::Instant,
  41};
  42use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  43
  44/// Use this tool when you cannot or should not make a rewrite. This includes:
  45/// - The user's request is unclear, ambiguous, or nonsensical
  46/// - The requested change cannot be made by only editing the <rewrite_this> section
  47#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  48pub struct FailureMessageInput {
  49    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  50    #[serde(default)]
  51    pub message: String,
  52}
  53
  54/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  55/// Only use this tool when you are confident you understand the user's request and can fulfill it
  56/// by editing the marked section.
  57#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  58pub struct RewriteSectionInput {
  59    /// The text to replace the section with.
  60    #[serde(default)]
  61    pub replacement_text: String,
  62}
  63
  64pub struct BufferCodegen {
  65    alternatives: Vec<Entity<CodegenAlternative>>,
  66    pub active_alternative: usize,
  67    seen_alternatives: HashSet<usize>,
  68    subscriptions: Vec<Subscription>,
  69    buffer: Entity<MultiBuffer>,
  70    range: Range<Anchor>,
  71    initial_transaction_id: Option<TransactionId>,
  72    builder: Arc<PromptBuilder>,
  73    pub is_insertion: bool,
  74    session_id: Uuid,
  75}
  76
  77pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
  78pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
  79
  80impl BufferCodegen {
  81    pub fn new(
  82        buffer: Entity<MultiBuffer>,
  83        range: Range<Anchor>,
  84        initial_transaction_id: Option<TransactionId>,
  85        session_id: Uuid,
  86        builder: Arc<PromptBuilder>,
  87        cx: &mut Context<Self>,
  88    ) -> Self {
  89        let codegen = cx.new(|cx| {
  90            CodegenAlternative::new(
  91                buffer.clone(),
  92                range.clone(),
  93                false,
  94                builder.clone(),
  95                session_id,
  96                cx,
  97            )
  98        });
  99        let mut this = Self {
 100            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
 101            alternatives: vec![codegen],
 102            active_alternative: 0,
 103            seen_alternatives: HashSet::default(),
 104            subscriptions: Vec::new(),
 105            buffer,
 106            range,
 107            initial_transaction_id,
 108            builder,
 109            session_id,
 110        };
 111        this.activate(0, cx);
 112        this
 113    }
 114
 115    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 116        let codegen = self.active_alternative().clone();
 117        self.subscriptions.clear();
 118        self.subscriptions
 119            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 120        self.subscriptions
 121            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 122    }
 123
 124    pub fn active_completion(&self, cx: &App) -> Option<String> {
 125        self.active_alternative().read(cx).current_completion()
 126    }
 127
 128    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 129        &self.alternatives[self.active_alternative]
 130    }
 131
 132    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 133        self.active_alternative().read(cx).language_name(cx)
 134    }
 135
 136    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 137        &self.active_alternative().read(cx).status
 138    }
 139
 140    pub fn alternative_count(&self, cx: &App) -> usize {
 141        LanguageModelRegistry::read_global(cx)
 142            .inline_alternative_models()
 143            .len()
 144            + 1
 145    }
 146
 147    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 148        let next_active_ix = if self.active_alternative == 0 {
 149            self.alternatives.len() - 1
 150        } else {
 151            self.active_alternative - 1
 152        };
 153        self.activate(next_active_ix, cx);
 154    }
 155
 156    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 157        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 158        self.activate(next_active_ix, cx);
 159    }
 160
 161    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 162        self.active_alternative()
 163            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 164        self.seen_alternatives.insert(index);
 165        self.active_alternative = index;
 166        self.active_alternative()
 167            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 168        self.subscribe_to_alternative(cx);
 169        cx.notify();
 170    }
 171
 172    pub fn start(
 173        &mut self,
 174        primary_model: Arc<dyn LanguageModel>,
 175        user_prompt: String,
 176        context_task: Shared<Task<Option<LoadedContext>>>,
 177        cx: &mut Context<Self>,
 178    ) -> Result<()> {
 179        let alternative_models = LanguageModelRegistry::read_global(cx)
 180            .inline_alternative_models()
 181            .to_vec();
 182
 183        self.active_alternative()
 184            .update(cx, |alternative, cx| alternative.undo(cx));
 185        self.activate(0, cx);
 186        self.alternatives.truncate(1);
 187
 188        for _ in 0..alternative_models.len() {
 189            self.alternatives.push(cx.new(|cx| {
 190                CodegenAlternative::new(
 191                    self.buffer.clone(),
 192                    self.range.clone(),
 193                    false,
 194                    self.builder.clone(),
 195                    self.session_id,
 196                    cx,
 197                )
 198            }));
 199        }
 200
 201        for (model, alternative) in iter::once(primary_model)
 202            .chain(alternative_models)
 203            .zip(&self.alternatives)
 204        {
 205            alternative.update(cx, |alternative, cx| {
 206                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 207            })?;
 208        }
 209
 210        Ok(())
 211    }
 212
 213    pub fn stop(&mut self, cx: &mut Context<Self>) {
 214        for codegen in &self.alternatives {
 215            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 216        }
 217    }
 218
 219    pub fn undo(&mut self, cx: &mut Context<Self>) {
 220        self.active_alternative()
 221            .update(cx, |codegen, cx| codegen.undo(cx));
 222
 223        self.buffer.update(cx, |buffer, cx| {
 224            if let Some(transaction_id) = self.initial_transaction_id.take() {
 225                buffer.undo_transaction(transaction_id, cx);
 226                buffer.refresh_preview(cx);
 227            }
 228        });
 229    }
 230
 231    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 232        self.active_alternative().read(cx).buffer.clone()
 233    }
 234
 235    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 236        self.active_alternative().read(cx).old_buffer.clone()
 237    }
 238
 239    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 240        self.active_alternative().read(cx).snapshot.clone()
 241    }
 242
 243    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 244        self.active_alternative().read(cx).edit_position
 245    }
 246
 247    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 248        &self.active_alternative().read(cx).diff
 249    }
 250
 251    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 252        self.active_alternative().read(cx).last_equal_ranges()
 253    }
 254
 255    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 256        self.active_alternative().read(cx).selected_text()
 257    }
 258
 259    pub fn session_id(&self) -> Uuid {
 260        self.session_id
 261    }
 262}
 263
 264impl EventEmitter<CodegenEvent> for BufferCodegen {}
 265
 266pub struct CodegenAlternative {
 267    buffer: Entity<MultiBuffer>,
 268    old_buffer: Entity<Buffer>,
 269    snapshot: MultiBufferSnapshot,
 270    edit_position: Option<Anchor>,
 271    range: Range<Anchor>,
 272    last_equal_ranges: Vec<Range<Anchor>>,
 273    transformation_transaction_id: Option<TransactionId>,
 274    status: CodegenStatus,
 275    generation: Task<()>,
 276    diff: Diff,
 277    _subscription: gpui::Subscription,
 278    builder: Arc<PromptBuilder>,
 279    active: bool,
 280    edits: Vec<(Range<Anchor>, String)>,
 281    line_operations: Vec<LineOperation>,
 282    elapsed_time: Option<f64>,
 283    completion: Option<String>,
 284    selected_text: Option<String>,
 285    pub message_id: Option<String>,
 286    session_id: Uuid,
 287    pub description: Option<String>,
 288    pub failure: Option<String>,
 289}
 290
 291impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 292
 293impl CodegenAlternative {
 294    pub fn new(
 295        buffer: Entity<MultiBuffer>,
 296        range: Range<Anchor>,
 297        active: bool,
 298        builder: Arc<PromptBuilder>,
 299        session_id: Uuid,
 300        cx: &mut Context<Self>,
 301    ) -> Self {
 302        let snapshot = buffer.read(cx).snapshot(cx);
 303
 304        let (old_buffer, _, _) = snapshot
 305            .range_to_buffer_ranges(range.clone())
 306            .pop()
 307            .unwrap();
 308        let old_buffer = cx.new(|cx| {
 309            let text = old_buffer.as_rope().clone();
 310            let line_ending = old_buffer.line_ending();
 311            let language = old_buffer.language().cloned();
 312            let language_registry = buffer
 313                .read(cx)
 314                .buffer(old_buffer.remote_id())
 315                .unwrap()
 316                .read(cx)
 317                .language_registry();
 318
 319            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 320            buffer.set_language(language, cx);
 321            if let Some(language_registry) = language_registry {
 322                buffer.set_language_registry(language_registry);
 323            }
 324            buffer
 325        });
 326
 327        Self {
 328            buffer: buffer.clone(),
 329            old_buffer,
 330            edit_position: None,
 331            message_id: None,
 332            snapshot,
 333            last_equal_ranges: Default::default(),
 334            transformation_transaction_id: None,
 335            status: CodegenStatus::Idle,
 336            generation: Task::ready(()),
 337            diff: Diff::default(),
 338            builder,
 339            active: active,
 340            edits: Vec::new(),
 341            line_operations: Vec::new(),
 342            range,
 343            elapsed_time: None,
 344            completion: None,
 345            selected_text: None,
 346            session_id,
 347            description: None,
 348            failure: None,
 349            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 350        }
 351    }
 352
 353    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 354        self.old_buffer
 355            .read(cx)
 356            .language()
 357            .map(|language| language.name())
 358    }
 359
 360    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 361        if active != self.active {
 362            self.active = active;
 363
 364            if self.active {
 365                let edits = self.edits.clone();
 366                self.apply_edits(edits, cx);
 367                if matches!(self.status, CodegenStatus::Pending) {
 368                    let line_operations = self.line_operations.clone();
 369                    self.reapply_line_based_diff(line_operations, cx);
 370                } else {
 371                    self.reapply_batch_diff(cx).detach();
 372                }
 373            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 374                self.buffer.update(cx, |buffer, cx| {
 375                    buffer.undo_transaction(transaction_id, cx);
 376                    buffer.forget_transaction(transaction_id, cx);
 377                });
 378            }
 379        }
 380    }
 381
 382    fn handle_buffer_event(
 383        &mut self,
 384        _buffer: Entity<MultiBuffer>,
 385        event: &multi_buffer::Event,
 386        cx: &mut Context<Self>,
 387    ) {
 388        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 389            && self.transformation_transaction_id == Some(*transaction_id)
 390        {
 391            self.transformation_transaction_id = None;
 392            self.generation = Task::ready(());
 393            cx.emit(CodegenEvent::Undone);
 394        }
 395    }
 396
 397    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 398        &self.last_equal_ranges
 399    }
 400
 401    pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 402        model.supports_streaming_tools()
 403            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 404    }
 405
 406    pub fn start(
 407        &mut self,
 408        user_prompt: String,
 409        context_task: Shared<Task<Option<LoadedContext>>>,
 410        model: Arc<dyn LanguageModel>,
 411        cx: &mut Context<Self>,
 412    ) -> Result<()> {
 413        // Clear the model explanation since the user has started a new generation.
 414        self.description = None;
 415
 416        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 417            self.buffer.update(cx, |buffer, cx| {
 418                buffer.undo_transaction(transformation_transaction_id, cx);
 419            });
 420        }
 421
 422        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 423
 424        if Self::use_streaming_tools(model.as_ref(), cx) {
 425            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 426            let completion_events = cx.spawn({
 427                let model = model.clone();
 428                async move |_, cx| model.stream_completion(request.await, cx).await
 429            });
 430            self.generation = self.handle_completion(model, completion_events, cx);
 431        } else {
 432            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 433                if user_prompt.trim().to_lowercase() == "delete" {
 434                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 435                } else {
 436                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 437                    cx.spawn({
 438                        let model = model.clone();
 439                        async move |_, cx| {
 440                            Ok(model.stream_completion_text(request.await, cx).await?)
 441                        }
 442                    })
 443                    .boxed_local()
 444                };
 445            self.generation =
 446                self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
 447        }
 448
 449        Ok(())
 450    }
 451
 452    fn build_request_tools(
 453        &self,
 454        model: &Arc<dyn LanguageModel>,
 455        user_prompt: String,
 456        context_task: Shared<Task<Option<LoadedContext>>>,
 457        cx: &mut App,
 458    ) -> Result<Task<LanguageModelRequest>> {
 459        let buffer = self.buffer.read(cx).snapshot(cx);
 460        let language = buffer.language_at(self.range.start);
 461        let language_name = if let Some(language) = language.as_ref() {
 462            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 463                None
 464            } else {
 465                Some(language.name())
 466            }
 467        } else {
 468            None
 469        };
 470
 471        let language_name = language_name.as_ref();
 472        let start = buffer.point_to_buffer_offset(self.range.start);
 473        let end = buffer.point_to_buffer_offset(self.range.end);
 474        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 475            let (start_buffer, start_buffer_offset) = start;
 476            let (end_buffer, end_buffer_offset) = end;
 477            if start_buffer.remote_id() == end_buffer.remote_id() {
 478                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 479            } else {
 480                anyhow::bail!("invalid transformation range");
 481            }
 482        } else {
 483            anyhow::bail!("invalid transformation range");
 484        };
 485
 486        let system_prompt = self
 487            .builder
 488            .generate_inline_transformation_prompt_tools(
 489                language_name,
 490                buffer,
 491                range.start.0..range.end.0,
 492            )
 493            .context("generating content prompt")?;
 494
 495        let temperature = AgentSettings::temperature_for_model(model, cx);
 496
 497        let tool_input_format = model.tool_input_format();
 498        let tool_choice = model
 499            .supports_tool_choice(LanguageModelToolChoice::Any)
 500            .then_some(LanguageModelToolChoice::Any);
 501
 502        Ok(cx.spawn(async move |_cx| {
 503            let mut messages = vec![LanguageModelRequestMessage {
 504                role: Role::System,
 505                content: vec![system_prompt.into()],
 506                cache: false,
 507                reasoning_details: None,
 508            }];
 509
 510            let mut user_message = LanguageModelRequestMessage {
 511                role: Role::User,
 512                content: Vec::new(),
 513                cache: false,
 514                reasoning_details: None,
 515            };
 516
 517            if let Some(context) = context_task.await {
 518                context.add_to_request_message(&mut user_message);
 519            }
 520
 521            user_message.content.push(user_prompt.into());
 522            messages.push(user_message);
 523
 524            let tools = vec![
 525                LanguageModelRequestTool {
 526                    name: REWRITE_SECTION_TOOL_NAME.to_string(),
 527                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 528                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 529                },
 530                LanguageModelRequestTool {
 531                    name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
 532                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 533                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 534                },
 535            ];
 536
 537            LanguageModelRequest {
 538                thread_id: None,
 539                prompt_id: None,
 540                intent: Some(CompletionIntent::InlineAssist),
 541                mode: None,
 542                tools,
 543                tool_choice,
 544                stop: Vec::new(),
 545                temperature,
 546                messages,
 547                thinking_allowed: false,
 548            }
 549        }))
 550    }
 551
 552    fn build_request(
 553        &self,
 554        model: &Arc<dyn LanguageModel>,
 555        user_prompt: String,
 556        context_task: Shared<Task<Option<LoadedContext>>>,
 557        cx: &mut App,
 558    ) -> Result<Task<LanguageModelRequest>> {
 559        if Self::use_streaming_tools(model.as_ref(), cx) {
 560            return self.build_request_tools(model, user_prompt, context_task, cx);
 561        }
 562
 563        let buffer = self.buffer.read(cx).snapshot(cx);
 564        let language = buffer.language_at(self.range.start);
 565        let language_name = if let Some(language) = language.as_ref() {
 566            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 567                None
 568            } else {
 569                Some(language.name())
 570            }
 571        } else {
 572            None
 573        };
 574
 575        let language_name = language_name.as_ref();
 576        let start = buffer.point_to_buffer_offset(self.range.start);
 577        let end = buffer.point_to_buffer_offset(self.range.end);
 578        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 579            let (start_buffer, start_buffer_offset) = start;
 580            let (end_buffer, end_buffer_offset) = end;
 581            if start_buffer.remote_id() == end_buffer.remote_id() {
 582                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 583            } else {
 584                anyhow::bail!("invalid transformation range");
 585            }
 586        } else {
 587            anyhow::bail!("invalid transformation range");
 588        };
 589
 590        let prompt = self
 591            .builder
 592            .generate_inline_transformation_prompt(
 593                user_prompt,
 594                language_name,
 595                buffer,
 596                range.start.0..range.end.0,
 597            )
 598            .context("generating content prompt")?;
 599
 600        let temperature = AgentSettings::temperature_for_model(model, cx);
 601
 602        Ok(cx.spawn(async move |_cx| {
 603            let mut request_message = LanguageModelRequestMessage {
 604                role: Role::User,
 605                content: Vec::new(),
 606                cache: false,
 607                reasoning_details: None,
 608            };
 609
 610            if let Some(context) = context_task.await {
 611                context.add_to_request_message(&mut request_message);
 612            }
 613
 614            request_message.content.push(prompt.into());
 615
 616            LanguageModelRequest {
 617                thread_id: None,
 618                prompt_id: None,
 619                intent: Some(CompletionIntent::InlineAssist),
 620                mode: None,
 621                tools: Vec::new(),
 622                tool_choice: None,
 623                stop: Vec::new(),
 624                temperature,
 625                messages: vec![request_message],
 626                thinking_allowed: false,
 627            }
 628        }))
 629    }
 630
 631    pub fn handle_stream(
 632        &mut self,
 633        model: Arc<dyn LanguageModel>,
 634        strip_invalid_spans: bool,
 635        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 636        cx: &mut Context<Self>,
 637    ) -> Task<()> {
 638        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 639        let session_id = self.session_id;
 640        let model_telemetry_id = model.telemetry_id();
 641        let model_provider_id = model.provider_id().to_string();
 642        let start_time = Instant::now();
 643
 644        // Make a new snapshot and re-resolve anchor in case the document was modified.
 645        // This can happen often if the editor loses focus and is saved + reformatted,
 646        // as in https://github.com/zed-industries/zed/issues/39088
 647        self.snapshot = self.buffer.read(cx).snapshot(cx);
 648        self.range = self.snapshot.anchor_after(self.range.start)
 649            ..self.snapshot.anchor_after(self.range.end);
 650
 651        let snapshot = self.snapshot.clone();
 652        let selected_text = snapshot
 653            .text_for_range(self.range.start..self.range.end)
 654            .collect::<Rope>();
 655
 656        self.selected_text = Some(selected_text.to_string());
 657
 658        let selection_start = self.range.start.to_point(&snapshot);
 659
 660        // Start with the indentation of the first line in the selection
 661        let mut suggested_line_indent = snapshot
 662            .suggested_indents(selection_start.row..=selection_start.row, cx)
 663            .into_values()
 664            .next()
 665            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 666
 667        // If the first line in the selection does not have indentation, check the following lines
 668        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 669            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 670                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 671                // Prefer tabs if a line in the selection uses tabs as indentation
 672                if line_indent.kind == IndentKind::Tab {
 673                    suggested_line_indent.kind = IndentKind::Tab;
 674                    break;
 675                }
 676            }
 677        }
 678
 679        let language_name = {
 680            let multibuffer = self.buffer.read(cx);
 681            let snapshot = multibuffer.snapshot(cx);
 682            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 683            ranges
 684                .first()
 685                .and_then(|(buffer, _, _)| buffer.language())
 686                .map(|language| language.name())
 687        };
 688
 689        self.diff = Diff::default();
 690        self.status = CodegenStatus::Pending;
 691        let mut edit_start = self.range.start.to_offset(&snapshot);
 692        let completion = Arc::new(Mutex::new(String::new()));
 693        let completion_clone = completion.clone();
 694
 695        cx.notify();
 696        cx.spawn(async move |codegen, cx| {
 697            let stream = stream.await;
 698
 699            let token_usage = stream
 700                .as_ref()
 701                .ok()
 702                .map(|stream| stream.last_token_usage.clone());
 703            let message_id = stream
 704                .as_ref()
 705                .ok()
 706                .and_then(|stream| stream.message_id.clone());
 707            let generate = async {
 708                let model_telemetry_id = model_telemetry_id.clone();
 709                let model_provider_id = model_provider_id.clone();
 710                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 711                let message_id = message_id.clone();
 712                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 713                    let anthropic_reporter = anthropic_reporter.clone();
 714                    let language_name = language_name.clone();
 715                    async move {
 716                        let mut response_latency = None;
 717                        let request_start = Instant::now();
 718                        let diff = async {
 719                            let raw_stream = stream?.stream.map_err(|error| error.into());
 720
 721                            let stripped;
 722                            let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
 723                                if strip_invalid_spans {
 724                                    stripped = StripInvalidSpans::new(raw_stream);
 725                                    Box::pin(stripped)
 726                                } else {
 727                                    Box::pin(raw_stream)
 728                                };
 729
 730                            let mut diff = StreamingDiff::new(selected_text.to_string());
 731                            let mut line_diff = LineDiff::default();
 732
 733                            let mut new_text = String::new();
 734                            let mut base_indent = None;
 735                            let mut line_indent = None;
 736                            let mut first_line = true;
 737
 738                            while let Some(chunk) = chunks.next().await {
 739                                if response_latency.is_none() {
 740                                    response_latency = Some(request_start.elapsed());
 741                                }
 742                                let chunk = chunk?;
 743                                completion_clone.lock().push_str(&chunk);
 744
 745                                let mut lines = chunk.split('\n').peekable();
 746                                while let Some(line) = lines.next() {
 747                                    new_text.push_str(line);
 748                                    if line_indent.is_none()
 749                                        && let Some(non_whitespace_ch_ix) =
 750                                            new_text.find(|ch: char| !ch.is_whitespace())
 751                                    {
 752                                        line_indent = Some(non_whitespace_ch_ix);
 753                                        base_indent = base_indent.or(line_indent);
 754
 755                                        let line_indent = line_indent.unwrap();
 756                                        let base_indent = base_indent.unwrap();
 757                                        let indent_delta = line_indent as i32 - base_indent as i32;
 758                                        let mut corrected_indent_len = cmp::max(
 759                                            0,
 760                                            suggested_line_indent.len as i32 + indent_delta,
 761                                        )
 762                                            as usize;
 763                                        if first_line {
 764                                            corrected_indent_len = corrected_indent_len
 765                                                .saturating_sub(selection_start.column as usize);
 766                                        }
 767
 768                                        let indent_char = suggested_line_indent.char();
 769                                        let mut indent_buffer = [0; 4];
 770                                        let indent_str =
 771                                            indent_char.encode_utf8(&mut indent_buffer);
 772                                        new_text.replace_range(
 773                                            ..line_indent,
 774                                            &indent_str.repeat(corrected_indent_len),
 775                                        );
 776                                    }
 777
 778                                    if line_indent.is_some() {
 779                                        let char_ops = diff.push_new(&new_text);
 780                                        line_diff.push_char_operations(&char_ops, &selected_text);
 781                                        diff_tx
 782                                            .send((char_ops, line_diff.line_operations()))
 783                                            .await?;
 784                                        new_text.clear();
 785                                    }
 786
 787                                    if lines.peek().is_some() {
 788                                        let char_ops = diff.push_new("\n");
 789                                        line_diff.push_char_operations(&char_ops, &selected_text);
 790                                        diff_tx
 791                                            .send((char_ops, line_diff.line_operations()))
 792                                            .await?;
 793                                        if line_indent.is_none() {
 794                                            // Don't write out the leading indentation in empty lines on the next line
 795                                            // This is the case where the above if statement didn't clear the buffer
 796                                            new_text.clear();
 797                                        }
 798                                        line_indent = None;
 799                                        first_line = false;
 800                                    }
 801                                }
 802                            }
 803
 804                            let mut char_ops = diff.push_new(&new_text);
 805                            char_ops.extend(diff.finish());
 806                            line_diff.push_char_operations(&char_ops, &selected_text);
 807                            line_diff.finish(&selected_text);
 808                            diff_tx
 809                                .send((char_ops, line_diff.line_operations()))
 810                                .await?;
 811
 812                            anyhow::Ok(())
 813                        };
 814
 815                        let result = diff.await;
 816
 817                        let error_message = result.as_ref().err().map(|error| error.to_string());
 818                        telemetry::event!(
 819                            "Assistant Responded",
 820                            kind = "inline",
 821                            phase = "response",
 822                            session_id = session_id.to_string(),
 823                            model = model_telemetry_id,
 824                            model_provider = model_provider_id,
 825                            language_name = language_name.as_ref().map(|n| n.to_string()),
 826                            message_id = message_id.as_deref(),
 827                            response_latency = response_latency,
 828                            error_message = error_message.as_deref(),
 829                        );
 830
 831                        anthropic_reporter.report(language_model::AnthropicEventData {
 832                            completion_type: language_model::AnthropicCompletionType::Editor,
 833                            event: language_model::AnthropicEventType::Response,
 834                            language_name: language_name.map(|n| n.to_string()),
 835                            message_id,
 836                        });
 837
 838                        result?;
 839                        Ok(())
 840                    }
 841                });
 842
 843                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 844                    codegen.update(cx, |codegen, cx| {
 845                        codegen.last_equal_ranges.clear();
 846
 847                        let edits = char_ops
 848                            .into_iter()
 849                            .filter_map(|operation| match operation {
 850                                CharOperation::Insert { text } => {
 851                                    let edit_start = snapshot.anchor_after(edit_start);
 852                                    Some((edit_start..edit_start, text))
 853                                }
 854                                CharOperation::Delete { bytes } => {
 855                                    let edit_end = edit_start + bytes;
 856                                    let edit_range = snapshot.anchor_after(edit_start)
 857                                        ..snapshot.anchor_before(edit_end);
 858                                    edit_start = edit_end;
 859                                    Some((edit_range, String::new()))
 860                                }
 861                                CharOperation::Keep { bytes } => {
 862                                    let edit_end = edit_start + bytes;
 863                                    let edit_range = snapshot.anchor_after(edit_start)
 864                                        ..snapshot.anchor_before(edit_end);
 865                                    edit_start = edit_end;
 866                                    codegen.last_equal_ranges.push(edit_range);
 867                                    None
 868                                }
 869                            })
 870                            .collect::<Vec<_>>();
 871
 872                        if codegen.active {
 873                            codegen.apply_edits(edits.iter().cloned(), cx);
 874                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 875                        }
 876                        codegen.edits.extend(edits);
 877                        codegen.line_operations = line_ops;
 878                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 879
 880                        cx.notify();
 881                    })?;
 882                }
 883
 884                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 885                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 886                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 887                let batch_diff_task =
 888                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 889                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 890                line_based_stream_diff?;
 891
 892                anyhow::Ok(())
 893            };
 894
 895            let result = generate.await;
 896            let elapsed_time = start_time.elapsed().as_secs_f64();
 897
 898            codegen
 899                .update(cx, |this, cx| {
 900                    this.message_id = message_id;
 901                    this.last_equal_ranges.clear();
 902                    if let Err(error) = result {
 903                        this.status = CodegenStatus::Error(error);
 904                    } else {
 905                        this.status = CodegenStatus::Done;
 906                    }
 907                    this.elapsed_time = Some(elapsed_time);
 908                    this.completion = Some(completion.lock().clone());
 909                    if let Some(usage) = token_usage {
 910                        let usage = usage.lock();
 911                        telemetry::event!(
 912                            "Inline Assistant Completion",
 913                            model = model_telemetry_id,
 914                            model_provider = model_provider_id,
 915                            input_tokens = usage.input_tokens,
 916                            output_tokens = usage.output_tokens,
 917                        )
 918                    }
 919
 920                    cx.emit(CodegenEvent::Finished);
 921                    cx.notify();
 922                })
 923                .ok();
 924        })
 925    }
 926
 927    pub fn current_completion(&self) -> Option<String> {
 928        self.completion.clone()
 929    }
 930
 931    #[cfg(any(test, feature = "test-support"))]
 932    pub fn current_description(&self) -> Option<String> {
 933        self.description.clone()
 934    }
 935
 936    #[cfg(any(test, feature = "test-support"))]
 937    pub fn current_failure(&self) -> Option<String> {
 938        self.failure.clone()
 939    }
 940
 941    pub fn selected_text(&self) -> Option<&str> {
 942        self.selected_text.as_deref()
 943    }
 944
 945    pub fn stop(&mut self, cx: &mut Context<Self>) {
 946        self.last_equal_ranges.clear();
 947        if self.diff.is_empty() {
 948            self.status = CodegenStatus::Idle;
 949        } else {
 950            self.status = CodegenStatus::Done;
 951        }
 952        self.generation = Task::ready(());
 953        cx.emit(CodegenEvent::Finished);
 954        cx.notify();
 955    }
 956
 957    pub fn undo(&mut self, cx: &mut Context<Self>) {
 958        self.buffer.update(cx, |buffer, cx| {
 959            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 960                buffer.undo_transaction(transaction_id, cx);
 961                buffer.refresh_preview(cx);
 962            }
 963        });
 964    }
 965
 966    fn apply_edits(
 967        &mut self,
 968        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 969        cx: &mut Context<CodegenAlternative>,
 970    ) {
 971        let transaction = self.buffer.update(cx, |buffer, cx| {
 972            // Avoid grouping agent edits with user edits.
 973            buffer.finalize_last_transaction(cx);
 974            buffer.start_transaction(cx);
 975            buffer.edit(edits, None, cx);
 976            buffer.end_transaction(cx)
 977        });
 978
 979        if let Some(transaction) = transaction {
 980            if let Some(first_transaction) = self.transformation_transaction_id {
 981                // Group all agent edits into the first transaction.
 982                self.buffer.update(cx, |buffer, cx| {
 983                    buffer.merge_transactions(transaction, first_transaction, cx)
 984                });
 985            } else {
 986                self.transformation_transaction_id = Some(transaction);
 987                self.buffer
 988                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 989            }
 990        }
 991    }
 992
 993    fn reapply_line_based_diff(
 994        &mut self,
 995        line_operations: impl IntoIterator<Item = LineOperation>,
 996        cx: &mut Context<Self>,
 997    ) {
 998        let old_snapshot = self.snapshot.clone();
 999        let old_range = self.range.to_point(&old_snapshot);
1000        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1001        let new_range = self.range.to_point(&new_snapshot);
1002
1003        let mut old_row = old_range.start.row;
1004        let mut new_row = new_range.start.row;
1005
1006        self.diff.deleted_row_ranges.clear();
1007        self.diff.inserted_row_ranges.clear();
1008        for operation in line_operations {
1009            match operation {
1010                LineOperation::Keep { lines } => {
1011                    old_row += lines;
1012                    new_row += lines;
1013                }
1014                LineOperation::Delete { lines } => {
1015                    let old_end_row = old_row + lines - 1;
1016                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1017
1018                    if let Some((_, last_deleted_row_range)) =
1019                        self.diff.deleted_row_ranges.last_mut()
1020                    {
1021                        if *last_deleted_row_range.end() + 1 == old_row {
1022                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1023                        } else {
1024                            self.diff
1025                                .deleted_row_ranges
1026                                .push((new_row, old_row..=old_end_row));
1027                        }
1028                    } else {
1029                        self.diff
1030                            .deleted_row_ranges
1031                            .push((new_row, old_row..=old_end_row));
1032                    }
1033
1034                    old_row += lines;
1035                }
1036                LineOperation::Insert { lines } => {
1037                    let new_end_row = new_row + lines - 1;
1038                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1039                    let end = new_snapshot.anchor_before(Point::new(
1040                        new_end_row,
1041                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1042                    ));
1043                    self.diff.inserted_row_ranges.push(start..end);
1044                    new_row += lines;
1045                }
1046            }
1047
1048            cx.notify();
1049        }
1050    }
1051
1052    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1053        let old_snapshot = self.snapshot.clone();
1054        let old_range = self.range.to_point(&old_snapshot);
1055        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1056        let new_range = self.range.to_point(&new_snapshot);
1057
1058        cx.spawn(async move |codegen, cx| {
1059            let (deleted_row_ranges, inserted_row_ranges) = cx
1060                .background_spawn(async move {
1061                    let old_text = old_snapshot
1062                        .text_for_range(
1063                            Point::new(old_range.start.row, 0)
1064                                ..Point::new(
1065                                    old_range.end.row,
1066                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1067                                ),
1068                        )
1069                        .collect::<String>();
1070                    let new_text = new_snapshot
1071                        .text_for_range(
1072                            Point::new(new_range.start.row, 0)
1073                                ..Point::new(
1074                                    new_range.end.row,
1075                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1076                                ),
1077                        )
1078                        .collect::<String>();
1079
1080                    let old_start_row = old_range.start.row;
1081                    let new_start_row = new_range.start.row;
1082                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1083                    let mut inserted_row_ranges = Vec::new();
1084                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1085                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1086                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1087                        if !old_rows.is_empty() {
1088                            deleted_row_ranges.push((
1089                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1090                                old_rows.start..=old_rows.end - 1,
1091                            ));
1092                        }
1093                        if !new_rows.is_empty() {
1094                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1095                            let new_end_row = new_rows.end - 1;
1096                            let end = new_snapshot.anchor_before(Point::new(
1097                                new_end_row,
1098                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1099                            ));
1100                            inserted_row_ranges.push(start..end);
1101                        }
1102                    }
1103                    (deleted_row_ranges, inserted_row_ranges)
1104                })
1105                .await;
1106
1107            codegen
1108                .update(cx, |codegen, cx| {
1109                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1110                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1111                    cx.notify();
1112                })
1113                .ok();
1114        })
1115    }
1116
1117    fn handle_completion(
1118        &mut self,
1119        model: Arc<dyn LanguageModel>,
1120        completion_stream: Task<
1121            Result<
1122                BoxStream<
1123                    'static,
1124                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1125                >,
1126                LanguageModelCompletionError,
1127            >,
1128        >,
1129        cx: &mut Context<Self>,
1130    ) -> Task<()> {
1131        self.diff = Diff::default();
1132        self.status = CodegenStatus::Pending;
1133
1134        cx.notify();
1135        // Leaving this in generation so that STOP equivalent events are respected even
1136        // while we're still pre-processing the completion event
1137        cx.spawn(async move |codegen, cx| {
1138            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1139                let _ = codegen.update(cx, |this, cx| {
1140                    this.status = status;
1141                    cx.emit(CodegenEvent::Finished);
1142                    cx.notify();
1143                });
1144            };
1145
1146            let mut completion_events = match completion_stream.await {
1147                Ok(events) => events,
1148                Err(err) => {
1149                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1150                    return;
1151                }
1152            };
1153
1154            enum ToolUseOutput {
1155                Rewrite {
1156                    text: String,
1157                    description: Option<String>,
1158                },
1159                Failure(String),
1160            }
1161
1162            enum ModelUpdate {
1163                Description(String),
1164                Failure(String),
1165            }
1166
1167            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1168            let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1169                let mut chars_read_so_far = chars_read_so_far.lock();
1170                match tool_use.name.as_ref() {
1171                    REWRITE_SECTION_TOOL_NAME => {
1172                        let Ok(input) =
1173                            serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1174                        else {
1175                            return None;
1176                        };
1177                        let text = input.replacement_text[*chars_read_so_far..].to_string();
1178                        *chars_read_so_far = input.replacement_text.len();
1179                        Some(ToolUseOutput::Rewrite {
1180                            text,
1181                            description: None,
1182                        })
1183                    }
1184                    FAILURE_MESSAGE_TOOL_NAME => {
1185                        let Ok(mut input) =
1186                            serde_json::from_value::<FailureMessageInput>(tool_use.input)
1187                        else {
1188                            return None;
1189                        };
1190                        Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1191                    }
1192                    _ => None,
1193                }
1194            };
1195
1196            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1197
1198            cx.spawn({
1199                let codegen = codegen.clone();
1200                async move |cx| {
1201                    while let Some(update) = message_rx.next().await {
1202                        let _ = codegen.update(cx, |this, _cx| match update {
1203                            ModelUpdate::Description(d) => this.description = Some(d),
1204                            ModelUpdate::Failure(f) => this.failure = Some(f),
1205                        });
1206                    }
1207                }
1208            })
1209            .detach();
1210
1211            let mut message_id = None;
1212            let mut first_text = None;
1213            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1214            let total_text = Arc::new(Mutex::new(String::new()));
1215
1216            loop {
1217                if let Some(first_event) = completion_events.next().await {
1218                    match first_event {
1219                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1220                            message_id = Some(id);
1221                        }
1222                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1223                            if let Some(output) = process_tool_use(tool_use) {
1224                                let (text, update) = match output {
1225                                    ToolUseOutput::Rewrite { text, description } => {
1226                                        (Some(text), description.map(ModelUpdate::Description))
1227                                    }
1228                                    ToolUseOutput::Failure(message) => {
1229                                        (None, Some(ModelUpdate::Failure(message)))
1230                                    }
1231                                };
1232                                if let Some(update) = update {
1233                                    let _ = message_tx.unbounded_send(update);
1234                                }
1235                                first_text = text;
1236                                if first_text.is_some() {
1237                                    break;
1238                                }
1239                            }
1240                        }
1241                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1242                            *last_token_usage.lock() = token_usage;
1243                        }
1244                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1245                            let mut lock = total_text.lock();
1246                            lock.push_str(&text);
1247                        }
1248                        Ok(e) => {
1249                            log::warn!("Unexpected event: {:?}", e);
1250                            break;
1251                        }
1252                        Err(e) => {
1253                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1254                            break;
1255                        }
1256                    }
1257                }
1258            }
1259
1260            let Some(first_text) = first_text else {
1261                finish_with_status(CodegenStatus::Done, cx);
1262                return;
1263            };
1264
1265            let move_last_token_usage = last_token_usage.clone();
1266
1267            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1268                completion_events.filter_map(move |e| {
1269                    let process_tool_use = process_tool_use.clone();
1270                    let last_token_usage = move_last_token_usage.clone();
1271                    let total_text = total_text.clone();
1272                    let mut message_tx = message_tx.clone();
1273                    async move {
1274                        match e {
1275                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1276                                let Some(output) = process_tool_use(tool_use) else {
1277                                    return None;
1278                                };
1279                                let (text, update) = match output {
1280                                    ToolUseOutput::Rewrite { text, description } => {
1281                                        (Some(text), description.map(ModelUpdate::Description))
1282                                    }
1283                                    ToolUseOutput::Failure(message) => {
1284                                        (None, Some(ModelUpdate::Failure(message)))
1285                                    }
1286                                };
1287                                if let Some(update) = update {
1288                                    let _ = message_tx.send(update).await;
1289                                }
1290                                text.map(Ok)
1291                            }
1292                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1293                                *last_token_usage.lock() = token_usage;
1294                                None
1295                            }
1296                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1297                                let mut lock = total_text.lock();
1298                                lock.push_str(&text);
1299                                None
1300                            }
1301                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1302                            e => {
1303                                log::error!("UNEXPECTED EVENT {:?}", e);
1304                                None
1305                            }
1306                        }
1307                    }
1308                }),
1309            ));
1310
1311            let language_model_text_stream = LanguageModelTextStream {
1312                message_id: message_id,
1313                stream: text_stream,
1314                last_token_usage,
1315            };
1316
1317            let Some(task) = codegen
1318                .update(cx, move |codegen, cx| {
1319                    codegen.handle_stream(
1320                        model,
1321                        /* strip_invalid_spans: */ false,
1322                        async { Ok(language_model_text_stream) },
1323                        cx,
1324                    )
1325                })
1326                .ok()
1327            else {
1328                return;
1329            };
1330
1331            task.await;
1332        })
1333    }
1334}
1335
1336#[derive(Copy, Clone, Debug)]
1337pub enum CodegenEvent {
1338    Finished,
1339    Undone,
1340}
1341
1342struct StripInvalidSpans<T> {
1343    stream: T,
1344    stream_done: bool,
1345    buffer: String,
1346    first_line: bool,
1347    line_end: bool,
1348    starts_with_code_block: bool,
1349}
1350
1351impl<T> StripInvalidSpans<T>
1352where
1353    T: Stream<Item = Result<String>>,
1354{
1355    fn new(stream: T) -> Self {
1356        Self {
1357            stream,
1358            stream_done: false,
1359            buffer: String::new(),
1360            first_line: true,
1361            line_end: false,
1362            starts_with_code_block: false,
1363        }
1364    }
1365}
1366
1367impl<T> Stream for StripInvalidSpans<T>
1368where
1369    T: Stream<Item = Result<String>>,
1370{
1371    type Item = Result<String>;
1372
1373    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1374        const CODE_BLOCK_DELIMITER: &str = "```";
1375        const CURSOR_SPAN: &str = "<|CURSOR|>";
1376
1377        let this = unsafe { self.get_unchecked_mut() };
1378        loop {
1379            if !this.stream_done {
1380                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1381                match stream.as_mut().poll_next(cx) {
1382                    Poll::Ready(Some(Ok(chunk))) => {
1383                        this.buffer.push_str(&chunk);
1384                    }
1385                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1386                    Poll::Ready(None) => {
1387                        this.stream_done = true;
1388                    }
1389                    Poll::Pending => return Poll::Pending,
1390                }
1391            }
1392
1393            let mut chunk = String::new();
1394            let mut consumed = 0;
1395            if !this.buffer.is_empty() {
1396                let mut lines = this.buffer.split('\n').enumerate().peekable();
1397                while let Some((line_ix, line)) = lines.next() {
1398                    if line_ix > 0 {
1399                        this.first_line = false;
1400                    }
1401
1402                    if this.first_line {
1403                        let trimmed_line = line.trim();
1404                        if lines.peek().is_some() {
1405                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1406                                consumed += line.len() + 1;
1407                                this.starts_with_code_block = true;
1408                                continue;
1409                            }
1410                        } else if trimmed_line.is_empty()
1411                            || prefixes(CODE_BLOCK_DELIMITER)
1412                                .any(|prefix| trimmed_line.starts_with(prefix))
1413                        {
1414                            break;
1415                        }
1416                    }
1417
1418                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1419                    if lines.peek().is_some() {
1420                        if this.line_end {
1421                            chunk.push('\n');
1422                        }
1423
1424                        chunk.push_str(&line_without_cursor);
1425                        this.line_end = true;
1426                        consumed += line.len() + 1;
1427                    } else if this.stream_done {
1428                        if !this.starts_with_code_block
1429                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1430                        {
1431                            if this.line_end {
1432                                chunk.push('\n');
1433                            }
1434
1435                            chunk.push_str(line);
1436                        }
1437
1438                        consumed += line.len();
1439                    } else {
1440                        let trimmed_line = line.trim();
1441                        if trimmed_line.is_empty()
1442                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1443                            || prefixes(CODE_BLOCK_DELIMITER)
1444                                .any(|prefix| trimmed_line.ends_with(prefix))
1445                        {
1446                            break;
1447                        } else {
1448                            if this.line_end {
1449                                chunk.push('\n');
1450                                this.line_end = false;
1451                            }
1452
1453                            chunk.push_str(&line_without_cursor);
1454                            consumed += line.len();
1455                        }
1456                    }
1457                }
1458            }
1459
1460            this.buffer = this.buffer.split_off(consumed);
1461            if !chunk.is_empty() {
1462                return Poll::Ready(Some(Ok(chunk)));
1463            } else if this.stream_done {
1464                return Poll::Ready(None);
1465            }
1466        }
1467    }
1468}
1469
1470fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1471    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1472}
1473
1474#[derive(Default)]
1475pub struct Diff {
1476    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1477    pub inserted_row_ranges: Vec<Range<Anchor>>,
1478}
1479
1480impl Diff {
1481    fn is_empty(&self) -> bool {
1482        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1483    }
1484}
1485
1486#[cfg(test)]
1487mod tests {
1488    use super::*;
1489    use futures::{
1490        Stream,
1491        stream::{self},
1492    };
1493    use gpui::TestAppContext;
1494    use indoc::indoc;
1495    use language::{Buffer, Point};
1496    use language_model::fake_provider::FakeLanguageModel;
1497    use language_model::{
1498        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
1499        LanguageModelToolUse, StopReason, TokenUsage,
1500    };
1501    use languages::rust_lang;
1502    use rand::prelude::*;
1503    use settings::SettingsStore;
1504    use std::{future, sync::Arc};
1505
1506    #[gpui::test(iterations = 10)]
1507    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1508        init_test(cx);
1509
1510        let text = indoc! {"
1511            fn main() {
1512                let x = 0;
1513                for _ in 0..10 {
1514                    x += 1;
1515                }
1516            }
1517        "};
1518        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1519        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1520        let range = buffer.read_with(cx, |buffer, cx| {
1521            let snapshot = buffer.snapshot(cx);
1522            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1523        });
1524        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1525        let codegen = cx.new(|cx| {
1526            CodegenAlternative::new(
1527                buffer.clone(),
1528                range.clone(),
1529                true,
1530                prompt_builder,
1531                Uuid::new_v4(),
1532                cx,
1533            )
1534        });
1535
1536        let chunks_tx = simulate_response_stream(&codegen, cx);
1537
1538        let mut new_text = concat!(
1539            "       let mut x = 0;\n",
1540            "       while x < 10 {\n",
1541            "           x += 1;\n",
1542            "       }",
1543        );
1544        while !new_text.is_empty() {
1545            let max_len = cmp::min(new_text.len(), 10);
1546            let len = rng.random_range(1..=max_len);
1547            let (chunk, suffix) = new_text.split_at(len);
1548            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1549            new_text = suffix;
1550            cx.background_executor.run_until_parked();
1551        }
1552        drop(chunks_tx);
1553        cx.background_executor.run_until_parked();
1554
1555        assert_eq!(
1556            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1557            indoc! {"
1558                fn main() {
1559                    let mut x = 0;
1560                    while x < 10 {
1561                        x += 1;
1562                    }
1563                }
1564            "}
1565        );
1566    }
1567
1568    #[gpui::test(iterations = 10)]
1569    async fn test_autoindent_when_generating_past_indentation(
1570        cx: &mut TestAppContext,
1571        mut rng: StdRng,
1572    ) {
1573        init_test(cx);
1574
1575        let text = indoc! {"
1576            fn main() {
1577                le
1578            }
1579        "};
1580        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1581        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1582        let range = buffer.read_with(cx, |buffer, cx| {
1583            let snapshot = buffer.snapshot(cx);
1584            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1585        });
1586        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1587        let codegen = cx.new(|cx| {
1588            CodegenAlternative::new(
1589                buffer.clone(),
1590                range.clone(),
1591                true,
1592                prompt_builder,
1593                Uuid::new_v4(),
1594                cx,
1595            )
1596        });
1597
1598        let chunks_tx = simulate_response_stream(&codegen, cx);
1599
1600        cx.background_executor.run_until_parked();
1601
1602        let mut new_text = concat!(
1603            "t mut x = 0;\n",
1604            "while x < 10 {\n",
1605            "    x += 1;\n",
1606            "}", //
1607        );
1608        while !new_text.is_empty() {
1609            let max_len = cmp::min(new_text.len(), 10);
1610            let len = rng.random_range(1..=max_len);
1611            let (chunk, suffix) = new_text.split_at(len);
1612            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1613            new_text = suffix;
1614            cx.background_executor.run_until_parked();
1615        }
1616        drop(chunks_tx);
1617        cx.background_executor.run_until_parked();
1618
1619        assert_eq!(
1620            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1621            indoc! {"
1622                fn main() {
1623                    let mut x = 0;
1624                    while x < 10 {
1625                        x += 1;
1626                    }
1627                }
1628            "}
1629        );
1630    }
1631
1632    #[gpui::test(iterations = 10)]
1633    async fn test_autoindent_when_generating_before_indentation(
1634        cx: &mut TestAppContext,
1635        mut rng: StdRng,
1636    ) {
1637        init_test(cx);
1638
1639        let text = concat!(
1640            "fn main() {\n",
1641            "  \n",
1642            "}\n" //
1643        );
1644        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1645        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1646        let range = buffer.read_with(cx, |buffer, cx| {
1647            let snapshot = buffer.snapshot(cx);
1648            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1649        });
1650        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1651        let codegen = cx.new(|cx| {
1652            CodegenAlternative::new(
1653                buffer.clone(),
1654                range.clone(),
1655                true,
1656                prompt_builder,
1657                Uuid::new_v4(),
1658                cx,
1659            )
1660        });
1661
1662        let chunks_tx = simulate_response_stream(&codegen, cx);
1663
1664        cx.background_executor.run_until_parked();
1665
1666        let mut new_text = concat!(
1667            "let mut x = 0;\n",
1668            "while x < 10 {\n",
1669            "    x += 1;\n",
1670            "}", //
1671        );
1672        while !new_text.is_empty() {
1673            let max_len = cmp::min(new_text.len(), 10);
1674            let len = rng.random_range(1..=max_len);
1675            let (chunk, suffix) = new_text.split_at(len);
1676            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1677            new_text = suffix;
1678            cx.background_executor.run_until_parked();
1679        }
1680        drop(chunks_tx);
1681        cx.background_executor.run_until_parked();
1682
1683        assert_eq!(
1684            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1685            indoc! {"
1686                fn main() {
1687                    let mut x = 0;
1688                    while x < 10 {
1689                        x += 1;
1690                    }
1691                }
1692            "}
1693        );
1694    }
1695
1696    #[gpui::test(iterations = 10)]
1697    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1698        init_test(cx);
1699
1700        let text = indoc! {"
1701            func main() {
1702            \tx := 0
1703            \tfor i := 0; i < 10; i++ {
1704            \t\tx++
1705            \t}
1706            }
1707        "};
1708        let buffer = cx.new(|cx| Buffer::local(text, cx));
1709        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1710        let range = buffer.read_with(cx, |buffer, cx| {
1711            let snapshot = buffer.snapshot(cx);
1712            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1713        });
1714        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1715        let codegen = cx.new(|cx| {
1716            CodegenAlternative::new(
1717                buffer.clone(),
1718                range.clone(),
1719                true,
1720                prompt_builder,
1721                Uuid::new_v4(),
1722                cx,
1723            )
1724        });
1725
1726        let chunks_tx = simulate_response_stream(&codegen, cx);
1727        let new_text = concat!(
1728            "func main() {\n",
1729            "\tx := 0\n",
1730            "\tfor x < 10 {\n",
1731            "\t\tx++\n",
1732            "\t}", //
1733        );
1734        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1735        drop(chunks_tx);
1736        cx.background_executor.run_until_parked();
1737
1738        assert_eq!(
1739            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1740            indoc! {"
1741                func main() {
1742                \tx := 0
1743                \tfor x < 10 {
1744                \t\tx++
1745                \t}
1746                }
1747            "}
1748        );
1749    }
1750
1751    #[gpui::test]
1752    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1753        init_test(cx);
1754
1755        let text = indoc! {"
1756            fn main() {
1757                let x = 0;
1758            }
1759        "};
1760        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1761        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1762        let range = buffer.read_with(cx, |buffer, cx| {
1763            let snapshot = buffer.snapshot(cx);
1764            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1765        });
1766        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1767        let codegen = cx.new(|cx| {
1768            CodegenAlternative::new(
1769                buffer.clone(),
1770                range.clone(),
1771                false,
1772                prompt_builder,
1773                Uuid::new_v4(),
1774                cx,
1775            )
1776        });
1777
1778        let chunks_tx = simulate_response_stream(&codegen, cx);
1779        chunks_tx
1780            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1781            .unwrap();
1782        drop(chunks_tx);
1783        cx.run_until_parked();
1784
1785        // The codegen is inactive, so the buffer doesn't get modified.
1786        assert_eq!(
1787            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1788            text
1789        );
1790
1791        // Activating the codegen applies the changes.
1792        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1793        assert_eq!(
1794            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1795            indoc! {"
1796                fn main() {
1797                    let mut x = 0;
1798                    x += 1;
1799                }
1800            "}
1801        );
1802
1803        // Deactivating the codegen undoes the changes.
1804        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1805        cx.run_until_parked();
1806        assert_eq!(
1807            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1808            text
1809        );
1810    }
1811
1812    // When not streaming tool calls, we strip backticks as part of parsing the model's
1813    // plain text response. This is a regression test for a bug where we stripped
1814    // backticks incorrectly.
1815    #[gpui::test]
1816    async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
1817        init_test(cx);
1818        let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
1819        let buffer = cx.new(|cx| Buffer::local("", cx));
1820        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1821        let range = buffer.read_with(cx, |buffer, cx| {
1822            let snapshot = buffer.snapshot(cx);
1823            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
1824        });
1825        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1826        let codegen = cx.new(|cx| {
1827            CodegenAlternative::new(
1828                buffer.clone(),
1829                range.clone(),
1830                true,
1831                prompt_builder,
1832                Uuid::new_v4(),
1833                cx,
1834            )
1835        });
1836
1837        let events_tx = simulate_tool_based_completion(&codegen, cx);
1838        let chunk_len = text.find('`').unwrap();
1839        events_tx
1840            .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
1841            .unwrap();
1842        events_tx
1843            .unbounded_send(rewrite_tool_use("tool_2", &text, true))
1844            .unwrap();
1845        events_tx
1846            .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
1847            .unwrap();
1848        drop(events_tx);
1849        cx.run_until_parked();
1850
1851        assert_eq!(
1852            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1853            text
1854        );
1855    }
1856
1857    #[gpui::test]
1858    async fn test_strip_invalid_spans_from_codeblock() {
1859        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1860        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1861        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1862        assert_chunks(
1863            "```html\n```js\nLorem ipsum dolor\n```\n```",
1864            "```js\nLorem ipsum dolor\n```",
1865        )
1866        .await;
1867        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1868        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1869        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1870        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1871
1872        async fn assert_chunks(text: &str, expected_text: &str) {
1873            for chunk_size in 1..=text.len() {
1874                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1875                    .map(|chunk| chunk.unwrap())
1876                    .collect::<String>()
1877                    .await;
1878                assert_eq!(
1879                    actual_text, expected_text,
1880                    "failed to strip invalid spans, chunk size: {}",
1881                    chunk_size
1882                );
1883            }
1884        }
1885
1886        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1887            stream::iter(
1888                text.chars()
1889                    .collect::<Vec<_>>()
1890                    .chunks(size)
1891                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1892                    .collect::<Vec<_>>(),
1893            )
1894        }
1895    }
1896
1897    fn init_test(cx: &mut TestAppContext) {
1898        cx.update(LanguageModelRegistry::test);
1899        cx.set_global(cx.update(SettingsStore::test));
1900    }
1901
1902    fn simulate_response_stream(
1903        codegen: &Entity<CodegenAlternative>,
1904        cx: &mut TestAppContext,
1905    ) -> mpsc::UnboundedSender<String> {
1906        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1907        let model = Arc::new(FakeLanguageModel::default());
1908        codegen.update(cx, |codegen, cx| {
1909            codegen.generation = codegen.handle_stream(
1910                model,
1911                /* strip_invalid_spans: */ false,
1912                future::ready(Ok(LanguageModelTextStream {
1913                    message_id: None,
1914                    stream: chunks_rx.map(Ok).boxed(),
1915                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1916                })),
1917                cx,
1918            );
1919        });
1920        chunks_tx
1921    }
1922
1923    fn simulate_tool_based_completion(
1924        codegen: &Entity<CodegenAlternative>,
1925        cx: &mut TestAppContext,
1926    ) -> mpsc::UnboundedSender<LanguageModelCompletionEvent> {
1927        let (events_tx, events_rx) = mpsc::unbounded();
1928        let model = Arc::new(FakeLanguageModel::default());
1929        codegen.update(cx, |codegen, cx| {
1930            let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
1931                as BoxStream<
1932                    'static,
1933                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1934                >));
1935            codegen.generation = codegen.handle_completion(model, completion_stream, cx);
1936        });
1937        events_tx
1938    }
1939
1940    fn rewrite_tool_use(
1941        id: &str,
1942        replacement_text: &str,
1943        is_complete: bool,
1944    ) -> LanguageModelCompletionEvent {
1945        let input = RewriteSectionInput {
1946            replacement_text: replacement_text.into(),
1947        };
1948        LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
1949            id: id.into(),
1950            name: REWRITE_SECTION_TOOL_NAME.into(),
1951            raw_input: serde_json::to_string(&input).unwrap(),
1952            input: serde_json::to_value(&input).unwrap(),
1953            is_input_complete: is_complete,
1954            thought_signature: None,
1955        })
1956    }
1957}