inline_assistant.rs

   1use crate::{
   2    prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
   3    LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
   4};
   5use anyhow::Result;
   6use client::telemetry::Telemetry;
   7use collections::{hash_map, HashMap, HashSet, VecDeque};
   8use editor::{
   9    actions::{MoveDown, MoveUp},
  10    display_map::{BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle},
  11    scroll::{Autoscroll, AutoscrollStrategy},
  12    Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer,
  13    MultiBufferSnapshot, ToOffset, ToPoint,
  14};
  15use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  16use gpui::{
  17    AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
  18    Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
  19    ViewContext, WeakView, WhiteSpace, WindowContext,
  20};
  21use language::{Point, TransactionId};
  22use multi_buffer::MultiBufferRow;
  23use parking_lot::Mutex;
  24use rope::Rope;
  25use settings::Settings;
  26use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
  27use theme::ThemeSettings;
  28use ui::{prelude::*, Tooltip};
  29use workspace::{notifications::NotificationId, Toast, Workspace};
  30
  31pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
  32    cx.set_global(InlineAssistant::new(telemetry));
  33}
  34
  35const PROMPT_HISTORY_MAX_LEN: usize = 20;
  36
  37pub struct InlineAssistant {
  38    next_assist_id: InlineAssistId,
  39    pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
  40    pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
  41    prompt_history: VecDeque<String>,
  42    telemetry: Option<Arc<Telemetry>>,
  43}
  44
  45struct EditorPendingAssists {
  46    window: AnyWindowHandle,
  47    assist_ids: Vec<InlineAssistId>,
  48}
  49
  50impl Global for InlineAssistant {}
  51
  52impl InlineAssistant {
  53    pub fn new(telemetry: Arc<Telemetry>) -> Self {
  54        Self {
  55            next_assist_id: InlineAssistId::default(),
  56            pending_assists: HashMap::default(),
  57            pending_assist_ids_by_editor: HashMap::default(),
  58            prompt_history: VecDeque::default(),
  59            telemetry: Some(telemetry),
  60        }
  61    }
  62
  63    pub fn assist(
  64        &mut self,
  65        editor: &View<Editor>,
  66        workspace: Option<WeakView<Workspace>>,
  67        include_conversation: bool,
  68        cx: &mut WindowContext,
  69    ) {
  70        let selection = editor.read(cx).selections.newest_anchor().clone();
  71        if selection.start.excerpt_id != selection.end.excerpt_id {
  72            return;
  73        }
  74        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
  75
  76        // Extend the selection to the start and the end of the line.
  77        let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
  78        if point_selection.end > point_selection.start {
  79            point_selection.start.column = 0;
  80            // If the selection ends at the start of the line, we don't want to include it.
  81            if point_selection.end.column == 0 {
  82                point_selection.end.row -= 1;
  83            }
  84            point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
  85        }
  86
  87        let codegen_kind = if point_selection.start == point_selection.end {
  88            CodegenKind::Generate {
  89                position: snapshot.anchor_after(point_selection.start),
  90            }
  91        } else {
  92            CodegenKind::Transform {
  93                range: snapshot.anchor_before(point_selection.start)
  94                    ..snapshot.anchor_after(point_selection.end),
  95            }
  96        };
  97
  98        let inline_assist_id = self.next_assist_id.post_inc();
  99        let codegen = cx.new_model(|cx| {
 100            Codegen::new(
 101                editor.read(cx).buffer().clone(),
 102                codegen_kind,
 103                self.telemetry.clone(),
 104                cx,
 105            )
 106        });
 107
 108        let measurements = Arc::new(Mutex::new(GutterDimensions::default()));
 109        let inline_assistant = cx.new_view(|cx| {
 110            InlineAssistEditor::new(
 111                inline_assist_id,
 112                measurements.clone(),
 113                self.prompt_history.clone(),
 114                codegen.clone(),
 115                cx,
 116            )
 117        });
 118        let block_id = editor.update(cx, |editor, cx| {
 119            editor.change_selections(None, cx, |selections| {
 120                selections.select_anchor_ranges([selection.head()..selection.head()])
 121            });
 122            editor.insert_blocks(
 123                [BlockProperties {
 124                    style: BlockStyle::Flex,
 125                    position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)),
 126                    height: 2,
 127                    render: Box::new({
 128                        let inline_assistant = inline_assistant.clone();
 129                        move |cx: &mut BlockContext| {
 130                            *measurements.lock() = *cx.gutter_dimensions;
 131                            inline_assistant.clone().into_any_element()
 132                        }
 133                    }),
 134                    disposition: if selection.reversed {
 135                        BlockDisposition::Above
 136                    } else {
 137                        BlockDisposition::Below
 138                    },
 139                }],
 140                Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
 141                cx,
 142            )[0]
 143        });
 144
 145        self.pending_assists.insert(
 146            inline_assist_id,
 147            PendingInlineAssist {
 148                include_conversation,
 149                editor: editor.downgrade(),
 150                inline_assistant: Some((block_id, inline_assistant.clone())),
 151                codegen: codegen.clone(),
 152                workspace,
 153                _subscriptions: vec![
 154                    cx.subscribe(&inline_assistant, |inline_assistant, event, cx| {
 155                        InlineAssistant::update_global(cx, |this, cx| {
 156                            this.handle_inline_assistant_event(inline_assistant, event, cx)
 157                        })
 158                    }),
 159                    cx.subscribe(editor, {
 160                        let inline_assistant = inline_assistant.downgrade();
 161                        move |editor, event, cx| {
 162                            if let Some(inline_assistant) = inline_assistant.upgrade() {
 163                                if let EditorEvent::SelectionsChanged { local } = event {
 164                                    if *local
 165                                        && inline_assistant.focus_handle(cx).contains_focused(cx)
 166                                    {
 167                                        cx.focus_view(&editor);
 168                                    }
 169                                }
 170                            }
 171                        }
 172                    }),
 173                    cx.observe(&codegen, {
 174                        let editor = editor.downgrade();
 175                        move |_, cx| {
 176                            if let Some(editor) = editor.upgrade() {
 177                                InlineAssistant::update_global(cx, |this, cx| {
 178                                    this.update_highlights_for_editor(&editor, cx);
 179                                })
 180                            }
 181                        }
 182                    }),
 183                    cx.subscribe(&codegen, move |codegen, event, cx| {
 184                        InlineAssistant::update_global(cx, |this, cx| match event {
 185                            CodegenEvent::Undone => {
 186                                this.finish_inline_assist(inline_assist_id, false, cx)
 187                            }
 188                            CodegenEvent::Finished => {
 189                                let pending_assist = if let Some(pending_assist) =
 190                                    this.pending_assists.get(&inline_assist_id)
 191                                {
 192                                    pending_assist
 193                                } else {
 194                                    return;
 195                                };
 196
 197                                let error = codegen
 198                                    .read(cx)
 199                                    .error()
 200                                    .map(|error| format!("Inline assistant error: {}", error));
 201                                if let Some(error) = error {
 202                                    if pending_assist.inline_assistant.is_none() {
 203                                        if let Some(workspace) = pending_assist
 204                                            .workspace
 205                                            .as_ref()
 206                                            .and_then(|workspace| workspace.upgrade())
 207                                        {
 208                                            workspace.update(cx, |workspace, cx| {
 209                                                struct InlineAssistantError;
 210
 211                                                let id = NotificationId::identified::<
 212                                                    InlineAssistantError,
 213                                                >(
 214                                                    inline_assist_id.0
 215                                                );
 216
 217                                                workspace.show_toast(Toast::new(id, error), cx);
 218                                            })
 219                                        }
 220
 221                                        this.finish_inline_assist(inline_assist_id, false, cx);
 222                                    }
 223                                } else {
 224                                    this.finish_inline_assist(inline_assist_id, false, cx);
 225                                }
 226                            }
 227                        })
 228                    }),
 229                ],
 230            },
 231        );
 232
 233        self.pending_assist_ids_by_editor
 234            .entry(editor.downgrade())
 235            .or_insert_with(|| EditorPendingAssists {
 236                window: cx.window_handle(),
 237                assist_ids: Vec::new(),
 238            })
 239            .assist_ids
 240            .push(inline_assist_id);
 241        self.update_highlights_for_editor(editor, cx);
 242    }
 243
 244    fn handle_inline_assistant_event(
 245        &mut self,
 246        inline_assistant: View<InlineAssistEditor>,
 247        event: &InlineAssistEditorEvent,
 248        cx: &mut WindowContext,
 249    ) {
 250        let assist_id = inline_assistant.read(cx).id;
 251        match event {
 252            InlineAssistEditorEvent::Confirmed { prompt } => {
 253                self.confirm_inline_assist(assist_id, prompt, cx);
 254            }
 255            InlineAssistEditorEvent::Canceled => {
 256                self.finish_inline_assist(assist_id, true, cx);
 257            }
 258            InlineAssistEditorEvent::Dismissed => {
 259                self.hide_inline_assist(assist_id, cx);
 260            }
 261        }
 262    }
 263
 264    pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
 265        for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
 266            if pending_assists.window == cx.window_handle() {
 267                if let Some(editor) = editor.upgrade() {
 268                    if editor.read(cx).is_focused(cx) {
 269                        if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
 270                            self.finish_inline_assist(assist_id, true, cx);
 271                            return true;
 272                        }
 273                    }
 274                }
 275            }
 276        }
 277        false
 278    }
 279
 280    fn finish_inline_assist(
 281        &mut self,
 282        assist_id: InlineAssistId,
 283        undo: bool,
 284        cx: &mut WindowContext,
 285    ) {
 286        self.hide_inline_assist(assist_id, cx);
 287
 288        if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
 289            if let hash_map::Entry::Occupied(mut entry) = self
 290                .pending_assist_ids_by_editor
 291                .entry(pending_assist.editor.clone())
 292            {
 293                entry.get_mut().assist_ids.retain(|id| *id != assist_id);
 294                if entry.get().assist_ids.is_empty() {
 295                    entry.remove();
 296                }
 297            }
 298
 299            if let Some(editor) = pending_assist.editor.upgrade() {
 300                self.update_highlights_for_editor(&editor, cx);
 301
 302                if undo {
 303                    pending_assist
 304                        .codegen
 305                        .update(cx, |codegen, cx| codegen.undo(cx));
 306                }
 307            }
 308        }
 309    }
 310
 311    fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
 312        if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
 313            if let Some(editor) = pending_assist.editor.upgrade() {
 314                if let Some((block_id, inline_assistant)) = pending_assist.inline_assistant.take() {
 315                    editor.update(cx, |editor, cx| {
 316                        editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
 317                        if inline_assistant.focus_handle(cx).contains_focused(cx) {
 318                            editor.focus(cx);
 319                        }
 320                    });
 321                }
 322            }
 323        }
 324    }
 325
 326    fn confirm_inline_assist(
 327        &mut self,
 328        assist_id: InlineAssistId,
 329        user_prompt: &str,
 330        cx: &mut WindowContext,
 331    ) {
 332        let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
 333        {
 334            pending_assist
 335        } else {
 336            return;
 337        };
 338
 339        let conversation = if pending_assist.include_conversation {
 340            pending_assist.workspace.as_ref().and_then(|workspace| {
 341                let workspace = workspace.upgrade()?.read(cx);
 342                let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
 343                assistant_panel.read(cx).active_conversation(cx)
 344            })
 345        } else {
 346            None
 347        };
 348
 349        let editor = if let Some(editor) = pending_assist.editor.upgrade() {
 350            editor
 351        } else {
 352            return;
 353        };
 354
 355        let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
 356            let workspace = workspace.upgrade()?;
 357            Some(
 358                workspace
 359                    .read(cx)
 360                    .project()
 361                    .read(cx)
 362                    .worktree_root_names(cx)
 363                    .collect::<Vec<&str>>()
 364                    .join("/"),
 365            )
 366        });
 367
 368        self.prompt_history.retain(|prompt| prompt != user_prompt);
 369        self.prompt_history.push_back(user_prompt.into());
 370        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
 371            self.prompt_history.pop_front();
 372        }
 373
 374        let codegen = pending_assist.codegen.clone();
 375        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 376        let range = codegen.read(cx).range();
 377        let start = snapshot.point_to_buffer_offset(range.start);
 378        let end = snapshot.point_to_buffer_offset(range.end);
 379        let (buffer, range) = if let Some((start, end)) = start.zip(end) {
 380            let (start_buffer, start_buffer_offset) = start;
 381            let (end_buffer, end_buffer_offset) = end;
 382            if start_buffer.remote_id() == end_buffer.remote_id() {
 383                (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
 384            } else {
 385                self.finish_inline_assist(assist_id, false, cx);
 386                return;
 387            }
 388        } else {
 389            self.finish_inline_assist(assist_id, false, cx);
 390            return;
 391        };
 392
 393        let language = buffer.language_at(range.start);
 394        let language_name = if let Some(language) = language.as_ref() {
 395            if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
 396                None
 397            } else {
 398                Some(language.name())
 399            }
 400        } else {
 401            None
 402        };
 403
 404        // Higher Temperature increases the randomness of model outputs.
 405        // If Markdown or No Language is Known, increase the randomness for more creative output
 406        // If Code, decrease temperature to get more deterministic outputs
 407        let temperature = if let Some(language) = language_name.clone() {
 408            if language.as_ref() == "Markdown" {
 409                1.0
 410            } else {
 411                0.5
 412            }
 413        } else {
 414            1.0
 415        };
 416
 417        let user_prompt = user_prompt.to_string();
 418
 419        let prompt = cx.background_executor().spawn(async move {
 420            let language_name = language_name.as_deref();
 421            generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
 422        });
 423
 424        let mut messages = Vec::new();
 425        if let Some(conversation) = conversation {
 426            let request = conversation.read(cx).to_completion_request(cx);
 427            messages = request.messages;
 428        }
 429        let model = CompletionProvider::global(cx).model();
 430
 431        cx.spawn(|mut cx| async move {
 432            let prompt = prompt.await?;
 433
 434            messages.push(LanguageModelRequestMessage {
 435                role: Role::User,
 436                content: prompt,
 437            });
 438
 439            let request = LanguageModelRequest {
 440                model,
 441                messages,
 442                stop: vec!["|END|>".to_string()],
 443                temperature,
 444            };
 445
 446            codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
 447            anyhow::Ok(())
 448        })
 449        .detach_and_log_err(cx);
 450    }
 451
 452    fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut WindowContext) {
 453        let mut background_ranges = Vec::new();
 454        let mut foreground_ranges = Vec::new();
 455        let empty_inline_assist_ids = Vec::new();
 456        let inline_assist_ids = self
 457            .pending_assist_ids_by_editor
 458            .get(&editor.downgrade())
 459            .map_or(&empty_inline_assist_ids, |pending_assists| {
 460                &pending_assists.assist_ids
 461            });
 462
 463        for inline_assist_id in inline_assist_ids {
 464            if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
 465                let codegen = pending_assist.codegen.read(cx);
 466                background_ranges.push(codegen.range());
 467                foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
 468            }
 469        }
 470
 471        let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
 472        merge_ranges(&mut background_ranges, &snapshot);
 473        merge_ranges(&mut foreground_ranges, &snapshot);
 474        editor.update(cx, |editor, cx| {
 475            if background_ranges.is_empty() {
 476                editor.clear_background_highlights::<PendingInlineAssist>(cx);
 477            } else {
 478                editor.highlight_background::<PendingInlineAssist>(
 479                    &background_ranges,
 480                    |theme| theme.editor_active_line_background, // TODO use the appropriate color
 481                    cx,
 482                );
 483            }
 484
 485            if foreground_ranges.is_empty() {
 486                editor.clear_highlights::<PendingInlineAssist>(cx);
 487            } else {
 488                editor.highlight_text::<PendingInlineAssist>(
 489                    foreground_ranges,
 490                    HighlightStyle {
 491                        fade_out: Some(0.6),
 492                        ..Default::default()
 493                    },
 494                    cx,
 495                );
 496            }
 497        });
 498    }
 499}
 500
 501#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 502struct InlineAssistId(usize);
 503
 504impl InlineAssistId {
 505    fn post_inc(&mut self) -> InlineAssistId {
 506        let id = *self;
 507        self.0 += 1;
 508        id
 509    }
 510}
 511
 512enum InlineAssistEditorEvent {
 513    Confirmed { prompt: String },
 514    Canceled,
 515    Dismissed,
 516}
 517
 518struct InlineAssistEditor {
 519    id: InlineAssistId,
 520    prompt_editor: View<Editor>,
 521    confirmed: bool,
 522    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 523    prompt_history: VecDeque<String>,
 524    prompt_history_ix: Option<usize>,
 525    pending_prompt: String,
 526    codegen: Model<Codegen>,
 527    _subscriptions: Vec<Subscription>,
 528}
 529
 530impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
 531
 532impl Render for InlineAssistEditor {
 533    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 534        let gutter_dimensions = *self.gutter_dimensions.lock();
 535        let icon_size = IconSize::default();
 536        h_flex()
 537            .w_full()
 538            .py_2()
 539            .border_y_1()
 540            .border_color(cx.theme().colors().border)
 541            .bg(cx.theme().colors().editor_background)
 542            .on_action(cx.listener(Self::confirm))
 543            .on_action(cx.listener(Self::cancel))
 544            .on_action(cx.listener(Self::move_up))
 545            .on_action(cx.listener(Self::move_down))
 546            .child(
 547                h_flex()
 548                    .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
 549                    .pr(gutter_dimensions.fold_area_width())
 550                    .justify_end()
 551                    .children(if let Some(error) = self.codegen.read(cx).error() {
 552                        let error_message = SharedString::from(error.to_string());
 553                        Some(
 554                            div()
 555                                .id("error")
 556                                .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
 557                                .child(
 558                                    Icon::new(IconName::XCircle)
 559                                        .size(icon_size)
 560                                        .color(Color::Error),
 561                                ),
 562                        )
 563                    } else {
 564                        None
 565                    }),
 566            )
 567            .child(h_flex().flex_1().child(self.render_prompt_editor(cx)))
 568    }
 569}
 570
 571impl FocusableView for InlineAssistEditor {
 572    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
 573        self.prompt_editor.focus_handle(cx)
 574    }
 575}
 576
 577impl InlineAssistEditor {
 578    #[allow(clippy::too_many_arguments)]
 579    fn new(
 580        id: InlineAssistId,
 581        gutter_dimensions: Arc<Mutex<GutterDimensions>>,
 582        prompt_history: VecDeque<String>,
 583        codegen: Model<Codegen>,
 584        cx: &mut ViewContext<Self>,
 585    ) -> Self {
 586        let prompt_editor = cx.new_view(|cx| {
 587            let mut editor = Editor::single_line(cx);
 588            let placeholder = match codegen.read(cx).kind() {
 589                CodegenKind::Transform { .. } => "Enter transformation prompt…",
 590                CodegenKind::Generate { .. } => "Enter generation prompt…",
 591            };
 592            editor.set_placeholder_text(placeholder, cx);
 593            editor
 594        });
 595        cx.focus_view(&prompt_editor);
 596
 597        let subscriptions = vec![
 598            cx.observe(&codegen, Self::handle_codegen_changed),
 599            cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
 600        ];
 601
 602        Self {
 603            id,
 604            prompt_editor,
 605            confirmed: false,
 606            gutter_dimensions,
 607            prompt_history,
 608            prompt_history_ix: None,
 609            pending_prompt: String::new(),
 610            codegen,
 611            _subscriptions: subscriptions,
 612        }
 613    }
 614
 615    fn handle_prompt_editor_events(
 616        &mut self,
 617        _: View<Editor>,
 618        event: &EditorEvent,
 619        cx: &mut ViewContext<Self>,
 620    ) {
 621        if let EditorEvent::Edited = event {
 622            self.pending_prompt = self.prompt_editor.read(cx).text(cx);
 623            cx.notify();
 624        }
 625    }
 626
 627    fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
 628        let is_read_only = !self.codegen.read(cx).idle();
 629        self.prompt_editor.update(cx, |editor, cx| {
 630            let was_read_only = editor.read_only(cx);
 631            if was_read_only != is_read_only {
 632                if is_read_only {
 633                    editor.set_read_only(true);
 634                } else {
 635                    self.confirmed = false;
 636                    editor.set_read_only(false);
 637                }
 638            }
 639        });
 640        cx.notify();
 641    }
 642
 643    fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
 644        cx.emit(InlineAssistEditorEvent::Canceled);
 645    }
 646
 647    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
 648        if self.confirmed {
 649            cx.emit(InlineAssistEditorEvent::Dismissed);
 650        } else {
 651            let prompt = self.prompt_editor.read(cx).text(cx);
 652            self.prompt_editor
 653                .update(cx, |editor, _cx| editor.set_read_only(true));
 654            cx.emit(InlineAssistEditorEvent::Confirmed { prompt });
 655            self.confirmed = true;
 656            cx.notify();
 657        }
 658    }
 659
 660    fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
 661        if let Some(ix) = self.prompt_history_ix {
 662            if ix > 0 {
 663                self.prompt_history_ix = Some(ix - 1);
 664                let prompt = self.prompt_history[ix - 1].clone();
 665                self.set_prompt(&prompt, cx);
 666            }
 667        } else if !self.prompt_history.is_empty() {
 668            self.prompt_history_ix = Some(self.prompt_history.len() - 1);
 669            let prompt = self.prompt_history[self.prompt_history.len() - 1].clone();
 670            self.set_prompt(&prompt, cx);
 671        }
 672    }
 673
 674    fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
 675        if let Some(ix) = self.prompt_history_ix {
 676            if ix < self.prompt_history.len() - 1 {
 677                self.prompt_history_ix = Some(ix + 1);
 678                let prompt = self.prompt_history[ix + 1].clone();
 679                self.set_prompt(&prompt, cx);
 680            } else {
 681                self.prompt_history_ix = None;
 682                let pending_prompt = self.pending_prompt.clone();
 683                self.set_prompt(&pending_prompt, cx);
 684            }
 685        }
 686    }
 687
 688    fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext<Self>) {
 689        self.prompt_editor.update(cx, |editor, cx| {
 690            editor.buffer().update(cx, |buffer, cx| {
 691                let len = buffer.len(cx);
 692                buffer.edit([(0..len, prompt)], None, cx);
 693            });
 694        });
 695    }
 696
 697    fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
 698        let settings = ThemeSettings::get_global(cx);
 699        let text_style = TextStyle {
 700            color: if self.prompt_editor.read(cx).read_only(cx) {
 701                cx.theme().colors().text_disabled
 702            } else {
 703                cx.theme().colors().text
 704            },
 705            font_family: settings.ui_font.family.clone(),
 706            font_features: settings.ui_font.features.clone(),
 707            font_size: rems(0.875).into(),
 708            font_weight: FontWeight::NORMAL,
 709            font_style: FontStyle::Normal,
 710            line_height: relative(1.3),
 711            background_color: None,
 712            underline: None,
 713            strikethrough: None,
 714            white_space: WhiteSpace::Normal,
 715        };
 716        EditorElement::new(
 717            &self.prompt_editor,
 718            EditorStyle {
 719                background: cx.theme().colors().editor_background,
 720                local_player: cx.theme().players().local(),
 721                text: text_style,
 722                ..Default::default()
 723            },
 724        )
 725    }
 726}
 727
 728struct PendingInlineAssist {
 729    editor: WeakView<Editor>,
 730    inline_assistant: Option<(BlockId, View<InlineAssistEditor>)>,
 731    codegen: Model<Codegen>,
 732    _subscriptions: Vec<Subscription>,
 733    workspace: Option<WeakView<Workspace>>,
 734    include_conversation: bool,
 735}
 736
 737#[derive(Debug)]
 738pub enum CodegenEvent {
 739    Finished,
 740    Undone,
 741}
 742
 743#[derive(Clone)]
 744pub enum CodegenKind {
 745    Transform { range: Range<Anchor> },
 746    Generate { position: Anchor },
 747}
 748
 749pub struct Codegen {
 750    buffer: Model<MultiBuffer>,
 751    snapshot: MultiBufferSnapshot,
 752    kind: CodegenKind,
 753    last_equal_ranges: Vec<Range<Anchor>>,
 754    transaction_id: Option<TransactionId>,
 755    error: Option<anyhow::Error>,
 756    generation: Task<()>,
 757    idle: bool,
 758    telemetry: Option<Arc<Telemetry>>,
 759    _subscription: gpui::Subscription,
 760}
 761
 762impl EventEmitter<CodegenEvent> for Codegen {}
 763
 764impl Codegen {
 765    pub fn new(
 766        buffer: Model<MultiBuffer>,
 767        kind: CodegenKind,
 768        telemetry: Option<Arc<Telemetry>>,
 769        cx: &mut ModelContext<Self>,
 770    ) -> Self {
 771        let snapshot = buffer.read(cx).snapshot(cx);
 772        Self {
 773            buffer: buffer.clone(),
 774            snapshot,
 775            kind,
 776            last_equal_ranges: Default::default(),
 777            transaction_id: Default::default(),
 778            error: Default::default(),
 779            idle: true,
 780            generation: Task::ready(()),
 781            telemetry,
 782            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 783        }
 784    }
 785
 786    fn handle_buffer_event(
 787        &mut self,
 788        _buffer: Model<MultiBuffer>,
 789        event: &multi_buffer::Event,
 790        cx: &mut ModelContext<Self>,
 791    ) {
 792        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 793            if self.transaction_id == Some(*transaction_id) {
 794                self.transaction_id = None;
 795                self.generation = Task::ready(());
 796                cx.emit(CodegenEvent::Undone);
 797            }
 798        }
 799    }
 800
 801    pub fn range(&self) -> Range<Anchor> {
 802        match &self.kind {
 803            CodegenKind::Transform { range } => range.clone(),
 804            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 805        }
 806    }
 807
 808    pub fn kind(&self) -> &CodegenKind {
 809        &self.kind
 810    }
 811
 812    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 813        &self.last_equal_ranges
 814    }
 815
 816    pub fn idle(&self) -> bool {
 817        self.idle
 818    }
 819
 820    pub fn error(&self) -> Option<&anyhow::Error> {
 821        self.error.as_ref()
 822    }
 823
 824    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
 825        let range = self.range();
 826        let snapshot = self.snapshot.clone();
 827        let selected_text = snapshot
 828            .text_for_range(range.start..range.end)
 829            .collect::<Rope>();
 830
 831        let selection_start = range.start.to_point(&snapshot);
 832        let suggested_line_indent = snapshot
 833            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
 834            .into_values()
 835            .next()
 836            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 837
 838        let model_telemetry_id = prompt.model.telemetry_id();
 839        let response = CompletionProvider::global(cx).complete(prompt);
 840        let telemetry = self.telemetry.clone();
 841        self.generation = cx.spawn(|this, mut cx| {
 842            async move {
 843                let generate = async {
 844                    let mut edit_start = range.start.to_offset(&snapshot);
 845
 846                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 847                    let diff: Task<anyhow::Result<()>> =
 848                        cx.background_executor().spawn(async move {
 849                            let mut response_latency = None;
 850                            let request_start = Instant::now();
 851                            let diff = async {
 852                                let chunks = strip_invalid_spans_from_codeblock(response.await?);
 853                                futures::pin_mut!(chunks);
 854                                let mut diff = StreamingDiff::new(selected_text.to_string());
 855
 856                                let mut new_text = String::new();
 857                                let mut base_indent = None;
 858                                let mut line_indent = None;
 859                                let mut first_line = true;
 860
 861                                while let Some(chunk) = chunks.next().await {
 862                                    if response_latency.is_none() {
 863                                        response_latency = Some(request_start.elapsed());
 864                                    }
 865                                    let chunk = chunk?;
 866
 867                                    let mut lines = chunk.split('\n').peekable();
 868                                    while let Some(line) = lines.next() {
 869                                        new_text.push_str(line);
 870                                        if line_indent.is_none() {
 871                                            if let Some(non_whitespace_ch_ix) =
 872                                                new_text.find(|ch: char| !ch.is_whitespace())
 873                                            {
 874                                                line_indent = Some(non_whitespace_ch_ix);
 875                                                base_indent = base_indent.or(line_indent);
 876
 877                                                let line_indent = line_indent.unwrap();
 878                                                let base_indent = base_indent.unwrap();
 879                                                let indent_delta =
 880                                                    line_indent as i32 - base_indent as i32;
 881                                                let mut corrected_indent_len = cmp::max(
 882                                                    0,
 883                                                    suggested_line_indent.len as i32 + indent_delta,
 884                                                )
 885                                                    as usize;
 886                                                if first_line {
 887                                                    corrected_indent_len = corrected_indent_len
 888                                                        .saturating_sub(
 889                                                            selection_start.column as usize,
 890                                                        );
 891                                                }
 892
 893                                                let indent_char = suggested_line_indent.char();
 894                                                let mut indent_buffer = [0; 4];
 895                                                let indent_str =
 896                                                    indent_char.encode_utf8(&mut indent_buffer);
 897                                                new_text.replace_range(
 898                                                    ..line_indent,
 899                                                    &indent_str.repeat(corrected_indent_len),
 900                                                );
 901                                            }
 902                                        }
 903
 904                                        if line_indent.is_some() {
 905                                            hunks_tx.send(diff.push_new(&new_text)).await?;
 906                                            new_text.clear();
 907                                        }
 908
 909                                        if lines.peek().is_some() {
 910                                            hunks_tx.send(diff.push_new("\n")).await?;
 911                                            line_indent = None;
 912                                            first_line = false;
 913                                        }
 914                                    }
 915                                }
 916                                hunks_tx.send(diff.push_new(&new_text)).await?;
 917                                hunks_tx.send(diff.finish()).await?;
 918
 919                                anyhow::Ok(())
 920                            };
 921
 922                            let result = diff.await;
 923
 924                            let error_message =
 925                                result.as_ref().err().map(|error| error.to_string());
 926                            if let Some(telemetry) = telemetry {
 927                                telemetry.report_assistant_event(
 928                                    None,
 929                                    telemetry_events::AssistantKind::Inline,
 930                                    model_telemetry_id,
 931                                    response_latency,
 932                                    error_message,
 933                                );
 934                            }
 935
 936                            result?;
 937                            Ok(())
 938                        });
 939
 940                    while let Some(hunks) = hunks_rx.next().await {
 941                        this.update(&mut cx, |this, cx| {
 942                            this.last_equal_ranges.clear();
 943
 944                            let transaction = this.buffer.update(cx, |buffer, cx| {
 945                                // Avoid grouping assistant edits with user edits.
 946                                buffer.finalize_last_transaction(cx);
 947
 948                                buffer.start_transaction(cx);
 949                                buffer.edit(
 950                                    hunks.into_iter().filter_map(|hunk| match hunk {
 951                                        Hunk::Insert { text } => {
 952                                            let edit_start = snapshot.anchor_after(edit_start);
 953                                            Some((edit_start..edit_start, text))
 954                                        }
 955                                        Hunk::Remove { len } => {
 956                                            let edit_end = edit_start + len;
 957                                            let edit_range = snapshot.anchor_after(edit_start)
 958                                                ..snapshot.anchor_before(edit_end);
 959                                            edit_start = edit_end;
 960                                            Some((edit_range, String::new()))
 961                                        }
 962                                        Hunk::Keep { len } => {
 963                                            let edit_end = edit_start + len;
 964                                            let edit_range = snapshot.anchor_after(edit_start)
 965                                                ..snapshot.anchor_before(edit_end);
 966                                            edit_start = edit_end;
 967                                            this.last_equal_ranges.push(edit_range);
 968                                            None
 969                                        }
 970                                    }),
 971                                    None,
 972                                    cx,
 973                                );
 974
 975                                buffer.end_transaction(cx)
 976                            });
 977
 978                            if let Some(transaction) = transaction {
 979                                if let Some(first_transaction) = this.transaction_id {
 980                                    // Group all assistant edits into the first transaction.
 981                                    this.buffer.update(cx, |buffer, cx| {
 982                                        buffer.merge_transactions(
 983                                            transaction,
 984                                            first_transaction,
 985                                            cx,
 986                                        )
 987                                    });
 988                                } else {
 989                                    this.transaction_id = Some(transaction);
 990                                    this.buffer.update(cx, |buffer, cx| {
 991                                        buffer.finalize_last_transaction(cx)
 992                                    });
 993                                }
 994                            }
 995
 996                            cx.notify();
 997                        })?;
 998                    }
 999
1000                    diff.await?;
1001
1002                    anyhow::Ok(())
1003                };
1004
1005                let result = generate.await;
1006                this.update(&mut cx, |this, cx| {
1007                    this.last_equal_ranges.clear();
1008                    this.idle = true;
1009                    if let Err(error) = result {
1010                        this.error = Some(error);
1011                    }
1012                    cx.emit(CodegenEvent::Finished);
1013                    cx.notify();
1014                })
1015                .ok();
1016            }
1017        });
1018        self.error.take();
1019        self.idle = false;
1020        cx.notify();
1021    }
1022
1023    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1024        if let Some(transaction_id) = self.transaction_id {
1025            self.buffer
1026                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1027        }
1028    }
1029}
1030
1031fn strip_invalid_spans_from_codeblock(
1032    stream: impl Stream<Item = Result<String>>,
1033) -> impl Stream<Item = Result<String>> {
1034    let mut first_line = true;
1035    let mut buffer = String::new();
1036    let mut starts_with_markdown_codeblock = false;
1037    let mut includes_start_or_end_span = false;
1038    stream.filter_map(move |chunk| {
1039        let chunk = match chunk {
1040            Ok(chunk) => chunk,
1041            Err(err) => return future::ready(Some(Err(err))),
1042        };
1043        buffer.push_str(&chunk);
1044
1045        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1046            includes_start_or_end_span = true;
1047
1048            buffer = buffer
1049                .strip_prefix("<|S|>")
1050                .or_else(|| buffer.strip_prefix("<|S|"))
1051                .unwrap_or(&buffer)
1052                .to_string();
1053        } else if buffer.ends_with("|E|>") {
1054            includes_start_or_end_span = true;
1055        } else if buffer.starts_with("<|")
1056            || buffer.starts_with("<|S")
1057            || buffer.starts_with("<|S|")
1058            || buffer.ends_with('|')
1059            || buffer.ends_with("|E")
1060            || buffer.ends_with("|E|")
1061        {
1062            return future::ready(None);
1063        }
1064
1065        if first_line {
1066            if buffer.is_empty() || buffer == "`" || buffer == "``" {
1067                return future::ready(None);
1068            } else if buffer.starts_with("```") {
1069                starts_with_markdown_codeblock = true;
1070                if let Some(newline_ix) = buffer.find('\n') {
1071                    buffer.replace_range(..newline_ix + 1, "");
1072                    first_line = false;
1073                } else {
1074                    return future::ready(None);
1075                }
1076            }
1077        }
1078
1079        let mut text = buffer.to_string();
1080        if starts_with_markdown_codeblock {
1081            text = text
1082                .strip_suffix("\n```\n")
1083                .or_else(|| text.strip_suffix("\n```"))
1084                .or_else(|| text.strip_suffix("\n``"))
1085                .or_else(|| text.strip_suffix("\n`"))
1086                .or_else(|| text.strip_suffix('\n'))
1087                .unwrap_or(&text)
1088                .to_string();
1089        }
1090
1091        if includes_start_or_end_span {
1092            text = text
1093                .strip_suffix("|E|>")
1094                .or_else(|| text.strip_suffix("E|>"))
1095                .or_else(|| text.strip_prefix("|>"))
1096                .or_else(|| text.strip_prefix('>'))
1097                .unwrap_or(&text)
1098                .to_string();
1099        };
1100
1101        if text.contains('\n') {
1102            first_line = false;
1103        }
1104
1105        let remainder = buffer.split_off(text.len());
1106        let result = if buffer.is_empty() {
1107            None
1108        } else {
1109            Some(Ok(buffer.clone()))
1110        };
1111
1112        buffer = remainder;
1113        future::ready(result)
1114    })
1115}
1116
1117fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1118    ranges.sort_unstable_by(|a, b| {
1119        a.start
1120            .cmp(&b.start, buffer)
1121            .then_with(|| b.end.cmp(&a.end, buffer))
1122    });
1123
1124    let mut ix = 0;
1125    while ix + 1 < ranges.len() {
1126        let b = ranges[ix + 1].clone();
1127        let a = &mut ranges[ix];
1128        if a.end.cmp(&b.start, buffer).is_gt() {
1129            if a.end.cmp(&b.end, buffer).is_lt() {
1130                a.end = b.end;
1131            }
1132            ranges.remove(ix + 1);
1133        } else {
1134            ix += 1;
1135        }
1136    }
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141    use std::sync::Arc;
1142
1143    use crate::FakeCompletionProvider;
1144
1145    use super::*;
1146    use futures::stream::{self};
1147    use gpui::{Context, TestAppContext};
1148    use indoc::indoc;
1149    use language::{
1150        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1151        Point,
1152    };
1153    use rand::prelude::*;
1154    use serde::Serialize;
1155    use settings::SettingsStore;
1156
1157    #[derive(Serialize)]
1158    pub struct DummyCompletionRequest {
1159        pub name: String,
1160    }
1161
1162    #[gpui::test(iterations = 10)]
1163    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1164        let provider = FakeCompletionProvider::default();
1165        cx.set_global(cx.update(SettingsStore::test));
1166        cx.set_global(CompletionProvider::Fake(provider.clone()));
1167        cx.update(language_settings::init);
1168
1169        let text = indoc! {"
1170            fn main() {
1171                let x = 0;
1172                for _ in 0..10 {
1173                    x += 1;
1174                }
1175            }
1176        "};
1177        let buffer =
1178            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1179        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1180        let range = buffer.read_with(cx, |buffer, cx| {
1181            let snapshot = buffer.snapshot(cx);
1182            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1183        });
1184        let codegen = cx.new_model(|cx| {
1185            Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1186        });
1187
1188        let request = LanguageModelRequest::default();
1189        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1190
1191        let mut new_text = concat!(
1192            "       let mut x = 0;\n",
1193            "       while x < 10 {\n",
1194            "           x += 1;\n",
1195            "       }",
1196        );
1197        while !new_text.is_empty() {
1198            let max_len = cmp::min(new_text.len(), 10);
1199            let len = rng.gen_range(1..=max_len);
1200            let (chunk, suffix) = new_text.split_at(len);
1201            provider.send_completion(chunk.into());
1202            new_text = suffix;
1203            cx.background_executor.run_until_parked();
1204        }
1205        provider.finish_completion();
1206        cx.background_executor.run_until_parked();
1207
1208        assert_eq!(
1209            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1210            indoc! {"
1211                fn main() {
1212                    let mut x = 0;
1213                    while x < 10 {
1214                        x += 1;
1215                    }
1216                }
1217            "}
1218        );
1219    }
1220
1221    #[gpui::test(iterations = 10)]
1222    async fn test_autoindent_when_generating_past_indentation(
1223        cx: &mut TestAppContext,
1224        mut rng: StdRng,
1225    ) {
1226        let provider = FakeCompletionProvider::default();
1227        cx.set_global(CompletionProvider::Fake(provider.clone()));
1228        cx.set_global(cx.update(SettingsStore::test));
1229        cx.update(language_settings::init);
1230
1231        let text = indoc! {"
1232            fn main() {
1233                le
1234            }
1235        "};
1236        let buffer =
1237            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1238        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1239        let position = buffer.read_with(cx, |buffer, cx| {
1240            let snapshot = buffer.snapshot(cx);
1241            snapshot.anchor_before(Point::new(1, 6))
1242        });
1243        let codegen = cx.new_model(|cx| {
1244            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1245        });
1246
1247        let request = LanguageModelRequest::default();
1248        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1249
1250        let mut new_text = concat!(
1251            "t mut x = 0;\n",
1252            "while x < 10 {\n",
1253            "    x += 1;\n",
1254            "}", //
1255        );
1256        while !new_text.is_empty() {
1257            let max_len = cmp::min(new_text.len(), 10);
1258            let len = rng.gen_range(1..=max_len);
1259            let (chunk, suffix) = new_text.split_at(len);
1260            provider.send_completion(chunk.into());
1261            new_text = suffix;
1262            cx.background_executor.run_until_parked();
1263        }
1264        provider.finish_completion();
1265        cx.background_executor.run_until_parked();
1266
1267        assert_eq!(
1268            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1269            indoc! {"
1270                fn main() {
1271                    let mut x = 0;
1272                    while x < 10 {
1273                        x += 1;
1274                    }
1275                }
1276            "}
1277        );
1278    }
1279
1280    #[gpui::test(iterations = 10)]
1281    async fn test_autoindent_when_generating_before_indentation(
1282        cx: &mut TestAppContext,
1283        mut rng: StdRng,
1284    ) {
1285        let provider = FakeCompletionProvider::default();
1286        cx.set_global(CompletionProvider::Fake(provider.clone()));
1287        cx.set_global(cx.update(SettingsStore::test));
1288        cx.update(language_settings::init);
1289
1290        let text = concat!(
1291            "fn main() {\n",
1292            "  \n",
1293            "}\n" //
1294        );
1295        let buffer =
1296            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1297        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1298        let position = buffer.read_with(cx, |buffer, cx| {
1299            let snapshot = buffer.snapshot(cx);
1300            snapshot.anchor_before(Point::new(1, 2))
1301        });
1302        let codegen = cx.new_model(|cx| {
1303            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1304        });
1305
1306        let request = LanguageModelRequest::default();
1307        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1308
1309        let mut new_text = concat!(
1310            "let mut x = 0;\n",
1311            "while x < 10 {\n",
1312            "    x += 1;\n",
1313            "}", //
1314        );
1315        while !new_text.is_empty() {
1316            let max_len = cmp::min(new_text.len(), 10);
1317            let len = rng.gen_range(1..=max_len);
1318            let (chunk, suffix) = new_text.split_at(len);
1319            provider.send_completion(chunk.into());
1320            new_text = suffix;
1321            cx.background_executor.run_until_parked();
1322        }
1323        provider.finish_completion();
1324        cx.background_executor.run_until_parked();
1325
1326        assert_eq!(
1327            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1328            indoc! {"
1329                fn main() {
1330                    let mut x = 0;
1331                    while x < 10 {
1332                        x += 1;
1333                    }
1334                }
1335            "}
1336        );
1337    }
1338
1339    #[gpui::test]
1340    async fn test_strip_invalid_spans_from_codeblock() {
1341        assert_eq!(
1342            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1343                .map(|chunk| chunk.unwrap())
1344                .collect::<String>()
1345                .await,
1346            "Lorem ipsum dolor"
1347        );
1348        assert_eq!(
1349            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1350                .map(|chunk| chunk.unwrap())
1351                .collect::<String>()
1352                .await,
1353            "Lorem ipsum dolor"
1354        );
1355        assert_eq!(
1356            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1357                .map(|chunk| chunk.unwrap())
1358                .collect::<String>()
1359                .await,
1360            "Lorem ipsum dolor"
1361        );
1362        assert_eq!(
1363            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1364                .map(|chunk| chunk.unwrap())
1365                .collect::<String>()
1366                .await,
1367            "Lorem ipsum dolor"
1368        );
1369        assert_eq!(
1370            strip_invalid_spans_from_codeblock(chunks(
1371                "```html\n```js\nLorem ipsum dolor\n```\n```",
1372                2
1373            ))
1374            .map(|chunk| chunk.unwrap())
1375            .collect::<String>()
1376            .await,
1377            "```js\nLorem ipsum dolor\n```"
1378        );
1379        assert_eq!(
1380            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1381                .map(|chunk| chunk.unwrap())
1382                .collect::<String>()
1383                .await,
1384            "``\nLorem ipsum dolor\n```"
1385        );
1386        assert_eq!(
1387            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1388                .map(|chunk| chunk.unwrap())
1389                .collect::<String>()
1390                .await,
1391            "Lorem ipsum"
1392        );
1393
1394        assert_eq!(
1395            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
1396                .map(|chunk| chunk.unwrap())
1397                .collect::<String>()
1398                .await,
1399            "Lorem ipsum"
1400        );
1401
1402        assert_eq!(
1403            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
1404                .map(|chunk| chunk.unwrap())
1405                .collect::<String>()
1406                .await,
1407            "Lorem ipsum"
1408        );
1409        assert_eq!(
1410            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
1411                .map(|chunk| chunk.unwrap())
1412                .collect::<String>()
1413                .await,
1414            "Lorem ipsum"
1415        );
1416        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1417            stream::iter(
1418                text.chars()
1419                    .collect::<Vec<_>>()
1420                    .chunks(size)
1421                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
1422                    .collect::<Vec<_>>(),
1423            )
1424        }
1425    }
1426
1427    fn rust_lang() -> Language {
1428        Language::new(
1429            LanguageConfig {
1430                name: "Rust".into(),
1431                matcher: LanguageMatcher {
1432                    path_suffixes: vec!["rs".to_string()],
1433                    ..Default::default()
1434                },
1435                ..Default::default()
1436            },
1437            Some(tree_sitter_rust::language()),
1438        )
1439        .with_indents_query(
1440            r#"
1441            (call_expression) @indent
1442            (field_expression) @indent
1443            (_ "(" ")" @end) @indent
1444            (_ "{" "}" @end) @indent
1445            "#,
1446        )
1447        .unwrap()
1448    }
1449}