inline_assistant.rs

   1use crate::{
   2    assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder,
   3    AssistantPanel, AssistantPanelEvent, CharOperation, CycleNextInlineAssist,
   4    CyclePreviousInlineAssist, LineDiff, LineOperation, ModelSelector, StreamingDiff,
   5};
   6use anyhow::{anyhow, Context as _, Result};
   7use client::{telemetry::Telemetry, ErrorExt};
   8use collections::{hash_map, HashMap, HashSet, VecDeque};
   9use editor::{
  10    actions::{MoveDown, MoveUp, SelectAll},
  11    display_map::{
  12        BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
  13        ToDisplayPoint,
  14    },
  15    Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
  16    EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
  17    ToOffset as _, ToPoint,
  18};
  19use feature_flags::{FeatureFlagAppExt as _, ZedPro};
  20use fs::Fs;
  21use futures::{
  22    channel::mpsc,
  23    future::{BoxFuture, LocalBoxFuture},
  24    join,
  25    stream::{self, BoxStream},
  26    SinkExt, Stream, StreamExt,
  27};
  28use gpui::{
  29    anchored, deferred, point, AnyElement, AppContext, ClickEvent, EventEmitter, FocusHandle,
  30    FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task,
  31    TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
  32};
  33use language::{Buffer, IndentKind, Point, Selection, TransactionId};
  34use language_model::{
  35    LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
  36};
  37use multi_buffer::MultiBufferRow;
  38use parking_lot::Mutex;
  39use project::{CodeAction, ProjectTransaction};
  40use rope::Rope;
  41use settings::{Settings, SettingsStore};
  42use smol::future::FutureExt;
  43use std::{
  44    cmp,
  45    future::{self, Future},
  46    iter, mem,
  47    ops::{Range, RangeInclusive},
  48    pin::Pin,
  49    sync::Arc,
  50    task::{self, Poll},
  51    time::{Duration, Instant},
  52};
  53use terminal_view::terminal_panel::TerminalPanel;
  54use text::{OffsetRangeExt, ToPoint as _};
  55use theme::ThemeSettings;
  56use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
  57use util::{RangeExt, ResultExt};
  58use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
  59
  60pub fn init(
  61    fs: Arc<dyn Fs>,
  62    prompt_builder: Arc<PromptBuilder>,
  63    telemetry: Arc<Telemetry>,
  64    cx: &mut AppContext,
  65) {
  66    cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
  67    cx.observe_new_views(|_, cx| {
  68        let workspace = cx.view().clone();
  69        InlineAssistant::update_global(cx, |inline_assistant, cx| {
  70            inline_assistant.register_workspace(&workspace, cx)
  71        })
  72    })
  73    .detach();
  74}
  75
  76const PROMPT_HISTORY_MAX_LEN: usize = 20;
  77
  78pub struct InlineAssistant {
  79    next_assist_id: InlineAssistId,
  80    next_assist_group_id: InlineAssistGroupId,
  81    assists: HashMap<InlineAssistId, InlineAssist>,
  82    assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
  83    assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
  84    assist_observations: HashMap<
  85        InlineAssistId,
  86        (
  87            async_watch::Sender<AssistStatus>,
  88            async_watch::Receiver<AssistStatus>,
  89        ),
  90    >,
  91    confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
  92    prompt_history: VecDeque<String>,
  93    prompt_builder: Arc<PromptBuilder>,
  94    telemetry: Option<Arc<Telemetry>>,
  95    fs: Arc<dyn Fs>,
  96}
  97
  98pub enum AssistStatus {
  99    Idle,
 100    Started,
 101    Stopped,
 102    Finished,
 103}
 104
 105impl AssistStatus {
 106    pub fn is_done(&self) -> bool {
 107        matches!(self, Self::Stopped | Self::Finished)
 108    }
 109}
 110
 111impl Global for InlineAssistant {}
 112
 113impl InlineAssistant {
 114    pub fn new(
 115        fs: Arc<dyn Fs>,
 116        prompt_builder: Arc<PromptBuilder>,
 117        telemetry: Arc<Telemetry>,
 118    ) -> Self {
 119        Self {
 120            next_assist_id: InlineAssistId::default(),
 121            next_assist_group_id: InlineAssistGroupId::default(),
 122            assists: HashMap::default(),
 123            assists_by_editor: HashMap::default(),
 124            assist_groups: HashMap::default(),
 125            assist_observations: HashMap::default(),
 126            confirmed_assists: HashMap::default(),
 127            prompt_history: VecDeque::default(),
 128            prompt_builder,
 129            telemetry: Some(telemetry),
 130            fs,
 131        }
 132    }
 133
 134    pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
 135        cx.subscribe(workspace, |workspace, event, cx| {
 136            Self::update_global(cx, |this, cx| {
 137                this.handle_workspace_event(workspace, event, cx)
 138            });
 139        })
 140        .detach();
 141
 142        let workspace = workspace.downgrade();
 143        cx.observe_global::<SettingsStore>(move |cx| {
 144            let Some(workspace) = workspace.upgrade() else {
 145                return;
 146            };
 147            let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
 148                return;
 149            };
 150            let enabled = AssistantSettings::get_global(cx).enabled;
 151            terminal_panel.update(cx, |terminal_panel, cx| {
 152                terminal_panel.asssistant_enabled(enabled, cx)
 153            });
 154        })
 155        .detach();
 156    }
 157
 158    fn handle_workspace_event(
 159        &mut self,
 160        workspace: View<Workspace>,
 161        event: &workspace::Event,
 162        cx: &mut WindowContext,
 163    ) {
 164        match event {
 165            workspace::Event::UserSavedItem { item, .. } => {
 166                // When the user manually saves an editor, automatically accepts all finished transformations.
 167                if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
 168                    if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
 169                        for assist_id in editor_assists.assist_ids.clone() {
 170                            let assist = &self.assists[&assist_id];
 171                            if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
 172                                self.finish_assist(assist_id, false, cx)
 173                            }
 174                        }
 175                    }
 176                }
 177            }
 178            workspace::Event::ItemAdded { item } => {
 179                self.register_workspace_item(&workspace, item.as_ref(), cx);
 180            }
 181            _ => (),
 182        }
 183    }
 184
 185    fn register_workspace_item(
 186        &mut self,
 187        workspace: &View<Workspace>,
 188        item: &dyn ItemHandle,
 189        cx: &mut WindowContext,
 190    ) {
 191        if let Some(editor) = item.act_as::<Editor>(cx) {
 192            editor.update(cx, |editor, cx| {
 193                editor.push_code_action_provider(
 194                    Arc::new(AssistantCodeActionProvider {
 195                        editor: cx.view().downgrade(),
 196                        workspace: workspace.downgrade(),
 197                    }),
 198                    cx,
 199                );
 200            });
 201        }
 202    }
 203
 204    pub fn assist(
 205        &mut self,
 206        editor: &View<Editor>,
 207        workspace: Option<WeakView<Workspace>>,
 208        assistant_panel: Option<&View<AssistantPanel>>,
 209        initial_prompt: Option<String>,
 210        cx: &mut WindowContext,
 211    ) {
 212        if let Some(telemetry) = self.telemetry.as_ref() {
 213            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
 214                telemetry.report_assistant_event(
 215                    None,
 216                    telemetry_events::AssistantKind::Inline,
 217                    telemetry_events::AssistantPhase::Invoked,
 218                    model.telemetry_id(),
 219                    None,
 220                    None,
 221                );
 222            }
 223        }
 224        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 225
 226        let mut selections = Vec::<Selection<Point>>::new();
 227        let mut newest_selection = None;
 228        for mut selection in editor.read(cx).selections.all::<Point>(cx) {
 229            if selection.end > selection.start {
 230                selection.start.column = 0;
 231                // If the selection ends at the start of the line, we don't want to include it.
 232                if selection.end.column == 0 {
 233                    selection.end.row -= 1;
 234                }
 235                selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
 236            }
 237
 238            if let Some(prev_selection) = selections.last_mut() {
 239                if selection.start <= prev_selection.end {
 240                    prev_selection.end = selection.end;
 241                    continue;
 242                }
 243            }
 244
 245            let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
 246            if selection.id > latest_selection.id {
 247                *latest_selection = selection.clone();
 248            }
 249            selections.push(selection);
 250        }
 251        let newest_selection = newest_selection.unwrap();
 252
 253        let mut codegen_ranges = Vec::new();
 254        for (excerpt_id, buffer, buffer_range) in
 255            snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
 256                snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
 257            }))
 258        {
 259            let start = Anchor {
 260                buffer_id: Some(buffer.remote_id()),
 261                excerpt_id,
 262                text_anchor: buffer.anchor_before(buffer_range.start),
 263            };
 264            let end = Anchor {
 265                buffer_id: Some(buffer.remote_id()),
 266                excerpt_id,
 267                text_anchor: buffer.anchor_after(buffer_range.end),
 268            };
 269            codegen_ranges.push(start..end);
 270        }
 271
 272        let assist_group_id = self.next_assist_group_id.post_inc();
 273        let prompt_buffer =
 274            cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
 275        let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 276
 277        let mut assists = Vec::new();
 278        let mut assist_to_focus = None;
 279        for range in codegen_ranges {
 280            let assist_id = self.next_assist_id.post_inc();
 281            let codegen = cx.new_model(|cx| {
 282                Codegen::new(
 283                    editor.read(cx).buffer().clone(),
 284                    range.clone(),
 285                    None,
 286                    self.telemetry.clone(),
 287                    self.prompt_builder.clone(),
 288                    cx,
 289                )
 290            });
 291
 292            let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 293            let prompt_editor = cx.new_view(|cx| {
 294                PromptEditor::new(
 295                    assist_id,
 296                    gutter_dimensions.clone(),
 297                    self.prompt_history.clone(),
 298                    prompt_buffer.clone(),
 299                    codegen.clone(),
 300                    editor,
 301                    assistant_panel,
 302                    workspace.clone(),
 303                    self.fs.clone(),
 304                    cx,
 305                )
 306            });
 307
 308            if assist_to_focus.is_none() {
 309                let focus_assist = if newest_selection.reversed {
 310                    range.start.to_point(&snapshot) == newest_selection.start
 311                } else {
 312                    range.end.to_point(&snapshot) == newest_selection.end
 313                };
 314                if focus_assist {
 315                    assist_to_focus = Some(assist_id);
 316                }
 317            }
 318
 319            let [prompt_block_id, end_block_id] =
 320                self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 321
 322            assists.push((
 323                assist_id,
 324                range,
 325                prompt_editor,
 326                prompt_block_id,
 327                end_block_id,
 328            ));
 329        }
 330
 331        let editor_assists = self
 332            .assists_by_editor
 333            .entry(editor.downgrade())
 334            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 335        let mut assist_group = InlineAssistGroup::new();
 336        for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
 337            self.assists.insert(
 338                assist_id,
 339                InlineAssist::new(
 340                    assist_id,
 341                    assist_group_id,
 342                    assistant_panel.is_some(),
 343                    editor,
 344                    &prompt_editor,
 345                    prompt_block_id,
 346                    end_block_id,
 347                    range,
 348                    prompt_editor.read(cx).codegen.clone(),
 349                    workspace.clone(),
 350                    cx,
 351                ),
 352            );
 353            assist_group.assist_ids.push(assist_id);
 354            editor_assists.assist_ids.push(assist_id);
 355        }
 356        self.assist_groups.insert(assist_group_id, assist_group);
 357
 358        if let Some(assist_id) = assist_to_focus {
 359            self.focus_assist(assist_id, cx);
 360        }
 361    }
 362
 363    #[allow(clippy::too_many_arguments)]
 364    pub fn suggest_assist(
 365        &mut self,
 366        editor: &View<Editor>,
 367        mut range: Range<Anchor>,
 368        initial_prompt: String,
 369        initial_transaction_id: Option<TransactionId>,
 370        focus: bool,
 371        workspace: Option<WeakView<Workspace>>,
 372        assistant_panel: Option<&View<AssistantPanel>>,
 373        cx: &mut WindowContext,
 374    ) -> InlineAssistId {
 375        let assist_group_id = self.next_assist_group_id.post_inc();
 376        let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
 377        let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 378
 379        let assist_id = self.next_assist_id.post_inc();
 380
 381        let buffer = editor.read(cx).buffer().clone();
 382        {
 383            let snapshot = buffer.read(cx).read(cx);
 384            range.start = range.start.bias_left(&snapshot);
 385            range.end = range.end.bias_right(&snapshot);
 386        }
 387
 388        let codegen = cx.new_model(|cx| {
 389            Codegen::new(
 390                editor.read(cx).buffer().clone(),
 391                range.clone(),
 392                initial_transaction_id,
 393                self.telemetry.clone(),
 394                self.prompt_builder.clone(),
 395                cx,
 396            )
 397        });
 398
 399        let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 400        let prompt_editor = cx.new_view(|cx| {
 401            PromptEditor::new(
 402                assist_id,
 403                gutter_dimensions.clone(),
 404                self.prompt_history.clone(),
 405                prompt_buffer.clone(),
 406                codegen.clone(),
 407                editor,
 408                assistant_panel,
 409                workspace.clone(),
 410                self.fs.clone(),
 411                cx,
 412            )
 413        });
 414
 415        let [prompt_block_id, end_block_id] =
 416            self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 417
 418        let editor_assists = self
 419            .assists_by_editor
 420            .entry(editor.downgrade())
 421            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 422
 423        let mut assist_group = InlineAssistGroup::new();
 424        self.assists.insert(
 425            assist_id,
 426            InlineAssist::new(
 427                assist_id,
 428                assist_group_id,
 429                assistant_panel.is_some(),
 430                editor,
 431                &prompt_editor,
 432                prompt_block_id,
 433                end_block_id,
 434                range,
 435                prompt_editor.read(cx).codegen.clone(),
 436                workspace.clone(),
 437                cx,
 438            ),
 439        );
 440        assist_group.assist_ids.push(assist_id);
 441        editor_assists.assist_ids.push(assist_id);
 442        self.assist_groups.insert(assist_group_id, assist_group);
 443
 444        if focus {
 445            self.focus_assist(assist_id, cx);
 446        }
 447
 448        assist_id
 449    }
 450
 451    fn insert_assist_blocks(
 452        &self,
 453        editor: &View<Editor>,
 454        range: &Range<Anchor>,
 455        prompt_editor: &View<PromptEditor>,
 456        cx: &mut WindowContext,
 457    ) -> [CustomBlockId; 2] {
 458        let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
 459            prompt_editor
 460                .editor
 461                .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
 462        });
 463        let assist_blocks = vec![
 464            BlockProperties {
 465                style: BlockStyle::Sticky,
 466                position: range.start,
 467                height: prompt_editor_height,
 468                render: build_assist_editor_renderer(prompt_editor),
 469                disposition: BlockDisposition::Above,
 470                priority: 0,
 471            },
 472            BlockProperties {
 473                style: BlockStyle::Sticky,
 474                position: range.end,
 475                height: 0,
 476                render: Box::new(|cx| {
 477                    v_flex()
 478                        .h_full()
 479                        .w_full()
 480                        .border_t_1()
 481                        .border_color(cx.theme().status().info_border)
 482                        .into_any_element()
 483                }),
 484                disposition: BlockDisposition::Below,
 485                priority: 0,
 486            },
 487        ];
 488
 489        editor.update(cx, |editor, cx| {
 490            let block_ids = editor.insert_blocks(assist_blocks, None, cx);
 491            [block_ids[0], block_ids[1]]
 492        })
 493    }
 494
 495    fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 496        let assist = &self.assists[&assist_id];
 497        let Some(decorations) = assist.decorations.as_ref() else {
 498            return;
 499        };
 500        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 501        let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
 502
 503        assist_group.active_assist_id = Some(assist_id);
 504        if assist_group.linked {
 505            for assist_id in &assist_group.assist_ids {
 506                if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 507                    decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 508                        prompt_editor.set_show_cursor_when_unfocused(true, cx)
 509                    });
 510                }
 511            }
 512        }
 513
 514        assist
 515            .editor
 516            .update(cx, |editor, cx| {
 517                let scroll_top = editor.scroll_position(cx).y;
 518                let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
 519                let prompt_row = editor
 520                    .row_for_block(decorations.prompt_block_id, cx)
 521                    .unwrap()
 522                    .0 as f32;
 523
 524                if (scroll_top..scroll_bottom).contains(&prompt_row) {
 525                    editor_assists.scroll_lock = Some(InlineAssistScrollLock {
 526                        assist_id,
 527                        distance_from_top: prompt_row - scroll_top,
 528                    });
 529                } else {
 530                    editor_assists.scroll_lock = None;
 531                }
 532            })
 533            .ok();
 534    }
 535
 536    fn handle_prompt_editor_focus_out(
 537        &mut self,
 538        assist_id: InlineAssistId,
 539        cx: &mut WindowContext,
 540    ) {
 541        let assist = &self.assists[&assist_id];
 542        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 543        if assist_group.active_assist_id == Some(assist_id) {
 544            assist_group.active_assist_id = None;
 545            if assist_group.linked {
 546                for assist_id in &assist_group.assist_ids {
 547                    if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 548                        decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 549                            prompt_editor.set_show_cursor_when_unfocused(false, cx)
 550                        });
 551                    }
 552                }
 553            }
 554        }
 555    }
 556
 557    fn handle_prompt_editor_event(
 558        &mut self,
 559        prompt_editor: View<PromptEditor>,
 560        event: &PromptEditorEvent,
 561        cx: &mut WindowContext,
 562    ) {
 563        let assist_id = prompt_editor.read(cx).id;
 564        match event {
 565            PromptEditorEvent::StartRequested => {
 566                self.start_assist(assist_id, cx);
 567            }
 568            PromptEditorEvent::StopRequested => {
 569                self.stop_assist(assist_id, cx);
 570            }
 571            PromptEditorEvent::ConfirmRequested => {
 572                self.finish_assist(assist_id, false, cx);
 573            }
 574            PromptEditorEvent::CancelRequested => {
 575                self.finish_assist(assist_id, true, cx);
 576            }
 577            PromptEditorEvent::DismissRequested => {
 578                self.dismiss_assist(assist_id, cx);
 579            }
 580        }
 581    }
 582
 583    fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 584        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 585            return;
 586        };
 587
 588        let editor = editor.read(cx);
 589        if editor.selections.count() == 1 {
 590            let selection = editor.selections.newest::<usize>(cx);
 591            let buffer = editor.buffer().read(cx).snapshot(cx);
 592            for assist_id in &editor_assists.assist_ids {
 593                let assist = &self.assists[assist_id];
 594                let assist_range = assist.range.to_offset(&buffer);
 595                if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
 596                {
 597                    if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
 598                        self.dismiss_assist(*assist_id, cx);
 599                    } else {
 600                        self.finish_assist(*assist_id, false, cx);
 601                    }
 602
 603                    return;
 604                }
 605            }
 606        }
 607
 608        cx.propagate();
 609    }
 610
 611    fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 612        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 613            return;
 614        };
 615
 616        let editor = editor.read(cx);
 617        if editor.selections.count() == 1 {
 618            let selection = editor.selections.newest::<usize>(cx);
 619            let buffer = editor.buffer().read(cx).snapshot(cx);
 620            let mut closest_assist_fallback = None;
 621            for assist_id in &editor_assists.assist_ids {
 622                let assist = &self.assists[assist_id];
 623                let assist_range = assist.range.to_offset(&buffer);
 624                if assist.decorations.is_some() {
 625                    if assist_range.contains(&selection.start)
 626                        && assist_range.contains(&selection.end)
 627                    {
 628                        self.focus_assist(*assist_id, cx);
 629                        return;
 630                    } else {
 631                        let distance_from_selection = assist_range
 632                            .start
 633                            .abs_diff(selection.start)
 634                            .min(assist_range.start.abs_diff(selection.end))
 635                            + assist_range
 636                                .end
 637                                .abs_diff(selection.start)
 638                                .min(assist_range.end.abs_diff(selection.end));
 639                        match closest_assist_fallback {
 640                            Some((_, old_distance)) => {
 641                                if distance_from_selection < old_distance {
 642                                    closest_assist_fallback =
 643                                        Some((assist_id, distance_from_selection));
 644                                }
 645                            }
 646                            None => {
 647                                closest_assist_fallback = Some((assist_id, distance_from_selection))
 648                            }
 649                        }
 650                    }
 651                }
 652            }
 653
 654            if let Some((&assist_id, _)) = closest_assist_fallback {
 655                self.focus_assist(assist_id, cx);
 656            }
 657        }
 658
 659        cx.propagate();
 660    }
 661
 662    fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
 663        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
 664            for assist_id in editor_assists.assist_ids.clone() {
 665                self.finish_assist(assist_id, true, cx);
 666            }
 667        }
 668    }
 669
 670    fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 671        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 672            return;
 673        };
 674        let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
 675            return;
 676        };
 677        let assist = &self.assists[&scroll_lock.assist_id];
 678        let Some(decorations) = assist.decorations.as_ref() else {
 679            return;
 680        };
 681
 682        editor.update(cx, |editor, cx| {
 683            let scroll_position = editor.scroll_position(cx);
 684            let target_scroll_top = editor
 685                .row_for_block(decorations.prompt_block_id, cx)
 686                .unwrap()
 687                .0 as f32
 688                - scroll_lock.distance_from_top;
 689            if target_scroll_top != scroll_position.y {
 690                editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
 691            }
 692        });
 693    }
 694
 695    fn handle_editor_event(
 696        &mut self,
 697        editor: View<Editor>,
 698        event: &EditorEvent,
 699        cx: &mut WindowContext,
 700    ) {
 701        let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
 702            return;
 703        };
 704
 705        match event {
 706            EditorEvent::Edited { transaction_id } => {
 707                let buffer = editor.read(cx).buffer().read(cx);
 708                let edited_ranges =
 709                    buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
 710                let snapshot = buffer.snapshot(cx);
 711
 712                for assist_id in editor_assists.assist_ids.clone() {
 713                    let assist = &self.assists[&assist_id];
 714                    if matches!(
 715                        assist.codegen.read(cx).status(cx),
 716                        CodegenStatus::Error(_) | CodegenStatus::Done
 717                    ) {
 718                        let assist_range = assist.range.to_offset(&snapshot);
 719                        if edited_ranges
 720                            .iter()
 721                            .any(|range| range.overlaps(&assist_range))
 722                        {
 723                            self.finish_assist(assist_id, false, cx);
 724                        }
 725                    }
 726                }
 727            }
 728            EditorEvent::ScrollPositionChanged { .. } => {
 729                if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
 730                    let assist = &self.assists[&scroll_lock.assist_id];
 731                    if let Some(decorations) = assist.decorations.as_ref() {
 732                        let distance_from_top = editor.update(cx, |editor, cx| {
 733                            let scroll_top = editor.scroll_position(cx).y;
 734                            let prompt_row = editor
 735                                .row_for_block(decorations.prompt_block_id, cx)
 736                                .unwrap()
 737                                .0 as f32;
 738                            prompt_row - scroll_top
 739                        });
 740
 741                        if distance_from_top != scroll_lock.distance_from_top {
 742                            editor_assists.scroll_lock = None;
 743                        }
 744                    }
 745                }
 746            }
 747            EditorEvent::SelectionsChanged { .. } => {
 748                for assist_id in editor_assists.assist_ids.clone() {
 749                    let assist = &self.assists[&assist_id];
 750                    if let Some(decorations) = assist.decorations.as_ref() {
 751                        if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
 752                            return;
 753                        }
 754                    }
 755                }
 756
 757                editor_assists.scroll_lock = None;
 758            }
 759            _ => {}
 760        }
 761    }
 762
 763    pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
 764        if let Some(telemetry) = self.telemetry.as_ref() {
 765            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
 766                telemetry.report_assistant_event(
 767                    None,
 768                    telemetry_events::AssistantKind::Inline,
 769                    if undo {
 770                        telemetry_events::AssistantPhase::Rejected
 771                    } else {
 772                        telemetry_events::AssistantPhase::Accepted
 773                    },
 774                    model.telemetry_id(),
 775                    None,
 776                    None,
 777                );
 778            }
 779        }
 780        if let Some(assist) = self.assists.get(&assist_id) {
 781            let assist_group_id = assist.group_id;
 782            if self.assist_groups[&assist_group_id].linked {
 783                for assist_id in self.unlink_assist_group(assist_group_id, cx) {
 784                    self.finish_assist(assist_id, undo, cx);
 785                }
 786                return;
 787            }
 788        }
 789
 790        self.dismiss_assist(assist_id, cx);
 791
 792        if let Some(assist) = self.assists.remove(&assist_id) {
 793            if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
 794            {
 795                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 796                if entry.get().assist_ids.is_empty() {
 797                    entry.remove();
 798                }
 799            }
 800
 801            if let hash_map::Entry::Occupied(mut entry) =
 802                self.assists_by_editor.entry(assist.editor.clone())
 803            {
 804                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 805                if entry.get().assist_ids.is_empty() {
 806                    entry.remove();
 807                    if let Some(editor) = assist.editor.upgrade() {
 808                        self.update_editor_highlights(&editor, cx);
 809                    }
 810                } else {
 811                    entry.get().highlight_updates.send(()).ok();
 812                }
 813            }
 814
 815            if undo {
 816                assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
 817            } else {
 818                let confirmed_alternative = assist.codegen.read(cx).active_alternative().clone();
 819                self.confirmed_assists
 820                    .insert(assist_id, confirmed_alternative);
 821            }
 822        }
 823
 824        // Remove the assist from the status updates map
 825        self.assist_observations.remove(&assist_id);
 826    }
 827
 828    pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
 829        let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
 830            return false;
 831        };
 832        codegen.update(cx, |this, cx| this.undo(cx));
 833        true
 834    }
 835
 836    fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
 837        let Some(assist) = self.assists.get_mut(&assist_id) else {
 838            return false;
 839        };
 840        let Some(editor) = assist.editor.upgrade() else {
 841            return false;
 842        };
 843        let Some(decorations) = assist.decorations.take() else {
 844            return false;
 845        };
 846
 847        editor.update(cx, |editor, cx| {
 848            let mut to_remove = decorations.removed_line_block_ids;
 849            to_remove.insert(decorations.prompt_block_id);
 850            to_remove.insert(decorations.end_block_id);
 851            editor.remove_blocks(to_remove, None, cx);
 852        });
 853
 854        if decorations
 855            .prompt_editor
 856            .focus_handle(cx)
 857            .contains_focused(cx)
 858        {
 859            self.focus_next_assist(assist_id, cx);
 860        }
 861
 862        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
 863            if editor_assists
 864                .scroll_lock
 865                .as_ref()
 866                .map_or(false, |lock| lock.assist_id == assist_id)
 867            {
 868                editor_assists.scroll_lock = None;
 869            }
 870            editor_assists.highlight_updates.send(()).ok();
 871        }
 872
 873        true
 874    }
 875
 876    fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 877        let Some(assist) = self.assists.get(&assist_id) else {
 878            return;
 879        };
 880
 881        let assist_group = &self.assist_groups[&assist.group_id];
 882        let assist_ix = assist_group
 883            .assist_ids
 884            .iter()
 885            .position(|id| *id == assist_id)
 886            .unwrap();
 887        let assist_ids = assist_group
 888            .assist_ids
 889            .iter()
 890            .skip(assist_ix + 1)
 891            .chain(assist_group.assist_ids.iter().take(assist_ix));
 892
 893        for assist_id in assist_ids {
 894            let assist = &self.assists[assist_id];
 895            if assist.decorations.is_some() {
 896                self.focus_assist(*assist_id, cx);
 897                return;
 898            }
 899        }
 900
 901        assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
 902    }
 903
 904    fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 905        let Some(assist) = self.assists.get(&assist_id) else {
 906            return;
 907        };
 908
 909        if let Some(decorations) = assist.decorations.as_ref() {
 910            decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 911                prompt_editor.editor.update(cx, |editor, cx| {
 912                    editor.focus(cx);
 913                    editor.select_all(&SelectAll, cx);
 914                })
 915            });
 916        }
 917
 918        self.scroll_to_assist(assist_id, cx);
 919    }
 920
 921    pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 922        let Some(assist) = self.assists.get(&assist_id) else {
 923            return;
 924        };
 925        let Some(editor) = assist.editor.upgrade() else {
 926            return;
 927        };
 928
 929        let position = assist.range.start;
 930        editor.update(cx, |editor, cx| {
 931            editor.change_selections(None, cx, |selections| {
 932                selections.select_anchor_ranges([position..position])
 933            });
 934
 935            let mut scroll_target_top;
 936            let mut scroll_target_bottom;
 937            if let Some(decorations) = assist.decorations.as_ref() {
 938                scroll_target_top = editor
 939                    .row_for_block(decorations.prompt_block_id, cx)
 940                    .unwrap()
 941                    .0 as f32;
 942                scroll_target_bottom = editor
 943                    .row_for_block(decorations.end_block_id, cx)
 944                    .unwrap()
 945                    .0 as f32;
 946            } else {
 947                let snapshot = editor.snapshot(cx);
 948                let start_row = assist
 949                    .range
 950                    .start
 951                    .to_display_point(&snapshot.display_snapshot)
 952                    .row();
 953                scroll_target_top = start_row.0 as f32;
 954                scroll_target_bottom = scroll_target_top + 1.;
 955            }
 956            scroll_target_top -= editor.vertical_scroll_margin() as f32;
 957            scroll_target_bottom += editor.vertical_scroll_margin() as f32;
 958
 959            let height_in_lines = editor.visible_line_count().unwrap_or(0.);
 960            let scroll_top = editor.scroll_position(cx).y;
 961            let scroll_bottom = scroll_top + height_in_lines;
 962
 963            if scroll_target_top < scroll_top {
 964                editor.set_scroll_position(point(0., scroll_target_top), cx);
 965            } else if scroll_target_bottom > scroll_bottom {
 966                if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
 967                    editor
 968                        .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
 969                } else {
 970                    editor.set_scroll_position(point(0., scroll_target_top), cx);
 971                }
 972            }
 973        });
 974    }
 975
 976    fn unlink_assist_group(
 977        &mut self,
 978        assist_group_id: InlineAssistGroupId,
 979        cx: &mut WindowContext,
 980    ) -> Vec<InlineAssistId> {
 981        let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
 982        assist_group.linked = false;
 983        for assist_id in &assist_group.assist_ids {
 984            let assist = self.assists.get_mut(assist_id).unwrap();
 985            if let Some(editor_decorations) = assist.decorations.as_ref() {
 986                editor_decorations
 987                    .prompt_editor
 988                    .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
 989            }
 990        }
 991        assist_group.assist_ids.clone()
 992    }
 993
 994    pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 995        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
 996            assist
 997        } else {
 998            return;
 999        };
