buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use client::telemetry::Telemetry;
   5use cloud_llm_client::CompletionIntent;
   6use collections::HashSet;
   7use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   8use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
   9use futures::{
  10    SinkExt, Stream, StreamExt, TryStreamExt as _,
  11    channel::mpsc,
  12    future::{LocalBoxFuture, Shared},
  13    join,
  14};
  15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  16use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
  17use language_model::{
  18    LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
  19    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
  20    report_assistant_event,
  21};
  22use multi_buffer::MultiBufferRow;
  23use parking_lot::Mutex;
  24use prompt_store::PromptBuilder;
  25use rope::Rope;
  26use schemars::JsonSchema;
  27use serde::{Deserialize, Serialize};
  28use smol::future::FutureExt;
  29use std::{
  30    cmp,
  31    future::Future,
  32    iter,
  33    ops::{Range, RangeInclusive},
  34    pin::Pin,
  35    sync::Arc,
  36    task::{self, Poll},
  37    time::Instant,
  38};
  39use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  40use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
  41use ui::SharedString;
  42
  43/// Use this tool to provide a message to the user when you're unable to complete a task.
  44#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  45pub struct FailureMessageInput {
  46    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  47    ///
  48    /// The message may use markdown formatting if you wish.
  49    pub message: String,
  50}
  51
  52/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  54pub struct RewriteSectionInput {
  55    /// A brief description of the edit you have made.
  56    ///
  57    /// The description may use markdown formatting if you wish.
  58    /// This is optional - if the edit is simple or obvious, you should leave it empty.
  59    pub description: String,
  60
  61    /// The text to replace the section with.
  62    pub replacement_text: String,
  63}
  64
  65pub struct BufferCodegen {
  66    alternatives: Vec<Entity<CodegenAlternative>>,
  67    pub active_alternative: usize,
  68    seen_alternatives: HashSet<usize>,
  69    subscriptions: Vec<Subscription>,
  70    buffer: Entity<MultiBuffer>,
  71    range: Range<Anchor>,
  72    initial_transaction_id: Option<TransactionId>,
  73    telemetry: Arc<Telemetry>,
  74    builder: Arc<PromptBuilder>,
  75    pub is_insertion: bool,
  76}
  77
  78impl BufferCodegen {
  79    pub fn new(
  80        buffer: Entity<MultiBuffer>,
  81        range: Range<Anchor>,
  82        initial_transaction_id: Option<TransactionId>,
  83        telemetry: Arc<Telemetry>,
  84        builder: Arc<PromptBuilder>,
  85        cx: &mut Context<Self>,
  86    ) -> Self {
  87        let codegen = cx.new(|cx| {
  88            CodegenAlternative::new(
  89                buffer.clone(),
  90                range.clone(),
  91                false,
  92                Some(telemetry.clone()),
  93                builder.clone(),
  94                cx,
  95            )
  96        });
  97        let mut this = Self {
  98            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
  99            alternatives: vec![codegen],
 100            active_alternative: 0,
 101            seen_alternatives: HashSet::default(),
 102            subscriptions: Vec::new(),
 103            buffer,
 104            range,
 105            initial_transaction_id,
 106            telemetry,
 107            builder,
 108        };
 109        this.activate(0, cx);
 110        this
 111    }
 112
 113    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 114        let codegen = self.active_alternative().clone();
 115        self.subscriptions.clear();
 116        self.subscriptions
 117            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 118        self.subscriptions
 119            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 120    }
 121
 122    pub fn active_completion(&self, cx: &App) -> Option<String> {
 123        self.active_alternative().read(cx).current_completion()
 124    }
 125
 126    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 127        &self.alternatives[self.active_alternative]
 128    }
 129
 130    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 131        &self.active_alternative().read(cx).status
 132    }
 133
 134    pub fn alternative_count(&self, cx: &App) -> usize {
 135        LanguageModelRegistry::read_global(cx)
 136            .inline_alternative_models()
 137            .len()
 138            + 1
 139    }
 140
 141    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 142        let next_active_ix = if self.active_alternative == 0 {
 143            self.alternatives.len() - 1
 144        } else {
 145            self.active_alternative - 1
 146        };
 147        self.activate(next_active_ix, cx);
 148    }
 149
 150    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 151        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 152        self.activate(next_active_ix, cx);
 153    }
 154
 155    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 156        self.active_alternative()
 157            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 158        self.seen_alternatives.insert(index);
 159        self.active_alternative = index;
 160        self.active_alternative()
 161            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 162        self.subscribe_to_alternative(cx);
 163        cx.notify();
 164    }
 165
 166    pub fn start(
 167        &mut self,
 168        primary_model: Arc<dyn LanguageModel>,
 169        user_prompt: String,
 170        context_task: Shared<Task<Option<LoadedContext>>>,
 171        cx: &mut Context<Self>,
 172    ) -> Result<()> {
 173        let alternative_models = LanguageModelRegistry::read_global(cx)
 174            .inline_alternative_models()
 175            .to_vec();
 176
 177        self.active_alternative()
 178            .update(cx, |alternative, cx| alternative.undo(cx));
 179        self.activate(0, cx);
 180        self.alternatives.truncate(1);
 181
 182        for _ in 0..alternative_models.len() {
 183            self.alternatives.push(cx.new(|cx| {
 184                CodegenAlternative::new(
 185                    self.buffer.clone(),
 186                    self.range.clone(),
 187                    false,
 188                    Some(self.telemetry.clone()),
 189                    self.builder.clone(),
 190                    cx,
 191                )
 192            }));
 193        }
 194
 195        for (model, alternative) in iter::once(primary_model)
 196            .chain(alternative_models)
 197            .zip(&self.alternatives)
 198        {
 199            alternative.update(cx, |alternative, cx| {
 200                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 201            })?;
 202        }
 203
 204        Ok(())
 205    }
 206
 207    pub fn stop(&mut self, cx: &mut Context<Self>) {
 208        for codegen in &self.alternatives {
 209            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 210        }
 211    }
 212
 213    pub fn undo(&mut self, cx: &mut Context<Self>) {
 214        self.active_alternative()
 215            .update(cx, |codegen, cx| codegen.undo(cx));
 216
 217        self.buffer.update(cx, |buffer, cx| {
 218            if let Some(transaction_id) = self.initial_transaction_id.take() {
 219                buffer.undo_transaction(transaction_id, cx);
 220                buffer.refresh_preview(cx);
 221            }
 222        });
 223    }
 224
 225    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 226        self.active_alternative().read(cx).buffer.clone()
 227    }
 228
 229    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 230        self.active_alternative().read(cx).old_buffer.clone()
 231    }
 232
 233    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 234        self.active_alternative().read(cx).snapshot.clone()
 235    }
 236
 237    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 238        self.active_alternative().read(cx).edit_position
 239    }
 240
 241    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 242        &self.active_alternative().read(cx).diff
 243    }
 244
 245    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 246        self.active_alternative().read(cx).last_equal_ranges()
 247    }
 248
 249    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 250        self.active_alternative().read(cx).selected_text()
 251    }
 252}
 253
 254impl EventEmitter<CodegenEvent> for BufferCodegen {}
 255
 256pub struct CodegenAlternative {
 257    buffer: Entity<MultiBuffer>,
 258    old_buffer: Entity<Buffer>,
 259    snapshot: MultiBufferSnapshot,
 260    edit_position: Option<Anchor>,
 261    range: Range<Anchor>,
 262    last_equal_ranges: Vec<Range<Anchor>>,
 263    transformation_transaction_id: Option<TransactionId>,
 264    status: CodegenStatus,
 265    generation: Task<()>,
 266    diff: Diff,
 267    telemetry: Option<Arc<Telemetry>>,
 268    _subscription: gpui::Subscription,
 269    builder: Arc<PromptBuilder>,
 270    active: bool,
 271    edits: Vec<(Range<Anchor>, String)>,
 272    line_operations: Vec<LineOperation>,
 273    elapsed_time: Option<f64>,
 274    completion: Option<String>,
 275    selected_text: Option<String>,
 276    pub message_id: Option<String>,
 277    pub model_explanation: Option<SharedString>,
 278}
 279
 280impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 281
 282impl CodegenAlternative {
 283    pub fn new(
 284        buffer: Entity<MultiBuffer>,
 285        range: Range<Anchor>,
 286        active: bool,
 287        telemetry: Option<Arc<Telemetry>>,
 288        builder: Arc<PromptBuilder>,
 289        cx: &mut Context<Self>,
 290    ) -> Self {
 291        let snapshot = buffer.read(cx).snapshot(cx);
 292
 293        let (old_buffer, _, _) = snapshot
 294            .range_to_buffer_ranges(range.clone())
 295            .pop()
 296            .unwrap();
 297        let old_buffer = cx.new(|cx| {
 298            let text = old_buffer.as_rope().clone();
 299            let line_ending = old_buffer.line_ending();
 300            let language = old_buffer.language().cloned();
 301            let language_registry = buffer
 302                .read(cx)
 303                .buffer(old_buffer.remote_id())
 304                .unwrap()
 305                .read(cx)
 306                .language_registry();
 307
 308            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 309            buffer.set_language(language, cx);
 310            if let Some(language_registry) = language_registry {
 311                buffer.set_language_registry(language_registry);
 312            }
 313            buffer
 314        });
 315
 316        Self {
 317            buffer: buffer.clone(),
 318            old_buffer,
 319            edit_position: None,
 320            message_id: None,
 321            snapshot,
 322            last_equal_ranges: Default::default(),
 323            transformation_transaction_id: None,
 324            status: CodegenStatus::Idle,
 325            generation: Task::ready(()),
 326            diff: Diff::default(),
 327            telemetry,
 328            builder,
 329            active: active,
 330            edits: Vec::new(),
 331            line_operations: Vec::new(),
 332            range,
 333            elapsed_time: None,
 334            completion: None,
 335            selected_text: None,
 336            model_explanation: None,
 337            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 338        }
 339    }
 340
 341    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 342        if active != self.active {
 343            self.active = active;
 344
 345            if self.active {
 346                let edits = self.edits.clone();
 347                self.apply_edits(edits, cx);
 348                if matches!(self.status, CodegenStatus::Pending) {
 349                    let line_operations = self.line_operations.clone();
 350                    self.reapply_line_based_diff(line_operations, cx);
 351                } else {
 352                    self.reapply_batch_diff(cx).detach();
 353                }
 354            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 355                self.buffer.update(cx, |buffer, cx| {
 356                    buffer.undo_transaction(transaction_id, cx);
 357                    buffer.forget_transaction(transaction_id, cx);
 358                });
 359            }
 360        }
 361    }
 362
 363    fn handle_buffer_event(
 364        &mut self,
 365        _buffer: Entity<MultiBuffer>,
 366        event: &multi_buffer::Event,
 367        cx: &mut Context<Self>,
 368    ) {
 369        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 370            && self.transformation_transaction_id == Some(*transaction_id)
 371        {
 372            self.transformation_transaction_id = None;
 373            self.generation = Task::ready(());
 374            cx.emit(CodegenEvent::Undone);
 375        }
 376    }
 377
 378    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 379        &self.last_equal_ranges
 380    }
 381
 382    pub fn start(
 383        &mut self,
 384        user_prompt: String,
 385        context_task: Shared<Task<Option<LoadedContext>>>,
 386        model: Arc<dyn LanguageModel>,
 387        cx: &mut Context<Self>,
 388    ) -> Result<()> {
 389        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 390            self.buffer.update(cx, |buffer, cx| {
 391                buffer.undo_transaction(transformation_transaction_id, cx);
 392            });
 393        }
 394
 395        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 396
 397        let api_key = model.api_key(cx);
 398        let telemetry_id = model.telemetry_id();
 399        let provider_id = model.provider_id();
 400
 401        if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
 402            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 403            let tool_use =
 404                cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
 405            self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
 406        } else {
 407            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 408                if user_prompt.trim().to_lowercase() == "delete" {
 409                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 410                } else {
 411                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 412                    cx.spawn(async move |_, cx| {
 413                        Ok(model.stream_completion_text(request.await, cx).await?)
 414                    })
 415                    .boxed_local()
 416                };
 417            self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
 418        }
 419
 420        Ok(())
 421    }
 422
 423    fn build_request_v2(
 424        &self,
 425        model: &Arc<dyn LanguageModel>,
 426        user_prompt: String,
 427        context_task: Shared<Task<Option<LoadedContext>>>,
 428        cx: &mut App,
 429    ) -> Result<Task<LanguageModelRequest>> {
 430        let buffer = self.buffer.read(cx).snapshot(cx);
 431        let language = buffer.language_at(self.range.start);
 432        let language_name = if let Some(language) = language.as_ref() {
 433            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 434                None
 435            } else {
 436                Some(language.name())
 437            }
 438        } else {
 439            None
 440        };
 441
 442        let language_name = language_name.as_ref();
 443        let start = buffer.point_to_buffer_offset(self.range.start);
 444        let end = buffer.point_to_buffer_offset(self.range.end);
 445        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 446            let (start_buffer, start_buffer_offset) = start;
 447            let (end_buffer, end_buffer_offset) = end;
 448            if start_buffer.remote_id() == end_buffer.remote_id() {
 449                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 450            } else {
 451                anyhow::bail!("invalid transformation range");
 452            }
 453        } else {
 454            anyhow::bail!("invalid transformation range");
 455        };
 456
 457        let system_prompt = self
 458            .builder
 459            .generate_inline_transformation_prompt_v2(
 460                language_name,
 461                buffer,
 462                range.start.0..range.end.0,
 463            )
 464            .context("generating content prompt")?;
 465
 466        let temperature = AgentSettings::temperature_for_model(model, cx);
 467
 468        let tool_input_format = model.tool_input_format();
 469
 470        Ok(cx.spawn(async move |_cx| {
 471            let mut messages = vec![LanguageModelRequestMessage {
 472                role: Role::System,
 473                content: vec![system_prompt.into()],
 474                cache: false,
 475                reasoning_details: None,
 476            }];
 477
 478            let mut user_message = LanguageModelRequestMessage {
 479                role: Role::User,
 480                content: Vec::new(),
 481                cache: false,
 482                reasoning_details: None,
 483            };
 484
 485            if let Some(context) = context_task.await {
 486                context.add_to_request_message(&mut user_message);
 487            }
 488
 489            user_message.content.push(user_prompt.into());
 490            messages.push(user_message);
 491
 492            let tools = vec![
 493                LanguageModelRequestTool {
 494                    name: "rewrite_section".to_string(),
 495                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 496                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 497                },
 498                LanguageModelRequestTool {
 499                    name: "failure_message".to_string(),
 500                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 501                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 502                },
 503            ];
 504
 505            LanguageModelRequest {
 506                thread_id: None,
 507                prompt_id: None,
 508                intent: Some(CompletionIntent::InlineAssist),
 509                mode: None,
 510                tools,
 511                tool_choice: None,
 512                stop: Vec::new(),
 513                temperature,
 514                messages,
 515                thinking_allowed: false,
 516            }
 517        }))
 518    }
 519
 520    fn build_request(
 521        &self,
 522        model: &Arc<dyn LanguageModel>,
 523        user_prompt: String,
 524        context_task: Shared<Task<Option<LoadedContext>>>,
 525        cx: &mut App,
 526    ) -> Result<Task<LanguageModelRequest>> {
 527        if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
 528            return self.build_request_v2(model, user_prompt, context_task, cx);
 529        }
 530
 531        let buffer = self.buffer.read(cx).snapshot(cx);
 532        let language = buffer.language_at(self.range.start);
 533        let language_name = if let Some(language) = language.as_ref() {
 534            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 535                None
 536            } else {
 537                Some(language.name())
 538            }
 539        } else {
 540            None
 541        };
 542
 543        let language_name = language_name.as_ref();
 544        let start = buffer.point_to_buffer_offset(self.range.start);
 545        let end = buffer.point_to_buffer_offset(self.range.end);
 546        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 547            let (start_buffer, start_buffer_offset) = start;
 548            let (end_buffer, end_buffer_offset) = end;
 549            if start_buffer.remote_id() == end_buffer.remote_id() {
 550                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 551            } else {
 552                anyhow::bail!("invalid transformation range");
 553            }
 554        } else {
 555            anyhow::bail!("invalid transformation range");
 556        };
 557
 558        let prompt = self
 559            .builder
 560            .generate_inline_transformation_prompt(
 561                user_prompt,
 562                language_name,
 563                buffer,
 564                range.start.0..range.end.0,
 565            )
 566            .context("generating content prompt")?;
 567
 568        let temperature = AgentSettings::temperature_for_model(model, cx);
 569
 570        Ok(cx.spawn(async move |_cx| {
 571            let mut request_message = LanguageModelRequestMessage {
 572                role: Role::User,
 573                content: Vec::new(),
 574                cache: false,
 575                reasoning_details: None,
 576            };
 577
 578            if let Some(context) = context_task.await {
 579                context.add_to_request_message(&mut request_message);
 580            }
 581
 582            request_message.content.push(prompt.into());
 583
 584            LanguageModelRequest {
 585                thread_id: None,
 586                prompt_id: None,
 587                intent: Some(CompletionIntent::InlineAssist),
 588                mode: None,
 589                tools: Vec::new(),
 590                tool_choice: None,
 591                stop: Vec::new(),
 592                temperature,
 593                messages: vec![request_message],
 594                thinking_allowed: false,
 595            }
 596        }))
 597    }
 598
 599    pub fn handle_stream(
 600        &mut self,
 601        model_telemetry_id: String,
 602        model_provider_id: String,
 603        model_api_key: Option<String>,
 604        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 605        cx: &mut Context<Self>,
 606    ) {
 607        let start_time = Instant::now();
 608
 609        // Make a new snapshot and re-resolve anchor in case the document was modified.
 610        // This can happen often if the editor loses focus and is saved + reformatted,
 611        // as in https://github.com/zed-industries/zed/issues/39088
 612        self.snapshot = self.buffer.read(cx).snapshot(cx);
 613        self.range = self.snapshot.anchor_after(self.range.start)
 614            ..self.snapshot.anchor_after(self.range.end);
 615
 616        let snapshot = self.snapshot.clone();
 617        let selected_text = snapshot
 618            .text_for_range(self.range.start..self.range.end)
 619            .collect::<Rope>();
 620
 621        self.selected_text = Some(selected_text.to_string());
 622
 623        let selection_start = self.range.start.to_point(&snapshot);
 624
 625        // Start with the indentation of the first line in the selection
 626        let mut suggested_line_indent = snapshot
 627            .suggested_indents(selection_start.row..=selection_start.row, cx)
 628            .into_values()
 629            .next()
 630            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 631
 632        // If the first line in the selection does not have indentation, check the following lines
 633        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 634            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 635                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 636                // Prefer tabs if a line in the selection uses tabs as indentation
 637                if line_indent.kind == IndentKind::Tab {
 638                    suggested_line_indent.kind = IndentKind::Tab;
 639                    break;
 640                }
 641            }
 642        }
 643
 644        let http_client = cx.http_client();
 645        let telemetry = self.telemetry.clone();
 646        let language_name = {
 647            let multibuffer = self.buffer.read(cx);
 648            let snapshot = multibuffer.snapshot(cx);
 649            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 650            ranges
 651                .first()
 652                .and_then(|(buffer, _, _)| buffer.language())
 653                .map(|language| language.name())
 654        };
 655
 656        self.diff = Diff::default();
 657        self.status = CodegenStatus::Pending;
 658        let mut edit_start = self.range.start.to_offset(&snapshot);
 659        let completion = Arc::new(Mutex::new(String::new()));
 660        let completion_clone = completion.clone();
 661
 662        self.generation = cx.spawn(async move |codegen, cx| {
 663            let stream = stream.await;
 664
 665            let token_usage = stream
 666                .as_ref()
 667                .ok()
 668                .map(|stream| stream.last_token_usage.clone());
 669            let message_id = stream
 670                .as_ref()
 671                .ok()
 672                .and_then(|stream| stream.message_id.clone());
 673            let generate = async {
 674                let model_telemetry_id = model_telemetry_id.clone();
 675                let model_provider_id = model_provider_id.clone();
 676                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 677                let executor = cx.background_executor().clone();
 678                let message_id = message_id.clone();
 679                let line_based_stream_diff: Task<anyhow::Result<()>> =
 680                    cx.background_spawn(async move {
 681                        let mut response_latency = None;
 682                        let request_start = Instant::now();
 683                        let diff = async {
 684                            let chunks = StripInvalidSpans::new(
 685                                stream?.stream.map_err(|error| error.into()),
 686                            );
 687                            futures::pin_mut!(chunks);
 688                            let mut diff = StreamingDiff::new(selected_text.to_string());
 689                            let mut line_diff = LineDiff::default();
 690
 691                            let mut new_text = String::new();
 692                            let mut base_indent = None;
 693                            let mut line_indent = None;
 694                            let mut first_line = true;
 695
 696                            while let Some(chunk) = chunks.next().await {
 697                                if response_latency.is_none() {
 698                                    response_latency = Some(request_start.elapsed());
 699                                }
 700                                let chunk = chunk?;
 701                                completion_clone.lock().push_str(&chunk);
 702
 703                                let mut lines = chunk.split('\n').peekable();
 704                                while let Some(line) = lines.next() {
 705                                    new_text.push_str(line);
 706                                    if line_indent.is_none()
 707                                        && let Some(non_whitespace_ch_ix) =
 708                                            new_text.find(|ch: char| !ch.is_whitespace())
 709                                    {
 710                                        line_indent = Some(non_whitespace_ch_ix);
 711                                        base_indent = base_indent.or(line_indent);
 712
 713                                        let line_indent = line_indent.unwrap();
 714                                        let base_indent = base_indent.unwrap();
 715                                        let indent_delta = line_indent as i32 - base_indent as i32;
 716                                        let mut corrected_indent_len = cmp::max(
 717                                            0,
 718                                            suggested_line_indent.len as i32 + indent_delta,
 719                                        )
 720                                            as usize;
 721                                        if first_line {
 722                                            corrected_indent_len = corrected_indent_len
 723                                                .saturating_sub(selection_start.column as usize);
 724                                        }
 725
 726                                        let indent_char = suggested_line_indent.char();
 727                                        let mut indent_buffer = [0; 4];
 728                                        let indent_str =
 729                                            indent_char.encode_utf8(&mut indent_buffer);
 730                                        new_text.replace_range(
 731                                            ..line_indent,
 732                                            &indent_str.repeat(corrected_indent_len),
 733                                        );
 734                                    }
 735
 736                                    if line_indent.is_some() {
 737                                        let char_ops = diff.push_new(&new_text);
 738                                        line_diff.push_char_operations(&char_ops, &selected_text);
 739                                        diff_tx
 740                                            .send((char_ops, line_diff.line_operations()))
 741                                            .await?;
 742                                        new_text.clear();
 743                                    }
 744
 745                                    if lines.peek().is_some() {
 746                                        let char_ops = diff.push_new("\n");
 747                                        line_diff.push_char_operations(&char_ops, &selected_text);
 748                                        diff_tx
 749                                            .send((char_ops, line_diff.line_operations()))
 750                                            .await?;
 751                                        if line_indent.is_none() {
 752                                            // Don't write out the leading indentation in empty lines on the next line
 753                                            // This is the case where the above if statement didn't clear the buffer
 754                                            new_text.clear();
 755                                        }
 756                                        line_indent = None;
 757                                        first_line = false;
 758                                    }
 759                                }
 760                            }
 761
 762                            let mut char_ops = diff.push_new(&new_text);
 763                            char_ops.extend(diff.finish());
 764                            line_diff.push_char_operations(&char_ops, &selected_text);
 765                            line_diff.finish(&selected_text);
 766                            diff_tx
 767                                .send((char_ops, line_diff.line_operations()))
 768                                .await?;
 769
 770                            anyhow::Ok(())
 771                        };
 772
 773                        let result = diff.await;
 774
 775                        let error_message = result.as_ref().err().map(|error| error.to_string());
 776                        report_assistant_event(
 777                            AssistantEventData {
 778                                conversation_id: None,
 779                                message_id,
 780                                kind: AssistantKind::Inline,
 781                                phase: AssistantPhase::Response,
 782                                model: model_telemetry_id,
 783                                model_provider: model_provider_id,
 784                                response_latency,
 785                                error_message,
 786                                language_name: language_name.map(|name| name.to_proto()),
 787                            },
 788                            telemetry,
 789                            http_client,
 790                            model_api_key,
 791                            &executor,
 792                        );
 793
 794                        result?;
 795                        Ok(())
 796                    });
 797
 798                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 799                    codegen.update(cx, |codegen, cx| {
 800                        codegen.last_equal_ranges.clear();
 801
 802                        let edits = char_ops
 803                            .into_iter()
 804                            .filter_map(|operation| match operation {
 805                                CharOperation::Insert { text } => {
 806                                    let edit_start = snapshot.anchor_after(edit_start);
 807                                    Some((edit_start..edit_start, text))
 808                                }
 809                                CharOperation::Delete { bytes } => {
 810                                    let edit_end = edit_start + bytes;
 811                                    let edit_range = snapshot.anchor_after(edit_start)
 812                                        ..snapshot.anchor_before(edit_end);
 813                                    edit_start = edit_end;
 814                                    Some((edit_range, String::new()))
 815                                }
 816                                CharOperation::Keep { bytes } => {
 817                                    let edit_end = edit_start + bytes;
 818                                    let edit_range = snapshot.anchor_after(edit_start)
 819                                        ..snapshot.anchor_before(edit_end);
 820                                    edit_start = edit_end;
 821                                    codegen.last_equal_ranges.push(edit_range);
 822                                    None
 823                                }
 824                            })
 825                            .collect::<Vec<_>>();
 826
 827                        if codegen.active {
 828                            codegen.apply_edits(edits.iter().cloned(), cx);
 829                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 830                        }
 831                        codegen.edits.extend(edits);
 832                        codegen.line_operations = line_ops;
 833                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 834
 835                        cx.notify();
 836                    })?;
 837                }
 838
 839                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 840                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 841                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 842                let batch_diff_task =
 843                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 844                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 845                line_based_stream_diff?;
 846
 847                anyhow::Ok(())
 848            };
 849
 850            let result = generate.await;
 851            let elapsed_time = start_time.elapsed().as_secs_f64();
 852
 853            codegen
 854                .update(cx, |this, cx| {
 855                    this.message_id = message_id;
 856                    this.last_equal_ranges.clear();
 857                    if let Err(error) = result {
 858                        this.status = CodegenStatus::Error(error);
 859                    } else {
 860                        this.status = CodegenStatus::Done;
 861                    }
 862                    this.elapsed_time = Some(elapsed_time);
 863                    this.completion = Some(completion.lock().clone());
 864                    if let Some(usage) = token_usage {
 865                        let usage = usage.lock();
 866                        telemetry::event!(
 867                            "Inline Assistant Completion",
 868                            model = model_telemetry_id,
 869                            model_provider = model_provider_id,
 870                            input_tokens = usage.input_tokens,
 871                            output_tokens = usage.output_tokens,
 872                        )
 873                    }
 874
 875                    cx.emit(CodegenEvent::Finished);
 876                    cx.notify();
 877                })
 878                .ok();
 879        });
 880        cx.notify();
 881    }
 882
 883    pub fn current_completion(&self) -> Option<String> {
 884        self.completion.clone()
 885    }
 886
 887    pub fn selected_text(&self) -> Option<&str> {
 888        self.selected_text.as_deref()
 889    }
 890
 891    pub fn stop(&mut self, cx: &mut Context<Self>) {
 892        self.last_equal_ranges.clear();
 893        if self.diff.is_empty() {
 894            self.status = CodegenStatus::Idle;
 895        } else {
 896            self.status = CodegenStatus::Done;
 897        }
 898        self.generation = Task::ready(());
 899        cx.emit(CodegenEvent::Finished);
 900        cx.notify();
 901    }
 902
 903    pub fn undo(&mut self, cx: &mut Context<Self>) {
 904        self.buffer.update(cx, |buffer, cx| {
 905            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 906                buffer.undo_transaction(transaction_id, cx);
 907                buffer.refresh_preview(cx);
 908            }
 909        });
 910    }
 911
 912    fn apply_edits(
 913        &mut self,
 914        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 915        cx: &mut Context<CodegenAlternative>,
 916    ) {
 917        let transaction = self.buffer.update(cx, |buffer, cx| {
 918            // Avoid grouping agent edits with user edits.
 919            buffer.finalize_last_transaction(cx);
 920            buffer.start_transaction(cx);
 921            buffer.edit(edits, None, cx);
 922            buffer.end_transaction(cx)
 923        });
 924
 925        if let Some(transaction) = transaction {
 926            if let Some(first_transaction) = self.transformation_transaction_id {
 927                // Group all agent edits into the first transaction.
 928                self.buffer.update(cx, |buffer, cx| {
 929                    buffer.merge_transactions(transaction, first_transaction, cx)
 930                });
 931            } else {
 932                self.transformation_transaction_id = Some(transaction);
 933                self.buffer
 934                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 935            }
 936        }
 937    }
 938
 939    fn reapply_line_based_diff(
 940        &mut self,
 941        line_operations: impl IntoIterator<Item = LineOperation>,
 942        cx: &mut Context<Self>,
 943    ) {
 944        let old_snapshot = self.snapshot.clone();
 945        let old_range = self.range.to_point(&old_snapshot);
 946        let new_snapshot = self.buffer.read(cx).snapshot(cx);
 947        let new_range = self.range.to_point(&new_snapshot);
 948
 949        let mut old_row = old_range.start.row;
 950        let mut new_row = new_range.start.row;
 951
 952        self.diff.deleted_row_ranges.clear();
 953        self.diff.inserted_row_ranges.clear();
 954        for operation in line_operations {
 955            match operation {
 956                LineOperation::Keep { lines } => {
 957                    old_row += lines;
 958                    new_row += lines;
 959                }
 960                LineOperation::Delete { lines } => {
 961                    let old_end_row = old_row + lines - 1;
 962                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
 963
 964                    if let Some((_, last_deleted_row_range)) =
 965                        self.diff.deleted_row_ranges.last_mut()
 966                    {
 967                        if *last_deleted_row_range.end() + 1 == old_row {
 968                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
 969                        } else {
 970                            self.diff
 971                                .deleted_row_ranges
 972                                .push((new_row, old_row..=old_end_row));
 973                        }
 974                    } else {
 975                        self.diff
 976                            .deleted_row_ranges
 977                            .push((new_row, old_row..=old_end_row));
 978                    }
 979
 980                    old_row += lines;
 981                }
 982                LineOperation::Insert { lines } => {
 983                    let new_end_row = new_row + lines - 1;
 984                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
 985                    let end = new_snapshot.anchor_before(Point::new(
 986                        new_end_row,
 987                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
 988                    ));
 989                    self.diff.inserted_row_ranges.push(start..end);
 990                    new_row += lines;
 991                }
 992            }
 993
 994            cx.notify();
 995        }
 996    }
 997
 998    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
 999        let old_snapshot = self.snapshot.clone();
