buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4
   5use client::telemetry::Telemetry;
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
  10use futures::{
  11    SinkExt, Stream, StreamExt, TryStreamExt as _,
  12    channel::mpsc,
  13    future::{LocalBoxFuture, Shared},
  14    join,
  15    stream::BoxStream,
  16};
  17use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  18use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
  19use language_model::{
  20    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  22    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  23    LanguageModelToolUse, Role, TokenUsage, report_assistant_event,
  24};
  25use multi_buffer::MultiBufferRow;
  26use parking_lot::Mutex;
  27use prompt_store::PromptBuilder;
  28use rope::Rope;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings as _;
  32use smol::future::FutureExt;
  33use std::{
  34    cmp,
  35    future::Future,
  36    iter,
  37    ops::{Range, RangeInclusive},
  38    pin::Pin,
  39    sync::Arc,
  40    task::{self, Poll},
  41    time::Instant,
  42};
  43use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  44use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
  45use ui::SharedString;
  46
  47/// Use this tool to provide a message to the user when you're unable to complete a task.
  48#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  49pub struct FailureMessageInput {
  50    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  51    ///
  52    /// The message may use markdown formatting if you wish.
  53    #[serde(default)]
  54    pub message: String,
  55}
  56
  57/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  58#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  59pub struct RewriteSectionInput {
  60    /// A brief description of the edit you have made.
  61    ///
  62    /// The description may use markdown formatting if you wish.
  63    /// This is optional - if the edit is simple or obvious, you should leave it empty.
  64    #[serde(default)]
  65    pub description: String,
  66
  67    /// The text to replace the section with.
  68    #[serde(default)]
  69    pub replacement_text: String,
  70}
  71
  72pub struct BufferCodegen {
  73    alternatives: Vec<Entity<CodegenAlternative>>,
  74    pub active_alternative: usize,
  75    seen_alternatives: HashSet<usize>,
  76    subscriptions: Vec<Subscription>,
  77    buffer: Entity<MultiBuffer>,
  78    range: Range<Anchor>,
  79    initial_transaction_id: Option<TransactionId>,
  80    telemetry: Arc<Telemetry>,
  81    builder: Arc<PromptBuilder>,
  82    pub is_insertion: bool,
  83}
  84
  85impl BufferCodegen {
  86    pub fn new(
  87        buffer: Entity<MultiBuffer>,
  88        range: Range<Anchor>,
  89        initial_transaction_id: Option<TransactionId>,
  90        telemetry: Arc<Telemetry>,
  91        builder: Arc<PromptBuilder>,
  92        cx: &mut Context<Self>,
  93    ) -> Self {
  94        let codegen = cx.new(|cx| {
  95            CodegenAlternative::new(
  96                buffer.clone(),
  97                range.clone(),
  98                false,
  99                Some(telemetry.clone()),
 100                builder.clone(),
 101                cx,
 102            )
 103        });
 104        let mut this = Self {
 105            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
 106            alternatives: vec![codegen],
 107            active_alternative: 0,
 108            seen_alternatives: HashSet::default(),
 109            subscriptions: Vec::new(),
 110            buffer,
 111            range,
 112            initial_transaction_id,
 113            telemetry,
 114            builder,
 115        };
 116        this.activate(0, cx);
 117        this
 118    }
 119
 120    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 121        let codegen = self.active_alternative().clone();
 122        self.subscriptions.clear();
 123        self.subscriptions
 124            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 125        self.subscriptions
 126            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 127    }
 128
 129    pub fn active_completion(&self, cx: &App) -> Option<String> {
 130        self.active_alternative().read(cx).current_completion()
 131    }
 132
 133    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 134        &self.alternatives[self.active_alternative]
 135    }
 136
 137    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 138        &self.active_alternative().read(cx).status
 139    }
 140
 141    pub fn alternative_count(&self, cx: &App) -> usize {
 142        LanguageModelRegistry::read_global(cx)
 143            .inline_alternative_models()
 144            .len()
 145            + 1
 146    }
 147
 148    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 149        let next_active_ix = if self.active_alternative == 0 {
 150            self.alternatives.len() - 1
 151        } else {
 152            self.active_alternative - 1
 153        };
 154        self.activate(next_active_ix, cx);
 155    }
 156
 157    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 158        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 159        self.activate(next_active_ix, cx);
 160    }
 161
 162    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 163        self.active_alternative()
 164            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 165        self.seen_alternatives.insert(index);
 166        self.active_alternative = index;
 167        self.active_alternative()
 168            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 169        self.subscribe_to_alternative(cx);
 170        cx.notify();
 171    }
 172
 173    pub fn start(
 174        &mut self,
 175        primary_model: Arc<dyn LanguageModel>,
 176        user_prompt: String,
 177        context_task: Shared<Task<Option<LoadedContext>>>,
 178        cx: &mut Context<Self>,
 179    ) -> Result<()> {
 180        let alternative_models = LanguageModelRegistry::read_global(cx)
 181            .inline_alternative_models()
 182            .to_vec();
 183
 184        self.active_alternative()
 185            .update(cx, |alternative, cx| alternative.undo(cx));
 186        self.activate(0, cx);
 187        self.alternatives.truncate(1);
 188
 189        for _ in 0..alternative_models.len() {
 190            self.alternatives.push(cx.new(|cx| {
 191                CodegenAlternative::new(
 192                    self.buffer.clone(),
 193                    self.range.clone(),
 194                    false,
 195                    Some(self.telemetry.clone()),
 196                    self.builder.clone(),
 197                    cx,
 198                )
 199            }));
 200        }
 201
 202        for (model, alternative) in iter::once(primary_model)
 203            .chain(alternative_models)
 204            .zip(&self.alternatives)
 205        {
 206            alternative.update(cx, |alternative, cx| {
 207                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 208            })?;
 209        }
 210
 211        Ok(())
 212    }
 213
 214    pub fn stop(&mut self, cx: &mut Context<Self>) {
 215        for codegen in &self.alternatives {
 216            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 217        }
 218    }
 219
 220    pub fn undo(&mut self, cx: &mut Context<Self>) {
 221        self.active_alternative()
 222            .update(cx, |codegen, cx| codegen.undo(cx));
 223
 224        self.buffer.update(cx, |buffer, cx| {
 225            if let Some(transaction_id) = self.initial_transaction_id.take() {
 226                buffer.undo_transaction(transaction_id, cx);
 227                buffer.refresh_preview(cx);
 228            }
 229        });
 230    }
 231
 232    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 233        self.active_alternative().read(cx).buffer.clone()
 234    }
 235
 236    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 237        self.active_alternative().read(cx).old_buffer.clone()
 238    }
 239
 240    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 241        self.active_alternative().read(cx).snapshot.clone()
 242    }
 243
 244    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 245        self.active_alternative().read(cx).edit_position
 246    }
 247
 248    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 249        &self.active_alternative().read(cx).diff
 250    }
 251
 252    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 253        self.active_alternative().read(cx).last_equal_ranges()
 254    }
 255
 256    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 257        self.active_alternative().read(cx).selected_text()
 258    }
 259}
 260
 261impl EventEmitter<CodegenEvent> for BufferCodegen {}
 262
 263pub struct CodegenAlternative {
 264    buffer: Entity<MultiBuffer>,
 265    old_buffer: Entity<Buffer>,
 266    snapshot: MultiBufferSnapshot,
 267    edit_position: Option<Anchor>,
 268    range: Range<Anchor>,
 269    last_equal_ranges: Vec<Range<Anchor>>,
 270    transformation_transaction_id: Option<TransactionId>,
 271    status: CodegenStatus,
 272    generation: Task<()>,
 273    diff: Diff,
 274    telemetry: Option<Arc<Telemetry>>,
 275    _subscription: gpui::Subscription,
 276    builder: Arc<PromptBuilder>,
 277    active: bool,
 278    edits: Vec<(Range<Anchor>, String)>,
 279    line_operations: Vec<LineOperation>,
 280    elapsed_time: Option<f64>,
 281    completion: Option<String>,
 282    selected_text: Option<String>,
 283    pub message_id: Option<String>,
 284    pub model_explanation: Option<SharedString>,
 285}
 286
 287impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 288
 289impl CodegenAlternative {
 290    pub fn new(
 291        buffer: Entity<MultiBuffer>,
 292        range: Range<Anchor>,
 293        active: bool,
 294        telemetry: Option<Arc<Telemetry>>,
 295        builder: Arc<PromptBuilder>,
 296        cx: &mut Context<Self>,
 297    ) -> Self {
 298        let snapshot = buffer.read(cx).snapshot(cx);
 299
 300        let (old_buffer, _, _) = snapshot
 301            .range_to_buffer_ranges(range.clone())
 302            .pop()
 303            .unwrap();
 304        let old_buffer = cx.new(|cx| {
 305            let text = old_buffer.as_rope().clone();
 306            let line_ending = old_buffer.line_ending();
 307            let language = old_buffer.language().cloned();
 308            let language_registry = buffer
 309                .read(cx)
 310                .buffer(old_buffer.remote_id())
 311                .unwrap()
 312                .read(cx)
 313                .language_registry();
 314
 315            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 316            buffer.set_language(language, cx);
 317            if let Some(language_registry) = language_registry {
 318                buffer.set_language_registry(language_registry);
 319            }
 320            buffer
 321        });
 322
 323        Self {
 324            buffer: buffer.clone(),
 325            old_buffer,
 326            edit_position: None,
 327            message_id: None,
 328            snapshot,
 329            last_equal_ranges: Default::default(),
 330            transformation_transaction_id: None,
 331            status: CodegenStatus::Idle,
 332            generation: Task::ready(()),
 333            diff: Diff::default(),
 334            telemetry,
 335            builder,
 336            active: active,
 337            edits: Vec::new(),
 338            line_operations: Vec::new(),
 339            range,
 340            elapsed_time: None,
 341            completion: None,
 342            selected_text: None,
 343            model_explanation: None,
 344            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 345        }
 346    }
 347
 348    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 349        if active != self.active {
 350            self.active = active;
 351
 352            if self.active {
 353                let edits = self.edits.clone();
 354                self.apply_edits(edits, cx);
 355                if matches!(self.status, CodegenStatus::Pending) {
 356                    let line_operations = self.line_operations.clone();
 357                    self.reapply_line_based_diff(line_operations, cx);
 358                } else {
 359                    self.reapply_batch_diff(cx).detach();
 360                }
 361            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 362                self.buffer.update(cx, |buffer, cx| {
 363                    buffer.undo_transaction(transaction_id, cx);
 364                    buffer.forget_transaction(transaction_id, cx);
 365                });
 366            }
 367        }
 368    }
 369
 370    fn handle_buffer_event(
 371        &mut self,
 372        _buffer: Entity<MultiBuffer>,
 373        event: &multi_buffer::Event,
 374        cx: &mut Context<Self>,
 375    ) {
 376        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 377            && self.transformation_transaction_id == Some(*transaction_id)
 378        {
 379            self.transformation_transaction_id = None;
 380            self.generation = Task::ready(());
 381            cx.emit(CodegenEvent::Undone);
 382        }
 383    }
 384
 385    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 386        &self.last_equal_ranges
 387    }
 388
 389    fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 390        model.supports_streaming_tools()
 391            && cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
 392            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 393    }
 394
 395    pub fn start(
 396        &mut self,
 397        user_prompt: String,
 398        context_task: Shared<Task<Option<LoadedContext>>>,
 399        model: Arc<dyn LanguageModel>,
 400        cx: &mut Context<Self>,
 401    ) -> Result<()> {
 402        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 403            self.buffer.update(cx, |buffer, cx| {
 404                buffer.undo_transaction(transformation_transaction_id, cx);
 405            });
 406        }
 407
 408        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 409
 410        let api_key = model.api_key(cx);
 411        let telemetry_id = model.telemetry_id();
 412        let provider_id = model.provider_id();
 413
 414        if Self::use_streaming_tools(model.as_ref(), cx) {
 415            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 416            let completion_events =
 417                cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
 418            self.generation = self.handle_completion(
 419                telemetry_id,
 420                provider_id.to_string(),
 421                api_key,
 422                completion_events,
 423                cx,
 424            );
 425        } else {
 426            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 427                if user_prompt.trim().to_lowercase() == "delete" {
 428                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 429                } else {
 430                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 431                    cx.spawn(async move |_, cx| {
 432                        Ok(model.stream_completion_text(request.await, cx).await?)
 433                    })
 434                    .boxed_local()
 435                };
 436            self.generation =
 437                self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
 438        }
 439
 440        Ok(())
 441    }
 442
 443    fn build_request_tools(
 444        &self,
 445        model: &Arc<dyn LanguageModel>,
 446        user_prompt: String,
 447        context_task: Shared<Task<Option<LoadedContext>>>,
 448        cx: &mut App,
 449    ) -> Result<Task<LanguageModelRequest>> {
 450        let buffer = self.buffer.read(cx).snapshot(cx);
 451        let language = buffer.language_at(self.range.start);
 452        let language_name = if let Some(language) = language.as_ref() {
 453            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 454                None
 455            } else {
 456                Some(language.name())
 457            }
 458        } else {
 459            None
 460        };
 461
 462        let language_name = language_name.as_ref();
 463        let start = buffer.point_to_buffer_offset(self.range.start);
 464        let end = buffer.point_to_buffer_offset(self.range.end);
 465        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 466            let (start_buffer, start_buffer_offset) = start;
 467            let (end_buffer, end_buffer_offset) = end;
 468            if start_buffer.remote_id() == end_buffer.remote_id() {
 469                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 470            } else {
 471                anyhow::bail!("invalid transformation range");
 472            }
 473        } else {
 474            anyhow::bail!("invalid transformation range");
 475        };
 476
 477        let system_prompt = self
 478            .builder
 479            .generate_inline_transformation_prompt_tools(
 480                language_name,
 481                buffer,
 482                range.start.0..range.end.0,
 483            )
 484            .context("generating content prompt")?;
 485
 486        let temperature = AgentSettings::temperature_for_model(model, cx);
 487
 488        let tool_input_format = model.tool_input_format();
 489        let tool_choice = model
 490            .supports_tool_choice(LanguageModelToolChoice::Any)
 491            .then_some(LanguageModelToolChoice::Any);
 492
 493        Ok(cx.spawn(async move |_cx| {
 494            let mut messages = vec![LanguageModelRequestMessage {
 495                role: Role::System,
 496                content: vec![system_prompt.into()],
 497                cache: false,
 498                reasoning_details: None,
 499            }];
 500
 501            let mut user_message = LanguageModelRequestMessage {
 502                role: Role::User,
 503                content: Vec::new(),
 504                cache: false,
 505                reasoning_details: None,
 506            };
 507
 508            if let Some(context) = context_task.await {
 509                context.add_to_request_message(&mut user_message);
 510            }
 511
 512            user_message.content.push(user_prompt.into());
 513            messages.push(user_message);
 514
 515            let tools = vec![
 516                LanguageModelRequestTool {
 517                    name: "rewrite_section".to_string(),
 518                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 519                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 520                },
 521                LanguageModelRequestTool {
 522                    name: "failure_message".to_string(),
 523                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 524                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 525                },
 526            ];
 527
 528            LanguageModelRequest {
 529                thread_id: None,
 530                prompt_id: None,
 531                intent: Some(CompletionIntent::InlineAssist),
 532                mode: None,
 533                tools,
 534                tool_choice,
 535                stop: Vec::new(),
 536                temperature,
 537                messages,
 538                thinking_allowed: false,
 539            }
 540        }))
 541    }
 542
 543    fn build_request(
 544        &self,
 545        model: &Arc<dyn LanguageModel>,
 546        user_prompt: String,
 547        context_task: Shared<Task<Option<LoadedContext>>>,
 548        cx: &mut App,
 549    ) -> Result<Task<LanguageModelRequest>> {
 550        if Self::use_streaming_tools(model.as_ref(), cx) {
 551            return self.build_request_tools(model, user_prompt, context_task, cx);
 552        }
 553
 554        let buffer = self.buffer.read(cx).snapshot(cx);
 555        let language = buffer.language_at(self.range.start);
 556        let language_name = if let Some(language) = language.as_ref() {
 557            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 558                None
 559            } else {
 560                Some(language.name())
 561            }
 562        } else {
 563            None
 564        };
 565
 566        let language_name = language_name.as_ref();
 567        let start = buffer.point_to_buffer_offset(self.range.start);
 568        let end = buffer.point_to_buffer_offset(self.range.end);
 569        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 570            let (start_buffer, start_buffer_offset) = start;
 571            let (end_buffer, end_buffer_offset) = end;
 572            if start_buffer.remote_id() == end_buffer.remote_id() {
 573                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 574            } else {
 575                anyhow::bail!("invalid transformation range");
 576            }
 577        } else {
 578            anyhow::bail!("invalid transformation range");
 579        };
 580
 581        let prompt = self
 582            .builder
 583            .generate_inline_transformation_prompt(
 584                user_prompt,
 585                language_name,
 586                buffer,
 587                range.start.0..range.end.0,
 588            )
 589            .context("generating content prompt")?;
 590
 591        let temperature = AgentSettings::temperature_for_model(model, cx);
 592
 593        Ok(cx.spawn(async move |_cx| {
 594            let mut request_message = LanguageModelRequestMessage {
 595                role: Role::User,
 596                content: Vec::new(),
 597                cache: false,
 598                reasoning_details: None,
 599            };
 600
 601            if let Some(context) = context_task.await {
 602                context.add_to_request_message(&mut request_message);
 603            }
 604
 605            request_message.content.push(prompt.into());
 606
 607            LanguageModelRequest {
 608                thread_id: None,
 609                prompt_id: None,
 610                intent: Some(CompletionIntent::InlineAssist),
 611                mode: None,
 612                tools: Vec::new(),
 613                tool_choice: None,
 614                stop: Vec::new(),
 615                temperature,
 616                messages: vec![request_message],
 617                thinking_allowed: false,
 618            }
 619        }))
 620    }
 621
 622    pub fn handle_stream(
 623        &mut self,
 624        model_telemetry_id: String,
 625        model_provider_id: String,
 626        model_api_key: Option<String>,
 627        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 628        cx: &mut Context<Self>,
 629    ) -> Task<()> {
 630        let start_time = Instant::now();
 631
 632        // Make a new snapshot and re-resolve anchor in case the document was modified.
 633        // This can happen often if the editor loses focus and is saved + reformatted,
 634        // as in https://github.com/zed-industries/zed/issues/39088
 635        self.snapshot = self.buffer.read(cx).snapshot(cx);
 636        self.range = self.snapshot.anchor_after(self.range.start)
 637            ..self.snapshot.anchor_after(self.range.end);
 638
 639        let snapshot = self.snapshot.clone();
 640        let selected_text = snapshot
 641            .text_for_range(self.range.start..self.range.end)
 642            .collect::<Rope>();
 643
 644        self.selected_text = Some(selected_text.to_string());
 645
 646        let selection_start = self.range.start.to_point(&snapshot);
 647
 648        // Start with the indentation of the first line in the selection
 649        let mut suggested_line_indent = snapshot
 650            .suggested_indents(selection_start.row..=selection_start.row, cx)
 651            .into_values()
 652            .next()
 653            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 654
 655        // If the first line in the selection does not have indentation, check the following lines
 656        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 657            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 658                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 659                // Prefer tabs if a line in the selection uses tabs as indentation
 660                if line_indent.kind == IndentKind::Tab {
 661                    suggested_line_indent.kind = IndentKind::Tab;
 662                    break;
 663                }
 664            }
 665        }
 666
 667        let http_client = cx.http_client();
 668        let telemetry = self.telemetry.clone();
 669        let language_name = {
 670            let multibuffer = self.buffer.read(cx);
 671            let snapshot = multibuffer.snapshot(cx);
 672            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 673            ranges
 674                .first()
 675                .and_then(|(buffer, _, _)| buffer.language())
 676                .map(|language| language.name())
 677        };
 678
 679        self.diff = Diff::default();
 680        self.status = CodegenStatus::Pending;
 681        let mut edit_start = self.range.start.to_offset(&snapshot);
 682        let completion = Arc::new(Mutex::new(String::new()));
 683        let completion_clone = completion.clone();
 684
 685        cx.notify();
 686        cx.spawn(async move |codegen, cx| {
 687            let stream = stream.await;
 688
 689            let token_usage = stream
 690                .as_ref()
 691                .ok()
 692                .map(|stream| stream.last_token_usage.clone());
 693            let message_id = stream
 694                .as_ref()
 695                .ok()
 696                .and_then(|stream| stream.message_id.clone());
 697            let generate = async {
 698                let model_telemetry_id = model_telemetry_id.clone();
 699                let model_provider_id = model_provider_id.clone();
 700                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 701                let executor = cx.background_executor().clone();
 702                let message_id = message_id.clone();
 703                let line_based_stream_diff: Task<anyhow::Result<()>> =
 704                    cx.background_spawn(async move {
 705                        let mut response_latency = None;
 706                        let request_start = Instant::now();
 707                        let diff = async {
 708                            let chunks = StripInvalidSpans::new(
 709                                stream?.stream.map_err(|error| error.into()),
 710                            );
 711                            futures::pin_mut!(chunks);
 712
 713                            let mut diff = StreamingDiff::new(selected_text.to_string());
 714                            let mut line_diff = LineDiff::default();
 715
 716                            let mut new_text = String::new();
 717                            let mut base_indent = None;
 718                            let mut line_indent = None;
 719                            let mut first_line = true;
 720
 721                            while let Some(chunk) = chunks.next().await {
 722                                if response_latency.is_none() {
 723                                    response_latency = Some(request_start.elapsed());
 724                                }
 725                                let chunk = chunk?;
 726                                completion_clone.lock().push_str(&chunk);
 727
 728                                let mut lines = chunk.split('\n').peekable();
 729                                while let Some(line) = lines.next() {
 730                                    new_text.push_str(line);
 731                                    if line_indent.is_none()
 732                                        && let Some(non_whitespace_ch_ix) =
 733                                            new_text.find(|ch: char| !ch.is_whitespace())
 734                                    {
 735                                        line_indent = Some(non_whitespace_ch_ix);
 736                                        base_indent = base_indent.or(line_indent);
 737
 738                                        let line_indent = line_indent.unwrap();
 739                                        let base_indent = base_indent.unwrap();
 740                                        let indent_delta = line_indent as i32 - base_indent as i32;
 741                                        let mut corrected_indent_len = cmp::max(
 742                                            0,
 743                                            suggested_line_indent.len as i32 + indent_delta,
 744                                        )
 745                                            as usize;
 746                                        if first_line {
 747                                            corrected_indent_len = corrected_indent_len
 748                                                .saturating_sub(selection_start.column as usize);
 749                                        }
 750
 751                                        let indent_char = suggested_line_indent.char();
 752                                        let mut indent_buffer = [0; 4];
 753                                        let indent_str =
 754                                            indent_char.encode_utf8(&mut indent_buffer);
 755                                        new_text.replace_range(
 756                                            ..line_indent,
 757                                            &indent_str.repeat(corrected_indent_len),
 758                                        );
 759                                    }
 760
 761                                    if line_indent.is_some() {
 762                                        let char_ops = diff.push_new(&new_text);
 763                                        line_diff.push_char_operations(&char_ops, &selected_text);
 764                                        diff_tx
 765                                            .send((char_ops, line_diff.line_operations()))
 766                                            .await?;
 767                                        new_text.clear();
 768                                    }
 769
 770                                    if lines.peek().is_some() {
 771                                        let char_ops = diff.push_new("\n");
 772                                        line_diff.push_char_operations(&char_ops, &selected_text);
 773                                        diff_tx
 774                                            .send((char_ops, line_diff.line_operations()))
 775                                            .await?;
 776                                        if line_indent.is_none() {
 777                                            // Don't write out the leading indentation in empty lines on the next line
 778                                            // This is the case where the above if statement didn't clear the buffer
 779                                            new_text.clear();
 780                                        }
 781                                        line_indent = None;
 782                                        first_line = false;
 783                                    }
 784                                }
 785                            }
 786
 787                            let mut char_ops = diff.push_new(&new_text);
 788                            char_ops.extend(diff.finish());
 789                            line_diff.push_char_operations(&char_ops, &selected_text);
 790                            line_diff.finish(&selected_text);
 791                            diff_tx
 792                                .send((char_ops, line_diff.line_operations()))
 793                                .await?;
 794
 795                            anyhow::Ok(())
 796                        };
 797
 798                        let result = diff.await;
 799
 800                        let error_message = result.as_ref().err().map(|error| error.to_string());
 801                        report_assistant_event(
 802                            AssistantEventData {
 803                                conversation_id: None,
 804                                message_id,
 805                                kind: AssistantKind::Inline,
 806                                phase: AssistantPhase::Response,
 807                                model: model_telemetry_id,
 808                                model_provider: model_provider_id,
 809                                response_latency,
 810                                error_message,
 811                                language_name: language_name.map(|name| name.to_proto()),
 812                            },
 813                            telemetry,
 814                            http_client,
 815                            model_api_key,
 816                            &executor,
 817                        );
 818
 819                        result?;
 820                        Ok(())
 821                    });
 822
 823                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 824                    codegen.update(cx, |codegen, cx| {
 825                        codegen.last_equal_ranges.clear();
 826
 827                        let edits = char_ops
 828                            .into_iter()
 829                            .filter_map(|operation| match operation {
 830                                CharOperation::Insert { text } => {
 831                                    let edit_start = snapshot.anchor_after(edit_start);
 832                                    Some((edit_start..edit_start, text))
 833                                }
 834                                CharOperation::Delete { bytes } => {
 835                                    let edit_end = edit_start + bytes;
 836                                    let edit_range = snapshot.anchor_after(edit_start)
 837                                        ..snapshot.anchor_before(edit_end);
 838                                    edit_start = edit_end;
 839                                    Some((edit_range, String::new()))
 840                                }
 841                                CharOperation::Keep { bytes } => {
 842                                    let edit_end = edit_start + bytes;
 843                                    let edit_range = snapshot.anchor_after(edit_start)
 844                                        ..snapshot.anchor_before(edit_end);
 845                                    edit_start = edit_end;
 846                                    codegen.last_equal_ranges.push(edit_range);
 847                                    None
 848                                }
 849                            })
 850                            .collect::<Vec<_>>();
 851
 852                        if codegen.active {
 853                            codegen.apply_edits(edits.iter().cloned(), cx);
 854                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 855                        }
 856                        codegen.edits.extend(edits);
 857                        codegen.line_operations = line_ops;
 858                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 859
 860                        cx.notify();
 861                    })?;
 862                }
 863
 864                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 865                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 866                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 867                let batch_diff_task =
 868                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 869                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 870                line_based_stream_diff?;
 871
 872                anyhow::Ok(())
 873            };
 874
 875            let result = generate.await;
 876            let elapsed_time = start_time.elapsed().as_secs_f64();
 877
 878            codegen
 879                .update(cx, |this, cx| {
 880                    this.message_id = message_id;
 881                    this.last_equal_ranges.clear();
 882                    if let Err(error) = result {
 883                        this.status = CodegenStatus::Error(error);
 884                    } else {
 885                        this.status = CodegenStatus::Done;
 886                    }
 887                    this.elapsed_time = Some(elapsed_time);
 888                    this.completion = Some(completion.lock().clone());
 889                    if let Some(usage) = token_usage {
 890                        let usage = usage.lock();
 891                        telemetry::event!(
 892                            "Inline Assistant Completion",
 893                            model = model_telemetry_id,
 894                            model_provider = model_provider_id,
 895                            input_tokens = usage.input_tokens,
 896                            output_tokens = usage.output_tokens,
 897                        )
 898                    }
 899
 900                    cx.emit(CodegenEvent::Finished);
 901                    cx.notify();
 902                })
 903                .ok();
 904        })
 905    }
 906
 907    pub fn current_completion(&self) -> Option<String> {
 908        self.completion.clone()
 909    }
 910
 911    pub fn selected_text(&self) -> Option<&str> {
 912        self.selected_text.as_deref()
 913    }
 914
 915    pub fn stop(&mut self, cx: &mut Context<Self>) {
 916        self.last_equal_ranges.clear();
 917        if self.diff.is_empty() {
 918            self.status = CodegenStatus::Idle;
 919        } else {
 920            self.status = CodegenStatus::Done;
 921        }
 922        self.generation = Task::ready(());
 923        cx.emit(CodegenEvent::Finished);
 924        cx.notify();
 925    }
 926
 927    pub fn undo(&mut self, cx: &mut Context<Self>) {
 928        self.buffer.update(cx, |buffer, cx| {
 929            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 930                buffer.undo_transaction(transaction_id, cx);
 931                buffer.refresh_preview(cx);
 932            }
 933        });
 934    }
 935
 936    fn apply_edits(
 937        &mut self,
 938        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 939        cx: &mut Context<CodegenAlternative>,
 940    ) {
 941        let transaction = self.buffer.update(cx, |buffer, cx| {
 942            // Avoid grouping agent edits with user edits.
 943            buffer.finalize_last_transaction(cx);
 944            buffer.start_transaction(cx);
 945            buffer.edit(edits, None, cx);
 946            buffer.end_transaction(cx)
 947        });
 948
 949        if let Some(transaction) = transaction {
 950            if let Some(first_transaction) = self.transformation_transaction_id {
 951                // Group all agent edits into the first transaction.
 952                self.buffer.update(cx, |buffer, cx| {
 953                    buffer.merge_transactions(transaction, first_transaction, cx)
 954                });
 955            } else {
 956                self.transformation_transaction_id = Some(transaction);
 957                self.buffer
 958                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 959            }
 960        }
 961    }
 962
 963    fn reapply_line_based_diff(
 964        &mut self,
 965        line_operations: impl IntoIterator<Item = LineOperation>,
 966        cx: &mut Context<Self>,
 967    ) {
 968        let old_snapshot = self.snapshot.clone();
 969        let old_range = self.range.to_point(&old_snapshot);
 970        let new_snapshot = self.buffer.read(cx).snapshot(cx);
 971        let new_range = self.range.to_point(&new_snapshot);
 972
 973        let mut old_row = old_range.start.row;
 974        let mut new_row = new_range.start.row;
 975
 976        self.diff.deleted_row_ranges.clear();
 977        self.diff.inserted_row_ranges.clear();
 978        for operation in line_operations {
 979            match operation {
 980                LineOperation::Keep { lines } => {
 981                    old_row += lines;
 982                    new_row += lines;
 983                }
 984                LineOperation::Delete { lines } => {
 985                    let old_end_row = old_row + lines - 1;
 986                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
 987
 988                    if let Some((_, last_deleted_row_range)) =
 989                        self.diff.deleted_row_ranges.last_mut()
 990                    {
 991                        if *last_deleted_row_range.end() + 1 == old_row {
 992                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
 993                        } else {
 994                            self.diff
 995                                .deleted_row_ranges
 996                                .push((new_row, old_row..=old_end_row));
 997                        }
 998                    } else {
 999                        self.diff
1000                            .deleted_row_ranges
1001                            .push((new_row, old_row..=old_end_row));
1002                    }
1003
1004                    old_row += lines;
1005                }
1006                LineOperation::Insert { lines } => {
1007                    let new_end_row = new_row + lines - 1;
1008                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1009                    let end = new_snapshot.anchor_before(Point::new(
1010                        new_end_row,
1011                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1012                    ));
1013                    self.diff.inserted_row_ranges.push(start..end);
1014                    new_row += lines;
1015                }
1016            }
1017
1018            cx.notify();
1019        }
1020    }
1021
1022    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1023        let old_snapshot = self.snapshot.clone();
1024        let old_range = self.range.to_point(&old_snapshot);
1025        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1026        let new_range = self.range.to_point(&new_snapshot);
1027
1028        cx.spawn(async move |codegen, cx| {
1029            let (deleted_row_ranges, inserted_row_ranges) = cx
1030                .background_spawn(async move {
1031                    let old_text = old_snapshot
1032                        .text_for_range(
1033                            Point::new(old_range.start.row, 0)
1034                                ..Point::new(
1035                                    old_range.end.row,
1036                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1037                                ),
1038                        )
1039                        .collect::<String>();
1040                    let new_text = new_snapshot
1041                        .text_for_range(
1042                            Point::new(new_range.start.row, 0)
1043                                ..Point::new(
1044                                    new_range.end.row,
1045                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1046                                ),
1047                        )
1048                        .collect::<String>();
1049
1050                    let old_start_row = old_range.start.row;
1051                    let new_start_row = new_range.start.row;
1052                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1053                    let mut inserted_row_ranges = Vec::new();
1054                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1055                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1056                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1057                        if !old_rows.is_empty() {
1058                            deleted_row_ranges.push((
1059                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1060                                old_rows.start..=old_rows.end - 1,
1061                            ));
1062                        }
1063                        if !new_rows.is_empty() {
1064                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1065                            let new_end_row = new_rows.end - 1;
1066                            let end = new_snapshot.anchor_before(Point::new(
1067                                new_end_row,
1068                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1069                            ));
1070                            inserted_row_ranges.push(start..end);
1071                        }
1072                    }
1073                    (deleted_row_ranges, inserted_row_ranges)
1074                })
1075                .await;
1076
1077            codegen
1078                .update(cx, |codegen, cx| {
1079                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1080                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1081                    cx.notify();
1082                })
1083                .ok();
1084        })
1085    }
1086
1087    fn handle_completion(
1088        &mut self,
1089        telemetry_id: String,
1090        provider_id: String,
1091        api_key: Option<String>,
1092        completion_stream: Task<
1093            Result<
1094                BoxStream<
1095                    'static,
1096                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1097                >,
1098                LanguageModelCompletionError,
1099            >,
1100        >,
1101        cx: &mut Context<Self>,
1102    ) -> Task<()> {
1103        self.diff = Diff::default();
1104        self.status = CodegenStatus::Pending;
1105
1106        cx.notify();
1107        // Leaving this in generation so that STOP equivalent events are respected even
1108        // while we're still pre-processing the completion event
1109        cx.spawn(async move |codegen, cx| {
1110            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1111                let _ = codegen.update(cx, |this, cx| {
1112                    this.status = status;
1113                    cx.emit(CodegenEvent::Finished);
1114                    cx.notify();
1115                });
1116            };
1117
1118            let mut completion_events = match completion_stream.await {
1119                Ok(events) => events,
1120                Err(err) => {
1121                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1122                    return;
1123                }
1124            };
1125
1126            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1127            let tool_to_text_and_message =
1128                move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
1129                    let mut chars_read_so_far = chars_read_so_far.lock();
1130                    match tool_use.name.as_ref() {
1131                        "rewrite_section" => {
1132                            let Ok(mut input) =
1133                                serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1134                            else {
1135                                return (None, None);
1136                            };
1137                            let value = input.replacement_text[*chars_read_so_far..].to_string();
1138                            *chars_read_so_far = input.replacement_text.len();
1139                            (Some(value), Some(std::mem::take(&mut input.description)))
1140                        }
1141                        "failure_message" => {
1142                            let Ok(mut input) =
1143                                serde_json::from_value::<FailureMessageInput>(tool_use.input)
1144                            else {
1145                                return (None, None);
1146                            };
1147                            (None, Some(std::mem::take(&mut input.message)))
1148                        }
1149                        _ => (None, None),
1150                    }
1151                };
1152
1153            let mut message_id = None;
1154            let mut first_text = None;
1155            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1156            let total_text = Arc::new(Mutex::new(String::new()));
1157
1158            loop {
1159                if let Some(first_event) = completion_events.next().await {
1160                    match first_event {
1161                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1162                            message_id = Some(id);
1163                        }
1164                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
1165                            if matches!(
1166                                tool_use.name.as_ref(),
1167                                "rewrite_section" | "failure_message"
1168                            ) =>
1169                        {
1170                            let is_complete = tool_use.is_input_complete;
1171                            let (text, message) = tool_to_text_and_message(tool_use);
1172                            // Only update the model explanation if the tool use is complete.
1173                            // Otherwise the UI element bounces around as it's updated.
1174                            if is_complete {
1175                                let _ = codegen.update(cx, |this, _cx| {
1176                                    this.model_explanation = message.map(Into::into);
1177                                });
1178                            }
1179                            first_text = text;
1180                            if first_text.is_some() {
1181                                break;
1182                            }
1183                        }
1184                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1185                            *last_token_usage.lock() = token_usage;
1186                        }
1187                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1188                            let mut lock = total_text.lock();
1189                            lock.push_str(&text);
1190                        }
1191                        Ok(e) => {
1192                            log::warn!("Unexpected event: {:?}", e);
1193                            break;
1194                        }
1195                        Err(e) => {
1196                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1197                            break;
1198                        }
1199                    }
1200                }
1201            }
1202
1203            let Some(first_text) = first_text else {
1204                finish_with_status(CodegenStatus::Done, cx);
1205                return;
1206            };
1207
1208            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
1209
1210            cx.spawn({
1211                let codegen = codegen.clone();
1212                async move |cx| {
1213                    while let Some(message) = message_rx.next().await {
1214                        let _ = codegen.update(cx, |this, _cx| {
1215                            this.model_explanation = message;
1216                        });
1217                    }
1218                }
1219            })
1220            .detach();
1221
1222            let move_last_token_usage = last_token_usage.clone();
1223
1224            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1225                completion_events.filter_map(move |e| {
1226                    let tool_to_text_and_message = tool_to_text_and_message.clone();
1227                    let last_token_usage = move_last_token_usage.clone();
1228                    let total_text = total_text.clone();
1229                    let mut message_tx = message_tx.clone();
1230                    async move {
1231                        match e {
1232                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
1233                                if matches!(
1234                                    tool_use.name.as_ref(),
1235                                    "rewrite_section" | "failure_message"
1236                                ) =>
1237                            {
1238                                let is_complete = tool_use.is_input_complete;
1239                                let (text, message) = tool_to_text_and_message(tool_use);
1240                                if is_complete {
1241                                    // Again only send the message when complete to not get a bouncing UI element.
1242                                    let _ = message_tx.send(message.map(Into::into)).await;
1243                                }
1244                                text.map(Ok)
1245                            }
1246                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1247                                *last_token_usage.lock() = token_usage;
1248                                None
1249                            }
1250                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1251                                let mut lock = total_text.lock();
1252                                lock.push_str(&text);
1253                                None
1254                            }
1255                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1256                            e => {
1257                                log::error!("UNEXPECTED EVENT {:?}", e);
1258                                None
1259                            }
1260                        }
1261                    }
1262                }),
1263            ));
1264
1265            let language_model_text_stream = LanguageModelTextStream {
1266                message_id: message_id,
1267                stream: text_stream,
1268                last_token_usage,
1269            };
1270
1271            let Some(task) = codegen
1272                .update(cx, move |codegen, cx| {
1273                    codegen.handle_stream(
1274                        telemetry_id,
1275                        provider_id,
1276                        api_key,
1277                        async { Ok(language_model_text_stream) },
1278                        cx,
1279                    )
1280                })
1281                .ok()
1282            else {
1283                return;
1284            };
1285
1286            task.await;
1287        })
1288    }
1289}
1290
1291#[derive(Copy, Clone, Debug)]
1292pub enum CodegenEvent {
1293    Finished,
1294    Undone,
1295}
1296
1297struct StripInvalidSpans<T> {
1298    stream: T,
1299    stream_done: bool,
1300    buffer: String,
1301    first_line: bool,
1302    line_end: bool,
1303    starts_with_code_block: bool,
1304}
1305
1306impl<T> StripInvalidSpans<T>
1307where
1308    T: Stream<Item = Result<String>>,
1309{
1310    fn new(stream: T) -> Self {
1311        Self {
1312            stream,
1313            stream_done: false,
1314            buffer: String::new(),
1315            first_line: true,
1316            line_end: false,
1317            starts_with_code_block: false,
1318        }
1319    }
1320}
1321
1322impl<T> Stream for StripInvalidSpans<T>
1323where
1324    T: Stream<Item = Result<String>>,
1325{
1326    type Item = Result<String>;
1327
1328    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1329        const CODE_BLOCK_DELIMITER: &str = "```";
1330        const CURSOR_SPAN: &str = "<|CURSOR|>";
1331
1332        let this = unsafe { self.get_unchecked_mut() };
1333        loop {
1334            if !this.stream_done {
1335                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1336                match stream.as_mut().poll_next(cx) {
1337                    Poll::Ready(Some(Ok(chunk))) => {
1338                        this.buffer.push_str(&chunk);
1339                    }
1340                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1341                    Poll::Ready(None) => {
1342                        this.stream_done = true;
1343                    }
1344                    Poll::Pending => return Poll::Pending,
1345                }
1346            }
1347
1348            let mut chunk = String::new();
1349            let mut consumed = 0;
1350            if !this.buffer.is_empty() {
1351                let mut lines = this.buffer.split('\n').enumerate().peekable();
1352                while let Some((line_ix, line)) = lines.next() {
1353                    if line_ix > 0 {
1354                        this.first_line = false;
1355                    }
1356
1357                    if this.first_line {
1358                        let trimmed_line = line.trim();
1359                        if lines.peek().is_some() {
1360                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1361                                consumed += line.len() + 1;
1362                                this.starts_with_code_block = true;
1363                                continue;
1364                            }
1365                        } else if trimmed_line.is_empty()
1366                            || prefixes(CODE_BLOCK_DELIMITER)
1367                                .any(|prefix| trimmed_line.starts_with(prefix))
1368                        {
1369                            break;
1370                        }
1371                    }
1372
1373                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1374                    if lines.peek().is_some() {
1375                        if this.line_end {
1376                            chunk.push('\n');
1377                        }
1378
1379                        chunk.push_str(&line_without_cursor);
1380                        this.line_end = true;
1381                        consumed += line.len() + 1;
1382                    } else if this.stream_done {
1383                        if !this.starts_with_code_block
1384                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1385                        {
1386                            if this.line_end {
1387                                chunk.push('\n');
1388                            }
1389
1390                            chunk.push_str(line);
1391                        }
1392
1393                        consumed += line.len();
1394                    } else {
1395                        let trimmed_line = line.trim();
1396                        if trimmed_line.is_empty()
1397                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1398                            || prefixes(CODE_BLOCK_DELIMITER)
1399                                .any(|prefix| trimmed_line.ends_with(prefix))
1400                        {
1401                            break;
1402                        } else {
1403                            if this.line_end {
1404                                chunk.push('\n');
1405                                this.line_end = false;
1406                            }
1407
1408                            chunk.push_str(&line_without_cursor);
1409                            consumed += line.len();
1410                        }
1411                    }
1412                }
1413            }
1414
1415            this.buffer = this.buffer.split_off(consumed);
1416            if !chunk.is_empty() {
1417                return Poll::Ready(Some(Ok(chunk)));
1418            } else if this.stream_done {
1419                return Poll::Ready(None);
1420            }
1421        }
1422    }
1423}
1424
1425fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1426    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1427}
1428
1429#[derive(Default)]
1430pub struct Diff {
1431    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1432    pub inserted_row_ranges: Vec<Range<Anchor>>,
1433}
1434
1435impl Diff {
1436    fn is_empty(&self) -> bool {
1437        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1438    }
1439}
1440
1441#[cfg(test)]
1442mod tests {
1443    use super::*;
1444    use futures::{
1445        Stream,
1446        stream::{self},
1447    };
1448    use gpui::TestAppContext;
1449    use indoc::indoc;
1450    use language::{Buffer, Point};
1451    use language_model::{LanguageModelRegistry, TokenUsage};
1452    use languages::rust_lang;
1453    use rand::prelude::*;
1454    use settings::SettingsStore;
1455    use std::{future, sync::Arc};
1456
1457    #[gpui::test(iterations = 10)]
1458    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1459        init_test(cx);
1460
1461        let text = indoc! {"
1462            fn main() {
1463                let x = 0;
1464                for _ in 0..10 {
1465                    x += 1;
1466                }
1467            }
1468        "};
1469        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1470        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1471        let range = buffer.read_with(cx, |buffer, cx| {
1472            let snapshot = buffer.snapshot(cx);
1473            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1474        });
1475        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1476        let codegen = cx.new(|cx| {
1477            CodegenAlternative::new(
1478                buffer.clone(),
1479                range.clone(),
1480                true,
1481                None,
1482                prompt_builder,
1483                cx,
1484            )
1485        });
1486
1487        let chunks_tx = simulate_response_stream(&codegen, cx);
1488
1489        let mut new_text = concat!(
1490            "       let mut x = 0;\n",
1491            "       while x < 10 {\n",
1492            "           x += 1;\n",
1493            "       }",
1494        );
1495        while !new_text.is_empty() {
1496            let max_len = cmp::min(new_text.len(), 10);
1497            let len = rng.random_range(1..=max_len);
1498            let (chunk, suffix) = new_text.split_at(len);
1499            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1500            new_text = suffix;
1501            cx.background_executor.run_until_parked();
1502        }
1503        drop(chunks_tx);
1504        cx.background_executor.run_until_parked();
1505
1506        assert_eq!(
1507            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1508            indoc! {"
1509                fn main() {
1510                    let mut x = 0;
1511                    while x < 10 {
1512                        x += 1;
1513                    }
1514                }
1515            "}
1516        );
1517    }
1518
1519    #[gpui::test(iterations = 10)]
1520    async fn test_autoindent_when_generating_past_indentation(
1521        cx: &mut TestAppContext,
1522        mut rng: StdRng,
1523    ) {
1524        init_test(cx);
1525
1526        let text = indoc! {"
1527            fn main() {
1528                le
1529            }
1530        "};
1531        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1532        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1533        let range = buffer.read_with(cx, |buffer, cx| {
1534            let snapshot = buffer.snapshot(cx);
1535            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1536        });
1537        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1538        let codegen = cx.new(|cx| {
1539            CodegenAlternative::new(
1540                buffer.clone(),
1541                range.clone(),
1542                true,
1543                None,
1544                prompt_builder,
1545                cx,
1546            )
1547        });
1548
1549        let chunks_tx = simulate_response_stream(&codegen, cx);
1550
1551        cx.background_executor.run_until_parked();
1552
1553        let mut new_text = concat!(
1554            "t mut x = 0;\n",
1555            "while x < 10 {\n",
1556            "    x += 1;\n",
1557            "}", //
1558        );
1559        while !new_text.is_empty() {
1560            let max_len = cmp::min(new_text.len(), 10);
1561            let len = rng.random_range(1..=max_len);
1562            let (chunk, suffix) = new_text.split_at(len);
1563            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1564            new_text = suffix;
1565            cx.background_executor.run_until_parked();
1566        }
1567        drop(chunks_tx);
1568        cx.background_executor.run_until_parked();
1569
1570        assert_eq!(
1571            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1572            indoc! {"
1573                fn main() {
1574                    let mut x = 0;
1575                    while x < 10 {
1576                        x += 1;
1577                    }
1578                }
1579            "}
1580        );
1581    }
1582
1583    #[gpui::test(iterations = 10)]
1584    async fn test_autoindent_when_generating_before_indentation(
1585        cx: &mut TestAppContext,
1586        mut rng: StdRng,
1587    ) {
1588        init_test(cx);
1589
1590        let text = concat!(
1591            "fn main() {\n",
1592            "  \n",
1593            "}\n" //
1594        );
1595        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1596        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1597        let range = buffer.read_with(cx, |buffer, cx| {
1598            let snapshot = buffer.snapshot(cx);
1599            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1600        });
1601        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1602        let codegen = cx.new(|cx| {
1603            CodegenAlternative::new(
1604                buffer.clone(),
1605                range.clone(),
1606                true,
1607                None,
1608                prompt_builder,
1609                cx,
1610            )
1611        });
1612
1613        let chunks_tx = simulate_response_stream(&codegen, cx);
1614
1615        cx.background_executor.run_until_parked();
1616
1617        let mut new_text = concat!(
1618            "let mut x = 0;\n",
1619            "while x < 10 {\n",
1620            "    x += 1;\n",
1621            "}", //
1622        );
1623        while !new_text.is_empty() {
1624            let max_len = cmp::min(new_text.len(), 10);
1625            let len = rng.random_range(1..=max_len);
1626            let (chunk, suffix) = new_text.split_at(len);
1627            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1628            new_text = suffix;
1629            cx.background_executor.run_until_parked();
1630        }
1631        drop(chunks_tx);
1632        cx.background_executor.run_until_parked();
1633
1634        assert_eq!(
1635            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1636            indoc! {"
1637                fn main() {
1638                    let mut x = 0;
1639                    while x < 10 {
1640                        x += 1;
1641                    }
1642                }
1643            "}
1644        );
1645    }
1646
1647    #[gpui::test(iterations = 10)]
1648    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1649        init_test(cx);
1650
1651        let text = indoc! {"
1652            func main() {
1653            \tx := 0
1654            \tfor i := 0; i < 10; i++ {
1655            \t\tx++
1656            \t}
1657            }
1658        "};
1659        let buffer = cx.new(|cx| Buffer::local(text, cx));
1660        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1661        let range = buffer.read_with(cx, |buffer, cx| {
1662            let snapshot = buffer.snapshot(cx);
1663            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1664        });
1665        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1666        let codegen = cx.new(|cx| {
1667            CodegenAlternative::new(
1668                buffer.clone(),
1669                range.clone(),
1670                true,
1671                None,
1672                prompt_builder,
1673                cx,
1674            )
1675        });
1676
1677        let chunks_tx = simulate_response_stream(&codegen, cx);
1678        let new_text = concat!(
1679            "func main() {\n",
1680            "\tx := 0\n",
1681            "\tfor x < 10 {\n",
1682            "\t\tx++\n",
1683            "\t}", //
1684        );
1685        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1686        drop(chunks_tx);
1687        cx.background_executor.run_until_parked();
1688
1689        assert_eq!(
1690            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1691            indoc! {"
1692                func main() {
1693                \tx := 0
1694                \tfor x < 10 {
1695                \t\tx++
1696                \t}
1697                }
1698            "}
1699        );
1700    }
1701
1702    #[gpui::test]
1703    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1704        init_test(cx);
1705
1706        let text = indoc! {"
1707            fn main() {
1708                let x = 0;
1709            }
1710        "};
1711        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1712        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1713        let range = buffer.read_with(cx, |buffer, cx| {
1714            let snapshot = buffer.snapshot(cx);
1715            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1716        });
1717        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1718        let codegen = cx.new(|cx| {
1719            CodegenAlternative::new(
1720                buffer.clone(),
1721                range.clone(),
1722                false,
1723                None,
1724                prompt_builder,
1725                cx,
1726            )
1727        });
1728
1729        let chunks_tx = simulate_response_stream(&codegen, cx);
1730        chunks_tx
1731            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1732            .unwrap();
1733        drop(chunks_tx);
1734        cx.run_until_parked();
1735
1736        // The codegen is inactive, so the buffer doesn't get modified.
1737        assert_eq!(
1738            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1739            text
1740        );
1741
1742        // Activating the codegen applies the changes.
1743        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1744        assert_eq!(
1745            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1746            indoc! {"
1747                fn main() {
1748                    let mut x = 0;
1749                    x += 1;
1750                }
1751            "}
1752        );
1753
1754        // Deactivating the codegen undoes the changes.
1755        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1756        cx.run_until_parked();
1757        assert_eq!(
1758            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1759            text
1760        );
1761    }
1762
1763    #[gpui::test]
1764    async fn test_strip_invalid_spans_from_codeblock() {
1765        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1766        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1767        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1768        assert_chunks(
1769            "```html\n```js\nLorem ipsum dolor\n```\n```",
1770            "```js\nLorem ipsum dolor\n```",
1771        )
1772        .await;
1773        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1774        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1775        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1776        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1777
1778        async fn assert_chunks(text: &str, expected_text: &str) {
1779            for chunk_size in 1..=text.len() {
1780                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1781                    .map(|chunk| chunk.unwrap())
1782                    .collect::<String>()
1783                    .await;
1784                assert_eq!(
1785                    actual_text, expected_text,
1786                    "failed to strip invalid spans, chunk size: {}",
1787                    chunk_size
1788                );
1789            }
1790        }
1791
1792        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1793            stream::iter(
1794                text.chars()
1795                    .collect::<Vec<_>>()
1796                    .chunks(size)
1797                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1798                    .collect::<Vec<_>>(),
1799            )
1800        }
1801    }
1802
1803    fn init_test(cx: &mut TestAppContext) {
1804        cx.update(LanguageModelRegistry::test);
1805        cx.set_global(cx.update(SettingsStore::test));
1806    }
1807
1808    fn simulate_response_stream(
1809        codegen: &Entity<CodegenAlternative>,
1810        cx: &mut TestAppContext,
1811    ) -> mpsc::UnboundedSender<String> {
1812        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1813        codegen.update(cx, |codegen, cx| {
1814            codegen.generation = codegen.handle_stream(
1815                String::new(),
1816                String::new(),
1817                None,
1818                future::ready(Ok(LanguageModelTextStream {
1819                    message_id: None,
1820                    stream: chunks_rx.map(Ok).boxed(),
1821                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1822                })),
1823                cx,
1824            );
1825        });
1826        chunks_tx
1827    }
1828}