1use crate::{
2 humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
3 CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 join,
23 stream::{self, BoxStream},
24 SinkExt, Stream, StreamExt,
25};
26use gpui::{
27 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
28 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
29 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
30};
31use language::{Buffer, IndentKind, Point, TransactionId};
32use language_model::{
33 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
34};
35use multi_buffer::MultiBufferRow;
36use parking_lot::Mutex;
37use rope::Rope;
38use settings::Settings;
39use smol::future::FutureExt;
40use std::{
41 future::{self, Future},
42 mem,
43 ops::{Range, RangeInclusive},
44 pin::Pin,
45 sync::Arc,
46 task::{self, Poll},
47 time::{Duration, Instant},
48};
49use text::OffsetRangeExt as _;
50use theme::ThemeSettings;
51use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
52use util::{RangeExt, ResultExt};
53use workspace::{notifications::NotificationId, Toast, Workspace};
54
55pub fn init(
56 fs: Arc<dyn Fs>,
57 prompt_builder: Arc<PromptBuilder>,
58 telemetry: Arc<Telemetry>,
59 cx: &mut AppContext,
60) {
61 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
62 cx.observe_new_views(|_, cx| {
63 let workspace = cx.view().clone();
64 InlineAssistant::update_global(cx, |inline_assistant, cx| {
65 inline_assistant.register_workspace(&workspace, cx)
66 })
67 })
68 .detach();
69}
70
71const PROMPT_HISTORY_MAX_LEN: usize = 20;
72
73pub struct InlineAssistant {
74 next_assist_id: InlineAssistId,
75 next_assist_group_id: InlineAssistGroupId,
76 assists: HashMap<InlineAssistId, InlineAssist>,
77 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
78 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
79 assist_observations:
80 HashMap<InlineAssistId, (async_watch::Sender<()>, async_watch::Receiver<()>)>,
81 confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
82 prompt_history: VecDeque<String>,
83 prompt_builder: Arc<PromptBuilder>,
84 telemetry: Option<Arc<Telemetry>>,
85 fs: Arc<dyn Fs>,
86}
87
88impl Global for InlineAssistant {}
89
90impl InlineAssistant {
91 pub fn new(
92 fs: Arc<dyn Fs>,
93 prompt_builder: Arc<PromptBuilder>,
94 telemetry: Arc<Telemetry>,
95 ) -> Self {
96 Self {
97 next_assist_id: InlineAssistId::default(),
98 next_assist_group_id: InlineAssistGroupId::default(),
99 assists: HashMap::default(),
100 assists_by_editor: HashMap::default(),
101 assist_groups: HashMap::default(),
102 assist_observations: HashMap::default(),
103 confirmed_assists: HashMap::default(),
104 prompt_history: VecDeque::default(),
105 prompt_builder,
106 telemetry: Some(telemetry),
107 fs,
108 }
109 }
110
111 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
112 cx.subscribe(workspace, |_, event, cx| {
113 Self::update_global(cx, |this, cx| this.handle_workspace_event(event, cx));
114 })
115 .detach();
116 }
117
118 fn handle_workspace_event(&mut self, event: &workspace::Event, cx: &mut WindowContext) {
119 // When the user manually saves an editor, automatically accepts all finished transformations.
120 if let workspace::Event::UserSavedItem { item, .. } = event {
121 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
122 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
123 for assist_id in editor_assists.assist_ids.clone() {
124 let assist = &self.assists[&assist_id];
125 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
126 self.finish_assist(assist_id, false, cx)
127 }
128 }
129 }
130 }
131 }
132 }
133
134 pub fn assist(
135 &mut self,
136 editor: &View<Editor>,
137 workspace: Option<WeakView<Workspace>>,
138 assistant_panel: Option<&View<AssistantPanel>>,
139 initial_prompt: Option<String>,
140 cx: &mut WindowContext,
141 ) {
142 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
143 struct CodegenRange {
144 transform_range: Range<Point>,
145 selection_ranges: Vec<Range<Point>>,
146 focus_assist: bool,
147 }
148
149 let newest_selection_range = editor.read(cx).selections.newest::<Point>(cx).range();
150 let mut codegen_ranges: Vec<CodegenRange> = Vec::new();
151
152 let selection_ranges = snapshot
153 .split_ranges(editor.read(cx).selections.disjoint_anchor_ranges())
154 .map(|range| range.to_point(&snapshot))
155 .collect::<Vec<Range<Point>>>();
156
157 for selection_range in selection_ranges {
158 let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range);
159 let mut transform_range = selection_range.start..selection_range.end;
160
161 // Expand the transform range to start/end of lines.
162 // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
163 transform_range.start.column = 0;
164 if transform_range.end.column == 0 && transform_range.end > transform_range.start {
165 transform_range.end.row -= 1;
166 }
167 transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
168 let selection_range =
169 selection_range.start..selection_range.end.min(transform_range.end);
170
171 // If we intersect the previous transform range,
172 if let Some(CodegenRange {
173 transform_range: prev_transform_range,
174 selection_ranges,
175 focus_assist,
176 }) = codegen_ranges.last_mut()
177 {
178 if transform_range.start <= prev_transform_range.end {
179 prev_transform_range.end = transform_range.end;
180 selection_ranges.push(selection_range);
181 *focus_assist |= selection_is_newest;
182 continue;
183 }
184 }
185
186 codegen_ranges.push(CodegenRange {
187 transform_range,
188 selection_ranges: vec![selection_range],
189 focus_assist: selection_is_newest,
190 })
191 }
192
193 let assist_group_id = self.next_assist_group_id.post_inc();
194 let prompt_buffer =
195 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
196 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
197 let mut assists = Vec::new();
198 let mut assist_to_focus = None;
199
200 for CodegenRange {
201 transform_range,
202 selection_ranges,
203 focus_assist,
204 } in codegen_ranges
205 {
206 let transform_range = snapshot.anchor_before(transform_range.start)
207 ..snapshot.anchor_after(transform_range.end);
208 let selection_ranges = selection_ranges
209 .iter()
210 .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
211 .collect::<Vec<_>>();
212
213 let codegen = cx.new_model(|cx| {
214 Codegen::new(
215 editor.read(cx).buffer().clone(),
216 transform_range.clone(),
217 selection_ranges,
218 None,
219 self.telemetry.clone(),
220 self.prompt_builder.clone(),
221 cx,
222 )
223 });
224
225 let assist_id = self.next_assist_id.post_inc();
226 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
227 let prompt_editor = cx.new_view(|cx| {
228 PromptEditor::new(
229 assist_id,
230 gutter_dimensions.clone(),
231 self.prompt_history.clone(),
232 prompt_buffer.clone(),
233 codegen.clone(),
234 editor,
235 assistant_panel,
236 workspace.clone(),
237 self.fs.clone(),
238 cx,
239 )
240 });
241
242 if focus_assist {
243 assist_to_focus = Some(assist_id);
244 }
245
246 let [prompt_block_id, end_block_id] =
247 self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
248
249 assists.push((
250 assist_id,
251 transform_range,
252 prompt_editor,
253 prompt_block_id,
254 end_block_id,
255 ));
256 }
257
258 let editor_assists = self
259 .assists_by_editor
260 .entry(editor.downgrade())
261 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
262 let mut assist_group = InlineAssistGroup::new();
263 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
264 self.assists.insert(
265 assist_id,
266 InlineAssist::new(
267 assist_id,
268 assist_group_id,
269 assistant_panel.is_some(),
270 editor,
271 &prompt_editor,
272 prompt_block_id,
273 end_block_id,
274 range,
275 prompt_editor.read(cx).codegen.clone(),
276 workspace.clone(),
277 cx,
278 ),
279 );
280 assist_group.assist_ids.push(assist_id);
281 editor_assists.assist_ids.push(assist_id);
282 }
283 self.assist_groups.insert(assist_group_id, assist_group);
284
285 if let Some(assist_id) = assist_to_focus {
286 self.focus_assist(assist_id, cx);
287 }
288 }
289
290 #[allow(clippy::too_many_arguments)]
291 pub fn suggest_assist(
292 &mut self,
293 editor: &View<Editor>,
294 mut range: Range<Anchor>,
295 initial_prompt: String,
296 initial_transaction_id: Option<TransactionId>,
297 workspace: Option<WeakView<Workspace>>,
298 assistant_panel: Option<&View<AssistantPanel>>,
299 cx: &mut WindowContext,
300 ) -> InlineAssistId {
301 let assist_group_id = self.next_assist_group_id.post_inc();
302 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
303 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
304
305 let assist_id = self.next_assist_id.post_inc();
306
307 let buffer = editor.read(cx).buffer().clone();
308 {
309 let snapshot = buffer.read(cx).read(cx);
310 range.start = range.start.bias_left(&snapshot);
311 range.end = range.end.bias_right(&snapshot);
312 }
313
314 let codegen = cx.new_model(|cx| {
315 Codegen::new(
316 editor.read(cx).buffer().clone(),
317 range.clone(),
318 vec![range.clone()],
319 initial_transaction_id,
320 self.telemetry.clone(),
321 self.prompt_builder.clone(),
322 cx,
323 )
324 });
325
326 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
327 let prompt_editor = cx.new_view(|cx| {
328 PromptEditor::new(
329 assist_id,
330 gutter_dimensions.clone(),
331 self.prompt_history.clone(),
332 prompt_buffer.clone(),
333 codegen.clone(),
334 editor,
335 assistant_panel,
336 workspace.clone(),
337 self.fs.clone(),
338 cx,
339 )
340 });
341
342 let [prompt_block_id, end_block_id] =
343 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
344
345 let editor_assists = self
346 .assists_by_editor
347 .entry(editor.downgrade())
348 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
349
350 let mut assist_group = InlineAssistGroup::new();
351 self.assists.insert(
352 assist_id,
353 InlineAssist::new(
354 assist_id,
355 assist_group_id,
356 assistant_panel.is_some(),
357 editor,
358 &prompt_editor,
359 prompt_block_id,
360 end_block_id,
361 range,
362 prompt_editor.read(cx).codegen.clone(),
363 workspace.clone(),
364 cx,
365 ),
366 );
367 assist_group.assist_ids.push(assist_id);
368 editor_assists.assist_ids.push(assist_id);
369 self.assist_groups.insert(assist_group_id, assist_group);
370 assist_id
371 }
372
373 fn insert_assist_blocks(
374 &self,
375 editor: &View<Editor>,
376 range: &Range<Anchor>,
377 prompt_editor: &View<PromptEditor>,
378 cx: &mut WindowContext,
379 ) -> [CustomBlockId; 2] {
380 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
381 prompt_editor
382 .editor
383 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
384 });
385 let assist_blocks = vec![
386 BlockProperties {
387 style: BlockStyle::Sticky,
388 position: range.start,
389 height: prompt_editor_height,
390 render: build_assist_editor_renderer(prompt_editor),
391 disposition: BlockDisposition::Above,
392 priority: 0,
393 },
394 BlockProperties {
395 style: BlockStyle::Sticky,
396 position: range.end,
397 height: 0,
398 render: Box::new(|cx| {
399 v_flex()
400 .h_full()
401 .w_full()
402 .border_t_1()
403 .border_color(cx.theme().status().info_border)
404 .into_any_element()
405 }),
406 disposition: BlockDisposition::Below,
407 priority: 0,
408 },
409 ];
410
411 editor.update(cx, |editor, cx| {
412 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
413 [block_ids[0], block_ids[1]]
414 })
415 }
416
417 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
418 let assist = &self.assists[&assist_id];
419 let Some(decorations) = assist.decorations.as_ref() else {
420 return;
421 };
422 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
423 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
424
425 assist_group.active_assist_id = Some(assist_id);
426 if assist_group.linked {
427 for assist_id in &assist_group.assist_ids {
428 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
429 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
430 prompt_editor.set_show_cursor_when_unfocused(true, cx)
431 });
432 }
433 }
434 }
435
436 assist
437 .editor
438 .update(cx, |editor, cx| {
439 let scroll_top = editor.scroll_position(cx).y;
440 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
441 let prompt_row = editor
442 .row_for_block(decorations.prompt_block_id, cx)
443 .unwrap()
444 .0 as f32;
445
446 if (scroll_top..scroll_bottom).contains(&prompt_row) {
447 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
448 assist_id,
449 distance_from_top: prompt_row - scroll_top,
450 });
451 } else {
452 editor_assists.scroll_lock = None;
453 }
454 })
455 .ok();
456 }
457
458 fn handle_prompt_editor_focus_out(
459 &mut self,
460 assist_id: InlineAssistId,
461 cx: &mut WindowContext,
462 ) {
463 let assist = &self.assists[&assist_id];
464 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
465 if assist_group.active_assist_id == Some(assist_id) {
466 assist_group.active_assist_id = None;
467 if assist_group.linked {
468 for assist_id in &assist_group.assist_ids {
469 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
470 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
471 prompt_editor.set_show_cursor_when_unfocused(false, cx)
472 });
473 }
474 }
475 }
476 }
477 }
478
479 fn handle_prompt_editor_event(
480 &mut self,
481 prompt_editor: View<PromptEditor>,
482 event: &PromptEditorEvent,
483 cx: &mut WindowContext,
484 ) {
485 let assist_id = prompt_editor.read(cx).id;
486 match event {
487 PromptEditorEvent::StartRequested => {
488 self.start_assist(assist_id, cx);
489 }
490 PromptEditorEvent::StopRequested => {
491 self.stop_assist(assist_id, cx);
492 }
493 PromptEditorEvent::ConfirmRequested => {
494 self.finish_assist(assist_id, false, cx);
495 }
496 PromptEditorEvent::CancelRequested => {
497 self.finish_assist(assist_id, true, cx);
498 }
499 PromptEditorEvent::DismissRequested => {
500 self.dismiss_assist(assist_id, cx);
501 }
502 }
503 }
504
505 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
506 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
507 return;
508 };
509
510 let editor = editor.read(cx);
511 if editor.selections.count() == 1 {
512 let selection = editor.selections.newest::<usize>(cx);
513 let buffer = editor.buffer().read(cx).snapshot(cx);
514 for assist_id in &editor_assists.assist_ids {
515 let assist = &self.assists[assist_id];
516 let assist_range = assist.range.to_offset(&buffer);
517 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
518 {
519 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
520 self.dismiss_assist(*assist_id, cx);
521 } else {
522 self.finish_assist(*assist_id, false, cx);
523 }
524
525 return;
526 }
527 }
528 }
529
530 cx.propagate();
531 }
532
533 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
534 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
535 return;
536 };
537
538 let editor = editor.read(cx);
539 if editor.selections.count() == 1 {
540 let selection = editor.selections.newest::<usize>(cx);
541 let buffer = editor.buffer().read(cx).snapshot(cx);
542 let mut closest_assist_fallback = None;
543 for assist_id in &editor_assists.assist_ids {
544 let assist = &self.assists[assist_id];
545 let assist_range = assist.range.to_offset(&buffer);
546 if assist.decorations.is_some() {
547 if assist_range.contains(&selection.start)
548 && assist_range.contains(&selection.end)
549 {
550 self.focus_assist(*assist_id, cx);
551 return;
552 } else {
553 let distance_from_selection = assist_range
554 .start
555 .abs_diff(selection.start)
556 .min(assist_range.start.abs_diff(selection.end))
557 + assist_range
558 .end
559 .abs_diff(selection.start)
560 .min(assist_range.end.abs_diff(selection.end));
561 match closest_assist_fallback {
562 Some((_, old_distance)) => {
563 if distance_from_selection < old_distance {
564 closest_assist_fallback =
565 Some((assist_id, distance_from_selection));
566 }
567 }
568 None => {
569 closest_assist_fallback = Some((assist_id, distance_from_selection))
570 }
571 }
572 }
573 }
574 }
575
576 if let Some((&assist_id, _)) = closest_assist_fallback {
577 self.focus_assist(assist_id, cx);
578 }
579 }
580
581 cx.propagate();
582 }
583
584 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
585 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
586 for assist_id in editor_assists.assist_ids.clone() {
587 self.finish_assist(assist_id, true, cx);
588 }
589 }
590 }
591
592 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
593 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
594 return;
595 };
596 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
597 return;
598 };
599 let assist = &self.assists[&scroll_lock.assist_id];
600 let Some(decorations) = assist.decorations.as_ref() else {
601 return;
602 };
603
604 editor.update(cx, |editor, cx| {
605 let scroll_position = editor.scroll_position(cx);
606 let target_scroll_top = editor
607 .row_for_block(decorations.prompt_block_id, cx)
608 .unwrap()
609 .0 as f32
610 - scroll_lock.distance_from_top;
611 if target_scroll_top != scroll_position.y {
612 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
613 }
614 });
615 }
616
617 fn handle_editor_event(
618 &mut self,
619 editor: View<Editor>,
620 event: &EditorEvent,
621 cx: &mut WindowContext,
622 ) {
623 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
624 return;
625 };
626
627 match event {
628 EditorEvent::Edited { transaction_id } => {
629 let buffer = editor.read(cx).buffer().read(cx);
630 let edited_ranges =
631 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
632 let snapshot = buffer.snapshot(cx);
633
634 for assist_id in editor_assists.assist_ids.clone() {
635 let assist = &self.assists[&assist_id];
636 if matches!(
637 assist.codegen.read(cx).status,
638 CodegenStatus::Error(_) | CodegenStatus::Done
639 ) {
640 let assist_range = assist.range.to_offset(&snapshot);
641 if edited_ranges
642 .iter()
643 .any(|range| range.overlaps(&assist_range))
644 {
645 self.finish_assist(assist_id, false, cx);
646 }
647 }
648 }
649 }
650 EditorEvent::ScrollPositionChanged { .. } => {
651 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
652 let assist = &self.assists[&scroll_lock.assist_id];
653 if let Some(decorations) = assist.decorations.as_ref() {
654 let distance_from_top = editor.update(cx, |editor, cx| {
655 let scroll_top = editor.scroll_position(cx).y;
656 let prompt_row = editor
657 .row_for_block(decorations.prompt_block_id, cx)
658 .unwrap()
659 .0 as f32;
660 prompt_row - scroll_top
661 });
662
663 if distance_from_top != scroll_lock.distance_from_top {
664 editor_assists.scroll_lock = None;
665 }
666 }
667 }
668 }
669 EditorEvent::SelectionsChanged { .. } => {
670 for assist_id in editor_assists.assist_ids.clone() {
671 let assist = &self.assists[&assist_id];
672 if let Some(decorations) = assist.decorations.as_ref() {
673 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
674 return;
675 }
676 }
677 }
678
679 editor_assists.scroll_lock = None;
680 }
681 _ => {}
682 }
683 }
684
685 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
686 if let Some(assist) = self.assists.get(&assist_id) {
687 let assist_group_id = assist.group_id;
688 if self.assist_groups[&assist_group_id].linked {
689 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
690 self.finish_assist(assist_id, undo, cx);
691 }
692 return;
693 }
694 }
695
696 self.dismiss_assist(assist_id, cx);
697
698 if let Some(assist) = self.assists.remove(&assist_id) {
699 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
700 {
701 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
702 if entry.get().assist_ids.is_empty() {
703 entry.remove();
704 }
705 }
706
707 if let hash_map::Entry::Occupied(mut entry) =
708 self.assists_by_editor.entry(assist.editor.clone())
709 {
710 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
711 if entry.get().assist_ids.is_empty() {
712 entry.remove();
713 if let Some(editor) = assist.editor.upgrade() {
714 self.update_editor_highlights(&editor, cx);
715 }
716 } else {
717 entry.get().highlight_updates.send(()).ok();
718 }
719 }
720
721 if undo {
722 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
723 } else {
724 self.confirmed_assists.insert(assist_id, assist.codegen);
725 }
726 }
727
728 // Remove the assist from the status updates map
729 self.assist_observations.remove(&assist_id);
730 }
731
732 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
733 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
734 return false;
735 };
736 codegen.update(cx, |this, cx| this.undo(cx));
737 true
738 }
739
740 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
741 let Some(assist) = self.assists.get_mut(&assist_id) else {
742 return false;
743 };
744 let Some(editor) = assist.editor.upgrade() else {
745 return false;
746 };
747 let Some(decorations) = assist.decorations.take() else {
748 return false;
749 };
750
751 editor.update(cx, |editor, cx| {
752 let mut to_remove = decorations.removed_line_block_ids;
753 to_remove.insert(decorations.prompt_block_id);
754 to_remove.insert(decorations.end_block_id);
755 editor.remove_blocks(to_remove, None, cx);
756 });
757
758 if decorations
759 .prompt_editor
760 .focus_handle(cx)
761 .contains_focused(cx)
762 {
763 self.focus_next_assist(assist_id, cx);
764 }
765
766 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
767 if editor_assists
768 .scroll_lock
769 .as_ref()
770 .map_or(false, |lock| lock.assist_id == assist_id)
771 {
772 editor_assists.scroll_lock = None;
773 }
774 editor_assists.highlight_updates.send(()).ok();
775 }
776
777 true
778 }
779
780 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
781 let Some(assist) = self.assists.get(&assist_id) else {
782 return;
783 };
784
785 let assist_group = &self.assist_groups[&assist.group_id];
786 let assist_ix = assist_group
787 .assist_ids
788 .iter()
789 .position(|id| *id == assist_id)
790 .unwrap();
791 let assist_ids = assist_group
792 .assist_ids
793 .iter()
794 .skip(assist_ix + 1)
795 .chain(assist_group.assist_ids.iter().take(assist_ix));
796
797 for assist_id in assist_ids {
798 let assist = &self.assists[assist_id];
799 if assist.decorations.is_some() {
800 self.focus_assist(*assist_id, cx);
801 return;
802 }
803 }
804
805 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
806 }
807
808 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
809 let Some(assist) = self.assists.get(&assist_id) else {
810 return;
811 };
812
813 if let Some(decorations) = assist.decorations.as_ref() {
814 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
815 prompt_editor.editor.update(cx, |editor, cx| {
816 editor.focus(cx);
817 editor.select_all(&SelectAll, cx);
818 })
819 });
820 }
821
822 self.scroll_to_assist(assist_id, cx);
823 }
824
825 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
826 let Some(assist) = self.assists.get(&assist_id) else {
827 return;
828 };
829 let Some(editor) = assist.editor.upgrade() else {
830 return;
831 };
832
833 let position = assist.range.start;
834 editor.update(cx, |editor, cx| {
835 editor.change_selections(None, cx, |selections| {
836 selections.select_anchor_ranges([position..position])
837 });
838
839 let mut scroll_target_top;
840 let mut scroll_target_bottom;
841 if let Some(decorations) = assist.decorations.as_ref() {
842 scroll_target_top = editor
843 .row_for_block(decorations.prompt_block_id, cx)
844 .unwrap()
845 .0 as f32;
846 scroll_target_bottom = editor
847 .row_for_block(decorations.end_block_id, cx)
848 .unwrap()
849 .0 as f32;
850 } else {
851 let snapshot = editor.snapshot(cx);
852 let start_row = assist
853 .range
854 .start
855 .to_display_point(&snapshot.display_snapshot)
856 .row();
857 scroll_target_top = start_row.0 as f32;
858 scroll_target_bottom = scroll_target_top + 1.;
859 }
860 scroll_target_top -= editor.vertical_scroll_margin() as f32;
861 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
862
863 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
864 let scroll_top = editor.scroll_position(cx).y;
865 let scroll_bottom = scroll_top + height_in_lines;
866
867 if scroll_target_top < scroll_top {
868 editor.set_scroll_position(point(0., scroll_target_top), cx);
869 } else if scroll_target_bottom > scroll_bottom {
870 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
871 editor
872 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
873 } else {
874 editor.set_scroll_position(point(0., scroll_target_top), cx);
875 }
876 }
877 });
878 }
879
880 fn unlink_assist_group(
881 &mut self,
882 assist_group_id: InlineAssistGroupId,
883 cx: &mut WindowContext,
884 ) -> Vec<InlineAssistId> {
885 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
886 assist_group.linked = false;
887 for assist_id in &assist_group.assist_ids {
888 let assist = self.assists.get_mut(assist_id).unwrap();
889 if let Some(editor_decorations) = assist.decorations.as_ref() {
890 editor_decorations
891 .prompt_editor
892 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
893 }
894 }
895 assist_group.assist_ids.clone()
896 }
897
898 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
899 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
900 assist
901 } else {
902 return;
903 };
904
905 let assist_group_id = assist.group_id;
906 if self.assist_groups[&assist_group_id].linked {
907 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
908 self.start_assist(assist_id, cx);
909 }
910 return;
911 }
912
913 let Some(user_prompt) = assist.user_prompt(cx) else {
914 return;
915 };
916
917 self.prompt_history.retain(|prompt| *prompt != user_prompt);
918 self.prompt_history.push_back(user_prompt.clone());
919 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
920 self.prompt_history.pop_front();
921 }
922
923 let assistant_panel_context = assist.assistant_panel_context(cx);
924
925 assist
926 .codegen
927 .update(cx, |codegen, cx| {
928 codegen.start(user_prompt, assistant_panel_context, cx)
929 })
930 .log_err();
931
932 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
933 tx.send(()).ok();
934 }
935 }
936
937 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
938 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
939 assist
940 } else {
941 return;
942 };
943
944 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
945
946 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
947 tx.send(()).ok();
948 }
949 }
950
951 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
952 if let Some(assist) = self.assists.get(&assist_id) {
953 match &assist.codegen.read(cx).status {
954 CodegenStatus::Idle => InlineAssistStatus::Idle,
955 CodegenStatus::Pending => InlineAssistStatus::Pending,
956 CodegenStatus::Done => InlineAssistStatus::Done,
957 CodegenStatus::Error(_) => InlineAssistStatus::Error,
958 }
959 } else if self.confirmed_assists.contains_key(&assist_id) {
960 InlineAssistStatus::Confirmed
961 } else {
962 InlineAssistStatus::Canceled
963 }
964 }
965
966 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
967 let mut gutter_pending_ranges = Vec::new();
968 let mut gutter_transformed_ranges = Vec::new();
969 let mut foreground_ranges = Vec::new();
970 let mut inserted_row_ranges = Vec::new();
971 let empty_assist_ids = Vec::new();
972 let assist_ids = self
973 .assists_by_editor
974 .get(&editor.downgrade())
975 .map_or(&empty_assist_ids, |editor_assists| {
976 &editor_assists.assist_ids
977 });
978
979 for assist_id in assist_ids {
980 if let Some(assist) = self.assists.get(assist_id) {
981 let codegen = assist.codegen.read(cx);
982 let buffer = codegen.buffer.read(cx).read(cx);
983 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
984
985 let pending_range =
986 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
987 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
988 gutter_pending_ranges.push(pending_range);
989 }
990
991 if let Some(edit_position) = codegen.edit_position {
992 let edited_range = assist.range.start..edit_position;
993 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
994 gutter_transformed_ranges.push(edited_range);
995 }
996 }
997
998 if assist.decorations.is_some() {
999 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
1000 }
1001 }
1002 }
1003
1004 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1005 merge_ranges(&mut foreground_ranges, &snapshot);
1006 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1007 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1008 editor.update(cx, |editor, cx| {
1009 enum GutterPendingRange {}
1010 if gutter_pending_ranges.is_empty() {
1011 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1012 } else {
1013 editor.highlight_gutter::<GutterPendingRange>(
1014 &gutter_pending_ranges,
1015 |cx| cx.theme().status().info_background,
1016 cx,
1017 )
1018 }
1019
1020 enum GutterTransformedRange {}
1021 if gutter_transformed_ranges.is_empty() {
1022 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1023 } else {
1024 editor.highlight_gutter::<GutterTransformedRange>(
1025 &gutter_transformed_ranges,
1026 |cx| cx.theme().status().info,
1027 cx,
1028 )
1029 }
1030
1031 if foreground_ranges.is_empty() {
1032 editor.clear_highlights::<InlineAssist>(cx);
1033 } else {
1034 editor.highlight_text::<InlineAssist>(
1035 foreground_ranges,
1036 HighlightStyle {
1037 fade_out: Some(0.6),
1038 ..Default::default()
1039 },
1040 cx,
1041 );
1042 }
1043
1044 editor.clear_row_highlights::<InlineAssist>();
1045 for row_range in inserted_row_ranges {
1046 editor.highlight_rows::<InlineAssist>(
1047 row_range,
1048 Some(cx.theme().status().info_background),
1049 false,
1050 cx,
1051 );
1052 }
1053 });
1054 }
1055
1056 fn update_editor_blocks(
1057 &mut self,
1058 editor: &View<Editor>,
1059 assist_id: InlineAssistId,
1060 cx: &mut WindowContext,
1061 ) {
1062 let Some(assist) = self.assists.get_mut(&assist_id) else {
1063 return;
1064 };
1065 let Some(decorations) = assist.decorations.as_mut() else {
1066 return;
1067 };
1068
1069 let codegen = assist.codegen.read(cx);
1070 let old_snapshot = codegen.snapshot.clone();
1071 let old_buffer = codegen.old_buffer.clone();
1072 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
1073
1074 editor.update(cx, |editor, cx| {
1075 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1076 editor.remove_blocks(old_blocks, None, cx);
1077
1078 let mut new_blocks = Vec::new();
1079 for (new_row, old_row_range) in deleted_row_ranges {
1080 let (_, buffer_start) = old_snapshot
1081 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1082 .unwrap();
1083 let (_, buffer_end) = old_snapshot
1084 .point_to_buffer_offset(Point::new(
1085 *old_row_range.end(),
1086 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1087 ))
1088 .unwrap();
1089
1090 let deleted_lines_editor = cx.new_view(|cx| {
1091 let multi_buffer = cx.new_model(|_| {
1092 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1093 });
1094 multi_buffer.update(cx, |multi_buffer, cx| {
1095 multi_buffer.push_excerpts(
1096 old_buffer.clone(),
1097 Some(ExcerptRange {
1098 context: buffer_start..buffer_end,
1099 primary: None,
1100 }),
1101 cx,
1102 );
1103 });
1104
1105 enum DeletedLines {}
1106 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1107 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1108 editor.set_show_wrap_guides(false, cx);
1109 editor.set_show_gutter(false, cx);
1110 editor.scroll_manager.set_forbid_vertical_scroll(true);
1111 editor.set_read_only(true);
1112 editor.highlight_rows::<DeletedLines>(
1113 Anchor::min()..=Anchor::max(),
1114 Some(cx.theme().status().deleted_background),
1115 false,
1116 cx,
1117 );
1118 editor
1119 });
1120
1121 let height =
1122 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1123 new_blocks.push(BlockProperties {
1124 position: new_row,
1125 height,
1126 style: BlockStyle::Flex,
1127 render: Box::new(move |cx| {
1128 div()
1129 .bg(cx.theme().status().deleted_background)
1130 .size_full()
1131 .h(height as f32 * cx.line_height())
1132 .pl(cx.gutter_dimensions.full_width())
1133 .child(deleted_lines_editor.clone())
1134 .into_any_element()
1135 }),
1136 disposition: BlockDisposition::Above,
1137 priority: 0,
1138 });
1139 }
1140
1141 decorations.removed_line_block_ids = editor
1142 .insert_blocks(new_blocks, None, cx)
1143 .into_iter()
1144 .collect();
1145 })
1146 }
1147
1148 pub fn observe_assist(&mut self, assist_id: InlineAssistId) -> async_watch::Receiver<()> {
1149 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1150 rx.clone()
1151 } else {
1152 let (tx, rx) = async_watch::channel(());
1153 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1154 rx
1155 }
1156 }
1157}
1158
1159pub enum InlineAssistStatus {
1160 Idle,
1161 Pending,
1162 Done,
1163 Error,
1164 Confirmed,
1165 Canceled,
1166}
1167
1168impl InlineAssistStatus {
1169 pub(crate) fn is_pending(&self) -> bool {
1170 matches!(self, Self::Pending)
1171 }
1172 pub(crate) fn is_confirmed(&self) -> bool {
1173 matches!(self, Self::Confirmed)
1174 }
1175 pub(crate) fn is_done(&self) -> bool {
1176 matches!(self, Self::Done)
1177 }
1178}
1179
1180struct EditorInlineAssists {
1181 assist_ids: Vec<InlineAssistId>,
1182 scroll_lock: Option<InlineAssistScrollLock>,
1183 highlight_updates: async_watch::Sender<()>,
1184 _update_highlights: Task<Result<()>>,
1185 _subscriptions: Vec<gpui::Subscription>,
1186}
1187
1188struct InlineAssistScrollLock {
1189 assist_id: InlineAssistId,
1190 distance_from_top: f32,
1191}
1192
1193impl EditorInlineAssists {
1194 #[allow(clippy::too_many_arguments)]
1195 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1196 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1197 Self {
1198 assist_ids: Vec::new(),
1199 scroll_lock: None,
1200 highlight_updates: highlight_updates_tx,
1201 _update_highlights: cx.spawn(|mut cx| {
1202 let editor = editor.downgrade();
1203 async move {
1204 while let Ok(()) = highlight_updates_rx.changed().await {
1205 let editor = editor.upgrade().context("editor was dropped")?;
1206 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1207 assistant.update_editor_highlights(&editor, cx);
1208 })?;
1209 }
1210 Ok(())
1211 }
1212 }),
1213 _subscriptions: vec![
1214 cx.observe_release(editor, {
1215 let editor = editor.downgrade();
1216 |_, cx| {
1217 InlineAssistant::update_global(cx, |this, cx| {
1218 this.handle_editor_release(editor, cx);
1219 })
1220 }
1221 }),
1222 cx.observe(editor, move |editor, cx| {
1223 InlineAssistant::update_global(cx, |this, cx| {
1224 this.handle_editor_change(editor, cx)
1225 })
1226 }),
1227 cx.subscribe(editor, move |editor, event, cx| {
1228 InlineAssistant::update_global(cx, |this, cx| {
1229 this.handle_editor_event(editor, event, cx)
1230 })
1231 }),
1232 editor.update(cx, |editor, cx| {
1233 let editor_handle = cx.view().downgrade();
1234 editor.register_action(
1235 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1236 InlineAssistant::update_global(cx, |this, cx| {
1237 if let Some(editor) = editor_handle.upgrade() {
1238 this.handle_editor_newline(editor, cx)
1239 }
1240 })
1241 },
1242 )
1243 }),
1244 editor.update(cx, |editor, cx| {
1245 let editor_handle = cx.view().downgrade();
1246 editor.register_action(
1247 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1248 InlineAssistant::update_global(cx, |this, cx| {
1249 if let Some(editor) = editor_handle.upgrade() {
1250 this.handle_editor_cancel(editor, cx)
1251 }
1252 })
1253 },
1254 )
1255 }),
1256 ],
1257 }
1258 }
1259}
1260
1261struct InlineAssistGroup {
1262 assist_ids: Vec<InlineAssistId>,
1263 linked: bool,
1264 active_assist_id: Option<InlineAssistId>,
1265}
1266
1267impl InlineAssistGroup {
1268 fn new() -> Self {
1269 Self {
1270 assist_ids: Vec::new(),
1271 linked: true,
1272 active_assist_id: None,
1273 }
1274 }
1275}
1276
1277fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1278 let editor = editor.clone();
1279 Box::new(move |cx: &mut BlockContext| {
1280 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1281 editor.clone().into_any_element()
1282 })
1283}
1284
1285#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1286pub struct InlineAssistId(usize);
1287
1288impl InlineAssistId {
1289 fn post_inc(&mut self) -> InlineAssistId {
1290 let id = *self;
1291 self.0 += 1;
1292 id
1293 }
1294}
1295
1296#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1297struct InlineAssistGroupId(usize);
1298
1299impl InlineAssistGroupId {
1300 fn post_inc(&mut self) -> InlineAssistGroupId {
1301 let id = *self;
1302 self.0 += 1;
1303 id
1304 }
1305}
1306
1307enum PromptEditorEvent {
1308 StartRequested,
1309 StopRequested,
1310 ConfirmRequested,
1311 CancelRequested,
1312 DismissRequested,
1313}
1314
1315struct PromptEditor {
1316 id: InlineAssistId,
1317 fs: Arc<dyn Fs>,
1318 editor: View<Editor>,
1319 edited_since_done: bool,
1320 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1321 prompt_history: VecDeque<String>,
1322 prompt_history_ix: Option<usize>,
1323 pending_prompt: String,
1324 codegen: Model<Codegen>,
1325 _codegen_subscription: Subscription,
1326 editor_subscriptions: Vec<Subscription>,
1327 pending_token_count: Task<Result<()>>,
1328 token_counts: Option<TokenCounts>,
1329 _token_count_subscriptions: Vec<Subscription>,
1330 workspace: Option<WeakView<Workspace>>,
1331 show_rate_limit_notice: bool,
1332}
1333
1334#[derive(Copy, Clone)]
1335pub struct TokenCounts {
1336 total: usize,
1337 assistant_panel: usize,
1338}
1339
1340impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1341
1342impl Render for PromptEditor {
1343 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1344 let gutter_dimensions = *self.gutter_dimensions.lock();
1345 let status = &self.codegen.read(cx).status;
1346 let buttons = match status {
1347 CodegenStatus::Idle => {
1348 vec![
1349 IconButton::new("cancel", IconName::Close)
1350 .icon_color(Color::Muted)
1351 .shape(IconButtonShape::Square)
1352 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1353 .on_click(
1354 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1355 ),
1356 IconButton::new("start", IconName::SparkleAlt)
1357 .icon_color(Color::Muted)
1358 .shape(IconButtonShape::Square)
1359 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1360 .on_click(
1361 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1362 ),
1363 ]
1364 }
1365 CodegenStatus::Pending => {
1366 vec![
1367 IconButton::new("cancel", IconName::Close)
1368 .icon_color(Color::Muted)
1369 .shape(IconButtonShape::Square)
1370 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1371 .on_click(
1372 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1373 ),
1374 IconButton::new("stop", IconName::Stop)
1375 .icon_color(Color::Error)
1376 .shape(IconButtonShape::Square)
1377 .tooltip(|cx| {
1378 Tooltip::with_meta(
1379 "Interrupt Transformation",
1380 Some(&menu::Cancel),
1381 "Changes won't be discarded",
1382 cx,
1383 )
1384 })
1385 .on_click(
1386 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1387 ),
1388 ]
1389 }
1390 CodegenStatus::Error(_) | CodegenStatus::Done => {
1391 vec![
1392 IconButton::new("cancel", IconName::Close)
1393 .icon_color(Color::Muted)
1394 .shape(IconButtonShape::Square)
1395 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1396 .on_click(
1397 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1398 ),
1399 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1400 IconButton::new("restart", IconName::RotateCw)
1401 .icon_color(Color::Info)
1402 .shape(IconButtonShape::Square)
1403 .tooltip(|cx| {
1404 Tooltip::with_meta(
1405 "Restart Transformation",
1406 Some(&menu::Confirm),
1407 "Changes will be discarded",
1408 cx,
1409 )
1410 })
1411 .on_click(cx.listener(|_, _, cx| {
1412 cx.emit(PromptEditorEvent::StartRequested);
1413 }))
1414 } else {
1415 IconButton::new("confirm", IconName::Check)
1416 .icon_color(Color::Info)
1417 .shape(IconButtonShape::Square)
1418 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1419 .on_click(cx.listener(|_, _, cx| {
1420 cx.emit(PromptEditorEvent::ConfirmRequested);
1421 }))
1422 },
1423 ]
1424 }
1425 };
1426
1427 h_flex()
1428 .bg(cx.theme().colors().editor_background)
1429 .border_y_1()
1430 .border_color(cx.theme().status().info_border)
1431 .size_full()
1432 .py(cx.line_height() / 2.)
1433 .on_action(cx.listener(Self::confirm))
1434 .on_action(cx.listener(Self::cancel))
1435 .on_action(cx.listener(Self::move_up))
1436 .on_action(cx.listener(Self::move_down))
1437 .child(
1438 h_flex()
1439 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1440 .justify_center()
1441 .gap_2()
1442 .child(
1443 ModelSelector::new(
1444 self.fs.clone(),
1445 IconButton::new("context", IconName::SlidersAlt)
1446 .shape(IconButtonShape::Square)
1447 .icon_size(IconSize::Small)
1448 .icon_color(Color::Muted)
1449 .tooltip(move |cx| {
1450 Tooltip::with_meta(
1451 format!(
1452 "Using {}",
1453 LanguageModelRegistry::read_global(cx)
1454 .active_model()
1455 .map(|model| model.name().0)
1456 .unwrap_or_else(|| "No model selected".into()),
1457 ),
1458 None,
1459 "Change Model",
1460 cx,
1461 )
1462 }),
1463 )
1464 .with_info_text(
1465 "Inline edits use context\n\
1466 from the currently selected\n\
1467 assistant panel tab.",
1468 ),
1469 )
1470 .map(|el| {
1471 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1472 return el;
1473 };
1474
1475 let error_message = SharedString::from(error.to_string());
1476 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1477 && cx.has_flag::<ZedPro>()
1478 {
1479 el.child(
1480 v_flex()
1481 .child(
1482 IconButton::new("rate-limit-error", IconName::XCircle)
1483 .selected(self.show_rate_limit_notice)
1484 .shape(IconButtonShape::Square)
1485 .icon_size(IconSize::Small)
1486 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1487 )
1488 .children(self.show_rate_limit_notice.then(|| {
1489 deferred(
1490 anchored()
1491 .position_mode(gpui::AnchoredPositionMode::Local)
1492 .position(point(px(0.), px(24.)))
1493 .anchor(gpui::AnchorCorner::TopLeft)
1494 .child(self.render_rate_limit_notice(cx)),
1495 )
1496 })),
1497 )
1498 } else {
1499 el.child(
1500 div()
1501 .id("error")
1502 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1503 .child(
1504 Icon::new(IconName::XCircle)
1505 .size(IconSize::Small)
1506 .color(Color::Error),
1507 ),
1508 )
1509 }
1510 }),
1511 )
1512 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1513 .child(
1514 h_flex()
1515 .gap_2()
1516 .pr_6()
1517 .children(self.render_token_count(cx))
1518 .children(buttons),
1519 )
1520 }
1521}
1522
1523impl FocusableView for PromptEditor {
1524 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1525 self.editor.focus_handle(cx)
1526 }
1527}
1528
1529impl PromptEditor {
1530 const MAX_LINES: u8 = 8;
1531
1532 #[allow(clippy::too_many_arguments)]
1533 fn new(
1534 id: InlineAssistId,
1535 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1536 prompt_history: VecDeque<String>,
1537 prompt_buffer: Model<MultiBuffer>,
1538 codegen: Model<Codegen>,
1539 parent_editor: &View<Editor>,
1540 assistant_panel: Option<&View<AssistantPanel>>,
1541 workspace: Option<WeakView<Workspace>>,
1542 fs: Arc<dyn Fs>,
1543 cx: &mut ViewContext<Self>,
1544 ) -> Self {
1545 let prompt_editor = cx.new_view(|cx| {
1546 let mut editor = Editor::new(
1547 EditorMode::AutoHeight {
1548 max_lines: Self::MAX_LINES as usize,
1549 },
1550 prompt_buffer,
1551 None,
1552 false,
1553 cx,
1554 );
1555 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1556 // Since the prompt editors for all inline assistants are linked,
1557 // always show the cursor (even when it isn't focused) because
1558 // typing in one will make what you typed appear in all of them.
1559 editor.set_show_cursor_when_unfocused(true, cx);
1560 editor.set_placeholder_text("Add a prompt…", cx);
1561 editor
1562 });
1563
1564 let mut token_count_subscriptions = Vec::new();
1565 token_count_subscriptions
1566 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1567 if let Some(assistant_panel) = assistant_panel {
1568 token_count_subscriptions
1569 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1570 }
1571
1572 let mut this = Self {
1573 id,
1574 editor: prompt_editor,
1575 edited_since_done: false,
1576 gutter_dimensions,
1577 prompt_history,
1578 prompt_history_ix: None,
1579 pending_prompt: String::new(),
1580 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1581 editor_subscriptions: Vec::new(),
1582 codegen,
1583 fs,
1584 pending_token_count: Task::ready(Ok(())),
1585 token_counts: None,
1586 _token_count_subscriptions: token_count_subscriptions,
1587 workspace,
1588 show_rate_limit_notice: false,
1589 };
1590 this.count_tokens(cx);
1591 this.subscribe_to_editor(cx);
1592 this
1593 }
1594
1595 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1596 self.editor_subscriptions.clear();
1597 self.editor_subscriptions
1598 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1599 }
1600
1601 fn set_show_cursor_when_unfocused(
1602 &mut self,
1603 show_cursor_when_unfocused: bool,
1604 cx: &mut ViewContext<Self>,
1605 ) {
1606 self.editor.update(cx, |editor, cx| {
1607 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1608 });
1609 }
1610
1611 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1612 let prompt = self.prompt(cx);
1613 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1614 self.editor = cx.new_view(|cx| {
1615 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1616 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1617 editor.set_placeholder_text("Add a prompt…", cx);
1618 editor.set_text(prompt, cx);
1619 if focus {
1620 editor.focus(cx);
1621 }
1622 editor
1623 });
1624 self.subscribe_to_editor(cx);
1625 }
1626
1627 fn prompt(&self, cx: &AppContext) -> String {
1628 self.editor.read(cx).text(cx)
1629 }
1630
1631 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1632 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1633 if self.show_rate_limit_notice {
1634 cx.focus_view(&self.editor);
1635 }
1636 cx.notify();
1637 }
1638
1639 fn handle_parent_editor_event(
1640 &mut self,
1641 _: View<Editor>,
1642 event: &EditorEvent,
1643 cx: &mut ViewContext<Self>,
1644 ) {
1645 if let EditorEvent::BufferEdited { .. } = event {
1646 self.count_tokens(cx);
1647 }
1648 }
1649
1650 fn handle_assistant_panel_event(
1651 &mut self,
1652 _: View<AssistantPanel>,
1653 event: &AssistantPanelEvent,
1654 cx: &mut ViewContext<Self>,
1655 ) {
1656 let AssistantPanelEvent::ContextEdited { .. } = event;
1657 self.count_tokens(cx);
1658 }
1659
1660 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1661 let assist_id = self.id;
1662 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1663 cx.background_executor().timer(Duration::from_secs(1)).await;
1664 let token_count = cx
1665 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1666 let assist = inline_assistant
1667 .assists
1668 .get(&assist_id)
1669 .context("assist not found")?;
1670 anyhow::Ok(assist.count_tokens(cx))
1671 })??
1672 .await?;
1673
1674 this.update(&mut cx, |this, cx| {
1675 this.token_counts = Some(token_count);
1676 cx.notify();
1677 })
1678 })
1679 }
1680
1681 fn handle_prompt_editor_events(
1682 &mut self,
1683 _: View<Editor>,
1684 event: &EditorEvent,
1685 cx: &mut ViewContext<Self>,
1686 ) {
1687 match event {
1688 EditorEvent::Edited { .. } => {
1689 let prompt = self.editor.read(cx).text(cx);
1690 if self
1691 .prompt_history_ix
1692 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1693 {
1694 self.prompt_history_ix.take();
1695 self.pending_prompt = prompt;
1696 }
1697
1698 self.edited_since_done = true;
1699 cx.notify();
1700 }
1701 EditorEvent::BufferEdited => {
1702 self.count_tokens(cx);
1703 }
1704 EditorEvent::Blurred => {
1705 if self.show_rate_limit_notice {
1706 self.show_rate_limit_notice = false;
1707 cx.notify();
1708 }
1709 }
1710 _ => {}
1711 }
1712 }
1713
1714 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1715 match &self.codegen.read(cx).status {
1716 CodegenStatus::Idle => {
1717 self.editor
1718 .update(cx, |editor, _| editor.set_read_only(false));
1719 }
1720 CodegenStatus::Pending => {
1721 self.editor
1722 .update(cx, |editor, _| editor.set_read_only(true));
1723 }
1724 CodegenStatus::Done => {
1725 self.edited_since_done = false;
1726 self.editor
1727 .update(cx, |editor, _| editor.set_read_only(false));
1728 }
1729 CodegenStatus::Error(error) => {
1730 if cx.has_flag::<ZedPro>()
1731 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1732 && !dismissed_rate_limit_notice()
1733 {
1734 self.show_rate_limit_notice = true;
1735 cx.notify();
1736 }
1737
1738 self.edited_since_done = false;
1739 self.editor
1740 .update(cx, |editor, _| editor.set_read_only(false));
1741 }
1742 }
1743 }
1744
1745 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1746 match &self.codegen.read(cx).status {
1747 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1748 cx.emit(PromptEditorEvent::CancelRequested);
1749 }
1750 CodegenStatus::Pending => {
1751 cx.emit(PromptEditorEvent::StopRequested);
1752 }
1753 }
1754 }
1755
1756 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1757 match &self.codegen.read(cx).status {
1758 CodegenStatus::Idle => {
1759 cx.emit(PromptEditorEvent::StartRequested);
1760 }
1761 CodegenStatus::Pending => {
1762 cx.emit(PromptEditorEvent::DismissRequested);
1763 }
1764 CodegenStatus::Done | CodegenStatus::Error(_) => {
1765 if self.edited_since_done {
1766 cx.emit(PromptEditorEvent::StartRequested);
1767 } else {
1768 cx.emit(PromptEditorEvent::ConfirmRequested);
1769 }
1770 }
1771 }
1772 }
1773
1774 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1775 if let Some(ix) = self.prompt_history_ix {
1776 if ix > 0 {
1777 self.prompt_history_ix = Some(ix - 1);
1778 let prompt = self.prompt_history[ix - 1].as_str();
1779 self.editor.update(cx, |editor, cx| {
1780 editor.set_text(prompt, cx);
1781 editor.move_to_beginning(&Default::default(), cx);
1782 });
1783 }
1784 } else if !self.prompt_history.is_empty() {
1785 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1786 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1787 self.editor.update(cx, |editor, cx| {
1788 editor.set_text(prompt, cx);
1789 editor.move_to_beginning(&Default::default(), cx);
1790 });
1791 }
1792 }
1793
1794 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1795 if let Some(ix) = self.prompt_history_ix {
1796 if ix < self.prompt_history.len() - 1 {
1797 self.prompt_history_ix = Some(ix + 1);
1798 let prompt = self.prompt_history[ix + 1].as_str();
1799 self.editor.update(cx, |editor, cx| {
1800 editor.set_text(prompt, cx);
1801 editor.move_to_end(&Default::default(), cx)
1802 });
1803 } else {
1804 self.prompt_history_ix = None;
1805 let prompt = self.pending_prompt.as_str();
1806 self.editor.update(cx, |editor, cx| {
1807 editor.set_text(prompt, cx);
1808 editor.move_to_end(&Default::default(), cx)
1809 });
1810 }
1811 }
1812 }
1813
1814 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1815 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1816 let token_counts = self.token_counts?;
1817 let max_token_count = model.max_token_count();
1818
1819 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1820 let token_count_color = if remaining_tokens <= 0 {
1821 Color::Error
1822 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1823 Color::Warning
1824 } else {
1825 Color::Muted
1826 };
1827
1828 let mut token_count = h_flex()
1829 .id("token_count")
1830 .gap_0p5()
1831 .child(
1832 Label::new(humanize_token_count(token_counts.total))
1833 .size(LabelSize::Small)
1834 .color(token_count_color),
1835 )
1836 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1837 .child(
1838 Label::new(humanize_token_count(max_token_count))
1839 .size(LabelSize::Small)
1840 .color(Color::Muted),
1841 );
1842 if let Some(workspace) = self.workspace.clone() {
1843 token_count = token_count
1844 .tooltip(move |cx| {
1845 Tooltip::with_meta(
1846 format!(
1847 "Tokens Used ({} from the Assistant Panel)",
1848 humanize_token_count(token_counts.assistant_panel)
1849 ),
1850 None,
1851 "Click to open the Assistant Panel",
1852 cx,
1853 )
1854 })
1855 .cursor_pointer()
1856 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1857 .on_click(move |_, cx| {
1858 cx.stop_propagation();
1859 workspace
1860 .update(cx, |workspace, cx| {
1861 workspace.focus_panel::<AssistantPanel>(cx)
1862 })
1863 .ok();
1864 });
1865 } else {
1866 token_count = token_count
1867 .cursor_default()
1868 .tooltip(|cx| Tooltip::text("Tokens used", cx));
1869 }
1870
1871 Some(token_count)
1872 }
1873
1874 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1875 let settings = ThemeSettings::get_global(cx);
1876 let text_style = TextStyle {
1877 color: if self.editor.read(cx).read_only(cx) {
1878 cx.theme().colors().text_disabled
1879 } else {
1880 cx.theme().colors().text
1881 },
1882 font_family: settings.ui_font.family.clone(),
1883 font_features: settings.ui_font.features.clone(),
1884 font_fallbacks: settings.ui_font.fallbacks.clone(),
1885 font_size: rems(0.875).into(),
1886 font_weight: settings.ui_font.weight,
1887 line_height: relative(1.3),
1888 ..Default::default()
1889 };
1890 EditorElement::new(
1891 &self.editor,
1892 EditorStyle {
1893 background: cx.theme().colors().editor_background,
1894 local_player: cx.theme().players().local(),
1895 text: text_style,
1896 ..Default::default()
1897 },
1898 )
1899 }
1900
1901 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1902 Popover::new().child(
1903 v_flex()
1904 .occlude()
1905 .p_2()
1906 .child(
1907 Label::new("Out of Tokens")
1908 .size(LabelSize::Small)
1909 .weight(FontWeight::BOLD),
1910 )
1911 .child(Label::new(
1912 "Try Zed Pro for higher limits, a wider range of models, and more.",
1913 ))
1914 .child(
1915 h_flex()
1916 .justify_between()
1917 .child(CheckboxWithLabel::new(
1918 "dont-show-again",
1919 Label::new("Don't show again"),
1920 if dismissed_rate_limit_notice() {
1921 ui::Selection::Selected
1922 } else {
1923 ui::Selection::Unselected
1924 },
1925 |selection, cx| {
1926 let is_dismissed = match selection {
1927 ui::Selection::Unselected => false,
1928 ui::Selection::Indeterminate => return,
1929 ui::Selection::Selected => true,
1930 };
1931
1932 set_rate_limit_notice_dismissed(is_dismissed, cx)
1933 },
1934 ))
1935 .child(
1936 h_flex()
1937 .gap_2()
1938 .child(
1939 Button::new("dismiss", "Dismiss")
1940 .style(ButtonStyle::Transparent)
1941 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1942 )
1943 .child(Button::new("more-info", "More Info").on_click(
1944 |_event, cx| {
1945 cx.dispatch_action(Box::new(
1946 zed_actions::OpenAccountSettings,
1947 ))
1948 },
1949 )),
1950 ),
1951 ),
1952 )
1953 }
1954}
1955
1956const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1957
1958fn dismissed_rate_limit_notice() -> bool {
1959 db::kvp::KEY_VALUE_STORE
1960 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1961 .log_err()
1962 .map_or(false, |s| s.is_some())
1963}
1964
1965fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1966 db::write_and_log(cx, move || async move {
1967 if is_dismissed {
1968 db::kvp::KEY_VALUE_STORE
1969 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1970 .await
1971 } else {
1972 db::kvp::KEY_VALUE_STORE
1973 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1974 .await
1975 }
1976 })
1977}
1978
1979struct InlineAssist {
1980 group_id: InlineAssistGroupId,
1981 range: Range<Anchor>,
1982 editor: WeakView<Editor>,
1983 decorations: Option<InlineAssistDecorations>,
1984 codegen: Model<Codegen>,
1985 _subscriptions: Vec<Subscription>,
1986 workspace: Option<WeakView<Workspace>>,
1987 include_context: bool,
1988}
1989
1990impl InlineAssist {
1991 #[allow(clippy::too_many_arguments)]
1992 fn new(
1993 assist_id: InlineAssistId,
1994 group_id: InlineAssistGroupId,
1995 include_context: bool,
1996 editor: &View<Editor>,
1997 prompt_editor: &View<PromptEditor>,
1998 prompt_block_id: CustomBlockId,
1999 end_block_id: CustomBlockId,
2000 range: Range<Anchor>,
2001 codegen: Model<Codegen>,
2002 workspace: Option<WeakView<Workspace>>,
2003 cx: &mut WindowContext,
2004 ) -> Self {
2005 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2006 InlineAssist {
2007 group_id,
2008 include_context,
2009 editor: editor.downgrade(),
2010 decorations: Some(InlineAssistDecorations {
2011 prompt_block_id,
2012 prompt_editor: prompt_editor.clone(),
2013 removed_line_block_ids: HashSet::default(),
2014 end_block_id,
2015 }),
2016 range,
2017 codegen: codegen.clone(),
2018 workspace: workspace.clone(),
2019 _subscriptions: vec![
2020 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2021 InlineAssistant::update_global(cx, |this, cx| {
2022 this.handle_prompt_editor_focus_in(assist_id, cx)
2023 })
2024 }),
2025 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2026 InlineAssistant::update_global(cx, |this, cx| {
2027 this.handle_prompt_editor_focus_out(assist_id, cx)
2028 })
2029 }),
2030 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2031 InlineAssistant::update_global(cx, |this, cx| {
2032 this.handle_prompt_editor_event(prompt_editor, event, cx)
2033 })
2034 }),
2035 cx.observe(&codegen, {
2036 let editor = editor.downgrade();
2037 move |_, cx| {
2038 if let Some(editor) = editor.upgrade() {
2039 InlineAssistant::update_global(cx, |this, cx| {
2040 if let Some(editor_assists) =
2041 this.assists_by_editor.get(&editor.downgrade())
2042 {
2043 editor_assists.highlight_updates.send(()).ok();
2044 }
2045
2046 this.update_editor_blocks(&editor, assist_id, cx);
2047 })
2048 }
2049 }
2050 }),
2051 cx.subscribe(&codegen, move |codegen, event, cx| {
2052 InlineAssistant::update_global(cx, |this, cx| match event {
2053 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2054 CodegenEvent::Finished => {
2055 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2056 assist
2057 } else {
2058 return;
2059 };
2060
2061 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
2062 if assist.decorations.is_none() {
2063 if let Some(workspace) = assist
2064 .workspace
2065 .as_ref()
2066 .and_then(|workspace| workspace.upgrade())
2067 {
2068 let error = format!("Inline assistant error: {}", error);
2069 workspace.update(cx, |workspace, cx| {
2070 struct InlineAssistantError;
2071
2072 let id =
2073 NotificationId::identified::<InlineAssistantError>(
2074 assist_id.0,
2075 );
2076
2077 workspace.show_toast(Toast::new(id, error), cx);
2078 })
2079 }
2080 }
2081 }
2082
2083 if assist.decorations.is_none() {
2084 this.finish_assist(assist_id, false, cx);
2085 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2086 tx.0.send(()).ok();
2087 }
2088 }
2089 })
2090 }),
2091 ],
2092 }
2093 }
2094
2095 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2096 let decorations = self.decorations.as_ref()?;
2097 Some(decorations.prompt_editor.read(cx).prompt(cx))
2098 }
2099
2100 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2101 if self.include_context {
2102 let workspace = self.workspace.as_ref()?;
2103 let workspace = workspace.upgrade()?.read(cx);
2104 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2105 Some(
2106 assistant_panel
2107 .read(cx)
2108 .active_context(cx)?
2109 .read(cx)
2110 .to_completion_request(cx),
2111 )
2112 } else {
2113 None
2114 }
2115 }
2116
2117 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2118 let Some(user_prompt) = self.user_prompt(cx) else {
2119 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2120 };
2121 let assistant_panel_context = self.assistant_panel_context(cx);
2122 self.codegen
2123 .read(cx)
2124 .count_tokens(user_prompt, assistant_panel_context, cx)
2125 }
2126}
2127
2128struct InlineAssistDecorations {
2129 prompt_block_id: CustomBlockId,
2130 prompt_editor: View<PromptEditor>,
2131 removed_line_block_ids: HashSet<CustomBlockId>,
2132 end_block_id: CustomBlockId,
2133}
2134
2135#[derive(Debug)]
2136pub enum CodegenEvent {
2137 Finished,
2138 Undone,
2139}
2140
2141pub struct Codegen {
2142 buffer: Model<MultiBuffer>,
2143 old_buffer: Model<Buffer>,
2144 snapshot: MultiBufferSnapshot,
2145 transform_range: Range<Anchor>,
2146 selected_ranges: Vec<Range<Anchor>>,
2147 edit_position: Option<Anchor>,
2148 last_equal_ranges: Vec<Range<Anchor>>,
2149 initial_transaction_id: Option<TransactionId>,
2150 transformation_transaction_id: Option<TransactionId>,
2151 status: CodegenStatus,
2152 generation: Task<()>,
2153 diff: Diff,
2154 telemetry: Option<Arc<Telemetry>>,
2155 _subscription: gpui::Subscription,
2156 prompt_builder: Arc<PromptBuilder>,
2157}
2158
2159enum CodegenStatus {
2160 Idle,
2161 Pending,
2162 Done,
2163 Error(anyhow::Error),
2164}
2165
2166#[derive(Default)]
2167struct Diff {
2168 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2169 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2170}
2171
2172impl Diff {
2173 fn is_empty(&self) -> bool {
2174 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2175 }
2176}
2177
2178impl EventEmitter<CodegenEvent> for Codegen {}
2179
2180impl Codegen {
2181 pub fn new(
2182 buffer: Model<MultiBuffer>,
2183 transform_range: Range<Anchor>,
2184 selected_ranges: Vec<Range<Anchor>>,
2185 initial_transaction_id: Option<TransactionId>,
2186 telemetry: Option<Arc<Telemetry>>,
2187 builder: Arc<PromptBuilder>,
2188 cx: &mut ModelContext<Self>,
2189 ) -> Self {
2190 let snapshot = buffer.read(cx).snapshot(cx);
2191
2192 let (old_buffer, _, _) = buffer
2193 .read(cx)
2194 .range_to_buffer_ranges(transform_range.clone(), cx)
2195 .pop()
2196 .unwrap();
2197 let old_buffer = cx.new_model(|cx| {
2198 let old_buffer = old_buffer.read(cx);
2199 let text = old_buffer.as_rope().clone();
2200 let line_ending = old_buffer.line_ending();
2201 let language = old_buffer.language().cloned();
2202 let language_registry = old_buffer.language_registry();
2203
2204 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2205 buffer.set_language(language, cx);
2206 if let Some(language_registry) = language_registry {
2207 buffer.set_language_registry(language_registry)
2208 }
2209 buffer
2210 });
2211
2212 Self {
2213 buffer: buffer.clone(),
2214 old_buffer,
2215 edit_position: None,
2216 snapshot,
2217 last_equal_ranges: Default::default(),
2218 transformation_transaction_id: None,
2219 status: CodegenStatus::Idle,
2220 generation: Task::ready(()),
2221 diff: Diff::default(),
2222 telemetry,
2223 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2224 initial_transaction_id,
2225 prompt_builder: builder,
2226 transform_range,
2227 selected_ranges,
2228 }
2229 }
2230
2231 fn handle_buffer_event(
2232 &mut self,
2233 _buffer: Model<MultiBuffer>,
2234 event: &multi_buffer::Event,
2235 cx: &mut ModelContext<Self>,
2236 ) {
2237 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2238 if self.transformation_transaction_id == Some(*transaction_id) {
2239 self.transformation_transaction_id = None;
2240 self.generation = Task::ready(());
2241 cx.emit(CodegenEvent::Undone);
2242 }
2243 }
2244 }
2245
2246 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2247 &self.last_equal_ranges
2248 }
2249
2250 pub fn count_tokens(
2251 &self,
2252 user_prompt: String,
2253 assistant_panel_context: Option<LanguageModelRequest>,
2254 cx: &AppContext,
2255 ) -> BoxFuture<'static, Result<TokenCounts>> {
2256 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2257 let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
2258 match request {
2259 Ok(request) => {
2260 let total_count = model.count_tokens(request.clone(), cx);
2261 let assistant_panel_count = assistant_panel_context
2262 .map(|context| model.count_tokens(context, cx))
2263 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2264
2265 async move {
2266 Ok(TokenCounts {
2267 total: total_count.await?,
2268 assistant_panel: assistant_panel_count.await?,
2269 })
2270 }
2271 .boxed()
2272 }
2273 Err(error) => futures::future::ready(Err(error)).boxed(),
2274 }
2275 } else {
2276 future::ready(Err(anyhow!("no active model"))).boxed()
2277 }
2278 }
2279
2280 pub fn start(
2281 &mut self,
2282 user_prompt: String,
2283 assistant_panel_context: Option<LanguageModelRequest>,
2284 cx: &mut ModelContext<Self>,
2285 ) -> Result<()> {
2286 let model = LanguageModelRegistry::read_global(cx)
2287 .active_model()
2288 .context("no active model")?;
2289
2290 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2291 self.buffer.update(cx, |buffer, cx| {
2292 buffer.undo_transaction(transformation_transaction_id, cx);
2293 });
2294 }
2295
2296 self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
2297
2298 let telemetry_id = model.telemetry_id();
2299 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
2300 if user_prompt.trim().to_lowercase() == "delete" {
2301 async { Ok(stream::empty().boxed()) }.boxed_local()
2302 } else {
2303 let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
2304
2305 let chunks =
2306 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2307 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2308 };
2309 self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
2310 Ok(())
2311 }
2312
2313 fn build_request(
2314 &self,
2315 user_prompt: String,
2316 assistant_panel_context: Option<LanguageModelRequest>,
2317 cx: &AppContext,
2318 ) -> Result<LanguageModelRequest> {
2319 let buffer = self.buffer.read(cx).snapshot(cx);
2320 let language = buffer.language_at(self.transform_range.start);
2321 let language_name = if let Some(language) = language.as_ref() {
2322 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2323 None
2324 } else {
2325 Some(language.name())
2326 }
2327 } else {
2328 None
2329 };
2330
2331 // Higher Temperature increases the randomness of model outputs.
2332 // If Markdown or No Language is Known, increase the randomness for more creative output
2333 // If Code, decrease temperature to get more deterministic outputs
2334 let temperature = if let Some(language) = language_name.clone() {
2335 if language.as_ref() == "Markdown" {
2336 1.0
2337 } else {
2338 0.5
2339 }
2340 } else {
2341 1.0
2342 };
2343
2344 let language_name = language_name.as_deref();
2345 let start = buffer.point_to_buffer_offset(self.transform_range.start);
2346 let end = buffer.point_to_buffer_offset(self.transform_range.end);
2347 let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) {
2348 let (start_buffer, start_buffer_offset) = start;
2349 let (end_buffer, end_buffer_offset) = end;
2350 if start_buffer.remote_id() == end_buffer.remote_id() {
2351 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2352 } else {
2353 return Err(anyhow::anyhow!("invalid transformation range"));
2354 }
2355 } else {
2356 return Err(anyhow::anyhow!("invalid transformation range"));
2357 };
2358
2359 let mut transform_context_range = transform_range.to_point(&transform_buffer);
2360 transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
2361 transform_context_range.start.column = 0;
2362 transform_context_range.end =
2363 (transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
2364 transform_context_range.end.column =
2365 transform_buffer.line_len(transform_context_range.end.row);
2366 let transform_context_range = transform_context_range.to_offset(&transform_buffer);
2367
2368 let selected_ranges = self
2369 .selected_ranges
2370 .iter()
2371 .filter_map(|selected_range| {
2372 let start = buffer
2373 .point_to_buffer_offset(selected_range.start)
2374 .map(|(_, offset)| offset)?;
2375 let end = buffer
2376 .point_to_buffer_offset(selected_range.end)
2377 .map(|(_, offset)| offset)?;
2378 Some(start..end)
2379 })
2380 .collect::<Vec<_>>();
2381
2382 let prompt = self
2383 .prompt_builder
2384 .generate_content_prompt(
2385 user_prompt,
2386 language_name,
2387 transform_buffer,
2388 transform_range,
2389 selected_ranges,
2390 transform_context_range,
2391 )
2392 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2393
2394 let mut messages = Vec::new();
2395 if let Some(context_request) = assistant_panel_context {
2396 messages = context_request.messages;
2397 }
2398
2399 messages.push(LanguageModelRequestMessage {
2400 role: Role::User,
2401 content: vec![prompt.into()],
2402 cache: false,
2403 });
2404
2405 Ok(LanguageModelRequest {
2406 messages,
2407 stop: vec!["|END|>".to_string()],
2408 temperature,
2409 })
2410 }
2411
2412 pub fn handle_stream(
2413 &mut self,
2414 model_telemetry_id: String,
2415 edit_range: Range<Anchor>,
2416 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2417 cx: &mut ModelContext<Self>,
2418 ) {
2419 let snapshot = self.snapshot.clone();
2420 let selected_text = snapshot
2421 .text_for_range(edit_range.start..edit_range.end)
2422 .collect::<Rope>();
2423
2424 let selection_start = edit_range.start.to_point(&snapshot);
2425
2426 // Start with the indentation of the first line in the selection
2427 let mut suggested_line_indent = snapshot
2428 .suggested_indents(selection_start.row..=selection_start.row, cx)
2429 .into_values()
2430 .next()
2431 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2432
2433 // If the first line in the selection does not have indentation, check the following lines
2434 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2435 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2436 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2437 // Prefer tabs if a line in the selection uses tabs as indentation
2438 if line_indent.kind == IndentKind::Tab {
2439 suggested_line_indent.kind = IndentKind::Tab;
2440 break;
2441 }
2442 }
2443 }
2444
2445 let telemetry = self.telemetry.clone();
2446 self.diff = Diff::default();
2447 self.status = CodegenStatus::Pending;
2448 let mut edit_start = edit_range.start.to_offset(&snapshot);
2449 self.generation = cx.spawn(|codegen, mut cx| {
2450 async move {
2451 let chunks = stream.await;
2452 let generate = async {
2453 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2454 let line_based_stream_diff: Task<anyhow::Result<()>> =
2455 cx.background_executor().spawn(async move {
2456 let mut response_latency = None;
2457 let request_start = Instant::now();
2458 let diff = async {
2459 let chunks = StripInvalidSpans::new(chunks?);
2460 futures::pin_mut!(chunks);
2461 let mut diff = StreamingDiff::new(selected_text.to_string());
2462 let mut line_diff = LineDiff::default();
2463
2464 while let Some(chunk) = chunks.next().await {
2465 if response_latency.is_none() {
2466 response_latency = Some(request_start.elapsed());
2467 }
2468 let chunk = chunk?;
2469 let char_ops = diff.push_new(&chunk);
2470 line_diff.push_char_operations(&char_ops, &selected_text);
2471 diff_tx
2472 .send((char_ops, line_diff.line_operations()))
2473 .await?;
2474 }
2475
2476 let char_ops = diff.finish();
2477 line_diff.push_char_operations(&char_ops, &selected_text);
2478 line_diff.finish(&selected_text);
2479 diff_tx
2480 .send((char_ops, line_diff.line_operations()))
2481 .await?;
2482
2483 anyhow::Ok(())
2484 };
2485
2486 let result = diff.await;
2487
2488 let error_message =
2489 result.as_ref().err().map(|error| error.to_string());
2490 if let Some(telemetry) = telemetry {
2491 telemetry.report_assistant_event(
2492 None,
2493 telemetry_events::AssistantKind::Inline,
2494 model_telemetry_id,
2495 response_latency,
2496 error_message,
2497 );
2498 }
2499
2500 result?;
2501 Ok(())
2502 });
2503
2504 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2505 codegen.update(&mut cx, |codegen, cx| {
2506 codegen.last_equal_ranges.clear();
2507
2508 let transaction = codegen.buffer.update(cx, |buffer, cx| {
2509 // Avoid grouping assistant edits with user edits.
2510 buffer.finalize_last_transaction(cx);
2511
2512 buffer.start_transaction(cx);
2513 buffer.edit(
2514 char_ops
2515 .into_iter()
2516 .filter_map(|operation| match operation {
2517 CharOperation::Insert { text } => {
2518 let edit_start = snapshot.anchor_after(edit_start);
2519 Some((edit_start..edit_start, text))
2520 }
2521 CharOperation::Delete { bytes } => {
2522 let edit_end = edit_start + bytes;
2523 let edit_range = snapshot.anchor_after(edit_start)
2524 ..snapshot.anchor_before(edit_end);
2525 edit_start = edit_end;
2526 Some((edit_range, String::new()))
2527 }
2528 CharOperation::Keep { bytes } => {
2529 let edit_end = edit_start + bytes;
2530 let edit_range = snapshot.anchor_after(edit_start)
2531 ..snapshot.anchor_before(edit_end);
2532 edit_start = edit_end;
2533 codegen.last_equal_ranges.push(edit_range);
2534 None
2535 }
2536 }),
2537 None,
2538 cx,
2539 );
2540 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2541
2542 buffer.end_transaction(cx)
2543 });
2544
2545 if let Some(transaction) = transaction {
2546 if let Some(first_transaction) =
2547 codegen.transformation_transaction_id
2548 {
2549 // Group all assistant edits into the first transaction.
2550 codegen.buffer.update(cx, |buffer, cx| {
2551 buffer.merge_transactions(
2552 transaction,
2553 first_transaction,
2554 cx,
2555 )
2556 });
2557 } else {
2558 codegen.transformation_transaction_id = Some(transaction);
2559 codegen.buffer.update(cx, |buffer, cx| {
2560 buffer.finalize_last_transaction(cx)
2561 });
2562 }
2563 }
2564
2565 codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx);
2566
2567 cx.notify();
2568 })?;
2569 }
2570
2571 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2572 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2573 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2574 let batch_diff_task = codegen.update(&mut cx, |codegen, cx| {
2575 codegen.reapply_batch_diff(edit_range.clone(), cx)
2576 })?;
2577 let (line_based_stream_diff, ()) =
2578 join!(line_based_stream_diff, batch_diff_task);
2579 line_based_stream_diff?;
2580
2581 anyhow::Ok(())
2582 };
2583
2584 let result = generate.await;
2585 codegen
2586 .update(&mut cx, |this, cx| {
2587 this.last_equal_ranges.clear();
2588 if let Err(error) = result {
2589 this.status = CodegenStatus::Error(error);
2590 } else {
2591 this.status = CodegenStatus::Done;
2592 }
2593 cx.emit(CodegenEvent::Finished);
2594 cx.notify();
2595 })
2596 .ok();
2597 }
2598 });
2599 cx.notify();
2600 }
2601
2602 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2603 self.last_equal_ranges.clear();
2604 if self.diff.is_empty() {
2605 self.status = CodegenStatus::Idle;
2606 } else {
2607 self.status = CodegenStatus::Done;
2608 }
2609 self.generation = Task::ready(());
2610 cx.emit(CodegenEvent::Finished);
2611 cx.notify();
2612 }
2613
2614 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2615 self.buffer.update(cx, |buffer, cx| {
2616 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2617 buffer.undo_transaction(transaction_id, cx);
2618 buffer.refresh_preview(cx);
2619 }
2620
2621 if let Some(transaction_id) = self.initial_transaction_id.take() {
2622 buffer.undo_transaction(transaction_id, cx);
2623 buffer.refresh_preview(cx);
2624 }
2625 });
2626 }
2627
2628 fn reapply_line_based_diff(
2629 &mut self,
2630 edit_range: Range<Anchor>,
2631 line_operations: Vec<LineOperation>,
2632 cx: &mut ModelContext<Self>,
2633 ) {
2634 let old_snapshot = self.snapshot.clone();
2635 let old_range = edit_range.to_point(&old_snapshot);
2636 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2637 let new_range = edit_range.to_point(&new_snapshot);
2638
2639 let mut old_row = old_range.start.row;
2640 let mut new_row = new_range.start.row;
2641
2642 self.diff.deleted_row_ranges.clear();
2643 self.diff.inserted_row_ranges.clear();
2644 for operation in line_operations {
2645 match operation {
2646 LineOperation::Keep { lines } => {
2647 old_row += lines;
2648 new_row += lines;
2649 }
2650 LineOperation::Delete { lines } => {
2651 let old_end_row = old_row + lines - 1;
2652 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2653
2654 if let Some((_, last_deleted_row_range)) =
2655 self.diff.deleted_row_ranges.last_mut()
2656 {
2657 if *last_deleted_row_range.end() + 1 == old_row {
2658 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2659 } else {
2660 self.diff
2661 .deleted_row_ranges
2662 .push((new_row, old_row..=old_end_row));
2663 }
2664 } else {
2665 self.diff
2666 .deleted_row_ranges
2667 .push((new_row, old_row..=old_end_row));
2668 }
2669
2670 old_row += lines;
2671 }
2672 LineOperation::Insert { lines } => {
2673 let new_end_row = new_row + lines - 1;
2674 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2675 let end = new_snapshot.anchor_before(Point::new(
2676 new_end_row,
2677 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2678 ));
2679 self.diff.inserted_row_ranges.push(start..=end);
2680 new_row += lines;
2681 }
2682 }
2683
2684 cx.notify();
2685 }
2686 }
2687
2688 fn reapply_batch_diff(
2689 &mut self,
2690 edit_range: Range<Anchor>,
2691 cx: &mut ModelContext<Self>,
2692 ) -> Task<()> {
2693 let old_snapshot = self.snapshot.clone();
2694 let old_range = edit_range.to_point(&old_snapshot);
2695 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2696 let new_range = edit_range.to_point(&new_snapshot);
2697
2698 cx.spawn(|codegen, mut cx| async move {
2699 let (deleted_row_ranges, inserted_row_ranges) = cx
2700 .background_executor()
2701 .spawn(async move {
2702 let old_text = old_snapshot
2703 .text_for_range(
2704 Point::new(old_range.start.row, 0)
2705 ..Point::new(
2706 old_range.end.row,
2707 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
2708 ),
2709 )
2710 .collect::<String>();
2711 let new_text = new_snapshot
2712 .text_for_range(
2713 Point::new(new_range.start.row, 0)
2714 ..Point::new(
2715 new_range.end.row,
2716 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
2717 ),
2718 )
2719 .collect::<String>();
2720
2721 let mut old_row = old_range.start.row;
2722 let mut new_row = new_range.start.row;
2723 let batch_diff =
2724 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
2725
2726 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
2727 let mut inserted_row_ranges = Vec::new();
2728 for change in batch_diff.iter_all_changes() {
2729 let line_count = change.value().lines().count() as u32;
2730 match change.tag() {
2731 similar::ChangeTag::Equal => {
2732 old_row += line_count;
2733 new_row += line_count;
2734 }
2735 similar::ChangeTag::Delete => {
2736 let old_end_row = old_row + line_count - 1;
2737 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2738
2739 if let Some((_, last_deleted_row_range)) =
2740 deleted_row_ranges.last_mut()
2741 {
2742 if *last_deleted_row_range.end() + 1 == old_row {
2743 *last_deleted_row_range =
2744 *last_deleted_row_range.start()..=old_end_row;
2745 } else {
2746 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2747 }
2748 } else {
2749 deleted_row_ranges.push((new_row, old_row..=old_end_row));
2750 }
2751
2752 old_row += line_count;
2753 }
2754 similar::ChangeTag::Insert => {
2755 let new_end_row = new_row + line_count - 1;
2756 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2757 let end = new_snapshot.anchor_before(Point::new(
2758 new_end_row,
2759 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2760 ));
2761 inserted_row_ranges.push(start..=end);
2762 new_row += line_count;
2763 }
2764 }
2765 }
2766
2767 (deleted_row_ranges, inserted_row_ranges)
2768 })
2769 .await;
2770
2771 codegen
2772 .update(&mut cx, |codegen, cx| {
2773 codegen.diff.deleted_row_ranges = deleted_row_ranges;
2774 codegen.diff.inserted_row_ranges = inserted_row_ranges;
2775 cx.notify();
2776 })
2777 .ok();
2778 })
2779 }
2780}
2781
2782struct StripInvalidSpans<T> {
2783 stream: T,
2784 stream_done: bool,
2785 buffer: String,
2786 first_line: bool,
2787 line_end: bool,
2788 starts_with_code_block: bool,
2789}
2790
2791impl<T> StripInvalidSpans<T>
2792where
2793 T: Stream<Item = Result<String>>,
2794{
2795 fn new(stream: T) -> Self {
2796 Self {
2797 stream,
2798 stream_done: false,
2799 buffer: String::new(),
2800 first_line: true,
2801 line_end: false,
2802 starts_with_code_block: false,
2803 }
2804 }
2805}
2806
2807impl<T> Stream for StripInvalidSpans<T>
2808where
2809 T: Stream<Item = Result<String>>,
2810{
2811 type Item = Result<String>;
2812
2813 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2814 const CODE_BLOCK_DELIMITER: &str = "```";
2815 const CURSOR_SPAN: &str = "<|CURSOR|>";
2816
2817 let this = unsafe { self.get_unchecked_mut() };
2818 loop {
2819 if !this.stream_done {
2820 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2821 match stream.as_mut().poll_next(cx) {
2822 Poll::Ready(Some(Ok(chunk))) => {
2823 this.buffer.push_str(&chunk);
2824 }
2825 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2826 Poll::Ready(None) => {
2827 this.stream_done = true;
2828 }
2829 Poll::Pending => return Poll::Pending,
2830 }
2831 }
2832
2833 let mut chunk = String::new();
2834 let mut consumed = 0;
2835 if !this.buffer.is_empty() {
2836 let mut lines = this.buffer.split('\n').enumerate().peekable();
2837 while let Some((line_ix, line)) = lines.next() {
2838 if line_ix > 0 {
2839 this.first_line = false;
2840 }
2841
2842 if this.first_line {
2843 let trimmed_line = line.trim();
2844 if lines.peek().is_some() {
2845 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2846 consumed += line.len() + 1;
2847 this.starts_with_code_block = true;
2848 continue;
2849 }
2850 } else if trimmed_line.is_empty()
2851 || prefixes(CODE_BLOCK_DELIMITER)
2852 .any(|prefix| trimmed_line.starts_with(prefix))
2853 {
2854 break;
2855 }
2856 }
2857
2858 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2859 if lines.peek().is_some() {
2860 if this.line_end {
2861 chunk.push('\n');
2862 }
2863
2864 chunk.push_str(&line_without_cursor);
2865 this.line_end = true;
2866 consumed += line.len() + 1;
2867 } else if this.stream_done {
2868 if !this.starts_with_code_block
2869 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2870 {
2871 if this.line_end {
2872 chunk.push('\n');
2873 }
2874
2875 chunk.push_str(&line);
2876 }
2877
2878 consumed += line.len();
2879 } else {
2880 let trimmed_line = line.trim();
2881 if trimmed_line.is_empty()
2882 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2883 || prefixes(CODE_BLOCK_DELIMITER)
2884 .any(|prefix| trimmed_line.ends_with(prefix))
2885 {
2886 break;
2887 } else {
2888 if this.line_end {
2889 chunk.push('\n');
2890 this.line_end = false;
2891 }
2892
2893 chunk.push_str(&line_without_cursor);
2894 consumed += line.len();
2895 }
2896 }
2897 }
2898 }
2899
2900 this.buffer = this.buffer.split_off(consumed);
2901 if !chunk.is_empty() {
2902 return Poll::Ready(Some(Ok(chunk)));
2903 } else if this.stream_done {
2904 return Poll::Ready(None);
2905 }
2906 }
2907 }
2908}
2909
2910fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2911 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2912}
2913
2914fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2915 ranges.sort_unstable_by(|a, b| {
2916 a.start
2917 .cmp(&b.start, buffer)
2918 .then_with(|| b.end.cmp(&a.end, buffer))
2919 });
2920
2921 let mut ix = 0;
2922 while ix + 1 < ranges.len() {
2923 let b = ranges[ix + 1].clone();
2924 let a = &mut ranges[ix];
2925 if a.end.cmp(&b.start, buffer).is_gt() {
2926 if a.end.cmp(&b.end, buffer).is_lt() {
2927 a.end = b.end;
2928 }
2929 ranges.remove(ix + 1);
2930 } else {
2931 ix += 1;
2932 }
2933 }
2934}
2935
2936#[cfg(test)]
2937mod tests {
2938 use super::*;
2939 use futures::stream::{self};
2940 use serde::Serialize;
2941
2942 #[derive(Serialize)]
2943 pub struct DummyCompletionRequest {
2944 pub name: String,
2945 }
2946
2947 #[gpui::test]
2948 async fn test_strip_invalid_spans_from_codeblock() {
2949 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2950 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2951 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
2952 assert_chunks(
2953 "```html\n```js\nLorem ipsum dolor\n```\n```",
2954 "```js\nLorem ipsum dolor\n```",
2955 )
2956 .await;
2957 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
2958 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
2959 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
2960 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
2961
2962 async fn assert_chunks(text: &str, expected_text: &str) {
2963 for chunk_size in 1..=text.len() {
2964 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
2965 .map(|chunk| chunk.unwrap())
2966 .collect::<String>()
2967 .await;
2968 assert_eq!(
2969 actual_text, expected_text,
2970 "failed to strip invalid spans, chunk size: {}",
2971 chunk_size
2972 );
2973 }
2974 }
2975
2976 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
2977 stream::iter(
2978 text.chars()
2979 .collect::<Vec<_>>()
2980 .chunks(size)
2981 .map(|chunk| Ok(chunk.iter().collect::<String>()))
2982 .collect::<Vec<_>>(),
2983 )
2984 }
2985 }
2986}