inline_assistant.rs

   1use crate::context::attach_context_to_message;
   2use crate::context_picker::ContextPicker;
   3use crate::context_store::ContextStore;
   4use crate::context_strip::ContextStrip;
   5use crate::thread_store::ThreadStore;
   6use crate::{
   7    assistant_settings::AssistantSettings,
   8    prompts::PromptBuilder,
   9    streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
  10    terminal_inline_assistant::TerminalInlineAssistant,
  11    CycleNextInlineAssist, CyclePreviousInlineAssist, ToggleInlineAssist,
  12};
  13use crate::{AssistantPanel, ToggleContextPicker};
  14use anyhow::{Context as _, Result};
  15use client::{telemetry::Telemetry, ErrorExt};
  16use collections::{hash_map, HashMap, HashSet, VecDeque};
  17use editor::{
  18    actions::{MoveDown, MoveUp, SelectAll},
  19    display_map::{
  20        BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
  21        ToDisplayPoint,
  22    },
  23    Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
  24    EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
  25    ToOffset as _, ToPoint,
  26};
  27use feature_flags::{FeatureFlagAppExt as _, ZedPro};
  28use fs::Fs;
  29use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
  30use gpui::{
  31    anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
  32    FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext,
  33    Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakModel, WeakView,
  34    WindowContext,
  35};
  36use language::{Buffer, IndentKind, Point, Selection, TransactionId};
  37use language_model::{
  38    LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
  39    LanguageModelTextStream, Role,
  40};
  41use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
  42use language_models::report_assistant_event;
  43use multi_buffer::MultiBufferRow;
  44use parking_lot::Mutex;
  45use project::{CodeAction, ProjectTransaction};
  46use rope::Rope;
  47use settings::{update_settings_file, Settings, SettingsStore};
  48use smol::future::FutureExt;
  49use std::{
  50    cmp,
  51    future::Future,
  52    iter, mem,
  53    ops::{Range, RangeInclusive},
  54    pin::Pin,
  55    rc::Rc,
  56    sync::Arc,
  57    task::{self, Poll},
  58    time::Instant,
  59};
  60use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
  61use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
  62use text::{OffsetRangeExt, ToPoint as _};
  63use theme::ThemeSettings;
  64use ui::{
  65    prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip,
  66};
  67use util::{RangeExt, ResultExt};
  68use workspace::{dock::Panel, ShowConfiguration};
  69use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
  70
  71pub fn init(
  72    fs: Arc<dyn Fs>,
  73    prompt_builder: Arc<PromptBuilder>,
  74    telemetry: Arc<Telemetry>,
  75    cx: &mut AppContext,
  76) {
  77    cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
  78    cx.observe_new_views(|workspace: &mut Workspace, cx| {
  79        workspace.register_action(InlineAssistant::toggle_inline_assist);
  80
  81        let workspace = cx.view().clone();
  82        InlineAssistant::update_global(cx, |inline_assistant, cx| {
  83            inline_assistant.register_workspace(&workspace, cx)
  84        })
  85    })
  86    .detach();
  87}
  88
  89const PROMPT_HISTORY_MAX_LEN: usize = 20;
  90
  91enum InlineAssistTarget {
  92    Editor(View<Editor>),
  93    Terminal(View<TerminalView>),
  94}
  95
  96pub struct InlineAssistant {
  97    next_assist_id: InlineAssistId,
  98    next_assist_group_id: InlineAssistGroupId,
  99    assists: HashMap<InlineAssistId, InlineAssist>,
 100    assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
 101    assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
 102    confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
 103    prompt_history: VecDeque<String>,
 104    prompt_builder: Arc<PromptBuilder>,
 105    telemetry: Arc<Telemetry>,
 106    fs: Arc<dyn Fs>,
 107}
 108
 109impl Global for InlineAssistant {}
 110
 111impl InlineAssistant {
 112    pub fn new(
 113        fs: Arc<dyn Fs>,
 114        prompt_builder: Arc<PromptBuilder>,
 115        telemetry: Arc<Telemetry>,
 116    ) -> Self {
 117        Self {
 118            next_assist_id: InlineAssistId::default(),
 119            next_assist_group_id: InlineAssistGroupId::default(),
 120            assists: HashMap::default(),
 121            assists_by_editor: HashMap::default(),
 122            assist_groups: HashMap::default(),
 123            confirmed_assists: HashMap::default(),
 124            prompt_history: VecDeque::default(),
 125            prompt_builder,
 126            telemetry,
 127            fs,
 128        }
 129    }
 130
 131    pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
 132        cx.subscribe(workspace, |workspace, event, cx| {
 133            Self::update_global(cx, |this, cx| {
 134                this.handle_workspace_event(workspace, event, cx)
 135            });
 136        })
 137        .detach();
 138
 139        let workspace = workspace.downgrade();
 140        cx.observe_global::<SettingsStore>(move |cx| {
 141            let Some(workspace) = workspace.upgrade() else {
 142                return;
 143            };
 144            let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
 145                return;
 146            };
 147            let enabled = AssistantSettings::get_global(cx).enabled;
 148            terminal_panel.update(cx, |terminal_panel, cx| {
 149                terminal_panel.asssistant_enabled(enabled, cx)
 150            });
 151        })
 152        .detach();
 153    }
 154
 155    fn handle_workspace_event(
 156        &mut self,
 157        workspace: View<Workspace>,
 158        event: &workspace::Event,
 159        cx: &mut WindowContext,
 160    ) {
 161        match event {
 162            workspace::Event::UserSavedItem { item, .. } => {
 163                // When the user manually saves an editor, automatically accepts all finished transformations.
 164                if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
 165                    if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
 166                        for assist_id in editor_assists.assist_ids.clone() {
 167                            let assist = &self.assists[&assist_id];
 168                            if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
 169                                self.finish_assist(assist_id, false, cx)
 170                            }
 171                        }
 172                    }
 173                }
 174            }
 175            workspace::Event::ItemAdded { item } => {
 176                self.register_workspace_item(&workspace, item.as_ref(), cx);
 177            }
 178            _ => (),
 179        }
 180    }
 181
 182    fn register_workspace_item(
 183        &mut self,
 184        workspace: &View<Workspace>,
 185        item: &dyn ItemHandle,
 186        cx: &mut WindowContext,
 187    ) {
 188        if let Some(editor) = item.act_as::<Editor>(cx) {
 189            editor.update(cx, |editor, cx| {
 190                let thread_store = workspace
 191                    .read(cx)
 192                    .panel::<AssistantPanel>(cx)
 193                    .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
 194
 195                editor.push_code_action_provider(
 196                    Rc::new(AssistantCodeActionProvider {
 197                        editor: cx.view().downgrade(),
 198                        workspace: workspace.downgrade(),
 199                        thread_store,
 200                    }),
 201                    cx,
 202                );
 203            });
 204        }
 205    }
 206
 207    pub fn toggle_inline_assist(
 208        workspace: &mut Workspace,
 209        _action: &ToggleInlineAssist,
 210        cx: &mut ViewContext<Workspace>,
 211    ) {
 212        let settings = AssistantSettings::get_global(cx);
 213        if !settings.enabled {
 214            return;
 215        }
 216
 217        let Some(inline_assist_target) = Self::resolve_inline_assist_target(workspace, cx) else {
 218            return;
 219        };
 220
 221        let is_authenticated = || {
 222            LanguageModelRegistry::read_global(cx)
 223                .active_provider()
 224                .map_or(false, |provider| provider.is_authenticated(cx))
 225        };
 226
 227        let thread_store = workspace
 228            .panel::<AssistantPanel>(cx)
 229            .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
 230
 231        let handle_assist = |cx: &mut ViewContext<Workspace>| match inline_assist_target {
 232            InlineAssistTarget::Editor(active_editor) => {
 233                InlineAssistant::update_global(cx, |assistant, cx| {
 234                    assistant.assist(&active_editor, cx.view().downgrade(), thread_store, cx)
 235                })
 236            }
 237            InlineAssistTarget::Terminal(active_terminal) => {
 238                TerminalInlineAssistant::update_global(cx, |assistant, cx| {
 239                    assistant.assist(&active_terminal, cx.view().downgrade(), thread_store, cx)
 240                })
 241            }
 242        };
 243
 244        if is_authenticated() {
 245            handle_assist(cx);
 246        } else {
 247            cx.spawn(|_workspace, mut cx| async move {
 248                let Some(task) = cx.update(|cx| {
 249                    LanguageModelRegistry::read_global(cx)
 250                        .active_provider()
 251                        .map_or(None, |provider| Some(provider.authenticate(cx)))
 252                })?
 253                else {
 254                    let answer = cx
 255                        .prompt(
 256                            gpui::PromptLevel::Warning,
 257                            "No language model provider configured",
 258                            None,
 259                            &["Configure", "Cancel"],
 260                        )
 261                        .await
 262                        .ok();
 263                    if let Some(answer) = answer {
 264                        if answer == 0 {
 265                            cx.update(|cx| cx.dispatch_action(Box::new(ShowConfiguration)))
 266                                .ok();
 267                        }
 268                    }
 269                    return Ok(());
 270                };
 271                task.await?;
 272
 273                anyhow::Ok(())
 274            })
 275            .detach_and_log_err(cx);
 276
 277            if is_authenticated() {
 278                handle_assist(cx);
 279            }
 280        }
 281    }
 282
 283    pub fn assist(
 284        &mut self,
 285        editor: &View<Editor>,
 286        workspace: WeakView<Workspace>,
 287        thread_store: Option<WeakModel<ThreadStore>>,
 288        cx: &mut WindowContext,
 289    ) {
 290        let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
 291            (
 292                editor.buffer().read(cx).snapshot(cx),
 293                editor.selections.all::<Point>(cx),
 294            )
 295        });
 296
 297        let mut selections = Vec::<Selection<Point>>::new();
 298        let mut newest_selection = None;
 299        for mut selection in initial_selections {
 300            if selection.end > selection.start {
 301                selection.start.column = 0;
 302                // If the selection ends at the start of the line, we don't want to include it.
 303                if selection.end.column == 0 {
 304                    selection.end.row -= 1;
 305                }
 306                selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
 307            }
 308
 309            if let Some(prev_selection) = selections.last_mut() {
 310                if selection.start <= prev_selection.end {
 311                    prev_selection.end = selection.end;
 312                    continue;
 313                }
 314            }
 315
 316            let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
 317            if selection.id > latest_selection.id {
 318                *latest_selection = selection.clone();
 319            }
 320            selections.push(selection);
 321        }
 322        let newest_selection = newest_selection.unwrap();
 323
 324        let mut codegen_ranges = Vec::new();
 325        for (excerpt_id, buffer, buffer_range) in
 326            snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
 327                snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
 328            }))
 329        {
 330            let start = Anchor {
 331                buffer_id: Some(buffer.remote_id()),
 332                excerpt_id,
 333                text_anchor: buffer.anchor_before(buffer_range.start),
 334            };
 335            let end = Anchor {
 336                buffer_id: Some(buffer.remote_id()),
 337                excerpt_id,
 338                text_anchor: buffer.anchor_after(buffer_range.end),
 339            };
 340            codegen_ranges.push(start..end);
 341
 342            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
 343                self.telemetry.report_assistant_event(AssistantEvent {
 344                    conversation_id: None,
 345                    kind: AssistantKind::Inline,
 346                    phase: AssistantPhase::Invoked,
 347                    message_id: None,
 348                    model: model.telemetry_id(),
 349                    model_provider: model.provider_id().to_string(),
 350                    response_latency: None,
 351                    error_message: None,
 352                    language_name: buffer.language().map(|language| language.name().to_proto()),
 353                });
 354            }
 355        }
 356
 357        let assist_group_id = self.next_assist_group_id.post_inc();
 358        let prompt_buffer = cx.new_model(|cx| {
 359            MultiBuffer::singleton(cx.new_model(|cx| Buffer::local(String::new(), cx)), cx)
 360        });
 361
 362        let mut assists = Vec::new();
 363        let mut assist_to_focus = None;
 364        for range in codegen_ranges {
 365            let assist_id = self.next_assist_id.post_inc();
 366            let context_store = cx.new_model(|_cx| ContextStore::new());
 367            let codegen = cx.new_model(|cx| {
 368                Codegen::new(
 369                    editor.read(cx).buffer().clone(),
 370                    range.clone(),
 371                    None,
 372                    context_store.clone(),
 373                    self.telemetry.clone(),
 374                    self.prompt_builder.clone(),
 375                    cx,
 376                )
 377            });
 378
 379            let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 380            let prompt_editor = cx.new_view(|cx| {
 381                PromptEditor::new(
 382                    assist_id,
 383                    gutter_dimensions.clone(),
 384                    self.prompt_history.clone(),
 385                    prompt_buffer.clone(),
 386                    codegen.clone(),
 387                    self.fs.clone(),
 388                    context_store,
 389                    workspace.clone(),
 390                    thread_store.clone(),
 391                    cx,
 392                )
 393            });
 394
 395            if assist_to_focus.is_none() {
 396                let focus_assist = if newest_selection.reversed {
 397                    range.start.to_point(&snapshot) == newest_selection.start
 398                } else {
 399                    range.end.to_point(&snapshot) == newest_selection.end
 400                };
 401                if focus_assist {
 402                    assist_to_focus = Some(assist_id);
 403                }
 404            }
 405
 406            let [prompt_block_id, end_block_id] =
 407                self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 408
 409            assists.push((
 410                assist_id,
 411                range,
 412                prompt_editor,
 413                prompt_block_id,
 414                end_block_id,
 415            ));
 416        }
 417
 418        let editor_assists = self
 419            .assists_by_editor
 420            .entry(editor.downgrade())
 421            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 422        let mut assist_group = InlineAssistGroup::new();
 423        for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
 424            self.assists.insert(
 425                assist_id,
 426                InlineAssist::new(
 427                    assist_id,
 428                    assist_group_id,
 429                    editor,
 430                    &prompt_editor,
 431                    prompt_block_id,
 432                    end_block_id,
 433                    range,
 434                    prompt_editor.read(cx).codegen.clone(),
 435                    workspace.clone(),
 436                    cx,
 437                ),
 438            );
 439            assist_group.assist_ids.push(assist_id);
 440            editor_assists.assist_ids.push(assist_id);
 441        }
 442        self.assist_groups.insert(assist_group_id, assist_group);
 443
 444        if let Some(assist_id) = assist_to_focus {
 445            self.focus_assist(assist_id, cx);
 446        }
 447    }
 448
 449    #[allow(clippy::too_many_arguments)]
 450    pub fn suggest_assist(
 451        &mut self,
 452        editor: &View<Editor>,
 453        mut range: Range<Anchor>,
 454        initial_prompt: String,
 455        initial_transaction_id: Option<TransactionId>,
 456        focus: bool,
 457        workspace: WeakView<Workspace>,
 458        thread_store: Option<WeakModel<ThreadStore>>,
 459        cx: &mut WindowContext,
 460    ) -> InlineAssistId {
 461        let assist_group_id = self.next_assist_group_id.post_inc();
 462        let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
 463        let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 464
 465        let assist_id = self.next_assist_id.post_inc();
 466
 467        let buffer = editor.read(cx).buffer().clone();
 468        {
 469            let snapshot = buffer.read(cx).read(cx);
 470            range.start = range.start.bias_left(&snapshot);
 471            range.end = range.end.bias_right(&snapshot);
 472        }
 473
 474        let context_store = cx.new_model(|_cx| ContextStore::new());
 475
 476        let codegen = cx.new_model(|cx| {
 477            Codegen::new(
 478                editor.read(cx).buffer().clone(),
 479                range.clone(),
 480                initial_transaction_id,
 481                context_store.clone(),
 482                self.telemetry.clone(),
 483                self.prompt_builder.clone(),
 484                cx,
 485            )
 486        });
 487
 488        let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 489        let prompt_editor = cx.new_view(|cx| {
 490            PromptEditor::new(
 491                assist_id,
 492                gutter_dimensions.clone(),
 493                self.prompt_history.clone(),
 494                prompt_buffer.clone(),
 495                codegen.clone(),
 496                self.fs.clone(),
 497                context_store,
 498                workspace.clone(),
 499                thread_store,
 500                cx,
 501            )
 502        });
 503
 504        let [prompt_block_id, end_block_id] =
 505            self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 506
 507        let editor_assists = self
 508            .assists_by_editor
 509            .entry(editor.downgrade())
 510            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 511
 512        let mut assist_group = InlineAssistGroup::new();
 513        self.assists.insert(
 514            assist_id,
 515            InlineAssist::new(
 516                assist_id,
 517                assist_group_id,
 518                editor,
 519                &prompt_editor,
 520                prompt_block_id,
 521                end_block_id,
 522                range,
 523                prompt_editor.read(cx).codegen.clone(),
 524                workspace.clone(),
 525                cx,
 526            ),
 527        );
 528        assist_group.assist_ids.push(assist_id);
 529        editor_assists.assist_ids.push(assist_id);
 530        self.assist_groups.insert(assist_group_id, assist_group);
 531
 532        if focus {
 533            self.focus_assist(assist_id, cx);
 534        }
 535
 536        assist_id
 537    }
 538
 539    fn insert_assist_blocks(
 540        &self,
 541        editor: &View<Editor>,
 542        range: &Range<Anchor>,
 543        prompt_editor: &View<PromptEditor>,
 544        cx: &mut WindowContext,
 545    ) -> [CustomBlockId; 2] {
 546        let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
 547            prompt_editor
 548                .editor
 549                .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
 550        });
 551        let assist_blocks = vec![
 552            BlockProperties {
 553                style: BlockStyle::Sticky,
 554                placement: BlockPlacement::Above(range.start),
 555                height: prompt_editor_height,
 556                render: build_assist_editor_renderer(prompt_editor),
 557                priority: 0,
 558            },
 559            BlockProperties {
 560                style: BlockStyle::Sticky,
 561                placement: BlockPlacement::Below(range.end),
 562                height: 0,
 563                render: Arc::new(|cx| {
 564                    v_flex()
 565                        .h_full()
 566                        .w_full()
 567                        .border_t_1()
 568                        .border_color(cx.theme().status().info_border)
 569                        .into_any_element()
 570                }),
 571                priority: 0,
 572            },
 573        ];
 574
 575        editor.update(cx, |editor, cx| {
 576            let block_ids = editor.insert_blocks(assist_blocks, None, cx);
 577            [block_ids[0], block_ids[1]]
 578        })
 579    }
 580
 581    fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 582        let assist = &self.assists[&assist_id];
 583        let Some(decorations) = assist.decorations.as_ref() else {
 584            return;
 585        };
 586        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 587        let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
 588
 589        assist_group.active_assist_id = Some(assist_id);
 590        if assist_group.linked {
 591            for assist_id in &assist_group.assist_ids {
 592                if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 593                    decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 594                        prompt_editor.set_show_cursor_when_unfocused(true, cx)
 595                    });
 596                }
 597            }
 598        }
 599
 600        assist
 601            .editor
 602            .update(cx, |editor, cx| {
 603                let scroll_top = editor.scroll_position(cx).y;
 604                let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
 605                let prompt_row = editor
 606                    .row_for_block(decorations.prompt_block_id, cx)
 607                    .unwrap()
 608                    .0 as f32;
 609
 610                if (scroll_top..scroll_bottom).contains(&prompt_row) {
 611                    editor_assists.scroll_lock = Some(InlineAssistScrollLock {
 612                        assist_id,
 613                        distance_from_top: prompt_row - scroll_top,
 614                    });
 615                } else {
 616                    editor_assists.scroll_lock = None;
 617                }
 618            })
 619            .ok();
 620    }
 621
 622    fn handle_prompt_editor_focus_out(
 623        &mut self,
 624        assist_id: InlineAssistId,
 625        cx: &mut WindowContext,
 626    ) {
 627        let assist = &self.assists[&assist_id];
 628        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 629        if assist_group.active_assist_id == Some(assist_id) {
 630            assist_group.active_assist_id = None;
 631            if assist_group.linked {
 632                for assist_id in &assist_group.assist_ids {
 633                    if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 634                        decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 635                            prompt_editor.set_show_cursor_when_unfocused(false, cx)
 636                        });
 637                    }
 638                }
 639            }
 640        }
 641    }
 642
 643    fn handle_prompt_editor_event(
 644        &mut self,
 645        prompt_editor: View<PromptEditor>,
 646        event: &PromptEditorEvent,
 647        cx: &mut WindowContext,
 648    ) {
 649        let assist_id = prompt_editor.read(cx).id;
 650        match event {
 651            PromptEditorEvent::StartRequested => {
 652                self.start_assist(assist_id, cx);
 653            }
 654            PromptEditorEvent::StopRequested => {
 655                self.stop_assist(assist_id, cx);
 656            }
 657            PromptEditorEvent::ConfirmRequested => {
 658                self.finish_assist(assist_id, false, cx);
 659            }
 660            PromptEditorEvent::CancelRequested => {
 661                self.finish_assist(assist_id, true, cx);
 662            }
 663            PromptEditorEvent::DismissRequested => {
 664                self.dismiss_assist(assist_id, cx);
 665            }
 666        }
 667    }
 668
 669    fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 670        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 671            return;
 672        };
 673
 674        if editor.read(cx).selections.count() == 1 {
 675            let (selection, buffer) = editor.update(cx, |editor, cx| {
 676                (
 677                    editor.selections.newest::<usize>(cx),
 678                    editor.buffer().read(cx).snapshot(cx),
 679                )
 680            });
 681            for assist_id in &editor_assists.assist_ids {
 682                let assist = &self.assists[assist_id];
 683                let assist_range = assist.range.to_offset(&buffer);
 684                if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
 685                {
 686                    if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
 687                        self.dismiss_assist(*assist_id, cx);
 688                    } else {
 689                        self.finish_assist(*assist_id, false, cx);
 690                    }
 691
 692                    return;
 693                }
 694            }
 695        }
 696
 697        cx.propagate();
 698    }
 699
 700    fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 701        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 702            return;
 703        };
 704
 705        if editor.read(cx).selections.count() == 1 {
 706            let (selection, buffer) = editor.update(cx, |editor, cx| {
 707                (
 708                    editor.selections.newest::<usize>(cx),
 709                    editor.buffer().read(cx).snapshot(cx),
 710                )
 711            });
 712            let mut closest_assist_fallback = None;
 713            for assist_id in &editor_assists.assist_ids {
 714                let assist = &self.assists[assist_id];
 715                let assist_range = assist.range.to_offset(&buffer);
 716                if assist.decorations.is_some() {
 717                    if assist_range.contains(&selection.start)
 718                        && assist_range.contains(&selection.end)
 719                    {
 720                        self.focus_assist(*assist_id, cx);
 721                        return;
 722                    } else {
 723                        let distance_from_selection = assist_range
 724                            .start
 725                            .abs_diff(selection.start)
 726                            .min(assist_range.start.abs_diff(selection.end))
 727                            + assist_range
 728                                .end
 729                                .abs_diff(selection.start)
 730                                .min(assist_range.end.abs_diff(selection.end));
 731                        match closest_assist_fallback {
 732                            Some((_, old_distance)) => {
 733                                if distance_from_selection < old_distance {
 734                                    closest_assist_fallback =
 735                                        Some((assist_id, distance_from_selection));
 736                                }
 737                            }
 738                            None => {
 739                                closest_assist_fallback = Some((assist_id, distance_from_selection))
 740                            }
 741                        }
 742                    }
 743                }
 744            }
 745
 746            if let Some((&assist_id, _)) = closest_assist_fallback {
 747                self.focus_assist(assist_id, cx);
 748            }
 749        }
 750
 751        cx.propagate();
 752    }
 753
 754    fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
 755        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
 756            for assist_id in editor_assists.assist_ids.clone() {
 757                self.finish_assist(assist_id, true, cx);
 758            }
 759        }
 760    }
 761
 762    fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 763        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 764            return;
 765        };
 766        let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
 767            return;
 768        };
 769        let assist = &self.assists[&scroll_lock.assist_id];
 770        let Some(decorations) = assist.decorations.as_ref() else {
 771            return;
 772        };
 773
 774        editor.update(cx, |editor, cx| {
 775            let scroll_position = editor.scroll_position(cx);
 776            let target_scroll_top = editor
 777                .row_for_block(decorations.prompt_block_id, cx)
 778                .unwrap()
 779                .0 as f32
 780                - scroll_lock.distance_from_top;
 781            if target_scroll_top != scroll_position.y {
 782                editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
 783            }
 784        });
 785    }
 786
 787    fn handle_editor_event(
 788        &mut self,
 789        editor: View<Editor>,
 790        event: &EditorEvent,
 791        cx: &mut WindowContext,
 792    ) {
 793        let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
 794            return;
 795        };
 796
 797        match event {
 798            EditorEvent::Edited { transaction_id } => {
 799                let buffer = editor.read(cx).buffer().read(cx);
 800                let edited_ranges =
 801                    buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
 802                let snapshot = buffer.snapshot(cx);
 803
 804                for assist_id in editor_assists.assist_ids.clone() {
 805                    let assist = &self.assists[&assist_id];
 806                    if matches!(
 807                        assist.codegen.read(cx).status(cx),
 808                        CodegenStatus::Error(_) | CodegenStatus::Done
 809                    ) {
 810                        let assist_range = assist.range.to_offset(&snapshot);
 811                        if edited_ranges
 812                            .iter()
 813                            .any(|range| range.overlaps(&assist_range))
 814                        {
 815                            self.finish_assist(assist_id, false, cx);
 816                        }
 817                    }
 818                }
 819            }
 820            EditorEvent::ScrollPositionChanged { .. } => {
 821                if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
 822                    let assist = &self.assists[&scroll_lock.assist_id];
 823                    if let Some(decorations) = assist.decorations.as_ref() {
 824                        let distance_from_top = editor.update(cx, |editor, cx| {
 825                            let scroll_top = editor.scroll_position(cx).y;
 826                            let prompt_row = editor
 827                                .row_for_block(decorations.prompt_block_id, cx)
 828                                .unwrap()
 829                                .0 as f32;
 830                            prompt_row - scroll_top
 831                        });
 832
 833                        if distance_from_top != scroll_lock.distance_from_top {
 834                            editor_assists.scroll_lock = None;
 835                        }
 836                    }
 837                }
 838            }
 839            EditorEvent::SelectionsChanged { .. } => {
 840                for assist_id in editor_assists.assist_ids.clone() {
 841                    let assist = &self.assists[&assist_id];
 842                    if let Some(decorations) = assist.decorations.as_ref() {
 843                        if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
 844                            return;
 845                        }
 846                    }
 847                }
 848
 849                editor_assists.scroll_lock = None;
 850            }
 851            _ => {}
 852        }
 853    }
 854
 855    pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
 856        if let Some(assist) = self.assists.get(&assist_id) {
 857            let assist_group_id = assist.group_id;
 858            if self.assist_groups[&assist_group_id].linked {
 859                for assist_id in self.unlink_assist_group(assist_group_id, cx) {
 860                    self.finish_assist(assist_id, undo, cx);
 861                }
 862                return;
 863            }
 864        }
 865
 866        self.dismiss_assist(assist_id, cx);
 867
 868        if let Some(assist) = self.assists.remove(&assist_id) {
 869            if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
 870            {
 871                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 872                if entry.get().assist_ids.is_empty() {
 873                    entry.remove();
 874                }
 875            }
 876
 877            if let hash_map::Entry::Occupied(mut entry) =
 878                self.assists_by_editor.entry(assist.editor.clone())
 879            {
 880                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 881                if entry.get().assist_ids.is_empty() {
 882                    entry.remove();
 883                    if let Some(editor) = assist.editor.upgrade() {
 884                        self.update_editor_highlights(&editor, cx);
 885                    }
 886                } else {
 887                    entry.get().highlight_updates.send(()).ok();
 888                }
 889            }
 890
 891            let active_alternative = assist.codegen.read(cx).active_alternative().clone();
 892            let message_id = active_alternative.read(cx).message_id.clone();
 893
 894            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
 895                let language_name = assist.editor.upgrade().and_then(|editor| {
 896                    let multibuffer = editor.read(cx).buffer().read(cx);
 897                    let ranges = multibuffer.range_to_buffer_ranges(assist.range.clone(), cx);
 898                    ranges
 899                        .first()
 900                        .and_then(|(buffer, _, _)| buffer.read(cx).language())
 901                        .map(|language| language.name())
 902                });
 903                report_assistant_event(
 904                    AssistantEvent {
 905                        conversation_id: None,
 906                        kind: AssistantKind::Inline,
 907                        message_id,
 908                        phase: if undo {
 909                            AssistantPhase::Rejected
 910                        } else {
 911                            AssistantPhase::Accepted
 912                        },
 913                        model: model.telemetry_id(),
 914                        model_provider: model.provider_id().to_string(),
 915                        response_latency: None,
 916                        error_message: None,
 917                        language_name: language_name.map(|name| name.to_proto()),
 918                    },
 919                    Some(self.telemetry.clone()),
 920                    cx.http_client(),
 921                    model.api_key(cx),
 922                    cx.background_executor(),
 923                );
 924            }
 925
 926            if undo {
 927                assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
 928            } else {
 929                self.confirmed_assists.insert(assist_id, active_alternative);
 930            }
 931        }
 932    }
 933
 934    fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
 935        let Some(assist) = self.assists.get_mut(&assist_id) else {
 936            return false;
 937        };
 938        let Some(editor) = assist.editor.upgrade() else {
 939            return false;
 940        };
 941        let Some(decorations) = assist.decorations.take() else {
 942            return false;
 943        };
 944
 945        editor.update(cx, |editor, cx| {
 946            let mut to_remove = decorations.removed_line_block_ids;
 947            to_remove.insert(decorations.prompt_block_id);
 948            to_remove.insert(decorations.end_block_id);
 949            editor.remove_blocks(to_remove, None, cx);
 950        });
 951
 952        if decorations
 953            .prompt_editor
 954            .focus_handle(cx)
 955            .contains_focused(cx)
 956        {
 957            self.focus_next_assist(assist_id, cx);
 958        }
 959
 960        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
 961            if editor_assists
 962                .scroll_lock
 963                .as_ref()
 964                .map_or(false, |lock| lock.assist_id == assist_id)
 965            {
 966                editor_assists.scroll_lock = None;
 967            }
 968            editor_assists.highlight_updates.send(()).ok();
 969        }
 970
 971        true
 972    }
 973
 974    fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 975        let Some(assist) = self.assists.get(&assist_id) else {
 976            return;
 977        };
 978
 979        let assist_group = &self.assist_groups[&assist.group_id];
 980        let assist_ix = assist_group
 981            .assist_ids
 982            .iter()
 983            .position(|id| *id == assist_id)
 984            .unwrap();
 985        let assist_ids = assist_group
 986            .assist_ids
 987            .iter()
 988            .skip(assist_ix + 1)
 989            .chain(assist_group.assist_ids.iter().take(assist_ix));
 990
 991        for assist_id in assist_ids {
 992            let assist = &self.assists[assist_id];
 993            if assist.decorations.is_some() {
 994                self.focus_assist(*assist_id, cx);
 995                return;
 996            }
 997        }
 998
 999        assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
