1use crate::{
2 humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
3 Hunk, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, Selection, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use similar::TextDiff;
39use smol::future::FutureExt;
40use std::{
41 cmp,
42 future::{self, Future},
43 mem,
44 ops::{Range, RangeInclusive},
45 pin::Pin,
46 sync::Arc,
47 task::{self, Poll},
48 time::{Duration, Instant},
49};
50use theme::ThemeSettings;
51use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
52use util::{RangeExt, ResultExt};
53use workspace::{notifications::NotificationId, Toast, Workspace};
54
55pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
56 cx.set_global(InlineAssistant::new(fs, telemetry));
57}
58
59const PROMPT_HISTORY_MAX_LEN: usize = 20;
60
61pub struct InlineAssistant {
62 next_assist_id: InlineAssistId,
63 next_assist_group_id: InlineAssistGroupId,
64 assists: HashMap<InlineAssistId, InlineAssist>,
65 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
66 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
67 prompt_history: VecDeque<String>,
68 telemetry: Option<Arc<Telemetry>>,
69 fs: Arc<dyn Fs>,
70}
71
72impl Global for InlineAssistant {}
73
74impl InlineAssistant {
75 pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
76 Self {
77 next_assist_id: InlineAssistId::default(),
78 next_assist_group_id: InlineAssistGroupId::default(),
79 assists: HashMap::default(),
80 assists_by_editor: HashMap::default(),
81 assist_groups: HashMap::default(),
82 prompt_history: VecDeque::default(),
83 telemetry: Some(telemetry),
84 fs,
85 }
86 }
87
88 pub fn assist(
89 &mut self,
90 editor: &View<Editor>,
91 workspace: Option<WeakView<Workspace>>,
92 assistant_panel: Option<&View<AssistantPanel>>,
93 initial_prompt: Option<String>,
94 cx: &mut WindowContext,
95 ) {
96 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
97
98 let mut selections = Vec::<Selection<Point>>::new();
99 let mut newest_selection = None;
100 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
101 if selection.end > selection.start {
102 selection.start.column = 0;
103 // If the selection ends at the start of the line, we don't want to include it.
104 if selection.end.column == 0 {
105 selection.end.row -= 1;
106 }
107 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
108 }
109
110 if let Some(prev_selection) = selections.last_mut() {
111 if selection.start <= prev_selection.end {
112 prev_selection.end = selection.end;
113 continue;
114 }
115 }
116
117 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
118 if selection.id > latest_selection.id {
119 *latest_selection = selection.clone();
120 }
121 selections.push(selection);
122 }
123 let newest_selection = newest_selection.unwrap();
124
125 let mut codegen_ranges = Vec::new();
126 for (excerpt_id, buffer, buffer_range) in
127 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
128 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
129 }))
130 {
131 let start = Anchor {
132 buffer_id: Some(buffer.remote_id()),
133 excerpt_id,
134 text_anchor: buffer.anchor_before(buffer_range.start),
135 };
136 let end = Anchor {
137 buffer_id: Some(buffer.remote_id()),
138 excerpt_id,
139 text_anchor: buffer.anchor_after(buffer_range.end),
140 };
141 codegen_ranges.push(start..end);
142 }
143
144 let assist_group_id = self.next_assist_group_id.post_inc();
145 let prompt_buffer =
146 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
147 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
148
149 let mut assists = Vec::new();
150 let mut assist_to_focus = None;
151 for range in codegen_ranges {
152 let assist_id = self.next_assist_id.post_inc();
153 let codegen = cx.new_model(|cx| {
154 Codegen::new(
155 editor.read(cx).buffer().clone(),
156 range.clone(),
157 None,
158 self.telemetry.clone(),
159 cx,
160 )
161 });
162
163 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
164 let prompt_editor = cx.new_view(|cx| {
165 PromptEditor::new(
166 assist_id,
167 gutter_dimensions.clone(),
168 self.prompt_history.clone(),
169 prompt_buffer.clone(),
170 codegen.clone(),
171 editor,
172 assistant_panel,
173 workspace.clone(),
174 self.fs.clone(),
175 cx,
176 )
177 });
178
179 if assist_to_focus.is_none() {
180 let focus_assist = if newest_selection.reversed {
181 range.start.to_point(&snapshot) == newest_selection.start
182 } else {
183 range.end.to_point(&snapshot) == newest_selection.end
184 };
185 if focus_assist {
186 assist_to_focus = Some(assist_id);
187 }
188 }
189
190 let [prompt_block_id, end_block_id] =
191 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
192
193 assists.push((
194 assist_id,
195 range,
196 prompt_editor,
197 prompt_block_id,
198 end_block_id,
199 ));
200 }
201
202 let editor_assists = self
203 .assists_by_editor
204 .entry(editor.downgrade())
205 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
206 let mut assist_group = InlineAssistGroup::new();
207 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
208 self.assists.insert(
209 assist_id,
210 InlineAssist::new(
211 assist_id,
212 assist_group_id,
213 assistant_panel.is_some(),
214 editor,
215 &prompt_editor,
216 prompt_block_id,
217 end_block_id,
218 range,
219 prompt_editor.read(cx).codegen.clone(),
220 workspace.clone(),
221 cx,
222 ),
223 );
224 assist_group.assist_ids.push(assist_id);
225 editor_assists.assist_ids.push(assist_id);
226 }
227 self.assist_groups.insert(assist_group_id, assist_group);
228
229 if let Some(assist_id) = assist_to_focus {
230 self.focus_assist(assist_id, cx);
231 }
232 }
233
234 #[allow(clippy::too_many_arguments)]
235 pub fn suggest_assist(
236 &mut self,
237 editor: &View<Editor>,
238 mut range: Range<Anchor>,
239 initial_prompt: String,
240 initial_insertion: Option<InitialInsertion>,
241 workspace: Option<WeakView<Workspace>>,
242 assistant_panel: Option<&View<AssistantPanel>>,
243 cx: &mut WindowContext,
244 ) -> InlineAssistId {
245 let assist_group_id = self.next_assist_group_id.post_inc();
246 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
247 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
248
249 let assist_id = self.next_assist_id.post_inc();
250
251 let buffer = editor.read(cx).buffer().clone();
252 {
253 let snapshot = buffer.read(cx).read(cx);
254
255 let mut point_range = range.to_point(&snapshot);
256 if point_range.is_empty() {
257 point_range.start.column = 0;
258 point_range.end.column = 0;
259 } else {
260 point_range.start.column = 0;
261 if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
262 point_range.end.row -= 1;
263 }
264 point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
265 }
266
267 range.start = snapshot.anchor_before(point_range.start);
268 range.end = snapshot.anchor_after(point_range.end);
269 }
270
271 let codegen = cx.new_model(|cx| {
272 Codegen::new(
273 editor.read(cx).buffer().clone(),
274 range.clone(),
275 initial_insertion,
276 self.telemetry.clone(),
277 cx,
278 )
279 });
280
281 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
282 let prompt_editor = cx.new_view(|cx| {
283 PromptEditor::new(
284 assist_id,
285 gutter_dimensions.clone(),
286 self.prompt_history.clone(),
287 prompt_buffer.clone(),
288 codegen.clone(),
289 editor,
290 assistant_panel,
291 workspace.clone(),
292 self.fs.clone(),
293 cx,
294 )
295 });
296
297 let [prompt_block_id, end_block_id] =
298 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
299
300 let editor_assists = self
301 .assists_by_editor
302 .entry(editor.downgrade())
303 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
304
305 let mut assist_group = InlineAssistGroup::new();
306 self.assists.insert(
307 assist_id,
308 InlineAssist::new(
309 assist_id,
310 assist_group_id,
311 assistant_panel.is_some(),
312 editor,
313 &prompt_editor,
314 prompt_block_id,
315 end_block_id,
316 range,
317 prompt_editor.read(cx).codegen.clone(),
318 workspace.clone(),
319 cx,
320 ),
321 );
322 assist_group.assist_ids.push(assist_id);
323 editor_assists.assist_ids.push(assist_id);
324 self.assist_groups.insert(assist_group_id, assist_group);
325 assist_id
326 }
327
328 fn insert_assist_blocks(
329 &self,
330 editor: &View<Editor>,
331 range: &Range<Anchor>,
332 prompt_editor: &View<PromptEditor>,
333 cx: &mut WindowContext,
334 ) -> [CustomBlockId; 2] {
335 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
336 prompt_editor
337 .editor
338 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
339 });
340 let assist_blocks = vec![
341 BlockProperties {
342 style: BlockStyle::Sticky,
343 position: range.start,
344 height: prompt_editor_height,
345 render: build_assist_editor_renderer(prompt_editor),
346 disposition: BlockDisposition::Above,
347 },
348 BlockProperties {
349 style: BlockStyle::Sticky,
350 position: range.end,
351 height: 1,
352 render: Box::new(|cx| {
353 v_flex()
354 .h_full()
355 .w_full()
356 .border_t_1()
357 .border_color(cx.theme().status().info_border)
358 .into_any_element()
359 }),
360 disposition: BlockDisposition::Below,
361 },
362 ];
363
364 editor.update(cx, |editor, cx| {
365 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
366 [block_ids[0], block_ids[1]]
367 })
368 }
369
370 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
371 let assist = &self.assists[&assist_id];
372 let Some(decorations) = assist.decorations.as_ref() else {
373 return;
374 };
375 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
376 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
377
378 assist_group.active_assist_id = Some(assist_id);
379 if assist_group.linked {
380 for assist_id in &assist_group.assist_ids {
381 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
382 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
383 prompt_editor.set_show_cursor_when_unfocused(true, cx)
384 });
385 }
386 }
387 }
388
389 assist
390 .editor
391 .update(cx, |editor, cx| {
392 let scroll_top = editor.scroll_position(cx).y;
393 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
394 let prompt_row = editor
395 .row_for_block(decorations.prompt_block_id, cx)
396 .unwrap()
397 .0 as f32;
398
399 if (scroll_top..scroll_bottom).contains(&prompt_row) {
400 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
401 assist_id,
402 distance_from_top: prompt_row - scroll_top,
403 });
404 } else {
405 editor_assists.scroll_lock = None;
406 }
407 })
408 .ok();
409 }
410
411 fn handle_prompt_editor_focus_out(
412 &mut self,
413 assist_id: InlineAssistId,
414 cx: &mut WindowContext,
415 ) {
416 let assist = &self.assists[&assist_id];
417 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
418 if assist_group.active_assist_id == Some(assist_id) {
419 assist_group.active_assist_id = None;
420 if assist_group.linked {
421 for assist_id in &assist_group.assist_ids {
422 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
423 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
424 prompt_editor.set_show_cursor_when_unfocused(false, cx)
425 });
426 }
427 }
428 }
429 }
430 }
431
432 fn handle_prompt_editor_event(
433 &mut self,
434 prompt_editor: View<PromptEditor>,
435 event: &PromptEditorEvent,
436 cx: &mut WindowContext,
437 ) {
438 let assist_id = prompt_editor.read(cx).id;
439 match event {
440 PromptEditorEvent::StartRequested => {
441 self.start_assist(assist_id, cx);
442 }
443 PromptEditorEvent::StopRequested => {
444 self.stop_assist(assist_id, cx);
445 }
446 PromptEditorEvent::ConfirmRequested => {
447 self.finish_assist(assist_id, false, cx);
448 }
449 PromptEditorEvent::CancelRequested => {
450 self.finish_assist(assist_id, true, cx);
451 }
452 PromptEditorEvent::DismissRequested => {
453 self.dismiss_assist(assist_id, cx);
454 }
455 }
456 }
457
458 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
459 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
460 return;
461 };
462
463 let editor = editor.read(cx);
464 if editor.selections.count() == 1 {
465 let selection = editor.selections.newest::<usize>(cx);
466 let buffer = editor.buffer().read(cx).snapshot(cx);
467 for assist_id in &editor_assists.assist_ids {
468 let assist = &self.assists[assist_id];
469 let assist_range = assist.range.to_offset(&buffer);
470 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
471 {
472 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
473 self.dismiss_assist(*assist_id, cx);
474 } else {
475 self.finish_assist(*assist_id, false, cx);
476 }
477
478 return;
479 }
480 }
481 }
482
483 cx.propagate();
484 }
485
486 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
487 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
488 return;
489 };
490
491 let editor = editor.read(cx);
492 if editor.selections.count() == 1 {
493 let selection = editor.selections.newest::<usize>(cx);
494 let buffer = editor.buffer().read(cx).snapshot(cx);
495 for assist_id in &editor_assists.assist_ids {
496 let assist = &self.assists[assist_id];
497 let assist_range = assist.range.to_offset(&buffer);
498 if assist.decorations.is_some()
499 && assist_range.contains(&selection.start)
500 && assist_range.contains(&selection.end)
501 {
502 self.focus_assist(*assist_id, cx);
503 return;
504 }
505 }
506 }
507
508 cx.propagate();
509 }
510
511 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
512 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
513 for assist_id in editor_assists.assist_ids.clone() {
514 self.finish_assist(assist_id, true, cx);
515 }
516 }
517 }
518
519 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
520 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
521 return;
522 };
523 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
524 return;
525 };
526 let assist = &self.assists[&scroll_lock.assist_id];
527 let Some(decorations) = assist.decorations.as_ref() else {
528 return;
529 };
530
531 editor.update(cx, |editor, cx| {
532 let scroll_position = editor.scroll_position(cx);
533 let target_scroll_top = editor
534 .row_for_block(decorations.prompt_block_id, cx)
535 .unwrap()
536 .0 as f32
537 - scroll_lock.distance_from_top;
538 if target_scroll_top != scroll_position.y {
539 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
540 }
541 });
542 }
543
544 fn handle_editor_event(
545 &mut self,
546 editor: View<Editor>,
547 event: &EditorEvent,
548 cx: &mut WindowContext,
549 ) {
550 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
551 return;
552 };
553
554 match event {
555 EditorEvent::Saved => {
556 for assist_id in editor_assists.assist_ids.clone() {
557 let assist = &self.assists[&assist_id];
558 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
559 self.finish_assist(assist_id, false, cx)
560 }
561 }
562 }
563 EditorEvent::Edited { transaction_id } => {
564 let buffer = editor.read(cx).buffer().read(cx);
565 let edited_ranges =
566 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
567 let snapshot = buffer.snapshot(cx);
568
569 for assist_id in editor_assists.assist_ids.clone() {
570 let assist = &self.assists[&assist_id];
571 if matches!(
572 assist.codegen.read(cx).status,
573 CodegenStatus::Error(_) | CodegenStatus::Done
574 ) {
575 let assist_range = assist.range.to_offset(&snapshot);
576 if edited_ranges
577 .iter()
578 .any(|range| range.overlaps(&assist_range))
579 {
580 self.finish_assist(assist_id, false, cx);
581 }
582 }
583 }
584 }
585 EditorEvent::ScrollPositionChanged { .. } => {
586 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
587 let assist = &self.assists[&scroll_lock.assist_id];
588 if let Some(decorations) = assist.decorations.as_ref() {
589 let distance_from_top = editor.update(cx, |editor, cx| {
590 let scroll_top = editor.scroll_position(cx).y;
591 let prompt_row = editor
592 .row_for_block(decorations.prompt_block_id, cx)
593 .unwrap()
594 .0 as f32;
595 prompt_row - scroll_top
596 });
597
598 if distance_from_top != scroll_lock.distance_from_top {
599 editor_assists.scroll_lock = None;
600 }
601 }
602 }
603 }
604 EditorEvent::SelectionsChanged { .. } => {
605 for assist_id in editor_assists.assist_ids.clone() {
606 let assist = &self.assists[&assist_id];
607 if let Some(decorations) = assist.decorations.as_ref() {
608 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
609 return;
610 }
611 }
612 }
613
614 editor_assists.scroll_lock = None;
615 }
616 _ => {}
617 }
618 }
619
620 fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
621 if let Some(assist) = self.assists.get(&assist_id) {
622 let assist_group_id = assist.group_id;
623 if self.assist_groups[&assist_group_id].linked {
624 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
625 self.finish_assist(assist_id, undo, cx);
626 }
627 return;
628 }
629 }
630
631 self.dismiss_assist(assist_id, cx);
632
633 if let Some(assist) = self.assists.remove(&assist_id) {
634 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
635 {
636 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
637 if entry.get().assist_ids.is_empty() {
638 entry.remove();
639 }
640 }
641
642 if let hash_map::Entry::Occupied(mut entry) =
643 self.assists_by_editor.entry(assist.editor.clone())
644 {
645 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
646 if entry.get().assist_ids.is_empty() {
647 entry.remove();
648 if let Some(editor) = assist.editor.upgrade() {
649 self.update_editor_highlights(&editor, cx);
650 }
651 } else {
652 entry.get().highlight_updates.send(()).ok();
653 }
654 }
655
656 if undo {
657 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
658 }
659 }
660 }
661
662 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
663 let Some(assist) = self.assists.get_mut(&assist_id) else {
664 return false;
665 };
666 let Some(editor) = assist.editor.upgrade() else {
667 return false;
668 };
669 let Some(decorations) = assist.decorations.take() else {
670 return false;
671 };
672
673 editor.update(cx, |editor, cx| {
674 let mut to_remove = decorations.removed_line_block_ids;
675 to_remove.insert(decorations.prompt_block_id);
676 to_remove.insert(decorations.end_block_id);
677 editor.remove_blocks(to_remove, None, cx);
678 });
679
680 if decorations
681 .prompt_editor
682 .focus_handle(cx)
683 .contains_focused(cx)
684 {
685 self.focus_next_assist(assist_id, cx);
686 }
687
688 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
689 if editor_assists
690 .scroll_lock
691 .as_ref()
692 .map_or(false, |lock| lock.assist_id == assist_id)
693 {
694 editor_assists.scroll_lock = None;
695 }
696 editor_assists.highlight_updates.send(()).ok();
697 }
698
699 true
700 }
701
702 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
703 let Some(assist) = self.assists.get(&assist_id) else {
704 return;
705 };
706
707 let assist_group = &self.assist_groups[&assist.group_id];
708 let assist_ix = assist_group
709 .assist_ids
710 .iter()
711 .position(|id| *id == assist_id)
712 .unwrap();
713 let assist_ids = assist_group
714 .assist_ids
715 .iter()
716 .skip(assist_ix + 1)
717 .chain(assist_group.assist_ids.iter().take(assist_ix));
718
719 for assist_id in assist_ids {
720 let assist = &self.assists[assist_id];
721 if assist.decorations.is_some() {
722 self.focus_assist(*assist_id, cx);
723 return;
724 }
725 }
726
727 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
728 }
729
730 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
731 let assist = &self.assists[&assist_id];
732 let Some(editor) = assist.editor.upgrade() else {
733 return;
734 };
735
736 if let Some(decorations) = assist.decorations.as_ref() {
737 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
738 prompt_editor.editor.update(cx, |editor, cx| {
739 editor.focus(cx);
740 editor.select_all(&SelectAll, cx);
741 })
742 });
743 }
744
745 let position = assist.range.start;
746 editor.update(cx, |editor, cx| {
747 editor.change_selections(None, cx, |selections| {
748 selections.select_anchor_ranges([position..position])
749 });
750
751 let mut scroll_target_top;
752 let mut scroll_target_bottom;
753 if let Some(decorations) = assist.decorations.as_ref() {
754 scroll_target_top = editor
755 .row_for_block(decorations.prompt_block_id, cx)
756 .unwrap()
757 .0 as f32;
758 scroll_target_bottom = editor
759 .row_for_block(decorations.end_block_id, cx)
760 .unwrap()
761 .0 as f32;
762 } else {
763 let snapshot = editor.snapshot(cx);
764 let start_row = assist
765 .range
766 .start
767 .to_display_point(&snapshot.display_snapshot)
768 .row();
769 scroll_target_top = start_row.0 as f32;
770 scroll_target_bottom = scroll_target_top + 1.;
771 }
772 scroll_target_top -= editor.vertical_scroll_margin() as f32;
773 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
774
775 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
776 let scroll_top = editor.scroll_position(cx).y;
777 let scroll_bottom = scroll_top + height_in_lines;
778
779 if scroll_target_top < scroll_top {
780 editor.set_scroll_position(point(0., scroll_target_top), cx);
781 } else if scroll_target_bottom > scroll_bottom {
782 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
783 editor
784 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
785 } else {
786 editor.set_scroll_position(point(0., scroll_target_top), cx);
787 }
788 }
789 });
790 }
791
792 fn unlink_assist_group(
793 &mut self,
794 assist_group_id: InlineAssistGroupId,
795 cx: &mut WindowContext,
796 ) -> Vec<InlineAssistId> {
797 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
798 assist_group.linked = false;
799 for assist_id in &assist_group.assist_ids {
800 let assist = self.assists.get_mut(assist_id).unwrap();
801 if let Some(editor_decorations) = assist.decorations.as_ref() {
802 editor_decorations
803 .prompt_editor
804 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
805 }
806 }
807 assist_group.assist_ids.clone()
808 }
809
810 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
811 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
812 assist
813 } else {
814 return;
815 };
816
817 let assist_group_id = assist.group_id;
818 if self.assist_groups[&assist_group_id].linked {
819 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
820 self.start_assist(assist_id, cx);
821 }
822 return;
823 }
824
825 let Some(user_prompt) = assist.user_prompt(cx) else {
826 return;
827 };
828
829 self.prompt_history.retain(|prompt| *prompt != user_prompt);
830 self.prompt_history.push_back(user_prompt.clone());
831 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
832 self.prompt_history.pop_front();
833 }
834
835 let assistant_panel_context = assist.assistant_panel_context(cx);
836
837 assist
838 .codegen
839 .update(cx, |codegen, cx| {
840 codegen.start(
841 assist.range.clone(),
842 user_prompt,
843 assistant_panel_context,
844 cx,
845 )
846 })
847 .log_err();
848 }
849
850 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
851 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
852 assist
853 } else {
854 return;
855 };
856
857 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
858 }
859
860 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
861 let mut gutter_pending_ranges = Vec::new();
862 let mut gutter_transformed_ranges = Vec::new();
863 let mut foreground_ranges = Vec::new();
864 let mut inserted_row_ranges = Vec::new();
865 let empty_assist_ids = Vec::new();
866 let assist_ids = self
867 .assists_by_editor
868 .get(&editor.downgrade())
869 .map_or(&empty_assist_ids, |editor_assists| {
870 &editor_assists.assist_ids
871 });
872
873 for assist_id in assist_ids {
874 if let Some(assist) = self.assists.get(assist_id) {
875 let codegen = assist.codegen.read(cx);
876 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
877
878 gutter_pending_ranges
879 .push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
880
881 if let Some(edit_position) = codegen.edit_position {
882 gutter_transformed_ranges.push(assist.range.start..edit_position);
883 }
884
885 if assist.decorations.is_some() {
886 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
887 }
888 }
889 }
890
891 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
892 merge_ranges(&mut foreground_ranges, &snapshot);
893 merge_ranges(&mut gutter_pending_ranges, &snapshot);
894 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
895 editor.update(cx, |editor, cx| {
896 enum GutterPendingRange {}
897 if gutter_pending_ranges.is_empty() {
898 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
899 } else {
900 editor.highlight_gutter::<GutterPendingRange>(
901 &gutter_pending_ranges,
902 |cx| cx.theme().status().info_background,
903 cx,
904 )
905 }
906
907 enum GutterTransformedRange {}
908 if gutter_transformed_ranges.is_empty() {
909 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
910 } else {
911 editor.highlight_gutter::<GutterTransformedRange>(
912 &gutter_transformed_ranges,
913 |cx| cx.theme().status().info,
914 cx,
915 )
916 }
917
918 if foreground_ranges.is_empty() {
919 editor.clear_highlights::<InlineAssist>(cx);
920 } else {
921 editor.highlight_text::<InlineAssist>(
922 foreground_ranges,
923 HighlightStyle {
924 fade_out: Some(0.6),
925 ..Default::default()
926 },
927 cx,
928 );
929 }
930
931 editor.clear_row_highlights::<InlineAssist>();
932 for row_range in inserted_row_ranges {
933 editor.highlight_rows::<InlineAssist>(
934 row_range,
935 Some(cx.theme().status().info_background),
936 false,
937 cx,
938 );
939 }
940 });
941 }
942
943 fn update_editor_blocks(
944 &mut self,
945 editor: &View<Editor>,
946 assist_id: InlineAssistId,
947 cx: &mut WindowContext,
948 ) {
949 let Some(assist) = self.assists.get_mut(&assist_id) else {
950 return;
951 };
952 let Some(decorations) = assist.decorations.as_mut() else {
953 return;
954 };
955
956 let codegen = assist.codegen.read(cx);
957 let old_snapshot = codegen.snapshot.clone();
958 let old_buffer = codegen.old_buffer.clone();
959 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
960
961 editor.update(cx, |editor, cx| {
962 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
963 editor.remove_blocks(old_blocks, None, cx);
964
965 let mut new_blocks = Vec::new();
966 for (new_row, old_row_range) in deleted_row_ranges {
967 let (_, buffer_start) = old_snapshot
968 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
969 .unwrap();
970 let (_, buffer_end) = old_snapshot
971 .point_to_buffer_offset(Point::new(
972 *old_row_range.end(),
973 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
974 ))
975 .unwrap();
976
977 let deleted_lines_editor = cx.new_view(|cx| {
978 let multi_buffer = cx.new_model(|_| {
979 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
980 });
981 multi_buffer.update(cx, |multi_buffer, cx| {
982 multi_buffer.push_excerpts(
983 old_buffer.clone(),
984 Some(ExcerptRange {
985 context: buffer_start..buffer_end,
986 primary: None,
987 }),
988 cx,
989 );
990 });
991
992 enum DeletedLines {}
993 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
994 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
995 editor.set_show_wrap_guides(false, cx);
996 editor.set_show_gutter(false, cx);
997 editor.scroll_manager.set_forbid_vertical_scroll(true);
998 editor.set_read_only(true);
999 editor.highlight_rows::<DeletedLines>(
1000 Anchor::min()..=Anchor::max(),
1001 Some(cx.theme().status().deleted_background),
1002 false,
1003 cx,
1004 );
1005 editor
1006 });
1007
1008 let height =
1009 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1010 new_blocks.push(BlockProperties {
1011 position: new_row,
1012 height,
1013 style: BlockStyle::Flex,
1014 render: Box::new(move |cx| {
1015 div()
1016 .bg(cx.theme().status().deleted_background)
1017 .size_full()
1018 .h(height as f32 * cx.line_height())
1019 .pl(cx.gutter_dimensions.full_width())
1020 .child(deleted_lines_editor.clone())
1021 .into_any_element()
1022 }),
1023 disposition: BlockDisposition::Above,
1024 });
1025 }
1026
1027 decorations.removed_line_block_ids = editor
1028 .insert_blocks(new_blocks, None, cx)
1029 .into_iter()
1030 .collect();
1031 })
1032 }
1033}
1034
1035struct EditorInlineAssists {
1036 assist_ids: Vec<InlineAssistId>,
1037 scroll_lock: Option<InlineAssistScrollLock>,
1038 highlight_updates: async_watch::Sender<()>,
1039 _update_highlights: Task<Result<()>>,
1040 _subscriptions: Vec<gpui::Subscription>,
1041}
1042
1043struct InlineAssistScrollLock {
1044 assist_id: InlineAssistId,
1045 distance_from_top: f32,
1046}
1047
1048impl EditorInlineAssists {
1049 #[allow(clippy::too_many_arguments)]
1050 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1051 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1052 Self {
1053 assist_ids: Vec::new(),
1054 scroll_lock: None,
1055 highlight_updates: highlight_updates_tx,
1056 _update_highlights: cx.spawn(|mut cx| {
1057 let editor = editor.downgrade();
1058 async move {
1059 while let Ok(()) = highlight_updates_rx.changed().await {
1060 let editor = editor.upgrade().context("editor was dropped")?;
1061 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1062 assistant.update_editor_highlights(&editor, cx);
1063 })?;
1064 }
1065 Ok(())
1066 }
1067 }),
1068 _subscriptions: vec![
1069 cx.observe_release(editor, {
1070 let editor = editor.downgrade();
1071 |_, cx| {
1072 InlineAssistant::update_global(cx, |this, cx| {
1073 this.handle_editor_release(editor, cx);
1074 })
1075 }
1076 }),
1077 cx.observe(editor, move |editor, cx| {
1078 InlineAssistant::update_global(cx, |this, cx| {
1079 this.handle_editor_change(editor, cx)
1080 })
1081 }),
1082 cx.subscribe(editor, move |editor, event, cx| {
1083 InlineAssistant::update_global(cx, |this, cx| {
1084 this.handle_editor_event(editor, event, cx)
1085 })
1086 }),
1087 editor.update(cx, |editor, cx| {
1088 let editor_handle = cx.view().downgrade();
1089 editor.register_action(
1090 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1091 InlineAssistant::update_global(cx, |this, cx| {
1092 if let Some(editor) = editor_handle.upgrade() {
1093 this.handle_editor_newline(editor, cx)
1094 }
1095 })
1096 },
1097 )
1098 }),
1099 editor.update(cx, |editor, cx| {
1100 let editor_handle = cx.view().downgrade();
1101 editor.register_action(
1102 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1103 InlineAssistant::update_global(cx, |this, cx| {
1104 if let Some(editor) = editor_handle.upgrade() {
1105 this.handle_editor_cancel(editor, cx)
1106 }
1107 })
1108 },
1109 )
1110 }),
1111 ],
1112 }
1113 }
1114}
1115
1116struct InlineAssistGroup {
1117 assist_ids: Vec<InlineAssistId>,
1118 linked: bool,
1119 active_assist_id: Option<InlineAssistId>,
1120}
1121
1122impl InlineAssistGroup {
1123 fn new() -> Self {
1124 Self {
1125 assist_ids: Vec::new(),
1126 linked: true,
1127 active_assist_id: None,
1128 }
1129 }
1130}
1131
1132fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1133 let editor = editor.clone();
1134 Box::new(move |cx: &mut BlockContext| {
1135 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1136 editor.clone().into_any_element()
1137 })
1138}
1139
1140#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1141pub enum InitialInsertion {
1142 NewlineBefore,
1143 NewlineAfter,
1144}
1145
1146#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1147pub struct InlineAssistId(usize);
1148
1149impl InlineAssistId {
1150 fn post_inc(&mut self) -> InlineAssistId {
1151 let id = *self;
1152 self.0 += 1;
1153 id
1154 }
1155}
1156
1157#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1158struct InlineAssistGroupId(usize);
1159
1160impl InlineAssistGroupId {
1161 fn post_inc(&mut self) -> InlineAssistGroupId {
1162 let id = *self;
1163 self.0 += 1;
1164 id
1165 }
1166}
1167
1168enum PromptEditorEvent {
1169 StartRequested,
1170 StopRequested,
1171 ConfirmRequested,
1172 CancelRequested,
1173 DismissRequested,
1174}
1175
1176struct PromptEditor {
1177 id: InlineAssistId,
1178 fs: Arc<dyn Fs>,
1179 editor: View<Editor>,
1180 edited_since_done: bool,
1181 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1182 prompt_history: VecDeque<String>,
1183 prompt_history_ix: Option<usize>,
1184 pending_prompt: String,
1185 codegen: Model<Codegen>,
1186 _codegen_subscription: Subscription,
1187 editor_subscriptions: Vec<Subscription>,
1188 pending_token_count: Task<Result<()>>,
1189 token_count: Option<usize>,
1190 _token_count_subscriptions: Vec<Subscription>,
1191 workspace: Option<WeakView<Workspace>>,
1192 show_rate_limit_notice: bool,
1193}
1194
1195impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1196
1197impl Render for PromptEditor {
1198 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1199 let gutter_dimensions = *self.gutter_dimensions.lock();
1200 let status = &self.codegen.read(cx).status;
1201 let buttons = match status {
1202 CodegenStatus::Idle => {
1203 vec![
1204 IconButton::new("cancel", IconName::Close)
1205 .icon_color(Color::Muted)
1206 .shape(IconButtonShape::Square)
1207 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1208 .on_click(
1209 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1210 ),
1211 IconButton::new("start", IconName::SparkleAlt)
1212 .icon_color(Color::Muted)
1213 .shape(IconButtonShape::Square)
1214 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1215 .on_click(
1216 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1217 ),
1218 ]
1219 }
1220 CodegenStatus::Pending => {
1221 vec![
1222 IconButton::new("cancel", IconName::Close)
1223 .icon_color(Color::Muted)
1224 .shape(IconButtonShape::Square)
1225 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1226 .on_click(
1227 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1228 ),
1229 IconButton::new("stop", IconName::Stop)
1230 .icon_color(Color::Error)
1231 .shape(IconButtonShape::Square)
1232 .tooltip(|cx| {
1233 Tooltip::with_meta(
1234 "Interrupt Transformation",
1235 Some(&menu::Cancel),
1236 "Changes won't be discarded",
1237 cx,
1238 )
1239 })
1240 .on_click(
1241 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1242 ),
1243 ]
1244 }
1245 CodegenStatus::Error(_) | CodegenStatus::Done => {
1246 vec![
1247 IconButton::new("cancel", IconName::Close)
1248 .icon_color(Color::Muted)
1249 .shape(IconButtonShape::Square)
1250 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1251 .on_click(
1252 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1253 ),
1254 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1255 IconButton::new("restart", IconName::RotateCw)
1256 .icon_color(Color::Info)
1257 .shape(IconButtonShape::Square)
1258 .tooltip(|cx| {
1259 Tooltip::with_meta(
1260 "Restart Transformation",
1261 Some(&menu::Confirm),
1262 "Changes will be discarded",
1263 cx,
1264 )
1265 })
1266 .on_click(cx.listener(|_, _, cx| {
1267 cx.emit(PromptEditorEvent::StartRequested);
1268 }))
1269 } else {
1270 IconButton::new("confirm", IconName::Check)
1271 .icon_color(Color::Info)
1272 .shape(IconButtonShape::Square)
1273 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1274 .on_click(cx.listener(|_, _, cx| {
1275 cx.emit(PromptEditorEvent::ConfirmRequested);
1276 }))
1277 },
1278 ]
1279 }
1280 };
1281
1282 h_flex()
1283 .bg(cx.theme().colors().editor_background)
1284 .border_y_1()
1285 .border_color(cx.theme().status().info_border)
1286 .size_full()
1287 .py(cx.line_height() / 2.)
1288 .on_action(cx.listener(Self::confirm))
1289 .on_action(cx.listener(Self::cancel))
1290 .on_action(cx.listener(Self::move_up))
1291 .on_action(cx.listener(Self::move_down))
1292 .child(
1293 h_flex()
1294 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1295 .justify_center()
1296 .gap_2()
1297 .child(
1298 ModelSelector::new(
1299 self.fs.clone(),
1300 IconButton::new("context", IconName::SlidersAlt)
1301 .shape(IconButtonShape::Square)
1302 .icon_size(IconSize::Small)
1303 .icon_color(Color::Muted)
1304 .tooltip(move |cx| {
1305 Tooltip::with_meta(
1306 format!(
1307 "Using {}",
1308 LanguageModelRegistry::read_global(cx)
1309 .active_model()
1310 .map(|model| model.name().0)
1311 .unwrap_or_else(|| "No model selected".into()),
1312 ),
1313 None,
1314 "Change Model",
1315 cx,
1316 )
1317 }),
1318 )
1319 .with_info_text(
1320 "Inline edits use context\n\
1321 from the currently selected\n\
1322 assistant panel tab.",
1323 ),
1324 )
1325 .map(|el| {
1326 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1327 return el;
1328 };
1329
1330 let error_message = SharedString::from(error.to_string());
1331 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1332 && cx.has_flag::<ZedPro>()
1333 {
1334 el.child(
1335 v_flex()
1336 .child(
1337 IconButton::new("rate-limit-error", IconName::XCircle)
1338 .selected(self.show_rate_limit_notice)
1339 .shape(IconButtonShape::Square)
1340 .icon_size(IconSize::Small)
1341 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1342 )
1343 .children(self.show_rate_limit_notice.then(|| {
1344 deferred(
1345 anchored()
1346 .position_mode(gpui::AnchoredPositionMode::Local)
1347 .position(point(px(0.), px(24.)))
1348 .anchor(gpui::AnchorCorner::TopLeft)
1349 .child(self.render_rate_limit_notice(cx)),
1350 )
1351 })),
1352 )
1353 } else {
1354 el.child(
1355 div()
1356 .id("error")
1357 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1358 .child(
1359 Icon::new(IconName::XCircle)
1360 .size(IconSize::Small)
1361 .color(Color::Error),
1362 ),
1363 )
1364 }
1365 }),
1366 )
1367 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1368 .child(
1369 h_flex()
1370 .gap_2()
1371 .pr_6()
1372 .children(self.render_token_count(cx))
1373 .children(buttons),
1374 )
1375 }
1376}
1377
1378impl FocusableView for PromptEditor {
1379 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1380 self.editor.focus_handle(cx)
1381 }
1382}
1383
1384impl PromptEditor {
1385 const MAX_LINES: u8 = 8;
1386
1387 #[allow(clippy::too_many_arguments)]
1388 fn new(
1389 id: InlineAssistId,
1390 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1391 prompt_history: VecDeque<String>,
1392 prompt_buffer: Model<MultiBuffer>,
1393 codegen: Model<Codegen>,
1394 parent_editor: &View<Editor>,
1395 assistant_panel: Option<&View<AssistantPanel>>,
1396 workspace: Option<WeakView<Workspace>>,
1397 fs: Arc<dyn Fs>,
1398 cx: &mut ViewContext<Self>,
1399 ) -> Self {
1400 let prompt_editor = cx.new_view(|cx| {
1401 let mut editor = Editor::new(
1402 EditorMode::AutoHeight {
1403 max_lines: Self::MAX_LINES as usize,
1404 },
1405 prompt_buffer,
1406 None,
1407 false,
1408 cx,
1409 );
1410 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1411 // Since the prompt editors for all inline assistants are linked,
1412 // always show the cursor (even when it isn't focused) because
1413 // typing in one will make what you typed appear in all of them.
1414 editor.set_show_cursor_when_unfocused(true, cx);
1415 editor.set_placeholder_text("Add a prompt…", cx);
1416 editor
1417 });
1418
1419 let mut token_count_subscriptions = Vec::new();
1420 token_count_subscriptions
1421 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1422 if let Some(assistant_panel) = assistant_panel {
1423 token_count_subscriptions
1424 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1425 }
1426
1427 let mut this = Self {
1428 id,
1429 editor: prompt_editor,
1430 edited_since_done: false,
1431 gutter_dimensions,
1432 prompt_history,
1433 prompt_history_ix: None,
1434 pending_prompt: String::new(),
1435 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1436 editor_subscriptions: Vec::new(),
1437 codegen,
1438 fs,
1439 pending_token_count: Task::ready(Ok(())),
1440 token_count: None,
1441 _token_count_subscriptions: token_count_subscriptions,
1442 workspace,
1443 show_rate_limit_notice: false,
1444 };
1445 this.count_tokens(cx);
1446 this.subscribe_to_editor(cx);
1447 this
1448 }
1449
1450 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1451 self.editor_subscriptions.clear();
1452 self.editor_subscriptions
1453 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1454 }
1455
1456 fn set_show_cursor_when_unfocused(
1457 &mut self,
1458 show_cursor_when_unfocused: bool,
1459 cx: &mut ViewContext<Self>,
1460 ) {
1461 self.editor.update(cx, |editor, cx| {
1462 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1463 });
1464 }
1465
1466 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1467 let prompt = self.prompt(cx);
1468 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1469 self.editor = cx.new_view(|cx| {
1470 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1471 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1472 editor.set_placeholder_text("Add a prompt…", cx);
1473 editor.set_text(prompt, cx);
1474 if focus {
1475 editor.focus(cx);
1476 }
1477 editor
1478 });
1479 self.subscribe_to_editor(cx);
1480 }
1481
1482 fn prompt(&self, cx: &AppContext) -> String {
1483 self.editor.read(cx).text(cx)
1484 }
1485
1486 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1487 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1488 if self.show_rate_limit_notice {
1489 cx.focus_view(&self.editor);
1490 }
1491 cx.notify();
1492 }
1493
1494 fn handle_parent_editor_event(
1495 &mut self,
1496 _: View<Editor>,
1497 event: &EditorEvent,
1498 cx: &mut ViewContext<Self>,
1499 ) {
1500 if let EditorEvent::BufferEdited { .. } = event {
1501 self.count_tokens(cx);
1502 }
1503 }
1504
1505 fn handle_assistant_panel_event(
1506 &mut self,
1507 _: View<AssistantPanel>,
1508 event: &AssistantPanelEvent,
1509 cx: &mut ViewContext<Self>,
1510 ) {
1511 let AssistantPanelEvent::ContextEdited { .. } = event;
1512 self.count_tokens(cx);
1513 }
1514
1515 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1516 let assist_id = self.id;
1517 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1518 cx.background_executor().timer(Duration::from_secs(1)).await;
1519 let token_count = cx
1520 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1521 let assist = inline_assistant
1522 .assists
1523 .get(&assist_id)
1524 .context("assist not found")?;
1525 anyhow::Ok(assist.count_tokens(cx))
1526 })??
1527 .await?;
1528
1529 this.update(&mut cx, |this, cx| {
1530 this.token_count = Some(token_count);
1531 cx.notify();
1532 })
1533 })
1534 }
1535
1536 fn handle_prompt_editor_events(
1537 &mut self,
1538 _: View<Editor>,
1539 event: &EditorEvent,
1540 cx: &mut ViewContext<Self>,
1541 ) {
1542 match event {
1543 EditorEvent::Edited { .. } => {
1544 let prompt = self.editor.read(cx).text(cx);
1545 if self
1546 .prompt_history_ix
1547 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1548 {
1549 self.prompt_history_ix.take();
1550 self.pending_prompt = prompt;
1551 }
1552
1553 self.edited_since_done = true;
1554 cx.notify();
1555 }
1556 EditorEvent::BufferEdited => {
1557 self.count_tokens(cx);
1558 }
1559 EditorEvent::Blurred => {
1560 if self.show_rate_limit_notice {
1561 self.show_rate_limit_notice = false;
1562 cx.notify();
1563 }
1564 }
1565 _ => {}
1566 }
1567 }
1568
1569 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1570 match &self.codegen.read(cx).status {
1571 CodegenStatus::Idle => {
1572 self.editor
1573 .update(cx, |editor, _| editor.set_read_only(false));
1574 }
1575 CodegenStatus::Pending => {
1576 self.editor
1577 .update(cx, |editor, _| editor.set_read_only(true));
1578 }
1579 CodegenStatus::Done => {
1580 self.edited_since_done = false;
1581 self.editor
1582 .update(cx, |editor, _| editor.set_read_only(false));
1583 }
1584 CodegenStatus::Error(error) => {
1585 if cx.has_flag::<ZedPro>()
1586 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1587 && !dismissed_rate_limit_notice()
1588 {
1589 self.show_rate_limit_notice = true;
1590 cx.notify();
1591 }
1592
1593 self.edited_since_done = false;
1594 self.editor
1595 .update(cx, |editor, _| editor.set_read_only(false));
1596 }
1597 }
1598 }
1599
1600 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1601 match &self.codegen.read(cx).status {
1602 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1603 cx.emit(PromptEditorEvent::CancelRequested);
1604 }
1605 CodegenStatus::Pending => {
1606 cx.emit(PromptEditorEvent::StopRequested);
1607 }
1608 }
1609 }
1610
1611 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1612 match &self.codegen.read(cx).status {
1613 CodegenStatus::Idle => {
1614 cx.emit(PromptEditorEvent::StartRequested);
1615 }
1616 CodegenStatus::Pending => {
1617 cx.emit(PromptEditorEvent::DismissRequested);
1618 }
1619 CodegenStatus::Done | CodegenStatus::Error(_) => {
1620 if self.edited_since_done {
1621 cx.emit(PromptEditorEvent::StartRequested);
1622 } else {
1623 cx.emit(PromptEditorEvent::ConfirmRequested);
1624 }
1625 }
1626 }
1627 }
1628
1629 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1630 if let Some(ix) = self.prompt_history_ix {
1631 if ix > 0 {
1632 self.prompt_history_ix = Some(ix - 1);
1633 let prompt = self.prompt_history[ix - 1].as_str();
1634 self.editor.update(cx, |editor, cx| {
1635 editor.set_text(prompt, cx);
1636 editor.move_to_beginning(&Default::default(), cx);
1637 });
1638 }
1639 } else if !self.prompt_history.is_empty() {
1640 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1641 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1642 self.editor.update(cx, |editor, cx| {
1643 editor.set_text(prompt, cx);
1644 editor.move_to_beginning(&Default::default(), cx);
1645 });
1646 }
1647 }
1648
1649 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1650 if let Some(ix) = self.prompt_history_ix {
1651 if ix < self.prompt_history.len() - 1 {
1652 self.prompt_history_ix = Some(ix + 1);
1653 let prompt = self.prompt_history[ix + 1].as_str();
1654 self.editor.update(cx, |editor, cx| {
1655 editor.set_text(prompt, cx);
1656 editor.move_to_end(&Default::default(), cx)
1657 });
1658 } else {
1659 self.prompt_history_ix = None;
1660 let prompt = self.pending_prompt.as_str();
1661 self.editor.update(cx, |editor, cx| {
1662 editor.set_text(prompt, cx);
1663 editor.move_to_end(&Default::default(), cx)
1664 });
1665 }
1666 }
1667 }
1668
1669 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1670 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1671 let token_count = self.token_count?;
1672 let max_token_count = model.max_token_count();
1673
1674 let remaining_tokens = max_token_count as isize - token_count as isize;
1675 let token_count_color = if remaining_tokens <= 0 {
1676 Color::Error
1677 } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1678 Color::Warning
1679 } else {
1680 Color::Muted
1681 };
1682
1683 let mut token_count = h_flex()
1684 .id("token_count")
1685 .gap_0p5()
1686 .child(
1687 Label::new(humanize_token_count(token_count))
1688 .size(LabelSize::Small)
1689 .color(token_count_color),
1690 )
1691 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1692 .child(
1693 Label::new(humanize_token_count(max_token_count))
1694 .size(LabelSize::Small)
1695 .color(Color::Muted),
1696 );
1697 if let Some(workspace) = self.workspace.clone() {
1698 token_count = token_count
1699 .tooltip(|cx| {
1700 Tooltip::with_meta(
1701 "Tokens Used by Inline Assistant",
1702 None,
1703 "Click to Open Assistant Panel",
1704 cx,
1705 )
1706 })
1707 .cursor_pointer()
1708 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1709 .on_click(move |_, cx| {
1710 cx.stop_propagation();
1711 workspace
1712 .update(cx, |workspace, cx| {
1713 workspace.focus_panel::<AssistantPanel>(cx)
1714 })
1715 .ok();
1716 });
1717 } else {
1718 token_count = token_count
1719 .cursor_default()
1720 .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1721 }
1722
1723 Some(token_count)
1724 }
1725
1726 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1727 let settings = ThemeSettings::get_global(cx);
1728 let text_style = TextStyle {
1729 color: if self.editor.read(cx).read_only(cx) {
1730 cx.theme().colors().text_disabled
1731 } else {
1732 cx.theme().colors().text
1733 },
1734 font_family: settings.ui_font.family.clone(),
1735 font_features: settings.ui_font.features.clone(),
1736 font_fallbacks: settings.ui_font.fallbacks.clone(),
1737 font_size: rems(0.875).into(),
1738 font_weight: settings.ui_font.weight,
1739 line_height: relative(1.3),
1740 ..Default::default()
1741 };
1742 EditorElement::new(
1743 &self.editor,
1744 EditorStyle {
1745 background: cx.theme().colors().editor_background,
1746 local_player: cx.theme().players().local(),
1747 text: text_style,
1748 ..Default::default()
1749 },
1750 )
1751 }
1752
1753 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1754 Popover::new().child(
1755 v_flex()
1756 .occlude()
1757 .p_2()
1758 .child(
1759 Label::new("Out of Tokens")
1760 .size(LabelSize::Small)
1761 .weight(FontWeight::BOLD),
1762 )
1763 .child(Label::new(
1764 "Try Zed Pro for higher limits, a wider range of models, and more.",
1765 ))
1766 .child(
1767 h_flex()
1768 .justify_between()
1769 .child(CheckboxWithLabel::new(
1770 "dont-show-again",
1771 Label::new("Don't show again"),
1772 if dismissed_rate_limit_notice() {
1773 ui::Selection::Selected
1774 } else {
1775 ui::Selection::Unselected
1776 },
1777 |selection, cx| {
1778 let is_dismissed = match selection {
1779 ui::Selection::Unselected => false,
1780 ui::Selection::Indeterminate => return,
1781 ui::Selection::Selected => true,
1782 };
1783
1784 set_rate_limit_notice_dismissed(is_dismissed, cx)
1785 },
1786 ))
1787 .child(
1788 h_flex()
1789 .gap_2()
1790 .child(
1791 Button::new("dismiss", "Dismiss")
1792 .style(ButtonStyle::Transparent)
1793 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1794 )
1795 .child(Button::new("more-info", "More Info").on_click(
1796 |_event, cx| {
1797 cx.dispatch_action(Box::new(
1798 zed_actions::OpenAccountSettings,
1799 ))
1800 },
1801 )),
1802 ),
1803 ),
1804 )
1805 }
1806}
1807
1808const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1809
1810fn dismissed_rate_limit_notice() -> bool {
1811 db::kvp::KEY_VALUE_STORE
1812 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1813 .log_err()
1814 .map_or(false, |s| s.is_some())
1815}
1816
1817fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1818 db::write_and_log(cx, move || async move {
1819 if is_dismissed {
1820 db::kvp::KEY_VALUE_STORE
1821 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1822 .await
1823 } else {
1824 db::kvp::KEY_VALUE_STORE
1825 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1826 .await
1827 }
1828 })
1829}
1830
1831struct InlineAssist {
1832 group_id: InlineAssistGroupId,
1833 range: Range<Anchor>,
1834 editor: WeakView<Editor>,
1835 decorations: Option<InlineAssistDecorations>,
1836 codegen: Model<Codegen>,
1837 _subscriptions: Vec<Subscription>,
1838 workspace: Option<WeakView<Workspace>>,
1839 include_context: bool,
1840}
1841
1842impl InlineAssist {
1843 #[allow(clippy::too_many_arguments)]
1844 fn new(
1845 assist_id: InlineAssistId,
1846 group_id: InlineAssistGroupId,
1847 include_context: bool,
1848 editor: &View<Editor>,
1849 prompt_editor: &View<PromptEditor>,
1850 prompt_block_id: CustomBlockId,
1851 end_block_id: CustomBlockId,
1852 range: Range<Anchor>,
1853 codegen: Model<Codegen>,
1854 workspace: Option<WeakView<Workspace>>,
1855 cx: &mut WindowContext,
1856 ) -> Self {
1857 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1858 InlineAssist {
1859 group_id,
1860 include_context,
1861 editor: editor.downgrade(),
1862 decorations: Some(InlineAssistDecorations {
1863 prompt_block_id,
1864 prompt_editor: prompt_editor.clone(),
1865 removed_line_block_ids: HashSet::default(),
1866 end_block_id,
1867 }),
1868 range,
1869 codegen: codegen.clone(),
1870 workspace: workspace.clone(),
1871 _subscriptions: vec![
1872 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1873 InlineAssistant::update_global(cx, |this, cx| {
1874 this.handle_prompt_editor_focus_in(assist_id, cx)
1875 })
1876 }),
1877 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1878 InlineAssistant::update_global(cx, |this, cx| {
1879 this.handle_prompt_editor_focus_out(assist_id, cx)
1880 })
1881 }),
1882 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1883 InlineAssistant::update_global(cx, |this, cx| {
1884 this.handle_prompt_editor_event(prompt_editor, event, cx)
1885 })
1886 }),
1887 cx.observe(&codegen, {
1888 let editor = editor.downgrade();
1889 move |_, cx| {
1890 if let Some(editor) = editor.upgrade() {
1891 InlineAssistant::update_global(cx, |this, cx| {
1892 if let Some(editor_assists) =
1893 this.assists_by_editor.get(&editor.downgrade())
1894 {
1895 editor_assists.highlight_updates.send(()).ok();
1896 }
1897
1898 this.update_editor_blocks(&editor, assist_id, cx);
1899 })
1900 }
1901 }
1902 }),
1903 cx.subscribe(&codegen, move |codegen, event, cx| {
1904 InlineAssistant::update_global(cx, |this, cx| match event {
1905 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1906 CodegenEvent::Finished => {
1907 let assist = if let Some(assist) = this.assists.get(&assist_id) {
1908 assist
1909 } else {
1910 return;
1911 };
1912
1913 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
1914 if assist.decorations.is_none() {
1915 if let Some(workspace) = assist
1916 .workspace
1917 .as_ref()
1918 .and_then(|workspace| workspace.upgrade())
1919 {
1920 let error = format!("Inline assistant error: {}", error);
1921 workspace.update(cx, |workspace, cx| {
1922 struct InlineAssistantError;
1923
1924 let id =
1925 NotificationId::identified::<InlineAssistantError>(
1926 assist_id.0,
1927 );
1928
1929 workspace.show_toast(Toast::new(id, error), cx);
1930 })
1931 }
1932 }
1933 }
1934
1935 if assist.decorations.is_none() {
1936 this.finish_assist(assist_id, false, cx);
1937 }
1938 }
1939 })
1940 }),
1941 ],
1942 }
1943 }
1944
1945 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
1946 let decorations = self.decorations.as_ref()?;
1947 Some(decorations.prompt_editor.read(cx).prompt(cx))
1948 }
1949
1950 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
1951 if self.include_context {
1952 let workspace = self.workspace.as_ref()?;
1953 let workspace = workspace.upgrade()?.read(cx);
1954 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
1955 Some(
1956 assistant_panel
1957 .read(cx)
1958 .active_context(cx)?
1959 .read(cx)
1960 .to_completion_request(cx),
1961 )
1962 } else {
1963 None
1964 }
1965 }
1966
1967 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
1968 let Some(user_prompt) = self.user_prompt(cx) else {
1969 return future::ready(Err(anyhow!("no user prompt"))).boxed();
1970 };
1971 let assistant_panel_context = self.assistant_panel_context(cx);
1972 self.codegen.read(cx).count_tokens(
1973 self.range.clone(),
1974 user_prompt,
1975 assistant_panel_context,
1976 cx,
1977 )
1978 }
1979}
1980
1981struct InlineAssistDecorations {
1982 prompt_block_id: CustomBlockId,
1983 prompt_editor: View<PromptEditor>,
1984 removed_line_block_ids: HashSet<CustomBlockId>,
1985 end_block_id: CustomBlockId,
1986}
1987
1988#[derive(Debug)]
1989pub enum CodegenEvent {
1990 Finished,
1991 Undone,
1992}
1993
1994pub struct Codegen {
1995 buffer: Model<MultiBuffer>,
1996 old_buffer: Model<Buffer>,
1997 snapshot: MultiBufferSnapshot,
1998 edit_position: Option<Anchor>,
1999 last_equal_ranges: Vec<Range<Anchor>>,
2000 transaction_id: Option<TransactionId>,
2001 status: CodegenStatus,
2002 generation: Task<()>,
2003 diff: Diff,
2004 telemetry: Option<Arc<Telemetry>>,
2005 _subscription: gpui::Subscription,
2006 initial_insertion: Option<InitialInsertion>,
2007}
2008
2009enum CodegenStatus {
2010 Idle,
2011 Pending,
2012 Done,
2013 Error(anyhow::Error),
2014}
2015
2016#[derive(Default)]
2017struct Diff {
2018 task: Option<Task<()>>,
2019 should_update: bool,
2020 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2021 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2022}
2023
2024impl EventEmitter<CodegenEvent> for Codegen {}
2025
2026impl Codegen {
2027 pub fn new(
2028 buffer: Model<MultiBuffer>,
2029 range: Range<Anchor>,
2030 initial_insertion: Option<InitialInsertion>,
2031 telemetry: Option<Arc<Telemetry>>,
2032 cx: &mut ModelContext<Self>,
2033 ) -> Self {
2034 let snapshot = buffer.read(cx).snapshot(cx);
2035
2036 let (old_buffer, _, _) = buffer
2037 .read(cx)
2038 .range_to_buffer_ranges(range.clone(), cx)
2039 .pop()
2040 .unwrap();
2041 let old_buffer = cx.new_model(|cx| {
2042 let old_buffer = old_buffer.read(cx);
2043 let text = old_buffer.as_rope().clone();
2044 let line_ending = old_buffer.line_ending();
2045 let language = old_buffer.language().cloned();
2046 let language_registry = old_buffer.language_registry();
2047
2048 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2049 buffer.set_language(language, cx);
2050 if let Some(language_registry) = language_registry {
2051 buffer.set_language_registry(language_registry)
2052 }
2053 buffer
2054 });
2055
2056 Self {
2057 buffer: buffer.clone(),
2058 old_buffer,
2059 edit_position: None,
2060 snapshot,
2061 last_equal_ranges: Default::default(),
2062 transaction_id: None,
2063 status: CodegenStatus::Idle,
2064 generation: Task::ready(()),
2065 diff: Diff::default(),
2066 telemetry,
2067 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2068 initial_insertion,
2069 }
2070 }
2071
2072 fn handle_buffer_event(
2073 &mut self,
2074 _buffer: Model<MultiBuffer>,
2075 event: &multi_buffer::Event,
2076 cx: &mut ModelContext<Self>,
2077 ) {
2078 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2079 if self.transaction_id == Some(*transaction_id) {
2080 self.transaction_id = None;
2081 self.generation = Task::ready(());
2082 cx.emit(CodegenEvent::Undone);
2083 }
2084 }
2085 }
2086
2087 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2088 &self.last_equal_ranges
2089 }
2090
2091 pub fn count_tokens(
2092 &self,
2093 edit_range: Range<Anchor>,
2094 user_prompt: String,
2095 assistant_panel_context: Option<LanguageModelRequest>,
2096 cx: &AppContext,
2097 ) -> BoxFuture<'static, Result<usize>> {
2098 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2099 let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
2100 model.count_tokens(request, cx)
2101 } else {
2102 future::ready(Err(anyhow!("no active model"))).boxed()
2103 }
2104 }
2105
2106 pub fn start(
2107 &mut self,
2108 mut edit_range: Range<Anchor>,
2109 user_prompt: String,
2110 assistant_panel_context: Option<LanguageModelRequest>,
2111 cx: &mut ModelContext<Self>,
2112 ) -> Result<()> {
2113 let model = LanguageModelRegistry::read_global(cx)
2114 .active_model()
2115 .context("no active model")?;
2116
2117 self.undo(cx);
2118
2119 // Handle initial insertion
2120 self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
2121 self.buffer.update(cx, |buffer, cx| {
2122 buffer.start_transaction(cx);
2123 let offset = edit_range.start.to_offset(&self.snapshot);
2124 let edit_position;
2125 match initial_insertion {
2126 InitialInsertion::NewlineBefore => {
2127 buffer.edit([(offset..offset, "\n\n")], None, cx);
2128 self.snapshot = buffer.snapshot(cx);
2129 edit_position = self.snapshot.anchor_after(offset + 1);
2130 }
2131 InitialInsertion::NewlineAfter => {
2132 buffer.edit([(offset..offset, "\n")], None, cx);
2133 self.snapshot = buffer.snapshot(cx);
2134 edit_position = self.snapshot.anchor_after(offset);
2135 }
2136 }
2137 self.edit_position = Some(edit_position);
2138 edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
2139 buffer.end_transaction(cx)
2140 })
2141 } else {
2142 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2143 None
2144 };
2145
2146 let telemetry_id = model.telemetry_id();
2147 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2148 .trim()
2149 .to_lowercase()
2150 == "delete"
2151 {
2152 async { Ok(stream::empty().boxed()) }.boxed_local()
2153 } else {
2154 let request =
2155 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
2156 let chunks =
2157 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2158 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2159 };
2160 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2161 Ok(())
2162 }
2163
2164 fn build_request(
2165 &self,
2166 user_prompt: String,
2167 assistant_panel_context: Option<LanguageModelRequest>,
2168 edit_range: Range<Anchor>,
2169 cx: &AppContext,
2170 ) -> LanguageModelRequest {
2171 let buffer = self.buffer.read(cx).snapshot(cx);
2172 let language = buffer.language_at(edit_range.start);
2173 let language_name = if let Some(language) = language.as_ref() {
2174 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2175 None
2176 } else {
2177 Some(language.name())
2178 }
2179 } else {
2180 None
2181 };
2182
2183 // Higher Temperature increases the randomness of model outputs.
2184 // If Markdown or No Language is Known, increase the randomness for more creative output
2185 // If Code, decrease temperature to get more deterministic outputs
2186 let temperature = if let Some(language) = language_name.clone() {
2187 if language.as_ref() == "Markdown" {
2188 1.0
2189 } else {
2190 0.5
2191 }
2192 } else {
2193 1.0
2194 };
2195
2196 let language_name = language_name.as_deref();
2197 let start = buffer.point_to_buffer_offset(edit_range.start);
2198 let end = buffer.point_to_buffer_offset(edit_range.end);
2199 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2200 let (start_buffer, start_buffer_offset) = start;
2201 let (end_buffer, end_buffer_offset) = end;
2202 if start_buffer.remote_id() == end_buffer.remote_id() {
2203 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2204 } else {
2205 panic!("invalid transformation range");
2206 }
2207 } else {
2208 panic!("invalid transformation range");
2209 };
2210 let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
2211
2212 let mut messages = Vec::new();
2213 if let Some(context_request) = assistant_panel_context {
2214 messages = context_request.messages;
2215 }
2216
2217 messages.push(LanguageModelRequestMessage {
2218 role: Role::User,
2219 content: prompt,
2220 });
2221
2222 LanguageModelRequest {
2223 messages,
2224 stop: vec!["|END|>".to_string()],
2225 temperature,
2226 }
2227 }
2228
2229 pub fn handle_stream(
2230 &mut self,
2231 model_telemetry_id: String,
2232 edit_range: Range<Anchor>,
2233 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2234 cx: &mut ModelContext<Self>,
2235 ) {
2236 let snapshot = self.snapshot.clone();
2237 let selected_text = snapshot
2238 .text_for_range(edit_range.start..edit_range.end)
2239 .collect::<Rope>();
2240
2241 let selection_start = edit_range.start.to_point(&snapshot);
2242
2243 // Start with the indentation of the first line in the selection
2244 let mut suggested_line_indent = snapshot
2245 .suggested_indents(selection_start.row..=selection_start.row, cx)
2246 .into_values()
2247 .next()
2248 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2249
2250 // If the first line in the selection does not have indentation, check the following lines
2251 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2252 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2253 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2254 // Prefer tabs if a line in the selection uses tabs as indentation
2255 if line_indent.kind == IndentKind::Tab {
2256 suggested_line_indent.kind = IndentKind::Tab;
2257 break;
2258 }
2259 }
2260 }
2261
2262 let telemetry = self.telemetry.clone();
2263 self.diff = Diff::default();
2264 self.status = CodegenStatus::Pending;
2265 let mut edit_start = edit_range.start.to_offset(&snapshot);
2266 self.generation = cx.spawn(|this, mut cx| {
2267 async move {
2268 let chunks = stream.await;
2269 let generate = async {
2270 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
2271 let diff: Task<anyhow::Result<()>> =
2272 cx.background_executor().spawn(async move {
2273 let mut response_latency = None;
2274 let request_start = Instant::now();
2275 let diff = async {
2276 let chunks = StripInvalidSpans::new(chunks?);
2277 futures::pin_mut!(chunks);
2278 let mut diff = StreamingDiff::new(selected_text.to_string());
2279
2280 let mut new_text = String::new();
2281 let mut base_indent = None;
2282 let mut line_indent = None;
2283 let mut first_line = true;
2284
2285 while let Some(chunk) = chunks.next().await {
2286 if response_latency.is_none() {
2287 response_latency = Some(request_start.elapsed());
2288 }
2289 let chunk = chunk?;
2290
2291 let mut lines = chunk.split('\n').peekable();
2292 while let Some(line) = lines.next() {
2293 new_text.push_str(line);
2294 if line_indent.is_none() {
2295 if let Some(non_whitespace_ch_ix) =
2296 new_text.find(|ch: char| !ch.is_whitespace())
2297 {
2298 line_indent = Some(non_whitespace_ch_ix);
2299 base_indent = base_indent.or(line_indent);
2300
2301 let line_indent = line_indent.unwrap();
2302 let base_indent = base_indent.unwrap();
2303 let indent_delta =
2304 line_indent as i32 - base_indent as i32;
2305 let mut corrected_indent_len = cmp::max(
2306 0,
2307 suggested_line_indent.len as i32 + indent_delta,
2308 )
2309 as usize;
2310 if first_line {
2311 corrected_indent_len = corrected_indent_len
2312 .saturating_sub(
2313 selection_start.column as usize,
2314 );
2315 }
2316
2317 let indent_char = suggested_line_indent.char();
2318 let mut indent_buffer = [0; 4];
2319 let indent_str =
2320 indent_char.encode_utf8(&mut indent_buffer);
2321 new_text.replace_range(
2322 ..line_indent,
2323 &indent_str.repeat(corrected_indent_len),
2324 );
2325 }
2326 }
2327
2328 if line_indent.is_some() {
2329 hunks_tx.send(diff.push_new(&new_text)).await?;
2330 new_text.clear();
2331 }
2332
2333 if lines.peek().is_some() {
2334 hunks_tx.send(diff.push_new("\n")).await?;
2335 if line_indent.is_none() {
2336 // Don't write out the leading indentation in empty lines on the next line
2337 // This is the case where the above if statement didn't clear the buffer
2338 new_text.clear();
2339 }
2340 line_indent = None;
2341 first_line = false;
2342 }
2343 }
2344 }
2345 hunks_tx.send(diff.push_new(&new_text)).await?;
2346 hunks_tx.send(diff.finish()).await?;
2347
2348 anyhow::Ok(())
2349 };
2350
2351 let result = diff.await;
2352
2353 let error_message =
2354 result.as_ref().err().map(|error| error.to_string());
2355 if let Some(telemetry) = telemetry {
2356 telemetry.report_assistant_event(
2357 None,
2358 telemetry_events::AssistantKind::Inline,
2359 model_telemetry_id,
2360 response_latency,
2361 error_message,
2362 );
2363 }
2364
2365 result?;
2366 Ok(())
2367 });
2368
2369 while let Some(hunks) = hunks_rx.next().await {
2370 this.update(&mut cx, |this, cx| {
2371 this.last_equal_ranges.clear();
2372
2373 let transaction = this.buffer.update(cx, |buffer, cx| {
2374 // Avoid grouping assistant edits with user edits.
2375 buffer.finalize_last_transaction(cx);
2376
2377 buffer.start_transaction(cx);
2378 buffer.edit(
2379 hunks.into_iter().filter_map(|hunk| match hunk {
2380 Hunk::Insert { text } => {
2381 let edit_start = snapshot.anchor_after(edit_start);
2382 Some((edit_start..edit_start, text))
2383 }
2384 Hunk::Remove { len } => {
2385 let edit_end = edit_start + len;
2386 let edit_range = snapshot.anchor_after(edit_start)
2387 ..snapshot.anchor_before(edit_end);
2388 edit_start = edit_end;
2389 Some((edit_range, String::new()))
2390 }
2391 Hunk::Keep { len } => {
2392 let edit_end = edit_start + len;
2393 let edit_range = snapshot.anchor_after(edit_start)
2394 ..snapshot.anchor_before(edit_end);
2395 edit_start = edit_end;
2396 this.last_equal_ranges.push(edit_range);
2397 None
2398 }
2399 }),
2400 None,
2401 cx,
2402 );
2403 this.edit_position = Some(snapshot.anchor_after(edit_start));
2404
2405 buffer.end_transaction(cx)
2406 });
2407
2408 if let Some(transaction) = transaction {
2409 if let Some(first_transaction) = this.transaction_id {
2410 // Group all assistant edits into the first transaction.
2411 this.buffer.update(cx, |buffer, cx| {
2412 buffer.merge_transactions(
2413 transaction,
2414 first_transaction,
2415 cx,
2416 )
2417 });
2418 } else {
2419 this.transaction_id = Some(transaction);
2420 this.buffer.update(cx, |buffer, cx| {
2421 buffer.finalize_last_transaction(cx)
2422 });
2423 }
2424 }
2425
2426 this.update_diff(edit_range.clone(), cx);
2427 cx.notify();
2428 })?;
2429 }
2430
2431 diff.await?;
2432
2433 anyhow::Ok(())
2434 };
2435
2436 let result = generate.await;
2437 this.update(&mut cx, |this, cx| {
2438 this.last_equal_ranges.clear();
2439 if let Err(error) = result {
2440 this.status = CodegenStatus::Error(error);
2441 } else {
2442 this.status = CodegenStatus::Done;
2443 }
2444 cx.emit(CodegenEvent::Finished);
2445 cx.notify();
2446 })
2447 .ok();
2448 }
2449 });
2450 cx.notify();
2451 }
2452
2453 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2454 self.last_equal_ranges.clear();
2455 self.status = CodegenStatus::Done;
2456 self.generation = Task::ready(());
2457 cx.emit(CodegenEvent::Finished);
2458 cx.notify();
2459 }
2460
2461 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2462 if let Some(transaction_id) = self.transaction_id.take() {
2463 self.buffer
2464 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
2465 }
2466 }
2467
2468 fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
2469 if self.diff.task.is_some() {
2470 self.diff.should_update = true;
2471 } else {
2472 self.diff.should_update = false;
2473
2474 let old_snapshot = self.snapshot.clone();
2475 let old_range = edit_range.to_point(&old_snapshot);
2476 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2477 let new_range = edit_range.to_point(&new_snapshot);
2478
2479 self.diff.task = Some(cx.spawn(|this, mut cx| async move {
2480 let (deleted_row_ranges, inserted_row_ranges) = cx
2481 .background_executor()
2482 .spawn(async move {
2483 let old_text = old_snapshot
2484 .text_for_range(
2485 Point::new(old_range.start.row, 0)
2486 ..Point::new(
2487 old_range.end.row,
2488 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2489 ),
2490 )
2491 .collect::<String>();
2492 let new_text = new_snapshot
2493 .text_for_range(
2494 Point::new(new_range.start.row, 0)
2495 ..Point::new(
2496 new_range.end.row,
2497 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2498 ),
2499 )
2500 .collect::<String>();
2501
2502 let mut old_row = old_range.start.row;
2503 let mut new_row = new_range.start.row;
2504 let diff = TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2505
2506 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2507 let mut inserted_row_ranges = Vec::new();
2508 for change in diff.iter_all_changes() {
2509 let line_count = change.value().lines().count() as u32;
2510 match change.tag() {
2511 similar::ChangeTag::Equal => {
2512 old_row += line_count;
2513 new_row += line_count;
2514 }
2515 similar::ChangeTag::Delete => {
2516 let old_end_row = old_row + line_count - 1;
2517 let new_row =
2518 new_snapshot.anchor_before(Point::new(new_row, 0));
2519
2520 if let Some((_, last_deleted_row_range)) =
2521 deleted_row_ranges.last_mut()
2522 {
2523 if *last_deleted_row_range.end() + 1 == old_row {
2524 *last_deleted_row_range =
2525 *last_deleted_row_range.start()..=old_end_row;
2526 } else {
2527 deleted_row_ranges
2528 .push((new_row, old_row..=old_end_row));
2529 }
2530 } else {
2531 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2532 }
2533
2534 old_row += line_count;
2535 }
2536 similar::ChangeTag::Insert => {
2537 let new_end_row = new_row + line_count - 1;
2538 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2539 let end = new_snapshot.anchor_before(Point::new(
2540 new_end_row,
2541 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2542 ));
2543 inserted_row_ranges.push(start..=end);
2544 new_row += line_count;
2545 }
2546 }
2547 }
2548
2549 (deleted_row_ranges, inserted_row_ranges)
2550 })
2551 .await;
2552
2553 this.update(&mut cx, |this, cx| {
2554 this.diff.deleted_row_ranges = deleted_row_ranges;
2555 this.diff.inserted_row_ranges = inserted_row_ranges;
2556 this.diff.task = None;
2557 if this.diff.should_update {
2558 this.update_diff(edit_range, cx);
2559 }
2560 cx.notify();
2561 })
2562 .ok();
2563 }));
2564 }
2565 }
2566}
2567
2568struct StripInvalidSpans<T> {
2569 stream: T,
2570 stream_done: bool,
2571 buffer: String,
2572 first_line: bool,
2573 line_end: bool,
2574 starts_with_code_block: bool,
2575}
2576
2577impl<T> StripInvalidSpans<T>
2578where
2579 T: Stream<Item = Result<String>>,
2580{
2581 fn new(stream: T) -> Self {
2582 Self {
2583 stream,
2584 stream_done: false,
2585 buffer: String::new(),
2586 first_line: true,
2587 line_end: false,
2588 starts_with_code_block: false,
2589 }
2590 }
2591}
2592
2593impl<T> Stream for StripInvalidSpans<T>
2594where
2595 T: Stream<Item = Result<String>>,
2596{
2597 type Item = Result<String>;
2598
2599 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2600 const CODE_BLOCK_DELIMITER: &str = "```";
2601 const CURSOR_SPAN: &str = "<|CURSOR|>";
2602
2603 let this = unsafe { self.get_unchecked_mut() };
2604 loop {
2605 if !this.stream_done {
2606 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2607 match stream.as_mut().poll_next(cx) {
2608 Poll::Ready(Some(Ok(chunk))) => {
2609 this.buffer.push_str(&chunk);
2610 }
2611 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2612 Poll::Ready(None) => {
2613 this.stream_done = true;
2614 }
2615 Poll::Pending => return Poll::Pending,
2616 }
2617 }
2618
2619 let mut chunk = String::new();
2620 let mut consumed = 0;
2621 if !this.buffer.is_empty() {
2622 let mut lines = this.buffer.split('\n').enumerate().peekable();
2623 while let Some((line_ix, line)) = lines.next() {
2624 if line_ix > 0 {
2625 this.first_line = false;
2626 }
2627
2628 if this.first_line {
2629 let trimmed_line = line.trim();
2630 if lines.peek().is_some() {
2631 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2632 consumed += line.len() + 1;
2633 this.starts_with_code_block = true;
2634 continue;
2635 }
2636 } else if trimmed_line.is_empty()
2637 || prefixes(CODE_BLOCK_DELIMITER)
2638 .any(|prefix| trimmed_line.starts_with(prefix))
2639 {
2640 break;
2641 }
2642 }
2643
2644 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2645 if lines.peek().is_some() {
2646 if this.line_end {
2647 chunk.push('\n');
2648 }
2649
2650 chunk.push_str(&line_without_cursor);
2651 this.line_end = true;
2652 consumed += line.len() + 1;
2653 } else if this.stream_done {
2654 if !this.starts_with_code_block
2655 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2656 {
2657 if this.line_end {
2658 chunk.push('\n');
2659 }
2660
2661 chunk.push_str(&line);
2662 }
2663
2664 consumed += line.len();
2665 } else {
2666 let trimmed_line = line.trim();
2667 if trimmed_line.is_empty()
2668 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2669 || prefixes(CODE_BLOCK_DELIMITER)
2670 .any(|prefix| trimmed_line.ends_with(prefix))
2671 {
2672 break;
2673 } else {
2674 if this.line_end {
2675 chunk.push('\n');
2676 this.line_end = false;
2677 }
2678
2679 chunk.push_str(&line_without_cursor);
2680 consumed += line.len();
2681 }
2682 }
2683 }
2684 }
2685
2686 this.buffer = this.buffer.split_off(consumed);
2687 if !chunk.is_empty() {
2688 return Poll::Ready(Some(Ok(chunk)));
2689 } else if this.stream_done {
2690 return Poll::Ready(None);
2691 }
2692 }
2693 }
2694}
2695
2696fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2697 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2698}
2699
2700fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2701 ranges.sort_unstable_by(|a, b| {
2702 a.start
2703 .cmp(&b.start, buffer)
2704 .then_with(|| b.end.cmp(&a.end, buffer))
2705 });
2706
2707 let mut ix = 0;
2708 while ix + 1 < ranges.len() {
2709 let b = ranges[ix + 1].clone();
2710 let a = &mut ranges[ix];
2711 if a.end.cmp(&b.start, buffer).is_gt() {
2712 if a.end.cmp(&b.end, buffer).is_lt() {
2713 a.end = b.end;
2714 }
2715 ranges.remove(ix + 1);
2716 } else {
2717 ix += 1;
2718 }
2719 }
2720}
2721
2722#[cfg(test)]
2723mod tests {
2724 use super::*;
2725 use futures::stream::{self};
2726 use gpui::{Context, TestAppContext};
2727 use indoc::indoc;
2728 use language::{
2729 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2730 Point,
2731 };
2732 use language_model::LanguageModelRegistry;
2733 use rand::prelude::*;
2734 use serde::Serialize;
2735 use settings::SettingsStore;
2736 use std::{future, sync::Arc};
2737
2738 #[derive(Serialize)]
2739 pub struct DummyCompletionRequest {
2740 pub name: String,
2741 }
2742
2743 #[gpui::test(iterations = 10)]
2744 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2745 cx.set_global(cx.update(SettingsStore::test));
2746 cx.update(language_model::LanguageModelRegistry::test);
2747 cx.update(language_settings::init);
2748
2749 let text = indoc! {"
2750 fn main() {
2751 let x = 0;
2752 for _ in 0..10 {
2753 x += 1;
2754 }
2755 }
2756 "};
2757 let buffer =
2758 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2759 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2760 let range = buffer.read_with(cx, |buffer, cx| {
2761 let snapshot = buffer.snapshot(cx);
2762 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2763 });
2764 let codegen =
2765 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2766
2767 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2768 codegen.update(cx, |codegen, cx| {
2769 codegen.handle_stream(
2770 String::new(),
2771 range,
2772 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2773 cx,
2774 )
2775 });
2776
2777 let mut new_text = concat!(
2778 " let mut x = 0;\n",
2779 " while x < 10 {\n",
2780 " x += 1;\n",
2781 " }",
2782 );
2783 while !new_text.is_empty() {
2784 let max_len = cmp::min(new_text.len(), 10);
2785 let len = rng.gen_range(1..=max_len);
2786 let (chunk, suffix) = new_text.split_at(len);
2787 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2788 new_text = suffix;
2789 cx.background_executor.run_until_parked();
2790 }
2791 drop(chunks_tx);
2792 cx.background_executor.run_until_parked();
2793
2794 assert_eq!(
2795 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2796 indoc! {"
2797 fn main() {
2798 let mut x = 0;
2799 while x < 10 {
2800 x += 1;
2801 }
2802 }
2803 "}
2804 );
2805 }
2806
2807 #[gpui::test(iterations = 10)]
2808 async fn test_autoindent_when_generating_past_indentation(
2809 cx: &mut TestAppContext,
2810 mut rng: StdRng,
2811 ) {
2812 cx.set_global(cx.update(SettingsStore::test));
2813 cx.update(language_settings::init);
2814
2815 let text = indoc! {"
2816 fn main() {
2817 le
2818 }
2819 "};
2820 let buffer =
2821 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2822 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2823 let range = buffer.read_with(cx, |buffer, cx| {
2824 let snapshot = buffer.snapshot(cx);
2825 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2826 });
2827 let codegen =
2828 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2829
2830 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2831 codegen.update(cx, |codegen, cx| {
2832 codegen.handle_stream(
2833 String::new(),
2834 range.clone(),
2835 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2836 cx,
2837 )
2838 });
2839
2840 cx.background_executor.run_until_parked();
2841
2842 let mut new_text = concat!(
2843 "t mut x = 0;\n",
2844 "while x < 10 {\n",
2845 " x += 1;\n",
2846 "}", //
2847 );
2848 while !new_text.is_empty() {
2849 let max_len = cmp::min(new_text.len(), 10);
2850 let len = rng.gen_range(1..=max_len);
2851 let (chunk, suffix) = new_text.split_at(len);
2852 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2853 new_text = suffix;
2854 cx.background_executor.run_until_parked();
2855 }
2856 drop(chunks_tx);
2857 cx.background_executor.run_until_parked();
2858
2859 assert_eq!(
2860 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2861 indoc! {"
2862 fn main() {
2863 let mut x = 0;
2864 while x < 10 {
2865 x += 1;
2866 }
2867 }
2868 "}
2869 );
2870 }
2871
2872 #[gpui::test(iterations = 10)]
2873 async fn test_autoindent_when_generating_before_indentation(
2874 cx: &mut TestAppContext,
2875 mut rng: StdRng,
2876 ) {
2877 cx.update(LanguageModelRegistry::test);
2878 cx.set_global(cx.update(SettingsStore::test));
2879 cx.update(language_settings::init);
2880
2881 let text = concat!(
2882 "fn main() {\n",
2883 " \n",
2884 "}\n" //
2885 );
2886 let buffer =
2887 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2888 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2889 let range = buffer.read_with(cx, |buffer, cx| {
2890 let snapshot = buffer.snapshot(cx);
2891 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2892 });
2893 let codegen =
2894 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2895
2896 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2897 codegen.update(cx, |codegen, cx| {
2898 codegen.handle_stream(
2899 String::new(),
2900 range.clone(),
2901 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2902 cx,
2903 )
2904 });
2905
2906 cx.background_executor.run_until_parked();
2907
2908 let mut new_text = concat!(
2909 "let mut x = 0;\n",
2910 "while x < 10 {\n",
2911 " x += 1;\n",
2912 "}", //
2913 );
2914 while !new_text.is_empty() {
2915 let max_len = cmp::min(new_text.len(), 10);
2916 let len = rng.gen_range(1..=max_len);
2917 let (chunk, suffix) = new_text.split_at(len);
2918 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2919 new_text = suffix;
2920 cx.background_executor.run_until_parked();
2921 }
2922 drop(chunks_tx);
2923 cx.background_executor.run_until_parked();
2924
2925 assert_eq!(
2926 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2927 indoc! {"
2928 fn main() {
2929 let mut x = 0;
2930 while x < 10 {
2931 x += 1;
2932 }
2933 }
2934 "}
2935 );
2936 }
2937
2938 #[gpui::test(iterations = 10)]
2939 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
2940 cx.update(LanguageModelRegistry::test);
2941 cx.set_global(cx.update(SettingsStore::test));
2942 cx.update(language_settings::init);
2943
2944 let text = indoc! {"
2945 func main() {
2946 \tx := 0
2947 \tfor i := 0; i < 10; i++ {
2948 \t\tx++
2949 \t}
2950 }
2951 "};
2952 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
2953 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2954 let range = buffer.read_with(cx, |buffer, cx| {
2955 let snapshot = buffer.snapshot(cx);
2956 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
2957 });
2958 let codegen =
2959 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2960
2961 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2962 codegen.update(cx, |codegen, cx| {
2963 codegen.handle_stream(
2964 String::new(),
2965 range.clone(),
2966 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2967 cx,
2968 )
2969 });
2970
2971 let new_text = concat!(
2972 "func main() {\n",
2973 "\tx := 0\n",
2974 "\tfor x < 10 {\n",
2975 "\t\tx++\n",
2976 "\t}", //
2977 );
2978 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
2979 drop(chunks_tx);
2980 cx.background_executor.run_until_parked();
2981
2982 assert_eq!(
2983 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2984 indoc! {"
2985 func main() {
2986 \tx := 0
2987 \tfor x < 10 {
2988 \t\tx++
2989 \t}
2990 }
2991 "}
2992 );
2993 }
2994
2995 #[gpui::test]
2996 async fn test_strip_invalid_spans_from_codeblock() {
2997 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2998 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2999 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3000 assert_chunks(
3001 "```html\n```js\nLorem ipsum dolor\n```\n```",
3002 "```js\nLorem ipsum dolor\n```",
3003 )
3004 .await;
3005 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3006 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3007 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3008 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3009
3010 async fn assert_chunks(text: &str, expected_text: &str) {
3011 for chunk_size in 1..=text.len() {
3012 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3013 .map(|chunk| chunk.unwrap())
3014 .collect::<String>()
3015 .await;
3016 assert_eq!(
3017 actual_text, expected_text,
3018 "failed to strip invalid spans, chunk size: {}",
3019 chunk_size
3020 );
3021 }
3022 }
3023
3024 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3025 stream::iter(
3026 text.chars()
3027 .collect::<Vec<_>>()
3028 .chunks(size)
3029 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3030 .collect::<Vec<_>>(),
3031 )
3032 }
3033 }
3034
3035 fn rust_lang() -> Language {
3036 Language::new(
3037 LanguageConfig {
3038 name: "Rust".into(),
3039 matcher: LanguageMatcher {
3040 path_suffixes: vec!["rs".to_string()],
3041 ..Default::default()
3042 },
3043 ..Default::default()
3044 },
3045 Some(tree_sitter_rust::language()),
3046 )
3047 .with_indents_query(
3048 r#"
3049 (call_expression) @indent
3050 (field_expression) @indent
3051 (_ "(" ")" @end) @indent
3052 (_ "{" "}" @end) @indent
3053 "#,
3054 )
3055 .unwrap()
3056 }
3057}