1use crate::{
2 humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
3 CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, Selection, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use smol::future::FutureExt;
39use std::{
40 cmp,
41 future::{self, Future},
42 mem,
43 ops::{Range, RangeInclusive},
44 pin::Pin,
45 sync::Arc,
46 task::{self, Poll},
47 time::{Duration, Instant},
48};
49use theme::ThemeSettings;
50use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
51use util::{RangeExt, ResultExt};
52use workspace::{notifications::NotificationId, Toast, Workspace};
53
54pub fn init(
55 fs: Arc<dyn Fs>,
56 prompt_builder: Arc<PromptBuilder>,
57 telemetry: Arc<Telemetry>,
58 cx: &mut AppContext,
59) {
60 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
61}
62
63const PROMPT_HISTORY_MAX_LEN: usize = 20;
64
65pub struct InlineAssistant {
66 next_assist_id: InlineAssistId,
67 next_assist_group_id: InlineAssistGroupId,
68 assists: HashMap<InlineAssistId, InlineAssist>,
69 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
70 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
71 assist_observations:
72 HashMap<InlineAssistId, (async_watch::Sender<()>, async_watch::Receiver<()>)>,
73 confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
74 prompt_history: VecDeque<String>,
75 prompt_builder: Arc<PromptBuilder>,
76 telemetry: Option<Arc<Telemetry>>,
77 fs: Arc<dyn Fs>,
78}
79
80impl Global for InlineAssistant {}
81
82impl InlineAssistant {
83 pub fn new(
84 fs: Arc<dyn Fs>,
85 prompt_builder: Arc<PromptBuilder>,
86 telemetry: Arc<Telemetry>,
87 ) -> Self {
88 Self {
89 next_assist_id: InlineAssistId::default(),
90 next_assist_group_id: InlineAssistGroupId::default(),
91 assists: HashMap::default(),
92 assists_by_editor: HashMap::default(),
93 assist_groups: HashMap::default(),
94 assist_observations: HashMap::default(),
95 confirmed_assists: HashMap::default(),
96 prompt_history: VecDeque::default(),
97 prompt_builder,
98 telemetry: Some(telemetry),
99 fs,
100 }
101 }
102
103 pub fn assist(
104 &mut self,
105 editor: &View<Editor>,
106 workspace: Option<WeakView<Workspace>>,
107 assistant_panel: Option<&View<AssistantPanel>>,
108 initial_prompt: Option<String>,
109 cx: &mut WindowContext,
110 ) {
111 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
112
113 let mut selections = Vec::<Selection<Point>>::new();
114 let mut newest_selection = None;
115 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
116 if selection.end > selection.start {
117 selection.start.column = 0;
118 // If the selection ends at the start of the line, we don't want to include it.
119 if selection.end.column == 0 {
120 selection.end.row -= 1;
121 }
122 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
123 }
124
125 if let Some(prev_selection) = selections.last_mut() {
126 if selection.start <= prev_selection.end {
127 prev_selection.end = selection.end;
128 continue;
129 }
130 }
131
132 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
133 if selection.id > latest_selection.id {
134 *latest_selection = selection.clone();
135 }
136 selections.push(selection);
137 }
138 let newest_selection = newest_selection.unwrap();
139
140 let mut codegen_ranges = Vec::new();
141 for (excerpt_id, buffer, buffer_range) in
142 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
143 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
144 }))
145 {
146 let start = Anchor {
147 buffer_id: Some(buffer.remote_id()),
148 excerpt_id,
149 text_anchor: buffer.anchor_before(buffer_range.start),
150 };
151 let end = Anchor {
152 buffer_id: Some(buffer.remote_id()),
153 excerpt_id,
154 text_anchor: buffer.anchor_after(buffer_range.end),
155 };
156 codegen_ranges.push(start..end);
157 }
158
159 let assist_group_id = self.next_assist_group_id.post_inc();
160 let prompt_buffer =
161 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
162 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
163
164 let mut assists = Vec::new();
165 let mut assist_to_focus = None;
166 for range in codegen_ranges {
167 let assist_id = self.next_assist_id.post_inc();
168 let codegen = cx.new_model(|cx| {
169 Codegen::new(
170 editor.read(cx).buffer().clone(),
171 range.clone(),
172 None,
173 self.telemetry.clone(),
174 self.prompt_builder.clone(),
175 cx,
176 )
177 });
178
179 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
180 let prompt_editor = cx.new_view(|cx| {
181 PromptEditor::new(
182 assist_id,
183 gutter_dimensions.clone(),
184 self.prompt_history.clone(),
185 prompt_buffer.clone(),
186 codegen.clone(),
187 editor,
188 assistant_panel,
189 workspace.clone(),
190 self.fs.clone(),
191 cx,
192 )
193 });
194
195 if assist_to_focus.is_none() {
196 let focus_assist = if newest_selection.reversed {
197 range.start.to_point(&snapshot) == newest_selection.start
198 } else {
199 range.end.to_point(&snapshot) == newest_selection.end
200 };
201 if focus_assist {
202 assist_to_focus = Some(assist_id);
203 }
204 }
205
206 let [prompt_block_id, end_block_id] =
207 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
208
209 assists.push((
210 assist_id,
211 range,
212 prompt_editor,
213 prompt_block_id,
214 end_block_id,
215 ));
216 }
217
218 let editor_assists = self
219 .assists_by_editor
220 .entry(editor.downgrade())
221 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
222 let mut assist_group = InlineAssistGroup::new();
223 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
224 self.assists.insert(
225 assist_id,
226 InlineAssist::new(
227 assist_id,
228 assist_group_id,
229 assistant_panel.is_some(),
230 editor,
231 &prompt_editor,
232 prompt_block_id,
233 end_block_id,
234 range,
235 prompt_editor.read(cx).codegen.clone(),
236 workspace.clone(),
237 cx,
238 ),
239 );
240 assist_group.assist_ids.push(assist_id);
241 editor_assists.assist_ids.push(assist_id);
242 }
243 self.assist_groups.insert(assist_group_id, assist_group);
244
245 if let Some(assist_id) = assist_to_focus {
246 self.focus_assist(assist_id, cx);
247 }
248 }
249
250 #[allow(clippy::too_many_arguments)]
251 pub fn suggest_assist(
252 &mut self,
253 editor: &View<Editor>,
254 mut range: Range<Anchor>,
255 initial_prompt: String,
256 initial_transaction_id: Option<TransactionId>,
257 workspace: Option<WeakView<Workspace>>,
258 assistant_panel: Option<&View<AssistantPanel>>,
259 cx: &mut WindowContext,
260 ) -> InlineAssistId {
261 let assist_group_id = self.next_assist_group_id.post_inc();
262 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
263 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
264
265 let assist_id = self.next_assist_id.post_inc();
266
267 let buffer = editor.read(cx).buffer().clone();
268 {
269 let snapshot = buffer.read(cx).read(cx);
270 range.start = range.start.bias_left(&snapshot);
271 range.end = range.end.bias_right(&snapshot);
272 }
273
274 let codegen = cx.new_model(|cx| {
275 Codegen::new(
276 editor.read(cx).buffer().clone(),
277 range.clone(),
278 initial_transaction_id,
279 self.telemetry.clone(),
280 self.prompt_builder.clone(),
281 cx,
282 )
283 });
284
285 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
286 let prompt_editor = cx.new_view(|cx| {
287 PromptEditor::new(
288 assist_id,
289 gutter_dimensions.clone(),
290 self.prompt_history.clone(),
291 prompt_buffer.clone(),
292 codegen.clone(),
293 editor,
294 assistant_panel,
295 workspace.clone(),
296 self.fs.clone(),
297 cx,
298 )
299 });
300
301 let [prompt_block_id, end_block_id] =
302 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
303
304 let editor_assists = self
305 .assists_by_editor
306 .entry(editor.downgrade())
307 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
308
309 let mut assist_group = InlineAssistGroup::new();
310 self.assists.insert(
311 assist_id,
312 InlineAssist::new(
313 assist_id,
314 assist_group_id,
315 assistant_panel.is_some(),
316 editor,
317 &prompt_editor,
318 prompt_block_id,
319 end_block_id,
320 range,
321 prompt_editor.read(cx).codegen.clone(),
322 workspace.clone(),
323 cx,
324 ),
325 );
326 assist_group.assist_ids.push(assist_id);
327 editor_assists.assist_ids.push(assist_id);
328 self.assist_groups.insert(assist_group_id, assist_group);
329 assist_id
330 }
331
332 fn insert_assist_blocks(
333 &self,
334 editor: &View<Editor>,
335 range: &Range<Anchor>,
336 prompt_editor: &View<PromptEditor>,
337 cx: &mut WindowContext,
338 ) -> [CustomBlockId; 2] {
339 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
340 prompt_editor
341 .editor
342 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
343 });
344 let assist_blocks = vec![
345 BlockProperties {
346 style: BlockStyle::Sticky,
347 position: range.start,
348 height: prompt_editor_height,
349 render: build_assist_editor_renderer(prompt_editor),
350 disposition: BlockDisposition::Above,
351 priority: 0,
352 },
353 BlockProperties {
354 style: BlockStyle::Sticky,
355 position: range.end,
356 height: 0,
357 render: Box::new(|cx| {
358 v_flex()
359 .h_full()
360 .w_full()
361 .border_t_1()
362 .border_color(cx.theme().status().info_border)
363 .into_any_element()
364 }),
365 disposition: BlockDisposition::Below,
366 priority: 0,
367 },
368 ];
369
370 editor.update(cx, |editor, cx| {
371 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
372 [block_ids[0], block_ids[1]]
373 })
374 }
375
376 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
377 let assist = &self.assists[&assist_id];
378 let Some(decorations) = assist.decorations.as_ref() else {
379 return;
380 };
381 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
382 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
383
384 assist_group.active_assist_id = Some(assist_id);
385 if assist_group.linked {
386 for assist_id in &assist_group.assist_ids {
387 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
388 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
389 prompt_editor.set_show_cursor_when_unfocused(true, cx)
390 });
391 }
392 }
393 }
394
395 assist
396 .editor
397 .update(cx, |editor, cx| {
398 let scroll_top = editor.scroll_position(cx).y;
399 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
400 let prompt_row = editor
401 .row_for_block(decorations.prompt_block_id, cx)
402 .unwrap()
403 .0 as f32;
404
405 if (scroll_top..scroll_bottom).contains(&prompt_row) {
406 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
407 assist_id,
408 distance_from_top: prompt_row - scroll_top,
409 });
410 } else {
411 editor_assists.scroll_lock = None;
412 }
413 })
414 .ok();
415 }
416
417 fn handle_prompt_editor_focus_out(
418 &mut self,
419 assist_id: InlineAssistId,
420 cx: &mut WindowContext,
421 ) {
422 let assist = &self.assists[&assist_id];
423 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
424 if assist_group.active_assist_id == Some(assist_id) {
425 assist_group.active_assist_id = None;
426 if assist_group.linked {
427 for assist_id in &assist_group.assist_ids {
428 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
429 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
430 prompt_editor.set_show_cursor_when_unfocused(false, cx)
431 });
432 }
433 }
434 }
435 }
436 }
437
438 fn handle_prompt_editor_event(
439 &mut self,
440 prompt_editor: View<PromptEditor>,
441 event: &PromptEditorEvent,
442 cx: &mut WindowContext,
443 ) {
444 let assist_id = prompt_editor.read(cx).id;
445 match event {
446 PromptEditorEvent::StartRequested => {
447 self.start_assist(assist_id, cx);
448 }
449 PromptEditorEvent::StopRequested => {
450 self.stop_assist(assist_id, cx);
451 }
452 PromptEditorEvent::ConfirmRequested => {
453 self.finish_assist(assist_id, false, cx);
454 }
455 PromptEditorEvent::CancelRequested => {
456 self.finish_assist(assist_id, true, cx);
457 }
458 PromptEditorEvent::DismissRequested => {
459 self.dismiss_assist(assist_id, cx);
460 }
461 }
462 }
463
464 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
465 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
466 return;
467 };
468
469 let editor = editor.read(cx);
470 if editor.selections.count() == 1 {
471 let selection = editor.selections.newest::<usize>(cx);
472 let buffer = editor.buffer().read(cx).snapshot(cx);
473 for assist_id in &editor_assists.assist_ids {
474 let assist = &self.assists[assist_id];
475 let assist_range = assist.range.to_offset(&buffer);
476 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
477 {
478 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
479 self.dismiss_assist(*assist_id, cx);
480 } else {
481 self.finish_assist(*assist_id, false, cx);
482 }
483
484 return;
485 }
486 }
487 }
488
489 cx.propagate();
490 }
491
492 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
493 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
494 return;
495 };
496
497 let editor = editor.read(cx);
498 if editor.selections.count() == 1 {
499 let selection = editor.selections.newest::<usize>(cx);
500 let buffer = editor.buffer().read(cx).snapshot(cx);
501 for assist_id in &editor_assists.assist_ids {
502 let assist = &self.assists[assist_id];
503 let assist_range = assist.range.to_offset(&buffer);
504 if assist.decorations.is_some()
505 && assist_range.contains(&selection.start)
506 && assist_range.contains(&selection.end)
507 {
508 self.focus_assist(*assist_id, cx);
509 return;
510 }
511 }
512 }
513
514 cx.propagate();
515 }
516
517 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
518 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
519 for assist_id in editor_assists.assist_ids.clone() {
520 self.finish_assist(assist_id, true, cx);
521 }
522 }
523 }
524
525 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
526 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
527 return;
528 };
529 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
530 return;
531 };
532 let assist = &self.assists[&scroll_lock.assist_id];
533 let Some(decorations) = assist.decorations.as_ref() else {
534 return;
535 };
536
537 editor.update(cx, |editor, cx| {
538 let scroll_position = editor.scroll_position(cx);
539 let target_scroll_top = editor
540 .row_for_block(decorations.prompt_block_id, cx)
541 .unwrap()
542 .0 as f32
543 - scroll_lock.distance_from_top;
544 if target_scroll_top != scroll_position.y {
545 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
546 }
547 });
548 }
549
550 fn handle_editor_event(
551 &mut self,
552 editor: View<Editor>,
553 event: &EditorEvent,
554 cx: &mut WindowContext,
555 ) {
556 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
557 return;
558 };
559
560 match event {
561 EditorEvent::Saved => {
562 for assist_id in editor_assists.assist_ids.clone() {
563 let assist = &self.assists[&assist_id];
564 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
565 self.finish_assist(assist_id, false, cx)
566 }
567 }
568 }
569 EditorEvent::Edited { transaction_id } => {
570 let buffer = editor.read(cx).buffer().read(cx);
571 let edited_ranges =
572 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
573 let snapshot = buffer.snapshot(cx);
574
575 for assist_id in editor_assists.assist_ids.clone() {
576 let assist = &self.assists[&assist_id];
577 if matches!(
578 assist.codegen.read(cx).status,
579 CodegenStatus::Error(_) | CodegenStatus::Done
580 ) {
581 let assist_range = assist.range.to_offset(&snapshot);
582 if edited_ranges
583 .iter()
584 .any(|range| range.overlaps(&assist_range))
585 {
586 self.finish_assist(assist_id, false, cx);
587 }
588 }
589 }
590 }
591 EditorEvent::ScrollPositionChanged { .. } => {
592 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
593 let assist = &self.assists[&scroll_lock.assist_id];
594 if let Some(decorations) = assist.decorations.as_ref() {
595 let distance_from_top = editor.update(cx, |editor, cx| {
596 let scroll_top = editor.scroll_position(cx).y;
597 let prompt_row = editor
598 .row_for_block(decorations.prompt_block_id, cx)
599 .unwrap()
600 .0 as f32;
601 prompt_row - scroll_top
602 });
603
604 if distance_from_top != scroll_lock.distance_from_top {
605 editor_assists.scroll_lock = None;
606 }
607 }
608 }
609 }
610 EditorEvent::SelectionsChanged { .. } => {
611 for assist_id in editor_assists.assist_ids.clone() {
612 let assist = &self.assists[&assist_id];
613 if let Some(decorations) = assist.decorations.as_ref() {
614 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
615 return;
616 }
617 }
618 }
619
620 editor_assists.scroll_lock = None;
621 }
622 _ => {}
623 }
624 }
625
626 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
627 if let Some(assist) = self.assists.get(&assist_id) {
628 let assist_group_id = assist.group_id;
629 if self.assist_groups[&assist_group_id].linked {
630 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
631 self.finish_assist(assist_id, undo, cx);
632 }
633 return;
634 }
635 }
636
637 self.dismiss_assist(assist_id, cx);
638
639 if let Some(assist) = self.assists.remove(&assist_id) {
640 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
641 {
642 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
643 if entry.get().assist_ids.is_empty() {
644 entry.remove();
645 }
646 }
647
648 if let hash_map::Entry::Occupied(mut entry) =
649 self.assists_by_editor.entry(assist.editor.clone())
650 {
651 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
652 if entry.get().assist_ids.is_empty() {
653 entry.remove();
654 if let Some(editor) = assist.editor.upgrade() {
655 self.update_editor_highlights(&editor, cx);
656 }
657 } else {
658 entry.get().highlight_updates.send(()).ok();
659 }
660 }
661
662 if undo {
663 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
664 } else {
665 self.confirmed_assists.insert(assist_id, assist.codegen);
666 }
667 }
668
669 // Remove the assist from the status updates map
670 self.assist_observations.remove(&assist_id);
671 }
672
673 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
674 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
675 return false;
676 };
677 codegen.update(cx, |this, cx| this.undo(cx));
678 true
679 }
680
681 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
682 let Some(assist) = self.assists.get_mut(&assist_id) else {
683 return false;
684 };
685 let Some(editor) = assist.editor.upgrade() else {
686 return false;
687 };
688 let Some(decorations) = assist.decorations.take() else {
689 return false;
690 };
691
692 editor.update(cx, |editor, cx| {
693 let mut to_remove = decorations.removed_line_block_ids;
694 to_remove.insert(decorations.prompt_block_id);
695 to_remove.insert(decorations.end_block_id);
696 editor.remove_blocks(to_remove, None, cx);
697 });
698
699 if decorations
700 .prompt_editor
701 .focus_handle(cx)
702 .contains_focused(cx)
703 {
704 self.focus_next_assist(assist_id, cx);
705 }
706
707 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
708 if editor_assists
709 .scroll_lock
710 .as_ref()
711 .map_or(false, |lock| lock.assist_id == assist_id)
712 {
713 editor_assists.scroll_lock = None;
714 }
715 editor_assists.highlight_updates.send(()).ok();
716 }
717
718 true
719 }
720
721 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
722 let Some(assist) = self.assists.get(&assist_id) else {
723 return;
724 };
725
726 let assist_group = &self.assist_groups[&assist.group_id];
727 let assist_ix = assist_group
728 .assist_ids
729 .iter()
730 .position(|id| *id == assist_id)
731 .unwrap();
732 let assist_ids = assist_group
733 .assist_ids
734 .iter()
735 .skip(assist_ix + 1)
736 .chain(assist_group.assist_ids.iter().take(assist_ix));
737
738 for assist_id in assist_ids {
739 let assist = &self.assists[assist_id];
740 if assist.decorations.is_some() {
741 self.focus_assist(*assist_id, cx);
742 return;
743 }
744 }
745
746 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
747 }
748
749 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
750 let Some(assist) = self.assists.get(&assist_id) else {
751 return;
752 };
753
754 if let Some(decorations) = assist.decorations.as_ref() {
755 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
756 prompt_editor.editor.update(cx, |editor, cx| {
757 editor.focus(cx);
758 editor.select_all(&SelectAll, cx);
759 })
760 });
761 }
762
763 self.scroll_to_assist(assist_id, cx);
764 }
765
766 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
767 let Some(assist) = self.assists.get(&assist_id) else {
768 return;
769 };
770 let Some(editor) = assist.editor.upgrade() else {
771 return;
772 };
773
774 let position = assist.range.start;
775 editor.update(cx, |editor, cx| {
776 editor.change_selections(None, cx, |selections| {
777 selections.select_anchor_ranges([position..position])
778 });
779
780 let mut scroll_target_top;
781 let mut scroll_target_bottom;
782 if let Some(decorations) = assist.decorations.as_ref() {
783 scroll_target_top = editor
784 .row_for_block(decorations.prompt_block_id, cx)
785 .unwrap()
786 .0 as f32;
787 scroll_target_bottom = editor
788 .row_for_block(decorations.end_block_id, cx)
789 .unwrap()
790 .0 as f32;
791 } else {
792 let snapshot = editor.snapshot(cx);
793 let start_row = assist
794 .range
795 .start
796 .to_display_point(&snapshot.display_snapshot)
797 .row();
798 scroll_target_top = start_row.0 as f32;
799 scroll_target_bottom = scroll_target_top + 1.;
800 }
801 scroll_target_top -= editor.vertical_scroll_margin() as f32;
802 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
803
804 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
805 let scroll_top = editor.scroll_position(cx).y;
806 let scroll_bottom = scroll_top + height_in_lines;
807
808 if scroll_target_top < scroll_top {
809 editor.set_scroll_position(point(0., scroll_target_top), cx);
810 } else if scroll_target_bottom > scroll_bottom {
811 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
812 editor
813 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
814 } else {
815 editor.set_scroll_position(point(0., scroll_target_top), cx);
816 }
817 }
818 });
819 }
820
821 fn unlink_assist_group(
822 &mut self,
823 assist_group_id: InlineAssistGroupId,
824 cx: &mut WindowContext,
825 ) -> Vec<InlineAssistId> {
826 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
827 assist_group.linked = false;
828 for assist_id in &assist_group.assist_ids {
829 let assist = self.assists.get_mut(assist_id).unwrap();
830 if let Some(editor_decorations) = assist.decorations.as_ref() {
831 editor_decorations
832 .prompt_editor
833 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
834 }
835 }
836 assist_group.assist_ids.clone()
837 }
838
839 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
840 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
841 assist
842 } else {
843 return;
844 };
845
846 let assist_group_id = assist.group_id;
847 if self.assist_groups[&assist_group_id].linked {
848 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
849 self.start_assist(assist_id, cx);
850 }
851 return;
852 }
853
854 let Some(user_prompt) = assist.user_prompt(cx) else {
855 return;
856 };
857
858 self.prompt_history.retain(|prompt| *prompt != user_prompt);
859 self.prompt_history.push_back(user_prompt.clone());
860 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
861 self.prompt_history.pop_front();
862 }
863
864 let assistant_panel_context = assist.assistant_panel_context(cx);
865
866 assist
867 .codegen
868 .update(cx, |codegen, cx| {
869 codegen.start(
870 assist.range.clone(),
871 user_prompt,
872 assistant_panel_context,
873 cx,
874 )
875 })
876 .log_err();
877
878 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
879 tx.send(()).ok();
880 }
881 }
882
883 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
884 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
885 assist
886 } else {
887 return;
888 };
889
890 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
891
892 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
893 tx.send(()).ok();
894 }
895 }
896
897 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
898 if let Some(assist) = self.assists.get(&assist_id) {
899 match &assist.codegen.read(cx).status {
900 CodegenStatus::Idle => InlineAssistStatus::Idle,
901 CodegenStatus::Pending => InlineAssistStatus::Pending,
902 CodegenStatus::Done => InlineAssistStatus::Done,
903 CodegenStatus::Error(_) => InlineAssistStatus::Error,
904 }
905 } else if self.confirmed_assists.contains_key(&assist_id) {
906 InlineAssistStatus::Confirmed
907 } else {
908 InlineAssistStatus::Canceled
909 }
910 }
911
912 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
913 let mut gutter_pending_ranges = Vec::new();
914 let mut gutter_transformed_ranges = Vec::new();
915 let mut foreground_ranges = Vec::new();
916 let mut inserted_row_ranges = Vec::new();
917 let empty_assist_ids = Vec::new();
918 let assist_ids = self
919 .assists_by_editor
920 .get(&editor.downgrade())
921 .map_or(&empty_assist_ids, |editor_assists| {
922 &editor_assists.assist_ids
923 });
924
925 for assist_id in assist_ids {
926 if let Some(assist) = self.assists.get(assist_id) {
927 let codegen = assist.codegen.read(cx);
928 let buffer = codegen.buffer.read(cx).read(cx);
929 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
930
931 let pending_range =
932 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
933 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
934 gutter_pending_ranges.push(pending_range);
935 }
936
937 if let Some(edit_position) = codegen.edit_position {
938 let edited_range = assist.range.start..edit_position;
939 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
940 gutter_transformed_ranges.push(edited_range);
941 }
942 }
943
944 if assist.decorations.is_some() {
945 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
946 }
947 }
948 }
949
950 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
951 merge_ranges(&mut foreground_ranges, &snapshot);
952 merge_ranges(&mut gutter_pending_ranges, &snapshot);
953 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
954 editor.update(cx, |editor, cx| {
955 enum GutterPendingRange {}
956 if gutter_pending_ranges.is_empty() {
957 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
958 } else {
959 editor.highlight_gutter::<GutterPendingRange>(
960 &gutter_pending_ranges,
961 |cx| cx.theme().status().info_background,
962 cx,
963 )
964 }
965
966 enum GutterTransformedRange {}
967 if gutter_transformed_ranges.is_empty() {
968 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
969 } else {
970 editor.highlight_gutter::<GutterTransformedRange>(
971 &gutter_transformed_ranges,
972 |cx| cx.theme().status().info,
973 cx,
974 )
975 }
976
977 if foreground_ranges.is_empty() {
978 editor.clear_highlights::<InlineAssist>(cx);
979 } else {
980 editor.highlight_text::<InlineAssist>(
981 foreground_ranges,
982 HighlightStyle {
983 fade_out: Some(0.6),
984 ..Default::default()
985 },
986 cx,
987 );
988 }
989
990 editor.clear_row_highlights::<InlineAssist>();
991 for row_range in inserted_row_ranges {
992 editor.highlight_rows::<InlineAssist>(
993 row_range,
994 Some(cx.theme().status().info_background),
995 false,
996 cx,
997 );
998 }
999 });
1000 }
1001
1002 fn update_editor_blocks(
1003 &mut self,
1004 editor: &View<Editor>,
1005 assist_id: InlineAssistId,
1006 cx: &mut WindowContext,
1007 ) {
1008 let Some(assist) = self.assists.get_mut(&assist_id) else {
1009 return;
1010 };
1011 let Some(decorations) = assist.decorations.as_mut() else {
1012 return;
1013 };
1014
1015 let codegen = assist.codegen.read(cx);
1016 let old_snapshot = codegen.snapshot.clone();
1017 let old_buffer = codegen.old_buffer.clone();
1018 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
1019
1020 editor.update(cx, |editor, cx| {
1021 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1022 editor.remove_blocks(old_blocks, None, cx);
1023
1024 let mut new_blocks = Vec::new();
1025 for (new_row, old_row_range) in deleted_row_ranges {
1026 let (_, buffer_start) = old_snapshot
1027 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1028 .unwrap();
1029 let (_, buffer_end) = old_snapshot
1030 .point_to_buffer_offset(Point::new(
1031 *old_row_range.end(),
1032 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1033 ))
1034 .unwrap();
1035
1036 let deleted_lines_editor = cx.new_view(|cx| {
1037 let multi_buffer = cx.new_model(|_| {
1038 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1039 });
1040 multi_buffer.update(cx, |multi_buffer, cx| {
1041 multi_buffer.push_excerpts(
1042 old_buffer.clone(),
1043 Some(ExcerptRange {
1044 context: buffer_start..buffer_end,
1045 primary: None,
1046 }),
1047 cx,
1048 );
1049 });
1050
1051 enum DeletedLines {}
1052 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1053 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1054 editor.set_show_wrap_guides(false, cx);
1055 editor.set_show_gutter(false, cx);
1056 editor.scroll_manager.set_forbid_vertical_scroll(true);
1057 editor.set_read_only(true);
1058 editor.highlight_rows::<DeletedLines>(
1059 Anchor::min()..=Anchor::max(),
1060 Some(cx.theme().status().deleted_background),
1061 false,
1062 cx,
1063 );
1064 editor
1065 });
1066
1067 let height =
1068 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1069 new_blocks.push(BlockProperties {
1070 position: new_row,
1071 height,
1072 style: BlockStyle::Flex,
1073 render: Box::new(move |cx| {
1074 div()
1075 .bg(cx.theme().status().deleted_background)
1076 .size_full()
1077 .h(height as f32 * cx.line_height())
1078 .pl(cx.gutter_dimensions.full_width())
1079 .child(deleted_lines_editor.clone())
1080 .into_any_element()
1081 }),
1082 disposition: BlockDisposition::Above,
1083 priority: 0,
1084 });
1085 }
1086
1087 decorations.removed_line_block_ids = editor
1088 .insert_blocks(new_blocks, None, cx)
1089 .into_iter()
1090 .collect();
1091 })
1092 }
1093
1094 pub fn observe_assist(&mut self, assist_id: InlineAssistId) -> async_watch::Receiver<()> {
1095 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1096 rx.clone()
1097 } else {
1098 let (tx, rx) = async_watch::channel(());
1099 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1100 rx
1101 }
1102 }
1103}
1104
1105pub enum InlineAssistStatus {
1106 Idle,
1107 Pending,
1108 Done,
1109 Error,
1110 Confirmed,
1111 Canceled,
1112}
1113
1114impl InlineAssistStatus {
1115 pub(crate) fn is_pending(&self) -> bool {
1116 matches!(self, Self::Pending)
1117 }
1118 pub(crate) fn is_confirmed(&self) -> bool {
1119 matches!(self, Self::Confirmed)
1120 }
1121 pub(crate) fn is_done(&self) -> bool {
1122 matches!(self, Self::Done)
1123 }
1124}
1125
1126struct EditorInlineAssists {
1127 assist_ids: Vec<InlineAssistId>,
1128 scroll_lock: Option<InlineAssistScrollLock>,
1129 highlight_updates: async_watch::Sender<()>,
1130 _update_highlights: Task<Result<()>>,
1131 _subscriptions: Vec<gpui::Subscription>,
1132}
1133
1134struct InlineAssistScrollLock {
1135 assist_id: InlineAssistId,
1136 distance_from_top: f32,
1137}
1138
1139impl EditorInlineAssists {
1140 #[allow(clippy::too_many_arguments)]
1141 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1142 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1143 Self {
1144 assist_ids: Vec::new(),
1145 scroll_lock: None,
1146 highlight_updates: highlight_updates_tx,
1147 _update_highlights: cx.spawn(|mut cx| {
1148 let editor = editor.downgrade();
1149 async move {
1150 while let Ok(()) = highlight_updates_rx.changed().await {
1151 let editor = editor.upgrade().context("editor was dropped")?;
1152 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1153 assistant.update_editor_highlights(&editor, cx);
1154 })?;
1155 }
1156 Ok(())
1157 }
1158 }),
1159 _subscriptions: vec![
1160 cx.observe_release(editor, {
1161 let editor = editor.downgrade();
1162 |_, cx| {
1163 InlineAssistant::update_global(cx, |this, cx| {
1164 this.handle_editor_release(editor, cx);
1165 })
1166 }
1167 }),
1168 cx.observe(editor, move |editor, cx| {
1169 InlineAssistant::update_global(cx, |this, cx| {
1170 this.handle_editor_change(editor, cx)
1171 })
1172 }),
1173 cx.subscribe(editor, move |editor, event, cx| {
1174 InlineAssistant::update_global(cx, |this, cx| {
1175 this.handle_editor_event(editor, event, cx)
1176 })
1177 }),
1178 editor.update(cx, |editor, cx| {
1179 let editor_handle = cx.view().downgrade();
1180 editor.register_action(
1181 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1182 InlineAssistant::update_global(cx, |this, cx| {
1183 if let Some(editor) = editor_handle.upgrade() {
1184 this.handle_editor_newline(editor, cx)
1185 }
1186 })
1187 },
1188 )
1189 }),
1190 editor.update(cx, |editor, cx| {
1191 let editor_handle = cx.view().downgrade();
1192 editor.register_action(
1193 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1194 InlineAssistant::update_global(cx, |this, cx| {
1195 if let Some(editor) = editor_handle.upgrade() {
1196 this.handle_editor_cancel(editor, cx)
1197 }
1198 })
1199 },
1200 )
1201 }),
1202 ],
1203 }
1204 }
1205}
1206
1207struct InlineAssistGroup {
1208 assist_ids: Vec<InlineAssistId>,
1209 linked: bool,
1210 active_assist_id: Option<InlineAssistId>,
1211}
1212
1213impl InlineAssistGroup {
1214 fn new() -> Self {
1215 Self {
1216 assist_ids: Vec::new(),
1217 linked: true,
1218 active_assist_id: None,
1219 }
1220 }
1221}
1222
1223fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1224 let editor = editor.clone();
1225 Box::new(move |cx: &mut BlockContext| {
1226 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1227 editor.clone().into_any_element()
1228 })
1229}
1230
1231#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1232pub enum InitialInsertion {
1233 NewlineBefore,
1234 NewlineAfter,
1235}
1236
1237#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1238pub struct InlineAssistId(usize);
1239
1240impl InlineAssistId {
1241 fn post_inc(&mut self) -> InlineAssistId {
1242 let id = *self;
1243 self.0 += 1;
1244 id
1245 }
1246}
1247
1248#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1249struct InlineAssistGroupId(usize);
1250
1251impl InlineAssistGroupId {
1252 fn post_inc(&mut self) -> InlineAssistGroupId {
1253 let id = *self;
1254 self.0 += 1;
1255 id
1256 }
1257}
1258
1259enum PromptEditorEvent {
1260 StartRequested,
1261 StopRequested,
1262 ConfirmRequested,
1263 CancelRequested,
1264 DismissRequested,
1265}
1266
1267struct PromptEditor {
1268 id: InlineAssistId,
1269 fs: Arc<dyn Fs>,
1270 editor: View<Editor>,
1271 edited_since_done: bool,
1272 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1273 prompt_history: VecDeque<String>,
1274 prompt_history_ix: Option<usize>,
1275 pending_prompt: String,
1276 codegen: Model<Codegen>,
1277 _codegen_subscription: Subscription,
1278 editor_subscriptions: Vec<Subscription>,
1279 pending_token_count: Task<Result<()>>,
1280 token_count: Option<usize>,
1281 _token_count_subscriptions: Vec<Subscription>,
1282 workspace: Option<WeakView<Workspace>>,
1283 show_rate_limit_notice: bool,
1284}
1285
1286impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1287
1288impl Render for PromptEditor {
1289 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1290 let gutter_dimensions = *self.gutter_dimensions.lock();
1291 let status = &self.codegen.read(cx).status;
1292 let buttons = match status {
1293 CodegenStatus::Idle => {
1294 vec![
1295 IconButton::new("cancel", IconName::Close)
1296 .icon_color(Color::Muted)
1297 .shape(IconButtonShape::Square)
1298 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1299 .on_click(
1300 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1301 ),
1302 IconButton::new("start", IconName::SparkleAlt)
1303 .icon_color(Color::Muted)
1304 .shape(IconButtonShape::Square)
1305 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1306 .on_click(
1307 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1308 ),
1309 ]
1310 }
1311 CodegenStatus::Pending => {
1312 vec![
1313 IconButton::new("cancel", IconName::Close)
1314 .icon_color(Color::Muted)
1315 .shape(IconButtonShape::Square)
1316 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1317 .on_click(
1318 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1319 ),
1320 IconButton::new("stop", IconName::Stop)
1321 .icon_color(Color::Error)
1322 .shape(IconButtonShape::Square)
1323 .tooltip(|cx| {
1324 Tooltip::with_meta(
1325 "Interrupt Transformation",
1326 Some(&menu::Cancel),
1327 "Changes won't be discarded",
1328 cx,
1329 )
1330 })
1331 .on_click(
1332 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1333 ),
1334 ]
1335 }
1336 CodegenStatus::Error(_) | CodegenStatus::Done => {
1337 vec![
1338 IconButton::new("cancel", IconName::Close)
1339 .icon_color(Color::Muted)
1340 .shape(IconButtonShape::Square)
1341 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1342 .on_click(
1343 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1344 ),
1345 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1346 IconButton::new("restart", IconName::RotateCw)
1347 .icon_color(Color::Info)
1348 .shape(IconButtonShape::Square)
1349 .tooltip(|cx| {
1350 Tooltip::with_meta(
1351 "Restart Transformation",
1352 Some(&menu::Confirm),
1353 "Changes will be discarded",
1354 cx,
1355 )
1356 })
1357 .on_click(cx.listener(|_, _, cx| {
1358 cx.emit(PromptEditorEvent::StartRequested);
1359 }))
1360 } else {
1361 IconButton::new("confirm", IconName::Check)
1362 .icon_color(Color::Info)
1363 .shape(IconButtonShape::Square)
1364 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1365 .on_click(cx.listener(|_, _, cx| {
1366 cx.emit(PromptEditorEvent::ConfirmRequested);
1367 }))
1368 },
1369 ]
1370 }
1371 };
1372
1373 h_flex()
1374 .bg(cx.theme().colors().editor_background)
1375 .border_y_1()
1376 .border_color(cx.theme().status().info_border)
1377 .size_full()
1378 .py(cx.line_height() / 2.)
1379 .on_action(cx.listener(Self::confirm))
1380 .on_action(cx.listener(Self::cancel))
1381 .on_action(cx.listener(Self::move_up))
1382 .on_action(cx.listener(Self::move_down))
1383 .child(
1384 h_flex()
1385 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1386 .justify_center()
1387 .gap_2()
1388 .child(
1389 ModelSelector::new(
1390 self.fs.clone(),
1391 IconButton::new("context", IconName::SlidersAlt)
1392 .shape(IconButtonShape::Square)
1393 .icon_size(IconSize::Small)
1394 .icon_color(Color::Muted)
1395 .tooltip(move |cx| {
1396 Tooltip::with_meta(
1397 format!(
1398 "Using {}",
1399 LanguageModelRegistry::read_global(cx)
1400 .active_model()
1401 .map(|model| model.name().0)
1402 .unwrap_or_else(|| "No model selected".into()),
1403 ),
1404 None,
1405 "Change Model",
1406 cx,
1407 )
1408 }),
1409 )
1410 .with_info_text(
1411 "Inline edits use context\n\
1412 from the currently selected\n\
1413 assistant panel tab.",
1414 ),
1415 )
1416 .map(|el| {
1417 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1418 return el;
1419 };
1420
1421 let error_message = SharedString::from(error.to_string());
1422 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1423 && cx.has_flag::<ZedPro>()
1424 {
1425 el.child(
1426 v_flex()
1427 .child(
1428 IconButton::new("rate-limit-error", IconName::XCircle)
1429 .selected(self.show_rate_limit_notice)
1430 .shape(IconButtonShape::Square)
1431 .icon_size(IconSize::Small)
1432 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1433 )
1434 .children(self.show_rate_limit_notice.then(|| {
1435 deferred(
1436 anchored()
1437 .position_mode(gpui::AnchoredPositionMode::Local)
1438 .position(point(px(0.), px(24.)))
1439 .anchor(gpui::AnchorCorner::TopLeft)
1440 .child(self.render_rate_limit_notice(cx)),
1441 )
1442 })),
1443 )
1444 } else {
1445 el.child(
1446 div()
1447 .id("error")
1448 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1449 .child(
1450 Icon::new(IconName::XCircle)
1451 .size(IconSize::Small)
1452 .color(Color::Error),
1453 ),
1454 )
1455 }
1456 }),
1457 )
1458 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1459 .child(
1460 h_flex()
1461 .gap_2()
1462 .pr_6()
1463 .children(self.render_token_count(cx))
1464 .children(buttons),
1465 )
1466 }
1467}
1468
1469impl FocusableView for PromptEditor {
1470 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1471 self.editor.focus_handle(cx)
1472 }
1473}
1474
1475impl PromptEditor {
1476 const MAX_LINES: u8 = 8;
1477
1478 #[allow(clippy::too_many_arguments)]
1479 fn new(
1480 id: InlineAssistId,
1481 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1482 prompt_history: VecDeque<String>,
1483 prompt_buffer: Model<MultiBuffer>,
1484 codegen: Model<Codegen>,
1485 parent_editor: &View<Editor>,
1486 assistant_panel: Option<&View<AssistantPanel>>,
1487 workspace: Option<WeakView<Workspace>>,
1488 fs: Arc<dyn Fs>,
1489 cx: &mut ViewContext<Self>,
1490 ) -> Self {
1491 let prompt_editor = cx.new_view(|cx| {
1492 let mut editor = Editor::new(
1493 EditorMode::AutoHeight {
1494 max_lines: Self::MAX_LINES as usize,
1495 },
1496 prompt_buffer,
1497 None,
1498 false,
1499 cx,
1500 );
1501 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1502 // Since the prompt editors for all inline assistants are linked,
1503 // always show the cursor (even when it isn't focused) because
1504 // typing in one will make what you typed appear in all of them.
1505 editor.set_show_cursor_when_unfocused(true, cx);
1506 editor.set_placeholder_text("Add a prompt…", cx);
1507 editor
1508 });
1509
1510 let mut token_count_subscriptions = Vec::new();
1511 token_count_subscriptions
1512 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1513 if let Some(assistant_panel) = assistant_panel {
1514 token_count_subscriptions
1515 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1516 }
1517
1518 let mut this = Self {
1519 id,
1520 editor: prompt_editor,
1521 edited_since_done: false,
1522 gutter_dimensions,
1523 prompt_history,
1524 prompt_history_ix: None,
1525 pending_prompt: String::new(),
1526 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1527 editor_subscriptions: Vec::new(),
1528 codegen,
1529 fs,
1530 pending_token_count: Task::ready(Ok(())),
1531 token_count: None,
1532 _token_count_subscriptions: token_count_subscriptions,
1533 workspace,
1534 show_rate_limit_notice: false,
1535 };
1536 this.count_tokens(cx);
1537 this.subscribe_to_editor(cx);
1538 this
1539 }
1540
1541 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1542 self.editor_subscriptions.clear();
1543 self.editor_subscriptions
1544 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1545 }
1546
1547 fn set_show_cursor_when_unfocused(
1548 &mut self,
1549 show_cursor_when_unfocused: bool,
1550 cx: &mut ViewContext<Self>,
1551 ) {
1552 self.editor.update(cx, |editor, cx| {
1553 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1554 });
1555 }
1556
1557 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1558 let prompt = self.prompt(cx);
1559 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1560 self.editor = cx.new_view(|cx| {
1561 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1562 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1563 editor.set_placeholder_text("Add a prompt…", cx);
1564 editor.set_text(prompt, cx);
1565 if focus {
1566 editor.focus(cx);
1567 }
1568 editor
1569 });
1570 self.subscribe_to_editor(cx);
1571 }
1572
1573 fn prompt(&self, cx: &AppContext) -> String {
1574 self.editor.read(cx).text(cx)
1575 }
1576
1577 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1578 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1579 if self.show_rate_limit_notice {
1580 cx.focus_view(&self.editor);
1581 }
1582 cx.notify();
1583 }
1584
1585 fn handle_parent_editor_event(
1586 &mut self,
1587 _: View<Editor>,
1588 event: &EditorEvent,
1589 cx: &mut ViewContext<Self>,
1590 ) {
1591 if let EditorEvent::BufferEdited { .. } = event {
1592 self.count_tokens(cx);
1593 }
1594 }
1595
1596 fn handle_assistant_panel_event(
1597 &mut self,
1598 _: View<AssistantPanel>,
1599 event: &AssistantPanelEvent,
1600 cx: &mut ViewContext<Self>,
1601 ) {
1602 let AssistantPanelEvent::ContextEdited { .. } = event;
1603 self.count_tokens(cx);
1604 }
1605
1606 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1607 let assist_id = self.id;
1608 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1609 cx.background_executor().timer(Duration::from_secs(1)).await;
1610 let token_count = cx
1611 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1612 let assist = inline_assistant
1613 .assists
1614 .get(&assist_id)
1615 .context("assist not found")?;
1616 anyhow::Ok(assist.count_tokens(cx))
1617 })??
1618 .await?;
1619
1620 this.update(&mut cx, |this, cx| {
1621 this.token_count = Some(token_count);
1622 cx.notify();
1623 })
1624 })
1625 }
1626
1627 fn handle_prompt_editor_events(
1628 &mut self,
1629 _: View<Editor>,
1630 event: &EditorEvent,
1631 cx: &mut ViewContext<Self>,
1632 ) {
1633 match event {
1634 EditorEvent::Edited { .. } => {
1635 let prompt = self.editor.read(cx).text(cx);
1636 if self
1637 .prompt_history_ix
1638 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1639 {
1640 self.prompt_history_ix.take();
1641 self.pending_prompt = prompt;
1642 }
1643
1644 self.edited_since_done = true;
1645 cx.notify();
1646 }
1647 EditorEvent::BufferEdited => {
1648 self.count_tokens(cx);
1649 }
1650 EditorEvent::Blurred => {
1651 if self.show_rate_limit_notice {
1652 self.show_rate_limit_notice = false;
1653 cx.notify();
1654 }
1655 }
1656 _ => {}
1657 }
1658 }
1659
1660 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1661 match &self.codegen.read(cx).status {
1662 CodegenStatus::Idle => {
1663 self.editor
1664 .update(cx, |editor, _| editor.set_read_only(false));
1665 }
1666 CodegenStatus::Pending => {
1667 self.editor
1668 .update(cx, |editor, _| editor.set_read_only(true));
1669 }
1670 CodegenStatus::Done => {
1671 self.edited_since_done = false;
1672 self.editor
1673 .update(cx, |editor, _| editor.set_read_only(false));
1674 }
1675 CodegenStatus::Error(error) => {
1676 if cx.has_flag::<ZedPro>()
1677 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1678 && !dismissed_rate_limit_notice()
1679 {
1680 self.show_rate_limit_notice = true;
1681 cx.notify();
1682 }
1683
1684 self.edited_since_done = false;
1685 self.editor
1686 .update(cx, |editor, _| editor.set_read_only(false));
1687 }
1688 }
1689 }
1690
1691 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1692 match &self.codegen.read(cx).status {
1693 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1694 cx.emit(PromptEditorEvent::CancelRequested);
1695 }
1696 CodegenStatus::Pending => {
1697 cx.emit(PromptEditorEvent::StopRequested);
1698 }
1699 }
1700 }
1701
1702 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1703 match &self.codegen.read(cx).status {
1704 CodegenStatus::Idle => {
1705 cx.emit(PromptEditorEvent::StartRequested);
1706 }
1707 CodegenStatus::Pending => {
1708 cx.emit(PromptEditorEvent::DismissRequested);
1709 }
1710 CodegenStatus::Done | CodegenStatus::Error(_) => {
1711 if self.edited_since_done {
1712 cx.emit(PromptEditorEvent::StartRequested);
1713 } else {
1714 cx.emit(PromptEditorEvent::ConfirmRequested);
1715 }
1716 }
1717 }
1718 }
1719
1720 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1721 if let Some(ix) = self.prompt_history_ix {
1722 if ix > 0 {
1723 self.prompt_history_ix = Some(ix - 1);
1724 let prompt = self.prompt_history[ix - 1].as_str();
1725 self.editor.update(cx, |editor, cx| {
1726 editor.set_text(prompt, cx);
1727 editor.move_to_beginning(&Default::default(), cx);
1728 });
1729 }
1730 } else if !self.prompt_history.is_empty() {
1731 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1732 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1733 self.editor.update(cx, |editor, cx| {
1734 editor.set_text(prompt, cx);
1735 editor.move_to_beginning(&Default::default(), cx);
1736 });
1737 }
1738 }
1739
1740 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1741 if let Some(ix) = self.prompt_history_ix {
1742 if ix < self.prompt_history.len() - 1 {
1743 self.prompt_history_ix = Some(ix + 1);
1744 let prompt = self.prompt_history[ix + 1].as_str();
1745 self.editor.update(cx, |editor, cx| {
1746 editor.set_text(prompt, cx);
1747 editor.move_to_end(&Default::default(), cx)
1748 });
1749 } else {
1750 self.prompt_history_ix = None;
1751 let prompt = self.pending_prompt.as_str();
1752 self.editor.update(cx, |editor, cx| {
1753 editor.set_text(prompt, cx);
1754 editor.move_to_end(&Default::default(), cx)
1755 });
1756 }
1757 }
1758 }
1759
1760 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1761 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1762 let token_count = self.token_count?;
1763 let max_token_count = model.max_token_count();
1764
1765 let remaining_tokens = max_token_count as isize - token_count as isize;
1766 let token_count_color = if remaining_tokens <= 0 {
1767 Color::Error
1768 } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1769 Color::Warning
1770 } else {
1771 Color::Muted
1772 };
1773
1774 let mut token_count = h_flex()
1775 .id("token_count")
1776 .gap_0p5()
1777 .child(
1778 Label::new(humanize_token_count(token_count))
1779 .size(LabelSize::Small)
1780 .color(token_count_color),
1781 )
1782 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1783 .child(
1784 Label::new(humanize_token_count(max_token_count))
1785 .size(LabelSize::Small)
1786 .color(Color::Muted),
1787 );
1788 if let Some(workspace) = self.workspace.clone() {
1789 token_count = token_count
1790 .tooltip(|cx| {
1791 Tooltip::with_meta(
1792 "Tokens Used by Inline Assistant",
1793 None,
1794 "Click to Open Assistant Panel",
1795 cx,
1796 )
1797 })
1798 .cursor_pointer()
1799 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1800 .on_click(move |_, cx| {
1801 cx.stop_propagation();
1802 workspace
1803 .update(cx, |workspace, cx| {
1804 workspace.focus_panel::<AssistantPanel>(cx)
1805 })
1806 .ok();
1807 });
1808 } else {
1809 token_count = token_count
1810 .cursor_default()
1811 .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1812 }
1813
1814 Some(token_count)
1815 }
1816
1817 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1818 let settings = ThemeSettings::get_global(cx);
1819 let text_style = TextStyle {
1820 color: if self.editor.read(cx).read_only(cx) {
1821 cx.theme().colors().text_disabled
1822 } else {
1823 cx.theme().colors().text
1824 },
1825 font_family: settings.ui_font.family.clone(),
1826 font_features: settings.ui_font.features.clone(),
1827 font_fallbacks: settings.ui_font.fallbacks.clone(),
1828 font_size: rems(0.875).into(),
1829 font_weight: settings.ui_font.weight,
1830 line_height: relative(1.3),
1831 ..Default::default()
1832 };
1833 EditorElement::new(
1834 &self.editor,
1835 EditorStyle {
1836 background: cx.theme().colors().editor_background,
1837 local_player: cx.theme().players().local(),
1838 text: text_style,
1839 ..Default::default()
1840 },
1841 )
1842 }
1843
1844 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1845 Popover::new().child(
1846 v_flex()
1847 .occlude()
1848 .p_2()
1849 .child(
1850 Label::new("Out of Tokens")
1851 .size(LabelSize::Small)
1852 .weight(FontWeight::BOLD),
1853 )
1854 .child(Label::new(
1855 "Try Zed Pro for higher limits, a wider range of models, and more.",
1856 ))
1857 .child(
1858 h_flex()
1859 .justify_between()
1860 .child(CheckboxWithLabel::new(
1861 "dont-show-again",
1862 Label::new("Don't show again"),
1863 if dismissed_rate_limit_notice() {
1864 ui::Selection::Selected
1865 } else {
1866 ui::Selection::Unselected
1867 },
1868 |selection, cx| {
1869 let is_dismissed = match selection {
1870 ui::Selection::Unselected => false,
1871 ui::Selection::Indeterminate => return,
1872 ui::Selection::Selected => true,
1873 };
1874
1875 set_rate_limit_notice_dismissed(is_dismissed, cx)
1876 },
1877 ))
1878 .child(
1879 h_flex()
1880 .gap_2()
1881 .child(
1882 Button::new("dismiss", "Dismiss")
1883 .style(ButtonStyle::Transparent)
1884 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1885 )
1886 .child(Button::new("more-info", "More Info").on_click(
1887 |_event, cx| {
1888 cx.dispatch_action(Box::new(
1889 zed_actions::OpenAccountSettings,
1890 ))
1891 },
1892 )),
1893 ),
1894 ),
1895 )
1896 }
1897}
1898
1899const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1900
1901fn dismissed_rate_limit_notice() -> bool {
1902 db::kvp::KEY_VALUE_STORE
1903 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1904 .log_err()
1905 .map_or(false, |s| s.is_some())
1906}
1907
1908fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1909 db::write_and_log(cx, move || async move {
1910 if is_dismissed {
1911 db::kvp::KEY_VALUE_STORE
1912 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1913 .await
1914 } else {
1915 db::kvp::KEY_VALUE_STORE
1916 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1917 .await
1918 }
1919 })
1920}
1921
1922struct InlineAssist {
1923 group_id: InlineAssistGroupId,
1924 range: Range<Anchor>,
1925 editor: WeakView<Editor>,
1926 decorations: Option<InlineAssistDecorations>,
1927 codegen: Model<Codegen>,
1928 _subscriptions: Vec<Subscription>,
1929 workspace: Option<WeakView<Workspace>>,
1930 include_context: bool,
1931}
1932
1933impl InlineAssist {
1934 #[allow(clippy::too_many_arguments)]
1935 fn new(
1936 assist_id: InlineAssistId,
1937 group_id: InlineAssistGroupId,
1938 include_context: bool,
1939 editor: &View<Editor>,
1940 prompt_editor: &View<PromptEditor>,
1941 prompt_block_id: CustomBlockId,
1942 end_block_id: CustomBlockId,
1943 range: Range<Anchor>,
1944 codegen: Model<Codegen>,
1945 workspace: Option<WeakView<Workspace>>,
1946 cx: &mut WindowContext,
1947 ) -> Self {
1948 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1949 InlineAssist {
1950 group_id,
1951 include_context,
1952 editor: editor.downgrade(),
1953 decorations: Some(InlineAssistDecorations {
1954 prompt_block_id,
1955 prompt_editor: prompt_editor.clone(),
1956 removed_line_block_ids: HashSet::default(),
1957 end_block_id,
1958 }),
1959 range,
1960 codegen: codegen.clone(),
1961 workspace: workspace.clone(),
1962 _subscriptions: vec![
1963 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1964 InlineAssistant::update_global(cx, |this, cx| {
1965 this.handle_prompt_editor_focus_in(assist_id, cx)
1966 })
1967 }),
1968 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1969 InlineAssistant::update_global(cx, |this, cx| {
1970 this.handle_prompt_editor_focus_out(assist_id, cx)
1971 })
1972 }),
1973 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1974 InlineAssistant::update_global(cx, |this, cx| {
1975 this.handle_prompt_editor_event(prompt_editor, event, cx)
1976 })
1977 }),
1978 cx.observe(&codegen, {
1979 let editor = editor.downgrade();
1980 move |_, cx| {
1981 if let Some(editor) = editor.upgrade() {
1982 InlineAssistant::update_global(cx, |this, cx| {
1983 if let Some(editor_assists) =
1984 this.assists_by_editor.get(&editor.downgrade())
1985 {
1986 editor_assists.highlight_updates.send(()).ok();
1987 }
1988
1989 this.update_editor_blocks(&editor, assist_id, cx);
1990 })
1991 }
1992 }
1993 }),
1994 cx.subscribe(&codegen, move |codegen, event, cx| {
1995 InlineAssistant::update_global(cx, |this, cx| match event {
1996 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1997 CodegenEvent::Finished => {
1998 let assist = if let Some(assist) = this.assists.get(&assist_id) {
1999 assist
2000 } else {
2001 return;
2002 };
2003
2004 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
2005 if assist.decorations.is_none() {
2006 if let Some(workspace) = assist
2007 .workspace
2008 .as_ref()
2009 .and_then(|workspace| workspace.upgrade())
2010 {
2011 let error = format!("Inline assistant error: {}", error);
2012 workspace.update(cx, |workspace, cx| {
2013 struct InlineAssistantError;
2014
2015 let id =
2016 NotificationId::identified::<InlineAssistantError>(
2017 assist_id.0,
2018 );
2019
2020 workspace.show_toast(Toast::new(id, error), cx);
2021 })
2022 }
2023 }
2024 }
2025
2026 if assist.decorations.is_none() {
2027 this.finish_assist(assist_id, false, cx);
2028 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2029 tx.0.send(()).ok();
2030 }
2031 }
2032 })
2033 }),
2034 ],
2035 }
2036 }
2037
2038 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2039 let decorations = self.decorations.as_ref()?;
2040 Some(decorations.prompt_editor.read(cx).prompt(cx))
2041 }
2042
2043 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2044 if self.include_context {
2045 let workspace = self.workspace.as_ref()?;
2046 let workspace = workspace.upgrade()?.read(cx);
2047 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2048 Some(
2049 assistant_panel
2050 .read(cx)
2051 .active_context(cx)?
2052 .read(cx)
2053 .to_completion_request(cx),
2054 )
2055 } else {
2056 None
2057 }
2058 }
2059
2060 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
2061 let Some(user_prompt) = self.user_prompt(cx) else {
2062 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2063 };
2064 let assistant_panel_context = self.assistant_panel_context(cx);
2065 self.codegen.read(cx).count_tokens(
2066 self.range.clone(),
2067 user_prompt,
2068 assistant_panel_context,
2069 cx,
2070 )
2071 }
2072}
2073
2074struct InlineAssistDecorations {
2075 prompt_block_id: CustomBlockId,
2076 prompt_editor: View<PromptEditor>,
2077 removed_line_block_ids: HashSet<CustomBlockId>,
2078 end_block_id: CustomBlockId,
2079}
2080
2081#[derive(Debug)]
2082pub enum CodegenEvent {
2083 Finished,
2084 Undone,
2085}
2086
2087pub struct Codegen {
2088 buffer: Model<MultiBuffer>,
2089 old_buffer: Model<Buffer>,
2090 snapshot: MultiBufferSnapshot,
2091 edit_position: Option<Anchor>,
2092 last_equal_ranges: Vec<Range<Anchor>>,
2093 initial_transaction_id: Option<TransactionId>,
2094 transformation_transaction_id: Option<TransactionId>,
2095 status: CodegenStatus,
2096 generation: Task<()>,
2097 diff: Diff,
2098 telemetry: Option<Arc<Telemetry>>,
2099 _subscription: gpui::Subscription,
2100 builder: Arc<PromptBuilder>,
2101}
2102
2103enum CodegenStatus {
2104 Idle,
2105 Pending,
2106 Done,
2107 Error(anyhow::Error),
2108}
2109
2110#[derive(Default)]
2111struct Diff {
2112 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2113 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2114}
2115
2116impl Diff {
2117 fn is_empty(&self) -> bool {
2118 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2119 }
2120}
2121
2122impl EventEmitter<CodegenEvent> for Codegen {}
2123
2124impl Codegen {
2125 pub fn new(
2126 buffer: Model<MultiBuffer>,
2127 range: Range<Anchor>,
2128 initial_transaction_id: Option<TransactionId>,
2129 telemetry: Option<Arc<Telemetry>>,
2130 builder: Arc<PromptBuilder>,
2131 cx: &mut ModelContext<Self>,
2132 ) -> Self {
2133 let snapshot = buffer.read(cx).snapshot(cx);
2134
2135 let (old_buffer, _, _) = buffer
2136 .read(cx)
2137 .range_to_buffer_ranges(range.clone(), cx)
2138 .pop()
2139 .unwrap();
2140 let old_buffer = cx.new_model(|cx| {
2141 let old_buffer = old_buffer.read(cx);
2142 let text = old_buffer.as_rope().clone();
2143 let line_ending = old_buffer.line_ending();
2144 let language = old_buffer.language().cloned();
2145 let language_registry = old_buffer.language_registry();
2146
2147 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2148 buffer.set_language(language, cx);
2149 if let Some(language_registry) = language_registry {
2150 buffer.set_language_registry(language_registry)
2151 }
2152 buffer
2153 });
2154
2155 Self {
2156 buffer: buffer.clone(),
2157 old_buffer,
2158 edit_position: None,
2159 snapshot,
2160 last_equal_ranges: Default::default(),
2161 transformation_transaction_id: None,
2162 status: CodegenStatus::Idle,
2163 generation: Task::ready(()),
2164 diff: Diff::default(),
2165 telemetry,
2166 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2167 initial_transaction_id,
2168 builder,
2169 }
2170 }
2171
2172 fn handle_buffer_event(
2173 &mut self,
2174 _buffer: Model<MultiBuffer>,
2175 event: &multi_buffer::Event,
2176 cx: &mut ModelContext<Self>,
2177 ) {
2178 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2179 if self.transformation_transaction_id == Some(*transaction_id) {
2180 self.transformation_transaction_id = None;
2181 self.generation = Task::ready(());
2182 cx.emit(CodegenEvent::Undone);
2183 }
2184 }
2185 }
2186
2187 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2188 &self.last_equal_ranges
2189 }
2190
2191 pub fn count_tokens(
2192 &self,
2193 edit_range: Range<Anchor>,
2194 user_prompt: String,
2195 assistant_panel_context: Option<LanguageModelRequest>,
2196 cx: &AppContext,
2197 ) -> BoxFuture<'static, Result<usize>> {
2198 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2199 let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
2200 match request {
2201 Ok(request) => model.count_tokens(request, cx),
2202 Err(error) => futures::future::ready(Err(error)).boxed(),
2203 }
2204 } else {
2205 future::ready(Err(anyhow!("no active model"))).boxed()
2206 }
2207 }
2208
2209 pub fn start(
2210 &mut self,
2211 edit_range: Range<Anchor>,
2212 user_prompt: String,
2213 assistant_panel_context: Option<LanguageModelRequest>,
2214 cx: &mut ModelContext<Self>,
2215 ) -> Result<()> {
2216 let model = LanguageModelRegistry::read_global(cx)
2217 .active_model()
2218 .context("no active model")?;
2219
2220 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2221 self.buffer.update(cx, |buffer, cx| {
2222 buffer.undo_transaction(transformation_transaction_id, cx);
2223 });
2224 }
2225
2226 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2227
2228 let telemetry_id = model.telemetry_id();
2229 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2230 .trim()
2231 .to_lowercase()
2232 == "delete"
2233 {
2234 async { Ok(stream::empty().boxed()) }.boxed_local()
2235 } else {
2236 let request =
2237 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
2238
2239 let chunks =
2240 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2241 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2242 };
2243 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2244 Ok(())
2245 }
2246
2247 fn build_request(
2248 &self,
2249 user_prompt: String,
2250 assistant_panel_context: Option<LanguageModelRequest>,
2251 edit_range: Range<Anchor>,
2252 cx: &AppContext,
2253 ) -> Result<LanguageModelRequest> {
2254 let buffer = self.buffer.read(cx).snapshot(cx);
2255 let language = buffer.language_at(edit_range.start);
2256 let language_name = if let Some(language) = language.as_ref() {
2257 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2258 None
2259 } else {
2260 Some(language.name())
2261 }
2262 } else {
2263 None
2264 };
2265
2266 // Higher Temperature increases the randomness of model outputs.
2267 // If Markdown or No Language is Known, increase the randomness for more creative output
2268 // If Code, decrease temperature to get more deterministic outputs
2269 let temperature = if let Some(language) = language_name.clone() {
2270 if language.as_ref() == "Markdown" {
2271 1.0
2272 } else {
2273 0.5
2274 }
2275 } else {
2276 1.0
2277 };
2278
2279 let language_name = language_name.as_deref();
2280 let start = buffer.point_to_buffer_offset(edit_range.start);
2281 let end = buffer.point_to_buffer_offset(edit_range.end);
2282 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2283 let (start_buffer, start_buffer_offset) = start;
2284 let (end_buffer, end_buffer_offset) = end;
2285 if start_buffer.remote_id() == end_buffer.remote_id() {
2286 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2287 } else {
2288 return Err(anyhow::anyhow!("invalid transformation range"));
2289 }
2290 } else {
2291 return Err(anyhow::anyhow!("invalid transformation range"));
2292 };
2293 let prompt = self
2294 .builder
2295 .generate_content_prompt(user_prompt, language_name, buffer, range)
2296 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2297
2298 let mut messages = Vec::new();
2299 if let Some(context_request) = assistant_panel_context {
2300 messages = context_request.messages;
2301 }
2302
2303 messages.push(LanguageModelRequestMessage {
2304 role: Role::User,
2305 content: prompt,
2306 });
2307
2308 Ok(LanguageModelRequest {
2309 messages,
2310 stop: vec!["|END|>".to_string()],
2311 temperature,
2312 })
2313 }
2314
2315 pub fn handle_stream(
2316 &mut self,
2317 model_telemetry_id: String,
2318 edit_range: Range<Anchor>,
2319 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2320 cx: &mut ModelContext<Self>,
2321 ) {
2322 let snapshot = self.snapshot.clone();
2323 let selected_text = snapshot
2324 .text_for_range(edit_range.start..edit_range.end)
2325 .collect::<Rope>();
2326
2327 let selection_start = edit_range.start.to_point(&snapshot);
2328
2329 // Start with the indentation of the first line in the selection
2330 let mut suggested_line_indent = snapshot
2331 .suggested_indents(selection_start.row..=selection_start.row, cx)
2332 .into_values()
2333 .next()
2334 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2335
2336 // If the first line in the selection does not have indentation, check the following lines
2337 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2338 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2339 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2340 // Prefer tabs if a line in the selection uses tabs as indentation
2341 if line_indent.kind == IndentKind::Tab {
2342 suggested_line_indent.kind = IndentKind::Tab;
2343 break;
2344 }
2345 }
2346 }
2347
2348 let telemetry = self.telemetry.clone();
2349 self.diff = Diff::default();
2350 self.status = CodegenStatus::Pending;
2351 let mut edit_start = edit_range.start.to_offset(&snapshot);
2352 self.generation = cx.spawn(|this, mut cx| {
2353 async move {
2354 let chunks = stream.await;
2355 let generate = async {
2356 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2357 let diff: Task<anyhow::Result<()>> =
2358 cx.background_executor().spawn(async move {
2359 let mut response_latency = None;
2360 let request_start = Instant::now();
2361 let diff = async {
2362 let chunks = StripInvalidSpans::new(chunks?);
2363 futures::pin_mut!(chunks);
2364 let mut diff = StreamingDiff::new(selected_text.to_string());
2365 let mut line_diff = LineDiff::default();
2366
2367 let mut new_text = String::new();
2368 let mut base_indent = None;
2369 let mut line_indent = None;
2370 let mut first_line = true;
2371
2372 while let Some(chunk) = chunks.next().await {
2373 if response_latency.is_none() {
2374 response_latency = Some(request_start.elapsed());
2375 }
2376 let chunk = chunk?;
2377
2378 let mut lines = chunk.split('\n').peekable();
2379 while let Some(line) = lines.next() {
2380 new_text.push_str(line);
2381 if line_indent.is_none() {
2382 if let Some(non_whitespace_ch_ix) =
2383 new_text.find(|ch: char| !ch.is_whitespace())
2384 {
2385 line_indent = Some(non_whitespace_ch_ix);
2386 base_indent = base_indent.or(line_indent);
2387
2388 let line_indent = line_indent.unwrap();
2389 let base_indent = base_indent.unwrap();
2390 let indent_delta =
2391 line_indent as i32 - base_indent as i32;
2392 let mut corrected_indent_len = cmp::max(
2393 0,
2394 suggested_line_indent.len as i32 + indent_delta,
2395 )
2396 as usize;
2397 if first_line {
2398 corrected_indent_len = corrected_indent_len
2399 .saturating_sub(
2400 selection_start.column as usize,
2401 );
2402 }
2403
2404 let indent_char = suggested_line_indent.char();
2405 let mut indent_buffer = [0; 4];
2406 let indent_str =
2407 indent_char.encode_utf8(&mut indent_buffer);
2408 new_text.replace_range(
2409 ..line_indent,
2410 &indent_str.repeat(corrected_indent_len),
2411 );
2412 }
2413 }
2414
2415 if line_indent.is_some() {
2416 let char_ops = diff.push_new(&new_text);
2417 line_diff
2418 .push_char_operations(&char_ops, &selected_text);
2419 diff_tx
2420 .send((char_ops, line_diff.line_operations()))
2421 .await?;
2422 new_text.clear();
2423 }
2424
2425 if lines.peek().is_some() {
2426 let char_ops = diff.push_new("\n");
2427 line_diff
2428 .push_char_operations(&char_ops, &selected_text);
2429 diff_tx
2430 .send((char_ops, line_diff.line_operations()))
2431 .await?;
2432 if line_indent.is_none() {
2433 // Don't write out the leading indentation in empty lines on the next line
2434 // This is the case where the above if statement didn't clear the buffer
2435 new_text.clear();
2436 }
2437 line_indent = None;
2438 first_line = false;
2439 }
2440 }
2441 }
2442
2443 let mut char_ops = diff.push_new(&new_text);
2444 char_ops.extend(diff.finish());
2445 line_diff.push_char_operations(&char_ops, &selected_text);
2446 line_diff.finish(&selected_text);
2447 diff_tx
2448 .send((char_ops, line_diff.line_operations()))
2449 .await?;
2450
2451 anyhow::Ok(())
2452 };
2453
2454 let result = diff.await;
2455
2456 let error_message =
2457 result.as_ref().err().map(|error| error.to_string());
2458 if let Some(telemetry) = telemetry {
2459 telemetry.report_assistant_event(
2460 None,
2461 telemetry_events::AssistantKind::Inline,
2462 model_telemetry_id,
2463 response_latency,
2464 error_message,
2465 );
2466 }
2467
2468 result?;
2469 Ok(())
2470 });
2471
2472 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2473 this.update(&mut cx, |this, cx| {
2474 this.last_equal_ranges.clear();
2475
2476 let transaction = this.buffer.update(cx, |buffer, cx| {
2477 // Avoid grouping assistant edits with user edits.
2478 buffer.finalize_last_transaction(cx);
2479
2480 buffer.start_transaction(cx);
2481 buffer.edit(
2482 char_ops
2483 .into_iter()
2484 .filter_map(|operation| match operation {
2485 CharOperation::Insert { text } => {
2486 let edit_start = snapshot.anchor_after(edit_start);
2487 Some((edit_start..edit_start, text))
2488 }
2489 CharOperation::Delete { bytes } => {
2490 let edit_end = edit_start + bytes;
2491 let edit_range = snapshot.anchor_after(edit_start)
2492 ..snapshot.anchor_before(edit_end);
2493 edit_start = edit_end;
2494 Some((edit_range, String::new()))
2495 }
2496 CharOperation::Keep { bytes } => {
2497 let edit_end = edit_start + bytes;
2498 let edit_range = snapshot.anchor_after(edit_start)
2499 ..snapshot.anchor_before(edit_end);
2500 edit_start = edit_end;
2501 this.last_equal_ranges.push(edit_range);
2502 None
2503 }
2504 }),
2505 None,
2506 cx,
2507 );
2508 this.edit_position = Some(snapshot.anchor_after(edit_start));
2509
2510 buffer.end_transaction(cx)
2511 });
2512
2513 if let Some(transaction) = transaction {
2514 if let Some(first_transaction) = this.transformation_transaction_id
2515 {
2516 // Group all assistant edits into the first transaction.
2517 this.buffer.update(cx, |buffer, cx| {
2518 buffer.merge_transactions(
2519 transaction,
2520 first_transaction,
2521 cx,
2522 )
2523 });
2524 } else {
2525 this.transformation_transaction_id = Some(transaction);
2526 this.buffer.update(cx, |buffer, cx| {
2527 buffer.finalize_last_transaction(cx)
2528 });
2529 }
2530 }
2531
2532 this.update_diff(edit_range.clone(), line_diff, cx);
2533
2534 cx.notify();
2535 })?;
2536 }
2537
2538 diff.await?;
2539
2540 anyhow::Ok(())
2541 };
2542
2543 let result = generate.await;
2544 this.update(&mut cx, |this, cx| {
2545 this.last_equal_ranges.clear();
2546 if let Err(error) = result {
2547 this.status = CodegenStatus::Error(error);
2548 } else {
2549 this.status = CodegenStatus::Done;
2550 }
2551 cx.emit(CodegenEvent::Finished);
2552 cx.notify();
2553 })
2554 .ok();
2555 }
2556 });
2557 cx.notify();
2558 }
2559
2560 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2561 self.last_equal_ranges.clear();
2562 if self.diff.is_empty() {
2563 self.status = CodegenStatus::Idle;
2564 } else {
2565 self.status = CodegenStatus::Done;
2566 }
2567 self.generation = Task::ready(());
2568 cx.emit(CodegenEvent::Finished);
2569 cx.notify();
2570 }
2571
2572 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2573 self.buffer.update(cx, |buffer, cx| {
2574 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2575 buffer.undo_transaction(transaction_id, cx);
2576 buffer.refresh_preview(cx);
2577 }
2578
2579 if let Some(transaction_id) = self.initial_transaction_id.take() {
2580 buffer.undo_transaction(transaction_id, cx);
2581 buffer.refresh_preview(cx);
2582 }
2583 });
2584 }
2585
2586 fn update_diff(
2587 &mut self,
2588 edit_range: Range<Anchor>,
2589 line_operations: Vec<LineOperation>,
2590 cx: &mut ModelContext<Self>,
2591 ) {
2592 let old_snapshot = self.snapshot.clone();
2593 let old_range = edit_range.to_point(&old_snapshot);
2594 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2595 let new_range = edit_range.to_point(&new_snapshot);
2596
2597 let mut old_row = old_range.start.row;
2598 let mut new_row = new_range.start.row;
2599
2600 self.diff.deleted_row_ranges.clear();
2601 self.diff.inserted_row_ranges.clear();
2602 for operation in line_operations {
2603 match operation {
2604 LineOperation::Keep { lines } => {
2605 old_row += lines;
2606 new_row += lines;
2607 }
2608 LineOperation::Delete { lines } => {
2609 let old_end_row = old_row + lines - 1;
2610 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2611
2612 if let Some((_, last_deleted_row_range)) =
2613 self.diff.deleted_row_ranges.last_mut()
2614 {
2615 if *last_deleted_row_range.end() + 1 == old_row {
2616 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2617 } else {
2618 self.diff
2619 .deleted_row_ranges
2620 .push((new_row, old_row..=old_end_row));
2621 }
2622 } else {
2623 self.diff
2624 .deleted_row_ranges
2625 .push((new_row, old_row..=old_end_row));
2626 }
2627
2628 old_row += lines;
2629 }
2630 LineOperation::Insert { lines } => {
2631 let new_end_row = new_row + lines - 1;
2632 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2633 let end = new_snapshot.anchor_before(Point::new(
2634 new_end_row,
2635 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2636 ));
2637 self.diff.inserted_row_ranges.push(start..=end);
2638 new_row += lines;
2639 }
2640 }
2641
2642 cx.notify();
2643 }
2644 }
2645}
2646
2647struct StripInvalidSpans<T> {
2648 stream: T,
2649 stream_done: bool,
2650 buffer: String,
2651 first_line: bool,
2652 line_end: bool,
2653 starts_with_code_block: bool,
2654}
2655
2656impl<T> StripInvalidSpans<T>
2657where
2658 T: Stream<Item = Result<String>>,
2659{
2660 fn new(stream: T) -> Self {
2661 Self {
2662 stream,
2663 stream_done: false,
2664 buffer: String::new(),
2665 first_line: true,
2666 line_end: false,
2667 starts_with_code_block: false,
2668 }
2669 }
2670}
2671
2672impl<T> Stream for StripInvalidSpans<T>
2673where
2674 T: Stream<Item = Result<String>>,
2675{
2676 type Item = Result<String>;
2677
2678 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2679 const CODE_BLOCK_DELIMITER: &str = "```";
2680 const CURSOR_SPAN: &str = "<|CURSOR|>";
2681
2682 let this = unsafe { self.get_unchecked_mut() };
2683 loop {
2684 if !this.stream_done {
2685 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2686 match stream.as_mut().poll_next(cx) {
2687 Poll::Ready(Some(Ok(chunk))) => {
2688 this.buffer.push_str(&chunk);
2689 }
2690 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2691 Poll::Ready(None) => {
2692 this.stream_done = true;
2693 }
2694 Poll::Pending => return Poll::Pending,
2695 }
2696 }
2697
2698 let mut chunk = String::new();
2699 let mut consumed = 0;
2700 if !this.buffer.is_empty() {
2701 let mut lines = this.buffer.split('\n').enumerate().peekable();
2702 while let Some((line_ix, line)) = lines.next() {
2703 if line_ix > 0 {
2704 this.first_line = false;
2705 }
2706
2707 if this.first_line {
2708 let trimmed_line = line.trim();
2709 if lines.peek().is_some() {
2710 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2711 consumed += line.len() + 1;
2712 this.starts_with_code_block = true;
2713 continue;
2714 }
2715 } else if trimmed_line.is_empty()
2716 || prefixes(CODE_BLOCK_DELIMITER)
2717 .any(|prefix| trimmed_line.starts_with(prefix))
2718 {
2719 break;
2720 }
2721 }
2722
2723 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2724 if lines.peek().is_some() {
2725 if this.line_end {
2726 chunk.push('\n');
2727 }
2728
2729 chunk.push_str(&line_without_cursor);
2730 this.line_end = true;
2731 consumed += line.len() + 1;
2732 } else if this.stream_done {
2733 if !this.starts_with_code_block
2734 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2735 {
2736 if this.line_end {
2737 chunk.push('\n');
2738 }
2739
2740 chunk.push_str(&line);
2741 }
2742
2743 consumed += line.len();
2744 } else {
2745 let trimmed_line = line.trim();
2746 if trimmed_line.is_empty()
2747 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2748 || prefixes(CODE_BLOCK_DELIMITER)
2749 .any(|prefix| trimmed_line.ends_with(prefix))
2750 {
2751 break;
2752 } else {
2753 if this.line_end {
2754 chunk.push('\n');
2755 this.line_end = false;
2756 }
2757
2758 chunk.push_str(&line_without_cursor);
2759 consumed += line.len();
2760 }
2761 }
2762 }
2763 }
2764
2765 this.buffer = this.buffer.split_off(consumed);
2766 if !chunk.is_empty() {
2767 return Poll::Ready(Some(Ok(chunk)));
2768 } else if this.stream_done {
2769 return Poll::Ready(None);
2770 }
2771 }
2772 }
2773}
2774
2775fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2776 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2777}
2778
2779fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2780 ranges.sort_unstable_by(|a, b| {
2781 a.start
2782 .cmp(&b.start, buffer)
2783 .then_with(|| b.end.cmp(&a.end, buffer))
2784 });
2785
2786 let mut ix = 0;
2787 while ix + 1 < ranges.len() {
2788 let b = ranges[ix + 1].clone();
2789 let a = &mut ranges[ix];
2790 if a.end.cmp(&b.start, buffer).is_gt() {
2791 if a.end.cmp(&b.end, buffer).is_lt() {
2792 a.end = b.end;
2793 }
2794 ranges.remove(ix + 1);
2795 } else {
2796 ix += 1;
2797 }
2798 }
2799}
2800
2801#[cfg(test)]
2802mod tests {
2803 use super::*;
2804 use futures::stream::{self};
2805 use gpui::{Context, TestAppContext};
2806 use indoc::indoc;
2807 use language::{
2808 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2809 Point,
2810 };
2811 use language_model::LanguageModelRegistry;
2812 use rand::prelude::*;
2813 use serde::Serialize;
2814 use settings::SettingsStore;
2815 use std::{future, sync::Arc};
2816
2817 #[derive(Serialize)]
2818 pub struct DummyCompletionRequest {
2819 pub name: String,
2820 }
2821
2822 #[gpui::test(iterations = 10)]
2823 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2824 cx.set_global(cx.update(SettingsStore::test));
2825 cx.update(language_model::LanguageModelRegistry::test);
2826 cx.update(language_settings::init);
2827
2828 let text = indoc! {"
2829 fn main() {
2830 let x = 0;
2831 for _ in 0..10 {
2832 x += 1;
2833 }
2834 }
2835 "};
2836 let buffer =
2837 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2838 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2839 let range = buffer.read_with(cx, |buffer, cx| {
2840 let snapshot = buffer.snapshot(cx);
2841 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2842 });
2843 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2844 let codegen = cx.new_model(|cx| {
2845 Codegen::new(
2846 buffer.clone(),
2847 range.clone(),
2848 None,
2849 None,
2850 prompt_builder,
2851 cx,
2852 )
2853 });
2854
2855 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2856 codegen.update(cx, |codegen, cx| {
2857 codegen.handle_stream(
2858 String::new(),
2859 range,
2860 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2861 cx,
2862 )
2863 });
2864
2865 let mut new_text = concat!(
2866 " let mut x = 0;\n",
2867 " while x < 10 {\n",
2868 " x += 1;\n",
2869 " }",
2870 );
2871 while !new_text.is_empty() {
2872 let max_len = cmp::min(new_text.len(), 10);
2873 let len = rng.gen_range(1..=max_len);
2874 let (chunk, suffix) = new_text.split_at(len);
2875 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2876 new_text = suffix;
2877 cx.background_executor.run_until_parked();
2878 }
2879 drop(chunks_tx);
2880 cx.background_executor.run_until_parked();
2881
2882 assert_eq!(
2883 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2884 indoc! {"
2885 fn main() {
2886 let mut x = 0;
2887 while x < 10 {
2888 x += 1;
2889 }
2890 }
2891 "}
2892 );
2893 }
2894
2895 #[gpui::test(iterations = 10)]
2896 async fn test_autoindent_when_generating_past_indentation(
2897 cx: &mut TestAppContext,
2898 mut rng: StdRng,
2899 ) {
2900 cx.set_global(cx.update(SettingsStore::test));
2901 cx.update(language_settings::init);
2902
2903 let text = indoc! {"
2904 fn main() {
2905 le
2906 }
2907 "};
2908 let buffer =
2909 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2910 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2911 let range = buffer.read_with(cx, |buffer, cx| {
2912 let snapshot = buffer.snapshot(cx);
2913 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2914 });
2915 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2916 let codegen = cx.new_model(|cx| {
2917 Codegen::new(
2918 buffer.clone(),
2919 range.clone(),
2920 None,
2921 None,
2922 prompt_builder,
2923 cx,
2924 )
2925 });
2926
2927 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2928 codegen.update(cx, |codegen, cx| {
2929 codegen.handle_stream(
2930 String::new(),
2931 range.clone(),
2932 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2933 cx,
2934 )
2935 });
2936
2937 cx.background_executor.run_until_parked();
2938
2939 let mut new_text = concat!(
2940 "t mut x = 0;\n",
2941 "while x < 10 {\n",
2942 " x += 1;\n",
2943 "}", //
2944 );
2945 while !new_text.is_empty() {
2946 let max_len = cmp::min(new_text.len(), 10);
2947 let len = rng.gen_range(1..=max_len);
2948 let (chunk, suffix) = new_text.split_at(len);
2949 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2950 new_text = suffix;
2951 cx.background_executor.run_until_parked();
2952 }
2953 drop(chunks_tx);
2954 cx.background_executor.run_until_parked();
2955
2956 assert_eq!(
2957 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2958 indoc! {"
2959 fn main() {
2960 let mut x = 0;
2961 while x < 10 {
2962 x += 1;
2963 }
2964 }
2965 "}
2966 );
2967 }
2968
2969 #[gpui::test(iterations = 10)]
2970 async fn test_autoindent_when_generating_before_indentation(
2971 cx: &mut TestAppContext,
2972 mut rng: StdRng,
2973 ) {
2974 cx.update(LanguageModelRegistry::test);
2975 cx.set_global(cx.update(SettingsStore::test));
2976 cx.update(language_settings::init);
2977
2978 let text = concat!(
2979 "fn main() {\n",
2980 " \n",
2981 "}\n" //
2982 );
2983 let buffer =
2984 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2985 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2986 let range = buffer.read_with(cx, |buffer, cx| {
2987 let snapshot = buffer.snapshot(cx);
2988 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2989 });
2990 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2991 let codegen = cx.new_model(|cx| {
2992 Codegen::new(
2993 buffer.clone(),
2994 range.clone(),
2995 None,
2996 None,
2997 prompt_builder,
2998 cx,
2999 )
3000 });
3001
3002 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3003 codegen.update(cx, |codegen, cx| {
3004 codegen.handle_stream(
3005 String::new(),
3006 range.clone(),
3007 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3008 cx,
3009 )
3010 });
3011
3012 cx.background_executor.run_until_parked();
3013
3014 let mut new_text = concat!(
3015 "let mut x = 0;\n",
3016 "while x < 10 {\n",
3017 " x += 1;\n",
3018 "}", //
3019 );
3020 while !new_text.is_empty() {
3021 let max_len = cmp::min(new_text.len(), 10);
3022 let len = rng.gen_range(1..=max_len);
3023 let (chunk, suffix) = new_text.split_at(len);
3024 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3025 new_text = suffix;
3026 cx.background_executor.run_until_parked();
3027 }
3028 drop(chunks_tx);
3029 cx.background_executor.run_until_parked();
3030
3031 assert_eq!(
3032 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3033 indoc! {"
3034 fn main() {
3035 let mut x = 0;
3036 while x < 10 {
3037 x += 1;
3038 }
3039 }
3040 "}
3041 );
3042 }
3043
3044 #[gpui::test(iterations = 10)]
3045 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3046 cx.update(LanguageModelRegistry::test);
3047 cx.set_global(cx.update(SettingsStore::test));
3048 cx.update(language_settings::init);
3049
3050 let text = indoc! {"
3051 func main() {
3052 \tx := 0
3053 \tfor i := 0; i < 10; i++ {
3054 \t\tx++
3055 \t}
3056 }
3057 "};
3058 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3059 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3060 let range = buffer.read_with(cx, |buffer, cx| {
3061 let snapshot = buffer.snapshot(cx);
3062 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3063 });
3064 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3065 let codegen = cx.new_model(|cx| {
3066 Codegen::new(
3067 buffer.clone(),
3068 range.clone(),
3069 None,
3070 None,
3071 prompt_builder,
3072 cx,
3073 )
3074 });
3075
3076 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3077 codegen.update(cx, |codegen, cx| {
3078 codegen.handle_stream(
3079 String::new(),
3080 range.clone(),
3081 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
3082 cx,
3083 )
3084 });
3085
3086 let new_text = concat!(
3087 "func main() {\n",
3088 "\tx := 0\n",
3089 "\tfor x < 10 {\n",
3090 "\t\tx++\n",
3091 "\t}", //
3092 );
3093 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3094 drop(chunks_tx);
3095 cx.background_executor.run_until_parked();
3096
3097 assert_eq!(
3098 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3099 indoc! {"
3100 func main() {
3101 \tx := 0
3102 \tfor x < 10 {
3103 \t\tx++
3104 \t}
3105 }
3106 "}
3107 );
3108 }
3109
3110 #[gpui::test]
3111 async fn test_strip_invalid_spans_from_codeblock() {
3112 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3113 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3114 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3115 assert_chunks(
3116 "```html\n```js\nLorem ipsum dolor\n```\n```",
3117 "```js\nLorem ipsum dolor\n```",
3118 )
3119 .await;
3120 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3121 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3122 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3123 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3124
3125 async fn assert_chunks(text: &str, expected_text: &str) {
3126 for chunk_size in 1..=text.len() {
3127 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3128 .map(|chunk| chunk.unwrap())
3129 .collect::<String>()
3130 .await;
3131 assert_eq!(
3132 actual_text, expected_text,
3133 "failed to strip invalid spans, chunk size: {}",
3134 chunk_size
3135 );
3136 }
3137 }
3138
3139 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3140 stream::iter(
3141 text.chars()
3142 .collect::<Vec<_>>()
3143 .chunks(size)
3144 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3145 .collect::<Vec<_>>(),
3146 )
3147 }
3148 }
3149
3150 fn rust_lang() -> Language {
3151 Language::new(
3152 LanguageConfig {
3153 name: "Rust".into(),
3154 matcher: LanguageMatcher {
3155 path_suffixes: vec!["rs".to_string()],
3156 ..Default::default()
3157 },
3158 ..Default::default()
3159 },
3160 Some(tree_sitter_rust::language()),
3161 )
3162 .with_indents_query(
3163 r#"
3164 (call_expression) @indent
3165 (field_expression) @indent
3166 (_ "(" ")" @end) @indent
3167 (_ "{" "}" @end) @indent
3168 "#,
3169 )
3170 .unwrap()
3171 }
3172}