1000    }
1001
1002    fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1003        let Some(assist) = self.assists.get(&assist_id) else {
1004            return;
1005        };
1006
1007        if let Some(decorations) = assist.decorations.as_ref() {
1008            decorations.prompt_editor.update(cx, |prompt_editor, cx| {
1009                prompt_editor.editor.update(cx, |editor, cx| {
1010                    editor.focus(cx);
1011                    editor.select_all(&SelectAll, cx);
1012                })
1013            });
1014        }
1015
1016        self.scroll_to_assist(assist_id, cx);
1017    }
1018
1019    pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1020        let Some(assist) = self.assists.get(&assist_id) else {
1021            return;
1022        };
1023        let Some(editor) = assist.editor.upgrade() else {
1024            return;
1025        };
1026
1027        let position = assist.range.start;
1028        editor.update(cx, |editor, cx| {
1029            editor.change_selections(None, cx, |selections| {
1030                selections.select_anchor_ranges([position..position])
1031            });
1032
1033            let mut scroll_target_top;
1034            let mut scroll_target_bottom;
1035            if let Some(decorations) = assist.decorations.as_ref() {
1036                scroll_target_top = editor
1037                    .row_for_block(decorations.prompt_block_id, cx)
1038                    .unwrap()
1039                    .0 as f32;
1040                scroll_target_bottom = editor
1041                    .row_for_block(decorations.end_block_id, cx)
1042                    .unwrap()
1043                    .0 as f32;
1044            } else {
1045                let snapshot = editor.snapshot(cx);
1046                let start_row = assist
1047                    .range
1048                    .start
1049                    .to_display_point(&snapshot.display_snapshot)
1050                    .row();
1051                scroll_target_top = start_row.0 as f32;
1052                scroll_target_bottom = scroll_target_top + 1.;
1053            }
1054            scroll_target_top -= editor.vertical_scroll_margin() as f32;
1055            scroll_target_bottom += editor.vertical_scroll_margin() as f32;
1056
1057            let height_in_lines = editor.visible_line_count().unwrap_or(0.);
1058            let scroll_top = editor.scroll_position(cx).y;
1059            let scroll_bottom = scroll_top + height_in_lines;
1060
1061            if scroll_target_top < scroll_top {
1062                editor.set_scroll_position(point(0., scroll_target_top), cx);
1063            } else if scroll_target_bottom > scroll_bottom {
1064                if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
1065                    editor
1066                        .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
1067                } else {
1068                    editor.set_scroll_position(point(0., scroll_target_top), cx);
1069                }
1070            }
1071        });
1072    }
1073
1074    fn unlink_assist_group(
1075        &mut self,
1076        assist_group_id: InlineAssistGroupId,
1077        cx: &mut WindowContext,
1078    ) -> Vec<InlineAssistId> {
1079        let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
1080        assist_group.linked = false;
1081        for assist_id in &assist_group.assist_ids {
1082            let assist = self.assists.get_mut(assist_id).unwrap();
1083            if let Some(editor_decorations) = assist.decorations.as_ref() {
1084                editor_decorations
1085                    .prompt_editor
1086                    .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
1087            }
1088        }
1089        assist_group.assist_ids.clone()
1090    }
1091
1092    pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1093        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1094            assist
1095        } else {
1096            return;
1097        };
1098
1099        let assist_group_id = assist.group_id;
1100        if self.assist_groups[&assist_group_id].linked {
1101            for assist_id in self.unlink_assist_group(assist_group_id, cx) {
1102                self.start_assist(assist_id, cx);
1103            }
1104            return;
1105        }
1106
1107        let Some(user_prompt) = assist.user_prompt(cx) else {
1108            return;
1109        };
1110
1111        self.prompt_history.retain(|prompt| *prompt != user_prompt);
1112        self.prompt_history.push_back(user_prompt.clone());
1113        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1114            self.prompt_history.pop_front();
1115        }
1116
1117        assist
1118            .codegen
1119            .update(cx, |codegen, cx| codegen.start(user_prompt, cx))
1120            .log_err();
1121    }
1122
1123    pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1124        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1125            assist
1126        } else {
1127            return;
1128        };
1129
1130        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1131    }
1132
1133    fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1134        let mut gutter_pending_ranges = Vec::new();
1135        let mut gutter_transformed_ranges = Vec::new();
1136        let mut foreground_ranges = Vec::new();
1137        let mut inserted_row_ranges = Vec::new();
1138        let empty_assist_ids = Vec::new();
1139        let assist_ids = self
1140            .assists_by_editor
1141            .get(&editor.downgrade())
1142            .map_or(&empty_assist_ids, |editor_assists| {
1143                &editor_assists.assist_ids
1144            });
1145
1146        for assist_id in assist_ids {
1147            if let Some(assist) = self.assists.get(assist_id) {
1148                let codegen = assist.codegen.read(cx);
1149                let buffer = codegen.buffer(cx).read(cx).read(cx);
1150                foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1151
1152                let pending_range =
1153                    codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1154                if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1155                    gutter_pending_ranges.push(pending_range);
1156                }
1157
1158                if let Some(edit_position) = codegen.edit_position(cx) {
1159                    let edited_range = assist.range.start..edit_position;
1160                    if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1161                        gutter_transformed_ranges.push(edited_range);
1162                    }
1163                }
1164
1165                if assist.decorations.is_some() {
1166                    inserted_row_ranges
1167                        .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1168                }
1169            }
1170        }
1171
1172        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1173        merge_ranges(&mut foreground_ranges, &snapshot);
1174        merge_ranges(&mut gutter_pending_ranges, &snapshot);
1175        merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1176        editor.update(cx, |editor, cx| {
1177            enum GutterPendingRange {}
1178            if gutter_pending_ranges.is_empty() {
1179                editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1180            } else {
1181                editor.highlight_gutter::<GutterPendingRange>(
1182                    &gutter_pending_ranges,
1183                    |cx| cx.theme().status().info_background,
1184                    cx,
1185                )
1186            }
1187
1188            enum GutterTransformedRange {}
1189            if gutter_transformed_ranges.is_empty() {
1190                editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1191            } else {
1192                editor.highlight_gutter::<GutterTransformedRange>(
1193                    &gutter_transformed_ranges,
1194                    |cx| cx.theme().status().info,
1195                    cx,
1196                )
1197            }
1198
1199            if foreground_ranges.is_empty() {
1200                editor.clear_highlights::<InlineAssist>(cx);
1201            } else {
1202                editor.highlight_text::<InlineAssist>(
1203                    foreground_ranges,
1204                    HighlightStyle {
1205                        fade_out: Some(0.6),
1206                        ..Default::default()
1207                    },
1208                    cx,
1209                );
1210            }
1211
1212            editor.clear_row_highlights::<InlineAssist>();
1213            for row_range in inserted_row_ranges {
1214                editor.highlight_rows::<InlineAssist>(
1215                    row_range,
1216                    cx.theme().status().info_background,
1217                    false,
1218                    cx,
1219                );
1220            }
1221        });
1222    }
1223
1224    fn update_editor_blocks(
1225        &mut self,
1226        editor: &View<Editor>,
1227        assist_id: InlineAssistId,
1228        cx: &mut WindowContext,
1229    ) {
1230        let Some(assist) = self.assists.get_mut(&assist_id) else {
1231            return;
1232        };
1233        let Some(decorations) = assist.decorations.as_mut() else {
1234            return;
1235        };
1236
1237        let codegen = assist.codegen.read(cx);
1238        let old_snapshot = codegen.snapshot(cx);
1239        let old_buffer = codegen.old_buffer(cx);
1240        let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1241
1242        editor.update(cx, |editor, cx| {
1243            let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1244            editor.remove_blocks(old_blocks, None, cx);
1245
1246            let mut new_blocks = Vec::new();
1247            for (new_row, old_row_range) in deleted_row_ranges {
1248                let (_, buffer_start) = old_snapshot
1249                    .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1250                    .unwrap();
1251                let (_, buffer_end) = old_snapshot
1252                    .point_to_buffer_offset(Point::new(
1253                        *old_row_range.end(),
1254                        old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1255                    ))
1256                    .unwrap();
1257
1258                let deleted_lines_editor = cx.new_view(|cx| {
1259                    let multi_buffer = cx.new_model(|_| {
1260                        MultiBuffer::without_headers(language::Capability::ReadOnly)
1261                    });
1262                    multi_buffer.update(cx, |multi_buffer, cx| {
1263                        multi_buffer.push_excerpts(
1264                            old_buffer.clone(),
1265                            Some(ExcerptRange {
1266                                context: buffer_start..buffer_end,
1267                                primary: None,
1268                            }),
1269                            cx,
1270                        );
1271                    });
1272
1273                    enum DeletedLines {}
1274                    let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1275                    editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1276                    editor.set_show_wrap_guides(false, cx);
1277                    editor.set_show_gutter(false, cx);
1278                    editor.scroll_manager.set_forbid_vertical_scroll(true);
1279                    editor.set_read_only(true);
1280                    editor.set_show_inline_completions(Some(false), cx);
1281                    editor.highlight_rows::<DeletedLines>(
1282                        Anchor::min()..Anchor::max(),
1283                        cx.theme().status().deleted_background,
1284                        false,
1285                        cx,
1286                    );
1287                    editor
1288                });
1289
1290                let height =
1291                    deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1292                new_blocks.push(BlockProperties {
1293                    placement: BlockPlacement::Above(new_row),
1294                    height,
1295                    style: BlockStyle::Flex,
1296                    render: Arc::new(move |cx| {
1297                        div()
1298                            .block_mouse_down()
1299                            .bg(cx.theme().status().deleted_background)
1300                            .size_full()
1301                            .h(height as f32 * cx.line_height())
1302                            .pl(cx.gutter_dimensions.full_width())
1303                            .child(deleted_lines_editor.clone())
1304                            .into_any_element()
1305                    }),
1306                    priority: 0,
1307                });
1308            }
1309
1310            decorations.removed_line_block_ids = editor
1311                .insert_blocks(new_blocks, None, cx)
1312                .into_iter()
1313                .collect();
1314        })
1315    }
1316
1317    fn resolve_inline_assist_target(
1318        workspace: &mut Workspace,
1319        cx: &mut WindowContext,
1320    ) -> Option<InlineAssistTarget> {
1321        if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx) {
1322            if terminal_panel
1323                .read(cx)
1324                .focus_handle(cx)
1325                .contains_focused(cx)
1326            {
1327                if let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| {
1328                    pane.read(cx)
1329                        .active_item()
1330                        .and_then(|t| t.downcast::<TerminalView>())
1331                }) {
1332                    return Some(InlineAssistTarget::Terminal(terminal_view));
1333                }
1334            }
1335        }
1336
1337        if let Some(workspace_editor) = workspace
1338            .active_item(cx)
1339            .and_then(|item| item.act_as::<Editor>(cx))
1340        {
1341            Some(InlineAssistTarget::Editor(workspace_editor))
1342        } else if let Some(terminal_view) = workspace
1343            .active_item(cx)
1344            .and_then(|item| item.act_as::<TerminalView>(cx))
1345        {
1346            Some(InlineAssistTarget::Terminal(terminal_view))
1347        } else {
1348            None
1349        }
1350    }
1351}
1352
1353struct EditorInlineAssists {
1354    assist_ids: Vec<InlineAssistId>,
1355    scroll_lock: Option<InlineAssistScrollLock>,
1356    highlight_updates: async_watch::Sender<()>,
1357    _update_highlights: Task<Result<()>>,
1358    _subscriptions: Vec<gpui::Subscription>,
1359}
1360
1361struct InlineAssistScrollLock {
1362    assist_id: InlineAssistId,
1363    distance_from_top: f32,
1364}
1365
1366impl EditorInlineAssists {
1367    #[allow(clippy::too_many_arguments)]
1368    fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1369        let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1370        Self {
1371            assist_ids: Vec::new(),
1372            scroll_lock: None,
1373            highlight_updates: highlight_updates_tx,
1374            _update_highlights: cx.spawn(|mut cx| {
1375                let editor = editor.downgrade();
1376                async move {
1377                    while let Ok(()) = highlight_updates_rx.changed().await {
1378                        let editor = editor.upgrade().context("editor was dropped")?;
1379                        cx.update_global(|assistant: &mut InlineAssistant, cx| {
1380                            assistant.update_editor_highlights(&editor, cx);
1381                        })?;
1382                    }
1383                    Ok(())
1384                }
1385            }),
1386            _subscriptions: vec![
1387                cx.observe_release(editor, {
1388                    let editor = editor.downgrade();
1389                    |_, cx| {
1390                        InlineAssistant::update_global(cx, |this, cx| {
1391                            this.handle_editor_release(editor, cx);
1392                        })
1393                    }
1394                }),
1395                cx.observe(editor, move |editor, cx| {
1396                    InlineAssistant::update_global(cx, |this, cx| {
1397                        this.handle_editor_change(editor, cx)
1398                    })
1399                }),
1400                cx.subscribe(editor, move |editor, event, cx| {
1401                    InlineAssistant::update_global(cx, |this, cx| {
1402                        this.handle_editor_event(editor, event, cx)
1403                    })
1404                }),
1405                editor.update(cx, |editor, cx| {
1406                    let editor_handle = cx.view().downgrade();
1407                    editor.register_action(
1408                        move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1409                            InlineAssistant::update_global(cx, |this, cx| {
1410                                if let Some(editor) = editor_handle.upgrade() {
1411                                    this.handle_editor_newline(editor, cx)
1412                                }
1413                            })
1414                        },
1415                    )
1416                }),
1417                editor.update(cx, |editor, cx| {
1418                    let editor_handle = cx.view().downgrade();
1419                    editor.register_action(
1420                        move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1421                            InlineAssistant::update_global(cx, |this, cx| {
1422                                if let Some(editor) = editor_handle.upgrade() {
1423                                    this.handle_editor_cancel(editor, cx)
1424                                }
1425                            })
1426                        },
1427                    )
1428                }),
1429            ],
1430        }
1431    }
1432}
1433
1434struct InlineAssistGroup {
1435    assist_ids: Vec<InlineAssistId>,
1436    linked: bool,
1437    active_assist_id: Option<InlineAssistId>,
1438}
1439
1440impl InlineAssistGroup {
1441    fn new() -> Self {
1442        Self {
1443            assist_ids: Vec::new(),
1444            linked: true,
1445            active_assist_id: None,
1446        }
1447    }
1448}
1449
1450fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1451    let editor = editor.clone();
1452    Arc::new(move |cx: &mut BlockContext| {
1453        *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1454        editor.clone().into_any_element()
1455    })
1456}
1457
1458#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1459pub struct InlineAssistId(usize);
1460
1461impl InlineAssistId {
1462    fn post_inc(&mut self) -> InlineAssistId {
1463        let id = *self;
1464        self.0 += 1;
1465        id
1466    }
1467}
1468
1469#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1470struct InlineAssistGroupId(usize);
1471
1472impl InlineAssistGroupId {
1473    fn post_inc(&mut self) -> InlineAssistGroupId {
1474        let id = *self;
1475        self.0 += 1;
1476        id
1477    }
1478}
1479
1480enum PromptEditorEvent {
1481    StartRequested,
1482    StopRequested,
1483    ConfirmRequested,
1484    CancelRequested,
1485    DismissRequested,
1486}
1487
1488struct PromptEditor {
1489    id: InlineAssistId,
1490    editor: View<Editor>,
1491    context_strip: View<ContextStrip>,
1492    context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
1493    language_model_selector: View<LanguageModelSelector>,
1494    edited_since_done: bool,
1495    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1496    prompt_history: VecDeque<String>,
1497    prompt_history_ix: Option<usize>,
1498    pending_prompt: String,
1499    codegen: Model<Codegen>,
1500    _codegen_subscription: Subscription,
1501    editor_subscriptions: Vec<Subscription>,
1502    show_rate_limit_notice: bool,
1503}
1504
1505impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1506
1507impl Render for PromptEditor {
1508    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1509        let gutter_dimensions = *self.gutter_dimensions.lock();
1510        let mut buttons = Vec::new();
1511        let codegen = self.codegen.read(cx);
1512        if codegen.alternative_count(cx) > 1 {
1513            buttons.push(self.render_cycle_controls(cx));
1514        }
1515
1516        let status = codegen.status(cx);
1517        buttons.extend(match status {
1518            CodegenStatus::Idle => {
1519                vec![
1520                    IconButton::new("cancel", IconName::Close)
1521                        .icon_color(Color::Muted)
1522                        .shape(IconButtonShape::Square)
1523                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1524                        .on_click(
1525                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1526                        )
1527                        .into_any_element(),
1528                    IconButton::new("start", IconName::SparkleAlt)
1529                        .icon_color(Color::Muted)
1530                        .shape(IconButtonShape::Square)
1531                        .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1532                        .on_click(
1533                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1534                        )
1535                        .into_any_element(),
1536                ]
1537            }
1538            CodegenStatus::Pending => {
1539                vec![
1540                    IconButton::new("cancel", IconName::Close)
1541                        .icon_color(Color::Muted)
1542                        .shape(IconButtonShape::Square)
1543                        .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1544                        .on_click(
1545                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1546                        )
1547                        .into_any_element(),
1548                    IconButton::new("stop", IconName::Stop)
1549                        .icon_color(Color::Error)
1550                        .shape(IconButtonShape::Square)
1551                        .tooltip(|cx| {
1552                            Tooltip::with_meta(
1553                                "Interrupt Transformation",
1554                                Some(&menu::Cancel),
1555                                "Changes won't be discarded",
1556                                cx,
1557                            )
1558                        })
1559                        .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
1560                        .into_any_element(),
1561                ]
1562            }
1563            CodegenStatus::Error(_) | CodegenStatus::Done => {
1564                vec![
1565                    IconButton::new("cancel", IconName::Close)
1566                        .icon_color(Color::Muted)
1567                        .shape(IconButtonShape::Square)
1568                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1569                        .on_click(
1570                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1571                        )
1572                        .into_any_element(),
1573                    if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1574                        IconButton::new("restart", IconName::RotateCw)
1575                            .icon_color(Color::Info)
1576                            .shape(IconButtonShape::Square)
1577                            .tooltip(|cx| {
1578                                Tooltip::with_meta(
1579                                    "Restart Transformation",
1580                                    Some(&menu::Confirm),
1581                                    "Changes will be discarded",
1582                                    cx,
1583                                )
1584                            })
1585                            .on_click(cx.listener(|_, _, cx| {
1586                                cx.emit(PromptEditorEvent::StartRequested);
1587                            }))
1588                            .into_any_element()
1589                    } else {
1590                        IconButton::new("confirm", IconName::Check)
1591                            .icon_color(Color::Info)
1592                            .shape(IconButtonShape::Square)
1593                            .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1594                            .on_click(cx.listener(|_, _, cx| {
1595                                cx.emit(PromptEditorEvent::ConfirmRequested);
1596                            }))
1597                            .into_any_element()
1598                    },
1599                ]
1600            }
1601        });
1602
1603        v_flex()
1604            .border_y_1()
1605            .border_color(cx.theme().status().info_border)
1606            .size_full()
1607            .py(cx.line_height() / 2.5)
1608            .child(
1609                h_flex()
1610                    .key_context("PromptEditor")
1611                    .bg(cx.theme().colors().editor_background)
1612                    .block_mouse_down()
1613                    .cursor(CursorStyle::Arrow)
1614                    .on_action(cx.listener(Self::toggle_context_picker))
1615                    .on_action(cx.listener(Self::confirm))
1616                    .on_action(cx.listener(Self::cancel))
1617                    .on_action(cx.listener(Self::move_up))
1618                    .on_action(cx.listener(Self::move_down))
1619                    .capture_action(cx.listener(Self::cycle_prev))
1620                    .capture_action(cx.listener(Self::cycle_next))
1621                    .child(
1622                        h_flex()
1623                            .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1624                            .justify_center()
1625                            .gap_2()
1626                            .child(LanguageModelSelectorPopoverMenu::new(
1627                                self.language_model_selector.clone(),
1628                                IconButton::new("context", IconName::SettingsAlt)
1629                                    .shape(IconButtonShape::Square)
1630                                    .icon_size(IconSize::Small)
1631                                    .icon_color(Color::Muted)
1632                                    .tooltip(move |cx| {
1633                                        Tooltip::with_meta(
1634                                            format!(
1635                                                "Using {}",
1636                                                LanguageModelRegistry::read_global(cx)
1637                                                    .active_model()
1638                                                    .map(|model| model.name().0)
1639                                                    .unwrap_or_else(|| "No model selected".into()),
1640                                            ),
1641                                            None,
1642                                            "Change Model",
1643                                            cx,
1644                                        )
1645                                    }),
1646                            ))
1647                            .map(|el| {
1648                                let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx)
1649                                else {
1650                                    return el;
1651                                };
1652
1653                                let error_message = SharedString::from(error.to_string());
1654                                if error.error_code() == proto::ErrorCode::RateLimitExceeded
1655                                    && cx.has_flag::<ZedPro>()
1656                                {
1657                                    el.child(
1658                                        v_flex()
1659                                            .child(
1660                                                IconButton::new(
1661                                                    "rate-limit-error",
1662                                                    IconName::XCircle,
1663                                                )
1664                                                .toggle_state(self.show_rate_limit_notice)
1665                                                .shape(IconButtonShape::Square)
1666                                                .icon_size(IconSize::Small)
1667                                                .on_click(
1668                                                    cx.listener(Self::toggle_rate_limit_notice),
1669                                                ),
1670                                            )
1671                                            .children(self.show_rate_limit_notice.then(|| {
1672                                                deferred(
1673                                                    anchored()
1674                                                        .position_mode(
1675                                                            gpui::AnchoredPositionMode::Local,
1676                                                        )
1677                                                        .position(point(px(0.), px(24.)))
1678                                                        .anchor(gpui::AnchorCorner::TopLeft)
1679                                                        .child(self.render_rate_limit_notice(cx)),
1680                                                )
1681                                            })),
1682                                    )
1683                                } else {
1684                                    el.child(
1685                                        div()
1686                                            .id("error")
1687                                            .tooltip(move |cx| {
1688                                                Tooltip::text(error_message.clone(), cx)
1689                                            })
1690                                            .child(
1691                                                Icon::new(IconName::XCircle)
1692                                                    .size(IconSize::Small)
1693                                                    .color(Color::Error),
1694                                            ),
1695                                    )
1696                                }
1697                            }),
1698                    )
1699                    .child(div().flex_1().child(self.render_editor(cx)))
1700                    .child(h_flex().gap_2().pr_6().children(buttons)),
1701            )
1702            .child(
1703                h_flex()
1704                    .child(
1705                        h_flex()
1706                            .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1707                            .justify_center()
1708                            .gap_2(),
1709                    )
1710                    .child(self.context_strip.clone()),
1711            )
1712    }
1713}
1714
1715impl FocusableView for PromptEditor {
1716    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1717        self.editor.focus_handle(cx)
1718    }
1719}
1720
1721impl PromptEditor {
1722    const MAX_LINES: u8 = 8;
1723
1724    #[allow(clippy::too_many_arguments)]
1725    fn new(
1726        id: InlineAssistId,
1727        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1728        prompt_history: VecDeque<String>,
1729        prompt_buffer: Model<MultiBuffer>,
1730        codegen: Model<Codegen>,
1731        fs: Arc<dyn Fs>,
1732        context_store: Model<ContextStore>,
1733        workspace: WeakView<Workspace>,
1734        thread_store: Option<WeakModel<ThreadStore>>,
1735        cx: &mut ViewContext<Self>,
1736    ) -> Self {
1737        let prompt_editor = cx.new_view(|cx| {
1738            let mut editor = Editor::new(
1739                EditorMode::AutoHeight {
1740                    max_lines: Self::MAX_LINES as usize,
1741                },
1742                prompt_buffer,
1743                None,
1744                false,
1745                cx,
1746            );
1747            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1748            // Since the prompt editors for all inline assistants are linked,
1749            // always show the cursor (even when it isn't focused) because
1750            // typing in one will make what you typed appear in all of them.
1751            editor.set_show_cursor_when_unfocused(true, cx);
1752            editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx)), cx);
1753            editor
1754        });
1755        let context_picker_menu_handle = PopoverMenuHandle::default();
1756
1757        let mut this = Self {
1758            id,
1759            editor: prompt_editor.clone(),
1760            context_strip: cx.new_view(|cx| {
1761                ContextStrip::new(
1762                    context_store,
1763                    workspace.clone(),
1764                    thread_store.clone(),
1765                    prompt_editor.focus_handle(cx),
1766                    context_picker_menu_handle.clone(),
1767                    cx,
1768                )
1769            }),
1770            context_picker_menu_handle,
1771            language_model_selector: cx.new_view(|cx| {
1772                let fs = fs.clone();
1773                LanguageModelSelector::new(
1774                    move |model, cx| {
1775                        update_settings_file::<AssistantSettings>(
1776                            fs.clone(),
1777                            cx,
1778                            move |settings, _| settings.set_model(model.clone()),
1779                        );
1780                    },
1781                    cx,
1782                )
1783            }),
1784            edited_since_done: false,
1785            gutter_dimensions,
1786            prompt_history,
1787            prompt_history_ix: None,
1788            pending_prompt: String::new(),
1789            _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1790            editor_subscriptions: Vec::new(),
1791            codegen,
1792            show_rate_limit_notice: false,
1793        };
1794        this.subscribe_to_editor(cx);
1795        this
1796    }
1797
1798    fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1799        self.editor_subscriptions.clear();
1800        self.editor_subscriptions
1801            .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1802    }
1803
1804    fn set_show_cursor_when_unfocused(
1805        &mut self,
1806        show_cursor_when_unfocused: bool,
1807        cx: &mut ViewContext<Self>,
1808    ) {
1809        self.editor.update(cx, |editor, cx| {
1810            editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1811        });
1812    }
1813
1814    fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1815        let prompt = self.prompt(cx);
1816        let focus = self.editor.focus_handle(cx).contains_focused(cx);
1817        self.editor = cx.new_view(|cx| {
1818            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1819            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1820            editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx)), cx);
1821            editor.set_placeholder_text("Add a prompt…", cx);
1822            editor.set_text(prompt, cx);
1823            if focus {
1824                editor.focus(cx);
1825            }
1826            editor
1827        });
1828        self.subscribe_to_editor(cx);
1829    }
1830
1831    fn placeholder_text(codegen: &Codegen) -> String {
1832        let action = if codegen.is_insertion {
1833            "Generate"
1834        } else {
1835            "Transform"
1836        };
1837
1838        format!("{action}… ↓↑ for history")
1839    }
1840
1841    fn prompt(&self, cx: &AppContext) -> String {
1842        self.editor.read(cx).text(cx)
1843    }
1844
1845    fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1846        self.show_rate_limit_notice = !self.show_rate_limit_notice;
1847        if self.show_rate_limit_notice {
1848            cx.focus_view(&self.editor);
1849        }
1850        cx.notify();
1851    }
1852
1853    fn handle_prompt_editor_events(
1854        &mut self,
1855        _: View<Editor>,
1856        event: &EditorEvent,
1857        cx: &mut ViewContext<Self>,
1858    ) {
1859        match event {
1860            EditorEvent::Edited { .. } => {
1861                if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
1862                    workspace
1863                        .update(cx, |workspace, cx| {
1864                            let is_via_ssh = workspace
1865                                .project()
1866                                .update(cx, |project, _| project.is_via_ssh());
1867
1868                            workspace
1869                                .client()
1870                                .telemetry()
1871                                .log_edit_event("inline assist", is_via_ssh);
1872                        })
1873                        .log_err();
1874                }
1875                let prompt = self.editor.read(cx).text(cx);
1876                if self
1877                    .prompt_history_ix
1878                    .map_or(true, |ix| self.prompt_history[ix] != prompt)
1879                {
1880                    self.prompt_history_ix.take();
1881                    self.pending_prompt = prompt;
1882                }
1883
1884                self.edited_since_done = true;
1885                cx.notify();
1886            }
1887            EditorEvent::Blurred => {
1888                if self.show_rate_limit_notice {
1889                    self.show_rate_limit_notice = false;
1890                    cx.notify();
1891                }
1892            }
1893            _ => {}
1894        }
1895    }
1896
1897    fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1898        match self.codegen.read(cx).status(cx) {
1899            CodegenStatus::Idle => {
1900                self.editor
1901                    .update(cx, |editor, _| editor.set_read_only(false));
1902            }
1903            CodegenStatus::Pending => {
1904                self.editor
1905                    .update(cx, |editor, _| editor.set_read_only(true));
1906            }
1907            CodegenStatus::Done => {
1908                self.edited_since_done = false;
1909                self.editor
1910                    .update(cx, |editor, _| editor.set_read_only(false));
1911            }
1912            CodegenStatus::Error(error) => {
1913                if cx.has_flag::<ZedPro>()
1914                    && error.error_code() == proto::ErrorCode::RateLimitExceeded
1915                    && !dismissed_rate_limit_notice()
1916                {
1917                    self.show_rate_limit_notice = true;
1918                    cx.notify();
1919                }
1920
1921                self.edited_since_done = false;
1922                self.editor
1923                    .update(cx, |editor, _| editor.set_read_only(false));
1924            }
1925        }
1926    }
1927
1928    fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext<Self>) {
1929        self.context_picker_menu_handle.toggle(cx);
1930    }
1931
1932    fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1933        match self.codegen.read(cx).status(cx) {
1934            CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1935                cx.emit(PromptEditorEvent::CancelRequested);
1936            }
1937            CodegenStatus::Pending => {
1938                cx.emit(PromptEditorEvent::StopRequested);
1939            }
1940        }
1941    }
1942
1943    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1944        match self.codegen.read(cx).status(cx) {
1945            CodegenStatus::Idle => {
1946                cx.emit(PromptEditorEvent::StartRequested);
1947            }
1948            CodegenStatus::Pending => {
1949                cx.emit(PromptEditorEvent::DismissRequested);
1950            }
1951            CodegenStatus::Done => {
1952                if self.edited_since_done {
1953                    cx.emit(PromptEditorEvent::StartRequested);
1954                } else {
1955                    cx.emit(PromptEditorEvent::ConfirmRequested);
1956                }
1957            }
1958            CodegenStatus::Error(_) => {
1959                cx.emit(PromptEditorEvent::StartRequested);
1960            }
1961        }
1962    }
1963
1964    fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1965        if let Some(ix) = self.prompt_history_ix {
1966            if ix > 0 {
1967                self.prompt_history_ix = Some(ix - 1);
1968                let prompt = self.prompt_history[ix - 1].as_str();
1969                self.editor.update(cx, |editor, cx| {
1970                    editor.set_text(prompt, cx);
1971                    editor.move_to_beginning(&Default::default(), cx);
1972                });
1973            }
1974        } else if !self.prompt_history.is_empty() {
1975            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1976            let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1977            self.editor.update(cx, |editor, cx| {
1978                editor.set_text(prompt, cx);
1979                editor.move_to_beginning(&Default::default(), cx);
1980            });
1981        }
1982    }
1983
1984    fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1985        if let Some(ix) = self.prompt_history_ix {
1986            if ix < self.prompt_history.len() - 1 {
1987                self.prompt_history_ix = Some(ix + 1);
1988                let prompt = self.prompt_history[ix + 1].as_str();
1989                self.editor.update(cx, |editor, cx| {
1990                    editor.set_text(prompt, cx);
1991                    editor.move_to_end(&Default::default(), cx)
1992                });
1993            } else {
1994                self.prompt_history_ix = None;
1995                let prompt = self.pending_prompt.as_str();
1996                self.editor.update(cx, |editor, cx| {
1997                    editor.set_text(prompt, cx);
1998                    editor.move_to_end(&Default::default(), cx)
1999                });
2000            }
2001        }
2002    }
2003
2004    fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
2005        self.codegen
2006            .update(cx, |codegen, cx| codegen.cycle_prev(cx));
2007    }
2008
2009    fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
2010        self.codegen
2011            .update(cx, |codegen, cx| codegen.cycle_next(cx));
2012    }
2013
2014    fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
2015        let codegen = self.codegen.read(cx);
2016        let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
2017
2018        let model_registry = LanguageModelRegistry::read_global(cx);
2019        let default_model = model_registry.active_model();
2020        let alternative_models = model_registry.inline_alternative_models();
2021
2022        let get_model_name = |index: usize| -> String {
2023            let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
2024
2025            match index {
2026                0 => default_model.as_ref().map_or_else(String::new, name),
2027                index if index <= alternative_models.len() => alternative_models
2028                    .get(index - 1)
2029                    .map_or_else(String::new, name),
2030                _ => String::new(),
2031            }
2032        };
2033
2034        let total_models = alternative_models.len() + 1;
2035
2036        if total_models <= 1 {
2037            return div().into_any_element();
2038        }
2039
2040        let current_index = codegen.active_alternative;
2041        let prev_index = (current_index + total_models - 1) % total_models;
2042        let next_index = (current_index + 1) % total_models;
2043
2044        let prev_model_name = get_model_name(prev_index);
2045        let next_model_name = get_model_name(next_index);
2046
2047        h_flex()
2048            .child(
2049                IconButton::new("previous", IconName::ChevronLeft)
2050                    .icon_color(Color::Muted)
2051                    .disabled(disabled || current_index == 0)
2052                    .shape(IconButtonShape::Square)
2053                    .tooltip({
2054                        let focus_handle = self.editor.focus_handle(cx);
2055                        move |cx| {
2056                            cx.new_view(|cx| {
2057                                let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
2058                                    KeyBinding::for_action_in(
2059                                        &CyclePreviousInlineAssist,
2060                                        &focus_handle,
2061                                        cx,
2062                                    ),
2063                                );
2064                                if !disabled && current_index != 0 {
2065                                    tooltip = tooltip.meta(prev_model_name.clone());
2066                                }
2067                                tooltip
2068                            })
2069                            .into()
2070                        }
2071                    })
2072                    .on_click(cx.listener(|this, _, cx| {
2073                        this.codegen
2074                            .update(cx, |codegen, cx| codegen.cycle_prev(cx))
2075                    })),
2076            )
2077            .child(
2078                Label::new(format!(
2079                    "{}/{}",
2080                    codegen.active_alternative + 1,
2081                    codegen.alternative_count(cx)
2082                ))
2083                .size(LabelSize::Small)
2084                .color(if disabled {
2085                    Color::Disabled
2086                } else {
2087                    Color::Muted
2088                }),
2089            )
2090            .child(
2091                IconButton::new("next", IconName::ChevronRight)
2092                    .icon_color(Color::Muted)
2093                    .disabled(disabled || current_index == total_models - 1)
2094                    .shape(IconButtonShape::Square)
2095                    .tooltip({
2096                        let focus_handle = self.editor.focus_handle(cx);
2097                        move |cx| {
2098                            cx.new_view(|cx| {
2099                                let mut tooltip = Tooltip::new("Next Alternative").key_binding(
2100                                    KeyBinding::for_action_in(
2101                                        &CycleNextInlineAssist,
2102                                        &focus_handle,
2103                                        cx,
2104                                    ),
2105                                );
2106                                if !disabled && current_index != total_models - 1 {
2107                                    tooltip = tooltip.meta(next_model_name.clone());
2108                                }
2109                                tooltip
2110                            })
2111                            .into()
2112                        }
2113                    })
2114                    .on_click(cx.listener(|this, _, cx| {
2115                        this.codegen
2116                            .update(cx, |codegen, cx| codegen.cycle_next(cx))
2117                    })),
2118            )
2119            .into_any_element()
2120    }
2121
2122    fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2123        Popover::new().child(
2124            v_flex()
2125                .occlude()
2126                .p_2()
2127                .child(
2128                    Label::new("Out of Tokens")
2129                        .size(LabelSize::Small)
2130                        .weight(FontWeight::BOLD),
2131                )
2132                .child(Label::new(
2133                    "Try Zed Pro for higher limits, a wider range of models, and more.",
2134                ))
2135                .child(
2136                    h_flex()
2137                        .justify_between()
2138                        .child(CheckboxWithLabel::new(
2139                            "dont-show-again",
2140                            Label::new("Don't show again"),
2141                            if dismissed_rate_limit_notice() {
2142                                ui::ToggleState::Selected
2143                            } else {
2144                                ui::ToggleState::Unselected
2145                            },
2146                            |selection, cx| {
2147                                let is_dismissed = match selection {
2148                                    ui::ToggleState::Unselected => false,
2149                                    ui::ToggleState::Indeterminate => return,
2150                                    ui::ToggleState::Selected => true,
2151                                };
2152
2153                                set_rate_limit_notice_dismissed(is_dismissed, cx)
2154                            },
2155                        ))
2156                        .child(
2157                            h_flex()
2158                                .gap_2()
2159                                .child(
2160                                    Button::new("dismiss", "Dismiss")
2161                                        .style(ButtonStyle::Transparent)
2162                                        .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2163                                )
2164                                .child(Button::new("more-info", "More Info").on_click(
2165                                    |_event, cx| {
2166                                        cx.dispatch_action(Box::new(
2167                                            zed_actions::OpenAccountSettings,
2168                                        ))
2169                                    },
2170                                )),
2171                        ),
2172                ),
2173        )
2174    }
2175
2176    fn render_editor(&mut self, cx: &mut ViewContext<Self>) -> AnyElement {
2177        let font_size = TextSize::Default.rems(cx);
2178        let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
2179
2180        v_flex()
2181            .key_context("MessageEditor")
2182            .size_full()
2183            .gap_2()
2184            .p_2()
2185            .bg(cx.theme().colors().editor_background)
2186            .child({
2187                let settings = ThemeSettings::get_global(cx);
2188                let text_style = TextStyle {
2189                    color: cx.theme().colors().editor_foreground,
2190                    font_family: settings.ui_font.family.clone(),
2191                    font_features: settings.ui_font.features.clone(),
2192                    font_size: font_size.into(),
2193                    font_weight: settings.ui_font.weight,
2194                    line_height: line_height.into(),
2195                    ..Default::default()
2196                };
2197
2198                EditorElement::new(
2199                    &self.editor,
2200                    EditorStyle {
2201                        background: cx.theme().colors().editor_background,
2202                        local_player: cx.theme().players().local(),
2203                        text: text_style,
2204                        ..Default::default()
2205                    },
2206                )
2207            })
2208            .into_any_element()
2209    }
2210}
2211
2212const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2213
2214fn dismissed_rate_limit_notice() -> bool {
2215    db::kvp::KEY_VALUE_STORE
2216        .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2217        .log_err()
2218        .map_or(false, |s| s.is_some())
2219}
2220
2221fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2222    db::write_and_log(cx, move || async move {
2223        if is_dismissed {
2224            db::kvp::KEY_VALUE_STORE
2225                .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2226                .await
2227        } else {
2228            db::kvp::KEY_VALUE_STORE
2229                .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2230                .await
2231        }
2232    })
2233}
2234
2235pub struct InlineAssist {
2236    group_id: InlineAssistGroupId,
2237    range: Range<Anchor>,
2238    editor: WeakView<Editor>,
2239    decorations: Option<InlineAssistDecorations>,
2240    codegen: Model<Codegen>,
2241    _subscriptions: Vec<Subscription>,
2242    workspace: WeakView<Workspace>,
2243}
2244
2245impl InlineAssist {
2246    #[allow(clippy::too_many_arguments)]
2247    fn new(
2248        assist_id: InlineAssistId,
2249        group_id: InlineAssistGroupId,
2250        editor: &View<Editor>,
2251        prompt_editor: &View<PromptEditor>,
2252        prompt_block_id: CustomBlockId,
2253        end_block_id: CustomBlockId,
2254        range: Range<Anchor>,
2255        codegen: Model<Codegen>,
2256        workspace: WeakView<Workspace>,
2257        cx: &mut WindowContext,
2258    ) -> Self {
2259        let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2260        InlineAssist {
2261            group_id,
2262            editor: editor.downgrade(),
2263            decorations: Some(InlineAssistDecorations {
2264                prompt_block_id,
2265                prompt_editor: prompt_editor.clone(),
2266                removed_line_block_ids: HashSet::default(),
2267                end_block_id,
2268            }),
2269            range,
2270            codegen: codegen.clone(),
2271            workspace: workspace.clone(),
2272            _subscriptions: vec![
2273                cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2274                    InlineAssistant::update_global(cx, |this, cx| {
2275                        this.handle_prompt_editor_focus_in(assist_id, cx)
2276                    })
2277                }),
2278                cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2279                    InlineAssistant::update_global(cx, |this, cx| {
2280                        this.handle_prompt_editor_focus_out(assist_id, cx)
2281                    })
2282                }),
2283                cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2284                    InlineAssistant::update_global(cx, |this, cx| {
2285                        this.handle_prompt_editor_event(prompt_editor, event, cx)
2286                    })
2287                }),
2288                cx.observe(&codegen, {
2289                    let editor = editor.downgrade();
2290                    move |_, cx| {
2291                        if let Some(editor) = editor.upgrade() {
2292                            InlineAssistant::update_global(cx, |this, cx| {
2293                                if let Some(editor_assists) =
2294                                    this.assists_by_editor.get(&editor.downgrade())
2295                                {
2296                                    editor_assists.highlight_updates.send(()).ok();
2297                                }
2298
2299                                this.update_editor_blocks(&editor, assist_id, cx);
2300                            })
2301                        }
2302                    }
2303                }),
2304                cx.subscribe(&codegen, move |codegen, event, cx| {
2305                    InlineAssistant::update_global(cx, |this, cx| match event {
2306                        CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2307                        CodegenEvent::Finished => {
2308                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
2309                                assist
2310                            } else {
2311                                return;
2312                            };
2313
2314                            if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2315                                if assist.decorations.is_none() {
2316                                    if let Some(workspace) = assist.workspace.upgrade() {
2317                                        let error = format!("Inline assistant error: {}", error);
2318                                        workspace.update(cx, |workspace, cx| {
2319                                            struct InlineAssistantError;
2320
2321                                            let id =
2322                                                NotificationId::composite::<InlineAssistantError>(
2323                                                    assist_id.0,
2324                                                );
2325
2326                                            workspace.show_toast(Toast::new(id, error), cx);
2327                                        })
2328                                    }
2329                                }
2330                            }
2331
2332                            if assist.decorations.is_none() {
2333                                this.finish_assist(assist_id, false, cx);
2334                            }
2335                        }
2336                    })
2337                }),
2338            ],
2339        }
2340    }
2341
2342    fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2343        let decorations = self.decorations.as_ref()?;
2344        Some(decorations.prompt_editor.read(cx).prompt(cx))
2345    }
2346}
2347
2348struct InlineAssistDecorations {
2349    prompt_block_id: CustomBlockId,
2350    prompt_editor: View<PromptEditor>,
2351    removed_line_block_ids: HashSet<CustomBlockId>,
2352    end_block_id: CustomBlockId,
2353}
2354
2355#[derive(Copy, Clone, Debug)]
2356pub enum CodegenEvent {
2357    Finished,
2358    Undone,
2359}
2360
2361pub struct Codegen {
2362    alternatives: Vec<Model<CodegenAlternative>>,
2363    active_alternative: usize,
2364    seen_alternatives: HashSet<usize>,
2365    subscriptions: Vec<Subscription>,
2366    buffer: Model<MultiBuffer>,
2367    range: Range<Anchor>,
2368    initial_transaction_id: Option<TransactionId>,
2369    context_store: Model<ContextStore>,
2370    telemetry: Arc<Telemetry>,
2371    builder: Arc<PromptBuilder>,
2372    is_insertion: bool,
2373}
2374
2375impl Codegen {
2376    pub fn new(
2377        buffer: Model<MultiBuffer>,
2378        range: Range<Anchor>,
2379        initial_transaction_id: Option<TransactionId>,
2380        context_store: Model<ContextStore>,
2381        telemetry: Arc<Telemetry>,
2382        builder: Arc<PromptBuilder>,
2383        cx: &mut ModelContext<Self>,
2384    ) -> Self {
2385        let codegen = cx.new_model(|cx| {
2386            CodegenAlternative::new(
2387                buffer.clone(),
2388                range.clone(),
2389                false,
2390                Some(context_store.clone()),
2391                Some(telemetry.clone()),
2392                builder.clone(),
2393                cx,
2394            )
2395        });
2396        let mut this = Self {
2397            is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2398            alternatives: vec![codegen],
2399            active_alternative: 0,
2400            seen_alternatives: HashSet::default(),
2401            subscriptions: Vec::new(),
2402            buffer,
2403            range,
2404            initial_transaction_id,
2405            context_store,
2406            telemetry,
2407            builder,
2408        };
2409        this.activate(0, cx);
2410        this
2411    }
2412
2413    fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2414        let codegen = self.active_alternative().clone();
2415        self.subscriptions.clear();
2416        self.subscriptions
2417            .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2418        self.subscriptions
2419            .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2420    }
2421
2422    fn active_alternative(&self) -> &Model<CodegenAlternative> {
2423        &self.alternatives[self.active_alternative]
2424    }
2425
2426    fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2427        &self.active_alternative().read(cx).status
2428    }
2429
2430    fn alternative_count(&self, cx: &AppContext) -> usize {
2431        LanguageModelRegistry::read_global(cx)
2432            .inline_alternative_models()
2433            .len()
2434            + 1
2435    }
2436
2437    pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2438        let next_active_ix = if self.active_alternative == 0 {
2439            self.alternatives.len() - 1
2440        } else {
2441            self.active_alternative - 1
2442        };
2443        self.activate(next_active_ix, cx);
2444    }
2445
2446    pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2447        let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2448        self.activate(next_active_ix, cx);
2449    }
2450
2451    fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2452        self.active_alternative()
2453            .update(cx, |codegen, cx| codegen.set_active(false, cx));
2454        self.seen_alternatives.insert(index);
2455        self.active_alternative = index;
2456        self.active_alternative()
2457            .update(cx, |codegen, cx| codegen.set_active(true, cx));
2458        self.subscribe_to_alternative(cx);
2459        cx.notify();
2460    }
2461
2462    pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
2463        let alternative_models = LanguageModelRegistry::read_global(cx)
2464            .inline_alternative_models()
2465            .to_vec();
2466
2467        self.active_alternative()
2468            .update(cx, |alternative, cx| alternative.undo(cx));
2469        self.activate(0, cx);
2470        self.alternatives.truncate(1);
2471
2472        for _ in 0..alternative_models.len() {
2473            self.alternatives.push(cx.new_model(|cx| {
2474                CodegenAlternative::new(
2475                    self.buffer.clone(),
2476                    self.range.clone(),
2477                    false,
2478                    Some(self.context_store.clone()),
2479                    Some(self.telemetry.clone()),
2480                    self.builder.clone(),
2481                    cx,
2482                )
2483            }));
2484        }
2485
2486        let primary_model = LanguageModelRegistry::read_global(cx)
2487            .active_model()
2488            .context("no active model")?;
2489
2490        for (model, alternative) in iter::once(primary_model)
2491            .chain(alternative_models)
2492            .zip(&self.alternatives)
2493        {
2494            alternative.update(cx, |alternative, cx| {
2495                alternative.start(user_prompt.clone(), model.clone(), cx)
2496            })?;
2497        }
2498
2499        Ok(())
2500    }
2501
2502    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2503        for codegen in &self.alternatives {
2504            codegen.update(cx, |codegen, cx| codegen.stop(cx));
2505        }
2506    }
2507
2508    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2509        self.active_alternative()
2510            .update(cx, |codegen, cx| codegen.undo(cx));
2511
2512        self.buffer.update(cx, |buffer, cx| {
2513            if let Some(transaction_id) = self.initial_transaction_id.take() {
2514                buffer.undo_transaction(transaction_id, cx);
2515                buffer.refresh_preview(cx);
2516            }
2517        });
2518    }
2519
2520    pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2521        self.active_alternative().read(cx).buffer.clone()
2522    }
2523
2524    pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2525        self.active_alternative().read(cx).old_buffer.clone()
2526    }
2527
2528    pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2529        self.active_alternative().read(cx).snapshot.clone()
2530    }
2531
2532    pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2533        self.active_alternative().read(cx).edit_position
2534    }
2535
2536    fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2537        &self.active_alternative().read(cx).diff
2538    }
2539
2540    pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2541        self.active_alternative().read(cx).last_equal_ranges()
2542    }
2543}
2544
2545impl EventEmitter<CodegenEvent> for Codegen {}
2546
2547pub struct CodegenAlternative {
2548    buffer: Model<MultiBuffer>,
2549    old_buffer: Model<Buffer>,
2550    snapshot: MultiBufferSnapshot,
2551    edit_position: Option<Anchor>,
2552    range: Range<Anchor>,
2553    last_equal_ranges: Vec<Range<Anchor>>,
2554    transformation_transaction_id: Option<TransactionId>,
2555    status: CodegenStatus,
2556    generation: Task<()>,
2557    diff: Diff,
2558    context_store: Option<Model<ContextStore>>,
2559    telemetry: Option<Arc<Telemetry>>,
2560    _subscription: gpui::Subscription,
2561    builder: Arc<PromptBuilder>,
2562    active: bool,
2563    edits: Vec<(Range<Anchor>, String)>,
2564    line_operations: Vec<LineOperation>,
2565    request: Option<LanguageModelRequest>,
2566    elapsed_time: Option<f64>,
2567    completion: Option<String>,
2568    message_id: Option<String>,
2569}
2570
2571enum CodegenStatus {
2572    Idle,
2573    Pending,
2574    Done,
2575    Error(anyhow::Error),
2576}
2577
2578#[derive(Default)]
2579struct Diff {
2580    deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2581    inserted_row_ranges: Vec<Range<Anchor>>,
2582}
2583
2584impl Diff {
2585    fn is_empty(&self) -> bool {
2586        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2587    }
2588}
2589
2590impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2591
2592impl CodegenAlternative {
2593    pub fn new(
2594        buffer: Model<MultiBuffer>,
2595        range: Range<Anchor>,
2596        active: bool,
2597        context_store: Option<Model<ContextStore>>,
2598        telemetry: Option<Arc<Telemetry>>,
2599        builder: Arc<PromptBuilder>,
2600        cx: &mut ModelContext<Self>,
2601    ) -> Self {
2602        let snapshot = buffer.read(cx).snapshot(cx);
2603
2604        let (old_buffer, _, _) = buffer
2605            .read(cx)
2606            .range_to_buffer_ranges(range.clone(), cx)
2607            .pop()
2608            .unwrap();
2609        let old_buffer = cx.new_model(|cx| {
2610            let old_buffer = old_buffer.read(cx);
2611            let text = old_buffer.as_rope().clone();
2612            let line_ending = old_buffer.line_ending();
2613            let language = old_buffer.language().cloned();
2614            let language_registry = old_buffer.language_registry();
2615
2616            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2617            buffer.set_language(language, cx);
2618            if let Some(language_registry) = language_registry {
2619                buffer.set_language_registry(language_registry)
2620            }
2621            buffer
2622        });
2623
2624        Self {
2625            buffer: buffer.clone(),
2626            old_buffer,
2627            edit_position: None,
2628            message_id: None,
2629            snapshot,
2630            last_equal_ranges: Default::default(),
2631            transformation_transaction_id: None,
2632            status: CodegenStatus::Idle,
2633            generation: Task::ready(()),
2634            diff: Diff::default(),
2635            context_store,
2636            telemetry,
2637            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2638            builder,
2639            active,
2640            edits: Vec::new(),
2641            line_operations: Vec::new(),
2642            range,
2643            request: None,
2644            elapsed_time: None,
2645            completion: None,
2646        }
2647    }
2648
2649    fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2650        if active != self.active {
2651            self.active = active;
2652
2653            if self.active {
2654                let edits = self.edits.clone();
2655                self.apply_edits(edits, cx);
2656                if matches!(self.status, CodegenStatus::Pending) {
2657                    let line_operations = self.line_operations.clone();
2658                    self.reapply_line_based_diff(line_operations, cx);
2659                } else {
2660                    self.reapply_batch_diff(cx).detach();
2661                }
2662            } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2663                self.buffer.update(cx, |buffer, cx| {
2664                    buffer.undo_transaction(transaction_id, cx);
2665                    buffer.forget_transaction(transaction_id, cx);
2666                });
2667            }
2668        }
2669    }
2670
2671    fn handle_buffer_event(
2672        &mut self,
2673        _buffer: Model<MultiBuffer>,
2674        event: &multi_buffer::Event,
2675        cx: &mut ModelContext<Self>,
2676    ) {
2677        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2678            if self.transformation_transaction_id == Some(*transaction_id) {
2679                self.transformation_transaction_id = None;
2680                self.generation = Task::ready(());
2681                cx.emit(CodegenEvent::Undone);
2682            }
2683        }
2684    }
2685
2686    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2687        &self.last_equal_ranges
2688    }
2689
2690    pub fn start(
2691        &mut self,
2692        user_prompt: String,
2693        model: Arc<dyn LanguageModel>,
2694        cx: &mut ModelContext<Self>,
2695    ) -> Result<()> {
2696        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2697            self.buffer.update(cx, |buffer, cx| {
2698                buffer.undo_transaction(transformation_transaction_id, cx);
2699            });
2700        }
2701
2702        self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2703
2704        let api_key = model.api_key(cx);
2705        let telemetry_id = model.telemetry_id();
2706        let provider_id = model.provider_id();
2707        let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2708            if user_prompt.trim().to_lowercase() == "delete" {
2709                async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2710            } else {
2711                let request = self.build_request(user_prompt, cx)?;
2712                self.request = Some(request.clone());
2713
2714                cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
2715                    .boxed_local()
2716            };
2717        self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2718        Ok(())
2719    }
2720
2721    fn build_request(
2722        &self,
2723        user_prompt: String,
2724        cx: &mut AppContext,
2725    ) -> Result<LanguageModelRequest> {
2726        let buffer = self.buffer.read(cx).snapshot(cx);
2727        let language = buffer.language_at(self.range.start);
2728        let language_name = if let Some(language) = language.as_ref() {
2729            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2730                None
2731            } else {
2732                Some(language.name())
2733            }
2734        } else {
2735            None
2736        };
2737
2738        let language_name = language_name.as_ref();
2739        let start = buffer.point_to_buffer_offset(self.range.start);
2740        let end = buffer.point_to_buffer_offset(self.range.end);
2741        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2742            let (start_buffer, start_buffer_offset) = start;
2743            let (end_buffer, end_buffer_offset) = end;
2744            if start_buffer.remote_id() == end_buffer.remote_id() {
2745                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2746            } else {
2747                return Err(anyhow::anyhow!("invalid transformation range"));
2748            }
2749        } else {
2750            return Err(anyhow::anyhow!("invalid transformation range"));
2751        };
2752
2753        let prompt = self
2754            .builder
2755            .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2756            .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2757
2758        let mut request_message = LanguageModelRequestMessage {
2759            role: Role::User,
2760            content: Vec::new(),
2761            cache: false,
2762        };
2763
2764        if let Some(context_store) = &self.context_store {
2765            let context = context_store.update(cx, |this, _cx| this.context().clone());
2766            attach_context_to_message(&mut request_message, context);
2767        }
2768
2769        request_message.content.push(prompt.into());
2770
2771        Ok(LanguageModelRequest {
2772            tools: Vec::new(),
2773            stop: Vec::new(),
2774            temperature: None,
2775            messages: vec![request_message],
2776        })
2777    }
2778
2779    pub fn handle_stream(
2780        &mut self,
2781        model_telemetry_id: String,
2782        model_provider_id: String,
2783        model_api_key: Option<String>,
2784        stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2785        cx: &mut ModelContext<Self>,
2786    ) {
2787        let start_time = Instant::now();
2788        let snapshot = self.snapshot.clone();
2789        let selected_text = snapshot
2790            .text_for_range(self.range.start..self.range.end)
2791            .collect::<Rope>();
2792
2793        let selection_start = self.range.start.to_point(&snapshot);
2794
2795        // Start with the indentation of the first line in the selection
2796        let mut suggested_line_indent = snapshot
2797            .suggested_indents(selection_start.row..=selection_start.row, cx)
2798            .into_values()
2799            .next()
2800            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2801
2802        // If the first line in the selection does not have indentation, check the following lines
2803        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2804            for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2805                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2806                // Prefer tabs if a line in the selection uses tabs as indentation
2807                if line_indent.kind == IndentKind::Tab {
2808                    suggested_line_indent.kind = IndentKind::Tab;
2809                    break;
2810                }
2811            }
2812        }
2813
2814        let http_client = cx.http_client().clone();
2815        let telemetry = self.telemetry.clone();
2816        let language_name = {
2817            let multibuffer = self.buffer.read(cx);
2818            let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
2819            ranges
2820                .first()
2821                .and_then(|(buffer, _, _)| buffer.read(cx).language())
2822                .map(|language| language.name())
2823        };
2824
2825        self.diff = Diff::default();
2826        self.status = CodegenStatus::Pending;
2827        let mut edit_start = self.range.start.to_offset(&snapshot);
2828        let completion = Arc::new(Mutex::new(String::new()));
2829        let completion_clone = completion.clone();
2830
2831        self.generation = cx.spawn(|codegen, mut cx| {
2832            async move {
2833                let stream = stream.await;
2834                let message_id = stream
2835                    .as_ref()
2836                    .ok()
2837                    .and_then(|stream| stream.message_id.clone());
2838                let generate = async {
2839                    let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2840                    let executor = cx.background_executor().clone();
2841                    let message_id = message_id.clone();
2842                    let line_based_stream_diff: Task<anyhow::Result<()>> =
2843                        cx.background_executor().spawn(async move {
2844                            let mut response_latency = None;
2845                            let request_start = Instant::now();
2846                            let diff = async {
2847                                let chunks = StripInvalidSpans::new(stream?.stream);
2848                                futures::pin_mut!(chunks);
2849                                let mut diff = StreamingDiff::new(selected_text.to_string());
2850                                let mut line_diff = LineDiff::default();
2851
2852                                let mut new_text = String::new();
2853                                let mut base_indent = None;
2854                                let mut line_indent = None;
2855                                let mut first_line = true;
2856
2857                                while let Some(chunk) = chunks.next().await {
2858                                    if response_latency.is_none() {
2859                                        response_latency = Some(request_start.elapsed());
2860                                    }
2861                                    let chunk = chunk?;
2862                                    completion_clone.lock().push_str(&chunk);
2863
2864                                    let mut lines = chunk.split('\n').peekable();
2865                                    while let Some(line) = lines.next() {
2866                                        new_text.push_str(line);
2867                                        if line_indent.is_none() {
2868                                            if let Some(non_whitespace_ch_ix) =
2869                                                new_text.find(|ch: char| !ch.is_whitespace())
2870                                            {
2871                                                line_indent = Some(non_whitespace_ch_ix);
2872                                                base_indent = base_indent.or(line_indent);
2873
2874                                                let line_indent = line_indent.unwrap();
2875                                                let base_indent = base_indent.unwrap();
2876                                                let indent_delta =
2877                                                    line_indent as i32 - base_indent as i32;
2878                                                let mut corrected_indent_len = cmp::max(
2879                                                    0,
2880                                                    suggested_line_indent.len as i32 + indent_delta,
2881                                                )
2882                                                    as usize;
2883                                                if first_line {
2884                                                    corrected_indent_len = corrected_indent_len
2885                                                        .saturating_sub(
2886                                                            selection_start.column as usize,
2887                                                        );
2888                                                }
2889
2890                                                let indent_char = suggested_line_indent.char();
2891                                                let mut indent_buffer = [0; 4];
2892                                                let indent_str =
2893                                                    indent_char.encode_utf8(&mut indent_buffer);
2894                                                new_text.replace_range(
2895                                                    ..line_indent,
2896                                                    &indent_str.repeat(corrected_indent_len),
2897                                                );
2898                                            }
2899                                        }
2900
2901                                        if line_indent.is_some() {
2902                                            let char_ops = diff.push_new(&new_text);
2903                                            line_diff
2904                                                .push_char_operations(&char_ops, &selected_text);
2905                                            diff_tx
2906                                                .send((char_ops, line_diff.line_operations()))
2907                                                .await?;
2908                                            new_text.clear();
2909                                        }
2910
2911                                        if lines.peek().is_some() {
2912                                            let char_ops = diff.push_new("\n");
2913                                            line_diff
2914                                                .push_char_operations(&char_ops, &selected_text);
2915                                            diff_tx
2916                                                .send((char_ops, line_diff.line_operations()))
2917                                                .await?;
2918                                            if line_indent.is_none() {
2919                                                // Don't write out the leading indentation in empty lines on the next line
2920                                                // This is the case where the above if statement didn't clear the buffer
2921                                                new_text.clear();
2922                                            }
2923                                            line_indent = None;
2924                                            first_line = false;
2925                                        }
2926                                    }
2927                                }
2928
2929                                let mut char_ops = diff.push_new(&new_text);
2930                                char_ops.extend(diff.finish());
2931                                line_diff.push_char_operations(&char_ops, &selected_text);
2932                                line_diff.finish(&selected_text);
2933                                diff_tx
2934                                    .send((char_ops, line_diff.line_operations()))
2935                                    .await?;
2936
2937                                anyhow::Ok(())
2938                            };
2939
2940                            let result = diff.await;
2941
2942                            let error_message =
2943                                result.as_ref().err().map(|error| error.to_string());
2944                            report_assistant_event(
2945                                AssistantEvent {
2946                                    conversation_id: None,
2947                                    message_id,
2948                                    kind: AssistantKind::Inline,
2949                                    phase: AssistantPhase::Response,
2950                                    model: model_telemetry_id,
2951                                    model_provider: model_provider_id.to_string(),
2952                                    response_latency,
2953                                    error_message,
2954                                    language_name: language_name.map(|name| name.to_proto()),
2955                                },
2956                                telemetry,
2957                                http_client,
2958                                model_api_key,
2959                                &executor,
2960                            );
2961
2962                            result?;
2963                            Ok(())
2964                        });
2965
2966                    while let Some((char_ops, line_ops)) = diff_rx.next().await {
2967                        codegen.update(&mut cx, |codegen, cx| {
2968                            codegen.last_equal_ranges.clear();
2969
2970                            let edits = char_ops
2971                                .into_iter()
2972                                .filter_map(|operation| match operation {
2973                                    CharOperation::Insert { text } => {
2974                                        let edit_start = snapshot.anchor_after(edit_start);
2975                                        Some((edit_start..edit_start, text))
2976                                    }
2977                                    CharOperation::Delete { bytes } => {
2978                                        let edit_end = edit_start + bytes;
2979                                        let edit_range = snapshot.anchor_after(edit_start)
2980                                            ..snapshot.anchor_before(edit_end);
2981                                        edit_start = edit_end;
2982                                        Some((edit_range, String::new()))
2983                                    }
2984                                    CharOperation::Keep { bytes } => {
2985                                        let edit_end = edit_start + bytes;
2986                                        let edit_range = snapshot.anchor_after(edit_start)
2987                                            ..snapshot.anchor_before(edit_end);
2988                                        edit_start = edit_end;
2989                                        codegen.last_equal_ranges.push(edit_range);
2990                                        None
2991                                    }
2992                                })
2993                                .collect::<Vec<_>>();
2994
2995                            if codegen.active {
2996                                codegen.apply_edits(edits.iter().cloned(), cx);
2997                                codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2998                            }
2999                            codegen.edits.extend(edits);
3000                            codegen.line_operations = line_ops;
3001                            codegen.edit_position = Some(snapshot.anchor_after(edit_start));
3002
3003                            cx.notify();
3004                        })?;
3005                    }
3006
3007                    // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
3008                    // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
3009                    // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
3010                    let batch_diff_task =
3011                        codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
3012                    let (line_based_stream_diff, ()) =
3013                        join!(line_based_stream_diff, batch_diff_task);
3014                    line_based_stream_diff?;
3015
3016                    anyhow::Ok(())
3017                };
3018
3019                let result = generate.await;
3020                let elapsed_time = start_time.elapsed().as_secs_f64();
3021
3022                codegen
3023                    .update(&mut cx, |this, cx| {
3024                        this.message_id = message_id;
3025                        this.last_equal_ranges.clear();
3026                        if let Err(error) = result {
3027                            this.status = CodegenStatus::Error(error);
3028                        } else {
3029                            this.status = CodegenStatus::Done;
3030                        }
3031                        this.elapsed_time = Some(elapsed_time);
3032                        this.completion = Some(completion.lock().clone());
3033                        cx.emit(CodegenEvent::Finished);
3034                        cx.notify();
3035                    })
3036                    .ok();
3037            }
3038        });
3039        cx.notify();
3040    }
3041
3042    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
3043        self.last_equal_ranges.clear();
3044        if self.diff.is_empty() {
3045            self.status = CodegenStatus::Idle;
3046        } else {
3047            self.status = CodegenStatus::Done;
3048        }
3049        self.generation = Task::ready(());
3050        cx.emit(CodegenEvent::Finished);
3051        cx.notify();
3052    }
3053
3054    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
3055        self.buffer.update(cx, |buffer, cx| {
3056            if let Some(transaction_id) = self.transformation_transaction_id.take() {
3057                buffer.undo_transaction(transaction_id, cx);
3058                buffer.refresh_preview(cx);
3059            }
3060        });
3061    }
3062
3063    fn apply_edits(
3064        &mut self,
3065        edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
3066        cx: &mut ModelContext<CodegenAlternative>,
3067    ) {
3068        let transaction = self.buffer.update(cx, |buffer, cx| {
3069            // Avoid grouping assistant edits with user edits.
3070            buffer.finalize_last_transaction(cx);
3071            buffer.start_transaction(cx);
3072            buffer.edit(edits, None, cx);
3073            buffer.end_transaction(cx)
3074        });
3075
3076        if let Some(transaction) = transaction {
3077            if let Some(first_transaction) = self.transformation_transaction_id {
3078                // Group all assistant edits into the first transaction.
3079                self.buffer.update(cx, |buffer, cx| {
3080                    buffer.merge_transactions(transaction, first_transaction, cx)
3081                });
3082            } else {
3083                self.transformation_transaction_id = Some(transaction);
3084                self.buffer
3085                    .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3086            }
3087        }
3088    }
3089
3090    fn reapply_line_based_diff(
3091        &mut self,
3092        line_operations: impl IntoIterator<Item = LineOperation>,
3093        cx: &mut ModelContext<Self>,
3094    ) {
3095        let old_snapshot = self.snapshot.clone();
3096        let old_range = self.range.to_point(&old_snapshot);
3097        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3098        let new_range = self.range.to_point(&new_snapshot);
3099
3100        let mut old_row = old_range.start.row;
3101        let mut new_row = new_range.start.row;
3102
3103        self.diff.deleted_row_ranges.clear();
3104        self.diff.inserted_row_ranges.clear();
3105        for operation in line_operations {
3106            match operation {
3107                LineOperation::Keep { lines } => {
3108                    old_row += lines;
3109                    new_row += lines;
3110                }
3111                LineOperation::Delete { lines } => {
3112                    let old_end_row = old_row + lines - 1;
3113                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3114
3115                    if let Some((_, last_deleted_row_range)) =
3116                        self.diff.deleted_row_ranges.last_mut()
3117                    {
3118                        if *last_deleted_row_range.end() + 1 == old_row {
3119                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3120                        } else {
3121                            self.diff
3122                                .deleted_row_ranges
3123                                .push((new_row, old_row..=old_end_row));
3124                        }
3125                    } else {
3126                        self.diff
3127                            .deleted_row_ranges
3128                            .push((new_row, old_row..=old_end_row));
3129                    }
3130
3131                    old_row += lines;
3132                }
3133                LineOperation::Insert { lines } => {
3134                    let new_end_row = new_row + lines - 1;
3135                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3136                    let end = new_snapshot.anchor_before(Point::new(
3137                        new_end_row,
3138                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
3139                    ));
3140                    self.diff.inserted_row_ranges.push(start..end);
3141                    new_row += lines;
3142                }
3143            }
3144
3145            cx.notify();
3146        }
3147    }
3148
3149    fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3150        let old_snapshot = self.snapshot.clone();
3151        let old_range = self.range.to_point(&old_snapshot);
3152        let new_snapshot = self.buffer.read(cx).snapshot(cx);
3153        let new_range = self.range.to_point(&new_snapshot);
3154
3155        cx.spawn(|codegen, mut cx| async move {
3156            let (deleted_row_ranges, inserted_row_ranges) = cx
3157                .background_executor()
3158                .spawn(async move {
3159                    let old_text = old_snapshot
3160                        .text_for_range(
3161                            Point::new(old_range.start.row, 0)
3162                                ..Point::new(
3163                                    old_range.end.row,
3164                                    old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3165                                ),
3166                        )
3167                        .collect::<String>();
3168                    let new_text = new_snapshot
3169                        .text_for_range(
3170                            Point::new(new_range.start.row, 0)
3171                                ..Point::new(
3172                                    new_range.end.row,
3173                                    new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3174                                ),
3175                        )
3176                        .collect::<String>();
3177
3178                    let mut old_row = old_range.start.row;
3179                    let mut new_row = new_range.start.row;
3180                    let batch_diff =
3181                        similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3182
3183                    let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3184                    let mut inserted_row_ranges = Vec::new();
3185                    for change in batch_diff.iter_all_changes() {
3186                        let line_count = change.value().lines().count() as u32;
3187                        match change.tag() {
3188                            similar::ChangeTag::Equal => {
3189                                old_row += line_count;
3190                                new_row += line_count;
3191                            }
3192                            similar::ChangeTag::Delete => {
3193                                let old_end_row = old_row + line_count - 1;
3194                                let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3195
3196                                if let Some((_, last_deleted_row_range)) =
3197                                    deleted_row_ranges.last_mut()
3198                                {
3199                                    if *last_deleted_row_range.end() + 1 == old_row {
3200                                        *last_deleted_row_range =
3201                                            *last_deleted_row_range.start()..=old_end_row;
3202                                    } else {
3203                                        deleted_row_ranges.push((new_row, old_row..=old_end_row));
3204                                    }
3205                                } else {
3206                                    deleted_row_ranges.push((new_row, old_row..=old_end_row));
3207                                }
3208
3209                                old_row += line_count;
3210                            }
3211                            similar::ChangeTag::Insert => {
3212                                let new_end_row = new_row + line_count - 1;
3213                                let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3214                                let end = new_snapshot.anchor_before(Point::new(
3215                                    new_end_row,
3216                                    new_snapshot.line_len(MultiBufferRow(new_end_row)),
3217                                ));
3218                                inserted_row_ranges.push(start..end);
3219                                new_row += line_count;
3220                            }
3221                        }
3222                    }
3223
3224                    (deleted_row_ranges, inserted_row_ranges)
3225                })
3226                .await;
3227
3228            codegen
3229                .update(&mut cx, |codegen, cx| {
3230                    codegen.diff.deleted_row_ranges = deleted_row_ranges;
3231                    codegen.diff.inserted_row_ranges = inserted_row_ranges;
3232                    cx.notify();
3233                })
3234                .ok();
3235        })
3236    }
3237}
3238
3239struct StripInvalidSpans<T> {
3240    stream: T,
3241    stream_done: bool,
3242    buffer: String,
3243    first_line: bool,
3244    line_end: bool,
3245    starts_with_code_block: bool,
3246}
3247
3248impl<T> StripInvalidSpans<T>
3249where
3250    T: Stream<Item = Result<String>>,
3251{
3252    fn new(stream: T) -> Self {
3253        Self {
3254            stream,
3255            stream_done: false,
3256            buffer: String::new(),
3257            first_line: true,
3258            line_end: false,
3259            starts_with_code_block: false,
3260        }
3261    }
3262}
3263
3264impl<T> Stream for StripInvalidSpans<T>
3265where
3266    T: Stream<Item = Result<String>>,
3267{
3268    type Item = Result<String>;
3269
3270    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3271        const CODE_BLOCK_DELIMITER: &str = "```";
3272        const CURSOR_SPAN: &str = "<|CURSOR|>";
3273
3274        let this = unsafe { self.get_unchecked_mut() };
3275        loop {
3276            if !this.stream_done {
3277                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3278                match stream.as_mut().poll_next(cx) {
3279                    Poll::Ready(Some(Ok(chunk))) => {
3280                        this.buffer.push_str(&chunk);
3281                    }
3282                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3283                    Poll::Ready(None) => {
3284                        this.stream_done = true;
3285                    }
3286                    Poll::Pending => return Poll::Pending,
3287                }
3288            }
3289
3290            let mut chunk = String::new();
3291            let mut consumed = 0;
3292            if !this.buffer.is_empty() {
3293                let mut lines = this.buffer.split('\n').enumerate().peekable();
3294                while let Some((line_ix, line)) = lines.next() {
3295                    if line_ix > 0 {
3296                        this.first_line = false;
3297                    }
3298
3299                    if this.first_line {
3300                        let trimmed_line = line.trim();
3301                        if lines.peek().is_some() {
3302                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3303                                consumed += line.len() + 1;
3304                                this.starts_with_code_block = true;
3305                                continue;
3306                            }
3307                        } else if trimmed_line.is_empty()
3308                            || prefixes(CODE_BLOCK_DELIMITER)
3309                                .any(|prefix| trimmed_line.starts_with(prefix))
3310                        {
3311                            break;
3312                        }
3313                    }
3314
3315                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
3316                    if lines.peek().is_some() {
3317                        if this.line_end {
3318                            chunk.push('\n');
3319                        }
3320
3321                        chunk.push_str(&line_without_cursor);
3322                        this.line_end = true;
3323                        consumed += line.len() + 1;
3324                    } else if this.stream_done {
3325                        if !this.starts_with_code_block
3326                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3327                        {
3328                            if this.line_end {
3329                                chunk.push('\n');
3330                            }
3331
3332                            chunk.push_str(&line);
3333                        }
3334
3335                        consumed += line.len();
3336                    } else {
3337                        let trimmed_line = line.trim();
3338                        if trimmed_line.is_empty()
3339                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3340                            || prefixes(CODE_BLOCK_DELIMITER)
3341                                .any(|prefix| trimmed_line.ends_with(prefix))
3342                        {
3343                            break;
3344                        } else {
3345                            if this.line_end {
3346                                chunk.push('\n');
3347                                this.line_end = false;
3348                            }
3349
3350                            chunk.push_str(&line_without_cursor);
3351                            consumed += line.len();
3352                        }
3353                    }
3354                }
3355            }
3356
3357            this.buffer = this.buffer.split_off(consumed);
3358            if !chunk.is_empty() {
3359                return Poll::Ready(Some(Ok(chunk)));
3360            } else if this.stream_done {
3361                return Poll::Ready(None);
3362            }
3363        }
3364    }
3365}
3366
3367struct AssistantCodeActionProvider {
3368    editor: WeakView<Editor>,
3369    workspace: WeakView<Workspace>,
3370    thread_store: Option<WeakModel<ThreadStore>>,
3371}
3372
3373impl CodeActionProvider for AssistantCodeActionProvider {
3374    fn code_actions(
3375        &self,
3376        buffer: &Model<Buffer>,
3377        range: Range<text::Anchor>,
3378        cx: &mut WindowContext,
3379    ) -> Task<Result<Vec<CodeAction>>> {
3380        if !AssistantSettings::get_global(cx).enabled {
3381            return Task::ready(Ok(Vec::new()));
3382        }
3383
3384        let snapshot = buffer.read(cx).snapshot();
3385        let mut range = range.to_point(&snapshot);
3386
3387        // Expand the range to line boundaries.
3388        range.start.column = 0;
3389        range.end.column = snapshot.line_len(range.end.row);
3390
3391        let mut has_diagnostics = false;
3392        for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3393            range.start = cmp::min(range.start, diagnostic.range.start);
3394            range.end = cmp::max(range.end, diagnostic.range.end);
3395            has_diagnostics = true;
3396        }
3397        if has_diagnostics {
3398            if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3399                if let Some(symbol) = symbols_containing_start.last() {
3400                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3401                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3402                }
3403            }
3404
3405            if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3406                if let Some(symbol) = symbols_containing_end.last() {
3407                    range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3408                    range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3409                }
3410            }
3411
3412            Task::ready(Ok(vec![CodeAction {
3413                server_id: language::LanguageServerId(0),
3414                range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3415                lsp_action: lsp::CodeAction {
3416                    title: "Fix with Assistant".into(),
3417                    ..Default::default()
3418                },
3419            }]))
3420        } else {
3421            Task::ready(Ok(Vec::new()))
3422        }
3423    }
3424
3425    fn apply_code_action(
3426        &self,
3427        buffer: Model<Buffer>,
3428        action: CodeAction,
3429        excerpt_id: ExcerptId,
3430        _push_to_history: bool,
3431        cx: &mut WindowContext,
3432    ) -> Task<Result<ProjectTransaction>> {
3433        let editor = self.editor.clone();
3434        let workspace = self.workspace.clone();
3435        let thread_store = self.thread_store.clone();
3436        cx.spawn(|mut cx| async move {
3437            let editor = editor.upgrade().context("editor was released")?;
3438            let range = editor
3439                .update(&mut cx, |editor, cx| {
3440                    editor.buffer().update(cx, |multibuffer, cx| {
3441                        let buffer = buffer.read(cx);
3442                        let multibuffer_snapshot = multibuffer.read(cx);
3443
3444                        let old_context_range =
3445                            multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3446                        let mut new_context_range = old_context_range.clone();
3447                        if action
3448                            .range
3449                            .start
3450                            .cmp(&old_context_range.start, buffer)
3451                            .is_lt()
3452                        {
3453                            new_context_range.start = action.range.start;
3454                        }
3455                        if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3456                            new_context_range.end = action.range.end;
3457                        }
3458                        drop(multibuffer_snapshot);
3459
3460                        if new_context_range != old_context_range {
3461                            multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3462                        }
3463
3464                        let multibuffer_snapshot = multibuffer.read(cx);
3465                        Some(
3466                            multibuffer_snapshot
3467                                .anchor_in_excerpt(excerpt_id, action.range.start)?
3468                                ..multibuffer_snapshot
3469                                    .anchor_in_excerpt(excerpt_id, action.range.end)?,
3470                        )
3471                    })
3472                })?
3473                .context("invalid range")?;
3474
3475            cx.update_global(|assistant: &mut InlineAssistant, cx| {
3476                let assist_id = assistant.suggest_assist(
3477                    &editor,
3478                    range,
3479                    "Fix Diagnostics".into(),
3480                    None,
3481                    true,
3482                    workspace,
3483                    thread_store,
3484                    cx,
3485                );
3486                assistant.start_assist(assist_id, cx);
3487            })?;
3488
3489            Ok(ProjectTransaction::default())
3490        })
3491    }
3492}
3493
3494fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3495    (0..text.len() - 1).map(|ix| &text[..ix + 1])
3496}
3497
3498fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3499    ranges.sort_unstable_by(|a, b| {
3500        a.start
3501            .cmp(&b.start, buffer)
3502            .then_with(|| b.end.cmp(&a.end, buffer))
3503    });
3504
3505    let mut ix = 0;
3506    while ix + 1 < ranges.len() {
3507        let b = ranges[ix + 1].clone();
3508        let a = &mut ranges[ix];
3509        if a.end.cmp(&b.start, buffer).is_gt() {
3510            if a.end.cmp(&b.end, buffer).is_lt() {
3511                a.end = b.end;
3512            }
3513            ranges.remove(ix + 1);
3514        } else {
3515            ix += 1;
3516        }
3517    }
3518}
3519
3520#[cfg(test)]
3521mod tests {
3522    use super::*;
3523    use futures::stream::{self};
3524    use gpui::{Context, TestAppContext};
3525    use indoc::indoc;
3526    use language::{
3527        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3528        Point,
3529    };
3530    use language_model::LanguageModelRegistry;
3531    use rand::prelude::*;
3532    use serde::Serialize;
3533    use settings::SettingsStore;
3534    use std::{future, sync::Arc};
3535
3536    #[derive(Serialize)]
3537    pub struct DummyCompletionRequest {
3538        pub name: String,
3539    }
3540
3541    #[gpui::test(iterations = 10)]
3542    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3543        cx.set_global(cx.update(SettingsStore::test));
3544        cx.update(language_model::LanguageModelRegistry::test);
3545        cx.update(language_settings::init);
3546
3547        let text = indoc! {"
3548            fn main() {
3549                let x = 0;
3550                for _ in 0..10 {
3551                    x += 1;
3552                }
3553            }
3554        "};
3555        let buffer =
3556            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3557        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3558        let range = buffer.read_with(cx, |buffer, cx| {
3559            let snapshot = buffer.snapshot(cx);
3560            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3561        });
3562        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3563        let codegen = cx.new_model(|cx| {
3564            CodegenAlternative::new(
3565                buffer.clone(),
3566                range.clone(),
3567                true,
3568                None,
3569                None,
3570                prompt_builder,
3571                cx,
3572            )
3573        });
3574
3575        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3576
3577        let mut new_text = concat!(
3578            "       let mut x = 0;\n",
3579            "       while x < 10 {\n",
3580            "           x += 1;\n",
3581            "       }",
3582        );
3583        while !new_text.is_empty() {
3584            let max_len = cmp::min(new_text.len(), 10);
3585            let len = rng.gen_range(1..=max_len);
3586            let (chunk, suffix) = new_text.split_at(len);
3587            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3588            new_text = suffix;
3589            cx.background_executor.run_until_parked();
3590        }
3591        drop(chunks_tx);
3592        cx.background_executor.run_until_parked();
3593
3594        assert_eq!(
3595            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3596            indoc! {"
3597                fn main() {
3598                    let mut x = 0;
3599                    while x < 10 {
3600                        x += 1;
3601                    }
3602                }
3603            "}
3604        );
3605    }
3606
3607    #[gpui::test(iterations = 10)]
3608    async fn test_autoindent_when_generating_past_indentation(
3609        cx: &mut TestAppContext,
3610        mut rng: StdRng,
3611    ) {
3612        cx.set_global(cx.update(SettingsStore::test));
3613        cx.update(language_settings::init);
3614
3615        let text = indoc! {"
3616            fn main() {
3617                le
3618            }
3619        "};
3620        let buffer =
3621            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3622        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3623        let range = buffer.read_with(cx, |buffer, cx| {
3624            let snapshot = buffer.snapshot(cx);
3625            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3626        });
3627        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3628        let codegen = cx.new_model(|cx| {
3629            CodegenAlternative::new(
3630                buffer.clone(),
3631                range.clone(),
3632                true,
3633                None,
3634                None,
3635                prompt_builder,
3636                cx,
3637            )
3638        });
3639
3640        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3641
3642        cx.background_executor.run_until_parked();
3643
3644        let mut new_text = concat!(
3645            "t mut x = 0;\n",
3646            "while x < 10 {\n",
3647            "    x += 1;\n",
3648            "}", //
3649        );
3650        while !new_text.is_empty() {
3651            let max_len = cmp::min(new_text.len(), 10);
3652            let len = rng.gen_range(1..=max_len);
3653            let (chunk, suffix) = new_text.split_at(len);
3654            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3655            new_text = suffix;
3656            cx.background_executor.run_until_parked();
3657        }
3658        drop(chunks_tx);
3659        cx.background_executor.run_until_parked();
3660
3661        assert_eq!(
3662            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3663            indoc! {"
3664                fn main() {
3665                    let mut x = 0;
3666                    while x < 10 {
3667                        x += 1;
3668                    }
3669                }
3670            "}
3671        );
3672    }
3673
3674    #[gpui::test(iterations = 10)]
3675    async fn test_autoindent_when_generating_before_indentation(
3676        cx: &mut TestAppContext,
3677        mut rng: StdRng,
3678    ) {
3679        cx.update(LanguageModelRegistry::test);
3680        cx.set_global(cx.update(SettingsStore::test));
3681        cx.update(language_settings::init);
3682
3683        let text = concat!(
3684            "fn main() {\n",
3685            "  \n",
3686            "}\n" //
3687        );
3688        let buffer =
3689            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3690        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3691        let range = buffer.read_with(cx, |buffer, cx| {
3692            let snapshot = buffer.snapshot(cx);
3693            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3694        });
3695        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3696        let codegen = cx.new_model(|cx| {
3697            CodegenAlternative::new(
3698                buffer.clone(),
3699                range.clone(),
3700                true,
3701                None,
3702                None,
3703                prompt_builder,
3704                cx,
3705            )
3706        });
3707
3708        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3709
3710        cx.background_executor.run_until_parked();
3711
3712        let mut new_text = concat!(
3713            "let mut x = 0;\n",
3714            "while x < 10 {\n",
3715            "    x += 1;\n",
3716            "}", //
3717        );
3718        while !new_text.is_empty() {
3719            let max_len = cmp::min(new_text.len(), 10);
3720            let len = rng.gen_range(1..=max_len);
3721            let (chunk, suffix) = new_text.split_at(len);
3722            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3723            new_text = suffix;
3724            cx.background_executor.run_until_parked();
3725        }
3726        drop(chunks_tx);
3727        cx.background_executor.run_until_parked();
3728
3729        assert_eq!(
3730            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3731            indoc! {"
3732                fn main() {
3733                    let mut x = 0;
3734                    while x < 10 {
3735                        x += 1;
3736                    }
3737                }
3738            "}
3739        );
3740    }
3741
3742    #[gpui::test(iterations = 10)]
3743    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3744        cx.update(LanguageModelRegistry::test);
3745        cx.set_global(cx.update(SettingsStore::test));
3746        cx.update(language_settings::init);
3747
3748        let text = indoc! {"
3749            func main() {
3750            \tx := 0
3751            \tfor i := 0; i < 10; i++ {
3752            \t\tx++
3753            \t}
3754            }
3755        "};
3756        let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3757        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3758        let range = buffer.read_with(cx, |buffer, cx| {
3759            let snapshot = buffer.snapshot(cx);
3760            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3761        });
3762        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3763        let codegen = cx.new_model(|cx| {
3764            CodegenAlternative::new(
3765                buffer.clone(),
3766                range.clone(),
3767                true,
3768                None,
3769                None,
3770                prompt_builder,
3771                cx,
3772            )
3773        });
3774
3775        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3776        let new_text = concat!(
3777            "func main() {\n",
3778            "\tx := 0\n",
3779            "\tfor x < 10 {\n",
3780            "\t\tx++\n",
3781            "\t}", //
3782        );
3783        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3784        drop(chunks_tx);
3785        cx.background_executor.run_until_parked();
3786
3787        assert_eq!(
3788            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3789            indoc! {"
3790                func main() {
3791                \tx := 0
3792                \tfor x < 10 {
3793                \t\tx++
3794                \t}
3795                }
3796            "}
3797        );
3798    }
3799
3800    #[gpui::test]
3801    async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3802        cx.update(LanguageModelRegistry::test);
3803        cx.set_global(cx.update(SettingsStore::test));
3804        cx.update(language_settings::init);
3805
3806        let text = indoc! {"
3807            fn main() {
3808                let x = 0;
3809            }
3810        "};
3811        let buffer =
3812            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3813        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3814        let range = buffer.read_with(cx, |buffer, cx| {
3815            let snapshot = buffer.snapshot(cx);
3816            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3817        });
3818        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3819        let codegen = cx.new_model(|cx| {
3820            CodegenAlternative::new(
3821                buffer.clone(),
3822                range.clone(),
3823                false,
3824                None,
3825                None,
3826                prompt_builder,
3827                cx,
3828            )
3829        });
3830
3831        let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3832        chunks_tx
3833            .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3834            .unwrap();
3835        drop(chunks_tx);
3836        cx.run_until_parked();
3837
3838        // The codegen is inactive, so the buffer doesn't get modified.
3839        assert_eq!(
3840            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3841            text
3842        );
3843
3844        // Activating the codegen applies the changes.
3845        codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3846        assert_eq!(
3847            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3848            indoc! {"
3849                fn main() {
3850                    let mut x = 0;
3851                    x += 1;
3852                }
3853            "}
3854        );
3855
3856        // Deactivating the codegen undoes the changes.
3857        codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3858        cx.run_until_parked();
3859        assert_eq!(
3860            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3861            text
3862        );
3863    }
3864
3865    #[gpui::test]
3866    async fn test_strip_invalid_spans_from_codeblock() {
3867        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3868        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3869        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3870        assert_chunks(
3871            "```html\n```js\nLorem ipsum dolor\n```\n```",
3872            "```js\nLorem ipsum dolor\n```",
3873        )
3874        .await;
3875        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3876        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3877        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3878        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3879
3880        async fn assert_chunks(text: &str, expected_text: &str) {
3881            for chunk_size in 1..=text.len() {
3882                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3883                    .map(|chunk| chunk.unwrap())
3884                    .collect::<String>()
3885                    .await;
3886                assert_eq!(
3887                    actual_text, expected_text,
3888                    "failed to strip invalid spans, chunk size: {}",
3889                    chunk_size
3890                );
3891            }
3892        }
3893
3894        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3895            stream::iter(
3896                text.chars()
3897                    .collect::<Vec<_>>()
3898                    .chunks(size)
3899                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
3900                    .collect::<Vec<_>>(),
3901            )
3902        }
3903    }
3904
3905    fn simulate_response_stream(
3906        codegen: Model<CodegenAlternative>,
3907        cx: &mut TestAppContext,
3908    ) -> mpsc::UnboundedSender<String> {
3909        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3910        codegen.update(cx, |codegen, cx| {
3911            codegen.handle_stream(
3912                String::new(),
3913                String::new(),
3914                None,
3915                future::ready(Ok(LanguageModelTextStream {
3916                    message_id: None,
3917                    stream: chunks_rx.map(Ok).boxed(),
3918                })),
3919                cx,
3920            );
3921        });
3922        chunks_tx
3923    }
3924
3925    fn rust_lang() -> Language {
3926        Language::new(
3927            LanguageConfig {
3928                name: "Rust".into(),
3929                matcher: LanguageMatcher {
3930                    path_suffixes: vec!["rs".to_string()],
3931                    ..Default::default()
3932                },
3933                ..Default::default()
3934            },
3935            Some(tree_sitter_rust::LANGUAGE.into()),
3936        )
3937        .with_indents_query(
3938            r#"
3939            (call_expression) @indent
3940            (field_expression) @indent
3941            (_ "(" ")" @end) @indent
3942            (_ "{" "}" @end) @indent
3943            "#,
3944        )
3945        .unwrap()
3946    }
3947}