1use crate::{
2 humanize_token_count, prompts::PromptBuilder, AssistantPanel, AssistantPanelEvent,
3 CharOperation, LineDiff, LineOperation, ModelSelector, StreamingDiff,
4};
5use anyhow::{anyhow, Context as _, Result};
6use client::{telemetry::Telemetry, ErrorExt};
7use collections::{hash_map, HashMap, HashSet, VecDeque};
8use editor::{
9 actions::{MoveDown, MoveUp, SelectAll},
10 display_map::{
11 BlockContext, BlockDisposition, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
12 ToDisplayPoint,
13 },
14 Anchor, AnchorRangeExt, Editor, EditorElement, EditorEvent, EditorMode, EditorStyle,
15 ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
16};
17use feature_flags::{FeatureFlagAppExt as _, ZedPro};
18use fs::Fs;
19use futures::{
20 channel::mpsc,
21 future::{BoxFuture, LocalBoxFuture},
22 stream::{self, BoxStream},
23 SinkExt, Stream, StreamExt,
24};
25use gpui::{
26 anchored, deferred, point, AppContext, ClickEvent, EventEmitter, FocusHandle, FocusableView,
27 FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
28 UpdateGlobal, View, ViewContext, WeakView, WindowContext,
29};
30use language::{Buffer, IndentKind, Point, TransactionId};
31use language_model::{
32 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
33};
34use multi_buffer::MultiBufferRow;
35use parking_lot::Mutex;
36use rope::Rope;
37use settings::Settings;
38use smol::future::FutureExt;
39use std::{
40 future::{self, Future},
41 mem,
42 ops::{Range, RangeInclusive},
43 pin::Pin,
44 sync::Arc,
45 task::{self, Poll},
46 time::{Duration, Instant},
47};
48use text::OffsetRangeExt as _;
49use theme::ThemeSettings;
50use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
51use util::{RangeExt, ResultExt};
52use workspace::{notifications::NotificationId, Toast, Workspace};
53
54pub fn init(
55 fs: Arc<dyn Fs>,
56 prompt_builder: Arc<PromptBuilder>,
57 telemetry: Arc<Telemetry>,
58 cx: &mut AppContext,
59) {
60 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
61 cx.observe_new_views(|_, cx| {
62 let workspace = cx.view().clone();
63 InlineAssistant::update_global(cx, |inline_assistant, cx| {
64 inline_assistant.register_workspace(&workspace, cx)
65 })
66 })
67 .detach();
68}
69
70const PROMPT_HISTORY_MAX_LEN: usize = 20;
71
72pub struct InlineAssistant {
73 next_assist_id: InlineAssistId,
74 next_assist_group_id: InlineAssistGroupId,
75 assists: HashMap<InlineAssistId, InlineAssist>,
76 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
77 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
78 assist_observations:
79 HashMap<InlineAssistId, (async_watch::Sender<()>, async_watch::Receiver<()>)>,
80 confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
81 prompt_history: VecDeque<String>,
82 prompt_builder: Arc<PromptBuilder>,
83 telemetry: Option<Arc<Telemetry>>,
84 fs: Arc<dyn Fs>,
85}
86
87impl Global for InlineAssistant {}
88
89impl InlineAssistant {
90 pub fn new(
91 fs: Arc<dyn Fs>,
92 prompt_builder: Arc<PromptBuilder>,
93 telemetry: Arc<Telemetry>,
94 ) -> Self {
95 Self {
96 next_assist_id: InlineAssistId::default(),
97 next_assist_group_id: InlineAssistGroupId::default(),
98 assists: HashMap::default(),
99 assists_by_editor: HashMap::default(),
100 assist_groups: HashMap::default(),
101 assist_observations: HashMap::default(),
102 confirmed_assists: HashMap::default(),
103 prompt_history: VecDeque::default(),
104 prompt_builder,
105 telemetry: Some(telemetry),
106 fs,
107 }
108 }
109
110 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
111 cx.subscribe(workspace, |_, event, cx| {
112 Self::update_global(cx, |this, cx| this.handle_workspace_event(event, cx));
113 })
114 .detach();
115 }
116
117 fn handle_workspace_event(&mut self, event: &workspace::Event, cx: &mut WindowContext) {
118 // When the user manually saves an editor, automatically accepts all finished transformations.
119 if let workspace::Event::UserSavedItem { item, .. } = event {
120 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
121 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
122 for assist_id in editor_assists.assist_ids.clone() {
123 let assist = &self.assists[&assist_id];
124 if let CodegenStatus::Done = &assist.codegen.read(cx).status {
125 self.finish_assist(assist_id, false, cx)
126 }
127 }
128 }
129 }
130 }
131 }
132
133 pub fn assist(
134 &mut self,
135 editor: &View<Editor>,
136 workspace: Option<WeakView<Workspace>>,
137 assistant_panel: Option<&View<AssistantPanel>>,
138 initial_prompt: Option<String>,
139 cx: &mut WindowContext,
140 ) {
141 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
142 struct CodegenRange {
143 transform_range: Range<Point>,
144 selection_ranges: Vec<Range<Point>>,
145 focus_assist: bool,
146 }
147
148 let newest_selection_range = editor.read(cx).selections.newest::<Point>(cx).range();
149 let mut codegen_ranges: Vec<CodegenRange> = Vec::new();
150
151 let selection_ranges = snapshot
152 .split_ranges(editor.read(cx).selections.disjoint_anchor_ranges())
153 .map(|range| range.to_point(&snapshot))
154 .collect::<Vec<Range<Point>>>();
155
156 for selection_range in selection_ranges {
157 let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range);
158 let mut transform_range = selection_range.start..selection_range.end;
159
160 // Expand the transform range to start/end of lines.
161 // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
162 transform_range.start.column = 0;
163 if transform_range.end.column == 0 && transform_range.end > transform_range.start {
164 transform_range.end.row -= 1;
165 }
166 transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
167 let selection_range =
168 selection_range.start..selection_range.end.min(transform_range.end);
169
170 // If we intersect the previous transform range,
171 if let Some(CodegenRange {
172 transform_range: prev_transform_range,
173 selection_ranges,
174 focus_assist,
175 }) = codegen_ranges.last_mut()
176 {
177 if transform_range.start <= prev_transform_range.end {
178 prev_transform_range.end = transform_range.end;
179 selection_ranges.push(selection_range);
180 *focus_assist |= selection_is_newest;
181 continue;
182 }
183 }
184
185 codegen_ranges.push(CodegenRange {
186 transform_range,
187 selection_ranges: vec![selection_range],
188 focus_assist: selection_is_newest,
189 })
190 }
191
192 let assist_group_id = self.next_assist_group_id.post_inc();
193 let prompt_buffer =
194 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
195 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
196 let mut assists = Vec::new();
197 let mut assist_to_focus = None;
198
199 for CodegenRange {
200 transform_range,
201 selection_ranges,
202 focus_assist,
203 } in codegen_ranges
204 {
205 let transform_range = snapshot.anchor_before(transform_range.start)
206 ..snapshot.anchor_after(transform_range.end);
207 let selection_ranges = selection_ranges
208 .iter()
209 .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
210 .collect::<Vec<_>>();
211
212 let codegen = cx.new_model(|cx| {
213 Codegen::new(
214 editor.read(cx).buffer().clone(),
215 transform_range.clone(),
216 selection_ranges,
217 None,
218 self.telemetry.clone(),
219 self.prompt_builder.clone(),
220 cx,
221 )
222 });
223
224 let assist_id = self.next_assist_id.post_inc();
225 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
226 let prompt_editor = cx.new_view(|cx| {
227 PromptEditor::new(
228 assist_id,
229 gutter_dimensions.clone(),
230 self.prompt_history.clone(),
231 prompt_buffer.clone(),
232 codegen.clone(),
233 editor,
234 assistant_panel,
235 workspace.clone(),
236 self.fs.clone(),
237 cx,
238 )
239 });
240
241 if focus_assist {
242 assist_to_focus = Some(assist_id);
243 }
244
245 let [prompt_block_id, end_block_id] =
246 self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
247
248 assists.push((
249 assist_id,
250 transform_range,
251 prompt_editor,
252 prompt_block_id,
253 end_block_id,
254 ));
255 }
256
257 let editor_assists = self
258 .assists_by_editor
259 .entry(editor.downgrade())
260 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
261 let mut assist_group = InlineAssistGroup::new();
262 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
263 self.assists.insert(
264 assist_id,
265 InlineAssist::new(
266 assist_id,
267 assist_group_id,
268 assistant_panel.is_some(),
269 editor,
270 &prompt_editor,
271 prompt_block_id,
272 end_block_id,
273 range,
274 prompt_editor.read(cx).codegen.clone(),
275 workspace.clone(),
276 cx,
277 ),
278 );
279 assist_group.assist_ids.push(assist_id);
280 editor_assists.assist_ids.push(assist_id);
281 }
282 self.assist_groups.insert(assist_group_id, assist_group);
283
284 if let Some(assist_id) = assist_to_focus {
285 self.focus_assist(assist_id, cx);
286 }
287 }
288
289 #[allow(clippy::too_many_arguments)]
290 pub fn suggest_assist(
291 &mut self,
292 editor: &View<Editor>,
293 mut range: Range<Anchor>,
294 initial_prompt: String,
295 initial_transaction_id: Option<TransactionId>,
296 workspace: Option<WeakView<Workspace>>,
297 assistant_panel: Option<&View<AssistantPanel>>,
298 cx: &mut WindowContext,
299 ) -> InlineAssistId {
300 let assist_group_id = self.next_assist_group_id.post_inc();
301 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
302 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
303
304 let assist_id = self.next_assist_id.post_inc();
305
306 let buffer = editor.read(cx).buffer().clone();
307 {
308 let snapshot = buffer.read(cx).read(cx);
309 range.start = range.start.bias_left(&snapshot);
310 range.end = range.end.bias_right(&snapshot);
311 }
312
313 let codegen = cx.new_model(|cx| {
314 Codegen::new(
315 editor.read(cx).buffer().clone(),
316 range.clone(),
317 vec![range.clone()],
318 initial_transaction_id,
319 self.telemetry.clone(),
320 self.prompt_builder.clone(),
321 cx,
322 )
323 });
324
325 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
326 let prompt_editor = cx.new_view(|cx| {
327 PromptEditor::new(
328 assist_id,
329 gutter_dimensions.clone(),
330 self.prompt_history.clone(),
331 prompt_buffer.clone(),
332 codegen.clone(),
333 editor,
334 assistant_panel,
335 workspace.clone(),
336 self.fs.clone(),
337 cx,
338 )
339 });
340
341 let [prompt_block_id, end_block_id] =
342 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
343
344 let editor_assists = self
345 .assists_by_editor
346 .entry(editor.downgrade())
347 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
348
349 let mut assist_group = InlineAssistGroup::new();
350 self.assists.insert(
351 assist_id,
352 InlineAssist::new(
353 assist_id,
354 assist_group_id,
355 assistant_panel.is_some(),
356 editor,
357 &prompt_editor,
358 prompt_block_id,
359 end_block_id,
360 range,
361 prompt_editor.read(cx).codegen.clone(),
362 workspace.clone(),
363 cx,
364 ),
365 );
366 assist_group.assist_ids.push(assist_id);
367 editor_assists.assist_ids.push(assist_id);
368 self.assist_groups.insert(assist_group_id, assist_group);
369 assist_id
370 }
371
372 fn insert_assist_blocks(
373 &self,
374 editor: &View<Editor>,
375 range: &Range<Anchor>,
376 prompt_editor: &View<PromptEditor>,
377 cx: &mut WindowContext,
378 ) -> [CustomBlockId; 2] {
379 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
380 prompt_editor
381 .editor
382 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
383 });
384 let assist_blocks = vec![
385 BlockProperties {
386 style: BlockStyle::Sticky,
387 position: range.start,
388 height: prompt_editor_height,
389 render: build_assist_editor_renderer(prompt_editor),
390 disposition: BlockDisposition::Above,
391 priority: 0,
392 },
393 BlockProperties {
394 style: BlockStyle::Sticky,
395 position: range.end,
396 height: 0,
397 render: Box::new(|cx| {
398 v_flex()
399 .h_full()
400 .w_full()
401 .border_t_1()
402 .border_color(cx.theme().status().info_border)
403 .into_any_element()
404 }),
405 disposition: BlockDisposition::Below,
406 priority: 0,
407 },
408 ];
409
410 editor.update(cx, |editor, cx| {
411 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
412 [block_ids[0], block_ids[1]]
413 })
414 }
415
416 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
417 let assist = &self.assists[&assist_id];
418 let Some(decorations) = assist.decorations.as_ref() else {
419 return;
420 };
421 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
422 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
423
424 assist_group.active_assist_id = Some(assist_id);
425 if assist_group.linked {
426 for assist_id in &assist_group.assist_ids {
427 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
428 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
429 prompt_editor.set_show_cursor_when_unfocused(true, cx)
430 });
431 }
432 }
433 }
434
435 assist
436 .editor
437 .update(cx, |editor, cx| {
438 let scroll_top = editor.scroll_position(cx).y;
439 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
440 let prompt_row = editor
441 .row_for_block(decorations.prompt_block_id, cx)
442 .unwrap()
443 .0 as f32;
444
445 if (scroll_top..scroll_bottom).contains(&prompt_row) {
446 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
447 assist_id,
448 distance_from_top: prompt_row - scroll_top,
449 });
450 } else {
451 editor_assists.scroll_lock = None;
452 }
453 })
454 .ok();
455 }
456
457 fn handle_prompt_editor_focus_out(
458 &mut self,
459 assist_id: InlineAssistId,
460 cx: &mut WindowContext,
461 ) {
462 let assist = &self.assists[&assist_id];
463 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
464 if assist_group.active_assist_id == Some(assist_id) {
465 assist_group.active_assist_id = None;
466 if assist_group.linked {
467 for assist_id in &assist_group.assist_ids {
468 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
469 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
470 prompt_editor.set_show_cursor_when_unfocused(false, cx)
471 });
472 }
473 }
474 }
475 }
476 }
477
478 fn handle_prompt_editor_event(
479 &mut self,
480 prompt_editor: View<PromptEditor>,
481 event: &PromptEditorEvent,
482 cx: &mut WindowContext,
483 ) {
484 let assist_id = prompt_editor.read(cx).id;
485 match event {
486 PromptEditorEvent::StartRequested => {
487 self.start_assist(assist_id, cx);
488 }
489 PromptEditorEvent::StopRequested => {
490 self.stop_assist(assist_id, cx);
491 }
492 PromptEditorEvent::ConfirmRequested => {
493 self.finish_assist(assist_id, false, cx);
494 }
495 PromptEditorEvent::CancelRequested => {
496 self.finish_assist(assist_id, true, cx);
497 }
498 PromptEditorEvent::DismissRequested => {
499 self.dismiss_assist(assist_id, cx);
500 }
501 }
502 }
503
504 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
505 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
506 return;
507 };
508
509 let editor = editor.read(cx);
510 if editor.selections.count() == 1 {
511 let selection = editor.selections.newest::<usize>(cx);
512 let buffer = editor.buffer().read(cx).snapshot(cx);
513 for assist_id in &editor_assists.assist_ids {
514 let assist = &self.assists[assist_id];
515 let assist_range = assist.range.to_offset(&buffer);
516 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
517 {
518 if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) {
519 self.dismiss_assist(*assist_id, cx);
520 } else {
521 self.finish_assist(*assist_id, false, cx);
522 }
523
524 return;
525 }
526 }
527 }
528
529 cx.propagate();
530 }
531
532 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
533 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
534 return;
535 };
536
537 let editor = editor.read(cx);
538 if editor.selections.count() == 1 {
539 let selection = editor.selections.newest::<usize>(cx);
540 let buffer = editor.buffer().read(cx).snapshot(cx);
541 let mut closest_assist_fallback = None;
542 for assist_id in &editor_assists.assist_ids {
543 let assist = &self.assists[assist_id];
544 let assist_range = assist.range.to_offset(&buffer);
545 if assist.decorations.is_some() {
546 if assist_range.contains(&selection.start)
547 && assist_range.contains(&selection.end)
548 {
549 self.focus_assist(*assist_id, cx);
550 return;
551 } else {
552 let distance_from_selection = assist_range
553 .start
554 .abs_diff(selection.start)
555 .min(assist_range.start.abs_diff(selection.end))
556 + assist_range
557 .end
558 .abs_diff(selection.start)
559 .min(assist_range.end.abs_diff(selection.end));
560 match closest_assist_fallback {
561 Some((_, old_distance)) => {
562 if distance_from_selection < old_distance {
563 closest_assist_fallback =
564 Some((assist_id, distance_from_selection));
565 }
566 }
567 None => {
568 closest_assist_fallback = Some((assist_id, distance_from_selection))
569 }
570 }
571 }
572 }
573 }
574
575 if let Some((&assist_id, _)) = closest_assist_fallback {
576 self.focus_assist(assist_id, cx);
577 }
578 }
579
580 cx.propagate();
581 }
582
583 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
584 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
585 for assist_id in editor_assists.assist_ids.clone() {
586 self.finish_assist(assist_id, true, cx);
587 }
588 }
589 }
590
591 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
592 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
593 return;
594 };
595 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
596 return;
597 };
598 let assist = &self.assists[&scroll_lock.assist_id];
599 let Some(decorations) = assist.decorations.as_ref() else {
600 return;
601 };
602
603 editor.update(cx, |editor, cx| {
604 let scroll_position = editor.scroll_position(cx);
605 let target_scroll_top = editor
606 .row_for_block(decorations.prompt_block_id, cx)
607 .unwrap()
608 .0 as f32
609 - scroll_lock.distance_from_top;
610 if target_scroll_top != scroll_position.y {
611 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
612 }
613 });
614 }
615
616 fn handle_editor_event(
617 &mut self,
618 editor: View<Editor>,
619 event: &EditorEvent,
620 cx: &mut WindowContext,
621 ) {
622 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
623 return;
624 };
625
626 match event {
627 EditorEvent::Edited { transaction_id } => {
628 let buffer = editor.read(cx).buffer().read(cx);
629 let edited_ranges =
630 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
631 let snapshot = buffer.snapshot(cx);
632
633 for assist_id in editor_assists.assist_ids.clone() {
634 let assist = &self.assists[&assist_id];
635 if matches!(
636 assist.codegen.read(cx).status,
637 CodegenStatus::Error(_) | CodegenStatus::Done
638 ) {
639 let assist_range = assist.range.to_offset(&snapshot);
640 if edited_ranges
641 .iter()
642 .any(|range| range.overlaps(&assist_range))
643 {
644 self.finish_assist(assist_id, false, cx);
645 }
646 }
647 }
648 }
649 EditorEvent::ScrollPositionChanged { .. } => {
650 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
651 let assist = &self.assists[&scroll_lock.assist_id];
652 if let Some(decorations) = assist.decorations.as_ref() {
653 let distance_from_top = editor.update(cx, |editor, cx| {
654 let scroll_top = editor.scroll_position(cx).y;
655 let prompt_row = editor
656 .row_for_block(decorations.prompt_block_id, cx)
657 .unwrap()
658 .0 as f32;
659 prompt_row - scroll_top
660 });
661
662 if distance_from_top != scroll_lock.distance_from_top {
663 editor_assists.scroll_lock = None;
664 }
665 }
666 }
667 }
668 EditorEvent::SelectionsChanged { .. } => {
669 for assist_id in editor_assists.assist_ids.clone() {
670 let assist = &self.assists[&assist_id];
671 if let Some(decorations) = assist.decorations.as_ref() {
672 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
673 return;
674 }
675 }
676 }
677
678 editor_assists.scroll_lock = None;
679 }
680 _ => {}
681 }
682 }
683
684 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
685 if let Some(assist) = self.assists.get(&assist_id) {
686 let assist_group_id = assist.group_id;
687 if self.assist_groups[&assist_group_id].linked {
688 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
689 self.finish_assist(assist_id, undo, cx);
690 }
691 return;
692 }
693 }
694
695 self.dismiss_assist(assist_id, cx);
696
697 if let Some(assist) = self.assists.remove(&assist_id) {
698 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
699 {
700 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
701 if entry.get().assist_ids.is_empty() {
702 entry.remove();
703 }
704 }
705
706 if let hash_map::Entry::Occupied(mut entry) =
707 self.assists_by_editor.entry(assist.editor.clone())
708 {
709 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
710 if entry.get().assist_ids.is_empty() {
711 entry.remove();
712 if let Some(editor) = assist.editor.upgrade() {
713 self.update_editor_highlights(&editor, cx);
714 }
715 } else {
716 entry.get().highlight_updates.send(()).ok();
717 }
718 }
719
720 if undo {
721 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
722 } else {
723 self.confirmed_assists.insert(assist_id, assist.codegen);
724 }
725 }
726
727 // Remove the assist from the status updates map
728 self.assist_observations.remove(&assist_id);
729 }
730
731 pub fn undo_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
732 let Some(codegen) = self.confirmed_assists.remove(&assist_id) else {
733 return false;
734 };
735 codegen.update(cx, |this, cx| this.undo(cx));
736 true
737 }
738
739 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
740 let Some(assist) = self.assists.get_mut(&assist_id) else {
741 return false;
742 };
743 let Some(editor) = assist.editor.upgrade() else {
744 return false;
745 };
746 let Some(decorations) = assist.decorations.take() else {
747 return false;
748 };
749
750 editor.update(cx, |editor, cx| {
751 let mut to_remove = decorations.removed_line_block_ids;
752 to_remove.insert(decorations.prompt_block_id);
753 to_remove.insert(decorations.end_block_id);
754 editor.remove_blocks(to_remove, None, cx);
755 });
756
757 if decorations
758 .prompt_editor
759 .focus_handle(cx)
760 .contains_focused(cx)
761 {
762 self.focus_next_assist(assist_id, cx);
763 }
764
765 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
766 if editor_assists
767 .scroll_lock
768 .as_ref()
769 .map_or(false, |lock| lock.assist_id == assist_id)
770 {
771 editor_assists.scroll_lock = None;
772 }
773 editor_assists.highlight_updates.send(()).ok();
774 }
775
776 true
777 }
778
779 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
780 let Some(assist) = self.assists.get(&assist_id) else {
781 return;
782 };
783
784 let assist_group = &self.assist_groups[&assist.group_id];
785 let assist_ix = assist_group
786 .assist_ids
787 .iter()
788 .position(|id| *id == assist_id)
789 .unwrap();
790 let assist_ids = assist_group
791 .assist_ids
792 .iter()
793 .skip(assist_ix + 1)
794 .chain(assist_group.assist_ids.iter().take(assist_ix));
795
796 for assist_id in assist_ids {
797 let assist = &self.assists[assist_id];
798 if assist.decorations.is_some() {
799 self.focus_assist(*assist_id, cx);
800 return;
801 }
802 }
803
804 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
805 }
806
807 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
808 let Some(assist) = self.assists.get(&assist_id) else {
809 return;
810 };
811
812 if let Some(decorations) = assist.decorations.as_ref() {
813 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
814 prompt_editor.editor.update(cx, |editor, cx| {
815 editor.focus(cx);
816 editor.select_all(&SelectAll, cx);
817 })
818 });
819 }
820
821 self.scroll_to_assist(assist_id, cx);
822 }
823
824 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
825 let Some(assist) = self.assists.get(&assist_id) else {
826 return;
827 };
828 let Some(editor) = assist.editor.upgrade() else {
829 return;
830 };
831
832 let position = assist.range.start;
833 editor.update(cx, |editor, cx| {
834 editor.change_selections(None, cx, |selections| {
835 selections.select_anchor_ranges([position..position])
836 });
837
838 let mut scroll_target_top;
839 let mut scroll_target_bottom;
840 if let Some(decorations) = assist.decorations.as_ref() {
841 scroll_target_top = editor
842 .row_for_block(decorations.prompt_block_id, cx)
843 .unwrap()
844 .0 as f32;
845 scroll_target_bottom = editor
846 .row_for_block(decorations.end_block_id, cx)
847 .unwrap()
848 .0 as f32;
849 } else {
850 let snapshot = editor.snapshot(cx);
851 let start_row = assist
852 .range
853 .start
854 .to_display_point(&snapshot.display_snapshot)
855 .row();
856 scroll_target_top = start_row.0 as f32;
857 scroll_target_bottom = scroll_target_top + 1.;
858 }
859 scroll_target_top -= editor.vertical_scroll_margin() as f32;
860 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
861
862 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
863 let scroll_top = editor.scroll_position(cx).y;
864 let scroll_bottom = scroll_top + height_in_lines;
865
866 if scroll_target_top < scroll_top {
867 editor.set_scroll_position(point(0., scroll_target_top), cx);
868 } else if scroll_target_bottom > scroll_bottom {
869 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
870 editor
871 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
872 } else {
873 editor.set_scroll_position(point(0., scroll_target_top), cx);
874 }
875 }
876 });
877 }
878
879 fn unlink_assist_group(
880 &mut self,
881 assist_group_id: InlineAssistGroupId,
882 cx: &mut WindowContext,
883 ) -> Vec<InlineAssistId> {
884 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
885 assist_group.linked = false;
886 for assist_id in &assist_group.assist_ids {
887 let assist = self.assists.get_mut(assist_id).unwrap();
888 if let Some(editor_decorations) = assist.decorations.as_ref() {
889 editor_decorations
890 .prompt_editor
891 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
892 }
893 }
894 assist_group.assist_ids.clone()
895 }
896
897 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
898 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
899 assist
900 } else {
901 return;
902 };
903
904 let assist_group_id = assist.group_id;
905 if self.assist_groups[&assist_group_id].linked {
906 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
907 self.start_assist(assist_id, cx);
908 }
909 return;
910 }
911
912 let Some(user_prompt) = assist.user_prompt(cx) else {
913 return;
914 };
915
916 self.prompt_history.retain(|prompt| *prompt != user_prompt);
917 self.prompt_history.push_back(user_prompt.clone());
918 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
919 self.prompt_history.pop_front();
920 }
921
922 let assistant_panel_context = assist.assistant_panel_context(cx);
923
924 assist
925 .codegen
926 .update(cx, |codegen, cx| {
927 codegen.start(user_prompt, assistant_panel_context, cx)
928 })
929 .log_err();
930
931 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
932 tx.send(()).ok();
933 }
934 }
935
936 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
937 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
938 assist
939 } else {
940 return;
941 };
942
943 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
944
945 if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
946 tx.send(()).ok();
947 }
948 }
949
950 pub fn assist_status(&self, assist_id: InlineAssistId, cx: &AppContext) -> InlineAssistStatus {
951 if let Some(assist) = self.assists.get(&assist_id) {
952 match &assist.codegen.read(cx).status {
953 CodegenStatus::Idle => InlineAssistStatus::Idle,
954 CodegenStatus::Pending => InlineAssistStatus::Pending,
955 CodegenStatus::Done => InlineAssistStatus::Done,
956 CodegenStatus::Error(_) => InlineAssistStatus::Error,
957 }
958 } else if self.confirmed_assists.contains_key(&assist_id) {
959 InlineAssistStatus::Confirmed
960 } else {
961 InlineAssistStatus::Canceled
962 }
963 }
964
965 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
966 let mut gutter_pending_ranges = Vec::new();
967 let mut gutter_transformed_ranges = Vec::new();
968 let mut foreground_ranges = Vec::new();
969 let mut inserted_row_ranges = Vec::new();
970 let empty_assist_ids = Vec::new();
971 let assist_ids = self
972 .assists_by_editor
973 .get(&editor.downgrade())
974 .map_or(&empty_assist_ids, |editor_assists| {
975 &editor_assists.assist_ids
976 });
977
978 for assist_id in assist_ids {
979 if let Some(assist) = self.assists.get(assist_id) {
980 let codegen = assist.codegen.read(cx);
981 let buffer = codegen.buffer.read(cx).read(cx);
982 foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned());
983
984 let pending_range =
985 codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end;
986 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
987 gutter_pending_ranges.push(pending_range);
988 }
989
990 if let Some(edit_position) = codegen.edit_position {
991 let edited_range = assist.range.start..edit_position;
992 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
993 gutter_transformed_ranges.push(edited_range);
994 }
995 }
996
997 if assist.decorations.is_some() {
998 inserted_row_ranges.extend(codegen.diff.inserted_row_ranges.iter().cloned());
999 }
1000 }
1001 }
1002
1003 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1004 merge_ranges(&mut foreground_ranges, &snapshot);
1005 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1006 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1007 editor.update(cx, |editor, cx| {
1008 enum GutterPendingRange {}
1009 if gutter_pending_ranges.is_empty() {
1010 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1011 } else {
1012 editor.highlight_gutter::<GutterPendingRange>(
1013 &gutter_pending_ranges,
1014 |cx| cx.theme().status().info_background,
1015 cx,
1016 )
1017 }
1018
1019 enum GutterTransformedRange {}
1020 if gutter_transformed_ranges.is_empty() {
1021 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1022 } else {
1023 editor.highlight_gutter::<GutterTransformedRange>(
1024 &gutter_transformed_ranges,
1025 |cx| cx.theme().status().info,
1026 cx,
1027 )
1028 }
1029
1030 if foreground_ranges.is_empty() {
1031 editor.clear_highlights::<InlineAssist>(cx);
1032 } else {
1033 editor.highlight_text::<InlineAssist>(
1034 foreground_ranges,
1035 HighlightStyle {
1036 fade_out: Some(0.6),
1037 ..Default::default()
1038 },
1039 cx,
1040 );
1041 }
1042
1043 editor.clear_row_highlights::<InlineAssist>();
1044 for row_range in inserted_row_ranges {
1045 editor.highlight_rows::<InlineAssist>(
1046 row_range,
1047 Some(cx.theme().status().info_background),
1048 false,
1049 cx,
1050 );
1051 }
1052 });
1053 }
1054
1055 fn update_editor_blocks(
1056 &mut self,
1057 editor: &View<Editor>,
1058 assist_id: InlineAssistId,
1059 cx: &mut WindowContext,
1060 ) {
1061 let Some(assist) = self.assists.get_mut(&assist_id) else {
1062 return;
1063 };
1064 let Some(decorations) = assist.decorations.as_mut() else {
1065 return;
1066 };
1067
1068 let codegen = assist.codegen.read(cx);
1069 let old_snapshot = codegen.snapshot.clone();
1070 let old_buffer = codegen.old_buffer.clone();
1071 let deleted_row_ranges = codegen.diff.deleted_row_ranges.clone();
1072
1073 editor.update(cx, |editor, cx| {
1074 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1075 editor.remove_blocks(old_blocks, None, cx);
1076
1077 let mut new_blocks = Vec::new();
1078 for (new_row, old_row_range) in deleted_row_ranges {
1079 let (_, buffer_start) = old_snapshot
1080 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1081 .unwrap();
1082 let (_, buffer_end) = old_snapshot
1083 .point_to_buffer_offset(Point::new(
1084 *old_row_range.end(),
1085 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1086 ))
1087 .unwrap();
1088
1089 let deleted_lines_editor = cx.new_view(|cx| {
1090 let multi_buffer = cx.new_model(|_| {
1091 MultiBuffer::without_headers(0, language::Capability::ReadOnly)
1092 });
1093 multi_buffer.update(cx, |multi_buffer, cx| {
1094 multi_buffer.push_excerpts(
1095 old_buffer.clone(),
1096 Some(ExcerptRange {
1097 context: buffer_start..buffer_end,
1098 primary: None,
1099 }),
1100 cx,
1101 );
1102 });
1103
1104 enum DeletedLines {}
1105 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1106 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1107 editor.set_show_wrap_guides(false, cx);
1108 editor.set_show_gutter(false, cx);
1109 editor.scroll_manager.set_forbid_vertical_scroll(true);
1110 editor.set_read_only(true);
1111 editor.highlight_rows::<DeletedLines>(
1112 Anchor::min()..=Anchor::max(),
1113 Some(cx.theme().status().deleted_background),
1114 false,
1115 cx,
1116 );
1117 editor
1118 });
1119
1120 let height =
1121 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1122 new_blocks.push(BlockProperties {
1123 position: new_row,
1124 height,
1125 style: BlockStyle::Flex,
1126 render: Box::new(move |cx| {
1127 div()
1128 .bg(cx.theme().status().deleted_background)
1129 .size_full()
1130 .h(height as f32 * cx.line_height())
1131 .pl(cx.gutter_dimensions.full_width())
1132 .child(deleted_lines_editor.clone())
1133 .into_any_element()
1134 }),
1135 disposition: BlockDisposition::Above,
1136 priority: 0,
1137 });
1138 }
1139
1140 decorations.removed_line_block_ids = editor
1141 .insert_blocks(new_blocks, None, cx)
1142 .into_iter()
1143 .collect();
1144 })
1145 }
1146
1147 pub fn observe_assist(&mut self, assist_id: InlineAssistId) -> async_watch::Receiver<()> {
1148 if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
1149 rx.clone()
1150 } else {
1151 let (tx, rx) = async_watch::channel(());
1152 self.assist_observations.insert(assist_id, (tx, rx.clone()));
1153 rx
1154 }
1155 }
1156}
1157
1158pub enum InlineAssistStatus {
1159 Idle,
1160 Pending,
1161 Done,
1162 Error,
1163 Confirmed,
1164 Canceled,
1165}
1166
1167impl InlineAssistStatus {
1168 pub(crate) fn is_pending(&self) -> bool {
1169 matches!(self, Self::Pending)
1170 }
1171 pub(crate) fn is_confirmed(&self) -> bool {
1172 matches!(self, Self::Confirmed)
1173 }
1174 pub(crate) fn is_done(&self) -> bool {
1175 matches!(self, Self::Done)
1176 }
1177}
1178
1179struct EditorInlineAssists {
1180 assist_ids: Vec<InlineAssistId>,
1181 scroll_lock: Option<InlineAssistScrollLock>,
1182 highlight_updates: async_watch::Sender<()>,
1183 _update_highlights: Task<Result<()>>,
1184 _subscriptions: Vec<gpui::Subscription>,
1185}
1186
1187struct InlineAssistScrollLock {
1188 assist_id: InlineAssistId,
1189 distance_from_top: f32,
1190}
1191
1192impl EditorInlineAssists {
1193 #[allow(clippy::too_many_arguments)]
1194 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1195 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1196 Self {
1197 assist_ids: Vec::new(),
1198 scroll_lock: None,
1199 highlight_updates: highlight_updates_tx,
1200 _update_highlights: cx.spawn(|mut cx| {
1201 let editor = editor.downgrade();
1202 async move {
1203 while let Ok(()) = highlight_updates_rx.changed().await {
1204 let editor = editor.upgrade().context("editor was dropped")?;
1205 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1206 assistant.update_editor_highlights(&editor, cx);
1207 })?;
1208 }
1209 Ok(())
1210 }
1211 }),
1212 _subscriptions: vec![
1213 cx.observe_release(editor, {
1214 let editor = editor.downgrade();
1215 |_, cx| {
1216 InlineAssistant::update_global(cx, |this, cx| {
1217 this.handle_editor_release(editor, cx);
1218 })
1219 }
1220 }),
1221 cx.observe(editor, move |editor, cx| {
1222 InlineAssistant::update_global(cx, |this, cx| {
1223 this.handle_editor_change(editor, cx)
1224 })
1225 }),
1226 cx.subscribe(editor, move |editor, event, cx| {
1227 InlineAssistant::update_global(cx, |this, cx| {
1228 this.handle_editor_event(editor, event, cx)
1229 })
1230 }),
1231 editor.update(cx, |editor, cx| {
1232 let editor_handle = cx.view().downgrade();
1233 editor.register_action(
1234 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1235 InlineAssistant::update_global(cx, |this, cx| {
1236 if let Some(editor) = editor_handle.upgrade() {
1237 this.handle_editor_newline(editor, cx)
1238 }
1239 })
1240 },
1241 )
1242 }),
1243 editor.update(cx, |editor, cx| {
1244 let editor_handle = cx.view().downgrade();
1245 editor.register_action(
1246 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1247 InlineAssistant::update_global(cx, |this, cx| {
1248 if let Some(editor) = editor_handle.upgrade() {
1249 this.handle_editor_cancel(editor, cx)
1250 }
1251 })
1252 },
1253 )
1254 }),
1255 ],
1256 }
1257 }
1258}
1259
1260struct InlineAssistGroup {
1261 assist_ids: Vec<InlineAssistId>,
1262 linked: bool,
1263 active_assist_id: Option<InlineAssistId>,
1264}
1265
1266impl InlineAssistGroup {
1267 fn new() -> Self {
1268 Self {
1269 assist_ids: Vec::new(),
1270 linked: true,
1271 active_assist_id: None,
1272 }
1273 }
1274}
1275
1276fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1277 let editor = editor.clone();
1278 Box::new(move |cx: &mut BlockContext| {
1279 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1280 editor.clone().into_any_element()
1281 })
1282}
1283
1284#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1285pub struct InlineAssistId(usize);
1286
1287impl InlineAssistId {
1288 fn post_inc(&mut self) -> InlineAssistId {
1289 let id = *self;
1290 self.0 += 1;
1291 id
1292 }
1293}
1294
1295#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1296struct InlineAssistGroupId(usize);
1297
1298impl InlineAssistGroupId {
1299 fn post_inc(&mut self) -> InlineAssistGroupId {
1300 let id = *self;
1301 self.0 += 1;
1302 id
1303 }
1304}
1305
1306enum PromptEditorEvent {
1307 StartRequested,
1308 StopRequested,
1309 ConfirmRequested,
1310 CancelRequested,
1311 DismissRequested,
1312}
1313
1314struct PromptEditor {
1315 id: InlineAssistId,
1316 fs: Arc<dyn Fs>,
1317 editor: View<Editor>,
1318 edited_since_done: bool,
1319 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1320 prompt_history: VecDeque<String>,
1321 prompt_history_ix: Option<usize>,
1322 pending_prompt: String,
1323 codegen: Model<Codegen>,
1324 _codegen_subscription: Subscription,
1325 editor_subscriptions: Vec<Subscription>,
1326 pending_token_count: Task<Result<()>>,
1327 token_counts: Option<TokenCounts>,
1328 _token_count_subscriptions: Vec<Subscription>,
1329 workspace: Option<WeakView<Workspace>>,
1330 show_rate_limit_notice: bool,
1331}
1332
1333#[derive(Copy, Clone)]
1334pub struct TokenCounts {
1335 total: usize,
1336 assistant_panel: usize,
1337}
1338
1339impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1340
1341impl Render for PromptEditor {
1342 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1343 let gutter_dimensions = *self.gutter_dimensions.lock();
1344 let status = &self.codegen.read(cx).status;
1345 let buttons = match status {
1346 CodegenStatus::Idle => {
1347 vec![
1348 IconButton::new("cancel", IconName::Close)
1349 .icon_color(Color::Muted)
1350 .shape(IconButtonShape::Square)
1351 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1352 .on_click(
1353 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1354 ),
1355 IconButton::new("start", IconName::SparkleAlt)
1356 .icon_color(Color::Muted)
1357 .shape(IconButtonShape::Square)
1358 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1359 .on_click(
1360 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1361 ),
1362 ]
1363 }
1364 CodegenStatus::Pending => {
1365 vec![
1366 IconButton::new("cancel", IconName::Close)
1367 .icon_color(Color::Muted)
1368 .shape(IconButtonShape::Square)
1369 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1370 .on_click(
1371 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1372 ),
1373 IconButton::new("stop", IconName::Stop)
1374 .icon_color(Color::Error)
1375 .shape(IconButtonShape::Square)
1376 .tooltip(|cx| {
1377 Tooltip::with_meta(
1378 "Interrupt Transformation",
1379 Some(&menu::Cancel),
1380 "Changes won't be discarded",
1381 cx,
1382 )
1383 })
1384 .on_click(
1385 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)),
1386 ),
1387 ]
1388 }
1389 CodegenStatus::Error(_) | CodegenStatus::Done => {
1390 vec![
1391 IconButton::new("cancel", IconName::Close)
1392 .icon_color(Color::Muted)
1393 .shape(IconButtonShape::Square)
1394 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1395 .on_click(
1396 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1397 ),
1398 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1399 IconButton::new("restart", IconName::RotateCw)
1400 .icon_color(Color::Info)
1401 .shape(IconButtonShape::Square)
1402 .tooltip(|cx| {
1403 Tooltip::with_meta(
1404 "Restart Transformation",
1405 Some(&menu::Confirm),
1406 "Changes will be discarded",
1407 cx,
1408 )
1409 })
1410 .on_click(cx.listener(|_, _, cx| {
1411 cx.emit(PromptEditorEvent::StartRequested);
1412 }))
1413 } else {
1414 IconButton::new("confirm", IconName::Check)
1415 .icon_color(Color::Info)
1416 .shape(IconButtonShape::Square)
1417 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1418 .on_click(cx.listener(|_, _, cx| {
1419 cx.emit(PromptEditorEvent::ConfirmRequested);
1420 }))
1421 },
1422 ]
1423 }
1424 };
1425
1426 h_flex()
1427 .bg(cx.theme().colors().editor_background)
1428 .border_y_1()
1429 .border_color(cx.theme().status().info_border)
1430 .size_full()
1431 .py(cx.line_height() / 2.)
1432 .on_action(cx.listener(Self::confirm))
1433 .on_action(cx.listener(Self::cancel))
1434 .on_action(cx.listener(Self::move_up))
1435 .on_action(cx.listener(Self::move_down))
1436 .child(
1437 h_flex()
1438 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1439 .justify_center()
1440 .gap_2()
1441 .child(
1442 ModelSelector::new(
1443 self.fs.clone(),
1444 IconButton::new("context", IconName::SlidersAlt)
1445 .shape(IconButtonShape::Square)
1446 .icon_size(IconSize::Small)
1447 .icon_color(Color::Muted)
1448 .tooltip(move |cx| {
1449 Tooltip::with_meta(
1450 format!(
1451 "Using {}",
1452 LanguageModelRegistry::read_global(cx)
1453 .active_model()
1454 .map(|model| model.name().0)
1455 .unwrap_or_else(|| "No model selected".into()),
1456 ),
1457 None,
1458 "Change Model",
1459 cx,
1460 )
1461 }),
1462 )
1463 .with_info_text(
1464 "Inline edits use context\n\
1465 from the currently selected\n\
1466 assistant panel tab.",
1467 ),
1468 )
1469 .map(|el| {
1470 let CodegenStatus::Error(error) = &self.codegen.read(cx).status else {
1471 return el;
1472 };
1473
1474 let error_message = SharedString::from(error.to_string());
1475 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1476 && cx.has_flag::<ZedPro>()
1477 {
1478 el.child(
1479 v_flex()
1480 .child(
1481 IconButton::new("rate-limit-error", IconName::XCircle)
1482 .selected(self.show_rate_limit_notice)
1483 .shape(IconButtonShape::Square)
1484 .icon_size(IconSize::Small)
1485 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1486 )
1487 .children(self.show_rate_limit_notice.then(|| {
1488 deferred(
1489 anchored()
1490 .position_mode(gpui::AnchoredPositionMode::Local)
1491 .position(point(px(0.), px(24.)))
1492 .anchor(gpui::AnchorCorner::TopLeft)
1493 .child(self.render_rate_limit_notice(cx)),
1494 )
1495 })),
1496 )
1497 } else {
1498 el.child(
1499 div()
1500 .id("error")
1501 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1502 .child(
1503 Icon::new(IconName::XCircle)
1504 .size(IconSize::Small)
1505 .color(Color::Error),
1506 ),
1507 )
1508 }
1509 }),
1510 )
1511 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1512 .child(
1513 h_flex()
1514 .gap_2()
1515 .pr_6()
1516 .children(self.render_token_count(cx))
1517 .children(buttons),
1518 )
1519 }
1520}
1521
1522impl FocusableView for PromptEditor {
1523 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1524 self.editor.focus_handle(cx)
1525 }
1526}
1527
1528impl PromptEditor {
1529 const MAX_LINES: u8 = 8;
1530
1531 #[allow(clippy::too_many_arguments)]
1532 fn new(
1533 id: InlineAssistId,
1534 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1535 prompt_history: VecDeque<String>,
1536 prompt_buffer: Model<MultiBuffer>,
1537 codegen: Model<Codegen>,
1538 parent_editor: &View<Editor>,
1539 assistant_panel: Option<&View<AssistantPanel>>,
1540 workspace: Option<WeakView<Workspace>>,
1541 fs: Arc<dyn Fs>,
1542 cx: &mut ViewContext<Self>,
1543 ) -> Self {
1544 let prompt_editor = cx.new_view(|cx| {
1545 let mut editor = Editor::new(
1546 EditorMode::AutoHeight {
1547 max_lines: Self::MAX_LINES as usize,
1548 },
1549 prompt_buffer,
1550 None,
1551 false,
1552 cx,
1553 );
1554 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1555 // Since the prompt editors for all inline assistants are linked,
1556 // always show the cursor (even when it isn't focused) because
1557 // typing in one will make what you typed appear in all of them.
1558 editor.set_show_cursor_when_unfocused(true, cx);
1559 editor.set_placeholder_text("Add a prompt…", cx);
1560 editor
1561 });
1562
1563 let mut token_count_subscriptions = Vec::new();
1564 token_count_subscriptions
1565 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1566 if let Some(assistant_panel) = assistant_panel {
1567 token_count_subscriptions
1568 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1569 }
1570
1571 let mut this = Self {
1572 id,
1573 editor: prompt_editor,
1574 edited_since_done: false,
1575 gutter_dimensions,
1576 prompt_history,
1577 prompt_history_ix: None,
1578 pending_prompt: String::new(),
1579 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1580 editor_subscriptions: Vec::new(),
1581 codegen,
1582 fs,
1583 pending_token_count: Task::ready(Ok(())),
1584 token_counts: None,
1585 _token_count_subscriptions: token_count_subscriptions,
1586 workspace,
1587 show_rate_limit_notice: false,
1588 };
1589 this.count_tokens(cx);
1590 this.subscribe_to_editor(cx);
1591 this
1592 }
1593
1594 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1595 self.editor_subscriptions.clear();
1596 self.editor_subscriptions
1597 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1598 }
1599
1600 fn set_show_cursor_when_unfocused(
1601 &mut self,
1602 show_cursor_when_unfocused: bool,
1603 cx: &mut ViewContext<Self>,
1604 ) {
1605 self.editor.update(cx, |editor, cx| {
1606 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1607 });
1608 }
1609
1610 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1611 let prompt = self.prompt(cx);
1612 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1613 self.editor = cx.new_view(|cx| {
1614 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1615 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1616 editor.set_placeholder_text("Add a prompt…", cx);
1617 editor.set_text(prompt, cx);
1618 if focus {
1619 editor.focus(cx);
1620 }
1621 editor
1622 });
1623 self.subscribe_to_editor(cx);
1624 }
1625
1626 fn prompt(&self, cx: &AppContext) -> String {
1627 self.editor.read(cx).text(cx)
1628 }
1629
1630 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1631 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1632 if self.show_rate_limit_notice {
1633 cx.focus_view(&self.editor);
1634 }
1635 cx.notify();
1636 }
1637
1638 fn handle_parent_editor_event(
1639 &mut self,
1640 _: View<Editor>,
1641 event: &EditorEvent,
1642 cx: &mut ViewContext<Self>,
1643 ) {
1644 if let EditorEvent::BufferEdited { .. } = event {
1645 self.count_tokens(cx);
1646 }
1647 }
1648
1649 fn handle_assistant_panel_event(
1650 &mut self,
1651 _: View<AssistantPanel>,
1652 event: &AssistantPanelEvent,
1653 cx: &mut ViewContext<Self>,
1654 ) {
1655 let AssistantPanelEvent::ContextEdited { .. } = event;
1656 self.count_tokens(cx);
1657 }
1658
1659 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1660 let assist_id = self.id;
1661 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1662 cx.background_executor().timer(Duration::from_secs(1)).await;
1663 let token_count = cx
1664 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1665 let assist = inline_assistant
1666 .assists
1667 .get(&assist_id)
1668 .context("assist not found")?;
1669 anyhow::Ok(assist.count_tokens(cx))
1670 })??
1671 .await?;
1672
1673 this.update(&mut cx, |this, cx| {
1674 this.token_counts = Some(token_count);
1675 cx.notify();
1676 })
1677 })
1678 }
1679
1680 fn handle_prompt_editor_events(
1681 &mut self,
1682 _: View<Editor>,
1683 event: &EditorEvent,
1684 cx: &mut ViewContext<Self>,
1685 ) {
1686 match event {
1687 EditorEvent::Edited { .. } => {
1688 let prompt = self.editor.read(cx).text(cx);
1689 if self
1690 .prompt_history_ix
1691 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1692 {
1693 self.prompt_history_ix.take();
1694 self.pending_prompt = prompt;
1695 }
1696
1697 self.edited_since_done = true;
1698 cx.notify();
1699 }
1700 EditorEvent::BufferEdited => {
1701 self.count_tokens(cx);
1702 }
1703 EditorEvent::Blurred => {
1704 if self.show_rate_limit_notice {
1705 self.show_rate_limit_notice = false;
1706 cx.notify();
1707 }
1708 }
1709 _ => {}
1710 }
1711 }
1712
1713 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1714 match &self.codegen.read(cx).status {
1715 CodegenStatus::Idle => {
1716 self.editor
1717 .update(cx, |editor, _| editor.set_read_only(false));
1718 }
1719 CodegenStatus::Pending => {
1720 self.editor
1721 .update(cx, |editor, _| editor.set_read_only(true));
1722 }
1723 CodegenStatus::Done => {
1724 self.edited_since_done = false;
1725 self.editor
1726 .update(cx, |editor, _| editor.set_read_only(false));
1727 }
1728 CodegenStatus::Error(error) => {
1729 if cx.has_flag::<ZedPro>()
1730 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1731 && !dismissed_rate_limit_notice()
1732 {
1733 self.show_rate_limit_notice = true;
1734 cx.notify();
1735 }
1736
1737 self.edited_since_done = false;
1738 self.editor
1739 .update(cx, |editor, _| editor.set_read_only(false));
1740 }
1741 }
1742 }
1743
1744 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1745 match &self.codegen.read(cx).status {
1746 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1747 cx.emit(PromptEditorEvent::CancelRequested);
1748 }
1749 CodegenStatus::Pending => {
1750 cx.emit(PromptEditorEvent::StopRequested);
1751 }
1752 }
1753 }
1754
1755 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1756 match &self.codegen.read(cx).status {
1757 CodegenStatus::Idle => {
1758 cx.emit(PromptEditorEvent::StartRequested);
1759 }
1760 CodegenStatus::Pending => {
1761 cx.emit(PromptEditorEvent::DismissRequested);
1762 }
1763 CodegenStatus::Done | CodegenStatus::Error(_) => {
1764 if self.edited_since_done {
1765 cx.emit(PromptEditorEvent::StartRequested);
1766 } else {
1767 cx.emit(PromptEditorEvent::ConfirmRequested);
1768 }
1769 }
1770 }
1771 }
1772
1773 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1774 if let Some(ix) = self.prompt_history_ix {
1775 if ix > 0 {
1776 self.prompt_history_ix = Some(ix - 1);
1777 let prompt = self.prompt_history[ix - 1].as_str();
1778 self.editor.update(cx, |editor, cx| {
1779 editor.set_text(prompt, cx);
1780 editor.move_to_beginning(&Default::default(), cx);
1781 });
1782 }
1783 } else if !self.prompt_history.is_empty() {
1784 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1785 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1786 self.editor.update(cx, |editor, cx| {
1787 editor.set_text(prompt, cx);
1788 editor.move_to_beginning(&Default::default(), cx);
1789 });
1790 }
1791 }
1792
1793 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1794 if let Some(ix) = self.prompt_history_ix {
1795 if ix < self.prompt_history.len() - 1 {
1796 self.prompt_history_ix = Some(ix + 1);
1797 let prompt = self.prompt_history[ix + 1].as_str();
1798 self.editor.update(cx, |editor, cx| {
1799 editor.set_text(prompt, cx);
1800 editor.move_to_end(&Default::default(), cx)
1801 });
1802 } else {
1803 self.prompt_history_ix = None;
1804 let prompt = self.pending_prompt.as_str();
1805 self.editor.update(cx, |editor, cx| {
1806 editor.set_text(prompt, cx);
1807 editor.move_to_end(&Default::default(), cx)
1808 });
1809 }
1810 }
1811 }
1812
1813 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1814 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1815 let token_counts = self.token_counts?;
1816 let max_token_count = model.max_token_count();
1817
1818 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1819 let token_count_color = if remaining_tokens <= 0 {
1820 Color::Error
1821 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1822 Color::Warning
1823 } else {
1824 Color::Muted
1825 };
1826
1827 let mut token_count = h_flex()
1828 .id("token_count")
1829 .gap_0p5()
1830 .child(
1831 Label::new(humanize_token_count(token_counts.total))
1832 .size(LabelSize::Small)
1833 .color(token_count_color),
1834 )
1835 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1836 .child(
1837 Label::new(humanize_token_count(max_token_count))
1838 .size(LabelSize::Small)
1839 .color(Color::Muted),
1840 );
1841 if let Some(workspace) = self.workspace.clone() {
1842 token_count = token_count
1843 .tooltip(move |cx| {
1844 Tooltip::with_meta(
1845 format!(
1846 "Tokens Used ({} from the Assistant Panel)",
1847 humanize_token_count(token_counts.assistant_panel)
1848 ),
1849 None,
1850 "Click to open the Assistant Panel",
1851 cx,
1852 )
1853 })
1854 .cursor_pointer()
1855 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
1856 .on_click(move |_, cx| {
1857 cx.stop_propagation();
1858 workspace
1859 .update(cx, |workspace, cx| {
1860 workspace.focus_panel::<AssistantPanel>(cx)
1861 })
1862 .ok();
1863 });
1864 } else {
1865 token_count = token_count
1866 .cursor_default()
1867 .tooltip(|cx| Tooltip::text("Tokens used", cx));
1868 }
1869
1870 Some(token_count)
1871 }
1872
1873 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1874 let settings = ThemeSettings::get_global(cx);
1875 let text_style = TextStyle {
1876 color: if self.editor.read(cx).read_only(cx) {
1877 cx.theme().colors().text_disabled
1878 } else {
1879 cx.theme().colors().text
1880 },
1881 font_family: settings.ui_font.family.clone(),
1882 font_features: settings.ui_font.features.clone(),
1883 font_fallbacks: settings.ui_font.fallbacks.clone(),
1884 font_size: rems(0.875).into(),
1885 font_weight: settings.ui_font.weight,
1886 line_height: relative(1.3),
1887 ..Default::default()
1888 };
1889 EditorElement::new(
1890 &self.editor,
1891 EditorStyle {
1892 background: cx.theme().colors().editor_background,
1893 local_player: cx.theme().players().local(),
1894 text: text_style,
1895 ..Default::default()
1896 },
1897 )
1898 }
1899
1900 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1901 Popover::new().child(
1902 v_flex()
1903 .occlude()
1904 .p_2()
1905 .child(
1906 Label::new("Out of Tokens")
1907 .size(LabelSize::Small)
1908 .weight(FontWeight::BOLD),
1909 )
1910 .child(Label::new(
1911 "Try Zed Pro for higher limits, a wider range of models, and more.",
1912 ))
1913 .child(
1914 h_flex()
1915 .justify_between()
1916 .child(CheckboxWithLabel::new(
1917 "dont-show-again",
1918 Label::new("Don't show again"),
1919 if dismissed_rate_limit_notice() {
1920 ui::Selection::Selected
1921 } else {
1922 ui::Selection::Unselected
1923 },
1924 |selection, cx| {
1925 let is_dismissed = match selection {
1926 ui::Selection::Unselected => false,
1927 ui::Selection::Indeterminate => return,
1928 ui::Selection::Selected => true,
1929 };
1930
1931 set_rate_limit_notice_dismissed(is_dismissed, cx)
1932 },
1933 ))
1934 .child(
1935 h_flex()
1936 .gap_2()
1937 .child(
1938 Button::new("dismiss", "Dismiss")
1939 .style(ButtonStyle::Transparent)
1940 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1941 )
1942 .child(Button::new("more-info", "More Info").on_click(
1943 |_event, cx| {
1944 cx.dispatch_action(Box::new(
1945 zed_actions::OpenAccountSettings,
1946 ))
1947 },
1948 )),
1949 ),
1950 ),
1951 )
1952 }
1953}
1954
1955const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
1956
1957fn dismissed_rate_limit_notice() -> bool {
1958 db::kvp::KEY_VALUE_STORE
1959 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
1960 .log_err()
1961 .map_or(false, |s| s.is_some())
1962}
1963
1964fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
1965 db::write_and_log(cx, move || async move {
1966 if is_dismissed {
1967 db::kvp::KEY_VALUE_STORE
1968 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
1969 .await
1970 } else {
1971 db::kvp::KEY_VALUE_STORE
1972 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
1973 .await
1974 }
1975 })
1976}
1977
1978struct InlineAssist {
1979 group_id: InlineAssistGroupId,
1980 range: Range<Anchor>,
1981 editor: WeakView<Editor>,
1982 decorations: Option<InlineAssistDecorations>,
1983 codegen: Model<Codegen>,
1984 _subscriptions: Vec<Subscription>,
1985 workspace: Option<WeakView<Workspace>>,
1986 include_context: bool,
1987}
1988
1989impl InlineAssist {
1990 #[allow(clippy::too_many_arguments)]
1991 fn new(
1992 assist_id: InlineAssistId,
1993 group_id: InlineAssistGroupId,
1994 include_context: bool,
1995 editor: &View<Editor>,
1996 prompt_editor: &View<PromptEditor>,
1997 prompt_block_id: CustomBlockId,
1998 end_block_id: CustomBlockId,
1999 range: Range<Anchor>,
2000 codegen: Model<Codegen>,
2001 workspace: Option<WeakView<Workspace>>,
2002 cx: &mut WindowContext,
2003 ) -> Self {
2004 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2005 InlineAssist {
2006 group_id,
2007 include_context,
2008 editor: editor.downgrade(),
2009 decorations: Some(InlineAssistDecorations {
2010 prompt_block_id,
2011 prompt_editor: prompt_editor.clone(),
2012 removed_line_block_ids: HashSet::default(),
2013 end_block_id,
2014 }),
2015 range,
2016 codegen: codegen.clone(),
2017 workspace: workspace.clone(),
2018 _subscriptions: vec![
2019 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2020 InlineAssistant::update_global(cx, |this, cx| {
2021 this.handle_prompt_editor_focus_in(assist_id, cx)
2022 })
2023 }),
2024 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2025 InlineAssistant::update_global(cx, |this, cx| {
2026 this.handle_prompt_editor_focus_out(assist_id, cx)
2027 })
2028 }),
2029 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2030 InlineAssistant::update_global(cx, |this, cx| {
2031 this.handle_prompt_editor_event(prompt_editor, event, cx)
2032 })
2033 }),
2034 cx.observe(&codegen, {
2035 let editor = editor.downgrade();
2036 move |_, cx| {
2037 if let Some(editor) = editor.upgrade() {
2038 InlineAssistant::update_global(cx, |this, cx| {
2039 if let Some(editor_assists) =
2040 this.assists_by_editor.get(&editor.downgrade())
2041 {
2042 editor_assists.highlight_updates.send(()).ok();
2043 }
2044
2045 this.update_editor_blocks(&editor, assist_id, cx);
2046 })
2047 }
2048 }
2049 }),
2050 cx.subscribe(&codegen, move |codegen, event, cx| {
2051 InlineAssistant::update_global(cx, |this, cx| match event {
2052 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2053 CodegenEvent::Finished => {
2054 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2055 assist
2056 } else {
2057 return;
2058 };
2059
2060 if let CodegenStatus::Error(error) = &codegen.read(cx).status {
2061 if assist.decorations.is_none() {
2062 if let Some(workspace) = assist
2063 .workspace
2064 .as_ref()
2065 .and_then(|workspace| workspace.upgrade())
2066 {
2067 let error = format!("Inline assistant error: {}", error);
2068 workspace.update(cx, |workspace, cx| {
2069 struct InlineAssistantError;
2070
2071 let id =
2072 NotificationId::identified::<InlineAssistantError>(
2073 assist_id.0,
2074 );
2075
2076 workspace.show_toast(Toast::new(id, error), cx);
2077 })
2078 }
2079 }
2080 }
2081
2082 if assist.decorations.is_none() {
2083 this.finish_assist(assist_id, false, cx);
2084 } else if let Some(tx) = this.assist_observations.get(&assist_id) {
2085 tx.0.send(()).ok();
2086 }
2087 }
2088 })
2089 }),
2090 ],
2091 }
2092 }
2093
2094 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2095 let decorations = self.decorations.as_ref()?;
2096 Some(decorations.prompt_editor.read(cx).prompt(cx))
2097 }
2098
2099 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2100 if self.include_context {
2101 let workspace = self.workspace.as_ref()?;
2102 let workspace = workspace.upgrade()?.read(cx);
2103 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2104 Some(
2105 assistant_panel
2106 .read(cx)
2107 .active_context(cx)?
2108 .read(cx)
2109 .to_completion_request(cx),
2110 )
2111 } else {
2112 None
2113 }
2114 }
2115
2116 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2117 let Some(user_prompt) = self.user_prompt(cx) else {
2118 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2119 };
2120 let assistant_panel_context = self.assistant_panel_context(cx);
2121 self.codegen
2122 .read(cx)
2123 .count_tokens(user_prompt, assistant_panel_context, cx)
2124 }
2125}
2126
2127struct InlineAssistDecorations {
2128 prompt_block_id: CustomBlockId,
2129 prompt_editor: View<PromptEditor>,
2130 removed_line_block_ids: HashSet<CustomBlockId>,
2131 end_block_id: CustomBlockId,
2132}
2133
2134#[derive(Debug)]
2135pub enum CodegenEvent {
2136 Finished,
2137 Undone,
2138}
2139
2140pub struct Codegen {
2141 buffer: Model<MultiBuffer>,
2142 old_buffer: Model<Buffer>,
2143 snapshot: MultiBufferSnapshot,
2144 transform_range: Range<Anchor>,
2145 selected_ranges: Vec<Range<Anchor>>,
2146 edit_position: Option<Anchor>,
2147 last_equal_ranges: Vec<Range<Anchor>>,
2148 initial_transaction_id: Option<TransactionId>,
2149 transformation_transaction_id: Option<TransactionId>,
2150 status: CodegenStatus,
2151 generation: Task<()>,
2152 diff: Diff,
2153 telemetry: Option<Arc<Telemetry>>,
2154 _subscription: gpui::Subscription,
2155 prompt_builder: Arc<PromptBuilder>,
2156}
2157
2158enum CodegenStatus {
2159 Idle,
2160 Pending,
2161 Done,
2162 Error(anyhow::Error),
2163}
2164
2165#[derive(Default)]
2166struct Diff {
2167 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2168 inserted_row_ranges: Vec<RangeInclusive<Anchor>>,
2169}
2170
2171impl Diff {
2172 fn is_empty(&self) -> bool {
2173 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2174 }
2175}
2176
2177impl EventEmitter<CodegenEvent> for Codegen {}
2178
2179impl Codegen {
2180 pub fn new(
2181 buffer: Model<MultiBuffer>,
2182 transform_range: Range<Anchor>,
2183 selected_ranges: Vec<Range<Anchor>>,
2184 initial_transaction_id: Option<TransactionId>,
2185 telemetry: Option<Arc<Telemetry>>,
2186 builder: Arc<PromptBuilder>,
2187 cx: &mut ModelContext<Self>,
2188 ) -> Self {
2189 let snapshot = buffer.read(cx).snapshot(cx);
2190
2191 let (old_buffer, _, _) = buffer
2192 .read(cx)
2193 .range_to_buffer_ranges(transform_range.clone(), cx)
2194 .pop()
2195 .unwrap();
2196 let old_buffer = cx.new_model(|cx| {
2197 let old_buffer = old_buffer.read(cx);
2198 let text = old_buffer.as_rope().clone();
2199 let line_ending = old_buffer.line_ending();
2200 let language = old_buffer.language().cloned();
2201 let language_registry = old_buffer.language_registry();
2202
2203 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2204 buffer.set_language(language, cx);
2205 if let Some(language_registry) = language_registry {
2206 buffer.set_language_registry(language_registry)
2207 }
2208 buffer
2209 });
2210
2211 Self {
2212 buffer: buffer.clone(),
2213 old_buffer,
2214 edit_position: None,
2215 snapshot,
2216 last_equal_ranges: Default::default(),
2217 transformation_transaction_id: None,
2218 status: CodegenStatus::Idle,
2219 generation: Task::ready(()),
2220 diff: Diff::default(),
2221 telemetry,
2222 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2223 initial_transaction_id,
2224 prompt_builder: builder,
2225 transform_range,
2226 selected_ranges,
2227 }
2228 }
2229
2230 fn handle_buffer_event(
2231 &mut self,
2232 _buffer: Model<MultiBuffer>,
2233 event: &multi_buffer::Event,
2234 cx: &mut ModelContext<Self>,
2235 ) {
2236 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2237 if self.transformation_transaction_id == Some(*transaction_id) {
2238 self.transformation_transaction_id = None;
2239 self.generation = Task::ready(());
2240 cx.emit(CodegenEvent::Undone);
2241 }
2242 }
2243 }
2244
2245 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2246 &self.last_equal_ranges
2247 }
2248
2249 pub fn count_tokens(
2250 &self,
2251 user_prompt: String,
2252 assistant_panel_context: Option<LanguageModelRequest>,
2253 cx: &AppContext,
2254 ) -> BoxFuture<'static, Result<TokenCounts>> {
2255 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2256 let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
2257 match request {
2258 Ok(request) => {
2259 let total_count = model.count_tokens(request.clone(), cx);
2260 let assistant_panel_count = assistant_panel_context
2261 .map(|context| model.count_tokens(context, cx))
2262 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2263
2264 async move {
2265 Ok(TokenCounts {
2266 total: total_count.await?,
2267 assistant_panel: assistant_panel_count.await?,
2268 })
2269 }
2270 .boxed()
2271 }
2272 Err(error) => futures::future::ready(Err(error)).boxed(),
2273 }
2274 } else {
2275 future::ready(Err(anyhow!("no active model"))).boxed()
2276 }
2277 }
2278
2279 pub fn start(
2280 &mut self,
2281 user_prompt: String,
2282 assistant_panel_context: Option<LanguageModelRequest>,
2283 cx: &mut ModelContext<Self>,
2284 ) -> Result<()> {
2285 let model = LanguageModelRegistry::read_global(cx)
2286 .active_model()
2287 .context("no active model")?;
2288
2289 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2290 self.buffer.update(cx, |buffer, cx| {
2291 buffer.undo_transaction(transformation_transaction_id, cx);
2292 });
2293 }
2294
2295 self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
2296
2297 let telemetry_id = model.telemetry_id();
2298 let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
2299 if user_prompt.trim().to_lowercase() == "delete" {
2300 async { Ok(stream::empty().boxed()) }.boxed_local()
2301 } else {
2302 let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
2303
2304 let chunks =
2305 cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
2306 async move { Ok(chunks.await?.boxed()) }.boxed_local()
2307 };
2308 self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
2309 Ok(())
2310 }
2311
2312 fn build_request(
2313 &self,
2314 user_prompt: String,
2315 assistant_panel_context: Option<LanguageModelRequest>,
2316 cx: &AppContext,
2317 ) -> Result<LanguageModelRequest> {
2318 let buffer = self.buffer.read(cx).snapshot(cx);
2319 let language = buffer.language_at(self.transform_range.start);
2320 let language_name = if let Some(language) = language.as_ref() {
2321 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2322 None
2323 } else {
2324 Some(language.name())
2325 }
2326 } else {
2327 None
2328 };
2329
2330 // Higher Temperature increases the randomness of model outputs.
2331 // If Markdown or No Language is Known, increase the randomness for more creative output
2332 // If Code, decrease temperature to get more deterministic outputs
2333 let temperature = if let Some(language) = language_name.clone() {
2334 if language.as_ref() == "Markdown" {
2335 1.0
2336 } else {
2337 0.5
2338 }
2339 } else {
2340 1.0
2341 };
2342
2343 let language_name = language_name.as_deref();
2344 let start = buffer.point_to_buffer_offset(self.transform_range.start);
2345 let end = buffer.point_to_buffer_offset(self.transform_range.end);
2346 let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) {
2347 let (start_buffer, start_buffer_offset) = start;
2348 let (end_buffer, end_buffer_offset) = end;
2349 if start_buffer.remote_id() == end_buffer.remote_id() {
2350 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2351 } else {
2352 return Err(anyhow::anyhow!("invalid transformation range"));
2353 }
2354 } else {
2355 return Err(anyhow::anyhow!("invalid transformation range"));
2356 };
2357
2358 let mut transform_context_range = transform_range.to_point(&transform_buffer);
2359 transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
2360 transform_context_range.start.column = 0;
2361 transform_context_range.end =
2362 (transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
2363 transform_context_range.end.column =
2364 transform_buffer.line_len(transform_context_range.end.row);
2365 let transform_context_range = transform_context_range.to_offset(&transform_buffer);
2366
2367 let selected_ranges = self
2368 .selected_ranges
2369 .iter()
2370 .filter_map(|selected_range| {
2371 let start = buffer
2372 .point_to_buffer_offset(selected_range.start)
2373 .map(|(_, offset)| offset)?;
2374 let end = buffer
2375 .point_to_buffer_offset(selected_range.end)
2376 .map(|(_, offset)| offset)?;
2377 Some(start..end)
2378 })
2379 .collect::<Vec<_>>();
2380
2381 let prompt = self
2382 .prompt_builder
2383 .generate_content_prompt(
2384 user_prompt,
2385 language_name,
2386 transform_buffer,
2387 transform_range,
2388 selected_ranges,
2389 transform_context_range,
2390 )
2391 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2392
2393 let mut messages = Vec::new();
2394 if let Some(context_request) = assistant_panel_context {
2395 messages = context_request.messages;
2396 }
2397
2398 messages.push(LanguageModelRequestMessage {
2399 role: Role::User,
2400 content: vec![prompt.into()],
2401 cache: false,
2402 });
2403
2404 Ok(LanguageModelRequest {
2405 messages,
2406 stop: vec!["|END|>".to_string()],
2407 temperature,
2408 })
2409 }
2410
2411 pub fn handle_stream(
2412 &mut self,
2413 model_telemetry_id: String,
2414 edit_range: Range<Anchor>,
2415 stream: impl 'static + Future<Output = Result<BoxStream<'static, Result<String>>>>,
2416 cx: &mut ModelContext<Self>,
2417 ) {
2418 let snapshot = self.snapshot.clone();
2419 let selected_text = snapshot
2420 .text_for_range(edit_range.start..edit_range.end)
2421 .collect::<Rope>();
2422
2423 let selection_start = edit_range.start.to_point(&snapshot);
2424
2425 // Start with the indentation of the first line in the selection
2426 let mut suggested_line_indent = snapshot
2427 .suggested_indents(selection_start.row..=selection_start.row, cx)
2428 .into_values()
2429 .next()
2430 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2431
2432 // If the first line in the selection does not have indentation, check the following lines
2433 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2434 for row in selection_start.row..=edit_range.end.to_point(&snapshot).row {
2435 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2436 // Prefer tabs if a line in the selection uses tabs as indentation
2437 if line_indent.kind == IndentKind::Tab {
2438 suggested_line_indent.kind = IndentKind::Tab;
2439 break;
2440 }
2441 }
2442 }
2443
2444 let telemetry = self.telemetry.clone();
2445 self.diff = Diff::default();
2446 self.status = CodegenStatus::Pending;
2447 let mut edit_start = edit_range.start.to_offset(&snapshot);
2448 self.generation = cx.spawn(|this, mut cx| {
2449 async move {
2450 let chunks = stream.await;
2451 let generate = async {
2452 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2453 let diff: Task<anyhow::Result<()>> =
2454 cx.background_executor().spawn(async move {
2455 let mut response_latency = None;
2456 let request_start = Instant::now();
2457 let diff = async {
2458 let chunks = StripInvalidSpans::new(chunks?);
2459 futures::pin_mut!(chunks);
2460 let mut diff = StreamingDiff::new(selected_text.to_string());
2461 let mut line_diff = LineDiff::default();
2462
2463 while let Some(chunk) = chunks.next().await {
2464 if response_latency.is_none() {
2465 response_latency = Some(request_start.elapsed());
2466 }
2467 let chunk = chunk?;
2468 let char_ops = diff.push_new(&chunk);
2469 line_diff.push_char_operations(&char_ops, &selected_text);
2470 diff_tx
2471 .send((char_ops, line_diff.line_operations()))
2472 .await?;
2473 }
2474
2475 let char_ops = diff.finish();
2476 line_diff.push_char_operations(&char_ops, &selected_text);
2477 line_diff.finish(&selected_text);
2478 diff_tx
2479 .send((char_ops, line_diff.line_operations()))
2480 .await?;
2481
2482 anyhow::Ok(())
2483 };
2484
2485 let result = diff.await;
2486
2487 let error_message =
2488 result.as_ref().err().map(|error| error.to_string());
2489 if let Some(telemetry) = telemetry {
2490 telemetry.report_assistant_event(
2491 None,
2492 telemetry_events::AssistantKind::Inline,
2493 model_telemetry_id,
2494 response_latency,
2495 error_message,
2496 );
2497 }
2498
2499 result?;
2500 Ok(())
2501 });
2502
2503 while let Some((char_ops, line_diff)) = diff_rx.next().await {
2504 this.update(&mut cx, |this, cx| {
2505 this.last_equal_ranges.clear();
2506
2507 let transaction = this.buffer.update(cx, |buffer, cx| {
2508 // Avoid grouping assistant edits with user edits.
2509 buffer.finalize_last_transaction(cx);
2510
2511 buffer.start_transaction(cx);
2512 buffer.edit(
2513 char_ops
2514 .into_iter()
2515 .filter_map(|operation| match operation {
2516 CharOperation::Insert { text } => {
2517 let edit_start = snapshot.anchor_after(edit_start);
2518 Some((edit_start..edit_start, text))
2519 }
2520 CharOperation::Delete { bytes } => {
2521 let edit_end = edit_start + bytes;
2522 let edit_range = snapshot.anchor_after(edit_start)
2523 ..snapshot.anchor_before(edit_end);
2524 edit_start = edit_end;
2525 Some((edit_range, String::new()))
2526 }
2527 CharOperation::Keep { bytes } => {
2528 let edit_end = edit_start + bytes;
2529 let edit_range = snapshot.anchor_after(edit_start)
2530 ..snapshot.anchor_before(edit_end);
2531 edit_start = edit_end;
2532 this.last_equal_ranges.push(edit_range);
2533 None
2534 }
2535 }),
2536 None,
2537 cx,
2538 );
2539 this.edit_position = Some(snapshot.anchor_after(edit_start));
2540
2541 buffer.end_transaction(cx)
2542 });
2543
2544 if let Some(transaction) = transaction {
2545 if let Some(first_transaction) = this.transformation_transaction_id
2546 {
2547 // Group all assistant edits into the first transaction.
2548 this.buffer.update(cx, |buffer, cx| {
2549 buffer.merge_transactions(
2550 transaction,
2551 first_transaction,
2552 cx,
2553 )
2554 });
2555 } else {
2556 this.transformation_transaction_id = Some(transaction);
2557 this.buffer.update(cx, |buffer, cx| {
2558 buffer.finalize_last_transaction(cx)
2559 });
2560 }
2561 }
2562
2563 this.update_diff(edit_range.clone(), line_diff, cx);
2564
2565 cx.notify();
2566 })?;
2567 }
2568
2569 diff.await?;
2570
2571 anyhow::Ok(())
2572 };
2573
2574 let result = generate.await;
2575 this.update(&mut cx, |this, cx| {
2576 this.last_equal_ranges.clear();
2577 if let Err(error) = result {
2578 this.status = CodegenStatus::Error(error);
2579 } else {
2580 this.status = CodegenStatus::Done;
2581 }
2582 cx.emit(CodegenEvent::Finished);
2583 cx.notify();
2584 })
2585 .ok();
2586 }
2587 });
2588 cx.notify();
2589 }
2590
2591 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2592 self.last_equal_ranges.clear();
2593 if self.diff.is_empty() {
2594 self.status = CodegenStatus::Idle;
2595 } else {
2596 self.status = CodegenStatus::Done;
2597 }
2598 self.generation = Task::ready(());
2599 cx.emit(CodegenEvent::Finished);
2600 cx.notify();
2601 }
2602
2603 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2604 self.buffer.update(cx, |buffer, cx| {
2605 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2606 buffer.undo_transaction(transaction_id, cx);
2607 buffer.refresh_preview(cx);
2608 }
2609
2610 if let Some(transaction_id) = self.initial_transaction_id.take() {
2611 buffer.undo_transaction(transaction_id, cx);
2612 buffer.refresh_preview(cx);
2613 }
2614 });
2615 }
2616
2617 fn update_diff(
2618 &mut self,
2619 edit_range: Range<Anchor>,
2620 line_operations: Vec<LineOperation>,
2621 cx: &mut ModelContext<Self>,
2622 ) {
2623 let old_snapshot = self.snapshot.clone();
2624 let old_range = edit_range.to_point(&old_snapshot);
2625 let new_snapshot = self.buffer.read(cx).snapshot(cx);
2626 let new_range = edit_range.to_point(&new_snapshot);
2627
2628 let mut old_row = old_range.start.row;
2629 let mut new_row = new_range.start.row;
2630
2631 self.diff.deleted_row_ranges.clear();
2632 self.diff.inserted_row_ranges.clear();
2633 for operation in line_operations {
2634 match operation {
2635 LineOperation::Keep { lines } => {
2636 old_row += lines;
2637 new_row += lines;
2638 }
2639 LineOperation::Delete { lines } => {
2640 let old_end_row = old_row + lines - 1;
2641 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
2642
2643 if let Some((_, last_deleted_row_range)) =
2644 self.diff.deleted_row_ranges.last_mut()
2645 {
2646 if *last_deleted_row_range.end() + 1 == old_row {
2647 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
2648 } else {
2649 self.diff
2650 .deleted_row_ranges
2651 .push((new_row, old_row..=old_end_row));
2652 }
2653 } else {
2654 self.diff
2655 .deleted_row_ranges
2656 .push((new_row, old_row..=old_end_row));
2657 }
2658
2659 old_row += lines;
2660 }
2661 LineOperation::Insert { lines } => {
2662 let new_end_row = new_row + lines - 1;
2663 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
2664 let end = new_snapshot.anchor_before(Point::new(
2665 new_end_row,
2666 new_snapshot.line_len(MultiBufferRow(new_end_row)),
2667 ));
2668 self.diff.inserted_row_ranges.push(start..=end);
2669 new_row += lines;
2670 }
2671 }
2672
2673 cx.notify();
2674 }
2675 }
2676}
2677
2678struct StripInvalidSpans<T> {
2679 stream: T,
2680 stream_done: bool,
2681 buffer: String,
2682 first_line: bool,
2683 line_end: bool,
2684 starts_with_code_block: bool,
2685}
2686
2687impl<T> StripInvalidSpans<T>
2688where
2689 T: Stream<Item = Result<String>>,
2690{
2691 fn new(stream: T) -> Self {
2692 Self {
2693 stream,
2694 stream_done: false,
2695 buffer: String::new(),
2696 first_line: true,
2697 line_end: false,
2698 starts_with_code_block: false,
2699 }
2700 }
2701}
2702
2703impl<T> Stream for StripInvalidSpans<T>
2704where
2705 T: Stream<Item = Result<String>>,
2706{
2707 type Item = Result<String>;
2708
2709 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
2710 const CODE_BLOCK_DELIMITER: &str = "```";
2711 const CURSOR_SPAN: &str = "<|CURSOR|>";
2712
2713 let this = unsafe { self.get_unchecked_mut() };
2714 loop {
2715 if !this.stream_done {
2716 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
2717 match stream.as_mut().poll_next(cx) {
2718 Poll::Ready(Some(Ok(chunk))) => {
2719 this.buffer.push_str(&chunk);
2720 }
2721 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
2722 Poll::Ready(None) => {
2723 this.stream_done = true;
2724 }
2725 Poll::Pending => return Poll::Pending,
2726 }
2727 }
2728
2729 let mut chunk = String::new();
2730 let mut consumed = 0;
2731 if !this.buffer.is_empty() {
2732 let mut lines = this.buffer.split('\n').enumerate().peekable();
2733 while let Some((line_ix, line)) = lines.next() {
2734 if line_ix > 0 {
2735 this.first_line = false;
2736 }
2737
2738 if this.first_line {
2739 let trimmed_line = line.trim();
2740 if lines.peek().is_some() {
2741 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
2742 consumed += line.len() + 1;
2743 this.starts_with_code_block = true;
2744 continue;
2745 }
2746 } else if trimmed_line.is_empty()
2747 || prefixes(CODE_BLOCK_DELIMITER)
2748 .any(|prefix| trimmed_line.starts_with(prefix))
2749 {
2750 break;
2751 }
2752 }
2753
2754 let line_without_cursor = line.replace(CURSOR_SPAN, "");
2755 if lines.peek().is_some() {
2756 if this.line_end {
2757 chunk.push('\n');
2758 }
2759
2760 chunk.push_str(&line_without_cursor);
2761 this.line_end = true;
2762 consumed += line.len() + 1;
2763 } else if this.stream_done {
2764 if !this.starts_with_code_block
2765 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
2766 {
2767 if this.line_end {
2768 chunk.push('\n');
2769 }
2770
2771 chunk.push_str(&line);
2772 }
2773
2774 consumed += line.len();
2775 } else {
2776 let trimmed_line = line.trim();
2777 if trimmed_line.is_empty()
2778 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
2779 || prefixes(CODE_BLOCK_DELIMITER)
2780 .any(|prefix| trimmed_line.ends_with(prefix))
2781 {
2782 break;
2783 } else {
2784 if this.line_end {
2785 chunk.push('\n');
2786 this.line_end = false;
2787 }
2788
2789 chunk.push_str(&line_without_cursor);
2790 consumed += line.len();
2791 }
2792 }
2793 }
2794 }
2795
2796 this.buffer = this.buffer.split_off(consumed);
2797 if !chunk.is_empty() {
2798 return Poll::Ready(Some(Ok(chunk)));
2799 } else if this.stream_done {
2800 return Poll::Ready(None);
2801 }
2802 }
2803 }
2804}
2805
2806fn prefixes(text: &str) -> impl Iterator<Item = &str> {
2807 (0..text.len() - 1).map(|ix| &text[..ix + 1])
2808}
2809
2810fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
2811 ranges.sort_unstable_by(|a, b| {
2812 a.start
2813 .cmp(&b.start, buffer)
2814 .then_with(|| b.end.cmp(&a.end, buffer))
2815 });
2816
2817 let mut ix = 0;
2818 while ix + 1 < ranges.len() {
2819 let b = ranges[ix + 1].clone();
2820 let a = &mut ranges[ix];
2821 if a.end.cmp(&b.start, buffer).is_gt() {
2822 if a.end.cmp(&b.end, buffer).is_lt() {
2823 a.end = b.end;
2824 }
2825 ranges.remove(ix + 1);
2826 } else {
2827 ix += 1;
2828 }
2829 }
2830}
2831
2832#[cfg(test)]
2833mod tests {
2834 use super::*;
2835 use futures::stream::{self};
2836 use serde::Serialize;
2837
2838 #[derive(Serialize)]
2839 pub struct DummyCompletionRequest {
2840 pub name: String,
2841 }
2842
2843 #[gpui::test]
2844 async fn test_strip_invalid_spans_from_codeblock() {
2845 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
2846 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
2847 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
2848 assert_chunks(
2849 "```html\n```js\nLorem ipsum dolor\n```\n```",
2850 "```js\nLorem ipsum dolor\n```",
2851 )
2852 .await;
2853 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
2854 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
2855 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
2856 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
2857
2858 async fn assert_chunks(text: &str, expected_text: &str) {
2859 for chunk_size in 1..=text.len() {
2860 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
2861 .map(|chunk| chunk.unwrap())
2862 .collect::<String>()
2863 .await;
2864 assert_eq!(
2865 actual_text, expected_text,
2866 "failed to strip invalid spans, chunk size: {}",
2867 chunk_size
2868 );
2869 }
2870 }
2871
2872 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
2873 stream::iter(
2874 text.chars()
2875 .collect::<Vec<_>>()
2876 .chunks(size)
2877 .map(|chunk| Ok(chunk.iter().collect::<String>()))
2878 .collect::<Vec<_>>(),
2879 )
2880 }
2881 }
2882}