inline_assistant.rs

   1use crate::{
   2    humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
   3    CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
   4};
   5use anyhow::{anyhow, Context as _, Result};
   6use client::{telemetry::Telemetry, ErrorExt};
   7use collections::{hash_map, HashMap, HashSet, VecDeque};
   8use editor::{
   9    actions::{MoveDown, MoveUp, SelectAll},
  10    display_map::{
  11        BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
  12        ToDisplayPoint,
  13    },
  14    Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
  15    ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
  16};
  17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
  18use fs::Fs;
  19use futures::{
  20    channel::mpsc,
  21    future::{BoxFuture, LocalBoxFuture},
  22    stream::{self, BoxStream},
  23    SinkExt, Stream, StreamExt,
  24};
  25use gpui::{
  26    anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
  27    FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
  28    UpdateGlobal, View, ViewContext, WeakView, WindowContext,
  29};
  30use language::{Buffer, IndentKind, Point, Selection, TransactionId};
  31use language_model::{
  32    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
  33};
  34use multi_buffer::MultiBufferRow;
  35use parking_lot::Mutex;
  36use rope::Rope;
  37use settings::Settings;
  38use smol::future::FutureExt;
  39use std::{
  40    cmp,
  41    future::{self, Future},
  42    mem,
  43    ops::{Range, RangeInclusive},
  44    pin::Pin,
  45    sync::Arc,
  46    task::{self, Poll},
  47    time::{Duration, Instant},
  48};
  49use theme::ThemeSettings;
  50use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
  51use util::{RangeExt, ResultExt};
  52use workspace::{notifications::NotificationId, Toast, Workspace};
  53
  54pub fn init(
  55    fs: Arc<dyn Fs>,
  56    prompt_builder: Arc<PromptBuilder>,
  57    telemetry: Arc<Telemetry>,
  58    cx: &mut AppContext,
  59) {
  60    cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
  61}
  62
  63const PROMPT_HISTORY_MAX_LEN: usize = 20;
  64
  65pub struct InlineAssistant {
  66    next_assist_id: InlineAssistId,
  67    next_assist_group_id: InlineAssistGroupId,
  68    assists: HashMap<InlineAssistId, InlineAssist>,
  69    assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
  70    assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
  71    prompt_history: VecDeque<String>,
  72    prompt_builder: Arc<PromptBuilder>,
  73    telemetry: Option<Arc<Telemetry>>,
  74    fs: Arc<dyn Fs>,
  75}
  76
  77impl Global for InlineAssistant {}
  78
  79impl InlineAssistant {
  80    pub fn new(
  81        fs: Arc<dyn Fs>,
  82        prompt_builder: Arc<PromptBuilder>,
  83        telemetry: Arc<Telemetry>,
  84    ) -> Self {
  85        Self {
  86            next_assist_id: InlineAssistId::default(),
  87            next_assist_group_id: InlineAssistGroupId::default(),
  88            assists: HashMap::default(),
  89            assists_by_editor: HashMap::default(),
  90            assist_groups: HashMap::default(),
  91            prompt_history: VecDeque::default(),
  92            prompt_builder,
  93            telemetry: Some(telemetry),
  94            fs,
  95        }
  96    }
  97
  98    pub fn assist(
  99        &mut self,
 100        editor: &View<Editor>,
 101        workspace: Option<WeakView<Workspace>>,
 102        assistant_panel: Option<&View<AssistantPanel>>,
 103        initial_prompt: Option<String>,
 104        cx: &mut WindowContext,
 105    ) {
 106        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 107
 108        let mut selections = Vec::<Selection<Point>>::new();
 109        let mut newest_selection = None;
 110        for mut selection in editor.read(cx).selections.all::<Point>(cx) {
 111            if selection.end > selection.start {
 112                selection.start.column = 0;
 113                // If the selection ends at the start of the line, we don't want to include it.
 114                if selection.end.column == 0 {
 115                    selection.end.row -= 1;
 116                }
 117                selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
 118            }
 119
 120            if let Some(prev_selection) = selections.last_mut() {
 121                if selection.start <= prev_selection.end {
 122                    prev_selection.end = selection.end;
 123                    continue;
 124                }
 125            }
 126
 127            let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
 128            if selection.id > latest_selection.id {
 129                *latest_selection = selection.clone();
 130            }
 131            selections.push(selection);
 132        }
 133        let newest_selection = newest_selection.unwrap();
 134
 135        let mut codegen_ranges = Vec::new();
 136        for (excerpt_id, buffer, buffer_range) in
 137            snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
 138                snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
 139            }))
 140        {
 141            let start = Anchor {
 142                buffer_id: Some(buffer.remote_id()),
 143                excerpt_id,
 144                text_anchor: buffer.anchor_before(buffer_range.start),
 145            };
 146            let end = Anchor {
 147                buffer_id: Some(buffer.remote_id()),
 148                excerpt_id,
 149                text_anchor: buffer.anchor_after(buffer_range.end),
 150            };
 151            codegen_ranges.push(start..end);
 152        }
 153
 154        let assist_group_id = self.next_assist_group_id.post_inc();
 155        let prompt_buffer =
 156            cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
 157        let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 158
 159        let mut assists = Vec::new();
 160        let mut assist_to_focus = None;
 161        for range in codegen_ranges {
 162            let assist_id = self.next_assist_id.post_inc();
 163            let codegen = cx.new_model(|cx| {
 164                Codegen::new(
 165                    editor.read(cx).buffer().clone(),
 166                    range.clone(),
 167                    None,
 168                    self.telemetry.clone(),
 169                    self.prompt_builder.clone(),
 170                    cx,
 171                )
 172            });
 173
 174            let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 175            let prompt_editor = cx.new_view(|cx| {
 176                PromptEditor::new(
 177                    assist_id,
 178                    gutter_dimensions.clone(),
 179                    self.prompt_history.clone(),
 180                    prompt_buffer.clone(),
 181                    codegen.clone(),
 182                    editor,
 183                    assistant_panel,
 184                    workspace.clone(),
 185                    self.fs.clone(),
 186                    cx,
 187                )
 188            });
 189
 190            if assist_to_focus.is_none() {
 191                let focus_assist = if newest_selection.reversed {
 192                    range.start.to_point(&snapshot) == newest_selection.start
 193                } else {
 194                    range.end.to_point(&snapshot) == newest_selection.end
 195                };
 196                if focus_assist {
 197                    assist_to_focus = Some(assist_id);
 198                }
 199            }
 200
 201            let [prompt_block_id, end_block_id] =
 202                self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 203
 204            assists.push((
 205                assist_id,
 206                range,
 207                prompt_editor,
 208                prompt_block_id,
 209                end_block_id,
 210            ));
 211        }
 212
 213        let editor_assists = self
 214            .assists_by_editor
 215            .entry(editor.downgrade())
 216            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 217        let mut assist_group = InlineAssistGroup::new();
 218        for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
 219            self.assists.insert(
 220                assist_id,
 221                InlineAssist::new(
 222                    assist_id,
 223                    assist_group_id,
 224                    assistant_panel.is_some(),
 225                    editor,
 226                    &prompt_editor,
 227                    prompt_block_id,
 228                    end_block_id,
 229                    range,
 230                    prompt_editor.read(cx).codegen.clone(),
 231                    workspace.clone(),
 232                    cx,
 233                ),
 234            );
 235            assist_group.assist_ids.push(assist_id);
 236            editor_assists.assist_ids.push(assist_id);
 237        }
 238        self.assist_groups.insert(assist_group_id, assist_group);
 239
 240        if let Some(assist_id) = assist_to_focus {
 241            self.focus_assist(assist_id, cx);
 242        }
 243    }
 244
 245    #[allow(clippy::too_many_arguments)]
 246    pub fn suggest_assist(
 247        &mut self,
 248        editor: &View<Editor>,
 249        mut range: Range<Anchor>,
 250        initial_prompt: String,
 251        initial_transaction_id: Option<TransactionId>,
 252        workspace: Option<WeakView<Workspace>>,
 253        assistant_panel: Option<&View<AssistantPanel>>,
 254        cx: &mut WindowContext,
 255    ) -> InlineAssistId {
 256        let assist_group_id = self.next_assist_group_id.post_inc();
 257        let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
 258        let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
 259
 260        let assist_id = self.next_assist_id.post_inc();
 261
 262        let buffer = editor.read(cx).buffer().clone();
 263        {
 264            let snapshot = buffer.read(cx).read(cx);
 265            range.start = range.start.bias_left(&snapshot);
 266            range.end = range.end.bias_right(&snapshot);
 267        }
 268
 269        let codegen = cx.new_model(|cx| {
 270            Codegen::new(
 271                editor.read(cx).buffer().clone(),
 272                range.clone(),
 273                initial_transaction_id,
 274                self.telemetry.clone(),
 275                self.prompt_builder.clone(),
 276                cx,
 277            )
 278        });
 279
 280        let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 281        let prompt_editor = cx.new_view(|cx| {
 282            PromptEditor::new(
 283                assist_id,
 284                gutter_dimensions.clone(),
 285                self.prompt_history.clone(),
 286                prompt_buffer.clone(),
 287                codegen.clone(),
 288                editor,
 289                assistant_panel,
 290                workspace.clone(),
 291                self.fs.clone(),
 292                cx,
 293            )
 294        });
 295
 296        let [prompt_block_id, end_block_id] =
 297            self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
 298
 299        let editor_assists = self
 300            .assists_by_editor
 301            .entry(editor.downgrade())
 302            .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
 303
 304        let mut assist_group = InlineAssistGroup::new();
 305        self.assists.insert(
 306            assist_id,
 307            InlineAssist::new(
 308                assist_id,
 309                assist_group_id,
 310                assistant_panel.is_some(),
 311                editor,
 312                &prompt_editor,
 313                prompt_block_id,
 314                end_block_id,
 315                range,
 316                prompt_editor.read(cx).codegen.clone(),
 317                workspace.clone(),
 318                cx,
 319            ),
 320        );
 321        assist_group.assist_ids.push(assist_id);
 322        editor_assists.assist_ids.push(assist_id);
 323        self.assist_groups.insert(assist_group_id, assist_group);
 324        assist_id
 325    }
 326
 327    fn insert_assist_blocks(
 328        &self,
 329        editor: &View<Editor>,
 330        range: &Range<Anchor>,
 331        prompt_editor: &View<PromptEditor>,
 332        cx: &mut WindowContext,
 333    ) -> [CustomBlockId; 2] {
 334        let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
 335            prompt_editor
 336                .editor
 337                .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
 338        });
 339        let assist_blocks = vec![
 340            BlockProperties {
 341                style: BlockStyle::Sticky,
 342                position: range.start,
 343                height: prompt_editor_height,
 344                render: build_assist_editor_renderer(prompt_editor),
 345                disposition: BlockDisposition::Above,
 346            },
 347            BlockProperties {
 348                style: BlockStyle::Sticky,
 349                position: range.end,
 350                height: 0,
 351                render: Box::new(|cx| {
 352                    v_flex()
 353                        .h_full()
 354                        .w_full()
 355                        .border_t_1()
 356                        .border_color(cx.theme().status().info_border)
 357                        .into_any_element()
 358                }),
 359                disposition: BlockDisposition::Below,
 360            },
 361        ];
 362
 363        editor.update(cx, |editor, cx| {
 364            let block_ids = editor.insert_blocks(assist_blocks, None, cx);
 365            [block_ids[0], block_ids[1]]
 366        })
 367    }
 368
 369    fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 370        let assist = &self.assists[&assist_id];
 371        let Some(decorations) = assist.decorations.as_ref() else {
 372            return;
 373        };
 374        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 375        let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
 376
 377        assist_group.active_assist_id = Some(assist_id);
 378        if assist_group.linked {
 379            for assist_id in &assist_group.assist_ids {
 380                if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 381                    decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 382                        prompt_editor.set_show_cursor_when_unfocused(true, cx)
 383                    });
 384                }
 385            }
 386        }
 387
 388        assist
 389            .editor
 390            .update(cx, |editor, cx| {
 391                let scroll_top = editor.scroll_position(cx).y;
 392                let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
 393                let prompt_row = editor
 394                    .row_for_block(decorations.prompt_block_id, cx)
 395                    .unwrap()
 396                    .0 as f32;
 397
 398                if (scroll_top..scroll_bottom).contains(&prompt_row) {
 399                    editor_assists.scroll_lock = Some(InlineAssistScrollLock {
 400                        assist_id,
 401                        distance_from_top: prompt_row - scroll_top,
 402                    });
 403                } else {
 404                    editor_assists.scroll_lock = None;
 405                }
 406            })
 407            .ok();
 408    }
 409
 410    fn handle_prompt_editor_focus_out(
 411        &mut self,
 412        assist_id: InlineAssistId,
 413        cx: &mut WindowContext,
 414    ) {
 415        let assist = &self.assists[&assist_id];
 416        let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
 417        if assist_group.active_assist_id == Some(assist_id) {
 418            assist_group.active_assist_id = None;
 419            if assist_group.linked {
 420                for assist_id in &assist_group.assist_ids {
 421                    if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
 422                        decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 423                            prompt_editor.set_show_cursor_when_unfocused(false, cx)
 424                        });
 425                    }
 426                }
 427            }
 428        }
 429    }
 430
 431    fn handle_prompt_editor_event(
 432        &mut self,
 433        prompt_editor: View<PromptEditor>,
 434        event: &PromptEditorEvent,
 435        cx: &mut WindowContext,
 436    ) {
 437        let assist_id = prompt_editor.read(cx).id;
 438        match event {
 439            PromptEditorEvent::StartRequested => {
 440                self.start_assist(assist_id, cx);
 441            }
 442            PromptEditorEvent::StopRequested => {
 443                self.stop_assist(assist_id, cx);
 444            }
 445            PromptEditorEvent::ConfirmRequested => {
 446                self.finish_assist(assist_id, false, cx);
 447            }
 448            PromptEditorEvent::CancelRequested => {
 449                self.finish_assist(assist_id, true, cx);
 450            }
 451            PromptEditorEvent::DismissRequested => {
 452                self.dismiss_assist(assist_id, cx);
 453            }
 454        }
 455    }
 456
 457    fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 458        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 459            return;
 460        };
 461
 462        let editor = editor.read(cx);
 463        if editor.selections.count() == 1 {
 464            let selection = editor.selections.newest::<usize>(cx);
 465            let buffer = editor.buffer().read(cx).snapshot(cx);
 466            for assist_id in &editor_assists.assist_ids {
 467                let assist = &self.assists[assist_id];
 468                let assist_range = assist.range.to_offset(&buffer);
 469                if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
 470                {
 471                    if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
 472                        self.dismiss_assist(*assist_id, cx);
 473                    } else {
 474                        self.finish_assist(*assist_id, false, cx);
 475                    }
 476
 477                    return;
 478                }
 479            }
 480        }
 481
 482        cx.propagate();
 483    }
 484
 485    fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 486        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 487            return;
 488        };
 489
 490        let editor = editor.read(cx);
 491        if editor.selections.count() == 1 {
 492            let selection = editor.selections.newest::<usize>(cx);
 493            let buffer = editor.buffer().read(cx).snapshot(cx);
 494            for assist_id in &editor_assists.assist_ids {
 495                let assist = &self.assists[assist_id];
 496                let assist_range = assist.range.to_offset(&buffer);
 497                if assist.decorations.is_some()
 498                    && assist_range.contains(&selection.start)
 499                    && assist_range.contains(&selection.end)
 500                {
 501                    self.focus_assist(*assist_id, cx);
 502                    return;
 503                }
 504            }
 505        }
 506
 507        cx.propagate();
 508    }
 509
 510    fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
 511        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
 512            for assist_id in editor_assists.assist_ids.clone() {
 513                self.finish_assist(assist_id, true, cx);
 514            }
 515        }
 516    }
 517
 518    fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
 519        let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
 520            return;
 521        };
 522        let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
 523            return;
 524        };
 525        let assist = &self.assists[&scroll_lock.assist_id];
 526        let Some(decorations) = assist.decorations.as_ref() else {
 527            return;
 528        };
 529
 530        editor.update(cx, |editor, cx| {
 531            let scroll_position = editor.scroll_position(cx);
 532            let target_scroll_top = editor
 533                .row_for_block(decorations.prompt_block_id, cx)
 534                .unwrap()
 535                .0 as f32
 536                - scroll_lock.distance_from_top;
 537            if target_scroll_top != scroll_position.y {
 538                editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
 539            }
 540        });
 541    }
 542
 543    fn handle_editor_event(
 544        &mut self,
 545        editor: View<Editor>,
 546        event: &EditorEvent,
 547        cx: &mut WindowContext,
 548    ) {
 549        let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
 550            return;
 551        };
 552
 553        match event {
 554            EditorEvent::Saved => {
 555                for assist_id in editor_assists.assist_ids.clone() {
 556                    let assist = &self.assists[&assist_id];
 557                    if let CodegenStatus::Done = &assist.codegen.read(cx).status {
 558                        self.finish_assist(assist_id, false, cx)
 559                    }
 560                }
 561            }
 562            EditorEvent::Edited { transaction_id } => {
 563                let buffer = editor.read(cx).buffer().read(cx);
 564                let edited_ranges =
 565                    buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
 566                let snapshot = buffer.snapshot(cx);
 567
 568                for assist_id in editor_assists.assist_ids.clone() {
 569                    let assist = &self.assists[&assist_id];
 570                    if matches!(
 571                        assist.codegen.read(cx).status,
 572                        CodegenStatus::Error(_) | CodegenStatus::Done
 573                    ) {
 574                        let assist_range = assist.range.to_offset(&snapshot);
 575                        if edited_ranges
 576                            .iter()
 577                            .any(|range| range.overlaps(&assist_range))
 578                        {
 579                            self.finish_assist(assist_id, false, cx);
 580                        }
 581                    }
 582                }
 583            }
 584            EditorEvent::ScrollPositionChanged { .. } => {
 585                if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
 586                    let assist = &self.assists[&scroll_lock.assist_id];
 587                    if let Some(decorations) = assist.decorations.as_ref() {
 588                        let distance_from_top = editor.update(cx, |editor, cx| {
 589                            let scroll_top = editor.scroll_position(cx).y;
 590                            let prompt_row = editor
 591                                .row_for_block(decorations.prompt_block_id, cx)
 592                                .unwrap()
 593                                .0 as f32;
 594                            prompt_row - scroll_top
 595                        });
 596
 597                        if distance_from_top != scroll_lock.distance_from_top {
 598                            editor_assists.scroll_lock = None;
 599                        }
 600                    }
 601                }
 602            }
 603            EditorEvent::SelectionsChanged { .. } => {
 604                for assist_id in editor_assists.assist_ids.clone() {
 605                    let assist = &self.assists[&assist_id];
 606                    if let Some(decorations) = assist.decorations.as_ref() {
 607                        if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
 608                            return;
 609                        }
 610                    }
 611                }
 612
 613                editor_assists.scroll_lock = None;
 614            }
 615            _ => {}
 616        }
 617    }
 618
 619    pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
 620        if let Some(assist) = self.assists.get(&assist_id) {
 621            let assist_group_id = assist.group_id;
 622            if self.assist_groups[&assist_group_id].linked {
 623                for assist_id in self.unlink_assist_group(assist_group_id, cx) {
 624                    self.finish_assist(assist_id, undo, cx);
 625                }
 626                return;
 627            }
 628        }
 629
 630        self.dismiss_assist(assist_id, cx);
 631
 632        if let Some(assist) = self.assists.remove(&assist_id) {
 633            if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
 634            {
 635                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 636                if entry.get().assist_ids.is_empty() {
 637                    entry.remove();
 638                }
 639            }
 640
 641            if let hash_map::Entry::Occupied(mut entry) =
 642                self.assists_by_editor.entry(assist.editor.clone())
 643            {
 644                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 645                if entry.get().assist_ids.is_empty() {
 646                    entry.remove();
 647                    if let Some(editor) = assist.editor.upgrade() {
 648                        self.update_editor_highlights(&editor, cx);
 649                    }
 650                } else {
 651                    entry.get().highlight_updates.send(()).ok();
 652                }
 653            }
 654
 655            if undo {
 656                assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
 657            }
 658        }
 659    }
 660
 661    fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
 662        let Some(assist) = self.assists.get_mut(&assist_id) else {
 663            return false;
 664        };
 665        let Some(editor) = assist.editor.upgrade() else {
 666            return false;
 667        };
 668        let Some(decorations) = assist.decorations.take() else {
 669            return false;
 670        };
 671
 672        editor.update(cx, |editor, cx| {
 673            let mut to_remove = decorations.removed_line_block_ids;
 674            to_remove.insert(decorations.prompt_block_id);
 675            to_remove.insert(decorations.end_block_id);
 676            editor.remove_blocks(to_remove, None, cx);
 677        });
 678
 679        if decorations
 680            .prompt_editor
 681            .focus_handle(cx)
 682            .contains_focused(cx)
 683        {
 684            self.focus_next_assist(assist_id, cx);
 685        }
 686
 687        if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
 688            if editor_assists
 689                .scroll_lock
 690                .as_ref()
 691                .map_or(false, |lock| lock.assist_id == assist_id)
 692            {
 693                editor_assists.scroll_lock = None;
 694            }
 695            editor_assists.highlight_updates.send(()).ok();
 696        }
 697
 698        true
 699    }
 700
 701    fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 702        let Some(assist) = self.assists.get(&assist_id) else {
 703            return;
 704        };
 705
 706        let assist_group = &self.assist_groups[&assist.group_id];
 707        let assist_ix = assist_group
 708            .assist_ids
 709            .iter()
 710            .position(|id| *id == assist_id)
 711            .unwrap();
 712        let assist_ids = assist_group
 713            .assist_ids
 714            .iter()
 715            .skip(assist_ix + 1)
 716            .chain(assist_group.assist_ids.iter().take(assist_ix));
 717
 718        for assist_id in assist_ids {
 719            let assist = &self.assists[assist_id];
 720            if assist.decorations.is_some() {
 721                self.focus_assist(*assist_id, cx);
 722                return;
 723            }
 724        }
 725
 726        assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
 727    }
 728
 729    fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 730        let Some(assist) = self.assists.get(&assist_id) else {
 731            return;
 732        };
 733
 734        if let Some(decorations) = assist.decorations.as_ref() {
 735            decorations.prompt_editor.update(cx, |prompt_editor, cx| {
 736                prompt_editor.editor.update(cx, |editor, cx| {
 737                    editor.focus(cx);
 738                    editor.select_all(&SelectAll, cx);
 739                })
 740            });
 741        }
 742
 743        self.scroll_to_assist(assist_id, cx);
 744    }
 745
 746    pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 747        let Some(assist) = self.assists.get(&assist_id) else {
 748            return;
 749        };
 750        let Some(editor) = assist.editor.upgrade() else {
 751            return;
 752        };
 753
 754        let position = assist.range.start;
 755        editor.update(cx, |editor, cx| {
 756            editor.change_selections(None, cx, |selections| {
 757                selections.select_anchor_ranges([position..position])
 758            });
 759
 760            let mut scroll_target_top;
 761            let mut scroll_target_bottom;
 762            if let Some(decorations) = assist.decorations.as_ref() {
 763                scroll_target_top = editor
 764                    .row_for_block(decorations.prompt_block_id, cx)
 765                    .unwrap()
 766                    .0 as f32;
 767                scroll_target_bottom = editor
 768                    .row_for_block(decorations.end_block_id, cx)
 769                    .unwrap()
 770                    .0 as f32;
 771            } else {
 772                let snapshot = editor.snapshot(cx);
 773                let start_row = assist
 774                    .range
 775                    .start
 776                    .to_display_point(&snapshot.display_snapshot)
 777                    .row();
 778                scroll_target_top = start_row.0 as f32;
 779                scroll_target_bottom = scroll_target_top + 1.;
 780            }
 781            scroll_target_top -= editor.vertical_scroll_margin() as f32;
 782            scroll_target_bottom += editor.vertical_scroll_margin() as f32;
 783
 784            let height_in_lines = editor.visible_line_count().unwrap_or(0.);
 785            let scroll_top = editor.scroll_position(cx).y;
 786            let scroll_bottom = scroll_top + height_in_lines;
 787
 788            if scroll_target_top < scroll_top {
 789                editor.set_scroll_position(point(0., scroll_target_top), cx);
 790            } else if scroll_target_bottom > scroll_bottom {
 791                if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
 792                    editor
 793                        .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
 794                } else {
 795                    editor.set_scroll_position(point(0., scroll_target_top), cx);
 796                }
 797            }
 798        });
 799    }
 800
 801    fn unlink_assist_group(
 802        &mut self,
 803        assist_group_id: InlineAssistGroupId,
 804        cx: &mut WindowContext,
 805    ) -> Vec<InlineAssistId> {
 806        let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
 807        assist_group.linked = false;
 808        for assist_id in &assist_group.assist_ids {
 809            let assist = self.assists.get_mut(assist_id).unwrap();
 810            if let Some(editor_decorations) = assist.decorations.as_ref() {
 811                editor_decorations
 812                    .prompt_editor
 813                    .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
 814            }
 815        }
 816        assist_group.assist_ids.clone()
 817    }
 818
 819    pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 820        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
 821            assist
 822        } else {
 823            return;
 824        };
 825
 826        let assist_group_id = assist.group_id;
 827        if self.assist_groups[&assist_group_id].linked {
 828            for assist_id in self.unlink_assist_group(assist_group_id, cx) {
 829                self.start_assist(assist_id, cx);
 830            }
 831            return;
 832        }
 833
 834        let Some(user_prompt) = assist.user_prompt(cx) else {
 835            return;
 836        };
 837
 838        self.prompt_history.retain(|prompt| *prompt != user_prompt);
 839        self.prompt_history.push_back(user_prompt.clone());
 840        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
 841            self.prompt_history.pop_front();
 842        }
 843
 844        let assistant_panel_context = assist.assistant_panel_context(cx);
 845
 846        assist
 847            .codegen
 848            .update(cx, |codegen, cx| {
 849                codegen.start(
 850                    assist.range.clone(),
 851                    user_prompt,
 852                    assistant_panel_context,
 853                    cx,
 854                )
 855            })
 856            .log_err();
 857    }
 858
 859    pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 860        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
 861            assist
 862        } else {
 863            return;
 864        };
 865
 866        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
 867    }
 868
 869    pub fn status_for_assist(
 870        &self,
 871        assist_id: InlineAssistId,
 872        cx: &WindowContext,
 873    ) -> Option<CodegenStatus> {
 874        let assist = self.assists.get(&assist_id)?;
 875        match &assist.codegen.read(cx).status {
 876            CodegenStatus::Idle => Some(CodegenStatus::Idle),
 877            CodegenStatus::Pending => Some(CodegenStatus::Pending),
 878            CodegenStatus::Done => Some(CodegenStatus::Done),
 879            CodegenStatus::Error(error) => Some(CodegenStatus::Error(anyhow!("{:?}", error))),
 880        }
 881    }
 882
 883    fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
 884        let mut gutter_pending_ranges = Vec::new();
 885        let mut gutter_transformed_ranges = Vec::new();
 886        let mut foreground_ranges = Vec::new();
 887        let mut inserted_row_ranges = Vec::new();
 888        let empty_assist_ids = Vec::new();
 889        let assist_ids = self
 890            .assists_by_editor
 891            .get(&editor.downgrade())
 892            .map_or(&empty_assist_ids, |editor_assists| {
 893                &editor_assists.assist_ids
 894            });
 895
 896        for assist_id in assist_ids {
 897            if let Some(assist) = self.assists.get(assist_id) {
 898                let codegen = assist.codegen.read(cx);
 899                let buffer = codegen.buffer.read(cx).read(cx);
 900                foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
 901
 902                let pending_range =
 903                    codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
 904                if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
 905                    gutter_pending_ranges.push(pending_range);
 906                }
 907
 908                if let Some(edit_position) = codegen.edit_position {
 909                    let edited_range = assist.range.start..edit_position;
 910                    if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
 911                        gutter_transformed_ranges.push(edited_range);
 912                    }
 913                }
 914
 915                if assist.decorations.is_some() {
 916                    inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
 917                }
 918            }
 919        }
 920
 921        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 922        merge_ranges(&mut foreground_ranges, &snapshot);
 923        merge_ranges(&mut gutter_pending_ranges, &snapshot);
 924        merge_ranges(&mut gutter_transformed_ranges, &snapshot);
 925        editor.update(cx, |editor, cx| {
 926            enum GutterPendingRange {}
 927            if gutter_pending_ranges.is_empty() {
 928                editor.clear_gutter_highlights::<GutterPendingRange>(cx);
 929            } else {
 930                editor.highlight_gutter::<GutterPendingRange>(
 931                    &gutter_pending_ranges,
 932                    |cx| cx.theme().status().info_background,
 933                    cx,
 934                )
 935            }
 936
 937            enum GutterTransformedRange {}
 938            if gutter_transformed_ranges.is_empty() {
 939                editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
 940            } else {
 941                editor.highlight_gutter::<GutterTransformedRange>(
 942                    &gutter_transformed_ranges,
 943                    |cx| cx.theme().status().info,
 944                    cx,
 945                )
 946            }
 947
 948            if foreground_ranges.is_empty() {
 949                editor.clear_highlights::<InlineAssist>(cx);
 950            } else {
 951                editor.highlight_text::<InlineAssist>(
 952                    foreground_ranges,
 953                    HighlightStyle {
 954                        fade_out: Some(0.6),
 955                        ..Default::default()
 956                    },
 957                    cx,
 958                );
 959            }
 960
 961            editor.clear_row_highlights::<InlineAssist>();
 962            for row_range in inserted_row_ranges {
 963                editor.highlight_rows::<InlineAssist>(
 964                    row_range,
 965                    Some(cx.theme().status().info_background),
 966                    false,
 967                    cx,
 968                );
 969            }
 970        });
 971    }
 972
 973    fn update_editor_blocks(
 974        &mut self,
 975        editor: &View<Editor>,
 976        assist_id: InlineAssistId,
 977        cx: &mut WindowContext,
 978    ) {
 979        let Some(assist) = self.assists.get_mut(&assist_id) else {
 980            return;
 981        };
 982        let Some(decorations) = assist.decorations.as_mut() else {
 983            return;
 984        };
 985
 986        let codegen = assist.codegen.read(cx);
 987        let old_snapshot = codegen.snapshot.clone();
 988        let old_buffer = codegen.old_buffer.clone();
 989        let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
 990
 991        editor.update(cx, |editor, cx| {
 992            let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
 993            editor.remove_blocks(old_blocks, None, cx);
 994
 995            let mut new_blocks = Vec::new();
 996            for (new_row, old_row_range) in deleted_row_ranges {
 997                let (_, buffer_start) = old_snapshot
 998                    .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
 999                    .unwrap();
1000                let (_, buffer_end) = old_snapshot
1001                    .point_to_buffer_offset(Point::new(
1002                        *old_row_range.end(),
1003                        old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1004                    ))
1005                    .unwrap();
1006
1007                let deleted_lines_editor = cx.new_view(|cx| {
1008                    let multi_buffer = cx.new_model(|_| {
1009                        MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1010                    });
1011                    multi_buffer.update(cx, |multi_buffer, cx| {
1012                        multi_buffer.push_excerpts(
1013                            old_buffer.clone(),
1014                            Some(ExcerptRange {
1015                                context: buffer_start..buffer_end,
1016                                primary: None,
1017                            }),
1018                            cx,
1019                        );
1020                    });
1021
1022                    enum DeletedLines {}
1023                    let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1024                    editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1025                    editor.set_show_wrap_guides(false, cx);
1026                    editor.set_show_gutter(false, cx);
1027                    editor.scroll_manager.set_forbid_vertical_scroll(true);
1028                    editor.set_read_only(true);
1029                    editor.highlight_rows::<DeletedLines>(
1030                        Anchor::min()..=Anchor::max(),
1031                        Some(cx.theme().status().deleted_background),
1032                        false,
1033                        cx,
1034                    );
1035                    editor
1036                });
1037
1038                let height =
1039                    deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1040                new_blocks.push(BlockProperties {
1041                    position: new_row,
1042                    height,
1043                    style: BlockStyle::Flex,
1044                    render: Box::new(move |cx| {
1045                        div()
1046                            .bg(cx.theme().status().deleted_background)
1047                            .size_full()
1048                            .h(height as f32 * cx.line_height())
1049                            .pl(cx.gutter_dimensions.full_width())
1050                            .child(deleted_lines_editor.clone())
1051                            .into_any_element()
1052                    }),
1053                    disposition: BlockDisposition::Above,
1054                });
1055            }
1056
1057            decorations.removed_line_block_ids = editor
1058                .insert_blocks(new_blocks, None, cx)
1059                .into_iter()
1060                .collect();
1061        })
1062    }
1063}
1064
1065struct EditorInlineAssists {
1066    assist_ids: Vec<InlineAssistId>,
1067    scroll_lock: Option<InlineAssistScrollLock>,
1068    highlight_updates: async_watch::Sender<()>,
1069    _update_highlights: Task<Result<()>>,
1070    _subscriptions: Vec<gpui::Subscription>,
1071}
1072
1073struct InlineAssistScrollLock {
1074    assist_id: InlineAssistId,
1075    distance_from_top: f32,
1076}
1077
1078impl EditorInlineAssists {
1079    #[allow(clippy::too_many_arguments)]
1080    fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1081        let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1082        Self {
1083            assist_ids: Vec::new(),
1084            scroll_lock: None,
1085            highlight_updates: highlight_updates_tx,
1086            _update_highlights: cx.spawn(|mut cx| {
1087                let editor = editor.downgrade();
1088                async move {
1089                    while let Ok(()) = highlight_updates_rx.changed().await {
1090                        let editor = editor.upgrade().context("editor was dropped")?;
1091                        cx.update_global(|assistant: &mut InlineAssistant, cx| {
1092                            assistant.update_editor_highlights(&editor, cx);
1093                        })?;
1094                    }
1095                    Ok(())
1096                }
1097            }),
1098            _subscriptions: vec![
1099                cx.observe_release(editor, {
1100                    let editor = editor.downgrade();
1101                    |_, cx| {
1102                        InlineAssistant::update_global(cx, |this, cx| {
1103                            this.handle_editor_release(editor, cx);
1104                        })
1105                    }
1106                }),
1107                cx.observe(editor, move |editor, cx| {
1108                    InlineAssistant::update_global(cx, |this, cx| {
1109                        this.handle_editor_change(editor, cx)
1110                    })
1111                }),
1112                cx.subscribe(editor, move |editor, event, cx| {
1113                    InlineAssistant::update_global(cx, |this, cx| {
1114                        this.handle_editor_event(editor, event, cx)
1115                    })
1116                }),
1117                editor.update(cx, |editor, cx| {
1118                    let editor_handle = cx.view().downgrade();
1119                    editor.register_action(
1120                        move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1121                            InlineAssistant::update_global(cx, |this, cx| {
1122                                if let Some(editor) = editor_handle.upgrade() {
1123                                    this.handle_editor_newline(editor, cx)
1124                                }
1125                            })
1126                        },
1127                    )
1128                }),
1129                editor.update(cx, |editor, cx| {
1130                    let editor_handle = cx.view().downgrade();
1131                    editor.register_action(
1132                        move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1133                            InlineAssistant::update_global(cx, |this, cx| {
1134                                if let Some(editor) = editor_handle.upgrade() {
1135                                    this.handle_editor_cancel(editor, cx)
1136                                }
1137                            })
1138                        },
1139                    )
1140                }),
1141            ],
1142        }
1143    }
1144}
1145
1146struct InlineAssistGroup {
1147    assist_ids: Vec<InlineAssistId>,
1148    linked: bool,
1149    active_assist_id: Option<InlineAssistId>,
1150}
1151
1152impl InlineAssistGroup {
1153    fn new() -> Self {
1154        Self {
1155            assist_ids: Vec::new(),
1156            linked: true,
1157            active_assist_id: None,
1158        }
1159    }
1160}
1161
1162fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1163    let editor = editor.clone();
1164    Box::new(move |cx: &mut BlockContext| {
1165        *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1166        editor.clone().into_any_element()
1167    })
1168}
1169
1170#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1171pub enum InitialInsertion {
1172    NewlineBefore,
1173    NewlineAfter,
1174}
1175
1176#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1177pub struct InlineAssistId(usize);
1178
1179impl InlineAssistId {
1180    fn post_inc(&mut self) -> InlineAssistId {
1181        let id = *self;
1182        self.0 += 1;
1183        id
1184    }
1185}
1186
1187#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1188struct InlineAssistGroupId(usize);
1189
1190impl InlineAssistGroupId {
1191    fn post_inc(&mut self) -> InlineAssistGroupId {
1192        let id = *self;
1193        self.0 += 1;
1194        id
1195    }
1196}
1197
1198enum PromptEditorEvent {
1199    StartRequested,
1200    StopRequested,
1201    ConfirmRequested,
1202    CancelRequested,
1203    DismissRequested,
1204}
1205
1206struct PromptEditor {
1207    id: InlineAssistId,
1208    fs: Arc<dyn Fs>,
1209    editor: View<Editor>,
1210    edited_since_done: bool,
1211    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1212    prompt_history: VecDeque<String>,
1213    prompt_history_ix: Option<usize>,
1214    pending_prompt: String,
1215    codegen: Model<Codegen>,
1216    _codegen_subscription: Subscription,
1217    editor_subscriptions: Vec<Subscription>,
1218    pending_token_count: Task<Result<()>>,
1219    token_count: Option<usize>,
1220    _token_count_subscriptions: Vec<Subscription>,
1221    workspace: Option<WeakView<Workspace>>,
1222    show_rate_limit_notice: bool,
1223}
1224
1225impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1226
1227impl Render for PromptEditor {
1228    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1229        let gutter_dimensions = *self.gutter_dimensions.lock();
1230        let status = &self.codegen.read(cx).status;
1231        let buttons = match status {
1232            CodegenStatus::Idle => {
1233                vec![
1234                    IconButton::new("cancel", IconName::Close)
1235                        .icon_color(Color::Muted)
1236                        .shape(IconButtonShape::Square)
1237                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1238                        .on_click(
1239                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1240                        ),
1241                    IconButton::new("start", IconName::SparkleAlt)
1242                        .icon_color(Color::Muted)
1243                        .shape(IconButtonShape::Square)
1244                        .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1245                        .on_click(
1246                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1247                        ),
1248                ]
1249            }
1250            CodegenStatus::Pending => {
1251                vec![
1252                    IconButton::new("cancel", IconName::Close)
1253                        .icon_color(Color::Muted)
1254                        .shape(IconButtonShape::Square)
1255                        .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1256                        .on_click(
1257                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1258                        ),
1259                    IconButton::new("stop", IconName::Stop)
1260                        .icon_color(Color::Error)
1261                        .shape(IconButtonShape::Square)
1262                        .tooltip(|cx| {
1263                            Tooltip::with_meta(
1264                                "Interrupt Transformation",
1265                                Some(&menu::Cancel),
1266                                "Changes won't be discarded",
1267                                cx,
1268                            )
1269                        })
1270                        .on_click(
1271                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1272                        ),
1273                ]
1274            }
1275            CodegenStatus::Error(_) | CodegenStatus::Done => {
1276                vec![
1277                    IconButton::new("cancel", IconName::Close)
1278                        .icon_color(Color::Muted)
1279                        .shape(IconButtonShape::Square)
1280                        .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1281                        .on_click(
1282                            cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1283                        ),
1284                    if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1285                        IconButton::new("restart", IconName::RotateCw)
1286                            .icon_color(Color::Info)
1287                            .shape(IconButtonShape::Square)
1288                            .tooltip(|cx| {
1289                                Tooltip::with_meta(
1290                                    "Restart Transformation",
1291                                    Some(&menu::Confirm),
1292                                    "Changes will be discarded",
1293                                    cx,
1294                                )
1295                            })
1296                            .on_click(cx.listener(|_, _, cx| {
1297                                cx.emit(PromptEditorEvent::StartRequested);
1298                            }))
1299                    } else {
1300                        IconButton::new("confirm", IconName::Check)
1301                            .icon_color(Color::Info)
1302                            .shape(IconButtonShape::Square)
1303                            .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1304                            .on_click(cx.listener(|_, _, cx| {
1305                                cx.emit(PromptEditorEvent::ConfirmRequested);
1306                            }))
1307                    },
1308                ]
1309            }
1310        };
1311
1312        h_flex()
1313            .bg(cx.theme().colors().editor_background)
1314            .border_y_1()
1315            .border_color(cx.theme().status().info_border)
1316            .size_full()
1317            .py(cx.line_height() / 2.)
1318            .on_action(cx.listener(Self::confirm))
1319            .on_action(cx.listener(Self::cancel))
1320            .on_action(cx.listener(Self::move_up))
1321            .on_action(cx.listener(Self::move_down))
1322            .child(
1323                h_flex()
1324                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1325                    .justify_center()
1326                    .gap_2()
1327                    .child(
1328                        ModelSelector::new(
1329                            self.fs.clone(),
1330                            IconButton::new("context", IconName::SlidersAlt)
1331                                .shape(IconButtonShape::Square)
1332                                .icon_size(IconSize::Small)
1333                                .icon_color(Color::Muted)
1334                                .tooltip(move |cx| {
1335                                    Tooltip::with_meta(
1336                                        format!(
1337                                            "Using {}",
1338                                            LanguageModelRegistry::read_global(cx)
1339                                                .active_model()
1340                                                .map(|model| model.name().0)
1341                                                .unwrap_or_else(|| "No model selected".into()),
1342                                        ),
1343                                        None,
1344                                        "Change Model",
1345                                        cx,
1346                                    )
1347                                }),
1348                        )
1349                        .with_info_text(
1350                            "Inline edits use context\n\
1351                            from the currently selected\n\
1352                            assistant panel tab.",
1353                        ),
1354                    )
1355                    .map(|el| {
1356                        let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1357                            return el;
1358                        };
1359
1360                        let error_message = SharedString::from(error.to_string());
1361                        if error.error_code() == proto::ErrorCode::RateLimitExceeded
1362                            && cx.has_flag::<ZedPro>()
1363                        {
1364                            el.child(
1365                                v_flex()
1366                                    .child(
1367                                        IconButton::new("rate-limit-error", IconName::XCircle)
1368                                            .selected(self.show_rate_limit_notice)
1369                                            .shape(IconButtonShape::Square)
1370                                            .icon_size(IconSize::Small)
1371                                            .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1372                                    )
1373                                    .children(self.show_rate_limit_notice.then(|| {
1374                                        deferred(
1375                                            anchored()
1376                                                .position_mode(gpui::AnchoredPositionMode::Local)
1377                                                .position(point(px(0.), px(24.)))
1378                                                .anchor(gpui::AnchorCorner::TopLeft)
1379                                                .child(self.render_rate_limit_notice(cx)),
1380                                        )
1381                                    })),
1382                            )
1383                        } else {
1384                            el.child(
1385                                div()
1386                                    .id("error")
1387                                    .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1388                                    .child(
1389                                        Icon::new(IconName::XCircle)
1390                                            .size(IconSize::Small)
1391                                            .color(Color::Error),
1392                                    ),
1393                            )
1394                        }
1395                    }),
1396            )
1397            .child(div().flex_1().child(self.render_prompt_editor(cx)))
1398            .child(
1399                h_flex()
1400                    .gap_2()
1401                    .pr_6()
1402                    .children(self.render_token_count(cx))
1403                    .children(buttons),
1404            )
1405    }
1406}
1407
1408impl FocusableView for PromptEditor {
1409    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1410        self.editor.focus_handle(cx)
1411    }
1412}
1413
1414impl PromptEditor {
1415    const MAX_LINES: u8 = 8;
1416
1417    #[allow(clippy::too_many_arguments)]
1418    fn new(
1419        id: InlineAssistId,
1420        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1421        prompt_history: VecDeque<String>,
1422        prompt_buffer: Model<MultiBuffer>,
1423        codegen: Model<Codegen>,
1424        parent_editor: &View<Editor>,
1425        assistant_panel: Option<&View<AssistantPanel>>,
1426        workspace: Option<WeakView<Workspace>>,
1427        fs: Arc<dyn Fs>,
1428        cx: &mut ViewContext<Self>,
1429    ) -> Self {
1430        let prompt_editor = cx.new_view(|cx| {
1431            let mut editor = Editor::new(
1432                EditorMode::AutoHeight {
1433                    max_lines: Self::MAX_LINES as usize,
1434                },
1435                prompt_buffer,
1436                None,
1437                false,
1438                cx,
1439            );
1440            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1441            // Since the prompt editors for all inline assistants are linked,
1442            // always show the cursor (even when it isn't focused) because
1443            // typing in one will make what you typed appear in all of them.
1444            editor.set_show_cursor_when_unfocused(true, cx);
1445            editor.set_placeholder_text("Add a prompt…", cx);
1446            editor
1447        });
1448
1449        let mut token_count_subscriptions = Vec::new();
1450        token_count_subscriptions
1451            .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1452        if let Some(assistant_panel) = assistant_panel {
1453            token_count_subscriptions
1454                .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1455        }
1456
1457        let mut this = Self {
1458            id,
1459            editor: prompt_editor,
1460            edited_since_done: false,
1461            gutter_dimensions,
1462            prompt_history,
1463            prompt_history_ix: None,
1464            pending_prompt: String::new(),
1465            _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1466            editor_subscriptions: Vec::new(),
1467            codegen,
1468            fs,
1469            pending_token_count: Task::ready(Ok(())),
1470            token_count: None,
1471            _token_count_subscriptions: token_count_subscriptions,
1472            workspace,
1473            show_rate_limit_notice: false,
1474        };
1475        this.count_tokens(cx);
1476        this.subscribe_to_editor(cx);
1477        this
1478    }
1479
1480    fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1481        self.editor_subscriptions.clear();
1482        self.editor_subscriptions
1483            .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1484    }
1485
1486    fn set_show_cursor_when_unfocused(
1487        &mut self,
1488        show_cursor_when_unfocused: bool,
1489        cx: &mut ViewContext<Self>,
1490    ) {
1491        self.editor.update(cx, |editor, cx| {
1492            editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1493        });
1494    }
1495
1496    fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1497        let prompt = self.prompt(cx);
1498        let focus = self.editor.focus_handle(cx).contains_focused(cx);
1499        self.editor = cx.new_view(|cx| {
1500            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1501            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1502            editor.set_placeholder_text("Add a prompt…", cx);
1503            editor.set_text(prompt, cx);
1504            if focus {
1505                editor.focus(cx);
1506            }
1507            editor
1508        });
1509        self.subscribe_to_editor(cx);
1510    }
1511
1512    fn prompt(&self, cx: &AppContext) -> String {
1513        self.editor.read(cx).text(cx)
1514    }
1515
1516    fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1517        self.show_rate_limit_notice = !self.show_rate_limit_notice;
1518        if self.show_rate_limit_notice {
1519            cx.focus_view(&self.editor);
1520        }
1521        cx.notify();
1522    }
1523
1524    fn handle_parent_editor_event(
1525        &mut self,
1526        _: View<Editor>,
1527        event: &EditorEvent,
1528        cx: &mut ViewContext<Self>,
1529    ) {
1530        if let EditorEvent::BufferEdited { .. } = event {
1531            self.count_tokens(cx);
1532        }
1533    }
1534
1535    fn handle_assistant_panel_event(
1536        &mut self,
1537        _: View<AssistantPanel>,
1538        event: &AssistantPanelEvent,
1539        cx: &mut ViewContext<Self>,
1540    ) {
1541        let AssistantPanelEvent::ContextEdited { .. } = event;
1542        self.count_tokens(cx);
1543    }
1544
1545    fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1546        let assist_id = self.id;
1547        self.pending_token_count = cx.spawn(|this, mut cx| async move {
1548            cx.background_executor().timer(Duration::from_secs(1)).await;
1549            let token_count = cx
1550                .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1551                    let assist = inline_assistant
1552                        .assists
1553                        .get(&assist_id)
1554                        .context("assist not found")?;
1555                    anyhow::Ok(assist.count_tokens(cx))
1556                })??
1557                .await?;
1558
1559            this.update(&mut cx, |this, cx| {
1560                this.token_count = Some(token_count);
1561                cx.notify();
1562            })
1563        })
1564    }
1565
1566    fn handle_prompt_editor_events(
1567        &mut self,
1568        _: View<Editor>,
1569        event: &EditorEvent,
1570        cx: &mut ViewContext<Self>,
1571    ) {
1572        match event {
1573            EditorEvent::Edited { .. } => {
1574                let prompt = self.editor.read(cx).text(cx);
1575                if self
1576                    .prompt_history_ix
1577                    .map_or(true, |ix| self.prompt_history[ix] != prompt)
1578                {
1579                    self.prompt_history_ix.take();
1580                    self.pending_prompt = prompt;
1581                }
1582
1583                self.edited_since_done = true;
1584                cx.notify();
1585            }
1586            EditorEvent::BufferEdited => {
1587                self.count_tokens(cx);
1588            }
1589            EditorEvent::Blurred => {
1590                if self.show_rate_limit_notice {
1591                    self.show_rate_limit_notice = false;
1592                    cx.notify();
1593                }
1594            }
1595            _ => {}
1596        }
1597    }
1598
1599    fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1600        match &self.codegen.read(cx).status {
1601            CodegenStatus::Idle => {
1602                self.editor
1603                    .update(cx, |editor, _| editor.set_read_only(false));
1604            }
1605            CodegenStatus::Pending => {
1606                self.editor
1607                    .update(cx, |editor, _| editor.set_read_only(true));
1608            }
1609            CodegenStatus::Done => {
1610                self.edited_since_done = false;
1611                self.editor
1612                    .update(cx, |editor, _| editor.set_read_only(false));
1613            }
1614            CodegenStatus::Error(error) => {
1615                if cx.has_flag::<ZedPro>()
1616                    && error.error_code() == proto::ErrorCode::RateLimitExceeded
1617                    && !dismissed_rate_limit_notice()
1618                {
1619                    self.show_rate_limit_notice = true;
1620                    cx.notify();
1621                }
1622
1623                self.edited_since_done = false;
1624                self.editor
1625                    .update(cx, |editor, _| editor.set_read_only(false));
1626            }
1627        }
1628    }
1629
1630    fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1631        match &self.codegen.read(cx).status {
1632            CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1633                cx.emit(PromptEditorEvent::CancelRequested);
1634            }
1635            CodegenStatus::Pending => {
1636                cx.emit(PromptEditorEvent::StopRequested);
1637            }
1638        }
1639    }
1640
1641    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1642        match &self.codegen.read(cx).status {
1643            CodegenStatus::Idle => {
1644                cx.emit(PromptEditorEvent::StartRequested);
1645            }
1646            CodegenStatus::Pending => {
1647                cx.emit(PromptEditorEvent::DismissRequested);
1648            }
1649            CodegenStatus::Done | CodegenStatus::Error(_) => {
1650                if self.edited_since_done {
1651                    cx.emit(PromptEditorEvent::StartRequested);
1652                } else {
1653                    cx.emit(PromptEditorEvent::ConfirmRequested);
1654                }
1655            }
1656        }
1657    }
1658
1659    fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1660        if let Some(ix) = self.prompt_history_ix {
1661            if ix > 0 {
1662                self.prompt_history_ix = Some(ix - 1);
1663                let prompt = self.prompt_history[ix - 1].as_str();
1664                self.editor.update(cx, |editor, cx| {
1665                    editor.set_text(prompt, cx);
1666                    editor.move_to_beginning(&Default::default(), cx);
1667                });
1668            }
1669        } else if !self.prompt_history.is_empty() {
1670            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1671            let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1672            self.editor.update(cx, |editor, cx| {
1673                editor.set_text(prompt, cx);
1674                editor.move_to_beginning(&Default::default(), cx);
1675            });
1676        }
1677    }
1678
1679    fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1680        if let Some(ix) = self.prompt_history_ix {
1681            if ix < self.prompt_history.len() - 1 {
1682                self.prompt_history_ix = Some(ix + 1);
1683                let prompt = self.prompt_history[ix + 1].as_str();
1684                self.editor.update(cx, |editor, cx| {
1685                    editor.set_text(prompt, cx);
1686                    editor.move_to_end(&Default::default(), cx)
1687                });
1688            } else {
1689                self.prompt_history_ix = None;
1690                let prompt = self.pending_prompt.as_str();
1691                self.editor.update(cx, |editor, cx| {
1692                    editor.set_text(prompt, cx);
1693                    editor.move_to_end(&Default::default(), cx)
1694                });
1695            }
1696        }
1697    }
1698
1699    fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1700        let model = LanguageModelRegistry::read_global(cx).active_model()?;
1701        let token_count = self.token_count?;
1702        let max_token_count = model.max_token_count();
1703
1704        let remaining_tokens = max_token_count as isize - token_count as isize;
1705        let token_count_color = if remaining_tokens <= 0 {
1706            Color::Error
1707        } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1708            Color::Warning
1709        } else {
1710            Color::Muted
1711        };
1712
1713        let mut token_count = h_flex()
1714            .id("token_count")
1715            .gap_0p5()
1716            .child(
1717                Label::new(humanize_token_count(token_count))
1718                    .size(LabelSize::Small)
1719                    .color(token_count_color),
1720            )
1721            .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1722            .child(
1723                Label::new(humanize_token_count(max_token_count))
1724                    .size(LabelSize::Small)
1725                    .color(Color::Muted),
1726            );
1727        if let Some(workspace) = self.workspace.clone() {
1728            token_count = token_count
1729                .tooltip(|cx| {
1730                    Tooltip::with_meta(
1731                        "Tokens Used by Inline Assistant",
1732                        None,
1733                        "Click to Open Assistant Panel",
1734                        cx,
1735                    )
1736                })
1737                .cursor_pointer()
1738                .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1739                .on_click(move |_, cx| {
1740                    cx.stop_propagation();
1741                    workspace
1742                        .update(cx, |workspace, cx| {
1743                            workspace.focus_panel::<AssistantPanel>(cx)
1744                        })
1745                        .ok();
1746                });
1747        } else {
1748            token_count = token_count
1749                .cursor_default()
1750                .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1751        }
1752
1753        Some(token_count)
1754    }
1755
1756    fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1757        let settings = ThemeSettings::get_global(cx);
1758        let text_style = TextStyle {
1759            color: if self.editor.read(cx).read_only(cx) {
1760                cx.theme().colors().text_disabled
1761            } else {
1762                cx.theme().colors().text
1763            },
1764            font_family: settings.ui_font.family.clone(),
1765            font_features: settings.ui_font.features.clone(),
1766            font_fallbacks: settings.ui_font.fallbacks.clone(),
1767            font_size: rems(0.875).into(),
1768            font_weight: settings.ui_font.weight,
1769            line_height: relative(1.3),
1770            ..Default::default()
1771        };
1772        EditorElement::new(
1773            &self.editor,
1774            EditorStyle {
1775                background: cx.theme().colors().editor_background,
1776                local_player: cx.theme().players().local(),
1777                text: text_style,
1778                ..Default::default()
1779            },
1780        )
1781    }
1782
1783    fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1784        Popover::new().child(
1785            v_flex()
1786                .occlude()
1787                .p_2()
1788                .child(
1789                    Label::new("Out of Tokens")
1790                        .size(LabelSize::Small)
1791                        .weight(FontWeight::BOLD),
1792                )
1793                .child(Label::new(
1794                    "Try Zed Pro for higher limits, a wider range of models, and more.",
1795                ))
1796                .child(
1797                    h_flex()
1798                        .justify_between()
1799                        .child(CheckboxWithLabel::new(
1800                            "dont-show-again",
1801                            Label::new("Don't show again"),
1802                            if dismissed_rate_limit_notice() {
1803                                ui::Selection::Selected
1804                            } else {
1805                                ui::Selection::Unselected
1806                            },
1807                            |selection, cx| {
1808                                let is_dismissed = match selection {
1809                                    ui::Selection::Unselected => false,
1810                                    ui::Selection::Indeterminate => return,
1811                                    ui::Selection::Selected => true,
1812                                };
1813
1814                                set_rate_limit_notice_dismissed(is_dismissed, cx)
1815                            },
1816                        ))
1817                        .child(
1818                            h_flex()
1819                                .gap_2()
1820                                .child(
1821                                    Button::new("dismiss", "Dismiss")
1822                                        .style(ButtonStyle::Transparent)
1823                                        .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1824                                )
1825                                .child(Button::new("more-info", "More Info").on_click(
1826                                    |_event, cx| {
1827                                        cx.dispatch_action(Box::new(
1828                                            zed_actions::OpenAccountSettings,
1829                                        ))
1830                                    },
1831                                )),
1832                        ),
1833                ),
1834        )
1835    }
1836}
1837
1838const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1839
1840fn dismissed_rate_limit_notice() -> bool {
1841    db::kvp::KEY_VALUE_STORE
1842        .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1843        .log_err()
1844        .map_or(false, |s| s.is_some())
1845}
1846
1847fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1848    db::write_and_log(cx, move || async move {
1849        if is_dismissed {
1850            db::kvp::KEY_VALUE_STORE
1851                .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1852                .await
1853        } else {
1854            db::kvp::KEY_VALUE_STORE
1855                .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1856                .await
1857        }
1858    })
1859}
1860
1861struct InlineAssist {
1862    group_id: InlineAssistGroupId,
1863    range: Range<Anchor>,
1864    editor: WeakView<Editor>,
1865    decorations: Option<InlineAssistDecorations>,
1866    codegen: Model<Codegen>,
1867    _subscriptions: Vec<Subscription>,
1868    workspace: Option<WeakView<Workspace>>,
1869    include_context: bool,
1870}
1871
1872impl InlineAssist {
1873    #[allow(clippy::too_many_arguments)]
1874    fn new(
1875        assist_id: InlineAssistId,
1876        group_id: InlineAssistGroupId,
1877        include_context: bool,
1878        editor: &View<Editor>,
1879        prompt_editor: &View<PromptEditor>,
1880        prompt_block_id: CustomBlockId,
1881        end_block_id: CustomBlockId,
1882        range: Range<Anchor>,
1883        codegen: Model<Codegen>,
1884        workspace: Option<WeakView<Workspace>>,
1885        cx: &mut WindowContext,
1886    ) -> Self {
1887        let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1888        InlineAssist {
1889            group_id,
1890            include_context,
1891            editor: editor.downgrade(),
1892            decorations: Some(InlineAssistDecorations {
1893                prompt_block_id,
1894                prompt_editor: prompt_editor.clone(),
1895                removed_line_block_ids: HashSet::default(),
1896                end_block_id,
1897            }),
1898            range,
1899            codegen: codegen.clone(),
1900            workspace: workspace.clone(),
1901            _subscriptions: vec![
1902                cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1903                    InlineAssistant::update_global(cx, |this, cx| {
1904                        this.handle_prompt_editor_focus_in(assist_id, cx)
1905                    })
1906                }),
1907                cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1908                    InlineAssistant::update_global(cx, |this, cx| {
1909                        this.handle_prompt_editor_focus_out(assist_id, cx)
1910                    })
1911                }),
1912                cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1913                    InlineAssistant::update_global(cx, |this, cx| {
1914                        this.handle_prompt_editor_event(prompt_editor, event, cx)
1915                    })
1916                }),
1917                cx.observe(&codegen, {
1918                    let editor = editor.downgrade();
1919                    move |_, cx| {
1920                        if let Some(editor) = editor.upgrade() {
1921                            InlineAssistant::update_global(cx, |this, cx| {
1922                                if let Some(editor_assists) =
1923                                    this.assists_by_editor.get(&editor.downgrade())
1924                                {
1925                                    editor_assists.highlight_updates.send(()).ok();
1926                                }
1927
1928                                this.update_editor_blocks(&editor, assist_id, cx);
1929                            })
1930                        }
1931                    }
1932                }),
1933                cx.subscribe(&codegen, move |codegen, event, cx| {
1934                    InlineAssistant::update_global(cx, |this, cx| match event {
1935                        CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1936                        CodegenEvent::Finished => {
1937                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
1938                                assist
1939                            } else {
1940                                return;
1941                            };
1942
1943                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
1944                                if assist.decorations.is_none() {
1945                                    if let Some(workspace) = assist
1946                                        .workspace
1947                                        .as_ref()
1948                                        .and_then(|workspace| workspace.upgrade())
1949                                    {
1950                                        let error = format!("Inline assistant error: {}", error);
1951                                        workspace.update(cx, |workspace, cx| {
1952                                            struct InlineAssistantError;
1953
1954                                            let id =
1955                                                NotificationId::identified::<InlineAssistantError>(
1956                                                    assist_id.0,
1957                                                );
1958
1959                                            workspace.show_toast(Toast::new(id, error), cx);
1960                                        })
1961                                    }
1962                                }
1963                            }
1964
1965                            if assist.decorations.is_none() {
1966                                this.finish_assist(assist_id, false, cx);
1967                            }
1968                        }
1969                    })
1970                }),
1971            ],
1972        }
1973    }
1974
1975    fn user_prompt(&self, cx: &AppContext) -> Option<String> {
1976        let decorations = self.decorations.as_ref()?;
1977        Some(decorations.prompt_editor.read(cx).prompt(cx))
1978    }
1979
1980    fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
1981        if self.include_context {
1982            let workspace = self.workspace.as_ref()?;
1983            let workspace = workspace.upgrade()?.read(cx);
1984            let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
1985            Some(
1986                assistant_panel
1987                    .read(cx)
1988                    .active_context(cx)?
1989                    .read(cx)
1990                    .to_completion_request(cx),
1991            )
1992        } else {
1993            None
1994        }
1995    }
1996
1997    pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
1998        let Some(user_prompt) = self.user_prompt(cx) else {
1999            return future::ready(Err(anyhow!("no user prompt"))).boxed();
2000        };
2001        let assistant_panel_context = self.assistant_panel_context(cx);
2002        self.codegen.read(cx).count_tokens(
2003            self.range.clone(),
2004            user_prompt,
2005            assistant_panel_context,
2006            cx,
2007        )
2008    }
2009}
2010
2011struct InlineAssistDecorations {
2012    prompt_block_id: CustomBlockId,
2013    prompt_editor: View<PromptEditor>,
2014    removed_line_block_ids: HashSet<CustomBlockId>,
2015    end_block_id: CustomBlockId,
2016}
2017
2018#[derive(Debug)]
2019pub enum CodegenEvent {
2020    Finished,
2021    Undone,
2022}
2023
2024pub struct Codegen {
2025    buffer: Model<MultiBuffer>,
2026    old_buffer: Model<Buffer>,
2027    snapshot: MultiBufferSnapshot,
2028    edit_position: Option<Anchor>,
2029    last_equal_ranges: Vec<Range<Anchor>>,
2030    initial_transaction_id: Option<TransactionId>,
2031    transformation_transaction_id: Option<TransactionId>,
2032    status: CodegenStatus,
2033    generation: Task<()>,
2034    diff: Diff,
2035    telemetry: Option<Arc<Telemetry>>,
2036    _subscription: gpui::Subscription,
2037    builder: Arc<PromptBuilder>,
2038}
2039
2040pub enum CodegenStatus {
2041    Idle,
2042    Pending,
2043    Done,
2044    Error(anyhow::Error),
2045}
2046
2047#[derive(Default)]
2048struct Diff {
2049    deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2050    inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2051}
2052
2053impl Diff {
2054    fn is_empty(&self) -> bool {
2055        self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2056    }
2057}
2058
2059impl EventEmitter<CodegenEvent> for Codegen {}
2060
2061impl Codegen {
2062    pub fn new(
2063        buffer: Model<MultiBuffer>,
2064        range: Range<Anchor>,
2065        initial_transaction_id: Option<TransactionId>,
2066        telemetry: Option<Arc<Telemetry>>,
2067        builder: Arc<PromptBuilder>,
2068        cx: &mut ModelContext<Self>,
2069    ) -> Self {
2070        let snapshot = buffer.read(cx).snapshot(cx);
2071
2072        let (old_buffer, _, _) = buffer
2073            .read(cx)
2074            .range_to_buffer_ranges(range.clone(), cx)
2075            .pop()
2076            .unwrap();
2077        let old_buffer = cx.new_model(|cx| {
2078            let old_buffer = old_buffer.read(cx);
2079            let text = old_buffer.as_rope().clone();
2080            let line_ending = old_buffer.line_ending();
2081            let language = old_buffer.language().cloned();
2082            let language_registry = old_buffer.language_registry();
2083
2084            let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2085            buffer.set_language(language, cx);
2086            if let Some(language_registry) = language_registry {
2087                buffer.set_language_registry(language_registry)
2088            }
2089            buffer
2090        });
2091
2092        Self {
2093            buffer: buffer.clone(),
2094            old_buffer,
2095            edit_position: None,
2096            snapshot,
2097            last_equal_ranges: Default::default(),
2098            transformation_transaction_id: None,
2099            status: CodegenStatus::Idle,
2100            generation: Task::ready(()),
2101            diff: Diff::default(),
2102            telemetry,
2103            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2104            initial_transaction_id,
2105            builder,
2106        }
2107    }
2108
2109    fn handle_buffer_event(
2110        &mut self,
2111        _buffer: Model<MultiBuffer>,
2112        event: &multi_buffer::Event,
2113        cx: &mut ModelContext<Self>,
2114    ) {
2115        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2116            if self.transformation_transaction_id == Some(*transaction_id) {
2117                self.transformation_transaction_id = None;
2118                self.generation = Task::ready(());
2119                cx.emit(CodegenEvent::Undone);
2120            }
2121        }
2122    }
2123
2124    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2125        &self.last_equal_ranges
2126    }
2127
2128    pub fn count_tokens(
2129        &self,
2130        edit_range: Range<Anchor>,
2131        user_prompt: String,
2132        assistant_panel_context: Option<LanguageModelRequest>,
2133        cx: &AppContext,
2134    ) -> BoxFuture<'static, Result<usize>> {
2135        if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2136            let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
2137            match request {
2138                Ok(request) => model.count_tokens(request, cx),
2139                Err(error) => futures::future::ready(Err(error)).boxed(),
2140            }
2141        } else {
2142            future::ready(Err(anyhow!("no active model"))).boxed()
2143        }
2144    }
2145
2146    pub fn start(
2147        &mut self,
2148        edit_range: Range<Anchor>,
2149        user_prompt: String,
2150        assistant_panel_context: Option<LanguageModelRequest>,
2151        cx: &mut ModelContext<Self>,
2152    ) -> Result<()> {
2153        let model = LanguageModelRegistry::read_global(cx)
2154            .active_model()
2155            .context("no active model")?;
2156
2157        if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2158            self.buffer.update(cx, |buffer, cx| {
2159                buffer.undo_transaction(transformation_transaction_id, cx)
2160            });
2161        }
2162
2163        self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2164
2165        let telemetry_id = model.telemetry_id();
2166        let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2167            .trim()
2168            .to_lowercase()
2169            == "delete"
2170        {
2171            async { Ok(stream::empty().boxed()) }.boxed_local()
2172        } else {
2173            let request =
2174                self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
2175
2176            let chunks =
2177                cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2178            async move { Ok(chunks.await?.boxed()) }.boxed_local()
2179        };
2180        self.handle_stream(telemetry_id, edit_range, chunks, cx);
2181        Ok(())
2182    }
2183
2184    fn build_request(
2185        &self,
2186        user_prompt: String,
2187        assistant_panel_context: Option<LanguageModelRequest>,
2188        edit_range: Range<Anchor>,
2189        cx: &AppContext,
2190    ) -> Result<LanguageModelRequest> {
2191        let buffer = self.buffer.read(cx).snapshot(cx);
2192        let language = buffer.language_at(edit_range.start);
2193        let language_name = if let Some(language) = language.as_ref() {
2194            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2195                None
2196            } else {
2197                Some(language.name())
2198            }
2199        } else {
2200            None
2201        };
2202
2203        // Higher Temperature increases the randomness of model outputs.
2204        // If Markdown or No Language is Known, increase the randomness for more creative output
2205        // If Code, decrease temperature to get more deterministic outputs
2206        let temperature = if let Some(language) = language_name.clone() {
2207            if language.as_ref() == "Markdown" {
2208                1.0
2209            } else {
2210                0.5
2211            }
2212        } else {
2213            1.0
2214        };
2215
2216        let language_name = language_name.as_deref();
2217        let start = buffer.point_to_buffer_offset(edit_range.start);
2218        let end = buffer.point_to_buffer_offset(edit_range.end);
2219        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2220            let (start_buffer, start_buffer_offset) = start;
2221            let (end_buffer, end_buffer_offset) = end;
2222            if start_buffer.remote_id() == end_buffer.remote_id() {
2223                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2224            } else {
2225                return Err(anyhow::anyhow!("invalid transformation range"));
2226            }
2227        } else {
2228            return Err(anyhow::anyhow!("invalid transformation range"));
2229        };
2230        let prompt = self
2231            .builder
2232            .generate_content_prompt(user_prompt, language_name, buffer, range)
2233            .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2234
2235        let mut messages = Vec::new();
2236        if let Some(context_request) = assistant_panel_context {
2237            messages = context_request.messages;
2238        }
2239
2240        messages.push(LanguageModelRequestMessage {
2241            role: Role::User,
2242            content: prompt,
2243        });
2244
2245        Ok(LanguageModelRequest {
2246            messages,
2247            stop: vec!["|END|>".to_string()],
2248            temperature,
2249        })
2250    }
2251
2252    pub fn handle_stream(
2253        &mut self,
2254        model_telemetry_id: String,
2255        edit_range: Range<Anchor>,
2256        stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2257        cx: &mut ModelContext<Self>,
2258    ) {
2259        let snapshot = self.snapshot.clone();
2260        let selected_text = snapshot
2261            .text_for_range(edit_range.start..edit_range.end)
2262            .collect::<Rope>();
2263
2264        let selection_start = edit_range.start.to_point(&snapshot);
2265
2266        // Start with the indentation of the first line in the selection
2267        let mut suggested_line_indent = snapshot
2268            .suggested_indents(selection_start.row..=selection_start.row, cx)
2269            .into_values()
2270            .next()
2271            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2272
2273        // If the first line in the selection does not have indentation, check the following lines
2274        if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2275            for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2276                let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2277                // Prefer tabs if a line in the selection uses tabs as indentation
2278                if line_indent.kind == IndentKind::Tab {
2279                    suggested_line_indent.kind = IndentKind::Tab;
2280                    break;
2281                }
2282            }
2283        }
2284
2285        let telemetry = self.telemetry.clone();
2286        self.diff = Diff::default();
2287        self.status = CodegenStatus::Pending;
2288        let mut edit_start = edit_range.start.to_offset(&snapshot);
2289        self.generation = cx.spawn(|this, mut cx| {
2290            async move {
2291                let chunks = stream.await;
2292                let generate = async {
2293                    let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2294                    let diff: Task<anyhow::Result<()>> =
2295                        cx.background_executor().spawn(async move {
2296                            let mut response_latency = None;
2297                            let request_start = Instant::now();
2298                            let diff = async {
2299                                let chunks = StripInvalidSpans::new(chunks?);
2300                                futures::pin_mut!(chunks);
2301                                let mut diff = StreamingDiff::new(selected_text.to_string());
2302                                let mut line_diff = LineDiff::default();
2303
2304                                let mut new_text = String::new();
2305                                let mut base_indent = None;
2306                                let mut line_indent = None;
2307                                let mut first_line = true;
2308
2309                                while let Some(chunk) = chunks.next().await {
2310                                    if response_latency.is_none() {
2311                                        response_latency = Some(request_start.elapsed());
2312                                    }
2313                                    let chunk = chunk?;
2314
2315                                    let mut lines = chunk.split('\n').peekable();
2316                                    while let Some(line) = lines.next() {
2317                                        new_text.push_str(line);
2318                                        if line_indent.is_none() {
2319                                            if let Some(non_whitespace_ch_ix) =
2320                                                new_text.find(|ch: char| !ch.is_whitespace())
2321                                            {
2322                                                line_indent = Some(non_whitespace_ch_ix);
2323                                                base_indent = base_indent.or(line_indent);
2324
2325                                                let line_indent = line_indent.unwrap();
2326                                                let base_indent = base_indent.unwrap();
2327                                                let indent_delta =
2328                                                    line_indent as i32 - base_indent as i32;
2329                                                let mut corrected_indent_len = cmp::max(
2330                                                    0,
2331                                                    suggested_line_indent.len as i32 + indent_delta,
2332                                                )
2333                                                    as usize;
2334                                                if first_line {
2335                                                    corrected_indent_len = corrected_indent_len
2336                                                        .saturating_sub(
2337                                                            selection_start.column as usize,
2338                                                        );
2339                                                }
2340
2341                                                let indent_char = suggested_line_indent.char();
2342                                                let mut indent_buffer = [0; 4];
2343                                                let indent_str =
2344                                                    indent_char.encode_utf8(&mut indent_buffer);
2345                                                new_text.replace_range(
2346                                                    ..line_indent,
2347                                                    &indent_str.repeat(corrected_indent_len),
2348                                                );
2349                                            }
2350                                        }
2351
2352                                        if line_indent.is_some() {
2353                                            let char_ops = diff.push_new(&new_text);
2354                                            line_diff
2355                                                .push_char_operations(&char_ops, &selected_text);
2356                                            diff_tx
2357                                                .send((char_ops, line_diff.line_operations()))
2358                                                .await?;
2359                                            new_text.clear();
2360                                        }
2361
2362                                        if lines.peek().is_some() {
2363                                            let char_ops = diff.push_new("\n");
2364                                            line_diff
2365                                                .push_char_operations(&char_ops, &selected_text);
2366                                            diff_tx
2367                                                .send((char_ops, line_diff.line_operations()))
2368                                                .await?;
2369                                            if line_indent.is_none() {
2370                                                // Don't write out the leading indentation in empty lines on the next line
2371                                                // This is the case where the above if statement didn't clear the buffer
2372                                                new_text.clear();
2373                                            }
2374                                            line_indent = None;
2375                                            first_line = false;
2376                                        }
2377                                    }
2378                                }
2379
2380                                let mut char_ops = diff.push_new(&new_text);
2381                                char_ops.extend(diff.finish());
2382                                line_diff.push_char_operations(&char_ops, &selected_text);
2383                                line_diff.finish(&selected_text);
2384                                diff_tx
2385                                    .send((char_ops, line_diff.line_operations()))
2386                                    .await?;
2387
2388                                anyhow::Ok(())
2389                            };
2390
2391                            let result = diff.await;
2392
2393                            let error_message =
2394                                result.as_ref().err().map(|error| error.to_string());
2395                            if let Some(telemetry) = telemetry {
2396                                telemetry.report_assistant_event(
2397                                    None,
2398                                    telemetry_events::AssistantKind::Inline,
2399                                    model_telemetry_id,
2400                                    response_latency,
2401                                    error_message,
2402                                );
2403                            }
2404
2405                            result?;
2406                            Ok(())
2407                        });
2408
2409                    while let Some((char_ops, line_diff)) = diff_rx.next().await {
2410                        this.update(&mut cx, |this, cx| {
2411                            this.last_equal_ranges.clear();
2412
2413                            let transaction = this.buffer.update(cx, |buffer, cx| {
2414                                // Avoid grouping assistant edits with user edits.
2415                                buffer.finalize_last_transaction(cx);
2416
2417                                buffer.start_transaction(cx);
2418                                buffer.edit(
2419                                    char_ops
2420                                        .into_iter()
2421                                        .filter_map(|operation| match operation {
2422                                            CharOperation::Insert { text } => {
2423                                                let edit_start = snapshot.anchor_after(edit_start);
2424                                                Some((edit_start..edit_start, text))
2425                                            }
2426                                            CharOperation::Delete { bytes } => {
2427                                                let edit_end = edit_start + bytes;
2428                                                let edit_range = snapshot.anchor_after(edit_start)
2429                                                    ..snapshot.anchor_before(edit_end);
2430                                                edit_start = edit_end;
2431                                                Some((edit_range, String::new()))
2432                                            }
2433                                            CharOperation::Keep { bytes } => {
2434                                                let edit_end = edit_start + bytes;
2435                                                let edit_range = snapshot.anchor_after(edit_start)
2436                                                    ..snapshot.anchor_before(edit_end);
2437                                                edit_start = edit_end;
2438                                                this.last_equal_ranges.push(edit_range);
2439                                                None
2440                                            }
2441                                        }),
2442                                    None,
2443                                    cx,
2444                                );
2445                                this.edit_position = Some(snapshot.anchor_after(edit_start));
2446
2447                                buffer.end_transaction(cx)
2448                            });
2449
2450                            if let Some(transaction) = transaction {
2451                                if let Some(first_transaction) = this.transformation_transaction_id
2452                                {
2453                                    // Group all assistant edits into the first transaction.
2454                                    this.buffer.update(cx, |buffer, cx| {
2455                                        buffer.merge_transactions(
2456                                            transaction,
2457                                            first_transaction,
2458                                            cx,
2459                                        )
2460                                    });
2461                                } else {
2462                                    this.transformation_transaction_id = Some(transaction);
2463                                    this.buffer.update(cx, |buffer, cx| {
2464                                        buffer.finalize_last_transaction(cx)
2465                                    });
2466                                }
2467                            }
2468
2469                            this.update_diff(edit_range.clone(), line_diff, cx);
2470
2471                            cx.notify();
2472                        })?;
2473                    }
2474
2475                    diff.await?;
2476
2477                    anyhow::Ok(())
2478                };
2479
2480                let result = generate.await;
2481                this.update(&mut cx, |this, cx| {
2482                    this.last_equal_ranges.clear();
2483                    if let Err(error) = result {
2484                        this.status = CodegenStatus::Error(error);
2485                    } else {
2486                        this.status = CodegenStatus::Done;
2487                    }
2488                    cx.emit(CodegenEvent::Finished);
2489                    cx.notify();
2490                })
2491                .ok();
2492            }
2493        });
2494        cx.notify();
2495    }
2496
2497    pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2498        self.last_equal_ranges.clear();
2499        if self.diff.is_empty() {
2500            self.status = CodegenStatus::Idle;
2501        } else {
2502            self.status = CodegenStatus::Done;
2503        }
2504        self.generation = Task::ready(());
2505        cx.emit(CodegenEvent::Finished);
2506        cx.notify();
2507    }
2508
2509    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2510        self.buffer.update(cx, |buffer, cx| {
2511            if let Some(transaction_id) = self.transformation_transaction_id.take() {
2512                buffer.undo_transaction(transaction_id, cx);
2513            }
2514
2515            if let Some(transaction_id) = self.initial_transaction_id.take() {
2516                buffer.undo_transaction(transaction_id, cx);
2517            }
2518        });
2519    }
2520
2521    fn update_diff(
2522        &mut self,
2523        edit_range: Range<Anchor>,
2524        line_operations: Vec<LineOperation>,
2525        cx: &mut ModelContext<Self>,
2526    ) {
2527        let old_snapshot = self.snapshot.clone();
2528        let old_range = edit_range.to_point(&old_snapshot);
2529        let new_snapshot = self.buffer.read(cx).snapshot(cx);
2530        let new_range = edit_range.to_point(&new_snapshot);
2531
2532        let mut old_row = old_range.start.row;
2533        let mut new_row = new_range.start.row;
2534
2535        self.diff.deleted_row_ranges.clear();
2536        self.diff.inserted_row_ranges.clear();
2537        for operation in line_operations {
2538            match operation {
2539                LineOperation::Keep { lines } => {
2540                    old_row += lines;
2541                    new_row += lines;
2542                }
2543                LineOperation::Delete { lines } => {
2544                    let old_end_row = old_row + lines - 1;
2545                    let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2546
2547                    if let Some((_, last_deleted_row_range)) =
2548                        self.diff.deleted_row_ranges.last_mut()
2549                    {
2550                        if *last_deleted_row_range.end() + 1 == old_row {
2551                            *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2552                        } else {
2553                            self.diff
2554                                .deleted_row_ranges
2555                                .push((new_row, old_row..=old_end_row));
2556                        }
2557                    } else {
2558                        self.diff
2559                            .deleted_row_ranges
2560                            .push((new_row, old_row..=old_end_row));
2561                    }
2562
2563                    old_row += lines;
2564                }
2565                LineOperation::Insert { lines } => {
2566                    let new_end_row = new_row + lines - 1;
2567                    let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2568                    let end = new_snapshot.anchor_before(Point::new(
2569                        new_end_row,
2570                        new_snapshot.line_len(MultiBufferRow(new_end_row)),
2571                    ));
2572                    self.diff.inserted_row_ranges.push(start..=end);
2573                    new_row += lines;
2574                }
2575            }
2576
2577            cx.notify();
2578        }
2579    }
2580}
2581
2582struct StripInvalidSpans<T> {
2583    stream: T,
2584    stream_done: bool,
2585    buffer: String,
2586    first_line: bool,
2587    line_end: bool,
2588    starts_with_code_block: bool,
2589}
2590
2591impl<T> StripInvalidSpans<T>
2592where
2593    T: Stream<Item = Result<String>>,
2594{
2595    fn new(stream: T) -> Self {
2596        Self {
2597            stream,
2598            stream_done: false,
2599            buffer: String::new(),
2600            first_line: true,
2601            line_end: false,
2602            starts_with_code_block: false,
2603        }
2604    }
2605}
2606
2607impl<T> Stream for StripInvalidSpans<T>
2608where
2609    T: Stream<Item = Result<String>>,
2610{
2611    type Item = Result<String>;
2612
2613    fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2614        const CODE_BLOCK_DELIMITER: &str = "```";
2615        const CURSOR_SPAN: &str = "<|CURSOR|>";
2616
2617        let this = unsafe { self.get_unchecked_mut() };
2618        loop {
2619            if !this.stream_done {
2620                let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2621                match stream.as_mut().poll_next(cx) {
2622                    Poll::Ready(Some(Ok(chunk))) => {
2623                        this.buffer.push_str(&chunk);
2624                    }
2625                    Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2626                    Poll::Ready(None) => {
2627                        this.stream_done = true;
2628                    }
2629                    Poll::Pending => return Poll::Pending,
2630                }
2631            }
2632
2633            let mut chunk = String::new();
2634            let mut consumed = 0;
2635            if !this.buffer.is_empty() {
2636                let mut lines = this.buffer.split('\n').enumerate().peekable();
2637                while let Some((line_ix, line)) = lines.next() {
2638                    if line_ix > 0 {
2639                        this.first_line = false;
2640                    }
2641
2642                    if this.first_line {
2643                        let trimmed_line = line.trim();
2644                        if lines.peek().is_some() {
2645                            if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2646                                consumed += line.len() + 1;
2647                                this.starts_with_code_block = true;
2648                                continue;
2649                            }
2650                        } else if trimmed_line.is_empty()
2651                            || prefixes(CODE_BLOCK_DELIMITER)
2652                                .any(|prefix| trimmed_line.starts_with(prefix))
2653                        {
2654                            break;
2655                        }
2656                    }
2657
2658                    let line_without_cursor = line.replace(CURSOR_SPAN, "");
2659                    if lines.peek().is_some() {
2660                        if this.line_end {
2661                            chunk.push('\n');
2662                        }
2663
2664                        chunk.push_str(&line_without_cursor);
2665                        this.line_end = true;
2666                        consumed += line.len() + 1;
2667                    } else if this.stream_done {
2668                        if !this.starts_with_code_block
2669                            || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2670                        {
2671                            if this.line_end {
2672                                chunk.push('\n');
2673                            }
2674
2675                            chunk.push_str(&line);
2676                        }
2677
2678                        consumed += line.len();
2679                    } else {
2680                        let trimmed_line = line.trim();
2681                        if trimmed_line.is_empty()
2682                            || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2683                            || prefixes(CODE_BLOCK_DELIMITER)
2684                                .any(|prefix| trimmed_line.ends_with(prefix))
2685                        {
2686                            break;
2687                        } else {
2688                            if this.line_end {
2689                                chunk.push('\n');
2690                                this.line_end = false;
2691                            }
2692
2693                            chunk.push_str(&line_without_cursor);
2694                            consumed += line.len();
2695                        }
2696                    }
2697                }
2698            }
2699
2700            this.buffer = this.buffer.split_off(consumed);
2701            if !chunk.is_empty() {
2702                return Poll::Ready(Some(Ok(chunk)));
2703            } else if this.stream_done {
2704                return Poll::Ready(None);
2705            }
2706        }
2707    }
2708}
2709
2710fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2711    (0..text.len() - 1).map(|ix| &text[..ix + 1])
2712}
2713
2714fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2715    ranges.sort_unstable_by(|a, b| {
2716        a.start
2717            .cmp(&b.start, buffer)
2718            .then_with(|| b.end.cmp(&a.end, buffer))
2719    });
2720
2721    let mut ix = 0;
2722    while ix + 1 < ranges.len() {
2723        let b = ranges[ix + 1].clone();
2724        let a = &mut ranges[ix];
2725        if a.end.cmp(&b.start, buffer).is_gt() {
2726            if a.end.cmp(&b.end, buffer).is_lt() {
2727                a.end = b.end;
2728            }
2729            ranges.remove(ix + 1);
2730        } else {
2731            ix += 1;
2732        }
2733    }
2734}
2735
2736#[cfg(test)]
2737mod tests {
2738    use super::*;
2739    use futures::stream::{self};
2740    use gpui::{Context, TestAppContext};
2741    use indoc::indoc;
2742    use language::{
2743        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2744        Point,
2745    };
2746    use language_model::LanguageModelRegistry;
2747    use rand::prelude::*;
2748    use serde::Serialize;
2749    use settings::SettingsStore;
2750    use std::{future, sync::Arc};
2751
2752    #[derive(Serialize)]
2753    pub struct DummyCompletionRequest {
2754        pub name: String,
2755    }
2756
2757    #[gpui::test(iterations = 10)]
2758    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2759        cx.set_global(cx.update(SettingsStore::test));
2760        cx.update(language_model::LanguageModelRegistry::test);
2761        cx.update(language_settings::init);
2762
2763        let text = indoc! {"
2764            fn main() {
2765                let x = 0;
2766                for _ in 0..10 {
2767                    x += 1;
2768                }
2769            }
2770        "};
2771        let buffer =
2772            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2773        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2774        let range = buffer.read_with(cx, |buffer, cx| {
2775            let snapshot = buffer.snapshot(cx);
2776            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2777        });
2778        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2779        let codegen = cx.new_model(|cx| {
2780            Codegen::new(
2781                buffer.clone(),
2782                range.clone(),
2783                None,
2784                None,
2785                prompt_builder,
2786                cx,
2787            )
2788        });
2789
2790        let (chunks_tx, chunks_rx) = mpsc::unbounded();
2791        codegen.update(cx, |codegen, cx| {
2792            codegen.handle_stream(
2793                String::new(),
2794                range,
2795                future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2796                cx,
2797            )
2798        });
2799
2800        let mut new_text = concat!(
2801            "       let mut x = 0;\n",
2802            "       while x < 10 {\n",
2803            "           x += 1;\n",
2804            "       }",
2805        );
2806        while !new_text.is_empty() {
2807            let max_len = cmp::min(new_text.len(), 10);
2808            let len = rng.gen_range(1..=max_len);
2809            let (chunk, suffix) = new_text.split_at(len);
2810            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2811            new_text = suffix;
2812            cx.background_executor.run_until_parked();
2813        }
2814        drop(chunks_tx);
2815        cx.background_executor.run_until_parked();
2816
2817        assert_eq!(
2818            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2819            indoc! {"
2820                fn main() {
2821                    let mut x = 0;
2822                    while x < 10 {
2823                        x += 1;
2824                    }
2825                }
2826            "}
2827        );
2828    }
2829
2830    #[gpui::test(iterations = 10)]
2831    async fn test_autoindent_when_generating_past_indentation(
2832        cx: &mut TestAppContext,
2833        mut rng: StdRng,
2834    ) {
2835        cx.set_global(cx.update(SettingsStore::test));
2836        cx.update(language_settings::init);
2837
2838        let text = indoc! {"
2839            fn main() {
2840                le
2841            }
2842        "};
2843        let buffer =
2844            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2845        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2846        let range = buffer.read_with(cx, |buffer, cx| {
2847            let snapshot = buffer.snapshot(cx);
2848            snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2849        });
2850        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2851        let codegen = cx.new_model(|cx| {
2852            Codegen::new(
2853                buffer.clone(),
2854                range.clone(),
2855                None,
2856                None,
2857                prompt_builder,
2858                cx,
2859            )
2860        });
2861
2862        let (chunks_tx, chunks_rx) = mpsc::unbounded();
2863        codegen.update(cx, |codegen, cx| {
2864            codegen.handle_stream(
2865                String::new(),
2866                range.clone(),
2867                future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2868                cx,
2869            )
2870        });
2871
2872        cx.background_executor.run_until_parked();
2873
2874        let mut new_text = concat!(
2875            "t mut x = 0;\n",
2876            "while x < 10 {\n",
2877            "    x += 1;\n",
2878            "}", //
2879        );
2880        while !new_text.is_empty() {
2881            let max_len = cmp::min(new_text.len(), 10);
2882            let len = rng.gen_range(1..=max_len);
2883            let (chunk, suffix) = new_text.split_at(len);
2884            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2885            new_text = suffix;
2886            cx.background_executor.run_until_parked();
2887        }
2888        drop(chunks_tx);
2889        cx.background_executor.run_until_parked();
2890
2891        assert_eq!(
2892            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2893            indoc! {"
2894                fn main() {
2895                    let mut x = 0;
2896                    while x < 10 {
2897                        x += 1;
2898                    }
2899                }
2900            "}
2901        );
2902    }
2903
2904    #[gpui::test(iterations = 10)]
2905    async fn test_autoindent_when_generating_before_indentation(
2906        cx: &mut TestAppContext,
2907        mut rng: StdRng,
2908    ) {
2909        cx.update(LanguageModelRegistry::test);
2910        cx.set_global(cx.update(SettingsStore::test));
2911        cx.update(language_settings::init);
2912
2913        let text = concat!(
2914            "fn main() {\n",
2915            "  \n",
2916            "}\n" //
2917        );
2918        let buffer =
2919            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2920        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2921        let range = buffer.read_with(cx, |buffer, cx| {
2922            let snapshot = buffer.snapshot(cx);
2923            snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2924        });
2925        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2926        let codegen = cx.new_model(|cx| {
2927            Codegen::new(
2928                buffer.clone(),
2929                range.clone(),
2930                None,
2931                None,
2932                prompt_builder,
2933                cx,
2934            )
2935        });
2936
2937        let (chunks_tx, chunks_rx) = mpsc::unbounded();
2938        codegen.update(cx, |codegen, cx| {
2939            codegen.handle_stream(
2940                String::new(),
2941                range.clone(),
2942                future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2943                cx,
2944            )
2945        });
2946
2947        cx.background_executor.run_until_parked();
2948
2949        let mut new_text = concat!(
2950            "let mut x = 0;\n",
2951            "while x < 10 {\n",
2952            "    x += 1;\n",
2953            "}", //
2954        );
2955        while !new_text.is_empty() {
2956            let max_len = cmp::min(new_text.len(), 10);
2957            let len = rng.gen_range(1..=max_len);
2958            let (chunk, suffix) = new_text.split_at(len);
2959            chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2960            new_text = suffix;
2961            cx.background_executor.run_until_parked();
2962        }
2963        drop(chunks_tx);
2964        cx.background_executor.run_until_parked();
2965
2966        assert_eq!(
2967            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2968            indoc! {"
2969                fn main() {
2970                    let mut x = 0;
2971                    while x < 10 {
2972                        x += 1;
2973                    }
2974                }
2975            "}
2976        );
2977    }
2978
2979    #[gpui::test(iterations = 10)]
2980    async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
2981        cx.update(LanguageModelRegistry::test);
2982        cx.set_global(cx.update(SettingsStore::test));
2983        cx.update(language_settings::init);
2984
2985        let text = indoc! {"
2986            func main() {
2987            \tx := 0
2988            \tfor i := 0; i < 10; i++ {
2989            \t\tx++
2990            \t}
2991            }
2992        "};
2993        let buffer = cx.new_model(|cx| Buffer::local(text, cx));
2994        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2995        let range = buffer.read_with(cx, |buffer, cx| {
2996            let snapshot = buffer.snapshot(cx);
2997            snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
2998        });
2999        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3000        let codegen = cx.new_model(|cx| {
3001            Codegen::new(
3002                buffer.clone(),
3003                range.clone(),
3004                None,
3005                None,
3006                prompt_builder,
3007                cx,
3008            )
3009        });
3010
3011        let (chunks_tx, chunks_rx) = mpsc::unbounded();
3012        codegen.update(cx, |codegen, cx| {
3013            codegen.handle_stream(
3014                String::new(),
3015                range.clone(),
3016                future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3017                cx,
3018            )
3019        });
3020
3021        let new_text = concat!(
3022            "func main() {\n",
3023            "\tx := 0\n",
3024            "\tfor x < 10 {\n",
3025            "\t\tx++\n",
3026            "\t}", //
3027        );
3028        chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3029        drop(chunks_tx);
3030        cx.background_executor.run_until_parked();
3031
3032        assert_eq!(
3033            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3034            indoc! {"
3035                func main() {
3036                \tx := 0
3037                \tfor x < 10 {
3038                \t\tx++
3039                \t}
3040                }
3041            "}
3042        );
3043    }
3044
3045    #[gpui::test]
3046    async fn test_strip_invalid_spans_from_codeblock() {
3047        assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3048        assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3049        assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3050        assert_chunks(
3051            "```html\n```js\nLorem ipsum dolor\n```\n```",
3052            "```js\nLorem ipsum dolor\n```",
3053        )
3054        .await;
3055        assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3056        assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3057        assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3058        assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3059
3060        async fn assert_chunks(text: &str, expected_text: &str) {
3061            for chunk_size in 1..=text.len() {
3062                let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3063                    .map(|chunk| chunk.unwrap())
3064                    .collect::<String>()
3065                    .await;
3066                assert_eq!(
3067                    actual_text, expected_text,
3068                    "failed to strip invalid spans, chunk size: {}",
3069                    chunk_size
3070                );
3071            }
3072        }
3073
3074        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3075            stream::iter(
3076                text.chars()
3077                    .collect::<Vec<_>>()
3078                    .chunks(size)
3079                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
3080                    .collect::<Vec<_>>(),
3081            )
3082        }
3083    }
3084
3085    fn rust_lang() -> Language {
3086        Language::new(
3087            LanguageConfig {
3088                name: "Rust".into(),
3089                matcher: LanguageMatcher {
3090                    path_suffixes: vec!["rs".to_string()],
3091                    ..Default::default()
3092                },
3093                ..Default::default()
3094            },
3095            Some(tree_sitter_rust::language()),
3096        )
3097        .with_indents_query(
3098            r#"
3099            (call_expression) @indent
3100            (field_expression) @indent
3101            (_ "(" ")" @end) @indent
3102            (_ "{" "}" @end) @indent
3103            "#,
3104        )
3105        .unwrap()
3106    }
3107}