1000
1001        let assist_group_id = assist.group_id;
1002        if self.assist_groups[&assist_group_id].linked {
1003            for assist_id in self.unlink_assist_group(assist_group_id, cx) {
1004                self.start_assist(assist_id, cx);
1005            }
1006            return;
1007        }
1008
1009        let Some(user_prompt) = assist.user_prompt(cx) else {
1010            return;
1011        };
1012
1013        self.prompt_history.retain(|prompt| *prompt != user_prompt);
1014        self.prompt_history.push_back(user_prompt.clone());
1015        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1016            self.prompt_history.pop_front();
1017        }
1018
1019        let assistant_panel_context = assist.assistant_panel_context(cx);
1020
1021        assist
1022            .codegen
1023            .update(cx, |codegen, cx| {
1024                codegen.start(user_prompt, assistant_panel_context, cx)
1025            })
1026            .log_err();
1027
1028        if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
1029            tx.send(AssistStatus::Started).ok();
1030        }
1031    }
1032
1033    pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1034        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1035            assist
1036        } else {
1037            return;
1038        };
1039
1040        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1041
1042        if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
1043            tx.send(AssistStatus::Stopped).ok();
1044        }
1045    }
1046
1047    pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
1048        if let Some(assist) = self.assists.get(&assist_id) {
1049            match assist.codegen.read(cx).status(cx) {
1050                CodegenStatus::Idle => InlineAssistStatus::Idle,
1051                CodegenStatus::Pending => InlineAssistStatus::Pending,
1052                CodegenStatus::Done => InlineAssistStatus::Done,
1053                CodegenStatus::Error(_) => InlineAssistStatus::Error,
1054            }
1055        } else if self.confirmed_assists.contains_key(&assist_id) {
1056            InlineAssistStatus::Confirmed
1057        } else {
1058            InlineAssistStatus::Canceled
1059        }
1060    }
1061
1062    fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1063        let mut gutter_pending_ranges = Vec::new();
1064        let mut gutter_transformed_ranges = Vec::new();
1065        let mut foreground_ranges = Vec::new();
1066        let mut inserted_row_ranges = Vec::new();
1067        let empty_assist_ids = Vec::new();
1068        let assist_ids = self
1069            .assists_by_editor
1070            .get(&editor.downgrade())
1071            .map_or(&empty_assist_ids, |editor_assists| {
1072                &editor_assists.assist_ids
1073            });
1074
1075        for assist_id in assist_ids {
1076            if let Some(assist) = self.assists.get(assist_id) {
1077                let codegen = assist.codegen.read(cx);
1078                let buffer = codegen.buffer(cx).read(cx).read(cx);
1079                foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1080
1081                let pending_range =
1082                    codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1083                if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1084                    gutter_pending_ranges.push(pending_range);
1085                }
1086
1087                if let Some(edit_position) = codegen.edit_position(cx) {
1088                    let edited_range = assist.range.start..edit_position;
1089                    if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1090                        gutter_transformed_ranges.push(edited_range);
1091                    }
1092                }
1093
1094                if assist.decorations.is_some() {
1095                    inserted_row_ranges
1096                        .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1097                }
1098            }
1099        }
1100
1101        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1102        merge_ranges(&mut foreground_ranges, &snapshot);
1103        merge_ranges(&mut gutter_pending_ranges, &snapshot);
1104        merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1105        editor.update(cx, |editor, cx| {
1106            enum GutterPendingRange {}
1107            if gutter_pending_ranges.is_empty() {
1108                editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1109            } else {
1110                editor.highlight_gutter::<GutterPendingRange>(
1111                    &gutter_pending_ranges,
1112                    |cx| cx.theme().status().info_background,
1113                    cx,
1114                )
1115            }
1116
1117            enum GutterTransformedRange {}
1118            if gutter_transformed_ranges.is_empty() {
1119                editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1120            } else {
1121                editor.highlight_gutter::<GutterTransformedRange>(
1122                    &gutter_transformed_ranges,
1123                    |cx| cx.theme().status().info,
1124                    cx,
1125                )
1126            }
1127
1128            if foreground_ranges.is_empty() {
1129                editor.clear_highlights::<InlineAssist>(cx);
1130            } else {
1131                editor.highlight_text::<InlineAssist>(
1132                    foreground_ranges,
1133                    HighlightStyle {
1134                        fade_out: Some(0.6),
1135                        ..Default::default()
1136                    },
1137                    cx,
1138                );
1139            }
1140
1141            editor.clear_row_highlights::<InlineAssist>();
1142            for row_range in inserted_row_ranges {
1143                editor.highlight_rows::<InlineAssist>(
1144                    row_range,
1145                    Some(cx.theme().status().info_background),
1146                    false,
1147                    cx,
1148                );
1149            }
1150        });
1151    }
1152
1153    fn update_editor_blocks(
1154        &mut self,
1155        editor: &View<Editor>,
1156        assist_id: InlineAssistId,
1157        cx: &mut WindowContext,
1158    ) {
1159        let Some(assist) = self.assists.get_mut(&assist_id) else {
1160            return;
1161        };
1162        let Some(decorations) = assist.decorations.as_mut() else {
1163            return;
1164        };
1165
1166        let codegen = assist.codegen.read(cx);
1167        let old_snapshot = codegen.snapshot(cx);
1168        let old_buffer = codegen.old_buffer(cx);
1169        let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1170
1171        editor.update(cx, |editor, cx| {
1172            let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1173            editor.remove_blocks(old_blocks, None, cx);
1174
1175            let mut new_blocks = Vec::new();
1176            for (new_row, old_row_range) in deleted_row_ranges {
1177                let (_, buffer_start) = old_snapshot
1178                    .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1179                    .unwrap();
1180                let (_, buffer_end) = old_snapshot
1181                    .point_to_buffer_offset(Point::new(
1182                        *old_row_range.end(),
1183                        old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1184                    ))
1185                    .unwrap();
1186
1187                let deleted_lines_editor = cx.new_view(|cx| {
1188                    let multi_buffer = cx.new_model(|_| {
1189                        MultiBuffer::without_headers(language::Capability::ReadOnly)
1190                    });
1191                    multi_buffer.update(cx, |multi_buffer, cx| {
1192                        multi_buffer.push_excerpts(
1193                            old_buffer.clone(),
1194                            Some(ExcerptRange {
1195                                context: buffer_start..buffer_end,
1196                                primary: None,
1197                            }),
1198                            cx,
1199                        );
1200                    });
1201
1202                    enum DeletedLines {}
1203                    let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1204                    editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1205                    editor.set_show_wrap_guides(false, cx);
1206                    editor.set_show_gutter(false, cx);
1207                    editor.scroll_manager.set_forbid_vertical_scroll(true);
1208                    editor.set_read_only(true);
1209                    editor.set_show_inline_completions(Some(false), cx);
1210                    editor.highlight_rows::<DeletedLines>(
1211                        Anchor::min()..=Anchor::max(),
1212                        Some(cx.theme().status().deleted_background),
1213                        false,
1214                        cx,
1215                    );
1216                    editor
1217                });
1218
1219                let height =
1220                    deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1221                new_blocks.push(BlockProperties {
1222                    position: new_row,
1223                    height,
1224                    style: BlockStyle::Flex,
1225                    render: Box::new(move |cx| {
1226                        div()
1227                            .bg(cx.theme().status().deleted_background)
1228                            .size_full()
1229                            .h(height as f32 * cx.line_height())
1230                            .pl(cx.gutter_dimensions.full_width())
1231                            .child(deleted_lines_editor.clone())
1232                            .into_any_element()
1233                    }),
1234                    disposition: BlockDisposition::Above,
1235                    priority: 0,
1236                });
1237            }
1238
1239            decorations.removed_line_block_ids = editor
1240                .insert_blocks(new_blocks, None, cx)
1241                .into_iter()
1242                .collect();
1243        })
1244    }
1245
1246    pub fn observe_assist(
1247        &mut self,
1248        assist_id: InlineAssistId,
1249    ) -> async_watch::Receiver<AssistStatus> {
1250        if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1251            rx.clone()
1252        } else {
1253            let (tx, rx) = async_watch::channel(AssistStatus::Idle);
1254            self.assist_observations.insert(assist_id, (tx, rx.clone()));
1255            rx
1256        }
1257    }
1258}
1259
1260pub enum InlineAssistStatus {
1261    Idle,
1262    Pending,
1263    Done,
1264    Error,
1265    Confirmed,
1266    Canceled,
1267}
1268
1269impl InlineAssistStatus {
1270    pub(crate) fn is_pending(&self) -> bool {
1271        matches!(self, Self::Pending)
1272    }
1273
1274    pub(crate) fn is_confirmed(&self) -> bool {
1275        matches!(self, Self::Confirmed)
1276    }
1277
1278    pub(crate) fn is_done(&self) -> bool {
1279        matches!(self, Self::Done)
1280    }
1281}
1282
1283struct EditorInlineAssists {
1284    assist_ids: Vec<InlineAssistId>,
1285    scroll_lock: Option<InlineAssistScrollLock>,
1286    highlight_updates: async_watch::Sender<()>,
1287    _update_highlights: Task<Result<()>>,
1288    _subscriptions: Vec<gpui::Subscription>,
1289}
1290
1291struct InlineAssistScrollLock {
1292    assist_id: InlineAssistId,
1293    distance_from_top: f32,
1294}
1295
1296impl EditorInlineAssists {
1297    #[allow(clippy::too_many_arguments)]
1298    fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1299        let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1300        Self {
1301            assist_ids: Vec::new(),
1302            scroll_lock: None,
1303            highlight_updates: highlight_updates_tx,
1304            _update_highlights: cx.spawn(|mut cx| {
1305                let editor = editor.downgrade();
1306                async move {
1307                    while let Ok(()) = highlight_updates_rx.changed().await {
1308                        let editor = editor.upgrade().context("editor was dropped")?;
1309                        cx.update_global(|assistant: &mut InlineAssistant, cx| {
1310                            assistant.update_editor_highlights(&editor, cx);
1311                        })?;
1312                    }
1313                    Ok(())
1314                }
1315            }),
1316            _subscriptions: vec![
1317                cx.observe_release(editor, {
1318                    let editor = editor.downgrade();
1319                    |_, cx| {
1320                        InlineAssistant::update_global(cx, |this, cx| {
1321                            this.handle_editor_release(editor, cx);
1322                        })
1323                    }
1324                }),
1325                cx.observe(editor, move |editor, cx| {
1326                    InlineAssistant::update_global(cx, |this, cx| {
1327                        this.handle_editor_change(editor, cx)
1328                    })
1329                }),
1330                cx.subscribe(editor, move |editor, event, cx| {
1331                    InlineAssistant::update_global(cx, |this, cx| {
1332                        this.handle_editor_event(editor, event, cx)
1333                    })
1334                }),
1335                editor.update(cx, |editor, cx| {
1336                    let editor_handle = cx.view().downgrade();
1337                    editor.register_action(
1338                        move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1339                            InlineAssistant::update_global(cx, |this, cx| {
1340                                if let Some(editor) = editor_handle.upgrade() {
1341                                    this.handle_editor_newline(editor, cx)
1342                                }
1343                            })
1344                        },
1345                    )
1346                }),
1347                editor.update(cx, |editor, cx| {
1348                    let editor_handle = cx.view().downgrade();
1349                    editor.register_action(
1350                        move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1351                            InlineAssistant::update_global(cx, |this, cx| {
1352                                if let Some(editor) = editor_handle.upgrade() {
1353                                    this.handle_editor_cancel(editor, cx)
1354                                }
1355                            })
1356                        },
1357                    )
1358                }),
1359            ],
1360        }
1361    }
1362}
1363
1364struct InlineAssistGroup {
1365    assist_ids: Vec<InlineAssistId>,
1366    linked: bool,
1367    active_assist_id: Option<InlineAssistId>,
1368}
1369
1370impl InlineAssistGroup {
1371    fn new() -> Self {
1372        Self {
1373            assist_ids: Vec::new(),
1374            linked: true,
1375            active_assist_id: None,
1376        }
1377    }
1378}
1379
1380fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1381    let editor = editor.clone();
1382    Box::new(move |cx: &mut BlockContext| {
1383        *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1384        editor.clone().into_any_element()
1385    })
1386}
1387
1388#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1389pub struct InlineAssistId(usize);
1390
1391impl InlineAssistId {
1392    fn post_inc(&mut self) -> InlineAssistId {
1393        let id = *self;
1394        self.0 += 1;
1395        id
1396    }
1397}
1398
1399#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1400struct InlineAssistGroupId(usize);
1401
1402impl InlineAssistGroupId {
1403    fn post_inc(&mut self) -> InlineAssistGroupId {
1404        let id = *self;
1405        self.0 += 1;
1406        id
1407    }
1408}
1409
1410enum PromptEditorEvent {
1411    StartRequested,
1412    StopRequested,
1413    ConfirmRequested,
1414    CancelRequested,
1415    DismissRequested,
1416}
1417
1418struct PromptEditor {
1419    id: InlineAssistId,
1420    fs: Arc<dyn Fs>,
1421    editor: View<Editor>,
1422    edited_since_done: bool,
1423    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1424    prompt_history: VecDeque<String>,
1425    prompt_history_ix: Option<usize>,
1426    pending_prompt: String,
1427    codegen: Model<Codegen>,
1428    _codegen_subscription: Subscription,
1429    editor_subscriptions: Vec<Subscription>,
1430    pending_token_count: Task<Result<()>>,
1431    token_counts: Option<TokenCounts>,
1432    _token_count_subscriptions: Vec<Subscription>,
1433    workspace: Option<WeakView<Workspace>>,
1434    show_rate_limit_notice: bool,
1435}
1436
1437#[derive(Copy, Clone)]
1438pub struct TokenCounts {
1439    total: usize,
1440    assistant_panel: usize,
1441}
1442
1443impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1444
1445impl Render for PromptEditor {
1446    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1447        let gutter_dimensions = *self.gutter_dimensions.lock();
1448        let codegen = self.codegen.read(cx);
1449
1450        let mut buttons = Vec::new();
1451        if codegen.alternative_count(cx) > 1 {
1452            buttons.push(self.render_cycle_controls(cx));
1453        }
1454
1455        let status = codegen.status(cx);
1456        buttons.extend(match status {
1457            CodegenStatus::Idle => {
1458                vec![
1459                    IconButton::new("cancel", IconName::Close)
1460                        .icon_color(Color::Muted)
1461                        .shape(IconButtonShape::Square)
1462                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1463                        .on_click(
1464                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1465                        )
1466                        .into_any_element(),
1467                    IconButton::new("start", IconName::SparkleAlt)
1468                        .icon_color(Color::Muted)
1469                        .shape(IconButtonShape::Square)
1470                        .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1471                        .on_click(
1472                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1473                        )
1474                        .into_any_element(),
1475                ]
1476            }
1477            CodegenStatus::Pending => {
1478                vec![
1479                    IconButton::new("cancel", IconName::Close)
1480                        .icon_color(Color::Muted)
1481                        .shape(IconButtonShape::Square)
1482                        .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1483                        .on_click(
1484                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1485                        )
1486                        .into_any_element(),
1487                    IconButton::new("stop", IconName::Stop)
1488                        .icon_color(Color::Error)
1489                        .shape(IconButtonShape::Square)
1490                        .tooltip(|cx| {
1491                            Tooltip::with_meta(
1492                                "Interrupt Transformation",
1493                                Some(&menu::Cancel),
1494                                "Changes won't be discarded",
1495                                cx,
1496                            )
1497                        })
1498                        .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
1499                        .into_any_element(),
1500                ]
1501            }
1502            CodegenStatus::Error(_) | CodegenStatus::Done => {
1503                vec![
1504                    IconButton::new("cancel", IconName::Close)
1505                        .icon_color(Color::Muted)
1506                        .shape(IconButtonShape::Square)
1507                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1508                        .on_click(
1509                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1510                        )
1511                        .into_any_element(),
1512                    if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1513                        IconButton::new("restart", IconName::RotateCw)
1514                            .icon_color(Color::Info)
1515                            .shape(IconButtonShape::Square)
1516                            .tooltip(|cx| {
1517                                Tooltip::with_meta(
1518                                    "Restart Transformation",
1519                                    Some(&menu::Confirm),
1520                                    "Changes will be discarded",
1521                                    cx,
1522                                )
1523                            })
1524                            .on_click(cx.listener(|_, _, cx| {
1525                                cx.emit(PromptEditorEvent::StartRequested);
1526                            }))
1527                            .into_any_element()
1528                    } else {
1529                        IconButton::new("confirm", IconName::Check)
1530                            .icon_color(Color::Info)
1531                            .shape(IconButtonShape::Square)
1532                            .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1533                            .on_click(cx.listener(|_, _, cx| {
1534                                cx.emit(PromptEditorEvent::ConfirmRequested);
1535                            }))
1536                            .into_any_element()
1537                    },
1538                ]
1539            }
1540        });
1541
1542        h_flex()
1543            .key_context("PromptEditor")
1544            .bg(cx.theme().colors().editor_background)
1545            .border_y_1()
1546            .border_color(cx.theme().status().info_border)
1547            .size_full()
1548            .py(cx.line_height() / 2.5)
1549            .on_action(cx.listener(Self::confirm))
1550            .on_action(cx.listener(Self::cancel))
1551            .on_action(cx.listener(Self::move_up))
1552            .on_action(cx.listener(Self::move_down))
1553            .capture_action(cx.listener(Self::cycle_prev))
1554            .capture_action(cx.listener(Self::cycle_next))
1555            .child(
1556                h_flex()
1557                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1558                    .justify_center()
1559                    .gap_2()
1560                    .child(
1561                        ModelSelector::new(
1562                            self.fs.clone(),
1563                            IconButton::new("context", IconName::SettingsAlt)
1564                                .shape(IconButtonShape::Square)
1565                                .icon_size(IconSize::Small)
1566                                .icon_color(Color::Muted)
1567                                .tooltip(move |cx| {
1568                                    Tooltip::with_meta(
1569                                        format!(
1570                                            "Using {}",
1571                                            LanguageModelRegistry::read_global(cx)
1572                                                .active_model()
1573                                                .map(|model| model.name().0)
1574                                                .unwrap_or_else(|| "No model selected".into()),
1575                                        ),
1576                                        None,
1577                                        "Change Model",
1578                                        cx,
1579                                    )
1580                                }),
1581                        )
1582                        .with_info_text(
1583                            "Inline edits use context\n\
1584                            from the currently selected\n\
1585                            assistant panel tab.",
1586                        ),
1587                    )
1588                    .map(|el| {
1589                        let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1590                            return el;
1591                        };
1592
1593                        let error_message = SharedString::from(error.to_string());
1594                        if error.error_code() == proto::ErrorCode::RateLimitExceeded
1595                            && cx.has_flag::<ZedPro>()
1596                        {
1597                            el.child(
1598                                v_flex()
1599                                    .child(
1600                                        IconButton::new("rate-limit-error", IconName::XCircle)
1601                                            .selected(self.show_rate_limit_notice)
1602                                            .shape(IconButtonShape::Square)
1603                                            .icon_size(IconSize::Small)
1604                                            .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1605                                    )
1606                                    .children(self.show_rate_limit_notice.then(|| {
1607                                        deferred(
1608                                            anchored()
1609                                                .position_mode(gpui::AnchoredPositionMode::Local)
1610                                                .position(point(px(0.), px(24.)))
1611                                                .anchor(gpui::AnchorCorner::TopLeft)
1612                                                .child(self.render_rate_limit_notice(cx)),
1613                                        )
1614                                    })),
1615                            )
1616                        } else {
1617                            el.child(
1618                                div()
1619                                    .id("error")
1620                                    .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1621                                    .child(
1622                                        Icon::new(IconName::XCircle)
1623                                            .size(IconSize::Small)
1624                                            .color(Color::Error),
1625                                    ),
1626                            )
1627                        }
1628                    }),
1629            )
1630            .child(div().flex_1().child(self.render_prompt_editor(cx)))
1631            .child(
1632                h_flex()
1633                    .gap_2()
1634                    .pr_6()
1635                    .children(self.render_token_count(cx))
1636                    .children(buttons),
1637            )
1638    }
1639}
1640
1641impl FocusableView for PromptEditor {
1642    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1643        self.editor.focus_handle(cx)
1644    }
1645}
1646
1647impl PromptEditor {
1648    const MAX_LINES: u8 = 8;
1649
1650    #[allow(clippy::too_many_arguments)]
1651    fn new(
1652        id: InlineAssistId,
1653        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1654        prompt_history: VecDeque<String>,
1655        prompt_buffer: Model<MultiBuffer>,
1656        codegen: Model<Codegen>,
1657        parent_editor: &View<Editor>,
1658        assistant_panel: Option<&View<AssistantPanel>>,
1659        workspace: Option<WeakView<Workspace>>,
1660        fs: Arc<dyn Fs>,
1661        cx: &mut ViewContext<Self>,
1662    ) -> Self {
1663        let prompt_editor = cx.new_view(|cx| {
1664            let mut editor = Editor::new(
1665                EditorMode::AutoHeight {
1666                    max_lines: Self::MAX_LINES as usize,
1667                },
1668                prompt_buffer,
1669                None,
1670                false,
1671                cx,
1672            );
1673            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1674            // Since the prompt editors for all inline assistants are linked,
1675            // always show the cursor (even when it isn't focused) because
1676            // typing in one will make what you typed appear in all of them.
1677            editor.set_show_cursor_when_unfocused(true, cx);
1678            editor.set_placeholder_text("Add a prompt…", cx);
1679            editor
1680        });
1681
1682        let mut token_count_subscriptions = Vec::new();
1683        token_count_subscriptions
1684            .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1685        if let Some(assistant_panel) = assistant_panel {
1686            token_count_subscriptions
1687                .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1688        }
1689
1690        let mut this = Self {
1691            id,
1692            editor: prompt_editor,
1693            edited_since_done: false,
1694            gutter_dimensions,
1695            prompt_history,
1696            prompt_history_ix: None,
1697            pending_prompt: String::new(),
1698            _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1699            editor_subscriptions: Vec::new(),
1700            codegen,
1701            fs,
1702            pending_token_count: Task::ready(Ok(())),
1703            token_counts: None,
1704            _token_count_subscriptions: token_count_subscriptions,
1705            workspace,
1706            show_rate_limit_notice: false,
1707        };
1708        this.count_tokens(cx);
1709        this.subscribe_to_editor(cx);
1710        this
1711    }
1712
1713    fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1714        self.editor_subscriptions.clear();
1715        self.editor_subscriptions
1716            .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1717    }
1718
1719    fn set_show_cursor_when_unfocused(
1720        &mut self,
1721        show_cursor_when_unfocused: bool,
1722        cx: &mut ViewContext<Self>,
1723    ) {
1724        self.editor.update(cx, |editor, cx| {
1725            editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1726        });
1727    }
1728
1729    fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1730        let prompt = self.prompt(cx);
1731        let focus = self.editor.focus_handle(cx).contains_focused(cx);
1732        self.editor = cx.new_view(|cx| {
1733            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1734            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1735            editor.set_placeholder_text("Add a prompt…", cx);
1736            editor.set_text(prompt, cx);
1737            if focus {
1738                editor.focus(cx);
1739            }
1740            editor
1741        });
1742        self.subscribe_to_editor(cx);
1743    }
1744
1745    fn prompt(&self, cx: &AppContext) -> String {
1746        self.editor.read(cx).text(cx)
1747    }
1748
1749    fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1750        self.show_rate_limit_notice = !self.show_rate_limit_notice;
1751        if self.show_rate_limit_notice {
1752            cx.focus_view(&self.editor);
1753        }
1754        cx.notify();
1755    }
1756
1757    fn handle_parent_editor_event(
1758        &mut self,
1759        _: View<Editor>,
1760        event: &EditorEvent,
1761        cx: &mut ViewContext<Self>,
1762    ) {
1763        if let EditorEvent::BufferEdited { .. } = event {
1764            self.count_tokens(cx);
1765        }
1766    }
1767
1768    fn handle_assistant_panel_event(
1769        &mut self,
1770        _: View<AssistantPanel>,
1771        event: &AssistantPanelEvent,
1772        cx: &mut ViewContext<Self>,
1773    ) {
1774        let AssistantPanelEvent::ContextEdited { .. } = event;
1775        self.count_tokens(cx);
1776    }
1777
1778    fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1779        let assist_id = self.id;
1780        self.pending_token_count = cx.spawn(|this, mut cx| async move {
1781            cx.background_executor().timer(Duration::from_secs(1)).await;
1782            let token_count = cx
1783                .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1784                    let assist = inline_assistant
1785                        .assists
1786                        .get(&assist_id)
1787                        .context("assist not found")?;
1788                    anyhow::Ok(assist.count_tokens(cx))
1789                })??
1790                .await?;
1791
1792            this.update(&mut cx, |this, cx| {
1793                this.token_counts = Some(token_count);
1794                cx.notify();
1795            })
1796        })
1797    }
1798
1799    fn handle_prompt_editor_events(
1800        &mut self,
1801        _: View<Editor>,
1802        event: &EditorEvent,
1803        cx: &mut ViewContext<Self>,
1804    ) {
1805        match event {
1806            EditorEvent::Edited { .. } => {
1807                let prompt = self.editor.read(cx).text(cx);
1808                if self
1809                    .prompt_history_ix
1810                    .map_or(true, |ix| self.prompt_history[ix] != prompt)
1811                {
1812                    self.prompt_history_ix.take();
1813                    self.pending_prompt = prompt;
1814                }
1815
1816                self.edited_since_done = true;
1817                cx.notify();
1818            }
1819            EditorEvent::BufferEdited => {
1820                self.count_tokens(cx);
1821            }
1822            EditorEvent::Blurred => {
1823                if self.show_rate_limit_notice {
1824                    self.show_rate_limit_notice = false;
1825                    cx.notify();
1826                }
1827            }
1828            _ => {}
1829        }
1830    }
1831
1832    fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1833        match self.codegen.read(cx).status(cx) {
1834            CodegenStatus::Idle => {
1835                self.editor
1836                    .update(cx, |editor, _| editor.set_read_only(false));
1837            }
1838            CodegenStatus::Pending => {
1839                self.editor
1840                    .update(cx, |editor, _| editor.set_read_only(true));
1841            }
1842            CodegenStatus::Done => {
1843                self.edited_since_done = false;
1844                self.editor
1845                    .update(cx, |editor, _| editor.set_read_only(false));
1846            }
1847            CodegenStatus::Error(error) => {
1848                if cx.has_flag::<ZedPro>()
1849                    && error.error_code() == proto::ErrorCode::RateLimitExceeded
1850                    && !dismissed_rate_limit_notice()
1851                {
1852                    self.show_rate_limit_notice = true;
1853                    cx.notify();
1854                }
1855
1856                self.edited_since_done = false;
1857                self.editor
1858                    .update(cx, |editor, _| editor.set_read_only(false));
1859            }
1860        }
1861    }
1862
1863    fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1864        match self.codegen.read(cx).status(cx) {
1865            CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1866                cx.emit(PromptEditorEvent::CancelRequested);
1867            }
1868            CodegenStatus::Pending => {
1869                cx.emit(PromptEditorEvent::StopRequested);
1870            }
1871        }
1872    }
1873
1874    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1875        match self.codegen.read(cx).status(cx) {
1876            CodegenStatus::Idle => {
1877                cx.emit(PromptEditorEvent::StartRequested);
1878            }
1879            CodegenStatus::Pending => {
1880                cx.emit(PromptEditorEvent::DismissRequested);
1881            }
1882            CodegenStatus::Done => {
1883                if self.edited_since_done {
1884                    cx.emit(PromptEditorEvent::StartRequested);
1885                } else {
1886                    cx.emit(PromptEditorEvent::ConfirmRequested);
1887                }
1888            }
1889            CodegenStatus::Error(_) => {
1890                cx.emit(PromptEditorEvent::StartRequested);
1891            }
1892        }
1893    }
1894
1895    fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1896        if let Some(ix) = self.prompt_history_ix {
1897            if ix > 0 {
1898                self.prompt_history_ix = Some(ix - 1);
1899                let prompt = self.prompt_history[ix - 1].as_str();
1900                self.editor.update(cx, |editor, cx| {
1901                    editor.set_text(prompt, cx);
1902                    editor.move_to_beginning(&Default::default(), cx);
1903                });
1904            }
1905        } else if !self.prompt_history.is_empty() {
1906            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1907            let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1908            self.editor.update(cx, |editor, cx| {
1909                editor.set_text(prompt, cx);
1910                editor.move_to_beginning(&Default::default(), cx);
1911            });
1912        }
1913    }
1914
1915    fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1916        if let Some(ix) = self.prompt_history_ix {
1917            if ix < self.prompt_history.len() - 1 {
1918                self.prompt_history_ix = Some(ix + 1);
1919                let prompt = self.prompt_history[ix + 1].as_str();
1920                self.editor.update(cx, |editor, cx| {
1921                    editor.set_text(prompt, cx);
1922                    editor.move_to_end(&Default::default(), cx)
1923                });
1924            } else {
1925                self.prompt_history_ix = None;
1926                let prompt = self.pending_prompt.as_str();
1927                self.editor.update(cx, |editor, cx| {
1928                    editor.set_text(prompt, cx);
1929                    editor.move_to_end(&Default::default(), cx)
1930                });
1931            }
1932        }
1933    }
1934
1935    fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
1936        self.codegen
1937            .update(cx, |codegen, cx| codegen.cycle_prev(cx));
1938    }
1939
1940    fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
1941        self.codegen
1942            .update(cx, |codegen, cx| codegen.cycle_next(cx));
1943    }
1944
1945    fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
1946        let codegen = self.codegen.read(cx);
1947        let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
1948
1949        h_flex()
1950            .child(
1951                IconButton::new("previous", IconName::ChevronLeft)
1952                    .icon_color(Color::Muted)
1953                    .disabled(disabled)
1954                    .shape(IconButtonShape::Square)
1955                    .tooltip({
1956                        let focus_handle = self.editor.focus_handle(cx);
1957                        move |cx| {
1958                            Tooltip::for_action_in(
1959                                "Previous Alternative",
1960                                &CyclePreviousInlineAssist,
1961                                &focus_handle,
1962                                cx,
1963                            )
1964                        }
1965                    })
1966                    .on_click(cx.listener(|this, _, cx| {
1967                        this.codegen
1968                            .update(cx, |codegen, cx| codegen.cycle_prev(cx))
1969                    })),
1970            )
1971            .child(
1972                Label::new(format!(
1973                    "{}/{}",
1974                    codegen.active_alternative + 1,
1975                    codegen.alternative_count(cx)
1976                ))
1977                .size(LabelSize::Small)
1978                .color(if disabled {
1979                    Color::Disabled
1980                } else {
1981                    Color::Muted
1982                }),
1983            )
1984            .child(
1985                IconButton::new("next", IconName::ChevronRight)
1986                    .icon_color(Color::Muted)
1987                    .disabled(disabled)
1988                    .shape(IconButtonShape::Square)
1989                    .tooltip({
1990                        let focus_handle = self.editor.focus_handle(cx);
1991                        move |cx| {
1992                            Tooltip::for_action_in(
1993                                "Next Alternative",
1994                                &CycleNextInlineAssist,
1995                                &focus_handle,
1996                                cx,
1997                            )
1998                        }
1999                    })
2000                    .on_click(cx.listener(|this, _, cx| {
2001                        this.codegen
2002                            .update(cx, |codegen, cx| codegen.cycle_next(cx))
2003                    })),
2004            )
2005            .into_any_element()
2006    }
2007
2008    fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
2009        let model = LanguageModelRegistry::read_global(cx).active_model()?;
2010        let token_counts = self.token_counts?;
2011        let max_token_count = model.max_token_count();
2012
2013        let remaining_tokens = max_token_count as isize - token_counts.total as isize;
2014        let token_count_color = if remaining_tokens <= 0 {
2015            Color::Error
2016        } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
2017            Color::Warning
2018        } else {
2019            Color::Muted
2020        };
2021
2022        let mut token_count = h_flex()
2023            .id("token_count")
2024            .gap_0p5()
2025            .child(
2026                Label::new(humanize_token_count(token_counts.total))
2027                    .size(LabelSize::Small)
2028                    .color(token_count_color),
2029            )
2030            .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
2031            .child(
2032                Label::new(humanize_token_count(max_token_count))
2033                    .size(LabelSize::Small)
2034                    .color(Color::Muted),
2035            );
2036        if let Some(workspace) = self.workspace.clone() {
2037            token_count = token_count
2038                .tooltip(move |cx| {
2039                    Tooltip::with_meta(
2040                        format!(
2041                            "Tokens Used ({} from the Assistant Panel)",
2042                            humanize_token_count(token_counts.assistant_panel)
2043                        ),
2044                        None,
2045                        "Click to open the Assistant Panel",
2046                        cx,
2047                    )
2048                })
2049                .cursor_pointer()
2050                .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
2051                .on_click(move |_, cx| {
2052                    cx.stop_propagation();
2053                    workspace
2054                        .update(cx, |workspace, cx| {
2055                            workspace.focus_panel::<AssistantPanel>(cx)
2056                        })
2057                        .ok();
2058                });
2059        } else {
2060            token_count = token_count
2061                .cursor_default()
2062                .tooltip(|cx| Tooltip::text("Tokens used", cx));
2063        }
2064
2065        Some(token_count)
2066    }
2067
2068    fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2069        let settings = ThemeSettings::get_global(cx);
2070        let text_style = TextStyle {
2071            color: if self.editor.read(cx).read_only(cx) {
2072                cx.theme().colors().text_disabled
2073            } else {
2074                cx.theme().colors().text
2075            },
2076            font_family: settings.buffer_font.family.clone(),
2077            font_fallbacks: settings.buffer_font.fallbacks.clone(),
2078            font_size: settings.buffer_font_size.into(),
2079            font_weight: settings.buffer_font.weight,
2080            line_height: relative(settings.buffer_line_height.value()),
2081            ..Default::default()
2082        };
2083        EditorElement::new(
2084            &self.editor,
2085            EditorStyle {
2086                background: cx.theme().colors().editor_background,
2087                local_player: cx.theme().players().local(),
2088                text: text_style,
2089                ..Default::default()
2090            },
2091        )
2092    }
2093
2094    fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2095        Popover::new().child(
2096            v_flex()
2097                .occlude()
2098                .p_2()
2099                .child(
2100                    Label::new("Out of Tokens")
2101                        .size(LabelSize::Small)
2102                        .weight(FontWeight::BOLD),
2103                )
2104                .child(Label::new(
2105                    "Try Zed Pro for higher limits, a wider range of models, and more.",
2106                ))
2107                .child(
2108                    h_flex()
2109                        .justify_between()
2110                        .child(CheckboxWithLabel::new(
2111                            "dont-show-again",
2112                            Label::new("Don't show again"),
2113                            if dismissed_rate_limit_notice() {
2114                                ui::Selection::Selected
2115                            } else {
2116                                ui::Selection::Unselected
2117                            },
2118                            |selection, cx| {
2119                                let is_dismissed = match selection {
2120                                    ui::Selection::Unselected => false,
2121                                    ui::Selection::Indeterminate => return,
2122                                    ui::Selection::Selected => true,
2123                                };
2124
2125                                set_rate_limit_notice_dismissed(is_dismissed, cx)
2126                            },
2127                        ))
2128                        .child(
2129                            h_flex()
2130                                .gap_2()
2131                                .child(
2132                                    Button::new("dismiss", "Dismiss")
2133                                        .style(ButtonStyle::Transparent)
2134                                        .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2135                                )
2136                                .child(Button::new("more-info", "More Info").on_click(
2137                                    |_event, cx| {
2138                                        cx.dispatch_action(Box::new(
2139                                            zed_actions::OpenAccountSettings,
2140                                        ))
2141                                    },
2142                                )),
2143                        ),
2144                ),
2145        )
2146    }
2147}
2148
2149const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2150
2151fn dismissed_rate_limit_notice() -> bool {
2152    db::kvp::KEY_VALUE_STORE
2153        .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2154        .log_err()
2155        .map_or(false, |s| s.is_some())
2156}
2157
2158fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2159    db::write_and_log(cx, move || async move {
2160        if is_dismissed {
2161            db::kvp::KEY_VALUE_STORE
2162                .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2163                .await
2164        } else {
2165            db::kvp::KEY_VALUE_STORE
2166                .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2167                .await
2168        }
2169    })
2170}
2171
2172struct InlineAssist {
2173    group_id: InlineAssistGroupId,
2174    range: Range<Anchor>,
2175    editor: WeakView<Editor>,
2176    decorations: Option<InlineAssistDecorations>,
2177    codegen: Model<Codegen>,
2178    _subscriptions: Vec<Subscription>,
2179    workspace: Option<WeakView<Workspace>>,
2180    include_context: bool,
2181}
2182
2183impl InlineAssist {
2184    #[allow(clippy::too_many_arguments)]
2185    fn new(
2186        assist_id: InlineAssistId,
2187        group_id: InlineAssistGroupId,
2188        include_context: bool,
2189        editor: &View<Editor>,
2190        prompt_editor: &View<PromptEditor>,
2191        prompt_block_id: CustomBlockId,
2192        end_block_id: CustomBlockId,
2193        range: Range<Anchor>,
2194        codegen: Model<Codegen>,
2195        workspace: Option<WeakView<Workspace>>,
2196        cx: &mut WindowContext,
2197    ) -> Self {
2198        let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2199        InlineAssist {
2200            group_id,
2201            include_context,
2202            editor: editor.downgrade(),
2203            decorations: Some(InlineAssistDecorations {
2204                prompt_block_id,
2205                prompt_editor: prompt_editor.clone(),
2206                removed_line_block_ids: HashSet::default(),
2207                end_block_id,
2208            }),
2209            range,
2210            codegen: codegen.clone(),
2211            workspace: workspace.clone(),
2212            _subscriptions: vec![
2213                cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2214                    InlineAssistant::update_global(cx, |this, cx| {
2215                        this.handle_prompt_editor_focus_in(assist_id, cx)
2216                    })
2217                }),
2218                cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2219                    InlineAssistant::update_global(cx, |this, cx| {
2220                        this.handle_prompt_editor_focus_out(assist_id, cx)
2221                    })
2222                }),
2223                cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2224                    InlineAssistant::update_global(cx, |this, cx| {
2225                        this.handle_prompt_editor_event(prompt_editor, event, cx)
2226                    })
2227                }),
2228                cx.observe(&codegen, {
2229                    let editor = editor.downgrade();
2230                    move |_, cx| {
2231                        if let Some(editor) = editor.upgrade() {
2232                            InlineAssistant::update_global(cx, |this, cx| {
2233                                if let Some(editor_assists) =
2234                                    this.assists_by_editor.get(&editor.downgrade())
2235                                {
2236                                    editor_assists.highlight_updates.send(()).ok();
2237                                }
2238
2239                                this.update_editor_blocks(&editor, assist_id, cx);
2240                            })
2241                        }
2242                    }
2243                }),
2244                cx.subscribe(&codegen, move |codegen, event, cx| {
2245                    InlineAssistant::update_global(cx, |this, cx| match event {
2246                        CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2247                        CodegenEvent::Finished => {
2248                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
2249                                assist
2250                            } else {
2251                                return;
2252                            };
2253
2254                            if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2255                                if assist.decorations.is_none() {
2256                                    if let Some(workspace) = assist
2257                                        .workspace
2258                                        .as_ref()
2259                                        .and_then(|workspace| workspace.upgrade())
2260                                    {
2261                                        let error = format!("Inline assistant error: {}", error);
2262                                        workspace.update(cx, |workspace, cx| {
2263                                            struct InlineAssistantError;
2264
2265                                            let id =
2266                                                NotificationId::identified::<InlineAssistantError>(
2267                                                    assist_id.0,
2268                                                );
2269
2270                                            workspace.show_toast(Toast::new(id, error), cx);
2271                                        })
2272                                    }
2273                                }
2274                            }
2275
2276                            if assist.decorations.is_none() {
2277                                this.finish_assist(assist_id, false, cx);
2278                            } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2279                                tx.0.send(AssistStatus::Finished).ok();
2280                            }
2281                        }
2282                    })
2283                }),
2284            ],
2285        }
2286    }
2287
2288    fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2289        let decorations = self.decorations.as_ref()?;
2290        Some(decorations.prompt_editor.read(cx).prompt(cx))
2291    }
2292
2293    fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2294        if self.include_context {
2295            let workspace = self.workspace.as_ref()?;
2296            let workspace = workspace.upgrade()?.read(cx);
2297            let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2298            Some(
2299                assistant_panel
2300                    .read(cx)
2301                    .active_context(cx)?
2302                    .read(cx)
2303                    .to_completion_request(cx),
2304            )
2305        } else {
2306            None
2307        }
2308    }
2309
2310    pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2311        let Some(user_prompt) = self.user_prompt(cx) else {
2312            return future::ready(Err(anyhow!("no user prompt"))).boxed();
2313        };
2314        let assistant_panel_context = self.assistant_panel_context(cx);
2315        self.codegen
2316            .read(cx)
2317            .count_tokens(user_prompt, assistant_panel_context, cx)
2318    }
2319}
2320
2321struct InlineAssistDecorations {
2322    prompt_block_id: CustomBlockId,
2323    prompt_editor: View<PromptEditor>,
2324    removed_line_block_ids: HashSet<CustomBlockId>,
2325    end_block_id: CustomBlockId,
2326}
2327
2328#[derive(Copy, Clone, Debug)]
2329pub enum CodegenEvent {
2330    Finished,
2331    Undone,
2332}
2333
2334pub struct Codegen {
2335    alternatives: Vec<Model<CodegenAlternative>>,
2336    active_alternative: usize,
2337    subscriptions: Vec<Subscription>,
2338    buffer: Model<MultiBuffer>,
2339    range: Range<Anchor>,
2340    initial_transaction_id: Option<TransactionId>,
2341    telemetry: Option<Arc<Telemetry>>,
2342    builder: Arc<PromptBuilder>,
2343}
2344
2345impl Codegen {
2346    pub fn new(
2347        buffer: Model<MultiBuffer>,
2348        range: Range<Anchor>,
2349        initial_transaction_id: Option<TransactionId>,
2350        telemetry: Option<Arc<Telemetry>>,
2351        builder: Arc<PromptBuilder>,
2352        cx: &mut ModelContext<Self>,
2353    ) -> Self {
2354        let codegen = cx.new_model(|cx| {
2355            CodegenAlternative::new(
2356                buffer.clone(),
2357                range.clone(),
2358                false,
2359                telemetry.clone(),
2360                builder.clone(),
2361                cx,
2362            )
2363        });
2364        let mut this = Self {
2365            alternatives: vec![codegen],
2366            active_alternative: 0,
2367            subscriptions: Vec::new(),
2368            buffer,
2369            range,
2370            initial_transaction_id,
2371            telemetry,
2372            builder,
2373        };
2374        this.activate(0, cx);
2375        this
2376    }
2377
2378    fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2379        let codegen = self.active_alternative().clone();
2380        self.subscriptions.clear();
2381        self.subscriptions
2382            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2383        self.subscriptions
2384            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2385    }
2386
2387    fn active_alternative(&self) -> &Model<CodegenAlternative> {
2388        &self.alternatives[self.active_alternative]
2389    }
2390
2391    fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2392        &self.active_alternative().read(cx).status
2393    }
2394
2395    fn alternative_count(&self, cx: &AppContext) -> usize {
2396        LanguageModelRegistry::read_global(cx)
2397            .inline_alternative_models()
2398            .len()
2399            + 1
2400    }
2401
2402    pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2403        let next_active_ix = if self.active_alternative == 0 {
2404            self.alternatives.len() - 1
2405        } else {
2406            self.active_alternative - 1
2407        };
2408        self.activate(next_active_ix, cx);
2409    }
2410
2411    pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2412        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2413        self.activate(next_active_ix, cx);
2414    }
2415
2416    fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2417        self.active_alternative()
2418            .update(cx, |codegen, cx| codegen.set_active(false, cx));
2419        self.active_alternative = index;
2420        self.active_alternative()
2421            .update(cx, |codegen, cx| codegen.set_active(true, cx));
2422        self.subscribe_to_alternative(cx);
2423        cx.notify();
2424    }
2425
2426    pub fn start(
2427        &mut self,
2428        user_prompt: String,
2429        assistant_panel_context: Option<LanguageModelRequest>,
2430        cx: &mut ModelContext<Self>,
2431    ) -> Result<()> {
2432        let alternative_models = LanguageModelRegistry::read_global(cx)
2433            .inline_alternative_models()
2434            .to_vec();
2435
2436        self.active_alternative()
2437            .update(cx, |alternative, cx| alternative.undo(cx));
2438        self.activate(0, cx);
2439        self.alternatives.truncate(1);
2440
2441        for _ in 0..alternative_models.len() {
2442            self.alternatives.push(cx.new_model(|cx| {
2443                CodegenAlternative::new(
2444                    self.buffer.clone(),
2445                    self.range.clone(),
2446                    false,
2447                    self.telemetry.clone(),
2448                    self.builder.clone(),
2449                    cx,
2450                )
2451            }));
2452        }
2453
2454        let primary_model = LanguageModelRegistry::read_global(cx)
2455            .active_model()
2456            .context("no active model")?;
2457
2458        for (model, alternative) in iter::once(primary_model)
2459            .chain(alternative_models)
2460            .zip(&self.alternatives)
2461        {
2462            alternative.update(cx, |alternative, cx| {
2463                alternative.start(
2464                    user_prompt.clone(),
2465                    assistant_panel_context.clone(),
2466                    model.clone(),
2467                    cx,
2468                )
2469            })?;
2470        }
2471
2472        Ok(())
2473    }
2474
2475    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2476        for codegen in &self.alternatives {
2477            codegen.update(cx, |codegen, cx| codegen.stop(cx));
2478        }
2479    }
2480
2481    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2482        self.active_alternative()
2483            .update(cx, |codegen, cx| codegen.undo(cx));
2484
2485        self.buffer.update(cx, |buffer, cx| {
2486            if let Some(transaction_id) = self.initial_transaction_id.take() {
2487                buffer.undo_transaction(transaction_id, cx);
2488                buffer.refresh_preview(cx);
2489            }
2490        });
2491    }
2492
2493    pub fn count_tokens(
2494        &self,
2495        user_prompt: String,
2496        assistant_panel_context: Option<LanguageModelRequest>,
2497        cx: &AppContext,
2498    ) -> BoxFuture<'static, Result<TokenCounts>> {
2499        self.active_alternative()
2500            .read(cx)
2501            .count_tokens(user_prompt, assistant_panel_context, cx)
2502    }
2503
2504    pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2505        self.active_alternative().read(cx).buffer.clone()
2506    }
2507
2508    pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2509        self.active_alternative().read(cx).old_buffer.clone()
2510    }
2511
2512    pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2513        self.active_alternative().read(cx).snapshot.clone()
2514    }
2515
2516    pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2517        self.active_alternative().read(cx).edit_position
2518    }
2519
2520    fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2521        &self.active_alternative().read(cx).diff
2522    }
2523
2524    pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2525        self.active_alternative().read(cx).last_equal_ranges()
2526    }
2527}
2528
2529impl EventEmitter<CodegenEvent> for Codegen {}
2530
2531pub struct CodegenAlternative {
2532    buffer: Model<MultiBuffer>,
2533    old_buffer: Model<Buffer>,
2534    snapshot: MultiBufferSnapshot,
2535    edit_position: Option<Anchor>,
2536    range: Range<Anchor>,
2537    last_equal_ranges: Vec<Range<Anchor>>,
2538    transformation_transaction_id: Option<TransactionId>,
2539    status: CodegenStatus,
2540    generation: Task<()>,
2541    diff: Diff,
2542    telemetry: Option<Arc<Telemetry>>,
2543    _subscription: gpui::Subscription,
2544    builder: Arc<PromptBuilder>,
2545    active: bool,
2546    edits: Vec<(Range<Anchor>, String)>,
2547    line_operations: Vec<LineOperation>,
2548}
2549
2550enum CodegenStatus {
2551    Idle,
2552    Pending,
2553    Done,
2554    Error(anyhow::Error),
2555}
2556
2557#[derive(Default)]
2558struct Diff {
2559    deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2560    inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2561}
2562
2563impl Diff {
2564    fn is_empty(&self) -> bool {
2565        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2566    }
2567}
2568
2569impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2570
2571impl CodegenAlternative {
2572    pub fn new(
2573        buffer: Model<MultiBuffer>,
2574        range: Range<Anchor>,
2575        active: bool,
2576        telemetry: Option<Arc<Telemetry>>,
2577        builder: Arc<PromptBuilder>,
2578        cx: &mut ModelContext<Self>,
2579    ) -> Self {
2580        let snapshot = buffer.read(cx).snapshot(cx);
2581
2582        let (old_buffer, _, _) = buffer
2583            .read(cx)
2584            .range_to_buffer_ranges(range.clone(), cx)
2585            .pop()
2586            .unwrap();
2587        let old_buffer = cx.new_model(|cx| {
2588            let old_buffer = old_buffer.read(cx);
2589            let text = old_buffer.as_rope().clone();
2590            let line_ending = old_buffer.line_ending();
2591            let language = old_buffer.language().cloned();
2592            let language_registry = old_buffer.language_registry();
2593
2594            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2595            buffer.set_language(language, cx);
2596            if let Some(language_registry) = language_registry {
2597                buffer.set_language_registry(language_registry)
2598            }
2599            buffer
2600        });
2601
2602        Self {
2603            buffer: buffer.clone(),
2604            old_buffer,
2605            edit_position: None,
2606            snapshot,
2607            last_equal_ranges: Default::default(),
2608            transformation_transaction_id: None,
2609            status: CodegenStatus::Idle,
2610            generation: Task::ready(()),
2611            diff: Diff::default(),
2612            telemetry,
2613            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2614            builder,
2615            active,
2616            edits: Vec::new(),
2617            line_operations: Vec::new(),
2618            range,
2619        }
2620    }
2621
2622    fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2623        if active != self.active {
2624            self.active = active;
2625
2626            if self.active {
2627                let edits = self.edits.clone();
2628                self.apply_edits(edits, cx);
2629                if matches!(self.status, CodegenStatus::Pending) {
2630                    let line_operations = self.line_operations.clone();
2631                    self.reapply_line_based_diff(line_operations, cx);
2632                } else {
2633                    self.reapply_batch_diff(cx).detach();
2634                }
2635            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2636                self.buffer.update(cx, |buffer, cx| {
2637                    buffer.undo_transaction(transaction_id, cx);
2638                    buffer.forget_transaction(transaction_id, cx);
2639                });
2640            }
2641        }
2642    }
2643
2644    fn handle_buffer_event(
2645        &mut self,
2646        _buffer: Model<MultiBuffer>,
2647        event: &multi_buffer::Event,
2648        cx: &mut ModelContext<Self>,
2649    ) {
2650        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2651            if self.transformation_transaction_id == Some(*transaction_id) {
2652                self.transformation_transaction_id = None;
2653                self.generation = Task::ready(());
2654                cx.emit(CodegenEvent::Undone);
2655            }
2656        }
2657    }
2658
2659    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2660        &self.last_equal_ranges
2661    }
2662
2663    pub fn count_tokens(
2664        &self,
2665        user_prompt: String,
2666        assistant_panel_context: Option<LanguageModelRequest>,
2667        cx: &AppContext,
2668    ) -> BoxFuture<'static, Result<TokenCounts>> {
2669        if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2670            let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
2671            match request {
2672                Ok(request) => {
2673                    let total_count = model.count_tokens(request.clone(), cx);
2674                    let assistant_panel_count = assistant_panel_context
2675                        .map(|context| model.count_tokens(context, cx))
2676                        .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2677
2678                    async move {
2679                        Ok(TokenCounts {
2680                            total: total_count.await?,
2681                            assistant_panel: assistant_panel_count.await?,
2682                        })
2683                    }
2684                    .boxed()
2685                }
2686                Err(error) => futures::future::ready(Err(error)).boxed(),
2687            }
2688        } else {
2689            future::ready(Err(anyhow!("no active model"))).boxed()
2690        }
2691    }
2692
2693    pub fn start(
2694        &mut self,
2695        user_prompt: String,
2696        assistant_panel_context: Option<LanguageModelRequest>,
2697        model: Arc<dyn LanguageModel>,
2698        cx: &mut ModelContext<Self>,
2699    ) -> Result<()> {
2700        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2701            self.buffer.update(cx, |buffer, cx| {
2702                buffer.undo_transaction(transformation_transaction_id, cx);
2703            });
2704        }
2705
2706        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2707
2708        let telemetry_id = model.telemetry_id();
2709        let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
2710            if user_prompt.trim().to_lowercase() == "delete" {
2711                async { Ok(stream::empty().boxed()) }.boxed_local()
2712            } else {
2713                let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
2714
2715                let chunks = cx
2716                    .spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
2717                async move { Ok(chunks.await?.boxed()) }.boxed_local()
2718            };
2719        self.handle_stream(telemetry_id, chunks, cx);
2720        Ok(())
2721    }
2722
2723    fn build_request(
2724        &self,
2725        user_prompt: String,
2726        assistant_panel_context: Option<LanguageModelRequest>,
2727        cx: &AppContext,
2728    ) -> Result<LanguageModelRequest> {
2729        let buffer = self.buffer.read(cx).snapshot(cx);
2730        let language = buffer.language_at(self.range.start);
2731        let language_name = if let Some(language) = language.as_ref() {
2732            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2733                None
2734            } else {
2735                Some(language.name())
2736            }
2737        } else {
2738            None
2739        };
2740
2741        let language_name = language_name.as_ref();
2742        let start = buffer.point_to_buffer_offset(self.range.start);
2743        let end = buffer.point_to_buffer_offset(self.range.end);
2744        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2745            let (start_buffer, start_buffer_offset) = start;
2746            let (end_buffer, end_buffer_offset) = end;
2747            if start_buffer.remote_id() == end_buffer.remote_id() {
2748                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2749            } else {
2750                return Err(anyhow::anyhow!("invalid transformation range"));
2751            }
2752        } else {
2753            return Err(anyhow::anyhow!("invalid transformation range"));
2754        };
2755
2756        let prompt = self
2757            .builder
2758            .generate_content_prompt(user_prompt, language_name, buffer, range)
2759            .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2760
2761        let mut messages = Vec::new();
2762        if let Some(context_request) = assistant_panel_context {
2763            messages = context_request.messages;
2764        }
2765
2766        messages.push(LanguageModelRequestMessage {
2767            role: Role::User,
2768            content: vec![prompt.into()],
2769            cache: false,
2770        });
2771
2772        Ok(LanguageModelRequest {
2773            messages,
2774            tools: Vec::new(),
2775            stop: Vec::new(),
2776            temperature: None,
2777        })
2778    }
2779
2780    pub fn handle_stream(
2781        &mut self,
2782        model_telemetry_id: String,
2783        stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2784        cx: &mut ModelContext<Self>,
2785    ) {
2786        let snapshot = self.snapshot.clone();
2787        let selected_text = snapshot
2788            .text_for_range(self.range.start..self.range.end)
2789            .collect::<Rope>();
2790
2791        let selection_start = self.range.start.to_point(&snapshot);
2792
2793        // Start with the indentation of the first line in the selection
2794        let mut suggested_line_indent = snapshot
2795            .suggested_indents(selection_start.row..=selection_start.row, cx)
2796            .into_values()
2797            .next()
2798            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2799
2800        // If the first line in the selection does not have indentation, check the following lines
2801        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2802            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2803                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2804                // Prefer tabs if a line in the selection uses tabs as indentation
2805                if line_indent.kind == IndentKind::Tab {
2806                    suggested_line_indent.kind = IndentKind::Tab;
2807                    break;
2808                }
2809            }
2810        }
2811
2812        let telemetry = self.telemetry.clone();
2813        self.diff = Diff::default();
2814        self.status = CodegenStatus::Pending;
2815        let mut edit_start = self.range.start.to_offset(&snapshot);
2816        self.generation = cx.spawn(|codegen, mut cx| {
2817            async move {
2818                let chunks = stream.await;
2819                let generate = async {
2820                    let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2821                    let line_based_stream_diff: Task<anyhow::Result<()>> =
2822                        cx.background_executor().spawn(async move {
2823                            let mut response_latency = None;
2824                            let request_start = Instant::now();
2825                            let diff = async {
2826                                let chunks = StripInvalidSpans::new(chunks?);
2827                                futures::pin_mut!(chunks);
2828                                let mut diff = StreamingDiff::new(selected_text.to_string());
2829                                let mut line_diff = LineDiff::default();
2830
2831                                let mut new_text = String::new();
2832                                let mut base_indent = None;
2833                                let mut line_indent = None;
2834                                let mut first_line = true;
2835
2836                                while let Some(chunk) = chunks.next().await {
2837                                    if response_latency.is_none() {
2838                                        response_latency = Some(request_start.elapsed());
2839                                    }
2840                                    let chunk = chunk?;
2841
2842                                    let mut lines = chunk.split('\n').peekable();
2843                                    while let Some(line) = lines.next() {
2844                                        new_text.push_str(line);
2845                                        if line_indent.is_none() {
2846                                            if let Some(non_whitespace_ch_ix) =
2847                                                new_text.find(|ch: char| !ch.is_whitespace())
2848                                            {
2849                                                line_indent = Some(non_whitespace_ch_ix);
2850                                                base_indent = base_indent.or(line_indent);
2851
2852                                                let line_indent = line_indent.unwrap();
2853                                                let base_indent = base_indent.unwrap();
2854                                                let indent_delta =
2855                                                    line_indent as i32 - base_indent as i32;
2856                                                let mut corrected_indent_len = cmp::max(
2857                                                    0,
2858                                                    suggested_line_indent.len as i32 + indent_delta,
2859                                                )
2860                                                    as usize;
2861                                                if first_line {
2862                                                    corrected_indent_len = corrected_indent_len
2863                                                        .saturating_sub(
2864                                                            selection_start.column as usize,
2865                                                        );
2866                                                }
2867
2868                                                let indent_char = suggested_line_indent.char();
2869                                                let mut indent_buffer = [0; 4];
2870                                                let indent_str =
2871                                                    indent_char.encode_utf8(&mut indent_buffer);
2872                                                new_text.replace_range(
2873                                                    ..line_indent,
2874                                                    &indent_str.repeat(corrected_indent_len),
2875                                                );
2876                                            }
2877                                        }
2878
2879                                        if line_indent.is_some() {
2880                                            let char_ops = diff.push_new(&new_text);
2881                                            line_diff
2882                                                .push_char_operations(&char_ops, &selected_text);
2883                                            diff_tx
2884                                                .send((char_ops, line_diff.line_operations()))
2885                                                .await?;
2886                                            new_text.clear();
2887                                        }
2888
2889                                        if lines.peek().is_some() {
2890                                            let char_ops = diff.push_new("\n");
2891                                            line_diff
2892                                                .push_char_operations(&char_ops, &selected_text);
2893                                            diff_tx
2894                                                .send((char_ops, line_diff.line_operations()))
2895                                                .await?;
2896                                            if line_indent.is_none() {
2897                                                // Don't write out the leading indentation in empty lines on the next line
2898                                                // This is the case where the above if statement didn't clear the buffer
2899                                                new_text.clear();
2900                                            }
2901                                            line_indent = None;
2902                                            first_line = false;
2903                                        }
2904                                    }
2905                                }
2906
2907                                let mut char_ops = diff.push_new(&new_text);
2908                                char_ops.extend(diff.finish());
2909                                line_diff.push_char_operations(&char_ops, &selected_text);
2910                                line_diff.finish(&selected_text);
2911                                diff_tx
2912                                    .send((char_ops, line_diff.line_operations()))
2913                                    .await?;
2914
2915                                anyhow::Ok(())
2916                            };
2917
2918                            let result = diff.await;
2919
2920                            let error_message =
2921                                result.as_ref().err().map(|error| error.to_string());
2922                            if let Some(telemetry) = telemetry {
2923                                telemetry.report_assistant_event(
2924                                    None,
2925                                    telemetry_events::AssistantKind::Inline,
2926                                    telemetry_events::AssistantPhase::Response,
2927                                    model_telemetry_id,
2928                                    response_latency,
2929                                    error_message,
2930                                );
2931                            }
2932
2933                            result?;
2934                            Ok(())
2935                        });
2936
2937                    while let Some((char_ops, line_ops)) = diff_rx.next().await {
2938                        codegen.update(&mut cx, |codegen, cx| {
2939                            codegen.last_equal_ranges.clear();
2940
2941                            let edits = char_ops
2942                                .into_iter()
2943                                .filter_map(|operation| match operation {
2944                                    CharOperation::Insert { text } => {
2945                                        let edit_start = snapshot.anchor_after(edit_start);
2946                                        Some((edit_start..edit_start, text))
2947                                    }
2948                                    CharOperation::Delete { bytes } => {
2949                                        let edit_end = edit_start + bytes;
2950                                        let edit_range = snapshot.anchor_after(edit_start)
2951                                            ..snapshot.anchor_before(edit_end);
2952                                        edit_start = edit_end;
2953                                        Some((edit_range, String::new()))
2954                                    }
2955                                    CharOperation::Keep { bytes } => {
2956                                        let edit_end = edit_start + bytes;
2957                                        let edit_range = snapshot.anchor_after(edit_start)
2958                                            ..snapshot.anchor_before(edit_end);
2959                                        edit_start = edit_end;
2960                                        codegen.last_equal_ranges.push(edit_range);
2961                                        None
2962                                    }
2963                                })
2964                                .collect::<Vec<_>>();
2965
2966                            if codegen.active {
2967                                codegen.apply_edits(edits.iter().cloned(), cx);
2968                                codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2969                            }
2970                            codegen.edits.extend(edits);
2971                            codegen.line_operations = line_ops;
2972                            codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2973
2974                            cx.notify();
2975                        })?;
2976                    }
2977
2978                    // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2979                    // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2980                    // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2981                    let batch_diff_task =
2982                        codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
2983                    let (line_based_stream_diff, ()) =
2984                        join!(line_based_stream_diff, batch_diff_task);
2985                    line_based_stream_diff?;
2986
2987                    anyhow::Ok(())
2988                };
2989
2990                let result = generate.await;
2991                codegen
2992                    .update(&mut cx, |this, cx| {
2993                        this.last_equal_ranges.clear();
2994                        if let Err(error) = result {
2995                            this.status = CodegenStatus::Error(error);
2996                        } else {
2997                            this.status = CodegenStatus::Done;
2998                        }
2999                        cx.emit(CodegenEvent::Finished);
3000                        cx.notify();
3001                    })
3002                    .ok();
3003            }
3004        });
3005        cx.notify();
3006    }
3007
3008    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
3009        self.last_equal_ranges.clear();
3010        if self.diff.is_empty() {
3011            self.status = CodegenStatus::Idle;
3012        } else {
3013            self.status = CodegenStatus::Done;
3014        }
3015        self.generation = Task::ready(());
3016        cx.emit(CodegenEvent::Finished);
3017        cx.notify();
3018    }
3019
3020    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
3021        self.buffer.update(cx, |buffer, cx| {
3022            if let Some(transaction_id) = self.transformation_transaction_id.take() {
3023                buffer.undo_transaction(transaction_id, cx);
3024                buffer.refresh_preview(cx);
3025            }
3026        });
3027    }
3028
3029    fn apply_edits(
3030        &mut self,
3031        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
3032        cx: &mut ModelContext<CodegenAlternative>,
3033    ) {
3034        let transaction = self.buffer.update(cx, |buffer, cx| {
3035            // Avoid grouping assistant edits with user edits.
3036            buffer.finalize_last_transaction(cx);
3037            buffer.start_transaction(cx);
3038            buffer.edit(edits, None, cx);
3039            buffer.end_transaction(cx)
3040        });
3041
3042        if let Some(transaction) = transaction {
3043            if let Some(first_transaction) = self.transformation_transaction_id {
3044                // Group all assistant edits into the first transaction.
3045                self.buffer.update(cx, |buffer, cx| {
3046                    buffer.merge_transactions(transaction, first_transaction, cx)
3047                });
3048            } else {
3049                self.transformation_transaction_id = Some(transaction);
3050                self.buffer
3051                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3052            }
3053        }
3054    }
3055
3056    fn reapply_line_based_diff(
3057        &mut self,
3058        line_operations: impl IntoIterator<Item = LineOperation>,
3059        cx: &mut ModelContext<Self>,
3060    ) {
3061        let old_snapshot = self.snapshot.clone();
3062        let old_range = self.range.to_point(&old_snapshot);
3063        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3064        let new_range = self.range.to_point(&new_snapshot);
3065
3066        let mut old_row = old_range.start.row;
3067        let mut new_row = new_range.start.row;
3068
3069        self.diff.deleted_row_ranges.clear();
3070        self.diff.inserted_row_ranges.clear();
3071        for operation in line_operations {
3072            match operation {
3073                LineOperation::Keep { lines } => {
3074                    old_row += lines;
3075                    new_row += lines;
3076                }
3077                LineOperation::Delete { lines } => {
3078                    let old_end_row = old_row + lines - 1;
3079                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3080
3081                    if let Some((_, last_deleted_row_range)) =
3082                        self.diff.deleted_row_ranges.last_mut()
3083                    {
3084                        if *last_deleted_row_range.end() + 1 == old_row {
3085                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3086                        } else {
3087                            self.diff
3088                                .deleted_row_ranges
3089                                .push((new_row, old_row..=old_end_row));
3090                        }
3091                    } else {
3092                        self.diff
3093                            .deleted_row_ranges
3094                            .push((new_row, old_row..=old_end_row));
3095                    }
3096
3097                    old_row += lines;
3098                }
3099                LineOperation::Insert { lines } => {
3100                    let new_end_row = new_row + lines - 1;
3101                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3102                    let end = new_snapshot.anchor_before(Point::new(
3103                        new_end_row,
3104                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
3105                    ));
3106                    self.diff.inserted_row_ranges.push(start..=end);
3107                    new_row += lines;
3108                }
3109            }
3110
3111            cx.notify();
3112        }
3113    }
3114
3115    fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3116        let old_snapshot = self.snapshot.clone();
3117        let old_range = self.range.to_point(&old_snapshot);
3118        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3119        let new_range = self.range.to_point(&new_snapshot);
3120
3121        cx.spawn(|codegen, mut cx| async move {
3122            let (deleted_row_ranges, inserted_row_ranges) = cx
3123                .background_executor()
3124                .spawn(async move {
3125                    let old_text = old_snapshot
3126                        .text_for_range(
3127                            Point::new(old_range.start.row, 0)
3128                                ..Point::new(
3129                                    old_range.end.row,
3130                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3131                                ),
3132                        )
3133                        .collect::<String>();
3134                    let new_text = new_snapshot
3135                        .text_for_range(
3136                            Point::new(new_range.start.row, 0)
3137                                ..Point::new(
3138                                    new_range.end.row,
3139                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3140                                ),
3141                        )
3142                        .collect::<String>();
3143
3144                    let mut old_row = old_range.start.row;
3145                    let mut new_row = new_range.start.row;
3146                    let batch_diff =
3147                        similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3148
3149                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3150                    let mut inserted_row_ranges = Vec::new();
3151                    for change in batch_diff.iter_all_changes() {
3152                        let line_count = change.value().lines().count() as u32;
3153                        match change.tag() {
3154                            similar::ChangeTag::Equal => {
3155                                old_row += line_count;
3156                                new_row += line_count;
3157                            }
3158                            similar::ChangeTag::Delete => {
3159                                let old_end_row = old_row + line_count - 1;
3160                                let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3161
3162                                if let Some((_, last_deleted_row_range)) =
3163                                    deleted_row_ranges.last_mut()
3164                                {
3165                                    if *last_deleted_row_range.end() + 1 == old_row {
3166                                        *last_deleted_row_range =
3167                                            *last_deleted_row_range.start()..=old_end_row;
3168                                    } else {
3169                                        deleted_row_ranges.push((new_row, old_row..=old_end_row));
3170                                    }
3171                                } else {
3172                                    deleted_row_ranges.push((new_row, old_row..=old_end_row));
3173                                }
3174
3175                                old_row += line_count;
3176                            }
3177                            similar::ChangeTag::Insert => {
3178                                let new_end_row = new_row + line_count - 1;
3179                                let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3180                                let end = new_snapshot.anchor_before(Point::new(
3181                                    new_end_row,
3182                                    new_snapshot.line_len(MultiBufferRow(new_end_row)),
3183                                ));
3184                                inserted_row_ranges.push(start..=end);
3185                                new_row += line_count;
3186                            }
3187                        }
3188                    }
3189
3190                    (deleted_row_ranges, inserted_row_ranges)
3191                })
3192                .await;
3193
3194            codegen
3195                .update(&mut cx, |codegen, cx| {
3196                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
3197                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
3198                    cx.notify();
3199                })
3200                .ok();
3201        })
3202    }
3203}
3204
3205struct StripInvalidSpans<T> {
3206    stream: T,
3207    stream_done: bool,
3208    buffer: String,
3209    first_line: bool,
3210    line_end: bool,
3211    starts_with_code_block: bool,
3212}
3213
3214impl<T> StripInvalidSpans<T>
3215where
3216    T: Stream<Item = Result<String>>,
3217{
3218    fn new(stream: T) -> Self {
3219        Self {
3220            stream,
3221            stream_done: false,
3222            buffer: String::new(),
3223            first_line: true,
3224            line_end: false,
3225            starts_with_code_block: false,
3226        }
3227    }
3228}
3229
3230impl<T> Stream for StripInvalidSpans<T>
3231where
3232    T: Stream<Item = Result<String>>,
3233{
3234    type Item = Result<String>;
3235
3236    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3237        const CODE_BLOCK_DELIMITER: &str = "```";
3238        const CURSOR_SPAN: &str = "<|CURSOR|>";
3239
3240        let this = unsafe { self.get_unchecked_mut() };
3241        loop {
3242            if !this.stream_done {
3243                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3244                match stream.as_mut().poll_next(cx) {
3245                    Poll::Ready(Some(Ok(chunk))) => {
3246                        this.buffer.push_str(&chunk);
3247                    }
3248                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3249                    Poll::Ready(None) => {
3250                        this.stream_done = true;
3251                    }
3252                    Poll::Pending => return Poll::Pending,
3253                }
3254            }
3255
3256            let mut chunk = String::new();
3257            let mut consumed = 0;
3258            if !this.buffer.is_empty() {
3259                let mut lines = this.buffer.split('\n').enumerate().peekable();
3260                while let Some((line_ix, line)) = lines.next() {
3261                    if line_ix > 0 {
3262                        this.first_line = false;
3263                    }
3264
3265                    if this.first_line {
3266                        let trimmed_line = line.trim();
3267                        if lines.peek().is_some() {
3268                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3269                                consumed += line.len() + 1;
3270                                this.starts_with_code_block = true;
3271                                continue;
3272                            }
3273                        } else if trimmed_line.is_empty()
3274                            || prefixes(CODE_BLOCK_DELIMITER)
3275                                .any(|prefix| trimmed_line.starts_with(prefix))
3276                        {
3277                            break;
3278                        }
3279                    }
3280
3281                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
3282                    if lines.peek().is_some() {
3283                        if this.line_end {
3284                            chunk.push('\n');
3285                        }
3286
3287                        chunk.push_str(&line_without_cursor);
3288                        this.line_end = true;
3289                        consumed += line.len() + 1;
3290                    } else if this.stream_done {
3291                        if !this.starts_with_code_block
3292                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3293                        {
3294                            if this.line_end {
3295                                chunk.push('\n');
3296                            }
3297
3298                            chunk.push_str(&line);
3299                        }
3300
3301                        consumed += line.len();
3302                    } else {
3303                        let trimmed_line = line.trim();
3304                        if trimmed_line.is_empty()
3305                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3306                            || prefixes(CODE_BLOCK_DELIMITER)
3307                                .any(|prefix| trimmed_line.ends_with(prefix))
3308                        {
3309                            break;
3310                        } else {
3311                            if this.line_end {
3312                                chunk.push('\n');
3313                                this.line_end = false;
3314                            }
3315
3316                            chunk.push_str(&line_without_cursor);
3317                            consumed += line.len();
3318                        }
3319                    }
3320                }
3321            }
3322
3323            this.buffer = this.buffer.split_off(consumed);
3324            if !chunk.is_empty() {
3325                return Poll::Ready(Some(Ok(chunk)));
3326            } else if this.stream_done {
3327                return Poll::Ready(None);
3328            }
3329        }
3330    }
3331}
3332
3333struct AssistantCodeActionProvider {
3334    editor: WeakView<Editor>,
3335    workspace: WeakView<Workspace>,
3336}
3337
3338impl CodeActionProvider for AssistantCodeActionProvider {
3339    fn code_actions(
3340        &self,
3341        buffer: &Model<Buffer>,
3342        range: Range<text::Anchor>,
3343        cx: &mut WindowContext,
3344    ) -> Task<Result<Vec<CodeAction>>> {
3345        let snapshot = buffer.read(cx).snapshot();
3346        let mut range = range.to_point(&snapshot);
3347
3348        // Expand the range to line boundaries.
3349        range.start.column = 0;
3350        range.end.column = snapshot.line_len(range.end.row);
3351
3352        let mut has_diagnostics = false;
3353        for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3354            range.start = cmp::min(range.start, diagnostic.range.start);
3355            range.end = cmp::max(range.end, diagnostic.range.end);
3356            has_diagnostics = true;
3357        }
3358        if has_diagnostics {
3359            if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3360                if let Some(symbol) = symbols_containing_start.last() {
3361                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3362                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3363                }
3364            }
3365
3366            if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3367                if let Some(symbol) = symbols_containing_end.last() {
3368                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3369                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3370                }
3371            }
3372
3373            Task::ready(Ok(vec![CodeAction {
3374                server_id: language::LanguageServerId(0),
3375                range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3376                lsp_action: lsp::CodeAction {
3377                    title: "Fix with Assistant".into(),
3378                    ..Default::default()
3379                },
3380            }]))
3381        } else {
3382            Task::ready(Ok(Vec::new()))
3383        }
3384    }
3385
3386    fn apply_code_action(
3387        &self,
3388        buffer: Model<Buffer>,
3389        action: CodeAction,
3390        excerpt_id: ExcerptId,
3391        _push_to_history: bool,
3392        cx: &mut WindowContext,
3393    ) -> Task<Result<ProjectTransaction>> {
3394        let editor = self.editor.clone();
3395        let workspace = self.workspace.clone();
3396        cx.spawn(|mut cx| async move {
3397            let editor = editor.upgrade().context("editor was released")?;
3398            let range = editor
3399                .update(&mut cx, |editor, cx| {
3400                    editor.buffer().update(cx, |multibuffer, cx| {
3401                        let buffer = buffer.read(cx);
3402                        let multibuffer_snapshot = multibuffer.read(cx);
3403
3404                        let old_context_range =
3405                            multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3406                        let mut new_context_range = old_context_range.clone();
3407                        if action
3408                            .range
3409                            .start
3410                            .cmp(&old_context_range.start, buffer)
3411                            .is_lt()
3412                        {
3413                            new_context_range.start = action.range.start;
3414                        }
3415                        if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3416                            new_context_range.end = action.range.end;
3417                        }
3418                        drop(multibuffer_snapshot);
3419
3420                        if new_context_range != old_context_range {
3421                            multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3422                        }
3423
3424                        let multibuffer_snapshot = multibuffer.read(cx);
3425                        Some(
3426                            multibuffer_snapshot
3427                                .anchor_in_excerpt(excerpt_id, action.range.start)?
3428                                ..multibuffer_snapshot
3429                                    .anchor_in_excerpt(excerpt_id, action.range.end)?,
3430                        )
3431                    })
3432                })?
3433                .context("invalid range")?;
3434            let assistant_panel = workspace.update(&mut cx, |workspace, cx| {
3435                workspace
3436                    .panel::<AssistantPanel>(cx)
3437                    .context("assistant panel was released")
3438            })??;
3439
3440            cx.update_global(|assistant: &mut InlineAssistant, cx| {
3441                let assist_id = assistant.suggest_assist(
3442                    &editor,
3443                    range,
3444                    "Fix Diagnostics".into(),
3445                    None,
3446                    true,
3447                    Some(workspace),
3448                    Some(&assistant_panel),
3449                    cx,
3450                );
3451                assistant.start_assist(assist_id, cx);
3452            })?;
3453
3454            Ok(ProjectTransaction::default())
3455        })
3456    }
3457}
3458
3459fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3460    (0..text.len() - 1).map(|ix| &text[..ix + 1])
3461}
3462
3463fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3464    ranges.sort_unstable_by(|a, b| {
3465        a.start
3466            .cmp(&b.start, buffer)
3467            .then_with(|| b.end.cmp(&a.end, buffer))
3468    });
3469
3470    let mut ix = 0;
3471    while ix + 1 < ranges.len() {
3472        let b = ranges[ix + 1].clone();
3473        let a = &mut ranges[ix];
3474        if a.end.cmp(&b.start, buffer).is_gt() {
3475            if a.end.cmp(&b.end, buffer).is_lt() {
3476                a.end = b.end;
3477            }
3478            ranges.remove(ix + 1);
3479        } else {
3480            ix += 1;
3481        }
3482    }
3483}
3484
3485#[cfg(test)]
3486mod tests {
3487    use super::*;
3488    use futures::stream::{self};
3489    use gpui::{Context, TestAppContext};
3490    use indoc::indoc;
3491    use language::{
3492        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3493        Point,
3494    };
3495    use language_model::LanguageModelRegistry;
3496    use rand::prelude::*;
3497    use serde::Serialize;
3498    use settings::SettingsStore;
3499    use std::{future, sync::Arc};
3500
3501    #[derive(Serialize)]
3502    pub struct DummyCompletionRequest {
3503        pub name: String,
3504    }
3505
3506    #[gpui::test(iterations = 10)]
3507    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3508        cx.set_global(cx.update(SettingsStore::test));
3509        cx.update(language_model::LanguageModelRegistry::test);
3510        cx.update(language_settings::init);
3511
3512        let text = indoc! {"
3513            fn main() {
3514                let x = 0;
3515                for _ in 0..10 {
3516                    x += 1;
3517                }
3518            }
3519        "};
3520        let buffer =
3521            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3522        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3523        let range = buffer.read_with(cx, |buffer, cx| {
3524            let snapshot = buffer.snapshot(cx);
3525            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3526        });
3527        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3528        let codegen = cx.new_model(|cx| {
3529            CodegenAlternative::new(
3530                buffer.clone(),
3531                range.clone(),
3532                true,
3533                None,
3534                prompt_builder,
3535                cx,
3536            )
3537        });
3538
3539        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3540        codegen.update(cx, |codegen, cx| {
3541            codegen.handle_stream(
3542                String::new(),
3543                future::ready(Ok(chunks_rx.map(Ok).boxed())),
3544                cx,
3545            )
3546        });
3547
3548        let mut new_text = concat!(
3549            "       let mut x = 0;\n",
3550            "       while x < 10 {\n",
3551            "           x += 1;\n",
3552            "       }",
3553        );
3554        while !new_text.is_empty() {
3555            let max_len = cmp::min(new_text.len(), 10);
3556            let len = rng.gen_range(1..=max_len);
3557            let (chunk, suffix) = new_text.split_at(len);
3558            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3559            new_text = suffix;
3560            cx.background_executor.run_until_parked();
3561        }
3562        drop(chunks_tx);
3563        cx.background_executor.run_until_parked();
3564
3565        assert_eq!(
3566            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3567            indoc! {"
3568                fn main() {
3569                    let mut x = 0;
3570                    while x < 10 {
3571                        x += 1;
3572                    }
3573                }
3574            "}
3575        );
3576    }
3577
3578    #[gpui::test(iterations = 10)]
3579    async fn test_autoindent_when_generating_past_indentation(
3580        cx: &mut TestAppContext,
3581        mut rng: StdRng,
3582    ) {
3583        cx.set_global(cx.update(SettingsStore::test));
3584        cx.update(language_settings::init);
3585
3586        let text = indoc! {"
3587            fn main() {
3588                le
3589            }
3590        "};
3591        let buffer =
3592            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3593        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3594        let range = buffer.read_with(cx, |buffer, cx| {
3595            let snapshot = buffer.snapshot(cx);
3596            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3597        });
3598        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3599        let codegen = cx.new_model(|cx| {
3600            CodegenAlternative::new(
3601                buffer.clone(),
3602                range.clone(),
3603                true,
3604                None,
3605                prompt_builder,
3606                cx,
3607            )
3608        });
3609
3610        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3611        codegen.update(cx, |codegen, cx| {
3612            codegen.handle_stream(
3613                String::new(),
3614                future::ready(Ok(chunks_rx.map(Ok).boxed())),
3615                cx,
3616            )
3617        });
3618
3619        cx.background_executor.run_until_parked();
3620
3621        let mut new_text = concat!(
3622            "t mut x = 0;\n",
3623            "while x < 10 {\n",
3624            "    x += 1;\n",
3625            "}", //
3626        );
3627        while !new_text.is_empty() {
3628            let max_len = cmp::min(new_text.len(), 10);
3629            let len = rng.gen_range(1..=max_len);
3630            let (chunk, suffix) = new_text.split_at(len);
3631            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3632            new_text = suffix;
3633            cx.background_executor.run_until_parked();
3634        }
3635        drop(chunks_tx);
3636        cx.background_executor.run_until_parked();
3637
3638        assert_eq!(
3639            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3640            indoc! {"
3641                fn main() {
3642                    let mut x = 0;
3643                    while x < 10 {
3644                        x += 1;
3645                    }
3646                }
3647            "}
3648        );
3649    }
3650
3651    #[gpui::test(iterations = 10)]
3652    async fn test_autoindent_when_generating_before_indentation(
3653        cx: &mut TestAppContext,
3654        mut rng: StdRng,
3655    ) {
3656        cx.update(LanguageModelRegistry::test);
3657        cx.set_global(cx.update(SettingsStore::test));
3658        cx.update(language_settings::init);
3659
3660        let text = concat!(
3661            "fn main() {\n",
3662            "  \n",
3663            "}\n" //
3664        );
3665        let buffer =
3666            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3667        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3668        let range = buffer.read_with(cx, |buffer, cx| {
3669            let snapshot = buffer.snapshot(cx);
3670            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3671        });
3672        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3673        let codegen = cx.new_model(|cx| {
3674            CodegenAlternative::new(
3675                buffer.clone(),
3676                range.clone(),
3677                true,
3678                None,
3679                prompt_builder,
3680                cx,
3681            )
3682        });
3683
3684        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3685        codegen.update(cx, |codegen, cx| {
3686            codegen.handle_stream(
3687                String::new(),
3688                future::ready(Ok(chunks_rx.map(Ok).boxed())),
3689                cx,
3690            )
3691        });
3692
3693        cx.background_executor.run_until_parked();
3694
3695        let mut new_text = concat!(
3696            "let mut x = 0;\n",
3697            "while x < 10 {\n",
3698            "    x += 1;\n",
3699            "}", //
3700        );
3701        while !new_text.is_empty() {
3702            let max_len = cmp::min(new_text.len(), 10);
3703            let len = rng.gen_range(1..=max_len);
3704            let (chunk, suffix) = new_text.split_at(len);
3705            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3706            new_text = suffix;
3707            cx.background_executor.run_until_parked();
3708        }
3709        drop(chunks_tx);
3710        cx.background_executor.run_until_parked();
3711
3712        assert_eq!(
3713            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3714            indoc! {"
3715                fn main() {
3716                    let mut x = 0;
3717                    while x < 10 {
3718                        x += 1;
3719                    }
3720                }
3721            "}
3722        );
3723    }
3724
3725    #[gpui::test(iterations = 10)]
3726    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3727        cx.update(LanguageModelRegistry::test);
3728        cx.set_global(cx.update(SettingsStore::test));
3729        cx.update(language_settings::init);
3730
3731        let text = indoc! {"
3732            func main() {
3733            \tx := 0
3734            \tfor i := 0; i < 10; i++ {
3735            \t\tx++
3736            \t}
3737            }
3738        "};
3739        let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3740        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3741        let range = buffer.read_with(cx, |buffer, cx| {
3742            let snapshot = buffer.snapshot(cx);
3743            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3744        });
3745        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3746        let codegen = cx.new_model(|cx| {
3747            CodegenAlternative::new(
3748                buffer.clone(),
3749                range.clone(),
3750                true,
3751                None,
3752                prompt_builder,
3753                cx,
3754            )
3755        });
3756
3757        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3758        codegen.update(cx, |codegen, cx| {
3759            codegen.handle_stream(
3760                String::new(),
3761                future::ready(Ok(chunks_rx.map(Ok).boxed())),
3762                cx,
3763            )
3764        });
3765
3766        let new_text = concat!(
3767            "func main() {\n",
3768            "\tx := 0\n",
3769            "\tfor x < 10 {\n",
3770            "\t\tx++\n",
3771            "\t}", //
3772        );
3773        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3774        drop(chunks_tx);
3775        cx.background_executor.run_until_parked();
3776
3777        assert_eq!(
3778            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3779            indoc! {"
3780                func main() {
3781                \tx := 0
3782                \tfor x < 10 {
3783                \t\tx++
3784                \t}
3785                }
3786            "}
3787        );
3788    }
3789
3790    #[gpui::test]
3791    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3792        cx.update(LanguageModelRegistry::test);
3793        cx.set_global(cx.update(SettingsStore::test));
3794        cx.update(language_settings::init);
3795
3796        let text = indoc! {"
3797            fn main() {
3798                let x = 0;
3799            }
3800        "};
3801        let buffer =
3802            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3803        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3804        let range = buffer.read_with(cx, |buffer, cx| {
3805            let snapshot = buffer.snapshot(cx);
3806            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3807        });
3808        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3809        let codegen = cx.new_model(|cx| {
3810            CodegenAlternative::new(
3811                buffer.clone(),
3812                range.clone(),
3813                false,
3814                None,
3815                prompt_builder,
3816                cx,
3817            )
3818        });
3819
3820        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3821        codegen.update(cx, |codegen, cx| {
3822            codegen.handle_stream(
3823                String::new(),
3824                future::ready(Ok(chunks_rx.map(Ok).boxed())),
3825                cx,
3826            )
3827        });
3828
3829        chunks_tx
3830            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3831            .unwrap();
3832        drop(chunks_tx);
3833        cx.run_until_parked();
3834
3835        // The codegen is inactive, so the buffer doesn't get modified.
3836        assert_eq!(
3837            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3838            text
3839        );
3840
3841        // Activating the codegen applies the changes.
3842        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3843        assert_eq!(
3844            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3845            indoc! {"
3846                fn main() {
3847                    let mut x = 0;
3848                    x += 1;
3849                }
3850            "}
3851        );
3852
3853        // Deactivating the codegen undoes the changes.
3854        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3855        cx.run_until_parked();
3856        assert_eq!(
3857            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3858            text
3859        );
3860    }
3861
3862    #[gpui::test]
3863    async fn test_strip_invalid_spans_from_codeblock() {
3864        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3865        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3866        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3867        assert_chunks(
3868            "```html\n```js\nLorem ipsum dolor\n```\n```",
3869            "```js\nLorem ipsum dolor\n```",
3870        )
3871        .await;
3872        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3873        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3874        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3875        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3876
3877        async fn assert_chunks(text: &str, expected_text: &str) {
3878            for chunk_size in 1..=text.len() {
3879                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3880                    .map(|chunk| chunk.unwrap())
3881                    .collect::<String>()
3882                    .await;
3883                assert_eq!(
3884                    actual_text, expected_text,
3885                    "failed to strip invalid spans, chunk size: {}",
3886                    chunk_size
3887                );
3888            }
3889        }
3890
3891        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3892            stream::iter(
3893                text.chars()
3894                    .collect::<Vec<_>>()
3895                    .chunks(size)
3896                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
3897                    .collect::<Vec<_>>(),
3898            )
3899        }
3900    }
3901
3902    fn rust_lang() -> Language {
3903        Language::new(
3904            LanguageConfig {
3905                name: "Rust".into(),
3906                matcher: LanguageMatcher {
3907                    path_suffixes: vec!["rs".to_string()],
3908                    ..Default::default()
3909                },
3910                ..Default::default()
3911            },
3912            Some(tree_sitter_rust::LANGUAGE.into()),
3913        )
3914        .with_indents_query(
3915            r#"
3916            (call_expression) @indent
3917            (field_expression) @indent
3918            (_ "(" ")" @end) @indent
3919            (_ "{" "}" @end) @indent
3920            "#,
3921        )
3922        .unwrap()
3923    }
3924}