1use crate::{
2 assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder,
3 AssistantPanel, AssistantPanelEvent, CharOperation, CycleNextInlineAssist,
4 CyclePreviousInlineAssist, LineDiff, LineOperation, ModelSelector, StreamingDiff,
5};
6use anyhow::{anyhow, Context as _, Result};
7use client::{telemetry::Telemetry, ErrorExt};
8use collections::{hash_map, HashMap, HashSet, VecDeque};
9use editor::{
10 actions::{MoveDown, MoveUp, SelectAll},
11 display_map::{
12 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
13 ToDisplayPoint,
14 },
15 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
16 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
17};
18use feature_flags::{FeatureFlagAppExt as _, ZedPro};
19use fs::Fs;
20use futures::{
21 channel::mpsc,
22 future::{BoxFuture, LocalBoxFuture},
23 join,
24 stream::{self, BoxStream},
25 SinkExt, Stream, StreamExt,
26};
27use gpui::{
28 anchored, deferred, point, AnyElement, AppContext, ClickEvent, EventEmitter, FocusHandle,
29 FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task,
30 TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
31};
32use language::{Buffer, IndentKind, Point, Selection, TransactionId};
33use language_model::{
34 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
35};
36use multi_buffer::MultiBufferRow;
37use parking_lot::Mutex;
38use rope::Rope;
39use settings::{Settings, SettingsStore};
40use smol::future::FutureExt;
41use std::{
42 cmp,
43 future::{self, Future},
44 iter, mem,
45 ops::{Range, RangeInclusive},
46 pin::Pin,
47 sync::Arc,
48 task::{self, Poll},
49 time::{Duration, Instant},
50};
51use terminal_view::terminal_panel::TerminalPanel;
52use theme::ThemeSettings;
53use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
54use util::{RangeExt, ResultExt};
55use workspace::{notifications::NotificationId, Toast, Workspace};
56
57pub fn init(
58 fs: Arc<dyn Fs>,
59 prompt_builder: Arc<PromptBuilder>,
60 telemetry: Arc<Telemetry>,
61 cx: &mut AppContext,
62) {
63 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
64 cx.observe_new_views(|_, cx| {
65 let workspace = cx.view().clone();
66 InlineAssistant::update_global(cx, |inline_assistant, cx| {
67 inline_assistant.register_workspace(&workspace, cx)
68 })
69 })
70 .detach();
71}
72
73const PROMPT_HISTORY_MAX_LEN: usize = 20;
74
75pub struct InlineAssistant {
76 next_assist_id: InlineAssistId,
77 next_assist_group_id: InlineAssistGroupId,
78 assists: HashMap<InlineAssistId, InlineAssist>,
79 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
80 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
81 assist_observations: HashMap<
82 InlineAssistId,
83 (
84 async_watch::Sender<AssistStatus>,
85 async_watch::Receiver<AssistStatus>,
86 ),
87 >,
88 confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
89 prompt_history: VecDeque<String>,
90 prompt_builder: Arc<PromptBuilder>,
91 telemetry: Option<Arc<Telemetry>>,
92 fs: Arc<dyn Fs>,
93}
94
95pub enum AssistStatus {
96 Idle,
97 Started,
98 Stopped,
99 Finished,
100}
101
102impl AssistStatus {
103 pub fn is_done(&self) -> bool {
104 matches!(self, Self::Stopped | Self::Finished)
105 }
106}
107
108impl Global for InlineAssistant {}
109
110impl InlineAssistant {
111 pub fn new(
112 fs: Arc<dyn Fs>,
113 prompt_builder: Arc<PromptBuilder>,
114 telemetry: Arc<Telemetry>,
115 ) -> Self {
116 Self {
117 next_assist_id: InlineAssistId::default(),
118 next_assist_group_id: InlineAssistGroupId::default(),
119 assists: HashMap::default(),
120 assists_by_editor: HashMap::default(),
121 assist_groups: HashMap::default(),
122 assist_observations: HashMap::default(),
123 confirmed_assists: HashMap::default(),
124 prompt_history: VecDeque::default(),
125 prompt_builder,
126 telemetry: Some(telemetry),
127 fs,
128 }
129 }
130
131 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
132 cx.subscribe(workspace, |_, event, cx| {
133 Self::update_global(cx, |this, cx| this.handle_workspace_event(event, cx));
134 })
135 .detach();
136
137 let workspace = workspace.downgrade();
138 cx.observe_global::<SettingsStore>(move |cx| {
139 let Some(workspace) = workspace.upgrade() else {
140 return;
141 };
142 let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
143 return;
144 };
145 let enabled = AssistantSettings::get_global(cx).enabled;
146 terminal_panel.update(cx, |terminal_panel, cx| {
147 terminal_panel.asssistant_enabled(enabled, cx)
148 });
149 })
150 .detach();
151 }
152
153 fn handle_workspace_event(&mut self, event: &workspace::Event, cx: &mut WindowContext) {
154 // When the user manually saves an editor, automatically accepts all finished transformations.
155 if let workspace::Event::UserSavedItem { item, .. } = event {
156 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
157 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
158 for assist_id in editor_assists.assist_ids.clone() {
159 let assist = &self.assists[&assist_id];
160 if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
161 self.finish_assist(assist_id, false, cx)
162 }
163 }
164 }
165 }
166 }
167 }
168
169 pub fn assist(
170 &mut self,
171 editor: &View<Editor>,
172 workspace: Option<WeakView<Workspace>>,
173 assistant_panel: Option<&View<AssistantPanel>>,
174 initial_prompt: Option<String>,
175 cx: &mut WindowContext,
176 ) {
177 if let Some(telemetry) = self.telemetry.as_ref() {
178 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
179 telemetry.report_assistant_event(
180 None,
181 telemetry_events::AssistantKind::Inline,
182 telemetry_events::AssistantPhase::Invoked,
183 model.telemetry_id(),
184 None,
185 None,
186 );
187 }
188 }
189 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
190
191 let mut selections = Vec::<Selection<Point>>::new();
192 let mut newest_selection = None;
193 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
194 if selection.end > selection.start {
195 selection.start.column = 0;
196 // If the selection ends at the start of the line, we don't want to include it.
197 if selection.end.column == 0 {
198 selection.end.row -= 1;
199 }
200 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
201 }
202
203 if let Some(prev_selection) = selections.last_mut() {
204 if selection.start <= prev_selection.end {
205 prev_selection.end = selection.end;
206 continue;
207 }
208 }
209
210 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
211 if selection.id > latest_selection.id {
212 *latest_selection = selection.clone();
213 }
214 selections.push(selection);
215 }
216 let newest_selection = newest_selection.unwrap();
217
218 let mut codegen_ranges = Vec::new();
219 for (excerpt_id, buffer, buffer_range) in
220 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
221 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
222 }))
223 {
224 let start = Anchor {
225 buffer_id: Some(buffer.remote_id()),
226 excerpt_id,
227 text_anchor: buffer.anchor_before(buffer_range.start),
228 };
229 let end = Anchor {
230 buffer_id: Some(buffer.remote_id()),
231 excerpt_id,
232 text_anchor: buffer.anchor_after(buffer_range.end),
233 };
234 codegen_ranges.push(start..end);
235 }
236
237 let assist_group_id = self.next_assist_group_id.post_inc();
238 let prompt_buffer =
239 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
240 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
241
242 let mut assists = Vec::new();
243 let mut assist_to_focus = None;
244 for range in codegen_ranges {
245 let assist_id = self.next_assist_id.post_inc();
246 let codegen = cx.new_model(|cx| {
247 Codegen::new(
248 editor.read(cx).buffer().clone(),
249 range.clone(),
250 None,
251 self.telemetry.clone(),
252 self.prompt_builder.clone(),
253 cx,
254 )
255 });
256
257 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
258 let prompt_editor = cx.new_view(|cx| {
259 PromptEditor::new(
260 assist_id,
261 gutter_dimensions.clone(),
262 self.prompt_history.clone(),
263 prompt_buffer.clone(),
264 codegen.clone(),
265 editor,
266 assistant_panel,
267 workspace.clone(),
268 self.fs.clone(),
269 cx,
270 )
271 });
272
273 if assist_to_focus.is_none() {
274 let focus_assist = if newest_selection.reversed {
275 range.start.to_point(&snapshot) == newest_selection.start
276 } else {
277 range.end.to_point(&snapshot) == newest_selection.end
278 };
279 if focus_assist {
280 assist_to_focus = Some(assist_id);
281 }
282 }
283
284 let [prompt_block_id, end_block_id] =
285 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
286
287 assists.push((
288 assist_id,
289 range,
290 prompt_editor,
291 prompt_block_id,
292 end_block_id,
293 ));
294 }
295
296 let editor_assists = self
297 .assists_by_editor
298 .entry(editor.downgrade())
299 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
300 let mut assist_group = InlineAssistGroup::new();
301 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
302 self.assists.insert(
303 assist_id,
304 InlineAssist::new(
305 assist_id,
306 assist_group_id,
307 assistant_panel.is_some(),
308 editor,
309 &prompt_editor,
310 prompt_block_id,
311 end_block_id,
312 range,
313 prompt_editor.read(cx).codegen.clone(),
314 workspace.clone(),
315 cx,
316 ),
317 );
318 assist_group.assist_ids.push(assist_id);
319 editor_assists.assist_ids.push(assist_id);
320 }
321 self.assist_groups.insert(assist_group_id, assist_group);
322
323 if let Some(assist_id) = assist_to_focus {
324 self.focus_assist(assist_id, cx);
325 }
326 }
327
328 #[allow(clippy::too_many_arguments)]
329 pub fn suggest_assist(
330 &mut self,
331 editor: &View<Editor>,
332 mut range: Range<Anchor>,
333 initial_prompt: String,
334 initial_transaction_id: Option<TransactionId>,
335 workspace: Option<WeakView<Workspace>>,
336 assistant_panel: Option<&View<AssistantPanel>>,
337 cx: &mut WindowContext,
338 ) -> InlineAssistId {
339 let assist_group_id = self.next_assist_group_id.post_inc();
340 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
341 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
342
343 let assist_id = self.next_assist_id.post_inc();
344
345 let buffer = editor.read(cx).buffer().clone();
346 {
347 let snapshot = buffer.read(cx).read(cx);
348 range.start = range.start.bias_left(&snapshot);
349 range.end = range.end.bias_right(&snapshot);
350 }
351
352 let codegen = cx.new_model(|cx| {
353 Codegen::new(
354 editor.read(cx).buffer().clone(),
355 range.clone(),
356 initial_transaction_id,
357 self.telemetry.clone(),
358 self.prompt_builder.clone(),
359 cx,
360 )
361 });
362
363 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
364 let prompt_editor = cx.new_view(|cx| {
365 PromptEditor::new(
366 assist_id,
367 gutter_dimensions.clone(),
368 self.prompt_history.clone(),
369 prompt_buffer.clone(),
370 codegen.clone(),
371 editor,
372 assistant_panel,
373 workspace.clone(),
374 self.fs.clone(),
375 cx,
376 )
377 });
378
379 let [prompt_block_id, end_block_id] =
380 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
381
382 let editor_assists = self
383 .assists_by_editor
384 .entry(editor.downgrade())
385 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
386
387 let mut assist_group = InlineAssistGroup::new();
388 self.assists.insert(
389 assist_id,
390 InlineAssist::new(
391 assist_id,
392 assist_group_id,
393 assistant_panel.is_some(),
394 editor,
395 &prompt_editor,
396 prompt_block_id,
397 end_block_id,
398 range,
399 prompt_editor.read(cx).codegen.clone(),
400 workspace.clone(),
401 cx,
402 ),
403 );
404 assist_group.assist_ids.push(assist_id);
405 editor_assists.assist_ids.push(assist_id);
406 self.assist_groups.insert(assist_group_id, assist_group);
407 assist_id
408 }
409
410 fn insert_assist_blocks(
411 &self,
412 editor: &View<Editor>,
413 range: &Range<Anchor>,
414 prompt_editor: &View<PromptEditor>,
415 cx: &mut WindowContext,
416 ) -> [CustomBlockId; 2] {
417 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
418 prompt_editor
419 .editor
420 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
421 });
422 let assist_blocks = vec![
423 BlockProperties {
424 style: BlockStyle::Sticky,
425 position: range.start,
426 height: prompt_editor_height,
427 render: build_assist_editor_renderer(prompt_editor),
428 disposition: BlockDisposition::Above,
429 priority: 0,
430 },
431 BlockProperties {
432 style: BlockStyle::Sticky,
433 position: range.end,
434 height: 0,
435 render: Box::new(|cx| {
436 v_flex()
437 .h_full()
438 .w_full()
439 .border_t_1()
440 .border_color(cx.theme().status().info_border)
441 .into_any_element()
442 }),
443 disposition: BlockDisposition::Below,
444 priority: 0,
445 },
446 ];
447
448 editor.update(cx, |editor, cx| {
449 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
450 [block_ids[0], block_ids[1]]
451 })
452 }
453
454 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
455 let assist = &self.assists[&assist_id];
456 let Some(decorations) = assist.decorations.as_ref() else {
457 return;
458 };
459 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
460 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
461
462 assist_group.active_assist_id = Some(assist_id);
463 if assist_group.linked {
464 for assist_id in &assist_group.assist_ids {
465 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
466 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
467 prompt_editor.set_show_cursor_when_unfocused(true, cx)
468 });
469 }
470 }
471 }
472
473 assist
474 .editor
475 .update(cx, |editor, cx| {
476 let scroll_top = editor.scroll_position(cx).y;
477 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
478 let prompt_row = editor
479 .row_for_block(decorations.prompt_block_id, cx)
480 .unwrap()
481 .0 as f32;
482
483 if (scroll_top..scroll_bottom).contains(&prompt_row) {
484 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
485 assist_id,
486 distance_from_top: prompt_row - scroll_top,
487 });
488 } else {
489 editor_assists.scroll_lock = None;
490 }
491 })
492 .ok();
493 }
494
495 fn handle_prompt_editor_focus_out(
496 &mut self,
497 assist_id: InlineAssistId,
498 cx: &mut WindowContext,
499 ) {
500 let assist = &self.assists[&assist_id];
501 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
502 if assist_group.active_assist_id == Some(assist_id) {
503 assist_group.active_assist_id = None;
504 if assist_group.linked {
505 for assist_id in &assist_group.assist_ids {
506 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
507 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
508 prompt_editor.set_show_cursor_when_unfocused(false, cx)
509 });
510 }
511 }
512 }
513 }
514 }
515
516 fn handle_prompt_editor_event(
517 &mut self,
518 prompt_editor: View<PromptEditor>,
519 event: &PromptEditorEvent,
520 cx: &mut WindowContext,
521 ) {
522 let assist_id = prompt_editor.read(cx).id;
523 match event {
524 PromptEditorEvent::StartRequested => {
525 self.start_assist(assist_id, cx);
526 }
527 PromptEditorEvent::StopRequested => {
528 self.stop_assist(assist_id, cx);
529 }
530 PromptEditorEvent::ConfirmRequested => {
531 self.finish_assist(assist_id, false, cx);
532 }
533 PromptEditorEvent::CancelRequested => {
534 self.finish_assist(assist_id, true, cx);
535 }
536 PromptEditorEvent::DismissRequested => {
537 self.dismiss_assist(assist_id, cx);
538 }
539 }
540 }
541
542 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
543 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
544 return;
545 };
546
547 let editor = editor.read(cx);
548 if editor.selections.count() == 1 {
549 let selection = editor.selections.newest::<usize>(cx);
550 let buffer = editor.buffer().read(cx).snapshot(cx);
551 for assist_id in &editor_assists.assist_ids {
552 let assist = &self.assists[assist_id];
553 let assist_range = assist.range.to_offset(&buffer);
554 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
555 {
556 if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
557 self.dismiss_assist(*assist_id, cx);
558 } else {
559 self.finish_assist(*assist_id, false, cx);
560 }
561
562 return;
563 }
564 }
565 }
566
567 cx.propagate();
568 }
569
570 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
571 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
572 return;
573 };
574
575 let editor = editor.read(cx);
576 if editor.selections.count() == 1 {
577 let selection = editor.selections.newest::<usize>(cx);
578 let buffer = editor.buffer().read(cx).snapshot(cx);
579 let mut closest_assist_fallback = None;
580 for assist_id in &editor_assists.assist_ids {
581 let assist = &self.assists[assist_id];
582 let assist_range = assist.range.to_offset(&buffer);
583 if assist.decorations.is_some() {
584 if assist_range.contains(&selection.start)
585 && assist_range.contains(&selection.end)
586 {
587 self.focus_assist(*assist_id, cx);
588 return;
589 } else {
590 let distance_from_selection = assist_range
591 .start
592 .abs_diff(selection.start)
593 .min(assist_range.start.abs_diff(selection.end))
594 + assist_range
595 .end
596 .abs_diff(selection.start)
597 .min(assist_range.end.abs_diff(selection.end));
598 match closest_assist_fallback {
599 Some((_, old_distance)) => {
600 if distance_from_selection < old_distance {
601 closest_assist_fallback =
602 Some((assist_id, distance_from_selection));
603 }
604 }
605 None => {
606 closest_assist_fallback = Some((assist_id, distance_from_selection))
607 }
608 }
609 }
610 }
611 }
612
613 if let Some((&assist_id, _)) = closest_assist_fallback {
614 self.focus_assist(assist_id, cx);
615 }
616 }
617
618 cx.propagate();
619 }
620
621 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
622 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
623 for assist_id in editor_assists.assist_ids.clone() {
624 self.finish_assist(assist_id, true, cx);
625 }
626 }
627 }
628
629 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
630 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
631 return;
632 };
633 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
634 return;
635 };
636 let assist = &self.assists[&scroll_lock.assist_id];
637 let Some(decorations) = assist.decorations.as_ref() else {
638 return;
639 };
640
641 editor.update(cx, |editor, cx| {
642 let scroll_position = editor.scroll_position(cx);
643 let target_scroll_top = editor
644 .row_for_block(decorations.prompt_block_id, cx)
645 .unwrap()
646 .0 as f32
647 - scroll_lock.distance_from_top;
648 if target_scroll_top != scroll_position.y {
649 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
650 }
651 });
652 }
653
654 fn handle_editor_event(
655 &mut self,
656 editor: View<Editor>,
657 event: &EditorEvent,
658 cx: &mut WindowContext,
659 ) {
660 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
661 return;
662 };
663
664 match event {
665 EditorEvent::Edited { transaction_id } => {
666 let buffer = editor.read(cx).buffer().read(cx);
667 let edited_ranges =
668 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
669 let snapshot = buffer.snapshot(cx);
670
671 for assist_id in editor_assists.assist_ids.clone() {
672 let assist = &self.assists[&assist_id];
673 if matches!(
674 assist.codegen.read(cx).status(cx),
675 CodegenStatus::Error(_) | CodegenStatus::Done
676 ) {
677 let assist_range = assist.range.to_offset(&snapshot);
678 if edited_ranges
679 .iter()
680 .any(|range| range.overlaps(&assist_range))
681 {
682 self.finish_assist(assist_id, false, cx);
683 }
684 }
685 }
686 }
687 EditorEvent::ScrollPositionChanged { .. } => {
688 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
689 let assist = &self.assists[&scroll_lock.assist_id];
690 if let Some(decorations) = assist.decorations.as_ref() {
691 let distance_from_top = editor.update(cx, |editor, cx| {
692 let scroll_top = editor.scroll_position(cx).y;
693 let prompt_row = editor
694 .row_for_block(decorations.prompt_block_id, cx)
695 .unwrap()
696 .0 as f32;
697 prompt_row - scroll_top
698 });
699
700 if distance_from_top != scroll_lock.distance_from_top {
701 editor_assists.scroll_lock = None;
702 }
703 }
704 }
705 }
706 EditorEvent::SelectionsChanged { .. } => {
707 for assist_id in editor_assists.assist_ids.clone() {
708 let assist = &self.assists[&assist_id];
709 if let Some(decorations) = assist.decorations.as_ref() {
710 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
711 return;
712 }
713 }
714 }
715
716 editor_assists.scroll_lock = None;
717 }
718 _ => {}
719 }
720 }
721
722 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
723 if let Some(telemetry) = self.telemetry.as_ref() {
724 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
725 telemetry.report_assistant_event(
726 None,
727 telemetry_events::AssistantKind::Inline,
728 if undo {
729 telemetry_events::AssistantPhase::Rejected
730 } else {
731 telemetry_events::AssistantPhase::Accepted
732 },
733 model.telemetry_id(),
734 None,
735 None,
736 );
737 }
738 }
739 if let Some(assist) = self.assists.get(&assist_id) {
740 let assist_group_id = assist.group_id;
741 if self.assist_groups[&assist_group_id].linked {
742 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
743 self.finish_assist(assist_id, undo, cx);
744 }
745 return;
746 }
747 }
748
749 self.dismiss_assist(assist_id, cx);
750
751 if let Some(assist) = self.assists.remove(&assist_id) {
752 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
753 {
754 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
755 if entry.get().assist_ids.is_empty() {
756 entry.remove();
757 }
758 }
759
760 if let hash_map::Entry::Occupied(mut entry) =
761 self.assists_by_editor.entry(assist.editor.clone())
762 {
763 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
764 if entry.get().assist_ids.is_empty() {
765 entry.remove();
766 if let Some(editor) = assist.editor.upgrade() {
767 self.update_editor_highlights(&editor, cx);
768 }
769 } else {
770 entry.get().highlight_updates.send(()).ok();
771 }
772 }
773
774 if undo {
775 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
776 } else {
777 let confirmed_alternative = assist.codegen.read(cx).active_alternative().clone();
778 self.confirmed_assists
779 .insert(assist_id, confirmed_alternative);
780 }
781 }
782
783 // Remove the assist from the status updates map
784 self.assist_observations.remove(&assist_id);
785 }
786
787 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
788 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
789 return false;
790 };
791 codegen.update(cx, |this, cx| this.undo(cx));
792 true
793 }
794
795 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
796 let Some(assist) = self.assists.get_mut(&assist_id) else {
797 return false;
798 };
799 let Some(editor) = assist.editor.upgrade() else {
800 return false;
801 };
802 let Some(decorations) = assist.decorations.take() else {
803 return false;
804 };
805
806 editor.update(cx, |editor, cx| {
807 let mut to_remove = decorations.removed_line_block_ids;
808 to_remove.insert(decorations.prompt_block_id);
809 to_remove.insert(decorations.end_block_id);
810 editor.remove_blocks(to_remove, None, cx);
811 });
812
813 if decorations
814 .prompt_editor
815 .focus_handle(cx)
816 .contains_focused(cx)
817 {
818 self.focus_next_assist(assist_id, cx);
819 }
820
821 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
822 if editor_assists
823 .scroll_lock
824 .as_ref()
825 .map_or(false, |lock| lock.assist_id == assist_id)
826 {
827 editor_assists.scroll_lock = None;
828 }
829 editor_assists.highlight_updates.send(()).ok();
830 }
831
832 true
833 }
834
835 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
836 let Some(assist) = self.assists.get(&assist_id) else {
837 return;
838 };
839
840 let assist_group = &self.assist_groups[&assist.group_id];
841 let assist_ix = assist_group
842 .assist_ids
843 .iter()
844 .position(|id| *id == assist_id)
845 .unwrap();
846 let assist_ids = assist_group
847 .assist_ids
848 .iter()
849 .skip(assist_ix + 1)
850 .chain(assist_group.assist_ids.iter().take(assist_ix));
851
852 for assist_id in assist_ids {
853 let assist = &self.assists[assist_id];
854 if assist.decorations.is_some() {
855 self.focus_assist(*assist_id, cx);
856 return;
857 }
858 }
859
860 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
861 }
862
863 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
864 let Some(assist) = self.assists.get(&assist_id) else {
865 return;
866 };
867
868 if let Some(decorations) = assist.decorations.as_ref() {
869 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
870 prompt_editor.editor.update(cx, |editor, cx| {
871 editor.focus(cx);
872 editor.select_all(&SelectAll, cx);
873 })
874 });
875 }
876
877 self.scroll_to_assist(assist_id, cx);
878 }
879
880 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
881 let Some(assist) = self.assists.get(&assist_id) else {
882 return;
883 };
884 let Some(editor) = assist.editor.upgrade() else {
885 return;
886 };
887
888 let position = assist.range.start;
889 editor.update(cx, |editor, cx| {
890 editor.change_selections(None, cx, |selections| {
891 selections.select_anchor_ranges([position..position])
892 });
893
894 let mut scroll_target_top;
895 let mut scroll_target_bottom;
896 if let Some(decorations) = assist.decorations.as_ref() {
897 scroll_target_top = editor
898 .row_for_block(decorations.prompt_block_id, cx)
899 .unwrap()
900 .0 as f32;
901 scroll_target_bottom = editor
902 .row_for_block(decorations.end_block_id, cx)
903 .unwrap()
904 .0 as f32;
905 } else {
906 let snapshot = editor.snapshot(cx);
907 let start_row = assist
908 .range
909 .start
910 .to_display_point(&snapshot.display_snapshot)
911 .row();
912 scroll_target_top = start_row.0 as f32;
913 scroll_target_bottom = scroll_target_top + 1.;
914 }
915 scroll_target_top -= editor.vertical_scroll_margin() as f32;
916 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
917
918 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
919 let scroll_top = editor.scroll_position(cx).y;
920 let scroll_bottom = scroll_top + height_in_lines;
921
922 if scroll_target_top < scroll_top {
923 editor.set_scroll_position(point(0., scroll_target_top), cx);
924 } else if scroll_target_bottom > scroll_bottom {
925 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
926 editor
927 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
928 } else {
929 editor.set_scroll_position(point(0., scroll_target_top), cx);
930 }
931 }
932 });
933 }
934
935 fn unlink_assist_group(
936 &mut self,
937 assist_group_id: InlineAssistGroupId,
938 cx: &mut WindowContext,
939 ) -> Vec<InlineAssistId> {
940 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
941 assist_group.linked = false;
942 for assist_id in &assist_group.assist_ids {
943 let assist = self.assists.get_mut(assist_id).unwrap();
944 if let Some(editor_decorations) = assist.decorations.as_ref() {
945 editor_decorations
946 .prompt_editor
947 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
948 }
949 }
950 assist_group.assist_ids.clone()
951 }
952
953 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
954 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
955 assist
956 } else {
957 return;
958 };
959
960 let assist_group_id = assist.group_id;
961 if self.assist_groups[&assist_group_id].linked {
962 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
963 self.start_assist(assist_id, cx);
964 }
965 return;
966 }
967
968 let Some(user_prompt) = assist.user_prompt(cx) else {
969 return;
970 };
971
972 self.prompt_history.retain(|prompt| *prompt != user_prompt);
973 self.prompt_history.push_back(user_prompt.clone());
974 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
975 self.prompt_history.pop_front();
976 }
977
978 let assistant_panel_context = assist.assistant_panel_context(cx);
979
980 assist
981 .codegen
982 .update(cx, |codegen, cx| {
983 codegen.start(user_prompt, assistant_panel_context, cx)
984 })
985 .log_err();
986
987 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
988 tx.send(AssistStatus::Started).ok();
989 }
990 }
991
992 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
993 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
994 assist
995 } else {
996 return;
997 };
998
999 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1000
1001 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
1002 tx.send(AssistStatus::Stopped).ok();
1003 }
1004 }
1005
1006 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
1007 if let Some(assist) = self.assists.get(&assist_id) {
1008 match assist.codegen.read(cx).status(cx) {
1009 CodegenStatus::Idle => InlineAssistStatus::Idle,
1010 CodegenStatus::Pending => InlineAssistStatus::Pending,
1011 CodegenStatus::Done => InlineAssistStatus::Done,
1012 CodegenStatus::Error(_) => InlineAssistStatus::Error,
1013 }
1014 } else if self.confirmed_assists.contains_key(&assist_id) {
1015 InlineAssistStatus::Confirmed
1016 } else {
1017 InlineAssistStatus::Canceled
1018 }
1019 }
1020
1021 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1022 let mut gutter_pending_ranges = Vec::new();
1023 let mut gutter_transformed_ranges = Vec::new();
1024 let mut foreground_ranges = Vec::new();
1025 let mut inserted_row_ranges = Vec::new();
1026 let empty_assist_ids = Vec::new();
1027 let assist_ids = self
1028 .assists_by_editor
1029 .get(&editor.downgrade())
1030 .map_or(&empty_assist_ids, |editor_assists| {
1031 &editor_assists.assist_ids
1032 });
1033
1034 for assist_id in assist_ids {
1035 if let Some(assist) = self.assists.get(assist_id) {
1036 let codegen = assist.codegen.read(cx);
1037 let buffer = codegen.buffer(cx).read(cx).read(cx);
1038 foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1039
1040 let pending_range =
1041 codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1042 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1043 gutter_pending_ranges.push(pending_range);
1044 }
1045
1046 if let Some(edit_position) = codegen.edit_position(cx) {
1047 let edited_range = assist.range.start..edit_position;
1048 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1049 gutter_transformed_ranges.push(edited_range);
1050 }
1051 }
1052
1053 if assist.decorations.is_some() {
1054 inserted_row_ranges
1055 .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1056 }
1057 }
1058 }
1059
1060 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1061 merge_ranges(&mut foreground_ranges, &snapshot);
1062 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1063 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1064 editor.update(cx, |editor, cx| {
1065 enum GutterPendingRange {}
1066 if gutter_pending_ranges.is_empty() {
1067 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1068 } else {
1069 editor.highlight_gutter::<GutterPendingRange>(
1070 &gutter_pending_ranges,
1071 |cx| cx.theme().status().info_background,
1072 cx,
1073 )
1074 }
1075
1076 enum GutterTransformedRange {}
1077 if gutter_transformed_ranges.is_empty() {
1078 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1079 } else {
1080 editor.highlight_gutter::<GutterTransformedRange>(
1081 &gutter_transformed_ranges,
1082 |cx| cx.theme().status().info,
1083 cx,
1084 )
1085 }
1086
1087 if foreground_ranges.is_empty() {
1088 editor.clear_highlights::<InlineAssist>(cx);
1089 } else {
1090 editor.highlight_text::<InlineAssist>(
1091 foreground_ranges,
1092 HighlightStyle {
1093 fade_out: Some(0.6),
1094 ..Default::default()
1095 },
1096 cx,
1097 );
1098 }
1099
1100 editor.clear_row_highlights::<InlineAssist>();
1101 for row_range in inserted_row_ranges {
1102 editor.highlight_rows::<InlineAssist>(
1103 row_range,
1104 Some(cx.theme().status().info_background),
1105 false,
1106 cx,
1107 );
1108 }
1109 });
1110 }
1111
1112 fn update_editor_blocks(
1113 &mut self,
1114 editor: &View<Editor>,
1115 assist_id: InlineAssistId,
1116 cx: &mut WindowContext,
1117 ) {
1118 let Some(assist) = self.assists.get_mut(&assist_id) else {
1119 return;
1120 };
1121 let Some(decorations) = assist.decorations.as_mut() else {
1122 return;
1123 };
1124
1125 let codegen = assist.codegen.read(cx);
1126 let old_snapshot = codegen.snapshot(cx);
1127 let old_buffer = codegen.old_buffer(cx);
1128 let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1129
1130 editor.update(cx, |editor, cx| {
1131 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1132 editor.remove_blocks(old_blocks, None, cx);
1133
1134 let mut new_blocks = Vec::new();
1135 for (new_row, old_row_range) in deleted_row_ranges {
1136 let (_, buffer_start) = old_snapshot
1137 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1138 .unwrap();
1139 let (_, buffer_end) = old_snapshot
1140 .point_to_buffer_offset(Point::new(
1141 *old_row_range.end(),
1142 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1143 ))
1144 .unwrap();
1145
1146 let deleted_lines_editor = cx.new_view(|cx| {
1147 let multi_buffer = cx.new_model(|_| {
1148 MultiBuffer::without_headers(language::Capability::ReadOnly)
1149 });
1150 multi_buffer.update(cx, |multi_buffer, cx| {
1151 multi_buffer.push_excerpts(
1152 old_buffer.clone(),
1153 Some(ExcerptRange {
1154 context: buffer_start..buffer_end,
1155 primary: None,
1156 }),
1157 cx,
1158 );
1159 });
1160
1161 enum DeletedLines {}
1162 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1163 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1164 editor.set_show_wrap_guides(false, cx);
1165 editor.set_show_gutter(false, cx);
1166 editor.scroll_manager.set_forbid_vertical_scroll(true);
1167 editor.set_read_only(true);
1168 editor.set_show_inline_completions(Some(false), cx);
1169 editor.highlight_rows::<DeletedLines>(
1170 Anchor::min()..=Anchor::max(),
1171 Some(cx.theme().status().deleted_background),
1172 false,
1173 cx,
1174 );
1175 editor
1176 });
1177
1178 let height =
1179 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1180 new_blocks.push(BlockProperties {
1181 position: new_row,
1182 height,
1183 style: BlockStyle::Flex,
1184 render: Box::new(move |cx| {
1185 div()
1186 .bg(cx.theme().status().deleted_background)
1187 .size_full()
1188 .h(height as f32 * cx.line_height())
1189 .pl(cx.gutter_dimensions.full_width())
1190 .child(deleted_lines_editor.clone())
1191 .into_any_element()
1192 }),
1193 disposition: BlockDisposition::Above,
1194 priority: 0,
1195 });
1196 }
1197
1198 decorations.removed_line_block_ids = editor
1199 .insert_blocks(new_blocks, None, cx)
1200 .into_iter()
1201 .collect();
1202 })
1203 }
1204
1205 pub fn observe_assist(
1206 &mut self,
1207 assist_id: InlineAssistId,
1208 ) -> async_watch::Receiver<AssistStatus> {
1209 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1210 rx.clone()
1211 } else {
1212 let (tx, rx) = async_watch::channel(AssistStatus::Idle);
1213 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1214 rx
1215 }
1216 }
1217}
1218
1219pub enum InlineAssistStatus {
1220 Idle,
1221 Pending,
1222 Done,
1223 Error,
1224 Confirmed,
1225 Canceled,
1226}
1227
1228impl InlineAssistStatus {
1229 pub(crate) fn is_pending(&self) -> bool {
1230 matches!(self, Self::Pending)
1231 }
1232
1233 pub(crate) fn is_confirmed(&self) -> bool {
1234 matches!(self, Self::Confirmed)
1235 }
1236
1237 pub(crate) fn is_done(&self) -> bool {
1238 matches!(self, Self::Done)
1239 }
1240}
1241
1242struct EditorInlineAssists {
1243 assist_ids: Vec<InlineAssistId>,
1244 scroll_lock: Option<InlineAssistScrollLock>,
1245 highlight_updates: async_watch::Sender<()>,
1246 _update_highlights: Task<Result<()>>,
1247 _subscriptions: Vec<gpui::Subscription>,
1248}
1249
1250struct InlineAssistScrollLock {
1251 assist_id: InlineAssistId,
1252 distance_from_top: f32,
1253}
1254
1255impl EditorInlineAssists {
1256 #[allow(clippy::too_many_arguments)]
1257 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1258 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1259 Self {
1260 assist_ids: Vec::new(),
1261 scroll_lock: None,
1262 highlight_updates: highlight_updates_tx,
1263 _update_highlights: cx.spawn(|mut cx| {
1264 let editor = editor.downgrade();
1265 async move {
1266 while let Ok(()) = highlight_updates_rx.changed().await {
1267 let editor = editor.upgrade().context("editor was dropped")?;
1268 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1269 assistant.update_editor_highlights(&editor, cx);
1270 })?;
1271 }
1272 Ok(())
1273 }
1274 }),
1275 _subscriptions: vec![
1276 cx.observe_release(editor, {
1277 let editor = editor.downgrade();
1278 |_, cx| {
1279 InlineAssistant::update_global(cx, |this, cx| {
1280 this.handle_editor_release(editor, cx);
1281 })
1282 }
1283 }),
1284 cx.observe(editor, move |editor, cx| {
1285 InlineAssistant::update_global(cx, |this, cx| {
1286 this.handle_editor_change(editor, cx)
1287 })
1288 }),
1289 cx.subscribe(editor, move |editor, event, cx| {
1290 InlineAssistant::update_global(cx, |this, cx| {
1291 this.handle_editor_event(editor, event, cx)
1292 })
1293 }),
1294 editor.update(cx, |editor, cx| {
1295 let editor_handle = cx.view().downgrade();
1296 editor.register_action(
1297 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1298 InlineAssistant::update_global(cx, |this, cx| {
1299 if let Some(editor) = editor_handle.upgrade() {
1300 this.handle_editor_newline(editor, cx)
1301 }
1302 })
1303 },
1304 )
1305 }),
1306 editor.update(cx, |editor, cx| {
1307 let editor_handle = cx.view().downgrade();
1308 editor.register_action(
1309 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1310 InlineAssistant::update_global(cx, |this, cx| {
1311 if let Some(editor) = editor_handle.upgrade() {
1312 this.handle_editor_cancel(editor, cx)
1313 }
1314 })
1315 },
1316 )
1317 }),
1318 ],
1319 }
1320 }
1321}
1322
1323struct InlineAssistGroup {
1324 assist_ids: Vec<InlineAssistId>,
1325 linked: bool,
1326 active_assist_id: Option<InlineAssistId>,
1327}
1328
1329impl InlineAssistGroup {
1330 fn new() -> Self {
1331 Self {
1332 assist_ids: Vec::new(),
1333 linked: true,
1334 active_assist_id: None,
1335 }
1336 }
1337}
1338
1339fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1340 let editor = editor.clone();
1341 Box::new(move |cx: &mut BlockContext| {
1342 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1343 editor.clone().into_any_element()
1344 })
1345}
1346
1347#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1348pub struct InlineAssistId(usize);
1349
1350impl InlineAssistId {
1351 fn post_inc(&mut self) -> InlineAssistId {
1352 let id = *self;
1353 self.0 += 1;
1354 id
1355 }
1356}
1357
1358#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1359struct InlineAssistGroupId(usize);
1360
1361impl InlineAssistGroupId {
1362 fn post_inc(&mut self) -> InlineAssistGroupId {
1363 let id = *self;
1364 self.0 += 1;
1365 id
1366 }
1367}
1368
1369enum PromptEditorEvent {
1370 StartRequested,
1371 StopRequested,
1372 ConfirmRequested,
1373 CancelRequested,
1374 DismissRequested,
1375}
1376
1377struct PromptEditor {
1378 id: InlineAssistId,
1379 fs: Arc<dyn Fs>,
1380 editor: View<Editor>,
1381 edited_since_done: bool,
1382 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1383 prompt_history: VecDeque<String>,
1384 prompt_history_ix: Option<usize>,
1385 pending_prompt: String,
1386 codegen: Model<Codegen>,
1387 _codegen_subscription: Subscription,
1388 editor_subscriptions: Vec<Subscription>,
1389 pending_token_count: Task<Result<()>>,
1390 token_counts: Option<TokenCounts>,
1391 _token_count_subscriptions: Vec<Subscription>,
1392 workspace: Option<WeakView<Workspace>>,
1393 show_rate_limit_notice: bool,
1394}
1395
1396#[derive(Copy, Clone)]
1397pub struct TokenCounts {
1398 total: usize,
1399 assistant_panel: usize,
1400}
1401
1402impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1403
1404impl Render for PromptEditor {
1405 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1406 let gutter_dimensions = *self.gutter_dimensions.lock();
1407 let codegen = self.codegen.read(cx);
1408
1409 let mut buttons = Vec::new();
1410 if codegen.alternative_count(cx) > 1 {
1411 buttons.push(self.render_cycle_controls(cx));
1412 }
1413
1414 let status = codegen.status(cx);
1415 buttons.extend(match status {
1416 CodegenStatus::Idle => {
1417 vec![
1418 IconButton::new("cancel", IconName::Close)
1419 .icon_color(Color::Muted)
1420 .shape(IconButtonShape::Square)
1421 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1422 .on_click(
1423 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1424 )
1425 .into_any_element(),
1426 IconButton::new("start", IconName::SparkleAlt)
1427 .icon_color(Color::Muted)
1428 .shape(IconButtonShape::Square)
1429 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1430 .on_click(
1431 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1432 )
1433 .into_any_element(),
1434 ]
1435 }
1436 CodegenStatus::Pending => {
1437 vec![
1438 IconButton::new("cancel", IconName::Close)
1439 .icon_color(Color::Muted)
1440 .shape(IconButtonShape::Square)
1441 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1442 .on_click(
1443 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1444 )
1445 .into_any_element(),
1446 IconButton::new("stop", IconName::Stop)
1447 .icon_color(Color::Error)
1448 .shape(IconButtonShape::Square)
1449 .tooltip(|cx| {
1450 Tooltip::with_meta(
1451 "Interrupt Transformation",
1452 Some(&menu::Cancel),
1453 "Changes won't be discarded",
1454 cx,
1455 )
1456 })
1457 .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
1458 .into_any_element(),
1459 ]
1460 }
1461 CodegenStatus::Error(_) | CodegenStatus::Done => {
1462 vec![
1463 IconButton::new("cancel", IconName::Close)
1464 .icon_color(Color::Muted)
1465 .shape(IconButtonShape::Square)
1466 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1467 .on_click(
1468 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1469 )
1470 .into_any_element(),
1471 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1472 IconButton::new("restart", IconName::RotateCw)
1473 .icon_color(Color::Info)
1474 .shape(IconButtonShape::Square)
1475 .tooltip(|cx| {
1476 Tooltip::with_meta(
1477 "Restart Transformation",
1478 Some(&menu::Confirm),
1479 "Changes will be discarded",
1480 cx,
1481 )
1482 })
1483 .on_click(cx.listener(|_, _, cx| {
1484 cx.emit(PromptEditorEvent::StartRequested);
1485 }))
1486 .into_any_element()
1487 } else {
1488 IconButton::new("confirm", IconName::Check)
1489 .icon_color(Color::Info)
1490 .shape(IconButtonShape::Square)
1491 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1492 .on_click(cx.listener(|_, _, cx| {
1493 cx.emit(PromptEditorEvent::ConfirmRequested);
1494 }))
1495 .into_any_element()
1496 },
1497 ]
1498 }
1499 });
1500
1501 h_flex()
1502 .key_context("PromptEditor")
1503 .bg(cx.theme().colors().editor_background)
1504 .border_y_1()
1505 .border_color(cx.theme().status().info_border)
1506 .size_full()
1507 .py(cx.line_height() / 2.5)
1508 .on_action(cx.listener(Self::confirm))
1509 .on_action(cx.listener(Self::cancel))
1510 .on_action(cx.listener(Self::move_up))
1511 .on_action(cx.listener(Self::move_down))
1512 .capture_action(cx.listener(Self::cycle_prev))
1513 .capture_action(cx.listener(Self::cycle_next))
1514 .child(
1515 h_flex()
1516 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1517 .justify_center()
1518 .gap_2()
1519 .child(
1520 ModelSelector::new(
1521 self.fs.clone(),
1522 IconButton::new("context", IconName::SettingsAlt)
1523 .shape(IconButtonShape::Square)
1524 .icon_size(IconSize::Small)
1525 .icon_color(Color::Muted)
1526 .tooltip(move |cx| {
1527 Tooltip::with_meta(
1528 format!(
1529 "Using {}",
1530 LanguageModelRegistry::read_global(cx)
1531 .active_model()
1532 .map(|model| model.name().0)
1533 .unwrap_or_else(|| "No model selected".into()),
1534 ),
1535 None,
1536 "Change Model",
1537 cx,
1538 )
1539 }),
1540 )
1541 .with_info_text(
1542 "Inline edits use context\n\
1543 from the currently selected\n\
1544 assistant panel tab.",
1545 ),
1546 )
1547 .map(|el| {
1548 let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1549 return el;
1550 };
1551
1552 let error_message = SharedString::from(error.to_string());
1553 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1554 && cx.has_flag::<ZedPro>()
1555 {
1556 el.child(
1557 v_flex()
1558 .child(
1559 IconButton::new("rate-limit-error", IconName::XCircle)
1560 .selected(self.show_rate_limit_notice)
1561 .shape(IconButtonShape::Square)
1562 .icon_size(IconSize::Small)
1563 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1564 )
1565 .children(self.show_rate_limit_notice.then(|| {
1566 deferred(
1567 anchored()
1568 .position_mode(gpui::AnchoredPositionMode::Local)
1569 .position(point(px(0.), px(24.)))
1570 .anchor(gpui::AnchorCorner::TopLeft)
1571 .child(self.render_rate_limit_notice(cx)),
1572 )
1573 })),
1574 )
1575 } else {
1576 el.child(
1577 div()
1578 .id("error")
1579 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1580 .child(
1581 Icon::new(IconName::XCircle)
1582 .size(IconSize::Small)
1583 .color(Color::Error),
1584 ),
1585 )
1586 }
1587 }),
1588 )
1589 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1590 .child(
1591 h_flex()
1592 .gap_2()
1593 .pr_6()
1594 .children(self.render_token_count(cx))
1595 .children(buttons),
1596 )
1597 }
1598}
1599
1600impl FocusableView for PromptEditor {
1601 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1602 self.editor.focus_handle(cx)
1603 }
1604}
1605
1606impl PromptEditor {
1607 const MAX_LINES: u8 = 8;
1608
1609 #[allow(clippy::too_many_arguments)]
1610 fn new(
1611 id: InlineAssistId,
1612 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1613 prompt_history: VecDeque<String>,
1614 prompt_buffer: Model<MultiBuffer>,
1615 codegen: Model<Codegen>,
1616 parent_editor: &View<Editor>,
1617 assistant_panel: Option<&View<AssistantPanel>>,
1618 workspace: Option<WeakView<Workspace>>,
1619 fs: Arc<dyn Fs>,
1620 cx: &mut ViewContext<Self>,
1621 ) -> Self {
1622 let prompt_editor = cx.new_view(|cx| {
1623 let mut editor = Editor::new(
1624 EditorMode::AutoHeight {
1625 max_lines: Self::MAX_LINES as usize,
1626 },
1627 prompt_buffer,
1628 None,
1629 false,
1630 cx,
1631 );
1632 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1633 // Since the prompt editors for all inline assistants are linked,
1634 // always show the cursor (even when it isn't focused) because
1635 // typing in one will make what you typed appear in all of them.
1636 editor.set_show_cursor_when_unfocused(true, cx);
1637 editor.set_placeholder_text("Add a prompt…", cx);
1638 editor
1639 });
1640
1641 let mut token_count_subscriptions = Vec::new();
1642 token_count_subscriptions
1643 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1644 if let Some(assistant_panel) = assistant_panel {
1645 token_count_subscriptions
1646 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1647 }
1648
1649 let mut this = Self {
1650 id,
1651 editor: prompt_editor,
1652 edited_since_done: false,
1653 gutter_dimensions,
1654 prompt_history,
1655 prompt_history_ix: None,
1656 pending_prompt: String::new(),
1657 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1658 editor_subscriptions: Vec::new(),
1659 codegen,
1660 fs,
1661 pending_token_count: Task::ready(Ok(())),
1662 token_counts: None,
1663 _token_count_subscriptions: token_count_subscriptions,
1664 workspace,
1665 show_rate_limit_notice: false,
1666 };
1667 this.count_tokens(cx);
1668 this.subscribe_to_editor(cx);
1669 this
1670 }
1671
1672 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1673 self.editor_subscriptions.clear();
1674 self.editor_subscriptions
1675 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1676 }
1677
1678 fn set_show_cursor_when_unfocused(
1679 &mut self,
1680 show_cursor_when_unfocused: bool,
1681 cx: &mut ViewContext<Self>,
1682 ) {
1683 self.editor.update(cx, |editor, cx| {
1684 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1685 });
1686 }
1687
1688 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1689 let prompt = self.prompt(cx);
1690 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1691 self.editor = cx.new_view(|cx| {
1692 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1693 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1694 editor.set_placeholder_text("Add a prompt…", cx);
1695 editor.set_text(prompt, cx);
1696 if focus {
1697 editor.focus(cx);
1698 }
1699 editor
1700 });
1701 self.subscribe_to_editor(cx);
1702 }
1703
1704 fn prompt(&self, cx: &AppContext) -> String {
1705 self.editor.read(cx).text(cx)
1706 }
1707
1708 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1709 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1710 if self.show_rate_limit_notice {
1711 cx.focus_view(&self.editor);
1712 }
1713 cx.notify();
1714 }
1715
1716 fn handle_parent_editor_event(
1717 &mut self,
1718 _: View<Editor>,
1719 event: &EditorEvent,
1720 cx: &mut ViewContext<Self>,
1721 ) {
1722 if let EditorEvent::BufferEdited { .. } = event {
1723 self.count_tokens(cx);
1724 }
1725 }
1726
1727 fn handle_assistant_panel_event(
1728 &mut self,
1729 _: View<AssistantPanel>,
1730 event: &AssistantPanelEvent,
1731 cx: &mut ViewContext<Self>,
1732 ) {
1733 let AssistantPanelEvent::ContextEdited { .. } = event;
1734 self.count_tokens(cx);
1735 }
1736
1737 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1738 let assist_id = self.id;
1739 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1740 cx.background_executor().timer(Duration::from_secs(1)).await;
1741 let token_count = cx
1742 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1743 let assist = inline_assistant
1744 .assists
1745 .get(&assist_id)
1746 .context("assist not found")?;
1747 anyhow::Ok(assist.count_tokens(cx))
1748 })??
1749 .await?;
1750
1751 this.update(&mut cx, |this, cx| {
1752 this.token_counts = Some(token_count);
1753 cx.notify();
1754 })
1755 })
1756 }
1757
1758 fn handle_prompt_editor_events(
1759 &mut self,
1760 _: View<Editor>,
1761 event: &EditorEvent,
1762 cx: &mut ViewContext<Self>,
1763 ) {
1764 match event {
1765 EditorEvent::Edited { .. } => {
1766 let prompt = self.editor.read(cx).text(cx);
1767 if self
1768 .prompt_history_ix
1769 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1770 {
1771 self.prompt_history_ix.take();
1772 self.pending_prompt = prompt;
1773 }
1774
1775 self.edited_since_done = true;
1776 cx.notify();
1777 }
1778 EditorEvent::BufferEdited => {
1779 self.count_tokens(cx);
1780 }
1781 EditorEvent::Blurred => {
1782 if self.show_rate_limit_notice {
1783 self.show_rate_limit_notice = false;
1784 cx.notify();
1785 }
1786 }
1787 _ => {}
1788 }
1789 }
1790
1791 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1792 match self.codegen.read(cx).status(cx) {
1793 CodegenStatus::Idle => {
1794 self.editor
1795 .update(cx, |editor, _| editor.set_read_only(false));
1796 }
1797 CodegenStatus::Pending => {
1798 self.editor
1799 .update(cx, |editor, _| editor.set_read_only(true));
1800 }
1801 CodegenStatus::Done => {
1802 self.edited_since_done = false;
1803 self.editor
1804 .update(cx, |editor, _| editor.set_read_only(false));
1805 }
1806 CodegenStatus::Error(error) => {
1807 if cx.has_flag::<ZedPro>()
1808 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1809 && !dismissed_rate_limit_notice()
1810 {
1811 self.show_rate_limit_notice = true;
1812 cx.notify();
1813 }
1814
1815 self.edited_since_done = false;
1816 self.editor
1817 .update(cx, |editor, _| editor.set_read_only(false));
1818 }
1819 }
1820 }
1821
1822 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1823 match self.codegen.read(cx).status(cx) {
1824 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1825 cx.emit(PromptEditorEvent::CancelRequested);
1826 }
1827 CodegenStatus::Pending => {
1828 cx.emit(PromptEditorEvent::StopRequested);
1829 }
1830 }
1831 }
1832
1833 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1834 match self.codegen.read(cx).status(cx) {
1835 CodegenStatus::Idle => {
1836 cx.emit(PromptEditorEvent::StartRequested);
1837 }
1838 CodegenStatus::Pending => {
1839 cx.emit(PromptEditorEvent::DismissRequested);
1840 }
1841 CodegenStatus::Done => {
1842 if self.edited_since_done {
1843 cx.emit(PromptEditorEvent::StartRequested);
1844 } else {
1845 cx.emit(PromptEditorEvent::ConfirmRequested);
1846 }
1847 }
1848 CodegenStatus::Error(_) => {
1849 cx.emit(PromptEditorEvent::StartRequested);
1850 }
1851 }
1852 }
1853
1854 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1855 if let Some(ix) = self.prompt_history_ix {
1856 if ix > 0 {
1857 self.prompt_history_ix = Some(ix - 1);
1858 let prompt = self.prompt_history[ix - 1].as_str();
1859 self.editor.update(cx, |editor, cx| {
1860 editor.set_text(prompt, cx);
1861 editor.move_to_beginning(&Default::default(), cx);
1862 });
1863 }
1864 } else if !self.prompt_history.is_empty() {
1865 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1866 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1867 self.editor.update(cx, |editor, cx| {
1868 editor.set_text(prompt, cx);
1869 editor.move_to_beginning(&Default::default(), cx);
1870 });
1871 }
1872 }
1873
1874 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1875 if let Some(ix) = self.prompt_history_ix {
1876 if ix < self.prompt_history.len() - 1 {
1877 self.prompt_history_ix = Some(ix + 1);
1878 let prompt = self.prompt_history[ix + 1].as_str();
1879 self.editor.update(cx, |editor, cx| {
1880 editor.set_text(prompt, cx);
1881 editor.move_to_end(&Default::default(), cx)
1882 });
1883 } else {
1884 self.prompt_history_ix = None;
1885 let prompt = self.pending_prompt.as_str();
1886 self.editor.update(cx, |editor, cx| {
1887 editor.set_text(prompt, cx);
1888 editor.move_to_end(&Default::default(), cx)
1889 });
1890 }
1891 }
1892 }
1893
1894 fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
1895 self.codegen
1896 .update(cx, |codegen, cx| codegen.cycle_prev(cx));
1897 }
1898
1899 fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
1900 self.codegen
1901 .update(cx, |codegen, cx| codegen.cycle_next(cx));
1902 }
1903
1904 fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
1905 let codegen = self.codegen.read(cx);
1906 let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
1907
1908 h_flex()
1909 .child(
1910 IconButton::new("previous", IconName::ChevronLeft)
1911 .icon_color(Color::Muted)
1912 .disabled(disabled)
1913 .shape(IconButtonShape::Square)
1914 .tooltip({
1915 let focus_handle = self.editor.focus_handle(cx);
1916 move |cx| {
1917 Tooltip::for_action_in(
1918 "Previous Alternative",
1919 &CyclePreviousInlineAssist,
1920 &focus_handle,
1921 cx,
1922 )
1923 }
1924 })
1925 .on_click(cx.listener(|this, _, cx| {
1926 this.codegen
1927 .update(cx, |codegen, cx| codegen.cycle_prev(cx))
1928 })),
1929 )
1930 .child(
1931 Label::new(format!(
1932 "{}/{}",
1933 codegen.active_alternative + 1,
1934 codegen.alternative_count(cx)
1935 ))
1936 .size(LabelSize::Small)
1937 .color(if disabled {
1938 Color::Disabled
1939 } else {
1940 Color::Muted
1941 }),
1942 )
1943 .child(
1944 IconButton::new("next", IconName::ChevronRight)
1945 .icon_color(Color::Muted)
1946 .disabled(disabled)
1947 .shape(IconButtonShape::Square)
1948 .tooltip({
1949 let focus_handle = self.editor.focus_handle(cx);
1950 move |cx| {
1951 Tooltip::for_action_in(
1952 "Next Alternative",
1953 &CycleNextInlineAssist,
1954 &focus_handle,
1955 cx,
1956 )
1957 }
1958 })
1959 .on_click(cx.listener(|this, _, cx| {
1960 this.codegen
1961 .update(cx, |codegen, cx| codegen.cycle_next(cx))
1962 })),
1963 )
1964 .into_any_element()
1965 }
1966
1967 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1968 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1969 let token_counts = self.token_counts?;
1970 let max_token_count = model.max_token_count();
1971
1972 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1973 let token_count_color = if remaining_tokens <= 0 {
1974 Color::Error
1975 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1976 Color::Warning
1977 } else {
1978 Color::Muted
1979 };
1980
1981 let mut token_count = h_flex()
1982 .id("token_count")
1983 .gap_0p5()
1984 .child(
1985 Label::new(humanize_token_count(token_counts.total))
1986 .size(LabelSize::Small)
1987 .color(token_count_color),
1988 )
1989 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1990 .child(
1991 Label::new(humanize_token_count(max_token_count))
1992 .size(LabelSize::Small)
1993 .color(Color::Muted),
1994 );
1995 if let Some(workspace) = self.workspace.clone() {
1996 token_count = token_count
1997 .tooltip(move |cx| {
1998 Tooltip::with_meta(
1999 format!(
2000 "Tokens Used ({} from the Assistant Panel)",
2001 humanize_token_count(token_counts.assistant_panel)
2002 ),
2003 None,
2004 "Click to open the Assistant Panel",
2005 cx,
2006 )
2007 })
2008 .cursor_pointer()
2009 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
2010 .on_click(move |_, cx| {
2011 cx.stop_propagation();
2012 workspace
2013 .update(cx, |workspace, cx| {
2014 workspace.focus_panel::<AssistantPanel>(cx)
2015 })
2016 .ok();
2017 });
2018 } else {
2019 token_count = token_count
2020 .cursor_default()
2021 .tooltip(|cx| Tooltip::text("Tokens used", cx));
2022 }
2023
2024 Some(token_count)
2025 }
2026
2027 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2028 let settings = ThemeSettings::get_global(cx);
2029 let text_style = TextStyle {
2030 color: if self.editor.read(cx).read_only(cx) {
2031 cx.theme().colors().text_disabled
2032 } else {
2033 cx.theme().colors().text
2034 },
2035 font_family: settings.buffer_font.family.clone(),
2036 font_fallbacks: settings.buffer_font.fallbacks.clone(),
2037 font_size: settings.buffer_font_size.into(),
2038 font_weight: settings.buffer_font.weight,
2039 line_height: relative(settings.buffer_line_height.value()),
2040 ..Default::default()
2041 };
2042 EditorElement::new(
2043 &self.editor,
2044 EditorStyle {
2045 background: cx.theme().colors().editor_background,
2046 local_player: cx.theme().players().local(),
2047 text: text_style,
2048 ..Default::default()
2049 },
2050 )
2051 }
2052
2053 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2054 Popover::new().child(
2055 v_flex()
2056 .occlude()
2057 .p_2()
2058 .child(
2059 Label::new("Out of Tokens")
2060 .size(LabelSize::Small)
2061 .weight(FontWeight::BOLD),
2062 )
2063 .child(Label::new(
2064 "Try Zed Pro for higher limits, a wider range of models, and more.",
2065 ))
2066 .child(
2067 h_flex()
2068 .justify_between()
2069 .child(CheckboxWithLabel::new(
2070 "dont-show-again",
2071 Label::new("Don't show again"),
2072 if dismissed_rate_limit_notice() {
2073 ui::Selection::Selected
2074 } else {
2075 ui::Selection::Unselected
2076 },
2077 |selection, cx| {
2078 let is_dismissed = match selection {
2079 ui::Selection::Unselected => false,
2080 ui::Selection::Indeterminate => return,
2081 ui::Selection::Selected => true,
2082 };
2083
2084 set_rate_limit_notice_dismissed(is_dismissed, cx)
2085 },
2086 ))
2087 .child(
2088 h_flex()
2089 .gap_2()
2090 .child(
2091 Button::new("dismiss", "Dismiss")
2092 .style(ButtonStyle::Transparent)
2093 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2094 )
2095 .child(Button::new("more-info", "More Info").on_click(
2096 |_event, cx| {
2097 cx.dispatch_action(Box::new(
2098 zed_actions::OpenAccountSettings,
2099 ))
2100 },
2101 )),
2102 ),
2103 ),
2104 )
2105 }
2106}
2107
2108const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2109
2110fn dismissed_rate_limit_notice() -> bool {
2111 db::kvp::KEY_VALUE_STORE
2112 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2113 .log_err()
2114 .map_or(false, |s| s.is_some())
2115}
2116
2117fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2118 db::write_and_log(cx, move || async move {
2119 if is_dismissed {
2120 db::kvp::KEY_VALUE_STORE
2121 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2122 .await
2123 } else {
2124 db::kvp::KEY_VALUE_STORE
2125 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2126 .await
2127 }
2128 })
2129}
2130
2131struct InlineAssist {
2132 group_id: InlineAssistGroupId,
2133 range: Range<Anchor>,
2134 editor: WeakView<Editor>,
2135 decorations: Option<InlineAssistDecorations>,
2136 codegen: Model<Codegen>,
2137 _subscriptions: Vec<Subscription>,
2138 workspace: Option<WeakView<Workspace>>,
2139 include_context: bool,
2140}
2141
2142impl InlineAssist {
2143 #[allow(clippy::too_many_arguments)]
2144 fn new(
2145 assist_id: InlineAssistId,
2146 group_id: InlineAssistGroupId,
2147 include_context: bool,
2148 editor: &View<Editor>,
2149 prompt_editor: &View<PromptEditor>,
2150 prompt_block_id: CustomBlockId,
2151 end_block_id: CustomBlockId,
2152 range: Range<Anchor>,
2153 codegen: Model<Codegen>,
2154 workspace: Option<WeakView<Workspace>>,
2155 cx: &mut WindowContext,
2156 ) -> Self {
2157 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2158 InlineAssist {
2159 group_id,
2160 include_context,
2161 editor: editor.downgrade(),
2162 decorations: Some(InlineAssistDecorations {
2163 prompt_block_id,
2164 prompt_editor: prompt_editor.clone(),
2165 removed_line_block_ids: HashSet::default(),
2166 end_block_id,
2167 }),
2168 range,
2169 codegen: codegen.clone(),
2170 workspace: workspace.clone(),
2171 _subscriptions: vec![
2172 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2173 InlineAssistant::update_global(cx, |this, cx| {
2174 this.handle_prompt_editor_focus_in(assist_id, cx)
2175 })
2176 }),
2177 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2178 InlineAssistant::update_global(cx, |this, cx| {
2179 this.handle_prompt_editor_focus_out(assist_id, cx)
2180 })
2181 }),
2182 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2183 InlineAssistant::update_global(cx, |this, cx| {
2184 this.handle_prompt_editor_event(prompt_editor, event, cx)
2185 })
2186 }),
2187 cx.observe(&codegen, {
2188 let editor = editor.downgrade();
2189 move |_, cx| {
2190 if let Some(editor) = editor.upgrade() {
2191 InlineAssistant::update_global(cx, |this, cx| {
2192 if let Some(editor_assists) =
2193 this.assists_by_editor.get(&editor.downgrade())
2194 {
2195 editor_assists.highlight_updates.send(()).ok();
2196 }
2197
2198 this.update_editor_blocks(&editor, assist_id, cx);
2199 })
2200 }
2201 }
2202 }),
2203 cx.subscribe(&codegen, move |codegen, event, cx| {
2204 InlineAssistant::update_global(cx, |this, cx| match event {
2205 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2206 CodegenEvent::Finished => {
2207 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2208 assist
2209 } else {
2210 return;
2211 };
2212
2213 if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2214 if assist.decorations.is_none() {
2215 if let Some(workspace) = assist
2216 .workspace
2217 .as_ref()
2218 .and_then(|workspace| workspace.upgrade())
2219 {
2220 let error = format!("Inline assistant error: {}", error);
2221 workspace.update(cx, |workspace, cx| {
2222 struct InlineAssistantError;
2223
2224 let id =
2225 NotificationId::identified::<InlineAssistantError>(
2226 assist_id.0,
2227 );
2228
2229 workspace.show_toast(Toast::new(id, error), cx);
2230 })
2231 }
2232 }
2233 }
2234
2235 if assist.decorations.is_none() {
2236 this.finish_assist(assist_id, false, cx);
2237 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2238 tx.0.send(AssistStatus::Finished).ok();
2239 }
2240 }
2241 })
2242 }),
2243 ],
2244 }
2245 }
2246
2247 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2248 let decorations = self.decorations.as_ref()?;
2249 Some(decorations.prompt_editor.read(cx).prompt(cx))
2250 }
2251
2252 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2253 if self.include_context {
2254 let workspace = self.workspace.as_ref()?;
2255 let workspace = workspace.upgrade()?.read(cx);
2256 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2257 Some(
2258 assistant_panel
2259 .read(cx)
2260 .active_context(cx)?
2261 .read(cx)
2262 .to_completion_request(cx),
2263 )
2264 } else {
2265 None
2266 }
2267 }
2268
2269 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2270 let Some(user_prompt) = self.user_prompt(cx) else {
2271 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2272 };
2273 let assistant_panel_context = self.assistant_panel_context(cx);
2274 self.codegen
2275 .read(cx)
2276 .count_tokens(user_prompt, assistant_panel_context, cx)
2277 }
2278}
2279
2280struct InlineAssistDecorations {
2281 prompt_block_id: CustomBlockId,
2282 prompt_editor: View<PromptEditor>,
2283 removed_line_block_ids: HashSet<CustomBlockId>,
2284 end_block_id: CustomBlockId,
2285}
2286
2287#[derive(Copy, Clone, Debug)]
2288pub enum CodegenEvent {
2289 Finished,
2290 Undone,
2291}
2292
2293pub struct Codegen {
2294 alternatives: Vec<Model<CodegenAlternative>>,
2295 active_alternative: usize,
2296 subscriptions: Vec<Subscription>,
2297 buffer: Model<MultiBuffer>,
2298 range: Range<Anchor>,
2299 initial_transaction_id: Option<TransactionId>,
2300 telemetry: Option<Arc<Telemetry>>,
2301 builder: Arc<PromptBuilder>,
2302}
2303
2304impl Codegen {
2305 pub fn new(
2306 buffer: Model<MultiBuffer>,
2307 range: Range<Anchor>,
2308 initial_transaction_id: Option<TransactionId>,
2309 telemetry: Option<Arc<Telemetry>>,
2310 builder: Arc<PromptBuilder>,
2311 cx: &mut ModelContext<Self>,
2312 ) -> Self {
2313 let codegen = cx.new_model(|cx| {
2314 CodegenAlternative::new(
2315 buffer.clone(),
2316 range.clone(),
2317 false,
2318 telemetry.clone(),
2319 builder.clone(),
2320 cx,
2321 )
2322 });
2323 let mut this = Self {
2324 alternatives: vec![codegen],
2325 active_alternative: 0,
2326 subscriptions: Vec::new(),
2327 buffer,
2328 range,
2329 initial_transaction_id,
2330 telemetry,
2331 builder,
2332 };
2333 this.activate(0, cx);
2334 this
2335 }
2336
2337 fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2338 let codegen = self.active_alternative().clone();
2339 self.subscriptions.clear();
2340 self.subscriptions
2341 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2342 self.subscriptions
2343 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2344 }
2345
2346 fn active_alternative(&self) -> &Model<CodegenAlternative> {
2347 &self.alternatives[self.active_alternative]
2348 }
2349
2350 fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2351 &self.active_alternative().read(cx).status
2352 }
2353
2354 fn alternative_count(&self, cx: &AppContext) -> usize {
2355 LanguageModelRegistry::read_global(cx)
2356 .inline_alternative_models()
2357 .len()
2358 + 1
2359 }
2360
2361 pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2362 let next_active_ix = if self.active_alternative == 0 {
2363 self.alternatives.len() - 1
2364 } else {
2365 self.active_alternative - 1
2366 };
2367 self.activate(next_active_ix, cx);
2368 }
2369
2370 pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2371 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2372 self.activate(next_active_ix, cx);
2373 }
2374
2375 fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2376 self.active_alternative()
2377 .update(cx, |codegen, cx| codegen.set_active(false, cx));
2378 self.active_alternative = index;
2379 self.active_alternative()
2380 .update(cx, |codegen, cx| codegen.set_active(true, cx));
2381 self.subscribe_to_alternative(cx);
2382 cx.notify();
2383 }
2384
2385 pub fn start(
2386 &mut self,
2387 user_prompt: String,
2388 assistant_panel_context: Option<LanguageModelRequest>,
2389 cx: &mut ModelContext<Self>,
2390 ) -> Result<()> {
2391 let alternative_models = LanguageModelRegistry::read_global(cx)
2392 .inline_alternative_models()
2393 .to_vec();
2394
2395 self.active_alternative()
2396 .update(cx, |alternative, cx| alternative.undo(cx));
2397 self.activate(0, cx);
2398 self.alternatives.truncate(1);
2399
2400 for _ in 0..alternative_models.len() {
2401 self.alternatives.push(cx.new_model(|cx| {
2402 CodegenAlternative::new(
2403 self.buffer.clone(),
2404 self.range.clone(),
2405 false,
2406 self.telemetry.clone(),
2407 self.builder.clone(),
2408 cx,
2409 )
2410 }));
2411 }
2412
2413 let primary_model = LanguageModelRegistry::read_global(cx)
2414 .active_model()
2415 .context("no active model")?;
2416
2417 for (model, alternative) in iter::once(primary_model)
2418 .chain(alternative_models)
2419 .zip(&self.alternatives)
2420 {
2421 alternative.update(cx, |alternative, cx| {
2422 alternative.start(
2423 user_prompt.clone(),
2424 assistant_panel_context.clone(),
2425 model.clone(),
2426 cx,
2427 )
2428 })?;
2429 }
2430
2431 Ok(())
2432 }
2433
2434 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2435 for codegen in &self.alternatives {
2436 codegen.update(cx, |codegen, cx| codegen.stop(cx));
2437 }
2438 }
2439
2440 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2441 self.active_alternative()
2442 .update(cx, |codegen, cx| codegen.undo(cx));
2443
2444 self.buffer.update(cx, |buffer, cx| {
2445 if let Some(transaction_id) = self.initial_transaction_id.take() {
2446 buffer.undo_transaction(transaction_id, cx);
2447 buffer.refresh_preview(cx);
2448 }
2449 });
2450 }
2451
2452 pub fn count_tokens(
2453 &self,
2454 user_prompt: String,
2455 assistant_panel_context: Option<LanguageModelRequest>,
2456 cx: &AppContext,
2457 ) -> BoxFuture<'static, Result<TokenCounts>> {
2458 self.active_alternative()
2459 .read(cx)
2460 .count_tokens(user_prompt, assistant_panel_context, cx)
2461 }
2462
2463 pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2464 self.active_alternative().read(cx).buffer.clone()
2465 }
2466
2467 pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2468 self.active_alternative().read(cx).old_buffer.clone()
2469 }
2470
2471 pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2472 self.active_alternative().read(cx).snapshot.clone()
2473 }
2474
2475 pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2476 self.active_alternative().read(cx).edit_position
2477 }
2478
2479 fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2480 &self.active_alternative().read(cx).diff
2481 }
2482
2483 pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2484 self.active_alternative().read(cx).last_equal_ranges()
2485 }
2486}
2487
2488impl EventEmitter<CodegenEvent> for Codegen {}
2489
2490pub struct CodegenAlternative {
2491 buffer: Model<MultiBuffer>,
2492 old_buffer: Model<Buffer>,
2493 snapshot: MultiBufferSnapshot,
2494 edit_position: Option<Anchor>,
2495 range: Range<Anchor>,
2496 last_equal_ranges: Vec<Range<Anchor>>,
2497 transformation_transaction_id: Option<TransactionId>,
2498 status: CodegenStatus,
2499 generation: Task<()>,
2500 diff: Diff,
2501 telemetry: Option<Arc<Telemetry>>,
2502 _subscription: gpui::Subscription,
2503 builder: Arc<PromptBuilder>,
2504 active: bool,
2505 edits: Vec<(Range<Anchor>, String)>,
2506 line_operations: Vec<LineOperation>,
2507}
2508
2509enum CodegenStatus {
2510 Idle,
2511 Pending,
2512 Done,
2513 Error(anyhow::Error),
2514}
2515
2516#[derive(Default)]
2517struct Diff {
2518 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2519 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2520}
2521
2522impl Diff {
2523 fn is_empty(&self) -> bool {
2524 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2525 }
2526}
2527
2528impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2529
2530impl CodegenAlternative {
2531 pub fn new(
2532 buffer: Model<MultiBuffer>,
2533 range: Range<Anchor>,
2534 active: bool,
2535 telemetry: Option<Arc<Telemetry>>,
2536 builder: Arc<PromptBuilder>,
2537 cx: &mut ModelContext<Self>,
2538 ) -> Self {
2539 let snapshot = buffer.read(cx).snapshot(cx);
2540
2541 let (old_buffer, _, _) = buffer
2542 .read(cx)
2543 .range_to_buffer_ranges(range.clone(), cx)
2544 .pop()
2545 .unwrap();
2546 let old_buffer = cx.new_model(|cx| {
2547 let old_buffer = old_buffer.read(cx);
2548 let text = old_buffer.as_rope().clone();
2549 let line_ending = old_buffer.line_ending();
2550 let language = old_buffer.language().cloned();
2551 let language_registry = old_buffer.language_registry();
2552
2553 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2554 buffer.set_language(language, cx);
2555 if let Some(language_registry) = language_registry {
2556 buffer.set_language_registry(language_registry)
2557 }
2558 buffer
2559 });
2560
2561 Self {
2562 buffer: buffer.clone(),
2563 old_buffer,
2564 edit_position: None,
2565 snapshot,
2566 last_equal_ranges: Default::default(),
2567 transformation_transaction_id: None,
2568 status: CodegenStatus::Idle,
2569 generation: Task::ready(()),
2570 diff: Diff::default(),
2571 telemetry,
2572 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2573 builder,
2574 active,
2575 edits: Vec::new(),
2576 line_operations: Vec::new(),
2577 range,
2578 }
2579 }
2580
2581 fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2582 if active != self.active {
2583 self.active = active;
2584
2585 if self.active {
2586 let edits = self.edits.clone();
2587 self.apply_edits(edits, cx);
2588 if matches!(self.status, CodegenStatus::Pending) {
2589 let line_operations = self.line_operations.clone();
2590 self.reapply_line_based_diff(line_operations, cx);
2591 } else {
2592 self.reapply_batch_diff(cx).detach();
2593 }
2594 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2595 self.buffer.update(cx, |buffer, cx| {
2596 buffer.undo_transaction(transaction_id, cx);
2597 buffer.forget_transaction(transaction_id, cx);
2598 });
2599 }
2600 }
2601 }
2602
2603 fn handle_buffer_event(
2604 &mut self,
2605 _buffer: Model<MultiBuffer>,
2606 event: &multi_buffer::Event,
2607 cx: &mut ModelContext<Self>,
2608 ) {
2609 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2610 if self.transformation_transaction_id == Some(*transaction_id) {
2611 self.transformation_transaction_id = None;
2612 self.generation = Task::ready(());
2613 cx.emit(CodegenEvent::Undone);
2614 }
2615 }
2616 }
2617
2618 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2619 &self.last_equal_ranges
2620 }
2621
2622 pub fn count_tokens(
2623 &self,
2624 user_prompt: String,
2625 assistant_panel_context: Option<LanguageModelRequest>,
2626 cx: &AppContext,
2627 ) -> BoxFuture<'static, Result<TokenCounts>> {
2628 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2629 let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
2630 match request {
2631 Ok(request) => {
2632 let total_count = model.count_tokens(request.clone(), cx);
2633 let assistant_panel_count = assistant_panel_context
2634 .map(|context| model.count_tokens(context, cx))
2635 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2636
2637 async move {
2638 Ok(TokenCounts {
2639 total: total_count.await?,
2640 assistant_panel: assistant_panel_count.await?,
2641 })
2642 }
2643 .boxed()
2644 }
2645 Err(error) => futures::future::ready(Err(error)).boxed(),
2646 }
2647 } else {
2648 future::ready(Err(anyhow!("no active model"))).boxed()
2649 }
2650 }
2651
2652 pub fn start(
2653 &mut self,
2654 user_prompt: String,
2655 assistant_panel_context: Option<LanguageModelRequest>,
2656 model: Arc<dyn LanguageModel>,
2657 cx: &mut ModelContext<Self>,
2658 ) -> Result<()> {
2659 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2660 self.buffer.update(cx, |buffer, cx| {
2661 buffer.undo_transaction(transformation_transaction_id, cx);
2662 });
2663 }
2664
2665 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2666
2667 let telemetry_id = model.telemetry_id();
2668 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
2669 if user_prompt.trim().to_lowercase() == "delete" {
2670 async { Ok(stream::empty().boxed()) }.boxed_local()
2671 } else {
2672 let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
2673
2674 let chunks = cx
2675 .spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await });
2676 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2677 };
2678 self.handle_stream(telemetry_id, chunks, cx);
2679 Ok(())
2680 }
2681
2682 fn build_request(
2683 &self,
2684 user_prompt: String,
2685 assistant_panel_context: Option<LanguageModelRequest>,
2686 cx: &AppContext,
2687 ) -> Result<LanguageModelRequest> {
2688 let buffer = self.buffer.read(cx).snapshot(cx);
2689 let language = buffer.language_at(self.range.start);
2690 let language_name = if let Some(language) = language.as_ref() {
2691 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2692 None
2693 } else {
2694 Some(language.name())
2695 }
2696 } else {
2697 None
2698 };
2699
2700 let language_name = language_name.as_ref();
2701 let start = buffer.point_to_buffer_offset(self.range.start);
2702 let end = buffer.point_to_buffer_offset(self.range.end);
2703 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2704 let (start_buffer, start_buffer_offset) = start;
2705 let (end_buffer, end_buffer_offset) = end;
2706 if start_buffer.remote_id() == end_buffer.remote_id() {
2707 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2708 } else {
2709 return Err(anyhow::anyhow!("invalid transformation range"));
2710 }
2711 } else {
2712 return Err(anyhow::anyhow!("invalid transformation range"));
2713 };
2714
2715 let prompt = self
2716 .builder
2717 .generate_content_prompt(user_prompt, language_name, buffer, range)
2718 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2719
2720 let mut messages = Vec::new();
2721 if let Some(context_request) = assistant_panel_context {
2722 messages = context_request.messages;
2723 }
2724
2725 messages.push(LanguageModelRequestMessage {
2726 role: Role::User,
2727 content: vec![prompt.into()],
2728 cache: false,
2729 });
2730
2731 Ok(LanguageModelRequest {
2732 messages,
2733 tools: Vec::new(),
2734 stop: Vec::new(),
2735 temperature: None,
2736 })
2737 }
2738
2739 pub fn handle_stream(
2740 &mut self,
2741 model_telemetry_id: String,
2742 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2743 cx: &mut ModelContext<Self>,
2744 ) {
2745 let snapshot = self.snapshot.clone();
2746 let selected_text = snapshot
2747 .text_for_range(self.range.start..self.range.end)
2748 .collect::<Rope>();
2749
2750 let selection_start = self.range.start.to_point(&snapshot);
2751
2752 // Start with the indentation of the first line in the selection
2753 let mut suggested_line_indent = snapshot
2754 .suggested_indents(selection_start.row..=selection_start.row, cx)
2755 .into_values()
2756 .next()
2757 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2758
2759 // If the first line in the selection does not have indentation, check the following lines
2760 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2761 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2762 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2763 // Prefer tabs if a line in the selection uses tabs as indentation
2764 if line_indent.kind == IndentKind::Tab {
2765 suggested_line_indent.kind = IndentKind::Tab;
2766 break;
2767 }
2768 }
2769 }
2770
2771 let telemetry = self.telemetry.clone();
2772 self.diff = Diff::default();
2773 self.status = CodegenStatus::Pending;
2774 let mut edit_start = self.range.start.to_offset(&snapshot);
2775 self.generation = cx.spawn(|codegen, mut cx| {
2776 async move {
2777 let chunks = stream.await;
2778 let generate = async {
2779 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2780 let line_based_stream_diff: Task<anyhow::Result<()>> =
2781 cx.background_executor().spawn(async move {
2782 let mut response_latency = None;
2783 let request_start = Instant::now();
2784 let diff = async {
2785 let chunks = StripInvalidSpans::new(chunks?);
2786 futures::pin_mut!(chunks);
2787 let mut diff = StreamingDiff::new(selected_text.to_string());
2788 let mut line_diff = LineDiff::default();
2789
2790 let mut new_text = String::new();
2791 let mut base_indent = None;
2792 let mut line_indent = None;
2793 let mut first_line = true;
2794
2795 while let Some(chunk) = chunks.next().await {
2796 if response_latency.is_none() {
2797 response_latency = Some(request_start.elapsed());
2798 }
2799 let chunk = chunk?;
2800
2801 let mut lines = chunk.split('\n').peekable();
2802 while let Some(line) = lines.next() {
2803 new_text.push_str(line);
2804 if line_indent.is_none() {
2805 if let Some(non_whitespace_ch_ix) =
2806 new_text.find(|ch: char| !ch.is_whitespace())
2807 {
2808 line_indent = Some(non_whitespace_ch_ix);
2809 base_indent = base_indent.or(line_indent);
2810
2811 let line_indent = line_indent.unwrap();
2812 let base_indent = base_indent.unwrap();
2813 let indent_delta =
2814 line_indent as i32 - base_indent as i32;
2815 let mut corrected_indent_len = cmp::max(
2816 0,
2817 suggested_line_indent.len as i32 + indent_delta,
2818 )
2819 as usize;
2820 if first_line {
2821 corrected_indent_len = corrected_indent_len
2822 .saturating_sub(
2823 selection_start.column as usize,
2824 );
2825 }
2826
2827 let indent_char = suggested_line_indent.char();
2828 let mut indent_buffer = [0; 4];
2829 let indent_str =
2830 indent_char.encode_utf8(&mut indent_buffer);
2831 new_text.replace_range(
2832 ..line_indent,
2833 &indent_str.repeat(corrected_indent_len),
2834 );
2835 }
2836 }
2837
2838 if line_indent.is_some() {
2839 let char_ops = diff.push_new(&new_text);
2840 line_diff
2841 .push_char_operations(&char_ops, &selected_text);
2842 diff_tx
2843 .send((char_ops, line_diff.line_operations()))
2844 .await?;
2845 new_text.clear();
2846 }
2847
2848 if lines.peek().is_some() {
2849 let char_ops = diff.push_new("\n");
2850 line_diff
2851 .push_char_operations(&char_ops, &selected_text);
2852 diff_tx
2853 .send((char_ops, line_diff.line_operations()))
2854 .await?;
2855 if line_indent.is_none() {
2856 // Don't write out the leading indentation in empty lines on the next line
2857 // This is the case where the above if statement didn't clear the buffer
2858 new_text.clear();
2859 }
2860 line_indent = None;
2861 first_line = false;
2862 }
2863 }
2864 }
2865
2866 let mut char_ops = diff.push_new(&new_text);
2867 char_ops.extend(diff.finish());
2868 line_diff.push_char_operations(&char_ops, &selected_text);
2869 line_diff.finish(&selected_text);
2870 diff_tx
2871 .send((char_ops, line_diff.line_operations()))
2872 .await?;
2873
2874 anyhow::Ok(())
2875 };
2876
2877 let result = diff.await;
2878
2879 let error_message =
2880 result.as_ref().err().map(|error| error.to_string());
2881 if let Some(telemetry) = telemetry {
2882 telemetry.report_assistant_event(
2883 None,
2884 telemetry_events::AssistantKind::Inline,
2885 telemetry_events::AssistantPhase::Response,
2886 model_telemetry_id,
2887 response_latency,
2888 error_message,
2889 );
2890 }
2891
2892 result?;
2893 Ok(())
2894 });
2895
2896 while let Some((char_ops, line_ops)) = diff_rx.next().await {
2897 codegen.update(&mut cx, |codegen, cx| {
2898 codegen.last_equal_ranges.clear();
2899
2900 let edits = char_ops
2901 .into_iter()
2902 .filter_map(|operation| match operation {
2903 CharOperation::Insert { text } => {
2904 let edit_start = snapshot.anchor_after(edit_start);
2905 Some((edit_start..edit_start, text))
2906 }
2907 CharOperation::Delete { bytes } => {
2908 let edit_end = edit_start + bytes;
2909 let edit_range = snapshot.anchor_after(edit_start)
2910 ..snapshot.anchor_before(edit_end);
2911 edit_start = edit_end;
2912 Some((edit_range, String::new()))
2913 }
2914 CharOperation::Keep { bytes } => {
2915 let edit_end = edit_start + bytes;
2916 let edit_range = snapshot.anchor_after(edit_start)
2917 ..snapshot.anchor_before(edit_end);
2918 edit_start = edit_end;
2919 codegen.last_equal_ranges.push(edit_range);
2920 None
2921 }
2922 })
2923 .collect::<Vec<_>>();
2924
2925 if codegen.active {
2926 codegen.apply_edits(edits.iter().cloned(), cx);
2927 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2928 }
2929 codegen.edits.extend(edits);
2930 codegen.line_operations = line_ops;
2931 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2932
2933 cx.notify();
2934 })?;
2935 }
2936
2937 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2938 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2939 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2940 let batch_diff_task =
2941 codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
2942 let (line_based_stream_diff, ()) =
2943 join!(line_based_stream_diff, batch_diff_task);
2944 line_based_stream_diff?;
2945
2946 anyhow::Ok(())
2947 };
2948
2949 let result = generate.await;
2950 codegen
2951 .update(&mut cx, |this, cx| {
2952 this.last_equal_ranges.clear();
2953 if let Err(error) = result {
2954 this.status = CodegenStatus::Error(error);
2955 } else {
2956 this.status = CodegenStatus::Done;
2957 }
2958 cx.emit(CodegenEvent::Finished);
2959 cx.notify();
2960 })
2961 .ok();
2962 }
2963 });
2964 cx.notify();
2965 }
2966
2967 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2968 self.last_equal_ranges.clear();
2969 if self.diff.is_empty() {
2970 self.status = CodegenStatus::Idle;
2971 } else {
2972 self.status = CodegenStatus::Done;
2973 }
2974 self.generation = Task::ready(());
2975 cx.emit(CodegenEvent::Finished);
2976 cx.notify();
2977 }
2978
2979 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2980 self.buffer.update(cx, |buffer, cx| {
2981 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2982 buffer.undo_transaction(transaction_id, cx);
2983 buffer.refresh_preview(cx);
2984 }
2985 });
2986 }
2987
2988 fn apply_edits(
2989 &mut self,
2990 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
2991 cx: &mut ModelContext<CodegenAlternative>,
2992 ) {
2993 let transaction = self.buffer.update(cx, |buffer, cx| {
2994 // Avoid grouping assistant edits with user edits.
2995 buffer.finalize_last_transaction(cx);
2996 buffer.start_transaction(cx);
2997 buffer.edit(edits, None, cx);
2998 buffer.end_transaction(cx)
2999 });
3000
3001 if let Some(transaction) = transaction {
3002 if let Some(first_transaction) = self.transformation_transaction_id {
3003 // Group all assistant edits into the first transaction.
3004 self.buffer.update(cx, |buffer, cx| {
3005 buffer.merge_transactions(transaction, first_transaction, cx)
3006 });
3007 } else {
3008 self.transformation_transaction_id = Some(transaction);
3009 self.buffer
3010 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3011 }
3012 }
3013 }
3014
3015 fn reapply_line_based_diff(
3016 &mut self,
3017 line_operations: impl IntoIterator<Item = LineOperation>,
3018 cx: &mut ModelContext<Self>,
3019 ) {
3020 let old_snapshot = self.snapshot.clone();
3021 let old_range = self.range.to_point(&old_snapshot);
3022 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3023 let new_range = self.range.to_point(&new_snapshot);
3024
3025 let mut old_row = old_range.start.row;
3026 let mut new_row = new_range.start.row;
3027
3028 self.diff.deleted_row_ranges.clear();
3029 self.diff.inserted_row_ranges.clear();
3030 for operation in line_operations {
3031 match operation {
3032 LineOperation::Keep { lines } => {
3033 old_row += lines;
3034 new_row += lines;
3035 }
3036 LineOperation::Delete { lines } => {
3037 let old_end_row = old_row + lines - 1;
3038 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3039
3040 if let Some((_, last_deleted_row_range)) =
3041 self.diff.deleted_row_ranges.last_mut()
3042 {
3043 if *last_deleted_row_range.end() + 1 == old_row {
3044 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3045 } else {
3046 self.diff
3047 .deleted_row_ranges
3048 .push((new_row, old_row..=old_end_row));
3049 }
3050 } else {
3051 self.diff
3052 .deleted_row_ranges
3053 .push((new_row, old_row..=old_end_row));
3054 }
3055
3056 old_row += lines;
3057 }
3058 LineOperation::Insert { lines } => {
3059 let new_end_row = new_row + lines - 1;
3060 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3061 let end = new_snapshot.anchor_before(Point::new(
3062 new_end_row,
3063 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3064 ));
3065 self.diff.inserted_row_ranges.push(start..=end);
3066 new_row += lines;
3067 }
3068 }
3069
3070 cx.notify();
3071 }
3072 }
3073
3074 fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3075 let old_snapshot = self.snapshot.clone();
3076 let old_range = self.range.to_point(&old_snapshot);
3077 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3078 let new_range = self.range.to_point(&new_snapshot);
3079
3080 cx.spawn(|codegen, mut cx| async move {
3081 let (deleted_row_ranges, inserted_row_ranges) = cx
3082 .background_executor()
3083 .spawn(async move {
3084 let old_text = old_snapshot
3085 .text_for_range(
3086 Point::new(old_range.start.row, 0)
3087 ..Point::new(
3088 old_range.end.row,
3089 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3090 ),
3091 )
3092 .collect::<String>();
3093 let new_text = new_snapshot
3094 .text_for_range(
3095 Point::new(new_range.start.row, 0)
3096 ..Point::new(
3097 new_range.end.row,
3098 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3099 ),
3100 )
3101 .collect::<String>();
3102
3103 let mut old_row = old_range.start.row;
3104 let mut new_row = new_range.start.row;
3105 let batch_diff =
3106 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3107
3108 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3109 let mut inserted_row_ranges = Vec::new();
3110 for change in batch_diff.iter_all_changes() {
3111 let line_count = change.value().lines().count() as u32;
3112 match change.tag() {
3113 similar::ChangeTag::Equal => {
3114 old_row += line_count;
3115 new_row += line_count;
3116 }
3117 similar::ChangeTag::Delete => {
3118 let old_end_row = old_row + line_count - 1;
3119 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3120
3121 if let Some((_, last_deleted_row_range)) =
3122 deleted_row_ranges.last_mut()
3123 {
3124 if *last_deleted_row_range.end() + 1 == old_row {
3125 *last_deleted_row_range =
3126 *last_deleted_row_range.start()..=old_end_row;
3127 } else {
3128 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3129 }
3130 } else {
3131 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3132 }
3133
3134 old_row += line_count;
3135 }
3136 similar::ChangeTag::Insert => {
3137 let new_end_row = new_row + line_count - 1;
3138 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3139 let end = new_snapshot.anchor_before(Point::new(
3140 new_end_row,
3141 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3142 ));
3143 inserted_row_ranges.push(start..=end);
3144 new_row += line_count;
3145 }
3146 }
3147 }
3148
3149 (deleted_row_ranges, inserted_row_ranges)
3150 })
3151 .await;
3152
3153 codegen
3154 .update(&mut cx, |codegen, cx| {
3155 codegen.diff.deleted_row_ranges = deleted_row_ranges;
3156 codegen.diff.inserted_row_ranges = inserted_row_ranges;
3157 cx.notify();
3158 })
3159 .ok();
3160 })
3161 }
3162}
3163
3164struct StripInvalidSpans<T> {
3165 stream: T,
3166 stream_done: bool,
3167 buffer: String,
3168 first_line: bool,
3169 line_end: bool,
3170 starts_with_code_block: bool,
3171}
3172
3173impl<T> StripInvalidSpans<T>
3174where
3175 T: Stream<Item = Result<String>>,
3176{
3177 fn new(stream: T) -> Self {
3178 Self {
3179 stream,
3180 stream_done: false,
3181 buffer: String::new(),
3182 first_line: true,
3183 line_end: false,
3184 starts_with_code_block: false,
3185 }
3186 }
3187}
3188
3189impl<T> Stream for StripInvalidSpans<T>
3190where
3191 T: Stream<Item = Result<String>>,
3192{
3193 type Item = Result<String>;
3194
3195 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3196 const CODE_BLOCK_DELIMITER: &str = "```";
3197 const CURSOR_SPAN: &str = "<|CURSOR|>";
3198
3199 let this = unsafe { self.get_unchecked_mut() };
3200 loop {
3201 if !this.stream_done {
3202 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3203 match stream.as_mut().poll_next(cx) {
3204 Poll::Ready(Some(Ok(chunk))) => {
3205 this.buffer.push_str(&chunk);
3206 }
3207 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3208 Poll::Ready(None) => {
3209 this.stream_done = true;
3210 }
3211 Poll::Pending => return Poll::Pending,
3212 }
3213 }
3214
3215 let mut chunk = String::new();
3216 let mut consumed = 0;
3217 if !this.buffer.is_empty() {
3218 let mut lines = this.buffer.split('\n').enumerate().peekable();
3219 while let Some((line_ix, line)) = lines.next() {
3220 if line_ix > 0 {
3221 this.first_line = false;
3222 }
3223
3224 if this.first_line {
3225 let trimmed_line = line.trim();
3226 if lines.peek().is_some() {
3227 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3228 consumed += line.len() + 1;
3229 this.starts_with_code_block = true;
3230 continue;
3231 }
3232 } else if trimmed_line.is_empty()
3233 || prefixes(CODE_BLOCK_DELIMITER)
3234 .any(|prefix| trimmed_line.starts_with(prefix))
3235 {
3236 break;
3237 }
3238 }
3239
3240 let line_without_cursor = line.replace(CURSOR_SPAN, "");
3241 if lines.peek().is_some() {
3242 if this.line_end {
3243 chunk.push('\n');
3244 }
3245
3246 chunk.push_str(&line_without_cursor);
3247 this.line_end = true;
3248 consumed += line.len() + 1;
3249 } else if this.stream_done {
3250 if !this.starts_with_code_block
3251 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3252 {
3253 if this.line_end {
3254 chunk.push('\n');
3255 }
3256
3257 chunk.push_str(&line);
3258 }
3259
3260 consumed += line.len();
3261 } else {
3262 let trimmed_line = line.trim();
3263 if trimmed_line.is_empty()
3264 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3265 || prefixes(CODE_BLOCK_DELIMITER)
3266 .any(|prefix| trimmed_line.ends_with(prefix))
3267 {
3268 break;
3269 } else {
3270 if this.line_end {
3271 chunk.push('\n');
3272 this.line_end = false;
3273 }
3274
3275 chunk.push_str(&line_without_cursor);
3276 consumed += line.len();
3277 }
3278 }
3279 }
3280 }
3281
3282 this.buffer = this.buffer.split_off(consumed);
3283 if !chunk.is_empty() {
3284 return Poll::Ready(Some(Ok(chunk)));
3285 } else if this.stream_done {
3286 return Poll::Ready(None);
3287 }
3288 }
3289 }
3290}
3291
3292fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3293 (0..text.len() - 1).map(|ix| &text[..ix + 1])
3294}
3295
3296fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3297 ranges.sort_unstable_by(|a, b| {
3298 a.start
3299 .cmp(&b.start, buffer)
3300 .then_with(|| b.end.cmp(&a.end, buffer))
3301 });
3302
3303 let mut ix = 0;
3304 while ix + 1 < ranges.len() {
3305 let b = ranges[ix + 1].clone();
3306 let a = &mut ranges[ix];
3307 if a.end.cmp(&b.start, buffer).is_gt() {
3308 if a.end.cmp(&b.end, buffer).is_lt() {
3309 a.end = b.end;
3310 }
3311 ranges.remove(ix + 1);
3312 } else {
3313 ix += 1;
3314 }
3315 }
3316}
3317
3318#[cfg(test)]
3319mod tests {
3320 use super::*;
3321 use futures::stream::{self};
3322 use gpui::{Context, TestAppContext};
3323 use indoc::indoc;
3324 use language::{
3325 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3326 Point,
3327 };
3328 use language_model::LanguageModelRegistry;
3329 use rand::prelude::*;
3330 use serde::Serialize;
3331 use settings::SettingsStore;
3332 use std::{future, sync::Arc};
3333
3334 #[derive(Serialize)]
3335 pub struct DummyCompletionRequest {
3336 pub name: String,
3337 }
3338
3339 #[gpui::test(iterations = 10)]
3340 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3341 cx.set_global(cx.update(SettingsStore::test));
3342 cx.update(language_model::LanguageModelRegistry::test);
3343 cx.update(language_settings::init);
3344
3345 let text = indoc! {"
3346 fn main() {
3347 let x = 0;
3348 for _ in 0..10 {
3349 x += 1;
3350 }
3351 }
3352 "};
3353 let buffer =
3354 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3355 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3356 let range = buffer.read_with(cx, |buffer, cx| {
3357 let snapshot = buffer.snapshot(cx);
3358 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3359 });
3360 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3361 let codegen = cx.new_model(|cx| {
3362 CodegenAlternative::new(
3363 buffer.clone(),
3364 range.clone(),
3365 true,
3366 None,
3367 prompt_builder,
3368 cx,
3369 )
3370 });
3371
3372 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3373 codegen.update(cx, |codegen, cx| {
3374 codegen.handle_stream(
3375 String::new(),
3376 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3377 cx,
3378 )
3379 });
3380
3381 let mut new_text = concat!(
3382 " let mut x = 0;\n",
3383 " while x < 10 {\n",
3384 " x += 1;\n",
3385 " }",
3386 );
3387 while !new_text.is_empty() {
3388 let max_len = cmp::min(new_text.len(), 10);
3389 let len = rng.gen_range(1..=max_len);
3390 let (chunk, suffix) = new_text.split_at(len);
3391 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3392 new_text = suffix;
3393 cx.background_executor.run_until_parked();
3394 }
3395 drop(chunks_tx);
3396 cx.background_executor.run_until_parked();
3397
3398 assert_eq!(
3399 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3400 indoc! {"
3401 fn main() {
3402 let mut x = 0;
3403 while x < 10 {
3404 x += 1;
3405 }
3406 }
3407 "}
3408 );
3409 }
3410
3411 #[gpui::test(iterations = 10)]
3412 async fn test_autoindent_when_generating_past_indentation(
3413 cx: &mut TestAppContext,
3414 mut rng: StdRng,
3415 ) {
3416 cx.set_global(cx.update(SettingsStore::test));
3417 cx.update(language_settings::init);
3418
3419 let text = indoc! {"
3420 fn main() {
3421 le
3422 }
3423 "};
3424 let buffer =
3425 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3426 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3427 let range = buffer.read_with(cx, |buffer, cx| {
3428 let snapshot = buffer.snapshot(cx);
3429 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3430 });
3431 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3432 let codegen = cx.new_model(|cx| {
3433 CodegenAlternative::new(
3434 buffer.clone(),
3435 range.clone(),
3436 true,
3437 None,
3438 prompt_builder,
3439 cx,
3440 )
3441 });
3442
3443 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3444 codegen.update(cx, |codegen, cx| {
3445 codegen.handle_stream(
3446 String::new(),
3447 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3448 cx,
3449 )
3450 });
3451
3452 cx.background_executor.run_until_parked();
3453
3454 let mut new_text = concat!(
3455 "t mut x = 0;\n",
3456 "while x < 10 {\n",
3457 " x += 1;\n",
3458 "}", //
3459 );
3460 while !new_text.is_empty() {
3461 let max_len = cmp::min(new_text.len(), 10);
3462 let len = rng.gen_range(1..=max_len);
3463 let (chunk, suffix) = new_text.split_at(len);
3464 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3465 new_text = suffix;
3466 cx.background_executor.run_until_parked();
3467 }
3468 drop(chunks_tx);
3469 cx.background_executor.run_until_parked();
3470
3471 assert_eq!(
3472 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3473 indoc! {"
3474 fn main() {
3475 let mut x = 0;
3476 while x < 10 {
3477 x += 1;
3478 }
3479 }
3480 "}
3481 );
3482 }
3483
3484 #[gpui::test(iterations = 10)]
3485 async fn test_autoindent_when_generating_before_indentation(
3486 cx: &mut TestAppContext,
3487 mut rng: StdRng,
3488 ) {
3489 cx.update(LanguageModelRegistry::test);
3490 cx.set_global(cx.update(SettingsStore::test));
3491 cx.update(language_settings::init);
3492
3493 let text = concat!(
3494 "fn main() {\n",
3495 " \n",
3496 "}\n" //
3497 );
3498 let buffer =
3499 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3500 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3501 let range = buffer.read_with(cx, |buffer, cx| {
3502 let snapshot = buffer.snapshot(cx);
3503 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3504 });
3505 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3506 let codegen = cx.new_model(|cx| {
3507 CodegenAlternative::new(
3508 buffer.clone(),
3509 range.clone(),
3510 true,
3511 None,
3512 prompt_builder,
3513 cx,
3514 )
3515 });
3516
3517 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3518 codegen.update(cx, |codegen, cx| {
3519 codegen.handle_stream(
3520 String::new(),
3521 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3522 cx,
3523 )
3524 });
3525
3526 cx.background_executor.run_until_parked();
3527
3528 let mut new_text = concat!(
3529 "let mut x = 0;\n",
3530 "while x < 10 {\n",
3531 " x += 1;\n",
3532 "}", //
3533 );
3534 while !new_text.is_empty() {
3535 let max_len = cmp::min(new_text.len(), 10);
3536 let len = rng.gen_range(1..=max_len);
3537 let (chunk, suffix) = new_text.split_at(len);
3538 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3539 new_text = suffix;
3540 cx.background_executor.run_until_parked();
3541 }
3542 drop(chunks_tx);
3543 cx.background_executor.run_until_parked();
3544
3545 assert_eq!(
3546 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3547 indoc! {"
3548 fn main() {
3549 let mut x = 0;
3550 while x < 10 {
3551 x += 1;
3552 }
3553 }
3554 "}
3555 );
3556 }
3557
3558 #[gpui::test(iterations = 10)]
3559 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3560 cx.update(LanguageModelRegistry::test);
3561 cx.set_global(cx.update(SettingsStore::test));
3562 cx.update(language_settings::init);
3563
3564 let text = indoc! {"
3565 func main() {
3566 \tx := 0
3567 \tfor i := 0; i < 10; i++ {
3568 \t\tx++
3569 \t}
3570 }
3571 "};
3572 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3573 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3574 let range = buffer.read_with(cx, |buffer, cx| {
3575 let snapshot = buffer.snapshot(cx);
3576 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3577 });
3578 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3579 let codegen = cx.new_model(|cx| {
3580 CodegenAlternative::new(
3581 buffer.clone(),
3582 range.clone(),
3583 true,
3584 None,
3585 prompt_builder,
3586 cx,
3587 )
3588 });
3589
3590 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3591 codegen.update(cx, |codegen, cx| {
3592 codegen.handle_stream(
3593 String::new(),
3594 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3595 cx,
3596 )
3597 });
3598
3599 let new_text = concat!(
3600 "func main() {\n",
3601 "\tx := 0\n",
3602 "\tfor x < 10 {\n",
3603 "\t\tx++\n",
3604 "\t}", //
3605 );
3606 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3607 drop(chunks_tx);
3608 cx.background_executor.run_until_parked();
3609
3610 assert_eq!(
3611 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3612 indoc! {"
3613 func main() {
3614 \tx := 0
3615 \tfor x < 10 {
3616 \t\tx++
3617 \t}
3618 }
3619 "}
3620 );
3621 }
3622
3623 #[gpui::test]
3624 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3625 cx.update(LanguageModelRegistry::test);
3626 cx.set_global(cx.update(SettingsStore::test));
3627 cx.update(language_settings::init);
3628
3629 let text = indoc! {"
3630 fn main() {
3631 let x = 0;
3632 }
3633 "};
3634 let buffer =
3635 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3636 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3637 let range = buffer.read_with(cx, |buffer, cx| {
3638 let snapshot = buffer.snapshot(cx);
3639 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3640 });
3641 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3642 let codegen = cx.new_model(|cx| {
3643 CodegenAlternative::new(
3644 buffer.clone(),
3645 range.clone(),
3646 false,
3647 None,
3648 prompt_builder,
3649 cx,
3650 )
3651 });
3652
3653 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3654 codegen.update(cx, |codegen, cx| {
3655 codegen.handle_stream(
3656 String::new(),
3657 future::ready(Ok(chunks_rx.map(Ok).boxed())),
3658 cx,
3659 )
3660 });
3661
3662 chunks_tx
3663 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3664 .unwrap();
3665 drop(chunks_tx);
3666 cx.run_until_parked();
3667
3668 // The codegen is inactive, so the buffer doesn't get modified.
3669 assert_eq!(
3670 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3671 text
3672 );
3673
3674 // Activating the codegen applies the changes.
3675 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3676 assert_eq!(
3677 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3678 indoc! {"
3679 fn main() {
3680 let mut x = 0;
3681 x += 1;
3682 }
3683 "}
3684 );
3685
3686 // Deactivating the codegen undoes the changes.
3687 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3688 cx.run_until_parked();
3689 assert_eq!(
3690 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3691 text
3692 );
3693 }
3694
3695 #[gpui::test]
3696 async fn test_strip_invalid_spans_from_codeblock() {
3697 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3698 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3699 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3700 assert_chunks(
3701 "```html\n```js\nLorem ipsum dolor\n```\n```",
3702 "```js\nLorem ipsum dolor\n```",
3703 )
3704 .await;
3705 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3706 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3707 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3708 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3709
3710 async fn assert_chunks(text: &str, expected_text: &str) {
3711 for chunk_size in 1..=text.len() {
3712 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3713 .map(|chunk| chunk.unwrap())
3714 .collect::<String>()
3715 .await;
3716 assert_eq!(
3717 actual_text, expected_text,
3718 "failed to strip invalid spans, chunk size: {}",
3719 chunk_size
3720 );
3721 }
3722 }
3723
3724 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3725 stream::iter(
3726 text.chars()
3727 .collect::<Vec<_>>()
3728 .chunks(size)
3729 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3730 .collect::<Vec<_>>(),
3731 )
3732 }
3733 }
3734
3735 fn rust_lang() -> Language {
3736 Language::new(
3737 LanguageConfig {
3738 name: "Rust".into(),
3739 matcher: LanguageMatcher {
3740 path_suffixes: vec!["rs".to_string()],
3741 ..Default::default()
3742 },
3743 ..Default::default()
3744 },
3745 Some(tree_sitter_rust::LANGUAGE.into()),
3746 )
3747 .with_indents_query(
3748 r#"
3749 (call_expression) @indent
3750 (field_expression) @indent
3751 (_ "(" ")" @end) @indent
3752 (_ "{" "}" @end) @indent
3753 "#,
3754 )
3755 .unwrap()
3756 }
3757}