1use crate::{
2 assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder,
3 AssistantPanel, AssistantPanelEvent, CharOperation, LineDiff, LineOperation, ModelSelector,
4 StreamingDiff,
5};
6use anyhow::{anyhow, Context as _, Result};
7use client::{telemetry::Telemetry, ErrorExt};
8use collections::{hash_map, HashMap, HashSet, VecDeque};
9use editor::{
10 actions::{MoveDown, MoveUp, SelectAll},
11 display_map::{
12 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
13 ToDisplayPoint,
14 },
15 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
16 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
17};
18use feature_flags::{FeatureFlagAppExt as _, ZedPro};
19use fs::Fs;
20use futures::{
21 channel::mpsc,
22 future::{BoxFuture, LocalBoxFuture},
23 join,
24 stream::{self, BoxStream},
25 SinkExt, Stream, StreamExt,
26};
27use gpui::{
28 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
29 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
30 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
31};
32use language::{Buffer, IndentKind, Point, Selection, TransactionId};
33use language_model::{
34 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
35};
36use multi_buffer::MultiBufferRow;
37use parking_lot::Mutex;
38use rope::Rope;
39use settings::{Settings, SettingsStore};
40use smol::future::FutureExt;
41use std::{
42 cmp,
43 future::{self, Future},
44 mem,
45 ops::{Range, RangeInclusive},
46 pin::Pin,
47 sync::Arc,
48 task::{self, Poll},
49 time::{Duration, Instant},
50};
51use terminal_view::terminal_panel::TerminalPanel;
52use theme::ThemeSettings;
53use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
54use util::{RangeExt, ResultExt};
55use workspace::{notifications::NotificationId, Toast, Workspace};
56
57pub fn init(
58 fs: Arc<dyn Fs>,
59 prompt_builder: Arc<PromptBuilder>,
60 telemetry: Arc<Telemetry>,
61 cx: &mut AppContext,
62) {
63 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
64 cx.observe_new_views(|_, cx| {
65 let workspace = cx.view().clone();
66 InlineAssistant::update_global(cx, |inline_assistant, cx| {
67 inline_assistant.register_workspace(&workspace, cx)
68 })
69 })
70 .detach();
71}
72
73const PROMPT_HISTORY_MAX_LEN: usize = 20;
74
75pub struct InlineAssistant {
76 next_assist_id: InlineAssistId,
77 next_assist_group_id: InlineAssistGroupId,
78 assists: HashMap<InlineAssistId, InlineAssist>,
79 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
80 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
81 assist_observations: HashMap<
82 InlineAssistId,
83 (
84 async_watch::Sender<AssistStatus>,
85 async_watch::Receiver<AssistStatus>,
86 ),
87 >,
88 confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
89 prompt_history: VecDeque<String>,
90 prompt_builder: Arc<PromptBuilder>,
91 telemetry: Option<Arc<Telemetry>>,
92 fs: Arc<dyn Fs>,
93}
94
95pub enum AssistStatus {
96 Idle,
97 Started,
98 Stopped,
99 Finished,
100}
101
102impl AssistStatus {
103 pub fn is_done(&self) -> bool {
104 matches!(self, Self::Stopped | Self::Finished)
105 }
106}
107
108impl Global for InlineAssistant {}
109
110impl InlineAssistant {
111 pub fn new(
112 fs: Arc<dyn Fs>,
113 prompt_builder: Arc<PromptBuilder>,
114 telemetry: Arc<Telemetry>,
115 ) -> Self {
116 Self {
117 next_assist_id: InlineAssistId::default(),
118 next_assist_group_id: InlineAssistGroupId::default(),
119 assists: HashMap::default(),
120 assists_by_editor: HashMap::default(),
121 assist_groups: HashMap::default(),
122 assist_observations: HashMap::default(),
123 confirmed_assists: HashMap::default(),
124 prompt_history: VecDeque::default(),
125 prompt_builder,
126 telemetry: Some(telemetry),
127 fs,
128 }
129 }
130
131 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
132 cx.subscribe(workspace, |_, event, cx| {
133 Self::update_global(cx, |this, cx| this.handle_workspace_event(event, cx));
134 })
135 .detach();
136
137 let workspace = workspace.downgrade();
138 cx.observe_global::<SettingsStore>(move |cx| {
139 let Some(workspace) = workspace.upgrade() else {
140 return;
141 };
142 let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
143 return;
144 };
145 let enabled = AssistantSettings::get_global(cx).enabled;
146 terminal_panel.update(cx, |terminal_panel, cx| {
147 terminal_panel.asssistant_enabled(enabled, cx)
148 });
149 })
150 .detach();
151 }
152
153 fn handle_workspace_event(&mut self, event: &workspace::Event, cx: &mut WindowContext) {
154 // When the user manually saves an editor, automatically accepts all finished transformations.
155 if let workspace::Event::UserSavedItem { item, .. } = event {
156 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
157 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
158 for assist_id in editor_assists.assist_ids.clone() {
159 let assist = &self.assists[&assist_id];
160 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
161 self.finish_assist(assist_id, false, cx)
162 }
163 }
164 }
165 }
166 }
167 }
168
169 pub fn assist(
170 &mut self,
171 editor: &View<Editor>,
172 workspace: Option<WeakView<Workspace>>,
173 assistant_panel: Option<&View<AssistantPanel>>,
174 initial_prompt: Option<String>,
175 cx: &mut WindowContext,
176 ) {
177 if let Some(telemetry) = self.telemetry.as_ref() {
178 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
179 telemetry.report_assistant_event(
180 None,
181 telemetry_events::AssistantKind::Inline,
182 telemetry_events::AssistantPhase::Invoked,
183 model.telemetry_id(),
184 None,
185 None,
186 );
187 }
188 }
189 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
190
191 let mut selections = Vec::<Selection<Point>>::new();
192 let mut newest_selection = None;
193 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
194 if selection.end > selection.start {
195 selection.start.column = 0;
196 // If the selection ends at the start of the line, we don't want to include it.
197 if selection.end.column == 0 {
198 selection.end.row -= 1;
199 }
200 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
201 }
202
203 if let Some(prev_selection) = selections.last_mut() {
204 if selection.start <= prev_selection.end {
205 prev_selection.end = selection.end;
206 continue;
207 }
208 }
209
210 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
211 if selection.id > latest_selection.id {
212 *latest_selection = selection.clone();
213 }
214 selections.push(selection);
215 }
216 let newest_selection = newest_selection.unwrap();
217
218 let mut codegen_ranges = Vec::new();
219 for (excerpt_id, buffer, buffer_range) in
220 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
221 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
222 }))
223 {
224 let start = Anchor {
225 buffer_id: Some(buffer.remote_id()),
226 excerpt_id,
227 text_anchor: buffer.anchor_before(buffer_range.start),
228 };
229 let end = Anchor {
230 buffer_id: Some(buffer.remote_id()),
231 excerpt_id,
232 text_anchor: buffer.anchor_after(buffer_range.end),
233 };
234 codegen_ranges.push(start..end);
235 }
236
237 let assist_group_id = self.next_assist_group_id.post_inc();
238 let prompt_buffer =
239 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
240 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
241
242 let mut assists = Vec::new();
243 let mut assist_to_focus = None;
244 for range in codegen_ranges {
245 let assist_id = self.next_assist_id.post_inc();
246 let codegen = cx.new_model(|cx| {
247 Codegen::new(
248 editor.read(cx).buffer().clone(),
249 range.clone(),
250 None,
251 self.telemetry.clone(),
252 self.prompt_builder.clone(),
253 cx,
254 )
255 });
256
257 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
258 let prompt_editor = cx.new_view(|cx| {
259 PromptEditor::new(
260 assist_id,
261 gutter_dimensions.clone(),
262 self.prompt_history.clone(),
263 prompt_buffer.clone(),
264 codegen.clone(),
265 editor,
266 assistant_panel,
267 workspace.clone(),
268 self.fs.clone(),
269 cx,
270 )
271 });
272
273 if assist_to_focus.is_none() {
274 let focus_assist = if newest_selection.reversed {
275 range.start.to_point(&snapshot) == newest_selection.start
276 } else {
277 range.end.to_point(&snapshot) == newest_selection.end
278 };
279 if focus_assist {
280 assist_to_focus = Some(assist_id);
281 }
282 }
283
284 let [prompt_block_id, end_block_id] =
285 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
286
287 assists.push((
288 assist_id,
289 range,
290 prompt_editor,
291 prompt_block_id,
292 end_block_id,
293 ));
294 }
295
296 let editor_assists = self
297 .assists_by_editor
298 .entry(editor.downgrade())
299 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
300 let mut assist_group = InlineAssistGroup::new();
301 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
302 self.assists.insert(
303 assist_id,
304 InlineAssist::new(
305 assist_id,
306 assist_group_id,
307 assistant_panel.is_some(),
308 editor,
309 &prompt_editor,
310 prompt_block_id,
311 end_block_id,
312 range,
313 prompt_editor.read(cx).codegen.clone(),
314 workspace.clone(),
315 cx,
316 ),
317 );
318 assist_group.assist_ids.push(assist_id);
319 editor_assists.assist_ids.push(assist_id);
320 }
321 self.assist_groups.insert(assist_group_id, assist_group);
322
323 if let Some(assist_id) = assist_to_focus {
324 self.focus_assist(assist_id, cx);
325 }
326 }
327
328 #[allow(clippy::too_many_arguments)]
329 pub fn suggest_assist(
330 &mut self,
331 editor: &View<Editor>,
332 mut range: Range<Anchor>,
333 initial_prompt: String,
334 initial_transaction_id: Option<TransactionId>,
335 workspace: Option<WeakView<Workspace>>,
336 assistant_panel: Option<&View<AssistantPanel>>,
337 cx: &mut WindowContext,
338 ) -> InlineAssistId {
339 let assist_group_id = self.next_assist_group_id.post_inc();
340 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
341 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
342
343 let assist_id = self.next_assist_id.post_inc();
344
345 let buffer = editor.read(cx).buffer().clone();
346 {
347 let snapshot = buffer.read(cx).read(cx);
348 range.start = range.start.bias_left(&snapshot);
349 range.end = range.end.bias_right(&snapshot);
350 }
351
352 let codegen = cx.new_model(|cx| {
353 Codegen::new(
354 editor.read(cx).buffer().clone(),
355 range.clone(),
356 initial_transaction_id,
357 self.telemetry.clone(),
358 self.prompt_builder.clone(),
359 cx,
360 )
361 });
362
363 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
364 let prompt_editor = cx.new_view(|cx| {
365 PromptEditor::new(
366 assist_id,
367 gutter_dimensions.clone(),
368 self.prompt_history.clone(),
369 prompt_buffer.clone(),
370 codegen.clone(),
371 editor,
372 assistant_panel,
373 workspace.clone(),
374 self.fs.clone(),
375 cx,
376 )
377 });
378
379 let [prompt_block_id, end_block_id] =
380 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
381
382 let editor_assists = self
383 .assists_by_editor
384 .entry(editor.downgrade())
385 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
386
387 let mut assist_group = InlineAssistGroup::new();
388 self.assists.insert(
389 assist_id,
390 InlineAssist::new(
391 assist_id,
392 assist_group_id,
393 assistant_panel.is_some(),
394 editor,
395 &prompt_editor,
396 prompt_block_id,
397 end_block_id,
398 range,
399 prompt_editor.read(cx).codegen.clone(),
400 workspace.clone(),
401 cx,
402 ),
403 );
404 assist_group.assist_ids.push(assist_id);
405 editor_assists.assist_ids.push(assist_id);
406 self.assist_groups.insert(assist_group_id, assist_group);
407 assist_id
408 }
409
410 fn insert_assist_blocks(
411 &self,
412 editor: &View<Editor>,
413 range: &Range<Anchor>,
414 prompt_editor: &View<PromptEditor>,
415 cx: &mut WindowContext,
416 ) -> [CustomBlockId; 2] {
417 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
418 prompt_editor
419 .editor
420 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
421 });
422 let assist_blocks = vec![
423 BlockProperties {
424 style: BlockStyle::Sticky,
425 position: range.start,
426 height: prompt_editor_height,
427 render: build_assist_editor_renderer(prompt_editor),
428 disposition: BlockDisposition::Above,
429 priority: 0,
430 },
431 BlockProperties {
432 style: BlockStyle::Sticky,
433 position: range.end,
434 height: 0,
435 render: Box::new(|cx| {
436 v_flex()
437 .h_full()
438 .w_full()
439 .border_t_1()
440 .border_color(cx.theme().status().info_border)
441 .into_any_element()
442 }),
443 disposition: BlockDisposition::Below,
444 priority: 0,
445 },
446 ];
447
448 editor.update(cx, |editor, cx| {
449 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
450 [block_ids[0], block_ids[1]]
451 })
452 }
453
454 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
455 let assist = &self.assists[&assist_id];
456 let Some(decorations) = assist.decorations.as_ref() else {
457 return;
458 };
459 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
460 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
461
462 assist_group.active_assist_id = Some(assist_id);
463 if assist_group.linked {
464 for assist_id in &assist_group.assist_ids {
465 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
466 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
467 prompt_editor.set_show_cursor_when_unfocused(true, cx)
468 });
469 }
470 }
471 }
472
473 assist
474 .editor
475 .update(cx, |editor, cx| {
476 let scroll_top = editor.scroll_position(cx).y;
477 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
478 let prompt_row = editor
479 .row_for_block(decorations.prompt_block_id, cx)
480 .unwrap()
481 .0 as f32;
482
483 if (scroll_top..scroll_bottom).contains(&prompt_row) {
484 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
485 assist_id,
486 distance_from_top: prompt_row - scroll_top,
487 });
488 } else {
489 editor_assists.scroll_lock = None;
490 }
491 })
492 .ok();
493 }
494
495 fn handle_prompt_editor_focus_out(
496 &mut self,
497 assist_id: InlineAssistId,
498 cx: &mut WindowContext,
499 ) {
500 let assist = &self.assists[&assist_id];
501 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
502 if assist_group.active_assist_id == Some(assist_id) {
503 assist_group.active_assist_id = None;
504 if assist_group.linked {
505 for assist_id in &assist_group.assist_ids {
506 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
507 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
508 prompt_editor.set_show_cursor_when_unfocused(false, cx)
509 });
510 }
511 }
512 }
513 }
514 }
515
516 fn handle_prompt_editor_event(
517 &mut self,
518 prompt_editor: View<PromptEditor>,
519 event: &PromptEditorEvent,
520 cx: &mut WindowContext,
521 ) {
522 let assist_id = prompt_editor.read(cx).id;
523 match event {
524 PromptEditorEvent::StartRequested => {
525 self.start_assist(assist_id, cx);
526 }
527 PromptEditorEvent::StopRequested => {
528 self.stop_assist(assist_id, cx);
529 }
530 PromptEditorEvent::ConfirmRequested => {
531 self.finish_assist(assist_id, false, cx);
532 }
533 PromptEditorEvent::CancelRequested => {
534 self.finish_assist(assist_id, true, cx);
535 }
536 PromptEditorEvent::DismissRequested => {
537 self.dismiss_assist(assist_id, cx);
538 }
539 }
540 }
541
542 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
543 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
544 return;
545 };
546
547 let editor = editor.read(cx);
548 if editor.selections.count() == 1 {
549 let selection = editor.selections.newest::<usize>(cx);
550 let buffer = editor.buffer().read(cx).snapshot(cx);
551 for assist_id in &editor_assists.assist_ids {
552 let assist = &self.assists[assist_id];
553 let assist_range = assist.range.to_offset(&buffer);
554 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
555 {
556 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
557 self.dismiss_assist(*assist_id, cx);
558 } else {
559 self.finish_assist(*assist_id, false, cx);
560 }
561
562 return;
563 }
564 }
565 }
566
567 cx.propagate();
568 }
569
570 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
571 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
572 return;
573 };
574
575 let editor = editor.read(cx);
576 if editor.selections.count() == 1 {
577 let selection = editor.selections.newest::<usize>(cx);
578 let buffer = editor.buffer().read(cx).snapshot(cx);
579 let mut closest_assist_fallback = None;
580 for assist_id in &editor_assists.assist_ids {
581 let assist = &self.assists[assist_id];
582 let assist_range = assist.range.to_offset(&buffer);
583 if assist.decorations.is_some() {
584 if assist_range.contains(&selection.start)
585 && assist_range.contains(&selection.end)
586 {
587 self.focus_assist(*assist_id, cx);
588 return;
589 } else {
590 let distance_from_selection = assist_range
591 .start
592 .abs_diff(selection.start)
593 .min(assist_range.start.abs_diff(selection.end))
594 + assist_range
595 .end
596 .abs_diff(selection.start)
597 .min(assist_range.end.abs_diff(selection.end));
598 match closest_assist_fallback {
599 Some((_, old_distance)) => {
600 if distance_from_selection < old_distance {
601 closest_assist_fallback =
602 Some((assist_id, distance_from_selection));
603 }
604 }
605 None => {
606 closest_assist_fallback = Some((assist_id, distance_from_selection))
607 }
608 }
609 }
610 }
611 }
612
613 if let Some((&assist_id, _)) = closest_assist_fallback {
614 self.focus_assist(assist_id, cx);
615 }
616 }
617
618 cx.propagate();
619 }
620
621 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
622 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
623 for assist_id in editor_assists.assist_ids.clone() {
624 self.finish_assist(assist_id, true, cx);
625 }
626 }
627 }
628
629 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
630 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
631 return;
632 };
633 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
634 return;
635 };
636 let assist = &self.assists[&scroll_lock.assist_id];
637 let Some(decorations) = assist.decorations.as_ref() else {
638 return;
639 };
640
641 editor.update(cx, |editor, cx| {
642 let scroll_position = editor.scroll_position(cx);
643 let target_scroll_top = editor
644 .row_for_block(decorations.prompt_block_id, cx)
645 .unwrap()
646 .0 as f32
647 - scroll_lock.distance_from_top;
648 if target_scroll_top != scroll_position.y {
649 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
650 }
651 });
652 }
653
654 fn handle_editor_event(
655 &mut self,
656 editor: View<Editor>,
657 event: &EditorEvent,
658 cx: &mut WindowContext,
659 ) {
660 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
661 return;
662 };
663
664 match event {
665 EditorEvent::Edited { transaction_id } => {
666 let buffer = editor.read(cx).buffer().read(cx);
667 let edited_ranges =
668 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
669 let snapshot = buffer.snapshot(cx);
670
671 for assist_id in editor_assists.assist_ids.clone() {
672 let assist = &self.assists[&assist_id];
673 if matches!(
674 assist.codegen.read(cx).status,
675 CodegenStatus::Error(_) | CodegenStatus::Done
676 ) {
677 let assist_range = assist.range.to_offset(&snapshot);
678 if edited_ranges
679 .iter()
680 .any(|range| range.overlaps(&assist_range))
681 {
682 self.finish_assist(assist_id, false, cx);
683 }
684 }
685 }
686 }
687 EditorEvent::ScrollPositionChanged { .. } => {
688 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
689 let assist = &self.assists[&scroll_lock.assist_id];
690 if let Some(decorations) = assist.decorations.as_ref() {
691 let distance_from_top = editor.update(cx, |editor, cx| {
692 let scroll_top = editor.scroll_position(cx).y;
693 let prompt_row = editor
694 .row_for_block(decorations.prompt_block_id, cx)
695 .unwrap()
696 .0 as f32;
697 prompt_row - scroll_top
698 });
699
700 if distance_from_top != scroll_lock.distance_from_top {
701 editor_assists.scroll_lock = None;
702 }
703 }
704 }
705 }
706 EditorEvent::SelectionsChanged { .. } => {
707 for assist_id in editor_assists.assist_ids.clone() {
708 let assist = &self.assists[&assist_id];
709 if let Some(decorations) = assist.decorations.as_ref() {
710 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
711 return;
712 }
713 }
714 }
715
716 editor_assists.scroll_lock = None;
717 }
718 _ => {}
719 }
720 }
721
722 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
723 if let Some(telemetry) = self.telemetry.as_ref() {
724 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
725 telemetry.report_assistant_event(
726 None,
727 telemetry_events::AssistantKind::Inline,
728 if undo {
729 telemetry_events::AssistantPhase::Rejected
730 } else {
731 telemetry_events::AssistantPhase::Accepted
732 },
733 model.telemetry_id(),
734 None,
735 None,
736 );
737 }
738 }
739 if let Some(assist) = self.assists.get(&assist_id) {
740 let assist_group_id = assist.group_id;
741 if self.assist_groups[&assist_group_id].linked {
742 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
743 self.finish_assist(assist_id, undo, cx);
744 }
745 return;
746 }
747 }
748
749 self.dismiss_assist(assist_id, cx);
750
751 if let Some(assist) = self.assists.remove(&assist_id) {
752 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
753 {
754 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
755 if entry.get().assist_ids.is_empty() {
756 entry.remove();
757 }
758 }
759
760 if let hash_map::Entry::Occupied(mut entry) =
761 self.assists_by_editor.entry(assist.editor.clone())
762 {
763 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
764 if entry.get().assist_ids.is_empty() {
765 entry.remove();
766 if let Some(editor) = assist.editor.upgrade() {
767 self.update_editor_highlights(&editor, cx);
768 }
769 } else {
770 entry.get().highlight_updates.send(()).ok();
771 }
772 }
773
774 if undo {
775 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
776 } else {
777 self.confirmed_assists.insert(assist_id, assist.codegen);
778 }
779 }
780
781 // Remove the assist from the status updates map
782 self.assist_observations.remove(&assist_id);
783 }
784
785 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
786 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
787 return false;
788 };
789 codegen.update(cx, |this, cx| this.undo(cx));
790 true
791 }
792
793 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
794 let Some(assist) = self.assists.get_mut(&assist_id) else {
795 return false;
796 };
797 let Some(editor) = assist.editor.upgrade() else {
798 return false;
799 };
800 let Some(decorations) = assist.decorations.take() else {
801 return false;
802 };
803
804 editor.update(cx, |editor, cx| {
805 let mut to_remove = decorations.removed_line_block_ids;
806 to_remove.insert(decorations.prompt_block_id);
807 to_remove.insert(decorations.end_block_id);
808 editor.remove_blocks(to_remove, None, cx);
809 });
810
811 if decorations
812 .prompt_editor
813 .focus_handle(cx)
814 .contains_focused(cx)
815 {
816 self.focus_next_assist(assist_id, cx);
817 }
818
819 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
820 if editor_assists
821 .scroll_lock
822 .as_ref()
823 .map_or(false, |lock| lock.assist_id == assist_id)
824 {
825 editor_assists.scroll_lock = None;
826 }
827 editor_assists.highlight_updates.send(()).ok();
828 }
829
830 true
831 }
832
833 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
834 let Some(assist) = self.assists.get(&assist_id) else {
835 return;
836 };
837
838 let assist_group = &self.assist_groups[&assist.group_id];
839 let assist_ix = assist_group
840 .assist_ids
841 .iter()
842 .position(|id| *id == assist_id)
843 .unwrap();
844 let assist_ids = assist_group
845 .assist_ids
846 .iter()
847 .skip(assist_ix + 1)
848 .chain(assist_group.assist_ids.iter().take(assist_ix));
849
850 for assist_id in assist_ids {
851 let assist = &self.assists[assist_id];
852 if assist.decorations.is_some() {
853 self.focus_assist(*assist_id, cx);
854 return;
855 }
856 }
857
858 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
859 }
860
861 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
862 let Some(assist) = self.assists.get(&assist_id) else {
863 return;
864 };
865
866 if let Some(decorations) = assist.decorations.as_ref() {
867 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
868 prompt_editor.editor.update(cx, |editor, cx| {
869 editor.focus(cx);
870 editor.select_all(&SelectAll, cx);
871 })
872 });
873 }
874
875 self.scroll_to_assist(assist_id, cx);
876 }
877
878 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
879 let Some(assist) = self.assists.get(&assist_id) else {
880 return;
881 };
882 let Some(editor) = assist.editor.upgrade() else {
883 return;
884 };
885
886 let position = assist.range.start;
887 editor.update(cx, |editor, cx| {
888 editor.change_selections(None, cx, |selections| {
889 selections.select_anchor_ranges([position..position])
890 });
891
892 let mut scroll_target_top;
893 let mut scroll_target_bottom;
894 if let Some(decorations) = assist.decorations.as_ref() {
895 scroll_target_top = editor
896 .row_for_block(decorations.prompt_block_id, cx)
897 .unwrap()
898 .0 as f32;
899 scroll_target_bottom = editor
900 .row_for_block(decorations.end_block_id, cx)
901 .unwrap()
902 .0 as f32;
903 } else {
904 let snapshot = editor.snapshot(cx);
905 let start_row = assist
906 .range
907 .start
908 .to_display_point(&snapshot.display_snapshot)
909 .row();
910 scroll_target_top = start_row.0 as f32;
911 scroll_target_bottom = scroll_target_top + 1.;
912 }
913 scroll_target_top -= editor.vertical_scroll_margin() as f32;
914 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
915
916 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
917 let scroll_top = editor.scroll_position(cx).y;
918 let scroll_bottom = scroll_top + height_in_lines;
919
920 if scroll_target_top < scroll_top {
921 editor.set_scroll_position(point(0., scroll_target_top), cx);
922 } else if scroll_target_bottom > scroll_bottom {
923 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
924 editor
925 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
926 } else {
927 editor.set_scroll_position(point(0., scroll_target_top), cx);
928 }
929 }
930 });
931 }
932
933 fn unlink_assist_group(
934 &mut self,
935 assist_group_id: InlineAssistGroupId,
936 cx: &mut WindowContext,
937 ) -> Vec<InlineAssistId> {
938 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
939 assist_group.linked = false;
940 for assist_id in &assist_group.assist_ids {
941 let assist = self.assists.get_mut(assist_id).unwrap();
942 if let Some(editor_decorations) = assist.decorations.as_ref() {
943 editor_decorations
944 .prompt_editor
945 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
946 }
947 }
948 assist_group.assist_ids.clone()
949 }
950
951 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
952 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
953 assist
954 } else {
955 return;
956 };
957
958 let assist_group_id = assist.group_id;
959 if self.assist_groups[&assist_group_id].linked {
960 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
961 self.start_assist(assist_id, cx);
962 }
963 return;
964 }
965
966 let Some(user_prompt) = assist.user_prompt(cx) else {
967 return;
968 };
969
970 self.prompt_history.retain(|prompt| *prompt != user_prompt);
971 self.prompt_history.push_back(user_prompt.clone());
972 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
973 self.prompt_history.pop_front();
974 }
975
976 let assistant_panel_context = assist.assistant_panel_context(cx);
977
978 assist
979 .codegen
980 .update(cx, |codegen, cx| {
981 codegen.start(
982 assist.range.clone(),
983 user_prompt,
984 assistant_panel_context,
985 cx,
986 )
987 })
988 .log_err();
989
990 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
991 tx.send(AssistStatus::Started).ok();
992 }
993 }
994
995 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
996 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
997 assist
998 } else {
999 return;
1000 };
1001
1002 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1003
1004 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
1005 tx.send(AssistStatus::Stopped).ok();
1006 }
1007 }
1008
1009 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
1010 if let Some(assist) = self.assists.get(&assist_id) {
1011 match &assist.codegen.read(cx).status {
1012 CodegenStatus::Idle => InlineAssistStatus::Idle,
1013 CodegenStatus::Pending => InlineAssistStatus::Pending,
1014 CodegenStatus::Done => InlineAssistStatus::Done,
1015 CodegenStatus::Error(_) => InlineAssistStatus::Error,
1016 }
1017 } else if self.confirmed_assists.contains_key(&assist_id) {
1018 InlineAssistStatus::Confirmed
1019 } else {
1020 InlineAssistStatus::Canceled
1021 }
1022 }
1023
1024 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1025 let mut gutter_pending_ranges = Vec::new();
1026 let mut gutter_transformed_ranges = Vec::new();
1027 let mut foreground_ranges = Vec::new();
1028 let mut inserted_row_ranges = Vec::new();
1029 let empty_assist_ids = Vec::new();
1030 let assist_ids = self
1031 .assists_by_editor
1032 .get(&editor.downgrade())
1033 .map_or(&empty_assist_ids, |editor_assists| {
1034 &editor_assists.assist_ids
1035 });
1036
1037 for assist_id in assist_ids {
1038 if let Some(assist) = self.assists.get(assist_id) {
1039 let codegen = assist.codegen.read(cx);
1040 let buffer = codegen.buffer.read(cx).read(cx);
1041 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
1042
1043 let pending_range =
1044 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
1045 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1046 gutter_pending_ranges.push(pending_range);
1047 }
1048
1049 if let Some(edit_position) = codegen.edit_position {
1050 let edited_range = assist.range.start..edit_position;
1051 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1052 gutter_transformed_ranges.push(edited_range);
1053 }
1054 }
1055
1056 if assist.decorations.is_some() {
1057 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
1058 }
1059 }
1060 }
1061
1062 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1063 merge_ranges(&mut foreground_ranges, &snapshot);
1064 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1065 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1066 editor.update(cx, |editor, cx| {
1067 enum GutterPendingRange {}
1068 if gutter_pending_ranges.is_empty() {
1069 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1070 } else {
1071 editor.highlight_gutter::<GutterPendingRange>(
1072 &gutter_pending_ranges,
1073 |cx| cx.theme().status().info_background,
1074 cx,
1075 )
1076 }
1077
1078 enum GutterTransformedRange {}
1079 if gutter_transformed_ranges.is_empty() {
1080 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1081 } else {
1082 editor.highlight_gutter::<GutterTransformedRange>(
1083 &gutter_transformed_ranges,
1084 |cx| cx.theme().status().info,
1085 cx,
1086 )
1087 }
1088
1089 if foreground_ranges.is_empty() {
1090 editor.clear_highlights::<InlineAssist>(cx);
1091 } else {
1092 editor.highlight_text::<InlineAssist>(
1093 foreground_ranges,
1094 HighlightStyle {
1095 fade_out: Some(0.6),
1096 ..Default::default()
1097 },
1098 cx,
1099 );
1100 }
1101
1102 editor.clear_row_highlights::<InlineAssist>();
1103 for row_range in inserted_row_ranges {
1104 editor.highlight_rows::<InlineAssist>(
1105 row_range,
1106 Some(cx.theme().status().info_background),
1107 false,
1108 cx,
1109 );
1110 }
1111 });
1112 }
1113
1114 fn update_editor_blocks(
1115 &mut self,
1116 editor: &View<Editor>,
1117 assist_id: InlineAssistId,
1118 cx: &mut WindowContext,
1119 ) {
1120 let Some(assist) = self.assists.get_mut(&assist_id) else {
1121 return;
1122 };
1123 let Some(decorations) = assist.decorations.as_mut() else {
1124 return;
1125 };
1126
1127 let codegen = assist.codegen.read(cx);
1128 let old_snapshot = codegen.snapshot.clone();
1129 let old_buffer = codegen.old_buffer.clone();
1130 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
1131
1132 editor.update(cx, |editor, cx| {
1133 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1134 editor.remove_blocks(old_blocks, None, cx);
1135
1136 let mut new_blocks = Vec::new();
1137 for (new_row, old_row_range) in deleted_row_ranges {
1138 let (_, buffer_start) = old_snapshot
1139 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1140 .unwrap();
1141 let (_, buffer_end) = old_snapshot
1142 .point_to_buffer_offset(Point::new(
1143 *old_row_range.end(),
1144 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1145 ))
1146 .unwrap();
1147
1148 let deleted_lines_editor = cx.new_view(|cx| {
1149 let multi_buffer = cx.new_model(|_| {
1150 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1151 });
1152 multi_buffer.update(cx, |multi_buffer, cx| {
1153 multi_buffer.push_excerpts(
1154 old_buffer.clone(),
1155 Some(ExcerptRange {
1156 context: buffer_start..buffer_end,
1157 primary: None,
1158 }),
1159 cx,
1160 );
1161 });
1162
1163 enum DeletedLines {}
1164 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1165 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1166 editor.set_show_wrap_guides(false, cx);
1167 editor.set_show_gutter(false, cx);
1168 editor.scroll_manager.set_forbid_vertical_scroll(true);
1169 editor.set_read_only(true);
1170 editor.set_show_inline_completions(Some(false), cx);
1171 editor.highlight_rows::<DeletedLines>(
1172 Anchor::min()..=Anchor::max(),
1173 Some(cx.theme().status().deleted_background),
1174 false,
1175 cx,
1176 );
1177 editor
1178 });
1179
1180 let height =
1181 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1182 new_blocks.push(BlockProperties {
1183 position: new_row,
1184 height,
1185 style: BlockStyle::Flex,
1186 render: Box::new(move |cx| {
1187 div()
1188 .bg(cx.theme().status().deleted_background)
1189 .size_full()
1190 .h(height as f32 * cx.line_height())
1191 .pl(cx.gutter_dimensions.full_width())
1192 .child(deleted_lines_editor.clone())
1193 .into_any_element()
1194 }),
1195 disposition: BlockDisposition::Above,
1196 priority: 0,
1197 });
1198 }
1199
1200 decorations.removed_line_block_ids = editor
1201 .insert_blocks(new_blocks, None, cx)
1202 .into_iter()
1203 .collect();
1204 })
1205 }
1206
1207 pub fn observe_assist(
1208 &mut self,
1209 assist_id: InlineAssistId,
1210 ) -> async_watch::Receiver<AssistStatus> {
1211 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1212 rx.clone()
1213 } else {
1214 let (tx, rx) = async_watch::channel(AssistStatus::Idle);
1215 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1216 rx
1217 }
1218 }
1219}
1220
1221pub enum InlineAssistStatus {
1222 Idle,
1223 Pending,
1224 Done,
1225 Error,
1226 Confirmed,
1227 Canceled,
1228}
1229
1230impl InlineAssistStatus {
1231 pub(crate) fn is_pending(&self) -> bool {
1232 matches!(self, Self::Pending)
1233 }
1234
1235 pub(crate) fn is_confirmed(&self) -> bool {
1236 matches!(self, Self::Confirmed)
1237 }
1238
1239 pub(crate) fn is_done(&self) -> bool {
1240 matches!(self, Self::Done)
1241 }
1242}
1243
1244struct EditorInlineAssists {
1245 assist_ids: Vec<InlineAssistId>,
1246 scroll_lock: Option<InlineAssistScrollLock>,
1247 highlight_updates: async_watch::Sender<()>,
1248 _update_highlights: Task<Result<()>>,
1249 _subscriptions: Vec<gpui::Subscription>,
1250}
1251
1252struct InlineAssistScrollLock {
1253 assist_id: InlineAssistId,
1254 distance_from_top: f32,
1255}
1256
1257impl EditorInlineAssists {
1258 #[allow(clippy::too_many_arguments)]
1259 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1260 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1261 Self {
1262 assist_ids: Vec::new(),
1263 scroll_lock: None,
1264 highlight_updates: highlight_updates_tx,
1265 _update_highlights: cx.spawn(|mut cx| {
1266 let editor = editor.downgrade();
1267 async move {
1268 while let Ok(()) = highlight_updates_rx.changed().await {
1269 let editor = editor.upgrade().context("editor was dropped")?;
1270 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1271 assistant.update_editor_highlights(&editor, cx);
1272 })?;
1273 }
1274 Ok(())
1275 }
1276 }),
1277 _subscriptions: vec![
1278 cx.observe_release(editor, {
1279 let editor = editor.downgrade();
1280 |_, cx| {
1281 InlineAssistant::update_global(cx, |this, cx| {
1282 this.handle_editor_release(editor, cx);
1283 })
1284 }
1285 }),
1286 cx.observe(editor, move |editor, cx| {
1287 InlineAssistant::update_global(cx, |this, cx| {
1288 this.handle_editor_change(editor, cx)
1289 })
1290 }),
1291 cx.subscribe(editor, move |editor, event, cx| {
1292 InlineAssistant::update_global(cx, |this, cx| {
1293 this.handle_editor_event(editor, event, cx)
1294 })
1295 }),
1296 editor.update(cx, |editor, cx| {
1297 let editor_handle = cx.view().downgrade();
1298 editor.register_action(
1299 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1300 InlineAssistant::update_global(cx, |this, cx| {
1301 if let Some(editor) = editor_handle.upgrade() {
1302 this.handle_editor_newline(editor, cx)
1303 }
1304 })
1305 },
1306 )
1307 }),
1308 editor.update(cx, |editor, cx| {
1309 let editor_handle = cx.view().downgrade();
1310 editor.register_action(
1311 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1312 InlineAssistant::update_global(cx, |this, cx| {
1313 if let Some(editor) = editor_handle.upgrade() {
1314 this.handle_editor_cancel(editor, cx)
1315 }
1316 })
1317 },
1318 )
1319 }),
1320 ],
1321 }
1322 }
1323}
1324
1325struct InlineAssistGroup {
1326 assist_ids: Vec<InlineAssistId>,
1327 linked: bool,
1328 active_assist_id: Option<InlineAssistId>,
1329}
1330
1331impl InlineAssistGroup {
1332 fn new() -> Self {
1333 Self {
1334 assist_ids: Vec::new(),
1335 linked: true,
1336 active_assist_id: None,
1337 }
1338 }
1339}
1340
1341fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1342 let editor = editor.clone();
1343 Box::new(move |cx: &mut BlockContext| {
1344 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1345 editor.clone().into_any_element()
1346 })
1347}
1348
1349#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1350pub struct InlineAssistId(usize);
1351
1352impl InlineAssistId {
1353 fn post_inc(&mut self) -> InlineAssistId {
1354 let id = *self;
1355 self.0 += 1;
1356 id
1357 }
1358}
1359
1360#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1361struct InlineAssistGroupId(usize);
1362
1363impl InlineAssistGroupId {
1364 fn post_inc(&mut self) -> InlineAssistGroupId {
1365 let id = *self;
1366 self.0 += 1;
1367 id
1368 }
1369}
1370
1371enum PromptEditorEvent {
1372 StartRequested,
1373 StopRequested,
1374 ConfirmRequested,
1375 CancelRequested,
1376 DismissRequested,
1377}
1378
1379struct PromptEditor {
1380 id: InlineAssistId,
1381 fs: Arc<dyn Fs>,
1382 editor: View<Editor>,
1383 edited_since_done: bool,
1384 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1385 prompt_history: VecDeque<String>,
1386 prompt_history_ix: Option<usize>,
1387 pending_prompt: String,
1388 codegen: Model<Codegen>,
1389 _codegen_subscription: Subscription,
1390 editor_subscriptions: Vec<Subscription>,
1391 pending_token_count: Task<Result<()>>,
1392 token_counts: Option<TokenCounts>,
1393 _token_count_subscriptions: Vec<Subscription>,
1394 workspace: Option<WeakView<Workspace>>,
1395 show_rate_limit_notice: bool,
1396}
1397
1398#[derive(Copy, Clone)]
1399pub struct TokenCounts {
1400 total: usize,
1401 assistant_panel: usize,
1402}
1403
1404impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1405
1406impl Render for PromptEditor {
1407 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1408 let gutter_dimensions = *self.gutter_dimensions.lock();
1409 let status = &self.codegen.read(cx).status;
1410 let buttons = match status {
1411 CodegenStatus::Idle => {
1412 vec![
1413 IconButton::new("cancel", IconName::Close)
1414 .icon_color(Color::Muted)
1415 .shape(IconButtonShape::Square)
1416 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1417 .on_click(
1418 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1419 ),
1420 IconButton::new("start", IconName::SparkleAlt)
1421 .icon_color(Color::Muted)
1422 .shape(IconButtonShape::Square)
1423 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1424 .on_click(
1425 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1426 ),
1427 ]
1428 }
1429 CodegenStatus::Pending => {
1430 vec![
1431 IconButton::new("cancel", IconName::Close)
1432 .icon_color(Color::Muted)
1433 .shape(IconButtonShape::Square)
1434 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1435 .on_click(
1436 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1437 ),
1438 IconButton::new("stop", IconName::Stop)
1439 .icon_color(Color::Error)
1440 .shape(IconButtonShape::Square)
1441 .tooltip(|cx| {
1442 Tooltip::with_meta(
1443 "Interrupt Transformation",
1444 Some(&menu::Cancel),
1445 "Changes won't be discarded",
1446 cx,
1447 )
1448 })
1449 .on_click(
1450 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1451 ),
1452 ]
1453 }
1454 CodegenStatus::Error(_) | CodegenStatus::Done => {
1455 vec![
1456 IconButton::new("cancel", IconName::Close)
1457 .icon_color(Color::Muted)
1458 .shape(IconButtonShape::Square)
1459 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1460 .on_click(
1461 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1462 ),
1463 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1464 IconButton::new("restart", IconName::RotateCw)
1465 .icon_color(Color::Info)
1466 .shape(IconButtonShape::Square)
1467 .tooltip(|cx| {
1468 Tooltip::with_meta(
1469 "Restart Transformation",
1470 Some(&menu::Confirm),
1471 "Changes will be discarded",
1472 cx,
1473 )
1474 })
1475 .on_click(cx.listener(|_, _, cx| {
1476 cx.emit(PromptEditorEvent::StartRequested);
1477 }))
1478 } else {
1479 IconButton::new("confirm", IconName::Check)
1480 .icon_color(Color::Info)
1481 .shape(IconButtonShape::Square)
1482 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1483 .on_click(cx.listener(|_, _, cx| {
1484 cx.emit(PromptEditorEvent::ConfirmRequested);
1485 }))
1486 },
1487 ]
1488 }
1489 };
1490
1491 h_flex()
1492 .bg(cx.theme().colors().editor_background)
1493 .border_y_1()
1494 .border_color(cx.theme().status().info_border)
1495 .size_full()
1496 .py(cx.line_height() / 2.5)
1497 .on_action(cx.listener(Self::confirm))
1498 .on_action(cx.listener(Self::cancel))
1499 .on_action(cx.listener(Self::move_up))
1500 .on_action(cx.listener(Self::move_down))
1501 .child(
1502 h_flex()
1503 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1504 .justify_center()
1505 .gap_2()
1506 .child(
1507 ModelSelector::new(
1508 self.fs.clone(),
1509 IconButton::new("context", IconName::SettingsAlt)
1510 .shape(IconButtonShape::Square)
1511 .icon_size(IconSize::Small)
1512 .icon_color(Color::Muted)
1513 .tooltip(move |cx| {
1514 Tooltip::with_meta(
1515 format!(
1516 "Using {}",
1517 LanguageModelRegistry::read_global(cx)
1518 .active_model()
1519 .map(|model| model.name().0)
1520 .unwrap_or_else(|| "No model selected".into()),
1521 ),
1522 None,
1523 "Change Model",
1524 cx,
1525 )
1526 }),
1527 )
1528 .with_info_text(
1529 "Inline edits use context\n\
1530 from the currently selected\n\
1531 assistant panel tab.",
1532 ),
1533 )
1534 .map(|el| {
1535 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1536 return el;
1537 };
1538
1539 let error_message = SharedString::from(error.to_string());
1540 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1541 && cx.has_flag::<ZedPro>()
1542 {
1543 el.child(
1544 v_flex()
1545 .child(
1546 IconButton::new("rate-limit-error", IconName::XCircle)
1547 .selected(self.show_rate_limit_notice)
1548 .shape(IconButtonShape::Square)
1549 .icon_size(IconSize::Small)
1550 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1551 )
1552 .children(self.show_rate_limit_notice.then(|| {
1553 deferred(
1554 anchored()
1555 .position_mode(gpui::AnchoredPositionMode::Local)
1556 .position(point(px(0.), px(24.)))
1557 .anchor(gpui::AnchorCorner::TopLeft)
1558 .child(self.render_rate_limit_notice(cx)),
1559 )
1560 })),
1561 )
1562 } else {
1563 el.child(
1564 div()
1565 .id("error")
1566 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1567 .child(
1568 Icon::new(IconName::XCircle)
1569 .size(IconSize::Small)
1570 .color(Color::Error),
1571 ),
1572 )
1573 }
1574 }),
1575 )
1576 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1577 .child(
1578 h_flex()
1579 .gap_2()
1580 .pr_6()
1581 .children(self.render_token_count(cx))
1582 .children(buttons),
1583 )
1584 }
1585}
1586
1587impl FocusableView for PromptEditor {
1588 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1589 self.editor.focus_handle(cx)
1590 }
1591}
1592
1593impl PromptEditor {
1594 const MAX_LINES: u8 = 8;
1595
1596 #[allow(clippy::too_many_arguments)]
1597 fn new(
1598 id: InlineAssistId,
1599 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1600 prompt_history: VecDeque<String>,
1601 prompt_buffer: Model<MultiBuffer>,
1602 codegen: Model<Codegen>,
1603 parent_editor: &View<Editor>,
1604 assistant_panel: Option<&View<AssistantPanel>>,
1605 workspace: Option<WeakView<Workspace>>,
1606 fs: Arc<dyn Fs>,
1607 cx: &mut ViewContext<Self>,
1608 ) -> Self {
1609 let prompt_editor = cx.new_view(|cx| {
1610 let mut editor = Editor::new(
1611 EditorMode::AutoHeight {
1612 max_lines: Self::MAX_LINES as usize,
1613 },
1614 prompt_buffer,
1615 None,
1616 false,
1617 cx,
1618 );
1619 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1620 // Since the prompt editors for all inline assistants are linked,
1621 // always show the cursor (even when it isn't focused) because
1622 // typing in one will make what you typed appear in all of them.
1623 editor.set_show_cursor_when_unfocused(true, cx);
1624 editor.set_placeholder_text("Add a prompt…", cx);
1625 editor
1626 });
1627
1628 let mut token_count_subscriptions = Vec::new();
1629 token_count_subscriptions
1630 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1631 if let Some(assistant_panel) = assistant_panel {
1632 token_count_subscriptions
1633 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1634 }
1635
1636 let mut this = Self {
1637 id,
1638 editor: prompt_editor,
1639 edited_since_done: false,
1640 gutter_dimensions,
1641 prompt_history,
1642 prompt_history_ix: None,
1643 pending_prompt: String::new(),
1644 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1645 editor_subscriptions: Vec::new(),
1646 codegen,
1647 fs,
1648 pending_token_count: Task::ready(Ok(())),
1649 token_counts: None,
1650 _token_count_subscriptions: token_count_subscriptions,
1651 workspace,
1652 show_rate_limit_notice: false,
1653 };
1654 this.count_tokens(cx);
1655 this.subscribe_to_editor(cx);
1656 this
1657 }
1658
1659 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1660 self.editor_subscriptions.clear();
1661 self.editor_subscriptions
1662 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1663 }
1664
1665 fn set_show_cursor_when_unfocused(
1666 &mut self,
1667 show_cursor_when_unfocused: bool,
1668 cx: &mut ViewContext<Self>,
1669 ) {
1670 self.editor.update(cx, |editor, cx| {
1671 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1672 });
1673 }
1674
1675 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1676 let prompt = self.prompt(cx);
1677 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1678 self.editor = cx.new_view(|cx| {
1679 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1680 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1681 editor.set_placeholder_text("Add a prompt…", cx);
1682 editor.set_text(prompt, cx);
1683 if focus {
1684 editor.focus(cx);
1685 }
1686 editor
1687 });
1688 self.subscribe_to_editor(cx);
1689 }
1690
1691 fn prompt(&self, cx: &AppContext) -> String {
1692 self.editor.read(cx).text(cx)
1693 }
1694
1695 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1696 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1697 if self.show_rate_limit_notice {
1698 cx.focus_view(&self.editor);
1699 }
1700 cx.notify();
1701 }
1702
1703 fn handle_parent_editor_event(
1704 &mut self,
1705 _: View<Editor>,
1706 event: &EditorEvent,
1707 cx: &mut ViewContext<Self>,
1708 ) {
1709 if let EditorEvent::BufferEdited { .. } = event {
1710 self.count_tokens(cx);
1711 }
1712 }
1713
1714 fn handle_assistant_panel_event(
1715 &mut self,
1716 _: View<AssistantPanel>,
1717 event: &AssistantPanelEvent,
1718 cx: &mut ViewContext<Self>,
1719 ) {
1720 let AssistantPanelEvent::ContextEdited { .. } = event;
1721 self.count_tokens(cx);
1722 }
1723
1724 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1725 let assist_id = self.id;
1726 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1727 cx.background_executor().timer(Duration::from_secs(1)).await;
1728 let token_count = cx
1729 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1730 let assist = inline_assistant
1731 .assists
1732 .get(&assist_id)
1733 .context("assist not found")?;
1734 anyhow::Ok(assist.count_tokens(cx))
1735 })??
1736 .await?;
1737
1738 this.update(&mut cx, |this, cx| {
1739 this.token_counts = Some(token_count);
1740 cx.notify();
1741 })
1742 })
1743 }
1744
1745 fn handle_prompt_editor_events(
1746 &mut self,
1747 _: View<Editor>,
1748 event: &EditorEvent,
1749 cx: &mut ViewContext<Self>,
1750 ) {
1751 match event {
1752 EditorEvent::Edited { .. } => {
1753 let prompt = self.editor.read(cx).text(cx);
1754 if self
1755 .prompt_history_ix
1756 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1757 {
1758 self.prompt_history_ix.take();
1759 self.pending_prompt = prompt;
1760 }
1761
1762 self.edited_since_done = true;
1763 cx.notify();
1764 }
1765 EditorEvent::BufferEdited => {
1766 self.count_tokens(cx);
1767 }
1768 EditorEvent::Blurred => {
1769 if self.show_rate_limit_notice {
1770 self.show_rate_limit_notice = false;
1771 cx.notify();
1772 }
1773 }
1774 _ => {}
1775 }
1776 }
1777
1778 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1779 match &self.codegen.read(cx).status {
1780 CodegenStatus::Idle => {
1781 self.editor
1782 .update(cx, |editor, _| editor.set_read_only(false));
1783 }
1784 CodegenStatus::Pending => {
1785 self.editor
1786 .update(cx, |editor, _| editor.set_read_only(true));
1787 }
1788 CodegenStatus::Done => {
1789 self.edited_since_done = false;
1790 self.editor
1791 .update(cx, |editor, _| editor.set_read_only(false));
1792 }
1793 CodegenStatus::Error(error) => {
1794 if cx.has_flag::<ZedPro>()
1795 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1796 && !dismissed_rate_limit_notice()
1797 {
1798 self.show_rate_limit_notice = true;
1799 cx.notify();
1800 }
1801
1802 self.edited_since_done = false;
1803 self.editor
1804 .update(cx, |editor, _| editor.set_read_only(false));
1805 }
1806 }
1807 }
1808
1809 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1810 match &self.codegen.read(cx).status {
1811 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1812 cx.emit(PromptEditorEvent::CancelRequested);
1813 }
1814 CodegenStatus::Pending => {
1815 cx.emit(PromptEditorEvent::StopRequested);
1816 }
1817 }
1818 }
1819
1820 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1821 match &self.codegen.read(cx).status {
1822 CodegenStatus::Idle => {
1823 cx.emit(PromptEditorEvent::StartRequested);
1824 }
1825 CodegenStatus::Pending => {
1826 cx.emit(PromptEditorEvent::DismissRequested);
1827 }
1828 CodegenStatus::Done => {
1829 if self.edited_since_done {
1830 cx.emit(PromptEditorEvent::StartRequested);
1831 } else {
1832 cx.emit(PromptEditorEvent::ConfirmRequested);
1833 }
1834 }
1835 CodegenStatus::Error(_) => {
1836 cx.emit(PromptEditorEvent::StartRequested);
1837 }
1838 }
1839 }
1840
1841 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1842 if let Some(ix) = self.prompt_history_ix {
1843 if ix > 0 {
1844 self.prompt_history_ix = Some(ix - 1);
1845 let prompt = self.prompt_history[ix - 1].as_str();
1846 self.editor.update(cx, |editor, cx| {
1847 editor.set_text(prompt, cx);
1848 editor.move_to_beginning(&Default::default(), cx);
1849 });
1850 }
1851 } else if !self.prompt_history.is_empty() {
1852 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1853 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1854 self.editor.update(cx, |editor, cx| {
1855 editor.set_text(prompt, cx);
1856 editor.move_to_beginning(&Default::default(), cx);
1857 });
1858 }
1859 }
1860
1861 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1862 if let Some(ix) = self.prompt_history_ix {
1863 if ix < self.prompt_history.len() - 1 {
1864 self.prompt_history_ix = Some(ix + 1);
1865 let prompt = self.prompt_history[ix + 1].as_str();
1866 self.editor.update(cx, |editor, cx| {
1867 editor.set_text(prompt, cx);
1868 editor.move_to_end(&Default::default(), cx)
1869 });
1870 } else {
1871 self.prompt_history_ix = None;
1872 let prompt = self.pending_prompt.as_str();
1873 self.editor.update(cx, |editor, cx| {
1874 editor.set_text(prompt, cx);
1875 editor.move_to_end(&Default::default(), cx)
1876 });
1877 }
1878 }
1879 }
1880
1881 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1882 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1883 let token_counts = self.token_counts?;
1884 let max_token_count = model.max_token_count();
1885
1886 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1887 let token_count_color = if remaining_tokens <= 0 {
1888 Color::Error
1889 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1890 Color::Warning
1891 } else {
1892 Color::Muted
1893 };
1894
1895 let mut token_count = h_flex()
1896 .id("token_count")
1897 .gap_0p5()
1898 .child(
1899 Label::new(humanize_token_count(token_counts.total))
1900 .size(LabelSize::Small)
1901 .color(token_count_color),
1902 )
1903 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1904 .child(
1905 Label::new(humanize_token_count(max_token_count))
1906 .size(LabelSize::Small)
1907 .color(Color::Muted),
1908 );
1909 if let Some(workspace) = self.workspace.clone() {
1910 token_count = token_count
1911 .tooltip(move |cx| {
1912 Tooltip::with_meta(
1913 format!(
1914 "Tokens Used ({} from the Assistant Panel)",
1915 humanize_token_count(token_counts.assistant_panel)
1916 ),
1917 None,
1918 "Click to open the Assistant Panel",
1919 cx,
1920 )
1921 })
1922 .cursor_pointer()
1923 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1924 .on_click(move |_, cx| {
1925 cx.stop_propagation();
1926 workspace
1927 .update(cx, |workspace, cx| {
1928 workspace.focus_panel::<AssistantPanel>(cx)
1929 })
1930 .ok();
1931 });
1932 } else {
1933 token_count = token_count
1934 .cursor_default()
1935 .tooltip(|cx| Tooltip::text("Tokens used", cx));
1936 }
1937
1938 Some(token_count)
1939 }
1940
1941 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1942 let settings = ThemeSettings::get_global(cx);
1943 let text_style = TextStyle {
1944 color: if self.editor.read(cx).read_only(cx) {
1945 cx.theme().colors().text_disabled
1946 } else {
1947 cx.theme().colors().text
1948 },
1949 font_family: settings.buffer_font.family.clone(),
1950 font_fallbacks: settings.buffer_font.fallbacks.clone(),
1951 font_size: settings.buffer_font_size.into(),
1952 font_weight: settings.buffer_font.weight,
1953 line_height: relative(settings.buffer_line_height.value()),
1954 ..Default::default()
1955 };
1956 EditorElement::new(
1957 &self.editor,
1958 EditorStyle {
1959 background: cx.theme().colors().editor_background,
1960 local_player: cx.theme().players().local(),
1961 text: text_style,
1962 ..Default::default()
1963 },
1964 )
1965 }
1966
1967 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1968 Popover::new().child(
1969 v_flex()
1970 .occlude()
1971 .p_2()
1972 .child(
1973 Label::new("Out of Tokens")
1974 .size(LabelSize::Small)
1975 .weight(FontWeight::BOLD),
1976 )
1977 .child(Label::new(
1978 "Try Zed Pro for higher limits, a wider range of models, and more.",
1979 ))
1980 .child(
1981 h_flex()
1982 .justify_between()
1983 .child(CheckboxWithLabel::new(
1984 "dont-show-again",
1985 Label::new("Don't show again"),
1986 if dismissed_rate_limit_notice() {
1987 ui::Selection::Selected
1988 } else {
1989 ui::Selection::Unselected
1990 },
1991 |selection, cx| {
1992 let is_dismissed = match selection {
1993 ui::Selection::Unselected => false,
1994 ui::Selection::Indeterminate => return,
1995 ui::Selection::Selected => true,
1996 };
1997
1998 set_rate_limit_notice_dismissed(is_dismissed, cx)
1999 },
2000 ))
2001 .child(
2002 h_flex()
2003 .gap_2()
2004 .child(
2005 Button::new("dismiss", "Dismiss")
2006 .style(ButtonStyle::Transparent)
2007 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2008 )
2009 .child(Button::new("more-info", "More Info").on_click(
2010 |_event, cx| {
2011 cx.dispatch_action(Box::new(
2012 zed_actions::OpenAccountSettings,
2013 ))
2014 },
2015 )),
2016 ),
2017 ),
2018 )
2019 }
2020}
2021
2022const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2023
2024fn dismissed_rate_limit_notice() -> bool {
2025 db::kvp::KEY_VALUE_STORE
2026 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2027 .log_err()
2028 .map_or(false, |s| s.is_some())
2029}
2030
2031fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2032 db::write_and_log(cx, move || async move {
2033 if is_dismissed {
2034 db::kvp::KEY_VALUE_STORE
2035 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2036 .await
2037 } else {
2038 db::kvp::KEY_VALUE_STORE
2039 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2040 .await
2041 }
2042 })
2043}
2044
2045struct InlineAssist {
2046 group_id: InlineAssistGroupId,
2047 range: Range<Anchor>,
2048 editor: WeakView<Editor>,
2049 decorations: Option<InlineAssistDecorations>,
2050 codegen: Model<Codegen>,
2051 _subscriptions: Vec<Subscription>,
2052 workspace: Option<WeakView<Workspace>>,
2053 include_context: bool,
2054}
2055
2056impl InlineAssist {
2057 #[allow(clippy::too_many_arguments)]
2058 fn new(
2059 assist_id: InlineAssistId,
2060 group_id: InlineAssistGroupId,
2061 include_context: bool,
2062 editor: &View<Editor>,
2063 prompt_editor: &View<PromptEditor>,
2064 prompt_block_id: CustomBlockId,
2065 end_block_id: CustomBlockId,
2066 range: Range<Anchor>,
2067 codegen: Model<Codegen>,
2068 workspace: Option<WeakView<Workspace>>,
2069 cx: &mut WindowContext,
2070 ) -> Self {
2071 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2072 InlineAssist {
2073 group_id,
2074 include_context,
2075 editor: editor.downgrade(),
2076 decorations: Some(InlineAssistDecorations {
2077 prompt_block_id,
2078 prompt_editor: prompt_editor.clone(),
2079 removed_line_block_ids: HashSet::default(),
2080 end_block_id,
2081 }),
2082 range,
2083 codegen: codegen.clone(),
2084 workspace: workspace.clone(),
2085 _subscriptions: vec![
2086 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2087 InlineAssistant::update_global(cx, |this, cx| {
2088 this.handle_prompt_editor_focus_in(assist_id, cx)
2089 })
2090 }),
2091 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2092 InlineAssistant::update_global(cx, |this, cx| {
2093 this.handle_prompt_editor_focus_out(assist_id, cx)
2094 })
2095 }),
2096 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2097 InlineAssistant::update_global(cx, |this, cx| {
2098 this.handle_prompt_editor_event(prompt_editor, event, cx)
2099 })
2100 }),
2101 cx.observe(&codegen, {
2102 let editor = editor.downgrade();
2103 move |_, cx| {
2104 if let Some(editor) = editor.upgrade() {
2105 InlineAssistant::update_global(cx, |this, cx| {
2106 if let Some(editor_assists) =
2107 this.assists_by_editor.get(&editor.downgrade())
2108 {
2109 editor_assists.highlight_updates.send(()).ok();
2110 }
2111
2112 this.update_editor_blocks(&editor, assist_id, cx);
2113 })
2114 }
2115 }
2116 }),
2117 cx.subscribe(&codegen, move |codegen, event, cx| {
2118 InlineAssistant::update_global(cx, |this, cx| match event {
2119 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2120 CodegenEvent::Finished => {
2121 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2122 assist
2123 } else {
2124 return;
2125 };
2126
2127 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
2128 if assist.decorations.is_none() {
2129 if let Some(workspace) = assist
2130 .workspace
2131 .as_ref()
2132 .and_then(|workspace| workspace.upgrade())
2133 {
2134 let error = format!("Inline assistant error: {}", error);
2135 workspace.update(cx, |workspace, cx| {
2136 struct InlineAssistantError;
2137
2138 let id =
2139 NotificationId::identified::<InlineAssistantError>(
2140 assist_id.0,
2141 );
2142
2143 workspace.show_toast(Toast::new(id, error), cx);
2144 })
2145 }
2146 }
2147 }
2148
2149 if assist.decorations.is_none() {
2150 this.finish_assist(assist_id, false, cx);
2151 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2152 tx.0.send(AssistStatus::Finished).ok();
2153 }
2154 }
2155 })
2156 }),
2157 ],
2158 }
2159 }
2160
2161 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2162 let decorations = self.decorations.as_ref()?;
2163 Some(decorations.prompt_editor.read(cx).prompt(cx))
2164 }
2165
2166 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2167 if self.include_context {
2168 let workspace = self.workspace.as_ref()?;
2169 let workspace = workspace.upgrade()?.read(cx);
2170 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2171 Some(
2172 assistant_panel
2173 .read(cx)
2174 .active_context(cx)?
2175 .read(cx)
2176 .to_completion_request(cx),
2177 )
2178 } else {
2179 None
2180 }
2181 }
2182
2183 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2184 let Some(user_prompt) = self.user_prompt(cx) else {
2185 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2186 };
2187 let assistant_panel_context = self.assistant_panel_context(cx);
2188 self.codegen.read(cx).count_tokens(
2189 self.range.clone(),
2190 user_prompt,
2191 assistant_panel_context,
2192 cx,
2193 )
2194 }
2195}
2196
2197struct InlineAssistDecorations {
2198 prompt_block_id: CustomBlockId,
2199 prompt_editor: View<PromptEditor>,
2200 removed_line_block_ids: HashSet<CustomBlockId>,
2201 end_block_id: CustomBlockId,
2202}
2203
2204#[derive(Debug)]
2205pub enum CodegenEvent {
2206 Finished,
2207 Undone,
2208}
2209
2210pub struct Codegen {
2211 buffer: Model<MultiBuffer>,
2212 old_buffer: Model<Buffer>,
2213 snapshot: MultiBufferSnapshot,
2214 edit_position: Option<Anchor>,
2215 last_equal_ranges: Vec<Range<Anchor>>,
2216 initial_transaction_id: Option<TransactionId>,
2217 transformation_transaction_id: Option<TransactionId>,
2218 status: CodegenStatus,
2219 generation: Task<()>,
2220 diff: Diff,
2221 telemetry: Option<Arc<Telemetry>>,
2222 _subscription: gpui::Subscription,
2223 builder: Arc<PromptBuilder>,
2224}
2225
2226enum CodegenStatus {
2227 Idle,
2228 Pending,
2229 Done,
2230 Error(anyhow::Error),
2231}
2232
2233#[derive(Default)]
2234struct Diff {
2235 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2236 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2237}
2238
2239impl Diff {
2240 fn is_empty(&self) -> bool {
2241 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2242 }
2243}
2244
2245impl EventEmitter<CodegenEvent> for Codegen {}
2246
2247impl Codegen {
2248 pub fn new(
2249 buffer: Model<MultiBuffer>,
2250 range: Range<Anchor>,
2251 initial_transaction_id: Option<TransactionId>,
2252 telemetry: Option<Arc<Telemetry>>,
2253 builder: Arc<PromptBuilder>,
2254 cx: &mut ModelContext<Self>,
2255 ) -> Self {
2256 let snapshot = buffer.read(cx).snapshot(cx);
2257
2258 let (old_buffer, _, _) = buffer
2259 .read(cx)
2260 .range_to_buffer_ranges(range.clone(), cx)
2261 .pop()
2262 .unwrap();
2263 let old_buffer = cx.new_model(|cx| {
2264 let old_buffer = old_buffer.read(cx);
2265 let text = old_buffer.as_rope().clone();
2266 let line_ending = old_buffer.line_ending();
2267 let language = old_buffer.language().cloned();
2268 let language_registry = old_buffer.language_registry();
2269
2270 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2271 buffer.set_language(language, cx);
2272 if let Some(language_registry) = language_registry {
2273 buffer.set_language_registry(language_registry)
2274 }
2275 buffer
2276 });
2277
2278 Self {
2279 buffer: buffer.clone(),
2280 old_buffer,
2281 edit_position: None,
2282 snapshot,
2283 last_equal_ranges: Default::default(),
2284 transformation_transaction_id: None,
2285 status: CodegenStatus::Idle,
2286 generation: Task::ready(()),
2287 diff: Diff::default(),
2288 telemetry,
2289 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2290 initial_transaction_id,
2291 builder,
2292 }
2293 }
2294
2295 fn handle_buffer_event(
2296 &mut self,
2297 _buffer: Model<MultiBuffer>,
2298 event: &multi_buffer::Event,
2299 cx: &mut ModelContext<Self>,
2300 ) {
2301 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2302 if self.transformation_transaction_id == Some(*transaction_id) {
2303 self.transformation_transaction_id = None;
2304 self.generation = Task::ready(());
2305 cx.emit(CodegenEvent::Undone);
2306 }
2307 }
2308 }
2309
2310 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2311 &self.last_equal_ranges
2312 }
2313
2314 pub fn count_tokens(
2315 &self,
2316 edit_range: Range<Anchor>,
2317 user_prompt: String,
2318 assistant_panel_context: Option<LanguageModelRequest>,
2319 cx: &AppContext,
2320 ) -> BoxFuture<'static, Result<TokenCounts>> {
2321 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2322 let request =
2323 self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
2324 match request {
2325 Ok(request) => {
2326 let total_count = model.count_tokens(request.clone(), cx);
2327 let assistant_panel_count = assistant_panel_context
2328 .map(|context| model.count_tokens(context, cx))
2329 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2330
2331 async move {
2332 Ok(TokenCounts {
2333 total: total_count.await?,
2334 assistant_panel: assistant_panel_count.await?,
2335 })
2336 }
2337 .boxed()
2338 }
2339 Err(error) => futures::future::ready(Err(error)).boxed(),
2340 }
2341 } else {
2342 future::ready(Err(anyhow!("no active model"))).boxed()
2343 }
2344 }
2345
2346 pub fn start(
2347 &mut self,
2348 edit_range: Range<Anchor>,
2349 user_prompt: String,
2350 assistant_panel_context: Option<LanguageModelRequest>,
2351 cx: &mut ModelContext<Self>,
2352 ) -> Result<()> {
2353 let model = LanguageModelRegistry::read_global(cx)
2354 .active_model()
2355 .context("no active model")?;
2356
2357 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2358 self.buffer.update(cx, |buffer, cx| {
2359 buffer.undo_transaction(transformation_transaction_id, cx);
2360 });
2361 }
2362
2363 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2364
2365 let telemetry_id = model.telemetry_id();
2366 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2367 .trim()
2368 .to_lowercase()
2369 == "delete"
2370 {
2371 async { Ok(stream::empty().boxed()) }.boxed_local()
2372 } else {
2373 let request =
2374 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
2375
2376 let chunks =
2377 cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
2378 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2379 };
2380 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2381 Ok(())
2382 }
2383
2384 fn build_request(
2385 &self,
2386 user_prompt: String,
2387 assistant_panel_context: Option<LanguageModelRequest>,
2388 edit_range: Range<Anchor>,
2389 cx: &AppContext,
2390 ) -> Result<LanguageModelRequest> {
2391 let buffer = self.buffer.read(cx).snapshot(cx);
2392 let language = buffer.language_at(edit_range.start);
2393 let language_name = if let Some(language) = language.as_ref() {
2394 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2395 None
2396 } else {
2397 Some(language.name())
2398 }
2399 } else {
2400 None
2401 };
2402
2403 let language_name = language_name.as_ref();
2404 let start = buffer.point_to_buffer_offset(edit_range.start);
2405 let end = buffer.point_to_buffer_offset(edit_range.end);
2406 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2407 let (start_buffer, start_buffer_offset) = start;
2408 let (end_buffer, end_buffer_offset) = end;
2409 if start_buffer.remote_id() == end_buffer.remote_id() {
2410 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2411 } else {
2412 return Err(anyhow::anyhow!("invalid transformation range"));
2413 }
2414 } else {
2415 return Err(anyhow::anyhow!("invalid transformation range"));
2416 };
2417
2418 let prompt = self
2419 .builder
2420 .generate_content_prompt(user_prompt, language_name, buffer, range)
2421 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2422
2423 let mut messages = Vec::new();
2424 if let Some(context_request) = assistant_panel_context {
2425 messages = context_request.messages;
2426 }
2427
2428 messages.push(LanguageModelRequestMessage {
2429 role: Role::User,
2430 content: vec![prompt.into()],
2431 cache: false,
2432 });
2433
2434 Ok(LanguageModelRequest {
2435 messages,
2436 tools: Vec::new(),
2437 stop: Vec::new(),
2438 temperature: 1.,
2439 })
2440 }
2441
2442 pub fn handle_stream(
2443 &mut self,
2444 model_telemetry_id: String,
2445 edit_range: Range<Anchor>,
2446 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2447 cx: &mut ModelContext<Self>,
2448 ) {
2449 let snapshot = self.snapshot.clone();
2450 let selected_text = snapshot
2451 .text_for_range(edit_range.start..edit_range.end)
2452 .collect::<Rope>();
2453
2454 let selection_start = edit_range.start.to_point(&snapshot);
2455
2456 // Start with the indentation of the first line in the selection
2457 let mut suggested_line_indent = snapshot
2458 .suggested_indents(selection_start.row..=selection_start.row, cx)
2459 .into_values()
2460 .next()
2461 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2462
2463 // If the first line in the selection does not have indentation, check the following lines
2464 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2465 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2466 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2467 // Prefer tabs if a line in the selection uses tabs as indentation
2468 if line_indent.kind == IndentKind::Tab {
2469 suggested_line_indent.kind = IndentKind::Tab;
2470 break;
2471 }
2472 }
2473 }
2474
2475 let telemetry = self.telemetry.clone();
2476 self.diff = Diff::default();
2477 self.status = CodegenStatus::Pending;
2478 let mut edit_start = edit_range.start.to_offset(&snapshot);
2479 self.generation = cx.spawn(|codegen, mut cx| {
2480 async move {
2481 let chunks = stream.await;
2482 let generate = async {
2483 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2484 let line_based_stream_diff: Task<anyhow::Result<()>> =
2485 cx.background_executor().spawn(async move {
2486 let mut response_latency = None;
2487 let request_start = Instant::now();
2488 let diff = async {
2489 let chunks = StripInvalidSpans::new(chunks?);
2490 futures::pin_mut!(chunks);
2491 let mut diff = StreamingDiff::new(selected_text.to_string());
2492 let mut line_diff = LineDiff::default();
2493
2494 let mut new_text = String::new();
2495 let mut base_indent = None;
2496 let mut line_indent = None;
2497 let mut first_line = true;
2498
2499 while let Some(chunk) = chunks.next().await {
2500 if response_latency.is_none() {
2501 response_latency = Some(request_start.elapsed());
2502 }
2503 let chunk = chunk?;
2504
2505 let mut lines = chunk.split('\n').peekable();
2506 while let Some(line) = lines.next() {
2507 new_text.push_str(line);
2508 if line_indent.is_none() {
2509 if let Some(non_whitespace_ch_ix) =
2510 new_text.find(|ch: char| !ch.is_whitespace())
2511 {
2512 line_indent = Some(non_whitespace_ch_ix);
2513 base_indent = base_indent.or(line_indent);
2514
2515 let line_indent = line_indent.unwrap();
2516 let base_indent = base_indent.unwrap();
2517 let indent_delta =
2518 line_indent as i32 - base_indent as i32;
2519 let mut corrected_indent_len = cmp::max(
2520 0,
2521 suggested_line_indent.len as i32 + indent_delta,
2522 )
2523 as usize;
2524 if first_line {
2525 corrected_indent_len = corrected_indent_len
2526 .saturating_sub(
2527 selection_start.column as usize,
2528 );
2529 }
2530
2531 let indent_char = suggested_line_indent.char();
2532 let mut indent_buffer = [0; 4];
2533 let indent_str =
2534 indent_char.encode_utf8(&mut indent_buffer);
2535 new_text.replace_range(
2536 ..line_indent,
2537 &indent_str.repeat(corrected_indent_len),
2538 );
2539 }
2540 }
2541
2542 if line_indent.is_some() {
2543 let char_ops = diff.push_new(&new_text);
2544 line_diff
2545 .push_char_operations(&char_ops, &selected_text);
2546 diff_tx
2547 .send((char_ops, line_diff.line_operations()))
2548 .await?;
2549 new_text.clear();
2550 }
2551
2552 if lines.peek().is_some() {
2553 let char_ops = diff.push_new("\n");
2554 line_diff
2555 .push_char_operations(&char_ops, &selected_text);
2556 diff_tx
2557 .send((char_ops, line_diff.line_operations()))
2558 .await?;
2559 if line_indent.is_none() {
2560 // Don't write out the leading indentation in empty lines on the next line
2561 // This is the case where the above if statement didn't clear the buffer
2562 new_text.clear();
2563 }
2564 line_indent = None;
2565 first_line = false;
2566 }
2567 }
2568 }
2569
2570 let mut char_ops = diff.push_new(&new_text);
2571 char_ops.extend(diff.finish());
2572 line_diff.push_char_operations(&char_ops, &selected_text);
2573 line_diff.finish(&selected_text);
2574 diff_tx
2575 .send((char_ops, line_diff.line_operations()))
2576 .await?;
2577
2578 anyhow::Ok(())
2579 };
2580
2581 let result = diff.await;
2582
2583 let error_message =
2584 result.as_ref().err().map(|error| error.to_string());
2585 if let Some(telemetry) = telemetry {
2586 telemetry.report_assistant_event(
2587 None,
2588 telemetry_events::AssistantKind::Inline,
2589 telemetry_events::AssistantPhase::Response,
2590 model_telemetry_id,
2591 response_latency,
2592 error_message,
2593 );
2594 }
2595
2596 result?;
2597 Ok(())
2598 });
2599
2600 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2601 codegen.update(&mut cx, |codegen, cx| {
2602 codegen.last_equal_ranges.clear();
2603
2604 let transaction = codegen.buffer.update(cx, |buffer, cx| {
2605 // Avoid grouping assistant edits with user edits.
2606 buffer.finalize_last_transaction(cx);
2607
2608 buffer.start_transaction(cx);
2609 buffer.edit(
2610 char_ops
2611 .into_iter()
2612 .filter_map(|operation| match operation {
2613 CharOperation::Insert { text } => {
2614 let edit_start = snapshot.anchor_after(edit_start);
2615 Some((edit_start..edit_start, text))
2616 }
2617 CharOperation::Delete { bytes } => {
2618 let edit_end = edit_start + bytes;
2619 let edit_range = snapshot.anchor_after(edit_start)
2620 ..snapshot.anchor_before(edit_end);
2621 edit_start = edit_end;
2622 Some((edit_range, String::new()))
2623 }
2624 CharOperation::Keep { bytes } => {
2625 let edit_end = edit_start + bytes;
2626 let edit_range = snapshot.anchor_after(edit_start)
2627 ..snapshot.anchor_before(edit_end);
2628 edit_start = edit_end;
2629 codegen.last_equal_ranges.push(edit_range);
2630 None
2631 }
2632 }),
2633 None,
2634 cx,
2635 );
2636 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2637
2638 buffer.end_transaction(cx)
2639 });
2640
2641 if let Some(transaction) = transaction {
2642 if let Some(first_transaction) =
2643 codegen.transformation_transaction_id
2644 {
2645 // Group all assistant edits into the first transaction.
2646 codegen.buffer.update(cx, |buffer, cx| {
2647 buffer.merge_transactions(
2648 transaction,
2649 first_transaction,
2650 cx,
2651 )
2652 });
2653 } else {
2654 codegen.transformation_transaction_id = Some(transaction);
2655 codegen.buffer.update(cx, |buffer, cx| {
2656 buffer.finalize_last_transaction(cx)
2657 });
2658 }
2659 }
2660
2661 codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx);
2662
2663 cx.notify();
2664 })?;
2665 }
2666
2667 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2668 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2669 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2670 let batch_diff_task = codegen.update(&mut cx, |codegen, cx| {
2671 codegen.reapply_batch_diff(edit_range.clone(), cx)
2672 })?;
2673 let (line_based_stream_diff, ()) =
2674 join!(line_based_stream_diff, batch_diff_task);
2675 line_based_stream_diff?;
2676
2677 anyhow::Ok(())
2678 };
2679
2680 let result = generate.await;
2681 codegen
2682 .update(&mut cx, |this, cx| {
2683 this.last_equal_ranges.clear();
2684 if let Err(error) = result {
2685 this.status = CodegenStatus::Error(error);
2686 } else {
2687 this.status = CodegenStatus::Done;
2688 }
2689 cx.emit(CodegenEvent::Finished);
2690 cx.notify();
2691 })
2692 .ok();
2693 }
2694 });
2695 cx.notify();
2696 }
2697
2698 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2699 self.last_equal_ranges.clear();
2700 if self.diff.is_empty() {
2701 self.status = CodegenStatus::Idle;
2702 } else {
2703 self.status = CodegenStatus::Done;
2704 }
2705 self.generation = Task::ready(());
2706 cx.emit(CodegenEvent::Finished);
2707 cx.notify();
2708 }
2709
2710 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2711 self.buffer.update(cx, |buffer, cx| {
2712 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2713 buffer.undo_transaction(transaction_id, cx);
2714 buffer.refresh_preview(cx);
2715 }
2716
2717 if let Some(transaction_id) = self.initial_transaction_id.take() {
2718 buffer.undo_transaction(transaction_id, cx);
2719 buffer.refresh_preview(cx);
2720 }
2721 });
2722 }
2723
2724 fn reapply_line_based_diff(
2725 &mut self,
2726 edit_range: Range<Anchor>,
2727 line_operations: Vec<LineOperation>,
2728 cx: &mut ModelContext<Self>,
2729 ) {
2730 let old_snapshot = self.snapshot.clone();
2731 let old_range = edit_range.to_point(&old_snapshot);
2732 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2733 let new_range = edit_range.to_point(&new_snapshot);
2734
2735 let mut old_row = old_range.start.row;
2736 let mut new_row = new_range.start.row;
2737
2738 self.diff.deleted_row_ranges.clear();
2739 self.diff.inserted_row_ranges.clear();
2740 for operation in line_operations {
2741 match operation {
2742 LineOperation::Keep { lines } => {
2743 old_row += lines;
2744 new_row += lines;
2745 }
2746 LineOperation::Delete { lines } => {
2747 let old_end_row = old_row + lines - 1;
2748 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2749
2750 if let Some((_, last_deleted_row_range)) =
2751 self.diff.deleted_row_ranges.last_mut()
2752 {
2753 if *last_deleted_row_range.end() + 1 == old_row {
2754 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2755 } else {
2756 self.diff
2757 .deleted_row_ranges
2758 .push((new_row, old_row..=old_end_row));
2759 }
2760 } else {
2761 self.diff
2762 .deleted_row_ranges
2763 .push((new_row, old_row..=old_end_row));
2764 }
2765
2766 old_row += lines;
2767 }
2768 LineOperation::Insert { lines } => {
2769 let new_end_row = new_row + lines - 1;
2770 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2771 let end = new_snapshot.anchor_before(Point::new(
2772 new_end_row,
2773 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2774 ));
2775 self.diff.inserted_row_ranges.push(start..=end);
2776 new_row += lines;
2777 }
2778 }
2779
2780 cx.notify();
2781 }
2782 }
2783
2784 fn reapply_batch_diff(
2785 &mut self,
2786 edit_range: Range<Anchor>,
2787 cx: &mut ModelContext<Self>,
2788 ) -> Task<()> {
2789 let old_snapshot = self.snapshot.clone();
2790 let old_range = edit_range.to_point(&old_snapshot);
2791 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2792 let new_range = edit_range.to_point(&new_snapshot);
2793
2794 cx.spawn(|codegen, mut cx| async move {
2795 let (deleted_row_ranges, inserted_row_ranges) = cx
2796 .background_executor()
2797 .spawn(async move {
2798 let old_text = old_snapshot
2799 .text_for_range(
2800 Point::new(old_range.start.row, 0)
2801 ..Point::new(
2802 old_range.end.row,
2803 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2804 ),
2805 )
2806 .collect::<String>();
2807 let new_text = new_snapshot
2808 .text_for_range(
2809 Point::new(new_range.start.row, 0)
2810 ..Point::new(
2811 new_range.end.row,
2812 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2813 ),
2814 )
2815 .collect::<String>();
2816
2817 let mut old_row = old_range.start.row;
2818 let mut new_row = new_range.start.row;
2819 let batch_diff =
2820 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2821
2822 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2823 let mut inserted_row_ranges = Vec::new();
2824 for change in batch_diff.iter_all_changes() {
2825 let line_count = change.value().lines().count() as u32;
2826 match change.tag() {
2827 similar::ChangeTag::Equal => {
2828 old_row += line_count;
2829 new_row += line_count;
2830 }
2831 similar::ChangeTag::Delete => {
2832 let old_end_row = old_row + line_count - 1;
2833 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2834
2835 if let Some((_, last_deleted_row_range)) =
2836 deleted_row_ranges.last_mut()
2837 {
2838 if *last_deleted_row_range.end() + 1 == old_row {
2839 *last_deleted_row_range =
2840 *last_deleted_row_range.start()..=old_end_row;
2841 } else {
2842 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2843 }
2844 } else {
2845 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2846 }
2847
2848 old_row += line_count;
2849 }
2850 similar::ChangeTag::Insert => {
2851 let new_end_row = new_row + line_count - 1;
2852 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2853 let end = new_snapshot.anchor_before(Point::new(
2854 new_end_row,
2855 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2856 ));
2857 inserted_row_ranges.push(start..=end);
2858 new_row += line_count;
2859 }
2860 }
2861 }
2862
2863 (deleted_row_ranges, inserted_row_ranges)
2864 })
2865 .await;
2866
2867 codegen
2868 .update(&mut cx, |codegen, cx| {
2869 codegen.diff.deleted_row_ranges = deleted_row_ranges;
2870 codegen.diff.inserted_row_ranges = inserted_row_ranges;
2871 cx.notify();
2872 })
2873 .ok();
2874 })
2875 }
2876}
2877
2878struct StripInvalidSpans<T> {
2879 stream: T,
2880 stream_done: bool,
2881 buffer: String,
2882 first_line: bool,
2883 line_end: bool,
2884 starts_with_code_block: bool,
2885}
2886
2887impl<T> StripInvalidSpans<T>
2888where
2889 T: Stream<Item = Result<String>>,
2890{
2891 fn new(stream: T) -> Self {
2892 Self {
2893 stream,
2894 stream_done: false,
2895 buffer: String::new(),
2896 first_line: true,
2897 line_end: false,
2898 starts_with_code_block: false,
2899 }
2900 }
2901}
2902
2903impl<T> Stream for StripInvalidSpans<T>
2904where
2905 T: Stream<Item = Result<String>>,
2906{
2907 type Item = Result<String>;
2908
2909 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2910 const CODE_BLOCK_DELIMITER: &str = "```";
2911 const CURSOR_SPAN: &str = "<|CURSOR|>";
2912
2913 let this = unsafe { self.get_unchecked_mut() };
2914 loop {
2915 if !this.stream_done {
2916 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2917 match stream.as_mut().poll_next(cx) {
2918 Poll::Ready(Some(Ok(chunk))) => {
2919 this.buffer.push_str(&chunk);
2920 }
2921 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2922 Poll::Ready(None) => {
2923 this.stream_done = true;
2924 }
2925 Poll::Pending => return Poll::Pending,
2926 }
2927 }
2928
2929 let mut chunk = String::new();
2930 let mut consumed = 0;
2931 if !this.buffer.is_empty() {
2932 let mut lines = this.buffer.split('\n').enumerate().peekable();
2933 while let Some((line_ix, line)) = lines.next() {
2934 if line_ix > 0 {
2935 this.first_line = false;
2936 }
2937
2938 if this.first_line {
2939 let trimmed_line = line.trim();
2940 if lines.peek().is_some() {
2941 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2942 consumed += line.len() + 1;
2943 this.starts_with_code_block = true;
2944 continue;
2945 }
2946 } else if trimmed_line.is_empty()
2947 || prefixes(CODE_BLOCK_DELIMITER)
2948 .any(|prefix| trimmed_line.starts_with(prefix))
2949 {
2950 break;
2951 }
2952 }
2953
2954 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2955 if lines.peek().is_some() {
2956 if this.line_end {
2957 chunk.push('\n');
2958 }
2959
2960 chunk.push_str(&line_without_cursor);
2961 this.line_end = true;
2962 consumed += line.len() + 1;
2963 } else if this.stream_done {
2964 if !this.starts_with_code_block
2965 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2966 {
2967 if this.line_end {
2968 chunk.push('\n');
2969 }
2970
2971 chunk.push_str(&line);
2972 }
2973
2974 consumed += line.len();
2975 } else {
2976 let trimmed_line = line.trim();
2977 if trimmed_line.is_empty()
2978 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2979 || prefixes(CODE_BLOCK_DELIMITER)
2980 .any(|prefix| trimmed_line.ends_with(prefix))
2981 {
2982 break;
2983 } else {
2984 if this.line_end {
2985 chunk.push('\n');
2986 this.line_end = false;
2987 }
2988
2989 chunk.push_str(&line_without_cursor);
2990 consumed += line.len();
2991 }
2992 }
2993 }
2994 }
2995
2996 this.buffer = this.buffer.split_off(consumed);
2997 if !chunk.is_empty() {
2998 return Poll::Ready(Some(Ok(chunk)));
2999 } else if this.stream_done {
3000 return Poll::Ready(None);
3001 }
3002 }
3003 }
3004}
3005
3006fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3007 (0..text.len() - 1).map(|ix| &text[..ix + 1])
3008}
3009
3010fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3011 ranges.sort_unstable_by(|a, b| {
3012 a.start
3013 .cmp(&b.start, buffer)
3014 .then_with(|| b.end.cmp(&a.end, buffer))
3015 });
3016
3017 let mut ix = 0;
3018 while ix + 1 < ranges.len() {
3019 let b = ranges[ix + 1].clone();
3020 let a = &mut ranges[ix];
3021 if a.end.cmp(&b.start, buffer).is_gt() {
3022 if a.end.cmp(&b.end, buffer).is_lt() {
3023 a.end = b.end;
3024 }
3025 ranges.remove(ix + 1);
3026 } else {
3027 ix += 1;
3028 }
3029 }
3030}
3031
3032#[cfg(test)]
3033mod tests {
3034 use super::*;
3035 use futures::stream::{self};
3036 use gpui::{Context, TestAppContext};
3037 use indoc::indoc;
3038 use language::{
3039 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3040 Point,
3041 };
3042 use language_model::LanguageModelRegistry;
3043 use rand::prelude::*;
3044 use serde::Serialize;
3045 use settings::SettingsStore;
3046 use std::{future, sync::Arc};
3047
3048 #[derive(Serialize)]
3049 pub struct DummyCompletionRequest {
3050 pub name: String,
3051 }
3052
3053 #[gpui::test(iterations = 10)]
3054 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3055 cx.set_global(cx.update(SettingsStore::test));
3056 cx.update(language_model::LanguageModelRegistry::test);
3057 cx.update(language_settings::init);
3058
3059 let text = indoc! {"
3060 fn main() {
3061 let x = 0;
3062 for _ in 0..10 {
3063 x += 1;
3064 }
3065 }
3066 "};
3067 let buffer =
3068 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3069 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3070 let range = buffer.read_with(cx, |buffer, cx| {
3071 let snapshot = buffer.snapshot(cx);
3072 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3073 });
3074 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3075 let codegen = cx.new_model(|cx| {
3076 Codegen::new(
3077 buffer.clone(),
3078 range.clone(),
3079 None,
3080 None,
3081 prompt_builder,
3082 cx,
3083 )
3084 });
3085
3086 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3087 codegen.update(cx, |codegen, cx| {
3088 codegen.handle_stream(
3089 String::new(),
3090 range,
3091 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3092 cx,
3093 )
3094 });
3095
3096 let mut new_text = concat!(
3097 " let mut x = 0;\n",
3098 " while x < 10 {\n",
3099 " x += 1;\n",
3100 " }",
3101 );
3102 while !new_text.is_empty() {
3103 let max_len = cmp::min(new_text.len(), 10);
3104 let len = rng.gen_range(1..=max_len);
3105 let (chunk, suffix) = new_text.split_at(len);
3106 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3107 new_text = suffix;
3108 cx.background_executor.run_until_parked();
3109 }
3110 drop(chunks_tx);
3111 cx.background_executor.run_until_parked();
3112
3113 assert_eq!(
3114 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3115 indoc! {"
3116 fn main() {
3117 let mut x = 0;
3118 while x < 10 {
3119 x += 1;
3120 }
3121 }
3122 "}
3123 );
3124 }
3125
3126 #[gpui::test(iterations = 10)]
3127 async fn test_autoindent_when_generating_past_indentation(
3128 cx: &mut TestAppContext,
3129 mut rng: StdRng,
3130 ) {
3131 cx.set_global(cx.update(SettingsStore::test));
3132 cx.update(language_settings::init);
3133
3134 let text = indoc! {"
3135 fn main() {
3136 le
3137 }
3138 "};
3139 let buffer =
3140 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3141 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3142 let range = buffer.read_with(cx, |buffer, cx| {
3143 let snapshot = buffer.snapshot(cx);
3144 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3145 });
3146 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3147 let codegen = cx.new_model(|cx| {
3148 Codegen::new(
3149 buffer.clone(),
3150 range.clone(),
3151 None,
3152 None,
3153 prompt_builder,
3154 cx,
3155 )
3156 });
3157
3158 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3159 codegen.update(cx, |codegen, cx| {
3160 codegen.handle_stream(
3161 String::new(),
3162 range.clone(),
3163 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3164 cx,
3165 )
3166 });
3167
3168 cx.background_executor.run_until_parked();
3169
3170 let mut new_text = concat!(
3171 "t mut x = 0;\n",
3172 "while x < 10 {\n",
3173 " x += 1;\n",
3174 "}", //
3175 );
3176 while !new_text.is_empty() {
3177 let max_len = cmp::min(new_text.len(), 10);
3178 let len = rng.gen_range(1..=max_len);
3179 let (chunk, suffix) = new_text.split_at(len);
3180 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3181 new_text = suffix;
3182 cx.background_executor.run_until_parked();
3183 }
3184 drop(chunks_tx);
3185 cx.background_executor.run_until_parked();
3186
3187 assert_eq!(
3188 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3189 indoc! {"
3190 fn main() {
3191 let mut x = 0;
3192 while x < 10 {
3193 x += 1;
3194 }
3195 }
3196 "}
3197 );
3198 }
3199
3200 #[gpui::test(iterations = 10)]
3201 async fn test_autoindent_when_generating_before_indentation(
3202 cx: &mut TestAppContext,
3203 mut rng: StdRng,
3204 ) {
3205 cx.update(LanguageModelRegistry::test);
3206 cx.set_global(cx.update(SettingsStore::test));
3207 cx.update(language_settings::init);
3208
3209 let text = concat!(
3210 "fn main() {\n",
3211 " \n",
3212 "}\n" //
3213 );
3214 let buffer =
3215 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3216 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3217 let range = buffer.read_with(cx, |buffer, cx| {
3218 let snapshot = buffer.snapshot(cx);
3219 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3220 });
3221 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3222 let codegen = cx.new_model(|cx| {
3223 Codegen::new(
3224 buffer.clone(),
3225 range.clone(),
3226 None,
3227 None,
3228 prompt_builder,
3229 cx,
3230 )
3231 });
3232
3233 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3234 codegen.update(cx, |codegen, cx| {
3235 codegen.handle_stream(
3236 String::new(),
3237 range.clone(),
3238 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3239 cx,
3240 )
3241 });
3242
3243 cx.background_executor.run_until_parked();
3244
3245 let mut new_text = concat!(
3246 "let mut x = 0;\n",
3247 "while x < 10 {\n",
3248 " x += 1;\n",
3249 "}", //
3250 );
3251 while !new_text.is_empty() {
3252 let max_len = cmp::min(new_text.len(), 10);
3253 let len = rng.gen_range(1..=max_len);
3254 let (chunk, suffix) = new_text.split_at(len);
3255 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3256 new_text = suffix;
3257 cx.background_executor.run_until_parked();
3258 }
3259 drop(chunks_tx);
3260 cx.background_executor.run_until_parked();
3261
3262 assert_eq!(
3263 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3264 indoc! {"
3265 fn main() {
3266 let mut x = 0;
3267 while x < 10 {
3268 x += 1;
3269 }
3270 }
3271 "}
3272 );
3273 }
3274
3275 #[gpui::test(iterations = 10)]
3276 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3277 cx.update(LanguageModelRegistry::test);
3278 cx.set_global(cx.update(SettingsStore::test));
3279 cx.update(language_settings::init);
3280
3281 let text = indoc! {"
3282 func main() {
3283 \tx := 0
3284 \tfor i := 0; i < 10; i++ {
3285 \t\tx++
3286 \t}
3287 }
3288 "};
3289 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3290 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3291 let range = buffer.read_with(cx, |buffer, cx| {
3292 let snapshot = buffer.snapshot(cx);
3293 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3294 });
3295 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3296 let codegen = cx.new_model(|cx| {
3297 Codegen::new(
3298 buffer.clone(),
3299 range.clone(),
3300 None,
3301 None,
3302 prompt_builder,
3303 cx,
3304 )
3305 });
3306
3307 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3308 codegen.update(cx, |codegen, cx| {
3309 codegen.handle_stream(
3310 String::new(),
3311 range.clone(),
3312 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3313 cx,
3314 )
3315 });
3316
3317 let new_text = concat!(
3318 "func main() {\n",
3319 "\tx := 0\n",
3320 "\tfor x < 10 {\n",
3321 "\t\tx++\n",
3322 "\t}", //
3323 );
3324 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3325 drop(chunks_tx);
3326 cx.background_executor.run_until_parked();
3327
3328 assert_eq!(
3329 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3330 indoc! {"
3331 func main() {
3332 \tx := 0
3333 \tfor x < 10 {
3334 \t\tx++
3335 \t}
3336 }
3337 "}
3338 );
3339 }
3340
3341 #[gpui::test]
3342 async fn test_strip_invalid_spans_from_codeblock() {
3343 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3344 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3345 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3346 assert_chunks(
3347 "```html\n```js\nLorem ipsum dolor\n```\n```",
3348 "```js\nLorem ipsum dolor\n```",
3349 )
3350 .await;
3351 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3352 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3353 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3354 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3355
3356 async fn assert_chunks(text: &str, expected_text: &str) {
3357 for chunk_size in 1..=text.len() {
3358 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3359 .map(|chunk| chunk.unwrap())
3360 .collect::<String>()
3361 .await;
3362 assert_eq!(
3363 actual_text, expected_text,
3364 "failed to strip invalid spans, chunk size: {}",
3365 chunk_size
3366 );
3367 }
3368 }
3369
3370 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3371 stream::iter(
3372 text.chars()
3373 .collect::<Vec<_>>()
3374 .chunks(size)
3375 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3376 .collect::<Vec<_>>(),
3377 )
3378 }
3379 }
3380
3381 fn rust_lang() -> Language {
3382 Language::new(
3383 LanguageConfig {
3384 name: "Rust".into(),
3385 matcher: LanguageMatcher {
3386 path_suffixes: vec!["rs".to_string()],
3387 ..Default::default()
3388 },
3389 ..Default::default()
3390 },
3391 Some(tree_sitter_rust::LANGUAGE.into()),
3392 )
3393 .with_indents_query(
3394 r#"
3395 (call_expression) @indent
3396 (field_expression) @indent
3397 (_ "(" ")" @end) @indent
3398 (_ "{" "}" @end) @indent
3399 "#,
3400 )
3401 .unwrap()
3402 }
3403}