buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use uuid::Uuid;
   5
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use futures::{
  10    SinkExt, Stream, StreamExt, TryStreamExt as _,
  11    channel::mpsc,
  12    future::{LocalBoxFuture, Shared},
  13    join,
  14    stream::BoxStream,
  15};
  16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  17use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  18use language_model::{
  19    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  20    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  21    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  22    LanguageModelToolUse, Role, TokenUsage,
  23};
  24use multi_buffer::MultiBufferRow;
  25use parking_lot::Mutex;
  26use prompt_store::PromptBuilder;
  27use rope::Rope;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings as _;
  31use smol::future::FutureExt;
  32use std::{
  33    cmp,
  34    future::Future,
  35    iter,
  36    ops::{Range, RangeInclusive},
  37    pin::Pin,
  38    sync::Arc,
  39    task::{self, Poll},
  40    time::Instant,
  41};
  42use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  43
  44/// Use this tool when you cannot or should not make a rewrite. This includes:
  45/// - The user's request is unclear, ambiguous, or nonsensical
  46/// - The requested change cannot be made by only editing the <rewrite_this> section
  47#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  48pub struct FailureMessageInput {
  49    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  50    #[serde(default)]
  51    pub message: String,
  52}
  53
  54/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  55/// Only use this tool when you are confident you understand the user's request and can fulfill it
  56/// by editing the marked section.
  57#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  58pub struct RewriteSectionInput {
  59    /// The text to replace the section with.
  60    #[serde(default)]
  61    pub replacement_text: String,
  62}
  63
  64pub struct BufferCodegen {
  65    alternatives: Vec<Entity<CodegenAlternative>>,
  66    pub active_alternative: usize,
  67    seen_alternatives: HashSet<usize>,
  68    subscriptions: Vec<Subscription>,
  69    buffer: Entity<MultiBuffer>,
  70    range: Range<Anchor>,
  71    initial_transaction_id: Option<TransactionId>,
  72    builder: Arc<PromptBuilder>,
  73    pub is_insertion: bool,
  74    session_id: Uuid,
  75}
  76
  77pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
  78pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
  79
  80impl BufferCodegen {
  81    pub fn new(
  82        buffer: Entity<MultiBuffer>,
  83        range: Range<Anchor>,
  84        initial_transaction_id: Option<TransactionId>,
  85        session_id: Uuid,
  86        builder: Arc<PromptBuilder>,
  87        cx: &mut Context<Self>,
  88    ) -> Self {
  89        let codegen = cx.new(|cx| {
  90            CodegenAlternative::new(
  91                buffer.clone(),
  92                range.clone(),
  93                false,
  94                builder.clone(),
  95                session_id,
  96                cx,
  97            )
  98        });
  99        let mut this = Self {
 100            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
 101            alternatives: vec![codegen],
 102            active_alternative: 0,
 103            seen_alternatives: HashSet::default(),
 104            subscriptions: Vec::new(),
 105            buffer,
 106            range,
 107            initial_transaction_id,
 108            builder,
 109            session_id,
 110        };
 111        this.activate(0, cx);
 112        this
 113    }
 114
 115    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 116        let codegen = self.active_alternative().clone();
 117        self.subscriptions.clear();
 118        self.subscriptions
 119            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 120        self.subscriptions
 121            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 122    }
 123
 124    pub fn active_completion(&self, cx: &App) -> Option<String> {
 125        self.active_alternative().read(cx).current_completion()
 126    }
 127
 128    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 129        &self.alternatives[self.active_alternative]
 130    }
 131
 132    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 133        self.active_alternative().read(cx).language_name(cx)
 134    }
 135
 136    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 137        &self.active_alternative().read(cx).status
 138    }
 139
 140    pub fn alternative_count(&self, cx: &App) -> usize {
 141        LanguageModelRegistry::read_global(cx)
 142            .inline_alternative_models()
 143            .len()
 144            + 1
 145    }
 146
 147    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 148        let next_active_ix = if self.active_alternative == 0 {
 149            self.alternatives.len() - 1
 150        } else {
 151            self.active_alternative - 1
 152        };
 153        self.activate(next_active_ix, cx);
 154    }
 155
 156    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 157        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 158        self.activate(next_active_ix, cx);
 159    }
 160
 161    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 162        self.active_alternative()
 163            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 164        self.seen_alternatives.insert(index);
 165        self.active_alternative = index;
 166        self.active_alternative()
 167            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 168        self.subscribe_to_alternative(cx);
 169        cx.notify();
 170    }
 171
 172    pub fn start(
 173        &mut self,
 174        primary_model: Arc<dyn LanguageModel>,
 175        user_prompt: String,
 176        context_task: Shared<Task<Option<LoadedContext>>>,
 177        cx: &mut Context<Self>,
 178    ) -> Result<()> {
 179        let alternative_models = LanguageModelRegistry::read_global(cx)
 180            .inline_alternative_models()
 181            .to_vec();
 182
 183        self.active_alternative()
 184            .update(cx, |alternative, cx| alternative.undo(cx));
 185        self.activate(0, cx);
 186        self.alternatives.truncate(1);
 187
 188        for _ in 0..alternative_models.len() {
 189            self.alternatives.push(cx.new(|cx| {
 190                CodegenAlternative::new(
 191                    self.buffer.clone(),
 192                    self.range.clone(),
 193                    false,
 194                    self.builder.clone(),
 195                    self.session_id,
 196                    cx,
 197                )
 198            }));
 199        }
 200
 201        for (model, alternative) in iter::once(primary_model)
 202            .chain(alternative_models)
 203            .zip(&self.alternatives)
 204        {
 205            alternative.update(cx, |alternative, cx| {
 206                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 207            })?;
 208        }
 209
 210        Ok(())
 211    }
 212
 213    pub fn stop(&mut self, cx: &mut Context<Self>) {
 214        for codegen in &self.alternatives {
 215            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 216        }
 217    }
 218
 219    pub fn undo(&mut self, cx: &mut Context<Self>) {
 220        self.active_alternative()
 221            .update(cx, |codegen, cx| codegen.undo(cx));
 222
 223        self.buffer.update(cx, |buffer, cx| {
 224            if let Some(transaction_id) = self.initial_transaction_id.take() {
 225                buffer.undo_transaction(transaction_id, cx);
 226                buffer.refresh_preview(cx);
 227            }
 228        });
 229    }
 230
 231    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 232        self.active_alternative().read(cx).buffer.clone()
 233    }
 234
 235    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 236        self.active_alternative().read(cx).old_buffer.clone()
 237    }
 238
 239    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 240        self.active_alternative().read(cx).snapshot.clone()
 241    }
 242
 243    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 244        self.active_alternative().read(cx).edit_position
 245    }
 246
 247    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 248        &self.active_alternative().read(cx).diff
 249    }
 250
 251    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 252        self.active_alternative().read(cx).last_equal_ranges()
 253    }
 254
 255    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 256        self.active_alternative().read(cx).selected_text()
 257    }
 258
 259    pub fn session_id(&self) -> Uuid {
 260        self.session_id
 261    }
 262}
 263
 264impl EventEmitter<CodegenEvent> for BufferCodegen {}
 265
 266pub struct CodegenAlternative {
 267    buffer: Entity<MultiBuffer>,
 268    old_buffer: Entity<Buffer>,
 269    snapshot: MultiBufferSnapshot,
 270    edit_position: Option<Anchor>,
 271    range: Range<Anchor>,
 272    last_equal_ranges: Vec<Range<Anchor>>,
 273    transformation_transaction_id: Option<TransactionId>,
 274    status: CodegenStatus,
 275    generation: Task<()>,
 276    diff: Diff,
 277    _subscription: gpui::Subscription,
 278    builder: Arc<PromptBuilder>,
 279    active: bool,
 280    edits: Vec<(Range<Anchor>, String)>,
 281    line_operations: Vec<LineOperation>,
 282    elapsed_time: Option<f64>,
 283    completion: Option<String>,
 284    selected_text: Option<String>,
 285    pub message_id: Option<String>,
 286    session_id: Uuid,
 287    pub description: Option<String>,
 288    pub failure: Option<String>,
 289}
 290
 291impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 292
 293impl CodegenAlternative {
 294    pub fn new(
 295        buffer: Entity<MultiBuffer>,
 296        range: Range<Anchor>,
 297        active: bool,
 298        builder: Arc<PromptBuilder>,
 299        session_id: Uuid,
 300        cx: &mut Context<Self>,
 301    ) -> Self {
 302        let snapshot = buffer.read(cx).snapshot(cx);
 303
 304        let (old_buffer, _, _) = snapshot
 305            .range_to_buffer_ranges(range.start..=range.end)
 306            .pop()
 307            .unwrap();
 308        let old_buffer = cx.new(|cx| {
 309            let text = old_buffer.as_rope().clone();
 310            let line_ending = old_buffer.line_ending();
 311            let language = old_buffer.language().cloned();
 312            let language_registry = buffer
 313                .read(cx)
 314                .buffer(old_buffer.remote_id())
 315                .unwrap()
 316                .read(cx)
 317                .language_registry();
 318
 319            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 320            buffer.set_language(language, cx);
 321            if let Some(language_registry) = language_registry {
 322                buffer.set_language_registry(language_registry);
 323            }
 324            buffer
 325        });
 326
 327        Self {
 328            buffer: buffer.clone(),
 329            old_buffer,
 330            edit_position: None,
 331            message_id: None,
 332            snapshot,
 333            last_equal_ranges: Default::default(),
 334            transformation_transaction_id: None,
 335            status: CodegenStatus::Idle,
 336            generation: Task::ready(()),
 337            diff: Diff::default(),
 338            builder,
 339            active: active,
 340            edits: Vec::new(),
 341            line_operations: Vec::new(),
 342            range,
 343            elapsed_time: None,
 344            completion: None,
 345            selected_text: None,
 346            session_id,
 347            description: None,
 348            failure: None,
 349            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 350        }
 351    }
 352
 353    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 354        self.old_buffer
 355            .read(cx)
 356            .language()
 357            .map(|language| language.name())
 358    }
 359
 360    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 361        if active != self.active {
 362            self.active = active;
 363
 364            if self.active {
 365                let edits = self.edits.clone();
 366                self.apply_edits(edits, cx);
 367                if matches!(self.status, CodegenStatus::Pending) {
 368                    let line_operations = self.line_operations.clone();
 369                    self.reapply_line_based_diff(line_operations, cx);
 370                } else {
 371                    self.reapply_batch_diff(cx).detach();
 372                }
 373            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 374                self.buffer.update(cx, |buffer, cx| {
 375                    buffer.undo_transaction(transaction_id, cx);
 376                    buffer.forget_transaction(transaction_id, cx);
 377                });
 378            }
 379        }
 380    }
 381
 382    fn handle_buffer_event(
 383        &mut self,
 384        _buffer: Entity<MultiBuffer>,
 385        event: &multi_buffer::Event,
 386        cx: &mut Context<Self>,
 387    ) {
 388        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 389            && self.transformation_transaction_id == Some(*transaction_id)
 390        {
 391            self.transformation_transaction_id = None;
 392            self.generation = Task::ready(());
 393            cx.emit(CodegenEvent::Undone);
 394        }
 395    }
 396
 397    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 398        &self.last_equal_ranges
 399    }
 400
 401    pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 402        model.supports_streaming_tools()
 403            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 404    }
 405
 406    pub fn start(
 407        &mut self,
 408        user_prompt: String,
 409        context_task: Shared<Task<Option<LoadedContext>>>,
 410        model: Arc<dyn LanguageModel>,
 411        cx: &mut Context<Self>,
 412    ) -> Result<()> {
 413        // Clear the model explanation since the user has started a new generation.
 414        self.description = None;
 415
 416        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 417            self.buffer.update(cx, |buffer, cx| {
 418                buffer.undo_transaction(transformation_transaction_id, cx);
 419            });
 420        }
 421
 422        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 423
 424        if Self::use_streaming_tools(model.as_ref(), cx) {
 425            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 426            let completion_events = cx.spawn({
 427                let model = model.clone();
 428                async move |_, cx| model.stream_completion(request.await, cx).await
 429            });
 430            self.generation = self.handle_completion(model, completion_events, cx);
 431        } else {
 432            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 433                if user_prompt.trim().to_lowercase() == "delete" {
 434                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 435                } else {
 436                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 437                    cx.spawn({
 438                        let model = model.clone();
 439                        async move |_, cx| {
 440                            Ok(model.stream_completion_text(request.await, cx).await?)
 441                        }
 442                    })
 443                    .boxed_local()
 444                };
 445            self.generation =
 446                self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
 447        }
 448
 449        Ok(())
 450    }
 451
 452    fn build_request_tools(
 453        &self,
 454        model: &Arc<dyn LanguageModel>,
 455        user_prompt: String,
 456        context_task: Shared<Task<Option<LoadedContext>>>,
 457        cx: &mut App,
 458    ) -> Result<Task<LanguageModelRequest>> {
 459        let buffer = self.buffer.read(cx).snapshot(cx);
 460        let language = buffer.language_at(self.range.start);
 461        let language_name = if let Some(language) = language.as_ref() {
 462            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 463                None
 464            } else {
 465                Some(language.name())
 466            }
 467        } else {
 468            None
 469        };
 470
 471        let language_name = language_name.as_ref();
 472        let start = buffer.point_to_buffer_offset(self.range.start);
 473        let end = buffer.point_to_buffer_offset(self.range.end);
 474        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 475            let (start_buffer, start_buffer_offset) = start;
 476            let (end_buffer, end_buffer_offset) = end;
 477            if start_buffer.remote_id() == end_buffer.remote_id() {
 478                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 479            } else {
 480                anyhow::bail!("invalid transformation range");
 481            }
 482        } else {
 483            anyhow::bail!("invalid transformation range");
 484        };
 485
 486        let system_prompt = self
 487            .builder
 488            .generate_inline_transformation_prompt_tools(
 489                language_name,
 490                buffer,
 491                range.start.0..range.end.0,
 492            )
 493            .context("generating content prompt")?;
 494
 495        let temperature = AgentSettings::temperature_for_model(model, cx);
 496
 497        let tool_input_format = model.tool_input_format();
 498        let tool_choice = model
 499            .supports_tool_choice(LanguageModelToolChoice::Any)
 500            .then_some(LanguageModelToolChoice::Any);
 501
 502        Ok(cx.spawn(async move |_cx| {
 503            let mut messages = vec![LanguageModelRequestMessage {
 504                role: Role::System,
 505                content: vec![system_prompt.into()],
 506                cache: false,
 507                reasoning_details: None,
 508            }];
 509
 510            let mut user_message = LanguageModelRequestMessage {
 511                role: Role::User,
 512                content: Vec::new(),
 513                cache: false,
 514                reasoning_details: None,
 515            };
 516
 517            if let Some(context) = context_task.await {
 518                context.add_to_request_message(&mut user_message);
 519            }
 520
 521            user_message.content.push(user_prompt.into());
 522            messages.push(user_message);
 523
 524            let tools = vec![
 525                LanguageModelRequestTool {
 526                    name: REWRITE_SECTION_TOOL_NAME.to_string(),
 527                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 528                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 529                    use_input_streaming: false,
 530                },
 531                LanguageModelRequestTool {
 532                    name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
 533                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 534                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 535                    use_input_streaming: false,
 536                },
 537            ];
 538
 539            LanguageModelRequest {
 540                thread_id: None,
 541                prompt_id: None,
 542                intent: Some(CompletionIntent::InlineAssist),
 543                tools,
 544                tool_choice,
 545                stop: Vec::new(),
 546                temperature,
 547                messages,
 548                thinking_allowed: false,
 549                thinking_effort: None,
 550                speed: None,
 551            }
 552        }))
 553    }
 554
 555    fn build_request(
 556        &self,
 557        model: &Arc<dyn LanguageModel>,
 558        user_prompt: String,
 559        context_task: Shared<Task<Option<LoadedContext>>>,
 560        cx: &mut App,
 561    ) -> Result<Task<LanguageModelRequest>> {
 562        if Self::use_streaming_tools(model.as_ref(), cx) {
 563            return self.build_request_tools(model, user_prompt, context_task, cx);
 564        }
 565
 566        let buffer = self.buffer.read(cx).snapshot(cx);
 567        let language = buffer.language_at(self.range.start);
 568        let language_name = if let Some(language) = language.as_ref() {
 569            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 570                None
 571            } else {
 572                Some(language.name())
 573            }
 574        } else {
 575            None
 576        };
 577
 578        let language_name = language_name.as_ref();
 579        let start = buffer.point_to_buffer_offset(self.range.start);
 580        let end = buffer.point_to_buffer_offset(self.range.end);
 581        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 582            let (start_buffer, start_buffer_offset) = start;
 583            let (end_buffer, end_buffer_offset) = end;
 584            if start_buffer.remote_id() == end_buffer.remote_id() {
 585                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 586            } else {
 587                anyhow::bail!("invalid transformation range");
 588            }
 589        } else {
 590            anyhow::bail!("invalid transformation range");
 591        };
 592
 593        let prompt = self
 594            .builder
 595            .generate_inline_transformation_prompt(
 596                user_prompt,
 597                language_name,
 598                buffer,
 599                range.start.0..range.end.0,
 600            )
 601            .context("generating content prompt")?;
 602
 603        let temperature = AgentSettings::temperature_for_model(model, cx);
 604
 605        Ok(cx.spawn(async move |_cx| {
 606            let mut request_message = LanguageModelRequestMessage {
 607                role: Role::User,
 608                content: Vec::new(),
 609                cache: false,
 610                reasoning_details: None,
 611            };
 612
 613            if let Some(context) = context_task.await {
 614                context.add_to_request_message(&mut request_message);
 615            }
 616
 617            request_message.content.push(prompt.into());
 618
 619            LanguageModelRequest {
 620                thread_id: None,
 621                prompt_id: None,
 622                intent: Some(CompletionIntent::InlineAssist),
 623                tools: Vec::new(),
 624                tool_choice: None,
 625                stop: Vec::new(),
 626                temperature,
 627                messages: vec![request_message],
 628                thinking_allowed: false,
 629                thinking_effort: None,
 630                speed: None,
 631            }
 632        }))
 633    }
 634
 635    pub fn handle_stream(
 636        &mut self,
 637        model: Arc<dyn LanguageModel>,
 638        strip_invalid_spans: bool,
 639        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 640        cx: &mut Context<Self>,
 641    ) -> Task<()> {
 642        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 643        let session_id = self.session_id;
 644        let model_telemetry_id = model.telemetry_id();
 645        let model_provider_id = model.provider_id().to_string();
 646        let start_time = Instant::now();
 647
 648        // Make a new snapshot and re-resolve anchor in case the document was modified.
 649        // This can happen often if the editor loses focus and is saved + reformatted,
 650        // as in https://github.com/zed-industries/zed/issues/39088
 651        self.snapshot = self.buffer.read(cx).snapshot(cx);
 652        self.range = self.snapshot.anchor_after(self.range.start)
 653            ..self.snapshot.anchor_after(self.range.end);
 654
 655        let snapshot = self.snapshot.clone();
 656        let selected_text = snapshot
 657            .text_for_range(self.range.start..self.range.end)
 658            .collect::<Rope>();
 659
 660        self.selected_text = Some(selected_text.to_string());
 661
 662        let selection_start = self.range.start.to_point(&snapshot);
 663
 664        // Start with the indentation of the first line in the selection
 665        let mut suggested_line_indent = snapshot
 666            .suggested_indents(selection_start.row..=selection_start.row, cx)
 667            .into_values()
 668            .next()
 669            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 670
 671        // If the first line in the selection does not have indentation, check the following lines
 672        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 673            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 674                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 675                // Prefer tabs if a line in the selection uses tabs as indentation
 676                if line_indent.kind == IndentKind::Tab {
 677                    suggested_line_indent.kind = IndentKind::Tab;
 678                    break;
 679                }
 680            }
 681        }
 682
 683        let language_name = {
 684            let multibuffer = self.buffer.read(cx);
 685            let snapshot = multibuffer.snapshot(cx);
 686            let ranges = snapshot.range_to_buffer_ranges(self.range.start..=self.range.end);
 687            ranges
 688                .first()
 689                .and_then(|(buffer, _, _)| buffer.language())
 690                .map(|language| language.name())
 691        };
 692
 693        self.diff = Diff::default();
 694        self.status = CodegenStatus::Pending;
 695        let mut edit_start = self.range.start.to_offset(&snapshot);
 696        let completion = Arc::new(Mutex::new(String::new()));
 697        let completion_clone = completion.clone();
 698
 699        cx.notify();
 700        cx.spawn(async move |codegen, cx| {
 701            let stream = stream.await;
 702
 703            let token_usage = stream
 704                .as_ref()
 705                .ok()
 706                .map(|stream| stream.last_token_usage.clone());
 707            let message_id = stream
 708                .as_ref()
 709                .ok()
 710                .and_then(|stream| stream.message_id.clone());
 711            let generate = async {
 712                let model_telemetry_id = model_telemetry_id.clone();
 713                let model_provider_id = model_provider_id.clone();
 714                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 715                let message_id = message_id.clone();
 716                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 717                    let anthropic_reporter = anthropic_reporter.clone();
 718                    let language_name = language_name.clone();
 719                    async move {
 720                        let mut response_latency = None;
 721                        let request_start = Instant::now();
 722                        let diff = async {
 723                            let raw_stream = stream?.stream.map_err(|error| error.into());
 724
 725                            let stripped;
 726                            let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
 727                                if strip_invalid_spans {
 728                                    stripped = StripInvalidSpans::new(raw_stream);
 729                                    Box::pin(stripped)
 730                                } else {
 731                                    Box::pin(raw_stream)
 732                                };
 733
 734                            let mut diff = StreamingDiff::new(selected_text.to_string());
 735                            let mut line_diff = LineDiff::default();
 736
 737                            let mut new_text = String::new();
 738                            let mut base_indent = None;
 739                            let mut line_indent = None;
 740                            let mut first_line = true;
 741
 742                            while let Some(chunk) = chunks.next().await {
 743                                if response_latency.is_none() {
 744                                    response_latency = Some(request_start.elapsed());
 745                                }
 746                                let chunk = chunk?;
 747                                completion_clone.lock().push_str(&chunk);
 748
 749                                let mut lines = chunk.split('\n').peekable();
 750                                while let Some(line) = lines.next() {
 751                                    new_text.push_str(line);
 752                                    if line_indent.is_none()
 753                                        && let Some(non_whitespace_ch_ix) =
 754                                            new_text.find(|ch: char| !ch.is_whitespace())
 755                                    {
 756                                        line_indent = Some(non_whitespace_ch_ix);
 757                                        base_indent = base_indent.or(line_indent);
 758
 759                                        let line_indent = line_indent.unwrap();
 760                                        let base_indent = base_indent.unwrap();
 761                                        let indent_delta = line_indent as i32 - base_indent as i32;
 762                                        let mut corrected_indent_len = cmp::max(
 763                                            0,
 764                                            suggested_line_indent.len as i32 + indent_delta,
 765                                        )
 766                                            as usize;
 767                                        if first_line {
 768                                            corrected_indent_len = corrected_indent_len
 769                                                .saturating_sub(selection_start.column as usize);
 770                                        }
 771
 772                                        let indent_char = suggested_line_indent.char();
 773                                        let mut indent_buffer = [0; 4];
 774                                        let indent_str =
 775                                            indent_char.encode_utf8(&mut indent_buffer);
 776                                        new_text.replace_range(
 777                                            ..line_indent,
 778                                            &indent_str.repeat(corrected_indent_len),
 779                                        );
 780                                    }
 781
 782                                    if line_indent.is_some() {
 783                                        let char_ops = diff.push_new(&new_text);
 784                                        line_diff.push_char_operations(&char_ops, &selected_text);
 785                                        diff_tx
 786                                            .send((char_ops, line_diff.line_operations()))
 787                                            .await?;
 788                                        new_text.clear();
 789                                    }
 790
 791                                    if lines.peek().is_some() {
 792                                        let char_ops = diff.push_new("\n");
 793                                        line_diff.push_char_operations(&char_ops, &selected_text);
 794                                        diff_tx
 795                                            .send((char_ops, line_diff.line_operations()))
 796                                            .await?;
 797                                        if line_indent.is_none() {
 798                                            // Don't write out the leading indentation in empty lines on the next line
 799                                            // This is the case where the above if statement didn't clear the buffer
 800                                            new_text.clear();
 801                                        }
 802                                        line_indent = None;
 803                                        first_line = false;
 804                                    }
 805                                }
 806                            }
 807
 808                            let mut char_ops = diff.push_new(&new_text);
 809                            char_ops.extend(diff.finish());
 810                            line_diff.push_char_operations(&char_ops, &selected_text);
 811                            line_diff.finish(&selected_text);
 812                            diff_tx
 813                                .send((char_ops, line_diff.line_operations()))
 814                                .await?;
 815
 816                            anyhow::Ok(())
 817                        };
 818
 819                        let result = diff.await;
 820
 821                        let error_message = result.as_ref().err().map(|error| error.to_string());
 822                        telemetry::event!(
 823                            "Assistant Responded",
 824                            kind = "inline",
 825                            phase = "response",
 826                            session_id = session_id.to_string(),
 827                            model = model_telemetry_id,
 828                            model_provider = model_provider_id,
 829                            language_name = language_name.as_ref().map(|n| n.to_string()),
 830                            message_id = message_id.as_deref(),
 831                            response_latency = response_latency,
 832                            error_message = error_message.as_deref(),
 833                        );
 834
 835                        anthropic_reporter.report(language_model::AnthropicEventData {
 836                            completion_type: language_model::AnthropicCompletionType::Editor,
 837                            event: language_model::AnthropicEventType::Response,
 838                            language_name: language_name.map(|n| n.to_string()),
 839                            message_id,
 840                        });
 841
 842                        result?;
 843                        Ok(())
 844                    }
 845                });
 846
 847                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 848                    codegen.update(cx, |codegen, cx| {
 849                        codegen.last_equal_ranges.clear();
 850
 851                        let edits = char_ops
 852                            .into_iter()
 853                            .filter_map(|operation| match operation {
 854                                CharOperation::Insert { text } => {
 855                                    let edit_start = snapshot.anchor_after(edit_start);
 856                                    Some((edit_start..edit_start, text))
 857                                }
 858                                CharOperation::Delete { bytes } => {
 859                                    let edit_end = edit_start + bytes;
 860                                    let edit_range = snapshot.anchor_after(edit_start)
 861                                        ..snapshot.anchor_before(edit_end);
 862                                    edit_start = edit_end;
 863                                    Some((edit_range, String::new()))
 864                                }
 865                                CharOperation::Keep { bytes } => {
 866                                    let edit_end = edit_start + bytes;
 867                                    let edit_range = snapshot.anchor_after(edit_start)
 868                                        ..snapshot.anchor_before(edit_end);
 869                                    edit_start = edit_end;
 870                                    codegen.last_equal_ranges.push(edit_range);
 871                                    None
 872                                }
 873                            })
 874                            .collect::<Vec<_>>();
 875
 876                        if codegen.active {
 877                            codegen.apply_edits(edits.iter().cloned(), cx);
 878                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 879                        }
 880                        codegen.edits.extend(edits);
 881                        codegen.line_operations = line_ops;
 882                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 883
 884                        cx.notify();
 885                    })?;
 886                }
 887
 888                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 889                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 890                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 891                let batch_diff_task =
 892                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 893                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 894                line_based_stream_diff?;
 895
 896                anyhow::Ok(())
 897            };
 898
 899            let result = generate.await;
 900            let elapsed_time = start_time.elapsed().as_secs_f64();
 901
 902            codegen
 903                .update(cx, |this, cx| {
 904                    this.message_id = message_id;
 905                    this.last_equal_ranges.clear();
 906                    if let Err(error) = result {
 907                        this.status = CodegenStatus::Error(error);
 908                    } else {
 909                        this.status = CodegenStatus::Done;
 910                    }
 911                    this.elapsed_time = Some(elapsed_time);
 912                    this.completion = Some(completion.lock().clone());
 913                    if let Some(usage) = token_usage {
 914                        let usage = usage.lock();
 915                        telemetry::event!(
 916                            "Inline Assistant Completion",
 917                            model = model_telemetry_id,
 918                            model_provider = model_provider_id,
 919                            input_tokens = usage.input_tokens,
 920                            output_tokens = usage.output_tokens,
 921                        )
 922                    }
 923
 924                    cx.emit(CodegenEvent::Finished);
 925                    cx.notify();
 926                })
 927                .ok();
 928        })
 929    }
 930
 931    pub fn current_completion(&self) -> Option<String> {
 932        self.completion.clone()
 933    }
 934
 935    #[cfg(any(test, feature = "test-support"))]
 936    pub fn current_description(&self) -> Option<String> {
 937        self.description.clone()
 938    }
 939
 940    #[cfg(any(test, feature = "test-support"))]
 941    pub fn current_failure(&self) -> Option<String> {
 942        self.failure.clone()
 943    }
 944
 945    pub fn selected_text(&self) -> Option<&str> {
 946        self.selected_text.as_deref()
 947    }
 948
 949    pub fn stop(&mut self, cx: &mut Context<Self>) {
 950        self.last_equal_ranges.clear();
 951        if self.diff.is_empty() {
 952            self.status = CodegenStatus::Idle;
 953        } else {
 954            self.status = CodegenStatus::Done;
 955        }
 956        self.generation = Task::ready(());
 957        cx.emit(CodegenEvent::Finished);
 958        cx.notify();
 959    }
 960
 961    pub fn undo(&mut self, cx: &mut Context<Self>) {
 962        self.buffer.update(cx, |buffer, cx| {
 963            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 964                buffer.undo_transaction(transaction_id, cx);
 965                buffer.refresh_preview(cx);
 966            }
 967        });
 968    }
 969
 970    fn apply_edits(
 971        &mut self,
 972        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 973        cx: &mut Context<CodegenAlternative>,
 974    ) {
 975        let transaction = self.buffer.update(cx, |buffer, cx| {
 976            // Avoid grouping agent edits with user edits.
 977            buffer.finalize_last_transaction(cx);
 978            buffer.start_transaction(cx);
 979            buffer.edit(edits, None, cx);
 980            buffer.end_transaction(cx)
 981        });
 982
 983        if let Some(transaction) = transaction {
 984            if let Some(first_transaction) = self.transformation_transaction_id {
 985                // Group all agent edits into the first transaction.
 986                self.buffer.update(cx, |buffer, cx| {
 987                    buffer.merge_transactions(transaction, first_transaction, cx)
 988                });
 989            } else {
 990                self.transformation_transaction_id = Some(transaction);
 991                self.buffer
 992                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 993            }
 994        }
 995    }
 996
 997    fn reapply_line_based_diff(
 998        &mut self,
 999        line_operations: impl IntoIterator<Item = LineOperation>,
1000        cx: &mut Context<Self>,
1001    ) {
1002        let old_snapshot = self.snapshot.clone();
1003        let old_range = self.range.to_point(&old_snapshot);
1004        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1005        let new_range = self.range.to_point(&new_snapshot);
1006
1007        let mut old_row = old_range.start.row;
1008        let mut new_row = new_range.start.row;
1009
1010        self.diff.deleted_row_ranges.clear();
1011        self.diff.inserted_row_ranges.clear();
1012        for operation in line_operations {
1013            match operation {
1014                LineOperation::Keep { lines } => {
1015                    old_row += lines;
1016                    new_row += lines;
1017                }
1018                LineOperation::Delete { lines } => {
1019                    let old_end_row = old_row + lines - 1;
1020                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1021
1022                    if let Some((_, last_deleted_row_range)) =
1023                        self.diff.deleted_row_ranges.last_mut()
1024                    {
1025                        if *last_deleted_row_range.end() + 1 == old_row {
1026                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1027                        } else {
1028                            self.diff
1029                                .deleted_row_ranges
1030                                .push((new_row, old_row..=old_end_row));
1031                        }
1032                    } else {
1033                        self.diff
1034                            .deleted_row_ranges
1035                            .push((new_row, old_row..=old_end_row));
1036                    }
1037
1038                    old_row += lines;
1039                }
1040                LineOperation::Insert { lines } => {
1041                    let new_end_row = new_row + lines - 1;
1042                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1043                    let end = new_snapshot.anchor_before(Point::new(
1044                        new_end_row,
1045                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1046                    ));
1047                    self.diff.inserted_row_ranges.push(start..end);
1048                    new_row += lines;
1049                }
1050            }
1051
1052            cx.notify();
1053        }
1054    }
1055
1056    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1057        let old_snapshot = self.snapshot.clone();
1058        let old_range = self.range.to_point(&old_snapshot);
1059        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1060        let new_range = self.range.to_point(&new_snapshot);
1061
1062        cx.spawn(async move |codegen, cx| {
1063            let (deleted_row_ranges, inserted_row_ranges) = cx
1064                .background_spawn(async move {
1065                    let old_text = old_snapshot
1066                        .text_for_range(
1067                            Point::new(old_range.start.row, 0)
1068                                ..Point::new(
1069                                    old_range.end.row,
1070                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1071                                ),
1072                        )
1073                        .collect::<String>();
1074                    let new_text = new_snapshot
1075                        .text_for_range(
1076                            Point::new(new_range.start.row, 0)
1077                                ..Point::new(
1078                                    new_range.end.row,
1079                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1080                                ),
1081                        )
1082                        .collect::<String>();
1083
1084                    let old_start_row = old_range.start.row;
1085                    let new_start_row = new_range.start.row;
1086                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1087                    let mut inserted_row_ranges = Vec::new();
1088                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1089                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1090                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1091                        if !old_rows.is_empty() {
1092                            deleted_row_ranges.push((
1093                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1094                                old_rows.start..=old_rows.end - 1,
1095                            ));
1096                        }
1097                        if !new_rows.is_empty() {
1098                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1099                            let new_end_row = new_rows.end - 1;
1100                            let end = new_snapshot.anchor_before(Point::new(
1101                                new_end_row,
1102                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1103                            ));
1104                            inserted_row_ranges.push(start..end);
1105                        }
1106                    }
1107                    (deleted_row_ranges, inserted_row_ranges)
1108                })
1109                .await;
1110
1111            codegen
1112                .update(cx, |codegen, cx| {
1113                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1114                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1115                    cx.notify();
1116                })
1117                .ok();
1118        })
1119    }
1120
1121    fn handle_completion(
1122        &mut self,
1123        model: Arc<dyn LanguageModel>,
1124        completion_stream: Task<
1125            Result<
1126                BoxStream<
1127                    'static,
1128                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1129                >,
1130                LanguageModelCompletionError,
1131            >,
1132        >,
1133        cx: &mut Context<Self>,
1134    ) -> Task<()> {
1135        self.diff = Diff::default();
1136        self.status = CodegenStatus::Pending;
1137
1138        cx.notify();
1139        // Leaving this in generation so that STOP equivalent events are respected even
1140        // while we're still pre-processing the completion event
1141        cx.spawn(async move |codegen, cx| {
1142            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1143                let _ = codegen.update(cx, |this, cx| {
1144                    this.status = status;
1145                    cx.emit(CodegenEvent::Finished);
1146                    cx.notify();
1147                });
1148            };
1149
1150            let mut completion_events = match completion_stream.await {
1151                Ok(events) => events,
1152                Err(err) => {
1153                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1154                    return;
1155                }
1156            };
1157
1158            enum ToolUseOutput {
1159                Rewrite {
1160                    text: String,
1161                    description: Option<String>,
1162                },
1163                Failure(String),
1164            }
1165
1166            enum ModelUpdate {
1167                Description(String),
1168                Failure(String),
1169            }
1170
1171            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1172            let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1173                let mut chars_read_so_far = chars_read_so_far.lock();
1174                match tool_use.name.as_ref() {
1175                    REWRITE_SECTION_TOOL_NAME => {
1176                        let Ok(input) =
1177                            serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1178                        else {
1179                            return None;
1180                        };
1181                        let text = input.replacement_text[*chars_read_so_far..].to_string();
1182                        *chars_read_so_far = input.replacement_text.len();
1183                        Some(ToolUseOutput::Rewrite {
1184                            text,
1185                            description: None,
1186                        })
1187                    }
1188                    FAILURE_MESSAGE_TOOL_NAME => {
1189                        let Ok(mut input) =
1190                            serde_json::from_value::<FailureMessageInput>(tool_use.input)
1191                        else {
1192                            return None;
1193                        };
1194                        Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1195                    }
1196                    _ => None,
1197                }
1198            };
1199
1200            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1201
1202            cx.spawn({
1203                let codegen = codegen.clone();
1204                async move |cx| {
1205                    while let Some(update) = message_rx.next().await {
1206                        let _ = codegen.update(cx, |this, _cx| match update {
1207                            ModelUpdate::Description(d) => this.description = Some(d),
1208                            ModelUpdate::Failure(f) => this.failure = Some(f),
1209                        });
1210                    }
1211                }
1212            })
1213            .detach();
1214
1215            let mut message_id = None;
1216            let mut first_text = None;
1217            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1218            let total_text = Arc::new(Mutex::new(String::new()));
1219
1220            loop {
1221                if let Some(first_event) = completion_events.next().await {
1222                    match first_event {
1223                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1224                            message_id = Some(id);
1225                        }
1226                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1227                            if let Some(output) = process_tool_use(tool_use) {
1228                                let (text, update) = match output {
1229                                    ToolUseOutput::Rewrite { text, description } => {
1230                                        (Some(text), description.map(ModelUpdate::Description))
1231                                    }
1232                                    ToolUseOutput::Failure(message) => {
1233                                        (None, Some(ModelUpdate::Failure(message)))
1234                                    }
1235                                };
1236                                if let Some(update) = update {
1237                                    let _ = message_tx.unbounded_send(update);
1238                                }
1239                                first_text = text;
1240                                if first_text.is_some() {
1241                                    break;
1242                                }
1243                            }
1244                        }
1245                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1246                            *last_token_usage.lock() = token_usage;
1247                        }
1248                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1249                            let mut lock = total_text.lock();
1250                            lock.push_str(&text);
1251                        }
1252                        Ok(e) => {
1253                            log::warn!("Unexpected event: {:?}", e);
1254                            break;
1255                        }
1256                        Err(e) => {
1257                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1258                            break;
1259                        }
1260                    }
1261                }
1262            }
1263
1264            let Some(first_text) = first_text else {
1265                finish_with_status(CodegenStatus::Done, cx);
1266                return;
1267            };
1268
1269            let move_last_token_usage = last_token_usage.clone();
1270
1271            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1272                completion_events.filter_map(move |e| {
1273                    let process_tool_use = process_tool_use.clone();
1274                    let last_token_usage = move_last_token_usage.clone();
1275                    let total_text = total_text.clone();
1276                    let mut message_tx = message_tx.clone();
1277                    async move {
1278                        match e {
1279                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1280                                let Some(output) = process_tool_use(tool_use) else {
1281                                    return None;
1282                                };
1283                                let (text, update) = match output {
1284                                    ToolUseOutput::Rewrite { text, description } => {
1285                                        (Some(text), description.map(ModelUpdate::Description))
1286                                    }
1287                                    ToolUseOutput::Failure(message) => {
1288                                        (None, Some(ModelUpdate::Failure(message)))
1289                                    }
1290                                };
1291                                if let Some(update) = update {
1292                                    let _ = message_tx.send(update).await;
1293                                }
1294                                text.map(Ok)
1295                            }
1296                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1297                                *last_token_usage.lock() = token_usage;
1298                                None
1299                            }
1300                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1301                                let mut lock = total_text.lock();
1302                                lock.push_str(&text);
1303                                None
1304                            }
1305                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1306                            e => {
1307                                log::error!("UNEXPECTED EVENT {:?}", e);
1308                                None
1309                            }
1310                        }
1311                    }
1312                }),
1313            ));
1314
1315            let language_model_text_stream = LanguageModelTextStream {
1316                message_id: message_id,
1317                stream: text_stream,
1318                last_token_usage,
1319            };
1320
1321            let Some(task) = codegen
1322                .update(cx, move |codegen, cx| {
1323                    codegen.handle_stream(
1324                        model,
1325                        /* strip_invalid_spans: */ false,
1326                        async { Ok(language_model_text_stream) },
1327                        cx,
1328                    )
1329                })
1330                .ok()
1331            else {
1332                return;
1333            };
1334
1335            task.await;
1336        })
1337    }
1338}
1339
1340#[derive(Copy, Clone, Debug)]
1341pub enum CodegenEvent {
1342    Finished,
1343    Undone,
1344}
1345
1346struct StripInvalidSpans<T> {
1347    stream: T,
1348    stream_done: bool,
1349    buffer: String,
1350    first_line: bool,
1351    line_end: bool,
1352    starts_with_code_block: bool,
1353}
1354
1355impl<T> StripInvalidSpans<T>
1356where
1357    T: Stream<Item = Result<String>>,
1358{
1359    fn new(stream: T) -> Self {
1360        Self {
1361            stream,
1362            stream_done: false,
1363            buffer: String::new(),
1364            first_line: true,
1365            line_end: false,
1366            starts_with_code_block: false,
1367        }
1368    }
1369}
1370
1371impl<T> Stream for StripInvalidSpans<T>
1372where
1373    T: Stream<Item = Result<String>>,
1374{
1375    type Item = Result<String>;
1376
1377    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1378        const CODE_BLOCK_DELIMITER: &str = "```";
1379        const CURSOR_SPAN: &str = "<|CURSOR|>";
1380
1381        let this = unsafe { self.get_unchecked_mut() };
1382        loop {
1383            if !this.stream_done {
1384                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1385                match stream.as_mut().poll_next(cx) {
1386                    Poll::Ready(Some(Ok(chunk))) => {
1387                        this.buffer.push_str(&chunk);
1388                    }
1389                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1390                    Poll::Ready(None) => {
1391                        this.stream_done = true;
1392                    }
1393                    Poll::Pending => return Poll::Pending,
1394                }
1395            }
1396
1397            let mut chunk = String::new();
1398            let mut consumed = 0;
1399            if !this.buffer.is_empty() {
1400                let mut lines = this.buffer.split('\n').enumerate().peekable();
1401                while let Some((line_ix, line)) = lines.next() {
1402                    if line_ix > 0 {
1403                        this.first_line = false;
1404                    }
1405
1406                    if this.first_line {
1407                        let trimmed_line = line.trim();
1408                        if lines.peek().is_some() {
1409                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1410                                consumed += line.len() + 1;
1411                                this.starts_with_code_block = true;
1412                                continue;
1413                            }
1414                        } else if trimmed_line.is_empty()
1415                            || prefixes(CODE_BLOCK_DELIMITER)
1416                                .any(|prefix| trimmed_line.starts_with(prefix))
1417                        {
1418                            break;
1419                        }
1420                    }
1421
1422                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1423                    if lines.peek().is_some() {
1424                        if this.line_end {
1425                            chunk.push('\n');
1426                        }
1427
1428                        chunk.push_str(&line_without_cursor);
1429                        this.line_end = true;
1430                        consumed += line.len() + 1;
1431                    } else if this.stream_done {
1432                        if !this.starts_with_code_block
1433                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1434                        {
1435                            if this.line_end {
1436                                chunk.push('\n');
1437                            }
1438
1439                            chunk.push_str(line);
1440                        }
1441
1442                        consumed += line.len();
1443                    } else {
1444                        let trimmed_line = line.trim();
1445                        if trimmed_line.is_empty()
1446                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1447                            || prefixes(CODE_BLOCK_DELIMITER)
1448                                .any(|prefix| trimmed_line.ends_with(prefix))
1449                        {
1450                            break;
1451                        } else {
1452                            if this.line_end {
1453                                chunk.push('\n');
1454                                this.line_end = false;
1455                            }
1456
1457                            chunk.push_str(&line_without_cursor);
1458                            consumed += line.len();
1459                        }
1460                    }
1461                }
1462            }
1463
1464            this.buffer = this.buffer.split_off(consumed);
1465            if !chunk.is_empty() {
1466                return Poll::Ready(Some(Ok(chunk)));
1467            } else if this.stream_done {
1468                return Poll::Ready(None);
1469            }
1470        }
1471    }
1472}
1473
1474fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1475    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1476}
1477
1478#[derive(Default)]
1479pub struct Diff {
1480    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1481    pub inserted_row_ranges: Vec<Range<Anchor>>,
1482}
1483
1484impl Diff {
1485    fn is_empty(&self) -> bool {
1486        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1487    }
1488}
1489
1490#[cfg(test)]
1491mod tests {
1492    use super::*;
1493    use futures::{
1494        Stream,
1495        stream::{self},
1496    };
1497    use gpui::TestAppContext;
1498    use indoc::indoc;
1499    use language::{Buffer, Point};
1500    use language_model::fake_provider::FakeLanguageModel;
1501    use language_model::{
1502        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
1503        LanguageModelToolUse, StopReason, TokenUsage,
1504    };
1505    use languages::rust_lang;
1506    use rand::prelude::*;
1507    use settings::SettingsStore;
1508    use std::{future, sync::Arc};
1509
1510    #[gpui::test(iterations = 10)]
1511    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1512        init_test(cx);
1513
1514        let text = indoc! {"
1515            fn main() {
1516                let x = 0;
1517                for _ in 0..10 {
1518                    x += 1;
1519                }
1520            }
1521        "};
1522        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1523        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1524        let range = buffer.read_with(cx, |buffer, cx| {
1525            let snapshot = buffer.snapshot(cx);
1526            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1527        });
1528        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1529        let codegen = cx.new(|cx| {
1530            CodegenAlternative::new(
1531                buffer.clone(),
1532                range.clone(),
1533                true,
1534                prompt_builder,
1535                Uuid::new_v4(),
1536                cx,
1537            )
1538        });
1539
1540        let chunks_tx = simulate_response_stream(&codegen, cx);
1541
1542        let mut new_text = concat!(
1543            "       let mut x = 0;\n",
1544            "       while x < 10 {\n",
1545            "           x += 1;\n",
1546            "       }",
1547        );
1548        while !new_text.is_empty() {
1549            let max_len = cmp::min(new_text.len(), 10);
1550            let len = rng.random_range(1..=max_len);
1551            let (chunk, suffix) = new_text.split_at(len);
1552            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1553            new_text = suffix;
1554            cx.background_executor.run_until_parked();
1555        }
1556        drop(chunks_tx);
1557        cx.background_executor.run_until_parked();
1558
1559        assert_eq!(
1560            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1561            indoc! {"
1562                fn main() {
1563                    let mut x = 0;
1564                    while x < 10 {
1565                        x += 1;
1566                    }
1567                }
1568            "}
1569        );
1570    }
1571
1572    #[gpui::test(iterations = 10)]
1573    async fn test_autoindent_when_generating_past_indentation(
1574        cx: &mut TestAppContext,
1575        mut rng: StdRng,
1576    ) {
1577        init_test(cx);
1578
1579        let text = indoc! {"
1580            fn main() {
1581                le
1582            }
1583        "};
1584        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1585        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1586        let range = buffer.read_with(cx, |buffer, cx| {
1587            let snapshot = buffer.snapshot(cx);
1588            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1589        });
1590        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1591        let codegen = cx.new(|cx| {
1592            CodegenAlternative::new(
1593                buffer.clone(),
1594                range.clone(),
1595                true,
1596                prompt_builder,
1597                Uuid::new_v4(),
1598                cx,
1599            )
1600        });
1601
1602        let chunks_tx = simulate_response_stream(&codegen, cx);
1603
1604        cx.background_executor.run_until_parked();
1605
1606        let mut new_text = concat!(
1607            "t mut x = 0;\n",
1608            "while x < 10 {\n",
1609            "    x += 1;\n",
1610            "}", //
1611        );
1612        while !new_text.is_empty() {
1613            let max_len = cmp::min(new_text.len(), 10);
1614            let len = rng.random_range(1..=max_len);
1615            let (chunk, suffix) = new_text.split_at(len);
1616            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1617            new_text = suffix;
1618            cx.background_executor.run_until_parked();
1619        }
1620        drop(chunks_tx);
1621        cx.background_executor.run_until_parked();
1622
1623        assert_eq!(
1624            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1625            indoc! {"
1626                fn main() {
1627                    let mut x = 0;
1628                    while x < 10 {
1629                        x += 1;
1630                    }
1631                }
1632            "}
1633        );
1634    }
1635
1636    #[gpui::test(iterations = 10)]
1637    async fn test_autoindent_when_generating_before_indentation(
1638        cx: &mut TestAppContext,
1639        mut rng: StdRng,
1640    ) {
1641        init_test(cx);
1642
1643        let text = concat!(
1644            "fn main() {\n",
1645            "  \n",
1646            "}\n" //
1647        );
1648        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1649        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1650        let range = buffer.read_with(cx, |buffer, cx| {
1651            let snapshot = buffer.snapshot(cx);
1652            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1653        });
1654        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1655        let codegen = cx.new(|cx| {
1656            CodegenAlternative::new(
1657                buffer.clone(),
1658                range.clone(),
1659                true,
1660                prompt_builder,
1661                Uuid::new_v4(),
1662                cx,
1663            )
1664        });
1665
1666        let chunks_tx = simulate_response_stream(&codegen, cx);
1667
1668        cx.background_executor.run_until_parked();
1669
1670        let mut new_text = concat!(
1671            "let mut x = 0;\n",
1672            "while x < 10 {\n",
1673            "    x += 1;\n",
1674            "}", //
1675        );
1676        while !new_text.is_empty() {
1677            let max_len = cmp::min(new_text.len(), 10);
1678            let len = rng.random_range(1..=max_len);
1679            let (chunk, suffix) = new_text.split_at(len);
1680            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1681            new_text = suffix;
1682            cx.background_executor.run_until_parked();
1683        }
1684        drop(chunks_tx);
1685        cx.background_executor.run_until_parked();
1686
1687        assert_eq!(
1688            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1689            indoc! {"
1690                fn main() {
1691                    let mut x = 0;
1692                    while x < 10 {
1693                        x += 1;
1694                    }
1695                }
1696            "}
1697        );
1698    }
1699
1700    #[gpui::test(iterations = 10)]
1701    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1702        init_test(cx);
1703
1704        let text = indoc! {"
1705            func main() {
1706            \tx := 0
1707            \tfor i := 0; i < 10; i++ {
1708            \t\tx++
1709            \t}
1710            }
1711        "};
1712        let buffer = cx.new(|cx| Buffer::local(text, cx));
1713        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1714        let range = buffer.read_with(cx, |buffer, cx| {
1715            let snapshot = buffer.snapshot(cx);
1716            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1717        });
1718        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1719        let codegen = cx.new(|cx| {
1720            CodegenAlternative::new(
1721                buffer.clone(),
1722                range.clone(),
1723                true,
1724                prompt_builder,
1725                Uuid::new_v4(),
1726                cx,
1727            )
1728        });
1729
1730        let chunks_tx = simulate_response_stream(&codegen, cx);
1731        let new_text = concat!(
1732            "func main() {\n",
1733            "\tx := 0\n",
1734            "\tfor x < 10 {\n",
1735            "\t\tx++\n",
1736            "\t}", //
1737        );
1738        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1739        drop(chunks_tx);
1740        cx.background_executor.run_until_parked();
1741
1742        assert_eq!(
1743            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1744            indoc! {"
1745                func main() {
1746                \tx := 0
1747                \tfor x < 10 {
1748                \t\tx++
1749                \t}
1750                }
1751            "}
1752        );
1753    }
1754
1755    #[gpui::test]
1756    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1757        init_test(cx);
1758
1759        let text = indoc! {"
1760            fn main() {
1761                let x = 0;
1762            }
1763        "};
1764        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1765        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1766        let range = buffer.read_with(cx, |buffer, cx| {
1767            let snapshot = buffer.snapshot(cx);
1768            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1769        });
1770        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1771        let codegen = cx.new(|cx| {
1772            CodegenAlternative::new(
1773                buffer.clone(),
1774                range.clone(),
1775                false,
1776                prompt_builder,
1777                Uuid::new_v4(),
1778                cx,
1779            )
1780        });
1781
1782        let chunks_tx = simulate_response_stream(&codegen, cx);
1783        chunks_tx
1784            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1785            .unwrap();
1786        drop(chunks_tx);
1787        cx.run_until_parked();
1788
1789        // The codegen is inactive, so the buffer doesn't get modified.
1790        assert_eq!(
1791            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1792            text
1793        );
1794
1795        // Activating the codegen applies the changes.
1796        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1797        assert_eq!(
1798            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1799            indoc! {"
1800                fn main() {
1801                    let mut x = 0;
1802                    x += 1;
1803                }
1804            "}
1805        );
1806
1807        // Deactivating the codegen undoes the changes.
1808        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1809        cx.run_until_parked();
1810        assert_eq!(
1811            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1812            text
1813        );
1814    }
1815
1816    // When not streaming tool calls, we strip backticks as part of parsing the model's
1817    // plain text response. This is a regression test for a bug where we stripped
1818    // backticks incorrectly.
1819    #[gpui::test]
1820    async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
1821        init_test(cx);
1822        let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
1823        let buffer = cx.new(|cx| Buffer::local("", cx));
1824        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1825        let range = buffer.read_with(cx, |buffer, cx| {
1826            let snapshot = buffer.snapshot(cx);
1827            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
1828        });
1829        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1830        let codegen = cx.new(|cx| {
1831            CodegenAlternative::new(
1832                buffer.clone(),
1833                range.clone(),
1834                true,
1835                prompt_builder,
1836                Uuid::new_v4(),
1837                cx,
1838            )
1839        });
1840
1841        let events_tx = simulate_tool_based_completion(&codegen, cx);
1842        let chunk_len = text.find('`').unwrap();
1843        events_tx
1844            .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
1845            .unwrap();
1846        events_tx
1847            .unbounded_send(rewrite_tool_use("tool_2", &text, true))
1848            .unwrap();
1849        events_tx
1850            .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
1851            .unwrap();
1852        drop(events_tx);
1853        cx.run_until_parked();
1854
1855        assert_eq!(
1856            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1857            text
1858        );
1859    }
1860
1861    #[gpui::test]
1862    async fn test_strip_invalid_spans_from_codeblock() {
1863        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1864        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1865        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1866        assert_chunks(
1867            "```html\n```js\nLorem ipsum dolor\n```\n```",
1868            "```js\nLorem ipsum dolor\n```",
1869        )
1870        .await;
1871        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1872        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1873        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1874        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1875
1876        async fn assert_chunks(text: &str, expected_text: &str) {
1877            for chunk_size in 1..=text.len() {
1878                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1879                    .map(|chunk| chunk.unwrap())
1880                    .collect::<String>()
1881                    .await;
1882                assert_eq!(
1883                    actual_text, expected_text,
1884                    "failed to strip invalid spans, chunk size: {}",
1885                    chunk_size
1886                );
1887            }
1888        }
1889
1890        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1891            stream::iter(
1892                text.chars()
1893                    .collect::<Vec<_>>()
1894                    .chunks(size)
1895                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1896                    .collect::<Vec<_>>(),
1897            )
1898        }
1899    }
1900
1901    fn init_test(cx: &mut TestAppContext) {
1902        cx.update(LanguageModelRegistry::test);
1903        cx.set_global(cx.update(SettingsStore::test));
1904    }
1905
1906    fn simulate_response_stream(
1907        codegen: &Entity<CodegenAlternative>,
1908        cx: &mut TestAppContext,
1909    ) -> mpsc::UnboundedSender<String> {
1910        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1911        let model = Arc::new(FakeLanguageModel::default());
1912        codegen.update(cx, |codegen, cx| {
1913            codegen.generation = codegen.handle_stream(
1914                model,
1915                /* strip_invalid_spans: */ false,
1916                future::ready(Ok(LanguageModelTextStream {
1917                    message_id: None,
1918                    stream: chunks_rx.map(Ok).boxed(),
1919                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1920                })),
1921                cx,
1922            );
1923        });
1924        chunks_tx
1925    }
1926
1927    fn simulate_tool_based_completion(
1928        codegen: &Entity<CodegenAlternative>,
1929        cx: &mut TestAppContext,
1930    ) -> mpsc::UnboundedSender<LanguageModelCompletionEvent> {
1931        let (events_tx, events_rx) = mpsc::unbounded();
1932        let model = Arc::new(FakeLanguageModel::default());
1933        codegen.update(cx, |codegen, cx| {
1934            let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
1935                as BoxStream<
1936                    'static,
1937                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1938                >));
1939            codegen.generation = codegen.handle_completion(model, completion_stream, cx);
1940        });
1941        events_tx
1942    }
1943
1944    fn rewrite_tool_use(
1945        id: &str,
1946        replacement_text: &str,
1947        is_complete: bool,
1948    ) -> LanguageModelCompletionEvent {
1949        let input = RewriteSectionInput {
1950            replacement_text: replacement_text.into(),
1951        };
1952        LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
1953            id: id.into(),
1954            name: REWRITE_SECTION_TOOL_NAME.into(),
1955            raw_input: serde_json::to_string(&input).unwrap(),
1956            input: serde_json::to_value(&input).unwrap(),
1957            is_input_complete: is_complete,
1958            thought_signature: None,
1959        })
1960    }
1961}