1use crate::{
2 humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
3 CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 join,
23 stream::{self, BoxStream},
24 SinkExt, Stream, StreamExt,
25};
26use gpui::{
27 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
28 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
29 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
30};
31use language::{Buffer, IndentKind, Point, Selection, TransactionId};
32use language_model::{
33 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
34};
35use multi_buffer::MultiBufferRow;
36use parking_lot::Mutex;
37use rope::Rope;
38use settings::Settings;
39use smol::future::FutureExt;
40use std::{
41 cmp,
42 future::{self, Future},
43 mem,
44 ops::{Range, RangeInclusive},
45 pin::Pin,
46 sync::Arc,
47 task::{self, Poll},
48 time::{Duration, Instant},
49};
50use theme::ThemeSettings;
51use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
52use util::{RangeExt, ResultExt};
53use workspace::{notifications::NotificationId, Toast, Workspace};
54
55pub fn init(
56 fs: Arc<dyn Fs>,
57 prompt_builder: Arc<PromptBuilder>,
58 telemetry: Arc<Telemetry>,
59 cx: &mut AppContext,
60) {
61 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
62 cx.observe_new_views(|_, cx| {
63 let workspace = cx.view().clone();
64 InlineAssistant::update_global(cx, |inline_assistant, cx| {
65 inline_assistant.register_workspace(&workspace, cx)
66 })
67 })
68 .detach();
69}
70
71const PROMPT_HISTORY_MAX_LEN: usize = 20;
72
73pub struct InlineAssistant {
74 next_assist_id: InlineAssistId,
75 next_assist_group_id: InlineAssistGroupId,
76 assists: HashMap<InlineAssistId, InlineAssist>,
77 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
78 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
79 assist_observations: HashMap<
80 InlineAssistId,
81 (
82 async_watch::Sender<AssistStatus>,
83 async_watch::Receiver<AssistStatus>,
84 ),
85 >,
86 confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
87 prompt_history: VecDeque<String>,
88 prompt_builder: Arc<PromptBuilder>,
89 telemetry: Option<Arc<Telemetry>>,
90 fs: Arc<dyn Fs>,
91}
92
93pub enum AssistStatus {
94 Idle,
95 Started,
96 Stopped,
97 Finished,
98}
99
100impl AssistStatus {
101 pub fn is_done(&self) -> bool {
102 matches!(self, Self::Stopped | Self::Finished)
103 }
104}
105
106impl Global for InlineAssistant {}
107
108impl InlineAssistant {
109 pub fn new(
110 fs: Arc<dyn Fs>,
111 prompt_builder: Arc<PromptBuilder>,
112 telemetry: Arc<Telemetry>,
113 ) -> Self {
114 Self {
115 next_assist_id: InlineAssistId::default(),
116 next_assist_group_id: InlineAssistGroupId::default(),
117 assists: HashMap::default(),
118 assists_by_editor: HashMap::default(),
119 assist_groups: HashMap::default(),
120 assist_observations: HashMap::default(),
121 confirmed_assists: HashMap::default(),
122 prompt_history: VecDeque::default(),
123 prompt_builder,
124 telemetry: Some(telemetry),
125 fs,
126 }
127 }
128
129 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
130 cx.subscribe(workspace, |_, event, cx| {
131 Self::update_global(cx, |this, cx| this.handle_workspace_event(event, cx));
132 })
133 .detach();
134 }
135
136 fn handle_workspace_event(&mut self, event: &workspace::Event, cx: &mut WindowContext) {
137 // When the user manually saves an editor, automatically accepts all finished transformations.
138 if let workspace::Event::UserSavedItem { item, .. } = event {
139 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
140 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
141 for assist_id in editor_assists.assist_ids.clone() {
142 let assist = &self.assists[&assist_id];
143 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
144 self.finish_assist(assist_id, false, cx)
145 }
146 }
147 }
148 }
149 }
150 }
151
152 pub fn assist(
153 &mut self,
154 editor: &View<Editor>,
155 workspace: Option<WeakView<Workspace>>,
156 assistant_panel: Option<&View<AssistantPanel>>,
157 initial_prompt: Option<String>,
158 cx: &mut WindowContext,
159 ) {
160 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
161
162 let mut selections = Vec::<Selection<Point>>::new();
163 let mut newest_selection = None;
164 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
165 if selection.end > selection.start {
166 selection.start.column = 0;
167 // If the selection ends at the start of the line, we don't want to include it.
168 if selection.end.column == 0 {
169 selection.end.row -= 1;
170 }
171 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
172 }
173
174 if let Some(prev_selection) = selections.last_mut() {
175 if selection.start <= prev_selection.end {
176 prev_selection.end = selection.end;
177 continue;
178 }
179 }
180
181 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
182 if selection.id > latest_selection.id {
183 *latest_selection = selection.clone();
184 }
185 selections.push(selection);
186 }
187 let newest_selection = newest_selection.unwrap();
188
189 let mut codegen_ranges = Vec::new();
190 for (excerpt_id, buffer, buffer_range) in
191 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
192 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
193 }))
194 {
195 let start = Anchor {
196 buffer_id: Some(buffer.remote_id()),
197 excerpt_id,
198 text_anchor: buffer.anchor_before(buffer_range.start),
199 };
200 let end = Anchor {
201 buffer_id: Some(buffer.remote_id()),
202 excerpt_id,
203 text_anchor: buffer.anchor_after(buffer_range.end),
204 };
205 codegen_ranges.push(start..end);
206 }
207
208 let assist_group_id = self.next_assist_group_id.post_inc();
209 let prompt_buffer =
210 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
211 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
212
213 let mut assists = Vec::new();
214 let mut assist_to_focus = None;
215 for range in codegen_ranges {
216 let assist_id = self.next_assist_id.post_inc();
217 let codegen = cx.new_model(|cx| {
218 Codegen::new(
219 editor.read(cx).buffer().clone(),
220 range.clone(),
221 None,
222 self.telemetry.clone(),
223 self.prompt_builder.clone(),
224 cx,
225 )
226 });
227
228 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
229 let prompt_editor = cx.new_view(|cx| {
230 PromptEditor::new(
231 assist_id,
232 gutter_dimensions.clone(),
233 self.prompt_history.clone(),
234 prompt_buffer.clone(),
235 codegen.clone(),
236 editor,
237 assistant_panel,
238 workspace.clone(),
239 self.fs.clone(),
240 cx,
241 )
242 });
243
244 if assist_to_focus.is_none() {
245 let focus_assist = if newest_selection.reversed {
246 range.start.to_point(&snapshot) == newest_selection.start
247 } else {
248 range.end.to_point(&snapshot) == newest_selection.end
249 };
250 if focus_assist {
251 assist_to_focus = Some(assist_id);
252 }
253 }
254
255 let [prompt_block_id, end_block_id] =
256 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
257
258 assists.push((
259 assist_id,
260 range,
261 prompt_editor,
262 prompt_block_id,
263 end_block_id,
264 ));
265 }
266
267 let editor_assists = self
268 .assists_by_editor
269 .entry(editor.downgrade())
270 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
271 let mut assist_group = InlineAssistGroup::new();
272 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
273 self.assists.insert(
274 assist_id,
275 InlineAssist::new(
276 assist_id,
277 assist_group_id,
278 assistant_panel.is_some(),
279 editor,
280 &prompt_editor,
281 prompt_block_id,
282 end_block_id,
283 range,
284 prompt_editor.read(cx).codegen.clone(),
285 workspace.clone(),
286 cx,
287 ),
288 );
289 assist_group.assist_ids.push(assist_id);
290 editor_assists.assist_ids.push(assist_id);
291 }
292 self.assist_groups.insert(assist_group_id, assist_group);
293
294 if let Some(assist_id) = assist_to_focus {
295 self.focus_assist(assist_id, cx);
296 }
297 }
298
299 #[allow(clippy::too_many_arguments)]
300 pub fn suggest_assist(
301 &mut self,
302 editor: &View<Editor>,
303 mut range: Range<Anchor>,
304 initial_prompt: String,
305 initial_transaction_id: Option<TransactionId>,
306 workspace: Option<WeakView<Workspace>>,
307 assistant_panel: Option<&View<AssistantPanel>>,
308 cx: &mut WindowContext,
309 ) -> InlineAssistId {
310 let assist_group_id = self.next_assist_group_id.post_inc();
311 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
312 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
313
314 let assist_id = self.next_assist_id.post_inc();
315
316 let buffer = editor.read(cx).buffer().clone();
317 {
318 let snapshot = buffer.read(cx).read(cx);
319 range.start = range.start.bias_left(&snapshot);
320 range.end = range.end.bias_right(&snapshot);
321 }
322
323 let codegen = cx.new_model(|cx| {
324 Codegen::new(
325 editor.read(cx).buffer().clone(),
326 range.clone(),
327 initial_transaction_id,
328 self.telemetry.clone(),
329 self.prompt_builder.clone(),
330 cx,
331 )
332 });
333
334 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
335 let prompt_editor = cx.new_view(|cx| {
336 PromptEditor::new(
337 assist_id,
338 gutter_dimensions.clone(),
339 self.prompt_history.clone(),
340 prompt_buffer.clone(),
341 codegen.clone(),
342 editor,
343 assistant_panel,
344 workspace.clone(),
345 self.fs.clone(),
346 cx,
347 )
348 });
349
350 let [prompt_block_id, end_block_id] =
351 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
352
353 let editor_assists = self
354 .assists_by_editor
355 .entry(editor.downgrade())
356 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
357
358 let mut assist_group = InlineAssistGroup::new();
359 self.assists.insert(
360 assist_id,
361 InlineAssist::new(
362 assist_id,
363 assist_group_id,
364 assistant_panel.is_some(),
365 editor,
366 &prompt_editor,
367 prompt_block_id,
368 end_block_id,
369 range,
370 prompt_editor.read(cx).codegen.clone(),
371 workspace.clone(),
372 cx,
373 ),
374 );
375 assist_group.assist_ids.push(assist_id);
376 editor_assists.assist_ids.push(assist_id);
377 self.assist_groups.insert(assist_group_id, assist_group);
378 assist_id
379 }
380
381 fn insert_assist_blocks(
382 &self,
383 editor: &View<Editor>,
384 range: &Range<Anchor>,
385 prompt_editor: &View<PromptEditor>,
386 cx: &mut WindowContext,
387 ) -> [CustomBlockId; 2] {
388 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
389 prompt_editor
390 .editor
391 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
392 });
393 let assist_blocks = vec![
394 BlockProperties {
395 style: BlockStyle::Sticky,
396 position: range.start,
397 height: prompt_editor_height,
398 render: build_assist_editor_renderer(prompt_editor),
399 disposition: BlockDisposition::Above,
400 priority: 0,
401 },
402 BlockProperties {
403 style: BlockStyle::Sticky,
404 position: range.end,
405 height: 0,
406 render: Box::new(|cx| {
407 v_flex()
408 .h_full()
409 .w_full()
410 .border_t_1()
411 .border_color(cx.theme().status().info_border)
412 .into_any_element()
413 }),
414 disposition: BlockDisposition::Below,
415 priority: 0,
416 },
417 ];
418
419 editor.update(cx, |editor, cx| {
420 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
421 [block_ids[0], block_ids[1]]
422 })
423 }
424
425 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
426 let assist = &self.assists[&assist_id];
427 let Some(decorations) = assist.decorations.as_ref() else {
428 return;
429 };
430 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
431 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
432
433 assist_group.active_assist_id = Some(assist_id);
434 if assist_group.linked {
435 for assist_id in &assist_group.assist_ids {
436 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
437 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
438 prompt_editor.set_show_cursor_when_unfocused(true, cx)
439 });
440 }
441 }
442 }
443
444 assist
445 .editor
446 .update(cx, |editor, cx| {
447 let scroll_top = editor.scroll_position(cx).y;
448 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
449 let prompt_row = editor
450 .row_for_block(decorations.prompt_block_id, cx)
451 .unwrap()
452 .0 as f32;
453
454 if (scroll_top..scroll_bottom).contains(&prompt_row) {
455 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
456 assist_id,
457 distance_from_top: prompt_row - scroll_top,
458 });
459 } else {
460 editor_assists.scroll_lock = None;
461 }
462 })
463 .ok();
464 }
465
466 fn handle_prompt_editor_focus_out(
467 &mut self,
468 assist_id: InlineAssistId,
469 cx: &mut WindowContext,
470 ) {
471 let assist = &self.assists[&assist_id];
472 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
473 if assist_group.active_assist_id == Some(assist_id) {
474 assist_group.active_assist_id = None;
475 if assist_group.linked {
476 for assist_id in &assist_group.assist_ids {
477 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
478 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
479 prompt_editor.set_show_cursor_when_unfocused(false, cx)
480 });
481 }
482 }
483 }
484 }
485 }
486
487 fn handle_prompt_editor_event(
488 &mut self,
489 prompt_editor: View<PromptEditor>,
490 event: &PromptEditorEvent,
491 cx: &mut WindowContext,
492 ) {
493 let assist_id = prompt_editor.read(cx).id;
494 match event {
495 PromptEditorEvent::StartRequested => {
496 self.start_assist(assist_id, cx);
497 }
498 PromptEditorEvent::StopRequested => {
499 self.stop_assist(assist_id, cx);
500 }
501 PromptEditorEvent::ConfirmRequested => {
502 self.finish_assist(assist_id, false, cx);
503 }
504 PromptEditorEvent::CancelRequested => {
505 self.finish_assist(assist_id, true, cx);
506 }
507 PromptEditorEvent::DismissRequested => {
508 self.dismiss_assist(assist_id, cx);
509 }
510 }
511 }
512
513 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
514 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
515 return;
516 };
517
518 let editor = editor.read(cx);
519 if editor.selections.count() == 1 {
520 let selection = editor.selections.newest::<usize>(cx);
521 let buffer = editor.buffer().read(cx).snapshot(cx);
522 for assist_id in &editor_assists.assist_ids {
523 let assist = &self.assists[assist_id];
524 let assist_range = assist.range.to_offset(&buffer);
525 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
526 {
527 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
528 self.dismiss_assist(*assist_id, cx);
529 } else {
530 self.finish_assist(*assist_id, false, cx);
531 }
532
533 return;
534 }
535 }
536 }
537
538 cx.propagate();
539 }
540
541 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
542 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
543 return;
544 };
545
546 let editor = editor.read(cx);
547 if editor.selections.count() == 1 {
548 let selection = editor.selections.newest::<usize>(cx);
549 let buffer = editor.buffer().read(cx).snapshot(cx);
550 let mut closest_assist_fallback = None;
551 for assist_id in &editor_assists.assist_ids {
552 let assist = &self.assists[assist_id];
553 let assist_range = assist.range.to_offset(&buffer);
554 if assist.decorations.is_some() {
555 if assist_range.contains(&selection.start)
556 && assist_range.contains(&selection.end)
557 {
558 self.focus_assist(*assist_id, cx);
559 return;
560 } else {
561 let distance_from_selection = assist_range
562 .start
563 .abs_diff(selection.start)
564 .min(assist_range.start.abs_diff(selection.end))
565 + assist_range
566 .end
567 .abs_diff(selection.start)
568 .min(assist_range.end.abs_diff(selection.end));
569 match closest_assist_fallback {
570 Some((_, old_distance)) => {
571 if distance_from_selection < old_distance {
572 closest_assist_fallback =
573 Some((assist_id, distance_from_selection));
574 }
575 }
576 None => {
577 closest_assist_fallback = Some((assist_id, distance_from_selection))
578 }
579 }
580 }
581 }
582 }
583
584 if let Some((&assist_id, _)) = closest_assist_fallback {
585 self.focus_assist(assist_id, cx);
586 }
587 }
588
589 cx.propagate();
590 }
591
592 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
593 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
594 for assist_id in editor_assists.assist_ids.clone() {
595 self.finish_assist(assist_id, true, cx);
596 }
597 }
598 }
599
600 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
601 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
602 return;
603 };
604 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
605 return;
606 };
607 let assist = &self.assists[&scroll_lock.assist_id];
608 let Some(decorations) = assist.decorations.as_ref() else {
609 return;
610 };
611
612 editor.update(cx, |editor, cx| {
613 let scroll_position = editor.scroll_position(cx);
614 let target_scroll_top = editor
615 .row_for_block(decorations.prompt_block_id, cx)
616 .unwrap()
617 .0 as f32
618 - scroll_lock.distance_from_top;
619 if target_scroll_top != scroll_position.y {
620 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
621 }
622 });
623 }
624
625 fn handle_editor_event(
626 &mut self,
627 editor: View<Editor>,
628 event: &EditorEvent,
629 cx: &mut WindowContext,
630 ) {
631 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
632 return;
633 };
634
635 match event {
636 EditorEvent::Edited { transaction_id } => {
637 let buffer = editor.read(cx).buffer().read(cx);
638 let edited_ranges =
639 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
640 let snapshot = buffer.snapshot(cx);
641
642 for assist_id in editor_assists.assist_ids.clone() {
643 let assist = &self.assists[&assist_id];
644 if matches!(
645 assist.codegen.read(cx).status,
646 CodegenStatus::Error(_) | CodegenStatus::Done
647 ) {
648 let assist_range = assist.range.to_offset(&snapshot);
649 if edited_ranges
650 .iter()
651 .any(|range| range.overlaps(&assist_range))
652 {
653 self.finish_assist(assist_id, false, cx);
654 }
655 }
656 }
657 }
658 EditorEvent::ScrollPositionChanged { .. } => {
659 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
660 let assist = &self.assists[&scroll_lock.assist_id];
661 if let Some(decorations) = assist.decorations.as_ref() {
662 let distance_from_top = editor.update(cx, |editor, cx| {
663 let scroll_top = editor.scroll_position(cx).y;
664 let prompt_row = editor
665 .row_for_block(decorations.prompt_block_id, cx)
666 .unwrap()
667 .0 as f32;
668 prompt_row - scroll_top
669 });
670
671 if distance_from_top != scroll_lock.distance_from_top {
672 editor_assists.scroll_lock = None;
673 }
674 }
675 }
676 }
677 EditorEvent::SelectionsChanged { .. } => {
678 for assist_id in editor_assists.assist_ids.clone() {
679 let assist = &self.assists[&assist_id];
680 if let Some(decorations) = assist.decorations.as_ref() {
681 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
682 return;
683 }
684 }
685 }
686
687 editor_assists.scroll_lock = None;
688 }
689 _ => {}
690 }
691 }
692
693 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
694 if let Some(assist) = self.assists.get(&assist_id) {
695 let assist_group_id = assist.group_id;
696 if self.assist_groups[&assist_group_id].linked {
697 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
698 self.finish_assist(assist_id, undo, cx);
699 }
700 return;
701 }
702 }
703
704 self.dismiss_assist(assist_id, cx);
705
706 if let Some(assist) = self.assists.remove(&assist_id) {
707 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
708 {
709 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
710 if entry.get().assist_ids.is_empty() {
711 entry.remove();
712 }
713 }
714
715 if let hash_map::Entry::Occupied(mut entry) =
716 self.assists_by_editor.entry(assist.editor.clone())
717 {
718 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
719 if entry.get().assist_ids.is_empty() {
720 entry.remove();
721 if let Some(editor) = assist.editor.upgrade() {
722 self.update_editor_highlights(&editor, cx);
723 }
724 } else {
725 entry.get().highlight_updates.send(()).ok();
726 }
727 }
728
729 if undo {
730 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
731 } else {
732 self.confirmed_assists.insert(assist_id, assist.codegen);
733 }
734 }
735
736 // Remove the assist from the status updates map
737 self.assist_observations.remove(&assist_id);
738 }
739
740 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
741 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
742 return false;
743 };
744 codegen.update(cx, |this, cx| this.undo(cx));
745 true
746 }
747
748 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
749 let Some(assist) = self.assists.get_mut(&assist_id) else {
750 return false;
751 };
752 let Some(editor) = assist.editor.upgrade() else {
753 return false;
754 };
755 let Some(decorations) = assist.decorations.take() else {
756 return false;
757 };
758
759 editor.update(cx, |editor, cx| {
760 let mut to_remove = decorations.removed_line_block_ids;
761 to_remove.insert(decorations.prompt_block_id);
762 to_remove.insert(decorations.end_block_id);
763 editor.remove_blocks(to_remove, None, cx);
764 });
765
766 if decorations
767 .prompt_editor
768 .focus_handle(cx)
769 .contains_focused(cx)
770 {
771 self.focus_next_assist(assist_id, cx);
772 }
773
774 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
775 if editor_assists
776 .scroll_lock
777 .as_ref()
778 .map_or(false, |lock| lock.assist_id == assist_id)
779 {
780 editor_assists.scroll_lock = None;
781 }
782 editor_assists.highlight_updates.send(()).ok();
783 }
784
785 true
786 }
787
788 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
789 let Some(assist) = self.assists.get(&assist_id) else {
790 return;
791 };
792
793 let assist_group = &self.assist_groups[&assist.group_id];
794 let assist_ix = assist_group
795 .assist_ids
796 .iter()
797 .position(|id| *id == assist_id)
798 .unwrap();
799 let assist_ids = assist_group
800 .assist_ids
801 .iter()
802 .skip(assist_ix + 1)
803 .chain(assist_group.assist_ids.iter().take(assist_ix));
804
805 for assist_id in assist_ids {
806 let assist = &self.assists[assist_id];
807 if assist.decorations.is_some() {
808 self.focus_assist(*assist_id, cx);
809 return;
810 }
811 }
812
813 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
814 }
815
816 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
817 let Some(assist) = self.assists.get(&assist_id) else {
818 return;
819 };
820
821 if let Some(decorations) = assist.decorations.as_ref() {
822 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
823 prompt_editor.editor.update(cx, |editor, cx| {
824 editor.focus(cx);
825 editor.select_all(&SelectAll, cx);
826 })
827 });
828 }
829
830 self.scroll_to_assist(assist_id, cx);
831 }
832
833 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
834 let Some(assist) = self.assists.get(&assist_id) else {
835 return;
836 };
837 let Some(editor) = assist.editor.upgrade() else {
838 return;
839 };
840
841 let position = assist.range.start;
842 editor.update(cx, |editor, cx| {
843 editor.change_selections(None, cx, |selections| {
844 selections.select_anchor_ranges([position..position])
845 });
846
847 let mut scroll_target_top;
848 let mut scroll_target_bottom;
849 if let Some(decorations) = assist.decorations.as_ref() {
850 scroll_target_top = editor
851 .row_for_block(decorations.prompt_block_id, cx)
852 .unwrap()
853 .0 as f32;
854 scroll_target_bottom = editor
855 .row_for_block(decorations.end_block_id, cx)
856 .unwrap()
857 .0 as f32;
858 } else {
859 let snapshot = editor.snapshot(cx);
860 let start_row = assist
861 .range
862 .start
863 .to_display_point(&snapshot.display_snapshot)
864 .row();
865 scroll_target_top = start_row.0 as f32;
866 scroll_target_bottom = scroll_target_top + 1.;
867 }
868 scroll_target_top -= editor.vertical_scroll_margin() as f32;
869 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
870
871 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
872 let scroll_top = editor.scroll_position(cx).y;
873 let scroll_bottom = scroll_top + height_in_lines;
874
875 if scroll_target_top < scroll_top {
876 editor.set_scroll_position(point(0., scroll_target_top), cx);
877 } else if scroll_target_bottom > scroll_bottom {
878 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
879 editor
880 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
881 } else {
882 editor.set_scroll_position(point(0., scroll_target_top), cx);
883 }
884 }
885 });
886 }
887
888 fn unlink_assist_group(
889 &mut self,
890 assist_group_id: InlineAssistGroupId,
891 cx: &mut WindowContext,
892 ) -> Vec<InlineAssistId> {
893 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
894 assist_group.linked = false;
895 for assist_id in &assist_group.assist_ids {
896 let assist = self.assists.get_mut(assist_id).unwrap();
897 if let Some(editor_decorations) = assist.decorations.as_ref() {
898 editor_decorations
899 .prompt_editor
900 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
901 }
902 }
903 assist_group.assist_ids.clone()
904 }
905
906 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
907 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
908 assist
909 } else {
910 return;
911 };
912
913 let assist_group_id = assist.group_id;
914 if self.assist_groups[&assist_group_id].linked {
915 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
916 self.start_assist(assist_id, cx);
917 }
918 return;
919 }
920
921 let Some(user_prompt) = assist.user_prompt(cx) else {
922 return;
923 };
924
925 self.prompt_history.retain(|prompt| *prompt != user_prompt);
926 self.prompt_history.push_back(user_prompt.clone());
927 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
928 self.prompt_history.pop_front();
929 }
930
931 let assistant_panel_context = assist.assistant_panel_context(cx);
932
933 assist
934 .codegen
935 .update(cx, |codegen, cx| {
936 codegen.start(
937 assist.range.clone(),
938 user_prompt,
939 assistant_panel_context,
940 cx,
941 )
942 })
943 .log_err();
944
945 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
946 tx.send(AssistStatus::Started).ok();
947 }
948 }
949
950 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
951 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
952 assist
953 } else {
954 return;
955 };
956
957 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
958
959 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
960 tx.send(AssistStatus::Stopped).ok();
961 }
962 }
963
964 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
965 if let Some(assist) = self.assists.get(&assist_id) {
966 match &assist.codegen.read(cx).status {
967 CodegenStatus::Idle => InlineAssistStatus::Idle,
968 CodegenStatus::Pending => InlineAssistStatus::Pending,
969 CodegenStatus::Done => InlineAssistStatus::Done,
970 CodegenStatus::Error(_) => InlineAssistStatus::Error,
971 }
972 } else if self.confirmed_assists.contains_key(&assist_id) {
973 InlineAssistStatus::Confirmed
974 } else {
975 InlineAssistStatus::Canceled
976 }
977 }
978
979 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
980 let mut gutter_pending_ranges = Vec::new();
981 let mut gutter_transformed_ranges = Vec::new();
982 let mut foreground_ranges = Vec::new();
983 let mut inserted_row_ranges = Vec::new();
984 let empty_assist_ids = Vec::new();
985 let assist_ids = self
986 .assists_by_editor
987 .get(&editor.downgrade())
988 .map_or(&empty_assist_ids, |editor_assists| {
989 &editor_assists.assist_ids
990 });
991
992 for assist_id in assist_ids {
993 if let Some(assist) = self.assists.get(assist_id) {
994 let codegen = assist.codegen.read(cx);
995 let buffer = codegen.buffer.read(cx).read(cx);
996 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
997
998 let pending_range =
999 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
1000 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1001 gutter_pending_ranges.push(pending_range);
1002 }
1003
1004 if let Some(edit_position) = codegen.edit_position {
1005 let edited_range = assist.range.start..edit_position;
1006 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1007 gutter_transformed_ranges.push(edited_range);
1008 }
1009 }
1010
1011 if assist.decorations.is_some() {
1012 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
1013 }
1014 }
1015 }
1016
1017 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1018 merge_ranges(&mut foreground_ranges, &snapshot);
1019 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1020 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1021 editor.update(cx, |editor, cx| {
1022 enum GutterPendingRange {}
1023 if gutter_pending_ranges.is_empty() {
1024 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1025 } else {
1026 editor.highlight_gutter::<GutterPendingRange>(
1027 &gutter_pending_ranges,
1028 |cx| cx.theme().status().info_background,
1029 cx,
1030 )
1031 }
1032
1033 enum GutterTransformedRange {}
1034 if gutter_transformed_ranges.is_empty() {
1035 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1036 } else {
1037 editor.highlight_gutter::<GutterTransformedRange>(
1038 &gutter_transformed_ranges,
1039 |cx| cx.theme().status().info,
1040 cx,
1041 )
1042 }
1043
1044 if foreground_ranges.is_empty() {
1045 editor.clear_highlights::<InlineAssist>(cx);
1046 } else {
1047 editor.highlight_text::<InlineAssist>(
1048 foreground_ranges,
1049 HighlightStyle {
1050 fade_out: Some(0.6),
1051 ..Default::default()
1052 },
1053 cx,
1054 );
1055 }
1056
1057 editor.clear_row_highlights::<InlineAssist>();
1058 for row_range in inserted_row_ranges {
1059 editor.highlight_rows::<InlineAssist>(
1060 row_range,
1061 Some(cx.theme().status().info_background),
1062 false,
1063 cx,
1064 );
1065 }
1066 });
1067 }
1068
1069 fn update_editor_blocks(
1070 &mut self,
1071 editor: &View<Editor>,
1072 assist_id: InlineAssistId,
1073 cx: &mut WindowContext,
1074 ) {
1075 let Some(assist) = self.assists.get_mut(&assist_id) else {
1076 return;
1077 };
1078 let Some(decorations) = assist.decorations.as_mut() else {
1079 return;
1080 };
1081
1082 let codegen = assist.codegen.read(cx);
1083 let old_snapshot = codegen.snapshot.clone();
1084 let old_buffer = codegen.old_buffer.clone();
1085 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
1086
1087 editor.update(cx, |editor, cx| {
1088 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1089 editor.remove_blocks(old_blocks, None, cx);
1090
1091 let mut new_blocks = Vec::new();
1092 for (new_row, old_row_range) in deleted_row_ranges {
1093 let (_, buffer_start) = old_snapshot
1094 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1095 .unwrap();
1096 let (_, buffer_end) = old_snapshot
1097 .point_to_buffer_offset(Point::new(
1098 *old_row_range.end(),
1099 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1100 ))
1101 .unwrap();
1102
1103 let deleted_lines_editor = cx.new_view(|cx| {
1104 let multi_buffer = cx.new_model(|_| {
1105 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1106 });
1107 multi_buffer.update(cx, |multi_buffer, cx| {
1108 multi_buffer.push_excerpts(
1109 old_buffer.clone(),
1110 Some(ExcerptRange {
1111 context: buffer_start..buffer_end,
1112 primary: None,
1113 }),
1114 cx,
1115 );
1116 });
1117
1118 enum DeletedLines {}
1119 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1120 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1121 editor.set_show_wrap_guides(false, cx);
1122 editor.set_show_gutter(false, cx);
1123 editor.scroll_manager.set_forbid_vertical_scroll(true);
1124 editor.set_read_only(true);
1125 editor.set_show_inline_completions(false);
1126 editor.highlight_rows::<DeletedLines>(
1127 Anchor::min()..=Anchor::max(),
1128 Some(cx.theme().status().deleted_background),
1129 false,
1130 cx,
1131 );
1132 editor
1133 });
1134
1135 let height =
1136 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1137 new_blocks.push(BlockProperties {
1138 position: new_row,
1139 height,
1140 style: BlockStyle::Flex,
1141 render: Box::new(move |cx| {
1142 div()
1143 .bg(cx.theme().status().deleted_background)
1144 .size_full()
1145 .h(height as f32 * cx.line_height())
1146 .pl(cx.gutter_dimensions.full_width())
1147 .child(deleted_lines_editor.clone())
1148 .into_any_element()
1149 }),
1150 disposition: BlockDisposition::Above,
1151 priority: 0,
1152 });
1153 }
1154
1155 decorations.removed_line_block_ids = editor
1156 .insert_blocks(new_blocks, None, cx)
1157 .into_iter()
1158 .collect();
1159 })
1160 }
1161
1162 pub fn observe_assist(
1163 &mut self,
1164 assist_id: InlineAssistId,
1165 ) -> async_watch::Receiver<AssistStatus> {
1166 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1167 rx.clone()
1168 } else {
1169 let (tx, rx) = async_watch::channel(AssistStatus::Idle);
1170 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1171 rx
1172 }
1173 }
1174}
1175
1176pub enum InlineAssistStatus {
1177 Idle,
1178 Pending,
1179 Done,
1180 Error,
1181 Confirmed,
1182 Canceled,
1183}
1184
1185impl InlineAssistStatus {
1186 pub(crate) fn is_pending(&self) -> bool {
1187 matches!(self, Self::Pending)
1188 }
1189 pub(crate) fn is_confirmed(&self) -> bool {
1190 matches!(self, Self::Confirmed)
1191 }
1192 pub(crate) fn is_done(&self) -> bool {
1193 matches!(self, Self::Done)
1194 }
1195}
1196
1197struct EditorInlineAssists {
1198 assist_ids: Vec<InlineAssistId>,
1199 scroll_lock: Option<InlineAssistScrollLock>,
1200 highlight_updates: async_watch::Sender<()>,
1201 _update_highlights: Task<Result<()>>,
1202 _subscriptions: Vec<gpui::Subscription>,
1203}
1204
1205struct InlineAssistScrollLock {
1206 assist_id: InlineAssistId,
1207 distance_from_top: f32,
1208}
1209
1210impl EditorInlineAssists {
1211 #[allow(clippy::too_many_arguments)]
1212 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1213 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1214 Self {
1215 assist_ids: Vec::new(),
1216 scroll_lock: None,
1217 highlight_updates: highlight_updates_tx,
1218 _update_highlights: cx.spawn(|mut cx| {
1219 let editor = editor.downgrade();
1220 async move {
1221 while let Ok(()) = highlight_updates_rx.changed().await {
1222 let editor = editor.upgrade().context("editor was dropped")?;
1223 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1224 assistant.update_editor_highlights(&editor, cx);
1225 })?;
1226 }
1227 Ok(())
1228 }
1229 }),
1230 _subscriptions: vec![
1231 cx.observe_release(editor, {
1232 let editor = editor.downgrade();
1233 |_, cx| {
1234 InlineAssistant::update_global(cx, |this, cx| {
1235 this.handle_editor_release(editor, cx);
1236 })
1237 }
1238 }),
1239 cx.observe(editor, move |editor, cx| {
1240 InlineAssistant::update_global(cx, |this, cx| {
1241 this.handle_editor_change(editor, cx)
1242 })
1243 }),
1244 cx.subscribe(editor, move |editor, event, cx| {
1245 InlineAssistant::update_global(cx, |this, cx| {
1246 this.handle_editor_event(editor, event, cx)
1247 })
1248 }),
1249 editor.update(cx, |editor, cx| {
1250 let editor_handle = cx.view().downgrade();
1251 editor.register_action(
1252 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1253 InlineAssistant::update_global(cx, |this, cx| {
1254 if let Some(editor) = editor_handle.upgrade() {
1255 this.handle_editor_newline(editor, cx)
1256 }
1257 })
1258 },
1259 )
1260 }),
1261 editor.update(cx, |editor, cx| {
1262 let editor_handle = cx.view().downgrade();
1263 editor.register_action(
1264 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1265 InlineAssistant::update_global(cx, |this, cx| {
1266 if let Some(editor) = editor_handle.upgrade() {
1267 this.handle_editor_cancel(editor, cx)
1268 }
1269 })
1270 },
1271 )
1272 }),
1273 ],
1274 }
1275 }
1276}
1277
1278struct InlineAssistGroup {
1279 assist_ids: Vec<InlineAssistId>,
1280 linked: bool,
1281 active_assist_id: Option<InlineAssistId>,
1282}
1283
1284impl InlineAssistGroup {
1285 fn new() -> Self {
1286 Self {
1287 assist_ids: Vec::new(),
1288 linked: true,
1289 active_assist_id: None,
1290 }
1291 }
1292}
1293
1294fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1295 let editor = editor.clone();
1296 Box::new(move |cx: &mut BlockContext| {
1297 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1298 editor.clone().into_any_element()
1299 })
1300}
1301
1302#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1303pub struct InlineAssistId(usize);
1304
1305impl InlineAssistId {
1306 fn post_inc(&mut self) -> InlineAssistId {
1307 let id = *self;
1308 self.0 += 1;
1309 id
1310 }
1311}
1312
1313#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1314struct InlineAssistGroupId(usize);
1315
1316impl InlineAssistGroupId {
1317 fn post_inc(&mut self) -> InlineAssistGroupId {
1318 let id = *self;
1319 self.0 += 1;
1320 id
1321 }
1322}
1323
1324enum PromptEditorEvent {
1325 StartRequested,
1326 StopRequested,
1327 ConfirmRequested,
1328 CancelRequested,
1329 DismissRequested,
1330}
1331
1332struct PromptEditor {
1333 id: InlineAssistId,
1334 fs: Arc<dyn Fs>,
1335 editor: View<Editor>,
1336 edited_since_done: bool,
1337 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1338 prompt_history: VecDeque<String>,
1339 prompt_history_ix: Option<usize>,
1340 pending_prompt: String,
1341 codegen: Model<Codegen>,
1342 _codegen_subscription: Subscription,
1343 editor_subscriptions: Vec<Subscription>,
1344 pending_token_count: Task<Result<()>>,
1345 token_counts: Option<TokenCounts>,
1346 _token_count_subscriptions: Vec<Subscription>,
1347 workspace: Option<WeakView<Workspace>>,
1348 show_rate_limit_notice: bool,
1349}
1350
1351#[derive(Copy, Clone)]
1352pub struct TokenCounts {
1353 total: usize,
1354 assistant_panel: usize,
1355}
1356
1357impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1358
1359impl Render for PromptEditor {
1360 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1361 let gutter_dimensions = *self.gutter_dimensions.lock();
1362 let status = &self.codegen.read(cx).status;
1363 let buttons = match status {
1364 CodegenStatus::Idle => {
1365 vec![
1366 IconButton::new("cancel", IconName::Close)
1367 .icon_color(Color::Muted)
1368 .shape(IconButtonShape::Square)
1369 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1370 .on_click(
1371 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1372 ),
1373 IconButton::new("start", IconName::SparkleAlt)
1374 .icon_color(Color::Muted)
1375 .shape(IconButtonShape::Square)
1376 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1377 .on_click(
1378 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1379 ),
1380 ]
1381 }
1382 CodegenStatus::Pending => {
1383 vec![
1384 IconButton::new("cancel", IconName::Close)
1385 .icon_color(Color::Muted)
1386 .shape(IconButtonShape::Square)
1387 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1388 .on_click(
1389 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1390 ),
1391 IconButton::new("stop", IconName::Stop)
1392 .icon_color(Color::Error)
1393 .shape(IconButtonShape::Square)
1394 .tooltip(|cx| {
1395 Tooltip::with_meta(
1396 "Interrupt Transformation",
1397 Some(&menu::Cancel),
1398 "Changes won't be discarded",
1399 cx,
1400 )
1401 })
1402 .on_click(
1403 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1404 ),
1405 ]
1406 }
1407 CodegenStatus::Error(_) | CodegenStatus::Done => {
1408 vec![
1409 IconButton::new("cancel", IconName::Close)
1410 .icon_color(Color::Muted)
1411 .shape(IconButtonShape::Square)
1412 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1413 .on_click(
1414 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1415 ),
1416 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1417 IconButton::new("restart", IconName::RotateCw)
1418 .icon_color(Color::Info)
1419 .shape(IconButtonShape::Square)
1420 .tooltip(|cx| {
1421 Tooltip::with_meta(
1422 "Restart Transformation",
1423 Some(&menu::Confirm),
1424 "Changes will be discarded",
1425 cx,
1426 )
1427 })
1428 .on_click(cx.listener(|_, _, cx| {
1429 cx.emit(PromptEditorEvent::StartRequested);
1430 }))
1431 } else {
1432 IconButton::new("confirm", IconName::Check)
1433 .icon_color(Color::Info)
1434 .shape(IconButtonShape::Square)
1435 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1436 .on_click(cx.listener(|_, _, cx| {
1437 cx.emit(PromptEditorEvent::ConfirmRequested);
1438 }))
1439 },
1440 ]
1441 }
1442 };
1443
1444 h_flex()
1445 .bg(cx.theme().colors().editor_background)
1446 .border_y_1()
1447 .border_color(cx.theme().status().info_border)
1448 .size_full()
1449 .py(cx.line_height() / 2.)
1450 .on_action(cx.listener(Self::confirm))
1451 .on_action(cx.listener(Self::cancel))
1452 .on_action(cx.listener(Self::move_up))
1453 .on_action(cx.listener(Self::move_down))
1454 .child(
1455 h_flex()
1456 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1457 .justify_center()
1458 .gap_2()
1459 .child(
1460 ModelSelector::new(
1461 self.fs.clone(),
1462 IconButton::new("context", IconName::SlidersAlt)
1463 .shape(IconButtonShape::Square)
1464 .icon_size(IconSize::Small)
1465 .icon_color(Color::Muted)
1466 .tooltip(move |cx| {
1467 Tooltip::with_meta(
1468 format!(
1469 "Using {}",
1470 LanguageModelRegistry::read_global(cx)
1471 .active_model()
1472 .map(|model| model.name().0)
1473 .unwrap_or_else(|| "No model selected".into()),
1474 ),
1475 None,
1476 "Change Model",
1477 cx,
1478 )
1479 }),
1480 )
1481 .with_info_text(
1482 "Inline edits use context\n\
1483 from the currently selected\n\
1484 assistant panel tab.",
1485 ),
1486 )
1487 .map(|el| {
1488 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1489 return el;
1490 };
1491
1492 let error_message = SharedString::from(error.to_string());
1493 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1494 && cx.has_flag::<ZedPro>()
1495 {
1496 el.child(
1497 v_flex()
1498 .child(
1499 IconButton::new("rate-limit-error", IconName::XCircle)
1500 .selected(self.show_rate_limit_notice)
1501 .shape(IconButtonShape::Square)
1502 .icon_size(IconSize::Small)
1503 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1504 )
1505 .children(self.show_rate_limit_notice.then(|| {
1506 deferred(
1507 anchored()
1508 .position_mode(gpui::AnchoredPositionMode::Local)
1509 .position(point(px(0.), px(24.)))
1510 .anchor(gpui::AnchorCorner::TopLeft)
1511 .child(self.render_rate_limit_notice(cx)),
1512 )
1513 })),
1514 )
1515 } else {
1516 el.child(
1517 div()
1518 .id("error")
1519 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1520 .child(
1521 Icon::new(IconName::XCircle)
1522 .size(IconSize::Small)
1523 .color(Color::Error),
1524 ),
1525 )
1526 }
1527 }),
1528 )
1529 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1530 .child(
1531 h_flex()
1532 .gap_2()
1533 .pr_6()
1534 .children(self.render_token_count(cx))
1535 .children(buttons),
1536 )
1537 }
1538}
1539
1540impl FocusableView for PromptEditor {
1541 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1542 self.editor.focus_handle(cx)
1543 }
1544}
1545
1546impl PromptEditor {
1547 const MAX_LINES: u8 = 8;
1548
1549 #[allow(clippy::too_many_arguments)]
1550 fn new(
1551 id: InlineAssistId,
1552 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1553 prompt_history: VecDeque<String>,
1554 prompt_buffer: Model<MultiBuffer>,
1555 codegen: Model<Codegen>,
1556 parent_editor: &View<Editor>,
1557 assistant_panel: Option<&View<AssistantPanel>>,
1558 workspace: Option<WeakView<Workspace>>,
1559 fs: Arc<dyn Fs>,
1560 cx: &mut ViewContext<Self>,
1561 ) -> Self {
1562 let prompt_editor = cx.new_view(|cx| {
1563 let mut editor = Editor::new(
1564 EditorMode::AutoHeight {
1565 max_lines: Self::MAX_LINES as usize,
1566 },
1567 prompt_buffer,
1568 None,
1569 false,
1570 cx,
1571 );
1572 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1573 // Since the prompt editors for all inline assistants are linked,
1574 // always show the cursor (even when it isn't focused) because
1575 // typing in one will make what you typed appear in all of them.
1576 editor.set_show_cursor_when_unfocused(true, cx);
1577 editor.set_placeholder_text("Add a prompt…", cx);
1578 editor
1579 });
1580
1581 let mut token_count_subscriptions = Vec::new();
1582 token_count_subscriptions
1583 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1584 if let Some(assistant_panel) = assistant_panel {
1585 token_count_subscriptions
1586 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1587 }
1588
1589 let mut this = Self {
1590 id,
1591 editor: prompt_editor,
1592 edited_since_done: false,
1593 gutter_dimensions,
1594 prompt_history,
1595 prompt_history_ix: None,
1596 pending_prompt: String::new(),
1597 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1598 editor_subscriptions: Vec::new(),
1599 codegen,
1600 fs,
1601 pending_token_count: Task::ready(Ok(())),
1602 token_counts: None,
1603 _token_count_subscriptions: token_count_subscriptions,
1604 workspace,
1605 show_rate_limit_notice: false,
1606 };
1607 this.count_tokens(cx);
1608 this.subscribe_to_editor(cx);
1609 this
1610 }
1611
1612 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1613 self.editor_subscriptions.clear();
1614 self.editor_subscriptions
1615 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1616 }
1617
1618 fn set_show_cursor_when_unfocused(
1619 &mut self,
1620 show_cursor_when_unfocused: bool,
1621 cx: &mut ViewContext<Self>,
1622 ) {
1623 self.editor.update(cx, |editor, cx| {
1624 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1625 });
1626 }
1627
1628 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1629 let prompt = self.prompt(cx);
1630 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1631 self.editor = cx.new_view(|cx| {
1632 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1633 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1634 editor.set_placeholder_text("Add a prompt…", cx);
1635 editor.set_text(prompt, cx);
1636 if focus {
1637 editor.focus(cx);
1638 }
1639 editor
1640 });
1641 self.subscribe_to_editor(cx);
1642 }
1643
1644 fn prompt(&self, cx: &AppContext) -> String {
1645 self.editor.read(cx).text(cx)
1646 }
1647
1648 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1649 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1650 if self.show_rate_limit_notice {
1651 cx.focus_view(&self.editor);
1652 }
1653 cx.notify();
1654 }
1655
1656 fn handle_parent_editor_event(
1657 &mut self,
1658 _: View<Editor>,
1659 event: &EditorEvent,
1660 cx: &mut ViewContext<Self>,
1661 ) {
1662 if let EditorEvent::BufferEdited { .. } = event {
1663 self.count_tokens(cx);
1664 }
1665 }
1666
1667 fn handle_assistant_panel_event(
1668 &mut self,
1669 _: View<AssistantPanel>,
1670 event: &AssistantPanelEvent,
1671 cx: &mut ViewContext<Self>,
1672 ) {
1673 let AssistantPanelEvent::ContextEdited { .. } = event;
1674 self.count_tokens(cx);
1675 }
1676
1677 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1678 let assist_id = self.id;
1679 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1680 cx.background_executor().timer(Duration::from_secs(1)).await;
1681 let token_count = cx
1682 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1683 let assist = inline_assistant
1684 .assists
1685 .get(&assist_id)
1686 .context("assist not found")?;
1687 anyhow::Ok(assist.count_tokens(cx))
1688 })??
1689 .await?;
1690
1691 this.update(&mut cx, |this, cx| {
1692 this.token_counts = Some(token_count);
1693 cx.notify();
1694 })
1695 })
1696 }
1697
1698 fn handle_prompt_editor_events(
1699 &mut self,
1700 _: View<Editor>,
1701 event: &EditorEvent,
1702 cx: &mut ViewContext<Self>,
1703 ) {
1704 match event {
1705 EditorEvent::Edited { .. } => {
1706 let prompt = self.editor.read(cx).text(cx);
1707 if self
1708 .prompt_history_ix
1709 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1710 {
1711 self.prompt_history_ix.take();
1712 self.pending_prompt = prompt;
1713 }
1714
1715 self.edited_since_done = true;
1716 cx.notify();
1717 }
1718 EditorEvent::BufferEdited => {
1719 self.count_tokens(cx);
1720 }
1721 EditorEvent::Blurred => {
1722 if self.show_rate_limit_notice {
1723 self.show_rate_limit_notice = false;
1724 cx.notify();
1725 }
1726 }
1727 _ => {}
1728 }
1729 }
1730
1731 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1732 match &self.codegen.read(cx).status {
1733 CodegenStatus::Idle => {
1734 self.editor
1735 .update(cx, |editor, _| editor.set_read_only(false));
1736 }
1737 CodegenStatus::Pending => {
1738 self.editor
1739 .update(cx, |editor, _| editor.set_read_only(true));
1740 }
1741 CodegenStatus::Done => {
1742 self.edited_since_done = false;
1743 self.editor
1744 .update(cx, |editor, _| editor.set_read_only(false));
1745 }
1746 CodegenStatus::Error(error) => {
1747 if cx.has_flag::<ZedPro>()
1748 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1749 && !dismissed_rate_limit_notice()
1750 {
1751 self.show_rate_limit_notice = true;
1752 cx.notify();
1753 }
1754
1755 self.edited_since_done = false;
1756 self.editor
1757 .update(cx, |editor, _| editor.set_read_only(false));
1758 }
1759 }
1760 }
1761
1762 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1763 match &self.codegen.read(cx).status {
1764 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1765 cx.emit(PromptEditorEvent::CancelRequested);
1766 }
1767 CodegenStatus::Pending => {
1768 cx.emit(PromptEditorEvent::StopRequested);
1769 }
1770 }
1771 }
1772
1773 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1774 match &self.codegen.read(cx).status {
1775 CodegenStatus::Idle => {
1776 cx.emit(PromptEditorEvent::StartRequested);
1777 }
1778 CodegenStatus::Pending => {
1779 cx.emit(PromptEditorEvent::DismissRequested);
1780 }
1781 CodegenStatus::Done | CodegenStatus::Error(_) => {
1782 if self.edited_since_done {
1783 cx.emit(PromptEditorEvent::StartRequested);
1784 } else {
1785 cx.emit(PromptEditorEvent::ConfirmRequested);
1786 }
1787 }
1788 }
1789 }
1790
1791 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1792 if let Some(ix) = self.prompt_history_ix {
1793 if ix > 0 {
1794 self.prompt_history_ix = Some(ix - 1);
1795 let prompt = self.prompt_history[ix - 1].as_str();
1796 self.editor.update(cx, |editor, cx| {
1797 editor.set_text(prompt, cx);
1798 editor.move_to_beginning(&Default::default(), cx);
1799 });
1800 }
1801 } else if !self.prompt_history.is_empty() {
1802 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1803 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1804 self.editor.update(cx, |editor, cx| {
1805 editor.set_text(prompt, cx);
1806 editor.move_to_beginning(&Default::default(), cx);
1807 });
1808 }
1809 }
1810
1811 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1812 if let Some(ix) = self.prompt_history_ix {
1813 if ix < self.prompt_history.len() - 1 {
1814 self.prompt_history_ix = Some(ix + 1);
1815 let prompt = self.prompt_history[ix + 1].as_str();
1816 self.editor.update(cx, |editor, cx| {
1817 editor.set_text(prompt, cx);
1818 editor.move_to_end(&Default::default(), cx)
1819 });
1820 } else {
1821 self.prompt_history_ix = None;
1822 let prompt = self.pending_prompt.as_str();
1823 self.editor.update(cx, |editor, cx| {
1824 editor.set_text(prompt, cx);
1825 editor.move_to_end(&Default::default(), cx)
1826 });
1827 }
1828 }
1829 }
1830
1831 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1832 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1833 let token_counts = self.token_counts?;
1834 let max_token_count = model.max_token_count();
1835
1836 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1837 let token_count_color = if remaining_tokens <= 0 {
1838 Color::Error
1839 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1840 Color::Warning
1841 } else {
1842 Color::Muted
1843 };
1844
1845 let mut token_count = h_flex()
1846 .id("token_count")
1847 .gap_0p5()
1848 .child(
1849 Label::new(humanize_token_count(token_counts.total))
1850 .size(LabelSize::Small)
1851 .color(token_count_color),
1852 )
1853 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1854 .child(
1855 Label::new(humanize_token_count(max_token_count))
1856 .size(LabelSize::Small)
1857 .color(Color::Muted),
1858 );
1859 if let Some(workspace) = self.workspace.clone() {
1860 token_count = token_count
1861 .tooltip(move |cx| {
1862 Tooltip::with_meta(
1863 format!(
1864 "Tokens Used ({} from the Assistant Panel)",
1865 humanize_token_count(token_counts.assistant_panel)
1866 ),
1867 None,
1868 "Click to open the Assistant Panel",
1869 cx,
1870 )
1871 })
1872 .cursor_pointer()
1873 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1874 .on_click(move |_, cx| {
1875 cx.stop_propagation();
1876 workspace
1877 .update(cx, |workspace, cx| {
1878 workspace.focus_panel::<AssistantPanel>(cx)
1879 })
1880 .ok();
1881 });
1882 } else {
1883 token_count = token_count
1884 .cursor_default()
1885 .tooltip(|cx| Tooltip::text("Tokens used", cx));
1886 }
1887
1888 Some(token_count)
1889 }
1890
1891 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1892 let settings = ThemeSettings::get_global(cx);
1893 let text_style = TextStyle {
1894 color: if self.editor.read(cx).read_only(cx) {
1895 cx.theme().colors().text_disabled
1896 } else {
1897 cx.theme().colors().text
1898 },
1899 font_family: settings.ui_font.family.clone(),
1900 font_features: settings.ui_font.features.clone(),
1901 font_fallbacks: settings.ui_font.fallbacks.clone(),
1902 font_size: rems(0.875).into(),
1903 font_weight: settings.ui_font.weight,
1904 line_height: relative(1.3),
1905 ..Default::default()
1906 };
1907 EditorElement::new(
1908 &self.editor,
1909 EditorStyle {
1910 background: cx.theme().colors().editor_background,
1911 local_player: cx.theme().players().local(),
1912 text: text_style,
1913 ..Default::default()
1914 },
1915 )
1916 }
1917
1918 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1919 Popover::new().child(
1920 v_flex()
1921 .occlude()
1922 .p_2()
1923 .child(
1924 Label::new("Out of Tokens")
1925 .size(LabelSize::Small)
1926 .weight(FontWeight::BOLD),
1927 )
1928 .child(Label::new(
1929 "Try Zed Pro for higher limits, a wider range of models, and more.",
1930 ))
1931 .child(
1932 h_flex()
1933 .justify_between()
1934 .child(CheckboxWithLabel::new(
1935 "dont-show-again",
1936 Label::new("Don't show again"),
1937 if dismissed_rate_limit_notice() {
1938 ui::Selection::Selected
1939 } else {
1940 ui::Selection::Unselected
1941 },
1942 |selection, cx| {
1943 let is_dismissed = match selection {
1944 ui::Selection::Unselected => false,
1945 ui::Selection::Indeterminate => return,
1946 ui::Selection::Selected => true,
1947 };
1948
1949 set_rate_limit_notice_dismissed(is_dismissed, cx)
1950 },
1951 ))
1952 .child(
1953 h_flex()
1954 .gap_2()
1955 .child(
1956 Button::new("dismiss", "Dismiss")
1957 .style(ButtonStyle::Transparent)
1958 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1959 )
1960 .child(Button::new("more-info", "More Info").on_click(
1961 |_event, cx| {
1962 cx.dispatch_action(Box::new(
1963 zed_actions::OpenAccountSettings,
1964 ))
1965 },
1966 )),
1967 ),
1968 ),
1969 )
1970 }
1971}
1972
1973const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1974
1975fn dismissed_rate_limit_notice() -> bool {
1976 db::kvp::KEY_VALUE_STORE
1977 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1978 .log_err()
1979 .map_or(false, |s| s.is_some())
1980}
1981
1982fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1983 db::write_and_log(cx, move || async move {
1984 if is_dismissed {
1985 db::kvp::KEY_VALUE_STORE
1986 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1987 .await
1988 } else {
1989 db::kvp::KEY_VALUE_STORE
1990 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1991 .await
1992 }
1993 })
1994}
1995
1996struct InlineAssist {
1997 group_id: InlineAssistGroupId,
1998 range: Range<Anchor>,
1999 editor: WeakView<Editor>,
2000 decorations: Option<InlineAssistDecorations>,
2001 codegen: Model<Codegen>,
2002 _subscriptions: Vec<Subscription>,
2003 workspace: Option<WeakView<Workspace>>,
2004 include_context: bool,
2005}
2006
2007impl InlineAssist {
2008 #[allow(clippy::too_many_arguments)]
2009 fn new(
2010 assist_id: InlineAssistId,
2011 group_id: InlineAssistGroupId,
2012 include_context: bool,
2013 editor: &View<Editor>,
2014 prompt_editor: &View<PromptEditor>,
2015 prompt_block_id: CustomBlockId,
2016 end_block_id: CustomBlockId,
2017 range: Range<Anchor>,
2018 codegen: Model<Codegen>,
2019 workspace: Option<WeakView<Workspace>>,
2020 cx: &mut WindowContext,
2021 ) -> Self {
2022 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2023 InlineAssist {
2024 group_id,
2025 include_context,
2026 editor: editor.downgrade(),
2027 decorations: Some(InlineAssistDecorations {
2028 prompt_block_id,
2029 prompt_editor: prompt_editor.clone(),
2030 removed_line_block_ids: HashSet::default(),
2031 end_block_id,
2032 }),
2033 range,
2034 codegen: codegen.clone(),
2035 workspace: workspace.clone(),
2036 _subscriptions: vec![
2037 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2038 InlineAssistant::update_global(cx, |this, cx| {
2039 this.handle_prompt_editor_focus_in(assist_id, cx)
2040 })
2041 }),
2042 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2043 InlineAssistant::update_global(cx, |this, cx| {
2044 this.handle_prompt_editor_focus_out(assist_id, cx)
2045 })
2046 }),
2047 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2048 InlineAssistant::update_global(cx, |this, cx| {
2049 this.handle_prompt_editor_event(prompt_editor, event, cx)
2050 })
2051 }),
2052 cx.observe(&codegen, {
2053 let editor = editor.downgrade();
2054 move |_, cx| {
2055 if let Some(editor) = editor.upgrade() {
2056 InlineAssistant::update_global(cx, |this, cx| {
2057 if let Some(editor_assists) =
2058 this.assists_by_editor.get(&editor.downgrade())
2059 {
2060 editor_assists.highlight_updates.send(()).ok();
2061 }
2062
2063 this.update_editor_blocks(&editor, assist_id, cx);
2064 })
2065 }
2066 }
2067 }),
2068 cx.subscribe(&codegen, move |codegen, event, cx| {
2069 InlineAssistant::update_global(cx, |this, cx| match event {
2070 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2071 CodegenEvent::Finished => {
2072 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2073 assist
2074 } else {
2075 return;
2076 };
2077
2078 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
2079 if assist.decorations.is_none() {
2080 if let Some(workspace) = assist
2081 .workspace
2082 .as_ref()
2083 .and_then(|workspace| workspace.upgrade())
2084 {
2085 let error = format!("Inline assistant error: {}", error);
2086 workspace.update(cx, |workspace, cx| {
2087 struct InlineAssistantError;
2088
2089 let id =
2090 NotificationId::identified::<InlineAssistantError>(
2091 assist_id.0,
2092 );
2093
2094 workspace.show_toast(Toast::new(id, error), cx);
2095 })
2096 }
2097 }
2098 }
2099
2100 if assist.decorations.is_none() {
2101 this.finish_assist(assist_id, false, cx);
2102 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2103 tx.0.send(AssistStatus::Finished).ok();
2104 }
2105 }
2106 })
2107 }),
2108 ],
2109 }
2110 }
2111
2112 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2113 let decorations = self.decorations.as_ref()?;
2114 Some(decorations.prompt_editor.read(cx).prompt(cx))
2115 }
2116
2117 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2118 if self.include_context {
2119 let workspace = self.workspace.as_ref()?;
2120 let workspace = workspace.upgrade()?.read(cx);
2121 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2122 Some(
2123 assistant_panel
2124 .read(cx)
2125 .active_context(cx)?
2126 .read(cx)
2127 .to_completion_request(cx),
2128 )
2129 } else {
2130 None
2131 }
2132 }
2133
2134 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2135 let Some(user_prompt) = self.user_prompt(cx) else {
2136 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2137 };
2138 let assistant_panel_context = self.assistant_panel_context(cx);
2139 self.codegen.read(cx).count_tokens(
2140 self.range.clone(),
2141 user_prompt,
2142 assistant_panel_context,
2143 cx,
2144 )
2145 }
2146}
2147
2148struct InlineAssistDecorations {
2149 prompt_block_id: CustomBlockId,
2150 prompt_editor: View<PromptEditor>,
2151 removed_line_block_ids: HashSet<CustomBlockId>,
2152 end_block_id: CustomBlockId,
2153}
2154
2155#[derive(Debug)]
2156pub enum CodegenEvent {
2157 Finished,
2158 Undone,
2159}
2160
2161pub struct Codegen {
2162 buffer: Model<MultiBuffer>,
2163 old_buffer: Model<Buffer>,
2164 snapshot: MultiBufferSnapshot,
2165 edit_position: Option<Anchor>,
2166 last_equal_ranges: Vec<Range<Anchor>>,
2167 initial_transaction_id: Option<TransactionId>,
2168 transformation_transaction_id: Option<TransactionId>,
2169 status: CodegenStatus,
2170 generation: Task<()>,
2171 diff: Diff,
2172 telemetry: Option<Arc<Telemetry>>,
2173 _subscription: gpui::Subscription,
2174 builder: Arc<PromptBuilder>,
2175}
2176
2177enum CodegenStatus {
2178 Idle,
2179 Pending,
2180 Done,
2181 Error(anyhow::Error),
2182}
2183
2184#[derive(Default)]
2185struct Diff {
2186 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2187 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2188}
2189
2190impl Diff {
2191 fn is_empty(&self) -> bool {
2192 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2193 }
2194}
2195
2196impl EventEmitter<CodegenEvent> for Codegen {}
2197
2198impl Codegen {
2199 pub fn new(
2200 buffer: Model<MultiBuffer>,
2201 range: Range<Anchor>,
2202 initial_transaction_id: Option<TransactionId>,
2203 telemetry: Option<Arc<Telemetry>>,
2204 builder: Arc<PromptBuilder>,
2205 cx: &mut ModelContext<Self>,
2206 ) -> Self {
2207 let snapshot = buffer.read(cx).snapshot(cx);
2208
2209 let (old_buffer, _, _) = buffer
2210 .read(cx)
2211 .range_to_buffer_ranges(range.clone(), cx)
2212 .pop()
2213 .unwrap();
2214 let old_buffer = cx.new_model(|cx| {
2215 let old_buffer = old_buffer.read(cx);
2216 let text = old_buffer.as_rope().clone();
2217 let line_ending = old_buffer.line_ending();
2218 let language = old_buffer.language().cloned();
2219 let language_registry = old_buffer.language_registry();
2220
2221 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2222 buffer.set_language(language, cx);
2223 if let Some(language_registry) = language_registry {
2224 buffer.set_language_registry(language_registry)
2225 }
2226 buffer
2227 });
2228
2229 Self {
2230 buffer: buffer.clone(),
2231 old_buffer,
2232 edit_position: None,
2233 snapshot,
2234 last_equal_ranges: Default::default(),
2235 transformation_transaction_id: None,
2236 status: CodegenStatus::Idle,
2237 generation: Task::ready(()),
2238 diff: Diff::default(),
2239 telemetry,
2240 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2241 initial_transaction_id,
2242 builder,
2243 }
2244 }
2245
2246 fn handle_buffer_event(
2247 &mut self,
2248 _buffer: Model<MultiBuffer>,
2249 event: &multi_buffer::Event,
2250 cx: &mut ModelContext<Self>,
2251 ) {
2252 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2253 if self.transformation_transaction_id == Some(*transaction_id) {
2254 self.transformation_transaction_id = None;
2255 self.generation = Task::ready(());
2256 cx.emit(CodegenEvent::Undone);
2257 }
2258 }
2259 }
2260
2261 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2262 &self.last_equal_ranges
2263 }
2264
2265 pub fn count_tokens(
2266 &self,
2267 edit_range: Range<Anchor>,
2268 user_prompt: String,
2269 assistant_panel_context: Option<LanguageModelRequest>,
2270 cx: &AppContext,
2271 ) -> BoxFuture<'static, Result<TokenCounts>> {
2272 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2273 let request =
2274 self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
2275 match request {
2276 Ok(request) => {
2277 let total_count = model.count_tokens(request.clone(), cx);
2278 let assistant_panel_count = assistant_panel_context
2279 .map(|context| model.count_tokens(context, cx))
2280 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2281
2282 async move {
2283 Ok(TokenCounts {
2284 total: total_count.await?,
2285 assistant_panel: assistant_panel_count.await?,
2286 })
2287 }
2288 .boxed()
2289 }
2290 Err(error) => futures::future::ready(Err(error)).boxed(),
2291 }
2292 } else {
2293 future::ready(Err(anyhow!("no active model"))).boxed()
2294 }
2295 }
2296
2297 pub fn start(
2298 &mut self,
2299 edit_range: Range<Anchor>,
2300 user_prompt: String,
2301 assistant_panel_context: Option<LanguageModelRequest>,
2302 cx: &mut ModelContext<Self>,
2303 ) -> Result<()> {
2304 let model = LanguageModelRegistry::read_global(cx)
2305 .active_model()
2306 .context("no active model")?;
2307
2308 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2309 self.buffer.update(cx, |buffer, cx| {
2310 buffer.undo_transaction(transformation_transaction_id, cx);
2311 });
2312 }
2313
2314 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2315
2316 let telemetry_id = model.telemetry_id();
2317 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2318 .trim()
2319 .to_lowercase()
2320 == "delete"
2321 {
2322 async { Ok(stream::empty().boxed()) }.boxed_local()
2323 } else {
2324 let request =
2325 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
2326
2327 let chunks =
2328 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2329 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2330 };
2331 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2332 Ok(())
2333 }
2334
2335 fn build_request(
2336 &self,
2337 user_prompt: String,
2338 assistant_panel_context: Option<LanguageModelRequest>,
2339 edit_range: Range<Anchor>,
2340 cx: &AppContext,
2341 ) -> Result<LanguageModelRequest> {
2342 let buffer = self.buffer.read(cx).snapshot(cx);
2343 let language = buffer.language_at(edit_range.start);
2344 let language_name = if let Some(language) = language.as_ref() {
2345 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2346 None
2347 } else {
2348 Some(language.name())
2349 }
2350 } else {
2351 None
2352 };
2353
2354 // Higher Temperature increases the randomness of model outputs.
2355 // If Markdown or No Language is Known, increase the randomness for more creative output
2356 // If Code, decrease temperature to get more deterministic outputs
2357 let temperature = if let Some(language) = language_name.clone() {
2358 if language.as_ref() == "Markdown" {
2359 1.0
2360 } else {
2361 0.5
2362 }
2363 } else {
2364 1.0
2365 };
2366
2367 let language_name = language_name.as_deref();
2368 let start = buffer.point_to_buffer_offset(edit_range.start);
2369 let end = buffer.point_to_buffer_offset(edit_range.end);
2370 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2371 let (start_buffer, start_buffer_offset) = start;
2372 let (end_buffer, end_buffer_offset) = end;
2373 if start_buffer.remote_id() == end_buffer.remote_id() {
2374 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2375 } else {
2376 return Err(anyhow::anyhow!("invalid transformation range"));
2377 }
2378 } else {
2379 return Err(anyhow::anyhow!("invalid transformation range"));
2380 };
2381
2382 let prompt = self
2383 .builder
2384 .generate_content_prompt(user_prompt, language_name, buffer, range)
2385 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2386
2387 let mut messages = Vec::new();
2388 if let Some(context_request) = assistant_panel_context {
2389 messages = context_request.messages;
2390 }
2391
2392 messages.push(LanguageModelRequestMessage {
2393 role: Role::User,
2394 content: vec![prompt.into()],
2395 cache: false,
2396 });
2397
2398 Ok(LanguageModelRequest {
2399 messages,
2400 stop: vec!["|END|>".to_string()],
2401 temperature,
2402 })
2403 }
2404
2405 pub fn handle_stream(
2406 &mut self,
2407 model_telemetry_id: String,
2408 edit_range: Range<Anchor>,
2409 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2410 cx: &mut ModelContext<Self>,
2411 ) {
2412 let snapshot = self.snapshot.clone();
2413 let selected_text = snapshot
2414 .text_for_range(edit_range.start..edit_range.end)
2415 .collect::<Rope>();
2416
2417 let selection_start = edit_range.start.to_point(&snapshot);
2418
2419 // Start with the indentation of the first line in the selection
2420 let mut suggested_line_indent = snapshot
2421 .suggested_indents(selection_start.row..=selection_start.row, cx)
2422 .into_values()
2423 .next()
2424 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2425
2426 // If the first line in the selection does not have indentation, check the following lines
2427 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2428 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2429 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2430 // Prefer tabs if a line in the selection uses tabs as indentation
2431 if line_indent.kind == IndentKind::Tab {
2432 suggested_line_indent.kind = IndentKind::Tab;
2433 break;
2434 }
2435 }
2436 }
2437
2438 let telemetry = self.telemetry.clone();
2439 self.diff = Diff::default();
2440 self.status = CodegenStatus::Pending;
2441 let mut edit_start = edit_range.start.to_offset(&snapshot);
2442 self.generation = cx.spawn(|codegen, mut cx| {
2443 async move {
2444 let chunks = stream.await;
2445 let generate = async {
2446 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2447 let line_based_stream_diff: Task<anyhow::Result<()>> =
2448 cx.background_executor().spawn(async move {
2449 let mut response_latency = None;
2450 let request_start = Instant::now();
2451 let diff = async {
2452 let chunks = StripInvalidSpans::new(chunks?);
2453 futures::pin_mut!(chunks);
2454 let mut diff = StreamingDiff::new(selected_text.to_string());
2455 let mut line_diff = LineDiff::default();
2456
2457 let mut new_text = String::new();
2458 let mut base_indent = None;
2459 let mut line_indent = None;
2460 let mut first_line = true;
2461
2462 while let Some(chunk) = chunks.next().await {
2463 if response_latency.is_none() {
2464 response_latency = Some(request_start.elapsed());
2465 }
2466 let chunk = chunk?;
2467
2468 let mut lines = chunk.split('\n').peekable();
2469 while let Some(line) = lines.next() {
2470 new_text.push_str(line);
2471 if line_indent.is_none() {
2472 if let Some(non_whitespace_ch_ix) =
2473 new_text.find(|ch: char| !ch.is_whitespace())
2474 {
2475 line_indent = Some(non_whitespace_ch_ix);
2476 base_indent = base_indent.or(line_indent);
2477
2478 let line_indent = line_indent.unwrap();
2479 let base_indent = base_indent.unwrap();
2480 let indent_delta =
2481 line_indent as i32 - base_indent as i32;
2482 let mut corrected_indent_len = cmp::max(
2483 0,
2484 suggested_line_indent.len as i32 + indent_delta,
2485 )
2486 as usize;
2487 if first_line {
2488 corrected_indent_len = corrected_indent_len
2489 .saturating_sub(
2490 selection_start.column as usize,
2491 );
2492 }
2493
2494 let indent_char = suggested_line_indent.char();
2495 let mut indent_buffer = [0; 4];
2496 let indent_str =
2497 indent_char.encode_utf8(&mut indent_buffer);
2498 new_text.replace_range(
2499 ..line_indent,
2500 &indent_str.repeat(corrected_indent_len),
2501 );
2502 }
2503 }
2504
2505 if line_indent.is_some() {
2506 let char_ops = diff.push_new(&new_text);
2507 line_diff
2508 .push_char_operations(&char_ops, &selected_text);
2509 diff_tx
2510 .send((char_ops, line_diff.line_operations()))
2511 .await?;
2512 new_text.clear();
2513 }
2514
2515 if lines.peek().is_some() {
2516 let char_ops = diff.push_new("\n");
2517 line_diff
2518 .push_char_operations(&char_ops, &selected_text);
2519 diff_tx
2520 .send((char_ops, line_diff.line_operations()))
2521 .await?;
2522 if line_indent.is_none() {
2523 // Don't write out the leading indentation in empty lines on the next line
2524 // This is the case where the above if statement didn't clear the buffer
2525 new_text.clear();
2526 }
2527 line_indent = None;
2528 first_line = false;
2529 }
2530 }
2531 }
2532
2533 let mut char_ops = diff.push_new(&new_text);
2534 char_ops.extend(diff.finish());
2535 line_diff.push_char_operations(&char_ops, &selected_text);
2536 line_diff.finish(&selected_text);
2537 diff_tx
2538 .send((char_ops, line_diff.line_operations()))
2539 .await?;
2540
2541 anyhow::Ok(())
2542 };
2543
2544 let result = diff.await;
2545
2546 let error_message =
2547 result.as_ref().err().map(|error| error.to_string());
2548 if let Some(telemetry) = telemetry {
2549 telemetry.report_assistant_event(
2550 None,
2551 telemetry_events::AssistantKind::Inline,
2552 model_telemetry_id,
2553 response_latency,
2554 error_message,
2555 );
2556 }
2557
2558 result?;
2559 Ok(())
2560 });
2561
2562 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2563 codegen.update(&mut cx, |codegen, cx| {
2564 codegen.last_equal_ranges.clear();
2565
2566 let transaction = codegen.buffer.update(cx, |buffer, cx| {
2567 // Avoid grouping assistant edits with user edits.
2568 buffer.finalize_last_transaction(cx);
2569
2570 buffer.start_transaction(cx);
2571 buffer.edit(
2572 char_ops
2573 .into_iter()
2574 .filter_map(|operation| match operation {
2575 CharOperation::Insert { text } => {
2576 let edit_start = snapshot.anchor_after(edit_start);
2577 Some((edit_start..edit_start, text))
2578 }
2579 CharOperation::Delete { bytes } => {
2580 let edit_end = edit_start + bytes;
2581 let edit_range = snapshot.anchor_after(edit_start)
2582 ..snapshot.anchor_before(edit_end);
2583 edit_start = edit_end;
2584 Some((edit_range, String::new()))
2585 }
2586 CharOperation::Keep { bytes } => {
2587 let edit_end = edit_start + bytes;
2588 let edit_range = snapshot.anchor_after(edit_start)
2589 ..snapshot.anchor_before(edit_end);
2590 edit_start = edit_end;
2591 codegen.last_equal_ranges.push(edit_range);
2592 None
2593 }
2594 }),
2595 None,
2596 cx,
2597 );
2598 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2599
2600 buffer.end_transaction(cx)
2601 });
2602
2603 if let Some(transaction) = transaction {
2604 if let Some(first_transaction) =
2605 codegen.transformation_transaction_id
2606 {
2607 // Group all assistant edits into the first transaction.
2608 codegen.buffer.update(cx, |buffer, cx| {
2609 buffer.merge_transactions(
2610 transaction,
2611 first_transaction,
2612 cx,
2613 )
2614 });
2615 } else {
2616 codegen.transformation_transaction_id = Some(transaction);
2617 codegen.buffer.update(cx, |buffer, cx| {
2618 buffer.finalize_last_transaction(cx)
2619 });
2620 }
2621 }
2622
2623 codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx);
2624
2625 cx.notify();
2626 })?;
2627 }
2628
2629 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2630 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2631 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2632 let batch_diff_task = codegen.update(&mut cx, |codegen, cx| {
2633 codegen.reapply_batch_diff(edit_range.clone(), cx)
2634 })?;
2635 let (line_based_stream_diff, ()) =
2636 join!(line_based_stream_diff, batch_diff_task);
2637 line_based_stream_diff?;
2638
2639 anyhow::Ok(())
2640 };
2641
2642 let result = generate.await;
2643 codegen
2644 .update(&mut cx, |this, cx| {
2645 this.last_equal_ranges.clear();
2646 if let Err(error) = result {
2647 this.status = CodegenStatus::Error(error);
2648 } else {
2649 this.status = CodegenStatus::Done;
2650 }
2651 cx.emit(CodegenEvent::Finished);
2652 cx.notify();
2653 })
2654 .ok();
2655 }
2656 });
2657 cx.notify();
2658 }
2659
2660 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2661 self.last_equal_ranges.clear();
2662 if self.diff.is_empty() {
2663 self.status = CodegenStatus::Idle;
2664 } else {
2665 self.status = CodegenStatus::Done;
2666 }
2667 self.generation = Task::ready(());
2668 cx.emit(CodegenEvent::Finished);
2669 cx.notify();
2670 }
2671
2672 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2673 self.buffer.update(cx, |buffer, cx| {
2674 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2675 buffer.undo_transaction(transaction_id, cx);
2676 buffer.refresh_preview(cx);
2677 }
2678
2679 if let Some(transaction_id) = self.initial_transaction_id.take() {
2680 buffer.undo_transaction(transaction_id, cx);
2681 buffer.refresh_preview(cx);
2682 }
2683 });
2684 }
2685
2686 fn reapply_line_based_diff(
2687 &mut self,
2688 edit_range: Range<Anchor>,
2689 line_operations: Vec<LineOperation>,
2690 cx: &mut ModelContext<Self>,
2691 ) {
2692 let old_snapshot = self.snapshot.clone();
2693 let old_range = edit_range.to_point(&old_snapshot);
2694 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2695 let new_range = edit_range.to_point(&new_snapshot);
2696
2697 let mut old_row = old_range.start.row;
2698 let mut new_row = new_range.start.row;
2699
2700 self.diff.deleted_row_ranges.clear();
2701 self.diff.inserted_row_ranges.clear();
2702 for operation in line_operations {
2703 match operation {
2704 LineOperation::Keep { lines } => {
2705 old_row += lines;
2706 new_row += lines;
2707 }
2708 LineOperation::Delete { lines } => {
2709 let old_end_row = old_row + lines - 1;
2710 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2711
2712 if let Some((_, last_deleted_row_range)) =
2713 self.diff.deleted_row_ranges.last_mut()
2714 {
2715 if *last_deleted_row_range.end() + 1 == old_row {
2716 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2717 } else {
2718 self.diff
2719 .deleted_row_ranges
2720 .push((new_row, old_row..=old_end_row));
2721 }
2722 } else {
2723 self.diff
2724 .deleted_row_ranges
2725 .push((new_row, old_row..=old_end_row));
2726 }
2727
2728 old_row += lines;
2729 }
2730 LineOperation::Insert { lines } => {
2731 let new_end_row = new_row + lines - 1;
2732 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2733 let end = new_snapshot.anchor_before(Point::new(
2734 new_end_row,
2735 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2736 ));
2737 self.diff.inserted_row_ranges.push(start..=end);
2738 new_row += lines;
2739 }
2740 }
2741
2742 cx.notify();
2743 }
2744 }
2745
2746 fn reapply_batch_diff(
2747 &mut self,
2748 edit_range: Range<Anchor>,
2749 cx: &mut ModelContext<Self>,
2750 ) -> Task<()> {
2751 let old_snapshot = self.snapshot.clone();
2752 let old_range = edit_range.to_point(&old_snapshot);
2753 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2754 let new_range = edit_range.to_point(&new_snapshot);
2755
2756 cx.spawn(|codegen, mut cx| async move {
2757 let (deleted_row_ranges, inserted_row_ranges) = cx
2758 .background_executor()
2759 .spawn(async move {
2760 let old_text = old_snapshot
2761 .text_for_range(
2762 Point::new(old_range.start.row, 0)
2763 ..Point::new(
2764 old_range.end.row,
2765 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2766 ),
2767 )
2768 .collect::<String>();
2769 let new_text = new_snapshot
2770 .text_for_range(
2771 Point::new(new_range.start.row, 0)
2772 ..Point::new(
2773 new_range.end.row,
2774 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2775 ),
2776 )
2777 .collect::<String>();
2778
2779 let mut old_row = old_range.start.row;
2780 let mut new_row = new_range.start.row;
2781 let batch_diff =
2782 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2783
2784 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2785 let mut inserted_row_ranges = Vec::new();
2786 for change in batch_diff.iter_all_changes() {
2787 let line_count = change.value().lines().count() as u32;
2788 match change.tag() {
2789 similar::ChangeTag::Equal => {
2790 old_row += line_count;
2791 new_row += line_count;
2792 }
2793 similar::ChangeTag::Delete => {
2794 let old_end_row = old_row + line_count - 1;
2795 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2796
2797 if let Some((_, last_deleted_row_range)) =
2798 deleted_row_ranges.last_mut()
2799 {
2800 if *last_deleted_row_range.end() + 1 == old_row {
2801 *last_deleted_row_range =
2802 *last_deleted_row_range.start()..=old_end_row;
2803 } else {
2804 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2805 }
2806 } else {
2807 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2808 }
2809
2810 old_row += line_count;
2811 }
2812 similar::ChangeTag::Insert => {
2813 let new_end_row = new_row + line_count - 1;
2814 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2815 let end = new_snapshot.anchor_before(Point::new(
2816 new_end_row,
2817 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2818 ));
2819 inserted_row_ranges.push(start..=end);
2820 new_row += line_count;
2821 }
2822 }
2823 }
2824
2825 (deleted_row_ranges, inserted_row_ranges)
2826 })
2827 .await;
2828
2829 codegen
2830 .update(&mut cx, |codegen, cx| {
2831 codegen.diff.deleted_row_ranges = deleted_row_ranges;
2832 codegen.diff.inserted_row_ranges = inserted_row_ranges;
2833 cx.notify();
2834 })
2835 .ok();
2836 })
2837 }
2838}
2839
2840struct StripInvalidSpans<T> {
2841 stream: T,
2842 stream_done: bool,
2843 buffer: String,
2844 first_line: bool,
2845 line_end: bool,
2846 starts_with_code_block: bool,
2847}
2848
2849impl<T> StripInvalidSpans<T>
2850where
2851 T: Stream<Item = Result<String>>,
2852{
2853 fn new(stream: T) -> Self {
2854 Self {
2855 stream,
2856 stream_done: false,
2857 buffer: String::new(),
2858 first_line: true,
2859 line_end: false,
2860 starts_with_code_block: false,
2861 }
2862 }
2863}
2864
2865impl<T> Stream for StripInvalidSpans<T>
2866where
2867 T: Stream<Item = Result<String>>,
2868{
2869 type Item = Result<String>;
2870
2871 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2872 const CODE_BLOCK_DELIMITER: &str = "```";
2873 const CURSOR_SPAN: &str = "<|CURSOR|>";
2874
2875 let this = unsafe { self.get_unchecked_mut() };
2876 loop {
2877 if !this.stream_done {
2878 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2879 match stream.as_mut().poll_next(cx) {
2880 Poll::Ready(Some(Ok(chunk))) => {
2881 this.buffer.push_str(&chunk);
2882 }
2883 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2884 Poll::Ready(None) => {
2885 this.stream_done = true;
2886 }
2887 Poll::Pending => return Poll::Pending,
2888 }
2889 }
2890
2891 let mut chunk = String::new();
2892 let mut consumed = 0;
2893 if !this.buffer.is_empty() {
2894 let mut lines = this.buffer.split('\n').enumerate().peekable();
2895 while let Some((line_ix, line)) = lines.next() {
2896 if line_ix > 0 {
2897 this.first_line = false;
2898 }
2899
2900 if this.first_line {
2901 let trimmed_line = line.trim();
2902 if lines.peek().is_some() {
2903 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2904 consumed += line.len() + 1;
2905 this.starts_with_code_block = true;
2906 continue;
2907 }
2908 } else if trimmed_line.is_empty()
2909 || prefixes(CODE_BLOCK_DELIMITER)
2910 .any(|prefix| trimmed_line.starts_with(prefix))
2911 {
2912 break;
2913 }
2914 }
2915
2916 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2917 if lines.peek().is_some() {
2918 if this.line_end {
2919 chunk.push('\n');
2920 }
2921
2922 chunk.push_str(&line_without_cursor);
2923 this.line_end = true;
2924 consumed += line.len() + 1;
2925 } else if this.stream_done {
2926 if !this.starts_with_code_block
2927 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2928 {
2929 if this.line_end {
2930 chunk.push('\n');
2931 }
2932
2933 chunk.push_str(&line);
2934 }
2935
2936 consumed += line.len();
2937 } else {
2938 let trimmed_line = line.trim();
2939 if trimmed_line.is_empty()
2940 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2941 || prefixes(CODE_BLOCK_DELIMITER)
2942 .any(|prefix| trimmed_line.ends_with(prefix))
2943 {
2944 break;
2945 } else {
2946 if this.line_end {
2947 chunk.push('\n');
2948 this.line_end = false;
2949 }
2950
2951 chunk.push_str(&line_without_cursor);
2952 consumed += line.len();
2953 }
2954 }
2955 }
2956 }
2957
2958 this.buffer = this.buffer.split_off(consumed);
2959 if !chunk.is_empty() {
2960 return Poll::Ready(Some(Ok(chunk)));
2961 } else if this.stream_done {
2962 return Poll::Ready(None);
2963 }
2964 }
2965 }
2966}
2967
2968fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2969 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2970}
2971
2972fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2973 ranges.sort_unstable_by(|a, b| {
2974 a.start
2975 .cmp(&b.start, buffer)
2976 .then_with(|| b.end.cmp(&a.end, buffer))
2977 });
2978
2979 let mut ix = 0;
2980 while ix + 1 < ranges.len() {
2981 let b = ranges[ix + 1].clone();
2982 let a = &mut ranges[ix];
2983 if a.end.cmp(&b.start, buffer).is_gt() {
2984 if a.end.cmp(&b.end, buffer).is_lt() {
2985 a.end = b.end;
2986 }
2987 ranges.remove(ix + 1);
2988 } else {
2989 ix += 1;
2990 }
2991 }
2992}
2993
2994#[cfg(test)]
2995mod tests {
2996 use super::*;
2997 use futures::stream::{self};
2998 use gpui::{Context, TestAppContext};
2999 use indoc::indoc;
3000 use language::{
3001 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3002 Point,
3003 };
3004 use language_model::LanguageModelRegistry;
3005 use rand::prelude::*;
3006 use serde::Serialize;
3007 use settings::SettingsStore;
3008 use std::{future, sync::Arc};
3009
3010 #[derive(Serialize)]
3011 pub struct DummyCompletionRequest {
3012 pub name: String,
3013 }
3014
3015 #[gpui::test(iterations = 10)]
3016 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3017 cx.set_global(cx.update(SettingsStore::test));
3018 cx.update(language_model::LanguageModelRegistry::test);
3019 cx.update(language_settings::init);
3020
3021 let text = indoc! {"
3022 fn main() {
3023 let x = 0;
3024 for _ in 0..10 {
3025 x += 1;
3026 }
3027 }
3028 "};
3029 let buffer =
3030 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3031 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3032 let range = buffer.read_with(cx, |buffer, cx| {
3033 let snapshot = buffer.snapshot(cx);
3034 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3035 });
3036 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3037 let codegen = cx.new_model(|cx| {
3038 Codegen::new(
3039 buffer.clone(),
3040 range.clone(),
3041 None,
3042 None,
3043 prompt_builder,
3044 cx,
3045 )
3046 });
3047
3048 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3049 codegen.update(cx, |codegen, cx| {
3050 codegen.handle_stream(
3051 String::new(),
3052 range,
3053 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3054 cx,
3055 )
3056 });
3057
3058 let mut new_text = concat!(
3059 " let mut x = 0;\n",
3060 " while x < 10 {\n",
3061 " x += 1;\n",
3062 " }",
3063 );
3064 while !new_text.is_empty() {
3065 let max_len = cmp::min(new_text.len(), 10);
3066 let len = rng.gen_range(1..=max_len);
3067 let (chunk, suffix) = new_text.split_at(len);
3068 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3069 new_text = suffix;
3070 cx.background_executor.run_until_parked();
3071 }
3072 drop(chunks_tx);
3073 cx.background_executor.run_until_parked();
3074
3075 assert_eq!(
3076 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3077 indoc! {"
3078 fn main() {
3079 let mut x = 0;
3080 while x < 10 {
3081 x += 1;
3082 }
3083 }
3084 "}
3085 );
3086 }
3087
3088 #[gpui::test(iterations = 10)]
3089 async fn test_autoindent_when_generating_past_indentation(
3090 cx: &mut TestAppContext,
3091 mut rng: StdRng,
3092 ) {
3093 cx.set_global(cx.update(SettingsStore::test));
3094 cx.update(language_settings::init);
3095
3096 let text = indoc! {"
3097 fn main() {
3098 le
3099 }
3100 "};
3101 let buffer =
3102 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3103 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3104 let range = buffer.read_with(cx, |buffer, cx| {
3105 let snapshot = buffer.snapshot(cx);
3106 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3107 });
3108 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3109 let codegen = cx.new_model(|cx| {
3110 Codegen::new(
3111 buffer.clone(),
3112 range.clone(),
3113 None,
3114 None,
3115 prompt_builder,
3116 cx,
3117 )
3118 });
3119
3120 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3121 codegen.update(cx, |codegen, cx| {
3122 codegen.handle_stream(
3123 String::new(),
3124 range.clone(),
3125 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3126 cx,
3127 )
3128 });
3129
3130 cx.background_executor.run_until_parked();
3131
3132 let mut new_text = concat!(
3133 "t mut x = 0;\n",
3134 "while x < 10 {\n",
3135 " x += 1;\n",
3136 "}", //
3137 );
3138 while !new_text.is_empty() {
3139 let max_len = cmp::min(new_text.len(), 10);
3140 let len = rng.gen_range(1..=max_len);
3141 let (chunk, suffix) = new_text.split_at(len);
3142 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3143 new_text = suffix;
3144 cx.background_executor.run_until_parked();
3145 }
3146 drop(chunks_tx);
3147 cx.background_executor.run_until_parked();
3148
3149 assert_eq!(
3150 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3151 indoc! {"
3152 fn main() {
3153 let mut x = 0;
3154 while x < 10 {
3155 x += 1;
3156 }
3157 }
3158 "}
3159 );
3160 }
3161
3162 #[gpui::test(iterations = 10)]
3163 async fn test_autoindent_when_generating_before_indentation(
3164 cx: &mut TestAppContext,
3165 mut rng: StdRng,
3166 ) {
3167 cx.update(LanguageModelRegistry::test);
3168 cx.set_global(cx.update(SettingsStore::test));
3169 cx.update(language_settings::init);
3170
3171 let text = concat!(
3172 "fn main() {\n",
3173 " \n",
3174 "}\n" //
3175 );
3176 let buffer =
3177 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3178 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3179 let range = buffer.read_with(cx, |buffer, cx| {
3180 let snapshot = buffer.snapshot(cx);
3181 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3182 });
3183 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3184 let codegen = cx.new_model(|cx| {
3185 Codegen::new(
3186 buffer.clone(),
3187 range.clone(),
3188 None,
3189 None,
3190 prompt_builder,
3191 cx,
3192 )
3193 });
3194
3195 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3196 codegen.update(cx, |codegen, cx| {
3197 codegen.handle_stream(
3198 String::new(),
3199 range.clone(),
3200 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3201 cx,
3202 )
3203 });
3204
3205 cx.background_executor.run_until_parked();
3206
3207 let mut new_text = concat!(
3208 "let mut x = 0;\n",
3209 "while x < 10 {\n",
3210 " x += 1;\n",
3211 "}", //
3212 );
3213 while !new_text.is_empty() {
3214 let max_len = cmp::min(new_text.len(), 10);
3215 let len = rng.gen_range(1..=max_len);
3216 let (chunk, suffix) = new_text.split_at(len);
3217 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3218 new_text = suffix;
3219 cx.background_executor.run_until_parked();
3220 }
3221 drop(chunks_tx);
3222 cx.background_executor.run_until_parked();
3223
3224 assert_eq!(
3225 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3226 indoc! {"
3227 fn main() {
3228 let mut x = 0;
3229 while x < 10 {
3230 x += 1;
3231 }
3232 }
3233 "}
3234 );
3235 }
3236
3237 #[gpui::test(iterations = 10)]
3238 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3239 cx.update(LanguageModelRegistry::test);
3240 cx.set_global(cx.update(SettingsStore::test));
3241 cx.update(language_settings::init);
3242
3243 let text = indoc! {"
3244 func main() {
3245 \tx := 0
3246 \tfor i := 0; i < 10; i++ {
3247 \t\tx++
3248 \t}
3249 }
3250 "};
3251 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3252 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3253 let range = buffer.read_with(cx, |buffer, cx| {
3254 let snapshot = buffer.snapshot(cx);
3255 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3256 });
3257 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3258 let codegen = cx.new_model(|cx| {
3259 Codegen::new(
3260 buffer.clone(),
3261 range.clone(),
3262 None,
3263 None,
3264 prompt_builder,
3265 cx,
3266 )
3267 });
3268
3269 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3270 codegen.update(cx, |codegen, cx| {
3271 codegen.handle_stream(
3272 String::new(),
3273 range.clone(),
3274 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3275 cx,
3276 )
3277 });
3278
3279 let new_text = concat!(
3280 "func main() {\n",
3281 "\tx := 0\n",
3282 "\tfor x < 10 {\n",
3283 "\t\tx++\n",
3284 "\t}", //
3285 );
3286 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3287 drop(chunks_tx);
3288 cx.background_executor.run_until_parked();
3289
3290 assert_eq!(
3291 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3292 indoc! {"
3293 func main() {
3294 \tx := 0
3295 \tfor x < 10 {
3296 \t\tx++
3297 \t}
3298 }
3299 "}
3300 );
3301 }
3302
3303 #[gpui::test]
3304 async fn test_strip_invalid_spans_from_codeblock() {
3305 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3306 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3307 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3308 assert_chunks(
3309 "```html\n```js\nLorem ipsum dolor\n```\n```",
3310 "```js\nLorem ipsum dolor\n```",
3311 )
3312 .await;
3313 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3314 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3315 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3316 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3317
3318 async fn assert_chunks(text: &str, expected_text: &str) {
3319 for chunk_size in 1..=text.len() {
3320 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3321 .map(|chunk| chunk.unwrap())
3322 .collect::<String>()
3323 .await;
3324 assert_eq!(
3325 actual_text, expected_text,
3326 "failed to strip invalid spans, chunk size: {}",
3327 chunk_size
3328 );
3329 }
3330 }
3331
3332 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3333 stream::iter(
3334 text.chars()
3335 .collect::<Vec<_>>()
3336 .chunks(size)
3337 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3338 .collect::<Vec<_>>(),
3339 )
3340 }
3341 }
3342
3343 fn rust_lang() -> Language {
3344 Language::new(
3345 LanguageConfig {
3346 name: "Rust".into(),
3347 matcher: LanguageMatcher {
3348 path_suffixes: vec!["rs".to_string()],
3349 ..Default::default()
3350 },
3351 ..Default::default()
3352 },
3353 Some(tree_sitter_rust::language()),
3354 )
3355 .with_indents_query(
3356 r#"
3357 (call_expression) @indent
3358 (field_expression) @indent
3359 (_ "(" ")" @end) @indent
3360 (_ "{" "}" @end) @indent
3361 "#,
3362 )
3363 .unwrap()
3364 }
3365}