inline_assistant.rs

   1use crate::{
   2    prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
   3    LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
   4};
   5use anyhow::Result;
   6use client::telemetry::Telemetry;
   7use collections::{hash_map, HashMap, HashSet, VecDeque};
   8use editor::{
   9    actions::{MoveDown, MoveUp},
  10    display_map::{
  11        BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
  12    },
  13    scroll::{Autoscroll, AutoscrollStrategy},
  14    Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer,
  15    MultiBufferSnapshot, ToOffset, ToPoint,
  16};
  17use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  18use gpui::{
  19    AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
  20    Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
  21    ViewContext, WeakView, WhiteSpace, WindowContext,
  22};
  23use language::{Point, TransactionId};
  24use multi_buffer::MultiBufferRow;
  25use parking_lot::Mutex;
  26use rope::Rope;
  27use settings::Settings;
  28use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
  29use theme::ThemeSettings;
  30use ui::{prelude::*, Tooltip};
  31use workspace::{notifications::NotificationId, Toast, Workspace};
  32
  33pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
  34    cx.set_global(InlineAssistant::new(telemetry));
  35}
  36
  37const PROMPT_HISTORY_MAX_LEN: usize = 20;
  38
  39pub struct InlineAssistant {
  40    next_assist_id: InlineAssistId,
  41    pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
  42    pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
  43    prompt_history: VecDeque<String>,
  44    telemetry: Option<Arc<Telemetry>>,
  45}
  46
  47struct EditorPendingAssists {
  48    window: AnyWindowHandle,
  49    assist_ids: Vec<InlineAssistId>,
  50}
  51
  52impl Global for InlineAssistant {}
  53
  54impl InlineAssistant {
  55    pub fn new(telemetry: Arc<Telemetry>) -> Self {
  56        Self {
  57            next_assist_id: InlineAssistId::default(),
  58            pending_assists: HashMap::default(),
  59            pending_assist_ids_by_editor: HashMap::default(),
  60            prompt_history: VecDeque::default(),
  61            telemetry: Some(telemetry),
  62        }
  63    }
  64
  65    pub fn assist(
  66        &mut self,
  67        editor: &View<Editor>,
  68        workspace: Option<WeakView<Workspace>>,
  69        include_conversation: bool,
  70        cx: &mut WindowContext,
  71    ) {
  72        let selection = editor.read(cx).selections.newest_anchor().clone();
  73        if selection.start.excerpt_id != selection.end.excerpt_id {
  74            return;
  75        }
  76        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
  77
  78        // Extend the selection to the start and the end of the line.
  79        let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
  80        if point_selection.end > point_selection.start {
  81            point_selection.start.column = 0;
  82            // If the selection ends at the start of the line, we don't want to include it.
  83            if point_selection.end.column == 0 {
  84                point_selection.end.row -= 1;
  85            }
  86            point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
  87        }
  88
  89        let codegen_kind = if point_selection.start == point_selection.end {
  90            CodegenKind::Generate {
  91                position: snapshot.anchor_after(point_selection.start),
  92            }
  93        } else {
  94            CodegenKind::Transform {
  95                range: snapshot.anchor_before(point_selection.start)
  96                    ..snapshot.anchor_after(point_selection.end),
  97            }
  98        };
  99
 100        let inline_assist_id = self.next_assist_id.post_inc();
 101        let codegen = cx.new_model(|cx| {
 102            Codegen::new(
 103                editor.read(cx).buffer().clone(),
 104                codegen_kind,
 105                self.telemetry.clone(),
 106                cx,
 107            )
 108        });
 109
 110        let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 111        let inline_assist_editor = cx.new_view(|cx| {
 112            InlineAssistEditor::new(
 113                inline_assist_id,
 114                gutter_dimensions.clone(),
 115                self.prompt_history.clone(),
 116                codegen.clone(),
 117                cx,
 118            )
 119        });
 120        let block_id = editor.update(cx, |editor, cx| {
 121            editor.change_selections(None, cx, |selections| {
 122                selections.select_anchor_ranges([selection.head()..selection.head()])
 123            });
 124            editor.insert_blocks(
 125                [BlockProperties {
 126                    style: BlockStyle::Sticky,
 127                    position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)),
 128                    height: inline_assist_editor.read(cx).height_in_lines,
 129                    render: build_inline_assist_editor_renderer(
 130                        &inline_assist_editor,
 131                        gutter_dimensions,
 132                    ),
 133                    disposition: if selection.reversed {
 134                        BlockDisposition::Above
 135                    } else {
 136                        BlockDisposition::Below
 137                    },
 138                }],
 139                Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
 140                cx,
 141            )[0]
 142        });
 143
 144        self.pending_assists.insert(
 145            inline_assist_id,
 146            PendingInlineAssist {
 147                include_conversation,
 148                editor: editor.downgrade(),
 149                inline_assist_editor: Some((block_id, inline_assist_editor.clone())),
 150                codegen: codegen.clone(),
 151                workspace,
 152                _subscriptions: vec![
 153                    cx.subscribe(&inline_assist_editor, |inline_assist_editor, event, cx| {
 154                        InlineAssistant::update_global(cx, |this, cx| {
 155                            this.handle_inline_assistant_event(inline_assist_editor, event, cx)
 156                        })
 157                    }),
 158                    cx.subscribe(editor, {
 159                        let inline_assist_editor = inline_assist_editor.downgrade();
 160                        move |editor, event, cx| {
 161                            if let Some(inline_assist_editor) = inline_assist_editor.upgrade() {
 162                                if let EditorEvent::SelectionsChanged { local } = event {
 163                                    if *local
 164                                        && inline_assist_editor
 165                                            .focus_handle(cx)
 166                                            .contains_focused(cx)
 167                                    {
 168                                        cx.focus_view(&editor);
 169                                    }
 170                                }
 171                            }
 172                        }
 173                    }),
 174                    cx.observe(&codegen, {
 175                        let editor = editor.downgrade();
 176                        move |_, cx| {
 177                            if let Some(editor) = editor.upgrade() {
 178                                InlineAssistant::update_global(cx, |this, cx| {
 179                                    this.update_highlights_for_editor(&editor, cx);
 180                                })
 181                            }
 182                        }
 183                    }),
 184                    cx.subscribe(&codegen, move |codegen, event, cx| {
 185                        InlineAssistant::update_global(cx, |this, cx| match event {
 186                            CodegenEvent::Undone => {
 187                                this.finish_inline_assist(inline_assist_id, false, cx)
 188                            }
 189                            CodegenEvent::Finished => {
 190                                let pending_assist = if let Some(pending_assist) =
 191                                    this.pending_assists.get(&inline_assist_id)
 192                                {
 193                                    pending_assist
 194                                } else {
 195                                    return;
 196                                };
 197
 198                                let error = codegen
 199                                    .read(cx)
 200                                    .error()
 201                                    .map(|error| format!("Inline assistant error: {}", error));
 202                                if let Some(error) = error {
 203                                    if pending_assist.inline_assist_editor.is_none() {
 204                                        if let Some(workspace) = pending_assist
 205                                            .workspace
 206                                            .as_ref()
 207                                            .and_then(|workspace| workspace.upgrade())
 208                                        {
 209                                            workspace.update(cx, |workspace, cx| {
 210                                                struct InlineAssistantError;
 211
 212                                                let id = NotificationId::identified::<
 213                                                    InlineAssistantError,
 214                                                >(
 215                                                    inline_assist_id.0
 216                                                );
 217
 218                                                workspace.show_toast(Toast::new(id, error), cx);
 219                                            })
 220                                        }
 221
 222                                        this.finish_inline_assist(inline_assist_id, false, cx);
 223                                    }
 224                                } else {
 225                                    this.finish_inline_assist(inline_assist_id, false, cx);
 226                                }
 227                            }
 228                        })
 229                    }),
 230                ],
 231            },
 232        );
 233
 234        self.pending_assist_ids_by_editor
 235            .entry(editor.downgrade())
 236            .or_insert_with(|| EditorPendingAssists {
 237                window: cx.window_handle(),
 238                assist_ids: Vec::new(),
 239            })
 240            .assist_ids
 241            .push(inline_assist_id);
 242        self.update_highlights_for_editor(editor, cx);
 243    }
 244
 245    fn handle_inline_assistant_event(
 246        &mut self,
 247        inline_assist_editor: View<InlineAssistEditor>,
 248        event: &InlineAssistEditorEvent,
 249        cx: &mut WindowContext,
 250    ) {
 251        let assist_id = inline_assist_editor.read(cx).id;
 252        match event {
 253            InlineAssistEditorEvent::Confirmed { prompt } => {
 254                self.confirm_inline_assist(assist_id, prompt, cx);
 255            }
 256            InlineAssistEditorEvent::Canceled => {
 257                self.finish_inline_assist(assist_id, true, cx);
 258            }
 259            InlineAssistEditorEvent::Dismissed => {
 260                self.hide_inline_assist(assist_id, cx);
 261            }
 262            InlineAssistEditorEvent::Resized { height_in_lines } => {
 263                self.resize_inline_assist(assist_id, *height_in_lines, cx);
 264            }
 265        }
 266    }
 267
 268    pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
 269        for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
 270            if pending_assists.window == cx.window_handle() {
 271                if let Some(editor) = editor.upgrade() {
 272                    if editor.read(cx).is_focused(cx) {
 273                        if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
 274                            self.finish_inline_assist(assist_id, true, cx);
 275                            return true;
 276                        }
 277                    }
 278                }
 279            }
 280        }
 281        false
 282    }
 283
 284    fn finish_inline_assist(
 285        &mut self,
 286        assist_id: InlineAssistId,
 287        undo: bool,
 288        cx: &mut WindowContext,
 289    ) {
 290        self.hide_inline_assist(assist_id, cx);
 291
 292        if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
 293            if let hash_map::Entry::Occupied(mut entry) = self
 294                .pending_assist_ids_by_editor
 295                .entry(pending_assist.editor.clone())
 296            {
 297                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 298                if entry.get().assist_ids.is_empty() {
 299                    entry.remove();
 300                }
 301            }
 302
 303            if let Some(editor) = pending_assist.editor.upgrade() {
 304                self.update_highlights_for_editor(&editor, cx);
 305
 306                if undo {
 307                    pending_assist
 308                        .codegen
 309                        .update(cx, |codegen, cx| codegen.undo(cx));
 310                }
 311            }
 312        }
 313    }
 314
 315    fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 316        if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
 317            if let Some(editor) = pending_assist.editor.upgrade() {
 318                if let Some((block_id, inline_assist_editor)) =
 319                    pending_assist.inline_assist_editor.take()
 320                {
 321                    editor.update(cx, |editor, cx| {
 322                        editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
 323                        if inline_assist_editor.focus_handle(cx).contains_focused(cx) {
 324                            editor.focus(cx);
 325                        }
 326                    });
 327                }
 328            }
 329        }
 330    }
 331
 332    fn resize_inline_assist(
 333        &mut self,
 334        assist_id: InlineAssistId,
 335        height_in_lines: u8,
 336        cx: &mut WindowContext,
 337    ) {
 338        if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
 339            if let Some(editor) = pending_assist.editor.upgrade() {
 340                if let Some((block_id, inline_assist_editor)) =
 341                    pending_assist.inline_assist_editor.as_ref()
 342                {
 343                    let gutter_dimensions = inline_assist_editor.read(cx).gutter_dimensions.clone();
 344                    let mut new_blocks = HashMap::default();
 345                    new_blocks.insert(
 346                        *block_id,
 347                        (
 348                            Some(height_in_lines),
 349                            build_inline_assist_editor_renderer(
 350                                inline_assist_editor,
 351                                gutter_dimensions,
 352                            ),
 353                        ),
 354                    );
 355                    editor.update(cx, |editor, cx| {
 356                        editor
 357                            .display_map
 358                            .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
 359                    });
 360                }
 361            }
 362        }
 363    }
 364
 365    fn confirm_inline_assist(
 366        &mut self,
 367        assist_id: InlineAssistId,
 368        user_prompt: &str,
 369        cx: &mut WindowContext,
 370    ) {
 371        let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
 372        {
 373            pending_assist
 374        } else {
 375            return;
 376        };
 377
 378        let conversation = if pending_assist.include_conversation {
 379            pending_assist.workspace.as_ref().and_then(|workspace| {
 380                let workspace = workspace.upgrade()?.read(cx);
 381                let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
 382                assistant_panel.read(cx).active_conversation(cx)
 383            })
 384        } else {
 385            None
 386        };
 387
 388        let editor = if let Some(editor) = pending_assist.editor.upgrade() {
 389            editor
 390        } else {
 391            return;
 392        };
 393
 394        let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
 395            let workspace = workspace.upgrade()?;
 396            Some(
 397                workspace
 398                    .read(cx)
 399                    .project()
 400                    .read(cx)
 401                    .worktree_root_names(cx)
 402                    .collect::<Vec<&str>>()
 403                    .join("/"),
 404            )
 405        });
 406
 407        self.prompt_history.retain(|prompt| prompt != user_prompt);
 408        self.prompt_history.push_back(user_prompt.into());
 409        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
 410            self.prompt_history.pop_front();
 411        }
 412
 413        let codegen = pending_assist.codegen.clone();
 414        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 415        let range = codegen.read(cx).range();
 416        let start = snapshot.point_to_buffer_offset(range.start);
 417        let end = snapshot.point_to_buffer_offset(range.end);
 418        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 419            let (start_buffer, start_buffer_offset) = start;
 420            let (end_buffer, end_buffer_offset) = end;
 421            if start_buffer.remote_id() == end_buffer.remote_id() {
 422                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 423            } else {
 424                self.finish_inline_assist(assist_id, false, cx);
 425                return;
 426            }
 427        } else {
 428            self.finish_inline_assist(assist_id, false, cx);
 429            return;
 430        };
 431
 432        let language = buffer.language_at(range.start);
 433        let language_name = if let Some(language) = language.as_ref() {
 434            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 435                None
 436            } else {
 437                Some(language.name())
 438            }
 439        } else {
 440            None
 441        };
 442
 443        // Higher Temperature increases the randomness of model outputs.
 444        // If Markdown or No Language is Known, increase the randomness for more creative output
 445        // If Code, decrease temperature to get more deterministic outputs
 446        let temperature = if let Some(language) = language_name.clone() {
 447            if language.as_ref() == "Markdown" {
 448                1.0
 449            } else {
 450                0.5
 451            }
 452        } else {
 453            1.0
 454        };
 455
 456        let user_prompt = user_prompt.to_string();
 457
 458        let prompt = cx.background_executor().spawn(async move {
 459            let language_name = language_name.as_deref();
 460            generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
 461        });
 462
 463        let mut messages = Vec::new();
 464        if let Some(conversation) = conversation {
 465            let request = conversation.read(cx).to_completion_request(cx);
 466            messages = request.messages;
 467        }
 468        let model = CompletionProvider::global(cx).model();
 469
 470        cx.spawn(|mut cx| async move {
 471            let prompt = prompt.await?;
 472
 473            messages.push(LanguageModelRequestMessage {
 474                role: Role::User,
 475                content: prompt,
 476            });
 477
 478            let request = LanguageModelRequest {
 479                model,
 480                messages,
 481                stop: vec!["|END|>".to_string()],
 482                temperature,
 483            };
 484
 485            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
 486            anyhow::Ok(())
 487        })
 488        .detach_and_log_err(cx);
 489    }
 490
 491    fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut WindowContext) {
 492        let mut background_ranges = Vec::new();
 493        let mut foreground_ranges = Vec::new();
 494        let empty_inline_assist_ids = Vec::new();
 495        let inline_assist_ids = self
 496            .pending_assist_ids_by_editor
 497            .get(&editor.downgrade())
 498            .map_or(&empty_inline_assist_ids, |pending_assists| {
 499                &pending_assists.assist_ids
 500            });
 501
 502        for inline_assist_id in inline_assist_ids {
 503            if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
 504                let codegen = pending_assist.codegen.read(cx);
 505                background_ranges.push(codegen.range());
 506                foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
 507            }
 508        }
 509
 510        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 511        merge_ranges(&mut background_ranges, &snapshot);
 512        merge_ranges(&mut foreground_ranges, &snapshot);
 513        editor.update(cx, |editor, cx| {
 514            if background_ranges.is_empty() {
 515                editor.clear_background_highlights::<PendingInlineAssist>(cx);
 516            } else {
 517                editor.highlight_background::<PendingInlineAssist>(
 518                    &background_ranges,
 519                    |theme| theme.editor_active_line_background, // TODO use the appropriate color
 520                    cx,
 521                );
 522            }
 523
 524            if foreground_ranges.is_empty() {
 525                editor.clear_highlights::<PendingInlineAssist>(cx);
 526            } else {
 527                editor.highlight_text::<PendingInlineAssist>(
 528                    foreground_ranges,
 529                    HighlightStyle {
 530                        fade_out: Some(0.6),
 531                        ..Default::default()
 532                    },
 533                    cx,
 534                );
 535            }
 536        });
 537    }
 538}
 539
 540fn build_inline_assist_editor_renderer(
 541    editor: &View<InlineAssistEditor>,
 542    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 543) -> RenderBlock {
 544    let editor = editor.clone();
 545    Box::new(move |cx: &mut BlockContext| {
 546        *gutter_dimensions.lock() = *cx.gutter_dimensions;
 547        editor.clone().into_any_element()
 548    })
 549}
 550
 551#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 552struct InlineAssistId(usize);
 553
 554impl InlineAssistId {
 555    fn post_inc(&mut self) -> InlineAssistId {
 556        let id = *self;
 557        self.0 += 1;
 558        id
 559    }
 560}
 561
 562enum InlineAssistEditorEvent {
 563    Confirmed { prompt: String },
 564    Canceled,
 565    Dismissed,
 566    Resized { height_in_lines: u8 },
 567}
 568
 569struct InlineAssistEditor {
 570    id: InlineAssistId,
 571    height_in_lines: u8,
 572    prompt_editor: View<Editor>,
 573    confirmed: bool,
 574    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 575    prompt_history: VecDeque<String>,
 576    prompt_history_ix: Option<usize>,
 577    pending_prompt: String,
 578    codegen: Model<Codegen>,
 579    _subscriptions: Vec<Subscription>,
 580}
 581
 582impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
 583
 584impl Render for InlineAssistEditor {
 585    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 586        let gutter_dimensions = *self.gutter_dimensions.lock();
 587        let icon_size = IconSize::default();
 588        h_flex()
 589            .w_full()
 590            .py_1p5()
 591            .border_y_1()
 592            .border_color(cx.theme().colors().border)
 593            .bg(cx.theme().colors().editor_background)
 594            .on_action(cx.listener(Self::confirm))
 595            .on_action(cx.listener(Self::cancel))
 596            .on_action(cx.listener(Self::move_up))
 597            .on_action(cx.listener(Self::move_down))
 598            .child(
 599                h_flex()
 600                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
 601                    .pr(gutter_dimensions.fold_area_width())
 602                    .justify_end()
 603                    .children(if let Some(error) = self.codegen.read(cx).error() {
 604                        let error_message = SharedString::from(error.to_string());
 605                        Some(
 606                            div()
 607                                .id("error")
 608                                .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
 609                                .child(
 610                                    Icon::new(IconName::XCircle)
 611                                        .size(icon_size)
 612                                        .color(Color::Error),
 613                                ),
 614                        )
 615                    } else {
 616                        None
 617                    }),
 618            )
 619            .child(div().flex_1().child(self.render_prompt_editor(cx)))
 620    }
 621}
 622
 623impl FocusableView for InlineAssistEditor {
 624    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
 625        self.prompt_editor.focus_handle(cx)
 626    }
 627}
 628
 629impl InlineAssistEditor {
 630    const MAX_LINES: u8 = 8;
 631
 632    #[allow(clippy::too_many_arguments)]
 633    fn new(
 634        id: InlineAssistId,
 635        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 636        prompt_history: VecDeque<String>,
 637        codegen: Model<Codegen>,
 638        cx: &mut ViewContext<Self>,
 639    ) -> Self {
 640        let prompt_editor = cx.new_view(|cx| {
 641            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
 642            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
 643            let placeholder = match codegen.read(cx).kind() {
 644                CodegenKind::Transform { .. } => "Enter transformation prompt…",
 645                CodegenKind::Generate { .. } => "Enter generation prompt…",
 646            };
 647            editor.set_placeholder_text(placeholder, cx);
 648            editor
 649        });
 650        cx.focus_view(&prompt_editor);
 651
 652        let subscriptions = vec![
 653            cx.observe(&codegen, Self::handle_codegen_changed),
 654            cx.observe(&prompt_editor, Self::handle_prompt_editor_changed),
 655            cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
 656        ];
 657
 658        let mut this = Self {
 659            id,
 660            height_in_lines: 1,
 661            prompt_editor,
 662            confirmed: false,
 663            gutter_dimensions,
 664            prompt_history,
 665            prompt_history_ix: None,
 666            pending_prompt: String::new(),
 667            codegen,
 668            _subscriptions: subscriptions,
 669        };
 670        this.count_lines(cx);
 671        this
 672    }
 673
 674    fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
 675        let height_in_lines = cmp::max(
 676            2, // Make the editor at least two lines tall, to account for padding.
 677            cmp::min(
 678                self.prompt_editor
 679                    .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
 680                Self::MAX_LINES as u32,
 681            ),
 682        ) as u8;
 683
 684        if height_in_lines != self.height_in_lines {
 685            self.height_in_lines = height_in_lines;
 686            cx.emit(InlineAssistEditorEvent::Resized { height_in_lines });
 687        }
 688    }
 689
 690    fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
 691        self.count_lines(cx);
 692    }
 693
 694    fn handle_prompt_editor_events(
 695        &mut self,
 696        _: View<Editor>,
 697        event: &EditorEvent,
 698        cx: &mut ViewContext<Self>,
 699    ) {
 700        if let EditorEvent::Edited = event {
 701            self.pending_prompt = self.prompt_editor.read(cx).text(cx);
 702            cx.notify();
 703        }
 704    }
 705
 706    fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
 707        let is_read_only = !self.codegen.read(cx).idle();
 708        self.prompt_editor.update(cx, |editor, cx| {
 709            let was_read_only = editor.read_only(cx);
 710            if was_read_only != is_read_only {
 711                if is_read_only {
 712                    editor.set_read_only(true);
 713                } else {
 714                    self.confirmed = false;
 715                    editor.set_read_only(false);
 716                }
 717            }
 718        });
 719        cx.notify();
 720    }
 721
 722    fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
 723        cx.emit(InlineAssistEditorEvent::Canceled);
 724    }
 725
 726    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
 727        if self.confirmed {
 728            cx.emit(InlineAssistEditorEvent::Dismissed);
 729        } else {
 730            let prompt = self.prompt_editor.read(cx).text(cx);
 731            self.prompt_editor
 732                .update(cx, |editor, _cx| editor.set_read_only(true));
 733            cx.emit(InlineAssistEditorEvent::Confirmed { prompt });
 734            self.confirmed = true;
 735            cx.notify();
 736        }
 737    }
 738
 739    fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
 740        if let Some(ix) = self.prompt_history_ix {
 741            if ix > 0 {
 742                self.prompt_history_ix = Some(ix - 1);
 743                let prompt = self.prompt_history[ix - 1].clone();
 744                self.set_prompt(&prompt, cx);
 745            }
 746        } else if !self.prompt_history.is_empty() {
 747            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
 748            let prompt = self.prompt_history[self.prompt_history.len() - 1].clone();
 749            self.set_prompt(&prompt, cx);
 750        }
 751    }
 752
 753    fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
 754        if let Some(ix) = self.prompt_history_ix {
 755            if ix < self.prompt_history.len() - 1 {
 756                self.prompt_history_ix = Some(ix + 1);
 757                let prompt = self.prompt_history[ix + 1].clone();
 758                self.set_prompt(&prompt, cx);
 759            } else {
 760                self.prompt_history_ix = None;
 761                let pending_prompt = self.pending_prompt.clone();
 762                self.set_prompt(&pending_prompt, cx);
 763            }
 764        }
 765    }
 766
 767    fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext<Self>) {
 768        self.prompt_editor.update(cx, |editor, cx| {
 769            editor.buffer().update(cx, |buffer, cx| {
 770                let len = buffer.len(cx);
 771                buffer.edit([(0..len, prompt)], None, cx);
 772            });
 773        });
 774    }
 775
 776    fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 777        let settings = ThemeSettings::get_global(cx);
 778        let text_style = TextStyle {
 779            color: if self.prompt_editor.read(cx).read_only(cx) {
 780                cx.theme().colors().text_disabled
 781            } else {
 782                cx.theme().colors().text
 783            },
 784            font_family: settings.ui_font.family.clone(),
 785            font_features: settings.ui_font.features.clone(),
 786            font_size: rems(0.875).into(),
 787            font_weight: FontWeight::NORMAL,
 788            font_style: FontStyle::Normal,
 789            line_height: relative(1.3),
 790            background_color: None,
 791            underline: None,
 792            strikethrough: None,
 793            white_space: WhiteSpace::Normal,
 794        };
 795        EditorElement::new(
 796            &self.prompt_editor,
 797            EditorStyle {
 798                background: cx.theme().colors().editor_background,
 799                local_player: cx.theme().players().local(),
 800                text: text_style,
 801                ..Default::default()
 802            },
 803        )
 804    }
 805}
 806
 807struct PendingInlineAssist {
 808    editor: WeakView<Editor>,
 809    inline_assist_editor: Option<(BlockId, View<InlineAssistEditor>)>,
 810    codegen: Model<Codegen>,
 811    _subscriptions: Vec<Subscription>,
 812    workspace: Option<WeakView<Workspace>>,
 813    include_conversation: bool,
 814}
 815
 816#[derive(Debug)]
 817pub enum CodegenEvent {
 818    Finished,
 819    Undone,
 820}
 821
 822#[derive(Clone)]
 823pub enum CodegenKind {
 824    Transform { range: Range<Anchor> },
 825    Generate { position: Anchor },
 826}
 827
 828pub struct Codegen {
 829    buffer: Model<MultiBuffer>,
 830    snapshot: MultiBufferSnapshot,
 831    kind: CodegenKind,
 832    last_equal_ranges: Vec<Range<Anchor>>,
 833    transaction_id: Option<TransactionId>,
 834    error: Option<anyhow::Error>,
 835    generation: Task<()>,
 836    idle: bool,
 837    telemetry: Option<Arc<Telemetry>>,
 838    _subscription: gpui::Subscription,
 839}
 840
 841impl EventEmitter<CodegenEvent> for Codegen {}
 842
 843impl Codegen {
 844    pub fn new(
 845        buffer: Model<MultiBuffer>,
 846        kind: CodegenKind,
 847        telemetry: Option<Arc<Telemetry>>,
 848        cx: &mut ModelContext<Self>,
 849    ) -> Self {
 850        let snapshot = buffer.read(cx).snapshot(cx);
 851        Self {
 852            buffer: buffer.clone(),
 853            snapshot,
 854            kind,
 855            last_equal_ranges: Default::default(),
 856            transaction_id: Default::default(),
 857            error: Default::default(),
 858            idle: true,
 859            generation: Task::ready(()),
 860            telemetry,
 861            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 862        }
 863    }
 864
 865    fn handle_buffer_event(
 866        &mut self,
 867        _buffer: Model<MultiBuffer>,
 868        event: &multi_buffer::Event,
 869        cx: &mut ModelContext<Self>,
 870    ) {
 871        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 872            if self.transaction_id == Some(*transaction_id) {
 873                self.transaction_id = None;
 874                self.generation = Task::ready(());
 875                cx.emit(CodegenEvent::Undone);
 876            }
 877        }
 878    }
 879
 880    pub fn range(&self) -> Range<Anchor> {
 881        match &self.kind {
 882            CodegenKind::Transform { range } => range.clone(),
 883            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 884        }
 885    }
 886
 887    pub fn kind(&self) -> &CodegenKind {
 888        &self.kind
 889    }
 890
 891    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 892        &self.last_equal_ranges
 893    }
 894
 895    pub fn idle(&self) -> bool {
 896        self.idle
 897    }
 898
 899    pub fn error(&self) -> Option<&anyhow::Error> {
 900        self.error.as_ref()
 901    }
 902
 903    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
 904        let range = self.range();
 905        let snapshot = self.snapshot.clone();
 906        let selected_text = snapshot
 907            .text_for_range(range.start..range.end)
 908            .collect::<Rope>();
 909
 910        let selection_start = range.start.to_point(&snapshot);
 911        let suggested_line_indent = snapshot
 912            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
 913            .into_values()
 914            .next()
 915            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 916
 917        let model_telemetry_id = prompt.model.telemetry_id();
 918        let response = CompletionProvider::global(cx).complete(prompt);
 919        let telemetry = self.telemetry.clone();
 920        self.generation = cx.spawn(|this, mut cx| {
 921            async move {
 922                let generate = async {
 923                    let mut edit_start = range.start.to_offset(&snapshot);
 924
 925                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 926                    let diff: Task<anyhow::Result<()>> =
 927                        cx.background_executor().spawn(async move {
 928                            let mut response_latency = None;
 929                            let request_start = Instant::now();
 930                            let diff = async {
 931                                let chunks = strip_invalid_spans_from_codeblock(response.await?);
 932                                futures::pin_mut!(chunks);
 933                                let mut diff = StreamingDiff::new(selected_text.to_string());
 934
 935                                let mut new_text = String::new();
 936                                let mut base_indent = None;
 937                                let mut line_indent = None;
 938                                let mut first_line = true;
 939
 940                                while let Some(chunk) = chunks.next().await {
 941                                    if response_latency.is_none() {
 942                                        response_latency = Some(request_start.elapsed());
 943                                    }
 944                                    let chunk = chunk?;
 945
 946                                    let mut lines = chunk.split('\n').peekable();
 947                                    while let Some(line) = lines.next() {
 948                                        new_text.push_str(line);
 949                                        if line_indent.is_none() {
 950                                            if let Some(non_whitespace_ch_ix) =
 951                                                new_text.find(|ch: char| !ch.is_whitespace())
 952                                            {
 953                                                line_indent = Some(non_whitespace_ch_ix);
 954                                                base_indent = base_indent.or(line_indent);
 955
 956                                                let line_indent = line_indent.unwrap();
 957                                                let base_indent = base_indent.unwrap();
 958                                                let indent_delta =
 959                                                    line_indent as i32 - base_indent as i32;
 960                                                let mut corrected_indent_len = cmp::max(
 961                                                    0,
 962                                                    suggested_line_indent.len as i32 + indent_delta,
 963                                                )
 964                                                    as usize;
 965                                                if first_line {
 966                                                    corrected_indent_len = corrected_indent_len
 967                                                        .saturating_sub(
 968                                                            selection_start.column as usize,
 969                                                        );
 970                                                }
 971
 972                                                let indent_char = suggested_line_indent.char();
 973                                                let mut indent_buffer = [0; 4];
 974                                                let indent_str =
 975                                                    indent_char.encode_utf8(&mut indent_buffer);
 976                                                new_text.replace_range(
 977                                                    ..line_indent,
 978                                                    &indent_str.repeat(corrected_indent_len),
 979                                                );
 980                                            }
 981                                        }
 982
 983                                        if line_indent.is_some() {
 984                                            hunks_tx.send(diff.push_new(&new_text)).await?;
 985                                            new_text.clear();
 986                                        }
 987
 988                                        if lines.peek().is_some() {
 989                                            hunks_tx.send(diff.push_new("\n")).await?;
 990                                            line_indent = None;
 991                                            first_line = false;
 992                                        }
 993                                    }
 994                                }
 995                                hunks_tx.send(diff.push_new(&new_text)).await?;
 996                                hunks_tx.send(diff.finish()).await?;
 997
 998                                anyhow::Ok(())
 999                            };
1000
1001                            let result = diff.await;
1002
1003                            let error_message =
1004                                result.as_ref().err().map(|error| error.to_string());
1005                            if let Some(telemetry) = telemetry {
1006                                telemetry.report_assistant_event(
1007                                    None,
1008                                    telemetry_events::AssistantKind::Inline,
1009                                    model_telemetry_id,
1010                                    response_latency,
1011                                    error_message,
1012                                );
1013                            }
1014
1015                            result?;
1016                            Ok(())
1017                        });
1018
1019                    while let Some(hunks) = hunks_rx.next().await {
1020                        this.update(&mut cx, |this, cx| {
1021                            this.last_equal_ranges.clear();
1022
1023                            let transaction = this.buffer.update(cx, |buffer, cx| {
1024                                // Avoid grouping assistant edits with user edits.
1025                                buffer.finalize_last_transaction(cx);
1026
1027                                buffer.start_transaction(cx);
1028                                buffer.edit(
1029                                    hunks.into_iter().filter_map(|hunk| match hunk {
1030                                        Hunk::Insert { text } => {
1031                                            let edit_start = snapshot.anchor_after(edit_start);
1032                                            Some((edit_start..edit_start, text))
1033                                        }
1034                                        Hunk::Remove { len } => {
1035                                            let edit_end = edit_start + len;
1036                                            let edit_range = snapshot.anchor_after(edit_start)
1037                                                ..snapshot.anchor_before(edit_end);
1038                                            edit_start = edit_end;
1039                                            Some((edit_range, String::new()))
1040                                        }
1041                                        Hunk::Keep { len } => {
1042                                            let edit_end = edit_start + len;
1043                                            let edit_range = snapshot.anchor_after(edit_start)
1044                                                ..snapshot.anchor_before(edit_end);
1045                                            edit_start = edit_end;
1046                                            this.last_equal_ranges.push(edit_range);
1047                                            None
1048                                        }
1049                                    }),
1050                                    None,
1051                                    cx,
1052                                );
1053
1054                                buffer.end_transaction(cx)
1055                            });
1056
1057                            if let Some(transaction) = transaction {
1058                                if let Some(first_transaction) = this.transaction_id {
1059                                    // Group all assistant edits into the first transaction.
1060                                    this.buffer.update(cx, |buffer, cx| {
1061                                        buffer.merge_transactions(
1062                                            transaction,
1063                                            first_transaction,
1064                                            cx,
1065                                        )
1066                                    });
1067                                } else {
1068                                    this.transaction_id = Some(transaction);
1069                                    this.buffer.update(cx, |buffer, cx| {
1070                                        buffer.finalize_last_transaction(cx)
1071                                    });
1072                                }
1073                            }
1074
1075                            cx.notify();
1076                        })?;
1077                    }
1078
1079                    diff.await?;
1080
1081                    anyhow::Ok(())
1082                };
1083
1084                let result = generate.await;
1085                this.update(&mut cx, |this, cx| {
1086                    this.last_equal_ranges.clear();
1087                    this.idle = true;
1088                    if let Err(error) = result {
1089                        this.error = Some(error);
1090                    }
1091                    cx.emit(CodegenEvent::Finished);
1092                    cx.notify();
1093                })
1094                .ok();
1095            }
1096        });
1097        self.error.take();
1098        self.idle = false;
1099        cx.notify();
1100    }
1101
1102    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1103        if let Some(transaction_id) = self.transaction_id {
1104            self.buffer
1105                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1106        }
1107    }
1108}
1109
1110fn strip_invalid_spans_from_codeblock(
1111    stream: impl Stream<Item = Result<String>>,
1112) -> impl Stream<Item = Result<String>> {
1113    let mut first_line = true;
1114    let mut buffer = String::new();
1115    let mut starts_with_markdown_codeblock = false;
1116    let mut includes_start_or_end_span = false;
1117    stream.filter_map(move |chunk| {
1118        let chunk = match chunk {
1119            Ok(chunk) => chunk,
1120            Err(err) => return future::ready(Some(Err(err))),
1121        };
1122        buffer.push_str(&chunk);
1123
1124        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1125            includes_start_or_end_span = true;
1126
1127            buffer = buffer
1128                .strip_prefix("<|S|>")
1129                .or_else(|| buffer.strip_prefix("<|S|"))
1130                .unwrap_or(&buffer)
1131                .to_string();
1132        } else if buffer.ends_with("|E|>") {
1133            includes_start_or_end_span = true;
1134        } else if buffer.starts_with("<|")
1135            || buffer.starts_with("<|S")
1136            || buffer.starts_with("<|S|")
1137            || buffer.ends_with('|')
1138            || buffer.ends_with("|E")
1139            || buffer.ends_with("|E|")
1140        {
1141            return future::ready(None);
1142        }
1143
1144        if first_line {
1145            if buffer.is_empty() || buffer == "`" || buffer == "``" {
1146                return future::ready(None);
1147            } else if buffer.starts_with("```") {
1148                starts_with_markdown_codeblock = true;
1149                if let Some(newline_ix) = buffer.find('\n') {
1150                    buffer.replace_range(..newline_ix + 1, "");
1151                    first_line = false;
1152                } else {
1153                    return future::ready(None);
1154                }
1155            }
1156        }
1157
1158        let mut text = buffer.to_string();
1159        if starts_with_markdown_codeblock {
1160            text = text
1161                .strip_suffix("\n```\n")
1162                .or_else(|| text.strip_suffix("\n```"))
1163                .or_else(|| text.strip_suffix("\n``"))
1164                .or_else(|| text.strip_suffix("\n`"))
1165                .or_else(|| text.strip_suffix('\n'))
1166                .unwrap_or(&text)
1167                .to_string();
1168        }
1169
1170        if includes_start_or_end_span {
1171            text = text
1172                .strip_suffix("|E|>")
1173                .or_else(|| text.strip_suffix("E|>"))
1174                .or_else(|| text.strip_prefix("|>"))
1175                .or_else(|| text.strip_prefix('>'))
1176                .unwrap_or(&text)
1177                .to_string();
1178        };
1179
1180        if text.contains('\n') {
1181            first_line = false;
1182        }
1183
1184        let remainder = buffer.split_off(text.len());
1185        let result = if buffer.is_empty() {
1186            None
1187        } else {
1188            Some(Ok(buffer.clone()))
1189        };
1190
1191        buffer = remainder;
1192        future::ready(result)
1193    })
1194}
1195
1196fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1197    ranges.sort_unstable_by(|a, b| {
1198        a.start
1199            .cmp(&b.start, buffer)
1200            .then_with(|| b.end.cmp(&a.end, buffer))
1201    });
1202
1203    let mut ix = 0;
1204    while ix + 1 < ranges.len() {
1205        let b = ranges[ix + 1].clone();
1206        let a = &mut ranges[ix];
1207        if a.end.cmp(&b.start, buffer).is_gt() {
1208            if a.end.cmp(&b.end, buffer).is_lt() {
1209                a.end = b.end;
1210            }
1211            ranges.remove(ix + 1);
1212        } else {
1213            ix += 1;
1214        }
1215    }
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220    use std::sync::Arc;
1221
1222    use crate::FakeCompletionProvider;
1223
1224    use super::*;
1225    use futures::stream::{self};
1226    use gpui::{Context, TestAppContext};
1227    use indoc::indoc;
1228    use language::{
1229        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1230        Point,
1231    };
1232    use rand::prelude::*;
1233    use serde::Serialize;
1234    use settings::SettingsStore;
1235
1236    #[derive(Serialize)]
1237    pub struct DummyCompletionRequest {
1238        pub name: String,
1239    }
1240
1241    #[gpui::test(iterations = 10)]
1242    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1243        let provider = FakeCompletionProvider::default();
1244        cx.set_global(cx.update(SettingsStore::test));
1245        cx.set_global(CompletionProvider::Fake(provider.clone()));
1246        cx.update(language_settings::init);
1247
1248        let text = indoc! {"
1249            fn main() {
1250                let x = 0;
1251                for _ in 0..10 {
1252                    x += 1;
1253                }
1254            }
1255        "};
1256        let buffer =
1257            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1258        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1259        let range = buffer.read_with(cx, |buffer, cx| {
1260            let snapshot = buffer.snapshot(cx);
1261            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1262        });
1263        let codegen = cx.new_model(|cx| {
1264            Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1265        });
1266
1267        let request = LanguageModelRequest::default();
1268        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1269
1270        let mut new_text = concat!(
1271            "       let mut x = 0;\n",
1272            "       while x < 10 {\n",
1273            "           x += 1;\n",
1274            "       }",
1275        );
1276        while !new_text.is_empty() {
1277            let max_len = cmp::min(new_text.len(), 10);
1278            let len = rng.gen_range(1..=max_len);
1279            let (chunk, suffix) = new_text.split_at(len);
1280            provider.send_completion(chunk.into());
1281            new_text = suffix;
1282            cx.background_executor.run_until_parked();
1283        }
1284        provider.finish_completion();
1285        cx.background_executor.run_until_parked();
1286
1287        assert_eq!(
1288            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1289            indoc! {"
1290                fn main() {
1291                    let mut x = 0;
1292                    while x < 10 {
1293                        x += 1;
1294                    }
1295                }
1296            "}
1297        );
1298    }
1299
1300    #[gpui::test(iterations = 10)]
1301    async fn test_autoindent_when_generating_past_indentation(
1302        cx: &mut TestAppContext,
1303        mut rng: StdRng,
1304    ) {
1305        let provider = FakeCompletionProvider::default();
1306        cx.set_global(CompletionProvider::Fake(provider.clone()));
1307        cx.set_global(cx.update(SettingsStore::test));
1308        cx.update(language_settings::init);
1309
1310        let text = indoc! {"
1311            fn main() {
1312                le
1313            }
1314        "};
1315        let buffer =
1316            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1317        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1318        let position = buffer.read_with(cx, |buffer, cx| {
1319            let snapshot = buffer.snapshot(cx);
1320            snapshot.anchor_before(Point::new(1, 6))
1321        });
1322        let codegen = cx.new_model(|cx| {
1323            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1324        });
1325
1326        let request = LanguageModelRequest::default();
1327        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1328
1329        let mut new_text = concat!(
1330            "t mut x = 0;\n",
1331            "while x < 10 {\n",
1332            "    x += 1;\n",
1333            "}", //
1334        );
1335        while !new_text.is_empty() {
1336            let max_len = cmp::min(new_text.len(), 10);
1337            let len = rng.gen_range(1..=max_len);
1338            let (chunk, suffix) = new_text.split_at(len);
1339            provider.send_completion(chunk.into());
1340            new_text = suffix;
1341            cx.background_executor.run_until_parked();
1342        }
1343        provider.finish_completion();
1344        cx.background_executor.run_until_parked();
1345
1346        assert_eq!(
1347            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1348            indoc! {"
1349                fn main() {
1350                    let mut x = 0;
1351                    while x < 10 {
1352                        x += 1;
1353                    }
1354                }
1355            "}
1356        );
1357    }
1358
1359    #[gpui::test(iterations = 10)]
1360    async fn test_autoindent_when_generating_before_indentation(
1361        cx: &mut TestAppContext,
1362        mut rng: StdRng,
1363    ) {
1364        let provider = FakeCompletionProvider::default();
1365        cx.set_global(CompletionProvider::Fake(provider.clone()));
1366        cx.set_global(cx.update(SettingsStore::test));
1367        cx.update(language_settings::init);
1368
1369        let text = concat!(
1370            "fn main() {\n",
1371            "  \n",
1372            "}\n" //
1373        );
1374        let buffer =
1375            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1376        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1377        let position = buffer.read_with(cx, |buffer, cx| {
1378            let snapshot = buffer.snapshot(cx);
1379            snapshot.anchor_before(Point::new(1, 2))
1380        });
1381        let codegen = cx.new_model(|cx| {
1382            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1383        });
1384
1385        let request = LanguageModelRequest::default();
1386        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1387
1388        let mut new_text = concat!(
1389            "let mut x = 0;\n",
1390            "while x < 10 {\n",
1391            "    x += 1;\n",
1392            "}", //
1393        );
1394        while !new_text.is_empty() {
1395            let max_len = cmp::min(new_text.len(), 10);
1396            let len = rng.gen_range(1..=max_len);
1397            let (chunk, suffix) = new_text.split_at(len);
1398            provider.send_completion(chunk.into());
1399            new_text = suffix;
1400            cx.background_executor.run_until_parked();
1401        }
1402        provider.finish_completion();
1403        cx.background_executor.run_until_parked();
1404
1405        assert_eq!(
1406            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1407            indoc! {"
1408                fn main() {
1409                    let mut x = 0;
1410                    while x < 10 {
1411                        x += 1;
1412                    }
1413                }
1414            "}
1415        );
1416    }
1417
1418    #[gpui::test]
1419    async fn test_strip_invalid_spans_from_codeblock() {
1420        assert_eq!(
1421            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1422                .map(|chunk| chunk.unwrap())
1423                .collect::<String>()
1424                .await,
1425            "Lorem ipsum dolor"
1426        );
1427        assert_eq!(
1428            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1429                .map(|chunk| chunk.unwrap())
1430                .collect::<String>()
1431                .await,
1432            "Lorem ipsum dolor"
1433        );
1434        assert_eq!(
1435            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1436                .map(|chunk| chunk.unwrap())
1437                .collect::<String>()
1438                .await,
1439            "Lorem ipsum dolor"
1440        );
1441        assert_eq!(
1442            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1443                .map(|chunk| chunk.unwrap())
1444                .collect::<String>()
1445                .await,
1446            "Lorem ipsum dolor"
1447        );
1448        assert_eq!(
1449            strip_invalid_spans_from_codeblock(chunks(
1450                "```html\n```js\nLorem ipsum dolor\n```\n```",
1451                2
1452            ))
1453            .map(|chunk| chunk.unwrap())
1454            .collect::<String>()
1455            .await,
1456            "```js\nLorem ipsum dolor\n```"
1457        );
1458        assert_eq!(
1459            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1460                .map(|chunk| chunk.unwrap())
1461                .collect::<String>()
1462                .await,
1463            "``\nLorem ipsum dolor\n```"
1464        );
1465        assert_eq!(
1466            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1467                .map(|chunk| chunk.unwrap())
1468                .collect::<String>()
1469                .await,
1470            "Lorem ipsum"
1471        );
1472
1473        assert_eq!(
1474            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
1475                .map(|chunk| chunk.unwrap())
1476                .collect::<String>()
1477                .await,
1478            "Lorem ipsum"
1479        );
1480
1481        assert_eq!(
1482            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
1483                .map(|chunk| chunk.unwrap())
1484                .collect::<String>()
1485                .await,
1486            "Lorem ipsum"
1487        );
1488        assert_eq!(
1489            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
1490                .map(|chunk| chunk.unwrap())
1491                .collect::<String>()
1492                .await,
1493            "Lorem ipsum"
1494        );
1495        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1496            stream::iter(
1497                text.chars()
1498                    .collect::<Vec<_>>()
1499                    .chunks(size)
1500                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1501                    .collect::<Vec<_>>(),
1502            )
1503        }
1504    }
1505
1506    fn rust_lang() -> Language {
1507        Language::new(
1508            LanguageConfig {
1509                name: "Rust".into(),
1510                matcher: LanguageMatcher {
1511                    path_suffixes: vec!["rs".to_string()],
1512                    ..Default::default()
1513                },
1514                ..Default::default()
1515            },
1516            Some(tree_sitter_rust::language()),
1517        )
1518        .with_indents_query(
1519            r#"
1520            (call_expression) @indent
1521            (field_expression) @indent
1522            (_ "(" ")" @end) @indent
1523            (_ "{" "}" @end) @indent
1524            "#,
1525        )
1526        .unwrap()
1527    }
1528}