1use crate::{
2 assistant_settings::AssistantSettings,
3 prompts::PromptBuilder,
4 streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
5 terminal_inline_assistant::TerminalInlineAssistant,
6 CycleNextInlineAssist, CyclePreviousInlineAssist, ToggleInlineAssist,
7};
8use anyhow::{Context as _, Result};
9use client::{telemetry::Telemetry, ErrorExt};
10use collections::{hash_map, HashMap, HashSet, VecDeque};
11use editor::{
12 actions::{MoveDown, MoveUp, SelectAll},
13 display_map::{
14 BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
15 ToDisplayPoint,
16 },
17 Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
18 EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
19 ToOffset as _, ToPoint,
20};
21use feature_flags::{FeatureFlagAppExt as _, ZedPro};
22use fs::Fs;
23use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
24use gpui::{
25 anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
26 FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext,
27 Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
28};
29use language::{Buffer, IndentKind, Point, Selection, TransactionId};
30use language_model::{
31 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
32 LanguageModelTextStream, Role,
33};
34use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
35use language_models::report_assistant_event;
36use multi_buffer::MultiBufferRow;
37use parking_lot::Mutex;
38use project::{CodeAction, ProjectTransaction};
39use rope::Rope;
40use settings::{update_settings_file, Settings, SettingsStore};
41use smol::future::FutureExt;
42use std::{
43 cmp,
44 future::Future,
45 iter, mem,
46 ops::{Range, RangeInclusive},
47 pin::Pin,
48 sync::Arc,
49 task::{self, Poll},
50 time::Instant,
51};
52use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
53use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
54use text::{OffsetRangeExt, ToPoint as _};
55use theme::ThemeSettings;
56use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, Tooltip};
57use util::{RangeExt, ResultExt};
58use workspace::{dock::Panel, ShowConfiguration};
59use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
60
61pub fn init(
62 fs: Arc<dyn Fs>,
63 prompt_builder: Arc<PromptBuilder>,
64 telemetry: Arc<Telemetry>,
65 cx: &mut AppContext,
66) {
67 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
68 cx.observe_new_views(|workspace: &mut Workspace, cx| {
69 workspace.register_action(InlineAssistant::toggle_inline_assist);
70
71 let workspace = cx.view().clone();
72 InlineAssistant::update_global(cx, |inline_assistant, cx| {
73 inline_assistant.register_workspace(&workspace, cx)
74 })
75 })
76 .detach();
77}
78
79const PROMPT_HISTORY_MAX_LEN: usize = 20;
80
81enum InlineAssistTarget {
82 Editor(View<Editor>),
83 Terminal(View<TerminalView>),
84}
85
86pub struct InlineAssistant {
87 next_assist_id: InlineAssistId,
88 next_assist_group_id: InlineAssistGroupId,
89 assists: HashMap<InlineAssistId, InlineAssist>,
90 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
91 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
92 confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
93 prompt_history: VecDeque<String>,
94 prompt_builder: Arc<PromptBuilder>,
95 telemetry: Arc<Telemetry>,
96 fs: Arc<dyn Fs>,
97}
98
99impl Global for InlineAssistant {}
100
101impl InlineAssistant {
102 pub fn new(
103 fs: Arc<dyn Fs>,
104 prompt_builder: Arc<PromptBuilder>,
105 telemetry: Arc<Telemetry>,
106 ) -> Self {
107 Self {
108 next_assist_id: InlineAssistId::default(),
109 next_assist_group_id: InlineAssistGroupId::default(),
110 assists: HashMap::default(),
111 assists_by_editor: HashMap::default(),
112 assist_groups: HashMap::default(),
113 confirmed_assists: HashMap::default(),
114 prompt_history: VecDeque::default(),
115 prompt_builder,
116 telemetry,
117 fs,
118 }
119 }
120
121 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
122 cx.subscribe(workspace, |workspace, event, cx| {
123 Self::update_global(cx, |this, cx| {
124 this.handle_workspace_event(workspace, event, cx)
125 });
126 })
127 .detach();
128
129 let workspace = workspace.downgrade();
130 cx.observe_global::<SettingsStore>(move |cx| {
131 let Some(workspace) = workspace.upgrade() else {
132 return;
133 };
134 let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
135 return;
136 };
137 let enabled = AssistantSettings::get_global(cx).enabled;
138 terminal_panel.update(cx, |terminal_panel, cx| {
139 terminal_panel.asssistant_enabled(enabled, cx)
140 });
141 })
142 .detach();
143 }
144
145 fn handle_workspace_event(
146 &mut self,
147 workspace: View<Workspace>,
148 event: &workspace::Event,
149 cx: &mut WindowContext,
150 ) {
151 match event {
152 workspace::Event::UserSavedItem { item, .. } => {
153 // When the user manually saves an editor, automatically accepts all finished transformations.
154 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
155 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
156 for assist_id in editor_assists.assist_ids.clone() {
157 let assist = &self.assists[&assist_id];
158 if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
159 self.finish_assist(assist_id, false, cx)
160 }
161 }
162 }
163 }
164 }
165 workspace::Event::ItemAdded { item } => {
166 self.register_workspace_item(&workspace, item.as_ref(), cx);
167 }
168 _ => (),
169 }
170 }
171
172 fn register_workspace_item(
173 &mut self,
174 workspace: &View<Workspace>,
175 item: &dyn ItemHandle,
176 cx: &mut WindowContext,
177 ) {
178 if let Some(editor) = item.act_as::<Editor>(cx) {
179 editor.update(cx, |editor, cx| {
180 editor.push_code_action_provider(
181 Arc::new(AssistantCodeActionProvider {
182 editor: cx.view().downgrade(),
183 workspace: workspace.downgrade(),
184 }),
185 cx,
186 );
187 });
188 }
189 }
190
191 pub fn toggle_inline_assist(
192 workspace: &mut Workspace,
193 _action: &ToggleInlineAssist,
194 cx: &mut ViewContext<Workspace>,
195 ) {
196 let settings = AssistantSettings::get_global(cx);
197 if !settings.enabled {
198 return;
199 }
200
201 let Some(inline_assist_target) = Self::resolve_inline_assist_target(workspace, cx) else {
202 return;
203 };
204
205 let is_authenticated = || {
206 LanguageModelRegistry::read_global(cx)
207 .active_provider()
208 .map_or(false, |provider| provider.is_authenticated(cx))
209 };
210
211 let handle_assist = |cx: &mut ViewContext<Workspace>| match inline_assist_target {
212 InlineAssistTarget::Editor(active_editor) => {
213 InlineAssistant::update_global(cx, |assistant, cx| {
214 assistant.assist(&active_editor, Some(cx.view().downgrade()), cx)
215 })
216 }
217 InlineAssistTarget::Terminal(active_terminal) => {
218 TerminalInlineAssistant::update_global(cx, |assistant, cx| {
219 assistant.assist(&active_terminal, Some(cx.view().downgrade()), cx)
220 })
221 }
222 };
223
224 if is_authenticated() {
225 handle_assist(cx);
226 } else {
227 cx.spawn(|_workspace, mut cx| async move {
228 let Some(task) = cx.update(|cx| {
229 LanguageModelRegistry::read_global(cx)
230 .active_provider()
231 .map_or(None, |provider| Some(provider.authenticate(cx)))
232 })?
233 else {
234 let answer = cx
235 .prompt(
236 gpui::PromptLevel::Warning,
237 "No language model provider configured",
238 None,
239 &["Configure", "Cancel"],
240 )
241 .await
242 .ok();
243 if let Some(answer) = answer {
244 if answer == 0 {
245 cx.update(|cx| cx.dispatch_action(Box::new(ShowConfiguration)))
246 .ok();
247 }
248 }
249 return Ok(());
250 };
251 task.await?;
252
253 anyhow::Ok(())
254 })
255 .detach_and_log_err(cx);
256
257 if is_authenticated() {
258 handle_assist(cx);
259 }
260 }
261 }
262
263 pub fn assist(
264 &mut self,
265 editor: &View<Editor>,
266 workspace: Option<WeakView<Workspace>>,
267 cx: &mut WindowContext,
268 ) {
269 let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
270 (
271 editor.buffer().read(cx).snapshot(cx),
272 editor.selections.all::<Point>(cx),
273 )
274 });
275
276 let mut selections = Vec::<Selection<Point>>::new();
277 let mut newest_selection = None;
278 for mut selection in initial_selections {
279 if selection.end > selection.start {
280 selection.start.column = 0;
281 // If the selection ends at the start of the line, we don't want to include it.
282 if selection.end.column == 0 {
283 selection.end.row -= 1;
284 }
285 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
286 }
287
288 if let Some(prev_selection) = selections.last_mut() {
289 if selection.start <= prev_selection.end {
290 prev_selection.end = selection.end;
291 continue;
292 }
293 }
294
295 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
296 if selection.id > latest_selection.id {
297 *latest_selection = selection.clone();
298 }
299 selections.push(selection);
300 }
301 let newest_selection = newest_selection.unwrap();
302
303 let mut codegen_ranges = Vec::new();
304 for (excerpt_id, buffer, buffer_range) in
305 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
306 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
307 }))
308 {
309 let start = Anchor {
310 buffer_id: Some(buffer.remote_id()),
311 excerpt_id,
312 text_anchor: buffer.anchor_before(buffer_range.start),
313 };
314 let end = Anchor {
315 buffer_id: Some(buffer.remote_id()),
316 excerpt_id,
317 text_anchor: buffer.anchor_after(buffer_range.end),
318 };
319 codegen_ranges.push(start..end);
320
321 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
322 self.telemetry.report_assistant_event(AssistantEvent {
323 conversation_id: None,
324 kind: AssistantKind::Inline,
325 phase: AssistantPhase::Invoked,
326 message_id: None,
327 model: model.telemetry_id(),
328 model_provider: model.provider_id().to_string(),
329 response_latency: None,
330 error_message: None,
331 language_name: buffer.language().map(|language| language.name().to_proto()),
332 });
333 }
334 }
335
336 let assist_group_id = self.next_assist_group_id.post_inc();
337 let prompt_buffer = cx.new_model(|cx| {
338 MultiBuffer::singleton(cx.new_model(|cx| Buffer::local(String::new(), cx)), cx)
339 });
340
341 let mut assists = Vec::new();
342 let mut assist_to_focus = None;
343 for range in codegen_ranges {
344 let assist_id = self.next_assist_id.post_inc();
345 let codegen = cx.new_model(|cx| {
346 Codegen::new(
347 editor.read(cx).buffer().clone(),
348 range.clone(),
349 None,
350 self.telemetry.clone(),
351 self.prompt_builder.clone(),
352 cx,
353 )
354 });
355
356 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
357 let prompt_editor = cx.new_view(|cx| {
358 PromptEditor::new(
359 assist_id,
360 gutter_dimensions.clone(),
361 self.prompt_history.clone(),
362 prompt_buffer.clone(),
363 codegen.clone(),
364 self.fs.clone(),
365 cx,
366 )
367 });
368
369 if assist_to_focus.is_none() {
370 let focus_assist = if newest_selection.reversed {
371 range.start.to_point(&snapshot) == newest_selection.start
372 } else {
373 range.end.to_point(&snapshot) == newest_selection.end
374 };
375 if focus_assist {
376 assist_to_focus = Some(assist_id);
377 }
378 }
379
380 let [prompt_block_id, end_block_id] =
381 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
382
383 assists.push((
384 assist_id,
385 range,
386 prompt_editor,
387 prompt_block_id,
388 end_block_id,
389 ));
390 }
391
392 let editor_assists = self
393 .assists_by_editor
394 .entry(editor.downgrade())
395 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
396 let mut assist_group = InlineAssistGroup::new();
397 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
398 self.assists.insert(
399 assist_id,
400 InlineAssist::new(
401 assist_id,
402 assist_group_id,
403 editor,
404 &prompt_editor,
405 prompt_block_id,
406 end_block_id,
407 range,
408 prompt_editor.read(cx).codegen.clone(),
409 workspace.clone(),
410 cx,
411 ),
412 );
413 assist_group.assist_ids.push(assist_id);
414 editor_assists.assist_ids.push(assist_id);
415 }
416 self.assist_groups.insert(assist_group_id, assist_group);
417
418 if let Some(assist_id) = assist_to_focus {
419 self.focus_assist(assist_id, cx);
420 }
421 }
422
423 #[allow(clippy::too_many_arguments)]
424 pub fn suggest_assist(
425 &mut self,
426 editor: &View<Editor>,
427 mut range: Range<Anchor>,
428 initial_prompt: String,
429 initial_transaction_id: Option<TransactionId>,
430 focus: bool,
431 workspace: Option<WeakView<Workspace>>,
432 cx: &mut WindowContext,
433 ) -> InlineAssistId {
434 let assist_group_id = self.next_assist_group_id.post_inc();
435 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
436 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
437
438 let assist_id = self.next_assist_id.post_inc();
439
440 let buffer = editor.read(cx).buffer().clone();
441 {
442 let snapshot = buffer.read(cx).read(cx);
443 range.start = range.start.bias_left(&snapshot);
444 range.end = range.end.bias_right(&snapshot);
445 }
446
447 let codegen = cx.new_model(|cx| {
448 Codegen::new(
449 editor.read(cx).buffer().clone(),
450 range.clone(),
451 initial_transaction_id,
452 self.telemetry.clone(),
453 self.prompt_builder.clone(),
454 cx,
455 )
456 });
457
458 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
459 let prompt_editor = cx.new_view(|cx| {
460 PromptEditor::new(
461 assist_id,
462 gutter_dimensions.clone(),
463 self.prompt_history.clone(),
464 prompt_buffer.clone(),
465 codegen.clone(),
466 self.fs.clone(),
467 cx,
468 )
469 });
470
471 let [prompt_block_id, end_block_id] =
472 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
473
474 let editor_assists = self
475 .assists_by_editor
476 .entry(editor.downgrade())
477 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
478
479 let mut assist_group = InlineAssistGroup::new();
480 self.assists.insert(
481 assist_id,
482 InlineAssist::new(
483 assist_id,
484 assist_group_id,
485 editor,
486 &prompt_editor,
487 prompt_block_id,
488 end_block_id,
489 range,
490 prompt_editor.read(cx).codegen.clone(),
491 workspace.clone(),
492 cx,
493 ),
494 );
495 assist_group.assist_ids.push(assist_id);
496 editor_assists.assist_ids.push(assist_id);
497 self.assist_groups.insert(assist_group_id, assist_group);
498
499 if focus {
500 self.focus_assist(assist_id, cx);
501 }
502
503 assist_id
504 }
505
506 fn insert_assist_blocks(
507 &self,
508 editor: &View<Editor>,
509 range: &Range<Anchor>,
510 prompt_editor: &View<PromptEditor>,
511 cx: &mut WindowContext,
512 ) -> [CustomBlockId; 2] {
513 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
514 prompt_editor
515 .editor
516 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
517 });
518 let assist_blocks = vec![
519 BlockProperties {
520 style: BlockStyle::Sticky,
521 placement: BlockPlacement::Above(range.start),
522 height: prompt_editor_height,
523 render: build_assist_editor_renderer(prompt_editor),
524 priority: 0,
525 },
526 BlockProperties {
527 style: BlockStyle::Sticky,
528 placement: BlockPlacement::Below(range.end),
529 height: 0,
530 render: Arc::new(|cx| {
531 v_flex()
532 .h_full()
533 .w_full()
534 .border_t_1()
535 .border_color(cx.theme().status().info_border)
536 .into_any_element()
537 }),
538 priority: 0,
539 },
540 ];
541
542 editor.update(cx, |editor, cx| {
543 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
544 [block_ids[0], block_ids[1]]
545 })
546 }
547
548 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
549 let assist = &self.assists[&assist_id];
550 let Some(decorations) = assist.decorations.as_ref() else {
551 return;
552 };
553 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
554 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
555
556 assist_group.active_assist_id = Some(assist_id);
557 if assist_group.linked {
558 for assist_id in &assist_group.assist_ids {
559 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
560 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
561 prompt_editor.set_show_cursor_when_unfocused(true, cx)
562 });
563 }
564 }
565 }
566
567 assist
568 .editor
569 .update(cx, |editor, cx| {
570 let scroll_top = editor.scroll_position(cx).y;
571 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
572 let prompt_row = editor
573 .row_for_block(decorations.prompt_block_id, cx)
574 .unwrap()
575 .0 as f32;
576
577 if (scroll_top..scroll_bottom).contains(&prompt_row) {
578 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
579 assist_id,
580 distance_from_top: prompt_row - scroll_top,
581 });
582 } else {
583 editor_assists.scroll_lock = None;
584 }
585 })
586 .ok();
587 }
588
589 fn handle_prompt_editor_focus_out(
590 &mut self,
591 assist_id: InlineAssistId,
592 cx: &mut WindowContext,
593 ) {
594 let assist = &self.assists[&assist_id];
595 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
596 if assist_group.active_assist_id == Some(assist_id) {
597 assist_group.active_assist_id = None;
598 if assist_group.linked {
599 for assist_id in &assist_group.assist_ids {
600 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
601 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
602 prompt_editor.set_show_cursor_when_unfocused(false, cx)
603 });
604 }
605 }
606 }
607 }
608 }
609
610 fn handle_prompt_editor_event(
611 &mut self,
612 prompt_editor: View<PromptEditor>,
613 event: &PromptEditorEvent,
614 cx: &mut WindowContext,
615 ) {
616 let assist_id = prompt_editor.read(cx).id;
617 match event {
618 PromptEditorEvent::StartRequested => {
619 self.start_assist(assist_id, cx);
620 }
621 PromptEditorEvent::StopRequested => {
622 self.stop_assist(assist_id, cx);
623 }
624 PromptEditorEvent::ConfirmRequested => {
625 self.finish_assist(assist_id, false, cx);
626 }
627 PromptEditorEvent::CancelRequested => {
628 self.finish_assist(assist_id, true, cx);
629 }
630 PromptEditorEvent::DismissRequested => {
631 self.dismiss_assist(assist_id, cx);
632 }
633 }
634 }
635
636 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
637 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
638 return;
639 };
640
641 if editor.read(cx).selections.count() == 1 {
642 let (selection, buffer) = editor.update(cx, |editor, cx| {
643 (
644 editor.selections.newest::<usize>(cx),
645 editor.buffer().read(cx).snapshot(cx),
646 )
647 });
648 for assist_id in &editor_assists.assist_ids {
649 let assist = &self.assists[assist_id];
650 let assist_range = assist.range.to_offset(&buffer);
651 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
652 {
653 if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
654 self.dismiss_assist(*assist_id, cx);
655 } else {
656 self.finish_assist(*assist_id, false, cx);
657 }
658
659 return;
660 }
661 }
662 }
663
664 cx.propagate();
665 }
666
667 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
668 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
669 return;
670 };
671
672 if editor.read(cx).selections.count() == 1 {
673 let (selection, buffer) = editor.update(cx, |editor, cx| {
674 (
675 editor.selections.newest::<usize>(cx),
676 editor.buffer().read(cx).snapshot(cx),
677 )
678 });
679 let mut closest_assist_fallback = None;
680 for assist_id in &editor_assists.assist_ids {
681 let assist = &self.assists[assist_id];
682 let assist_range = assist.range.to_offset(&buffer);
683 if assist.decorations.is_some() {
684 if assist_range.contains(&selection.start)
685 && assist_range.contains(&selection.end)
686 {
687 self.focus_assist(*assist_id, cx);
688 return;
689 } else {
690 let distance_from_selection = assist_range
691 .start
692 .abs_diff(selection.start)
693 .min(assist_range.start.abs_diff(selection.end))
694 + assist_range
695 .end
696 .abs_diff(selection.start)
697 .min(assist_range.end.abs_diff(selection.end));
698 match closest_assist_fallback {
699 Some((_, old_distance)) => {
700 if distance_from_selection < old_distance {
701 closest_assist_fallback =
702 Some((assist_id, distance_from_selection));
703 }
704 }
705 None => {
706 closest_assist_fallback = Some((assist_id, distance_from_selection))
707 }
708 }
709 }
710 }
711 }
712
713 if let Some((&assist_id, _)) = closest_assist_fallback {
714 self.focus_assist(assist_id, cx);
715 }
716 }
717
718 cx.propagate();
719 }
720
721 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
722 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
723 for assist_id in editor_assists.assist_ids.clone() {
724 self.finish_assist(assist_id, true, cx);
725 }
726 }
727 }
728
729 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
730 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
731 return;
732 };
733 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
734 return;
735 };
736 let assist = &self.assists[&scroll_lock.assist_id];
737 let Some(decorations) = assist.decorations.as_ref() else {
738 return;
739 };
740
741 editor.update(cx, |editor, cx| {
742 let scroll_position = editor.scroll_position(cx);
743 let target_scroll_top = editor
744 .row_for_block(decorations.prompt_block_id, cx)
745 .unwrap()
746 .0 as f32
747 - scroll_lock.distance_from_top;
748 if target_scroll_top != scroll_position.y {
749 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
750 }
751 });
752 }
753
754 fn handle_editor_event(
755 &mut self,
756 editor: View<Editor>,
757 event: &EditorEvent,
758 cx: &mut WindowContext,
759 ) {
760 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
761 return;
762 };
763
764 match event {
765 EditorEvent::Edited { transaction_id } => {
766 let buffer = editor.read(cx).buffer().read(cx);
767 let edited_ranges =
768 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
769 let snapshot = buffer.snapshot(cx);
770
771 for assist_id in editor_assists.assist_ids.clone() {
772 let assist = &self.assists[&assist_id];
773 if matches!(
774 assist.codegen.read(cx).status(cx),
775 CodegenStatus::Error(_) | CodegenStatus::Done
776 ) {
777 let assist_range = assist.range.to_offset(&snapshot);
778 if edited_ranges
779 .iter()
780 .any(|range| range.overlaps(&assist_range))
781 {
782 self.finish_assist(assist_id, false, cx);
783 }
784 }
785 }
786 }
787 EditorEvent::ScrollPositionChanged { .. } => {
788 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
789 let assist = &self.assists[&scroll_lock.assist_id];
790 if let Some(decorations) = assist.decorations.as_ref() {
791 let distance_from_top = editor.update(cx, |editor, cx| {
792 let scroll_top = editor.scroll_position(cx).y;
793 let prompt_row = editor
794 .row_for_block(decorations.prompt_block_id, cx)
795 .unwrap()
796 .0 as f32;
797 prompt_row - scroll_top
798 });
799
800 if distance_from_top != scroll_lock.distance_from_top {
801 editor_assists.scroll_lock = None;
802 }
803 }
804 }
805 }
806 EditorEvent::SelectionsChanged { .. } => {
807 for assist_id in editor_assists.assist_ids.clone() {
808 let assist = &self.assists[&assist_id];
809 if let Some(decorations) = assist.decorations.as_ref() {
810 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
811 return;
812 }
813 }
814 }
815
816 editor_assists.scroll_lock = None;
817 }
818 _ => {}
819 }
820 }
821
822 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
823 if let Some(assist) = self.assists.get(&assist_id) {
824 let assist_group_id = assist.group_id;
825 if self.assist_groups[&assist_group_id].linked {
826 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
827 self.finish_assist(assist_id, undo, cx);
828 }
829 return;
830 }
831 }
832
833 self.dismiss_assist(assist_id, cx);
834
835 if let Some(assist) = self.assists.remove(&assist_id) {
836 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
837 {
838 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
839 if entry.get().assist_ids.is_empty() {
840 entry.remove();
841 }
842 }
843
844 if let hash_map::Entry::Occupied(mut entry) =
845 self.assists_by_editor.entry(assist.editor.clone())
846 {
847 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
848 if entry.get().assist_ids.is_empty() {
849 entry.remove();
850 if let Some(editor) = assist.editor.upgrade() {
851 self.update_editor_highlights(&editor, cx);
852 }
853 } else {
854 entry.get().highlight_updates.send(()).ok();
855 }
856 }
857
858 let active_alternative = assist.codegen.read(cx).active_alternative().clone();
859 let message_id = active_alternative.read(cx).message_id.clone();
860
861 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
862 let language_name = assist.editor.upgrade().and_then(|editor| {
863 let multibuffer = editor.read(cx).buffer().read(cx);
864 let ranges = multibuffer.range_to_buffer_ranges(assist.range.clone(), cx);
865 ranges
866 .first()
867 .and_then(|(buffer, _, _)| buffer.read(cx).language())
868 .map(|language| language.name())
869 });
870 report_assistant_event(
871 AssistantEvent {
872 conversation_id: None,
873 kind: AssistantKind::Inline,
874 message_id,
875 phase: if undo {
876 AssistantPhase::Rejected
877 } else {
878 AssistantPhase::Accepted
879 },
880 model: model.telemetry_id(),
881 model_provider: model.provider_id().to_string(),
882 response_latency: None,
883 error_message: None,
884 language_name: language_name.map(|name| name.to_proto()),
885 },
886 Some(self.telemetry.clone()),
887 cx.http_client(),
888 model.api_key(cx),
889 cx.background_executor(),
890 );
891 }
892
893 if undo {
894 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
895 } else {
896 self.confirmed_assists.insert(assist_id, active_alternative);
897 }
898 }
899 }
900
901 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
902 let Some(assist) = self.assists.get_mut(&assist_id) else {
903 return false;
904 };
905 let Some(editor) = assist.editor.upgrade() else {
906 return false;
907 };
908 let Some(decorations) = assist.decorations.take() else {
909 return false;
910 };
911
912 editor.update(cx, |editor, cx| {
913 let mut to_remove = decorations.removed_line_block_ids;
914 to_remove.insert(decorations.prompt_block_id);
915 to_remove.insert(decorations.end_block_id);
916 editor.remove_blocks(to_remove, None, cx);
917 });
918
919 if decorations
920 .prompt_editor
921 .focus_handle(cx)
922 .contains_focused(cx)
923 {
924 self.focus_next_assist(assist_id, cx);
925 }
926
927 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
928 if editor_assists
929 .scroll_lock
930 .as_ref()
931 .map_or(false, |lock| lock.assist_id == assist_id)
932 {
933 editor_assists.scroll_lock = None;
934 }
935 editor_assists.highlight_updates.send(()).ok();
936 }
937
938 true
939 }
940
941 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
942 let Some(assist) = self.assists.get(&assist_id) else {
943 return;
944 };
945
946 let assist_group = &self.assist_groups[&assist.group_id];
947 let assist_ix = assist_group
948 .assist_ids
949 .iter()
950 .position(|id| *id == assist_id)
951 .unwrap();
952 let assist_ids = assist_group
953 .assist_ids
954 .iter()
955 .skip(assist_ix + 1)
956 .chain(assist_group.assist_ids.iter().take(assist_ix));
957
958 for assist_id in assist_ids {
959 let assist = &self.assists[assist_id];
960 if assist.decorations.is_some() {
961 self.focus_assist(*assist_id, cx);
962 return;
963 }
964 }
965
966 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
967 }
968
969 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
970 let Some(assist) = self.assists.get(&assist_id) else {
971 return;
972 };
973
974 if let Some(decorations) = assist.decorations.as_ref() {
975 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
976 prompt_editor.editor.update(cx, |editor, cx| {
977 editor.focus(cx);
978 editor.select_all(&SelectAll, cx);
979 })
980 });
981 }
982
983 self.scroll_to_assist(assist_id, cx);
984 }
985
986 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
987 let Some(assist) = self.assists.get(&assist_id) else {
988 return;
989 };
990 let Some(editor) = assist.editor.upgrade() else {
991 return;
992 };
993
994 let position = assist.range.start;
995 editor.update(cx, |editor, cx| {
996 editor.change_selections(None, cx, |selections| {
997 selections.select_anchor_ranges([position..position])
998 });
999
1000 let mut scroll_target_top;
1001 let mut scroll_target_bottom;
1002 if let Some(decorations) = assist.decorations.as_ref() {
1003 scroll_target_top = editor
1004 .row_for_block(decorations.prompt_block_id, cx)
1005 .unwrap()
1006 .0 as f32;
1007 scroll_target_bottom = editor
1008 .row_for_block(decorations.end_block_id, cx)
1009 .unwrap()
1010 .0 as f32;
1011 } else {
1012 let snapshot = editor.snapshot(cx);
1013 let start_row = assist
1014 .range
1015 .start
1016 .to_display_point(&snapshot.display_snapshot)
1017 .row();
1018 scroll_target_top = start_row.0 as f32;
1019 scroll_target_bottom = scroll_target_top + 1.;
1020 }
1021 scroll_target_top -= editor.vertical_scroll_margin() as f32;
1022 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
1023
1024 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
1025 let scroll_top = editor.scroll_position(cx).y;
1026 let scroll_bottom = scroll_top + height_in_lines;
1027
1028 if scroll_target_top < scroll_top {
1029 editor.set_scroll_position(point(0., scroll_target_top), cx);
1030 } else if scroll_target_bottom > scroll_bottom {
1031 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
1032 editor
1033 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
1034 } else {
1035 editor.set_scroll_position(point(0., scroll_target_top), cx);
1036 }
1037 }
1038 });
1039 }
1040
1041 fn unlink_assist_group(
1042 &mut self,
1043 assist_group_id: InlineAssistGroupId,
1044 cx: &mut WindowContext,
1045 ) -> Vec<InlineAssistId> {
1046 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
1047 assist_group.linked = false;
1048 for assist_id in &assist_group.assist_ids {
1049 let assist = self.assists.get_mut(assist_id).unwrap();
1050 if let Some(editor_decorations) = assist.decorations.as_ref() {
1051 editor_decorations
1052 .prompt_editor
1053 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
1054 }
1055 }
1056 assist_group.assist_ids.clone()
1057 }
1058
1059 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1060 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1061 assist
1062 } else {
1063 return;
1064 };
1065
1066 let assist_group_id = assist.group_id;
1067 if self.assist_groups[&assist_group_id].linked {
1068 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
1069 self.start_assist(assist_id, cx);
1070 }
1071 return;
1072 }
1073
1074 let Some(user_prompt) = assist.user_prompt(cx) else {
1075 return;
1076 };
1077
1078 self.prompt_history.retain(|prompt| *prompt != user_prompt);
1079 self.prompt_history.push_back(user_prompt.clone());
1080 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1081 self.prompt_history.pop_front();
1082 }
1083
1084 assist
1085 .codegen
1086 .update(cx, |codegen, cx| codegen.start(user_prompt, cx))
1087 .log_err();
1088 }
1089
1090 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1091 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1092 assist
1093 } else {
1094 return;
1095 };
1096
1097 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1098 }
1099
1100 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1101 let mut gutter_pending_ranges = Vec::new();
1102 let mut gutter_transformed_ranges = Vec::new();
1103 let mut foreground_ranges = Vec::new();
1104 let mut inserted_row_ranges = Vec::new();
1105 let empty_assist_ids = Vec::new();
1106 let assist_ids = self
1107 .assists_by_editor
1108 .get(&editor.downgrade())
1109 .map_or(&empty_assist_ids, |editor_assists| {
1110 &editor_assists.assist_ids
1111 });
1112
1113 for assist_id in assist_ids {
1114 if let Some(assist) = self.assists.get(assist_id) {
1115 let codegen = assist.codegen.read(cx);
1116 let buffer = codegen.buffer(cx).read(cx).read(cx);
1117 foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1118
1119 let pending_range =
1120 codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1121 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1122 gutter_pending_ranges.push(pending_range);
1123 }
1124
1125 if let Some(edit_position) = codegen.edit_position(cx) {
1126 let edited_range = assist.range.start..edit_position;
1127 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1128 gutter_transformed_ranges.push(edited_range);
1129 }
1130 }
1131
1132 if assist.decorations.is_some() {
1133 inserted_row_ranges
1134 .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1135 }
1136 }
1137 }
1138
1139 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1140 merge_ranges(&mut foreground_ranges, &snapshot);
1141 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1142 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1143 editor.update(cx, |editor, cx| {
1144 enum GutterPendingRange {}
1145 if gutter_pending_ranges.is_empty() {
1146 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1147 } else {
1148 editor.highlight_gutter::<GutterPendingRange>(
1149 &gutter_pending_ranges,
1150 |cx| cx.theme().status().info_background,
1151 cx,
1152 )
1153 }
1154
1155 enum GutterTransformedRange {}
1156 if gutter_transformed_ranges.is_empty() {
1157 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1158 } else {
1159 editor.highlight_gutter::<GutterTransformedRange>(
1160 &gutter_transformed_ranges,
1161 |cx| cx.theme().status().info,
1162 cx,
1163 )
1164 }
1165
1166 if foreground_ranges.is_empty() {
1167 editor.clear_highlights::<InlineAssist>(cx);
1168 } else {
1169 editor.highlight_text::<InlineAssist>(
1170 foreground_ranges,
1171 HighlightStyle {
1172 fade_out: Some(0.6),
1173 ..Default::default()
1174 },
1175 cx,
1176 );
1177 }
1178
1179 editor.clear_row_highlights::<InlineAssist>();
1180 for row_range in inserted_row_ranges {
1181 editor.highlight_rows::<InlineAssist>(
1182 row_range,
1183 cx.theme().status().info_background,
1184 false,
1185 cx,
1186 );
1187 }
1188 });
1189 }
1190
1191 fn update_editor_blocks(
1192 &mut self,
1193 editor: &View<Editor>,
1194 assist_id: InlineAssistId,
1195 cx: &mut WindowContext,
1196 ) {
1197 let Some(assist) = self.assists.get_mut(&assist_id) else {
1198 return;
1199 };
1200 let Some(decorations) = assist.decorations.as_mut() else {
1201 return;
1202 };
1203
1204 let codegen = assist.codegen.read(cx);
1205 let old_snapshot = codegen.snapshot(cx);
1206 let old_buffer = codegen.old_buffer(cx);
1207 let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1208
1209 editor.update(cx, |editor, cx| {
1210 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1211 editor.remove_blocks(old_blocks, None, cx);
1212
1213 let mut new_blocks = Vec::new();
1214 for (new_row, old_row_range) in deleted_row_ranges {
1215 let (_, buffer_start) = old_snapshot
1216 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1217 .unwrap();
1218 let (_, buffer_end) = old_snapshot
1219 .point_to_buffer_offset(Point::new(
1220 *old_row_range.end(),
1221 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1222 ))
1223 .unwrap();
1224
1225 let deleted_lines_editor = cx.new_view(|cx| {
1226 let multi_buffer = cx.new_model(|_| {
1227 MultiBuffer::without_headers(language::Capability::ReadOnly)
1228 });
1229 multi_buffer.update(cx, |multi_buffer, cx| {
1230 multi_buffer.push_excerpts(
1231 old_buffer.clone(),
1232 Some(ExcerptRange {
1233 context: buffer_start..buffer_end,
1234 primary: None,
1235 }),
1236 cx,
1237 );
1238 });
1239
1240 enum DeletedLines {}
1241 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1242 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1243 editor.set_show_wrap_guides(false, cx);
1244 editor.set_show_gutter(false, cx);
1245 editor.scroll_manager.set_forbid_vertical_scroll(true);
1246 editor.set_read_only(true);
1247 editor.set_show_inline_completions(Some(false), cx);
1248 editor.highlight_rows::<DeletedLines>(
1249 Anchor::min()..Anchor::max(),
1250 cx.theme().status().deleted_background,
1251 false,
1252 cx,
1253 );
1254 editor
1255 });
1256
1257 let height =
1258 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1259 new_blocks.push(BlockProperties {
1260 placement: BlockPlacement::Above(new_row),
1261 height,
1262 style: BlockStyle::Flex,
1263 render: Arc::new(move |cx| {
1264 div()
1265 .block_mouse_down()
1266 .bg(cx.theme().status().deleted_background)
1267 .size_full()
1268 .h(height as f32 * cx.line_height())
1269 .pl(cx.gutter_dimensions.full_width())
1270 .child(deleted_lines_editor.clone())
1271 .into_any_element()
1272 }),
1273 priority: 0,
1274 });
1275 }
1276
1277 decorations.removed_line_block_ids = editor
1278 .insert_blocks(new_blocks, None, cx)
1279 .into_iter()
1280 .collect();
1281 })
1282 }
1283
1284 fn resolve_inline_assist_target(
1285 workspace: &mut Workspace,
1286 cx: &mut WindowContext,
1287 ) -> Option<InlineAssistTarget> {
1288 if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx) {
1289 if terminal_panel
1290 .read(cx)
1291 .focus_handle(cx)
1292 .contains_focused(cx)
1293 {
1294 if let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| {
1295 pane.read(cx)
1296 .active_item()
1297 .and_then(|t| t.downcast::<TerminalView>())
1298 }) {
1299 return Some(InlineAssistTarget::Terminal(terminal_view));
1300 }
1301 }
1302 }
1303
1304 if let Some(workspace_editor) = workspace
1305 .active_item(cx)
1306 .and_then(|item| item.act_as::<Editor>(cx))
1307 {
1308 Some(InlineAssistTarget::Editor(workspace_editor))
1309 } else if let Some(terminal_view) = workspace
1310 .active_item(cx)
1311 .and_then(|item| item.act_as::<TerminalView>(cx))
1312 {
1313 Some(InlineAssistTarget::Terminal(terminal_view))
1314 } else {
1315 None
1316 }
1317 }
1318}
1319
1320struct EditorInlineAssists {
1321 assist_ids: Vec<InlineAssistId>,
1322 scroll_lock: Option<InlineAssistScrollLock>,
1323 highlight_updates: async_watch::Sender<()>,
1324 _update_highlights: Task<Result<()>>,
1325 _subscriptions: Vec<gpui::Subscription>,
1326}
1327
1328struct InlineAssistScrollLock {
1329 assist_id: InlineAssistId,
1330 distance_from_top: f32,
1331}
1332
1333impl EditorInlineAssists {
1334 #[allow(clippy::too_many_arguments)]
1335 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1336 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1337 Self {
1338 assist_ids: Vec::new(),
1339 scroll_lock: None,
1340 highlight_updates: highlight_updates_tx,
1341 _update_highlights: cx.spawn(|mut cx| {
1342 let editor = editor.downgrade();
1343 async move {
1344 while let Ok(()) = highlight_updates_rx.changed().await {
1345 let editor = editor.upgrade().context("editor was dropped")?;
1346 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1347 assistant.update_editor_highlights(&editor, cx);
1348 })?;
1349 }
1350 Ok(())
1351 }
1352 }),
1353 _subscriptions: vec![
1354 cx.observe_release(editor, {
1355 let editor = editor.downgrade();
1356 |_, cx| {
1357 InlineAssistant::update_global(cx, |this, cx| {
1358 this.handle_editor_release(editor, cx);
1359 })
1360 }
1361 }),
1362 cx.observe(editor, move |editor, cx| {
1363 InlineAssistant::update_global(cx, |this, cx| {
1364 this.handle_editor_change(editor, cx)
1365 })
1366 }),
1367 cx.subscribe(editor, move |editor, event, cx| {
1368 InlineAssistant::update_global(cx, |this, cx| {
1369 this.handle_editor_event(editor, event, cx)
1370 })
1371 }),
1372 editor.update(cx, |editor, cx| {
1373 let editor_handle = cx.view().downgrade();
1374 editor.register_action(
1375 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1376 InlineAssistant::update_global(cx, |this, cx| {
1377 if let Some(editor) = editor_handle.upgrade() {
1378 this.handle_editor_newline(editor, cx)
1379 }
1380 })
1381 },
1382 )
1383 }),
1384 editor.update(cx, |editor, cx| {
1385 let editor_handle = cx.view().downgrade();
1386 editor.register_action(
1387 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1388 InlineAssistant::update_global(cx, |this, cx| {
1389 if let Some(editor) = editor_handle.upgrade() {
1390 this.handle_editor_cancel(editor, cx)
1391 }
1392 })
1393 },
1394 )
1395 }),
1396 ],
1397 }
1398 }
1399}
1400
1401struct InlineAssistGroup {
1402 assist_ids: Vec<InlineAssistId>,
1403 linked: bool,
1404 active_assist_id: Option<InlineAssistId>,
1405}
1406
1407impl InlineAssistGroup {
1408 fn new() -> Self {
1409 Self {
1410 assist_ids: Vec::new(),
1411 linked: true,
1412 active_assist_id: None,
1413 }
1414 }
1415}
1416
1417fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1418 let editor = editor.clone();
1419 Arc::new(move |cx: &mut BlockContext| {
1420 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1421 editor.clone().into_any_element()
1422 })
1423}
1424
1425#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1426pub struct InlineAssistId(usize);
1427
1428impl InlineAssistId {
1429 fn post_inc(&mut self) -> InlineAssistId {
1430 let id = *self;
1431 self.0 += 1;
1432 id
1433 }
1434}
1435
1436#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1437struct InlineAssistGroupId(usize);
1438
1439impl InlineAssistGroupId {
1440 fn post_inc(&mut self) -> InlineAssistGroupId {
1441 let id = *self;
1442 self.0 += 1;
1443 id
1444 }
1445}
1446
1447enum PromptEditorEvent {
1448 StartRequested,
1449 StopRequested,
1450 ConfirmRequested,
1451 CancelRequested,
1452 DismissRequested,
1453}
1454
1455struct PromptEditor {
1456 id: InlineAssistId,
1457 editor: View<Editor>,
1458 language_model_selector: View<LanguageModelSelector>,
1459 edited_since_done: bool,
1460 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1461 prompt_history: VecDeque<String>,
1462 prompt_history_ix: Option<usize>,
1463 pending_prompt: String,
1464 codegen: Model<Codegen>,
1465 _codegen_subscription: Subscription,
1466 editor_subscriptions: Vec<Subscription>,
1467 show_rate_limit_notice: bool,
1468}
1469
1470impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1471
1472impl Render for PromptEditor {
1473 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1474 let gutter_dimensions = *self.gutter_dimensions.lock();
1475 let mut buttons = vec![Button::new("add-context", "Add Context")
1476 .style(ButtonStyle::Filled)
1477 .icon(IconName::Plus)
1478 .icon_position(IconPosition::Start)
1479 .into_any_element()];
1480 let codegen = self.codegen.read(cx);
1481 if codegen.alternative_count(cx) > 1 {
1482 buttons.push(self.render_cycle_controls(cx));
1483 }
1484
1485 let status = codegen.status(cx);
1486 buttons.extend(match status {
1487 CodegenStatus::Idle => {
1488 vec![
1489 IconButton::new("cancel", IconName::Close)
1490 .icon_color(Color::Muted)
1491 .shape(IconButtonShape::Square)
1492 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1493 .on_click(
1494 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1495 )
1496 .into_any_element(),
1497 IconButton::new("start", IconName::SparkleAlt)
1498 .icon_color(Color::Muted)
1499 .shape(IconButtonShape::Square)
1500 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1501 .on_click(
1502 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1503 )
1504 .into_any_element(),
1505 ]
1506 }
1507 CodegenStatus::Pending => {
1508 vec![
1509 IconButton::new("cancel", IconName::Close)
1510 .icon_color(Color::Muted)
1511 .shape(IconButtonShape::Square)
1512 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1513 .on_click(
1514 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1515 )
1516 .into_any_element(),
1517 IconButton::new("stop", IconName::Stop)
1518 .icon_color(Color::Error)
1519 .shape(IconButtonShape::Square)
1520 .tooltip(|cx| {
1521 Tooltip::with_meta(
1522 "Interrupt Transformation",
1523 Some(&menu::Cancel),
1524 "Changes won't be discarded",
1525 cx,
1526 )
1527 })
1528 .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
1529 .into_any_element(),
1530 ]
1531 }
1532 CodegenStatus::Error(_) | CodegenStatus::Done => {
1533 vec![
1534 IconButton::new("cancel", IconName::Close)
1535 .icon_color(Color::Muted)
1536 .shape(IconButtonShape::Square)
1537 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1538 .on_click(
1539 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1540 )
1541 .into_any_element(),
1542 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1543 IconButton::new("restart", IconName::RotateCw)
1544 .icon_color(Color::Info)
1545 .shape(IconButtonShape::Square)
1546 .tooltip(|cx| {
1547 Tooltip::with_meta(
1548 "Restart Transformation",
1549 Some(&menu::Confirm),
1550 "Changes will be discarded",
1551 cx,
1552 )
1553 })
1554 .on_click(cx.listener(|_, _, cx| {
1555 cx.emit(PromptEditorEvent::StartRequested);
1556 }))
1557 .into_any_element()
1558 } else {
1559 IconButton::new("confirm", IconName::Check)
1560 .icon_color(Color::Info)
1561 .shape(IconButtonShape::Square)
1562 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1563 .on_click(cx.listener(|_, _, cx| {
1564 cx.emit(PromptEditorEvent::ConfirmRequested);
1565 }))
1566 .into_any_element()
1567 },
1568 ]
1569 }
1570 });
1571
1572 h_flex()
1573 .key_context("PromptEditor")
1574 .bg(cx.theme().colors().editor_background)
1575 .block_mouse_down()
1576 .cursor(CursorStyle::Arrow)
1577 .border_y_1()
1578 .border_color(cx.theme().status().info_border)
1579 .size_full()
1580 .py(cx.line_height() / 2.5)
1581 .on_action(cx.listener(Self::confirm))
1582 .on_action(cx.listener(Self::cancel))
1583 .on_action(cx.listener(Self::move_up))
1584 .on_action(cx.listener(Self::move_down))
1585 .capture_action(cx.listener(Self::cycle_prev))
1586 .capture_action(cx.listener(Self::cycle_next))
1587 .child(
1588 h_flex()
1589 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1590 .justify_center()
1591 .gap_2()
1592 .child(LanguageModelSelectorPopoverMenu::new(
1593 self.language_model_selector.clone(),
1594 IconButton::new("context", IconName::SettingsAlt)
1595 .shape(IconButtonShape::Square)
1596 .icon_size(IconSize::Small)
1597 .icon_color(Color::Muted)
1598 .tooltip(move |cx| {
1599 Tooltip::with_meta(
1600 format!(
1601 "Using {}",
1602 LanguageModelRegistry::read_global(cx)
1603 .active_model()
1604 .map(|model| model.name().0)
1605 .unwrap_or_else(|| "No model selected".into()),
1606 ),
1607 None,
1608 "Change Model",
1609 cx,
1610 )
1611 }),
1612 ))
1613 .map(|el| {
1614 let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1615 return el;
1616 };
1617
1618 let error_message = SharedString::from(error.to_string());
1619 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1620 && cx.has_flag::<ZedPro>()
1621 {
1622 el.child(
1623 v_flex()
1624 .child(
1625 IconButton::new("rate-limit-error", IconName::XCircle)
1626 .toggle_state(self.show_rate_limit_notice)
1627 .shape(IconButtonShape::Square)
1628 .icon_size(IconSize::Small)
1629 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1630 )
1631 .children(self.show_rate_limit_notice.then(|| {
1632 deferred(
1633 anchored()
1634 .position_mode(gpui::AnchoredPositionMode::Local)
1635 .position(point(px(0.), px(24.)))
1636 .anchor(gpui::AnchorCorner::TopLeft)
1637 .child(self.render_rate_limit_notice(cx)),
1638 )
1639 })),
1640 )
1641 } else {
1642 el.child(
1643 div()
1644 .id("error")
1645 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1646 .child(
1647 Icon::new(IconName::XCircle)
1648 .size(IconSize::Small)
1649 .color(Color::Error),
1650 ),
1651 )
1652 }
1653 }),
1654 )
1655 .child(div().flex_1().child(self.render_editor(cx)))
1656 .child(h_flex().gap_2().pr_6().children(buttons))
1657 }
1658}
1659
1660impl FocusableView for PromptEditor {
1661 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1662 self.editor.focus_handle(cx)
1663 }
1664}
1665
1666impl PromptEditor {
1667 const MAX_LINES: u8 = 8;
1668
1669 #[allow(clippy::too_many_arguments)]
1670 fn new(
1671 id: InlineAssistId,
1672 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1673 prompt_history: VecDeque<String>,
1674 prompt_buffer: Model<MultiBuffer>,
1675 codegen: Model<Codegen>,
1676 fs: Arc<dyn Fs>,
1677 cx: &mut ViewContext<Self>,
1678 ) -> Self {
1679 let prompt_editor = cx.new_view(|cx| {
1680 let mut editor = Editor::new(
1681 EditorMode::AutoHeight {
1682 max_lines: Self::MAX_LINES as usize,
1683 },
1684 prompt_buffer,
1685 None,
1686 false,
1687 cx,
1688 );
1689 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1690 // Since the prompt editors for all inline assistants are linked,
1691 // always show the cursor (even when it isn't focused) because
1692 // typing in one will make what you typed appear in all of them.
1693 editor.set_show_cursor_when_unfocused(true, cx);
1694 editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx)), cx);
1695 editor
1696 });
1697
1698 let mut this = Self {
1699 id,
1700 editor: prompt_editor,
1701 language_model_selector: cx.new_view(|cx| {
1702 let fs = fs.clone();
1703 LanguageModelSelector::new(
1704 move |model, cx| {
1705 update_settings_file::<AssistantSettings>(
1706 fs.clone(),
1707 cx,
1708 move |settings, _| settings.set_model(model.clone()),
1709 );
1710 },
1711 cx,
1712 )
1713 }),
1714 edited_since_done: false,
1715 gutter_dimensions,
1716 prompt_history,
1717 prompt_history_ix: None,
1718 pending_prompt: String::new(),
1719 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1720 editor_subscriptions: Vec::new(),
1721 codegen,
1722 show_rate_limit_notice: false,
1723 };
1724 this.subscribe_to_editor(cx);
1725 this
1726 }
1727
1728 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1729 self.editor_subscriptions.clear();
1730 self.editor_subscriptions
1731 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1732 }
1733
1734 fn set_show_cursor_when_unfocused(
1735 &mut self,
1736 show_cursor_when_unfocused: bool,
1737 cx: &mut ViewContext<Self>,
1738 ) {
1739 self.editor.update(cx, |editor, cx| {
1740 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1741 });
1742 }
1743
1744 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1745 let prompt = self.prompt(cx);
1746 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1747 self.editor = cx.new_view(|cx| {
1748 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1749 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1750 editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx)), cx);
1751 editor.set_placeholder_text("Add a prompt…", cx);
1752 editor.set_text(prompt, cx);
1753 if focus {
1754 editor.focus(cx);
1755 }
1756 editor
1757 });
1758 self.subscribe_to_editor(cx);
1759 }
1760
1761 fn placeholder_text(codegen: &Codegen) -> String {
1762 let action = if codegen.is_insertion {
1763 "Generate"
1764 } else {
1765 "Transform"
1766 };
1767
1768 format!("{action}… ↓↑ for history")
1769 }
1770
1771 fn prompt(&self, cx: &AppContext) -> String {
1772 self.editor.read(cx).text(cx)
1773 }
1774
1775 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1776 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1777 if self.show_rate_limit_notice {
1778 cx.focus_view(&self.editor);
1779 }
1780 cx.notify();
1781 }
1782
1783 fn handle_prompt_editor_events(
1784 &mut self,
1785 _: View<Editor>,
1786 event: &EditorEvent,
1787 cx: &mut ViewContext<Self>,
1788 ) {
1789 match event {
1790 EditorEvent::Edited { .. } => {
1791 if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
1792 workspace
1793 .update(cx, |workspace, cx| {
1794 let is_via_ssh = workspace
1795 .project()
1796 .update(cx, |project, _| project.is_via_ssh());
1797
1798 workspace
1799 .client()
1800 .telemetry()
1801 .log_edit_event("inline assist", is_via_ssh);
1802 })
1803 .log_err();
1804 }
1805 let prompt = self.editor.read(cx).text(cx);
1806 if self
1807 .prompt_history_ix
1808 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1809 {
1810 self.prompt_history_ix.take();
1811 self.pending_prompt = prompt;
1812 }
1813
1814 self.edited_since_done = true;
1815 cx.notify();
1816 }
1817 EditorEvent::Blurred => {
1818 if self.show_rate_limit_notice {
1819 self.show_rate_limit_notice = false;
1820 cx.notify();
1821 }
1822 }
1823 _ => {}
1824 }
1825 }
1826
1827 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1828 match self.codegen.read(cx).status(cx) {
1829 CodegenStatus::Idle => {
1830 self.editor
1831 .update(cx, |editor, _| editor.set_read_only(false));
1832 }
1833 CodegenStatus::Pending => {
1834 self.editor
1835 .update(cx, |editor, _| editor.set_read_only(true));
1836 }
1837 CodegenStatus::Done => {
1838 self.edited_since_done = false;
1839 self.editor
1840 .update(cx, |editor, _| editor.set_read_only(false));
1841 }
1842 CodegenStatus::Error(error) => {
1843 if cx.has_flag::<ZedPro>()
1844 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1845 && !dismissed_rate_limit_notice()
1846 {
1847 self.show_rate_limit_notice = true;
1848 cx.notify();
1849 }
1850
1851 self.edited_since_done = false;
1852 self.editor
1853 .update(cx, |editor, _| editor.set_read_only(false));
1854 }
1855 }
1856 }
1857
1858 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1859 match self.codegen.read(cx).status(cx) {
1860 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1861 cx.emit(PromptEditorEvent::CancelRequested);
1862 }
1863 CodegenStatus::Pending => {
1864 cx.emit(PromptEditorEvent::StopRequested);
1865 }
1866 }
1867 }
1868
1869 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1870 match self.codegen.read(cx).status(cx) {
1871 CodegenStatus::Idle => {
1872 cx.emit(PromptEditorEvent::StartRequested);
1873 }
1874 CodegenStatus::Pending => {
1875 cx.emit(PromptEditorEvent::DismissRequested);
1876 }
1877 CodegenStatus::Done => {
1878 if self.edited_since_done {
1879 cx.emit(PromptEditorEvent::StartRequested);
1880 } else {
1881 cx.emit(PromptEditorEvent::ConfirmRequested);
1882 }
1883 }
1884 CodegenStatus::Error(_) => {
1885 cx.emit(PromptEditorEvent::StartRequested);
1886 }
1887 }
1888 }
1889
1890 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1891 if let Some(ix) = self.prompt_history_ix {
1892 if ix > 0 {
1893 self.prompt_history_ix = Some(ix - 1);
1894 let prompt = self.prompt_history[ix - 1].as_str();
1895 self.editor.update(cx, |editor, cx| {
1896 editor.set_text(prompt, cx);
1897 editor.move_to_beginning(&Default::default(), cx);
1898 });
1899 }
1900 } else if !self.prompt_history.is_empty() {
1901 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1902 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1903 self.editor.update(cx, |editor, cx| {
1904 editor.set_text(prompt, cx);
1905 editor.move_to_beginning(&Default::default(), cx);
1906 });
1907 }
1908 }
1909
1910 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1911 if let Some(ix) = self.prompt_history_ix {
1912 if ix < self.prompt_history.len() - 1 {
1913 self.prompt_history_ix = Some(ix + 1);
1914 let prompt = self.prompt_history[ix + 1].as_str();
1915 self.editor.update(cx, |editor, cx| {
1916 editor.set_text(prompt, cx);
1917 editor.move_to_end(&Default::default(), cx)
1918 });
1919 } else {
1920 self.prompt_history_ix = None;
1921 let prompt = self.pending_prompt.as_str();
1922 self.editor.update(cx, |editor, cx| {
1923 editor.set_text(prompt, cx);
1924 editor.move_to_end(&Default::default(), cx)
1925 });
1926 }
1927 }
1928 }
1929
1930 fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
1931 self.codegen
1932 .update(cx, |codegen, cx| codegen.cycle_prev(cx));
1933 }
1934
1935 fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
1936 self.codegen
1937 .update(cx, |codegen, cx| codegen.cycle_next(cx));
1938 }
1939
1940 fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
1941 let codegen = self.codegen.read(cx);
1942 let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
1943
1944 let model_registry = LanguageModelRegistry::read_global(cx);
1945 let default_model = model_registry.active_model();
1946 let alternative_models = model_registry.inline_alternative_models();
1947
1948 let get_model_name = |index: usize| -> String {
1949 let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
1950
1951 match index {
1952 0 => default_model.as_ref().map_or_else(String::new, name),
1953 index if index <= alternative_models.len() => alternative_models
1954 .get(index - 1)
1955 .map_or_else(String::new, name),
1956 _ => String::new(),
1957 }
1958 };
1959
1960 let total_models = alternative_models.len() + 1;
1961
1962 if total_models <= 1 {
1963 return div().into_any_element();
1964 }
1965
1966 let current_index = codegen.active_alternative;
1967 let prev_index = (current_index + total_models - 1) % total_models;
1968 let next_index = (current_index + 1) % total_models;
1969
1970 let prev_model_name = get_model_name(prev_index);
1971 let next_model_name = get_model_name(next_index);
1972
1973 h_flex()
1974 .child(
1975 IconButton::new("previous", IconName::ChevronLeft)
1976 .icon_color(Color::Muted)
1977 .disabled(disabled || current_index == 0)
1978 .shape(IconButtonShape::Square)
1979 .tooltip({
1980 let focus_handle = self.editor.focus_handle(cx);
1981 move |cx| {
1982 cx.new_view(|cx| {
1983 let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
1984 KeyBinding::for_action_in(
1985 &CyclePreviousInlineAssist,
1986 &focus_handle,
1987 cx,
1988 ),
1989 );
1990 if !disabled && current_index != 0 {
1991 tooltip = tooltip.meta(prev_model_name.clone());
1992 }
1993 tooltip
1994 })
1995 .into()
1996 }
1997 })
1998 .on_click(cx.listener(|this, _, cx| {
1999 this.codegen
2000 .update(cx, |codegen, cx| codegen.cycle_prev(cx))
2001 })),
2002 )
2003 .child(
2004 Label::new(format!(
2005 "{}/{}",
2006 codegen.active_alternative + 1,
2007 codegen.alternative_count(cx)
2008 ))
2009 .size(LabelSize::Small)
2010 .color(if disabled {
2011 Color::Disabled
2012 } else {
2013 Color::Muted
2014 }),
2015 )
2016 .child(
2017 IconButton::new("next", IconName::ChevronRight)
2018 .icon_color(Color::Muted)
2019 .disabled(disabled || current_index == total_models - 1)
2020 .shape(IconButtonShape::Square)
2021 .tooltip({
2022 let focus_handle = self.editor.focus_handle(cx);
2023 move |cx| {
2024 cx.new_view(|cx| {
2025 let mut tooltip = Tooltip::new("Next Alternative").key_binding(
2026 KeyBinding::for_action_in(
2027 &CycleNextInlineAssist,
2028 &focus_handle,
2029 cx,
2030 ),
2031 );
2032 if !disabled && current_index != total_models - 1 {
2033 tooltip = tooltip.meta(next_model_name.clone());
2034 }
2035 tooltip
2036 })
2037 .into()
2038 }
2039 })
2040 .on_click(cx.listener(|this, _, cx| {
2041 this.codegen
2042 .update(cx, |codegen, cx| codegen.cycle_next(cx))
2043 })),
2044 )
2045 .into_any_element()
2046 }
2047
2048 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2049 Popover::new().child(
2050 v_flex()
2051 .occlude()
2052 .p_2()
2053 .child(
2054 Label::new("Out of Tokens")
2055 .size(LabelSize::Small)
2056 .weight(FontWeight::BOLD),
2057 )
2058 .child(Label::new(
2059 "Try Zed Pro for higher limits, a wider range of models, and more.",
2060 ))
2061 .child(
2062 h_flex()
2063 .justify_between()
2064 .child(CheckboxWithLabel::new(
2065 "dont-show-again",
2066 Label::new("Don't show again"),
2067 if dismissed_rate_limit_notice() {
2068 ui::ToggleState::Selected
2069 } else {
2070 ui::ToggleState::Unselected
2071 },
2072 |selection, cx| {
2073 let is_dismissed = match selection {
2074 ui::ToggleState::Unselected => false,
2075 ui::ToggleState::Indeterminate => return,
2076 ui::ToggleState::Selected => true,
2077 };
2078
2079 set_rate_limit_notice_dismissed(is_dismissed, cx)
2080 },
2081 ))
2082 .child(
2083 h_flex()
2084 .gap_2()
2085 .child(
2086 Button::new("dismiss", "Dismiss")
2087 .style(ButtonStyle::Transparent)
2088 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2089 )
2090 .child(Button::new("more-info", "More Info").on_click(
2091 |_event, cx| {
2092 cx.dispatch_action(Box::new(
2093 zed_actions::OpenAccountSettings,
2094 ))
2095 },
2096 )),
2097 ),
2098 ),
2099 )
2100 }
2101
2102 fn render_editor(&mut self, cx: &mut ViewContext<Self>) -> AnyElement {
2103 let font_size = TextSize::Default.rems(cx);
2104 let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
2105
2106 v_flex()
2107 .key_context("MessageEditor")
2108 .size_full()
2109 .gap_2()
2110 .p_2()
2111 .bg(cx.theme().colors().editor_background)
2112 .child({
2113 let settings = ThemeSettings::get_global(cx);
2114 let text_style = TextStyle {
2115 color: cx.theme().colors().editor_foreground,
2116 font_family: settings.ui_font.family.clone(),
2117 font_features: settings.ui_font.features.clone(),
2118 font_size: font_size.into(),
2119 font_weight: settings.ui_font.weight,
2120 line_height: line_height.into(),
2121 ..Default::default()
2122 };
2123
2124 EditorElement::new(
2125 &self.editor,
2126 EditorStyle {
2127 background: cx.theme().colors().editor_background,
2128 local_player: cx.theme().players().local(),
2129 text: text_style,
2130 ..Default::default()
2131 },
2132 )
2133 })
2134 .into_any_element()
2135 }
2136}
2137
2138const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2139
2140fn dismissed_rate_limit_notice() -> bool {
2141 db::kvp::KEY_VALUE_STORE
2142 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2143 .log_err()
2144 .map_or(false, |s| s.is_some())
2145}
2146
2147fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2148 db::write_and_log(cx, move || async move {
2149 if is_dismissed {
2150 db::kvp::KEY_VALUE_STORE
2151 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2152 .await
2153 } else {
2154 db::kvp::KEY_VALUE_STORE
2155 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2156 .await
2157 }
2158 })
2159}
2160
2161pub struct InlineAssist {
2162 group_id: InlineAssistGroupId,
2163 range: Range<Anchor>,
2164 editor: WeakView<Editor>,
2165 decorations: Option<InlineAssistDecorations>,
2166 codegen: Model<Codegen>,
2167 _subscriptions: Vec<Subscription>,
2168 workspace: Option<WeakView<Workspace>>,
2169}
2170
2171impl InlineAssist {
2172 #[allow(clippy::too_many_arguments)]
2173 fn new(
2174 assist_id: InlineAssistId,
2175 group_id: InlineAssistGroupId,
2176 editor: &View<Editor>,
2177 prompt_editor: &View<PromptEditor>,
2178 prompt_block_id: CustomBlockId,
2179 end_block_id: CustomBlockId,
2180 range: Range<Anchor>,
2181 codegen: Model<Codegen>,
2182 workspace: Option<WeakView<Workspace>>,
2183 cx: &mut WindowContext,
2184 ) -> Self {
2185 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2186 InlineAssist {
2187 group_id,
2188 editor: editor.downgrade(),
2189 decorations: Some(InlineAssistDecorations {
2190 prompt_block_id,
2191 prompt_editor: prompt_editor.clone(),
2192 removed_line_block_ids: HashSet::default(),
2193 end_block_id,
2194 }),
2195 range,
2196 codegen: codegen.clone(),
2197 workspace: workspace.clone(),
2198 _subscriptions: vec![
2199 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2200 InlineAssistant::update_global(cx, |this, cx| {
2201 this.handle_prompt_editor_focus_in(assist_id, cx)
2202 })
2203 }),
2204 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2205 InlineAssistant::update_global(cx, |this, cx| {
2206 this.handle_prompt_editor_focus_out(assist_id, cx)
2207 })
2208 }),
2209 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2210 InlineAssistant::update_global(cx, |this, cx| {
2211 this.handle_prompt_editor_event(prompt_editor, event, cx)
2212 })
2213 }),
2214 cx.observe(&codegen, {
2215 let editor = editor.downgrade();
2216 move |_, cx| {
2217 if let Some(editor) = editor.upgrade() {
2218 InlineAssistant::update_global(cx, |this, cx| {
2219 if let Some(editor_assists) =
2220 this.assists_by_editor.get(&editor.downgrade())
2221 {
2222 editor_assists.highlight_updates.send(()).ok();
2223 }
2224
2225 this.update_editor_blocks(&editor, assist_id, cx);
2226 })
2227 }
2228 }
2229 }),
2230 cx.subscribe(&codegen, move |codegen, event, cx| {
2231 InlineAssistant::update_global(cx, |this, cx| match event {
2232 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2233 CodegenEvent::Finished => {
2234 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2235 assist
2236 } else {
2237 return;
2238 };
2239
2240 if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2241 if assist.decorations.is_none() {
2242 if let Some(workspace) = assist
2243 .workspace
2244 .as_ref()
2245 .and_then(|workspace| workspace.upgrade())
2246 {
2247 let error = format!("Inline assistant error: {}", error);
2248 workspace.update(cx, |workspace, cx| {
2249 struct InlineAssistantError;
2250
2251 let id =
2252 NotificationId::composite::<InlineAssistantError>(
2253 assist_id.0,
2254 );
2255
2256 workspace.show_toast(Toast::new(id, error), cx);
2257 })
2258 }
2259 }
2260 }
2261
2262 if assist.decorations.is_none() {
2263 this.finish_assist(assist_id, false, cx);
2264 }
2265 }
2266 })
2267 }),
2268 ],
2269 }
2270 }
2271
2272 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2273 let decorations = self.decorations.as_ref()?;
2274 Some(decorations.prompt_editor.read(cx).prompt(cx))
2275 }
2276}
2277
2278struct InlineAssistDecorations {
2279 prompt_block_id: CustomBlockId,
2280 prompt_editor: View<PromptEditor>,
2281 removed_line_block_ids: HashSet<CustomBlockId>,
2282 end_block_id: CustomBlockId,
2283}
2284
2285#[derive(Copy, Clone, Debug)]
2286pub enum CodegenEvent {
2287 Finished,
2288 Undone,
2289}
2290
2291pub struct Codegen {
2292 alternatives: Vec<Model<CodegenAlternative>>,
2293 active_alternative: usize,
2294 seen_alternatives: HashSet<usize>,
2295 subscriptions: Vec<Subscription>,
2296 buffer: Model<MultiBuffer>,
2297 range: Range<Anchor>,
2298 initial_transaction_id: Option<TransactionId>,
2299 telemetry: Arc<Telemetry>,
2300 builder: Arc<PromptBuilder>,
2301 is_insertion: bool,
2302}
2303
2304impl Codegen {
2305 pub fn new(
2306 buffer: Model<MultiBuffer>,
2307 range: Range<Anchor>,
2308 initial_transaction_id: Option<TransactionId>,
2309 telemetry: Arc<Telemetry>,
2310 builder: Arc<PromptBuilder>,
2311 cx: &mut ModelContext<Self>,
2312 ) -> Self {
2313 let codegen = cx.new_model(|cx| {
2314 CodegenAlternative::new(
2315 buffer.clone(),
2316 range.clone(),
2317 false,
2318 Some(telemetry.clone()),
2319 builder.clone(),
2320 cx,
2321 )
2322 });
2323 let mut this = Self {
2324 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2325 alternatives: vec![codegen],
2326 active_alternative: 0,
2327 seen_alternatives: HashSet::default(),
2328 subscriptions: Vec::new(),
2329 buffer,
2330 range,
2331 initial_transaction_id,
2332 telemetry,
2333 builder,
2334 };
2335 this.activate(0, cx);
2336 this
2337 }
2338
2339 fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2340 let codegen = self.active_alternative().clone();
2341 self.subscriptions.clear();
2342 self.subscriptions
2343 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2344 self.subscriptions
2345 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2346 }
2347
2348 fn active_alternative(&self) -> &Model<CodegenAlternative> {
2349 &self.alternatives[self.active_alternative]
2350 }
2351
2352 fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2353 &self.active_alternative().read(cx).status
2354 }
2355
2356 fn alternative_count(&self, cx: &AppContext) -> usize {
2357 LanguageModelRegistry::read_global(cx)
2358 .inline_alternative_models()
2359 .len()
2360 + 1
2361 }
2362
2363 pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2364 let next_active_ix = if self.active_alternative == 0 {
2365 self.alternatives.len() - 1
2366 } else {
2367 self.active_alternative - 1
2368 };
2369 self.activate(next_active_ix, cx);
2370 }
2371
2372 pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2373 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2374 self.activate(next_active_ix, cx);
2375 }
2376
2377 fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2378 self.active_alternative()
2379 .update(cx, |codegen, cx| codegen.set_active(false, cx));
2380 self.seen_alternatives.insert(index);
2381 self.active_alternative = index;
2382 self.active_alternative()
2383 .update(cx, |codegen, cx| codegen.set_active(true, cx));
2384 self.subscribe_to_alternative(cx);
2385 cx.notify();
2386 }
2387
2388 pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
2389 let alternative_models = LanguageModelRegistry::read_global(cx)
2390 .inline_alternative_models()
2391 .to_vec();
2392
2393 self.active_alternative()
2394 .update(cx, |alternative, cx| alternative.undo(cx));
2395 self.activate(0, cx);
2396 self.alternatives.truncate(1);
2397
2398 for _ in 0..alternative_models.len() {
2399 self.alternatives.push(cx.new_model(|cx| {
2400 CodegenAlternative::new(
2401 self.buffer.clone(),
2402 self.range.clone(),
2403 false,
2404 Some(self.telemetry.clone()),
2405 self.builder.clone(),
2406 cx,
2407 )
2408 }));
2409 }
2410
2411 let primary_model = LanguageModelRegistry::read_global(cx)
2412 .active_model()
2413 .context("no active model")?;
2414
2415 for (model, alternative) in iter::once(primary_model)
2416 .chain(alternative_models)
2417 .zip(&self.alternatives)
2418 {
2419 alternative.update(cx, |alternative, cx| {
2420 alternative.start(user_prompt.clone(), model.clone(), cx)
2421 })?;
2422 }
2423
2424 Ok(())
2425 }
2426
2427 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2428 for codegen in &self.alternatives {
2429 codegen.update(cx, |codegen, cx| codegen.stop(cx));
2430 }
2431 }
2432
2433 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2434 self.active_alternative()
2435 .update(cx, |codegen, cx| codegen.undo(cx));
2436
2437 self.buffer.update(cx, |buffer, cx| {
2438 if let Some(transaction_id) = self.initial_transaction_id.take() {
2439 buffer.undo_transaction(transaction_id, cx);
2440 buffer.refresh_preview(cx);
2441 }
2442 });
2443 }
2444
2445 pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2446 self.active_alternative().read(cx).buffer.clone()
2447 }
2448
2449 pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2450 self.active_alternative().read(cx).old_buffer.clone()
2451 }
2452
2453 pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2454 self.active_alternative().read(cx).snapshot.clone()
2455 }
2456
2457 pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2458 self.active_alternative().read(cx).edit_position
2459 }
2460
2461 fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2462 &self.active_alternative().read(cx).diff
2463 }
2464
2465 pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2466 self.active_alternative().read(cx).last_equal_ranges()
2467 }
2468}
2469
2470impl EventEmitter<CodegenEvent> for Codegen {}
2471
2472pub struct CodegenAlternative {
2473 buffer: Model<MultiBuffer>,
2474 old_buffer: Model<Buffer>,
2475 snapshot: MultiBufferSnapshot,
2476 edit_position: Option<Anchor>,
2477 range: Range<Anchor>,
2478 last_equal_ranges: Vec<Range<Anchor>>,
2479 transformation_transaction_id: Option<TransactionId>,
2480 status: CodegenStatus,
2481 generation: Task<()>,
2482 diff: Diff,
2483 telemetry: Option<Arc<Telemetry>>,
2484 _subscription: gpui::Subscription,
2485 builder: Arc<PromptBuilder>,
2486 active: bool,
2487 edits: Vec<(Range<Anchor>, String)>,
2488 line_operations: Vec<LineOperation>,
2489 request: Option<LanguageModelRequest>,
2490 elapsed_time: Option<f64>,
2491 completion: Option<String>,
2492 message_id: Option<String>,
2493}
2494
2495enum CodegenStatus {
2496 Idle,
2497 Pending,
2498 Done,
2499 Error(anyhow::Error),
2500}
2501
2502#[derive(Default)]
2503struct Diff {
2504 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2505 inserted_row_ranges: Vec<Range<Anchor>>,
2506}
2507
2508impl Diff {
2509 fn is_empty(&self) -> bool {
2510 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2511 }
2512}
2513
2514impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2515
2516impl CodegenAlternative {
2517 pub fn new(
2518 buffer: Model<MultiBuffer>,
2519 range: Range<Anchor>,
2520 active: bool,
2521 telemetry: Option<Arc<Telemetry>>,
2522 builder: Arc<PromptBuilder>,
2523 cx: &mut ModelContext<Self>,
2524 ) -> Self {
2525 let snapshot = buffer.read(cx).snapshot(cx);
2526
2527 let (old_buffer, _, _) = buffer
2528 .read(cx)
2529 .range_to_buffer_ranges(range.clone(), cx)
2530 .pop()
2531 .unwrap();
2532 let old_buffer = cx.new_model(|cx| {
2533 let old_buffer = old_buffer.read(cx);
2534 let text = old_buffer.as_rope().clone();
2535 let line_ending = old_buffer.line_ending();
2536 let language = old_buffer.language().cloned();
2537 let language_registry = old_buffer.language_registry();
2538
2539 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2540 buffer.set_language(language, cx);
2541 if let Some(language_registry) = language_registry {
2542 buffer.set_language_registry(language_registry)
2543 }
2544 buffer
2545 });
2546
2547 Self {
2548 buffer: buffer.clone(),
2549 old_buffer,
2550 edit_position: None,
2551 message_id: None,
2552 snapshot,
2553 last_equal_ranges: Default::default(),
2554 transformation_transaction_id: None,
2555 status: CodegenStatus::Idle,
2556 generation: Task::ready(()),
2557 diff: Diff::default(),
2558 telemetry,
2559 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2560 builder,
2561 active,
2562 edits: Vec::new(),
2563 line_operations: Vec::new(),
2564 range,
2565 request: None,
2566 elapsed_time: None,
2567 completion: None,
2568 }
2569 }
2570
2571 fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2572 if active != self.active {
2573 self.active = active;
2574
2575 if self.active {
2576 let edits = self.edits.clone();
2577 self.apply_edits(edits, cx);
2578 if matches!(self.status, CodegenStatus::Pending) {
2579 let line_operations = self.line_operations.clone();
2580 self.reapply_line_based_diff(line_operations, cx);
2581 } else {
2582 self.reapply_batch_diff(cx).detach();
2583 }
2584 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2585 self.buffer.update(cx, |buffer, cx| {
2586 buffer.undo_transaction(transaction_id, cx);
2587 buffer.forget_transaction(transaction_id, cx);
2588 });
2589 }
2590 }
2591 }
2592
2593 fn handle_buffer_event(
2594 &mut self,
2595 _buffer: Model<MultiBuffer>,
2596 event: &multi_buffer::Event,
2597 cx: &mut ModelContext<Self>,
2598 ) {
2599 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2600 if self.transformation_transaction_id == Some(*transaction_id) {
2601 self.transformation_transaction_id = None;
2602 self.generation = Task::ready(());
2603 cx.emit(CodegenEvent::Undone);
2604 }
2605 }
2606 }
2607
2608 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2609 &self.last_equal_ranges
2610 }
2611
2612 pub fn start(
2613 &mut self,
2614 user_prompt: String,
2615 model: Arc<dyn LanguageModel>,
2616 cx: &mut ModelContext<Self>,
2617 ) -> Result<()> {
2618 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2619 self.buffer.update(cx, |buffer, cx| {
2620 buffer.undo_transaction(transformation_transaction_id, cx);
2621 });
2622 }
2623
2624 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2625
2626 let api_key = model.api_key(cx);
2627 let telemetry_id = model.telemetry_id();
2628 let provider_id = model.provider_id();
2629 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2630 if user_prompt.trim().to_lowercase() == "delete" {
2631 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2632 } else {
2633 let request = self.build_request(user_prompt, cx)?;
2634 self.request = Some(request.clone());
2635
2636 cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
2637 .boxed_local()
2638 };
2639 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2640 Ok(())
2641 }
2642
2643 fn build_request(&self, user_prompt: String, cx: &AppContext) -> Result<LanguageModelRequest> {
2644 let buffer = self.buffer.read(cx).snapshot(cx);
2645 let language = buffer.language_at(self.range.start);
2646 let language_name = if let Some(language) = language.as_ref() {
2647 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2648 None
2649 } else {
2650 Some(language.name())
2651 }
2652 } else {
2653 None
2654 };
2655
2656 let language_name = language_name.as_ref();
2657 let start = buffer.point_to_buffer_offset(self.range.start);
2658 let end = buffer.point_to_buffer_offset(self.range.end);
2659 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2660 let (start_buffer, start_buffer_offset) = start;
2661 let (end_buffer, end_buffer_offset) = end;
2662 if start_buffer.remote_id() == end_buffer.remote_id() {
2663 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2664 } else {
2665 return Err(anyhow::anyhow!("invalid transformation range"));
2666 }
2667 } else {
2668 return Err(anyhow::anyhow!("invalid transformation range"));
2669 };
2670
2671 let prompt = self
2672 .builder
2673 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2674 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2675
2676 Ok(LanguageModelRequest {
2677 tools: Vec::new(),
2678 stop: Vec::new(),
2679 temperature: None,
2680 messages: vec![LanguageModelRequestMessage {
2681 role: Role::User,
2682 content: vec![prompt.into()],
2683 cache: false,
2684 }],
2685 })
2686 }
2687
2688 pub fn handle_stream(
2689 &mut self,
2690 model_telemetry_id: String,
2691 model_provider_id: String,
2692 model_api_key: Option<String>,
2693 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2694 cx: &mut ModelContext<Self>,
2695 ) {
2696 let start_time = Instant::now();
2697 let snapshot = self.snapshot.clone();
2698 let selected_text = snapshot
2699 .text_for_range(self.range.start..self.range.end)
2700 .collect::<Rope>();
2701
2702 let selection_start = self.range.start.to_point(&snapshot);
2703
2704 // Start with the indentation of the first line in the selection
2705 let mut suggested_line_indent = snapshot
2706 .suggested_indents(selection_start.row..=selection_start.row, cx)
2707 .into_values()
2708 .next()
2709 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2710
2711 // If the first line in the selection does not have indentation, check the following lines
2712 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2713 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2714 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2715 // Prefer tabs if a line in the selection uses tabs as indentation
2716 if line_indent.kind == IndentKind::Tab {
2717 suggested_line_indent.kind = IndentKind::Tab;
2718 break;
2719 }
2720 }
2721 }
2722
2723 let http_client = cx.http_client().clone();
2724 let telemetry = self.telemetry.clone();
2725 let language_name = {
2726 let multibuffer = self.buffer.read(cx);
2727 let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
2728 ranges
2729 .first()
2730 .and_then(|(buffer, _, _)| buffer.read(cx).language())
2731 .map(|language| language.name())
2732 };
2733
2734 self.diff = Diff::default();
2735 self.status = CodegenStatus::Pending;
2736 let mut edit_start = self.range.start.to_offset(&snapshot);
2737 let completion = Arc::new(Mutex::new(String::new()));
2738 let completion_clone = completion.clone();
2739
2740 self.generation = cx.spawn(|codegen, mut cx| {
2741 async move {
2742 let stream = stream.await;
2743 let message_id = stream
2744 .as_ref()
2745 .ok()
2746 .and_then(|stream| stream.message_id.clone());
2747 let generate = async {
2748 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2749 let executor = cx.background_executor().clone();
2750 let message_id = message_id.clone();
2751 let line_based_stream_diff: Task<anyhow::Result<()>> =
2752 cx.background_executor().spawn(async move {
2753 let mut response_latency = None;
2754 let request_start = Instant::now();
2755 let diff = async {
2756 let chunks = StripInvalidSpans::new(stream?.stream);
2757 futures::pin_mut!(chunks);
2758 let mut diff = StreamingDiff::new(selected_text.to_string());
2759 let mut line_diff = LineDiff::default();
2760
2761 let mut new_text = String::new();
2762 let mut base_indent = None;
2763 let mut line_indent = None;
2764 let mut first_line = true;
2765
2766 while let Some(chunk) = chunks.next().await {
2767 if response_latency.is_none() {
2768 response_latency = Some(request_start.elapsed());
2769 }
2770 let chunk = chunk?;
2771 completion_clone.lock().push_str(&chunk);
2772
2773 let mut lines = chunk.split('\n').peekable();
2774 while let Some(line) = lines.next() {
2775 new_text.push_str(line);
2776 if line_indent.is_none() {
2777 if let Some(non_whitespace_ch_ix) =
2778 new_text.find(|ch: char| !ch.is_whitespace())
2779 {
2780 line_indent = Some(non_whitespace_ch_ix);
2781 base_indent = base_indent.or(line_indent);
2782
2783 let line_indent = line_indent.unwrap();
2784 let base_indent = base_indent.unwrap();
2785 let indent_delta =
2786 line_indent as i32 - base_indent as i32;
2787 let mut corrected_indent_len = cmp::max(
2788 0,
2789 suggested_line_indent.len as i32 + indent_delta,
2790 )
2791 as usize;
2792 if first_line {
2793 corrected_indent_len = corrected_indent_len
2794 .saturating_sub(
2795 selection_start.column as usize,
2796 );
2797 }
2798
2799 let indent_char = suggested_line_indent.char();
2800 let mut indent_buffer = [0; 4];
2801 let indent_str =
2802 indent_char.encode_utf8(&mut indent_buffer);
2803 new_text.replace_range(
2804 ..line_indent,
2805 &indent_str.repeat(corrected_indent_len),
2806 );
2807 }
2808 }
2809
2810 if line_indent.is_some() {
2811 let char_ops = diff.push_new(&new_text);
2812 line_diff
2813 .push_char_operations(&char_ops, &selected_text);
2814 diff_tx
2815 .send((char_ops, line_diff.line_operations()))
2816 .await?;
2817 new_text.clear();
2818 }
2819
2820 if lines.peek().is_some() {
2821 let char_ops = diff.push_new("\n");
2822 line_diff
2823 .push_char_operations(&char_ops, &selected_text);
2824 diff_tx
2825 .send((char_ops, line_diff.line_operations()))
2826 .await?;
2827 if line_indent.is_none() {
2828 // Don't write out the leading indentation in empty lines on the next line
2829 // This is the case where the above if statement didn't clear the buffer
2830 new_text.clear();
2831 }
2832 line_indent = None;
2833 first_line = false;
2834 }
2835 }
2836 }
2837
2838 let mut char_ops = diff.push_new(&new_text);
2839 char_ops.extend(diff.finish());
2840 line_diff.push_char_operations(&char_ops, &selected_text);
2841 line_diff.finish(&selected_text);
2842 diff_tx
2843 .send((char_ops, line_diff.line_operations()))
2844 .await?;
2845
2846 anyhow::Ok(())
2847 };
2848
2849 let result = diff.await;
2850
2851 let error_message =
2852 result.as_ref().err().map(|error| error.to_string());
2853 report_assistant_event(
2854 AssistantEvent {
2855 conversation_id: None,
2856 message_id,
2857 kind: AssistantKind::Inline,
2858 phase: AssistantPhase::Response,
2859 model: model_telemetry_id,
2860 model_provider: model_provider_id.to_string(),
2861 response_latency,
2862 error_message,
2863 language_name: language_name.map(|name| name.to_proto()),
2864 },
2865 telemetry,
2866 http_client,
2867 model_api_key,
2868 &executor,
2869 );
2870
2871 result?;
2872 Ok(())
2873 });
2874
2875 while let Some((char_ops, line_ops)) = diff_rx.next().await {
2876 codegen.update(&mut cx, |codegen, cx| {
2877 codegen.last_equal_ranges.clear();
2878
2879 let edits = char_ops
2880 .into_iter()
2881 .filter_map(|operation| match operation {
2882 CharOperation::Insert { text } => {
2883 let edit_start = snapshot.anchor_after(edit_start);
2884 Some((edit_start..edit_start, text))
2885 }
2886 CharOperation::Delete { bytes } => {
2887 let edit_end = edit_start + bytes;
2888 let edit_range = snapshot.anchor_after(edit_start)
2889 ..snapshot.anchor_before(edit_end);
2890 edit_start = edit_end;
2891 Some((edit_range, String::new()))
2892 }
2893 CharOperation::Keep { bytes } => {
2894 let edit_end = edit_start + bytes;
2895 let edit_range = snapshot.anchor_after(edit_start)
2896 ..snapshot.anchor_before(edit_end);
2897 edit_start = edit_end;
2898 codegen.last_equal_ranges.push(edit_range);
2899 None
2900 }
2901 })
2902 .collect::<Vec<_>>();
2903
2904 if codegen.active {
2905 codegen.apply_edits(edits.iter().cloned(), cx);
2906 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2907 }
2908 codegen.edits.extend(edits);
2909 codegen.line_operations = line_ops;
2910 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2911
2912 cx.notify();
2913 })?;
2914 }
2915
2916 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2917 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2918 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2919 let batch_diff_task =
2920 codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
2921 let (line_based_stream_diff, ()) =
2922 join!(line_based_stream_diff, batch_diff_task);
2923 line_based_stream_diff?;
2924
2925 anyhow::Ok(())
2926 };
2927
2928 let result = generate.await;
2929 let elapsed_time = start_time.elapsed().as_secs_f64();
2930
2931 codegen
2932 .update(&mut cx, |this, cx| {
2933 this.message_id = message_id;
2934 this.last_equal_ranges.clear();
2935 if let Err(error) = result {
2936 this.status = CodegenStatus::Error(error);
2937 } else {
2938 this.status = CodegenStatus::Done;
2939 }
2940 this.elapsed_time = Some(elapsed_time);
2941 this.completion = Some(completion.lock().clone());
2942 cx.emit(CodegenEvent::Finished);
2943 cx.notify();
2944 })
2945 .ok();
2946 }
2947 });
2948 cx.notify();
2949 }
2950
2951 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2952 self.last_equal_ranges.clear();
2953 if self.diff.is_empty() {
2954 self.status = CodegenStatus::Idle;
2955 } else {
2956 self.status = CodegenStatus::Done;
2957 }
2958 self.generation = Task::ready(());
2959 cx.emit(CodegenEvent::Finished);
2960 cx.notify();
2961 }
2962
2963 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2964 self.buffer.update(cx, |buffer, cx| {
2965 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2966 buffer.undo_transaction(transaction_id, cx);
2967 buffer.refresh_preview(cx);
2968 }
2969 });
2970 }
2971
2972 fn apply_edits(
2973 &mut self,
2974 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
2975 cx: &mut ModelContext<CodegenAlternative>,
2976 ) {
2977 let transaction = self.buffer.update(cx, |buffer, cx| {
2978 // Avoid grouping assistant edits with user edits.
2979 buffer.finalize_last_transaction(cx);
2980 buffer.start_transaction(cx);
2981 buffer.edit(edits, None, cx);
2982 buffer.end_transaction(cx)
2983 });
2984
2985 if let Some(transaction) = transaction {
2986 if let Some(first_transaction) = self.transformation_transaction_id {
2987 // Group all assistant edits into the first transaction.
2988 self.buffer.update(cx, |buffer, cx| {
2989 buffer.merge_transactions(transaction, first_transaction, cx)
2990 });
2991 } else {
2992 self.transformation_transaction_id = Some(transaction);
2993 self.buffer
2994 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
2995 }
2996 }
2997 }
2998
2999 fn reapply_line_based_diff(
3000 &mut self,
3001 line_operations: impl IntoIterator<Item = LineOperation>,
3002 cx: &mut ModelContext<Self>,
3003 ) {
3004 let old_snapshot = self.snapshot.clone();
3005 let old_range = self.range.to_point(&old_snapshot);
3006 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3007 let new_range = self.range.to_point(&new_snapshot);
3008
3009 let mut old_row = old_range.start.row;
3010 let mut new_row = new_range.start.row;
3011
3012 self.diff.deleted_row_ranges.clear();
3013 self.diff.inserted_row_ranges.clear();
3014 for operation in line_operations {
3015 match operation {
3016 LineOperation::Keep { lines } => {
3017 old_row += lines;
3018 new_row += lines;
3019 }
3020 LineOperation::Delete { lines } => {
3021 let old_end_row = old_row + lines - 1;
3022 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3023
3024 if let Some((_, last_deleted_row_range)) =
3025 self.diff.deleted_row_ranges.last_mut()
3026 {
3027 if *last_deleted_row_range.end() + 1 == old_row {
3028 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3029 } else {
3030 self.diff
3031 .deleted_row_ranges
3032 .push((new_row, old_row..=old_end_row));
3033 }
3034 } else {
3035 self.diff
3036 .deleted_row_ranges
3037 .push((new_row, old_row..=old_end_row));
3038 }
3039
3040 old_row += lines;
3041 }
3042 LineOperation::Insert { lines } => {
3043 let new_end_row = new_row + lines - 1;
3044 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3045 let end = new_snapshot.anchor_before(Point::new(
3046 new_end_row,
3047 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3048 ));
3049 self.diff.inserted_row_ranges.push(start..end);
3050 new_row += lines;
3051 }
3052 }
3053
3054 cx.notify();
3055 }
3056 }
3057
3058 fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3059 let old_snapshot = self.snapshot.clone();
3060 let old_range = self.range.to_point(&old_snapshot);
3061 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3062 let new_range = self.range.to_point(&new_snapshot);
3063
3064 cx.spawn(|codegen, mut cx| async move {
3065 let (deleted_row_ranges, inserted_row_ranges) = cx
3066 .background_executor()
3067 .spawn(async move {
3068 let old_text = old_snapshot
3069 .text_for_range(
3070 Point::new(old_range.start.row, 0)
3071 ..Point::new(
3072 old_range.end.row,
3073 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3074 ),
3075 )
3076 .collect::<String>();
3077 let new_text = new_snapshot
3078 .text_for_range(
3079 Point::new(new_range.start.row, 0)
3080 ..Point::new(
3081 new_range.end.row,
3082 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3083 ),
3084 )
3085 .collect::<String>();
3086
3087 let mut old_row = old_range.start.row;
3088 let mut new_row = new_range.start.row;
3089 let batch_diff =
3090 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3091
3092 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3093 let mut inserted_row_ranges = Vec::new();
3094 for change in batch_diff.iter_all_changes() {
3095 let line_count = change.value().lines().count() as u32;
3096 match change.tag() {
3097 similar::ChangeTag::Equal => {
3098 old_row += line_count;
3099 new_row += line_count;
3100 }
3101 similar::ChangeTag::Delete => {
3102 let old_end_row = old_row + line_count - 1;
3103 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3104
3105 if let Some((_, last_deleted_row_range)) =
3106 deleted_row_ranges.last_mut()
3107 {
3108 if *last_deleted_row_range.end() + 1 == old_row {
3109 *last_deleted_row_range =
3110 *last_deleted_row_range.start()..=old_end_row;
3111 } else {
3112 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3113 }
3114 } else {
3115 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3116 }
3117
3118 old_row += line_count;
3119 }
3120 similar::ChangeTag::Insert => {
3121 let new_end_row = new_row + line_count - 1;
3122 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3123 let end = new_snapshot.anchor_before(Point::new(
3124 new_end_row,
3125 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3126 ));
3127 inserted_row_ranges.push(start..end);
3128 new_row += line_count;
3129 }
3130 }
3131 }
3132
3133 (deleted_row_ranges, inserted_row_ranges)
3134 })
3135 .await;
3136
3137 codegen
3138 .update(&mut cx, |codegen, cx| {
3139 codegen.diff.deleted_row_ranges = deleted_row_ranges;
3140 codegen.diff.inserted_row_ranges = inserted_row_ranges;
3141 cx.notify();
3142 })
3143 .ok();
3144 })
3145 }
3146}
3147
3148struct StripInvalidSpans<T> {
3149 stream: T,
3150 stream_done: bool,
3151 buffer: String,
3152 first_line: bool,
3153 line_end: bool,
3154 starts_with_code_block: bool,
3155}
3156
3157impl<T> StripInvalidSpans<T>
3158where
3159 T: Stream<Item = Result<String>>,
3160{
3161 fn new(stream: T) -> Self {
3162 Self {
3163 stream,
3164 stream_done: false,
3165 buffer: String::new(),
3166 first_line: true,
3167 line_end: false,
3168 starts_with_code_block: false,
3169 }
3170 }
3171}
3172
3173impl<T> Stream for StripInvalidSpans<T>
3174where
3175 T: Stream<Item = Result<String>>,
3176{
3177 type Item = Result<String>;
3178
3179 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3180 const CODE_BLOCK_DELIMITER: &str = "```";
3181 const CURSOR_SPAN: &str = "<|CURSOR|>";
3182
3183 let this = unsafe { self.get_unchecked_mut() };
3184 loop {
3185 if !this.stream_done {
3186 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3187 match stream.as_mut().poll_next(cx) {
3188 Poll::Ready(Some(Ok(chunk))) => {
3189 this.buffer.push_str(&chunk);
3190 }
3191 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3192 Poll::Ready(None) => {
3193 this.stream_done = true;
3194 }
3195 Poll::Pending => return Poll::Pending,
3196 }
3197 }
3198
3199 let mut chunk = String::new();
3200 let mut consumed = 0;
3201 if !this.buffer.is_empty() {
3202 let mut lines = this.buffer.split('\n').enumerate().peekable();
3203 while let Some((line_ix, line)) = lines.next() {
3204 if line_ix > 0 {
3205 this.first_line = false;
3206 }
3207
3208 if this.first_line {
3209 let trimmed_line = line.trim();
3210 if lines.peek().is_some() {
3211 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3212 consumed += line.len() + 1;
3213 this.starts_with_code_block = true;
3214 continue;
3215 }
3216 } else if trimmed_line.is_empty()
3217 || prefixes(CODE_BLOCK_DELIMITER)
3218 .any(|prefix| trimmed_line.starts_with(prefix))
3219 {
3220 break;
3221 }
3222 }
3223
3224 let line_without_cursor = line.replace(CURSOR_SPAN, "");
3225 if lines.peek().is_some() {
3226 if this.line_end {
3227 chunk.push('\n');
3228 }
3229
3230 chunk.push_str(&line_without_cursor);
3231 this.line_end = true;
3232 consumed += line.len() + 1;
3233 } else if this.stream_done {
3234 if !this.starts_with_code_block
3235 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3236 {
3237 if this.line_end {
3238 chunk.push('\n');
3239 }
3240
3241 chunk.push_str(&line);
3242 }
3243
3244 consumed += line.len();
3245 } else {
3246 let trimmed_line = line.trim();
3247 if trimmed_line.is_empty()
3248 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3249 || prefixes(CODE_BLOCK_DELIMITER)
3250 .any(|prefix| trimmed_line.ends_with(prefix))
3251 {
3252 break;
3253 } else {
3254 if this.line_end {
3255 chunk.push('\n');
3256 this.line_end = false;
3257 }
3258
3259 chunk.push_str(&line_without_cursor);
3260 consumed += line.len();
3261 }
3262 }
3263 }
3264 }
3265
3266 this.buffer = this.buffer.split_off(consumed);
3267 if !chunk.is_empty() {
3268 return Poll::Ready(Some(Ok(chunk)));
3269 } else if this.stream_done {
3270 return Poll::Ready(None);
3271 }
3272 }
3273 }
3274}
3275
3276struct AssistantCodeActionProvider {
3277 editor: WeakView<Editor>,
3278 workspace: WeakView<Workspace>,
3279}
3280
3281impl CodeActionProvider for AssistantCodeActionProvider {
3282 fn code_actions(
3283 &self,
3284 buffer: &Model<Buffer>,
3285 range: Range<text::Anchor>,
3286 cx: &mut WindowContext,
3287 ) -> Task<Result<Vec<CodeAction>>> {
3288 if !AssistantSettings::get_global(cx).enabled {
3289 return Task::ready(Ok(Vec::new()));
3290 }
3291
3292 let snapshot = buffer.read(cx).snapshot();
3293 let mut range = range.to_point(&snapshot);
3294
3295 // Expand the range to line boundaries.
3296 range.start.column = 0;
3297 range.end.column = snapshot.line_len(range.end.row);
3298
3299 let mut has_diagnostics = false;
3300 for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3301 range.start = cmp::min(range.start, diagnostic.range.start);
3302 range.end = cmp::max(range.end, diagnostic.range.end);
3303 has_diagnostics = true;
3304 }
3305 if has_diagnostics {
3306 if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3307 if let Some(symbol) = symbols_containing_start.last() {
3308 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3309 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3310 }
3311 }
3312
3313 if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3314 if let Some(symbol) = symbols_containing_end.last() {
3315 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3316 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3317 }
3318 }
3319
3320 Task::ready(Ok(vec![CodeAction {
3321 server_id: language::LanguageServerId(0),
3322 range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3323 lsp_action: lsp::CodeAction {
3324 title: "Fix with Assistant".into(),
3325 ..Default::default()
3326 },
3327 }]))
3328 } else {
3329 Task::ready(Ok(Vec::new()))
3330 }
3331 }
3332
3333 fn apply_code_action(
3334 &self,
3335 buffer: Model<Buffer>,
3336 action: CodeAction,
3337 excerpt_id: ExcerptId,
3338 _push_to_history: bool,
3339 cx: &mut WindowContext,
3340 ) -> Task<Result<ProjectTransaction>> {
3341 let editor = self.editor.clone();
3342 let workspace = self.workspace.clone();
3343 cx.spawn(|mut cx| async move {
3344 let editor = editor.upgrade().context("editor was released")?;
3345 let range = editor
3346 .update(&mut cx, |editor, cx| {
3347 editor.buffer().update(cx, |multibuffer, cx| {
3348 let buffer = buffer.read(cx);
3349 let multibuffer_snapshot = multibuffer.read(cx);
3350
3351 let old_context_range =
3352 multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3353 let mut new_context_range = old_context_range.clone();
3354 if action
3355 .range
3356 .start
3357 .cmp(&old_context_range.start, buffer)
3358 .is_lt()
3359 {
3360 new_context_range.start = action.range.start;
3361 }
3362 if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3363 new_context_range.end = action.range.end;
3364 }
3365 drop(multibuffer_snapshot);
3366
3367 if new_context_range != old_context_range {
3368 multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3369 }
3370
3371 let multibuffer_snapshot = multibuffer.read(cx);
3372 Some(
3373 multibuffer_snapshot
3374 .anchor_in_excerpt(excerpt_id, action.range.start)?
3375 ..multibuffer_snapshot
3376 .anchor_in_excerpt(excerpt_id, action.range.end)?,
3377 )
3378 })
3379 })?
3380 .context("invalid range")?;
3381
3382 cx.update_global(|assistant: &mut InlineAssistant, cx| {
3383 let assist_id = assistant.suggest_assist(
3384 &editor,
3385 range,
3386 "Fix Diagnostics".into(),
3387 None,
3388 true,
3389 Some(workspace),
3390 cx,
3391 );
3392 assistant.start_assist(assist_id, cx);
3393 })?;
3394
3395 Ok(ProjectTransaction::default())
3396 })
3397 }
3398}
3399
3400fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3401 (0..text.len() - 1).map(|ix| &text[..ix + 1])
3402}
3403
3404fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3405 ranges.sort_unstable_by(|a, b| {
3406 a.start
3407 .cmp(&b.start, buffer)
3408 .then_with(|| b.end.cmp(&a.end, buffer))
3409 });
3410
3411 let mut ix = 0;
3412 while ix + 1 < ranges.len() {
3413 let b = ranges[ix + 1].clone();
3414 let a = &mut ranges[ix];
3415 if a.end.cmp(&b.start, buffer).is_gt() {
3416 if a.end.cmp(&b.end, buffer).is_lt() {
3417 a.end = b.end;
3418 }
3419 ranges.remove(ix + 1);
3420 } else {
3421 ix += 1;
3422 }
3423 }
3424}
3425
3426#[cfg(test)]
3427mod tests {
3428 use super::*;
3429 use futures::stream::{self};
3430 use gpui::{Context, TestAppContext};
3431 use indoc::indoc;
3432 use language::{
3433 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3434 Point,
3435 };
3436 use language_model::LanguageModelRegistry;
3437 use rand::prelude::*;
3438 use serde::Serialize;
3439 use settings::SettingsStore;
3440 use std::{future, sync::Arc};
3441
3442 #[derive(Serialize)]
3443 pub struct DummyCompletionRequest {
3444 pub name: String,
3445 }
3446
3447 #[gpui::test(iterations = 10)]
3448 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3449 cx.set_global(cx.update(SettingsStore::test));
3450 cx.update(language_model::LanguageModelRegistry::test);
3451 cx.update(language_settings::init);
3452
3453 let text = indoc! {"
3454 fn main() {
3455 let x = 0;
3456 for _ in 0..10 {
3457 x += 1;
3458 }
3459 }
3460 "};
3461 let buffer =
3462 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3463 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3464 let range = buffer.read_with(cx, |buffer, cx| {
3465 let snapshot = buffer.snapshot(cx);
3466 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3467 });
3468 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3469 let codegen = cx.new_model(|cx| {
3470 CodegenAlternative::new(
3471 buffer.clone(),
3472 range.clone(),
3473 true,
3474 None,
3475 prompt_builder,
3476 cx,
3477 )
3478 });
3479
3480 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3481
3482 let mut new_text = concat!(
3483 " let mut x = 0;\n",
3484 " while x < 10 {\n",
3485 " x += 1;\n",
3486 " }",
3487 );
3488 while !new_text.is_empty() {
3489 let max_len = cmp::min(new_text.len(), 10);
3490 let len = rng.gen_range(1..=max_len);
3491 let (chunk, suffix) = new_text.split_at(len);
3492 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3493 new_text = suffix;
3494 cx.background_executor.run_until_parked();
3495 }
3496 drop(chunks_tx);
3497 cx.background_executor.run_until_parked();
3498
3499 assert_eq!(
3500 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3501 indoc! {"
3502 fn main() {
3503 let mut x = 0;
3504 while x < 10 {
3505 x += 1;
3506 }
3507 }
3508 "}
3509 );
3510 }
3511
3512 #[gpui::test(iterations = 10)]
3513 async fn test_autoindent_when_generating_past_indentation(
3514 cx: &mut TestAppContext,
3515 mut rng: StdRng,
3516 ) {
3517 cx.set_global(cx.update(SettingsStore::test));
3518 cx.update(language_settings::init);
3519
3520 let text = indoc! {"
3521 fn main() {
3522 le
3523 }
3524 "};
3525 let buffer =
3526 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3527 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3528 let range = buffer.read_with(cx, |buffer, cx| {
3529 let snapshot = buffer.snapshot(cx);
3530 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3531 });
3532 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3533 let codegen = cx.new_model(|cx| {
3534 CodegenAlternative::new(
3535 buffer.clone(),
3536 range.clone(),
3537 true,
3538 None,
3539 prompt_builder,
3540 cx,
3541 )
3542 });
3543
3544 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3545
3546 cx.background_executor.run_until_parked();
3547
3548 let mut new_text = concat!(
3549 "t mut x = 0;\n",
3550 "while x < 10 {\n",
3551 " x += 1;\n",
3552 "}", //
3553 );
3554 while !new_text.is_empty() {
3555 let max_len = cmp::min(new_text.len(), 10);
3556 let len = rng.gen_range(1..=max_len);
3557 let (chunk, suffix) = new_text.split_at(len);
3558 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3559 new_text = suffix;
3560 cx.background_executor.run_until_parked();
3561 }
3562 drop(chunks_tx);
3563 cx.background_executor.run_until_parked();
3564
3565 assert_eq!(
3566 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3567 indoc! {"
3568 fn main() {
3569 let mut x = 0;
3570 while x < 10 {
3571 x += 1;
3572 }
3573 }
3574 "}
3575 );
3576 }
3577
3578 #[gpui::test(iterations = 10)]
3579 async fn test_autoindent_when_generating_before_indentation(
3580 cx: &mut TestAppContext,
3581 mut rng: StdRng,
3582 ) {
3583 cx.update(LanguageModelRegistry::test);
3584 cx.set_global(cx.update(SettingsStore::test));
3585 cx.update(language_settings::init);
3586
3587 let text = concat!(
3588 "fn main() {\n",
3589 " \n",
3590 "}\n" //
3591 );
3592 let buffer =
3593 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3594 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3595 let range = buffer.read_with(cx, |buffer, cx| {
3596 let snapshot = buffer.snapshot(cx);
3597 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3598 });
3599 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3600 let codegen = cx.new_model(|cx| {
3601 CodegenAlternative::new(
3602 buffer.clone(),
3603 range.clone(),
3604 true,
3605 None,
3606 prompt_builder,
3607 cx,
3608 )
3609 });
3610
3611 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3612
3613 cx.background_executor.run_until_parked();
3614
3615 let mut new_text = concat!(
3616 "let mut x = 0;\n",
3617 "while x < 10 {\n",
3618 " x += 1;\n",
3619 "}", //
3620 );
3621 while !new_text.is_empty() {
3622 let max_len = cmp::min(new_text.len(), 10);
3623 let len = rng.gen_range(1..=max_len);
3624 let (chunk, suffix) = new_text.split_at(len);
3625 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3626 new_text = suffix;
3627 cx.background_executor.run_until_parked();
3628 }
3629 drop(chunks_tx);
3630 cx.background_executor.run_until_parked();
3631
3632 assert_eq!(
3633 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3634 indoc! {"
3635 fn main() {
3636 let mut x = 0;
3637 while x < 10 {
3638 x += 1;
3639 }
3640 }
3641 "}
3642 );
3643 }
3644
3645 #[gpui::test(iterations = 10)]
3646 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3647 cx.update(LanguageModelRegistry::test);
3648 cx.set_global(cx.update(SettingsStore::test));
3649 cx.update(language_settings::init);
3650
3651 let text = indoc! {"
3652 func main() {
3653 \tx := 0
3654 \tfor i := 0; i < 10; i++ {
3655 \t\tx++
3656 \t}
3657 }
3658 "};
3659 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3660 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3661 let range = buffer.read_with(cx, |buffer, cx| {
3662 let snapshot = buffer.snapshot(cx);
3663 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3664 });
3665 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3666 let codegen = cx.new_model(|cx| {
3667 CodegenAlternative::new(
3668 buffer.clone(),
3669 range.clone(),
3670 true,
3671 None,
3672 prompt_builder,
3673 cx,
3674 )
3675 });
3676
3677 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3678 let new_text = concat!(
3679 "func main() {\n",
3680 "\tx := 0\n",
3681 "\tfor x < 10 {\n",
3682 "\t\tx++\n",
3683 "\t}", //
3684 );
3685 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3686 drop(chunks_tx);
3687 cx.background_executor.run_until_parked();
3688
3689 assert_eq!(
3690 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3691 indoc! {"
3692 func main() {
3693 \tx := 0
3694 \tfor x < 10 {
3695 \t\tx++
3696 \t}
3697 }
3698 "}
3699 );
3700 }
3701
3702 #[gpui::test]
3703 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3704 cx.update(LanguageModelRegistry::test);
3705 cx.set_global(cx.update(SettingsStore::test));
3706 cx.update(language_settings::init);
3707
3708 let text = indoc! {"
3709 fn main() {
3710 let x = 0;
3711 }
3712 "};
3713 let buffer =
3714 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3715 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3716 let range = buffer.read_with(cx, |buffer, cx| {
3717 let snapshot = buffer.snapshot(cx);
3718 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3719 });
3720 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3721 let codegen = cx.new_model(|cx| {
3722 CodegenAlternative::new(
3723 buffer.clone(),
3724 range.clone(),
3725 false,
3726 None,
3727 prompt_builder,
3728 cx,
3729 )
3730 });
3731
3732 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3733 chunks_tx
3734 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3735 .unwrap();
3736 drop(chunks_tx);
3737 cx.run_until_parked();
3738
3739 // The codegen is inactive, so the buffer doesn't get modified.
3740 assert_eq!(
3741 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3742 text
3743 );
3744
3745 // Activating the codegen applies the changes.
3746 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3747 assert_eq!(
3748 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3749 indoc! {"
3750 fn main() {
3751 let mut x = 0;
3752 x += 1;
3753 }
3754 "}
3755 );
3756
3757 // Deactivating the codegen undoes the changes.
3758 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3759 cx.run_until_parked();
3760 assert_eq!(
3761 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3762 text
3763 );
3764 }
3765
3766 #[gpui::test]
3767 async fn test_strip_invalid_spans_from_codeblock() {
3768 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3769 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3770 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3771 assert_chunks(
3772 "```html\n```js\nLorem ipsum dolor\n```\n```",
3773 "```js\nLorem ipsum dolor\n```",
3774 )
3775 .await;
3776 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3777 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3778 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3779 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3780
3781 async fn assert_chunks(text: &str, expected_text: &str) {
3782 for chunk_size in 1..=text.len() {
3783 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3784 .map(|chunk| chunk.unwrap())
3785 .collect::<String>()
3786 .await;
3787 assert_eq!(
3788 actual_text, expected_text,
3789 "failed to strip invalid spans, chunk size: {}",
3790 chunk_size
3791 );
3792 }
3793 }
3794
3795 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3796 stream::iter(
3797 text.chars()
3798 .collect::<Vec<_>>()
3799 .chunks(size)
3800 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3801 .collect::<Vec<_>>(),
3802 )
3803 }
3804 }
3805
3806 fn simulate_response_stream(
3807 codegen: Model<CodegenAlternative>,
3808 cx: &mut TestAppContext,
3809 ) -> mpsc::UnboundedSender<String> {
3810 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3811 codegen.update(cx, |codegen, cx| {
3812 codegen.handle_stream(
3813 String::new(),
3814 String::new(),
3815 None,
3816 future::ready(Ok(LanguageModelTextStream {
3817 message_id: None,
3818 stream: chunks_rx.map(Ok).boxed(),
3819 })),
3820 cx,
3821 );
3822 });
3823 chunks_tx
3824 }
3825
3826 fn rust_lang() -> Language {
3827 Language::new(
3828 LanguageConfig {
3829 name: "Rust".into(),
3830 matcher: LanguageMatcher {
3831 path_suffixes: vec!["rs".to_string()],
3832 ..Default::default()
3833 },
3834 ..Default::default()
3835 },
3836 Some(tree_sitter_rust::LANGUAGE.into()),
3837 )
3838 .with_indents_query(
3839 r#"
3840 (call_expression) @indent
3841 (field_expression) @indent
3842 (_ "(" ")" @end) @indent
3843 (_ "{" "}" @end) @indent
3844 "#,
3845 )
3846 .unwrap()
3847 }
3848}