1use crate::{
2 humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
3 Hunk, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, Selection, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use similar::TextDiff;
39use smol::future::FutureExt;
40use std::{
41 cmp,
42 future::{self, Future},
43 mem,
44 ops::{Range, RangeInclusive},
45 pin::Pin,
46 sync::Arc,
47 task::{self, Poll},
48 time::{Duration, Instant},
49};
50use theme::ThemeSettings;
51use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
52use util::{RangeExt, ResultExt};
53use workspace::{notifications::NotificationId, Toast, Workspace};
54
55pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
56 cx.set_global(InlineAssistant::new(fs, telemetry));
57}
58
59const PROMPT_HISTORY_MAX_LEN: usize = 20;
60
61pub struct InlineAssistant {
62 next_assist_id: InlineAssistId,
63 next_assist_group_id: InlineAssistGroupId,
64 assists: HashMap<InlineAssistId, InlineAssist>,
65 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
66 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
67 prompt_history: VecDeque<String>,
68 telemetry: Option<Arc<Telemetry>>,
69 fs: Arc<dyn Fs>,
70}
71
72impl Global for InlineAssistant {}
73
74impl InlineAssistant {
75 pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
76 Self {
77 next_assist_id: InlineAssistId::default(),
78 next_assist_group_id: InlineAssistGroupId::default(),
79 assists: HashMap::default(),
80 assists_by_editor: HashMap::default(),
81 assist_groups: HashMap::default(),
82 prompt_history: VecDeque::default(),
83 telemetry: Some(telemetry),
84 fs,
85 }
86 }
87
88 pub fn assist(
89 &mut self,
90 editor: &View<Editor>,
91 workspace: Option<WeakView<Workspace>>,
92 assistant_panel: Option<&View<AssistantPanel>>,
93 initial_prompt: Option<String>,
94 cx: &mut WindowContext,
95 ) {
96 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
97
98 let mut selections = Vec::<Selection<Point>>::new();
99 let mut newest_selection = None;
100 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
101 if selection.end > selection.start {
102 selection.start.column = 0;
103 // If the selection ends at the start of the line, we don't want to include it.
104 if selection.end.column == 0 {
105 selection.end.row -= 1;
106 }
107 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
108 }
109
110 if let Some(prev_selection) = selections.last_mut() {
111 if selection.start <= prev_selection.end {
112 prev_selection.end = selection.end;
113 continue;
114 }
115 }
116
117 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
118 if selection.id > latest_selection.id {
119 *latest_selection = selection.clone();
120 }
121 selections.push(selection);
122 }
123 let newest_selection = newest_selection.unwrap();
124
125 let mut codegen_ranges = Vec::new();
126 for (excerpt_id, buffer, buffer_range) in
127 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
128 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
129 }))
130 {
131 let start = Anchor {
132 buffer_id: Some(buffer.remote_id()),
133 excerpt_id,
134 text_anchor: buffer.anchor_before(buffer_range.start),
135 };
136 let end = Anchor {
137 buffer_id: Some(buffer.remote_id()),
138 excerpt_id,
139 text_anchor: buffer.anchor_after(buffer_range.end),
140 };
141 codegen_ranges.push(start..end);
142 }
143
144 let assist_group_id = self.next_assist_group_id.post_inc();
145 let prompt_buffer =
146 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
147 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
148
149 let mut assists = Vec::new();
150 let mut assist_to_focus = None;
151 for range in codegen_ranges {
152 let assist_id = self.next_assist_id.post_inc();
153 let codegen = cx.new_model(|cx| {
154 Codegen::new(
155 editor.read(cx).buffer().clone(),
156 range.clone(),
157 None,
158 self.telemetry.clone(),
159 cx,
160 )
161 });
162
163 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
164 let prompt_editor = cx.new_view(|cx| {
165 PromptEditor::new(
166 assist_id,
167 gutter_dimensions.clone(),
168 self.prompt_history.clone(),
169 prompt_buffer.clone(),
170 codegen.clone(),
171 editor,
172 assistant_panel,
173 workspace.clone(),
174 self.fs.clone(),
175 cx,
176 )
177 });
178
179 if assist_to_focus.is_none() {
180 let focus_assist = if newest_selection.reversed {
181 range.start.to_point(&snapshot) == newest_selection.start
182 } else {
183 range.end.to_point(&snapshot) == newest_selection.end
184 };
185 if focus_assist {
186 assist_to_focus = Some(assist_id);
187 }
188 }
189
190 let [prompt_block_id, end_block_id] =
191 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
192
193 assists.push((
194 assist_id,
195 range,
196 prompt_editor,
197 prompt_block_id,
198 end_block_id,
199 ));
200 }
201
202 let editor_assists = self
203 .assists_by_editor
204 .entry(editor.downgrade())
205 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
206 let mut assist_group = InlineAssistGroup::new();
207 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
208 self.assists.insert(
209 assist_id,
210 InlineAssist::new(
211 assist_id,
212 assist_group_id,
213 assistant_panel.is_some(),
214 editor,
215 &prompt_editor,
216 prompt_block_id,
217 end_block_id,
218 range,
219 prompt_editor.read(cx).codegen.clone(),
220 workspace.clone(),
221 cx,
222 ),
223 );
224 assist_group.assist_ids.push(assist_id);
225 editor_assists.assist_ids.push(assist_id);
226 }
227 self.assist_groups.insert(assist_group_id, assist_group);
228
229 if let Some(assist_id) = assist_to_focus {
230 self.focus_assist(assist_id, cx);
231 }
232 }
233
234 #[allow(clippy::too_many_arguments)]
235 pub fn suggest_assist(
236 &mut self,
237 editor: &View<Editor>,
238 mut range: Range<Anchor>,
239 initial_prompt: String,
240 initial_insertion: Option<InitialInsertion>,
241 workspace: Option<WeakView<Workspace>>,
242 assistant_panel: Option<&View<AssistantPanel>>,
243 cx: &mut WindowContext,
244 ) -> InlineAssistId {
245 let assist_group_id = self.next_assist_group_id.post_inc();
246 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
247 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
248
249 let assist_id = self.next_assist_id.post_inc();
250
251 let buffer = editor.read(cx).buffer().clone();
252 {
253 let snapshot = buffer.read(cx).read(cx);
254
255 let mut point_range = range.to_point(&snapshot);
256 if point_range.is_empty() {
257 point_range.start.column = 0;
258 point_range.end.column = 0;
259 } else {
260 point_range.start.column = 0;
261 if point_range.end.row > point_range.start.row && point_range.end.column == 0 {
262 point_range.end.row -= 1;
263 }
264 point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row));
265 }
266
267 range.start = snapshot.anchor_before(point_range.start);
268 range.end = snapshot.anchor_after(point_range.end);
269 }
270
271 let codegen = cx.new_model(|cx| {
272 Codegen::new(
273 editor.read(cx).buffer().clone(),
274 range.clone(),
275 initial_insertion,
276 self.telemetry.clone(),
277 cx,
278 )
279 });
280
281 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
282 let prompt_editor = cx.new_view(|cx| {
283 PromptEditor::new(
284 assist_id,
285 gutter_dimensions.clone(),
286 self.prompt_history.clone(),
287 prompt_buffer.clone(),
288 codegen.clone(),
289 editor,
290 assistant_panel,
291 workspace.clone(),
292 self.fs.clone(),
293 cx,
294 )
295 });
296
297 let [prompt_block_id, end_block_id] =
298 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
299
300 let editor_assists = self
301 .assists_by_editor
302 .entry(editor.downgrade())
303 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
304
305 let mut assist_group = InlineAssistGroup::new();
306 self.assists.insert(
307 assist_id,
308 InlineAssist::new(
309 assist_id,
310 assist_group_id,
311 assistant_panel.is_some(),
312 editor,
313 &prompt_editor,
314 prompt_block_id,
315 end_block_id,
316 range,
317 prompt_editor.read(cx).codegen.clone(),
318 workspace.clone(),
319 cx,
320 ),
321 );
322 assist_group.assist_ids.push(assist_id);
323 editor_assists.assist_ids.push(assist_id);
324 self.assist_groups.insert(assist_group_id, assist_group);
325 assist_id
326 }
327
328 fn insert_assist_blocks(
329 &self,
330 editor: &View<Editor>,
331 range: &Range<Anchor>,
332 prompt_editor: &View<PromptEditor>,
333 cx: &mut WindowContext,
334 ) -> [CustomBlockId; 2] {
335 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
336 prompt_editor
337 .editor
338 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
339 });
340 let assist_blocks = vec![
341 BlockProperties {
342 style: BlockStyle::Sticky,
343 position: range.start,
344 height: prompt_editor_height,
345 render: build_assist_editor_renderer(prompt_editor),
346 disposition: BlockDisposition::Above,
347 },
348 BlockProperties {
349 style: BlockStyle::Sticky,
350 position: range.end,
351 height: 1,
352 render: Box::new(|cx| {
353 v_flex()
354 .h_full()
355 .w_full()
356 .border_t_1()
357 .border_color(cx.theme().status().info_border)
358 .into_any_element()
359 }),
360 disposition: BlockDisposition::Below,
361 },
362 ];
363
364 editor.update(cx, |editor, cx| {
365 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
366 [block_ids[0], block_ids[1]]
367 })
368 }
369
370 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
371 let assist = &self.assists[&assist_id];
372 let Some(decorations) = assist.decorations.as_ref() else {
373 return;
374 };
375 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
376 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
377
378 assist_group.active_assist_id = Some(assist_id);
379 if assist_group.linked {
380 for assist_id in &assist_group.assist_ids {
381 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
382 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
383 prompt_editor.set_show_cursor_when_unfocused(true, cx)
384 });
385 }
386 }
387 }
388
389 assist
390 .editor
391 .update(cx, |editor, cx| {
392 let scroll_top = editor.scroll_position(cx).y;
393 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
394 let prompt_row = editor
395 .row_for_block(decorations.prompt_block_id, cx)
396 .unwrap()
397 .0 as f32;
398
399 if (scroll_top..scroll_bottom).contains(&prompt_row) {
400 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
401 assist_id,
402 distance_from_top: prompt_row - scroll_top,
403 });
404 } else {
405 editor_assists.scroll_lock = None;
406 }
407 })
408 .ok();
409 }
410
411 fn handle_prompt_editor_focus_out(
412 &mut self,
413 assist_id: InlineAssistId,
414 cx: &mut WindowContext,
415 ) {
416 let assist = &self.assists[&assist_id];
417 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
418 if assist_group.active_assist_id == Some(assist_id) {
419 assist_group.active_assist_id = None;
420 if assist_group.linked {
421 for assist_id in &assist_group.assist_ids {
422 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
423 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
424 prompt_editor.set_show_cursor_when_unfocused(false, cx)
425 });
426 }
427 }
428 }
429 }
430 }
431
432 fn handle_prompt_editor_event(
433 &mut self,
434 prompt_editor: View<PromptEditor>,
435 event: &PromptEditorEvent,
436 cx: &mut WindowContext,
437 ) {
438 let assist_id = prompt_editor.read(cx).id;
439 match event {
440 PromptEditorEvent::StartRequested => {
441 self.start_assist(assist_id, cx);
442 }
443 PromptEditorEvent::StopRequested => {
444 self.stop_assist(assist_id, cx);
445 }
446 PromptEditorEvent::ConfirmRequested => {
447 self.finish_assist(assist_id, false, cx);
448 }
449 PromptEditorEvent::CancelRequested => {
450 self.finish_assist(assist_id, true, cx);
451 }
452 PromptEditorEvent::DismissRequested => {
453 self.dismiss_assist(assist_id, cx);
454 }
455 }
456 }
457
458 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
459 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
460 return;
461 };
462
463 let editor = editor.read(cx);
464 if editor.selections.count() == 1 {
465 let selection = editor.selections.newest::<usize>(cx);
466 let buffer = editor.buffer().read(cx).snapshot(cx);
467 for assist_id in &editor_assists.assist_ids {
468 let assist = &self.assists[assist_id];
469 let assist_range = assist.range.to_offset(&buffer);
470 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
471 {
472 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
473 self.dismiss_assist(*assist_id, cx);
474 } else {
475 self.finish_assist(*assist_id, false, cx);
476 }
477
478 return;
479 }
480 }
481 }
482
483 cx.propagate();
484 }
485
486 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
487 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
488 return;
489 };
490
491 let editor = editor.read(cx);
492 if editor.selections.count() == 1 {
493 let selection = editor.selections.newest::<usize>(cx);
494 let buffer = editor.buffer().read(cx).snapshot(cx);
495 for assist_id in &editor_assists.assist_ids {
496 let assist = &self.assists[assist_id];
497 let assist_range = assist.range.to_offset(&buffer);
498 if assist.decorations.is_some()
499 && assist_range.contains(&selection.start)
500 && assist_range.contains(&selection.end)
501 {
502 self.focus_assist(*assist_id, cx);
503 return;
504 }
505 }
506 }
507
508 cx.propagate();
509 }
510
511 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
512 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
513 for assist_id in editor_assists.assist_ids.clone() {
514 self.finish_assist(assist_id, true, cx);
515 }
516 }
517 }
518
519 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
520 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
521 return;
522 };
523 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
524 return;
525 };
526 let assist = &self.assists[&scroll_lock.assist_id];
527 let Some(decorations) = assist.decorations.as_ref() else {
528 return;
529 };
530
531 editor.update(cx, |editor, cx| {
532 let scroll_position = editor.scroll_position(cx);
533 let target_scroll_top = editor
534 .row_for_block(decorations.prompt_block_id, cx)
535 .unwrap()
536 .0 as f32
537 - scroll_lock.distance_from_top;
538 if target_scroll_top != scroll_position.y {
539 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
540 }
541 });
542 }
543
544 fn handle_editor_event(
545 &mut self,
546 editor: View<Editor>,
547 event: &EditorEvent,
548 cx: &mut WindowContext,
549 ) {
550 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
551 return;
552 };
553
554 match event {
555 EditorEvent::Saved => {
556 for assist_id in editor_assists.assist_ids.clone() {
557 let assist = &self.assists[&assist_id];
558 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
559 self.finish_assist(assist_id, false, cx)
560 }
561 }
562 }
563 EditorEvent::Edited { transaction_id } => {
564 let buffer = editor.read(cx).buffer().read(cx);
565 let edited_ranges =
566 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
567 let snapshot = buffer.snapshot(cx);
568
569 for assist_id in editor_assists.assist_ids.clone() {
570 let assist = &self.assists[&assist_id];
571 if matches!(
572 assist.codegen.read(cx).status,
573 CodegenStatus::Error(_) | CodegenStatus::Done
574 ) {
575 let assist_range = assist.range.to_offset(&snapshot);
576 if edited_ranges
577 .iter()
578 .any(|range| range.overlaps(&assist_range))
579 {
580 self.finish_assist(assist_id, false, cx);
581 }
582 }
583 }
584 }
585 EditorEvent::ScrollPositionChanged { .. } => {
586 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
587 let assist = &self.assists[&scroll_lock.assist_id];
588 if let Some(decorations) = assist.decorations.as_ref() {
589 let distance_from_top = editor.update(cx, |editor, cx| {
590 let scroll_top = editor.scroll_position(cx).y;
591 let prompt_row = editor
592 .row_for_block(decorations.prompt_block_id, cx)
593 .unwrap()
594 .0 as f32;
595 prompt_row - scroll_top
596 });
597
598 if distance_from_top != scroll_lock.distance_from_top {
599 editor_assists.scroll_lock = None;
600 }
601 }
602 }
603 }
604 EditorEvent::SelectionsChanged { .. } => {
605 for assist_id in editor_assists.assist_ids.clone() {
606 let assist = &self.assists[&assist_id];
607 if let Some(decorations) = assist.decorations.as_ref() {
608 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
609 return;
610 }
611 }
612 }
613
614 editor_assists.scroll_lock = None;
615 }
616 _ => {}
617 }
618 }
619
620 fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
621 if let Some(assist) = self.assists.get(&assist_id) {
622 let assist_group_id = assist.group_id;
623 if self.assist_groups[&assist_group_id].linked {
624 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
625 self.finish_assist(assist_id, undo, cx);
626 }
627 return;
628 }
629 }
630
631 self.dismiss_assist(assist_id, cx);
632
633 if let Some(assist) = self.assists.remove(&assist_id) {
634 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
635 {
636 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
637 if entry.get().assist_ids.is_empty() {
638 entry.remove();
639 }
640 }
641
642 if let hash_map::Entry::Occupied(mut entry) =
643 self.assists_by_editor.entry(assist.editor.clone())
644 {
645 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
646 if entry.get().assist_ids.is_empty() {
647 entry.remove();
648 if let Some(editor) = assist.editor.upgrade() {
649 self.update_editor_highlights(&editor, cx);
650 }
651 } else {
652 entry.get().highlight_updates.send(()).ok();
653 }
654 }
655
656 if undo {
657 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
658 }
659 }
660 }
661
662 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
663 let Some(assist) = self.assists.get_mut(&assist_id) else {
664 return false;
665 };
666 let Some(editor) = assist.editor.upgrade() else {
667 return false;
668 };
669 let Some(decorations) = assist.decorations.take() else {
670 return false;
671 };
672
673 editor.update(cx, |editor, cx| {
674 let mut to_remove = decorations.removed_line_block_ids;
675 to_remove.insert(decorations.prompt_block_id);
676 to_remove.insert(decorations.end_block_id);
677 editor.remove_blocks(to_remove, None, cx);
678 });
679
680 if decorations
681 .prompt_editor
682 .focus_handle(cx)
683 .contains_focused(cx)
684 {
685 self.focus_next_assist(assist_id, cx);
686 }
687
688 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
689 if editor_assists
690 .scroll_lock
691 .as_ref()
692 .map_or(false, |lock| lock.assist_id == assist_id)
693 {
694 editor_assists.scroll_lock = None;
695 }
696 editor_assists.highlight_updates.send(()).ok();
697 }
698
699 true
700 }
701
702 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
703 let Some(assist) = self.assists.get(&assist_id) else {
704 return;
705 };
706
707 let assist_group = &self.assist_groups[&assist.group_id];
708 let assist_ix = assist_group
709 .assist_ids
710 .iter()
711 .position(|id| *id == assist_id)
712 .unwrap();
713 let assist_ids = assist_group
714 .assist_ids
715 .iter()
716 .skip(assist_ix + 1)
717 .chain(assist_group.assist_ids.iter().take(assist_ix));
718
719 for assist_id in assist_ids {
720 let assist = &self.assists[assist_id];
721 if assist.decorations.is_some() {
722 self.focus_assist(*assist_id, cx);
723 return;
724 }
725 }
726
727 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
728 }
729
730 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
731 let assist = &self.assists[&assist_id];
732 let Some(editor) = assist.editor.upgrade() else {
733 return;
734 };
735
736 if let Some(decorations) = assist.decorations.as_ref() {
737 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
738 prompt_editor.editor.update(cx, |editor, cx| {
739 editor.focus(cx);
740 editor.select_all(&SelectAll, cx);
741 })
742 });
743 }
744
745 let position = assist.range.start;
746 editor.update(cx, |editor, cx| {
747 editor.change_selections(None, cx, |selections| {
748 selections.select_anchor_ranges([position..position])
749 });
750
751 let mut scroll_target_top;
752 let mut scroll_target_bottom;
753 if let Some(decorations) = assist.decorations.as_ref() {
754 scroll_target_top = editor
755 .row_for_block(decorations.prompt_block_id, cx)
756 .unwrap()
757 .0 as f32;
758 scroll_target_bottom = editor
759 .row_for_block(decorations.end_block_id, cx)
760 .unwrap()
761 .0 as f32;
762 } else {
763 let snapshot = editor.snapshot(cx);
764 let start_row = assist
765 .range
766 .start
767 .to_display_point(&snapshot.display_snapshot)
768 .row();
769 scroll_target_top = start_row.0 as f32;
770 scroll_target_bottom = scroll_target_top + 1.;
771 }
772 scroll_target_top -= editor.vertical_scroll_margin() as f32;
773 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
774
775 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
776 let scroll_top = editor.scroll_position(cx).y;
777 let scroll_bottom = scroll_top + height_in_lines;
778
779 if scroll_target_top < scroll_top {
780 editor.set_scroll_position(point(0., scroll_target_top), cx);
781 } else if scroll_target_bottom > scroll_bottom {
782 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
783 editor
784 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
785 } else {
786 editor.set_scroll_position(point(0., scroll_target_top), cx);
787 }
788 }
789 });
790 }
791
792 fn unlink_assist_group(
793 &mut self,
794 assist_group_id: InlineAssistGroupId,
795 cx: &mut WindowContext,
796 ) -> Vec<InlineAssistId> {
797 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
798 assist_group.linked = false;
799 for assist_id in &assist_group.assist_ids {
800 let assist = self.assists.get_mut(assist_id).unwrap();
801 if let Some(editor_decorations) = assist.decorations.as_ref() {
802 editor_decorations
803 .prompt_editor
804 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
805 }
806 }
807 assist_group.assist_ids.clone()
808 }
809
810 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
811 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
812 assist
813 } else {
814 return;
815 };
816
817 let assist_group_id = assist.group_id;
818 if self.assist_groups[&assist_group_id].linked {
819 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
820 self.start_assist(assist_id, cx);
821 }
822 return;
823 }
824
825 let Some(user_prompt) = assist.user_prompt(cx) else {
826 return;
827 };
828
829 self.prompt_history.retain(|prompt| *prompt != user_prompt);
830 self.prompt_history.push_back(user_prompt.clone());
831 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
832 self.prompt_history.pop_front();
833 }
834
835 let assistant_panel_context = assist.assistant_panel_context(cx);
836
837 assist
838 .codegen
839 .update(cx, |codegen, cx| {
840 codegen.start(
841 assist.range.clone(),
842 user_prompt,
843 assistant_panel_context,
844 cx,
845 )
846 })
847 .log_err();
848 }
849
850 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
851 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
852 assist
853 } else {
854 return;
855 };
856
857 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
858 }
859
860 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
861 let mut gutter_pending_ranges = Vec::new();
862 let mut gutter_transformed_ranges = Vec::new();
863 let mut foreground_ranges = Vec::new();
864 let mut inserted_row_ranges = Vec::new();
865 let empty_assist_ids = Vec::new();
866 let assist_ids = self
867 .assists_by_editor
868 .get(&editor.downgrade())
869 .map_or(&empty_assist_ids, |editor_assists| {
870 &editor_assists.assist_ids
871 });
872
873 for assist_id in assist_ids {
874 if let Some(assist) = self.assists.get(assist_id) {
875 let codegen = assist.codegen.read(cx);
876 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
877
878 gutter_pending_ranges
879 .push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end);
880
881 if let Some(edit_position) = codegen.edit_position {
882 gutter_transformed_ranges.push(assist.range.start..edit_position);
883 }
884
885 if assist.decorations.is_some() {
886 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
887 }
888 }
889 }
890
891 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
892 merge_ranges(&mut foreground_ranges, &snapshot);
893 merge_ranges(&mut gutter_pending_ranges, &snapshot);
894 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
895 editor.update(cx, |editor, cx| {
896 enum GutterPendingRange {}
897 if gutter_pending_ranges.is_empty() {
898 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
899 } else {
900 editor.highlight_gutter::<GutterPendingRange>(
901 &gutter_pending_ranges,
902 |cx| cx.theme().status().info_background,
903 cx,
904 )
905 }
906
907 enum GutterTransformedRange {}
908 if gutter_transformed_ranges.is_empty() {
909 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
910 } else {
911 editor.highlight_gutter::<GutterTransformedRange>(
912 &gutter_transformed_ranges,
913 |cx| cx.theme().status().info,
914 cx,
915 )
916 }
917
918 if foreground_ranges.is_empty() {
919 editor.clear_highlights::<InlineAssist>(cx);
920 } else {
921 editor.highlight_text::<InlineAssist>(
922 foreground_ranges,
923 HighlightStyle {
924 fade_out: Some(0.6),
925 ..Default::default()
926 },
927 cx,
928 );
929 }
930
931 editor.clear_row_highlights::<InlineAssist>();
932 for row_range in inserted_row_ranges {
933 editor.highlight_rows::<InlineAssist>(
934 row_range,
935 Some(cx.theme().status().info_background),
936 false,
937 cx,
938 );
939 }
940 });
941 }
942
943 fn update_editor_blocks(
944 &mut self,
945 editor: &View<Editor>,
946 assist_id: InlineAssistId,
947 cx: &mut WindowContext,
948 ) {
949 let Some(assist) = self.assists.get_mut(&assist_id) else {
950 return;
951 };
952 let Some(decorations) = assist.decorations.as_mut() else {
953 return;
954 };
955
956 let codegen = assist.codegen.read(cx);
957 let old_snapshot = codegen.snapshot.clone();
958 let old_buffer = codegen.old_buffer.clone();
959 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
960
961 editor.update(cx, |editor, cx| {
962 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
963 editor.remove_blocks(old_blocks, None, cx);
964
965 let mut new_blocks = Vec::new();
966 for (new_row, old_row_range) in deleted_row_ranges {
967 let (_, buffer_start) = old_snapshot
968 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
969 .unwrap();
970 let (_, buffer_end) = old_snapshot
971 .point_to_buffer_offset(Point::new(
972 *old_row_range.end(),
973 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
974 ))
975 .unwrap();
976
977 let deleted_lines_editor = cx.new_view(|cx| {
978 let multi_buffer = cx.new_model(|_| {
979 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
980 });
981 multi_buffer.update(cx, |multi_buffer, cx| {
982 multi_buffer.push_excerpts(
983 old_buffer.clone(),
984 Some(ExcerptRange {
985 context: buffer_start..buffer_end,
986 primary: None,
987 }),
988 cx,
989 );
990 });
991
992 enum DeletedLines {}
993 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
994 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
995 editor.set_show_wrap_guides(false, cx);
996 editor.set_show_gutter(false, cx);
997 editor.scroll_manager.set_forbid_vertical_scroll(true);
998 editor.set_read_only(true);
999 editor.highlight_rows::<DeletedLines>(
1000 Anchor::min()..=Anchor::max(),
1001 Some(cx.theme().status().deleted_background),
1002 false,
1003 cx,
1004 );
1005 editor
1006 });
1007
1008 let height =
1009 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1010 new_blocks.push(BlockProperties {
1011 position: new_row,
1012 height,
1013 style: BlockStyle::Flex,
1014 render: Box::new(move |cx| {
1015 div()
1016 .bg(cx.theme().status().deleted_background)
1017 .size_full()
1018 .pl(cx.gutter_dimensions.full_width())
1019 .child(deleted_lines_editor.clone())
1020 .into_any_element()
1021 }),
1022 disposition: BlockDisposition::Above,
1023 });
1024 }
1025
1026 decorations.removed_line_block_ids = editor
1027 .insert_blocks(new_blocks, None, cx)
1028 .into_iter()
1029 .collect();
1030 })
1031 }
1032}
1033
1034struct EditorInlineAssists {
1035 assist_ids: Vec<InlineAssistId>,
1036 scroll_lock: Option<InlineAssistScrollLock>,
1037 highlight_updates: async_watch::Sender<()>,
1038 _update_highlights: Task<Result<()>>,
1039 _subscriptions: Vec<gpui::Subscription>,
1040}
1041
1042struct InlineAssistScrollLock {
1043 assist_id: InlineAssistId,
1044 distance_from_top: f32,
1045}
1046
1047impl EditorInlineAssists {
1048 #[allow(clippy::too_many_arguments)]
1049 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1050 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1051 Self {
1052 assist_ids: Vec::new(),
1053 scroll_lock: None,
1054 highlight_updates: highlight_updates_tx,
1055 _update_highlights: cx.spawn(|mut cx| {
1056 let editor = editor.downgrade();
1057 async move {
1058 while let Ok(()) = highlight_updates_rx.changed().await {
1059 let editor = editor.upgrade().context("editor was dropped")?;
1060 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1061 assistant.update_editor_highlights(&editor, cx);
1062 })?;
1063 }
1064 Ok(())
1065 }
1066 }),
1067 _subscriptions: vec![
1068 cx.observe_release(editor, {
1069 let editor = editor.downgrade();
1070 |_, cx| {
1071 InlineAssistant::update_global(cx, |this, cx| {
1072 this.handle_editor_release(editor, cx);
1073 })
1074 }
1075 }),
1076 cx.observe(editor, move |editor, cx| {
1077 InlineAssistant::update_global(cx, |this, cx| {
1078 this.handle_editor_change(editor, cx)
1079 })
1080 }),
1081 cx.subscribe(editor, move |editor, event, cx| {
1082 InlineAssistant::update_global(cx, |this, cx| {
1083 this.handle_editor_event(editor, event, cx)
1084 })
1085 }),
1086 editor.update(cx, |editor, cx| {
1087 let editor_handle = cx.view().downgrade();
1088 editor.register_action(
1089 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1090 InlineAssistant::update_global(cx, |this, cx| {
1091 if let Some(editor) = editor_handle.upgrade() {
1092 this.handle_editor_newline(editor, cx)
1093 }
1094 })
1095 },
1096 )
1097 }),
1098 editor.update(cx, |editor, cx| {
1099 let editor_handle = cx.view().downgrade();
1100 editor.register_action(
1101 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1102 InlineAssistant::update_global(cx, |this, cx| {
1103 if let Some(editor) = editor_handle.upgrade() {
1104 this.handle_editor_cancel(editor, cx)
1105 }
1106 })
1107 },
1108 )
1109 }),
1110 ],
1111 }
1112 }
1113}
1114
1115struct InlineAssistGroup {
1116 assist_ids: Vec<InlineAssistId>,
1117 linked: bool,
1118 active_assist_id: Option<InlineAssistId>,
1119}
1120
1121impl InlineAssistGroup {
1122 fn new() -> Self {
1123 Self {
1124 assist_ids: Vec::new(),
1125 linked: true,
1126 active_assist_id: None,
1127 }
1128 }
1129}
1130
1131fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1132 let editor = editor.clone();
1133 Box::new(move |cx: &mut BlockContext| {
1134 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1135 editor.clone().into_any_element()
1136 })
1137}
1138
1139#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1140pub enum InitialInsertion {
1141 NewlineBefore,
1142 NewlineAfter,
1143}
1144
1145#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1146pub struct InlineAssistId(usize);
1147
1148impl InlineAssistId {
1149 fn post_inc(&mut self) -> InlineAssistId {
1150 let id = *self;
1151 self.0 += 1;
1152 id
1153 }
1154}
1155
1156#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1157struct InlineAssistGroupId(usize);
1158
1159impl InlineAssistGroupId {
1160 fn post_inc(&mut self) -> InlineAssistGroupId {
1161 let id = *self;
1162 self.0 += 1;
1163 id
1164 }
1165}
1166
1167enum PromptEditorEvent {
1168 StartRequested,
1169 StopRequested,
1170 ConfirmRequested,
1171 CancelRequested,
1172 DismissRequested,
1173}
1174
1175struct PromptEditor {
1176 id: InlineAssistId,
1177 fs: Arc<dyn Fs>,
1178 editor: View<Editor>,
1179 edited_since_done: bool,
1180 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1181 prompt_history: VecDeque<String>,
1182 prompt_history_ix: Option<usize>,
1183 pending_prompt: String,
1184 codegen: Model<Codegen>,
1185 _codegen_subscription: Subscription,
1186 editor_subscriptions: Vec<Subscription>,
1187 pending_token_count: Task<Result<()>>,
1188 token_count: Option<usize>,
1189 _token_count_subscriptions: Vec<Subscription>,
1190 workspace: Option<WeakView<Workspace>>,
1191 show_rate_limit_notice: bool,
1192}
1193
1194impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1195
1196impl Render for PromptEditor {
1197 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1198 let gutter_dimensions = *self.gutter_dimensions.lock();
1199 let status = &self.codegen.read(cx).status;
1200 let buttons = match status {
1201 CodegenStatus::Idle => {
1202 vec![
1203 IconButton::new("cancel", IconName::Close)
1204 .icon_color(Color::Muted)
1205 .shape(IconButtonShape::Square)
1206 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1207 .on_click(
1208 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1209 ),
1210 IconButton::new("start", IconName::SparkleAlt)
1211 .icon_color(Color::Muted)
1212 .shape(IconButtonShape::Square)
1213 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1214 .on_click(
1215 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1216 ),
1217 ]
1218 }
1219 CodegenStatus::Pending => {
1220 vec![
1221 IconButton::new("cancel", IconName::Close)
1222 .icon_color(Color::Muted)
1223 .shape(IconButtonShape::Square)
1224 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1225 .on_click(
1226 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1227 ),
1228 IconButton::new("stop", IconName::Stop)
1229 .icon_color(Color::Error)
1230 .shape(IconButtonShape::Square)
1231 .tooltip(|cx| {
1232 Tooltip::with_meta(
1233 "Interrupt Transformation",
1234 Some(&menu::Cancel),
1235 "Changes won't be discarded",
1236 cx,
1237 )
1238 })
1239 .on_click(
1240 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1241 ),
1242 ]
1243 }
1244 CodegenStatus::Error(_) | CodegenStatus::Done => {
1245 vec![
1246 IconButton::new("cancel", IconName::Close)
1247 .icon_color(Color::Muted)
1248 .shape(IconButtonShape::Square)
1249 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1250 .on_click(
1251 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1252 ),
1253 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1254 IconButton::new("restart", IconName::RotateCw)
1255 .icon_color(Color::Info)
1256 .shape(IconButtonShape::Square)
1257 .tooltip(|cx| {
1258 Tooltip::with_meta(
1259 "Restart Transformation",
1260 Some(&menu::Confirm),
1261 "Changes will be discarded",
1262 cx,
1263 )
1264 })
1265 .on_click(cx.listener(|_, _, cx| {
1266 cx.emit(PromptEditorEvent::StartRequested);
1267 }))
1268 } else {
1269 IconButton::new("confirm", IconName::Check)
1270 .icon_color(Color::Info)
1271 .shape(IconButtonShape::Square)
1272 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1273 .on_click(cx.listener(|_, _, cx| {
1274 cx.emit(PromptEditorEvent::ConfirmRequested);
1275 }))
1276 },
1277 ]
1278 }
1279 };
1280
1281 h_flex()
1282 .bg(cx.theme().colors().editor_background)
1283 .border_y_1()
1284 .border_color(cx.theme().status().info_border)
1285 .size_full()
1286 .py(cx.line_height() / 2.)
1287 .on_action(cx.listener(Self::confirm))
1288 .on_action(cx.listener(Self::cancel))
1289 .on_action(cx.listener(Self::move_up))
1290 .on_action(cx.listener(Self::move_down))
1291 .child(
1292 h_flex()
1293 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1294 .justify_center()
1295 .gap_2()
1296 .child(
1297 ModelSelector::new(
1298 self.fs.clone(),
1299 IconButton::new("context", IconName::SlidersAlt)
1300 .shape(IconButtonShape::Square)
1301 .icon_size(IconSize::Small)
1302 .icon_color(Color::Muted)
1303 .tooltip(move |cx| {
1304 Tooltip::with_meta(
1305 format!(
1306 "Using {}",
1307 LanguageModelRegistry::read_global(cx)
1308 .active_model()
1309 .map(|model| model.name().0)
1310 .unwrap_or_else(|| "No model selected".into()),
1311 ),
1312 None,
1313 "Change Model",
1314 cx,
1315 )
1316 }),
1317 )
1318 .with_info_text(
1319 "Inline edits use context\n\
1320 from the currently selected\n\
1321 assistant panel tab.",
1322 ),
1323 )
1324 .map(|el| {
1325 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1326 return el;
1327 };
1328
1329 let error_message = SharedString::from(error.to_string());
1330 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1331 && cx.has_flag::<ZedPro>()
1332 {
1333 el.child(
1334 v_flex()
1335 .child(
1336 IconButton::new("rate-limit-error", IconName::XCircle)
1337 .selected(self.show_rate_limit_notice)
1338 .shape(IconButtonShape::Square)
1339 .icon_size(IconSize::Small)
1340 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1341 )
1342 .children(self.show_rate_limit_notice.then(|| {
1343 deferred(
1344 anchored()
1345 .position_mode(gpui::AnchoredPositionMode::Local)
1346 .position(point(px(0.), px(24.)))
1347 .anchor(gpui::AnchorCorner::TopLeft)
1348 .child(self.render_rate_limit_notice(cx)),
1349 )
1350 })),
1351 )
1352 } else {
1353 el.child(
1354 div()
1355 .id("error")
1356 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1357 .child(
1358 Icon::new(IconName::XCircle)
1359 .size(IconSize::Small)
1360 .color(Color::Error),
1361 ),
1362 )
1363 }
1364 }),
1365 )
1366 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1367 .child(
1368 h_flex()
1369 .gap_2()
1370 .pr_6()
1371 .children(self.render_token_count(cx))
1372 .children(buttons),
1373 )
1374 }
1375}
1376
1377impl FocusableView for PromptEditor {
1378 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1379 self.editor.focus_handle(cx)
1380 }
1381}
1382
1383impl PromptEditor {
1384 const MAX_LINES: u8 = 8;
1385
1386 #[allow(clippy::too_many_arguments)]
1387 fn new(
1388 id: InlineAssistId,
1389 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1390 prompt_history: VecDeque<String>,
1391 prompt_buffer: Model<MultiBuffer>,
1392 codegen: Model<Codegen>,
1393 parent_editor: &View<Editor>,
1394 assistant_panel: Option<&View<AssistantPanel>>,
1395 workspace: Option<WeakView<Workspace>>,
1396 fs: Arc<dyn Fs>,
1397 cx: &mut ViewContext<Self>,
1398 ) -> Self {
1399 let prompt_editor = cx.new_view(|cx| {
1400 let mut editor = Editor::new(
1401 EditorMode::AutoHeight {
1402 max_lines: Self::MAX_LINES as usize,
1403 },
1404 prompt_buffer,
1405 None,
1406 false,
1407 cx,
1408 );
1409 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1410 // Since the prompt editors for all inline assistants are linked,
1411 // always show the cursor (even when it isn't focused) because
1412 // typing in one will make what you typed appear in all of them.
1413 editor.set_show_cursor_when_unfocused(true, cx);
1414 editor.set_placeholder_text("Add a prompt…", cx);
1415 editor
1416 });
1417
1418 let mut token_count_subscriptions = Vec::new();
1419 token_count_subscriptions
1420 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1421 if let Some(assistant_panel) = assistant_panel {
1422 token_count_subscriptions
1423 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1424 }
1425
1426 let mut this = Self {
1427 id,
1428 editor: prompt_editor,
1429 edited_since_done: false,
1430 gutter_dimensions,
1431 prompt_history,
1432 prompt_history_ix: None,
1433 pending_prompt: String::new(),
1434 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1435 editor_subscriptions: Vec::new(),
1436 codegen,
1437 fs,
1438 pending_token_count: Task::ready(Ok(())),
1439 token_count: None,
1440 _token_count_subscriptions: token_count_subscriptions,
1441 workspace,
1442 show_rate_limit_notice: false,
1443 };
1444 this.count_tokens(cx);
1445 this.subscribe_to_editor(cx);
1446 this
1447 }
1448
1449 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1450 self.editor_subscriptions.clear();
1451 self.editor_subscriptions
1452 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1453 }
1454
1455 fn set_show_cursor_when_unfocused(
1456 &mut self,
1457 show_cursor_when_unfocused: bool,
1458 cx: &mut ViewContext<Self>,
1459 ) {
1460 self.editor.update(cx, |editor, cx| {
1461 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1462 });
1463 }
1464
1465 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1466 let prompt = self.prompt(cx);
1467 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1468 self.editor = cx.new_view(|cx| {
1469 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1470 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1471 editor.set_placeholder_text("Add a prompt…", cx);
1472 editor.set_text(prompt, cx);
1473 if focus {
1474 editor.focus(cx);
1475 }
1476 editor
1477 });
1478 self.subscribe_to_editor(cx);
1479 }
1480
1481 fn prompt(&self, cx: &AppContext) -> String {
1482 self.editor.read(cx).text(cx)
1483 }
1484
1485 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1486 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1487 if self.show_rate_limit_notice {
1488 cx.focus_view(&self.editor);
1489 }
1490 cx.notify();
1491 }
1492
1493 fn handle_parent_editor_event(
1494 &mut self,
1495 _: View<Editor>,
1496 event: &EditorEvent,
1497 cx: &mut ViewContext<Self>,
1498 ) {
1499 if let EditorEvent::BufferEdited { .. } = event {
1500 self.count_tokens(cx);
1501 }
1502 }
1503
1504 fn handle_assistant_panel_event(
1505 &mut self,
1506 _: View<AssistantPanel>,
1507 event: &AssistantPanelEvent,
1508 cx: &mut ViewContext<Self>,
1509 ) {
1510 let AssistantPanelEvent::ContextEdited { .. } = event;
1511 self.count_tokens(cx);
1512 }
1513
1514 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1515 let assist_id = self.id;
1516 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1517 cx.background_executor().timer(Duration::from_secs(1)).await;
1518 let token_count = cx
1519 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1520 let assist = inline_assistant
1521 .assists
1522 .get(&assist_id)
1523 .context("assist not found")?;
1524 anyhow::Ok(assist.count_tokens(cx))
1525 })??
1526 .await?;
1527
1528 this.update(&mut cx, |this, cx| {
1529 this.token_count = Some(token_count);
1530 cx.notify();
1531 })
1532 })
1533 }
1534
1535 fn handle_prompt_editor_events(
1536 &mut self,
1537 _: View<Editor>,
1538 event: &EditorEvent,
1539 cx: &mut ViewContext<Self>,
1540 ) {
1541 match event {
1542 EditorEvent::Edited { .. } => {
1543 let prompt = self.editor.read(cx).text(cx);
1544 if self
1545 .prompt_history_ix
1546 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1547 {
1548 self.prompt_history_ix.take();
1549 self.pending_prompt = prompt;
1550 }
1551
1552 self.edited_since_done = true;
1553 cx.notify();
1554 }
1555 EditorEvent::BufferEdited => {
1556 self.count_tokens(cx);
1557 }
1558 EditorEvent::Blurred => {
1559 if self.show_rate_limit_notice {
1560 self.show_rate_limit_notice = false;
1561 cx.notify();
1562 }
1563 }
1564 _ => {}
1565 }
1566 }
1567
1568 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1569 match &self.codegen.read(cx).status {
1570 CodegenStatus::Idle => {
1571 self.editor
1572 .update(cx, |editor, _| editor.set_read_only(false));
1573 }
1574 CodegenStatus::Pending => {
1575 self.editor
1576 .update(cx, |editor, _| editor.set_read_only(true));
1577 }
1578 CodegenStatus::Done => {
1579 self.edited_since_done = false;
1580 self.editor
1581 .update(cx, |editor, _| editor.set_read_only(false));
1582 }
1583 CodegenStatus::Error(error) => {
1584 if cx.has_flag::<ZedPro>()
1585 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1586 && !dismissed_rate_limit_notice()
1587 {
1588 self.show_rate_limit_notice = true;
1589 cx.notify();
1590 }
1591
1592 self.edited_since_done = false;
1593 self.editor
1594 .update(cx, |editor, _| editor.set_read_only(false));
1595 }
1596 }
1597 }
1598
1599 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1600 match &self.codegen.read(cx).status {
1601 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1602 cx.emit(PromptEditorEvent::CancelRequested);
1603 }
1604 CodegenStatus::Pending => {
1605 cx.emit(PromptEditorEvent::StopRequested);
1606 }
1607 }
1608 }
1609
1610 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1611 match &self.codegen.read(cx).status {
1612 CodegenStatus::Idle => {
1613 cx.emit(PromptEditorEvent::StartRequested);
1614 }
1615 CodegenStatus::Pending => {
1616 cx.emit(PromptEditorEvent::DismissRequested);
1617 }
1618 CodegenStatus::Done | CodegenStatus::Error(_) => {
1619 if self.edited_since_done {
1620 cx.emit(PromptEditorEvent::StartRequested);
1621 } else {
1622 cx.emit(PromptEditorEvent::ConfirmRequested);
1623 }
1624 }
1625 }
1626 }
1627
1628 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1629 if let Some(ix) = self.prompt_history_ix {
1630 if ix > 0 {
1631 self.prompt_history_ix = Some(ix - 1);
1632 let prompt = self.prompt_history[ix - 1].as_str();
1633 self.editor.update(cx, |editor, cx| {
1634 editor.set_text(prompt, cx);
1635 editor.move_to_beginning(&Default::default(), cx);
1636 });
1637 }
1638 } else if !self.prompt_history.is_empty() {
1639 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1640 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1641 self.editor.update(cx, |editor, cx| {
1642 editor.set_text(prompt, cx);
1643 editor.move_to_beginning(&Default::default(), cx);
1644 });
1645 }
1646 }
1647
1648 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1649 if let Some(ix) = self.prompt_history_ix {
1650 if ix < self.prompt_history.len() - 1 {
1651 self.prompt_history_ix = Some(ix + 1);
1652 let prompt = self.prompt_history[ix + 1].as_str();
1653 self.editor.update(cx, |editor, cx| {
1654 editor.set_text(prompt, cx);
1655 editor.move_to_end(&Default::default(), cx)
1656 });
1657 } else {
1658 self.prompt_history_ix = None;
1659 let prompt = self.pending_prompt.as_str();
1660 self.editor.update(cx, |editor, cx| {
1661 editor.set_text(prompt, cx);
1662 editor.move_to_end(&Default::default(), cx)
1663 });
1664 }
1665 }
1666 }
1667
1668 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1669 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1670 let token_count = self.token_count?;
1671 let max_token_count = model.max_token_count();
1672
1673 let remaining_tokens = max_token_count as isize - token_count as isize;
1674 let token_count_color = if remaining_tokens <= 0 {
1675 Color::Error
1676 } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1677 Color::Warning
1678 } else {
1679 Color::Muted
1680 };
1681
1682 let mut token_count = h_flex()
1683 .id("token_count")
1684 .gap_0p5()
1685 .child(
1686 Label::new(humanize_token_count(token_count))
1687 .size(LabelSize::Small)
1688 .color(token_count_color),
1689 )
1690 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1691 .child(
1692 Label::new(humanize_token_count(max_token_count))
1693 .size(LabelSize::Small)
1694 .color(Color::Muted),
1695 );
1696 if let Some(workspace) = self.workspace.clone() {
1697 token_count = token_count
1698 .tooltip(|cx| {
1699 Tooltip::with_meta(
1700 "Tokens Used by Inline Assistant",
1701 None,
1702 "Click to Open Assistant Panel",
1703 cx,
1704 )
1705 })
1706 .cursor_pointer()
1707 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1708 .on_click(move |_, cx| {
1709 cx.stop_propagation();
1710 workspace
1711 .update(cx, |workspace, cx| {
1712 workspace.focus_panel::<AssistantPanel>(cx)
1713 })
1714 .ok();
1715 });
1716 } else {
1717 token_count = token_count
1718 .cursor_default()
1719 .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1720 }
1721
1722 Some(token_count)
1723 }
1724
1725 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1726 let settings = ThemeSettings::get_global(cx);
1727 let text_style = TextStyle {
1728 color: if self.editor.read(cx).read_only(cx) {
1729 cx.theme().colors().text_disabled
1730 } else {
1731 cx.theme().colors().text
1732 },
1733 font_family: settings.ui_font.family.clone(),
1734 font_features: settings.ui_font.features.clone(),
1735 font_fallbacks: settings.ui_font.fallbacks.clone(),
1736 font_size: rems(0.875).into(),
1737 font_weight: settings.ui_font.weight,
1738 line_height: relative(1.3),
1739 ..Default::default()
1740 };
1741 EditorElement::new(
1742 &self.editor,
1743 EditorStyle {
1744 background: cx.theme().colors().editor_background,
1745 local_player: cx.theme().players().local(),
1746 text: text_style,
1747 ..Default::default()
1748 },
1749 )
1750 }
1751
1752 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1753 Popover::new().child(
1754 v_flex()
1755 .occlude()
1756 .p_2()
1757 .child(
1758 Label::new("Out of Tokens")
1759 .size(LabelSize::Small)
1760 .weight(FontWeight::BOLD),
1761 )
1762 .child(Label::new(
1763 "Try Zed Pro for higher limits, a wider range of models, and more.",
1764 ))
1765 .child(
1766 h_flex()
1767 .justify_between()
1768 .child(CheckboxWithLabel::new(
1769 "dont-show-again",
1770 Label::new("Don't show again"),
1771 if dismissed_rate_limit_notice() {
1772 ui::Selection::Selected
1773 } else {
1774 ui::Selection::Unselected
1775 },
1776 |selection, cx| {
1777 let is_dismissed = match selection {
1778 ui::Selection::Unselected => false,
1779 ui::Selection::Indeterminate => return,
1780 ui::Selection::Selected => true,
1781 };
1782
1783 set_rate_limit_notice_dismissed(is_dismissed, cx)
1784 },
1785 ))
1786 .child(
1787 h_flex()
1788 .gap_2()
1789 .child(
1790 Button::new("dismiss", "Dismiss")
1791 .style(ButtonStyle::Transparent)
1792 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1793 )
1794 .child(Button::new("more-info", "More Info").on_click(
1795 |_event, cx| {
1796 cx.dispatch_action(Box::new(
1797 zed_actions::OpenAccountSettings,
1798 ))
1799 },
1800 )),
1801 ),
1802 ),
1803 )
1804 }
1805}
1806
1807const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1808
1809fn dismissed_rate_limit_notice() -> bool {
1810 db::kvp::KEY_VALUE_STORE
1811 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1812 .log_err()
1813 .map_or(false, |s| s.is_some())
1814}
1815
1816fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1817 db::write_and_log(cx, move || async move {
1818 if is_dismissed {
1819 db::kvp::KEY_VALUE_STORE
1820 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1821 .await
1822 } else {
1823 db::kvp::KEY_VALUE_STORE
1824 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1825 .await
1826 }
1827 })
1828}
1829
1830struct InlineAssist {
1831 group_id: InlineAssistGroupId,
1832 range: Range<Anchor>,
1833 editor: WeakView<Editor>,
1834 decorations: Option<InlineAssistDecorations>,
1835 codegen: Model<Codegen>,
1836 _subscriptions: Vec<Subscription>,
1837 workspace: Option<WeakView<Workspace>>,
1838 include_context: bool,
1839}
1840
1841impl InlineAssist {
1842 #[allow(clippy::too_many_arguments)]
1843 fn new(
1844 assist_id: InlineAssistId,
1845 group_id: InlineAssistGroupId,
1846 include_context: bool,
1847 editor: &View<Editor>,
1848 prompt_editor: &View<PromptEditor>,
1849 prompt_block_id: CustomBlockId,
1850 end_block_id: CustomBlockId,
1851 range: Range<Anchor>,
1852 codegen: Model<Codegen>,
1853 workspace: Option<WeakView<Workspace>>,
1854 cx: &mut WindowContext,
1855 ) -> Self {
1856 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1857 InlineAssist {
1858 group_id,
1859 include_context,
1860 editor: editor.downgrade(),
1861 decorations: Some(InlineAssistDecorations {
1862 prompt_block_id,
1863 prompt_editor: prompt_editor.clone(),
1864 removed_line_block_ids: HashSet::default(),
1865 end_block_id,
1866 }),
1867 range,
1868 codegen: codegen.clone(),
1869 workspace: workspace.clone(),
1870 _subscriptions: vec![
1871 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1872 InlineAssistant::update_global(cx, |this, cx| {
1873 this.handle_prompt_editor_focus_in(assist_id, cx)
1874 })
1875 }),
1876 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1877 InlineAssistant::update_global(cx, |this, cx| {
1878 this.handle_prompt_editor_focus_out(assist_id, cx)
1879 })
1880 }),
1881 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1882 InlineAssistant::update_global(cx, |this, cx| {
1883 this.handle_prompt_editor_event(prompt_editor, event, cx)
1884 })
1885 }),
1886 cx.observe(&codegen, {
1887 let editor = editor.downgrade();
1888 move |_, cx| {
1889 if let Some(editor) = editor.upgrade() {
1890 InlineAssistant::update_global(cx, |this, cx| {
1891 if let Some(editor_assists) =
1892 this.assists_by_editor.get(&editor.downgrade())
1893 {
1894 editor_assists.highlight_updates.send(()).ok();
1895 }
1896
1897 this.update_editor_blocks(&editor, assist_id, cx);
1898 })
1899 }
1900 }
1901 }),
1902 cx.subscribe(&codegen, move |codegen, event, cx| {
1903 InlineAssistant::update_global(cx, |this, cx| match event {
1904 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1905 CodegenEvent::Finished => {
1906 let assist = if let Some(assist) = this.assists.get(&assist_id) {
1907 assist
1908 } else {
1909 return;
1910 };
1911
1912 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
1913 if assist.decorations.is_none() {
1914 if let Some(workspace) = assist
1915 .workspace
1916 .as_ref()
1917 .and_then(|workspace| workspace.upgrade())
1918 {
1919 let error = format!("Inline assistant error: {}", error);
1920 workspace.update(cx, |workspace, cx| {
1921 struct InlineAssistantError;
1922
1923 let id =
1924 NotificationId::identified::<InlineAssistantError>(
1925 assist_id.0,
1926 );
1927
1928 workspace.show_toast(Toast::new(id, error), cx);
1929 })
1930 }
1931 }
1932 }
1933
1934 if assist.decorations.is_none() {
1935 this.finish_assist(assist_id, false, cx);
1936 }
1937 }
1938 })
1939 }),
1940 ],
1941 }
1942 }
1943
1944 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
1945 let decorations = self.decorations.as_ref()?;
1946 Some(decorations.prompt_editor.read(cx).prompt(cx))
1947 }
1948
1949 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
1950 if self.include_context {
1951 let workspace = self.workspace.as_ref()?;
1952 let workspace = workspace.upgrade()?.read(cx);
1953 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
1954 Some(
1955 assistant_panel
1956 .read(cx)
1957 .active_context(cx)?
1958 .read(cx)
1959 .to_completion_request(cx),
1960 )
1961 } else {
1962 None
1963 }
1964 }
1965
1966 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
1967 let Some(user_prompt) = self.user_prompt(cx) else {
1968 return future::ready(Err(anyhow!("no user prompt"))).boxed();
1969 };
1970 let assistant_panel_context = self.assistant_panel_context(cx);
1971 self.codegen.read(cx).count_tokens(
1972 self.range.clone(),
1973 user_prompt,
1974 assistant_panel_context,
1975 cx,
1976 )
1977 }
1978}
1979
1980struct InlineAssistDecorations {
1981 prompt_block_id: CustomBlockId,
1982 prompt_editor: View<PromptEditor>,
1983 removed_line_block_ids: HashSet<CustomBlockId>,
1984 end_block_id: CustomBlockId,
1985}
1986
1987#[derive(Debug)]
1988pub enum CodegenEvent {
1989 Finished,
1990 Undone,
1991}
1992
1993pub struct Codegen {
1994 buffer: Model<MultiBuffer>,
1995 old_buffer: Model<Buffer>,
1996 snapshot: MultiBufferSnapshot,
1997 edit_position: Option<Anchor>,
1998 last_equal_ranges: Vec<Range<Anchor>>,
1999 transaction_id: Option<TransactionId>,
2000 status: CodegenStatus,
2001 generation: Task<()>,
2002 diff: Diff,
2003 telemetry: Option<Arc<Telemetry>>,
2004 _subscription: gpui::Subscription,
2005 initial_insertion: Option<InitialInsertion>,
2006}
2007
2008enum CodegenStatus {
2009 Idle,
2010 Pending,
2011 Done,
2012 Error(anyhow::Error),
2013}
2014
2015#[derive(Default)]
2016struct Diff {
2017 task: Option<Task<()>>,
2018 should_update: bool,
2019 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2020 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2021}
2022
2023impl EventEmitter<CodegenEvent> for Codegen {}
2024
2025impl Codegen {
2026 pub fn new(
2027 buffer: Model<MultiBuffer>,
2028 range: Range<Anchor>,
2029 initial_insertion: Option<InitialInsertion>,
2030 telemetry: Option<Arc<Telemetry>>,
2031 cx: &mut ModelContext<Self>,
2032 ) -> Self {
2033 let snapshot = buffer.read(cx).snapshot(cx);
2034
2035 let (old_buffer, _, _) = buffer
2036 .read(cx)
2037 .range_to_buffer_ranges(range.clone(), cx)
2038 .pop()
2039 .unwrap();
2040 let old_buffer = cx.new_model(|cx| {
2041 let old_buffer = old_buffer.read(cx);
2042 let text = old_buffer.as_rope().clone();
2043 let line_ending = old_buffer.line_ending();
2044 let language = old_buffer.language().cloned();
2045 let language_registry = old_buffer.language_registry();
2046
2047 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2048 buffer.set_language(language, cx);
2049 if let Some(language_registry) = language_registry {
2050 buffer.set_language_registry(language_registry)
2051 }
2052 buffer
2053 });
2054
2055 Self {
2056 buffer: buffer.clone(),
2057 old_buffer,
2058 edit_position: None,
2059 snapshot,
2060 last_equal_ranges: Default::default(),
2061 transaction_id: None,
2062 status: CodegenStatus::Idle,
2063 generation: Task::ready(()),
2064 diff: Diff::default(),
2065 telemetry,
2066 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2067 initial_insertion,
2068 }
2069 }
2070
2071 fn handle_buffer_event(
2072 &mut self,
2073 _buffer: Model<MultiBuffer>,
2074 event: &multi_buffer::Event,
2075 cx: &mut ModelContext<Self>,
2076 ) {
2077 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2078 if self.transaction_id == Some(*transaction_id) {
2079 self.transaction_id = None;
2080 self.generation = Task::ready(());
2081 cx.emit(CodegenEvent::Undone);
2082 }
2083 }
2084 }
2085
2086 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2087 &self.last_equal_ranges
2088 }
2089
2090 pub fn count_tokens(
2091 &self,
2092 edit_range: Range<Anchor>,
2093 user_prompt: String,
2094 assistant_panel_context: Option<LanguageModelRequest>,
2095 cx: &AppContext,
2096 ) -> BoxFuture<'static, Result<usize>> {
2097 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2098 let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
2099 model.count_tokens(request, cx)
2100 } else {
2101 future::ready(Err(anyhow!("no active model"))).boxed()
2102 }
2103 }
2104
2105 pub fn start(
2106 &mut self,
2107 mut edit_range: Range<Anchor>,
2108 user_prompt: String,
2109 assistant_panel_context: Option<LanguageModelRequest>,
2110 cx: &mut ModelContext<Self>,
2111 ) -> Result<()> {
2112 let model = LanguageModelRegistry::read_global(cx)
2113 .active_model()
2114 .context("no active model")?;
2115
2116 self.undo(cx);
2117
2118 // Handle initial insertion
2119 self.transaction_id = if let Some(initial_insertion) = self.initial_insertion {
2120 self.buffer.update(cx, |buffer, cx| {
2121 buffer.start_transaction(cx);
2122 let offset = edit_range.start.to_offset(&self.snapshot);
2123 let edit_position;
2124 match initial_insertion {
2125 InitialInsertion::NewlineBefore => {
2126 buffer.edit([(offset..offset, "\n\n")], None, cx);
2127 self.snapshot = buffer.snapshot(cx);
2128 edit_position = self.snapshot.anchor_after(offset + 1);
2129 }
2130 InitialInsertion::NewlineAfter => {
2131 buffer.edit([(offset..offset, "\n")], None, cx);
2132 self.snapshot = buffer.snapshot(cx);
2133 edit_position = self.snapshot.anchor_after(offset);
2134 }
2135 }
2136 self.edit_position = Some(edit_position);
2137 edit_range = edit_position.bias_left(&self.snapshot)..edit_position;
2138 buffer.end_transaction(cx)
2139 })
2140 } else {
2141 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2142 None
2143 };
2144
2145 let telemetry_id = model.telemetry_id();
2146 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2147 .trim()
2148 .to_lowercase()
2149 == "delete"
2150 {
2151 async { Ok(stream::empty().boxed()) }.boxed_local()
2152 } else {
2153 let request =
2154 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
2155 let chunks =
2156 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2157 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2158 };
2159 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2160 Ok(())
2161 }
2162
2163 fn build_request(
2164 &self,
2165 user_prompt: String,
2166 assistant_panel_context: Option<LanguageModelRequest>,
2167 edit_range: Range<Anchor>,
2168 cx: &AppContext,
2169 ) -> LanguageModelRequest {
2170 let buffer = self.buffer.read(cx).snapshot(cx);
2171 let language = buffer.language_at(edit_range.start);
2172 let language_name = if let Some(language) = language.as_ref() {
2173 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2174 None
2175 } else {
2176 Some(language.name())
2177 }
2178 } else {
2179 None
2180 };
2181
2182 // Higher Temperature increases the randomness of model outputs.
2183 // If Markdown or No Language is Known, increase the randomness for more creative output
2184 // If Code, decrease temperature to get more deterministic outputs
2185 let temperature = if let Some(language) = language_name.clone() {
2186 if language.as_ref() == "Markdown" {
2187 1.0
2188 } else {
2189 0.5
2190 }
2191 } else {
2192 1.0
2193 };
2194
2195 let language_name = language_name.as_deref();
2196 let start = buffer.point_to_buffer_offset(edit_range.start);
2197 let end = buffer.point_to_buffer_offset(edit_range.end);
2198 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2199 let (start_buffer, start_buffer_offset) = start;
2200 let (end_buffer, end_buffer_offset) = end;
2201 if start_buffer.remote_id() == end_buffer.remote_id() {
2202 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2203 } else {
2204 panic!("invalid transformation range");
2205 }
2206 } else {
2207 panic!("invalid transformation range");
2208 };
2209 let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
2210
2211 let mut messages = Vec::new();
2212 if let Some(context_request) = assistant_panel_context {
2213 messages = context_request.messages;
2214 }
2215
2216 messages.push(LanguageModelRequestMessage {
2217 role: Role::User,
2218 content: prompt,
2219 });
2220
2221 LanguageModelRequest {
2222 messages,
2223 stop: vec!["|END|>".to_string()],
2224 temperature,
2225 }
2226 }
2227
2228 pub fn handle_stream(
2229 &mut self,
2230 model_telemetry_id: String,
2231 edit_range: Range<Anchor>,
2232 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2233 cx: &mut ModelContext<Self>,
2234 ) {
2235 let snapshot = self.snapshot.clone();
2236 let selected_text = snapshot
2237 .text_for_range(edit_range.start..edit_range.end)
2238 .collect::<Rope>();
2239
2240 let selection_start = edit_range.start.to_point(&snapshot);
2241
2242 // Start with the indentation of the first line in the selection
2243 let mut suggested_line_indent = snapshot
2244 .suggested_indents(selection_start.row..=selection_start.row, cx)
2245 .into_values()
2246 .next()
2247 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2248
2249 // If the first line in the selection does not have indentation, check the following lines
2250 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2251 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2252 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2253 // Prefer tabs if a line in the selection uses tabs as indentation
2254 if line_indent.kind == IndentKind::Tab {
2255 suggested_line_indent.kind = IndentKind::Tab;
2256 break;
2257 }
2258 }
2259 }
2260
2261 let telemetry = self.telemetry.clone();
2262 self.diff = Diff::default();
2263 self.status = CodegenStatus::Pending;
2264 let mut edit_start = edit_range.start.to_offset(&snapshot);
2265 self.generation = cx.spawn(|this, mut cx| {
2266 async move {
2267 let chunks = stream.await;
2268 let generate = async {
2269 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
2270 let diff: Task<anyhow::Result<()>> =
2271 cx.background_executor().spawn(async move {
2272 let mut response_latency = None;
2273 let request_start = Instant::now();
2274 let diff = async {
2275 let chunks = StripInvalidSpans::new(chunks?);
2276 futures::pin_mut!(chunks);
2277 let mut diff = StreamingDiff::new(selected_text.to_string());
2278
2279 let mut new_text = String::new();
2280 let mut base_indent = None;
2281 let mut line_indent = None;
2282 let mut first_line = true;
2283
2284 while let Some(chunk) = chunks.next().await {
2285 if response_latency.is_none() {
2286 response_latency = Some(request_start.elapsed());
2287 }
2288 let chunk = chunk?;
2289
2290 let mut lines = chunk.split('\n').peekable();
2291 while let Some(line) = lines.next() {
2292 new_text.push_str(line);
2293 if line_indent.is_none() {
2294 if let Some(non_whitespace_ch_ix) =
2295 new_text.find(|ch: char| !ch.is_whitespace())
2296 {
2297 line_indent = Some(non_whitespace_ch_ix);
2298 base_indent = base_indent.or(line_indent);
2299
2300 let line_indent = line_indent.unwrap();
2301 let base_indent = base_indent.unwrap();
2302 let indent_delta =
2303 line_indent as i32 - base_indent as i32;
2304 let mut corrected_indent_len = cmp::max(
2305 0,
2306 suggested_line_indent.len as i32 + indent_delta,
2307 )
2308 as usize;
2309 if first_line {
2310 corrected_indent_len = corrected_indent_len
2311 .saturating_sub(
2312 selection_start.column as usize,
2313 );
2314 }
2315
2316 let indent_char = suggested_line_indent.char();
2317 let mut indent_buffer = [0; 4];
2318 let indent_str =
2319 indent_char.encode_utf8(&mut indent_buffer);
2320 new_text.replace_range(
2321 ..line_indent,
2322 &indent_str.repeat(corrected_indent_len),
2323 );
2324 }
2325 }
2326
2327 if line_indent.is_some() {
2328 hunks_tx.send(diff.push_new(&new_text)).await?;
2329 new_text.clear();
2330 }
2331
2332 if lines.peek().is_some() {
2333 hunks_tx.send(diff.push_new("\n")).await?;
2334 if line_indent.is_none() {
2335 // Don't write out the leading indentation in empty lines on the next line
2336 // This is the case where the above if statement didn't clear the buffer
2337 new_text.clear();
2338 }
2339 line_indent = None;
2340 first_line = false;
2341 }
2342 }
2343 }
2344 hunks_tx.send(diff.push_new(&new_text)).await?;
2345 hunks_tx.send(diff.finish()).await?;
2346
2347 anyhow::Ok(())
2348 };
2349
2350 let result = diff.await;
2351
2352 let error_message =
2353 result.as_ref().err().map(|error| error.to_string());
2354 if let Some(telemetry) = telemetry {
2355 telemetry.report_assistant_event(
2356 None,
2357 telemetry_events::AssistantKind::Inline,
2358 model_telemetry_id,
2359 response_latency,
2360 error_message,
2361 );
2362 }
2363
2364 result?;
2365 Ok(())
2366 });
2367
2368 while let Some(hunks) = hunks_rx.next().await {
2369 this.update(&mut cx, |this, cx| {
2370 this.last_equal_ranges.clear();
2371
2372 let transaction = this.buffer.update(cx, |buffer, cx| {
2373 // Avoid grouping assistant edits with user edits.
2374 buffer.finalize_last_transaction(cx);
2375
2376 buffer.start_transaction(cx);
2377 buffer.edit(
2378 hunks.into_iter().filter_map(|hunk| match hunk {
2379 Hunk::Insert { text } => {
2380 let edit_start = snapshot.anchor_after(edit_start);
2381 Some((edit_start..edit_start, text))
2382 }
2383 Hunk::Remove { len } => {
2384 let edit_end = edit_start + len;
2385 let edit_range = snapshot.anchor_after(edit_start)
2386 ..snapshot.anchor_before(edit_end);
2387 edit_start = edit_end;
2388 Some((edit_range, String::new()))
2389 }
2390 Hunk::Keep { len } => {
2391 let edit_end = edit_start + len;
2392 let edit_range = snapshot.anchor_after(edit_start)
2393 ..snapshot.anchor_before(edit_end);
2394 edit_start = edit_end;
2395 this.last_equal_ranges.push(edit_range);
2396 None
2397 }
2398 }),
2399 None,
2400 cx,
2401 );
2402 this.edit_position = Some(snapshot.anchor_after(edit_start));
2403
2404 buffer.end_transaction(cx)
2405 });
2406
2407 if let Some(transaction) = transaction {
2408 if let Some(first_transaction) = this.transaction_id {
2409 // Group all assistant edits into the first transaction.
2410 this.buffer.update(cx, |buffer, cx| {
2411 buffer.merge_transactions(
2412 transaction,
2413 first_transaction,
2414 cx,
2415 )
2416 });
2417 } else {
2418 this.transaction_id = Some(transaction);
2419 this.buffer.update(cx, |buffer, cx| {
2420 buffer.finalize_last_transaction(cx)
2421 });
2422 }
2423 }
2424
2425 this.update_diff(edit_range.clone(), cx);
2426 cx.notify();
2427 })?;
2428 }
2429
2430 diff.await?;
2431
2432 anyhow::Ok(())
2433 };
2434
2435 let result = generate.await;
2436 this.update(&mut cx, |this, cx| {
2437 this.last_equal_ranges.clear();
2438 if let Err(error) = result {
2439 this.status = CodegenStatus::Error(error);
2440 } else {
2441 this.status = CodegenStatus::Done;
2442 }
2443 cx.emit(CodegenEvent::Finished);
2444 cx.notify();
2445 })
2446 .ok();
2447 }
2448 });
2449 cx.notify();
2450 }
2451
2452 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2453 self.last_equal_ranges.clear();
2454 self.status = CodegenStatus::Done;
2455 self.generation = Task::ready(());
2456 cx.emit(CodegenEvent::Finished);
2457 cx.notify();
2458 }
2459
2460 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2461 if let Some(transaction_id) = self.transaction_id.take() {
2462 self.buffer
2463 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
2464 }
2465 }
2466
2467 fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
2468 if self.diff.task.is_some() {
2469 self.diff.should_update = true;
2470 } else {
2471 self.diff.should_update = false;
2472
2473 let old_snapshot = self.snapshot.clone();
2474 let old_range = edit_range.to_point(&old_snapshot);
2475 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2476 let new_range = edit_range.to_point(&new_snapshot);
2477
2478 self.diff.task = Some(cx.spawn(|this, mut cx| async move {
2479 let (deleted_row_ranges, inserted_row_ranges) = cx
2480 .background_executor()
2481 .spawn(async move {
2482 let old_text = old_snapshot
2483 .text_for_range(
2484 Point::new(old_range.start.row, 0)
2485 ..Point::new(
2486 old_range.end.row,
2487 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2488 ),
2489 )
2490 .collect::<String>();
2491 let new_text = new_snapshot
2492 .text_for_range(
2493 Point::new(new_range.start.row, 0)
2494 ..Point::new(
2495 new_range.end.row,
2496 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2497 ),
2498 )
2499 .collect::<String>();
2500
2501 let mut old_row = old_range.start.row;
2502 let mut new_row = new_range.start.row;
2503 let diff = TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2504
2505 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2506 let mut inserted_row_ranges = Vec::new();
2507 for change in diff.iter_all_changes() {
2508 let line_count = change.value().lines().count() as u32;
2509 match change.tag() {
2510 similar::ChangeTag::Equal => {
2511 old_row += line_count;
2512 new_row += line_count;
2513 }
2514 similar::ChangeTag::Delete => {
2515 let old_end_row = old_row + line_count - 1;
2516 let new_row =
2517 new_snapshot.anchor_before(Point::new(new_row, 0));
2518
2519 if let Some((_, last_deleted_row_range)) =
2520 deleted_row_ranges.last_mut()
2521 {
2522 if *last_deleted_row_range.end() + 1 == old_row {
2523 *last_deleted_row_range =
2524 *last_deleted_row_range.start()..=old_end_row;
2525 } else {
2526 deleted_row_ranges
2527 .push((new_row, old_row..=old_end_row));
2528 }
2529 } else {
2530 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2531 }
2532
2533 old_row += line_count;
2534 }
2535 similar::ChangeTag::Insert => {
2536 let new_end_row = new_row + line_count - 1;
2537 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2538 let end = new_snapshot.anchor_before(Point::new(
2539 new_end_row,
2540 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2541 ));
2542 inserted_row_ranges.push(start..=end);
2543 new_row += line_count;
2544 }
2545 }
2546 }
2547
2548 (deleted_row_ranges, inserted_row_ranges)
2549 })
2550 .await;
2551
2552 this.update(&mut cx, |this, cx| {
2553 this.diff.deleted_row_ranges = deleted_row_ranges;
2554 this.diff.inserted_row_ranges = inserted_row_ranges;
2555 this.diff.task = None;
2556 if this.diff.should_update {
2557 this.update_diff(edit_range, cx);
2558 }
2559 cx.notify();
2560 })
2561 .ok();
2562 }));
2563 }
2564 }
2565}
2566
2567struct StripInvalidSpans<T> {
2568 stream: T,
2569 stream_done: bool,
2570 buffer: String,
2571 first_line: bool,
2572 line_end: bool,
2573 starts_with_code_block: bool,
2574}
2575
2576impl<T> StripInvalidSpans<T>
2577where
2578 T: Stream<Item = Result<String>>,
2579{
2580 fn new(stream: T) -> Self {
2581 Self {
2582 stream,
2583 stream_done: false,
2584 buffer: String::new(),
2585 first_line: true,
2586 line_end: false,
2587 starts_with_code_block: false,
2588 }
2589 }
2590}
2591
2592impl<T> Stream for StripInvalidSpans<T>
2593where
2594 T: Stream<Item = Result<String>>,
2595{
2596 type Item = Result<String>;
2597
2598 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2599 const CODE_BLOCK_DELIMITER: &str = "```";
2600 const CURSOR_SPAN: &str = "<|CURSOR|>";
2601
2602 let this = unsafe { self.get_unchecked_mut() };
2603 loop {
2604 if !this.stream_done {
2605 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2606 match stream.as_mut().poll_next(cx) {
2607 Poll::Ready(Some(Ok(chunk))) => {
2608 this.buffer.push_str(&chunk);
2609 }
2610 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2611 Poll::Ready(None) => {
2612 this.stream_done = true;
2613 }
2614 Poll::Pending => return Poll::Pending,
2615 }
2616 }
2617
2618 let mut chunk = String::new();
2619 let mut consumed = 0;
2620 if !this.buffer.is_empty() {
2621 let mut lines = this.buffer.split('\n').enumerate().peekable();
2622 while let Some((line_ix, line)) = lines.next() {
2623 if line_ix > 0 {
2624 this.first_line = false;
2625 }
2626
2627 if this.first_line {
2628 let trimmed_line = line.trim();
2629 if lines.peek().is_some() {
2630 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2631 consumed += line.len() + 1;
2632 this.starts_with_code_block = true;
2633 continue;
2634 }
2635 } else if trimmed_line.is_empty()
2636 || prefixes(CODE_BLOCK_DELIMITER)
2637 .any(|prefix| trimmed_line.starts_with(prefix))
2638 {
2639 break;
2640 }
2641 }
2642
2643 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2644 if lines.peek().is_some() {
2645 if this.line_end {
2646 chunk.push('\n');
2647 }
2648
2649 chunk.push_str(&line_without_cursor);
2650 this.line_end = true;
2651 consumed += line.len() + 1;
2652 } else if this.stream_done {
2653 if !this.starts_with_code_block
2654 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2655 {
2656 if this.line_end {
2657 chunk.push('\n');
2658 }
2659
2660 chunk.push_str(&line);
2661 }
2662
2663 consumed += line.len();
2664 } else {
2665 let trimmed_line = line.trim();
2666 if trimmed_line.is_empty()
2667 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2668 || prefixes(CODE_BLOCK_DELIMITER)
2669 .any(|prefix| trimmed_line.ends_with(prefix))
2670 {
2671 break;
2672 } else {
2673 if this.line_end {
2674 chunk.push('\n');
2675 this.line_end = false;
2676 }
2677
2678 chunk.push_str(&line_without_cursor);
2679 consumed += line.len();
2680 }
2681 }
2682 }
2683 }
2684
2685 this.buffer = this.buffer.split_off(consumed);
2686 if !chunk.is_empty() {
2687 return Poll::Ready(Some(Ok(chunk)));
2688 } else if this.stream_done {
2689 return Poll::Ready(None);
2690 }
2691 }
2692 }
2693}
2694
2695fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2696 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2697}
2698
2699fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2700 ranges.sort_unstable_by(|a, b| {
2701 a.start
2702 .cmp(&b.start, buffer)
2703 .then_with(|| b.end.cmp(&a.end, buffer))
2704 });
2705
2706 let mut ix = 0;
2707 while ix + 1 < ranges.len() {
2708 let b = ranges[ix + 1].clone();
2709 let a = &mut ranges[ix];
2710 if a.end.cmp(&b.start, buffer).is_gt() {
2711 if a.end.cmp(&b.end, buffer).is_lt() {
2712 a.end = b.end;
2713 }
2714 ranges.remove(ix + 1);
2715 } else {
2716 ix += 1;
2717 }
2718 }
2719}
2720
2721#[cfg(test)]
2722mod tests {
2723 use super::*;
2724 use futures::stream::{self};
2725 use gpui::{Context, TestAppContext};
2726 use indoc::indoc;
2727 use language::{
2728 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2729 Point,
2730 };
2731 use language_model::LanguageModelRegistry;
2732 use rand::prelude::*;
2733 use serde::Serialize;
2734 use settings::SettingsStore;
2735 use std::{future, sync::Arc};
2736
2737 #[derive(Serialize)]
2738 pub struct DummyCompletionRequest {
2739 pub name: String,
2740 }
2741
2742 #[gpui::test(iterations = 10)]
2743 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2744 cx.set_global(cx.update(SettingsStore::test));
2745 cx.update(language_model::LanguageModelRegistry::test);
2746 cx.update(language_settings::init);
2747
2748 let text = indoc! {"
2749 fn main() {
2750 let x = 0;
2751 for _ in 0..10 {
2752 x += 1;
2753 }
2754 }
2755 "};
2756 let buffer =
2757 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2758 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2759 let range = buffer.read_with(cx, |buffer, cx| {
2760 let snapshot = buffer.snapshot(cx);
2761 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2762 });
2763 let codegen =
2764 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2765
2766 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2767 codegen.update(cx, |codegen, cx| {
2768 codegen.handle_stream(
2769 String::new(),
2770 range,
2771 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2772 cx,
2773 )
2774 });
2775
2776 let mut new_text = concat!(
2777 " let mut x = 0;\n",
2778 " while x < 10 {\n",
2779 " x += 1;\n",
2780 " }",
2781 );
2782 while !new_text.is_empty() {
2783 let max_len = cmp::min(new_text.len(), 10);
2784 let len = rng.gen_range(1..=max_len);
2785 let (chunk, suffix) = new_text.split_at(len);
2786 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2787 new_text = suffix;
2788 cx.background_executor.run_until_parked();
2789 }
2790 drop(chunks_tx);
2791 cx.background_executor.run_until_parked();
2792
2793 assert_eq!(
2794 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2795 indoc! {"
2796 fn main() {
2797 let mut x = 0;
2798 while x < 10 {
2799 x += 1;
2800 }
2801 }
2802 "}
2803 );
2804 }
2805
2806 #[gpui::test(iterations = 10)]
2807 async fn test_autoindent_when_generating_past_indentation(
2808 cx: &mut TestAppContext,
2809 mut rng: StdRng,
2810 ) {
2811 cx.set_global(cx.update(SettingsStore::test));
2812 cx.update(language_settings::init);
2813
2814 let text = indoc! {"
2815 fn main() {
2816 le
2817 }
2818 "};
2819 let buffer =
2820 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2821 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2822 let range = buffer.read_with(cx, |buffer, cx| {
2823 let snapshot = buffer.snapshot(cx);
2824 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2825 });
2826 let codegen =
2827 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2828
2829 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2830 codegen.update(cx, |codegen, cx| {
2831 codegen.handle_stream(
2832 String::new(),
2833 range.clone(),
2834 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2835 cx,
2836 )
2837 });
2838
2839 cx.background_executor.run_until_parked();
2840
2841 let mut new_text = concat!(
2842 "t mut x = 0;\n",
2843 "while x < 10 {\n",
2844 " x += 1;\n",
2845 "}", //
2846 );
2847 while !new_text.is_empty() {
2848 let max_len = cmp::min(new_text.len(), 10);
2849 let len = rng.gen_range(1..=max_len);
2850 let (chunk, suffix) = new_text.split_at(len);
2851 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2852 new_text = suffix;
2853 cx.background_executor.run_until_parked();
2854 }
2855 drop(chunks_tx);
2856 cx.background_executor.run_until_parked();
2857
2858 assert_eq!(
2859 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2860 indoc! {"
2861 fn main() {
2862 let mut x = 0;
2863 while x < 10 {
2864 x += 1;
2865 }
2866 }
2867 "}
2868 );
2869 }
2870
2871 #[gpui::test(iterations = 10)]
2872 async fn test_autoindent_when_generating_before_indentation(
2873 cx: &mut TestAppContext,
2874 mut rng: StdRng,
2875 ) {
2876 cx.update(LanguageModelRegistry::test);
2877 cx.set_global(cx.update(SettingsStore::test));
2878 cx.update(language_settings::init);
2879
2880 let text = concat!(
2881 "fn main() {\n",
2882 " \n",
2883 "}\n" //
2884 );
2885 let buffer =
2886 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2887 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2888 let range = buffer.read_with(cx, |buffer, cx| {
2889 let snapshot = buffer.snapshot(cx);
2890 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2891 });
2892 let codegen =
2893 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2894
2895 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2896 codegen.update(cx, |codegen, cx| {
2897 codegen.handle_stream(
2898 String::new(),
2899 range.clone(),
2900 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2901 cx,
2902 )
2903 });
2904
2905 cx.background_executor.run_until_parked();
2906
2907 let mut new_text = concat!(
2908 "let mut x = 0;\n",
2909 "while x < 10 {\n",
2910 " x += 1;\n",
2911 "}", //
2912 );
2913 while !new_text.is_empty() {
2914 let max_len = cmp::min(new_text.len(), 10);
2915 let len = rng.gen_range(1..=max_len);
2916 let (chunk, suffix) = new_text.split_at(len);
2917 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2918 new_text = suffix;
2919 cx.background_executor.run_until_parked();
2920 }
2921 drop(chunks_tx);
2922 cx.background_executor.run_until_parked();
2923
2924 assert_eq!(
2925 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2926 indoc! {"
2927 fn main() {
2928 let mut x = 0;
2929 while x < 10 {
2930 x += 1;
2931 }
2932 }
2933 "}
2934 );
2935 }
2936
2937 #[gpui::test(iterations = 10)]
2938 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
2939 cx.update(LanguageModelRegistry::test);
2940 cx.set_global(cx.update(SettingsStore::test));
2941 cx.update(language_settings::init);
2942
2943 let text = indoc! {"
2944 func main() {
2945 \tx := 0
2946 \tfor i := 0; i < 10; i++ {
2947 \t\tx++
2948 \t}
2949 }
2950 "};
2951 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
2952 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2953 let range = buffer.read_with(cx, |buffer, cx| {
2954 let snapshot = buffer.snapshot(cx);
2955 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
2956 });
2957 let codegen =
2958 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2959
2960 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2961 codegen.update(cx, |codegen, cx| {
2962 codegen.handle_stream(
2963 String::new(),
2964 range.clone(),
2965 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2966 cx,
2967 )
2968 });
2969
2970 let new_text = concat!(
2971 "func main() {\n",
2972 "\tx := 0\n",
2973 "\tfor x < 10 {\n",
2974 "\t\tx++\n",
2975 "\t}", //
2976 );
2977 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
2978 drop(chunks_tx);
2979 cx.background_executor.run_until_parked();
2980
2981 assert_eq!(
2982 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2983 indoc! {"
2984 func main() {
2985 \tx := 0
2986 \tfor x < 10 {
2987 \t\tx++
2988 \t}
2989 }
2990 "}
2991 );
2992 }
2993
2994 #[gpui::test]
2995 async fn test_strip_invalid_spans_from_codeblock() {
2996 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2997 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2998 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
2999 assert_chunks(
3000 "```html\n```js\nLorem ipsum dolor\n```\n```",
3001 "```js\nLorem ipsum dolor\n```",
3002 )
3003 .await;
3004 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3005 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3006 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3007 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3008
3009 async fn assert_chunks(text: &str, expected_text: &str) {
3010 for chunk_size in 1..=text.len() {
3011 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3012 .map(|chunk| chunk.unwrap())
3013 .collect::<String>()
3014 .await;
3015 assert_eq!(
3016 actual_text, expected_text,
3017 "failed to strip invalid spans, chunk size: {}",
3018 chunk_size
3019 );
3020 }
3021 }
3022
3023 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3024 stream::iter(
3025 text.chars()
3026 .collect::<Vec<_>>()
3027 .chunks(size)
3028 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3029 .collect::<Vec<_>>(),
3030 )
3031 }
3032 }
3033
3034 fn rust_lang() -> Language {
3035 Language::new(
3036 LanguageConfig {
3037 name: "Rust".into(),
3038 matcher: LanguageMatcher {
3039 path_suffixes: vec!["rs".to_string()],
3040 ..Default::default()
3041 },
3042 ..Default::default()
3043 },
3044 Some(tree_sitter_rust::language()),
3045 )
3046 .with_indents_query(
3047 r#"
3048 (call_expression) @indent
3049 (field_expression) @indent
3050 (_ "(" ")" @end) @indent
3051 (_ "{" "}" @end) @indent
3052 "#,
3053 )
3054 .unwrap()
3055 }
3056}