buffer_codegen.rs

   1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
   2use agent_settings::AgentSettings;
   3use anyhow::{Context as _, Result};
   4use collections::HashSet;
   5use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
   6use futures::{
   7    SinkExt, Stream, StreamExt, TryStreamExt as _,
   8    channel::mpsc,
   9    future::{LocalBoxFuture, Shared},
  10    join,
  11    stream::BoxStream,
  12};
  13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
  14use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
  15use language_model::{
  16    CompletionIntent, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  17    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  18    LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
  19    LanguageModelToolUse, Role, TokenUsage,
  20};
  21use language_models::provider::anthropic::telemetry::{
  22    AnthropicCompletionType, AnthropicEventData, AnthropicEventReporter, AnthropicEventType,
  23};
  24use multi_buffer::MultiBufferRow;
  25use parking_lot::Mutex;
  26use prompt_store::PromptBuilder;
  27use rope::Rope;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings as _;
  31use smol::future::FutureExt;
  32use std::{
  33    cmp,
  34    future::Future,
  35    iter,
  36    ops::{Range, RangeInclusive},
  37    pin::Pin,
  38    sync::Arc,
  39    task::{self, Poll},
  40    time::Instant,
  41};
  42use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  43use uuid::Uuid;
  44
  45/// Use this tool when you cannot or should not make a rewrite. This includes:
  46/// - The user's request is unclear, ambiguous, or nonsensical
  47/// - The requested change cannot be made by only editing the <rewrite_this> section
  48#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  49pub struct FailureMessageInput {
  50    /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
  51    #[serde(default)]
  52    pub message: String,
  53}
  54
  55/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
  56/// Only use this tool when you are confident you understand the user's request and can fulfill it
  57/// by editing the marked section.
  58#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  59pub struct RewriteSectionInput {
  60    /// The text to replace the section with.
  61    #[serde(default)]
  62    pub replacement_text: String,
  63}
  64
  65pub struct BufferCodegen {
  66    alternatives: Vec<Entity<CodegenAlternative>>,
  67    pub active_alternative: usize,
  68    seen_alternatives: HashSet<usize>,
  69    subscriptions: Vec<Subscription>,
  70    buffer: Entity<MultiBuffer>,
  71    range: Range<Anchor>,
  72    initial_transaction_id: Option<TransactionId>,
  73    builder: Arc<PromptBuilder>,
  74    pub is_insertion: bool,
  75    session_id: Uuid,
  76}
  77
  78pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
  79pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
  80
  81impl BufferCodegen {
  82    pub fn new(
  83        buffer: Entity<MultiBuffer>,
  84        range: Range<Anchor>,
  85        initial_transaction_id: Option<TransactionId>,
  86        session_id: Uuid,
  87        builder: Arc<PromptBuilder>,
  88        cx: &mut Context<Self>,
  89    ) -> Self {
  90        let codegen = cx.new(|cx| {
  91            CodegenAlternative::new(
  92                buffer.clone(),
  93                range.clone(),
  94                false,
  95                builder.clone(),
  96                session_id,
  97                cx,
  98            )
  99        });
 100        let mut this = Self {
 101            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
 102            alternatives: vec![codegen],
 103            active_alternative: 0,
 104            seen_alternatives: HashSet::default(),
 105            subscriptions: Vec::new(),
 106            buffer,
 107            range,
 108            initial_transaction_id,
 109            builder,
 110            session_id,
 111        };
 112        this.activate(0, cx);
 113        this
 114    }
 115
 116    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
 117        let codegen = self.active_alternative().clone();
 118        self.subscriptions.clear();
 119        self.subscriptions
 120            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
 121        self.subscriptions
 122            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
 123    }
 124
 125    pub fn active_completion(&self, cx: &App) -> Option<String> {
 126        self.active_alternative().read(cx).current_completion()
 127    }
 128
 129    pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
 130        &self.alternatives[self.active_alternative]
 131    }
 132
 133    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 134        self.active_alternative().read(cx).language_name(cx)
 135    }
 136
 137    pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
 138        &self.active_alternative().read(cx).status
 139    }
 140
 141    pub fn alternative_count(&self, cx: &App) -> usize {
 142        LanguageModelRegistry::read_global(cx)
 143            .inline_alternative_models()
 144            .len()
 145            + 1
 146    }
 147
 148    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
 149        let next_active_ix = if self.active_alternative == 0 {
 150            self.alternatives.len() - 1
 151        } else {
 152            self.active_alternative - 1
 153        };
 154        self.activate(next_active_ix, cx);
 155    }
 156
 157    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
 158        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
 159        self.activate(next_active_ix, cx);
 160    }
 161
 162    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
 163        self.active_alternative()
 164            .update(cx, |codegen, cx| codegen.set_active(false, cx));
 165        self.seen_alternatives.insert(index);
 166        self.active_alternative = index;
 167        self.active_alternative()
 168            .update(cx, |codegen, cx| codegen.set_active(true, cx));
 169        self.subscribe_to_alternative(cx);
 170        cx.notify();
 171    }
 172
 173    pub fn start(
 174        &mut self,
 175        primary_model: Arc<dyn LanguageModel>,
 176        user_prompt: String,
 177        context_task: Shared<Task<Option<LoadedContext>>>,
 178        cx: &mut Context<Self>,
 179    ) -> Result<()> {
 180        let alternative_models = LanguageModelRegistry::read_global(cx)
 181            .inline_alternative_models()
 182            .to_vec();
 183
 184        self.active_alternative()
 185            .update(cx, |alternative, cx| alternative.undo(cx));
 186        self.activate(0, cx);
 187        self.alternatives.truncate(1);
 188
 189        for _ in 0..alternative_models.len() {
 190            self.alternatives.push(cx.new(|cx| {
 191                CodegenAlternative::new(
 192                    self.buffer.clone(),
 193                    self.range.clone(),
 194                    false,
 195                    self.builder.clone(),
 196                    self.session_id,
 197                    cx,
 198                )
 199            }));
 200        }
 201
 202        for (model, alternative) in iter::once(primary_model)
 203            .chain(alternative_models)
 204            .zip(&self.alternatives)
 205        {
 206            alternative.update(cx, |alternative, cx| {
 207                alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
 208            })?;
 209        }
 210
 211        Ok(())
 212    }
 213
 214    pub fn stop(&mut self, cx: &mut Context<Self>) {
 215        for codegen in &self.alternatives {
 216            codegen.update(cx, |codegen, cx| codegen.stop(cx));
 217        }
 218    }
 219
 220    pub fn undo(&mut self, cx: &mut Context<Self>) {
 221        self.active_alternative()
 222            .update(cx, |codegen, cx| codegen.undo(cx));
 223
 224        self.buffer.update(cx, |buffer, cx| {
 225            if let Some(transaction_id) = self.initial_transaction_id.take() {
 226                buffer.undo_transaction(transaction_id, cx);
 227                buffer.refresh_preview(cx);
 228            }
 229        });
 230    }
 231
 232    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
 233        self.active_alternative().read(cx).buffer.clone()
 234    }
 235
 236    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
 237        self.active_alternative().read(cx).old_buffer.clone()
 238    }
 239
 240    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
 241        self.active_alternative().read(cx).snapshot.clone()
 242    }
 243
 244    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
 245        self.active_alternative().read(cx).edit_position
 246    }
 247
 248    pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
 249        &self.active_alternative().read(cx).diff
 250    }
 251
 252    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
 253        self.active_alternative().read(cx).last_equal_ranges()
 254    }
 255
 256    pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
 257        self.active_alternative().read(cx).selected_text()
 258    }
 259
 260    pub fn session_id(&self) -> Uuid {
 261        self.session_id
 262    }
 263}
 264
 265impl EventEmitter<CodegenEvent> for BufferCodegen {}
 266
 267pub struct CodegenAlternative {
 268    buffer: Entity<MultiBuffer>,
 269    old_buffer: Entity<Buffer>,
 270    snapshot: MultiBufferSnapshot,
 271    edit_position: Option<Anchor>,
 272    range: Range<Anchor>,
 273    last_equal_ranges: Vec<Range<Anchor>>,
 274    transformation_transaction_id: Option<TransactionId>,
 275    status: CodegenStatus,
 276    generation: Task<()>,
 277    diff: Diff,
 278    _subscription: gpui::Subscription,
 279    builder: Arc<PromptBuilder>,
 280    active: bool,
 281    edits: Vec<(Range<Anchor>, String)>,
 282    line_operations: Vec<LineOperation>,
 283    elapsed_time: Option<f64>,
 284    completion: Option<String>,
 285    selected_text: Option<String>,
 286    pub message_id: Option<String>,
 287    session_id: Uuid,
 288    pub description: Option<String>,
 289    pub failure: Option<String>,
 290}
 291
 292impl EventEmitter<CodegenEvent> for CodegenAlternative {}
 293
 294impl CodegenAlternative {
 295    pub fn new(
 296        buffer: Entity<MultiBuffer>,
 297        range: Range<Anchor>,
 298        active: bool,
 299        builder: Arc<PromptBuilder>,
 300        session_id: Uuid,
 301        cx: &mut Context<Self>,
 302    ) -> Self {
 303        let snapshot = buffer.read(cx).snapshot(cx);
 304
 305        let (old_buffer, _, _) = snapshot
 306            .range_to_buffer_ranges(range.start..range.end)
 307            .pop()
 308            .unwrap();
 309        let old_buffer = cx.new(|cx| {
 310            let text = old_buffer.as_rope().clone();
 311            let line_ending = old_buffer.line_ending();
 312            let language = old_buffer.language().cloned();
 313            let language_registry = buffer
 314                .read(cx)
 315                .buffer(old_buffer.remote_id())
 316                .unwrap()
 317                .read(cx)
 318                .language_registry();
 319
 320            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
 321            buffer.set_language(language, cx);
 322            if let Some(language_registry) = language_registry {
 323                buffer.set_language_registry(language_registry);
 324            }
 325            buffer
 326        });
 327
 328        Self {
 329            buffer: buffer.clone(),
 330            old_buffer,
 331            edit_position: None,
 332            message_id: None,
 333            snapshot,
 334            last_equal_ranges: Default::default(),
 335            transformation_transaction_id: None,
 336            status: CodegenStatus::Idle,
 337            generation: Task::ready(()),
 338            diff: Diff::default(),
 339            builder,
 340            active: active,
 341            edits: Vec::new(),
 342            line_operations: Vec::new(),
 343            range,
 344            elapsed_time: None,
 345            completion: None,
 346            selected_text: None,
 347            session_id,
 348            description: None,
 349            failure: None,
 350            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 351        }
 352    }
 353
 354    pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
 355        self.old_buffer
 356            .read(cx)
 357            .language()
 358            .map(|language| language.name())
 359    }
 360
 361    pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
 362        if active != self.active {
 363            self.active = active;
 364
 365            if self.active {
 366                let edits = self.edits.clone();
 367                self.apply_edits(edits, cx);
 368                if matches!(self.status, CodegenStatus::Pending) {
 369                    let line_operations = self.line_operations.clone();
 370                    self.reapply_line_based_diff(line_operations, cx);
 371                } else {
 372                    self.reapply_batch_diff(cx).detach();
 373                }
 374            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
 375                self.buffer.update(cx, |buffer, cx| {
 376                    buffer.undo_transaction(transaction_id, cx);
 377                    buffer.forget_transaction(transaction_id, cx);
 378                });
 379            }
 380        }
 381    }
 382
 383    fn handle_buffer_event(
 384        &mut self,
 385        _buffer: Entity<MultiBuffer>,
 386        event: &multi_buffer::Event,
 387        cx: &mut Context<Self>,
 388    ) {
 389        if let multi_buffer::Event::TransactionUndone { transaction_id } = event
 390            && self.transformation_transaction_id == Some(*transaction_id)
 391        {
 392            self.transformation_transaction_id = None;
 393            self.generation = Task::ready(());
 394            cx.emit(CodegenEvent::Undone);
 395        }
 396    }
 397
 398    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 399        &self.last_equal_ranges
 400    }
 401
 402    pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
 403        model.supports_streaming_tools()
 404            && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
 405    }
 406
 407    pub fn start(
 408        &mut self,
 409        user_prompt: String,
 410        context_task: Shared<Task<Option<LoadedContext>>>,
 411        model: Arc<dyn LanguageModel>,
 412        cx: &mut Context<Self>,
 413    ) -> Result<()> {
 414        // Clear the model explanation since the user has started a new generation.
 415        self.description = None;
 416
 417        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
 418            self.buffer.update(cx, |buffer, cx| {
 419                buffer.undo_transaction(transformation_transaction_id, cx);
 420            });
 421        }
 422
 423        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
 424
 425        if Self::use_streaming_tools(model.as_ref(), cx) {
 426            let request = self.build_request(&model, user_prompt, context_task, cx)?;
 427            let completion_events = cx.spawn({
 428                let model = model.clone();
 429                async move |_, cx| model.stream_completion(request.await, cx).await
 430            });
 431            self.generation = self.handle_completion(model, completion_events, cx);
 432        } else {
 433            let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
 434                if user_prompt.trim().to_lowercase() == "delete" {
 435                    async { Ok(LanguageModelTextStream::default()) }.boxed_local()
 436                } else {
 437                    let request = self.build_request(&model, user_prompt, context_task, cx)?;
 438                    cx.spawn({
 439                        let model = model.clone();
 440                        async move |_, cx| {
 441                            Ok(model.stream_completion_text(request.await, cx).await?)
 442                        }
 443                    })
 444                    .boxed_local()
 445                };
 446            self.generation =
 447                self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
 448        }
 449
 450        Ok(())
 451    }
 452
 453    fn build_request_tools(
 454        &self,
 455        model: &Arc<dyn LanguageModel>,
 456        user_prompt: String,
 457        context_task: Shared<Task<Option<LoadedContext>>>,
 458        cx: &mut App,
 459    ) -> Result<Task<LanguageModelRequest>> {
 460        let buffer = self.buffer.read(cx).snapshot(cx);
 461        let language = buffer.language_at(self.range.start);
 462        let language_name = if let Some(language) = language.as_ref() {
 463            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 464                None
 465            } else {
 466                Some(language.name())
 467            }
 468        } else {
 469            None
 470        };
 471
 472        let language_name = language_name.as_ref();
 473        let start = buffer.point_to_buffer_offset(self.range.start);
 474        let end = buffer.point_to_buffer_offset(self.range.end);
 475        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 476            let (start_buffer, start_buffer_offset) = start;
 477            let (end_buffer, end_buffer_offset) = end;
 478            if start_buffer.remote_id() == end_buffer.remote_id() {
 479                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 480            } else {
 481                anyhow::bail!("invalid transformation range");
 482            }
 483        } else {
 484            anyhow::bail!("invalid transformation range");
 485        };
 486
 487        let system_prompt = self
 488            .builder
 489            .generate_inline_transformation_prompt_tools(
 490                language_name,
 491                buffer,
 492                range.start.0..range.end.0,
 493            )
 494            .context("generating content prompt")?;
 495
 496        let temperature = AgentSettings::temperature_for_model(model, cx);
 497
 498        let tool_input_format = model.tool_input_format();
 499        let tool_choice = model
 500            .supports_tool_choice(LanguageModelToolChoice::Any)
 501            .then_some(LanguageModelToolChoice::Any);
 502
 503        Ok(cx.spawn(async move |_cx| {
 504            let mut messages = vec![LanguageModelRequestMessage {
 505                role: Role::System,
 506                content: vec![system_prompt.into()],
 507                cache: false,
 508                reasoning_details: None,
 509            }];
 510
 511            let mut user_message = LanguageModelRequestMessage {
 512                role: Role::User,
 513                content: Vec::new(),
 514                cache: false,
 515                reasoning_details: None,
 516            };
 517
 518            if let Some(context) = context_task.await {
 519                context.add_to_request_message(&mut user_message);
 520            }
 521
 522            user_message.content.push(user_prompt.into());
 523            messages.push(user_message);
 524
 525            let tools = vec![
 526                LanguageModelRequestTool {
 527                    name: REWRITE_SECTION_TOOL_NAME.to_string(),
 528                    description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
 529                    input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
 530                    use_input_streaming: false,
 531                },
 532                LanguageModelRequestTool {
 533                    name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
 534                    description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
 535                    input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
 536                    use_input_streaming: false,
 537                },
 538            ];
 539
 540            LanguageModelRequest {
 541                thread_id: None,
 542                prompt_id: None,
 543                intent: Some(CompletionIntent::InlineAssist),
 544                tools,
 545                tool_choice,
 546                stop: Vec::new(),
 547                temperature,
 548                messages,
 549                thinking_allowed: false,
 550                thinking_effort: None,
 551                speed: None,
 552            }
 553        }))
 554    }
 555
 556    fn build_request(
 557        &self,
 558        model: &Arc<dyn LanguageModel>,
 559        user_prompt: String,
 560        context_task: Shared<Task<Option<LoadedContext>>>,
 561        cx: &mut App,
 562    ) -> Result<Task<LanguageModelRequest>> {
 563        if Self::use_streaming_tools(model.as_ref(), cx) {
 564            return self.build_request_tools(model, user_prompt, context_task, cx);
 565        }
 566
 567        let buffer = self.buffer.read(cx).snapshot(cx);
 568        let language = buffer.language_at(self.range.start);
 569        let language_name = if let Some(language) = language.as_ref() {
 570            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 571                None
 572            } else {
 573                Some(language.name())
 574            }
 575        } else {
 576            None
 577        };
 578
 579        let language_name = language_name.as_ref();
 580        let start = buffer.point_to_buffer_offset(self.range.start);
 581        let end = buffer.point_to_buffer_offset(self.range.end);
 582        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 583            let (start_buffer, start_buffer_offset) = start;
 584            let (end_buffer, end_buffer_offset) = end;
 585            if start_buffer.remote_id() == end_buffer.remote_id() {
 586                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 587            } else {
 588                anyhow::bail!("invalid transformation range");
 589            }
 590        } else {
 591            anyhow::bail!("invalid transformation range");
 592        };
 593
 594        let prompt = self
 595            .builder
 596            .generate_inline_transformation_prompt(
 597                user_prompt,
 598                language_name,
 599                buffer,
 600                range.start.0..range.end.0,
 601            )
 602            .context("generating content prompt")?;
 603
 604        let temperature = AgentSettings::temperature_for_model(model, cx);
 605
 606        Ok(cx.spawn(async move |_cx| {
 607            let mut request_message = LanguageModelRequestMessage {
 608                role: Role::User,
 609                content: Vec::new(),
 610                cache: false,
 611                reasoning_details: None,
 612            };
 613
 614            if let Some(context) = context_task.await {
 615                context.add_to_request_message(&mut request_message);
 616            }
 617
 618            request_message.content.push(prompt.into());
 619
 620            LanguageModelRequest {
 621                thread_id: None,
 622                prompt_id: None,
 623                intent: Some(CompletionIntent::InlineAssist),
 624                tools: Vec::new(),
 625                tool_choice: None,
 626                stop: Vec::new(),
 627                temperature,
 628                messages: vec![request_message],
 629                thinking_allowed: false,
 630                thinking_effort: None,
 631                speed: None,
 632            }
 633        }))
 634    }
 635
 636    pub fn handle_stream(
 637        &mut self,
 638        model: Arc<dyn LanguageModel>,
 639        strip_invalid_spans: bool,
 640        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
 641        cx: &mut Context<Self>,
 642    ) -> Task<()> {
 643        let anthropic_reporter = AnthropicEventReporter::new(&model, cx);
 644        let session_id = self.session_id;
 645        let model_telemetry_id = model.telemetry_id();
 646        let model_provider_id = model.provider_id().to_string();
 647        let start_time = Instant::now();
 648
 649        // Make a new snapshot and re-resolve anchor in case the document was modified.
 650        // This can happen often if the editor loses focus and is saved + reformatted,
 651        // as in https://github.com/zed-industries/zed/issues/39088
 652        self.snapshot = self.buffer.read(cx).snapshot(cx);
 653        self.range = self.snapshot.anchor_after(self.range.start)
 654            ..self.snapshot.anchor_after(self.range.end);
 655
 656        let snapshot = self.snapshot.clone();
 657        let selected_text = snapshot
 658            .text_for_range(self.range.start..self.range.end)
 659            .collect::<Rope>();
 660
 661        self.selected_text = Some(selected_text.to_string());
 662
 663        let selection_start = self.range.start.to_point(&snapshot);
 664
 665        // Start with the indentation of the first line in the selection
 666        let mut suggested_line_indent = snapshot
 667            .suggested_indents(selection_start.row..=selection_start.row, cx)
 668            .into_values()
 669            .next()
 670            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 671
 672        // If the first line in the selection does not have indentation, check the following lines
 673        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
 674            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
 675                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
 676                // Prefer tabs if a line in the selection uses tabs as indentation
 677                if line_indent.kind == IndentKind::Tab {
 678                    suggested_line_indent.kind = IndentKind::Tab;
 679                    break;
 680                }
 681            }
 682        }
 683
 684        let language_name = {
 685            let multibuffer = self.buffer.read(cx);
 686            let snapshot = multibuffer.snapshot(cx);
 687            let ranges = snapshot.range_to_buffer_ranges(self.range.start..self.range.end);
 688            ranges
 689                .first()
 690                .and_then(|(buffer, _, _)| buffer.language())
 691                .map(|language| language.name())
 692        };
 693
 694        self.diff = Diff::default();
 695        self.status = CodegenStatus::Pending;
 696        let mut edit_start = self.range.start.to_offset(&snapshot);
 697        let completion = Arc::new(Mutex::new(String::new()));
 698        let completion_clone = completion.clone();
 699
 700        cx.notify();
 701        cx.spawn(async move |codegen, cx| {
 702            let stream = stream.await;
 703
 704            let token_usage = stream
 705                .as_ref()
 706                .ok()
 707                .map(|stream| stream.last_token_usage.clone());
 708            let message_id = stream
 709                .as_ref()
 710                .ok()
 711                .and_then(|stream| stream.message_id.clone());
 712            let generate = async {
 713                let model_telemetry_id = model_telemetry_id.clone();
 714                let model_provider_id = model_provider_id.clone();
 715                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
 716                let message_id = message_id.clone();
 717                let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
 718                    let anthropic_reporter = anthropic_reporter.clone();
 719                    let language_name = language_name.clone();
 720                    async move {
 721                        let mut response_latency = None;
 722                        let request_start = Instant::now();
 723                        let diff = async {
 724                            let raw_stream = stream?.stream.map_err(|error| error.into());
 725
 726                            let stripped;
 727                            let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
 728                                if strip_invalid_spans {
 729                                    stripped = StripInvalidSpans::new(raw_stream);
 730                                    Box::pin(stripped)
 731                                } else {
 732                                    Box::pin(raw_stream)
 733                                };
 734
 735                            let mut diff = StreamingDiff::new(selected_text.to_string());
 736                            let mut line_diff = LineDiff::default();
 737
 738                            let mut new_text = String::new();
 739                            let mut base_indent = None;
 740                            let mut line_indent = None;
 741                            let mut first_line = true;
 742
 743                            while let Some(chunk) = chunks.next().await {
 744                                if response_latency.is_none() {
 745                                    response_latency = Some(request_start.elapsed());
 746                                }
 747                                let chunk = chunk?;
 748                                completion_clone.lock().push_str(&chunk);
 749
 750                                let mut lines = chunk.split('\n').peekable();
 751                                while let Some(line) = lines.next() {
 752                                    new_text.push_str(line);
 753                                    if line_indent.is_none()
 754                                        && let Some(non_whitespace_ch_ix) =
 755                                            new_text.find(|ch: char| !ch.is_whitespace())
 756                                    {
 757                                        line_indent = Some(non_whitespace_ch_ix);
 758                                        base_indent = base_indent.or(line_indent);
 759
 760                                        let line_indent = line_indent.unwrap();
 761                                        let base_indent = base_indent.unwrap();
 762                                        let indent_delta = line_indent as i32 - base_indent as i32;
 763                                        let mut corrected_indent_len = cmp::max(
 764                                            0,
 765                                            suggested_line_indent.len as i32 + indent_delta,
 766                                        )
 767                                            as usize;
 768                                        if first_line {
 769                                            corrected_indent_len = corrected_indent_len
 770                                                .saturating_sub(selection_start.column as usize);
 771                                        }
 772
 773                                        let indent_char = suggested_line_indent.char();
 774                                        let mut indent_buffer = [0; 4];
 775                                        let indent_str =
 776                                            indent_char.encode_utf8(&mut indent_buffer);
 777                                        new_text.replace_range(
 778                                            ..line_indent,
 779                                            &indent_str.repeat(corrected_indent_len),
 780                                        );
 781                                    }
 782
 783                                    if line_indent.is_some() {
 784                                        let char_ops = diff.push_new(&new_text);
 785                                        line_diff.push_char_operations(&char_ops, &selected_text);
 786                                        diff_tx
 787                                            .send((char_ops, line_diff.line_operations()))
 788                                            .await?;
 789                                        new_text.clear();
 790                                    }
 791
 792                                    if lines.peek().is_some() {
 793                                        let char_ops = diff.push_new("\n");
 794                                        line_diff.push_char_operations(&char_ops, &selected_text);
 795                                        diff_tx
 796                                            .send((char_ops, line_diff.line_operations()))
 797                                            .await?;
 798                                        if line_indent.is_none() {
 799                                            // Don't write out the leading indentation in empty lines on the next line
 800                                            // This is the case where the above if statement didn't clear the buffer
 801                                            new_text.clear();
 802                                        }
 803                                        line_indent = None;
 804                                        first_line = false;
 805                                    }
 806                                }
 807                            }
 808
 809                            let mut char_ops = diff.push_new(&new_text);
 810                            char_ops.extend(diff.finish());
 811                            line_diff.push_char_operations(&char_ops, &selected_text);
 812                            line_diff.finish(&selected_text);
 813                            diff_tx
 814                                .send((char_ops, line_diff.line_operations()))
 815                                .await?;
 816
 817                            anyhow::Ok(())
 818                        };
 819
 820                        let result = diff.await;
 821
 822                        let error_message = result.as_ref().err().map(|error| error.to_string());
 823                        telemetry::event!(
 824                            "Assistant Responded",
 825                            kind = "inline",
 826                            phase = "response",
 827                            session_id = session_id.to_string(),
 828                            model = model_telemetry_id,
 829                            model_provider = model_provider_id,
 830                            language_name = language_name.as_ref().map(|n| n.to_string()),
 831                            message_id = message_id.as_deref(),
 832                            response_latency = response_latency,
 833                            error_message = error_message.as_deref(),
 834                        );
 835
 836                        anthropic_reporter.report(AnthropicEventData {
 837                            completion_type: AnthropicCompletionType::Editor,
 838                            event: AnthropicEventType::Response,
 839                            language_name: language_name.map(|n| n.to_string()),
 840                            message_id,
 841                        });
 842
 843                        result?;
 844                        Ok(())
 845                    }
 846                });
 847
 848                while let Some((char_ops, line_ops)) = diff_rx.next().await {
 849                    codegen.update(cx, |codegen, cx| {
 850                        codegen.last_equal_ranges.clear();
 851
 852                        let edits = char_ops
 853                            .into_iter()
 854                            .filter_map(|operation| match operation {
 855                                CharOperation::Insert { text } => {
 856                                    let edit_start = snapshot.anchor_after(edit_start);
 857                                    Some((edit_start..edit_start, text))
 858                                }
 859                                CharOperation::Delete { bytes } => {
 860                                    let edit_end = edit_start + bytes;
 861                                    let edit_range = snapshot.anchor_after(edit_start)
 862                                        ..snapshot.anchor_before(edit_end);
 863                                    edit_start = edit_end;
 864                                    Some((edit_range, String::new()))
 865                                }
 866                                CharOperation::Keep { bytes } => {
 867                                    let edit_end = edit_start + bytes;
 868                                    let edit_range = snapshot.anchor_after(edit_start)
 869                                        ..snapshot.anchor_before(edit_end);
 870                                    edit_start = edit_end;
 871                                    codegen.last_equal_ranges.push(edit_range);
 872                                    None
 873                                }
 874                            })
 875                            .collect::<Vec<_>>();
 876
 877                        if codegen.active {
 878                            codegen.apply_edits(edits.iter().cloned(), cx);
 879                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
 880                        }
 881                        codegen.edits.extend(edits);
 882                        codegen.line_operations = line_ops;
 883                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
 884
 885                        cx.notify();
 886                    })?;
 887                }
 888
 889                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
 890                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
 891                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
 892                let batch_diff_task =
 893                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
 894                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
 895                line_based_stream_diff?;
 896
 897                anyhow::Ok(())
 898            };
 899
 900            let result = generate.await;
 901            let elapsed_time = start_time.elapsed().as_secs_f64();
 902
 903            codegen
 904                .update(cx, |this, cx| {
 905                    this.message_id = message_id;
 906                    this.last_equal_ranges.clear();
 907                    if let Err(error) = result {
 908                        this.status = CodegenStatus::Error(error);
 909                    } else {
 910                        this.status = CodegenStatus::Done;
 911                    }
 912                    this.elapsed_time = Some(elapsed_time);
 913                    this.completion = Some(completion.lock().clone());
 914                    if let Some(usage) = token_usage {
 915                        let usage = usage.lock();
 916                        telemetry::event!(
 917                            "Inline Assistant Completion",
 918                            model = model_telemetry_id,
 919                            model_provider = model_provider_id,
 920                            input_tokens = usage.input_tokens,
 921                            output_tokens = usage.output_tokens,
 922                        )
 923                    }
 924
 925                    cx.emit(CodegenEvent::Finished);
 926                    cx.notify();
 927                })
 928                .ok();
 929        })
 930    }
 931
 932    pub fn current_completion(&self) -> Option<String> {
 933        self.completion.clone()
 934    }
 935
 936    #[cfg(any(test, feature = "test-support"))]
 937    pub fn current_description(&self) -> Option<String> {
 938        self.description.clone()
 939    }
 940
 941    #[cfg(any(test, feature = "test-support"))]
 942    pub fn current_failure(&self) -> Option<String> {
 943        self.failure.clone()
 944    }
 945
 946    pub fn selected_text(&self) -> Option<&str> {
 947        self.selected_text.as_deref()
 948    }
 949
 950    pub fn stop(&mut self, cx: &mut Context<Self>) {
 951        self.last_equal_ranges.clear();
 952        if self.diff.is_empty() {
 953            self.status = CodegenStatus::Idle;
 954        } else {
 955            self.status = CodegenStatus::Done;
 956        }
 957        self.generation = Task::ready(());
 958        cx.emit(CodegenEvent::Finished);
 959        cx.notify();
 960    }
 961
 962    pub fn undo(&mut self, cx: &mut Context<Self>) {
 963        self.buffer.update(cx, |buffer, cx| {
 964            if let Some(transaction_id) = self.transformation_transaction_id.take() {
 965                buffer.undo_transaction(transaction_id, cx);
 966                buffer.refresh_preview(cx);
 967            }
 968        });
 969    }
 970
 971    fn apply_edits(
 972        &mut self,
 973        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
 974        cx: &mut Context<CodegenAlternative>,
 975    ) {
 976        let transaction = self.buffer.update(cx, |buffer, cx| {
 977            // Avoid grouping agent edits with user edits.
 978            buffer.finalize_last_transaction(cx);
 979            buffer.start_transaction(cx);
 980            buffer.edit(edits, None, cx);
 981            buffer.end_transaction(cx)
 982        });
 983
 984        if let Some(transaction) = transaction {
 985            if let Some(first_transaction) = self.transformation_transaction_id {
 986                // Group all agent edits into the first transaction.
 987                self.buffer.update(cx, |buffer, cx| {
 988                    buffer.merge_transactions(transaction, first_transaction, cx)
 989                });
 990            } else {
 991                self.transformation_transaction_id = Some(transaction);
 992                self.buffer
 993                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
 994            }
 995        }
 996    }
 997
 998    fn reapply_line_based_diff(
 999        &mut self,
1000        line_operations: impl IntoIterator<Item = LineOperation>,
1001        cx: &mut Context<Self>,
1002    ) {
1003        let old_snapshot = self.snapshot.clone();
1004        let old_range = self.range.to_point(&old_snapshot);
1005        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1006        let new_range = self.range.to_point(&new_snapshot);
1007
1008        let mut old_row = old_range.start.row;
1009        let mut new_row = new_range.start.row;
1010
1011        self.diff.deleted_row_ranges.clear();
1012        self.diff.inserted_row_ranges.clear();
1013        for operation in line_operations {
1014            match operation {
1015                LineOperation::Keep { lines } => {
1016                    old_row += lines;
1017                    new_row += lines;
1018                }
1019                LineOperation::Delete { lines } => {
1020                    let old_end_row = old_row + lines - 1;
1021                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1022
1023                    if let Some((_, last_deleted_row_range)) =
1024                        self.diff.deleted_row_ranges.last_mut()
1025                    {
1026                        if *last_deleted_row_range.end() + 1 == old_row {
1027                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1028                        } else {
1029                            self.diff
1030                                .deleted_row_ranges
1031                                .push((new_row, old_row..=old_end_row));
1032                        }
1033                    } else {
1034                        self.diff
1035                            .deleted_row_ranges
1036                            .push((new_row, old_row..=old_end_row));
1037                    }
1038
1039                    old_row += lines;
1040                }
1041                LineOperation::Insert { lines } => {
1042                    let new_end_row = new_row + lines - 1;
1043                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1044                    let end = new_snapshot.anchor_before(Point::new(
1045                        new_end_row,
1046                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
1047                    ));
1048                    self.diff.inserted_row_ranges.push(start..end);
1049                    new_row += lines;
1050                }
1051            }
1052
1053            cx.notify();
1054        }
1055    }
1056
1057    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1058        let old_snapshot = self.snapshot.clone();
1059        let old_range = self.range.to_point(&old_snapshot);
1060        let new_snapshot = self.buffer.read(cx).snapshot(cx);
1061        let new_range = self.range.to_point(&new_snapshot);
1062
1063        cx.spawn(async move |codegen, cx| {
1064            let (deleted_row_ranges, inserted_row_ranges) = cx
1065                .background_spawn(async move {
1066                    let old_text = old_snapshot
1067                        .text_for_range(
1068                            Point::new(old_range.start.row, 0)
1069                                ..Point::new(
1070                                    old_range.end.row,
1071                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1072                                ),
1073                        )
1074                        .collect::<String>();
1075                    let new_text = new_snapshot
1076                        .text_for_range(
1077                            Point::new(new_range.start.row, 0)
1078                                ..Point::new(
1079                                    new_range.end.row,
1080                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1081                                ),
1082                        )
1083                        .collect::<String>();
1084
1085                    let old_start_row = old_range.start.row;
1086                    let new_start_row = new_range.start.row;
1087                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1088                    let mut inserted_row_ranges = Vec::new();
1089                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1090                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1091                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1092                        if !old_rows.is_empty() {
1093                            deleted_row_ranges.push((
1094                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1095                                old_rows.start..=old_rows.end - 1,
1096                            ));
1097                        }
1098                        if !new_rows.is_empty() {
1099                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1100                            let new_end_row = new_rows.end - 1;
1101                            let end = new_snapshot.anchor_before(Point::new(
1102                                new_end_row,
1103                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
1104                            ));
1105                            inserted_row_ranges.push(start..end);
1106                        }
1107                    }
1108                    (deleted_row_ranges, inserted_row_ranges)
1109                })
1110                .await;
1111
1112            codegen
1113                .update(cx, |codegen, cx| {
1114                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
1115                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
1116                    cx.notify();
1117                })
1118                .ok();
1119        })
1120    }
1121
1122    fn handle_completion(
1123        &mut self,
1124        model: Arc<dyn LanguageModel>,
1125        completion_stream: Task<
1126            Result<
1127                BoxStream<
1128                    'static,
1129                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1130                >,
1131                LanguageModelCompletionError,
1132            >,
1133        >,
1134        cx: &mut Context<Self>,
1135    ) -> Task<()> {
1136        self.diff = Diff::default();
1137        self.status = CodegenStatus::Pending;
1138
1139        cx.notify();
1140        // Leaving this in generation so that STOP equivalent events are respected even
1141        // while we're still pre-processing the completion event
1142        cx.spawn(async move |codegen, cx| {
1143            let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1144                let _ = codegen.update(cx, |this, cx| {
1145                    this.status = status;
1146                    cx.emit(CodegenEvent::Finished);
1147                    cx.notify();
1148                });
1149            };
1150
1151            let mut completion_events = match completion_stream.await {
1152                Ok(events) => events,
1153                Err(err) => {
1154                    finish_with_status(CodegenStatus::Error(err.into()), cx);
1155                    return;
1156                }
1157            };
1158
1159            enum ToolUseOutput {
1160                Rewrite {
1161                    text: String,
1162                    description: Option<String>,
1163                },
1164                Failure(String),
1165            }
1166
1167            enum ModelUpdate {
1168                Description(String),
1169                Failure(String),
1170            }
1171
1172            let chars_read_so_far = Arc::new(Mutex::new(0usize));
1173            let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1174                let mut chars_read_so_far = chars_read_so_far.lock();
1175                match tool_use.name.as_ref() {
1176                    REWRITE_SECTION_TOOL_NAME => {
1177                        let Ok(input) =
1178                            serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1179                        else {
1180                            return None;
1181                        };
1182                        let text = input.replacement_text[*chars_read_so_far..].to_string();
1183                        *chars_read_so_far = input.replacement_text.len();
1184                        Some(ToolUseOutput::Rewrite {
1185                            text,
1186                            description: None,
1187                        })
1188                    }
1189                    FAILURE_MESSAGE_TOOL_NAME => {
1190                        let Ok(mut input) =
1191                            serde_json::from_value::<FailureMessageInput>(tool_use.input)
1192                        else {
1193                            return None;
1194                        };
1195                        Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1196                    }
1197                    _ => None,
1198                }
1199            };
1200
1201            let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1202
1203            cx.spawn({
1204                let codegen = codegen.clone();
1205                async move |cx| {
1206                    while let Some(update) = message_rx.next().await {
1207                        let _ = codegen.update(cx, |this, _cx| match update {
1208                            ModelUpdate::Description(d) => this.description = Some(d),
1209                            ModelUpdate::Failure(f) => this.failure = Some(f),
1210                        });
1211                    }
1212                }
1213            })
1214            .detach();
1215
1216            let mut message_id = None;
1217            let mut first_text = None;
1218            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1219            let total_text = Arc::new(Mutex::new(String::new()));
1220
1221            loop {
1222                if let Some(first_event) = completion_events.next().await {
1223                    match first_event {
1224                        Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1225                            message_id = Some(id);
1226                        }
1227                        Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1228                            if let Some(output) = process_tool_use(tool_use) {
1229                                let (text, update) = match output {
1230                                    ToolUseOutput::Rewrite { text, description } => {
1231                                        (Some(text), description.map(ModelUpdate::Description))
1232                                    }
1233                                    ToolUseOutput::Failure(message) => {
1234                                        (None, Some(ModelUpdate::Failure(message)))
1235                                    }
1236                                };
1237                                if let Some(update) = update {
1238                                    let _ = message_tx.unbounded_send(update);
1239                                }
1240                                first_text = text;
1241                                if first_text.is_some() {
1242                                    break;
1243                                }
1244                            }
1245                        }
1246                        Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1247                            *last_token_usage.lock() = token_usage;
1248                        }
1249                        Ok(LanguageModelCompletionEvent::Text(text)) => {
1250                            let mut lock = total_text.lock();
1251                            lock.push_str(&text);
1252                        }
1253                        Ok(e) => {
1254                            log::warn!("Unexpected event: {:?}", e);
1255                            break;
1256                        }
1257                        Err(e) => {
1258                            finish_with_status(CodegenStatus::Error(e.into()), cx);
1259                            break;
1260                        }
1261                    }
1262                }
1263            }
1264
1265            let Some(first_text) = first_text else {
1266                finish_with_status(CodegenStatus::Done, cx);
1267                return;
1268            };
1269
1270            let move_last_token_usage = last_token_usage.clone();
1271
1272            let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1273                completion_events.filter_map(move |e| {
1274                    let process_tool_use = process_tool_use.clone();
1275                    let last_token_usage = move_last_token_usage.clone();
1276                    let total_text = total_text.clone();
1277                    let mut message_tx = message_tx.clone();
1278                    async move {
1279                        match e {
1280                            Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1281                                let Some(output) = process_tool_use(tool_use) else {
1282                                    return None;
1283                                };
1284                                let (text, update) = match output {
1285                                    ToolUseOutput::Rewrite { text, description } => {
1286                                        (Some(text), description.map(ModelUpdate::Description))
1287                                    }
1288                                    ToolUseOutput::Failure(message) => {
1289                                        (None, Some(ModelUpdate::Failure(message)))
1290                                    }
1291                                };
1292                                if let Some(update) = update {
1293                                    let _ = message_tx.send(update).await;
1294                                }
1295                                text.map(Ok)
1296                            }
1297                            Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1298                                *last_token_usage.lock() = token_usage;
1299                                None
1300                            }
1301                            Ok(LanguageModelCompletionEvent::Text(text)) => {
1302                                let mut lock = total_text.lock();
1303                                lock.push_str(&text);
1304                                None
1305                            }
1306                            Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1307                            e => {
1308                                log::error!("UNEXPECTED EVENT {:?}", e);
1309                                None
1310                            }
1311                        }
1312                    }
1313                }),
1314            ));
1315
1316            let language_model_text_stream = LanguageModelTextStream {
1317                message_id: message_id,
1318                stream: text_stream,
1319                last_token_usage,
1320            };
1321
1322            let Some(task) = codegen
1323                .update(cx, move |codegen, cx| {
1324                    codegen.handle_stream(
1325                        model,
1326                        /* strip_invalid_spans: */ false,
1327                        async { Ok(language_model_text_stream) },
1328                        cx,
1329                    )
1330                })
1331                .ok()
1332            else {
1333                return;
1334            };
1335
1336            task.await;
1337        })
1338    }
1339}
1340
1341#[derive(Copy, Clone, Debug)]
1342pub enum CodegenEvent {
1343    Finished,
1344    Undone,
1345}
1346
1347struct StripInvalidSpans<T> {
1348    stream: T,
1349    stream_done: bool,
1350    buffer: String,
1351    first_line: bool,
1352    line_end: bool,
1353    starts_with_code_block: bool,
1354}
1355
1356impl<T> StripInvalidSpans<T>
1357where
1358    T: Stream<Item = Result<String>>,
1359{
1360    fn new(stream: T) -> Self {
1361        Self {
1362            stream,
1363            stream_done: false,
1364            buffer: String::new(),
1365            first_line: true,
1366            line_end: false,
1367            starts_with_code_block: false,
1368        }
1369    }
1370}
1371
1372impl<T> Stream for StripInvalidSpans<T>
1373where
1374    T: Stream<Item = Result<String>>,
1375{
1376    type Item = Result<String>;
1377
1378    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1379        const CODE_BLOCK_DELIMITER: &str = "```";
1380        const CURSOR_SPAN: &str = "<|CURSOR|>";
1381
1382        let this = unsafe { self.get_unchecked_mut() };
1383        loop {
1384            if !this.stream_done {
1385                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1386                match stream.as_mut().poll_next(cx) {
1387                    Poll::Ready(Some(Ok(chunk))) => {
1388                        this.buffer.push_str(&chunk);
1389                    }
1390                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1391                    Poll::Ready(None) => {
1392                        this.stream_done = true;
1393                    }
1394                    Poll::Pending => return Poll::Pending,
1395                }
1396            }
1397
1398            let mut chunk = String::new();
1399            let mut consumed = 0;
1400            if !this.buffer.is_empty() {
1401                let mut lines = this.buffer.split('\n').enumerate().peekable();
1402                while let Some((line_ix, line)) = lines.next() {
1403                    if line_ix > 0 {
1404                        this.first_line = false;
1405                    }
1406
1407                    if this.first_line {
1408                        let trimmed_line = line.trim();
1409                        if lines.peek().is_some() {
1410                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1411                                consumed += line.len() + 1;
1412                                this.starts_with_code_block = true;
1413                                continue;
1414                            }
1415                        } else if trimmed_line.is_empty()
1416                            || prefixes(CODE_BLOCK_DELIMITER)
1417                                .any(|prefix| trimmed_line.starts_with(prefix))
1418                        {
1419                            break;
1420                        }
1421                    }
1422
1423                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
1424                    if lines.peek().is_some() {
1425                        if this.line_end {
1426                            chunk.push('\n');
1427                        }
1428
1429                        chunk.push_str(&line_without_cursor);
1430                        this.line_end = true;
1431                        consumed += line.len() + 1;
1432                    } else if this.stream_done {
1433                        if !this.starts_with_code_block
1434                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1435                        {
1436                            if this.line_end {
1437                                chunk.push('\n');
1438                            }
1439
1440                            chunk.push_str(line);
1441                        }
1442
1443                        consumed += line.len();
1444                    } else {
1445                        let trimmed_line = line.trim();
1446                        if trimmed_line.is_empty()
1447                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1448                            || prefixes(CODE_BLOCK_DELIMITER)
1449                                .any(|prefix| trimmed_line.ends_with(prefix))
1450                        {
1451                            break;
1452                        } else {
1453                            if this.line_end {
1454                                chunk.push('\n');
1455                                this.line_end = false;
1456                            }
1457
1458                            chunk.push_str(&line_without_cursor);
1459                            consumed += line.len();
1460                        }
1461                    }
1462                }
1463            }
1464
1465            this.buffer = this.buffer.split_off(consumed);
1466            if !chunk.is_empty() {
1467                return Poll::Ready(Some(Ok(chunk)));
1468            } else if this.stream_done {
1469                return Poll::Ready(None);
1470            }
1471        }
1472    }
1473}
1474
1475fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1476    (0..text.len() - 1).map(|ix| &text[..ix + 1])
1477}
1478
1479#[derive(Default)]
1480pub struct Diff {
1481    pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1482    pub inserted_row_ranges: Vec<Range<Anchor>>,
1483}
1484
1485impl Diff {
1486    fn is_empty(&self) -> bool {
1487        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1488    }
1489}
1490
1491#[cfg(test)]
1492mod tests {
1493    use super::*;
1494    use futures::{
1495        Stream,
1496        stream::{self},
1497    };
1498    use gpui::TestAppContext;
1499    use indoc::indoc;
1500    use language::{Buffer, Point};
1501    use language_model::fake_provider::FakeLanguageModel;
1502    use language_model::{
1503        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
1504        LanguageModelToolUse, StopReason, TokenUsage,
1505    };
1506    use languages::rust_lang;
1507    use rand::prelude::*;
1508    use settings::SettingsStore;
1509    use std::{future, sync::Arc};
1510
1511    #[gpui::test(iterations = 10)]
1512    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1513        init_test(cx);
1514
1515        let text = indoc! {"
1516            fn main() {
1517                let x = 0;
1518                for _ in 0..10 {
1519                    x += 1;
1520                }
1521            }
1522        "};
1523        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1524        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1525        let range = buffer.read_with(cx, |buffer, cx| {
1526            let snapshot = buffer.snapshot(cx);
1527            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1528        });
1529        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1530        let codegen = cx.new(|cx| {
1531            CodegenAlternative::new(
1532                buffer.clone(),
1533                range.clone(),
1534                true,
1535                prompt_builder,
1536                Uuid::new_v4(),
1537                cx,
1538            )
1539        });
1540
1541        let chunks_tx = simulate_response_stream(&codegen, cx);
1542
1543        let mut new_text = concat!(
1544            "       let mut x = 0;\n",
1545            "       while x < 10 {\n",
1546            "           x += 1;\n",
1547            "       }",
1548        );
1549        while !new_text.is_empty() {
1550            let max_len = cmp::min(new_text.len(), 10);
1551            let len = rng.random_range(1..=max_len);
1552            let (chunk, suffix) = new_text.split_at(len);
1553            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1554            new_text = suffix;
1555            cx.background_executor.run_until_parked();
1556        }
1557        drop(chunks_tx);
1558        cx.background_executor.run_until_parked();
1559
1560        assert_eq!(
1561            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1562            indoc! {"
1563                fn main() {
1564                    let mut x = 0;
1565                    while x < 10 {
1566                        x += 1;
1567                    }
1568                }
1569            "}
1570        );
1571    }
1572
1573    #[gpui::test(iterations = 10)]
1574    async fn test_autoindent_when_generating_past_indentation(
1575        cx: &mut TestAppContext,
1576        mut rng: StdRng,
1577    ) {
1578        init_test(cx);
1579
1580        let text = indoc! {"
1581            fn main() {
1582                le
1583            }
1584        "};
1585        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1586        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1587        let range = buffer.read_with(cx, |buffer, cx| {
1588            let snapshot = buffer.snapshot(cx);
1589            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1590        });
1591        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1592        let codegen = cx.new(|cx| {
1593            CodegenAlternative::new(
1594                buffer.clone(),
1595                range.clone(),
1596                true,
1597                prompt_builder,
1598                Uuid::new_v4(),
1599                cx,
1600            )
1601        });
1602
1603        let chunks_tx = simulate_response_stream(&codegen, cx);
1604
1605        cx.background_executor.run_until_parked();
1606
1607        let mut new_text = concat!(
1608            "t mut x = 0;\n",
1609            "while x < 10 {\n",
1610            "    x += 1;\n",
1611            "}", //
1612        );
1613        while !new_text.is_empty() {
1614            let max_len = cmp::min(new_text.len(), 10);
1615            let len = rng.random_range(1..=max_len);
1616            let (chunk, suffix) = new_text.split_at(len);
1617            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1618            new_text = suffix;
1619            cx.background_executor.run_until_parked();
1620        }
1621        drop(chunks_tx);
1622        cx.background_executor.run_until_parked();
1623
1624        assert_eq!(
1625            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1626            indoc! {"
1627                fn main() {
1628                    let mut x = 0;
1629                    while x < 10 {
1630                        x += 1;
1631                    }
1632                }
1633            "}
1634        );
1635    }
1636
1637    #[gpui::test(iterations = 10)]
1638    async fn test_autoindent_when_generating_before_indentation(
1639        cx: &mut TestAppContext,
1640        mut rng: StdRng,
1641    ) {
1642        init_test(cx);
1643
1644        let text = concat!(
1645            "fn main() {\n",
1646            "  \n",
1647            "}\n" //
1648        );
1649        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1650        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1651        let range = buffer.read_with(cx, |buffer, cx| {
1652            let snapshot = buffer.snapshot(cx);
1653            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1654        });
1655        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1656        let codegen = cx.new(|cx| {
1657            CodegenAlternative::new(
1658                buffer.clone(),
1659                range.clone(),
1660                true,
1661                prompt_builder,
1662                Uuid::new_v4(),
1663                cx,
1664            )
1665        });
1666
1667        let chunks_tx = simulate_response_stream(&codegen, cx);
1668
1669        cx.background_executor.run_until_parked();
1670
1671        let mut new_text = concat!(
1672            "let mut x = 0;\n",
1673            "while x < 10 {\n",
1674            "    x += 1;\n",
1675            "}", //
1676        );
1677        while !new_text.is_empty() {
1678            let max_len = cmp::min(new_text.len(), 10);
1679            let len = rng.random_range(1..=max_len);
1680            let (chunk, suffix) = new_text.split_at(len);
1681            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1682            new_text = suffix;
1683            cx.background_executor.run_until_parked();
1684        }
1685        drop(chunks_tx);
1686        cx.background_executor.run_until_parked();
1687
1688        assert_eq!(
1689            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1690            indoc! {"
1691                fn main() {
1692                    let mut x = 0;
1693                    while x < 10 {
1694                        x += 1;
1695                    }
1696                }
1697            "}
1698        );
1699    }
1700
1701    #[gpui::test(iterations = 10)]
1702    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1703        init_test(cx);
1704
1705        let text = indoc! {"
1706            func main() {
1707            \tx := 0
1708            \tfor i := 0; i < 10; i++ {
1709            \t\tx++
1710            \t}
1711            }
1712        "};
1713        let buffer = cx.new(|cx| Buffer::local(text, cx));
1714        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1715        let range = buffer.read_with(cx, |buffer, cx| {
1716            let snapshot = buffer.snapshot(cx);
1717            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1718        });
1719        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1720        let codegen = cx.new(|cx| {
1721            CodegenAlternative::new(
1722                buffer.clone(),
1723                range.clone(),
1724                true,
1725                prompt_builder,
1726                Uuid::new_v4(),
1727                cx,
1728            )
1729        });
1730
1731        let chunks_tx = simulate_response_stream(&codegen, cx);
1732        let new_text = concat!(
1733            "func main() {\n",
1734            "\tx := 0\n",
1735            "\tfor x < 10 {\n",
1736            "\t\tx++\n",
1737            "\t}", //
1738        );
1739        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1740        drop(chunks_tx);
1741        cx.background_executor.run_until_parked();
1742
1743        assert_eq!(
1744            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1745            indoc! {"
1746                func main() {
1747                \tx := 0
1748                \tfor x < 10 {
1749                \t\tx++
1750                \t}
1751                }
1752            "}
1753        );
1754    }
1755
1756    #[gpui::test]
1757    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1758        init_test(cx);
1759
1760        let text = indoc! {"
1761            fn main() {
1762                let x = 0;
1763            }
1764        "};
1765        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1766        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1767        let range = buffer.read_with(cx, |buffer, cx| {
1768            let snapshot = buffer.snapshot(cx);
1769            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1770        });
1771        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1772        let codegen = cx.new(|cx| {
1773            CodegenAlternative::new(
1774                buffer.clone(),
1775                range.clone(),
1776                false,
1777                prompt_builder,
1778                Uuid::new_v4(),
1779                cx,
1780            )
1781        });
1782
1783        let chunks_tx = simulate_response_stream(&codegen, cx);
1784        chunks_tx
1785            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1786            .unwrap();
1787        drop(chunks_tx);
1788        cx.run_until_parked();
1789
1790        // The codegen is inactive, so the buffer doesn't get modified.
1791        assert_eq!(
1792            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1793            text
1794        );
1795
1796        // Activating the codegen applies the changes.
1797        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1798        assert_eq!(
1799            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1800            indoc! {"
1801                fn main() {
1802                    let mut x = 0;
1803                    x += 1;
1804                }
1805            "}
1806        );
1807
1808        // Deactivating the codegen undoes the changes.
1809        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1810        cx.run_until_parked();
1811        assert_eq!(
1812            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1813            text
1814        );
1815    }
1816
1817    // When not streaming tool calls, we strip backticks as part of parsing the model's
1818    // plain text response. This is a regression test for a bug where we stripped
1819    // backticks incorrectly.
1820    #[gpui::test]
1821    async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
1822        init_test(cx);
1823        let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
1824        let buffer = cx.new(|cx| Buffer::local("", cx));
1825        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1826        let range = buffer.read_with(cx, |buffer, cx| {
1827            let snapshot = buffer.snapshot(cx);
1828            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
1829        });
1830        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1831        let codegen = cx.new(|cx| {
1832            CodegenAlternative::new(
1833                buffer.clone(),
1834                range.clone(),
1835                true,
1836                prompt_builder,
1837                Uuid::new_v4(),
1838                cx,
1839            )
1840        });
1841
1842        let events_tx = simulate_tool_based_completion(&codegen, cx);
1843        let chunk_len = text.find('`').unwrap();
1844        events_tx
1845            .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
1846            .unwrap();
1847        events_tx
1848            .unbounded_send(rewrite_tool_use("tool_2", &text, true))
1849            .unwrap();
1850        events_tx
1851            .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
1852            .unwrap();
1853        drop(events_tx);
1854        cx.run_until_parked();
1855
1856        assert_eq!(
1857            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1858            text
1859        );
1860    }
1861
1862    #[gpui::test]
1863    async fn test_strip_invalid_spans_from_codeblock() {
1864        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1865        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1866        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1867        assert_chunks(
1868            "```html\n```js\nLorem ipsum dolor\n```\n```",
1869            "```js\nLorem ipsum dolor\n```",
1870        )
1871        .await;
1872        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1873        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1874        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1875        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1876
1877        async fn assert_chunks(text: &str, expected_text: &str) {
1878            for chunk_size in 1..=text.len() {
1879                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1880                    .map(|chunk| chunk.unwrap())
1881                    .collect::<String>()
1882                    .await;
1883                assert_eq!(
1884                    actual_text, expected_text,
1885                    "failed to strip invalid spans, chunk size: {}",
1886                    chunk_size
1887                );
1888            }
1889        }
1890
1891        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1892            stream::iter(
1893                text.chars()
1894                    .collect::<Vec<_>>()
1895                    .chunks(size)
1896                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1897                    .collect::<Vec<_>>(),
1898            )
1899        }
1900    }
1901
1902    fn init_test(cx: &mut TestAppContext) {
1903        cx.update(LanguageModelRegistry::test);
1904        cx.set_global(cx.update(SettingsStore::test));
1905    }
1906
1907    fn simulate_response_stream(
1908        codegen: &Entity<CodegenAlternative>,
1909        cx: &mut TestAppContext,
1910    ) -> mpsc::UnboundedSender<String> {
1911        let (chunks_tx, chunks_rx) = mpsc::unbounded();
1912        let model = Arc::new(FakeLanguageModel::default());
1913        codegen.update(cx, |codegen, cx| {
1914            codegen.generation = codegen.handle_stream(
1915                model,
1916                /* strip_invalid_spans: */ false,
1917                future::ready(Ok(LanguageModelTextStream {
1918                    message_id: None,
1919                    stream: chunks_rx.map(Ok).boxed(),
1920                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1921                })),
1922                cx,
1923            );
1924        });
1925        chunks_tx
1926    }
1927
1928    fn simulate_tool_based_completion(
1929        codegen: &Entity<CodegenAlternative>,
1930        cx: &mut TestAppContext,
1931    ) -> mpsc::UnboundedSender<LanguageModelCompletionEvent> {
1932        let (events_tx, events_rx) = mpsc::unbounded();
1933        let model = Arc::new(FakeLanguageModel::default());
1934        codegen.update(cx, |codegen, cx| {
1935            let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
1936                as BoxStream<
1937                    'static,
1938                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1939                >));
1940            codegen.generation = codegen.handle_completion(model, completion_stream, cx);
1941        });
1942        events_tx
1943    }
1944
1945    fn rewrite_tool_use(
1946        id: &str,
1947        replacement_text: &str,
1948        is_complete: bool,
1949    ) -> LanguageModelCompletionEvent {
1950        let input = RewriteSectionInput {
1951            replacement_text: replacement_text.into(),
1952        };
1953        LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
1954            id: id.into(),
1955            name: REWRITE_SECTION_TOOL_NAME.into(),
1956            raw_input: serde_json::to_string(&input).unwrap(),
1957            input: serde_json::to_value(&input).unwrap(),
1958            is_input_complete: is_complete,
1959            thought_signature: None,
1960        })
1961    }
1962}