buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use uuid::Uuid;
   5
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use futures::{
  10    SinkExt, Stream, StreamExt, TryStreamExt as _,
  11    channel::mpsc,
  12    future::{LocalBoxFuture, Shared},
  13    join,
  14    stream::BoxStream,
  15};
  16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  17use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  18use language_model::{
  19    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  20    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  21    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  22    LanguageModelToolUse, Role, TokenUsage,
  23};
  24use multi_buffer::MultiBufferRow;
  25use parking_lot::Mutex;
  26use prompt_store::PromptBuilder;
  27use rope::Rope;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings as _;
  31use smol::future::FutureExt;
  32use std::{
  33    cmp,
  34    future::Future,
  35    iter,
  36    ops::{Range, RangeInclusive},
  37    pin::Pin,
  38    sync::Arc,
  39    task::{self, Poll},
  40    time::Instant,
  41};
  42use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  43
  44/// Use this tool when you cannot or should not make a rewrite. This includes:
  45/// - The user's request is unclear, ambiguous, or nonsensical
  46/// - The requested change cannot be made by only editing the <rewrite_this> section
  47#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  48pub struct FailureMessageInput {
  49    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  50    #[serde(default)]
  51    pub message: String,
  52}
  53
  54/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  55/// Only use this tool when you are confident you understand the user's request and can fulfill it
  56/// by editing the marked section.
  57#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  58pub struct RewriteSectionInput {
  59    /// The text to replace the section with.
  60    #[serde(default)]
  61    pub replacement_text: String,
  62}
  63
  64pub struct BufferCodegen {
  65    alternatives: Vec<Entity<CodegenAlternative>>,
  66    pub active_alternative: usize,
  67    seen_alternatives: HashSet<usize>,
  68    subscriptions: Vec<Subscription>,
  69    buffer: Entity<MultiBuffer>,
  70    range: Range<Anchor>,
  71    initial_transaction_id: Option<TransactionId>,
  72    builder: Arc<PromptBuilder>,
  73    pub is_insertion: bool,
  74    session_id: Uuid,
  75}
  76
  77pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
  78pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
  79
  80impl BufferCodegen {
  81    pub fn new(
  82        buffer: Entity<MultiBuffer>,
  83        range: Range<Anchor>,
  84        initial_transaction_id: Option<TransactionId>,
  85        session_id: Uuid,
  86        builder: Arc<PromptBuilder>,
  87        cx: &mut Context<Self>,
  88    ) -> Self {
  89        let codegen = cx.new(|cx| {
  90            CodegenAlternative::new(
  91                buffer.clone(),
  92                range.clone(),
  93                false,
  94                builder.clone(),
  95                session_id,
  96                cx,
  97            )
  98        });
  99        let mut this = Self {
 100            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
 101            alternatives: vec![codegen],
 102            active_alternative: 0,
 103            seen_alternatives: HashSet::default(),
 104            subscriptions: Vec::new(),
 105            buffer,
 106            range,
 107            initial_transaction_id,
 108            builder,
 109            session_id,
 110        };
 111        this.activate(0, cx);
 112        this
 113    }
 114
 115    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 116        let codegen = self.active_alternative().clone();
 117        self.subscriptions.clear();
 118        self.subscriptions
 119            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 120        self.subscriptions
 121            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 122    }
 123
 124    pub fn active_completion(&self, cx: &App) -> Option<String> {
 125        self.active_alternative().read(cx).current_completion()
 126    }
 127
 128    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 129        &self.alternatives[self.active_alternative]
 130    }
 131
 132    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 133        self.active_alternative().read(cx).language_name(cx)
 134    }
 135
 136    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 137        &self.active_alternative().read(cx).status
 138    }
 139
 140    pub fn alternative_count(&self, cx: &App) -> usize {
 141        LanguageModelRegistry::read_global(cx)
 142            .inline_alternative_models()
 143            .len()
 144            + 1
 145    }
 146
 147    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 148        let next_active_ix = if self.active_alternative == 0 {
 149            self.alternatives.len() - 1
 150        } else {
 151            self.active_alternative - 1
 152        };
 153        self.activate(next_active_ix, cx);
 154    }
 155
 156    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 157        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 158        self.activate(next_active_ix, cx);
 159    }
 160
 161    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 162        self.active_alternative()
 163            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 164        self.seen_alternatives.insert(index);
 165        self.active_alternative = index;
 166        self.active_alternative()
 167            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 168        self.subscribe_to_alternative(cx);
 169        cx.notify();
 170    }
 171
 172    pub fn start(
 173        &mut self,
 174        primary_model: Arc<dyn LanguageModel>,
 175        user_prompt: String,
 176        context_task: Shared<Task<Option<LoadedContext>>>,
 177        cx: &mut Context<Self>,
 178    ) -> Result<()> {
 179        let alternative_models = LanguageModelRegistry::read_global(cx)
 180            .inline_alternative_models()
 181            .to_vec();
 182
 183        self.active_alternative()
 184            .update(cx, |alternative, cx| alternative.undo(cx));
 185        self.activate(0, cx);
 186        self.alternatives.truncate(1);
 187
 188        for _ in 0..alternative_models.len() {
 189            self.alternatives.push(cx.new(|cx| {
 190                CodegenAlternative::new(
 191                    self.buffer.clone(),
 192                    self.range.clone(),
 193                    false,
 194                    self.builder.clone(),
 195                    self.session_id,
 196                    cx,
 197                )
 198            }));
 199        }
 200
 201        for (model, alternative) in iter::once(primary_model)
 202            .chain(alternative_models)
 203            .zip(&self.alternatives)
 204        {
 205            alternative.update(cx, |alternative, cx| {
 206                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 207            })?;
 208        }
 209
 210        Ok(())
 211    }
 212
 213    pub fn stop(&mut self, cx: &mut Context<Self>) {
 214        for codegen in &self.alternatives {
 215            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 216        }
 217    }
 218
 219    pub fn undo(&mut self, cx: &mut Context<Self>) {
 220        self.active_alternative()
 221            .update(cx, |codegen, cx| codegen.undo(cx));
 222
 223        self.buffer.update(cx, |buffer, cx| {
 224            if let Some(transaction_id) = self.initial_transaction_id.take() {
 225                buffer.undo_transaction(transaction_id, cx);
 226                buffer.refresh_preview(cx);
 227            }
 228        });
 229    }
 230
 231    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 232        self.active_alternative().read(cx).buffer.clone()
 233    }
 234
 235    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 236        self.active_alternative().read(cx).old_buffer.clone()
 237    }
 238
 239    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 240        self.active_alternative().read(cx).snapshot.clone()
 241    }
 242
 243    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 244        self.active_alternative().read(cx).edit_position
 245    }
 246
 247    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 248        &self.active_alternative().read(cx).diff
 249    }
 250
 251    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 252        self.active_alternative().read(cx).last_equal_ranges()
 253    }
 254
 255    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 256        self.active_alternative().read(cx).selected_text()
 257    }
 258
 259    pub fn session_id(&self) -> Uuid {
 260        self.session_id
 261    }
 262}
 263
 264impl EventEmitter<CodegenEvent> for BufferCodegen {}
 265
 266pub struct CodegenAlternative {
 267    buffer: Entity<MultiBuffer>,
 268    old_buffer: Entity<Buffer>,
 269    snapshot: MultiBufferSnapshot,
 270    edit_position: Option<Anchor>,
 271    range: Range<Anchor>,
 272    last_equal_ranges: Vec<Range<Anchor>>,
 273    transformation_transaction_id: Option<TransactionId>,
 274    status: CodegenStatus,
 275    generation: Task<()>,
 276    diff: Diff,
 277    _subscription: gpui::Subscription,
 278    builder: Arc<PromptBuilder>,
 279    active: bool,
 280    edits: Vec<(Range<Anchor>, String)>,
 281    line_operations: Vec<LineOperation>,
 282    elapsed_time: Option<f64>,
 283    completion: Option<String>,
 284    selected_text: Option<String>,
 285    pub message_id: Option<String>,
 286    session_id: Uuid,
 287    pub description: Option<String>,
 288    pub failure: Option<String>,
 289}
 290
 291impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 292
 293impl CodegenAlternative {
 294    pub fn new(
 295        buffer: Entity<MultiBuffer>,
 296        range: Range<Anchor>,
 297        active: bool,
 298        builder: Arc<PromptBuilder>,
 299        session_id: Uuid,
 300        cx: &mut Context<Self>,
 301    ) -> Self {
 302        let snapshot = buffer.read(cx).snapshot(cx);
 303
 304        let (old_buffer, _, _) = snapshot
 305            .range_to_buffer_ranges(range.clone())
 306            .pop()
 307            .unwrap();
 308        let old_buffer = cx.new(|cx| {
 309            let text = old_buffer.as_rope().clone();
 310            let line_ending = old_buffer.line_ending();
 311            let language = old_buffer.language().cloned();
 312            let language_registry = buffer
 313                .read(cx)
 314                .buffer(old_buffer.remote_id())
 315                .unwrap()
 316                .read(cx)
 317                .language_registry();
 318
 319            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 320            buffer.set_language(language, cx);
 321            if let Some(language_registry) = language_registry {
 322                buffer.set_language_registry(language_registry);
 323            }
 324            buffer
 325        });
 326
 327        Self {
 328            buffer: buffer.clone(),
 329            old_buffer,
 330            edit_position: None,
 331            message_id: None,
 332            snapshot,
 333            last_equal_ranges: Default::default(),
 334            transformation_transaction_id: None,
 335            status: CodegenStatus::Idle,
 336            generation: Task::ready(()),
 337            diff: Diff::default(),
 338            builder,
 339            active: active,
 340            edits: Vec::new(),
 341            line_operations: Vec::new(),
 342            range,
 343            elapsed_time: None,
 344            completion: None,
 345            selected_text: None,
 346            session_id,
 347            description: None,
 348            failure: None,
 349            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 350        }
 351    }
 352
 353    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 354        self.old_buffer
 355            .read(cx)
 356            .language()
 357            .map(|language| language.name())
 358    }
 359
 360    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 361        if active != self.active {
 362            self.active = active;
 363
 364            if self.active {
 365                let edits = self.edits.clone();
 366                self.apply_edits(edits, cx);
 367                if matches!(self.status, CodegenStatus::Pending) {
 368                    let line_operations = self.line_operations.clone();
 369                    self.reapply_line_based_diff(line_operations, cx);
 370                } else {
 371                    self.reapply_batch_diff(cx).detach();
 372                }
 373            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 374                self.buffer.update(cx, |buffer, cx| {
 375                    buffer.undo_transaction(transaction_id, cx);
 376                    buffer.forget_transaction(transaction_id, cx);
 377                });
 378            }
 379        }
 380    }
 381
 382    fn handle_buffer_event(
 383        &mut self,
 384        _buffer: Entity<MultiBuffer>,
 385        event: &multi_buffer::Event,
 386        cx: &mut Context<Self>,
 387    ) {
 388        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 389            && self.transformation_transaction_id == Some(*transaction_id)
 390        {
 391            self.transformation_transaction_id = None;
 392            self.generation = Task::ready(());
 393            cx.emit(CodegenEvent::Undone);
 394        }
 395    }
 396
 397    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 398        &self.last_equal_ranges
 399    }
 400
 401    pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 402        model.supports_streaming_tools()
 403            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 404    }
 405
 406    pub fn start(
 407        &mut self,
 408        user_prompt: String,
 409        context_task: Shared<Task<Option<LoadedContext>>>,
 410        model: Arc<dyn LanguageModel>,
 411        cx: &mut Context<Self>,
 412    ) -> Result<()> {
 413        // Clear the model explanation since the user has started a new generation.
 414        self.description = None;
 415
 416        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 417            self.buffer.update(cx, |buffer, cx| {
 418                buffer.undo_transaction(transformation_transaction_id, cx);
 419            });
 420        }
 421
 422        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 423
 424        if Self::use_streaming_tools(model.as_ref(), cx) {
 425            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 426            let completion_events = cx.spawn({
 427                let model = model.clone();
 428                async move |_, cx| model.stream_completion(request.await, cx).await
 429            });
 430            self.generation = self.handle_completion(model, completion_events, cx);
 431        } else {
 432            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 433                if user_prompt.trim().to_lowercase() == "delete" {
 434                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 435                } else {
 436                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 437                    cx.spawn({
 438                        let model = model.clone();
 439                        async move |_, cx| {
 440                            Ok(model.stream_completion_text(request.await, cx).await?)
 441                        }
 442                    })
 443                    .boxed_local()
 444                };
 445            self.generation =
 446                self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
 447        }
 448
 449        Ok(())
 450    }
 451
 452    fn build_request_tools(
 453        &self,
 454        model: &Arc<dyn LanguageModel>,
 455        user_prompt: String,
 456        context_task: Shared<Task<Option<LoadedContext>>>,
 457        cx: &mut App,
 458    ) -> Result<Task<LanguageModelRequest>> {
 459        let buffer = self.buffer.read(cx).snapshot(cx);
 460        let language = buffer.language_at(self.range.start);
 461        let language_name = if let Some(language) = language.as_ref() {
 462            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 463                None
 464            } else {
 465                Some(language.name())
 466            }
 467        } else {
 468            None
 469        };
 470
 471        let language_name = language_name.as_ref();
 472        let start = buffer.point_to_buffer_offset(self.range.start);
 473        let end = buffer.point_to_buffer_offset(self.range.end);
 474        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 475            let (start_buffer, start_buffer_offset) = start;
 476            let (end_buffer, end_buffer_offset) = end;
 477            if start_buffer.remote_id() == end_buffer.remote_id() {
 478                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 479            } else {
 480                anyhow::bail!("invalid transformation range");
 481            }
 482        } else {
 483            anyhow::bail!("invalid transformation range");
 484        };
 485
 486        let system_prompt = self
 487            .builder
 488            .generate_inline_transformation_prompt_tools(
 489                language_name,
 490                buffer,
 491                range.start.0..range.end.0,
 492            )
 493            .context("generating content prompt")?;
 494
 495        let temperature = AgentSettings::temperature_for_model(model, cx);
 496
 497        let tool_input_format = model.tool_input_format();
 498        let tool_choice = model
 499            .supports_tool_choice(LanguageModelToolChoice::Any)
 500            .then_some(LanguageModelToolChoice::Any);
 501
 502        Ok(cx.spawn(async move |_cx| {
 503            let mut messages = vec![LanguageModelRequestMessage {
 504                role: Role::System,
 505                content: vec![system_prompt.into()],
 506                cache: false,
 507                reasoning_details: None,
 508            }];
 509
 510            let mut user_message = LanguageModelRequestMessage {
 511                role: Role::User,
 512                content: Vec::new(),
 513                cache: false,
 514                reasoning_details: None,
 515            };
 516
 517            if let Some(context) = context_task.await {
 518                context.add_to_request_message(&mut user_message);
 519            }
 520
 521            user_message.content.push(user_prompt.into());
 522            messages.push(user_message);
 523
 524            let tools = vec![
 525                LanguageModelRequestTool {
 526                    name: REWRITE_SECTION_TOOL_NAME.to_string(),
 527                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 528                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 529                },
 530                LanguageModelRequestTool {
 531                    name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
 532                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 533                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 534                },
 535            ];
 536
 537            LanguageModelRequest {
 538                thread_id: None,
 539                prompt_id: None,
 540                intent: Some(CompletionIntent::InlineAssist),
 541                tools,
 542                tool_choice,
 543                stop: Vec::new(),
 544                temperature,
 545                messages,
 546                thinking_allowed: false,
 547            }
 548        }))
 549    }
 550
 551    fn build_request(
 552        &self,
 553        model: &Arc<dyn LanguageModel>,
 554        user_prompt: String,
 555        context_task: Shared<Task<Option<LoadedContext>>>,
 556        cx: &mut App,
 557    ) -> Result<Task<LanguageModelRequest>> {
 558        if Self::use_streaming_tools(model.as_ref(), cx) {
 559            return self.build_request_tools(model, user_prompt, context_task, cx);
 560        }
 561
 562        let buffer = self.buffer.read(cx).snapshot(cx);
 563        let language = buffer.language_at(self.range.start);
 564        let language_name = if let Some(language) = language.as_ref() {
 565            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 566                None
 567            } else {
 568                Some(language.name())
 569            }
 570        } else {
 571            None
 572        };
 573
 574        let language_name = language_name.as_ref();
 575        let start = buffer.point_to_buffer_offset(self.range.start);
 576        let end = buffer.point_to_buffer_offset(self.range.end);
 577        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 578            let (start_buffer, start_buffer_offset) = start;
 579            let (end_buffer, end_buffer_offset) = end;
 580            if start_buffer.remote_id() == end_buffer.remote_id() {
 581                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 582            } else {
 583                anyhow::bail!("invalid transformation range");
 584            }
 585        } else {
 586            anyhow::bail!("invalid transformation range");
 587        };
 588
 589        let prompt = self
 590            .builder
 591            .generate_inline_transformation_prompt(
 592                user_prompt,
 593                language_name,
 594                buffer,
 595                range.start.0..range.end.0,
 596            )
 597            .context("generating content prompt")?;
 598
 599        let temperature = AgentSettings::temperature_for_model(model, cx);
 600
 601        Ok(cx.spawn(async move |_cx| {
 602            let mut request_message = LanguageModelRequestMessage {
 603                role: Role::User,
 604                content: Vec::new(),
 605                cache: false,
 606                reasoning_details: None,
 607            };
 608
 609            if let Some(context) = context_task.await {
 610                context.add_to_request_message(&mut request_message);
 611            }
 612
 613            request_message.content.push(prompt.into());
 614
 615            LanguageModelRequest {
 616                thread_id: None,
 617                prompt_id: None,
 618                intent: Some(CompletionIntent::InlineAssist),
 619                tools: Vec::new(),
 620                tool_choice: None,
 621                stop: Vec::new(),
 622                temperature,
 623                messages: vec![request_message],
 624                thinking_allowed: false,
 625            }
 626        }))
 627    }
 628
 629    pub fn handle_stream(
 630        &mut self,
 631        model: Arc<dyn LanguageModel>,
 632        strip_invalid_spans: bool,
 633        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 634        cx: &mut Context<Self>,
 635    ) -> Task<()> {
 636        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 637        let session_id = self.session_id;
 638        let model_telemetry_id = model.telemetry_id();
 639        let model_provider_id = model.provider_id().to_string();
 640        let start_time = Instant::now();
 641
 642        // Make a new snapshot and re-resolve anchor in case the document was modified.
 643        // This can happen often if the editor loses focus and is saved + reformatted,
 644        // as in https://github.com/zed-industries/zed/issues/39088
 645        self.snapshot = self.buffer.read(cx).snapshot(cx);
 646        self.range = self.snapshot.anchor_after(self.range.start)
 647            ..self.snapshot.anchor_after(self.range.end);
 648
 649        let snapshot = self.snapshot.clone();
 650        let selected_text = snapshot
 651            .text_for_range(self.range.start..self.range.end)
 652            .collect::<Rope>();
 653
 654        self.selected_text = Some(selected_text.to_string());
 655
 656        let selection_start = self.range.start.to_point(&snapshot);
 657
 658        // Start with the indentation of the first line in the selection
 659        let mut suggested_line_indent = snapshot
 660            .suggested_indents(selection_start.row..=selection_start.row, cx)
 661            .into_values()
 662            .next()
 663            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 664
 665        // If the first line in the selection does not have indentation, check the following lines
 666        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 667            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 668                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 669                // Prefer tabs if a line in the selection uses tabs as indentation
 670                if line_indent.kind == IndentKind::Tab {
 671                    suggested_line_indent.kind = IndentKind::Tab;
 672                    break;
 673                }
 674            }
 675        }
 676
 677        let language_name = {
 678            let multibuffer = self.buffer.read(cx);
 679            let snapshot = multibuffer.snapshot(cx);
 680            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 681            ranges
 682                .first()
 683                .and_then(|(buffer, _, _)| buffer.language())
 684                .map(|language| language.name())
 685        };
 686
 687        self.diff = Diff::default();
 688        self.status = CodegenStatus::Pending;
 689        let mut edit_start = self.range.start.to_offset(&snapshot);
 690        let completion = Arc::new(Mutex::new(String::new()));
 691        let completion_clone = completion.clone();
 692
 693        cx.notify();
 694        cx.spawn(async move |codegen, cx| {
 695            let stream = stream.await;
 696
 697            let token_usage = stream
 698                .as_ref()
 699                .ok()
 700                .map(|stream| stream.last_token_usage.clone());
 701            let message_id = stream
 702                .as_ref()
 703                .ok()
 704                .and_then(|stream| stream.message_id.clone());
 705            let generate = async {
 706                let model_telemetry_id = model_telemetry_id.clone();
 707                let model_provider_id = model_provider_id.clone();
 708                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 709                let message_id = message_id.clone();
 710                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 711                    let anthropic_reporter = anthropic_reporter.clone();
 712                    let language_name = language_name.clone();
 713                    async move {
 714                        let mut response_latency = None;
 715                        let request_start = Instant::now();
 716                        let diff = async {
 717                            let raw_stream = stream?.stream.map_err(|error| error.into());
 718
 719                            let stripped;
 720                            let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
 721                                if strip_invalid_spans {
 722                                    stripped = StripInvalidSpans::new(raw_stream);
 723                                    Box::pin(stripped)
 724                                } else {
 725                                    Box::pin(raw_stream)
 726                                };
 727
 728                            let mut diff = StreamingDiff::new(selected_text.to_string());
 729                            let mut line_diff = LineDiff::default();
 730
 731                            let mut new_text = String::new();
 732                            let mut base_indent = None;
 733                            let mut line_indent = None;
 734                            let mut first_line = true;
 735
 736                            while let Some(chunk) = chunks.next().await {
 737                                if response_latency.is_none() {
 738                                    response_latency = Some(request_start.elapsed());
 739                                }
 740                                let chunk = chunk?;
 741                                completion_clone.lock().push_str(&chunk);
 742
 743                                let mut lines = chunk.split('\n').peekable();
 744                                while let Some(line) = lines.next() {
 745                                    new_text.push_str(line);
 746                                    if line_indent.is_none()
 747                                        && let Some(non_whitespace_ch_ix) =
 748                                            new_text.find(|ch: char| !ch.is_whitespace())
 749                                    {
 750                                        line_indent = Some(non_whitespace_ch_ix);
 751                                        base_indent = base_indent.or(line_indent);
 752
 753                                        let line_indent = line_indent.unwrap();
 754                                        let base_indent = base_indent.unwrap();
 755                                        let indent_delta = line_indent as i32 - base_indent as i32;
 756                                        let mut corrected_indent_len = cmp::max(
 757                                            0,
 758                                            suggested_line_indent.len as i32 + indent_delta,
 759                                        )
 760                                            as usize;
 761                                        if first_line {
 762                                            corrected_indent_len = corrected_indent_len
 763                                                .saturating_sub(selection_start.column as usize);
 764                                        }
 765
 766                                        let indent_char = suggested_line_indent.char();
 767                                        let mut indent_buffer = [0; 4];
 768                                        let indent_str =
 769                                            indent_char.encode_utf8(&mut indent_buffer);
 770                                        new_text.replace_range(
 771                                            ..line_indent,
 772                                            &indent_str.repeat(corrected_indent_len),
 773                                        );
 774                                    }
 775
 776                                    if line_indent.is_some() {
 777                                        let char_ops = diff.push_new(&new_text);
 778                                        line_diff.push_char_operations(&char_ops, &selected_text);
 779                                        diff_tx
 780                                            .send((char_ops, line_diff.line_operations()))
 781                                            .await?;
 782                                        new_text.clear();
 783                                    }
 784
 785                                    if lines.peek().is_some() {
 786                                        let char_ops = diff.push_new("\n");
 787                                        line_diff.push_char_operations(&char_ops, &selected_text);
 788                                        diff_tx
 789                                            .send((char_ops, line_diff.line_operations()))
 790                                            .await?;
 791                                        if line_indent.is_none() {
 792                                            // Don't write out the leading indentation in empty lines on the next line
 793                                            // This is the case where the above if statement didn't clear the buffer
 794                                            new_text.clear();
 795                                        }
 796                                        line_indent = None;
 797                                        first_line = false;
 798                                    }
 799                                }
 800                            }
 801
 802                            let mut char_ops = diff.push_new(&new_text);
 803                            char_ops.extend(diff.finish());
 804                            line_diff.push_char_operations(&char_ops, &selected_text);
 805                            line_diff.finish(&selected_text);
 806                            diff_tx
 807                                .send((char_ops, line_diff.line_operations()))
 808                                .await?;
 809
 810                            anyhow::Ok(())
 811                        };
 812
 813                        let result = diff.await;
 814
 815                        let error_message = result.as_ref().err().map(|error| error.to_string());
 816                        telemetry::event!(
 817                            "Assistant Responded",
 818                            kind = "inline",
 819                            phase = "response",
 820                            session_id = session_id.to_string(),
 821                            model = model_telemetry_id,
 822                            model_provider = model_provider_id,
 823                            language_name = language_name.as_ref().map(|n| n.to_string()),
 824                            message_id = message_id.as_deref(),
 825                            response_latency = response_latency,
 826                            error_message = error_message.as_deref(),
 827                        );
 828
 829                        anthropic_reporter.report(language_model::AnthropicEventData {
 830                            completion_type: language_model::AnthropicCompletionType::Editor,
 831                            event: language_model::AnthropicEventType::Response,
 832                            language_name: language_name.map(|n| n.to_string()),
 833                            message_id,
 834                        });
 835
 836                        result?;
 837                        Ok(())
 838                    }
 839                });
 840
 841                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 842                    codegen.update(cx, |codegen, cx| {
 843                        codegen.last_equal_ranges.clear();
 844
 845                        let edits = char_ops
 846                            .into_iter()
 847                            .filter_map(|operation| match operation {
 848                                CharOperation::Insert { text } => {
 849                                    let edit_start = snapshot.anchor_after(edit_start);
 850                                    Some((edit_start..edit_start, text))
 851                                }
 852                                CharOperation::Delete { bytes } => {
 853                                    let edit_end = edit_start + bytes;
 854                                    let edit_range = snapshot.anchor_after(edit_start)
 855                                        ..snapshot.anchor_before(edit_end);
 856                                    edit_start = edit_end;
 857                                    Some((edit_range, String::new()))
 858                                }
 859                                CharOperation::Keep { bytes } => {
 860                                    let edit_end = edit_start + bytes;
 861                                    let edit_range = snapshot.anchor_after(edit_start)
 862                                        ..snapshot.anchor_before(edit_end);
 863                                    edit_start = edit_end;
 864                                    codegen.last_equal_ranges.push(edit_range);
 865                                    None
 866                                }
 867                            })
 868                            .collect::<Vec<_>>();
 869
 870                        if codegen.active {
 871                            codegen.apply_edits(edits.iter().cloned(), cx);
 872                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 873                        }
 874                        codegen.edits.extend(edits);
 875                        codegen.line_operations = line_ops;
 876                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 877
 878                        cx.notify();
 879                    })?;
 880                }
 881
 882                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 883                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 884                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 885                let batch_diff_task =
 886                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 887                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 888                line_based_stream_diff?;
 889
 890                anyhow::Ok(())
 891            };
 892
 893            let result = generate.await;
 894            let elapsed_time = start_time.elapsed().as_secs_f64();
 895
 896            codegen
 897                .update(cx, |this, cx| {
 898                    this.message_id = message_id;
 899                    this.last_equal_ranges.clear();
 900                    if let Err(error) = result {
 901                        this.status = CodegenStatus::Error(error);
 902                    } else {
 903                        this.status = CodegenStatus::Done;
 904                    }
 905                    this.elapsed_time = Some(elapsed_time);
 906                    this.completion = Some(completion.lock().clone());
 907                    if let Some(usage) = token_usage {
 908                        let usage = usage.lock();
 909                        telemetry::event!(
 910                            "Inline Assistant Completion",
 911                            model = model_telemetry_id,
 912                            model_provider = model_provider_id,
 913                            input_tokens = usage.input_tokens,
 914                            output_tokens = usage.output_tokens,
 915                        )
 916                    }
 917
 918                    cx.emit(CodegenEvent::Finished);
 919                    cx.notify();
 920                })
 921                .ok();
 922        })
 923    }
 924
 925    pub fn current_completion(&self) -> Option<String> {
 926        self.completion.clone()
 927    }
 928
 929    #[cfg(any(test, feature = "test-support"))]
 930    pub fn current_description(&self) -> Option<String> {
 931        self.description.clone()
 932    }
 933
 934    #[cfg(any(test, feature = "test-support"))]
 935    pub fn current_failure(&self) -> Option<String> {
 936        self.failure.clone()
 937    }
 938
 939    pub fn selected_text(&self) -> Option<&str> {
 940        self.selected_text.as_deref()
 941    }
 942
 943    pub fn stop(&mut self, cx: &mut Context<Self>) {
 944        self.last_equal_ranges.clear();
 945        if self.diff.is_empty() {
 946            self.status = CodegenStatus::Idle;
 947        } else {
 948            self.status = CodegenStatus::Done;
 949        }
 950        self.generation = Task::ready(());
 951        cx.emit(CodegenEvent::Finished);
 952        cx.notify();
 953    }
 954
 955    pub fn undo(&mut self, cx: &mut Context<Self>) {
 956        self.buffer.update(cx, |buffer, cx| {
 957            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 958                buffer.undo_transaction(transaction_id, cx);
 959                buffer.refresh_preview(cx);
 960            }
 961        });
 962    }
 963
 964    fn apply_edits(
 965        &mut self,
 966        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 967        cx: &mut Context<CodegenAlternative>,
 968    ) {
 969        let transaction = self.buffer.update(cx, |buffer, cx| {
 970            // Avoid grouping agent edits with user edits.
 971            buffer.finalize_last_transaction(cx);
 972            buffer.start_transaction(cx);
 973            buffer.edit(edits, None, cx);
 974            buffer.end_transaction(cx)
 975        });
 976
 977        if let Some(transaction) = transaction {
 978            if let Some(first_transaction) = self.transformation_transaction_id {
 979                // Group all agent edits into the first transaction.
 980                self.buffer.update(cx, |buffer, cx| {
 981                    buffer.merge_transactions(transaction, first_transaction, cx)
 982                });
 983            } else {
 984                self.transformation_transaction_id = Some(transaction);
 985                self.buffer
 986                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 987            }
 988        }
 989    }
 990
 991    fn reapply_line_based_diff(
 992        &mut self,
 993        line_operations: impl IntoIterator<Item = LineOperation>,
 994        cx: &mut Context<Self>,
 995    ) {
 996        let old_snapshot = self.snapshot.clone();
 997        let old_range = self.range.to_point(&old_snapshot);
 998        let new_snapshot = self.buffer.read(cx).snapshot(cx);
 999        let new_range = self.range.to_point(&new_snapshot);
1000
1001        let mut old_row = old_range.start.row;
1002        let mut new_row = new_range.start.row;
1003
1004        self.diff.deleted_row_ranges.clear();
1005        self.diff.inserted_row_ranges.clear();
1006        for operation in line_operations {
1007            match operation {
1008                LineOperation::Keep { lines } => {
1009                    old_row += lines;
1010                    new_row += lines;
1011                }
1012                LineOperation::Delete { lines } => {
1013                    let old_end_row = old_row + lines - 1;
1014                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1015
1016                    if let Some((_, last_deleted_row_range)) =
1017                        self.diff.deleted_row_ranges.last_mut()
1018                    {
1019                        if *last_deleted_row_range.end() + 1 == old_row {
1020                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1021                        } else {
1022                            self.diff
1023                                .deleted_row_ranges
1024                                .push((new_row, old_row..=old_end_row));
1025                        }
1026                    } else {
1027                        self.diff
1028                            .deleted_row_ranges
1029                            .push((new_row, old_row..=old_end_row));
1030                    }
1031
1032                    old_row += lines;
1033                }
1034                LineOperation::Insert { lines } => {
1035                    let new_end_row = new_row + lines - 1;
1036                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1037                    let end = new_snapshot.anchor_before(Point::new(
1038                        new_end_row,
1039                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1040                    ));
1041                    self.diff.inserted_row_ranges.push(start..end);
1042                    new_row += lines;
1043                }
1044            }
1045
1046            cx.notify();
1047        }
1048    }
1049
1050    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1051        let old_snapshot = self.snapshot.clone();
1052        let old_range = self.range.to_point(&old_snapshot);
1053        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1054        let new_range = self.range.to_point(&new_snapshot);
1055
1056        cx.spawn(async move |codegen, cx| {
1057            let (deleted_row_ranges, inserted_row_ranges) = cx
1058                .background_spawn(async move {
1059                    let old_text = old_snapshot
1060                        .text_for_range(
1061                            Point::new(old_range.start.row, 0)
1062                                ..Point::new(
1063                                    old_range.end.row,
1064                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1065                                ),
1066                        )
1067                        .collect::<String>();
1068                    let new_text = new_snapshot
1069                        .text_for_range(
1070                            Point::new(new_range.start.row, 0)
1071                                ..Point::new(
1072                                    new_range.end.row,
1073                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1074                                ),
1075                        )
1076                        .collect::<String>();
1077
1078                    let old_start_row = old_range.start.row;
1079                    let new_start_row = new_range.start.row;
1080                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1081                    let mut inserted_row_ranges = Vec::new();
1082                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1083                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1084                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1085                        if !old_rows.is_empty() {
1086                            deleted_row_ranges.push((
1087                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1088                                old_rows.start..=old_rows.end - 1,
1089                            ));
1090                        }
1091                        if !new_rows.is_empty() {
1092                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1093                            let new_end_row = new_rows.end - 1;
1094                            let end = new_snapshot.anchor_before(Point::new(
1095                                new_end_row,
1096                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1097                            ));
1098                            inserted_row_ranges.push(start..end);
1099                        }
1100                    }
1101                    (deleted_row_ranges, inserted_row_ranges)
1102                })
1103                .await;
1104
1105            codegen
1106                .update(cx, |codegen, cx| {
1107                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1108                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1109                    cx.notify();
1110                })
1111                .ok();
1112        })
1113    }
1114
1115    fn handle_completion(
1116        &mut self,
1117        model: Arc<dyn LanguageModel>,
1118        completion_stream: Task<
1119            Result<
1120                BoxStream<
1121                    'static,
1122                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1123                >,
1124                LanguageModelCompletionError,
1125            >,
1126        >,
1127        cx: &mut Context<Self>,
1128    ) -> Task<()> {
1129        self.diff = Diff::default();
1130        self.status = CodegenStatus::Pending;
1131
1132        cx.notify();
1133        // Leaving this in generation so that STOP equivalent events are respected even
1134        // while we're still pre-processing the completion event
1135        cx.spawn(async move |codegen, cx| {
1136            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1137                let _ = codegen.update(cx, |this, cx| {
1138                    this.status = status;
1139                    cx.emit(CodegenEvent::Finished);
1140                    cx.notify();
1141                });
1142            };
1143
1144            let mut completion_events = match completion_stream.await {
1145                Ok(events) => events,
1146                Err(err) => {
1147                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1148                    return;
1149                }
1150            };
1151
1152            enum ToolUseOutput {
1153                Rewrite {
1154                    text: String,
1155                    description: Option<String>,
1156                },
1157                Failure(String),
1158            }
1159
1160            enum ModelUpdate {
1161                Description(String),
1162                Failure(String),
1163            }
1164
1165            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1166            let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1167                let mut chars_read_so_far = chars_read_so_far.lock();
1168                match tool_use.name.as_ref() {
1169                    REWRITE_SECTION_TOOL_NAME => {
1170                        let Ok(input) =
1171                            serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1172                        else {
1173                            return None;
1174                        };
1175                        let text = input.replacement_text[*chars_read_so_far..].to_string();
1176                        *chars_read_so_far = input.replacement_text.len();
1177                        Some(ToolUseOutput::Rewrite {
1178                            text,
1179                            description: None,
1180                        })
1181                    }
1182                    FAILURE_MESSAGE_TOOL_NAME => {
1183                        let Ok(mut input) =
1184                            serde_json::from_value::<FailureMessageInput>(tool_use.input)
1185                        else {
1186                            return None;
1187                        };
1188                        Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1189                    }
1190                    _ => None,
1191                }
1192            };
1193
1194            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1195
1196            cx.spawn({
1197                let codegen = codegen.clone();
1198                async move |cx| {
1199                    while let Some(update) = message_rx.next().await {
1200                        let _ = codegen.update(cx, |this, _cx| match update {
1201                            ModelUpdate::Description(d) => this.description = Some(d),
1202                            ModelUpdate::Failure(f) => this.failure = Some(f),
1203                        });
1204                    }
1205                }
1206            })
1207            .detach();
1208
1209            let mut message_id = None;
1210            let mut first_text = None;
1211            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1212            let total_text = Arc::new(Mutex::new(String::new()));
1213
1214            loop {
1215                if let Some(first_event) = completion_events.next().await {
1216                    match first_event {
1217                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1218                            message_id = Some(id);
1219                        }
1220                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1221                            if let Some(output) = process_tool_use(tool_use) {
1222                                let (text, update) = match output {
1223                                    ToolUseOutput::Rewrite { text, description } => {
1224                                        (Some(text), description.map(ModelUpdate::Description))
1225                                    }
1226                                    ToolUseOutput::Failure(message) => {
1227                                        (None, Some(ModelUpdate::Failure(message)))
1228                                    }
1229                                };
1230                                if let Some(update) = update {
1231                                    let _ = message_tx.unbounded_send(update);
1232                                }
1233                                first_text = text;
1234                                if first_text.is_some() {
1235                                    break;
1236                                }
1237                            }
1238                        }
1239                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1240                            *last_token_usage.lock() = token_usage;
1241                        }
1242                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1243                            let mut lock = total_text.lock();
1244                            lock.push_str(&text);
1245                        }
1246                        Ok(e) => {
1247                            log::warn!("Unexpected event: {:?}", e);
1248                            break;
1249                        }
1250                        Err(e) => {
1251                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1252                            break;
1253                        }
1254                    }
1255                }
1256            }
1257
1258            let Some(first_text) = first_text else {
1259                finish_with_status(CodegenStatus::Done, cx);
1260                return;
1261            };
1262
1263            let move_last_token_usage = last_token_usage.clone();
1264
1265            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1266                completion_events.filter_map(move |e| {
1267                    let process_tool_use = process_tool_use.clone();
1268                    let last_token_usage = move_last_token_usage.clone();
1269                    let total_text = total_text.clone();
1270                    let mut message_tx = message_tx.clone();
1271                    async move {
1272                        match e {
1273                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1274                                let Some(output) = process_tool_use(tool_use) else {
1275                                    return None;
1276                                };
1277                                let (text, update) = match output {
1278                                    ToolUseOutput::Rewrite { text, description } => {
1279                                        (Some(text), description.map(ModelUpdate::Description))
1280                                    }
1281                                    ToolUseOutput::Failure(message) => {
1282                                        (None, Some(ModelUpdate::Failure(message)))
1283                                    }
1284                                };
1285                                if let Some(update) = update {
1286                                    let _ = message_tx.send(update).await;
1287                                }
1288                                text.map(Ok)
1289                            }
1290                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1291                                *last_token_usage.lock() = token_usage;
1292                                None
1293                            }
1294                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1295                                let mut lock = total_text.lock();
1296                                lock.push_str(&text);
1297                                None
1298                            }
1299                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1300                            e => {
1301                                log::error!("UNEXPECTED EVENT {:?}", e);
1302                                None
1303                            }
1304                        }
1305                    }
1306                }),
1307            ));
1308
1309            let language_model_text_stream = LanguageModelTextStream {
1310                message_id: message_id,
1311                stream: text_stream,
1312                last_token_usage,
1313            };
1314
1315            let Some(task) = codegen
1316                .update(cx, move |codegen, cx| {
1317                    codegen.handle_stream(
1318                        model,
1319                        /* strip_invalid_spans: */ false,
1320                        async { Ok(language_model_text_stream) },
1321                        cx,
1322                    )
1323                })
1324                .ok()
1325            else {
1326                return;
1327            };
1328
1329            task.await;
1330        })
1331    }
1332}
1333
1334#[derive(Copy, Clone, Debug)]
1335pub enum CodegenEvent {
1336    Finished,
1337    Undone,
1338}
1339
1340struct StripInvalidSpans<T> {
1341    stream: T,
1342    stream_done: bool,
1343    buffer: String,
1344    first_line: bool,
1345    line_end: bool,
1346    starts_with_code_block: bool,
1347}
1348
1349impl<T> StripInvalidSpans<T>
1350where
1351    T: Stream<Item = Result<String>>,
1352{
1353    fn new(stream: T) -> Self {
1354        Self {
1355            stream,
1356            stream_done: false,
1357            buffer: String::new(),
1358            first_line: true,
1359            line_end: false,
1360            starts_with_code_block: false,
1361        }
1362    }
1363}
1364
1365impl<T> Stream for StripInvalidSpans<T>
1366where
1367    T: Stream<Item = Result<String>>,
1368{
1369    type Item = Result<String>;
1370
1371    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1372        const CODE_BLOCK_DELIMITER: &str = "```";
1373        const CURSOR_SPAN: &str = "<|CURSOR|>";
1374
1375        let this = unsafe { self.get_unchecked_mut() };
1376        loop {
1377            if !this.stream_done {
1378                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1379                match stream.as_mut().poll_next(cx) {
1380                    Poll::Ready(Some(Ok(chunk))) => {
1381                        this.buffer.push_str(&chunk);
1382                    }
1383                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1384                    Poll::Ready(None) => {
1385                        this.stream_done = true;
1386                    }
1387                    Poll::Pending => return Poll::Pending,
1388                }
1389            }
1390
1391            let mut chunk = String::new();
1392            let mut consumed = 0;
1393            if !this.buffer.is_empty() {
1394                let mut lines = this.buffer.split('\n').enumerate().peekable();
1395                while let Some((line_ix, line)) = lines.next() {
1396                    if line_ix > 0 {
1397                        this.first_line = false;
1398                    }
1399
1400                    if this.first_line {
1401                        let trimmed_line = line.trim();
1402                        if lines.peek().is_some() {
1403                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1404                                consumed += line.len() + 1;
1405                                this.starts_with_code_block = true;
1406                                continue;
1407                            }
1408                        } else if trimmed_line.is_empty()
1409                            || prefixes(CODE_BLOCK_DELIMITER)
1410                                .any(|prefix| trimmed_line.starts_with(prefix))
1411                        {
1412                            break;
1413                        }
1414                    }
1415
1416                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1417                    if lines.peek().is_some() {
1418                        if this.line_end {
1419                            chunk.push('\n');
1420                        }
1421
1422                        chunk.push_str(&line_without_cursor);
1423                        this.line_end = true;
1424                        consumed += line.len() + 1;
1425                    } else if this.stream_done {
1426                        if !this.starts_with_code_block
1427                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1428                        {
1429                            if this.line_end {
1430                                chunk.push('\n');
1431                            }
1432
1433                            chunk.push_str(line);
1434                        }
1435
1436                        consumed += line.len();
1437                    } else {
1438                        let trimmed_line = line.trim();
1439                        if trimmed_line.is_empty()
1440                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1441                            || prefixes(CODE_BLOCK_DELIMITER)
1442                                .any(|prefix| trimmed_line.ends_with(prefix))
1443                        {
1444                            break;
1445                        } else {
1446                            if this.line_end {
1447                                chunk.push('\n');
1448                                this.line_end = false;
1449                            }
1450
1451                            chunk.push_str(&line_without_cursor);
1452                            consumed += line.len();
1453                        }
1454                    }
1455                }
1456            }
1457
1458            this.buffer = this.buffer.split_off(consumed);
1459            if !chunk.is_empty() {
1460                return Poll::Ready(Some(Ok(chunk)));
1461            } else if this.stream_done {
1462                return Poll::Ready(None);
1463            }
1464        }
1465    }
1466}
1467
1468fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1469    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1470}
1471
1472#[derive(Default)]
1473pub struct Diff {
1474    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1475    pub inserted_row_ranges: Vec<Range<Anchor>>,
1476}
1477
1478impl Diff {
1479    fn is_empty(&self) -> bool {
1480        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1481    }
1482}
1483
1484#[cfg(test)]
1485mod tests {
1486    use super::*;
1487    use futures::{
1488        Stream,
1489        stream::{self},
1490    };
1491    use gpui::TestAppContext;
1492    use indoc::indoc;
1493    use language::{Buffer, Point};
1494    use language_model::fake_provider::FakeLanguageModel;
1495    use language_model::{
1496        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
1497        LanguageModelToolUse, StopReason, TokenUsage,
1498    };
1499    use languages::rust_lang;
1500    use rand::prelude::*;
1501    use settings::SettingsStore;
1502    use std::{future, sync::Arc};
1503
1504    #[gpui::test(iterations = 10)]
1505    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1506        init_test(cx);
1507
1508        let text = indoc! {"
1509            fn main() {
1510                let x = 0;
1511                for _ in 0..10 {
1512                    x += 1;
1513                }
1514            }
1515        "};
1516        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1517        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1518        let range = buffer.read_with(cx, |buffer, cx| {
1519            let snapshot = buffer.snapshot(cx);
1520            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1521        });
1522        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1523        let codegen = cx.new(|cx| {
1524            CodegenAlternative::new(
1525                buffer.clone(),
1526                range.clone(),
1527                true,
1528                prompt_builder,
1529                Uuid::new_v4(),
1530                cx,
1531            )
1532        });
1533
1534        let chunks_tx = simulate_response_stream(&codegen, cx);
1535
1536        let mut new_text = concat!(
1537            "       let mut x = 0;\n",
1538            "       while x < 10 {\n",
1539            "           x += 1;\n",
1540            "       }",
1541        );
1542        while !new_text.is_empty() {
1543            let max_len = cmp::min(new_text.len(), 10);
1544            let len = rng.random_range(1..=max_len);
1545            let (chunk, suffix) = new_text.split_at(len);
1546            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1547            new_text = suffix;
1548            cx.background_executor.run_until_parked();
1549        }
1550        drop(chunks_tx);
1551        cx.background_executor.run_until_parked();
1552
1553        assert_eq!(
1554            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1555            indoc! {"
1556                fn main() {
1557                    let mut x = 0;
1558                    while x < 10 {
1559                        x += 1;
1560                    }
1561                }
1562            "}
1563        );
1564    }
1565
1566    #[gpui::test(iterations = 10)]
1567    async fn test_autoindent_when_generating_past_indentation(
1568        cx: &mut TestAppContext,
1569        mut rng: StdRng,
1570    ) {
1571        init_test(cx);
1572
1573        let text = indoc! {"
1574            fn main() {
1575                le
1576            }
1577        "};
1578        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1579        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1580        let range = buffer.read_with(cx, |buffer, cx| {
1581            let snapshot = buffer.snapshot(cx);
1582            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1583        });
1584        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1585        let codegen = cx.new(|cx| {
1586            CodegenAlternative::new(
1587                buffer.clone(),
1588                range.clone(),
1589                true,
1590                prompt_builder,
1591                Uuid::new_v4(),
1592                cx,
1593            )
1594        });
1595
1596        let chunks_tx = simulate_response_stream(&codegen, cx);
1597
1598        cx.background_executor.run_until_parked();
1599
1600        let mut new_text = concat!(
1601            "t mut x = 0;\n",
1602            "while x < 10 {\n",
1603            "    x += 1;\n",
1604            "}", //
1605        );
1606        while !new_text.is_empty() {
1607            let max_len = cmp::min(new_text.len(), 10);
1608            let len = rng.random_range(1..=max_len);
1609            let (chunk, suffix) = new_text.split_at(len);
1610            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1611            new_text = suffix;
1612            cx.background_executor.run_until_parked();
1613        }
1614        drop(chunks_tx);
1615        cx.background_executor.run_until_parked();
1616
1617        assert_eq!(
1618            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1619            indoc! {"
1620                fn main() {
1621                    let mut x = 0;
1622                    while x < 10 {
1623                        x += 1;
1624                    }
1625                }
1626            "}
1627        );
1628    }
1629
1630    #[gpui::test(iterations = 10)]
1631    async fn test_autoindent_when_generating_before_indentation(
1632        cx: &mut TestAppContext,
1633        mut rng: StdRng,
1634    ) {
1635        init_test(cx);
1636
1637        let text = concat!(
1638            "fn main() {\n",
1639            "  \n",
1640            "}\n" //
1641        );
1642        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1643        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1644        let range = buffer.read_with(cx, |buffer, cx| {
1645            let snapshot = buffer.snapshot(cx);
1646            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1647        });
1648        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1649        let codegen = cx.new(|cx| {
1650            CodegenAlternative::new(
1651                buffer.clone(),
1652                range.clone(),
1653                true,
1654                prompt_builder,
1655                Uuid::new_v4(),
1656                cx,
1657            )
1658        });
1659
1660        let chunks_tx = simulate_response_stream(&codegen, cx);
1661
1662        cx.background_executor.run_until_parked();
1663
1664        let mut new_text = concat!(
1665            "let mut x = 0;\n",
1666            "while x < 10 {\n",
1667            "    x += 1;\n",
1668            "}", //
1669        );
1670        while !new_text.is_empty() {
1671            let max_len = cmp::min(new_text.len(), 10);
1672            let len = rng.random_range(1..=max_len);
1673            let (chunk, suffix) = new_text.split_at(len);
1674            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1675            new_text = suffix;
1676            cx.background_executor.run_until_parked();
1677        }
1678        drop(chunks_tx);
1679        cx.background_executor.run_until_parked();
1680
1681        assert_eq!(
1682            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1683            indoc! {"
1684                fn main() {
1685                    let mut x = 0;
1686                    while x < 10 {
1687                        x += 1;
1688                    }
1689                }
1690            "}
1691        );
1692    }
1693
1694    #[gpui::test(iterations = 10)]
1695    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1696        init_test(cx);
1697
1698        let text = indoc! {"
1699            func main() {
1700            \tx := 0
1701            \tfor i := 0; i < 10; i++ {
1702            \t\tx++
1703            \t}
1704            }
1705        "};
1706        let buffer = cx.new(|cx| Buffer::local(text, cx));
1707        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1708        let range = buffer.read_with(cx, |buffer, cx| {
1709            let snapshot = buffer.snapshot(cx);
1710            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1711        });
1712        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1713        let codegen = cx.new(|cx| {
1714            CodegenAlternative::new(
1715                buffer.clone(),
1716                range.clone(),
1717                true,
1718                prompt_builder,
1719                Uuid::new_v4(),
1720                cx,
1721            )
1722        });
1723
1724        let chunks_tx = simulate_response_stream(&codegen, cx);
1725        let new_text = concat!(
1726            "func main() {\n",
1727            "\tx := 0\n",
1728            "\tfor x < 10 {\n",
1729            "\t\tx++\n",
1730            "\t}", //
1731        );
1732        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1733        drop(chunks_tx);
1734        cx.background_executor.run_until_parked();
1735
1736        assert_eq!(
1737            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1738            indoc! {"
1739                func main() {
1740                \tx := 0
1741                \tfor x < 10 {
1742                \t\tx++
1743                \t}
1744                }
1745            "}
1746        );
1747    }
1748
1749    #[gpui::test]
1750    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1751        init_test(cx);
1752
1753        let text = indoc! {"
1754            fn main() {
1755                let x = 0;
1756            }
1757        "};
1758        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1759        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1760        let range = buffer.read_with(cx, |buffer, cx| {
1761            let snapshot = buffer.snapshot(cx);
1762            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1763        });
1764        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1765        let codegen = cx.new(|cx| {
1766            CodegenAlternative::new(
1767                buffer.clone(),
1768                range.clone(),
1769                false,
1770                prompt_builder,
1771                Uuid::new_v4(),
1772                cx,
1773            )
1774        });
1775
1776        let chunks_tx = simulate_response_stream(&codegen, cx);
1777        chunks_tx
1778            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1779            .unwrap();
1780        drop(chunks_tx);
1781        cx.run_until_parked();
1782
1783        // The codegen is inactive, so the buffer doesn't get modified.
1784        assert_eq!(
1785            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1786            text
1787        );
1788
1789        // Activating the codegen applies the changes.
1790        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1791        assert_eq!(
1792            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1793            indoc! {"
1794                fn main() {
1795                    let mut x = 0;
1796                    x += 1;
1797                }
1798            "}
1799        );
1800
1801        // Deactivating the codegen undoes the changes.
1802        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1803        cx.run_until_parked();
1804        assert_eq!(
1805            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1806            text
1807        );
1808    }
1809
1810    // When not streaming tool calls, we strip backticks as part of parsing the model's
1811    // plain text response. This is a regression test for a bug where we stripped
1812    // backticks incorrectly.
1813    #[gpui::test]
1814    async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
1815        init_test(cx);
1816        let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
1817        let buffer = cx.new(|cx| Buffer::local("", cx));
1818        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1819        let range = buffer.read_with(cx, |buffer, cx| {
1820            let snapshot = buffer.snapshot(cx);
1821            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
1822        });
1823        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1824        let codegen = cx.new(|cx| {
1825            CodegenAlternative::new(
1826                buffer.clone(),
1827                range.clone(),
1828                true,
1829                prompt_builder,
1830                Uuid::new_v4(),
1831                cx,
1832            )
1833        });
1834
1835        let events_tx = simulate_tool_based_completion(&codegen, cx);
1836        let chunk_len = text.find('`').unwrap();
1837        events_tx
1838            .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
1839            .unwrap();
1840        events_tx
1841            .unbounded_send(rewrite_tool_use("tool_2", &text, true))
1842            .unwrap();
1843        events_tx
1844            .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
1845            .unwrap();
1846        drop(events_tx);
1847        cx.run_until_parked();
1848
1849        assert_eq!(
1850            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1851            text
1852        );
1853    }
1854
1855    #[gpui::test]
1856    async fn test_strip_invalid_spans_from_codeblock() {
1857        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1858        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1859        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1860        assert_chunks(
1861            "```html\n```js\nLorem ipsum dolor\n```\n```",
1862            "```js\nLorem ipsum dolor\n```",
1863        )
1864        .await;
1865        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1866        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1867        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1868        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1869
1870        async fn assert_chunks(text: &str, expected_text: &str) {
1871            for chunk_size in 1..=text.len() {
1872                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1873                    .map(|chunk| chunk.unwrap())
1874                    .collect::<String>()
1875                    .await;
1876                assert_eq!(
1877                    actual_text, expected_text,
1878                    "failed to strip invalid spans, chunk size: {}",
1879                    chunk_size
1880                );
1881            }
1882        }
1883
1884        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1885            stream::iter(
1886                text.chars()
1887                    .collect::<Vec<_>>()
1888                    .chunks(size)
1889                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1890                    .collect::<Vec<_>>(),
1891            )
1892        }
1893    }
1894
1895    fn init_test(cx: &mut TestAppContext) {
1896        cx.update(LanguageModelRegistry::test);
1897        cx.set_global(cx.update(SettingsStore::test));
1898    }
1899
1900    fn simulate_response_stream(
1901        codegen: &Entity<CodegenAlternative>,
1902        cx: &mut TestAppContext,
1903    ) -> mpsc::UnboundedSender<String> {
1904        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1905        let model = Arc::new(FakeLanguageModel::default());
1906        codegen.update(cx, |codegen, cx| {
1907            codegen.generation = codegen.handle_stream(
1908                model,
1909                /* strip_invalid_spans: */ false,
1910                future::ready(Ok(LanguageModelTextStream {
1911                    message_id: None,
1912                    stream: chunks_rx.map(Ok).boxed(),
1913                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1914                })),
1915                cx,
1916            );
1917        });
1918        chunks_tx
1919    }
1920
1921    fn simulate_tool_based_completion(
1922        codegen: &Entity<CodegenAlternative>,
1923        cx: &mut TestAppContext,
1924    ) -> mpsc::UnboundedSender<LanguageModelCompletionEvent> {
1925        let (events_tx, events_rx) = mpsc::unbounded();
1926        let model = Arc::new(FakeLanguageModel::default());
1927        codegen.update(cx, |codegen, cx| {
1928            let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
1929                as BoxStream<
1930                    'static,
1931                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1932                >));
1933            codegen.generation = codegen.handle_completion(model, completion_stream, cx);
1934        });
1935        events_tx
1936    }
1937
1938    fn rewrite_tool_use(
1939        id: &str,
1940        replacement_text: &str,
1941        is_complete: bool,
1942    ) -> LanguageModelCompletionEvent {
1943        let input = RewriteSectionInput {
1944            replacement_text: replacement_text.into(),
1945        };
1946        LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
1947            id: id.into(),
1948            name: REWRITE_SECTION_TOOL_NAME.into(),
1949            raw_input: serde_json::to_string(&input).unwrap(),
1950            input: serde_json::to_value(&input).unwrap(),
1951            is_input_complete: is_complete,
1952            thought_signature: None,
1953        })
1954    }
1955}