1use crate::{
2 humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
3 Hunk, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, Selection, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use similar::TextDiff;
39use smol::future::FutureExt;
40use std::{
41 cmp,
42 future::{self, Future},
43 mem,
44 ops::{Range, RangeInclusive},
45 pin::Pin,
46 sync::Arc,
47 task::{self, Poll},
48 time::{Duration, Instant},
49};
50use theme::ThemeSettings;
51use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
52use util::{RangeExt, ResultExt};
53use workspace::{notifications::NotificationId, Toast, Workspace};
54
55pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
56 cx.set_global(InlineAssistant::new(fs, telemetry));
57}
58
59const PROMPT_HISTORY_MAX_LEN: usize = 20;
60
61pub struct InlineAssistant {
62 next_assist_id: InlineAssistId,
63 next_assist_group_id: InlineAssistGroupId,
64 assists: HashMap<InlineAssistId, InlineAssist>,
65 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
66 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
67 prompt_history: VecDeque<String>,
68 telemetry: Option<Arc<Telemetry>>,
69 fs: Arc<dyn Fs>,
70}
71
72impl Global for InlineAssistant {}
73
74impl InlineAssistant {
75 pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
76 Self {
77 next_assist_id: InlineAssistId::default(),
78 next_assist_group_id: InlineAssistGroupId::default(),
79 assists: HashMap::default(),
80 assists_by_editor: HashMap::default(),
81 assist_groups: HashMap::default(),
82 prompt_history: VecDeque::default(),
83 telemetry: Some(telemetry),
84 fs,
85 }
86 }
87
88 pub fn assist(
89 &mut self,
90 editor: &View<Editor>,
91 workspace: Option<WeakView<Workspace>>,
92 assistant_panel: Option<&View<AssistantPanel>>,
93 initial_prompt: Option<String>,
94 cx: &mut WindowContext,
95 ) {
96 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
97
98 let mut selections = Vec::<Selection<Point>>::new();
99 let mut newest_selection = None;
100 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
101 if selection.end > selection.start {
102 selection.start.column = 0;
103 // If the selection ends at the start of the line, we don't want to include it.
104 if selection.end.column == 0 {
105 selection.end.row -= 1;
106 }
107 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
108 }
109
110 if let Some(prev_selection) = selections.last_mut() {
111 if selection.start <= prev_selection.end {
112 prev_selection.end = selection.end;
113 continue;
114 }
115 }
116
117 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
118 if selection.id > latest_selection.id {
119 *latest_selection = selection.clone();
120 }
121 selections.push(selection);
122 }
123 let newest_selection = newest_selection.unwrap();
124
125 let mut codegen_ranges = Vec::new();
126 for (excerpt_id, buffer, buffer_range) in
127 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
128 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
129 }))
130 {
131 let start = Anchor {
132 buffer_id: Some(buffer.remote_id()),
133 excerpt_id,
134 text_anchor: buffer.anchor_before(buffer_range.start),
135 };
136 let end = Anchor {
137 buffer_id: Some(buffer.remote_id()),
138 excerpt_id,
139 text_anchor: buffer.anchor_after(buffer_range.end),
140 };
141 codegen_ranges.push(start..end);
142 }
143
144 let assist_group_id = self.next_assist_group_id.post_inc();
145 let prompt_buffer =
146 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
147 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
148
149 let mut assists = Vec::new();
150 let mut assist_to_focus = None;
151 for range in codegen_ranges {
152 let assist_id = self.next_assist_id.post_inc();
153 let codegen = cx.new_model(|cx| {
154 Codegen::new(
155 editor.read(cx).buffer().clone(),
156 range.clone(),
157 None,
158 self.telemetry.clone(),
159 cx,
160 )
161 });
162
163 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
164 let prompt_editor = cx.new_view(|cx| {
165 PromptEditor::new(
166 assist_id,
167 gutter_dimensions.clone(),
168 self.prompt_history.clone(),
169 prompt_buffer.clone(),
170 codegen.clone(),
171 editor,
172 assistant_panel,
173 workspace.clone(),
174 self.fs.clone(),
175 cx,
176 )
177 });
178
179 if assist_to_focus.is_none() {
180 let focus_assist = if newest_selection.reversed {
181 range.start.to_point(&snapshot) == newest_selection.start
182 } else {
183 range.end.to_point(&snapshot) == newest_selection.end
184 };
185 if focus_assist {
186 assist_to_focus = Some(assist_id);
187 }
188 }
189
190 let [prompt_block_id, end_block_id] =
191 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
192
193 assists.push((
194 assist_id,
195 range,
196 prompt_editor,
197 prompt_block_id,
198 end_block_id,
199 ));
200 }
201
202 let editor_assists = self
203 .assists_by_editor
204 .entry(editor.downgrade())
205 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
206 let mut assist_group = InlineAssistGroup::new();
207 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
208 self.assists.insert(
209 assist_id,
210 InlineAssist::new(
211 assist_id,
212 assist_group_id,
213 assistant_panel.is_some(),
214 editor,
215 &prompt_editor,
216 prompt_block_id,
217 end_block_id,
218 range,
219 prompt_editor.read(cx).codegen.clone(),
220 workspace.clone(),
221 cx,
222 ),
223 );
224 assist_group.assist_ids.push(assist_id);
225 editor_assists.assist_ids.push(assist_id);
226 }
227 self.assist_groups.insert(assist_group_id, assist_group);
228
229 if let Some(assist_id) = assist_to_focus {
230 self.focus_assist(assist_id, cx);
231 }
232 }
233
234 #[allow(clippy::too_many_arguments)]
235 pub fn suggest_assist(
236 &mut self,
237 editor: &View<Editor>,
238 mut range: Range<Anchor>,
239 initial_prompt: String,
240 initial_transaction_id: Option<TransactionId>,
241 workspace: Option<WeakView<Workspace>>,
242 assistant_panel: Option<&View<AssistantPanel>>,
243 cx: &mut WindowContext,
244 ) -> InlineAssistId {
245 let assist_group_id = self.next_assist_group_id.post_inc();
246 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
247 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
248
249 let assist_id = self.next_assist_id.post_inc();
250
251 let buffer = editor.read(cx).buffer().clone();
252 {
253 let snapshot = buffer.read(cx).read(cx);
254 range.start = range.start.bias_left(&snapshot);
255 range.end = range.end.bias_right(&snapshot);
256 }
257
258 let codegen = cx.new_model(|cx| {
259 Codegen::new(
260 editor.read(cx).buffer().clone(),
261 range.clone(),
262 initial_transaction_id,
263 self.telemetry.clone(),
264 cx,
265 )
266 });
267
268 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
269 let prompt_editor = cx.new_view(|cx| {
270 PromptEditor::new(
271 assist_id,
272 gutter_dimensions.clone(),
273 self.prompt_history.clone(),
274 prompt_buffer.clone(),
275 codegen.clone(),
276 editor,
277 assistant_panel,
278 workspace.clone(),
279 self.fs.clone(),
280 cx,
281 )
282 });
283
284 let [prompt_block_id, end_block_id] =
285 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
286
287 let editor_assists = self
288 .assists_by_editor
289 .entry(editor.downgrade())
290 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
291
292 let mut assist_group = InlineAssistGroup::new();
293 self.assists.insert(
294 assist_id,
295 InlineAssist::new(
296 assist_id,
297 assist_group_id,
298 assistant_panel.is_some(),
299 editor,
300 &prompt_editor,
301 prompt_block_id,
302 end_block_id,
303 range,
304 prompt_editor.read(cx).codegen.clone(),
305 workspace.clone(),
306 cx,
307 ),
308 );
309 assist_group.assist_ids.push(assist_id);
310 editor_assists.assist_ids.push(assist_id);
311 self.assist_groups.insert(assist_group_id, assist_group);
312 assist_id
313 }
314
315 fn insert_assist_blocks(
316 &self,
317 editor: &View<Editor>,
318 range: &Range<Anchor>,
319 prompt_editor: &View<PromptEditor>,
320 cx: &mut WindowContext,
321 ) -> [CustomBlockId; 2] {
322 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
323 prompt_editor
324 .editor
325 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
326 });
327 let assist_blocks = vec![
328 BlockProperties {
329 style: BlockStyle::Sticky,
330 position: range.start,
331 height: prompt_editor_height,
332 render: build_assist_editor_renderer(prompt_editor),
333 disposition: BlockDisposition::Above,
334 },
335 BlockProperties {
336 style: BlockStyle::Sticky,
337 position: range.end,
338 height: 1,
339 render: Box::new(|cx| {
340 v_flex()
341 .h_full()
342 .w_full()
343 .border_t_1()
344 .border_color(cx.theme().status().info_border)
345 .into_any_element()
346 }),
347 disposition: BlockDisposition::Below,
348 },
349 ];
350
351 editor.update(cx, |editor, cx| {
352 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
353 [block_ids[0], block_ids[1]]
354 })
355 }
356
357 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
358 let assist = &self.assists[&assist_id];
359 let Some(decorations) = assist.decorations.as_ref() else {
360 return;
361 };
362 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
363 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
364
365 assist_group.active_assist_id = Some(assist_id);
366 if assist_group.linked {
367 for assist_id in &assist_group.assist_ids {
368 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
369 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
370 prompt_editor.set_show_cursor_when_unfocused(true, cx)
371 });
372 }
373 }
374 }
375
376 assist
377 .editor
378 .update(cx, |editor, cx| {
379 let scroll_top = editor.scroll_position(cx).y;
380 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
381 let prompt_row = editor
382 .row_for_block(decorations.prompt_block_id, cx)
383 .unwrap()
384 .0 as f32;
385
386 if (scroll_top..scroll_bottom).contains(&prompt_row) {
387 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
388 assist_id,
389 distance_from_top: prompt_row - scroll_top,
390 });
391 } else {
392 editor_assists.scroll_lock = None;
393 }
394 })
395 .ok();
396 }
397
398 fn handle_prompt_editor_focus_out(
399 &mut self,
400 assist_id: InlineAssistId,
401 cx: &mut WindowContext,
402 ) {
403 let assist = &self.assists[&assist_id];
404 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
405 if assist_group.active_assist_id == Some(assist_id) {
406 assist_group.active_assist_id = None;
407 if assist_group.linked {
408 for assist_id in &assist_group.assist_ids {
409 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
410 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
411 prompt_editor.set_show_cursor_when_unfocused(false, cx)
412 });
413 }
414 }
415 }
416 }
417 }
418
419 fn handle_prompt_editor_event(
420 &mut self,
421 prompt_editor: View<PromptEditor>,
422 event: &PromptEditorEvent,
423 cx: &mut WindowContext,
424 ) {
425 let assist_id = prompt_editor.read(cx).id;
426 match event {
427 PromptEditorEvent::StartRequested => {
428 self.start_assist(assist_id, cx);
429 }
430 PromptEditorEvent::StopRequested => {
431 self.stop_assist(assist_id, cx);
432 }
433 PromptEditorEvent::ConfirmRequested => {
434 self.finish_assist(assist_id, false, cx);
435 }
436 PromptEditorEvent::CancelRequested => {
437 self.finish_assist(assist_id, true, cx);
438 }
439 PromptEditorEvent::DismissRequested => {
440 self.dismiss_assist(assist_id, cx);
441 }
442 }
443 }
444
445 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
446 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
447 return;
448 };
449
450 let editor = editor.read(cx);
451 if editor.selections.count() == 1 {
452 let selection = editor.selections.newest::<usize>(cx);
453 let buffer = editor.buffer().read(cx).snapshot(cx);
454 for assist_id in &editor_assists.assist_ids {
455 let assist = &self.assists[assist_id];
456 let assist_range = assist.range.to_offset(&buffer);
457 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
458 {
459 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
460 self.dismiss_assist(*assist_id, cx);
461 } else {
462 self.finish_assist(*assist_id, false, cx);
463 }
464
465 return;
466 }
467 }
468 }
469
470 cx.propagate();
471 }
472
473 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
474 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
475 return;
476 };
477
478 let editor = editor.read(cx);
479 if editor.selections.count() == 1 {
480 let selection = editor.selections.newest::<usize>(cx);
481 let buffer = editor.buffer().read(cx).snapshot(cx);
482 for assist_id in &editor_assists.assist_ids {
483 let assist = &self.assists[assist_id];
484 let assist_range = assist.range.to_offset(&buffer);
485 if assist.decorations.is_some()
486 && assist_range.contains(&selection.start)
487 && assist_range.contains(&selection.end)
488 {
489 self.focus_assist(*assist_id, cx);
490 return;
491 }
492 }
493 }
494
495 cx.propagate();
496 }
497
498 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
499 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
500 for assist_id in editor_assists.assist_ids.clone() {
501 self.finish_assist(assist_id, true, cx);
502 }
503 }
504 }
505
506 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
507 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
508 return;
509 };
510 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
511 return;
512 };
513 let assist = &self.assists[&scroll_lock.assist_id];
514 let Some(decorations) = assist.decorations.as_ref() else {
515 return;
516 };
517
518 editor.update(cx, |editor, cx| {
519 let scroll_position = editor.scroll_position(cx);
520 let target_scroll_top = editor
521 .row_for_block(decorations.prompt_block_id, cx)
522 .unwrap()
523 .0 as f32
524 - scroll_lock.distance_from_top;
525 if target_scroll_top != scroll_position.y {
526 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
527 }
528 });
529 }
530
531 fn handle_editor_event(
532 &mut self,
533 editor: View<Editor>,
534 event: &EditorEvent,
535 cx: &mut WindowContext,
536 ) {
537 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
538 return;
539 };
540
541 match event {
542 EditorEvent::Saved => {
543 for assist_id in editor_assists.assist_ids.clone() {
544 let assist = &self.assists[&assist_id];
545 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
546 self.finish_assist(assist_id, false, cx)
547 }
548 }
549 }
550 EditorEvent::Edited { transaction_id } => {
551 let buffer = editor.read(cx).buffer().read(cx);
552 let edited_ranges =
553 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
554 let snapshot = buffer.snapshot(cx);
555
556 for assist_id in editor_assists.assist_ids.clone() {
557 let assist = &self.assists[&assist_id];
558 if matches!(
559 assist.codegen.read(cx).status,
560 CodegenStatus::Error(_) | CodegenStatus::Done
561 ) {
562 let assist_range = assist.range.to_offset(&snapshot);
563 if edited_ranges
564 .iter()
565 .any(|range| range.overlaps(&assist_range))
566 {
567 self.finish_assist(assist_id, false, cx);
568 }
569 }
570 }
571 }
572 EditorEvent::ScrollPositionChanged { .. } => {
573 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
574 let assist = &self.assists[&scroll_lock.assist_id];
575 if let Some(decorations) = assist.decorations.as_ref() {
576 let distance_from_top = editor.update(cx, |editor, cx| {
577 let scroll_top = editor.scroll_position(cx).y;
578 let prompt_row = editor
579 .row_for_block(decorations.prompt_block_id, cx)
580 .unwrap()
581 .0 as f32;
582 prompt_row - scroll_top
583 });
584
585 if distance_from_top != scroll_lock.distance_from_top {
586 editor_assists.scroll_lock = None;
587 }
588 }
589 }
590 }
591 EditorEvent::SelectionsChanged { .. } => {
592 for assist_id in editor_assists.assist_ids.clone() {
593 let assist = &self.assists[&assist_id];
594 if let Some(decorations) = assist.decorations.as_ref() {
595 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
596 return;
597 }
598 }
599 }
600
601 editor_assists.scroll_lock = None;
602 }
603 _ => {}
604 }
605 }
606
607 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
608 if let Some(assist) = self.assists.get(&assist_id) {
609 let assist_group_id = assist.group_id;
610 if self.assist_groups[&assist_group_id].linked {
611 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
612 self.finish_assist(assist_id, undo, cx);
613 }
614 return;
615 }
616 }
617
618 self.dismiss_assist(assist_id, cx);
619
620 if let Some(assist) = self.assists.remove(&assist_id) {
621 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
622 {
623 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
624 if entry.get().assist_ids.is_empty() {
625 entry.remove();
626 }
627 }
628
629 if let hash_map::Entry::Occupied(mut entry) =
630 self.assists_by_editor.entry(assist.editor.clone())
631 {
632 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
633 if entry.get().assist_ids.is_empty() {
634 entry.remove();
635 if let Some(editor) = assist.editor.upgrade() {
636 self.update_editor_highlights(&editor, cx);
637 }
638 } else {
639 entry.get().highlight_updates.send(()).ok();
640 }
641 }
642
643 if undo {
644 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
645 }
646 }
647 }
648
649 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
650 let Some(assist) = self.assists.get_mut(&assist_id) else {
651 return false;
652 };
653 let Some(editor) = assist.editor.upgrade() else {
654 return false;
655 };
656 let Some(decorations) = assist.decorations.take() else {
657 return false;
658 };
659
660 editor.update(cx, |editor, cx| {
661 let mut to_remove = decorations.removed_line_block_ids;
662 to_remove.insert(decorations.prompt_block_id);
663 to_remove.insert(decorations.end_block_id);
664 editor.remove_blocks(to_remove, None, cx);
665 });
666
667 if decorations
668 .prompt_editor
669 .focus_handle(cx)
670 .contains_focused(cx)
671 {
672 self.focus_next_assist(assist_id, cx);
673 }
674
675 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
676 if editor_assists
677 .scroll_lock
678 .as_ref()
679 .map_or(false, |lock| lock.assist_id == assist_id)
680 {
681 editor_assists.scroll_lock = None;
682 }
683 editor_assists.highlight_updates.send(()).ok();
684 }
685
686 true
687 }
688
689 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
690 let Some(assist) = self.assists.get(&assist_id) else {
691 return;
692 };
693
694 let assist_group = &self.assist_groups[&assist.group_id];
695 let assist_ix = assist_group
696 .assist_ids
697 .iter()
698 .position(|id| *id == assist_id)
699 .unwrap();
700 let assist_ids = assist_group
701 .assist_ids
702 .iter()
703 .skip(assist_ix + 1)
704 .chain(assist_group.assist_ids.iter().take(assist_ix));
705
706 for assist_id in assist_ids {
707 let assist = &self.assists[assist_id];
708 if assist.decorations.is_some() {
709 self.focus_assist(*assist_id, cx);
710 return;
711 }
712 }
713
714 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
715 }
716
717 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
718 let Some(assist) = self.assists.get(&assist_id) else {
719 return;
720 };
721
722 if let Some(decorations) = assist.decorations.as_ref() {
723 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
724 prompt_editor.editor.update(cx, |editor, cx| {
725 editor.focus(cx);
726 editor.select_all(&SelectAll, cx);
727 })
728 });
729 }
730
731 self.scroll_to_assist(assist_id, cx);
732 }
733
734 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
735 let Some(assist) = self.assists.get(&assist_id) else {
736 return;
737 };
738 let Some(editor) = assist.editor.upgrade() else {
739 return;
740 };
741
742 let position = assist.range.start;
743 editor.update(cx, |editor, cx| {
744 editor.change_selections(None, cx, |selections| {
745 selections.select_anchor_ranges([position..position])
746 });
747
748 let mut scroll_target_top;
749 let mut scroll_target_bottom;
750 if let Some(decorations) = assist.decorations.as_ref() {
751 scroll_target_top = editor
752 .row_for_block(decorations.prompt_block_id, cx)
753 .unwrap()
754 .0 as f32;
755 scroll_target_bottom = editor
756 .row_for_block(decorations.end_block_id, cx)
757 .unwrap()
758 .0 as f32;
759 } else {
760 let snapshot = editor.snapshot(cx);
761 let start_row = assist
762 .range
763 .start
764 .to_display_point(&snapshot.display_snapshot)
765 .row();
766 scroll_target_top = start_row.0 as f32;
767 scroll_target_bottom = scroll_target_top + 1.;
768 }
769 scroll_target_top -= editor.vertical_scroll_margin() as f32;
770 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
771
772 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
773 let scroll_top = editor.scroll_position(cx).y;
774 let scroll_bottom = scroll_top + height_in_lines;
775
776 if scroll_target_top < scroll_top {
777 editor.set_scroll_position(point(0., scroll_target_top), cx);
778 } else if scroll_target_bottom > scroll_bottom {
779 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
780 editor
781 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
782 } else {
783 editor.set_scroll_position(point(0., scroll_target_top), cx);
784 }
785 }
786 });
787 }
788
789 fn unlink_assist_group(
790 &mut self,
791 assist_group_id: InlineAssistGroupId,
792 cx: &mut WindowContext,
793 ) -> Vec<InlineAssistId> {
794 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
795 assist_group.linked = false;
796 for assist_id in &assist_group.assist_ids {
797 let assist = self.assists.get_mut(assist_id).unwrap();
798 if let Some(editor_decorations) = assist.decorations.as_ref() {
799 editor_decorations
800 .prompt_editor
801 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
802 }
803 }
804 assist_group.assist_ids.clone()
805 }
806
807 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
808 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
809 assist
810 } else {
811 return;
812 };
813
814 let assist_group_id = assist.group_id;
815 if self.assist_groups[&assist_group_id].linked {
816 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
817 self.start_assist(assist_id, cx);
818 }
819 return;
820 }
821
822 let Some(user_prompt) = assist.user_prompt(cx) else {
823 return;
824 };
825
826 self.prompt_history.retain(|prompt| *prompt != user_prompt);
827 self.prompt_history.push_back(user_prompt.clone());
828 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
829 self.prompt_history.pop_front();
830 }
831
832 let assistant_panel_context = assist.assistant_panel_context(cx);
833
834 assist
835 .codegen
836 .update(cx, |codegen, cx| {
837 codegen.start(
838 assist.range.clone(),
839 user_prompt,
840 assistant_panel_context,
841 cx,
842 )
843 })
844 .log_err();
845 }
846
847 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
848 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
849 assist
850 } else {
851 return;
852 };
853
854 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
855 }
856
857 pub fn status_for_assist(
858 &self,
859 assist_id: InlineAssistId,
860 cx: &WindowContext,
861 ) -> Option<CodegenStatus> {
862 let assist = self.assists.get(&assist_id)?;
863 match &assist.codegen.read(cx).status {
864 CodegenStatus::Idle => Some(CodegenStatus::Idle),
865 CodegenStatus::Pending => Some(CodegenStatus::Pending),
866 CodegenStatus::Done => Some(CodegenStatus::Done),
867 CodegenStatus::Error(error) => Some(CodegenStatus::Error(anyhow!("{:?}", error))),
868 }
869 }
870
871 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
872 let mut gutter_pending_ranges = Vec::new();
873 let mut gutter_transformed_ranges = Vec::new();
874 let mut foreground_ranges = Vec::new();
875 let mut inserted_row_ranges = Vec::new();
876 let empty_assist_ids = Vec::new();
877 let assist_ids = self
878 .assists_by_editor
879 .get(&editor.downgrade())
880 .map_or(&empty_assist_ids, |editor_assists| {
881 &editor_assists.assist_ids
882 });
883
884 for assist_id in assist_ids {
885 if let Some(assist) = self.assists.get(assist_id) {
886 let codegen = assist.codegen.read(cx);
887 let buffer = codegen.buffer.read(cx).read(cx);
888 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
889
890 let pending_range =
891 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
892 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
893 gutter_pending_ranges.push(pending_range);
894 }
895
896 if let Some(edit_position) = codegen.edit_position {
897 let edited_range = assist.range.start..edit_position;
898 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
899 gutter_transformed_ranges.push(edited_range);
900 }
901 }
902
903 if assist.decorations.is_some() {
904 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
905 }
906 }
907 }
908
909 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
910 merge_ranges(&mut foreground_ranges, &snapshot);
911 merge_ranges(&mut gutter_pending_ranges, &snapshot);
912 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
913 editor.update(cx, |editor, cx| {
914 enum GutterPendingRange {}
915 if gutter_pending_ranges.is_empty() {
916 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
917 } else {
918 editor.highlight_gutter::<GutterPendingRange>(
919 &gutter_pending_ranges,
920 |cx| cx.theme().status().info_background,
921 cx,
922 )
923 }
924
925 enum GutterTransformedRange {}
926 if gutter_transformed_ranges.is_empty() {
927 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
928 } else {
929 editor.highlight_gutter::<GutterTransformedRange>(
930 &gutter_transformed_ranges,
931 |cx| cx.theme().status().info,
932 cx,
933 )
934 }
935
936 if foreground_ranges.is_empty() {
937 editor.clear_highlights::<InlineAssist>(cx);
938 } else {
939 editor.highlight_text::<InlineAssist>(
940 foreground_ranges,
941 HighlightStyle {
942 fade_out: Some(0.6),
943 ..Default::default()
944 },
945 cx,
946 );
947 }
948
949 editor.clear_row_highlights::<InlineAssist>();
950 for row_range in inserted_row_ranges {
951 editor.highlight_rows::<InlineAssist>(
952 row_range,
953 Some(cx.theme().status().info_background),
954 false,
955 cx,
956 );
957 }
958 });
959 }
960
961 fn update_editor_blocks(
962 &mut self,
963 editor: &View<Editor>,
964 assist_id: InlineAssistId,
965 cx: &mut WindowContext,
966 ) {
967 let Some(assist) = self.assists.get_mut(&assist_id) else {
968 return;
969 };
970 let Some(decorations) = assist.decorations.as_mut() else {
971 return;
972 };
973
974 let codegen = assist.codegen.read(cx);
975 let old_snapshot = codegen.snapshot.clone();
976 let old_buffer = codegen.old_buffer.clone();
977 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
978
979 editor.update(cx, |editor, cx| {
980 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
981 editor.remove_blocks(old_blocks, None, cx);
982
983 let mut new_blocks = Vec::new();
984 for (new_row, old_row_range) in deleted_row_ranges {
985 let (_, buffer_start) = old_snapshot
986 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
987 .unwrap();
988 let (_, buffer_end) = old_snapshot
989 .point_to_buffer_offset(Point::new(
990 *old_row_range.end(),
991 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
992 ))
993 .unwrap();
994
995 let deleted_lines_editor = cx.new_view(|cx| {
996 let multi_buffer = cx.new_model(|_| {
997 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
998 });
999 multi_buffer.update(cx, |multi_buffer, cx| {
1000 multi_buffer.push_excerpts(
1001 old_buffer.clone(),
1002 Some(ExcerptRange {
1003 context: buffer_start..buffer_end,
1004 primary: None,
1005 }),
1006 cx,
1007 );
1008 });
1009
1010 enum DeletedLines {}
1011 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1012 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1013 editor.set_show_wrap_guides(false, cx);
1014 editor.set_show_gutter(false, cx);
1015 editor.scroll_manager.set_forbid_vertical_scroll(true);
1016 editor.set_read_only(true);
1017 editor.highlight_rows::<DeletedLines>(
1018 Anchor::min()..=Anchor::max(),
1019 Some(cx.theme().status().deleted_background),
1020 false,
1021 cx,
1022 );
1023 editor
1024 });
1025
1026 let height =
1027 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1028 new_blocks.push(BlockProperties {
1029 position: new_row,
1030 height,
1031 style: BlockStyle::Flex,
1032 render: Box::new(move |cx| {
1033 div()
1034 .bg(cx.theme().status().deleted_background)
1035 .size_full()
1036 .h(height as f32 * cx.line_height())
1037 .pl(cx.gutter_dimensions.full_width())
1038 .child(deleted_lines_editor.clone())
1039 .into_any_element()
1040 }),
1041 disposition: BlockDisposition::Above,
1042 });
1043 }
1044
1045 decorations.removed_line_block_ids = editor
1046 .insert_blocks(new_blocks, None, cx)
1047 .into_iter()
1048 .collect();
1049 })
1050 }
1051}
1052
1053struct EditorInlineAssists {
1054 assist_ids: Vec<InlineAssistId>,
1055 scroll_lock: Option<InlineAssistScrollLock>,
1056 highlight_updates: async_watch::Sender<()>,
1057 _update_highlights: Task<Result<()>>,
1058 _subscriptions: Vec<gpui::Subscription>,
1059}
1060
1061struct InlineAssistScrollLock {
1062 assist_id: InlineAssistId,
1063 distance_from_top: f32,
1064}
1065
1066impl EditorInlineAssists {
1067 #[allow(clippy::too_many_arguments)]
1068 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1069 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1070 Self {
1071 assist_ids: Vec::new(),
1072 scroll_lock: None,
1073 highlight_updates: highlight_updates_tx,
1074 _update_highlights: cx.spawn(|mut cx| {
1075 let editor = editor.downgrade();
1076 async move {
1077 while let Ok(()) = highlight_updates_rx.changed().await {
1078 let editor = editor.upgrade().context("editor was dropped")?;
1079 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1080 assistant.update_editor_highlights(&editor, cx);
1081 })?;
1082 }
1083 Ok(())
1084 }
1085 }),
1086 _subscriptions: vec![
1087 cx.observe_release(editor, {
1088 let editor = editor.downgrade();
1089 |_, cx| {
1090 InlineAssistant::update_global(cx, |this, cx| {
1091 this.handle_editor_release(editor, cx);
1092 })
1093 }
1094 }),
1095 cx.observe(editor, move |editor, cx| {
1096 InlineAssistant::update_global(cx, |this, cx| {
1097 this.handle_editor_change(editor, cx)
1098 })
1099 }),
1100 cx.subscribe(editor, move |editor, event, cx| {
1101 InlineAssistant::update_global(cx, |this, cx| {
1102 this.handle_editor_event(editor, event, cx)
1103 })
1104 }),
1105 editor.update(cx, |editor, cx| {
1106 let editor_handle = cx.view().downgrade();
1107 editor.register_action(
1108 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1109 InlineAssistant::update_global(cx, |this, cx| {
1110 if let Some(editor) = editor_handle.upgrade() {
1111 this.handle_editor_newline(editor, cx)
1112 }
1113 })
1114 },
1115 )
1116 }),
1117 editor.update(cx, |editor, cx| {
1118 let editor_handle = cx.view().downgrade();
1119 editor.register_action(
1120 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1121 InlineAssistant::update_global(cx, |this, cx| {
1122 if let Some(editor) = editor_handle.upgrade() {
1123 this.handle_editor_cancel(editor, cx)
1124 }
1125 })
1126 },
1127 )
1128 }),
1129 ],
1130 }
1131 }
1132}
1133
1134struct InlineAssistGroup {
1135 assist_ids: Vec<InlineAssistId>,
1136 linked: bool,
1137 active_assist_id: Option<InlineAssistId>,
1138}
1139
1140impl InlineAssistGroup {
1141 fn new() -> Self {
1142 Self {
1143 assist_ids: Vec::new(),
1144 linked: true,
1145 active_assist_id: None,
1146 }
1147 }
1148}
1149
1150fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1151 let editor = editor.clone();
1152 Box::new(move |cx: &mut BlockContext| {
1153 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1154 editor.clone().into_any_element()
1155 })
1156}
1157
1158#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1159pub enum InitialInsertion {
1160 NewlineBefore,
1161 NewlineAfter,
1162}
1163
1164#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1165pub struct InlineAssistId(usize);
1166
1167impl InlineAssistId {
1168 fn post_inc(&mut self) -> InlineAssistId {
1169 let id = *self;
1170 self.0 += 1;
1171 id
1172 }
1173}
1174
1175#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1176struct InlineAssistGroupId(usize);
1177
1178impl InlineAssistGroupId {
1179 fn post_inc(&mut self) -> InlineAssistGroupId {
1180 let id = *self;
1181 self.0 += 1;
1182 id
1183 }
1184}
1185
1186enum PromptEditorEvent {
1187 StartRequested,
1188 StopRequested,
1189 ConfirmRequested,
1190 CancelRequested,
1191 DismissRequested,
1192}
1193
1194struct PromptEditor {
1195 id: InlineAssistId,
1196 fs: Arc<dyn Fs>,
1197 editor: View<Editor>,
1198 edited_since_done: bool,
1199 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1200 prompt_history: VecDeque<String>,
1201 prompt_history_ix: Option<usize>,
1202 pending_prompt: String,
1203 codegen: Model<Codegen>,
1204 _codegen_subscription: Subscription,
1205 editor_subscriptions: Vec<Subscription>,
1206 pending_token_count: Task<Result<()>>,
1207 token_count: Option<usize>,
1208 _token_count_subscriptions: Vec<Subscription>,
1209 workspace: Option<WeakView<Workspace>>,
1210 show_rate_limit_notice: bool,
1211}
1212
1213impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1214
1215impl Render for PromptEditor {
1216 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1217 let gutter_dimensions = *self.gutter_dimensions.lock();
1218 let status = &self.codegen.read(cx).status;
1219 let buttons = match status {
1220 CodegenStatus::Idle => {
1221 vec![
1222 IconButton::new("cancel", IconName::Close)
1223 .icon_color(Color::Muted)
1224 .shape(IconButtonShape::Square)
1225 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1226 .on_click(
1227 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1228 ),
1229 IconButton::new("start", IconName::SparkleAlt)
1230 .icon_color(Color::Muted)
1231 .shape(IconButtonShape::Square)
1232 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1233 .on_click(
1234 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1235 ),
1236 ]
1237 }
1238 CodegenStatus::Pending => {
1239 vec![
1240 IconButton::new("cancel", IconName::Close)
1241 .icon_color(Color::Muted)
1242 .shape(IconButtonShape::Square)
1243 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1244 .on_click(
1245 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1246 ),
1247 IconButton::new("stop", IconName::Stop)
1248 .icon_color(Color::Error)
1249 .shape(IconButtonShape::Square)
1250 .tooltip(|cx| {
1251 Tooltip::with_meta(
1252 "Interrupt Transformation",
1253 Some(&menu::Cancel),
1254 "Changes won't be discarded",
1255 cx,
1256 )
1257 })
1258 .on_click(
1259 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1260 ),
1261 ]
1262 }
1263 CodegenStatus::Error(_) | CodegenStatus::Done => {
1264 vec![
1265 IconButton::new("cancel", IconName::Close)
1266 .icon_color(Color::Muted)
1267 .shape(IconButtonShape::Square)
1268 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1269 .on_click(
1270 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1271 ),
1272 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1273 IconButton::new("restart", IconName::RotateCw)
1274 .icon_color(Color::Info)
1275 .shape(IconButtonShape::Square)
1276 .tooltip(|cx| {
1277 Tooltip::with_meta(
1278 "Restart Transformation",
1279 Some(&menu::Confirm),
1280 "Changes will be discarded",
1281 cx,
1282 )
1283 })
1284 .on_click(cx.listener(|_, _, cx| {
1285 cx.emit(PromptEditorEvent::StartRequested);
1286 }))
1287 } else {
1288 IconButton::new("confirm", IconName::Check)
1289 .icon_color(Color::Info)
1290 .shape(IconButtonShape::Square)
1291 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1292 .on_click(cx.listener(|_, _, cx| {
1293 cx.emit(PromptEditorEvent::ConfirmRequested);
1294 }))
1295 },
1296 ]
1297 }
1298 };
1299
1300 h_flex()
1301 .bg(cx.theme().colors().editor_background)
1302 .border_y_1()
1303 .border_color(cx.theme().status().info_border)
1304 .size_full()
1305 .py(cx.line_height() / 2.)
1306 .on_action(cx.listener(Self::confirm))
1307 .on_action(cx.listener(Self::cancel))
1308 .on_action(cx.listener(Self::move_up))
1309 .on_action(cx.listener(Self::move_down))
1310 .child(
1311 h_flex()
1312 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1313 .justify_center()
1314 .gap_2()
1315 .child(
1316 ModelSelector::new(
1317 self.fs.clone(),
1318 IconButton::new("context", IconName::SlidersAlt)
1319 .shape(IconButtonShape::Square)
1320 .icon_size(IconSize::Small)
1321 .icon_color(Color::Muted)
1322 .tooltip(move |cx| {
1323 Tooltip::with_meta(
1324 format!(
1325 "Using {}",
1326 LanguageModelRegistry::read_global(cx)
1327 .active_model()
1328 .map(|model| model.name().0)
1329 .unwrap_or_else(|| "No model selected".into()),
1330 ),
1331 None,
1332 "Change Model",
1333 cx,
1334 )
1335 }),
1336 )
1337 .with_info_text(
1338 "Inline edits use context\n\
1339 from the currently selected\n\
1340 assistant panel tab.",
1341 ),
1342 )
1343 .map(|el| {
1344 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1345 return el;
1346 };
1347
1348 let error_message = SharedString::from(error.to_string());
1349 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1350 && cx.has_flag::<ZedPro>()
1351 {
1352 el.child(
1353 v_flex()
1354 .child(
1355 IconButton::new("rate-limit-error", IconName::XCircle)
1356 .selected(self.show_rate_limit_notice)
1357 .shape(IconButtonShape::Square)
1358 .icon_size(IconSize::Small)
1359 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1360 )
1361 .children(self.show_rate_limit_notice.then(|| {
1362 deferred(
1363 anchored()
1364 .position_mode(gpui::AnchoredPositionMode::Local)
1365 .position(point(px(0.), px(24.)))
1366 .anchor(gpui::AnchorCorner::TopLeft)
1367 .child(self.render_rate_limit_notice(cx)),
1368 )
1369 })),
1370 )
1371 } else {
1372 el.child(
1373 div()
1374 .id("error")
1375 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1376 .child(
1377 Icon::new(IconName::XCircle)
1378 .size(IconSize::Small)
1379 .color(Color::Error),
1380 ),
1381 )
1382 }
1383 }),
1384 )
1385 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1386 .child(
1387 h_flex()
1388 .gap_2()
1389 .pr_6()
1390 .children(self.render_token_count(cx))
1391 .children(buttons),
1392 )
1393 }
1394}
1395
1396impl FocusableView for PromptEditor {
1397 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1398 self.editor.focus_handle(cx)
1399 }
1400}
1401
1402impl PromptEditor {
1403 const MAX_LINES: u8 = 8;
1404
1405 #[allow(clippy::too_many_arguments)]
1406 fn new(
1407 id: InlineAssistId,
1408 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1409 prompt_history: VecDeque<String>,
1410 prompt_buffer: Model<MultiBuffer>,
1411 codegen: Model<Codegen>,
1412 parent_editor: &View<Editor>,
1413 assistant_panel: Option<&View<AssistantPanel>>,
1414 workspace: Option<WeakView<Workspace>>,
1415 fs: Arc<dyn Fs>,
1416 cx: &mut ViewContext<Self>,
1417 ) -> Self {
1418 let prompt_editor = cx.new_view(|cx| {
1419 let mut editor = Editor::new(
1420 EditorMode::AutoHeight {
1421 max_lines: Self::MAX_LINES as usize,
1422 },
1423 prompt_buffer,
1424 None,
1425 false,
1426 cx,
1427 );
1428 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1429 // Since the prompt editors for all inline assistants are linked,
1430 // always show the cursor (even when it isn't focused) because
1431 // typing in one will make what you typed appear in all of them.
1432 editor.set_show_cursor_when_unfocused(true, cx);
1433 editor.set_placeholder_text("Add a prompt…", cx);
1434 editor
1435 });
1436
1437 let mut token_count_subscriptions = Vec::new();
1438 token_count_subscriptions
1439 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1440 if let Some(assistant_panel) = assistant_panel {
1441 token_count_subscriptions
1442 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1443 }
1444
1445 let mut this = Self {
1446 id,
1447 editor: prompt_editor,
1448 edited_since_done: false,
1449 gutter_dimensions,
1450 prompt_history,
1451 prompt_history_ix: None,
1452 pending_prompt: String::new(),
1453 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1454 editor_subscriptions: Vec::new(),
1455 codegen,
1456 fs,
1457 pending_token_count: Task::ready(Ok(())),
1458 token_count: None,
1459 _token_count_subscriptions: token_count_subscriptions,
1460 workspace,
1461 show_rate_limit_notice: false,
1462 };
1463 this.count_tokens(cx);
1464 this.subscribe_to_editor(cx);
1465 this
1466 }
1467
1468 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1469 self.editor_subscriptions.clear();
1470 self.editor_subscriptions
1471 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1472 }
1473
1474 fn set_show_cursor_when_unfocused(
1475 &mut self,
1476 show_cursor_when_unfocused: bool,
1477 cx: &mut ViewContext<Self>,
1478 ) {
1479 self.editor.update(cx, |editor, cx| {
1480 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1481 });
1482 }
1483
1484 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1485 let prompt = self.prompt(cx);
1486 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1487 self.editor = cx.new_view(|cx| {
1488 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1489 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1490 editor.set_placeholder_text("Add a prompt…", cx);
1491 editor.set_text(prompt, cx);
1492 if focus {
1493 editor.focus(cx);
1494 }
1495 editor
1496 });
1497 self.subscribe_to_editor(cx);
1498 }
1499
1500 fn prompt(&self, cx: &AppContext) -> String {
1501 self.editor.read(cx).text(cx)
1502 }
1503
1504 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1505 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1506 if self.show_rate_limit_notice {
1507 cx.focus_view(&self.editor);
1508 }
1509 cx.notify();
1510 }
1511
1512 fn handle_parent_editor_event(
1513 &mut self,
1514 _: View<Editor>,
1515 event: &EditorEvent,
1516 cx: &mut ViewContext<Self>,
1517 ) {
1518 if let EditorEvent::BufferEdited { .. } = event {
1519 self.count_tokens(cx);
1520 }
1521 }
1522
1523 fn handle_assistant_panel_event(
1524 &mut self,
1525 _: View<AssistantPanel>,
1526 event: &AssistantPanelEvent,
1527 cx: &mut ViewContext<Self>,
1528 ) {
1529 let AssistantPanelEvent::ContextEdited { .. } = event;
1530 self.count_tokens(cx);
1531 }
1532
1533 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1534 let assist_id = self.id;
1535 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1536 cx.background_executor().timer(Duration::from_secs(1)).await;
1537 let token_count = cx
1538 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1539 let assist = inline_assistant
1540 .assists
1541 .get(&assist_id)
1542 .context("assist not found")?;
1543 anyhow::Ok(assist.count_tokens(cx))
1544 })??
1545 .await?;
1546
1547 this.update(&mut cx, |this, cx| {
1548 this.token_count = Some(token_count);
1549 cx.notify();
1550 })
1551 })
1552 }
1553
1554 fn handle_prompt_editor_events(
1555 &mut self,
1556 _: View<Editor>,
1557 event: &EditorEvent,
1558 cx: &mut ViewContext<Self>,
1559 ) {
1560 match event {
1561 EditorEvent::Edited { .. } => {
1562 let prompt = self.editor.read(cx).text(cx);
1563 if self
1564 .prompt_history_ix
1565 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1566 {
1567 self.prompt_history_ix.take();
1568 self.pending_prompt = prompt;
1569 }
1570
1571 self.edited_since_done = true;
1572 cx.notify();
1573 }
1574 EditorEvent::BufferEdited => {
1575 self.count_tokens(cx);
1576 }
1577 EditorEvent::Blurred => {
1578 if self.show_rate_limit_notice {
1579 self.show_rate_limit_notice = false;
1580 cx.notify();
1581 }
1582 }
1583 _ => {}
1584 }
1585 }
1586
1587 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1588 match &self.codegen.read(cx).status {
1589 CodegenStatus::Idle => {
1590 self.editor
1591 .update(cx, |editor, _| editor.set_read_only(false));
1592 }
1593 CodegenStatus::Pending => {
1594 self.editor
1595 .update(cx, |editor, _| editor.set_read_only(true));
1596 }
1597 CodegenStatus::Done => {
1598 self.edited_since_done = false;
1599 self.editor
1600 .update(cx, |editor, _| editor.set_read_only(false));
1601 }
1602 CodegenStatus::Error(error) => {
1603 if cx.has_flag::<ZedPro>()
1604 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1605 && !dismissed_rate_limit_notice()
1606 {
1607 self.show_rate_limit_notice = true;
1608 cx.notify();
1609 }
1610
1611 self.edited_since_done = false;
1612 self.editor
1613 .update(cx, |editor, _| editor.set_read_only(false));
1614 }
1615 }
1616 }
1617
1618 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1619 match &self.codegen.read(cx).status {
1620 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1621 cx.emit(PromptEditorEvent::CancelRequested);
1622 }
1623 CodegenStatus::Pending => {
1624 cx.emit(PromptEditorEvent::StopRequested);
1625 }
1626 }
1627 }
1628
1629 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1630 match &self.codegen.read(cx).status {
1631 CodegenStatus::Idle => {
1632 cx.emit(PromptEditorEvent::StartRequested);
1633 }
1634 CodegenStatus::Pending => {
1635 cx.emit(PromptEditorEvent::DismissRequested);
1636 }
1637 CodegenStatus::Done | CodegenStatus::Error(_) => {
1638 if self.edited_since_done {
1639 cx.emit(PromptEditorEvent::StartRequested);
1640 } else {
1641 cx.emit(PromptEditorEvent::ConfirmRequested);
1642 }
1643 }
1644 }
1645 }
1646
1647 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1648 if let Some(ix) = self.prompt_history_ix {
1649 if ix > 0 {
1650 self.prompt_history_ix = Some(ix - 1);
1651 let prompt = self.prompt_history[ix - 1].as_str();
1652 self.editor.update(cx, |editor, cx| {
1653 editor.set_text(prompt, cx);
1654 editor.move_to_beginning(&Default::default(), cx);
1655 });
1656 }
1657 } else if !self.prompt_history.is_empty() {
1658 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1659 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1660 self.editor.update(cx, |editor, cx| {
1661 editor.set_text(prompt, cx);
1662 editor.move_to_beginning(&Default::default(), cx);
1663 });
1664 }
1665 }
1666
1667 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1668 if let Some(ix) = self.prompt_history_ix {
1669 if ix < self.prompt_history.len() - 1 {
1670 self.prompt_history_ix = Some(ix + 1);
1671 let prompt = self.prompt_history[ix + 1].as_str();
1672 self.editor.update(cx, |editor, cx| {
1673 editor.set_text(prompt, cx);
1674 editor.move_to_end(&Default::default(), cx)
1675 });
1676 } else {
1677 self.prompt_history_ix = None;
1678 let prompt = self.pending_prompt.as_str();
1679 self.editor.update(cx, |editor, cx| {
1680 editor.set_text(prompt, cx);
1681 editor.move_to_end(&Default::default(), cx)
1682 });
1683 }
1684 }
1685 }
1686
1687 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1688 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1689 let token_count = self.token_count?;
1690 let max_token_count = model.max_token_count();
1691
1692 let remaining_tokens = max_token_count as isize - token_count as isize;
1693 let token_count_color = if remaining_tokens <= 0 {
1694 Color::Error
1695 } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1696 Color::Warning
1697 } else {
1698 Color::Muted
1699 };
1700
1701 let mut token_count = h_flex()
1702 .id("token_count")
1703 .gap_0p5()
1704 .child(
1705 Label::new(humanize_token_count(token_count))
1706 .size(LabelSize::Small)
1707 .color(token_count_color),
1708 )
1709 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1710 .child(
1711 Label::new(humanize_token_count(max_token_count))
1712 .size(LabelSize::Small)
1713 .color(Color::Muted),
1714 );
1715 if let Some(workspace) = self.workspace.clone() {
1716 token_count = token_count
1717 .tooltip(|cx| {
1718 Tooltip::with_meta(
1719 "Tokens Used by Inline Assistant",
1720 None,
1721 "Click to Open Assistant Panel",
1722 cx,
1723 )
1724 })
1725 .cursor_pointer()
1726 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1727 .on_click(move |_, cx| {
1728 cx.stop_propagation();
1729 workspace
1730 .update(cx, |workspace, cx| {
1731 workspace.focus_panel::<AssistantPanel>(cx)
1732 })
1733 .ok();
1734 });
1735 } else {
1736 token_count = token_count
1737 .cursor_default()
1738 .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1739 }
1740
1741 Some(token_count)
1742 }
1743
1744 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1745 let settings = ThemeSettings::get_global(cx);
1746 let text_style = TextStyle {
1747 color: if self.editor.read(cx).read_only(cx) {
1748 cx.theme().colors().text_disabled
1749 } else {
1750 cx.theme().colors().text
1751 },
1752 font_family: settings.ui_font.family.clone(),
1753 font_features: settings.ui_font.features.clone(),
1754 font_fallbacks: settings.ui_font.fallbacks.clone(),
1755 font_size: rems(0.875).into(),
1756 font_weight: settings.ui_font.weight,
1757 line_height: relative(1.3),
1758 ..Default::default()
1759 };
1760 EditorElement::new(
1761 &self.editor,
1762 EditorStyle {
1763 background: cx.theme().colors().editor_background,
1764 local_player: cx.theme().players().local(),
1765 text: text_style,
1766 ..Default::default()
1767 },
1768 )
1769 }
1770
1771 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1772 Popover::new().child(
1773 v_flex()
1774 .occlude()
1775 .p_2()
1776 .child(
1777 Label::new("Out of Tokens")
1778 .size(LabelSize::Small)
1779 .weight(FontWeight::BOLD),
1780 )
1781 .child(Label::new(
1782 "Try Zed Pro for higher limits, a wider range of models, and more.",
1783 ))
1784 .child(
1785 h_flex()
1786 .justify_between()
1787 .child(CheckboxWithLabel::new(
1788 "dont-show-again",
1789 Label::new("Don't show again"),
1790 if dismissed_rate_limit_notice() {
1791 ui::Selection::Selected
1792 } else {
1793 ui::Selection::Unselected
1794 },
1795 |selection, cx| {
1796 let is_dismissed = match selection {
1797 ui::Selection::Unselected => false,
1798 ui::Selection::Indeterminate => return,
1799 ui::Selection::Selected => true,
1800 };
1801
1802 set_rate_limit_notice_dismissed(is_dismissed, cx)
1803 },
1804 ))
1805 .child(
1806 h_flex()
1807 .gap_2()
1808 .child(
1809 Button::new("dismiss", "Dismiss")
1810 .style(ButtonStyle::Transparent)
1811 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1812 )
1813 .child(Button::new("more-info", "More Info").on_click(
1814 |_event, cx| {
1815 cx.dispatch_action(Box::new(
1816 zed_actions::OpenAccountSettings,
1817 ))
1818 },
1819 )),
1820 ),
1821 ),
1822 )
1823 }
1824}
1825
1826const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1827
1828fn dismissed_rate_limit_notice() -> bool {
1829 db::kvp::KEY_VALUE_STORE
1830 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1831 .log_err()
1832 .map_or(false, |s| s.is_some())
1833}
1834
1835fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1836 db::write_and_log(cx, move || async move {
1837 if is_dismissed {
1838 db::kvp::KEY_VALUE_STORE
1839 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1840 .await
1841 } else {
1842 db::kvp::KEY_VALUE_STORE
1843 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1844 .await
1845 }
1846 })
1847}
1848
1849struct InlineAssist {
1850 group_id: InlineAssistGroupId,
1851 range: Range<Anchor>,
1852 editor: WeakView<Editor>,
1853 decorations: Option<InlineAssistDecorations>,
1854 codegen: Model<Codegen>,
1855 _subscriptions: Vec<Subscription>,
1856 workspace: Option<WeakView<Workspace>>,
1857 include_context: bool,
1858}
1859
1860impl InlineAssist {
1861 #[allow(clippy::too_many_arguments)]
1862 fn new(
1863 assist_id: InlineAssistId,
1864 group_id: InlineAssistGroupId,
1865 include_context: bool,
1866 editor: &View<Editor>,
1867 prompt_editor: &View<PromptEditor>,
1868 prompt_block_id: CustomBlockId,
1869 end_block_id: CustomBlockId,
1870 range: Range<Anchor>,
1871 codegen: Model<Codegen>,
1872 workspace: Option<WeakView<Workspace>>,
1873 cx: &mut WindowContext,
1874 ) -> Self {
1875 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1876 InlineAssist {
1877 group_id,
1878 include_context,
1879 editor: editor.downgrade(),
1880 decorations: Some(InlineAssistDecorations {
1881 prompt_block_id,
1882 prompt_editor: prompt_editor.clone(),
1883 removed_line_block_ids: HashSet::default(),
1884 end_block_id,
1885 }),
1886 range,
1887 codegen: codegen.clone(),
1888 workspace: workspace.clone(),
1889 _subscriptions: vec![
1890 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1891 InlineAssistant::update_global(cx, |this, cx| {
1892 this.handle_prompt_editor_focus_in(assist_id, cx)
1893 })
1894 }),
1895 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1896 InlineAssistant::update_global(cx, |this, cx| {
1897 this.handle_prompt_editor_focus_out(assist_id, cx)
1898 })
1899 }),
1900 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1901 InlineAssistant::update_global(cx, |this, cx| {
1902 this.handle_prompt_editor_event(prompt_editor, event, cx)
1903 })
1904 }),
1905 cx.observe(&codegen, {
1906 let editor = editor.downgrade();
1907 move |_, cx| {
1908 if let Some(editor) = editor.upgrade() {
1909 InlineAssistant::update_global(cx, |this, cx| {
1910 if let Some(editor_assists) =
1911 this.assists_by_editor.get(&editor.downgrade())
1912 {
1913 editor_assists.highlight_updates.send(()).ok();
1914 }
1915
1916 this.update_editor_blocks(&editor, assist_id, cx);
1917 })
1918 }
1919 }
1920 }),
1921 cx.subscribe(&codegen, move |codegen, event, cx| {
1922 InlineAssistant::update_global(cx, |this, cx| match event {
1923 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1924 CodegenEvent::Finished => {
1925 let assist = if let Some(assist) = this.assists.get(&assist_id) {
1926 assist
1927 } else {
1928 return;
1929 };
1930
1931 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
1932 if assist.decorations.is_none() {
1933 if let Some(workspace) = assist
1934 .workspace
1935 .as_ref()
1936 .and_then(|workspace| workspace.upgrade())
1937 {
1938 let error = format!("Inline assistant error: {}", error);
1939 workspace.update(cx, |workspace, cx| {
1940 struct InlineAssistantError;
1941
1942 let id =
1943 NotificationId::identified::<InlineAssistantError>(
1944 assist_id.0,
1945 );
1946
1947 workspace.show_toast(Toast::new(id, error), cx);
1948 })
1949 }
1950 }
1951 }
1952
1953 if assist.decorations.is_none() {
1954 this.finish_assist(assist_id, false, cx);
1955 }
1956 }
1957 })
1958 }),
1959 ],
1960 }
1961 }
1962
1963 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
1964 let decorations = self.decorations.as_ref()?;
1965 Some(decorations.prompt_editor.read(cx).prompt(cx))
1966 }
1967
1968 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
1969 if self.include_context {
1970 let workspace = self.workspace.as_ref()?;
1971 let workspace = workspace.upgrade()?.read(cx);
1972 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
1973 Some(
1974 assistant_panel
1975 .read(cx)
1976 .active_context(cx)?
1977 .read(cx)
1978 .to_completion_request(cx),
1979 )
1980 } else {
1981 None
1982 }
1983 }
1984
1985 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
1986 let Some(user_prompt) = self.user_prompt(cx) else {
1987 return future::ready(Err(anyhow!("no user prompt"))).boxed();
1988 };
1989 let assistant_panel_context = self.assistant_panel_context(cx);
1990 self.codegen.read(cx).count_tokens(
1991 self.range.clone(),
1992 user_prompt,
1993 assistant_panel_context,
1994 cx,
1995 )
1996 }
1997}
1998
1999struct InlineAssistDecorations {
2000 prompt_block_id: CustomBlockId,
2001 prompt_editor: View<PromptEditor>,
2002 removed_line_block_ids: HashSet<CustomBlockId>,
2003 end_block_id: CustomBlockId,
2004}
2005
2006#[derive(Debug)]
2007pub enum CodegenEvent {
2008 Finished,
2009 Undone,
2010}
2011
2012pub struct Codegen {
2013 buffer: Model<MultiBuffer>,
2014 old_buffer: Model<Buffer>,
2015 snapshot: MultiBufferSnapshot,
2016 edit_position: Option<Anchor>,
2017 last_equal_ranges: Vec<Range<Anchor>>,
2018 initial_transaction_id: Option<TransactionId>,
2019 transformation_transaction_id: Option<TransactionId>,
2020 status: CodegenStatus,
2021 generation: Task<()>,
2022 diff: Diff,
2023 telemetry: Option<Arc<Telemetry>>,
2024 _subscription: gpui::Subscription,
2025}
2026
2027pub enum CodegenStatus {
2028 Idle,
2029 Pending,
2030 Done,
2031 Error(anyhow::Error),
2032}
2033
2034#[derive(Default)]
2035struct Diff {
2036 task: Option<Task<()>>,
2037 should_update: bool,
2038 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2039 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2040}
2041
2042impl EventEmitter<CodegenEvent> for Codegen {}
2043
2044impl Codegen {
2045 pub fn new(
2046 buffer: Model<MultiBuffer>,
2047 range: Range<Anchor>,
2048 initial_transaction_id: Option<TransactionId>,
2049 telemetry: Option<Arc<Telemetry>>,
2050 cx: &mut ModelContext<Self>,
2051 ) -> Self {
2052 let snapshot = buffer.read(cx).snapshot(cx);
2053
2054 let (old_buffer, _, _) = buffer
2055 .read(cx)
2056 .range_to_buffer_ranges(range.clone(), cx)
2057 .pop()
2058 .unwrap();
2059 let old_buffer = cx.new_model(|cx| {
2060 let old_buffer = old_buffer.read(cx);
2061 let text = old_buffer.as_rope().clone();
2062 let line_ending = old_buffer.line_ending();
2063 let language = old_buffer.language().cloned();
2064 let language_registry = old_buffer.language_registry();
2065
2066 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2067 buffer.set_language(language, cx);
2068 if let Some(language_registry) = language_registry {
2069 buffer.set_language_registry(language_registry)
2070 }
2071 buffer
2072 });
2073
2074 Self {
2075 buffer: buffer.clone(),
2076 old_buffer,
2077 edit_position: None,
2078 snapshot,
2079 last_equal_ranges: Default::default(),
2080 transformation_transaction_id: None,
2081 status: CodegenStatus::Idle,
2082 generation: Task::ready(()),
2083 diff: Diff::default(),
2084 telemetry,
2085 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2086 initial_transaction_id,
2087 }
2088 }
2089
2090 fn handle_buffer_event(
2091 &mut self,
2092 _buffer: Model<MultiBuffer>,
2093 event: &multi_buffer::Event,
2094 cx: &mut ModelContext<Self>,
2095 ) {
2096 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2097 if self.transformation_transaction_id == Some(*transaction_id) {
2098 self.transformation_transaction_id = None;
2099 self.generation = Task::ready(());
2100 cx.emit(CodegenEvent::Undone);
2101 }
2102 }
2103 }
2104
2105 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2106 &self.last_equal_ranges
2107 }
2108
2109 pub fn count_tokens(
2110 &self,
2111 edit_range: Range<Anchor>,
2112 user_prompt: String,
2113 assistant_panel_context: Option<LanguageModelRequest>,
2114 cx: &AppContext,
2115 ) -> BoxFuture<'static, Result<usize>> {
2116 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2117 let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
2118 model.count_tokens(request, cx)
2119 } else {
2120 future::ready(Err(anyhow!("no active model"))).boxed()
2121 }
2122 }
2123
2124 pub fn start(
2125 &mut self,
2126 edit_range: Range<Anchor>,
2127 user_prompt: String,
2128 assistant_panel_context: Option<LanguageModelRequest>,
2129 cx: &mut ModelContext<Self>,
2130 ) -> Result<()> {
2131 let model = LanguageModelRegistry::read_global(cx)
2132 .active_model()
2133 .context("no active model")?;
2134
2135 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2136 self.buffer.update(cx, |buffer, cx| {
2137 buffer.undo_transaction(transformation_transaction_id, cx)
2138 });
2139 }
2140
2141 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2142
2143 let telemetry_id = model.telemetry_id();
2144 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2145 .trim()
2146 .to_lowercase()
2147 == "delete"
2148 {
2149 async { Ok(stream::empty().boxed()) }.boxed_local()
2150 } else {
2151 let request =
2152 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
2153 let chunks =
2154 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2155 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2156 };
2157 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2158 Ok(())
2159 }
2160
2161 fn build_request(
2162 &self,
2163 user_prompt: String,
2164 assistant_panel_context: Option<LanguageModelRequest>,
2165 edit_range: Range<Anchor>,
2166 cx: &AppContext,
2167 ) -> LanguageModelRequest {
2168 let buffer = self.buffer.read(cx).snapshot(cx);
2169 let language = buffer.language_at(edit_range.start);
2170 let language_name = if let Some(language) = language.as_ref() {
2171 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2172 None
2173 } else {
2174 Some(language.name())
2175 }
2176 } else {
2177 None
2178 };
2179
2180 // Higher Temperature increases the randomness of model outputs.
2181 // If Markdown or No Language is Known, increase the randomness for more creative output
2182 // If Code, decrease temperature to get more deterministic outputs
2183 let temperature = if let Some(language) = language_name.clone() {
2184 if language.as_ref() == "Markdown" {
2185 1.0
2186 } else {
2187 0.5
2188 }
2189 } else {
2190 1.0
2191 };
2192
2193 let language_name = language_name.as_deref();
2194 let start = buffer.point_to_buffer_offset(edit_range.start);
2195 let end = buffer.point_to_buffer_offset(edit_range.end);
2196 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2197 let (start_buffer, start_buffer_offset) = start;
2198 let (end_buffer, end_buffer_offset) = end;
2199 if start_buffer.remote_id() == end_buffer.remote_id() {
2200 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2201 } else {
2202 panic!("invalid transformation range");
2203 }
2204 } else {
2205 panic!("invalid transformation range");
2206 };
2207 let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
2208
2209 let mut messages = Vec::new();
2210 if let Some(context_request) = assistant_panel_context {
2211 messages = context_request.messages;
2212 }
2213
2214 messages.push(LanguageModelRequestMessage {
2215 role: Role::User,
2216 content: prompt,
2217 });
2218
2219 LanguageModelRequest {
2220 messages,
2221 stop: vec!["|END|>".to_string()],
2222 temperature,
2223 }
2224 }
2225
2226 pub fn handle_stream(
2227 &mut self,
2228 model_telemetry_id: String,
2229 edit_range: Range<Anchor>,
2230 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2231 cx: &mut ModelContext<Self>,
2232 ) {
2233 let snapshot = self.snapshot.clone();
2234 let selected_text = snapshot
2235 .text_for_range(edit_range.start..edit_range.end)
2236 .collect::<Rope>();
2237
2238 let selection_start = edit_range.start.to_point(&snapshot);
2239
2240 // Start with the indentation of the first line in the selection
2241 let mut suggested_line_indent = snapshot
2242 .suggested_indents(selection_start.row..=selection_start.row, cx)
2243 .into_values()
2244 .next()
2245 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2246
2247 // If the first line in the selection does not have indentation, check the following lines
2248 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2249 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2250 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2251 // Prefer tabs if a line in the selection uses tabs as indentation
2252 if line_indent.kind == IndentKind::Tab {
2253 suggested_line_indent.kind = IndentKind::Tab;
2254 break;
2255 }
2256 }
2257 }
2258
2259 let telemetry = self.telemetry.clone();
2260 self.diff = Diff::default();
2261 self.status = CodegenStatus::Pending;
2262 let mut edit_start = edit_range.start.to_offset(&snapshot);
2263 self.generation = cx.spawn(|this, mut cx| {
2264 async move {
2265 let chunks = stream.await;
2266 let generate = async {
2267 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
2268 let diff: Task<anyhow::Result<()>> =
2269 cx.background_executor().spawn(async move {
2270 let mut response_latency = None;
2271 let request_start = Instant::now();
2272 let diff = async {
2273 let chunks = StripInvalidSpans::new(chunks?);
2274 futures::pin_mut!(chunks);
2275 let mut diff = StreamingDiff::new(selected_text.to_string());
2276
2277 let mut new_text = String::new();
2278 let mut base_indent = None;
2279 let mut line_indent = None;
2280 let mut first_line = true;
2281
2282 while let Some(chunk) = chunks.next().await {
2283 if response_latency.is_none() {
2284 response_latency = Some(request_start.elapsed());
2285 }
2286 let chunk = chunk?;
2287
2288 let mut lines = chunk.split('\n').peekable();
2289 while let Some(line) = lines.next() {
2290 new_text.push_str(line);
2291 if line_indent.is_none() {
2292 if let Some(non_whitespace_ch_ix) =
2293 new_text.find(|ch: char| !ch.is_whitespace())
2294 {
2295 line_indent = Some(non_whitespace_ch_ix);
2296 base_indent = base_indent.or(line_indent);
2297
2298 let line_indent = line_indent.unwrap();
2299 let base_indent = base_indent.unwrap();
2300 let indent_delta =
2301 line_indent as i32 - base_indent as i32;
2302 let mut corrected_indent_len = cmp::max(
2303 0,
2304 suggested_line_indent.len as i32 + indent_delta,
2305 )
2306 as usize;
2307 if first_line {
2308 corrected_indent_len = corrected_indent_len
2309 .saturating_sub(
2310 selection_start.column as usize,
2311 );
2312 }
2313
2314 let indent_char = suggested_line_indent.char();
2315 let mut indent_buffer = [0; 4];
2316 let indent_str =
2317 indent_char.encode_utf8(&mut indent_buffer);
2318 new_text.replace_range(
2319 ..line_indent,
2320 &indent_str.repeat(corrected_indent_len),
2321 );
2322 }
2323 }
2324
2325 if line_indent.is_some() {
2326 hunks_tx.send(diff.push_new(&new_text)).await?;
2327 new_text.clear();
2328 }
2329
2330 if lines.peek().is_some() {
2331 hunks_tx.send(diff.push_new("\n")).await?;
2332 if line_indent.is_none() {
2333 // Don't write out the leading indentation in empty lines on the next line
2334 // This is the case where the above if statement didn't clear the buffer
2335 new_text.clear();
2336 }
2337 line_indent = None;
2338 first_line = false;
2339 }
2340 }
2341 }
2342 hunks_tx.send(diff.push_new(&new_text)).await?;
2343 hunks_tx.send(diff.finish()).await?;
2344
2345 anyhow::Ok(())
2346 };
2347
2348 let result = diff.await;
2349
2350 let error_message =
2351 result.as_ref().err().map(|error| error.to_string());
2352 if let Some(telemetry) = telemetry {
2353 telemetry.report_assistant_event(
2354 None,
2355 telemetry_events::AssistantKind::Inline,
2356 model_telemetry_id,
2357 response_latency,
2358 error_message,
2359 );
2360 }
2361
2362 result?;
2363 Ok(())
2364 });
2365
2366 while let Some(hunks) = hunks_rx.next().await {
2367 this.update(&mut cx, |this, cx| {
2368 this.last_equal_ranges.clear();
2369
2370 let transaction = this.buffer.update(cx, |buffer, cx| {
2371 // Avoid grouping assistant edits with user edits.
2372 buffer.finalize_last_transaction(cx);
2373
2374 buffer.start_transaction(cx);
2375 buffer.edit(
2376 hunks.into_iter().filter_map(|hunk| match hunk {
2377 Hunk::Insert { text } => {
2378 let edit_start = snapshot.anchor_after(edit_start);
2379 Some((edit_start..edit_start, text))
2380 }
2381 Hunk::Remove { len } => {
2382 let edit_end = edit_start + len;
2383 let edit_range = snapshot.anchor_after(edit_start)
2384 ..snapshot.anchor_before(edit_end);
2385 edit_start = edit_end;
2386 Some((edit_range, String::new()))
2387 }
2388 Hunk::Keep { len } => {
2389 let edit_end = edit_start + len;
2390 let edit_range = snapshot.anchor_after(edit_start)
2391 ..snapshot.anchor_before(edit_end);
2392 edit_start = edit_end;
2393 this.last_equal_ranges.push(edit_range);
2394 None
2395 }
2396 }),
2397 None,
2398 cx,
2399 );
2400 this.edit_position = Some(snapshot.anchor_after(edit_start));
2401
2402 buffer.end_transaction(cx)
2403 });
2404
2405 if let Some(transaction) = transaction {
2406 if let Some(first_transaction) = this.transformation_transaction_id
2407 {
2408 // Group all assistant edits into the first transaction.
2409 this.buffer.update(cx, |buffer, cx| {
2410 buffer.merge_transactions(
2411 transaction,
2412 first_transaction,
2413 cx,
2414 )
2415 });
2416 } else {
2417 this.transformation_transaction_id = Some(transaction);
2418 this.buffer.update(cx, |buffer, cx| {
2419 buffer.finalize_last_transaction(cx)
2420 });
2421 }
2422 }
2423
2424 this.update_diff(edit_range.clone(), cx);
2425 cx.notify();
2426 })?;
2427 }
2428
2429 diff.await?;
2430
2431 anyhow::Ok(())
2432 };
2433
2434 let result = generate.await;
2435 this.update(&mut cx, |this, cx| {
2436 this.last_equal_ranges.clear();
2437 if let Err(error) = result {
2438 this.status = CodegenStatus::Error(error);
2439 } else {
2440 this.status = CodegenStatus::Done;
2441 }
2442 cx.emit(CodegenEvent::Finished);
2443 cx.notify();
2444 })
2445 .ok();
2446 }
2447 });
2448 cx.notify();
2449 }
2450
2451 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2452 self.last_equal_ranges.clear();
2453 self.status = CodegenStatus::Done;
2454 self.generation = Task::ready(());
2455 cx.emit(CodegenEvent::Finished);
2456 cx.notify();
2457 }
2458
2459 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2460 self.buffer.update(cx, |buffer, cx| {
2461 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2462 buffer.undo_transaction(transaction_id, cx);
2463 }
2464
2465 if let Some(transaction_id) = self.initial_transaction_id.take() {
2466 buffer.undo_transaction(transaction_id, cx);
2467 }
2468 });
2469 }
2470
2471 fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
2472 if self.diff.task.is_some() {
2473 self.diff.should_update = true;
2474 } else {
2475 self.diff.should_update = false;
2476
2477 let old_snapshot = self.snapshot.clone();
2478 let old_range = edit_range.to_point(&old_snapshot);
2479 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2480 let new_range = edit_range.to_point(&new_snapshot);
2481
2482 self.diff.task = Some(cx.spawn(|this, mut cx| async move {
2483 let (deleted_row_ranges, inserted_row_ranges) = cx
2484 .background_executor()
2485 .spawn(async move {
2486 let old_text = old_snapshot
2487 .text_for_range(
2488 Point::new(old_range.start.row, 0)
2489 ..Point::new(
2490 old_range.end.row,
2491 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2492 ),
2493 )
2494 .collect::<String>();
2495 let new_text = new_snapshot
2496 .text_for_range(
2497 Point::new(new_range.start.row, 0)
2498 ..Point::new(
2499 new_range.end.row,
2500 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2501 ),
2502 )
2503 .collect::<String>();
2504
2505 let mut old_row = old_range.start.row;
2506 let mut new_row = new_range.start.row;
2507 let diff = TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2508
2509 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2510 let mut inserted_row_ranges = Vec::new();
2511 for change in diff.iter_all_changes() {
2512 let line_count = change.value().lines().count() as u32;
2513 match change.tag() {
2514 similar::ChangeTag::Equal => {
2515 old_row += line_count;
2516 new_row += line_count;
2517 }
2518 similar::ChangeTag::Delete => {
2519 let old_end_row = old_row + line_count - 1;
2520 let new_row =
2521 new_snapshot.anchor_before(Point::new(new_row, 0));
2522
2523 if let Some((_, last_deleted_row_range)) =
2524 deleted_row_ranges.last_mut()
2525 {
2526 if *last_deleted_row_range.end() + 1 == old_row {
2527 *last_deleted_row_range =
2528 *last_deleted_row_range.start()..=old_end_row;
2529 } else {
2530 deleted_row_ranges
2531 .push((new_row, old_row..=old_end_row));
2532 }
2533 } else {
2534 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2535 }
2536
2537 old_row += line_count;
2538 }
2539 similar::ChangeTag::Insert => {
2540 let new_end_row = new_row + line_count - 1;
2541 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2542 let end = new_snapshot.anchor_before(Point::new(
2543 new_end_row,
2544 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2545 ));
2546 inserted_row_ranges.push(start..=end);
2547 new_row += line_count;
2548 }
2549 }
2550 }
2551
2552 (deleted_row_ranges, inserted_row_ranges)
2553 })
2554 .await;
2555
2556 this.update(&mut cx, |this, cx| {
2557 this.diff.deleted_row_ranges = deleted_row_ranges;
2558 this.diff.inserted_row_ranges = inserted_row_ranges;
2559 this.diff.task = None;
2560 if this.diff.should_update {
2561 this.update_diff(edit_range, cx);
2562 }
2563 cx.notify();
2564 })
2565 .ok();
2566 }));
2567 }
2568 }
2569}
2570
2571struct StripInvalidSpans<T> {
2572 stream: T,
2573 stream_done: bool,
2574 buffer: String,
2575 first_line: bool,
2576 line_end: bool,
2577 starts_with_code_block: bool,
2578}
2579
2580impl<T> StripInvalidSpans<T>
2581where
2582 T: Stream<Item = Result<String>>,
2583{
2584 fn new(stream: T) -> Self {
2585 Self {
2586 stream,
2587 stream_done: false,
2588 buffer: String::new(),
2589 first_line: true,
2590 line_end: false,
2591 starts_with_code_block: false,
2592 }
2593 }
2594}
2595
2596impl<T> Stream for StripInvalidSpans<T>
2597where
2598 T: Stream<Item = Result<String>>,
2599{
2600 type Item = Result<String>;
2601
2602 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2603 const CODE_BLOCK_DELIMITER: &str = "```";
2604 const CURSOR_SPAN: &str = "<|CURSOR|>";
2605
2606 let this = unsafe { self.get_unchecked_mut() };
2607 loop {
2608 if !this.stream_done {
2609 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2610 match stream.as_mut().poll_next(cx) {
2611 Poll::Ready(Some(Ok(chunk))) => {
2612 this.buffer.push_str(&chunk);
2613 }
2614 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2615 Poll::Ready(None) => {
2616 this.stream_done = true;
2617 }
2618 Poll::Pending => return Poll::Pending,
2619 }
2620 }
2621
2622 let mut chunk = String::new();
2623 let mut consumed = 0;
2624 if !this.buffer.is_empty() {
2625 let mut lines = this.buffer.split('\n').enumerate().peekable();
2626 while let Some((line_ix, line)) = lines.next() {
2627 if line_ix > 0 {
2628 this.first_line = false;
2629 }
2630
2631 if this.first_line {
2632 let trimmed_line = line.trim();
2633 if lines.peek().is_some() {
2634 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2635 consumed += line.len() + 1;
2636 this.starts_with_code_block = true;
2637 continue;
2638 }
2639 } else if trimmed_line.is_empty()
2640 || prefixes(CODE_BLOCK_DELIMITER)
2641 .any(|prefix| trimmed_line.starts_with(prefix))
2642 {
2643 break;
2644 }
2645 }
2646
2647 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2648 if lines.peek().is_some() {
2649 if this.line_end {
2650 chunk.push('\n');
2651 }
2652
2653 chunk.push_str(&line_without_cursor);
2654 this.line_end = true;
2655 consumed += line.len() + 1;
2656 } else if this.stream_done {
2657 if !this.starts_with_code_block
2658 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2659 {
2660 if this.line_end {
2661 chunk.push('\n');
2662 }
2663
2664 chunk.push_str(&line);
2665 }
2666
2667 consumed += line.len();
2668 } else {
2669 let trimmed_line = line.trim();
2670 if trimmed_line.is_empty()
2671 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2672 || prefixes(CODE_BLOCK_DELIMITER)
2673 .any(|prefix| trimmed_line.ends_with(prefix))
2674 {
2675 break;
2676 } else {
2677 if this.line_end {
2678 chunk.push('\n');
2679 this.line_end = false;
2680 }
2681
2682 chunk.push_str(&line_without_cursor);
2683 consumed += line.len();
2684 }
2685 }
2686 }
2687 }
2688
2689 this.buffer = this.buffer.split_off(consumed);
2690 if !chunk.is_empty() {
2691 return Poll::Ready(Some(Ok(chunk)));
2692 } else if this.stream_done {
2693 return Poll::Ready(None);
2694 }
2695 }
2696 }
2697}
2698
2699fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2700 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2701}
2702
2703fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2704 ranges.sort_unstable_by(|a, b| {
2705 a.start
2706 .cmp(&b.start, buffer)
2707 .then_with(|| b.end.cmp(&a.end, buffer))
2708 });
2709
2710 let mut ix = 0;
2711 while ix + 1 < ranges.len() {
2712 let b = ranges[ix + 1].clone();
2713 let a = &mut ranges[ix];
2714 if a.end.cmp(&b.start, buffer).is_gt() {
2715 if a.end.cmp(&b.end, buffer).is_lt() {
2716 a.end = b.end;
2717 }
2718 ranges.remove(ix + 1);
2719 } else {
2720 ix += 1;
2721 }
2722 }
2723}
2724
2725#[cfg(test)]
2726mod tests {
2727 use super::*;
2728 use futures::stream::{self};
2729 use gpui::{Context, TestAppContext};
2730 use indoc::indoc;
2731 use language::{
2732 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2733 Point,
2734 };
2735 use language_model::LanguageModelRegistry;
2736 use rand::prelude::*;
2737 use serde::Serialize;
2738 use settings::SettingsStore;
2739 use std::{future, sync::Arc};
2740
2741 #[derive(Serialize)]
2742 pub struct DummyCompletionRequest {
2743 pub name: String,
2744 }
2745
2746 #[gpui::test(iterations = 10)]
2747 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2748 cx.set_global(cx.update(SettingsStore::test));
2749 cx.update(language_model::LanguageModelRegistry::test);
2750 cx.update(language_settings::init);
2751
2752 let text = indoc! {"
2753 fn main() {
2754 let x = 0;
2755 for _ in 0..10 {
2756 x += 1;
2757 }
2758 }
2759 "};
2760 let buffer =
2761 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2762 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2763 let range = buffer.read_with(cx, |buffer, cx| {
2764 let snapshot = buffer.snapshot(cx);
2765 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2766 });
2767 let codegen =
2768 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2769
2770 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2771 codegen.update(cx, |codegen, cx| {
2772 codegen.handle_stream(
2773 String::new(),
2774 range,
2775 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2776 cx,
2777 )
2778 });
2779
2780 let mut new_text = concat!(
2781 " let mut x = 0;\n",
2782 " while x < 10 {\n",
2783 " x += 1;\n",
2784 " }",
2785 );
2786 while !new_text.is_empty() {
2787 let max_len = cmp::min(new_text.len(), 10);
2788 let len = rng.gen_range(1..=max_len);
2789 let (chunk, suffix) = new_text.split_at(len);
2790 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2791 new_text = suffix;
2792 cx.background_executor.run_until_parked();
2793 }
2794 drop(chunks_tx);
2795 cx.background_executor.run_until_parked();
2796
2797 assert_eq!(
2798 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2799 indoc! {"
2800 fn main() {
2801 let mut x = 0;
2802 while x < 10 {
2803 x += 1;
2804 }
2805 }
2806 "}
2807 );
2808 }
2809
2810 #[gpui::test(iterations = 10)]
2811 async fn test_autoindent_when_generating_past_indentation(
2812 cx: &mut TestAppContext,
2813 mut rng: StdRng,
2814 ) {
2815 cx.set_global(cx.update(SettingsStore::test));
2816 cx.update(language_settings::init);
2817
2818 let text = indoc! {"
2819 fn main() {
2820 le
2821 }
2822 "};
2823 let buffer =
2824 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2825 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2826 let range = buffer.read_with(cx, |buffer, cx| {
2827 let snapshot = buffer.snapshot(cx);
2828 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2829 });
2830 let codegen =
2831 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2832
2833 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2834 codegen.update(cx, |codegen, cx| {
2835 codegen.handle_stream(
2836 String::new(),
2837 range.clone(),
2838 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2839 cx,
2840 )
2841 });
2842
2843 cx.background_executor.run_until_parked();
2844
2845 let mut new_text = concat!(
2846 "t mut x = 0;\n",
2847 "while x < 10 {\n",
2848 " x += 1;\n",
2849 "}", //
2850 );
2851 while !new_text.is_empty() {
2852 let max_len = cmp::min(new_text.len(), 10);
2853 let len = rng.gen_range(1..=max_len);
2854 let (chunk, suffix) = new_text.split_at(len);
2855 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2856 new_text = suffix;
2857 cx.background_executor.run_until_parked();
2858 }
2859 drop(chunks_tx);
2860 cx.background_executor.run_until_parked();
2861
2862 assert_eq!(
2863 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2864 indoc! {"
2865 fn main() {
2866 let mut x = 0;
2867 while x < 10 {
2868 x += 1;
2869 }
2870 }
2871 "}
2872 );
2873 }
2874
2875 #[gpui::test(iterations = 10)]
2876 async fn test_autoindent_when_generating_before_indentation(
2877 cx: &mut TestAppContext,
2878 mut rng: StdRng,
2879 ) {
2880 cx.update(LanguageModelRegistry::test);
2881 cx.set_global(cx.update(SettingsStore::test));
2882 cx.update(language_settings::init);
2883
2884 let text = concat!(
2885 "fn main() {\n",
2886 " \n",
2887 "}\n" //
2888 );
2889 let buffer =
2890 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2891 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2892 let range = buffer.read_with(cx, |buffer, cx| {
2893 let snapshot = buffer.snapshot(cx);
2894 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2895 });
2896 let codegen =
2897 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2898
2899 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2900 codegen.update(cx, |codegen, cx| {
2901 codegen.handle_stream(
2902 String::new(),
2903 range.clone(),
2904 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2905 cx,
2906 )
2907 });
2908
2909 cx.background_executor.run_until_parked();
2910
2911 let mut new_text = concat!(
2912 "let mut x = 0;\n",
2913 "while x < 10 {\n",
2914 " x += 1;\n",
2915 "}", //
2916 );
2917 while !new_text.is_empty() {
2918 let max_len = cmp::min(new_text.len(), 10);
2919 let len = rng.gen_range(1..=max_len);
2920 let (chunk, suffix) = new_text.split_at(len);
2921 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2922 new_text = suffix;
2923 cx.background_executor.run_until_parked();
2924 }
2925 drop(chunks_tx);
2926 cx.background_executor.run_until_parked();
2927
2928 assert_eq!(
2929 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2930 indoc! {"
2931 fn main() {
2932 let mut x = 0;
2933 while x < 10 {
2934 x += 1;
2935 }
2936 }
2937 "}
2938 );
2939 }
2940
2941 #[gpui::test(iterations = 10)]
2942 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
2943 cx.update(LanguageModelRegistry::test);
2944 cx.set_global(cx.update(SettingsStore::test));
2945 cx.update(language_settings::init);
2946
2947 let text = indoc! {"
2948 func main() {
2949 \tx := 0
2950 \tfor i := 0; i < 10; i++ {
2951 \t\tx++
2952 \t}
2953 }
2954 "};
2955 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
2956 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2957 let range = buffer.read_with(cx, |buffer, cx| {
2958 let snapshot = buffer.snapshot(cx);
2959 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
2960 });
2961 let codegen =
2962 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2963
2964 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2965 codegen.update(cx, |codegen, cx| {
2966 codegen.handle_stream(
2967 String::new(),
2968 range.clone(),
2969 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2970 cx,
2971 )
2972 });
2973
2974 let new_text = concat!(
2975 "func main() {\n",
2976 "\tx := 0\n",
2977 "\tfor x < 10 {\n",
2978 "\t\tx++\n",
2979 "\t}", //
2980 );
2981 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
2982 drop(chunks_tx);
2983 cx.background_executor.run_until_parked();
2984
2985 assert_eq!(
2986 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2987 indoc! {"
2988 func main() {
2989 \tx := 0
2990 \tfor x < 10 {
2991 \t\tx++
2992 \t}
2993 }
2994 "}
2995 );
2996 }
2997
2998 #[gpui::test]
2999 async fn test_strip_invalid_spans_from_codeblock() {
3000 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3001 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3002 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3003 assert_chunks(
3004 "```html\n```js\nLorem ipsum dolor\n```\n```",
3005 "```js\nLorem ipsum dolor\n```",
3006 )
3007 .await;
3008 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3009 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3010 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3011 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3012
3013 async fn assert_chunks(text: &str, expected_text: &str) {
3014 for chunk_size in 1..=text.len() {
3015 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3016 .map(|chunk| chunk.unwrap())
3017 .collect::<String>()
3018 .await;
3019 assert_eq!(
3020 actual_text, expected_text,
3021 "failed to strip invalid spans, chunk size: {}",
3022 chunk_size
3023 );
3024 }
3025 }
3026
3027 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3028 stream::iter(
3029 text.chars()
3030 .collect::<Vec<_>>()
3031 .chunks(size)
3032 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3033 .collect::<Vec<_>>(),
3034 )
3035 }
3036 }
3037
3038 fn rust_lang() -> Language {
3039 Language::new(
3040 LanguageConfig {
3041 name: "Rust".into(),
3042 matcher: LanguageMatcher {
3043 path_suffixes: vec!["rs".to_string()],
3044 ..Default::default()
3045 },
3046 ..Default::default()
3047 },
3048 Some(tree_sitter_rust::language()),
3049 )
3050 .with_indents_query(
3051 r#"
3052 (call_expression) @indent
3053 (field_expression) @indent
3054 (_ "(" ")" @end) @indent
3055 (_ "{" "}" @end) @indent
3056 "#,
3057 )
3058 .unwrap()
3059 }
3060}