1use crate::{
2 humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
3 CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 join,
23 stream::{self, BoxStream},
24 SinkExt, Stream, StreamExt,
25};
26use gpui::{
27 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
28 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
29 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
30};
31use language::{Buffer, IndentKind, Point, Selection, TransactionId};
32use language_model::{
33 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
34};
35use multi_buffer::MultiBufferRow;
36use parking_lot::Mutex;
37use rope::Rope;
38use settings::Settings;
39use smol::future::FutureExt;
40use std::{
41 cmp,
42 future::{self, Future},
43 mem,
44 ops::{Range, RangeInclusive},
45 pin::Pin,
46 sync::Arc,
47 task::{self, Poll},
48 time::{Duration, Instant},
49};
50use theme::ThemeSettings;
51use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
52use util::{RangeExt, ResultExt};
53use workspace::{notifications::NotificationId, Toast, Workspace};
54
55pub fn init(
56 fs: Arc<dyn Fs>,
57 prompt_builder: Arc<PromptBuilder>,
58 telemetry: Arc<Telemetry>,
59 cx: &mut AppContext,
60) {
61 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
62 cx.observe_new_views(|_, cx| {
63 let workspace = cx.view().clone();
64 InlineAssistant::update_global(cx, |inline_assistant, cx| {
65 inline_assistant.register_workspace(&workspace, cx)
66 })
67 })
68 .detach();
69}
70
71const PROMPT_HISTORY_MAX_LEN: usize = 20;
72
73pub struct InlineAssistant {
74 next_assist_id: InlineAssistId,
75 next_assist_group_id: InlineAssistGroupId,
76 assists: HashMap<InlineAssistId, InlineAssist>,
77 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
78 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
79 assist_observations:
80 HashMap<InlineAssistId, (async_watch::Sender<()>, async_watch::Receiver<()>)>,
81 confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
82 prompt_history: VecDeque<String>,
83 prompt_builder: Arc<PromptBuilder>,
84 telemetry: Option<Arc<Telemetry>>,
85 fs: Arc<dyn Fs>,
86}
87
88impl Global for InlineAssistant {}
89
90impl InlineAssistant {
91 pub fn new(
92 fs: Arc<dyn Fs>,
93 prompt_builder: Arc<PromptBuilder>,
94 telemetry: Arc<Telemetry>,
95 ) -> Self {
96 Self {
97 next_assist_id: InlineAssistId::default(),
98 next_assist_group_id: InlineAssistGroupId::default(),
99 assists: HashMap::default(),
100 assists_by_editor: HashMap::default(),
101 assist_groups: HashMap::default(),
102 assist_observations: HashMap::default(),
103 confirmed_assists: HashMap::default(),
104 prompt_history: VecDeque::default(),
105 prompt_builder,
106 telemetry: Some(telemetry),
107 fs,
108 }
109 }
110
111 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
112 cx.subscribe(workspace, |_, event, cx| {
113 Self::update_global(cx, |this, cx| this.handle_workspace_event(event, cx));
114 })
115 .detach();
116 }
117
118 fn handle_workspace_event(&mut self, event: &workspace::Event, cx: &mut WindowContext) {
119 // When the user manually saves an editor, automatically accepts all finished transformations.
120 if let workspace::Event::UserSavedItem { item, .. } = event {
121 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
122 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
123 for assist_id in editor_assists.assist_ids.clone() {
124 let assist = &self.assists[&assist_id];
125 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
126 self.finish_assist(assist_id, false, cx)
127 }
128 }
129 }
130 }
131 }
132 }
133
134 pub fn assist(
135 &mut self,
136 editor: &View<Editor>,
137 workspace: Option<WeakView<Workspace>>,
138 assistant_panel: Option<&View<AssistantPanel>>,
139 initial_prompt: Option<String>,
140 cx: &mut WindowContext,
141 ) {
142 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
143
144 let mut selections = Vec::<Selection<Point>>::new();
145 let mut newest_selection = None;
146 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
147 if selection.end > selection.start {
148 selection.start.column = 0;
149 // If the selection ends at the start of the line, we don't want to include it.
150 if selection.end.column == 0 {
151 selection.end.row -= 1;
152 }
153 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
154 }
155
156 if let Some(prev_selection) = selections.last_mut() {
157 if selection.start <= prev_selection.end {
158 prev_selection.end = selection.end;
159 continue;
160 }
161 }
162
163 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
164 if selection.id > latest_selection.id {
165 *latest_selection = selection.clone();
166 }
167 selections.push(selection);
168 }
169 let newest_selection = newest_selection.unwrap();
170
171 let mut codegen_ranges = Vec::new();
172 for (excerpt_id, buffer, buffer_range) in
173 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
174 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
175 }))
176 {
177 let start = Anchor {
178 buffer_id: Some(buffer.remote_id()),
179 excerpt_id,
180 text_anchor: buffer.anchor_before(buffer_range.start),
181 };
182 let end = Anchor {
183 buffer_id: Some(buffer.remote_id()),
184 excerpt_id,
185 text_anchor: buffer.anchor_after(buffer_range.end),
186 };
187 codegen_ranges.push(start..end);
188 }
189
190 let assist_group_id = self.next_assist_group_id.post_inc();
191 let prompt_buffer =
192 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
193 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
194
195 let mut assists = Vec::new();
196 let mut assist_to_focus = None;
197 for range in codegen_ranges {
198 let assist_id = self.next_assist_id.post_inc();
199 let codegen = cx.new_model(|cx| {
200 Codegen::new(
201 editor.read(cx).buffer().clone(),
202 range.clone(),
203 None,
204 self.telemetry.clone(),
205 self.prompt_builder.clone(),
206 cx,
207 )
208 });
209
210 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
211 let prompt_editor = cx.new_view(|cx| {
212 PromptEditor::new(
213 assist_id,
214 gutter_dimensions.clone(),
215 self.prompt_history.clone(),
216 prompt_buffer.clone(),
217 codegen.clone(),
218 editor,
219 assistant_panel,
220 workspace.clone(),
221 self.fs.clone(),
222 cx,
223 )
224 });
225
226 if assist_to_focus.is_none() {
227 let focus_assist = if newest_selection.reversed {
228 range.start.to_point(&snapshot) == newest_selection.start
229 } else {
230 range.end.to_point(&snapshot) == newest_selection.end
231 };
232 if focus_assist {
233 assist_to_focus = Some(assist_id);
234 }
235 }
236
237 let [prompt_block_id, end_block_id] =
238 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
239
240 assists.push((
241 assist_id,
242 range,
243 prompt_editor,
244 prompt_block_id,
245 end_block_id,
246 ));
247 }
248
249 let editor_assists = self
250 .assists_by_editor
251 .entry(editor.downgrade())
252 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
253 let mut assist_group = InlineAssistGroup::new();
254 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
255 self.assists.insert(
256 assist_id,
257 InlineAssist::new(
258 assist_id,
259 assist_group_id,
260 assistant_panel.is_some(),
261 editor,
262 &prompt_editor,
263 prompt_block_id,
264 end_block_id,
265 range,
266 prompt_editor.read(cx).codegen.clone(),
267 workspace.clone(),
268 cx,
269 ),
270 );
271 assist_group.assist_ids.push(assist_id);
272 editor_assists.assist_ids.push(assist_id);
273 }
274 self.assist_groups.insert(assist_group_id, assist_group);
275
276 if let Some(assist_id) = assist_to_focus {
277 self.focus_assist(assist_id, cx);
278 }
279 }
280
281 #[allow(clippy::too_many_arguments)]
282 pub fn suggest_assist(
283 &mut self,
284 editor: &View<Editor>,
285 mut range: Range<Anchor>,
286 initial_prompt: String,
287 initial_transaction_id: Option<TransactionId>,
288 workspace: Option<WeakView<Workspace>>,
289 assistant_panel: Option<&View<AssistantPanel>>,
290 cx: &mut WindowContext,
291 ) -> InlineAssistId {
292 let assist_group_id = self.next_assist_group_id.post_inc();
293 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
294 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
295
296 let assist_id = self.next_assist_id.post_inc();
297
298 let buffer = editor.read(cx).buffer().clone();
299 {
300 let snapshot = buffer.read(cx).read(cx);
301 range.start = range.start.bias_left(&snapshot);
302 range.end = range.end.bias_right(&snapshot);
303 }
304
305 let codegen = cx.new_model(|cx| {
306 Codegen::new(
307 editor.read(cx).buffer().clone(),
308 range.clone(),
309 initial_transaction_id,
310 self.telemetry.clone(),
311 self.prompt_builder.clone(),
312 cx,
313 )
314 });
315
316 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
317 let prompt_editor = cx.new_view(|cx| {
318 PromptEditor::new(
319 assist_id,
320 gutter_dimensions.clone(),
321 self.prompt_history.clone(),
322 prompt_buffer.clone(),
323 codegen.clone(),
324 editor,
325 assistant_panel,
326 workspace.clone(),
327 self.fs.clone(),
328 cx,
329 )
330 });
331
332 let [prompt_block_id, end_block_id] =
333 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
334
335 let editor_assists = self
336 .assists_by_editor
337 .entry(editor.downgrade())
338 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
339
340 let mut assist_group = InlineAssistGroup::new();
341 self.assists.insert(
342 assist_id,
343 InlineAssist::new(
344 assist_id,
345 assist_group_id,
346 assistant_panel.is_some(),
347 editor,
348 &prompt_editor,
349 prompt_block_id,
350 end_block_id,
351 range,
352 prompt_editor.read(cx).codegen.clone(),
353 workspace.clone(),
354 cx,
355 ),
356 );
357 assist_group.assist_ids.push(assist_id);
358 editor_assists.assist_ids.push(assist_id);
359 self.assist_groups.insert(assist_group_id, assist_group);
360 assist_id
361 }
362
363 fn insert_assist_blocks(
364 &self,
365 editor: &View<Editor>,
366 range: &Range<Anchor>,
367 prompt_editor: &View<PromptEditor>,
368 cx: &mut WindowContext,
369 ) -> [CustomBlockId; 2] {
370 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
371 prompt_editor
372 .editor
373 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
374 });
375 let assist_blocks = vec![
376 BlockProperties {
377 style: BlockStyle::Sticky,
378 position: range.start,
379 height: prompt_editor_height,
380 render: build_assist_editor_renderer(prompt_editor),
381 disposition: BlockDisposition::Above,
382 priority: 0,
383 },
384 BlockProperties {
385 style: BlockStyle::Sticky,
386 position: range.end,
387 height: 0,
388 render: Box::new(|cx| {
389 v_flex()
390 .h_full()
391 .w_full()
392 .border_t_1()
393 .border_color(cx.theme().status().info_border)
394 .into_any_element()
395 }),
396 disposition: BlockDisposition::Below,
397 priority: 0,
398 },
399 ];
400
401 editor.update(cx, |editor, cx| {
402 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
403 [block_ids[0], block_ids[1]]
404 })
405 }
406
407 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
408 let assist = &self.assists[&assist_id];
409 let Some(decorations) = assist.decorations.as_ref() else {
410 return;
411 };
412 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
413 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
414
415 assist_group.active_assist_id = Some(assist_id);
416 if assist_group.linked {
417 for assist_id in &assist_group.assist_ids {
418 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
419 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
420 prompt_editor.set_show_cursor_when_unfocused(true, cx)
421 });
422 }
423 }
424 }
425
426 assist
427 .editor
428 .update(cx, |editor, cx| {
429 let scroll_top = editor.scroll_position(cx).y;
430 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
431 let prompt_row = editor
432 .row_for_block(decorations.prompt_block_id, cx)
433 .unwrap()
434 .0 as f32;
435
436 if (scroll_top..scroll_bottom).contains(&prompt_row) {
437 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
438 assist_id,
439 distance_from_top: prompt_row - scroll_top,
440 });
441 } else {
442 editor_assists.scroll_lock = None;
443 }
444 })
445 .ok();
446 }
447
448 fn handle_prompt_editor_focus_out(
449 &mut self,
450 assist_id: InlineAssistId,
451 cx: &mut WindowContext,
452 ) {
453 let assist = &self.assists[&assist_id];
454 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
455 if assist_group.active_assist_id == Some(assist_id) {
456 assist_group.active_assist_id = None;
457 if assist_group.linked {
458 for assist_id in &assist_group.assist_ids {
459 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
460 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
461 prompt_editor.set_show_cursor_when_unfocused(false, cx)
462 });
463 }
464 }
465 }
466 }
467 }
468
469 fn handle_prompt_editor_event(
470 &mut self,
471 prompt_editor: View<PromptEditor>,
472 event: &PromptEditorEvent,
473 cx: &mut WindowContext,
474 ) {
475 let assist_id = prompt_editor.read(cx).id;
476 match event {
477 PromptEditorEvent::StartRequested => {
478 self.start_assist(assist_id, cx);
479 }
480 PromptEditorEvent::StopRequested => {
481 self.stop_assist(assist_id, cx);
482 }
483 PromptEditorEvent::ConfirmRequested => {
484 self.finish_assist(assist_id, false, cx);
485 }
486 PromptEditorEvent::CancelRequested => {
487 self.finish_assist(assist_id, true, cx);
488 }
489 PromptEditorEvent::DismissRequested => {
490 self.dismiss_assist(assist_id, cx);
491 }
492 }
493 }
494
495 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
496 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
497 return;
498 };
499
500 let editor = editor.read(cx);
501 if editor.selections.count() == 1 {
502 let selection = editor.selections.newest::<usize>(cx);
503 let buffer = editor.buffer().read(cx).snapshot(cx);
504 for assist_id in &editor_assists.assist_ids {
505 let assist = &self.assists[assist_id];
506 let assist_range = assist.range.to_offset(&buffer);
507 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
508 {
509 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
510 self.dismiss_assist(*assist_id, cx);
511 } else {
512 self.finish_assist(*assist_id, false, cx);
513 }
514
515 return;
516 }
517 }
518 }
519
520 cx.propagate();
521 }
522
523 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
524 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
525 return;
526 };
527
528 let editor = editor.read(cx);
529 if editor.selections.count() == 1 {
530 let selection = editor.selections.newest::<usize>(cx);
531 let buffer = editor.buffer().read(cx).snapshot(cx);
532 let mut closest_assist_fallback = None;
533 for assist_id in &editor_assists.assist_ids {
534 let assist = &self.assists[assist_id];
535 let assist_range = assist.range.to_offset(&buffer);
536 if assist.decorations.is_some() {
537 if assist_range.contains(&selection.start)
538 && assist_range.contains(&selection.end)
539 {
540 self.focus_assist(*assist_id, cx);
541 return;
542 } else {
543 let distance_from_selection = assist_range
544 .start
545 .abs_diff(selection.start)
546 .min(assist_range.start.abs_diff(selection.end))
547 + assist_range
548 .end
549 .abs_diff(selection.start)
550 .min(assist_range.end.abs_diff(selection.end));
551 match closest_assist_fallback {
552 Some((_, old_distance)) => {
553 if distance_from_selection < old_distance {
554 closest_assist_fallback =
555 Some((assist_id, distance_from_selection));
556 }
557 }
558 None => {
559 closest_assist_fallback = Some((assist_id, distance_from_selection))
560 }
561 }
562 }
563 }
564 }
565
566 if let Some((&assist_id, _)) = closest_assist_fallback {
567 self.focus_assist(assist_id, cx);
568 }
569 }
570
571 cx.propagate();
572 }
573
574 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
575 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
576 for assist_id in editor_assists.assist_ids.clone() {
577 self.finish_assist(assist_id, true, cx);
578 }
579 }
580 }
581
582 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
583 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
584 return;
585 };
586 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
587 return;
588 };
589 let assist = &self.assists[&scroll_lock.assist_id];
590 let Some(decorations) = assist.decorations.as_ref() else {
591 return;
592 };
593
594 editor.update(cx, |editor, cx| {
595 let scroll_position = editor.scroll_position(cx);
596 let target_scroll_top = editor
597 .row_for_block(decorations.prompt_block_id, cx)
598 .unwrap()
599 .0 as f32
600 - scroll_lock.distance_from_top;
601 if target_scroll_top != scroll_position.y {
602 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
603 }
604 });
605 }
606
607 fn handle_editor_event(
608 &mut self,
609 editor: View<Editor>,
610 event: &EditorEvent,
611 cx: &mut WindowContext,
612 ) {
613 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
614 return;
615 };
616
617 match event {
618 EditorEvent::Edited { transaction_id } => {
619 let buffer = editor.read(cx).buffer().read(cx);
620 let edited_ranges =
621 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
622 let snapshot = buffer.snapshot(cx);
623
624 for assist_id in editor_assists.assist_ids.clone() {
625 let assist = &self.assists[&assist_id];
626 if matches!(
627 assist.codegen.read(cx).status,
628 CodegenStatus::Error(_) | CodegenStatus::Done
629 ) {
630 let assist_range = assist.range.to_offset(&snapshot);
631 if edited_ranges
632 .iter()
633 .any(|range| range.overlaps(&assist_range))
634 {
635 self.finish_assist(assist_id, false, cx);
636 }
637 }
638 }
639 }
640 EditorEvent::ScrollPositionChanged { .. } => {
641 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
642 let assist = &self.assists[&scroll_lock.assist_id];
643 if let Some(decorations) = assist.decorations.as_ref() {
644 let distance_from_top = editor.update(cx, |editor, cx| {
645 let scroll_top = editor.scroll_position(cx).y;
646 let prompt_row = editor
647 .row_for_block(decorations.prompt_block_id, cx)
648 .unwrap()
649 .0 as f32;
650 prompt_row - scroll_top
651 });
652
653 if distance_from_top != scroll_lock.distance_from_top {
654 editor_assists.scroll_lock = None;
655 }
656 }
657 }
658 }
659 EditorEvent::SelectionsChanged { .. } => {
660 for assist_id in editor_assists.assist_ids.clone() {
661 let assist = &self.assists[&assist_id];
662 if let Some(decorations) = assist.decorations.as_ref() {
663 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
664 return;
665 }
666 }
667 }
668
669 editor_assists.scroll_lock = None;
670 }
671 _ => {}
672 }
673 }
674
675 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
676 if let Some(assist) = self.assists.get(&assist_id) {
677 let assist_group_id = assist.group_id;
678 if self.assist_groups[&assist_group_id].linked {
679 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
680 self.finish_assist(assist_id, undo, cx);
681 }
682 return;
683 }
684 }
685
686 self.dismiss_assist(assist_id, cx);
687
688 if let Some(assist) = self.assists.remove(&assist_id) {
689 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
690 {
691 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
692 if entry.get().assist_ids.is_empty() {
693 entry.remove();
694 }
695 }
696
697 if let hash_map::Entry::Occupied(mut entry) =
698 self.assists_by_editor.entry(assist.editor.clone())
699 {
700 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
701 if entry.get().assist_ids.is_empty() {
702 entry.remove();
703 if let Some(editor) = assist.editor.upgrade() {
704 self.update_editor_highlights(&editor, cx);
705 }
706 } else {
707 entry.get().highlight_updates.send(()).ok();
708 }
709 }
710
711 if undo {
712 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
713 } else {
714 self.confirmed_assists.insert(assist_id, assist.codegen);
715 }
716 }
717
718 // Remove the assist from the status updates map
719 self.assist_observations.remove(&assist_id);
720 }
721
722 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
723 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
724 return false;
725 };
726 codegen.update(cx, |this, cx| this.undo(cx));
727 true
728 }
729
730 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
731 let Some(assist) = self.assists.get_mut(&assist_id) else {
732 return false;
733 };
734 let Some(editor) = assist.editor.upgrade() else {
735 return false;
736 };
737 let Some(decorations) = assist.decorations.take() else {
738 return false;
739 };
740
741 editor.update(cx, |editor, cx| {
742 let mut to_remove = decorations.removed_line_block_ids;
743 to_remove.insert(decorations.prompt_block_id);
744 to_remove.insert(decorations.end_block_id);
745 editor.remove_blocks(to_remove, None, cx);
746 });
747
748 if decorations
749 .prompt_editor
750 .focus_handle(cx)
751 .contains_focused(cx)
752 {
753 self.focus_next_assist(assist_id, cx);
754 }
755
756 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
757 if editor_assists
758 .scroll_lock
759 .as_ref()
760 .map_or(false, |lock| lock.assist_id == assist_id)
761 {
762 editor_assists.scroll_lock = None;
763 }
764 editor_assists.highlight_updates.send(()).ok();
765 }
766
767 true
768 }
769
770 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
771 let Some(assist) = self.assists.get(&assist_id) else {
772 return;
773 };
774
775 let assist_group = &self.assist_groups[&assist.group_id];
776 let assist_ix = assist_group
777 .assist_ids
778 .iter()
779 .position(|id| *id == assist_id)
780 .unwrap();
781 let assist_ids = assist_group
782 .assist_ids
783 .iter()
784 .skip(assist_ix + 1)
785 .chain(assist_group.assist_ids.iter().take(assist_ix));
786
787 for assist_id in assist_ids {
788 let assist = &self.assists[assist_id];
789 if assist.decorations.is_some() {
790 self.focus_assist(*assist_id, cx);
791 return;
792 }
793 }
794
795 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
796 }
797
798 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
799 let Some(assist) = self.assists.get(&assist_id) else {
800 return;
801 };
802
803 if let Some(decorations) = assist.decorations.as_ref() {
804 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
805 prompt_editor.editor.update(cx, |editor, cx| {
806 editor.focus(cx);
807 editor.select_all(&SelectAll, cx);
808 })
809 });
810 }
811
812 self.scroll_to_assist(assist_id, cx);
813 }
814
815 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
816 let Some(assist) = self.assists.get(&assist_id) else {
817 return;
818 };
819 let Some(editor) = assist.editor.upgrade() else {
820 return;
821 };
822
823 let position = assist.range.start;
824 editor.update(cx, |editor, cx| {
825 editor.change_selections(None, cx, |selections| {
826 selections.select_anchor_ranges([position..position])
827 });
828
829 let mut scroll_target_top;
830 let mut scroll_target_bottom;
831 if let Some(decorations) = assist.decorations.as_ref() {
832 scroll_target_top = editor
833 .row_for_block(decorations.prompt_block_id, cx)
834 .unwrap()
835 .0 as f32;
836 scroll_target_bottom = editor
837 .row_for_block(decorations.end_block_id, cx)
838 .unwrap()
839 .0 as f32;
840 } else {
841 let snapshot = editor.snapshot(cx);
842 let start_row = assist
843 .range
844 .start
845 .to_display_point(&snapshot.display_snapshot)
846 .row();
847 scroll_target_top = start_row.0 as f32;
848 scroll_target_bottom = scroll_target_top + 1.;
849 }
850 scroll_target_top -= editor.vertical_scroll_margin() as f32;
851 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
852
853 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
854 let scroll_top = editor.scroll_position(cx).y;
855 let scroll_bottom = scroll_top + height_in_lines;
856
857 if scroll_target_top < scroll_top {
858 editor.set_scroll_position(point(0., scroll_target_top), cx);
859 } else if scroll_target_bottom > scroll_bottom {
860 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
861 editor
862 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
863 } else {
864 editor.set_scroll_position(point(0., scroll_target_top), cx);
865 }
866 }
867 });
868 }
869
870 fn unlink_assist_group(
871 &mut self,
872 assist_group_id: InlineAssistGroupId,
873 cx: &mut WindowContext,
874 ) -> Vec<InlineAssistId> {
875 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
876 assist_group.linked = false;
877 for assist_id in &assist_group.assist_ids {
878 let assist = self.assists.get_mut(assist_id).unwrap();
879 if let Some(editor_decorations) = assist.decorations.as_ref() {
880 editor_decorations
881 .prompt_editor
882 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
883 }
884 }
885 assist_group.assist_ids.clone()
886 }
887
888 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
889 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
890 assist
891 } else {
892 return;
893 };
894
895 let assist_group_id = assist.group_id;
896 if self.assist_groups[&assist_group_id].linked {
897 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
898 self.start_assist(assist_id, cx);
899 }
900 return;
901 }
902
903 let Some(user_prompt) = assist.user_prompt(cx) else {
904 return;
905 };
906
907 self.prompt_history.retain(|prompt| *prompt != user_prompt);
908 self.prompt_history.push_back(user_prompt.clone());
909 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
910 self.prompt_history.pop_front();
911 }
912
913 let assistant_panel_context = assist.assistant_panel_context(cx);
914
915 assist
916 .codegen
917 .update(cx, |codegen, cx| {
918 codegen.start(
919 assist.range.clone(),
920 user_prompt,
921 assistant_panel_context,
922 cx,
923 )
924 })
925 .log_err();
926
927 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
928 tx.send(()).ok();
929 }
930 }
931
932 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
933 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
934 assist
935 } else {
936 return;
937 };
938
939 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
940
941 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
942 tx.send(()).ok();
943 }
944 }
945
946 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
947 if let Some(assist) = self.assists.get(&assist_id) {
948 match &assist.codegen.read(cx).status {
949 CodegenStatus::Idle => InlineAssistStatus::Idle,
950 CodegenStatus::Pending => InlineAssistStatus::Pending,
951 CodegenStatus::Done => InlineAssistStatus::Done,
952 CodegenStatus::Error(_) => InlineAssistStatus::Error,
953 }
954 } else if self.confirmed_assists.contains_key(&assist_id) {
955 InlineAssistStatus::Confirmed
956 } else {
957 InlineAssistStatus::Canceled
958 }
959 }
960
961 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
962 let mut gutter_pending_ranges = Vec::new();
963 let mut gutter_transformed_ranges = Vec::new();
964 let mut foreground_ranges = Vec::new();
965 let mut inserted_row_ranges = Vec::new();
966 let empty_assist_ids = Vec::new();
967 let assist_ids = self
968 .assists_by_editor
969 .get(&editor.downgrade())
970 .map_or(&empty_assist_ids, |editor_assists| {
971 &editor_assists.assist_ids
972 });
973
974 for assist_id in assist_ids {
975 if let Some(assist) = self.assists.get(assist_id) {
976 let codegen = assist.codegen.read(cx);
977 let buffer = codegen.buffer.read(cx).read(cx);
978 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
979
980 let pending_range =
981 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
982 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
983 gutter_pending_ranges.push(pending_range);
984 }
985
986 if let Some(edit_position) = codegen.edit_position {
987 let edited_range = assist.range.start..edit_position;
988 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
989 gutter_transformed_ranges.push(edited_range);
990 }
991 }
992
993 if assist.decorations.is_some() {
994 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
995 }
996 }
997 }
998
999 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1000 merge_ranges(&mut foreground_ranges, &snapshot);
1001 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1002 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1003 editor.update(cx, |editor, cx| {
1004 enum GutterPendingRange {}
1005 if gutter_pending_ranges.is_empty() {
1006 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1007 } else {
1008 editor.highlight_gutter::<GutterPendingRange>(
1009 &gutter_pending_ranges,
1010 |cx| cx.theme().status().info_background,
1011 cx,
1012 )
1013 }
1014
1015 enum GutterTransformedRange {}
1016 if gutter_transformed_ranges.is_empty() {
1017 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1018 } else {
1019 editor.highlight_gutter::<GutterTransformedRange>(
1020 &gutter_transformed_ranges,
1021 |cx| cx.theme().status().info,
1022 cx,
1023 )
1024 }
1025
1026 if foreground_ranges.is_empty() {
1027 editor.clear_highlights::<InlineAssist>(cx);
1028 } else {
1029 editor.highlight_text::<InlineAssist>(
1030 foreground_ranges,
1031 HighlightStyle {
1032 fade_out: Some(0.6),
1033 ..Default::default()
1034 },
1035 cx,
1036 );
1037 }
1038
1039 editor.clear_row_highlights::<InlineAssist>();
1040 for row_range in inserted_row_ranges {
1041 editor.highlight_rows::<InlineAssist>(
1042 row_range,
1043 Some(cx.theme().status().info_background),
1044 false,
1045 cx,
1046 );
1047 }
1048 });
1049 }
1050
1051 fn update_editor_blocks(
1052 &mut self,
1053 editor: &View<Editor>,
1054 assist_id: InlineAssistId,
1055 cx: &mut WindowContext,
1056 ) {
1057 let Some(assist) = self.assists.get_mut(&assist_id) else {
1058 return;
1059 };
1060 let Some(decorations) = assist.decorations.as_mut() else {
1061 return;
1062 };
1063
1064 let codegen = assist.codegen.read(cx);
1065 let old_snapshot = codegen.snapshot.clone();
1066 let old_buffer = codegen.old_buffer.clone();
1067 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
1068
1069 editor.update(cx, |editor, cx| {
1070 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1071 editor.remove_blocks(old_blocks, None, cx);
1072
1073 let mut new_blocks = Vec::new();
1074 for (new_row, old_row_range) in deleted_row_ranges {
1075 let (_, buffer_start) = old_snapshot
1076 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1077 .unwrap();
1078 let (_, buffer_end) = old_snapshot
1079 .point_to_buffer_offset(Point::new(
1080 *old_row_range.end(),
1081 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1082 ))
1083 .unwrap();
1084
1085 let deleted_lines_editor = cx.new_view(|cx| {
1086 let multi_buffer = cx.new_model(|_| {
1087 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1088 });
1089 multi_buffer.update(cx, |multi_buffer, cx| {
1090 multi_buffer.push_excerpts(
1091 old_buffer.clone(),
1092 Some(ExcerptRange {
1093 context: buffer_start..buffer_end,
1094 primary: None,
1095 }),
1096 cx,
1097 );
1098 });
1099
1100 enum DeletedLines {}
1101 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1102 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1103 editor.set_show_wrap_guides(false, cx);
1104 editor.set_show_gutter(false, cx);
1105 editor.scroll_manager.set_forbid_vertical_scroll(true);
1106 editor.set_read_only(true);
1107 editor.set_show_inline_completions(false);
1108 editor.highlight_rows::<DeletedLines>(
1109 Anchor::min()..=Anchor::max(),
1110 Some(cx.theme().status().deleted_background),
1111 false,
1112 cx,
1113 );
1114 editor
1115 });
1116
1117 let height =
1118 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1119 new_blocks.push(BlockProperties {
1120 position: new_row,
1121 height,
1122 style: BlockStyle::Flex,
1123 render: Box::new(move |cx| {
1124 div()
1125 .bg(cx.theme().status().deleted_background)
1126 .size_full()
1127 .h(height as f32 * cx.line_height())
1128 .pl(cx.gutter_dimensions.full_width())
1129 .child(deleted_lines_editor.clone())
1130 .into_any_element()
1131 }),
1132 disposition: BlockDisposition::Above,
1133 priority: 0,
1134 });
1135 }
1136
1137 decorations.removed_line_block_ids = editor
1138 .insert_blocks(new_blocks, None, cx)
1139 .into_iter()
1140 .collect();
1141 })
1142 }
1143
1144 pub fn observe_assist(&mut self, assist_id: InlineAssistId) -> async_watch::Receiver<()> {
1145 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1146 rx.clone()
1147 } else {
1148 let (tx, rx) = async_watch::channel(());
1149 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1150 rx
1151 }
1152 }
1153}
1154
1155pub enum InlineAssistStatus {
1156 Idle,
1157 Pending,
1158 Done,
1159 Error,
1160 Confirmed,
1161 Canceled,
1162}
1163
1164impl InlineAssistStatus {
1165 pub(crate) fn is_pending(&self) -> bool {
1166 matches!(self, Self::Pending)
1167 }
1168 pub(crate) fn is_confirmed(&self) -> bool {
1169 matches!(self, Self::Confirmed)
1170 }
1171 pub(crate) fn is_done(&self) -> bool {
1172 matches!(self, Self::Done)
1173 }
1174}
1175
1176struct EditorInlineAssists {
1177 assist_ids: Vec<InlineAssistId>,
1178 scroll_lock: Option<InlineAssistScrollLock>,
1179 highlight_updates: async_watch::Sender<()>,
1180 _update_highlights: Task<Result<()>>,
1181 _subscriptions: Vec<gpui::Subscription>,
1182}
1183
1184struct InlineAssistScrollLock {
1185 assist_id: InlineAssistId,
1186 distance_from_top: f32,
1187}
1188
1189impl EditorInlineAssists {
1190 #[allow(clippy::too_many_arguments)]
1191 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1192 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1193 Self {
1194 assist_ids: Vec::new(),
1195 scroll_lock: None,
1196 highlight_updates: highlight_updates_tx,
1197 _update_highlights: cx.spawn(|mut cx| {
1198 let editor = editor.downgrade();
1199 async move {
1200 while let Ok(()) = highlight_updates_rx.changed().await {
1201 let editor = editor.upgrade().context("editor was dropped")?;
1202 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1203 assistant.update_editor_highlights(&editor, cx);
1204 })?;
1205 }
1206 Ok(())
1207 }
1208 }),
1209 _subscriptions: vec![
1210 cx.observe_release(editor, {
1211 let editor = editor.downgrade();
1212 |_, cx| {
1213 InlineAssistant::update_global(cx, |this, cx| {
1214 this.handle_editor_release(editor, cx);
1215 })
1216 }
1217 }),
1218 cx.observe(editor, move |editor, cx| {
1219 InlineAssistant::update_global(cx, |this, cx| {
1220 this.handle_editor_change(editor, cx)
1221 })
1222 }),
1223 cx.subscribe(editor, move |editor, event, cx| {
1224 InlineAssistant::update_global(cx, |this, cx| {
1225 this.handle_editor_event(editor, event, cx)
1226 })
1227 }),
1228 editor.update(cx, |editor, cx| {
1229 let editor_handle = cx.view().downgrade();
1230 editor.register_action(
1231 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1232 InlineAssistant::update_global(cx, |this, cx| {
1233 if let Some(editor) = editor_handle.upgrade() {
1234 this.handle_editor_newline(editor, cx)
1235 }
1236 })
1237 },
1238 )
1239 }),
1240 editor.update(cx, |editor, cx| {
1241 let editor_handle = cx.view().downgrade();
1242 editor.register_action(
1243 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1244 InlineAssistant::update_global(cx, |this, cx| {
1245 if let Some(editor) = editor_handle.upgrade() {
1246 this.handle_editor_cancel(editor, cx)
1247 }
1248 })
1249 },
1250 )
1251 }),
1252 ],
1253 }
1254 }
1255}
1256
1257struct InlineAssistGroup {
1258 assist_ids: Vec<InlineAssistId>,
1259 linked: bool,
1260 active_assist_id: Option<InlineAssistId>,
1261}
1262
1263impl InlineAssistGroup {
1264 fn new() -> Self {
1265 Self {
1266 assist_ids: Vec::new(),
1267 linked: true,
1268 active_assist_id: None,
1269 }
1270 }
1271}
1272
1273fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1274 let editor = editor.clone();
1275 Box::new(move |cx: &mut BlockContext| {
1276 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1277 editor.clone().into_any_element()
1278 })
1279}
1280
1281#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1282pub struct InlineAssistId(usize);
1283
1284impl InlineAssistId {
1285 fn post_inc(&mut self) -> InlineAssistId {
1286 let id = *self;
1287 self.0 += 1;
1288 id
1289 }
1290}
1291
1292#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1293struct InlineAssistGroupId(usize);
1294
1295impl InlineAssistGroupId {
1296 fn post_inc(&mut self) -> InlineAssistGroupId {
1297 let id = *self;
1298 self.0 += 1;
1299 id
1300 }
1301}
1302
1303enum PromptEditorEvent {
1304 StartRequested,
1305 StopRequested,
1306 ConfirmRequested,
1307 CancelRequested,
1308 DismissRequested,
1309}
1310
1311struct PromptEditor {
1312 id: InlineAssistId,
1313 fs: Arc<dyn Fs>,
1314 editor: View<Editor>,
1315 edited_since_done: bool,
1316 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1317 prompt_history: VecDeque<String>,
1318 prompt_history_ix: Option<usize>,
1319 pending_prompt: String,
1320 codegen: Model<Codegen>,
1321 _codegen_subscription: Subscription,
1322 editor_subscriptions: Vec<Subscription>,
1323 pending_token_count: Task<Result<()>>,
1324 token_counts: Option<TokenCounts>,
1325 _token_count_subscriptions: Vec<Subscription>,
1326 workspace: Option<WeakView<Workspace>>,
1327 show_rate_limit_notice: bool,
1328}
1329
1330#[derive(Copy, Clone)]
1331pub struct TokenCounts {
1332 total: usize,
1333 assistant_panel: usize,
1334}
1335
1336impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1337
1338impl Render for PromptEditor {
1339 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1340 let gutter_dimensions = *self.gutter_dimensions.lock();
1341 let status = &self.codegen.read(cx).status;
1342 let buttons = match status {
1343 CodegenStatus::Idle => {
1344 vec![
1345 IconButton::new("cancel", IconName::Close)
1346 .icon_color(Color::Muted)
1347 .shape(IconButtonShape::Square)
1348 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1349 .on_click(
1350 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1351 ),
1352 IconButton::new("start", IconName::SparkleAlt)
1353 .icon_color(Color::Muted)
1354 .shape(IconButtonShape::Square)
1355 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1356 .on_click(
1357 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1358 ),
1359 ]
1360 }
1361 CodegenStatus::Pending => {
1362 vec![
1363 IconButton::new("cancel", IconName::Close)
1364 .icon_color(Color::Muted)
1365 .shape(IconButtonShape::Square)
1366 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1367 .on_click(
1368 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1369 ),
1370 IconButton::new("stop", IconName::Stop)
1371 .icon_color(Color::Error)
1372 .shape(IconButtonShape::Square)
1373 .tooltip(|cx| {
1374 Tooltip::with_meta(
1375 "Interrupt Transformation",
1376 Some(&menu::Cancel),
1377 "Changes won't be discarded",
1378 cx,
1379 )
1380 })
1381 .on_click(
1382 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1383 ),
1384 ]
1385 }
1386 CodegenStatus::Error(_) | CodegenStatus::Done => {
1387 vec![
1388 IconButton::new("cancel", IconName::Close)
1389 .icon_color(Color::Muted)
1390 .shape(IconButtonShape::Square)
1391 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1392 .on_click(
1393 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1394 ),
1395 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1396 IconButton::new("restart", IconName::RotateCw)
1397 .icon_color(Color::Info)
1398 .shape(IconButtonShape::Square)
1399 .tooltip(|cx| {
1400 Tooltip::with_meta(
1401 "Restart Transformation",
1402 Some(&menu::Confirm),
1403 "Changes will be discarded",
1404 cx,
1405 )
1406 })
1407 .on_click(cx.listener(|_, _, cx| {
1408 cx.emit(PromptEditorEvent::StartRequested);
1409 }))
1410 } else {
1411 IconButton::new("confirm", IconName::Check)
1412 .icon_color(Color::Info)
1413 .shape(IconButtonShape::Square)
1414 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1415 .on_click(cx.listener(|_, _, cx| {
1416 cx.emit(PromptEditorEvent::ConfirmRequested);
1417 }))
1418 },
1419 ]
1420 }
1421 };
1422
1423 h_flex()
1424 .bg(cx.theme().colors().editor_background)
1425 .border_y_1()
1426 .border_color(cx.theme().status().info_border)
1427 .size_full()
1428 .py(cx.line_height() / 2.)
1429 .on_action(cx.listener(Self::confirm))
1430 .on_action(cx.listener(Self::cancel))
1431 .on_action(cx.listener(Self::move_up))
1432 .on_action(cx.listener(Self::move_down))
1433 .child(
1434 h_flex()
1435 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1436 .justify_center()
1437 .gap_2()
1438 .child(
1439 ModelSelector::new(
1440 self.fs.clone(),
1441 IconButton::new("context", IconName::SlidersAlt)
1442 .shape(IconButtonShape::Square)
1443 .icon_size(IconSize::Small)
1444 .icon_color(Color::Muted)
1445 .tooltip(move |cx| {
1446 Tooltip::with_meta(
1447 format!(
1448 "Using {}",
1449 LanguageModelRegistry::read_global(cx)
1450 .active_model()
1451 .map(|model| model.name().0)
1452 .unwrap_or_else(|| "No model selected".into()),
1453 ),
1454 None,
1455 "Change Model",
1456 cx,
1457 )
1458 }),
1459 )
1460 .with_info_text(
1461 "Inline edits use context\n\
1462 from the currently selected\n\
1463 assistant panel tab.",
1464 ),
1465 )
1466 .map(|el| {
1467 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1468 return el;
1469 };
1470
1471 let error_message = SharedString::from(error.to_string());
1472 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1473 && cx.has_flag::<ZedPro>()
1474 {
1475 el.child(
1476 v_flex()
1477 .child(
1478 IconButton::new("rate-limit-error", IconName::XCircle)
1479 .selected(self.show_rate_limit_notice)
1480 .shape(IconButtonShape::Square)
1481 .icon_size(IconSize::Small)
1482 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1483 )
1484 .children(self.show_rate_limit_notice.then(|| {
1485 deferred(
1486 anchored()
1487 .position_mode(gpui::AnchoredPositionMode::Local)
1488 .position(point(px(0.), px(24.)))
1489 .anchor(gpui::AnchorCorner::TopLeft)
1490 .child(self.render_rate_limit_notice(cx)),
1491 )
1492 })),
1493 )
1494 } else {
1495 el.child(
1496 div()
1497 .id("error")
1498 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1499 .child(
1500 Icon::new(IconName::XCircle)
1501 .size(IconSize::Small)
1502 .color(Color::Error),
1503 ),
1504 )
1505 }
1506 }),
1507 )
1508 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1509 .child(
1510 h_flex()
1511 .gap_2()
1512 .pr_6()
1513 .children(self.render_token_count(cx))
1514 .children(buttons),
1515 )
1516 }
1517}
1518
1519impl FocusableView for PromptEditor {
1520 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1521 self.editor.focus_handle(cx)
1522 }
1523}
1524
1525impl PromptEditor {
1526 const MAX_LINES: u8 = 8;
1527
1528 #[allow(clippy::too_many_arguments)]
1529 fn new(
1530 id: InlineAssistId,
1531 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1532 prompt_history: VecDeque<String>,
1533 prompt_buffer: Model<MultiBuffer>,
1534 codegen: Model<Codegen>,
1535 parent_editor: &View<Editor>,
1536 assistant_panel: Option<&View<AssistantPanel>>,
1537 workspace: Option<WeakView<Workspace>>,
1538 fs: Arc<dyn Fs>,
1539 cx: &mut ViewContext<Self>,
1540 ) -> Self {
1541 let prompt_editor = cx.new_view(|cx| {
1542 let mut editor = Editor::new(
1543 EditorMode::AutoHeight {
1544 max_lines: Self::MAX_LINES as usize,
1545 },
1546 prompt_buffer,
1547 None,
1548 false,
1549 cx,
1550 );
1551 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1552 // Since the prompt editors for all inline assistants are linked,
1553 // always show the cursor (even when it isn't focused) because
1554 // typing in one will make what you typed appear in all of them.
1555 editor.set_show_cursor_when_unfocused(true, cx);
1556 editor.set_placeholder_text("Add a prompt…", cx);
1557 editor
1558 });
1559
1560 let mut token_count_subscriptions = Vec::new();
1561 token_count_subscriptions
1562 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1563 if let Some(assistant_panel) = assistant_panel {
1564 token_count_subscriptions
1565 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1566 }
1567
1568 let mut this = Self {
1569 id,
1570 editor: prompt_editor,
1571 edited_since_done: false,
1572 gutter_dimensions,
1573 prompt_history,
1574 prompt_history_ix: None,
1575 pending_prompt: String::new(),
1576 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1577 editor_subscriptions: Vec::new(),
1578 codegen,
1579 fs,
1580 pending_token_count: Task::ready(Ok(())),
1581 token_counts: None,
1582 _token_count_subscriptions: token_count_subscriptions,
1583 workspace,
1584 show_rate_limit_notice: false,
1585 };
1586 this.count_tokens(cx);
1587 this.subscribe_to_editor(cx);
1588 this
1589 }
1590
1591 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1592 self.editor_subscriptions.clear();
1593 self.editor_subscriptions
1594 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1595 }
1596
1597 fn set_show_cursor_when_unfocused(
1598 &mut self,
1599 show_cursor_when_unfocused: bool,
1600 cx: &mut ViewContext<Self>,
1601 ) {
1602 self.editor.update(cx, |editor, cx| {
1603 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1604 });
1605 }
1606
1607 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1608 let prompt = self.prompt(cx);
1609 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1610 self.editor = cx.new_view(|cx| {
1611 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1612 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1613 editor.set_placeholder_text("Add a prompt…", cx);
1614 editor.set_text(prompt, cx);
1615 if focus {
1616 editor.focus(cx);
1617 }
1618 editor
1619 });
1620 self.subscribe_to_editor(cx);
1621 }
1622
1623 fn prompt(&self, cx: &AppContext) -> String {
1624 self.editor.read(cx).text(cx)
1625 }
1626
1627 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1628 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1629 if self.show_rate_limit_notice {
1630 cx.focus_view(&self.editor);
1631 }
1632 cx.notify();
1633 }
1634
1635 fn handle_parent_editor_event(
1636 &mut self,
1637 _: View<Editor>,
1638 event: &EditorEvent,
1639 cx: &mut ViewContext<Self>,
1640 ) {
1641 if let EditorEvent::BufferEdited { .. } = event {
1642 self.count_tokens(cx);
1643 }
1644 }
1645
1646 fn handle_assistant_panel_event(
1647 &mut self,
1648 _: View<AssistantPanel>,
1649 event: &AssistantPanelEvent,
1650 cx: &mut ViewContext<Self>,
1651 ) {
1652 let AssistantPanelEvent::ContextEdited { .. } = event;
1653 self.count_tokens(cx);
1654 }
1655
1656 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1657 let assist_id = self.id;
1658 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1659 cx.background_executor().timer(Duration::from_secs(1)).await;
1660 let token_count = cx
1661 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1662 let assist = inline_assistant
1663 .assists
1664 .get(&assist_id)
1665 .context("assist not found")?;
1666 anyhow::Ok(assist.count_tokens(cx))
1667 })??
1668 .await?;
1669
1670 this.update(&mut cx, |this, cx| {
1671 this.token_counts = Some(token_count);
1672 cx.notify();
1673 })
1674 })
1675 }
1676
1677 fn handle_prompt_editor_events(
1678 &mut self,
1679 _: View<Editor>,
1680 event: &EditorEvent,
1681 cx: &mut ViewContext<Self>,
1682 ) {
1683 match event {
1684 EditorEvent::Edited { .. } => {
1685 let prompt = self.editor.read(cx).text(cx);
1686 if self
1687 .prompt_history_ix
1688 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1689 {
1690 self.prompt_history_ix.take();
1691 self.pending_prompt = prompt;
1692 }
1693
1694 self.edited_since_done = true;
1695 cx.notify();
1696 }
1697 EditorEvent::BufferEdited => {
1698 self.count_tokens(cx);
1699 }
1700 EditorEvent::Blurred => {
1701 if self.show_rate_limit_notice {
1702 self.show_rate_limit_notice = false;
1703 cx.notify();
1704 }
1705 }
1706 _ => {}
1707 }
1708 }
1709
1710 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1711 match &self.codegen.read(cx).status {
1712 CodegenStatus::Idle => {
1713 self.editor
1714 .update(cx, |editor, _| editor.set_read_only(false));
1715 }
1716 CodegenStatus::Pending => {
1717 self.editor
1718 .update(cx, |editor, _| editor.set_read_only(true));
1719 }
1720 CodegenStatus::Done => {
1721 self.edited_since_done = false;
1722 self.editor
1723 .update(cx, |editor, _| editor.set_read_only(false));
1724 }
1725 CodegenStatus::Error(error) => {
1726 if cx.has_flag::<ZedPro>()
1727 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1728 && !dismissed_rate_limit_notice()
1729 {
1730 self.show_rate_limit_notice = true;
1731 cx.notify();
1732 }
1733
1734 self.edited_since_done = false;
1735 self.editor
1736 .update(cx, |editor, _| editor.set_read_only(false));
1737 }
1738 }
1739 }
1740
1741 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1742 match &self.codegen.read(cx).status {
1743 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1744 cx.emit(PromptEditorEvent::CancelRequested);
1745 }
1746 CodegenStatus::Pending => {
1747 cx.emit(PromptEditorEvent::StopRequested);
1748 }
1749 }
1750 }
1751
1752 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1753 match &self.codegen.read(cx).status {
1754 CodegenStatus::Idle => {
1755 cx.emit(PromptEditorEvent::StartRequested);
1756 }
1757 CodegenStatus::Pending => {
1758 cx.emit(PromptEditorEvent::DismissRequested);
1759 }
1760 CodegenStatus::Done | CodegenStatus::Error(_) => {
1761 if self.edited_since_done {
1762 cx.emit(PromptEditorEvent::StartRequested);
1763 } else {
1764 cx.emit(PromptEditorEvent::ConfirmRequested);
1765 }
1766 }
1767 }
1768 }
1769
1770 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1771 if let Some(ix) = self.prompt_history_ix {
1772 if ix > 0 {
1773 self.prompt_history_ix = Some(ix - 1);
1774 let prompt = self.prompt_history[ix - 1].as_str();
1775 self.editor.update(cx, |editor, cx| {
1776 editor.set_text(prompt, cx);
1777 editor.move_to_beginning(&Default::default(), cx);
1778 });
1779 }
1780 } else if !self.prompt_history.is_empty() {
1781 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1782 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1783 self.editor.update(cx, |editor, cx| {
1784 editor.set_text(prompt, cx);
1785 editor.move_to_beginning(&Default::default(), cx);
1786 });
1787 }
1788 }
1789
1790 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1791 if let Some(ix) = self.prompt_history_ix {
1792 if ix < self.prompt_history.len() - 1 {
1793 self.prompt_history_ix = Some(ix + 1);
1794 let prompt = self.prompt_history[ix + 1].as_str();
1795 self.editor.update(cx, |editor, cx| {
1796 editor.set_text(prompt, cx);
1797 editor.move_to_end(&Default::default(), cx)
1798 });
1799 } else {
1800 self.prompt_history_ix = None;
1801 let prompt = self.pending_prompt.as_str();
1802 self.editor.update(cx, |editor, cx| {
1803 editor.set_text(prompt, cx);
1804 editor.move_to_end(&Default::default(), cx)
1805 });
1806 }
1807 }
1808 }
1809
1810 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1811 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1812 let token_counts = self.token_counts?;
1813 let max_token_count = model.max_token_count();
1814
1815 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1816 let token_count_color = if remaining_tokens <= 0 {
1817 Color::Error
1818 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1819 Color::Warning
1820 } else {
1821 Color::Muted
1822 };
1823
1824 let mut token_count = h_flex()
1825 .id("token_count")
1826 .gap_0p5()
1827 .child(
1828 Label::new(humanize_token_count(token_counts.total))
1829 .size(LabelSize::Small)
1830 .color(token_count_color),
1831 )
1832 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1833 .child(
1834 Label::new(humanize_token_count(max_token_count))
1835 .size(LabelSize::Small)
1836 .color(Color::Muted),
1837 );
1838 if let Some(workspace) = self.workspace.clone() {
1839 token_count = token_count
1840 .tooltip(move |cx| {
1841 Tooltip::with_meta(
1842 format!(
1843 "Tokens Used ({} from the Assistant Panel)",
1844 humanize_token_count(token_counts.assistant_panel)
1845 ),
1846 None,
1847 "Click to open the Assistant Panel",
1848 cx,
1849 )
1850 })
1851 .cursor_pointer()
1852 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1853 .on_click(move |_, cx| {
1854 cx.stop_propagation();
1855 workspace
1856 .update(cx, |workspace, cx| {
1857 workspace.focus_panel::<AssistantPanel>(cx)
1858 })
1859 .ok();
1860 });
1861 } else {
1862 token_count = token_count
1863 .cursor_default()
1864 .tooltip(|cx| Tooltip::text("Tokens used", cx));
1865 }
1866
1867 Some(token_count)
1868 }
1869
1870 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1871 let settings = ThemeSettings::get_global(cx);
1872 let text_style = TextStyle {
1873 color: if self.editor.read(cx).read_only(cx) {
1874 cx.theme().colors().text_disabled
1875 } else {
1876 cx.theme().colors().text
1877 },
1878 font_family: settings.ui_font.family.clone(),
1879 font_features: settings.ui_font.features.clone(),
1880 font_fallbacks: settings.ui_font.fallbacks.clone(),
1881 font_size: rems(0.875).into(),
1882 font_weight: settings.ui_font.weight,
1883 line_height: relative(1.3),
1884 ..Default::default()
1885 };
1886 EditorElement::new(
1887 &self.editor,
1888 EditorStyle {
1889 background: cx.theme().colors().editor_background,
1890 local_player: cx.theme().players().local(),
1891 text: text_style,
1892 ..Default::default()
1893 },
1894 )
1895 }
1896
1897 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1898 Popover::new().child(
1899 v_flex()
1900 .occlude()
1901 .p_2()
1902 .child(
1903 Label::new("Out of Tokens")
1904 .size(LabelSize::Small)
1905 .weight(FontWeight::BOLD),
1906 )
1907 .child(Label::new(
1908 "Try Zed Pro for higher limits, a wider range of models, and more.",
1909 ))
1910 .child(
1911 h_flex()
1912 .justify_between()
1913 .child(CheckboxWithLabel::new(
1914 "dont-show-again",
1915 Label::new("Don't show again"),
1916 if dismissed_rate_limit_notice() {
1917 ui::Selection::Selected
1918 } else {
1919 ui::Selection::Unselected
1920 },
1921 |selection, cx| {
1922 let is_dismissed = match selection {
1923 ui::Selection::Unselected => false,
1924 ui::Selection::Indeterminate => return,
1925 ui::Selection::Selected => true,
1926 };
1927
1928 set_rate_limit_notice_dismissed(is_dismissed, cx)
1929 },
1930 ))
1931 .child(
1932 h_flex()
1933 .gap_2()
1934 .child(
1935 Button::new("dismiss", "Dismiss")
1936 .style(ButtonStyle::Transparent)
1937 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1938 )
1939 .child(Button::new("more-info", "More Info").on_click(
1940 |_event, cx| {
1941 cx.dispatch_action(Box::new(
1942 zed_actions::OpenAccountSettings,
1943 ))
1944 },
1945 )),
1946 ),
1947 ),
1948 )
1949 }
1950}
1951
1952const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1953
1954fn dismissed_rate_limit_notice() -> bool {
1955 db::kvp::KEY_VALUE_STORE
1956 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1957 .log_err()
1958 .map_or(false, |s| s.is_some())
1959}
1960
1961fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1962 db::write_and_log(cx, move || async move {
1963 if is_dismissed {
1964 db::kvp::KEY_VALUE_STORE
1965 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1966 .await
1967 } else {
1968 db::kvp::KEY_VALUE_STORE
1969 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1970 .await
1971 }
1972 })
1973}
1974
1975struct InlineAssist {
1976 group_id: InlineAssistGroupId,
1977 range: Range<Anchor>,
1978 editor: WeakView<Editor>,
1979 decorations: Option<InlineAssistDecorations>,
1980 codegen: Model<Codegen>,
1981 _subscriptions: Vec<Subscription>,
1982 workspace: Option<WeakView<Workspace>>,
1983 include_context: bool,
1984}
1985
1986impl InlineAssist {
1987 #[allow(clippy::too_many_arguments)]
1988 fn new(
1989 assist_id: InlineAssistId,
1990 group_id: InlineAssistGroupId,
1991 include_context: bool,
1992 editor: &View<Editor>,
1993 prompt_editor: &View<PromptEditor>,
1994 prompt_block_id: CustomBlockId,
1995 end_block_id: CustomBlockId,
1996 range: Range<Anchor>,
1997 codegen: Model<Codegen>,
1998 workspace: Option<WeakView<Workspace>>,
1999 cx: &mut WindowContext,
2000 ) -> Self {
2001 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2002 InlineAssist {
2003 group_id,
2004 include_context,
2005 editor: editor.downgrade(),
2006 decorations: Some(InlineAssistDecorations {
2007 prompt_block_id,
2008 prompt_editor: prompt_editor.clone(),
2009 removed_line_block_ids: HashSet::default(),
2010 end_block_id,
2011 }),
2012 range,
2013 codegen: codegen.clone(),
2014 workspace: workspace.clone(),
2015 _subscriptions: vec![
2016 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2017 InlineAssistant::update_global(cx, |this, cx| {
2018 this.handle_prompt_editor_focus_in(assist_id, cx)
2019 })
2020 }),
2021 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2022 InlineAssistant::update_global(cx, |this, cx| {
2023 this.handle_prompt_editor_focus_out(assist_id, cx)
2024 })
2025 }),
2026 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2027 InlineAssistant::update_global(cx, |this, cx| {
2028 this.handle_prompt_editor_event(prompt_editor, event, cx)
2029 })
2030 }),
2031 cx.observe(&codegen, {
2032 let editor = editor.downgrade();
2033 move |_, cx| {
2034 if let Some(editor) = editor.upgrade() {
2035 InlineAssistant::update_global(cx, |this, cx| {
2036 if let Some(editor_assists) =
2037 this.assists_by_editor.get(&editor.downgrade())
2038 {
2039 editor_assists.highlight_updates.send(()).ok();
2040 }
2041
2042 this.update_editor_blocks(&editor, assist_id, cx);
2043 })
2044 }
2045 }
2046 }),
2047 cx.subscribe(&codegen, move |codegen, event, cx| {
2048 InlineAssistant::update_global(cx, |this, cx| match event {
2049 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2050 CodegenEvent::Finished => {
2051 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2052 assist
2053 } else {
2054 return;
2055 };
2056
2057 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
2058 if assist.decorations.is_none() {
2059 if let Some(workspace) = assist
2060 .workspace
2061 .as_ref()
2062 .and_then(|workspace| workspace.upgrade())
2063 {
2064 let error = format!("Inline assistant error: {}", error);
2065 workspace.update(cx, |workspace, cx| {
2066 struct InlineAssistantError;
2067
2068 let id =
2069 NotificationId::identified::<InlineAssistantError>(
2070 assist_id.0,
2071 );
2072
2073 workspace.show_toast(Toast::new(id, error), cx);
2074 })
2075 }
2076 }
2077 }
2078
2079 if assist.decorations.is_none() {
2080 this.finish_assist(assist_id, false, cx);
2081 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2082 tx.0.send(()).ok();
2083 }
2084 }
2085 })
2086 }),
2087 ],
2088 }
2089 }
2090
2091 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2092 let decorations = self.decorations.as_ref()?;
2093 Some(decorations.prompt_editor.read(cx).prompt(cx))
2094 }
2095
2096 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2097 if self.include_context {
2098 let workspace = self.workspace.as_ref()?;
2099 let workspace = workspace.upgrade()?.read(cx);
2100 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2101 Some(
2102 assistant_panel
2103 .read(cx)
2104 .active_context(cx)?
2105 .read(cx)
2106 .to_completion_request(cx),
2107 )
2108 } else {
2109 None
2110 }
2111 }
2112
2113 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2114 let Some(user_prompt) = self.user_prompt(cx) else {
2115 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2116 };
2117 let assistant_panel_context = self.assistant_panel_context(cx);
2118 self.codegen.read(cx).count_tokens(
2119 self.range.clone(),
2120 user_prompt,
2121 assistant_panel_context,
2122 cx,
2123 )
2124 }
2125}
2126
2127struct InlineAssistDecorations {
2128 prompt_block_id: CustomBlockId,
2129 prompt_editor: View<PromptEditor>,
2130 removed_line_block_ids: HashSet<CustomBlockId>,
2131 end_block_id: CustomBlockId,
2132}
2133
2134#[derive(Debug)]
2135pub enum CodegenEvent {
2136 Finished,
2137 Undone,
2138}
2139
2140pub struct Codegen {
2141 buffer: Model<MultiBuffer>,
2142 old_buffer: Model<Buffer>,
2143 snapshot: MultiBufferSnapshot,
2144 edit_position: Option<Anchor>,
2145 last_equal_ranges: Vec<Range<Anchor>>,
2146 initial_transaction_id: Option<TransactionId>,
2147 transformation_transaction_id: Option<TransactionId>,
2148 status: CodegenStatus,
2149 generation: Task<()>,
2150 diff: Diff,
2151 telemetry: Option<Arc<Telemetry>>,
2152 _subscription: gpui::Subscription,
2153 builder: Arc<PromptBuilder>,
2154}
2155
2156enum CodegenStatus {
2157 Idle,
2158 Pending,
2159 Done,
2160 Error(anyhow::Error),
2161}
2162
2163#[derive(Default)]
2164struct Diff {
2165 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2166 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2167}
2168
2169impl Diff {
2170 fn is_empty(&self) -> bool {
2171 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2172 }
2173}
2174
2175impl EventEmitter<CodegenEvent> for Codegen {}
2176
2177impl Codegen {
2178 pub fn new(
2179 buffer: Model<MultiBuffer>,
2180 range: Range<Anchor>,
2181 initial_transaction_id: Option<TransactionId>,
2182 telemetry: Option<Arc<Telemetry>>,
2183 builder: Arc<PromptBuilder>,
2184 cx: &mut ModelContext<Self>,
2185 ) -> Self {
2186 let snapshot = buffer.read(cx).snapshot(cx);
2187
2188 let (old_buffer, _, _) = buffer
2189 .read(cx)
2190 .range_to_buffer_ranges(range.clone(), cx)
2191 .pop()
2192 .unwrap();
2193 let old_buffer = cx.new_model(|cx| {
2194 let old_buffer = old_buffer.read(cx);
2195 let text = old_buffer.as_rope().clone();
2196 let line_ending = old_buffer.line_ending();
2197 let language = old_buffer.language().cloned();
2198 let language_registry = old_buffer.language_registry();
2199
2200 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2201 buffer.set_language(language, cx);
2202 if let Some(language_registry) = language_registry {
2203 buffer.set_language_registry(language_registry)
2204 }
2205 buffer
2206 });
2207
2208 Self {
2209 buffer: buffer.clone(),
2210 old_buffer,
2211 edit_position: None,
2212 snapshot,
2213 last_equal_ranges: Default::default(),
2214 transformation_transaction_id: None,
2215 status: CodegenStatus::Idle,
2216 generation: Task::ready(()),
2217 diff: Diff::default(),
2218 telemetry,
2219 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2220 initial_transaction_id,
2221 builder,
2222 }
2223 }
2224
2225 fn handle_buffer_event(
2226 &mut self,
2227 _buffer: Model<MultiBuffer>,
2228 event: &multi_buffer::Event,
2229 cx: &mut ModelContext<Self>,
2230 ) {
2231 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2232 if self.transformation_transaction_id == Some(*transaction_id) {
2233 self.transformation_transaction_id = None;
2234 self.generation = Task::ready(());
2235 cx.emit(CodegenEvent::Undone);
2236 }
2237 }
2238 }
2239
2240 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2241 &self.last_equal_ranges
2242 }
2243
2244 pub fn count_tokens(
2245 &self,
2246 edit_range: Range<Anchor>,
2247 user_prompt: String,
2248 assistant_panel_context: Option<LanguageModelRequest>,
2249 cx: &AppContext,
2250 ) -> BoxFuture<'static, Result<TokenCounts>> {
2251 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2252 let request =
2253 self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
2254 match request {
2255 Ok(request) => {
2256 let total_count = model.count_tokens(request.clone(), cx);
2257 let assistant_panel_count = assistant_panel_context
2258 .map(|context| model.count_tokens(context, cx))
2259 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2260
2261 async move {
2262 Ok(TokenCounts {
2263 total: total_count.await?,
2264 assistant_panel: assistant_panel_count.await?,
2265 })
2266 }
2267 .boxed()
2268 }
2269 Err(error) => futures::future::ready(Err(error)).boxed(),
2270 }
2271 } else {
2272 future::ready(Err(anyhow!("no active model"))).boxed()
2273 }
2274 }
2275
2276 pub fn start(
2277 &mut self,
2278 edit_range: Range<Anchor>,
2279 user_prompt: String,
2280 assistant_panel_context: Option<LanguageModelRequest>,
2281 cx: &mut ModelContext<Self>,
2282 ) -> Result<()> {
2283 let model = LanguageModelRegistry::read_global(cx)
2284 .active_model()
2285 .context("no active model")?;
2286
2287 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2288 self.buffer.update(cx, |buffer, cx| {
2289 buffer.undo_transaction(transformation_transaction_id, cx);
2290 });
2291 }
2292
2293 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2294
2295 let telemetry_id = model.telemetry_id();
2296 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2297 .trim()
2298 .to_lowercase()
2299 == "delete"
2300 {
2301 async { Ok(stream::empty().boxed()) }.boxed_local()
2302 } else {
2303 let request =
2304 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
2305
2306 let chunks =
2307 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2308 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2309 };
2310 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2311 Ok(())
2312 }
2313
2314 fn build_request(
2315 &self,
2316 user_prompt: String,
2317 assistant_panel_context: Option<LanguageModelRequest>,
2318 edit_range: Range<Anchor>,
2319 cx: &AppContext,
2320 ) -> Result<LanguageModelRequest> {
2321 let buffer = self.buffer.read(cx).snapshot(cx);
2322 let language = buffer.language_at(edit_range.start);
2323 let language_name = if let Some(language) = language.as_ref() {
2324 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2325 None
2326 } else {
2327 Some(language.name())
2328 }
2329 } else {
2330 None
2331 };
2332
2333 // Higher Temperature increases the randomness of model outputs.
2334 // If Markdown or No Language is Known, increase the randomness for more creative output
2335 // If Code, decrease temperature to get more deterministic outputs
2336 let temperature = if let Some(language) = language_name.clone() {
2337 if language.as_ref() == "Markdown" {
2338 1.0
2339 } else {
2340 0.5
2341 }
2342 } else {
2343 1.0
2344 };
2345
2346 let language_name = language_name.as_deref();
2347 let start = buffer.point_to_buffer_offset(edit_range.start);
2348 let end = buffer.point_to_buffer_offset(edit_range.end);
2349 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2350 let (start_buffer, start_buffer_offset) = start;
2351 let (end_buffer, end_buffer_offset) = end;
2352 if start_buffer.remote_id() == end_buffer.remote_id() {
2353 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2354 } else {
2355 return Err(anyhow::anyhow!("invalid transformation range"));
2356 }
2357 } else {
2358 return Err(anyhow::anyhow!("invalid transformation range"));
2359 };
2360
2361 let prompt = self
2362 .builder
2363 .generate_content_prompt(user_prompt, language_name, buffer, range)
2364 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2365
2366 let mut messages = Vec::new();
2367 if let Some(context_request) = assistant_panel_context {
2368 messages = context_request.messages;
2369 }
2370
2371 messages.push(LanguageModelRequestMessage {
2372 role: Role::User,
2373 content: vec![prompt.into()],
2374 cache: false,
2375 });
2376
2377 Ok(LanguageModelRequest {
2378 messages,
2379 stop: vec!["|END|>".to_string()],
2380 temperature,
2381 })
2382 }
2383
2384 pub fn handle_stream(
2385 &mut self,
2386 model_telemetry_id: String,
2387 edit_range: Range<Anchor>,
2388 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2389 cx: &mut ModelContext<Self>,
2390 ) {
2391 let snapshot = self.snapshot.clone();
2392 let selected_text = snapshot
2393 .text_for_range(edit_range.start..edit_range.end)
2394 .collect::<Rope>();
2395
2396 let selection_start = edit_range.start.to_point(&snapshot);
2397
2398 // Start with the indentation of the first line in the selection
2399 let mut suggested_line_indent = snapshot
2400 .suggested_indents(selection_start.row..=selection_start.row, cx)
2401 .into_values()
2402 .next()
2403 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2404
2405 // If the first line in the selection does not have indentation, check the following lines
2406 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2407 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2408 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2409 // Prefer tabs if a line in the selection uses tabs as indentation
2410 if line_indent.kind == IndentKind::Tab {
2411 suggested_line_indent.kind = IndentKind::Tab;
2412 break;
2413 }
2414 }
2415 }
2416
2417 let telemetry = self.telemetry.clone();
2418 self.diff = Diff::default();
2419 self.status = CodegenStatus::Pending;
2420 let mut edit_start = edit_range.start.to_offset(&snapshot);
2421 self.generation = cx.spawn(|codegen, mut cx| {
2422 async move {
2423 let chunks = stream.await;
2424 let generate = async {
2425 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2426 let line_based_stream_diff: Task<anyhow::Result<()>> =
2427 cx.background_executor().spawn(async move {
2428 let mut response_latency = None;
2429 let request_start = Instant::now();
2430 let diff = async {
2431 let chunks = StripInvalidSpans::new(chunks?);
2432 futures::pin_mut!(chunks);
2433 let mut diff = StreamingDiff::new(selected_text.to_string());
2434 let mut line_diff = LineDiff::default();
2435
2436 let mut new_text = String::new();
2437 let mut base_indent = None;
2438 let mut line_indent = None;
2439 let mut first_line = true;
2440
2441 while let Some(chunk) = chunks.next().await {
2442 if response_latency.is_none() {
2443 response_latency = Some(request_start.elapsed());
2444 }
2445 let chunk = chunk?;
2446
2447 let mut lines = chunk.split('\n').peekable();
2448 while let Some(line) = lines.next() {
2449 new_text.push_str(line);
2450 if line_indent.is_none() {
2451 if let Some(non_whitespace_ch_ix) =
2452 new_text.find(|ch: char| !ch.is_whitespace())
2453 {
2454 line_indent = Some(non_whitespace_ch_ix);
2455 base_indent = base_indent.or(line_indent);
2456
2457 let line_indent = line_indent.unwrap();
2458 let base_indent = base_indent.unwrap();
2459 let indent_delta =
2460 line_indent as i32 - base_indent as i32;
2461 let mut corrected_indent_len = cmp::max(
2462 0,
2463 suggested_line_indent.len as i32 + indent_delta,
2464 )
2465 as usize;
2466 if first_line {
2467 corrected_indent_len = corrected_indent_len
2468 .saturating_sub(
2469 selection_start.column as usize,
2470 );
2471 }
2472
2473 let indent_char = suggested_line_indent.char();
2474 let mut indent_buffer = [0; 4];
2475 let indent_str =
2476 indent_char.encode_utf8(&mut indent_buffer);
2477 new_text.replace_range(
2478 ..line_indent,
2479 &indent_str.repeat(corrected_indent_len),
2480 );
2481 }
2482 }
2483
2484 if line_indent.is_some() {
2485 let char_ops = diff.push_new(&new_text);
2486 line_diff
2487 .push_char_operations(&char_ops, &selected_text);
2488 diff_tx
2489 .send((char_ops, line_diff.line_operations()))
2490 .await?;
2491 new_text.clear();
2492 }
2493
2494 if lines.peek().is_some() {
2495 let char_ops = diff.push_new("\n");
2496 line_diff
2497 .push_char_operations(&char_ops, &selected_text);
2498 diff_tx
2499 .send((char_ops, line_diff.line_operations()))
2500 .await?;
2501 if line_indent.is_none() {
2502 // Don't write out the leading indentation in empty lines on the next line
2503 // This is the case where the above if statement didn't clear the buffer
2504 new_text.clear();
2505 }
2506 line_indent = None;
2507 first_line = false;
2508 }
2509 }
2510 }
2511
2512 let mut char_ops = diff.push_new(&new_text);
2513 char_ops.extend(diff.finish());
2514 line_diff.push_char_operations(&char_ops, &selected_text);
2515 line_diff.finish(&selected_text);
2516 diff_tx
2517 .send((char_ops, line_diff.line_operations()))
2518 .await?;
2519
2520 anyhow::Ok(())
2521 };
2522
2523 let result = diff.await;
2524
2525 let error_message =
2526 result.as_ref().err().map(|error| error.to_string());
2527 if let Some(telemetry) = telemetry {
2528 telemetry.report_assistant_event(
2529 None,
2530 telemetry_events::AssistantKind::Inline,
2531 model_telemetry_id,
2532 response_latency,
2533 error_message,
2534 );
2535 }
2536
2537 result?;
2538 Ok(())
2539 });
2540
2541 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2542 codegen.update(&mut cx, |codegen, cx| {
2543 codegen.last_equal_ranges.clear();
2544
2545 let transaction = codegen.buffer.update(cx, |buffer, cx| {
2546 // Avoid grouping assistant edits with user edits.
2547 buffer.finalize_last_transaction(cx);
2548
2549 buffer.start_transaction(cx);
2550 buffer.edit(
2551 char_ops
2552 .into_iter()
2553 .filter_map(|operation| match operation {
2554 CharOperation::Insert { text } => {
2555 let edit_start = snapshot.anchor_after(edit_start);
2556 Some((edit_start..edit_start, text))
2557 }
2558 CharOperation::Delete { bytes } => {
2559 let edit_end = edit_start + bytes;
2560 let edit_range = snapshot.anchor_after(edit_start)
2561 ..snapshot.anchor_before(edit_end);
2562 edit_start = edit_end;
2563 Some((edit_range, String::new()))
2564 }
2565 CharOperation::Keep { bytes } => {
2566 let edit_end = edit_start + bytes;
2567 let edit_range = snapshot.anchor_after(edit_start)
2568 ..snapshot.anchor_before(edit_end);
2569 edit_start = edit_end;
2570 codegen.last_equal_ranges.push(edit_range);
2571 None
2572 }
2573 }),
2574 None,
2575 cx,
2576 );
2577 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2578
2579 buffer.end_transaction(cx)
2580 });
2581
2582 if let Some(transaction) = transaction {
2583 if let Some(first_transaction) =
2584 codegen.transformation_transaction_id
2585 {
2586 // Group all assistant edits into the first transaction.
2587 codegen.buffer.update(cx, |buffer, cx| {
2588 buffer.merge_transactions(
2589 transaction,
2590 first_transaction,
2591 cx,
2592 )
2593 });
2594 } else {
2595 codegen.transformation_transaction_id = Some(transaction);
2596 codegen.buffer.update(cx, |buffer, cx| {
2597 buffer.finalize_last_transaction(cx)
2598 });
2599 }
2600 }
2601
2602 codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx);
2603
2604 cx.notify();
2605 })?;
2606 }
2607
2608 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2609 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2610 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2611 let batch_diff_task = codegen.update(&mut cx, |codegen, cx| {
2612 codegen.reapply_batch_diff(edit_range.clone(), cx)
2613 })?;
2614 let (line_based_stream_diff, ()) =
2615 join!(line_based_stream_diff, batch_diff_task);
2616 line_based_stream_diff?;
2617
2618 anyhow::Ok(())
2619 };
2620
2621 let result = generate.await;
2622 codegen
2623 .update(&mut cx, |this, cx| {
2624 this.last_equal_ranges.clear();
2625 if let Err(error) = result {
2626 this.status = CodegenStatus::Error(error);
2627 } else {
2628 this.status = CodegenStatus::Done;
2629 }
2630 cx.emit(CodegenEvent::Finished);
2631 cx.notify();
2632 })
2633 .ok();
2634 }
2635 });
2636 cx.notify();
2637 }
2638
2639 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2640 self.last_equal_ranges.clear();
2641 if self.diff.is_empty() {
2642 self.status = CodegenStatus::Idle;
2643 } else {
2644 self.status = CodegenStatus::Done;
2645 }
2646 self.generation = Task::ready(());
2647 cx.emit(CodegenEvent::Finished);
2648 cx.notify();
2649 }
2650
2651 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2652 self.buffer.update(cx, |buffer, cx| {
2653 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2654 buffer.undo_transaction(transaction_id, cx);
2655 buffer.refresh_preview(cx);
2656 }
2657
2658 if let Some(transaction_id) = self.initial_transaction_id.take() {
2659 buffer.undo_transaction(transaction_id, cx);
2660 buffer.refresh_preview(cx);
2661 }
2662 });
2663 }
2664
2665 fn reapply_line_based_diff(
2666 &mut self,
2667 edit_range: Range<Anchor>,
2668 line_operations: Vec<LineOperation>,
2669 cx: &mut ModelContext<Self>,
2670 ) {
2671 let old_snapshot = self.snapshot.clone();
2672 let old_range = edit_range.to_point(&old_snapshot);
2673 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2674 let new_range = edit_range.to_point(&new_snapshot);
2675
2676 let mut old_row = old_range.start.row;
2677 let mut new_row = new_range.start.row;
2678
2679 self.diff.deleted_row_ranges.clear();
2680 self.diff.inserted_row_ranges.clear();
2681 for operation in line_operations {
2682 match operation {
2683 LineOperation::Keep { lines } => {
2684 old_row += lines;
2685 new_row += lines;
2686 }
2687 LineOperation::Delete { lines } => {
2688 let old_end_row = old_row + lines - 1;
2689 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2690
2691 if let Some((_, last_deleted_row_range)) =
2692 self.diff.deleted_row_ranges.last_mut()
2693 {
2694 if *last_deleted_row_range.end() + 1 == old_row {
2695 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2696 } else {
2697 self.diff
2698 .deleted_row_ranges
2699 .push((new_row, old_row..=old_end_row));
2700 }
2701 } else {
2702 self.diff
2703 .deleted_row_ranges
2704 .push((new_row, old_row..=old_end_row));
2705 }
2706
2707 old_row += lines;
2708 }
2709 LineOperation::Insert { lines } => {
2710 let new_end_row = new_row + lines - 1;
2711 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2712 let end = new_snapshot.anchor_before(Point::new(
2713 new_end_row,
2714 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2715 ));
2716 self.diff.inserted_row_ranges.push(start..=end);
2717 new_row += lines;
2718 }
2719 }
2720
2721 cx.notify();
2722 }
2723 }
2724
2725 fn reapply_batch_diff(
2726 &mut self,
2727 edit_range: Range<Anchor>,
2728 cx: &mut ModelContext<Self>,
2729 ) -> Task<()> {
2730 let old_snapshot = self.snapshot.clone();
2731 let old_range = edit_range.to_point(&old_snapshot);
2732 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2733 let new_range = edit_range.to_point(&new_snapshot);
2734
2735 cx.spawn(|codegen, mut cx| async move {
2736 let (deleted_row_ranges, inserted_row_ranges) = cx
2737 .background_executor()
2738 .spawn(async move {
2739 let old_text = old_snapshot
2740 .text_for_range(
2741 Point::new(old_range.start.row, 0)
2742 ..Point::new(
2743 old_range.end.row,
2744 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2745 ),
2746 )
2747 .collect::<String>();
2748 let new_text = new_snapshot
2749 .text_for_range(
2750 Point::new(new_range.start.row, 0)
2751 ..Point::new(
2752 new_range.end.row,
2753 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2754 ),
2755 )
2756 .collect::<String>();
2757
2758 let mut old_row = old_range.start.row;
2759 let mut new_row = new_range.start.row;
2760 let batch_diff =
2761 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2762
2763 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2764 let mut inserted_row_ranges = Vec::new();
2765 for change in batch_diff.iter_all_changes() {
2766 let line_count = change.value().lines().count() as u32;
2767 match change.tag() {
2768 similar::ChangeTag::Equal => {
2769 old_row += line_count;
2770 new_row += line_count;
2771 }
2772 similar::ChangeTag::Delete => {
2773 let old_end_row = old_row + line_count - 1;
2774 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2775
2776 if let Some((_, last_deleted_row_range)) =
2777 deleted_row_ranges.last_mut()
2778 {
2779 if *last_deleted_row_range.end() + 1 == old_row {
2780 *last_deleted_row_range =
2781 *last_deleted_row_range.start()..=old_end_row;
2782 } else {
2783 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2784 }
2785 } else {
2786 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2787 }
2788
2789 old_row += line_count;
2790 }
2791 similar::ChangeTag::Insert => {
2792 let new_end_row = new_row + line_count - 1;
2793 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2794 let end = new_snapshot.anchor_before(Point::new(
2795 new_end_row,
2796 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2797 ));
2798 inserted_row_ranges.push(start..=end);
2799 new_row += line_count;
2800 }
2801 }
2802 }
2803
2804 (deleted_row_ranges, inserted_row_ranges)
2805 })
2806 .await;
2807
2808 codegen
2809 .update(&mut cx, |codegen, cx| {
2810 codegen.diff.deleted_row_ranges = deleted_row_ranges;
2811 codegen.diff.inserted_row_ranges = inserted_row_ranges;
2812 cx.notify();
2813 })
2814 .ok();
2815 })
2816 }
2817}
2818
2819struct StripInvalidSpans<T> {
2820 stream: T,
2821 stream_done: bool,
2822 buffer: String,
2823 first_line: bool,
2824 line_end: bool,
2825 starts_with_code_block: bool,
2826}
2827
2828impl<T> StripInvalidSpans<T>
2829where
2830 T: Stream<Item = Result<String>>,
2831{
2832 fn new(stream: T) -> Self {
2833 Self {
2834 stream,
2835 stream_done: false,
2836 buffer: String::new(),
2837 first_line: true,
2838 line_end: false,
2839 starts_with_code_block: false,
2840 }
2841 }
2842}
2843
2844impl<T> Stream for StripInvalidSpans<T>
2845where
2846 T: Stream<Item = Result<String>>,
2847{
2848 type Item = Result<String>;
2849
2850 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2851 const CODE_BLOCK_DELIMITER: &str = "```";
2852 const CURSOR_SPAN: &str = "<|CURSOR|>";
2853
2854 let this = unsafe { self.get_unchecked_mut() };
2855 loop {
2856 if !this.stream_done {
2857 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2858 match stream.as_mut().poll_next(cx) {
2859 Poll::Ready(Some(Ok(chunk))) => {
2860 this.buffer.push_str(&chunk);
2861 }
2862 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2863 Poll::Ready(None) => {
2864 this.stream_done = true;
2865 }
2866 Poll::Pending => return Poll::Pending,
2867 }
2868 }
2869
2870 let mut chunk = String::new();
2871 let mut consumed = 0;
2872 if !this.buffer.is_empty() {
2873 let mut lines = this.buffer.split('\n').enumerate().peekable();
2874 while let Some((line_ix, line)) = lines.next() {
2875 if line_ix > 0 {
2876 this.first_line = false;
2877 }
2878
2879 if this.first_line {
2880 let trimmed_line = line.trim();
2881 if lines.peek().is_some() {
2882 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2883 consumed += line.len() + 1;
2884 this.starts_with_code_block = true;
2885 continue;
2886 }
2887 } else if trimmed_line.is_empty()
2888 || prefixes(CODE_BLOCK_DELIMITER)
2889 .any(|prefix| trimmed_line.starts_with(prefix))
2890 {
2891 break;
2892 }
2893 }
2894
2895 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2896 if lines.peek().is_some() {
2897 if this.line_end {
2898 chunk.push('\n');
2899 }
2900
2901 chunk.push_str(&line_without_cursor);
2902 this.line_end = true;
2903 consumed += line.len() + 1;
2904 } else if this.stream_done {
2905 if !this.starts_with_code_block
2906 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2907 {
2908 if this.line_end {
2909 chunk.push('\n');
2910 }
2911
2912 chunk.push_str(&line);
2913 }
2914
2915 consumed += line.len();
2916 } else {
2917 let trimmed_line = line.trim();
2918 if trimmed_line.is_empty()
2919 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2920 || prefixes(CODE_BLOCK_DELIMITER)
2921 .any(|prefix| trimmed_line.ends_with(prefix))
2922 {
2923 break;
2924 } else {
2925 if this.line_end {
2926 chunk.push('\n');
2927 this.line_end = false;
2928 }
2929
2930 chunk.push_str(&line_without_cursor);
2931 consumed += line.len();
2932 }
2933 }
2934 }
2935 }
2936
2937 this.buffer = this.buffer.split_off(consumed);
2938 if !chunk.is_empty() {
2939 return Poll::Ready(Some(Ok(chunk)));
2940 } else if this.stream_done {
2941 return Poll::Ready(None);
2942 }
2943 }
2944 }
2945}
2946
2947fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2948 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2949}
2950
2951fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2952 ranges.sort_unstable_by(|a, b| {
2953 a.start
2954 .cmp(&b.start, buffer)
2955 .then_with(|| b.end.cmp(&a.end, buffer))
2956 });
2957
2958 let mut ix = 0;
2959 while ix + 1 < ranges.len() {
2960 let b = ranges[ix + 1].clone();
2961 let a = &mut ranges[ix];
2962 if a.end.cmp(&b.start, buffer).is_gt() {
2963 if a.end.cmp(&b.end, buffer).is_lt() {
2964 a.end = b.end;
2965 }
2966 ranges.remove(ix + 1);
2967 } else {
2968 ix += 1;
2969 }
2970 }
2971}
2972
2973#[cfg(test)]
2974mod tests {
2975 use super::*;
2976 use futures::stream::{self};
2977 use gpui::{Context, TestAppContext};
2978 use indoc::indoc;
2979 use language::{
2980 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2981 Point,
2982 };
2983 use language_model::LanguageModelRegistry;
2984 use rand::prelude::*;
2985 use serde::Serialize;
2986 use settings::SettingsStore;
2987 use std::{future, sync::Arc};
2988
2989 #[derive(Serialize)]
2990 pub struct DummyCompletionRequest {
2991 pub name: String,
2992 }
2993
2994 #[gpui::test(iterations = 10)]
2995 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2996 cx.set_global(cx.update(SettingsStore::test));
2997 cx.update(language_model::LanguageModelRegistry::test);
2998 cx.update(language_settings::init);
2999
3000 let text = indoc! {"
3001 fn main() {
3002 let x = 0;
3003 for _ in 0..10 {
3004 x += 1;
3005 }
3006 }
3007 "};
3008 let buffer =
3009 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3010 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3011 let range = buffer.read_with(cx, |buffer, cx| {
3012 let snapshot = buffer.snapshot(cx);
3013 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3014 });
3015 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3016 let codegen = cx.new_model(|cx| {
3017 Codegen::new(
3018 buffer.clone(),
3019 range.clone(),
3020 None,
3021 None,
3022 prompt_builder,
3023 cx,
3024 )
3025 });
3026
3027 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3028 codegen.update(cx, |codegen, cx| {
3029 codegen.handle_stream(
3030 String::new(),
3031 range,
3032 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3033 cx,
3034 )
3035 });
3036
3037 let mut new_text = concat!(
3038 " let mut x = 0;\n",
3039 " while x < 10 {\n",
3040 " x += 1;\n",
3041 " }",
3042 );
3043 while !new_text.is_empty() {
3044 let max_len = cmp::min(new_text.len(), 10);
3045 let len = rng.gen_range(1..=max_len);
3046 let (chunk, suffix) = new_text.split_at(len);
3047 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3048 new_text = suffix;
3049 cx.background_executor.run_until_parked();
3050 }
3051 drop(chunks_tx);
3052 cx.background_executor.run_until_parked();
3053
3054 assert_eq!(
3055 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3056 indoc! {"
3057 fn main() {
3058 let mut x = 0;
3059 while x < 10 {
3060 x += 1;
3061 }
3062 }
3063 "}
3064 );
3065 }
3066
3067 #[gpui::test(iterations = 10)]
3068 async fn test_autoindent_when_generating_past_indentation(
3069 cx: &mut TestAppContext,
3070 mut rng: StdRng,
3071 ) {
3072 cx.set_global(cx.update(SettingsStore::test));
3073 cx.update(language_settings::init);
3074
3075 let text = indoc! {"
3076 fn main() {
3077 le
3078 }
3079 "};
3080 let buffer =
3081 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3082 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3083 let range = buffer.read_with(cx, |buffer, cx| {
3084 let snapshot = buffer.snapshot(cx);
3085 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3086 });
3087 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3088 let codegen = cx.new_model(|cx| {
3089 Codegen::new(
3090 buffer.clone(),
3091 range.clone(),
3092 None,
3093 None,
3094 prompt_builder,
3095 cx,
3096 )
3097 });
3098
3099 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3100 codegen.update(cx, |codegen, cx| {
3101 codegen.handle_stream(
3102 String::new(),
3103 range.clone(),
3104 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3105 cx,
3106 )
3107 });
3108
3109 cx.background_executor.run_until_parked();
3110
3111 let mut new_text = concat!(
3112 "t mut x = 0;\n",
3113 "while x < 10 {\n",
3114 " x += 1;\n",
3115 "}", //
3116 );
3117 while !new_text.is_empty() {
3118 let max_len = cmp::min(new_text.len(), 10);
3119 let len = rng.gen_range(1..=max_len);
3120 let (chunk, suffix) = new_text.split_at(len);
3121 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3122 new_text = suffix;
3123 cx.background_executor.run_until_parked();
3124 }
3125 drop(chunks_tx);
3126 cx.background_executor.run_until_parked();
3127
3128 assert_eq!(
3129 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3130 indoc! {"
3131 fn main() {
3132 let mut x = 0;
3133 while x < 10 {
3134 x += 1;
3135 }
3136 }
3137 "}
3138 );
3139 }
3140
3141 #[gpui::test(iterations = 10)]
3142 async fn test_autoindent_when_generating_before_indentation(
3143 cx: &mut TestAppContext,
3144 mut rng: StdRng,
3145 ) {
3146 cx.update(LanguageModelRegistry::test);
3147 cx.set_global(cx.update(SettingsStore::test));
3148 cx.update(language_settings::init);
3149
3150 let text = concat!(
3151 "fn main() {\n",
3152 " \n",
3153 "}\n" //
3154 );
3155 let buffer =
3156 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3157 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3158 let range = buffer.read_with(cx, |buffer, cx| {
3159 let snapshot = buffer.snapshot(cx);
3160 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3161 });
3162 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3163 let codegen = cx.new_model(|cx| {
3164 Codegen::new(
3165 buffer.clone(),
3166 range.clone(),
3167 None,
3168 None,
3169 prompt_builder,
3170 cx,
3171 )
3172 });
3173
3174 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3175 codegen.update(cx, |codegen, cx| {
3176 codegen.handle_stream(
3177 String::new(),
3178 range.clone(),
3179 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3180 cx,
3181 )
3182 });
3183
3184 cx.background_executor.run_until_parked();
3185
3186 let mut new_text = concat!(
3187 "let mut x = 0;\n",
3188 "while x < 10 {\n",
3189 " x += 1;\n",
3190 "}", //
3191 );
3192 while !new_text.is_empty() {
3193 let max_len = cmp::min(new_text.len(), 10);
3194 let len = rng.gen_range(1..=max_len);
3195 let (chunk, suffix) = new_text.split_at(len);
3196 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3197 new_text = suffix;
3198 cx.background_executor.run_until_parked();
3199 }
3200 drop(chunks_tx);
3201 cx.background_executor.run_until_parked();
3202
3203 assert_eq!(
3204 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3205 indoc! {"
3206 fn main() {
3207 let mut x = 0;
3208 while x < 10 {
3209 x += 1;
3210 }
3211 }
3212 "}
3213 );
3214 }
3215
3216 #[gpui::test(iterations = 10)]
3217 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3218 cx.update(LanguageModelRegistry::test);
3219 cx.set_global(cx.update(SettingsStore::test));
3220 cx.update(language_settings::init);
3221
3222 let text = indoc! {"
3223 func main() {
3224 \tx := 0
3225 \tfor i := 0; i < 10; i++ {
3226 \t\tx++
3227 \t}
3228 }
3229 "};
3230 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3231 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3232 let range = buffer.read_with(cx, |buffer, cx| {
3233 let snapshot = buffer.snapshot(cx);
3234 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3235 });
3236 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3237 let codegen = cx.new_model(|cx| {
3238 Codegen::new(
3239 buffer.clone(),
3240 range.clone(),
3241 None,
3242 None,
3243 prompt_builder,
3244 cx,
3245 )
3246 });
3247
3248 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3249 codegen.update(cx, |codegen, cx| {
3250 codegen.handle_stream(
3251 String::new(),
3252 range.clone(),
3253 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3254 cx,
3255 )
3256 });
3257
3258 let new_text = concat!(
3259 "func main() {\n",
3260 "\tx := 0\n",
3261 "\tfor x < 10 {\n",
3262 "\t\tx++\n",
3263 "\t}", //
3264 );
3265 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3266 drop(chunks_tx);
3267 cx.background_executor.run_until_parked();
3268
3269 assert_eq!(
3270 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3271 indoc! {"
3272 func main() {
3273 \tx := 0
3274 \tfor x < 10 {
3275 \t\tx++
3276 \t}
3277 }
3278 "}
3279 );
3280 }
3281
3282 #[gpui::test]
3283 async fn test_strip_invalid_spans_from_codeblock() {
3284 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3285 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3286 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3287 assert_chunks(
3288 "```html\n```js\nLorem ipsum dolor\n```\n```",
3289 "```js\nLorem ipsum dolor\n```",
3290 )
3291 .await;
3292 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3293 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3294 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3295 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3296
3297 async fn assert_chunks(text: &str, expected_text: &str) {
3298 for chunk_size in 1..=text.len() {
3299 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3300 .map(|chunk| chunk.unwrap())
3301 .collect::<String>()
3302 .await;
3303 assert_eq!(
3304 actual_text, expected_text,
3305 "failed to strip invalid spans, chunk size: {}",
3306 chunk_size
3307 );
3308 }
3309 }
3310
3311 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3312 stream::iter(
3313 text.chars()
3314 .collect::<Vec<_>>()
3315 .chunks(size)
3316 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3317 .collect::<Vec<_>>(),
3318 )
3319 }
3320 }
3321
3322 fn rust_lang() -> Language {
3323 Language::new(
3324 LanguageConfig {
3325 name: "Rust".into(),
3326 matcher: LanguageMatcher {
3327 path_suffixes: vec!["rs".to_string()],
3328 ..Default::default()
3329 },
3330 ..Default::default()
3331 },
3332 Some(tree_sitter_rust::language()),
3333 )
3334 .with_indents_query(
3335 r#"
3336 (call_expression) @indent
3337 (field_expression) @indent
3338 (_ "(" ")" @end) @indent
3339 (_ "{" "}" @end) @indent
3340 "#,
3341 )
3342 .unwrap()
3343 }
3344}