1use crate::{
2 humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
3 Hunk, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, Selection, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use similar::TextDiff;
39use smol::future::FutureExt;
40use std::{
41 cmp,
42 future::{self, Future},
43 mem,
44 ops::{Range, RangeInclusive},
45 pin::Pin,
46 sync::Arc,
47 task::{self, Poll},
48 time::{Duration, Instant},
49};
50use theme::ThemeSettings;
51use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
52use util::{RangeExt, ResultExt};
53use workspace::{notifications::NotificationId, Toast, Workspace};
54
55pub fn init(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>, cx: &mut AppContext) {
56 cx.set_global(InlineAssistant::new(fs, telemetry));
57}
58
59const PROMPT_HISTORY_MAX_LEN: usize = 20;
60
61pub struct InlineAssistant {
62 next_assist_id: InlineAssistId,
63 next_assist_group_id: InlineAssistGroupId,
64 assists: HashMap<InlineAssistId, InlineAssist>,
65 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
66 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
67 prompt_history: VecDeque<String>,
68 telemetry: Option<Arc<Telemetry>>,
69 fs: Arc<dyn Fs>,
70}
71
72impl Global for InlineAssistant {}
73
74impl InlineAssistant {
75 pub fn new(fs: Arc<dyn Fs>, telemetry: Arc<Telemetry>) -> Self {
76 Self {
77 next_assist_id: InlineAssistId::default(),
78 next_assist_group_id: InlineAssistGroupId::default(),
79 assists: HashMap::default(),
80 assists_by_editor: HashMap::default(),
81 assist_groups: HashMap::default(),
82 prompt_history: VecDeque::default(),
83 telemetry: Some(telemetry),
84 fs,
85 }
86 }
87
88 pub fn assist(
89 &mut self,
90 editor: &View<Editor>,
91 workspace: Option<WeakView<Workspace>>,
92 assistant_panel: Option<&View<AssistantPanel>>,
93 initial_prompt: Option<String>,
94 cx: &mut WindowContext,
95 ) {
96 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
97
98 let mut selections = Vec::<Selection<Point>>::new();
99 let mut newest_selection = None;
100 for mut selection in editor.read(cx).selections.all::<Point>(cx) {
101 if selection.end > selection.start {
102 selection.start.column = 0;
103 // If the selection ends at the start of the line, we don't want to include it.
104 if selection.end.column == 0 {
105 selection.end.row -= 1;
106 }
107 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
108 }
109
110 if let Some(prev_selection) = selections.last_mut() {
111 if selection.start <= prev_selection.end {
112 prev_selection.end = selection.end;
113 continue;
114 }
115 }
116
117 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
118 if selection.id > latest_selection.id {
119 *latest_selection = selection.clone();
120 }
121 selections.push(selection);
122 }
123 let newest_selection = newest_selection.unwrap();
124
125 let mut codegen_ranges = Vec::new();
126 for (excerpt_id, buffer, buffer_range) in
127 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
128 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
129 }))
130 {
131 let start = Anchor {
132 buffer_id: Some(buffer.remote_id()),
133 excerpt_id,
134 text_anchor: buffer.anchor_before(buffer_range.start),
135 };
136 let end = Anchor {
137 buffer_id: Some(buffer.remote_id()),
138 excerpt_id,
139 text_anchor: buffer.anchor_after(buffer_range.end),
140 };
141 codegen_ranges.push(start..end);
142 }
143
144 let assist_group_id = self.next_assist_group_id.post_inc();
145 let prompt_buffer =
146 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
147 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
148
149 let mut assists = Vec::new();
150 let mut assist_to_focus = None;
151 for range in codegen_ranges {
152 let assist_id = self.next_assist_id.post_inc();
153 let codegen = cx.new_model(|cx| {
154 Codegen::new(
155 editor.read(cx).buffer().clone(),
156 range.clone(),
157 None,
158 self.telemetry.clone(),
159 cx,
160 )
161 });
162
163 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
164 let prompt_editor = cx.new_view(|cx| {
165 PromptEditor::new(
166 assist_id,
167 gutter_dimensions.clone(),
168 self.prompt_history.clone(),
169 prompt_buffer.clone(),
170 codegen.clone(),
171 editor,
172 assistant_panel,
173 workspace.clone(),
174 self.fs.clone(),
175 cx,
176 )
177 });
178
179 if assist_to_focus.is_none() {
180 let focus_assist = if newest_selection.reversed {
181 range.start.to_point(&snapshot) == newest_selection.start
182 } else {
183 range.end.to_point(&snapshot) == newest_selection.end
184 };
185 if focus_assist {
186 assist_to_focus = Some(assist_id);
187 }
188 }
189
190 let [prompt_block_id, end_block_id] =
191 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
192
193 assists.push((
194 assist_id,
195 range,
196 prompt_editor,
197 prompt_block_id,
198 end_block_id,
199 ));
200 }
201
202 let editor_assists = self
203 .assists_by_editor
204 .entry(editor.downgrade())
205 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
206 let mut assist_group = InlineAssistGroup::new();
207 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
208 self.assists.insert(
209 assist_id,
210 InlineAssist::new(
211 assist_id,
212 assist_group_id,
213 assistant_panel.is_some(),
214 editor,
215 &prompt_editor,
216 prompt_block_id,
217 end_block_id,
218 range,
219 prompt_editor.read(cx).codegen.clone(),
220 workspace.clone(),
221 cx,
222 ),
223 );
224 assist_group.assist_ids.push(assist_id);
225 editor_assists.assist_ids.push(assist_id);
226 }
227 self.assist_groups.insert(assist_group_id, assist_group);
228
229 if let Some(assist_id) = assist_to_focus {
230 self.focus_assist(assist_id, cx);
231 }
232 }
233
234 #[allow(clippy::too_many_arguments)]
235 pub fn suggest_assist(
236 &mut self,
237 editor: &View<Editor>,
238 mut range: Range<Anchor>,
239 initial_prompt: String,
240 initial_transaction_id: Option<TransactionId>,
241 workspace: Option<WeakView<Workspace>>,
242 assistant_panel: Option<&View<AssistantPanel>>,
243 cx: &mut WindowContext,
244 ) -> InlineAssistId {
245 let assist_group_id = self.next_assist_group_id.post_inc();
246 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
247 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
248
249 let assist_id = self.next_assist_id.post_inc();
250
251 let buffer = editor.read(cx).buffer().clone();
252 {
253 let snapshot = buffer.read(cx).read(cx);
254 range.start = range.start.bias_left(&snapshot);
255 range.end = range.end.bias_right(&snapshot);
256 }
257
258 let codegen = cx.new_model(|cx| {
259 Codegen::new(
260 editor.read(cx).buffer().clone(),
261 range.clone(),
262 initial_transaction_id,
263 self.telemetry.clone(),
264 cx,
265 )
266 });
267
268 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
269 let prompt_editor = cx.new_view(|cx| {
270 PromptEditor::new(
271 assist_id,
272 gutter_dimensions.clone(),
273 self.prompt_history.clone(),
274 prompt_buffer.clone(),
275 codegen.clone(),
276 editor,
277 assistant_panel,
278 workspace.clone(),
279 self.fs.clone(),
280 cx,
281 )
282 });
283
284 let [prompt_block_id, end_block_id] =
285 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
286
287 let editor_assists = self
288 .assists_by_editor
289 .entry(editor.downgrade())
290 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
291
292 let mut assist_group = InlineAssistGroup::new();
293 self.assists.insert(
294 assist_id,
295 InlineAssist::new(
296 assist_id,
297 assist_group_id,
298 assistant_panel.is_some(),
299 editor,
300 &prompt_editor,
301 prompt_block_id,
302 end_block_id,
303 range,
304 prompt_editor.read(cx).codegen.clone(),
305 workspace.clone(),
306 cx,
307 ),
308 );
309 assist_group.assist_ids.push(assist_id);
310 editor_assists.assist_ids.push(assist_id);
311 self.assist_groups.insert(assist_group_id, assist_group);
312 assist_id
313 }
314
315 fn insert_assist_blocks(
316 &self,
317 editor: &View<Editor>,
318 range: &Range<Anchor>,
319 prompt_editor: &View<PromptEditor>,
320 cx: &mut WindowContext,
321 ) -> [CustomBlockId; 2] {
322 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
323 prompt_editor
324 .editor
325 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
326 });
327 let assist_blocks = vec![
328 BlockProperties {
329 style: BlockStyle::Sticky,
330 position: range.start,
331 height: prompt_editor_height,
332 render: build_assist_editor_renderer(prompt_editor),
333 disposition: BlockDisposition::Above,
334 },
335 BlockProperties {
336 style: BlockStyle::Sticky,
337 position: range.end,
338 height: 1,
339 render: Box::new(|cx| {
340 v_flex()
341 .h_full()
342 .w_full()
343 .border_t_1()
344 .border_color(cx.theme().status().info_border)
345 .into_any_element()
346 }),
347 disposition: BlockDisposition::Below,
348 },
349 ];
350
351 editor.update(cx, |editor, cx| {
352 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
353 [block_ids[0], block_ids[1]]
354 })
355 }
356
357 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
358 let assist = &self.assists[&assist_id];
359 let Some(decorations) = assist.decorations.as_ref() else {
360 return;
361 };
362 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
363 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
364
365 assist_group.active_assist_id = Some(assist_id);
366 if assist_group.linked {
367 for assist_id in &assist_group.assist_ids {
368 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
369 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
370 prompt_editor.set_show_cursor_when_unfocused(true, cx)
371 });
372 }
373 }
374 }
375
376 assist
377 .editor
378 .update(cx, |editor, cx| {
379 let scroll_top = editor.scroll_position(cx).y;
380 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
381 let prompt_row = editor
382 .row_for_block(decorations.prompt_block_id, cx)
383 .unwrap()
384 .0 as f32;
385
386 if (scroll_top..scroll_bottom).contains(&prompt_row) {
387 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
388 assist_id,
389 distance_from_top: prompt_row - scroll_top,
390 });
391 } else {
392 editor_assists.scroll_lock = None;
393 }
394 })
395 .ok();
396 }
397
398 fn handle_prompt_editor_focus_out(
399 &mut self,
400 assist_id: InlineAssistId,
401 cx: &mut WindowContext,
402 ) {
403 let assist = &self.assists[&assist_id];
404 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
405 if assist_group.active_assist_id == Some(assist_id) {
406 assist_group.active_assist_id = None;
407 if assist_group.linked {
408 for assist_id in &assist_group.assist_ids {
409 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
410 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
411 prompt_editor.set_show_cursor_when_unfocused(false, cx)
412 });
413 }
414 }
415 }
416 }
417 }
418
419 fn handle_prompt_editor_event(
420 &mut self,
421 prompt_editor: View<PromptEditor>,
422 event: &PromptEditorEvent,
423 cx: &mut WindowContext,
424 ) {
425 let assist_id = prompt_editor.read(cx).id;
426 match event {
427 PromptEditorEvent::StartRequested => {
428 self.start_assist(assist_id, cx);
429 }
430 PromptEditorEvent::StopRequested => {
431 self.stop_assist(assist_id, cx);
432 }
433 PromptEditorEvent::ConfirmRequested => {
434 self.finish_assist(assist_id, false, cx);
435 }
436 PromptEditorEvent::CancelRequested => {
437 self.finish_assist(assist_id, true, cx);
438 }
439 PromptEditorEvent::DismissRequested => {
440 self.dismiss_assist(assist_id, cx);
441 }
442 }
443 }
444
445 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
446 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
447 return;
448 };
449
450 let editor = editor.read(cx);
451 if editor.selections.count() == 1 {
452 let selection = editor.selections.newest::<usize>(cx);
453 let buffer = editor.buffer().read(cx).snapshot(cx);
454 for assist_id in &editor_assists.assist_ids {
455 let assist = &self.assists[assist_id];
456 let assist_range = assist.range.to_offset(&buffer);
457 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
458 {
459 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
460 self.dismiss_assist(*assist_id, cx);
461 } else {
462 self.finish_assist(*assist_id, false, cx);
463 }
464
465 return;
466 }
467 }
468 }
469
470 cx.propagate();
471 }
472
473 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
474 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
475 return;
476 };
477
478 let editor = editor.read(cx);
479 if editor.selections.count() == 1 {
480 let selection = editor.selections.newest::<usize>(cx);
481 let buffer = editor.buffer().read(cx).snapshot(cx);
482 for assist_id in &editor_assists.assist_ids {
483 let assist = &self.assists[assist_id];
484 let assist_range = assist.range.to_offset(&buffer);
485 if assist.decorations.is_some()
486 && assist_range.contains(&selection.start)
487 && assist_range.contains(&selection.end)
488 {
489 self.focus_assist(*assist_id, cx);
490 return;
491 }
492 }
493 }
494
495 cx.propagate();
496 }
497
498 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
499 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
500 for assist_id in editor_assists.assist_ids.clone() {
501 self.finish_assist(assist_id, true, cx);
502 }
503 }
504 }
505
506 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
507 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
508 return;
509 };
510 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
511 return;
512 };
513 let assist = &self.assists[&scroll_lock.assist_id];
514 let Some(decorations) = assist.decorations.as_ref() else {
515 return;
516 };
517
518 editor.update(cx, |editor, cx| {
519 let scroll_position = editor.scroll_position(cx);
520 let target_scroll_top = editor
521 .row_for_block(decorations.prompt_block_id, cx)
522 .unwrap()
523 .0 as f32
524 - scroll_lock.distance_from_top;
525 if target_scroll_top != scroll_position.y {
526 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
527 }
528 });
529 }
530
531 fn handle_editor_event(
532 &mut self,
533 editor: View<Editor>,
534 event: &EditorEvent,
535 cx: &mut WindowContext,
536 ) {
537 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
538 return;
539 };
540
541 match event {
542 EditorEvent::Saved => {
543 for assist_id in editor_assists.assist_ids.clone() {
544 let assist = &self.assists[&assist_id];
545 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
546 self.finish_assist(assist_id, false, cx)
547 }
548 }
549 }
550 EditorEvent::Edited { transaction_id } => {
551 let buffer = editor.read(cx).buffer().read(cx);
552 let edited_ranges =
553 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
554 let snapshot = buffer.snapshot(cx);
555
556 for assist_id in editor_assists.assist_ids.clone() {
557 let assist = &self.assists[&assist_id];
558 if matches!(
559 assist.codegen.read(cx).status,
560 CodegenStatus::Error(_) | CodegenStatus::Done
561 ) {
562 let assist_range = assist.range.to_offset(&snapshot);
563 if edited_ranges
564 .iter()
565 .any(|range| range.overlaps(&assist_range))
566 {
567 self.finish_assist(assist_id, false, cx);
568 }
569 }
570 }
571 }
572 EditorEvent::ScrollPositionChanged { .. } => {
573 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
574 let assist = &self.assists[&scroll_lock.assist_id];
575 if let Some(decorations) = assist.decorations.as_ref() {
576 let distance_from_top = editor.update(cx, |editor, cx| {
577 let scroll_top = editor.scroll_position(cx).y;
578 let prompt_row = editor
579 .row_for_block(decorations.prompt_block_id, cx)
580 .unwrap()
581 .0 as f32;
582 prompt_row - scroll_top
583 });
584
585 if distance_from_top != scroll_lock.distance_from_top {
586 editor_assists.scroll_lock = None;
587 }
588 }
589 }
590 }
591 EditorEvent::SelectionsChanged { .. } => {
592 for assist_id in editor_assists.assist_ids.clone() {
593 let assist = &self.assists[&assist_id];
594 if let Some(decorations) = assist.decorations.as_ref() {
595 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
596 return;
597 }
598 }
599 }
600
601 editor_assists.scroll_lock = None;
602 }
603 _ => {}
604 }
605 }
606
607 fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
608 if let Some(assist) = self.assists.get(&assist_id) {
609 let assist_group_id = assist.group_id;
610 if self.assist_groups[&assist_group_id].linked {
611 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
612 self.finish_assist(assist_id, undo, cx);
613 }
614 return;
615 }
616 }
617
618 self.dismiss_assist(assist_id, cx);
619
620 if let Some(assist) = self.assists.remove(&assist_id) {
621 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
622 {
623 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
624 if entry.get().assist_ids.is_empty() {
625 entry.remove();
626 }
627 }
628
629 if let hash_map::Entry::Occupied(mut entry) =
630 self.assists_by_editor.entry(assist.editor.clone())
631 {
632 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
633 if entry.get().assist_ids.is_empty() {
634 entry.remove();
635 if let Some(editor) = assist.editor.upgrade() {
636 self.update_editor_highlights(&editor, cx);
637 }
638 } else {
639 entry.get().highlight_updates.send(()).ok();
640 }
641 }
642
643 if undo {
644 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
645 }
646 }
647 }
648
649 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
650 let Some(assist) = self.assists.get_mut(&assist_id) else {
651 return false;
652 };
653 let Some(editor) = assist.editor.upgrade() else {
654 return false;
655 };
656 let Some(decorations) = assist.decorations.take() else {
657 return false;
658 };
659
660 editor.update(cx, |editor, cx| {
661 let mut to_remove = decorations.removed_line_block_ids;
662 to_remove.insert(decorations.prompt_block_id);
663 to_remove.insert(decorations.end_block_id);
664 editor.remove_blocks(to_remove, None, cx);
665 });
666
667 if decorations
668 .prompt_editor
669 .focus_handle(cx)
670 .contains_focused(cx)
671 {
672 self.focus_next_assist(assist_id, cx);
673 }
674
675 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
676 if editor_assists
677 .scroll_lock
678 .as_ref()
679 .map_or(false, |lock| lock.assist_id == assist_id)
680 {
681 editor_assists.scroll_lock = None;
682 }
683 editor_assists.highlight_updates.send(()).ok();
684 }
685
686 true
687 }
688
689 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
690 let Some(assist) = self.assists.get(&assist_id) else {
691 return;
692 };
693
694 let assist_group = &self.assist_groups[&assist.group_id];
695 let assist_ix = assist_group
696 .assist_ids
697 .iter()
698 .position(|id| *id == assist_id)
699 .unwrap();
700 let assist_ids = assist_group
701 .assist_ids
702 .iter()
703 .skip(assist_ix + 1)
704 .chain(assist_group.assist_ids.iter().take(assist_ix));
705
706 for assist_id in assist_ids {
707 let assist = &self.assists[assist_id];
708 if assist.decorations.is_some() {
709 self.focus_assist(*assist_id, cx);
710 return;
711 }
712 }
713
714 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
715 }
716
717 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
718 let assist = &self.assists[&assist_id];
719 let Some(editor) = assist.editor.upgrade() else {
720 return;
721 };
722
723 if let Some(decorations) = assist.decorations.as_ref() {
724 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
725 prompt_editor.editor.update(cx, |editor, cx| {
726 editor.focus(cx);
727 editor.select_all(&SelectAll, cx);
728 })
729 });
730 }
731
732 let position = assist.range.start;
733 editor.update(cx, |editor, cx| {
734 editor.change_selections(None, cx, |selections| {
735 selections.select_anchor_ranges([position..position])
736 });
737
738 let mut scroll_target_top;
739 let mut scroll_target_bottom;
740 if let Some(decorations) = assist.decorations.as_ref() {
741 scroll_target_top = editor
742 .row_for_block(decorations.prompt_block_id, cx)
743 .unwrap()
744 .0 as f32;
745 scroll_target_bottom = editor
746 .row_for_block(decorations.end_block_id, cx)
747 .unwrap()
748 .0 as f32;
749 } else {
750 let snapshot = editor.snapshot(cx);
751 let start_row = assist
752 .range
753 .start
754 .to_display_point(&snapshot.display_snapshot)
755 .row();
756 scroll_target_top = start_row.0 as f32;
757 scroll_target_bottom = scroll_target_top + 1.;
758 }
759 scroll_target_top -= editor.vertical_scroll_margin() as f32;
760 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
761
762 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
763 let scroll_top = editor.scroll_position(cx).y;
764 let scroll_bottom = scroll_top + height_in_lines;
765
766 if scroll_target_top < scroll_top {
767 editor.set_scroll_position(point(0., scroll_target_top), cx);
768 } else if scroll_target_bottom > scroll_bottom {
769 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
770 editor
771 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
772 } else {
773 editor.set_scroll_position(point(0., scroll_target_top), cx);
774 }
775 }
776 });
777 }
778
779 fn unlink_assist_group(
780 &mut self,
781 assist_group_id: InlineAssistGroupId,
782 cx: &mut WindowContext,
783 ) -> Vec<InlineAssistId> {
784 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
785 assist_group.linked = false;
786 for assist_id in &assist_group.assist_ids {
787 let assist = self.assists.get_mut(assist_id).unwrap();
788 if let Some(editor_decorations) = assist.decorations.as_ref() {
789 editor_decorations
790 .prompt_editor
791 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
792 }
793 }
794 assist_group.assist_ids.clone()
795 }
796
797 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
798 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
799 assist
800 } else {
801 return;
802 };
803
804 let assist_group_id = assist.group_id;
805 if self.assist_groups[&assist_group_id].linked {
806 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
807 self.start_assist(assist_id, cx);
808 }
809 return;
810 }
811
812 let Some(user_prompt) = assist.user_prompt(cx) else {
813 return;
814 };
815
816 self.prompt_history.retain(|prompt| *prompt != user_prompt);
817 self.prompt_history.push_back(user_prompt.clone());
818 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
819 self.prompt_history.pop_front();
820 }
821
822 let assistant_panel_context = assist.assistant_panel_context(cx);
823
824 assist
825 .codegen
826 .update(cx, |codegen, cx| {
827 codegen.start(
828 assist.range.clone(),
829 user_prompt,
830 assistant_panel_context,
831 cx,
832 )
833 })
834 .log_err();
835 }
836
837 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
838 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
839 assist
840 } else {
841 return;
842 };
843
844 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
845 }
846
847 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
848 let mut gutter_pending_ranges = Vec::new();
849 let mut gutter_transformed_ranges = Vec::new();
850 let mut foreground_ranges = Vec::new();
851 let mut inserted_row_ranges = Vec::new();
852 let empty_assist_ids = Vec::new();
853 let assist_ids = self
854 .assists_by_editor
855 .get(&editor.downgrade())
856 .map_or(&empty_assist_ids, |editor_assists| {
857 &editor_assists.assist_ids
858 });
859
860 for assist_id in assist_ids {
861 if let Some(assist) = self.assists.get(assist_id) {
862 let codegen = assist.codegen.read(cx);
863 let buffer = codegen.buffer.read(cx).read(cx);
864 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
865
866 let pending_range =
867 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
868 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
869 gutter_pending_ranges.push(pending_range);
870 }
871
872 if let Some(edit_position) = codegen.edit_position {
873 let edited_range = assist.range.start..edit_position;
874 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
875 gutter_transformed_ranges.push(edited_range);
876 }
877 }
878
879 if assist.decorations.is_some() {
880 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
881 }
882 }
883 }
884
885 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
886 merge_ranges(&mut foreground_ranges, &snapshot);
887 merge_ranges(&mut gutter_pending_ranges, &snapshot);
888 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
889 editor.update(cx, |editor, cx| {
890 enum GutterPendingRange {}
891 if gutter_pending_ranges.is_empty() {
892 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
893 } else {
894 editor.highlight_gutter::<GutterPendingRange>(
895 &gutter_pending_ranges,
896 |cx| cx.theme().status().info_background,
897 cx,
898 )
899 }
900
901 enum GutterTransformedRange {}
902 if gutter_transformed_ranges.is_empty() {
903 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
904 } else {
905 editor.highlight_gutter::<GutterTransformedRange>(
906 &gutter_transformed_ranges,
907 |cx| cx.theme().status().info,
908 cx,
909 )
910 }
911
912 if foreground_ranges.is_empty() {
913 editor.clear_highlights::<InlineAssist>(cx);
914 } else {
915 editor.highlight_text::<InlineAssist>(
916 foreground_ranges,
917 HighlightStyle {
918 fade_out: Some(0.6),
919 ..Default::default()
920 },
921 cx,
922 );
923 }
924
925 editor.clear_row_highlights::<InlineAssist>();
926 for row_range in inserted_row_ranges {
927 editor.highlight_rows::<InlineAssist>(
928 row_range,
929 Some(cx.theme().status().info_background),
930 false,
931 cx,
932 );
933 }
934 });
935 }
936
937 fn update_editor_blocks(
938 &mut self,
939 editor: &View<Editor>,
940 assist_id: InlineAssistId,
941 cx: &mut WindowContext,
942 ) {
943 let Some(assist) = self.assists.get_mut(&assist_id) else {
944 return;
945 };
946 let Some(decorations) = assist.decorations.as_mut() else {
947 return;
948 };
949
950 let codegen = assist.codegen.read(cx);
951 let old_snapshot = codegen.snapshot.clone();
952 let old_buffer = codegen.old_buffer.clone();
953 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
954
955 editor.update(cx, |editor, cx| {
956 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
957 editor.remove_blocks(old_blocks, None, cx);
958
959 let mut new_blocks = Vec::new();
960 for (new_row, old_row_range) in deleted_row_ranges {
961 let (_, buffer_start) = old_snapshot
962 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
963 .unwrap();
964 let (_, buffer_end) = old_snapshot
965 .point_to_buffer_offset(Point::new(
966 *old_row_range.end(),
967 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
968 ))
969 .unwrap();
970
971 let deleted_lines_editor = cx.new_view(|cx| {
972 let multi_buffer = cx.new_model(|_| {
973 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
974 });
975 multi_buffer.update(cx, |multi_buffer, cx| {
976 multi_buffer.push_excerpts(
977 old_buffer.clone(),
978 Some(ExcerptRange {
979 context: buffer_start..buffer_end,
980 primary: None,
981 }),
982 cx,
983 );
984 });
985
986 enum DeletedLines {}
987 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
988 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
989 editor.set_show_wrap_guides(false, cx);
990 editor.set_show_gutter(false, cx);
991 editor.scroll_manager.set_forbid_vertical_scroll(true);
992 editor.set_read_only(true);
993 editor.highlight_rows::<DeletedLines>(
994 Anchor::min()..=Anchor::max(),
995 Some(cx.theme().status().deleted_background),
996 false,
997 cx,
998 );
999 editor
1000 });
1001
1002 let height =
1003 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1004 new_blocks.push(BlockProperties {
1005 position: new_row,
1006 height,
1007 style: BlockStyle::Flex,
1008 render: Box::new(move |cx| {
1009 div()
1010 .bg(cx.theme().status().deleted_background)
1011 .size_full()
1012 .h(height as f32 * cx.line_height())
1013 .pl(cx.gutter_dimensions.full_width())
1014 .child(deleted_lines_editor.clone())
1015 .into_any_element()
1016 }),
1017 disposition: BlockDisposition::Above,
1018 });
1019 }
1020
1021 decorations.removed_line_block_ids = editor
1022 .insert_blocks(new_blocks, None, cx)
1023 .into_iter()
1024 .collect();
1025 })
1026 }
1027}
1028
1029struct EditorInlineAssists {
1030 assist_ids: Vec<InlineAssistId>,
1031 scroll_lock: Option<InlineAssistScrollLock>,
1032 highlight_updates: async_watch::Sender<()>,
1033 _update_highlights: Task<Result<()>>,
1034 _subscriptions: Vec<gpui::Subscription>,
1035}
1036
1037struct InlineAssistScrollLock {
1038 assist_id: InlineAssistId,
1039 distance_from_top: f32,
1040}
1041
1042impl EditorInlineAssists {
1043 #[allow(clippy::too_many_arguments)]
1044 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1045 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1046 Self {
1047 assist_ids: Vec::new(),
1048 scroll_lock: None,
1049 highlight_updates: highlight_updates_tx,
1050 _update_highlights: cx.spawn(|mut cx| {
1051 let editor = editor.downgrade();
1052 async move {
1053 while let Ok(()) = highlight_updates_rx.changed().await {
1054 let editor = editor.upgrade().context("editor was dropped")?;
1055 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1056 assistant.update_editor_highlights(&editor, cx);
1057 })?;
1058 }
1059 Ok(())
1060 }
1061 }),
1062 _subscriptions: vec![
1063 cx.observe_release(editor, {
1064 let editor = editor.downgrade();
1065 |_, cx| {
1066 InlineAssistant::update_global(cx, |this, cx| {
1067 this.handle_editor_release(editor, cx);
1068 })
1069 }
1070 }),
1071 cx.observe(editor, move |editor, cx| {
1072 InlineAssistant::update_global(cx, |this, cx| {
1073 this.handle_editor_change(editor, cx)
1074 })
1075 }),
1076 cx.subscribe(editor, move |editor, event, cx| {
1077 InlineAssistant::update_global(cx, |this, cx| {
1078 this.handle_editor_event(editor, event, cx)
1079 })
1080 }),
1081 editor.update(cx, |editor, cx| {
1082 let editor_handle = cx.view().downgrade();
1083 editor.register_action(
1084 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1085 InlineAssistant::update_global(cx, |this, cx| {
1086 if let Some(editor) = editor_handle.upgrade() {
1087 this.handle_editor_newline(editor, cx)
1088 }
1089 })
1090 },
1091 )
1092 }),
1093 editor.update(cx, |editor, cx| {
1094 let editor_handle = cx.view().downgrade();
1095 editor.register_action(
1096 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1097 InlineAssistant::update_global(cx, |this, cx| {
1098 if let Some(editor) = editor_handle.upgrade() {
1099 this.handle_editor_cancel(editor, cx)
1100 }
1101 })
1102 },
1103 )
1104 }),
1105 ],
1106 }
1107 }
1108}
1109
1110struct InlineAssistGroup {
1111 assist_ids: Vec<InlineAssistId>,
1112 linked: bool,
1113 active_assist_id: Option<InlineAssistId>,
1114}
1115
1116impl InlineAssistGroup {
1117 fn new() -> Self {
1118 Self {
1119 assist_ids: Vec::new(),
1120 linked: true,
1121 active_assist_id: None,
1122 }
1123 }
1124}
1125
1126fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1127 let editor = editor.clone();
1128 Box::new(move |cx: &mut BlockContext| {
1129 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1130 editor.clone().into_any_element()
1131 })
1132}
1133
1134#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1135pub enum InitialInsertion {
1136 NewlineBefore,
1137 NewlineAfter,
1138}
1139
1140#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1141pub struct InlineAssistId(usize);
1142
1143impl InlineAssistId {
1144 fn post_inc(&mut self) -> InlineAssistId {
1145 let id = *self;
1146 self.0 += 1;
1147 id
1148 }
1149}
1150
1151#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1152struct InlineAssistGroupId(usize);
1153
1154impl InlineAssistGroupId {
1155 fn post_inc(&mut self) -> InlineAssistGroupId {
1156 let id = *self;
1157 self.0 += 1;
1158 id
1159 }
1160}
1161
1162enum PromptEditorEvent {
1163 StartRequested,
1164 StopRequested,
1165 ConfirmRequested,
1166 CancelRequested,
1167 DismissRequested,
1168}
1169
1170struct PromptEditor {
1171 id: InlineAssistId,
1172 fs: Arc<dyn Fs>,
1173 editor: View<Editor>,
1174 edited_since_done: bool,
1175 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1176 prompt_history: VecDeque<String>,
1177 prompt_history_ix: Option<usize>,
1178 pending_prompt: String,
1179 codegen: Model<Codegen>,
1180 _codegen_subscription: Subscription,
1181 editor_subscriptions: Vec<Subscription>,
1182 pending_token_count: Task<Result<()>>,
1183 token_count: Option<usize>,
1184 _token_count_subscriptions: Vec<Subscription>,
1185 workspace: Option<WeakView<Workspace>>,
1186 show_rate_limit_notice: bool,
1187}
1188
1189impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1190
1191impl Render for PromptEditor {
1192 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1193 let gutter_dimensions = *self.gutter_dimensions.lock();
1194 let status = &self.codegen.read(cx).status;
1195 let buttons = match status {
1196 CodegenStatus::Idle => {
1197 vec![
1198 IconButton::new("cancel", IconName::Close)
1199 .icon_color(Color::Muted)
1200 .shape(IconButtonShape::Square)
1201 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1202 .on_click(
1203 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1204 ),
1205 IconButton::new("start", IconName::SparkleAlt)
1206 .icon_color(Color::Muted)
1207 .shape(IconButtonShape::Square)
1208 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1209 .on_click(
1210 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1211 ),
1212 ]
1213 }
1214 CodegenStatus::Pending => {
1215 vec![
1216 IconButton::new("cancel", IconName::Close)
1217 .icon_color(Color::Muted)
1218 .shape(IconButtonShape::Square)
1219 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1220 .on_click(
1221 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1222 ),
1223 IconButton::new("stop", IconName::Stop)
1224 .icon_color(Color::Error)
1225 .shape(IconButtonShape::Square)
1226 .tooltip(|cx| {
1227 Tooltip::with_meta(
1228 "Interrupt Transformation",
1229 Some(&menu::Cancel),
1230 "Changes won't be discarded",
1231 cx,
1232 )
1233 })
1234 .on_click(
1235 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1236 ),
1237 ]
1238 }
1239 CodegenStatus::Error(_) | CodegenStatus::Done => {
1240 vec![
1241 IconButton::new("cancel", IconName::Close)
1242 .icon_color(Color::Muted)
1243 .shape(IconButtonShape::Square)
1244 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1245 .on_click(
1246 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1247 ),
1248 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1249 IconButton::new("restart", IconName::RotateCw)
1250 .icon_color(Color::Info)
1251 .shape(IconButtonShape::Square)
1252 .tooltip(|cx| {
1253 Tooltip::with_meta(
1254 "Restart Transformation",
1255 Some(&menu::Confirm),
1256 "Changes will be discarded",
1257 cx,
1258 )
1259 })
1260 .on_click(cx.listener(|_, _, cx| {
1261 cx.emit(PromptEditorEvent::StartRequested);
1262 }))
1263 } else {
1264 IconButton::new("confirm", IconName::Check)
1265 .icon_color(Color::Info)
1266 .shape(IconButtonShape::Square)
1267 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1268 .on_click(cx.listener(|_, _, cx| {
1269 cx.emit(PromptEditorEvent::ConfirmRequested);
1270 }))
1271 },
1272 ]
1273 }
1274 };
1275
1276 h_flex()
1277 .bg(cx.theme().colors().editor_background)
1278 .border_y_1()
1279 .border_color(cx.theme().status().info_border)
1280 .size_full()
1281 .py(cx.line_height() / 2.)
1282 .on_action(cx.listener(Self::confirm))
1283 .on_action(cx.listener(Self::cancel))
1284 .on_action(cx.listener(Self::move_up))
1285 .on_action(cx.listener(Self::move_down))
1286 .child(
1287 h_flex()
1288 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1289 .justify_center()
1290 .gap_2()
1291 .child(
1292 ModelSelector::new(
1293 self.fs.clone(),
1294 IconButton::new("context", IconName::SlidersAlt)
1295 .shape(IconButtonShape::Square)
1296 .icon_size(IconSize::Small)
1297 .icon_color(Color::Muted)
1298 .tooltip(move |cx| {
1299 Tooltip::with_meta(
1300 format!(
1301 "Using {}",
1302 LanguageModelRegistry::read_global(cx)
1303 .active_model()
1304 .map(|model| model.name().0)
1305 .unwrap_or_else(|| "No model selected".into()),
1306 ),
1307 None,
1308 "Change Model",
1309 cx,
1310 )
1311 }),
1312 )
1313 .with_info_text(
1314 "Inline edits use context\n\
1315 from the currently selected\n\
1316 assistant panel tab.",
1317 ),
1318 )
1319 .map(|el| {
1320 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1321 return el;
1322 };
1323
1324 let error_message = SharedString::from(error.to_string());
1325 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1326 && cx.has_flag::<ZedPro>()
1327 {
1328 el.child(
1329 v_flex()
1330 .child(
1331 IconButton::new("rate-limit-error", IconName::XCircle)
1332 .selected(self.show_rate_limit_notice)
1333 .shape(IconButtonShape::Square)
1334 .icon_size(IconSize::Small)
1335 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1336 )
1337 .children(self.show_rate_limit_notice.then(|| {
1338 deferred(
1339 anchored()
1340 .position_mode(gpui::AnchoredPositionMode::Local)
1341 .position(point(px(0.), px(24.)))
1342 .anchor(gpui::AnchorCorner::TopLeft)
1343 .child(self.render_rate_limit_notice(cx)),
1344 )
1345 })),
1346 )
1347 } else {
1348 el.child(
1349 div()
1350 .id("error")
1351 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1352 .child(
1353 Icon::new(IconName::XCircle)
1354 .size(IconSize::Small)
1355 .color(Color::Error),
1356 ),
1357 )
1358 }
1359 }),
1360 )
1361 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1362 .child(
1363 h_flex()
1364 .gap_2()
1365 .pr_6()
1366 .children(self.render_token_count(cx))
1367 .children(buttons),
1368 )
1369 }
1370}
1371
1372impl FocusableView for PromptEditor {
1373 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1374 self.editor.focus_handle(cx)
1375 }
1376}
1377
1378impl PromptEditor {
1379 const MAX_LINES: u8 = 8;
1380
1381 #[allow(clippy::too_many_arguments)]
1382 fn new(
1383 id: InlineAssistId,
1384 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1385 prompt_history: VecDeque<String>,
1386 prompt_buffer: Model<MultiBuffer>,
1387 codegen: Model<Codegen>,
1388 parent_editor: &View<Editor>,
1389 assistant_panel: Option<&View<AssistantPanel>>,
1390 workspace: Option<WeakView<Workspace>>,
1391 fs: Arc<dyn Fs>,
1392 cx: &mut ViewContext<Self>,
1393 ) -> Self {
1394 let prompt_editor = cx.new_view(|cx| {
1395 let mut editor = Editor::new(
1396 EditorMode::AutoHeight {
1397 max_lines: Self::MAX_LINES as usize,
1398 },
1399 prompt_buffer,
1400 None,
1401 false,
1402 cx,
1403 );
1404 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1405 // Since the prompt editors for all inline assistants are linked,
1406 // always show the cursor (even when it isn't focused) because
1407 // typing in one will make what you typed appear in all of them.
1408 editor.set_show_cursor_when_unfocused(true, cx);
1409 editor.set_placeholder_text("Add a prompt…", cx);
1410 editor
1411 });
1412
1413 let mut token_count_subscriptions = Vec::new();
1414 token_count_subscriptions
1415 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1416 if let Some(assistant_panel) = assistant_panel {
1417 token_count_subscriptions
1418 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1419 }
1420
1421 let mut this = Self {
1422 id,
1423 editor: prompt_editor,
1424 edited_since_done: false,
1425 gutter_dimensions,
1426 prompt_history,
1427 prompt_history_ix: None,
1428 pending_prompt: String::new(),
1429 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1430 editor_subscriptions: Vec::new(),
1431 codegen,
1432 fs,
1433 pending_token_count: Task::ready(Ok(())),
1434 token_count: None,
1435 _token_count_subscriptions: token_count_subscriptions,
1436 workspace,
1437 show_rate_limit_notice: false,
1438 };
1439 this.count_tokens(cx);
1440 this.subscribe_to_editor(cx);
1441 this
1442 }
1443
1444 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1445 self.editor_subscriptions.clear();
1446 self.editor_subscriptions
1447 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1448 }
1449
1450 fn set_show_cursor_when_unfocused(
1451 &mut self,
1452 show_cursor_when_unfocused: bool,
1453 cx: &mut ViewContext<Self>,
1454 ) {
1455 self.editor.update(cx, |editor, cx| {
1456 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1457 });
1458 }
1459
1460 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1461 let prompt = self.prompt(cx);
1462 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1463 self.editor = cx.new_view(|cx| {
1464 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1465 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1466 editor.set_placeholder_text("Add a prompt…", cx);
1467 editor.set_text(prompt, cx);
1468 if focus {
1469 editor.focus(cx);
1470 }
1471 editor
1472 });
1473 self.subscribe_to_editor(cx);
1474 }
1475
1476 fn prompt(&self, cx: &AppContext) -> String {
1477 self.editor.read(cx).text(cx)
1478 }
1479
1480 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1481 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1482 if self.show_rate_limit_notice {
1483 cx.focus_view(&self.editor);
1484 }
1485 cx.notify();
1486 }
1487
1488 fn handle_parent_editor_event(
1489 &mut self,
1490 _: View<Editor>,
1491 event: &EditorEvent,
1492 cx: &mut ViewContext<Self>,
1493 ) {
1494 if let EditorEvent::BufferEdited { .. } = event {
1495 self.count_tokens(cx);
1496 }
1497 }
1498
1499 fn handle_assistant_panel_event(
1500 &mut self,
1501 _: View<AssistantPanel>,
1502 event: &AssistantPanelEvent,
1503 cx: &mut ViewContext<Self>,
1504 ) {
1505 let AssistantPanelEvent::ContextEdited { .. } = event;
1506 self.count_tokens(cx);
1507 }
1508
1509 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1510 let assist_id = self.id;
1511 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1512 cx.background_executor().timer(Duration::from_secs(1)).await;
1513 let token_count = cx
1514 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1515 let assist = inline_assistant
1516 .assists
1517 .get(&assist_id)
1518 .context("assist not found")?;
1519 anyhow::Ok(assist.count_tokens(cx))
1520 })??
1521 .await?;
1522
1523 this.update(&mut cx, |this, cx| {
1524 this.token_count = Some(token_count);
1525 cx.notify();
1526 })
1527 })
1528 }
1529
1530 fn handle_prompt_editor_events(
1531 &mut self,
1532 _: View<Editor>,
1533 event: &EditorEvent,
1534 cx: &mut ViewContext<Self>,
1535 ) {
1536 match event {
1537 EditorEvent::Edited { .. } => {
1538 let prompt = self.editor.read(cx).text(cx);
1539 if self
1540 .prompt_history_ix
1541 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1542 {
1543 self.prompt_history_ix.take();
1544 self.pending_prompt = prompt;
1545 }
1546
1547 self.edited_since_done = true;
1548 cx.notify();
1549 }
1550 EditorEvent::BufferEdited => {
1551 self.count_tokens(cx);
1552 }
1553 EditorEvent::Blurred => {
1554 if self.show_rate_limit_notice {
1555 self.show_rate_limit_notice = false;
1556 cx.notify();
1557 }
1558 }
1559 _ => {}
1560 }
1561 }
1562
1563 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1564 match &self.codegen.read(cx).status {
1565 CodegenStatus::Idle => {
1566 self.editor
1567 .update(cx, |editor, _| editor.set_read_only(false));
1568 }
1569 CodegenStatus::Pending => {
1570 self.editor
1571 .update(cx, |editor, _| editor.set_read_only(true));
1572 }
1573 CodegenStatus::Done => {
1574 self.edited_since_done = false;
1575 self.editor
1576 .update(cx, |editor, _| editor.set_read_only(false));
1577 }
1578 CodegenStatus::Error(error) => {
1579 if cx.has_flag::<ZedPro>()
1580 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1581 && !dismissed_rate_limit_notice()
1582 {
1583 self.show_rate_limit_notice = true;
1584 cx.notify();
1585 }
1586
1587 self.edited_since_done = false;
1588 self.editor
1589 .update(cx, |editor, _| editor.set_read_only(false));
1590 }
1591 }
1592 }
1593
1594 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1595 match &self.codegen.read(cx).status {
1596 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1597 cx.emit(PromptEditorEvent::CancelRequested);
1598 }
1599 CodegenStatus::Pending => {
1600 cx.emit(PromptEditorEvent::StopRequested);
1601 }
1602 }
1603 }
1604
1605 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1606 match &self.codegen.read(cx).status {
1607 CodegenStatus::Idle => {
1608 cx.emit(PromptEditorEvent::StartRequested);
1609 }
1610 CodegenStatus::Pending => {
1611 cx.emit(PromptEditorEvent::DismissRequested);
1612 }
1613 CodegenStatus::Done | CodegenStatus::Error(_) => {
1614 if self.edited_since_done {
1615 cx.emit(PromptEditorEvent::StartRequested);
1616 } else {
1617 cx.emit(PromptEditorEvent::ConfirmRequested);
1618 }
1619 }
1620 }
1621 }
1622
1623 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1624 if let Some(ix) = self.prompt_history_ix {
1625 if ix > 0 {
1626 self.prompt_history_ix = Some(ix - 1);
1627 let prompt = self.prompt_history[ix - 1].as_str();
1628 self.editor.update(cx, |editor, cx| {
1629 editor.set_text(prompt, cx);
1630 editor.move_to_beginning(&Default::default(), cx);
1631 });
1632 }
1633 } else if !self.prompt_history.is_empty() {
1634 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1635 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1636 self.editor.update(cx, |editor, cx| {
1637 editor.set_text(prompt, cx);
1638 editor.move_to_beginning(&Default::default(), cx);
1639 });
1640 }
1641 }
1642
1643 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1644 if let Some(ix) = self.prompt_history_ix {
1645 if ix < self.prompt_history.len() - 1 {
1646 self.prompt_history_ix = Some(ix + 1);
1647 let prompt = self.prompt_history[ix + 1].as_str();
1648 self.editor.update(cx, |editor, cx| {
1649 editor.set_text(prompt, cx);
1650 editor.move_to_end(&Default::default(), cx)
1651 });
1652 } else {
1653 self.prompt_history_ix = None;
1654 let prompt = self.pending_prompt.as_str();
1655 self.editor.update(cx, |editor, cx| {
1656 editor.set_text(prompt, cx);
1657 editor.move_to_end(&Default::default(), cx)
1658 });
1659 }
1660 }
1661 }
1662
1663 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1664 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1665 let token_count = self.token_count?;
1666 let max_token_count = model.max_token_count();
1667
1668 let remaining_tokens = max_token_count as isize - token_count as isize;
1669 let token_count_color = if remaining_tokens <= 0 {
1670 Color::Error
1671 } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1672 Color::Warning
1673 } else {
1674 Color::Muted
1675 };
1676
1677 let mut token_count = h_flex()
1678 .id("token_count")
1679 .gap_0p5()
1680 .child(
1681 Label::new(humanize_token_count(token_count))
1682 .size(LabelSize::Small)
1683 .color(token_count_color),
1684 )
1685 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1686 .child(
1687 Label::new(humanize_token_count(max_token_count))
1688 .size(LabelSize::Small)
1689 .color(Color::Muted),
1690 );
1691 if let Some(workspace) = self.workspace.clone() {
1692 token_count = token_count
1693 .tooltip(|cx| {
1694 Tooltip::with_meta(
1695 "Tokens Used by Inline Assistant",
1696 None,
1697 "Click to Open Assistant Panel",
1698 cx,
1699 )
1700 })
1701 .cursor_pointer()
1702 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1703 .on_click(move |_, cx| {
1704 cx.stop_propagation();
1705 workspace
1706 .update(cx, |workspace, cx| {
1707 workspace.focus_panel::<AssistantPanel>(cx)
1708 })
1709 .ok();
1710 });
1711 } else {
1712 token_count = token_count
1713 .cursor_default()
1714 .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1715 }
1716
1717 Some(token_count)
1718 }
1719
1720 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1721 let settings = ThemeSettings::get_global(cx);
1722 let text_style = TextStyle {
1723 color: if self.editor.read(cx).read_only(cx) {
1724 cx.theme().colors().text_disabled
1725 } else {
1726 cx.theme().colors().text
1727 },
1728 font_family: settings.ui_font.family.clone(),
1729 font_features: settings.ui_font.features.clone(),
1730 font_fallbacks: settings.ui_font.fallbacks.clone(),
1731 font_size: rems(0.875).into(),
1732 font_weight: settings.ui_font.weight,
1733 line_height: relative(1.3),
1734 ..Default::default()
1735 };
1736 EditorElement::new(
1737 &self.editor,
1738 EditorStyle {
1739 background: cx.theme().colors().editor_background,
1740 local_player: cx.theme().players().local(),
1741 text: text_style,
1742 ..Default::default()
1743 },
1744 )
1745 }
1746
1747 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1748 Popover::new().child(
1749 v_flex()
1750 .occlude()
1751 .p_2()
1752 .child(
1753 Label::new("Out of Tokens")
1754 .size(LabelSize::Small)
1755 .weight(FontWeight::BOLD),
1756 )
1757 .child(Label::new(
1758 "Try Zed Pro for higher limits, a wider range of models, and more.",
1759 ))
1760 .child(
1761 h_flex()
1762 .justify_between()
1763 .child(CheckboxWithLabel::new(
1764 "dont-show-again",
1765 Label::new("Don't show again"),
1766 if dismissed_rate_limit_notice() {
1767 ui::Selection::Selected
1768 } else {
1769 ui::Selection::Unselected
1770 },
1771 |selection, cx| {
1772 let is_dismissed = match selection {
1773 ui::Selection::Unselected => false,
1774 ui::Selection::Indeterminate => return,
1775 ui::Selection::Selected => true,
1776 };
1777
1778 set_rate_limit_notice_dismissed(is_dismissed, cx)
1779 },
1780 ))
1781 .child(
1782 h_flex()
1783 .gap_2()
1784 .child(
1785 Button::new("dismiss", "Dismiss")
1786 .style(ButtonStyle::Transparent)
1787 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1788 )
1789 .child(Button::new("more-info", "More Info").on_click(
1790 |_event, cx| {
1791 cx.dispatch_action(Box::new(
1792 zed_actions::OpenAccountSettings,
1793 ))
1794 },
1795 )),
1796 ),
1797 ),
1798 )
1799 }
1800}
1801
1802const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1803
1804fn dismissed_rate_limit_notice() -> bool {
1805 db::kvp::KEY_VALUE_STORE
1806 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1807 .log_err()
1808 .map_or(false, |s| s.is_some())
1809}
1810
1811fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1812 db::write_and_log(cx, move || async move {
1813 if is_dismissed {
1814 db::kvp::KEY_VALUE_STORE
1815 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1816 .await
1817 } else {
1818 db::kvp::KEY_VALUE_STORE
1819 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1820 .await
1821 }
1822 })
1823}
1824
1825struct InlineAssist {
1826 group_id: InlineAssistGroupId,
1827 range: Range<Anchor>,
1828 editor: WeakView<Editor>,
1829 decorations: Option<InlineAssistDecorations>,
1830 codegen: Model<Codegen>,
1831 _subscriptions: Vec<Subscription>,
1832 workspace: Option<WeakView<Workspace>>,
1833 include_context: bool,
1834}
1835
1836impl InlineAssist {
1837 #[allow(clippy::too_many_arguments)]
1838 fn new(
1839 assist_id: InlineAssistId,
1840 group_id: InlineAssistGroupId,
1841 include_context: bool,
1842 editor: &View<Editor>,
1843 prompt_editor: &View<PromptEditor>,
1844 prompt_block_id: CustomBlockId,
1845 end_block_id: CustomBlockId,
1846 range: Range<Anchor>,
1847 codegen: Model<Codegen>,
1848 workspace: Option<WeakView<Workspace>>,
1849 cx: &mut WindowContext,
1850 ) -> Self {
1851 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1852 InlineAssist {
1853 group_id,
1854 include_context,
1855 editor: editor.downgrade(),
1856 decorations: Some(InlineAssistDecorations {
1857 prompt_block_id,
1858 prompt_editor: prompt_editor.clone(),
1859 removed_line_block_ids: HashSet::default(),
1860 end_block_id,
1861 }),
1862 range,
1863 codegen: codegen.clone(),
1864 workspace: workspace.clone(),
1865 _subscriptions: vec![
1866 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1867 InlineAssistant::update_global(cx, |this, cx| {
1868 this.handle_prompt_editor_focus_in(assist_id, cx)
1869 })
1870 }),
1871 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1872 InlineAssistant::update_global(cx, |this, cx| {
1873 this.handle_prompt_editor_focus_out(assist_id, cx)
1874 })
1875 }),
1876 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1877 InlineAssistant::update_global(cx, |this, cx| {
1878 this.handle_prompt_editor_event(prompt_editor, event, cx)
1879 })
1880 }),
1881 cx.observe(&codegen, {
1882 let editor = editor.downgrade();
1883 move |_, cx| {
1884 if let Some(editor) = editor.upgrade() {
1885 InlineAssistant::update_global(cx, |this, cx| {
1886 if let Some(editor_assists) =
1887 this.assists_by_editor.get(&editor.downgrade())
1888 {
1889 editor_assists.highlight_updates.send(()).ok();
1890 }
1891
1892 this.update_editor_blocks(&editor, assist_id, cx);
1893 })
1894 }
1895 }
1896 }),
1897 cx.subscribe(&codegen, move |codegen, event, cx| {
1898 InlineAssistant::update_global(cx, |this, cx| match event {
1899 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
1900 CodegenEvent::Finished => {
1901 let assist = if let Some(assist) = this.assists.get(&assist_id) {
1902 assist
1903 } else {
1904 return;
1905 };
1906
1907 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
1908 if assist.decorations.is_none() {
1909 if let Some(workspace) = assist
1910 .workspace
1911 .as_ref()
1912 .and_then(|workspace| workspace.upgrade())
1913 {
1914 let error = format!("Inline assistant error: {}", error);
1915 workspace.update(cx, |workspace, cx| {
1916 struct InlineAssistantError;
1917
1918 let id =
1919 NotificationId::identified::<InlineAssistantError>(
1920 assist_id.0,
1921 );
1922
1923 workspace.show_toast(Toast::new(id, error), cx);
1924 })
1925 }
1926 }
1927 }
1928
1929 if assist.decorations.is_none() {
1930 this.finish_assist(assist_id, false, cx);
1931 }
1932 }
1933 })
1934 }),
1935 ],
1936 }
1937 }
1938
1939 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
1940 let decorations = self.decorations.as_ref()?;
1941 Some(decorations.prompt_editor.read(cx).prompt(cx))
1942 }
1943
1944 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
1945 if self.include_context {
1946 let workspace = self.workspace.as_ref()?;
1947 let workspace = workspace.upgrade()?.read(cx);
1948 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
1949 Some(
1950 assistant_panel
1951 .read(cx)
1952 .active_context(cx)?
1953 .read(cx)
1954 .to_completion_request(cx),
1955 )
1956 } else {
1957 None
1958 }
1959 }
1960
1961 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
1962 let Some(user_prompt) = self.user_prompt(cx) else {
1963 return future::ready(Err(anyhow!("no user prompt"))).boxed();
1964 };
1965 let assistant_panel_context = self.assistant_panel_context(cx);
1966 self.codegen.read(cx).count_tokens(
1967 self.range.clone(),
1968 user_prompt,
1969 assistant_panel_context,
1970 cx,
1971 )
1972 }
1973}
1974
1975struct InlineAssistDecorations {
1976 prompt_block_id: CustomBlockId,
1977 prompt_editor: View<PromptEditor>,
1978 removed_line_block_ids: HashSet<CustomBlockId>,
1979 end_block_id: CustomBlockId,
1980}
1981
1982#[derive(Debug)]
1983pub enum CodegenEvent {
1984 Finished,
1985 Undone,
1986}
1987
1988pub struct Codegen {
1989 buffer: Model<MultiBuffer>,
1990 old_buffer: Model<Buffer>,
1991 snapshot: MultiBufferSnapshot,
1992 edit_position: Option<Anchor>,
1993 last_equal_ranges: Vec<Range<Anchor>>,
1994 initial_transaction_id: Option<TransactionId>,
1995 transformation_transaction_id: Option<TransactionId>,
1996 status: CodegenStatus,
1997 generation: Task<()>,
1998 diff: Diff,
1999 telemetry: Option<Arc<Telemetry>>,
2000 _subscription: gpui::Subscription,
2001}
2002
2003enum CodegenStatus {
2004 Idle,
2005 Pending,
2006 Done,
2007 Error(anyhow::Error),
2008}
2009
2010#[derive(Default)]
2011struct Diff {
2012 task: Option<Task<()>>,
2013 should_update: bool,
2014 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2015 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2016}
2017
2018impl EventEmitter<CodegenEvent> for Codegen {}
2019
2020impl Codegen {
2021 pub fn new(
2022 buffer: Model<MultiBuffer>,
2023 range: Range<Anchor>,
2024 initial_transaction_id: Option<TransactionId>,
2025 telemetry: Option<Arc<Telemetry>>,
2026 cx: &mut ModelContext<Self>,
2027 ) -> Self {
2028 let snapshot = buffer.read(cx).snapshot(cx);
2029
2030 let (old_buffer, _, _) = buffer
2031 .read(cx)
2032 .range_to_buffer_ranges(range.clone(), cx)
2033 .pop()
2034 .unwrap();
2035 let old_buffer = cx.new_model(|cx| {
2036 let old_buffer = old_buffer.read(cx);
2037 let text = old_buffer.as_rope().clone();
2038 let line_ending = old_buffer.line_ending();
2039 let language = old_buffer.language().cloned();
2040 let language_registry = old_buffer.language_registry();
2041
2042 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2043 buffer.set_language(language, cx);
2044 if let Some(language_registry) = language_registry {
2045 buffer.set_language_registry(language_registry)
2046 }
2047 buffer
2048 });
2049
2050 Self {
2051 buffer: buffer.clone(),
2052 old_buffer,
2053 edit_position: None,
2054 snapshot,
2055 last_equal_ranges: Default::default(),
2056 transformation_transaction_id: None,
2057 status: CodegenStatus::Idle,
2058 generation: Task::ready(()),
2059 diff: Diff::default(),
2060 telemetry,
2061 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2062 initial_transaction_id,
2063 }
2064 }
2065
2066 fn handle_buffer_event(
2067 &mut self,
2068 _buffer: Model<MultiBuffer>,
2069 event: &multi_buffer::Event,
2070 cx: &mut ModelContext<Self>,
2071 ) {
2072 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2073 if self.transformation_transaction_id == Some(*transaction_id) {
2074 self.transformation_transaction_id = None;
2075 self.generation = Task::ready(());
2076 cx.emit(CodegenEvent::Undone);
2077 }
2078 }
2079 }
2080
2081 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2082 &self.last_equal_ranges
2083 }
2084
2085 pub fn count_tokens(
2086 &self,
2087 edit_range: Range<Anchor>,
2088 user_prompt: String,
2089 assistant_panel_context: Option<LanguageModelRequest>,
2090 cx: &AppContext,
2091 ) -> BoxFuture<'static, Result<usize>> {
2092 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2093 let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
2094 model.count_tokens(request, cx)
2095 } else {
2096 future::ready(Err(anyhow!("no active model"))).boxed()
2097 }
2098 }
2099
2100 pub fn start(
2101 &mut self,
2102 edit_range: Range<Anchor>,
2103 user_prompt: String,
2104 assistant_panel_context: Option<LanguageModelRequest>,
2105 cx: &mut ModelContext<Self>,
2106 ) -> Result<()> {
2107 let model = LanguageModelRegistry::read_global(cx)
2108 .active_model()
2109 .context("no active model")?;
2110
2111 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2112 self.buffer.update(cx, |buffer, cx| {
2113 buffer.undo_transaction(transformation_transaction_id, cx)
2114 });
2115 }
2116
2117 self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
2118
2119 let telemetry_id = model.telemetry_id();
2120 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
2121 .trim()
2122 .to_lowercase()
2123 == "delete"
2124 {
2125 async { Ok(stream::empty().boxed()) }.boxed_local()
2126 } else {
2127 let request =
2128 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
2129 let chunks =
2130 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2131 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2132 };
2133 self.handle_stream(telemetry_id, edit_range, chunks, cx);
2134 Ok(())
2135 }
2136
2137 fn build_request(
2138 &self,
2139 user_prompt: String,
2140 assistant_panel_context: Option<LanguageModelRequest>,
2141 edit_range: Range<Anchor>,
2142 cx: &AppContext,
2143 ) -> LanguageModelRequest {
2144 let buffer = self.buffer.read(cx).snapshot(cx);
2145 let language = buffer.language_at(edit_range.start);
2146 let language_name = if let Some(language) = language.as_ref() {
2147 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2148 None
2149 } else {
2150 Some(language.name())
2151 }
2152 } else {
2153 None
2154 };
2155
2156 // Higher Temperature increases the randomness of model outputs.
2157 // If Markdown or No Language is Known, increase the randomness for more creative output
2158 // If Code, decrease temperature to get more deterministic outputs
2159 let temperature = if let Some(language) = language_name.clone() {
2160 if language.as_ref() == "Markdown" {
2161 1.0
2162 } else {
2163 0.5
2164 }
2165 } else {
2166 1.0
2167 };
2168
2169 let language_name = language_name.as_deref();
2170 let start = buffer.point_to_buffer_offset(edit_range.start);
2171 let end = buffer.point_to_buffer_offset(edit_range.end);
2172 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2173 let (start_buffer, start_buffer_offset) = start;
2174 let (end_buffer, end_buffer_offset) = end;
2175 if start_buffer.remote_id() == end_buffer.remote_id() {
2176 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2177 } else {
2178 panic!("invalid transformation range");
2179 }
2180 } else {
2181 panic!("invalid transformation range");
2182 };
2183 let prompt = generate_content_prompt(user_prompt, language_name, buffer, range);
2184
2185 let mut messages = Vec::new();
2186 if let Some(context_request) = assistant_panel_context {
2187 messages = context_request.messages;
2188 }
2189
2190 messages.push(LanguageModelRequestMessage {
2191 role: Role::User,
2192 content: prompt,
2193 });
2194
2195 LanguageModelRequest {
2196 messages,
2197 stop: vec!["|END|>".to_string()],
2198 temperature,
2199 }
2200 }
2201
2202 pub fn handle_stream(
2203 &mut self,
2204 model_telemetry_id: String,
2205 edit_range: Range<Anchor>,
2206 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2207 cx: &mut ModelContext<Self>,
2208 ) {
2209 let snapshot = self.snapshot.clone();
2210 let selected_text = snapshot
2211 .text_for_range(edit_range.start..edit_range.end)
2212 .collect::<Rope>();
2213
2214 let selection_start = edit_range.start.to_point(&snapshot);
2215
2216 // Start with the indentation of the first line in the selection
2217 let mut suggested_line_indent = snapshot
2218 .suggested_indents(selection_start.row..=selection_start.row, cx)
2219 .into_values()
2220 .next()
2221 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2222
2223 // If the first line in the selection does not have indentation, check the following lines
2224 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2225 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2226 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2227 // Prefer tabs if a line in the selection uses tabs as indentation
2228 if line_indent.kind == IndentKind::Tab {
2229 suggested_line_indent.kind = IndentKind::Tab;
2230 break;
2231 }
2232 }
2233 }
2234
2235 let telemetry = self.telemetry.clone();
2236 self.diff = Diff::default();
2237 self.status = CodegenStatus::Pending;
2238 let mut edit_start = edit_range.start.to_offset(&snapshot);
2239 self.generation = cx.spawn(|this, mut cx| {
2240 async move {
2241 let chunks = stream.await;
2242 let generate = async {
2243 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
2244 let diff: Task<anyhow::Result<()>> =
2245 cx.background_executor().spawn(async move {
2246 let mut response_latency = None;
2247 let request_start = Instant::now();
2248 let diff = async {
2249 let chunks = StripInvalidSpans::new(chunks?);
2250 futures::pin_mut!(chunks);
2251 let mut diff = StreamingDiff::new(selected_text.to_string());
2252
2253 let mut new_text = String::new();
2254 let mut base_indent = None;
2255 let mut line_indent = None;
2256 let mut first_line = true;
2257
2258 while let Some(chunk) = chunks.next().await {
2259 if response_latency.is_none() {
2260 response_latency = Some(request_start.elapsed());
2261 }
2262 let chunk = chunk?;
2263
2264 let mut lines = chunk.split('\n').peekable();
2265 while let Some(line) = lines.next() {
2266 new_text.push_str(line);
2267 if line_indent.is_none() {
2268 if let Some(non_whitespace_ch_ix) =
2269 new_text.find(|ch: char| !ch.is_whitespace())
2270 {
2271 line_indent = Some(non_whitespace_ch_ix);
2272 base_indent = base_indent.or(line_indent);
2273
2274 let line_indent = line_indent.unwrap();
2275 let base_indent = base_indent.unwrap();
2276 let indent_delta =
2277 line_indent as i32 - base_indent as i32;
2278 let mut corrected_indent_len = cmp::max(
2279 0,
2280 suggested_line_indent.len as i32 + indent_delta,
2281 )
2282 as usize;
2283 if first_line {
2284 corrected_indent_len = corrected_indent_len
2285 .saturating_sub(
2286 selection_start.column as usize,
2287 );
2288 }
2289
2290 let indent_char = suggested_line_indent.char();
2291 let mut indent_buffer = [0; 4];
2292 let indent_str =
2293 indent_char.encode_utf8(&mut indent_buffer);
2294 new_text.replace_range(
2295 ..line_indent,
2296 &indent_str.repeat(corrected_indent_len),
2297 );
2298 }
2299 }
2300
2301 if line_indent.is_some() {
2302 hunks_tx.send(diff.push_new(&new_text)).await?;
2303 new_text.clear();
2304 }
2305
2306 if lines.peek().is_some() {
2307 hunks_tx.send(diff.push_new("\n")).await?;
2308 if line_indent.is_none() {
2309 // Don't write out the leading indentation in empty lines on the next line
2310 // This is the case where the above if statement didn't clear the buffer
2311 new_text.clear();
2312 }
2313 line_indent = None;
2314 first_line = false;
2315 }
2316 }
2317 }
2318 hunks_tx.send(diff.push_new(&new_text)).await?;
2319 hunks_tx.send(diff.finish()).await?;
2320
2321 anyhow::Ok(())
2322 };
2323
2324 let result = diff.await;
2325
2326 let error_message =
2327 result.as_ref().err().map(|error| error.to_string());
2328 if let Some(telemetry) = telemetry {
2329 telemetry.report_assistant_event(
2330 None,
2331 telemetry_events::AssistantKind::Inline,
2332 model_telemetry_id,
2333 response_latency,
2334 error_message,
2335 );
2336 }
2337
2338 result?;
2339 Ok(())
2340 });
2341
2342 while let Some(hunks) = hunks_rx.next().await {
2343 this.update(&mut cx, |this, cx| {
2344 this.last_equal_ranges.clear();
2345
2346 let transaction = this.buffer.update(cx, |buffer, cx| {
2347 // Avoid grouping assistant edits with user edits.
2348 buffer.finalize_last_transaction(cx);
2349
2350 buffer.start_transaction(cx);
2351 buffer.edit(
2352 hunks.into_iter().filter_map(|hunk| match hunk {
2353 Hunk::Insert { text } => {
2354 let edit_start = snapshot.anchor_after(edit_start);
2355 Some((edit_start..edit_start, text))
2356 }
2357 Hunk::Remove { len } => {
2358 let edit_end = edit_start + len;
2359 let edit_range = snapshot.anchor_after(edit_start)
2360 ..snapshot.anchor_before(edit_end);
2361 edit_start = edit_end;
2362 Some((edit_range, String::new()))
2363 }
2364 Hunk::Keep { len } => {
2365 let edit_end = edit_start + len;
2366 let edit_range = snapshot.anchor_after(edit_start)
2367 ..snapshot.anchor_before(edit_end);
2368 edit_start = edit_end;
2369 this.last_equal_ranges.push(edit_range);
2370 None
2371 }
2372 }),
2373 None,
2374 cx,
2375 );
2376 this.edit_position = Some(snapshot.anchor_after(edit_start));
2377
2378 buffer.end_transaction(cx)
2379 });
2380
2381 if let Some(transaction) = transaction {
2382 if let Some(first_transaction) = this.transformation_transaction_id
2383 {
2384 // Group all assistant edits into the first transaction.
2385 this.buffer.update(cx, |buffer, cx| {
2386 buffer.merge_transactions(
2387 transaction,
2388 first_transaction,
2389 cx,
2390 )
2391 });
2392 } else {
2393 this.transformation_transaction_id = Some(transaction);
2394 this.buffer.update(cx, |buffer, cx| {
2395 buffer.finalize_last_transaction(cx)
2396 });
2397 }
2398 }
2399
2400 this.update_diff(edit_range.clone(), cx);
2401 cx.notify();
2402 })?;
2403 }
2404
2405 diff.await?;
2406
2407 anyhow::Ok(())
2408 };
2409
2410 let result = generate.await;
2411 this.update(&mut cx, |this, cx| {
2412 this.last_equal_ranges.clear();
2413 if let Err(error) = result {
2414 this.status = CodegenStatus::Error(error);
2415 } else {
2416 this.status = CodegenStatus::Done;
2417 }
2418 cx.emit(CodegenEvent::Finished);
2419 cx.notify();
2420 })
2421 .ok();
2422 }
2423 });
2424 cx.notify();
2425 }
2426
2427 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2428 self.last_equal_ranges.clear();
2429 self.status = CodegenStatus::Done;
2430 self.generation = Task::ready(());
2431 cx.emit(CodegenEvent::Finished);
2432 cx.notify();
2433 }
2434
2435 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2436 self.buffer.update(cx, |buffer, cx| {
2437 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2438 buffer.undo_transaction(transaction_id, cx);
2439 }
2440
2441 if let Some(transaction_id) = self.initial_transaction_id.take() {
2442 buffer.undo_transaction(transaction_id, cx);
2443 }
2444 });
2445 }
2446
2447 fn update_diff(&mut self, edit_range: Range<Anchor>, cx: &mut ModelContext<Self>) {
2448 if self.diff.task.is_some() {
2449 self.diff.should_update = true;
2450 } else {
2451 self.diff.should_update = false;
2452
2453 let old_snapshot = self.snapshot.clone();
2454 let old_range = edit_range.to_point(&old_snapshot);
2455 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2456 let new_range = edit_range.to_point(&new_snapshot);
2457
2458 self.diff.task = Some(cx.spawn(|this, mut cx| async move {
2459 let (deleted_row_ranges, inserted_row_ranges) = cx
2460 .background_executor()
2461 .spawn(async move {
2462 let old_text = old_snapshot
2463 .text_for_range(
2464 Point::new(old_range.start.row, 0)
2465 ..Point::new(
2466 old_range.end.row,
2467 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2468 ),
2469 )
2470 .collect::<String>();
2471 let new_text = new_snapshot
2472 .text_for_range(
2473 Point::new(new_range.start.row, 0)
2474 ..Point::new(
2475 new_range.end.row,
2476 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2477 ),
2478 )
2479 .collect::<String>();
2480
2481 let mut old_row = old_range.start.row;
2482 let mut new_row = new_range.start.row;
2483 let diff = TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2484
2485 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2486 let mut inserted_row_ranges = Vec::new();
2487 for change in diff.iter_all_changes() {
2488 let line_count = change.value().lines().count() as u32;
2489 match change.tag() {
2490 similar::ChangeTag::Equal => {
2491 old_row += line_count;
2492 new_row += line_count;
2493 }
2494 similar::ChangeTag::Delete => {
2495 let old_end_row = old_row + line_count - 1;
2496 let new_row =
2497 new_snapshot.anchor_before(Point::new(new_row, 0));
2498
2499 if let Some((_, last_deleted_row_range)) =
2500 deleted_row_ranges.last_mut()
2501 {
2502 if *last_deleted_row_range.end() + 1 == old_row {
2503 *last_deleted_row_range =
2504 *last_deleted_row_range.start()..=old_end_row;
2505 } else {
2506 deleted_row_ranges
2507 .push((new_row, old_row..=old_end_row));
2508 }
2509 } else {
2510 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2511 }
2512
2513 old_row += line_count;
2514 }
2515 similar::ChangeTag::Insert => {
2516 let new_end_row = new_row + line_count - 1;
2517 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2518 let end = new_snapshot.anchor_before(Point::new(
2519 new_end_row,
2520 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2521 ));
2522 inserted_row_ranges.push(start..=end);
2523 new_row += line_count;
2524 }
2525 }
2526 }
2527
2528 (deleted_row_ranges, inserted_row_ranges)
2529 })
2530 .await;
2531
2532 this.update(&mut cx, |this, cx| {
2533 this.diff.deleted_row_ranges = deleted_row_ranges;
2534 this.diff.inserted_row_ranges = inserted_row_ranges;
2535 this.diff.task = None;
2536 if this.diff.should_update {
2537 this.update_diff(edit_range, cx);
2538 }
2539 cx.notify();
2540 })
2541 .ok();
2542 }));
2543 }
2544 }
2545}
2546
2547struct StripInvalidSpans<T> {
2548 stream: T,
2549 stream_done: bool,
2550 buffer: String,
2551 first_line: bool,
2552 line_end: bool,
2553 starts_with_code_block: bool,
2554}
2555
2556impl<T> StripInvalidSpans<T>
2557where
2558 T: Stream<Item = Result<String>>,
2559{
2560 fn new(stream: T) -> Self {
2561 Self {
2562 stream,
2563 stream_done: false,
2564 buffer: String::new(),
2565 first_line: true,
2566 line_end: false,
2567 starts_with_code_block: false,
2568 }
2569 }
2570}
2571
2572impl<T> Stream for StripInvalidSpans<T>
2573where
2574 T: Stream<Item = Result<String>>,
2575{
2576 type Item = Result<String>;
2577
2578 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2579 const CODE_BLOCK_DELIMITER: &str = "```";
2580 const CURSOR_SPAN: &str = "<|CURSOR|>";
2581
2582 let this = unsafe { self.get_unchecked_mut() };
2583 loop {
2584 if !this.stream_done {
2585 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2586 match stream.as_mut().poll_next(cx) {
2587 Poll::Ready(Some(Ok(chunk))) => {
2588 this.buffer.push_str(&chunk);
2589 }
2590 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2591 Poll::Ready(None) => {
2592 this.stream_done = true;
2593 }
2594 Poll::Pending => return Poll::Pending,
2595 }
2596 }
2597
2598 let mut chunk = String::new();
2599 let mut consumed = 0;
2600 if !this.buffer.is_empty() {
2601 let mut lines = this.buffer.split('\n').enumerate().peekable();
2602 while let Some((line_ix, line)) = lines.next() {
2603 if line_ix > 0 {
2604 this.first_line = false;
2605 }
2606
2607 if this.first_line {
2608 let trimmed_line = line.trim();
2609 if lines.peek().is_some() {
2610 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2611 consumed += line.len() + 1;
2612 this.starts_with_code_block = true;
2613 continue;
2614 }
2615 } else if trimmed_line.is_empty()
2616 || prefixes(CODE_BLOCK_DELIMITER)
2617 .any(|prefix| trimmed_line.starts_with(prefix))
2618 {
2619 break;
2620 }
2621 }
2622
2623 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2624 if lines.peek().is_some() {
2625 if this.line_end {
2626 chunk.push('\n');
2627 }
2628
2629 chunk.push_str(&line_without_cursor);
2630 this.line_end = true;
2631 consumed += line.len() + 1;
2632 } else if this.stream_done {
2633 if !this.starts_with_code_block
2634 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2635 {
2636 if this.line_end {
2637 chunk.push('\n');
2638 }
2639
2640 chunk.push_str(&line);
2641 }
2642
2643 consumed += line.len();
2644 } else {
2645 let trimmed_line = line.trim();
2646 if trimmed_line.is_empty()
2647 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2648 || prefixes(CODE_BLOCK_DELIMITER)
2649 .any(|prefix| trimmed_line.ends_with(prefix))
2650 {
2651 break;
2652 } else {
2653 if this.line_end {
2654 chunk.push('\n');
2655 this.line_end = false;
2656 }
2657
2658 chunk.push_str(&line_without_cursor);
2659 consumed += line.len();
2660 }
2661 }
2662 }
2663 }
2664
2665 this.buffer = this.buffer.split_off(consumed);
2666 if !chunk.is_empty() {
2667 return Poll::Ready(Some(Ok(chunk)));
2668 } else if this.stream_done {
2669 return Poll::Ready(None);
2670 }
2671 }
2672 }
2673}
2674
2675fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2676 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2677}
2678
2679fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2680 ranges.sort_unstable_by(|a, b| {
2681 a.start
2682 .cmp(&b.start, buffer)
2683 .then_with(|| b.end.cmp(&a.end, buffer))
2684 });
2685
2686 let mut ix = 0;
2687 while ix + 1 < ranges.len() {
2688 let b = ranges[ix + 1].clone();
2689 let a = &mut ranges[ix];
2690 if a.end.cmp(&b.start, buffer).is_gt() {
2691 if a.end.cmp(&b.end, buffer).is_lt() {
2692 a.end = b.end;
2693 }
2694 ranges.remove(ix + 1);
2695 } else {
2696 ix += 1;
2697 }
2698 }
2699}
2700
2701#[cfg(test)]
2702mod tests {
2703 use super::*;
2704 use futures::stream::{self};
2705 use gpui::{Context, TestAppContext};
2706 use indoc::indoc;
2707 use language::{
2708 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
2709 Point,
2710 };
2711 use language_model::LanguageModelRegistry;
2712 use rand::prelude::*;
2713 use serde::Serialize;
2714 use settings::SettingsStore;
2715 use std::{future, sync::Arc};
2716
2717 #[derive(Serialize)]
2718 pub struct DummyCompletionRequest {
2719 pub name: String,
2720 }
2721
2722 #[gpui::test(iterations = 10)]
2723 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
2724 cx.set_global(cx.update(SettingsStore::test));
2725 cx.update(language_model::LanguageModelRegistry::test);
2726 cx.update(language_settings::init);
2727
2728 let text = indoc! {"
2729 fn main() {
2730 let x = 0;
2731 for _ in 0..10 {
2732 x += 1;
2733 }
2734 }
2735 "};
2736 let buffer =
2737 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2738 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2739 let range = buffer.read_with(cx, |buffer, cx| {
2740 let snapshot = buffer.snapshot(cx);
2741 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
2742 });
2743 let codegen =
2744 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2745
2746 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2747 codegen.update(cx, |codegen, cx| {
2748 codegen.handle_stream(
2749 String::new(),
2750 range,
2751 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2752 cx,
2753 )
2754 });
2755
2756 let mut new_text = concat!(
2757 " let mut x = 0;\n",
2758 " while x < 10 {\n",
2759 " x += 1;\n",
2760 " }",
2761 );
2762 while !new_text.is_empty() {
2763 let max_len = cmp::min(new_text.len(), 10);
2764 let len = rng.gen_range(1..=max_len);
2765 let (chunk, suffix) = new_text.split_at(len);
2766 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2767 new_text = suffix;
2768 cx.background_executor.run_until_parked();
2769 }
2770 drop(chunks_tx);
2771 cx.background_executor.run_until_parked();
2772
2773 assert_eq!(
2774 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2775 indoc! {"
2776 fn main() {
2777 let mut x = 0;
2778 while x < 10 {
2779 x += 1;
2780 }
2781 }
2782 "}
2783 );
2784 }
2785
2786 #[gpui::test(iterations = 10)]
2787 async fn test_autoindent_when_generating_past_indentation(
2788 cx: &mut TestAppContext,
2789 mut rng: StdRng,
2790 ) {
2791 cx.set_global(cx.update(SettingsStore::test));
2792 cx.update(language_settings::init);
2793
2794 let text = indoc! {"
2795 fn main() {
2796 le
2797 }
2798 "};
2799 let buffer =
2800 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2801 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2802 let range = buffer.read_with(cx, |buffer, cx| {
2803 let snapshot = buffer.snapshot(cx);
2804 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
2805 });
2806 let codegen =
2807 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2808
2809 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2810 codegen.update(cx, |codegen, cx| {
2811 codegen.handle_stream(
2812 String::new(),
2813 range.clone(),
2814 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2815 cx,
2816 )
2817 });
2818
2819 cx.background_executor.run_until_parked();
2820
2821 let mut new_text = concat!(
2822 "t mut x = 0;\n",
2823 "while x < 10 {\n",
2824 " x += 1;\n",
2825 "}", //
2826 );
2827 while !new_text.is_empty() {
2828 let max_len = cmp::min(new_text.len(), 10);
2829 let len = rng.gen_range(1..=max_len);
2830 let (chunk, suffix) = new_text.split_at(len);
2831 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2832 new_text = suffix;
2833 cx.background_executor.run_until_parked();
2834 }
2835 drop(chunks_tx);
2836 cx.background_executor.run_until_parked();
2837
2838 assert_eq!(
2839 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2840 indoc! {"
2841 fn main() {
2842 let mut x = 0;
2843 while x < 10 {
2844 x += 1;
2845 }
2846 }
2847 "}
2848 );
2849 }
2850
2851 #[gpui::test(iterations = 10)]
2852 async fn test_autoindent_when_generating_before_indentation(
2853 cx: &mut TestAppContext,
2854 mut rng: StdRng,
2855 ) {
2856 cx.update(LanguageModelRegistry::test);
2857 cx.set_global(cx.update(SettingsStore::test));
2858 cx.update(language_settings::init);
2859
2860 let text = concat!(
2861 "fn main() {\n",
2862 " \n",
2863 "}\n" //
2864 );
2865 let buffer =
2866 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
2867 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2868 let range = buffer.read_with(cx, |buffer, cx| {
2869 let snapshot = buffer.snapshot(cx);
2870 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
2871 });
2872 let codegen =
2873 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2874
2875 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2876 codegen.update(cx, |codegen, cx| {
2877 codegen.handle_stream(
2878 String::new(),
2879 range.clone(),
2880 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2881 cx,
2882 )
2883 });
2884
2885 cx.background_executor.run_until_parked();
2886
2887 let mut new_text = concat!(
2888 "let mut x = 0;\n",
2889 "while x < 10 {\n",
2890 " x += 1;\n",
2891 "}", //
2892 );
2893 while !new_text.is_empty() {
2894 let max_len = cmp::min(new_text.len(), 10);
2895 let len = rng.gen_range(1..=max_len);
2896 let (chunk, suffix) = new_text.split_at(len);
2897 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
2898 new_text = suffix;
2899 cx.background_executor.run_until_parked();
2900 }
2901 drop(chunks_tx);
2902 cx.background_executor.run_until_parked();
2903
2904 assert_eq!(
2905 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2906 indoc! {"
2907 fn main() {
2908 let mut x = 0;
2909 while x < 10 {
2910 x += 1;
2911 }
2912 }
2913 "}
2914 );
2915 }
2916
2917 #[gpui::test(iterations = 10)]
2918 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
2919 cx.update(LanguageModelRegistry::test);
2920 cx.set_global(cx.update(SettingsStore::test));
2921 cx.update(language_settings::init);
2922
2923 let text = indoc! {"
2924 func main() {
2925 \tx := 0
2926 \tfor i := 0; i < 10; i++ {
2927 \t\tx++
2928 \t}
2929 }
2930 "};
2931 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
2932 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
2933 let range = buffer.read_with(cx, |buffer, cx| {
2934 let snapshot = buffer.snapshot(cx);
2935 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
2936 });
2937 let codegen =
2938 cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx));
2939
2940 let (chunks_tx, chunks_rx) = mpsc::unbounded();
2941 codegen.update(cx, |codegen, cx| {
2942 codegen.handle_stream(
2943 String::new(),
2944 range.clone(),
2945 future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
2946 cx,
2947 )
2948 });
2949
2950 let new_text = concat!(
2951 "func main() {\n",
2952 "\tx := 0\n",
2953 "\tfor x < 10 {\n",
2954 "\t\tx++\n",
2955 "\t}", //
2956 );
2957 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
2958 drop(chunks_tx);
2959 cx.background_executor.run_until_parked();
2960
2961 assert_eq!(
2962 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
2963 indoc! {"
2964 func main() {
2965 \tx := 0
2966 \tfor x < 10 {
2967 \t\tx++
2968 \t}
2969 }
2970 "}
2971 );
2972 }
2973
2974 #[gpui::test]
2975 async fn test_strip_invalid_spans_from_codeblock() {
2976 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2977 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2978 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
2979 assert_chunks(
2980 "```html\n```js\nLorem ipsum dolor\n```\n```",
2981 "```js\nLorem ipsum dolor\n```",
2982 )
2983 .await;
2984 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
2985 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
2986 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
2987 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
2988
2989 async fn assert_chunks(text: &str, expected_text: &str) {
2990 for chunk_size in 1..=text.len() {
2991 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
2992 .map(|chunk| chunk.unwrap())
2993 .collect::<String>()
2994 .await;
2995 assert_eq!(
2996 actual_text, expected_text,
2997 "failed to strip invalid spans, chunk size: {}",
2998 chunk_size
2999 );
3000 }
3001 }
3002
3003 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3004 stream::iter(
3005 text.chars()
3006 .collect::<Vec<_>>()
3007 .chunks(size)
3008 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3009 .collect::<Vec<_>>(),
3010 )
3011 }
3012 }
3013
3014 fn rust_lang() -> Language {
3015 Language::new(
3016 LanguageConfig {
3017 name: "Rust".into(),
3018 matcher: LanguageMatcher {
3019 path_suffixes: vec!["rs".to_string()],
3020 ..Default::default()
3021 },
3022 ..Default::default()
3023 },
3024 Some(tree_sitter_rust::language()),
3025 )
3026 .with_indents_query(
3027 r#"
3028 (call_expression) @indent
3029 (field_expression) @indent
3030 (_ "(" ")" @end) @indent
3031 (_ "{" "}" @end) @indent
3032 "#,
3033 )
3034 .unwrap()
3035 }
3036}