1000        let old_range = self.range.to_point(&old_snapshot);
1001        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1002        let new_range = self.range.to_point(&new_snapshot);
1003
1004        cx.spawn(async move |codegen, cx| {
1005            let (deleted_row_ranges, inserted_row_ranges) = cx
1006                .background_spawn(async move {
1007                    let old_text = old_snapshot
1008                        .text_for_range(
1009                            Point::new(old_range.start.row, 0)
1010                                ..Point::new(
1011                                    old_range.end.row,
1012                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1013                                ),
1014                        )
1015                        .collect::<String>();
1016                    let new_text = new_snapshot
1017                        .text_for_range(
1018                            Point::new(new_range.start.row, 0)
1019                                ..Point::new(
1020                                    new_range.end.row,
1021                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1022                                ),
1023                        )
1024                        .collect::<String>();
1025
1026                    let old_start_row = old_range.start.row;
1027                    let new_start_row = new_range.start.row;
1028                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1029                    let mut inserted_row_ranges = Vec::new();
1030                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1031                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1032                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1033                        if !old_rows.is_empty() {
1034                            deleted_row_ranges.push((
1035                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1036                                old_rows.start..=old_rows.end - 1,
1037                            ));
1038                        }
1039                        if !new_rows.is_empty() {
1040                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1041                            let new_end_row = new_rows.end - 1;
1042                            let end = new_snapshot.anchor_before(Point::new(
1043                                new_end_row,
1044                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1045                            ));
1046                            inserted_row_ranges.push(start..end);
1047                        }
1048                    }
1049                    (deleted_row_ranges, inserted_row_ranges)
1050                })
1051                .await;
1052
1053            codegen
1054                .update(cx, |codegen, cx| {
1055                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1056                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1057                    cx.notify();
1058                })
1059                .ok();
1060        })
1061    }
1062
1063    fn handle_tool_use(
1064        &mut self,
1065        _telemetry_id: String,
1066        _provider_id: String,
1067        _api_key: Option<String>,
1068        tool_use: impl 'static
1069        + Future<
1070            Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
1071        >,
1072        cx: &mut Context<Self>,
1073    ) {
1074        self.diff = Diff::default();
1075        self.status = CodegenStatus::Pending;
1076
1077        self.generation = cx.spawn(async move |codegen, cx| {
1078            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1079                let _ = codegen.update(cx, |this, cx| {
1080                    this.status = status;
1081                    cx.emit(CodegenEvent::Finished);
1082                    cx.notify();
1083                });
1084            };
1085
1086            let tool_use = tool_use.await;
1087
1088            match tool_use {
1089                Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
1090                    // Parse the input JSON into RewriteSectionInput
1091                    match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
1092                        Ok(input) => {
1093                            // Store the description if non-empty
1094                            let description = if !input.description.trim().is_empty() {
1095                                Some(input.description.clone())
1096                            } else {
1097                                None
1098                            };
1099
1100                            // Apply the replacement text to the buffer and compute diff
1101                            let batch_diff_task = codegen
1102                                .update(cx, |this, cx| {
1103                                    this.model_explanation = description.map(Into::into);
1104                                    let range = this.range.clone();
1105                                    this.apply_edits(
1106                                        std::iter::once((range, input.replacement_text)),
1107                                        cx,
1108                                    );
1109                                    this.reapply_batch_diff(cx)
1110                                })
1111                                .ok();
1112
1113                            // Wait for the diff computation to complete
1114                            if let Some(diff_task) = batch_diff_task {
1115                                diff_task.await;
1116                            }
1117
1118                            finish_with_status(CodegenStatus::Done, cx);
1119                            return;
1120                        }
1121                        Err(e) => {
1122                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1123                            return;
1124                        }
1125                    }
1126                }
1127                Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
1128                    // Handle failure message tool use
1129                    match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
1130                        Ok(input) => {
1131                            let _ = codegen.update(cx, |this, _cx| {
1132                                // Store the failure message as the tool description
1133                                this.model_explanation = Some(input.message.into());
1134                            });
1135                            finish_with_status(CodegenStatus::Done, cx);
1136                            return;
1137                        }
1138                        Err(e) => {
1139                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1140                            return;
1141                        }
1142                    }
1143                }
1144                Ok(_tool_use) => {
1145                    // Unexpected tool.
1146                    finish_with_status(CodegenStatus::Done, cx);
1147                    return;
1148                }
1149                Err(e) => {
1150                    finish_with_status(CodegenStatus::Error(e.into()), cx);
1151                    return;
1152                }
1153            }
1154        });
1155        cx.notify();
1156    }
1157}
1158
1159#[derive(Copy, Clone, Debug)]
1160pub enum CodegenEvent {
1161    Finished,
1162    Undone,
1163}
1164
1165struct StripInvalidSpans<T> {
1166    stream: T,
1167    stream_done: bool,
1168    buffer: String,
1169    first_line: bool,
1170    line_end: bool,
1171    starts_with_code_block: bool,
1172}
1173
1174impl<T> StripInvalidSpans<T>
1175where
1176    T: Stream<Item = Result<String>>,
1177{
1178    fn new(stream: T) -> Self {
1179        Self {
1180            stream,
1181            stream_done: false,
1182            buffer: String::new(),
1183            first_line: true,
1184            line_end: false,
1185            starts_with_code_block: false,
1186        }
1187    }
1188}
1189
1190impl<T> Stream for StripInvalidSpans<T>
1191where
1192    T: Stream<Item = Result<String>>,
1193{
1194    type Item = Result<String>;
1195
1196    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1197        const CODE_BLOCK_DELIMITER: &str = "```";
1198        const CURSOR_SPAN: &str = "<|CURSOR|>";
1199
1200        let this = unsafe { self.get_unchecked_mut() };
1201        loop {
1202            if !this.stream_done {
1203                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1204                match stream.as_mut().poll_next(cx) {
1205                    Poll::Ready(Some(Ok(chunk))) => {
1206                        this.buffer.push_str(&chunk);
1207                    }
1208                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1209                    Poll::Ready(None) => {
1210                        this.stream_done = true;
1211                    }
1212                    Poll::Pending => return Poll::Pending,
1213                }
1214            }
1215
1216            let mut chunk = String::new();
1217            let mut consumed = 0;
1218            if !this.buffer.is_empty() {
1219                let mut lines = this.buffer.split('\n').enumerate().peekable();
1220                while let Some((line_ix, line)) = lines.next() {
1221                    if line_ix > 0 {
1222                        this.first_line = false;
1223                    }
1224
1225                    if this.first_line {
1226                        let trimmed_line = line.trim();
1227                        if lines.peek().is_some() {
1228                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1229                                consumed += line.len() + 1;
1230                                this.starts_with_code_block = true;
1231                                continue;
1232                            }
1233                        } else if trimmed_line.is_empty()
1234                            || prefixes(CODE_BLOCK_DELIMITER)
1235                                .any(|prefix| trimmed_line.starts_with(prefix))
1236                        {
1237                            break;
1238                        }
1239                    }
1240
1241                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1242                    if lines.peek().is_some() {
1243                        if this.line_end {
1244                            chunk.push('\n');
1245                        }
1246
1247                        chunk.push_str(&line_without_cursor);
1248                        this.line_end = true;
1249                        consumed += line.len() + 1;
1250                    } else if this.stream_done {
1251                        if !this.starts_with_code_block
1252                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1253                        {
1254                            if this.line_end {
1255                                chunk.push('\n');
1256                            }
1257
1258                            chunk.push_str(line);
1259                        }
1260
1261                        consumed += line.len();
1262                    } else {
1263                        let trimmed_line = line.trim();
1264                        if trimmed_line.is_empty()
1265                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1266                            || prefixes(CODE_BLOCK_DELIMITER)
1267                                .any(|prefix| trimmed_line.ends_with(prefix))
1268                        {
1269                            break;
1270                        } else {
1271                            if this.line_end {
1272                                chunk.push('\n');
1273                                this.line_end = false;
1274                            }
1275
1276                            chunk.push_str(&line_without_cursor);
1277                            consumed += line.len();
1278                        }
1279                    }
1280                }
1281            }
1282
1283            this.buffer = this.buffer.split_off(consumed);
1284            if !chunk.is_empty() {
1285                return Poll::Ready(Some(Ok(chunk)));
1286            } else if this.stream_done {
1287                return Poll::Ready(None);
1288            }
1289        }
1290    }
1291}
1292
1293fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1294    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1295}
1296
1297#[derive(Default)]
1298pub struct Diff {
1299    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1300    pub inserted_row_ranges: Vec<Range<Anchor>>,
1301}
1302
1303impl Diff {
1304    fn is_empty(&self) -> bool {
1305        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1306    }
1307}
1308
1309#[cfg(test)]
1310mod tests {
1311    use super::*;
1312    use futures::{
1313        Stream,
1314        stream::{self},
1315    };
1316    use gpui::TestAppContext;
1317    use indoc::indoc;
1318    use language::{Buffer, Point};
1319    use language_model::{LanguageModelRegistry, TokenUsage};
1320    use languages::rust_lang;
1321    use rand::prelude::*;
1322    use settings::SettingsStore;
1323    use std::{future, sync::Arc};
1324
1325    #[gpui::test(iterations = 10)]
1326    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1327        init_test(cx);
1328
1329        let text = indoc! {"
1330            fn main() {
1331                let x = 0;
1332                for _ in 0..10 {
1333                    x += 1;
1334                }
1335            }
1336        "};
1337        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1338        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1339        let range = buffer.read_with(cx, |buffer, cx| {
1340            let snapshot = buffer.snapshot(cx);
1341            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1342        });
1343        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1344        let codegen = cx.new(|cx| {
1345            CodegenAlternative::new(
1346                buffer.clone(),
1347                range.clone(),
1348                true,
1349                None,
1350                prompt_builder,
1351                cx,
1352            )
1353        });
1354
1355        let chunks_tx = simulate_response_stream(&codegen, cx);
1356
1357        let mut new_text = concat!(
1358            "       let mut x = 0;\n",
1359            "       while x < 10 {\n",
1360            "           x += 1;\n",
1361            "       }",
1362        );
1363        while !new_text.is_empty() {
1364            let max_len = cmp::min(new_text.len(), 10);
1365            let len = rng.random_range(1..=max_len);
1366            let (chunk, suffix) = new_text.split_at(len);
1367            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1368            new_text = suffix;
1369            cx.background_executor.run_until_parked();
1370        }
1371        drop(chunks_tx);
1372        cx.background_executor.run_until_parked();
1373
1374        assert_eq!(
1375            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1376            indoc! {"
1377                fn main() {
1378                    let mut x = 0;
1379                    while x < 10 {
1380                        x += 1;
1381                    }
1382                }
1383            "}
1384        );
1385    }
1386
1387    #[gpui::test(iterations = 10)]
1388    async fn test_autoindent_when_generating_past_indentation(
1389        cx: &mut TestAppContext,
1390        mut rng: StdRng,
1391    ) {
1392        init_test(cx);
1393
1394        let text = indoc! {"
1395            fn main() {
1396                le
1397            }
1398        "};
1399        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1400        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1401        let range = buffer.read_with(cx, |buffer, cx| {
1402            let snapshot = buffer.snapshot(cx);
1403            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1404        });
1405        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1406        let codegen = cx.new(|cx| {
1407            CodegenAlternative::new(
1408                buffer.clone(),
1409                range.clone(),
1410                true,
1411                None,
1412                prompt_builder,
1413                cx,
1414            )
1415        });
1416
1417        let chunks_tx = simulate_response_stream(&codegen, cx);
1418
1419        cx.background_executor.run_until_parked();
1420
1421        let mut new_text = concat!(
1422            "t mut x = 0;\n",
1423            "while x < 10 {\n",
1424            "    x += 1;\n",
1425            "}", //
1426        );
1427        while !new_text.is_empty() {
1428            let max_len = cmp::min(new_text.len(), 10);
1429            let len = rng.random_range(1..=max_len);
1430            let (chunk, suffix) = new_text.split_at(len);
1431            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1432            new_text = suffix;
1433            cx.background_executor.run_until_parked();
1434        }
1435        drop(chunks_tx);
1436        cx.background_executor.run_until_parked();
1437
1438        assert_eq!(
1439            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1440            indoc! {"
1441                fn main() {
1442                    let mut x = 0;
1443                    while x < 10 {
1444                        x += 1;
1445                    }
1446                }
1447            "}
1448        );
1449    }
1450
1451    #[gpui::test(iterations = 10)]
1452    async fn test_autoindent_when_generating_before_indentation(
1453        cx: &mut TestAppContext,
1454        mut rng: StdRng,
1455    ) {
1456        init_test(cx);
1457
1458        let text = concat!(
1459            "fn main() {\n",
1460            "  \n",
1461            "}\n" //
1462        );
1463        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1464        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1465        let range = buffer.read_with(cx, |buffer, cx| {
1466            let snapshot = buffer.snapshot(cx);
1467            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1468        });
1469        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1470        let codegen = cx.new(|cx| {
1471            CodegenAlternative::new(
1472                buffer.clone(),
1473                range.clone(),
1474                true,
1475                None,
1476                prompt_builder,
1477                cx,
1478            )
1479        });
1480
1481        let chunks_tx = simulate_response_stream(&codegen, cx);
1482
1483        cx.background_executor.run_until_parked();
1484
1485        let mut new_text = concat!(
1486            "let mut x = 0;\n",
1487            "while x < 10 {\n",
1488            "    x += 1;\n",
1489            "}", //
1490        );
1491        while !new_text.is_empty() {
1492            let max_len = cmp::min(new_text.len(), 10);
1493            let len = rng.random_range(1..=max_len);
1494            let (chunk, suffix) = new_text.split_at(len);
1495            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1496            new_text = suffix;
1497            cx.background_executor.run_until_parked();
1498        }
1499        drop(chunks_tx);
1500        cx.background_executor.run_until_parked();
1501
1502        assert_eq!(
1503            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1504            indoc! {"
1505                fn main() {
1506                    let mut x = 0;
1507                    while x < 10 {
1508                        x += 1;
1509                    }
1510                }
1511            "}
1512        );
1513    }
1514
1515    #[gpui::test(iterations = 10)]
1516    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1517        init_test(cx);
1518
1519        let text = indoc! {"
1520            func main() {
1521            \tx := 0
1522            \tfor i := 0; i < 10; i++ {
1523            \t\tx++
1524            \t}
1525            }
1526        "};
1527        let buffer = cx.new(|cx| Buffer::local(text, cx));
1528        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1529        let range = buffer.read_with(cx, |buffer, cx| {
1530            let snapshot = buffer.snapshot(cx);
1531            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1532        });
1533        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1534        let codegen = cx.new(|cx| {
1535            CodegenAlternative::new(
1536                buffer.clone(),
1537                range.clone(),
1538                true,
1539                None,
1540                prompt_builder,
1541                cx,
1542            )
1543        });
1544
1545        let chunks_tx = simulate_response_stream(&codegen, cx);
1546        let new_text = concat!(
1547            "func main() {\n",
1548            "\tx := 0\n",
1549            "\tfor x < 10 {\n",
1550            "\t\tx++\n",
1551            "\t}", //
1552        );
1553        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1554        drop(chunks_tx);
1555        cx.background_executor.run_until_parked();
1556
1557        assert_eq!(
1558            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1559            indoc! {"
1560                func main() {
1561                \tx := 0
1562                \tfor x < 10 {
1563                \t\tx++
1564                \t}
1565                }
1566            "}
1567        );
1568    }
1569
1570    #[gpui::test]
1571    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1572        init_test(cx);
1573
1574        let text = indoc! {"
1575            fn main() {
1576                let x = 0;
1577            }
1578        "};
1579        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1580        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1581        let range = buffer.read_with(cx, |buffer, cx| {
1582            let snapshot = buffer.snapshot(cx);
1583            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1584        });
1585        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1586        let codegen = cx.new(|cx| {
1587            CodegenAlternative::new(
1588                buffer.clone(),
1589                range.clone(),
1590                false,
1591                None,
1592                prompt_builder,
1593                cx,
1594            )
1595        });
1596
1597        let chunks_tx = simulate_response_stream(&codegen, cx);
1598        chunks_tx
1599            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1600            .unwrap();
1601        drop(chunks_tx);
1602        cx.run_until_parked();
1603
1604        // The codegen is inactive, so the buffer doesn't get modified.
1605        assert_eq!(
1606            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1607            text
1608        );
1609
1610        // Activating the codegen applies the changes.
1611        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1612        assert_eq!(
1613            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1614            indoc! {"
1615                fn main() {
1616                    let mut x = 0;
1617                    x += 1;
1618                }
1619            "}
1620        );
1621
1622        // Deactivating the codegen undoes the changes.
1623        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1624        cx.run_until_parked();
1625        assert_eq!(
1626            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1627            text
1628        );
1629    }
1630
1631    #[gpui::test]
1632    async fn test_strip_invalid_spans_from_codeblock() {
1633        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1634        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1635        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1636        assert_chunks(
1637            "```html\n```js\nLorem ipsum dolor\n```\n```",
1638            "```js\nLorem ipsum dolor\n```",
1639        )
1640        .await;
1641        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1642        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1643        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1644        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1645
1646        async fn assert_chunks(text: &str, expected_text: &str) {
1647            for chunk_size in 1..=text.len() {
1648                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1649                    .map(|chunk| chunk.unwrap())
1650                    .collect::<String>()
1651                    .await;
1652                assert_eq!(
1653                    actual_text, expected_text,
1654                    "failed to strip invalid spans, chunk size: {}",
1655                    chunk_size
1656                );
1657            }
1658        }
1659
1660        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1661            stream::iter(
1662                text.chars()
1663                    .collect::<Vec<_>>()
1664                    .chunks(size)
1665                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1666                    .collect::<Vec<_>>(),
1667            )
1668        }
1669    }
1670
1671    fn init_test(cx: &mut TestAppContext) {
1672        cx.update(LanguageModelRegistry::test);
1673        cx.set_global(cx.update(SettingsStore::test));
1674    }
1675
1676    fn simulate_response_stream(
1677        codegen: &Entity<CodegenAlternative>,
1678        cx: &mut TestAppContext,
1679    ) -> mpsc::UnboundedSender<String> {
1680        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1681        codegen.update(cx, |codegen, cx| {
1682            codegen.handle_stream(
1683                String::new(),
1684                String::new(),
1685                None,
1686                future::ready(Ok(LanguageModelTextStream {
1687                    message_id: None,
1688                    stream: chunks_rx.map(Ok).boxed(),
1689                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1690                })),
1691                cx,
1692            );
1693        });
1694        chunks_tx
1695    }
1696}