1use crate::{
2 Assistant, AssistantPanel, AssistantPanelEvent, CycleNextInlineAssist,
3 CyclePreviousInlineAssist,
4};
5use anyhow::{Context as _, Result, anyhow};
6use assistant_context_editor::{RequestType, humanize_token_count};
7use assistant_settings::AssistantSettings;
8use client::{ErrorExt, telemetry::Telemetry};
9use collections::{HashMap, HashSet, VecDeque, hash_map};
10use editor::{
11 Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
12 EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
13 ToOffset as _, ToPoint,
14 actions::{MoveDown, MoveUp, SelectAll},
15 display_map::{
16 BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
17 ToDisplayPoint,
18 },
19};
20use feature_flags::{FeatureFlagAppExt as _, ZedProFeatureFlag};
21use fs::Fs;
22use futures::{
23 SinkExt, Stream, StreamExt, TryStreamExt as _,
24 channel::mpsc,
25 future::{BoxFuture, LocalBoxFuture},
26 join,
27};
28use gpui::{
29 AnyElement, App, ClickEvent, Context, CursorStyle, Entity, EventEmitter, FocusHandle,
30 Focusable, FontWeight, Global, HighlightStyle, Subscription, Task, TextStyle, UpdateGlobal,
31 WeakEntity, Window, anchored, deferred, point,
32};
33use language::{Buffer, IndentKind, Point, Selection, TransactionId, line_diff};
34use language_model::{
35 ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
36 LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
37};
38use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
39use multi_buffer::MultiBufferRow;
40use parking_lot::Mutex;
41use project::{CodeAction, LspAction, ProjectTransaction};
42use prompt_store::PromptBuilder;
43use rope::Rope;
44use settings::{Settings, SettingsStore, update_settings_file};
45use smol::future::FutureExt;
46use std::{
47 cmp,
48 future::{self, Future},
49 iter, mem,
50 ops::{Range, RangeInclusive},
51 pin::Pin,
52 rc::Rc,
53 sync::Arc,
54 task::{self, Poll},
55 time::{Duration, Instant},
56};
57use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
58use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
59use terminal_view::terminal_panel::TerminalPanel;
60use text::{OffsetRangeExt, ToPoint as _};
61use theme::ThemeSettings;
62use ui::{
63 CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, Tooltip, prelude::*, text_for_action,
64};
65use util::{RangeExt, ResultExt};
66use workspace::{ItemHandle, Toast, Workspace, notifications::NotificationId};
67
68pub fn init(
69 fs: Arc<dyn Fs>,
70 prompt_builder: Arc<PromptBuilder>,
71 telemetry: Arc<Telemetry>,
72 cx: &mut App,
73) {
74 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
75 // Don't register now that the Agent is released.
76 if false {
77 cx.observe_new(|_, window, cx| {
78 let Some(window) = window else {
79 return;
80 };
81 let workspace = cx.entity().clone();
82 InlineAssistant::update_global(cx, |inline_assistant, cx| {
83 inline_assistant.register_workspace(&workspace, window, cx)
84 });
85 })
86 .detach();
87 }
88}
89
90const PROMPT_HISTORY_MAX_LEN: usize = 20;
91
92pub struct InlineAssistant {
93 next_assist_id: InlineAssistId,
94 next_assist_group_id: InlineAssistGroupId,
95 assists: HashMap<InlineAssistId, InlineAssist>,
96 assists_by_editor: HashMap<WeakEntity<Editor>, EditorInlineAssists>,
97 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
98 confirmed_assists: HashMap<InlineAssistId, Entity<CodegenAlternative>>,
99 prompt_history: VecDeque<String>,
100 prompt_builder: Arc<PromptBuilder>,
101 telemetry: Arc<Telemetry>,
102 fs: Arc<dyn Fs>,
103}
104
105impl Global for InlineAssistant {}
106
107impl InlineAssistant {
108 pub fn new(
109 fs: Arc<dyn Fs>,
110 prompt_builder: Arc<PromptBuilder>,
111 telemetry: Arc<Telemetry>,
112 ) -> Self {
113 Self {
114 next_assist_id: InlineAssistId::default(),
115 next_assist_group_id: InlineAssistGroupId::default(),
116 assists: HashMap::default(),
117 assists_by_editor: HashMap::default(),
118 assist_groups: HashMap::default(),
119 confirmed_assists: HashMap::default(),
120 prompt_history: VecDeque::default(),
121 prompt_builder,
122 telemetry,
123 fs,
124 }
125 }
126
127 pub fn register_workspace(
128 &mut self,
129 workspace: &Entity<Workspace>,
130 window: &mut Window,
131 cx: &mut App,
132 ) {
133 window
134 .subscribe(workspace, cx, |workspace, event, window, cx| {
135 Self::update_global(cx, |this, cx| {
136 this.handle_workspace_event(workspace, event, window, cx)
137 });
138 })
139 .detach();
140
141 let workspace = workspace.downgrade();
142 cx.observe_global::<SettingsStore>(move |cx| {
143 let Some(workspace) = workspace.upgrade() else {
144 return;
145 };
146 let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
147 return;
148 };
149 let enabled = AssistantSettings::get_global(cx).enabled;
150 terminal_panel.update(cx, |terminal_panel, cx| {
151 terminal_panel.set_assistant_enabled(enabled, cx)
152 });
153 })
154 .detach();
155 }
156
157 fn handle_workspace_event(
158 &mut self,
159 workspace: Entity<Workspace>,
160 event: &workspace::Event,
161 window: &mut Window,
162 cx: &mut App,
163 ) {
164 match event {
165 workspace::Event::UserSavedItem { item, .. } => {
166 // When the user manually saves an editor, automatically accepts all finished transformations.
167 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
168 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
169 for assist_id in editor_assists.assist_ids.clone() {
170 let assist = &self.assists[&assist_id];
171 if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
172 self.finish_assist(assist_id, false, window, cx)
173 }
174 }
175 }
176 }
177 }
178 workspace::Event::ItemAdded { item } => {
179 self.register_workspace_item(&workspace, item.as_ref(), window, cx);
180 }
181 _ => (),
182 }
183 }
184
185 fn register_workspace_item(
186 &mut self,
187 workspace: &Entity<Workspace>,
188 item: &dyn ItemHandle,
189 window: &mut Window,
190 cx: &mut App,
191 ) {
192 let is_assistant2_enabled = true;
193
194 if let Some(editor) = item.act_as::<Editor>(cx) {
195 editor.update(cx, |editor, cx| {
196 if is_assistant2_enabled {
197 editor.remove_code_action_provider(
198 ASSISTANT_CODE_ACTION_PROVIDER_ID.into(),
199 window,
200 cx,
201 );
202 } else {
203 editor.add_code_action_provider(
204 Rc::new(AssistantCodeActionProvider {
205 editor: cx.entity().downgrade(),
206 workspace: workspace.downgrade(),
207 }),
208 window,
209 cx,
210 );
211 }
212 });
213 }
214 }
215
216 pub fn assist(
217 &mut self,
218 editor: &Entity<Editor>,
219 workspace: Option<WeakEntity<Workspace>>,
220 assistant_panel: Option<&Entity<AssistantPanel>>,
221 initial_prompt: Option<String>,
222 window: &mut Window,
223 cx: &mut App,
224 ) {
225 let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
226 (
227 editor.snapshot(window, cx),
228 editor.selections.all::<Point>(cx),
229 )
230 });
231
232 let mut selections = Vec::<Selection<Point>>::new();
233 let mut newest_selection = None;
234 for mut selection in initial_selections {
235 if selection.end > selection.start {
236 selection.start.column = 0;
237 // If the selection ends at the start of the line, we don't want to include it.
238 if selection.end.column == 0 {
239 selection.end.row -= 1;
240 }
241 selection.end.column = snapshot
242 .buffer_snapshot
243 .line_len(MultiBufferRow(selection.end.row));
244 } else if let Some(fold) =
245 snapshot.crease_for_buffer_row(MultiBufferRow(selection.end.row))
246 {
247 selection.start = fold.range().start;
248 selection.end = fold.range().end;
249 if MultiBufferRow(selection.end.row) < snapshot.buffer_snapshot.max_row() {
250 let chars = snapshot
251 .buffer_snapshot
252 .chars_at(Point::new(selection.end.row + 1, 0));
253
254 for c in chars {
255 if c == '\n' {
256 break;
257 }
258 if c.is_whitespace() {
259 continue;
260 }
261 if snapshot
262 .language_at(selection.end)
263 .is_some_and(|language| language.config().brackets.is_closing_brace(c))
264 {
265 selection.end.row += 1;
266 selection.end.column = snapshot
267 .buffer_snapshot
268 .line_len(MultiBufferRow(selection.end.row));
269 }
270 }
271 }
272 }
273
274 if let Some(prev_selection) = selections.last_mut() {
275 if selection.start <= prev_selection.end {
276 prev_selection.end = selection.end;
277 continue;
278 }
279 }
280
281 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
282 if selection.id > latest_selection.id {
283 *latest_selection = selection.clone();
284 }
285 selections.push(selection);
286 }
287 let snapshot = &snapshot.buffer_snapshot;
288 let newest_selection = newest_selection.unwrap();
289
290 let mut codegen_ranges = Vec::new();
291 for (buffer, buffer_range, excerpt_id) in
292 snapshot.ranges_to_buffer_ranges(selections.iter().map(|selection| {
293 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
294 }))
295 {
296 let start = buffer.anchor_before(buffer_range.start);
297 let end = buffer.anchor_after(buffer_range.end);
298
299 codegen_ranges.push(Anchor::range_in_buffer(
300 excerpt_id,
301 buffer.remote_id(),
302 start..end,
303 ));
304
305 if let Some(ConfiguredModel { model, .. }) =
306 LanguageModelRegistry::read_global(cx).default_model()
307 {
308 self.telemetry.report_assistant_event(AssistantEventData {
309 conversation_id: None,
310 kind: AssistantKind::Inline,
311 phase: AssistantPhase::Invoked,
312 message_id: None,
313 model: model.telemetry_id(),
314 model_provider: model.provider_id().to_string(),
315 response_latency: None,
316 error_message: None,
317 language_name: buffer.language().map(|language| language.name().to_proto()),
318 });
319 }
320 }
321
322 let assist_group_id = self.next_assist_group_id.post_inc();
323 let prompt_buffer = cx.new(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
324 let prompt_buffer = cx.new(|cx| MultiBuffer::singleton(prompt_buffer, cx));
325
326 let mut assists = Vec::new();
327 let mut assist_to_focus = None;
328 for range in codegen_ranges {
329 let assist_id = self.next_assist_id.post_inc();
330 let codegen = cx.new(|cx| {
331 Codegen::new(
332 editor.read(cx).buffer().clone(),
333 range.clone(),
334 None,
335 self.telemetry.clone(),
336 self.prompt_builder.clone(),
337 cx,
338 )
339 });
340
341 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
342 let prompt_editor = cx.new(|cx| {
343 PromptEditor::new(
344 assist_id,
345 gutter_dimensions.clone(),
346 self.prompt_history.clone(),
347 prompt_buffer.clone(),
348 codegen.clone(),
349 editor,
350 assistant_panel,
351 workspace.clone(),
352 self.fs.clone(),
353 window,
354 cx,
355 )
356 });
357
358 if assist_to_focus.is_none() {
359 let focus_assist = if newest_selection.reversed {
360 range.start.to_point(&snapshot) == newest_selection.start
361 } else {
362 range.end.to_point(&snapshot) == newest_selection.end
363 };
364 if focus_assist {
365 assist_to_focus = Some(assist_id);
366 }
367 }
368
369 let [prompt_block_id, end_block_id] =
370 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
371
372 assists.push((
373 assist_id,
374 range,
375 prompt_editor,
376 prompt_block_id,
377 end_block_id,
378 ));
379 }
380
381 let editor_assists = self
382 .assists_by_editor
383 .entry(editor.downgrade())
384 .or_insert_with(|| EditorInlineAssists::new(&editor, window, cx));
385 let mut assist_group = InlineAssistGroup::new();
386 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
387 self.assists.insert(
388 assist_id,
389 InlineAssist::new(
390 assist_id,
391 assist_group_id,
392 assistant_panel.is_some(),
393 editor,
394 &prompt_editor,
395 prompt_block_id,
396 end_block_id,
397 range,
398 prompt_editor.read(cx).codegen.clone(),
399 workspace.clone(),
400 window,
401 cx,
402 ),
403 );
404 assist_group.assist_ids.push(assist_id);
405 editor_assists.assist_ids.push(assist_id);
406 }
407 self.assist_groups.insert(assist_group_id, assist_group);
408
409 if let Some(assist_id) = assist_to_focus {
410 self.focus_assist(assist_id, window, cx);
411 }
412 }
413
414 pub fn suggest_assist(
415 &mut self,
416 editor: &Entity<Editor>,
417 mut range: Range<Anchor>,
418 initial_prompt: String,
419 initial_transaction_id: Option<TransactionId>,
420 focus: bool,
421 workspace: Option<WeakEntity<Workspace>>,
422 assistant_panel: Option<&Entity<AssistantPanel>>,
423 window: &mut Window,
424 cx: &mut App,
425 ) -> InlineAssistId {
426 let assist_group_id = self.next_assist_group_id.post_inc();
427 let prompt_buffer = cx.new(|cx| Buffer::local(&initial_prompt, cx));
428 let prompt_buffer = cx.new(|cx| MultiBuffer::singleton(prompt_buffer, cx));
429
430 let assist_id = self.next_assist_id.post_inc();
431
432 let buffer = editor.read(cx).buffer().clone();
433 {
434 let snapshot = buffer.read(cx).read(cx);
435 range.start = range.start.bias_left(&snapshot);
436 range.end = range.end.bias_right(&snapshot);
437 }
438
439 let codegen = cx.new(|cx| {
440 Codegen::new(
441 editor.read(cx).buffer().clone(),
442 range.clone(),
443 initial_transaction_id,
444 self.telemetry.clone(),
445 self.prompt_builder.clone(),
446 cx,
447 )
448 });
449
450 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
451 let prompt_editor = cx.new(|cx| {
452 PromptEditor::new(
453 assist_id,
454 gutter_dimensions.clone(),
455 self.prompt_history.clone(),
456 prompt_buffer.clone(),
457 codegen.clone(),
458 editor,
459 assistant_panel,
460 workspace.clone(),
461 self.fs.clone(),
462 window,
463 cx,
464 )
465 });
466
467 let [prompt_block_id, end_block_id] =
468 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
469
470 let editor_assists = self
471 .assists_by_editor
472 .entry(editor.downgrade())
473 .or_insert_with(|| EditorInlineAssists::new(&editor, window, cx));
474
475 let mut assist_group = InlineAssistGroup::new();
476 self.assists.insert(
477 assist_id,
478 InlineAssist::new(
479 assist_id,
480 assist_group_id,
481 assistant_panel.is_some(),
482 editor,
483 &prompt_editor,
484 prompt_block_id,
485 end_block_id,
486 range,
487 prompt_editor.read(cx).codegen.clone(),
488 workspace.clone(),
489 window,
490 cx,
491 ),
492 );
493 assist_group.assist_ids.push(assist_id);
494 editor_assists.assist_ids.push(assist_id);
495 self.assist_groups.insert(assist_group_id, assist_group);
496
497 if focus {
498 self.focus_assist(assist_id, window, cx);
499 }
500
501 assist_id
502 }
503
504 fn insert_assist_blocks(
505 &self,
506 editor: &Entity<Editor>,
507 range: &Range<Anchor>,
508 prompt_editor: &Entity<PromptEditor>,
509 cx: &mut App,
510 ) -> [CustomBlockId; 2] {
511 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
512 prompt_editor
513 .editor
514 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
515 });
516 let assist_blocks = vec![
517 BlockProperties {
518 style: BlockStyle::Sticky,
519 placement: BlockPlacement::Above(range.start),
520 height: Some(prompt_editor_height),
521 render: build_assist_editor_renderer(prompt_editor),
522 priority: 0,
523 },
524 BlockProperties {
525 style: BlockStyle::Sticky,
526 placement: BlockPlacement::Below(range.end),
527 height: None,
528 render: Arc::new(|cx| {
529 v_flex()
530 .h_full()
531 .w_full()
532 .border_t_1()
533 .border_color(cx.theme().status().info_border)
534 .into_any_element()
535 }),
536 priority: 0,
537 },
538 ];
539
540 editor.update(cx, |editor, cx| {
541 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
542 [block_ids[0], block_ids[1]]
543 })
544 }
545
546 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut App) {
547 let assist = &self.assists[&assist_id];
548 let Some(decorations) = assist.decorations.as_ref() else {
549 return;
550 };
551 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
552 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
553
554 assist_group.active_assist_id = Some(assist_id);
555 if assist_group.linked {
556 for assist_id in &assist_group.assist_ids {
557 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
558 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
559 prompt_editor.set_show_cursor_when_unfocused(true, cx)
560 });
561 }
562 }
563 }
564
565 assist
566 .editor
567 .update(cx, |editor, cx| {
568 let scroll_top = editor.scroll_position(cx).y;
569 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
570 let prompt_row = editor
571 .row_for_block(decorations.prompt_block_id, cx)
572 .unwrap()
573 .0 as f32;
574
575 if (scroll_top..scroll_bottom).contains(&prompt_row) {
576 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
577 assist_id,
578 distance_from_top: prompt_row - scroll_top,
579 });
580 } else {
581 editor_assists.scroll_lock = None;
582 }
583 })
584 .ok();
585 }
586
587 fn handle_prompt_editor_focus_out(&mut self, assist_id: InlineAssistId, cx: &mut App) {
588 let assist = &self.assists[&assist_id];
589 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
590 if assist_group.active_assist_id == Some(assist_id) {
591 assist_group.active_assist_id = None;
592 if assist_group.linked {
593 for assist_id in &assist_group.assist_ids {
594 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
595 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
596 prompt_editor.set_show_cursor_when_unfocused(false, cx)
597 });
598 }
599 }
600 }
601 }
602 }
603
604 fn handle_prompt_editor_event(
605 &mut self,
606 prompt_editor: Entity<PromptEditor>,
607 event: &PromptEditorEvent,
608 window: &mut Window,
609 cx: &mut App,
610 ) {
611 let assist_id = prompt_editor.read(cx).id;
612 match event {
613 PromptEditorEvent::StartRequested => {
614 self.start_assist(assist_id, window, cx);
615 }
616 PromptEditorEvent::StopRequested => {
617 self.stop_assist(assist_id, cx);
618 }
619 PromptEditorEvent::ConfirmRequested => {
620 self.finish_assist(assist_id, false, window, cx);
621 }
622 PromptEditorEvent::CancelRequested => {
623 self.finish_assist(assist_id, true, window, cx);
624 }
625 PromptEditorEvent::DismissRequested => {
626 self.dismiss_assist(assist_id, window, cx);
627 }
628 }
629 }
630
631 fn handle_editor_newline(&mut self, editor: Entity<Editor>, window: &mut Window, cx: &mut App) {
632 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
633 return;
634 };
635
636 if editor.read(cx).selections.count() == 1 {
637 let (selection, buffer) = editor.update(cx, |editor, cx| {
638 (
639 editor.selections.newest::<usize>(cx),
640 editor.buffer().read(cx).snapshot(cx),
641 )
642 });
643 for assist_id in &editor_assists.assist_ids {
644 let assist = &self.assists[assist_id];
645 let assist_range = assist.range.to_offset(&buffer);
646 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
647 {
648 if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
649 self.dismiss_assist(*assist_id, window, cx);
650 } else {
651 self.finish_assist(*assist_id, false, window, cx);
652 }
653
654 return;
655 }
656 }
657 }
658
659 cx.propagate();
660 }
661
662 fn handle_editor_cancel(&mut self, editor: Entity<Editor>, window: &mut Window, cx: &mut App) {
663 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
664 return;
665 };
666
667 if editor.read(cx).selections.count() == 1 {
668 let (selection, buffer) = editor.update(cx, |editor, cx| {
669 (
670 editor.selections.newest::<usize>(cx),
671 editor.buffer().read(cx).snapshot(cx),
672 )
673 });
674 let mut closest_assist_fallback = None;
675 for assist_id in &editor_assists.assist_ids {
676 let assist = &self.assists[assist_id];
677 let assist_range = assist.range.to_offset(&buffer);
678 if assist.decorations.is_some() {
679 if assist_range.contains(&selection.start)
680 && assist_range.contains(&selection.end)
681 {
682 self.focus_assist(*assist_id, window, cx);
683 return;
684 } else {
685 let distance_from_selection = assist_range
686 .start
687 .abs_diff(selection.start)
688 .min(assist_range.start.abs_diff(selection.end))
689 + assist_range
690 .end
691 .abs_diff(selection.start)
692 .min(assist_range.end.abs_diff(selection.end));
693 match closest_assist_fallback {
694 Some((_, old_distance)) => {
695 if distance_from_selection < old_distance {
696 closest_assist_fallback =
697 Some((assist_id, distance_from_selection));
698 }
699 }
700 None => {
701 closest_assist_fallback = Some((assist_id, distance_from_selection))
702 }
703 }
704 }
705 }
706 }
707
708 if let Some((&assist_id, _)) = closest_assist_fallback {
709 self.focus_assist(assist_id, window, cx);
710 }
711 }
712
713 cx.propagate();
714 }
715
716 fn handle_editor_release(
717 &mut self,
718 editor: WeakEntity<Editor>,
719 window: &mut Window,
720 cx: &mut App,
721 ) {
722 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
723 for assist_id in editor_assists.assist_ids.clone() {
724 self.finish_assist(assist_id, true, window, cx);
725 }
726 }
727 }
728
729 fn handle_editor_change(&mut self, editor: Entity<Editor>, window: &mut Window, cx: &mut App) {
730 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
731 return;
732 };
733 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
734 return;
735 };
736 let assist = &self.assists[&scroll_lock.assist_id];
737 let Some(decorations) = assist.decorations.as_ref() else {
738 return;
739 };
740
741 editor.update(cx, |editor, cx| {
742 let scroll_position = editor.scroll_position(cx);
743 let target_scroll_top = editor
744 .row_for_block(decorations.prompt_block_id, cx)
745 .unwrap()
746 .0 as f32
747 - scroll_lock.distance_from_top;
748 if target_scroll_top != scroll_position.y {
749 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), window, cx);
750 }
751 });
752 }
753
754 fn handle_editor_event(
755 &mut self,
756 editor: Entity<Editor>,
757 event: &EditorEvent,
758 window: &mut Window,
759 cx: &mut App,
760 ) {
761 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
762 return;
763 };
764
765 match event {
766 EditorEvent::Edited { transaction_id } => {
767 let buffer = editor.read(cx).buffer().read(cx);
768 let edited_ranges =
769 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
770 let snapshot = buffer.snapshot(cx);
771
772 for assist_id in editor_assists.assist_ids.clone() {
773 let assist = &self.assists[&assist_id];
774 if matches!(
775 assist.codegen.read(cx).status(cx),
776 CodegenStatus::Error(_) | CodegenStatus::Done
777 ) {
778 let assist_range = assist.range.to_offset(&snapshot);
779 if edited_ranges
780 .iter()
781 .any(|range| range.overlaps(&assist_range))
782 {
783 self.finish_assist(assist_id, false, window, cx);
784 }
785 }
786 }
787 }
788 EditorEvent::ScrollPositionChanged { .. } => {
789 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
790 let assist = &self.assists[&scroll_lock.assist_id];
791 if let Some(decorations) = assist.decorations.as_ref() {
792 let distance_from_top = editor.update(cx, |editor, cx| {
793 let scroll_top = editor.scroll_position(cx).y;
794 let prompt_row = editor
795 .row_for_block(decorations.prompt_block_id, cx)
796 .unwrap()
797 .0 as f32;
798 prompt_row - scroll_top
799 });
800
801 if distance_from_top != scroll_lock.distance_from_top {
802 editor_assists.scroll_lock = None;
803 }
804 }
805 }
806 }
807 EditorEvent::SelectionsChanged { .. } => {
808 for assist_id in editor_assists.assist_ids.clone() {
809 let assist = &self.assists[&assist_id];
810 if let Some(decorations) = assist.decorations.as_ref() {
811 if decorations
812 .prompt_editor
813 .focus_handle(cx)
814 .is_focused(window)
815 {
816 return;
817 }
818 }
819 }
820
821 editor_assists.scroll_lock = None;
822 }
823 _ => {}
824 }
825 }
826
827 pub fn finish_assist(
828 &mut self,
829 assist_id: InlineAssistId,
830 undo: bool,
831 window: &mut Window,
832 cx: &mut App,
833 ) {
834 if let Some(assist) = self.assists.get(&assist_id) {
835 let assist_group_id = assist.group_id;
836 if self.assist_groups[&assist_group_id].linked {
837 for assist_id in self.unlink_assist_group(assist_group_id, window, cx) {
838 self.finish_assist(assist_id, undo, window, cx);
839 }
840 return;
841 }
842 }
843
844 self.dismiss_assist(assist_id, window, cx);
845
846 if let Some(assist) = self.assists.remove(&assist_id) {
847 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
848 {
849 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
850 if entry.get().assist_ids.is_empty() {
851 entry.remove();
852 }
853 }
854
855 if let hash_map::Entry::Occupied(mut entry) =
856 self.assists_by_editor.entry(assist.editor.clone())
857 {
858 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
859 if entry.get().assist_ids.is_empty() {
860 entry.remove();
861 if let Some(editor) = assist.editor.upgrade() {
862 self.update_editor_highlights(&editor, cx);
863 }
864 } else {
865 entry.get().highlight_updates.send(()).ok();
866 }
867 }
868
869 let active_alternative = assist.codegen.read(cx).active_alternative().clone();
870 let message_id = active_alternative.read(cx).message_id.clone();
871
872 if let Some(ConfiguredModel { model, .. }) =
873 LanguageModelRegistry::read_global(cx).default_model()
874 {
875 let language_name = assist.editor.upgrade().and_then(|editor| {
876 let multibuffer = editor.read(cx).buffer().read(cx);
877 let multibuffer_snapshot = multibuffer.snapshot(cx);
878 let ranges = multibuffer_snapshot.range_to_buffer_ranges(assist.range.clone());
879 ranges
880 .first()
881 .and_then(|(buffer, _, _)| buffer.language())
882 .map(|language| language.name())
883 });
884 report_assistant_event(
885 AssistantEventData {
886 conversation_id: None,
887 kind: AssistantKind::Inline,
888 message_id,
889 phase: if undo {
890 AssistantPhase::Rejected
891 } else {
892 AssistantPhase::Accepted
893 },
894 model: model.telemetry_id(),
895 model_provider: model.provider_id().to_string(),
896 response_latency: None,
897 error_message: None,
898 language_name: language_name.map(|name| name.to_proto()),
899 },
900 Some(self.telemetry.clone()),
901 cx.http_client(),
902 model.api_key(cx),
903 cx.background_executor(),
904 );
905 }
906
907 if undo {
908 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
909 } else {
910 self.confirmed_assists.insert(assist_id, active_alternative);
911 }
912 }
913 }
914
915 fn dismiss_assist(
916 &mut self,
917 assist_id: InlineAssistId,
918 window: &mut Window,
919 cx: &mut App,
920 ) -> bool {
921 let Some(assist) = self.assists.get_mut(&assist_id) else {
922 return false;
923 };
924 let Some(editor) = assist.editor.upgrade() else {
925 return false;
926 };
927 let Some(decorations) = assist.decorations.take() else {
928 return false;
929 };
930
931 editor.update(cx, |editor, cx| {
932 let mut to_remove = decorations.removed_line_block_ids;
933 to_remove.insert(decorations.prompt_block_id);
934 to_remove.insert(decorations.end_block_id);
935 editor.remove_blocks(to_remove, None, cx);
936 });
937
938 if decorations
939 .prompt_editor
940 .focus_handle(cx)
941 .contains_focused(window, cx)
942 {
943 self.focus_next_assist(assist_id, window, cx);
944 }
945
946 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
947 if editor_assists
948 .scroll_lock
949 .as_ref()
950 .map_or(false, |lock| lock.assist_id == assist_id)
951 {
952 editor_assists.scroll_lock = None;
953 }
954 editor_assists.highlight_updates.send(()).ok();
955 }
956
957 true
958 }
959
960 fn focus_next_assist(&mut self, assist_id: InlineAssistId, window: &mut Window, cx: &mut App) {
961 let Some(assist) = self.assists.get(&assist_id) else {
962 return;
963 };
964
965 let assist_group = &self.assist_groups[&assist.group_id];
966 let assist_ix = assist_group
967 .assist_ids
968 .iter()
969 .position(|id| *id == assist_id)
970 .unwrap();
971 let assist_ids = assist_group
972 .assist_ids
973 .iter()
974 .skip(assist_ix + 1)
975 .chain(assist_group.assist_ids.iter().take(assist_ix));
976
977 for assist_id in assist_ids {
978 let assist = &self.assists[assist_id];
979 if assist.decorations.is_some() {
980 self.focus_assist(*assist_id, window, cx);
981 return;
982 }
983 }
984
985 assist
986 .editor
987 .update(cx, |editor, cx| window.focus(&editor.focus_handle(cx)))
988 .ok();
989 }
990
991 fn focus_assist(&mut self, assist_id: InlineAssistId, window: &mut Window, cx: &mut App) {
992 let Some(assist) = self.assists.get(&assist_id) else {
993 return;
994 };
995
996 if let Some(decorations) = assist.decorations.as_ref() {
997 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
998 prompt_editor.editor.update(cx, |editor, cx| {
999 window.focus(&editor.focus_handle(cx));
1000 editor.select_all(&SelectAll, window, cx);
1001 })
1002 });
1003 }
1004
1005 self.scroll_to_assist(assist_id, window, cx);
1006 }
1007
1008 pub fn scroll_to_assist(
1009 &mut self,
1010 assist_id: InlineAssistId,
1011 window: &mut Window,
1012 cx: &mut App,
1013 ) {
1014 let Some(assist) = self.assists.get(&assist_id) else {
1015 return;
1016 };
1017 let Some(editor) = assist.editor.upgrade() else {
1018 return;
1019 };
1020
1021 let position = assist.range.start;
1022 editor.update(cx, |editor, cx| {
1023 editor.change_selections(None, window, cx, |selections| {
1024 selections.select_anchor_ranges([position..position])
1025 });
1026
1027 let mut scroll_target_top;
1028 let mut scroll_target_bottom;
1029 if let Some(decorations) = assist.decorations.as_ref() {
1030 scroll_target_top = editor
1031 .row_for_block(decorations.prompt_block_id, cx)
1032 .unwrap()
1033 .0 as f32;
1034 scroll_target_bottom = editor
1035 .row_for_block(decorations.end_block_id, cx)
1036 .unwrap()
1037 .0 as f32;
1038 } else {
1039 let snapshot = editor.snapshot(window, cx);
1040 let start_row = assist
1041 .range
1042 .start
1043 .to_display_point(&snapshot.display_snapshot)
1044 .row();
1045 scroll_target_top = start_row.0 as f32;
1046 scroll_target_bottom = scroll_target_top + 1.;
1047 }
1048 scroll_target_top -= editor.vertical_scroll_margin() as f32;
1049 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
1050
1051 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
1052 let scroll_top = editor.scroll_position(cx).y;
1053 let scroll_bottom = scroll_top + height_in_lines;
1054
1055 if scroll_target_top < scroll_top {
1056 editor.set_scroll_position(point(0., scroll_target_top), window, cx);
1057 } else if scroll_target_bottom > scroll_bottom {
1058 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
1059 editor.set_scroll_position(
1060 point(0., scroll_target_bottom - height_in_lines),
1061 window,
1062 cx,
1063 );
1064 } else {
1065 editor.set_scroll_position(point(0., scroll_target_top), window, cx);
1066 }
1067 }
1068 });
1069 }
1070
1071 fn unlink_assist_group(
1072 &mut self,
1073 assist_group_id: InlineAssistGroupId,
1074 window: &mut Window,
1075 cx: &mut App,
1076 ) -> Vec<InlineAssistId> {
1077 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
1078 assist_group.linked = false;
1079 for assist_id in &assist_group.assist_ids {
1080 let assist = self.assists.get_mut(assist_id).unwrap();
1081 if let Some(editor_decorations) = assist.decorations.as_ref() {
1082 editor_decorations
1083 .prompt_editor
1084 .update(cx, |prompt_editor, cx| prompt_editor.unlink(window, cx));
1085 }
1086 }
1087 assist_group.assist_ids.clone()
1088 }
1089
1090 pub fn start_assist(&mut self, assist_id: InlineAssistId, window: &mut Window, cx: &mut App) {
1091 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1092 assist
1093 } else {
1094 return;
1095 };
1096
1097 let assist_group_id = assist.group_id;
1098 if self.assist_groups[&assist_group_id].linked {
1099 for assist_id in self.unlink_assist_group(assist_group_id, window, cx) {
1100 self.start_assist(assist_id, window, cx);
1101 }
1102 return;
1103 }
1104
1105 let Some(user_prompt) = assist.user_prompt(cx) else {
1106 return;
1107 };
1108
1109 self.prompt_history.retain(|prompt| *prompt != user_prompt);
1110 self.prompt_history.push_back(user_prompt.clone());
1111 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1112 self.prompt_history.pop_front();
1113 }
1114
1115 let assistant_panel_context = assist.assistant_panel_context(cx);
1116
1117 assist
1118 .codegen
1119 .update(cx, |codegen, cx| {
1120 codegen.start(user_prompt, assistant_panel_context, cx)
1121 })
1122 .log_err();
1123 }
1124
1125 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut App) {
1126 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1127 assist
1128 } else {
1129 return;
1130 };
1131
1132 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1133 }
1134
1135 fn update_editor_highlights(&self, editor: &Entity<Editor>, cx: &mut App) {
1136 let mut gutter_pending_ranges = Vec::new();
1137 let mut gutter_transformed_ranges = Vec::new();
1138 let mut foreground_ranges = Vec::new();
1139 let mut inserted_row_ranges = Vec::new();
1140 let empty_assist_ids = Vec::new();
1141 let assist_ids = self
1142 .assists_by_editor
1143 .get(&editor.downgrade())
1144 .map_or(&empty_assist_ids, |editor_assists| {
1145 &editor_assists.assist_ids
1146 });
1147
1148 for assist_id in assist_ids {
1149 if let Some(assist) = self.assists.get(assist_id) {
1150 let codegen = assist.codegen.read(cx);
1151 let buffer = codegen.buffer(cx).read(cx).read(cx);
1152 foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1153
1154 let pending_range =
1155 codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1156 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1157 gutter_pending_ranges.push(pending_range);
1158 }
1159
1160 if let Some(edit_position) = codegen.edit_position(cx) {
1161 let edited_range = assist.range.start..edit_position;
1162 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1163 gutter_transformed_ranges.push(edited_range);
1164 }
1165 }
1166
1167 if assist.decorations.is_some() {
1168 inserted_row_ranges
1169 .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1170 }
1171 }
1172 }
1173
1174 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1175 merge_ranges(&mut foreground_ranges, &snapshot);
1176 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1177 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1178 editor.update(cx, |editor, cx| {
1179 enum GutterPendingRange {}
1180 if gutter_pending_ranges.is_empty() {
1181 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1182 } else {
1183 editor.highlight_gutter::<GutterPendingRange>(
1184 &gutter_pending_ranges,
1185 |cx| cx.theme().status().info_background,
1186 cx,
1187 )
1188 }
1189
1190 enum GutterTransformedRange {}
1191 if gutter_transformed_ranges.is_empty() {
1192 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1193 } else {
1194 editor.highlight_gutter::<GutterTransformedRange>(
1195 &gutter_transformed_ranges,
1196 |cx| cx.theme().status().info,
1197 cx,
1198 )
1199 }
1200
1201 if foreground_ranges.is_empty() {
1202 editor.clear_highlights::<InlineAssist>(cx);
1203 } else {
1204 editor.highlight_text::<InlineAssist>(
1205 foreground_ranges,
1206 HighlightStyle {
1207 fade_out: Some(0.6),
1208 ..Default::default()
1209 },
1210 cx,
1211 );
1212 }
1213
1214 editor.clear_row_highlights::<InlineAssist>();
1215 for row_range in inserted_row_ranges {
1216 editor.highlight_rows::<InlineAssist>(
1217 row_range,
1218 cx.theme().status().info_background,
1219 Default::default(),
1220 cx,
1221 );
1222 }
1223 });
1224 }
1225
1226 fn update_editor_blocks(
1227 &mut self,
1228 editor: &Entity<Editor>,
1229 assist_id: InlineAssistId,
1230 window: &mut Window,
1231 cx: &mut App,
1232 ) {
1233 let Some(assist) = self.assists.get_mut(&assist_id) else {
1234 return;
1235 };
1236 let Some(decorations) = assist.decorations.as_mut() else {
1237 return;
1238 };
1239
1240 let codegen = assist.codegen.read(cx);
1241 let old_snapshot = codegen.snapshot(cx);
1242 let old_buffer = codegen.old_buffer(cx);
1243 let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1244
1245 editor.update(cx, |editor, cx| {
1246 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1247 editor.remove_blocks(old_blocks, None, cx);
1248
1249 let mut new_blocks = Vec::new();
1250 for (new_row, old_row_range) in deleted_row_ranges {
1251 let (_, buffer_start) = old_snapshot
1252 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1253 .unwrap();
1254 let (_, buffer_end) = old_snapshot
1255 .point_to_buffer_offset(Point::new(
1256 *old_row_range.end(),
1257 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1258 ))
1259 .unwrap();
1260
1261 let deleted_lines_editor = cx.new(|cx| {
1262 let multi_buffer =
1263 cx.new(|_| MultiBuffer::without_headers(language::Capability::ReadOnly));
1264 multi_buffer.update(cx, |multi_buffer, cx| {
1265 multi_buffer.push_excerpts(
1266 old_buffer.clone(),
1267 Some(ExcerptRange::new(buffer_start..buffer_end)),
1268 cx,
1269 );
1270 });
1271
1272 enum DeletedLines {}
1273 let mut editor = Editor::for_multibuffer(multi_buffer, None, window, cx);
1274 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1275 editor.set_show_wrap_guides(false, cx);
1276 editor.set_show_gutter(false, cx);
1277 editor.scroll_manager.set_forbid_vertical_scroll(true);
1278 editor.set_show_scrollbars(false, cx);
1279 editor.set_read_only(true);
1280 editor.set_show_edit_predictions(Some(false), window, cx);
1281 editor.highlight_rows::<DeletedLines>(
1282 Anchor::min()..Anchor::max(),
1283 cx.theme().status().deleted_background,
1284 Default::default(),
1285 cx,
1286 );
1287 editor
1288 });
1289
1290 let height =
1291 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1292 new_blocks.push(BlockProperties {
1293 placement: BlockPlacement::Above(new_row),
1294 height: Some(height),
1295 style: BlockStyle::Flex,
1296 render: Arc::new(move |cx| {
1297 div()
1298 .block_mouse_down()
1299 .bg(cx.theme().status().deleted_background)
1300 .size_full()
1301 .h(height as f32 * cx.window.line_height())
1302 .pl(cx.gutter_dimensions.full_width())
1303 .child(deleted_lines_editor.clone())
1304 .into_any_element()
1305 }),
1306 priority: 0,
1307 });
1308 }
1309
1310 decorations.removed_line_block_ids = editor
1311 .insert_blocks(new_blocks, None, cx)
1312 .into_iter()
1313 .collect();
1314 })
1315 }
1316}
1317
1318struct EditorInlineAssists {
1319 assist_ids: Vec<InlineAssistId>,
1320 scroll_lock: Option<InlineAssistScrollLock>,
1321 highlight_updates: async_watch::Sender<()>,
1322 _update_highlights: Task<Result<()>>,
1323 _subscriptions: Vec<gpui::Subscription>,
1324}
1325
1326struct InlineAssistScrollLock {
1327 assist_id: InlineAssistId,
1328 distance_from_top: f32,
1329}
1330
1331impl EditorInlineAssists {
1332 fn new(editor: &Entity<Editor>, window: &mut Window, cx: &mut App) -> Self {
1333 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1334 Self {
1335 assist_ids: Vec::new(),
1336 scroll_lock: None,
1337 highlight_updates: highlight_updates_tx,
1338 _update_highlights: cx.spawn({
1339 let editor = editor.downgrade();
1340 async move |cx| {
1341 while let Ok(()) = highlight_updates_rx.changed().await {
1342 let editor = editor.upgrade().context("editor was dropped")?;
1343 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1344 assistant.update_editor_highlights(&editor, cx);
1345 })?;
1346 }
1347 Ok(())
1348 }
1349 }),
1350 _subscriptions: vec![
1351 cx.observe_release_in(editor, window, {
1352 let editor = editor.downgrade();
1353 |_, window, cx| {
1354 InlineAssistant::update_global(cx, |this, cx| {
1355 this.handle_editor_release(editor, window, cx);
1356 })
1357 }
1358 }),
1359 window.observe(editor, cx, move |editor, window, cx| {
1360 InlineAssistant::update_global(cx, |this, cx| {
1361 this.handle_editor_change(editor, window, cx)
1362 })
1363 }),
1364 window.subscribe(editor, cx, move |editor, event, window, cx| {
1365 InlineAssistant::update_global(cx, |this, cx| {
1366 this.handle_editor_event(editor, event, window, cx)
1367 })
1368 }),
1369 editor.update(cx, |editor, cx| {
1370 let editor_handle = cx.entity().downgrade();
1371 editor.register_action(move |_: &editor::actions::Newline, window, cx| {
1372 InlineAssistant::update_global(cx, |this, cx| {
1373 if let Some(editor) = editor_handle.upgrade() {
1374 this.handle_editor_newline(editor, window, cx)
1375 }
1376 })
1377 })
1378 }),
1379 editor.update(cx, |editor, cx| {
1380 let editor_handle = cx.entity().downgrade();
1381 editor.register_action(move |_: &editor::actions::Cancel, window, cx| {
1382 InlineAssistant::update_global(cx, |this, cx| {
1383 if let Some(editor) = editor_handle.upgrade() {
1384 this.handle_editor_cancel(editor, window, cx)
1385 }
1386 })
1387 })
1388 }),
1389 ],
1390 }
1391 }
1392}
1393
1394struct InlineAssistGroup {
1395 assist_ids: Vec<InlineAssistId>,
1396 linked: bool,
1397 active_assist_id: Option<InlineAssistId>,
1398}
1399
1400impl InlineAssistGroup {
1401 fn new() -> Self {
1402 Self {
1403 assist_ids: Vec::new(),
1404 linked: true,
1405 active_assist_id: None,
1406 }
1407 }
1408}
1409
1410fn build_assist_editor_renderer(editor: &Entity<PromptEditor>) -> RenderBlock {
1411 let editor = editor.clone();
1412 Arc::new(move |cx: &mut BlockContext| {
1413 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1414 editor.clone().into_any_element()
1415 })
1416}
1417
1418#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1419pub struct InlineAssistId(usize);
1420
1421impl InlineAssistId {
1422 fn post_inc(&mut self) -> InlineAssistId {
1423 let id = *self;
1424 self.0 += 1;
1425 id
1426 }
1427}
1428
1429#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1430struct InlineAssistGroupId(usize);
1431
1432impl InlineAssistGroupId {
1433 fn post_inc(&mut self) -> InlineAssistGroupId {
1434 let id = *self;
1435 self.0 += 1;
1436 id
1437 }
1438}
1439
1440enum PromptEditorEvent {
1441 StartRequested,
1442 StopRequested,
1443 ConfirmRequested,
1444 CancelRequested,
1445 DismissRequested,
1446}
1447
1448struct PromptEditor {
1449 id: InlineAssistId,
1450 editor: Entity<Editor>,
1451 language_model_selector: Entity<LanguageModelSelector>,
1452 edited_since_done: bool,
1453 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1454 prompt_history: VecDeque<String>,
1455 prompt_history_ix: Option<usize>,
1456 pending_prompt: String,
1457 codegen: Entity<Codegen>,
1458 _codegen_subscription: Subscription,
1459 editor_subscriptions: Vec<Subscription>,
1460 pending_token_count: Task<Result<()>>,
1461 token_counts: Option<TokenCounts>,
1462 _token_count_subscriptions: Vec<Subscription>,
1463 workspace: Option<WeakEntity<Workspace>>,
1464 show_rate_limit_notice: bool,
1465}
1466
1467#[derive(Copy, Clone)]
1468pub struct TokenCounts {
1469 total: usize,
1470 assistant_panel: usize,
1471}
1472
1473impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1474
1475impl Render for PromptEditor {
1476 fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
1477 let gutter_dimensions = *self.gutter_dimensions.lock();
1478 let codegen = self.codegen.read(cx);
1479
1480 let mut buttons = Vec::new();
1481 if codegen.alternative_count(cx) > 1 {
1482 buttons.push(self.render_cycle_controls(cx));
1483 }
1484
1485 let status = codegen.status(cx);
1486 buttons.extend(match status {
1487 CodegenStatus::Idle => {
1488 vec![
1489 IconButton::new("cancel", IconName::Close)
1490 .icon_color(Color::Muted)
1491 .shape(IconButtonShape::Square)
1492 .tooltip(|window, cx| {
1493 Tooltip::for_action("Cancel Assist", &menu::Cancel, window, cx)
1494 })
1495 .on_click(
1496 cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1497 )
1498 .into_any_element(),
1499 IconButton::new("start", IconName::SparkleAlt)
1500 .icon_color(Color::Muted)
1501 .shape(IconButtonShape::Square)
1502 .tooltip(|window, cx| {
1503 Tooltip::for_action("Transform", &menu::Confirm, window, cx)
1504 })
1505 .on_click(
1506 cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1507 )
1508 .into_any_element(),
1509 ]
1510 }
1511 CodegenStatus::Pending => {
1512 vec![
1513 IconButton::new("cancel", IconName::Close)
1514 .icon_color(Color::Muted)
1515 .shape(IconButtonShape::Square)
1516 .tooltip(Tooltip::text("Cancel Assist"))
1517 .on_click(
1518 cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1519 )
1520 .into_any_element(),
1521 IconButton::new("stop", IconName::Stop)
1522 .icon_color(Color::Error)
1523 .shape(IconButtonShape::Square)
1524 .tooltip(|window, cx| {
1525 Tooltip::with_meta(
1526 "Interrupt Transformation",
1527 Some(&menu::Cancel),
1528 "Changes won't be discarded",
1529 window,
1530 cx,
1531 )
1532 })
1533 .on_click(
1534 cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1535 )
1536 .into_any_element(),
1537 ]
1538 }
1539 CodegenStatus::Error(_) | CodegenStatus::Done => {
1540 let must_rerun =
1541 self.edited_since_done || matches!(status, CodegenStatus::Error(_));
1542 // when accept button isn't visible, then restart maps to confirm
1543 // when accept button is visible, then restart must be mapped to an alternate keyboard shortcut
1544 let restart_key: &dyn gpui::Action = if must_rerun {
1545 &menu::Confirm
1546 } else {
1547 &menu::Restart
1548 };
1549 vec![
1550 IconButton::new("cancel", IconName::Close)
1551 .icon_color(Color::Muted)
1552 .shape(IconButtonShape::Square)
1553 .tooltip(|window, cx| {
1554 Tooltip::for_action("Cancel Assist", &menu::Cancel, window, cx)
1555 })
1556 .on_click(
1557 cx.listener(|_, _, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1558 )
1559 .into_any_element(),
1560 IconButton::new("restart", IconName::RotateCw)
1561 .icon_color(Color::Muted)
1562 .shape(IconButtonShape::Square)
1563 .tooltip(|window, cx| {
1564 Tooltip::with_meta(
1565 "Regenerate Transformation",
1566 Some(restart_key),
1567 "Current change will be discarded",
1568 window,
1569 cx,
1570 )
1571 })
1572 .on_click(cx.listener(|_, _, _, cx| {
1573 cx.emit(PromptEditorEvent::StartRequested);
1574 }))
1575 .into_any_element(),
1576 if !must_rerun {
1577 IconButton::new("confirm", IconName::Check)
1578 .icon_color(Color::Info)
1579 .shape(IconButtonShape::Square)
1580 .tooltip(|window, cx| {
1581 Tooltip::for_action("Confirm Assist", &menu::Confirm, window, cx)
1582 })
1583 .on_click(cx.listener(|_, _, _, cx| {
1584 cx.emit(PromptEditorEvent::ConfirmRequested);
1585 }))
1586 .into_any_element()
1587 } else {
1588 div().into_any_element()
1589 },
1590 ]
1591 }
1592 });
1593
1594 h_flex()
1595 .key_context("PromptEditor")
1596 .bg(cx.theme().colors().editor_background)
1597 .block_mouse_down()
1598 .cursor(CursorStyle::Arrow)
1599 .border_y_1()
1600 .border_color(cx.theme().status().info_border)
1601 .size_full()
1602 .py(window.line_height() / 2.5)
1603 .on_action(cx.listener(Self::confirm))
1604 .on_action(cx.listener(Self::cancel))
1605 .on_action(cx.listener(Self::restart))
1606 .on_action(cx.listener(Self::move_up))
1607 .on_action(cx.listener(Self::move_down))
1608 .capture_action(cx.listener(Self::cycle_prev))
1609 .capture_action(cx.listener(Self::cycle_next))
1610 .child(
1611 h_flex()
1612 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1613 .justify_center()
1614 .gap_2()
1615 .child(LanguageModelSelectorPopoverMenu::new(
1616 self.language_model_selector.clone(),
1617 IconButton::new("context", IconName::SettingsAlt)
1618 .shape(IconButtonShape::Square)
1619 .icon_size(IconSize::Small)
1620 .icon_color(Color::Muted),
1621 move |window, cx| {
1622 Tooltip::with_meta(
1623 format!(
1624 "Using {}",
1625 LanguageModelRegistry::read_global(cx)
1626 .default_model()
1627 .map(|default| default.model.name().0)
1628 .unwrap_or_else(|| "No model selected".into()),
1629 ),
1630 None,
1631 "Change Model",
1632 window,
1633 cx,
1634 )
1635 },
1636 gpui::Corner::TopRight,
1637 ))
1638 .map(|el| {
1639 let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1640 return el;
1641 };
1642
1643 let error_message = SharedString::from(error.to_string());
1644 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1645 && cx.has_flag::<ZedProFeatureFlag>()
1646 {
1647 el.child(
1648 v_flex()
1649 .child(
1650 IconButton::new("rate-limit-error", IconName::XCircle)
1651 .toggle_state(self.show_rate_limit_notice)
1652 .shape(IconButtonShape::Square)
1653 .icon_size(IconSize::Small)
1654 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1655 )
1656 .children(self.show_rate_limit_notice.then(|| {
1657 deferred(
1658 anchored()
1659 .position_mode(gpui::AnchoredPositionMode::Local)
1660 .position(point(px(0.), px(24.)))
1661 .anchor(gpui::Corner::TopLeft)
1662 .child(self.render_rate_limit_notice(cx)),
1663 )
1664 })),
1665 )
1666 } else {
1667 el.child(
1668 div()
1669 .id("error")
1670 .tooltip(Tooltip::text(error_message))
1671 .child(
1672 Icon::new(IconName::XCircle)
1673 .size(IconSize::Small)
1674 .color(Color::Error),
1675 ),
1676 )
1677 }
1678 }),
1679 )
1680 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1681 .child(
1682 h_flex()
1683 .gap_2()
1684 .pr_6()
1685 .children(self.render_token_count(cx))
1686 .children(buttons),
1687 )
1688 }
1689}
1690
1691impl Focusable for PromptEditor {
1692 fn focus_handle(&self, cx: &App) -> FocusHandle {
1693 self.editor.focus_handle(cx)
1694 }
1695}
1696
1697impl PromptEditor {
1698 const MAX_LINES: u8 = 8;
1699
1700 fn new(
1701 id: InlineAssistId,
1702 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1703 prompt_history: VecDeque<String>,
1704 prompt_buffer: Entity<MultiBuffer>,
1705 codegen: Entity<Codegen>,
1706 parent_editor: &Entity<Editor>,
1707 assistant_panel: Option<&Entity<AssistantPanel>>,
1708 workspace: Option<WeakEntity<Workspace>>,
1709 fs: Arc<dyn Fs>,
1710 window: &mut Window,
1711 cx: &mut Context<Self>,
1712 ) -> Self {
1713 let prompt_editor = cx.new(|cx| {
1714 let mut editor = Editor::new(
1715 EditorMode::AutoHeight {
1716 max_lines: Self::MAX_LINES as usize,
1717 },
1718 prompt_buffer,
1719 None,
1720 window,
1721 cx,
1722 );
1723 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1724 // Since the prompt editors for all inline assistants are linked,
1725 // always show the cursor (even when it isn't focused) because
1726 // typing in one will make what you typed appear in all of them.
1727 editor.set_show_cursor_when_unfocused(true, cx);
1728 editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx), window, cx), cx);
1729 editor
1730 });
1731
1732 let mut token_count_subscriptions = Vec::new();
1733 token_count_subscriptions.push(cx.subscribe_in(
1734 parent_editor,
1735 window,
1736 Self::handle_parent_editor_event,
1737 ));
1738 if let Some(assistant_panel) = assistant_panel {
1739 token_count_subscriptions.push(cx.subscribe_in(
1740 assistant_panel,
1741 window,
1742 Self::handle_assistant_panel_event,
1743 ));
1744 }
1745
1746 let mut this = Self {
1747 id,
1748 editor: prompt_editor,
1749 language_model_selector: cx.new(|cx| {
1750 let fs = fs.clone();
1751 LanguageModelSelector::new(
1752 |cx| LanguageModelRegistry::read_global(cx).default_model(),
1753 move |model, cx| {
1754 update_settings_file::<AssistantSettings>(
1755 fs.clone(),
1756 cx,
1757 move |settings, _| settings.set_model(model.clone()),
1758 );
1759 },
1760 window,
1761 cx,
1762 )
1763 }),
1764 edited_since_done: false,
1765 gutter_dimensions,
1766 prompt_history,
1767 prompt_history_ix: None,
1768 pending_prompt: String::new(),
1769 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1770 editor_subscriptions: Vec::new(),
1771 codegen,
1772 pending_token_count: Task::ready(Ok(())),
1773 token_counts: None,
1774 _token_count_subscriptions: token_count_subscriptions,
1775 workspace,
1776 show_rate_limit_notice: false,
1777 };
1778 this.count_tokens(cx);
1779 this.subscribe_to_editor(window, cx);
1780 this
1781 }
1782
1783 fn subscribe_to_editor(&mut self, window: &mut Window, cx: &mut Context<Self>) {
1784 self.editor_subscriptions.clear();
1785 self.editor_subscriptions.push(cx.subscribe_in(
1786 &self.editor,
1787 window,
1788 Self::handle_prompt_editor_events,
1789 ));
1790 }
1791
1792 fn set_show_cursor_when_unfocused(
1793 &mut self,
1794 show_cursor_when_unfocused: bool,
1795 cx: &mut Context<Self>,
1796 ) {
1797 self.editor.update(cx, |editor, cx| {
1798 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1799 });
1800 }
1801
1802 fn unlink(&mut self, window: &mut Window, cx: &mut Context<Self>) {
1803 let prompt = self.prompt(cx);
1804 let focus = self.editor.focus_handle(cx).contains_focused(window, cx);
1805 self.editor = cx.new(|cx| {
1806 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, window, cx);
1807 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1808 editor.set_placeholder_text("Add a prompt…", cx);
1809 editor.set_text(prompt, window, cx);
1810 if focus {
1811 window.focus(&editor.focus_handle(cx));
1812 }
1813 editor
1814 });
1815 self.subscribe_to_editor(window, cx);
1816 }
1817
1818 fn placeholder_text(codegen: &Codegen, window: &Window, cx: &App) -> String {
1819 let context_keybinding = text_for_action(&zed_actions::assistant::ToggleFocus, window, cx)
1820 .map(|keybinding| format!(" • {keybinding} for context"))
1821 .unwrap_or_default();
1822
1823 let action = if codegen.is_insertion {
1824 "Generate"
1825 } else {
1826 "Transform"
1827 };
1828
1829 format!("{action}…{context_keybinding} • ↓↑ for history")
1830 }
1831
1832 fn prompt(&self, cx: &App) -> String {
1833 self.editor.read(cx).text(cx)
1834 }
1835
1836 fn toggle_rate_limit_notice(
1837 &mut self,
1838 _: &ClickEvent,
1839 window: &mut Window,
1840 cx: &mut Context<Self>,
1841 ) {
1842 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1843 if self.show_rate_limit_notice {
1844 window.focus(&self.editor.focus_handle(cx));
1845 }
1846 cx.notify();
1847 }
1848
1849 fn handle_parent_editor_event(
1850 &mut self,
1851 _: &Entity<Editor>,
1852 event: &EditorEvent,
1853 _: &mut Window,
1854 cx: &mut Context<Self>,
1855 ) {
1856 if let EditorEvent::BufferEdited { .. } = event {
1857 self.count_tokens(cx);
1858 }
1859 }
1860
1861 fn handle_assistant_panel_event(
1862 &mut self,
1863 _: &Entity<AssistantPanel>,
1864 event: &AssistantPanelEvent,
1865 _: &mut Window,
1866 cx: &mut Context<Self>,
1867 ) {
1868 let AssistantPanelEvent::ContextEdited { .. } = event;
1869 self.count_tokens(cx);
1870 }
1871
1872 fn count_tokens(&mut self, cx: &mut Context<Self>) {
1873 let assist_id = self.id;
1874 self.pending_token_count = cx.spawn(async move |this, cx| {
1875 cx.background_executor().timer(Duration::from_secs(1)).await;
1876 let token_count = cx
1877 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1878 let assist = inline_assistant
1879 .assists
1880 .get(&assist_id)
1881 .context("assist not found")?;
1882 anyhow::Ok(assist.count_tokens(cx))
1883 })??
1884 .await?;
1885
1886 this.update(cx, |this, cx| {
1887 this.token_counts = Some(token_count);
1888 cx.notify();
1889 })
1890 })
1891 }
1892
1893 fn handle_prompt_editor_events(
1894 &mut self,
1895 _: &Entity<Editor>,
1896 event: &EditorEvent,
1897 window: &mut Window,
1898 cx: &mut Context<Self>,
1899 ) {
1900 match event {
1901 EditorEvent::Edited { .. } => {
1902 if let Some(workspace) = window.root::<Workspace>().flatten() {
1903 workspace.update(cx, |workspace, cx| {
1904 let is_via_ssh = workspace
1905 .project()
1906 .update(cx, |project, _| project.is_via_ssh());
1907
1908 workspace
1909 .client()
1910 .telemetry()
1911 .log_edit_event("inline assist", is_via_ssh);
1912 });
1913 }
1914 let prompt = self.editor.read(cx).text(cx);
1915 if self
1916 .prompt_history_ix
1917 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1918 {
1919 self.prompt_history_ix.take();
1920 self.pending_prompt = prompt;
1921 }
1922
1923 self.edited_since_done = true;
1924 cx.notify();
1925 }
1926 EditorEvent::BufferEdited => {
1927 self.count_tokens(cx);
1928 }
1929 EditorEvent::Blurred => {
1930 if self.show_rate_limit_notice {
1931 self.show_rate_limit_notice = false;
1932 cx.notify();
1933 }
1934 }
1935 _ => {}
1936 }
1937 }
1938
1939 fn handle_codegen_changed(&mut self, _: Entity<Codegen>, cx: &mut Context<Self>) {
1940 match self.codegen.read(cx).status(cx) {
1941 CodegenStatus::Idle => {
1942 self.editor
1943 .update(cx, |editor, _| editor.set_read_only(false));
1944 }
1945 CodegenStatus::Pending => {
1946 self.editor
1947 .update(cx, |editor, _| editor.set_read_only(true));
1948 }
1949 CodegenStatus::Done => {
1950 self.edited_since_done = false;
1951 self.editor
1952 .update(cx, |editor, _| editor.set_read_only(false));
1953 }
1954 CodegenStatus::Error(error) => {
1955 if cx.has_flag::<ZedProFeatureFlag>()
1956 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1957 && !dismissed_rate_limit_notice()
1958 {
1959 self.show_rate_limit_notice = true;
1960 cx.notify();
1961 }
1962
1963 self.edited_since_done = false;
1964 self.editor
1965 .update(cx, |editor, _| editor.set_read_only(false));
1966 }
1967 }
1968 }
1969
1970 fn restart(&mut self, _: &menu::Restart, _window: &mut Window, cx: &mut Context<Self>) {
1971 cx.emit(PromptEditorEvent::StartRequested);
1972 }
1973
1974 fn cancel(
1975 &mut self,
1976 _: &editor::actions::Cancel,
1977 _window: &mut Window,
1978 cx: &mut Context<Self>,
1979 ) {
1980 match self.codegen.read(cx).status(cx) {
1981 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1982 cx.emit(PromptEditorEvent::CancelRequested);
1983 }
1984 CodegenStatus::Pending => {
1985 cx.emit(PromptEditorEvent::StopRequested);
1986 }
1987 }
1988 }
1989
1990 fn confirm(&mut self, _: &menu::Confirm, _window: &mut Window, cx: &mut Context<Self>) {
1991 match self.codegen.read(cx).status(cx) {
1992 CodegenStatus::Idle => {
1993 cx.emit(PromptEditorEvent::StartRequested);
1994 }
1995 CodegenStatus::Pending => {
1996 cx.emit(PromptEditorEvent::DismissRequested);
1997 }
1998 CodegenStatus::Done => {
1999 if self.edited_since_done {
2000 cx.emit(PromptEditorEvent::StartRequested);
2001 } else {
2002 cx.emit(PromptEditorEvent::ConfirmRequested);
2003 }
2004 }
2005 CodegenStatus::Error(_) => {
2006 cx.emit(PromptEditorEvent::StartRequested);
2007 }
2008 }
2009 }
2010
2011 fn move_up(&mut self, _: &MoveUp, window: &mut Window, cx: &mut Context<Self>) {
2012 if let Some(ix) = self.prompt_history_ix {
2013 if ix > 0 {
2014 self.prompt_history_ix = Some(ix - 1);
2015 let prompt = self.prompt_history[ix - 1].as_str();
2016 self.editor.update(cx, |editor, cx| {
2017 editor.set_text(prompt, window, cx);
2018 editor.move_to_beginning(&Default::default(), window, cx);
2019 });
2020 }
2021 } else if !self.prompt_history.is_empty() {
2022 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
2023 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
2024 self.editor.update(cx, |editor, cx| {
2025 editor.set_text(prompt, window, cx);
2026 editor.move_to_beginning(&Default::default(), window, cx);
2027 });
2028 }
2029 }
2030
2031 fn move_down(&mut self, _: &MoveDown, window: &mut Window, cx: &mut Context<Self>) {
2032 if let Some(ix) = self.prompt_history_ix {
2033 if ix < self.prompt_history.len() - 1 {
2034 self.prompt_history_ix = Some(ix + 1);
2035 let prompt = self.prompt_history[ix + 1].as_str();
2036 self.editor.update(cx, |editor, cx| {
2037 editor.set_text(prompt, window, cx);
2038 editor.move_to_end(&Default::default(), window, cx)
2039 });
2040 } else {
2041 self.prompt_history_ix = None;
2042 let prompt = self.pending_prompt.as_str();
2043 self.editor.update(cx, |editor, cx| {
2044 editor.set_text(prompt, window, cx);
2045 editor.move_to_end(&Default::default(), window, cx)
2046 });
2047 }
2048 }
2049 }
2050
2051 fn cycle_prev(
2052 &mut self,
2053 _: &CyclePreviousInlineAssist,
2054 _: &mut Window,
2055 cx: &mut Context<Self>,
2056 ) {
2057 self.codegen
2058 .update(cx, |codegen, cx| codegen.cycle_prev(cx));
2059 }
2060
2061 fn cycle_next(&mut self, _: &CycleNextInlineAssist, _: &mut Window, cx: &mut Context<Self>) {
2062 self.codegen
2063 .update(cx, |codegen, cx| codegen.cycle_next(cx));
2064 }
2065
2066 fn render_cycle_controls(&self, cx: &Context<Self>) -> AnyElement {
2067 let codegen = self.codegen.read(cx);
2068 let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
2069
2070 let model_registry = LanguageModelRegistry::read_global(cx);
2071 let default_model = model_registry.default_model().map(|default| default.model);
2072 let alternative_models = model_registry.inline_alternative_models();
2073
2074 let get_model_name = |index: usize| -> String {
2075 let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
2076
2077 match index {
2078 0 => default_model.as_ref().map_or_else(String::new, name),
2079 index if index <= alternative_models.len() => alternative_models
2080 .get(index - 1)
2081 .map_or_else(String::new, name),
2082 _ => String::new(),
2083 }
2084 };
2085
2086 let total_models = alternative_models.len() + 1;
2087
2088 if total_models <= 1 {
2089 return div().into_any_element();
2090 }
2091
2092 let current_index = codegen.active_alternative;
2093 let prev_index = (current_index + total_models - 1) % total_models;
2094 let next_index = (current_index + 1) % total_models;
2095
2096 let prev_model_name = get_model_name(prev_index);
2097 let next_model_name = get_model_name(next_index);
2098
2099 h_flex()
2100 .child(
2101 IconButton::new("previous", IconName::ChevronLeft)
2102 .icon_color(Color::Muted)
2103 .disabled(disabled || current_index == 0)
2104 .shape(IconButtonShape::Square)
2105 .tooltip({
2106 let focus_handle = self.editor.focus_handle(cx);
2107 move |window, cx| {
2108 cx.new(|cx| {
2109 let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
2110 KeyBinding::for_action_in(
2111 &CyclePreviousInlineAssist,
2112 &focus_handle,
2113 window,
2114 cx,
2115 ),
2116 );
2117 if !disabled && current_index != 0 {
2118 tooltip = tooltip.meta(prev_model_name.clone());
2119 }
2120 tooltip
2121 })
2122 .into()
2123 }
2124 })
2125 .on_click(cx.listener(|this, _, _, cx| {
2126 this.codegen
2127 .update(cx, |codegen, cx| codegen.cycle_prev(cx))
2128 })),
2129 )
2130 .child(
2131 Label::new(format!(
2132 "{}/{}",
2133 codegen.active_alternative + 1,
2134 codegen.alternative_count(cx)
2135 ))
2136 .size(LabelSize::Small)
2137 .color(if disabled {
2138 Color::Disabled
2139 } else {
2140 Color::Muted
2141 }),
2142 )
2143 .child(
2144 IconButton::new("next", IconName::ChevronRight)
2145 .icon_color(Color::Muted)
2146 .disabled(disabled || current_index == total_models - 1)
2147 .shape(IconButtonShape::Square)
2148 .tooltip({
2149 let focus_handle = self.editor.focus_handle(cx);
2150 move |window, cx| {
2151 cx.new(|cx| {
2152 let mut tooltip = Tooltip::new("Next Alternative").key_binding(
2153 KeyBinding::for_action_in(
2154 &CycleNextInlineAssist,
2155 &focus_handle,
2156 window,
2157 cx,
2158 ),
2159 );
2160 if !disabled && current_index != total_models - 1 {
2161 tooltip = tooltip.meta(next_model_name.clone());
2162 }
2163 tooltip
2164 })
2165 .into()
2166 }
2167 })
2168 .on_click(cx.listener(|this, _, _, cx| {
2169 this.codegen
2170 .update(cx, |codegen, cx| codegen.cycle_next(cx))
2171 })),
2172 )
2173 .into_any_element()
2174 }
2175
2176 fn render_token_count(&self, cx: &mut Context<Self>) -> Option<impl IntoElement> {
2177 let model = LanguageModelRegistry::read_global(cx)
2178 .default_model()?
2179 .model;
2180 let token_counts = self.token_counts?;
2181 let max_token_count = model.max_token_count();
2182
2183 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
2184 let token_count_color = if remaining_tokens <= 0 {
2185 Color::Error
2186 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
2187 Color::Warning
2188 } else {
2189 Color::Muted
2190 };
2191
2192 let mut token_count = h_flex()
2193 .id("token_count")
2194 .gap_0p5()
2195 .child(
2196 Label::new(humanize_token_count(token_counts.total))
2197 .size(LabelSize::Small)
2198 .color(token_count_color),
2199 )
2200 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
2201 .child(
2202 Label::new(humanize_token_count(max_token_count))
2203 .size(LabelSize::Small)
2204 .color(Color::Muted),
2205 );
2206 if let Some(workspace) = self.workspace.clone() {
2207 token_count = token_count
2208 .tooltip(move |window, cx| {
2209 Tooltip::with_meta(
2210 format!(
2211 "Tokens Used ({} from the Assistant Panel)",
2212 humanize_token_count(token_counts.assistant_panel)
2213 ),
2214 None,
2215 "Click to open the Assistant Panel",
2216 window,
2217 cx,
2218 )
2219 })
2220 .cursor_pointer()
2221 .on_mouse_down(gpui::MouseButton::Left, |_, _, cx| cx.stop_propagation())
2222 .on_click(move |_, window, cx| {
2223 cx.stop_propagation();
2224 workspace
2225 .update(cx, |workspace, cx| {
2226 workspace.focus_panel::<AssistantPanel>(window, cx)
2227 })
2228 .ok();
2229 });
2230 } else {
2231 token_count = token_count
2232 .cursor_default()
2233 .tooltip(Tooltip::text("Tokens used"));
2234 }
2235
2236 Some(token_count)
2237 }
2238
2239 fn render_prompt_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
2240 let settings = ThemeSettings::get_global(cx);
2241 let text_style = TextStyle {
2242 color: if self.editor.read(cx).read_only(cx) {
2243 cx.theme().colors().text_disabled
2244 } else {
2245 cx.theme().colors().text
2246 },
2247 font_family: settings.buffer_font.family.clone(),
2248 font_fallbacks: settings.buffer_font.fallbacks.clone(),
2249 font_size: settings.buffer_font_size(cx).into(),
2250 font_weight: settings.buffer_font.weight,
2251 line_height: relative(settings.buffer_line_height.value()),
2252 ..Default::default()
2253 };
2254 EditorElement::new(
2255 &self.editor,
2256 EditorStyle {
2257 background: cx.theme().colors().editor_background,
2258 local_player: cx.theme().players().local(),
2259 text: text_style,
2260 ..Default::default()
2261 },
2262 )
2263 }
2264
2265 fn render_rate_limit_notice(&self, cx: &mut Context<Self>) -> impl IntoElement {
2266 Popover::new().child(
2267 v_flex()
2268 .occlude()
2269 .p_2()
2270 .child(
2271 Label::new("Out of Tokens")
2272 .size(LabelSize::Small)
2273 .weight(FontWeight::BOLD),
2274 )
2275 .child(Label::new(
2276 "Try Zed Pro for higher limits, a wider range of models, and more.",
2277 ))
2278 .child(
2279 h_flex()
2280 .justify_between()
2281 .child(CheckboxWithLabel::new(
2282 "dont-show-again",
2283 Label::new("Don't show again"),
2284 if dismissed_rate_limit_notice() {
2285 ui::ToggleState::Selected
2286 } else {
2287 ui::ToggleState::Unselected
2288 },
2289 |selection, _, cx| {
2290 let is_dismissed = match selection {
2291 ui::ToggleState::Unselected => false,
2292 ui::ToggleState::Indeterminate => return,
2293 ui::ToggleState::Selected => true,
2294 };
2295
2296 set_rate_limit_notice_dismissed(is_dismissed, cx)
2297 },
2298 ))
2299 .child(
2300 h_flex()
2301 .gap_2()
2302 .child(
2303 Button::new("dismiss", "Dismiss")
2304 .style(ButtonStyle::Transparent)
2305 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2306 )
2307 .child(Button::new("more-info", "More Info").on_click(
2308 |_event, window, cx| {
2309 window.dispatch_action(
2310 Box::new(zed_actions::OpenAccountSettings),
2311 cx,
2312 )
2313 },
2314 )),
2315 ),
2316 ),
2317 )
2318 }
2319}
2320
2321const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2322
2323fn dismissed_rate_limit_notice() -> bool {
2324 db::kvp::KEY_VALUE_STORE
2325 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2326 .log_err()
2327 .map_or(false, |s| s.is_some())
2328}
2329
2330fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut App) {
2331 db::write_and_log(cx, move || async move {
2332 if is_dismissed {
2333 db::kvp::KEY_VALUE_STORE
2334 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2335 .await
2336 } else {
2337 db::kvp::KEY_VALUE_STORE
2338 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2339 .await
2340 }
2341 })
2342}
2343
2344struct InlineAssist {
2345 group_id: InlineAssistGroupId,
2346 range: Range<Anchor>,
2347 editor: WeakEntity<Editor>,
2348 decorations: Option<InlineAssistDecorations>,
2349 codegen: Entity<Codegen>,
2350 _subscriptions: Vec<Subscription>,
2351 workspace: Option<WeakEntity<Workspace>>,
2352 include_context: bool,
2353}
2354
2355impl InlineAssist {
2356 fn new(
2357 assist_id: InlineAssistId,
2358 group_id: InlineAssistGroupId,
2359 include_context: bool,
2360 editor: &Entity<Editor>,
2361 prompt_editor: &Entity<PromptEditor>,
2362 prompt_block_id: CustomBlockId,
2363 end_block_id: CustomBlockId,
2364 range: Range<Anchor>,
2365 codegen: Entity<Codegen>,
2366 workspace: Option<WeakEntity<Workspace>>,
2367 window: &mut Window,
2368 cx: &mut App,
2369 ) -> Self {
2370 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2371 InlineAssist {
2372 group_id,
2373 include_context,
2374 editor: editor.downgrade(),
2375 decorations: Some(InlineAssistDecorations {
2376 prompt_block_id,
2377 prompt_editor: prompt_editor.clone(),
2378 removed_line_block_ids: HashSet::default(),
2379 end_block_id,
2380 }),
2381 range,
2382 codegen: codegen.clone(),
2383 workspace: workspace.clone(),
2384 _subscriptions: vec![
2385 window.on_focus_in(&prompt_editor_focus_handle, cx, move |_, cx| {
2386 InlineAssistant::update_global(cx, |this, cx| {
2387 this.handle_prompt_editor_focus_in(assist_id, cx)
2388 })
2389 }),
2390 window.on_focus_out(&prompt_editor_focus_handle, cx, move |_, _, cx| {
2391 InlineAssistant::update_global(cx, |this, cx| {
2392 this.handle_prompt_editor_focus_out(assist_id, cx)
2393 })
2394 }),
2395 window.subscribe(
2396 prompt_editor,
2397 cx,
2398 move |prompt_editor, event, window, cx| {
2399 InlineAssistant::update_global(cx, |this, cx| {
2400 this.handle_prompt_editor_event(prompt_editor, event, window, cx)
2401 })
2402 },
2403 ),
2404 window.observe(&codegen, cx, {
2405 let editor = editor.downgrade();
2406 move |_, window, cx| {
2407 if let Some(editor) = editor.upgrade() {
2408 InlineAssistant::update_global(cx, |this, cx| {
2409 if let Some(editor_assists) =
2410 this.assists_by_editor.get(&editor.downgrade())
2411 {
2412 editor_assists.highlight_updates.send(()).ok();
2413 }
2414
2415 this.update_editor_blocks(&editor, assist_id, window, cx);
2416 })
2417 }
2418 }
2419 }),
2420 window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
2421 InlineAssistant::update_global(cx, |this, cx| match event {
2422 CodegenEvent::Undone => this.finish_assist(assist_id, false, window, cx),
2423 CodegenEvent::Finished => {
2424 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2425 assist
2426 } else {
2427 return;
2428 };
2429
2430 if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2431 if assist.decorations.is_none() {
2432 if let Some(workspace) = assist
2433 .workspace
2434 .as_ref()
2435 .and_then(|workspace| workspace.upgrade())
2436 {
2437 let error = format!("Inline assistant error: {}", error);
2438 workspace.update(cx, |workspace, cx| {
2439 struct InlineAssistantError;
2440
2441 let id =
2442 NotificationId::composite::<InlineAssistantError>(
2443 assist_id.0,
2444 );
2445
2446 workspace.show_toast(Toast::new(id, error), cx);
2447 })
2448 }
2449 }
2450 }
2451
2452 if assist.decorations.is_none() {
2453 this.finish_assist(assist_id, false, window, cx);
2454 }
2455 }
2456 })
2457 }),
2458 ],
2459 }
2460 }
2461
2462 fn user_prompt(&self, cx: &App) -> Option<String> {
2463 let decorations = self.decorations.as_ref()?;
2464 Some(decorations.prompt_editor.read(cx).prompt(cx))
2465 }
2466
2467 fn assistant_panel_context(&self, cx: &mut App) -> Option<LanguageModelRequest> {
2468 if self.include_context {
2469 let workspace = self.workspace.as_ref()?;
2470 let workspace = workspace.upgrade()?.read(cx);
2471 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2472 Some(
2473 assistant_panel
2474 .read(cx)
2475 .active_context(cx)?
2476 .read(cx)
2477 .to_completion_request(None, RequestType::Chat, cx),
2478 )
2479 } else {
2480 None
2481 }
2482 }
2483
2484 pub fn count_tokens(&self, cx: &mut App) -> BoxFuture<'static, Result<TokenCounts>> {
2485 let Some(user_prompt) = self.user_prompt(cx) else {
2486 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2487 };
2488 let assistant_panel_context = self.assistant_panel_context(cx);
2489 self.codegen
2490 .read(cx)
2491 .count_tokens(user_prompt, assistant_panel_context, cx)
2492 }
2493}
2494
2495struct InlineAssistDecorations {
2496 prompt_block_id: CustomBlockId,
2497 prompt_editor: Entity<PromptEditor>,
2498 removed_line_block_ids: HashSet<CustomBlockId>,
2499 end_block_id: CustomBlockId,
2500}
2501
2502#[derive(Copy, Clone, Debug)]
2503pub enum CodegenEvent {
2504 Finished,
2505 Undone,
2506}
2507
2508pub struct Codegen {
2509 alternatives: Vec<Entity<CodegenAlternative>>,
2510 active_alternative: usize,
2511 seen_alternatives: HashSet<usize>,
2512 subscriptions: Vec<Subscription>,
2513 buffer: Entity<MultiBuffer>,
2514 range: Range<Anchor>,
2515 initial_transaction_id: Option<TransactionId>,
2516 telemetry: Arc<Telemetry>,
2517 builder: Arc<PromptBuilder>,
2518 is_insertion: bool,
2519}
2520
2521impl Codegen {
2522 pub fn new(
2523 buffer: Entity<MultiBuffer>,
2524 range: Range<Anchor>,
2525 initial_transaction_id: Option<TransactionId>,
2526 telemetry: Arc<Telemetry>,
2527 builder: Arc<PromptBuilder>,
2528 cx: &mut Context<Self>,
2529 ) -> Self {
2530 let codegen = cx.new(|cx| {
2531 CodegenAlternative::new(
2532 buffer.clone(),
2533 range.clone(),
2534 false,
2535 Some(telemetry.clone()),
2536 builder.clone(),
2537 cx,
2538 )
2539 });
2540 let mut this = Self {
2541 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2542 alternatives: vec![codegen],
2543 active_alternative: 0,
2544 seen_alternatives: HashSet::default(),
2545 subscriptions: Vec::new(),
2546 buffer,
2547 range,
2548 initial_transaction_id,
2549 telemetry,
2550 builder,
2551 };
2552 this.activate(0, cx);
2553 this
2554 }
2555
2556 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
2557 let codegen = self.active_alternative().clone();
2558 self.subscriptions.clear();
2559 self.subscriptions
2560 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2561 self.subscriptions
2562 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2563 }
2564
2565 fn active_alternative(&self) -> &Entity<CodegenAlternative> {
2566 &self.alternatives[self.active_alternative]
2567 }
2568
2569 fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
2570 &self.active_alternative().read(cx).status
2571 }
2572
2573 fn alternative_count(&self, cx: &App) -> usize {
2574 LanguageModelRegistry::read_global(cx)
2575 .inline_alternative_models()
2576 .len()
2577 + 1
2578 }
2579
2580 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
2581 let next_active_ix = if self.active_alternative == 0 {
2582 self.alternatives.len() - 1
2583 } else {
2584 self.active_alternative - 1
2585 };
2586 self.activate(next_active_ix, cx);
2587 }
2588
2589 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
2590 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2591 self.activate(next_active_ix, cx);
2592 }
2593
2594 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
2595 self.active_alternative()
2596 .update(cx, |codegen, cx| codegen.set_active(false, cx));
2597 self.seen_alternatives.insert(index);
2598 self.active_alternative = index;
2599 self.active_alternative()
2600 .update(cx, |codegen, cx| codegen.set_active(true, cx));
2601 self.subscribe_to_alternative(cx);
2602 cx.notify();
2603 }
2604
2605 pub fn start(
2606 &mut self,
2607 user_prompt: String,
2608 assistant_panel_context: Option<LanguageModelRequest>,
2609 cx: &mut Context<Self>,
2610 ) -> Result<()> {
2611 let alternative_models = LanguageModelRegistry::read_global(cx)
2612 .inline_alternative_models()
2613 .to_vec();
2614
2615 self.active_alternative()
2616 .update(cx, |alternative, cx| alternative.undo(cx));
2617 self.activate(0, cx);
2618 self.alternatives.truncate(1);
2619
2620 for _ in 0..alternative_models.len() {
2621 self.alternatives.push(cx.new(|cx| {
2622 CodegenAlternative::new(
2623 self.buffer.clone(),
2624 self.range.clone(),
2625 false,
2626 Some(self.telemetry.clone()),
2627 self.builder.clone(),
2628 cx,
2629 )
2630 }));
2631 }
2632
2633 let primary_model = LanguageModelRegistry::read_global(cx)
2634 .default_model()
2635 .context("no active model")?
2636 .model;
2637
2638 for (model, alternative) in iter::once(primary_model)
2639 .chain(alternative_models)
2640 .zip(&self.alternatives)
2641 {
2642 alternative.update(cx, |alternative, cx| {
2643 alternative.start(
2644 user_prompt.clone(),
2645 assistant_panel_context.clone(),
2646 model.clone(),
2647 cx,
2648 )
2649 })?;
2650 }
2651
2652 Ok(())
2653 }
2654
2655 pub fn stop(&mut self, cx: &mut Context<Self>) {
2656 for codegen in &self.alternatives {
2657 codegen.update(cx, |codegen, cx| codegen.stop(cx));
2658 }
2659 }
2660
2661 pub fn undo(&mut self, cx: &mut Context<Self>) {
2662 self.active_alternative()
2663 .update(cx, |codegen, cx| codegen.undo(cx));
2664
2665 self.buffer.update(cx, |buffer, cx| {
2666 if let Some(transaction_id) = self.initial_transaction_id.take() {
2667 buffer.undo_transaction(transaction_id, cx);
2668 buffer.refresh_preview(cx);
2669 }
2670 });
2671 }
2672
2673 pub fn count_tokens(
2674 &self,
2675 user_prompt: String,
2676 assistant_panel_context: Option<LanguageModelRequest>,
2677 cx: &App,
2678 ) -> BoxFuture<'static, Result<TokenCounts>> {
2679 self.active_alternative()
2680 .read(cx)
2681 .count_tokens(user_prompt, assistant_panel_context, cx)
2682 }
2683
2684 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
2685 self.active_alternative().read(cx).buffer.clone()
2686 }
2687
2688 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
2689 self.active_alternative().read(cx).old_buffer.clone()
2690 }
2691
2692 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
2693 self.active_alternative().read(cx).snapshot.clone()
2694 }
2695
2696 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
2697 self.active_alternative().read(cx).edit_position
2698 }
2699
2700 fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
2701 &self.active_alternative().read(cx).diff
2702 }
2703
2704 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
2705 self.active_alternative().read(cx).last_equal_ranges()
2706 }
2707}
2708
2709impl EventEmitter<CodegenEvent> for Codegen {}
2710
2711pub struct CodegenAlternative {
2712 buffer: Entity<MultiBuffer>,
2713 old_buffer: Entity<Buffer>,
2714 snapshot: MultiBufferSnapshot,
2715 edit_position: Option<Anchor>,
2716 range: Range<Anchor>,
2717 last_equal_ranges: Vec<Range<Anchor>>,
2718 transformation_transaction_id: Option<TransactionId>,
2719 status: CodegenStatus,
2720 generation: Task<()>,
2721 diff: Diff,
2722 telemetry: Option<Arc<Telemetry>>,
2723 _subscription: gpui::Subscription,
2724 builder: Arc<PromptBuilder>,
2725 active: bool,
2726 edits: Vec<(Range<Anchor>, String)>,
2727 line_operations: Vec<LineOperation>,
2728 request: Option<LanguageModelRequest>,
2729 elapsed_time: Option<f64>,
2730 completion: Option<String>,
2731 message_id: Option<String>,
2732}
2733
2734enum CodegenStatus {
2735 Idle,
2736 Pending,
2737 Done,
2738 Error(anyhow::Error),
2739}
2740
2741#[derive(Default)]
2742struct Diff {
2743 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2744 inserted_row_ranges: Vec<Range<Anchor>>,
2745}
2746
2747impl Diff {
2748 fn is_empty(&self) -> bool {
2749 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2750 }
2751}
2752
2753impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2754
2755impl CodegenAlternative {
2756 pub fn new(
2757 multi_buffer: Entity<MultiBuffer>,
2758 range: Range<Anchor>,
2759 active: bool,
2760 telemetry: Option<Arc<Telemetry>>,
2761 builder: Arc<PromptBuilder>,
2762 cx: &mut Context<Self>,
2763 ) -> Self {
2764 let snapshot = multi_buffer.read(cx).snapshot(cx);
2765
2766 let (buffer, _, _) = snapshot
2767 .range_to_buffer_ranges(range.clone())
2768 .pop()
2769 .unwrap();
2770 let old_buffer = cx.new(|cx| {
2771 let text = buffer.as_rope().clone();
2772 let line_ending = buffer.line_ending();
2773 let language = buffer.language().cloned();
2774 let language_registry = multi_buffer
2775 .read(cx)
2776 .buffer(buffer.remote_id())
2777 .unwrap()
2778 .read(cx)
2779 .language_registry();
2780
2781 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2782 buffer.set_language(language, cx);
2783 if let Some(language_registry) = language_registry {
2784 buffer.set_language_registry(language_registry)
2785 }
2786 buffer
2787 });
2788
2789 Self {
2790 buffer: multi_buffer.clone(),
2791 old_buffer,
2792 edit_position: None,
2793 message_id: None,
2794 snapshot,
2795 last_equal_ranges: Default::default(),
2796 transformation_transaction_id: None,
2797 status: CodegenStatus::Idle,
2798 generation: Task::ready(()),
2799 diff: Diff::default(),
2800 telemetry,
2801 _subscription: cx.subscribe(&multi_buffer, Self::handle_buffer_event),
2802 builder,
2803 active,
2804 edits: Vec::new(),
2805 line_operations: Vec::new(),
2806 range,
2807 request: None,
2808 elapsed_time: None,
2809 completion: None,
2810 }
2811 }
2812
2813 fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
2814 if active != self.active {
2815 self.active = active;
2816
2817 if self.active {
2818 let edits = self.edits.clone();
2819 self.apply_edits(edits, cx);
2820 if matches!(self.status, CodegenStatus::Pending) {
2821 let line_operations = self.line_operations.clone();
2822 self.reapply_line_based_diff(line_operations, cx);
2823 } else {
2824 self.reapply_batch_diff(cx).detach();
2825 }
2826 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2827 self.buffer.update(cx, |buffer, cx| {
2828 buffer.undo_transaction(transaction_id, cx);
2829 buffer.forget_transaction(transaction_id, cx);
2830 });
2831 }
2832 }
2833 }
2834
2835 fn handle_buffer_event(
2836 &mut self,
2837 _buffer: Entity<MultiBuffer>,
2838 event: &multi_buffer::Event,
2839 cx: &mut Context<Self>,
2840 ) {
2841 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2842 if self.transformation_transaction_id == Some(*transaction_id) {
2843 self.transformation_transaction_id = None;
2844 self.generation = Task::ready(());
2845 cx.emit(CodegenEvent::Undone);
2846 }
2847 }
2848 }
2849
2850 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2851 &self.last_equal_ranges
2852 }
2853
2854 pub fn count_tokens(
2855 &self,
2856 user_prompt: String,
2857 assistant_panel_context: Option<LanguageModelRequest>,
2858 cx: &App,
2859 ) -> BoxFuture<'static, Result<TokenCounts>> {
2860 if let Some(ConfiguredModel { model, .. }) =
2861 LanguageModelRegistry::read_global(cx).inline_assistant_model()
2862 {
2863 let request =
2864 self.build_request(&model, user_prompt, assistant_panel_context.clone(), cx);
2865 match request {
2866 Ok(request) => {
2867 let total_count = model.count_tokens(request.clone(), cx);
2868 let assistant_panel_count = assistant_panel_context
2869 .map(|context| model.count_tokens(context, cx))
2870 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2871
2872 async move {
2873 Ok(TokenCounts {
2874 total: total_count.await?,
2875 assistant_panel: assistant_panel_count.await?,
2876 })
2877 }
2878 .boxed()
2879 }
2880 Err(error) => futures::future::ready(Err(error)).boxed(),
2881 }
2882 } else {
2883 future::ready(Err(anyhow!("no active model"))).boxed()
2884 }
2885 }
2886
2887 pub fn start(
2888 &mut self,
2889 user_prompt: String,
2890 assistant_panel_context: Option<LanguageModelRequest>,
2891 model: Arc<dyn LanguageModel>,
2892 cx: &mut Context<Self>,
2893 ) -> Result<()> {
2894 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2895 self.buffer.update(cx, |buffer, cx| {
2896 buffer.undo_transaction(transformation_transaction_id, cx);
2897 });
2898 }
2899
2900 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2901
2902 let api_key = model.api_key(cx);
2903 let telemetry_id = model.telemetry_id();
2904 let provider_id = model.provider_id();
2905 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2906 if user_prompt.trim().to_lowercase() == "delete" {
2907 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2908 } else {
2909 let request =
2910 self.build_request(&model, user_prompt, assistant_panel_context, cx)?;
2911 self.request = Some(request.clone());
2912
2913 cx.spawn(async move |_, cx| model.stream_completion_text(request, &cx).await)
2914 .boxed_local()
2915 };
2916 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2917 Ok(())
2918 }
2919
2920 fn build_request(
2921 &self,
2922 model: &Arc<dyn LanguageModel>,
2923 user_prompt: String,
2924 assistant_panel_context: Option<LanguageModelRequest>,
2925 cx: &App,
2926 ) -> Result<LanguageModelRequest> {
2927 let buffer = self.buffer.read(cx).snapshot(cx);
2928 let language = buffer.language_at(self.range.start);
2929 let language_name = if let Some(language) = language.as_ref() {
2930 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2931 None
2932 } else {
2933 Some(language.name())
2934 }
2935 } else {
2936 None
2937 };
2938
2939 let language_name = language_name.as_ref();
2940 let start = buffer.point_to_buffer_offset(self.range.start);
2941 let end = buffer.point_to_buffer_offset(self.range.end);
2942 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2943 let (start_buffer, start_buffer_offset) = start;
2944 let (end_buffer, end_buffer_offset) = end;
2945 if start_buffer.remote_id() == end_buffer.remote_id() {
2946 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2947 } else {
2948 return Err(anyhow::anyhow!("invalid transformation range"));
2949 }
2950 } else {
2951 return Err(anyhow::anyhow!("invalid transformation range"));
2952 };
2953
2954 let prompt = self
2955 .builder
2956 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2957 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2958
2959 let mut messages = Vec::new();
2960 if let Some(context_request) = assistant_panel_context {
2961 messages = context_request.messages;
2962 }
2963
2964 messages.push(LanguageModelRequestMessage {
2965 role: Role::User,
2966 content: vec![prompt.into()],
2967 cache: false,
2968 });
2969
2970 Ok(LanguageModelRequest {
2971 thread_id: None,
2972 prompt_id: None,
2973 mode: None,
2974 messages,
2975 tools: Vec::new(),
2976 stop: Vec::new(),
2977 temperature: AssistantSettings::temperature_for_model(&model, cx),
2978 })
2979 }
2980
2981 pub fn handle_stream(
2982 &mut self,
2983 model_telemetry_id: String,
2984 model_provider_id: String,
2985 model_api_key: Option<String>,
2986 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2987 cx: &mut Context<Self>,
2988 ) {
2989 let start_time = Instant::now();
2990 let snapshot = self.snapshot.clone();
2991 let selected_text = snapshot
2992 .text_for_range(self.range.start..self.range.end)
2993 .collect::<Rope>();
2994
2995 let selection_start = self.range.start.to_point(&snapshot);
2996
2997 // Start with the indentation of the first line in the selection
2998 let mut suggested_line_indent = snapshot
2999 .suggested_indents(selection_start.row..=selection_start.row, cx)
3000 .into_values()
3001 .next()
3002 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
3003
3004 // If the first line in the selection does not have indentation, check the following lines
3005 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
3006 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
3007 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
3008 // Prefer tabs if a line in the selection uses tabs as indentation
3009 if line_indent.kind == IndentKind::Tab {
3010 suggested_line_indent.kind = IndentKind::Tab;
3011 break;
3012 }
3013 }
3014 }
3015
3016 let http_client = cx.http_client();
3017 let telemetry = self.telemetry.clone();
3018 let language_name = {
3019 let multibuffer = self.buffer.read(cx);
3020 let snapshot = multibuffer.snapshot(cx);
3021 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
3022 ranges
3023 .first()
3024 .and_then(|(buffer, _, _)| buffer.language())
3025 .map(|language| language.name())
3026 };
3027
3028 self.diff = Diff::default();
3029 self.status = CodegenStatus::Pending;
3030 let mut edit_start = self.range.start.to_offset(&snapshot);
3031 let completion = Arc::new(Mutex::new(String::new()));
3032 let completion_clone = completion.clone();
3033
3034 self.generation = cx.spawn(async move |codegen, cx| {
3035 let stream = stream.await;
3036 let message_id = stream
3037 .as_ref()
3038 .ok()
3039 .and_then(|stream| stream.message_id.clone());
3040 let generate = async {
3041 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
3042 let executor = cx.background_executor().clone();
3043 let message_id = message_id.clone();
3044 let line_based_stream_diff: Task<anyhow::Result<()>> =
3045 cx.background_spawn(async move {
3046 let mut response_latency = None;
3047 let request_start = Instant::now();
3048 let diff = async {
3049 let chunks =
3050 StripInvalidSpans::new(stream?.stream.map_err(|e| e.into()));
3051 futures::pin_mut!(chunks);
3052 let mut diff = StreamingDiff::new(selected_text.to_string());
3053 let mut line_diff = LineDiff::default();
3054
3055 let mut new_text = String::new();
3056 let mut base_indent = None;
3057 let mut line_indent = None;
3058 let mut first_line = true;
3059
3060 while let Some(chunk) = chunks.next().await {
3061 if response_latency.is_none() {
3062 response_latency = Some(request_start.elapsed());
3063 }
3064 let chunk = chunk?;
3065 completion_clone.lock().push_str(&chunk);
3066
3067 let mut lines = chunk.split('\n').peekable();
3068 while let Some(line) = lines.next() {
3069 new_text.push_str(line);
3070 if line_indent.is_none() {
3071 if let Some(non_whitespace_ch_ix) =
3072 new_text.find(|ch: char| !ch.is_whitespace())
3073 {
3074 line_indent = Some(non_whitespace_ch_ix);
3075 base_indent = base_indent.or(line_indent);
3076
3077 let line_indent = line_indent.unwrap();
3078 let base_indent = base_indent.unwrap();
3079 let indent_delta =
3080 line_indent as i32 - base_indent as i32;
3081 let mut corrected_indent_len = cmp::max(
3082 0,
3083 suggested_line_indent.len as i32 + indent_delta,
3084 )
3085 as usize;
3086 if first_line {
3087 corrected_indent_len = corrected_indent_len
3088 .saturating_sub(
3089 selection_start.column as usize,
3090 );
3091 }
3092
3093 let indent_char = suggested_line_indent.char();
3094 let mut indent_buffer = [0; 4];
3095 let indent_str =
3096 indent_char.encode_utf8(&mut indent_buffer);
3097 new_text.replace_range(
3098 ..line_indent,
3099 &indent_str.repeat(corrected_indent_len),
3100 );
3101 }
3102 }
3103
3104 if line_indent.is_some() {
3105 let char_ops = diff.push_new(&new_text);
3106 line_diff.push_char_operations(&char_ops, &selected_text);
3107 diff_tx
3108 .send((char_ops, line_diff.line_operations()))
3109 .await?;
3110 new_text.clear();
3111 }
3112
3113 if lines.peek().is_some() {
3114 let char_ops = diff.push_new("\n");
3115 line_diff.push_char_operations(&char_ops, &selected_text);
3116 diff_tx
3117 .send((char_ops, line_diff.line_operations()))
3118 .await?;
3119 if line_indent.is_none() {
3120 // Don't write out the leading indentation in empty lines on the next line
3121 // This is the case where the above if statement didn't clear the buffer
3122 new_text.clear();
3123 }
3124 line_indent = None;
3125 first_line = false;
3126 }
3127 }
3128 }
3129
3130 let mut char_ops = diff.push_new(&new_text);
3131 char_ops.extend(diff.finish());
3132 line_diff.push_char_operations(&char_ops, &selected_text);
3133 line_diff.finish(&selected_text);
3134 diff_tx
3135 .send((char_ops, line_diff.line_operations()))
3136 .await?;
3137
3138 anyhow::Ok(())
3139 };
3140
3141 let result = diff.await;
3142
3143 let error_message = result.as_ref().err().map(|error| error.to_string());
3144 report_assistant_event(
3145 AssistantEventData {
3146 conversation_id: None,
3147 message_id,
3148 kind: AssistantKind::Inline,
3149 phase: AssistantPhase::Response,
3150 model: model_telemetry_id,
3151 model_provider: model_provider_id.to_string(),
3152 response_latency,
3153 error_message,
3154 language_name: language_name.map(|name| name.to_proto()),
3155 },
3156 telemetry,
3157 http_client,
3158 model_api_key,
3159 &executor,
3160 );
3161
3162 result?;
3163 Ok(())
3164 });
3165
3166 while let Some((char_ops, line_ops)) = diff_rx.next().await {
3167 codegen.update(cx, |codegen, cx| {
3168 codegen.last_equal_ranges.clear();
3169
3170 let edits = char_ops
3171 .into_iter()
3172 .filter_map(|operation| match operation {
3173 CharOperation::Insert { text } => {
3174 let edit_start = snapshot.anchor_after(edit_start);
3175 Some((edit_start..edit_start, text))
3176 }
3177 CharOperation::Delete { bytes } => {
3178 let edit_end = edit_start + bytes;
3179 let edit_range = snapshot.anchor_after(edit_start)
3180 ..snapshot.anchor_before(edit_end);
3181 edit_start = edit_end;
3182 Some((edit_range, String::new()))
3183 }
3184 CharOperation::Keep { bytes } => {
3185 let edit_end = edit_start + bytes;
3186 let edit_range = snapshot.anchor_after(edit_start)
3187 ..snapshot.anchor_before(edit_end);
3188 edit_start = edit_end;
3189 codegen.last_equal_ranges.push(edit_range);
3190 None
3191 }
3192 })
3193 .collect::<Vec<_>>();
3194
3195 if codegen.active {
3196 codegen.apply_edits(edits.iter().cloned(), cx);
3197 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
3198 }
3199 codegen.edits.extend(edits);
3200 codegen.line_operations = line_ops;
3201 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
3202
3203 cx.notify();
3204 })?;
3205 }
3206
3207 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
3208 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
3209 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
3210 let batch_diff_task =
3211 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
3212 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
3213 line_based_stream_diff?;
3214
3215 anyhow::Ok(())
3216 };
3217
3218 let result = generate.await;
3219 let elapsed_time = start_time.elapsed().as_secs_f64();
3220
3221 codegen
3222 .update(cx, |this, cx| {
3223 this.message_id = message_id;
3224 this.last_equal_ranges.clear();
3225 if let Err(error) = result {
3226 this.status = CodegenStatus::Error(error);
3227 } else {
3228 this.status = CodegenStatus::Done;
3229 }
3230 this.elapsed_time = Some(elapsed_time);
3231 this.completion = Some(completion.lock().clone());
3232 cx.emit(CodegenEvent::Finished);
3233 cx.notify();
3234 })
3235 .ok();
3236 });
3237 cx.notify();
3238 }
3239
3240 pub fn stop(&mut self, cx: &mut Context<Self>) {
3241 self.last_equal_ranges.clear();
3242 if self.diff.is_empty() {
3243 self.status = CodegenStatus::Idle;
3244 } else {
3245 self.status = CodegenStatus::Done;
3246 }
3247 self.generation = Task::ready(());
3248 cx.emit(CodegenEvent::Finished);
3249 cx.notify();
3250 }
3251
3252 pub fn undo(&mut self, cx: &mut Context<Self>) {
3253 self.buffer.update(cx, |buffer, cx| {
3254 if let Some(transaction_id) = self.transformation_transaction_id.take() {
3255 buffer.undo_transaction(transaction_id, cx);
3256 buffer.refresh_preview(cx);
3257 }
3258 });
3259 }
3260
3261 fn apply_edits(
3262 &mut self,
3263 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
3264 cx: &mut Context<CodegenAlternative>,
3265 ) {
3266 let transaction = self.buffer.update(cx, |buffer, cx| {
3267 // Avoid grouping assistant edits with user edits.
3268 buffer.finalize_last_transaction(cx);
3269 buffer.start_transaction(cx);
3270 buffer.edit(edits, None, cx);
3271 buffer.end_transaction(cx)
3272 });
3273
3274 if let Some(transaction) = transaction {
3275 if let Some(first_transaction) = self.transformation_transaction_id {
3276 // Group all assistant edits into the first transaction.
3277 self.buffer.update(cx, |buffer, cx| {
3278 buffer.merge_transactions(transaction, first_transaction, cx)
3279 });
3280 } else {
3281 self.transformation_transaction_id = Some(transaction);
3282 self.buffer
3283 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3284 }
3285 }
3286 }
3287
3288 fn reapply_line_based_diff(
3289 &mut self,
3290 line_operations: impl IntoIterator<Item = LineOperation>,
3291 cx: &mut Context<Self>,
3292 ) {
3293 let old_snapshot = self.snapshot.clone();
3294 let old_range = self.range.to_point(&old_snapshot);
3295 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3296 let new_range = self.range.to_point(&new_snapshot);
3297
3298 let mut old_row = old_range.start.row;
3299 let mut new_row = new_range.start.row;
3300
3301 self.diff.deleted_row_ranges.clear();
3302 self.diff.inserted_row_ranges.clear();
3303 for operation in line_operations {
3304 match operation {
3305 LineOperation::Keep { lines } => {
3306 old_row += lines;
3307 new_row += lines;
3308 }
3309 LineOperation::Delete { lines } => {
3310 let old_end_row = old_row + lines - 1;
3311 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3312
3313 if let Some((_, last_deleted_row_range)) =
3314 self.diff.deleted_row_ranges.last_mut()
3315 {
3316 if *last_deleted_row_range.end() + 1 == old_row {
3317 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3318 } else {
3319 self.diff
3320 .deleted_row_ranges
3321 .push((new_row, old_row..=old_end_row));
3322 }
3323 } else {
3324 self.diff
3325 .deleted_row_ranges
3326 .push((new_row, old_row..=old_end_row));
3327 }
3328
3329 old_row += lines;
3330 }
3331 LineOperation::Insert { lines } => {
3332 let new_end_row = new_row + lines - 1;
3333 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3334 let end = new_snapshot.anchor_before(Point::new(
3335 new_end_row,
3336 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3337 ));
3338 self.diff.inserted_row_ranges.push(start..end);
3339 new_row += lines;
3340 }
3341 }
3342
3343 cx.notify();
3344 }
3345 }
3346
3347 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
3348 let old_snapshot = self.snapshot.clone();
3349 let old_range = self.range.to_point(&old_snapshot);
3350 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3351 let new_range = self.range.to_point(&new_snapshot);
3352
3353 cx.spawn(async move |codegen, cx| {
3354 let (deleted_row_ranges, inserted_row_ranges) = cx
3355 .background_spawn(async move {
3356 let old_text = old_snapshot
3357 .text_for_range(
3358 Point::new(old_range.start.row, 0)
3359 ..Point::new(
3360 old_range.end.row,
3361 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3362 ),
3363 )
3364 .collect::<String>();
3365 let new_text = new_snapshot
3366 .text_for_range(
3367 Point::new(new_range.start.row, 0)
3368 ..Point::new(
3369 new_range.end.row,
3370 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3371 ),
3372 )
3373 .collect::<String>();
3374
3375 let old_start_row = old_range.start.row;
3376 let new_start_row = new_range.start.row;
3377 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3378 let mut inserted_row_ranges = Vec::new();
3379 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
3380 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
3381 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
3382 if !old_rows.is_empty() {
3383 deleted_row_ranges.push((
3384 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
3385 old_rows.start..=old_rows.end - 1,
3386 ));
3387 }
3388 if !new_rows.is_empty() {
3389 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
3390 let new_end_row = new_rows.end - 1;
3391 let end = new_snapshot.anchor_before(Point::new(
3392 new_end_row,
3393 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3394 ));
3395 inserted_row_ranges.push(start..end);
3396 }
3397 }
3398 (deleted_row_ranges, inserted_row_ranges)
3399 })
3400 .await;
3401
3402 codegen
3403 .update(cx, |codegen, cx| {
3404 codegen.diff.deleted_row_ranges = deleted_row_ranges;
3405 codegen.diff.inserted_row_ranges = inserted_row_ranges;
3406 cx.notify();
3407 })
3408 .ok();
3409 })
3410 }
3411}
3412
3413struct StripInvalidSpans<T> {
3414 stream: T,
3415 stream_done: bool,
3416 buffer: String,
3417 first_line: bool,
3418 line_end: bool,
3419 starts_with_code_block: bool,
3420}
3421
3422impl<T> StripInvalidSpans<T>
3423where
3424 T: Stream<Item = Result<String>>,
3425{
3426 fn new(stream: T) -> Self {
3427 Self {
3428 stream,
3429 stream_done: false,
3430 buffer: String::new(),
3431 first_line: true,
3432 line_end: false,
3433 starts_with_code_block: false,
3434 }
3435 }
3436}
3437
3438impl<T> Stream for StripInvalidSpans<T>
3439where
3440 T: Stream<Item = Result<String>>,
3441{
3442 type Item = Result<String>;
3443
3444 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3445 const CODE_BLOCK_DELIMITER: &str = "```";
3446 const CURSOR_SPAN: &str = "<|CURSOR|>";
3447
3448 let this = unsafe { self.get_unchecked_mut() };
3449 loop {
3450 if !this.stream_done {
3451 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3452 match stream.as_mut().poll_next(cx) {
3453 Poll::Ready(Some(Ok(chunk))) => {
3454 this.buffer.push_str(&chunk);
3455 }
3456 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3457 Poll::Ready(None) => {
3458 this.stream_done = true;
3459 }
3460 Poll::Pending => return Poll::Pending,
3461 }
3462 }
3463
3464 let mut chunk = String::new();
3465 let mut consumed = 0;
3466 if !this.buffer.is_empty() {
3467 let mut lines = this.buffer.split('\n').enumerate().peekable();
3468 while let Some((line_ix, line)) = lines.next() {
3469 if line_ix > 0 {
3470 this.first_line = false;
3471 }
3472
3473 if this.first_line {
3474 let trimmed_line = line.trim();
3475 if lines.peek().is_some() {
3476 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3477 consumed += line.len() + 1;
3478 this.starts_with_code_block = true;
3479 continue;
3480 }
3481 } else if trimmed_line.is_empty()
3482 || prefixes(CODE_BLOCK_DELIMITER)
3483 .any(|prefix| trimmed_line.starts_with(prefix))
3484 {
3485 break;
3486 }
3487 }
3488
3489 let line_without_cursor = line.replace(CURSOR_SPAN, "");
3490 if lines.peek().is_some() {
3491 if this.line_end {
3492 chunk.push('\n');
3493 }
3494
3495 chunk.push_str(&line_without_cursor);
3496 this.line_end = true;
3497 consumed += line.len() + 1;
3498 } else if this.stream_done {
3499 if !this.starts_with_code_block
3500 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3501 {
3502 if this.line_end {
3503 chunk.push('\n');
3504 }
3505
3506 chunk.push_str(&line);
3507 }
3508
3509 consumed += line.len();
3510 } else {
3511 let trimmed_line = line.trim();
3512 if trimmed_line.is_empty()
3513 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3514 || prefixes(CODE_BLOCK_DELIMITER)
3515 .any(|prefix| trimmed_line.ends_with(prefix))
3516 {
3517 break;
3518 } else {
3519 if this.line_end {
3520 chunk.push('\n');
3521 this.line_end = false;
3522 }
3523
3524 chunk.push_str(&line_without_cursor);
3525 consumed += line.len();
3526 }
3527 }
3528 }
3529 }
3530
3531 this.buffer = this.buffer.split_off(consumed);
3532 if !chunk.is_empty() {
3533 return Poll::Ready(Some(Ok(chunk)));
3534 } else if this.stream_done {
3535 return Poll::Ready(None);
3536 }
3537 }
3538 }
3539}
3540
3541struct AssistantCodeActionProvider {
3542 editor: WeakEntity<Editor>,
3543 workspace: WeakEntity<Workspace>,
3544}
3545
3546const ASSISTANT_CODE_ACTION_PROVIDER_ID: &str = "assistant";
3547
3548impl CodeActionProvider for AssistantCodeActionProvider {
3549 fn id(&self) -> Arc<str> {
3550 ASSISTANT_CODE_ACTION_PROVIDER_ID.into()
3551 }
3552
3553 fn code_actions(
3554 &self,
3555 buffer: &Entity<Buffer>,
3556 range: Range<text::Anchor>,
3557 _: &mut Window,
3558 cx: &mut App,
3559 ) -> Task<Result<Vec<CodeAction>>> {
3560 if !Assistant::enabled(cx) {
3561 return Task::ready(Ok(Vec::new()));
3562 }
3563
3564 let snapshot = buffer.read(cx).snapshot();
3565 let mut range = range.to_point(&snapshot);
3566
3567 // Expand the range to line boundaries.
3568 range.start.column = 0;
3569 range.end.column = snapshot.line_len(range.end.row);
3570
3571 let mut has_diagnostics = false;
3572 for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3573 range.start = cmp::min(range.start, diagnostic.range.start);
3574 range.end = cmp::max(range.end, diagnostic.range.end);
3575 has_diagnostics = true;
3576 }
3577 if has_diagnostics {
3578 if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3579 if let Some(symbol) = symbols_containing_start.last() {
3580 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3581 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3582 }
3583 }
3584
3585 if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3586 if let Some(symbol) = symbols_containing_end.last() {
3587 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3588 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3589 }
3590 }
3591
3592 Task::ready(Ok(vec![CodeAction {
3593 server_id: language::LanguageServerId(0),
3594 range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3595 lsp_action: LspAction::Action(Box::new(lsp::CodeAction {
3596 title: "Fix with Assistant".into(),
3597 ..Default::default()
3598 })),
3599 resolved: true,
3600 }]))
3601 } else {
3602 Task::ready(Ok(Vec::new()))
3603 }
3604 }
3605
3606 fn apply_code_action(
3607 &self,
3608 buffer: Entity<Buffer>,
3609 action: CodeAction,
3610 excerpt_id: ExcerptId,
3611 _push_to_history: bool,
3612 window: &mut Window,
3613 cx: &mut App,
3614 ) -> Task<Result<ProjectTransaction>> {
3615 let editor = self.editor.clone();
3616 let workspace = self.workspace.clone();
3617 window.spawn(cx, async move |cx| {
3618 let editor = editor.upgrade().context("editor was released")?;
3619 let range = editor
3620 .update(cx, |editor, cx| {
3621 editor.buffer().update(cx, |multibuffer, cx| {
3622 let buffer = buffer.read(cx);
3623 let multibuffer_snapshot = multibuffer.read(cx);
3624
3625 let old_context_range =
3626 multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3627 let mut new_context_range = old_context_range.clone();
3628 if action
3629 .range
3630 .start
3631 .cmp(&old_context_range.start, buffer)
3632 .is_lt()
3633 {
3634 new_context_range.start = action.range.start;
3635 }
3636 if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3637 new_context_range.end = action.range.end;
3638 }
3639 drop(multibuffer_snapshot);
3640
3641 if new_context_range != old_context_range {
3642 multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3643 }
3644
3645 let multibuffer_snapshot = multibuffer.read(cx);
3646 Some(
3647 multibuffer_snapshot
3648 .anchor_in_excerpt(excerpt_id, action.range.start)?
3649 ..multibuffer_snapshot
3650 .anchor_in_excerpt(excerpt_id, action.range.end)?,
3651 )
3652 })
3653 })?
3654 .context("invalid range")?;
3655 let assistant_panel = workspace.update(cx, |workspace, cx| {
3656 workspace
3657 .panel::<AssistantPanel>(cx)
3658 .context("assistant panel was released")
3659 })??;
3660
3661 cx.update_global(|assistant: &mut InlineAssistant, window, cx| {
3662 let assist_id = assistant.suggest_assist(
3663 &editor,
3664 range,
3665 "Fix Diagnostics".into(),
3666 None,
3667 true,
3668 Some(workspace),
3669 Some(&assistant_panel),
3670 window,
3671 cx,
3672 );
3673 assistant.start_assist(assist_id, window, cx);
3674 })?;
3675
3676 Ok(ProjectTransaction::default())
3677 })
3678 }
3679}
3680
3681fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3682 (0..text.len() - 1).map(|ix| &text[..ix + 1])
3683}
3684
3685fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3686 ranges.sort_unstable_by(|a, b| {
3687 a.start
3688 .cmp(&b.start, buffer)
3689 .then_with(|| b.end.cmp(&a.end, buffer))
3690 });
3691
3692 let mut ix = 0;
3693 while ix + 1 < ranges.len() {
3694 let b = ranges[ix + 1].clone();
3695 let a = &mut ranges[ix];
3696 if a.end.cmp(&b.start, buffer).is_gt() {
3697 if a.end.cmp(&b.end, buffer).is_lt() {
3698 a.end = b.end;
3699 }
3700 ranges.remove(ix + 1);
3701 } else {
3702 ix += 1;
3703 }
3704 }
3705}
3706
3707#[cfg(test)]
3708mod tests {
3709 use super::*;
3710 use futures::stream::{self};
3711 use gpui::TestAppContext;
3712 use indoc::indoc;
3713 use language::{
3714 Buffer, Language, LanguageConfig, LanguageMatcher, Point, language_settings,
3715 tree_sitter_rust,
3716 };
3717 use language_model::{LanguageModelRegistry, TokenUsage};
3718 use rand::prelude::*;
3719 use serde::Serialize;
3720 use settings::SettingsStore;
3721 use std::{future, sync::Arc};
3722
3723 #[derive(Serialize)]
3724 pub struct DummyCompletionRequest {
3725 pub name: String,
3726 }
3727
3728 #[gpui::test(iterations = 10)]
3729 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3730 cx.set_global(cx.update(SettingsStore::test));
3731 cx.update(language_model::LanguageModelRegistry::test);
3732 cx.update(language_settings::init);
3733
3734 let text = indoc! {"
3735 fn main() {
3736 let x = 0;
3737 for _ in 0..10 {
3738 x += 1;
3739 }
3740 }
3741 "};
3742 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3743 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3744 let range = buffer.read_with(cx, |buffer, cx| {
3745 let snapshot = buffer.snapshot(cx);
3746 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3747 });
3748 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3749 let codegen = cx.new(|cx| {
3750 CodegenAlternative::new(
3751 buffer.clone(),
3752 range.clone(),
3753 true,
3754 None,
3755 prompt_builder,
3756 cx,
3757 )
3758 });
3759
3760 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3761
3762 let mut new_text = concat!(
3763 " let mut x = 0;\n",
3764 " while x < 10 {\n",
3765 " x += 1;\n",
3766 " }",
3767 );
3768 while !new_text.is_empty() {
3769 let max_len = cmp::min(new_text.len(), 10);
3770 let len = rng.gen_range(1..=max_len);
3771 let (chunk, suffix) = new_text.split_at(len);
3772 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3773 new_text = suffix;
3774 cx.background_executor.run_until_parked();
3775 }
3776 drop(chunks_tx);
3777 cx.background_executor.run_until_parked();
3778
3779 assert_eq!(
3780 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3781 indoc! {"
3782 fn main() {
3783 let mut x = 0;
3784 while x < 10 {
3785 x += 1;
3786 }
3787 }
3788 "}
3789 );
3790 }
3791
3792 #[gpui::test(iterations = 10)]
3793 async fn test_autoindent_when_generating_past_indentation(
3794 cx: &mut TestAppContext,
3795 mut rng: StdRng,
3796 ) {
3797 cx.set_global(cx.update(SettingsStore::test));
3798 cx.update(language_settings::init);
3799
3800 let text = indoc! {"
3801 fn main() {
3802 le
3803 }
3804 "};
3805 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3806 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3807 let range = buffer.read_with(cx, |buffer, cx| {
3808 let snapshot = buffer.snapshot(cx);
3809 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3810 });
3811 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3812 let codegen = cx.new(|cx| {
3813 CodegenAlternative::new(
3814 buffer.clone(),
3815 range.clone(),
3816 true,
3817 None,
3818 prompt_builder,
3819 cx,
3820 )
3821 });
3822
3823 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3824
3825 cx.background_executor.run_until_parked();
3826
3827 let mut new_text = concat!(
3828 "t mut x = 0;\n",
3829 "while x < 10 {\n",
3830 " x += 1;\n",
3831 "}", //
3832 );
3833 while !new_text.is_empty() {
3834 let max_len = cmp::min(new_text.len(), 10);
3835 let len = rng.gen_range(1..=max_len);
3836 let (chunk, suffix) = new_text.split_at(len);
3837 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3838 new_text = suffix;
3839 cx.background_executor.run_until_parked();
3840 }
3841 drop(chunks_tx);
3842 cx.background_executor.run_until_parked();
3843
3844 assert_eq!(
3845 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3846 indoc! {"
3847 fn main() {
3848 let mut x = 0;
3849 while x < 10 {
3850 x += 1;
3851 }
3852 }
3853 "}
3854 );
3855 }
3856
3857 #[gpui::test(iterations = 10)]
3858 async fn test_autoindent_when_generating_before_indentation(
3859 cx: &mut TestAppContext,
3860 mut rng: StdRng,
3861 ) {
3862 cx.update(LanguageModelRegistry::test);
3863 cx.set_global(cx.update(SettingsStore::test));
3864 cx.update(language_settings::init);
3865
3866 let text = concat!(
3867 "fn main() {\n",
3868 " \n",
3869 "}\n" //
3870 );
3871 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3872 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3873 let range = buffer.read_with(cx, |buffer, cx| {
3874 let snapshot = buffer.snapshot(cx);
3875 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3876 });
3877 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3878 let codegen = cx.new(|cx| {
3879 CodegenAlternative::new(
3880 buffer.clone(),
3881 range.clone(),
3882 true,
3883 None,
3884 prompt_builder,
3885 cx,
3886 )
3887 });
3888
3889 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3890
3891 cx.background_executor.run_until_parked();
3892
3893 let mut new_text = concat!(
3894 "let mut x = 0;\n",
3895 "while x < 10 {\n",
3896 " x += 1;\n",
3897 "}", //
3898 );
3899 while !new_text.is_empty() {
3900 let max_len = cmp::min(new_text.len(), 10);
3901 let len = rng.gen_range(1..=max_len);
3902 let (chunk, suffix) = new_text.split_at(len);
3903 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3904 new_text = suffix;
3905 cx.background_executor.run_until_parked();
3906 }
3907 drop(chunks_tx);
3908 cx.background_executor.run_until_parked();
3909
3910 assert_eq!(
3911 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3912 indoc! {"
3913 fn main() {
3914 let mut x = 0;
3915 while x < 10 {
3916 x += 1;
3917 }
3918 }
3919 "}
3920 );
3921 }
3922
3923 #[gpui::test(iterations = 10)]
3924 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3925 cx.update(LanguageModelRegistry::test);
3926 cx.set_global(cx.update(SettingsStore::test));
3927 cx.update(language_settings::init);
3928
3929 let text = indoc! {"
3930 func main() {
3931 \tx := 0
3932 \tfor i := 0; i < 10; i++ {
3933 \t\tx++
3934 \t}
3935 }
3936 "};
3937 let buffer = cx.new(|cx| Buffer::local(text, cx));
3938 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3939 let range = buffer.read_with(cx, |buffer, cx| {
3940 let snapshot = buffer.snapshot(cx);
3941 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3942 });
3943 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3944 let codegen = cx.new(|cx| {
3945 CodegenAlternative::new(
3946 buffer.clone(),
3947 range.clone(),
3948 true,
3949 None,
3950 prompt_builder,
3951 cx,
3952 )
3953 });
3954
3955 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3956 let new_text = concat!(
3957 "func main() {\n",
3958 "\tx := 0\n",
3959 "\tfor x < 10 {\n",
3960 "\t\tx++\n",
3961 "\t}", //
3962 );
3963 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3964 drop(chunks_tx);
3965 cx.background_executor.run_until_parked();
3966
3967 assert_eq!(
3968 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3969 indoc! {"
3970 func main() {
3971 \tx := 0
3972 \tfor x < 10 {
3973 \t\tx++
3974 \t}
3975 }
3976 "}
3977 );
3978 }
3979
3980 #[gpui::test]
3981 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3982 cx.update(LanguageModelRegistry::test);
3983 cx.set_global(cx.update(SettingsStore::test));
3984 cx.update(language_settings::init);
3985
3986 let text = indoc! {"
3987 fn main() {
3988 let x = 0;
3989 }
3990 "};
3991 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3992 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
3993 let range = buffer.read_with(cx, |buffer, cx| {
3994 let snapshot = buffer.snapshot(cx);
3995 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3996 });
3997 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3998 let codegen = cx.new(|cx| {
3999 CodegenAlternative::new(
4000 buffer.clone(),
4001 range.clone(),
4002 false,
4003 None,
4004 prompt_builder,
4005 cx,
4006 )
4007 });
4008
4009 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
4010 chunks_tx
4011 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
4012 .unwrap();
4013 drop(chunks_tx);
4014 cx.run_until_parked();
4015
4016 // The codegen is inactive, so the buffer doesn't get modified.
4017 assert_eq!(
4018 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
4019 text
4020 );
4021
4022 // Activating the codegen applies the changes.
4023 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
4024 assert_eq!(
4025 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
4026 indoc! {"
4027 fn main() {
4028 let mut x = 0;
4029 x += 1;
4030 }
4031 "}
4032 );
4033
4034 // Deactivating the codegen undoes the changes.
4035 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
4036 cx.run_until_parked();
4037 assert_eq!(
4038 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
4039 text
4040 );
4041 }
4042
4043 #[gpui::test]
4044 async fn test_strip_invalid_spans_from_codeblock() {
4045 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
4046 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
4047 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
4048 assert_chunks(
4049 "```html\n```js\nLorem ipsum dolor\n```\n```",
4050 "```js\nLorem ipsum dolor\n```",
4051 )
4052 .await;
4053 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
4054 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
4055 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
4056 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
4057
4058 async fn assert_chunks(text: &str, expected_text: &str) {
4059 for chunk_size in 1..=text.len() {
4060 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
4061 .map(|chunk| chunk.unwrap())
4062 .collect::<String>()
4063 .await;
4064 assert_eq!(
4065 actual_text, expected_text,
4066 "failed to strip invalid spans, chunk size: {}",
4067 chunk_size
4068 );
4069 }
4070 }
4071
4072 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
4073 stream::iter(
4074 text.chars()
4075 .collect::<Vec<_>>()
4076 .chunks(size)
4077 .map(|chunk| Ok(chunk.iter().collect::<String>()))
4078 .collect::<Vec<_>>(),
4079 )
4080 }
4081 }
4082
4083 fn simulate_response_stream(
4084 codegen: Entity<CodegenAlternative>,
4085 cx: &mut TestAppContext,
4086 ) -> mpsc::UnboundedSender<String> {
4087 let (chunks_tx, chunks_rx) = mpsc::unbounded();
4088 codegen.update(cx, |codegen, cx| {
4089 codegen.handle_stream(
4090 String::new(),
4091 String::new(),
4092 None,
4093 future::ready(Ok(LanguageModelTextStream {
4094 message_id: None,
4095 stream: chunks_rx.map(Ok).boxed(),
4096 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
4097 })),
4098 cx,
4099 );
4100 });
4101 chunks_tx
4102 }
4103
4104 fn rust_lang() -> Language {
4105 Language::new(
4106 LanguageConfig {
4107 name: "Rust".into(),
4108 matcher: LanguageMatcher {
4109 path_suffixes: vec!["rs".to_string()],
4110 ..Default::default()
4111 },
4112 ..Default::default()
4113 },
4114 Some(tree_sitter_rust::LANGUAGE.into()),
4115 )
4116 .with_indents_query(
4117 r#"
4118 (call_expression) @indent
4119 (field_expression) @indent
4120 (_ "(" ")" @end) @indent
4121 (_ "{" "}" @end) @indent
4122 "#,
4123 )
4124 .unwrap()
4125 }
4126}