buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use uuid::Uuid;
   5
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
  10use futures::{
  11    SinkExt, Stream, StreamExt, TryStreamExt as _,
  12    channel::mpsc,
  13    future::{LocalBoxFuture, Shared},
  14    join,
  15    stream::BoxStream,
  16};
  17use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  18use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  19use language_model::{
  20    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  22    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  23    LanguageModelToolUse, Role, TokenUsage,
  24};
  25use multi_buffer::MultiBufferRow;
  26use parking_lot::Mutex;
  27use prompt_store::PromptBuilder;
  28use rope::Rope;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings as _;
  32use smol::future::FutureExt;
  33use std::{
  34    cmp,
  35    future::Future,
  36    iter,
  37    ops::{Range, RangeInclusive},
  38    pin::Pin,
  39    sync::Arc,
  40    task::{self, Poll},
  41    time::Instant,
  42};
  43use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  44
  45/// Use this tool when you cannot or should not make a rewrite. This includes:
  46/// - The user's request is unclear, ambiguous, or nonsensical
  47/// - The requested change cannot be made by only editing the <rewrite_this> section
  48#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  49pub struct FailureMessageInput {
  50    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  51    #[serde(default)]
  52    pub message: String,
  53}
  54
  55/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  56/// Only use this tool when you are confident you understand the user's request and can fulfill it
  57/// by editing the marked section.
  58#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  59pub struct RewriteSectionInput {
  60    /// The text to replace the section with.
  61    #[serde(default)]
  62    pub replacement_text: String,
  63}
  64
  65pub struct BufferCodegen {
  66    alternatives: Vec<Entity<CodegenAlternative>>,
  67    pub active_alternative: usize,
  68    seen_alternatives: HashSet<usize>,
  69    subscriptions: Vec<Subscription>,
  70    buffer: Entity<MultiBuffer>,
  71    range: Range<Anchor>,
  72    initial_transaction_id: Option<TransactionId>,
  73    builder: Arc<PromptBuilder>,
  74    pub is_insertion: bool,
  75    session_id: Uuid,
  76}
  77
  78impl BufferCodegen {
  79    pub fn new(
  80        buffer: Entity<MultiBuffer>,
  81        range: Range<Anchor>,
  82        initial_transaction_id: Option<TransactionId>,
  83        session_id: Uuid,
  84        builder: Arc<PromptBuilder>,
  85        cx: &mut Context<Self>,
  86    ) -> Self {
  87        let codegen = cx.new(|cx| {
  88            CodegenAlternative::new(
  89                buffer.clone(),
  90                range.clone(),
  91                false,
  92                builder.clone(),
  93                session_id,
  94                cx,
  95            )
  96        });
  97        let mut this = Self {
  98            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
  99            alternatives: vec![codegen],
 100            active_alternative: 0,
 101            seen_alternatives: HashSet::default(),
 102            subscriptions: Vec::new(),
 103            buffer,
 104            range,
 105            initial_transaction_id,
 106            builder,
 107            session_id,
 108        };
 109        this.activate(0, cx);
 110        this
 111    }
 112
 113    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 114        let codegen = self.active_alternative().clone();
 115        self.subscriptions.clear();
 116        self.subscriptions
 117            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 118        self.subscriptions
 119            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 120    }
 121
 122    pub fn active_completion(&self, cx: &App) -> Option<String> {
 123        self.active_alternative().read(cx).current_completion()
 124    }
 125
 126    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 127        &self.alternatives[self.active_alternative]
 128    }
 129
 130    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 131        self.active_alternative().read(cx).language_name(cx)
 132    }
 133
 134    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 135        &self.active_alternative().read(cx).status
 136    }
 137
 138    pub fn alternative_count(&self, cx: &App) -> usize {
 139        LanguageModelRegistry::read_global(cx)
 140            .inline_alternative_models()
 141            .len()
 142            + 1
 143    }
 144
 145    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 146        let next_active_ix = if self.active_alternative == 0 {
 147            self.alternatives.len() - 1
 148        } else {
 149            self.active_alternative - 1
 150        };
 151        self.activate(next_active_ix, cx);
 152    }
 153
 154    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 155        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 156        self.activate(next_active_ix, cx);
 157    }
 158
 159    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 160        self.active_alternative()
 161            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 162        self.seen_alternatives.insert(index);
 163        self.active_alternative = index;
 164        self.active_alternative()
 165            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 166        self.subscribe_to_alternative(cx);
 167        cx.notify();
 168    }
 169
 170    pub fn start(
 171        &mut self,
 172        primary_model: Arc<dyn LanguageModel>,
 173        user_prompt: String,
 174        context_task: Shared<Task<Option<LoadedContext>>>,
 175        cx: &mut Context<Self>,
 176    ) -> Result<()> {
 177        let alternative_models = LanguageModelRegistry::read_global(cx)
 178            .inline_alternative_models()
 179            .to_vec();
 180
 181        self.active_alternative()
 182            .update(cx, |alternative, cx| alternative.undo(cx));
 183        self.activate(0, cx);
 184        self.alternatives.truncate(1);
 185
 186        for _ in 0..alternative_models.len() {
 187            self.alternatives.push(cx.new(|cx| {
 188                CodegenAlternative::new(
 189                    self.buffer.clone(),
 190                    self.range.clone(),
 191                    false,
 192                    self.builder.clone(),
 193                    self.session_id,
 194                    cx,
 195                )
 196            }));
 197        }
 198
 199        for (model, alternative) in iter::once(primary_model)
 200            .chain(alternative_models)
 201            .zip(&self.alternatives)
 202        {
 203            alternative.update(cx, |alternative, cx| {
 204                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 205            })?;
 206        }
 207
 208        Ok(())
 209    }
 210
 211    pub fn stop(&mut self, cx: &mut Context<Self>) {
 212        for codegen in &self.alternatives {
 213            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 214        }
 215    }
 216
 217    pub fn undo(&mut self, cx: &mut Context<Self>) {
 218        self.active_alternative()
 219            .update(cx, |codegen, cx| codegen.undo(cx));
 220
 221        self.buffer.update(cx, |buffer, cx| {
 222            if let Some(transaction_id) = self.initial_transaction_id.take() {
 223                buffer.undo_transaction(transaction_id, cx);
 224                buffer.refresh_preview(cx);
 225            }
 226        });
 227    }
 228
 229    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 230        self.active_alternative().read(cx).buffer.clone()
 231    }
 232
 233    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 234        self.active_alternative().read(cx).old_buffer.clone()
 235    }
 236
 237    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 238        self.active_alternative().read(cx).snapshot.clone()
 239    }
 240
 241    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 242        self.active_alternative().read(cx).edit_position
 243    }
 244
 245    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 246        &self.active_alternative().read(cx).diff
 247    }
 248
 249    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 250        self.active_alternative().read(cx).last_equal_ranges()
 251    }
 252
 253    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 254        self.active_alternative().read(cx).selected_text()
 255    }
 256
 257    pub fn session_id(&self) -> Uuid {
 258        self.session_id
 259    }
 260}
 261
 262impl EventEmitter<CodegenEvent> for BufferCodegen {}
 263
 264pub struct CodegenAlternative {
 265    buffer: Entity<MultiBuffer>,
 266    old_buffer: Entity<Buffer>,
 267    snapshot: MultiBufferSnapshot,
 268    edit_position: Option<Anchor>,
 269    range: Range<Anchor>,
 270    last_equal_ranges: Vec<Range<Anchor>>,
 271    transformation_transaction_id: Option<TransactionId>,
 272    status: CodegenStatus,
 273    generation: Task<()>,
 274    diff: Diff,
 275    _subscription: gpui::Subscription,
 276    builder: Arc<PromptBuilder>,
 277    active: bool,
 278    edits: Vec<(Range<Anchor>, String)>,
 279    line_operations: Vec<LineOperation>,
 280    elapsed_time: Option<f64>,
 281    completion: Option<String>,
 282    selected_text: Option<String>,
 283    pub message_id: Option<String>,
 284    session_id: Uuid,
 285    pub description: Option<String>,
 286    pub failure: Option<String>,
 287}
 288
 289impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 290
 291impl CodegenAlternative {
 292    pub fn new(
 293        buffer: Entity<MultiBuffer>,
 294        range: Range<Anchor>,
 295        active: bool,
 296        builder: Arc<PromptBuilder>,
 297        session_id: Uuid,
 298        cx: &mut Context<Self>,
 299    ) -> Self {
 300        let snapshot = buffer.read(cx).snapshot(cx);
 301
 302        let (old_buffer, _, _) = snapshot
 303            .range_to_buffer_ranges(range.clone())
 304            .pop()
 305            .unwrap();
 306        let old_buffer = cx.new(|cx| {
 307            let text = old_buffer.as_rope().clone();
 308            let line_ending = old_buffer.line_ending();
 309            let language = old_buffer.language().cloned();
 310            let language_registry = buffer
 311                .read(cx)
 312                .buffer(old_buffer.remote_id())
 313                .unwrap()
 314                .read(cx)
 315                .language_registry();
 316
 317            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 318            buffer.set_language(language, cx);
 319            if let Some(language_registry) = language_registry {
 320                buffer.set_language_registry(language_registry);
 321            }
 322            buffer
 323        });
 324
 325        Self {
 326            buffer: buffer.clone(),
 327            old_buffer,
 328            edit_position: None,
 329            message_id: None,
 330            snapshot,
 331            last_equal_ranges: Default::default(),
 332            transformation_transaction_id: None,
 333            status: CodegenStatus::Idle,
 334            generation: Task::ready(()),
 335            diff: Diff::default(),
 336            builder,
 337            active: active,
 338            edits: Vec::new(),
 339            line_operations: Vec::new(),
 340            range,
 341            elapsed_time: None,
 342            completion: None,
 343            selected_text: None,
 344            session_id,
 345            description: None,
 346            failure: None,
 347            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 348        }
 349    }
 350
 351    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 352        self.old_buffer
 353            .read(cx)
 354            .language()
 355            .map(|language| language.name())
 356    }
 357
 358    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 359        if active != self.active {
 360            self.active = active;
 361
 362            if self.active {
 363                let edits = self.edits.clone();
 364                self.apply_edits(edits, cx);
 365                if matches!(self.status, CodegenStatus::Pending) {
 366                    let line_operations = self.line_operations.clone();
 367                    self.reapply_line_based_diff(line_operations, cx);
 368                } else {
 369                    self.reapply_batch_diff(cx).detach();
 370                }
 371            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 372                self.buffer.update(cx, |buffer, cx| {
 373                    buffer.undo_transaction(transaction_id, cx);
 374                    buffer.forget_transaction(transaction_id, cx);
 375                });
 376            }
 377        }
 378    }
 379
 380    fn handle_buffer_event(
 381        &mut self,
 382        _buffer: Entity<MultiBuffer>,
 383        event: &multi_buffer::Event,
 384        cx: &mut Context<Self>,
 385    ) {
 386        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 387            && self.transformation_transaction_id == Some(*transaction_id)
 388        {
 389            self.transformation_transaction_id = None;
 390            self.generation = Task::ready(());
 391            cx.emit(CodegenEvent::Undone);
 392        }
 393    }
 394
 395    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 396        &self.last_equal_ranges
 397    }
 398
 399    pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 400        model.supports_streaming_tools()
 401            && cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
 402            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 403    }
 404
 405    pub fn start(
 406        &mut self,
 407        user_prompt: String,
 408        context_task: Shared<Task<Option<LoadedContext>>>,
 409        model: Arc<dyn LanguageModel>,
 410        cx: &mut Context<Self>,
 411    ) -> Result<()> {
 412        // Clear the model explanation since the user has started a new generation.
 413        self.description = None;
 414
 415        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 416            self.buffer.update(cx, |buffer, cx| {
 417                buffer.undo_transaction(transformation_transaction_id, cx);
 418            });
 419        }
 420
 421        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 422
 423        if Self::use_streaming_tools(model.as_ref(), cx) {
 424            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 425            let completion_events = cx.spawn({
 426                let model = model.clone();
 427                async move |_, cx| model.stream_completion(request.await, cx).await
 428            });
 429            self.generation = self.handle_completion(model, completion_events, cx);
 430        } else {
 431            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 432                if user_prompt.trim().to_lowercase() == "delete" {
 433                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 434                } else {
 435                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 436                    cx.spawn({
 437                        let model = model.clone();
 438                        async move |_, cx| {
 439                            Ok(model.stream_completion_text(request.await, cx).await?)
 440                        }
 441                    })
 442                    .boxed_local()
 443                };
 444            self.generation = self.handle_stream(model, stream, cx);
 445        }
 446
 447        Ok(())
 448    }
 449
 450    fn build_request_tools(
 451        &self,
 452        model: &Arc<dyn LanguageModel>,
 453        user_prompt: String,
 454        context_task: Shared<Task<Option<LoadedContext>>>,
 455        cx: &mut App,
 456    ) -> Result<Task<LanguageModelRequest>> {
 457        let buffer = self.buffer.read(cx).snapshot(cx);
 458        let language = buffer.language_at(self.range.start);
 459        let language_name = if let Some(language) = language.as_ref() {
 460            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 461                None
 462            } else {
 463                Some(language.name())
 464            }
 465        } else {
 466            None
 467        };
 468
 469        let language_name = language_name.as_ref();
 470        let start = buffer.point_to_buffer_offset(self.range.start);
 471        let end = buffer.point_to_buffer_offset(self.range.end);
 472        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 473            let (start_buffer, start_buffer_offset) = start;
 474            let (end_buffer, end_buffer_offset) = end;
 475            if start_buffer.remote_id() == end_buffer.remote_id() {
 476                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 477            } else {
 478                anyhow::bail!("invalid transformation range");
 479            }
 480        } else {
 481            anyhow::bail!("invalid transformation range");
 482        };
 483
 484        let system_prompt = self
 485            .builder
 486            .generate_inline_transformation_prompt_tools(
 487                language_name,
 488                buffer,
 489                range.start.0..range.end.0,
 490            )
 491            .context("generating content prompt")?;
 492
 493        let temperature = AgentSettings::temperature_for_model(model, cx);
 494
 495        let tool_input_format = model.tool_input_format();
 496        let tool_choice = model
 497            .supports_tool_choice(LanguageModelToolChoice::Any)
 498            .then_some(LanguageModelToolChoice::Any);
 499
 500        Ok(cx.spawn(async move |_cx| {
 501            let mut messages = vec![LanguageModelRequestMessage {
 502                role: Role::System,
 503                content: vec![system_prompt.into()],
 504                cache: false,
 505                reasoning_details: None,
 506            }];
 507
 508            let mut user_message = LanguageModelRequestMessage {
 509                role: Role::User,
 510                content: Vec::new(),
 511                cache: false,
 512                reasoning_details: None,
 513            };
 514
 515            if let Some(context) = context_task.await {
 516                context.add_to_request_message(&mut user_message);
 517            }
 518
 519            user_message.content.push(user_prompt.into());
 520            messages.push(user_message);
 521
 522            let tools = vec![
 523                LanguageModelRequestTool {
 524                    name: "rewrite_section".to_string(),
 525                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 526                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 527                },
 528                LanguageModelRequestTool {
 529                    name: "failure_message".to_string(),
 530                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 531                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 532                },
 533            ];
 534
 535            LanguageModelRequest {
 536                thread_id: None,
 537                prompt_id: None,
 538                intent: Some(CompletionIntent::InlineAssist),
 539                mode: None,
 540                tools,
 541                tool_choice,
 542                stop: Vec::new(),
 543                temperature,
 544                messages,
 545                thinking_allowed: false,
 546            }
 547        }))
 548    }
 549
 550    fn build_request(
 551        &self,
 552        model: &Arc<dyn LanguageModel>,
 553        user_prompt: String,
 554        context_task: Shared<Task<Option<LoadedContext>>>,
 555        cx: &mut App,
 556    ) -> Result<Task<LanguageModelRequest>> {
 557        if Self::use_streaming_tools(model.as_ref(), cx) {
 558            return self.build_request_tools(model, user_prompt, context_task, cx);
 559        }
 560
 561        let buffer = self.buffer.read(cx).snapshot(cx);
 562        let language = buffer.language_at(self.range.start);
 563        let language_name = if let Some(language) = language.as_ref() {
 564            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 565                None
 566            } else {
 567                Some(language.name())
 568            }
 569        } else {
 570            None
 571        };
 572
 573        let language_name = language_name.as_ref();
 574        let start = buffer.point_to_buffer_offset(self.range.start);
 575        let end = buffer.point_to_buffer_offset(self.range.end);
 576        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 577            let (start_buffer, start_buffer_offset) = start;
 578            let (end_buffer, end_buffer_offset) = end;
 579            if start_buffer.remote_id() == end_buffer.remote_id() {
 580                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 581            } else {
 582                anyhow::bail!("invalid transformation range");
 583            }
 584        } else {
 585            anyhow::bail!("invalid transformation range");
 586        };
 587
 588        let prompt = self
 589            .builder
 590            .generate_inline_transformation_prompt(
 591                user_prompt,
 592                language_name,
 593                buffer,
 594                range.start.0..range.end.0,
 595            )
 596            .context("generating content prompt")?;
 597
 598        let temperature = AgentSettings::temperature_for_model(model, cx);
 599
 600        Ok(cx.spawn(async move |_cx| {
 601            let mut request_message = LanguageModelRequestMessage {
 602                role: Role::User,
 603                content: Vec::new(),
 604                cache: false,
 605                reasoning_details: None,
 606            };
 607
 608            if let Some(context) = context_task.await {
 609                context.add_to_request_message(&mut request_message);
 610            }
 611
 612            request_message.content.push(prompt.into());
 613
 614            LanguageModelRequest {
 615                thread_id: None,
 616                prompt_id: None,
 617                intent: Some(CompletionIntent::InlineAssist),
 618                mode: None,
 619                tools: Vec::new(),
 620                tool_choice: None,
 621                stop: Vec::new(),
 622                temperature,
 623                messages: vec![request_message],
 624                thinking_allowed: false,
 625            }
 626        }))
 627    }
 628
 629    pub fn handle_stream(
 630        &mut self,
 631        model: Arc<dyn LanguageModel>,
 632        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 633        cx: &mut Context<Self>,
 634    ) -> Task<()> {
 635        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 636        let session_id = self.session_id;
 637        let model_telemetry_id = model.telemetry_id();
 638        let model_provider_id = model.provider_id().to_string();
 639        let start_time = Instant::now();
 640
 641        // Make a new snapshot and re-resolve anchor in case the document was modified.
 642        // This can happen often if the editor loses focus and is saved + reformatted,
 643        // as in https://github.com/zed-industries/zed/issues/39088
 644        self.snapshot = self.buffer.read(cx).snapshot(cx);
 645        self.range = self.snapshot.anchor_after(self.range.start)
 646            ..self.snapshot.anchor_after(self.range.end);
 647
 648        let snapshot = self.snapshot.clone();
 649        let selected_text = snapshot
 650            .text_for_range(self.range.start..self.range.end)
 651            .collect::<Rope>();
 652
 653        self.selected_text = Some(selected_text.to_string());
 654
 655        let selection_start = self.range.start.to_point(&snapshot);
 656
 657        // Start with the indentation of the first line in the selection
 658        let mut suggested_line_indent = snapshot
 659            .suggested_indents(selection_start.row..=selection_start.row, cx)
 660            .into_values()
 661            .next()
 662            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 663
 664        // If the first line in the selection does not have indentation, check the following lines
 665        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 666            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 667                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 668                // Prefer tabs if a line in the selection uses tabs as indentation
 669                if line_indent.kind == IndentKind::Tab {
 670                    suggested_line_indent.kind = IndentKind::Tab;
 671                    break;
 672                }
 673            }
 674        }
 675
 676        let language_name = {
 677            let multibuffer = self.buffer.read(cx);
 678            let snapshot = multibuffer.snapshot(cx);
 679            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 680            ranges
 681                .first()
 682                .and_then(|(buffer, _, _)| buffer.language())
 683                .map(|language| language.name())
 684        };
 685
 686        self.diff = Diff::default();
 687        self.status = CodegenStatus::Pending;
 688        let mut edit_start = self.range.start.to_offset(&snapshot);
 689        let completion = Arc::new(Mutex::new(String::new()));
 690        let completion_clone = completion.clone();
 691
 692        cx.notify();
 693        cx.spawn(async move |codegen, cx| {
 694            let stream = stream.await;
 695
 696            let token_usage = stream
 697                .as_ref()
 698                .ok()
 699                .map(|stream| stream.last_token_usage.clone());
 700            let message_id = stream
 701                .as_ref()
 702                .ok()
 703                .and_then(|stream| stream.message_id.clone());
 704            let generate = async {
 705                let model_telemetry_id = model_telemetry_id.clone();
 706                let model_provider_id = model_provider_id.clone();
 707                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 708                let message_id = message_id.clone();
 709                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 710                    let anthropic_reporter = anthropic_reporter.clone();
 711                    let language_name = language_name.clone();
 712                    async move {
 713                        let mut response_latency = None;
 714                        let request_start = Instant::now();
 715                        let diff = async {
 716                            let chunks = StripInvalidSpans::new(
 717                                stream?.stream.map_err(|error| error.into()),
 718                            );
 719                            futures::pin_mut!(chunks);
 720
 721                            let mut diff = StreamingDiff::new(selected_text.to_string());
 722                            let mut line_diff = LineDiff::default();
 723
 724                            let mut new_text = String::new();
 725                            let mut base_indent = None;
 726                            let mut line_indent = None;
 727                            let mut first_line = true;
 728
 729                            while let Some(chunk) = chunks.next().await {
 730                                if response_latency.is_none() {
 731                                    response_latency = Some(request_start.elapsed());
 732                                }
 733                                let chunk = chunk?;
 734                                completion_clone.lock().push_str(&chunk);
 735
 736                                let mut lines = chunk.split('\n').peekable();
 737                                while let Some(line) = lines.next() {
 738                                    new_text.push_str(line);
 739                                    if line_indent.is_none()
 740                                        && let Some(non_whitespace_ch_ix) =
 741                                            new_text.find(|ch: char| !ch.is_whitespace())
 742                                    {
 743                                        line_indent = Some(non_whitespace_ch_ix);
 744                                        base_indent = base_indent.or(line_indent);
 745
 746                                        let line_indent = line_indent.unwrap();
 747                                        let base_indent = base_indent.unwrap();
 748                                        let indent_delta = line_indent as i32 - base_indent as i32;
 749                                        let mut corrected_indent_len = cmp::max(
 750                                            0,
 751                                            suggested_line_indent.len as i32 + indent_delta,
 752                                        )
 753                                            as usize;
 754                                        if first_line {
 755                                            corrected_indent_len = corrected_indent_len
 756                                                .saturating_sub(selection_start.column as usize);
 757                                        }
 758
 759                                        let indent_char = suggested_line_indent.char();
 760                                        let mut indent_buffer = [0; 4];
 761                                        let indent_str =
 762                                            indent_char.encode_utf8(&mut indent_buffer);
 763                                        new_text.replace_range(
 764                                            ..line_indent,
 765                                            &indent_str.repeat(corrected_indent_len),
 766                                        );
 767                                    }
 768
 769                                    if line_indent.is_some() {
 770                                        let char_ops = diff.push_new(&new_text);
 771                                        line_diff.push_char_operations(&char_ops, &selected_text);
 772                                        diff_tx
 773                                            .send((char_ops, line_diff.line_operations()))
 774                                            .await?;
 775                                        new_text.clear();
 776                                    }
 777
 778                                    if lines.peek().is_some() {
 779                                        let char_ops = diff.push_new("\n");
 780                                        line_diff.push_char_operations(&char_ops, &selected_text);
 781                                        diff_tx
 782                                            .send((char_ops, line_diff.line_operations()))
 783                                            .await?;
 784                                        if line_indent.is_none() {
 785                                            // Don't write out the leading indentation in empty lines on the next line
 786                                            // This is the case where the above if statement didn't clear the buffer
 787                                            new_text.clear();
 788                                        }
 789                                        line_indent = None;
 790                                        first_line = false;
 791                                    }
 792                                }
 793                            }
 794
 795                            let mut char_ops = diff.push_new(&new_text);
 796                            char_ops.extend(diff.finish());
 797                            line_diff.push_char_operations(&char_ops, &selected_text);
 798                            line_diff.finish(&selected_text);
 799                            diff_tx
 800                                .send((char_ops, line_diff.line_operations()))
 801                                .await?;
 802
 803                            anyhow::Ok(())
 804                        };
 805
 806                        let result = diff.await;
 807
 808                        let error_message = result.as_ref().err().map(|error| error.to_string());
 809                        telemetry::event!(
 810                            "Assistant Responded",
 811                            kind = "inline",
 812                            phase = "response",
 813                            session_id = session_id.to_string(),
 814                            model = model_telemetry_id,
 815                            model_provider = model_provider_id,
 816                            language_name = language_name.as_ref().map(|n| n.to_string()),
 817                            message_id = message_id.as_deref(),
 818                            response_latency = response_latency,
 819                            error_message = error_message.as_deref(),
 820                        );
 821
 822                        anthropic_reporter.report(language_model::AnthropicEventData {
 823                            completion_type: language_model::AnthropicCompletionType::Editor,
 824                            event: language_model::AnthropicEventType::Response,
 825                            language_name: language_name.map(|n| n.to_string()),
 826                            message_id,
 827                        });
 828
 829                        result?;
 830                        Ok(())
 831                    }
 832                });
 833
 834                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 835                    codegen.update(cx, |codegen, cx| {
 836                        codegen.last_equal_ranges.clear();
 837
 838                        let edits = char_ops
 839                            .into_iter()
 840                            .filter_map(|operation| match operation {
 841                                CharOperation::Insert { text } => {
 842                                    let edit_start = snapshot.anchor_after(edit_start);
 843                                    Some((edit_start..edit_start, text))
 844                                }
 845                                CharOperation::Delete { bytes } => {
 846                                    let edit_end = edit_start + bytes;
 847                                    let edit_range = snapshot.anchor_after(edit_start)
 848                                        ..snapshot.anchor_before(edit_end);
 849                                    edit_start = edit_end;
 850                                    Some((edit_range, String::new()))
 851                                }
 852                                CharOperation::Keep { bytes } => {
 853                                    let edit_end = edit_start + bytes;
 854                                    let edit_range = snapshot.anchor_after(edit_start)
 855                                        ..snapshot.anchor_before(edit_end);
 856                                    edit_start = edit_end;
 857                                    codegen.last_equal_ranges.push(edit_range);
 858                                    None
 859                                }
 860                            })
 861                            .collect::<Vec<_>>();
 862
 863                        if codegen.active {
 864                            codegen.apply_edits(edits.iter().cloned(), cx);
 865                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 866                        }
 867                        codegen.edits.extend(edits);
 868                        codegen.line_operations = line_ops;
 869                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 870
 871                        cx.notify();
 872                    })?;
 873                }
 874
 875                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 876                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 877                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 878                let batch_diff_task =
 879                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 880                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 881                line_based_stream_diff?;
 882
 883                anyhow::Ok(())
 884            };
 885
 886            let result = generate.await;
 887            let elapsed_time = start_time.elapsed().as_secs_f64();
 888
 889            codegen
 890                .update(cx, |this, cx| {
 891                    this.message_id = message_id;
 892                    this.last_equal_ranges.clear();
 893                    if let Err(error) = result {
 894                        this.status = CodegenStatus::Error(error);
 895                    } else {
 896                        this.status = CodegenStatus::Done;
 897                    }
 898                    this.elapsed_time = Some(elapsed_time);
 899                    this.completion = Some(completion.lock().clone());
 900                    if let Some(usage) = token_usage {
 901                        let usage = usage.lock();
 902                        telemetry::event!(
 903                            "Inline Assistant Completion",
 904                            model = model_telemetry_id,
 905                            model_provider = model_provider_id,
 906                            input_tokens = usage.input_tokens,
 907                            output_tokens = usage.output_tokens,
 908                        )
 909                    }
 910
 911                    cx.emit(CodegenEvent::Finished);
 912                    cx.notify();
 913                })
 914                .ok();
 915        })
 916    }
 917
 918    pub fn current_completion(&self) -> Option<String> {
 919        self.completion.clone()
 920    }
 921
 922    #[cfg(any(test, feature = "test-support"))]
 923    pub fn current_description(&self) -> Option<String> {
 924        self.description.clone()
 925    }
 926
 927    #[cfg(any(test, feature = "test-support"))]
 928    pub fn current_failure(&self) -> Option<String> {
 929        self.failure.clone()
 930    }
 931
 932    pub fn selected_text(&self) -> Option<&str> {
 933        self.selected_text.as_deref()
 934    }
 935
 936    pub fn stop(&mut self, cx: &mut Context<Self>) {
 937        self.last_equal_ranges.clear();
 938        if self.diff.is_empty() {
 939            self.status = CodegenStatus::Idle;
 940        } else {
 941            self.status = CodegenStatus::Done;
 942        }
 943        self.generation = Task::ready(());
 944        cx.emit(CodegenEvent::Finished);
 945        cx.notify();
 946    }
 947
 948    pub fn undo(&mut self, cx: &mut Context<Self>) {
 949        self.buffer.update(cx, |buffer, cx| {
 950            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 951                buffer.undo_transaction(transaction_id, cx);
 952                buffer.refresh_preview(cx);
 953            }
 954        });
 955    }
 956
 957    fn apply_edits(
 958        &mut self,
 959        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 960        cx: &mut Context<CodegenAlternative>,
 961    ) {
 962        let transaction = self.buffer.update(cx, |buffer, cx| {
 963            // Avoid grouping agent edits with user edits.
 964            buffer.finalize_last_transaction(cx);
 965            buffer.start_transaction(cx);
 966            buffer.edit(edits, None, cx);
 967            buffer.end_transaction(cx)
 968        });
 969
 970        if let Some(transaction) = transaction {
 971            if let Some(first_transaction) = self.transformation_transaction_id {
 972                // Group all agent edits into the first transaction.
 973                self.buffer.update(cx, |buffer, cx| {
 974                    buffer.merge_transactions(transaction, first_transaction, cx)
 975                });
 976            } else {
 977                self.transformation_transaction_id = Some(transaction);
 978                self.buffer
 979                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 980            }
 981        }
 982    }
 983
 984    fn reapply_line_based_diff(
 985        &mut self,
 986        line_operations: impl IntoIterator<Item = LineOperation>,
 987        cx: &mut Context<Self>,
 988    ) {
 989        let old_snapshot = self.snapshot.clone();
 990        let old_range = self.range.to_point(&old_snapshot);
 991        let new_snapshot = self.buffer.read(cx).snapshot(cx);
 992        let new_range = self.range.to_point(&new_snapshot);
 993
 994        let mut old_row = old_range.start.row;
 995        let mut new_row = new_range.start.row;
 996
 997        self.diff.deleted_row_ranges.clear();
 998        self.diff.inserted_row_ranges.clear();
 999        for operation in line_operations {
1000            match operation {
1001                LineOperation::Keep { lines } => {
1002                    old_row += lines;
1003                    new_row += lines;
1004                }
1005                LineOperation::Delete { lines } => {
1006                    let old_end_row = old_row + lines - 1;
1007                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1008
1009                    if let Some((_, last_deleted_row_range)) =
1010                        self.diff.deleted_row_ranges.last_mut()
1011                    {
1012                        if *last_deleted_row_range.end() + 1 == old_row {
1013                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1014                        } else {
1015                            self.diff
1016                                .deleted_row_ranges
1017                                .push((new_row, old_row..=old_end_row));
1018                        }
1019                    } else {
1020                        self.diff
1021                            .deleted_row_ranges
1022                            .push((new_row, old_row..=old_end_row));
1023                    }
1024
1025                    old_row += lines;
1026                }
1027                LineOperation::Insert { lines } => {
1028                    let new_end_row = new_row + lines - 1;
1029                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1030                    let end = new_snapshot.anchor_before(Point::new(
1031                        new_end_row,
1032                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1033                    ));
1034                    self.diff.inserted_row_ranges.push(start..end);
1035                    new_row += lines;
1036                }
1037            }
1038
1039            cx.notify();
1040        }
1041    }
1042
1043    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1044        let old_snapshot = self.snapshot.clone();
1045        let old_range = self.range.to_point(&old_snapshot);
1046        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1047        let new_range = self.range.to_point(&new_snapshot);
1048
1049        cx.spawn(async move |codegen, cx| {
1050            let (deleted_row_ranges, inserted_row_ranges) = cx
1051                .background_spawn(async move {
1052                    let old_text = old_snapshot
1053                        .text_for_range(
1054                            Point::new(old_range.start.row, 0)
1055                                ..Point::new(
1056                                    old_range.end.row,
1057                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1058                                ),
1059                        )
1060                        .collect::<String>();
1061                    let new_text = new_snapshot
1062                        .text_for_range(
1063                            Point::new(new_range.start.row, 0)
1064                                ..Point::new(
1065                                    new_range.end.row,
1066                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1067                                ),
1068                        )
1069                        .collect::<String>();
1070
1071                    let old_start_row = old_range.start.row;
1072                    let new_start_row = new_range.start.row;
1073                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1074                    let mut inserted_row_ranges = Vec::new();
1075                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1076                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1077                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1078                        if !old_rows.is_empty() {
1079                            deleted_row_ranges.push((
1080                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1081                                old_rows.start..=old_rows.end - 1,
1082                            ));
1083                        }
1084                        if !new_rows.is_empty() {
1085                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1086                            let new_end_row = new_rows.end - 1;
1087                            let end = new_snapshot.anchor_before(Point::new(
1088                                new_end_row,
1089                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1090                            ));
1091                            inserted_row_ranges.push(start..end);
1092                        }
1093                    }
1094                    (deleted_row_ranges, inserted_row_ranges)
1095                })
1096                .await;
1097
1098            codegen
1099                .update(cx, |codegen, cx| {
1100                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1101                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1102                    cx.notify();
1103                })
1104                .ok();
1105        })
1106    }
1107
1108    fn handle_completion(
1109        &mut self,
1110        model: Arc<dyn LanguageModel>,
1111        completion_stream: Task<
1112            Result<
1113                BoxStream<
1114                    'static,
1115                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1116                >,
1117                LanguageModelCompletionError,
1118            >,
1119        >,
1120        cx: &mut Context<Self>,
1121    ) -> Task<()> {
1122        self.diff = Diff::default();
1123        self.status = CodegenStatus::Pending;
1124
1125        cx.notify();
1126        // Leaving this in generation so that STOP equivalent events are respected even
1127        // while we're still pre-processing the completion event
1128        cx.spawn(async move |codegen, cx| {
1129            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1130                let _ = codegen.update(cx, |this, cx| {
1131                    this.status = status;
1132                    cx.emit(CodegenEvent::Finished);
1133                    cx.notify();
1134                });
1135            };
1136
1137            let mut completion_events = match completion_stream.await {
1138                Ok(events) => events,
1139                Err(err) => {
1140                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1141                    return;
1142                }
1143            };
1144
1145            enum ToolUseOutput {
1146                Rewrite {
1147                    text: String,
1148                    description: Option<String>,
1149                },
1150                Failure(String),
1151            }
1152
1153            enum ModelUpdate {
1154                Description(String),
1155                Failure(String),
1156            }
1157
1158            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1159            let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1160                let mut chars_read_so_far = chars_read_so_far.lock();
1161                match tool_use.name.as_ref() {
1162                    "rewrite_section" => {
1163                        let Ok(input) =
1164                            serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1165                        else {
1166                            return None;
1167                        };
1168                        let text = input.replacement_text[*chars_read_so_far..].to_string();
1169                        *chars_read_so_far = input.replacement_text.len();
1170                        Some(ToolUseOutput::Rewrite {
1171                            text,
1172                            description: None,
1173                        })
1174                    }
1175                    "failure_message" => {
1176                        let Ok(mut input) =
1177                            serde_json::from_value::<FailureMessageInput>(tool_use.input)
1178                        else {
1179                            return None;
1180                        };
1181                        Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1182                    }
1183                    _ => None,
1184                }
1185            };
1186
1187            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1188
1189            cx.spawn({
1190                let codegen = codegen.clone();
1191                async move |cx| {
1192                    while let Some(update) = message_rx.next().await {
1193                        let _ = codegen.update(cx, |this, _cx| match update {
1194                            ModelUpdate::Description(d) => this.description = Some(d),
1195                            ModelUpdate::Failure(f) => this.failure = Some(f),
1196                        });
1197                    }
1198                }
1199            })
1200            .detach();
1201
1202            let mut message_id = None;
1203            let mut first_text = None;
1204            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1205            let total_text = Arc::new(Mutex::new(String::new()));
1206
1207            loop {
1208                if let Some(first_event) = completion_events.next().await {
1209                    match first_event {
1210                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1211                            message_id = Some(id);
1212                        }
1213                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1214                            if let Some(output) = process_tool_use(tool_use) {
1215                                let (text, update) = match output {
1216                                    ToolUseOutput::Rewrite { text, description } => {
1217                                        (Some(text), description.map(ModelUpdate::Description))
1218                                    }
1219                                    ToolUseOutput::Failure(message) => {
1220                                        (None, Some(ModelUpdate::Failure(message)))
1221                                    }
1222                                };
1223                                if let Some(update) = update {
1224                                    let _ = message_tx.unbounded_send(update);
1225                                }
1226                                first_text = text;
1227                                if first_text.is_some() {
1228                                    break;
1229                                }
1230                            }
1231                        }
1232                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1233                            *last_token_usage.lock() = token_usage;
1234                        }
1235                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1236                            let mut lock = total_text.lock();
1237                            lock.push_str(&text);
1238                        }
1239                        Ok(e) => {
1240                            log::warn!("Unexpected event: {:?}", e);
1241                            break;
1242                        }
1243                        Err(e) => {
1244                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1245                            break;
1246                        }
1247                    }
1248                }
1249            }
1250
1251            let Some(first_text) = first_text else {
1252                finish_with_status(CodegenStatus::Done, cx);
1253                return;
1254            };
1255
1256            let move_last_token_usage = last_token_usage.clone();
1257
1258            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1259                completion_events.filter_map(move |e| {
1260                    let process_tool_use = process_tool_use.clone();
1261                    let last_token_usage = move_last_token_usage.clone();
1262                    let total_text = total_text.clone();
1263                    let mut message_tx = message_tx.clone();
1264                    async move {
1265                        match e {
1266                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1267                                let Some(output) = process_tool_use(tool_use) else {
1268                                    return None;
1269                                };
1270                                let (text, update) = match output {
1271                                    ToolUseOutput::Rewrite { text, description } => {
1272                                        (Some(text), description.map(ModelUpdate::Description))
1273                                    }
1274                                    ToolUseOutput::Failure(message) => {
1275                                        (None, Some(ModelUpdate::Failure(message)))
1276                                    }
1277                                };
1278                                if let Some(update) = update {
1279                                    let _ = message_tx.send(update).await;
1280                                }
1281                                text.map(Ok)
1282                            }
1283                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1284                                *last_token_usage.lock() = token_usage;
1285                                None
1286                            }
1287                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1288                                let mut lock = total_text.lock();
1289                                lock.push_str(&text);
1290                                None
1291                            }
1292                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1293                            e => {
1294                                log::error!("UNEXPECTED EVENT {:?}", e);
1295                                None
1296                            }
1297                        }
1298                    }
1299                }),
1300            ));
1301
1302            let language_model_text_stream = LanguageModelTextStream {
1303                message_id: message_id,
1304                stream: text_stream,
1305                last_token_usage,
1306            };
1307
1308            let Some(task) = codegen
1309                .update(cx, move |codegen, cx| {
1310                    codegen.handle_stream(model, async { Ok(language_model_text_stream) }, cx)
1311                })
1312                .ok()
1313            else {
1314                return;
1315            };
1316
1317            task.await;
1318        })
1319    }
1320}
1321
1322#[derive(Copy, Clone, Debug)]
1323pub enum CodegenEvent {
1324    Finished,
1325    Undone,
1326}
1327
1328struct StripInvalidSpans<T> {
1329    stream: T,
1330    stream_done: bool,
1331    buffer: String,
1332    first_line: bool,
1333    line_end: bool,
1334    starts_with_code_block: bool,
1335}
1336
1337impl<T> StripInvalidSpans<T>
1338where
1339    T: Stream<Item = Result<String>>,
1340{
1341    fn new(stream: T) -> Self {
1342        Self {
1343            stream,
1344            stream_done: false,
1345            buffer: String::new(),
1346            first_line: true,
1347            line_end: false,
1348            starts_with_code_block: false,
1349        }
1350    }
1351}
1352
1353impl<T> Stream for StripInvalidSpans<T>
1354where
1355    T: Stream<Item = Result<String>>,
1356{
1357    type Item = Result<String>;
1358
1359    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1360        const CODE_BLOCK_DELIMITER: &str = "```";
1361        const CURSOR_SPAN: &str = "<|CURSOR|>";
1362
1363        let this = unsafe { self.get_unchecked_mut() };
1364        loop {
1365            if !this.stream_done {
1366                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1367                match stream.as_mut().poll_next(cx) {
1368                    Poll::Ready(Some(Ok(chunk))) => {
1369                        this.buffer.push_str(&chunk);
1370                    }
1371                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1372                    Poll::Ready(None) => {
1373                        this.stream_done = true;
1374                    }
1375                    Poll::Pending => return Poll::Pending,
1376                }
1377            }
1378
1379            let mut chunk = String::new();
1380            let mut consumed = 0;
1381            if !this.buffer.is_empty() {
1382                let mut lines = this.buffer.split('\n').enumerate().peekable();
1383                while let Some((line_ix, line)) = lines.next() {
1384                    if line_ix > 0 {
1385                        this.first_line = false;
1386                    }
1387
1388                    if this.first_line {
1389                        let trimmed_line = line.trim();
1390                        if lines.peek().is_some() {
1391                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1392                                consumed += line.len() + 1;
1393                                this.starts_with_code_block = true;
1394                                continue;
1395                            }
1396                        } else if trimmed_line.is_empty()
1397                            || prefixes(CODE_BLOCK_DELIMITER)
1398                                .any(|prefix| trimmed_line.starts_with(prefix))
1399                        {
1400                            break;
1401                        }
1402                    }
1403
1404                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1405                    if lines.peek().is_some() {
1406                        if this.line_end {
1407                            chunk.push('\n');
1408                        }
1409
1410                        chunk.push_str(&line_without_cursor);
1411                        this.line_end = true;
1412                        consumed += line.len() + 1;
1413                    } else if this.stream_done {
1414                        if !this.starts_with_code_block
1415                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1416                        {
1417                            if this.line_end {
1418                                chunk.push('\n');
1419                            }
1420
1421                            chunk.push_str(line);
1422                        }
1423
1424                        consumed += line.len();
1425                    } else {
1426                        let trimmed_line = line.trim();
1427                        if trimmed_line.is_empty()
1428                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1429                            || prefixes(CODE_BLOCK_DELIMITER)
1430                                .any(|prefix| trimmed_line.ends_with(prefix))
1431                        {
1432                            break;
1433                        } else {
1434                            if this.line_end {
1435                                chunk.push('\n');
1436                                this.line_end = false;
1437                            }
1438
1439                            chunk.push_str(&line_without_cursor);
1440                            consumed += line.len();
1441                        }
1442                    }
1443                }
1444            }
1445
1446            this.buffer = this.buffer.split_off(consumed);
1447            if !chunk.is_empty() {
1448                return Poll::Ready(Some(Ok(chunk)));
1449            } else if this.stream_done {
1450                return Poll::Ready(None);
1451            }
1452        }
1453    }
1454}
1455
1456fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1457    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1458}
1459
1460#[derive(Default)]
1461pub struct Diff {
1462    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1463    pub inserted_row_ranges: Vec<Range<Anchor>>,
1464}
1465
1466impl Diff {
1467    fn is_empty(&self) -> bool {
1468        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1469    }
1470}
1471
1472#[cfg(test)]
1473mod tests {
1474    use super::*;
1475    use futures::{
1476        Stream,
1477        stream::{self},
1478    };
1479    use gpui::TestAppContext;
1480    use indoc::indoc;
1481    use language::{Buffer, Point};
1482    use language_model::fake_provider::FakeLanguageModel;
1483    use language_model::{LanguageModelRegistry, TokenUsage};
1484    use languages::rust_lang;
1485    use rand::prelude::*;
1486    use settings::SettingsStore;
1487    use std::{future, sync::Arc};
1488
1489    #[gpui::test(iterations = 10)]
1490    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1491        init_test(cx);
1492
1493        let text = indoc! {"
1494            fn main() {
1495                let x = 0;
1496                for _ in 0..10 {
1497                    x += 1;
1498                }
1499            }
1500        "};
1501        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1502        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1503        let range = buffer.read_with(cx, |buffer, cx| {
1504            let snapshot = buffer.snapshot(cx);
1505            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1506        });
1507        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1508        let codegen = cx.new(|cx| {
1509            CodegenAlternative::new(
1510                buffer.clone(),
1511                range.clone(),
1512                true,
1513                prompt_builder,
1514                Uuid::new_v4(),
1515                cx,
1516            )
1517        });
1518
1519        let chunks_tx = simulate_response_stream(&codegen, cx);
1520
1521        let mut new_text = concat!(
1522            "       let mut x = 0;\n",
1523            "       while x < 10 {\n",
1524            "           x += 1;\n",
1525            "       }",
1526        );
1527        while !new_text.is_empty() {
1528            let max_len = cmp::min(new_text.len(), 10);
1529            let len = rng.random_range(1..=max_len);
1530            let (chunk, suffix) = new_text.split_at(len);
1531            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1532            new_text = suffix;
1533            cx.background_executor.run_until_parked();
1534        }
1535        drop(chunks_tx);
1536        cx.background_executor.run_until_parked();
1537
1538        assert_eq!(
1539            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1540            indoc! {"
1541                fn main() {
1542                    let mut x = 0;
1543                    while x < 10 {
1544                        x += 1;
1545                    }
1546                }
1547            "}
1548        );
1549    }
1550
1551    #[gpui::test(iterations = 10)]
1552    async fn test_autoindent_when_generating_past_indentation(
1553        cx: &mut TestAppContext,
1554        mut rng: StdRng,
1555    ) {
1556        init_test(cx);
1557
1558        let text = indoc! {"
1559            fn main() {
1560                le
1561            }
1562        "};
1563        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1564        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1565        let range = buffer.read_with(cx, |buffer, cx| {
1566            let snapshot = buffer.snapshot(cx);
1567            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1568        });
1569        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1570        let codegen = cx.new(|cx| {
1571            CodegenAlternative::new(
1572                buffer.clone(),
1573                range.clone(),
1574                true,
1575                prompt_builder,
1576                Uuid::new_v4(),
1577                cx,
1578            )
1579        });
1580
1581        let chunks_tx = simulate_response_stream(&codegen, cx);
1582
1583        cx.background_executor.run_until_parked();
1584
1585        let mut new_text = concat!(
1586            "t mut x = 0;\n",
1587            "while x < 10 {\n",
1588            "    x += 1;\n",
1589            "}", //
1590        );
1591        while !new_text.is_empty() {
1592            let max_len = cmp::min(new_text.len(), 10);
1593            let len = rng.random_range(1..=max_len);
1594            let (chunk, suffix) = new_text.split_at(len);
1595            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1596            new_text = suffix;
1597            cx.background_executor.run_until_parked();
1598        }
1599        drop(chunks_tx);
1600        cx.background_executor.run_until_parked();
1601
1602        assert_eq!(
1603            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1604            indoc! {"
1605                fn main() {
1606                    let mut x = 0;
1607                    while x < 10 {
1608                        x += 1;
1609                    }
1610                }
1611            "}
1612        );
1613    }
1614
1615    #[gpui::test(iterations = 10)]
1616    async fn test_autoindent_when_generating_before_indentation(
1617        cx: &mut TestAppContext,
1618        mut rng: StdRng,
1619    ) {
1620        init_test(cx);
1621
1622        let text = concat!(
1623            "fn main() {\n",
1624            "  \n",
1625            "}\n" //
1626        );
1627        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1628        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1629        let range = buffer.read_with(cx, |buffer, cx| {
1630            let snapshot = buffer.snapshot(cx);
1631            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1632        });
1633        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1634        let codegen = cx.new(|cx| {
1635            CodegenAlternative::new(
1636                buffer.clone(),
1637                range.clone(),
1638                true,
1639                prompt_builder,
1640                Uuid::new_v4(),
1641                cx,
1642            )
1643        });
1644
1645        let chunks_tx = simulate_response_stream(&codegen, cx);
1646
1647        cx.background_executor.run_until_parked();
1648
1649        let mut new_text = concat!(
1650            "let mut x = 0;\n",
1651            "while x < 10 {\n",
1652            "    x += 1;\n",
1653            "}", //
1654        );
1655        while !new_text.is_empty() {
1656            let max_len = cmp::min(new_text.len(), 10);
1657            let len = rng.random_range(1..=max_len);
1658            let (chunk, suffix) = new_text.split_at(len);
1659            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1660            new_text = suffix;
1661            cx.background_executor.run_until_parked();
1662        }
1663        drop(chunks_tx);
1664        cx.background_executor.run_until_parked();
1665
1666        assert_eq!(
1667            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1668            indoc! {"
1669                fn main() {
1670                    let mut x = 0;
1671                    while x < 10 {
1672                        x += 1;
1673                    }
1674                }
1675            "}
1676        );
1677    }
1678
1679    #[gpui::test(iterations = 10)]
1680    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1681        init_test(cx);
1682
1683        let text = indoc! {"
1684            func main() {
1685            \tx := 0
1686            \tfor i := 0; i < 10; i++ {
1687            \t\tx++
1688            \t}
1689            }
1690        "};
1691        let buffer = cx.new(|cx| Buffer::local(text, cx));
1692        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1693        let range = buffer.read_with(cx, |buffer, cx| {
1694            let snapshot = buffer.snapshot(cx);
1695            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1696        });
1697        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1698        let codegen = cx.new(|cx| {
1699            CodegenAlternative::new(
1700                buffer.clone(),
1701                range.clone(),
1702                true,
1703                prompt_builder,
1704                Uuid::new_v4(),
1705                cx,
1706            )
1707        });
1708
1709        let chunks_tx = simulate_response_stream(&codegen, cx);
1710        let new_text = concat!(
1711            "func main() {\n",
1712            "\tx := 0\n",
1713            "\tfor x < 10 {\n",
1714            "\t\tx++\n",
1715            "\t}", //
1716        );
1717        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1718        drop(chunks_tx);
1719        cx.background_executor.run_until_parked();
1720
1721        assert_eq!(
1722            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1723            indoc! {"
1724                func main() {
1725                \tx := 0
1726                \tfor x < 10 {
1727                \t\tx++
1728                \t}
1729                }
1730            "}
1731        );
1732    }
1733
1734    #[gpui::test]
1735    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1736        init_test(cx);
1737
1738        let text = indoc! {"
1739            fn main() {
1740                let x = 0;
1741            }
1742        "};
1743        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1744        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1745        let range = buffer.read_with(cx, |buffer, cx| {
1746            let snapshot = buffer.snapshot(cx);
1747            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1748        });
1749        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1750        let codegen = cx.new(|cx| {
1751            CodegenAlternative::new(
1752                buffer.clone(),
1753                range.clone(),
1754                false,
1755                prompt_builder,
1756                Uuid::new_v4(),
1757                cx,
1758            )
1759        });
1760
1761        let chunks_tx = simulate_response_stream(&codegen, cx);
1762        chunks_tx
1763            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1764            .unwrap();
1765        drop(chunks_tx);
1766        cx.run_until_parked();
1767
1768        // The codegen is inactive, so the buffer doesn't get modified.
1769        assert_eq!(
1770            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1771            text
1772        );
1773
1774        // Activating the codegen applies the changes.
1775        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1776        assert_eq!(
1777            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1778            indoc! {"
1779                fn main() {
1780                    let mut x = 0;
1781                    x += 1;
1782                }
1783            "}
1784        );
1785
1786        // Deactivating the codegen undoes the changes.
1787        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1788        cx.run_until_parked();
1789        assert_eq!(
1790            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1791            text
1792        );
1793    }
1794
1795    #[gpui::test]
1796    async fn test_strip_invalid_spans_from_codeblock() {
1797        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1798        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1799        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1800        assert_chunks(
1801            "```html\n```js\nLorem ipsum dolor\n```\n```",
1802            "```js\nLorem ipsum dolor\n```",
1803        )
1804        .await;
1805        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1806        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1807        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1808        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1809
1810        async fn assert_chunks(text: &str, expected_text: &str) {
1811            for chunk_size in 1..=text.len() {
1812                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1813                    .map(|chunk| chunk.unwrap())
1814                    .collect::<String>()
1815                    .await;
1816                assert_eq!(
1817                    actual_text, expected_text,
1818                    "failed to strip invalid spans, chunk size: {}",
1819                    chunk_size
1820                );
1821            }
1822        }
1823
1824        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1825            stream::iter(
1826                text.chars()
1827                    .collect::<Vec<_>>()
1828                    .chunks(size)
1829                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1830                    .collect::<Vec<_>>(),
1831            )
1832        }
1833    }
1834
1835    fn init_test(cx: &mut TestAppContext) {
1836        cx.update(LanguageModelRegistry::test);
1837        cx.set_global(cx.update(SettingsStore::test));
1838    }
1839
1840    fn simulate_response_stream(
1841        codegen: &Entity<CodegenAlternative>,
1842        cx: &mut TestAppContext,
1843    ) -> mpsc::UnboundedSender<String> {
1844        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1845        let model = Arc::new(FakeLanguageModel::default());
1846        codegen.update(cx, |codegen, cx| {
1847            codegen.generation = codegen.handle_stream(
1848                model,
1849                future::ready(Ok(LanguageModelTextStream {
1850                    message_id: None,
1851                    stream: chunks_rx.map(Ok).boxed(),
1852                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1853                })),
1854                cx,
1855            );
1856        });
1857        chunks_tx
1858    }
1859}