1use crate::{
2 prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
3 LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
4};
5use anyhow::Result;
6use client::telemetry::Telemetry;
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp},
10 display_map::{BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle},
11 scroll::{Autoscroll, AutoscrollStrategy},
12 Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer,
13 MultiBufferSnapshot, ToOffset, ToPoint,
14};
15use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
16use gpui::{
17 AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
18 Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
19 ViewContext, WeakView, WhiteSpace, WindowContext,
20};
21use language::{Point, TransactionId};
22use multi_buffer::MultiBufferRow;
23use parking_lot::Mutex;
24use rope::Rope;
25use settings::Settings;
26use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
27use theme::ThemeSettings;
28use ui::{prelude::*, Tooltip};
29use workspace::{notifications::NotificationId, Toast, Workspace};
30
31pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
32 cx.set_global(InlineAssistant::new(telemetry));
33}
34
35const PROMPT_HISTORY_MAX_LEN: usize = 20;
36
37pub struct InlineAssistant {
38 next_assist_id: InlineAssistId,
39 pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
40 pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
41 prompt_history: VecDeque<String>,
42 telemetry: Option<Arc<Telemetry>>,
43}
44
45struct EditorPendingAssists {
46 window: AnyWindowHandle,
47 assist_ids: Vec<InlineAssistId>,
48}
49
50impl Global for InlineAssistant {}
51
52impl InlineAssistant {
53 pub fn new(telemetry: Arc<Telemetry>) -> Self {
54 Self {
55 next_assist_id: InlineAssistId::default(),
56 pending_assists: HashMap::default(),
57 pending_assist_ids_by_editor: HashMap::default(),
58 prompt_history: VecDeque::default(),
59 telemetry: Some(telemetry),
60 }
61 }
62
63 pub fn assist(
64 &mut self,
65 editor: &View<Editor>,
66 workspace: Option<WeakView<Workspace>>,
67 include_conversation: bool,
68 cx: &mut WindowContext,
69 ) {
70 let selection = editor.read(cx).selections.newest_anchor().clone();
71 if selection.start.excerpt_id != selection.end.excerpt_id {
72 return;
73 }
74 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
75
76 // Extend the selection to the start and the end of the line.
77 let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
78 if point_selection.end > point_selection.start {
79 point_selection.start.column = 0;
80 // If the selection ends at the start of the line, we don't want to include it.
81 if point_selection.end.column == 0 {
82 point_selection.end.row -= 1;
83 }
84 point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
85 }
86
87 let codegen_kind = if point_selection.start == point_selection.end {
88 CodegenKind::Generate {
89 position: snapshot.anchor_after(point_selection.start),
90 }
91 } else {
92 CodegenKind::Transform {
93 range: snapshot.anchor_before(point_selection.start)
94 ..snapshot.anchor_after(point_selection.end),
95 }
96 };
97
98 let inline_assist_id = self.next_assist_id.post_inc();
99 let codegen = cx.new_model(|cx| {
100 Codegen::new(
101 editor.read(cx).buffer().clone(),
102 codegen_kind,
103 self.telemetry.clone(),
104 cx,
105 )
106 });
107
108 let measurements = Arc::new(Mutex::new(GutterDimensions::default()));
109 let inline_assistant = cx.new_view(|cx| {
110 InlineAssistEditor::new(
111 inline_assist_id,
112 measurements.clone(),
113 self.prompt_history.clone(),
114 codegen.clone(),
115 cx,
116 )
117 });
118 let block_id = editor.update(cx, |editor, cx| {
119 editor.change_selections(None, cx, |selections| {
120 selections.select_anchor_ranges([selection.head()..selection.head()])
121 });
122 editor.insert_blocks(
123 [BlockProperties {
124 style: BlockStyle::Flex,
125 position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)),
126 height: 2,
127 render: Box::new({
128 let inline_assistant = inline_assistant.clone();
129 move |cx: &mut BlockContext| {
130 *measurements.lock() = *cx.gutter_dimensions;
131 inline_assistant.clone().into_any_element()
132 }
133 }),
134 disposition: if selection.reversed {
135 BlockDisposition::Above
136 } else {
137 BlockDisposition::Below
138 },
139 }],
140 Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
141 cx,
142 )[0]
143 });
144
145 self.pending_assists.insert(
146 inline_assist_id,
147 PendingInlineAssist {
148 include_conversation,
149 editor: editor.downgrade(),
150 inline_assistant: Some((block_id, inline_assistant.clone())),
151 codegen: codegen.clone(),
152 workspace,
153 _subscriptions: vec![
154 cx.subscribe(&inline_assistant, |inline_assistant, event, cx| {
155 InlineAssistant::update_global(cx, |this, cx| {
156 this.handle_inline_assistant_event(inline_assistant, event, cx)
157 })
158 }),
159 cx.subscribe(editor, {
160 let inline_assistant = inline_assistant.downgrade();
161 move |editor, event, cx| {
162 if let Some(inline_assistant) = inline_assistant.upgrade() {
163 if let EditorEvent::SelectionsChanged { local } = event {
164 if *local
165 && inline_assistant.focus_handle(cx).contains_focused(cx)
166 {
167 cx.focus_view(&editor);
168 }
169 }
170 }
171 }
172 }),
173 cx.observe(&codegen, {
174 let editor = editor.downgrade();
175 move |_, cx| {
176 if let Some(editor) = editor.upgrade() {
177 InlineAssistant::update_global(cx, |this, cx| {
178 this.update_highlights_for_editor(&editor, cx);
179 })
180 }
181 }
182 }),
183 cx.subscribe(&codegen, move |codegen, event, cx| {
184 InlineAssistant::update_global(cx, |this, cx| match event {
185 CodegenEvent::Undone => {
186 this.finish_inline_assist(inline_assist_id, false, cx)
187 }
188 CodegenEvent::Finished => {
189 let pending_assist = if let Some(pending_assist) =
190 this.pending_assists.get(&inline_assist_id)
191 {
192 pending_assist
193 } else {
194 return;
195 };
196
197 let error = codegen
198 .read(cx)
199 .error()
200 .map(|error| format!("Inline assistant error: {}", error));
201 if let Some(error) = error {
202 if pending_assist.inline_assistant.is_none() {
203 if let Some(workspace) = pending_assist
204 .workspace
205 .as_ref()
206 .and_then(|workspace| workspace.upgrade())
207 {
208 workspace.update(cx, |workspace, cx| {
209 struct InlineAssistantError;
210
211 let id = NotificationId::identified::<
212 InlineAssistantError,
213 >(
214 inline_assist_id.0
215 );
216
217 workspace.show_toast(Toast::new(id, error), cx);
218 })
219 }
220
221 this.finish_inline_assist(inline_assist_id, false, cx);
222 }
223 } else {
224 this.finish_inline_assist(inline_assist_id, false, cx);
225 }
226 }
227 })
228 }),
229 ],
230 },
231 );
232
233 self.pending_assist_ids_by_editor
234 .entry(editor.downgrade())
235 .or_insert_with(|| EditorPendingAssists {
236 window: cx.window_handle(),
237 assist_ids: Vec::new(),
238 })
239 .assist_ids
240 .push(inline_assist_id);
241 self.update_highlights_for_editor(editor, cx);
242 }
243
244 fn handle_inline_assistant_event(
245 &mut self,
246 inline_assistant: View<InlineAssistEditor>,
247 event: &InlineAssistEditorEvent,
248 cx: &mut WindowContext,
249 ) {
250 let assist_id = inline_assistant.read(cx).id;
251 match event {
252 InlineAssistEditorEvent::Confirmed { prompt } => {
253 self.confirm_inline_assist(assist_id, prompt, cx);
254 }
255 InlineAssistEditorEvent::Canceled => {
256 self.finish_inline_assist(assist_id, true, cx);
257 }
258 InlineAssistEditorEvent::Dismissed => {
259 self.hide_inline_assist(assist_id, cx);
260 }
261 }
262 }
263
264 pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
265 for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
266 if pending_assists.window == cx.window_handle() {
267 if let Some(editor) = editor.upgrade() {
268 if editor.read(cx).is_focused(cx) {
269 if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
270 self.finish_inline_assist(assist_id, true, cx);
271 return true;
272 }
273 }
274 }
275 }
276 }
277 false
278 }
279
280 fn finish_inline_assist(
281 &mut self,
282 assist_id: InlineAssistId,
283 undo: bool,
284 cx: &mut WindowContext,
285 ) {
286 self.hide_inline_assist(assist_id, cx);
287
288 if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
289 if let hash_map::Entry::Occupied(mut entry) = self
290 .pending_assist_ids_by_editor
291 .entry(pending_assist.editor.clone())
292 {
293 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
294 if entry.get().assist_ids.is_empty() {
295 entry.remove();
296 }
297 }
298
299 if let Some(editor) = pending_assist.editor.upgrade() {
300 self.update_highlights_for_editor(&editor, cx);
301
302 if undo {
303 pending_assist
304 .codegen
305 .update(cx, |codegen, cx| codegen.undo(cx));
306 }
307 }
308 }
309 }
310
311 fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
312 if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
313 if let Some(editor) = pending_assist.editor.upgrade() {
314 if let Some((block_id, inline_assistant)) = pending_assist.inline_assistant.take() {
315 editor.update(cx, |editor, cx| {
316 editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
317 if inline_assistant.focus_handle(cx).contains_focused(cx) {
318 editor.focus(cx);
319 }
320 });
321 }
322 }
323 }
324 }
325
326 fn confirm_inline_assist(
327 &mut self,
328 assist_id: InlineAssistId,
329 user_prompt: &str,
330 cx: &mut WindowContext,
331 ) {
332 let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
333 {
334 pending_assist
335 } else {
336 return;
337 };
338
339 let conversation = if pending_assist.include_conversation {
340 pending_assist.workspace.as_ref().and_then(|workspace| {
341 let workspace = workspace.upgrade()?.read(cx);
342 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
343 assistant_panel.read(cx).active_conversation(cx)
344 })
345 } else {
346 None
347 };
348
349 let editor = if let Some(editor) = pending_assist.editor.upgrade() {
350 editor
351 } else {
352 return;
353 };
354
355 let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
356 let workspace = workspace.upgrade()?;
357 Some(
358 workspace
359 .read(cx)
360 .project()
361 .read(cx)
362 .worktree_root_names(cx)
363 .collect::<Vec<&str>>()
364 .join("/"),
365 )
366 });
367
368 self.prompt_history.retain(|prompt| prompt != user_prompt);
369 self.prompt_history.push_back(user_prompt.into());
370 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
371 self.prompt_history.pop_front();
372 }
373
374 let codegen = pending_assist.codegen.clone();
375 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
376 let range = codegen.read(cx).range();
377 let start = snapshot.point_to_buffer_offset(range.start);
378 let end = snapshot.point_to_buffer_offset(range.end);
379 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
380 let (start_buffer, start_buffer_offset) = start;
381 let (end_buffer, end_buffer_offset) = end;
382 if start_buffer.remote_id() == end_buffer.remote_id() {
383 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
384 } else {
385 self.finish_inline_assist(assist_id, false, cx);
386 return;
387 }
388 } else {
389 self.finish_inline_assist(assist_id, false, cx);
390 return;
391 };
392
393 let language = buffer.language_at(range.start);
394 let language_name = if let Some(language) = language.as_ref() {
395 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
396 None
397 } else {
398 Some(language.name())
399 }
400 } else {
401 None
402 };
403
404 // Higher Temperature increases the randomness of model outputs.
405 // If Markdown or No Language is Known, increase the randomness for more creative output
406 // If Code, decrease temperature to get more deterministic outputs
407 let temperature = if let Some(language) = language_name.clone() {
408 if language.as_ref() == "Markdown" {
409 1.0
410 } else {
411 0.5
412 }
413 } else {
414 1.0
415 };
416
417 let user_prompt = user_prompt.to_string();
418
419 let prompt = cx.background_executor().spawn(async move {
420 let language_name = language_name.as_deref();
421 generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
422 });
423
424 let mut messages = Vec::new();
425 if let Some(conversation) = conversation {
426 let request = conversation.read(cx).to_completion_request(cx);
427 messages = request.messages;
428 }
429 let model = CompletionProvider::global(cx).model();
430
431 cx.spawn(|mut cx| async move {
432 let prompt = prompt.await?;
433
434 messages.push(LanguageModelRequestMessage {
435 role: Role::User,
436 content: prompt,
437 });
438
439 let request = LanguageModelRequest {
440 model,
441 messages,
442 stop: vec!["|END|>".to_string()],
443 temperature,
444 };
445
446 codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
447 anyhow::Ok(())
448 })
449 .detach_and_log_err(cx);
450 }
451
452 fn update_highlights_for_editor(&self, editor: &View<Editor>, cx: &mut WindowContext) {
453 let mut background_ranges = Vec::new();
454 let mut foreground_ranges = Vec::new();
455 let empty_inline_assist_ids = Vec::new();
456 let inline_assist_ids = self
457 .pending_assist_ids_by_editor
458 .get(&editor.downgrade())
459 .map_or(&empty_inline_assist_ids, |pending_assists| {
460 &pending_assists.assist_ids
461 });
462
463 for inline_assist_id in inline_assist_ids {
464 if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
465 let codegen = pending_assist.codegen.read(cx);
466 background_ranges.push(codegen.range());
467 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
468 }
469 }
470
471 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
472 merge_ranges(&mut background_ranges, &snapshot);
473 merge_ranges(&mut foreground_ranges, &snapshot);
474 editor.update(cx, |editor, cx| {
475 if background_ranges.is_empty() {
476 editor.clear_background_highlights::<PendingInlineAssist>(cx);
477 } else {
478 editor.highlight_background::<PendingInlineAssist>(
479 &background_ranges,
480 |theme| theme.editor_active_line_background, // TODO use the appropriate color
481 cx,
482 );
483 }
484
485 if foreground_ranges.is_empty() {
486 editor.clear_highlights::<PendingInlineAssist>(cx);
487 } else {
488 editor.highlight_text::<PendingInlineAssist>(
489 foreground_ranges,
490 HighlightStyle {
491 fade_out: Some(0.6),
492 ..Default::default()
493 },
494 cx,
495 );
496 }
497 });
498 }
499}
500
501#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
502struct InlineAssistId(usize);
503
504impl InlineAssistId {
505 fn post_inc(&mut self) -> InlineAssistId {
506 let id = *self;
507 self.0 += 1;
508 id
509 }
510}
511
512enum InlineAssistEditorEvent {
513 Confirmed { prompt: String },
514 Canceled,
515 Dismissed,
516}
517
518struct InlineAssistEditor {
519 id: InlineAssistId,
520 prompt_editor: View<Editor>,
521 confirmed: bool,
522 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
523 prompt_history: VecDeque<String>,
524 prompt_history_ix: Option<usize>,
525 pending_prompt: String,
526 codegen: Model<Codegen>,
527 _subscriptions: Vec<Subscription>,
528}
529
530impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
531
532impl Render for InlineAssistEditor {
533 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
534 let gutter_dimensions = *self.gutter_dimensions.lock();
535 let icon_size = IconSize::default();
536 h_flex()
537 .w_full()
538 .py_2()
539 .border_y_1()
540 .border_color(cx.theme().colors().border)
541 .bg(cx.theme().colors().editor_background)
542 .on_action(cx.listener(Self::confirm))
543 .on_action(cx.listener(Self::cancel))
544 .on_action(cx.listener(Self::move_up))
545 .on_action(cx.listener(Self::move_down))
546 .child(
547 h_flex()
548 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
549 .pr(gutter_dimensions.fold_area_width())
550 .justify_end()
551 .children(if let Some(error) = self.codegen.read(cx).error() {
552 let error_message = SharedString::from(error.to_string());
553 Some(
554 div()
555 .id("error")
556 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
557 .child(
558 Icon::new(IconName::XCircle)
559 .size(icon_size)
560 .color(Color::Error),
561 ),
562 )
563 } else {
564 None
565 }),
566 )
567 .child(h_flex().flex_1().child(self.render_prompt_editor(cx)))
568 }
569}
570
571impl FocusableView for InlineAssistEditor {
572 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
573 self.prompt_editor.focus_handle(cx)
574 }
575}
576
577impl InlineAssistEditor {
578 #[allow(clippy::too_many_arguments)]
579 fn new(
580 id: InlineAssistId,
581 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
582 prompt_history: VecDeque<String>,
583 codegen: Model<Codegen>,
584 cx: &mut ViewContext<Self>,
585 ) -> Self {
586 let prompt_editor = cx.new_view(|cx| {
587 let mut editor = Editor::single_line(cx);
588 let placeholder = match codegen.read(cx).kind() {
589 CodegenKind::Transform { .. } => "Enter transformation prompt…",
590 CodegenKind::Generate { .. } => "Enter generation prompt…",
591 };
592 editor.set_placeholder_text(placeholder, cx);
593 editor
594 });
595 cx.focus_view(&prompt_editor);
596
597 let subscriptions = vec![
598 cx.observe(&codegen, Self::handle_codegen_changed),
599 cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
600 ];
601
602 Self {
603 id,
604 prompt_editor,
605 confirmed: false,
606 gutter_dimensions,
607 prompt_history,
608 prompt_history_ix: None,
609 pending_prompt: String::new(),
610 codegen,
611 _subscriptions: subscriptions,
612 }
613 }
614
615 fn handle_prompt_editor_events(
616 &mut self,
617 _: View<Editor>,
618 event: &EditorEvent,
619 cx: &mut ViewContext<Self>,
620 ) {
621 if let EditorEvent::Edited = event {
622 self.pending_prompt = self.prompt_editor.read(cx).text(cx);
623 cx.notify();
624 }
625 }
626
627 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
628 let is_read_only = !self.codegen.read(cx).idle();
629 self.prompt_editor.update(cx, |editor, cx| {
630 let was_read_only = editor.read_only(cx);
631 if was_read_only != is_read_only {
632 if is_read_only {
633 editor.set_read_only(true);
634 } else {
635 self.confirmed = false;
636 editor.set_read_only(false);
637 }
638 }
639 });
640 cx.notify();
641 }
642
643 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
644 cx.emit(InlineAssistEditorEvent::Canceled);
645 }
646
647 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
648 if self.confirmed {
649 cx.emit(InlineAssistEditorEvent::Dismissed);
650 } else {
651 let prompt = self.prompt_editor.read(cx).text(cx);
652 self.prompt_editor
653 .update(cx, |editor, _cx| editor.set_read_only(true));
654 cx.emit(InlineAssistEditorEvent::Confirmed { prompt });
655 self.confirmed = true;
656 cx.notify();
657 }
658 }
659
660 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
661 if let Some(ix) = self.prompt_history_ix {
662 if ix > 0 {
663 self.prompt_history_ix = Some(ix - 1);
664 let prompt = self.prompt_history[ix - 1].clone();
665 self.set_prompt(&prompt, cx);
666 }
667 } else if !self.prompt_history.is_empty() {
668 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
669 let prompt = self.prompt_history[self.prompt_history.len() - 1].clone();
670 self.set_prompt(&prompt, cx);
671 }
672 }
673
674 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
675 if let Some(ix) = self.prompt_history_ix {
676 if ix < self.prompt_history.len() - 1 {
677 self.prompt_history_ix = Some(ix + 1);
678 let prompt = self.prompt_history[ix + 1].clone();
679 self.set_prompt(&prompt, cx);
680 } else {
681 self.prompt_history_ix = None;
682 let pending_prompt = self.pending_prompt.clone();
683 self.set_prompt(&pending_prompt, cx);
684 }
685 }
686 }
687
688 fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext<Self>) {
689 self.prompt_editor.update(cx, |editor, cx| {
690 editor.buffer().update(cx, |buffer, cx| {
691 let len = buffer.len(cx);
692 buffer.edit([(0..len, prompt)], None, cx);
693 });
694 });
695 }
696
697 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
698 let settings = ThemeSettings::get_global(cx);
699 let text_style = TextStyle {
700 color: if self.prompt_editor.read(cx).read_only(cx) {
701 cx.theme().colors().text_disabled
702 } else {
703 cx.theme().colors().text
704 },
705 font_family: settings.ui_font.family.clone(),
706 font_features: settings.ui_font.features.clone(),
707 font_size: rems(0.875).into(),
708 font_weight: FontWeight::NORMAL,
709 font_style: FontStyle::Normal,
710 line_height: relative(1.3),
711 background_color: None,
712 underline: None,
713 strikethrough: None,
714 white_space: WhiteSpace::Normal,
715 };
716 EditorElement::new(
717 &self.prompt_editor,
718 EditorStyle {
719 background: cx.theme().colors().editor_background,
720 local_player: cx.theme().players().local(),
721 text: text_style,
722 ..Default::default()
723 },
724 )
725 }
726}
727
728struct PendingInlineAssist {
729 editor: WeakView<Editor>,
730 inline_assistant: Option<(BlockId, View<InlineAssistEditor>)>,
731 codegen: Model<Codegen>,
732 _subscriptions: Vec<Subscription>,
733 workspace: Option<WeakView<Workspace>>,
734 include_conversation: bool,
735}
736
737#[derive(Debug)]
738pub enum CodegenEvent {
739 Finished,
740 Undone,
741}
742
743#[derive(Clone)]
744pub enum CodegenKind {
745 Transform { range: Range<Anchor> },
746 Generate { position: Anchor },
747}
748
749pub struct Codegen {
750 buffer: Model<MultiBuffer>,
751 snapshot: MultiBufferSnapshot,
752 kind: CodegenKind,
753 last_equal_ranges: Vec<Range<Anchor>>,
754 transaction_id: Option<TransactionId>,
755 error: Option<anyhow::Error>,
756 generation: Task<()>,
757 idle: bool,
758 telemetry: Option<Arc<Telemetry>>,
759 _subscription: gpui::Subscription,
760}
761
762impl EventEmitter<CodegenEvent> for Codegen {}
763
764impl Codegen {
765 pub fn new(
766 buffer: Model<MultiBuffer>,
767 kind: CodegenKind,
768 telemetry: Option<Arc<Telemetry>>,
769 cx: &mut ModelContext<Self>,
770 ) -> Self {
771 let snapshot = buffer.read(cx).snapshot(cx);
772 Self {
773 buffer: buffer.clone(),
774 snapshot,
775 kind,
776 last_equal_ranges: Default::default(),
777 transaction_id: Default::default(),
778 error: Default::default(),
779 idle: true,
780 generation: Task::ready(()),
781 telemetry,
782 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
783 }
784 }
785
786 fn handle_buffer_event(
787 &mut self,
788 _buffer: Model<MultiBuffer>,
789 event: &multi_buffer::Event,
790 cx: &mut ModelContext<Self>,
791 ) {
792 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
793 if self.transaction_id == Some(*transaction_id) {
794 self.transaction_id = None;
795 self.generation = Task::ready(());
796 cx.emit(CodegenEvent::Undone);
797 }
798 }
799 }
800
801 pub fn range(&self) -> Range<Anchor> {
802 match &self.kind {
803 CodegenKind::Transform { range } => range.clone(),
804 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
805 }
806 }
807
808 pub fn kind(&self) -> &CodegenKind {
809 &self.kind
810 }
811
812 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
813 &self.last_equal_ranges
814 }
815
816 pub fn idle(&self) -> bool {
817 self.idle
818 }
819
820 pub fn error(&self) -> Option<&anyhow::Error> {
821 self.error.as_ref()
822 }
823
824 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
825 let range = self.range();
826 let snapshot = self.snapshot.clone();
827 let selected_text = snapshot
828 .text_for_range(range.start..range.end)
829 .collect::<Rope>();
830
831 let selection_start = range.start.to_point(&snapshot);
832 let suggested_line_indent = snapshot
833 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
834 .into_values()
835 .next()
836 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
837
838 let model_telemetry_id = prompt.model.telemetry_id();
839 let response = CompletionProvider::global(cx).complete(prompt);
840 let telemetry = self.telemetry.clone();
841 self.generation = cx.spawn(|this, mut cx| {
842 async move {
843 let generate = async {
844 let mut edit_start = range.start.to_offset(&snapshot);
845
846 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
847 let diff: Task<anyhow::Result<()>> =
848 cx.background_executor().spawn(async move {
849 let mut response_latency = None;
850 let request_start = Instant::now();
851 let diff = async {
852 let chunks = strip_invalid_spans_from_codeblock(response.await?);
853 futures::pin_mut!(chunks);
854 let mut diff = StreamingDiff::new(selected_text.to_string());
855
856 let mut new_text = String::new();
857 let mut base_indent = None;
858 let mut line_indent = None;
859 let mut first_line = true;
860
861 while let Some(chunk) = chunks.next().await {
862 if response_latency.is_none() {
863 response_latency = Some(request_start.elapsed());
864 }
865 let chunk = chunk?;
866
867 let mut lines = chunk.split('\n').peekable();
868 while let Some(line) = lines.next() {
869 new_text.push_str(line);
870 if line_indent.is_none() {
871 if let Some(non_whitespace_ch_ix) =
872 new_text.find(|ch: char| !ch.is_whitespace())
873 {
874 line_indent = Some(non_whitespace_ch_ix);
875 base_indent = base_indent.or(line_indent);
876
877 let line_indent = line_indent.unwrap();
878 let base_indent = base_indent.unwrap();
879 let indent_delta =
880 line_indent as i32 - base_indent as i32;
881 let mut corrected_indent_len = cmp::max(
882 0,
883 suggested_line_indent.len as i32 + indent_delta,
884 )
885 as usize;
886 if first_line {
887 corrected_indent_len = corrected_indent_len
888 .saturating_sub(
889 selection_start.column as usize,
890 );
891 }
892
893 let indent_char = suggested_line_indent.char();
894 let mut indent_buffer = [0; 4];
895 let indent_str =
896 indent_char.encode_utf8(&mut indent_buffer);
897 new_text.replace_range(
898 ..line_indent,
899 &indent_str.repeat(corrected_indent_len),
900 );
901 }
902 }
903
904 if line_indent.is_some() {
905 hunks_tx.send(diff.push_new(&new_text)).await?;
906 new_text.clear();
907 }
908
909 if lines.peek().is_some() {
910 hunks_tx.send(diff.push_new("\n")).await?;
911 line_indent = None;
912 first_line = false;
913 }
914 }
915 }
916 hunks_tx.send(diff.push_new(&new_text)).await?;
917 hunks_tx.send(diff.finish()).await?;
918
919 anyhow::Ok(())
920 };
921
922 let result = diff.await;
923
924 let error_message =
925 result.as_ref().err().map(|error| error.to_string());
926 if let Some(telemetry) = telemetry {
927 telemetry.report_assistant_event(
928 None,
929 telemetry_events::AssistantKind::Inline,
930 model_telemetry_id,
931 response_latency,
932 error_message,
933 );
934 }
935
936 result?;
937 Ok(())
938 });
939
940 while let Some(hunks) = hunks_rx.next().await {
941 this.update(&mut cx, |this, cx| {
942 this.last_equal_ranges.clear();
943
944 let transaction = this.buffer.update(cx, |buffer, cx| {
945 // Avoid grouping assistant edits with user edits.
946 buffer.finalize_last_transaction(cx);
947
948 buffer.start_transaction(cx);
949 buffer.edit(
950 hunks.into_iter().filter_map(|hunk| match hunk {
951 Hunk::Insert { text } => {
952 let edit_start = snapshot.anchor_after(edit_start);
953 Some((edit_start..edit_start, text))
954 }
955 Hunk::Remove { len } => {
956 let edit_end = edit_start + len;
957 let edit_range = snapshot.anchor_after(edit_start)
958 ..snapshot.anchor_before(edit_end);
959 edit_start = edit_end;
960 Some((edit_range, String::new()))
961 }
962 Hunk::Keep { len } => {
963 let edit_end = edit_start + len;
964 let edit_range = snapshot.anchor_after(edit_start)
965 ..snapshot.anchor_before(edit_end);
966 edit_start = edit_end;
967 this.last_equal_ranges.push(edit_range);
968 None
969 }
970 }),
971 None,
972 cx,
973 );
974
975 buffer.end_transaction(cx)
976 });
977
978 if let Some(transaction) = transaction {
979 if let Some(first_transaction) = this.transaction_id {
980 // Group all assistant edits into the first transaction.
981 this.buffer.update(cx, |buffer, cx| {
982 buffer.merge_transactions(
983 transaction,
984 first_transaction,
985 cx,
986 )
987 });
988 } else {
989 this.transaction_id = Some(transaction);
990 this.buffer.update(cx, |buffer, cx| {
991 buffer.finalize_last_transaction(cx)
992 });
993 }
994 }
995
996 cx.notify();
997 })?;
998 }
999
1000 diff.await?;
1001
1002 anyhow::Ok(())
1003 };
1004
1005 let result = generate.await;
1006 this.update(&mut cx, |this, cx| {
1007 this.last_equal_ranges.clear();
1008 this.idle = true;
1009 if let Err(error) = result {
1010 this.error = Some(error);
1011 }
1012 cx.emit(CodegenEvent::Finished);
1013 cx.notify();
1014 })
1015 .ok();
1016 }
1017 });
1018 self.error.take();
1019 self.idle = false;
1020 cx.notify();
1021 }
1022
1023 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1024 if let Some(transaction_id) = self.transaction_id {
1025 self.buffer
1026 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1027 }
1028 }
1029}
1030
1031fn strip_invalid_spans_from_codeblock(
1032 stream: impl Stream<Item = Result<String>>,
1033) -> impl Stream<Item = Result<String>> {
1034 let mut first_line = true;
1035 let mut buffer = String::new();
1036 let mut starts_with_markdown_codeblock = false;
1037 let mut includes_start_or_end_span = false;
1038 stream.filter_map(move |chunk| {
1039 let chunk = match chunk {
1040 Ok(chunk) => chunk,
1041 Err(err) => return future::ready(Some(Err(err))),
1042 };
1043 buffer.push_str(&chunk);
1044
1045 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1046 includes_start_or_end_span = true;
1047
1048 buffer = buffer
1049 .strip_prefix("<|S|>")
1050 .or_else(|| buffer.strip_prefix("<|S|"))
1051 .unwrap_or(&buffer)
1052 .to_string();
1053 } else if buffer.ends_with("|E|>") {
1054 includes_start_or_end_span = true;
1055 } else if buffer.starts_with("<|")
1056 || buffer.starts_with("<|S")
1057 || buffer.starts_with("<|S|")
1058 || buffer.ends_with('|')
1059 || buffer.ends_with("|E")
1060 || buffer.ends_with("|E|")
1061 {
1062 return future::ready(None);
1063 }
1064
1065 if first_line {
1066 if buffer.is_empty() || buffer == "`" || buffer == "``" {
1067 return future::ready(None);
1068 } else if buffer.starts_with("```") {
1069 starts_with_markdown_codeblock = true;
1070 if let Some(newline_ix) = buffer.find('\n') {
1071 buffer.replace_range(..newline_ix + 1, "");
1072 first_line = false;
1073 } else {
1074 return future::ready(None);
1075 }
1076 }
1077 }
1078
1079 let mut text = buffer.to_string();
1080 if starts_with_markdown_codeblock {
1081 text = text
1082 .strip_suffix("\n```\n")
1083 .or_else(|| text.strip_suffix("\n```"))
1084 .or_else(|| text.strip_suffix("\n``"))
1085 .or_else(|| text.strip_suffix("\n`"))
1086 .or_else(|| text.strip_suffix('\n'))
1087 .unwrap_or(&text)
1088 .to_string();
1089 }
1090
1091 if includes_start_or_end_span {
1092 text = text
1093 .strip_suffix("|E|>")
1094 .or_else(|| text.strip_suffix("E|>"))
1095 .or_else(|| text.strip_prefix("|>"))
1096 .or_else(|| text.strip_prefix('>'))
1097 .unwrap_or(&text)
1098 .to_string();
1099 };
1100
1101 if text.contains('\n') {
1102 first_line = false;
1103 }
1104
1105 let remainder = buffer.split_off(text.len());
1106 let result = if buffer.is_empty() {
1107 None
1108 } else {
1109 Some(Ok(buffer.clone()))
1110 };
1111
1112 buffer = remainder;
1113 future::ready(result)
1114 })
1115}
1116
1117fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1118 ranges.sort_unstable_by(|a, b| {
1119 a.start
1120 .cmp(&b.start, buffer)
1121 .then_with(|| b.end.cmp(&a.end, buffer))
1122 });
1123
1124 let mut ix = 0;
1125 while ix + 1 < ranges.len() {
1126 let b = ranges[ix + 1].clone();
1127 let a = &mut ranges[ix];
1128 if a.end.cmp(&b.start, buffer).is_gt() {
1129 if a.end.cmp(&b.end, buffer).is_lt() {
1130 a.end = b.end;
1131 }
1132 ranges.remove(ix + 1);
1133 } else {
1134 ix += 1;
1135 }
1136 }
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141 use std::sync::Arc;
1142
1143 use crate::FakeCompletionProvider;
1144
1145 use super::*;
1146 use futures::stream::{self};
1147 use gpui::{Context, TestAppContext};
1148 use indoc::indoc;
1149 use language::{
1150 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1151 Point,
1152 };
1153 use rand::prelude::*;
1154 use serde::Serialize;
1155 use settings::SettingsStore;
1156
1157 #[derive(Serialize)]
1158 pub struct DummyCompletionRequest {
1159 pub name: String,
1160 }
1161
1162 #[gpui::test(iterations = 10)]
1163 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1164 let provider = FakeCompletionProvider::default();
1165 cx.set_global(cx.update(SettingsStore::test));
1166 cx.set_global(CompletionProvider::Fake(provider.clone()));
1167 cx.update(language_settings::init);
1168
1169 let text = indoc! {"
1170 fn main() {
1171 let x = 0;
1172 for _ in 0..10 {
1173 x += 1;
1174 }
1175 }
1176 "};
1177 let buffer =
1178 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1179 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1180 let range = buffer.read_with(cx, |buffer, cx| {
1181 let snapshot = buffer.snapshot(cx);
1182 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1183 });
1184 let codegen = cx.new_model(|cx| {
1185 Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1186 });
1187
1188 let request = LanguageModelRequest::default();
1189 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1190
1191 let mut new_text = concat!(
1192 " let mut x = 0;\n",
1193 " while x < 10 {\n",
1194 " x += 1;\n",
1195 " }",
1196 );
1197 while !new_text.is_empty() {
1198 let max_len = cmp::min(new_text.len(), 10);
1199 let len = rng.gen_range(1..=max_len);
1200 let (chunk, suffix) = new_text.split_at(len);
1201 provider.send_completion(chunk.into());
1202 new_text = suffix;
1203 cx.background_executor.run_until_parked();
1204 }
1205 provider.finish_completion();
1206 cx.background_executor.run_until_parked();
1207
1208 assert_eq!(
1209 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1210 indoc! {"
1211 fn main() {
1212 let mut x = 0;
1213 while x < 10 {
1214 x += 1;
1215 }
1216 }
1217 "}
1218 );
1219 }
1220
1221 #[gpui::test(iterations = 10)]
1222 async fn test_autoindent_when_generating_past_indentation(
1223 cx: &mut TestAppContext,
1224 mut rng: StdRng,
1225 ) {
1226 let provider = FakeCompletionProvider::default();
1227 cx.set_global(CompletionProvider::Fake(provider.clone()));
1228 cx.set_global(cx.update(SettingsStore::test));
1229 cx.update(language_settings::init);
1230
1231 let text = indoc! {"
1232 fn main() {
1233 le
1234 }
1235 "};
1236 let buffer =
1237 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1238 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1239 let position = buffer.read_with(cx, |buffer, cx| {
1240 let snapshot = buffer.snapshot(cx);
1241 snapshot.anchor_before(Point::new(1, 6))
1242 });
1243 let codegen = cx.new_model(|cx| {
1244 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1245 });
1246
1247 let request = LanguageModelRequest::default();
1248 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1249
1250 let mut new_text = concat!(
1251 "t mut x = 0;\n",
1252 "while x < 10 {\n",
1253 " x += 1;\n",
1254 "}", //
1255 );
1256 while !new_text.is_empty() {
1257 let max_len = cmp::min(new_text.len(), 10);
1258 let len = rng.gen_range(1..=max_len);
1259 let (chunk, suffix) = new_text.split_at(len);
1260 provider.send_completion(chunk.into());
1261 new_text = suffix;
1262 cx.background_executor.run_until_parked();
1263 }
1264 provider.finish_completion();
1265 cx.background_executor.run_until_parked();
1266
1267 assert_eq!(
1268 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1269 indoc! {"
1270 fn main() {
1271 let mut x = 0;
1272 while x < 10 {
1273 x += 1;
1274 }
1275 }
1276 "}
1277 );
1278 }
1279
1280 #[gpui::test(iterations = 10)]
1281 async fn test_autoindent_when_generating_before_indentation(
1282 cx: &mut TestAppContext,
1283 mut rng: StdRng,
1284 ) {
1285 let provider = FakeCompletionProvider::default();
1286 cx.set_global(CompletionProvider::Fake(provider.clone()));
1287 cx.set_global(cx.update(SettingsStore::test));
1288 cx.update(language_settings::init);
1289
1290 let text = concat!(
1291 "fn main() {\n",
1292 " \n",
1293 "}\n" //
1294 );
1295 let buffer =
1296 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1297 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1298 let position = buffer.read_with(cx, |buffer, cx| {
1299 let snapshot = buffer.snapshot(cx);
1300 snapshot.anchor_before(Point::new(1, 2))
1301 });
1302 let codegen = cx.new_model(|cx| {
1303 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1304 });
1305
1306 let request = LanguageModelRequest::default();
1307 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1308
1309 let mut new_text = concat!(
1310 "let mut x = 0;\n",
1311 "while x < 10 {\n",
1312 " x += 1;\n",
1313 "}", //
1314 );
1315 while !new_text.is_empty() {
1316 let max_len = cmp::min(new_text.len(), 10);
1317 let len = rng.gen_range(1..=max_len);
1318 let (chunk, suffix) = new_text.split_at(len);
1319 provider.send_completion(chunk.into());
1320 new_text = suffix;
1321 cx.background_executor.run_until_parked();
1322 }
1323 provider.finish_completion();
1324 cx.background_executor.run_until_parked();
1325
1326 assert_eq!(
1327 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1328 indoc! {"
1329 fn main() {
1330 let mut x = 0;
1331 while x < 10 {
1332 x += 1;
1333 }
1334 }
1335 "}
1336 );
1337 }
1338
1339 #[gpui::test]
1340 async fn test_strip_invalid_spans_from_codeblock() {
1341 assert_eq!(
1342 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1343 .map(|chunk| chunk.unwrap())
1344 .collect::<String>()
1345 .await,
1346 "Lorem ipsum dolor"
1347 );
1348 assert_eq!(
1349 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1350 .map(|chunk| chunk.unwrap())
1351 .collect::<String>()
1352 .await,
1353 "Lorem ipsum dolor"
1354 );
1355 assert_eq!(
1356 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1357 .map(|chunk| chunk.unwrap())
1358 .collect::<String>()
1359 .await,
1360 "Lorem ipsum dolor"
1361 );
1362 assert_eq!(
1363 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1364 .map(|chunk| chunk.unwrap())
1365 .collect::<String>()
1366 .await,
1367 "Lorem ipsum dolor"
1368 );
1369 assert_eq!(
1370 strip_invalid_spans_from_codeblock(chunks(
1371 "```html\n```js\nLorem ipsum dolor\n```\n```",
1372 2
1373 ))
1374 .map(|chunk| chunk.unwrap())
1375 .collect::<String>()
1376 .await,
1377 "```js\nLorem ipsum dolor\n```"
1378 );
1379 assert_eq!(
1380 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1381 .map(|chunk| chunk.unwrap())
1382 .collect::<String>()
1383 .await,
1384 "``\nLorem ipsum dolor\n```"
1385 );
1386 assert_eq!(
1387 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1388 .map(|chunk| chunk.unwrap())
1389 .collect::<String>()
1390 .await,
1391 "Lorem ipsum"
1392 );
1393
1394 assert_eq!(
1395 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
1396 .map(|chunk| chunk.unwrap())
1397 .collect::<String>()
1398 .await,
1399 "Lorem ipsum"
1400 );
1401
1402 assert_eq!(
1403 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
1404 .map(|chunk| chunk.unwrap())
1405 .collect::<String>()
1406 .await,
1407 "Lorem ipsum"
1408 );
1409 assert_eq!(
1410 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
1411 .map(|chunk| chunk.unwrap())
1412 .collect::<String>()
1413 .await,
1414 "Lorem ipsum"
1415 );
1416 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1417 stream::iter(
1418 text.chars()
1419 .collect::<Vec<_>>()
1420 .chunks(size)
1421 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1422 .collect::<Vec<_>>(),
1423 )
1424 }
1425 }
1426
1427 fn rust_lang() -> Language {
1428 Language::new(
1429 LanguageConfig {
1430 name: "Rust".into(),
1431 matcher: LanguageMatcher {
1432 path_suffixes: vec!["rs".to_string()],
1433 ..Default::default()
1434 },
1435 ..Default::default()
1436 },
1437 Some(tree_sitter_rust::language()),
1438 )
1439 .with_indents_query(
1440 r#"
1441 (call_expression) @indent
1442 (field_expression) @indent
1443 (_ "(" ")" @end) @indent
1444 (_ "{" "}" @end) @indent
1445 "#,
1446 )
1447 .unwrap()
1448 }
1449}