inline_assistant.rs

   1use crate::{
   2    assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder,
   3    AssistantPanel, AssistantPanelEvent, CharOperation, CycleNextInlineAssist,
   4    CyclePreviousInlineAssist, LineDiff, LineOperation, ModelSelector, RequestType, StreamingDiff,
   5};
   6use anyhow::{anyhow, Context as _, Result};
   7use client::{telemetry::Telemetry, ErrorExt};
   8use collections::{hash_map, HashMap, HashSet, VecDeque};
   9use editor::{
  10    actions::{MoveDown, MoveUp, SelectAll},
  11    display_map::{
  12        BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
  13        ToDisplayPoint,
  14    },
  15    Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
  16    EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
  17    ToOffset as _, ToPoint,
  18};
  19use feature_flags::{FeatureFlagAppExt as _, ZedPro};
  20use fs::Fs;
  21use futures::{
  22    channel::mpsc,
  23    future::{BoxFuture, LocalBoxFuture},
  24    join, SinkExt, Stream, StreamExt,
  25};
  26use gpui::{
  27    anchored, deferred, point, AnyElement, AppContext, ClickEvent, EventEmitter, FocusHandle,
  28    FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task,
  29    TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
  30};
  31use language::{Buffer, IndentKind, Point, Selection, TransactionId};
  32use language_model::{
  33    logging::report_assistant_event, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
  34    LanguageModelRequestMessage, LanguageModelTextStream, Role,
  35};
  36use multi_buffer::MultiBufferRow;
  37use parking_lot::Mutex;
  38use project::{CodeAction, ProjectTransaction};
  39use rope::Rope;
  40use settings::{Settings, SettingsStore};
  41use smol::future::FutureExt;
  42use std::{
  43    cmp,
  44    future::{self, Future},
  45    iter, mem,
  46    ops::{Range, RangeInclusive},
  47    pin::Pin,
  48    sync::Arc,
  49    task::{self, Poll},
  50    time::{Duration, Instant},
  51};
  52use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
  53use terminal_view::terminal_panel::TerminalPanel;
  54use text::{OffsetRangeExt, ToPoint as _};
  55use theme::ThemeSettings;
  56use ui::{prelude::*, text_for_action, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
  57use util::{RangeExt, ResultExt};
  58use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
  59
  60pub fn init(
  61    fs: Arc<dyn Fs>,
  62    prompt_builder: Arc<PromptBuilder>,
  63    telemetry: Arc<Telemetry>,
  64    cx: &mut AppContext,
  65) {
  66    cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
  67    cx.observe_new_views(|_, cx| {
  68        let workspace = cx.view().clone();
  69        InlineAssistant::update_global(cx, |inline_assistant, cx| {
  70            inline_assistant.register_workspace(&workspace, cx)
  71        })
  72    })
  73    .detach();
  74}
  75
  76const PROMPT_HISTORY_MAX_LEN: usize = 20;
  77
  78pub struct InlineAssistant {
  79    next_assist_id: InlineAssistId,
  80    next_assist_group_id: InlineAssistGroupId,
  81    assists: HashMap<InlineAssistId, InlineAssist>,
  82    assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
  83    assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
  84    confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
  85    prompt_history: VecDeque<String>,
  86    prompt_builder: Arc<PromptBuilder>,
  87    telemetry: Arc<Telemetry>,
  88    fs: Arc<dyn Fs>,
  89}
  90
  91impl Global for InlineAssistant {}
  92
  93impl InlineAssistant {
  94    pub fn new(
  95        fs: Arc<dyn Fs>,
  96        prompt_builder: Arc<PromptBuilder>,
  97        telemetry: Arc<Telemetry>,
  98    ) -> Self {
  99        Self {
 100            next_assist_id: InlineAssistId::default(),
 101            next_assist_group_id: InlineAssistGroupId::default(),
 102            assists: HashMap::default(),
 103            assists_by_editor: HashMap::default(),
 104            assist_groups: HashMap::default(),
 105            confirmed_assists: HashMap::default(),
 106            prompt_history: VecDeque::default(),
 107            prompt_builder,
 108            telemetry,
 109            fs,
 110        }
 111    }
 112
 113    pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
 114        cx.subscribe(workspace, |workspace, event, cx| {
 115            Self::update_global(cx, |this, cx| {
 116                this.handle_workspace_event(workspace, event, cx)
 117            });
 118        })
 119        .detach();
 120
 121        let workspace = workspace.downgrade();
 122        cx.observe_global::<SettingsStore>(move |cx| {
 123            let Some(workspace) = workspace.upgrade() else {
 124                return;
 125            };
 126            let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
 127                return;
 128            };
 129            let enabled = AssistantSettings::get_global(cx).enabled;
 130            terminal_panel.update(cx, |terminal_panel, cx| {
 131                terminal_panel.asssistant_enabled(enabled, cx)
 132            });
 133        })
 134        .detach();
 135    }
 136
 137    fn handle_workspace_event(
 138        &mut self,
 139        workspace: View<Workspace>,
 140        event: &workspace::Event,
 141        cx: &mut WindowContext,
 142    ) {
 143        match event {
 144            workspace::Event::UserSavedItem { item, .. } => {
 145                // When the user manually saves an editor, automatically accepts all finished transformations.
 146                if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
 147                    if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
 148                        for assist_id in editor_assists.assist_ids.clone() {
 149                            let assist = &self.assists[&assist_id];
 150                            if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
 151                                self.finish_assist(assist_id, false, cx)
 152                            }
 153                        }
 154                    }
 155                }
 156            }
 157            workspace::Event::ItemAdded { item } => {
 158                self.register_workspace_item(&workspace, item.as_ref(), cx);
 159            }
 160            _ => (),
 161        }
 162    }
 163
 164    fn register_workspace_item(
 165        &mut self,
 166        workspace: &View<Workspace>,
 167        item: &dyn ItemHandle,
 168        cx: &mut WindowContext,
 169    ) {
 170        if let Some(editor) = item.act_as::<Editor>(cx) {
 171            editor.update(cx, |editor, cx| {
 172                editor.push_code_action_provider(
 173                    Arc::new(AssistantCodeActionProvider {
 174                        editor: cx.view().downgrade(),
 175                        workspace: workspace.downgrade(),
 176                    }),
 177                    cx,
 178                );
 179            });
 180        }
 181    }
 182
 183    pub fn assist(
 184        &mut self,
 185        editor: &View<Editor>,
 186        workspace: Option<WeakView<Workspace>>,
 187        assistant_panel: Option<&View<AssistantPanel>>,
 188        initial_prompt: Option<String>,
 189        cx: &mut WindowContext,
 190    ) {
 191        let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
 192            (
 193                editor.buffer().read(cx).snapshot(cx),
 194                editor.selections.all::<Point>(cx),
 195            )
 196        });
 197
 198        let mut selections = Vec::<Selection<Point>>::new();
 199        let mut newest_selection = None;
 200        for mut selection in initial_selections {
 201            if selection.end > selection.start {
 202                selection.start.column = 0;
 203                // If the selection ends at the start of the line, we don't want to include it.
 204                if selection.end.column == 0 {
 205                    selection.end.row -= 1;
 206                }
 207                selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
 208            }
 209
 210            if let Some(prev_selection) = selections.last_mut() {
 211                if selection.start <= prev_selection.end {
 212                    prev_selection.end = selection.end;
 213                    continue;
 214                }
 215            }
 216
 217            let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
 218            if selection.id > latest_selection.id {
 219                *latest_selection = selection.clone();
 220            }
 221            selections.push(selection);
 222        }
 223        let newest_selection = newest_selection.unwrap();
 224
 225        let mut codegen_ranges = Vec::new();
 226        for (excerpt_id, buffer, buffer_range) in
 227            snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
 228                snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
 229            }))
 230        {
 231            let start = Anchor {
 232                buffer_id: Some(buffer.remote_id()),
 233                excerpt_id,
 234                text_anchor: buffer.anchor_before(buffer_range.start),
 235            };
 236            let end = Anchor {
 237                buffer_id: Some(buffer.remote_id()),
 238                excerpt_id,
 239                text_anchor: buffer.anchor_after(buffer_range.end),
 240            };
 241            codegen_ranges.push(start..end);
 242
 243            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
 244                self.telemetry.report_assistant_event(AssistantEvent {
 245                    conversation_id: None,
 246                    kind: AssistantKind::Inline,
 247                    phase: AssistantPhase::Invoked,
 248                    message_id: None,
 249                    model: model.telemetry_id(),
 250                    model_provider: model.provider_id().to_string(),
 251                    response_latency: None,
 252                    error_message: None,
 253                    language_name: buffer.language().map(|language| language.name().to_proto()),
 254                });
 255            }
 256        }
 257
 258        let assist_group_id = self.next_assist_group_id.post_inc();
 259        let prompt_buffer =
 260            cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
 261        let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 262
 263        let mut assists = Vec::new();
 264        let mut assist_to_focus = None;
 265        for range in codegen_ranges {
 266            let assist_id = self.next_assist_id.post_inc();
 267            let codegen = cx.new_model(|cx| {
 268                Codegen::new(
 269                    editor.read(cx).buffer().clone(),
 270                    range.clone(),
 271                    None,
 272                    self.telemetry.clone(),
 273                    self.prompt_builder.clone(),
 274                    cx,
 275                )
 276            });
 277
 278            let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 279            let prompt_editor = cx.new_view(|cx| {
 280                PromptEditor::new(
 281                    assist_id,
 282                    gutter_dimensions.clone(),
 283                    self.prompt_history.clone(),
 284                    prompt_buffer.clone(),
 285                    codegen.clone(),
 286                    editor,
 287                    assistant_panel,
 288                    workspace.clone(),
 289                    self.fs.clone(),
 290                    cx,
 291                )
 292            });
 293
 294            if assist_to_focus.is_none() {
 295                let focus_assist = if newest_selection.reversed {
 296                    range.start.to_point(&snapshot) == newest_selection.start
 297                } else {
 298                    range.end.to_point(&snapshot) == newest_selection.end
 299                };
 300                if focus_assist {
 301                    assist_to_focus = Some(assist_id);
 302                }
 303            }
 304
 305            let [prompt_block_id, end_block_id] =
 306                self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 307
 308            assists.push((
 309                assist_id,
 310                range,
 311                prompt_editor,
 312                prompt_block_id,
 313                end_block_id,
 314            ));
 315        }
 316
 317        let editor_assists = self
 318            .assists_by_editor
 319            .entry(editor.downgrade())
 320            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 321        let mut assist_group = InlineAssistGroup::new();
 322        for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
 323            self.assists.insert(
 324                assist_id,
 325                InlineAssist::new(
 326                    assist_id,
 327                    assist_group_id,
 328                    assistant_panel.is_some(),
 329                    editor,
 330                    &prompt_editor,
 331                    prompt_block_id,
 332                    end_block_id,
 333                    range,
 334                    prompt_editor.read(cx).codegen.clone(),
 335                    workspace.clone(),
 336                    cx,
 337                ),
 338            );
 339            assist_group.assist_ids.push(assist_id);
 340            editor_assists.assist_ids.push(assist_id);
 341        }
 342        self.assist_groups.insert(assist_group_id, assist_group);
 343
 344        if let Some(assist_id) = assist_to_focus {
 345            self.focus_assist(assist_id, cx);
 346        }
 347    }
 348
 349    #[allow(clippy::too_many_arguments)]
 350    pub fn suggest_assist(
 351        &mut self,
 352        editor: &View<Editor>,
 353        mut range: Range<Anchor>,
 354        initial_prompt: String,
 355        initial_transaction_id: Option<TransactionId>,
 356        focus: bool,
 357        workspace: Option<WeakView<Workspace>>,
 358        assistant_panel: Option<&View<AssistantPanel>>,
 359        cx: &mut WindowContext,
 360    ) -> InlineAssistId {
 361        let assist_group_id = self.next_assist_group_id.post_inc();
 362        let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
 363        let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 364
 365        let assist_id = self.next_assist_id.post_inc();
 366
 367        let buffer = editor.read(cx).buffer().clone();
 368        {
 369            let snapshot = buffer.read(cx).read(cx);
 370            range.start = range.start.bias_left(&snapshot);
 371            range.end = range.end.bias_right(&snapshot);
 372        }
 373
 374        let codegen = cx.new_model(|cx| {
 375            Codegen::new(
 376                editor.read(cx).buffer().clone(),
 377                range.clone(),
 378                initial_transaction_id,
 379                self.telemetry.clone(),
 380                self.prompt_builder.clone(),
 381                cx,
 382            )
 383        });
 384
 385        let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 386        let prompt_editor = cx.new_view(|cx| {
 387            PromptEditor::new(
 388                assist_id,
 389                gutter_dimensions.clone(),
 390                self.prompt_history.clone(),
 391                prompt_buffer.clone(),
 392                codegen.clone(),
 393                editor,
 394                assistant_panel,
 395                workspace.clone(),
 396                self.fs.clone(),
 397                cx,
 398            )
 399        });
 400
 401        let [prompt_block_id, end_block_id] =
 402            self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 403
 404        let editor_assists = self
 405            .assists_by_editor
 406            .entry(editor.downgrade())
 407            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 408
 409        let mut assist_group = InlineAssistGroup::new();
 410        self.assists.insert(
 411            assist_id,
 412            InlineAssist::new(
 413                assist_id,
 414                assist_group_id,
 415                assistant_panel.is_some(),
 416                editor,
 417                &prompt_editor,
 418                prompt_block_id,
 419                end_block_id,
 420                range,
 421                prompt_editor.read(cx).codegen.clone(),
 422                workspace.clone(),
 423                cx,
 424            ),
 425        );
 426        assist_group.assist_ids.push(assist_id);
 427        editor_assists.assist_ids.push(assist_id);
 428        self.assist_groups.insert(assist_group_id, assist_group);
 429
 430        if focus {
 431            self.focus_assist(assist_id, cx);
 432        }
 433
 434        assist_id
 435    }
 436
 437    fn insert_assist_blocks(
 438        &self,
 439        editor: &View<Editor>,
 440        range: &Range<Anchor>,
 441        prompt_editor: &View<PromptEditor>,
 442        cx: &mut WindowContext,
 443    ) -> [CustomBlockId; 2] {
 444        let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
 445            prompt_editor
 446                .editor
 447                .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
 448        });
 449        let assist_blocks = vec![
 450            BlockProperties {
 451                style: BlockStyle::Sticky,
 452                placement: BlockPlacement::Above(range.start),
 453                height: prompt_editor_height,
 454                render: build_assist_editor_renderer(prompt_editor),
 455                priority: 0,
 456            },
 457            BlockProperties {
 458                style: BlockStyle::Sticky,
 459                placement: BlockPlacement::Below(range.end),
 460                height: 0,
 461                render: Box::new(|cx| {
 462                    v_flex()
 463                        .h_full()
 464                        .w_full()
 465                        .border_t_1()
 466                        .border_color(cx.theme().status().info_border)
 467                        .into_any_element()
 468                }),
 469                priority: 0,
 470            },
 471        ];
 472
 473        editor.update(cx, |editor, cx| {
 474            let block_ids = editor.insert_blocks(assist_blocks, None, cx);
 475            [block_ids[0], block_ids[1]]
 476        })
 477    }
 478
 479    fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 480        let assist = &self.assists[&assist_id];
 481        let Some(decorations) = assist.decorations.as_ref() else {
 482            return;
 483        };
 484        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 485        let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
 486
 487        assist_group.active_assist_id = Some(assist_id);
 488        if assist_group.linked {
 489            for assist_id in &assist_group.assist_ids {
 490                if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 491                    decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 492                        prompt_editor.set_show_cursor_when_unfocused(true, cx)
 493                    });
 494                }
 495            }
 496        }
 497
 498        assist
 499            .editor
 500            .update(cx, |editor, cx| {
 501                let scroll_top = editor.scroll_position(cx).y;
 502                let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
 503                let prompt_row = editor
 504                    .row_for_block(decorations.prompt_block_id, cx)
 505                    .unwrap()
 506                    .0 as f32;
 507
 508                if (scroll_top..scroll_bottom).contains(&prompt_row) {
 509                    editor_assists.scroll_lock = Some(InlineAssistScrollLock {
 510                        assist_id,
 511                        distance_from_top: prompt_row - scroll_top,
 512                    });
 513                } else {
 514                    editor_assists.scroll_lock = None;
 515                }
 516            })
 517            .ok();
 518    }
 519
 520    fn handle_prompt_editor_focus_out(
 521        &mut self,
 522        assist_id: InlineAssistId,
 523        cx: &mut WindowContext,
 524    ) {
 525        let assist = &self.assists[&assist_id];
 526        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 527        if assist_group.active_assist_id == Some(assist_id) {
 528            assist_group.active_assist_id = None;
 529            if assist_group.linked {
 530                for assist_id in &assist_group.assist_ids {
 531                    if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 532                        decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 533                            prompt_editor.set_show_cursor_when_unfocused(false, cx)
 534                        });
 535                    }
 536                }
 537            }
 538        }
 539    }
 540
 541    fn handle_prompt_editor_event(
 542        &mut self,
 543        prompt_editor: View<PromptEditor>,
 544        event: &PromptEditorEvent,
 545        cx: &mut WindowContext,
 546    ) {
 547        let assist_id = prompt_editor.read(cx).id;
 548        match event {
 549            PromptEditorEvent::StartRequested => {
 550                self.start_assist(assist_id, cx);
 551            }
 552            PromptEditorEvent::StopRequested => {
 553                self.stop_assist(assist_id, cx);
 554            }
 555            PromptEditorEvent::ConfirmRequested => {
 556                self.finish_assist(assist_id, false, cx);
 557            }
 558            PromptEditorEvent::CancelRequested => {
 559                self.finish_assist(assist_id, true, cx);
 560            }
 561            PromptEditorEvent::DismissRequested => {
 562                self.dismiss_assist(assist_id, cx);
 563            }
 564        }
 565    }
 566
 567    fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 568        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 569            return;
 570        };
 571
 572        if editor.read(cx).selections.count() == 1 {
 573            let (selection, buffer) = editor.update(cx, |editor, cx| {
 574                (
 575                    editor.selections.newest::<usize>(cx),
 576                    editor.buffer().read(cx).snapshot(cx),
 577                )
 578            });
 579            for assist_id in &editor_assists.assist_ids {
 580                let assist = &self.assists[assist_id];
 581                let assist_range = assist.range.to_offset(&buffer);
 582                if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
 583                {
 584                    if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
 585                        self.dismiss_assist(*assist_id, cx);
 586                    } else {
 587                        self.finish_assist(*assist_id, false, cx);
 588                    }
 589
 590                    return;
 591                }
 592            }
 593        }
 594
 595        cx.propagate();
 596    }
 597
 598    fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 599        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 600            return;
 601        };
 602
 603        if editor.read(cx).selections.count() == 1 {
 604            let (selection, buffer) = editor.update(cx, |editor, cx| {
 605                (
 606                    editor.selections.newest::<usize>(cx),
 607                    editor.buffer().read(cx).snapshot(cx),
 608                )
 609            });
 610            let mut closest_assist_fallback = None;
 611            for assist_id in &editor_assists.assist_ids {
 612                let assist = &self.assists[assist_id];
 613                let assist_range = assist.range.to_offset(&buffer);
 614                if assist.decorations.is_some() {
 615                    if assist_range.contains(&selection.start)
 616                        && assist_range.contains(&selection.end)
 617                    {
 618                        self.focus_assist(*assist_id, cx);
 619                        return;
 620                    } else {
 621                        let distance_from_selection = assist_range
 622                            .start
 623                            .abs_diff(selection.start)
 624                            .min(assist_range.start.abs_diff(selection.end))
 625                            + assist_range
 626                                .end
 627                                .abs_diff(selection.start)
 628                                .min(assist_range.end.abs_diff(selection.end));
 629                        match closest_assist_fallback {
 630                            Some((_, old_distance)) => {
 631                                if distance_from_selection < old_distance {
 632                                    closest_assist_fallback =
 633                                        Some((assist_id, distance_from_selection));
 634                                }
 635                            }
 636                            None => {
 637                                closest_assist_fallback = Some((assist_id, distance_from_selection))
 638                            }
 639                        }
 640                    }
 641                }
 642            }
 643
 644            if let Some((&assist_id, _)) = closest_assist_fallback {
 645                self.focus_assist(assist_id, cx);
 646            }
 647        }
 648
 649        cx.propagate();
 650    }
 651
 652    fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
 653        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
 654            for assist_id in editor_assists.assist_ids.clone() {
 655                self.finish_assist(assist_id, true, cx);
 656            }
 657        }
 658    }
 659
 660    fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 661        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 662            return;
 663        };
 664        let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
 665            return;
 666        };
 667        let assist = &self.assists[&scroll_lock.assist_id];
 668        let Some(decorations) = assist.decorations.as_ref() else {
 669            return;
 670        };
 671
 672        editor.update(cx, |editor, cx| {
 673            let scroll_position = editor.scroll_position(cx);
 674            let target_scroll_top = editor
 675                .row_for_block(decorations.prompt_block_id, cx)
 676                .unwrap()
 677                .0 as f32
 678                - scroll_lock.distance_from_top;
 679            if target_scroll_top != scroll_position.y {
 680                editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
 681            }
 682        });
 683    }
 684
 685    fn handle_editor_event(
 686        &mut self,
 687        editor: View<Editor>,
 688        event: &EditorEvent,
 689        cx: &mut WindowContext,
 690    ) {
 691        let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
 692            return;
 693        };
 694
 695        match event {
 696            EditorEvent::Edited { transaction_id } => {
 697                let buffer = editor.read(cx).buffer().read(cx);
 698                let edited_ranges =
 699                    buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
 700                let snapshot = buffer.snapshot(cx);
 701
 702                for assist_id in editor_assists.assist_ids.clone() {
 703                    let assist = &self.assists[&assist_id];
 704                    if matches!(
 705                        assist.codegen.read(cx).status(cx),
 706                        CodegenStatus::Error(_) | CodegenStatus::Done
 707                    ) {
 708                        let assist_range = assist.range.to_offset(&snapshot);
 709                        if edited_ranges
 710                            .iter()
 711                            .any(|range| range.overlaps(&assist_range))
 712                        {
 713                            self.finish_assist(assist_id, false, cx);
 714                        }
 715                    }
 716                }
 717            }
 718            EditorEvent::ScrollPositionChanged { .. } => {
 719                if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
 720                    let assist = &self.assists[&scroll_lock.assist_id];
 721                    if let Some(decorations) = assist.decorations.as_ref() {
 722                        let distance_from_top = editor.update(cx, |editor, cx| {
 723                            let scroll_top = editor.scroll_position(cx).y;
 724                            let prompt_row = editor
 725                                .row_for_block(decorations.prompt_block_id, cx)
 726                                .unwrap()
 727                                .0 as f32;
 728                            prompt_row - scroll_top
 729                        });
 730
 731                        if distance_from_top != scroll_lock.distance_from_top {
 732                            editor_assists.scroll_lock = None;
 733                        }
 734                    }
 735                }
 736            }
 737            EditorEvent::SelectionsChanged { .. } => {
 738                for assist_id in editor_assists.assist_ids.clone() {
 739                    let assist = &self.assists[&assist_id];
 740                    if let Some(decorations) = assist.decorations.as_ref() {
 741                        if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
 742                            return;
 743                        }
 744                    }
 745                }
 746
 747                editor_assists.scroll_lock = None;
 748            }
 749            _ => {}
 750        }
 751    }
 752
 753    pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
 754        if let Some(assist) = self.assists.get(&assist_id) {
 755            let assist_group_id = assist.group_id;
 756            if self.assist_groups[&assist_group_id].linked {
 757                for assist_id in self.unlink_assist_group(assist_group_id, cx) {
 758                    self.finish_assist(assist_id, undo, cx);
 759                }
 760                return;
 761            }
 762        }
 763
 764        self.dismiss_assist(assist_id, cx);
 765
 766        if let Some(assist) = self.assists.remove(&assist_id) {
 767            if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
 768            {
 769                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 770                if entry.get().assist_ids.is_empty() {
 771                    entry.remove();
 772                }
 773            }
 774
 775            if let hash_map::Entry::Occupied(mut entry) =
 776                self.assists_by_editor.entry(assist.editor.clone())
 777            {
 778                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 779                if entry.get().assist_ids.is_empty() {
 780                    entry.remove();
 781                    if let Some(editor) = assist.editor.upgrade() {
 782                        self.update_editor_highlights(&editor, cx);
 783                    }
 784                } else {
 785                    entry.get().highlight_updates.send(()).ok();
 786                }
 787            }
 788
 789            let active_alternative = assist.codegen.read(cx).active_alternative().clone();
 790            let message_id = active_alternative.read(cx).message_id.clone();
 791
 792            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
 793                let language_name = assist.editor.upgrade().and_then(|editor| {
 794                    let multibuffer = editor.read(cx).buffer().read(cx);
 795                    let ranges = multibuffer.range_to_buffer_ranges(assist.range.clone(), cx);
 796                    ranges
 797                        .first()
 798                        .and_then(|(buffer, _, _)| buffer.read(cx).language())
 799                        .map(|language| language.name())
 800                });
 801                report_assistant_event(
 802                    AssistantEvent {
 803                        conversation_id: None,
 804                        kind: AssistantKind::Inline,
 805                        message_id,
 806                        phase: if undo {
 807                            AssistantPhase::Rejected
 808                        } else {
 809                            AssistantPhase::Accepted
 810                        },
 811                        model: model.telemetry_id(),
 812                        model_provider: model.provider_id().to_string(),
 813                        response_latency: None,
 814                        error_message: None,
 815                        language_name: language_name.map(|name| name.to_proto()),
 816                    },
 817                    Some(self.telemetry.clone()),
 818                    cx.http_client(),
 819                    model.api_key(cx),
 820                    cx.background_executor(),
 821                );
 822            }
 823
 824            if undo {
 825                assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
 826            } else {
 827                self.confirmed_assists.insert(assist_id, active_alternative);
 828            }
 829        }
 830    }
 831
 832    fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
 833        let Some(assist) = self.assists.get_mut(&assist_id) else {
 834            return false;
 835        };
 836        let Some(editor) = assist.editor.upgrade() else {
 837            return false;
 838        };
 839        let Some(decorations) = assist.decorations.take() else {
 840            return false;
 841        };
 842
 843        editor.update(cx, |editor, cx| {
 844            let mut to_remove = decorations.removed_line_block_ids;
 845            to_remove.insert(decorations.prompt_block_id);
 846            to_remove.insert(decorations.end_block_id);
 847            editor.remove_blocks(to_remove, None, cx);
 848        });
 849
 850        if decorations
 851            .prompt_editor
 852            .focus_handle(cx)
 853            .contains_focused(cx)
 854        {
 855            self.focus_next_assist(assist_id, cx);
 856        }
 857
 858        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
 859            if editor_assists
 860                .scroll_lock
 861                .as_ref()
 862                .map_or(false, |lock| lock.assist_id == assist_id)
 863            {
 864                editor_assists.scroll_lock = None;
 865            }
 866            editor_assists.highlight_updates.send(()).ok();
 867        }
 868
 869        true
 870    }
 871
 872    fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 873        let Some(assist) = self.assists.get(&assist_id) else {
 874            return;
 875        };
 876
 877        let assist_group = &self.assist_groups[&assist.group_id];
 878        let assist_ix = assist_group
 879            .assist_ids
 880            .iter()
 881            .position(|id| *id == assist_id)
 882            .unwrap();
 883        let assist_ids = assist_group
 884            .assist_ids
 885            .iter()
 886            .skip(assist_ix + 1)
 887            .chain(assist_group.assist_ids.iter().take(assist_ix));
 888
 889        for assist_id in assist_ids {
 890            let assist = &self.assists[assist_id];
 891            if assist.decorations.is_some() {
 892                self.focus_assist(*assist_id, cx);
 893                return;
 894            }
 895        }
 896
 897        assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
 898    }
 899
 900    fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 901        let Some(assist) = self.assists.get(&assist_id) else {
 902            return;
 903        };
 904
 905        if let Some(decorations) = assist.decorations.as_ref() {
 906            decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 907                prompt_editor.editor.update(cx, |editor, cx| {
 908                    editor.focus(cx);
 909                    editor.select_all(&SelectAll, cx);
 910                })
 911            });
 912        }
 913
 914        self.scroll_to_assist(assist_id, cx);
 915    }
 916
 917    pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 918        let Some(assist) = self.assists.get(&assist_id) else {
 919            return;
 920        };
 921        let Some(editor) = assist.editor.upgrade() else {
 922            return;
 923        };
 924
 925        let position = assist.range.start;
 926        editor.update(cx, |editor, cx| {
 927            editor.change_selections(None, cx, |selections| {
 928                selections.select_anchor_ranges([position..position])
 929            });
 930
 931            let mut scroll_target_top;
 932            let mut scroll_target_bottom;
 933            if let Some(decorations) = assist.decorations.as_ref() {
 934                scroll_target_top = editor
 935                    .row_for_block(decorations.prompt_block_id, cx)
 936                    .unwrap()
 937                    .0 as f32;
 938                scroll_target_bottom = editor
 939                    .row_for_block(decorations.end_block_id, cx)
 940                    .unwrap()
 941                    .0 as f32;
 942            } else {
 943                let snapshot = editor.snapshot(cx);
 944                let start_row = assist
 945                    .range
 946                    .start
 947                    .to_display_point(&snapshot.display_snapshot)
 948                    .row();
 949                scroll_target_top = start_row.0 as f32;
 950                scroll_target_bottom = scroll_target_top + 1.;
 951            }
 952            scroll_target_top -= editor.vertical_scroll_margin() as f32;
 953            scroll_target_bottom += editor.vertical_scroll_margin() as f32;
 954
 955            let height_in_lines = editor.visible_line_count().unwrap_or(0.);
 956            let scroll_top = editor.scroll_position(cx).y;
 957            let scroll_bottom = scroll_top + height_in_lines;
 958
 959            if scroll_target_top < scroll_top {
 960                editor.set_scroll_position(point(0., scroll_target_top), cx);
 961            } else if scroll_target_bottom > scroll_bottom {
 962                if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
 963                    editor
 964                        .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
 965                } else {
 966                    editor.set_scroll_position(point(0., scroll_target_top), cx);
 967                }
 968            }
 969        });
 970    }
 971
 972    fn unlink_assist_group(
 973        &mut self,
 974        assist_group_id: InlineAssistGroupId,
 975        cx: &mut WindowContext,
 976    ) -> Vec<InlineAssistId> {
 977        let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
 978        assist_group.linked = false;
 979        for assist_id in &assist_group.assist_ids {
 980            let assist = self.assists.get_mut(assist_id).unwrap();
 981            if let Some(editor_decorations) = assist.decorations.as_ref() {
 982                editor_decorations
 983                    .prompt_editor
 984                    .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
 985            }
 986        }
 987        assist_group.assist_ids.clone()
 988    }
 989
 990    pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 991        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
 992            assist
 993        } else {
 994            return;
 995        };
 996
 997        let assist_group_id = assist.group_id;
 998        if self.assist_groups[&assist_group_id].linked {
 999            for assist_id in self.unlink_assist_group(assist_group_id, cx) {
1000                self.start_assist(assist_id, cx);
1001            }
1002            return;
1003        }
1004
1005        let Some(user_prompt) = assist.user_prompt(cx) else {
1006            return;
1007        };
1008
1009        self.prompt_history.retain(|prompt| *prompt != user_prompt);
1010        self.prompt_history.push_back(user_prompt.clone());
1011        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1012            self.prompt_history.pop_front();
1013        }
1014
1015        let assistant_panel_context = assist.assistant_panel_context(cx);
1016
1017        assist
1018            .codegen
1019            .update(cx, |codegen, cx| {
1020                codegen.start(user_prompt, assistant_panel_context, cx)
1021            })
1022            .log_err();
1023    }
1024
1025    pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1026        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1027            assist
1028        } else {
1029            return;
1030        };
1031
1032        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1033    }
1034
1035    fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1036        let mut gutter_pending_ranges = Vec::new();
1037        let mut gutter_transformed_ranges = Vec::new();
1038        let mut foreground_ranges = Vec::new();
1039        let mut inserted_row_ranges = Vec::new();
1040        let empty_assist_ids = Vec::new();
1041        let assist_ids = self
1042            .assists_by_editor
1043            .get(&editor.downgrade())
1044            .map_or(&empty_assist_ids, |editor_assists| {
1045                &editor_assists.assist_ids
1046            });
1047
1048        for assist_id in assist_ids {
1049            if let Some(assist) = self.assists.get(assist_id) {
1050                let codegen = assist.codegen.read(cx);
1051                let buffer = codegen.buffer(cx).read(cx).read(cx);
1052                foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1053
1054                let pending_range =
1055                    codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1056                if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1057                    gutter_pending_ranges.push(pending_range);
1058                }
1059
1060                if let Some(edit_position) = codegen.edit_position(cx) {
1061                    let edited_range = assist.range.start..edit_position;
1062                    if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1063                        gutter_transformed_ranges.push(edited_range);
1064                    }
1065                }
1066
1067                if assist.decorations.is_some() {
1068                    inserted_row_ranges
1069                        .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1070                }
1071            }
1072        }
1073
1074        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1075        merge_ranges(&mut foreground_ranges, &snapshot);
1076        merge_ranges(&mut gutter_pending_ranges, &snapshot);
1077        merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1078        editor.update(cx, |editor, cx| {
1079            enum GutterPendingRange {}
1080            if gutter_pending_ranges.is_empty() {
1081                editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1082            } else {
1083                editor.highlight_gutter::<GutterPendingRange>(
1084                    &gutter_pending_ranges,
1085                    |cx| cx.theme().status().info_background,
1086                    cx,
1087                )
1088            }
1089
1090            enum GutterTransformedRange {}
1091            if gutter_transformed_ranges.is_empty() {
1092                editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1093            } else {
1094                editor.highlight_gutter::<GutterTransformedRange>(
1095                    &gutter_transformed_ranges,
1096                    |cx| cx.theme().status().info,
1097                    cx,
1098                )
1099            }
1100
1101            if foreground_ranges.is_empty() {
1102                editor.clear_highlights::<InlineAssist>(cx);
1103            } else {
1104                editor.highlight_text::<InlineAssist>(
1105                    foreground_ranges,
1106                    HighlightStyle {
1107                        fade_out: Some(0.6),
1108                        ..Default::default()
1109                    },
1110                    cx,
1111                );
1112            }
1113
1114            editor.clear_row_highlights::<InlineAssist>();
1115            for row_range in inserted_row_ranges {
1116                editor.highlight_rows::<InlineAssist>(
1117                    row_range,
1118                    cx.theme().status().info_background,
1119                    false,
1120                    cx,
1121                );
1122            }
1123        });
1124    }
1125
1126    fn update_editor_blocks(
1127        &mut self,
1128        editor: &View<Editor>,
1129        assist_id: InlineAssistId,
1130        cx: &mut WindowContext,
1131    ) {
1132        let Some(assist) = self.assists.get_mut(&assist_id) else {
1133            return;
1134        };
1135        let Some(decorations) = assist.decorations.as_mut() else {
1136            return;
1137        };
1138
1139        let codegen = assist.codegen.read(cx);
1140        let old_snapshot = codegen.snapshot(cx);
1141        let old_buffer = codegen.old_buffer(cx);
1142        let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1143
1144        editor.update(cx, |editor, cx| {
1145            let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1146            editor.remove_blocks(old_blocks, None, cx);
1147
1148            let mut new_blocks = Vec::new();
1149            for (new_row, old_row_range) in deleted_row_ranges {
1150                let (_, buffer_start) = old_snapshot
1151                    .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1152                    .unwrap();
1153                let (_, buffer_end) = old_snapshot
1154                    .point_to_buffer_offset(Point::new(
1155                        *old_row_range.end(),
1156                        old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1157                    ))
1158                    .unwrap();
1159
1160                let deleted_lines_editor = cx.new_view(|cx| {
1161                    let multi_buffer = cx.new_model(|_| {
1162                        MultiBuffer::without_headers(language::Capability::ReadOnly)
1163                    });
1164                    multi_buffer.update(cx, |multi_buffer, cx| {
1165                        multi_buffer.push_excerpts(
1166                            old_buffer.clone(),
1167                            Some(ExcerptRange {
1168                                context: buffer_start..buffer_end,
1169                                primary: None,
1170                            }),
1171                            cx,
1172                        );
1173                    });
1174
1175                    enum DeletedLines {}
1176                    let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1177                    editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1178                    editor.set_show_wrap_guides(false, cx);
1179                    editor.set_show_gutter(false, cx);
1180                    editor.scroll_manager.set_forbid_vertical_scroll(true);
1181                    editor.set_read_only(true);
1182                    editor.set_show_inline_completions(Some(false), cx);
1183                    editor.highlight_rows::<DeletedLines>(
1184                        Anchor::min()..Anchor::max(),
1185                        cx.theme().status().deleted_background,
1186                        false,
1187                        cx,
1188                    );
1189                    editor
1190                });
1191
1192                let height =
1193                    deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1194                new_blocks.push(BlockProperties {
1195                    placement: BlockPlacement::Above(new_row),
1196                    height,
1197                    style: BlockStyle::Flex,
1198                    render: Box::new(move |cx| {
1199                        div()
1200                            .bg(cx.theme().status().deleted_background)
1201                            .size_full()
1202                            .h(height as f32 * cx.line_height())
1203                            .pl(cx.gutter_dimensions.full_width())
1204                            .child(deleted_lines_editor.clone())
1205                            .into_any_element()
1206                    }),
1207                    priority: 0,
1208                });
1209            }
1210
1211            decorations.removed_line_block_ids = editor
1212                .insert_blocks(new_blocks, None, cx)
1213                .into_iter()
1214                .collect();
1215        })
1216    }
1217}
1218
1219struct EditorInlineAssists {
1220    assist_ids: Vec<InlineAssistId>,
1221    scroll_lock: Option<InlineAssistScrollLock>,
1222    highlight_updates: async_watch::Sender<()>,
1223    _update_highlights: Task<Result<()>>,
1224    _subscriptions: Vec<gpui::Subscription>,
1225}
1226
1227struct InlineAssistScrollLock {
1228    assist_id: InlineAssistId,
1229    distance_from_top: f32,
1230}
1231
1232impl EditorInlineAssists {
1233    #[allow(clippy::too_many_arguments)]
1234    fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1235        let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1236        Self {
1237            assist_ids: Vec::new(),
1238            scroll_lock: None,
1239            highlight_updates: highlight_updates_tx,
1240            _update_highlights: cx.spawn(|mut cx| {
1241                let editor = editor.downgrade();
1242                async move {
1243                    while let Ok(()) = highlight_updates_rx.changed().await {
1244                        let editor = editor.upgrade().context("editor was dropped")?;
1245                        cx.update_global(|assistant: &mut InlineAssistant, cx| {
1246                            assistant.update_editor_highlights(&editor, cx);
1247                        })?;
1248                    }
1249                    Ok(())
1250                }
1251            }),
1252            _subscriptions: vec![
1253                cx.observe_release(editor, {
1254                    let editor = editor.downgrade();
1255                    |_, cx| {
1256                        InlineAssistant::update_global(cx, |this, cx| {
1257                            this.handle_editor_release(editor, cx);
1258                        })
1259                    }
1260                }),
1261                cx.observe(editor, move |editor, cx| {
1262                    InlineAssistant::update_global(cx, |this, cx| {
1263                        this.handle_editor_change(editor, cx)
1264                    })
1265                }),
1266                cx.subscribe(editor, move |editor, event, cx| {
1267                    InlineAssistant::update_global(cx, |this, cx| {
1268                        this.handle_editor_event(editor, event, cx)
1269                    })
1270                }),
1271                editor.update(cx, |editor, cx| {
1272                    let editor_handle = cx.view().downgrade();
1273                    editor.register_action(
1274                        move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1275                            InlineAssistant::update_global(cx, |this, cx| {
1276                                if let Some(editor) = editor_handle.upgrade() {
1277                                    this.handle_editor_newline(editor, cx)
1278                                }
1279                            })
1280                        },
1281                    )
1282                }),
1283                editor.update(cx, |editor, cx| {
1284                    let editor_handle = cx.view().downgrade();
1285                    editor.register_action(
1286                        move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1287                            InlineAssistant::update_global(cx, |this, cx| {
1288                                if let Some(editor) = editor_handle.upgrade() {
1289                                    this.handle_editor_cancel(editor, cx)
1290                                }
1291                            })
1292                        },
1293                    )
1294                }),
1295            ],
1296        }
1297    }
1298}
1299
1300struct InlineAssistGroup {
1301    assist_ids: Vec<InlineAssistId>,
1302    linked: bool,
1303    active_assist_id: Option<InlineAssistId>,
1304}
1305
1306impl InlineAssistGroup {
1307    fn new() -> Self {
1308        Self {
1309            assist_ids: Vec::new(),
1310            linked: true,
1311            active_assist_id: None,
1312        }
1313    }
1314}
1315
1316fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1317    let editor = editor.clone();
1318    Box::new(move |cx: &mut BlockContext| {
1319        *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1320        editor.clone().into_any_element()
1321    })
1322}
1323
1324#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1325pub struct InlineAssistId(usize);
1326
1327impl InlineAssistId {
1328    fn post_inc(&mut self) -> InlineAssistId {
1329        let id = *self;
1330        self.0 += 1;
1331        id
1332    }
1333}
1334
1335#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1336struct InlineAssistGroupId(usize);
1337
1338impl InlineAssistGroupId {
1339    fn post_inc(&mut self) -> InlineAssistGroupId {
1340        let id = *self;
1341        self.0 += 1;
1342        id
1343    }
1344}
1345
1346enum PromptEditorEvent {
1347    StartRequested,
1348    StopRequested,
1349    ConfirmRequested,
1350    CancelRequested,
1351    DismissRequested,
1352}
1353
1354struct PromptEditor {
1355    id: InlineAssistId,
1356    fs: Arc<dyn Fs>,
1357    editor: View<Editor>,
1358    edited_since_done: bool,
1359    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1360    prompt_history: VecDeque<String>,
1361    prompt_history_ix: Option<usize>,
1362    pending_prompt: String,
1363    codegen: Model<Codegen>,
1364    _codegen_subscription: Subscription,
1365    editor_subscriptions: Vec<Subscription>,
1366    pending_token_count: Task<Result<()>>,
1367    token_counts: Option<TokenCounts>,
1368    _token_count_subscriptions: Vec<Subscription>,
1369    workspace: Option<WeakView<Workspace>>,
1370    show_rate_limit_notice: bool,
1371}
1372
1373#[derive(Copy, Clone)]
1374pub struct TokenCounts {
1375    total: usize,
1376    assistant_panel: usize,
1377}
1378
1379impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1380
1381impl Render for PromptEditor {
1382    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1383        let gutter_dimensions = *self.gutter_dimensions.lock();
1384        let codegen = self.codegen.read(cx);
1385
1386        let mut buttons = Vec::new();
1387        if codegen.alternative_count(cx) > 1 {
1388            buttons.push(self.render_cycle_controls(cx));
1389        }
1390
1391        let status = codegen.status(cx);
1392        buttons.extend(match status {
1393            CodegenStatus::Idle => {
1394                vec![
1395                    IconButton::new("cancel", IconName::Close)
1396                        .icon_color(Color::Muted)
1397                        .shape(IconButtonShape::Square)
1398                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1399                        .on_click(
1400                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1401                        )
1402                        .into_any_element(),
1403                    IconButton::new("start", IconName::SparkleAlt)
1404                        .icon_color(Color::Muted)
1405                        .shape(IconButtonShape::Square)
1406                        .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1407                        .on_click(
1408                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1409                        )
1410                        .into_any_element(),
1411                ]
1412            }
1413            CodegenStatus::Pending => {
1414                vec![
1415                    IconButton::new("cancel", IconName::Close)
1416                        .icon_color(Color::Muted)
1417                        .shape(IconButtonShape::Square)
1418                        .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1419                        .on_click(
1420                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1421                        )
1422                        .into_any_element(),
1423                    IconButton::new("stop", IconName::Stop)
1424                        .icon_color(Color::Error)
1425                        .shape(IconButtonShape::Square)
1426                        .tooltip(|cx| {
1427                            Tooltip::with_meta(
1428                                "Interrupt Transformation",
1429                                Some(&menu::Cancel),
1430                                "Changes won't be discarded",
1431                                cx,
1432                            )
1433                        })
1434                        .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
1435                        .into_any_element(),
1436                ]
1437            }
1438            CodegenStatus::Error(_) | CodegenStatus::Done => {
1439                vec![
1440                    IconButton::new("cancel", IconName::Close)
1441                        .icon_color(Color::Muted)
1442                        .shape(IconButtonShape::Square)
1443                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1444                        .on_click(
1445                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1446                        )
1447                        .into_any_element(),
1448                    if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1449                        IconButton::new("restart", IconName::RotateCw)
1450                            .icon_color(Color::Info)
1451                            .shape(IconButtonShape::Square)
1452                            .tooltip(|cx| {
1453                                Tooltip::with_meta(
1454                                    "Restart Transformation",
1455                                    Some(&menu::Confirm),
1456                                    "Changes will be discarded",
1457                                    cx,
1458                                )
1459                            })
1460                            .on_click(cx.listener(|_, _, cx| {
1461                                cx.emit(PromptEditorEvent::StartRequested);
1462                            }))
1463                            .into_any_element()
1464                    } else {
1465                        IconButton::new("confirm", IconName::Check)
1466                            .icon_color(Color::Info)
1467                            .shape(IconButtonShape::Square)
1468                            .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1469                            .on_click(cx.listener(|_, _, cx| {
1470                                cx.emit(PromptEditorEvent::ConfirmRequested);
1471                            }))
1472                            .into_any_element()
1473                    },
1474                ]
1475            }
1476        });
1477
1478        h_flex()
1479            .key_context("PromptEditor")
1480            .bg(cx.theme().colors().editor_background)
1481            .border_y_1()
1482            .border_color(cx.theme().status().info_border)
1483            .size_full()
1484            .py(cx.line_height() / 2.5)
1485            .on_action(cx.listener(Self::confirm))
1486            .on_action(cx.listener(Self::cancel))
1487            .on_action(cx.listener(Self::move_up))
1488            .on_action(cx.listener(Self::move_down))
1489            .capture_action(cx.listener(Self::cycle_prev))
1490            .capture_action(cx.listener(Self::cycle_next))
1491            .child(
1492                h_flex()
1493                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1494                    .justify_center()
1495                    .gap_2()
1496                    .child(
1497                        ModelSelector::new(
1498                            self.fs.clone(),
1499                            IconButton::new("context", IconName::SettingsAlt)
1500                                .shape(IconButtonShape::Square)
1501                                .icon_size(IconSize::Small)
1502                                .icon_color(Color::Muted)
1503                                .tooltip(move |cx| {
1504                                    Tooltip::with_meta(
1505                                        format!(
1506                                            "Using {}",
1507                                            LanguageModelRegistry::read_global(cx)
1508                                                .active_model()
1509                                                .map(|model| model.name().0)
1510                                                .unwrap_or_else(|| "No model selected".into()),
1511                                        ),
1512                                        None,
1513                                        "Change Model",
1514                                        cx,
1515                                    )
1516                                }),
1517                        )
1518                        .with_info_text(
1519                            "Inline edits use context\n\
1520                            from the currently selected\n\
1521                            assistant panel tab.",
1522                        ),
1523                    )
1524                    .map(|el| {
1525                        let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1526                            return el;
1527                        };
1528
1529                        let error_message = SharedString::from(error.to_string());
1530                        if error.error_code() == proto::ErrorCode::RateLimitExceeded
1531                            && cx.has_flag::<ZedPro>()
1532                        {
1533                            el.child(
1534                                v_flex()
1535                                    .child(
1536                                        IconButton::new("rate-limit-error", IconName::XCircle)
1537                                            .selected(self.show_rate_limit_notice)
1538                                            .shape(IconButtonShape::Square)
1539                                            .icon_size(IconSize::Small)
1540                                            .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1541                                    )
1542                                    .children(self.show_rate_limit_notice.then(|| {
1543                                        deferred(
1544                                            anchored()
1545                                                .position_mode(gpui::AnchoredPositionMode::Local)
1546                                                .position(point(px(0.), px(24.)))
1547                                                .anchor(gpui::AnchorCorner::TopLeft)
1548                                                .child(self.render_rate_limit_notice(cx)),
1549                                        )
1550                                    })),
1551                            )
1552                        } else {
1553                            el.child(
1554                                div()
1555                                    .id("error")
1556                                    .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1557                                    .child(
1558                                        Icon::new(IconName::XCircle)
1559                                            .size(IconSize::Small)
1560                                            .color(Color::Error),
1561                                    ),
1562                            )
1563                        }
1564                    }),
1565            )
1566            .child(div().flex_1().child(self.render_prompt_editor(cx)))
1567            .child(
1568                h_flex()
1569                    .gap_2()
1570                    .pr_6()
1571                    .children(self.render_token_count(cx))
1572                    .children(buttons),
1573            )
1574    }
1575}
1576
1577impl FocusableView for PromptEditor {
1578    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1579        self.editor.focus_handle(cx)
1580    }
1581}
1582
1583impl PromptEditor {
1584    const MAX_LINES: u8 = 8;
1585
1586    #[allow(clippy::too_many_arguments)]
1587    fn new(
1588        id: InlineAssistId,
1589        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1590        prompt_history: VecDeque<String>,
1591        prompt_buffer: Model<MultiBuffer>,
1592        codegen: Model<Codegen>,
1593        parent_editor: &View<Editor>,
1594        assistant_panel: Option<&View<AssistantPanel>>,
1595        workspace: Option<WeakView<Workspace>>,
1596        fs: Arc<dyn Fs>,
1597        cx: &mut ViewContext<Self>,
1598    ) -> Self {
1599        let prompt_editor = cx.new_view(|cx| {
1600            let mut editor = Editor::new(
1601                EditorMode::AutoHeight {
1602                    max_lines: Self::MAX_LINES as usize,
1603                },
1604                prompt_buffer,
1605                None,
1606                false,
1607                cx,
1608            );
1609            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1610            // Since the prompt editors for all inline assistants are linked,
1611            // always show the cursor (even when it isn't focused) because
1612            // typing in one will make what you typed appear in all of them.
1613            editor.set_show_cursor_when_unfocused(true, cx);
1614            editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx), cx), cx);
1615            editor
1616        });
1617
1618        let mut token_count_subscriptions = Vec::new();
1619        token_count_subscriptions
1620            .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1621        if let Some(assistant_panel) = assistant_panel {
1622            token_count_subscriptions
1623                .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1624        }
1625
1626        let mut this = Self {
1627            id,
1628            editor: prompt_editor,
1629            edited_since_done: false,
1630            gutter_dimensions,
1631            prompt_history,
1632            prompt_history_ix: None,
1633            pending_prompt: String::new(),
1634            _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1635            editor_subscriptions: Vec::new(),
1636            codegen,
1637            fs,
1638            pending_token_count: Task::ready(Ok(())),
1639            token_counts: None,
1640            _token_count_subscriptions: token_count_subscriptions,
1641            workspace,
1642            show_rate_limit_notice: false,
1643        };
1644        this.count_tokens(cx);
1645        this.subscribe_to_editor(cx);
1646        this
1647    }
1648
1649    fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1650        self.editor_subscriptions.clear();
1651        self.editor_subscriptions
1652            .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1653    }
1654
1655    fn set_show_cursor_when_unfocused(
1656        &mut self,
1657        show_cursor_when_unfocused: bool,
1658        cx: &mut ViewContext<Self>,
1659    ) {
1660        self.editor.update(cx, |editor, cx| {
1661            editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1662        });
1663    }
1664
1665    fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1666        let prompt = self.prompt(cx);
1667        let focus = self.editor.focus_handle(cx).contains_focused(cx);
1668        self.editor = cx.new_view(|cx| {
1669            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1670            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1671            editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx), cx), cx);
1672            editor.set_placeholder_text("Add a prompt…", cx);
1673            editor.set_text(prompt, cx);
1674            if focus {
1675                editor.focus(cx);
1676            }
1677            editor
1678        });
1679        self.subscribe_to_editor(cx);
1680    }
1681
1682    fn placeholder_text(codegen: &Codegen, cx: &WindowContext) -> String {
1683        let context_keybinding = text_for_action(&crate::ToggleFocus, cx)
1684            .map(|keybinding| format!("{keybinding} for context"))
1685            .unwrap_or_default();
1686
1687        let action = if codegen.is_insertion {
1688            "Generate"
1689        } else {
1690            "Transform"
1691        };
1692
1693        format!("{action}{context_keybinding} • ↓↑ for history")
1694    }
1695
1696    fn prompt(&self, cx: &AppContext) -> String {
1697        self.editor.read(cx).text(cx)
1698    }
1699
1700    fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1701        self.show_rate_limit_notice = !self.show_rate_limit_notice;
1702        if self.show_rate_limit_notice {
1703            cx.focus_view(&self.editor);
1704        }
1705        cx.notify();
1706    }
1707
1708    fn handle_parent_editor_event(
1709        &mut self,
1710        _: View<Editor>,
1711        event: &EditorEvent,
1712        cx: &mut ViewContext<Self>,
1713    ) {
1714        if let EditorEvent::BufferEdited { .. } = event {
1715            self.count_tokens(cx);
1716        }
1717    }
1718
1719    fn handle_assistant_panel_event(
1720        &mut self,
1721        _: View<AssistantPanel>,
1722        event: &AssistantPanelEvent,
1723        cx: &mut ViewContext<Self>,
1724    ) {
1725        let AssistantPanelEvent::ContextEdited { .. } = event;
1726        self.count_tokens(cx);
1727    }
1728
1729    fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1730        let assist_id = self.id;
1731        self.pending_token_count = cx.spawn(|this, mut cx| async move {
1732            cx.background_executor().timer(Duration::from_secs(1)).await;
1733            let token_count = cx
1734                .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1735                    let assist = inline_assistant
1736                        .assists
1737                        .get(&assist_id)
1738                        .context("assist not found")?;
1739                    anyhow::Ok(assist.count_tokens(cx))
1740                })??
1741                .await?;
1742
1743            this.update(&mut cx, |this, cx| {
1744                this.token_counts = Some(token_count);
1745                cx.notify();
1746            })
1747        })
1748    }
1749
1750    fn handle_prompt_editor_events(
1751        &mut self,
1752        _: View<Editor>,
1753        event: &EditorEvent,
1754        cx: &mut ViewContext<Self>,
1755    ) {
1756        match event {
1757            EditorEvent::Edited { .. } => {
1758                if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
1759                    workspace
1760                        .update(cx, |workspace, cx| {
1761                            let is_via_ssh = workspace
1762                                .project()
1763                                .update(cx, |project, _| project.is_via_ssh());
1764
1765                            workspace
1766                                .client()
1767                                .telemetry()
1768                                .log_edit_event("inline assist", is_via_ssh);
1769                        })
1770                        .log_err();
1771                }
1772                let prompt = self.editor.read(cx).text(cx);
1773                if self
1774                    .prompt_history_ix
1775                    .map_or(true, |ix| self.prompt_history[ix] != prompt)
1776                {
1777                    self.prompt_history_ix.take();
1778                    self.pending_prompt = prompt;
1779                }
1780
1781                self.edited_since_done = true;
1782                cx.notify();
1783            }
1784            EditorEvent::BufferEdited => {
1785                self.count_tokens(cx);
1786            }
1787            EditorEvent::Blurred => {
1788                if self.show_rate_limit_notice {
1789                    self.show_rate_limit_notice = false;
1790                    cx.notify();
1791                }
1792            }
1793            _ => {}
1794        }
1795    }
1796
1797    fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1798        match self.codegen.read(cx).status(cx) {
1799            CodegenStatus::Idle => {
1800                self.editor
1801                    .update(cx, |editor, _| editor.set_read_only(false));
1802            }
1803            CodegenStatus::Pending => {
1804                self.editor
1805                    .update(cx, |editor, _| editor.set_read_only(true));
1806            }
1807            CodegenStatus::Done => {
1808                self.edited_since_done = false;
1809                self.editor
1810                    .update(cx, |editor, _| editor.set_read_only(false));
1811            }
1812            CodegenStatus::Error(error) => {
1813                if cx.has_flag::<ZedPro>()
1814                    && error.error_code() == proto::ErrorCode::RateLimitExceeded
1815                    && !dismissed_rate_limit_notice()
1816                {
1817                    self.show_rate_limit_notice = true;
1818                    cx.notify();
1819                }
1820
1821                self.edited_since_done = false;
1822                self.editor
1823                    .update(cx, |editor, _| editor.set_read_only(false));
1824            }
1825        }
1826    }
1827
1828    fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1829        match self.codegen.read(cx).status(cx) {
1830            CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1831                cx.emit(PromptEditorEvent::CancelRequested);
1832            }
1833            CodegenStatus::Pending => {
1834                cx.emit(PromptEditorEvent::StopRequested);
1835            }
1836        }
1837    }
1838
1839    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1840        match self.codegen.read(cx).status(cx) {
1841            CodegenStatus::Idle => {
1842                cx.emit(PromptEditorEvent::StartRequested);
1843            }
1844            CodegenStatus::Pending => {
1845                cx.emit(PromptEditorEvent::DismissRequested);
1846            }
1847            CodegenStatus::Done => {
1848                if self.edited_since_done {
1849                    cx.emit(PromptEditorEvent::StartRequested);
1850                } else {
1851                    cx.emit(PromptEditorEvent::ConfirmRequested);
1852                }
1853            }
1854            CodegenStatus::Error(_) => {
1855                cx.emit(PromptEditorEvent::StartRequested);
1856            }
1857        }
1858    }
1859
1860    fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1861        if let Some(ix) = self.prompt_history_ix {
1862            if ix > 0 {
1863                self.prompt_history_ix = Some(ix - 1);
1864                let prompt = self.prompt_history[ix - 1].as_str();
1865                self.editor.update(cx, |editor, cx| {
1866                    editor.set_text(prompt, cx);
1867                    editor.move_to_beginning(&Default::default(), cx);
1868                });
1869            }
1870        } else if !self.prompt_history.is_empty() {
1871            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1872            let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1873            self.editor.update(cx, |editor, cx| {
1874                editor.set_text(prompt, cx);
1875                editor.move_to_beginning(&Default::default(), cx);
1876            });
1877        }
1878    }
1879
1880    fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1881        if let Some(ix) = self.prompt_history_ix {
1882            if ix < self.prompt_history.len() - 1 {
1883                self.prompt_history_ix = Some(ix + 1);
1884                let prompt = self.prompt_history[ix + 1].as_str();
1885                self.editor.update(cx, |editor, cx| {
1886                    editor.set_text(prompt, cx);
1887                    editor.move_to_end(&Default::default(), cx)
1888                });
1889            } else {
1890                self.prompt_history_ix = None;
1891                let prompt = self.pending_prompt.as_str();
1892                self.editor.update(cx, |editor, cx| {
1893                    editor.set_text(prompt, cx);
1894                    editor.move_to_end(&Default::default(), cx)
1895                });
1896            }
1897        }
1898    }
1899
1900    fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
1901        self.codegen
1902            .update(cx, |codegen, cx| codegen.cycle_prev(cx));
1903    }
1904
1905    fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
1906        self.codegen
1907            .update(cx, |codegen, cx| codegen.cycle_next(cx));
1908    }
1909
1910    fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
1911        let codegen = self.codegen.read(cx);
1912        let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
1913
1914        h_flex()
1915            .child(
1916                IconButton::new("previous", IconName::ChevronLeft)
1917                    .icon_color(Color::Muted)
1918                    .disabled(disabled)
1919                    .shape(IconButtonShape::Square)
1920                    .tooltip({
1921                        let focus_handle = self.editor.focus_handle(cx);
1922                        move |cx| {
1923                            Tooltip::for_action_in(
1924                                "Previous Alternative",
1925                                &CyclePreviousInlineAssist,
1926                                &focus_handle,
1927                                cx,
1928                            )
1929                        }
1930                    })
1931                    .on_click(cx.listener(|this, _, cx| {
1932                        this.codegen
1933                            .update(cx, |codegen, cx| codegen.cycle_prev(cx))
1934                    })),
1935            )
1936            .child(
1937                Label::new(format!(
1938                    "{}/{}",
1939                    codegen.active_alternative + 1,
1940                    codegen.alternative_count(cx)
1941                ))
1942                .size(LabelSize::Small)
1943                .color(if disabled {
1944                    Color::Disabled
1945                } else {
1946                    Color::Muted
1947                }),
1948            )
1949            .child(
1950                IconButton::new("next", IconName::ChevronRight)
1951                    .icon_color(Color::Muted)
1952                    .disabled(disabled)
1953                    .shape(IconButtonShape::Square)
1954                    .tooltip({
1955                        let focus_handle = self.editor.focus_handle(cx);
1956                        move |cx| {
1957                            Tooltip::for_action_in(
1958                                "Next Alternative",
1959                                &CycleNextInlineAssist,
1960                                &focus_handle,
1961                                cx,
1962                            )
1963                        }
1964                    })
1965                    .on_click(cx.listener(|this, _, cx| {
1966                        this.codegen
1967                            .update(cx, |codegen, cx| codegen.cycle_next(cx))
1968                    })),
1969            )
1970            .into_any_element()
1971    }
1972
1973    fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1974        let model = LanguageModelRegistry::read_global(cx).active_model()?;
1975        let token_counts = self.token_counts?;
1976        let max_token_count = model.max_token_count();
1977
1978        let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1979        let token_count_color = if remaining_tokens <= 0 {
1980            Color::Error
1981        } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1982            Color::Warning
1983        } else {
1984            Color::Muted
1985        };
1986
1987        let mut token_count = h_flex()
1988            .id("token_count")
1989            .gap_0p5()
1990            .child(
1991                Label::new(humanize_token_count(token_counts.total))
1992                    .size(LabelSize::Small)
1993                    .color(token_count_color),
1994            )
1995            .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1996            .child(
1997                Label::new(humanize_token_count(max_token_count))
1998                    .size(LabelSize::Small)
1999                    .color(Color::Muted),
2000            );
2001        if let Some(workspace) = self.workspace.clone() {
2002            token_count = token_count
2003                .tooltip(move |cx| {
2004                    Tooltip::with_meta(
2005                        format!(
2006                            "Tokens Used ({} from the Assistant Panel)",
2007                            humanize_token_count(token_counts.assistant_panel)
2008                        ),
2009                        None,
2010                        "Click to open the Assistant Panel",
2011                        cx,
2012                    )
2013                })
2014                .cursor_pointer()
2015                .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
2016                .on_click(move |_, cx| {
2017                    cx.stop_propagation();
2018                    workspace
2019                        .update(cx, |workspace, cx| {
2020                            workspace.focus_panel::<AssistantPanel>(cx)
2021                        })
2022                        .ok();
2023                });
2024        } else {
2025            token_count = token_count
2026                .cursor_default()
2027                .tooltip(|cx| Tooltip::text("Tokens used", cx));
2028        }
2029
2030        Some(token_count)
2031    }
2032
2033    fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2034        let settings = ThemeSettings::get_global(cx);
2035        let text_style = TextStyle {
2036            color: if self.editor.read(cx).read_only(cx) {
2037                cx.theme().colors().text_disabled
2038            } else {
2039                cx.theme().colors().text
2040            },
2041            font_family: settings.buffer_font.family.clone(),
2042            font_fallbacks: settings.buffer_font.fallbacks.clone(),
2043            font_size: settings.buffer_font_size.into(),
2044            font_weight: settings.buffer_font.weight,
2045            line_height: relative(settings.buffer_line_height.value()),
2046            ..Default::default()
2047        };
2048        EditorElement::new(
2049            &self.editor,
2050            EditorStyle {
2051                background: cx.theme().colors().editor_background,
2052                local_player: cx.theme().players().local(),
2053                text: text_style,
2054                ..Default::default()
2055            },
2056        )
2057    }
2058
2059    fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2060        Popover::new().child(
2061            v_flex()
2062                .occlude()
2063                .p_2()
2064                .child(
2065                    Label::new("Out of Tokens")
2066                        .size(LabelSize::Small)
2067                        .weight(FontWeight::BOLD),
2068                )
2069                .child(Label::new(
2070                    "Try Zed Pro for higher limits, a wider range of models, and more.",
2071                ))
2072                .child(
2073                    h_flex()
2074                        .justify_between()
2075                        .child(CheckboxWithLabel::new(
2076                            "dont-show-again",
2077                            Label::new("Don't show again"),
2078                            if dismissed_rate_limit_notice() {
2079                                ui::Selection::Selected
2080                            } else {
2081                                ui::Selection::Unselected
2082                            },
2083                            |selection, cx| {
2084                                let is_dismissed = match selection {
2085                                    ui::Selection::Unselected => false,
2086                                    ui::Selection::Indeterminate => return,
2087                                    ui::Selection::Selected => true,
2088                                };
2089
2090                                set_rate_limit_notice_dismissed(is_dismissed, cx)
2091                            },
2092                        ))
2093                        .child(
2094                            h_flex()
2095                                .gap_2()
2096                                .child(
2097                                    Button::new("dismiss", "Dismiss")
2098                                        .style(ButtonStyle::Transparent)
2099                                        .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2100                                )
2101                                .child(Button::new("more-info", "More Info").on_click(
2102                                    |_event, cx| {
2103                                        cx.dispatch_action(Box::new(
2104                                            zed_actions::OpenAccountSettings,
2105                                        ))
2106                                    },
2107                                )),
2108                        ),
2109                ),
2110        )
2111    }
2112}
2113
2114const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2115
2116fn dismissed_rate_limit_notice() -> bool {
2117    db::kvp::KEY_VALUE_STORE
2118        .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2119        .log_err()
2120        .map_or(false, |s| s.is_some())
2121}
2122
2123fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2124    db::write_and_log(cx, move || async move {
2125        if is_dismissed {
2126            db::kvp::KEY_VALUE_STORE
2127                .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2128                .await
2129        } else {
2130            db::kvp::KEY_VALUE_STORE
2131                .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2132                .await
2133        }
2134    })
2135}
2136
2137struct InlineAssist {
2138    group_id: InlineAssistGroupId,
2139    range: Range<Anchor>,
2140    editor: WeakView<Editor>,
2141    decorations: Option<InlineAssistDecorations>,
2142    codegen: Model<Codegen>,
2143    _subscriptions: Vec<Subscription>,
2144    workspace: Option<WeakView<Workspace>>,
2145    include_context: bool,
2146}
2147
2148impl InlineAssist {
2149    #[allow(clippy::too_many_arguments)]
2150    fn new(
2151        assist_id: InlineAssistId,
2152        group_id: InlineAssistGroupId,
2153        include_context: bool,
2154        editor: &View<Editor>,
2155        prompt_editor: &View<PromptEditor>,
2156        prompt_block_id: CustomBlockId,
2157        end_block_id: CustomBlockId,
2158        range: Range<Anchor>,
2159        codegen: Model<Codegen>,
2160        workspace: Option<WeakView<Workspace>>,
2161        cx: &mut WindowContext,
2162    ) -> Self {
2163        let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2164        InlineAssist {
2165            group_id,
2166            include_context,
2167            editor: editor.downgrade(),
2168            decorations: Some(InlineAssistDecorations {
2169                prompt_block_id,
2170                prompt_editor: prompt_editor.clone(),
2171                removed_line_block_ids: HashSet::default(),
2172                end_block_id,
2173            }),
2174            range,
2175            codegen: codegen.clone(),
2176            workspace: workspace.clone(),
2177            _subscriptions: vec![
2178                cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2179                    InlineAssistant::update_global(cx, |this, cx| {
2180                        this.handle_prompt_editor_focus_in(assist_id, cx)
2181                    })
2182                }),
2183                cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2184                    InlineAssistant::update_global(cx, |this, cx| {
2185                        this.handle_prompt_editor_focus_out(assist_id, cx)
2186                    })
2187                }),
2188                cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2189                    InlineAssistant::update_global(cx, |this, cx| {
2190                        this.handle_prompt_editor_event(prompt_editor, event, cx)
2191                    })
2192                }),
2193                cx.observe(&codegen, {
2194                    let editor = editor.downgrade();
2195                    move |_, cx| {
2196                        if let Some(editor) = editor.upgrade() {
2197                            InlineAssistant::update_global(cx, |this, cx| {
2198                                if let Some(editor_assists) =
2199                                    this.assists_by_editor.get(&editor.downgrade())
2200                                {
2201                                    editor_assists.highlight_updates.send(()).ok();
2202                                }
2203
2204                                this.update_editor_blocks(&editor, assist_id, cx);
2205                            })
2206                        }
2207                    }
2208                }),
2209                cx.subscribe(&codegen, move |codegen, event, cx| {
2210                    InlineAssistant::update_global(cx, |this, cx| match event {
2211                        CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2212                        CodegenEvent::Finished => {
2213                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
2214                                assist
2215                            } else {
2216                                return;
2217                            };
2218
2219                            if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2220                                if assist.decorations.is_none() {
2221                                    if let Some(workspace) = assist
2222                                        .workspace
2223                                        .as_ref()
2224                                        .and_then(|workspace| workspace.upgrade())
2225                                    {
2226                                        let error = format!("Inline assistant error: {}", error);
2227                                        workspace.update(cx, |workspace, cx| {
2228                                            struct InlineAssistantError;
2229
2230                                            let id =
2231                                                NotificationId::composite::<InlineAssistantError>(
2232                                                    assist_id.0,
2233                                                );
2234
2235                                            workspace.show_toast(Toast::new(id, error), cx);
2236                                        })
2237                                    }
2238                                }
2239                            }
2240
2241                            if assist.decorations.is_none() {
2242                                this.finish_assist(assist_id, false, cx);
2243                            }
2244                        }
2245                    })
2246                }),
2247            ],
2248        }
2249    }
2250
2251    fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2252        let decorations = self.decorations.as_ref()?;
2253        Some(decorations.prompt_editor.read(cx).prompt(cx))
2254    }
2255
2256    fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2257        if self.include_context {
2258            let workspace = self.workspace.as_ref()?;
2259            let workspace = workspace.upgrade()?.read(cx);
2260            let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2261            Some(
2262                assistant_panel
2263                    .read(cx)
2264                    .active_context(cx)?
2265                    .read(cx)
2266                    .to_completion_request(RequestType::Chat, cx),
2267            )
2268        } else {
2269            None
2270        }
2271    }
2272
2273    pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2274        let Some(user_prompt) = self.user_prompt(cx) else {
2275            return future::ready(Err(anyhow!("no user prompt"))).boxed();
2276        };
2277        let assistant_panel_context = self.assistant_panel_context(cx);
2278        self.codegen
2279            .read(cx)
2280            .count_tokens(user_prompt, assistant_panel_context, cx)
2281    }
2282}
2283
2284struct InlineAssistDecorations {
2285    prompt_block_id: CustomBlockId,
2286    prompt_editor: View<PromptEditor>,
2287    removed_line_block_ids: HashSet<CustomBlockId>,
2288    end_block_id: CustomBlockId,
2289}
2290
2291#[derive(Copy, Clone, Debug)]
2292pub enum CodegenEvent {
2293    Finished,
2294    Undone,
2295}
2296
2297pub struct Codegen {
2298    alternatives: Vec<Model<CodegenAlternative>>,
2299    active_alternative: usize,
2300    seen_alternatives: HashSet<usize>,
2301    subscriptions: Vec<Subscription>,
2302    buffer: Model<MultiBuffer>,
2303    range: Range<Anchor>,
2304    initial_transaction_id: Option<TransactionId>,
2305    telemetry: Arc<Telemetry>,
2306    builder: Arc<PromptBuilder>,
2307    is_insertion: bool,
2308}
2309
2310impl Codegen {
2311    pub fn new(
2312        buffer: Model<MultiBuffer>,
2313        range: Range<Anchor>,
2314        initial_transaction_id: Option<TransactionId>,
2315        telemetry: Arc<Telemetry>,
2316        builder: Arc<PromptBuilder>,
2317        cx: &mut ModelContext<Self>,
2318    ) -> Self {
2319        let codegen = cx.new_model(|cx| {
2320            CodegenAlternative::new(
2321                buffer.clone(),
2322                range.clone(),
2323                false,
2324                Some(telemetry.clone()),
2325                builder.clone(),
2326                cx,
2327            )
2328        });
2329        let mut this = Self {
2330            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2331            alternatives: vec![codegen],
2332            active_alternative: 0,
2333            seen_alternatives: HashSet::default(),
2334            subscriptions: Vec::new(),
2335            buffer,
2336            range,
2337            initial_transaction_id,
2338            telemetry,
2339            builder,
2340        };
2341        this.activate(0, cx);
2342        this
2343    }
2344
2345    fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2346        let codegen = self.active_alternative().clone();
2347        self.subscriptions.clear();
2348        self.subscriptions
2349            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2350        self.subscriptions
2351            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2352    }
2353
2354    fn active_alternative(&self) -> &Model<CodegenAlternative> {
2355        &self.alternatives[self.active_alternative]
2356    }
2357
2358    fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2359        &self.active_alternative().read(cx).status
2360    }
2361
2362    fn alternative_count(&self, cx: &AppContext) -> usize {
2363        LanguageModelRegistry::read_global(cx)
2364            .inline_alternative_models()
2365            .len()
2366            + 1
2367    }
2368
2369    pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2370        let next_active_ix = if self.active_alternative == 0 {
2371            self.alternatives.len() - 1
2372        } else {
2373            self.active_alternative - 1
2374        };
2375        self.activate(next_active_ix, cx);
2376    }
2377
2378    pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2379        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2380        self.activate(next_active_ix, cx);
2381    }
2382
2383    fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2384        self.active_alternative()
2385            .update(cx, |codegen, cx| codegen.set_active(false, cx));
2386        self.seen_alternatives.insert(index);
2387        self.active_alternative = index;
2388        self.active_alternative()
2389            .update(cx, |codegen, cx| codegen.set_active(true, cx));
2390        self.subscribe_to_alternative(cx);
2391        cx.notify();
2392    }
2393
2394    pub fn start(
2395        &mut self,
2396        user_prompt: String,
2397        assistant_panel_context: Option<LanguageModelRequest>,
2398        cx: &mut ModelContext<Self>,
2399    ) -> Result<()> {
2400        let alternative_models = LanguageModelRegistry::read_global(cx)
2401            .inline_alternative_models()
2402            .to_vec();
2403
2404        self.active_alternative()
2405            .update(cx, |alternative, cx| alternative.undo(cx));
2406        self.activate(0, cx);
2407        self.alternatives.truncate(1);
2408
2409        for _ in 0..alternative_models.len() {
2410            self.alternatives.push(cx.new_model(|cx| {
2411                CodegenAlternative::new(
2412                    self.buffer.clone(),
2413                    self.range.clone(),
2414                    false,
2415                    Some(self.telemetry.clone()),
2416                    self.builder.clone(),
2417                    cx,
2418                )
2419            }));
2420        }
2421
2422        let primary_model = LanguageModelRegistry::read_global(cx)
2423            .active_model()
2424            .context("no active model")?;
2425
2426        for (model, alternative) in iter::once(primary_model)
2427            .chain(alternative_models)
2428            .zip(&self.alternatives)
2429        {
2430            alternative.update(cx, |alternative, cx| {
2431                alternative.start(
2432                    user_prompt.clone(),
2433                    assistant_panel_context.clone(),
2434                    model.clone(),
2435                    cx,
2436                )
2437            })?;
2438        }
2439
2440        Ok(())
2441    }
2442
2443    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2444        for codegen in &self.alternatives {
2445            codegen.update(cx, |codegen, cx| codegen.stop(cx));
2446        }
2447    }
2448
2449    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2450        self.active_alternative()
2451            .update(cx, |codegen, cx| codegen.undo(cx));
2452
2453        self.buffer.update(cx, |buffer, cx| {
2454            if let Some(transaction_id) = self.initial_transaction_id.take() {
2455                buffer.undo_transaction(transaction_id, cx);
2456                buffer.refresh_preview(cx);
2457            }
2458        });
2459    }
2460
2461    pub fn count_tokens(
2462        &self,
2463        user_prompt: String,
2464        assistant_panel_context: Option<LanguageModelRequest>,
2465        cx: &AppContext,
2466    ) -> BoxFuture<'static, Result<TokenCounts>> {
2467        self.active_alternative()
2468            .read(cx)
2469            .count_tokens(user_prompt, assistant_panel_context, cx)
2470    }
2471
2472    pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2473        self.active_alternative().read(cx).buffer.clone()
2474    }
2475
2476    pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2477        self.active_alternative().read(cx).old_buffer.clone()
2478    }
2479
2480    pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2481        self.active_alternative().read(cx).snapshot.clone()
2482    }
2483
2484    pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2485        self.active_alternative().read(cx).edit_position
2486    }
2487
2488    fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2489        &self.active_alternative().read(cx).diff
2490    }
2491
2492    pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2493        self.active_alternative().read(cx).last_equal_ranges()
2494    }
2495}
2496
2497impl EventEmitter<CodegenEvent> for Codegen {}
2498
2499pub struct CodegenAlternative {
2500    buffer: Model<MultiBuffer>,
2501    old_buffer: Model<Buffer>,
2502    snapshot: MultiBufferSnapshot,
2503    edit_position: Option<Anchor>,
2504    range: Range<Anchor>,
2505    last_equal_ranges: Vec<Range<Anchor>>,
2506    transformation_transaction_id: Option<TransactionId>,
2507    status: CodegenStatus,
2508    generation: Task<()>,
2509    diff: Diff,
2510    telemetry: Option<Arc<Telemetry>>,
2511    _subscription: gpui::Subscription,
2512    builder: Arc<PromptBuilder>,
2513    active: bool,
2514    edits: Vec<(Range<Anchor>, String)>,
2515    line_operations: Vec<LineOperation>,
2516    request: Option<LanguageModelRequest>,
2517    elapsed_time: Option<f64>,
2518    message_id: Option<String>,
2519}
2520
2521enum CodegenStatus {
2522    Idle,
2523    Pending,
2524    Done,
2525    Error(anyhow::Error),
2526}
2527
2528#[derive(Default)]
2529struct Diff {
2530    deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2531    inserted_row_ranges: Vec<Range<Anchor>>,
2532}
2533
2534impl Diff {
2535    fn is_empty(&self) -> bool {
2536        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2537    }
2538}
2539
2540impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2541
2542impl CodegenAlternative {
2543    pub fn new(
2544        buffer: Model<MultiBuffer>,
2545        range: Range<Anchor>,
2546        active: bool,
2547        telemetry: Option<Arc<Telemetry>>,
2548        builder: Arc<PromptBuilder>,
2549        cx: &mut ModelContext<Self>,
2550    ) -> Self {
2551        let snapshot = buffer.read(cx).snapshot(cx);
2552
2553        let (old_buffer, _, _) = buffer
2554            .read(cx)
2555            .range_to_buffer_ranges(range.clone(), cx)
2556            .pop()
2557            .unwrap();
2558        let old_buffer = cx.new_model(|cx| {
2559            let old_buffer = old_buffer.read(cx);
2560            let text = old_buffer.as_rope().clone();
2561            let line_ending = old_buffer.line_ending();
2562            let language = old_buffer.language().cloned();
2563            let language_registry = old_buffer.language_registry();
2564
2565            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2566            buffer.set_language(language, cx);
2567            if let Some(language_registry) = language_registry {
2568                buffer.set_language_registry(language_registry)
2569            }
2570            buffer
2571        });
2572
2573        Self {
2574            buffer: buffer.clone(),
2575            old_buffer,
2576            edit_position: None,
2577            message_id: None,
2578            snapshot,
2579            last_equal_ranges: Default::default(),
2580            transformation_transaction_id: None,
2581            status: CodegenStatus::Idle,
2582            generation: Task::ready(()),
2583            diff: Diff::default(),
2584            telemetry,
2585            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2586            builder,
2587            active,
2588            edits: Vec::new(),
2589            line_operations: Vec::new(),
2590            range,
2591            request: None,
2592            elapsed_time: None,
2593        }
2594    }
2595
2596    fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2597        if active != self.active {
2598            self.active = active;
2599
2600            if self.active {
2601                let edits = self.edits.clone();
2602                self.apply_edits(edits, cx);
2603                if matches!(self.status, CodegenStatus::Pending) {
2604                    let line_operations = self.line_operations.clone();
2605                    self.reapply_line_based_diff(line_operations, cx);
2606                } else {
2607                    self.reapply_batch_diff(cx).detach();
2608                }
2609            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2610                self.buffer.update(cx, |buffer, cx| {
2611                    buffer.undo_transaction(transaction_id, cx);
2612                    buffer.forget_transaction(transaction_id, cx);
2613                });
2614            }
2615        }
2616    }
2617
2618    fn handle_buffer_event(
2619        &mut self,
2620        _buffer: Model<MultiBuffer>,
2621        event: &multi_buffer::Event,
2622        cx: &mut ModelContext<Self>,
2623    ) {
2624        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2625            if self.transformation_transaction_id == Some(*transaction_id) {
2626                self.transformation_transaction_id = None;
2627                self.generation = Task::ready(());
2628                cx.emit(CodegenEvent::Undone);
2629            }
2630        }
2631    }
2632
2633    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2634        &self.last_equal_ranges
2635    }
2636
2637    pub fn count_tokens(
2638        &self,
2639        user_prompt: String,
2640        assistant_panel_context: Option<LanguageModelRequest>,
2641        cx: &AppContext,
2642    ) -> BoxFuture<'static, Result<TokenCounts>> {
2643        if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2644            let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
2645            match request {
2646                Ok(request) => {
2647                    let total_count = model.count_tokens(request.clone(), cx);
2648                    let assistant_panel_count = assistant_panel_context
2649                        .map(|context| model.count_tokens(context, cx))
2650                        .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2651
2652                    async move {
2653                        Ok(TokenCounts {
2654                            total: total_count.await?,
2655                            assistant_panel: assistant_panel_count.await?,
2656                        })
2657                    }
2658                    .boxed()
2659                }
2660                Err(error) => futures::future::ready(Err(error)).boxed(),
2661            }
2662        } else {
2663            future::ready(Err(anyhow!("no active model"))).boxed()
2664        }
2665    }
2666
2667    pub fn start(
2668        &mut self,
2669        user_prompt: String,
2670        assistant_panel_context: Option<LanguageModelRequest>,
2671        model: Arc<dyn LanguageModel>,
2672        cx: &mut ModelContext<Self>,
2673    ) -> Result<()> {
2674        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2675            self.buffer.update(cx, |buffer, cx| {
2676                buffer.undo_transaction(transformation_transaction_id, cx);
2677            });
2678        }
2679
2680        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2681
2682        let api_key = model.api_key(cx);
2683        let telemetry_id = model.telemetry_id();
2684        let provider_id = model.provider_id();
2685        let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2686            if user_prompt.trim().to_lowercase() == "delete" {
2687                async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2688            } else {
2689                let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
2690                self.request = Some(request.clone());
2691
2692                cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
2693                    .boxed_local()
2694            };
2695        self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2696        Ok(())
2697    }
2698
2699    fn build_request(
2700        &self,
2701        user_prompt: String,
2702        assistant_panel_context: Option<LanguageModelRequest>,
2703        cx: &AppContext,
2704    ) -> Result<LanguageModelRequest> {
2705        let buffer = self.buffer.read(cx).snapshot(cx);
2706        let language = buffer.language_at(self.range.start);
2707        let language_name = if let Some(language) = language.as_ref() {
2708            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2709                None
2710            } else {
2711                Some(language.name())
2712            }
2713        } else {
2714            None
2715        };
2716
2717        let language_name = language_name.as_ref();
2718        let start = buffer.point_to_buffer_offset(self.range.start);
2719        let end = buffer.point_to_buffer_offset(self.range.end);
2720        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2721            let (start_buffer, start_buffer_offset) = start;
2722            let (end_buffer, end_buffer_offset) = end;
2723            if start_buffer.remote_id() == end_buffer.remote_id() {
2724                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2725            } else {
2726                return Err(anyhow::anyhow!("invalid transformation range"));
2727            }
2728        } else {
2729            return Err(anyhow::anyhow!("invalid transformation range"));
2730        };
2731
2732        let prompt = self
2733            .builder
2734            .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2735            .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2736
2737        let mut messages = Vec::new();
2738        if let Some(context_request) = assistant_panel_context {
2739            messages = context_request.messages;
2740        }
2741
2742        messages.push(LanguageModelRequestMessage {
2743            role: Role::User,
2744            content: vec![prompt.into()],
2745            cache: false,
2746        });
2747
2748        Ok(LanguageModelRequest {
2749            messages,
2750            tools: Vec::new(),
2751            stop: Vec::new(),
2752            temperature: None,
2753        })
2754    }
2755
2756    pub fn handle_stream(
2757        &mut self,
2758        model_telemetry_id: String,
2759        model_provider_id: String,
2760        model_api_key: Option<String>,
2761        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2762        cx: &mut ModelContext<Self>,
2763    ) {
2764        let start_time = Instant::now();
2765        let snapshot = self.snapshot.clone();
2766        let selected_text = snapshot
2767            .text_for_range(self.range.start..self.range.end)
2768            .collect::<Rope>();
2769
2770        let selection_start = self.range.start.to_point(&snapshot);
2771
2772        // Start with the indentation of the first line in the selection
2773        let mut suggested_line_indent = snapshot
2774            .suggested_indents(selection_start.row..=selection_start.row, cx)
2775            .into_values()
2776            .next()
2777            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2778
2779        // If the first line in the selection does not have indentation, check the following lines
2780        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2781            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2782                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2783                // Prefer tabs if a line in the selection uses tabs as indentation
2784                if line_indent.kind == IndentKind::Tab {
2785                    suggested_line_indent.kind = IndentKind::Tab;
2786                    break;
2787                }
2788            }
2789        }
2790
2791        let http_client = cx.http_client().clone();
2792        let telemetry = self.telemetry.clone();
2793        let language_name = {
2794            let multibuffer = self.buffer.read(cx);
2795            let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
2796            ranges
2797                .first()
2798                .and_then(|(buffer, _, _)| buffer.read(cx).language())
2799                .map(|language| language.name())
2800        };
2801
2802        self.diff = Diff::default();
2803        self.status = CodegenStatus::Pending;
2804        let mut edit_start = self.range.start.to_offset(&snapshot);
2805        self.generation = cx.spawn(|codegen, mut cx| {
2806            async move {
2807                let stream = stream.await;
2808                let message_id = stream
2809                    .as_ref()
2810                    .ok()
2811                    .and_then(|stream| stream.message_id.clone());
2812                let generate = async {
2813                    let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2814                    let executor = cx.background_executor().clone();
2815                    let message_id = message_id.clone();
2816                    let line_based_stream_diff: Task<anyhow::Result<()>> =
2817                        cx.background_executor().spawn(async move {
2818                            let mut response_latency = None;
2819                            let request_start = Instant::now();
2820                            let diff = async {
2821                                let chunks = StripInvalidSpans::new(stream?.stream);
2822                                futures::pin_mut!(chunks);
2823                                let mut diff = StreamingDiff::new(selected_text.to_string());
2824                                let mut line_diff = LineDiff::default();
2825
2826                                let mut new_text = String::new();
2827                                let mut base_indent = None;
2828                                let mut line_indent = None;
2829                                let mut first_line = true;
2830
2831                                while let Some(chunk) = chunks.next().await {
2832                                    if response_latency.is_none() {
2833                                        response_latency = Some(request_start.elapsed());
2834                                    }
2835                                    let chunk = chunk?;
2836
2837                                    let mut lines = chunk.split('\n').peekable();
2838                                    while let Some(line) = lines.next() {
2839                                        new_text.push_str(line);
2840                                        if line_indent.is_none() {
2841                                            if let Some(non_whitespace_ch_ix) =
2842                                                new_text.find(|ch: char| !ch.is_whitespace())
2843                                            {
2844                                                line_indent = Some(non_whitespace_ch_ix);
2845                                                base_indent = base_indent.or(line_indent);
2846
2847                                                let line_indent = line_indent.unwrap();
2848                                                let base_indent = base_indent.unwrap();
2849                                                let indent_delta =
2850                                                    line_indent as i32 - base_indent as i32;
2851                                                let mut corrected_indent_len = cmp::max(
2852                                                    0,
2853                                                    suggested_line_indent.len as i32 + indent_delta,
2854                                                )
2855                                                    as usize;
2856                                                if first_line {
2857                                                    corrected_indent_len = corrected_indent_len
2858                                                        .saturating_sub(
2859                                                            selection_start.column as usize,
2860                                                        );
2861                                                }
2862
2863                                                let indent_char = suggested_line_indent.char();
2864                                                let mut indent_buffer = [0; 4];
2865                                                let indent_str =
2866                                                    indent_char.encode_utf8(&mut indent_buffer);
2867                                                new_text.replace_range(
2868                                                    ..line_indent,
2869                                                    &indent_str.repeat(corrected_indent_len),
2870                                                );
2871                                            }
2872                                        }
2873
2874                                        if line_indent.is_some() {
2875                                            let char_ops = diff.push_new(&new_text);
2876                                            line_diff
2877                                                .push_char_operations(&char_ops, &selected_text);
2878                                            diff_tx
2879                                                .send((char_ops, line_diff.line_operations()))
2880                                                .await?;
2881                                            new_text.clear();
2882                                        }
2883
2884                                        if lines.peek().is_some() {
2885                                            let char_ops = diff.push_new("\n");
2886                                            line_diff
2887                                                .push_char_operations(&char_ops, &selected_text);
2888                                            diff_tx
2889                                                .send((char_ops, line_diff.line_operations()))
2890                                                .await?;
2891                                            if line_indent.is_none() {
2892                                                // Don't write out the leading indentation in empty lines on the next line
2893                                                // This is the case where the above if statement didn't clear the buffer
2894                                                new_text.clear();
2895                                            }
2896                                            line_indent = None;
2897                                            first_line = false;
2898                                        }
2899                                    }
2900                                }
2901
2902                                let mut char_ops = diff.push_new(&new_text);
2903                                char_ops.extend(diff.finish());
2904                                line_diff.push_char_operations(&char_ops, &selected_text);
2905                                line_diff.finish(&selected_text);
2906                                diff_tx
2907                                    .send((char_ops, line_diff.line_operations()))
2908                                    .await?;
2909
2910                                anyhow::Ok(())
2911                            };
2912
2913                            let result = diff.await;
2914
2915                            let error_message =
2916                                result.as_ref().err().map(|error| error.to_string());
2917                            report_assistant_event(
2918                                AssistantEvent {
2919                                    conversation_id: None,
2920                                    message_id,
2921                                    kind: AssistantKind::Inline,
2922                                    phase: AssistantPhase::Response,
2923                                    model: model_telemetry_id,
2924                                    model_provider: model_provider_id.to_string(),
2925                                    response_latency,
2926                                    error_message,
2927                                    language_name: language_name.map(|name| name.to_proto()),
2928                                },
2929                                telemetry,
2930                                http_client,
2931                                model_api_key,
2932                                &executor,
2933                            );
2934
2935                            result?;
2936                            Ok(())
2937                        });
2938
2939                    while let Some((char_ops, line_ops)) = diff_rx.next().await {
2940                        codegen.update(&mut cx, |codegen, cx| {
2941                            codegen.last_equal_ranges.clear();
2942
2943                            let edits = char_ops
2944                                .into_iter()
2945                                .filter_map(|operation| match operation {
2946                                    CharOperation::Insert { text } => {
2947                                        let edit_start = snapshot.anchor_after(edit_start);
2948                                        Some((edit_start..edit_start, text))
2949                                    }
2950                                    CharOperation::Delete { bytes } => {
2951                                        let edit_end = edit_start + bytes;
2952                                        let edit_range = snapshot.anchor_after(edit_start)
2953                                            ..snapshot.anchor_before(edit_end);
2954                                        edit_start = edit_end;
2955                                        Some((edit_range, String::new()))
2956                                    }
2957                                    CharOperation::Keep { bytes } => {
2958                                        let edit_end = edit_start + bytes;
2959                                        let edit_range = snapshot.anchor_after(edit_start)
2960                                            ..snapshot.anchor_before(edit_end);
2961                                        edit_start = edit_end;
2962                                        codegen.last_equal_ranges.push(edit_range);
2963                                        None
2964                                    }
2965                                })
2966                                .collect::<Vec<_>>();
2967
2968                            if codegen.active {
2969                                codegen.apply_edits(edits.iter().cloned(), cx);
2970                                codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2971                            }
2972                            codegen.edits.extend(edits);
2973                            codegen.line_operations = line_ops;
2974                            codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2975
2976                            cx.notify();
2977                        })?;
2978                    }
2979
2980                    // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2981                    // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2982                    // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2983                    let batch_diff_task =
2984                        codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
2985                    let (line_based_stream_diff, ()) =
2986                        join!(line_based_stream_diff, batch_diff_task);
2987                    line_based_stream_diff?;
2988
2989                    anyhow::Ok(())
2990                };
2991
2992                let result = generate.await;
2993                let elapsed_time = start_time.elapsed().as_secs_f64();
2994
2995                codegen
2996                    .update(&mut cx, |this, cx| {
2997                        this.message_id = message_id;
2998                        this.last_equal_ranges.clear();
2999                        if let Err(error) = result {
3000                            this.status = CodegenStatus::Error(error);
3001                        } else {
3002                            this.status = CodegenStatus::Done;
3003                        }
3004                        this.elapsed_time = Some(elapsed_time);
3005                        cx.emit(CodegenEvent::Finished);
3006                        cx.notify();
3007                    })
3008                    .ok();
3009            }
3010        });
3011        cx.notify();
3012    }
3013
3014    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
3015        self.last_equal_ranges.clear();
3016        if self.diff.is_empty() {
3017            self.status = CodegenStatus::Idle;
3018        } else {
3019            self.status = CodegenStatus::Done;
3020        }
3021        self.generation = Task::ready(());
3022        cx.emit(CodegenEvent::Finished);
3023        cx.notify();
3024    }
3025
3026    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
3027        self.buffer.update(cx, |buffer, cx| {
3028            if let Some(transaction_id) = self.transformation_transaction_id.take() {
3029                buffer.undo_transaction(transaction_id, cx);
3030                buffer.refresh_preview(cx);
3031            }
3032        });
3033    }
3034
3035    fn apply_edits(
3036        &mut self,
3037        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
3038        cx: &mut ModelContext<CodegenAlternative>,
3039    ) {
3040        let transaction = self.buffer.update(cx, |buffer, cx| {
3041            // Avoid grouping assistant edits with user edits.
3042            buffer.finalize_last_transaction(cx);
3043            buffer.start_transaction(cx);
3044            buffer.edit(edits, None, cx);
3045            buffer.end_transaction(cx)
3046        });
3047
3048        if let Some(transaction) = transaction {
3049            if let Some(first_transaction) = self.transformation_transaction_id {
3050                // Group all assistant edits into the first transaction.
3051                self.buffer.update(cx, |buffer, cx| {
3052                    buffer.merge_transactions(transaction, first_transaction, cx)
3053                });
3054            } else {
3055                self.transformation_transaction_id = Some(transaction);
3056                self.buffer
3057                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3058            }
3059        }
3060    }
3061
3062    fn reapply_line_based_diff(
3063        &mut self,
3064        line_operations: impl IntoIterator<Item = LineOperation>,
3065        cx: &mut ModelContext<Self>,
3066    ) {
3067        let old_snapshot = self.snapshot.clone();
3068        let old_range = self.range.to_point(&old_snapshot);
3069        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3070        let new_range = self.range.to_point(&new_snapshot);
3071
3072        let mut old_row = old_range.start.row;
3073        let mut new_row = new_range.start.row;
3074
3075        self.diff.deleted_row_ranges.clear();
3076        self.diff.inserted_row_ranges.clear();
3077        for operation in line_operations {
3078            match operation {
3079                LineOperation::Keep { lines } => {
3080                    old_row += lines;
3081                    new_row += lines;
3082                }
3083                LineOperation::Delete { lines } => {
3084                    let old_end_row = old_row + lines - 1;
3085                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3086
3087                    if let Some((_, last_deleted_row_range)) =
3088                        self.diff.deleted_row_ranges.last_mut()
3089                    {
3090                        if *last_deleted_row_range.end() + 1 == old_row {
3091                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3092                        } else {
3093                            self.diff
3094                                .deleted_row_ranges
3095                                .push((new_row, old_row..=old_end_row));
3096                        }
3097                    } else {
3098                        self.diff
3099                            .deleted_row_ranges
3100                            .push((new_row, old_row..=old_end_row));
3101                    }
3102
3103                    old_row += lines;
3104                }
3105                LineOperation::Insert { lines } => {
3106                    let new_end_row = new_row + lines - 1;
3107                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3108                    let end = new_snapshot.anchor_before(Point::new(
3109                        new_end_row,
3110                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
3111                    ));
3112                    self.diff.inserted_row_ranges.push(start..end);
3113                    new_row += lines;
3114                }
3115            }
3116
3117            cx.notify();
3118        }
3119    }
3120
3121    fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3122        let old_snapshot = self.snapshot.clone();
3123        let old_range = self.range.to_point(&old_snapshot);
3124        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3125        let new_range = self.range.to_point(&new_snapshot);
3126
3127        cx.spawn(|codegen, mut cx| async move {
3128            let (deleted_row_ranges, inserted_row_ranges) = cx
3129                .background_executor()
3130                .spawn(async move {
3131                    let old_text = old_snapshot
3132                        .text_for_range(
3133                            Point::new(old_range.start.row, 0)
3134                                ..Point::new(
3135                                    old_range.end.row,
3136                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3137                                ),
3138                        )
3139                        .collect::<String>();
3140                    let new_text = new_snapshot
3141                        .text_for_range(
3142                            Point::new(new_range.start.row, 0)
3143                                ..Point::new(
3144                                    new_range.end.row,
3145                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3146                                ),
3147                        )
3148                        .collect::<String>();
3149
3150                    let mut old_row = old_range.start.row;
3151                    let mut new_row = new_range.start.row;
3152                    let batch_diff =
3153                        similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3154
3155                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3156                    let mut inserted_row_ranges = Vec::new();
3157                    for change in batch_diff.iter_all_changes() {
3158                        let line_count = change.value().lines().count() as u32;
3159                        match change.tag() {
3160                            similar::ChangeTag::Equal => {
3161                                old_row += line_count;
3162                                new_row += line_count;
3163                            }
3164                            similar::ChangeTag::Delete => {
3165                                let old_end_row = old_row + line_count - 1;
3166                                let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3167
3168                                if let Some((_, last_deleted_row_range)) =
3169                                    deleted_row_ranges.last_mut()
3170                                {
3171                                    if *last_deleted_row_range.end() + 1 == old_row {
3172                                        *last_deleted_row_range =
3173                                            *last_deleted_row_range.start()..=old_end_row;
3174                                    } else {
3175                                        deleted_row_ranges.push((new_row, old_row..=old_end_row));
3176                                    }
3177                                } else {
3178                                    deleted_row_ranges.push((new_row, old_row..=old_end_row));
3179                                }
3180
3181                                old_row += line_count;
3182                            }
3183                            similar::ChangeTag::Insert => {
3184                                let new_end_row = new_row + line_count - 1;
3185                                let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3186                                let end = new_snapshot.anchor_before(Point::new(
3187                                    new_end_row,
3188                                    new_snapshot.line_len(MultiBufferRow(new_end_row)),
3189                                ));
3190                                inserted_row_ranges.push(start..end);
3191                                new_row += line_count;
3192                            }
3193                        }
3194                    }
3195
3196                    (deleted_row_ranges, inserted_row_ranges)
3197                })
3198                .await;
3199
3200            codegen
3201                .update(&mut cx, |codegen, cx| {
3202                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
3203                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
3204                    cx.notify();
3205                })
3206                .ok();
3207        })
3208    }
3209}
3210
3211struct StripInvalidSpans<T> {
3212    stream: T,
3213    stream_done: bool,
3214    buffer: String,
3215    first_line: bool,
3216    line_end: bool,
3217    starts_with_code_block: bool,
3218}
3219
3220impl<T> StripInvalidSpans<T>
3221where
3222    T: Stream<Item = Result<String>>,
3223{
3224    fn new(stream: T) -> Self {
3225        Self {
3226            stream,
3227            stream_done: false,
3228            buffer: String::new(),
3229            first_line: true,
3230            line_end: false,
3231            starts_with_code_block: false,
3232        }
3233    }
3234}
3235
3236impl<T> Stream for StripInvalidSpans<T>
3237where
3238    T: Stream<Item = Result<String>>,
3239{
3240    type Item = Result<String>;
3241
3242    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3243        const CODE_BLOCK_DELIMITER: &str = "```";
3244        const CURSOR_SPAN: &str = "<|CURSOR|>";
3245
3246        let this = unsafe { self.get_unchecked_mut() };
3247        loop {
3248            if !this.stream_done {
3249                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3250                match stream.as_mut().poll_next(cx) {
3251                    Poll::Ready(Some(Ok(chunk))) => {
3252                        this.buffer.push_str(&chunk);
3253                    }
3254                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3255                    Poll::Ready(None) => {
3256                        this.stream_done = true;
3257                    }
3258                    Poll::Pending => return Poll::Pending,
3259                }
3260            }
3261
3262            let mut chunk = String::new();
3263            let mut consumed = 0;
3264            if !this.buffer.is_empty() {
3265                let mut lines = this.buffer.split('\n').enumerate().peekable();
3266                while let Some((line_ix, line)) = lines.next() {
3267                    if line_ix > 0 {
3268                        this.first_line = false;
3269                    }
3270
3271                    if this.first_line {
3272                        let trimmed_line = line.trim();
3273                        if lines.peek().is_some() {
3274                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3275                                consumed += line.len() + 1;
3276                                this.starts_with_code_block = true;
3277                                continue;
3278                            }
3279                        } else if trimmed_line.is_empty()
3280                            || prefixes(CODE_BLOCK_DELIMITER)
3281                                .any(|prefix| trimmed_line.starts_with(prefix))
3282                        {
3283                            break;
3284                        }
3285                    }
3286
3287                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
3288                    if lines.peek().is_some() {
3289                        if this.line_end {
3290                            chunk.push('\n');
3291                        }
3292
3293                        chunk.push_str(&line_without_cursor);
3294                        this.line_end = true;
3295                        consumed += line.len() + 1;
3296                    } else if this.stream_done {
3297                        if !this.starts_with_code_block
3298                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3299                        {
3300                            if this.line_end {
3301                                chunk.push('\n');
3302                            }
3303
3304                            chunk.push_str(&line);
3305                        }
3306
3307                        consumed += line.len();
3308                    } else {
3309                        let trimmed_line = line.trim();
3310                        if trimmed_line.is_empty()
3311                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3312                            || prefixes(CODE_BLOCK_DELIMITER)
3313                                .any(|prefix| trimmed_line.ends_with(prefix))
3314                        {
3315                            break;
3316                        } else {
3317                            if this.line_end {
3318                                chunk.push('\n');
3319                                this.line_end = false;
3320                            }
3321
3322                            chunk.push_str(&line_without_cursor);
3323                            consumed += line.len();
3324                        }
3325                    }
3326                }
3327            }
3328
3329            this.buffer = this.buffer.split_off(consumed);
3330            if !chunk.is_empty() {
3331                return Poll::Ready(Some(Ok(chunk)));
3332            } else if this.stream_done {
3333                return Poll::Ready(None);
3334            }
3335        }
3336    }
3337}
3338
3339struct AssistantCodeActionProvider {
3340    editor: WeakView<Editor>,
3341    workspace: WeakView<Workspace>,
3342}
3343
3344impl CodeActionProvider for AssistantCodeActionProvider {
3345    fn code_actions(
3346        &self,
3347        buffer: &Model<Buffer>,
3348        range: Range<text::Anchor>,
3349        cx: &mut WindowContext,
3350    ) -> Task<Result<Vec<CodeAction>>> {
3351        if !AssistantSettings::get_global(cx).enabled {
3352            return Task::ready(Ok(Vec::new()));
3353        }
3354
3355        let snapshot = buffer.read(cx).snapshot();
3356        let mut range = range.to_point(&snapshot);
3357
3358        // Expand the range to line boundaries.
3359        range.start.column = 0;
3360        range.end.column = snapshot.line_len(range.end.row);
3361
3362        let mut has_diagnostics = false;
3363        for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3364            range.start = cmp::min(range.start, diagnostic.range.start);
3365            range.end = cmp::max(range.end, diagnostic.range.end);
3366            has_diagnostics = true;
3367        }
3368        if has_diagnostics {
3369            if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3370                if let Some(symbol) = symbols_containing_start.last() {
3371                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3372                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3373                }
3374            }
3375
3376            if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3377                if let Some(symbol) = symbols_containing_end.last() {
3378                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3379                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3380                }
3381            }
3382
3383            Task::ready(Ok(vec![CodeAction {
3384                server_id: language::LanguageServerId(0),
3385                range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3386                lsp_action: lsp::CodeAction {
3387                    title: "Fix with Assistant".into(),
3388                    ..Default::default()
3389                },
3390            }]))
3391        } else {
3392            Task::ready(Ok(Vec::new()))
3393        }
3394    }
3395
3396    fn apply_code_action(
3397        &self,
3398        buffer: Model<Buffer>,
3399        action: CodeAction,
3400        excerpt_id: ExcerptId,
3401        _push_to_history: bool,
3402        cx: &mut WindowContext,
3403    ) -> Task<Result<ProjectTransaction>> {
3404        let editor = self.editor.clone();
3405        let workspace = self.workspace.clone();
3406        cx.spawn(|mut cx| async move {
3407            let editor = editor.upgrade().context("editor was released")?;
3408            let range = editor
3409                .update(&mut cx, |editor, cx| {
3410                    editor.buffer().update(cx, |multibuffer, cx| {
3411                        let buffer = buffer.read(cx);
3412                        let multibuffer_snapshot = multibuffer.read(cx);
3413
3414                        let old_context_range =
3415                            multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3416                        let mut new_context_range = old_context_range.clone();
3417                        if action
3418                            .range
3419                            .start
3420                            .cmp(&old_context_range.start, buffer)
3421                            .is_lt()
3422                        {
3423                            new_context_range.start = action.range.start;
3424                        }
3425                        if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3426                            new_context_range.end = action.range.end;
3427                        }
3428                        drop(multibuffer_snapshot);
3429
3430                        if new_context_range != old_context_range {
3431                            multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3432                        }
3433
3434                        let multibuffer_snapshot = multibuffer.read(cx);
3435                        Some(
3436                            multibuffer_snapshot
3437                                .anchor_in_excerpt(excerpt_id, action.range.start)?
3438                                ..multibuffer_snapshot
3439                                    .anchor_in_excerpt(excerpt_id, action.range.end)?,
3440                        )
3441                    })
3442                })?
3443                .context("invalid range")?;
3444            let assistant_panel = workspace.update(&mut cx, |workspace, cx| {
3445                workspace
3446                    .panel::<AssistantPanel>(cx)
3447                    .context("assistant panel was released")
3448            })??;
3449
3450            cx.update_global(|assistant: &mut InlineAssistant, cx| {
3451                let assist_id = assistant.suggest_assist(
3452                    &editor,
3453                    range,
3454                    "Fix Diagnostics".into(),
3455                    None,
3456                    true,
3457                    Some(workspace),
3458                    Some(&assistant_panel),
3459                    cx,
3460                );
3461                assistant.start_assist(assist_id, cx);
3462            })?;
3463
3464            Ok(ProjectTransaction::default())
3465        })
3466    }
3467}
3468
3469fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3470    (0..text.len() - 1).map(|ix| &text[..ix + 1])
3471}
3472
3473fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3474    ranges.sort_unstable_by(|a, b| {
3475        a.start
3476            .cmp(&b.start, buffer)
3477            .then_with(|| b.end.cmp(&a.end, buffer))
3478    });
3479
3480    let mut ix = 0;
3481    while ix + 1 < ranges.len() {
3482        let b = ranges[ix + 1].clone();
3483        let a = &mut ranges[ix];
3484        if a.end.cmp(&b.start, buffer).is_gt() {
3485            if a.end.cmp(&b.end, buffer).is_lt() {
3486                a.end = b.end;
3487            }
3488            ranges.remove(ix + 1);
3489        } else {
3490            ix += 1;
3491        }
3492    }
3493}
3494
3495#[cfg(test)]
3496mod tests {
3497    use super::*;
3498    use futures::stream::{self};
3499    use gpui::{Context, TestAppContext};
3500    use indoc::indoc;
3501    use language::{
3502        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3503        Point,
3504    };
3505    use language_model::LanguageModelRegistry;
3506    use rand::prelude::*;
3507    use serde::Serialize;
3508    use settings::SettingsStore;
3509    use std::{future, sync::Arc};
3510
3511    #[derive(Serialize)]
3512    pub struct DummyCompletionRequest {
3513        pub name: String,
3514    }
3515
3516    #[gpui::test(iterations = 10)]
3517    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3518        cx.set_global(cx.update(SettingsStore::test));
3519        cx.update(language_model::LanguageModelRegistry::test);
3520        cx.update(language_settings::init);
3521
3522        let text = indoc! {"
3523            fn main() {
3524                let x = 0;
3525                for _ in 0..10 {
3526                    x += 1;
3527                }
3528            }
3529        "};
3530        let buffer =
3531            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3532        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3533        let range = buffer.read_with(cx, |buffer, cx| {
3534            let snapshot = buffer.snapshot(cx);
3535            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3536        });
3537        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3538        let codegen = cx.new_model(|cx| {
3539            CodegenAlternative::new(
3540                buffer.clone(),
3541                range.clone(),
3542                true,
3543                None,
3544                prompt_builder,
3545                cx,
3546            )
3547        });
3548
3549        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3550
3551        let mut new_text = concat!(
3552            "       let mut x = 0;\n",
3553            "       while x < 10 {\n",
3554            "           x += 1;\n",
3555            "       }",
3556        );
3557        while !new_text.is_empty() {
3558            let max_len = cmp::min(new_text.len(), 10);
3559            let len = rng.gen_range(1..=max_len);
3560            let (chunk, suffix) = new_text.split_at(len);
3561            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3562            new_text = suffix;
3563            cx.background_executor.run_until_parked();
3564        }
3565        drop(chunks_tx);
3566        cx.background_executor.run_until_parked();
3567
3568        assert_eq!(
3569            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3570            indoc! {"
3571                fn main() {
3572                    let mut x = 0;
3573                    while x < 10 {
3574                        x += 1;
3575                    }
3576                }
3577            "}
3578        );
3579    }
3580
3581    #[gpui::test(iterations = 10)]
3582    async fn test_autoindent_when_generating_past_indentation(
3583        cx: &mut TestAppContext,
3584        mut rng: StdRng,
3585    ) {
3586        cx.set_global(cx.update(SettingsStore::test));
3587        cx.update(language_settings::init);
3588
3589        let text = indoc! {"
3590            fn main() {
3591                le
3592            }
3593        "};
3594        let buffer =
3595            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3596        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3597        let range = buffer.read_with(cx, |buffer, cx| {
3598            let snapshot = buffer.snapshot(cx);
3599            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3600        });
3601        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3602        let codegen = cx.new_model(|cx| {
3603            CodegenAlternative::new(
3604                buffer.clone(),
3605                range.clone(),
3606                true,
3607                None,
3608                prompt_builder,
3609                cx,
3610            )
3611        });
3612
3613        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3614
3615        cx.background_executor.run_until_parked();
3616
3617        let mut new_text = concat!(
3618            "t mut x = 0;\n",
3619            "while x < 10 {\n",
3620            "    x += 1;\n",
3621            "}", //
3622        );
3623        while !new_text.is_empty() {
3624            let max_len = cmp::min(new_text.len(), 10);
3625            let len = rng.gen_range(1..=max_len);
3626            let (chunk, suffix) = new_text.split_at(len);
3627            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3628            new_text = suffix;
3629            cx.background_executor.run_until_parked();
3630        }
3631        drop(chunks_tx);
3632        cx.background_executor.run_until_parked();
3633
3634        assert_eq!(
3635            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3636            indoc! {"
3637                fn main() {
3638                    let mut x = 0;
3639                    while x < 10 {
3640                        x += 1;
3641                    }
3642                }
3643            "}
3644        );
3645    }
3646
3647    #[gpui::test(iterations = 10)]
3648    async fn test_autoindent_when_generating_before_indentation(
3649        cx: &mut TestAppContext,
3650        mut rng: StdRng,
3651    ) {
3652        cx.update(LanguageModelRegistry::test);
3653        cx.set_global(cx.update(SettingsStore::test));
3654        cx.update(language_settings::init);
3655
3656        let text = concat!(
3657            "fn main() {\n",
3658            "  \n",
3659            "}\n" //
3660        );
3661        let buffer =
3662            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3663        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3664        let range = buffer.read_with(cx, |buffer, cx| {
3665            let snapshot = buffer.snapshot(cx);
3666            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3667        });
3668        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3669        let codegen = cx.new_model(|cx| {
3670            CodegenAlternative::new(
3671                buffer.clone(),
3672                range.clone(),
3673                true,
3674                None,
3675                prompt_builder,
3676                cx,
3677            )
3678        });
3679
3680        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3681
3682        cx.background_executor.run_until_parked();
3683
3684        let mut new_text = concat!(
3685            "let mut x = 0;\n",
3686            "while x < 10 {\n",
3687            "    x += 1;\n",
3688            "}", //
3689        );
3690        while !new_text.is_empty() {
3691            let max_len = cmp::min(new_text.len(), 10);
3692            let len = rng.gen_range(1..=max_len);
3693            let (chunk, suffix) = new_text.split_at(len);
3694            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3695            new_text = suffix;
3696            cx.background_executor.run_until_parked();
3697        }
3698        drop(chunks_tx);
3699        cx.background_executor.run_until_parked();
3700
3701        assert_eq!(
3702            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3703            indoc! {"
3704                fn main() {
3705                    let mut x = 0;
3706                    while x < 10 {
3707                        x += 1;
3708                    }
3709                }
3710            "}
3711        );
3712    }
3713
3714    #[gpui::test(iterations = 10)]
3715    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3716        cx.update(LanguageModelRegistry::test);
3717        cx.set_global(cx.update(SettingsStore::test));
3718        cx.update(language_settings::init);
3719
3720        let text = indoc! {"
3721            func main() {
3722            \tx := 0
3723            \tfor i := 0; i < 10; i++ {
3724            \t\tx++
3725            \t}
3726            }
3727        "};
3728        let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3729        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3730        let range = buffer.read_with(cx, |buffer, cx| {
3731            let snapshot = buffer.snapshot(cx);
3732            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3733        });
3734        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3735        let codegen = cx.new_model(|cx| {
3736            CodegenAlternative::new(
3737                buffer.clone(),
3738                range.clone(),
3739                true,
3740                None,
3741                prompt_builder,
3742                cx,
3743            )
3744        });
3745
3746        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3747        let new_text = concat!(
3748            "func main() {\n",
3749            "\tx := 0\n",
3750            "\tfor x < 10 {\n",
3751            "\t\tx++\n",
3752            "\t}", //
3753        );
3754        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3755        drop(chunks_tx);
3756        cx.background_executor.run_until_parked();
3757
3758        assert_eq!(
3759            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3760            indoc! {"
3761                func main() {
3762                \tx := 0
3763                \tfor x < 10 {
3764                \t\tx++
3765                \t}
3766                }
3767            "}
3768        );
3769    }
3770
3771    #[gpui::test]
3772    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3773        cx.update(LanguageModelRegistry::test);
3774        cx.set_global(cx.update(SettingsStore::test));
3775        cx.update(language_settings::init);
3776
3777        let text = indoc! {"
3778            fn main() {
3779                let x = 0;
3780            }
3781        "};
3782        let buffer =
3783            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3784        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3785        let range = buffer.read_with(cx, |buffer, cx| {
3786            let snapshot = buffer.snapshot(cx);
3787            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3788        });
3789        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3790        let codegen = cx.new_model(|cx| {
3791            CodegenAlternative::new(
3792                buffer.clone(),
3793                range.clone(),
3794                false,
3795                None,
3796                prompt_builder,
3797                cx,
3798            )
3799        });
3800
3801        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3802        chunks_tx
3803            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3804            .unwrap();
3805        drop(chunks_tx);
3806        cx.run_until_parked();
3807
3808        // The codegen is inactive, so the buffer doesn't get modified.
3809        assert_eq!(
3810            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3811            text
3812        );
3813
3814        // Activating the codegen applies the changes.
3815        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3816        assert_eq!(
3817            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3818            indoc! {"
3819                fn main() {
3820                    let mut x = 0;
3821                    x += 1;
3822                }
3823            "}
3824        );
3825
3826        // Deactivating the codegen undoes the changes.
3827        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3828        cx.run_until_parked();
3829        assert_eq!(
3830            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3831            text
3832        );
3833    }
3834
3835    #[gpui::test]
3836    async fn test_strip_invalid_spans_from_codeblock() {
3837        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3838        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3839        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3840        assert_chunks(
3841            "```html\n```js\nLorem ipsum dolor\n```\n```",
3842            "```js\nLorem ipsum dolor\n```",
3843        )
3844        .await;
3845        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3846        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3847        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3848        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3849
3850        async fn assert_chunks(text: &str, expected_text: &str) {
3851            for chunk_size in 1..=text.len() {
3852                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3853                    .map(|chunk| chunk.unwrap())
3854                    .collect::<String>()
3855                    .await;
3856                assert_eq!(
3857                    actual_text, expected_text,
3858                    "failed to strip invalid spans, chunk size: {}",
3859                    chunk_size
3860                );
3861            }
3862        }
3863
3864        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3865            stream::iter(
3866                text.chars()
3867                    .collect::<Vec<_>>()
3868                    .chunks(size)
3869                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
3870                    .collect::<Vec<_>>(),
3871            )
3872        }
3873    }
3874
3875    fn simulate_response_stream(
3876        codegen: Model<CodegenAlternative>,
3877        cx: &mut TestAppContext,
3878    ) -> mpsc::UnboundedSender<String> {
3879        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3880        codegen.update(cx, |codegen, cx| {
3881            codegen.handle_stream(
3882                String::new(),
3883                String::new(),
3884                None,
3885                future::ready(Ok(LanguageModelTextStream {
3886                    message_id: None,
3887                    stream: chunks_rx.map(Ok).boxed(),
3888                })),
3889                cx,
3890            );
3891        });
3892        chunks_tx
3893    }
3894
3895    fn rust_lang() -> Language {
3896        Language::new(
3897            LanguageConfig {
3898                name: "Rust".into(),
3899                matcher: LanguageMatcher {
3900                    path_suffixes: vec!["rs".to_string()],
3901                    ..Default::default()
3902                },
3903                ..Default::default()
3904            },
3905            Some(tree_sitter_rust::LANGUAGE.into()),
3906        )
3907        .with_indents_query(
3908            r#"
3909            (call_expression) @indent
3910            (field_expression) @indent
3911            (_ "(" ")" @end) @indent
3912            (_ "{" "}" @end) @indent
3913            "#,
3914        )
3915        .unwrap()
3916    }
3917}