1use crate::{
2 prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
3 LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
4};
5use anyhow::Result;
6use client::telemetry::Telemetry;
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp},
10 display_map::{
11 BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
12 },
13 scroll::{Autoscroll, AutoscrollStrategy},
14 Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer,
15 MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
18use gpui::{
19 AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
20 Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
21 ViewContext, WeakView, WhiteSpace, WindowContext,
22};
23use language::{Point, TransactionId};
24use multi_buffer::MultiBufferRow;
25use parking_lot::Mutex;
26use rope::Rope;
27use settings::Settings;
28use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
29use theme::ThemeSettings;
30use ui::{prelude::*, Tooltip};
31use workspace::{notifications::NotificationId, Toast, Workspace};
32
33pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
34 cx.set_global(InlineAssistant::new(telemetry));
35}
36
37const PROMPT_HISTORY_MAX_LEN: usize = 20;
38
39pub struct InlineAssistant {
40 next_assist_id: InlineAssistId,
41 pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
42 pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
43 prompt_history: VecDeque<String>,
44 telemetry: Option<Arc<Telemetry>>,
45}
46
47struct EditorPendingAssists {
48 window: AnyWindowHandle,
49 assist_ids: Vec<InlineAssistId>,
50}
51
52impl Global for InlineAssistant {}
53
54impl InlineAssistant {
55 pub fn new(telemetry: Arc<Telemetry>) -> Self {
56 Self {
57 next_assist_id: InlineAssistId::default(),
58 pending_assists: HashMap::default(),
59 pending_assist_ids_by_editor: HashMap::default(),
60 prompt_history: VecDeque::default(),
61 telemetry: Some(telemetry),
62 }
63 }
64
65 pub fn assist(
66 &mut self,
67 editor: &View<Editor>,
68 workspace: Option<WeakView<Workspace>>,
69 include_context: bool,
70 cx: &mut WindowContext,
71 ) {
72 let selection = editor.read(cx).selections.newest_anchor().clone();
73 if selection.start.excerpt_id != selection.end.excerpt_id {
74 return;
75 }
76 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
77
78 // Extend the selection to the start and the end of the line.
79 let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
80 if point_selection.end > point_selection.start {
81 point_selection.start.column = 0;
82 // If the selection ends at the start of the line, we don't want to include it.
83 if point_selection.end.column == 0 {
84 point_selection.end.row -= 1;
85 }
86 point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
87 }
88
89 let codegen_kind = if point_selection.start == point_selection.end {
90 CodegenKind::Generate {
91 position: snapshot.anchor_after(point_selection.start),
92 }
93 } else {
94 CodegenKind::Transform {
95 range: snapshot.anchor_before(point_selection.start)
96 ..snapshot.anchor_after(point_selection.end),
97 }
98 };
99
100 let inline_assist_id = self.next_assist_id.post_inc();
101 let codegen = cx.new_model(|cx| {
102 Codegen::new(
103 editor.read(cx).buffer().clone(),
104 codegen_kind,
105 self.telemetry.clone(),
106 cx,
107 )
108 });
109
110 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
111 let inline_assist_editor = cx.new_view(|cx| {
112 InlineAssistEditor::new(
113 inline_assist_id,
114 gutter_dimensions.clone(),
115 self.prompt_history.clone(),
116 codegen.clone(),
117 cx,
118 )
119 });
120 let block_id = editor.update(cx, |editor, cx| {
121 editor.change_selections(None, cx, |selections| {
122 selections.select_anchor_ranges([selection.head()..selection.head()])
123 });
124 editor.insert_blocks(
125 [BlockProperties {
126 style: BlockStyle::Sticky,
127 position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)),
128 height: inline_assist_editor.read(cx).height_in_lines,
129 render: build_inline_assist_editor_renderer(
130 &inline_assist_editor,
131 gutter_dimensions,
132 ),
133 disposition: if selection.reversed {
134 BlockDisposition::Above
135 } else {
136 BlockDisposition::Below
137 },
138 }],
139 Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
140 cx,
141 )[0]
142 });
143
144 self.pending_assists.insert(
145 inline_assist_id,
146 PendingInlineAssist {
147 include_context,
148 editor: editor.downgrade(),
149 inline_assist_editor: Some((block_id, inline_assist_editor.clone())),
150 codegen: codegen.clone(),
151 workspace,
152 _subscriptions: vec![
153 cx.subscribe(&inline_assist_editor, |inline_assist_editor, event, cx| {
154 InlineAssistant::update_global(cx, |this, cx| {
155 this.handle_inline_assistant_event(inline_assist_editor, event, cx)
156 })
157 }),
158 cx.subscribe(editor, {
159 let inline_assist_editor = inline_assist_editor.downgrade();
160 move |editor, event, cx| {
161 if let Some(inline_assist_editor) = inline_assist_editor.upgrade() {
162 if let EditorEvent::SelectionsChanged { local } = event {
163 if *local
164 && inline_assist_editor
165 .focus_handle(cx)
166 .contains_focused(cx)
167 {
168 cx.focus_view(&editor);
169 }
170 }
171 }
172 }
173 }),
174 cx.observe(&codegen, {
175 let editor = editor.downgrade();
176 move |_, cx| {
177 if let Some(editor) = editor.upgrade() {
178 InlineAssistant::update_global(cx, |this, cx| {
179 this.update_highlights_for_editor(&editor, cx);
180 })
181 }
182 }
183 }),
184 cx.subscribe(&codegen, move |codegen, event, cx| {
185 InlineAssistant::update_global(cx, |this, cx| match event {
186 CodegenEvent::Undone => {
187 this.finish_inline_assist(inline_assist_id, false, cx)
188 }
189 CodegenEvent::Finished => {
190 let pending_assist = if let Some(pending_assist) =
191 this.pending_assists.get(&inline_assist_id)
192 {
193 pending_assist
194 } else {
195 return;
196 };
197
198 let error = codegen
199 .read(cx)
200 .error()
201 .map(|error| format!("Inline assistant error: {}", error));
202 if let Some(error) = error {
203 if pending_assist.inline_assist_editor.is_none() {
204 if let Some(workspace) = pending_assist
205 .workspace
206 .as_ref()
207 .and_then(|workspace| workspace.upgrade())
208 {
209 workspace.update(cx, |workspace, cx| {
210 struct InlineAssistantError;
211
212 let id = NotificationId::identified::<
213 InlineAssistantError,
214 >(
215 inline_assist_id.0
216 );
217
218 workspace.show_toast(Toast::new(id, error), cx);
219 })
220 }
221
222 this.finish_inline_assist(inline_assist_id, false, cx);
223 }
224 } else {
225 this.finish_inline_assist(inline_assist_id, false, cx);
226 }
227 }
228 })
229 }),
230 ],
231 },
232 );
233
234 self.pending_assist_ids_by_editor
235 .entry(editor.downgrade())
236 .or_insert_with(|| EditorPendingAssists {
237 window: cx.window_handle(),
238 assist_ids: Vec::new(),
239 })
240 .assist_ids
241 .push(inline_assist_id);
242 self.update_highlights_for_editor(editor, cx);
243 }
244
245 fn handle_inline_assistant_event(
246 &mut self,
247 inline_assist_editor: View<InlineAssistEditor>,
248 event: &InlineAssistEditorEvent,
249 cx: &mut WindowContext,
250 ) {
251 let assist_id = inline_assist_editor.read(cx).id;
252 match event {
253 InlineAssistEditorEvent::Confirmed { prompt } => {
254 self.confirm_inline_assist(assist_id, prompt, cx);
255 }
256 InlineAssistEditorEvent::Canceled => {
257 self.finish_inline_assist(assist_id, true, cx);
258 }
259 InlineAssistEditorEvent::Dismissed => {
260 self.hide_inline_assist(assist_id, cx);
261 }
262 InlineAssistEditorEvent::Resized { height_in_lines } => {
263 self.resize_inline_assist(assist_id, *height_in_lines, cx);
264 }
265 }
266 }
267
268 pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
269 for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
270 if pending_assists.window == cx.window_handle() {
271 if let Some(editor) = editor.upgrade() {
272 if editor.read(cx).is_focused(cx) {
273 if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
274 self.finish_inline_assist(assist_id, true, cx);
275 return true;
276 }
277 }
278 }
279 }
280 }
281 false
282 }
283
284 fn finish_inline_assist(
285 &mut self,
286 assist_id: InlineAssistId,
287 undo: bool,
288 cx: &mut WindowContext,
289 ) {
290 self.hide_inline_assist(assist_id, cx);
291
292 if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
293 if let hash_map::Entry::Occupied(mut entry) = self
294 .pending_assist_ids_by_editor
295 .entry(pending_assist.editor.clone())
296 {
297 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
298 if entry.get().assist_ids.is_empty() {
299 entry.remove();
300 }
301 }
302
303 if let Some(editor) = pending_assist.editor.upgrade() {
304 self.update_highlights_for_editor(&editor, cx);
305
306 if undo {
307 pending_assist
308 .codegen
309 .update(cx, |codegen, cx| codegen.undo(cx));
310 }
311 }
312 }
313 }
314
315 fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
316 if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
317 if let Some(editor) = pending_assist.editor.upgrade() {
318 if let Some((block_id, inline_assist_editor)) =
319 pending_assist.inline_assist_editor.take()
320 {
321 editor.update(cx, |editor, cx| {
322 editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
323 if inline_assist_editor.focus_handle(cx).contains_focused(cx) {
324 editor.focus(cx);
325 }
326 });
327 }
328 }
329 }
330 }
331
332 fn resize_inline_assist(
333 &mut self,
334 assist_id: InlineAssistId,
335 height_in_lines: u8,
336 cx: &mut WindowContext,
337 ) {
338 if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
339 if let Some(editor) = pending_assist.editor.upgrade() {
340 if let Some((block_id, inline_assist_editor)) =
341 pending_assist.inline_assist_editor.as_ref()
342 {
343 let gutter_dimensions = inline_assist_editor.read(cx).gutter_dimensions.clone();
344 let mut new_blocks = HashMap::default();
345 new_blocks.insert(
346 *block_id,
347 (
348 Some(height_in_lines),
349 build_inline_assist_editor_renderer(
350 inline_assist_editor,
351 gutter_dimensions,
352 ),
353 ),
354 );
355 editor.update(cx, |editor, cx| {
356 editor
357 .display_map
358 .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
359 });
360 }
361 }
362 }
363 }
364
365 fn confirm_inline_assist(
366 &mut self,
367 assist_id: InlineAssistId,
368 user_prompt: &str,
369 cx: &mut WindowContext,
370 ) {
371 let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
372 {
373 pending_assist
374 } else {
375 return;
376 };
377
378 let context = if pending_assist.include_context {
379 pending_assist.workspace.as_ref().and_then(|workspace| {
380 let workspace = workspace.upgrade()?.read(cx);
381 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
382 assistant_panel.read(cx).active_context(cx)
383 })
384 } else {
385 None
386 };
387
388 let editor = if let Some(editor) = pending_assist.editor.upgrade() {
389 editor
390 } else {
391 return;
392 };
393
394 let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
395 let workspace = workspace.upgrade()?;
396 Some(
397 workspace
398 .read(cx)
399 .project()
400 .read(cx)
401 .worktree_root_names(cx)
402 .collect::<Vec<&str>>()
403 .join("/"),
404 )
405 });
406
407 self.prompt_history.retain(|prompt| prompt != user_prompt);
408 self.prompt_history.push_back(user_prompt.into());
409 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
410 self.prompt_history.pop_front();
411 }
412
413 let codegen = pending_assist.codegen.clone();
414 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
415 let range = codegen.read(cx).range();
416 let start = snapshot.point_to_buffer_offset(range.start);
417 let end = snapshot.point_to_buffer_offset(range.end);
418 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
419 let (start_buffer, start_buffer_offset) = start;
420 let (end_buffer, end_buffer_offset) = end;
421 if start_buffer.remote_id() == end_buffer.remote_id() {
422 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
423 } else {
424 self.finish_inline_assist(assist_id, false, cx);
425 return;
426 }
427 } else {
428 self.finish_inline_assist(assist_id, false, cx);
429 return;
430 };
431
432 let language = buffer.language_at(range.start);
433 let language_name = if let Some(language) = language.as_ref() {
434 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
435 None
436 } else {
437 Some(language.name())
438 }
439 } else {
440 None
441 };
442
443 // Higher Temperature increases the randomness of model outputs.
444 // If Markdown or No Language is Known, increase the randomness for more creative output
445 // If Code, decrease temperature to get more deterministic outputs
446 let temperature = if let Some(language) = language_name.clone() {
447 if language.as_ref() == "Markdown" {
448 1.0
449 } else {
450 0.5
451 }
452 } else {
453 1.0
454 };
455
456 let user_prompt = user_prompt.to_string();
457
458 let prompt = cx.background_executor().spawn(async move {
459 let language_name = language_name.as_deref();
460 generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
461 });
462
463 let mut messages = Vec::new();
464 if let Some(context) = context {
465 let request = context.read(cx).to_completion_request(cx);
466 messages = request.messages;
467 }
468 let model = CompletionProvider::global(cx).model();
469
470 cx.spawn(|mut cx| async move {
471 let prompt = prompt.await?;
472
473 messages.push(LanguageModelRequestMessage {
474 role: Role::User,
475 content: prompt,
476 });
477
478 let request = LanguageModelRequest {
479 model,
480 messages,
481 stop: vec!["|END|>".to_string()],
482 temperature,
483 };
484
485 codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
486 anyhow::Ok(())
487 })
488 .detach_and_log_err(cx);
489 }
490
491 fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut WindowContext) {
492 let mut background_ranges = Vec::new();
493 let mut foreground_ranges = Vec::new();
494 let empty_inline_assist_ids = Vec::new();
495 let inline_assist_ids = self
496 .pending_assist_ids_by_editor
497 .get(&editor.downgrade())
498 .map_or(&empty_inline_assist_ids, |pending_assists| {
499 &pending_assists.assist_ids
500 });
501
502 for inline_assist_id in inline_assist_ids {
503 if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
504 let codegen = pending_assist.codegen.read(cx);
505 background_ranges.push(codegen.range());
506 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
507 }
508 }
509
510 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
511 merge_ranges(&mut background_ranges, &snapshot);
512 merge_ranges(&mut foreground_ranges, &snapshot);
513 editor.update(cx, |editor, cx| {
514 if background_ranges.is_empty() {
515 editor.clear_background_highlights::<PendingInlineAssist>(cx);
516 } else {
517 editor.highlight_background::<PendingInlineAssist>(
518 &background_ranges,
519 |theme| theme.editor_active_line_background, // TODO use the appropriate color
520 cx,
521 );
522 }
523
524 if foreground_ranges.is_empty() {
525 editor.clear_highlights::<PendingInlineAssist>(cx);
526 } else {
527 editor.highlight_text::<PendingInlineAssist>(
528 foreground_ranges,
529 HighlightStyle {
530 fade_out: Some(0.6),
531 ..Default::default()
532 },
533 cx,
534 );
535 }
536 });
537 }
538}
539
540fn build_inline_assist_editor_renderer(
541 editor: &View<InlineAssistEditor>,
542 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
543) -> RenderBlock {
544 let editor = editor.clone();
545 Box::new(move |cx: &mut BlockContext| {
546 *gutter_dimensions.lock() = *cx.gutter_dimensions;
547 editor.clone().into_any_element()
548 })
549}
550
551#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
552struct InlineAssistId(usize);
553
554impl InlineAssistId {
555 fn post_inc(&mut self) -> InlineAssistId {
556 let id = *self;
557 self.0 += 1;
558 id
559 }
560}
561
562enum InlineAssistEditorEvent {
563 Confirmed { prompt: String },
564 Canceled,
565 Dismissed,
566 Resized { height_in_lines: u8 },
567}
568
569struct InlineAssistEditor {
570 id: InlineAssistId,
571 height_in_lines: u8,
572 prompt_editor: View<Editor>,
573 confirmed: bool,
574 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
575 prompt_history: VecDeque<String>,
576 prompt_history_ix: Option<usize>,
577 pending_prompt: String,
578 codegen: Model<Codegen>,
579 _subscriptions: Vec<Subscription>,
580}
581
582impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
583
584impl Render for InlineAssistEditor {
585 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
586 let gutter_dimensions = *self.gutter_dimensions.lock();
587 let icon_size = IconSize::default();
588 h_flex()
589 .w_full()
590 .py_1p5()
591 .border_y_1()
592 .border_color(cx.theme().colors().border)
593 .bg(cx.theme().colors().editor_background)
594 .on_action(cx.listener(Self::confirm))
595 .on_action(cx.listener(Self::cancel))
596 .on_action(cx.listener(Self::move_up))
597 .on_action(cx.listener(Self::move_down))
598 .child(
599 h_flex()
600 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
601 .pr(gutter_dimensions.fold_area_width())
602 .justify_end()
603 .children(if let Some(error) = self.codegen.read(cx).error() {
604 let error_message = SharedString::from(error.to_string());
605 Some(
606 div()
607 .id("error")
608 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
609 .child(
610 Icon::new(IconName::XCircle)
611 .size(icon_size)
612 .color(Color::Error),
613 ),
614 )
615 } else {
616 None
617 }),
618 )
619 .child(div().flex_1().child(self.render_prompt_editor(cx)))
620 }
621}
622
623impl FocusableView for InlineAssistEditor {
624 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
625 self.prompt_editor.focus_handle(cx)
626 }
627}
628
629impl InlineAssistEditor {
630 const MAX_LINES: u8 = 8;
631
632 #[allow(clippy::too_many_arguments)]
633 fn new(
634 id: InlineAssistId,
635 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
636 prompt_history: VecDeque<String>,
637 codegen: Model<Codegen>,
638 cx: &mut ViewContext<Self>,
639 ) -> Self {
640 let prompt_editor = cx.new_view(|cx| {
641 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
642 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
643 let placeholder = match codegen.read(cx).kind() {
644 CodegenKind::Transform { .. } => "Enter transformation prompt…",
645 CodegenKind::Generate { .. } => "Enter generation prompt…",
646 };
647 editor.set_placeholder_text(placeholder, cx);
648 editor
649 });
650 cx.focus_view(&prompt_editor);
651
652 let subscriptions = vec![
653 cx.observe(&codegen, Self::handle_codegen_changed),
654 cx.observe(&prompt_editor, Self::handle_prompt_editor_changed),
655 cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
656 ];
657
658 let mut this = Self {
659 id,
660 height_in_lines: 1,
661 prompt_editor,
662 confirmed: false,
663 gutter_dimensions,
664 prompt_history,
665 prompt_history_ix: None,
666 pending_prompt: String::new(),
667 codegen,
668 _subscriptions: subscriptions,
669 };
670 this.count_lines(cx);
671 this
672 }
673
674 fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
675 let height_in_lines = cmp::max(
676 2, // Make the editor at least two lines tall, to account for padding.
677 cmp::min(
678 self.prompt_editor
679 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
680 Self::MAX_LINES as u32,
681 ),
682 ) as u8;
683
684 if height_in_lines != self.height_in_lines {
685 self.height_in_lines = height_in_lines;
686 cx.emit(InlineAssistEditorEvent::Resized { height_in_lines });
687 }
688 }
689
690 fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
691 self.count_lines(cx);
692 }
693
694 fn handle_prompt_editor_events(
695 &mut self,
696 _: View<Editor>,
697 event: &EditorEvent,
698 cx: &mut ViewContext<Self>,
699 ) {
700 match event {
701 EditorEvent::Edited => {
702 self.pending_prompt = self.prompt_editor.read(cx).text(cx);
703 cx.notify();
704 }
705 EditorEvent::Blurred => {
706 if !self.confirmed {
707 cx.emit(InlineAssistEditorEvent::Canceled);
708 }
709 }
710 _ => {}
711 }
712 }
713
714 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
715 let is_read_only = !self.codegen.read(cx).idle();
716 self.prompt_editor.update(cx, |editor, cx| {
717 let was_read_only = editor.read_only(cx);
718 if was_read_only != is_read_only {
719 if is_read_only {
720 editor.set_read_only(true);
721 } else {
722 self.confirmed = false;
723 editor.set_read_only(false);
724 }
725 }
726 });
727 cx.notify();
728 }
729
730 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
731 cx.emit(InlineAssistEditorEvent::Canceled);
732 }
733
734 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
735 if self.confirmed {
736 cx.emit(InlineAssistEditorEvent::Dismissed);
737 } else {
738 let prompt = self.prompt_editor.read(cx).text(cx);
739 self.prompt_editor
740 .update(cx, |editor, _cx| editor.set_read_only(true));
741 cx.emit(InlineAssistEditorEvent::Confirmed { prompt });
742 self.confirmed = true;
743 cx.notify();
744 }
745 }
746
747 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
748 if let Some(ix) = self.prompt_history_ix {
749 if ix > 0 {
750 self.prompt_history_ix = Some(ix - 1);
751 let prompt = self.prompt_history[ix - 1].clone();
752 self.set_prompt(&prompt, cx);
753 }
754 } else if !self.prompt_history.is_empty() {
755 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
756 let prompt = self.prompt_history[self.prompt_history.len() - 1].clone();
757 self.set_prompt(&prompt, cx);
758 }
759 }
760
761 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
762 if let Some(ix) = self.prompt_history_ix {
763 if ix < self.prompt_history.len() - 1 {
764 self.prompt_history_ix = Some(ix + 1);
765 let prompt = self.prompt_history[ix + 1].clone();
766 self.set_prompt(&prompt, cx);
767 } else {
768 self.prompt_history_ix = None;
769 let pending_prompt = self.pending_prompt.clone();
770 self.set_prompt(&pending_prompt, cx);
771 }
772 }
773 }
774
775 fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext<Self>) {
776 self.prompt_editor.update(cx, |editor, cx| {
777 editor.buffer().update(cx, |buffer, cx| {
778 let len = buffer.len(cx);
779 buffer.edit([(0..len, prompt)], None, cx);
780 });
781 });
782 }
783
784 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
785 let settings = ThemeSettings::get_global(cx);
786 let text_style = TextStyle {
787 color: if self.prompt_editor.read(cx).read_only(cx) {
788 cx.theme().colors().text_disabled
789 } else {
790 cx.theme().colors().text
791 },
792 font_family: settings.ui_font.family.clone(),
793 font_features: settings.ui_font.features.clone(),
794 font_size: rems(0.875).into(),
795 font_weight: FontWeight::NORMAL,
796 font_style: FontStyle::Normal,
797 line_height: relative(1.3),
798 background_color: None,
799 underline: None,
800 strikethrough: None,
801 white_space: WhiteSpace::Normal,
802 };
803 EditorElement::new(
804 &self.prompt_editor,
805 EditorStyle {
806 background: cx.theme().colors().editor_background,
807 local_player: cx.theme().players().local(),
808 text: text_style,
809 ..Default::default()
810 },
811 )
812 }
813}
814
815struct PendingInlineAssist {
816 editor: WeakView<Editor>,
817 inline_assist_editor: Option<(BlockId, View<InlineAssistEditor>)>,
818 codegen: Model<Codegen>,
819 _subscriptions: Vec<Subscription>,
820 workspace: Option<WeakView<Workspace>>,
821 include_context: bool,
822}
823
824#[derive(Debug)]
825pub enum CodegenEvent {
826 Finished,
827 Undone,
828}
829
830#[derive(Clone)]
831pub enum CodegenKind {
832 Transform { range: Range<Anchor> },
833 Generate { position: Anchor },
834}
835
836pub struct Codegen {
837 buffer: Model<MultiBuffer>,
838 snapshot: MultiBufferSnapshot,
839 kind: CodegenKind,
840 last_equal_ranges: Vec<Range<Anchor>>,
841 transaction_id: Option<TransactionId>,
842 error: Option<anyhow::Error>,
843 generation: Task<()>,
844 idle: bool,
845 telemetry: Option<Arc<Telemetry>>,
846 _subscription: gpui::Subscription,
847}
848
849impl EventEmitter<CodegenEvent> for Codegen {}
850
851impl Codegen {
852 pub fn new(
853 buffer: Model<MultiBuffer>,
854 kind: CodegenKind,
855 telemetry: Option<Arc<Telemetry>>,
856 cx: &mut ModelContext<Self>,
857 ) -> Self {
858 let snapshot = buffer.read(cx).snapshot(cx);
859 Self {
860 buffer: buffer.clone(),
861 snapshot,
862 kind,
863 last_equal_ranges: Default::default(),
864 transaction_id: Default::default(),
865 error: Default::default(),
866 idle: true,
867 generation: Task::ready(()),
868 telemetry,
869 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
870 }
871 }
872
873 fn handle_buffer_event(
874 &mut self,
875 _buffer: Model<MultiBuffer>,
876 event: &multi_buffer::Event,
877 cx: &mut ModelContext<Self>,
878 ) {
879 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
880 if self.transaction_id == Some(*transaction_id) {
881 self.transaction_id = None;
882 self.generation = Task::ready(());
883 cx.emit(CodegenEvent::Undone);
884 }
885 }
886 }
887
888 pub fn range(&self) -> Range<Anchor> {
889 match &self.kind {
890 CodegenKind::Transform { range } => range.clone(),
891 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
892 }
893 }
894
895 pub fn kind(&self) -> &CodegenKind {
896 &self.kind
897 }
898
899 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
900 &self.last_equal_ranges
901 }
902
903 pub fn idle(&self) -> bool {
904 self.idle
905 }
906
907 pub fn error(&self) -> Option<&anyhow::Error> {
908 self.error.as_ref()
909 }
910
911 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
912 let range = self.range();
913 let snapshot = self.snapshot.clone();
914 let selected_text = snapshot
915 .text_for_range(range.start..range.end)
916 .collect::<Rope>();
917
918 let selection_start = range.start.to_point(&snapshot);
919 let suggested_line_indent = snapshot
920 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
921 .into_values()
922 .next()
923 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
924
925 let model_telemetry_id = prompt.model.telemetry_id();
926 let response = CompletionProvider::global(cx).complete(prompt);
927 let telemetry = self.telemetry.clone();
928 self.generation = cx.spawn(|this, mut cx| {
929 async move {
930 let generate = async {
931 let mut edit_start = range.start.to_offset(&snapshot);
932
933 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
934 let diff: Task<anyhow::Result<()>> =
935 cx.background_executor().spawn(async move {
936 let mut response_latency = None;
937 let request_start = Instant::now();
938 let diff = async {
939 let chunks = strip_invalid_spans_from_codeblock(response.await?);
940 futures::pin_mut!(chunks);
941 let mut diff = StreamingDiff::new(selected_text.to_string());
942
943 let mut new_text = String::new();
944 let mut base_indent = None;
945 let mut line_indent = None;
946 let mut first_line = true;
947
948 while let Some(chunk) = chunks.next().await {
949 if response_latency.is_none() {
950 response_latency = Some(request_start.elapsed());
951 }
952 let chunk = chunk?;
953
954 let mut lines = chunk.split('\n').peekable();
955 while let Some(line) = lines.next() {
956 new_text.push_str(line);
957 if line_indent.is_none() {
958 if let Some(non_whitespace_ch_ix) =
959 new_text.find(|ch: char| !ch.is_whitespace())
960 {
961 line_indent = Some(non_whitespace_ch_ix);
962 base_indent = base_indent.or(line_indent);
963
964 let line_indent = line_indent.unwrap();
965 let base_indent = base_indent.unwrap();
966 let indent_delta =
967 line_indent as i32 - base_indent as i32;
968 let mut corrected_indent_len = cmp::max(
969 0,
970 suggested_line_indent.len as i32 + indent_delta,
971 )
972 as usize;
973 if first_line {
974 corrected_indent_len = corrected_indent_len
975 .saturating_sub(
976 selection_start.column as usize,
977 );
978 }
979
980 let indent_char = suggested_line_indent.char();
981 let mut indent_buffer = [0; 4];
982 let indent_str =
983 indent_char.encode_utf8(&mut indent_buffer);
984 new_text.replace_range(
985 ..line_indent,
986 &indent_str.repeat(corrected_indent_len),
987 );
988 }
989 }
990
991 if line_indent.is_some() {
992 hunks_tx.send(diff.push_new(&new_text)).await?;
993 new_text.clear();
994 }
995
996 if lines.peek().is_some() {
997 hunks_tx.send(diff.push_new("\n")).await?;
998 line_indent = None;
999 first_line = false;
1000 }
1001 }
1002 }
1003 hunks_tx.send(diff.push_new(&new_text)).await?;
1004 hunks_tx.send(diff.finish()).await?;
1005
1006 anyhow::Ok(())
1007 };
1008
1009 let result = diff.await;
1010
1011 let error_message =
1012 result.as_ref().err().map(|error| error.to_string());
1013 if let Some(telemetry) = telemetry {
1014 telemetry.report_assistant_event(
1015 None,
1016 telemetry_events::AssistantKind::Inline,
1017 model_telemetry_id,
1018 response_latency,
1019 error_message,
1020 );
1021 }
1022
1023 result?;
1024 Ok(())
1025 });
1026
1027 while let Some(hunks) = hunks_rx.next().await {
1028 this.update(&mut cx, |this, cx| {
1029 this.last_equal_ranges.clear();
1030
1031 let transaction = this.buffer.update(cx, |buffer, cx| {
1032 // Avoid grouping assistant edits with user edits.
1033 buffer.finalize_last_transaction(cx);
1034
1035 buffer.start_transaction(cx);
1036 buffer.edit(
1037 hunks.into_iter().filter_map(|hunk| match hunk {
1038 Hunk::Insert { text } => {
1039 let edit_start = snapshot.anchor_after(edit_start);
1040 Some((edit_start..edit_start, text))
1041 }
1042 Hunk::Remove { len } => {
1043 let edit_end = edit_start + len;
1044 let edit_range = snapshot.anchor_after(edit_start)
1045 ..snapshot.anchor_before(edit_end);
1046 edit_start = edit_end;
1047 Some((edit_range, String::new()))
1048 }
1049 Hunk::Keep { len } => {
1050 let edit_end = edit_start + len;
1051 let edit_range = snapshot.anchor_after(edit_start)
1052 ..snapshot.anchor_before(edit_end);
1053 edit_start = edit_end;
1054 this.last_equal_ranges.push(edit_range);
1055 None
1056 }
1057 }),
1058 None,
1059 cx,
1060 );
1061
1062 buffer.end_transaction(cx)
1063 });
1064
1065 if let Some(transaction) = transaction {
1066 if let Some(first_transaction) = this.transaction_id {
1067 // Group all assistant edits into the first transaction.
1068 this.buffer.update(cx, |buffer, cx| {
1069 buffer.merge_transactions(
1070 transaction,
1071 first_transaction,
1072 cx,
1073 )
1074 });
1075 } else {
1076 this.transaction_id = Some(transaction);
1077 this.buffer.update(cx, |buffer, cx| {
1078 buffer.finalize_last_transaction(cx)
1079 });
1080 }
1081 }
1082
1083 cx.notify();
1084 })?;
1085 }
1086
1087 diff.await?;
1088
1089 anyhow::Ok(())
1090 };
1091
1092 let result = generate.await;
1093 this.update(&mut cx, |this, cx| {
1094 this.last_equal_ranges.clear();
1095 this.idle = true;
1096 if let Err(error) = result {
1097 this.error = Some(error);
1098 }
1099 cx.emit(CodegenEvent::Finished);
1100 cx.notify();
1101 })
1102 .ok();
1103 }
1104 });
1105 self.error.take();
1106 self.idle = false;
1107 cx.notify();
1108 }
1109
1110 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1111 if let Some(transaction_id) = self.transaction_id {
1112 self.buffer
1113 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1114 }
1115 }
1116}
1117
1118fn strip_invalid_spans_from_codeblock(
1119 stream: impl Stream<Item = Result<String>>,
1120) -> impl Stream<Item = Result<String>> {
1121 let mut first_line = true;
1122 let mut buffer = String::new();
1123 let mut starts_with_markdown_codeblock = false;
1124 let mut includes_start_or_end_span = false;
1125 stream.filter_map(move |chunk| {
1126 let chunk = match chunk {
1127 Ok(chunk) => chunk,
1128 Err(err) => return future::ready(Some(Err(err))),
1129 };
1130 buffer.push_str(&chunk);
1131
1132 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1133 includes_start_or_end_span = true;
1134
1135 buffer = buffer
1136 .strip_prefix("<|S|>")
1137 .or_else(|| buffer.strip_prefix("<|S|"))
1138 .unwrap_or(&buffer)
1139 .to_string();
1140 } else if buffer.ends_with("|E|>") {
1141 includes_start_or_end_span = true;
1142 } else if buffer.starts_with("<|")
1143 || buffer.starts_with("<|S")
1144 || buffer.starts_with("<|S|")
1145 || buffer.ends_with('|')
1146 || buffer.ends_with("|E")
1147 || buffer.ends_with("|E|")
1148 {
1149 return future::ready(None);
1150 }
1151
1152 if first_line {
1153 if buffer.is_empty() || buffer == "`" || buffer == "``" {
1154 return future::ready(None);
1155 } else if buffer.starts_with("```") {
1156 starts_with_markdown_codeblock = true;
1157 if let Some(newline_ix) = buffer.find('\n') {
1158 buffer.replace_range(..newline_ix + 1, "");
1159 first_line = false;
1160 } else {
1161 return future::ready(None);
1162 }
1163 }
1164 }
1165
1166 let mut text = buffer.to_string();
1167 if starts_with_markdown_codeblock {
1168 text = text
1169 .strip_suffix("\n```\n")
1170 .or_else(|| text.strip_suffix("\n```"))
1171 .or_else(|| text.strip_suffix("\n``"))
1172 .or_else(|| text.strip_suffix("\n`"))
1173 .or_else(|| text.strip_suffix('\n'))
1174 .unwrap_or(&text)
1175 .to_string();
1176 }
1177
1178 if includes_start_or_end_span {
1179 text = text
1180 .strip_suffix("|E|>")
1181 .or_else(|| text.strip_suffix("E|>"))
1182 .or_else(|| text.strip_prefix("|>"))
1183 .or_else(|| text.strip_prefix('>'))
1184 .unwrap_or(&text)
1185 .to_string();
1186 };
1187
1188 if text.contains('\n') {
1189 first_line = false;
1190 }
1191
1192 let remainder = buffer.split_off(text.len());
1193 let result = if buffer.is_empty() {
1194 None
1195 } else {
1196 Some(Ok(buffer.clone()))
1197 };
1198
1199 buffer = remainder;
1200 future::ready(result)
1201 })
1202}
1203
1204fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1205 ranges.sort_unstable_by(|a, b| {
1206 a.start
1207 .cmp(&b.start, buffer)
1208 .then_with(|| b.end.cmp(&a.end, buffer))
1209 });
1210
1211 let mut ix = 0;
1212 while ix + 1 < ranges.len() {
1213 let b = ranges[ix + 1].clone();
1214 let a = &mut ranges[ix];
1215 if a.end.cmp(&b.start, buffer).is_gt() {
1216 if a.end.cmp(&b.end, buffer).is_lt() {
1217 a.end = b.end;
1218 }
1219 ranges.remove(ix + 1);
1220 } else {
1221 ix += 1;
1222 }
1223 }
1224}
1225
1226#[cfg(test)]
1227mod tests {
1228 use std::sync::Arc;
1229
1230 use crate::FakeCompletionProvider;
1231
1232 use super::*;
1233 use futures::stream::{self};
1234 use gpui::{Context, TestAppContext};
1235 use indoc::indoc;
1236 use language::{
1237 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1238 Point,
1239 };
1240 use rand::prelude::*;
1241 use serde::Serialize;
1242 use settings::SettingsStore;
1243
1244 #[derive(Serialize)]
1245 pub struct DummyCompletionRequest {
1246 pub name: String,
1247 }
1248
1249 #[gpui::test(iterations = 10)]
1250 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1251 let provider = FakeCompletionProvider::default();
1252 cx.set_global(cx.update(SettingsStore::test));
1253 cx.set_global(CompletionProvider::Fake(provider.clone()));
1254 cx.update(language_settings::init);
1255
1256 let text = indoc! {"
1257 fn main() {
1258 let x = 0;
1259 for _ in 0..10 {
1260 x += 1;
1261 }
1262 }
1263 "};
1264 let buffer =
1265 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1266 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1267 let range = buffer.read_with(cx, |buffer, cx| {
1268 let snapshot = buffer.snapshot(cx);
1269 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1270 });
1271 let codegen = cx.new_model(|cx| {
1272 Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1273 });
1274
1275 let request = LanguageModelRequest::default();
1276 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1277
1278 let mut new_text = concat!(
1279 " let mut x = 0;\n",
1280 " while x < 10 {\n",
1281 " x += 1;\n",
1282 " }",
1283 );
1284 while !new_text.is_empty() {
1285 let max_len = cmp::min(new_text.len(), 10);
1286 let len = rng.gen_range(1..=max_len);
1287 let (chunk, suffix) = new_text.split_at(len);
1288 provider.send_completion(chunk.into());
1289 new_text = suffix;
1290 cx.background_executor.run_until_parked();
1291 }
1292 provider.finish_completion();
1293 cx.background_executor.run_until_parked();
1294
1295 assert_eq!(
1296 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1297 indoc! {"
1298 fn main() {
1299 let mut x = 0;
1300 while x < 10 {
1301 x += 1;
1302 }
1303 }
1304 "}
1305 );
1306 }
1307
1308 #[gpui::test(iterations = 10)]
1309 async fn test_autoindent_when_generating_past_indentation(
1310 cx: &mut TestAppContext,
1311 mut rng: StdRng,
1312 ) {
1313 let provider = FakeCompletionProvider::default();
1314 cx.set_global(CompletionProvider::Fake(provider.clone()));
1315 cx.set_global(cx.update(SettingsStore::test));
1316 cx.update(language_settings::init);
1317
1318 let text = indoc! {"
1319 fn main() {
1320 le
1321 }
1322 "};
1323 let buffer =
1324 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1325 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1326 let position = buffer.read_with(cx, |buffer, cx| {
1327 let snapshot = buffer.snapshot(cx);
1328 snapshot.anchor_before(Point::new(1, 6))
1329 });
1330 let codegen = cx.new_model(|cx| {
1331 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1332 });
1333
1334 let request = LanguageModelRequest::default();
1335 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1336
1337 let mut new_text = concat!(
1338 "t mut x = 0;\n",
1339 "while x < 10 {\n",
1340 " x += 1;\n",
1341 "}", //
1342 );
1343 while !new_text.is_empty() {
1344 let max_len = cmp::min(new_text.len(), 10);
1345 let len = rng.gen_range(1..=max_len);
1346 let (chunk, suffix) = new_text.split_at(len);
1347 provider.send_completion(chunk.into());
1348 new_text = suffix;
1349 cx.background_executor.run_until_parked();
1350 }
1351 provider.finish_completion();
1352 cx.background_executor.run_until_parked();
1353
1354 assert_eq!(
1355 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1356 indoc! {"
1357 fn main() {
1358 let mut x = 0;
1359 while x < 10 {
1360 x += 1;
1361 }
1362 }
1363 "}
1364 );
1365 }
1366
1367 #[gpui::test(iterations = 10)]
1368 async fn test_autoindent_when_generating_before_indentation(
1369 cx: &mut TestAppContext,
1370 mut rng: StdRng,
1371 ) {
1372 let provider = FakeCompletionProvider::default();
1373 cx.set_global(CompletionProvider::Fake(provider.clone()));
1374 cx.set_global(cx.update(SettingsStore::test));
1375 cx.update(language_settings::init);
1376
1377 let text = concat!(
1378 "fn main() {\n",
1379 " \n",
1380 "}\n" //
1381 );
1382 let buffer =
1383 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1384 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1385 let position = buffer.read_with(cx, |buffer, cx| {
1386 let snapshot = buffer.snapshot(cx);
1387 snapshot.anchor_before(Point::new(1, 2))
1388 });
1389 let codegen = cx.new_model(|cx| {
1390 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1391 });
1392
1393 let request = LanguageModelRequest::default();
1394 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1395
1396 let mut new_text = concat!(
1397 "let mut x = 0;\n",
1398 "while x < 10 {\n",
1399 " x += 1;\n",
1400 "}", //
1401 );
1402 while !new_text.is_empty() {
1403 let max_len = cmp::min(new_text.len(), 10);
1404 let len = rng.gen_range(1..=max_len);
1405 let (chunk, suffix) = new_text.split_at(len);
1406 provider.send_completion(chunk.into());
1407 new_text = suffix;
1408 cx.background_executor.run_until_parked();
1409 }
1410 provider.finish_completion();
1411 cx.background_executor.run_until_parked();
1412
1413 assert_eq!(
1414 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1415 indoc! {"
1416 fn main() {
1417 let mut x = 0;
1418 while x < 10 {
1419 x += 1;
1420 }
1421 }
1422 "}
1423 );
1424 }
1425
1426 #[gpui::test]
1427 async fn test_strip_invalid_spans_from_codeblock() {
1428 assert_eq!(
1429 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1430 .map(|chunk| chunk.unwrap())
1431 .collect::<String>()
1432 .await,
1433 "Lorem ipsum dolor"
1434 );
1435 assert_eq!(
1436 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1437 .map(|chunk| chunk.unwrap())
1438 .collect::<String>()
1439 .await,
1440 "Lorem ipsum dolor"
1441 );
1442 assert_eq!(
1443 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1444 .map(|chunk| chunk.unwrap())
1445 .collect::<String>()
1446 .await,
1447 "Lorem ipsum dolor"
1448 );
1449 assert_eq!(
1450 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1451 .map(|chunk| chunk.unwrap())
1452 .collect::<String>()
1453 .await,
1454 "Lorem ipsum dolor"
1455 );
1456 assert_eq!(
1457 strip_invalid_spans_from_codeblock(chunks(
1458 "```html\n```js\nLorem ipsum dolor\n```\n```",
1459 2
1460 ))
1461 .map(|chunk| chunk.unwrap())
1462 .collect::<String>()
1463 .await,
1464 "```js\nLorem ipsum dolor\n```"
1465 );
1466 assert_eq!(
1467 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1468 .map(|chunk| chunk.unwrap())
1469 .collect::<String>()
1470 .await,
1471 "``\nLorem ipsum dolor\n```"
1472 );
1473 assert_eq!(
1474 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1475 .map(|chunk| chunk.unwrap())
1476 .collect::<String>()
1477 .await,
1478 "Lorem ipsum"
1479 );
1480
1481 assert_eq!(
1482 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
1483 .map(|chunk| chunk.unwrap())
1484 .collect::<String>()
1485 .await,
1486 "Lorem ipsum"
1487 );
1488
1489 assert_eq!(
1490 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
1491 .map(|chunk| chunk.unwrap())
1492 .collect::<String>()
1493 .await,
1494 "Lorem ipsum"
1495 );
1496 assert_eq!(
1497 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
1498 .map(|chunk| chunk.unwrap())
1499 .collect::<String>()
1500 .await,
1501 "Lorem ipsum"
1502 );
1503 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1504 stream::iter(
1505 text.chars()
1506 .collect::<Vec<_>>()
1507 .chunks(size)
1508 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1509 .collect::<Vec<_>>(),
1510 )
1511 }
1512 }
1513
1514 fn rust_lang() -> Language {
1515 Language::new(
1516 LanguageConfig {
1517 name: "Rust".into(),
1518 matcher: LanguageMatcher {
1519 path_suffixes: vec!["rs".to_string()],
1520 ..Default::default()
1521 },
1522 ..Default::default()
1523 },
1524 Some(tree_sitter_rust::language()),
1525 )
1526 .with_indents_query(
1527 r#"
1528 (call_expression) @indent
1529 (field_expression) @indent
1530 (_ "(" ")" @end) @indent
1531 (_ "{" "}" @end) @indent
1532 "#,
1533 )
1534 .unwrap()
1535 }
1536}