1use crate::{
2 prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
3 LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
4};
5use anyhow::Result;
6use client::telemetry::Telemetry;
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp},
10 display_map::{
11 BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
12 },
13 scroll::{Autoscroll, AutoscrollStrategy},
14 Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer,
15 MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
18use gpui::{
19 AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
20 Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
21 ViewContext, WeakView, WhiteSpace, WindowContext,
22};
23use language::{Point, TransactionId};
24use multi_buffer::MultiBufferRow;
25use parking_lot::Mutex;
26use rope::Rope;
27use settings::Settings;
28use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
29use theme::ThemeSettings;
30use ui::{prelude::*, Tooltip};
31use workspace::{notifications::NotificationId, Toast, Workspace};
32
33pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
34 cx.set_global(InlineAssistant::new(telemetry));
35}
36
37const PROMPT_HISTORY_MAX_LEN: usize = 20;
38
39pub struct InlineAssistant {
40 next_assist_id: InlineAssistId,
41 pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
42 pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
43 prompt_history: VecDeque<String>,
44 telemetry: Option<Arc<Telemetry>>,
45}
46
47struct EditorPendingAssists {
48 window: AnyWindowHandle,
49 assist_ids: Vec<InlineAssistId>,
50}
51
52impl Global for InlineAssistant {}
53
54impl InlineAssistant {
55 pub fn new(telemetry: Arc<Telemetry>) -> Self {
56 Self {
57 next_assist_id: InlineAssistId::default(),
58 pending_assists: HashMap::default(),
59 pending_assist_ids_by_editor: HashMap::default(),
60 prompt_history: VecDeque::default(),
61 telemetry: Some(telemetry),
62 }
63 }
64
65 pub fn assist(
66 &mut self,
67 editor: &View<Editor>,
68 workspace: Option<WeakView<Workspace>>,
69 include_conversation: bool,
70 cx: &mut WindowContext,
71 ) {
72 let selection = editor.read(cx).selections.newest_anchor().clone();
73 if selection.start.excerpt_id != selection.end.excerpt_id {
74 return;
75 }
76 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
77
78 // Extend the selection to the start and the end of the line.
79 let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
80 if point_selection.end > point_selection.start {
81 point_selection.start.column = 0;
82 // If the selection ends at the start of the line, we don't want to include it.
83 if point_selection.end.column == 0 {
84 point_selection.end.row -= 1;
85 }
86 point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
87 }
88
89 let codegen_kind = if point_selection.start == point_selection.end {
90 CodegenKind::Generate {
91 position: snapshot.anchor_after(point_selection.start),
92 }
93 } else {
94 CodegenKind::Transform {
95 range: snapshot.anchor_before(point_selection.start)
96 ..snapshot.anchor_after(point_selection.end),
97 }
98 };
99
100 let inline_assist_id = self.next_assist_id.post_inc();
101 let codegen = cx.new_model(|cx| {
102 Codegen::new(
103 editor.read(cx).buffer().clone(),
104 codegen_kind,
105 self.telemetry.clone(),
106 cx,
107 )
108 });
109
110 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
111 let inline_assist_editor = cx.new_view(|cx| {
112 InlineAssistEditor::new(
113 inline_assist_id,
114 gutter_dimensions.clone(),
115 self.prompt_history.clone(),
116 codegen.clone(),
117 cx,
118 )
119 });
120 let block_id = editor.update(cx, |editor, cx| {
121 editor.change_selections(None, cx, |selections| {
122 selections.select_anchor_ranges([selection.head()..selection.head()])
123 });
124 editor.insert_blocks(
125 [BlockProperties {
126 style: BlockStyle::Sticky,
127 position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)),
128 height: inline_assist_editor.read(cx).height_in_lines,
129 render: build_inline_assist_editor_renderer(
130 &inline_assist_editor,
131 gutter_dimensions,
132 ),
133 disposition: if selection.reversed {
134 BlockDisposition::Above
135 } else {
136 BlockDisposition::Below
137 },
138 }],
139 Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
140 cx,
141 )[0]
142 });
143
144 self.pending_assists.insert(
145 inline_assist_id,
146 PendingInlineAssist {
147 include_conversation,
148 editor: editor.downgrade(),
149 inline_assist_editor: Some((block_id, inline_assist_editor.clone())),
150 codegen: codegen.clone(),
151 workspace,
152 _subscriptions: vec![
153 cx.subscribe(&inline_assist_editor, |inline_assist_editor, event, cx| {
154 InlineAssistant::update_global(cx, |this, cx| {
155 this.handle_inline_assistant_event(inline_assist_editor, event, cx)
156 })
157 }),
158 cx.subscribe(editor, {
159 let inline_assist_editor = inline_assist_editor.downgrade();
160 move |editor, event, cx| {
161 if let Some(inline_assist_editor) = inline_assist_editor.upgrade() {
162 if let EditorEvent::SelectionsChanged { local } = event {
163 if *local
164 && inline_assist_editor
165 .focus_handle(cx)
166 .contains_focused(cx)
167 {
168 cx.focus_view(&editor);
169 }
170 }
171 }
172 }
173 }),
174 cx.observe(&codegen, {
175 let editor = editor.downgrade();
176 move |_, cx| {
177 if let Some(editor) = editor.upgrade() {
178 InlineAssistant::update_global(cx, |this, cx| {
179 this.update_highlights_for_editor(&editor, cx);
180 })
181 }
182 }
183 }),
184 cx.subscribe(&codegen, move |codegen, event, cx| {
185 InlineAssistant::update_global(cx, |this, cx| match event {
186 CodegenEvent::Undone => {
187 this.finish_inline_assist(inline_assist_id, false, cx)
188 }
189 CodegenEvent::Finished => {
190 let pending_assist = if let Some(pending_assist) =
191 this.pending_assists.get(&inline_assist_id)
192 {
193 pending_assist
194 } else {
195 return;
196 };
197
198 let error = codegen
199 .read(cx)
200 .error()
201 .map(|error| format!("Inline assistant error: {}", error));
202 if let Some(error) = error {
203 if pending_assist.inline_assist_editor.is_none() {
204 if let Some(workspace) = pending_assist
205 .workspace
206 .as_ref()
207 .and_then(|workspace| workspace.upgrade())
208 {
209 workspace.update(cx, |workspace, cx| {
210 struct InlineAssistantError;
211
212 let id = NotificationId::identified::<
213 InlineAssistantError,
214 >(
215 inline_assist_id.0
216 );
217
218 workspace.show_toast(Toast::new(id, error), cx);
219 })
220 }
221
222 this.finish_inline_assist(inline_assist_id, false, cx);
223 }
224 } else {
225 this.finish_inline_assist(inline_assist_id, false, cx);
226 }
227 }
228 })
229 }),
230 ],
231 },
232 );
233
234 self.pending_assist_ids_by_editor
235 .entry(editor.downgrade())
236 .or_insert_with(|| EditorPendingAssists {
237 window: cx.window_handle(),
238 assist_ids: Vec::new(),
239 })
240 .assist_ids
241 .push(inline_assist_id);
242 self.update_highlights_for_editor(editor, cx);
243 }
244
245 fn handle_inline_assistant_event(
246 &mut self,
247 inline_assist_editor: View<InlineAssistEditor>,
248 event: &InlineAssistEditorEvent,
249 cx: &mut WindowContext,
250 ) {
251 let assist_id = inline_assist_editor.read(cx).id;
252 match event {
253 InlineAssistEditorEvent::Confirmed { prompt } => {
254 self.confirm_inline_assist(assist_id, prompt, cx);
255 }
256 InlineAssistEditorEvent::Canceled => {
257 self.finish_inline_assist(assist_id, true, cx);
258 }
259 InlineAssistEditorEvent::Dismissed => {
260 self.hide_inline_assist(assist_id, cx);
261 }
262 InlineAssistEditorEvent::Resized { height_in_lines } => {
263 self.resize_inline_assist(assist_id, *height_in_lines, cx);
264 }
265 }
266 }
267
268 pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
269 for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
270 if pending_assists.window == cx.window_handle() {
271 if let Some(editor) = editor.upgrade() {
272 if editor.read(cx).is_focused(cx) {
273 if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
274 self.finish_inline_assist(assist_id, true, cx);
275 return true;
276 }
277 }
278 }
279 }
280 }
281 false
282 }
283
284 fn finish_inline_assist(
285 &mut self,
286 assist_id: InlineAssistId,
287 undo: bool,
288 cx: &mut WindowContext,
289 ) {
290 self.hide_inline_assist(assist_id, cx);
291
292 if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
293 if let hash_map::Entry::Occupied(mut entry) = self
294 .pending_assist_ids_by_editor
295 .entry(pending_assist.editor.clone())
296 {
297 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
298 if entry.get().assist_ids.is_empty() {
299 entry.remove();
300 }
301 }
302
303 if let Some(editor) = pending_assist.editor.upgrade() {
304 self.update_highlights_for_editor(&editor, cx);
305
306 if undo {
307 pending_assist
308 .codegen
309 .update(cx, |codegen, cx| codegen.undo(cx));
310 }
311 }
312 }
313 }
314
315 fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
316 if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
317 if let Some(editor) = pending_assist.editor.upgrade() {
318 if let Some((block_id, inline_assist_editor)) =
319 pending_assist.inline_assist_editor.take()
320 {
321 editor.update(cx, |editor, cx| {
322 editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
323 if inline_assist_editor.focus_handle(cx).contains_focused(cx) {
324 editor.focus(cx);
325 }
326 });
327 }
328 }
329 }
330 }
331
332 fn resize_inline_assist(
333 &mut self,
334 assist_id: InlineAssistId,
335 height_in_lines: u8,
336 cx: &mut WindowContext,
337 ) {
338 if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
339 if let Some(editor) = pending_assist.editor.upgrade() {
340 if let Some((block_id, inline_assist_editor)) =
341 pending_assist.inline_assist_editor.as_ref()
342 {
343 let gutter_dimensions = inline_assist_editor.read(cx).gutter_dimensions.clone();
344 let mut new_blocks = HashMap::default();
345 new_blocks.insert(
346 *block_id,
347 (
348 Some(height_in_lines),
349 build_inline_assist_editor_renderer(
350 inline_assist_editor,
351 gutter_dimensions,
352 ),
353 ),
354 );
355 editor.update(cx, |editor, cx| {
356 editor
357 .display_map
358 .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
359 });
360 }
361 }
362 }
363 }
364
365 fn confirm_inline_assist(
366 &mut self,
367 assist_id: InlineAssistId,
368 user_prompt: &str,
369 cx: &mut WindowContext,
370 ) {
371 let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
372 {
373 pending_assist
374 } else {
375 return;
376 };
377
378 let conversation = if pending_assist.include_conversation {
379 pending_assist.workspace.as_ref().and_then(|workspace| {
380 let workspace = workspace.upgrade()?.read(cx);
381 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
382 assistant_panel.read(cx).active_conversation(cx)
383 })
384 } else {
385 None
386 };
387
388 let editor = if let Some(editor) = pending_assist.editor.upgrade() {
389 editor
390 } else {
391 return;
392 };
393
394 let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
395 let workspace = workspace.upgrade()?;
396 Some(
397 workspace
398 .read(cx)
399 .project()
400 .read(cx)
401 .worktree_root_names(cx)
402 .collect::<Vec<&str>>()
403 .join("/"),
404 )
405 });
406
407 self.prompt_history.retain(|prompt| prompt != user_prompt);
408 self.prompt_history.push_back(user_prompt.into());
409 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
410 self.prompt_history.pop_front();
411 }
412
413 let codegen = pending_assist.codegen.clone();
414 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
415 let range = codegen.read(cx).range();
416 let start = snapshot.point_to_buffer_offset(range.start);
417 let end = snapshot.point_to_buffer_offset(range.end);
418 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
419 let (start_buffer, start_buffer_offset) = start;
420 let (end_buffer, end_buffer_offset) = end;
421 if start_buffer.remote_id() == end_buffer.remote_id() {
422 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
423 } else {
424 self.finish_inline_assist(assist_id, false, cx);
425 return;
426 }
427 } else {
428 self.finish_inline_assist(assist_id, false, cx);
429 return;
430 };
431
432 let language = buffer.language_at(range.start);
433 let language_name = if let Some(language) = language.as_ref() {
434 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
435 None
436 } else {
437 Some(language.name())
438 }
439 } else {
440 None
441 };
442
443 // Higher Temperature increases the randomness of model outputs.
444 // If Markdown or No Language is Known, increase the randomness for more creative output
445 // If Code, decrease temperature to get more deterministic outputs
446 let temperature = if let Some(language) = language_name.clone() {
447 if language.as_ref() == "Markdown" {
448 1.0
449 } else {
450 0.5
451 }
452 } else {
453 1.0
454 };
455
456 let user_prompt = user_prompt.to_string();
457
458 let prompt = cx.background_executor().spawn(async move {
459 let language_name = language_name.as_deref();
460 generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
461 });
462
463 let mut messages = Vec::new();
464 if let Some(conversation) = conversation {
465 let request = conversation.read(cx).to_completion_request(cx);
466 messages = request.messages;
467 }
468 let model = CompletionProvider::global(cx).model();
469
470 cx.spawn(|mut cx| async move {
471 let prompt = prompt.await?;
472
473 messages.push(LanguageModelRequestMessage {
474 role: Role::User,
475 content: prompt,
476 });
477
478 let request = LanguageModelRequest {
479 model,
480 messages,
481 stop: vec!["|END|>".to_string()],
482 temperature,
483 };
484
485 codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
486 anyhow::Ok(())
487 })
488 .detach_and_log_err(cx);
489 }
490
491 fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut WindowContext) {
492 let mut background_ranges = Vec::new();
493 let mut foreground_ranges = Vec::new();
494 let empty_inline_assist_ids = Vec::new();
495 let inline_assist_ids = self
496 .pending_assist_ids_by_editor
497 .get(&editor.downgrade())
498 .map_or(&empty_inline_assist_ids, |pending_assists| {
499 &pending_assists.assist_ids
500 });
501
502 for inline_assist_id in inline_assist_ids {
503 if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
504 let codegen = pending_assist.codegen.read(cx);
505 background_ranges.push(codegen.range());
506 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
507 }
508 }
509
510 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
511 merge_ranges(&mut background_ranges, &snapshot);
512 merge_ranges(&mut foreground_ranges, &snapshot);
513 editor.update(cx, |editor, cx| {
514 if background_ranges.is_empty() {
515 editor.clear_background_highlights::<PendingInlineAssist>(cx);
516 } else {
517 editor.highlight_background::<PendingInlineAssist>(
518 &background_ranges,
519 |theme| theme.editor_active_line_background, // TODO use the appropriate color
520 cx,
521 );
522 }
523
524 if foreground_ranges.is_empty() {
525 editor.clear_highlights::<PendingInlineAssist>(cx);
526 } else {
527 editor.highlight_text::<PendingInlineAssist>(
528 foreground_ranges,
529 HighlightStyle {
530 fade_out: Some(0.6),
531 ..Default::default()
532 },
533 cx,
534 );
535 }
536 });
537 }
538}
539
540fn build_inline_assist_editor_renderer(
541 editor: &View<InlineAssistEditor>,
542 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
543) -> RenderBlock {
544 let editor = editor.clone();
545 Box::new(move |cx: &mut BlockContext| {
546 *gutter_dimensions.lock() = *cx.gutter_dimensions;
547 editor.clone().into_any_element()
548 })
549}
550
551#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
552struct InlineAssistId(usize);
553
554impl InlineAssistId {
555 fn post_inc(&mut self) -> InlineAssistId {
556 let id = *self;
557 self.0 += 1;
558 id
559 }
560}
561
562enum InlineAssistEditorEvent {
563 Confirmed { prompt: String },
564 Canceled,
565 Dismissed,
566 Resized { height_in_lines: u8 },
567}
568
569struct InlineAssistEditor {
570 id: InlineAssistId,
571 height_in_lines: u8,
572 prompt_editor: View<Editor>,
573 confirmed: bool,
574 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
575 prompt_history: VecDeque<String>,
576 prompt_history_ix: Option<usize>,
577 pending_prompt: String,
578 codegen: Model<Codegen>,
579 _subscriptions: Vec<Subscription>,
580}
581
582impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
583
584impl Render for InlineAssistEditor {
585 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
586 let gutter_dimensions = *self.gutter_dimensions.lock();
587 let icon_size = IconSize::default();
588 h_flex()
589 .w_full()
590 .py_1p5()
591 .border_y_1()
592 .border_color(cx.theme().colors().border)
593 .bg(cx.theme().colors().editor_background)
594 .on_action(cx.listener(Self::confirm))
595 .on_action(cx.listener(Self::cancel))
596 .on_action(cx.listener(Self::move_up))
597 .on_action(cx.listener(Self::move_down))
598 .child(
599 h_flex()
600 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
601 .pr(gutter_dimensions.fold_area_width())
602 .justify_end()
603 .children(if let Some(error) = self.codegen.read(cx).error() {
604 let error_message = SharedString::from(error.to_string());
605 Some(
606 div()
607 .id("error")
608 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
609 .child(
610 Icon::new(IconName::XCircle)
611 .size(icon_size)
612 .color(Color::Error),
613 ),
614 )
615 } else {
616 None
617 }),
618 )
619 .child(div().flex_1().child(self.render_prompt_editor(cx)))
620 }
621}
622
623impl FocusableView for InlineAssistEditor {
624 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
625 self.prompt_editor.focus_handle(cx)
626 }
627}
628
629impl InlineAssistEditor {
630 const MAX_LINES: u8 = 8;
631
632 #[allow(clippy::too_many_arguments)]
633 fn new(
634 id: InlineAssistId,
635 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
636 prompt_history: VecDeque<String>,
637 codegen: Model<Codegen>,
638 cx: &mut ViewContext<Self>,
639 ) -> Self {
640 let prompt_editor = cx.new_view(|cx| {
641 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
642 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
643 let placeholder = match codegen.read(cx).kind() {
644 CodegenKind::Transform { .. } => "Enter transformation prompt…",
645 CodegenKind::Generate { .. } => "Enter generation prompt…",
646 };
647 editor.set_placeholder_text(placeholder, cx);
648 editor
649 });
650 cx.focus_view(&prompt_editor);
651
652 let subscriptions = vec![
653 cx.observe(&codegen, Self::handle_codegen_changed),
654 cx.observe(&prompt_editor, Self::handle_prompt_editor_changed),
655 cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
656 ];
657
658 let mut this = Self {
659 id,
660 height_in_lines: 1,
661 prompt_editor,
662 confirmed: false,
663 gutter_dimensions,
664 prompt_history,
665 prompt_history_ix: None,
666 pending_prompt: String::new(),
667 codegen,
668 _subscriptions: subscriptions,
669 };
670 this.count_lines(cx);
671 this
672 }
673
674 fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
675 let height_in_lines = cmp::max(
676 2, // Make the editor at least two lines tall, to account for padding.
677 cmp::min(
678 self.prompt_editor
679 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
680 Self::MAX_LINES as u32,
681 ),
682 ) as u8;
683
684 if height_in_lines != self.height_in_lines {
685 self.height_in_lines = height_in_lines;
686 cx.emit(InlineAssistEditorEvent::Resized { height_in_lines });
687 }
688 }
689
690 fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
691 self.count_lines(cx);
692 }
693
694 fn handle_prompt_editor_events(
695 &mut self,
696 _: View<Editor>,
697 event: &EditorEvent,
698 cx: &mut ViewContext<Self>,
699 ) {
700 if let EditorEvent::Edited = event {
701 self.pending_prompt = self.prompt_editor.read(cx).text(cx);
702 cx.notify();
703 }
704 }
705
706 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
707 let is_read_only = !self.codegen.read(cx).idle();
708 self.prompt_editor.update(cx, |editor, cx| {
709 let was_read_only = editor.read_only(cx);
710 if was_read_only != is_read_only {
711 if is_read_only {
712 editor.set_read_only(true);
713 } else {
714 self.confirmed = false;
715 editor.set_read_only(false);
716 }
717 }
718 });
719 cx.notify();
720 }
721
722 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
723 cx.emit(InlineAssistEditorEvent::Canceled);
724 }
725
726 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
727 if self.confirmed {
728 cx.emit(InlineAssistEditorEvent::Dismissed);
729 } else {
730 let prompt = self.prompt_editor.read(cx).text(cx);
731 self.prompt_editor
732 .update(cx, |editor, _cx| editor.set_read_only(true));
733 cx.emit(InlineAssistEditorEvent::Confirmed { prompt });
734 self.confirmed = true;
735 cx.notify();
736 }
737 }
738
739 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
740 if let Some(ix) = self.prompt_history_ix {
741 if ix > 0 {
742 self.prompt_history_ix = Some(ix - 1);
743 let prompt = self.prompt_history[ix - 1].clone();
744 self.set_prompt(&prompt, cx);
745 }
746 } else if !self.prompt_history.is_empty() {
747 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
748 let prompt = self.prompt_history[self.prompt_history.len() - 1].clone();
749 self.set_prompt(&prompt, cx);
750 }
751 }
752
753 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
754 if let Some(ix) = self.prompt_history_ix {
755 if ix < self.prompt_history.len() - 1 {
756 self.prompt_history_ix = Some(ix + 1);
757 let prompt = self.prompt_history[ix + 1].clone();
758 self.set_prompt(&prompt, cx);
759 } else {
760 self.prompt_history_ix = None;
761 let pending_prompt = self.pending_prompt.clone();
762 self.set_prompt(&pending_prompt, cx);
763 }
764 }
765 }
766
767 fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext<Self>) {
768 self.prompt_editor.update(cx, |editor, cx| {
769 editor.buffer().update(cx, |buffer, cx| {
770 let len = buffer.len(cx);
771 buffer.edit([(0..len, prompt)], None, cx);
772 });
773 });
774 }
775
776 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
777 let settings = ThemeSettings::get_global(cx);
778 let text_style = TextStyle {
779 color: if self.prompt_editor.read(cx).read_only(cx) {
780 cx.theme().colors().text_disabled
781 } else {
782 cx.theme().colors().text
783 },
784 font_family: settings.ui_font.family.clone(),
785 font_features: settings.ui_font.features.clone(),
786 font_size: rems(0.875).into(),
787 font_weight: FontWeight::NORMAL,
788 font_style: FontStyle::Normal,
789 line_height: relative(1.3),
790 background_color: None,
791 underline: None,
792 strikethrough: None,
793 white_space: WhiteSpace::Normal,
794 };
795 EditorElement::new(
796 &self.prompt_editor,
797 EditorStyle {
798 background: cx.theme().colors().editor_background,
799 local_player: cx.theme().players().local(),
800 text: text_style,
801 ..Default::default()
802 },
803 )
804 }
805}
806
807struct PendingInlineAssist {
808 editor: WeakView<Editor>,
809 inline_assist_editor: Option<(BlockId, View<InlineAssistEditor>)>,
810 codegen: Model<Codegen>,
811 _subscriptions: Vec<Subscription>,
812 workspace: Option<WeakView<Workspace>>,
813 include_conversation: bool,
814}
815
816#[derive(Debug)]
817pub enum CodegenEvent {
818 Finished,
819 Undone,
820}
821
822#[derive(Clone)]
823pub enum CodegenKind {
824 Transform { range: Range<Anchor> },
825 Generate { position: Anchor },
826}
827
828pub struct Codegen {
829 buffer: Model<MultiBuffer>,
830 snapshot: MultiBufferSnapshot,
831 kind: CodegenKind,
832 last_equal_ranges: Vec<Range<Anchor>>,
833 transaction_id: Option<TransactionId>,
834 error: Option<anyhow::Error>,
835 generation: Task<()>,
836 idle: bool,
837 telemetry: Option<Arc<Telemetry>>,
838 _subscription: gpui::Subscription,
839}
840
841impl EventEmitter<CodegenEvent> for Codegen {}
842
843impl Codegen {
844 pub fn new(
845 buffer: Model<MultiBuffer>,
846 kind: CodegenKind,
847 telemetry: Option<Arc<Telemetry>>,
848 cx: &mut ModelContext<Self>,
849 ) -> Self {
850 let snapshot = buffer.read(cx).snapshot(cx);
851 Self {
852 buffer: buffer.clone(),
853 snapshot,
854 kind,
855 last_equal_ranges: Default::default(),
856 transaction_id: Default::default(),
857 error: Default::default(),
858 idle: true,
859 generation: Task::ready(()),
860 telemetry,
861 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
862 }
863 }
864
865 fn handle_buffer_event(
866 &mut self,
867 _buffer: Model<MultiBuffer>,
868 event: &multi_buffer::Event,
869 cx: &mut ModelContext<Self>,
870 ) {
871 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
872 if self.transaction_id == Some(*transaction_id) {
873 self.transaction_id = None;
874 self.generation = Task::ready(());
875 cx.emit(CodegenEvent::Undone);
876 }
877 }
878 }
879
880 pub fn range(&self) -> Range<Anchor> {
881 match &self.kind {
882 CodegenKind::Transform { range } => range.clone(),
883 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
884 }
885 }
886
887 pub fn kind(&self) -> &CodegenKind {
888 &self.kind
889 }
890
891 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
892 &self.last_equal_ranges
893 }
894
895 pub fn idle(&self) -> bool {
896 self.idle
897 }
898
899 pub fn error(&self) -> Option<&anyhow::Error> {
900 self.error.as_ref()
901 }
902
903 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
904 let range = self.range();
905 let snapshot = self.snapshot.clone();
906 let selected_text = snapshot
907 .text_for_range(range.start..range.end)
908 .collect::<Rope>();
909
910 let selection_start = range.start.to_point(&snapshot);
911 let suggested_line_indent = snapshot
912 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
913 .into_values()
914 .next()
915 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
916
917 let model_telemetry_id = prompt.model.telemetry_id();
918 let response = CompletionProvider::global(cx).complete(prompt);
919 let telemetry = self.telemetry.clone();
920 self.generation = cx.spawn(|this, mut cx| {
921 async move {
922 let generate = async {
923 let mut edit_start = range.start.to_offset(&snapshot);
924
925 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
926 let diff: Task<anyhow::Result<()>> =
927 cx.background_executor().spawn(async move {
928 let mut response_latency = None;
929 let request_start = Instant::now();
930 let diff = async {
931 let chunks = strip_invalid_spans_from_codeblock(response.await?);
932 futures::pin_mut!(chunks);
933 let mut diff = StreamingDiff::new(selected_text.to_string());
934
935 let mut new_text = String::new();
936 let mut base_indent = None;
937 let mut line_indent = None;
938 let mut first_line = true;
939
940 while let Some(chunk) = chunks.next().await {
941 if response_latency.is_none() {
942 response_latency = Some(request_start.elapsed());
943 }
944 let chunk = chunk?;
945
946 let mut lines = chunk.split('\n').peekable();
947 while let Some(line) = lines.next() {
948 new_text.push_str(line);
949 if line_indent.is_none() {
950 if let Some(non_whitespace_ch_ix) =
951 new_text.find(|ch: char| !ch.is_whitespace())
952 {
953 line_indent = Some(non_whitespace_ch_ix);
954 base_indent = base_indent.or(line_indent);
955
956 let line_indent = line_indent.unwrap();
957 let base_indent = base_indent.unwrap();
958 let indent_delta =
959 line_indent as i32 - base_indent as i32;
960 let mut corrected_indent_len = cmp::max(
961 0,
962 suggested_line_indent.len as i32 + indent_delta,
963 )
964 as usize;
965 if first_line {
966 corrected_indent_len = corrected_indent_len
967 .saturating_sub(
968 selection_start.column as usize,
969 );
970 }
971
972 let indent_char = suggested_line_indent.char();
973 let mut indent_buffer = [0; 4];
974 let indent_str =
975 indent_char.encode_utf8(&mut indent_buffer);
976 new_text.replace_range(
977 ..line_indent,
978 &indent_str.repeat(corrected_indent_len),
979 );
980 }
981 }
982
983 if line_indent.is_some() {
984 hunks_tx.send(diff.push_new(&new_text)).await?;
985 new_text.clear();
986 }
987
988 if lines.peek().is_some() {
989 hunks_tx.send(diff.push_new("\n")).await?;
990 line_indent = None;
991 first_line = false;
992 }
993 }
994 }
995 hunks_tx.send(diff.push_new(&new_text)).await?;
996 hunks_tx.send(diff.finish()).await?;
997
998 anyhow::Ok(())
999 };
1000
1001 let result = diff.await;
1002
1003 let error_message =
1004 result.as_ref().err().map(|error| error.to_string());
1005 if let Some(telemetry) = telemetry {
1006 telemetry.report_assistant_event(
1007 None,
1008 telemetry_events::AssistantKind::Inline,
1009 model_telemetry_id,
1010 response_latency,
1011 error_message,
1012 );
1013 }
1014
1015 result?;
1016 Ok(())
1017 });
1018
1019 while let Some(hunks) = hunks_rx.next().await {
1020 this.update(&mut cx, |this, cx| {
1021 this.last_equal_ranges.clear();
1022
1023 let transaction = this.buffer.update(cx, |buffer, cx| {
1024 // Avoid grouping assistant edits with user edits.
1025 buffer.finalize_last_transaction(cx);
1026
1027 buffer.start_transaction(cx);
1028 buffer.edit(
1029 hunks.into_iter().filter_map(|hunk| match hunk {
1030 Hunk::Insert { text } => {
1031 let edit_start = snapshot.anchor_after(edit_start);
1032 Some((edit_start..edit_start, text))
1033 }
1034 Hunk::Remove { len } => {
1035 let edit_end = edit_start + len;
1036 let edit_range = snapshot.anchor_after(edit_start)
1037 ..snapshot.anchor_before(edit_end);
1038 edit_start = edit_end;
1039 Some((edit_range, String::new()))
1040 }
1041 Hunk::Keep { len } => {
1042 let edit_end = edit_start + len;
1043 let edit_range = snapshot.anchor_after(edit_start)
1044 ..snapshot.anchor_before(edit_end);
1045 edit_start = edit_end;
1046 this.last_equal_ranges.push(edit_range);
1047 None
1048 }
1049 }),
1050 None,
1051 cx,
1052 );
1053
1054 buffer.end_transaction(cx)
1055 });
1056
1057 if let Some(transaction) = transaction {
1058 if let Some(first_transaction) = this.transaction_id {
1059 // Group all assistant edits into the first transaction.
1060 this.buffer.update(cx, |buffer, cx| {
1061 buffer.merge_transactions(
1062 transaction,
1063 first_transaction,
1064 cx,
1065 )
1066 });
1067 } else {
1068 this.transaction_id = Some(transaction);
1069 this.buffer.update(cx, |buffer, cx| {
1070 buffer.finalize_last_transaction(cx)
1071 });
1072 }
1073 }
1074
1075 cx.notify();
1076 })?;
1077 }
1078
1079 diff.await?;
1080
1081 anyhow::Ok(())
1082 };
1083
1084 let result = generate.await;
1085 this.update(&mut cx, |this, cx| {
1086 this.last_equal_ranges.clear();
1087 this.idle = true;
1088 if let Err(error) = result {
1089 this.error = Some(error);
1090 }
1091 cx.emit(CodegenEvent::Finished);
1092 cx.notify();
1093 })
1094 .ok();
1095 }
1096 });
1097 self.error.take();
1098 self.idle = false;
1099 cx.notify();
1100 }
1101
1102 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1103 if let Some(transaction_id) = self.transaction_id {
1104 self.buffer
1105 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1106 }
1107 }
1108}
1109
1110fn strip_invalid_spans_from_codeblock(
1111 stream: impl Stream<Item = Result<String>>,
1112) -> impl Stream<Item = Result<String>> {
1113 let mut first_line = true;
1114 let mut buffer = String::new();
1115 let mut starts_with_markdown_codeblock = false;
1116 let mut includes_start_or_end_span = false;
1117 stream.filter_map(move |chunk| {
1118 let chunk = match chunk {
1119 Ok(chunk) => chunk,
1120 Err(err) => return future::ready(Some(Err(err))),
1121 };
1122 buffer.push_str(&chunk);
1123
1124 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1125 includes_start_or_end_span = true;
1126
1127 buffer = buffer
1128 .strip_prefix("<|S|>")
1129 .or_else(|| buffer.strip_prefix("<|S|"))
1130 .unwrap_or(&buffer)
1131 .to_string();
1132 } else if buffer.ends_with("|E|>") {
1133 includes_start_or_end_span = true;
1134 } else if buffer.starts_with("<|")
1135 || buffer.starts_with("<|S")
1136 || buffer.starts_with("<|S|")
1137 || buffer.ends_with('|')
1138 || buffer.ends_with("|E")
1139 || buffer.ends_with("|E|")
1140 {
1141 return future::ready(None);
1142 }
1143
1144 if first_line {
1145 if buffer.is_empty() || buffer == "`" || buffer == "``" {
1146 return future::ready(None);
1147 } else if buffer.starts_with("```") {
1148 starts_with_markdown_codeblock = true;
1149 if let Some(newline_ix) = buffer.find('\n') {
1150 buffer.replace_range(..newline_ix + 1, "");
1151 first_line = false;
1152 } else {
1153 return future::ready(None);
1154 }
1155 }
1156 }
1157
1158 let mut text = buffer.to_string();
1159 if starts_with_markdown_codeblock {
1160 text = text
1161 .strip_suffix("\n```\n")
1162 .or_else(|| text.strip_suffix("\n```"))
1163 .or_else(|| text.strip_suffix("\n``"))
1164 .or_else(|| text.strip_suffix("\n`"))
1165 .or_else(|| text.strip_suffix('\n'))
1166 .unwrap_or(&text)
1167 .to_string();
1168 }
1169
1170 if includes_start_or_end_span {
1171 text = text
1172 .strip_suffix("|E|>")
1173 .or_else(|| text.strip_suffix("E|>"))
1174 .or_else(|| text.strip_prefix("|>"))
1175 .or_else(|| text.strip_prefix('>'))
1176 .unwrap_or(&text)
1177 .to_string();
1178 };
1179
1180 if text.contains('\n') {
1181 first_line = false;
1182 }
1183
1184 let remainder = buffer.split_off(text.len());
1185 let result = if buffer.is_empty() {
1186 None
1187 } else {
1188 Some(Ok(buffer.clone()))
1189 };
1190
1191 buffer = remainder;
1192 future::ready(result)
1193 })
1194}
1195
1196fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1197 ranges.sort_unstable_by(|a, b| {
1198 a.start
1199 .cmp(&b.start, buffer)
1200 .then_with(|| b.end.cmp(&a.end, buffer))
1201 });
1202
1203 let mut ix = 0;
1204 while ix + 1 < ranges.len() {
1205 let b = ranges[ix + 1].clone();
1206 let a = &mut ranges[ix];
1207 if a.end.cmp(&b.start, buffer).is_gt() {
1208 if a.end.cmp(&b.end, buffer).is_lt() {
1209 a.end = b.end;
1210 }
1211 ranges.remove(ix + 1);
1212 } else {
1213 ix += 1;
1214 }
1215 }
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220 use std::sync::Arc;
1221
1222 use crate::FakeCompletionProvider;
1223
1224 use super::*;
1225 use futures::stream::{self};
1226 use gpui::{Context, TestAppContext};
1227 use indoc::indoc;
1228 use language::{
1229 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1230 Point,
1231 };
1232 use rand::prelude::*;
1233 use serde::Serialize;
1234 use settings::SettingsStore;
1235
1236 #[derive(Serialize)]
1237 pub struct DummyCompletionRequest {
1238 pub name: String,
1239 }
1240
1241 #[gpui::test(iterations = 10)]
1242 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1243 let provider = FakeCompletionProvider::default();
1244 cx.set_global(cx.update(SettingsStore::test));
1245 cx.set_global(CompletionProvider::Fake(provider.clone()));
1246 cx.update(language_settings::init);
1247
1248 let text = indoc! {"
1249 fn main() {
1250 let x = 0;
1251 for _ in 0..10 {
1252 x += 1;
1253 }
1254 }
1255 "};
1256 let buffer =
1257 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1258 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1259 let range = buffer.read_with(cx, |buffer, cx| {
1260 let snapshot = buffer.snapshot(cx);
1261 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1262 });
1263 let codegen = cx.new_model(|cx| {
1264 Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1265 });
1266
1267 let request = LanguageModelRequest::default();
1268 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1269
1270 let mut new_text = concat!(
1271 " let mut x = 0;\n",
1272 " while x < 10 {\n",
1273 " x += 1;\n",
1274 " }",
1275 );
1276 while !new_text.is_empty() {
1277 let max_len = cmp::min(new_text.len(), 10);
1278 let len = rng.gen_range(1..=max_len);
1279 let (chunk, suffix) = new_text.split_at(len);
1280 provider.send_completion(chunk.into());
1281 new_text = suffix;
1282 cx.background_executor.run_until_parked();
1283 }
1284 provider.finish_completion();
1285 cx.background_executor.run_until_parked();
1286
1287 assert_eq!(
1288 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1289 indoc! {"
1290 fn main() {
1291 let mut x = 0;
1292 while x < 10 {
1293 x += 1;
1294 }
1295 }
1296 "}
1297 );
1298 }
1299
1300 #[gpui::test(iterations = 10)]
1301 async fn test_autoindent_when_generating_past_indentation(
1302 cx: &mut TestAppContext,
1303 mut rng: StdRng,
1304 ) {
1305 let provider = FakeCompletionProvider::default();
1306 cx.set_global(CompletionProvider::Fake(provider.clone()));
1307 cx.set_global(cx.update(SettingsStore::test));
1308 cx.update(language_settings::init);
1309
1310 let text = indoc! {"
1311 fn main() {
1312 le
1313 }
1314 "};
1315 let buffer =
1316 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1317 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1318 let position = buffer.read_with(cx, |buffer, cx| {
1319 let snapshot = buffer.snapshot(cx);
1320 snapshot.anchor_before(Point::new(1, 6))
1321 });
1322 let codegen = cx.new_model(|cx| {
1323 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1324 });
1325
1326 let request = LanguageModelRequest::default();
1327 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1328
1329 let mut new_text = concat!(
1330 "t mut x = 0;\n",
1331 "while x < 10 {\n",
1332 " x += 1;\n",
1333 "}", //
1334 );
1335 while !new_text.is_empty() {
1336 let max_len = cmp::min(new_text.len(), 10);
1337 let len = rng.gen_range(1..=max_len);
1338 let (chunk, suffix) = new_text.split_at(len);
1339 provider.send_completion(chunk.into());
1340 new_text = suffix;
1341 cx.background_executor.run_until_parked();
1342 }
1343 provider.finish_completion();
1344 cx.background_executor.run_until_parked();
1345
1346 assert_eq!(
1347 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1348 indoc! {"
1349 fn main() {
1350 let mut x = 0;
1351 while x < 10 {
1352 x += 1;
1353 }
1354 }
1355 "}
1356 );
1357 }
1358
1359 #[gpui::test(iterations = 10)]
1360 async fn test_autoindent_when_generating_before_indentation(
1361 cx: &mut TestAppContext,
1362 mut rng: StdRng,
1363 ) {
1364 let provider = FakeCompletionProvider::default();
1365 cx.set_global(CompletionProvider::Fake(provider.clone()));
1366 cx.set_global(cx.update(SettingsStore::test));
1367 cx.update(language_settings::init);
1368
1369 let text = concat!(
1370 "fn main() {\n",
1371 " \n",
1372 "}\n" //
1373 );
1374 let buffer =
1375 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1376 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1377 let position = buffer.read_with(cx, |buffer, cx| {
1378 let snapshot = buffer.snapshot(cx);
1379 snapshot.anchor_before(Point::new(1, 2))
1380 });
1381 let codegen = cx.new_model(|cx| {
1382 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1383 });
1384
1385 let request = LanguageModelRequest::default();
1386 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1387
1388 let mut new_text = concat!(
1389 "let mut x = 0;\n",
1390 "while x < 10 {\n",
1391 " x += 1;\n",
1392 "}", //
1393 );
1394 while !new_text.is_empty() {
1395 let max_len = cmp::min(new_text.len(), 10);
1396 let len = rng.gen_range(1..=max_len);
1397 let (chunk, suffix) = new_text.split_at(len);
1398 provider.send_completion(chunk.into());
1399 new_text = suffix;
1400 cx.background_executor.run_until_parked();
1401 }
1402 provider.finish_completion();
1403 cx.background_executor.run_until_parked();
1404
1405 assert_eq!(
1406 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1407 indoc! {"
1408 fn main() {
1409 let mut x = 0;
1410 while x < 10 {
1411 x += 1;
1412 }
1413 }
1414 "}
1415 );
1416 }
1417
1418 #[gpui::test]
1419 async fn test_strip_invalid_spans_from_codeblock() {
1420 assert_eq!(
1421 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1422 .map(|chunk| chunk.unwrap())
1423 .collect::<String>()
1424 .await,
1425 "Lorem ipsum dolor"
1426 );
1427 assert_eq!(
1428 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1429 .map(|chunk| chunk.unwrap())
1430 .collect::<String>()
1431 .await,
1432 "Lorem ipsum dolor"
1433 );
1434 assert_eq!(
1435 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1436 .map(|chunk| chunk.unwrap())
1437 .collect::<String>()
1438 .await,
1439 "Lorem ipsum dolor"
1440 );
1441 assert_eq!(
1442 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1443 .map(|chunk| chunk.unwrap())
1444 .collect::<String>()
1445 .await,
1446 "Lorem ipsum dolor"
1447 );
1448 assert_eq!(
1449 strip_invalid_spans_from_codeblock(chunks(
1450 "```html\n```js\nLorem ipsum dolor\n```\n```",
1451 2
1452 ))
1453 .map(|chunk| chunk.unwrap())
1454 .collect::<String>()
1455 .await,
1456 "```js\nLorem ipsum dolor\n```"
1457 );
1458 assert_eq!(
1459 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1460 .map(|chunk| chunk.unwrap())
1461 .collect::<String>()
1462 .await,
1463 "``\nLorem ipsum dolor\n```"
1464 );
1465 assert_eq!(
1466 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1467 .map(|chunk| chunk.unwrap())
1468 .collect::<String>()
1469 .await,
1470 "Lorem ipsum"
1471 );
1472
1473 assert_eq!(
1474 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
1475 .map(|chunk| chunk.unwrap())
1476 .collect::<String>()
1477 .await,
1478 "Lorem ipsum"
1479 );
1480
1481 assert_eq!(
1482 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
1483 .map(|chunk| chunk.unwrap())
1484 .collect::<String>()
1485 .await,
1486 "Lorem ipsum"
1487 );
1488 assert_eq!(
1489 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
1490 .map(|chunk| chunk.unwrap())
1491 .collect::<String>()
1492 .await,
1493 "Lorem ipsum"
1494 );
1495 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1496 stream::iter(
1497 text.chars()
1498 .collect::<Vec<_>>()
1499 .chunks(size)
1500 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1501 .collect::<Vec<_>>(),
1502 )
1503 }
1504 }
1505
1506 fn rust_lang() -> Language {
1507 Language::new(
1508 LanguageConfig {
1509 name: "Rust".into(),
1510 matcher: LanguageMatcher {
1511 path_suffixes: vec!["rs".to_string()],
1512 ..Default::default()
1513 },
1514 ..Default::default()
1515 },
1516 Some(tree_sitter_rust::language()),
1517 )
1518 .with_indents_query(
1519 r#"
1520 (call_expression) @indent
1521 (field_expression) @indent
1522 (_ "(" ")" @end) @indent
1523 (_ "{" "}" @end) @indent
1524 "#,
1525 )
1526 .unwrap()
1527 }
1528}