buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use uuid::Uuid;
   5
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
  10use futures::{
  11    SinkExt, Stream, StreamExt, TryStreamExt as _,
  12    channel::mpsc,
  13    future::{LocalBoxFuture, Shared},
  14    join,
  15    stream::BoxStream,
  16};
  17use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  18use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  19use language_model::{
  20    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  22    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  23    LanguageModelToolUse, Role, TokenUsage,
  24};
  25use multi_buffer::MultiBufferRow;
  26use parking_lot::Mutex;
  27use prompt_store::PromptBuilder;
  28use rope::Rope;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings as _;
  32use smol::future::FutureExt;
  33use std::{
  34    cmp,
  35    future::Future,
  36    iter,
  37    ops::{Range, RangeInclusive},
  38    pin::Pin,
  39    sync::Arc,
  40    task::{self, Poll},
  41    time::Instant,
  42};
  43use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  44
  45/// Use this tool when you cannot or should not make a rewrite. This includes:
  46/// - The user's request is unclear, ambiguous, or nonsensical
  47/// - The requested change cannot be made by only editing the <rewrite_this> section
  48#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  49pub struct FailureMessageInput {
  50    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  51    #[serde(default)]
  52    pub message: String,
  53}
  54
  55/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  56/// Only use this tool when you are confident you understand the user's request and can fulfill it
  57/// by editing the marked section.
  58#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  59pub struct RewriteSectionInput {
  60    /// The text to replace the section with.
  61    #[serde(default)]
  62    pub replacement_text: String,
  63}
  64
  65pub struct BufferCodegen {
  66    alternatives: Vec<Entity<CodegenAlternative>>,
  67    pub active_alternative: usize,
  68    seen_alternatives: HashSet<usize>,
  69    subscriptions: Vec<Subscription>,
  70    buffer: Entity<MultiBuffer>,
  71    range: Range<Anchor>,
  72    initial_transaction_id: Option<TransactionId>,
  73    builder: Arc<PromptBuilder>,
  74    pub is_insertion: bool,
  75    session_id: Uuid,
  76}
  77
  78impl BufferCodegen {
  79    pub fn new(
  80        buffer: Entity<MultiBuffer>,
  81        range: Range<Anchor>,
  82        initial_transaction_id: Option<TransactionId>,
  83        session_id: Uuid,
  84        builder: Arc<PromptBuilder>,
  85        cx: &mut Context<Self>,
  86    ) -> Self {
  87        let codegen = cx.new(|cx| {
  88            CodegenAlternative::new(
  89                buffer.clone(),
  90                range.clone(),
  91                false,
  92                builder.clone(),
  93                session_id,
  94                cx,
  95            )
  96        });
  97        let mut this = Self {
  98            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
  99            alternatives: vec![codegen],
 100            active_alternative: 0,
 101            seen_alternatives: HashSet::default(),
 102            subscriptions: Vec::new(),
 103            buffer,
 104            range,
 105            initial_transaction_id,
 106            builder,
 107            session_id,
 108        };
 109        this.activate(0, cx);
 110        this
 111    }
 112
 113    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 114        let codegen = self.active_alternative().clone();
 115        self.subscriptions.clear();
 116        self.subscriptions
 117            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 118        self.subscriptions
 119            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 120    }
 121
 122    pub fn active_completion(&self, cx: &App) -> Option<String> {
 123        self.active_alternative().read(cx).current_completion()
 124    }
 125
 126    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 127        &self.alternatives[self.active_alternative]
 128    }
 129
 130    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 131        self.active_alternative().read(cx).language_name(cx)
 132    }
 133
 134    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 135        &self.active_alternative().read(cx).status
 136    }
 137
 138    pub fn alternative_count(&self, cx: &App) -> usize {
 139        LanguageModelRegistry::read_global(cx)
 140            .inline_alternative_models()
 141            .len()
 142            + 1
 143    }
 144
 145    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 146        let next_active_ix = if self.active_alternative == 0 {
 147            self.alternatives.len() - 1
 148        } else {
 149            self.active_alternative - 1
 150        };
 151        self.activate(next_active_ix, cx);
 152    }
 153
 154    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 155        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 156        self.activate(next_active_ix, cx);
 157    }
 158
 159    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 160        self.active_alternative()
 161            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 162        self.seen_alternatives.insert(index);
 163        self.active_alternative = index;
 164        self.active_alternative()
 165            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 166        self.subscribe_to_alternative(cx);
 167        cx.notify();
 168    }
 169
 170    pub fn start(
 171        &mut self,
 172        primary_model: Arc<dyn LanguageModel>,
 173        user_prompt: String,
 174        context_task: Shared<Task<Option<LoadedContext>>>,
 175        cx: &mut Context<Self>,
 176    ) -> Result<()> {
 177        let alternative_models = LanguageModelRegistry::read_global(cx)
 178            .inline_alternative_models()
 179            .to_vec();
 180
 181        self.active_alternative()
 182            .update(cx, |alternative, cx| alternative.undo(cx));
 183        self.activate(0, cx);
 184        self.alternatives.truncate(1);
 185
 186        for _ in 0..alternative_models.len() {
 187            self.alternatives.push(cx.new(|cx| {
 188                CodegenAlternative::new(
 189                    self.buffer.clone(),
 190                    self.range.clone(),
 191                    false,
 192                    self.builder.clone(),
 193                    self.session_id,
 194                    cx,
 195                )
 196            }));
 197        }
 198
 199        for (model, alternative) in iter::once(primary_model)
 200            .chain(alternative_models)
 201            .zip(&self.alternatives)
 202        {
 203            alternative.update(cx, |alternative, cx| {
 204                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 205            })?;
 206        }
 207
 208        Ok(())
 209    }
 210
 211    pub fn stop(&mut self, cx: &mut Context<Self>) {
 212        for codegen in &self.alternatives {
 213            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 214        }
 215    }
 216
 217    pub fn undo(&mut self, cx: &mut Context<Self>) {
 218        self.active_alternative()
 219            .update(cx, |codegen, cx| codegen.undo(cx));
 220
 221        self.buffer.update(cx, |buffer, cx| {
 222            if let Some(transaction_id) = self.initial_transaction_id.take() {
 223                buffer.undo_transaction(transaction_id, cx);
 224                buffer.refresh_preview(cx);
 225            }
 226        });
 227    }
 228
 229    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 230        self.active_alternative().read(cx).buffer.clone()
 231    }
 232
 233    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 234        self.active_alternative().read(cx).old_buffer.clone()
 235    }
 236
 237    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 238        self.active_alternative().read(cx).snapshot.clone()
 239    }
 240
 241    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 242        self.active_alternative().read(cx).edit_position
 243    }
 244
 245    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 246        &self.active_alternative().read(cx).diff
 247    }
 248
 249    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 250        self.active_alternative().read(cx).last_equal_ranges()
 251    }
 252
 253    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 254        self.active_alternative().read(cx).selected_text()
 255    }
 256
 257    pub fn session_id(&self) -> Uuid {
 258        self.session_id
 259    }
 260}
 261
 262impl EventEmitter<CodegenEvent> for BufferCodegen {}
 263
 264pub struct CodegenAlternative {
 265    buffer: Entity<MultiBuffer>,
 266    old_buffer: Entity<Buffer>,
 267    snapshot: MultiBufferSnapshot,
 268    edit_position: Option<Anchor>,
 269    range: Range<Anchor>,
 270    last_equal_ranges: Vec<Range<Anchor>>,
 271    transformation_transaction_id: Option<TransactionId>,
 272    status: CodegenStatus,
 273    generation: Task<()>,
 274    diff: Diff,
 275    _subscription: gpui::Subscription,
 276    builder: Arc<PromptBuilder>,
 277    active: bool,
 278    edits: Vec<(Range<Anchor>, String)>,
 279    line_operations: Vec<LineOperation>,
 280    elapsed_time: Option<f64>,
 281    completion: Option<String>,
 282    selected_text: Option<String>,
 283    pub message_id: Option<String>,
 284    session_id: Uuid,
 285    pub description: Option<String>,
 286    pub failure: Option<String>,
 287}
 288
 289impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 290
 291impl CodegenAlternative {
 292    pub fn new(
 293        buffer: Entity<MultiBuffer>,
 294        range: Range<Anchor>,
 295        active: bool,
 296        builder: Arc<PromptBuilder>,
 297        session_id: Uuid,
 298        cx: &mut Context<Self>,
 299    ) -> Self {
 300        let snapshot = buffer.read(cx).snapshot(cx);
 301
 302        let (old_buffer, _, _) = snapshot
 303            .range_to_buffer_ranges(range.clone())
 304            .pop()
 305            .unwrap();
 306        let old_buffer = cx.new(|cx| {
 307            let text = old_buffer.as_rope().clone();
 308            let line_ending = old_buffer.line_ending();
 309            let language = old_buffer.language().cloned();
 310            let language_registry = buffer
 311                .read(cx)
 312                .buffer(old_buffer.remote_id())
 313                .unwrap()
 314                .read(cx)
 315                .language_registry();
 316
 317            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 318            buffer.set_language(language, cx);
 319            if let Some(language_registry) = language_registry {
 320                buffer.set_language_registry(language_registry);
 321            }
 322            buffer
 323        });
 324
 325        Self {
 326            buffer: buffer.clone(),
 327            old_buffer,
 328            edit_position: None,
 329            message_id: None,
 330            snapshot,
 331            last_equal_ranges: Default::default(),
 332            transformation_transaction_id: None,
 333            status: CodegenStatus::Idle,
 334            generation: Task::ready(()),
 335            diff: Diff::default(),
 336            builder,
 337            active: active,
 338            edits: Vec::new(),
 339            line_operations: Vec::new(),
 340            range,
 341            elapsed_time: None,
 342            completion: None,
 343            selected_text: None,
 344            session_id,
 345            description: None,
 346            failure: None,
 347            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 348        }
 349    }
 350
 351    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 352        self.old_buffer
 353            .read(cx)
 354            .language()
 355            .map(|language| language.name())
 356    }
 357
 358    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 359        if active != self.active {
 360            self.active = active;
 361
 362            if self.active {
 363                let edits = self.edits.clone();
 364                self.apply_edits(edits, cx);
 365                if matches!(self.status, CodegenStatus::Pending) {
 366                    let line_operations = self.line_operations.clone();
 367                    self.reapply_line_based_diff(line_operations, cx);
 368                } else {
 369                    self.reapply_batch_diff(cx).detach();
 370                }
 371            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 372                self.buffer.update(cx, |buffer, cx| {
 373                    buffer.undo_transaction(transaction_id, cx);
 374                    buffer.forget_transaction(transaction_id, cx);
 375                });
 376            }
 377        }
 378    }
 379
 380    fn handle_buffer_event(
 381        &mut self,
 382        _buffer: Entity<MultiBuffer>,
 383        event: &multi_buffer::Event,
 384        cx: &mut Context<Self>,
 385    ) {
 386        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 387            && self.transformation_transaction_id == Some(*transaction_id)
 388        {
 389            self.transformation_transaction_id = None;
 390            self.generation = Task::ready(());
 391            cx.emit(CodegenEvent::Undone);
 392        }
 393    }
 394
 395    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 396        &self.last_equal_ranges
 397    }
 398
 399    pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 400        model.supports_streaming_tools()
 401            && cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
 402            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 403    }
 404
 405    pub fn start(
 406        &mut self,
 407        user_prompt: String,
 408        context_task: Shared<Task<Option<LoadedContext>>>,
 409        model: Arc<dyn LanguageModel>,
 410        cx: &mut Context<Self>,
 411    ) -> Result<()> {
 412        // Clear the model explanation since the user has started a new generation.
 413        self.description = None;
 414
 415        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 416            self.buffer.update(cx, |buffer, cx| {
 417                buffer.undo_transaction(transformation_transaction_id, cx);
 418            });
 419        }
 420
 421        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 422
 423        if Self::use_streaming_tools(model.as_ref(), cx) {
 424            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 425            let completion_events = cx.spawn({
 426                let model = model.clone();
 427                async move |_, cx| model.stream_completion(request.await, cx).await
 428            });
 429            self.generation = self.handle_completion(model, completion_events, cx);
 430        } else {
 431            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 432                if user_prompt.trim().to_lowercase() == "delete" {
 433                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 434                } else {
 435                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 436                    cx.spawn({
 437                        let model = model.clone();
 438                        async move |_, cx| {
 439                            Ok(model.stream_completion_text(request.await, cx).await?)
 440                        }
 441                    })
 442                    .boxed_local()
 443                };
 444            self.generation =
 445                self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
 446        }
 447
 448        Ok(())
 449    }
 450
 451    fn build_request_tools(
 452        &self,
 453        model: &Arc<dyn LanguageModel>,
 454        user_prompt: String,
 455        context_task: Shared<Task<Option<LoadedContext>>>,
 456        cx: &mut App,
 457    ) -> Result<Task<LanguageModelRequest>> {
 458        let buffer = self.buffer.read(cx).snapshot(cx);
 459        let language = buffer.language_at(self.range.start);
 460        let language_name = if let Some(language) = language.as_ref() {
 461            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 462                None
 463            } else {
 464                Some(language.name())
 465            }
 466        } else {
 467            None
 468        };
 469
 470        let language_name = language_name.as_ref();
 471        let start = buffer.point_to_buffer_offset(self.range.start);
 472        let end = buffer.point_to_buffer_offset(self.range.end);
 473        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 474            let (start_buffer, start_buffer_offset) = start;
 475            let (end_buffer, end_buffer_offset) = end;
 476            if start_buffer.remote_id() == end_buffer.remote_id() {
 477                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 478            } else {
 479                anyhow::bail!("invalid transformation range");
 480            }
 481        } else {
 482            anyhow::bail!("invalid transformation range");
 483        };
 484
 485        let system_prompt = self
 486            .builder
 487            .generate_inline_transformation_prompt_tools(
 488                language_name,
 489                buffer,
 490                range.start.0..range.end.0,
 491            )
 492            .context("generating content prompt")?;
 493
 494        let temperature = AgentSettings::temperature_for_model(model, cx);
 495
 496        let tool_input_format = model.tool_input_format();
 497        let tool_choice = model
 498            .supports_tool_choice(LanguageModelToolChoice::Any)
 499            .then_some(LanguageModelToolChoice::Any);
 500
 501        Ok(cx.spawn(async move |_cx| {
 502            let mut messages = vec![LanguageModelRequestMessage {
 503                role: Role::System,
 504                content: vec![system_prompt.into()],
 505                cache: false,
 506                reasoning_details: None,
 507            }];
 508
 509            let mut user_message = LanguageModelRequestMessage {
 510                role: Role::User,
 511                content: Vec::new(),
 512                cache: false,
 513                reasoning_details: None,
 514            };
 515
 516            if let Some(context) = context_task.await {
 517                context.add_to_request_message(&mut user_message);
 518            }
 519
 520            user_message.content.push(user_prompt.into());
 521            messages.push(user_message);
 522
 523            let tools = vec![
 524                LanguageModelRequestTool {
 525                    name: "rewrite_section".to_string(),
 526                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 527                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 528                },
 529                LanguageModelRequestTool {
 530                    name: "failure_message".to_string(),
 531                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 532                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 533                },
 534            ];
 535
 536            LanguageModelRequest {
 537                thread_id: None,
 538                prompt_id: None,
 539                intent: Some(CompletionIntent::InlineAssist),
 540                mode: None,
 541                tools,
 542                tool_choice,
 543                stop: Vec::new(),
 544                temperature,
 545                messages,
 546                thinking_allowed: false,
 547            }
 548        }))
 549    }
 550
 551    fn build_request(
 552        &self,
 553        model: &Arc<dyn LanguageModel>,
 554        user_prompt: String,
 555        context_task: Shared<Task<Option<LoadedContext>>>,
 556        cx: &mut App,
 557    ) -> Result<Task<LanguageModelRequest>> {
 558        if Self::use_streaming_tools(model.as_ref(), cx) {
 559            return self.build_request_tools(model, user_prompt, context_task, cx);
 560        }
 561
 562        let buffer = self.buffer.read(cx).snapshot(cx);
 563        let language = buffer.language_at(self.range.start);
 564        let language_name = if let Some(language) = language.as_ref() {
 565            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 566                None
 567            } else {
 568                Some(language.name())
 569            }
 570        } else {
 571            None
 572        };
 573
 574        let language_name = language_name.as_ref();
 575        let start = buffer.point_to_buffer_offset(self.range.start);
 576        let end = buffer.point_to_buffer_offset(self.range.end);
 577        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 578            let (start_buffer, start_buffer_offset) = start;
 579            let (end_buffer, end_buffer_offset) = end;
 580            if start_buffer.remote_id() == end_buffer.remote_id() {
 581                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 582            } else {
 583                anyhow::bail!("invalid transformation range");
 584            }
 585        } else {
 586            anyhow::bail!("invalid transformation range");
 587        };
 588
 589        let prompt = self
 590            .builder
 591            .generate_inline_transformation_prompt(
 592                user_prompt,
 593                language_name,
 594                buffer,
 595                range.start.0..range.end.0,
 596            )
 597            .context("generating content prompt")?;
 598
 599        let temperature = AgentSettings::temperature_for_model(model, cx);
 600
 601        Ok(cx.spawn(async move |_cx| {
 602            let mut request_message = LanguageModelRequestMessage {
 603                role: Role::User,
 604                content: Vec::new(),
 605                cache: false,
 606                reasoning_details: None,
 607            };
 608
 609            if let Some(context) = context_task.await {
 610                context.add_to_request_message(&mut request_message);
 611            }
 612
 613            request_message.content.push(prompt.into());
 614
 615            LanguageModelRequest {
 616                thread_id: None,
 617                prompt_id: None,
 618                intent: Some(CompletionIntent::InlineAssist),
 619                mode: None,
 620                tools: Vec::new(),
 621                tool_choice: None,
 622                stop: Vec::new(),
 623                temperature,
 624                messages: vec![request_message],
 625                thinking_allowed: false,
 626            }
 627        }))
 628    }
 629
 630    pub fn handle_stream(
 631        &mut self,
 632        model: Arc<dyn LanguageModel>,
 633        strip_invalid_spans: bool,
 634        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 635        cx: &mut Context<Self>,
 636    ) -> Task<()> {
 637        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 638        let session_id = self.session_id;
 639        let model_telemetry_id = model.telemetry_id();
 640        let model_provider_id = model.provider_id().to_string();
 641        let start_time = Instant::now();
 642
 643        // Make a new snapshot and re-resolve anchor in case the document was modified.
 644        // This can happen often if the editor loses focus and is saved + reformatted,
 645        // as in https://github.com/zed-industries/zed/issues/39088
 646        self.snapshot = self.buffer.read(cx).snapshot(cx);
 647        self.range = self.snapshot.anchor_after(self.range.start)
 648            ..self.snapshot.anchor_after(self.range.end);
 649
 650        let snapshot = self.snapshot.clone();
 651        let selected_text = snapshot
 652            .text_for_range(self.range.start..self.range.end)
 653            .collect::<Rope>();
 654
 655        self.selected_text = Some(selected_text.to_string());
 656
 657        let selection_start = self.range.start.to_point(&snapshot);
 658
 659        // Start with the indentation of the first line in the selection
 660        let mut suggested_line_indent = snapshot
 661            .suggested_indents(selection_start.row..=selection_start.row, cx)
 662            .into_values()
 663            .next()
 664            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 665
 666        // If the first line in the selection does not have indentation, check the following lines
 667        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 668            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 669                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 670                // Prefer tabs if a line in the selection uses tabs as indentation
 671                if line_indent.kind == IndentKind::Tab {
 672                    suggested_line_indent.kind = IndentKind::Tab;
 673                    break;
 674                }
 675            }
 676        }
 677
 678        let language_name = {
 679            let multibuffer = self.buffer.read(cx);
 680            let snapshot = multibuffer.snapshot(cx);
 681            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 682            ranges
 683                .first()
 684                .and_then(|(buffer, _, _)| buffer.language())
 685                .map(|language| language.name())
 686        };
 687
 688        self.diff = Diff::default();
 689        self.status = CodegenStatus::Pending;
 690        let mut edit_start = self.range.start.to_offset(&snapshot);
 691        let completion = Arc::new(Mutex::new(String::new()));
 692        let completion_clone = completion.clone();
 693
 694        cx.notify();
 695        cx.spawn(async move |codegen, cx| {
 696            let stream = stream.await;
 697
 698            let token_usage = stream
 699                .as_ref()
 700                .ok()
 701                .map(|stream| stream.last_token_usage.clone());
 702            let message_id = stream
 703                .as_ref()
 704                .ok()
 705                .and_then(|stream| stream.message_id.clone());
 706            let generate = async {
 707                let model_telemetry_id = model_telemetry_id.clone();
 708                let model_provider_id = model_provider_id.clone();
 709                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 710                let message_id = message_id.clone();
 711                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 712                    let anthropic_reporter = anthropic_reporter.clone();
 713                    let language_name = language_name.clone();
 714                    async move {
 715                        let mut response_latency = None;
 716                        let request_start = Instant::now();
 717                        let diff = async {
 718                            let raw_stream = stream?.stream.map_err(|error| error.into());
 719
 720                            let stripped;
 721                            let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
 722                                if strip_invalid_spans {
 723                                    stripped = StripInvalidSpans::new(raw_stream);
 724                                    Box::pin(stripped)
 725                                } else {
 726                                    Box::pin(raw_stream)
 727                                };
 728
 729                            let mut diff = StreamingDiff::new(selected_text.to_string());
 730                            let mut line_diff = LineDiff::default();
 731
 732                            let mut new_text = String::new();
 733                            let mut base_indent = None;
 734                            let mut line_indent = None;
 735                            let mut first_line = true;
 736
 737                            while let Some(chunk) = chunks.next().await {
 738                                if response_latency.is_none() {
 739                                    response_latency = Some(request_start.elapsed());
 740                                }
 741                                let chunk = chunk?;
 742                                completion_clone.lock().push_str(&chunk);
 743
 744                                let mut lines = chunk.split('\n').peekable();
 745                                while let Some(line) = lines.next() {
 746                                    new_text.push_str(line);
 747                                    if line_indent.is_none()
 748                                        && let Some(non_whitespace_ch_ix) =
 749                                            new_text.find(|ch: char| !ch.is_whitespace())
 750                                    {
 751                                        line_indent = Some(non_whitespace_ch_ix);
 752                                        base_indent = base_indent.or(line_indent);
 753
 754                                        let line_indent = line_indent.unwrap();
 755                                        let base_indent = base_indent.unwrap();
 756                                        let indent_delta = line_indent as i32 - base_indent as i32;
 757                                        let mut corrected_indent_len = cmp::max(
 758                                            0,
 759                                            suggested_line_indent.len as i32 + indent_delta,
 760                                        )
 761                                            as usize;
 762                                        if first_line {
 763                                            corrected_indent_len = corrected_indent_len
 764                                                .saturating_sub(selection_start.column as usize);
 765                                        }
 766
 767                                        let indent_char = suggested_line_indent.char();
 768                                        let mut indent_buffer = [0; 4];
 769                                        let indent_str =
 770                                            indent_char.encode_utf8(&mut indent_buffer);
 771                                        new_text.replace_range(
 772                                            ..line_indent,
 773                                            &indent_str.repeat(corrected_indent_len),
 774                                        );
 775                                    }
 776
 777                                    if line_indent.is_some() {
 778                                        let char_ops = diff.push_new(&new_text);
 779                                        line_diff.push_char_operations(&char_ops, &selected_text);
 780                                        diff_tx
 781                                            .send((char_ops, line_diff.line_operations()))
 782                                            .await?;
 783                                        new_text.clear();
 784                                    }
 785
 786                                    if lines.peek().is_some() {
 787                                        let char_ops = diff.push_new("\n");
 788                                        line_diff.push_char_operations(&char_ops, &selected_text);
 789                                        diff_tx
 790                                            .send((char_ops, line_diff.line_operations()))
 791                                            .await?;
 792                                        if line_indent.is_none() {
 793                                            // Don't write out the leading indentation in empty lines on the next line
 794                                            // This is the case where the above if statement didn't clear the buffer
 795                                            new_text.clear();
 796                                        }
 797                                        line_indent = None;
 798                                        first_line = false;
 799                                    }
 800                                }
 801                            }
 802
 803                            let mut char_ops = diff.push_new(&new_text);
 804                            char_ops.extend(diff.finish());
 805                            line_diff.push_char_operations(&char_ops, &selected_text);
 806                            line_diff.finish(&selected_text);
 807                            diff_tx
 808                                .send((char_ops, line_diff.line_operations()))
 809                                .await?;
 810
 811                            anyhow::Ok(())
 812                        };
 813
 814                        let result = diff.await;
 815
 816                        let error_message = result.as_ref().err().map(|error| error.to_string());
 817                        telemetry::event!(
 818                            "Assistant Responded",
 819                            kind = "inline",
 820                            phase = "response",
 821                            session_id = session_id.to_string(),
 822                            model = model_telemetry_id,
 823                            model_provider = model_provider_id,
 824                            language_name = language_name.as_ref().map(|n| n.to_string()),
 825                            message_id = message_id.as_deref(),
 826                            response_latency = response_latency,
 827                            error_message = error_message.as_deref(),
 828                        );
 829
 830                        anthropic_reporter.report(language_model::AnthropicEventData {
 831                            completion_type: language_model::AnthropicCompletionType::Editor,
 832                            event: language_model::AnthropicEventType::Response,
 833                            language_name: language_name.map(|n| n.to_string()),
 834                            message_id,
 835                        });
 836
 837                        result?;
 838                        Ok(())
 839                    }
 840                });
 841
 842                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 843                    codegen.update(cx, |codegen, cx| {
 844                        codegen.last_equal_ranges.clear();
 845
 846                        let edits = char_ops
 847                            .into_iter()
 848                            .filter_map(|operation| match operation {
 849                                CharOperation::Insert { text } => {
 850                                    let edit_start = snapshot.anchor_after(edit_start);
 851                                    Some((edit_start..edit_start, text))
 852                                }
 853                                CharOperation::Delete { bytes } => {
 854                                    let edit_end = edit_start + bytes;
 855                                    let edit_range = snapshot.anchor_after(edit_start)
 856                                        ..snapshot.anchor_before(edit_end);
 857                                    edit_start = edit_end;
 858                                    Some((edit_range, String::new()))
 859                                }
 860                                CharOperation::Keep { bytes } => {
 861                                    let edit_end = edit_start + bytes;
 862                                    let edit_range = snapshot.anchor_after(edit_start)
 863                                        ..snapshot.anchor_before(edit_end);
 864                                    edit_start = edit_end;
 865                                    codegen.last_equal_ranges.push(edit_range);
 866                                    None
 867                                }
 868                            })
 869                            .collect::<Vec<_>>();
 870
 871                        if codegen.active {
 872                            codegen.apply_edits(edits.iter().cloned(), cx);
 873                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 874                        }
 875                        codegen.edits.extend(edits);
 876                        codegen.line_operations = line_ops;
 877                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 878
 879                        cx.notify();
 880                    })?;
 881                }
 882
 883                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 884                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 885                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 886                let batch_diff_task =
 887                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 888                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 889                line_based_stream_diff?;
 890
 891                anyhow::Ok(())
 892            };
 893
 894            let result = generate.await;
 895            let elapsed_time = start_time.elapsed().as_secs_f64();
 896
 897            codegen
 898                .update(cx, |this, cx| {
 899                    this.message_id = message_id;
 900                    this.last_equal_ranges.clear();
 901                    if let Err(error) = result {
 902                        this.status = CodegenStatus::Error(error);
 903                    } else {
 904                        this.status = CodegenStatus::Done;
 905                    }
 906                    this.elapsed_time = Some(elapsed_time);
 907                    this.completion = Some(completion.lock().clone());
 908                    if let Some(usage) = token_usage {
 909                        let usage = usage.lock();
 910                        telemetry::event!(
 911                            "Inline Assistant Completion",
 912                            model = model_telemetry_id,
 913                            model_provider = model_provider_id,
 914                            input_tokens = usage.input_tokens,
 915                            output_tokens = usage.output_tokens,
 916                        )
 917                    }
 918
 919                    cx.emit(CodegenEvent::Finished);
 920                    cx.notify();
 921                })
 922                .ok();
 923        })
 924    }
 925
 926    pub fn current_completion(&self) -> Option<String> {
 927        self.completion.clone()
 928    }
 929
 930    #[cfg(any(test, feature = "test-support"))]
 931    pub fn current_description(&self) -> Option<String> {
 932        self.description.clone()
 933    }
 934
 935    #[cfg(any(test, feature = "test-support"))]
 936    pub fn current_failure(&self) -> Option<String> {
 937        self.failure.clone()
 938    }
 939
 940    pub fn selected_text(&self) -> Option<&str> {
 941        self.selected_text.as_deref()
 942    }
 943
 944    pub fn stop(&mut self, cx: &mut Context<Self>) {
 945        self.last_equal_ranges.clear();
 946        if self.diff.is_empty() {
 947            self.status = CodegenStatus::Idle;
 948        } else {
 949            self.status = CodegenStatus::Done;
 950        }
 951        self.generation = Task::ready(());
 952        cx.emit(CodegenEvent::Finished);
 953        cx.notify();
 954    }
 955
 956    pub fn undo(&mut self, cx: &mut Context<Self>) {
 957        self.buffer.update(cx, |buffer, cx| {
 958            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 959                buffer.undo_transaction(transaction_id, cx);
 960                buffer.refresh_preview(cx);
 961            }
 962        });
 963    }
 964
 965    fn apply_edits(
 966        &mut self,
 967        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 968        cx: &mut Context<CodegenAlternative>,
 969    ) {
 970        let transaction = self.buffer.update(cx, |buffer, cx| {
 971            // Avoid grouping agent edits with user edits.
 972            buffer.finalize_last_transaction(cx);
 973            buffer.start_transaction(cx);
 974            buffer.edit(edits, None, cx);
 975            buffer.end_transaction(cx)
 976        });
 977
 978        if let Some(transaction) = transaction {
 979            if let Some(first_transaction) = self.transformation_transaction_id {
 980                // Group all agent edits into the first transaction.
 981                self.buffer.update(cx, |buffer, cx| {
 982                    buffer.merge_transactions(transaction, first_transaction, cx)
 983                });
 984            } else {
 985                self.transformation_transaction_id = Some(transaction);
 986                self.buffer
 987                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 988            }
 989        }
 990    }
 991
 992    fn reapply_line_based_diff(
 993        &mut self,
 994        line_operations: impl IntoIterator<Item = LineOperation>,
 995        cx: &mut Context<Self>,
 996    ) {
 997        let old_snapshot = self.snapshot.clone();
 998        let old_range = self.range.to_point(&old_snapshot);
 999        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1000        let new_range = self.range.to_point(&new_snapshot);
1001
1002        let mut old_row = old_range.start.row;
1003        let mut new_row = new_range.start.row;
1004
1005        self.diff.deleted_row_ranges.clear();
1006        self.diff.inserted_row_ranges.clear();
1007        for operation in line_operations {
1008            match operation {
1009                LineOperation::Keep { lines } => {
1010                    old_row += lines;
1011                    new_row += lines;
1012                }
1013                LineOperation::Delete { lines } => {
1014                    let old_end_row = old_row + lines - 1;
1015                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1016
1017                    if let Some((_, last_deleted_row_range)) =
1018                        self.diff.deleted_row_ranges.last_mut()
1019                    {
1020                        if *last_deleted_row_range.end() + 1 == old_row {
1021                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1022                        } else {
1023                            self.diff
1024                                .deleted_row_ranges
1025                                .push((new_row, old_row..=old_end_row));
1026                        }
1027                    } else {
1028                        self.diff
1029                            .deleted_row_ranges
1030                            .push((new_row, old_row..=old_end_row));
1031                    }
1032
1033                    old_row += lines;
1034                }
1035                LineOperation::Insert { lines } => {
1036                    let new_end_row = new_row + lines - 1;
1037                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1038                    let end = new_snapshot.anchor_before(Point::new(
1039                        new_end_row,
1040                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1041                    ));
1042                    self.diff.inserted_row_ranges.push(start..end);
1043                    new_row += lines;
1044                }
1045            }
1046
1047            cx.notify();
1048        }
1049    }
1050
1051    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1052        let old_snapshot = self.snapshot.clone();
1053        let old_range = self.range.to_point(&old_snapshot);
1054        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1055        let new_range = self.range.to_point(&new_snapshot);
1056
1057        cx.spawn(async move |codegen, cx| {
1058            let (deleted_row_ranges, inserted_row_ranges) = cx
1059                .background_spawn(async move {
1060                    let old_text = old_snapshot
1061                        .text_for_range(
1062                            Point::new(old_range.start.row, 0)
1063                                ..Point::new(
1064                                    old_range.end.row,
1065                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1066                                ),
1067                        )
1068                        .collect::<String>();
1069                    let new_text = new_snapshot
1070                        .text_for_range(
1071                            Point::new(new_range.start.row, 0)
1072                                ..Point::new(
1073                                    new_range.end.row,
1074                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1075                                ),
1076                        )
1077                        .collect::<String>();
1078
1079                    let old_start_row = old_range.start.row;
1080                    let new_start_row = new_range.start.row;
1081                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1082                    let mut inserted_row_ranges = Vec::new();
1083                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1084                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1085                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1086                        if !old_rows.is_empty() {
1087                            deleted_row_ranges.push((
1088                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1089                                old_rows.start..=old_rows.end - 1,
1090                            ));
1091                        }
1092                        if !new_rows.is_empty() {
1093                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1094                            let new_end_row = new_rows.end - 1;
1095                            let end = new_snapshot.anchor_before(Point::new(
1096                                new_end_row,
1097                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1098                            ));
1099                            inserted_row_ranges.push(start..end);
1100                        }
1101                    }
1102                    (deleted_row_ranges, inserted_row_ranges)
1103                })
1104                .await;
1105
1106            codegen
1107                .update(cx, |codegen, cx| {
1108                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1109                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1110                    cx.notify();
1111                })
1112                .ok();
1113        })
1114    }
1115
1116    fn handle_completion(
1117        &mut self,
1118        model: Arc<dyn LanguageModel>,
1119        completion_stream: Task<
1120            Result<
1121                BoxStream<
1122                    'static,
1123                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1124                >,
1125                LanguageModelCompletionError,
1126            >,
1127        >,
1128        cx: &mut Context<Self>,
1129    ) -> Task<()> {
1130        self.diff = Diff::default();
1131        self.status = CodegenStatus::Pending;
1132
1133        cx.notify();
1134        // Leaving this in generation so that STOP equivalent events are respected even
1135        // while we're still pre-processing the completion event
1136        cx.spawn(async move |codegen, cx| {
1137            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1138                let _ = codegen.update(cx, |this, cx| {
1139                    this.status = status;
1140                    cx.emit(CodegenEvent::Finished);
1141                    cx.notify();
1142                });
1143            };
1144
1145            let mut completion_events = match completion_stream.await {
1146                Ok(events) => events,
1147                Err(err) => {
1148                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1149                    return;
1150                }
1151            };
1152
1153            enum ToolUseOutput {
1154                Rewrite {
1155                    text: String,
1156                    description: Option<String>,
1157                },
1158                Failure(String),
1159            }
1160
1161            enum ModelUpdate {
1162                Description(String),
1163                Failure(String),
1164            }
1165
1166            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1167            let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1168                let mut chars_read_so_far = chars_read_so_far.lock();
1169                match tool_use.name.as_ref() {
1170                    "rewrite_section" => {
1171                        let Ok(input) =
1172                            serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1173                        else {
1174                            return None;
1175                        };
1176                        let text = input.replacement_text[*chars_read_so_far..].to_string();
1177                        *chars_read_so_far = input.replacement_text.len();
1178                        Some(ToolUseOutput::Rewrite {
1179                            text,
1180                            description: None,
1181                        })
1182                    }
1183                    "failure_message" => {
1184                        let Ok(mut input) =
1185                            serde_json::from_value::<FailureMessageInput>(tool_use.input)
1186                        else {
1187                            return None;
1188                        };
1189                        Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1190                    }
1191                    _ => None,
1192                }
1193            };
1194
1195            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1196
1197            cx.spawn({
1198                let codegen = codegen.clone();
1199                async move |cx| {
1200                    while let Some(update) = message_rx.next().await {
1201                        let _ = codegen.update(cx, |this, _cx| match update {
1202                            ModelUpdate::Description(d) => this.description = Some(d),
1203                            ModelUpdate::Failure(f) => this.failure = Some(f),
1204                        });
1205                    }
1206                }
1207            })
1208            .detach();
1209
1210            let mut message_id = None;
1211            let mut first_text = None;
1212            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1213            let total_text = Arc::new(Mutex::new(String::new()));
1214
1215            loop {
1216                if let Some(first_event) = completion_events.next().await {
1217                    match first_event {
1218                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1219                            message_id = Some(id);
1220                        }
1221                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1222                            if let Some(output) = process_tool_use(tool_use) {
1223                                let (text, update) = match output {
1224                                    ToolUseOutput::Rewrite { text, description } => {
1225                                        (Some(text), description.map(ModelUpdate::Description))
1226                                    }
1227                                    ToolUseOutput::Failure(message) => {
1228                                        (None, Some(ModelUpdate::Failure(message)))
1229                                    }
1230                                };
1231                                if let Some(update) = update {
1232                                    let _ = message_tx.unbounded_send(update);
1233                                }
1234                                first_text = text;
1235                                if first_text.is_some() {
1236                                    break;
1237                                }
1238                            }
1239                        }
1240                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1241                            *last_token_usage.lock() = token_usage;
1242                        }
1243                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1244                            let mut lock = total_text.lock();
1245                            lock.push_str(&text);
1246                        }
1247                        Ok(e) => {
1248                            log::warn!("Unexpected event: {:?}", e);
1249                            break;
1250                        }
1251                        Err(e) => {
1252                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1253                            break;
1254                        }
1255                    }
1256                }
1257            }
1258
1259            let Some(first_text) = first_text else {
1260                finish_with_status(CodegenStatus::Done, cx);
1261                return;
1262            };
1263
1264            let move_last_token_usage = last_token_usage.clone();
1265
1266            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1267                completion_events.filter_map(move |e| {
1268                    let process_tool_use = process_tool_use.clone();
1269                    let last_token_usage = move_last_token_usage.clone();
1270                    let total_text = total_text.clone();
1271                    let mut message_tx = message_tx.clone();
1272                    async move {
1273                        match e {
1274                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1275                                let Some(output) = process_tool_use(tool_use) else {
1276                                    return None;
1277                                };
1278                                let (text, update) = match output {
1279                                    ToolUseOutput::Rewrite { text, description } => {
1280                                        (Some(text), description.map(ModelUpdate::Description))
1281                                    }
1282                                    ToolUseOutput::Failure(message) => {
1283                                        (None, Some(ModelUpdate::Failure(message)))
1284                                    }
1285                                };
1286                                if let Some(update) = update {
1287                                    let _ = message_tx.send(update).await;
1288                                }
1289                                text.map(Ok)
1290                            }
1291                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1292                                *last_token_usage.lock() = token_usage;
1293                                None
1294                            }
1295                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1296                                let mut lock = total_text.lock();
1297                                lock.push_str(&text);
1298                                None
1299                            }
1300                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1301                            e => {
1302                                log::error!("UNEXPECTED EVENT {:?}", e);
1303                                None
1304                            }
1305                        }
1306                    }
1307                }),
1308            ));
1309
1310            let language_model_text_stream = LanguageModelTextStream {
1311                message_id: message_id,
1312                stream: text_stream,
1313                last_token_usage,
1314            };
1315
1316            let Some(task) = codegen
1317                .update(cx, move |codegen, cx| {
1318                    codegen.handle_stream(
1319                        model,
1320                        /* strip_invalid_spans: */ false,
1321                        async { Ok(language_model_text_stream) },
1322                        cx,
1323                    )
1324                })
1325                .ok()
1326            else {
1327                return;
1328            };
1329
1330            task.await;
1331        })
1332    }
1333}
1334
1335#[derive(Copy, Clone, Debug)]
1336pub enum CodegenEvent {
1337    Finished,
1338    Undone,
1339}
1340
1341struct StripInvalidSpans<T> {
1342    stream: T,
1343    stream_done: bool,
1344    buffer: String,
1345    first_line: bool,
1346    line_end: bool,
1347    starts_with_code_block: bool,
1348}
1349
1350impl<T> StripInvalidSpans<T>
1351where
1352    T: Stream<Item = Result<String>>,
1353{
1354    fn new(stream: T) -> Self {
1355        Self {
1356            stream,
1357            stream_done: false,
1358            buffer: String::new(),
1359            first_line: true,
1360            line_end: false,
1361            starts_with_code_block: false,
1362        }
1363    }
1364}
1365
1366impl<T> Stream for StripInvalidSpans<T>
1367where
1368    T: Stream<Item = Result<String>>,
1369{
1370    type Item = Result<String>;
1371
1372    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1373        const CODE_BLOCK_DELIMITER: &str = "```";
1374        const CURSOR_SPAN: &str = "<|CURSOR|>";
1375
1376        let this = unsafe { self.get_unchecked_mut() };
1377        loop {
1378            if !this.stream_done {
1379                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1380                match stream.as_mut().poll_next(cx) {
1381                    Poll::Ready(Some(Ok(chunk))) => {
1382                        this.buffer.push_str(&chunk);
1383                    }
1384                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1385                    Poll::Ready(None) => {
1386                        this.stream_done = true;
1387                    }
1388                    Poll::Pending => return Poll::Pending,
1389                }
1390            }
1391
1392            let mut chunk = String::new();
1393            let mut consumed = 0;
1394            if !this.buffer.is_empty() {
1395                let mut lines = this.buffer.split('\n').enumerate().peekable();
1396                while let Some((line_ix, line)) = lines.next() {
1397                    if line_ix > 0 {
1398                        this.first_line = false;
1399                    }
1400
1401                    if this.first_line {
1402                        let trimmed_line = line.trim();
1403                        if lines.peek().is_some() {
1404                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1405                                consumed += line.len() + 1;
1406                                this.starts_with_code_block = true;
1407                                continue;
1408                            }
1409                        } else if trimmed_line.is_empty()
1410                            || prefixes(CODE_BLOCK_DELIMITER)
1411                                .any(|prefix| trimmed_line.starts_with(prefix))
1412                        {
1413                            break;
1414                        }
1415                    }
1416
1417                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1418                    if lines.peek().is_some() {
1419                        if this.line_end {
1420                            chunk.push('\n');
1421                        }
1422
1423                        chunk.push_str(&line_without_cursor);
1424                        this.line_end = true;
1425                        consumed += line.len() + 1;
1426                    } else if this.stream_done {
1427                        if !this.starts_with_code_block
1428                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1429                        {
1430                            if this.line_end {
1431                                chunk.push('\n');
1432                            }
1433
1434                            chunk.push_str(line);
1435                        }
1436
1437                        consumed += line.len();
1438                    } else {
1439                        let trimmed_line = line.trim();
1440                        if trimmed_line.is_empty()
1441                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1442                            || prefixes(CODE_BLOCK_DELIMITER)
1443                                .any(|prefix| trimmed_line.ends_with(prefix))
1444                        {
1445                            break;
1446                        } else {
1447                            if this.line_end {
1448                                chunk.push('\n');
1449                                this.line_end = false;
1450                            }
1451
1452                            chunk.push_str(&line_without_cursor);
1453                            consumed += line.len();
1454                        }
1455                    }
1456                }
1457            }
1458
1459            this.buffer = this.buffer.split_off(consumed);
1460            if !chunk.is_empty() {
1461                return Poll::Ready(Some(Ok(chunk)));
1462            } else if this.stream_done {
1463                return Poll::Ready(None);
1464            }
1465        }
1466    }
1467}
1468
1469fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1470    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1471}
1472
1473#[derive(Default)]
1474pub struct Diff {
1475    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1476    pub inserted_row_ranges: Vec<Range<Anchor>>,
1477}
1478
1479impl Diff {
1480    fn is_empty(&self) -> bool {
1481        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1482    }
1483}
1484
1485#[cfg(test)]
1486mod tests {
1487    use super::*;
1488    use futures::{
1489        Stream,
1490        stream::{self},
1491    };
1492    use gpui::TestAppContext;
1493    use indoc::indoc;
1494    use language::{Buffer, Point};
1495    use language_model::fake_provider::FakeLanguageModel;
1496    use language_model::{LanguageModelRegistry, TokenUsage};
1497    use languages::rust_lang;
1498    use rand::prelude::*;
1499    use settings::SettingsStore;
1500    use std::{future, sync::Arc};
1501
1502    #[gpui::test(iterations = 10)]
1503    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1504        init_test(cx);
1505
1506        let text = indoc! {"
1507            fn main() {
1508                let x = 0;
1509                for _ in 0..10 {
1510                    x += 1;
1511                }
1512            }
1513        "};
1514        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1515        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1516        let range = buffer.read_with(cx, |buffer, cx| {
1517            let snapshot = buffer.snapshot(cx);
1518            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1519        });
1520        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1521        let codegen = cx.new(|cx| {
1522            CodegenAlternative::new(
1523                buffer.clone(),
1524                range.clone(),
1525                true,
1526                prompt_builder,
1527                Uuid::new_v4(),
1528                cx,
1529            )
1530        });
1531
1532        let chunks_tx = simulate_response_stream(&codegen, cx);
1533
1534        let mut new_text = concat!(
1535            "       let mut x = 0;\n",
1536            "       while x < 10 {\n",
1537            "           x += 1;\n",
1538            "       }",
1539        );
1540        while !new_text.is_empty() {
1541            let max_len = cmp::min(new_text.len(), 10);
1542            let len = rng.random_range(1..=max_len);
1543            let (chunk, suffix) = new_text.split_at(len);
1544            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1545            new_text = suffix;
1546            cx.background_executor.run_until_parked();
1547        }
1548        drop(chunks_tx);
1549        cx.background_executor.run_until_parked();
1550
1551        assert_eq!(
1552            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1553            indoc! {"
1554                fn main() {
1555                    let mut x = 0;
1556                    while x < 10 {
1557                        x += 1;
1558                    }
1559                }
1560            "}
1561        );
1562    }
1563
1564    #[gpui::test(iterations = 10)]
1565    async fn test_autoindent_when_generating_past_indentation(
1566        cx: &mut TestAppContext,
1567        mut rng: StdRng,
1568    ) {
1569        init_test(cx);
1570
1571        let text = indoc! {"
1572            fn main() {
1573                le
1574            }
1575        "};
1576        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1577        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1578        let range = buffer.read_with(cx, |buffer, cx| {
1579            let snapshot = buffer.snapshot(cx);
1580            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1581        });
1582        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1583        let codegen = cx.new(|cx| {
1584            CodegenAlternative::new(
1585                buffer.clone(),
1586                range.clone(),
1587                true,
1588                prompt_builder,
1589                Uuid::new_v4(),
1590                cx,
1591            )
1592        });
1593
1594        let chunks_tx = simulate_response_stream(&codegen, cx);
1595
1596        cx.background_executor.run_until_parked();
1597
1598        let mut new_text = concat!(
1599            "t mut x = 0;\n",
1600            "while x < 10 {\n",
1601            "    x += 1;\n",
1602            "}", //
1603        );
1604        while !new_text.is_empty() {
1605            let max_len = cmp::min(new_text.len(), 10);
1606            let len = rng.random_range(1..=max_len);
1607            let (chunk, suffix) = new_text.split_at(len);
1608            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1609            new_text = suffix;
1610            cx.background_executor.run_until_parked();
1611        }
1612        drop(chunks_tx);
1613        cx.background_executor.run_until_parked();
1614
1615        assert_eq!(
1616            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1617            indoc! {"
1618                fn main() {
1619                    let mut x = 0;
1620                    while x < 10 {
1621                        x += 1;
1622                    }
1623                }
1624            "}
1625        );
1626    }
1627
1628    #[gpui::test(iterations = 10)]
1629    async fn test_autoindent_when_generating_before_indentation(
1630        cx: &mut TestAppContext,
1631        mut rng: StdRng,
1632    ) {
1633        init_test(cx);
1634
1635        let text = concat!(
1636            "fn main() {\n",
1637            "  \n",
1638            "}\n" //
1639        );
1640        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1641        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1642        let range = buffer.read_with(cx, |buffer, cx| {
1643            let snapshot = buffer.snapshot(cx);
1644            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1645        });
1646        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1647        let codegen = cx.new(|cx| {
1648            CodegenAlternative::new(
1649                buffer.clone(),
1650                range.clone(),
1651                true,
1652                prompt_builder,
1653                Uuid::new_v4(),
1654                cx,
1655            )
1656        });
1657
1658        let chunks_tx = simulate_response_stream(&codegen, cx);
1659
1660        cx.background_executor.run_until_parked();
1661
1662        let mut new_text = concat!(
1663            "let mut x = 0;\n",
1664            "while x < 10 {\n",
1665            "    x += 1;\n",
1666            "}", //
1667        );
1668        while !new_text.is_empty() {
1669            let max_len = cmp::min(new_text.len(), 10);
1670            let len = rng.random_range(1..=max_len);
1671            let (chunk, suffix) = new_text.split_at(len);
1672            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1673            new_text = suffix;
1674            cx.background_executor.run_until_parked();
1675        }
1676        drop(chunks_tx);
1677        cx.background_executor.run_until_parked();
1678
1679        assert_eq!(
1680            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1681            indoc! {"
1682                fn main() {
1683                    let mut x = 0;
1684                    while x < 10 {
1685                        x += 1;
1686                    }
1687                }
1688            "}
1689        );
1690    }
1691
1692    #[gpui::test(iterations = 10)]
1693    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1694        init_test(cx);
1695
1696        let text = indoc! {"
1697            func main() {
1698            \tx := 0
1699            \tfor i := 0; i < 10; i++ {
1700            \t\tx++
1701            \t}
1702            }
1703        "};
1704        let buffer = cx.new(|cx| Buffer::local(text, cx));
1705        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1706        let range = buffer.read_with(cx, |buffer, cx| {
1707            let snapshot = buffer.snapshot(cx);
1708            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1709        });
1710        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1711        let codegen = cx.new(|cx| {
1712            CodegenAlternative::new(
1713                buffer.clone(),
1714                range.clone(),
1715                true,
1716                prompt_builder,
1717                Uuid::new_v4(),
1718                cx,
1719            )
1720        });
1721
1722        let chunks_tx = simulate_response_stream(&codegen, cx);
1723        let new_text = concat!(
1724            "func main() {\n",
1725            "\tx := 0\n",
1726            "\tfor x < 10 {\n",
1727            "\t\tx++\n",
1728            "\t}", //
1729        );
1730        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1731        drop(chunks_tx);
1732        cx.background_executor.run_until_parked();
1733
1734        assert_eq!(
1735            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1736            indoc! {"
1737                func main() {
1738                \tx := 0
1739                \tfor x < 10 {
1740                \t\tx++
1741                \t}
1742                }
1743            "}
1744        );
1745    }
1746
1747    #[gpui::test]
1748    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1749        init_test(cx);
1750
1751        let text = indoc! {"
1752            fn main() {
1753                let x = 0;
1754            }
1755        "};
1756        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1757        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1758        let range = buffer.read_with(cx, |buffer, cx| {
1759            let snapshot = buffer.snapshot(cx);
1760            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1761        });
1762        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1763        let codegen = cx.new(|cx| {
1764            CodegenAlternative::new(
1765                buffer.clone(),
1766                range.clone(),
1767                false,
1768                prompt_builder,
1769                Uuid::new_v4(),
1770                cx,
1771            )
1772        });
1773
1774        let chunks_tx = simulate_response_stream(&codegen, cx);
1775        chunks_tx
1776            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1777            .unwrap();
1778        drop(chunks_tx);
1779        cx.run_until_parked();
1780
1781        // The codegen is inactive, so the buffer doesn't get modified.
1782        assert_eq!(
1783            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1784            text
1785        );
1786
1787        // Activating the codegen applies the changes.
1788        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1789        assert_eq!(
1790            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1791            indoc! {"
1792                fn main() {
1793                    let mut x = 0;
1794                    x += 1;
1795                }
1796            "}
1797        );
1798
1799        // Deactivating the codegen undoes the changes.
1800        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1801        cx.run_until_parked();
1802        assert_eq!(
1803            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1804            text
1805        );
1806    }
1807
1808    #[gpui::test]
1809    async fn test_strip_invalid_spans_from_codeblock() {
1810        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1811        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1812        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1813        assert_chunks(
1814            "```html\n```js\nLorem ipsum dolor\n```\n```",
1815            "```js\nLorem ipsum dolor\n```",
1816        )
1817        .await;
1818        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1819        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1820        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1821        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1822
1823        async fn assert_chunks(text: &str, expected_text: &str) {
1824            for chunk_size in 1..=text.len() {
1825                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1826                    .map(|chunk| chunk.unwrap())
1827                    .collect::<String>()
1828                    .await;
1829                assert_eq!(
1830                    actual_text, expected_text,
1831                    "failed to strip invalid spans, chunk size: {}",
1832                    chunk_size
1833                );
1834            }
1835        }
1836
1837        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1838            stream::iter(
1839                text.chars()
1840                    .collect::<Vec<_>>()
1841                    .chunks(size)
1842                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1843                    .collect::<Vec<_>>(),
1844            )
1845        }
1846    }
1847
1848    fn init_test(cx: &mut TestAppContext) {
1849        cx.update(LanguageModelRegistry::test);
1850        cx.set_global(cx.update(SettingsStore::test));
1851    }
1852
1853    fn simulate_response_stream(
1854        codegen: &Entity<CodegenAlternative>,
1855        cx: &mut TestAppContext,
1856    ) -> mpsc::UnboundedSender<String> {
1857        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1858        let model = Arc::new(FakeLanguageModel::default());
1859        codegen.update(cx, |codegen, cx| {
1860            codegen.generation = codegen.handle_stream(
1861                model,
1862                /* strip_invalid_spans: */ false,
1863                future::ready(Ok(LanguageModelTextStream {
1864                    message_id: None,
1865                    stream: chunks_rx.map(Ok).boxed(),
1866                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1867                })),
1868                cx,
1869            );
1870        });
1871        chunks_tx
1872    }
1873}