1use crate::{
2 assistant_settings::AssistantSettings, humanize_token_count, prompts::PromptBuilder,
3 AssistantPanel, AssistantPanelEvent, CharOperation, CycleNextInlineAssist,
4 CyclePreviousInlineAssist, LineDiff, LineOperation, ModelSelector, RequestType, StreamingDiff,
5};
6use anyhow::{anyhow, Context as _, Result};
7use client::{telemetry::Telemetry, ErrorExt};
8use collections::{hash_map, HashMap, HashSet, VecDeque};
9use editor::{
10 actions::{MoveDown, MoveUp, SelectAll},
11 display_map::{
12 BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
13 ToDisplayPoint,
14 },
15 Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
16 EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
17 ToOffset as _, ToPoint,
18};
19use feature_flags::{FeatureFlagAppExt as _, ZedPro};
20use fs::Fs;
21use futures::{
22 channel::mpsc,
23 future::{BoxFuture, LocalBoxFuture},
24 join, SinkExt, Stream, StreamExt,
25};
26use gpui::{
27 anchored, deferred, point, AnyElement, AppContext, ClickEvent, EventEmitter, FocusHandle,
28 FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task,
29 TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
30};
31use language::{Buffer, IndentKind, Point, Selection, TransactionId};
32use language_model::{
33 logging::report_assistant_event, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
34 LanguageModelRequestMessage, LanguageModelTextStream, Role,
35};
36use multi_buffer::MultiBufferRow;
37use parking_lot::Mutex;
38use project::{CodeAction, ProjectTransaction};
39use rope::Rope;
40use settings::{Settings, SettingsStore};
41use smol::future::FutureExt;
42use std::{
43 cmp,
44 future::{self, Future},
45 iter, mem,
46 ops::{Range, RangeInclusive},
47 pin::Pin,
48 sync::Arc,
49 task::{self, Poll},
50 time::{Duration, Instant},
51};
52use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
53use terminal_view::terminal_panel::TerminalPanel;
54use text::{OffsetRangeExt, ToPoint as _};
55use theme::ThemeSettings;
56use ui::{prelude::*, text_for_action, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
57use util::{RangeExt, ResultExt};
58use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
59
60pub fn init(
61 fs: Arc<dyn Fs>,
62 prompt_builder: Arc<PromptBuilder>,
63 telemetry: Arc<Telemetry>,
64 cx: &mut AppContext,
65) {
66 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
67 cx.observe_new_views(|_, cx| {
68 let workspace = cx.view().clone();
69 InlineAssistant::update_global(cx, |inline_assistant, cx| {
70 inline_assistant.register_workspace(&workspace, cx)
71 })
72 })
73 .detach();
74}
75
76const PROMPT_HISTORY_MAX_LEN: usize = 20;
77
78pub struct InlineAssistant {
79 next_assist_id: InlineAssistId,
80 next_assist_group_id: InlineAssistGroupId,
81 assists: HashMap<InlineAssistId, InlineAssist>,
82 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
83 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
84 confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
85 prompt_history: VecDeque<String>,
86 prompt_builder: Arc<PromptBuilder>,
87 telemetry: Arc<Telemetry>,
88 fs: Arc<dyn Fs>,
89}
90
91impl Global for InlineAssistant {}
92
93impl InlineAssistant {
94 pub fn new(
95 fs: Arc<dyn Fs>,
96 prompt_builder: Arc<PromptBuilder>,
97 telemetry: Arc<Telemetry>,
98 ) -> Self {
99 Self {
100 next_assist_id: InlineAssistId::default(),
101 next_assist_group_id: InlineAssistGroupId::default(),
102 assists: HashMap::default(),
103 assists_by_editor: HashMap::default(),
104 assist_groups: HashMap::default(),
105 confirmed_assists: HashMap::default(),
106 prompt_history: VecDeque::default(),
107 prompt_builder,
108 telemetry,
109 fs,
110 }
111 }
112
113 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
114 cx.subscribe(workspace, |workspace, event, cx| {
115 Self::update_global(cx, |this, cx| {
116 this.handle_workspace_event(workspace, event, cx)
117 });
118 })
119 .detach();
120
121 let workspace = workspace.downgrade();
122 cx.observe_global::<SettingsStore>(move |cx| {
123 let Some(workspace) = workspace.upgrade() else {
124 return;
125 };
126 let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
127 return;
128 };
129 let enabled = AssistantSettings::get_global(cx).enabled;
130 terminal_panel.update(cx, |terminal_panel, cx| {
131 terminal_panel.asssistant_enabled(enabled, cx)
132 });
133 })
134 .detach();
135 }
136
137 fn handle_workspace_event(
138 &mut self,
139 workspace: View<Workspace>,
140 event: &workspace::Event,
141 cx: &mut WindowContext,
142 ) {
143 match event {
144 workspace::Event::UserSavedItem { item, .. } => {
145 // When the user manually saves an editor, automatically accepts all finished transformations.
146 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
147 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
148 for assist_id in editor_assists.assist_ids.clone() {
149 let assist = &self.assists[&assist_id];
150 if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
151 self.finish_assist(assist_id, false, cx)
152 }
153 }
154 }
155 }
156 }
157 workspace::Event::ItemAdded { item } => {
158 self.register_workspace_item(&workspace, item.as_ref(), cx);
159 }
160 _ => (),
161 }
162 }
163
164 fn register_workspace_item(
165 &mut self,
166 workspace: &View<Workspace>,
167 item: &dyn ItemHandle,
168 cx: &mut WindowContext,
169 ) {
170 if let Some(editor) = item.act_as::<Editor>(cx) {
171 editor.update(cx, |editor, cx| {
172 editor.push_code_action_provider(
173 Arc::new(AssistantCodeActionProvider {
174 editor: cx.view().downgrade(),
175 workspace: workspace.downgrade(),
176 }),
177 cx,
178 );
179 });
180 }
181 }
182
183 pub fn assist(
184 &mut self,
185 editor: &View<Editor>,
186 workspace: Option<WeakView<Workspace>>,
187 assistant_panel: Option<&View<AssistantPanel>>,
188 initial_prompt: Option<String>,
189 cx: &mut WindowContext,
190 ) {
191 let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
192 (
193 editor.buffer().read(cx).snapshot(cx),
194 editor.selections.all::<Point>(cx),
195 )
196 });
197
198 let mut selections = Vec::<Selection<Point>>::new();
199 let mut newest_selection = None;
200 for mut selection in initial_selections {
201 if selection.end > selection.start {
202 selection.start.column = 0;
203 // If the selection ends at the start of the line, we don't want to include it.
204 if selection.end.column == 0 {
205 selection.end.row -= 1;
206 }
207 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
208 }
209
210 if let Some(prev_selection) = selections.last_mut() {
211 if selection.start <= prev_selection.end {
212 prev_selection.end = selection.end;
213 continue;
214 }
215 }
216
217 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
218 if selection.id > latest_selection.id {
219 *latest_selection = selection.clone();
220 }
221 selections.push(selection);
222 }
223 let newest_selection = newest_selection.unwrap();
224
225 let mut codegen_ranges = Vec::new();
226 for (excerpt_id, buffer, buffer_range) in
227 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
228 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
229 }))
230 {
231 let start = Anchor {
232 buffer_id: Some(buffer.remote_id()),
233 excerpt_id,
234 text_anchor: buffer.anchor_before(buffer_range.start),
235 };
236 let end = Anchor {
237 buffer_id: Some(buffer.remote_id()),
238 excerpt_id,
239 text_anchor: buffer.anchor_after(buffer_range.end),
240 };
241 codegen_ranges.push(start..end);
242
243 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
244 self.telemetry.report_assistant_event(AssistantEvent {
245 conversation_id: None,
246 kind: AssistantKind::Inline,
247 phase: AssistantPhase::Invoked,
248 message_id: None,
249 model: model.telemetry_id(),
250 model_provider: model.provider_id().to_string(),
251 response_latency: None,
252 error_message: None,
253 language_name: buffer.language().map(|language| language.name().to_proto()),
254 });
255 }
256 }
257
258 let assist_group_id = self.next_assist_group_id.post_inc();
259 let prompt_buffer =
260 cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
261 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
262
263 let mut assists = Vec::new();
264 let mut assist_to_focus = None;
265 for range in codegen_ranges {
266 let assist_id = self.next_assist_id.post_inc();
267 let codegen = cx.new_model(|cx| {
268 Codegen::new(
269 editor.read(cx).buffer().clone(),
270 range.clone(),
271 None,
272 self.telemetry.clone(),
273 self.prompt_builder.clone(),
274 cx,
275 )
276 });
277
278 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
279 let prompt_editor = cx.new_view(|cx| {
280 PromptEditor::new(
281 assist_id,
282 gutter_dimensions.clone(),
283 self.prompt_history.clone(),
284 prompt_buffer.clone(),
285 codegen.clone(),
286 editor,
287 assistant_panel,
288 workspace.clone(),
289 self.fs.clone(),
290 cx,
291 )
292 });
293
294 if assist_to_focus.is_none() {
295 let focus_assist = if newest_selection.reversed {
296 range.start.to_point(&snapshot) == newest_selection.start
297 } else {
298 range.end.to_point(&snapshot) == newest_selection.end
299 };
300 if focus_assist {
301 assist_to_focus = Some(assist_id);
302 }
303 }
304
305 let [prompt_block_id, end_block_id] =
306 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
307
308 assists.push((
309 assist_id,
310 range,
311 prompt_editor,
312 prompt_block_id,
313 end_block_id,
314 ));
315 }
316
317 let editor_assists = self
318 .assists_by_editor
319 .entry(editor.downgrade())
320 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
321 let mut assist_group = InlineAssistGroup::new();
322 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
323 self.assists.insert(
324 assist_id,
325 InlineAssist::new(
326 assist_id,
327 assist_group_id,
328 assistant_panel.is_some(),
329 editor,
330 &prompt_editor,
331 prompt_block_id,
332 end_block_id,
333 range,
334 prompt_editor.read(cx).codegen.clone(),
335 workspace.clone(),
336 cx,
337 ),
338 );
339 assist_group.assist_ids.push(assist_id);
340 editor_assists.assist_ids.push(assist_id);
341 }
342 self.assist_groups.insert(assist_group_id, assist_group);
343
344 if let Some(assist_id) = assist_to_focus {
345 self.focus_assist(assist_id, cx);
346 }
347 }
348
349 #[allow(clippy::too_many_arguments)]
350 pub fn suggest_assist(
351 &mut self,
352 editor: &View<Editor>,
353 mut range: Range<Anchor>,
354 initial_prompt: String,
355 initial_transaction_id: Option<TransactionId>,
356 focus: bool,
357 workspace: Option<WeakView<Workspace>>,
358 assistant_panel: Option<&View<AssistantPanel>>,
359 cx: &mut WindowContext,
360 ) -> InlineAssistId {
361 let assist_group_id = self.next_assist_group_id.post_inc();
362 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
363 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
364
365 let assist_id = self.next_assist_id.post_inc();
366
367 let buffer = editor.read(cx).buffer().clone();
368 {
369 let snapshot = buffer.read(cx).read(cx);
370 range.start = range.start.bias_left(&snapshot);
371 range.end = range.end.bias_right(&snapshot);
372 }
373
374 let codegen = cx.new_model(|cx| {
375 Codegen::new(
376 editor.read(cx).buffer().clone(),
377 range.clone(),
378 initial_transaction_id,
379 self.telemetry.clone(),
380 self.prompt_builder.clone(),
381 cx,
382 )
383 });
384
385 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
386 let prompt_editor = cx.new_view(|cx| {
387 PromptEditor::new(
388 assist_id,
389 gutter_dimensions.clone(),
390 self.prompt_history.clone(),
391 prompt_buffer.clone(),
392 codegen.clone(),
393 editor,
394 assistant_panel,
395 workspace.clone(),
396 self.fs.clone(),
397 cx,
398 )
399 });
400
401 let [prompt_block_id, end_block_id] =
402 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
403
404 let editor_assists = self
405 .assists_by_editor
406 .entry(editor.downgrade())
407 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
408
409 let mut assist_group = InlineAssistGroup::new();
410 self.assists.insert(
411 assist_id,
412 InlineAssist::new(
413 assist_id,
414 assist_group_id,
415 assistant_panel.is_some(),
416 editor,
417 &prompt_editor,
418 prompt_block_id,
419 end_block_id,
420 range,
421 prompt_editor.read(cx).codegen.clone(),
422 workspace.clone(),
423 cx,
424 ),
425 );
426 assist_group.assist_ids.push(assist_id);
427 editor_assists.assist_ids.push(assist_id);
428 self.assist_groups.insert(assist_group_id, assist_group);
429
430 if focus {
431 self.focus_assist(assist_id, cx);
432 }
433
434 assist_id
435 }
436
437 fn insert_assist_blocks(
438 &self,
439 editor: &View<Editor>,
440 range: &Range<Anchor>,
441 prompt_editor: &View<PromptEditor>,
442 cx: &mut WindowContext,
443 ) -> [CustomBlockId; 2] {
444 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
445 prompt_editor
446 .editor
447 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
448 });
449 let assist_blocks = vec![
450 BlockProperties {
451 style: BlockStyle::Sticky,
452 placement: BlockPlacement::Above(range.start),
453 height: prompt_editor_height,
454 render: build_assist_editor_renderer(prompt_editor),
455 priority: 0,
456 },
457 BlockProperties {
458 style: BlockStyle::Sticky,
459 placement: BlockPlacement::Below(range.end),
460 height: 0,
461 render: Box::new(|cx| {
462 v_flex()
463 .h_full()
464 .w_full()
465 .border_t_1()
466 .border_color(cx.theme().status().info_border)
467 .into_any_element()
468 }),
469 priority: 0,
470 },
471 ];
472
473 editor.update(cx, |editor, cx| {
474 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
475 [block_ids[0], block_ids[1]]
476 })
477 }
478
479 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
480 let assist = &self.assists[&assist_id];
481 let Some(decorations) = assist.decorations.as_ref() else {
482 return;
483 };
484 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
485 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
486
487 assist_group.active_assist_id = Some(assist_id);
488 if assist_group.linked {
489 for assist_id in &assist_group.assist_ids {
490 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
491 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
492 prompt_editor.set_show_cursor_when_unfocused(true, cx)
493 });
494 }
495 }
496 }
497
498 assist
499 .editor
500 .update(cx, |editor, cx| {
501 let scroll_top = editor.scroll_position(cx).y;
502 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
503 let prompt_row = editor
504 .row_for_block(decorations.prompt_block_id, cx)
505 .unwrap()
506 .0 as f32;
507
508 if (scroll_top..scroll_bottom).contains(&prompt_row) {
509 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
510 assist_id,
511 distance_from_top: prompt_row - scroll_top,
512 });
513 } else {
514 editor_assists.scroll_lock = None;
515 }
516 })
517 .ok();
518 }
519
520 fn handle_prompt_editor_focus_out(
521 &mut self,
522 assist_id: InlineAssistId,
523 cx: &mut WindowContext,
524 ) {
525 let assist = &self.assists[&assist_id];
526 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
527 if assist_group.active_assist_id == Some(assist_id) {
528 assist_group.active_assist_id = None;
529 if assist_group.linked {
530 for assist_id in &assist_group.assist_ids {
531 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
532 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
533 prompt_editor.set_show_cursor_when_unfocused(false, cx)
534 });
535 }
536 }
537 }
538 }
539 }
540
541 fn handle_prompt_editor_event(
542 &mut self,
543 prompt_editor: View<PromptEditor>,
544 event: &PromptEditorEvent,
545 cx: &mut WindowContext,
546 ) {
547 let assist_id = prompt_editor.read(cx).id;
548 match event {
549 PromptEditorEvent::StartRequested => {
550 self.start_assist(assist_id, cx);
551 }
552 PromptEditorEvent::StopRequested => {
553 self.stop_assist(assist_id, cx);
554 }
555 PromptEditorEvent::ConfirmRequested => {
556 self.finish_assist(assist_id, false, cx);
557 }
558 PromptEditorEvent::CancelRequested => {
559 self.finish_assist(assist_id, true, cx);
560 }
561 PromptEditorEvent::DismissRequested => {
562 self.dismiss_assist(assist_id, cx);
563 }
564 }
565 }
566
567 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
568 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
569 return;
570 };
571
572 if editor.read(cx).selections.count() == 1 {
573 let (selection, buffer) = editor.update(cx, |editor, cx| {
574 (
575 editor.selections.newest::<usize>(cx),
576 editor.buffer().read(cx).snapshot(cx),
577 )
578 });
579 for assist_id in &editor_assists.assist_ids {
580 let assist = &self.assists[assist_id];
581 let assist_range = assist.range.to_offset(&buffer);
582 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
583 {
584 if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
585 self.dismiss_assist(*assist_id, cx);
586 } else {
587 self.finish_assist(*assist_id, false, cx);
588 }
589
590 return;
591 }
592 }
593 }
594
595 cx.propagate();
596 }
597
598 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
599 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
600 return;
601 };
602
603 if editor.read(cx).selections.count() == 1 {
604 let (selection, buffer) = editor.update(cx, |editor, cx| {
605 (
606 editor.selections.newest::<usize>(cx),
607 editor.buffer().read(cx).snapshot(cx),
608 )
609 });
610 let mut closest_assist_fallback = None;
611 for assist_id in &editor_assists.assist_ids {
612 let assist = &self.assists[assist_id];
613 let assist_range = assist.range.to_offset(&buffer);
614 if assist.decorations.is_some() {
615 if assist_range.contains(&selection.start)
616 && assist_range.contains(&selection.end)
617 {
618 self.focus_assist(*assist_id, cx);
619 return;
620 } else {
621 let distance_from_selection = assist_range
622 .start
623 .abs_diff(selection.start)
624 .min(assist_range.start.abs_diff(selection.end))
625 + assist_range
626 .end
627 .abs_diff(selection.start)
628 .min(assist_range.end.abs_diff(selection.end));
629 match closest_assist_fallback {
630 Some((_, old_distance)) => {
631 if distance_from_selection < old_distance {
632 closest_assist_fallback =
633 Some((assist_id, distance_from_selection));
634 }
635 }
636 None => {
637 closest_assist_fallback = Some((assist_id, distance_from_selection))
638 }
639 }
640 }
641 }
642 }
643
644 if let Some((&assist_id, _)) = closest_assist_fallback {
645 self.focus_assist(assist_id, cx);
646 }
647 }
648
649 cx.propagate();
650 }
651
652 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
653 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
654 for assist_id in editor_assists.assist_ids.clone() {
655 self.finish_assist(assist_id, true, cx);
656 }
657 }
658 }
659
660 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
661 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
662 return;
663 };
664 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
665 return;
666 };
667 let assist = &self.assists[&scroll_lock.assist_id];
668 let Some(decorations) = assist.decorations.as_ref() else {
669 return;
670 };
671
672 editor.update(cx, |editor, cx| {
673 let scroll_position = editor.scroll_position(cx);
674 let target_scroll_top = editor
675 .row_for_block(decorations.prompt_block_id, cx)
676 .unwrap()
677 .0 as f32
678 - scroll_lock.distance_from_top;
679 if target_scroll_top != scroll_position.y {
680 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
681 }
682 });
683 }
684
685 fn handle_editor_event(
686 &mut self,
687 editor: View<Editor>,
688 event: &EditorEvent,
689 cx: &mut WindowContext,
690 ) {
691 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
692 return;
693 };
694
695 match event {
696 EditorEvent::Edited { transaction_id } => {
697 let buffer = editor.read(cx).buffer().read(cx);
698 let edited_ranges =
699 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
700 let snapshot = buffer.snapshot(cx);
701
702 for assist_id in editor_assists.assist_ids.clone() {
703 let assist = &self.assists[&assist_id];
704 if matches!(
705 assist.codegen.read(cx).status(cx),
706 CodegenStatus::Error(_) | CodegenStatus::Done
707 ) {
708 let assist_range = assist.range.to_offset(&snapshot);
709 if edited_ranges
710 .iter()
711 .any(|range| range.overlaps(&assist_range))
712 {
713 self.finish_assist(assist_id, false, cx);
714 }
715 }
716 }
717 }
718 EditorEvent::ScrollPositionChanged { .. } => {
719 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
720 let assist = &self.assists[&scroll_lock.assist_id];
721 if let Some(decorations) = assist.decorations.as_ref() {
722 let distance_from_top = editor.update(cx, |editor, cx| {
723 let scroll_top = editor.scroll_position(cx).y;
724 let prompt_row = editor
725 .row_for_block(decorations.prompt_block_id, cx)
726 .unwrap()
727 .0 as f32;
728 prompt_row - scroll_top
729 });
730
731 if distance_from_top != scroll_lock.distance_from_top {
732 editor_assists.scroll_lock = None;
733 }
734 }
735 }
736 }
737 EditorEvent::SelectionsChanged { .. } => {
738 for assist_id in editor_assists.assist_ids.clone() {
739 let assist = &self.assists[&assist_id];
740 if let Some(decorations) = assist.decorations.as_ref() {
741 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
742 return;
743 }
744 }
745 }
746
747 editor_assists.scroll_lock = None;
748 }
749 _ => {}
750 }
751 }
752
753 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
754 if let Some(assist) = self.assists.get(&assist_id) {
755 let assist_group_id = assist.group_id;
756 if self.assist_groups[&assist_group_id].linked {
757 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
758 self.finish_assist(assist_id, undo, cx);
759 }
760 return;
761 }
762 }
763
764 self.dismiss_assist(assist_id, cx);
765
766 if let Some(assist) = self.assists.remove(&assist_id) {
767 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
768 {
769 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
770 if entry.get().assist_ids.is_empty() {
771 entry.remove();
772 }
773 }
774
775 if let hash_map::Entry::Occupied(mut entry) =
776 self.assists_by_editor.entry(assist.editor.clone())
777 {
778 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
779 if entry.get().assist_ids.is_empty() {
780 entry.remove();
781 if let Some(editor) = assist.editor.upgrade() {
782 self.update_editor_highlights(&editor, cx);
783 }
784 } else {
785 entry.get().highlight_updates.send(()).ok();
786 }
787 }
788
789 let active_alternative = assist.codegen.read(cx).active_alternative().clone();
790 let message_id = active_alternative.read(cx).message_id.clone();
791
792 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
793 let language_name = assist.editor.upgrade().and_then(|editor| {
794 let multibuffer = editor.read(cx).buffer().read(cx);
795 let ranges = multibuffer.range_to_buffer_ranges(assist.range.clone(), cx);
796 ranges
797 .first()
798 .and_then(|(buffer, _, _)| buffer.read(cx).language())
799 .map(|language| language.name())
800 });
801 report_assistant_event(
802 AssistantEvent {
803 conversation_id: None,
804 kind: AssistantKind::Inline,
805 message_id,
806 phase: if undo {
807 AssistantPhase::Rejected
808 } else {
809 AssistantPhase::Accepted
810 },
811 model: model.telemetry_id(),
812 model_provider: model.provider_id().to_string(),
813 response_latency: None,
814 error_message: None,
815 language_name: language_name.map(|name| name.to_proto()),
816 },
817 Some(self.telemetry.clone()),
818 cx.http_client(),
819 model.api_key(cx),
820 cx.background_executor(),
821 );
822 }
823
824 if undo {
825 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
826 } else {
827 self.confirmed_assists.insert(assist_id, active_alternative);
828 }
829 }
830 }
831
832 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
833 let Some(assist) = self.assists.get_mut(&assist_id) else {
834 return false;
835 };
836 let Some(editor) = assist.editor.upgrade() else {
837 return false;
838 };
839 let Some(decorations) = assist.decorations.take() else {
840 return false;
841 };
842
843 editor.update(cx, |editor, cx| {
844 let mut to_remove = decorations.removed_line_block_ids;
845 to_remove.insert(decorations.prompt_block_id);
846 to_remove.insert(decorations.end_block_id);
847 editor.remove_blocks(to_remove, None, cx);
848 });
849
850 if decorations
851 .prompt_editor
852 .focus_handle(cx)
853 .contains_focused(cx)
854 {
855 self.focus_next_assist(assist_id, cx);
856 }
857
858 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
859 if editor_assists
860 .scroll_lock
861 .as_ref()
862 .map_or(false, |lock| lock.assist_id == assist_id)
863 {
864 editor_assists.scroll_lock = None;
865 }
866 editor_assists.highlight_updates.send(()).ok();
867 }
868
869 true
870 }
871
872 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
873 let Some(assist) = self.assists.get(&assist_id) else {
874 return;
875 };
876
877 let assist_group = &self.assist_groups[&assist.group_id];
878 let assist_ix = assist_group
879 .assist_ids
880 .iter()
881 .position(|id| *id == assist_id)
882 .unwrap();
883 let assist_ids = assist_group
884 .assist_ids
885 .iter()
886 .skip(assist_ix + 1)
887 .chain(assist_group.assist_ids.iter().take(assist_ix));
888
889 for assist_id in assist_ids {
890 let assist = &self.assists[assist_id];
891 if assist.decorations.is_some() {
892 self.focus_assist(*assist_id, cx);
893 return;
894 }
895 }
896
897 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
898 }
899
900 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
901 let Some(assist) = self.assists.get(&assist_id) else {
902 return;
903 };
904
905 if let Some(decorations) = assist.decorations.as_ref() {
906 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
907 prompt_editor.editor.update(cx, |editor, cx| {
908 editor.focus(cx);
909 editor.select_all(&SelectAll, cx);
910 })
911 });
912 }
913
914 self.scroll_to_assist(assist_id, cx);
915 }
916
917 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
918 let Some(assist) = self.assists.get(&assist_id) else {
919 return;
920 };
921 let Some(editor) = assist.editor.upgrade() else {
922 return;
923 };
924
925 let position = assist.range.start;
926 editor.update(cx, |editor, cx| {
927 editor.change_selections(None, cx, |selections| {
928 selections.select_anchor_ranges([position..position])
929 });
930
931 let mut scroll_target_top;
932 let mut scroll_target_bottom;
933 if let Some(decorations) = assist.decorations.as_ref() {
934 scroll_target_top = editor
935 .row_for_block(decorations.prompt_block_id, cx)
936 .unwrap()
937 .0 as f32;
938 scroll_target_bottom = editor
939 .row_for_block(decorations.end_block_id, cx)
940 .unwrap()
941 .0 as f32;
942 } else {
943 let snapshot = editor.snapshot(cx);
944 let start_row = assist
945 .range
946 .start
947 .to_display_point(&snapshot.display_snapshot)
948 .row();
949 scroll_target_top = start_row.0 as f32;
950 scroll_target_bottom = scroll_target_top + 1.;
951 }
952 scroll_target_top -= editor.vertical_scroll_margin() as f32;
953 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
954
955 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
956 let scroll_top = editor.scroll_position(cx).y;
957 let scroll_bottom = scroll_top + height_in_lines;
958
959 if scroll_target_top < scroll_top {
960 editor.set_scroll_position(point(0., scroll_target_top), cx);
961 } else if scroll_target_bottom > scroll_bottom {
962 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
963 editor
964 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
965 } else {
966 editor.set_scroll_position(point(0., scroll_target_top), cx);
967 }
968 }
969 });
970 }
971
972 fn unlink_assist_group(
973 &mut self,
974 assist_group_id: InlineAssistGroupId,
975 cx: &mut WindowContext,
976 ) -> Vec<InlineAssistId> {
977 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
978 assist_group.linked = false;
979 for assist_id in &assist_group.assist_ids {
980 let assist = self.assists.get_mut(assist_id).unwrap();
981 if let Some(editor_decorations) = assist.decorations.as_ref() {
982 editor_decorations
983 .prompt_editor
984 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
985 }
986 }
987 assist_group.assist_ids.clone()
988 }
989
990 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
991 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
992 assist
993 } else {
994 return;
995 };
996
997 let assist_group_id = assist.group_id;
998 if self.assist_groups[&assist_group_id].linked {
999 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
1000 self.start_assist(assist_id, cx);
1001 }
1002 return;
1003 }
1004
1005 let Some(user_prompt) = assist.user_prompt(cx) else {
1006 return;
1007 };
1008
1009 self.prompt_history.retain(|prompt| *prompt != user_prompt);
1010 self.prompt_history.push_back(user_prompt.clone());
1011 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1012 self.prompt_history.pop_front();
1013 }
1014
1015 let assistant_panel_context = assist.assistant_panel_context(cx);
1016
1017 assist
1018 .codegen
1019 .update(cx, |codegen, cx| {
1020 codegen.start(user_prompt, assistant_panel_context, cx)
1021 })
1022 .log_err();
1023 }
1024
1025 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1026 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1027 assist
1028 } else {
1029 return;
1030 };
1031
1032 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1033 }
1034
1035 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1036 let mut gutter_pending_ranges = Vec::new();
1037 let mut gutter_transformed_ranges = Vec::new();
1038 let mut foreground_ranges = Vec::new();
1039 let mut inserted_row_ranges = Vec::new();
1040 let empty_assist_ids = Vec::new();
1041 let assist_ids = self
1042 .assists_by_editor
1043 .get(&editor.downgrade())
1044 .map_or(&empty_assist_ids, |editor_assists| {
1045 &editor_assists.assist_ids
1046 });
1047
1048 for assist_id in assist_ids {
1049 if let Some(assist) = self.assists.get(assist_id) {
1050 let codegen = assist.codegen.read(cx);
1051 let buffer = codegen.buffer(cx).read(cx).read(cx);
1052 foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1053
1054 let pending_range =
1055 codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1056 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1057 gutter_pending_ranges.push(pending_range);
1058 }
1059
1060 if let Some(edit_position) = codegen.edit_position(cx) {
1061 let edited_range = assist.range.start..edit_position;
1062 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1063 gutter_transformed_ranges.push(edited_range);
1064 }
1065 }
1066
1067 if assist.decorations.is_some() {
1068 inserted_row_ranges
1069 .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1070 }
1071 }
1072 }
1073
1074 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1075 merge_ranges(&mut foreground_ranges, &snapshot);
1076 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1077 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1078 editor.update(cx, |editor, cx| {
1079 enum GutterPendingRange {}
1080 if gutter_pending_ranges.is_empty() {
1081 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1082 } else {
1083 editor.highlight_gutter::<GutterPendingRange>(
1084 &gutter_pending_ranges,
1085 |cx| cx.theme().status().info_background,
1086 cx,
1087 )
1088 }
1089
1090 enum GutterTransformedRange {}
1091 if gutter_transformed_ranges.is_empty() {
1092 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1093 } else {
1094 editor.highlight_gutter::<GutterTransformedRange>(
1095 &gutter_transformed_ranges,
1096 |cx| cx.theme().status().info,
1097 cx,
1098 )
1099 }
1100
1101 if foreground_ranges.is_empty() {
1102 editor.clear_highlights::<InlineAssist>(cx);
1103 } else {
1104 editor.highlight_text::<InlineAssist>(
1105 foreground_ranges,
1106 HighlightStyle {
1107 fade_out: Some(0.6),
1108 ..Default::default()
1109 },
1110 cx,
1111 );
1112 }
1113
1114 editor.clear_row_highlights::<InlineAssist>();
1115 for row_range in inserted_row_ranges {
1116 editor.highlight_rows::<InlineAssist>(
1117 row_range,
1118 cx.theme().status().info_background,
1119 false,
1120 cx,
1121 );
1122 }
1123 });
1124 }
1125
1126 fn update_editor_blocks(
1127 &mut self,
1128 editor: &View<Editor>,
1129 assist_id: InlineAssistId,
1130 cx: &mut WindowContext,
1131 ) {
1132 let Some(assist) = self.assists.get_mut(&assist_id) else {
1133 return;
1134 };
1135 let Some(decorations) = assist.decorations.as_mut() else {
1136 return;
1137 };
1138
1139 let codegen = assist.codegen.read(cx);
1140 let old_snapshot = codegen.snapshot(cx);
1141 let old_buffer = codegen.old_buffer(cx);
1142 let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1143
1144 editor.update(cx, |editor, cx| {
1145 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1146 editor.remove_blocks(old_blocks, None, cx);
1147
1148 let mut new_blocks = Vec::new();
1149 for (new_row, old_row_range) in deleted_row_ranges {
1150 let (_, buffer_start) = old_snapshot
1151 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1152 .unwrap();
1153 let (_, buffer_end) = old_snapshot
1154 .point_to_buffer_offset(Point::new(
1155 *old_row_range.end(),
1156 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1157 ))
1158 .unwrap();
1159
1160 let deleted_lines_editor = cx.new_view(|cx| {
1161 let multi_buffer = cx.new_model(|_| {
1162 MultiBuffer::without_headers(language::Capability::ReadOnly)
1163 });
1164 multi_buffer.update(cx, |multi_buffer, cx| {
1165 multi_buffer.push_excerpts(
1166 old_buffer.clone(),
1167 Some(ExcerptRange {
1168 context: buffer_start..buffer_end,
1169 primary: None,
1170 }),
1171 cx,
1172 );
1173 });
1174
1175 enum DeletedLines {}
1176 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1177 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1178 editor.set_show_wrap_guides(false, cx);
1179 editor.set_show_gutter(false, cx);
1180 editor.scroll_manager.set_forbid_vertical_scroll(true);
1181 editor.set_read_only(true);
1182 editor.set_show_inline_completions(Some(false), cx);
1183 editor.highlight_rows::<DeletedLines>(
1184 Anchor::min()..Anchor::max(),
1185 cx.theme().status().deleted_background,
1186 false,
1187 cx,
1188 );
1189 editor
1190 });
1191
1192 let height =
1193 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1194 new_blocks.push(BlockProperties {
1195 placement: BlockPlacement::Above(new_row),
1196 height,
1197 style: BlockStyle::Flex,
1198 render: Box::new(move |cx| {
1199 div()
1200 .bg(cx.theme().status().deleted_background)
1201 .size_full()
1202 .h(height as f32 * cx.line_height())
1203 .pl(cx.gutter_dimensions.full_width())
1204 .child(deleted_lines_editor.clone())
1205 .into_any_element()
1206 }),
1207 priority: 0,
1208 });
1209 }
1210
1211 decorations.removed_line_block_ids = editor
1212 .insert_blocks(new_blocks, None, cx)
1213 .into_iter()
1214 .collect();
1215 })
1216 }
1217}
1218
1219struct EditorInlineAssists {
1220 assist_ids: Vec<InlineAssistId>,
1221 scroll_lock: Option<InlineAssistScrollLock>,
1222 highlight_updates: async_watch::Sender<()>,
1223 _update_highlights: Task<Result<()>>,
1224 _subscriptions: Vec<gpui::Subscription>,
1225}
1226
1227struct InlineAssistScrollLock {
1228 assist_id: InlineAssistId,
1229 distance_from_top: f32,
1230}
1231
1232impl EditorInlineAssists {
1233 #[allow(clippy::too_many_arguments)]
1234 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1235 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1236 Self {
1237 assist_ids: Vec::new(),
1238 scroll_lock: None,
1239 highlight_updates: highlight_updates_tx,
1240 _update_highlights: cx.spawn(|mut cx| {
1241 let editor = editor.downgrade();
1242 async move {
1243 while let Ok(()) = highlight_updates_rx.changed().await {
1244 let editor = editor.upgrade().context("editor was dropped")?;
1245 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1246 assistant.update_editor_highlights(&editor, cx);
1247 })?;
1248 }
1249 Ok(())
1250 }
1251 }),
1252 _subscriptions: vec![
1253 cx.observe_release(editor, {
1254 let editor = editor.downgrade();
1255 |_, cx| {
1256 InlineAssistant::update_global(cx, |this, cx| {
1257 this.handle_editor_release(editor, cx);
1258 })
1259 }
1260 }),
1261 cx.observe(editor, move |editor, cx| {
1262 InlineAssistant::update_global(cx, |this, cx| {
1263 this.handle_editor_change(editor, cx)
1264 })
1265 }),
1266 cx.subscribe(editor, move |editor, event, cx| {
1267 InlineAssistant::update_global(cx, |this, cx| {
1268 this.handle_editor_event(editor, event, cx)
1269 })
1270 }),
1271 editor.update(cx, |editor, cx| {
1272 let editor_handle = cx.view().downgrade();
1273 editor.register_action(
1274 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1275 InlineAssistant::update_global(cx, |this, cx| {
1276 if let Some(editor) = editor_handle.upgrade() {
1277 this.handle_editor_newline(editor, cx)
1278 }
1279 })
1280 },
1281 )
1282 }),
1283 editor.update(cx, |editor, cx| {
1284 let editor_handle = cx.view().downgrade();
1285 editor.register_action(
1286 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1287 InlineAssistant::update_global(cx, |this, cx| {
1288 if let Some(editor) = editor_handle.upgrade() {
1289 this.handle_editor_cancel(editor, cx)
1290 }
1291 })
1292 },
1293 )
1294 }),
1295 ],
1296 }
1297 }
1298}
1299
1300struct InlineAssistGroup {
1301 assist_ids: Vec<InlineAssistId>,
1302 linked: bool,
1303 active_assist_id: Option<InlineAssistId>,
1304}
1305
1306impl InlineAssistGroup {
1307 fn new() -> Self {
1308 Self {
1309 assist_ids: Vec::new(),
1310 linked: true,
1311 active_assist_id: None,
1312 }
1313 }
1314}
1315
1316fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1317 let editor = editor.clone();
1318 Box::new(move |cx: &mut BlockContext| {
1319 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1320 editor.clone().into_any_element()
1321 })
1322}
1323
1324#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1325pub struct InlineAssistId(usize);
1326
1327impl InlineAssistId {
1328 fn post_inc(&mut self) -> InlineAssistId {
1329 let id = *self;
1330 self.0 += 1;
1331 id
1332 }
1333}
1334
1335#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1336struct InlineAssistGroupId(usize);
1337
1338impl InlineAssistGroupId {
1339 fn post_inc(&mut self) -> InlineAssistGroupId {
1340 let id = *self;
1341 self.0 += 1;
1342 id
1343 }
1344}
1345
1346enum PromptEditorEvent {
1347 StartRequested,
1348 StopRequested,
1349 ConfirmRequested,
1350 CancelRequested,
1351 DismissRequested,
1352}
1353
1354struct PromptEditor {
1355 id: InlineAssistId,
1356 fs: Arc<dyn Fs>,
1357 editor: View<Editor>,
1358 edited_since_done: bool,
1359 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1360 prompt_history: VecDeque<String>,
1361 prompt_history_ix: Option<usize>,
1362 pending_prompt: String,
1363 codegen: Model<Codegen>,
1364 _codegen_subscription: Subscription,
1365 editor_subscriptions: Vec<Subscription>,
1366 pending_token_count: Task<Result<()>>,
1367 token_counts: Option<TokenCounts>,
1368 _token_count_subscriptions: Vec<Subscription>,
1369 workspace: Option<WeakView<Workspace>>,
1370 show_rate_limit_notice: bool,
1371}
1372
1373#[derive(Copy, Clone)]
1374pub struct TokenCounts {
1375 total: usize,
1376 assistant_panel: usize,
1377}
1378
1379impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1380
1381impl Render for PromptEditor {
1382 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1383 let gutter_dimensions = *self.gutter_dimensions.lock();
1384 let codegen = self.codegen.read(cx);
1385
1386 let mut buttons = Vec::new();
1387 if codegen.alternative_count(cx) > 1 {
1388 buttons.push(self.render_cycle_controls(cx));
1389 }
1390
1391 let status = codegen.status(cx);
1392 buttons.extend(match status {
1393 CodegenStatus::Idle => {
1394 vec![
1395 IconButton::new("cancel", IconName::Close)
1396 .icon_color(Color::Muted)
1397 .shape(IconButtonShape::Square)
1398 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1399 .on_click(
1400 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1401 )
1402 .into_any_element(),
1403 IconButton::new("start", IconName::SparkleAlt)
1404 .icon_color(Color::Muted)
1405 .shape(IconButtonShape::Square)
1406 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1407 .on_click(
1408 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1409 )
1410 .into_any_element(),
1411 ]
1412 }
1413 CodegenStatus::Pending => {
1414 vec![
1415 IconButton::new("cancel", IconName::Close)
1416 .icon_color(Color::Muted)
1417 .shape(IconButtonShape::Square)
1418 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1419 .on_click(
1420 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1421 )
1422 .into_any_element(),
1423 IconButton::new("stop", IconName::Stop)
1424 .icon_color(Color::Error)
1425 .shape(IconButtonShape::Square)
1426 .tooltip(|cx| {
1427 Tooltip::with_meta(
1428 "Interrupt Transformation",
1429 Some(&menu::Cancel),
1430 "Changes won't be discarded",
1431 cx,
1432 )
1433 })
1434 .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
1435 .into_any_element(),
1436 ]
1437 }
1438 CodegenStatus::Error(_) | CodegenStatus::Done => {
1439 vec![
1440 IconButton::new("cancel", IconName::Close)
1441 .icon_color(Color::Muted)
1442 .shape(IconButtonShape::Square)
1443 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1444 .on_click(
1445 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1446 )
1447 .into_any_element(),
1448 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1449 IconButton::new("restart", IconName::RotateCw)
1450 .icon_color(Color::Info)
1451 .shape(IconButtonShape::Square)
1452 .tooltip(|cx| {
1453 Tooltip::with_meta(
1454 "Restart Transformation",
1455 Some(&menu::Confirm),
1456 "Changes will be discarded",
1457 cx,
1458 )
1459 })
1460 .on_click(cx.listener(|_, _, cx| {
1461 cx.emit(PromptEditorEvent::StartRequested);
1462 }))
1463 .into_any_element()
1464 } else {
1465 IconButton::new("confirm", IconName::Check)
1466 .icon_color(Color::Info)
1467 .shape(IconButtonShape::Square)
1468 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1469 .on_click(cx.listener(|_, _, cx| {
1470 cx.emit(PromptEditorEvent::ConfirmRequested);
1471 }))
1472 .into_any_element()
1473 },
1474 ]
1475 }
1476 });
1477
1478 h_flex()
1479 .key_context("PromptEditor")
1480 .bg(cx.theme().colors().editor_background)
1481 .border_y_1()
1482 .border_color(cx.theme().status().info_border)
1483 .size_full()
1484 .py(cx.line_height() / 2.5)
1485 .on_action(cx.listener(Self::confirm))
1486 .on_action(cx.listener(Self::cancel))
1487 .on_action(cx.listener(Self::move_up))
1488 .on_action(cx.listener(Self::move_down))
1489 .capture_action(cx.listener(Self::cycle_prev))
1490 .capture_action(cx.listener(Self::cycle_next))
1491 .child(
1492 h_flex()
1493 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1494 .justify_center()
1495 .gap_2()
1496 .child(
1497 ModelSelector::new(
1498 self.fs.clone(),
1499 IconButton::new("context", IconName::SettingsAlt)
1500 .shape(IconButtonShape::Square)
1501 .icon_size(IconSize::Small)
1502 .icon_color(Color::Muted)
1503 .tooltip(move |cx| {
1504 Tooltip::with_meta(
1505 format!(
1506 "Using {}",
1507 LanguageModelRegistry::read_global(cx)
1508 .active_model()
1509 .map(|model| model.name().0)
1510 .unwrap_or_else(|| "No model selected".into()),
1511 ),
1512 None,
1513 "Change Model",
1514 cx,
1515 )
1516 }),
1517 )
1518 .with_info_text(
1519 "Inline edits use context\n\
1520 from the currently selected\n\
1521 assistant panel tab.",
1522 ),
1523 )
1524 .map(|el| {
1525 let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1526 return el;
1527 };
1528
1529 let error_message = SharedString::from(error.to_string());
1530 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1531 && cx.has_flag::<ZedPro>()
1532 {
1533 el.child(
1534 v_flex()
1535 .child(
1536 IconButton::new("rate-limit-error", IconName::XCircle)
1537 .selected(self.show_rate_limit_notice)
1538 .shape(IconButtonShape::Square)
1539 .icon_size(IconSize::Small)
1540 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1541 )
1542 .children(self.show_rate_limit_notice.then(|| {
1543 deferred(
1544 anchored()
1545 .position_mode(gpui::AnchoredPositionMode::Local)
1546 .position(point(px(0.), px(24.)))
1547 .anchor(gpui::AnchorCorner::TopLeft)
1548 .child(self.render_rate_limit_notice(cx)),
1549 )
1550 })),
1551 )
1552 } else {
1553 el.child(
1554 div()
1555 .id("error")
1556 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1557 .child(
1558 Icon::new(IconName::XCircle)
1559 .size(IconSize::Small)
1560 .color(Color::Error),
1561 ),
1562 )
1563 }
1564 }),
1565 )
1566 .child(div().flex_1().child(self.render_prompt_editor(cx)))
1567 .child(
1568 h_flex()
1569 .gap_2()
1570 .pr_6()
1571 .children(self.render_token_count(cx))
1572 .children(buttons),
1573 )
1574 }
1575}
1576
1577impl FocusableView for PromptEditor {
1578 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1579 self.editor.focus_handle(cx)
1580 }
1581}
1582
1583impl PromptEditor {
1584 const MAX_LINES: u8 = 8;
1585
1586 #[allow(clippy::too_many_arguments)]
1587 fn new(
1588 id: InlineAssistId,
1589 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1590 prompt_history: VecDeque<String>,
1591 prompt_buffer: Model<MultiBuffer>,
1592 codegen: Model<Codegen>,
1593 parent_editor: &View<Editor>,
1594 assistant_panel: Option<&View<AssistantPanel>>,
1595 workspace: Option<WeakView<Workspace>>,
1596 fs: Arc<dyn Fs>,
1597 cx: &mut ViewContext<Self>,
1598 ) -> Self {
1599 let prompt_editor = cx.new_view(|cx| {
1600 let mut editor = Editor::new(
1601 EditorMode::AutoHeight {
1602 max_lines: Self::MAX_LINES as usize,
1603 },
1604 prompt_buffer,
1605 None,
1606 false,
1607 cx,
1608 );
1609 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1610 // Since the prompt editors for all inline assistants are linked,
1611 // always show the cursor (even when it isn't focused) because
1612 // typing in one will make what you typed appear in all of them.
1613 editor.set_show_cursor_when_unfocused(true, cx);
1614 editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx), cx), cx);
1615 editor
1616 });
1617
1618 let mut token_count_subscriptions = Vec::new();
1619 token_count_subscriptions
1620 .push(cx.subscribe(parent_editor, Self::handle_parent_editor_event));
1621 if let Some(assistant_panel) = assistant_panel {
1622 token_count_subscriptions
1623 .push(cx.subscribe(assistant_panel, Self::handle_assistant_panel_event));
1624 }
1625
1626 let mut this = Self {
1627 id,
1628 editor: prompt_editor,
1629 edited_since_done: false,
1630 gutter_dimensions,
1631 prompt_history,
1632 prompt_history_ix: None,
1633 pending_prompt: String::new(),
1634 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1635 editor_subscriptions: Vec::new(),
1636 codegen,
1637 fs,
1638 pending_token_count: Task::ready(Ok(())),
1639 token_counts: None,
1640 _token_count_subscriptions: token_count_subscriptions,
1641 workspace,
1642 show_rate_limit_notice: false,
1643 };
1644 this.count_tokens(cx);
1645 this.subscribe_to_editor(cx);
1646 this
1647 }
1648
1649 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1650 self.editor_subscriptions.clear();
1651 self.editor_subscriptions
1652 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1653 }
1654
1655 fn set_show_cursor_when_unfocused(
1656 &mut self,
1657 show_cursor_when_unfocused: bool,
1658 cx: &mut ViewContext<Self>,
1659 ) {
1660 self.editor.update(cx, |editor, cx| {
1661 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1662 });
1663 }
1664
1665 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1666 let prompt = self.prompt(cx);
1667 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1668 self.editor = cx.new_view(|cx| {
1669 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1670 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1671 editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx), cx), cx);
1672 editor.set_placeholder_text("Add a prompt…", cx);
1673 editor.set_text(prompt, cx);
1674 if focus {
1675 editor.focus(cx);
1676 }
1677 editor
1678 });
1679 self.subscribe_to_editor(cx);
1680 }
1681
1682 fn placeholder_text(codegen: &Codegen, cx: &WindowContext) -> String {
1683 let context_keybinding = text_for_action(&crate::ToggleFocus, cx)
1684 .map(|keybinding| format!(" • {keybinding} for context"))
1685 .unwrap_or_default();
1686
1687 let action = if codegen.is_insertion {
1688 "Generate"
1689 } else {
1690 "Transform"
1691 };
1692
1693 format!("{action}…{context_keybinding} • ↓↑ for history")
1694 }
1695
1696 fn prompt(&self, cx: &AppContext) -> String {
1697 self.editor.read(cx).text(cx)
1698 }
1699
1700 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1701 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1702 if self.show_rate_limit_notice {
1703 cx.focus_view(&self.editor);
1704 }
1705 cx.notify();
1706 }
1707
1708 fn handle_parent_editor_event(
1709 &mut self,
1710 _: View<Editor>,
1711 event: &EditorEvent,
1712 cx: &mut ViewContext<Self>,
1713 ) {
1714 if let EditorEvent::BufferEdited { .. } = event {
1715 self.count_tokens(cx);
1716 }
1717 }
1718
1719 fn handle_assistant_panel_event(
1720 &mut self,
1721 _: View<AssistantPanel>,
1722 event: &AssistantPanelEvent,
1723 cx: &mut ViewContext<Self>,
1724 ) {
1725 let AssistantPanelEvent::ContextEdited { .. } = event;
1726 self.count_tokens(cx);
1727 }
1728
1729 fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
1730 let assist_id = self.id;
1731 self.pending_token_count = cx.spawn(|this, mut cx| async move {
1732 cx.background_executor().timer(Duration::from_secs(1)).await;
1733 let token_count = cx
1734 .update_global(|inline_assistant: &mut InlineAssistant, cx| {
1735 let assist = inline_assistant
1736 .assists
1737 .get(&assist_id)
1738 .context("assist not found")?;
1739 anyhow::Ok(assist.count_tokens(cx))
1740 })??
1741 .await?;
1742
1743 this.update(&mut cx, |this, cx| {
1744 this.token_counts = Some(token_count);
1745 cx.notify();
1746 })
1747 })
1748 }
1749
1750 fn handle_prompt_editor_events(
1751 &mut self,
1752 _: View<Editor>,
1753 event: &EditorEvent,
1754 cx: &mut ViewContext<Self>,
1755 ) {
1756 match event {
1757 EditorEvent::Edited { .. } => {
1758 if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
1759 workspace
1760 .update(cx, |workspace, cx| {
1761 let is_via_ssh = workspace
1762 .project()
1763 .update(cx, |project, _| project.is_via_ssh());
1764
1765 workspace
1766 .client()
1767 .telemetry()
1768 .log_edit_event("inline assist", is_via_ssh);
1769 })
1770 .log_err();
1771 }
1772 let prompt = self.editor.read(cx).text(cx);
1773 if self
1774 .prompt_history_ix
1775 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1776 {
1777 self.prompt_history_ix.take();
1778 self.pending_prompt = prompt;
1779 }
1780
1781 self.edited_since_done = true;
1782 cx.notify();
1783 }
1784 EditorEvent::BufferEdited => {
1785 self.count_tokens(cx);
1786 }
1787 EditorEvent::Blurred => {
1788 if self.show_rate_limit_notice {
1789 self.show_rate_limit_notice = false;
1790 cx.notify();
1791 }
1792 }
1793 _ => {}
1794 }
1795 }
1796
1797 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1798 match self.codegen.read(cx).status(cx) {
1799 CodegenStatus::Idle => {
1800 self.editor
1801 .update(cx, |editor, _| editor.set_read_only(false));
1802 }
1803 CodegenStatus::Pending => {
1804 self.editor
1805 .update(cx, |editor, _| editor.set_read_only(true));
1806 }
1807 CodegenStatus::Done => {
1808 self.edited_since_done = false;
1809 self.editor
1810 .update(cx, |editor, _| editor.set_read_only(false));
1811 }
1812 CodegenStatus::Error(error) => {
1813 if cx.has_flag::<ZedPro>()
1814 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1815 && !dismissed_rate_limit_notice()
1816 {
1817 self.show_rate_limit_notice = true;
1818 cx.notify();
1819 }
1820
1821 self.edited_since_done = false;
1822 self.editor
1823 .update(cx, |editor, _| editor.set_read_only(false));
1824 }
1825 }
1826 }
1827
1828 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1829 match self.codegen.read(cx).status(cx) {
1830 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1831 cx.emit(PromptEditorEvent::CancelRequested);
1832 }
1833 CodegenStatus::Pending => {
1834 cx.emit(PromptEditorEvent::StopRequested);
1835 }
1836 }
1837 }
1838
1839 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1840 match self.codegen.read(cx).status(cx) {
1841 CodegenStatus::Idle => {
1842 cx.emit(PromptEditorEvent::StartRequested);
1843 }
1844 CodegenStatus::Pending => {
1845 cx.emit(PromptEditorEvent::DismissRequested);
1846 }
1847 CodegenStatus::Done => {
1848 if self.edited_since_done {
1849 cx.emit(PromptEditorEvent::StartRequested);
1850 } else {
1851 cx.emit(PromptEditorEvent::ConfirmRequested);
1852 }
1853 }
1854 CodegenStatus::Error(_) => {
1855 cx.emit(PromptEditorEvent::StartRequested);
1856 }
1857 }
1858 }
1859
1860 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1861 if let Some(ix) = self.prompt_history_ix {
1862 if ix > 0 {
1863 self.prompt_history_ix = Some(ix - 1);
1864 let prompt = self.prompt_history[ix - 1].as_str();
1865 self.editor.update(cx, |editor, cx| {
1866 editor.set_text(prompt, cx);
1867 editor.move_to_beginning(&Default::default(), cx);
1868 });
1869 }
1870 } else if !self.prompt_history.is_empty() {
1871 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1872 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1873 self.editor.update(cx, |editor, cx| {
1874 editor.set_text(prompt, cx);
1875 editor.move_to_beginning(&Default::default(), cx);
1876 });
1877 }
1878 }
1879
1880 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1881 if let Some(ix) = self.prompt_history_ix {
1882 if ix < self.prompt_history.len() - 1 {
1883 self.prompt_history_ix = Some(ix + 1);
1884 let prompt = self.prompt_history[ix + 1].as_str();
1885 self.editor.update(cx, |editor, cx| {
1886 editor.set_text(prompt, cx);
1887 editor.move_to_end(&Default::default(), cx)
1888 });
1889 } else {
1890 self.prompt_history_ix = None;
1891 let prompt = self.pending_prompt.as_str();
1892 self.editor.update(cx, |editor, cx| {
1893 editor.set_text(prompt, cx);
1894 editor.move_to_end(&Default::default(), cx)
1895 });
1896 }
1897 }
1898 }
1899
1900 fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
1901 self.codegen
1902 .update(cx, |codegen, cx| codegen.cycle_prev(cx));
1903 }
1904
1905 fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
1906 self.codegen
1907 .update(cx, |codegen, cx| codegen.cycle_next(cx));
1908 }
1909
1910 fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
1911 let codegen = self.codegen.read(cx);
1912 let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
1913
1914 h_flex()
1915 .child(
1916 IconButton::new("previous", IconName::ChevronLeft)
1917 .icon_color(Color::Muted)
1918 .disabled(disabled)
1919 .shape(IconButtonShape::Square)
1920 .tooltip({
1921 let focus_handle = self.editor.focus_handle(cx);
1922 move |cx| {
1923 Tooltip::for_action_in(
1924 "Previous Alternative",
1925 &CyclePreviousInlineAssist,
1926 &focus_handle,
1927 cx,
1928 )
1929 }
1930 })
1931 .on_click(cx.listener(|this, _, cx| {
1932 this.codegen
1933 .update(cx, |codegen, cx| codegen.cycle_prev(cx))
1934 })),
1935 )
1936 .child(
1937 Label::new(format!(
1938 "{}/{}",
1939 codegen.active_alternative + 1,
1940 codegen.alternative_count(cx)
1941 ))
1942 .size(LabelSize::Small)
1943 .color(if disabled {
1944 Color::Disabled
1945 } else {
1946 Color::Muted
1947 }),
1948 )
1949 .child(
1950 IconButton::new("next", IconName::ChevronRight)
1951 .icon_color(Color::Muted)
1952 .disabled(disabled)
1953 .shape(IconButtonShape::Square)
1954 .tooltip({
1955 let focus_handle = self.editor.focus_handle(cx);
1956 move |cx| {
1957 Tooltip::for_action_in(
1958 "Next Alternative",
1959 &CycleNextInlineAssist,
1960 &focus_handle,
1961 cx,
1962 )
1963 }
1964 })
1965 .on_click(cx.listener(|this, _, cx| {
1966 this.codegen
1967 .update(cx, |codegen, cx| codegen.cycle_next(cx))
1968 })),
1969 )
1970 .into_any_element()
1971 }
1972
1973 fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
1974 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1975 let token_counts = self.token_counts?;
1976 let max_token_count = model.max_token_count();
1977
1978 let remaining_tokens = max_token_count as isize - token_counts.total as isize;
1979 let token_count_color = if remaining_tokens <= 0 {
1980 Color::Error
1981 } else if token_counts.total as f32 / max_token_count as f32 >= 0.8 {
1982 Color::Warning
1983 } else {
1984 Color::Muted
1985 };
1986
1987 let mut token_count = h_flex()
1988 .id("token_count")
1989 .gap_0p5()
1990 .child(
1991 Label::new(humanize_token_count(token_counts.total))
1992 .size(LabelSize::Small)
1993 .color(token_count_color),
1994 )
1995 .child(Label::new("/").size(LabelSize::Small).color(Color::Muted))
1996 .child(
1997 Label::new(humanize_token_count(max_token_count))
1998 .size(LabelSize::Small)
1999 .color(Color::Muted),
2000 );
2001 if let Some(workspace) = self.workspace.clone() {
2002 token_count = token_count
2003 .tooltip(move |cx| {
2004 Tooltip::with_meta(
2005 format!(
2006 "Tokens Used ({} from the Assistant Panel)",
2007 humanize_token_count(token_counts.assistant_panel)
2008 ),
2009 None,
2010 "Click to open the Assistant Panel",
2011 cx,
2012 )
2013 })
2014 .cursor_pointer()
2015 .on_mouse_down(gpui::MouseButton::Left, |_, cx| cx.stop_propagation())
2016 .on_click(move |_, cx| {
2017 cx.stop_propagation();
2018 workspace
2019 .update(cx, |workspace, cx| {
2020 workspace.focus_panel::<AssistantPanel>(cx)
2021 })
2022 .ok();
2023 });
2024 } else {
2025 token_count = token_count
2026 .cursor_default()
2027 .tooltip(|cx| Tooltip::text("Tokens used", cx));
2028 }
2029
2030 Some(token_count)
2031 }
2032
2033 fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2034 let settings = ThemeSettings::get_global(cx);
2035 let text_style = TextStyle {
2036 color: if self.editor.read(cx).read_only(cx) {
2037 cx.theme().colors().text_disabled
2038 } else {
2039 cx.theme().colors().text
2040 },
2041 font_family: settings.buffer_font.family.clone(),
2042 font_fallbacks: settings.buffer_font.fallbacks.clone(),
2043 font_size: settings.buffer_font_size.into(),
2044 font_weight: settings.buffer_font.weight,
2045 line_height: relative(settings.buffer_line_height.value()),
2046 ..Default::default()
2047 };
2048 EditorElement::new(
2049 &self.editor,
2050 EditorStyle {
2051 background: cx.theme().colors().editor_background,
2052 local_player: cx.theme().players().local(),
2053 text: text_style,
2054 ..Default::default()
2055 },
2056 )
2057 }
2058
2059 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2060 Popover::new().child(
2061 v_flex()
2062 .occlude()
2063 .p_2()
2064 .child(
2065 Label::new("Out of Tokens")
2066 .size(LabelSize::Small)
2067 .weight(FontWeight::BOLD),
2068 )
2069 .child(Label::new(
2070 "Try Zed Pro for higher limits, a wider range of models, and more.",
2071 ))
2072 .child(
2073 h_flex()
2074 .justify_between()
2075 .child(CheckboxWithLabel::new(
2076 "dont-show-again",
2077 Label::new("Don't show again"),
2078 if dismissed_rate_limit_notice() {
2079 ui::Selection::Selected
2080 } else {
2081 ui::Selection::Unselected
2082 },
2083 |selection, cx| {
2084 let is_dismissed = match selection {
2085 ui::Selection::Unselected => false,
2086 ui::Selection::Indeterminate => return,
2087 ui::Selection::Selected => true,
2088 };
2089
2090 set_rate_limit_notice_dismissed(is_dismissed, cx)
2091 },
2092 ))
2093 .child(
2094 h_flex()
2095 .gap_2()
2096 .child(
2097 Button::new("dismiss", "Dismiss")
2098 .style(ButtonStyle::Transparent)
2099 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2100 )
2101 .child(Button::new("more-info", "More Info").on_click(
2102 |_event, cx| {
2103 cx.dispatch_action(Box::new(
2104 zed_actions::OpenAccountSettings,
2105 ))
2106 },
2107 )),
2108 ),
2109 ),
2110 )
2111 }
2112}
2113
2114const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2115
2116fn dismissed_rate_limit_notice() -> bool {
2117 db::kvp::KEY_VALUE_STORE
2118 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2119 .log_err()
2120 .map_or(false, |s| s.is_some())
2121}
2122
2123fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2124 db::write_and_log(cx, move || async move {
2125 if is_dismissed {
2126 db::kvp::KEY_VALUE_STORE
2127 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2128 .await
2129 } else {
2130 db::kvp::KEY_VALUE_STORE
2131 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2132 .await
2133 }
2134 })
2135}
2136
2137struct InlineAssist {
2138 group_id: InlineAssistGroupId,
2139 range: Range<Anchor>,
2140 editor: WeakView<Editor>,
2141 decorations: Option<InlineAssistDecorations>,
2142 codegen: Model<Codegen>,
2143 _subscriptions: Vec<Subscription>,
2144 workspace: Option<WeakView<Workspace>>,
2145 include_context: bool,
2146}
2147
2148impl InlineAssist {
2149 #[allow(clippy::too_many_arguments)]
2150 fn new(
2151 assist_id: InlineAssistId,
2152 group_id: InlineAssistGroupId,
2153 include_context: bool,
2154 editor: &View<Editor>,
2155 prompt_editor: &View<PromptEditor>,
2156 prompt_block_id: CustomBlockId,
2157 end_block_id: CustomBlockId,
2158 range: Range<Anchor>,
2159 codegen: Model<Codegen>,
2160 workspace: Option<WeakView<Workspace>>,
2161 cx: &mut WindowContext,
2162 ) -> Self {
2163 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2164 InlineAssist {
2165 group_id,
2166 include_context,
2167 editor: editor.downgrade(),
2168 decorations: Some(InlineAssistDecorations {
2169 prompt_block_id,
2170 prompt_editor: prompt_editor.clone(),
2171 removed_line_block_ids: HashSet::default(),
2172 end_block_id,
2173 }),
2174 range,
2175 codegen: codegen.clone(),
2176 workspace: workspace.clone(),
2177 _subscriptions: vec![
2178 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2179 InlineAssistant::update_global(cx, |this, cx| {
2180 this.handle_prompt_editor_focus_in(assist_id, cx)
2181 })
2182 }),
2183 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2184 InlineAssistant::update_global(cx, |this, cx| {
2185 this.handle_prompt_editor_focus_out(assist_id, cx)
2186 })
2187 }),
2188 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2189 InlineAssistant::update_global(cx, |this, cx| {
2190 this.handle_prompt_editor_event(prompt_editor, event, cx)
2191 })
2192 }),
2193 cx.observe(&codegen, {
2194 let editor = editor.downgrade();
2195 move |_, cx| {
2196 if let Some(editor) = editor.upgrade() {
2197 InlineAssistant::update_global(cx, |this, cx| {
2198 if let Some(editor_assists) =
2199 this.assists_by_editor.get(&editor.downgrade())
2200 {
2201 editor_assists.highlight_updates.send(()).ok();
2202 }
2203
2204 this.update_editor_blocks(&editor, assist_id, cx);
2205 })
2206 }
2207 }
2208 }),
2209 cx.subscribe(&codegen, move |codegen, event, cx| {
2210 InlineAssistant::update_global(cx, |this, cx| match event {
2211 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2212 CodegenEvent::Finished => {
2213 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2214 assist
2215 } else {
2216 return;
2217 };
2218
2219 if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2220 if assist.decorations.is_none() {
2221 if let Some(workspace) = assist
2222 .workspace
2223 .as_ref()
2224 .and_then(|workspace| workspace.upgrade())
2225 {
2226 let error = format!("Inline assistant error: {}", error);
2227 workspace.update(cx, |workspace, cx| {
2228 struct InlineAssistantError;
2229
2230 let id =
2231 NotificationId::composite::<InlineAssistantError>(
2232 assist_id.0,
2233 );
2234
2235 workspace.show_toast(Toast::new(id, error), cx);
2236 })
2237 }
2238 }
2239 }
2240
2241 if assist.decorations.is_none() {
2242 this.finish_assist(assist_id, false, cx);
2243 }
2244 }
2245 })
2246 }),
2247 ],
2248 }
2249 }
2250
2251 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2252 let decorations = self.decorations.as_ref()?;
2253 Some(decorations.prompt_editor.read(cx).prompt(cx))
2254 }
2255
2256 fn assistant_panel_context(&self, cx: &WindowContext) -> Option<LanguageModelRequest> {
2257 if self.include_context {
2258 let workspace = self.workspace.as_ref()?;
2259 let workspace = workspace.upgrade()?.read(cx);
2260 let assistant_panel = workspace.panel::<AssistantPanel>(cx)?;
2261 Some(
2262 assistant_panel
2263 .read(cx)
2264 .active_context(cx)?
2265 .read(cx)
2266 .to_completion_request(RequestType::Chat, cx),
2267 )
2268 } else {
2269 None
2270 }
2271 }
2272
2273 pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result<TokenCounts>> {
2274 let Some(user_prompt) = self.user_prompt(cx) else {
2275 return future::ready(Err(anyhow!("no user prompt"))).boxed();
2276 };
2277 let assistant_panel_context = self.assistant_panel_context(cx);
2278 self.codegen
2279 .read(cx)
2280 .count_tokens(user_prompt, assistant_panel_context, cx)
2281 }
2282}
2283
2284struct InlineAssistDecorations {
2285 prompt_block_id: CustomBlockId,
2286 prompt_editor: View<PromptEditor>,
2287 removed_line_block_ids: HashSet<CustomBlockId>,
2288 end_block_id: CustomBlockId,
2289}
2290
2291#[derive(Copy, Clone, Debug)]
2292pub enum CodegenEvent {
2293 Finished,
2294 Undone,
2295}
2296
2297pub struct Codegen {
2298 alternatives: Vec<Model<CodegenAlternative>>,
2299 active_alternative: usize,
2300 seen_alternatives: HashSet<usize>,
2301 subscriptions: Vec<Subscription>,
2302 buffer: Model<MultiBuffer>,
2303 range: Range<Anchor>,
2304 initial_transaction_id: Option<TransactionId>,
2305 telemetry: Arc<Telemetry>,
2306 builder: Arc<PromptBuilder>,
2307 is_insertion: bool,
2308}
2309
2310impl Codegen {
2311 pub fn new(
2312 buffer: Model<MultiBuffer>,
2313 range: Range<Anchor>,
2314 initial_transaction_id: Option<TransactionId>,
2315 telemetry: Arc<Telemetry>,
2316 builder: Arc<PromptBuilder>,
2317 cx: &mut ModelContext<Self>,
2318 ) -> Self {
2319 let codegen = cx.new_model(|cx| {
2320 CodegenAlternative::new(
2321 buffer.clone(),
2322 range.clone(),
2323 false,
2324 Some(telemetry.clone()),
2325 builder.clone(),
2326 cx,
2327 )
2328 });
2329 let mut this = Self {
2330 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2331 alternatives: vec![codegen],
2332 active_alternative: 0,
2333 seen_alternatives: HashSet::default(),
2334 subscriptions: Vec::new(),
2335 buffer,
2336 range,
2337 initial_transaction_id,
2338 telemetry,
2339 builder,
2340 };
2341 this.activate(0, cx);
2342 this
2343 }
2344
2345 fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2346 let codegen = self.active_alternative().clone();
2347 self.subscriptions.clear();
2348 self.subscriptions
2349 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2350 self.subscriptions
2351 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2352 }
2353
2354 fn active_alternative(&self) -> &Model<CodegenAlternative> {
2355 &self.alternatives[self.active_alternative]
2356 }
2357
2358 fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2359 &self.active_alternative().read(cx).status
2360 }
2361
2362 fn alternative_count(&self, cx: &AppContext) -> usize {
2363 LanguageModelRegistry::read_global(cx)
2364 .inline_alternative_models()
2365 .len()
2366 + 1
2367 }
2368
2369 pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2370 let next_active_ix = if self.active_alternative == 0 {
2371 self.alternatives.len() - 1
2372 } else {
2373 self.active_alternative - 1
2374 };
2375 self.activate(next_active_ix, cx);
2376 }
2377
2378 pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2379 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2380 self.activate(next_active_ix, cx);
2381 }
2382
2383 fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2384 self.active_alternative()
2385 .update(cx, |codegen, cx| codegen.set_active(false, cx));
2386 self.seen_alternatives.insert(index);
2387 self.active_alternative = index;
2388 self.active_alternative()
2389 .update(cx, |codegen, cx| codegen.set_active(true, cx));
2390 self.subscribe_to_alternative(cx);
2391 cx.notify();
2392 }
2393
2394 pub fn start(
2395 &mut self,
2396 user_prompt: String,
2397 assistant_panel_context: Option<LanguageModelRequest>,
2398 cx: &mut ModelContext<Self>,
2399 ) -> Result<()> {
2400 let alternative_models = LanguageModelRegistry::read_global(cx)
2401 .inline_alternative_models()
2402 .to_vec();
2403
2404 self.active_alternative()
2405 .update(cx, |alternative, cx| alternative.undo(cx));
2406 self.activate(0, cx);
2407 self.alternatives.truncate(1);
2408
2409 for _ in 0..alternative_models.len() {
2410 self.alternatives.push(cx.new_model(|cx| {
2411 CodegenAlternative::new(
2412 self.buffer.clone(),
2413 self.range.clone(),
2414 false,
2415 Some(self.telemetry.clone()),
2416 self.builder.clone(),
2417 cx,
2418 )
2419 }));
2420 }
2421
2422 let primary_model = LanguageModelRegistry::read_global(cx)
2423 .active_model()
2424 .context("no active model")?;
2425
2426 for (model, alternative) in iter::once(primary_model)
2427 .chain(alternative_models)
2428 .zip(&self.alternatives)
2429 {
2430 alternative.update(cx, |alternative, cx| {
2431 alternative.start(
2432 user_prompt.clone(),
2433 assistant_panel_context.clone(),
2434 model.clone(),
2435 cx,
2436 )
2437 })?;
2438 }
2439
2440 Ok(())
2441 }
2442
2443 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2444 for codegen in &self.alternatives {
2445 codegen.update(cx, |codegen, cx| codegen.stop(cx));
2446 }
2447 }
2448
2449 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2450 self.active_alternative()
2451 .update(cx, |codegen, cx| codegen.undo(cx));
2452
2453 self.buffer.update(cx, |buffer, cx| {
2454 if let Some(transaction_id) = self.initial_transaction_id.take() {
2455 buffer.undo_transaction(transaction_id, cx);
2456 buffer.refresh_preview(cx);
2457 }
2458 });
2459 }
2460
2461 pub fn count_tokens(
2462 &self,
2463 user_prompt: String,
2464 assistant_panel_context: Option<LanguageModelRequest>,
2465 cx: &AppContext,
2466 ) -> BoxFuture<'static, Result<TokenCounts>> {
2467 self.active_alternative()
2468 .read(cx)
2469 .count_tokens(user_prompt, assistant_panel_context, cx)
2470 }
2471
2472 pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2473 self.active_alternative().read(cx).buffer.clone()
2474 }
2475
2476 pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2477 self.active_alternative().read(cx).old_buffer.clone()
2478 }
2479
2480 pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2481 self.active_alternative().read(cx).snapshot.clone()
2482 }
2483
2484 pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2485 self.active_alternative().read(cx).edit_position
2486 }
2487
2488 fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2489 &self.active_alternative().read(cx).diff
2490 }
2491
2492 pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2493 self.active_alternative().read(cx).last_equal_ranges()
2494 }
2495}
2496
2497impl EventEmitter<CodegenEvent> for Codegen {}
2498
2499pub struct CodegenAlternative {
2500 buffer: Model<MultiBuffer>,
2501 old_buffer: Model<Buffer>,
2502 snapshot: MultiBufferSnapshot,
2503 edit_position: Option<Anchor>,
2504 range: Range<Anchor>,
2505 last_equal_ranges: Vec<Range<Anchor>>,
2506 transformation_transaction_id: Option<TransactionId>,
2507 status: CodegenStatus,
2508 generation: Task<()>,
2509 diff: Diff,
2510 telemetry: Option<Arc<Telemetry>>,
2511 _subscription: gpui::Subscription,
2512 builder: Arc<PromptBuilder>,
2513 active: bool,
2514 edits: Vec<(Range<Anchor>, String)>,
2515 line_operations: Vec<LineOperation>,
2516 request: Option<LanguageModelRequest>,
2517 elapsed_time: Option<f64>,
2518 message_id: Option<String>,
2519}
2520
2521enum CodegenStatus {
2522 Idle,
2523 Pending,
2524 Done,
2525 Error(anyhow::Error),
2526}
2527
2528#[derive(Default)]
2529struct Diff {
2530 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2531 inserted_row_ranges: Vec<Range<Anchor>>,
2532}
2533
2534impl Diff {
2535 fn is_empty(&self) -> bool {
2536 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2537 }
2538}
2539
2540impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2541
2542impl CodegenAlternative {
2543 pub fn new(
2544 buffer: Model<MultiBuffer>,
2545 range: Range<Anchor>,
2546 active: bool,
2547 telemetry: Option<Arc<Telemetry>>,
2548 builder: Arc<PromptBuilder>,
2549 cx: &mut ModelContext<Self>,
2550 ) -> Self {
2551 let snapshot = buffer.read(cx).snapshot(cx);
2552
2553 let (old_buffer, _, _) = buffer
2554 .read(cx)
2555 .range_to_buffer_ranges(range.clone(), cx)
2556 .pop()
2557 .unwrap();
2558 let old_buffer = cx.new_model(|cx| {
2559 let old_buffer = old_buffer.read(cx);
2560 let text = old_buffer.as_rope().clone();
2561 let line_ending = old_buffer.line_ending();
2562 let language = old_buffer.language().cloned();
2563 let language_registry = old_buffer.language_registry();
2564
2565 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2566 buffer.set_language(language, cx);
2567 if let Some(language_registry) = language_registry {
2568 buffer.set_language_registry(language_registry)
2569 }
2570 buffer
2571 });
2572
2573 Self {
2574 buffer: buffer.clone(),
2575 old_buffer,
2576 edit_position: None,
2577 message_id: None,
2578 snapshot,
2579 last_equal_ranges: Default::default(),
2580 transformation_transaction_id: None,
2581 status: CodegenStatus::Idle,
2582 generation: Task::ready(()),
2583 diff: Diff::default(),
2584 telemetry,
2585 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2586 builder,
2587 active,
2588 edits: Vec::new(),
2589 line_operations: Vec::new(),
2590 range,
2591 request: None,
2592 elapsed_time: None,
2593 }
2594 }
2595
2596 fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2597 if active != self.active {
2598 self.active = active;
2599
2600 if self.active {
2601 let edits = self.edits.clone();
2602 self.apply_edits(edits, cx);
2603 if matches!(self.status, CodegenStatus::Pending) {
2604 let line_operations = self.line_operations.clone();
2605 self.reapply_line_based_diff(line_operations, cx);
2606 } else {
2607 self.reapply_batch_diff(cx).detach();
2608 }
2609 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2610 self.buffer.update(cx, |buffer, cx| {
2611 buffer.undo_transaction(transaction_id, cx);
2612 buffer.forget_transaction(transaction_id, cx);
2613 });
2614 }
2615 }
2616 }
2617
2618 fn handle_buffer_event(
2619 &mut self,
2620 _buffer: Model<MultiBuffer>,
2621 event: &multi_buffer::Event,
2622 cx: &mut ModelContext<Self>,
2623 ) {
2624 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2625 if self.transformation_transaction_id == Some(*transaction_id) {
2626 self.transformation_transaction_id = None;
2627 self.generation = Task::ready(());
2628 cx.emit(CodegenEvent::Undone);
2629 }
2630 }
2631 }
2632
2633 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2634 &self.last_equal_ranges
2635 }
2636
2637 pub fn count_tokens(
2638 &self,
2639 user_prompt: String,
2640 assistant_panel_context: Option<LanguageModelRequest>,
2641 cx: &AppContext,
2642 ) -> BoxFuture<'static, Result<TokenCounts>> {
2643 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
2644 let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
2645 match request {
2646 Ok(request) => {
2647 let total_count = model.count_tokens(request.clone(), cx);
2648 let assistant_panel_count = assistant_panel_context
2649 .map(|context| model.count_tokens(context, cx))
2650 .unwrap_or_else(|| future::ready(Ok(0)).boxed());
2651
2652 async move {
2653 Ok(TokenCounts {
2654 total: total_count.await?,
2655 assistant_panel: assistant_panel_count.await?,
2656 })
2657 }
2658 .boxed()
2659 }
2660 Err(error) => futures::future::ready(Err(error)).boxed(),
2661 }
2662 } else {
2663 future::ready(Err(anyhow!("no active model"))).boxed()
2664 }
2665 }
2666
2667 pub fn start(
2668 &mut self,
2669 user_prompt: String,
2670 assistant_panel_context: Option<LanguageModelRequest>,
2671 model: Arc<dyn LanguageModel>,
2672 cx: &mut ModelContext<Self>,
2673 ) -> Result<()> {
2674 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2675 self.buffer.update(cx, |buffer, cx| {
2676 buffer.undo_transaction(transformation_transaction_id, cx);
2677 });
2678 }
2679
2680 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2681
2682 let api_key = model.api_key(cx);
2683 let telemetry_id = model.telemetry_id();
2684 let provider_id = model.provider_id();
2685 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2686 if user_prompt.trim().to_lowercase() == "delete" {
2687 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2688 } else {
2689 let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
2690 self.request = Some(request.clone());
2691
2692 cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
2693 .boxed_local()
2694 };
2695 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2696 Ok(())
2697 }
2698
2699 fn build_request(
2700 &self,
2701 user_prompt: String,
2702 assistant_panel_context: Option<LanguageModelRequest>,
2703 cx: &AppContext,
2704 ) -> Result<LanguageModelRequest> {
2705 let buffer = self.buffer.read(cx).snapshot(cx);
2706 let language = buffer.language_at(self.range.start);
2707 let language_name = if let Some(language) = language.as_ref() {
2708 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2709 None
2710 } else {
2711 Some(language.name())
2712 }
2713 } else {
2714 None
2715 };
2716
2717 let language_name = language_name.as_ref();
2718 let start = buffer.point_to_buffer_offset(self.range.start);
2719 let end = buffer.point_to_buffer_offset(self.range.end);
2720 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2721 let (start_buffer, start_buffer_offset) = start;
2722 let (end_buffer, end_buffer_offset) = end;
2723 if start_buffer.remote_id() == end_buffer.remote_id() {
2724 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2725 } else {
2726 return Err(anyhow::anyhow!("invalid transformation range"));
2727 }
2728 } else {
2729 return Err(anyhow::anyhow!("invalid transformation range"));
2730 };
2731
2732 let prompt = self
2733 .builder
2734 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2735 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2736
2737 let mut messages = Vec::new();
2738 if let Some(context_request) = assistant_panel_context {
2739 messages = context_request.messages;
2740 }
2741
2742 messages.push(LanguageModelRequestMessage {
2743 role: Role::User,
2744 content: vec![prompt.into()],
2745 cache: false,
2746 });
2747
2748 Ok(LanguageModelRequest {
2749 messages,
2750 tools: Vec::new(),
2751 stop: Vec::new(),
2752 temperature: None,
2753 })
2754 }
2755
2756 pub fn handle_stream(
2757 &mut self,
2758 model_telemetry_id: String,
2759 model_provider_id: String,
2760 model_api_key: Option<String>,
2761 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2762 cx: &mut ModelContext<Self>,
2763 ) {
2764 let start_time = Instant::now();
2765 let snapshot = self.snapshot.clone();
2766 let selected_text = snapshot
2767 .text_for_range(self.range.start..self.range.end)
2768 .collect::<Rope>();
2769
2770 let selection_start = self.range.start.to_point(&snapshot);
2771
2772 // Start with the indentation of the first line in the selection
2773 let mut suggested_line_indent = snapshot
2774 .suggested_indents(selection_start.row..=selection_start.row, cx)
2775 .into_values()
2776 .next()
2777 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2778
2779 // If the first line in the selection does not have indentation, check the following lines
2780 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2781 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2782 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2783 // Prefer tabs if a line in the selection uses tabs as indentation
2784 if line_indent.kind == IndentKind::Tab {
2785 suggested_line_indent.kind = IndentKind::Tab;
2786 break;
2787 }
2788 }
2789 }
2790
2791 let http_client = cx.http_client().clone();
2792 let telemetry = self.telemetry.clone();
2793 let language_name = {
2794 let multibuffer = self.buffer.read(cx);
2795 let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
2796 ranges
2797 .first()
2798 .and_then(|(buffer, _, _)| buffer.read(cx).language())
2799 .map(|language| language.name())
2800 };
2801
2802 self.diff = Diff::default();
2803 self.status = CodegenStatus::Pending;
2804 let mut edit_start = self.range.start.to_offset(&snapshot);
2805 self.generation = cx.spawn(|codegen, mut cx| {
2806 async move {
2807 let stream = stream.await;
2808 let message_id = stream
2809 .as_ref()
2810 .ok()
2811 .and_then(|stream| stream.message_id.clone());
2812 let generate = async {
2813 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2814 let executor = cx.background_executor().clone();
2815 let message_id = message_id.clone();
2816 let line_based_stream_diff: Task<anyhow::Result<()>> =
2817 cx.background_executor().spawn(async move {
2818 let mut response_latency = None;
2819 let request_start = Instant::now();
2820 let diff = async {
2821 let chunks = StripInvalidSpans::new(stream?.stream);
2822 futures::pin_mut!(chunks);
2823 let mut diff = StreamingDiff::new(selected_text.to_string());
2824 let mut line_diff = LineDiff::default();
2825
2826 let mut new_text = String::new();
2827 let mut base_indent = None;
2828 let mut line_indent = None;
2829 let mut first_line = true;
2830
2831 while let Some(chunk) = chunks.next().await {
2832 if response_latency.is_none() {
2833 response_latency = Some(request_start.elapsed());
2834 }
2835 let chunk = chunk?;
2836
2837 let mut lines = chunk.split('\n').peekable();
2838 while let Some(line) = lines.next() {
2839 new_text.push_str(line);
2840 if line_indent.is_none() {
2841 if let Some(non_whitespace_ch_ix) =
2842 new_text.find(|ch: char| !ch.is_whitespace())
2843 {
2844 line_indent = Some(non_whitespace_ch_ix);
2845 base_indent = base_indent.or(line_indent);
2846
2847 let line_indent = line_indent.unwrap();
2848 let base_indent = base_indent.unwrap();
2849 let indent_delta =
2850 line_indent as i32 - base_indent as i32;
2851 let mut corrected_indent_len = cmp::max(
2852 0,
2853 suggested_line_indent.len as i32 + indent_delta,
2854 )
2855 as usize;
2856 if first_line {
2857 corrected_indent_len = corrected_indent_len
2858 .saturating_sub(
2859 selection_start.column as usize,
2860 );
2861 }
2862
2863 let indent_char = suggested_line_indent.char();
2864 let mut indent_buffer = [0; 4];
2865 let indent_str =
2866 indent_char.encode_utf8(&mut indent_buffer);
2867 new_text.replace_range(
2868 ..line_indent,
2869 &indent_str.repeat(corrected_indent_len),
2870 );
2871 }
2872 }
2873
2874 if line_indent.is_some() {
2875 let char_ops = diff.push_new(&new_text);
2876 line_diff
2877 .push_char_operations(&char_ops, &selected_text);
2878 diff_tx
2879 .send((char_ops, line_diff.line_operations()))
2880 .await?;
2881 new_text.clear();
2882 }
2883
2884 if lines.peek().is_some() {
2885 let char_ops = diff.push_new("\n");
2886 line_diff
2887 .push_char_operations(&char_ops, &selected_text);
2888 diff_tx
2889 .send((char_ops, line_diff.line_operations()))
2890 .await?;
2891 if line_indent.is_none() {
2892 // Don't write out the leading indentation in empty lines on the next line
2893 // This is the case where the above if statement didn't clear the buffer
2894 new_text.clear();
2895 }
2896 line_indent = None;
2897 first_line = false;
2898 }
2899 }
2900 }
2901
2902 let mut char_ops = diff.push_new(&new_text);
2903 char_ops.extend(diff.finish());
2904 line_diff.push_char_operations(&char_ops, &selected_text);
2905 line_diff.finish(&selected_text);
2906 diff_tx
2907 .send((char_ops, line_diff.line_operations()))
2908 .await?;
2909
2910 anyhow::Ok(())
2911 };
2912
2913 let result = diff.await;
2914
2915 let error_message =
2916 result.as_ref().err().map(|error| error.to_string());
2917 report_assistant_event(
2918 AssistantEvent {
2919 conversation_id: None,
2920 message_id,
2921 kind: AssistantKind::Inline,
2922 phase: AssistantPhase::Response,
2923 model: model_telemetry_id,
2924 model_provider: model_provider_id.to_string(),
2925 response_latency,
2926 error_message,
2927 language_name: language_name.map(|name| name.to_proto()),
2928 },
2929 telemetry,
2930 http_client,
2931 model_api_key,
2932 &executor,
2933 );
2934
2935 result?;
2936 Ok(())
2937 });
2938
2939 while let Some((char_ops, line_ops)) = diff_rx.next().await {
2940 codegen.update(&mut cx, |codegen, cx| {
2941 codegen.last_equal_ranges.clear();
2942
2943 let edits = char_ops
2944 .into_iter()
2945 .filter_map(|operation| match operation {
2946 CharOperation::Insert { text } => {
2947 let edit_start = snapshot.anchor_after(edit_start);
2948 Some((edit_start..edit_start, text))
2949 }
2950 CharOperation::Delete { bytes } => {
2951 let edit_end = edit_start + bytes;
2952 let edit_range = snapshot.anchor_after(edit_start)
2953 ..snapshot.anchor_before(edit_end);
2954 edit_start = edit_end;
2955 Some((edit_range, String::new()))
2956 }
2957 CharOperation::Keep { bytes } => {
2958 let edit_end = edit_start + bytes;
2959 let edit_range = snapshot.anchor_after(edit_start)
2960 ..snapshot.anchor_before(edit_end);
2961 edit_start = edit_end;
2962 codegen.last_equal_ranges.push(edit_range);
2963 None
2964 }
2965 })
2966 .collect::<Vec<_>>();
2967
2968 if codegen.active {
2969 codegen.apply_edits(edits.iter().cloned(), cx);
2970 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2971 }
2972 codegen.edits.extend(edits);
2973 codegen.line_operations = line_ops;
2974 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2975
2976 cx.notify();
2977 })?;
2978 }
2979
2980 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2981 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2982 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2983 let batch_diff_task =
2984 codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
2985 let (line_based_stream_diff, ()) =
2986 join!(line_based_stream_diff, batch_diff_task);
2987 line_based_stream_diff?;
2988
2989 anyhow::Ok(())
2990 };
2991
2992 let result = generate.await;
2993 let elapsed_time = start_time.elapsed().as_secs_f64();
2994
2995 codegen
2996 .update(&mut cx, |this, cx| {
2997 this.message_id = message_id;
2998 this.last_equal_ranges.clear();
2999 if let Err(error) = result {
3000 this.status = CodegenStatus::Error(error);
3001 } else {
3002 this.status = CodegenStatus::Done;
3003 }
3004 this.elapsed_time = Some(elapsed_time);
3005 cx.emit(CodegenEvent::Finished);
3006 cx.notify();
3007 })
3008 .ok();
3009 }
3010 });
3011 cx.notify();
3012 }
3013
3014 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
3015 self.last_equal_ranges.clear();
3016 if self.diff.is_empty() {
3017 self.status = CodegenStatus::Idle;
3018 } else {
3019 self.status = CodegenStatus::Done;
3020 }
3021 self.generation = Task::ready(());
3022 cx.emit(CodegenEvent::Finished);
3023 cx.notify();
3024 }
3025
3026 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
3027 self.buffer.update(cx, |buffer, cx| {
3028 if let Some(transaction_id) = self.transformation_transaction_id.take() {
3029 buffer.undo_transaction(transaction_id, cx);
3030 buffer.refresh_preview(cx);
3031 }
3032 });
3033 }
3034
3035 fn apply_edits(
3036 &mut self,
3037 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
3038 cx: &mut ModelContext<CodegenAlternative>,
3039 ) {
3040 let transaction = self.buffer.update(cx, |buffer, cx| {
3041 // Avoid grouping assistant edits with user edits.
3042 buffer.finalize_last_transaction(cx);
3043 buffer.start_transaction(cx);
3044 buffer.edit(edits, None, cx);
3045 buffer.end_transaction(cx)
3046 });
3047
3048 if let Some(transaction) = transaction {
3049 if let Some(first_transaction) = self.transformation_transaction_id {
3050 // Group all assistant edits into the first transaction.
3051 self.buffer.update(cx, |buffer, cx| {
3052 buffer.merge_transactions(transaction, first_transaction, cx)
3053 });
3054 } else {
3055 self.transformation_transaction_id = Some(transaction);
3056 self.buffer
3057 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3058 }
3059 }
3060 }
3061
3062 fn reapply_line_based_diff(
3063 &mut self,
3064 line_operations: impl IntoIterator<Item = LineOperation>,
3065 cx: &mut ModelContext<Self>,
3066 ) {
3067 let old_snapshot = self.snapshot.clone();
3068 let old_range = self.range.to_point(&old_snapshot);
3069 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3070 let new_range = self.range.to_point(&new_snapshot);
3071
3072 let mut old_row = old_range.start.row;
3073 let mut new_row = new_range.start.row;
3074
3075 self.diff.deleted_row_ranges.clear();
3076 self.diff.inserted_row_ranges.clear();
3077 for operation in line_operations {
3078 match operation {
3079 LineOperation::Keep { lines } => {
3080 old_row += lines;
3081 new_row += lines;
3082 }
3083 LineOperation::Delete { lines } => {
3084 let old_end_row = old_row + lines - 1;
3085 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3086
3087 if let Some((_, last_deleted_row_range)) =
3088 self.diff.deleted_row_ranges.last_mut()
3089 {
3090 if *last_deleted_row_range.end() + 1 == old_row {
3091 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3092 } else {
3093 self.diff
3094 .deleted_row_ranges
3095 .push((new_row, old_row..=old_end_row));
3096 }
3097 } else {
3098 self.diff
3099 .deleted_row_ranges
3100 .push((new_row, old_row..=old_end_row));
3101 }
3102
3103 old_row += lines;
3104 }
3105 LineOperation::Insert { lines } => {
3106 let new_end_row = new_row + lines - 1;
3107 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3108 let end = new_snapshot.anchor_before(Point::new(
3109 new_end_row,
3110 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3111 ));
3112 self.diff.inserted_row_ranges.push(start..end);
3113 new_row += lines;
3114 }
3115 }
3116
3117 cx.notify();
3118 }
3119 }
3120
3121 fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3122 let old_snapshot = self.snapshot.clone();
3123 let old_range = self.range.to_point(&old_snapshot);
3124 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3125 let new_range = self.range.to_point(&new_snapshot);
3126
3127 cx.spawn(|codegen, mut cx| async move {
3128 let (deleted_row_ranges, inserted_row_ranges) = cx
3129 .background_executor()
3130 .spawn(async move {
3131 let old_text = old_snapshot
3132 .text_for_range(
3133 Point::new(old_range.start.row, 0)
3134 ..Point::new(
3135 old_range.end.row,
3136 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3137 ),
3138 )
3139 .collect::<String>();
3140 let new_text = new_snapshot
3141 .text_for_range(
3142 Point::new(new_range.start.row, 0)
3143 ..Point::new(
3144 new_range.end.row,
3145 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3146 ),
3147 )
3148 .collect::<String>();
3149
3150 let mut old_row = old_range.start.row;
3151 let mut new_row = new_range.start.row;
3152 let batch_diff =
3153 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3154
3155 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3156 let mut inserted_row_ranges = Vec::new();
3157 for change in batch_diff.iter_all_changes() {
3158 let line_count = change.value().lines().count() as u32;
3159 match change.tag() {
3160 similar::ChangeTag::Equal => {
3161 old_row += line_count;
3162 new_row += line_count;
3163 }
3164 similar::ChangeTag::Delete => {
3165 let old_end_row = old_row + line_count - 1;
3166 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3167
3168 if let Some((_, last_deleted_row_range)) =
3169 deleted_row_ranges.last_mut()
3170 {
3171 if *last_deleted_row_range.end() + 1 == old_row {
3172 *last_deleted_row_range =
3173 *last_deleted_row_range.start()..=old_end_row;
3174 } else {
3175 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3176 }
3177 } else {
3178 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3179 }
3180
3181 old_row += line_count;
3182 }
3183 similar::ChangeTag::Insert => {
3184 let new_end_row = new_row + line_count - 1;
3185 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3186 let end = new_snapshot.anchor_before(Point::new(
3187 new_end_row,
3188 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3189 ));
3190 inserted_row_ranges.push(start..end);
3191 new_row += line_count;
3192 }
3193 }
3194 }
3195
3196 (deleted_row_ranges, inserted_row_ranges)
3197 })
3198 .await;
3199
3200 codegen
3201 .update(&mut cx, |codegen, cx| {
3202 codegen.diff.deleted_row_ranges = deleted_row_ranges;
3203 codegen.diff.inserted_row_ranges = inserted_row_ranges;
3204 cx.notify();
3205 })
3206 .ok();
3207 })
3208 }
3209}
3210
3211struct StripInvalidSpans<T> {
3212 stream: T,
3213 stream_done: bool,
3214 buffer: String,
3215 first_line: bool,
3216 line_end: bool,
3217 starts_with_code_block: bool,
3218}
3219
3220impl<T> StripInvalidSpans<T>
3221where
3222 T: Stream<Item = Result<String>>,
3223{
3224 fn new(stream: T) -> Self {
3225 Self {
3226 stream,
3227 stream_done: false,
3228 buffer: String::new(),
3229 first_line: true,
3230 line_end: false,
3231 starts_with_code_block: false,
3232 }
3233 }
3234}
3235
3236impl<T> Stream for StripInvalidSpans<T>
3237where
3238 T: Stream<Item = Result<String>>,
3239{
3240 type Item = Result<String>;
3241
3242 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3243 const CODE_BLOCK_DELIMITER: &str = "```";
3244 const CURSOR_SPAN: &str = "<|CURSOR|>";
3245
3246 let this = unsafe { self.get_unchecked_mut() };
3247 loop {
3248 if !this.stream_done {
3249 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3250 match stream.as_mut().poll_next(cx) {
3251 Poll::Ready(Some(Ok(chunk))) => {
3252 this.buffer.push_str(&chunk);
3253 }
3254 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3255 Poll::Ready(None) => {
3256 this.stream_done = true;
3257 }
3258 Poll::Pending => return Poll::Pending,
3259 }
3260 }
3261
3262 let mut chunk = String::new();
3263 let mut consumed = 0;
3264 if !this.buffer.is_empty() {
3265 let mut lines = this.buffer.split('\n').enumerate().peekable();
3266 while let Some((line_ix, line)) = lines.next() {
3267 if line_ix > 0 {
3268 this.first_line = false;
3269 }
3270
3271 if this.first_line {
3272 let trimmed_line = line.trim();
3273 if lines.peek().is_some() {
3274 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3275 consumed += line.len() + 1;
3276 this.starts_with_code_block = true;
3277 continue;
3278 }
3279 } else if trimmed_line.is_empty()
3280 || prefixes(CODE_BLOCK_DELIMITER)
3281 .any(|prefix| trimmed_line.starts_with(prefix))
3282 {
3283 break;
3284 }
3285 }
3286
3287 let line_without_cursor = line.replace(CURSOR_SPAN, "");
3288 if lines.peek().is_some() {
3289 if this.line_end {
3290 chunk.push('\n');
3291 }
3292
3293 chunk.push_str(&line_without_cursor);
3294 this.line_end = true;
3295 consumed += line.len() + 1;
3296 } else if this.stream_done {
3297 if !this.starts_with_code_block
3298 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3299 {
3300 if this.line_end {
3301 chunk.push('\n');
3302 }
3303
3304 chunk.push_str(&line);
3305 }
3306
3307 consumed += line.len();
3308 } else {
3309 let trimmed_line = line.trim();
3310 if trimmed_line.is_empty()
3311 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3312 || prefixes(CODE_BLOCK_DELIMITER)
3313 .any(|prefix| trimmed_line.ends_with(prefix))
3314 {
3315 break;
3316 } else {
3317 if this.line_end {
3318 chunk.push('\n');
3319 this.line_end = false;
3320 }
3321
3322 chunk.push_str(&line_without_cursor);
3323 consumed += line.len();
3324 }
3325 }
3326 }
3327 }
3328
3329 this.buffer = this.buffer.split_off(consumed);
3330 if !chunk.is_empty() {
3331 return Poll::Ready(Some(Ok(chunk)));
3332 } else if this.stream_done {
3333 return Poll::Ready(None);
3334 }
3335 }
3336 }
3337}
3338
3339struct AssistantCodeActionProvider {
3340 editor: WeakView<Editor>,
3341 workspace: WeakView<Workspace>,
3342}
3343
3344impl CodeActionProvider for AssistantCodeActionProvider {
3345 fn code_actions(
3346 &self,
3347 buffer: &Model<Buffer>,
3348 range: Range<text::Anchor>,
3349 cx: &mut WindowContext,
3350 ) -> Task<Result<Vec<CodeAction>>> {
3351 if !AssistantSettings::get_global(cx).enabled {
3352 return Task::ready(Ok(Vec::new()));
3353 }
3354
3355 let snapshot = buffer.read(cx).snapshot();
3356 let mut range = range.to_point(&snapshot);
3357
3358 // Expand the range to line boundaries.
3359 range.start.column = 0;
3360 range.end.column = snapshot.line_len(range.end.row);
3361
3362 let mut has_diagnostics = false;
3363 for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3364 range.start = cmp::min(range.start, diagnostic.range.start);
3365 range.end = cmp::max(range.end, diagnostic.range.end);
3366 has_diagnostics = true;
3367 }
3368 if has_diagnostics {
3369 if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3370 if let Some(symbol) = symbols_containing_start.last() {
3371 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3372 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3373 }
3374 }
3375
3376 if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3377 if let Some(symbol) = symbols_containing_end.last() {
3378 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3379 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3380 }
3381 }
3382
3383 Task::ready(Ok(vec![CodeAction {
3384 server_id: language::LanguageServerId(0),
3385 range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3386 lsp_action: lsp::CodeAction {
3387 title: "Fix with Assistant".into(),
3388 ..Default::default()
3389 },
3390 }]))
3391 } else {
3392 Task::ready(Ok(Vec::new()))
3393 }
3394 }
3395
3396 fn apply_code_action(
3397 &self,
3398 buffer: Model<Buffer>,
3399 action: CodeAction,
3400 excerpt_id: ExcerptId,
3401 _push_to_history: bool,
3402 cx: &mut WindowContext,
3403 ) -> Task<Result<ProjectTransaction>> {
3404 let editor = self.editor.clone();
3405 let workspace = self.workspace.clone();
3406 cx.spawn(|mut cx| async move {
3407 let editor = editor.upgrade().context("editor was released")?;
3408 let range = editor
3409 .update(&mut cx, |editor, cx| {
3410 editor.buffer().update(cx, |multibuffer, cx| {
3411 let buffer = buffer.read(cx);
3412 let multibuffer_snapshot = multibuffer.read(cx);
3413
3414 let old_context_range =
3415 multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3416 let mut new_context_range = old_context_range.clone();
3417 if action
3418 .range
3419 .start
3420 .cmp(&old_context_range.start, buffer)
3421 .is_lt()
3422 {
3423 new_context_range.start = action.range.start;
3424 }
3425 if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3426 new_context_range.end = action.range.end;
3427 }
3428 drop(multibuffer_snapshot);
3429
3430 if new_context_range != old_context_range {
3431 multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3432 }
3433
3434 let multibuffer_snapshot = multibuffer.read(cx);
3435 Some(
3436 multibuffer_snapshot
3437 .anchor_in_excerpt(excerpt_id, action.range.start)?
3438 ..multibuffer_snapshot
3439 .anchor_in_excerpt(excerpt_id, action.range.end)?,
3440 )
3441 })
3442 })?
3443 .context("invalid range")?;
3444 let assistant_panel = workspace.update(&mut cx, |workspace, cx| {
3445 workspace
3446 .panel::<AssistantPanel>(cx)
3447 .context("assistant panel was released")
3448 })??;
3449
3450 cx.update_global(|assistant: &mut InlineAssistant, cx| {
3451 let assist_id = assistant.suggest_assist(
3452 &editor,
3453 range,
3454 "Fix Diagnostics".into(),
3455 None,
3456 true,
3457 Some(workspace),
3458 Some(&assistant_panel),
3459 cx,
3460 );
3461 assistant.start_assist(assist_id, cx);
3462 })?;
3463
3464 Ok(ProjectTransaction::default())
3465 })
3466 }
3467}
3468
3469fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3470 (0..text.len() - 1).map(|ix| &text[..ix + 1])
3471}
3472
3473fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3474 ranges.sort_unstable_by(|a, b| {
3475 a.start
3476 .cmp(&b.start, buffer)
3477 .then_with(|| b.end.cmp(&a.end, buffer))
3478 });
3479
3480 let mut ix = 0;
3481 while ix + 1 < ranges.len() {
3482 let b = ranges[ix + 1].clone();
3483 let a = &mut ranges[ix];
3484 if a.end.cmp(&b.start, buffer).is_gt() {
3485 if a.end.cmp(&b.end, buffer).is_lt() {
3486 a.end = b.end;
3487 }
3488 ranges.remove(ix + 1);
3489 } else {
3490 ix += 1;
3491 }
3492 }
3493}
3494
3495#[cfg(test)]
3496mod tests {
3497 use super::*;
3498 use futures::stream::{self};
3499 use gpui::{Context, TestAppContext};
3500 use indoc::indoc;
3501 use language::{
3502 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3503 Point,
3504 };
3505 use language_model::LanguageModelRegistry;
3506 use rand::prelude::*;
3507 use serde::Serialize;
3508 use settings::SettingsStore;
3509 use std::{future, sync::Arc};
3510
3511 #[derive(Serialize)]
3512 pub struct DummyCompletionRequest {
3513 pub name: String,
3514 }
3515
3516 #[gpui::test(iterations = 10)]
3517 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3518 cx.set_global(cx.update(SettingsStore::test));
3519 cx.update(language_model::LanguageModelRegistry::test);
3520 cx.update(language_settings::init);
3521
3522 let text = indoc! {"
3523 fn main() {
3524 let x = 0;
3525 for _ in 0..10 {
3526 x += 1;
3527 }
3528 }
3529 "};
3530 let buffer =
3531 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3532 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3533 let range = buffer.read_with(cx, |buffer, cx| {
3534 let snapshot = buffer.snapshot(cx);
3535 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3536 });
3537 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3538 let codegen = cx.new_model(|cx| {
3539 CodegenAlternative::new(
3540 buffer.clone(),
3541 range.clone(),
3542 true,
3543 None,
3544 prompt_builder,
3545 cx,
3546 )
3547 });
3548
3549 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3550
3551 let mut new_text = concat!(
3552 " let mut x = 0;\n",
3553 " while x < 10 {\n",
3554 " x += 1;\n",
3555 " }",
3556 );
3557 while !new_text.is_empty() {
3558 let max_len = cmp::min(new_text.len(), 10);
3559 let len = rng.gen_range(1..=max_len);
3560 let (chunk, suffix) = new_text.split_at(len);
3561 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3562 new_text = suffix;
3563 cx.background_executor.run_until_parked();
3564 }
3565 drop(chunks_tx);
3566 cx.background_executor.run_until_parked();
3567
3568 assert_eq!(
3569 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3570 indoc! {"
3571 fn main() {
3572 let mut x = 0;
3573 while x < 10 {
3574 x += 1;
3575 }
3576 }
3577 "}
3578 );
3579 }
3580
3581 #[gpui::test(iterations = 10)]
3582 async fn test_autoindent_when_generating_past_indentation(
3583 cx: &mut TestAppContext,
3584 mut rng: StdRng,
3585 ) {
3586 cx.set_global(cx.update(SettingsStore::test));
3587 cx.update(language_settings::init);
3588
3589 let text = indoc! {"
3590 fn main() {
3591 le
3592 }
3593 "};
3594 let buffer =
3595 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3596 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3597 let range = buffer.read_with(cx, |buffer, cx| {
3598 let snapshot = buffer.snapshot(cx);
3599 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3600 });
3601 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3602 let codegen = cx.new_model(|cx| {
3603 CodegenAlternative::new(
3604 buffer.clone(),
3605 range.clone(),
3606 true,
3607 None,
3608 prompt_builder,
3609 cx,
3610 )
3611 });
3612
3613 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3614
3615 cx.background_executor.run_until_parked();
3616
3617 let mut new_text = concat!(
3618 "t mut x = 0;\n",
3619 "while x < 10 {\n",
3620 " x += 1;\n",
3621 "}", //
3622 );
3623 while !new_text.is_empty() {
3624 let max_len = cmp::min(new_text.len(), 10);
3625 let len = rng.gen_range(1..=max_len);
3626 let (chunk, suffix) = new_text.split_at(len);
3627 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3628 new_text = suffix;
3629 cx.background_executor.run_until_parked();
3630 }
3631 drop(chunks_tx);
3632 cx.background_executor.run_until_parked();
3633
3634 assert_eq!(
3635 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3636 indoc! {"
3637 fn main() {
3638 let mut x = 0;
3639 while x < 10 {
3640 x += 1;
3641 }
3642 }
3643 "}
3644 );
3645 }
3646
3647 #[gpui::test(iterations = 10)]
3648 async fn test_autoindent_when_generating_before_indentation(
3649 cx: &mut TestAppContext,
3650 mut rng: StdRng,
3651 ) {
3652 cx.update(LanguageModelRegistry::test);
3653 cx.set_global(cx.update(SettingsStore::test));
3654 cx.update(language_settings::init);
3655
3656 let text = concat!(
3657 "fn main() {\n",
3658 " \n",
3659 "}\n" //
3660 );
3661 let buffer =
3662 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3663 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3664 let range = buffer.read_with(cx, |buffer, cx| {
3665 let snapshot = buffer.snapshot(cx);
3666 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3667 });
3668 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3669 let codegen = cx.new_model(|cx| {
3670 CodegenAlternative::new(
3671 buffer.clone(),
3672 range.clone(),
3673 true,
3674 None,
3675 prompt_builder,
3676 cx,
3677 )
3678 });
3679
3680 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3681
3682 cx.background_executor.run_until_parked();
3683
3684 let mut new_text = concat!(
3685 "let mut x = 0;\n",
3686 "while x < 10 {\n",
3687 " x += 1;\n",
3688 "}", //
3689 );
3690 while !new_text.is_empty() {
3691 let max_len = cmp::min(new_text.len(), 10);
3692 let len = rng.gen_range(1..=max_len);
3693 let (chunk, suffix) = new_text.split_at(len);
3694 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3695 new_text = suffix;
3696 cx.background_executor.run_until_parked();
3697 }
3698 drop(chunks_tx);
3699 cx.background_executor.run_until_parked();
3700
3701 assert_eq!(
3702 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3703 indoc! {"
3704 fn main() {
3705 let mut x = 0;
3706 while x < 10 {
3707 x += 1;
3708 }
3709 }
3710 "}
3711 );
3712 }
3713
3714 #[gpui::test(iterations = 10)]
3715 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3716 cx.update(LanguageModelRegistry::test);
3717 cx.set_global(cx.update(SettingsStore::test));
3718 cx.update(language_settings::init);
3719
3720 let text = indoc! {"
3721 func main() {
3722 \tx := 0
3723 \tfor i := 0; i < 10; i++ {
3724 \t\tx++
3725 \t}
3726 }
3727 "};
3728 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3729 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3730 let range = buffer.read_with(cx, |buffer, cx| {
3731 let snapshot = buffer.snapshot(cx);
3732 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3733 });
3734 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3735 let codegen = cx.new_model(|cx| {
3736 CodegenAlternative::new(
3737 buffer.clone(),
3738 range.clone(),
3739 true,
3740 None,
3741 prompt_builder,
3742 cx,
3743 )
3744 });
3745
3746 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3747 let new_text = concat!(
3748 "func main() {\n",
3749 "\tx := 0\n",
3750 "\tfor x < 10 {\n",
3751 "\t\tx++\n",
3752 "\t}", //
3753 );
3754 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3755 drop(chunks_tx);
3756 cx.background_executor.run_until_parked();
3757
3758 assert_eq!(
3759 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3760 indoc! {"
3761 func main() {
3762 \tx := 0
3763 \tfor x < 10 {
3764 \t\tx++
3765 \t}
3766 }
3767 "}
3768 );
3769 }
3770
3771 #[gpui::test]
3772 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3773 cx.update(LanguageModelRegistry::test);
3774 cx.set_global(cx.update(SettingsStore::test));
3775 cx.update(language_settings::init);
3776
3777 let text = indoc! {"
3778 fn main() {
3779 let x = 0;
3780 }
3781 "};
3782 let buffer =
3783 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3784 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3785 let range = buffer.read_with(cx, |buffer, cx| {
3786 let snapshot = buffer.snapshot(cx);
3787 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3788 });
3789 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3790 let codegen = cx.new_model(|cx| {
3791 CodegenAlternative::new(
3792 buffer.clone(),
3793 range.clone(),
3794 false,
3795 None,
3796 prompt_builder,
3797 cx,
3798 )
3799 });
3800
3801 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3802 chunks_tx
3803 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3804 .unwrap();
3805 drop(chunks_tx);
3806 cx.run_until_parked();
3807
3808 // The codegen is inactive, so the buffer doesn't get modified.
3809 assert_eq!(
3810 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3811 text
3812 );
3813
3814 // Activating the codegen applies the changes.
3815 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3816 assert_eq!(
3817 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3818 indoc! {"
3819 fn main() {
3820 let mut x = 0;
3821 x += 1;
3822 }
3823 "}
3824 );
3825
3826 // Deactivating the codegen undoes the changes.
3827 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3828 cx.run_until_parked();
3829 assert_eq!(
3830 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3831 text
3832 );
3833 }
3834
3835 #[gpui::test]
3836 async fn test_strip_invalid_spans_from_codeblock() {
3837 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3838 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3839 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3840 assert_chunks(
3841 "```html\n```js\nLorem ipsum dolor\n```\n```",
3842 "```js\nLorem ipsum dolor\n```",
3843 )
3844 .await;
3845 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3846 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3847 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3848 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3849
3850 async fn assert_chunks(text: &str, expected_text: &str) {
3851 for chunk_size in 1..=text.len() {
3852 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3853 .map(|chunk| chunk.unwrap())
3854 .collect::<String>()
3855 .await;
3856 assert_eq!(
3857 actual_text, expected_text,
3858 "failed to strip invalid spans, chunk size: {}",
3859 chunk_size
3860 );
3861 }
3862 }
3863
3864 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3865 stream::iter(
3866 text.chars()
3867 .collect::<Vec<_>>()
3868 .chunks(size)
3869 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3870 .collect::<Vec<_>>(),
3871 )
3872 }
3873 }
3874
3875 fn simulate_response_stream(
3876 codegen: Model<CodegenAlternative>,
3877 cx: &mut TestAppContext,
3878 ) -> mpsc::UnboundedSender<String> {
3879 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3880 codegen.update(cx, |codegen, cx| {
3881 codegen.handle_stream(
3882 String::new(),
3883 String::new(),
3884 None,
3885 future::ready(Ok(LanguageModelTextStream {
3886 message_id: None,
3887 stream: chunks_rx.map(Ok).boxed(),
3888 })),
3889 cx,
3890 );
3891 });
3892 chunks_tx
3893 }
3894
3895 fn rust_lang() -> Language {
3896 Language::new(
3897 LanguageConfig {
3898 name: "Rust".into(),
3899 matcher: LanguageMatcher {
3900 path_suffixes: vec!["rs".to_string()],
3901 ..Default::default()
3902 },
3903 ..Default::default()
3904 },
3905 Some(tree_sitter_rust::LANGUAGE.into()),
3906 )
3907 .with_indents_query(
3908 r#"
3909 (call_expression) @indent
3910 (field_expression) @indent
3911 (_ "(" ")" @end) @indent
3912 (_ "{" "}" @end) @indent
3913 "#,
3914 )
3915 .unwrap()
3916 }
3917}