1use crate::{
2 humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
3 CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use smol::future::FutureExt;
39use std::{
40 future::{self, Future},
41 mem,
42 ops::{Range, RangeInclusive},
43 pin::Pin,
44 sync::Arc,
45 task::{self, Poll},
46 time::{Duration, Instant},
47};
48use text::ToOffset as _;
49use theme::ThemeSettings;
50use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
51use util::{RangeExt, ResultExt};
52use workspace::{notifications::NotificationId, Toast, Workspace};
53
54pub fn init(
55 fs: Arc<dyn Fs>,
56 prompt_builder: Arc<PromptBuilder>,
57 telemetry: Arc<Telemetry>,
58 cx: &mut AppContext,
59) {
60 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
61 cx.observe_new_views(|_, cx| {
62 let workspace = cx.view().clone();
63 InlineAssistant::update_global(cx, |inline_assistant, cx| {
64 inline_assistant.register_workspace(&workspace, cx)
65 })
66 })
67 .detach();
68}
69
70const PROMPT_HISTORY_MAX_LEN: usize = 20;
71
72pub struct InlineAssistant {
73 next_assist_id: InlineAssistId,
74 next_assist_group_id: InlineAssistGroupId,
75 assists: HashMap<InlineAssistId, InlineAssist>,
76 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
77 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
78 assist_observations:
79 HashMap<InlineAssistId, (async_watch::Sender<()>, async_watch::Receiver<()>)>,
80 confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
81 prompt_history: VecDeque<String>,
82 prompt_builder: Arc<PromptBuilder>,
83 telemetry: Option<Arc<Telemetry>>,
84 fs: Arc<dyn Fs>,
85}
86
87impl Global for InlineAssistant {}
88
89impl InlineAssistant {
90 pub fn new(
91 fs: Arc<dyn Fs>,
92 prompt_builder: Arc<PromptBuilder>,
93 telemetry: Arc<Telemetry>,
94 ) -> Self {
95 Self {
96 next_assist_id: InlineAssistId::default(),
97 next_assist_group_id: InlineAssistGroupId::default(),
98 assists: HashMap::default(),
99 assists_by_editor: HashMap::default(),
100 assist_groups: HashMap::default(),
101 assist_observations: HashMap::default(),
102 confirmed_assists: HashMap::default(),
103 prompt_history: VecDeque::default(),
104 prompt_builder,
105 telemetry: Some(telemetry),
106 fs,
107 }
108 }
109
110 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
111 cx.subscribe(workspace, |_, event, cx| {
112 Self::update_global(cx, |this, cx| this.handle_workspace_event(event, cx));
113 })
114 .detach();
115 }
116
117 fn handle_workspace_event(&mut self, event: &workspace::Event, cx: &mut WindowContext) {
118 // When the user manually saves an editor, automatically accepts all finished transformations.
119 if let workspace::Event::UserSavedItem { item, .. } = event {
120 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
121 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
122 for assist_id in editor_assists.assist_ids.clone() {
123 let assist = &self.assists[&assist_id];
124 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
125 self.finish_assist(assist_id, false, cx)
126 }
127 }
128 }
129 }
130 }
131 }
132
133 pub fn assist(
134 &mut self,
135 editor: &View<Editor>,
136 workspace: Option<WeakView<Workspace>>,
137 assistant_panel: Option<&View<AssistantPanel>>,
138 initial_prompt: Option<String>,
139 cx: &mut WindowContext,
140 ) {
141 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
142
143 struct CodegenRange {
144 transform_range: Range<Point>,
145 selection_ranges: Vec<Range<Point>>,
146 focus_assist: bool,
147 }
148
149 let newest_selection = editor.read(cx).selections.newest::<Point>(cx);
150 let mut codegen_ranges: Vec<CodegenRange> = Vec::new();
151 for selection in editor.read(cx).selections.all::<Point>(cx) {
152 let selection_is_newest = selection.id == newest_selection.id;
153 let mut transform_range = selection.start..selection.end;
154
155 // Expand the transform range to start/end of lines.
156 // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
157 transform_range.start.column = 0;
158 if transform_range.end.column == 0 && transform_range.end > transform_range.start {
159 transform_range.end.row -= 1;
160 }
161 transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
162 let selection_range = selection.start..selection.end.min(transform_range.end);
163
164 // If we intersect the previous transform range,
165 if let Some(CodegenRange {
166 transform_range: prev_transform_range,
167 selection_ranges,
168 focus_assist,
169 }) = codegen_ranges.last_mut()
170 {
171 if transform_range.start <= prev_transform_range.end {
172 prev_transform_range.end = transform_range.end;
173 selection_ranges.push(selection_range);
174 *focus_assist |= selection_is_newest;
175 continue;
176 }
177 }
178
179 codegen_ranges.push(CodegenRange {
180 transform_range,
181 selection_ranges: vec![selection_range],
182 focus_assist: selection_is_newest,
183 })
184 }
185
186 let assist_group_id = self.next_assist_group_id.post_inc();
187 let prompt_buffer =
188 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
189 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
190 let mut assists = Vec::new();
191 let mut assist_to_focus = None;
192
193 for CodegenRange {
194 transform_range,
195 selection_ranges,
196 focus_assist,
197 } in codegen_ranges
198 {
199 let transform_range = snapshot.anchor_before(transform_range.start)
200 ..snapshot.anchor_after(transform_range.end);
201 let selection_ranges = selection_ranges
202 .iter()
203 .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
204 .collect::<Vec<_>>();
205
206 let codegen = cx.new_model(|cx| {
207 Codegen::new(
208 editor.read(cx).buffer().clone(),
209 transform_range.clone(),
210 selection_ranges,
211 None,
212 self.telemetry.clone(),
213 self.prompt_builder.clone(),
214 cx,
215 )
216 });
217
218 let assist_id = self.next_assist_id.post_inc();
219 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
220 let prompt_editor = cx.new_view(|cx| {
221 PromptEditor::new(
222 assist_id,
223 gutter_dimensions.clone(),
224 self.prompt_history.clone(),
225 prompt_buffer.clone(),
226 codegen.clone(),
227 editor,
228 assistant_panel,
229 workspace.clone(),
230 self.fs.clone(),
231 cx,
232 )
233 });
234
235 if focus_assist {
236 assist_to_focus = Some(assist_id);
237 }
238
239 let [prompt_block_id, end_block_id] =
240 self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
241
242 assists.push((
243 assist_id,
244 transform_range,
245 prompt_editor,
246 prompt_block_id,
247 end_block_id,
248 ));
249 }
250
251 let editor_assists = self
252 .assists_by_editor
253 .entry(editor.downgrade())
254 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
255 let mut assist_group = InlineAssistGroup::new();
256 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
257 self.assists.insert(
258 assist_id,
259 InlineAssist::new(
260 assist_id,
261 assist_group_id,
262 assistant_panel.is_some(),
263 editor,
264 &prompt_editor,
265 prompt_block_id,
266 end_block_id,
267 range,
268 prompt_editor.read(cx).codegen.clone(),
269 workspace.clone(),
270 cx,
271 ),
272 );
273 assist_group.assist_ids.push(assist_id);
274 editor_assists.assist_ids.push(assist_id);
275 }
276 self.assist_groups.insert(assist_group_id, assist_group);
277
278 if let Some(assist_id) = assist_to_focus {
279 self.focus_assist(assist_id, cx);
280 }
281 }
282
283 #[allow(clippy::too_many_arguments)]
284 pub fn suggest_assist(
285 &mut self,
286 editor: &View<Editor>,
287 mut range: Range<Anchor>,
288 initial_prompt: String,
289 initial_transaction_id: Option<TransactionId>,
290 workspace: Option<WeakView<Workspace>>,
291 assistant_panel: Option<&View<AssistantPanel>>,
292 cx: &mut WindowContext,
293 ) -> InlineAssistId {
294 let assist_group_id = self.next_assist_group_id.post_inc();
295 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
296 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
297
298 let assist_id = self.next_assist_id.post_inc();
299
300 let buffer = editor.read(cx).buffer().clone();
301 {
302 let snapshot = buffer.read(cx).read(cx);
303 range.start = range.start.bias_left(&snapshot);
304 range.end = range.end.bias_right(&snapshot);
305 }
306
307 let codegen = cx.new_model(|cx| {
308 Codegen::new(
309 editor.read(cx).buffer().clone(),
310 range.clone(),
311 vec![range.clone()],
312 initial_transaction_id,
313 self.telemetry.clone(),
314 self.prompt_builder.clone(),
315 cx,
316 )
317 });
318
319 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
320 let prompt_editor = cx.new_view(|cx| {
321 PromptEditor::new(
322 assist_id,
323 gutter_dimensions.clone(),
324 self.prompt_history.clone(),
325 prompt_buffer.clone(),
326 codegen.clone(),
327 editor,
328 assistant_panel,
329 workspace.clone(),
330 self.fs.clone(),
331 cx,
332 )
333 });
334
335 let [prompt_block_id, end_block_id] =
336 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
337
338 let editor_assists = self
339 .assists_by_editor
340 .entry(editor.downgrade())
341 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
342
343 let mut assist_group = InlineAssistGroup::new();
344 self.assists.insert(
345 assist_id,
346 InlineAssist::new(
347 assist_id,
348 assist_group_id,
349 assistant_panel.is_some(),
350 editor,
351 &prompt_editor,
352 prompt_block_id,
353 end_block_id,
354 range,
355 prompt_editor.read(cx).codegen.clone(),
356 workspace.clone(),
357 cx,
358 ),
359 );
360 assist_group.assist_ids.push(assist_id);
361 editor_assists.assist_ids.push(assist_id);
362 self.assist_groups.insert(assist_group_id, assist_group);
363 assist_id
364 }
365
366 fn insert_assist_blocks(
367 &self,
368 editor: &View<Editor>,
369 range: &Range<Anchor>,
370 prompt_editor: &View<PromptEditor>,
371 cx: &mut WindowContext,
372 ) -> [CustomBlockId; 2] {
373 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
374 prompt_editor
375 .editor
376 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
377 });
378 let assist_blocks = vec![
379 BlockProperties {
380 style: BlockStyle::Sticky,
381 position: range.start,
382 height: prompt_editor_height,
383 render: build_assist_editor_renderer(prompt_editor),
384 disposition: BlockDisposition::Above,
385 priority: 0,
386 },
387 BlockProperties {
388 style: BlockStyle::Sticky,
389 position: range.end,
390 height: 0,
391 render: Box::new(|cx| {
392 v_flex()
393 .h_full()
394 .w_full()
395 .border_t_1()
396 .border_color(cx.theme().status().info_border)
397 .into_any_element()
398 }),
399 disposition: BlockDisposition::Below,
400 priority: 0,
401 },
402 ];
403
404 editor.update(cx, |editor, cx| {
405 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
406 [block_ids[0], block_ids[1]]
407 })
408 }
409
410 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
411 let assist = &self.assists[&assist_id];
412 let Some(decorations) = assist.decorations.as_ref() else {
413 return;
414 };
415 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
416 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
417
418 assist_group.active_assist_id = Some(assist_id);
419 if assist_group.linked {
420 for assist_id in &assist_group.assist_ids {
421 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
422 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
423 prompt_editor.set_show_cursor_when_unfocused(true, cx)
424 });
425 }
426 }
427 }
428
429 assist
430 .editor
431 .update(cx, |editor, cx| {
432 let scroll_top = editor.scroll_position(cx).y;
433 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
434 let prompt_row = editor
435 .row_for_block(decorations.prompt_block_id, cx)
436 .unwrap()
437 .0 as f32;
438
439 if (scroll_top..scroll_bottom).contains(&prompt_row) {
440 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
441 assist_id,
442 distance_from_top: prompt_row - scroll_top,
443 });
444 } else {
445 editor_assists.scroll_lock = None;
446 }
447 })
448 .ok();
449 }
450
451 fn handle_prompt_editor_focus_out(
452 &mut self,
453 assist_id: InlineAssistId,
454 cx: &mut WindowContext,
455 ) {
456 let assist = &self.assists[&assist_id];
457 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
458 if assist_group.active_assist_id == Some(assist_id) {
459 assist_group.active_assist_id = None;
460 if assist_group.linked {
461 for assist_id in &assist_group.assist_ids {
462 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
463 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
464 prompt_editor.set_show_cursor_when_unfocused(false, cx)
465 });
466 }
467 }
468 }
469 }
470 }
471
472 fn handle_prompt_editor_event(
473 &mut self,
474 prompt_editor: View<PromptEditor>,
475 event: &PromptEditorEvent,
476 cx: &mut WindowContext,
477 ) {
478 let assist_id = prompt_editor.read(cx).id;
479 match event {
480 PromptEditorEvent::StartRequested => {
481 self.start_assist(assist_id, cx);
482 }
483 PromptEditorEvent::StopRequested => {
484 self.stop_assist(assist_id, cx);
485 }
486 PromptEditorEvent::ConfirmRequested => {
487 self.finish_assist(assist_id, false, cx);
488 }
489 PromptEditorEvent::CancelRequested => {
490 self.finish_assist(assist_id, true, cx);
491 }
492 PromptEditorEvent::DismissRequested => {
493 self.dismiss_assist(assist_id, cx);
494 }
495 }
496 }
497
498 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
499 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
500 return;
501 };
502
503 let editor = editor.read(cx);
504 if editor.selections.count() == 1 {
505 let selection = editor.selections.newest::<usize>(cx);
506 let buffer = editor.buffer().read(cx).snapshot(cx);
507 for assist_id in &editor_assists.assist_ids {
508 let assist = &self.assists[assist_id];
509 let assist_range = assist.range.to_offset(&buffer);
510 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
511 {
512 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
513 self.dismiss_assist(*assist_id, cx);
514 } else {
515 self.finish_assist(*assist_id, false, cx);
516 }
517
518 return;
519 }
520 }
521 }
522
523 cx.propagate();
524 }
525
526 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
527 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
528 return;
529 };
530
531 let editor = editor.read(cx);
532 if editor.selections.count() == 1 {
533 let selection = editor.selections.newest::<usize>(cx);
534 let buffer = editor.buffer().read(cx).snapshot(cx);
535 for assist_id in &editor_assists.assist_ids {
536 let assist = &self.assists[assist_id];
537 let assist_range = assist.range.to_offset(&buffer);
538 if assist.decorations.is_some()
539 && assist_range.contains(&selection.start)
540 && assist_range.contains(&selection.end)
541 {
542 self.focus_assist(*assist_id, cx);
543 return;
544 }
545 }
546 }
547
548 cx.propagate();
549 }
550
551 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
552 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
553 for assist_id in editor_assists.assist_ids.clone() {
554 self.finish_assist(assist_id, true, cx);
555 }
556 }
557 }
558
559 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
560 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
561 return;
562 };
563 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
564 return;
565 };
566 let assist = &self.assists[&scroll_lock.assist_id];
567 let Some(decorations) = assist.decorations.as_ref() else {
568 return;
569 };
570
571 editor.update(cx, |editor, cx| {
572 let scroll_position = editor.scroll_position(cx);
573 let target_scroll_top = editor
574 .row_for_block(decorations.prompt_block_id, cx)
575 .unwrap()
576 .0 as f32
577 - scroll_lock.distance_from_top;
578 if target_scroll_top != scroll_position.y {
579 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
580 }
581 });
582 }
583
584 fn handle_editor_event(
585 &mut self,
586 editor: View<Editor>,
587 event: &EditorEvent,
588 cx: &mut WindowContext,
589 ) {
590 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
591 return;
592 };
593
594 match event {
595 EditorEvent::Edited { transaction_id } => {
596 let buffer = editor.read(cx).buffer().read(cx);
597 let edited_ranges =
598 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
599 let snapshot = buffer.snapshot(cx);
600
601 for assist_id in editor_assists.assist_ids.clone() {
602 let assist = &self.assists[&assist_id];
603 if matches!(
604 assist.codegen.read(cx).status,
605 CodegenStatus::Error(_) | CodegenStatus::Done
606 ) {
607 let assist_range = assist.range.to_offset(&snapshot);
608 if edited_ranges
609 .iter()
610 .any(|range| range.overlaps(&assist_range))
611 {
612 self.finish_assist(assist_id, false, cx);
613 }
614 }
615 }
616 }
617 EditorEvent::ScrollPositionChanged { .. } => {
618 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
619 let assist = &self.assists[&scroll_lock.assist_id];
620 if let Some(decorations) = assist.decorations.as_ref() {
621 let distance_from_top = editor.update(cx, |editor, cx| {
622 let scroll_top = editor.scroll_position(cx).y;
623 let prompt_row = editor
624 .row_for_block(decorations.prompt_block_id, cx)
625 .unwrap()
626 .0 as f32;
627 prompt_row - scroll_top
628 });
629
630 if distance_from_top != scroll_lock.distance_from_top {
631 editor_assists.scroll_lock = None;
632 }
633 }
634 }
635 }
636 EditorEvent::SelectionsChanged { .. } => {
637 for assist_id in editor_assists.assist_ids.clone() {
638 let assist = &self.assists[&assist_id];
639 if let Some(decorations) = assist.decorations.as_ref() {
640 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
641 return;
642 }
643 }
644 }
645
646 editor_assists.scroll_lock = None;
647 }
648 _ => {}
649 }
650 }
651
652 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
653 if let Some(assist) = self.assists.get(&assist_id) {
654 let assist_group_id = assist.group_id;
655 if self.assist_groups[&assist_group_id].linked {
656 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
657 self.finish_assist(assist_id, undo, cx);
658 }
659 return;
660 }
661 }
662
663 self.dismiss_assist(assist_id, cx);
664
665 if let Some(assist) = self.assists.remove(&assist_id) {
666 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
667 {
668 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
669 if entry.get().assist_ids.is_empty() {
670 entry.remove();
671 }
672 }
673
674 if let hash_map::Entry::Occupied(mut entry) =
675 self.assists_by_editor.entry(assist.editor.clone())
676 {
677 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
678 if entry.get().assist_ids.is_empty() {
679 entry.remove();
680 if let Some(editor) = assist.editor.upgrade() {
681 self.update_editor_highlights(&editor, cx);
682 }
683 } else {
684 entry.get().highlight_updates.send(()).ok();
685 }
686 }
687
688 if undo {
689 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
690 } else {
691 self.confirmed_assists.insert(assist_id, assist.codegen);
692 }
693 }
694
695 // Remove the assist from the status updates map
696 self.assist_observations.remove(&assist_id);
697 }
698
699 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
700 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
701 return false;
702 };
703 codegen.update(cx, |this, cx| this.undo(cx));
704 true
705 }
706
707 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
708 let Some(assist) = self.assists.get_mut(&assist_id) else {
709 return false;
710 };
711 let Some(editor) = assist.editor.upgrade() else {
712 return false;
713 };
714 let Some(decorations) = assist.decorations.take() else {
715 return false;
716 };
717
718 editor.update(cx, |editor, cx| {
719 let mut to_remove = decorations.removed_line_block_ids;
720 to_remove.insert(decorations.prompt_block_id);
721 to_remove.insert(decorations.end_block_id);
722 editor.remove_blocks(to_remove, None, cx);
723 });
724
725 if decorations
726 .prompt_editor
727 .focus_handle(cx)
728 .contains_focused(cx)
729 {
730 self.focus_next_assist(assist_id, cx);
731 }
732
733 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
734 if editor_assists
735 .scroll_lock
736 .as_ref()
737 .map_or(false, |lock| lock.assist_id == assist_id)
738 {
739 editor_assists.scroll_lock = None;
740 }
741 editor_assists.highlight_updates.send(()).ok();
742 }
743
744 true
745 }
746
747 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
748 let Some(assist) = self.assists.get(&assist_id) else {
749 return;
750 };
751
752 let assist_group = &self.assist_groups[&assist.group_id];
753 let assist_ix = assist_group
754 .assist_ids
755 .iter()
756 .position(|id| *id == assist_id)
757 .unwrap();
758 let assist_ids = assist_group
759 .assist_ids
760 .iter()
761 .skip(assist_ix + 1)
762 .chain(assist_group.assist_ids.iter().take(assist_ix));
763
764 for assist_id in assist_ids {
765 let assist = &self.assists[assist_id];
766 if assist.decorations.is_some() {
767 self.focus_assist(*assist_id, cx);
768 return;
769 }
770 }
771
772 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
773 }
774
775 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
776 let Some(assist) = self.assists.get(&assist_id) else {
777 return;
778 };
779
780 if let Some(decorations) = assist.decorations.as_ref() {
781 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
782 prompt_editor.editor.update(cx, |editor, cx| {
783 editor.focus(cx);
784 editor.select_all(&SelectAll, cx);
785 })
786 });
787 }
788
789 self.scroll_to_assist(assist_id, cx);
790 }
791
792 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
793 let Some(assist) = self.assists.get(&assist_id) else {
794 return;
795 };
796 let Some(editor) = assist.editor.upgrade() else {
797 return;
798 };
799
800 let position = assist.range.start;
801 editor.update(cx, |editor, cx| {
802 editor.change_selections(None, cx, |selections| {
803 selections.select_anchor_ranges([position..position])
804 });
805
806 let mut scroll_target_top;
807 let mut scroll_target_bottom;
808 if let Some(decorations) = assist.decorations.as_ref() {
809 scroll_target_top = editor
810 .row_for_block(decorations.prompt_block_id, cx)
811 .unwrap()
812 .0 as f32;
813 scroll_target_bottom = editor
814 .row_for_block(decorations.end_block_id, cx)
815 .unwrap()
816 .0 as f32;
817 } else {
818 let snapshot = editor.snapshot(cx);
819 let start_row = assist
820 .range
821 .start
822 .to_display_point(&snapshot.display_snapshot)
823 .row();
824 scroll_target_top = start_row.0 as f32;
825 scroll_target_bottom = scroll_target_top + 1.;
826 }
827 scroll_target_top -= editor.vertical_scroll_margin() as f32;
828 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
829
830 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
831 let scroll_top = editor.scroll_position(cx).y;
832 let scroll_bottom = scroll_top + height_in_lines;
833
834 if scroll_target_top < scroll_top {
835 editor.set_scroll_position(point(0., scroll_target_top), cx);
836 } else if scroll_target_bottom > scroll_bottom {
837 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
838 editor
839 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
840 } else {
841 editor.set_scroll_position(point(0., scroll_target_top), cx);
842 }
843 }
844 });
845 }
846
847 fn unlink_assist_group(
848 &mut self,
849 assist_group_id: InlineAssistGroupId,
850 cx: &mut WindowContext,
851 ) -> Vec<InlineAssistId> {
852 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
853 assist_group.linked = false;
854 for assist_id in &assist_group.assist_ids {
855 let assist = self.assists.get_mut(assist_id).unwrap();
856 if let Some(editor_decorations) = assist.decorations.as_ref() {
857 editor_decorations
858 .prompt_editor
859 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
860 }
861 }
862 assist_group.assist_ids.clone()
863 }
864
865 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
866 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
867 assist
868 } else {
869 return;
870 };
871
872 let assist_group_id = assist.group_id;
873 if self.assist_groups[&assist_group_id].linked {
874 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
875 self.start_assist(assist_id, cx);
876 }
877 return;
878 }
879
880 let Some(user_prompt) = assist.user_prompt(cx) else {
881 return;
882 };
883
884 self.prompt_history.retain(|prompt| *prompt != user_prompt);
885 self.prompt_history.push_back(user_prompt.clone());
886 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
887 self.prompt_history.pop_front();
888 }
889
890 let assistant_panel_context = assist.assistant_panel_context(cx);
891
892 assist
893 .codegen
894 .update(cx, |codegen, cx| {
895 codegen.start(user_prompt, assistant_panel_context, cx)
896 })
897 .log_err();
898
899 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
900 tx.send(()).ok();
901 }
902 }
903
904 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
905 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
906 assist
907 } else {
908 return;
909 };
910
911 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
912
913 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
914 tx.send(()).ok();
915 }
916 }
917
918 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
919 if let Some(assist) = self.assists.get(&assist_id) {
920 match &assist.codegen.read(cx).status {
921 CodegenStatus::Idle => InlineAssistStatus::Idle,
922 CodegenStatus::Pending => InlineAssistStatus::Pending,
923 CodegenStatus::Done => InlineAssistStatus::Done,
924 CodegenStatus::Error(_) => InlineAssistStatus::Error,
925 }
926 } else if self.confirmed_assists.contains_key(&assist_id) {
927 InlineAssistStatus::Confirmed
928 } else {
929 InlineAssistStatus::Canceled
930 }
931 }
932
933 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
934 let mut gutter_pending_ranges = Vec::new();
935 let mut gutter_transformed_ranges = Vec::new();
936 let mut foreground_ranges = Vec::new();
937 let mut inserted_row_ranges = Vec::new();
938 let empty_assist_ids = Vec::new();
939 let assist_ids = self
940 .assists_by_editor
941 .get(&editor.downgrade())
942 .map_or(&empty_assist_ids, |editor_assists| {
943 &editor_assists.assist_ids
944 });
945
946 for assist_id in assist_ids {
947 if let Some(assist) = self.assists.get(assist_id) {
948 let codegen = assist.codegen.read(cx);
949 let buffer = codegen.buffer.read(cx).read(cx);
950 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
951
952 let pending_range =
953 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
954 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
955 gutter_pending_ranges.push(pending_range);
956 }
957
958 if let Some(edit_position) = codegen.edit_position {
959 let edited_range = assist.range.start..edit_position;
960 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
961 gutter_transformed_ranges.push(edited_range);
962 }
963 }
964
965 if assist.decorations.is_some() {
966 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
967 }
968 }
969 }
970
971 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
972 merge_ranges(&mut foreground_ranges, &snapshot);
973 merge_ranges(&mut gutter_pending_ranges, &snapshot);
974 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
975 editor.update(cx, |editor, cx| {
976 enum GutterPendingRange {}
977 if gutter_pending_ranges.is_empty() {
978 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
979 } else {
980 editor.highlight_gutter::<GutterPendingRange>(
981 &gutter_pending_ranges,
982 |cx| cx.theme().status().info_background,
983 cx,
984 )
985 }
986
987 enum GutterTransformedRange {}
988 if gutter_transformed_ranges.is_empty() {
989 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
990 } else {
991 editor.highlight_gutter::<GutterTransformedRange>(
992 &gutter_transformed_ranges,
993 |cx| cx.theme().status().info,
994 cx,
995 )
996 }
997
998 if foreground_ranges.is_empty() {
999 editor.clear_highlights::<InlineAssist>(cx);
1000 } else {
1001 editor.highlight_text::<InlineAssist>(
1002 foreground_ranges,
1003 HighlightStyle {
1004 fade_out: Some(0.6),
1005 ..Default::default()
1006 },
1007 cx,
1008 );
1009 }
1010
1011 editor.clear_row_highlights::<InlineAssist>();
1012 for row_range in inserted_row_ranges {
1013 editor.highlight_rows::<InlineAssist>(
1014 row_range,
1015 Some(cx.theme().status().info_background),
1016 false,
1017 cx,
1018 );
1019 }
1020 });
1021 }
1022
1023 fn update_editor_blocks(
1024 &mut self,
1025 editor: &View<Editor>,
1026 assist_id: InlineAssistId,
1027 cx: &mut WindowContext,
1028 ) {
1029 let Some(assist) = self.assists.get_mut(&assist_id) else {
1030 return;
1031 };
1032 let Some(decorations) = assist.decorations.as_mut() else {
1033 return;
1034 };
1035
1036 let codegen = assist.codegen.read(cx);
1037 let old_snapshot = codegen.snapshot.clone();
1038 let old_buffer = codegen.old_buffer.clone();
1039 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
1040
1041 editor.update(cx, |editor, cx| {
1042 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1043 editor.remove_blocks(old_blocks, None, cx);
1044
1045 let mut new_blocks = Vec::new();
1046 for (new_row, old_row_range) in deleted_row_ranges {
1047 let (_, buffer_start) = old_snapshot
1048 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1049 .unwrap();
1050 let (_, buffer_end) = old_snapshot
1051 .point_to_buffer_offset(Point::new(
1052 *old_row_range.end(),
1053 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1054 ))
1055 .unwrap();
1056
1057 let deleted_lines_editor = cx.new_view(|cx| {
1058 let multi_buffer = cx.new_model(|_| {
1059 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1060 });
1061 multi_buffer.update(cx, |multi_buffer, cx| {
1062 multi_buffer.push_excerpts(
1063 old_buffer.clone(),
1064 Some(ExcerptRange {
1065 context: buffer_start..buffer_end,
1066 primary: None,
1067 }),
1068 cx,
1069 );
1070 });
1071
1072 enum DeletedLines {}
1073 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1074 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1075 editor.set_show_wrap_guides(false, cx);
1076 editor.set_show_gutter(false, cx);
1077 editor.scroll_manager.set_forbid_vertical_scroll(true);
1078 editor.set_read_only(true);
1079 editor.highlight_rows::<DeletedLines>(
1080 Anchor::min()..=Anchor::max(),
1081 Some(cx.theme().status().deleted_background),
1082 false,
1083 cx,
1084 );
1085 editor
1086 });
1087
1088 let height =
1089 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1090 new_blocks.push(BlockProperties {
1091 position: new_row,
1092 height,
1093 style: BlockStyle::Flex,
1094 render: Box::new(move |cx| {
1095 div()
1096 .bg(cx.theme().status().deleted_background)
1097 .size_full()
1098 .h(height as f32 * cx.line_height())
1099 .pl(cx.gutter_dimensions.full_width())
1100 .child(deleted_lines_editor.clone())
1101 .into_any_element()
1102 }),
1103 disposition: BlockDisposition::Above,
1104 priority: 0,
1105 });
1106 }
1107
1108 decorations.removed_line_block_ids = editor
1109 .insert_blocks(new_blocks, None, cx)
1110 .into_iter()
1111 .collect();
1112 })
1113 }
1114
1115 pub fn observe_assist(&mut self, assist_id: InlineAssistId) -> async_watch::Receiver<()> {
1116 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1117 rx.clone()
1118 } else {
1119 let (tx, rx) = async_watch::channel(());
1120 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1121 rx
1122 }
1123 }
1124}
1125
1126pub enum InlineAssistStatus {
1127 Idle,
1128 Pending,
1129 Done,
1130 Error,
1131 Confirmed,
1132 Canceled,
1133}
1134
1135impl InlineAssistStatus {
1136 pub(crate) fn is_pending(&self) -> bool {
1137 matches!(self, Self::Pending)
1138 }
1139 pub(crate) fn is_confirmed(&self) -> bool {
1140 matches!(self, Self::Confirmed)
1141 }
1142 pub(crate) fn is_done(&self) -> bool {
1143 matches!(self, Self::Done)
1144 }
1145}
1146
1147struct EditorInlineAssists {
1148 assist_ids: Vec<InlineAssistId>,
1149 scroll_lock: Option<InlineAssistScrollLock>,
1150 highlight_updates: async_watch::Sender<()>,
1151 _update_highlights: Task<Result<()>>,
1152 _subscriptions: Vec<gpui::Subscription>,
1153}
1154
1155struct InlineAssistScrollLock {
1156 assist_id: InlineAssistId,
1157 distance_from_top: f32,
1158}
1159
1160impl EditorInlineAssists {
1161 #[allow(clippy::too_many_arguments)]
1162 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1163 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1164 Self {
1165 assist_ids: Vec::new(),
1166 scroll_lock: None,
1167 highlight_updates: highlight_updates_tx,
1168 _update_highlights: cx.spawn(|mut cx| {
1169 let editor = editor.downgrade();
1170 async move {
1171 while let Ok(()) = highlight_updates_rx.changed().await {
1172 let editor = editor.upgrade().context("editor was dropped")?;
1173 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1174 assistant.update_editor_highlights(&editor, cx);
1175 })?;
1176 }
1177 Ok(())
1178 }
1179 }),
1180 _subscriptions: vec![
1181 cx.observe_release(editor, {
1182 let editor = editor.downgrade();
1183 |_, cx| {
1184 InlineAssistant::update_global(cx, |this, cx| {
1185 this.handle_editor_release(editor, cx);
1186 })
1187 }
1188 }),
1189 cx.observe(editor, move |editor, cx| {
1190 InlineAssistant::update_global(cx, |this, cx| {
1191 this.handle_editor_change(editor, cx)
1192 })
1193 }),
1194 cx.subscribe(editor, move |editor, event, cx| {
1195 InlineAssistant::update_global(cx, |this, cx| {
1196 this.handle_editor_event(editor, event, cx)
1197 })
1198 }),
1199 editor.update(cx, |editor, cx| {
1200 let editor_handle = cx.view().downgrade();
1201 editor.register_action(
1202 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1203 InlineAssistant::update_global(cx, |this, cx| {
1204 if let Some(editor) = editor_handle.upgrade() {
1205 this.handle_editor_newline(editor, cx)
1206 }
1207 })
1208 },
1209 )
1210 }),
1211 editor.update(cx, |editor, cx| {
1212 let editor_handle = cx.view().downgrade();
1213 editor.register_action(
1214 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1215 InlineAssistant::update_global(cx, |this, cx| {
1216 if let Some(editor) = editor_handle.upgrade() {
1217 this.handle_editor_cancel(editor, cx)
1218 }
1219 })
1220 },
1221 )
1222 }),
1223 ],
1224 }
1225 }
1226}
1227
1228struct InlineAssistGroup {
1229 assist_ids: Vec<InlineAssistId>,
1230 linked: bool,
1231 active_assist_id: Option<InlineAssistId>,
1232}
1233
1234impl InlineAssistGroup {
1235 fn new() -> Self {
1236 Self {
1237 assist_ids: Vec::new(),
1238 linked: true,
1239 active_assist_id: None,
1240 }
1241 }
1242}
1243
1244fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1245 let editor = editor.clone();
1246 Box::new(move |cx: &mut BlockContext| {
1247 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1248 editor.clone().into_any_element()
1249 })
1250}
1251
1252#[derive(Copy, Clone, Debug, Eq, PartialEq)]
1253pub enum InitialInsertion {
1254 NewlineBefore,
1255 NewlineAfter,
1256}
1257
1258#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1259pub struct InlineAssistId(usize);
1260
1261impl InlineAssistId {
1262 fn post_inc(&mut self) -> InlineAssistId {
1263 let id = *self;
1264 self.0 += 1;
1265 id
1266 }
1267}
1268
1269#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1270struct InlineAssistGroupId(usize);
1271
1272impl InlineAssistGroupId {
1273 fn post_inc(&mut self) -> InlineAssistGroupId {
1274 let id = *self;
1275 self.0 += 1;
1276 id
1277 }
1278}
1279
1280enum PromptEditorEvent {
1281 StartRequested,
1282 StopRequested,
1283 ConfirmRequested,
1284 CancelRequested,
1285 DismissRequested,
1286}
1287
1288struct PromptEditor {
1289 id: InlineAssistId,
1290 fs: Arc<dyn Fs>,
1291 editor: View<Editor>,
1292 edited_since_done: bool,
1293 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1294 prompt_history: VecDeque<String>,
1295 prompt_history_ix: Option<usize>,
1296 pending_prompt: String,
1297 codegen: Model<Codegen>,
1298 _codegen_subscription: Subscription,
1299 editor_subscriptions: Vec<Subscription>,
1300 pending_token_count: Task<Result<()>>,
1301 token_count: Option<usize>,
1302 _token_count_subscriptions: Vec<Subscription>,
1303 workspace: Option<WeakView<Workspace>>,
1304 show_rate_limit_notice: bool,
1305}
1306
1307impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1308
1309impl Render for PromptEditor {
1310 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1311 let gutter_dimensions = *self.gutter_dimensions.lock();
1312 let status = &self.codegen.read(cx).status;
1313 let buttons = match status {
1314 CodegenStatus::Idle => {
1315 vec![
1316 IconButton::new("cancel", IconName::Close)
1317 .icon_color(Color::Muted)
1318 .shape(IconButtonShape::Square)
1319 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1320 .on_click(
1321 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1322 ),
1323 IconButton::new("start", IconName::SparkleAlt)
1324 .icon_color(Color::Muted)
1325 .shape(IconButtonShape::Square)
1326 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1327 .on_click(
1328 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1329 ),
1330 ]
1331 }
1332 CodegenStatus::Pending => {
1333 vec![
1334 IconButton::new("cancel", IconName::Close)
1335 .icon_color(Color::Muted)
1336 .shape(IconButtonShape::Square)
1337 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1338 .on_click(
1339 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1340 ),
1341 IconButton::new("stop", IconName::Stop)
1342 .icon_color(Color::Error)
1343 .shape(IconButtonShape::Square)
1344 .tooltip(|cx| {
1345 Tooltip::with_meta(
1346 "Interrupt Transformation",
1347 Some(&menu::Cancel),
1348 "Changes won't be discarded",
1349 cx,
1350 )
1351 })
1352 .on_click(
1353 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1354 ),
1355 ]
1356 }
1357 CodegenStatus::Error(_) | CodegenStatus::Done => {
1358 vec![
1359 IconButton::new("cancel", IconName::Close)
1360 .icon_color(Color::Muted)
1361 .shape(IconButtonShape::Square)
1362 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1363 .on_click(
1364 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1365 ),
1366 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1367 IconButton::new("restart", IconName::RotateCw)
1368 .icon_color(Color::Info)
1369 .shape(IconButtonShape::Square)
1370 .tooltip(|cx| {
1371 Tooltip::with_meta(
1372 "Restart Transformation",
1373 Some(&menu::Confirm),
1374 "Changes will be discarded",
1375 cx,
1376 )
1377 })
1378 .on_click(cx.listener(|_, _, cx| {
1379 cx.emit(PromptEditorEvent::StartRequested);
1380 }))
1381 } else {
1382 IconButton::new("confirm", IconName::Check)
1383 .icon_color(Color::Info)
1384 .shape(IconButtonShape::Square)
1385 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1386 .on_click(cx.listener(|_, _, cx| {
1387 cx.emit(PromptEditorEvent::ConfirmRequested);
1388 }))
1389 },
1390 ]
1391 }
1392 };
1393
1394 h_flex()
1395 .bg(cx.theme().colors().editor_background)
1396 .border_y_1()
1397 .border_color(cx.theme().status().info_border)
1398 .size_full()
1399 .py(cx.line_height() / 2.)
1400 .on_action(cx.listener(Self::confirm))
1401 .on_action(cx.listener(Self::cancel))
1402 .on_action(cx.listener(Self::move_up))
1403 .on_action(cx.listener(Self::move_down))
1404 .child(
1405 h_flex()
1406 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1407 .justify_center()
1408 .gap_2()
1409 .child(
1410 ModelSelector::new(
1411 self.fs.clone(),
1412 IconButton::new("context", IconName::SlidersAlt)
1413 .shape(IconButtonShape::Square)
1414 .icon_size(IconSize::Small)
1415 .icon_color(Color::Muted)
1416 .tooltip(move |cx| {
1417 Tooltip::with_meta(
1418 format!(
1419 "Using {}",
1420 LanguageModelRegistry::read_global(cx)
1421 .active_model()
1422 .map(|model| model.name().0)
1423 .unwrap_or_else(|| "No model selected".into()),
1424 ),
1425 None,
1426 "Change Model",
1427 cx,
1428 )
1429 }),
1430 )
1431 .with_info_text(
1432 "Inline edits use context\n\
1433 from the currently selected\n\
1434 assistant panel tab.",
1435 ),
1436 )
1437 .map(|el| {
1438 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1439 return el;
1440 };
1441
1442 let error_message = SharedString::from(error.to_string());
1443 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1444 && cx.has_flag::<ZedPro>()
1445 {
1446 el.child(
1447 v_flex()
1448 .child(
1449 IconButton::new("rate-limit-error", IconName::XCircle)
1450 .selected(self.show_rate_limit_notice)
1451 .shape(IconButtonShape::Square)
1452 .icon_size(IconSize::Small)
1453 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1454 )
1455 .children(self.show_rate_limit_notice.then(|| {
1456 deferred(
1457 anchored()
1458 .position_mode(gpui::AnchoredPositionMode::Local)
1459 .position(point(px(0.), px(24.)))
1460 .anchor(gpui::AnchorCorner::TopLeft)
1461 .child(self.render_rate_limit_notice(cx)),
1462 )
1463 })),
1464 )
1465 } else {
1466 el.child(
1467 div()
1468 .id("error")
1469 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1470 .child(
1471 Icon::new(IconName::XCircle)
1472 .size(IconSize::Small)
1473 .color(Color::Error),
1474 ),
1475 )
1476 }
1477 }),
1478 )
1479 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1480 .child(
1481 h_flex()
1482 .gap_2()
1483 .pr_6()
1484 .children(self.render_token_count(cx))
1485 .children(buttons),
1486 )
1487 }
1488}
1489
1490impl FocusableView for PromptEditor {
1491 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1492 self.editor.focus_handle(cx)
1493 }
1494}
1495
1496impl PromptEditor {
1497 const MAX_LINES: u8 = 8;
1498
1499 #[allow(clippy::too_many_arguments)]
1500 fn new(
1501 id: InlineAssistId,
1502 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1503 prompt_history: VecDeque<String>,
1504 prompt_buffer: Model<MultiBuffer>,
1505 codegen: Model<Codegen>,
1506 parent_editor: &View<Editor>,
1507 assistant_panel: Option<&View<AssistantPanel>>,
1508 workspace: Option<WeakView<Workspace>>,
1509 fs: Arc<dyn Fs>,
1510 cx: &mut ViewContext<Self>,
1511 ) -> Self {
1512 let prompt_editor = cx.new_view(|cx| {
1513 let mut editor = Editor::new(
1514 EditorMode::AutoHeight {
1515 max_lines: Self::MAX_LINES as usize,
1516 },
1517 prompt_buffer,
1518 None,
1519 false,
1520 cx,
1521 );
1522 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1523 // Since the prompt editors for all inline assistants are linked,
1524 // always show the cursor (even when it isn't focused) because
1525 // typing in one will make what you typed appear in all of them.
1526 editor.set_show_cursor_when_unfocused(true, cx);
1527 editor.set_placeholder_text("Add a prompt…", cx);
1528 editor
1529 });
1530
1531 let mut token_count_subscriptions = Vec::new();
1532 token_count_subscriptions
1533 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1534 if let Some(assistant_panel) = assistant_panel {
1535 token_count_subscriptions
1536 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1537 }
1538
1539 let mut this = Self {
1540 id,
1541 editor: prompt_editor,
1542 edited_since_done: false,
1543 gutter_dimensions,
1544 prompt_history,
1545 prompt_history_ix: None,
1546 pending_prompt: String::new(),
1547 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1548 editor_subscriptions: Vec::new(),
1549 codegen,
1550 fs,
1551 pending_token_count: Task::ready(Ok(())),
1552 token_count: None,
1553 _token_count_subscriptions: token_count_subscriptions,
1554 workspace,
1555 show_rate_limit_notice: false,
1556 };
1557 this.count_tokens(cx);
1558 this.subscribe_to_editor(cx);
1559 this
1560 }
1561
1562 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1563 self.editor_subscriptions.clear();
1564 self.editor_subscriptions
1565 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1566 }
1567
1568 fn set_show_cursor_when_unfocused(
1569 &mut self,
1570 show_cursor_when_unfocused: bool,
1571 cx: &mut ViewContext<Self>,
1572 ) {
1573 self.editor.update(cx, |editor, cx| {
1574 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1575 });
1576 }
1577
1578 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1579 let prompt = self.prompt(cx);
1580 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1581 self.editor = cx.new_view(|cx| {
1582 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1583 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1584 editor.set_placeholder_text("Add a prompt…", cx);
1585 editor.set_text(prompt, cx);
1586 if focus {
1587 editor.focus(cx);
1588 }
1589 editor
1590 });
1591 self.subscribe_to_editor(cx);
1592 }
1593
1594 fn prompt(&self, cx: &AppContext) -> String {
1595 self.editor.read(cx).text(cx)
1596 }
1597
1598 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1599 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1600 if self.show_rate_limit_notice {
1601 cx.focus_view(&self.editor);
1602 }
1603 cx.notify();
1604 }
1605
1606 fn handle_parent_editor_event(
1607 &mut self,
1608 _: View<Editor>,
1609 event: &EditorEvent,
1610 cx: &mut ViewContext<Self>,
1611 ) {
1612 if let EditorEvent::BufferEdited { .. } = event {
1613 self.count_tokens(cx);
1614 }
1615 }
1616
1617 fn handle_assistant_panel_event(
1618 &mut self,
1619 _: View<AssistantPanel>,
1620 event: &AssistantPanelEvent,
1621 cx: &mut ViewContext<Self>,
1622 ) {
1623 let AssistantPanelEvent::ContextEdited { .. } = event;
1624 self.count_tokens(cx);
1625 }
1626
1627 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1628 let assist_id = self.id;
1629 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1630 cx.background_executor().timer(Duration::from_secs(1)).await;
1631 let token_count = cx
1632 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1633 let assist = inline_assistant
1634 .assists
1635 .get(&assist_id)
1636 .context("assist not found")?;
1637 anyhow::Ok(assist.count_tokens(cx))
1638 })??
1639 .await?;
1640
1641 this.update(&mut cx, |this, cx| {
1642 this.token_count = Some(token_count);
1643 cx.notify();
1644 })
1645 })
1646 }
1647
1648 fn handle_prompt_editor_events(
1649 &mut self,
1650 _: View<Editor>,
1651 event: &EditorEvent,
1652 cx: &mut ViewContext<Self>,
1653 ) {
1654 match event {
1655 EditorEvent::Edited { .. } => {
1656 let prompt = self.editor.read(cx).text(cx);
1657 if self
1658 .prompt_history_ix
1659 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1660 {
1661 self.prompt_history_ix.take();
1662 self.pending_prompt = prompt;
1663 }
1664
1665 self.edited_since_done = true;
1666 cx.notify();
1667 }
1668 EditorEvent::BufferEdited => {
1669 self.count_tokens(cx);
1670 }
1671 EditorEvent::Blurred => {
1672 if self.show_rate_limit_notice {
1673 self.show_rate_limit_notice = false;
1674 cx.notify();
1675 }
1676 }
1677 _ => {}
1678 }
1679 }
1680
1681 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1682 match &self.codegen.read(cx).status {
1683 CodegenStatus::Idle => {
1684 self.editor
1685 .update(cx, |editor, _| editor.set_read_only(false));
1686 }
1687 CodegenStatus::Pending => {
1688 self.editor
1689 .update(cx, |editor, _| editor.set_read_only(true));
1690 }
1691 CodegenStatus::Done => {
1692 self.edited_since_done = false;
1693 self.editor
1694 .update(cx, |editor, _| editor.set_read_only(false));
1695 }
1696 CodegenStatus::Error(error) => {
1697 if cx.has_flag::<ZedPro>()
1698 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1699 && !dismissed_rate_limit_notice()
1700 {
1701 self.show_rate_limit_notice = true;
1702 cx.notify();
1703 }
1704
1705 self.edited_since_done = false;
1706 self.editor
1707 .update(cx, |editor, _| editor.set_read_only(false));
1708 }
1709 }
1710 }
1711
1712 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1713 match &self.codegen.read(cx).status {
1714 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1715 cx.emit(PromptEditorEvent::CancelRequested);
1716 }
1717 CodegenStatus::Pending => {
1718 cx.emit(PromptEditorEvent::StopRequested);
1719 }
1720 }
1721 }
1722
1723 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1724 match &self.codegen.read(cx).status {
1725 CodegenStatus::Idle => {
1726 cx.emit(PromptEditorEvent::StartRequested);
1727 }
1728 CodegenStatus::Pending => {
1729 cx.emit(PromptEditorEvent::DismissRequested);
1730 }
1731 CodegenStatus::Done | CodegenStatus::Error(_) => {
1732 if self.edited_since_done {
1733 cx.emit(PromptEditorEvent::StartRequested);
1734 } else {
1735 cx.emit(PromptEditorEvent::ConfirmRequested);
1736 }
1737 }
1738 }
1739 }
1740
1741 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1742 if let Some(ix) = self.prompt_history_ix {
1743 if ix > 0 {
1744 self.prompt_history_ix = Some(ix - 1);
1745 let prompt = self.prompt_history[ix - 1].as_str();
1746 self.editor.update(cx, |editor, cx| {
1747 editor.set_text(prompt, cx);
1748 editor.move_to_beginning(&Default::default(), cx);
1749 });
1750 }
1751 } else if !self.prompt_history.is_empty() {
1752 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1753 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1754 self.editor.update(cx, |editor, cx| {
1755 editor.set_text(prompt, cx);
1756 editor.move_to_beginning(&Default::default(), cx);
1757 });
1758 }
1759 }
1760
1761 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1762 if let Some(ix) = self.prompt_history_ix {
1763 if ix < self.prompt_history.len() - 1 {
1764 self.prompt_history_ix = Some(ix + 1);
1765 let prompt = self.prompt_history[ix + 1].as_str();
1766 self.editor.update(cx, |editor, cx| {
1767 editor.set_text(prompt, cx);
1768 editor.move_to_end(&Default::default(), cx)
1769 });
1770 } else {
1771 self.prompt_history_ix = None;
1772 let prompt = self.pending_prompt.as_str();
1773 self.editor.update(cx, |editor, cx| {
1774 editor.set_text(prompt, cx);
1775 editor.move_to_end(&Default::default(), cx)
1776 });
1777 }
1778 }
1779 }
1780
1781 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1782 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1783 let token_count = self.token_count?;
1784 let max_token_count = model.max_token_count();
1785
1786 let remaining_tokens = max_token_count as isize - token_count as isize;
1787 let token_count_color = if remaining_tokens <= 0 {
1788 Color::Error
1789 } else if token_count as f32 / max_token_count as f32 >= 0.8 {
1790 Color::Warning
1791 } else {
1792 Color::Muted
1793 };
1794
1795 let mut token_count = h_flex()
1796 .id("token_count")
1797 .gap_0p5()
1798 .child(
1799 Label::new(humanize_token_count(token_count))
1800 .size(LabelSize::Small)
1801 .color(token_count_color),
1802 )
1803 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1804 .child(
1805 Label::new(humanize_token_count(max_token_count))
1806 .size(LabelSize::Small)
1807 .color(Color::Muted),
1808 );
1809 if let Some(workspace) = self.workspace.clone() {
1810 token_count = token_count
1811 .tooltip(|cx| {
1812 Tooltip::with_meta(
1813 "Tokens Used by Inline Assistant",
1814 None,
1815 "Click to Open Assistant Panel",
1816 cx,
1817 )
1818 })
1819 .cursor_pointer()
1820 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1821 .on_click(move |_, cx| {
1822 cx.stop_propagation();
1823 workspace
1824 .update(cx, |workspace, cx| {
1825 workspace.focus_panel::<AssistantPanel>(cx)
1826 })
1827 .ok();
1828 });
1829 } else {
1830 token_count = token_count
1831 .cursor_default()
1832 .tooltip(|cx| Tooltip::text("Tokens Used by Inline Assistant", cx));
1833 }
1834
1835 Some(token_count)
1836 }
1837
1838 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1839 let settings = ThemeSettings::get_global(cx);
1840 let text_style = TextStyle {
1841 color: if self.editor.read(cx).read_only(cx) {
1842 cx.theme().colors().text_disabled
1843 } else {
1844 cx.theme().colors().text
1845 },
1846 font_family: settings.ui_font.family.clone(),
1847 font_features: settings.ui_font.features.clone(),
1848 font_fallbacks: settings.ui_font.fallbacks.clone(),
1849 font_size: rems(0.875).into(),
1850 font_weight: settings.ui_font.weight,
1851 line_height: relative(1.3),
1852 ..Default::default()
1853 };
1854 EditorElement::new(
1855 &self.editor,
1856 EditorStyle {
1857 background: cx.theme().colors().editor_background,
1858 local_player: cx.theme().players().local(),
1859 text: text_style,
1860 ..Default::default()
1861 },
1862 )
1863 }
1864
1865 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1866 Popover::new().child(
1867 v_flex()
1868 .occlude()
1869 .p_2()
1870 .child(
1871 Label::new("Out of Tokens")
1872 .size(LabelSize::Small)
1873 .weight(FontWeight::BOLD),
1874 )
1875 .child(Label::new(
1876 "Try Zed Pro for higher limits, a wider range of models, and more.",
1877 ))
1878 .child(
1879 h_flex()
1880 .justify_between()
1881 .child(CheckboxWithLabel::new(
1882 "dont-show-again",
1883 Label::new("Don't show again"),
1884 if dismissed_rate_limit_notice() {
1885 ui::Selection::Selected
1886 } else {
1887 ui::Selection::Unselected
1888 },
1889 |selection, cx| {
1890 let is_dismissed = match selection {
1891 ui::Selection::Unselected => false,
1892 ui::Selection::Indeterminate => return,
1893 ui::Selection::Selected => true,
1894 };
1895
1896 set_rate_limit_notice_dismissed(is_dismissed, cx)
1897 },
1898 ))
1899 .child(
1900 h_flex()
1901 .gap_2()
1902 .child(
1903 Button::new("dismiss", "Dismiss")
1904 .style(ButtonStyle::Transparent)
1905 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1906 )
1907 .child(Button::new("more-info", "More Info").on_click(
1908 |_event, cx| {
1909 cx.dispatch_action(Box::new(
1910 zed_actions::OpenAccountSettings,
1911 ))
1912 },
1913 )),
1914 ),
1915 ),
1916 )
1917 }
1918}
1919
1920const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1921
1922fn dismissed_rate_limit_notice() -> bool {
1923 db::kvp::KEY_VALUE_STORE
1924 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1925 .log_err()
1926 .map_or(false, |s| s.is_some())
1927}
1928
1929fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1930 db::write_and_log(cx, move || async move {
1931 if is_dismissed {
1932 db::kvp::KEY_VALUE_STORE
1933 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1934 .await
1935 } else {
1936 db::kvp::KEY_VALUE_STORE
1937 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1938 .await
1939 }
1940 })
1941}
1942
1943struct InlineAssist {
1944 group_id: InlineAssistGroupId,
1945 range: Range<Anchor>,
1946 editor: WeakView<Editor>,
1947 decorations: Option<InlineAssistDecorations>,
1948 codegen: Model<Codegen>,
1949 _subscriptions: Vec<Subscription>,
1950 workspace: Option<WeakView<Workspace>>,
1951 include_context: bool,
1952}
1953
1954impl InlineAssist {
1955 #[allow(clippy::too_many_arguments)]
1956 fn new(
1957 assist_id: InlineAssistId,
1958 group_id: InlineAssistGroupId,
1959 include_context: bool,
1960 editor: &View<Editor>,
1961 prompt_editor: &View<PromptEditor>,
1962 prompt_block_id: CustomBlockId,
1963 end_block_id: CustomBlockId,
1964 range: Range<Anchor>,
1965 codegen: Model<Codegen>,
1966 workspace: Option<WeakView<Workspace>>,
1967 cx: &mut WindowContext,
1968 ) -> Self {
1969 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
1970 InlineAssist {
1971 group_id,
1972 include_context,
1973 editor: editor.downgrade(),
1974 decorations: Some(InlineAssistDecorations {
1975 prompt_block_id,
1976 prompt_editor: prompt_editor.clone(),
1977 removed_line_block_ids: HashSet::default(),
1978 end_block_id,
1979 }),
1980 range,
1981 codegen: codegen.clone(),
1982 workspace: workspace.clone(),
1983 _subscriptions: vec![
1984 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
1985 InlineAssistant::update_global(cx, |this, cx| {
1986 this.handle_prompt_editor_focus_in(assist_id, cx)
1987 })
1988 }),
1989 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
1990 InlineAssistant::update_global(cx, |this, cx| {
1991 this.handle_prompt_editor_focus_out(assist_id, cx)
1992 })
1993 }),
1994 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
1995 InlineAssistant::update_global(cx, |this, cx| {
1996 this.handle_prompt_editor_event(prompt_editor, event, cx)
1997 })
1998 }),
1999 cx.observe(&codegen, {
2000 let editor = editor.downgrade();
2001 move |_, cx| {
2002 if let Some(editor) = editor.upgrade() {
2003 InlineAssistant::update_global(cx, |this, cx| {
2004 if let Some(editor_assists) =
2005 this.assists_by_editor.get(&editor.downgrade())
2006 {
2007 editor_assists.highlight_updates.send(()).ok();
2008 }
2009
2010 this.update_editor_blocks(&editor, assist_id, cx);
2011 })
2012 }
2013 }
2014 }),
2015 cx.subscribe(&codegen, move |codegen, event, cx| {
2016 InlineAssistant::update_global(cx, |this, cx| match event {
2017 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2018 CodegenEvent::Finished => {
2019 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2020 assist
2021 } else {
2022 return;
2023 };
2024
2025 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
2026 if assist.decorations.is_none() {
2027 if let Some(workspace) = assist
2028 .workspace
2029 .as_ref()
2030 .and_then(|workspace| workspace.upgrade())
2031 {
2032 let error = format!("Inline assistant error: {}", error);
2033 workspace.update(cx, |workspace, cx| {
2034 struct InlineAssistantError;
2035
2036 let id =
2037 NotificationId::identified::<InlineAssistantError>(
2038 assist_id.0,
2039 );
2040
2041 workspace.show_toast(Toast::new(id, error), cx);
2042 })
2043 }
2044 }
2045 }
2046
2047 if assist.decorations.is_none() {
2048 this.finish_assist(assist_id, false, cx);
2049 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2050 tx.0.send(()).ok();
2051 }
2052 }
2053 })
2054 }),
2055 ],
2056 }
2057 }
2058
2059 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2060 let decorations = self.decorations.as_ref()?;
2061 Some(decorations.prompt_editor.read(cx).prompt(cx))
2062 }
2063
2064 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2065 if self.include_context {
2066 let workspace = self.workspace.as_ref()?;
2067 let workspace = workspace.upgrade()?.read(cx);
2068 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2069 Some(
2070 assistant_panel
2071 .read(cx)
2072 .active_context(cx)?
2073 .read(cx)
2074 .to_completion_request(cx),
2075 )
2076 } else {
2077 None
2078 }
2079 }
2080
2081 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<usize>> {
2082 let Some(user_prompt) = self.user_prompt(cx) else {
2083 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2084 };
2085 let assistant_panel_context = self.assistant_panel_context(cx);
2086 self.codegen
2087 .read(cx)
2088 .count_tokens(user_prompt, assistant_panel_context, cx)
2089 }
2090}
2091
2092struct InlineAssistDecorations {
2093 prompt_block_id: CustomBlockId,
2094 prompt_editor: View<PromptEditor>,
2095 removed_line_block_ids: HashSet<CustomBlockId>,
2096 end_block_id: CustomBlockId,
2097}
2098
2099#[derive(Debug)]
2100pub enum CodegenEvent {
2101 Finished,
2102 Undone,
2103}
2104
2105pub struct Codegen {
2106 buffer: Model<MultiBuffer>,
2107 old_buffer: Model<Buffer>,
2108 snapshot: MultiBufferSnapshot,
2109 transform_range: Range<Anchor>,
2110 selected_ranges: Vec<Range<Anchor>>,
2111 edit_position: Option<Anchor>,
2112 last_equal_ranges: Vec<Range<Anchor>>,
2113 initial_transaction_id: Option<TransactionId>,
2114 transformation_transaction_id: Option<TransactionId>,
2115 status: CodegenStatus,
2116 generation: Task<()>,
2117 diff: Diff,
2118 telemetry: Option<Arc<Telemetry>>,
2119 _subscription: gpui::Subscription,
2120 prompt_builder: Arc<PromptBuilder>,
2121}
2122
2123enum CodegenStatus {
2124 Idle,
2125 Pending,
2126 Done,
2127 Error(anyhow::Error),
2128}
2129
2130#[derive(Default)]
2131struct Diff {
2132 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2133 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2134}
2135
2136impl Diff {
2137 fn is_empty(&self) -> bool {
2138 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2139 }
2140}
2141
2142impl EventEmitter<CodegenEvent> for Codegen {}
2143
2144impl Codegen {
2145 pub fn new(
2146 buffer: Model<MultiBuffer>,
2147 transform_range: Range<Anchor>,
2148 selected_ranges: Vec<Range<Anchor>>,
2149 initial_transaction_id: Option<TransactionId>,
2150 telemetry: Option<Arc<Telemetry>>,
2151 builder: Arc<PromptBuilder>,
2152 cx: &mut ModelContext<Self>,
2153 ) -> Self {
2154 let snapshot = buffer.read(cx).snapshot(cx);
2155
2156 let (old_buffer, _, _) = buffer
2157 .read(cx)
2158 .range_to_buffer_ranges(transform_range.clone(), cx)
2159 .pop()
2160 .unwrap();
2161 let old_buffer = cx.new_model(|cx| {
2162 let old_buffer = old_buffer.read(cx);
2163 let text = old_buffer.as_rope().clone();
2164 let line_ending = old_buffer.line_ending();
2165 let language = old_buffer.language().cloned();
2166 let language_registry = old_buffer.language_registry();
2167
2168 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2169 buffer.set_language(language, cx);
2170 if let Some(language_registry) = language_registry {
2171 buffer.set_language_registry(language_registry)
2172 }
2173 buffer
2174 });
2175
2176 Self {
2177 buffer: buffer.clone(),
2178 old_buffer,
2179 edit_position: None,
2180 snapshot,
2181 last_equal_ranges: Default::default(),
2182 transformation_transaction_id: None,
2183 status: CodegenStatus::Idle,
2184 generation: Task::ready(()),
2185 diff: Diff::default(),
2186 telemetry,
2187 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2188 initial_transaction_id,
2189 prompt_builder: builder,
2190 transform_range,
2191 selected_ranges,
2192 }
2193 }
2194
2195 fn handle_buffer_event(
2196 &mut self,
2197 _buffer: Model<MultiBuffer>,
2198 event: &multi_buffer::Event,
2199 cx: &mut ModelContext<Self>,
2200 ) {
2201 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2202 if self.transformation_transaction_id == Some(*transaction_id) {
2203 self.transformation_transaction_id = None;
2204 self.generation = Task::ready(());
2205 cx.emit(CodegenEvent::Undone);
2206 }
2207 }
2208 }
2209
2210 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2211 &self.last_equal_ranges
2212 }
2213
2214 pub fn count_tokens(
2215 &self,
2216 user_prompt: String,
2217 assistant_panel_context: Option<LanguageModelRequest>,
2218 cx: &AppContext,
2219 ) -> BoxFuture<'static, Result<usize>> {
2220 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2221 let request = self.build_request(user_prompt, assistant_panel_context, cx);
2222 match request {
2223 Ok(request) => model.count_tokens(request, cx),
2224 Err(error) => futures::future::ready(Err(error)).boxed(),
2225 }
2226 } else {
2227 future::ready(Err(anyhow!("no active model"))).boxed()
2228 }
2229 }
2230
2231 pub fn start(
2232 &mut self,
2233 user_prompt: String,
2234 assistant_panel_context: Option<LanguageModelRequest>,
2235 cx: &mut ModelContext<Self>,
2236 ) -> Result<()> {
2237 let model = LanguageModelRegistry::read_global(cx)
2238 .active_model()
2239 .context("no active model")?;
2240
2241 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2242 self.buffer.update(cx, |buffer, cx| {
2243 buffer.undo_transaction(transformation_transaction_id, cx);
2244 });
2245 }
2246
2247 self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
2248
2249 let telemetry_id = model.telemetry_id();
2250 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
2251 if user_prompt.trim().to_lowercase() == "delete" {
2252 async { Ok(stream::empty().boxed()) }.boxed_local()
2253 } else {
2254 let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
2255
2256 let chunks =
2257 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2258 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2259 };
2260 self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
2261 Ok(())
2262 }
2263
2264 fn build_request(
2265 &self,
2266 user_prompt: String,
2267 assistant_panel_context: Option<LanguageModelRequest>,
2268 cx: &AppContext,
2269 ) -> Result<LanguageModelRequest> {
2270 let buffer = self.buffer.read(cx).snapshot(cx);
2271 let language = buffer.language_at(self.transform_range.start);
2272 let language_name = if let Some(language) = language.as_ref() {
2273 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2274 None
2275 } else {
2276 Some(language.name())
2277 }
2278 } else {
2279 None
2280 };
2281
2282 // Higher Temperature increases the randomness of model outputs.
2283 // If Markdown or No Language is Known, increase the randomness for more creative output
2284 // If Code, decrease temperature to get more deterministic outputs
2285 let temperature = if let Some(language) = language_name.clone() {
2286 if language.as_ref() == "Markdown" {
2287 1.0
2288 } else {
2289 0.5
2290 }
2291 } else {
2292 1.0
2293 };
2294
2295 let language_name = language_name.as_deref();
2296 let start = buffer.point_to_buffer_offset(self.transform_range.start);
2297 let end = buffer.point_to_buffer_offset(self.transform_range.end);
2298 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2299 let (start_buffer, start_buffer_offset) = start;
2300 let (end_buffer, end_buffer_offset) = end;
2301 if start_buffer.remote_id() == end_buffer.remote_id() {
2302 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2303 } else {
2304 return Err(anyhow::anyhow!("invalid transformation range"));
2305 }
2306 } else {
2307 return Err(anyhow::anyhow!("invalid transformation range"));
2308 };
2309
2310 let selected_ranges = self
2311 .selected_ranges
2312 .iter()
2313 .map(|range| {
2314 let start = range.start.text_anchor.to_offset(&buffer);
2315 let end = range.end.text_anchor.to_offset(&buffer);
2316 start..end
2317 })
2318 .collect::<Vec<_>>();
2319
2320 let prompt = self
2321 .prompt_builder
2322 .generate_content_prompt(user_prompt, language_name, buffer, range, selected_ranges)
2323 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2324
2325 let mut messages = Vec::new();
2326 if let Some(context_request) = assistant_panel_context {
2327 messages = context_request.messages;
2328 }
2329
2330 messages.push(LanguageModelRequestMessage {
2331 role: Role::User,
2332 content: prompt,
2333 });
2334
2335 Ok(LanguageModelRequest {
2336 messages,
2337 stop: vec!["|END|>".to_string()],
2338 temperature,
2339 })
2340 }
2341
2342 pub fn handle_stream(
2343 &mut self,
2344 model_telemetry_id: String,
2345 edit_range: Range<Anchor>,
2346 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2347 cx: &mut ModelContext<Self>,
2348 ) {
2349 let snapshot = self.snapshot.clone();
2350 let selected_text = snapshot
2351 .text_for_range(edit_range.start..edit_range.end)
2352 .collect::<Rope>();
2353
2354 let selection_start = edit_range.start.to_point(&snapshot);
2355
2356 // Start with the indentation of the first line in the selection
2357 let mut suggested_line_indent = snapshot
2358 .suggested_indents(selection_start.row..=selection_start.row, cx)
2359 .into_values()
2360 .next()
2361 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2362
2363 // If the first line in the selection does not have indentation, check the following lines
2364 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2365 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2366 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2367 // Prefer tabs if a line in the selection uses tabs as indentation
2368 if line_indent.kind == IndentKind::Tab {
2369 suggested_line_indent.kind = IndentKind::Tab;
2370 break;
2371 }
2372 }
2373 }
2374
2375 let telemetry = self.telemetry.clone();
2376 self.diff = Diff::default();
2377 self.status = CodegenStatus::Pending;
2378 let mut edit_start = edit_range.start.to_offset(&snapshot);
2379 self.generation = cx.spawn(|this, mut cx| {
2380 async move {
2381 let chunks = stream.await;
2382 let generate = async {
2383 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2384 let diff: Task<anyhow::Result<()>> =
2385 cx.background_executor().spawn(async move {
2386 let mut response_latency = None;
2387 let request_start = Instant::now();
2388 let diff = async {
2389 let chunks = StripInvalidSpans::new(chunks?);
2390 futures::pin_mut!(chunks);
2391 let mut diff = StreamingDiff::new(selected_text.to_string());
2392 let mut line_diff = LineDiff::default();
2393
2394 while let Some(chunk) = chunks.next().await {
2395 if response_latency.is_none() {
2396 response_latency = Some(request_start.elapsed());
2397 }
2398 let chunk = chunk?;
2399 let char_ops = diff.push_new(&chunk);
2400 line_diff.push_char_operations(&char_ops, &selected_text);
2401 diff_tx
2402 .send((char_ops, line_diff.line_operations()))
2403 .await?;
2404 }
2405
2406 let char_ops = diff.finish();
2407 line_diff.push_char_operations(&char_ops, &selected_text);
2408 line_diff.finish(&selected_text);
2409 diff_tx
2410 .send((char_ops, line_diff.line_operations()))
2411 .await?;
2412
2413 anyhow::Ok(())
2414 };
2415
2416 let result = diff.await;
2417
2418 let error_message =
2419 result.as_ref().err().map(|error| error.to_string());
2420 if let Some(telemetry) = telemetry {
2421 telemetry.report_assistant_event(
2422 None,
2423 telemetry_events::AssistantKind::Inline,
2424 model_telemetry_id,
2425 response_latency,
2426 error_message,
2427 );
2428 }
2429
2430 result?;
2431 Ok(())
2432 });
2433
2434 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2435 this.update(&mut cx, |this, cx| {
2436 this.last_equal_ranges.clear();
2437
2438 let transaction = this.buffer.update(cx, |buffer, cx| {
2439 // Avoid grouping assistant edits with user edits.
2440 buffer.finalize_last_transaction(cx);
2441
2442 buffer.start_transaction(cx);
2443 buffer.edit(
2444 char_ops
2445 .into_iter()
2446 .filter_map(|operation| match operation {
2447 CharOperation::Insert { text } => {
2448 let edit_start = snapshot.anchor_after(edit_start);
2449 Some((edit_start..edit_start, text))
2450 }
2451 CharOperation::Delete { bytes } => {
2452 let edit_end = edit_start + bytes;
2453 let edit_range = snapshot.anchor_after(edit_start)
2454 ..snapshot.anchor_before(edit_end);
2455 edit_start = edit_end;
2456 Some((edit_range, String::new()))
2457 }
2458 CharOperation::Keep { bytes } => {
2459 let edit_end = edit_start + bytes;
2460 let edit_range = snapshot.anchor_after(edit_start)
2461 ..snapshot.anchor_before(edit_end);
2462 edit_start = edit_end;
2463 this.last_equal_ranges.push(edit_range);
2464 None
2465 }
2466 }),
2467 None,
2468 cx,
2469 );
2470 this.edit_position = Some(snapshot.anchor_after(edit_start));
2471
2472 buffer.end_transaction(cx)
2473 });
2474
2475 if let Some(transaction) = transaction {
2476 if let Some(first_transaction) = this.transformation_transaction_id
2477 {
2478 // Group all assistant edits into the first transaction.
2479 this.buffer.update(cx, |buffer, cx| {
2480 buffer.merge_transactions(
2481 transaction,
2482 first_transaction,
2483 cx,
2484 )
2485 });
2486 } else {
2487 this.transformation_transaction_id = Some(transaction);
2488 this.buffer.update(cx, |buffer, cx| {
2489 buffer.finalize_last_transaction(cx)
2490 });
2491 }
2492 }
2493
2494 this.update_diff(edit_range.clone(), line_diff, cx);
2495
2496 cx.notify();
2497 })?;
2498 }
2499
2500 diff.await?;
2501
2502 anyhow::Ok(())
2503 };
2504
2505 let result = generate.await;
2506 this.update(&mut cx, |this, cx| {
2507 this.last_equal_ranges.clear();
2508 if let Err(error) = result {
2509 this.status = CodegenStatus::Error(error);
2510 } else {
2511 this.status = CodegenStatus::Done;
2512 }
2513 cx.emit(CodegenEvent::Finished);
2514 cx.notify();
2515 })
2516 .ok();
2517 }
2518 });
2519 cx.notify();
2520 }
2521
2522 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2523 self.last_equal_ranges.clear();
2524 if self.diff.is_empty() {
2525 self.status = CodegenStatus::Idle;
2526 } else {
2527 self.status = CodegenStatus::Done;
2528 }
2529 self.generation = Task::ready(());
2530 cx.emit(CodegenEvent::Finished);
2531 cx.notify();
2532 }
2533
2534 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2535 self.buffer.update(cx, |buffer, cx| {
2536 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2537 buffer.undo_transaction(transaction_id, cx);
2538 buffer.refresh_preview(cx);
2539 }
2540
2541 if let Some(transaction_id) = self.initial_transaction_id.take() {
2542 buffer.undo_transaction(transaction_id, cx);
2543 buffer.refresh_preview(cx);
2544 }
2545 });
2546 }
2547
2548 fn update_diff(
2549 &mut self,
2550 edit_range: Range<Anchor>,
2551 line_operations: Vec<LineOperation>,
2552 cx: &mut ModelContext<Self>,
2553 ) {
2554 let old_snapshot = self.snapshot.clone();
2555 let old_range = edit_range.to_point(&old_snapshot);
2556 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2557 let new_range = edit_range.to_point(&new_snapshot);
2558
2559 let mut old_row = old_range.start.row;
2560 let mut new_row = new_range.start.row;
2561
2562 self.diff.deleted_row_ranges.clear();
2563 self.diff.inserted_row_ranges.clear();
2564 for operation in line_operations {
2565 match operation {
2566 LineOperation::Keep { lines } => {
2567 old_row += lines;
2568 new_row += lines;
2569 }
2570 LineOperation::Delete { lines } => {
2571 let old_end_row = old_row + lines - 1;
2572 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2573
2574 if let Some((_, last_deleted_row_range)) =
2575 self.diff.deleted_row_ranges.last_mut()
2576 {
2577 if *last_deleted_row_range.end() + 1 == old_row {
2578 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2579 } else {
2580 self.diff
2581 .deleted_row_ranges
2582 .push((new_row, old_row..=old_end_row));
2583 }
2584 } else {
2585 self.diff
2586 .deleted_row_ranges
2587 .push((new_row, old_row..=old_end_row));
2588 }
2589
2590 old_row += lines;
2591 }
2592 LineOperation::Insert { lines } => {
2593 let new_end_row = new_row + lines - 1;
2594 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2595 let end = new_snapshot.anchor_before(Point::new(
2596 new_end_row,
2597 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2598 ));
2599 self.diff.inserted_row_ranges.push(start..=end);
2600 new_row += lines;
2601 }
2602 }
2603
2604 cx.notify();
2605 }
2606 }
2607}
2608
2609struct StripInvalidSpans<T> {
2610 stream: T,
2611 stream_done: bool,
2612 buffer: String,
2613 first_line: bool,
2614 line_end: bool,
2615 starts_with_code_block: bool,
2616}
2617
2618impl<T> StripInvalidSpans<T>
2619where
2620 T: Stream<Item = Result<String>>,
2621{
2622 fn new(stream: T) -> Self {
2623 Self {
2624 stream,
2625 stream_done: false,
2626 buffer: String::new(),
2627 first_line: true,
2628 line_end: false,
2629 starts_with_code_block: false,
2630 }
2631 }
2632}
2633
2634impl<T> Stream for StripInvalidSpans<T>
2635where
2636 T: Stream<Item = Result<String>>,
2637{
2638 type Item = Result<String>;
2639
2640 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2641 const CODE_BLOCK_DELIMITER: &str = "```";
2642 const CURSOR_SPAN: &str = "<|CURSOR|>";
2643
2644 let this = unsafe { self.get_unchecked_mut() };
2645 loop {
2646 if !this.stream_done {
2647 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2648 match stream.as_mut().poll_next(cx) {
2649 Poll::Ready(Some(Ok(chunk))) => {
2650 this.buffer.push_str(&chunk);
2651 }
2652 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2653 Poll::Ready(None) => {
2654 this.stream_done = true;
2655 }
2656 Poll::Pending => return Poll::Pending,
2657 }
2658 }
2659
2660 let mut chunk = String::new();
2661 let mut consumed = 0;
2662 if !this.buffer.is_empty() {
2663 let mut lines = this.buffer.split('\n').enumerate().peekable();
2664 while let Some((line_ix, line)) = lines.next() {
2665 if line_ix > 0 {
2666 this.first_line = false;
2667 }
2668
2669 if this.first_line {
2670 let trimmed_line = line.trim();
2671 if lines.peek().is_some() {
2672 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2673 consumed += line.len() + 1;
2674 this.starts_with_code_block = true;
2675 continue;
2676 }
2677 } else if trimmed_line.is_empty()
2678 || prefixes(CODE_BLOCK_DELIMITER)
2679 .any(|prefix| trimmed_line.starts_with(prefix))
2680 {
2681 break;
2682 }
2683 }
2684
2685 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2686 if lines.peek().is_some() {
2687 if this.line_end {
2688 chunk.push('\n');
2689 }
2690
2691 chunk.push_str(&line_without_cursor);
2692 this.line_end = true;
2693 consumed += line.len() + 1;
2694 } else if this.stream_done {
2695 if !this.starts_with_code_block
2696 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2697 {
2698 if this.line_end {
2699 chunk.push('\n');
2700 }
2701
2702 chunk.push_str(&line);
2703 }
2704
2705 consumed += line.len();
2706 } else {
2707 let trimmed_line = line.trim();
2708 if trimmed_line.is_empty()
2709 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2710 || prefixes(CODE_BLOCK_DELIMITER)
2711 .any(|prefix| trimmed_line.ends_with(prefix))
2712 {
2713 break;
2714 } else {
2715 if this.line_end {
2716 chunk.push('\n');
2717 this.line_end = false;
2718 }
2719
2720 chunk.push_str(&line_without_cursor);
2721 consumed += line.len();
2722 }
2723 }
2724 }
2725 }
2726
2727 this.buffer = this.buffer.split_off(consumed);
2728 if !chunk.is_empty() {
2729 return Poll::Ready(Some(Ok(chunk)));
2730 } else if this.stream_done {
2731 return Poll::Ready(None);
2732 }
2733 }
2734 }
2735}
2736
2737fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2738 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2739}
2740
2741fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2742 ranges.sort_unstable_by(|a, b| {
2743 a.start
2744 .cmp(&b.start, buffer)
2745 .then_with(|| b.end.cmp(&a.end, buffer))
2746 });
2747
2748 let mut ix = 0;
2749 while ix + 1 < ranges.len() {
2750 let b = ranges[ix + 1].clone();
2751 let a = &mut ranges[ix];
2752 if a.end.cmp(&b.start, buffer).is_gt() {
2753 if a.end.cmp(&b.end, buffer).is_lt() {
2754 a.end = b.end;
2755 }
2756 ranges.remove(ix + 1);
2757 } else {
2758 ix += 1;
2759 }
2760 }
2761}
2762
2763#[cfg(test)]
2764mod tests {
2765 use super::*;
2766 use futures::stream::{self};
2767 use serde::Serialize;
2768
2769 #[derive(Serialize)]
2770 pub struct DummyCompletionRequest {
2771 pub name: String,
2772 }
2773
2774 #[gpui::test]
2775 async fn test_strip_invalid_spans_from_codeblock() {
2776 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2777 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2778 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
2779 assert_chunks(
2780 "```html\n```js\nLorem ipsum dolor\n```\n```",
2781 "```js\nLorem ipsum dolor\n```",
2782 )
2783 .await;
2784 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
2785 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
2786 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
2787 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
2788
2789 async fn assert_chunks(text: &str, expected_text: &str) {
2790 for chunk_size in 1..=text.len() {
2791 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
2792 .map(|chunk| chunk.unwrap())
2793 .collect::<String>()
2794 .await;
2795 assert_eq!(
2796 actual_text, expected_text,
2797 "failed to strip invalid spans, chunk size: {}",
2798 chunk_size
2799 );
2800 }
2801 }
2802
2803 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
2804 stream::iter(
2805 text.chars()
2806 .collect::<Vec<_>>()
2807 .chunks(size)
2808 .map(|chunk| Ok(chunk.iter().collect::<String>()))
2809 .collect::<Vec<_>>(),
2810 )
2811 }
2812 }
2813}