buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use uuid::Uuid;
   5
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   9use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
  10use futures::{
  11    SinkExt, Stream, StreamExt, TryStreamExt as _,
  12    channel::mpsc,
  13    future::{LocalBoxFuture, Shared},
  14    join,
  15    stream::BoxStream,
  16};
  17use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  18use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  19use language_model::{
  20    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  21    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  22    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  23    LanguageModelToolUse, Role, TokenUsage,
  24};
  25use multi_buffer::MultiBufferRow;
  26use parking_lot::Mutex;
  27use prompt_store::PromptBuilder;
  28use rope::Rope;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings as _;
  32use smol::future::FutureExt;
  33use std::{
  34    cmp,
  35    future::Future,
  36    iter,
  37    ops::{Range, RangeInclusive},
  38    pin::Pin,
  39    sync::Arc,
  40    task::{self, Poll},
  41    time::Instant,
  42};
  43use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  44use ui::SharedString;
  45
  46/// Use this tool to provide a message to the user when you're unable to complete a task.
  47#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  48pub struct FailureMessageInput {
  49    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  50    ///
  51    /// The message may use markdown formatting if you wish.
  52    #[serde(default)]
  53    pub message: String,
  54}
  55
  56/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  57#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  58pub struct RewriteSectionInput {
  59    /// A brief description of the edit you have made.
  60    ///
  61    /// The description may use markdown formatting if you wish.
  62    /// This is optional - if the edit is simple or obvious, you should leave it empty.
  63    #[serde(default)]
  64    pub description: String,
  65
  66    /// The text to replace the section with.
  67    #[serde(default)]
  68    pub replacement_text: String,
  69}
  70
  71pub struct BufferCodegen {
  72    alternatives: Vec<Entity<CodegenAlternative>>,
  73    pub active_alternative: usize,
  74    seen_alternatives: HashSet<usize>,
  75    subscriptions: Vec<Subscription>,
  76    buffer: Entity<MultiBuffer>,
  77    range: Range<Anchor>,
  78    initial_transaction_id: Option<TransactionId>,
  79    builder: Arc<PromptBuilder>,
  80    pub is_insertion: bool,
  81    session_id: Uuid,
  82}
  83
  84impl BufferCodegen {
  85    pub fn new(
  86        buffer: Entity<MultiBuffer>,
  87        range: Range<Anchor>,
  88        initial_transaction_id: Option<TransactionId>,
  89        session_id: Uuid,
  90        builder: Arc<PromptBuilder>,
  91        cx: &mut Context<Self>,
  92    ) -> Self {
  93        let codegen = cx.new(|cx| {
  94            CodegenAlternative::new(
  95                buffer.clone(),
  96                range.clone(),
  97                false,
  98                builder.clone(),
  99                session_id,
 100                cx,
 101            )
 102        });
 103        let mut this = Self {
 104            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
 105            alternatives: vec![codegen],
 106            active_alternative: 0,
 107            seen_alternatives: HashSet::default(),
 108            subscriptions: Vec::new(),
 109            buffer,
 110            range,
 111            initial_transaction_id,
 112            builder,
 113            session_id,
 114        };
 115        this.activate(0, cx);
 116        this
 117    }
 118
 119    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 120        let codegen = self.active_alternative().clone();
 121        self.subscriptions.clear();
 122        self.subscriptions
 123            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 124        self.subscriptions
 125            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 126    }
 127
 128    pub fn active_completion(&self, cx: &App) -> Option<String> {
 129        self.active_alternative().read(cx).current_completion()
 130    }
 131
 132    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 133        &self.alternatives[self.active_alternative]
 134    }
 135
 136    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 137        self.active_alternative().read(cx).language_name(cx)
 138    }
 139
 140    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 141        &self.active_alternative().read(cx).status
 142    }
 143
 144    pub fn alternative_count(&self, cx: &App) -> usize {
 145        LanguageModelRegistry::read_global(cx)
 146            .inline_alternative_models()
 147            .len()
 148            + 1
 149    }
 150
 151    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 152        let next_active_ix = if self.active_alternative == 0 {
 153            self.alternatives.len() - 1
 154        } else {
 155            self.active_alternative - 1
 156        };
 157        self.activate(next_active_ix, cx);
 158    }
 159
 160    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 161        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 162        self.activate(next_active_ix, cx);
 163    }
 164
 165    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 166        self.active_alternative()
 167            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 168        self.seen_alternatives.insert(index);
 169        self.active_alternative = index;
 170        self.active_alternative()
 171            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 172        self.subscribe_to_alternative(cx);
 173        cx.notify();
 174    }
 175
 176    pub fn start(
 177        &mut self,
 178        primary_model: Arc<dyn LanguageModel>,
 179        user_prompt: String,
 180        context_task: Shared<Task<Option<LoadedContext>>>,
 181        cx: &mut Context<Self>,
 182    ) -> Result<()> {
 183        let alternative_models = LanguageModelRegistry::read_global(cx)
 184            .inline_alternative_models()
 185            .to_vec();
 186
 187        self.active_alternative()
 188            .update(cx, |alternative, cx| alternative.undo(cx));
 189        self.activate(0, cx);
 190        self.alternatives.truncate(1);
 191
 192        for _ in 0..alternative_models.len() {
 193            self.alternatives.push(cx.new(|cx| {
 194                CodegenAlternative::new(
 195                    self.buffer.clone(),
 196                    self.range.clone(),
 197                    false,
 198                    self.builder.clone(),
 199                    self.session_id,
 200                    cx,
 201                )
 202            }));
 203        }
 204
 205        for (model, alternative) in iter::once(primary_model)
 206            .chain(alternative_models)
 207            .zip(&self.alternatives)
 208        {
 209            alternative.update(cx, |alternative, cx| {
 210                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 211            })?;
 212        }
 213
 214        Ok(())
 215    }
 216
 217    pub fn stop(&mut self, cx: &mut Context<Self>) {
 218        for codegen in &self.alternatives {
 219            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 220        }
 221    }
 222
 223    pub fn undo(&mut self, cx: &mut Context<Self>) {
 224        self.active_alternative()
 225            .update(cx, |codegen, cx| codegen.undo(cx));
 226
 227        self.buffer.update(cx, |buffer, cx| {
 228            if let Some(transaction_id) = self.initial_transaction_id.take() {
 229                buffer.undo_transaction(transaction_id, cx);
 230                buffer.refresh_preview(cx);
 231            }
 232        });
 233    }
 234
 235    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 236        self.active_alternative().read(cx).buffer.clone()
 237    }
 238
 239    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 240        self.active_alternative().read(cx).old_buffer.clone()
 241    }
 242
 243    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 244        self.active_alternative().read(cx).snapshot.clone()
 245    }
 246
 247    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 248        self.active_alternative().read(cx).edit_position
 249    }
 250
 251    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 252        &self.active_alternative().read(cx).diff
 253    }
 254
 255    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 256        self.active_alternative().read(cx).last_equal_ranges()
 257    }
 258
 259    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 260        self.active_alternative().read(cx).selected_text()
 261    }
 262
 263    pub fn session_id(&self) -> Uuid {
 264        self.session_id
 265    }
 266}
 267
 268impl EventEmitter<CodegenEvent> for BufferCodegen {}
 269
 270pub struct CodegenAlternative {
 271    buffer: Entity<MultiBuffer>,
 272    old_buffer: Entity<Buffer>,
 273    snapshot: MultiBufferSnapshot,
 274    edit_position: Option<Anchor>,
 275    range: Range<Anchor>,
 276    last_equal_ranges: Vec<Range<Anchor>>,
 277    transformation_transaction_id: Option<TransactionId>,
 278    status: CodegenStatus,
 279    generation: Task<()>,
 280    diff: Diff,
 281    _subscription: gpui::Subscription,
 282    builder: Arc<PromptBuilder>,
 283    active: bool,
 284    edits: Vec<(Range<Anchor>, String)>,
 285    line_operations: Vec<LineOperation>,
 286    elapsed_time: Option<f64>,
 287    completion: Option<String>,
 288    selected_text: Option<String>,
 289    pub message_id: Option<String>,
 290    pub model_explanation: Option<SharedString>,
 291    session_id: Uuid,
 292}
 293
 294impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 295
 296impl CodegenAlternative {
 297    pub fn new(
 298        buffer: Entity<MultiBuffer>,
 299        range: Range<Anchor>,
 300        active: bool,
 301        builder: Arc<PromptBuilder>,
 302        session_id: Uuid,
 303        cx: &mut Context<Self>,
 304    ) -> Self {
 305        let snapshot = buffer.read(cx).snapshot(cx);
 306
 307        let (old_buffer, _, _) = snapshot
 308            .range_to_buffer_ranges(range.clone())
 309            .pop()
 310            .unwrap();
 311        let old_buffer = cx.new(|cx| {
 312            let text = old_buffer.as_rope().clone();
 313            let line_ending = old_buffer.line_ending();
 314            let language = old_buffer.language().cloned();
 315            let language_registry = buffer
 316                .read(cx)
 317                .buffer(old_buffer.remote_id())
 318                .unwrap()
 319                .read(cx)
 320                .language_registry();
 321
 322            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 323            buffer.set_language(language, cx);
 324            if let Some(language_registry) = language_registry {
 325                buffer.set_language_registry(language_registry);
 326            }
 327            buffer
 328        });
 329
 330        Self {
 331            buffer: buffer.clone(),
 332            old_buffer,
 333            edit_position: None,
 334            message_id: None,
 335            snapshot,
 336            last_equal_ranges: Default::default(),
 337            transformation_transaction_id: None,
 338            status: CodegenStatus::Idle,
 339            generation: Task::ready(()),
 340            diff: Diff::default(),
 341            builder,
 342            active: active,
 343            edits: Vec::new(),
 344            line_operations: Vec::new(),
 345            range,
 346            elapsed_time: None,
 347            completion: None,
 348            selected_text: None,
 349            model_explanation: None,
 350            session_id,
 351            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 352        }
 353    }
 354
 355    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 356        self.old_buffer
 357            .read(cx)
 358            .language()
 359            .map(|language| language.name())
 360    }
 361
 362    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 363        if active != self.active {
 364            self.active = active;
 365
 366            if self.active {
 367                let edits = self.edits.clone();
 368                self.apply_edits(edits, cx);
 369                if matches!(self.status, CodegenStatus::Pending) {
 370                    let line_operations = self.line_operations.clone();
 371                    self.reapply_line_based_diff(line_operations, cx);
 372                } else {
 373                    self.reapply_batch_diff(cx).detach();
 374                }
 375            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 376                self.buffer.update(cx, |buffer, cx| {
 377                    buffer.undo_transaction(transaction_id, cx);
 378                    buffer.forget_transaction(transaction_id, cx);
 379                });
 380            }
 381        }
 382    }
 383
 384    fn handle_buffer_event(
 385        &mut self,
 386        _buffer: Entity<MultiBuffer>,
 387        event: &multi_buffer::Event,
 388        cx: &mut Context<Self>,
 389    ) {
 390        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 391            && self.transformation_transaction_id == Some(*transaction_id)
 392        {
 393            self.transformation_transaction_id = None;
 394            self.generation = Task::ready(());
 395            cx.emit(CodegenEvent::Undone);
 396        }
 397    }
 398
 399    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 400        &self.last_equal_ranges
 401    }
 402
 403    fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 404        model.supports_streaming_tools()
 405            && cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
 406            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 407    }
 408
 409    pub fn start(
 410        &mut self,
 411        user_prompt: String,
 412        context_task: Shared<Task<Option<LoadedContext>>>,
 413        model: Arc<dyn LanguageModel>,
 414        cx: &mut Context<Self>,
 415    ) -> Result<()> {
 416        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 417            self.buffer.update(cx, |buffer, cx| {
 418                buffer.undo_transaction(transformation_transaction_id, cx);
 419            });
 420        }
 421
 422        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 423
 424        if Self::use_streaming_tools(model.as_ref(), cx) {
 425            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 426            let completion_events = cx.spawn({
 427                let model = model.clone();
 428                async move |_, cx| model.stream_completion(request.await, cx).await
 429            });
 430            self.generation = self.handle_completion(model, completion_events, cx);
 431        } else {
 432            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 433                if user_prompt.trim().to_lowercase() == "delete" {
 434                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 435                } else {
 436                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 437                    cx.spawn({
 438                        let model = model.clone();
 439                        async move |_, cx| {
 440                            Ok(model.stream_completion_text(request.await, cx).await?)
 441                        }
 442                    })
 443                    .boxed_local()
 444                };
 445            self.generation = self.handle_stream(model, stream, cx);
 446        }
 447
 448        Ok(())
 449    }
 450
 451    fn build_request_tools(
 452        &self,
 453        model: &Arc<dyn LanguageModel>,
 454        user_prompt: String,
 455        context_task: Shared<Task<Option<LoadedContext>>>,
 456        cx: &mut App,
 457    ) -> Result<Task<LanguageModelRequest>> {
 458        let buffer = self.buffer.read(cx).snapshot(cx);
 459        let language = buffer.language_at(self.range.start);
 460        let language_name = if let Some(language) = language.as_ref() {
 461            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 462                None
 463            } else {
 464                Some(language.name())
 465            }
 466        } else {
 467            None
 468        };
 469
 470        let language_name = language_name.as_ref();
 471        let start = buffer.point_to_buffer_offset(self.range.start);
 472        let end = buffer.point_to_buffer_offset(self.range.end);
 473        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 474            let (start_buffer, start_buffer_offset) = start;
 475            let (end_buffer, end_buffer_offset) = end;
 476            if start_buffer.remote_id() == end_buffer.remote_id() {
 477                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 478            } else {
 479                anyhow::bail!("invalid transformation range");
 480            }
 481        } else {
 482            anyhow::bail!("invalid transformation range");
 483        };
 484
 485        let system_prompt = self
 486            .builder
 487            .generate_inline_transformation_prompt_tools(
 488                language_name,
 489                buffer,
 490                range.start.0..range.end.0,
 491            )
 492            .context("generating content prompt")?;
 493
 494        let temperature = AgentSettings::temperature_for_model(model, cx);
 495
 496        let tool_input_format = model.tool_input_format();
 497        let tool_choice = model
 498            .supports_tool_choice(LanguageModelToolChoice::Any)
 499            .then_some(LanguageModelToolChoice::Any);
 500
 501        Ok(cx.spawn(async move |_cx| {
 502            let mut messages = vec![LanguageModelRequestMessage {
 503                role: Role::System,
 504                content: vec![system_prompt.into()],
 505                cache: false,
 506                reasoning_details: None,
 507            }];
 508
 509            let mut user_message = LanguageModelRequestMessage {
 510                role: Role::User,
 511                content: Vec::new(),
 512                cache: false,
 513                reasoning_details: None,
 514            };
 515
 516            if let Some(context) = context_task.await {
 517                context.add_to_request_message(&mut user_message);
 518            }
 519
 520            user_message.content.push(user_prompt.into());
 521            messages.push(user_message);
 522
 523            let tools = vec![
 524                LanguageModelRequestTool {
 525                    name: "rewrite_section".to_string(),
 526                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 527                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 528                },
 529                LanguageModelRequestTool {
 530                    name: "failure_message".to_string(),
 531                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 532                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 533                },
 534            ];
 535
 536            LanguageModelRequest {
 537                thread_id: None,
 538                prompt_id: None,
 539                intent: Some(CompletionIntent::InlineAssist),
 540                mode: None,
 541                tools,
 542                tool_choice,
 543                stop: Vec::new(),
 544                temperature,
 545                messages,
 546                thinking_allowed: false,
 547            }
 548        }))
 549    }
 550
 551    fn build_request(
 552        &self,
 553        model: &Arc<dyn LanguageModel>,
 554        user_prompt: String,
 555        context_task: Shared<Task<Option<LoadedContext>>>,
 556        cx: &mut App,
 557    ) -> Result<Task<LanguageModelRequest>> {
 558        if Self::use_streaming_tools(model.as_ref(), cx) {
 559            return self.build_request_tools(model, user_prompt, context_task, cx);
 560        }
 561
 562        let buffer = self.buffer.read(cx).snapshot(cx);
 563        let language = buffer.language_at(self.range.start);
 564        let language_name = if let Some(language) = language.as_ref() {
 565            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 566                None
 567            } else {
 568                Some(language.name())
 569            }
 570        } else {
 571            None
 572        };
 573
 574        let language_name = language_name.as_ref();
 575        let start = buffer.point_to_buffer_offset(self.range.start);
 576        let end = buffer.point_to_buffer_offset(self.range.end);
 577        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 578            let (start_buffer, start_buffer_offset) = start;
 579            let (end_buffer, end_buffer_offset) = end;
 580            if start_buffer.remote_id() == end_buffer.remote_id() {
 581                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 582            } else {
 583                anyhow::bail!("invalid transformation range");
 584            }
 585        } else {
 586            anyhow::bail!("invalid transformation range");
 587        };
 588
 589        let prompt = self
 590            .builder
 591            .generate_inline_transformation_prompt(
 592                user_prompt,
 593                language_name,
 594                buffer,
 595                range.start.0..range.end.0,
 596            )
 597            .context("generating content prompt")?;
 598
 599        let temperature = AgentSettings::temperature_for_model(model, cx);
 600
 601        Ok(cx.spawn(async move |_cx| {
 602            let mut request_message = LanguageModelRequestMessage {
 603                role: Role::User,
 604                content: Vec::new(),
 605                cache: false,
 606                reasoning_details: None,
 607            };
 608
 609            if let Some(context) = context_task.await {
 610                context.add_to_request_message(&mut request_message);
 611            }
 612
 613            request_message.content.push(prompt.into());
 614
 615            LanguageModelRequest {
 616                thread_id: None,
 617                prompt_id: None,
 618                intent: Some(CompletionIntent::InlineAssist),
 619                mode: None,
 620                tools: Vec::new(),
 621                tool_choice: None,
 622                stop: Vec::new(),
 623                temperature,
 624                messages: vec![request_message],
 625                thinking_allowed: false,
 626            }
 627        }))
 628    }
 629
 630    pub fn handle_stream(
 631        &mut self,
 632        model: Arc<dyn LanguageModel>,
 633        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 634        cx: &mut Context<Self>,
 635    ) -> Task<()> {
 636        let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
 637        let session_id = self.session_id;
 638        let model_telemetry_id = model.telemetry_id();
 639        let model_provider_id = model.provider_id().to_string();
 640        let start_time = Instant::now();
 641
 642        // Make a new snapshot and re-resolve anchor in case the document was modified.
 643        // This can happen often if the editor loses focus and is saved + reformatted,
 644        // as in https://github.com/zed-industries/zed/issues/39088
 645        self.snapshot = self.buffer.read(cx).snapshot(cx);
 646        self.range = self.snapshot.anchor_after(self.range.start)
 647            ..self.snapshot.anchor_after(self.range.end);
 648
 649        let snapshot = self.snapshot.clone();
 650        let selected_text = snapshot
 651            .text_for_range(self.range.start..self.range.end)
 652            .collect::<Rope>();
 653
 654        self.selected_text = Some(selected_text.to_string());
 655
 656        let selection_start = self.range.start.to_point(&snapshot);
 657
 658        // Start with the indentation of the first line in the selection
 659        let mut suggested_line_indent = snapshot
 660            .suggested_indents(selection_start.row..=selection_start.row, cx)
 661            .into_values()
 662            .next()
 663            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 664
 665        // If the first line in the selection does not have indentation, check the following lines
 666        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 667            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 668                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 669                // Prefer tabs if a line in the selection uses tabs as indentation
 670                if line_indent.kind == IndentKind::Tab {
 671                    suggested_line_indent.kind = IndentKind::Tab;
 672                    break;
 673                }
 674            }
 675        }
 676
 677        let language_name = {
 678            let multibuffer = self.buffer.read(cx);
 679            let snapshot = multibuffer.snapshot(cx);
 680            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
 681            ranges
 682                .first()
 683                .and_then(|(buffer, _, _)| buffer.language())
 684                .map(|language| language.name())
 685        };
 686
 687        self.diff = Diff::default();
 688        self.status = CodegenStatus::Pending;
 689        let mut edit_start = self.range.start.to_offset(&snapshot);
 690        let completion = Arc::new(Mutex::new(String::new()));
 691        let completion_clone = completion.clone();
 692
 693        cx.notify();
 694        cx.spawn(async move |codegen, cx| {
 695            let stream = stream.await;
 696
 697            let token_usage = stream
 698                .as_ref()
 699                .ok()
 700                .map(|stream| stream.last_token_usage.clone());
 701            let message_id = stream
 702                .as_ref()
 703                .ok()
 704                .and_then(|stream| stream.message_id.clone());
 705            let generate = async {
 706                let model_telemetry_id = model_telemetry_id.clone();
 707                let model_provider_id = model_provider_id.clone();
 708                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 709                let message_id = message_id.clone();
 710                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 711                    let anthropic_reporter = anthropic_reporter.clone();
 712                    let language_name = language_name.clone();
 713                    async move {
 714                        let mut response_latency = None;
 715                        let request_start = Instant::now();
 716                        let diff = async {
 717                            let chunks = StripInvalidSpans::new(
 718                                stream?.stream.map_err(|error| error.into()),
 719                            );
 720                            futures::pin_mut!(chunks);
 721
 722                            let mut diff = StreamingDiff::new(selected_text.to_string());
 723                            let mut line_diff = LineDiff::default();
 724
 725                            let mut new_text = String::new();
 726                            let mut base_indent = None;
 727                            let mut line_indent = None;
 728                            let mut first_line = true;
 729
 730                            while let Some(chunk) = chunks.next().await {
 731                                if response_latency.is_none() {
 732                                    response_latency = Some(request_start.elapsed());
 733                                }
 734                                let chunk = chunk?;
 735                                completion_clone.lock().push_str(&chunk);
 736
 737                                let mut lines = chunk.split('\n').peekable();
 738                                while let Some(line) = lines.next() {
 739                                    new_text.push_str(line);
 740                                    if line_indent.is_none()
 741                                        && let Some(non_whitespace_ch_ix) =
 742                                            new_text.find(|ch: char| !ch.is_whitespace())
 743                                    {
 744                                        line_indent = Some(non_whitespace_ch_ix);
 745                                        base_indent = base_indent.or(line_indent);
 746
 747                                        let line_indent = line_indent.unwrap();
 748                                        let base_indent = base_indent.unwrap();
 749                                        let indent_delta = line_indent as i32 - base_indent as i32;
 750                                        let mut corrected_indent_len = cmp::max(
 751                                            0,
 752                                            suggested_line_indent.len as i32 + indent_delta,
 753                                        )
 754                                            as usize;
 755                                        if first_line {
 756                                            corrected_indent_len = corrected_indent_len
 757                                                .saturating_sub(selection_start.column as usize);
 758                                        }
 759
 760                                        let indent_char = suggested_line_indent.char();
 761                                        let mut indent_buffer = [0; 4];
 762                                        let indent_str =
 763                                            indent_char.encode_utf8(&mut indent_buffer);
 764                                        new_text.replace_range(
 765                                            ..line_indent,
 766                                            &indent_str.repeat(corrected_indent_len),
 767                                        );
 768                                    }
 769
 770                                    if line_indent.is_some() {
 771                                        let char_ops = diff.push_new(&new_text);
 772                                        line_diff.push_char_operations(&char_ops, &selected_text);
 773                                        diff_tx
 774                                            .send((char_ops, line_diff.line_operations()))
 775                                            .await?;
 776                                        new_text.clear();
 777                                    }
 778
 779                                    if lines.peek().is_some() {
 780                                        let char_ops = diff.push_new("\n");
 781                                        line_diff.push_char_operations(&char_ops, &selected_text);
 782                                        diff_tx
 783                                            .send((char_ops, line_diff.line_operations()))
 784                                            .await?;
 785                                        if line_indent.is_none() {
 786                                            // Don't write out the leading indentation in empty lines on the next line
 787                                            // This is the case where the above if statement didn't clear the buffer
 788                                            new_text.clear();
 789                                        }
 790                                        line_indent = None;
 791                                        first_line = false;
 792                                    }
 793                                }
 794                            }
 795
 796                            let mut char_ops = diff.push_new(&new_text);
 797                            char_ops.extend(diff.finish());
 798                            line_diff.push_char_operations(&char_ops, &selected_text);
 799                            line_diff.finish(&selected_text);
 800                            diff_tx
 801                                .send((char_ops, line_diff.line_operations()))
 802                                .await?;
 803
 804                            anyhow::Ok(())
 805                        };
 806
 807                        let result = diff.await;
 808
 809                        let error_message = result.as_ref().err().map(|error| error.to_string());
 810                        telemetry::event!(
 811                            "Assistant Responded",
 812                            kind = "inline",
 813                            phase = "response",
 814                            session_id = session_id.to_string(),
 815                            model = model_telemetry_id,
 816                            model_provider = model_provider_id,
 817                            language_name = language_name.as_ref().map(|n| n.to_string()),
 818                            message_id = message_id.as_deref(),
 819                            response_latency = response_latency,
 820                            error_message = error_message.as_deref(),
 821                        );
 822
 823                        anthropic_reporter.report(language_model::AnthropicEventData {
 824                            completion_type: language_model::AnthropicCompletionType::Editor,
 825                            event: language_model::AnthropicEventType::Response,
 826                            language_name: language_name.map(|n| n.to_string()),
 827                            message_id,
 828                        });
 829
 830                        result?;
 831                        Ok(())
 832                    }
 833                });
 834
 835                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 836                    codegen.update(cx, |codegen, cx| {
 837                        codegen.last_equal_ranges.clear();
 838
 839                        let edits = char_ops
 840                            .into_iter()
 841                            .filter_map(|operation| match operation {
 842                                CharOperation::Insert { text } => {
 843                                    let edit_start = snapshot.anchor_after(edit_start);
 844                                    Some((edit_start..edit_start, text))
 845                                }
 846                                CharOperation::Delete { bytes } => {
 847                                    let edit_end = edit_start + bytes;
 848                                    let edit_range = snapshot.anchor_after(edit_start)
 849                                        ..snapshot.anchor_before(edit_end);
 850                                    edit_start = edit_end;
 851                                    Some((edit_range, String::new()))
 852                                }
 853                                CharOperation::Keep { bytes } => {
 854                                    let edit_end = edit_start + bytes;
 855                                    let edit_range = snapshot.anchor_after(edit_start)
 856                                        ..snapshot.anchor_before(edit_end);
 857                                    edit_start = edit_end;
 858                                    codegen.last_equal_ranges.push(edit_range);
 859                                    None
 860                                }
 861                            })
 862                            .collect::<Vec<_>>();
 863
 864                        if codegen.active {
 865                            codegen.apply_edits(edits.iter().cloned(), cx);
 866                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 867                        }
 868                        codegen.edits.extend(edits);
 869                        codegen.line_operations = line_ops;
 870                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 871
 872                        cx.notify();
 873                    })?;
 874                }
 875
 876                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 877                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 878                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 879                let batch_diff_task =
 880                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 881                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 882                line_based_stream_diff?;
 883
 884                anyhow::Ok(())
 885            };
 886
 887            let result = generate.await;
 888            let elapsed_time = start_time.elapsed().as_secs_f64();
 889
 890            codegen
 891                .update(cx, |this, cx| {
 892                    this.message_id = message_id;
 893                    this.last_equal_ranges.clear();
 894                    if let Err(error) = result {
 895                        this.status = CodegenStatus::Error(error);
 896                    } else {
 897                        this.status = CodegenStatus::Done;
 898                    }
 899                    this.elapsed_time = Some(elapsed_time);
 900                    this.completion = Some(completion.lock().clone());
 901                    if let Some(usage) = token_usage {
 902                        let usage = usage.lock();
 903                        telemetry::event!(
 904                            "Inline Assistant Completion",
 905                            model = model_telemetry_id,
 906                            model_provider = model_provider_id,
 907                            input_tokens = usage.input_tokens,
 908                            output_tokens = usage.output_tokens,
 909                        )
 910                    }
 911
 912                    cx.emit(CodegenEvent::Finished);
 913                    cx.notify();
 914                })
 915                .ok();
 916        })
 917    }
 918
 919    pub fn current_completion(&self) -> Option<String> {
 920        self.completion.clone()
 921    }
 922
 923    pub fn selected_text(&self) -> Option<&str> {
 924        self.selected_text.as_deref()
 925    }
 926
 927    pub fn stop(&mut self, cx: &mut Context<Self>) {
 928        self.last_equal_ranges.clear();
 929        if self.diff.is_empty() {
 930            self.status = CodegenStatus::Idle;
 931        } else {
 932            self.status = CodegenStatus::Done;
 933        }
 934        self.generation = Task::ready(());
 935        cx.emit(CodegenEvent::Finished);
 936        cx.notify();
 937    }
 938
 939    pub fn undo(&mut self, cx: &mut Context<Self>) {
 940        self.buffer.update(cx, |buffer, cx| {
 941            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 942                buffer.undo_transaction(transaction_id, cx);
 943                buffer.refresh_preview(cx);
 944            }
 945        });
 946    }
 947
 948    fn apply_edits(
 949        &mut self,
 950        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 951        cx: &mut Context<CodegenAlternative>,
 952    ) {
 953        let transaction = self.buffer.update(cx, |buffer, cx| {
 954            // Avoid grouping agent edits with user edits.
 955            buffer.finalize_last_transaction(cx);
 956            buffer.start_transaction(cx);
 957            buffer.edit(edits, None, cx);
 958            buffer.end_transaction(cx)
 959        });
 960
 961        if let Some(transaction) = transaction {
 962            if let Some(first_transaction) = self.transformation_transaction_id {
 963                // Group all agent edits into the first transaction.
 964                self.buffer.update(cx, |buffer, cx| {
 965                    buffer.merge_transactions(transaction, first_transaction, cx)
 966                });
 967            } else {
 968                self.transformation_transaction_id = Some(transaction);
 969                self.buffer
 970                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 971            }
 972        }
 973    }
 974
 975    fn reapply_line_based_diff(
 976        &mut self,
 977        line_operations: impl IntoIterator<Item = LineOperation>,
 978        cx: &mut Context<Self>,
 979    ) {
 980        let old_snapshot = self.snapshot.clone();
 981        let old_range = self.range.to_point(&old_snapshot);
 982        let new_snapshot = self.buffer.read(cx).snapshot(cx);
 983        let new_range = self.range.to_point(&new_snapshot);
 984
 985        let mut old_row = old_range.start.row;
 986        let mut new_row = new_range.start.row;
 987
 988        self.diff.deleted_row_ranges.clear();
 989        self.diff.inserted_row_ranges.clear();
 990        for operation in line_operations {
 991            match operation {
 992                LineOperation::Keep { lines } => {
 993                    old_row += lines;
 994                    new_row += lines;
 995                }
 996                LineOperation::Delete { lines } => {
 997                    let old_end_row = old_row + lines - 1;
 998                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
 999
1000                    if let Some((_, last_deleted_row_range)) =
1001                        self.diff.deleted_row_ranges.last_mut()
1002                    {
1003                        if *last_deleted_row_range.end() + 1 == old_row {
1004                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1005                        } else {
1006                            self.diff
1007                                .deleted_row_ranges
1008                                .push((new_row, old_row..=old_end_row));
1009                        }
1010                    } else {
1011                        self.diff
1012                            .deleted_row_ranges
1013                            .push((new_row, old_row..=old_end_row));
1014                    }
1015
1016                    old_row += lines;
1017                }
1018                LineOperation::Insert { lines } => {
1019                    let new_end_row = new_row + lines - 1;
1020                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1021                    let end = new_snapshot.anchor_before(Point::new(
1022                        new_end_row,
1023                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1024                    ));
1025                    self.diff.inserted_row_ranges.push(start..end);
1026                    new_row += lines;
1027                }
1028            }
1029
1030            cx.notify();
1031        }
1032    }
1033
1034    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1035        let old_snapshot = self.snapshot.clone();
1036        let old_range = self.range.to_point(&old_snapshot);
1037        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1038        let new_range = self.range.to_point(&new_snapshot);
1039
1040        cx.spawn(async move |codegen, cx| {
1041            let (deleted_row_ranges, inserted_row_ranges) = cx
1042                .background_spawn(async move {
1043                    let old_text = old_snapshot
1044                        .text_for_range(
1045                            Point::new(old_range.start.row, 0)
1046                                ..Point::new(
1047                                    old_range.end.row,
1048                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1049                                ),
1050                        )
1051                        .collect::<String>();
1052                    let new_text = new_snapshot
1053                        .text_for_range(
1054                            Point::new(new_range.start.row, 0)
1055                                ..Point::new(
1056                                    new_range.end.row,
1057                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1058                                ),
1059                        )
1060                        .collect::<String>();
1061
1062                    let old_start_row = old_range.start.row;
1063                    let new_start_row = new_range.start.row;
1064                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1065                    let mut inserted_row_ranges = Vec::new();
1066                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1067                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1068                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1069                        if !old_rows.is_empty() {
1070                            deleted_row_ranges.push((
1071                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1072                                old_rows.start..=old_rows.end - 1,
1073                            ));
1074                        }
1075                        if !new_rows.is_empty() {
1076                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1077                            let new_end_row = new_rows.end - 1;
1078                            let end = new_snapshot.anchor_before(Point::new(
1079                                new_end_row,
1080                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1081                            ));
1082                            inserted_row_ranges.push(start..end);
1083                        }
1084                    }
1085                    (deleted_row_ranges, inserted_row_ranges)
1086                })
1087                .await;
1088
1089            codegen
1090                .update(cx, |codegen, cx| {
1091                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1092                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1093                    cx.notify();
1094                })
1095                .ok();
1096        })
1097    }
1098
1099    fn handle_completion(
1100        &mut self,
1101        model: Arc<dyn LanguageModel>,
1102        completion_stream: Task<
1103            Result<
1104                BoxStream<
1105                    'static,
1106                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1107                >,
1108                LanguageModelCompletionError,
1109            >,
1110        >,
1111        cx: &mut Context<Self>,
1112    ) -> Task<()> {
1113        self.diff = Diff::default();
1114        self.status = CodegenStatus::Pending;
1115
1116        cx.notify();
1117        // Leaving this in generation so that STOP equivalent events are respected even
1118        // while we're still pre-processing the completion event
1119        cx.spawn(async move |codegen, cx| {
1120            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1121                let _ = codegen.update(cx, |this, cx| {
1122                    this.status = status;
1123                    cx.emit(CodegenEvent::Finished);
1124                    cx.notify();
1125                });
1126            };
1127
1128            let mut completion_events = match completion_stream.await {
1129                Ok(events) => events,
1130                Err(err) => {
1131                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1132                    return;
1133                }
1134            };
1135
1136            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1137            let tool_to_text_and_message =
1138                move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
1139                    let mut chars_read_so_far = chars_read_so_far.lock();
1140                    match tool_use.name.as_ref() {
1141                        "rewrite_section" => {
1142                            let Ok(mut input) =
1143                                serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1144                            else {
1145                                return (None, None);
1146                            };
1147                            let value = input.replacement_text[*chars_read_so_far..].to_string();
1148                            *chars_read_so_far = input.replacement_text.len();
1149                            (Some(value), Some(std::mem::take(&mut input.description)))
1150                        }
1151                        "failure_message" => {
1152                            let Ok(mut input) =
1153                                serde_json::from_value::<FailureMessageInput>(tool_use.input)
1154                            else {
1155                                return (None, None);
1156                            };
1157                            (None, Some(std::mem::take(&mut input.message)))
1158                        }
1159                        _ => (None, None),
1160                    }
1161                };
1162
1163            let mut message_id = None;
1164            let mut first_text = None;
1165            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1166            let total_text = Arc::new(Mutex::new(String::new()));
1167
1168            loop {
1169                if let Some(first_event) = completion_events.next().await {
1170                    match first_event {
1171                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1172                            message_id = Some(id);
1173                        }
1174                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
1175                            if matches!(
1176                                tool_use.name.as_ref(),
1177                                "rewrite_section" | "failure_message"
1178                            ) =>
1179                        {
1180                            let is_complete = tool_use.is_input_complete;
1181                            let (text, message) = tool_to_text_and_message(tool_use);
1182                            // Only update the model explanation if the tool use is complete.
1183                            // Otherwise the UI element bounces around as it's updated.
1184                            if is_complete {
1185                                let _ = codegen.update(cx, |this, _cx| {
1186                                    this.model_explanation = message.map(Into::into);
1187                                });
1188                            }
1189                            first_text = text;
1190                            if first_text.is_some() {
1191                                break;
1192                            }
1193                        }
1194                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1195                            *last_token_usage.lock() = token_usage;
1196                        }
1197                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1198                            let mut lock = total_text.lock();
1199                            lock.push_str(&text);
1200                        }
1201                        Ok(e) => {
1202                            log::warn!("Unexpected event: {:?}", e);
1203                            break;
1204                        }
1205                        Err(e) => {
1206                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1207                            break;
1208                        }
1209                    }
1210                }
1211            }
1212
1213            let Some(first_text) = first_text else {
1214                finish_with_status(CodegenStatus::Done, cx);
1215                return;
1216            };
1217
1218            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
1219
1220            cx.spawn({
1221                let codegen = codegen.clone();
1222                async move |cx| {
1223                    while let Some(message) = message_rx.next().await {
1224                        let _ = codegen.update(cx, |this, _cx| {
1225                            this.model_explanation = message;
1226                        });
1227                    }
1228                }
1229            })
1230            .detach();
1231
1232            let move_last_token_usage = last_token_usage.clone();
1233
1234            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1235                completion_events.filter_map(move |e| {
1236                    let tool_to_text_and_message = tool_to_text_and_message.clone();
1237                    let last_token_usage = move_last_token_usage.clone();
1238                    let total_text = total_text.clone();
1239                    let mut message_tx = message_tx.clone();
1240                    async move {
1241                        match e {
1242                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
1243                                if matches!(
1244                                    tool_use.name.as_ref(),
1245                                    "rewrite_section" | "failure_message"
1246                                ) =>
1247                            {
1248                                let is_complete = tool_use.is_input_complete;
1249                                let (text, message) = tool_to_text_and_message(tool_use);
1250                                if is_complete {
1251                                    // Again only send the message when complete to not get a bouncing UI element.
1252                                    let _ = message_tx.send(message.map(Into::into)).await;
1253                                }
1254                                text.map(Ok)
1255                            }
1256                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1257                                *last_token_usage.lock() = token_usage;
1258                                None
1259                            }
1260                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1261                                let mut lock = total_text.lock();
1262                                lock.push_str(&text);
1263                                None
1264                            }
1265                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1266                            e => {
1267                                log::error!("UNEXPECTED EVENT {:?}", e);
1268                                None
1269                            }
1270                        }
1271                    }
1272                }),
1273            ));
1274
1275            let language_model_text_stream = LanguageModelTextStream {
1276                message_id: message_id,
1277                stream: text_stream,
1278                last_token_usage,
1279            };
1280
1281            let Some(task) = codegen
1282                .update(cx, move |codegen, cx| {
1283                    codegen.handle_stream(model, async { Ok(language_model_text_stream) }, cx)
1284                })
1285                .ok()
1286            else {
1287                return;
1288            };
1289
1290            task.await;
1291        })
1292    }
1293}
1294
1295#[derive(Copy, Clone, Debug)]
1296pub enum CodegenEvent {
1297    Finished,
1298    Undone,
1299}
1300
1301struct StripInvalidSpans<T> {
1302    stream: T,
1303    stream_done: bool,
1304    buffer: String,
1305    first_line: bool,
1306    line_end: bool,
1307    starts_with_code_block: bool,
1308}
1309
1310impl<T> StripInvalidSpans<T>
1311where
1312    T: Stream<Item = Result<String>>,
1313{
1314    fn new(stream: T) -> Self {
1315        Self {
1316            stream,
1317            stream_done: false,
1318            buffer: String::new(),
1319            first_line: true,
1320            line_end: false,
1321            starts_with_code_block: false,
1322        }
1323    }
1324}
1325
1326impl<T> Stream for StripInvalidSpans<T>
1327where
1328    T: Stream<Item = Result<String>>,
1329{
1330    type Item = Result<String>;
1331
1332    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1333        const CODE_BLOCK_DELIMITER: &str = "```";
1334        const CURSOR_SPAN: &str = "<|CURSOR|>";
1335
1336        let this = unsafe { self.get_unchecked_mut() };
1337        loop {
1338            if !this.stream_done {
1339                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1340                match stream.as_mut().poll_next(cx) {
1341                    Poll::Ready(Some(Ok(chunk))) => {
1342                        this.buffer.push_str(&chunk);
1343                    }
1344                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1345                    Poll::Ready(None) => {
1346                        this.stream_done = true;
1347                    }
1348                    Poll::Pending => return Poll::Pending,
1349                }
1350            }
1351
1352            let mut chunk = String::new();
1353            let mut consumed = 0;
1354            if !this.buffer.is_empty() {
1355                let mut lines = this.buffer.split('\n').enumerate().peekable();
1356                while let Some((line_ix, line)) = lines.next() {
1357                    if line_ix > 0 {
1358                        this.first_line = false;
1359                    }
1360
1361                    if this.first_line {
1362                        let trimmed_line = line.trim();
1363                        if lines.peek().is_some() {
1364                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1365                                consumed += line.len() + 1;
1366                                this.starts_with_code_block = true;
1367                                continue;
1368                            }
1369                        } else if trimmed_line.is_empty()
1370                            || prefixes(CODE_BLOCK_DELIMITER)
1371                                .any(|prefix| trimmed_line.starts_with(prefix))
1372                        {
1373                            break;
1374                        }
1375                    }
1376
1377                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1378                    if lines.peek().is_some() {
1379                        if this.line_end {
1380                            chunk.push('\n');
1381                        }
1382
1383                        chunk.push_str(&line_without_cursor);
1384                        this.line_end = true;
1385                        consumed += line.len() + 1;
1386                    } else if this.stream_done {
1387                        if !this.starts_with_code_block
1388                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1389                        {
1390                            if this.line_end {
1391                                chunk.push('\n');
1392                            }
1393
1394                            chunk.push_str(line);
1395                        }
1396
1397                        consumed += line.len();
1398                    } else {
1399                        let trimmed_line = line.trim();
1400                        if trimmed_line.is_empty()
1401                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1402                            || prefixes(CODE_BLOCK_DELIMITER)
1403                                .any(|prefix| trimmed_line.ends_with(prefix))
1404                        {
1405                            break;
1406                        } else {
1407                            if this.line_end {
1408                                chunk.push('\n');
1409                                this.line_end = false;
1410                            }
1411
1412                            chunk.push_str(&line_without_cursor);
1413                            consumed += line.len();
1414                        }
1415                    }
1416                }
1417            }
1418
1419            this.buffer = this.buffer.split_off(consumed);
1420            if !chunk.is_empty() {
1421                return Poll::Ready(Some(Ok(chunk)));
1422            } else if this.stream_done {
1423                return Poll::Ready(None);
1424            }
1425        }
1426    }
1427}
1428
1429fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1430    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1431}
1432
1433#[derive(Default)]
1434pub struct Diff {
1435    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1436    pub inserted_row_ranges: Vec<Range<Anchor>>,
1437}
1438
1439impl Diff {
1440    fn is_empty(&self) -> bool {
1441        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1442    }
1443}
1444
1445#[cfg(test)]
1446mod tests {
1447    use super::*;
1448    use futures::{
1449        Stream,
1450        stream::{self},
1451    };
1452    use gpui::TestAppContext;
1453    use indoc::indoc;
1454    use language::{Buffer, Point};
1455    use language_model::fake_provider::FakeLanguageModel;
1456    use language_model::{LanguageModelRegistry, TokenUsage};
1457    use languages::rust_lang;
1458    use rand::prelude::*;
1459    use settings::SettingsStore;
1460    use std::{future, sync::Arc};
1461
1462    #[gpui::test(iterations = 10)]
1463    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1464        init_test(cx);
1465
1466        let text = indoc! {"
1467            fn main() {
1468                let x = 0;
1469                for _ in 0..10 {
1470                    x += 1;
1471                }
1472            }
1473        "};
1474        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1475        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1476        let range = buffer.read_with(cx, |buffer, cx| {
1477            let snapshot = buffer.snapshot(cx);
1478            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1479        });
1480        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1481        let codegen = cx.new(|cx| {
1482            CodegenAlternative::new(
1483                buffer.clone(),
1484                range.clone(),
1485                true,
1486                prompt_builder,
1487                Uuid::new_v4(),
1488                cx,
1489            )
1490        });
1491
1492        let chunks_tx = simulate_response_stream(&codegen, cx);
1493
1494        let mut new_text = concat!(
1495            "       let mut x = 0;\n",
1496            "       while x < 10 {\n",
1497            "           x += 1;\n",
1498            "       }",
1499        );
1500        while !new_text.is_empty() {
1501            let max_len = cmp::min(new_text.len(), 10);
1502            let len = rng.random_range(1..=max_len);
1503            let (chunk, suffix) = new_text.split_at(len);
1504            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1505            new_text = suffix;
1506            cx.background_executor.run_until_parked();
1507        }
1508        drop(chunks_tx);
1509        cx.background_executor.run_until_parked();
1510
1511        assert_eq!(
1512            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1513            indoc! {"
1514                fn main() {
1515                    let mut x = 0;
1516                    while x < 10 {
1517                        x += 1;
1518                    }
1519                }
1520            "}
1521        );
1522    }
1523
1524    #[gpui::test(iterations = 10)]
1525    async fn test_autoindent_when_generating_past_indentation(
1526        cx: &mut TestAppContext,
1527        mut rng: StdRng,
1528    ) {
1529        init_test(cx);
1530
1531        let text = indoc! {"
1532            fn main() {
1533                le
1534            }
1535        "};
1536        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1537        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1538        let range = buffer.read_with(cx, |buffer, cx| {
1539            let snapshot = buffer.snapshot(cx);
1540            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1541        });
1542        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1543        let codegen = cx.new(|cx| {
1544            CodegenAlternative::new(
1545                buffer.clone(),
1546                range.clone(),
1547                true,
1548                prompt_builder,
1549                Uuid::new_v4(),
1550                cx,
1551            )
1552        });
1553
1554        let chunks_tx = simulate_response_stream(&codegen, cx);
1555
1556        cx.background_executor.run_until_parked();
1557
1558        let mut new_text = concat!(
1559            "t mut x = 0;\n",
1560            "while x < 10 {\n",
1561            "    x += 1;\n",
1562            "}", //
1563        );
1564        while !new_text.is_empty() {
1565            let max_len = cmp::min(new_text.len(), 10);
1566            let len = rng.random_range(1..=max_len);
1567            let (chunk, suffix) = new_text.split_at(len);
1568            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1569            new_text = suffix;
1570            cx.background_executor.run_until_parked();
1571        }
1572        drop(chunks_tx);
1573        cx.background_executor.run_until_parked();
1574
1575        assert_eq!(
1576            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1577            indoc! {"
1578                fn main() {
1579                    let mut x = 0;
1580                    while x < 10 {
1581                        x += 1;
1582                    }
1583                }
1584            "}
1585        );
1586    }
1587
1588    #[gpui::test(iterations = 10)]
1589    async fn test_autoindent_when_generating_before_indentation(
1590        cx: &mut TestAppContext,
1591        mut rng: StdRng,
1592    ) {
1593        init_test(cx);
1594
1595        let text = concat!(
1596            "fn main() {\n",
1597            "  \n",
1598            "}\n" //
1599        );
1600        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1601        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1602        let range = buffer.read_with(cx, |buffer, cx| {
1603            let snapshot = buffer.snapshot(cx);
1604            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1605        });
1606        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1607        let codegen = cx.new(|cx| {
1608            CodegenAlternative::new(
1609                buffer.clone(),
1610                range.clone(),
1611                true,
1612                prompt_builder,
1613                Uuid::new_v4(),
1614                cx,
1615            )
1616        });
1617
1618        let chunks_tx = simulate_response_stream(&codegen, cx);
1619
1620        cx.background_executor.run_until_parked();
1621
1622        let mut new_text = concat!(
1623            "let mut x = 0;\n",
1624            "while x < 10 {\n",
1625            "    x += 1;\n",
1626            "}", //
1627        );
1628        while !new_text.is_empty() {
1629            let max_len = cmp::min(new_text.len(), 10);
1630            let len = rng.random_range(1..=max_len);
1631            let (chunk, suffix) = new_text.split_at(len);
1632            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1633            new_text = suffix;
1634            cx.background_executor.run_until_parked();
1635        }
1636        drop(chunks_tx);
1637        cx.background_executor.run_until_parked();
1638
1639        assert_eq!(
1640            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1641            indoc! {"
1642                fn main() {
1643                    let mut x = 0;
1644                    while x < 10 {
1645                        x += 1;
1646                    }
1647                }
1648            "}
1649        );
1650    }
1651
1652    #[gpui::test(iterations = 10)]
1653    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1654        init_test(cx);
1655
1656        let text = indoc! {"
1657            func main() {
1658            \tx := 0
1659            \tfor i := 0; i < 10; i++ {
1660            \t\tx++
1661            \t}
1662            }
1663        "};
1664        let buffer = cx.new(|cx| Buffer::local(text, cx));
1665        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1666        let range = buffer.read_with(cx, |buffer, cx| {
1667            let snapshot = buffer.snapshot(cx);
1668            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1669        });
1670        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1671        let codegen = cx.new(|cx| {
1672            CodegenAlternative::new(
1673                buffer.clone(),
1674                range.clone(),
1675                true,
1676                prompt_builder,
1677                Uuid::new_v4(),
1678                cx,
1679            )
1680        });
1681
1682        let chunks_tx = simulate_response_stream(&codegen, cx);
1683        let new_text = concat!(
1684            "func main() {\n",
1685            "\tx := 0\n",
1686            "\tfor x < 10 {\n",
1687            "\t\tx++\n",
1688            "\t}", //
1689        );
1690        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1691        drop(chunks_tx);
1692        cx.background_executor.run_until_parked();
1693
1694        assert_eq!(
1695            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1696            indoc! {"
1697                func main() {
1698                \tx := 0
1699                \tfor x < 10 {
1700                \t\tx++
1701                \t}
1702                }
1703            "}
1704        );
1705    }
1706
1707    #[gpui::test]
1708    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1709        init_test(cx);
1710
1711        let text = indoc! {"
1712            fn main() {
1713                let x = 0;
1714            }
1715        "};
1716        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1717        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1718        let range = buffer.read_with(cx, |buffer, cx| {
1719            let snapshot = buffer.snapshot(cx);
1720            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1721        });
1722        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1723        let codegen = cx.new(|cx| {
1724            CodegenAlternative::new(
1725                buffer.clone(),
1726                range.clone(),
1727                false,
1728                prompt_builder,
1729                Uuid::new_v4(),
1730                cx,
1731            )
1732        });
1733
1734        let chunks_tx = simulate_response_stream(&codegen, cx);
1735        chunks_tx
1736            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1737            .unwrap();
1738        drop(chunks_tx);
1739        cx.run_until_parked();
1740
1741        // The codegen is inactive, so the buffer doesn't get modified.
1742        assert_eq!(
1743            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1744            text
1745        );
1746
1747        // Activating the codegen applies the changes.
1748        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1749        assert_eq!(
1750            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1751            indoc! {"
1752                fn main() {
1753                    let mut x = 0;
1754                    x += 1;
1755                }
1756            "}
1757        );
1758
1759        // Deactivating the codegen undoes the changes.
1760        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1761        cx.run_until_parked();
1762        assert_eq!(
1763            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1764            text
1765        );
1766    }
1767
1768    #[gpui::test]
1769    async fn test_strip_invalid_spans_from_codeblock() {
1770        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1771        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1772        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1773        assert_chunks(
1774            "```html\n```js\nLorem ipsum dolor\n```\n```",
1775            "```js\nLorem ipsum dolor\n```",
1776        )
1777        .await;
1778        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1779        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1780        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1781        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1782
1783        async fn assert_chunks(text: &str, expected_text: &str) {
1784            for chunk_size in 1..=text.len() {
1785                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1786                    .map(|chunk| chunk.unwrap())
1787                    .collect::<String>()
1788                    .await;
1789                assert_eq!(
1790                    actual_text, expected_text,
1791                    "failed to strip invalid spans, chunk size: {}",
1792                    chunk_size
1793                );
1794            }
1795        }
1796
1797        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1798            stream::iter(
1799                text.chars()
1800                    .collect::<Vec<_>>()
1801                    .chunks(size)
1802                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1803                    .collect::<Vec<_>>(),
1804            )
1805        }
1806    }
1807
1808    fn init_test(cx: &mut TestAppContext) {
1809        cx.update(LanguageModelRegistry::test);
1810        cx.set_global(cx.update(SettingsStore::test));
1811    }
1812
1813    fn simulate_response_stream(
1814        codegen: &Entity<CodegenAlternative>,
1815        cx: &mut TestAppContext,
1816    ) -> mpsc::UnboundedSender<String> {
1817        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1818        let model = Arc::new(FakeLanguageModel::default());
1819        codegen.update(cx, |codegen, cx| {
1820            codegen.generation = codegen.handle_stream(
1821                model,
1822                future::ready(Ok(LanguageModelTextStream {
1823                    message_id: None,
1824                    stream: chunks_rx.map(Ok).boxed(),
1825                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1826                })),
1827                cx,
1828            );
1829        });
1830        chunks_tx
1831    }
1832}