inline_assistant.rs

   1use crate::{
   2    Assistant, AssistantPanel, AssistantPanelEvent, CycleNextInlineAssist,
   3    CyclePreviousInlineAssist,
   4};
   5use anyhow::{Context as _, Result, anyhow};
   6use assistant_context_editor::{RequestType, humanize_token_count};
   7use assistant_settings::AssistantSettings;
   8use client::{ErrorExt, telemetry::Telemetry};
   9use collections::{HashMap, HashSet, VecDeque, hash_map};
  10use editor::{
  11    Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
  12    EditorStyle, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint,
  13    actions::{MoveDown, MoveUp, SelectAll},
  14    display_map::{
  15        BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, EditorMargins,
  16        RenderBlock, ToDisplayPoint,
  17    },
  18};
  19use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
  20use fs::Fs;
  21use futures::{
  22    SinkExt, Stream, StreamExt, TryStreamExt as _,
  23    channel::mpsc,
  24    future::{BoxFuture, LocalBoxFuture},
  25    join,
  26};
  27use gpui::{
  28    AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
  29    Focusable, FontWeight, Global, HighlightStyle, Subscription, Task, TextStyle, UpdateGlobal,
  30    WeakEntity, Window, anchored, deferred, point,
  31};
  32use language::{Buffer, IndentKind, Point, Selection, TransactionId, line_diff};
  33use language_model::{
  34    ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
  35    LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
  36};
  37use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
  38use multi_buffer::MultiBufferRow;
  39use parking_lot::Mutex;
  40use project::{CodeAction, LspAction, ProjectTransaction};
  41use prompt_store::PromptBuilder;
  42use rope::Rope;
  43use settings::{Settings, SettingsStore, update_settings_file};
  44use smol::future::FutureExt;
  45use std::{
  46    cmp,
  47    future::{self, Future},
  48    iter, mem,
  49    ops::{Range, RangeInclusive},
  50    pin::Pin,
  51    rc::Rc,
  52    sync::Arc,
  53    task::{self, Poll},
  54    time::{Duration, Instant},
  55};
  56use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
  57use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
  58use terminal_view::terminal_panel::TerminalPanel;
  59use text::{OffsetRangeExt, ToPoint as _};
  60use theme::ThemeSettings;
  61use ui::{
  62    CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, Tooltip, prelude::*, text_for_action,
  63};
  64use util::{RangeExt, ResultExt};
  65use workspace::{ItemHandle, Toast, Workspace, notifications::NotificationId};
  66
  67pub fn init(
  68    fs: Arc<dyn Fs>,
  69    prompt_builder: Arc<PromptBuilder>,
  70    telemetry: Arc<Telemetry>,
  71    cx: &mut App,
  72) {
  73    cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
  74    // Don't register now that the Agent is released.
  75    if false {
  76        cx.observe_new(|_, window, cx| {
  77            let Some(window) = window else {
  78                return;
  79            };
  80            let workspace = cx.entity().clone();
  81            InlineAssistant::update_global(cx, |inline_assistant, cx| {
  82                inline_assistant.register_workspace(&workspace, window, cx)
  83            });
  84        })
  85        .detach();
  86    }
  87}
  88
  89const PROMPT_HISTORY_MAX_LEN: usize = 20;
  90
  91pub struct InlineAssistant {
  92    next_assist_id: InlineAssistId,
  93    next_assist_group_id: InlineAssistGroupId,
  94    assists: HashMap<InlineAssistId, InlineAssist>,
  95    assists_by_editor: HashMap<WeakEntity<Editor>, EditorInlineAssists>,
  96    assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
  97    confirmed_assists: HashMap<InlineAssistId, Entity<CodegenAlternative>>,
  98    prompt_history: VecDeque<String>,
  99    prompt_builder: Arc<PromptBuilder>,
 100    telemetry: Arc<Telemetry>,
 101    fs: Arc<dyn Fs>,
 102}
 103
 104impl Global for InlineAssistant {}
 105
 106impl InlineAssistant {
 107    pub fn new(
 108        fs: Arc<dyn Fs>,
 109        prompt_builder: Arc<PromptBuilder>,
 110        telemetry: Arc<Telemetry>,
 111    ) -> Self {
 112        Self {
 113            next_assist_id: InlineAssistId::default(),
 114            next_assist_group_id: InlineAssistGroupId::default(),
 115            assists: HashMap::default(),
 116            assists_by_editor: HashMap::default(),
 117            assist_groups: HashMap::default(),
 118            confirmed_assists: HashMap::default(),
 119            prompt_history: VecDeque::default(),
 120            prompt_builder,
 121            telemetry,
 122            fs,
 123        }
 124    }
 125
 126    pub fn register_workspace(
 127        &mut self,
 128        workspace: &Entity<Workspace>,
 129        window: &mut Window,
 130        cx: &mut App,
 131    ) {
 132        window
 133            .subscribe(workspace, cx, |workspace, event, window, cx| {
 134                Self::update_global(cx, |this, cx| {
 135                    this.handle_workspace_event(workspace, event, window, cx)
 136                });
 137            })
 138            .detach();
 139
 140        let workspace = workspace.downgrade();
 141        cx.observe_global::<SettingsStore>(move |cx| {
 142            let Some(workspace) = workspace.upgrade() else {
 143                return;
 144            };
 145            let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
 146                return;
 147            };
 148            let enabled = AssistantSettings::get_global(cx).enabled;
 149            terminal_panel.update(cx, |terminal_panel, cx| {
 150                terminal_panel.set_assistant_enabled(enabled, cx)
 151            });
 152        })
 153        .detach();
 154    }
 155
 156    fn handle_workspace_event(
 157        &mut self,
 158        workspace: Entity<Workspace>,
 159        event: &workspace::Event,
 160        window: &mut Window,
 161        cx: &mut App,
 162    ) {
 163        match event {
 164            workspace::Event::UserSavedItem { item, .. } => {
 165                // When the user manually saves an editor, automatically accepts all finished transformations.
 166                if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
 167                    if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
 168                        for assist_id in editor_assists.assist_ids.clone() {
 169                            let assist = &self.assists[&assist_id];
 170                            if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
 171                                self.finish_assist(assist_id, false, window, cx)
 172                            }
 173                        }
 174                    }
 175                }
 176            }
 177            workspace::Event::ItemAdded { item } => {
 178                self.register_workspace_item(&workspace, item.as_ref(), window, cx);
 179            }
 180            _ => (),
 181        }
 182    }
 183
 184    fn register_workspace_item(
 185        &mut self,
 186        workspace: &Entity<Workspace>,
 187        item: &dyn ItemHandle,
 188        window: &mut Window,
 189        cx: &mut App,
 190    ) {
 191        let is_assistant2_enabled = true;
 192
 193        if let Some(editor) = item.act_as::<Editor>(cx) {
 194            editor.update(cx, |editor, cx| {
 195                if is_assistant2_enabled {
 196                    editor.remove_code_action_provider(
 197                        ASSISTANT_CODE_ACTION_PROVIDER_ID.into(),
 198                        window,
 199                        cx,
 200                    );
 201                } else {
 202                    editor.add_code_action_provider(
 203                        Rc::new(AssistantCodeActionProvider {
 204                            editor: cx.entity().downgrade(),
 205                            workspace: workspace.downgrade(),
 206                        }),
 207                        window,
 208                        cx,
 209                    );
 210                }
 211            });
 212        }
 213    }
 214
 215    pub fn assist(
 216        &mut self,
 217        editor: &Entity<Editor>,
 218        workspace: Option<WeakEntity<Workspace>>,
 219        assistant_panel: Option<&Entity<AssistantPanel>>,
 220        initial_prompt: Option<String>,
 221        window: &mut Window,
 222        cx: &mut App,
 223    ) {
 224        let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
 225            (
 226                editor.snapshot(window, cx),
 227                editor.selections.all::<Point>(cx),
 228            )
 229        });
 230
 231        let mut selections = Vec::<Selection<Point>>::new();
 232        let mut newest_selection = None;
 233        for mut selection in initial_selections {
 234            if selection.end > selection.start {
 235                selection.start.column = 0;
 236                // If the selection ends at the start of the line, we don't want to include it.
 237                if selection.end.column == 0 {
 238                    selection.end.row -= 1;
 239                }
 240                selection.end.column = snapshot
 241                    .buffer_snapshot
 242                    .line_len(MultiBufferRow(selection.end.row));
 243            } else if let Some(fold) =
 244                snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
 245            {
 246                selection.start = fold.range().start;
 247                selection.end = fold.range().end;
 248                if MultiBufferRow(selection.end.row) < snapshot.buffer_snapshot.max_row() {
 249                    let chars = snapshot
 250                        .buffer_snapshot
 251                        .chars_at(Point::new(selection.end.row + 1, 0));
 252
 253                    for c in chars {
 254                        if c == '\n' {
 255                            break;
 256                        }
 257                        if c.is_whitespace() {
 258                            continue;
 259                        }
 260                        if snapshot
 261                            .language_at(selection.end)
 262                            .is_some_and(|language| language.config().brackets.is_closing_brace(c))
 263                        {
 264                            selection.end.row += 1;
 265                            selection.end.column = snapshot
 266                                .buffer_snapshot
 267                                .line_len(MultiBufferRow(selection.end.row));
 268                        }
 269                    }
 270                }
 271            }
 272
 273            if let Some(prev_selection) = selections.last_mut() {
 274                if selection.start <= prev_selection.end {
 275                    prev_selection.end = selection.end;
 276                    continue;
 277                }
 278            }
 279
 280            let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
 281            if selection.id > latest_selection.id {
 282                *latest_selection = selection.clone();
 283            }
 284            selections.push(selection);
 285        }
 286        let snapshot = &snapshot.buffer_snapshot;
 287        let newest_selection = newest_selection.unwrap();
 288
 289        let mut codegen_ranges = Vec::new();
 290        for (buffer, buffer_range, excerpt_id) in
 291            snapshot.ranges_to_buffer_ranges(selections.iter().map(|selection| {
 292                snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
 293            }))
 294        {
 295            let start = buffer.anchor_before(buffer_range.start);
 296            let end = buffer.anchor_after(buffer_range.end);
 297
 298            codegen_ranges.push(Anchor::range_in_buffer(
 299                excerpt_id,
 300                buffer.remote_id(),
 301                start..end,
 302            ));
 303
 304            if let Some(ConfiguredModel { model, .. }) =
 305                LanguageModelRegistry::read_global(cx).default_model()
 306            {
 307                self.telemetry.report_assistant_event(AssistantEventData {
 308                    conversation_id: None,
 309                    kind: AssistantKind::Inline,
 310                    phase: AssistantPhase::Invoked,
 311                    message_id: None,
 312                    model: model.telemetry_id(),
 313                    model_provider: model.provider_id().to_string(),
 314                    response_latency: None,
 315                    error_message: None,
 316                    language_name: buffer.language().map(|language| language.name().to_proto()),
 317                });
 318            }
 319        }
 320
 321        let assist_group_id = self.next_assist_group_id.post_inc();
 322        let prompt_buffer = cx.new(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
 323        let prompt_buffer = cx.new(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 324
 325        let mut assists = Vec::new();
 326        let mut assist_to_focus = None;
 327        for range in codegen_ranges {
 328            let assist_id = self.next_assist_id.post_inc();
 329            let codegen = cx.new(|cx| {
 330                Codegen::new(
 331                    editor.read(cx).buffer().clone(),
 332                    range.clone(),
 333                    None,
 334                    self.telemetry.clone(),
 335                    self.prompt_builder.clone(),
 336                    cx,
 337                )
 338            });
 339
 340            let editor_margins = Arc::new(Mutex::new(EditorMargins::default()));
 341            let prompt_editor = cx.new(|cx| {
 342                PromptEditor::new(
 343                    assist_id,
 344                    editor_margins,
 345                    self.prompt_history.clone(),
 346                    prompt_buffer.clone(),
 347                    codegen.clone(),
 348                    editor,
 349                    assistant_panel,
 350                    workspace.clone(),
 351                    self.fs.clone(),
 352                    window,
 353                    cx,
 354                )
 355            });
 356
 357            if assist_to_focus.is_none() {
 358                let focus_assist = if newest_selection.reversed {
 359                    range.start.to_point(&snapshot) == newest_selection.start
 360                } else {
 361                    range.end.to_point(&snapshot) == newest_selection.end
 362                };
 363                if focus_assist {
 364                    assist_to_focus = Some(assist_id);
 365                }
 366            }
 367
 368            let [prompt_block_id, end_block_id] =
 369                self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 370
 371            assists.push((
 372                assist_id,
 373                range,
 374                prompt_editor,
 375                prompt_block_id,
 376                end_block_id,
 377            ));
 378        }
 379
 380        let editor_assists = self
 381            .assists_by_editor
 382            .entry(editor.downgrade())
 383            .or_insert_with(|| EditorInlineAssists::new(&editor, window, cx));
 384        let mut assist_group = InlineAssistGroup::new();
 385        for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
 386            self.assists.insert(
 387                assist_id,
 388                InlineAssist::new(
 389                    assist_id,
 390                    assist_group_id,
 391                    assistant_panel.is_some(),
 392                    editor,
 393                    &prompt_editor,
 394                    prompt_block_id,
 395                    end_block_id,
 396                    range,
 397                    prompt_editor.read(cx).codegen.clone(),
 398                    workspace.clone(),
 399                    window,
 400                    cx,
 401                ),
 402            );
 403            assist_group.assist_ids.push(assist_id);
 404            editor_assists.assist_ids.push(assist_id);
 405        }
 406        self.assist_groups.insert(assist_group_id, assist_group);
 407
 408        if let Some(assist_id) = assist_to_focus {
 409            self.focus_assist(assist_id, window, cx);
 410        }
 411    }
 412
 413    pub fn suggest_assist(
 414        &mut self,
 415        editor: &Entity<Editor>,
 416        mut range: Range<Anchor>,
 417        initial_prompt: String,
 418        initial_transaction_id: Option<TransactionId>,
 419        focus: bool,
 420        workspace: Option<WeakEntity<Workspace>>,
 421        assistant_panel: Option<&Entity<AssistantPanel>>,
 422        window: &mut Window,
 423        cx: &mut App,
 424    ) -> InlineAssistId {
 425        let assist_group_id = self.next_assist_group_id.post_inc();
 426        let prompt_buffer = cx.new(|cx| Buffer::local(&initial_prompt, cx));
 427        let prompt_buffer = cx.new(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 428
 429        let assist_id = self.next_assist_id.post_inc();
 430
 431        let buffer = editor.read(cx).buffer().clone();
 432        {
 433            let snapshot = buffer.read(cx).read(cx);
 434            range.start = range.start.bias_left(&snapshot);
 435            range.end = range.end.bias_right(&snapshot);
 436        }
 437
 438        let codegen = cx.new(|cx| {
 439            Codegen::new(
 440                editor.read(cx).buffer().clone(),
 441                range.clone(),
 442                initial_transaction_id,
 443                self.telemetry.clone(),
 444                self.prompt_builder.clone(),
 445                cx,
 446            )
 447        });
 448
 449        let editor_margins = Arc::new(Mutex::new(EditorMargins::default()));
 450        let prompt_editor = cx.new(|cx| {
 451            PromptEditor::new(
 452                assist_id,
 453                editor_margins,
 454                self.prompt_history.clone(),
 455                prompt_buffer.clone(),
 456                codegen.clone(),
 457                editor,
 458                assistant_panel,
 459                workspace.clone(),
 460                self.fs.clone(),
 461                window,
 462                cx,
 463            )
 464        });
 465
 466        let [prompt_block_id, end_block_id] =
 467            self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 468
 469        let editor_assists = self
 470            .assists_by_editor
 471            .entry(editor.downgrade())
 472            .or_insert_with(|| EditorInlineAssists::new(&editor, window, cx));
 473
 474        let mut assist_group = InlineAssistGroup::new();
 475        self.assists.insert(
 476            assist_id,
 477            InlineAssist::new(
 478                assist_id,
 479                assist_group_id,
 480                assistant_panel.is_some(),
 481                editor,
 482                &prompt_editor,
 483                prompt_block_id,
 484                end_block_id,
 485                range,
 486                prompt_editor.read(cx).codegen.clone(),
 487                workspace.clone(),
 488                window,
 489                cx,
 490            ),
 491        );
 492        assist_group.assist_ids.push(assist_id);
 493        editor_assists.assist_ids.push(assist_id);
 494        self.assist_groups.insert(assist_group_id, assist_group);
 495
 496        if focus {
 497            self.focus_assist(assist_id, window, cx);
 498        }
 499
 500        assist_id
 501    }
 502
 503    fn insert_assist_blocks(
 504        &self,
 505        editor: &Entity<Editor>,
 506        range: &Range<Anchor>,
 507        prompt_editor: &Entity<PromptEditor>,
 508        cx: &mut App,
 509    ) -> [CustomBlockId; 2] {
 510        let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
 511            prompt_editor
 512                .editor
 513                .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
 514        });
 515        let assist_blocks = vec![
 516            BlockProperties {
 517                style: BlockStyle::Sticky,
 518                placement: BlockPlacement::Above(range.start),
 519                height: Some(prompt_editor_height),
 520                render: build_assist_editor_renderer(prompt_editor),
 521                priority: 0,
 522                render_in_minimap: false,
 523            },
 524            BlockProperties {
 525                style: BlockStyle::Sticky,
 526                placement: BlockPlacement::Below(range.end),
 527                height: None,
 528                render: Arc::new(|cx| {
 529                    v_flex()
 530                        .h_full()
 531                        .w_full()
 532                        .border_t_1()
 533                        .border_color(cx.theme().status().info_border)
 534                        .into_any_element()
 535                }),
 536                priority: 0,
 537                render_in_minimap: false,
 538            },
 539        ];
 540
 541        editor.update(cx, |editor, cx| {
 542            let block_ids = editor.insert_blocks(assist_blocks, None, cx);
 543            [block_ids[0], block_ids[1]]
 544        })
 545    }
 546
 547    fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut App) {
 548        let assist = &self.assists[&assist_id];
 549        let Some(decorations) = assist.decorations.as_ref() else {
 550            return;
 551        };
 552        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 553        let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
 554
 555        assist_group.active_assist_id = Some(assist_id);
 556        if assist_group.linked {
 557            for assist_id in &assist_group.assist_ids {
 558                if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 559                    decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 560                        prompt_editor.set_show_cursor_when_unfocused(true, cx)
 561                    });
 562                }
 563            }
 564        }
 565
 566        assist
 567            .editor
 568            .update(cx, |editor, cx| {
 569                let scroll_top = editor.scroll_position(cx).y;
 570                let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
 571                let prompt_row = editor
 572                    .row_for_block(decorations.prompt_block_id, cx)
 573                    .unwrap()
 574                    .0 as f32;
 575
 576                if (scroll_top..scroll_bottom).contains(&prompt_row) {
 577                    editor_assists.scroll_lock = Some(InlineAssistScrollLock {
 578                        assist_id,
 579                        distance_from_top: prompt_row - scroll_top,
 580                    });
 581                } else {
 582                    editor_assists.scroll_lock = None;
 583                }
 584            })
 585            .ok();
 586    }
 587
 588    fn handle_prompt_editor_focus_out(&mut self, assist_id: InlineAssistId, cx: &mut App) {
 589        let assist = &self.assists[&assist_id];
 590        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 591        if assist_group.active_assist_id == Some(assist_id) {
 592            assist_group.active_assist_id = None;
 593            if assist_group.linked {
 594                for assist_id in &assist_group.assist_ids {
 595                    if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 596                        decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 597                            prompt_editor.set_show_cursor_when_unfocused(false, cx)
 598                        });
 599                    }
 600                }
 601            }
 602        }
 603    }
 604
 605    fn handle_prompt_editor_event(
 606        &mut self,
 607        prompt_editor: Entity<PromptEditor>,
 608        event: &PromptEditorEvent,
 609        window: &mut Window,
 610        cx: &mut App,
 611    ) {
 612        let assist_id = prompt_editor.read(cx).id;
 613        match event {
 614            PromptEditorEvent::StartRequested => {
 615                self.start_assist(assist_id, window, cx);
 616            }
 617            PromptEditorEvent::StopRequested => {
 618                self.stop_assist(assist_id, cx);
 619            }
 620            PromptEditorEvent::ConfirmRequested => {
 621                self.finish_assist(assist_id, false, window, cx);
 622            }
 623            PromptEditorEvent::CancelRequested => {
 624                self.finish_assist(assist_id, true, window, cx);
 625            }
 626            PromptEditorEvent::DismissRequested => {
 627                self.dismiss_assist(assist_id, window, cx);
 628            }
 629        }
 630    }
 631
 632    fn handle_editor_newline(&mut self, editor: Entity<Editor>, window: &mut Window, cx: &mut App) {
 633        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 634            return;
 635        };
 636
 637        if editor.read(cx).selections.count() == 1 {
 638            let (selection, buffer) = editor.update(cx, |editor, cx| {
 639                (
 640                    editor.selections.newest::<usize>(cx),
 641                    editor.buffer().read(cx).snapshot(cx),
 642                )
 643            });
 644            for assist_id in &editor_assists.assist_ids {
 645                let assist = &self.assists[assist_id];
 646                let assist_range = assist.range.to_offset(&buffer);
 647                if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
 648                {
 649                    if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
 650                        self.dismiss_assist(*assist_id, window, cx);
 651                    } else {
 652                        self.finish_assist(*assist_id, false, window, cx);
 653                    }
 654
 655                    return;
 656                }
 657            }
 658        }
 659
 660        cx.propagate();
 661    }
 662
 663    fn handle_editor_cancel(&mut self, editor: Entity<Editor>, window: &mut Window, cx: &mut App) {
 664        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 665            return;
 666        };
 667
 668        if editor.read(cx).selections.count() == 1 {
 669            let (selection, buffer) = editor.update(cx, |editor, cx| {
 670                (
 671                    editor.selections.newest::<usize>(cx),
 672                    editor.buffer().read(cx).snapshot(cx),
 673                )
 674            });
 675            let mut closest_assist_fallback = None;
 676            for assist_id in &editor_assists.assist_ids {
 677                let assist = &self.assists[assist_id];
 678                let assist_range = assist.range.to_offset(&buffer);
 679                if assist.decorations.is_some() {
 680                    if assist_range.contains(&selection.start)
 681                        && assist_range.contains(&selection.end)
 682                    {
 683                        self.focus_assist(*assist_id, window, cx);
 684                        return;
 685                    } else {
 686                        let distance_from_selection = assist_range
 687                            .start
 688                            .abs_diff(selection.start)
 689                            .min(assist_range.start.abs_diff(selection.end))
 690                            + assist_range
 691                                .end
 692                                .abs_diff(selection.start)
 693                                .min(assist_range.end.abs_diff(selection.end));
 694                        match closest_assist_fallback {
 695                            Some((_, old_distance)) => {
 696                                if distance_from_selection < old_distance {
 697                                    closest_assist_fallback =
 698                                        Some((assist_id, distance_from_selection));
 699                                }
 700                            }
 701                            None => {
 702                                closest_assist_fallback = Some((assist_id, distance_from_selection))
 703                            }
 704                        }
 705                    }
 706                }
 707            }
 708
 709            if let Some((&assist_id, _)) = closest_assist_fallback {
 710                self.focus_assist(assist_id, window, cx);
 711            }
 712        }
 713
 714        cx.propagate();
 715    }
 716
 717    fn handle_editor_release(
 718        &mut self,
 719        editor: WeakEntity<Editor>,
 720        window: &mut Window,
 721        cx: &mut App,
 722    ) {
 723        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
 724            for assist_id in editor_assists.assist_ids.clone() {
 725                self.finish_assist(assist_id, true, window, cx);
 726            }
 727        }
 728    }
 729
 730    fn handle_editor_change(&mut self, editor: Entity<Editor>, window: &mut Window, cx: &mut App) {
 731        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 732            return;
 733        };
 734        let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
 735            return;
 736        };
 737        let assist = &self.assists[&scroll_lock.assist_id];
 738        let Some(decorations) = assist.decorations.as_ref() else {
 739            return;
 740        };
 741
 742        editor.update(cx, |editor, cx| {
 743            let scroll_position = editor.scroll_position(cx);
 744            let target_scroll_top = editor
 745                .row_for_block(decorations.prompt_block_id, cx)
 746                .unwrap()
 747                .0 as f32
 748                - scroll_lock.distance_from_top;
 749            if target_scroll_top != scroll_position.y {
 750                editor.set_scroll_position(point(scroll_position.x, target_scroll_top), window, cx);
 751            }
 752        });
 753    }
 754
 755    fn handle_editor_event(
 756        &mut self,
 757        editor: Entity<Editor>,
 758        event: &EditorEvent,
 759        window: &mut Window,
 760        cx: &mut App,
 761    ) {
 762        let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
 763            return;
 764        };
 765
 766        match event {
 767            EditorEvent::Edited { transaction_id } => {
 768                let buffer = editor.read(cx).buffer().read(cx);
 769                let edited_ranges =
 770                    buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
 771                let snapshot = buffer.snapshot(cx);
 772
 773                for assist_id in editor_assists.assist_ids.clone() {
 774                    let assist = &self.assists[&assist_id];
 775                    if matches!(
 776                        assist.codegen.read(cx).status(cx),
 777                        CodegenStatus::Error(_) | CodegenStatus::Done
 778                    ) {
 779                        let assist_range = assist.range.to_offset(&snapshot);
 780                        if edited_ranges
 781                            .iter()
 782                            .any(|range| range.overlaps(&assist_range))
 783                        {
 784                            self.finish_assist(assist_id, false, window, cx);
 785                        }
 786                    }
 787                }
 788            }
 789            EditorEvent::ScrollPositionChanged { .. } => {
 790                if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
 791                    let assist = &self.assists[&scroll_lock.assist_id];
 792                    if let Some(decorations) = assist.decorations.as_ref() {
 793                        let distance_from_top = editor.update(cx, |editor, cx| {
 794                            let scroll_top = editor.scroll_position(cx).y;
 795                            let prompt_row = editor
 796                                .row_for_block(decorations.prompt_block_id, cx)
 797                                .unwrap()
 798                                .0 as f32;
 799                            prompt_row - scroll_top
 800                        });
 801
 802                        if distance_from_top != scroll_lock.distance_from_top {
 803                            editor_assists.scroll_lock = None;
 804                        }
 805                    }
 806                }
 807            }
 808            EditorEvent::SelectionsChanged { .. } => {
 809                for assist_id in editor_assists.assist_ids.clone() {
 810                    let assist = &self.assists[&assist_id];
 811                    if let Some(decorations) = assist.decorations.as_ref() {
 812                        if decorations
 813                            .prompt_editor
 814                            .focus_handle(cx)
 815                            .is_focused(window)
 816                        {
 817                            return;
 818                        }
 819                    }
 820                }
 821
 822                editor_assists.scroll_lock = None;
 823            }
 824            _ => {}
 825        }
 826    }
 827
 828    pub fn finish_assist(
 829        &mut self,
 830        assist_id: InlineAssistId,
 831        undo: bool,
 832        window: &mut Window,
 833        cx: &mut App,
 834    ) {
 835        if let Some(assist) = self.assists.get(&assist_id) {
 836            let assist_group_id = assist.group_id;
 837            if self.assist_groups[&assist_group_id].linked {
 838                for assist_id in self.unlink_assist_group(assist_group_id, window, cx) {
 839                    self.finish_assist(assist_id, undo, window, cx);
 840                }
 841                return;
 842            }
 843        }
 844
 845        self.dismiss_assist(assist_id, window, cx);
 846
 847        if let Some(assist) = self.assists.remove(&assist_id) {
 848            if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
 849            {
 850                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 851                if entry.get().assist_ids.is_empty() {
 852                    entry.remove();
 853                }
 854            }
 855
 856            if let hash_map::Entry::Occupied(mut entry) =
 857                self.assists_by_editor.entry(assist.editor.clone())
 858            {
 859                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 860                if entry.get().assist_ids.is_empty() {
 861                    entry.remove();
 862                    if let Some(editor) = assist.editor.upgrade() {
 863                        self.update_editor_highlights(&editor, cx);
 864                    }
 865                } else {
 866                    entry.get().highlight_updates.send(()).ok();
 867                }
 868            }
 869
 870            let active_alternative = assist.codegen.read(cx).active_alternative().clone();
 871            let message_id = active_alternative.read(cx).message_id.clone();
 872
 873            if let Some(ConfiguredModel { model, .. }) =
 874                LanguageModelRegistry::read_global(cx).default_model()
 875            {
 876                let language_name = assist.editor.upgrade().and_then(|editor| {
 877                    let multibuffer = editor.read(cx).buffer().read(cx);
 878                    let multibuffer_snapshot = multibuffer.snapshot(cx);
 879                    let ranges = multibuffer_snapshot.range_to_buffer_ranges(assist.range.clone());
 880                    ranges
 881                        .first()
 882                        .and_then(|(buffer, _, _)| buffer.language())
 883                        .map(|language| language.name())
 884                });
 885                report_assistant_event(
 886                    AssistantEventData {
 887                        conversation_id: None,
 888                        kind: AssistantKind::Inline,
 889                        message_id,
 890                        phase: if undo {
 891                            AssistantPhase::Rejected
 892                        } else {
 893                            AssistantPhase::Accepted
 894                        },
 895                        model: model.telemetry_id(),
 896                        model_provider: model.provider_id().to_string(),
 897                        response_latency: None,
 898                        error_message: None,
 899                        language_name: language_name.map(|name| name.to_proto()),
 900                    },
 901                    Some(self.telemetry.clone()),
 902                    cx.http_client(),
 903                    model.api_key(cx),
 904                    cx.background_executor(),
 905                );
 906            }
 907
 908            if undo {
 909                assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
 910            } else {
 911                self.confirmed_assists.insert(assist_id, active_alternative);
 912            }
 913        }
 914    }
 915
 916    fn dismiss_assist(
 917        &mut self,
 918        assist_id: InlineAssistId,
 919        window: &mut Window,
 920        cx: &mut App,
 921    ) -> bool {
 922        let Some(assist) = self.assists.get_mut(&assist_id) else {
 923            return false;
 924        };
 925        let Some(editor) = assist.editor.upgrade() else {
 926            return false;
 927        };
 928        let Some(decorations) = assist.decorations.take() else {
 929            return false;
 930        };
 931
 932        editor.update(cx, |editor, cx| {
 933            let mut to_remove = decorations.removed_line_block_ids;
 934            to_remove.insert(decorations.prompt_block_id);
 935            to_remove.insert(decorations.end_block_id);
 936            editor.remove_blocks(to_remove, None, cx);
 937        });
 938
 939        if decorations
 940            .prompt_editor
 941            .focus_handle(cx)
 942            .contains_focused(window, cx)
 943        {
 944            self.focus_next_assist(assist_id, window, cx);
 945        }
 946
 947        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
 948            if editor_assists
 949                .scroll_lock
 950                .as_ref()
 951                .map_or(false, |lock| lock.assist_id == assist_id)
 952            {
 953                editor_assists.scroll_lock = None;
 954            }
 955            editor_assists.highlight_updates.send(()).ok();
 956        }
 957
 958        true
 959    }
 960
 961    fn focus_next_assist(&mut self, assist_id: InlineAssistId, window: &mut Window, cx: &mut App) {
 962        let Some(assist) = self.assists.get(&assist_id) else {
 963            return;
 964        };
 965
 966        let assist_group = &self.assist_groups[&assist.group_id];
 967        let assist_ix = assist_group
 968            .assist_ids
 969            .iter()
 970            .position(|id| *id == assist_id)
 971            .unwrap();
 972        let assist_ids = assist_group
 973            .assist_ids
 974            .iter()
 975            .skip(assist_ix + 1)
 976            .chain(assist_group.assist_ids.iter().take(assist_ix));
 977
 978        for assist_id in assist_ids {
 979            let assist = &self.assists[assist_id];
 980            if assist.decorations.is_some() {
 981                self.focus_assist(*assist_id, window, cx);
 982                return;
 983            }
 984        }
 985
 986        assist
 987            .editor
 988            .update(cx, |editor, cx| window.focus(&editor.focus_handle(cx)))
 989            .ok();
 990    }
 991
 992    fn focus_assist(&mut self, assist_id: InlineAssistId, window: &mut Window, cx: &mut App) {
 993        let Some(assist) = self.assists.get(&assist_id) else {
 994            return;
 995        };
 996
 997        if let Some(decorations) = assist.decorations.as_ref() {
 998            decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 999                prompt_editor.editor.update(cx, |editor, cx| {
1000                    window.focus(&editor.focus_handle(cx));
1001                    editor.select_all(&SelectAll, window, cx);
1002                })
1003            });
1004        }
1005
1006        self.scroll_to_assist(assist_id, window, cx);
1007    }
1008
1009    pub fn scroll_to_assist(
1010        &mut self,
1011        assist_id: InlineAssistId,
1012        window: &mut Window,
1013        cx: &mut App,
1014    ) {
1015        let Some(assist) = self.assists.get(&assist_id) else {
1016            return;
1017        };
1018        let Some(editor) = assist.editor.upgrade() else {
1019            return;
1020        };
1021
1022        let position = assist.range.start;
1023        editor.update(cx, |editor, cx| {
1024            editor.change_selections(None, window, cx, |selections| {
1025                selections.select_anchor_ranges([position..position])
1026            });
1027
1028            let mut scroll_target_top;
1029            let mut scroll_target_bottom;
1030            if let Some(decorations) = assist.decorations.as_ref() {
1031                scroll_target_top = editor
1032                    .row_for_block(decorations.prompt_block_id, cx)
1033                    .unwrap()
1034                    .0 as f32;
1035                scroll_target_bottom = editor
1036                    .row_for_block(decorations.end_block_id, cx)
1037                    .unwrap()
1038                    .0 as f32;
1039            } else {
1040                let snapshot = editor.snapshot(window, cx);
1041                let start_row = assist
1042                    .range
1043                    .start
1044                    .to_display_point(&snapshot.display_snapshot)
1045                    .row();
1046                scroll_target_top = start_row.0 as f32;
1047                scroll_target_bottom = scroll_target_top + 1.;
1048            }
1049            scroll_target_top -= editor.vertical_scroll_margin() as f32;
1050            scroll_target_bottom += editor.vertical_scroll_margin() as f32;
1051
1052            let height_in_lines = editor.visible_line_count().unwrap_or(0.);
1053            let scroll_top = editor.scroll_position(cx).y;
1054            let scroll_bottom = scroll_top + height_in_lines;
1055
1056            if scroll_target_top < scroll_top {
1057                editor.set_scroll_position(point(0., scroll_target_top), window, cx);
1058            } else if scroll_target_bottom > scroll_bottom {
1059                if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
1060                    editor.set_scroll_position(
1061                        point(0., scroll_target_bottom - height_in_lines),
1062                        window,
1063                        cx,
1064                    );
1065                } else {
1066                    editor.set_scroll_position(point(0., scroll_target_top), window, cx);
1067                }
1068            }
1069        });
1070    }
1071
1072    fn unlink_assist_group(
1073        &mut self,
1074        assist_group_id: InlineAssistGroupId,
1075        window: &mut Window,
1076        cx: &mut App,
1077    ) -> Vec<InlineAssistId> {
1078        let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
1079        assist_group.linked = false;
1080        for assist_id in &assist_group.assist_ids {
1081            let assist = self.assists.get_mut(assist_id).unwrap();
1082            if let Some(editor_decorations) = assist.decorations.as_ref() {
1083                editor_decorations
1084                    .prompt_editor
1085                    .update(cx, |prompt_editor, cx| prompt_editor.unlink(window, cx));
1086            }
1087        }
1088        assist_group.assist_ids.clone()
1089    }
1090
1091    pub fn start_assist(&mut self, assist_id: InlineAssistId, window: &mut Window, cx: &mut App) {
1092        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1093            assist
1094        } else {
1095            return;
1096        };
1097
1098        let assist_group_id = assist.group_id;
1099        if self.assist_groups[&assist_group_id].linked {
1100            for assist_id in self.unlink_assist_group(assist_group_id, window, cx) {
1101                self.start_assist(assist_id, window, cx);
1102            }
1103            return;
1104        }
1105
1106        let Some(user_prompt) = assist.user_prompt(cx) else {
1107            return;
1108        };
1109
1110        self.prompt_history.retain(|prompt| *prompt != user_prompt);
1111        self.prompt_history.push_back(user_prompt.clone());
1112        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1113            self.prompt_history.pop_front();
1114        }
1115
1116        let assistant_panel_context = assist.assistant_panel_context(cx);
1117
1118        assist
1119            .codegen
1120            .update(cx, |codegen, cx| {
1121                codegen.start(user_prompt, assistant_panel_context, cx)
1122            })
1123            .log_err();
1124    }
1125
1126    pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut App) {
1127        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1128            assist
1129        } else {
1130            return;
1131        };
1132
1133        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1134    }
1135
1136    fn update_editor_highlights(&self, editor: &Entity<Editor>, cx: &mut App) {
1137        let mut gutter_pending_ranges = Vec::new();
1138        let mut gutter_transformed_ranges = Vec::new();
1139        let mut foreground_ranges = Vec::new();
1140        let mut inserted_row_ranges = Vec::new();
1141        let empty_assist_ids = Vec::new();
1142        let assist_ids = self
1143            .assists_by_editor
1144            .get(&editor.downgrade())
1145            .map_or(&empty_assist_ids, |editor_assists| {
1146                &editor_assists.assist_ids
1147            });
1148
1149        for assist_id in assist_ids {
1150            if let Some(assist) = self.assists.get(assist_id) {
1151                let codegen = assist.codegen.read(cx);
1152                let buffer = codegen.buffer(cx).read(cx).read(cx);
1153                foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1154
1155                let pending_range =
1156                    codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1157                if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1158                    gutter_pending_ranges.push(pending_range);
1159                }
1160
1161                if let Some(edit_position) = codegen.edit_position(cx) {
1162                    let edited_range = assist.range.start..edit_position;
1163                    if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1164                        gutter_transformed_ranges.push(edited_range);
1165                    }
1166                }
1167
1168                if assist.decorations.is_some() {
1169                    inserted_row_ranges
1170                        .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1171                }
1172            }
1173        }
1174
1175        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1176        merge_ranges(&mut foreground_ranges, &snapshot);
1177        merge_ranges(&mut gutter_pending_ranges, &snapshot);
1178        merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1179        editor.update(cx, |editor, cx| {
1180            enum GutterPendingRange {}
1181            if gutter_pending_ranges.is_empty() {
1182                editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1183            } else {
1184                editor.highlight_gutter::<GutterPendingRange>(
1185                    &gutter_pending_ranges,
1186                    |cx| cx.theme().status().info_background,
1187                    cx,
1188                )
1189            }
1190
1191            enum GutterTransformedRange {}
1192            if gutter_transformed_ranges.is_empty() {
1193                editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1194            } else {
1195                editor.highlight_gutter::<GutterTransformedRange>(
1196                    &gutter_transformed_ranges,
1197                    |cx| cx.theme().status().info,
1198                    cx,
1199                )
1200            }
1201
1202            if foreground_ranges.is_empty() {
1203                editor.clear_highlights::<InlineAssist>(cx);
1204            } else {
1205                editor.highlight_text::<InlineAssist>(
1206                    foreground_ranges,
1207                    HighlightStyle {
1208                        fade_out: Some(0.6),
1209                        ..Default::default()
1210                    },
1211                    cx,
1212                );
1213            }
1214
1215            editor.clear_row_highlights::<InlineAssist>();
1216            for row_range in inserted_row_ranges {
1217                editor.highlight_rows::<InlineAssist>(
1218                    row_range,
1219                    cx.theme().status().info_background,
1220                    Default::default(),
1221                    cx,
1222                );
1223            }
1224        });
1225    }
1226
1227    fn update_editor_blocks(
1228        &mut self,
1229        editor: &Entity<Editor>,
1230        assist_id: InlineAssistId,
1231        window: &mut Window,
1232        cx: &mut App,
1233    ) {
1234        let Some(assist) = self.assists.get_mut(&assist_id) else {
1235            return;
1236        };
1237        let Some(decorations) = assist.decorations.as_mut() else {
1238            return;
1239        };
1240
1241        let codegen = assist.codegen.read(cx);
1242        let old_snapshot = codegen.snapshot(cx);
1243        let old_buffer = codegen.old_buffer(cx);
1244        let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1245
1246        editor.update(cx, |editor, cx| {
1247            let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1248            editor.remove_blocks(old_blocks, None, cx);
1249
1250            let mut new_blocks = Vec::new();
1251            for (new_row, old_row_range) in deleted_row_ranges {
1252                let (_, buffer_start) = old_snapshot
1253                    .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1254                    .unwrap();
1255                let (_, buffer_end) = old_snapshot
1256                    .point_to_buffer_offset(Point::new(
1257                        *old_row_range.end(),
1258                        old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1259                    ))
1260                    .unwrap();
1261
1262                let deleted_lines_editor = cx.new(|cx| {
1263                    let multi_buffer =
1264                        cx.new(|_| MultiBuffer::without_headers(language::Capability::ReadOnly));
1265                    multi_buffer.update(cx, |multi_buffer, cx| {
1266                        multi_buffer.push_excerpts(
1267                            old_buffer.clone(),
1268                            Some(ExcerptRange::new(buffer_start..buffer_end)),
1269                            cx,
1270                        );
1271                    });
1272
1273                    enum DeletedLines {}
1274                    let mut editor = Editor::for_multibuffer(multi_buffer, None, window, cx);
1275                    editor.disable_scrollbars_and_minimap(cx);
1276                    editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1277                    editor.set_show_wrap_guides(false, cx);
1278                    editor.set_show_gutter(false, cx);
1279                    editor.scroll_manager.set_forbid_vertical_scroll(true);
1280                    editor.set_read_only(true);
1281                    editor.set_show_edit_predictions(Some(false), window, cx);
1282                    editor.highlight_rows::<DeletedLines>(
1283                        Anchor::min()..Anchor::max(),
1284                        cx.theme().status().deleted_background,
1285                        Default::default(),
1286                        cx,
1287                    );
1288                    editor
1289                });
1290
1291                let height =
1292                    deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1293                new_blocks.push(BlockProperties {
1294                    placement: BlockPlacement::Above(new_row),
1295                    height: Some(height),
1296                    style: BlockStyle::Flex,
1297                    render: Arc::new(move |cx| {
1298                        div()
1299                            .block_mouse_down()
1300                            .bg(cx.theme().status().deleted_background)
1301                            .size_full()
1302                            .h(height as f32 * cx.window.line_height())
1303                            .pl(cx.margins.gutter.full_width())
1304                            .child(deleted_lines_editor.clone())
1305                            .into_any_element()
1306                    }),
1307                    priority: 0,
1308                    render_in_minimap: false,
1309                });
1310            }
1311
1312            decorations.removed_line_block_ids = editor
1313                .insert_blocks(new_blocks, None, cx)
1314                .into_iter()
1315                .collect();
1316        })
1317    }
1318}
1319
1320struct EditorInlineAssists {
1321    assist_ids: Vec<InlineAssistId>,
1322    scroll_lock: Option<InlineAssistScrollLock>,
1323    highlight_updates: async_watch::Sender<()>,
1324    _update_highlights: Task<Result<()>>,
1325    _subscriptions: Vec<gpui::Subscription>,
1326}
1327
1328struct InlineAssistScrollLock {
1329    assist_id: InlineAssistId,
1330    distance_from_top: f32,
1331}
1332
1333impl EditorInlineAssists {
1334    fn new(editor: &Entity<Editor>, window: &mut Window, cx: &mut App) -> Self {
1335        let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1336        Self {
1337            assist_ids: Vec::new(),
1338            scroll_lock: None,
1339            highlight_updates: highlight_updates_tx,
1340            _update_highlights: cx.spawn({
1341                let editor = editor.downgrade();
1342                async move |cx| {
1343                    while let Ok(()) = highlight_updates_rx.changed().await {
1344                        let editor = editor.upgrade().context("editor was dropped")?;
1345                        cx.update_global(|assistant: &mut InlineAssistant, cx| {
1346                            assistant.update_editor_highlights(&editor, cx);
1347                        })?;
1348                    }
1349                    Ok(())
1350                }
1351            }),
1352            _subscriptions: vec![
1353                cx.observe_release_in(editor, window, {
1354                    let editor = editor.downgrade();
1355                    |_, window, cx| {
1356                        InlineAssistant::update_global(cx, |this, cx| {
1357                            this.handle_editor_release(editor, window, cx);
1358                        })
1359                    }
1360                }),
1361                window.observe(editor, cx, move |editor, window, cx| {
1362                    InlineAssistant::update_global(cx, |this, cx| {
1363                        this.handle_editor_change(editor, window, cx)
1364                    })
1365                }),
1366                window.subscribe(editor, cx, move |editor, event, window, cx| {
1367                    InlineAssistant::update_global(cx, |this, cx| {
1368                        this.handle_editor_event(editor, event, window, cx)
1369                    })
1370                }),
1371                editor.update(cx, |editor, cx| {
1372                    let editor_handle = cx.entity().downgrade();
1373                    editor.register_action(move |_: &editor::actions::Newline, window, cx| {
1374                        InlineAssistant::update_global(cx, |this, cx| {
1375                            if let Some(editor) = editor_handle.upgrade() {
1376                                this.handle_editor_newline(editor, window, cx)
1377                            }
1378                        })
1379                    })
1380                }),
1381                editor.update(cx, |editor, cx| {
1382                    let editor_handle = cx.entity().downgrade();
1383                    editor.register_action(move |_: &editor::actions::Cancel, window, cx| {
1384                        InlineAssistant::update_global(cx, |this, cx| {
1385                            if let Some(editor) = editor_handle.upgrade() {
1386                                this.handle_editor_cancel(editor, window, cx)
1387                            }
1388                        })
1389                    })
1390                }),
1391            ],
1392        }
1393    }
1394}
1395
1396struct InlineAssistGroup {
1397    assist_ids: Vec<InlineAssistId>,
1398    linked: bool,
1399    active_assist_id: Option<InlineAssistId>,
1400}
1401
1402impl InlineAssistGroup {
1403    fn new() -> Self {
1404        Self {
1405            assist_ids: Vec::new(),
1406            linked: true,
1407            active_assist_id: None,
1408        }
1409    }
1410}
1411
1412fn build_assist_editor_renderer(editor: &Entity<PromptEditor>) -> RenderBlock {
1413    let editor = editor.clone();
1414    Arc::new(move |cx: &mut BlockContext| {
1415        *editor.read(cx).editor_margins.lock() = *cx.margins;
1416        editor.clone().into_any_element()
1417    })
1418}
1419
1420#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1421pub struct InlineAssistId(usize);
1422
1423impl InlineAssistId {
1424    fn post_inc(&mut self) -> InlineAssistId {
1425        let id = *self;
1426        self.0 += 1;
1427        id
1428    }
1429}
1430
1431#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1432struct InlineAssistGroupId(usize);
1433
1434impl InlineAssistGroupId {
1435    fn post_inc(&mut self) -> InlineAssistGroupId {
1436        let id = *self;
1437        self.0 += 1;
1438        id
1439    }
1440}
1441
1442enum PromptEditorEvent {
1443    StartRequested,
1444    StopRequested,
1445    ConfirmRequested,
1446    CancelRequested,
1447    DismissRequested,
1448}
1449
1450struct PromptEditor {
1451    id: InlineAssistId,
1452    editor: Entity<Editor>,
1453    language_model_selector: Entity<LanguageModelSelector>,
1454    edited_since_done: bool,
1455    editor_margins: Arc<Mutex<EditorMargins>>,
1456    prompt_history: VecDeque<String>,
1457    prompt_history_ix: Option<usize>,
1458    pending_prompt: String,
1459    codegen: Entity<Codegen>,
1460    _codegen_subscription: Subscription,
1461    editor_subscriptions: Vec<Subscription>,
1462    pending_token_count: Task<Result<()>>,
1463    token_counts: Option<TokenCounts>,
1464    _token_count_subscriptions: Vec<Subscription>,
1465    workspace: Option<WeakEntity<Workspace>>,
1466    show_rate_limit_notice: bool,
1467}
1468
1469#[derive(Copy, Clone)]
1470pub struct TokenCounts {
1471    total: usize,
1472    assistant_panel: usize,
1473}
1474
1475impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1476
1477impl Render for PromptEditor {
1478    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1479        let editor_margins = *self.editor_margins.lock();
1480        let gutter_dimensions = editor_margins.gutter;
1481        let codegen = self.codegen.read(cx);
1482
1483        let mut buttons = Vec::new();
1484        if codegen.alternative_count(cx) > 1 {
1485            buttons.push(self.render_cycle_controls(cx));
1486        }
1487
1488        let status = codegen.status(cx);
1489        buttons.extend(match status {
1490            CodegenStatus::Idle => {
1491                vec![
1492                    IconButton::new("cancel", IconName::Close)
1493                        .icon_color(Color::Muted)
1494                        .shape(IconButtonShape::Square)
1495                        .tooltip(|window, cx| {
1496                            Tooltip::for_action("Cancel Assist", &menu::Cancel, window, cx)
1497                        })
1498                        .on_click(
1499                            cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1500                        )
1501                        .into_any_element(),
1502                    IconButton::new("start", IconName::SparkleAlt)
1503                        .icon_color(Color::Muted)
1504                        .shape(IconButtonShape::Square)
1505                        .tooltip(|window, cx| {
1506                            Tooltip::for_action("Transform", &menu::Confirm, window, cx)
1507                        })
1508                        .on_click(
1509                            cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1510                        )
1511                        .into_any_element(),
1512                ]
1513            }
1514            CodegenStatus::Pending => {
1515                vec![
1516                    IconButton::new("cancel", IconName::Close)
1517                        .icon_color(Color::Muted)
1518                        .shape(IconButtonShape::Square)
1519                        .tooltip(Tooltip::text("Cancel Assist"))
1520                        .on_click(
1521                            cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1522                        )
1523                        .into_any_element(),
1524                    IconButton::new("stop", IconName::Stop)
1525                        .icon_color(Color::Error)
1526                        .shape(IconButtonShape::Square)
1527                        .tooltip(|window, cx| {
1528                            Tooltip::with_meta(
1529                                "Interrupt Transformation",
1530                                Some(&menu::Cancel),
1531                                "Changes won't be discarded",
1532                                window,
1533                                cx,
1534                            )
1535                        })
1536                        .on_click(
1537                            cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1538                        )
1539                        .into_any_element(),
1540                ]
1541            }
1542            CodegenStatus::Error(_) | CodegenStatus::Done => {
1543                let must_rerun =
1544                    self.edited_since_done || matches!(status, CodegenStatus::Error(_));
1545                // when accept button isn't visible, then restart maps to confirm
1546                // when accept button is visible, then restart must be mapped to an alternate keyboard shortcut
1547                let restart_key: &dyn gpui::Action = if must_rerun {
1548                    &menu::Confirm
1549                } else {
1550                    &menu::Restart
1551                };
1552                vec![
1553                    IconButton::new("cancel", IconName::Close)
1554                        .icon_color(Color::Muted)
1555                        .shape(IconButtonShape::Square)
1556                        .tooltip(|window, cx| {
1557                            Tooltip::for_action("Cancel Assist", &menu::Cancel, window, cx)
1558                        })
1559                        .on_click(
1560                            cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1561                        )
1562                        .into_any_element(),
1563                    IconButton::new("restart", IconName::RotateCw)
1564                        .icon_color(Color::Muted)
1565                        .shape(IconButtonShape::Square)
1566                        .tooltip(|window, cx| {
1567                            Tooltip::with_meta(
1568                                "Regenerate Transformation",
1569                                Some(restart_key),
1570                                "Current change will be discarded",
1571                                window,
1572                                cx,
1573                            )
1574                        })
1575                        .on_click(cx.listener(|_, _, _, cx| {
1576                            cx.emit(PromptEditorEvent::StartRequested);
1577                        }))
1578                        .into_any_element(),
1579                    if !must_rerun {
1580                        IconButton::new("confirm", IconName::Check)
1581                            .icon_color(Color::Info)
1582                            .shape(IconButtonShape::Square)
1583                            .tooltip(|window, cx| {
1584                                Tooltip::for_action("Confirm Assist", &menu::Confirm, window, cx)
1585                            })
1586                            .on_click(cx.listener(|_, _, _, cx| {
1587                                cx.emit(PromptEditorEvent::ConfirmRequested);
1588                            }))
1589                            .into_any_element()
1590                    } else {
1591                        div().into_any_element()
1592                    },
1593                ]
1594            }
1595        });
1596
1597        h_flex()
1598            .key_context("PromptEditor")
1599            .bg(cx.theme().colors().editor_background)
1600            .block_mouse_down()
1601            .cursor(CursorStyle::Arrow)
1602            .border_y_1()
1603            .border_color(cx.theme().status().info_border)
1604            .size_full()
1605            .pr(editor_margins.right)
1606            .py(window.line_height() / 2.5)
1607            .on_action(cx.listener(Self::confirm))
1608            .on_action(cx.listener(Self::cancel))
1609            .on_action(cx.listener(Self::restart))
1610            .on_action(cx.listener(Self::move_up))
1611            .on_action(cx.listener(Self::move_down))
1612            .capture_action(cx.listener(Self::cycle_prev))
1613            .capture_action(cx.listener(Self::cycle_next))
1614            .child(
1615                h_flex()
1616                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1617                    .justify_center()
1618                    .gap_2()
1619                    .child(LanguageModelSelectorPopoverMenu::new(
1620                        self.language_model_selector.clone(),
1621                        IconButton::new("context", IconName::SettingsAlt)
1622                            .shape(IconButtonShape::Square)
1623                            .icon_size(IconSize::Small)
1624                            .icon_color(Color::Muted),
1625                        move |window, cx| {
1626                            Tooltip::with_meta(
1627                                format!(
1628                                    "Using {}",
1629                                    LanguageModelRegistry::read_global(cx)
1630                                        .default_model()
1631                                        .map(|default| default.model.name().0)
1632                                        .unwrap_or_else(|| "No model selected".into()),
1633                                ),
1634                                None,
1635                                "Change Model",
1636                                window,
1637                                cx,
1638                            )
1639                        },
1640                        gpui::Corner::TopRight,
1641                    ))
1642                    .map(|el| {
1643                        let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1644                            return el;
1645                        };
1646
1647                        let error_message = SharedString::from(error.to_string());
1648                        if error.error_code() == proto::ErrorCode::RateLimitExceeded
1649                            && cx.has_flag::<ZedProFeatureFlag>()
1650                        {
1651                            el.child(
1652                                v_flex()
1653                                    .child(
1654                                        IconButton::new("rate-limit-error", IconName::XCircle)
1655                                            .toggle_state(self.show_rate_limit_notice)
1656                                            .shape(IconButtonShape::Square)
1657                                            .icon_size(IconSize::Small)
1658                                            .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1659                                    )
1660                                    .children(self.show_rate_limit_notice.then(|| {
1661                                        deferred(
1662                                            anchored()
1663                                                .position_mode(gpui::AnchoredPositionMode::Local)
1664                                                .position(point(px(0.), px(24.)))
1665                                                .anchor(gpui::Corner::TopLeft)
1666                                                .child(self.render_rate_limit_notice(cx)),
1667                                        )
1668                                    })),
1669                            )
1670                        } else {
1671                            el.child(
1672                                div()
1673                                    .id("error")
1674                                    .tooltip(Tooltip::text(error_message))
1675                                    .child(
1676                                        Icon::new(IconName::XCircle)
1677                                            .size(IconSize::Small)
1678                                            .color(Color::Error),
1679                                    ),
1680                            )
1681                        }
1682                    }),
1683            )
1684            .child(div().flex_1().child(self.render_prompt_editor(cx)))
1685            .child(
1686                h_flex()
1687                    .gap_2()
1688                    .pr(px(9.))
1689                    .children(self.render_token_count(cx))
1690                    .children(buttons),
1691            )
1692    }
1693}
1694
1695impl Focusable for PromptEditor {
1696    fn focus_handle(&self, cx: &App) -> FocusHandle {
1697        self.editor.focus_handle(cx)
1698    }
1699}
1700
1701impl PromptEditor {
1702    const MAX_LINES: u8 = 8;
1703
1704    fn new(
1705        id: InlineAssistId,
1706        editor_margins: Arc<Mutex<EditorMargins>>,
1707        prompt_history: VecDeque<String>,
1708        prompt_buffer: Entity<MultiBuffer>,
1709        codegen: Entity<Codegen>,
1710        parent_editor: &Entity<Editor>,
1711        assistant_panel: Option<&Entity<AssistantPanel>>,
1712        workspace: Option<WeakEntity<Workspace>>,
1713        fs: Arc<dyn Fs>,
1714        window: &mut Window,
1715        cx: &mut Context<Self>,
1716    ) -> Self {
1717        let prompt_editor = cx.new(|cx| {
1718            let mut editor = Editor::new(
1719                EditorMode::AutoHeight {
1720                    max_lines: Self::MAX_LINES as usize,
1721                },
1722                prompt_buffer,
1723                None,
1724                window,
1725                cx,
1726            );
1727            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1728            // Since the prompt editors for all inline assistants are linked,
1729            // always show the cursor (even when it isn't focused) because
1730            // typing in one will make what you typed appear in all of them.
1731            editor.set_show_cursor_when_unfocused(true, cx);
1732            editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx), window, cx), cx);
1733            editor
1734        });
1735
1736        let mut token_count_subscriptions = Vec::new();
1737        token_count_subscriptions.push(cx.subscribe_in(
1738            parent_editor,
1739            window,
1740            Self::handle_parent_editor_event,
1741        ));
1742        if let Some(assistant_panel) = assistant_panel {
1743            token_count_subscriptions.push(cx.subscribe_in(
1744                assistant_panel,
1745                window,
1746                Self::handle_assistant_panel_event,
1747            ));
1748        }
1749
1750        let mut this = Self {
1751            id,
1752            editor: prompt_editor,
1753            language_model_selector: cx.new(|cx| {
1754                let fs = fs.clone();
1755                LanguageModelSelector::new(
1756                    |cx| LanguageModelRegistry::read_global(cx).default_model(),
1757                    move |model, cx| {
1758                        update_settings_file::<AssistantSettings>(
1759                            fs.clone(),
1760                            cx,
1761                            move |settings, _| settings.set_model(model.clone()),
1762                        );
1763                    },
1764                    window,
1765                    cx,
1766                )
1767            }),
1768            edited_since_done: false,
1769            editor_margins,
1770            prompt_history,
1771            prompt_history_ix: None,
1772            pending_prompt: String::new(),
1773            _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1774            editor_subscriptions: Vec::new(),
1775            codegen,
1776            pending_token_count: Task::ready(Ok(())),
1777            token_counts: None,
1778            _token_count_subscriptions: token_count_subscriptions,
1779            workspace,
1780            show_rate_limit_notice: false,
1781        };
1782        this.count_tokens(cx);
1783        this.subscribe_to_editor(window, cx);
1784        this
1785    }
1786
1787    fn subscribe_to_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
1788        self.editor_subscriptions.clear();
1789        self.editor_subscriptions.push(cx.subscribe_in(
1790            &self.editor,
1791            window,
1792            Self::handle_prompt_editor_events,
1793        ));
1794    }
1795
1796    fn set_show_cursor_when_unfocused(
1797        &mut self,
1798        show_cursor_when_unfocused: bool,
1799        cx: &mut Context<Self>,
1800    ) {
1801        self.editor.update(cx, |editor, cx| {
1802            editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1803        });
1804    }
1805
1806    fn unlink(&mut self, window: &mut Window, cx: &mut Context<Self>) {
1807        let prompt = self.prompt(cx);
1808        let focus = self.editor.focus_handle(cx).contains_focused(window, cx);
1809        self.editor = cx.new(|cx| {
1810            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, window, cx);
1811            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1812            editor.set_placeholder_text("Add a prompt…", cx);
1813            editor.set_text(prompt, window, cx);
1814            if focus {
1815                window.focus(&editor.focus_handle(cx));
1816            }
1817            editor
1818        });
1819        self.subscribe_to_editor(window, cx);
1820    }
1821
1822    fn placeholder_text(codegen: &Codegen, window: &Window, cx: &App) -> String {
1823        let context_keybinding = text_for_action(&zed_actions::assistant::ToggleFocus, window, cx)
1824            .map(|keybinding| format!("{keybinding} for context"))
1825            .unwrap_or_default();
1826
1827        let action = if codegen.is_insertion {
1828            "Generate"
1829        } else {
1830            "Transform"
1831        };
1832
1833        format!("{action}{context_keybinding} • ↓↑ for history")
1834    }
1835
1836    fn prompt(&self, cx: &App) -> String {
1837        self.editor.read(cx).text(cx)
1838    }
1839
1840    fn toggle_rate_limit_notice(
1841        &mut self,
1842        _: &ClickEvent,
1843        window: &mut Window,
1844        cx: &mut Context<Self>,
1845    ) {
1846        self.show_rate_limit_notice = !self.show_rate_limit_notice;
1847        if self.show_rate_limit_notice {
1848            window.focus(&self.editor.focus_handle(cx));
1849        }
1850        cx.notify();
1851    }
1852
1853    fn handle_parent_editor_event(
1854        &mut self,
1855        _: &Entity<Editor>,
1856        event: &EditorEvent,
1857        _: &mut Window,
1858        cx: &mut Context<Self>,
1859    ) {
1860        if let EditorEvent::BufferEdited { .. } = event {
1861            self.count_tokens(cx);
1862        }
1863    }
1864
1865    fn handle_assistant_panel_event(
1866        &mut self,
1867        _: &Entity<AssistantPanel>,
1868        event: &AssistantPanelEvent,
1869        _: &mut Window,
1870        cx: &mut Context<Self>,
1871    ) {
1872        let AssistantPanelEvent::ContextEdited { .. } = event;
1873        self.count_tokens(cx);
1874    }
1875
1876    fn count_tokens(&mut self, cx: &mut Context<Self>) {
1877        let assist_id = self.id;
1878        self.pending_token_count = cx.spawn(async move |this, cx| {
1879            cx.background_executor().timer(Duration::from_secs(1)).await;
1880            let token_count = cx
1881                .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1882                    let assist = inline_assistant
1883                        .assists
1884                        .get(&assist_id)
1885                        .context("assist not found")?;
1886                    anyhow::Ok(assist.count_tokens(cx))
1887                })??
1888                .await?;
1889
1890            this.update(cx, |this, cx| {
1891                this.token_counts = Some(token_count);
1892                cx.notify();
1893            })
1894        })
1895    }
1896
1897    fn handle_prompt_editor_events(
1898        &mut self,
1899        _: &Entity<Editor>,
1900        event: &EditorEvent,
1901        window: &mut Window,
1902        cx: &mut Context<Self>,
1903    ) {
1904        match event {
1905            EditorEvent::Edited { .. } => {
1906                if let Some(workspace) = window.root::<Workspace>().flatten() {
1907                    workspace.update(cx, |workspace, cx| {
1908                        let is_via_ssh = workspace
1909                            .project()
1910                            .update(cx, |project, _| project.is_via_ssh());
1911
1912                        workspace
1913                            .client()
1914                            .telemetry()
1915                            .log_edit_event("inline assist", is_via_ssh);
1916                    });
1917                }
1918                let prompt = self.editor.read(cx).text(cx);
1919                if self
1920                    .prompt_history_ix
1921                    .map_or(true, |ix| self.prompt_history[ix] != prompt)
1922                {
1923                    self.prompt_history_ix.take();
1924                    self.pending_prompt = prompt;
1925                }
1926
1927                self.edited_since_done = true;
1928                cx.notify();
1929            }
1930            EditorEvent::BufferEdited => {
1931                self.count_tokens(cx);
1932            }
1933            EditorEvent::Blurred => {
1934                if self.show_rate_limit_notice {
1935                    self.show_rate_limit_notice = false;
1936                    cx.notify();
1937                }
1938            }
1939            _ => {}
1940        }
1941    }
1942
1943    fn handle_codegen_changed(&mut self, _: Entity<Codegen>, cx: &mut Context<Self>) {
1944        match self.codegen.read(cx).status(cx) {
1945            CodegenStatus::Idle => {
1946                self.editor
1947                    .update(cx, |editor, _| editor.set_read_only(false));
1948            }
1949            CodegenStatus::Pending => {
1950                self.editor
1951                    .update(cx, |editor, _| editor.set_read_only(true));
1952            }
1953            CodegenStatus::Done => {
1954                self.edited_since_done = false;
1955                self.editor
1956                    .update(cx, |editor, _| editor.set_read_only(false));
1957            }
1958            CodegenStatus::Error(error) => {
1959                if cx.has_flag::<ZedProFeatureFlag>()
1960                    && error.error_code() == proto::ErrorCode::RateLimitExceeded
1961                    && !dismissed_rate_limit_notice()
1962                {
1963                    self.show_rate_limit_notice = true;
1964                    cx.notify();
1965                }
1966
1967                self.edited_since_done = false;
1968                self.editor
1969                    .update(cx, |editor, _| editor.set_read_only(false));
1970            }
1971        }
1972    }
1973
1974    fn restart(&mut self, _: &menu::Restart, _window: &mut Window, cx: &mut Context<Self>) {
1975        cx.emit(PromptEditorEvent::StartRequested);
1976    }
1977
1978    fn cancel(
1979        &mut self,
1980        _: &editor::actions::Cancel,
1981        _window: &mut Window,
1982        cx: &mut Context<Self>,
1983    ) {
1984        match self.codegen.read(cx).status(cx) {
1985            CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1986                cx.emit(PromptEditorEvent::CancelRequested);
1987            }
1988            CodegenStatus::Pending => {
1989                cx.emit(PromptEditorEvent::StopRequested);
1990            }
1991        }
1992    }
1993
1994    fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
1995        match self.codegen.read(cx).status(cx) {
1996            CodegenStatus::Idle => {
1997                cx.emit(PromptEditorEvent::StartRequested);
1998            }
1999            CodegenStatus::Pending => {
2000                cx.emit(PromptEditorEvent::DismissRequested);
2001            }
2002            CodegenStatus::Done => {
2003                if self.edited_since_done {
2004                    cx.emit(PromptEditorEvent::StartRequested);
2005                } else {
2006                    cx.emit(PromptEditorEvent::ConfirmRequested);
2007                }
2008            }
2009            CodegenStatus::Error(_) => {
2010                cx.emit(PromptEditorEvent::StartRequested);
2011            }
2012        }
2013    }
2014
2015    fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
2016        if let Some(ix) = self.prompt_history_ix {
2017            if ix > 0 {
2018                self.prompt_history_ix = Some(ix - 1);
2019                let prompt = self.prompt_history[ix - 1].as_str();
2020                self.editor.update(cx, |editor, cx| {
2021                    editor.set_text(prompt, window, cx);
2022                    editor.move_to_beginning(&Default::default(), window, cx);
2023                });
2024            }
2025        } else if !self.prompt_history.is_empty() {
2026            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
2027            let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
2028            self.editor.update(cx, |editor, cx| {
2029                editor.set_text(prompt, window, cx);
2030                editor.move_to_beginning(&Default::default(), window, cx);
2031            });
2032        }
2033    }
2034
2035    fn move_down(&mut self, _: &MoveDown, window: &mut Window, cx: &mut Context<Self>) {
2036        if let Some(ix) = self.prompt_history_ix {
2037            if ix < self.prompt_history.len() - 1 {
2038                self.prompt_history_ix = Some(ix + 1);
2039                let prompt = self.prompt_history[ix + 1].as_str();
2040                self.editor.update(cx, |editor, cx| {
2041                    editor.set_text(prompt, window, cx);
2042                    editor.move_to_end(&Default::default(), window, cx)
2043                });
2044            } else {
2045                self.prompt_history_ix = None;
2046                let prompt = self.pending_prompt.as_str();
2047                self.editor.update(cx, |editor, cx| {
2048                    editor.set_text(prompt, window, cx);
2049                    editor.move_to_end(&Default::default(), window, cx)
2050                });
2051            }
2052        }
2053    }
2054
2055    fn cycle_prev(
2056        &mut self,
2057        _: &CyclePreviousInlineAssist,
2058        _: &mut Window,
2059        cx: &mut Context<Self>,
2060    ) {
2061        self.codegen
2062            .update(cx, |codegen, cx| codegen.cycle_prev(cx));
2063    }
2064
2065    fn cycle_next(&mut self, _: &CycleNextInlineAssist, _: &mut Window, cx: &mut Context<Self>) {
2066        self.codegen
2067            .update(cx, |codegen, cx| codegen.cycle_next(cx));
2068    }
2069
2070    fn render_cycle_controls(&self, cx: &Context<Self>) -> AnyElement {
2071        let codegen = self.codegen.read(cx);
2072        let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
2073
2074        let model_registry = LanguageModelRegistry::read_global(cx);
2075        let default_model = model_registry.default_model().map(|default| default.model);
2076        let alternative_models = model_registry.inline_alternative_models();
2077
2078        let get_model_name = |index: usize| -> String {
2079            let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
2080
2081            match index {
2082                0 => default_model.as_ref().map_or_else(String::new, name),
2083                index if index <= alternative_models.len() => alternative_models
2084                    .get(index - 1)
2085                    .map_or_else(String::new, name),
2086                _ => String::new(),
2087            }
2088        };
2089
2090        let total_models = alternative_models.len() + 1;
2091
2092        if total_models <= 1 {
2093            return div().into_any_element();
2094        }
2095
2096        let current_index = codegen.active_alternative;
2097        let prev_index = (current_index + total_models - 1) % total_models;
2098        let next_index = (current_index + 1) % total_models;
2099
2100        let prev_model_name = get_model_name(prev_index);
2101        let next_model_name = get_model_name(next_index);
2102
2103        h_flex()
2104            .child(
2105                IconButton::new("previous", IconName::ChevronLeft)
2106                    .icon_color(Color::Muted)
2107                    .disabled(disabled || current_index == 0)
2108                    .shape(IconButtonShape::Square)
2109                    .tooltip({
2110                        let focus_handle = self.editor.focus_handle(cx);
2111                        move |window, cx| {
2112                            cx.new(|cx| {
2113                                let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
2114                                    KeyBinding::for_action_in(
2115                                        &CyclePreviousInlineAssist,
2116                                        &focus_handle,
2117                                        window,
2118                                        cx,
2119                                    ),
2120                                );
2121                                if !disabled && current_index != 0 {
2122                                    tooltip = tooltip.meta(prev_model_name.clone());
2123                                }
2124                                tooltip
2125                            })
2126                            .into()
2127                        }
2128                    })
2129                    .on_click(cx.listener(|this, _, _, cx| {
2130                        this.codegen
2131                            .update(cx, |codegen, cx| codegen.cycle_prev(cx))
2132                    })),
2133            )
2134            .child(
2135                Label::new(format!(
2136                    "{}/{}",
2137                    codegen.active_alternative + 1,
2138                    codegen.alternative_count(cx)
2139                ))
2140                .size(LabelSize::Small)
2141                .color(if disabled {
2142                    Color::Disabled
2143                } else {
2144                    Color::Muted
2145                }),
2146            )
2147            .child(
2148                IconButton::new("next", IconName::ChevronRight)
2149                    .icon_color(Color::Muted)
2150                    .disabled(disabled || current_index == total_models - 1)
2151                    .shape(IconButtonShape::Square)
2152                    .tooltip({
2153                        let focus_handle = self.editor.focus_handle(cx);
2154                        move |window, cx| {
2155                            cx.new(|cx| {
2156                                let mut tooltip = Tooltip::new("Next Alternative").key_binding(
2157                                    KeyBinding::for_action_in(
2158                                        &CycleNextInlineAssist,
2159                                        &focus_handle,
2160                                        window,
2161                                        cx,
2162                                    ),
2163                                );
2164                                if !disabled && current_index != total_models - 1 {
2165                                    tooltip = tooltip.meta(next_model_name.clone());
2166                                }
2167                                tooltip
2168                            })
2169                            .into()
2170                        }
2171                    })
2172                    .on_click(cx.listener(|this, _, _, cx| {
2173                        this.codegen
2174                            .update(cx, |codegen, cx| codegen.cycle_next(cx))
2175                    })),
2176            )
2177            .into_any_element()
2178    }
2179
2180    fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
2181        let model = LanguageModelRegistry::read_global(cx)
2182            .default_model()?
2183            .model;
2184        let token_counts = self.token_counts?;
2185        let max_token_count = model.max_token_count();
2186
2187        let remaining_tokens = max_token_count as isize - token_counts.total as isize;
2188        let token_count_color = if remaining_tokens <= 0 {
2189            Color::Error
2190        } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
2191            Color::Warning
2192        } else {
2193            Color::Muted
2194        };
2195
2196        let mut token_count = h_flex()
2197            .id("token_count")
2198            .gap_0p5()
2199            .child(
2200                Label::new(humanize_token_count(token_counts.total))
2201                    .size(LabelSize::Small)
2202                    .color(token_count_color),
2203            )
2204            .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
2205            .child(
2206                Label::new(humanize_token_count(max_token_count))
2207                    .size(LabelSize::Small)
2208                    .color(Color::Muted),
2209            );
2210        if let Some(workspace) = self.workspace.clone() {
2211            token_count = token_count
2212                .tooltip(move |window, cx| {
2213                    Tooltip::with_meta(
2214                        format!(
2215                            "Tokens Used ({} from the Assistant Panel)",
2216                            humanize_token_count(token_counts.assistant_panel)
2217                        ),
2218                        None,
2219                        "Click to open the Assistant Panel",
2220                        window,
2221                        cx,
2222                    )
2223                })
2224                .cursor_pointer()
2225                .on_mouse_down(gpui::MouseButton::Left, |_, _, cx| cx.stop_propagation())
2226                .on_click(move |_, window, cx| {
2227                    cx.stop_propagation();
2228                    workspace
2229                        .update(cx, |workspace, cx| {
2230                            workspace.focus_panel::<AssistantPanel>(window, cx)
2231                        })
2232                        .ok();
2233                });
2234        } else {
2235            token_count = token_count
2236                .cursor_default()
2237                .tooltip(Tooltip::text("Tokens used"));
2238        }
2239
2240        Some(token_count)
2241    }
2242
2243    fn render_prompt_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
2244        let settings = ThemeSettings::get_global(cx);
2245        let text_style = TextStyle {
2246            color: if self.editor.read(cx).read_only(cx) {
2247                cx.theme().colors().text_disabled
2248            } else {
2249                cx.theme().colors().text
2250            },
2251            font_family: settings.buffer_font.family.clone(),
2252            font_fallbacks: settings.buffer_font.fallbacks.clone(),
2253            font_size: settings.buffer_font_size(cx).into(),
2254            font_weight: settings.buffer_font.weight,
2255            line_height: relative(settings.buffer_line_height.value()),
2256            ..Default::default()
2257        };
2258        EditorElement::new(
2259            &self.editor,
2260            EditorStyle {
2261                background: cx.theme().colors().editor_background,
2262                local_player: cx.theme().players().local(),
2263                text: text_style,
2264                ..Default::default()
2265            },
2266        )
2267    }
2268
2269    fn render_rate_limit_notice(&self, cx: &mut Context<Self>) -> impl IntoElement {
2270        Popover::new().child(
2271            v_flex()
2272                .occlude()
2273                .p_2()
2274                .child(
2275                    Label::new("Out of Tokens")
2276                        .size(LabelSize::Small)
2277                        .weight(FontWeight::BOLD),
2278                )
2279                .child(Label::new(
2280                    "Try Zed Pro for higher limits, a wider range of models, and more.",
2281                ))
2282                .child(
2283                    h_flex()
2284                        .justify_between()
2285                        .child(CheckboxWithLabel::new(
2286                            "dont-show-again",
2287                            Label::new("Don't show again"),
2288                            if dismissed_rate_limit_notice() {
2289                                ui::ToggleState::Selected
2290                            } else {
2291                                ui::ToggleState::Unselected
2292                            },
2293                            |selection, _, cx| {
2294                                let is_dismissed = match selection {
2295                                    ui::ToggleState::Unselected => false,
2296                                    ui::ToggleState::Indeterminate => return,
2297                                    ui::ToggleState::Selected => true,
2298                                };
2299
2300                                set_rate_limit_notice_dismissed(is_dismissed, cx)
2301                            },
2302                        ))
2303                        .child(
2304                            h_flex()
2305                                .gap_2()
2306                                .child(
2307                                    Button::new("dismiss", "Dismiss")
2308                                        .style(ButtonStyle::Transparent)
2309                                        .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2310                                )
2311                                .child(Button::new("more-info", "More Info").on_click(
2312                                    |_event, window, cx| {
2313                                        window.dispatch_action(
2314                                            Box::new(zed_actions::OpenAccountSettings),
2315                                            cx,
2316                                        )
2317                                    },
2318                                )),
2319                        ),
2320                ),
2321        )
2322    }
2323}
2324
2325const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2326
2327fn dismissed_rate_limit_notice() -> bool {
2328    db::kvp::KEY_VALUE_STORE
2329        .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2330        .log_err()
2331        .map_or(false, |s| s.is_some())
2332}
2333
2334fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut App) {
2335    db::write_and_log(cx, move || async move {
2336        if is_dismissed {
2337            db::kvp::KEY_VALUE_STORE
2338                .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2339                .await
2340        } else {
2341            db::kvp::KEY_VALUE_STORE
2342                .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2343                .await
2344        }
2345    })
2346}
2347
2348struct InlineAssist {
2349    group_id: InlineAssistGroupId,
2350    range: Range<Anchor>,
2351    editor: WeakEntity<Editor>,
2352    decorations: Option<InlineAssistDecorations>,
2353    codegen: Entity<Codegen>,
2354    _subscriptions: Vec<Subscription>,
2355    workspace: Option<WeakEntity<Workspace>>,
2356    include_context: bool,
2357}
2358
2359impl InlineAssist {
2360    fn new(
2361        assist_id: InlineAssistId,
2362        group_id: InlineAssistGroupId,
2363        include_context: bool,
2364        editor: &Entity<Editor>,
2365        prompt_editor: &Entity<PromptEditor>,
2366        prompt_block_id: CustomBlockId,
2367        end_block_id: CustomBlockId,
2368        range: Range<Anchor>,
2369        codegen: Entity<Codegen>,
2370        workspace: Option<WeakEntity<Workspace>>,
2371        window: &mut Window,
2372        cx: &mut App,
2373    ) -> Self {
2374        let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2375        InlineAssist {
2376            group_id,
2377            include_context,
2378            editor: editor.downgrade(),
2379            decorations: Some(InlineAssistDecorations {
2380                prompt_block_id,
2381                prompt_editor: prompt_editor.clone(),
2382                removed_line_block_ids: HashSet::default(),
2383                end_block_id,
2384            }),
2385            range,
2386            codegen: codegen.clone(),
2387            workspace: workspace.clone(),
2388            _subscriptions: vec![
2389                window.on_focus_in(&prompt_editor_focus_handle, cx, move |_, cx| {
2390                    InlineAssistant::update_global(cx, |this, cx| {
2391                        this.handle_prompt_editor_focus_in(assist_id, cx)
2392                    })
2393                }),
2394                window.on_focus_out(&prompt_editor_focus_handle, cx, move |_, _, cx| {
2395                    InlineAssistant::update_global(cx, |this, cx| {
2396                        this.handle_prompt_editor_focus_out(assist_id, cx)
2397                    })
2398                }),
2399                window.subscribe(
2400                    prompt_editor,
2401                    cx,
2402                    move |prompt_editor, event, window, cx| {
2403                        InlineAssistant::update_global(cx, |this, cx| {
2404                            this.handle_prompt_editor_event(prompt_editor, event, window, cx)
2405                        })
2406                    },
2407                ),
2408                window.observe(&codegen, cx, {
2409                    let editor = editor.downgrade();
2410                    move |_, window, cx| {
2411                        if let Some(editor) = editor.upgrade() {
2412                            InlineAssistant::update_global(cx, |this, cx| {
2413                                if let Some(editor_assists) =
2414                                    this.assists_by_editor.get(&editor.downgrade())
2415                                {
2416                                    editor_assists.highlight_updates.send(()).ok();
2417                                }
2418
2419                                this.update_editor_blocks(&editor, assist_id, window, cx);
2420                            })
2421                        }
2422                    }
2423                }),
2424                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
2425                    InlineAssistant::update_global(cx, |this, cx| match event {
2426                        CodegenEvent::Undone => this.finish_assist(assist_id, false, window, cx),
2427                        CodegenEvent::Finished => {
2428                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
2429                                assist
2430                            } else {
2431                                return;
2432                            };
2433
2434                            if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2435                                if assist.decorations.is_none() {
2436                                    if let Some(workspace) = assist
2437                                        .workspace
2438                                        .as_ref()
2439                                        .and_then(|workspace| workspace.upgrade())
2440                                    {
2441                                        let error = format!("Inline assistant error: {}", error);
2442                                        workspace.update(cx, |workspace, cx| {
2443                                            struct InlineAssistantError;
2444
2445                                            let id =
2446                                                NotificationId::composite::<InlineAssistantError>(
2447                                                    assist_id.0,
2448                                                );
2449
2450                                            workspace.show_toast(Toast::new(id, error), cx);
2451                                        })
2452                                    }
2453                                }
2454                            }
2455
2456                            if assist.decorations.is_none() {
2457                                this.finish_assist(assist_id, false, window, cx);
2458                            }
2459                        }
2460                    })
2461                }),
2462            ],
2463        }
2464    }
2465
2466    fn user_prompt(&self, cx: &App) -> Option<String> {
2467        let decorations = self.decorations.as_ref()?;
2468        Some(decorations.prompt_editor.read(cx).prompt(cx))
2469    }
2470
2471    fn assistant_panel_context(&self, cx: &mut App) -> Option<LanguageModelRequest> {
2472        if self.include_context {
2473            let workspace = self.workspace.as_ref()?;
2474            let workspace = workspace.upgrade()?.read(cx);
2475            let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2476            Some(
2477                assistant_panel
2478                    .read(cx)
2479                    .active_context(cx)?
2480                    .read(cx)
2481                    .to_completion_request(None, RequestType::Chat, cx),
2482            )
2483        } else {
2484            None
2485        }
2486    }
2487
2488    pub fn count_tokens(&self, cx: &mut App) -> BoxFuture<'static, Result<TokenCounts>> {
2489        let Some(user_prompt) = self.user_prompt(cx) else {
2490            return future::ready(Err(anyhow!("no user prompt"))).boxed();
2491        };
2492        let assistant_panel_context = self.assistant_panel_context(cx);
2493        self.codegen
2494            .read(cx)
2495            .count_tokens(user_prompt, assistant_panel_context, cx)
2496    }
2497}
2498
2499struct InlineAssistDecorations {
2500    prompt_block_id: CustomBlockId,
2501    prompt_editor: Entity<PromptEditor>,
2502    removed_line_block_ids: HashSet<CustomBlockId>,
2503    end_block_id: CustomBlockId,
2504}
2505
2506#[derive(Copy, Clone, Debug)]
2507pub enum CodegenEvent {
2508    Finished,
2509    Undone,
2510}
2511
2512pub struct Codegen {
2513    alternatives: Vec<Entity<CodegenAlternative>>,
2514    active_alternative: usize,
2515    seen_alternatives: HashSet<usize>,
2516    subscriptions: Vec<Subscription>,
2517    buffer: Entity<MultiBuffer>,
2518    range: Range<Anchor>,
2519    initial_transaction_id: Option<TransactionId>,
2520    telemetry: Arc<Telemetry>,
2521    builder: Arc<PromptBuilder>,
2522    is_insertion: bool,
2523}
2524
2525impl Codegen {
2526    pub fn new(
2527        buffer: Entity<MultiBuffer>,
2528        range: Range<Anchor>,
2529        initial_transaction_id: Option<TransactionId>,
2530        telemetry: Arc<Telemetry>,
2531        builder: Arc<PromptBuilder>,
2532        cx: &mut Context<Self>,
2533    ) -> Self {
2534        let codegen = cx.new(|cx| {
2535            CodegenAlternative::new(
2536                buffer.clone(),
2537                range.clone(),
2538                false,
2539                Some(telemetry.clone()),
2540                builder.clone(),
2541                cx,
2542            )
2543        });
2544        let mut this = Self {
2545            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2546            alternatives: vec![codegen],
2547            active_alternative: 0,
2548            seen_alternatives: HashSet::default(),
2549            subscriptions: Vec::new(),
2550            buffer,
2551            range,
2552            initial_transaction_id,
2553            telemetry,
2554            builder,
2555        };
2556        this.activate(0, cx);
2557        this
2558    }
2559
2560    fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
2561        let codegen = self.active_alternative().clone();
2562        self.subscriptions.clear();
2563        self.subscriptions
2564            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2565        self.subscriptions
2566            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2567    }
2568
2569    fn active_alternative(&self) -> &Entity<CodegenAlternative> {
2570        &self.alternatives[self.active_alternative]
2571    }
2572
2573    fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
2574        &self.active_alternative().read(cx).status
2575    }
2576
2577    fn alternative_count(&self, cx: &App) -> usize {
2578        LanguageModelRegistry::read_global(cx)
2579            .inline_alternative_models()
2580            .len()
2581            + 1
2582    }
2583
2584    pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
2585        let next_active_ix = if self.active_alternative == 0 {
2586            self.alternatives.len() - 1
2587        } else {
2588            self.active_alternative - 1
2589        };
2590        self.activate(next_active_ix, cx);
2591    }
2592
2593    pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
2594        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2595        self.activate(next_active_ix, cx);
2596    }
2597
2598    fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
2599        self.active_alternative()
2600            .update(cx, |codegen, cx| codegen.set_active(false, cx));
2601        self.seen_alternatives.insert(index);
2602        self.active_alternative = index;
2603        self.active_alternative()
2604            .update(cx, |codegen, cx| codegen.set_active(true, cx));
2605        self.subscribe_to_alternative(cx);
2606        cx.notify();
2607    }
2608
2609    pub fn start(
2610        &mut self,
2611        user_prompt: String,
2612        assistant_panel_context: Option<LanguageModelRequest>,
2613        cx: &mut Context<Self>,
2614    ) -> Result<()> {
2615        let alternative_models = LanguageModelRegistry::read_global(cx)
2616            .inline_alternative_models()
2617            .to_vec();
2618
2619        self.active_alternative()
2620            .update(cx, |alternative, cx| alternative.undo(cx));
2621        self.activate(0, cx);
2622        self.alternatives.truncate(1);
2623
2624        for _ in 0..alternative_models.len() {
2625            self.alternatives.push(cx.new(|cx| {
2626                CodegenAlternative::new(
2627                    self.buffer.clone(),
2628                    self.range.clone(),
2629                    false,
2630                    Some(self.telemetry.clone()),
2631                    self.builder.clone(),
2632                    cx,
2633                )
2634            }));
2635        }
2636
2637        let primary_model = LanguageModelRegistry::read_global(cx)
2638            .default_model()
2639            .context("no active model")?
2640            .model;
2641
2642        for (model, alternative) in iter::once(primary_model)
2643            .chain(alternative_models)
2644            .zip(&self.alternatives)
2645        {
2646            alternative.update(cx, |alternative, cx| {
2647                alternative.start(
2648                    user_prompt.clone(),
2649                    assistant_panel_context.clone(),
2650                    model.clone(),
2651                    cx,
2652                )
2653            })?;
2654        }
2655
2656        Ok(())
2657    }
2658
2659    pub fn stop(&mut self, cx: &mut Context<Self>) {
2660        for codegen in &self.alternatives {
2661            codegen.update(cx, |codegen, cx| codegen.stop(cx));
2662        }
2663    }
2664
2665    pub fn undo(&mut self, cx: &mut Context<Self>) {
2666        self.active_alternative()
2667            .update(cx, |codegen, cx| codegen.undo(cx));
2668
2669        self.buffer.update(cx, |buffer, cx| {
2670            if let Some(transaction_id) = self.initial_transaction_id.take() {
2671                buffer.undo_transaction(transaction_id, cx);
2672                buffer.refresh_preview(cx);
2673            }
2674        });
2675    }
2676
2677    pub fn count_tokens(
2678        &self,
2679        user_prompt: String,
2680        assistant_panel_context: Option<LanguageModelRequest>,
2681        cx: &App,
2682    ) -> BoxFuture<'static, Result<TokenCounts>> {
2683        self.active_alternative()
2684            .read(cx)
2685            .count_tokens(user_prompt, assistant_panel_context, cx)
2686    }
2687
2688    pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
2689        self.active_alternative().read(cx).buffer.clone()
2690    }
2691
2692    pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
2693        self.active_alternative().read(cx).old_buffer.clone()
2694    }
2695
2696    pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
2697        self.active_alternative().read(cx).snapshot.clone()
2698    }
2699
2700    pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
2701        self.active_alternative().read(cx).edit_position
2702    }
2703
2704    fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
2705        &self.active_alternative().read(cx).diff
2706    }
2707
2708    pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
2709        self.active_alternative().read(cx).last_equal_ranges()
2710    }
2711}
2712
2713impl EventEmitter<CodegenEvent> for Codegen {}
2714
2715pub struct CodegenAlternative {
2716    buffer: Entity<MultiBuffer>,
2717    old_buffer: Entity<Buffer>,
2718    snapshot: MultiBufferSnapshot,
2719    edit_position: Option<Anchor>,
2720    range: Range<Anchor>,
2721    last_equal_ranges: Vec<Range<Anchor>>,
2722    transformation_transaction_id: Option<TransactionId>,
2723    status: CodegenStatus,
2724    generation: Task<()>,
2725    diff: Diff,
2726    telemetry: Option<Arc<Telemetry>>,
2727    _subscription: gpui::Subscription,
2728    builder: Arc<PromptBuilder>,
2729    active: bool,
2730    edits: Vec<(Range<Anchor>, String)>,
2731    line_operations: Vec<LineOperation>,
2732    request: Option<LanguageModelRequest>,
2733    elapsed_time: Option<f64>,
2734    completion: Option<String>,
2735    message_id: Option<String>,
2736}
2737
2738enum CodegenStatus {
2739    Idle,
2740    Pending,
2741    Done,
2742    Error(anyhow::Error),
2743}
2744
2745#[derive(Default)]
2746struct Diff {
2747    deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2748    inserted_row_ranges: Vec<Range<Anchor>>,
2749}
2750
2751impl Diff {
2752    fn is_empty(&self) -> bool {
2753        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2754    }
2755}
2756
2757impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2758
2759impl CodegenAlternative {
2760    pub fn new(
2761        multi_buffer: Entity<MultiBuffer>,
2762        range: Range<Anchor>,
2763        active: bool,
2764        telemetry: Option<Arc<Telemetry>>,
2765        builder: Arc<PromptBuilder>,
2766        cx: &mut Context<Self>,
2767    ) -> Self {
2768        let snapshot = multi_buffer.read(cx).snapshot(cx);
2769
2770        let (buffer, _, _) = snapshot
2771            .range_to_buffer_ranges(range.clone())
2772            .pop()
2773            .unwrap();
2774        let old_buffer = cx.new(|cx| {
2775            let text = buffer.as_rope().clone();
2776            let line_ending = buffer.line_ending();
2777            let language = buffer.language().cloned();
2778            let language_registry = multi_buffer
2779                .read(cx)
2780                .buffer(buffer.remote_id())
2781                .unwrap()
2782                .read(cx)
2783                .language_registry();
2784
2785            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2786            buffer.set_language(language, cx);
2787            if let Some(language_registry) = language_registry {
2788                buffer.set_language_registry(language_registry)
2789            }
2790            buffer
2791        });
2792
2793        Self {
2794            buffer: multi_buffer.clone(),
2795            old_buffer,
2796            edit_position: None,
2797            message_id: None,
2798            snapshot,
2799            last_equal_ranges: Default::default(),
2800            transformation_transaction_id: None,
2801            status: CodegenStatus::Idle,
2802            generation: Task::ready(()),
2803            diff: Diff::default(),
2804            telemetry,
2805            _subscription: cx.subscribe(&multi_buffer, Self::handle_buffer_event),
2806            builder,
2807            active,
2808            edits: Vec::new(),
2809            line_operations: Vec::new(),
2810            range,
2811            request: None,
2812            elapsed_time: None,
2813            completion: None,
2814        }
2815    }
2816
2817    fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
2818        if active != self.active {
2819            self.active = active;
2820
2821            if self.active {
2822                let edits = self.edits.clone();
2823                self.apply_edits(edits, cx);
2824                if matches!(self.status, CodegenStatus::Pending) {
2825                    let line_operations = self.line_operations.clone();
2826                    self.reapply_line_based_diff(line_operations, cx);
2827                } else {
2828                    self.reapply_batch_diff(cx).detach();
2829                }
2830            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2831                self.buffer.update(cx, |buffer, cx| {
2832                    buffer.undo_transaction(transaction_id, cx);
2833                    buffer.forget_transaction(transaction_id, cx);
2834                });
2835            }
2836        }
2837    }
2838
2839    fn handle_buffer_event(
2840        &mut self,
2841        _buffer: Entity<MultiBuffer>,
2842        event: &multi_buffer::Event,
2843        cx: &mut Context<Self>,
2844    ) {
2845        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2846            if self.transformation_transaction_id == Some(*transaction_id) {
2847                self.transformation_transaction_id = None;
2848                self.generation = Task::ready(());
2849                cx.emit(CodegenEvent::Undone);
2850            }
2851        }
2852    }
2853
2854    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2855        &self.last_equal_ranges
2856    }
2857
2858    pub fn count_tokens(
2859        &self,
2860        user_prompt: String,
2861        assistant_panel_context: Option<LanguageModelRequest>,
2862        cx: &App,
2863    ) -> BoxFuture<'static, Result<TokenCounts>> {
2864        if let Some(ConfiguredModel { model, .. }) =
2865            LanguageModelRegistry::read_global(cx).inline_assistant_model()
2866        {
2867            let request =
2868                self.build_request(&model, user_prompt, assistant_panel_context.clone(), cx);
2869            match request {
2870                Ok(request) => {
2871                    let total_count = model.count_tokens(request.clone(), cx);
2872                    let assistant_panel_count = assistant_panel_context
2873                        .map(|context| model.count_tokens(context, cx))
2874                        .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2875
2876                    async move {
2877                        Ok(TokenCounts {
2878                            total: total_count.await?,
2879                            assistant_panel: assistant_panel_count.await?,
2880                        })
2881                    }
2882                    .boxed()
2883                }
2884                Err(error) => futures::future::ready(Err(error)).boxed(),
2885            }
2886        } else {
2887            future::ready(Err(anyhow!("no active model"))).boxed()
2888        }
2889    }
2890
2891    pub fn start(
2892        &mut self,
2893        user_prompt: String,
2894        assistant_panel_context: Option<LanguageModelRequest>,
2895        model: Arc<dyn LanguageModel>,
2896        cx: &mut Context<Self>,
2897    ) -> Result<()> {
2898        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2899            self.buffer.update(cx, |buffer, cx| {
2900                buffer.undo_transaction(transformation_transaction_id, cx);
2901            });
2902        }
2903
2904        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2905
2906        let api_key = model.api_key(cx);
2907        let telemetry_id = model.telemetry_id();
2908        let provider_id = model.provider_id();
2909        let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2910            if user_prompt.trim().to_lowercase() == "delete" {
2911                async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2912            } else {
2913                let request =
2914                    self.build_request(&model, user_prompt, assistant_panel_context, cx)?;
2915                self.request = Some(request.clone());
2916
2917                cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
2918                    .boxed_local()
2919            };
2920        self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2921        Ok(())
2922    }
2923
2924    fn build_request(
2925        &self,
2926        model: &Arc<dyn LanguageModel>,
2927        user_prompt: String,
2928        assistant_panel_context: Option<LanguageModelRequest>,
2929        cx: &App,
2930    ) -> Result<LanguageModelRequest> {
2931        let buffer = self.buffer.read(cx).snapshot(cx);
2932        let language = buffer.language_at(self.range.start);
2933        let language_name = if let Some(language) = language.as_ref() {
2934            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2935                None
2936            } else {
2937                Some(language.name())
2938            }
2939        } else {
2940            None
2941        };
2942
2943        let language_name = language_name.as_ref();
2944        let start = buffer.point_to_buffer_offset(self.range.start);
2945        let end = buffer.point_to_buffer_offset(self.range.end);
2946        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2947            let (start_buffer, start_buffer_offset) = start;
2948            let (end_buffer, end_buffer_offset) = end;
2949            if start_buffer.remote_id() == end_buffer.remote_id() {
2950                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2951            } else {
2952                return Err(anyhow::anyhow!("invalid transformation range"));
2953            }
2954        } else {
2955            return Err(anyhow::anyhow!("invalid transformation range"));
2956        };
2957
2958        let prompt = self
2959            .builder
2960            .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2961            .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2962
2963        let mut messages = Vec::new();
2964        if let Some(context_request) = assistant_panel_context {
2965            messages = context_request.messages;
2966        }
2967
2968        messages.push(LanguageModelRequestMessage {
2969            role: Role::User,
2970            content: vec![prompt.into()],
2971            cache: false,
2972        });
2973
2974        Ok(LanguageModelRequest {
2975            thread_id: None,
2976            prompt_id: None,
2977            mode: None,
2978            messages,
2979            tools: Vec::new(),
2980            stop: Vec::new(),
2981            temperature: AssistantSettings::temperature_for_model(&model, cx),
2982        })
2983    }
2984
2985    pub fn handle_stream(
2986        &mut self,
2987        model_telemetry_id: String,
2988        model_provider_id: String,
2989        model_api_key: Option<String>,
2990        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2991        cx: &mut Context<Self>,
2992    ) {
2993        let start_time = Instant::now();
2994        let snapshot = self.snapshot.clone();
2995        let selected_text = snapshot
2996            .text_for_range(self.range.start..self.range.end)
2997            .collect::<Rope>();
2998
2999        let selection_start = self.range.start.to_point(&snapshot);
3000
3001        // Start with the indentation of the first line in the selection
3002        let mut suggested_line_indent = snapshot
3003            .suggested_indents(selection_start.row..=selection_start.row, cx)
3004            .into_values()
3005            .next()
3006            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
3007
3008        // If the first line in the selection does not have indentation, check the following lines
3009        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
3010            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
3011                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
3012                // Prefer tabs if a line in the selection uses tabs as indentation
3013                if line_indent.kind == IndentKind::Tab {
3014                    suggested_line_indent.kind = IndentKind::Tab;
3015                    break;
3016                }
3017            }
3018        }
3019
3020        let http_client = cx.http_client();
3021        let telemetry = self.telemetry.clone();
3022        let language_name = {
3023            let multibuffer = self.buffer.read(cx);
3024            let snapshot = multibuffer.snapshot(cx);
3025            let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
3026            ranges
3027                .first()
3028                .and_then(|(buffer, _, _)| buffer.language())
3029                .map(|language| language.name())
3030        };
3031
3032        self.diff = Diff::default();
3033        self.status = CodegenStatus::Pending;
3034        let mut edit_start = self.range.start.to_offset(&snapshot);
3035        let completion = Arc::new(Mutex::new(String::new()));
3036        let completion_clone = completion.clone();
3037
3038        self.generation = cx.spawn(async move |codegen, cx| {
3039            let stream = stream.await;
3040            let message_id = stream
3041                .as_ref()
3042                .ok()
3043                .and_then(|stream| stream.message_id.clone());
3044            let generate = async {
3045                let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
3046                let executor = cx.background_executor().clone();
3047                let message_id = message_id.clone();
3048                let line_based_stream_diff: Task<anyhow::Result<()>> =
3049                    cx.background_spawn(async move {
3050                        let mut response_latency = None;
3051                        let request_start = Instant::now();
3052                        let diff = async {
3053                            let chunks =
3054                                StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
3055                            futures::pin_mut!(chunks);
3056                            let mut diff = StreamingDiff::new(selected_text.to_string());
3057                            let mut line_diff = LineDiff::default();
3058
3059                            let mut new_text = String::new();
3060                            let mut base_indent = None;
3061                            let mut line_indent = None;
3062                            let mut first_line = true;
3063
3064                            while let Some(chunk) = chunks.next().await {
3065                                if response_latency.is_none() {
3066                                    response_latency = Some(request_start.elapsed());
3067                                }
3068                                let chunk = chunk?;
3069                                completion_clone.lock().push_str(&chunk);
3070
3071                                let mut lines = chunk.split('\n').peekable();
3072                                while let Some(line) = lines.next() {
3073                                    new_text.push_str(line);
3074                                    if line_indent.is_none() {
3075                                        if let Some(non_whitespace_ch_ix) =
3076                                            new_text.find(|ch: char| !ch.is_whitespace())
3077                                        {
3078                                            line_indent = Some(non_whitespace_ch_ix);
3079                                            base_indent = base_indent.or(line_indent);
3080
3081                                            let line_indent = line_indent.unwrap();
3082                                            let base_indent = base_indent.unwrap();
3083                                            let indent_delta =
3084                                                line_indent as i32 - base_indent as i32;
3085                                            let mut corrected_indent_len = cmp::max(
3086                                                0,
3087                                                suggested_line_indent.len as i32 + indent_delta,
3088                                            )
3089                                                as usize;
3090                                            if first_line {
3091                                                corrected_indent_len = corrected_indent_len
3092                                                    .saturating_sub(
3093                                                        selection_start.column as usize,
3094                                                    );
3095                                            }
3096
3097                                            let indent_char = suggested_line_indent.char();
3098                                            let mut indent_buffer = [0; 4];
3099                                            let indent_str =
3100                                                indent_char.encode_utf8(&mut indent_buffer);
3101                                            new_text.replace_range(
3102                                                ..line_indent,
3103                                                &indent_str.repeat(corrected_indent_len),
3104                                            );
3105                                        }
3106                                    }
3107
3108                                    if line_indent.is_some() {
3109                                        let char_ops = diff.push_new(&new_text);
3110                                        line_diff.push_char_operations(&char_ops, &selected_text);
3111                                        diff_tx
3112                                            .send((char_ops, line_diff.line_operations()))
3113                                            .await?;
3114                                        new_text.clear();
3115                                    }
3116
3117                                    if lines.peek().is_some() {
3118                                        let char_ops = diff.push_new("\n");
3119                                        line_diff.push_char_operations(&char_ops, &selected_text);
3120                                        diff_tx
3121                                            .send((char_ops, line_diff.line_operations()))
3122                                            .await?;
3123                                        if line_indent.is_none() {
3124                                            // Don't write out the leading indentation in empty lines on the next line
3125                                            // This is the case where the above if statement didn't clear the buffer
3126                                            new_text.clear();
3127                                        }
3128                                        line_indent = None;
3129                                        first_line = false;
3130                                    }
3131                                }
3132                            }
3133
3134                            let mut char_ops = diff.push_new(&new_text);
3135                            char_ops.extend(diff.finish());
3136                            line_diff.push_char_operations(&char_ops, &selected_text);
3137                            line_diff.finish(&selected_text);
3138                            diff_tx
3139                                .send((char_ops, line_diff.line_operations()))
3140                                .await?;
3141
3142                            anyhow::Ok(())
3143                        };
3144
3145                        let result = diff.await;
3146
3147                        let error_message = result.as_ref().err().map(|error| error.to_string());
3148                        report_assistant_event(
3149                            AssistantEventData {
3150                                conversation_id: None,
3151                                message_id,
3152                                kind: AssistantKind::Inline,
3153                                phase: AssistantPhase::Response,
3154                                model: model_telemetry_id,
3155                                model_provider: model_provider_id.to_string(),
3156                                response_latency,
3157                                error_message,
3158                                language_name: language_name.map(|name| name.to_proto()),
3159                            },
3160                            telemetry,
3161                            http_client,
3162                            model_api_key,
3163                            &executor,
3164                        );
3165
3166                        result?;
3167                        Ok(())
3168                    });
3169
3170                while let Some((char_ops, line_ops)) = diff_rx.next().await {
3171                    codegen.update(cx, |codegen, cx| {
3172                        codegen.last_equal_ranges.clear();
3173
3174                        let edits = char_ops
3175                            .into_iter()
3176                            .filter_map(|operation| match operation {
3177                                CharOperation::Insert { text } => {
3178                                    let edit_start = snapshot.anchor_after(edit_start);
3179                                    Some((edit_start..edit_start, text))
3180                                }
3181                                CharOperation::Delete { bytes } => {
3182                                    let edit_end = edit_start + bytes;
3183                                    let edit_range = snapshot.anchor_after(edit_start)
3184                                        ..snapshot.anchor_before(edit_end);
3185                                    edit_start = edit_end;
3186                                    Some((edit_range, String::new()))
3187                                }
3188                                CharOperation::Keep { bytes } => {
3189                                    let edit_end = edit_start + bytes;
3190                                    let edit_range = snapshot.anchor_after(edit_start)
3191                                        ..snapshot.anchor_before(edit_end);
3192                                    edit_start = edit_end;
3193                                    codegen.last_equal_ranges.push(edit_range);
3194                                    None
3195                                }
3196                            })
3197                            .collect::<Vec<_>>();
3198
3199                        if codegen.active {
3200                            codegen.apply_edits(edits.iter().cloned(), cx);
3201                            codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
3202                        }
3203                        codegen.edits.extend(edits);
3204                        codegen.line_operations = line_ops;
3205                        codegen.edit_position = Some(snapshot.anchor_after(edit_start));
3206
3207                        cx.notify();
3208                    })?;
3209                }
3210
3211                // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
3212                // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
3213                // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
3214                let batch_diff_task =
3215                    codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
3216                let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
3217                line_based_stream_diff?;
3218
3219                anyhow::Ok(())
3220            };
3221
3222            let result = generate.await;
3223            let elapsed_time = start_time.elapsed().as_secs_f64();
3224
3225            codegen
3226                .update(cx, |this, cx| {
3227                    this.message_id = message_id;
3228                    this.last_equal_ranges.clear();
3229                    if let Err(error) = result {
3230                        this.status = CodegenStatus::Error(error);
3231                    } else {
3232                        this.status = CodegenStatus::Done;
3233                    }
3234                    this.elapsed_time = Some(elapsed_time);
3235                    this.completion = Some(completion.lock().clone());
3236                    cx.emit(CodegenEvent::Finished);
3237                    cx.notify();
3238                })
3239                .ok();
3240        });
3241        cx.notify();
3242    }
3243
3244    pub fn stop(&mut self, cx: &mut Context<Self>) {
3245        self.last_equal_ranges.clear();
3246        if self.diff.is_empty() {
3247            self.status = CodegenStatus::Idle;
3248        } else {
3249            self.status = CodegenStatus::Done;
3250        }
3251        self.generation = Task::ready(());
3252        cx.emit(CodegenEvent::Finished);
3253        cx.notify();
3254    }
3255
3256    pub fn undo(&mut self, cx: &mut Context<Self>) {
3257        self.buffer.update(cx, |buffer, cx| {
3258            if let Some(transaction_id) = self.transformation_transaction_id.take() {
3259                buffer.undo_transaction(transaction_id, cx);
3260                buffer.refresh_preview(cx);
3261            }
3262        });
3263    }
3264
3265    fn apply_edits(
3266        &mut self,
3267        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
3268        cx: &mut Context<CodegenAlternative>,
3269    ) {
3270        let transaction = self.buffer.update(cx, |buffer, cx| {
3271            // Avoid grouping assistant edits with user edits.
3272            buffer.finalize_last_transaction(cx);
3273            buffer.start_transaction(cx);
3274            buffer.edit(edits, None, cx);
3275            buffer.end_transaction(cx)
3276        });
3277
3278        if let Some(transaction) = transaction {
3279            if let Some(first_transaction) = self.transformation_transaction_id {
3280                // Group all assistant edits into the first transaction.
3281                self.buffer.update(cx, |buffer, cx| {
3282                    buffer.merge_transactions(transaction, first_transaction, cx)
3283                });
3284            } else {
3285                self.transformation_transaction_id = Some(transaction);
3286                self.buffer
3287                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3288            }
3289        }
3290    }
3291
3292    fn reapply_line_based_diff(
3293        &mut self,
3294        line_operations: impl IntoIterator<Item = LineOperation>,
3295        cx: &mut Context<Self>,
3296    ) {
3297        let old_snapshot = self.snapshot.clone();
3298        let old_range = self.range.to_point(&old_snapshot);
3299        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3300        let new_range = self.range.to_point(&new_snapshot);
3301
3302        let mut old_row = old_range.start.row;
3303        let mut new_row = new_range.start.row;
3304
3305        self.diff.deleted_row_ranges.clear();
3306        self.diff.inserted_row_ranges.clear();
3307        for operation in line_operations {
3308            match operation {
3309                LineOperation::Keep { lines } => {
3310                    old_row += lines;
3311                    new_row += lines;
3312                }
3313                LineOperation::Delete { lines } => {
3314                    let old_end_row = old_row + lines - 1;
3315                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3316
3317                    if let Some((_, last_deleted_row_range)) =
3318                        self.diff.deleted_row_ranges.last_mut()
3319                    {
3320                        if *last_deleted_row_range.end() + 1 == old_row {
3321                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3322                        } else {
3323                            self.diff
3324                                .deleted_row_ranges
3325                                .push((new_row, old_row..=old_end_row));
3326                        }
3327                    } else {
3328                        self.diff
3329                            .deleted_row_ranges
3330                            .push((new_row, old_row..=old_end_row));
3331                    }
3332
3333                    old_row += lines;
3334                }
3335                LineOperation::Insert { lines } => {
3336                    let new_end_row = new_row + lines - 1;
3337                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3338                    let end = new_snapshot.anchor_before(Point::new(
3339                        new_end_row,
3340                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
3341                    ));
3342                    self.diff.inserted_row_ranges.push(start..end);
3343                    new_row += lines;
3344                }
3345            }
3346
3347            cx.notify();
3348        }
3349    }
3350
3351    fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
3352        let old_snapshot = self.snapshot.clone();
3353        let old_range = self.range.to_point(&old_snapshot);
3354        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3355        let new_range = self.range.to_point(&new_snapshot);
3356
3357        cx.spawn(async move |codegen, cx| {
3358            let (deleted_row_ranges, inserted_row_ranges) = cx
3359                .background_spawn(async move {
3360                    let old_text = old_snapshot
3361                        .text_for_range(
3362                            Point::new(old_range.start.row, 0)
3363                                ..Point::new(
3364                                    old_range.end.row,
3365                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3366                                ),
3367                        )
3368                        .collect::<String>();
3369                    let new_text = new_snapshot
3370                        .text_for_range(
3371                            Point::new(new_range.start.row, 0)
3372                                ..Point::new(
3373                                    new_range.end.row,
3374                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3375                                ),
3376                        )
3377                        .collect::<String>();
3378
3379                    let old_start_row = old_range.start.row;
3380                    let new_start_row = new_range.start.row;
3381                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3382                    let mut inserted_row_ranges = Vec::new();
3383                    for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
3384                        let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
3385                        let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
3386                        if !old_rows.is_empty() {
3387                            deleted_row_ranges.push((
3388                                new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
3389                                old_rows.start..=old_rows.end - 1,
3390                            ));
3391                        }
3392                        if !new_rows.is_empty() {
3393                            let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
3394                            let new_end_row = new_rows.end - 1;
3395                            let end = new_snapshot.anchor_before(Point::new(
3396                                new_end_row,
3397                                new_snapshot.line_len(MultiBufferRow(new_end_row)),
3398                            ));
3399                            inserted_row_ranges.push(start..end);
3400                        }
3401                    }
3402                    (deleted_row_ranges, inserted_row_ranges)
3403                })
3404                .await;
3405
3406            codegen
3407                .update(cx, |codegen, cx| {
3408                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
3409                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
3410                    cx.notify();
3411                })
3412                .ok();
3413        })
3414    }
3415}
3416
3417struct StripInvalidSpans<T> {
3418    stream: T,
3419    stream_done: bool,
3420    buffer: String,
3421    first_line: bool,
3422    line_end: bool,
3423    starts_with_code_block: bool,
3424}
3425
3426impl<T> StripInvalidSpans<T>
3427where
3428    T: Stream<Item = Result<String>>,
3429{
3430    fn new(stream: T) -> Self {
3431        Self {
3432            stream,
3433            stream_done: false,
3434            buffer: String::new(),
3435            first_line: true,
3436            line_end: false,
3437            starts_with_code_block: false,
3438        }
3439    }
3440}
3441
3442impl<T> Stream for StripInvalidSpans<T>
3443where
3444    T: Stream<Item = Result<String>>,
3445{
3446    type Item = Result<String>;
3447
3448    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3449        const CODE_BLOCK_DELIMITER: &str = "```";
3450        const CURSOR_SPAN: &str = "<|CURSOR|>";
3451
3452        let this = unsafe { self.get_unchecked_mut() };
3453        loop {
3454            if !this.stream_done {
3455                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3456                match stream.as_mut().poll_next(cx) {
3457                    Poll::Ready(Some(Ok(chunk))) => {
3458                        this.buffer.push_str(&chunk);
3459                    }
3460                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3461                    Poll::Ready(None) => {
3462                        this.stream_done = true;
3463                    }
3464                    Poll::Pending => return Poll::Pending,
3465                }
3466            }
3467
3468            let mut chunk = String::new();
3469            let mut consumed = 0;
3470            if !this.buffer.is_empty() {
3471                let mut lines = this.buffer.split('\n').enumerate().peekable();
3472                while let Some((line_ix, line)) = lines.next() {
3473                    if line_ix > 0 {
3474                        this.first_line = false;
3475                    }
3476
3477                    if this.first_line {
3478                        let trimmed_line = line.trim();
3479                        if lines.peek().is_some() {
3480                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3481                                consumed += line.len() + 1;
3482                                this.starts_with_code_block = true;
3483                                continue;
3484                            }
3485                        } else if trimmed_line.is_empty()
3486                            || prefixes(CODE_BLOCK_DELIMITER)
3487                                .any(|prefix| trimmed_line.starts_with(prefix))
3488                        {
3489                            break;
3490                        }
3491                    }
3492
3493                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
3494                    if lines.peek().is_some() {
3495                        if this.line_end {
3496                            chunk.push('\n');
3497                        }
3498
3499                        chunk.push_str(&line_without_cursor);
3500                        this.line_end = true;
3501                        consumed += line.len() + 1;
3502                    } else if this.stream_done {
3503                        if !this.starts_with_code_block
3504                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3505                        {
3506                            if this.line_end {
3507                                chunk.push('\n');
3508                            }
3509
3510                            chunk.push_str(&line);
3511                        }
3512
3513                        consumed += line.len();
3514                    } else {
3515                        let trimmed_line = line.trim();
3516                        if trimmed_line.is_empty()
3517                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3518                            || prefixes(CODE_BLOCK_DELIMITER)
3519                                .any(|prefix| trimmed_line.ends_with(prefix))
3520                        {
3521                            break;
3522                        } else {
3523                            if this.line_end {
3524                                chunk.push('\n');
3525                                this.line_end = false;
3526                            }
3527
3528                            chunk.push_str(&line_without_cursor);
3529                            consumed += line.len();
3530                        }
3531                    }
3532                }
3533            }
3534
3535            this.buffer = this.buffer.split_off(consumed);
3536            if !chunk.is_empty() {
3537                return Poll::Ready(Some(Ok(chunk)));
3538            } else if this.stream_done {
3539                return Poll::Ready(None);
3540            }
3541        }
3542    }
3543}
3544
3545struct AssistantCodeActionProvider {
3546    editor: WeakEntity<Editor>,
3547    workspace: WeakEntity<Workspace>,
3548}
3549
3550const ASSISTANT_CODE_ACTION_PROVIDER_ID: &str = "assistant";
3551
3552impl CodeActionProvider for AssistantCodeActionProvider {
3553    fn id(&self) -> Arc<str> {
3554        ASSISTANT_CODE_ACTION_PROVIDER_ID.into()
3555    }
3556
3557    fn code_actions(
3558        &self,
3559        buffer: &Entity<Buffer>,
3560        range: Range<text::Anchor>,
3561        _: &mut Window,
3562        cx: &mut App,
3563    ) -> Task<Result<Vec<CodeAction>>> {
3564        if !Assistant::enabled(cx) {
3565            return Task::ready(Ok(Vec::new()));
3566        }
3567
3568        let snapshot = buffer.read(cx).snapshot();
3569        let mut range = range.to_point(&snapshot);
3570
3571        // Expand the range to line boundaries.
3572        range.start.column = 0;
3573        range.end.column = snapshot.line_len(range.end.row);
3574
3575        let mut has_diagnostics = false;
3576        for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3577            range.start = cmp::min(range.start, diagnostic.range.start);
3578            range.end = cmp::max(range.end, diagnostic.range.end);
3579            has_diagnostics = true;
3580        }
3581        if has_diagnostics {
3582            if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3583                if let Some(symbol) = symbols_containing_start.last() {
3584                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3585                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3586                }
3587            }
3588
3589            if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3590                if let Some(symbol) = symbols_containing_end.last() {
3591                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3592                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3593                }
3594            }
3595
3596            Task::ready(Ok(vec![CodeAction {
3597                server_id: language::LanguageServerId(0),
3598                range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3599                lsp_action: LspAction::Action(Box::new(lsp::CodeAction {
3600                    title: "Fix with Assistant".into(),
3601                    ..Default::default()
3602                })),
3603                resolved: true,
3604            }]))
3605        } else {
3606            Task::ready(Ok(Vec::new()))
3607        }
3608    }
3609
3610    fn apply_code_action(
3611        &self,
3612        buffer: Entity<Buffer>,
3613        action: CodeAction,
3614        excerpt_id: ExcerptId,
3615        _push_to_history: bool,
3616        window: &mut Window,
3617        cx: &mut App,
3618    ) -> Task<Result<ProjectTransaction>> {
3619        let editor = self.editor.clone();
3620        let workspace = self.workspace.clone();
3621        window.spawn(cx, async move |cx| {
3622            let editor = editor.upgrade().context("editor was released")?;
3623            let range = editor
3624                .update(cx, |editor, cx| {
3625                    editor.buffer().update(cx, |multibuffer, cx| {
3626                        let buffer = buffer.read(cx);
3627                        let multibuffer_snapshot = multibuffer.read(cx);
3628
3629                        let old_context_range =
3630                            multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3631                        let mut new_context_range = old_context_range.clone();
3632                        if action
3633                            .range
3634                            .start
3635                            .cmp(&old_context_range.start, buffer)
3636                            .is_lt()
3637                        {
3638                            new_context_range.start = action.range.start;
3639                        }
3640                        if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3641                            new_context_range.end = action.range.end;
3642                        }
3643                        drop(multibuffer_snapshot);
3644
3645                        if new_context_range != old_context_range {
3646                            multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3647                        }
3648
3649                        let multibuffer_snapshot = multibuffer.read(cx);
3650                        Some(
3651                            multibuffer_snapshot
3652                                .anchor_in_excerpt(excerpt_id, action.range.start)?
3653                                ..multibuffer_snapshot
3654                                    .anchor_in_excerpt(excerpt_id, action.range.end)?,
3655                        )
3656                    })
3657                })?
3658                .context("invalid range")?;
3659            let assistant_panel = workspace.update(cx, |workspace, cx| {
3660                workspace
3661                    .panel::<AssistantPanel>(cx)
3662                    .context("assistant panel was released")
3663            })??;
3664
3665            cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
3666                let assist_id = assistant.suggest_assist(
3667                    &editor,
3668                    range,
3669                    "Fix Diagnostics".into(),
3670                    None,
3671                    true,
3672                    Some(workspace),
3673                    Some(&assistant_panel),
3674                    window,
3675                    cx,
3676                );
3677                assistant.start_assist(assist_id, window, cx);
3678            })?;
3679
3680            Ok(ProjectTransaction::default())
3681        })
3682    }
3683}
3684
3685fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3686    (0..text.len() - 1).map(|ix| &text[..ix + 1])
3687}
3688
3689fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3690    ranges.sort_unstable_by(|a, b| {
3691        a.start
3692            .cmp(&b.start, buffer)
3693            .then_with(|| b.end.cmp(&a.end, buffer))
3694    });
3695
3696    let mut ix = 0;
3697    while ix + 1 < ranges.len() {
3698        let b = ranges[ix + 1].clone();
3699        let a = &mut ranges[ix];
3700        if a.end.cmp(&b.start, buffer).is_gt() {
3701            if a.end.cmp(&b.end, buffer).is_lt() {
3702                a.end = b.end;
3703            }
3704            ranges.remove(ix + 1);
3705        } else {
3706            ix += 1;
3707        }
3708    }
3709}
3710
3711#[cfg(test)]
3712mod tests {
3713    use super::*;
3714    use futures::stream::{self};
3715    use gpui::TestAppContext;
3716    use indoc::indoc;
3717    use language::{
3718        Buffer, Language, LanguageConfig, LanguageMatcher, Point, language_settings,
3719        tree_sitter_rust,
3720    };
3721    use language_model::{LanguageModelRegistry, TokenUsage};
3722    use rand::prelude::*;
3723    use serde::Serialize;
3724    use settings::SettingsStore;
3725    use std::{future, sync::Arc};
3726
3727    #[derive(Serialize)]
3728    pub struct DummyCompletionRequest {
3729        pub name: String,
3730    }
3731
3732    #[gpui::test(iterations = 10)]
3733    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3734        cx.set_global(cx.update(SettingsStore::test));
3735        cx.update(language_model::LanguageModelRegistry::test);
3736        cx.update(language_settings::init);
3737
3738        let text = indoc! {"
3739            fn main() {
3740                let x = 0;
3741                for _ in 0..10 {
3742                    x += 1;
3743                }
3744            }
3745        "};
3746        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3747        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3748        let range = buffer.read_with(cx, |buffer, cx| {
3749            let snapshot = buffer.snapshot(cx);
3750            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3751        });
3752        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3753        let codegen = cx.new(|cx| {
3754            CodegenAlternative::new(
3755                buffer.clone(),
3756                range.clone(),
3757                true,
3758                None,
3759                prompt_builder,
3760                cx,
3761            )
3762        });
3763
3764        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3765
3766        let mut new_text = concat!(
3767            "       let mut x = 0;\n",
3768            "       while x < 10 {\n",
3769            "           x += 1;\n",
3770            "       }",
3771        );
3772        while !new_text.is_empty() {
3773            let max_len = cmp::min(new_text.len(), 10);
3774            let len = rng.gen_range(1..=max_len);
3775            let (chunk, suffix) = new_text.split_at(len);
3776            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3777            new_text = suffix;
3778            cx.background_executor.run_until_parked();
3779        }
3780        drop(chunks_tx);
3781        cx.background_executor.run_until_parked();
3782
3783        assert_eq!(
3784            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3785            indoc! {"
3786                fn main() {
3787                    let mut x = 0;
3788                    while x < 10 {
3789                        x += 1;
3790                    }
3791                }
3792            "}
3793        );
3794    }
3795
3796    #[gpui::test(iterations = 10)]
3797    async fn test_autoindent_when_generating_past_indentation(
3798        cx: &mut TestAppContext,
3799        mut rng: StdRng,
3800    ) {
3801        cx.set_global(cx.update(SettingsStore::test));
3802        cx.update(language_settings::init);
3803
3804        let text = indoc! {"
3805            fn main() {
3806                le
3807            }
3808        "};
3809        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3810        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3811        let range = buffer.read_with(cx, |buffer, cx| {
3812            let snapshot = buffer.snapshot(cx);
3813            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3814        });
3815        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3816        let codegen = cx.new(|cx| {
3817            CodegenAlternative::new(
3818                buffer.clone(),
3819                range.clone(),
3820                true,
3821                None,
3822                prompt_builder,
3823                cx,
3824            )
3825        });
3826
3827        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3828
3829        cx.background_executor.run_until_parked();
3830
3831        let mut new_text = concat!(
3832            "t mut x = 0;\n",
3833            "while x < 10 {\n",
3834            "    x += 1;\n",
3835            "}", //
3836        );
3837        while !new_text.is_empty() {
3838            let max_len = cmp::min(new_text.len(), 10);
3839            let len = rng.gen_range(1..=max_len);
3840            let (chunk, suffix) = new_text.split_at(len);
3841            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3842            new_text = suffix;
3843            cx.background_executor.run_until_parked();
3844        }
3845        drop(chunks_tx);
3846        cx.background_executor.run_until_parked();
3847
3848        assert_eq!(
3849            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3850            indoc! {"
3851                fn main() {
3852                    let mut x = 0;
3853                    while x < 10 {
3854                        x += 1;
3855                    }
3856                }
3857            "}
3858        );
3859    }
3860
3861    #[gpui::test(iterations = 10)]
3862    async fn test_autoindent_when_generating_before_indentation(
3863        cx: &mut TestAppContext,
3864        mut rng: StdRng,
3865    ) {
3866        cx.update(LanguageModelRegistry::test);
3867        cx.set_global(cx.update(SettingsStore::test));
3868        cx.update(language_settings::init);
3869
3870        let text = concat!(
3871            "fn main() {\n",
3872            "  \n",
3873            "}\n" //
3874        );
3875        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3876        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3877        let range = buffer.read_with(cx, |buffer, cx| {
3878            let snapshot = buffer.snapshot(cx);
3879            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3880        });
3881        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3882        let codegen = cx.new(|cx| {
3883            CodegenAlternative::new(
3884                buffer.clone(),
3885                range.clone(),
3886                true,
3887                None,
3888                prompt_builder,
3889                cx,
3890            )
3891        });
3892
3893        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3894
3895        cx.background_executor.run_until_parked();
3896
3897        let mut new_text = concat!(
3898            "let mut x = 0;\n",
3899            "while x < 10 {\n",
3900            "    x += 1;\n",
3901            "}", //
3902        );
3903        while !new_text.is_empty() {
3904            let max_len = cmp::min(new_text.len(), 10);
3905            let len = rng.gen_range(1..=max_len);
3906            let (chunk, suffix) = new_text.split_at(len);
3907            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3908            new_text = suffix;
3909            cx.background_executor.run_until_parked();
3910        }
3911        drop(chunks_tx);
3912        cx.background_executor.run_until_parked();
3913
3914        assert_eq!(
3915            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3916            indoc! {"
3917                fn main() {
3918                    let mut x = 0;
3919                    while x < 10 {
3920                        x += 1;
3921                    }
3922                }
3923            "}
3924        );
3925    }
3926
3927    #[gpui::test(iterations = 10)]
3928    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3929        cx.update(LanguageModelRegistry::test);
3930        cx.set_global(cx.update(SettingsStore::test));
3931        cx.update(language_settings::init);
3932
3933        let text = indoc! {"
3934            func main() {
3935            \tx := 0
3936            \tfor i := 0; i < 10; i++ {
3937            \t\tx++
3938            \t}
3939            }
3940        "};
3941        let buffer = cx.new(|cx| Buffer::local(text, cx));
3942        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3943        let range = buffer.read_with(cx, |buffer, cx| {
3944            let snapshot = buffer.snapshot(cx);
3945            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3946        });
3947        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3948        let codegen = cx.new(|cx| {
3949            CodegenAlternative::new(
3950                buffer.clone(),
3951                range.clone(),
3952                true,
3953                None,
3954                prompt_builder,
3955                cx,
3956            )
3957        });
3958
3959        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3960        let new_text = concat!(
3961            "func main() {\n",
3962            "\tx := 0\n",
3963            "\tfor x < 10 {\n",
3964            "\t\tx++\n",
3965            "\t}", //
3966        );
3967        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3968        drop(chunks_tx);
3969        cx.background_executor.run_until_parked();
3970
3971        assert_eq!(
3972            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3973            indoc! {"
3974                func main() {
3975                \tx := 0
3976                \tfor x < 10 {
3977                \t\tx++
3978                \t}
3979                }
3980            "}
3981        );
3982    }
3983
3984    #[gpui::test]
3985    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3986        cx.update(LanguageModelRegistry::test);
3987        cx.set_global(cx.update(SettingsStore::test));
3988        cx.update(language_settings::init);
3989
3990        let text = indoc! {"
3991            fn main() {
3992                let x = 0;
3993            }
3994        "};
3995        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3996        let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3997        let range = buffer.read_with(cx, |buffer, cx| {
3998            let snapshot = buffer.snapshot(cx);
3999            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
4000        });
4001        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
4002        let codegen = cx.new(|cx| {
4003            CodegenAlternative::new(
4004                buffer.clone(),
4005                range.clone(),
4006                false,
4007                None,
4008                prompt_builder,
4009                cx,
4010            )
4011        });
4012
4013        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
4014        chunks_tx
4015            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
4016            .unwrap();
4017        drop(chunks_tx);
4018        cx.run_until_parked();
4019
4020        // The codegen is inactive, so the buffer doesn't get modified.
4021        assert_eq!(
4022            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
4023            text
4024        );
4025
4026        // Activating the codegen applies the changes.
4027        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
4028        assert_eq!(
4029            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
4030            indoc! {"
4031                fn main() {
4032                    let mut x = 0;
4033                    x += 1;
4034                }
4035            "}
4036        );
4037
4038        // Deactivating the codegen undoes the changes.
4039        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
4040        cx.run_until_parked();
4041        assert_eq!(
4042            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
4043            text
4044        );
4045    }
4046
4047    #[gpui::test]
4048    async fn test_strip_invalid_spans_from_codeblock() {
4049        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
4050        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
4051        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
4052        assert_chunks(
4053            "```html\n```js\nLorem ipsum dolor\n```\n```",
4054            "```js\nLorem ipsum dolor\n```",
4055        )
4056        .await;
4057        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
4058        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
4059        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
4060        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
4061
4062        async fn assert_chunks(text: &str, expected_text: &str) {
4063            for chunk_size in 1..=text.len() {
4064                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
4065                    .map(|chunk| chunk.unwrap())
4066                    .collect::<String>()
4067                    .await;
4068                assert_eq!(
4069                    actual_text, expected_text,
4070                    "failed to strip invalid spans, chunk size: {}",
4071                    chunk_size
4072                );
4073            }
4074        }
4075
4076        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
4077            stream::iter(
4078                text.chars()
4079                    .collect::<Vec<_>>()
4080                    .chunks(size)
4081                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
4082                    .collect::<Vec<_>>(),
4083            )
4084        }
4085    }
4086
4087    fn simulate_response_stream(
4088        codegen: Entity<CodegenAlternative>,
4089        cx: &mut TestAppContext,
4090    ) -> mpsc::UnboundedSender<String> {
4091        let (chunks_tx, chunks_rx) = mpsc::unbounded();
4092        codegen.update(cx, |codegen, cx| {
4093            codegen.handle_stream(
4094                String::new(),
4095                String::new(),
4096                None,
4097                future::ready(Ok(LanguageModelTextStream {
4098                    message_id: None,
4099                    stream: chunks_rx.map(Ok).boxed(),
4100                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
4101                })),
4102                cx,
4103            );
4104        });
4105        chunks_tx
4106    }
4107
4108    fn rust_lang() -> Language {
4109        Language::new(
4110            LanguageConfig {
4111                name: "Rust".into(),
4112                matcher: LanguageMatcher {
4113                    path_suffixes: vec!["rs".to_string()],
4114                    ..Default::default()
4115                },
4116                ..Default::default()
4117            },
4118            Some(tree_sitter_rust::LANGUAGE.into()),
4119        )
4120        .with_indents_query(
4121            r#"
4122            (call_expression) @indent
4123            (field_expression) @indent
4124            (_ "(" ")" @end) @indent
4125            (_ "{" "}" @end) @indent
4126            "#,
4127        )
4128        .unwrap()
4129    }
4130}