1use crate::{
2 prompts::generate_content_prompt, AssistantPanel, CompletionProvider, Hunk,
3 LanguageModelRequest, LanguageModelRequestMessage, Role, StreamingDiff,
4};
5use anyhow::{Context as _, Result};
6use client::telemetry::Telemetry;
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
18use gpui::{
19 point, AppContext, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global,
20 HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View,
21 ViewContext, WeakView, WhiteSpace, WindowContext,
22};
23use language::{Buffer, Point, Selection, TransactionId};
24use multi_buffer::MultiBufferRow;
25use parking_lot::Mutex;
26use rope::Rope;
27use settings::Settings;
28use similar::TextDiff;
29use std::{
30 cmp, mem,
31 ops::{Range, RangeInclusive},
32 pin::Pin,
33 sync::Arc,
34 task::{self, Poll},
35 time::Instant,
36};
37use theme::ThemeSettings;
38use ui::{prelude::*, Tooltip};
39use util::RangeExt;
40use workspace::{notifications::NotificationId, Toast, Workspace};
41
42pub fn init(telemetry: Arc<Telemetry>, cx: &mut AppContext) {
43 cx.set_global(InlineAssistant::new(telemetry));
44}
45
46const PROMPT_HISTORY_MAX_LEN: usize = 20;
47
48pub struct InlineAssistant {
49 next_assist_id: InlineAssistId,
50 next_assist_group_id: InlineAssistGroupId,
51 assists: HashMap<InlineAssistId, InlineAssist>,
52 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
53 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
54 prompt_history: VecDeque<String>,
55 telemetry: Option<Arc<Telemetry>>,
56}
57
58impl Global for InlineAssistant {}
59
60impl InlineAssistant {
61 pub fn new(telemetry: Arc<Telemetry>) -> Self {
62 Self {
63 next_assist_id: InlineAssistId::default(),
64 next_assist_group_id: InlineAssistGroupId::default(),
65 assists: HashMap::default(),
66 assists_by_editor: HashMap::default(),
67 assist_groups: HashMap::default(),
68 prompt_history: VecDeque::default(),
69 telemetry: Some(telemetry),
70 }
71 }
72
73 pub fn assist(
74 &mut self,
75 editor: &View<Editor>,
76 workspace: Option<WeakView<Workspace>>,
77 include_context: bool,
78 cx: &mut WindowContext,
79 ) {
80 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
81
82 let mut selections = Vec::<Selection<Point>>::new();
83 let mut newest_selection = None;
84 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
85 if selection.end > selection.start {
86 selection.start.column = 0;
87 // If the selection ends at the start of the line, we don't want to include it.
88 if selection.end.column == 0 {
89 selection.end.row -= 1;
90 }
91 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
92 }
93
94 if let Some(prev_selection) = selections.last_mut() {
95 if selection.start <= prev_selection.end {
96 prev_selection.end = selection.end;
97 continue;
98 }
99 }
100
101 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
102 if selection.id > latest_selection.id {
103 *latest_selection = selection.clone();
104 }
105 selections.push(selection);
106 }
107 let newest_selection = newest_selection.unwrap();
108
109 let mut codegen_ranges = Vec::new();
110 for (excerpt_id, buffer, buffer_range) in
111 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
112 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
113 }))
114 {
115 let start = Anchor {
116 buffer_id: Some(buffer.remote_id()),
117 excerpt_id,
118 text_anchor: buffer.anchor_before(buffer_range.start),
119 };
120 let end = Anchor {
121 buffer_id: Some(buffer.remote_id()),
122 excerpt_id,
123 text_anchor: buffer.anchor_after(buffer_range.end),
124 };
125 codegen_ranges.push(start..end);
126 }
127
128 let assist_group_id = self.next_assist_group_id.post_inc();
129 let prompt_buffer = cx.new_model(|cx| Buffer::local("", cx));
130 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
131
132 let mut assists = Vec::new();
133 let mut assist_blocks = Vec::new();
134 let mut assist_to_focus = None;
135 for range in codegen_ranges {
136 let assist_id = self.next_assist_id.post_inc();
137 let codegen = cx.new_model(|cx| {
138 Codegen::new(
139 editor.read(cx).buffer().clone(),
140 range.clone(),
141 self.telemetry.clone(),
142 cx,
143 )
144 });
145
146 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
147 let prompt_editor = cx.new_view(|cx| {
148 PromptEditor::new(
149 assist_id,
150 gutter_dimensions.clone(),
151 self.prompt_history.clone(),
152 prompt_buffer.clone(),
153 codegen.clone(),
154 workspace.clone(),
155 cx,
156 )
157 });
158
159 if assist_to_focus.is_none() {
160 let focus_assist = if newest_selection.reversed {
161 range.start.to_point(&snapshot) == newest_selection.start
162 } else {
163 range.end.to_point(&snapshot) == newest_selection.end
164 };
165 if focus_assist {
166 assist_to_focus = Some(assist_id);
167 }
168 }
169
170 assist_blocks.push(BlockProperties {
171 style: BlockStyle::Sticky,
172 position: range.start,
173 height: prompt_editor.read(cx).height_in_lines,
174 render: build_assist_editor_renderer(&prompt_editor),
175 disposition: BlockDisposition::Above,
176 });
177 assist_blocks.push(BlockProperties {
178 style: BlockStyle::Sticky,
179 position: range.end,
180 height: 1,
181 render: Box::new(|cx| {
182 v_flex()
183 .h_full()
184 .w_full()
185 .border_t_1()
186 .border_color(cx.theme().status().info_border)
187 .into_any_element()
188 }),
189 disposition: BlockDisposition::Below,
190 });
191 assists.push((assist_id, prompt_editor));
192 }
193
194 let assist_block_ids = editor.update(cx, |editor, cx| {
195 editor.insert_blocks(assist_blocks, None, cx)
196 });
197
198 let editor_assists = self
199 .assists_by_editor
200 .entry(editor.downgrade())
201 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
202 let mut assist_group = InlineAssistGroup::new();
203 for ((assist_id, prompt_editor), block_ids) in
204 assists.into_iter().zip(assist_block_ids.chunks_exact(2))
205 {
206 self.assists.insert(
207 assist_id,
208 InlineAssist::new(
209 assist_id,
210 assist_group_id,
211 include_context,
212 editor,
213 &prompt_editor,
214 block_ids[0],
215 block_ids[1],
216 prompt_editor.read(cx).codegen.clone(),
217 workspace.clone(),
218 cx,
219 ),
220 );
221 assist_group.assist_ids.push(assist_id);
222 editor_assists.assist_ids.push(assist_id);
223 }
224 self.assist_groups.insert(assist_group_id, assist_group);
225
226 if let Some(assist_id) = assist_to_focus {
227 self.focus_assist(assist_id, cx);
228 }
229 }
230
231 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
232 let assist = &self.assists[&assist_id];
233 let Some(decorations) = assist.decorations.as_ref() else {
234 return;
235 };
236 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
237 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
238
239 assist_group.active_assist_id = Some(assist_id);
240 if assist_group.linked {
241 for assist_id in &assist_group.assist_ids {
242 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
243 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
244 prompt_editor.set_show_cursor_when_unfocused(true, cx)
245 });
246 }
247 }
248 }
249
250 assist
251 .editor
252 .update(cx, |editor, cx| {
253 let scroll_top = editor.scroll_position(cx).y;
254 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
255 let prompt_row = editor
256 .row_for_block(decorations.prompt_block_id, cx)
257 .unwrap()
258 .0 as f32;
259
260 if (scroll_top..scroll_bottom).contains(&prompt_row) {
261 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
262 assist_id,
263 distance_from_top: prompt_row - scroll_top,
264 });
265 } else {
266 editor_assists.scroll_lock = None;
267 }
268 })
269 .ok();
270 }
271
272 fn handle_prompt_editor_focus_out(
273 &mut self,
274 assist_id: InlineAssistId,
275 cx: &mut WindowContext,
276 ) {
277 let assist = &self.assists[&assist_id];
278 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
279 if assist_group.active_assist_id == Some(assist_id) {
280 assist_group.active_assist_id = None;
281 if assist_group.linked {
282 for assist_id in &assist_group.assist_ids {
283 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
284 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
285 prompt_editor.set_show_cursor_when_unfocused(false, cx)
286 });
287 }
288 }
289 }
290 }
291 }
292
293 fn handle_prompt_editor_event(
294 &mut self,
295 prompt_editor: View<PromptEditor>,
296 event: &PromptEditorEvent,
297 cx: &mut WindowContext,
298 ) {
299 let assist_id = prompt_editor.read(cx).id;
300 match event {
301 PromptEditorEvent::StartRequested => {
302 self.start_assist(assist_id, cx);
303 }
304 PromptEditorEvent::StopRequested => {
305 self.stop_assist(assist_id, cx);
306 }
307 PromptEditorEvent::ConfirmRequested => {
308 self.finish_assist(assist_id, false, cx);
309 }
310 PromptEditorEvent::CancelRequested => {
311 self.finish_assist(assist_id, true, cx);
312 }
313 PromptEditorEvent::DismissRequested => {
314 self.dismiss_assist(assist_id, cx);
315 }
316 PromptEditorEvent::Resized { height_in_lines } => {
317 self.resize_assist(assist_id, *height_in_lines, cx);
318 }
319 }
320 }
321
322 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
323 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
324 return;
325 };
326
327 let editor = editor.read(cx);
328 if editor.selections.count() == 1 {
329 let selection = editor.selections.newest::<usize>(cx);
330 let buffer = editor.buffer().read(cx).snapshot(cx);
331 for assist_id in &editor_assists.assist_ids {
332 let assist = &self.assists[assist_id];
333 let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
334 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
335 {
336 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
337 self.dismiss_assist(*assist_id, cx);
338 } else {
339 self.finish_assist(*assist_id, false, cx);
340 }
341
342 return;
343 }
344 }
345 }
346
347 cx.propagate();
348 }
349
350 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
351 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
352 return;
353 };
354
355 let editor = editor.read(cx);
356 if editor.selections.count() == 1 {
357 let selection = editor.selections.newest::<usize>(cx);
358 let buffer = editor.buffer().read(cx).snapshot(cx);
359 for assist_id in &editor_assists.assist_ids {
360 let assist = &self.assists[assist_id];
361 let assist_range = assist.codegen.read(cx).range.to_offset(&buffer);
362 if assist.decorations.is_some()
363 && assist_range.contains(&selection.start)
364 && assist_range.contains(&selection.end)
365 {
366 self.focus_assist(*assist_id, cx);
367 return;
368 }
369 }
370 }
371
372 cx.propagate();
373 }
374
375 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
376 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
377 return;
378 };
379 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
380 return;
381 };
382 let assist = &self.assists[&scroll_lock.assist_id];
383 let Some(decorations) = assist.decorations.as_ref() else {
384 return;
385 };
386
387 editor.update(cx, |editor, cx| {
388 let scroll_position = editor.scroll_position(cx);
389 let target_scroll_top = editor
390 .row_for_block(decorations.prompt_block_id, cx)
391 .unwrap()
392 .0 as f32
393 - scroll_lock.distance_from_top;
394 if target_scroll_top != scroll_position.y {
395 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
396 }
397 });
398 }
399
400 fn handle_editor_event(
401 &mut self,
402 editor: View<Editor>,
403 event: &EditorEvent,
404 cx: &mut WindowContext,
405 ) {
406 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
407 return;
408 };
409
410 match event {
411 EditorEvent::Saved => {
412 for assist_id in editor_assists.assist_ids.clone() {
413 let assist = &self.assists[&assist_id];
414 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
415 self.finish_assist(assist_id, false, cx)
416 }
417 }
418 }
419 EditorEvent::Edited { transaction_id } => {
420 let buffer = editor.read(cx).buffer().read(cx);
421 let edited_ranges =
422 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
423 let snapshot = buffer.snapshot(cx);
424
425 for assist_id in editor_assists.assist_ids.clone() {
426 let assist = &self.assists[&assist_id];
427 if matches!(
428 assist.codegen.read(cx).status,
429 CodegenStatus::Error(_) | CodegenStatus::Done
430 ) {
431 let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot);
432 if edited_ranges
433 .iter()
434 .any(|range| range.overlaps(&assist_range))
435 {
436 self.finish_assist(assist_id, false, cx);
437 }
438 }
439 }
440 }
441 EditorEvent::ScrollPositionChanged { .. } => {
442 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
443 let assist = &self.assists[&scroll_lock.assist_id];
444 if let Some(decorations) = assist.decorations.as_ref() {
445 let distance_from_top = editor.update(cx, |editor, cx| {
446 let scroll_top = editor.scroll_position(cx).y;
447 let prompt_row = editor
448 .row_for_block(decorations.prompt_block_id, cx)
449 .unwrap()
450 .0 as f32;
451 prompt_row - scroll_top
452 });
453
454 if distance_from_top != scroll_lock.distance_from_top {
455 editor_assists.scroll_lock = None;
456 }
457 }
458 }
459 }
460 EditorEvent::SelectionsChanged { .. } => {
461 for assist_id in editor_assists.assist_ids.clone() {
462 let assist = &self.assists[&assist_id];
463 if let Some(decorations) = assist.decorations.as_ref() {
464 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
465 return;
466 }
467 }
468 }
469
470 editor_assists.scroll_lock = None;
471 }
472 _ => {}
473 }
474 }
475
476 fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
477 if let Some(assist) = self.assists.get(&assist_id) {
478 let assist_group_id = assist.group_id;
479 if self.assist_groups[&assist_group_id].linked {
480 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
481 self.finish_assist(assist_id, undo, cx);
482 }
483 return;
484 }
485 }
486
487 self.dismiss_assist(assist_id, cx);
488
489 if let Some(assist) = self.assists.remove(&assist_id) {
490 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
491 {
492 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
493 if entry.get().assist_ids.is_empty() {
494 entry.remove();
495 }
496 }
497
498 if let hash_map::Entry::Occupied(mut entry) =
499 self.assists_by_editor.entry(assist.editor.clone())
500 {
501 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
502 if entry.get().assist_ids.is_empty() {
503 entry.remove();
504 if let Some(editor) = assist.editor.upgrade() {
505 self.update_editor_highlights(&editor, cx);
506 }
507 } else {
508 entry.get().highlight_updates.send(()).ok();
509 }
510 }
511
512 if undo {
513 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
514 }
515 }
516 }
517
518 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
519 let Some(assist) = self.assists.get_mut(&assist_id) else {
520 return false;
521 };
522 let Some(editor) = assist.editor.upgrade() else {
523 return false;
524 };
525 let Some(decorations) = assist.decorations.take() else {
526 return false;
527 };
528
529 editor.update(cx, |editor, cx| {
530 let mut to_remove = decorations.removed_line_block_ids;
531 to_remove.insert(decorations.prompt_block_id);
532 to_remove.insert(decorations.end_block_id);
533 editor.remove_blocks(to_remove, None, cx);
534 });
535
536 if decorations
537 .prompt_editor
538 .focus_handle(cx)
539 .contains_focused(cx)
540 {
541 self.focus_next_assist(assist_id, cx);
542 }
543
544 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
545 if editor_assists
546 .scroll_lock
547 .as_ref()
548 .map_or(false, |lock| lock.assist_id == assist_id)
549 {
550 editor_assists.scroll_lock = None;
551 }
552 editor_assists.highlight_updates.send(()).ok();
553 }
554
555 true
556 }
557
558 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
559 let Some(assist) = self.assists.get(&assist_id) else {
560 return;
561 };
562
563 let assist_group = &self.assist_groups[&assist.group_id];
564 let assist_ix = assist_group
565 .assist_ids
566 .iter()
567 .position(|id| *id == assist_id)
568 .unwrap();
569 let assist_ids = assist_group
570 .assist_ids
571 .iter()
572 .skip(assist_ix + 1)
573 .chain(assist_group.assist_ids.iter().take(assist_ix));
574
575 for assist_id in assist_ids {
576 let assist = &self.assists[assist_id];
577 if assist.decorations.is_some() {
578 self.focus_assist(*assist_id, cx);
579 return;
580 }
581 }
582
583 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
584 }
585
586 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
587 let assist = &self.assists[&assist_id];
588 let Some(editor) = assist.editor.upgrade() else {
589 return;
590 };
591
592 if let Some(decorations) = assist.decorations.as_ref() {
593 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
594 prompt_editor.editor.update(cx, |editor, cx| {
595 editor.focus(cx);
596 editor.select_all(&SelectAll, cx);
597 })
598 });
599 }
600
601 let position = assist.codegen.read(cx).range.start;
602 editor.update(cx, |editor, cx| {
603 editor.change_selections(None, cx, |selections| {
604 selections.select_anchor_ranges([position..position])
605 });
606
607 let mut scroll_target_top;
608 let mut scroll_target_bottom;
609 if let Some(decorations) = assist.decorations.as_ref() {
610 scroll_target_top = editor
611 .row_for_block(decorations.prompt_block_id, cx)
612 .unwrap()
613 .0 as f32;
614 scroll_target_bottom = editor
615 .row_for_block(decorations.end_block_id, cx)
616 .unwrap()
617 .0 as f32;
618 } else {
619 let snapshot = editor.snapshot(cx);
620 let codegen = assist.codegen.read(cx);
621 let start_row = codegen
622 .range
623 .start
624 .to_display_point(&snapshot.display_snapshot)
625 .row();
626 scroll_target_top = start_row.0 as f32;
627 scroll_target_bottom = scroll_target_top + 1.;
628 }
629 scroll_target_top -= editor.vertical_scroll_margin() as f32;
630 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
631
632 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
633 let scroll_top = editor.scroll_position(cx).y;
634 let scroll_bottom = scroll_top + height_in_lines;
635
636 if scroll_target_top < scroll_top {
637 editor.set_scroll_position(point(0., scroll_target_top), cx);
638 } else if scroll_target_bottom > scroll_bottom {
639 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
640 editor
641 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
642 } else {
643 editor.set_scroll_position(point(0., scroll_target_top), cx);
644 }
645 }
646 });
647 }
648
649 fn resize_assist(
650 &mut self,
651 assist_id: InlineAssistId,
652 height_in_lines: u8,
653 cx: &mut WindowContext,
654 ) {
655 if let Some(assist) = self.assists.get_mut(&assist_id) {
656 if let Some(editor) = assist.editor.upgrade() {
657 if let Some(decorations) = assist.decorations.as_ref() {
658 let mut new_blocks = HashMap::default();
659 new_blocks.insert(
660 decorations.prompt_block_id,
661 (
662 Some(height_in_lines),
663 build_assist_editor_renderer(&decorations.prompt_editor),
664 ),
665 );
666 editor.update(cx, |editor, cx| {
667 editor
668 .display_map
669 .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
670 });
671 }
672 }
673 }
674 }
675
676 fn unlink_assist_group(
677 &mut self,
678 assist_group_id: InlineAssistGroupId,
679 cx: &mut WindowContext,
680 ) -> Vec<InlineAssistId> {
681 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
682 assist_group.linked = false;
683 for assist_id in &assist_group.assist_ids {
684 let assist = self.assists.get_mut(assist_id).unwrap();
685 if let Some(editor_decorations) = assist.decorations.as_ref() {
686 editor_decorations
687 .prompt_editor
688 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
689 }
690 }
691 assist_group.assist_ids.clone()
692 }
693
694 fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
695 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
696 assist
697 } else {
698 return;
699 };
700
701 let assist_group_id = assist.group_id;
702 if self.assist_groups[&assist_group_id].linked {
703 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
704 self.start_assist(assist_id, cx);
705 }
706 return;
707 }
708
709 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
710
711 let Some(user_prompt) = assist
712 .decorations
713 .as_ref()
714 .map(|decorations| decorations.prompt_editor.read(cx).prompt(cx))
715 else {
716 return;
717 };
718
719 let context = if assist.include_context {
720 assist.workspace.as_ref().and_then(|workspace| {
721 let workspace = workspace.upgrade()?.read(cx);
722 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
723 assistant_panel.read(cx).active_context(cx)
724 })
725 } else {
726 None
727 };
728
729 let editor = if let Some(editor) = assist.editor.upgrade() {
730 editor
731 } else {
732 return;
733 };
734
735 let project_name = assist.workspace.as_ref().and_then(|workspace| {
736 let workspace = workspace.upgrade()?;
737 Some(
738 workspace
739 .read(cx)
740 .project()
741 .read(cx)
742 .worktree_root_names(cx)
743 .collect::<Vec<&str>>()
744 .join("/"),
745 )
746 });
747
748 self.prompt_history.retain(|prompt| *prompt != user_prompt);
749 self.prompt_history.push_back(user_prompt.clone());
750 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
751 self.prompt_history.pop_front();
752 }
753
754 let codegen = assist.codegen.clone();
755 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
756 let range = codegen.read(cx).range.clone();
757 let start = snapshot.point_to_buffer_offset(range.start);
758 let end = snapshot.point_to_buffer_offset(range.end);
759 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
760 let (start_buffer, start_buffer_offset) = start;
761 let (end_buffer, end_buffer_offset) = end;
762 if start_buffer.remote_id() == end_buffer.remote_id() {
763 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
764 } else {
765 self.finish_assist(assist_id, false, cx);
766 return;
767 }
768 } else {
769 self.finish_assist(assist_id, false, cx);
770 return;
771 };
772
773 let language = buffer.language_at(range.start);
774 let language_name = if let Some(language) = language.as_ref() {
775 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
776 None
777 } else {
778 Some(language.name())
779 }
780 } else {
781 None
782 };
783
784 // Higher Temperature increases the randomness of model outputs.
785 // If Markdown or No Language is Known, increase the randomness for more creative output
786 // If Code, decrease temperature to get more deterministic outputs
787 let temperature = if let Some(language) = language_name.clone() {
788 if language.as_ref() == "Markdown" {
789 1.0
790 } else {
791 0.5
792 }
793 } else {
794 1.0
795 };
796
797 let prompt = cx.background_executor().spawn(async move {
798 let language_name = language_name.as_deref();
799 generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
800 });
801
802 let mut messages = Vec::new();
803 if let Some(context) = context {
804 let request = context.read(cx).to_completion_request(cx);
805 messages = request.messages;
806 }
807 let model = CompletionProvider::global(cx).model();
808
809 cx.spawn(|mut cx| async move {
810 let prompt = prompt.await?;
811
812 messages.push(LanguageModelRequestMessage {
813 role: Role::User,
814 content: prompt,
815 });
816
817 let request = LanguageModelRequest {
818 model,
819 messages,
820 stop: vec!["|END|>".to_string()],
821 temperature,
822 };
823
824 codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
825 anyhow::Ok(())
826 })
827 .detach_and_log_err(cx);
828 }
829
830 fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
831 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
832 assist
833 } else {
834 return;
835 };
836
837 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
838 }
839
840 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
841 let mut gutter_pending_ranges = Vec::new();
842 let mut gutter_transformed_ranges = Vec::new();
843 let mut foreground_ranges = Vec::new();
844 let mut inserted_row_ranges = Vec::new();
845 let empty_assist_ids = Vec::new();
846 let assist_ids = self
847 .assists_by_editor
848 .get(&editor.downgrade())
849 .map_or(&empty_assist_ids, |editor_assists| {
850 &editor_assists.assist_ids
851 });
852
853 for assist_id in assist_ids {
854 if let Some(assist) = self.assists.get(assist_id) {
855 let codegen = assist.codegen.read(cx);
856 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
857
858 if codegen.edit_position != codegen.range.end {
859 gutter_pending_ranges.push(codegen.edit_position..codegen.range.end);
860 }
861
862 if codegen.range.start != codegen.edit_position {
863 gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position);
864 }
865
866 if assist.decorations.is_some() {
867 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
868 }
869 }
870 }
871
872 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
873 merge_ranges(&mut foreground_ranges, &snapshot);
874 merge_ranges(&mut gutter_pending_ranges, &snapshot);
875 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
876 editor.update(cx, |editor, cx| {
877 enum GutterPendingRange {}
878 if gutter_pending_ranges.is_empty() {
879 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
880 } else {
881 editor.highlight_gutter::<GutterPendingRange>(
882 &gutter_pending_ranges,
883 |cx| cx.theme().status().info_background,
884 cx,
885 )
886 }
887
888 enum GutterTransformedRange {}
889 if gutter_transformed_ranges.is_empty() {
890 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
891 } else {
892 editor.highlight_gutter::<GutterTransformedRange>(
893 &gutter_transformed_ranges,
894 |cx| cx.theme().status().info,
895 cx,
896 )
897 }
898
899 if foreground_ranges.is_empty() {
900 editor.clear_highlights::<InlineAssist>(cx);
901 } else {
902 editor.highlight_text::<InlineAssist>(
903 foreground_ranges,
904 HighlightStyle {
905 fade_out: Some(0.6),
906 ..Default::default()
907 },
908 cx,
909 );
910 }
911
912 editor.clear_row_highlights::<InlineAssist>();
913 for row_range in inserted_row_ranges {
914 editor.highlight_rows::<InlineAssist>(
915 row_range,
916 Some(cx.theme().status().info_background),
917 false,
918 cx,
919 );
920 }
921 });
922 }
923
924 fn update_editor_blocks(
925 &mut self,
926 editor: &View<Editor>,
927 assist_id: InlineAssistId,
928 cx: &mut WindowContext,
929 ) {
930 let Some(assist) = self.assists.get_mut(&assist_id) else {
931 return;
932 };
933 let Some(decorations) = assist.decorations.as_mut() else {
934 return;
935 };
936
937 let codegen = assist.codegen.read(cx);
938 let old_snapshot = codegen.snapshot.clone();
939 let old_buffer = codegen.old_buffer.clone();
940 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
941
942 editor.update(cx, |editor, cx| {
943 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
944 editor.remove_blocks(old_blocks, None, cx);
945
946 let mut new_blocks = Vec::new();
947 for (new_row, old_row_range) in deleted_row_ranges {
948 let (_, buffer_start) = old_snapshot
949 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
950 .unwrap();
951 let (_, buffer_end) = old_snapshot
952 .point_to_buffer_offset(Point::new(
953 *old_row_range.end(),
954 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
955 ))
956 .unwrap();
957
958 let deleted_lines_editor = cx.new_view(|cx| {
959 let multi_buffer = cx.new_model(|_| {
960 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
961 });
962 multi_buffer.update(cx, |multi_buffer, cx| {
963 multi_buffer.push_excerpts(
964 old_buffer.clone(),
965 Some(ExcerptRange {
966 context: buffer_start..buffer_end,
967 primary: None,
968 }),
969 cx,
970 );
971 });
972
973 enum DeletedLines {}
974 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
975 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
976 editor.set_show_wrap_guides(false, cx);
977 editor.set_show_gutter(false, cx);
978 editor.scroll_manager.set_forbid_vertical_scroll(true);
979 editor.set_read_only(true);
980 editor.highlight_rows::<DeletedLines>(
981 Anchor::min()..=Anchor::max(),
982 Some(cx.theme().status().deleted_background),
983 false,
984 cx,
985 );
986 editor
987 });
988
989 let height = deleted_lines_editor
990 .update(cx, |editor, cx| editor.max_point(cx).row().0 as u8 + 1);
991 new_blocks.push(BlockProperties {
992 position: new_row,
993 height,
994 style: BlockStyle::Flex,
995 render: Box::new(move |cx| {
996 div()
997 .bg(cx.theme().status().deleted_background)
998 .size_full()
999 .pl(cx.gutter_dimensions.full_width())
1000 .child(deleted_lines_editor.clone())
1001 .into_any_element()
1002 }),
1003 disposition: BlockDisposition::Above,
1004 });
1005 }
1006
1007 decorations.removed_line_block_ids = editor
1008 .insert_blocks(new_blocks, None, cx)
1009 .into_iter()
1010 .collect();
1011 })
1012 }
1013}
1014
1015struct EditorInlineAssists {
1016 assist_ids: Vec<InlineAssistId>,
1017 scroll_lock: Option<InlineAssistScrollLock>,
1018 highlight_updates: async_watch::Sender<()>,
1019 _update_highlights: Task<Result<()>>,
1020 _subscriptions: Vec<gpui::Subscription>,
1021}
1022
1023struct InlineAssistScrollLock {
1024 assist_id: InlineAssistId,
1025 distance_from_top: f32,
1026}
1027
1028impl EditorInlineAssists {
1029 #[allow(clippy::too_many_arguments)]
1030 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1031 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1032 Self {
1033 assist_ids: Vec::new(),
1034 scroll_lock: None,
1035 highlight_updates: highlight_updates_tx,
1036 _update_highlights: cx.spawn(|mut cx| {
1037 let editor = editor.downgrade();
1038 async move {
1039 while let Ok(()) = highlight_updates_rx.changed().await {
1040 let editor = editor.upgrade().context("editor was dropped")?;
1041 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1042 assistant.update_editor_highlights(&editor, cx);
1043 })?;
1044 }
1045 Ok(())
1046 }
1047 }),
1048 _subscriptions: vec![
1049 cx.observe(editor, move |editor, cx| {
1050 InlineAssistant::update_global(cx, |this, cx| {
1051 this.handle_editor_change(editor, cx)
1052 })
1053 }),
1054 cx.subscribe(editor, move |editor, event, cx| {
1055 InlineAssistant::update_global(cx, |this, cx| {
1056 this.handle_editor_event(editor, event, cx)
1057 })
1058 }),
1059 editor.update(cx, |editor, cx| {
1060 let editor_handle = cx.view().downgrade();
1061 editor.register_action(
1062 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1063 InlineAssistant::update_global(cx, |this, cx| {
1064 if let Some(editor) = editor_handle.upgrade() {
1065 this.handle_editor_newline(editor, cx)
1066 }
1067 })
1068 },
1069 )
1070 }),
1071 editor.update(cx, |editor, cx| {
1072 let editor_handle = cx.view().downgrade();
1073 editor.register_action(
1074 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1075 InlineAssistant::update_global(cx, |this, cx| {
1076 if let Some(editor) = editor_handle.upgrade() {
1077 this.handle_editor_cancel(editor, cx)
1078 }
1079 })
1080 },
1081 )
1082 }),
1083 ],
1084 }
1085 }
1086}
1087
1088struct InlineAssistGroup {
1089 assist_ids: Vec<InlineAssistId>,
1090 linked: bool,
1091 active_assist_id: Option<InlineAssistId>,
1092}
1093
1094impl InlineAssistGroup {
1095 fn new() -> Self {
1096 Self {
1097 assist_ids: Vec::new(),
1098 linked: true,
1099 active_assist_id: None,
1100 }
1101 }
1102}
1103
1104fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1105 let editor = editor.clone();
1106 Box::new(move |cx: &mut BlockContext| {
1107 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1108 editor.clone().into_any_element()
1109 })
1110}
1111
1112#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1113struct InlineAssistId(usize);
1114
1115impl InlineAssistId {
1116 fn post_inc(&mut self) -> InlineAssistId {
1117 let id = *self;
1118 self.0 += 1;
1119 id
1120 }
1121}
1122
1123#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1124struct InlineAssistGroupId(usize);
1125
1126impl InlineAssistGroupId {
1127 fn post_inc(&mut self) -> InlineAssistGroupId {
1128 let id = *self;
1129 self.0 += 1;
1130 id
1131 }
1132}
1133
1134enum PromptEditorEvent {
1135 StartRequested,
1136 StopRequested,
1137 ConfirmRequested,
1138 CancelRequested,
1139 DismissRequested,
1140 Resized { height_in_lines: u8 },
1141}
1142
1143struct PromptEditor {
1144 id: InlineAssistId,
1145 height_in_lines: u8,
1146 editor: View<Editor>,
1147 edited_since_done: bool,
1148 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1149 prompt_history: VecDeque<String>,
1150 prompt_history_ix: Option<usize>,
1151 pending_prompt: String,
1152 codegen: Model<Codegen>,
1153 workspace: Option<WeakView<Workspace>>,
1154 _codegen_subscription: Subscription,
1155 editor_subscriptions: Vec<Subscription>,
1156}
1157
1158impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1159
1160impl Render for PromptEditor {
1161 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1162 let gutter_dimensions = *self.gutter_dimensions.lock();
1163
1164 let buttons = match &self.codegen.read(cx).status {
1165 CodegenStatus::Idle => {
1166 vec![
1167 IconButton::new("cancel", IconName::Close)
1168 .icon_color(Color::Muted)
1169 .size(ButtonSize::None)
1170 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1171 .on_click(
1172 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1173 ),
1174 IconButton::new("start", IconName::Sparkle)
1175 .icon_color(Color::Muted)
1176 .size(ButtonSize::None)
1177 .icon_size(IconSize::XSmall)
1178 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1179 .on_click(
1180 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1181 ),
1182 ]
1183 }
1184 CodegenStatus::Pending => {
1185 vec![
1186 IconButton::new("cancel", IconName::Close)
1187 .icon_color(Color::Muted)
1188 .size(ButtonSize::None)
1189 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1190 .on_click(
1191 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1192 ),
1193 IconButton::new("stop", IconName::Stop)
1194 .icon_color(Color::Error)
1195 .size(ButtonSize::None)
1196 .icon_size(IconSize::XSmall)
1197 .tooltip(|cx| {
1198 Tooltip::with_meta(
1199 "Interrupt Transformation",
1200 Some(&menu::Cancel),
1201 "Changes won't be discarded",
1202 cx,
1203 )
1204 })
1205 .on_click(
1206 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1207 ),
1208 ]
1209 }
1210 CodegenStatus::Error(_) | CodegenStatus::Done => {
1211 vec![
1212 IconButton::new("cancel", IconName::Close)
1213 .icon_color(Color::Muted)
1214 .size(ButtonSize::None)
1215 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1216 .on_click(
1217 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1218 ),
1219 if self.edited_since_done {
1220 IconButton::new("restart", IconName::RotateCw)
1221 .icon_color(Color::Info)
1222 .icon_size(IconSize::XSmall)
1223 .size(ButtonSize::None)
1224 .tooltip(|cx| {
1225 Tooltip::with_meta(
1226 "Restart Transformation",
1227 Some(&menu::Confirm),
1228 "Changes will be discarded",
1229 cx,
1230 )
1231 })
1232 .on_click(cx.listener(|_, _, cx| {
1233 cx.emit(PromptEditorEvent::StartRequested);
1234 }))
1235 } else {
1236 IconButton::new("confirm", IconName::Check)
1237 .icon_color(Color::Info)
1238 .size(ButtonSize::None)
1239 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1240 .on_click(cx.listener(|_, _, cx| {
1241 cx.emit(PromptEditorEvent::ConfirmRequested);
1242 }))
1243 },
1244 ]
1245 }
1246 };
1247
1248 v_flex().h_full().w_full().justify_end().child(
1249 h_flex()
1250 .bg(cx.theme().colors().editor_background)
1251 .border_y_1()
1252 .border_color(cx.theme().status().info_border)
1253 .py_1p5()
1254 .w_full()
1255 .on_action(cx.listener(Self::confirm))
1256 .on_action(cx.listener(Self::cancel))
1257 .on_action(cx.listener(Self::move_up))
1258 .on_action(cx.listener(Self::move_down))
1259 .child(
1260 h_flex()
1261 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1262 // .pr(gutter_dimensions.fold_area_width())
1263 .justify_center()
1264 .gap_2()
1265 .children(self.workspace.clone().map(|workspace| {
1266 IconButton::new("context", IconName::Context)
1267 .size(ButtonSize::None)
1268 .icon_size(IconSize::XSmall)
1269 .icon_color(Color::Muted)
1270 .on_click({
1271 let workspace = workspace.clone();
1272 cx.listener(move |_, _, cx| {
1273 workspace
1274 .update(cx, |workspace, cx| {
1275 workspace.focus_panel::<AssistantPanel>(cx);
1276 })
1277 .ok();
1278 })
1279 })
1280 .tooltip(move |cx| {
1281 let token_count = workspace.upgrade().and_then(|workspace| {
1282 let panel =
1283 workspace.read(cx).panel::<AssistantPanel>(cx)?;
1284 let context = panel.read(cx).active_context(cx)?;
1285 context.read(cx).token_count()
1286 });
1287 if let Some(token_count) = token_count {
1288 Tooltip::with_meta(
1289 format!(
1290 "{} Additional Context Tokens from Assistant",
1291 token_count
1292 ),
1293 Some(&crate::ToggleFocus),
1294 "Click to open…",
1295 cx,
1296 )
1297 } else {
1298 Tooltip::for_action(
1299 "Toggle Assistant Panel",
1300 &crate::ToggleFocus,
1301 cx,
1302 )
1303 }
1304 })
1305 }))
1306 .children(
1307 if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
1308 let error_message = SharedString::from(error.to_string());
1309 Some(
1310 div()
1311 .id("error")
1312 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1313 .child(
1314 Icon::new(IconName::XCircle)
1315 .size(IconSize::Small)
1316 .color(Color::Error),
1317 ),
1318 )
1319 } else {
1320 None
1321 },
1322 ),
1323 )
1324 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1325 .child(h_flex().gap_2().pr_4().children(buttons)),
1326 )
1327 }
1328}
1329
1330impl FocusableView for PromptEditor {
1331 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1332 self.editor.focus_handle(cx)
1333 }
1334}
1335
1336impl PromptEditor {
1337 const MAX_LINES: u8 = 8;
1338
1339 fn new(
1340 id: InlineAssistId,
1341 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1342 prompt_history: VecDeque<String>,
1343 prompt_buffer: Model<MultiBuffer>,
1344 codegen: Model<Codegen>,
1345 workspace: Option<WeakView<Workspace>>,
1346 cx: &mut ViewContext<Self>,
1347 ) -> Self {
1348 let prompt_editor = cx.new_view(|cx| {
1349 let mut editor = Editor::new(
1350 EditorMode::AutoHeight {
1351 max_lines: Self::MAX_LINES as usize,
1352 },
1353 prompt_buffer,
1354 None,
1355 false,
1356 cx,
1357 );
1358 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1359 // Since the prompt editors for all inline assistants are linked,
1360 // always show the cursor (even when it isn't focused) because
1361 // typing in one will make what you typed appear in all of them.
1362 editor.set_show_cursor_when_unfocused(true, cx);
1363 editor.set_placeholder_text("Add a prompt…", cx);
1364 editor
1365 });
1366 let mut this = Self {
1367 id,
1368 height_in_lines: 1,
1369 editor: prompt_editor,
1370 edited_since_done: false,
1371 gutter_dimensions,
1372 prompt_history,
1373 prompt_history_ix: None,
1374 pending_prompt: String::new(),
1375 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1376 editor_subscriptions: Vec::new(),
1377 codegen,
1378 workspace,
1379 };
1380 this.count_lines(cx);
1381 this.subscribe_to_editor(cx);
1382 this
1383 }
1384
1385 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1386 self.editor_subscriptions.clear();
1387 self.editor_subscriptions
1388 .push(cx.observe(&self.editor, Self::handle_prompt_editor_changed));
1389 self.editor_subscriptions
1390 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1391 }
1392
1393 fn set_show_cursor_when_unfocused(
1394 &mut self,
1395 show_cursor_when_unfocused: bool,
1396 cx: &mut ViewContext<Self>,
1397 ) {
1398 self.editor.update(cx, |editor, cx| {
1399 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1400 });
1401 }
1402
1403 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1404 let prompt = self.prompt(cx);
1405 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1406 self.editor = cx.new_view(|cx| {
1407 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1408 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1409 editor.set_placeholder_text("Add a prompt…", cx);
1410 editor.set_text(prompt, cx);
1411 if focus {
1412 editor.focus(cx);
1413 }
1414 editor
1415 });
1416 self.subscribe_to_editor(cx);
1417 }
1418
1419 fn prompt(&self, cx: &AppContext) -> String {
1420 self.editor.read(cx).text(cx)
1421 }
1422
1423 fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
1424 let height_in_lines = cmp::max(
1425 2, // Make the editor at least two lines tall, to account for padding and buttons.
1426 cmp::min(
1427 self.editor
1428 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
1429 Self::MAX_LINES as u32,
1430 ),
1431 ) as u8;
1432
1433 if height_in_lines != self.height_in_lines {
1434 self.height_in_lines = height_in_lines;
1435 cx.emit(PromptEditorEvent::Resized { height_in_lines });
1436 }
1437 }
1438
1439 fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
1440 self.count_lines(cx);
1441 }
1442
1443 fn handle_prompt_editor_events(
1444 &mut self,
1445 _: View<Editor>,
1446 event: &EditorEvent,
1447 cx: &mut ViewContext<Self>,
1448 ) {
1449 match event {
1450 EditorEvent::Edited { .. } => {
1451 let prompt = self.editor.read(cx).text(cx);
1452 if self
1453 .prompt_history_ix
1454 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1455 {
1456 self.prompt_history_ix.take();
1457 self.pending_prompt = prompt;
1458 }
1459
1460 self.edited_since_done = true;
1461 cx.notify();
1462 }
1463 _ => {}
1464 }
1465 }
1466
1467 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1468 match &self.codegen.read(cx).status {
1469 CodegenStatus::Idle => {
1470 self.editor
1471 .update(cx, |editor, _| editor.set_read_only(false));
1472 }
1473 CodegenStatus::Pending => {
1474 self.editor
1475 .update(cx, |editor, _| editor.set_read_only(true));
1476 }
1477 CodegenStatus::Done | CodegenStatus::Error(_) => {
1478 self.edited_since_done = false;
1479 self.editor
1480 .update(cx, |editor, _| editor.set_read_only(false));
1481 }
1482 }
1483 }
1484
1485 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1486 match &self.codegen.read(cx).status {
1487 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1488 cx.emit(PromptEditorEvent::CancelRequested);
1489 }
1490 CodegenStatus::Pending => {
1491 cx.emit(PromptEditorEvent::StopRequested);
1492 }
1493 }
1494 }
1495
1496 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1497 match &self.codegen.read(cx).status {
1498 CodegenStatus::Idle => {
1499 cx.emit(PromptEditorEvent::StartRequested);
1500 }
1501 CodegenStatus::Pending => {
1502 cx.emit(PromptEditorEvent::DismissRequested);
1503 }
1504 CodegenStatus::Done | CodegenStatus::Error(_) => {
1505 if self.edited_since_done {
1506 cx.emit(PromptEditorEvent::StartRequested);
1507 } else {
1508 cx.emit(PromptEditorEvent::ConfirmRequested);
1509 }
1510 }
1511 }
1512 }
1513
1514 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1515 if let Some(ix) = self.prompt_history_ix {
1516 if ix > 0 {
1517 self.prompt_history_ix = Some(ix - 1);
1518 let prompt = self.prompt_history[ix - 1].as_str();
1519 self.editor.update(cx, |editor, cx| {
1520 editor.set_text(prompt, cx);
1521 editor.move_to_beginning(&Default::default(), cx);
1522 });
1523 }
1524 } else if !self.prompt_history.is_empty() {
1525 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1526 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1527 self.editor.update(cx, |editor, cx| {
1528 editor.set_text(prompt, cx);
1529 editor.move_to_beginning(&Default::default(), cx);
1530 });
1531 }
1532 }
1533
1534 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1535 if let Some(ix) = self.prompt_history_ix {
1536 if ix < self.prompt_history.len() - 1 {
1537 self.prompt_history_ix = Some(ix + 1);
1538 let prompt = self.prompt_history[ix + 1].as_str();
1539 self.editor.update(cx, |editor, cx| {
1540 editor.set_text(prompt, cx);
1541 editor.move_to_end(&Default::default(), cx)
1542 });
1543 } else {
1544 self.prompt_history_ix = None;
1545 let prompt = self.pending_prompt.as_str();
1546 self.editor.update(cx, |editor, cx| {
1547 editor.set_text(prompt, cx);
1548 editor.move_to_end(&Default::default(), cx)
1549 });
1550 }
1551 }
1552 }
1553
1554 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1555 let settings = ThemeSettings::get_global(cx);
1556 let text_style = TextStyle {
1557 color: if self.editor.read(cx).read_only(cx) {
1558 cx.theme().colors().text_disabled
1559 } else {
1560 cx.theme().colors().text
1561 },
1562 font_family: settings.ui_font.family.clone(),
1563 font_features: settings.ui_font.features.clone(),
1564 font_size: rems(0.875).into(),
1565 font_weight: FontWeight::NORMAL,
1566 font_style: FontStyle::Normal,
1567 line_height: relative(1.3),
1568 background_color: None,
1569 underline: None,
1570 strikethrough: None,
1571 white_space: WhiteSpace::Normal,
1572 };
1573 EditorElement::new(
1574 &self.editor,
1575 EditorStyle {
1576 background: cx.theme().colors().editor_background,
1577 local_player: cx.theme().players().local(),
1578 text: text_style,
1579 ..Default::default()
1580 },
1581 )
1582 }
1583}
1584
1585struct InlineAssist {
1586 group_id: InlineAssistGroupId,
1587 editor: WeakView<Editor>,
1588 decorations: Option<InlineAssistDecorations>,
1589 codegen: Model<Codegen>,
1590 _subscriptions: Vec<Subscription>,
1591 workspace: Option<WeakView<Workspace>>,
1592 include_context: bool,
1593}
1594
1595impl InlineAssist {
1596 #[allow(clippy::too_many_arguments)]
1597 fn new(
1598 assist_id: InlineAssistId,
1599 group_id: InlineAssistGroupId,
1600 include_context: bool,
1601 editor: &View<Editor>,
1602 prompt_editor: &View<PromptEditor>,
1603 prompt_block_id: BlockId,
1604 end_block_id: BlockId,
1605 codegen: Model<Codegen>,
1606 workspace: Option<WeakView<Workspace>>,
1607 cx: &mut WindowContext,
1608 ) -> Self {
1609 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1610 InlineAssist {
1611 group_id,
1612 include_context,
1613 editor: editor.downgrade(),
1614 decorations: Some(InlineAssistDecorations {
1615 prompt_block_id,
1616 prompt_editor: prompt_editor.clone(),
1617 removed_line_block_ids: HashSet::default(),
1618 end_block_id,
1619 }),
1620 codegen: codegen.clone(),
1621 workspace: workspace.clone(),
1622 _subscriptions: vec![
1623 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1624 InlineAssistant::update_global(cx, |this, cx| {
1625 this.handle_prompt_editor_focus_in(assist_id, cx)
1626 })
1627 }),
1628 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1629 InlineAssistant::update_global(cx, |this, cx| {
1630 this.handle_prompt_editor_focus_out(assist_id, cx)
1631 })
1632 }),
1633 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1634 InlineAssistant::update_global(cx, |this, cx| {
1635 this.handle_prompt_editor_event(prompt_editor, event, cx)
1636 })
1637 }),
1638 cx.observe(&codegen, {
1639 let editor = editor.downgrade();
1640 move |_, cx| {
1641 if let Some(editor) = editor.upgrade() {
1642 InlineAssistant::update_global(cx, |this, cx| {
1643 if let Some(editor_assists) =
1644 this.assists_by_editor.get(&editor.downgrade())
1645 {
1646 editor_assists.highlight_updates.send(()).ok();
1647 }
1648
1649 this.update_editor_blocks(&editor, assist_id, cx);
1650 })
1651 }
1652 }
1653 }),
1654 cx.subscribe(&codegen, move |codegen, event, cx| {
1655 InlineAssistant::update_global(cx, |this, cx| match event {
1656 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1657 CodegenEvent::Finished => {
1658 let assist = if let Some(assist) = this.assists.get(&assist_id) {
1659 assist
1660 } else {
1661 return;
1662 };
1663
1664 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
1665 if assist.decorations.is_none() {
1666 if let Some(workspace) = assist
1667 .workspace
1668 .as_ref()
1669 .and_then(|workspace| workspace.upgrade())
1670 {
1671 let error = format!("Inline assistant error: {}", error);
1672 workspace.update(cx, |workspace, cx| {
1673 struct InlineAssistantError;
1674
1675 let id =
1676 NotificationId::identified::<InlineAssistantError>(
1677 assist_id.0,
1678 );
1679
1680 workspace.show_toast(Toast::new(id, error), cx);
1681 })
1682 }
1683 }
1684 }
1685
1686 if assist.decorations.is_none() {
1687 this.finish_assist(assist_id, false, cx);
1688 }
1689 }
1690 })
1691 }),
1692 ],
1693 }
1694 }
1695}
1696
1697struct InlineAssistDecorations {
1698 prompt_block_id: BlockId,
1699 prompt_editor: View<PromptEditor>,
1700 removed_line_block_ids: HashSet<BlockId>,
1701 end_block_id: BlockId,
1702}
1703
1704#[derive(Debug)]
1705pub enum CodegenEvent {
1706 Finished,
1707 Undone,
1708}
1709
1710pub struct Codegen {
1711 buffer: Model<MultiBuffer>,
1712 old_buffer: Model<Buffer>,
1713 snapshot: MultiBufferSnapshot,
1714 range: Range<Anchor>,
1715 edit_position: Anchor,
1716 last_equal_ranges: Vec<Range<Anchor>>,
1717 transaction_id: Option<TransactionId>,
1718 status: CodegenStatus,
1719 generation: Task<()>,
1720 diff: Diff,
1721 telemetry: Option<Arc<Telemetry>>,
1722 _subscription: gpui::Subscription,
1723}
1724
1725enum CodegenStatus {
1726 Idle,
1727 Pending,
1728 Done,
1729 Error(anyhow::Error),
1730}
1731
1732#[derive(Default)]
1733struct Diff {
1734 task: Option<Task<()>>,
1735 should_update: bool,
1736 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1737 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
1738}
1739
1740impl EventEmitter<CodegenEvent> for Codegen {}
1741
1742impl Codegen {
1743 pub fn new(
1744 buffer: Model<MultiBuffer>,
1745 range: Range<Anchor>,
1746 telemetry: Option<Arc<Telemetry>>,
1747 cx: &mut ModelContext<Self>,
1748 ) -> Self {
1749 let snapshot = buffer.read(cx).snapshot(cx);
1750
1751 let (old_buffer, _, _) = buffer
1752 .read(cx)
1753 .range_to_buffer_ranges(range.clone(), cx)
1754 .pop()
1755 .unwrap();
1756 let old_buffer = cx.new_model(|cx| {
1757 let old_buffer = old_buffer.read(cx);
1758 let text = old_buffer.as_rope().clone();
1759 let line_ending = old_buffer.line_ending();
1760 let language = old_buffer.language().cloned();
1761 let language_registry = old_buffer.language_registry();
1762
1763 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
1764 buffer.set_language(language, cx);
1765 if let Some(language_registry) = language_registry {
1766 buffer.set_language_registry(language_registry)
1767 }
1768 buffer
1769 });
1770
1771 Self {
1772 buffer: buffer.clone(),
1773 old_buffer,
1774 edit_position: range.start,
1775 range,
1776 snapshot,
1777 last_equal_ranges: Default::default(),
1778 transaction_id: Default::default(),
1779 status: CodegenStatus::Idle,
1780 generation: Task::ready(()),
1781 diff: Diff::default(),
1782 telemetry,
1783 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
1784 }
1785 }
1786
1787 fn handle_buffer_event(
1788 &mut self,
1789 _buffer: Model<MultiBuffer>,
1790 event: &multi_buffer::Event,
1791 cx: &mut ModelContext<Self>,
1792 ) {
1793 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
1794 if self.transaction_id == Some(*transaction_id) {
1795 self.transaction_id = None;
1796 self.generation = Task::ready(());
1797 cx.emit(CodegenEvent::Undone);
1798 }
1799 }
1800 }
1801
1802 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
1803 &self.last_equal_ranges
1804 }
1805
1806 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
1807 let range = self.range.clone();
1808 let snapshot = self.snapshot.clone();
1809 let selected_text = snapshot
1810 .text_for_range(range.start..range.end)
1811 .collect::<Rope>();
1812
1813 let selection_start = range.start.to_point(&snapshot);
1814 let suggested_line_indent = snapshot
1815 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
1816 .into_values()
1817 .next()
1818 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
1819
1820 let model_telemetry_id = prompt.model.telemetry_id();
1821 let response = CompletionProvider::global(cx).complete(prompt);
1822 let telemetry = self.telemetry.clone();
1823 self.edit_position = range.start;
1824 self.diff = Diff::default();
1825 self.status = CodegenStatus::Pending;
1826 self.generation = cx.spawn(|this, mut cx| {
1827 async move {
1828 let generate = async {
1829 let mut edit_start = range.start.to_offset(&snapshot);
1830
1831 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
1832 let diff: Task<anyhow::Result<()>> =
1833 cx.background_executor().spawn(async move {
1834 let mut response_latency = None;
1835 let request_start = Instant::now();
1836 let diff = async {
1837 let chunks = StripInvalidSpans::new(response.await?);
1838 futures::pin_mut!(chunks);
1839 let mut diff = StreamingDiff::new(selected_text.to_string());
1840
1841 let mut new_text = String::new();
1842 let mut base_indent = None;
1843 let mut line_indent = None;
1844 let mut first_line = true;
1845
1846 while let Some(chunk) = chunks.next().await {
1847 if response_latency.is_none() {
1848 response_latency = Some(request_start.elapsed());
1849 }
1850 let chunk = chunk?;
1851
1852 let mut lines = chunk.split('\n').peekable();
1853 while let Some(line) = lines.next() {
1854 new_text.push_str(line);
1855 if line_indent.is_none() {
1856 if let Some(non_whitespace_ch_ix) =
1857 new_text.find(|ch: char| !ch.is_whitespace())
1858 {
1859 line_indent = Some(non_whitespace_ch_ix);
1860 base_indent = base_indent.or(line_indent);
1861
1862 let line_indent = line_indent.unwrap();
1863 let base_indent = base_indent.unwrap();
1864 let indent_delta =
1865 line_indent as i32 - base_indent as i32;
1866 let mut corrected_indent_len = cmp::max(
1867 0,
1868 suggested_line_indent.len as i32 + indent_delta,
1869 )
1870 as usize;
1871 if first_line {
1872 corrected_indent_len = corrected_indent_len
1873 .saturating_sub(
1874 selection_start.column as usize,
1875 );
1876 }
1877
1878 let indent_char = suggested_line_indent.char();
1879 let mut indent_buffer = [0; 4];
1880 let indent_str =
1881 indent_char.encode_utf8(&mut indent_buffer);
1882 new_text.replace_range(
1883 ..line_indent,
1884 &indent_str.repeat(corrected_indent_len),
1885 );
1886 }
1887 }
1888
1889 if line_indent.is_some() {
1890 hunks_tx.send(diff.push_new(&new_text)).await?;
1891 new_text.clear();
1892 }
1893
1894 if lines.peek().is_some() {
1895 hunks_tx.send(diff.push_new("\n")).await?;
1896 line_indent = None;
1897 first_line = false;
1898 }
1899 }
1900 }
1901 hunks_tx.send(diff.push_new(&new_text)).await?;
1902 hunks_tx.send(diff.finish()).await?;
1903
1904 anyhow::Ok(())
1905 };
1906
1907 let result = diff.await;
1908
1909 let error_message =
1910 result.as_ref().err().map(|error| error.to_string());
1911 if let Some(telemetry) = telemetry {
1912 telemetry.report_assistant_event(
1913 None,
1914 telemetry_events::AssistantKind::Inline,
1915 model_telemetry_id,
1916 response_latency,
1917 error_message,
1918 );
1919 }
1920
1921 result?;
1922 Ok(())
1923 });
1924
1925 while let Some(hunks) = hunks_rx.next().await {
1926 this.update(&mut cx, |this, cx| {
1927 this.last_equal_ranges.clear();
1928
1929 let transaction = this.buffer.update(cx, |buffer, cx| {
1930 // Avoid grouping assistant edits with user edits.
1931 buffer.finalize_last_transaction(cx);
1932
1933 buffer.start_transaction(cx);
1934 buffer.edit(
1935 hunks.into_iter().filter_map(|hunk| match hunk {
1936 Hunk::Insert { text } => {
1937 let edit_start = snapshot.anchor_after(edit_start);
1938 Some((edit_start..edit_start, text))
1939 }
1940 Hunk::Remove { len } => {
1941 let edit_end = edit_start + len;
1942 let edit_range = snapshot.anchor_after(edit_start)
1943 ..snapshot.anchor_before(edit_end);
1944 edit_start = edit_end;
1945 Some((edit_range, String::new()))
1946 }
1947 Hunk::Keep { len } => {
1948 let edit_end = edit_start + len;
1949 let edit_range = snapshot.anchor_after(edit_start)
1950 ..snapshot.anchor_before(edit_end);
1951 edit_start = edit_end;
1952 this.last_equal_ranges.push(edit_range);
1953 None
1954 }
1955 }),
1956 None,
1957 cx,
1958 );
1959 this.edit_position = snapshot.anchor_after(edit_start);
1960
1961 buffer.end_transaction(cx)
1962 });
1963
1964 if let Some(transaction) = transaction {
1965 if let Some(first_transaction) = this.transaction_id {
1966 // Group all assistant edits into the first transaction.
1967 this.buffer.update(cx, |buffer, cx| {
1968 buffer.merge_transactions(
1969 transaction,
1970 first_transaction,
1971 cx,
1972 )
1973 });
1974 } else {
1975 this.transaction_id = Some(transaction);
1976 this.buffer.update(cx, |buffer, cx| {
1977 buffer.finalize_last_transaction(cx)
1978 });
1979 }
1980 }
1981
1982 this.update_diff(cx);
1983 cx.notify();
1984 })?;
1985 }
1986
1987 diff.await?;
1988
1989 anyhow::Ok(())
1990 };
1991
1992 let result = generate.await;
1993 this.update(&mut cx, |this, cx| {
1994 this.last_equal_ranges.clear();
1995 if let Err(error) = result {
1996 this.status = CodegenStatus::Error(error);
1997 } else {
1998 this.status = CodegenStatus::Done;
1999 }
2000 cx.emit(CodegenEvent::Finished);
2001 cx.notify();
2002 })
2003 .ok();
2004 }
2005 });
2006 cx.notify();
2007 }
2008
2009 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2010 self.last_equal_ranges.clear();
2011 self.status = CodegenStatus::Done;
2012 self.generation = Task::ready(());
2013 cx.emit(CodegenEvent::Finished);
2014 cx.notify();
2015 }
2016
2017 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2018 if let Some(transaction_id) = self.transaction_id.take() {
2019 self.buffer
2020 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
2021 }
2022 }
2023
2024 fn update_diff(&mut self, cx: &mut ModelContext<Self>) {
2025 if self.diff.task.is_some() {
2026 self.diff.should_update = true;
2027 } else {
2028 self.diff.should_update = false;
2029
2030 let old_snapshot = self.snapshot.clone();
2031 let old_range = self.range.to_point(&old_snapshot);
2032 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2033 let new_range = self.range.to_point(&new_snapshot);
2034
2035 self.diff.task = Some(cx.spawn(|this, mut cx| async move {
2036 let (deleted_row_ranges, inserted_row_ranges) = cx
2037 .background_executor()
2038 .spawn(async move {
2039 let old_text = old_snapshot
2040 .text_for_range(
2041 Point::new(old_range.start.row, 0)
2042 ..Point::new(
2043 old_range.end.row,
2044 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2045 ),
2046 )
2047 .collect::<String>();
2048 let new_text = new_snapshot
2049 .text_for_range(
2050 Point::new(new_range.start.row, 0)
2051 ..Point::new(
2052 new_range.end.row,
2053 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2054 ),
2055 )
2056 .collect::<String>();
2057
2058 let mut old_row = old_range.start.row;
2059 let mut new_row = new_range.start.row;
2060 let diff = TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2061
2062 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2063 let mut inserted_row_ranges = Vec::new();
2064 for change in diff.iter_all_changes() {
2065 let line_count = change.value().lines().count() as u32;
2066 match change.tag() {
2067 similar::ChangeTag::Equal => {
2068 old_row += line_count;
2069 new_row += line_count;
2070 }
2071 similar::ChangeTag::Delete => {
2072 let old_end_row = old_row + line_count - 1;
2073 let new_row =
2074 new_snapshot.anchor_before(Point::new(new_row, 0));
2075
2076 if let Some((_, last_deleted_row_range)) =
2077 deleted_row_ranges.last_mut()
2078 {
2079 if *last_deleted_row_range.end() + 1 == old_row {
2080 *last_deleted_row_range =
2081 *last_deleted_row_range.start()..=old_end_row;
2082 } else {
2083 deleted_row_ranges
2084 .push((new_row, old_row..=old_end_row));
2085 }
2086 } else {
2087 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2088 }
2089
2090 old_row += line_count;
2091 }
2092 similar::ChangeTag::Insert => {
2093 let new_end_row = new_row + line_count - 1;
2094 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2095 let end = new_snapshot.anchor_before(Point::new(
2096 new_end_row,
2097 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2098 ));
2099 inserted_row_ranges.push(start..=end);
2100 new_row += line_count;
2101 }
2102 }
2103 }
2104
2105 (deleted_row_ranges, inserted_row_ranges)
2106 })
2107 .await;
2108
2109 this.update(&mut cx, |this, cx| {
2110 this.diff.deleted_row_ranges = deleted_row_ranges;
2111 this.diff.inserted_row_ranges = inserted_row_ranges;
2112 this.diff.task = None;
2113 if this.diff.should_update {
2114 this.update_diff(cx);
2115 }
2116 cx.notify();
2117 })
2118 .ok();
2119 }));
2120 }
2121 }
2122}
2123
2124struct StripInvalidSpans<T> {
2125 stream: T,
2126 stream_done: bool,
2127 buffer: String,
2128 first_line: bool,
2129 line_end: bool,
2130 starts_with_code_block: bool,
2131}
2132
2133impl<T> StripInvalidSpans<T>
2134where
2135 T: Stream<Item = Result<String>>,
2136{
2137 fn new(stream: T) -> Self {
2138 Self {
2139 stream,
2140 stream_done: false,
2141 buffer: String::new(),
2142 first_line: true,
2143 line_end: false,
2144 starts_with_code_block: false,
2145 }
2146 }
2147}
2148
2149impl<T> Stream for StripInvalidSpans<T>
2150where
2151 T: Stream<Item = Result<String>>,
2152{
2153 type Item = Result<String>;
2154
2155 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2156 const CODE_BLOCK_DELIMITER: &str = "```";
2157 const CURSOR_SPAN: &str = "<|CURSOR|>";
2158
2159 let this = unsafe { self.get_unchecked_mut() };
2160 loop {
2161 if !this.stream_done {
2162 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2163 match stream.as_mut().poll_next(cx) {
2164 Poll::Ready(Some(Ok(chunk))) => {
2165 this.buffer.push_str(&chunk);
2166 }
2167 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2168 Poll::Ready(None) => {
2169 this.stream_done = true;
2170 }
2171 Poll::Pending => return Poll::Pending,
2172 }
2173 }
2174
2175 let mut chunk = String::new();
2176 let mut consumed = 0;
2177 if !this.buffer.is_empty() {
2178 let mut lines = this.buffer.split('\n').enumerate().peekable();
2179 while let Some((line_ix, line)) = lines.next() {
2180 if line_ix > 0 {
2181 this.first_line = false;
2182 }
2183
2184 if this.first_line {
2185 let trimmed_line = line.trim();
2186 if lines.peek().is_some() {
2187 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2188 consumed += line.len() + 1;
2189 this.starts_with_code_block = true;
2190 continue;
2191 }
2192 } else if trimmed_line.is_empty()
2193 || prefixes(CODE_BLOCK_DELIMITER)
2194 .any(|prefix| trimmed_line.starts_with(prefix))
2195 {
2196 break;
2197 }
2198 }
2199
2200 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2201 if lines.peek().is_some() {
2202 if this.line_end {
2203 chunk.push('\n');
2204 }
2205
2206 chunk.push_str(&line_without_cursor);
2207 this.line_end = true;
2208 consumed += line.len() + 1;
2209 } else if this.stream_done {
2210 if !this.starts_with_code_block
2211 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2212 {
2213 if this.line_end {
2214 chunk.push('\n');
2215 }
2216
2217 chunk.push_str(&line);
2218 }
2219
2220 consumed += line.len();
2221 } else {
2222 let trimmed_line = line.trim();
2223 if trimmed_line.is_empty()
2224 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2225 || prefixes(CODE_BLOCK_DELIMITER)
2226 .any(|prefix| trimmed_line.ends_with(prefix))
2227 {
2228 break;
2229 } else {
2230 if this.line_end {
2231 chunk.push('\n');
2232 this.line_end = false;
2233 }
2234
2235 chunk.push_str(&line_without_cursor);
2236 consumed += line.len();
2237 }
2238 }
2239 }
2240 }
2241
2242 this.buffer = this.buffer.split_off(consumed);
2243 if !chunk.is_empty() {
2244 return Poll::Ready(Some(Ok(chunk)));
2245 } else if this.stream_done {
2246 return Poll::Ready(None);
2247 }
2248 }
2249 }
2250}
2251
2252fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2253 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2254}
2255
2256fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2257 ranges.sort_unstable_by(|a, b| {
2258 a.start
2259 .cmp(&b.start, buffer)
2260 .then_with(|| b.end.cmp(&a.end, buffer))
2261 });
2262
2263 let mut ix = 0;
2264 while ix + 1 < ranges.len() {
2265 let b = ranges[ix + 1].clone();
2266 let a = &mut ranges[ix];
2267 if a.end.cmp(&b.start, buffer).is_gt() {
2268 if a.end.cmp(&b.end, buffer).is_lt() {
2269 a.end = b.end;
2270 }
2271 ranges.remove(ix + 1);
2272 } else {
2273 ix += 1;
2274 }
2275 }
2276}
2277
2278#[cfg(test)]
2279mod tests {
2280 use std::sync::Arc;
2281
2282 use crate::FakeCompletionProvider;
2283
2284 use super::*;
2285 use futures::stream::{self};
2286 use gpui::{Context, TestAppContext};
2287 use indoc::indoc;
2288 use language::{
2289 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2290 Point,
2291 };
2292 use rand::prelude::*;
2293 use serde::Serialize;
2294 use settings::SettingsStore;
2295
2296 #[derive(Serialize)]
2297 pub struct DummyCompletionRequest {
2298 pub name: String,
2299 }
2300
2301 #[gpui::test(iterations = 10)]
2302 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2303 let provider = FakeCompletionProvider::default();
2304 cx.set_global(cx.update(SettingsStore::test));
2305 cx.set_global(CompletionProvider::Fake(provider.clone()));
2306 cx.update(language_settings::init);
2307
2308 let text = indoc! {"
2309 fn main() {
2310 let x = 0;
2311 for _ in 0..10 {
2312 x += 1;
2313 }
2314 }
2315 "};
2316 let buffer =
2317 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2318 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2319 let range = buffer.read_with(cx, |buffer, cx| {
2320 let snapshot = buffer.snapshot(cx);
2321 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2322 });
2323 let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
2324
2325 let request = LanguageModelRequest::default();
2326 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
2327
2328 let mut new_text = concat!(
2329 " let mut x = 0;\n",
2330 " while x < 10 {\n",
2331 " x += 1;\n",
2332 " }",
2333 );
2334 while !new_text.is_empty() {
2335 let max_len = cmp::min(new_text.len(), 10);
2336 let len = rng.gen_range(1..=max_len);
2337 let (chunk, suffix) = new_text.split_at(len);
2338 provider.send_completion(chunk.into());
2339 new_text = suffix;
2340 cx.background_executor.run_until_parked();
2341 }
2342 provider.finish_completion();
2343 cx.background_executor.run_until_parked();
2344
2345 assert_eq!(
2346 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2347 indoc! {"
2348 fn main() {
2349 let mut x = 0;
2350 while x < 10 {
2351 x += 1;
2352 }
2353 }
2354 "}
2355 );
2356 }
2357
2358 #[gpui::test(iterations = 10)]
2359 async fn test_autoindent_when_generating_past_indentation(
2360 cx: &mut TestAppContext,
2361 mut rng: StdRng,
2362 ) {
2363 let provider = FakeCompletionProvider::default();
2364 cx.set_global(CompletionProvider::Fake(provider.clone()));
2365 cx.set_global(cx.update(SettingsStore::test));
2366 cx.update(language_settings::init);
2367
2368 let text = indoc! {"
2369 fn main() {
2370 le
2371 }
2372 "};
2373 let buffer =
2374 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2375 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2376 let range = buffer.read_with(cx, |buffer, cx| {
2377 let snapshot = buffer.snapshot(cx);
2378 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2379 });
2380 let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
2381
2382 let request = LanguageModelRequest::default();
2383 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
2384
2385 let mut new_text = concat!(
2386 "t mut x = 0;\n",
2387 "while x < 10 {\n",
2388 " x += 1;\n",
2389 "}", //
2390 );
2391 while !new_text.is_empty() {
2392 let max_len = cmp::min(new_text.len(), 10);
2393 let len = rng.gen_range(1..=max_len);
2394 let (chunk, suffix) = new_text.split_at(len);
2395 provider.send_completion(chunk.into());
2396 new_text = suffix;
2397 cx.background_executor.run_until_parked();
2398 }
2399 provider.finish_completion();
2400 cx.background_executor.run_until_parked();
2401
2402 assert_eq!(
2403 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2404 indoc! {"
2405 fn main() {
2406 let mut x = 0;
2407 while x < 10 {
2408 x += 1;
2409 }
2410 }
2411 "}
2412 );
2413 }
2414
2415 #[gpui::test(iterations = 10)]
2416 async fn test_autoindent_when_generating_before_indentation(
2417 cx: &mut TestAppContext,
2418 mut rng: StdRng,
2419 ) {
2420 let provider = FakeCompletionProvider::default();
2421 cx.set_global(CompletionProvider::Fake(provider.clone()));
2422 cx.set_global(cx.update(SettingsStore::test));
2423 cx.update(language_settings::init);
2424
2425 let text = concat!(
2426 "fn main() {\n",
2427 " \n",
2428 "}\n" //
2429 );
2430 let buffer =
2431 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2432 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2433 let range = buffer.read_with(cx, |buffer, cx| {
2434 let snapshot = buffer.snapshot(cx);
2435 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2436 });
2437 let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
2438
2439 let request = LanguageModelRequest::default();
2440 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
2441
2442 let mut new_text = concat!(
2443 "let mut x = 0;\n",
2444 "while x < 10 {\n",
2445 " x += 1;\n",
2446 "}", //
2447 );
2448 while !new_text.is_empty() {
2449 let max_len = cmp::min(new_text.len(), 10);
2450 let len = rng.gen_range(1..=max_len);
2451 let (chunk, suffix) = new_text.split_at(len);
2452 provider.send_completion(chunk.into());
2453 new_text = suffix;
2454 cx.background_executor.run_until_parked();
2455 }
2456 provider.finish_completion();
2457 cx.background_executor.run_until_parked();
2458
2459 assert_eq!(
2460 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2461 indoc! {"
2462 fn main() {
2463 let mut x = 0;
2464 while x < 10 {
2465 x += 1;
2466 }
2467 }
2468 "}
2469 );
2470 }
2471
2472 #[gpui::test]
2473 async fn test_strip_invalid_spans_from_codeblock() {
2474 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2475 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2476 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
2477 assert_chunks(
2478 "```html\n```js\nLorem ipsum dolor\n```\n```",
2479 "```js\nLorem ipsum dolor\n```",
2480 )
2481 .await;
2482 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
2483 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
2484 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
2485 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
2486
2487 async fn assert_chunks(text: &str, expected_text: &str) {
2488 for chunk_size in 1..=text.len() {
2489 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
2490 .map(|chunk| chunk.unwrap())
2491 .collect::<String>()
2492 .await;
2493 assert_eq!(
2494 actual_text, expected_text,
2495 "failed to strip invalid spans, chunk size: {}",
2496 chunk_size
2497 );
2498 }
2499 }
2500
2501 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
2502 stream::iter(
2503 text.chars()
2504 .collect::<Vec<_>>()
2505 .chunks(size)
2506 .map(|chunk| Ok(chunk.iter().collect::<String>()))
2507 .collect::<Vec<_>>(),
2508 )
2509 }
2510 }
2511
2512 fn rust_lang() -> Language {
2513 Language::new(
2514 LanguageConfig {
2515 name: "Rust".into(),
2516 matcher: LanguageMatcher {
2517 path_suffixes: vec!["rs".to_string()],
2518 ..Default::default()
2519 },
2520 ..Default::default()
2521 },
2522 Some(tree_sitter_rust::language()),
2523 )
2524 .with_indents_query(
2525 r#"
2526 (call_expression) @indent
2527 (field_expression) @indent
2528 (_ "(" ")" @end) @indent
2529 (_ "{" "}" @end) @indent
2530 "#,
2531 )
2532 .unwrap()
2533 }
2534}