1use crate::{
2 prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
3 LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
4};
5use anyhow::Result;
6use client::telemetry::Telemetry;
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp},
10 display_map::{
11 BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
12 },
13 scroll::{Autoscroll, AutoscrollStrategy},
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorStyle, ExcerptRange,
15 GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
18use gpui::{
19 AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
20 Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
21 ViewContext, WeakView, WhiteSpace, WindowContext,
22};
23use language::{Buffer, Point, TransactionId};
24use multi_buffer::MultiBufferRow;
25use parking_lot::Mutex;
26use rope::Rope;
27use settings::Settings;
28use similar::TextDiff;
29use std::{
30 cmp, future, mem,
31 ops::{Range, RangeInclusive},
32 sync::Arc,
33 time::Instant,
34};
35use theme::ThemeSettings;
36use ui::{prelude::*, Tooltip};
37use workspace::{notifications::NotificationId, Toast, Workspace};
38
39pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
40 cx.set_global(InlineAssistant::new(telemetry));
41}
42
43const PROMPT_HISTORY_MAX_LEN: usize = 20;
44
45pub struct InlineAssistant {
46 next_assist_id: InlineAssistId,
47 pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
48 pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
49 prompt_history: VecDeque<String>,
50 telemetry: Option<Arc<Telemetry>>,
51}
52
53struct EditorPendingAssists {
54 window: AnyWindowHandle,
55 assist_ids: Vec<InlineAssistId>,
56}
57
58impl Global for InlineAssistant {}
59
60impl InlineAssistant {
61 pub fn new(telemetry: Arc<Telemetry>) -> Self {
62 Self {
63 next_assist_id: InlineAssistId::default(),
64 pending_assists: HashMap::default(),
65 pending_assist_ids_by_editor: HashMap::default(),
66 prompt_history: VecDeque::default(),
67 telemetry: Some(telemetry),
68 }
69 }
70
71 pub fn assist(
72 &mut self,
73 editor: &View<Editor>,
74 workspace: Option<WeakView<Workspace>>,
75 include_context: bool,
76 cx: &mut WindowContext,
77 ) {
78 let selection = editor.read(cx).selections.newest_anchor().clone();
79 if selection.start.excerpt_id != selection.end.excerpt_id {
80 return;
81 }
82 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
83
84 // Extend the selection to the start and the end of the line.
85 let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
86 if point_selection.end > point_selection.start {
87 point_selection.start.column = 0;
88 // If the selection ends at the start of the line, we don't want to include it.
89 if point_selection.end.column == 0 {
90 point_selection.end.row -= 1;
91 }
92 point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
93 }
94
95 let codegen_kind = if point_selection.start == point_selection.end {
96 CodegenKind::Generate {
97 position: snapshot.anchor_after(point_selection.start),
98 }
99 } else {
100 CodegenKind::Transform {
101 range: snapshot.anchor_before(point_selection.start)
102 ..snapshot.anchor_after(point_selection.end),
103 }
104 };
105
106 let inline_assist_id = self.next_assist_id.post_inc();
107 let codegen = cx.new_model(|cx| {
108 Codegen::new(
109 editor.read(cx).buffer().clone(),
110 codegen_kind,
111 self.telemetry.clone(),
112 cx,
113 )
114 });
115
116 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
117 let prompt_editor = cx.new_view(|cx| {
118 InlineAssistEditor::new(
119 inline_assist_id,
120 gutter_dimensions.clone(),
121 self.prompt_history.clone(),
122 codegen.clone(),
123 workspace.clone(),
124 cx,
125 )
126 });
127 let (prompt_block_id, end_block_id) = editor.update(cx, |editor, cx| {
128 let start_anchor = snapshot.anchor_before(point_selection.start);
129 let end_anchor = snapshot.anchor_after(point_selection.end);
130 editor.change_selections(Some(Autoscroll::newest()), cx, |selections| {
131 selections.select_anchor_ranges([start_anchor..start_anchor])
132 });
133 let block_ids = editor.insert_blocks(
134 [
135 BlockProperties {
136 style: BlockStyle::Sticky,
137 position: start_anchor,
138 height: prompt_editor.read(cx).height_in_lines,
139 render: build_inline_assist_editor_renderer(
140 &prompt_editor,
141 gutter_dimensions,
142 ),
143 disposition: BlockDisposition::Above,
144 },
145 BlockProperties {
146 style: BlockStyle::Sticky,
147 position: end_anchor,
148 height: 1,
149 render: Box::new(|cx| {
150 v_flex()
151 .h_full()
152 .w_full()
153 .border_t_1()
154 .border_color(cx.theme().status().info_border)
155 .into_any_element()
156 }),
157 disposition: BlockDisposition::Below,
158 },
159 ],
160 Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
161 cx,
162 );
163 (block_ids[0], block_ids[1])
164 });
165
166 self.pending_assists.insert(
167 inline_assist_id,
168 PendingInlineAssist {
169 include_context,
170 editor: editor.downgrade(),
171 editor_decorations: Some(PendingInlineAssistDecorations {
172 prompt_block_id,
173 prompt_editor: prompt_editor.clone(),
174 removed_line_block_ids: HashSet::default(),
175 end_block_id,
176 }),
177 codegen: codegen.clone(),
178 workspace,
179 _subscriptions: vec![
180 cx.subscribe(&prompt_editor, |inline_assist_editor, event, cx| {
181 InlineAssistant::update_global(cx, |this, cx| {
182 this.handle_inline_assistant_event(inline_assist_editor, event, cx)
183 })
184 }),
185 cx.subscribe(editor, {
186 let inline_assist_editor = prompt_editor.downgrade();
187 move |editor, event, cx| {
188 if let Some(inline_assist_editor) = inline_assist_editor.upgrade() {
189 if let EditorEvent::SelectionsChanged { local } = event {
190 if *local
191 && inline_assist_editor
192 .focus_handle(cx)
193 .contains_focused(cx)
194 {
195 cx.focus_view(&editor);
196 }
197 }
198 }
199 }
200 }),
201 cx.observe(&codegen, {
202 let editor = editor.downgrade();
203 move |_, cx| {
204 if let Some(editor) = editor.upgrade() {
205 InlineAssistant::update_global(cx, |this, cx| {
206 this.update_editor_highlights(&editor, cx);
207 this.update_editor_blocks(&editor, inline_assist_id, cx);
208 })
209 }
210 }
211 }),
212 cx.subscribe(&codegen, move |codegen, event, cx| {
213 InlineAssistant::update_global(cx, |this, cx| match event {
214 CodegenEvent::Undone => {
215 this.finish_inline_assist(inline_assist_id, false, cx)
216 }
217 CodegenEvent::Finished => {
218 let pending_assist = if let Some(pending_assist) =
219 this.pending_assists.get(&inline_assist_id)
220 {
221 pending_assist
222 } else {
223 return;
224 };
225
226 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
227 if pending_assist.editor_decorations.is_none() {
228 if let Some(workspace) = pending_assist
229 .workspace
230 .as_ref()
231 .and_then(|workspace| workspace.upgrade())
232 {
233 let error =
234 format!("Inline assistant error: {}", error);
235 workspace.update(cx, |workspace, cx| {
236 struct InlineAssistantError;
237
238 let id = NotificationId::identified::<
239 InlineAssistantError,
240 >(
241 inline_assist_id.0
242 );
243
244 workspace.show_toast(Toast::new(id, error), cx);
245 })
246 }
247 }
248 }
249
250 if pending_assist.editor_decorations.is_none() {
251 this.finish_inline_assist(inline_assist_id, false, cx);
252 }
253 }
254 })
255 }),
256 ],
257 },
258 );
259
260 self.pending_assist_ids_by_editor
261 .entry(editor.downgrade())
262 .or_insert_with(|| EditorPendingAssists {
263 window: cx.window_handle(),
264 assist_ids: Vec::new(),
265 })
266 .assist_ids
267 .push(inline_assist_id);
268 self.update_editor_highlights(editor, cx);
269 }
270
271 fn handle_inline_assistant_event(
272 &mut self,
273 inline_assist_editor: View<InlineAssistEditor>,
274 event: &InlineAssistEditorEvent,
275 cx: &mut WindowContext,
276 ) {
277 let assist_id = inline_assist_editor.read(cx).id;
278 match event {
279 InlineAssistEditorEvent::Started => {
280 self.start_inline_assist(assist_id, cx);
281 }
282 InlineAssistEditorEvent::Stopped => {
283 self.stop_inline_assist(assist_id, cx);
284 }
285 InlineAssistEditorEvent::Confirmed => {
286 self.finish_inline_assist(assist_id, false, cx);
287 }
288 InlineAssistEditorEvent::Canceled => {
289 self.finish_inline_assist(assist_id, true, cx);
290 }
291 InlineAssistEditorEvent::Dismissed => {
292 self.hide_inline_assist_decorations(assist_id, cx);
293 }
294 InlineAssistEditorEvent::Resized { height_in_lines } => {
295 self.resize_inline_assist(assist_id, *height_in_lines, cx);
296 }
297 }
298 }
299
300 pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
301 for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
302 if pending_assists.window == cx.window_handle() {
303 if let Some(editor) = editor.upgrade() {
304 if editor.read(cx).is_focused(cx) {
305 if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
306 self.finish_inline_assist(assist_id, true, cx);
307 return true;
308 }
309 }
310 }
311 }
312 }
313 false
314 }
315
316 fn finish_inline_assist(
317 &mut self,
318 assist_id: InlineAssistId,
319 undo: bool,
320 cx: &mut WindowContext,
321 ) {
322 self.hide_inline_assist_decorations(assist_id, cx);
323
324 if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
325 if let hash_map::Entry::Occupied(mut entry) = self
326 .pending_assist_ids_by_editor
327 .entry(pending_assist.editor.clone())
328 {
329 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
330 if entry.get().assist_ids.is_empty() {
331 entry.remove();
332 }
333 }
334
335 if let Some(editor) = pending_assist.editor.upgrade() {
336 self.update_editor_highlights(&editor, cx);
337
338 if undo {
339 pending_assist
340 .codegen
341 .update(cx, |codegen, cx| codegen.undo(cx));
342 }
343 }
344 }
345 }
346
347 fn hide_inline_assist_decorations(
348 &mut self,
349 assist_id: InlineAssistId,
350 cx: &mut WindowContext,
351 ) -> bool {
352 let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) else {
353 return false;
354 };
355 let Some(editor) = pending_assist.editor.upgrade() else {
356 return false;
357 };
358 let Some(decorations) = pending_assist.editor_decorations.take() else {
359 return false;
360 };
361
362 editor.update(cx, |editor, cx| {
363 let mut to_remove = decorations.removed_line_block_ids;
364 to_remove.insert(decorations.prompt_block_id);
365 to_remove.insert(decorations.end_block_id);
366 editor.remove_blocks(to_remove, None, cx);
367 if decorations
368 .prompt_editor
369 .focus_handle(cx)
370 .contains_focused(cx)
371 {
372 editor.focus(cx);
373 }
374 });
375
376 self.update_editor_highlights(&editor, cx);
377 true
378 }
379
380 fn resize_inline_assist(
381 &mut self,
382 assist_id: InlineAssistId,
383 height_in_lines: u8,
384 cx: &mut WindowContext,
385 ) {
386 if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
387 if let Some(editor) = pending_assist.editor.upgrade() {
388 if let Some(decorations) = pending_assist.editor_decorations.as_ref() {
389 let gutter_dimensions =
390 decorations.prompt_editor.read(cx).gutter_dimensions.clone();
391 let mut new_blocks = HashMap::default();
392 new_blocks.insert(
393 decorations.prompt_block_id,
394 (
395 Some(height_in_lines),
396 build_inline_assist_editor_renderer(
397 &decorations.prompt_editor,
398 gutter_dimensions,
399 ),
400 ),
401 );
402 editor.update(cx, |editor, cx| {
403 editor
404 .display_map
405 .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
406 });
407 }
408 }
409 }
410 }
411
412 fn start_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
413 let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
414 {
415 pending_assist
416 } else {
417 return;
418 };
419
420 pending_assist
421 .codegen
422 .update(cx, |codegen, cx| codegen.undo(cx));
423
424 let Some(user_prompt) = pending_assist
425 .editor_decorations
426 .as_ref()
427 .map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
428 else {
429 return;
430 };
431
432 let context = if pending_assist.include_context {
433 pending_assist.workspace.as_ref().and_then(|workspace| {
434 let workspace = workspace.upgrade()?.read(cx);
435 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
436 assistant_panel.read(cx).active_context(cx)
437 })
438 } else {
439 None
440 };
441
442 let editor = if let Some(editor) = pending_assist.editor.upgrade() {
443 editor
444 } else {
445 return;
446 };
447
448 let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
449 let workspace = workspace.upgrade()?;
450 Some(
451 workspace
452 .read(cx)
453 .project()
454 .read(cx)
455 .worktree_root_names(cx)
456 .collect::<Vec<&str>>()
457 .join("/"),
458 )
459 });
460
461 self.prompt_history.retain(|prompt| *prompt != user_prompt);
462 self.prompt_history.push_back(user_prompt.clone());
463 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
464 self.prompt_history.pop_front();
465 }
466
467 let codegen = pending_assist.codegen.clone();
468 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
469 let range = codegen.read(cx).range();
470 let start = snapshot.point_to_buffer_offset(range.start);
471 let end = snapshot.point_to_buffer_offset(range.end);
472 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
473 let (start_buffer, start_buffer_offset) = start;
474 let (end_buffer, end_buffer_offset) = end;
475 if start_buffer.remote_id() == end_buffer.remote_id() {
476 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
477 } else {
478 self.finish_inline_assist(assist_id, false, cx);
479 return;
480 }
481 } else {
482 self.finish_inline_assist(assist_id, false, cx);
483 return;
484 };
485
486 let language = buffer.language_at(range.start);
487 let language_name = if let Some(language) = language.as_ref() {
488 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
489 None
490 } else {
491 Some(language.name())
492 }
493 } else {
494 None
495 };
496
497 // Higher Temperature increases the randomness of model outputs.
498 // If Markdown or No Language is Known, increase the randomness for more creative output
499 // If Code, decrease temperature to get more deterministic outputs
500 let temperature = if let Some(language) = language_name.clone() {
501 if language.as_ref() == "Markdown" {
502 1.0
503 } else {
504 0.5
505 }
506 } else {
507 1.0
508 };
509
510 let prompt = cx.background_executor().spawn(async move {
511 let language_name = language_name.as_deref();
512 generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
513 });
514
515 let mut messages = Vec::new();
516 if let Some(context) = context {
517 let request = context.read(cx).to_completion_request(cx);
518 messages = request.messages;
519 }
520 let model = CompletionProvider::global(cx).model();
521
522 cx.spawn(|mut cx| async move {
523 let prompt = prompt.await?;
524
525 messages.push(LanguageModelRequestMessage {
526 role: Role::User,
527 content: prompt,
528 });
529
530 let request = LanguageModelRequest {
531 model,
532 messages,
533 stop: vec!["|END|>".to_string()],
534 temperature,
535 };
536
537 codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
538 anyhow::Ok(())
539 })
540 .detach_and_log_err(cx);
541 }
542
543 fn stop_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
544 let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
545 {
546 pending_assist
547 } else {
548 return;
549 };
550
551 pending_assist
552 .codegen
553 .update(cx, |codegen, cx| codegen.stop(cx));
554 }
555
556 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
557 let mut gutter_pending_ranges = Vec::new();
558 let mut gutter_transformed_ranges = Vec::new();
559 let mut foreground_ranges = Vec::new();
560 let mut inserted_row_ranges = Vec::new();
561 let empty_inline_assist_ids = Vec::new();
562 let inline_assist_ids = self
563 .pending_assist_ids_by_editor
564 .get(&editor.downgrade())
565 .map_or(&empty_inline_assist_ids, |pending_assists| {
566 &pending_assists.assist_ids
567 });
568
569 for inline_assist_id in inline_assist_ids {
570 if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
571 let codegen = pending_assist.codegen.read(cx);
572 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
573
574 if codegen.edit_position != codegen.range().end {
575 gutter_pending_ranges.push(codegen.edit_position..codegen.range().end);
576 }
577
578 if codegen.range().start != codegen.edit_position {
579 gutter_transformed_ranges.push(codegen.range().start..codegen.edit_position);
580 }
581
582 if pending_assist.editor_decorations.is_some() {
583 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
584 }
585 }
586 }
587
588 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
589 merge_ranges(&mut foreground_ranges, &snapshot);
590 merge_ranges(&mut gutter_pending_ranges, &snapshot);
591 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
592 editor.update(cx, |editor, cx| {
593 enum GutterPendingRange {}
594 if gutter_pending_ranges.is_empty() {
595 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
596 } else {
597 editor.highlight_gutter::<GutterPendingRange>(
598 &gutter_pending_ranges,
599 |cx| cx.theme().status().info_background,
600 cx,
601 )
602 }
603
604 enum GutterTransformedRange {}
605 if gutter_transformed_ranges.is_empty() {
606 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
607 } else {
608 editor.highlight_gutter::<GutterTransformedRange>(
609 &gutter_transformed_ranges,
610 |cx| cx.theme().status().info,
611 cx,
612 )
613 }
614
615 if foreground_ranges.is_empty() {
616 editor.clear_highlights::<PendingInlineAssist>(cx);
617 } else {
618 editor.highlight_text::<PendingInlineAssist>(
619 foreground_ranges,
620 HighlightStyle {
621 fade_out: Some(0.6),
622 ..Default::default()
623 },
624 cx,
625 );
626 }
627
628 editor.clear_row_highlights::<PendingInlineAssist>();
629 for row_range in inserted_row_ranges {
630 editor.highlight_rows::<PendingInlineAssist>(
631 row_range,
632 Some(cx.theme().status().info_background),
633 false,
634 cx,
635 );
636 }
637 });
638 }
639
640 fn update_editor_blocks(
641 &mut self,
642 editor: &View<Editor>,
643 assist_id: InlineAssistId,
644 cx: &mut WindowContext,
645 ) {
646 let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) else {
647 return;
648 };
649 let Some(decorations) = pending_assist.editor_decorations.as_mut() else {
650 return;
651 };
652
653 let codegen = pending_assist.codegen.read(cx);
654 let old_snapshot = codegen.snapshot.clone();
655 let old_buffer = codegen.old_buffer.clone();
656 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
657
658 editor.update(cx, |editor, cx| {
659 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
660 editor.remove_blocks(old_blocks, None, cx);
661
662 let mut new_blocks = Vec::new();
663 for (new_row, old_row_range) in deleted_row_ranges {
664 let (_, buffer_start) = old_snapshot
665 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
666 .unwrap();
667 let (_, buffer_end) = old_snapshot
668 .point_to_buffer_offset(Point::new(
669 *old_row_range.end(),
670 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
671 ))
672 .unwrap();
673
674 let deleted_lines_editor = cx.new_view(|cx| {
675 let multi_buffer = cx.new_model(|_| {
676 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
677 });
678 multi_buffer.update(cx, |multi_buffer, cx| {
679 multi_buffer.push_excerpts(
680 old_buffer.clone(),
681 Some(ExcerptRange {
682 context: buffer_start..buffer_end,
683 primary: None,
684 }),
685 cx,
686 );
687 });
688
689 enum DeletedLines {}
690 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
691 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
692 editor.set_show_wrap_guides(false, cx);
693 editor.set_show_gutter(false, cx);
694 editor.scroll_manager.set_forbid_vertical_scroll(true);
695 editor.set_read_only(true);
696 editor.highlight_rows::<DeletedLines>(
697 Anchor::min()..=Anchor::max(),
698 Some(cx.theme().status().deleted_background),
699 false,
700 cx,
701 );
702 editor
703 });
704
705 let height = deleted_lines_editor
706 .update(cx, |editor, cx| editor.max_point(cx).row().0 as u8 + 1);
707 new_blocks.push(BlockProperties {
708 position: new_row,
709 height,
710 style: BlockStyle::Flex,
711 render: Box::new(move |cx| {
712 div()
713 .bg(cx.theme().status().deleted_background)
714 .size_full()
715 .pl(cx.gutter_dimensions.full_width())
716 .child(deleted_lines_editor.clone())
717 .into_any_element()
718 }),
719 disposition: BlockDisposition::Above,
720 });
721 }
722
723 decorations.removed_line_block_ids = editor
724 .insert_blocks(new_blocks, None, cx)
725 .into_iter()
726 .collect();
727 })
728 }
729}
730
731fn build_inline_assist_editor_renderer(
732 editor: &View<InlineAssistEditor>,
733 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
734) -> RenderBlock {
735 let editor = editor.clone();
736 Box::new(move |cx: &mut BlockContext| {
737 *gutter_dimensions.lock() = *cx.gutter_dimensions;
738 editor.clone().into_any_element()
739 })
740}
741
742#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
743struct InlineAssistId(usize);
744
745impl InlineAssistId {
746 fn post_inc(&mut self) -> InlineAssistId {
747 let id = *self;
748 self.0 += 1;
749 id
750 }
751}
752
753enum InlineAssistEditorEvent {
754 Started,
755 Stopped,
756 Confirmed,
757 Canceled,
758 Dismissed,
759 Resized { height_in_lines: u8 },
760}
761
762struct InlineAssistEditor {
763 id: InlineAssistId,
764 height_in_lines: u8,
765 prompt_editor: View<Editor>,
766 edited_since_done: bool,
767 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
768 prompt_history: VecDeque<String>,
769 prompt_history_ix: Option<usize>,
770 pending_prompt: String,
771 codegen: Model<Codegen>,
772 workspace: Option<WeakView<Workspace>>,
773 _subscriptions: Vec<Subscription>,
774}
775
776impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
777
778impl Render for InlineAssistEditor {
779 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
780 let gutter_dimensions = *self.gutter_dimensions.lock();
781
782 let buttons = match &self.codegen.read(cx).status {
783 CodegenStatus::Idle => {
784 vec![
785 IconButton::new("start", IconName::Sparkle)
786 .icon_color(Color::Muted)
787 .size(ButtonSize::None)
788 .icon_size(IconSize::XSmall)
789 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
790 .on_click(
791 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Started)),
792 ),
793 IconButton::new("cancel", IconName::Close)
794 .icon_color(Color::Muted)
795 .size(ButtonSize::None)
796 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
797 .on_click(
798 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Canceled)),
799 ),
800 ]
801 }
802 CodegenStatus::Pending => {
803 vec![
804 IconButton::new("stop", IconName::Stop)
805 .icon_color(Color::Error)
806 .size(ButtonSize::None)
807 .icon_size(IconSize::XSmall)
808 .tooltip(|cx| {
809 Tooltip::with_meta(
810 "Interrupt Transformation",
811 Some(&menu::Cancel),
812 "Changes won't be discarded",
813 cx,
814 )
815 })
816 .on_click(
817 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Stopped)),
818 ),
819 IconButton::new("cancel", IconName::Close)
820 .icon_color(Color::Muted)
821 .size(ButtonSize::None)
822 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
823 .on_click(
824 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Canceled)),
825 ),
826 ]
827 }
828 CodegenStatus::Error(_) | CodegenStatus::Done => {
829 vec![
830 if self.edited_since_done {
831 IconButton::new("restart", IconName::RotateCw)
832 .icon_color(Color::Info)
833 .icon_size(IconSize::XSmall)
834 .size(ButtonSize::None)
835 .tooltip(|cx| {
836 Tooltip::with_meta(
837 "Restart Transformation",
838 Some(&menu::Confirm),
839 "Changes will be discarded",
840 cx,
841 )
842 })
843 .on_click(cx.listener(|_, _, cx| {
844 cx.emit(InlineAssistEditorEvent::Started);
845 }))
846 } else {
847 IconButton::new("confirm", IconName::Check)
848 .icon_color(Color::Info)
849 .size(ButtonSize::None)
850 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
851 .on_click(cx.listener(|_, _, cx| {
852 cx.emit(InlineAssistEditorEvent::Confirmed);
853 }))
854 },
855 IconButton::new("cancel", IconName::Close)
856 .icon_color(Color::Muted)
857 .size(ButtonSize::None)
858 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
859 .on_click(
860 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Canceled)),
861 ),
862 ]
863 }
864 };
865
866 v_flex().h_full().w_full().justify_end().child(
867 h_flex()
868 .bg(cx.theme().colors().editor_background)
869 .border_y_1()
870 .border_color(cx.theme().status().info_border)
871 .py_1p5()
872 .w_full()
873 .on_action(cx.listener(Self::confirm))
874 .on_action(cx.listener(Self::cancel))
875 .on_action(cx.listener(Self::move_up))
876 .on_action(cx.listener(Self::move_down))
877 .child(
878 h_flex()
879 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
880 // .pr(gutter_dimensions.fold_area_width())
881 .justify_center()
882 .gap_2()
883 .children(self.workspace.clone().map(|workspace| {
884 IconButton::new("context", IconName::Context)
885 .size(ButtonSize::None)
886 .icon_size(IconSize::XSmall)
887 .icon_color(Color::Muted)
888 .on_click({
889 let workspace = workspace.clone();
890 cx.listener(move |_, _, cx| {
891 workspace
892 .update(cx, |workspace, cx| {
893 workspace.focus_panel::<AssistantPanel>(cx);
894 })
895 .ok();
896 })
897 })
898 .tooltip(move |cx| {
899 let token_count = workspace.upgrade().and_then(|workspace| {
900 let panel =
901 workspace.read(cx).panel::<AssistantPanel>(cx)?;
902 let context = panel.read(cx).active_context(cx)?;
903 context.read(cx).token_count()
904 });
905 if let Some(token_count) = token_count {
906 Tooltip::with_meta(
907 format!(
908 "{} Additional Context Tokens from Assistant",
909 token_count
910 ),
911 Some(&crate::ToggleFocus),
912 "Click to open…",
913 cx,
914 )
915 } else {
916 Tooltip::for_action(
917 "Toggle Assistant Panel",
918 &crate::ToggleFocus,
919 cx,
920 )
921 }
922 })
923 }))
924 .children(
925 if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
926 let error_message = SharedString::from(error.to_string());
927 Some(
928 div()
929 .id("error")
930 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
931 .child(
932 Icon::new(IconName::XCircle)
933 .size(IconSize::Small)
934 .color(Color::Error),
935 ),
936 )
937 } else {
938 None
939 },
940 ),
941 )
942 .child(div().flex_1().child(self.render_prompt_editor(cx)))
943 .child(h_flex().gap_2().pr_4().children(buttons)),
944 )
945 }
946}
947
948impl FocusableView for InlineAssistEditor {
949 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
950 self.prompt_editor.focus_handle(cx)
951 }
952}
953
954impl InlineAssistEditor {
955 const MAX_LINES: u8 = 8;
956
957 #[allow(clippy::too_many_arguments)]
958 fn new(
959 id: InlineAssistId,
960 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
961 prompt_history: VecDeque<String>,
962 codegen: Model<Codegen>,
963 workspace: Option<WeakView<Workspace>>,
964 cx: &mut ViewContext<Self>,
965 ) -> Self {
966 let prompt_editor = cx.new_view(|cx| {
967 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
968 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
969 editor.set_placeholder_text("Add a prompt…", cx);
970 editor
971 });
972 cx.focus_view(&prompt_editor);
973
974 let subscriptions = vec![
975 cx.observe(&codegen, Self::handle_codegen_changed),
976 cx.observe(&prompt_editor, Self::handle_prompt_editor_changed),
977 cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
978 ];
979
980 let mut this = Self {
981 id,
982 height_in_lines: 1,
983 prompt_editor,
984 edited_since_done: false,
985 gutter_dimensions,
986 prompt_history,
987 prompt_history_ix: None,
988 pending_prompt: String::new(),
989 codegen,
990 workspace,
991 _subscriptions: subscriptions,
992 };
993 this.count_lines(cx);
994 this
995 }
996
997 fn prompt(&self, cx: &AppContext) -> String {
998 self.prompt_editor.read(cx).text(cx)
999 }
1000
1001 fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
1002 let height_in_lines = cmp::max(
1003 2, // Make the editor at least two lines tall, to account for padding and buttons.
1004 cmp::min(
1005 self.prompt_editor
1006 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
1007 Self::MAX_LINES as u32,
1008 ),
1009 ) as u8;
1010
1011 if height_in_lines != self.height_in_lines {
1012 self.height_in_lines = height_in_lines;
1013 cx.emit(InlineAssistEditorEvent::Resized { height_in_lines });
1014 }
1015 }
1016
1017 fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
1018 self.count_lines(cx);
1019 }
1020
1021 fn handle_prompt_editor_events(
1022 &mut self,
1023 _: View<Editor>,
1024 event: &EditorEvent,
1025 cx: &mut ViewContext<Self>,
1026 ) {
1027 match event {
1028 EditorEvent::Edited => {
1029 let prompt = self.prompt_editor.read(cx).text(cx);
1030 if self
1031 .prompt_history_ix
1032 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1033 {
1034 self.prompt_history_ix.take();
1035 self.pending_prompt = prompt;
1036 }
1037
1038 self.edited_since_done = true;
1039 cx.notify();
1040 }
1041 EditorEvent::Blurred => {
1042 if let CodegenStatus::Idle = &self.codegen.read(cx).status {
1043 let assistant_panel_is_focused = self
1044 .workspace
1045 .as_ref()
1046 .and_then(|workspace| {
1047 let panel =
1048 workspace.upgrade()?.read(cx).panel::<AssistantPanel>(cx)?;
1049 Some(panel.focus_handle(cx).contains_focused(cx))
1050 })
1051 .unwrap_or(false);
1052
1053 if !assistant_panel_is_focused {
1054 cx.emit(InlineAssistEditorEvent::Canceled);
1055 }
1056 }
1057 }
1058 _ => {}
1059 }
1060 }
1061
1062 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1063 match &self.codegen.read(cx).status {
1064 CodegenStatus::Idle => {
1065 self.prompt_editor
1066 .update(cx, |editor, _| editor.set_read_only(false));
1067 }
1068 CodegenStatus::Pending => {
1069 self.prompt_editor
1070 .update(cx, |editor, _| editor.set_read_only(true));
1071 }
1072 CodegenStatus::Done | CodegenStatus::Error(_) => {
1073 self.edited_since_done = false;
1074 self.prompt_editor
1075 .update(cx, |editor, _| editor.set_read_only(false));
1076 }
1077 }
1078 }
1079
1080 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1081 match &self.codegen.read(cx).status {
1082 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1083 cx.emit(InlineAssistEditorEvent::Canceled);
1084 }
1085 CodegenStatus::Pending => {
1086 cx.emit(InlineAssistEditorEvent::Stopped);
1087 }
1088 }
1089 }
1090
1091 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1092 match &self.codegen.read(cx).status {
1093 CodegenStatus::Idle => {
1094 cx.emit(InlineAssistEditorEvent::Started);
1095 }
1096 CodegenStatus::Pending => {
1097 cx.emit(InlineAssistEditorEvent::Dismissed);
1098 }
1099 CodegenStatus::Done | CodegenStatus::Error(_) => {
1100 if self.edited_since_done {
1101 cx.emit(InlineAssistEditorEvent::Started);
1102 } else {
1103 cx.emit(InlineAssistEditorEvent::Confirmed);
1104 }
1105 }
1106 }
1107 }
1108
1109 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1110 if let Some(ix) = self.prompt_history_ix {
1111 if ix > 0 {
1112 self.prompt_history_ix = Some(ix - 1);
1113 let prompt = self.prompt_history[ix - 1].as_str();
1114 self.prompt_editor.update(cx, |editor, cx| {
1115 editor.set_text(prompt, cx);
1116 editor.move_to_beginning(&Default::default(), cx);
1117 });
1118 }
1119 } else if !self.prompt_history.is_empty() {
1120 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1121 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1122 self.prompt_editor.update(cx, |editor, cx| {
1123 editor.set_text(prompt, cx);
1124 editor.move_to_beginning(&Default::default(), cx);
1125 });
1126 }
1127 }
1128
1129 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1130 if let Some(ix) = self.prompt_history_ix {
1131 if ix < self.prompt_history.len() - 1 {
1132 self.prompt_history_ix = Some(ix + 1);
1133 let prompt = self.prompt_history[ix + 1].as_str();
1134 self.prompt_editor.update(cx, |editor, cx| {
1135 editor.set_text(prompt, cx);
1136 editor.move_to_end(&Default::default(), cx)
1137 });
1138 } else {
1139 self.prompt_history_ix = None;
1140 let prompt = self.pending_prompt.as_str();
1141 self.prompt_editor.update(cx, |editor, cx| {
1142 editor.set_text(prompt, cx);
1143 editor.move_to_end(&Default::default(), cx)
1144 });
1145 }
1146 }
1147 }
1148
1149 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1150 let settings = ThemeSettings::get_global(cx);
1151 let text_style = TextStyle {
1152 color: if self.prompt_editor.read(cx).read_only(cx) {
1153 cx.theme().colors().text_disabled
1154 } else {
1155 cx.theme().colors().text
1156 },
1157 font_family: settings.ui_font.family.clone(),
1158 font_features: settings.ui_font.features.clone(),
1159 font_size: rems(0.875).into(),
1160 font_weight: FontWeight::NORMAL,
1161 font_style: FontStyle::Normal,
1162 line_height: relative(1.3),
1163 background_color: None,
1164 underline: None,
1165 strikethrough: None,
1166 white_space: WhiteSpace::Normal,
1167 };
1168 EditorElement::new(
1169 &self.prompt_editor,
1170 EditorStyle {
1171 background: cx.theme().colors().editor_background,
1172 local_player: cx.theme().players().local(),
1173 text: text_style,
1174 ..Default::default()
1175 },
1176 )
1177 }
1178}
1179
1180struct PendingInlineAssist {
1181 editor: WeakView<Editor>,
1182 editor_decorations: Option<PendingInlineAssistDecorations>,
1183 codegen: Model<Codegen>,
1184 _subscriptions: Vec<Subscription>,
1185 workspace: Option<WeakView<Workspace>>,
1186 include_context: bool,
1187}
1188
1189struct PendingInlineAssistDecorations {
1190 prompt_block_id: BlockId,
1191 prompt_editor: View<InlineAssistEditor>,
1192 removed_line_block_ids: HashSet<BlockId>,
1193 end_block_id: BlockId,
1194}
1195
1196#[derive(Debug)]
1197pub enum CodegenEvent {
1198 Finished,
1199 Undone,
1200}
1201
1202#[derive(Clone)]
1203pub enum CodegenKind {
1204 Transform { range: Range<Anchor> },
1205 Generate { position: Anchor },
1206}
1207
1208impl CodegenKind {
1209 fn range(&self, snapshot: &MultiBufferSnapshot) -> Range<Anchor> {
1210 match self {
1211 CodegenKind::Transform { range } => range.clone(),
1212 CodegenKind::Generate { position } => position.bias_left(snapshot)..*position,
1213 }
1214 }
1215}
1216
1217pub struct Codegen {
1218 buffer: Model<MultiBuffer>,
1219 old_buffer: Model<Buffer>,
1220 snapshot: MultiBufferSnapshot,
1221 kind: CodegenKind,
1222 edit_position: Anchor,
1223 last_equal_ranges: Vec<Range<Anchor>>,
1224 transaction_id: Option<TransactionId>,
1225 status: CodegenStatus,
1226 generation: Task<()>,
1227 diff: Diff,
1228 telemetry: Option<Arc<Telemetry>>,
1229 _subscription: gpui::Subscription,
1230}
1231
1232enum CodegenStatus {
1233 Idle,
1234 Pending,
1235 Done,
1236 Error(anyhow::Error),
1237}
1238
1239#[derive(Default)]
1240struct Diff {
1241 task: Option<Task<()>>,
1242 should_update: bool,
1243 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1244 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
1245}
1246
1247impl EventEmitter<CodegenEvent> for Codegen {}
1248
1249impl Codegen {
1250 pub fn new(
1251 buffer: Model<MultiBuffer>,
1252 kind: CodegenKind,
1253 telemetry: Option<Arc<Telemetry>>,
1254 cx: &mut ModelContext<Self>,
1255 ) -> Self {
1256 let snapshot = buffer.read(cx).snapshot(cx);
1257
1258 let (old_buffer, _, _) = buffer
1259 .read(cx)
1260 .range_to_buffer_ranges(kind.range(&snapshot), cx)
1261 .pop()
1262 .unwrap();
1263 let old_buffer = cx.new_model(|cx| {
1264 let old_buffer = old_buffer.read(cx);
1265 let text = old_buffer.as_rope().clone();
1266 let line_ending = old_buffer.line_ending();
1267 let language = old_buffer.language().cloned();
1268 let language_registry = old_buffer.language_registry();
1269
1270 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
1271 buffer.set_language(language, cx);
1272 if let Some(language_registry) = language_registry {
1273 buffer.set_language_registry(language_registry)
1274 }
1275 buffer
1276 });
1277
1278 Self {
1279 buffer: buffer.clone(),
1280 old_buffer,
1281 edit_position: kind.range(&snapshot).start,
1282 snapshot,
1283 kind,
1284 last_equal_ranges: Default::default(),
1285 transaction_id: Default::default(),
1286 status: CodegenStatus::Idle,
1287 generation: Task::ready(()),
1288 diff: Diff::default(),
1289 telemetry,
1290 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
1291 }
1292 }
1293
1294 fn handle_buffer_event(
1295 &mut self,
1296 _buffer: Model<MultiBuffer>,
1297 event: &multi_buffer::Event,
1298 cx: &mut ModelContext<Self>,
1299 ) {
1300 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
1301 if self.transaction_id == Some(*transaction_id) {
1302 self.transaction_id = None;
1303 self.generation = Task::ready(());
1304 cx.emit(CodegenEvent::Undone);
1305 }
1306 }
1307 }
1308
1309 pub fn range(&self) -> Range<Anchor> {
1310 self.kind.range(&self.snapshot)
1311 }
1312
1313 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
1314 &self.last_equal_ranges
1315 }
1316
1317 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
1318 let range = self.range();
1319 let snapshot = self.snapshot.clone();
1320 let selected_text = snapshot
1321 .text_for_range(range.start..range.end)
1322 .collect::<Rope>();
1323
1324 let selection_start = range.start.to_point(&snapshot);
1325 let suggested_line_indent = snapshot
1326 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
1327 .into_values()
1328 .next()
1329 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
1330
1331 let model_telemetry_id = prompt.model.telemetry_id();
1332 let response = CompletionProvider::global(cx).complete(prompt);
1333 let telemetry = self.telemetry.clone();
1334 self.edit_position = range.start;
1335 self.diff = Diff::default();
1336 self.status = CodegenStatus::Pending;
1337 self.generation = cx.spawn(|this, mut cx| {
1338 async move {
1339 let generate = async {
1340 let mut edit_start = range.start.to_offset(&snapshot);
1341
1342 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
1343 let diff: Task<anyhow::Result<()>> =
1344 cx.background_executor().spawn(async move {
1345 let mut response_latency = None;
1346 let request_start = Instant::now();
1347 let diff = async {
1348 let chunks = strip_invalid_spans_from_codeblock(response.await?);
1349 futures::pin_mut!(chunks);
1350 let mut diff = StreamingDiff::new(selected_text.to_string());
1351
1352 let mut new_text = String::new();
1353 let mut base_indent = None;
1354 let mut line_indent = None;
1355 let mut first_line = true;
1356
1357 while let Some(chunk) = chunks.next().await {
1358 if response_latency.is_none() {
1359 response_latency = Some(request_start.elapsed());
1360 }
1361 let chunk = chunk?;
1362
1363 let mut lines = chunk.split('\n').peekable();
1364 while let Some(line) = lines.next() {
1365 new_text.push_str(line);
1366 if line_indent.is_none() {
1367 if let Some(non_whitespace_ch_ix) =
1368 new_text.find(|ch: char| !ch.is_whitespace())
1369 {
1370 line_indent = Some(non_whitespace_ch_ix);
1371 base_indent = base_indent.or(line_indent);
1372
1373 let line_indent = line_indent.unwrap();
1374 let base_indent = base_indent.unwrap();
1375 let indent_delta =
1376 line_indent as i32 - base_indent as i32;
1377 let mut corrected_indent_len = cmp::max(
1378 0,
1379 suggested_line_indent.len as i32 + indent_delta,
1380 )
1381 as usize;
1382 if first_line {
1383 corrected_indent_len = corrected_indent_len
1384 .saturating_sub(
1385 selection_start.column as usize,
1386 );
1387 }
1388
1389 let indent_char = suggested_line_indent.char();
1390 let mut indent_buffer = [0; 4];
1391 let indent_str =
1392 indent_char.encode_utf8(&mut indent_buffer);
1393 new_text.replace_range(
1394 ..line_indent,
1395 &indent_str.repeat(corrected_indent_len),
1396 );
1397 }
1398 }
1399
1400 if line_indent.is_some() {
1401 hunks_tx.send(diff.push_new(&new_text)).await?;
1402 new_text.clear();
1403 }
1404
1405 if lines.peek().is_some() {
1406 hunks_tx.send(diff.push_new("\n")).await?;
1407 line_indent = None;
1408 first_line = false;
1409 }
1410 }
1411 }
1412 hunks_tx.send(diff.push_new(&new_text)).await?;
1413 hunks_tx.send(diff.finish()).await?;
1414
1415 anyhow::Ok(())
1416 };
1417
1418 let result = diff.await;
1419
1420 let error_message =
1421 result.as_ref().err().map(|error| error.to_string());
1422 if let Some(telemetry) = telemetry {
1423 telemetry.report_assistant_event(
1424 None,
1425 telemetry_events::AssistantKind::Inline,
1426 model_telemetry_id,
1427 response_latency,
1428 error_message,
1429 );
1430 }
1431
1432 result?;
1433 Ok(())
1434 });
1435
1436 while let Some(hunks) = hunks_rx.next().await {
1437 this.update(&mut cx, |this, cx| {
1438 this.last_equal_ranges.clear();
1439
1440 let transaction = this.buffer.update(cx, |buffer, cx| {
1441 // Avoid grouping assistant edits with user edits.
1442 buffer.finalize_last_transaction(cx);
1443
1444 buffer.start_transaction(cx);
1445 buffer.edit(
1446 hunks.into_iter().filter_map(|hunk| match hunk {
1447 Hunk::Insert { text } => {
1448 let edit_start = snapshot.anchor_after(edit_start);
1449 Some((edit_start..edit_start, text))
1450 }
1451 Hunk::Remove { len } => {
1452 let edit_end = edit_start + len;
1453 let edit_range = snapshot.anchor_after(edit_start)
1454 ..snapshot.anchor_before(edit_end);
1455 edit_start = edit_end;
1456 Some((edit_range, String::new()))
1457 }
1458 Hunk::Keep { len } => {
1459 let edit_end = edit_start + len;
1460 let edit_range = snapshot.anchor_after(edit_start)
1461 ..snapshot.anchor_before(edit_end);
1462 edit_start = edit_end;
1463 this.last_equal_ranges.push(edit_range);
1464 None
1465 }
1466 }),
1467 None,
1468 cx,
1469 );
1470 this.edit_position = snapshot.anchor_after(edit_start);
1471
1472 buffer.end_transaction(cx)
1473 });
1474
1475 if let Some(transaction) = transaction {
1476 if let Some(first_transaction) = this.transaction_id {
1477 // Group all assistant edits into the first transaction.
1478 this.buffer.update(cx, |buffer, cx| {
1479 buffer.merge_transactions(
1480 transaction,
1481 first_transaction,
1482 cx,
1483 )
1484 });
1485 } else {
1486 this.transaction_id = Some(transaction);
1487 this.buffer.update(cx, |buffer, cx| {
1488 buffer.finalize_last_transaction(cx)
1489 });
1490 }
1491 }
1492
1493 this.update_diff(cx);
1494 cx.notify();
1495 })?;
1496 }
1497
1498 diff.await?;
1499
1500 anyhow::Ok(())
1501 };
1502
1503 let result = generate.await;
1504 this.update(&mut cx, |this, cx| {
1505 this.last_equal_ranges.clear();
1506 if let Err(error) = result {
1507 this.status = CodegenStatus::Error(error);
1508 } else {
1509 this.status = CodegenStatus::Done;
1510 }
1511 cx.emit(CodegenEvent::Finished);
1512 cx.notify();
1513 })
1514 .ok();
1515 }
1516 });
1517 cx.notify();
1518 }
1519
1520 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
1521 self.last_equal_ranges.clear();
1522 self.status = CodegenStatus::Done;
1523 self.generation = Task::ready(());
1524 cx.emit(CodegenEvent::Finished);
1525 cx.notify();
1526 }
1527
1528 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1529 if let Some(transaction_id) = self.transaction_id.take() {
1530 self.buffer
1531 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1532 }
1533 }
1534
1535 fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
1536 if self.diff.task.is_some() {
1537 self.diff.should_update = true;
1538 } else {
1539 self.diff.should_update = false;
1540
1541 let old_snapshot = self.snapshot.clone();
1542 let old_range = self.range().to_point(&old_snapshot);
1543 let new_snapshot = self.buffer.read(cx).snapshot(cx);
1544 let new_range = self.range().to_point(&new_snapshot);
1545
1546 self.diff.task = Some(cx.spawn(|this, mut cx| async move {
1547 let (deleted_row_ranges, inserted_row_ranges) = cx
1548 .background_executor()
1549 .spawn(async move {
1550 let old_text = old_snapshot
1551 .text_for_range(
1552 Point::new(old_range.start.row, 0)
1553 ..Point::new(
1554 old_range.end.row,
1555 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1556 ),
1557 )
1558 .collect::<String>();
1559 let new_text = new_snapshot
1560 .text_for_range(
1561 Point::new(new_range.start.row, 0)
1562 ..Point::new(
1563 new_range.end.row,
1564 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1565 ),
1566 )
1567 .collect::<String>();
1568
1569 let mut old_row = old_range.start.row;
1570 let mut new_row = new_range.start.row;
1571 let diff = TextDiff::from_lines(old_text.as_str(), new_text.as_str());
1572
1573 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1574 let mut inserted_row_ranges = Vec::new();
1575 for change in diff.iter_all_changes() {
1576 let line_count = change.value().lines().count() as u32;
1577 match change.tag() {
1578 similar::ChangeTag::Equal => {
1579 old_row += line_count;
1580 new_row += line_count;
1581 }
1582 similar::ChangeTag::Delete => {
1583 let old_end_row = old_row + line_count - 1;
1584 let new_row =
1585 new_snapshot.anchor_before(Point::new(new_row, 0));
1586
1587 if let Some((_, last_deleted_row_range)) =
1588 deleted_row_ranges.last_mut()
1589 {
1590 if *last_deleted_row_range.end() + 1 == old_row {
1591 *last_deleted_row_range =
1592 *last_deleted_row_range.start()..=old_end_row;
1593 } else {
1594 deleted_row_ranges
1595 .push((new_row, old_row..=old_end_row));
1596 }
1597 } else {
1598 deleted_row_ranges.push((new_row, old_row..=old_end_row));
1599 }
1600
1601 old_row += line_count;
1602 }
1603 similar::ChangeTag::Insert => {
1604 let new_end_row = new_row + line_count - 1;
1605 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1606 let end = new_snapshot.anchor_before(Point::new(
1607 new_end_row,
1608 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1609 ));
1610 inserted_row_ranges.push(start..=end);
1611 new_row += line_count;
1612 }
1613 }
1614 }
1615
1616 (deleted_row_ranges, inserted_row_ranges)
1617 })
1618 .await;
1619
1620 this.update(&mut cx, |this, cx| {
1621 this.diff.deleted_row_ranges = deleted_row_ranges;
1622 this.diff.inserted_row_ranges = inserted_row_ranges;
1623 this.diff.task = None;
1624 if this.diff.should_update {
1625 this.update_diff(cx);
1626 }
1627 cx.notify();
1628 })
1629 .ok();
1630 }));
1631 }
1632 }
1633}
1634
1635fn strip_invalid_spans_from_codeblock(
1636 stream: impl Stream<Item = Result<String>>,
1637) -> impl Stream<Item = Result<String>> {
1638 let mut first_line = true;
1639 let mut buffer = String::new();
1640 let mut starts_with_markdown_codeblock = false;
1641 let mut includes_start_or_end_span = false;
1642 stream.filter_map(move |chunk| {
1643 let chunk = match chunk {
1644 Ok(chunk) => chunk,
1645 Err(err) => return future::ready(Some(Err(err))),
1646 };
1647 buffer.push_str(&chunk);
1648
1649 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1650 includes_start_or_end_span = true;
1651
1652 buffer = buffer
1653 .strip_prefix("<|S|>")
1654 .or_else(|| buffer.strip_prefix("<|S|"))
1655 .unwrap_or(&buffer)
1656 .to_string();
1657 } else if buffer.ends_with("|E|>") {
1658 includes_start_or_end_span = true;
1659 } else if buffer.starts_with("<|")
1660 || buffer.starts_with("<|S")
1661 || buffer.starts_with("<|S|")
1662 || buffer.ends_with('|')
1663 || buffer.ends_with("|E")
1664 || buffer.ends_with("|E|")
1665 {
1666 return future::ready(None);
1667 }
1668
1669 if first_line {
1670 if buffer.is_empty() || buffer == "`" || buffer == "``" {
1671 return future::ready(None);
1672 } else if buffer.starts_with("```") {
1673 starts_with_markdown_codeblock = true;
1674 if let Some(newline_ix) = buffer.find('\n') {
1675 buffer.replace_range(..newline_ix + 1, "");
1676 first_line = false;
1677 } else {
1678 return future::ready(None);
1679 }
1680 }
1681 }
1682
1683 let mut text = buffer.to_string();
1684 if starts_with_markdown_codeblock {
1685 text = text
1686 .strip_suffix("\n```\n")
1687 .or_else(|| text.strip_suffix("\n```"))
1688 .or_else(|| text.strip_suffix("\n``"))
1689 .or_else(|| text.strip_suffix("\n`"))
1690 .or_else(|| text.strip_suffix('\n'))
1691 .unwrap_or(&text)
1692 .to_string();
1693 }
1694
1695 if includes_start_or_end_span {
1696 text = text
1697 .strip_suffix("|E|>")
1698 .or_else(|| text.strip_suffix("E|>"))
1699 .or_else(|| text.strip_prefix("|>"))
1700 .or_else(|| text.strip_prefix('>'))
1701 .unwrap_or(&text)
1702 .to_string();
1703 };
1704
1705 if text.contains('\n') {
1706 first_line = false;
1707 }
1708
1709 let remainder = buffer.split_off(text.len());
1710 let result = if buffer.is_empty() {
1711 None
1712 } else {
1713 Some(Ok(buffer.clone()))
1714 };
1715
1716 buffer = remainder;
1717 future::ready(result)
1718 })
1719}
1720
1721fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1722 ranges.sort_unstable_by(|a, b| {
1723 a.start
1724 .cmp(&b.start, buffer)
1725 .then_with(|| b.end.cmp(&a.end, buffer))
1726 });
1727
1728 let mut ix = 0;
1729 while ix + 1 < ranges.len() {
1730 let b = ranges[ix + 1].clone();
1731 let a = &mut ranges[ix];
1732 if a.end.cmp(&b.start, buffer).is_gt() {
1733 if a.end.cmp(&b.end, buffer).is_lt() {
1734 a.end = b.end;
1735 }
1736 ranges.remove(ix + 1);
1737 } else {
1738 ix += 1;
1739 }
1740 }
1741}
1742
1743#[cfg(test)]
1744mod tests {
1745 use std::sync::Arc;
1746
1747 use crate::FakeCompletionProvider;
1748
1749 use super::*;
1750 use futures::stream::{self};
1751 use gpui::{Context, TestAppContext};
1752 use indoc::indoc;
1753 use language::{
1754 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1755 Point,
1756 };
1757 use rand::prelude::*;
1758 use serde::Serialize;
1759 use settings::SettingsStore;
1760
1761 #[derive(Serialize)]
1762 pub struct DummyCompletionRequest {
1763 pub name: String,
1764 }
1765
1766 #[gpui::test(iterations = 10)]
1767 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1768 let provider = FakeCompletionProvider::default();
1769 cx.set_global(cx.update(SettingsStore::test));
1770 cx.set_global(CompletionProvider::Fake(provider.clone()));
1771 cx.update(language_settings::init);
1772
1773 let text = indoc! {"
1774 fn main() {
1775 let x = 0;
1776 for _ in 0..10 {
1777 x += 1;
1778 }
1779 }
1780 "};
1781 let buffer =
1782 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1783 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1784 let range = buffer.read_with(cx, |buffer, cx| {
1785 let snapshot = buffer.snapshot(cx);
1786 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1787 });
1788 let codegen = cx.new_model(|cx| {
1789 Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1790 });
1791
1792 let request = LanguageModelRequest::default();
1793 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1794
1795 let mut new_text = concat!(
1796 " let mut x = 0;\n",
1797 " while x < 10 {\n",
1798 " x += 1;\n",
1799 " }",
1800 );
1801 while !new_text.is_empty() {
1802 let max_len = cmp::min(new_text.len(), 10);
1803 let len = rng.gen_range(1..=max_len);
1804 let (chunk, suffix) = new_text.split_at(len);
1805 provider.send_completion(chunk.into());
1806 new_text = suffix;
1807 cx.background_executor.run_until_parked();
1808 }
1809 provider.finish_completion();
1810 cx.background_executor.run_until_parked();
1811
1812 assert_eq!(
1813 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1814 indoc! {"
1815 fn main() {
1816 let mut x = 0;
1817 while x < 10 {
1818 x += 1;
1819 }
1820 }
1821 "}
1822 );
1823 }
1824
1825 #[gpui::test(iterations = 10)]
1826 async fn test_autoindent_when_generating_past_indentation(
1827 cx: &mut TestAppContext,
1828 mut rng: StdRng,
1829 ) {
1830 let provider = FakeCompletionProvider::default();
1831 cx.set_global(CompletionProvider::Fake(provider.clone()));
1832 cx.set_global(cx.update(SettingsStore::test));
1833 cx.update(language_settings::init);
1834
1835 let text = indoc! {"
1836 fn main() {
1837 le
1838 }
1839 "};
1840 let buffer =
1841 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1842 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1843 let position = buffer.read_with(cx, |buffer, cx| {
1844 let snapshot = buffer.snapshot(cx);
1845 snapshot.anchor_before(Point::new(1, 6))
1846 });
1847 let codegen = cx.new_model(|cx| {
1848 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1849 });
1850
1851 let request = LanguageModelRequest::default();
1852 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1853
1854 let mut new_text = concat!(
1855 "t mut x = 0;\n",
1856 "while x < 10 {\n",
1857 " x += 1;\n",
1858 "}", //
1859 );
1860 while !new_text.is_empty() {
1861 let max_len = cmp::min(new_text.len(), 10);
1862 let len = rng.gen_range(1..=max_len);
1863 let (chunk, suffix) = new_text.split_at(len);
1864 provider.send_completion(chunk.into());
1865 new_text = suffix;
1866 cx.background_executor.run_until_parked();
1867 }
1868 provider.finish_completion();
1869 cx.background_executor.run_until_parked();
1870
1871 assert_eq!(
1872 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1873 indoc! {"
1874 fn main() {
1875 let mut x = 0;
1876 while x < 10 {
1877 x += 1;
1878 }
1879 }
1880 "}
1881 );
1882 }
1883
1884 #[gpui::test(iterations = 10)]
1885 async fn test_autoindent_when_generating_before_indentation(
1886 cx: &mut TestAppContext,
1887 mut rng: StdRng,
1888 ) {
1889 let provider = FakeCompletionProvider::default();
1890 cx.set_global(CompletionProvider::Fake(provider.clone()));
1891 cx.set_global(cx.update(SettingsStore::test));
1892 cx.update(language_settings::init);
1893
1894 let text = concat!(
1895 "fn main() {\n",
1896 " \n",
1897 "}\n" //
1898 );
1899 let buffer =
1900 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1901 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1902 let position = buffer.read_with(cx, |buffer, cx| {
1903 let snapshot = buffer.snapshot(cx);
1904 snapshot.anchor_before(Point::new(1, 2))
1905 });
1906 let codegen = cx.new_model(|cx| {
1907 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1908 });
1909
1910 let request = LanguageModelRequest::default();
1911 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1912
1913 let mut new_text = concat!(
1914 "let mut x = 0;\n",
1915 "while x < 10 {\n",
1916 " x += 1;\n",
1917 "}", //
1918 );
1919 while !new_text.is_empty() {
1920 let max_len = cmp::min(new_text.len(), 10);
1921 let len = rng.gen_range(1..=max_len);
1922 let (chunk, suffix) = new_text.split_at(len);
1923 provider.send_completion(chunk.into());
1924 new_text = suffix;
1925 cx.background_executor.run_until_parked();
1926 }
1927 provider.finish_completion();
1928 cx.background_executor.run_until_parked();
1929
1930 assert_eq!(
1931 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1932 indoc! {"
1933 fn main() {
1934 let mut x = 0;
1935 while x < 10 {
1936 x += 1;
1937 }
1938 }
1939 "}
1940 );
1941 }
1942
1943 #[gpui::test]
1944 async fn test_strip_invalid_spans_from_codeblock() {
1945 assert_eq!(
1946 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1947 .map(|chunk| chunk.unwrap())
1948 .collect::<String>()
1949 .await,
1950 "Lorem ipsum dolor"
1951 );
1952 assert_eq!(
1953 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1954 .map(|chunk| chunk.unwrap())
1955 .collect::<String>()
1956 .await,
1957 "Lorem ipsum dolor"
1958 );
1959 assert_eq!(
1960 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1961 .map(|chunk| chunk.unwrap())
1962 .collect::<String>()
1963 .await,
1964 "Lorem ipsum dolor"
1965 );
1966 assert_eq!(
1967 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1968 .map(|chunk| chunk.unwrap())
1969 .collect::<String>()
1970 .await,
1971 "Lorem ipsum dolor"
1972 );
1973 assert_eq!(
1974 strip_invalid_spans_from_codeblock(chunks(
1975 "```html\n```js\nLorem ipsum dolor\n```\n```",
1976 2
1977 ))
1978 .map(|chunk| chunk.unwrap())
1979 .collect::<String>()
1980 .await,
1981 "```js\nLorem ipsum dolor\n```"
1982 );
1983 assert_eq!(
1984 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1985 .map(|chunk| chunk.unwrap())
1986 .collect::<String>()
1987 .await,
1988 "``\nLorem ipsum dolor\n```"
1989 );
1990 assert_eq!(
1991 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1992 .map(|chunk| chunk.unwrap())
1993 .collect::<String>()
1994 .await,
1995 "Lorem ipsum"
1996 );
1997
1998 assert_eq!(
1999 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
2000 .map(|chunk| chunk.unwrap())
2001 .collect::<String>()
2002 .await,
2003 "Lorem ipsum"
2004 );
2005
2006 assert_eq!(
2007 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
2008 .map(|chunk| chunk.unwrap())
2009 .collect::<String>()
2010 .await,
2011 "Lorem ipsum"
2012 );
2013 assert_eq!(
2014 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
2015 .map(|chunk| chunk.unwrap())
2016 .collect::<String>()
2017 .await,
2018 "Lorem ipsum"
2019 );
2020 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
2021 stream::iter(
2022 text.chars()
2023 .collect::<Vec<_>>()
2024 .chunks(size)
2025 .map(|chunk| Ok(chunk.iter().collect::<String>()))
2026 .collect::<Vec<_>>(),
2027 )
2028 }
2029 }
2030
2031 fn rust_lang() -> Language {
2032 Language::new(
2033 LanguageConfig {
2034 name: "Rust".into(),
2035 matcher: LanguageMatcher {
2036 path_suffixes: vec!["rs".to_string()],
2037 ..Default::default()
2038 },
2039 ..Default::default()
2040 },
2041 Some(tree_sitter_rust::language()),
2042 )
2043 .with_indents_query(
2044 r#"
2045 (call_expression) @indent
2046 (field_expression) @indent
2047 (_ "(" ")" @end) @indent
2048 (_ "{" "}" @end) @indent
2049 "#,
2050 )
2051 .unwrap()
2052 }
2053}