buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use uuid::Uuid;
   5
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
  10use futures::{
  11    SinkExt, Stream, StreamExt, TryStreamExt as _,
  12    channel::mpsc,
  13    future::{LocalBoxFuture, Shared},
  14    join,
  15    stream::BoxStream,
  16};
  17use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  18use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  19use language_model::{
  20    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  22    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  23    LanguageModelToolUse, Role, TokenUsage,
  24};
  25use multi_buffer::MultiBufferRow;
  26use parking_lot::Mutex;
  27use prompt_store::PromptBuilder;
  28use rope::Rope;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings as _;
  32use smol::future::FutureExt;
  33use std::{
  34    cmp,
  35    future::Future,
  36    iter,
  37    ops::{Range, RangeInclusive},
  38    pin::Pin,
  39    sync::Arc,
  40    task::{self, Poll},
  41    time::Instant,
  42};
  43use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  44
  45/// Use this tool when you cannot or should not make a rewrite. This includes:
  46/// - The user's request is unclear, ambiguous, or nonsensical
  47/// - The requested change cannot be made by only editing the <rewrite_this> section
  48#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  49pub struct FailureMessageInput {
  50    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  51    #[serde(default)]
  52    pub message: String,
  53}
  54
  55/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  56/// Only use this tool when you are confident you understand the user's request and can fulfill it
  57/// by editing the marked section.
  58#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  59pub struct RewriteSectionInput {
  60    /// The text to replace the section with.
  61    #[serde(default)]
  62    pub replacement_text: String,
  63}
  64
  65pub struct BufferCodegen {
  66    alternatives: Vec<Entity<CodegenAlternative>>,
  67    pub active_alternative: usize,
  68    seen_alternatives: HashSet<usize>,
  69    subscriptions: Vec<Subscription>,
  70    buffer: Entity<MultiBuffer>,
  71    range: Range<Anchor>,
  72    initial_transaction_id: Option<TransactionId>,
  73    builder: Arc<PromptBuilder>,
  74    pub is_insertion: bool,
  75    session_id: Uuid,
  76}
  77
  78pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
  79pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
  80
  81impl BufferCodegen {
  82    pub fn new(
  83        buffer: Entity<MultiBuffer>,
  84        range: Range<Anchor>,
  85        initial_transaction_id: Option<TransactionId>,
  86        session_id: Uuid,
  87        builder: Arc<PromptBuilder>,
  88        cx: &mut Context<Self>,
  89    ) -> Self {
  90        let codegen = cx.new(|cx| {
  91            CodegenAlternative::new(
  92                buffer.clone(),
  93                range.clone(),
  94                false,
  95                builder.clone(),
  96                session_id,
  97                cx,
  98            )
  99        });
 100        let mut this = Self {
 101            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
 102            alternatives: vec![codegen],
 103            active_alternative: 0,
 104            seen_alternatives: HashSet::default(),
 105            subscriptions: Vec::new(),
 106            buffer,
 107            range,
 108            initial_transaction_id,
 109            builder,
 110            session_id,
 111        };
 112        this.activate(0, cx);
 113        this
 114    }
 115
 116    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 117        let codegen = self.active_alternative().clone();
 118        self.subscriptions.clear();
 119        self.subscriptions
 120            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 121        self.subscriptions
 122            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 123    }
 124
 125    pub fn active_completion(&self, cx: &App) -> Option<String> {
 126        self.active_alternative().read(cx).current_completion()
 127    }
 128
 129    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 130        &self.alternatives[self.active_alternative]
 131    }
 132
 133    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 134        self.active_alternative().read(cx).language_name(cx)
 135    }
 136
 137    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 138        &self.active_alternative().read(cx).status
 139    }
 140
 141    pub fn alternative_count(&self, cx: &App) -> usize {
 142        LanguageModelRegistry::read_global(cx)
 143            .inline_alternative_models()
 144            .len()
 145            + 1
 146    }
 147
 148    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 149        let next_active_ix = if self.active_alternative == 0 {
 150            self.alternatives.len() - 1
 151        } else {
 152            self.active_alternative - 1
 153        };
 154        self.activate(next_active_ix, cx);
 155    }
 156
 157    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 158        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 159        self.activate(next_active_ix, cx);
 160    }
 161
 162    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 163        self.active_alternative()
 164            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 165        self.seen_alternatives.insert(index);
 166        self.active_alternative = index;
 167        self.active_alternative()
 168            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 169        self.subscribe_to_alternative(cx);
 170        cx.notify();
 171    }
 172
 173    pub fn start(
 174        &mut self,
 175        primary_model: Arc<dyn LanguageModel>,
 176        user_prompt: String,
 177        context_task: Shared<Task<Option<LoadedContext>>>,
 178        cx: &mut Context<Self>,
 179    ) -> Result<()> {
 180        let alternative_models = LanguageModelRegistry::read_global(cx)
 181            .inline_alternative_models()
 182            .to_vec();
 183
 184        self.active_alternative()
 185            .update(cx, |alternative, cx| alternative.undo(cx));
 186        self.activate(0, cx);
 187        self.alternatives.truncate(1);
 188
 189        for _ in 0..alternative_models.len() {
 190            self.alternatives.push(cx.new(|cx| {
 191                CodegenAlternative::new(
 192                    self.buffer.clone(),
 193                    self.range.clone(),
 194                    false,
 195                    self.builder.clone(),
 196                    self.session_id,
 197                    cx,
 198                )
 199            }));
 200        }
 201
 202        for (model, alternative) in iter::once(primary_model)
 203            .chain(alternative_models)
 204            .zip(&self.alternatives)
 205        {
 206            alternative.update(cx, |alternative, cx| {
 207                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 208            })?;
 209        }
 210
 211        Ok(())
 212    }
 213
 214    pub fn stop(&mut self, cx: &mut Context<Self>) {
 215        for codegen in &self.alternatives {
 216            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 217        }
 218    }
 219
 220    pub fn undo(&mut self, cx: &mut Context<Self>) {
 221        self.active_alternative()
 222            .update(cx, |codegen, cx| codegen.undo(cx));
 223
 224        self.buffer.update(cx, |buffer, cx| {
 225            if let Some(transaction_id) = self.initial_transaction_id.take() {
 226                buffer.undo_transaction(transaction_id, cx);
 227                buffer.refresh_preview(cx);
 228            }
 229        });
 230    }
 231
 232    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 233        self.active_alternative().read(cx).buffer.clone()
 234    }
 235
 236    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 237        self.active_alternative().read(cx).old_buffer.clone()
 238    }
 239
 240    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 241        self.active_alternative().read(cx).snapshot.clone()
 242    }
 243
 244    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 245        self.active_alternative().read(cx).edit_position
 246    }
 247
 248    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 249        &self.active_alternative().read(cx).diff
 250    }
 251
 252    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 253        self.active_alternative().read(cx).last_equal_ranges()
 254    }
 255
 256    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 257        self.active_alternative().read(cx).selected_text()
 258    }
 259
 260    pub fn session_id(&self) -> Uuid {
 261        self.session_id
 262    }
 263}
 264
 265impl EventEmitter<CodegenEvent> for BufferCodegen {}
 266
 267pub struct CodegenAlternative {
 268    buffer: Entity<MultiBuffer>,
 269    old_buffer: Entity<Buffer>,
 270    snapshot: MultiBufferSnapshot,
 271    edit_position: Option<Anchor>,
 272    range: Range<Anchor>,
 273    last_equal_ranges: Vec<Range<Anchor>>,
 274    transformation_transaction_id: Option<TransactionId>,
 275    status: CodegenStatus,
 276    generation: Task<()>,
 277    diff: Diff,
 278    _subscription: gpui::Subscription,
 279    builder: Arc<PromptBuilder>,
 280    active: bool,
 281    edits: Vec<(Range<Anchor>, String)>,
 282    line_operations: Vec<LineOperation>,
 283    elapsed_time: Option<f64>,
 284    completion: Option<String>,
 285    selected_text: Option<String>,
 286    pub message_id: Option<String>,
 287    session_id: Uuid,
 288    pub description: Option<String>,
 289    pub failure: Option<String>,
 290}
 291
 292impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 293
 294impl CodegenAlternative {
 295    pub fn new(
 296        buffer: Entity<MultiBuffer>,
 297        range: Range<Anchor>,
 298        active: bool,
 299        builder: Arc<PromptBuilder>,
 300        session_id: Uuid,
 301        cx: &mut Context<Self>,
 302    ) -> Self {
 303        let snapshot = buffer.read(cx).snapshot(cx);
 304
 305        let (old_buffer, _, _) = snapshot
 306            .range_to_buffer_ranges(range.clone())
 307            .pop()
 308            .unwrap();
 309        let old_buffer = cx.new(|cx| {
 310            let text = old_buffer.as_rope().clone();
 311            let line_ending = old_buffer.line_ending();
 312            let language = old_buffer.language().cloned();
 313            let language_registry = buffer
 314                .read(cx)
 315                .buffer(old_buffer.remote_id())
 316                .unwrap()
 317                .read(cx)
 318                .language_registry();
 319
 320            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 321            buffer.set_language(language, cx);
 322            if let Some(language_registry) = language_registry {
 323                buffer.set_language_registry(language_registry);
 324            }
 325            buffer
 326        });
 327
 328        Self {
 329            buffer: buffer.clone(),
 330            old_buffer,
 331            edit_position: None,
 332            message_id: None,
 333            snapshot,
 334            last_equal_ranges: Default::default(),
 335            transformation_transaction_id: None,
 336            status: CodegenStatus::Idle,
 337            generation: Task::ready(()),
 338            diff: Diff::default(),
 339            builder,
 340            active: active,
 341            edits: Vec::new(),
 342            line_operations: Vec::new(),
 343            range,
 344            elapsed_time: None,
 345            completion: None,
 346            selected_text: None,
 347            session_id,
 348            description: None,
 349            failure: None,
 350            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 351        }
 352    }
 353
 354    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 355        self.old_buffer
 356            .read(cx)
 357            .language()
 358            .map(|language| language.name())
 359    }
 360
 361    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 362        if active != self.active {
 363            self.active = active;
 364
 365            if self.active {
 366                let edits = self.edits.clone();
 367                self.apply_edits(edits, cx);
 368                if matches!(self.status, CodegenStatus::Pending) {
 369                    let line_operations = self.line_operations.clone();
 370                    self.reapply_line_based_diff(line_operations, cx);
 371                } else {
 372                    self.reapply_batch_diff(cx).detach();
 373                }
 374            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 375                self.buffer.update(cx, |buffer, cx| {
 376                    buffer.undo_transaction(transaction_id, cx);
 377                    buffer.forget_transaction(transaction_id, cx);
 378                });
 379            }
 380        }
 381    }
 382
 383    fn handle_buffer_event(
 384        &mut self,
 385        _buffer: Entity<MultiBuffer>,
 386        event: &multi_buffer::Event,
 387        cx: &mut Context<Self>,
 388    ) {
 389        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 390            && self.transformation_transaction_id == Some(*transaction_id)
 391        {
 392            self.transformation_transaction_id = None;
 393            self.generation = Task::ready(());
 394            cx.emit(CodegenEvent::Undone);
 395        }
 396    }
 397
 398    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 399        &self.last_equal_ranges
 400    }
 401
 402    pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 403        model.supports_streaming_tools()
 404            && cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
 405            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 406    }
 407
 408    pub fn start(
 409        &mut self,
 410        user_prompt: String,
 411        context_task: Shared<Task<Option<LoadedContext>>>,
 412        model: Arc<dyn LanguageModel>,
 413        cx: &mut Context<Self>,
 414    ) -> Result<()> {
 415        // Clear the model explanation since the user has started a new generation.
 416        self.description = None;
 417
 418        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 419            self.buffer.update(cx, |buffer, cx| {
 420                buffer.undo_transaction(transformation_transaction_id, cx);
 421            });
 422        }
 423
 424        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 425
 426        if Self::use_streaming_tools(model.as_ref(), cx) {
 427            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 428            let completion_events = cx.spawn({
 429                let model = model.clone();
 430                async move |_, cx| model.stream_completion(request.await, cx).await
 431            });
 432            self.generation = self.handle_completion(model, completion_events, cx);
 433        } else {
 434            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 435                if user_prompt.trim().to_lowercase() == "delete" {
 436                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 437                } else {
 438                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 439                    cx.spawn({
 440                        let model = model.clone();
 441                        async move |_, cx| {
 442                            Ok(model.stream_completion_text(request.await, cx).await?)
 443                        }
 444                    })
 445                    .boxed_local()
 446                };
 447            self.generation =
 448                self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
 449        }
 450
 451        Ok(())
 452    }
 453
 454    fn build_request_tools(
 455        &self,
 456        model: &Arc<dyn LanguageModel>,
 457        user_prompt: String,
 458        context_task: Shared<Task<Option<LoadedContext>>>,
 459        cx: &mut App,
 460    ) -> Result<Task<LanguageModelRequest>> {
 461        let buffer = self.buffer.read(cx).snapshot(cx);
 462        let language = buffer.language_at(self.range.start);
 463        let language_name = if let Some(language) = language.as_ref() {
 464            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 465                None
 466            } else {
 467                Some(language.name())
 468            }
 469        } else {
 470            None
 471        };
 472
 473        let language_name = language_name.as_ref();
 474        let start = buffer.point_to_buffer_offset(self.range.start);
 475        let end = buffer.point_to_buffer_offset(self.range.end);
 476        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 477            let (start_buffer, start_buffer_offset) = start;
 478            let (end_buffer, end_buffer_offset) = end;
 479            if start_buffer.remote_id() == end_buffer.remote_id() {
 480                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 481            } else {
 482                anyhow::bail!("invalid transformation range");
 483            }
 484        } else {
 485            anyhow::bail!("invalid transformation range");
 486        };
 487
 488        let system_prompt = self
 489            .builder
 490            .generate_inline_transformation_prompt_tools(
 491                language_name,
 492                buffer,
 493                range.start.0..range.end.0,
 494            )
 495            .context("generating content prompt")?;
 496
 497        let temperature = AgentSettings::temperature_for_model(model, cx);
 498
 499        let tool_input_format = model.tool_input_format();
 500        let tool_choice = model
 501            .supports_tool_choice(LanguageModelToolChoice::Any)
 502            .then_some(LanguageModelToolChoice::Any);
 503
 504        Ok(cx.spawn(async move |_cx| {
 505            let mut messages = vec![LanguageModelRequestMessage {
 506                role: Role::System,
 507                content: vec![system_prompt.into()],
 508                cache: false,
 509                reasoning_details: None,
 510            }];
 511
 512            let mut user_message = LanguageModelRequestMessage {
 513                role: Role::User,
 514                content: Vec::new(),
 515                cache: false,
 516                reasoning_details: None,
 517            };
 518
 519            if let Some(context) = context_task.await {
 520                context.add_to_request_message(&mut user_message);
 521            }
 522
 523            user_message.content.push(user_prompt.into());
 524            messages.push(user_message);
 525
 526            let tools = vec![
 527                LanguageModelRequestTool {
 528                    name: REWRITE_SECTION_TOOL_NAME.to_string(),
 529                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 530                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 531                },
 532                LanguageModelRequestTool {
 533                    name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
 534                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 535                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 536                },
 537            ];
 538
 539            LanguageModelRequest {
 540                thread_id: None,
 541                prompt_id: None,
 542                intent: Some(CompletionIntent::InlineAssist),
 543                mode: None,
 544                tools,
 545                tool_choice,
 546                stop: Vec::new(),
 547                temperature,
 548                messages,
 549                thinking_allowed: false,
 550            }
 551        }))
 552    }
 553
 554    fn build_request(
 555        &self,
 556        model: &Arc<dyn LanguageModel>,
 557        user_prompt: String,
 558        context_task: Shared<Task<Option<LoadedContext>>>,
 559        cx: &mut App,
 560    ) -> Result<Task<LanguageModelRequest>> {
 561        if Self::use_streaming_tools(model.as_ref(), cx) {
 562            return self.build_request_tools(model, user_prompt, context_task, cx);
 563        }
 564
 565        let buffer = self.buffer.read(cx).snapshot(cx);
 566        let language = buffer.language_at(self.range.start);
 567        let language_name = if let Some(language) = language.as_ref() {
 568            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 569                None
 570            } else {
 571                Some(language.name())
 572            }
 573        } else {
 574            None
 575        };
 576
 577        let language_name = language_name.as_ref();
 578        let start = buffer.point_to_buffer_offset(self.range.start);
 579        let end = buffer.point_to_buffer_offset(self.range.end);
 580        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 581            let (start_buffer, start_buffer_offset) = start;
 582            let (end_buffer, end_buffer_offset) = end;
 583            if start_buffer.remote_id() == end_buffer.remote_id() {
 584                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 585            } else {
 586                anyhow::bail!("invalid transformation range");
 587            }
 588        } else {
 589            anyhow::bail!("invalid transformation range");
 590        };
 591
 592        let prompt = self
 593            .builder
 594            .generate_inline_transformation_prompt(
 595                user_prompt,
 596                language_name,
 597                buffer,
 598                range.start.0..range.end.0,
 599            )
 600            .context("generating content prompt")?;
 601
 602        let temperature = AgentSettings::temperature_for_model(model, cx);
 603
 604        Ok(cx.spawn(async move |_cx| {
 605            let mut request_message = LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: Vec::new(),
 608                cache: false,
 609                reasoning_details: None,
 610            };
 611
 612            if let Some(context) = context_task.await {
 613                context.add_to_request_message(&mut request_message);
 614            }
 615
 616            request_message.content.push(prompt.into());
 617
 618            LanguageModelRequest {
 619                thread_id: None,
 620                prompt_id: None,
 621                intent: Some(CompletionIntent::InlineAssist),
 622                mode: None,
 623                tools: Vec::new(),
 624                tool_choice: None,
 625                stop: Vec::new(),
 626                temperature,
 627                messages: vec![request_message],
 628                thinking_allowed: false,
 629            }
 630        }))
 631    }
 632
 633    pub fn handle_stream(
 634        &mut self,
 635        model: Arc<dyn LanguageModel>,
 636        strip_invalid_spans: bool,
 637        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 638        cx: &mut Context<Self>,
 639    ) -> Task<()> {
 640        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 641        let session_id = self.session_id;
 642        let model_telemetry_id = model.telemetry_id();
 643        let model_provider_id = model.provider_id().to_string();
 644        let start_time = Instant::now();
 645
 646        // Make a new snapshot and re-resolve anchor in case the document was modified.
 647        // This can happen often if the editor loses focus and is saved + reformatted,
 648        // as in https://github.com/zed-industries/zed/issues/39088
 649        self.snapshot = self.buffer.read(cx).snapshot(cx);
 650        self.range = self.snapshot.anchor_after(self.range.start)
 651            ..self.snapshot.anchor_after(self.range.end);
 652
 653        let snapshot = self.snapshot.clone();
 654        let selected_text = snapshot
 655            .text_for_range(self.range.start..self.range.end)
 656            .collect::<Rope>();
 657
 658        self.selected_text = Some(selected_text.to_string());
 659
 660        let selection_start = self.range.start.to_point(&snapshot);
 661
 662        // Start with the indentation of the first line in the selection
 663        let mut suggested_line_indent = snapshot
 664            .suggested_indents(selection_start.row..=selection_start.row, cx)
 665            .into_values()
 666            .next()
 667            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 668
 669        // If the first line in the selection does not have indentation, check the following lines
 670        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 671            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 672                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 673                // Prefer tabs if a line in the selection uses tabs as indentation
 674                if line_indent.kind == IndentKind::Tab {
 675                    suggested_line_indent.kind = IndentKind::Tab;
 676                    break;
 677                }
 678            }
 679        }
 680
 681        let language_name = {
 682            let multibuffer = self.buffer.read(cx);
 683            let snapshot = multibuffer.snapshot(cx);
 684            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 685            ranges
 686                .first()
 687                .and_then(|(buffer, _, _)| buffer.language())
 688                .map(|language| language.name())
 689        };
 690
 691        self.diff = Diff::default();
 692        self.status = CodegenStatus::Pending;
 693        let mut edit_start = self.range.start.to_offset(&snapshot);
 694        let completion = Arc::new(Mutex::new(String::new()));
 695        let completion_clone = completion.clone();
 696
 697        cx.notify();
 698        cx.spawn(async move |codegen, cx| {
 699            let stream = stream.await;
 700
 701            let token_usage = stream
 702                .as_ref()
 703                .ok()
 704                .map(|stream| stream.last_token_usage.clone());
 705            let message_id = stream
 706                .as_ref()
 707                .ok()
 708                .and_then(|stream| stream.message_id.clone());
 709            let generate = async {
 710                let model_telemetry_id = model_telemetry_id.clone();
 711                let model_provider_id = model_provider_id.clone();
 712                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 713                let message_id = message_id.clone();
 714                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 715                    let anthropic_reporter = anthropic_reporter.clone();
 716                    let language_name = language_name.clone();
 717                    async move {
 718                        let mut response_latency = None;
 719                        let request_start = Instant::now();
 720                        let diff = async {
 721                            let raw_stream = stream?.stream.map_err(|error| error.into());
 722
 723                            let stripped;
 724                            let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
 725                                if strip_invalid_spans {
 726                                    stripped = StripInvalidSpans::new(raw_stream);
 727                                    Box::pin(stripped)
 728                                } else {
 729                                    Box::pin(raw_stream)
 730                                };
 731
 732                            let mut diff = StreamingDiff::new(selected_text.to_string());
 733                            let mut line_diff = LineDiff::default();
 734
 735                            let mut new_text = String::new();
 736                            let mut base_indent = None;
 737                            let mut line_indent = None;
 738                            let mut first_line = true;
 739
 740                            while let Some(chunk) = chunks.next().await {
 741                                if response_latency.is_none() {
 742                                    response_latency = Some(request_start.elapsed());
 743                                }
 744                                let chunk = chunk?;
 745                                completion_clone.lock().push_str(&chunk);
 746
 747                                let mut lines = chunk.split('\n').peekable();
 748                                while let Some(line) = lines.next() {
 749                                    new_text.push_str(line);
 750                                    if line_indent.is_none()
 751                                        && let Some(non_whitespace_ch_ix) =
 752                                            new_text.find(|ch: char| !ch.is_whitespace())
 753                                    {
 754                                        line_indent = Some(non_whitespace_ch_ix);
 755                                        base_indent = base_indent.or(line_indent);
 756
 757                                        let line_indent = line_indent.unwrap();
 758                                        let base_indent = base_indent.unwrap();
 759                                        let indent_delta = line_indent as i32 - base_indent as i32;
 760                                        let mut corrected_indent_len = cmp::max(
 761                                            0,
 762                                            suggested_line_indent.len as i32 + indent_delta,
 763                                        )
 764                                            as usize;
 765                                        if first_line {
 766                                            corrected_indent_len = corrected_indent_len
 767                                                .saturating_sub(selection_start.column as usize);
 768                                        }
 769
 770                                        let indent_char = suggested_line_indent.char();
 771                                        let mut indent_buffer = [0; 4];
 772                                        let indent_str =
 773                                            indent_char.encode_utf8(&mut indent_buffer);
 774                                        new_text.replace_range(
 775                                            ..line_indent,
 776                                            &indent_str.repeat(corrected_indent_len),
 777                                        );
 778                                    }
 779
 780                                    if line_indent.is_some() {
 781                                        let char_ops = diff.push_new(&new_text);
 782                                        line_diff.push_char_operations(&char_ops, &selected_text);
 783                                        diff_tx
 784                                            .send((char_ops, line_diff.line_operations()))
 785                                            .await?;
 786                                        new_text.clear();
 787                                    }
 788
 789                                    if lines.peek().is_some() {
 790                                        let char_ops = diff.push_new("\n");
 791                                        line_diff.push_char_operations(&char_ops, &selected_text);
 792                                        diff_tx
 793                                            .send((char_ops, line_diff.line_operations()))
 794                                            .await?;
 795                                        if line_indent.is_none() {
 796                                            // Don't write out the leading indentation in empty lines on the next line
 797                                            // This is the case where the above if statement didn't clear the buffer
 798                                            new_text.clear();
 799                                        }
 800                                        line_indent = None;
 801                                        first_line = false;
 802                                    }
 803                                }
 804                            }
 805
 806                            let mut char_ops = diff.push_new(&new_text);
 807                            char_ops.extend(diff.finish());
 808                            line_diff.push_char_operations(&char_ops, &selected_text);
 809                            line_diff.finish(&selected_text);
 810                            diff_tx
 811                                .send((char_ops, line_diff.line_operations()))
 812                                .await?;
 813
 814                            anyhow::Ok(())
 815                        };
 816
 817                        let result = diff.await;
 818
 819                        let error_message = result.as_ref().err().map(|error| error.to_string());
 820                        telemetry::event!(
 821                            "Assistant Responded",
 822                            kind = "inline",
 823                            phase = "response",
 824                            session_id = session_id.to_string(),
 825                            model = model_telemetry_id,
 826                            model_provider = model_provider_id,
 827                            language_name = language_name.as_ref().map(|n| n.to_string()),
 828                            message_id = message_id.as_deref(),
 829                            response_latency = response_latency,
 830                            error_message = error_message.as_deref(),
 831                        );
 832
 833                        anthropic_reporter.report(language_model::AnthropicEventData {
 834                            completion_type: language_model::AnthropicCompletionType::Editor,
 835                            event: language_model::AnthropicEventType::Response,
 836                            language_name: language_name.map(|n| n.to_string()),
 837                            message_id,
 838                        });
 839
 840                        result?;
 841                        Ok(())
 842                    }
 843                });
 844
 845                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 846                    codegen.update(cx, |codegen, cx| {
 847                        codegen.last_equal_ranges.clear();
 848
 849                        let edits = char_ops
 850                            .into_iter()
 851                            .filter_map(|operation| match operation {
 852                                CharOperation::Insert { text } => {
 853                                    let edit_start = snapshot.anchor_after(edit_start);
 854                                    Some((edit_start..edit_start, text))
 855                                }
 856                                CharOperation::Delete { bytes } => {
 857                                    let edit_end = edit_start + bytes;
 858                                    let edit_range = snapshot.anchor_after(edit_start)
 859                                        ..snapshot.anchor_before(edit_end);
 860                                    edit_start = edit_end;
 861                                    Some((edit_range, String::new()))
 862                                }
 863                                CharOperation::Keep { bytes } => {
 864                                    let edit_end = edit_start + bytes;
 865                                    let edit_range = snapshot.anchor_after(edit_start)
 866                                        ..snapshot.anchor_before(edit_end);
 867                                    edit_start = edit_end;
 868                                    codegen.last_equal_ranges.push(edit_range);
 869                                    None
 870                                }
 871                            })
 872                            .collect::<Vec<_>>();
 873
 874                        if codegen.active {
 875                            codegen.apply_edits(edits.iter().cloned(), cx);
 876                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 877                        }
 878                        codegen.edits.extend(edits);
 879                        codegen.line_operations = line_ops;
 880                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 881
 882                        cx.notify();
 883                    })?;
 884                }
 885
 886                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 887                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 888                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 889                let batch_diff_task =
 890                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 891                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 892                line_based_stream_diff?;
 893
 894                anyhow::Ok(())
 895            };
 896
 897            let result = generate.await;
 898            let elapsed_time = start_time.elapsed().as_secs_f64();
 899
 900            codegen
 901                .update(cx, |this, cx| {
 902                    this.message_id = message_id;
 903                    this.last_equal_ranges.clear();
 904                    if let Err(error) = result {
 905                        this.status = CodegenStatus::Error(error);
 906                    } else {
 907                        this.status = CodegenStatus::Done;
 908                    }
 909                    this.elapsed_time = Some(elapsed_time);
 910                    this.completion = Some(completion.lock().clone());
 911                    if let Some(usage) = token_usage {
 912                        let usage = usage.lock();
 913                        telemetry::event!(
 914                            "Inline Assistant Completion",
 915                            model = model_telemetry_id,
 916                            model_provider = model_provider_id,
 917                            input_tokens = usage.input_tokens,
 918                            output_tokens = usage.output_tokens,
 919                        )
 920                    }
 921
 922                    cx.emit(CodegenEvent::Finished);
 923                    cx.notify();
 924                })
 925                .ok();
 926        })
 927    }
 928
 929    pub fn current_completion(&self) -> Option<String> {
 930        self.completion.clone()
 931    }
 932
 933    #[cfg(any(test, feature = "test-support"))]
 934    pub fn current_description(&self) -> Option<String> {
 935        self.description.clone()
 936    }
 937
 938    #[cfg(any(test, feature = "test-support"))]
 939    pub fn current_failure(&self) -> Option<String> {
 940        self.failure.clone()
 941    }
 942
 943    pub fn selected_text(&self) -> Option<&str> {
 944        self.selected_text.as_deref()
 945    }
 946
 947    pub fn stop(&mut self, cx: &mut Context<Self>) {
 948        self.last_equal_ranges.clear();
 949        if self.diff.is_empty() {
 950            self.status = CodegenStatus::Idle;
 951        } else {
 952            self.status = CodegenStatus::Done;
 953        }
 954        self.generation = Task::ready(());
 955        cx.emit(CodegenEvent::Finished);
 956        cx.notify();
 957    }
 958
 959    pub fn undo(&mut self, cx: &mut Context<Self>) {
 960        self.buffer.update(cx, |buffer, cx| {
 961            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 962                buffer.undo_transaction(transaction_id, cx);
 963                buffer.refresh_preview(cx);
 964            }
 965        });
 966    }
 967
 968    fn apply_edits(
 969        &mut self,
 970        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 971        cx: &mut Context<CodegenAlternative>,
 972    ) {
 973        let transaction = self.buffer.update(cx, |buffer, cx| {
 974            // Avoid grouping agent edits with user edits.
 975            buffer.finalize_last_transaction(cx);
 976            buffer.start_transaction(cx);
 977            buffer.edit(edits, None, cx);
 978            buffer.end_transaction(cx)
 979        });
 980
 981        if let Some(transaction) = transaction {
 982            if let Some(first_transaction) = self.transformation_transaction_id {
 983                // Group all agent edits into the first transaction.
 984                self.buffer.update(cx, |buffer, cx| {
 985                    buffer.merge_transactions(transaction, first_transaction, cx)
 986                });
 987            } else {
 988                self.transformation_transaction_id = Some(transaction);
 989                self.buffer
 990                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 991            }
 992        }
 993    }
 994
 995    fn reapply_line_based_diff(
 996        &mut self,
 997        line_operations: impl IntoIterator<Item = LineOperation>,
 998        cx: &mut Context<Self>,
 999    ) {
1000        let old_snapshot = self.snapshot.clone();
1001        let old_range = self.range.to_point(&old_snapshot);
1002        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1003        let new_range = self.range.to_point(&new_snapshot);
1004
1005        let mut old_row = old_range.start.row;
1006        let mut new_row = new_range.start.row;
1007
1008        self.diff.deleted_row_ranges.clear();
1009        self.diff.inserted_row_ranges.clear();
1010        for operation in line_operations {
1011            match operation {
1012                LineOperation::Keep { lines } => {
1013                    old_row += lines;
1014                    new_row += lines;
1015                }
1016                LineOperation::Delete { lines } => {
1017                    let old_end_row = old_row + lines - 1;
1018                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1019
1020                    if let Some((_, last_deleted_row_range)) =
1021                        self.diff.deleted_row_ranges.last_mut()
1022                    {
1023                        if *last_deleted_row_range.end() + 1 == old_row {
1024                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1025                        } else {
1026                            self.diff
1027                                .deleted_row_ranges
1028                                .push((new_row, old_row..=old_end_row));
1029                        }
1030                    } else {
1031                        self.diff
1032                            .deleted_row_ranges
1033                            .push((new_row, old_row..=old_end_row));
1034                    }
1035
1036                    old_row += lines;
1037                }
1038                LineOperation::Insert { lines } => {
1039                    let new_end_row = new_row + lines - 1;
1040                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1041                    let end = new_snapshot.anchor_before(Point::new(
1042                        new_end_row,
1043                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1044                    ));
1045                    self.diff.inserted_row_ranges.push(start..end);
1046                    new_row += lines;
1047                }
1048            }
1049
1050            cx.notify();
1051        }
1052    }
1053
1054    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1055        let old_snapshot = self.snapshot.clone();
1056        let old_range = self.range.to_point(&old_snapshot);
1057        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1058        let new_range = self.range.to_point(&new_snapshot);
1059
1060        cx.spawn(async move |codegen, cx| {
1061            let (deleted_row_ranges, inserted_row_ranges) = cx
1062                .background_spawn(async move {
1063                    let old_text = old_snapshot
1064                        .text_for_range(
1065                            Point::new(old_range.start.row, 0)
1066                                ..Point::new(
1067                                    old_range.end.row,
1068                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1069                                ),
1070                        )
1071                        .collect::<String>();
1072                    let new_text = new_snapshot
1073                        .text_for_range(
1074                            Point::new(new_range.start.row, 0)
1075                                ..Point::new(
1076                                    new_range.end.row,
1077                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1078                                ),
1079                        )
1080                        .collect::<String>();
1081
1082                    let old_start_row = old_range.start.row;
1083                    let new_start_row = new_range.start.row;
1084                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1085                    let mut inserted_row_ranges = Vec::new();
1086                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1087                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1088                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1089                        if !old_rows.is_empty() {
1090                            deleted_row_ranges.push((
1091                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1092                                old_rows.start..=old_rows.end - 1,
1093                            ));
1094                        }
1095                        if !new_rows.is_empty() {
1096                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1097                            let new_end_row = new_rows.end - 1;
1098                            let end = new_snapshot.anchor_before(Point::new(
1099                                new_end_row,
1100                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1101                            ));
1102                            inserted_row_ranges.push(start..end);
1103                        }
1104                    }
1105                    (deleted_row_ranges, inserted_row_ranges)
1106                })
1107                .await;
1108
1109            codegen
1110                .update(cx, |codegen, cx| {
1111                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1112                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1113                    cx.notify();
1114                })
1115                .ok();
1116        })
1117    }
1118
1119    fn handle_completion(
1120        &mut self,
1121        model: Arc<dyn LanguageModel>,
1122        completion_stream: Task<
1123            Result<
1124                BoxStream<
1125                    'static,
1126                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1127                >,
1128                LanguageModelCompletionError,
1129            >,
1130        >,
1131        cx: &mut Context<Self>,
1132    ) -> Task<()> {
1133        self.diff = Diff::default();
1134        self.status = CodegenStatus::Pending;
1135
1136        cx.notify();
1137        // Leaving this in generation so that STOP equivalent events are respected even
1138        // while we're still pre-processing the completion event
1139        cx.spawn(async move |codegen, cx| {
1140            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1141                let _ = codegen.update(cx, |this, cx| {
1142                    this.status = status;
1143                    cx.emit(CodegenEvent::Finished);
1144                    cx.notify();
1145                });
1146            };
1147
1148            let mut completion_events = match completion_stream.await {
1149                Ok(events) => events,
1150                Err(err) => {
1151                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1152                    return;
1153                }
1154            };
1155
1156            enum ToolUseOutput {
1157                Rewrite {
1158                    text: String,
1159                    description: Option<String>,
1160                },
1161                Failure(String),
1162            }
1163
1164            enum ModelUpdate {
1165                Description(String),
1166                Failure(String),
1167            }
1168
1169            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1170            let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1171                let mut chars_read_so_far = chars_read_so_far.lock();
1172                match tool_use.name.as_ref() {
1173                    REWRITE_SECTION_TOOL_NAME => {
1174                        let Ok(input) =
1175                            serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1176                        else {
1177                            return None;
1178                        };
1179                        let text = input.replacement_text[*chars_read_so_far..].to_string();
1180                        *chars_read_so_far = input.replacement_text.len();
1181                        Some(ToolUseOutput::Rewrite {
1182                            text,
1183                            description: None,
1184                        })
1185                    }
1186                    FAILURE_MESSAGE_TOOL_NAME => {
1187                        let Ok(mut input) =
1188                            serde_json::from_value::<FailureMessageInput>(tool_use.input)
1189                        else {
1190                            return None;
1191                        };
1192                        Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1193                    }
1194                    _ => None,
1195                }
1196            };
1197
1198            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1199
1200            cx.spawn({
1201                let codegen = codegen.clone();
1202                async move |cx| {
1203                    while let Some(update) = message_rx.next().await {
1204                        let _ = codegen.update(cx, |this, _cx| match update {
1205                            ModelUpdate::Description(d) => this.description = Some(d),
1206                            ModelUpdate::Failure(f) => this.failure = Some(f),
1207                        });
1208                    }
1209                }
1210            })
1211            .detach();
1212
1213            let mut message_id = None;
1214            let mut first_text = None;
1215            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1216            let total_text = Arc::new(Mutex::new(String::new()));
1217
1218            loop {
1219                if let Some(first_event) = completion_events.next().await {
1220                    match first_event {
1221                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1222                            message_id = Some(id);
1223                        }
1224                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1225                            if let Some(output) = process_tool_use(tool_use) {
1226                                let (text, update) = match output {
1227                                    ToolUseOutput::Rewrite { text, description } => {
1228                                        (Some(text), description.map(ModelUpdate::Description))
1229                                    }
1230                                    ToolUseOutput::Failure(message) => {
1231                                        (None, Some(ModelUpdate::Failure(message)))
1232                                    }
1233                                };
1234                                if let Some(update) = update {
1235                                    let _ = message_tx.unbounded_send(update);
1236                                }
1237                                first_text = text;
1238                                if first_text.is_some() {
1239                                    break;
1240                                }
1241                            }
1242                        }
1243                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1244                            *last_token_usage.lock() = token_usage;
1245                        }
1246                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1247                            let mut lock = total_text.lock();
1248                            lock.push_str(&text);
1249                        }
1250                        Ok(e) => {
1251                            log::warn!("Unexpected event: {:?}", e);
1252                            break;
1253                        }
1254                        Err(e) => {
1255                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1256                            break;
1257                        }
1258                    }
1259                }
1260            }
1261
1262            let Some(first_text) = first_text else {
1263                finish_with_status(CodegenStatus::Done, cx);
1264                return;
1265            };
1266
1267            let move_last_token_usage = last_token_usage.clone();
1268
1269            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1270                completion_events.filter_map(move |e| {
1271                    let process_tool_use = process_tool_use.clone();
1272                    let last_token_usage = move_last_token_usage.clone();
1273                    let total_text = total_text.clone();
1274                    let mut message_tx = message_tx.clone();
1275                    async move {
1276                        match e {
1277                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1278                                let Some(output) = process_tool_use(tool_use) else {
1279                                    return None;
1280                                };
1281                                let (text, update) = match output {
1282                                    ToolUseOutput::Rewrite { text, description } => {
1283                                        (Some(text), description.map(ModelUpdate::Description))
1284                                    }
1285                                    ToolUseOutput::Failure(message) => {
1286                                        (None, Some(ModelUpdate::Failure(message)))
1287                                    }
1288                                };
1289                                if let Some(update) = update {
1290                                    let _ = message_tx.send(update).await;
1291                                }
1292                                text.map(Ok)
1293                            }
1294                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1295                                *last_token_usage.lock() = token_usage;
1296                                None
1297                            }
1298                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1299                                let mut lock = total_text.lock();
1300                                lock.push_str(&text);
1301                                None
1302                            }
1303                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1304                            e => {
1305                                log::error!("UNEXPECTED EVENT {:?}", e);
1306                                None
1307                            }
1308                        }
1309                    }
1310                }),
1311            ));
1312
1313            let language_model_text_stream = LanguageModelTextStream {
1314                message_id: message_id,
1315                stream: text_stream,
1316                last_token_usage,
1317            };
1318
1319            let Some(task) = codegen
1320                .update(cx, move |codegen, cx| {
1321                    codegen.handle_stream(
1322                        model,
1323                        /* strip_invalid_spans: */ false,
1324                        async { Ok(language_model_text_stream) },
1325                        cx,
1326                    )
1327                })
1328                .ok()
1329            else {
1330                return;
1331            };
1332
1333            task.await;
1334        })
1335    }
1336}
1337
1338#[derive(Copy, Clone, Debug)]
1339pub enum CodegenEvent {
1340    Finished,
1341    Undone,
1342}
1343
1344struct StripInvalidSpans<T> {
1345    stream: T,
1346    stream_done: bool,
1347    buffer: String,
1348    first_line: bool,
1349    line_end: bool,
1350    starts_with_code_block: bool,
1351}
1352
1353impl<T> StripInvalidSpans<T>
1354where
1355    T: Stream<Item = Result<String>>,
1356{
1357    fn new(stream: T) -> Self {
1358        Self {
1359            stream,
1360            stream_done: false,
1361            buffer: String::new(),
1362            first_line: true,
1363            line_end: false,
1364            starts_with_code_block: false,
1365        }
1366    }
1367}
1368
1369impl<T> Stream for StripInvalidSpans<T>
1370where
1371    T: Stream<Item = Result<String>>,
1372{
1373    type Item = Result<String>;
1374
1375    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1376        const CODE_BLOCK_DELIMITER: &str = "```";
1377        const CURSOR_SPAN: &str = "<|CURSOR|>";
1378
1379        let this = unsafe { self.get_unchecked_mut() };
1380        loop {
1381            if !this.stream_done {
1382                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1383                match stream.as_mut().poll_next(cx) {
1384                    Poll::Ready(Some(Ok(chunk))) => {
1385                        this.buffer.push_str(&chunk);
1386                    }
1387                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1388                    Poll::Ready(None) => {
1389                        this.stream_done = true;
1390                    }
1391                    Poll::Pending => return Poll::Pending,
1392                }
1393            }
1394
1395            let mut chunk = String::new();
1396            let mut consumed = 0;
1397            if !this.buffer.is_empty() {
1398                let mut lines = this.buffer.split('\n').enumerate().peekable();
1399                while let Some((line_ix, line)) = lines.next() {
1400                    if line_ix > 0 {
1401                        this.first_line = false;
1402                    }
1403
1404                    if this.first_line {
1405                        let trimmed_line = line.trim();
1406                        if lines.peek().is_some() {
1407                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1408                                consumed += line.len() + 1;
1409                                this.starts_with_code_block = true;
1410                                continue;
1411                            }
1412                        } else if trimmed_line.is_empty()
1413                            || prefixes(CODE_BLOCK_DELIMITER)
1414                                .any(|prefix| trimmed_line.starts_with(prefix))
1415                        {
1416                            break;
1417                        }
1418                    }
1419
1420                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1421                    if lines.peek().is_some() {
1422                        if this.line_end {
1423                            chunk.push('\n');
1424                        }
1425
1426                        chunk.push_str(&line_without_cursor);
1427                        this.line_end = true;
1428                        consumed += line.len() + 1;
1429                    } else if this.stream_done {
1430                        if !this.starts_with_code_block
1431                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1432                        {
1433                            if this.line_end {
1434                                chunk.push('\n');
1435                            }
1436
1437                            chunk.push_str(line);
1438                        }
1439
1440                        consumed += line.len();
1441                    } else {
1442                        let trimmed_line = line.trim();
1443                        if trimmed_line.is_empty()
1444                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1445                            || prefixes(CODE_BLOCK_DELIMITER)
1446                                .any(|prefix| trimmed_line.ends_with(prefix))
1447                        {
1448                            break;
1449                        } else {
1450                            if this.line_end {
1451                                chunk.push('\n');
1452                                this.line_end = false;
1453                            }
1454
1455                            chunk.push_str(&line_without_cursor);
1456                            consumed += line.len();
1457                        }
1458                    }
1459                }
1460            }
1461
1462            this.buffer = this.buffer.split_off(consumed);
1463            if !chunk.is_empty() {
1464                return Poll::Ready(Some(Ok(chunk)));
1465            } else if this.stream_done {
1466                return Poll::Ready(None);
1467            }
1468        }
1469    }
1470}
1471
1472fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1473    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1474}
1475
1476#[derive(Default)]
1477pub struct Diff {
1478    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1479    pub inserted_row_ranges: Vec<Range<Anchor>>,
1480}
1481
1482impl Diff {
1483    fn is_empty(&self) -> bool {
1484        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1485    }
1486}
1487
1488#[cfg(test)]
1489mod tests {
1490    use super::*;
1491    use futures::{
1492        Stream,
1493        stream::{self},
1494    };
1495    use gpui::TestAppContext;
1496    use indoc::indoc;
1497    use language::{Buffer, Point};
1498    use language_model::fake_provider::FakeLanguageModel;
1499    use language_model::{
1500        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
1501        LanguageModelToolUse, StopReason, TokenUsage,
1502    };
1503    use languages::rust_lang;
1504    use rand::prelude::*;
1505    use settings::SettingsStore;
1506    use std::{future, sync::Arc};
1507
1508    #[gpui::test(iterations = 10)]
1509    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1510        init_test(cx);
1511
1512        let text = indoc! {"
1513            fn main() {
1514                let x = 0;
1515                for _ in 0..10 {
1516                    x += 1;
1517                }
1518            }
1519        "};
1520        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1521        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1522        let range = buffer.read_with(cx, |buffer, cx| {
1523            let snapshot = buffer.snapshot(cx);
1524            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1525        });
1526        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1527        let codegen = cx.new(|cx| {
1528            CodegenAlternative::new(
1529                buffer.clone(),
1530                range.clone(),
1531                true,
1532                prompt_builder,
1533                Uuid::new_v4(),
1534                cx,
1535            )
1536        });
1537
1538        let chunks_tx = simulate_response_stream(&codegen, cx);
1539
1540        let mut new_text = concat!(
1541            "       let mut x = 0;\n",
1542            "       while x < 10 {\n",
1543            "           x += 1;\n",
1544            "       }",
1545        );
1546        while !new_text.is_empty() {
1547            let max_len = cmp::min(new_text.len(), 10);
1548            let len = rng.random_range(1..=max_len);
1549            let (chunk, suffix) = new_text.split_at(len);
1550            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1551            new_text = suffix;
1552            cx.background_executor.run_until_parked();
1553        }
1554        drop(chunks_tx);
1555        cx.background_executor.run_until_parked();
1556
1557        assert_eq!(
1558            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1559            indoc! {"
1560                fn main() {
1561                    let mut x = 0;
1562                    while x < 10 {
1563                        x += 1;
1564                    }
1565                }
1566            "}
1567        );
1568    }
1569
1570    #[gpui::test(iterations = 10)]
1571    async fn test_autoindent_when_generating_past_indentation(
1572        cx: &mut TestAppContext,
1573        mut rng: StdRng,
1574    ) {
1575        init_test(cx);
1576
1577        let text = indoc! {"
1578            fn main() {
1579                le
1580            }
1581        "};
1582        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1583        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1584        let range = buffer.read_with(cx, |buffer, cx| {
1585            let snapshot = buffer.snapshot(cx);
1586            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1587        });
1588        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1589        let codegen = cx.new(|cx| {
1590            CodegenAlternative::new(
1591                buffer.clone(),
1592                range.clone(),
1593                true,
1594                prompt_builder,
1595                Uuid::new_v4(),
1596                cx,
1597            )
1598        });
1599
1600        let chunks_tx = simulate_response_stream(&codegen, cx);
1601
1602        cx.background_executor.run_until_parked();
1603
1604        let mut new_text = concat!(
1605            "t mut x = 0;\n",
1606            "while x < 10 {\n",
1607            "    x += 1;\n",
1608            "}", //
1609        );
1610        while !new_text.is_empty() {
1611            let max_len = cmp::min(new_text.len(), 10);
1612            let len = rng.random_range(1..=max_len);
1613            let (chunk, suffix) = new_text.split_at(len);
1614            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1615            new_text = suffix;
1616            cx.background_executor.run_until_parked();
1617        }
1618        drop(chunks_tx);
1619        cx.background_executor.run_until_parked();
1620
1621        assert_eq!(
1622            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1623            indoc! {"
1624                fn main() {
1625                    let mut x = 0;
1626                    while x < 10 {
1627                        x += 1;
1628                    }
1629                }
1630            "}
1631        );
1632    }
1633
1634    #[gpui::test(iterations = 10)]
1635    async fn test_autoindent_when_generating_before_indentation(
1636        cx: &mut TestAppContext,
1637        mut rng: StdRng,
1638    ) {
1639        init_test(cx);
1640
1641        let text = concat!(
1642            "fn main() {\n",
1643            "  \n",
1644            "}\n" //
1645        );
1646        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1647        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1648        let range = buffer.read_with(cx, |buffer, cx| {
1649            let snapshot = buffer.snapshot(cx);
1650            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1651        });
1652        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1653        let codegen = cx.new(|cx| {
1654            CodegenAlternative::new(
1655                buffer.clone(),
1656                range.clone(),
1657                true,
1658                prompt_builder,
1659                Uuid::new_v4(),
1660                cx,
1661            )
1662        });
1663
1664        let chunks_tx = simulate_response_stream(&codegen, cx);
1665
1666        cx.background_executor.run_until_parked();
1667
1668        let mut new_text = concat!(
1669            "let mut x = 0;\n",
1670            "while x < 10 {\n",
1671            "    x += 1;\n",
1672            "}", //
1673        );
1674        while !new_text.is_empty() {
1675            let max_len = cmp::min(new_text.len(), 10);
1676            let len = rng.random_range(1..=max_len);
1677            let (chunk, suffix) = new_text.split_at(len);
1678            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1679            new_text = suffix;
1680            cx.background_executor.run_until_parked();
1681        }
1682        drop(chunks_tx);
1683        cx.background_executor.run_until_parked();
1684
1685        assert_eq!(
1686            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1687            indoc! {"
1688                fn main() {
1689                    let mut x = 0;
1690                    while x < 10 {
1691                        x += 1;
1692                    }
1693                }
1694            "}
1695        );
1696    }
1697
1698    #[gpui::test(iterations = 10)]
1699    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1700        init_test(cx);
1701
1702        let text = indoc! {"
1703            func main() {
1704            \tx := 0
1705            \tfor i := 0; i < 10; i++ {
1706            \t\tx++
1707            \t}
1708            }
1709        "};
1710        let buffer = cx.new(|cx| Buffer::local(text, cx));
1711        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1712        let range = buffer.read_with(cx, |buffer, cx| {
1713            let snapshot = buffer.snapshot(cx);
1714            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1715        });
1716        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1717        let codegen = cx.new(|cx| {
1718            CodegenAlternative::new(
1719                buffer.clone(),
1720                range.clone(),
1721                true,
1722                prompt_builder,
1723                Uuid::new_v4(),
1724                cx,
1725            )
1726        });
1727
1728        let chunks_tx = simulate_response_stream(&codegen, cx);
1729        let new_text = concat!(
1730            "func main() {\n",
1731            "\tx := 0\n",
1732            "\tfor x < 10 {\n",
1733            "\t\tx++\n",
1734            "\t}", //
1735        );
1736        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1737        drop(chunks_tx);
1738        cx.background_executor.run_until_parked();
1739
1740        assert_eq!(
1741            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1742            indoc! {"
1743                func main() {
1744                \tx := 0
1745                \tfor x < 10 {
1746                \t\tx++
1747                \t}
1748                }
1749            "}
1750        );
1751    }
1752
1753    #[gpui::test]
1754    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1755        init_test(cx);
1756
1757        let text = indoc! {"
1758            fn main() {
1759                let x = 0;
1760            }
1761        "};
1762        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1763        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1764        let range = buffer.read_with(cx, |buffer, cx| {
1765            let snapshot = buffer.snapshot(cx);
1766            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1767        });
1768        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1769        let codegen = cx.new(|cx| {
1770            CodegenAlternative::new(
1771                buffer.clone(),
1772                range.clone(),
1773                false,
1774                prompt_builder,
1775                Uuid::new_v4(),
1776                cx,
1777            )
1778        });
1779
1780        let chunks_tx = simulate_response_stream(&codegen, cx);
1781        chunks_tx
1782            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1783            .unwrap();
1784        drop(chunks_tx);
1785        cx.run_until_parked();
1786
1787        // The codegen is inactive, so the buffer doesn't get modified.
1788        assert_eq!(
1789            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1790            text
1791        );
1792
1793        // Activating the codegen applies the changes.
1794        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1795        assert_eq!(
1796            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1797            indoc! {"
1798                fn main() {
1799                    let mut x = 0;
1800                    x += 1;
1801                }
1802            "}
1803        );
1804
1805        // Deactivating the codegen undoes the changes.
1806        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1807        cx.run_until_parked();
1808        assert_eq!(
1809            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1810            text
1811        );
1812    }
1813
1814    // When not streaming tool calls, we strip backticks as part of parsing the model's
1815    // plain text response. This is a regression test for a bug where we stripped
1816    // backticks incorrectly.
1817    #[gpui::test]
1818    async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
1819        init_test(cx);
1820        let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
1821        let buffer = cx.new(|cx| Buffer::local("", cx));
1822        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1823        let range = buffer.read_with(cx, |buffer, cx| {
1824            let snapshot = buffer.snapshot(cx);
1825            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
1826        });
1827        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1828        let codegen = cx.new(|cx| {
1829            CodegenAlternative::new(
1830                buffer.clone(),
1831                range.clone(),
1832                true,
1833                prompt_builder,
1834                Uuid::new_v4(),
1835                cx,
1836            )
1837        });
1838
1839        let events_tx = simulate_tool_based_completion(&codegen, cx);
1840        let chunk_len = text.find('`').unwrap();
1841        events_tx
1842            .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
1843            .unwrap();
1844        events_tx
1845            .unbounded_send(rewrite_tool_use("tool_2", &text, true))
1846            .unwrap();
1847        events_tx
1848            .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
1849            .unwrap();
1850        drop(events_tx);
1851        cx.run_until_parked();
1852
1853        assert_eq!(
1854            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1855            text
1856        );
1857    }
1858
1859    #[gpui::test]
1860    async fn test_strip_invalid_spans_from_codeblock() {
1861        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1862        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1863        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1864        assert_chunks(
1865            "```html\n```js\nLorem ipsum dolor\n```\n```",
1866            "```js\nLorem ipsum dolor\n```",
1867        )
1868        .await;
1869        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1870        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1871        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1872        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1873
1874        async fn assert_chunks(text: &str, expected_text: &str) {
1875            for chunk_size in 1..=text.len() {
1876                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1877                    .map(|chunk| chunk.unwrap())
1878                    .collect::<String>()
1879                    .await;
1880                assert_eq!(
1881                    actual_text, expected_text,
1882                    "failed to strip invalid spans, chunk size: {}",
1883                    chunk_size
1884                );
1885            }
1886        }
1887
1888        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1889            stream::iter(
1890                text.chars()
1891                    .collect::<Vec<_>>()
1892                    .chunks(size)
1893                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1894                    .collect::<Vec<_>>(),
1895            )
1896        }
1897    }
1898
1899    fn init_test(cx: &mut TestAppContext) {
1900        cx.update(LanguageModelRegistry::test);
1901        cx.set_global(cx.update(SettingsStore::test));
1902    }
1903
1904    fn simulate_response_stream(
1905        codegen: &Entity<CodegenAlternative>,
1906        cx: &mut TestAppContext,
1907    ) -> mpsc::UnboundedSender<String> {
1908        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1909        let model = Arc::new(FakeLanguageModel::default());
1910        codegen.update(cx, |codegen, cx| {
1911            codegen.generation = codegen.handle_stream(
1912                model,
1913                /* strip_invalid_spans: */ false,
1914                future::ready(Ok(LanguageModelTextStream {
1915                    message_id: None,
1916                    stream: chunks_rx.map(Ok).boxed(),
1917                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1918                })),
1919                cx,
1920            );
1921        });
1922        chunks_tx
1923    }
1924
1925    fn simulate_tool_based_completion(
1926        codegen: &Entity<CodegenAlternative>,
1927        cx: &mut TestAppContext,
1928    ) -> mpsc::UnboundedSender<LanguageModelCompletionEvent> {
1929        let (events_tx, events_rx) = mpsc::unbounded();
1930        let model = Arc::new(FakeLanguageModel::default());
1931        codegen.update(cx, |codegen, cx| {
1932            let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
1933                as BoxStream<
1934                    'static,
1935                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1936                >));
1937            codegen.generation = codegen.handle_completion(model, completion_stream, cx);
1938        });
1939        events_tx
1940    }
1941
1942    fn rewrite_tool_use(
1943        id: &str,
1944        replacement_text: &str,
1945        is_complete: bool,
1946    ) -> LanguageModelCompletionEvent {
1947        let input = RewriteSectionInput {
1948            replacement_text: replacement_text.into(),
1949        };
1950        LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
1951            id: id.into(),
1952            name: REWRITE_SECTION_TOOL_NAME.into(),
1953            raw_input: serde_json::to_string(&input).unwrap(),
1954            input: serde_json::to_value(&input).unwrap(),
1955            is_input_complete: is_complete,
1956            thought_signature: None,
1957        })
1958    }
1959}