1use crate::{
2 humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
3 CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, Selection, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use smol::future::FutureExt;
39use std::{
40 cmp,
41 future::{self, Future},
42 mem,
43 ops::{Range, RangeInclusive},
44 pin::Pin,
45 sync::Arc,
46 task::{self, Poll},
47 time::{Duration, Instant},
48};
49use theme::ThemeSettings;
50use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
51use util::{RangeExt, ResultExt};
52use workspace::{notifications::NotificationId, Toast, Workspace};
53
54pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
55 cx.set_global(InlineAssistant::new(fs, telemetry));
56}
57
58const PROMPT_HISTORY_MAX_LEN: usize = 20;
59
60pub struct InlineAssistant {
61 next_assist_id: InlineAssistId,
62 next_assist_group_id: InlineAssistGroupId,
63 assists: HashMap<InlineAssistId, InlineAssist>,
64 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
65 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
66 prompt_history: VecDeque<String>,
67 telemetry: Option<Arc<Telemetry>>,
68 fs: Arc<dyn Fs>,
69}
70
71impl Global for InlineAssistant {}
72
73impl InlineAssistant {
74 pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
75 Self {
76 next_assist_id: InlineAssistId::default(),
77 next_assist_group_id: InlineAssistGroupId::default(),
78 assists: HashMap::default(),
79 assists_by_editor: HashMap::default(),
80 assist_groups: HashMap::default(),
81 prompt_history: VecDeque::default(),
82 telemetry: Some(telemetry),
83 fs,
84 }
85 }
86
87 pub fn assist(
88 &mut self,
89 editor: &View<Editor>,
90 workspace: Option<WeakView<Workspace>>,
91 assistant_panel: Option<&View<AssistantPanel>>,
92 initial_prompt: Option<String>,
93 cx: &mut WindowContext,
94 ) {
95 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
96
97 let mut selections = Vec::<Selection<Point>>::new();
98 let mut newest_selection = None;
99 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
100 if selection.end > selection.start {
101 selection.start.column = 0;
102 // If the selection ends at the start of the line, we don't want to include it.
103 if selection.end.column == 0 {
104 selection.end.row -= 1;
105 }
106 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
107 }
108
109 if let Some(prev_selection) = selections.last_mut() {
110 if selection.start <= prev_selection.end {
111 prev_selection.end = selection.end;
112 continue;
113 }
114 }
115
116 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
117 if selection.id > latest_selection.id {
118 *latest_selection = selection.clone();
119 }
120 selections.push(selection);
121 }
122 let newest_selection = newest_selection.unwrap();
123
124 let mut codegen_ranges = Vec::new();
125 for (excerpt_id, buffer, buffer_range) in
126 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
127 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
128 }))
129 {
130 let start = Anchor {
131 buffer_id: Some(buffer.remote_id()),
132 excerpt_id,
133 text_anchor: buffer.anchor_before(buffer_range.start),
134 };
135 let end = Anchor {
136 buffer_id: Some(buffer.remote_id()),
137 excerpt_id,
138 text_anchor: buffer.anchor_after(buffer_range.end),
139 };
140 codegen_ranges.push(start..end);
141 }
142
143 let assist_group_id = self.next_assist_group_id.post_inc();
144 let prompt_buffer =
145 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
146 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
147
148 let mut assists = Vec::new();
149 let mut assist_to_focus = None;
150 for range in codegen_ranges {
151 let assist_id = self.next_assist_id.post_inc();
152 let codegen = cx.new_model(|cx| {
153 Codegen::new(
154 editor.read(cx).buffer().clone(),
155 range.clone(),
156 None,
157 self.telemetry.clone(),
158 cx,
159 )
160 });
161
162 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
163 let prompt_editor = cx.new_view(|cx| {
164 PromptEditor::new(
165 assist_id,
166 gutter_dimensions.clone(),
167 self.prompt_history.clone(),
168 prompt_buffer.clone(),
169 codegen.clone(),
170 editor,
171 assistant_panel,
172 workspace.clone(),
173 self.fs.clone(),
174 cx,
175 )
176 });
177
178 if assist_to_focus.is_none() {
179 let focus_assist = if newest_selection.reversed {
180 range.start.to_point(&snapshot) == newest_selection.start
181 } else {
182 range.end.to_point(&snapshot) == newest_selection.end
183 };
184 if focus_assist {
185 assist_to_focus = Some(assist_id);
186 }
187 }
188
189 let [prompt_block_id, end_block_id] =
190 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
191
192 assists.push((
193 assist_id,
194 range,
195 prompt_editor,
196 prompt_block_id,
197 end_block_id,
198 ));
199 }
200
201 let editor_assists = self
202 .assists_by_editor
203 .entry(editor.downgrade())
204 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
205 let mut assist_group = InlineAssistGroup::new();
206 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
207 self.assists.insert(
208 assist_id,
209 InlineAssist::new(
210 assist_id,
211 assist_group_id,
212 assistant_panel.is_some(),
213 editor,
214 &prompt_editor,
215 prompt_block_id,
216 end_block_id,
217 range,
218 prompt_editor.read(cx).codegen.clone(),
219 workspace.clone(),
220 cx,
221 ),
222 );
223 assist_group.assist_ids.push(assist_id);
224 editor_assists.assist_ids.push(assist_id);
225 }
226 self.assist_groups.insert(assist_group_id, assist_group);
227
228 if let Some(assist_id) = assist_to_focus {
229 self.focus_assist(assist_id, cx);
230 }
231 }
232
233 #[allow(clippy::too_many_arguments)]
234 pub fn suggest_assist(
235 &mut self,
236 editor: &View<Editor>,
237 mut range: Range<Anchor>,
238 initial_prompt: String,
239 initial_transaction_id: Option<TransactionId>,
240 workspace: Option<WeakView<Workspace>>,
241 assistant_panel: Option<&View<AssistantPanel>>,
242 cx: &mut WindowContext,
243 ) -> InlineAssistId {
244 let assist_group_id = self.next_assist_group_id.post_inc();
245 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
246 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
247
248 let assist_id = self.next_assist_id.post_inc();
249
250 let buffer = editor.read(cx).buffer().clone();
251 {
252 let snapshot = buffer.read(cx).read(cx);
253 range.start = range.start.bias_left(&snapshot);
254 range.end = range.end.bias_right(&snapshot);
255 }
256
257 let codegen = cx.new_model(|cx| {
258 Codegen::new(
259 editor.read(cx).buffer().clone(),
260 range.clone(),
261 initial_transaction_id,
262 self.telemetry.clone(),
263 cx,
264 )
265 });
266
267 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
268 let prompt_editor = cx.new_view(|cx| {
269 PromptEditor::new(
270 assist_id,
271 gutter_dimensions.clone(),
272 self.prompt_history.clone(),
273 prompt_buffer.clone(),
274 codegen.clone(),
275 editor,
276 assistant_panel,
277 workspace.clone(),
278 self.fs.clone(),
279 cx,
280 )
281 });
282
283 let [prompt_block_id, end_block_id] =
284 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
285
286 let editor_assists = self
287 .assists_by_editor
288 .entry(editor.downgrade())
289 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
290
291 let mut assist_group = InlineAssistGroup::new();
292 self.assists.insert(
293 assist_id,
294 InlineAssist::new(
295 assist_id,
296 assist_group_id,
297 assistant_panel.is_some(),
298 editor,
299 &prompt_editor,
300 prompt_block_id,
301 end_block_id,
302 range,
303 prompt_editor.read(cx).codegen.clone(),
304 workspace.clone(),
305 cx,
306 ),
307 );
308 assist_group.assist_ids.push(assist_id);
309 editor_assists.assist_ids.push(assist_id);
310 self.assist_groups.insert(assist_group_id, assist_group);
311 assist_id
312 }
313
314 fn insert_assist_blocks(
315 &self,
316 editor: &View<Editor>,
317 range: &Range<Anchor>,
318 prompt_editor: &View<PromptEditor>,
319 cx: &mut WindowContext,
320 ) -> [CustomBlockId; 2] {
321 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
322 prompt_editor
323 .editor
324 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
325 });
326 let assist_blocks = vec![
327 BlockProperties {
328 style: BlockStyle::Sticky,
329 position: range.start,
330 height: prompt_editor_height,
331 render: build_assist_editor_renderer(prompt_editor),
332 disposition: BlockDisposition::Above,
333 },
334 BlockProperties {
335 style: BlockStyle::Sticky,
336 position: range.end,
337 height: 0,
338 render: Box::new(|cx| {
339 v_flex()
340 .h_full()
341 .w_full()
342 .border_t_1()
343 .border_color(cx.theme().status().info_border)
344 .into_any_element()
345 }),
346 disposition: BlockDisposition::Below,
347 },
348 ];
349
350 editor.update(cx, |editor, cx| {
351 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
352 [block_ids[0], block_ids[1]]
353 })
354 }
355
356 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
357 let assist = &self.assists[&assist_id];
358 let Some(decorations) = assist.decorations.as_ref() else {
359 return;
360 };
361 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
362 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
363
364 assist_group.active_assist_id = Some(assist_id);
365 if assist_group.linked {
366 for assist_id in &assist_group.assist_ids {
367 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
368 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
369 prompt_editor.set_show_cursor_when_unfocused(true, cx)
370 });
371 }
372 }
373 }
374
375 assist
376 .editor
377 .update(cx, |editor, cx| {
378 let scroll_top = editor.scroll_position(cx).y;
379 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
380 let prompt_row = editor
381 .row_for_block(decorations.prompt_block_id, cx)
382 .unwrap()
383 .0 as f32;
384
385 if (scroll_top..scroll_bottom).contains(&prompt_row) {
386 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
387 assist_id,
388 distance_from_top: prompt_row - scroll_top,
389 });
390 } else {
391 editor_assists.scroll_lock = None;
392 }
393 })
394 .ok();
395 }
396
397 fn handle_prompt_editor_focus_out(
398 &mut self,
399 assist_id: InlineAssistId,
400 cx: &mut WindowContext,
401 ) {
402 let assist = &self.assists[&assist_id];
403 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
404 if assist_group.active_assist_id == Some(assist_id) {
405 assist_group.active_assist_id = None;
406 if assist_group.linked {
407 for assist_id in &assist_group.assist_ids {
408 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
409 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
410 prompt_editor.set_show_cursor_when_unfocused(false, cx)
411 });
412 }
413 }
414 }
415 }
416 }
417
418 fn handle_prompt_editor_event(
419 &mut self,
420 prompt_editor: View<PromptEditor>,
421 event: &PromptEditorEvent,
422 cx: &mut WindowContext,
423 ) {
424 let assist_id = prompt_editor.read(cx).id;
425 match event {
426 PromptEditorEvent::StartRequested => {
427 self.start_assist(assist_id, cx);
428 }
429 PromptEditorEvent::StopRequested => {
430 self.stop_assist(assist_id, cx);
431 }
432 PromptEditorEvent::ConfirmRequested => {
433 self.finish_assist(assist_id, false, cx);
434 }
435 PromptEditorEvent::CancelRequested => {
436 self.finish_assist(assist_id, true, cx);
437 }
438 PromptEditorEvent::DismissRequested => {
439 self.dismiss_assist(assist_id, cx);
440 }
441 }
442 }
443
444 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
445 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
446 return;
447 };
448
449 let editor = editor.read(cx);
450 if editor.selections.count() == 1 {
451 let selection = editor.selections.newest::<usize>(cx);
452 let buffer = editor.buffer().read(cx).snapshot(cx);
453 for assist_id in &editor_assists.assist_ids {
454 let assist = &self.assists[assist_id];
455 let assist_range = assist.range.to_offset(&buffer);
456 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
457 {
458 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
459 self.dismiss_assist(*assist_id, cx);
460 } else {
461 self.finish_assist(*assist_id, false, cx);
462 }
463
464 return;
465 }
466 }
467 }
468
469 cx.propagate();
470 }
471
472 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
473 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
474 return;
475 };
476
477 let editor = editor.read(cx);
478 if editor.selections.count() == 1 {
479 let selection = editor.selections.newest::<usize>(cx);
480 let buffer = editor.buffer().read(cx).snapshot(cx);
481 for assist_id in &editor_assists.assist_ids {
482 let assist = &self.assists[assist_id];
483 let assist_range = assist.range.to_offset(&buffer);
484 if assist.decorations.is_some()
485 && assist_range.contains(&selection.start)
486 && assist_range.contains(&selection.end)
487 {
488 self.focus_assist(*assist_id, cx);
489 return;
490 }
491 }
492 }
493
494 cx.propagate();
495 }
496
497 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
498 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
499 for assist_id in editor_assists.assist_ids.clone() {
500 self.finish_assist(assist_id, true, cx);
501 }
502 }
503 }
504
505 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
506 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
507 return;
508 };
509 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
510 return;
511 };
512 let assist = &self.assists[&scroll_lock.assist_id];
513 let Some(decorations) = assist.decorations.as_ref() else {
514 return;
515 };
516
517 editor.update(cx, |editor, cx| {
518 let scroll_position = editor.scroll_position(cx);
519 let target_scroll_top = editor
520 .row_for_block(decorations.prompt_block_id, cx)
521 .unwrap()
522 .0 as f32
523 - scroll_lock.distance_from_top;
524 if target_scroll_top != scroll_position.y {
525 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
526 }
527 });
528 }
529
530 fn handle_editor_event(
531 &mut self,
532 editor: View<Editor>,
533 event: &EditorEvent,
534 cx: &mut WindowContext,
535 ) {
536 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
537 return;
538 };
539
540 match event {
541 EditorEvent::Saved => {
542 for assist_id in editor_assists.assist_ids.clone() {
543 let assist = &self.assists[&assist_id];
544 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
545 self.finish_assist(assist_id, false, cx)
546 }
547 }
548 }
549 EditorEvent::Edited { transaction_id } => {
550 let buffer = editor.read(cx).buffer().read(cx);
551 let edited_ranges =
552 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
553 let snapshot = buffer.snapshot(cx);
554
555 for assist_id in editor_assists.assist_ids.clone() {
556 let assist = &self.assists[&assist_id];
557 if matches!(
558 assist.codegen.read(cx).status,
559 CodegenStatus::Error(_) | CodegenStatus::Done
560 ) {
561 let assist_range = assist.range.to_offset(&snapshot);
562 if edited_ranges
563 .iter()
564 .any(|range| range.overlaps(&assist_range))
565 {
566 self.finish_assist(assist_id, false, cx);
567 }
568 }
569 }
570 }
571 EditorEvent::ScrollPositionChanged { .. } => {
572 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
573 let assist = &self.assists[&scroll_lock.assist_id];
574 if let Some(decorations) = assist.decorations.as_ref() {
575 let distance_from_top = editor.update(cx, |editor, cx| {
576 let scroll_top = editor.scroll_position(cx).y;
577 let prompt_row = editor
578 .row_for_block(decorations.prompt_block_id, cx)
579 .unwrap()
580 .0 as f32;
581 prompt_row - scroll_top
582 });
583
584 if distance_from_top != scroll_lock.distance_from_top {
585 editor_assists.scroll_lock = None;
586 }
587 }
588 }
589 }
590 EditorEvent::SelectionsChanged { .. } => {
591 for assist_id in editor_assists.assist_ids.clone() {
592 let assist = &self.assists[&assist_id];
593 if let Some(decorations) = assist.decorations.as_ref() {
594 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
595 return;
596 }
597 }
598 }
599
600 editor_assists.scroll_lock = None;
601 }
602 _ => {}
603 }
604 }
605
606 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
607 if let Some(assist) = self.assists.get(&assist_id) {
608 let assist_group_id = assist.group_id;
609 if self.assist_groups[&assist_group_id].linked {
610 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
611 self.finish_assist(assist_id, undo, cx);
612 }
613 return;
614 }
615 }
616
617 self.dismiss_assist(assist_id, cx);
618
619 if let Some(assist) = self.assists.remove(&assist_id) {
620 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
621 {
622 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
623 if entry.get().assist_ids.is_empty() {
624 entry.remove();
625 }
626 }
627
628 if let hash_map::Entry::Occupied(mut entry) =
629 self.assists_by_editor.entry(assist.editor.clone())
630 {
631 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
632 if entry.get().assist_ids.is_empty() {
633 entry.remove();
634 if let Some(editor) = assist.editor.upgrade() {
635 self.update_editor_highlights(&editor, cx);
636 }
637 } else {
638 entry.get().highlight_updates.send(()).ok();
639 }
640 }
641
642 if undo {
643 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
644 }
645 }
646 }
647
648 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
649 let Some(assist) = self.assists.get_mut(&assist_id) else {
650 return false;
651 };
652 let Some(editor) = assist.editor.upgrade() else {
653 return false;
654 };
655 let Some(decorations) = assist.decorations.take() else {
656 return false;
657 };
658
659 editor.update(cx, |editor, cx| {
660 let mut to_remove = decorations.removed_line_block_ids;
661 to_remove.insert(decorations.prompt_block_id);
662 to_remove.insert(decorations.end_block_id);
663 editor.remove_blocks(to_remove, None, cx);
664 });
665
666 if decorations
667 .prompt_editor
668 .focus_handle(cx)
669 .contains_focused(cx)
670 {
671 self.focus_next_assist(assist_id, cx);
672 }
673
674 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
675 if editor_assists
676 .scroll_lock
677 .as_ref()
678 .map_or(false, |lock| lock.assist_id == assist_id)
679 {
680 editor_assists.scroll_lock = None;
681 }
682 editor_assists.highlight_updates.send(()).ok();
683 }
684
685 true
686 }
687
688 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
689 let Some(assist) = self.assists.get(&assist_id) else {
690 return;
691 };
692
693 let assist_group = &self.assist_groups[&assist.group_id];
694 let assist_ix = assist_group
695 .assist_ids
696 .iter()
697 .position(|id| *id == assist_id)
698 .unwrap();
699 let assist_ids = assist_group
700 .assist_ids
701 .iter()
702 .skip(assist_ix + 1)
703 .chain(assist_group.assist_ids.iter().take(assist_ix));
704
705 for assist_id in assist_ids {
706 let assist = &self.assists[assist_id];
707 if assist.decorations.is_some() {
708 self.focus_assist(*assist_id, cx);
709 return;
710 }
711 }
712
713 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
714 }
715
716 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
717 let Some(assist) = self.assists.get(&assist_id) else {
718 return;
719 };
720
721 if let Some(decorations) = assist.decorations.as_ref() {
722 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
723 prompt_editor.editor.update(cx, |editor, cx| {
724 editor.focus(cx);
725 editor.select_all(&SelectAll, cx);
726 })
727 });
728 }
729
730 self.scroll_to_assist(assist_id, cx);
731 }
732
733 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
734 let Some(assist) = self.assists.get(&assist_id) else {
735 return;
736 };
737 let Some(editor) = assist.editor.upgrade() else {
738 return;
739 };
740
741 let position = assist.range.start;
742 editor.update(cx, |editor, cx| {
743 editor.change_selections(None, cx, |selections| {
744 selections.select_anchor_ranges([position..position])
745 });
746
747 let mut scroll_target_top;
748 let mut scroll_target_bottom;
749 if let Some(decorations) = assist.decorations.as_ref() {
750 scroll_target_top = editor
751 .row_for_block(decorations.prompt_block_id, cx)
752 .unwrap()
753 .0 as f32;
754 scroll_target_bottom = editor
755 .row_for_block(decorations.end_block_id, cx)
756 .unwrap()
757 .0 as f32;
758 } else {
759 let snapshot = editor.snapshot(cx);
760 let start_row = assist
761 .range
762 .start
763 .to_display_point(&snapshot.display_snapshot)
764 .row();
765 scroll_target_top = start_row.0 as f32;
766 scroll_target_bottom = scroll_target_top + 1.;
767 }
768 scroll_target_top -= editor.vertical_scroll_margin() as f32;
769 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
770
771 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
772 let scroll_top = editor.scroll_position(cx).y;
773 let scroll_bottom = scroll_top + height_in_lines;
774
775 if scroll_target_top < scroll_top {
776 editor.set_scroll_position(point(0., scroll_target_top), cx);
777 } else if scroll_target_bottom > scroll_bottom {
778 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
779 editor
780 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
781 } else {
782 editor.set_scroll_position(point(0., scroll_target_top), cx);
783 }
784 }
785 });
786 }
787
788 fn unlink_assist_group(
789 &mut self,
790 assist_group_id: InlineAssistGroupId,
791 cx: &mut WindowContext,
792 ) -> Vec<InlineAssistId> {
793 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
794 assist_group.linked = false;
795 for assist_id in &assist_group.assist_ids {
796 let assist = self.assists.get_mut(assist_id).unwrap();
797 if let Some(editor_decorations) = assist.decorations.as_ref() {
798 editor_decorations
799 .prompt_editor
800 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
801 }
802 }
803 assist_group.assist_ids.clone()
804 }
805
806 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
807 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
808 assist
809 } else {
810 return;
811 };
812
813 let assist_group_id = assist.group_id;
814 if self.assist_groups[&assist_group_id].linked {
815 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
816 self.start_assist(assist_id, cx);
817 }
818 return;
819 }
820
821 let Some(user_prompt) = assist.user_prompt(cx) else {
822 return;
823 };
824
825 self.prompt_history.retain(|prompt| *prompt != user_prompt);
826 self.prompt_history.push_back(user_prompt.clone());
827 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
828 self.prompt_history.pop_front();
829 }
830
831 let assistant_panel_context = assist.assistant_panel_context(cx);
832
833 assist
834 .codegen
835 .update(cx, |codegen, cx| {
836 codegen.start(
837 assist.range.clone(),
838 user_prompt,
839 assistant_panel_context,
840 cx,
841 )
842 })
843 .log_err();
844 }
845
846 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
847 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
848 assist
849 } else {
850 return;
851 };
852
853 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
854 }
855
856 pub fn status_for_assist(
857 &self,
858 assist_id: InlineAssistId,
859 cx: &WindowContext,
860 ) -> Option<CodegenStatus> {
861 let assist = self.assists.get(&assist_id)?;
862 match &assist.codegen.read(cx).status {
863 CodegenStatus::Idle => Some(CodegenStatus::Idle),
864 CodegenStatus::Pending => Some(CodegenStatus::Pending),
865 CodegenStatus::Done => Some(CodegenStatus::Done),
866 CodegenStatus::Error(error) => Some(CodegenStatus::Error(anyhow!("{:?}", error))),
867 }
868 }
869
870 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
871 let mut gutter_pending_ranges = Vec::new();
872 let mut gutter_transformed_ranges = Vec::new();
873 let mut foreground_ranges = Vec::new();
874 let mut inserted_row_ranges = Vec::new();
875 let empty_assist_ids = Vec::new();
876 let assist_ids = self
877 .assists_by_editor
878 .get(&editor.downgrade())
879 .map_or(&empty_assist_ids, |editor_assists| {
880 &editor_assists.assist_ids
881 });
882
883 for assist_id in assist_ids {
884 if let Some(assist) = self.assists.get(assist_id) {
885 let codegen = assist.codegen.read(cx);
886 let buffer = codegen.buffer.read(cx).read(cx);
887 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
888
889 let pending_range =
890 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
891 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
892 gutter_pending_ranges.push(pending_range);
893 }
894
895 if let Some(edit_position) = codegen.edit_position {
896 let edited_range = assist.range.start..edit_position;
897 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
898 gutter_transformed_ranges.push(edited_range);
899 }
900 }
901
902 if assist.decorations.is_some() {
903 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
904 }
905 }
906 }
907
908 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
909 merge_ranges(&mut foreground_ranges, &snapshot);
910 merge_ranges(&mut gutter_pending_ranges, &snapshot);
911 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
912 editor.update(cx, |editor, cx| {
913 enum GutterPendingRange {}
914 if gutter_pending_ranges.is_empty() {
915 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
916 } else {
917 editor.highlight_gutter::<GutterPendingRange>(
918 &gutter_pending_ranges,
919 |cx| cx.theme().status().info_background,
920 cx,
921 )
922 }
923
924 enum GutterTransformedRange {}
925 if gutter_transformed_ranges.is_empty() {
926 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
927 } else {
928 editor.highlight_gutter::<GutterTransformedRange>(
929 &gutter_transformed_ranges,
930 |cx| cx.theme().status().info,
931 cx,
932 )
933 }
934
935 if foreground_ranges.is_empty() {
936 editor.clear_highlights::<InlineAssist>(cx);
937 } else {
938 editor.highlight_text::<InlineAssist>(
939 foreground_ranges,
940 HighlightStyle {
941 fade_out: Some(0.6),
942 ..Default::default()
943 },
944 cx,
945 );
946 }
947
948 editor.clear_row_highlights::<InlineAssist>();
949 for row_range in inserted_row_ranges {
950 editor.highlight_rows::<InlineAssist>(
951 row_range,
952 Some(cx.theme().status().info_background),
953 false,
954 cx,
955 );
956 }
957 });
958 }
959
960 fn update_editor_blocks(
961 &mut self,
962 editor: &View<Editor>,
963 assist_id: InlineAssistId,
964 cx: &mut WindowContext,
965 ) {
966 let Some(assist) = self.assists.get_mut(&assist_id) else {
967 return;
968 };
969 let Some(decorations) = assist.decorations.as_mut() else {
970 return;
971 };
972
973 let codegen = assist.codegen.read(cx);
974 let old_snapshot = codegen.snapshot.clone();
975 let old_buffer = codegen.old_buffer.clone();
976 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
977
978 editor.update(cx, |editor, cx| {
979 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
980 editor.remove_blocks(old_blocks, None, cx);
981
982 let mut new_blocks = Vec::new();
983 for (new_row, old_row_range) in deleted_row_ranges {
984 let (_, buffer_start) = old_snapshot
985 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
986 .unwrap();
987 let (_, buffer_end) = old_snapshot
988 .point_to_buffer_offset(Point::new(
989 *old_row_range.end(),
990 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
991 ))
992 .unwrap();
993
994 let deleted_lines_editor = cx.new_view(|cx| {
995 let multi_buffer = cx.new_model(|_| {
996 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
997 });
998 multi_buffer.update(cx, |multi_buffer, cx| {
999 multi_buffer.push_excerpts(
1000 old_buffer.clone(),
1001 Some(ExcerptRange {
1002 context: buffer_start..buffer_end,
1003 primary: None,
1004 }),
1005 cx,
1006 );
1007 });
1008
1009 enum DeletedLines {}
1010 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1011 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1012 editor.set_show_wrap_guides(false, cx);
1013 editor.set_show_gutter(false, cx);
1014 editor.scroll_manager.set_forbid_vertical_scroll(true);
1015 editor.set_read_only(true);
1016 editor.highlight_rows::<DeletedLines>(
1017 Anchor::min()..=Anchor::max(),
1018 Some(cx.theme().status().deleted_background),
1019 false,
1020 cx,
1021 );
1022 editor
1023 });
1024
1025 let height =
1026 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1027 new_blocks.push(BlockProperties {
1028 position: new_row,
1029 height,
1030 style: BlockStyle::Flex,
1031 render: Box::new(move |cx| {
1032 div()
1033 .bg(cx.theme().status().deleted_background)
1034 .size_full()
1035 .h(height as f32 * cx.line_height())
1036 .pl(cx.gutter_dimensions.full_width())
1037 .child(deleted_lines_editor.clone())
1038 .into_any_element()
1039 }),
1040 disposition: BlockDisposition::Above,
1041 });
1042 }
1043
1044 decorations.removed_line_block_ids = editor
1045 .insert_blocks(new_blocks, None, cx)
1046 .into_iter()
1047 .collect();
1048 })
1049 }
1050}
1051
1052struct EditorInlineAssists {
1053 assist_ids: Vec<InlineAssistId>,
1054 scroll_lock: Option<InlineAssistScrollLock>,
1055 highlight_updates: async_watch::Sender<()>,
1056 _update_highlights: Task<Result<()>>,
1057 _subscriptions: Vec<gpui::Subscription>,
1058}
1059
1060struct InlineAssistScrollLock {
1061 assist_id: InlineAssistId,
1062 distance_from_top: f32,
1063}
1064
1065impl EditorInlineAssists {
1066 #[allow(clippy::too_many_arguments)]
1067 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1068 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1069 Self {
1070 assist_ids: Vec::new(),
1071 scroll_lock: None,
1072 highlight_updates: highlight_updates_tx,
1073 _update_highlights: cx.spawn(|mut cx| {
1074 let editor = editor.downgrade();
1075 async move {
1076 while let Ok(()) = highlight_updates_rx.changed().await {
1077 let editor = editor.upgrade().context("editor was dropped")?;
1078 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1079 assistant.update_editor_highlights(&editor, cx);
1080 })?;
1081 }
1082 Ok(())
1083 }
1084 }),
1085 _subscriptions: vec![
1086 cx.observe_release(editor, {
1087 let editor = editor.downgrade();
1088 |_, cx| {
1089 InlineAssistant::update_global(cx, |this, cx| {
1090 this.handle_editor_release(editor, cx);
1091 })
1092 }
1093 }),
1094 cx.observe(editor, move |editor, cx| {
1095 InlineAssistant::update_global(cx, |this, cx| {
1096 this.handle_editor_change(editor, cx)
1097 })
1098 }),
1099 cx.subscribe(editor, move |editor, event, cx| {
1100 InlineAssistant::update_global(cx, |this, cx| {
1101 this.handle_editor_event(editor, event, cx)
1102 })
1103 }),
1104 editor.update(cx, |editor, cx| {
1105 let editor_handle = cx.view().downgrade();
1106 editor.register_action(
1107 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1108 InlineAssistant::update_global(cx, |this, cx| {
1109 if let Some(editor) = editor_handle.upgrade() {
1110 this.handle_editor_newline(editor, cx)
1111 }
1112 })
1113 },
1114 )
1115 }),
1116 editor.update(cx, |editor, cx| {
1117 let editor_handle = cx.view().downgrade();
1118 editor.register_action(
1119 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1120 InlineAssistant::update_global(cx, |this, cx| {
1121 if let Some(editor) = editor_handle.upgrade() {
1122 this.handle_editor_cancel(editor, cx)
1123 }
1124 })
1125 },
1126 )
1127 }),
1128 ],
1129 }
1130 }
1131}
1132
1133struct InlineAssistGroup {
1134 assist_ids: Vec<InlineAssistId>,
1135 linked: bool,
1136 active_assist_id: Option<InlineAssistId>,
1137}
1138
1139impl InlineAssistGroup {
1140 fn new() -> Self {
1141 Self {
1142 assist_ids: Vec::new(),
1143 linked: true,
1144 active_assist_id: None,
1145 }
1146 }
1147}
1148
1149fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1150 let editor = editor.clone();
1151 Box::new(move |cx: &mut BlockContext| {
1152 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1153 editor.clone().into_any_element()
1154 })
1155}
1156
1157#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1158pub enum InitialInsertion {
1159 NewlineBefore,
1160 NewlineAfter,
1161}
1162
1163#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1164pub struct InlineAssistId(usize);
1165
1166impl InlineAssistId {
1167 fn post_inc(&mut self) -> InlineAssistId {
1168 let id = *self;
1169 self.0 += 1;
1170 id
1171 }
1172}
1173
1174#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1175struct InlineAssistGroupId(usize);
1176
1177impl InlineAssistGroupId {
1178 fn post_inc(&mut self) -> InlineAssistGroupId {
1179 let id = *self;
1180 self.0 += 1;
1181 id
1182 }
1183}
1184
1185enum PromptEditorEvent {
1186 StartRequested,
1187 StopRequested,
1188 ConfirmRequested,
1189 CancelRequested,
1190 DismissRequested,
1191}
1192
1193struct PromptEditor {
1194 id: InlineAssistId,
1195 fs: Arc<dyn Fs>,
1196 editor: View<Editor>,
1197 edited_since_done: bool,
1198 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1199 prompt_history: VecDeque<String>,
1200 prompt_history_ix: Option<usize>,
1201 pending_prompt: String,
1202 codegen: Model<Codegen>,
1203 _codegen_subscription: Subscription,
1204 editor_subscriptions: Vec<Subscription>,
1205 pending_token_count: Task<Result<()>>,
1206 token_count: Option<usize>,
1207 _token_count_subscriptions: Vec<Subscription>,
1208 workspace: Option<WeakView<Workspace>>,
1209 show_rate_limit_notice: bool,
1210}
1211
1212impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1213
1214impl Render for PromptEditor {
1215 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1216 let gutter_dimensions = *self.gutter_dimensions.lock();
1217 let status = &self.codegen.read(cx).status;
1218 let buttons = match status {
1219 CodegenStatus::Idle => {
1220 vec![
1221 IconButton::new("cancel", IconName::Close)
1222 .icon_color(Color::Muted)
1223 .shape(IconButtonShape::Square)
1224 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1225 .on_click(
1226 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1227 ),
1228 IconButton::new("start", IconName::SparkleAlt)
1229 .icon_color(Color::Muted)
1230 .shape(IconButtonShape::Square)
1231 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1232 .on_click(
1233 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1234 ),
1235 ]
1236 }
1237 CodegenStatus::Pending => {
1238 vec![
1239 IconButton::new("cancel", IconName::Close)
1240 .icon_color(Color::Muted)
1241 .shape(IconButtonShape::Square)
1242 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1243 .on_click(
1244 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1245 ),
1246 IconButton::new("stop", IconName::Stop)
1247 .icon_color(Color::Error)
1248 .shape(IconButtonShape::Square)
1249 .tooltip(|cx| {
1250 Tooltip::with_meta(
1251 "Interrupt Transformation",
1252 Some(&menu::Cancel),
1253 "Changes won't be discarded",
1254 cx,
1255 )
1256 })
1257 .on_click(
1258 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1259 ),
1260 ]
1261 }
1262 CodegenStatus::Error(_) | CodegenStatus::Done => {
1263 vec![
1264 IconButton::new("cancel", IconName::Close)
1265 .icon_color(Color::Muted)
1266 .shape(IconButtonShape::Square)
1267 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1268 .on_click(
1269 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1270 ),
1271 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1272 IconButton::new("restart", IconName::RotateCw)
1273 .icon_color(Color::Info)
1274 .shape(IconButtonShape::Square)
1275 .tooltip(|cx| {
1276 Tooltip::with_meta(
1277 "Restart Transformation",
1278 Some(&menu::Confirm),
1279 "Changes will be discarded",
1280 cx,
1281 )
1282 })
1283 .on_click(cx.listener(|_, _, cx| {
1284 cx.emit(PromptEditorEvent::StartRequested);
1285 }))
1286 } else {
1287 IconButton::new("confirm", IconName::Check)
1288 .icon_color(Color::Info)
1289 .shape(IconButtonShape::Square)
1290 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1291 .on_click(cx.listener(|_, _, cx| {
1292 cx.emit(PromptEditorEvent::ConfirmRequested);
1293 }))
1294 },
1295 ]
1296 }
1297 };
1298
1299 h_flex()
1300 .bg(cx.theme().colors().editor_background)
1301 .border_y_1()
1302 .border_color(cx.theme().status().info_border)
1303 .size_full()
1304 .py(cx.line_height() / 2.)
1305 .on_action(cx.listener(Self::confirm))
1306 .on_action(cx.listener(Self::cancel))
1307 .on_action(cx.listener(Self::move_up))
1308 .on_action(cx.listener(Self::move_down))
1309 .child(
1310 h_flex()
1311 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1312 .justify_center()
1313 .gap_2()
1314 .child(
1315 ModelSelector::new(
1316 self.fs.clone(),
1317 IconButton::new("context", IconName::SlidersAlt)
1318 .shape(IconButtonShape::Square)
1319 .icon_size(IconSize::Small)
1320 .icon_color(Color::Muted)
1321 .tooltip(move |cx| {
1322 Tooltip::with_meta(
1323 format!(
1324 "Using {}",
1325 LanguageModelRegistry::read_global(cx)
1326 .active_model()
1327 .map(|model| model.name().0)
1328 .unwrap_or_else(|| "No model selected".into()),
1329 ),
1330 None,
1331 "Change Model",
1332 cx,
1333 )
1334 }),
1335 )
1336 .with_info_text(
1337 "Inline edits use context\n\
1338 from the currently selected\n\
1339 assistant panel tab.",
1340 ),
1341 )
1342 .map(|el| {
1343 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1344 return el;
1345 };
1346
1347 let error_message = SharedString::from(error.to_string());
1348 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1349 && cx.has_flag::<ZedPro>()
1350 {
1351 el.child(
1352 v_flex()
1353 .child(
1354 IconButton::new("rate-limit-error", IconName::XCircle)
1355 .selected(self.show_rate_limit_notice)
1356 .shape(IconButtonShape::Square)
1357 .icon_size(IconSize::Small)
1358 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1359 )
1360 .children(self.show_rate_limit_notice.then(|| {
1361 deferred(
1362 anchored()
1363 .position_mode(gpui::AnchoredPositionMode::Local)
1364 .position(point(px(0.), px(24.)))
1365 .anchor(gpui::AnchorCorner::TopLeft)
1366 .child(self.render_rate_limit_notice(cx)),
1367 )
1368 })),
1369 )
1370 } else {
1371 el.child(
1372 div()
1373 .id("error")
1374 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1375 .child(
1376 Icon::new(IconName::XCircle)
1377 .size(IconSize::Small)
1378 .color(Color::Error),
1379 ),
1380 )
1381 }
1382 }),
1383 )
1384 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1385 .child(
1386 h_flex()
1387 .gap_2()
1388 .pr_6()
1389 .children(self.render_token_count(cx))
1390 .children(buttons),
1391 )
1392 }
1393}
1394
1395impl FocusableView for PromptEditor {
1396 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1397 self.editor.focus_handle(cx)
1398 }
1399}
1400
1401impl PromptEditor {
1402 const MAX_LINES: u8 = 8;
1403
1404 #[allow(clippy::too_many_arguments)]
1405 fn new(
1406 id: InlineAssistId,
1407 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1408 prompt_history: VecDeque<String>,
1409 prompt_buffer: Model<MultiBuffer>,
1410 codegen: Model<Codegen>,
1411 parent_editor: &View<Editor>,
1412 assistant_panel: Option<&View<AssistantPanel>>,
1413 workspace: Option<WeakView<Workspace>>,
1414 fs: Arc<dyn Fs>,
1415 cx: &mut ViewContext<Self>,
1416 ) -> Self {
1417 let prompt_editor = cx.new_view(|cx| {
1418 let mut editor = Editor::new(
1419 EditorMode::AutoHeight {
1420 max_lines: Self::MAX_LINES as usize,
1421 },
1422 prompt_buffer,
1423 None,
1424 false,
1425 cx,
1426 );
1427 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1428 // Since the prompt editors for all inline assistants are linked,
1429 // always show the cursor (even when it isn't focused) because
1430 // typing in one will make what you typed appear in all of them.
1431 editor.set_show_cursor_when_unfocused(true, cx);
1432 editor.set_placeholder_text("Add a prompt…", cx);
1433 editor
1434 });
1435
1436 let mut token_count_subscriptions = Vec::new();
1437 token_count_subscriptions
1438 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1439 if let Some(assistant_panel) = assistant_panel {
1440 token_count_subscriptions
1441 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1442 }
1443
1444 let mut this = Self {
1445 id,
1446 editor: prompt_editor,
1447 edited_since_done: false,
1448 gutter_dimensions,
1449 prompt_history,
1450 prompt_history_ix: None,
1451 pending_prompt: String::new(),
1452 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1453 editor_subscriptions: Vec::new(),
1454 codegen,
1455 fs,
1456 pending_token_count: Task::ready(Ok(())),
1457 token_count: None,
1458 _token_count_subscriptions: token_count_subscriptions,
1459 workspace,
1460 show_rate_limit_notice: false,
1461 };
1462 this.count_tokens(cx);
1463 this.subscribe_to_editor(cx);
1464 this
1465 }
1466
1467 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1468 self.editor_subscriptions.clear();
1469 self.editor_subscriptions
1470 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1471 }
1472
1473 fn set_show_cursor_when_unfocused(
1474 &mut self,
1475 show_cursor_when_unfocused: bool,
1476 cx: &mut ViewContext<Self>,
1477 ) {
1478 self.editor.update(cx, |editor, cx| {
1479 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1480 });
1481 }
1482
1483 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1484 let prompt = self.prompt(cx);
1485 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1486 self.editor = cx.new_view(|cx| {
1487 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1488 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1489 editor.set_placeholder_text("Add a prompt…", cx);
1490 editor.set_text(prompt, cx);
1491 if focus {
1492 editor.focus(cx);
1493 }
1494 editor
1495 });
1496 self.subscribe_to_editor(cx);
1497 }
1498
1499 fn prompt(&self, cx: &AppContext) -> String {
1500 self.editor.read(cx).text(cx)
1501 }
1502
1503 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1504 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1505 if self.show_rate_limit_notice {
1506 cx.focus_view(&self.editor);
1507 }
1508 cx.notify();
1509 }
1510
1511 fn handle_parent_editor_event(
1512 &mut self,
1513 _: View<Editor>,
1514 event: &EditorEvent,
1515 cx: &mut ViewContext<Self>,
1516 ) {
1517 if let EditorEvent::BufferEdited { .. } = event {
1518 self.count_tokens(cx);
1519 }
1520 }
1521
1522 fn handle_assistant_panel_event(
1523 &mut self,
1524 _: View<AssistantPanel>,
1525 event: &AssistantPanelEvent,
1526 cx: &mut ViewContext<Self>,
1527 ) {
1528 let AssistantPanelEvent::ContextEdited { .. } = event;
1529 self.count_tokens(cx);
1530 }
1531
1532 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1533 let assist_id = self.id;
1534 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1535 cx.background_executor().timer(Duration::from_secs(1)).await;
1536 let token_count = cx
1537 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1538 let assist = inline_assistant
1539 .assists
1540 .get(&assist_id)
1541 .context("assist not found")?;
1542 anyhow::Ok(assist.count_tokens(cx))
1543 })??
1544 .await?;
1545
1546 this.update(&mut cx, |this, cx| {
1547 this.token_count = Some(token_count);
1548 cx.notify();
1549 })
1550 })
1551 }
1552
1553 fn handle_prompt_editor_events(
1554 &mut self,
1555 _: View<Editor>,
1556 event: &EditorEvent,
1557 cx: &mut ViewContext<Self>,
1558 ) {
1559 match event {
1560 EditorEvent::Edited { .. } => {
1561 let prompt = self.editor.read(cx).text(cx);
1562 if self
1563 .prompt_history_ix
1564 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1565 {
1566 self.prompt_history_ix.take();
1567 self.pending_prompt = prompt;
1568 }
1569
1570 self.edited_since_done = true;
1571 cx.notify();
1572 }
1573 EditorEvent::BufferEdited => {
1574 self.count_tokens(cx);
1575 }
1576 EditorEvent::Blurred => {
1577 if self.show_rate_limit_notice {
1578 self.show_rate_limit_notice = false;
1579 cx.notify();
1580 }
1581 }
1582 _ => {}
1583 }
1584 }
1585
1586 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1587 match &self.codegen.read(cx).status {
1588 CodegenStatus::Idle => {
1589 self.editor
1590 .update(cx, |editor, _| editor.set_read_only(false));
1591 }
1592 CodegenStatus::Pending => {
1593 self.editor
1594 .update(cx, |editor, _| editor.set_read_only(true));
1595 }
1596 CodegenStatus::Done => {
1597 self.edited_since_done = false;
1598 self.editor
1599 .update(cx, |editor, _| editor.set_read_only(false));
1600 }
1601 CodegenStatus::Error(error) => {
1602 if cx.has_flag::<ZedPro>()
1603 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1604 && !dismissed_rate_limit_notice()
1605 {
1606 self.show_rate_limit_notice = true;
1607 cx.notify();
1608 }
1609
1610 self.edited_since_done = false;
1611 self.editor
1612 .update(cx, |editor, _| editor.set_read_only(false));
1613 }
1614 }
1615 }
1616
1617 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1618 match &self.codegen.read(cx).status {
1619 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1620 cx.emit(PromptEditorEvent::CancelRequested);
1621 }
1622 CodegenStatus::Pending => {
1623 cx.emit(PromptEditorEvent::StopRequested);
1624 }
1625 }
1626 }
1627
1628 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1629 match &self.codegen.read(cx).status {
1630 CodegenStatus::Idle => {
1631 cx.emit(PromptEditorEvent::StartRequested);
1632 }
1633 CodegenStatus::Pending => {
1634 cx.emit(PromptEditorEvent::DismissRequested);
1635 }
1636 CodegenStatus::Done | CodegenStatus::Error(_) => {
1637 if self.edited_since_done {
1638 cx.emit(PromptEditorEvent::StartRequested);
1639 } else {
1640 cx.emit(PromptEditorEvent::ConfirmRequested);
1641 }
1642 }
1643 }
1644 }
1645
1646 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1647 if let Some(ix) = self.prompt_history_ix {
1648 if ix > 0 {
1649 self.prompt_history_ix = Some(ix - 1);
1650 let prompt = self.prompt_history[ix - 1].as_str();
1651 self.editor.update(cx, |editor, cx| {
1652 editor.set_text(prompt, cx);
1653 editor.move_to_beginning(&Default::default(), cx);
1654 });
1655 }
1656 } else if !self.prompt_history.is_empty() {
1657 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1658 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1659 self.editor.update(cx, |editor, cx| {
1660 editor.set_text(prompt, cx);
1661 editor.move_to_beginning(&Default::default(), cx);
1662 });
1663 }
1664 }
1665
1666 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1667 if let Some(ix) = self.prompt_history_ix {
1668 if ix < self.prompt_history.len() - 1 {
1669 self.prompt_history_ix = Some(ix + 1);
1670 let prompt = self.prompt_history[ix + 1].as_str();
1671 self.editor.update(cx, |editor, cx| {
1672 editor.set_text(prompt, cx);
1673 editor.move_to_end(&Default::default(), cx)
1674 });
1675 } else {
1676 self.prompt_history_ix = None;
1677 let prompt = self.pending_prompt.as_str();
1678 self.editor.update(cx, |editor, cx| {
1679 editor.set_text(prompt, cx);
1680 editor.move_to_end(&Default::default(), cx)
1681 });
1682 }
1683 }
1684 }
1685
1686 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1687 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1688 let token_count = self.token_count?;
1689 let max_token_count = model.max_token_count();
1690
1691 let remaining_tokens = max_token_count as isize - token_count as isize;
1692 let token_count_color = if remaining_tokens <= 0 {
1693 Color::Error
1694 } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1695 Color::Warning
1696 } else {
1697 Color::Muted
1698 };
1699
1700 let mut token_count = h_flex()
1701 .id("token_count")
1702 .gap_0p5()
1703 .child(
1704 Label::new(humanize_token_count(token_count))
1705 .size(LabelSize::Small)
1706 .color(token_count_color),
1707 )
1708 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1709 .child(
1710 Label::new(humanize_token_count(max_token_count))
1711 .size(LabelSize::Small)
1712 .color(Color::Muted),
1713 );
1714 if let Some(workspace) = self.workspace.clone() {
1715 token_count = token_count
1716 .tooltip(|cx| {
1717 Tooltip::with_meta(
1718 "Tokens Used by Inline Assistant",
1719 None,
1720 "Click to Open Assistant Panel",
1721 cx,
1722 )
1723 })
1724 .cursor_pointer()
1725 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1726 .on_click(move |_, cx| {
1727 cx.stop_propagation();
1728 workspace
1729 .update(cx, |workspace, cx| {
1730 workspace.focus_panel::<AssistantPanel>(cx)
1731 })
1732 .ok();
1733 });
1734 } else {
1735 token_count = token_count
1736 .cursor_default()
1737 .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1738 }
1739
1740 Some(token_count)
1741 }
1742
1743 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1744 let settings = ThemeSettings::get_global(cx);
1745 let text_style = TextStyle {
1746 color: if self.editor.read(cx).read_only(cx) {
1747 cx.theme().colors().text_disabled
1748 } else {
1749 cx.theme().colors().text
1750 },
1751 font_family: settings.ui_font.family.clone(),
1752 font_features: settings.ui_font.features.clone(),
1753 font_fallbacks: settings.ui_font.fallbacks.clone(),
1754 font_size: rems(0.875).into(),
1755 font_weight: settings.ui_font.weight,
1756 line_height: relative(1.3),
1757 ..Default::default()
1758 };
1759 EditorElement::new(
1760 &self.editor,
1761 EditorStyle {
1762 background: cx.theme().colors().editor_background,
1763 local_player: cx.theme().players().local(),
1764 text: text_style,
1765 ..Default::default()
1766 },
1767 )
1768 }
1769
1770 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1771 Popover::new().child(
1772 v_flex()
1773 .occlude()
1774 .p_2()
1775 .child(
1776 Label::new("Out of Tokens")
1777 .size(LabelSize::Small)
1778 .weight(FontWeight::BOLD),
1779 )
1780 .child(Label::new(
1781 "Try Zed Pro for higher limits, a wider range of models, and more.",
1782 ))
1783 .child(
1784 h_flex()
1785 .justify_between()
1786 .child(CheckboxWithLabel::new(
1787 "dont-show-again",
1788 Label::new("Don't show again"),
1789 if dismissed_rate_limit_notice() {
1790 ui::Selection::Selected
1791 } else {
1792 ui::Selection::Unselected
1793 },
1794 |selection, cx| {
1795 let is_dismissed = match selection {
1796 ui::Selection::Unselected => false,
1797 ui::Selection::Indeterminate => return,
1798 ui::Selection::Selected => true,
1799 };
1800
1801 set_rate_limit_notice_dismissed(is_dismissed, cx)
1802 },
1803 ))
1804 .child(
1805 h_flex()
1806 .gap_2()
1807 .child(
1808 Button::new("dismiss", "Dismiss")
1809 .style(ButtonStyle::Transparent)
1810 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1811 )
1812 .child(Button::new("more-info", "More Info").on_click(
1813 |_event, cx| {
1814 cx.dispatch_action(Box::new(
1815 zed_actions::OpenAccountSettings,
1816 ))
1817 },
1818 )),
1819 ),
1820 ),
1821 )
1822 }
1823}
1824
1825const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1826
1827fn dismissed_rate_limit_notice() -> bool {
1828 db::kvp::KEY_VALUE_STORE
1829 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1830 .log_err()
1831 .map_or(false, |s| s.is_some())
1832}
1833
1834fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1835 db::write_and_log(cx, move || async move {
1836 if is_dismissed {
1837 db::kvp::KEY_VALUE_STORE
1838 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1839 .await
1840 } else {
1841 db::kvp::KEY_VALUE_STORE
1842 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1843 .await
1844 }
1845 })
1846}
1847
1848struct InlineAssist {
1849 group_id: InlineAssistGroupId,
1850 range: Range<Anchor>,
1851 editor: WeakView<Editor>,
1852 decorations: Option<InlineAssistDecorations>,
1853 codegen: Model<Codegen>,
1854 _subscriptions: Vec<Subscription>,
1855 workspace: Option<WeakView<Workspace>>,
1856 include_context: bool,
1857}
1858
1859impl InlineAssist {
1860 #[allow(clippy::too_many_arguments)]
1861 fn new(
1862 assist_id: InlineAssistId,
1863 group_id: InlineAssistGroupId,
1864 include_context: bool,
1865 editor: &View<Editor>,
1866 prompt_editor: &View<PromptEditor>,
1867 prompt_block_id: CustomBlockId,
1868 end_block_id: CustomBlockId,
1869 range: Range<Anchor>,
1870 codegen: Model<Codegen>,
1871 workspace: Option<WeakView<Workspace>>,
1872 cx: &mut WindowContext,
1873 ) -> Self {
1874 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1875 InlineAssist {
1876 group_id,
1877 include_context,
1878 editor: editor.downgrade(),
1879 decorations: Some(InlineAssistDecorations {
1880 prompt_block_id,
1881 prompt_editor: prompt_editor.clone(),
1882 removed_line_block_ids: HashSet::default(),
1883 end_block_id,
1884 }),
1885 range,
1886 codegen: codegen.clone(),
1887 workspace: workspace.clone(),
1888 _subscriptions: vec![
1889 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1890 InlineAssistant::update_global(cx, |this, cx| {
1891 this.handle_prompt_editor_focus_in(assist_id, cx)
1892 })
1893 }),
1894 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1895 InlineAssistant::update_global(cx, |this, cx| {
1896 this.handle_prompt_editor_focus_out(assist_id, cx)
1897 })
1898 }),
1899 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1900 InlineAssistant::update_global(cx, |this, cx| {
1901 this.handle_prompt_editor_event(prompt_editor, event, cx)
1902 })
1903 }),
1904 cx.observe(&codegen, {
1905 let editor = editor.downgrade();
1906 move |_, cx| {
1907 if let Some(editor) = editor.upgrade() {
1908 InlineAssistant::update_global(cx, |this, cx| {
1909 if let Some(editor_assists) =
1910 this.assists_by_editor.get(&editor.downgrade())
1911 {
1912 editor_assists.highlight_updates.send(()).ok();
1913 }
1914
1915 this.update_editor_blocks(&editor, assist_id, cx);
1916 })
1917 }
1918 }
1919 }),
1920 cx.subscribe(&codegen, move |codegen, event, cx| {
1921 InlineAssistant::update_global(cx, |this, cx| match event {
1922 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1923 CodegenEvent::Finished => {
1924 let assist = if let Some(assist) = this.assists.get(&assist_id) {
1925 assist
1926 } else {
1927 return;
1928 };
1929
1930 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
1931 if assist.decorations.is_none() {
1932 if let Some(workspace) = assist
1933 .workspace
1934 .as_ref()
1935 .and_then(|workspace| workspace.upgrade())
1936 {
1937 let error = format!("Inline assistant error: {}", error);
1938 workspace.update(cx, |workspace, cx| {
1939 struct InlineAssistantError;
1940
1941 let id =
1942 NotificationId::identified::<InlineAssistantError>(
1943 assist_id.0,
1944 );
1945
1946 workspace.show_toast(Toast::new(id, error), cx);
1947 })
1948 }
1949 }
1950 }
1951
1952 if assist.decorations.is_none() {
1953 this.finish_assist(assist_id, false, cx);
1954 }
1955 }
1956 })
1957 }),
1958 ],
1959 }
1960 }
1961
1962 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
1963 let decorations = self.decorations.as_ref()?;
1964 Some(decorations.prompt_editor.read(cx).prompt(cx))
1965 }
1966
1967 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
1968 if self.include_context {
1969 let workspace = self.workspace.as_ref()?;
1970 let workspace = workspace.upgrade()?.read(cx);
1971 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
1972 Some(
1973 assistant_panel
1974 .read(cx)
1975 .active_context(cx)?
1976 .read(cx)
1977 .to_completion_request(cx),
1978 )
1979 } else {
1980 None
1981 }
1982 }
1983
1984 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
1985 let Some(user_prompt) = self.user_prompt(cx) else {
1986 return future::ready(Err(anyhow!("no user prompt"))).boxed();
1987 };
1988 let assistant_panel_context = self.assistant_panel_context(cx);
1989 self.codegen.read(cx).count_tokens(
1990 self.range.clone(),
1991 user_prompt,
1992 assistant_panel_context,
1993 cx,
1994 )
1995 }
1996}
1997
1998struct InlineAssistDecorations {
1999 prompt_block_id: CustomBlockId,
2000 prompt_editor: View<PromptEditor>,
2001 removed_line_block_ids: HashSet<CustomBlockId>,
2002 end_block_id: CustomBlockId,
2003}
2004
2005#[derive(Debug)]
2006pub enum CodegenEvent {
2007 Finished,
2008 Undone,
2009}
2010
2011pub struct Codegen {
2012 buffer: Model<MultiBuffer>,
2013 old_buffer: Model<Buffer>,
2014 snapshot: MultiBufferSnapshot,
2015 edit_position: Option<Anchor>,
2016 last_equal_ranges: Vec<Range<Anchor>>,
2017 initial_transaction_id: Option<TransactionId>,
2018 transformation_transaction_id: Option<TransactionId>,
2019 status: CodegenStatus,
2020 generation: Task<()>,
2021 diff: Diff,
2022 telemetry: Option<Arc<Telemetry>>,
2023 _subscription: gpui::Subscription,
2024}
2025
2026pub enum CodegenStatus {
2027 Idle,
2028 Pending,
2029 Done,
2030 Error(anyhow::Error),
2031}
2032
2033#[derive(Default)]
2034struct Diff {
2035 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2036 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2037}
2038
2039impl Diff {
2040 fn is_empty(&self) -> bool {
2041 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2042 }
2043}
2044
2045impl EventEmitter<CodegenEvent> for Codegen {}
2046
2047impl Codegen {
2048 pub fn new(
2049 buffer: Model<MultiBuffer>,
2050 range: Range<Anchor>,
2051 initial_transaction_id: Option<TransactionId>,
2052 telemetry: Option<Arc<Telemetry>>,
2053 cx: &mut ModelContext<Self>,
2054 ) -> Self {
2055 let snapshot = buffer.read(cx).snapshot(cx);
2056
2057 let (old_buffer, _, _) = buffer
2058 .read(cx)
2059 .range_to_buffer_ranges(range.clone(), cx)
2060 .pop()
2061 .unwrap();
2062 let old_buffer = cx.new_model(|cx| {
2063 let old_buffer = old_buffer.read(cx);
2064 let text = old_buffer.as_rope().clone();
2065 let line_ending = old_buffer.line_ending();
2066 let language = old_buffer.language().cloned();
2067 let language_registry = old_buffer.language_registry();
2068
2069 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2070 buffer.set_language(language, cx);
2071 if let Some(language_registry) = language_registry {
2072 buffer.set_language_registry(language_registry)
2073 }
2074 buffer
2075 });
2076
2077 Self {
2078 buffer: buffer.clone(),
2079 old_buffer,
2080 edit_position: None,
2081 snapshot,
2082 last_equal_ranges: Default::default(),
2083 transformation_transaction_id: None,
2084 status: CodegenStatus::Idle,
2085 generation: Task::ready(()),
2086 diff: Diff::default(),
2087 telemetry,
2088 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2089 initial_transaction_id,
2090 }
2091 }
2092
2093 fn handle_buffer_event(
2094 &mut self,
2095 _buffer: Model<MultiBuffer>,
2096 event: &multi_buffer::Event,
2097 cx: &mut ModelContext<Self>,
2098 ) {
2099 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2100 if self.transformation_transaction_id == Some(*transaction_id) {
2101 self.transformation_transaction_id = None;
2102 self.generation = Task::ready(());
2103 cx.emit(CodegenEvent::Undone);
2104 }
2105 }
2106 }
2107
2108 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2109 &self.last_equal_ranges
2110 }
2111
2112 pub fn count_tokens(
2113 &self,
2114 edit_range: Range<Anchor>,
2115 user_prompt: String,
2116 assistant_panel_context: Option<LanguageModelRequest>,
2117 cx: &AppContext,
2118 ) -> BoxFuture<'static, Result<usize>> {
2119 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2120 let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
2121 model.count_tokens(request, cx)
2122 } else {
2123 future::ready(Err(anyhow!("no active model"))).boxed()
2124 }
2125 }
2126
2127 pub fn start(
2128 &mut self,
2129 edit_range: Range<Anchor>,
2130 user_prompt: String,
2131 assistant_panel_context: Option<LanguageModelRequest>,
2132 cx: &mut ModelContext<Self>,
2133 ) -> Result<()> {
2134 let model = LanguageModelRegistry::read_global(cx)
2135 .active_model()
2136 .context("no active model")?;
2137
2138 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2139 self.buffer.update(cx, |buffer, cx| {
2140 buffer.undo_transaction(transformation_transaction_id, cx)
2141 });
2142 }
2143
2144 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2145
2146 let telemetry_id = model.telemetry_id();
2147 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2148 .trim()
2149 .to_lowercase()
2150 == "delete"
2151 {
2152 async { Ok(stream::empty().boxed()) }.boxed_local()
2153 } else {
2154 let request =
2155 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
2156 let chunks =
2157 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2158 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2159 };
2160 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2161 Ok(())
2162 }
2163
2164 fn build_request(
2165 &self,
2166 user_prompt: String,
2167 assistant_panel_context: Option<LanguageModelRequest>,
2168 edit_range: Range<Anchor>,
2169 cx: &AppContext,
2170 ) -> LanguageModelRequest {
2171 let buffer = self.buffer.read(cx).snapshot(cx);
2172 let language = buffer.language_at(edit_range.start);
2173 let language_name = if let Some(language) = language.as_ref() {
2174 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2175 None
2176 } else {
2177 Some(language.name())
2178 }
2179 } else {
2180 None
2181 };
2182
2183 // Higher Temperature increases the randomness of model outputs.
2184 // If Markdown or No Language is Known, increase the randomness for more creative output
2185 // If Code, decrease temperature to get more deterministic outputs
2186 let temperature = if let Some(language) = language_name.clone() {
2187 if language.as_ref() == "Markdown" {
2188 1.0
2189 } else {
2190 0.5
2191 }
2192 } else {
2193 1.0
2194 };
2195
2196 let language_name = language_name.as_deref();
2197 let start = buffer.point_to_buffer_offset(edit_range.start);
2198 let end = buffer.point_to_buffer_offset(edit_range.end);
2199 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2200 let (start_buffer, start_buffer_offset) = start;
2201 let (end_buffer, end_buffer_offset) = end;
2202 if start_buffer.remote_id() == end_buffer.remote_id() {
2203 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2204 } else {
2205 panic!("invalid transformation range");
2206 }
2207 } else {
2208 panic!("invalid transformation range");
2209 };
2210 let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
2211
2212 let mut messages = Vec::new();
2213 if let Some(context_request) = assistant_panel_context {
2214 messages = context_request.messages;
2215 }
2216
2217 messages.push(LanguageModelRequestMessage {
2218 role: Role::User,
2219 content: prompt,
2220 });
2221
2222 LanguageModelRequest {
2223 messages,
2224 stop: vec!["|END|>".to_string()],
2225 temperature,
2226 }
2227 }
2228
2229 pub fn handle_stream(
2230 &mut self,
2231 model_telemetry_id: String,
2232 edit_range: Range<Anchor>,
2233 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2234 cx: &mut ModelContext<Self>,
2235 ) {
2236 let snapshot = self.snapshot.clone();
2237 let selected_text = snapshot
2238 .text_for_range(edit_range.start..edit_range.end)
2239 .collect::<Rope>();
2240
2241 let selection_start = edit_range.start.to_point(&snapshot);
2242
2243 // Start with the indentation of the first line in the selection
2244 let mut suggested_line_indent = snapshot
2245 .suggested_indents(selection_start.row..=selection_start.row, cx)
2246 .into_values()
2247 .next()
2248 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2249
2250 // If the first line in the selection does not have indentation, check the following lines
2251 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2252 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2253 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2254 // Prefer tabs if a line in the selection uses tabs as indentation
2255 if line_indent.kind == IndentKind::Tab {
2256 suggested_line_indent.kind = IndentKind::Tab;
2257 break;
2258 }
2259 }
2260 }
2261
2262 let telemetry = self.telemetry.clone();
2263 self.diff = Diff::default();
2264 self.status = CodegenStatus::Pending;
2265 let mut edit_start = edit_range.start.to_offset(&snapshot);
2266 self.generation = cx.spawn(|this, mut cx| {
2267 async move {
2268 let chunks = stream.await;
2269 let generate = async {
2270 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2271 let diff: Task<anyhow::Result<()>> =
2272 cx.background_executor().spawn(async move {
2273 let mut response_latency = None;
2274 let request_start = Instant::now();
2275 let diff = async {
2276 let chunks = StripInvalidSpans::new(chunks?);
2277 futures::pin_mut!(chunks);
2278 let mut diff = StreamingDiff::new(selected_text.to_string());
2279 let mut line_diff = LineDiff::default();
2280
2281 let mut new_text = String::new();
2282 let mut base_indent = None;
2283 let mut line_indent = None;
2284 let mut first_line = true;
2285
2286 while let Some(chunk) = chunks.next().await {
2287 if response_latency.is_none() {
2288 response_latency = Some(request_start.elapsed());
2289 }
2290 let chunk = chunk?;
2291
2292 let mut lines = chunk.split('\n').peekable();
2293 while let Some(line) = lines.next() {
2294 new_text.push_str(line);
2295 if line_indent.is_none() {
2296 if let Some(non_whitespace_ch_ix) =
2297 new_text.find(|ch: char| !ch.is_whitespace())
2298 {
2299 line_indent = Some(non_whitespace_ch_ix);
2300 base_indent = base_indent.or(line_indent);
2301
2302 let line_indent = line_indent.unwrap();
2303 let base_indent = base_indent.unwrap();
2304 let indent_delta =
2305 line_indent as i32 - base_indent as i32;
2306 let mut corrected_indent_len = cmp::max(
2307 0,
2308 suggested_line_indent.len as i32 + indent_delta,
2309 )
2310 as usize;
2311 if first_line {
2312 corrected_indent_len = corrected_indent_len
2313 .saturating_sub(
2314 selection_start.column as usize,
2315 );
2316 }
2317
2318 let indent_char = suggested_line_indent.char();
2319 let mut indent_buffer = [0; 4];
2320 let indent_str =
2321 indent_char.encode_utf8(&mut indent_buffer);
2322 new_text.replace_range(
2323 ..line_indent,
2324 &indent_str.repeat(corrected_indent_len),
2325 );
2326 }
2327 }
2328
2329 if line_indent.is_some() {
2330 let char_ops = diff.push_new(&new_text);
2331 line_diff
2332 .push_char_operations(&char_ops, &selected_text);
2333 diff_tx
2334 .send((char_ops, line_diff.line_operations()))
2335 .await?;
2336 new_text.clear();
2337 }
2338
2339 if lines.peek().is_some() {
2340 let char_ops = diff.push_new("\n");
2341 line_diff
2342 .push_char_operations(&char_ops, &selected_text);
2343 diff_tx
2344 .send((char_ops, line_diff.line_operations()))
2345 .await?;
2346 if line_indent.is_none() {
2347 // Don't write out the leading indentation in empty lines on the next line
2348 // This is the case where the above if statement didn't clear the buffer
2349 new_text.clear();
2350 }
2351 line_indent = None;
2352 first_line = false;
2353 }
2354 }
2355 }
2356
2357 let mut char_ops = diff.push_new(&new_text);
2358 char_ops.extend(diff.finish());
2359 line_diff.push_char_operations(&char_ops, &selected_text);
2360 line_diff.finish(&selected_text);
2361 diff_tx
2362 .send((char_ops, line_diff.line_operations()))
2363 .await?;
2364
2365 anyhow::Ok(())
2366 };
2367
2368 let result = diff.await;
2369
2370 let error_message =
2371 result.as_ref().err().map(|error| error.to_string());
2372 if let Some(telemetry) = telemetry {
2373 telemetry.report_assistant_event(
2374 None,
2375 telemetry_events::AssistantKind::Inline,
2376 model_telemetry_id,
2377 response_latency,
2378 error_message,
2379 );
2380 }
2381
2382 result?;
2383 Ok(())
2384 });
2385
2386 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2387 this.update(&mut cx, |this, cx| {
2388 this.last_equal_ranges.clear();
2389
2390 let transaction = this.buffer.update(cx, |buffer, cx| {
2391 // Avoid grouping assistant edits with user edits.
2392 buffer.finalize_last_transaction(cx);
2393
2394 buffer.start_transaction(cx);
2395 buffer.edit(
2396 char_ops
2397 .into_iter()
2398 .filter_map(|operation| match operation {
2399 CharOperation::Insert { text } => {
2400 let edit_start = snapshot.anchor_after(edit_start);
2401 Some((edit_start..edit_start, text))
2402 }
2403 CharOperation::Delete { bytes } => {
2404 let edit_end = edit_start + bytes;
2405 let edit_range = snapshot.anchor_after(edit_start)
2406 ..snapshot.anchor_before(edit_end);
2407 edit_start = edit_end;
2408 Some((edit_range, String::new()))
2409 }
2410 CharOperation::Keep { bytes } => {
2411 let edit_end = edit_start + bytes;
2412 let edit_range = snapshot.anchor_after(edit_start)
2413 ..snapshot.anchor_before(edit_end);
2414 edit_start = edit_end;
2415 this.last_equal_ranges.push(edit_range);
2416 None
2417 }
2418 }),
2419 None,
2420 cx,
2421 );
2422 this.edit_position = Some(snapshot.anchor_after(edit_start));
2423
2424 buffer.end_transaction(cx)
2425 });
2426
2427 if let Some(transaction) = transaction {
2428 if let Some(first_transaction) = this.transformation_transaction_id
2429 {
2430 // Group all assistant edits into the first transaction.
2431 this.buffer.update(cx, |buffer, cx| {
2432 buffer.merge_transactions(
2433 transaction,
2434 first_transaction,
2435 cx,
2436 )
2437 });
2438 } else {
2439 this.transformation_transaction_id = Some(transaction);
2440 this.buffer.update(cx, |buffer, cx| {
2441 buffer.finalize_last_transaction(cx)
2442 });
2443 }
2444 }
2445
2446 this.update_diff(edit_range.clone(), line_diff, cx);
2447
2448 cx.notify();
2449 })?;
2450 }
2451
2452 diff.await?;
2453
2454 anyhow::Ok(())
2455 };
2456
2457 let result = generate.await;
2458 this.update(&mut cx, |this, cx| {
2459 this.last_equal_ranges.clear();
2460 if let Err(error) = result {
2461 this.status = CodegenStatus::Error(error);
2462 } else {
2463 this.status = CodegenStatus::Done;
2464 }
2465 cx.emit(CodegenEvent::Finished);
2466 cx.notify();
2467 })
2468 .ok();
2469 }
2470 });
2471 cx.notify();
2472 }
2473
2474 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2475 self.last_equal_ranges.clear();
2476 if self.diff.is_empty() {
2477 self.status = CodegenStatus::Idle;
2478 } else {
2479 self.status = CodegenStatus::Done;
2480 }
2481 self.generation = Task::ready(());
2482 cx.emit(CodegenEvent::Finished);
2483 cx.notify();
2484 }
2485
2486 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2487 self.buffer.update(cx, |buffer, cx| {
2488 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2489 buffer.undo_transaction(transaction_id, cx);
2490 }
2491
2492 if let Some(transaction_id) = self.initial_transaction_id.take() {
2493 buffer.undo_transaction(transaction_id, cx);
2494 }
2495 });
2496 }
2497
2498 fn update_diff(
2499 &mut self,
2500 edit_range: Range<Anchor>,
2501 line_operations: Vec<LineOperation>,
2502 cx: &mut ModelContext<Self>,
2503 ) {
2504 let old_snapshot = self.snapshot.clone();
2505 let old_range = edit_range.to_point(&old_snapshot);
2506 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2507 let new_range = edit_range.to_point(&new_snapshot);
2508
2509 let mut old_row = old_range.start.row;
2510 let mut new_row = new_range.start.row;
2511
2512 self.diff.deleted_row_ranges.clear();
2513 self.diff.inserted_row_ranges.clear();
2514 for operation in line_operations {
2515 match operation {
2516 LineOperation::Keep { lines } => {
2517 old_row += lines;
2518 new_row += lines;
2519 }
2520 LineOperation::Delete { lines } => {
2521 let old_end_row = old_row + lines - 1;
2522 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2523
2524 if let Some((_, last_deleted_row_range)) =
2525 self.diff.deleted_row_ranges.last_mut()
2526 {
2527 if *last_deleted_row_range.end() + 1 == old_row {
2528 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2529 } else {
2530 self.diff
2531 .deleted_row_ranges
2532 .push((new_row, old_row..=old_end_row));
2533 }
2534 } else {
2535 self.diff
2536 .deleted_row_ranges
2537 .push((new_row, old_row..=old_end_row));
2538 }
2539
2540 old_row += lines;
2541 }
2542 LineOperation::Insert { lines } => {
2543 let new_end_row = new_row + lines - 1;
2544 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2545 let end = new_snapshot.anchor_before(Point::new(
2546 new_end_row,
2547 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2548 ));
2549 self.diff.inserted_row_ranges.push(start..=end);
2550 new_row += lines;
2551 }
2552 }
2553
2554 cx.notify();
2555 }
2556 }
2557}
2558
2559struct StripInvalidSpans<T> {
2560 stream: T,
2561 stream_done: bool,
2562 buffer: String,
2563 first_line: bool,
2564 line_end: bool,
2565 starts_with_code_block: bool,
2566}
2567
2568impl<T> StripInvalidSpans<T>
2569where
2570 T: Stream<Item = Result<String>>,
2571{
2572 fn new(stream: T) -> Self {
2573 Self {
2574 stream,
2575 stream_done: false,
2576 buffer: String::new(),
2577 first_line: true,
2578 line_end: false,
2579 starts_with_code_block: false,
2580 }
2581 }
2582}
2583
2584impl<T> Stream for StripInvalidSpans<T>
2585where
2586 T: Stream<Item = Result<String>>,
2587{
2588 type Item = Result<String>;
2589
2590 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2591 const CODE_BLOCK_DELIMITER: &str = "```";
2592 const CURSOR_SPAN: &str = "<|CURSOR|>";
2593
2594 let this = unsafe { self.get_unchecked_mut() };
2595 loop {
2596 if !this.stream_done {
2597 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2598 match stream.as_mut().poll_next(cx) {
2599 Poll::Ready(Some(Ok(chunk))) => {
2600 this.buffer.push_str(&chunk);
2601 }
2602 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2603 Poll::Ready(None) => {
2604 this.stream_done = true;
2605 }
2606 Poll::Pending => return Poll::Pending,
2607 }
2608 }
2609
2610 let mut chunk = String::new();
2611 let mut consumed = 0;
2612 if !this.buffer.is_empty() {
2613 let mut lines = this.buffer.split('\n').enumerate().peekable();
2614 while let Some((line_ix, line)) = lines.next() {
2615 if line_ix > 0 {
2616 this.first_line = false;
2617 }
2618
2619 if this.first_line {
2620 let trimmed_line = line.trim();
2621 if lines.peek().is_some() {
2622 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2623 consumed += line.len() + 1;
2624 this.starts_with_code_block = true;
2625 continue;
2626 }
2627 } else if trimmed_line.is_empty()
2628 || prefixes(CODE_BLOCK_DELIMITER)
2629 .any(|prefix| trimmed_line.starts_with(prefix))
2630 {
2631 break;
2632 }
2633 }
2634
2635 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2636 if lines.peek().is_some() {
2637 if this.line_end {
2638 chunk.push('\n');
2639 }
2640
2641 chunk.push_str(&line_without_cursor);
2642 this.line_end = true;
2643 consumed += line.len() + 1;
2644 } else if this.stream_done {
2645 if !this.starts_with_code_block
2646 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2647 {
2648 if this.line_end {
2649 chunk.push('\n');
2650 }
2651
2652 chunk.push_str(&line);
2653 }
2654
2655 consumed += line.len();
2656 } else {
2657 let trimmed_line = line.trim();
2658 if trimmed_line.is_empty()
2659 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2660 || prefixes(CODE_BLOCK_DELIMITER)
2661 .any(|prefix| trimmed_line.ends_with(prefix))
2662 {
2663 break;
2664 } else {
2665 if this.line_end {
2666 chunk.push('\n');
2667 this.line_end = false;
2668 }
2669
2670 chunk.push_str(&line_without_cursor);
2671 consumed += line.len();
2672 }
2673 }
2674 }
2675 }
2676
2677 this.buffer = this.buffer.split_off(consumed);
2678 if !chunk.is_empty() {
2679 return Poll::Ready(Some(Ok(chunk)));
2680 } else if this.stream_done {
2681 return Poll::Ready(None);
2682 }
2683 }
2684 }
2685}
2686
2687fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2688 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2689}
2690
2691fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2692 ranges.sort_unstable_by(|a, b| {
2693 a.start
2694 .cmp(&b.start, buffer)
2695 .then_with(|| b.end.cmp(&a.end, buffer))
2696 });
2697
2698 let mut ix = 0;
2699 while ix + 1 < ranges.len() {
2700 let b = ranges[ix + 1].clone();
2701 let a = &mut ranges[ix];
2702 if a.end.cmp(&b.start, buffer).is_gt() {
2703 if a.end.cmp(&b.end, buffer).is_lt() {
2704 a.end = b.end;
2705 }
2706 ranges.remove(ix + 1);
2707 } else {
2708 ix += 1;
2709 }
2710 }
2711}
2712
2713#[cfg(test)]
2714mod tests {
2715 use super::*;
2716 use futures::stream::{self};
2717 use gpui::{Context, TestAppContext};
2718 use indoc::indoc;
2719 use language::{
2720 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2721 Point,
2722 };
2723 use language_model::LanguageModelRegistry;
2724 use rand::prelude::*;
2725 use serde::Serialize;
2726 use settings::SettingsStore;
2727 use std::{future, sync::Arc};
2728
2729 #[derive(Serialize)]
2730 pub struct DummyCompletionRequest {
2731 pub name: String,
2732 }
2733
2734 #[gpui::test(iterations = 10)]
2735 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2736 cx.set_global(cx.update(SettingsStore::test));
2737 cx.update(language_model::LanguageModelRegistry::test);
2738 cx.update(language_settings::init);
2739
2740 let text = indoc! {"
2741 fn main() {
2742 let x = 0;
2743 for _ in 0..10 {
2744 x += 1;
2745 }
2746 }
2747 "};
2748 let buffer =
2749 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2750 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2751 let range = buffer.read_with(cx, |buffer, cx| {
2752 let snapshot = buffer.snapshot(cx);
2753 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2754 });
2755 let codegen =
2756 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2757
2758 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2759 codegen.update(cx, |codegen, cx| {
2760 codegen.handle_stream(
2761 String::new(),
2762 range,
2763 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2764 cx,
2765 )
2766 });
2767
2768 let mut new_text = concat!(
2769 " let mut x = 0;\n",
2770 " while x < 10 {\n",
2771 " x += 1;\n",
2772 " }",
2773 );
2774 while !new_text.is_empty() {
2775 let max_len = cmp::min(new_text.len(), 10);
2776 let len = rng.gen_range(1..=max_len);
2777 let (chunk, suffix) = new_text.split_at(len);
2778 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2779 new_text = suffix;
2780 cx.background_executor.run_until_parked();
2781 }
2782 drop(chunks_tx);
2783 cx.background_executor.run_until_parked();
2784
2785 assert_eq!(
2786 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2787 indoc! {"
2788 fn main() {
2789 let mut x = 0;
2790 while x < 10 {
2791 x += 1;
2792 }
2793 }
2794 "}
2795 );
2796 }
2797
2798 #[gpui::test(iterations = 10)]
2799 async fn test_autoindent_when_generating_past_indentation(
2800 cx: &mut TestAppContext,
2801 mut rng: StdRng,
2802 ) {
2803 cx.set_global(cx.update(SettingsStore::test));
2804 cx.update(language_settings::init);
2805
2806 let text = indoc! {"
2807 fn main() {
2808 le
2809 }
2810 "};
2811 let buffer =
2812 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2813 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2814 let range = buffer.read_with(cx, |buffer, cx| {
2815 let snapshot = buffer.snapshot(cx);
2816 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2817 });
2818 let codegen =
2819 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2820
2821 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2822 codegen.update(cx, |codegen, cx| {
2823 codegen.handle_stream(
2824 String::new(),
2825 range.clone(),
2826 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2827 cx,
2828 )
2829 });
2830
2831 cx.background_executor.run_until_parked();
2832
2833 let mut new_text = concat!(
2834 "t mut x = 0;\n",
2835 "while x < 10 {\n",
2836 " x += 1;\n",
2837 "}", //
2838 );
2839 while !new_text.is_empty() {
2840 let max_len = cmp::min(new_text.len(), 10);
2841 let len = rng.gen_range(1..=max_len);
2842 let (chunk, suffix) = new_text.split_at(len);
2843 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2844 new_text = suffix;
2845 cx.background_executor.run_until_parked();
2846 }
2847 drop(chunks_tx);
2848 cx.background_executor.run_until_parked();
2849
2850 assert_eq!(
2851 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2852 indoc! {"
2853 fn main() {
2854 let mut x = 0;
2855 while x < 10 {
2856 x += 1;
2857 }
2858 }
2859 "}
2860 );
2861 }
2862
2863 #[gpui::test(iterations = 10)]
2864 async fn test_autoindent_when_generating_before_indentation(
2865 cx: &mut TestAppContext,
2866 mut rng: StdRng,
2867 ) {
2868 cx.update(LanguageModelRegistry::test);
2869 cx.set_global(cx.update(SettingsStore::test));
2870 cx.update(language_settings::init);
2871
2872 let text = concat!(
2873 "fn main() {\n",
2874 " \n",
2875 "}\n" //
2876 );
2877 let buffer =
2878 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2879 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2880 let range = buffer.read_with(cx, |buffer, cx| {
2881 let snapshot = buffer.snapshot(cx);
2882 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2883 });
2884 let codegen =
2885 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2886
2887 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2888 codegen.update(cx, |codegen, cx| {
2889 codegen.handle_stream(
2890 String::new(),
2891 range.clone(),
2892 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2893 cx,
2894 )
2895 });
2896
2897 cx.background_executor.run_until_parked();
2898
2899 let mut new_text = concat!(
2900 "let mut x = 0;\n",
2901 "while x < 10 {\n",
2902 " x += 1;\n",
2903 "}", //
2904 );
2905 while !new_text.is_empty() {
2906 let max_len = cmp::min(new_text.len(), 10);
2907 let len = rng.gen_range(1..=max_len);
2908 let (chunk, suffix) = new_text.split_at(len);
2909 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2910 new_text = suffix;
2911 cx.background_executor.run_until_parked();
2912 }
2913 drop(chunks_tx);
2914 cx.background_executor.run_until_parked();
2915
2916 assert_eq!(
2917 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2918 indoc! {"
2919 fn main() {
2920 let mut x = 0;
2921 while x < 10 {
2922 x += 1;
2923 }
2924 }
2925 "}
2926 );
2927 }
2928
2929 #[gpui::test(iterations = 10)]
2930 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
2931 cx.update(LanguageModelRegistry::test);
2932 cx.set_global(cx.update(SettingsStore::test));
2933 cx.update(language_settings::init);
2934
2935 let text = indoc! {"
2936 func main() {
2937 \tx := 0
2938 \tfor i := 0; i < 10; i++ {
2939 \t\tx++
2940 \t}
2941 }
2942 "};
2943 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
2944 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2945 let range = buffer.read_with(cx, |buffer, cx| {
2946 let snapshot = buffer.snapshot(cx);
2947 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
2948 });
2949 let codegen =
2950 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2951
2952 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2953 codegen.update(cx, |codegen, cx| {
2954 codegen.handle_stream(
2955 String::new(),
2956 range.clone(),
2957 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2958 cx,
2959 )
2960 });
2961
2962 let new_text = concat!(
2963 "func main() {\n",
2964 "\tx := 0\n",
2965 "\tfor x < 10 {\n",
2966 "\t\tx++\n",
2967 "\t}", //
2968 );
2969 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
2970 drop(chunks_tx);
2971 cx.background_executor.run_until_parked();
2972
2973 assert_eq!(
2974 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2975 indoc! {"
2976 func main() {
2977 \tx := 0
2978 \tfor x < 10 {
2979 \t\tx++
2980 \t}
2981 }
2982 "}
2983 );
2984 }
2985
2986 #[gpui::test]
2987 async fn test_strip_invalid_spans_from_codeblock() {
2988 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2989 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2990 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
2991 assert_chunks(
2992 "```html\n```js\nLorem ipsum dolor\n```\n```",
2993 "```js\nLorem ipsum dolor\n```",
2994 )
2995 .await;
2996 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
2997 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
2998 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
2999 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3000
3001 async fn assert_chunks(text: &str, expected_text: &str) {
3002 for chunk_size in 1..=text.len() {
3003 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3004 .map(|chunk| chunk.unwrap())
3005 .collect::<String>()
3006 .await;
3007 assert_eq!(
3008 actual_text, expected_text,
3009 "failed to strip invalid spans, chunk size: {}",
3010 chunk_size
3011 );
3012 }
3013 }
3014
3015 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3016 stream::iter(
3017 text.chars()
3018 .collect::<Vec<_>>()
3019 .chunks(size)
3020 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3021 .collect::<Vec<_>>(),
3022 )
3023 }
3024 }
3025
3026 fn rust_lang() -> Language {
3027 Language::new(
3028 LanguageConfig {
3029 name: "Rust".into(),
3030 matcher: LanguageMatcher {
3031 path_suffixes: vec!["rs".to_string()],
3032 ..Default::default()
3033 },
3034 ..Default::default()
3035 },
3036 Some(tree_sitter_rust::language()),
3037 )
3038 .with_indents_query(
3039 r#"
3040 (call_expression) @indent
3041 (field_expression) @indent
3042 (_ "(" ")" @end) @indent
3043 (_ "{" "}" @end) @indent
3044 "#,
3045 )
3046 .unwrap()
3047 }
3048}