inline_assistant.rs

   1use crate::{
   2    prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
   3    LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
   4};
   5use anyhow::Result;
   6use client::telemetry::Telemetry;
   7use collections::{hash_map, HashMap, HashSet, VecDeque};
   8use editor::{
   9    actions::{MoveDown, MoveUp},
  10    display_map::{
  11        BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
  12    },
  13    scroll::{Autoscroll, AutoscrollStrategy},
  14    Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer,
  15    MultiBufferSnapshot, ToOffset, ToPoint,
  16};
  17use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  18use gpui::{
  19    AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
  20    Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
  21    ViewContext, WeakView, WhiteSpace, WindowContext,
  22};
  23use language::{Point, TransactionId};
  24use multi_buffer::MultiBufferRow;
  25use parking_lot::Mutex;
  26use rope::Rope;
  27use settings::Settings;
  28use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
  29use theme::ThemeSettings;
  30use ui::{prelude::*, Tooltip};
  31use workspace::{notifications::NotificationId, Toast, Workspace};
  32
  33pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
  34    cx.set_global(InlineAssistant::new(telemetry));
  35}
  36
  37const PROMPT_HISTORY_MAX_LEN: usize = 20;
  38
  39pub struct InlineAssistant {
  40    next_assist_id: InlineAssistId,
  41    pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
  42    pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
  43    prompt_history: VecDeque<String>,
  44    telemetry: Option<Arc<Telemetry>>,
  45}
  46
  47struct EditorPendingAssists {
  48    window: AnyWindowHandle,
  49    assist_ids: Vec<InlineAssistId>,
  50}
  51
  52impl Global for InlineAssistant {}
  53
  54impl InlineAssistant {
  55    pub fn new(telemetry: Arc<Telemetry>) -> Self {
  56        Self {
  57            next_assist_id: InlineAssistId::default(),
  58            pending_assists: HashMap::default(),
  59            pending_assist_ids_by_editor: HashMap::default(),
  60            prompt_history: VecDeque::default(),
  61            telemetry: Some(telemetry),
  62        }
  63    }
  64
  65    pub fn assist(
  66        &mut self,
  67        editor: &View<Editor>,
  68        workspace: Option<WeakView<Workspace>>,
  69        include_context: bool,
  70        cx: &mut WindowContext,
  71    ) {
  72        let selection = editor.read(cx).selections.newest_anchor().clone();
  73        if selection.start.excerpt_id != selection.end.excerpt_id {
  74            return;
  75        }
  76        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
  77
  78        // Extend the selection to the start and the end of the line.
  79        let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
  80        if point_selection.end > point_selection.start {
  81            point_selection.start.column = 0;
  82            // If the selection ends at the start of the line, we don't want to include it.
  83            if point_selection.end.column == 0 {
  84                point_selection.end.row -= 1;
  85            }
  86            point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
  87        }
  88
  89        let codegen_kind = if point_selection.start == point_selection.end {
  90            CodegenKind::Generate {
  91                position: snapshot.anchor_after(point_selection.start),
  92            }
  93        } else {
  94            CodegenKind::Transform {
  95                range: snapshot.anchor_before(point_selection.start)
  96                    ..snapshot.anchor_after(point_selection.end),
  97            }
  98        };
  99
 100        let inline_assist_id = self.next_assist_id.post_inc();
 101        let codegen = cx.new_model(|cx| {
 102            Codegen::new(
 103                editor.read(cx).buffer().clone(),
 104                codegen_kind,
 105                self.telemetry.clone(),
 106                cx,
 107            )
 108        });
 109
 110        let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
 111        let inline_assist_editor = cx.new_view(|cx| {
 112            InlineAssistEditor::new(
 113                inline_assist_id,
 114                gutter_dimensions.clone(),
 115                self.prompt_history.clone(),
 116                codegen.clone(),
 117                cx,
 118            )
 119        });
 120        let block_id = editor.update(cx, |editor, cx| {
 121            editor.change_selections(None, cx, |selections| {
 122                selections.select_anchor_ranges([selection.head()..selection.head()])
 123            });
 124            editor.insert_blocks(
 125                [BlockProperties {
 126                    style: BlockStyle::Sticky,
 127                    position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)),
 128                    height: inline_assist_editor.read(cx).height_in_lines,
 129                    render: build_inline_assist_editor_renderer(
 130                        &inline_assist_editor,
 131                        gutter_dimensions,
 132                    ),
 133                    disposition: if selection.reversed {
 134                        BlockDisposition::Above
 135                    } else {
 136                        BlockDisposition::Below
 137                    },
 138                }],
 139                Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
 140                cx,
 141            )[0]
 142        });
 143
 144        self.pending_assists.insert(
 145            inline_assist_id,
 146            PendingInlineAssist {
 147                include_context,
 148                editor: editor.downgrade(),
 149                inline_assist_editor: Some((block_id, inline_assist_editor.clone())),
 150                codegen: codegen.clone(),
 151                workspace,
 152                _subscriptions: vec![
 153                    cx.subscribe(&inline_assist_editor, |inline_assist_editor, event, cx| {
 154                        InlineAssistant::update_global(cx, |this, cx| {
 155                            this.handle_inline_assistant_event(inline_assist_editor, event, cx)
 156                        })
 157                    }),
 158                    cx.subscribe(editor, {
 159                        let inline_assist_editor = inline_assist_editor.downgrade();
 160                        move |editor, event, cx| {
 161                            if let Some(inline_assist_editor) = inline_assist_editor.upgrade() {
 162                                if let EditorEvent::SelectionsChanged { local } = event {
 163                                    if *local
 164                                        && inline_assist_editor
 165                                            .focus_handle(cx)
 166                                            .contains_focused(cx)
 167                                    {
 168                                        cx.focus_view(&editor);
 169                                    }
 170                                }
 171                            }
 172                        }
 173                    }),
 174                    cx.observe(&codegen, {
 175                        let editor = editor.downgrade();
 176                        move |_, cx| {
 177                            if let Some(editor) = editor.upgrade() {
 178                                InlineAssistant::update_global(cx, |this, cx| {
 179                                    this.update_highlights_for_editor(&editor, cx);
 180                                })
 181                            }
 182                        }
 183                    }),
 184                    cx.subscribe(&codegen, move |codegen, event, cx| {
 185                        InlineAssistant::update_global(cx, |this, cx| match event {
 186                            CodegenEvent::Undone => {
 187                                this.finish_inline_assist(inline_assist_id, false, cx)
 188                            }
 189                            CodegenEvent::Finished => {
 190                                let pending_assist = if let Some(pending_assist) =
 191                                    this.pending_assists.get(&inline_assist_id)
 192                                {
 193                                    pending_assist
 194                                } else {
 195                                    return;
 196                                };
 197
 198                                let error = codegen
 199                                    .read(cx)
 200                                    .error()
 201                                    .map(|error| format!("Inline assistant error: {}", error));
 202                                if let Some(error) = error {
 203                                    if pending_assist.inline_assist_editor.is_none() {
 204                                        if let Some(workspace) = pending_assist
 205                                            .workspace
 206                                            .as_ref()
 207                                            .and_then(|workspace| workspace.upgrade())
 208                                        {
 209                                            workspace.update(cx, |workspace, cx| {
 210                                                struct InlineAssistantError;
 211
 212                                                let id = NotificationId::identified::<
 213                                                    InlineAssistantError,
 214                                                >(
 215                                                    inline_assist_id.0
 216                                                );
 217
 218                                                workspace.show_toast(Toast::new(id, error), cx);
 219                                            })
 220                                        }
 221
 222                                        this.finish_inline_assist(inline_assist_id, false, cx);
 223                                    }
 224                                } else {
 225                                    this.finish_inline_assist(inline_assist_id, false, cx);
 226                                }
 227                            }
 228                        })
 229                    }),
 230                ],
 231            },
 232        );
 233
 234        self.pending_assist_ids_by_editor
 235            .entry(editor.downgrade())
 236            .or_insert_with(|| EditorPendingAssists {
 237                window: cx.window_handle(),
 238                assist_ids: Vec::new(),
 239            })
 240            .assist_ids
 241            .push(inline_assist_id);
 242        self.update_highlights_for_editor(editor, cx);
 243    }
 244
 245    fn handle_inline_assistant_event(
 246        &mut self,
 247        inline_assist_editor: View<InlineAssistEditor>,
 248        event: &InlineAssistEditorEvent,
 249        cx: &mut WindowContext,
 250    ) {
 251        let assist_id = inline_assist_editor.read(cx).id;
 252        match event {
 253            InlineAssistEditorEvent::Confirmed { prompt } => {
 254                self.confirm_inline_assist(assist_id, prompt, cx);
 255            }
 256            InlineAssistEditorEvent::Canceled => {
 257                self.finish_inline_assist(assist_id, true, cx);
 258            }
 259            InlineAssistEditorEvent::Dismissed => {
 260                self.hide_inline_assist(assist_id, cx);
 261            }
 262            InlineAssistEditorEvent::Resized { height_in_lines } => {
 263                self.resize_inline_assist(assist_id, *height_in_lines, cx);
 264            }
 265        }
 266    }
 267
 268    pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
 269        for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
 270            if pending_assists.window == cx.window_handle() {
 271                if let Some(editor) = editor.upgrade() {
 272                    if editor.read(cx).is_focused(cx) {
 273                        if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
 274                            self.finish_inline_assist(assist_id, true, cx);
 275                            return true;
 276                        }
 277                    }
 278                }
 279            }
 280        }
 281        false
 282    }
 283
 284    fn finish_inline_assist(
 285        &mut self,
 286        assist_id: InlineAssistId,
 287        undo: bool,
 288        cx: &mut WindowContext,
 289    ) {
 290        self.hide_inline_assist(assist_id, cx);
 291
 292        if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
 293            if let hash_map::Entry::Occupied(mut entry) = self
 294                .pending_assist_ids_by_editor
 295                .entry(pending_assist.editor.clone())
 296            {
 297                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 298                if entry.get().assist_ids.is_empty() {
 299                    entry.remove();
 300                }
 301            }
 302
 303            if let Some(editor) = pending_assist.editor.upgrade() {
 304                self.update_highlights_for_editor(&editor, cx);
 305
 306                if undo {
 307                    pending_assist
 308                        .codegen
 309                        .update(cx, |codegen, cx| codegen.undo(cx));
 310                }
 311            }
 312        }
 313    }
 314
 315    fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 316        if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
 317            if let Some(editor) = pending_assist.editor.upgrade() {
 318                if let Some((block_id, inline_assist_editor)) =
 319                    pending_assist.inline_assist_editor.take()
 320                {
 321                    editor.update(cx, |editor, cx| {
 322                        editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
 323                        if inline_assist_editor.focus_handle(cx).contains_focused(cx) {
 324                            editor.focus(cx);
 325                        }
 326                    });
 327                }
 328            }
 329        }
 330    }
 331
 332    fn resize_inline_assist(
 333        &mut self,
 334        assist_id: InlineAssistId,
 335        height_in_lines: u8,
 336        cx: &mut WindowContext,
 337    ) {
 338        if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
 339            if let Some(editor) = pending_assist.editor.upgrade() {
 340                if let Some((block_id, inline_assist_editor)) =
 341                    pending_assist.inline_assist_editor.as_ref()
 342                {
 343                    let gutter_dimensions = inline_assist_editor.read(cx).gutter_dimensions.clone();
 344                    let mut new_blocks = HashMap::default();
 345                    new_blocks.insert(
 346                        *block_id,
 347                        (
 348                            Some(height_in_lines),
 349                            build_inline_assist_editor_renderer(
 350                                inline_assist_editor,
 351                                gutter_dimensions,
 352                            ),
 353                        ),
 354                    );
 355                    editor.update(cx, |editor, cx| {
 356                        editor
 357                            .display_map
 358                            .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
 359                    });
 360                }
 361            }
 362        }
 363    }
 364
 365    fn confirm_inline_assist(
 366        &mut self,
 367        assist_id: InlineAssistId,
 368        user_prompt: &str,
 369        cx: &mut WindowContext,
 370    ) {
 371        let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
 372        {
 373            pending_assist
 374        } else {
 375            return;
 376        };
 377
 378        let context = if pending_assist.include_context {
 379            pending_assist.workspace.as_ref().and_then(|workspace| {
 380                let workspace = workspace.upgrade()?.read(cx);
 381                let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
 382                assistant_panel.read(cx).active_context(cx)
 383            })
 384        } else {
 385            None
 386        };
 387
 388        let editor = if let Some(editor) = pending_assist.editor.upgrade() {
 389            editor
 390        } else {
 391            return;
 392        };
 393
 394        let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
 395            let workspace = workspace.upgrade()?;
 396            Some(
 397                workspace
 398                    .read(cx)
 399                    .project()
 400                    .read(cx)
 401                    .worktree_root_names(cx)
 402                    .collect::<Vec<&str>>()
 403                    .join("/"),
 404            )
 405        });
 406
 407        self.prompt_history.retain(|prompt| prompt != user_prompt);
 408        self.prompt_history.push_back(user_prompt.into());
 409        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
 410            self.prompt_history.pop_front();
 411        }
 412
 413        let codegen = pending_assist.codegen.clone();
 414        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 415        let range = codegen.read(cx).range();
 416        let start = snapshot.point_to_buffer_offset(range.start);
 417        let end = snapshot.point_to_buffer_offset(range.end);
 418        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 419            let (start_buffer, start_buffer_offset) = start;
 420            let (end_buffer, end_buffer_offset) = end;
 421            if start_buffer.remote_id() == end_buffer.remote_id() {
 422                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 423            } else {
 424                self.finish_inline_assist(assist_id, false, cx);
 425                return;
 426            }
 427        } else {
 428            self.finish_inline_assist(assist_id, false, cx);
 429            return;
 430        };
 431
 432        let language = buffer.language_at(range.start);
 433        let language_name = if let Some(language) = language.as_ref() {
 434            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 435                None
 436            } else {
 437                Some(language.name())
 438            }
 439        } else {
 440            None
 441        };
 442
 443        // Higher Temperature increases the randomness of model outputs.
 444        // If Markdown or No Language is Known, increase the randomness for more creative output
 445        // If Code, decrease temperature to get more deterministic outputs
 446        let temperature = if let Some(language) = language_name.clone() {
 447            if language.as_ref() == "Markdown" {
 448                1.0
 449            } else {
 450                0.5
 451            }
 452        } else {
 453            1.0
 454        };
 455
 456        let user_prompt = user_prompt.to_string();
 457
 458        let prompt = cx.background_executor().spawn(async move {
 459            let language_name = language_name.as_deref();
 460            generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
 461        });
 462
 463        let mut messages = Vec::new();
 464        if let Some(context) = context {
 465            let request = context.read(cx).to_completion_request(cx);
 466            messages = request.messages;
 467        }
 468        let model = CompletionProvider::global(cx).model();
 469
 470        cx.spawn(|mut cx| async move {
 471            let prompt = prompt.await?;
 472
 473            messages.push(LanguageModelRequestMessage {
 474                role: Role::User,
 475                content: prompt,
 476            });
 477
 478            let request = LanguageModelRequest {
 479                model,
 480                messages,
 481                stop: vec!["|END|>".to_string()],
 482                temperature,
 483            };
 484
 485            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
 486            anyhow::Ok(())
 487        })
 488        .detach_and_log_err(cx);
 489    }
 490
 491    fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut WindowContext) {
 492        let mut background_ranges = Vec::new();
 493        let mut foreground_ranges = Vec::new();
 494        let empty_inline_assist_ids = Vec::new();
 495        let inline_assist_ids = self
 496            .pending_assist_ids_by_editor
 497            .get(&editor.downgrade())
 498            .map_or(&empty_inline_assist_ids, |pending_assists| {
 499                &pending_assists.assist_ids
 500            });
 501
 502        for inline_assist_id in inline_assist_ids {
 503            if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
 504                let codegen = pending_assist.codegen.read(cx);
 505                background_ranges.push(codegen.range());
 506                foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
 507            }
 508        }
 509
 510        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 511        merge_ranges(&mut background_ranges, &snapshot);
 512        merge_ranges(&mut foreground_ranges, &snapshot);
 513        editor.update(cx, |editor, cx| {
 514            if background_ranges.is_empty() {
 515                editor.clear_background_highlights::<PendingInlineAssist>(cx);
 516            } else {
 517                editor.highlight_background::<PendingInlineAssist>(
 518                    &background_ranges,
 519                    |theme| theme.editor_active_line_background, // TODO use the appropriate color
 520                    cx,
 521                );
 522            }
 523
 524            if foreground_ranges.is_empty() {
 525                editor.clear_highlights::<PendingInlineAssist>(cx);
 526            } else {
 527                editor.highlight_text::<PendingInlineAssist>(
 528                    foreground_ranges,
 529                    HighlightStyle {
 530                        fade_out: Some(0.6),
 531                        ..Default::default()
 532                    },
 533                    cx,
 534                );
 535            }
 536        });
 537    }
 538}
 539
 540fn build_inline_assist_editor_renderer(
 541    editor: &View<InlineAssistEditor>,
 542    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 543) -> RenderBlock {
 544    let editor = editor.clone();
 545    Box::new(move |cx: &mut BlockContext| {
 546        *gutter_dimensions.lock() = *cx.gutter_dimensions;
 547        editor.clone().into_any_element()
 548    })
 549}
 550
 551#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 552struct InlineAssistId(usize);
 553
 554impl InlineAssistId {
 555    fn post_inc(&mut self) -> InlineAssistId {
 556        let id = *self;
 557        self.0 += 1;
 558        id
 559    }
 560}
 561
 562enum InlineAssistEditorEvent {
 563    Confirmed { prompt: String },
 564    Canceled,
 565    Dismissed,
 566    Resized { height_in_lines: u8 },
 567}
 568
 569struct InlineAssistEditor {
 570    id: InlineAssistId,
 571    height_in_lines: u8,
 572    prompt_editor: View<Editor>,
 573    confirmed: bool,
 574    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 575    prompt_history: VecDeque<String>,
 576    prompt_history_ix: Option<usize>,
 577    pending_prompt: String,
 578    codegen: Model<Codegen>,
 579    _subscriptions: Vec<Subscription>,
 580}
 581
 582impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
 583
 584impl Render for InlineAssistEditor {
 585    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 586        let gutter_dimensions = *self.gutter_dimensions.lock();
 587        let icon_size = IconSize::default();
 588        h_flex()
 589            .w_full()
 590            .py_1p5()
 591            .border_y_1()
 592            .border_color(cx.theme().colors().border)
 593            .bg(cx.theme().colors().editor_background)
 594            .on_action(cx.listener(Self::confirm))
 595            .on_action(cx.listener(Self::cancel))
 596            .on_action(cx.listener(Self::move_up))
 597            .on_action(cx.listener(Self::move_down))
 598            .child(
 599                h_flex()
 600                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
 601                    .pr(gutter_dimensions.fold_area_width())
 602                    .justify_end()
 603                    .children(if let Some(error) = self.codegen.read(cx).error() {
 604                        let error_message = SharedString::from(error.to_string());
 605                        Some(
 606                            div()
 607                                .id("error")
 608                                .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
 609                                .child(
 610                                    Icon::new(IconName::XCircle)
 611                                        .size(icon_size)
 612                                        .color(Color::Error),
 613                                ),
 614                        )
 615                    } else {
 616                        None
 617                    }),
 618            )
 619            .child(div().flex_1().child(self.render_prompt_editor(cx)))
 620    }
 621}
 622
 623impl FocusableView for InlineAssistEditor {
 624    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
 625        self.prompt_editor.focus_handle(cx)
 626    }
 627}
 628
 629impl InlineAssistEditor {
 630    const MAX_LINES: u8 = 8;
 631
 632    #[allow(clippy::too_many_arguments)]
 633    fn new(
 634        id: InlineAssistId,
 635        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 636        prompt_history: VecDeque<String>,
 637        codegen: Model<Codegen>,
 638        cx: &mut ViewContext<Self>,
 639    ) -> Self {
 640        let prompt_editor = cx.new_view(|cx| {
 641            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
 642            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
 643            let placeholder = match codegen.read(cx).kind() {
 644                CodegenKind::Transform { .. } => "Enter transformation prompt…",
 645                CodegenKind::Generate { .. } => "Enter generation prompt…",
 646            };
 647            editor.set_placeholder_text(placeholder, cx);
 648            editor
 649        });
 650        cx.focus_view(&prompt_editor);
 651
 652        let subscriptions = vec![
 653            cx.observe(&codegen, Self::handle_codegen_changed),
 654            cx.observe(&prompt_editor, Self::handle_prompt_editor_changed),
 655            cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
 656        ];
 657
 658        let mut this = Self {
 659            id,
 660            height_in_lines: 1,
 661            prompt_editor,
 662            confirmed: false,
 663            gutter_dimensions,
 664            prompt_history,
 665            prompt_history_ix: None,
 666            pending_prompt: String::new(),
 667            codegen,
 668            _subscriptions: subscriptions,
 669        };
 670        this.count_lines(cx);
 671        this
 672    }
 673
 674    fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
 675        let height_in_lines = cmp::max(
 676            2, // Make the editor at least two lines tall, to account for padding.
 677            cmp::min(
 678                self.prompt_editor
 679                    .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
 680                Self::MAX_LINES as u32,
 681            ),
 682        ) as u8;
 683
 684        if height_in_lines != self.height_in_lines {
 685            self.height_in_lines = height_in_lines;
 686            cx.emit(InlineAssistEditorEvent::Resized { height_in_lines });
 687        }
 688    }
 689
 690    fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
 691        self.count_lines(cx);
 692    }
 693
 694    fn handle_prompt_editor_events(
 695        &mut self,
 696        _: View<Editor>,
 697        event: &EditorEvent,
 698        cx: &mut ViewContext<Self>,
 699    ) {
 700        match event {
 701            EditorEvent::Edited => {
 702                self.pending_prompt = self.prompt_editor.read(cx).text(cx);
 703                cx.notify();
 704            }
 705            EditorEvent::Blurred => {
 706                if !self.confirmed {
 707                    cx.emit(InlineAssistEditorEvent::Canceled);
 708                }
 709            }
 710            _ => {}
 711        }
 712    }
 713
 714    fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
 715        let is_read_only = !self.codegen.read(cx).idle();
 716        self.prompt_editor.update(cx, |editor, cx| {
 717            let was_read_only = editor.read_only(cx);
 718            if was_read_only != is_read_only {
 719                if is_read_only {
 720                    editor.set_read_only(true);
 721                } else {
 722                    self.confirmed = false;
 723                    editor.set_read_only(false);
 724                }
 725            }
 726        });
 727        cx.notify();
 728    }
 729
 730    fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
 731        cx.emit(InlineAssistEditorEvent::Canceled);
 732    }
 733
 734    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
 735        if self.confirmed {
 736            cx.emit(InlineAssistEditorEvent::Dismissed);
 737        } else {
 738            let prompt = self.prompt_editor.read(cx).text(cx);
 739            self.prompt_editor
 740                .update(cx, |editor, _cx| editor.set_read_only(true));
 741            cx.emit(InlineAssistEditorEvent::Confirmed { prompt });
 742            self.confirmed = true;
 743            cx.notify();
 744        }
 745    }
 746
 747    fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
 748        if let Some(ix) = self.prompt_history_ix {
 749            if ix > 0 {
 750                self.prompt_history_ix = Some(ix - 1);
 751                let prompt = self.prompt_history[ix - 1].clone();
 752                self.set_prompt(&prompt, cx);
 753            }
 754        } else if !self.prompt_history.is_empty() {
 755            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
 756            let prompt = self.prompt_history[self.prompt_history.len() - 1].clone();
 757            self.set_prompt(&prompt, cx);
 758        }
 759    }
 760
 761    fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
 762        if let Some(ix) = self.prompt_history_ix {
 763            if ix < self.prompt_history.len() - 1 {
 764                self.prompt_history_ix = Some(ix + 1);
 765                let prompt = self.prompt_history[ix + 1].clone();
 766                self.set_prompt(&prompt, cx);
 767            } else {
 768                self.prompt_history_ix = None;
 769                let pending_prompt = self.pending_prompt.clone();
 770                self.set_prompt(&pending_prompt, cx);
 771            }
 772        }
 773    }
 774
 775    fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext<Self>) {
 776        self.prompt_editor.update(cx, |editor, cx| {
 777            editor.buffer().update(cx, |buffer, cx| {
 778                let len = buffer.len(cx);
 779                buffer.edit([(0..len, prompt)], None, cx);
 780            });
 781        });
 782    }
 783
 784    fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 785        let settings = ThemeSettings::get_global(cx);
 786        let text_style = TextStyle {
 787            color: if self.prompt_editor.read(cx).read_only(cx) {
 788                cx.theme().colors().text_disabled
 789            } else {
 790                cx.theme().colors().text
 791            },
 792            font_family: settings.ui_font.family.clone(),
 793            font_features: settings.ui_font.features.clone(),
 794            font_size: rems(0.875).into(),
 795            font_weight: FontWeight::NORMAL,
 796            font_style: FontStyle::Normal,
 797            line_height: relative(1.3),
 798            background_color: None,
 799            underline: None,
 800            strikethrough: None,
 801            white_space: WhiteSpace::Normal,
 802        };
 803        EditorElement::new(
 804            &self.prompt_editor,
 805            EditorStyle {
 806                background: cx.theme().colors().editor_background,
 807                local_player: cx.theme().players().local(),
 808                text: text_style,
 809                ..Default::default()
 810            },
 811        )
 812    }
 813}
 814
 815struct PendingInlineAssist {
 816    editor: WeakView<Editor>,
 817    inline_assist_editor: Option<(BlockId, View<InlineAssistEditor>)>,
 818    codegen: Model<Codegen>,
 819    _subscriptions: Vec<Subscription>,
 820    workspace: Option<WeakView<Workspace>>,
 821    include_context: bool,
 822}
 823
 824#[derive(Debug)]
 825pub enum CodegenEvent {
 826    Finished,
 827    Undone,
 828}
 829
 830#[derive(Clone)]
 831pub enum CodegenKind {
 832    Transform { range: Range<Anchor> },
 833    Generate { position: Anchor },
 834}
 835
 836pub struct Codegen {
 837    buffer: Model<MultiBuffer>,
 838    snapshot: MultiBufferSnapshot,
 839    kind: CodegenKind,
 840    last_equal_ranges: Vec<Range<Anchor>>,
 841    transaction_id: Option<TransactionId>,
 842    error: Option<anyhow::Error>,
 843    generation: Task<()>,
 844    idle: bool,
 845    telemetry: Option<Arc<Telemetry>>,
 846    _subscription: gpui::Subscription,
 847}
 848
 849impl EventEmitter<CodegenEvent> for Codegen {}
 850
 851impl Codegen {
 852    pub fn new(
 853        buffer: Model<MultiBuffer>,
 854        kind: CodegenKind,
 855        telemetry: Option<Arc<Telemetry>>,
 856        cx: &mut ModelContext<Self>,
 857    ) -> Self {
 858        let snapshot = buffer.read(cx).snapshot(cx);
 859        Self {
 860            buffer: buffer.clone(),
 861            snapshot,
 862            kind,
 863            last_equal_ranges: Default::default(),
 864            transaction_id: Default::default(),
 865            error: Default::default(),
 866            idle: true,
 867            generation: Task::ready(()),
 868            telemetry,
 869            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 870        }
 871    }
 872
 873    fn handle_buffer_event(
 874        &mut self,
 875        _buffer: Model<MultiBuffer>,
 876        event: &multi_buffer::Event,
 877        cx: &mut ModelContext<Self>,
 878    ) {
 879        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 880            if self.transaction_id == Some(*transaction_id) {
 881                self.transaction_id = None;
 882                self.generation = Task::ready(());
 883                cx.emit(CodegenEvent::Undone);
 884            }
 885        }
 886    }
 887
 888    pub fn range(&self) -> Range<Anchor> {
 889        match &self.kind {
 890            CodegenKind::Transform { range } => range.clone(),
 891            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 892        }
 893    }
 894
 895    pub fn kind(&self) -> &CodegenKind {
 896        &self.kind
 897    }
 898
 899    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 900        &self.last_equal_ranges
 901    }
 902
 903    pub fn idle(&self) -> bool {
 904        self.idle
 905    }
 906
 907    pub fn error(&self) -> Option<&anyhow::Error> {
 908        self.error.as_ref()
 909    }
 910
 911    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
 912        let range = self.range();
 913        let snapshot = self.snapshot.clone();
 914        let selected_text = snapshot
 915            .text_for_range(range.start..range.end)
 916            .collect::<Rope>();
 917
 918        let selection_start = range.start.to_point(&snapshot);
 919        let suggested_line_indent = snapshot
 920            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
 921            .into_values()
 922            .next()
 923            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 924
 925        let model_telemetry_id = prompt.model.telemetry_id();
 926        let response = CompletionProvider::global(cx).complete(prompt);
 927        let telemetry = self.telemetry.clone();
 928        self.generation = cx.spawn(|this, mut cx| {
 929            async move {
 930                let generate = async {
 931                    let mut edit_start = range.start.to_offset(&snapshot);
 932
 933                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 934                    let diff: Task<anyhow::Result<()>> =
 935                        cx.background_executor().spawn(async move {
 936                            let mut response_latency = None;
 937                            let request_start = Instant::now();
 938                            let diff = async {
 939                                let chunks = strip_invalid_spans_from_codeblock(response.await?);
 940                                futures::pin_mut!(chunks);
 941                                let mut diff = StreamingDiff::new(selected_text.to_string());
 942
 943                                let mut new_text = String::new();
 944                                let mut base_indent = None;
 945                                let mut line_indent = None;
 946                                let mut first_line = true;
 947
 948                                while let Some(chunk) = chunks.next().await {
 949                                    if response_latency.is_none() {
 950                                        response_latency = Some(request_start.elapsed());
 951                                    }
 952                                    let chunk = chunk?;
 953
 954                                    let mut lines = chunk.split('\n').peekable();
 955                                    while let Some(line) = lines.next() {
 956                                        new_text.push_str(line);
 957                                        if line_indent.is_none() {
 958                                            if let Some(non_whitespace_ch_ix) =
 959                                                new_text.find(|ch: char| !ch.is_whitespace())
 960                                            {
 961                                                line_indent = Some(non_whitespace_ch_ix);
 962                                                base_indent = base_indent.or(line_indent);
 963
 964                                                let line_indent = line_indent.unwrap();
 965                                                let base_indent = base_indent.unwrap();
 966                                                let indent_delta =
 967                                                    line_indent as i32 - base_indent as i32;
 968                                                let mut corrected_indent_len = cmp::max(
 969                                                    0,
 970                                                    suggested_line_indent.len as i32 + indent_delta,
 971                                                )
 972                                                    as usize;
 973                                                if first_line {
 974                                                    corrected_indent_len = corrected_indent_len
 975                                                        .saturating_sub(
 976                                                            selection_start.column as usize,
 977                                                        );
 978                                                }
 979
 980                                                let indent_char = suggested_line_indent.char();
 981                                                let mut indent_buffer = [0; 4];
 982                                                let indent_str =
 983                                                    indent_char.encode_utf8(&mut indent_buffer);
 984                                                new_text.replace_range(
 985                                                    ..line_indent,
 986                                                    &indent_str.repeat(corrected_indent_len),
 987                                                );
 988                                            }
 989                                        }
 990
 991                                        if line_indent.is_some() {
 992                                            hunks_tx.send(diff.push_new(&new_text)).await?;
 993                                            new_text.clear();
 994                                        }
 995
 996                                        if lines.peek().is_some() {
 997                                            hunks_tx.send(diff.push_new("\n")).await?;
 998                                            line_indent = None;
 999                                            first_line = false;
1000                                        }
1001                                    }
1002                                }
1003                                hunks_tx.send(diff.push_new(&new_text)).await?;
1004                                hunks_tx.send(diff.finish()).await?;
1005
1006                                anyhow::Ok(())
1007                            };
1008
1009                            let result = diff.await;
1010
1011                            let error_message =
1012                                result.as_ref().err().map(|error| error.to_string());
1013                            if let Some(telemetry) = telemetry {
1014                                telemetry.report_assistant_event(
1015                                    None,
1016                                    telemetry_events::AssistantKind::Inline,
1017                                    model_telemetry_id,
1018                                    response_latency,
1019                                    error_message,
1020                                );
1021                            }
1022
1023                            result?;
1024                            Ok(())
1025                        });
1026
1027                    while let Some(hunks) = hunks_rx.next().await {
1028                        this.update(&mut cx, |this, cx| {
1029                            this.last_equal_ranges.clear();
1030
1031                            let transaction = this.buffer.update(cx, |buffer, cx| {
1032                                // Avoid grouping assistant edits with user edits.
1033                                buffer.finalize_last_transaction(cx);
1034
1035                                buffer.start_transaction(cx);
1036                                buffer.edit(
1037                                    hunks.into_iter().filter_map(|hunk| match hunk {
1038                                        Hunk::Insert { text } => {
1039                                            let edit_start = snapshot.anchor_after(edit_start);
1040                                            Some((edit_start..edit_start, text))
1041                                        }
1042                                        Hunk::Remove { len } => {
1043                                            let edit_end = edit_start + len;
1044                                            let edit_range = snapshot.anchor_after(edit_start)
1045                                                ..snapshot.anchor_before(edit_end);
1046                                            edit_start = edit_end;
1047                                            Some((edit_range, String::new()))
1048                                        }
1049                                        Hunk::Keep { len } => {
1050                                            let edit_end = edit_start + len;
1051                                            let edit_range = snapshot.anchor_after(edit_start)
1052                                                ..snapshot.anchor_before(edit_end);
1053                                            edit_start = edit_end;
1054                                            this.last_equal_ranges.push(edit_range);
1055                                            None
1056                                        }
1057                                    }),
1058                                    None,
1059                                    cx,
1060                                );
1061
1062                                buffer.end_transaction(cx)
1063                            });
1064
1065                            if let Some(transaction) = transaction {
1066                                if let Some(first_transaction) = this.transaction_id {
1067                                    // Group all assistant edits into the first transaction.
1068                                    this.buffer.update(cx, |buffer, cx| {
1069                                        buffer.merge_transactions(
1070                                            transaction,
1071                                            first_transaction,
1072                                            cx,
1073                                        )
1074                                    });
1075                                } else {
1076                                    this.transaction_id = Some(transaction);
1077                                    this.buffer.update(cx, |buffer, cx| {
1078                                        buffer.finalize_last_transaction(cx)
1079                                    });
1080                                }
1081                            }
1082
1083                            cx.notify();
1084                        })?;
1085                    }
1086
1087                    diff.await?;
1088
1089                    anyhow::Ok(())
1090                };
1091
1092                let result = generate.await;
1093                this.update(&mut cx, |this, cx| {
1094                    this.last_equal_ranges.clear();
1095                    this.idle = true;
1096                    if let Err(error) = result {
1097                        this.error = Some(error);
1098                    }
1099                    cx.emit(CodegenEvent::Finished);
1100                    cx.notify();
1101                })
1102                .ok();
1103            }
1104        });
1105        self.error.take();
1106        self.idle = false;
1107        cx.notify();
1108    }
1109
1110    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1111        if let Some(transaction_id) = self.transaction_id {
1112            self.buffer
1113                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1114        }
1115    }
1116}
1117
1118fn strip_invalid_spans_from_codeblock(
1119    stream: impl Stream<Item = Result<String>>,
1120) -> impl Stream<Item = Result<String>> {
1121    let mut first_line = true;
1122    let mut buffer = String::new();
1123    let mut starts_with_markdown_codeblock = false;
1124    let mut includes_start_or_end_span = false;
1125    stream.filter_map(move |chunk| {
1126        let chunk = match chunk {
1127            Ok(chunk) => chunk,
1128            Err(err) => return future::ready(Some(Err(err))),
1129        };
1130        buffer.push_str(&chunk);
1131
1132        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1133            includes_start_or_end_span = true;
1134
1135            buffer = buffer
1136                .strip_prefix("<|S|>")
1137                .or_else(|| buffer.strip_prefix("<|S|"))
1138                .unwrap_or(&buffer)
1139                .to_string();
1140        } else if buffer.ends_with("|E|>") {
1141            includes_start_or_end_span = true;
1142        } else if buffer.starts_with("<|")
1143            || buffer.starts_with("<|S")
1144            || buffer.starts_with("<|S|")
1145            || buffer.ends_with('|')
1146            || buffer.ends_with("|E")
1147            || buffer.ends_with("|E|")
1148        {
1149            return future::ready(None);
1150        }
1151
1152        if first_line {
1153            if buffer.is_empty() || buffer == "`" || buffer == "``" {
1154                return future::ready(None);
1155            } else if buffer.starts_with("```") {
1156                starts_with_markdown_codeblock = true;
1157                if let Some(newline_ix) = buffer.find('\n') {
1158                    buffer.replace_range(..newline_ix + 1, "");
1159                    first_line = false;
1160                } else {
1161                    return future::ready(None);
1162                }
1163            }
1164        }
1165
1166        let mut text = buffer.to_string();
1167        if starts_with_markdown_codeblock {
1168            text = text
1169                .strip_suffix("\n```\n")
1170                .or_else(|| text.strip_suffix("\n```"))
1171                .or_else(|| text.strip_suffix("\n``"))
1172                .or_else(|| text.strip_suffix("\n`"))
1173                .or_else(|| text.strip_suffix('\n'))
1174                .unwrap_or(&text)
1175                .to_string();
1176        }
1177
1178        if includes_start_or_end_span {
1179            text = text
1180                .strip_suffix("|E|>")
1181                .or_else(|| text.strip_suffix("E|>"))
1182                .or_else(|| text.strip_prefix("|>"))
1183                .or_else(|| text.strip_prefix('>'))
1184                .unwrap_or(&text)
1185                .to_string();
1186        };
1187
1188        if text.contains('\n') {
1189            first_line = false;
1190        }
1191
1192        let remainder = buffer.split_off(text.len());
1193        let result = if buffer.is_empty() {
1194            None
1195        } else {
1196            Some(Ok(buffer.clone()))
1197        };
1198
1199        buffer = remainder;
1200        future::ready(result)
1201    })
1202}
1203
1204fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1205    ranges.sort_unstable_by(|a, b| {
1206        a.start
1207            .cmp(&b.start, buffer)
1208            .then_with(|| b.end.cmp(&a.end, buffer))
1209    });
1210
1211    let mut ix = 0;
1212    while ix + 1 < ranges.len() {
1213        let b = ranges[ix + 1].clone();
1214        let a = &mut ranges[ix];
1215        if a.end.cmp(&b.start, buffer).is_gt() {
1216            if a.end.cmp(&b.end, buffer).is_lt() {
1217                a.end = b.end;
1218            }
1219            ranges.remove(ix + 1);
1220        } else {
1221            ix += 1;
1222        }
1223    }
1224}
1225
1226#[cfg(test)]
1227mod tests {
1228    use std::sync::Arc;
1229
1230    use crate::FakeCompletionProvider;
1231
1232    use super::*;
1233    use futures::stream::{self};
1234    use gpui::{Context, TestAppContext};
1235    use indoc::indoc;
1236    use language::{
1237        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1238        Point,
1239    };
1240    use rand::prelude::*;
1241    use serde::Serialize;
1242    use settings::SettingsStore;
1243
1244    #[derive(Serialize)]
1245    pub struct DummyCompletionRequest {
1246        pub name: String,
1247    }
1248
1249    #[gpui::test(iterations = 10)]
1250    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1251        let provider = FakeCompletionProvider::default();
1252        cx.set_global(cx.update(SettingsStore::test));
1253        cx.set_global(CompletionProvider::Fake(provider.clone()));
1254        cx.update(language_settings::init);
1255
1256        let text = indoc! {"
1257            fn main() {
1258                let x = 0;
1259                for _ in 0..10 {
1260                    x += 1;
1261                }
1262            }
1263        "};
1264        let buffer =
1265            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1266        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1267        let range = buffer.read_with(cx, |buffer, cx| {
1268            let snapshot = buffer.snapshot(cx);
1269            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1270        });
1271        let codegen = cx.new_model(|cx| {
1272            Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1273        });
1274
1275        let request = LanguageModelRequest::default();
1276        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1277
1278        let mut new_text = concat!(
1279            "       let mut x = 0;\n",
1280            "       while x < 10 {\n",
1281            "           x += 1;\n",
1282            "       }",
1283        );
1284        while !new_text.is_empty() {
1285            let max_len = cmp::min(new_text.len(), 10);
1286            let len = rng.gen_range(1..=max_len);
1287            let (chunk, suffix) = new_text.split_at(len);
1288            provider.send_completion(chunk.into());
1289            new_text = suffix;
1290            cx.background_executor.run_until_parked();
1291        }
1292        provider.finish_completion();
1293        cx.background_executor.run_until_parked();
1294
1295        assert_eq!(
1296            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1297            indoc! {"
1298                fn main() {
1299                    let mut x = 0;
1300                    while x < 10 {
1301                        x += 1;
1302                    }
1303                }
1304            "}
1305        );
1306    }
1307
1308    #[gpui::test(iterations = 10)]
1309    async fn test_autoindent_when_generating_past_indentation(
1310        cx: &mut TestAppContext,
1311        mut rng: StdRng,
1312    ) {
1313        let provider = FakeCompletionProvider::default();
1314        cx.set_global(CompletionProvider::Fake(provider.clone()));
1315        cx.set_global(cx.update(SettingsStore::test));
1316        cx.update(language_settings::init);
1317
1318        let text = indoc! {"
1319            fn main() {
1320                le
1321            }
1322        "};
1323        let buffer =
1324            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1325        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1326        let position = buffer.read_with(cx, |buffer, cx| {
1327            let snapshot = buffer.snapshot(cx);
1328            snapshot.anchor_before(Point::new(1, 6))
1329        });
1330        let codegen = cx.new_model(|cx| {
1331            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1332        });
1333
1334        let request = LanguageModelRequest::default();
1335        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1336
1337        let mut new_text = concat!(
1338            "t mut x = 0;\n",
1339            "while x < 10 {\n",
1340            "    x += 1;\n",
1341            "}", //
1342        );
1343        while !new_text.is_empty() {
1344            let max_len = cmp::min(new_text.len(), 10);
1345            let len = rng.gen_range(1..=max_len);
1346            let (chunk, suffix) = new_text.split_at(len);
1347            provider.send_completion(chunk.into());
1348            new_text = suffix;
1349            cx.background_executor.run_until_parked();
1350        }
1351        provider.finish_completion();
1352        cx.background_executor.run_until_parked();
1353
1354        assert_eq!(
1355            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1356            indoc! {"
1357                fn main() {
1358                    let mut x = 0;
1359                    while x < 10 {
1360                        x += 1;
1361                    }
1362                }
1363            "}
1364        );
1365    }
1366
1367    #[gpui::test(iterations = 10)]
1368    async fn test_autoindent_when_generating_before_indentation(
1369        cx: &mut TestAppContext,
1370        mut rng: StdRng,
1371    ) {
1372        let provider = FakeCompletionProvider::default();
1373        cx.set_global(CompletionProvider::Fake(provider.clone()));
1374        cx.set_global(cx.update(SettingsStore::test));
1375        cx.update(language_settings::init);
1376
1377        let text = concat!(
1378            "fn main() {\n",
1379            "  \n",
1380            "}\n" //
1381        );
1382        let buffer =
1383            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1384        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1385        let position = buffer.read_with(cx, |buffer, cx| {
1386            let snapshot = buffer.snapshot(cx);
1387            snapshot.anchor_before(Point::new(1, 2))
1388        });
1389        let codegen = cx.new_model(|cx| {
1390            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1391        });
1392
1393        let request = LanguageModelRequest::default();
1394        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1395
1396        let mut new_text = concat!(
1397            "let mut x = 0;\n",
1398            "while x < 10 {\n",
1399            "    x += 1;\n",
1400            "}", //
1401        );
1402        while !new_text.is_empty() {
1403            let max_len = cmp::min(new_text.len(), 10);
1404            let len = rng.gen_range(1..=max_len);
1405            let (chunk, suffix) = new_text.split_at(len);
1406            provider.send_completion(chunk.into());
1407            new_text = suffix;
1408            cx.background_executor.run_until_parked();
1409        }
1410        provider.finish_completion();
1411        cx.background_executor.run_until_parked();
1412
1413        assert_eq!(
1414            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1415            indoc! {"
1416                fn main() {
1417                    let mut x = 0;
1418                    while x < 10 {
1419                        x += 1;
1420                    }
1421                }
1422            "}
1423        );
1424    }
1425
1426    #[gpui::test]
1427    async fn test_strip_invalid_spans_from_codeblock() {
1428        assert_eq!(
1429            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1430                .map(|chunk| chunk.unwrap())
1431                .collect::<String>()
1432                .await,
1433            "Lorem ipsum dolor"
1434        );
1435        assert_eq!(
1436            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1437                .map(|chunk| chunk.unwrap())
1438                .collect::<String>()
1439                .await,
1440            "Lorem ipsum dolor"
1441        );
1442        assert_eq!(
1443            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1444                .map(|chunk| chunk.unwrap())
1445                .collect::<String>()
1446                .await,
1447            "Lorem ipsum dolor"
1448        );
1449        assert_eq!(
1450            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1451                .map(|chunk| chunk.unwrap())
1452                .collect::<String>()
1453                .await,
1454            "Lorem ipsum dolor"
1455        );
1456        assert_eq!(
1457            strip_invalid_spans_from_codeblock(chunks(
1458                "```html\n```js\nLorem ipsum dolor\n```\n```",
1459                2
1460            ))
1461            .map(|chunk| chunk.unwrap())
1462            .collect::<String>()
1463            .await,
1464            "```js\nLorem ipsum dolor\n```"
1465        );
1466        assert_eq!(
1467            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1468                .map(|chunk| chunk.unwrap())
1469                .collect::<String>()
1470                .await,
1471            "``\nLorem ipsum dolor\n```"
1472        );
1473        assert_eq!(
1474            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1475                .map(|chunk| chunk.unwrap())
1476                .collect::<String>()
1477                .await,
1478            "Lorem ipsum"
1479        );
1480
1481        assert_eq!(
1482            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
1483                .map(|chunk| chunk.unwrap())
1484                .collect::<String>()
1485                .await,
1486            "Lorem ipsum"
1487        );
1488
1489        assert_eq!(
1490            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
1491                .map(|chunk| chunk.unwrap())
1492                .collect::<String>()
1493                .await,
1494            "Lorem ipsum"
1495        );
1496        assert_eq!(
1497            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
1498                .map(|chunk| chunk.unwrap())
1499                .collect::<String>()
1500                .await,
1501            "Lorem ipsum"
1502        );
1503        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1504            stream::iter(
1505                text.chars()
1506                    .collect::<Vec<_>>()
1507                    .chunks(size)
1508                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1509                    .collect::<Vec<_>>(),
1510            )
1511        }
1512    }
1513
1514    fn rust_lang() -> Language {
1515        Language::new(
1516            LanguageConfig {
1517                name: "Rust".into(),
1518                matcher: LanguageMatcher {
1519                    path_suffixes: vec!["rs".to_string()],
1520                    ..Default::default()
1521                },
1522                ..Default::default()
1523            },
1524            Some(tree_sitter_rust::language()),
1525        )
1526        .with_indents_query(
1527            r#"
1528            (call_expression) @indent
1529            (field_expression) @indent
1530            (_ "(" ")" @end) @indent
1531            (_ "{" "}" @end) @indent
1532            "#,
1533        )
1534        .unwrap()
1535    }
1536}