1use crate::{
2 prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
3 LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
4};
5use anyhow::Result;
6use client::telemetry::Telemetry;
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp},
10 display_map::{
11 BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
12 },
13 scroll::{Autoscroll, AutoscrollStrategy},
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorStyle, ExcerptRange,
15 GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
18use gpui::{
19 AnyWindowHandle, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight,
20 Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
21 ViewContext, WeakView, WhiteSpace, WindowContext,
22};
23use language::{Buffer, Point, TransactionId};
24use multi_buffer::MultiBufferRow;
25use parking_lot::Mutex;
26use rope::Rope;
27use settings::Settings;
28use similar::TextDiff;
29use std::{
30 cmp, future, mem,
31 ops::{Range, RangeInclusive},
32 sync::Arc,
33 time::Instant,
34};
35use theme::ThemeSettings;
36use ui::{prelude::*, Tooltip};
37use workspace::{notifications::NotificationId, Toast, Workspace};
38
39pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
40 cx.set_global(InlineAssistant::new(telemetry));
41}
42
43const PROMPT_HISTORY_MAX_LEN: usize = 20;
44
45pub struct InlineAssistant {
46 next_assist_id: InlineAssistId,
47 pending_assists: HashMap<InlineAssistId, PendingInlineAssist>,
48 pending_assist_ids_by_editor: HashMap<WeakView<Editor>, EditorPendingAssists>,
49 prompt_history: VecDeque<String>,
50 telemetry: Option<Arc<Telemetry>>,
51}
52
53struct EditorPendingAssists {
54 window: AnyWindowHandle,
55 assist_ids: Vec<InlineAssistId>,
56}
57
58impl Global for InlineAssistant {}
59
60impl InlineAssistant {
61 pub fn new(telemetry: Arc<Telemetry>) -> Self {
62 Self {
63 next_assist_id: InlineAssistId::default(),
64 pending_assists: HashMap::default(),
65 pending_assist_ids_by_editor: HashMap::default(),
66 prompt_history: VecDeque::default(),
67 telemetry: Some(telemetry),
68 }
69 }
70
71 pub fn assist(
72 &mut self,
73 editor: &View<Editor>,
74 workspace: Option<WeakView<Workspace>>,
75 include_context: bool,
76 cx: &mut WindowContext,
77 ) {
78 let selection = editor.read(cx).selections.newest_anchor().clone();
79 if selection.start.excerpt_id != selection.end.excerpt_id {
80 return;
81 }
82 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
83
84 // Extend the selection to the start and the end of the line.
85 let mut point_selection = selection.map(|selection| selection.to_point(&snapshot));
86 if point_selection.end > point_selection.start {
87 point_selection.start.column = 0;
88 // If the selection ends at the start of the line, we don't want to include it.
89 if point_selection.end.column == 0 {
90 point_selection.end.row -= 1;
91 }
92 point_selection.end.column = snapshot.line_len(MultiBufferRow(point_selection.end.row));
93 }
94
95 let codegen_kind = if point_selection.start == point_selection.end {
96 CodegenKind::Generate {
97 position: snapshot.anchor_after(point_selection.start),
98 }
99 } else {
100 CodegenKind::Transform {
101 range: snapshot.anchor_before(point_selection.start)
102 ..snapshot.anchor_after(point_selection.end),
103 }
104 };
105
106 let inline_assist_id = self.next_assist_id.post_inc();
107 let codegen = cx.new_model(|cx| {
108 Codegen::new(
109 editor.read(cx).buffer().clone(),
110 codegen_kind,
111 self.telemetry.clone(),
112 cx,
113 )
114 });
115
116 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
117 let prompt_editor = cx.new_view(|cx| {
118 InlineAssistEditor::new(
119 inline_assist_id,
120 gutter_dimensions.clone(),
121 self.prompt_history.clone(),
122 codegen.clone(),
123 workspace.clone(),
124 cx,
125 )
126 });
127 let (prompt_block_id, end_block_id) = editor.update(cx, |editor, cx| {
128 let start_anchor = snapshot.anchor_before(point_selection.start);
129 let end_anchor = snapshot.anchor_after(point_selection.end);
130 editor.change_selections(Some(Autoscroll::newest()), cx, |selections| {
131 selections.select_anchor_ranges([start_anchor..start_anchor])
132 });
133 let block_ids = editor.insert_blocks(
134 [
135 BlockProperties {
136 style: BlockStyle::Sticky,
137 position: start_anchor,
138 height: prompt_editor.read(cx).height_in_lines,
139 render: build_inline_assist_editor_renderer(
140 &prompt_editor,
141 gutter_dimensions,
142 ),
143 disposition: BlockDisposition::Above,
144 },
145 BlockProperties {
146 style: BlockStyle::Sticky,
147 position: end_anchor,
148 height: 1,
149 render: Box::new(|cx| {
150 v_flex()
151 .h_full()
152 .w_full()
153 .border_t_1()
154 .border_color(cx.theme().status().info_border)
155 .into_any_element()
156 }),
157 disposition: BlockDisposition::Below,
158 },
159 ],
160 Some(Autoscroll::Strategy(AutoscrollStrategy::Newest)),
161 cx,
162 );
163 (block_ids[0], block_ids[1])
164 });
165
166 self.pending_assists.insert(
167 inline_assist_id,
168 PendingInlineAssist {
169 include_context,
170 editor: editor.downgrade(),
171 editor_decorations: Some(PendingInlineAssistDecorations {
172 prompt_block_id,
173 prompt_editor: prompt_editor.clone(),
174 removed_line_block_ids: HashSet::default(),
175 end_block_id,
176 }),
177 codegen: codegen.clone(),
178 workspace,
179 _subscriptions: vec![
180 cx.subscribe(&prompt_editor, |inline_assist_editor, event, cx| {
181 InlineAssistant::update_global(cx, |this, cx| {
182 this.handle_inline_assistant_event(inline_assist_editor, event, cx)
183 })
184 }),
185 cx.subscribe(editor, {
186 let inline_assist_editor = prompt_editor.downgrade();
187 move |editor, event, cx| {
188 if let Some(inline_assist_editor) = inline_assist_editor.upgrade() {
189 if let EditorEvent::SelectionsChanged { local } = event {
190 if *local
191 && inline_assist_editor
192 .focus_handle(cx)
193 .contains_focused(cx)
194 {
195 cx.focus_view(&editor);
196 }
197 }
198 }
199 }
200 }),
201 cx.observe(&codegen, {
202 let editor = editor.downgrade();
203 move |_, cx| {
204 if let Some(editor) = editor.upgrade() {
205 InlineAssistant::update_global(cx, |this, cx| {
206 this.update_editor_highlights(&editor, cx);
207 this.update_editor_blocks(&editor, inline_assist_id, cx);
208 })
209 }
210 }
211 }),
212 cx.subscribe(&codegen, move |codegen, event, cx| {
213 InlineAssistant::update_global(cx, |this, cx| match event {
214 CodegenEvent::Undone => {
215 this.finish_inline_assist(inline_assist_id, false, cx)
216 }
217 CodegenEvent::Finished => {
218 let pending_assist = if let Some(pending_assist) =
219 this.pending_assists.get(&inline_assist_id)
220 {
221 pending_assist
222 } else {
223 return;
224 };
225
226 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
227 if pending_assist.editor_decorations.is_none() {
228 if let Some(workspace) = pending_assist
229 .workspace
230 .as_ref()
231 .and_then(|workspace| workspace.upgrade())
232 {
233 let error =
234 format!("Inline assistant error: {}", error);
235 workspace.update(cx, |workspace, cx| {
236 struct InlineAssistantError;
237
238 let id = NotificationId::identified::<
239 InlineAssistantError,
240 >(
241 inline_assist_id.0
242 );
243
244 workspace.show_toast(Toast::new(id, error), cx);
245 })
246 }
247 }
248 }
249
250 if pending_assist.editor_decorations.is_none() {
251 this.finish_inline_assist(inline_assist_id, false, cx);
252 }
253 }
254 })
255 }),
256 ],
257 },
258 );
259
260 self.pending_assist_ids_by_editor
261 .entry(editor.downgrade())
262 .or_insert_with(|| EditorPendingAssists {
263 window: cx.window_handle(),
264 assist_ids: Vec::new(),
265 })
266 .assist_ids
267 .push(inline_assist_id);
268 self.update_editor_highlights(editor, cx);
269 }
270
271 fn handle_inline_assistant_event(
272 &mut self,
273 inline_assist_editor: View<InlineAssistEditor>,
274 event: &InlineAssistEditorEvent,
275 cx: &mut WindowContext,
276 ) {
277 let assist_id = inline_assist_editor.read(cx).id;
278 match event {
279 InlineAssistEditorEvent::Started => {
280 self.start_inline_assist(assist_id, cx);
281 }
282 InlineAssistEditorEvent::Stopped => {
283 self.stop_inline_assist(assist_id, cx);
284 }
285 InlineAssistEditorEvent::Confirmed => {
286 self.finish_inline_assist(assist_id, false, cx);
287 }
288 InlineAssistEditorEvent::Canceled => {
289 self.finish_inline_assist(assist_id, true, cx);
290 }
291 InlineAssistEditorEvent::Dismissed => {
292 self.hide_inline_assist_decorations(assist_id, cx);
293 }
294 InlineAssistEditorEvent::Resized { height_in_lines } => {
295 self.resize_inline_assist(assist_id, *height_in_lines, cx);
296 }
297 }
298 }
299
300 pub fn cancel_last_inline_assist(&mut self, cx: &mut WindowContext) -> bool {
301 for (editor, pending_assists) in &self.pending_assist_ids_by_editor {
302 if pending_assists.window == cx.window_handle() {
303 if let Some(editor) = editor.upgrade() {
304 if editor.read(cx).is_focused(cx) {
305 if let Some(assist_id) = pending_assists.assist_ids.last().copied() {
306 self.finish_inline_assist(assist_id, true, cx);
307 return true;
308 }
309 }
310 }
311 }
312 }
313 false
314 }
315
316 fn finish_inline_assist(
317 &mut self,
318 assist_id: InlineAssistId,
319 undo: bool,
320 cx: &mut WindowContext,
321 ) {
322 self.hide_inline_assist_decorations(assist_id, cx);
323
324 if let Some(pending_assist) = self.pending_assists.remove(&assist_id) {
325 if let hash_map::Entry::Occupied(mut entry) = self
326 .pending_assist_ids_by_editor
327 .entry(pending_assist.editor.clone())
328 {
329 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
330 if entry.get().assist_ids.is_empty() {
331 entry.remove();
332 }
333 }
334
335 if let Some(editor) = pending_assist.editor.upgrade() {
336 self.update_editor_highlights(&editor, cx);
337
338 if undo {
339 pending_assist
340 .codegen
341 .update(cx, |codegen, cx| codegen.undo(cx));
342 }
343 }
344 }
345 }
346
347 fn hide_inline_assist_decorations(
348 &mut self,
349 assist_id: InlineAssistId,
350 cx: &mut WindowContext,
351 ) -> bool {
352 let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) else {
353 return false;
354 };
355 let Some(editor) = pending_assist.editor.upgrade() else {
356 return false;
357 };
358 let Some(decorations) = pending_assist.editor_decorations.take() else {
359 return false;
360 };
361
362 editor.update(cx, |editor, cx| {
363 let mut to_remove = decorations.removed_line_block_ids;
364 to_remove.insert(decorations.prompt_block_id);
365 to_remove.insert(decorations.end_block_id);
366 editor.remove_blocks(to_remove, None, cx);
367 if decorations
368 .prompt_editor
369 .focus_handle(cx)
370 .contains_focused(cx)
371 {
372 editor.focus(cx);
373 }
374 });
375
376 self.update_editor_highlights(&editor, cx);
377 true
378 }
379
380 fn resize_inline_assist(
381 &mut self,
382 assist_id: InlineAssistId,
383 height_in_lines: u8,
384 cx: &mut WindowContext,
385 ) {
386 if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
387 if let Some(editor) = pending_assist.editor.upgrade() {
388 if let Some(decorations) = pending_assist.editor_decorations.as_ref() {
389 let gutter_dimensions =
390 decorations.prompt_editor.read(cx).gutter_dimensions.clone();
391 let mut new_blocks = HashMap::default();
392 new_blocks.insert(
393 decorations.prompt_block_id,
394 (
395 Some(height_in_lines),
396 build_inline_assist_editor_renderer(
397 &decorations.prompt_editor,
398 gutter_dimensions,
399 ),
400 ),
401 );
402 editor.update(cx, |editor, cx| {
403 editor
404 .display_map
405 .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
406 });
407 }
408 }
409 }
410 }
411
412 fn start_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
413 let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
414 {
415 pending_assist
416 } else {
417 return;
418 };
419
420 pending_assist
421 .codegen
422 .update(cx, |codegen, cx| codegen.undo(cx));
423
424 let Some(user_prompt) = pending_assist
425 .editor_decorations
426 .as_ref()
427 .map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
428 else {
429 return;
430 };
431
432 let context = if pending_assist.include_context {
433 pending_assist.workspace.as_ref().and_then(|workspace| {
434 let workspace = workspace.upgrade()?.read(cx);
435 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
436 assistant_panel.read(cx).active_context(cx)
437 })
438 } else {
439 None
440 };
441
442 let editor = if let Some(editor) = pending_assist.editor.upgrade() {
443 editor
444 } else {
445 return;
446 };
447
448 let project_name = pending_assist.workspace.as_ref().and_then(|workspace| {
449 let workspace = workspace.upgrade()?;
450 Some(
451 workspace
452 .read(cx)
453 .project()
454 .read(cx)
455 .worktree_root_names(cx)
456 .collect::<Vec<&str>>()
457 .join("/"),
458 )
459 });
460
461 self.prompt_history.retain(|prompt| *prompt != user_prompt);
462 self.prompt_history.push_back(user_prompt.clone());
463 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
464 self.prompt_history.pop_front();
465 }
466
467 let codegen = pending_assist.codegen.clone();
468 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
469 let range = codegen.read(cx).range();
470 let start = snapshot.point_to_buffer_offset(range.start);
471 let end = snapshot.point_to_buffer_offset(range.end);
472 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
473 let (start_buffer, start_buffer_offset) = start;
474 let (end_buffer, end_buffer_offset) = end;
475 if start_buffer.remote_id() == end_buffer.remote_id() {
476 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
477 } else {
478 self.finish_inline_assist(assist_id, false, cx);
479 return;
480 }
481 } else {
482 self.finish_inline_assist(assist_id, false, cx);
483 return;
484 };
485
486 let language = buffer.language_at(range.start);
487 let language_name = if let Some(language) = language.as_ref() {
488 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
489 None
490 } else {
491 Some(language.name())
492 }
493 } else {
494 None
495 };
496
497 // Higher Temperature increases the randomness of model outputs.
498 // If Markdown or No Language is Known, increase the randomness for more creative output
499 // If Code, decrease temperature to get more deterministic outputs
500 let temperature = if let Some(language) = language_name.clone() {
501 if language.as_ref() == "Markdown" {
502 1.0
503 } else {
504 0.5
505 }
506 } else {
507 1.0
508 };
509
510 let prompt = cx.background_executor().spawn(async move {
511 let language_name = language_name.as_deref();
512 generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
513 });
514
515 let mut messages = Vec::new();
516 if let Some(context) = context {
517 let request = context.read(cx).to_completion_request(cx);
518 messages = request.messages;
519 }
520 let model = CompletionProvider::global(cx).model();
521
522 cx.spawn(|mut cx| async move {
523 let prompt = prompt.await?;
524
525 messages.push(LanguageModelRequestMessage {
526 role: Role::User,
527 content: prompt,
528 });
529
530 let request = LanguageModelRequest {
531 model,
532 messages,
533 stop: vec!["|END|>".to_string()],
534 temperature,
535 };
536
537 codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
538 anyhow::Ok(())
539 })
540 .detach_and_log_err(cx);
541 }
542
543 fn stop_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
544 let pending_assist = if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id)
545 {
546 pending_assist
547 } else {
548 return;
549 };
550
551 pending_assist
552 .codegen
553 .update(cx, |codegen, cx| codegen.stop(cx));
554 }
555
556 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
557 let mut gutter_pending_ranges = Vec::new();
558 let mut gutter_transformed_ranges = Vec::new();
559 let mut foreground_ranges = Vec::new();
560 let mut inserted_row_ranges = Vec::new();
561 let empty_inline_assist_ids = Vec::new();
562 let inline_assist_ids = self
563 .pending_assist_ids_by_editor
564 .get(&editor.downgrade())
565 .map_or(&empty_inline_assist_ids, |pending_assists| {
566 &pending_assists.assist_ids
567 });
568
569 for inline_assist_id in inline_assist_ids {
570 if let Some(pending_assist) = self.pending_assists.get(inline_assist_id) {
571 let codegen = pending_assist.codegen.read(cx);
572 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
573
574 if codegen.edit_position != codegen.range().end {
575 gutter_pending_ranges.push(codegen.edit_position..codegen.range().end);
576 }
577
578 if codegen.range().start != codegen.edit_position {
579 gutter_transformed_ranges.push(codegen.range().start..codegen.edit_position);
580 }
581
582 if pending_assist.editor_decorations.is_some() {
583 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
584 }
585 }
586 }
587
588 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
589 merge_ranges(&mut foreground_ranges, &snapshot);
590 merge_ranges(&mut gutter_pending_ranges, &snapshot);
591 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
592 editor.update(cx, |editor, cx| {
593 enum GutterPendingRange {}
594 if gutter_pending_ranges.is_empty() {
595 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
596 } else {
597 editor.highlight_gutter::<GutterPendingRange>(
598 &gutter_pending_ranges,
599 |cx| cx.theme().status().info_background,
600 cx,
601 )
602 }
603
604 enum GutterTransformedRange {}
605 if gutter_transformed_ranges.is_empty() {
606 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
607 } else {
608 editor.highlight_gutter::<GutterTransformedRange>(
609 &gutter_transformed_ranges,
610 |cx| cx.theme().status().info,
611 cx,
612 )
613 }
614
615 if foreground_ranges.is_empty() {
616 editor.clear_highlights::<PendingInlineAssist>(cx);
617 } else {
618 editor.highlight_text::<PendingInlineAssist>(
619 foreground_ranges,
620 HighlightStyle {
621 fade_out: Some(0.6),
622 ..Default::default()
623 },
624 cx,
625 );
626 }
627
628 editor.clear_row_highlights::<PendingInlineAssist>();
629 for row_range in inserted_row_ranges {
630 editor.highlight_rows::<PendingInlineAssist>(
631 row_range,
632 Some(cx.theme().status().info_background),
633 false,
634 cx,
635 );
636 }
637 });
638 }
639
640 fn update_editor_blocks(
641 &mut self,
642 editor: &View<Editor>,
643 assist_id: InlineAssistId,
644 cx: &mut WindowContext,
645 ) {
646 let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) else {
647 return;
648 };
649 let Some(decorations) = pending_assist.editor_decorations.as_mut() else {
650 return;
651 };
652
653 let codegen = pending_assist.codegen.read(cx);
654 let old_snapshot = codegen.snapshot.clone();
655 let old_buffer = codegen.old_buffer.clone();
656 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
657
658 editor.update(cx, |editor, cx| {
659 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
660 editor.remove_blocks(old_blocks, None, cx);
661
662 let mut new_blocks = Vec::new();
663 for (new_row, old_row_range) in deleted_row_ranges {
664 let (_, buffer_start) = old_snapshot
665 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
666 .unwrap();
667 let (_, buffer_end) = old_snapshot
668 .point_to_buffer_offset(Point::new(
669 *old_row_range.end(),
670 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
671 ))
672 .unwrap();
673
674 let deleted_lines_editor = cx.new_view(|cx| {
675 let multi_buffer = cx.new_model(|_| {
676 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
677 });
678 multi_buffer.update(cx, |multi_buffer, cx| {
679 multi_buffer.push_excerpts(
680 old_buffer.clone(),
681 Some(ExcerptRange {
682 context: buffer_start..buffer_end,
683 primary: None,
684 }),
685 cx,
686 );
687 });
688
689 enum DeletedLines {}
690 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
691 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
692 editor.set_show_wrap_guides(false, cx);
693 editor.set_show_gutter(false, cx);
694 editor.scroll_manager.set_forbid_vertical_scroll(true);
695 editor.set_read_only(true);
696 editor.highlight_rows::<DeletedLines>(
697 Anchor::min()..=Anchor::max(),
698 Some(cx.theme().status().deleted_background),
699 false,
700 cx,
701 );
702 editor
703 });
704
705 let height = deleted_lines_editor
706 .update(cx, |editor, cx| editor.max_point(cx).row().0 as u8 + 1);
707 new_blocks.push(BlockProperties {
708 position: new_row,
709 height,
710 style: BlockStyle::Flex,
711 render: Box::new(move |cx| {
712 div()
713 .bg(cx.theme().status().deleted_background)
714 .size_full()
715 .pl(cx.gutter_dimensions.full_width())
716 .child(deleted_lines_editor.clone())
717 .into_any_element()
718 }),
719 disposition: BlockDisposition::Above,
720 });
721 }
722
723 decorations.removed_line_block_ids = editor
724 .insert_blocks(new_blocks, None, cx)
725 .into_iter()
726 .collect();
727 })
728 }
729}
730
731fn build_inline_assist_editor_renderer(
732 editor: &View<InlineAssistEditor>,
733 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
734) -> RenderBlock {
735 let editor = editor.clone();
736 Box::new(move |cx: &mut BlockContext| {
737 *gutter_dimensions.lock() = *cx.gutter_dimensions;
738 editor.clone().into_any_element()
739 })
740}
741
742#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
743struct InlineAssistId(usize);
744
745impl InlineAssistId {
746 fn post_inc(&mut self) -> InlineAssistId {
747 let id = *self;
748 self.0 += 1;
749 id
750 }
751}
752
753enum InlineAssistEditorEvent {
754 Started,
755 Stopped,
756 Confirmed,
757 Canceled,
758 Dismissed,
759 Resized { height_in_lines: u8 },
760}
761
762struct InlineAssistEditor {
763 id: InlineAssistId,
764 height_in_lines: u8,
765 prompt_editor: View<Editor>,
766 edited_since_done: bool,
767 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
768 prompt_history: VecDeque<String>,
769 prompt_history_ix: Option<usize>,
770 pending_prompt: String,
771 codegen: Model<Codegen>,
772 workspace: Option<WeakView<Workspace>>,
773 _subscriptions: Vec<Subscription>,
774}
775
776impl EventEmitter<InlineAssistEditorEvent> for InlineAssistEditor {}
777
778impl Render for InlineAssistEditor {
779 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
780 let gutter_dimensions = *self.gutter_dimensions.lock();
781
782 let buttons = match &self.codegen.read(cx).status {
783 CodegenStatus::Idle => {
784 vec![
785 IconButton::new("start", IconName::Sparkle)
786 .icon_color(Color::Muted)
787 .size(ButtonSize::None)
788 .icon_size(IconSize::XSmall)
789 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
790 .on_click(
791 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Started)),
792 ),
793 IconButton::new("cancel", IconName::Close)
794 .icon_color(Color::Muted)
795 .size(ButtonSize::None)
796 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
797 .on_click(
798 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Canceled)),
799 ),
800 ]
801 }
802 CodegenStatus::Pending => {
803 vec![
804 IconButton::new("stop", IconName::Stop)
805 .icon_color(Color::Error)
806 .size(ButtonSize::None)
807 .icon_size(IconSize::XSmall)
808 .tooltip(|cx| {
809 Tooltip::with_meta(
810 "Interrupt Transformation",
811 Some(&menu::Cancel),
812 "Changes won't be discarded",
813 cx,
814 )
815 })
816 .on_click(
817 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Stopped)),
818 ),
819 IconButton::new("cancel", IconName::Close)
820 .icon_color(Color::Muted)
821 .size(ButtonSize::None)
822 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
823 .on_click(
824 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Canceled)),
825 ),
826 ]
827 }
828 CodegenStatus::Error(_) | CodegenStatus::Done => {
829 vec![
830 if self.edited_since_done {
831 IconButton::new("restart", IconName::RotateCw)
832 .icon_color(Color::Info)
833 .icon_size(IconSize::XSmall)
834 .size(ButtonSize::None)
835 .tooltip(|cx| {
836 Tooltip::with_meta(
837 "Restart Transformation",
838 Some(&menu::Confirm),
839 "Changes will be discarded",
840 cx,
841 )
842 })
843 .on_click(cx.listener(|_, _, cx| {
844 cx.emit(InlineAssistEditorEvent::Started);
845 }))
846 } else {
847 IconButton::new("confirm", IconName::Check)
848 .icon_color(Color::Info)
849 .size(ButtonSize::None)
850 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
851 .on_click(cx.listener(|_, _, cx| {
852 cx.emit(InlineAssistEditorEvent::Confirmed);
853 }))
854 },
855 IconButton::new("cancel", IconName::Close)
856 .icon_color(Color::Muted)
857 .size(ButtonSize::None)
858 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
859 .on_click(
860 cx.listener(|_, _, cx| cx.emit(InlineAssistEditorEvent::Canceled)),
861 ),
862 ]
863 }
864 };
865
866 v_flex().h_full().w_full().justify_end().child(
867 h_flex()
868 .bg(cx.theme().colors().editor_background)
869 .border_y_1()
870 .border_color(cx.theme().status().info_border)
871 .py_1p5()
872 .w_full()
873 .on_action(cx.listener(Self::confirm))
874 .on_action(cx.listener(Self::cancel))
875 .on_action(cx.listener(Self::move_up))
876 .on_action(cx.listener(Self::move_down))
877 .child(
878 h_flex()
879 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
880 // .pr(gutter_dimensions.fold_area_width())
881 .justify_center()
882 .gap_2()
883 .children(self.workspace.clone().map(|workspace| {
884 IconButton::new("context", IconName::Context)
885 .size(ButtonSize::None)
886 .icon_size(IconSize::XSmall)
887 .icon_color(Color::Muted)
888 .on_click({
889 let workspace = workspace.clone();
890 cx.listener(move |_, _, cx| {
891 workspace
892 .update(cx, |workspace, cx| {
893 workspace.focus_panel::<AssistantPanel>(cx);
894 })
895 .ok();
896 })
897 })
898 .tooltip(move |cx| {
899 let token_count = workspace.upgrade().and_then(|workspace| {
900 let panel =
901 workspace.read(cx).panel::<AssistantPanel>(cx)?;
902 let context = panel.read(cx).active_context(cx)?;
903 context.read(cx).token_count()
904 });
905 if let Some(token_count) = token_count {
906 Tooltip::with_meta(
907 format!(
908 "{} Additional Context Tokens from Assistant",
909 token_count
910 ),
911 Some(&crate::ToggleFocus),
912 "Click to open…",
913 cx,
914 )
915 } else {
916 Tooltip::for_action(
917 "Toggle Assistant Panel",
918 &crate::ToggleFocus,
919 cx,
920 )
921 }
922 })
923 }))
924 .children(
925 if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
926 let error_message = SharedString::from(error.to_string());
927 Some(
928 div()
929 .id("error")
930 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
931 .child(
932 Icon::new(IconName::XCircle)
933 .size(IconSize::Small)
934 .color(Color::Error),
935 ),
936 )
937 } else {
938 None
939 },
940 ),
941 )
942 .child(div().flex_1().child(self.render_prompt_editor(cx)))
943 .child(h_flex().gap_2().pr_4().children(buttons)),
944 )
945 }
946}
947
948impl FocusableView for InlineAssistEditor {
949 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
950 self.prompt_editor.focus_handle(cx)
951 }
952}
953
954impl InlineAssistEditor {
955 const MAX_LINES: u8 = 8;
956
957 #[allow(clippy::too_many_arguments)]
958 fn new(
959 id: InlineAssistId,
960 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
961 prompt_history: VecDeque<String>,
962 codegen: Model<Codegen>,
963 workspace: Option<WeakView<Workspace>>,
964 cx: &mut ViewContext<Self>,
965 ) -> Self {
966 let prompt_editor = cx.new_view(|cx| {
967 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
968 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
969 editor.set_placeholder_text("Add a prompt…", cx);
970 editor
971 });
972 cx.focus_view(&prompt_editor);
973
974 let subscriptions = vec![
975 cx.observe(&codegen, Self::handle_codegen_changed),
976 cx.observe(&prompt_editor, Self::handle_prompt_editor_changed),
977 cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
978 ];
979
980 let mut this = Self {
981 id,
982 height_in_lines: 1,
983 prompt_editor,
984 edited_since_done: false,
985 gutter_dimensions,
986 prompt_history,
987 prompt_history_ix: None,
988 pending_prompt: String::new(),
989 codegen,
990 workspace,
991 _subscriptions: subscriptions,
992 };
993 this.count_lines(cx);
994 this
995 }
996
997 fn prompt(&self, cx: &AppContext) -> String {
998 self.prompt_editor.read(cx).text(cx)
999 }
1000
1001 fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
1002 let height_in_lines = cmp::max(
1003 2, // Make the editor at least two lines tall, to account for padding and buttons.
1004 cmp::min(
1005 self.prompt_editor
1006 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
1007 Self::MAX_LINES as u32,
1008 ),
1009 ) as u8;
1010
1011 if height_in_lines != self.height_in_lines {
1012 self.height_in_lines = height_in_lines;
1013 cx.emit(InlineAssistEditorEvent::Resized { height_in_lines });
1014 }
1015 }
1016
1017 fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
1018 self.count_lines(cx);
1019 }
1020
1021 fn handle_prompt_editor_events(
1022 &mut self,
1023 _: View<Editor>,
1024 event: &EditorEvent,
1025 cx: &mut ViewContext<Self>,
1026 ) {
1027 match event {
1028 EditorEvent::Edited => {
1029 self.edited_since_done = true;
1030 self.pending_prompt = self.prompt_editor.read(cx).text(cx);
1031 cx.notify();
1032 }
1033 EditorEvent::Blurred => {
1034 if let CodegenStatus::Idle = &self.codegen.read(cx).status {
1035 let assistant_panel_is_focused = self
1036 .workspace
1037 .as_ref()
1038 .and_then(|workspace| {
1039 let panel =
1040 workspace.upgrade()?.read(cx).panel::<AssistantPanel>(cx)?;
1041 Some(panel.focus_handle(cx).contains_focused(cx))
1042 })
1043 .unwrap_or(false);
1044
1045 if !assistant_panel_is_focused {
1046 cx.emit(InlineAssistEditorEvent::Canceled);
1047 }
1048 }
1049 }
1050 _ => {}
1051 }
1052 }
1053
1054 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1055 match &self.codegen.read(cx).status {
1056 CodegenStatus::Idle => {
1057 self.prompt_editor
1058 .update(cx, |editor, _| editor.set_read_only(false));
1059 }
1060 CodegenStatus::Pending => {
1061 self.prompt_editor
1062 .update(cx, |editor, _| editor.set_read_only(true));
1063 }
1064 CodegenStatus::Done | CodegenStatus::Error(_) => {
1065 self.edited_since_done = false;
1066 self.prompt_editor
1067 .update(cx, |editor, _| editor.set_read_only(false));
1068 }
1069 }
1070 }
1071
1072 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1073 match &self.codegen.read(cx).status {
1074 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1075 cx.emit(InlineAssistEditorEvent::Canceled);
1076 }
1077 CodegenStatus::Pending => {
1078 cx.emit(InlineAssistEditorEvent::Stopped);
1079 }
1080 }
1081 }
1082
1083 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1084 match &self.codegen.read(cx).status {
1085 CodegenStatus::Idle => {
1086 cx.emit(InlineAssistEditorEvent::Started);
1087 }
1088 CodegenStatus::Pending => {
1089 cx.emit(InlineAssistEditorEvent::Dismissed);
1090 }
1091 CodegenStatus::Done | CodegenStatus::Error(_) => {
1092 if self.edited_since_done {
1093 cx.emit(InlineAssistEditorEvent::Started);
1094 } else {
1095 cx.emit(InlineAssistEditorEvent::Confirmed);
1096 }
1097 }
1098 }
1099 }
1100
1101 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1102 if let Some(ix) = self.prompt_history_ix {
1103 if ix > 0 {
1104 self.prompt_history_ix = Some(ix - 1);
1105 let prompt = self.prompt_history[ix - 1].clone();
1106 self.set_prompt(&prompt, cx);
1107 }
1108 } else if !self.prompt_history.is_empty() {
1109 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1110 let prompt = self.prompt_history[self.prompt_history.len() - 1].clone();
1111 self.set_prompt(&prompt, cx);
1112 }
1113 }
1114
1115 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1116 if let Some(ix) = self.prompt_history_ix {
1117 if ix < self.prompt_history.len() - 1 {
1118 self.prompt_history_ix = Some(ix + 1);
1119 let prompt = self.prompt_history[ix + 1].clone();
1120 self.set_prompt(&prompt, cx);
1121 } else {
1122 self.prompt_history_ix = None;
1123 let pending_prompt = self.pending_prompt.clone();
1124 self.set_prompt(&pending_prompt, cx);
1125 }
1126 }
1127 }
1128
1129 fn set_prompt(&mut self, prompt: &str, cx: &mut ViewContext<Self>) {
1130 self.prompt_editor.update(cx, |editor, cx| {
1131 editor.buffer().update(cx, |buffer, cx| {
1132 let len = buffer.len(cx);
1133 buffer.edit([(0..len, prompt)], None, cx);
1134 });
1135 });
1136 }
1137
1138 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1139 let settings = ThemeSettings::get_global(cx);
1140 let text_style = TextStyle {
1141 color: if self.prompt_editor.read(cx).read_only(cx) {
1142 cx.theme().colors().text_disabled
1143 } else {
1144 cx.theme().colors().text
1145 },
1146 font_family: settings.ui_font.family.clone(),
1147 font_features: settings.ui_font.features.clone(),
1148 font_size: rems(0.875).into(),
1149 font_weight: FontWeight::NORMAL,
1150 font_style: FontStyle::Normal,
1151 line_height: relative(1.3),
1152 background_color: None,
1153 underline: None,
1154 strikethrough: None,
1155 white_space: WhiteSpace::Normal,
1156 };
1157 EditorElement::new(
1158 &self.prompt_editor,
1159 EditorStyle {
1160 background: cx.theme().colors().editor_background,
1161 local_player: cx.theme().players().local(),
1162 text: text_style,
1163 ..Default::default()
1164 },
1165 )
1166 }
1167}
1168
1169struct PendingInlineAssist {
1170 editor: WeakView<Editor>,
1171 editor_decorations: Option<PendingInlineAssistDecorations>,
1172 codegen: Model<Codegen>,
1173 _subscriptions: Vec<Subscription>,
1174 workspace: Option<WeakView<Workspace>>,
1175 include_context: bool,
1176}
1177
1178struct PendingInlineAssistDecorations {
1179 prompt_block_id: BlockId,
1180 prompt_editor: View<InlineAssistEditor>,
1181 removed_line_block_ids: HashSet<BlockId>,
1182 end_block_id: BlockId,
1183}
1184
1185#[derive(Debug)]
1186pub enum CodegenEvent {
1187 Finished,
1188 Undone,
1189}
1190
1191#[derive(Clone)]
1192pub enum CodegenKind {
1193 Transform { range: Range<Anchor> },
1194 Generate { position: Anchor },
1195}
1196
1197impl CodegenKind {
1198 fn range(&self, snapshot: &MultiBufferSnapshot) -> Range<Anchor> {
1199 match self {
1200 CodegenKind::Transform { range } => range.clone(),
1201 CodegenKind::Generate { position } => position.bias_left(snapshot)..*position,
1202 }
1203 }
1204}
1205
1206pub struct Codegen {
1207 buffer: Model<MultiBuffer>,
1208 old_buffer: Model<Buffer>,
1209 snapshot: MultiBufferSnapshot,
1210 kind: CodegenKind,
1211 edit_position: Anchor,
1212 last_equal_ranges: Vec<Range<Anchor>>,
1213 transaction_id: Option<TransactionId>,
1214 status: CodegenStatus,
1215 generation: Task<()>,
1216 diff: Diff,
1217 telemetry: Option<Arc<Telemetry>>,
1218 _subscription: gpui::Subscription,
1219}
1220
1221enum CodegenStatus {
1222 Idle,
1223 Pending,
1224 Done,
1225 Error(anyhow::Error),
1226}
1227
1228#[derive(Default)]
1229struct Diff {
1230 task: Option<Task<()>>,
1231 should_update: bool,
1232 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1233 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
1234}
1235
1236impl EventEmitter<CodegenEvent> for Codegen {}
1237
1238impl Codegen {
1239 pub fn new(
1240 buffer: Model<MultiBuffer>,
1241 kind: CodegenKind,
1242 telemetry: Option<Arc<Telemetry>>,
1243 cx: &mut ModelContext<Self>,
1244 ) -> Self {
1245 let snapshot = buffer.read(cx).snapshot(cx);
1246
1247 let (old_buffer, _, _) = buffer
1248 .read(cx)
1249 .range_to_buffer_ranges(kind.range(&snapshot), cx)
1250 .pop()
1251 .unwrap();
1252 let old_buffer = cx.new_model(|cx| {
1253 let old_buffer = old_buffer.read(cx);
1254 let text = old_buffer.as_rope().clone();
1255 let line_ending = old_buffer.line_ending();
1256 let language = old_buffer.language().cloned();
1257 let language_registry = old_buffer.language_registry();
1258
1259 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
1260 buffer.set_language(language, cx);
1261 if let Some(language_registry) = language_registry {
1262 buffer.set_language_registry(language_registry)
1263 }
1264 buffer
1265 });
1266
1267 Self {
1268 buffer: buffer.clone(),
1269 old_buffer,
1270 edit_position: kind.range(&snapshot).start,
1271 snapshot,
1272 kind,
1273 last_equal_ranges: Default::default(),
1274 transaction_id: Default::default(),
1275 status: CodegenStatus::Idle,
1276 generation: Task::ready(()),
1277 diff: Diff::default(),
1278 telemetry,
1279 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
1280 }
1281 }
1282
1283 fn handle_buffer_event(
1284 &mut self,
1285 _buffer: Model<MultiBuffer>,
1286 event: &multi_buffer::Event,
1287 cx: &mut ModelContext<Self>,
1288 ) {
1289 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
1290 if self.transaction_id == Some(*transaction_id) {
1291 self.transaction_id = None;
1292 self.generation = Task::ready(());
1293 cx.emit(CodegenEvent::Undone);
1294 }
1295 }
1296 }
1297
1298 pub fn range(&self) -> Range<Anchor> {
1299 self.kind.range(&self.snapshot)
1300 }
1301
1302 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
1303 &self.last_equal_ranges
1304 }
1305
1306 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
1307 let range = self.range();
1308 let snapshot = self.snapshot.clone();
1309 let selected_text = snapshot
1310 .text_for_range(range.start..range.end)
1311 .collect::<Rope>();
1312
1313 let selection_start = range.start.to_point(&snapshot);
1314 let suggested_line_indent = snapshot
1315 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
1316 .into_values()
1317 .next()
1318 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
1319
1320 let model_telemetry_id = prompt.model.telemetry_id();
1321 let response = CompletionProvider::global(cx).complete(prompt);
1322 let telemetry = self.telemetry.clone();
1323 self.edit_position = range.start;
1324 self.diff = Diff::default();
1325 self.status = CodegenStatus::Pending;
1326 self.generation = cx.spawn(|this, mut cx| {
1327 async move {
1328 let generate = async {
1329 let mut edit_start = range.start.to_offset(&snapshot);
1330
1331 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
1332 let diff: Task<anyhow::Result<()>> =
1333 cx.background_executor().spawn(async move {
1334 let mut response_latency = None;
1335 let request_start = Instant::now();
1336 let diff = async {
1337 let chunks = strip_invalid_spans_from_codeblock(response.await?);
1338 futures::pin_mut!(chunks);
1339 let mut diff = StreamingDiff::new(selected_text.to_string());
1340
1341 let mut new_text = String::new();
1342 let mut base_indent = None;
1343 let mut line_indent = None;
1344 let mut first_line = true;
1345
1346 while let Some(chunk) = chunks.next().await {
1347 if response_latency.is_none() {
1348 response_latency = Some(request_start.elapsed());
1349 }
1350 let chunk = chunk?;
1351
1352 let mut lines = chunk.split('\n').peekable();
1353 while let Some(line) = lines.next() {
1354 new_text.push_str(line);
1355 if line_indent.is_none() {
1356 if let Some(non_whitespace_ch_ix) =
1357 new_text.find(|ch: char| !ch.is_whitespace())
1358 {
1359 line_indent = Some(non_whitespace_ch_ix);
1360 base_indent = base_indent.or(line_indent);
1361
1362 let line_indent = line_indent.unwrap();
1363 let base_indent = base_indent.unwrap();
1364 let indent_delta =
1365 line_indent as i32 - base_indent as i32;
1366 let mut corrected_indent_len = cmp::max(
1367 0,
1368 suggested_line_indent.len as i32 + indent_delta,
1369 )
1370 as usize;
1371 if first_line {
1372 corrected_indent_len = corrected_indent_len
1373 .saturating_sub(
1374 selection_start.column as usize,
1375 );
1376 }
1377
1378 let indent_char = suggested_line_indent.char();
1379 let mut indent_buffer = [0; 4];
1380 let indent_str =
1381 indent_char.encode_utf8(&mut indent_buffer);
1382 new_text.replace_range(
1383 ..line_indent,
1384 &indent_str.repeat(corrected_indent_len),
1385 );
1386 }
1387 }
1388
1389 if line_indent.is_some() {
1390 hunks_tx.send(diff.push_new(&new_text)).await?;
1391 new_text.clear();
1392 }
1393
1394 if lines.peek().is_some() {
1395 hunks_tx.send(diff.push_new("\n")).await?;
1396 line_indent = None;
1397 first_line = false;
1398 }
1399 }
1400 }
1401 hunks_tx.send(diff.push_new(&new_text)).await?;
1402 hunks_tx.send(diff.finish()).await?;
1403
1404 anyhow::Ok(())
1405 };
1406
1407 let result = diff.await;
1408
1409 let error_message =
1410 result.as_ref().err().map(|error| error.to_string());
1411 if let Some(telemetry) = telemetry {
1412 telemetry.report_assistant_event(
1413 None,
1414 telemetry_events::AssistantKind::Inline,
1415 model_telemetry_id,
1416 response_latency,
1417 error_message,
1418 );
1419 }
1420
1421 result?;
1422 Ok(())
1423 });
1424
1425 while let Some(hunks) = hunks_rx.next().await {
1426 this.update(&mut cx, |this, cx| {
1427 this.last_equal_ranges.clear();
1428
1429 let transaction = this.buffer.update(cx, |buffer, cx| {
1430 // Avoid grouping assistant edits with user edits.
1431 buffer.finalize_last_transaction(cx);
1432
1433 buffer.start_transaction(cx);
1434 buffer.edit(
1435 hunks.into_iter().filter_map(|hunk| match hunk {
1436 Hunk::Insert { text } => {
1437 let edit_start = snapshot.anchor_after(edit_start);
1438 Some((edit_start..edit_start, text))
1439 }
1440 Hunk::Remove { len } => {
1441 let edit_end = edit_start + len;
1442 let edit_range = snapshot.anchor_after(edit_start)
1443 ..snapshot.anchor_before(edit_end);
1444 edit_start = edit_end;
1445 Some((edit_range, String::new()))
1446 }
1447 Hunk::Keep { len } => {
1448 let edit_end = edit_start + len;
1449 let edit_range = snapshot.anchor_after(edit_start)
1450 ..snapshot.anchor_before(edit_end);
1451 edit_start = edit_end;
1452 this.last_equal_ranges.push(edit_range);
1453 None
1454 }
1455 }),
1456 None,
1457 cx,
1458 );
1459 this.edit_position = snapshot.anchor_after(edit_start);
1460
1461 buffer.end_transaction(cx)
1462 });
1463
1464 if let Some(transaction) = transaction {
1465 if let Some(first_transaction) = this.transaction_id {
1466 // Group all assistant edits into the first transaction.
1467 this.buffer.update(cx, |buffer, cx| {
1468 buffer.merge_transactions(
1469 transaction,
1470 first_transaction,
1471 cx,
1472 )
1473 });
1474 } else {
1475 this.transaction_id = Some(transaction);
1476 this.buffer.update(cx, |buffer, cx| {
1477 buffer.finalize_last_transaction(cx)
1478 });
1479 }
1480 }
1481
1482 this.update_diff(cx);
1483 cx.notify();
1484 })?;
1485 }
1486
1487 diff.await?;
1488
1489 anyhow::Ok(())
1490 };
1491
1492 let result = generate.await;
1493 this.update(&mut cx, |this, cx| {
1494 this.last_equal_ranges.clear();
1495 if let Err(error) = result {
1496 this.status = CodegenStatus::Error(error);
1497 } else {
1498 this.status = CodegenStatus::Done;
1499 }
1500 cx.emit(CodegenEvent::Finished);
1501 cx.notify();
1502 })
1503 .ok();
1504 }
1505 });
1506 cx.notify();
1507 }
1508
1509 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
1510 self.last_equal_ranges.clear();
1511 self.status = CodegenStatus::Done;
1512 self.generation = Task::ready(());
1513 cx.emit(CodegenEvent::Finished);
1514 cx.notify();
1515 }
1516
1517 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
1518 if let Some(transaction_id) = self.transaction_id.take() {
1519 self.buffer
1520 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
1521 }
1522 }
1523
1524 fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
1525 if self.diff.task.is_some() {
1526 self.diff.should_update = true;
1527 } else {
1528 self.diff.should_update = false;
1529
1530 let old_snapshot = self.snapshot.clone();
1531 let old_range = self.range().to_point(&old_snapshot);
1532 let new_snapshot = self.buffer.read(cx).snapshot(cx);
1533 let new_range = self.range().to_point(&new_snapshot);
1534
1535 self.diff.task = Some(cx.spawn(|this, mut cx| async move {
1536 let (deleted_row_ranges, inserted_row_ranges) = cx
1537 .background_executor()
1538 .spawn(async move {
1539 let old_text = old_snapshot
1540 .text_for_range(
1541 Point::new(old_range.start.row, 0)
1542 ..Point::new(
1543 old_range.end.row,
1544 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1545 ),
1546 )
1547 .collect::<String>();
1548 let new_text = new_snapshot
1549 .text_for_range(
1550 Point::new(new_range.start.row, 0)
1551 ..Point::new(
1552 new_range.end.row,
1553 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1554 ),
1555 )
1556 .collect::<String>();
1557
1558 let mut old_row = old_range.start.row;
1559 let mut new_row = new_range.start.row;
1560 let diff = TextDiff::from_lines(old_text.as_str(), new_text.as_str());
1561
1562 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1563 let mut inserted_row_ranges = Vec::new();
1564 for change in diff.iter_all_changes() {
1565 let line_count = change.value().lines().count() as u32;
1566 match change.tag() {
1567 similar::ChangeTag::Equal => {
1568 old_row += line_count;
1569 new_row += line_count;
1570 }
1571 similar::ChangeTag::Delete => {
1572 let old_end_row = old_row + line_count - 1;
1573 let new_row =
1574 new_snapshot.anchor_before(Point::new(new_row, 0));
1575
1576 if let Some((_, last_deleted_row_range)) =
1577 deleted_row_ranges.last_mut()
1578 {
1579 if *last_deleted_row_range.end() + 1 == old_row {
1580 *last_deleted_row_range =
1581 *last_deleted_row_range.start()..=old_end_row;
1582 } else {
1583 deleted_row_ranges
1584 .push((new_row, old_row..=old_end_row));
1585 }
1586 } else {
1587 deleted_row_ranges.push((new_row, old_row..=old_end_row));
1588 }
1589
1590 old_row += line_count;
1591 }
1592 similar::ChangeTag::Insert => {
1593 let new_end_row = new_row + line_count - 1;
1594 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1595 let end = new_snapshot.anchor_before(Point::new(
1596 new_end_row,
1597 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1598 ));
1599 inserted_row_ranges.push(start..=end);
1600 new_row += line_count;
1601 }
1602 }
1603 }
1604
1605 (deleted_row_ranges, inserted_row_ranges)
1606 })
1607 .await;
1608
1609 this.update(&mut cx, |this, cx| {
1610 this.diff.deleted_row_ranges = deleted_row_ranges;
1611 this.diff.inserted_row_ranges = inserted_row_ranges;
1612 this.diff.task = None;
1613 if this.diff.should_update {
1614 this.update_diff(cx);
1615 }
1616 cx.notify();
1617 })
1618 .ok();
1619 }));
1620 }
1621 }
1622}
1623
1624fn strip_invalid_spans_from_codeblock(
1625 stream: impl Stream<Item = Result<String>>,
1626) -> impl Stream<Item = Result<String>> {
1627 let mut first_line = true;
1628 let mut buffer = String::new();
1629 let mut starts_with_markdown_codeblock = false;
1630 let mut includes_start_or_end_span = false;
1631 stream.filter_map(move |chunk| {
1632 let chunk = match chunk {
1633 Ok(chunk) => chunk,
1634 Err(err) => return future::ready(Some(Err(err))),
1635 };
1636 buffer.push_str(&chunk);
1637
1638 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
1639 includes_start_or_end_span = true;
1640
1641 buffer = buffer
1642 .strip_prefix("<|S|>")
1643 .or_else(|| buffer.strip_prefix("<|S|"))
1644 .unwrap_or(&buffer)
1645 .to_string();
1646 } else if buffer.ends_with("|E|>") {
1647 includes_start_or_end_span = true;
1648 } else if buffer.starts_with("<|")
1649 || buffer.starts_with("<|S")
1650 || buffer.starts_with("<|S|")
1651 || buffer.ends_with('|')
1652 || buffer.ends_with("|E")
1653 || buffer.ends_with("|E|")
1654 {
1655 return future::ready(None);
1656 }
1657
1658 if first_line {
1659 if buffer.is_empty() || buffer == "`" || buffer == "``" {
1660 return future::ready(None);
1661 } else if buffer.starts_with("```") {
1662 starts_with_markdown_codeblock = true;
1663 if let Some(newline_ix) = buffer.find('\n') {
1664 buffer.replace_range(..newline_ix + 1, "");
1665 first_line = false;
1666 } else {
1667 return future::ready(None);
1668 }
1669 }
1670 }
1671
1672 let mut text = buffer.to_string();
1673 if starts_with_markdown_codeblock {
1674 text = text
1675 .strip_suffix("\n```\n")
1676 .or_else(|| text.strip_suffix("\n```"))
1677 .or_else(|| text.strip_suffix("\n``"))
1678 .or_else(|| text.strip_suffix("\n`"))
1679 .or_else(|| text.strip_suffix('\n'))
1680 .unwrap_or(&text)
1681 .to_string();
1682 }
1683
1684 if includes_start_or_end_span {
1685 text = text
1686 .strip_suffix("|E|>")
1687 .or_else(|| text.strip_suffix("E|>"))
1688 .or_else(|| text.strip_prefix("|>"))
1689 .or_else(|| text.strip_prefix('>'))
1690 .unwrap_or(&text)
1691 .to_string();
1692 };
1693
1694 if text.contains('\n') {
1695 first_line = false;
1696 }
1697
1698 let remainder = buffer.split_off(text.len());
1699 let result = if buffer.is_empty() {
1700 None
1701 } else {
1702 Some(Ok(buffer.clone()))
1703 };
1704
1705 buffer = remainder;
1706 future::ready(result)
1707 })
1708}
1709
1710fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
1711 ranges.sort_unstable_by(|a, b| {
1712 a.start
1713 .cmp(&b.start, buffer)
1714 .then_with(|| b.end.cmp(&a.end, buffer))
1715 });
1716
1717 let mut ix = 0;
1718 while ix + 1 < ranges.len() {
1719 let b = ranges[ix + 1].clone();
1720 let a = &mut ranges[ix];
1721 if a.end.cmp(&b.start, buffer).is_gt() {
1722 if a.end.cmp(&b.end, buffer).is_lt() {
1723 a.end = b.end;
1724 }
1725 ranges.remove(ix + 1);
1726 } else {
1727 ix += 1;
1728 }
1729 }
1730}
1731
1732#[cfg(test)]
1733mod tests {
1734 use std::sync::Arc;
1735
1736 use crate::FakeCompletionProvider;
1737
1738 use super::*;
1739 use futures::stream::{self};
1740 use gpui::{Context, TestAppContext};
1741 use indoc::indoc;
1742 use language::{
1743 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1744 Point,
1745 };
1746 use rand::prelude::*;
1747 use serde::Serialize;
1748 use settings::SettingsStore;
1749
1750 #[derive(Serialize)]
1751 pub struct DummyCompletionRequest {
1752 pub name: String,
1753 }
1754
1755 #[gpui::test(iterations = 10)]
1756 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1757 let provider = FakeCompletionProvider::default();
1758 cx.set_global(cx.update(SettingsStore::test));
1759 cx.set_global(CompletionProvider::Fake(provider.clone()));
1760 cx.update(language_settings::init);
1761
1762 let text = indoc! {"
1763 fn main() {
1764 let x = 0;
1765 for _ in 0..10 {
1766 x += 1;
1767 }
1768 }
1769 "};
1770 let buffer =
1771 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1772 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1773 let range = buffer.read_with(cx, |buffer, cx| {
1774 let snapshot = buffer.snapshot(cx);
1775 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1776 });
1777 let codegen = cx.new_model(|cx| {
1778 Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
1779 });
1780
1781 let request = LanguageModelRequest::default();
1782 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1783
1784 let mut new_text = concat!(
1785 " let mut x = 0;\n",
1786 " while x < 10 {\n",
1787 " x += 1;\n",
1788 " }",
1789 );
1790 while !new_text.is_empty() {
1791 let max_len = cmp::min(new_text.len(), 10);
1792 let len = rng.gen_range(1..=max_len);
1793 let (chunk, suffix) = new_text.split_at(len);
1794 provider.send_completion(chunk.into());
1795 new_text = suffix;
1796 cx.background_executor.run_until_parked();
1797 }
1798 provider.finish_completion();
1799 cx.background_executor.run_until_parked();
1800
1801 assert_eq!(
1802 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1803 indoc! {"
1804 fn main() {
1805 let mut x = 0;
1806 while x < 10 {
1807 x += 1;
1808 }
1809 }
1810 "}
1811 );
1812 }
1813
1814 #[gpui::test(iterations = 10)]
1815 async fn test_autoindent_when_generating_past_indentation(
1816 cx: &mut TestAppContext,
1817 mut rng: StdRng,
1818 ) {
1819 let provider = FakeCompletionProvider::default();
1820 cx.set_global(CompletionProvider::Fake(provider.clone()));
1821 cx.set_global(cx.update(SettingsStore::test));
1822 cx.update(language_settings::init);
1823
1824 let text = indoc! {"
1825 fn main() {
1826 le
1827 }
1828 "};
1829 let buffer =
1830 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1831 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1832 let position = buffer.read_with(cx, |buffer, cx| {
1833 let snapshot = buffer.snapshot(cx);
1834 snapshot.anchor_before(Point::new(1, 6))
1835 });
1836 let codegen = cx.new_model(|cx| {
1837 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1838 });
1839
1840 let request = LanguageModelRequest::default();
1841 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1842
1843 let mut new_text = concat!(
1844 "t mut x = 0;\n",
1845 "while x < 10 {\n",
1846 " x += 1;\n",
1847 "}", //
1848 );
1849 while !new_text.is_empty() {
1850 let max_len = cmp::min(new_text.len(), 10);
1851 let len = rng.gen_range(1..=max_len);
1852 let (chunk, suffix) = new_text.split_at(len);
1853 provider.send_completion(chunk.into());
1854 new_text = suffix;
1855 cx.background_executor.run_until_parked();
1856 }
1857 provider.finish_completion();
1858 cx.background_executor.run_until_parked();
1859
1860 assert_eq!(
1861 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1862 indoc! {"
1863 fn main() {
1864 let mut x = 0;
1865 while x < 10 {
1866 x += 1;
1867 }
1868 }
1869 "}
1870 );
1871 }
1872
1873 #[gpui::test(iterations = 10)]
1874 async fn test_autoindent_when_generating_before_indentation(
1875 cx: &mut TestAppContext,
1876 mut rng: StdRng,
1877 ) {
1878 let provider = FakeCompletionProvider::default();
1879 cx.set_global(CompletionProvider::Fake(provider.clone()));
1880 cx.set_global(cx.update(SettingsStore::test));
1881 cx.update(language_settings::init);
1882
1883 let text = concat!(
1884 "fn main() {\n",
1885 " \n",
1886 "}\n" //
1887 );
1888 let buffer =
1889 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1890 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1891 let position = buffer.read_with(cx, |buffer, cx| {
1892 let snapshot = buffer.snapshot(cx);
1893 snapshot.anchor_before(Point::new(1, 2))
1894 });
1895 let codegen = cx.new_model(|cx| {
1896 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
1897 });
1898
1899 let request = LanguageModelRequest::default();
1900 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
1901
1902 let mut new_text = concat!(
1903 "let mut x = 0;\n",
1904 "while x < 10 {\n",
1905 " x += 1;\n",
1906 "}", //
1907 );
1908 while !new_text.is_empty() {
1909 let max_len = cmp::min(new_text.len(), 10);
1910 let len = rng.gen_range(1..=max_len);
1911 let (chunk, suffix) = new_text.split_at(len);
1912 provider.send_completion(chunk.into());
1913 new_text = suffix;
1914 cx.background_executor.run_until_parked();
1915 }
1916 provider.finish_completion();
1917 cx.background_executor.run_until_parked();
1918
1919 assert_eq!(
1920 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1921 indoc! {"
1922 fn main() {
1923 let mut x = 0;
1924 while x < 10 {
1925 x += 1;
1926 }
1927 }
1928 "}
1929 );
1930 }
1931
1932 #[gpui::test]
1933 async fn test_strip_invalid_spans_from_codeblock() {
1934 assert_eq!(
1935 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
1936 .map(|chunk| chunk.unwrap())
1937 .collect::<String>()
1938 .await,
1939 "Lorem ipsum dolor"
1940 );
1941 assert_eq!(
1942 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
1943 .map(|chunk| chunk.unwrap())
1944 .collect::<String>()
1945 .await,
1946 "Lorem ipsum dolor"
1947 );
1948 assert_eq!(
1949 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
1950 .map(|chunk| chunk.unwrap())
1951 .collect::<String>()
1952 .await,
1953 "Lorem ipsum dolor"
1954 );
1955 assert_eq!(
1956 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
1957 .map(|chunk| chunk.unwrap())
1958 .collect::<String>()
1959 .await,
1960 "Lorem ipsum dolor"
1961 );
1962 assert_eq!(
1963 strip_invalid_spans_from_codeblock(chunks(
1964 "```html\n```js\nLorem ipsum dolor\n```\n```",
1965 2
1966 ))
1967 .map(|chunk| chunk.unwrap())
1968 .collect::<String>()
1969 .await,
1970 "```js\nLorem ipsum dolor\n```"
1971 );
1972 assert_eq!(
1973 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
1974 .map(|chunk| chunk.unwrap())
1975 .collect::<String>()
1976 .await,
1977 "``\nLorem ipsum dolor\n```"
1978 );
1979 assert_eq!(
1980 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
1981 .map(|chunk| chunk.unwrap())
1982 .collect::<String>()
1983 .await,
1984 "Lorem ipsum"
1985 );
1986
1987 assert_eq!(
1988 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
1989 .map(|chunk| chunk.unwrap())
1990 .collect::<String>()
1991 .await,
1992 "Lorem ipsum"
1993 );
1994
1995 assert_eq!(
1996 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
1997 .map(|chunk| chunk.unwrap())
1998 .collect::<String>()
1999 .await,
2000 "Lorem ipsum"
2001 );
2002 assert_eq!(
2003 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
2004 .map(|chunk| chunk.unwrap())
2005 .collect::<String>()
2006 .await,
2007 "Lorem ipsum"
2008 );
2009 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
2010 stream::iter(
2011 text.chars()
2012 .collect::<Vec<_>>()
2013 .chunks(size)
2014 .map(|chunk| Ok(chunk.iter().collect::<String>()))
2015 .collect::<Vec<_>>(),
2016 )
2017 }
2018 }
2019
2020 fn rust_lang() -> Language {
2021 Language::new(
2022 LanguageConfig {
2023 name: "Rust".into(),
2024 matcher: LanguageMatcher {
2025 path_suffixes: vec!["rs".to_string()],
2026 ..Default::default()
2027 },
2028 ..Default::default()
2029 },
2030 Some(tree_sitter_rust::language()),
2031 )
2032 .with_indents_query(
2033 r#"
2034 (call_expression) @indent
2035 (field_expression) @indent
2036 (_ "(" ")" @end) @indent
2037 (_ "{" "}" @end) @indent
2038 "#,
2039 )
2040 .unwrap()
2041 }
2042}