1use crate::context::attach_context_to_message;
2use crate::context_picker::ContextPicker;
3use crate::context_store::ContextStore;
4use crate::context_strip::ContextStrip;
5use crate::inline_prompt_editor::{
6 render_cancel_button, CodegenStatus, PromptEditorEvent, PromptMode,
7};
8use crate::thread_store::ThreadStore;
9use crate::{
10 assistant_settings::AssistantSettings,
11 prompts::PromptBuilder,
12 streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
13 terminal_inline_assistant::TerminalInlineAssistant,
14 CycleNextInlineAssist, CyclePreviousInlineAssist,
15};
16use crate::{AssistantPanel, ToggleContextPicker};
17use anyhow::{Context as _, Result};
18use client::{telemetry::Telemetry, ErrorExt};
19use collections::{hash_map, HashMap, HashSet, VecDeque};
20use editor::{
21 actions::{MoveDown, MoveUp, SelectAll},
22 display_map::{
23 BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
24 ToDisplayPoint,
25 },
26 Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
27 EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
28 ToOffset as _, ToPoint,
29};
30use feature_flags::{FeatureFlagAppExt as _, ZedPro};
31use fs::Fs;
32use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
33use gpui::{
34 anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
35 FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext,
36 Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakModel, WeakView,
37 WindowContext,
38};
39use language::{Buffer, IndentKind, Point, Selection, TransactionId};
40use language_model::{
41 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
42 LanguageModelTextStream, Role,
43};
44use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
45use language_models::report_assistant_event;
46use multi_buffer::MultiBufferRow;
47use parking_lot::Mutex;
48use project::{CodeAction, ProjectTransaction};
49use rope::Rope;
50use settings::{update_settings_file, Settings, SettingsStore};
51use smol::future::FutureExt;
52use std::{
53 cmp,
54 future::Future,
55 iter, mem,
56 ops::{Range, RangeInclusive},
57 pin::Pin,
58 rc::Rc,
59 sync::Arc,
60 task::{self, Poll},
61 time::Instant,
62};
63use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
64use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
65use text::{OffsetRangeExt, ToPoint as _};
66use theme::ThemeSettings;
67use ui::{
68 prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip,
69};
70use util::{RangeExt, ResultExt};
71use workspace::{dock::Panel, ShowConfiguration};
72use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
73
74pub fn init(
75 fs: Arc<dyn Fs>,
76 prompt_builder: Arc<PromptBuilder>,
77 telemetry: Arc<Telemetry>,
78 cx: &mut AppContext,
79) {
80 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
81 cx.observe_new_views(|_workspace: &mut Workspace, cx| {
82 let workspace = cx.view().clone();
83 InlineAssistant::update_global(cx, |inline_assistant, cx| {
84 inline_assistant.register_workspace(&workspace, cx)
85 })
86 })
87 .detach();
88}
89
90const PROMPT_HISTORY_MAX_LEN: usize = 20;
91
92enum InlineAssistTarget {
93 Editor(View<Editor>),
94 Terminal(View<TerminalView>),
95}
96
97pub struct InlineAssistant {
98 next_assist_id: InlineAssistId,
99 next_assist_group_id: InlineAssistGroupId,
100 assists: HashMap<InlineAssistId, InlineAssist>,
101 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
102 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
103 confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
104 prompt_history: VecDeque<String>,
105 prompt_builder: Arc<PromptBuilder>,
106 telemetry: Arc<Telemetry>,
107 fs: Arc<dyn Fs>,
108}
109
110impl Global for InlineAssistant {}
111
112impl InlineAssistant {
113 pub fn new(
114 fs: Arc<dyn Fs>,
115 prompt_builder: Arc<PromptBuilder>,
116 telemetry: Arc<Telemetry>,
117 ) -> Self {
118 Self {
119 next_assist_id: InlineAssistId::default(),
120 next_assist_group_id: InlineAssistGroupId::default(),
121 assists: HashMap::default(),
122 assists_by_editor: HashMap::default(),
123 assist_groups: HashMap::default(),
124 confirmed_assists: HashMap::default(),
125 prompt_history: VecDeque::default(),
126 prompt_builder,
127 telemetry,
128 fs,
129 }
130 }
131
132 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
133 cx.subscribe(workspace, |workspace, event, cx| {
134 Self::update_global(cx, |this, cx| {
135 this.handle_workspace_event(workspace, event, cx)
136 });
137 })
138 .detach();
139
140 let workspace = workspace.downgrade();
141 cx.observe_global::<SettingsStore>(move |cx| {
142 let Some(workspace) = workspace.upgrade() else {
143 return;
144 };
145 let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
146 return;
147 };
148 let enabled = AssistantSettings::get_global(cx).enabled;
149 terminal_panel.update(cx, |terminal_panel, cx| {
150 terminal_panel.asssistant_enabled(enabled, cx)
151 });
152 })
153 .detach();
154 }
155
156 fn handle_workspace_event(
157 &mut self,
158 workspace: View<Workspace>,
159 event: &workspace::Event,
160 cx: &mut WindowContext,
161 ) {
162 match event {
163 workspace::Event::UserSavedItem { item, .. } => {
164 // When the user manually saves an editor, automatically accepts all finished transformations.
165 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
166 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
167 for assist_id in editor_assists.assist_ids.clone() {
168 let assist = &self.assists[&assist_id];
169 if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
170 self.finish_assist(assist_id, false, cx)
171 }
172 }
173 }
174 }
175 }
176 workspace::Event::ItemAdded { item } => {
177 self.register_workspace_item(&workspace, item.as_ref(), cx);
178 }
179 _ => (),
180 }
181 }
182
183 fn register_workspace_item(
184 &mut self,
185 workspace: &View<Workspace>,
186 item: &dyn ItemHandle,
187 cx: &mut WindowContext,
188 ) {
189 if let Some(editor) = item.act_as::<Editor>(cx) {
190 editor.update(cx, |editor, cx| {
191 let thread_store = workspace
192 .read(cx)
193 .panel::<AssistantPanel>(cx)
194 .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
195
196 editor.push_code_action_provider(
197 Rc::new(AssistantCodeActionProvider {
198 editor: cx.view().downgrade(),
199 workspace: workspace.downgrade(),
200 thread_store,
201 }),
202 cx,
203 );
204 });
205 }
206 }
207
208 pub fn inline_assist(
209 workspace: &mut Workspace,
210 _action: &zed_actions::InlineAssist,
211 cx: &mut ViewContext<Workspace>,
212 ) {
213 let settings = AssistantSettings::get_global(cx);
214 if !settings.enabled {
215 return;
216 }
217
218 let Some(inline_assist_target) = Self::resolve_inline_assist_target(workspace, cx) else {
219 return;
220 };
221
222 let is_authenticated = || {
223 LanguageModelRegistry::read_global(cx)
224 .active_provider()
225 .map_or(false, |provider| provider.is_authenticated(cx))
226 };
227
228 let thread_store = workspace
229 .panel::<AssistantPanel>(cx)
230 .map(|assistant_panel| assistant_panel.read(cx).thread_store().downgrade());
231
232 let handle_assist = |cx: &mut ViewContext<Workspace>| match inline_assist_target {
233 InlineAssistTarget::Editor(active_editor) => {
234 InlineAssistant::update_global(cx, |assistant, cx| {
235 assistant.assist(&active_editor, cx.view().downgrade(), thread_store, cx)
236 })
237 }
238 InlineAssistTarget::Terminal(active_terminal) => {
239 TerminalInlineAssistant::update_global(cx, |assistant, cx| {
240 assistant.assist(&active_terminal, cx.view().downgrade(), thread_store, cx)
241 })
242 }
243 };
244
245 if is_authenticated() {
246 handle_assist(cx);
247 } else {
248 cx.spawn(|_workspace, mut cx| async move {
249 let Some(task) = cx.update(|cx| {
250 LanguageModelRegistry::read_global(cx)
251 .active_provider()
252 .map_or(None, |provider| Some(provider.authenticate(cx)))
253 })?
254 else {
255 let answer = cx
256 .prompt(
257 gpui::PromptLevel::Warning,
258 "No language model provider configured",
259 None,
260 &["Configure", "Cancel"],
261 )
262 .await
263 .ok();
264 if let Some(answer) = answer {
265 if answer == 0 {
266 cx.update(|cx| cx.dispatch_action(Box::new(ShowConfiguration)))
267 .ok();
268 }
269 }
270 return Ok(());
271 };
272 task.await?;
273
274 anyhow::Ok(())
275 })
276 .detach_and_log_err(cx);
277
278 if is_authenticated() {
279 handle_assist(cx);
280 }
281 }
282 }
283
284 pub fn assist(
285 &mut self,
286 editor: &View<Editor>,
287 workspace: WeakView<Workspace>,
288 thread_store: Option<WeakModel<ThreadStore>>,
289 cx: &mut WindowContext,
290 ) {
291 let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
292 (
293 editor.buffer().read(cx).snapshot(cx),
294 editor.selections.all::<Point>(cx),
295 )
296 });
297
298 let mut selections = Vec::<Selection<Point>>::new();
299 let mut newest_selection = None;
300 for mut selection in initial_selections {
301 if selection.end > selection.start {
302 selection.start.column = 0;
303 // If the selection ends at the start of the line, we don't want to include it.
304 if selection.end.column == 0 {
305 selection.end.row -= 1;
306 }
307 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
308 }
309
310 if let Some(prev_selection) = selections.last_mut() {
311 if selection.start <= prev_selection.end {
312 prev_selection.end = selection.end;
313 continue;
314 }
315 }
316
317 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
318 if selection.id > latest_selection.id {
319 *latest_selection = selection.clone();
320 }
321 selections.push(selection);
322 }
323 let newest_selection = newest_selection.unwrap();
324
325 let mut codegen_ranges = Vec::new();
326 for (excerpt_id, buffer, buffer_range) in
327 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
328 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
329 }))
330 {
331 let start = Anchor {
332 buffer_id: Some(buffer.remote_id()),
333 excerpt_id,
334 text_anchor: buffer.anchor_before(buffer_range.start),
335 };
336 let end = Anchor {
337 buffer_id: Some(buffer.remote_id()),
338 excerpt_id,
339 text_anchor: buffer.anchor_after(buffer_range.end),
340 };
341 codegen_ranges.push(start..end);
342
343 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
344 self.telemetry.report_assistant_event(AssistantEvent {
345 conversation_id: None,
346 kind: AssistantKind::Inline,
347 phase: AssistantPhase::Invoked,
348 message_id: None,
349 model: model.telemetry_id(),
350 model_provider: model.provider_id().to_string(),
351 response_latency: None,
352 error_message: None,
353 language_name: buffer.language().map(|language| language.name().to_proto()),
354 });
355 }
356 }
357
358 let assist_group_id = self.next_assist_group_id.post_inc();
359 let prompt_buffer = cx.new_model(|cx| {
360 MultiBuffer::singleton(cx.new_model(|cx| Buffer::local(String::new(), cx)), cx)
361 });
362
363 let mut assists = Vec::new();
364 let mut assist_to_focus = None;
365 for range in codegen_ranges {
366 let assist_id = self.next_assist_id.post_inc();
367 let context_store = cx.new_model(|_cx| ContextStore::new());
368 let codegen = cx.new_model(|cx| {
369 Codegen::new(
370 editor.read(cx).buffer().clone(),
371 range.clone(),
372 None,
373 context_store.clone(),
374 self.telemetry.clone(),
375 self.prompt_builder.clone(),
376 cx,
377 )
378 });
379
380 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
381 let prompt_editor = cx.new_view(|cx| {
382 PromptEditor::new(
383 assist_id,
384 gutter_dimensions.clone(),
385 self.prompt_history.clone(),
386 prompt_buffer.clone(),
387 codegen.clone(),
388 self.fs.clone(),
389 context_store,
390 workspace.clone(),
391 thread_store.clone(),
392 cx,
393 )
394 });
395
396 if assist_to_focus.is_none() {
397 let focus_assist = if newest_selection.reversed {
398 range.start.to_point(&snapshot) == newest_selection.start
399 } else {
400 range.end.to_point(&snapshot) == newest_selection.end
401 };
402 if focus_assist {
403 assist_to_focus = Some(assist_id);
404 }
405 }
406
407 let [prompt_block_id, end_block_id] =
408 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
409
410 assists.push((
411 assist_id,
412 range,
413 prompt_editor,
414 prompt_block_id,
415 end_block_id,
416 ));
417 }
418
419 let editor_assists = self
420 .assists_by_editor
421 .entry(editor.downgrade())
422 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
423 let mut assist_group = InlineAssistGroup::new();
424 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
425 self.assists.insert(
426 assist_id,
427 InlineAssist::new(
428 assist_id,
429 assist_group_id,
430 editor,
431 &prompt_editor,
432 prompt_block_id,
433 end_block_id,
434 range,
435 prompt_editor.read(cx).codegen.clone(),
436 workspace.clone(),
437 cx,
438 ),
439 );
440 assist_group.assist_ids.push(assist_id);
441 editor_assists.assist_ids.push(assist_id);
442 }
443 self.assist_groups.insert(assist_group_id, assist_group);
444
445 if let Some(assist_id) = assist_to_focus {
446 self.focus_assist(assist_id, cx);
447 }
448 }
449
450 #[allow(clippy::too_many_arguments)]
451 pub fn suggest_assist(
452 &mut self,
453 editor: &View<Editor>,
454 mut range: Range<Anchor>,
455 initial_prompt: String,
456 initial_transaction_id: Option<TransactionId>,
457 focus: bool,
458 workspace: WeakView<Workspace>,
459 thread_store: Option<WeakModel<ThreadStore>>,
460 cx: &mut WindowContext,
461 ) -> InlineAssistId {
462 let assist_group_id = self.next_assist_group_id.post_inc();
463 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
464 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
465
466 let assist_id = self.next_assist_id.post_inc();
467
468 let buffer = editor.read(cx).buffer().clone();
469 {
470 let snapshot = buffer.read(cx).read(cx);
471 range.start = range.start.bias_left(&snapshot);
472 range.end = range.end.bias_right(&snapshot);
473 }
474
475 let context_store = cx.new_model(|_cx| ContextStore::new());
476
477 let codegen = cx.new_model(|cx| {
478 Codegen::new(
479 editor.read(cx).buffer().clone(),
480 range.clone(),
481 initial_transaction_id,
482 context_store.clone(),
483 self.telemetry.clone(),
484 self.prompt_builder.clone(),
485 cx,
486 )
487 });
488
489 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
490 let prompt_editor = cx.new_view(|cx| {
491 PromptEditor::new(
492 assist_id,
493 gutter_dimensions.clone(),
494 self.prompt_history.clone(),
495 prompt_buffer.clone(),
496 codegen.clone(),
497 self.fs.clone(),
498 context_store,
499 workspace.clone(),
500 thread_store,
501 cx,
502 )
503 });
504
505 let [prompt_block_id, end_block_id] =
506 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
507
508 let editor_assists = self
509 .assists_by_editor
510 .entry(editor.downgrade())
511 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
512
513 let mut assist_group = InlineAssistGroup::new();
514 self.assists.insert(
515 assist_id,
516 InlineAssist::new(
517 assist_id,
518 assist_group_id,
519 editor,
520 &prompt_editor,
521 prompt_block_id,
522 end_block_id,
523 range,
524 prompt_editor.read(cx).codegen.clone(),
525 workspace.clone(),
526 cx,
527 ),
528 );
529 assist_group.assist_ids.push(assist_id);
530 editor_assists.assist_ids.push(assist_id);
531 self.assist_groups.insert(assist_group_id, assist_group);
532
533 if focus {
534 self.focus_assist(assist_id, cx);
535 }
536
537 assist_id
538 }
539
540 fn insert_assist_blocks(
541 &self,
542 editor: &View<Editor>,
543 range: &Range<Anchor>,
544 prompt_editor: &View<PromptEditor>,
545 cx: &mut WindowContext,
546 ) -> [CustomBlockId; 2] {
547 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
548 prompt_editor
549 .editor
550 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
551 });
552 let assist_blocks = vec![
553 BlockProperties {
554 style: BlockStyle::Sticky,
555 placement: BlockPlacement::Above(range.start),
556 height: prompt_editor_height,
557 render: build_assist_editor_renderer(prompt_editor),
558 priority: 0,
559 },
560 BlockProperties {
561 style: BlockStyle::Sticky,
562 placement: BlockPlacement::Below(range.end),
563 height: 0,
564 render: Arc::new(|cx| {
565 v_flex()
566 .h_full()
567 .w_full()
568 .border_t_1()
569 .border_color(cx.theme().status().info_border)
570 .into_any_element()
571 }),
572 priority: 0,
573 },
574 ];
575
576 editor.update(cx, |editor, cx| {
577 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
578 [block_ids[0], block_ids[1]]
579 })
580 }
581
582 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
583 let assist = &self.assists[&assist_id];
584 let Some(decorations) = assist.decorations.as_ref() else {
585 return;
586 };
587 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
588 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
589
590 assist_group.active_assist_id = Some(assist_id);
591 if assist_group.linked {
592 for assist_id in &assist_group.assist_ids {
593 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
594 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
595 prompt_editor.set_show_cursor_when_unfocused(true, cx)
596 });
597 }
598 }
599 }
600
601 assist
602 .editor
603 .update(cx, |editor, cx| {
604 let scroll_top = editor.scroll_position(cx).y;
605 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
606 let prompt_row = editor
607 .row_for_block(decorations.prompt_block_id, cx)
608 .unwrap()
609 .0 as f32;
610
611 if (scroll_top..scroll_bottom).contains(&prompt_row) {
612 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
613 assist_id,
614 distance_from_top: prompt_row - scroll_top,
615 });
616 } else {
617 editor_assists.scroll_lock = None;
618 }
619 })
620 .ok();
621 }
622
623 fn handle_prompt_editor_focus_out(
624 &mut self,
625 assist_id: InlineAssistId,
626 cx: &mut WindowContext,
627 ) {
628 let assist = &self.assists[&assist_id];
629 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
630 if assist_group.active_assist_id == Some(assist_id) {
631 assist_group.active_assist_id = None;
632 if assist_group.linked {
633 for assist_id in &assist_group.assist_ids {
634 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
635 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
636 prompt_editor.set_show_cursor_when_unfocused(false, cx)
637 });
638 }
639 }
640 }
641 }
642 }
643
644 fn handle_prompt_editor_event(
645 &mut self,
646 prompt_editor: View<PromptEditor>,
647 event: &PromptEditorEvent,
648 cx: &mut WindowContext,
649 ) {
650 let assist_id = prompt_editor.read(cx).id;
651 match event {
652 PromptEditorEvent::StartRequested => {
653 self.start_assist(assist_id, cx);
654 }
655 PromptEditorEvent::StopRequested => {
656 self.stop_assist(assist_id, cx);
657 }
658 PromptEditorEvent::ConfirmRequested { execute: _ } => {
659 self.finish_assist(assist_id, false, cx);
660 }
661 PromptEditorEvent::CancelRequested => {
662 self.finish_assist(assist_id, true, cx);
663 }
664 PromptEditorEvent::DismissRequested => {
665 self.dismiss_assist(assist_id, cx);
666 }
667 PromptEditorEvent::Resized { .. } => {
668 // This only matters for the terminal inline
669 }
670 }
671 }
672
673 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
674 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
675 return;
676 };
677
678 if editor.read(cx).selections.count() == 1 {
679 let (selection, buffer) = editor.update(cx, |editor, cx| {
680 (
681 editor.selections.newest::<usize>(cx),
682 editor.buffer().read(cx).snapshot(cx),
683 )
684 });
685 for assist_id in &editor_assists.assist_ids {
686 let assist = &self.assists[assist_id];
687 let assist_range = assist.range.to_offset(&buffer);
688 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
689 {
690 if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
691 self.dismiss_assist(*assist_id, cx);
692 } else {
693 self.finish_assist(*assist_id, false, cx);
694 }
695
696 return;
697 }
698 }
699 }
700
701 cx.propagate();
702 }
703
704 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
705 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
706 return;
707 };
708
709 if editor.read(cx).selections.count() == 1 {
710 let (selection, buffer) = editor.update(cx, |editor, cx| {
711 (
712 editor.selections.newest::<usize>(cx),
713 editor.buffer().read(cx).snapshot(cx),
714 )
715 });
716 let mut closest_assist_fallback = None;
717 for assist_id in &editor_assists.assist_ids {
718 let assist = &self.assists[assist_id];
719 let assist_range = assist.range.to_offset(&buffer);
720 if assist.decorations.is_some() {
721 if assist_range.contains(&selection.start)
722 && assist_range.contains(&selection.end)
723 {
724 self.focus_assist(*assist_id, cx);
725 return;
726 } else {
727 let distance_from_selection = assist_range
728 .start
729 .abs_diff(selection.start)
730 .min(assist_range.start.abs_diff(selection.end))
731 + assist_range
732 .end
733 .abs_diff(selection.start)
734 .min(assist_range.end.abs_diff(selection.end));
735 match closest_assist_fallback {
736 Some((_, old_distance)) => {
737 if distance_from_selection < old_distance {
738 closest_assist_fallback =
739 Some((assist_id, distance_from_selection));
740 }
741 }
742 None => {
743 closest_assist_fallback = Some((assist_id, distance_from_selection))
744 }
745 }
746 }
747 }
748 }
749
750 if let Some((&assist_id, _)) = closest_assist_fallback {
751 self.focus_assist(assist_id, cx);
752 }
753 }
754
755 cx.propagate();
756 }
757
758 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
759 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
760 for assist_id in editor_assists.assist_ids.clone() {
761 self.finish_assist(assist_id, true, cx);
762 }
763 }
764 }
765
766 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
767 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
768 return;
769 };
770 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
771 return;
772 };
773 let assist = &self.assists[&scroll_lock.assist_id];
774 let Some(decorations) = assist.decorations.as_ref() else {
775 return;
776 };
777
778 editor.update(cx, |editor, cx| {
779 let scroll_position = editor.scroll_position(cx);
780 let target_scroll_top = editor
781 .row_for_block(decorations.prompt_block_id, cx)
782 .unwrap()
783 .0 as f32
784 - scroll_lock.distance_from_top;
785 if target_scroll_top != scroll_position.y {
786 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
787 }
788 });
789 }
790
791 fn handle_editor_event(
792 &mut self,
793 editor: View<Editor>,
794 event: &EditorEvent,
795 cx: &mut WindowContext,
796 ) {
797 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
798 return;
799 };
800
801 match event {
802 EditorEvent::Edited { transaction_id } => {
803 let buffer = editor.read(cx).buffer().read(cx);
804 let edited_ranges =
805 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
806 let snapshot = buffer.snapshot(cx);
807
808 for assist_id in editor_assists.assist_ids.clone() {
809 let assist = &self.assists[&assist_id];
810 if matches!(
811 assist.codegen.read(cx).status(cx),
812 CodegenStatus::Error(_) | CodegenStatus::Done
813 ) {
814 let assist_range = assist.range.to_offset(&snapshot);
815 if edited_ranges
816 .iter()
817 .any(|range| range.overlaps(&assist_range))
818 {
819 self.finish_assist(assist_id, false, cx);
820 }
821 }
822 }
823 }
824 EditorEvent::ScrollPositionChanged { .. } => {
825 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
826 let assist = &self.assists[&scroll_lock.assist_id];
827 if let Some(decorations) = assist.decorations.as_ref() {
828 let distance_from_top = editor.update(cx, |editor, cx| {
829 let scroll_top = editor.scroll_position(cx).y;
830 let prompt_row = editor
831 .row_for_block(decorations.prompt_block_id, cx)
832 .unwrap()
833 .0 as f32;
834 prompt_row - scroll_top
835 });
836
837 if distance_from_top != scroll_lock.distance_from_top {
838 editor_assists.scroll_lock = None;
839 }
840 }
841 }
842 }
843 EditorEvent::SelectionsChanged { .. } => {
844 for assist_id in editor_assists.assist_ids.clone() {
845 let assist = &self.assists[&assist_id];
846 if let Some(decorations) = assist.decorations.as_ref() {
847 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
848 return;
849 }
850 }
851 }
852
853 editor_assists.scroll_lock = None;
854 }
855 _ => {}
856 }
857 }
858
859 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
860 if let Some(assist) = self.assists.get(&assist_id) {
861 let assist_group_id = assist.group_id;
862 if self.assist_groups[&assist_group_id].linked {
863 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
864 self.finish_assist(assist_id, undo, cx);
865 }
866 return;
867 }
868 }
869
870 self.dismiss_assist(assist_id, cx);
871
872 if let Some(assist) = self.assists.remove(&assist_id) {
873 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
874 {
875 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
876 if entry.get().assist_ids.is_empty() {
877 entry.remove();
878 }
879 }
880
881 if let hash_map::Entry::Occupied(mut entry) =
882 self.assists_by_editor.entry(assist.editor.clone())
883 {
884 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
885 if entry.get().assist_ids.is_empty() {
886 entry.remove();
887 if let Some(editor) = assist.editor.upgrade() {
888 self.update_editor_highlights(&editor, cx);
889 }
890 } else {
891 entry.get().highlight_updates.send(()).ok();
892 }
893 }
894
895 let active_alternative = assist.codegen.read(cx).active_alternative().clone();
896 let message_id = active_alternative.read(cx).message_id.clone();
897
898 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
899 let language_name = assist.editor.upgrade().and_then(|editor| {
900 let multibuffer = editor.read(cx).buffer().read(cx);
901 let ranges = multibuffer.range_to_buffer_ranges(assist.range.clone(), cx);
902 ranges
903 .first()
904 .and_then(|(buffer, _, _)| buffer.read(cx).language())
905 .map(|language| language.name())
906 });
907 report_assistant_event(
908 AssistantEvent {
909 conversation_id: None,
910 kind: AssistantKind::Inline,
911 message_id,
912 phase: if undo {
913 AssistantPhase::Rejected
914 } else {
915 AssistantPhase::Accepted
916 },
917 model: model.telemetry_id(),
918 model_provider: model.provider_id().to_string(),
919 response_latency: None,
920 error_message: None,
921 language_name: language_name.map(|name| name.to_proto()),
922 },
923 Some(self.telemetry.clone()),
924 cx.http_client(),
925 model.api_key(cx),
926 cx.background_executor(),
927 );
928 }
929
930 if undo {
931 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
932 } else {
933 self.confirmed_assists.insert(assist_id, active_alternative);
934 }
935 }
936 }
937
938 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
939 let Some(assist) = self.assists.get_mut(&assist_id) else {
940 return false;
941 };
942 let Some(editor) = assist.editor.upgrade() else {
943 return false;
944 };
945 let Some(decorations) = assist.decorations.take() else {
946 return false;
947 };
948
949 editor.update(cx, |editor, cx| {
950 let mut to_remove = decorations.removed_line_block_ids;
951 to_remove.insert(decorations.prompt_block_id);
952 to_remove.insert(decorations.end_block_id);
953 editor.remove_blocks(to_remove, None, cx);
954 });
955
956 if decorations
957 .prompt_editor
958 .focus_handle(cx)
959 .contains_focused(cx)
960 {
961 self.focus_next_assist(assist_id, cx);
962 }
963
964 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
965 if editor_assists
966 .scroll_lock
967 .as_ref()
968 .map_or(false, |lock| lock.assist_id == assist_id)
969 {
970 editor_assists.scroll_lock = None;
971 }
972 editor_assists.highlight_updates.send(()).ok();
973 }
974
975 true
976 }
977
978 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
979 let Some(assist) = self.assists.get(&assist_id) else {
980 return;
981 };
982
983 let assist_group = &self.assist_groups[&assist.group_id];
984 let assist_ix = assist_group
985 .assist_ids
986 .iter()
987 .position(|id| *id == assist_id)
988 .unwrap();
989 let assist_ids = assist_group
990 .assist_ids
991 .iter()
992 .skip(assist_ix + 1)
993 .chain(assist_group.assist_ids.iter().take(assist_ix));
994
995 for assist_id in assist_ids {
996 let assist = &self.assists[assist_id];
997 if assist.decorations.is_some() {
998 self.focus_assist(*assist_id, cx);
999 return;
1000 }
1001 }
1002
1003 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
1004 }
1005
1006 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1007 let Some(assist) = self.assists.get(&assist_id) else {
1008 return;
1009 };
1010
1011 if let Some(decorations) = assist.decorations.as_ref() {
1012 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
1013 prompt_editor.editor.update(cx, |editor, cx| {
1014 editor.focus(cx);
1015 editor.select_all(&SelectAll, cx);
1016 })
1017 });
1018 }
1019
1020 self.scroll_to_assist(assist_id, cx);
1021 }
1022
1023 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1024 let Some(assist) = self.assists.get(&assist_id) else {
1025 return;
1026 };
1027 let Some(editor) = assist.editor.upgrade() else {
1028 return;
1029 };
1030
1031 let position = assist.range.start;
1032 editor.update(cx, |editor, cx| {
1033 editor.change_selections(None, cx, |selections| {
1034 selections.select_anchor_ranges([position..position])
1035 });
1036
1037 let mut scroll_target_top;
1038 let mut scroll_target_bottom;
1039 if let Some(decorations) = assist.decorations.as_ref() {
1040 scroll_target_top = editor
1041 .row_for_block(decorations.prompt_block_id, cx)
1042 .unwrap()
1043 .0 as f32;
1044 scroll_target_bottom = editor
1045 .row_for_block(decorations.end_block_id, cx)
1046 .unwrap()
1047 .0 as f32;
1048 } else {
1049 let snapshot = editor.snapshot(cx);
1050 let start_row = assist
1051 .range
1052 .start
1053 .to_display_point(&snapshot.display_snapshot)
1054 .row();
1055 scroll_target_top = start_row.0 as f32;
1056 scroll_target_bottom = scroll_target_top + 1.;
1057 }
1058 scroll_target_top -= editor.vertical_scroll_margin() as f32;
1059 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
1060
1061 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
1062 let scroll_top = editor.scroll_position(cx).y;
1063 let scroll_bottom = scroll_top + height_in_lines;
1064
1065 if scroll_target_top < scroll_top {
1066 editor.set_scroll_position(point(0., scroll_target_top), cx);
1067 } else if scroll_target_bottom > scroll_bottom {
1068 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
1069 editor
1070 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
1071 } else {
1072 editor.set_scroll_position(point(0., scroll_target_top), cx);
1073 }
1074 }
1075 });
1076 }
1077
1078 fn unlink_assist_group(
1079 &mut self,
1080 assist_group_id: InlineAssistGroupId,
1081 cx: &mut WindowContext,
1082 ) -> Vec<InlineAssistId> {
1083 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
1084 assist_group.linked = false;
1085 for assist_id in &assist_group.assist_ids {
1086 let assist = self.assists.get_mut(assist_id).unwrap();
1087 if let Some(editor_decorations) = assist.decorations.as_ref() {
1088 editor_decorations
1089 .prompt_editor
1090 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
1091 }
1092 }
1093 assist_group.assist_ids.clone()
1094 }
1095
1096 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1097 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1098 assist
1099 } else {
1100 return;
1101 };
1102
1103 let assist_group_id = assist.group_id;
1104 if self.assist_groups[&assist_group_id].linked {
1105 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
1106 self.start_assist(assist_id, cx);
1107 }
1108 return;
1109 }
1110
1111 let Some(user_prompt) = assist.user_prompt(cx) else {
1112 return;
1113 };
1114
1115 self.prompt_history.retain(|prompt| *prompt != user_prompt);
1116 self.prompt_history.push_back(user_prompt.clone());
1117 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1118 self.prompt_history.pop_front();
1119 }
1120
1121 assist
1122 .codegen
1123 .update(cx, |codegen, cx| codegen.start(user_prompt, cx))
1124 .log_err();
1125 }
1126
1127 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1128 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1129 assist
1130 } else {
1131 return;
1132 };
1133
1134 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1135 }
1136
1137 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1138 let mut gutter_pending_ranges = Vec::new();
1139 let mut gutter_transformed_ranges = Vec::new();
1140 let mut foreground_ranges = Vec::new();
1141 let mut inserted_row_ranges = Vec::new();
1142 let empty_assist_ids = Vec::new();
1143 let assist_ids = self
1144 .assists_by_editor
1145 .get(&editor.downgrade())
1146 .map_or(&empty_assist_ids, |editor_assists| {
1147 &editor_assists.assist_ids
1148 });
1149
1150 for assist_id in assist_ids {
1151 if let Some(assist) = self.assists.get(assist_id) {
1152 let codegen = assist.codegen.read(cx);
1153 let buffer = codegen.buffer(cx).read(cx).read(cx);
1154 foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1155
1156 let pending_range =
1157 codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1158 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1159 gutter_pending_ranges.push(pending_range);
1160 }
1161
1162 if let Some(edit_position) = codegen.edit_position(cx) {
1163 let edited_range = assist.range.start..edit_position;
1164 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1165 gutter_transformed_ranges.push(edited_range);
1166 }
1167 }
1168
1169 if assist.decorations.is_some() {
1170 inserted_row_ranges
1171 .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1172 }
1173 }
1174 }
1175
1176 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1177 merge_ranges(&mut foreground_ranges, &snapshot);
1178 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1179 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1180 editor.update(cx, |editor, cx| {
1181 enum GutterPendingRange {}
1182 if gutter_pending_ranges.is_empty() {
1183 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1184 } else {
1185 editor.highlight_gutter::<GutterPendingRange>(
1186 &gutter_pending_ranges,
1187 |cx| cx.theme().status().info_background,
1188 cx,
1189 )
1190 }
1191
1192 enum GutterTransformedRange {}
1193 if gutter_transformed_ranges.is_empty() {
1194 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1195 } else {
1196 editor.highlight_gutter::<GutterTransformedRange>(
1197 &gutter_transformed_ranges,
1198 |cx| cx.theme().status().info,
1199 cx,
1200 )
1201 }
1202
1203 if foreground_ranges.is_empty() {
1204 editor.clear_highlights::<InlineAssist>(cx);
1205 } else {
1206 editor.highlight_text::<InlineAssist>(
1207 foreground_ranges,
1208 HighlightStyle {
1209 fade_out: Some(0.6),
1210 ..Default::default()
1211 },
1212 cx,
1213 );
1214 }
1215
1216 editor.clear_row_highlights::<InlineAssist>();
1217 for row_range in inserted_row_ranges {
1218 editor.highlight_rows::<InlineAssist>(
1219 row_range,
1220 cx.theme().status().info_background,
1221 false,
1222 cx,
1223 );
1224 }
1225 });
1226 }
1227
1228 fn update_editor_blocks(
1229 &mut self,
1230 editor: &View<Editor>,
1231 assist_id: InlineAssistId,
1232 cx: &mut WindowContext,
1233 ) {
1234 let Some(assist) = self.assists.get_mut(&assist_id) else {
1235 return;
1236 };
1237 let Some(decorations) = assist.decorations.as_mut() else {
1238 return;
1239 };
1240
1241 let codegen = assist.codegen.read(cx);
1242 let old_snapshot = codegen.snapshot(cx);
1243 let old_buffer = codegen.old_buffer(cx);
1244 let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1245
1246 editor.update(cx, |editor, cx| {
1247 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1248 editor.remove_blocks(old_blocks, None, cx);
1249
1250 let mut new_blocks = Vec::new();
1251 for (new_row, old_row_range) in deleted_row_ranges {
1252 let (_, buffer_start) = old_snapshot
1253 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1254 .unwrap();
1255 let (_, buffer_end) = old_snapshot
1256 .point_to_buffer_offset(Point::new(
1257 *old_row_range.end(),
1258 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1259 ))
1260 .unwrap();
1261
1262 let deleted_lines_editor = cx.new_view(|cx| {
1263 let multi_buffer = cx.new_model(|_| {
1264 MultiBuffer::without_headers(language::Capability::ReadOnly)
1265 });
1266 multi_buffer.update(cx, |multi_buffer, cx| {
1267 multi_buffer.push_excerpts(
1268 old_buffer.clone(),
1269 Some(ExcerptRange {
1270 context: buffer_start..buffer_end,
1271 primary: None,
1272 }),
1273 cx,
1274 );
1275 });
1276
1277 enum DeletedLines {}
1278 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1279 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1280 editor.set_show_wrap_guides(false, cx);
1281 editor.set_show_gutter(false, cx);
1282 editor.scroll_manager.set_forbid_vertical_scroll(true);
1283 editor.set_read_only(true);
1284 editor.set_show_inline_completions(Some(false), cx);
1285 editor.highlight_rows::<DeletedLines>(
1286 Anchor::min()..Anchor::max(),
1287 cx.theme().status().deleted_background,
1288 false,
1289 cx,
1290 );
1291 editor
1292 });
1293
1294 let height =
1295 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1296 new_blocks.push(BlockProperties {
1297 placement: BlockPlacement::Above(new_row),
1298 height,
1299 style: BlockStyle::Flex,
1300 render: Arc::new(move |cx| {
1301 div()
1302 .block_mouse_down()
1303 .bg(cx.theme().status().deleted_background)
1304 .size_full()
1305 .h(height as f32 * cx.line_height())
1306 .pl(cx.gutter_dimensions.full_width())
1307 .child(deleted_lines_editor.clone())
1308 .into_any_element()
1309 }),
1310 priority: 0,
1311 });
1312 }
1313
1314 decorations.removed_line_block_ids = editor
1315 .insert_blocks(new_blocks, None, cx)
1316 .into_iter()
1317 .collect();
1318 })
1319 }
1320
1321 fn resolve_inline_assist_target(
1322 workspace: &mut Workspace,
1323 cx: &mut WindowContext,
1324 ) -> Option<InlineAssistTarget> {
1325 if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx) {
1326 if terminal_panel
1327 .read(cx)
1328 .focus_handle(cx)
1329 .contains_focused(cx)
1330 {
1331 if let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| {
1332 pane.read(cx)
1333 .active_item()
1334 .and_then(|t| t.downcast::<TerminalView>())
1335 }) {
1336 return Some(InlineAssistTarget::Terminal(terminal_view));
1337 }
1338 }
1339 }
1340
1341 if let Some(workspace_editor) = workspace
1342 .active_item(cx)
1343 .and_then(|item| item.act_as::<Editor>(cx))
1344 {
1345 Some(InlineAssistTarget::Editor(workspace_editor))
1346 } else if let Some(terminal_view) = workspace
1347 .active_item(cx)
1348 .and_then(|item| item.act_as::<TerminalView>(cx))
1349 {
1350 Some(InlineAssistTarget::Terminal(terminal_view))
1351 } else {
1352 None
1353 }
1354 }
1355}
1356
1357struct EditorInlineAssists {
1358 assist_ids: Vec<InlineAssistId>,
1359 scroll_lock: Option<InlineAssistScrollLock>,
1360 highlight_updates: async_watch::Sender<()>,
1361 _update_highlights: Task<Result<()>>,
1362 _subscriptions: Vec<gpui::Subscription>,
1363}
1364
1365struct InlineAssistScrollLock {
1366 assist_id: InlineAssistId,
1367 distance_from_top: f32,
1368}
1369
1370impl EditorInlineAssists {
1371 #[allow(clippy::too_many_arguments)]
1372 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1373 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1374 Self {
1375 assist_ids: Vec::new(),
1376 scroll_lock: None,
1377 highlight_updates: highlight_updates_tx,
1378 _update_highlights: cx.spawn(|mut cx| {
1379 let editor = editor.downgrade();
1380 async move {
1381 while let Ok(()) = highlight_updates_rx.changed().await {
1382 let editor = editor.upgrade().context("editor was dropped")?;
1383 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1384 assistant.update_editor_highlights(&editor, cx);
1385 })?;
1386 }
1387 Ok(())
1388 }
1389 }),
1390 _subscriptions: vec![
1391 cx.observe_release(editor, {
1392 let editor = editor.downgrade();
1393 |_, cx| {
1394 InlineAssistant::update_global(cx, |this, cx| {
1395 this.handle_editor_release(editor, cx);
1396 })
1397 }
1398 }),
1399 cx.observe(editor, move |editor, cx| {
1400 InlineAssistant::update_global(cx, |this, cx| {
1401 this.handle_editor_change(editor, cx)
1402 })
1403 }),
1404 cx.subscribe(editor, move |editor, event, cx| {
1405 InlineAssistant::update_global(cx, |this, cx| {
1406 this.handle_editor_event(editor, event, cx)
1407 })
1408 }),
1409 editor.update(cx, |editor, cx| {
1410 let editor_handle = cx.view().downgrade();
1411 editor.register_action(
1412 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1413 InlineAssistant::update_global(cx, |this, cx| {
1414 if let Some(editor) = editor_handle.upgrade() {
1415 this.handle_editor_newline(editor, cx)
1416 }
1417 })
1418 },
1419 )
1420 }),
1421 editor.update(cx, |editor, cx| {
1422 let editor_handle = cx.view().downgrade();
1423 editor.register_action(
1424 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1425 InlineAssistant::update_global(cx, |this, cx| {
1426 if let Some(editor) = editor_handle.upgrade() {
1427 this.handle_editor_cancel(editor, cx)
1428 }
1429 })
1430 },
1431 )
1432 }),
1433 ],
1434 }
1435 }
1436}
1437
1438struct InlineAssistGroup {
1439 assist_ids: Vec<InlineAssistId>,
1440 linked: bool,
1441 active_assist_id: Option<InlineAssistId>,
1442}
1443
1444impl InlineAssistGroup {
1445 fn new() -> Self {
1446 Self {
1447 assist_ids: Vec::new(),
1448 linked: true,
1449 active_assist_id: None,
1450 }
1451 }
1452}
1453
1454fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1455 let editor = editor.clone();
1456 Arc::new(move |cx: &mut BlockContext| {
1457 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1458 editor.clone().into_any_element()
1459 })
1460}
1461
1462#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1463pub struct InlineAssistId(usize);
1464
1465impl InlineAssistId {
1466 fn post_inc(&mut self) -> InlineAssistId {
1467 let id = *self;
1468 self.0 += 1;
1469 id
1470 }
1471}
1472
1473#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1474struct InlineAssistGroupId(usize);
1475
1476impl InlineAssistGroupId {
1477 fn post_inc(&mut self) -> InlineAssistGroupId {
1478 let id = *self;
1479 self.0 += 1;
1480 id
1481 }
1482}
1483
1484struct PromptEditor {
1485 id: InlineAssistId,
1486 editor: View<Editor>,
1487 context_strip: View<ContextStrip>,
1488 context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
1489 language_model_selector: View<LanguageModelSelector>,
1490 edited_since_done: bool,
1491 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1492 prompt_history: VecDeque<String>,
1493 prompt_history_ix: Option<usize>,
1494 pending_prompt: String,
1495 codegen: Model<Codegen>,
1496 _codegen_subscription: Subscription,
1497 editor_subscriptions: Vec<Subscription>,
1498 show_rate_limit_notice: bool,
1499}
1500
1501impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1502
1503impl Render for PromptEditor {
1504 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1505 let gutter_dimensions = *self.gutter_dimensions.lock();
1506 let mut buttons = Vec::new();
1507 let codegen = self.codegen.read(cx);
1508 if codegen.alternative_count(cx) > 1 {
1509 buttons.push(self.render_cycle_controls(cx));
1510 }
1511 let prompt_mode = if codegen.is_insertion {
1512 PromptMode::Generate {
1513 supports_execute: false,
1514 }
1515 } else {
1516 PromptMode::Transform
1517 };
1518
1519 buttons.extend(render_cancel_button(
1520 codegen.status(cx).into(),
1521 self.edited_since_done,
1522 prompt_mode,
1523 cx,
1524 ));
1525
1526 v_flex()
1527 .border_y_1()
1528 .border_color(cx.theme().status().info_border)
1529 .size_full()
1530 .py(cx.line_height() / 2.5)
1531 .child(
1532 h_flex()
1533 .key_context("PromptEditor")
1534 .bg(cx.theme().colors().editor_background)
1535 .block_mouse_down()
1536 .cursor(CursorStyle::Arrow)
1537 .on_action(cx.listener(Self::toggle_context_picker))
1538 .on_action(cx.listener(Self::confirm))
1539 .on_action(cx.listener(Self::cancel))
1540 .on_action(cx.listener(Self::move_up))
1541 .on_action(cx.listener(Self::move_down))
1542 .capture_action(cx.listener(Self::cycle_prev))
1543 .capture_action(cx.listener(Self::cycle_next))
1544 .child(
1545 h_flex()
1546 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1547 .justify_center()
1548 .gap_2()
1549 .child(LanguageModelSelectorPopoverMenu::new(
1550 self.language_model_selector.clone(),
1551 IconButton::new("context", IconName::SettingsAlt)
1552 .shape(IconButtonShape::Square)
1553 .icon_size(IconSize::Small)
1554 .icon_color(Color::Muted)
1555 .tooltip(move |cx| {
1556 Tooltip::with_meta(
1557 format!(
1558 "Using {}",
1559 LanguageModelRegistry::read_global(cx)
1560 .active_model()
1561 .map(|model| model.name().0)
1562 .unwrap_or_else(|| "No model selected".into()),
1563 ),
1564 None,
1565 "Change Model",
1566 cx,
1567 )
1568 }),
1569 ))
1570 .map(|el| {
1571 let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx)
1572 else {
1573 return el;
1574 };
1575
1576 let error_message = SharedString::from(error.to_string());
1577 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1578 && cx.has_flag::<ZedPro>()
1579 {
1580 el.child(
1581 v_flex()
1582 .child(
1583 IconButton::new(
1584 "rate-limit-error",
1585 IconName::XCircle,
1586 )
1587 .toggle_state(self.show_rate_limit_notice)
1588 .shape(IconButtonShape::Square)
1589 .icon_size(IconSize::Small)
1590 .on_click(
1591 cx.listener(Self::toggle_rate_limit_notice),
1592 ),
1593 )
1594 .children(self.show_rate_limit_notice.then(|| {
1595 deferred(
1596 anchored()
1597 .position_mode(
1598 gpui::AnchoredPositionMode::Local,
1599 )
1600 .position(point(px(0.), px(24.)))
1601 .anchor(gpui::Corner::TopLeft)
1602 .child(self.render_rate_limit_notice(cx)),
1603 )
1604 })),
1605 )
1606 } else {
1607 el.child(
1608 div()
1609 .id("error")
1610 .tooltip(move |cx| {
1611 Tooltip::text(error_message.clone(), cx)
1612 })
1613 .child(
1614 Icon::new(IconName::XCircle)
1615 .size(IconSize::Small)
1616 .color(Color::Error),
1617 ),
1618 )
1619 }
1620 }),
1621 )
1622 .child(div().flex_1().child(self.render_editor(cx)))
1623 .child(h_flex().gap_2().pr_6().children(buttons)),
1624 )
1625 .child(
1626 h_flex()
1627 .child(
1628 h_flex()
1629 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1630 .justify_center()
1631 .gap_2(),
1632 )
1633 .child(self.context_strip.clone()),
1634 )
1635 }
1636}
1637
1638impl FocusableView for PromptEditor {
1639 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1640 self.editor.focus_handle(cx)
1641 }
1642}
1643
1644impl PromptEditor {
1645 const MAX_LINES: u8 = 8;
1646
1647 #[allow(clippy::too_many_arguments)]
1648 fn new(
1649 id: InlineAssistId,
1650 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1651 prompt_history: VecDeque<String>,
1652 prompt_buffer: Model<MultiBuffer>,
1653 codegen: Model<Codegen>,
1654 fs: Arc<dyn Fs>,
1655 context_store: Model<ContextStore>,
1656 workspace: WeakView<Workspace>,
1657 thread_store: Option<WeakModel<ThreadStore>>,
1658 cx: &mut ViewContext<Self>,
1659 ) -> Self {
1660 let prompt_editor = cx.new_view(|cx| {
1661 let mut editor = Editor::new(
1662 EditorMode::AutoHeight {
1663 max_lines: Self::MAX_LINES as usize,
1664 },
1665 prompt_buffer,
1666 None,
1667 false,
1668 cx,
1669 );
1670 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1671 // Since the prompt editors for all inline assistants are linked,
1672 // always show the cursor (even when it isn't focused) because
1673 // typing in one will make what you typed appear in all of them.
1674 editor.set_show_cursor_when_unfocused(true, cx);
1675 editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx), cx), cx);
1676 editor
1677 });
1678 let context_picker_menu_handle = PopoverMenuHandle::default();
1679
1680 let mut this = Self {
1681 id,
1682 editor: prompt_editor.clone(),
1683 context_strip: cx.new_view(|cx| {
1684 ContextStrip::new(
1685 context_store,
1686 workspace.clone(),
1687 thread_store.clone(),
1688 prompt_editor.focus_handle(cx),
1689 context_picker_menu_handle.clone(),
1690 cx,
1691 )
1692 }),
1693 context_picker_menu_handle,
1694 language_model_selector: cx.new_view(|cx| {
1695 let fs = fs.clone();
1696 LanguageModelSelector::new(
1697 move |model, cx| {
1698 update_settings_file::<AssistantSettings>(
1699 fs.clone(),
1700 cx,
1701 move |settings, _| settings.set_model(model.clone()),
1702 );
1703 },
1704 cx,
1705 )
1706 }),
1707 edited_since_done: false,
1708 gutter_dimensions,
1709 prompt_history,
1710 prompt_history_ix: None,
1711 pending_prompt: String::new(),
1712 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1713 editor_subscriptions: Vec::new(),
1714 codegen,
1715 show_rate_limit_notice: false,
1716 };
1717 this.subscribe_to_editor(cx);
1718 this
1719 }
1720
1721 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1722 self.editor_subscriptions.clear();
1723 self.editor_subscriptions
1724 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1725 }
1726
1727 fn set_show_cursor_when_unfocused(
1728 &mut self,
1729 show_cursor_when_unfocused: bool,
1730 cx: &mut ViewContext<Self>,
1731 ) {
1732 self.editor.update(cx, |editor, cx| {
1733 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1734 });
1735 }
1736
1737 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1738 let prompt = self.prompt(cx);
1739 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1740 self.editor = cx.new_view(|cx| {
1741 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1742 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1743 editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx), cx), cx);
1744 editor.set_placeholder_text("Add a prompt…", cx);
1745 editor.set_text(prompt, cx);
1746 if focus {
1747 editor.focus(cx);
1748 }
1749 editor
1750 });
1751 self.subscribe_to_editor(cx);
1752 }
1753
1754 fn placeholder_text(codegen: &Codegen, cx: &WindowContext) -> String {
1755 let action = if codegen.is_insertion {
1756 "Generate"
1757 } else {
1758 "Transform"
1759 };
1760 let assistant_panel_keybinding = ui::text_for_action(&crate::ToggleFocus, cx)
1761 .map(|keybinding| format!("{keybinding} to chat ― "))
1762 .unwrap_or_default();
1763
1764 format!("{action}… ({assistant_panel_keybinding}↓↑ for history)")
1765 }
1766
1767 fn prompt(&self, cx: &AppContext) -> String {
1768 self.editor.read(cx).text(cx)
1769 }
1770
1771 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1772 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1773 if self.show_rate_limit_notice {
1774 cx.focus_view(&self.editor);
1775 }
1776 cx.notify();
1777 }
1778
1779 fn handle_prompt_editor_events(
1780 &mut self,
1781 _: View<Editor>,
1782 event: &EditorEvent,
1783 cx: &mut ViewContext<Self>,
1784 ) {
1785 match event {
1786 EditorEvent::Edited { .. } => {
1787 if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
1788 workspace
1789 .update(cx, |workspace, cx| {
1790 let is_via_ssh = workspace
1791 .project()
1792 .update(cx, |project, _| project.is_via_ssh());
1793
1794 workspace
1795 .client()
1796 .telemetry()
1797 .log_edit_event("inline assist", is_via_ssh);
1798 })
1799 .log_err();
1800 }
1801 let prompt = self.editor.read(cx).text(cx);
1802 if self
1803 .prompt_history_ix
1804 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1805 {
1806 self.prompt_history_ix.take();
1807 self.pending_prompt = prompt;
1808 }
1809
1810 self.edited_since_done = true;
1811 cx.notify();
1812 }
1813 EditorEvent::Blurred => {
1814 if self.show_rate_limit_notice {
1815 self.show_rate_limit_notice = false;
1816 cx.notify();
1817 }
1818 }
1819 _ => {}
1820 }
1821 }
1822
1823 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1824 match self.codegen.read(cx).status(cx) {
1825 CodegenStatus::Idle => {
1826 self.editor
1827 .update(cx, |editor, _| editor.set_read_only(false));
1828 }
1829 CodegenStatus::Pending => {
1830 self.editor
1831 .update(cx, |editor, _| editor.set_read_only(true));
1832 }
1833 CodegenStatus::Done => {
1834 self.edited_since_done = false;
1835 self.editor
1836 .update(cx, |editor, _| editor.set_read_only(false));
1837 }
1838 CodegenStatus::Error(error) => {
1839 if cx.has_flag::<ZedPro>()
1840 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1841 && !dismissed_rate_limit_notice()
1842 {
1843 self.show_rate_limit_notice = true;
1844 cx.notify();
1845 }
1846
1847 self.edited_since_done = false;
1848 self.editor
1849 .update(cx, |editor, _| editor.set_read_only(false));
1850 }
1851 }
1852 }
1853
1854 fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext<Self>) {
1855 self.context_picker_menu_handle.toggle(cx);
1856 }
1857
1858 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1859 match self.codegen.read(cx).status(cx) {
1860 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1861 cx.emit(PromptEditorEvent::CancelRequested);
1862 }
1863 CodegenStatus::Pending => {
1864 cx.emit(PromptEditorEvent::StopRequested);
1865 }
1866 }
1867 }
1868
1869 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1870 match self.codegen.read(cx).status(cx) {
1871 CodegenStatus::Idle => {
1872 cx.emit(PromptEditorEvent::StartRequested);
1873 }
1874 CodegenStatus::Pending => {
1875 cx.emit(PromptEditorEvent::DismissRequested);
1876 }
1877 CodegenStatus::Done => {
1878 if self.edited_since_done {
1879 cx.emit(PromptEditorEvent::StartRequested);
1880 } else {
1881 cx.emit(PromptEditorEvent::ConfirmRequested { execute: false });
1882 }
1883 }
1884 CodegenStatus::Error(_) => {
1885 cx.emit(PromptEditorEvent::StartRequested);
1886 }
1887 }
1888 }
1889
1890 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1891 if let Some(ix) = self.prompt_history_ix {
1892 if ix > 0 {
1893 self.prompt_history_ix = Some(ix - 1);
1894 let prompt = self.prompt_history[ix - 1].as_str();
1895 self.editor.update(cx, |editor, cx| {
1896 editor.set_text(prompt, cx);
1897 editor.move_to_beginning(&Default::default(), cx);
1898 });
1899 }
1900 } else if !self.prompt_history.is_empty() {
1901 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1902 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1903 self.editor.update(cx, |editor, cx| {
1904 editor.set_text(prompt, cx);
1905 editor.move_to_beginning(&Default::default(), cx);
1906 });
1907 }
1908 }
1909
1910 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1911 if let Some(ix) = self.prompt_history_ix {
1912 if ix < self.prompt_history.len() - 1 {
1913 self.prompt_history_ix = Some(ix + 1);
1914 let prompt = self.prompt_history[ix + 1].as_str();
1915 self.editor.update(cx, |editor, cx| {
1916 editor.set_text(prompt, cx);
1917 editor.move_to_end(&Default::default(), cx)
1918 });
1919 } else {
1920 self.prompt_history_ix = None;
1921 let prompt = self.pending_prompt.as_str();
1922 self.editor.update(cx, |editor, cx| {
1923 editor.set_text(prompt, cx);
1924 editor.move_to_end(&Default::default(), cx)
1925 });
1926 }
1927 }
1928 }
1929
1930 fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
1931 self.codegen
1932 .update(cx, |codegen, cx| codegen.cycle_prev(cx));
1933 }
1934
1935 fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
1936 self.codegen
1937 .update(cx, |codegen, cx| codegen.cycle_next(cx));
1938 }
1939
1940 fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
1941 let codegen = self.codegen.read(cx);
1942 let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
1943
1944 let model_registry = LanguageModelRegistry::read_global(cx);
1945 let default_model = model_registry.active_model();
1946 let alternative_models = model_registry.inline_alternative_models();
1947
1948 let get_model_name = |index: usize| -> String {
1949 let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
1950
1951 match index {
1952 0 => default_model.as_ref().map_or_else(String::new, name),
1953 index if index <= alternative_models.len() => alternative_models
1954 .get(index - 1)
1955 .map_or_else(String::new, name),
1956 _ => String::new(),
1957 }
1958 };
1959
1960 let total_models = alternative_models.len() + 1;
1961
1962 if total_models <= 1 {
1963 return div().into_any_element();
1964 }
1965
1966 let current_index = codegen.active_alternative;
1967 let prev_index = (current_index + total_models - 1) % total_models;
1968 let next_index = (current_index + 1) % total_models;
1969
1970 let prev_model_name = get_model_name(prev_index);
1971 let next_model_name = get_model_name(next_index);
1972
1973 h_flex()
1974 .child(
1975 IconButton::new("previous", IconName::ChevronLeft)
1976 .icon_color(Color::Muted)
1977 .disabled(disabled || current_index == 0)
1978 .shape(IconButtonShape::Square)
1979 .tooltip({
1980 let focus_handle = self.editor.focus_handle(cx);
1981 move |cx| {
1982 cx.new_view(|cx| {
1983 let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
1984 KeyBinding::for_action_in(
1985 &CyclePreviousInlineAssist,
1986 &focus_handle,
1987 cx,
1988 ),
1989 );
1990 if !disabled && current_index != 0 {
1991 tooltip = tooltip.meta(prev_model_name.clone());
1992 }
1993 tooltip
1994 })
1995 .into()
1996 }
1997 })
1998 .on_click(cx.listener(|this, _, cx| {
1999 this.codegen
2000 .update(cx, |codegen, cx| codegen.cycle_prev(cx))
2001 })),
2002 )
2003 .child(
2004 Label::new(format!(
2005 "{}/{}",
2006 codegen.active_alternative + 1,
2007 codegen.alternative_count(cx)
2008 ))
2009 .size(LabelSize::Small)
2010 .color(if disabled {
2011 Color::Disabled
2012 } else {
2013 Color::Muted
2014 }),
2015 )
2016 .child(
2017 IconButton::new("next", IconName::ChevronRight)
2018 .icon_color(Color::Muted)
2019 .disabled(disabled || current_index == total_models - 1)
2020 .shape(IconButtonShape::Square)
2021 .tooltip({
2022 let focus_handle = self.editor.focus_handle(cx);
2023 move |cx| {
2024 cx.new_view(|cx| {
2025 let mut tooltip = Tooltip::new("Next Alternative").key_binding(
2026 KeyBinding::for_action_in(
2027 &CycleNextInlineAssist,
2028 &focus_handle,
2029 cx,
2030 ),
2031 );
2032 if !disabled && current_index != total_models - 1 {
2033 tooltip = tooltip.meta(next_model_name.clone());
2034 }
2035 tooltip
2036 })
2037 .into()
2038 }
2039 })
2040 .on_click(cx.listener(|this, _, cx| {
2041 this.codegen
2042 .update(cx, |codegen, cx| codegen.cycle_next(cx))
2043 })),
2044 )
2045 .into_any_element()
2046 }
2047
2048 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2049 Popover::new().child(
2050 v_flex()
2051 .occlude()
2052 .p_2()
2053 .child(
2054 Label::new("Out of Tokens")
2055 .size(LabelSize::Small)
2056 .weight(FontWeight::BOLD),
2057 )
2058 .child(Label::new(
2059 "Try Zed Pro for higher limits, a wider range of models, and more.",
2060 ))
2061 .child(
2062 h_flex()
2063 .justify_between()
2064 .child(CheckboxWithLabel::new(
2065 "dont-show-again",
2066 Label::new("Don't show again"),
2067 if dismissed_rate_limit_notice() {
2068 ui::ToggleState::Selected
2069 } else {
2070 ui::ToggleState::Unselected
2071 },
2072 |selection, cx| {
2073 let is_dismissed = match selection {
2074 ui::ToggleState::Unselected => false,
2075 ui::ToggleState::Indeterminate => return,
2076 ui::ToggleState::Selected => true,
2077 };
2078
2079 set_rate_limit_notice_dismissed(is_dismissed, cx)
2080 },
2081 ))
2082 .child(
2083 h_flex()
2084 .gap_2()
2085 .child(
2086 Button::new("dismiss", "Dismiss")
2087 .style(ButtonStyle::Transparent)
2088 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2089 )
2090 .child(Button::new("more-info", "More Info").on_click(
2091 |_event, cx| {
2092 cx.dispatch_action(Box::new(
2093 zed_actions::OpenAccountSettings,
2094 ))
2095 },
2096 )),
2097 ),
2098 ),
2099 )
2100 }
2101
2102 fn render_editor(&mut self, cx: &mut ViewContext<Self>) -> AnyElement {
2103 let font_size = TextSize::Default.rems(cx);
2104 let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
2105
2106 v_flex()
2107 .key_context("MessageEditor")
2108 .size_full()
2109 .gap_2()
2110 .p_2()
2111 .bg(cx.theme().colors().editor_background)
2112 .child({
2113 let settings = ThemeSettings::get_global(cx);
2114 let text_style = TextStyle {
2115 color: cx.theme().colors().editor_foreground,
2116 font_family: settings.ui_font.family.clone(),
2117 font_features: settings.ui_font.features.clone(),
2118 font_size: font_size.into(),
2119 font_weight: settings.ui_font.weight,
2120 line_height: line_height.into(),
2121 ..Default::default()
2122 };
2123
2124 EditorElement::new(
2125 &self.editor,
2126 EditorStyle {
2127 background: cx.theme().colors().editor_background,
2128 local_player: cx.theme().players().local(),
2129 text: text_style,
2130 ..Default::default()
2131 },
2132 )
2133 })
2134 .into_any_element()
2135 }
2136}
2137
2138const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2139
2140fn dismissed_rate_limit_notice() -> bool {
2141 db::kvp::KEY_VALUE_STORE
2142 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2143 .log_err()
2144 .map_or(false, |s| s.is_some())
2145}
2146
2147fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2148 db::write_and_log(cx, move || async move {
2149 if is_dismissed {
2150 db::kvp::KEY_VALUE_STORE
2151 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2152 .await
2153 } else {
2154 db::kvp::KEY_VALUE_STORE
2155 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2156 .await
2157 }
2158 })
2159}
2160
2161pub struct InlineAssist {
2162 group_id: InlineAssistGroupId,
2163 range: Range<Anchor>,
2164 editor: WeakView<Editor>,
2165 decorations: Option<InlineAssistDecorations>,
2166 codegen: Model<Codegen>,
2167 _subscriptions: Vec<Subscription>,
2168 workspace: WeakView<Workspace>,
2169}
2170
2171impl InlineAssist {
2172 #[allow(clippy::too_many_arguments)]
2173 fn new(
2174 assist_id: InlineAssistId,
2175 group_id: InlineAssistGroupId,
2176 editor: &View<Editor>,
2177 prompt_editor: &View<PromptEditor>,
2178 prompt_block_id: CustomBlockId,
2179 end_block_id: CustomBlockId,
2180 range: Range<Anchor>,
2181 codegen: Model<Codegen>,
2182 workspace: WeakView<Workspace>,
2183 cx: &mut WindowContext,
2184 ) -> Self {
2185 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2186 InlineAssist {
2187 group_id,
2188 editor: editor.downgrade(),
2189 decorations: Some(InlineAssistDecorations {
2190 prompt_block_id,
2191 prompt_editor: prompt_editor.clone(),
2192 removed_line_block_ids: HashSet::default(),
2193 end_block_id,
2194 }),
2195 range,
2196 codegen: codegen.clone(),
2197 workspace: workspace.clone(),
2198 _subscriptions: vec![
2199 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2200 InlineAssistant::update_global(cx, |this, cx| {
2201 this.handle_prompt_editor_focus_in(assist_id, cx)
2202 })
2203 }),
2204 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2205 InlineAssistant::update_global(cx, |this, cx| {
2206 this.handle_prompt_editor_focus_out(assist_id, cx)
2207 })
2208 }),
2209 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2210 InlineAssistant::update_global(cx, |this, cx| {
2211 this.handle_prompt_editor_event(prompt_editor, event, cx)
2212 })
2213 }),
2214 cx.observe(&codegen, {
2215 let editor = editor.downgrade();
2216 move |_, cx| {
2217 if let Some(editor) = editor.upgrade() {
2218 InlineAssistant::update_global(cx, |this, cx| {
2219 if let Some(editor_assists) =
2220 this.assists_by_editor.get(&editor.downgrade())
2221 {
2222 editor_assists.highlight_updates.send(()).ok();
2223 }
2224
2225 this.update_editor_blocks(&editor, assist_id, cx);
2226 })
2227 }
2228 }
2229 }),
2230 cx.subscribe(&codegen, move |codegen, event, cx| {
2231 InlineAssistant::update_global(cx, |this, cx| match event {
2232 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2233 CodegenEvent::Finished => {
2234 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2235 assist
2236 } else {
2237 return;
2238 };
2239
2240 if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2241 if assist.decorations.is_none() {
2242 if let Some(workspace) = assist.workspace.upgrade() {
2243 let error = format!("Inline assistant error: {}", error);
2244 workspace.update(cx, |workspace, cx| {
2245 struct InlineAssistantError;
2246
2247 let id =
2248 NotificationId::composite::<InlineAssistantError>(
2249 assist_id.0,
2250 );
2251
2252 workspace.show_toast(Toast::new(id, error), cx);
2253 })
2254 }
2255 }
2256 }
2257
2258 if assist.decorations.is_none() {
2259 this.finish_assist(assist_id, false, cx);
2260 }
2261 }
2262 })
2263 }),
2264 ],
2265 }
2266 }
2267
2268 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2269 let decorations = self.decorations.as_ref()?;
2270 Some(decorations.prompt_editor.read(cx).prompt(cx))
2271 }
2272}
2273
2274struct InlineAssistDecorations {
2275 prompt_block_id: CustomBlockId,
2276 prompt_editor: View<PromptEditor>,
2277 removed_line_block_ids: HashSet<CustomBlockId>,
2278 end_block_id: CustomBlockId,
2279}
2280
2281#[derive(Copy, Clone, Debug)]
2282pub enum CodegenEvent {
2283 Finished,
2284 Undone,
2285}
2286
2287pub struct Codegen {
2288 alternatives: Vec<Model<CodegenAlternative>>,
2289 active_alternative: usize,
2290 seen_alternatives: HashSet<usize>,
2291 subscriptions: Vec<Subscription>,
2292 buffer: Model<MultiBuffer>,
2293 range: Range<Anchor>,
2294 initial_transaction_id: Option<TransactionId>,
2295 context_store: Model<ContextStore>,
2296 telemetry: Arc<Telemetry>,
2297 builder: Arc<PromptBuilder>,
2298 is_insertion: bool,
2299}
2300
2301impl Codegen {
2302 pub fn new(
2303 buffer: Model<MultiBuffer>,
2304 range: Range<Anchor>,
2305 initial_transaction_id: Option<TransactionId>,
2306 context_store: Model<ContextStore>,
2307 telemetry: Arc<Telemetry>,
2308 builder: Arc<PromptBuilder>,
2309 cx: &mut ModelContext<Self>,
2310 ) -> Self {
2311 let codegen = cx.new_model(|cx| {
2312 CodegenAlternative::new(
2313 buffer.clone(),
2314 range.clone(),
2315 false,
2316 Some(context_store.clone()),
2317 Some(telemetry.clone()),
2318 builder.clone(),
2319 cx,
2320 )
2321 });
2322 let mut this = Self {
2323 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2324 alternatives: vec![codegen],
2325 active_alternative: 0,
2326 seen_alternatives: HashSet::default(),
2327 subscriptions: Vec::new(),
2328 buffer,
2329 range,
2330 initial_transaction_id,
2331 context_store,
2332 telemetry,
2333 builder,
2334 };
2335 this.activate(0, cx);
2336 this
2337 }
2338
2339 fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2340 let codegen = self.active_alternative().clone();
2341 self.subscriptions.clear();
2342 self.subscriptions
2343 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2344 self.subscriptions
2345 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2346 }
2347
2348 fn active_alternative(&self) -> &Model<CodegenAlternative> {
2349 &self.alternatives[self.active_alternative]
2350 }
2351
2352 fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2353 &self.active_alternative().read(cx).status
2354 }
2355
2356 fn alternative_count(&self, cx: &AppContext) -> usize {
2357 LanguageModelRegistry::read_global(cx)
2358 .inline_alternative_models()
2359 .len()
2360 + 1
2361 }
2362
2363 pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2364 let next_active_ix = if self.active_alternative == 0 {
2365 self.alternatives.len() - 1
2366 } else {
2367 self.active_alternative - 1
2368 };
2369 self.activate(next_active_ix, cx);
2370 }
2371
2372 pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2373 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2374 self.activate(next_active_ix, cx);
2375 }
2376
2377 fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2378 self.active_alternative()
2379 .update(cx, |codegen, cx| codegen.set_active(false, cx));
2380 self.seen_alternatives.insert(index);
2381 self.active_alternative = index;
2382 self.active_alternative()
2383 .update(cx, |codegen, cx| codegen.set_active(true, cx));
2384 self.subscribe_to_alternative(cx);
2385 cx.notify();
2386 }
2387
2388 pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
2389 let alternative_models = LanguageModelRegistry::read_global(cx)
2390 .inline_alternative_models()
2391 .to_vec();
2392
2393 self.active_alternative()
2394 .update(cx, |alternative, cx| alternative.undo(cx));
2395 self.activate(0, cx);
2396 self.alternatives.truncate(1);
2397
2398 for _ in 0..alternative_models.len() {
2399 self.alternatives.push(cx.new_model(|cx| {
2400 CodegenAlternative::new(
2401 self.buffer.clone(),
2402 self.range.clone(),
2403 false,
2404 Some(self.context_store.clone()),
2405 Some(self.telemetry.clone()),
2406 self.builder.clone(),
2407 cx,
2408 )
2409 }));
2410 }
2411
2412 let primary_model = LanguageModelRegistry::read_global(cx)
2413 .active_model()
2414 .context("no active model")?;
2415
2416 for (model, alternative) in iter::once(primary_model)
2417 .chain(alternative_models)
2418 .zip(&self.alternatives)
2419 {
2420 alternative.update(cx, |alternative, cx| {
2421 alternative.start(user_prompt.clone(), model.clone(), cx)
2422 })?;
2423 }
2424
2425 Ok(())
2426 }
2427
2428 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2429 for codegen in &self.alternatives {
2430 codegen.update(cx, |codegen, cx| codegen.stop(cx));
2431 }
2432 }
2433
2434 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2435 self.active_alternative()
2436 .update(cx, |codegen, cx| codegen.undo(cx));
2437
2438 self.buffer.update(cx, |buffer, cx| {
2439 if let Some(transaction_id) = self.initial_transaction_id.take() {
2440 buffer.undo_transaction(transaction_id, cx);
2441 buffer.refresh_preview(cx);
2442 }
2443 });
2444 }
2445
2446 pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2447 self.active_alternative().read(cx).buffer.clone()
2448 }
2449
2450 pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2451 self.active_alternative().read(cx).old_buffer.clone()
2452 }
2453
2454 pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2455 self.active_alternative().read(cx).snapshot.clone()
2456 }
2457
2458 pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2459 self.active_alternative().read(cx).edit_position
2460 }
2461
2462 fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2463 &self.active_alternative().read(cx).diff
2464 }
2465
2466 pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2467 self.active_alternative().read(cx).last_equal_ranges()
2468 }
2469}
2470
2471impl EventEmitter<CodegenEvent> for Codegen {}
2472
2473pub struct CodegenAlternative {
2474 buffer: Model<MultiBuffer>,
2475 old_buffer: Model<Buffer>,
2476 snapshot: MultiBufferSnapshot,
2477 edit_position: Option<Anchor>,
2478 range: Range<Anchor>,
2479 last_equal_ranges: Vec<Range<Anchor>>,
2480 transformation_transaction_id: Option<TransactionId>,
2481 status: CodegenStatus,
2482 generation: Task<()>,
2483 diff: Diff,
2484 context_store: Option<Model<ContextStore>>,
2485 telemetry: Option<Arc<Telemetry>>,
2486 _subscription: gpui::Subscription,
2487 builder: Arc<PromptBuilder>,
2488 active: bool,
2489 edits: Vec<(Range<Anchor>, String)>,
2490 line_operations: Vec<LineOperation>,
2491 request: Option<LanguageModelRequest>,
2492 elapsed_time: Option<f64>,
2493 completion: Option<String>,
2494 message_id: Option<String>,
2495}
2496
2497#[derive(Default)]
2498struct Diff {
2499 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2500 inserted_row_ranges: Vec<Range<Anchor>>,
2501}
2502
2503impl Diff {
2504 fn is_empty(&self) -> bool {
2505 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2506 }
2507}
2508
2509impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2510
2511impl CodegenAlternative {
2512 pub fn new(
2513 buffer: Model<MultiBuffer>,
2514 range: Range<Anchor>,
2515 active: bool,
2516 context_store: Option<Model<ContextStore>>,
2517 telemetry: Option<Arc<Telemetry>>,
2518 builder: Arc<PromptBuilder>,
2519 cx: &mut ModelContext<Self>,
2520 ) -> Self {
2521 let snapshot = buffer.read(cx).snapshot(cx);
2522
2523 let (old_buffer, _, _) = buffer
2524 .read(cx)
2525 .range_to_buffer_ranges(range.clone(), cx)
2526 .pop()
2527 .unwrap();
2528 let old_buffer = cx.new_model(|cx| {
2529 let old_buffer = old_buffer.read(cx);
2530 let text = old_buffer.as_rope().clone();
2531 let line_ending = old_buffer.line_ending();
2532 let language = old_buffer.language().cloned();
2533 let language_registry = old_buffer.language_registry();
2534
2535 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2536 buffer.set_language(language, cx);
2537 if let Some(language_registry) = language_registry {
2538 buffer.set_language_registry(language_registry)
2539 }
2540 buffer
2541 });
2542
2543 Self {
2544 buffer: buffer.clone(),
2545 old_buffer,
2546 edit_position: None,
2547 message_id: None,
2548 snapshot,
2549 last_equal_ranges: Default::default(),
2550 transformation_transaction_id: None,
2551 status: CodegenStatus::Idle,
2552 generation: Task::ready(()),
2553 diff: Diff::default(),
2554 context_store,
2555 telemetry,
2556 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2557 builder,
2558 active,
2559 edits: Vec::new(),
2560 line_operations: Vec::new(),
2561 range,
2562 request: None,
2563 elapsed_time: None,
2564 completion: None,
2565 }
2566 }
2567
2568 fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2569 if active != self.active {
2570 self.active = active;
2571
2572 if self.active {
2573 let edits = self.edits.clone();
2574 self.apply_edits(edits, cx);
2575 if matches!(self.status, CodegenStatus::Pending) {
2576 let line_operations = self.line_operations.clone();
2577 self.reapply_line_based_diff(line_operations, cx);
2578 } else {
2579 self.reapply_batch_diff(cx).detach();
2580 }
2581 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2582 self.buffer.update(cx, |buffer, cx| {
2583 buffer.undo_transaction(transaction_id, cx);
2584 buffer.forget_transaction(transaction_id, cx);
2585 });
2586 }
2587 }
2588 }
2589
2590 fn handle_buffer_event(
2591 &mut self,
2592 _buffer: Model<MultiBuffer>,
2593 event: &multi_buffer::Event,
2594 cx: &mut ModelContext<Self>,
2595 ) {
2596 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2597 if self.transformation_transaction_id == Some(*transaction_id) {
2598 self.transformation_transaction_id = None;
2599 self.generation = Task::ready(());
2600 cx.emit(CodegenEvent::Undone);
2601 }
2602 }
2603 }
2604
2605 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2606 &self.last_equal_ranges
2607 }
2608
2609 pub fn start(
2610 &mut self,
2611 user_prompt: String,
2612 model: Arc<dyn LanguageModel>,
2613 cx: &mut ModelContext<Self>,
2614 ) -> Result<()> {
2615 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2616 self.buffer.update(cx, |buffer, cx| {
2617 buffer.undo_transaction(transformation_transaction_id, cx);
2618 });
2619 }
2620
2621 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2622
2623 let api_key = model.api_key(cx);
2624 let telemetry_id = model.telemetry_id();
2625 let provider_id = model.provider_id();
2626 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2627 if user_prompt.trim().to_lowercase() == "delete" {
2628 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2629 } else {
2630 let request = self.build_request(user_prompt, cx)?;
2631 self.request = Some(request.clone());
2632
2633 cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
2634 .boxed_local()
2635 };
2636 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2637 Ok(())
2638 }
2639
2640 fn build_request(
2641 &self,
2642 user_prompt: String,
2643 cx: &mut AppContext,
2644 ) -> Result<LanguageModelRequest> {
2645 let buffer = self.buffer.read(cx).snapshot(cx);
2646 let language = buffer.language_at(self.range.start);
2647 let language_name = if let Some(language) = language.as_ref() {
2648 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2649 None
2650 } else {
2651 Some(language.name())
2652 }
2653 } else {
2654 None
2655 };
2656
2657 let language_name = language_name.as_ref();
2658 let start = buffer.point_to_buffer_offset(self.range.start);
2659 let end = buffer.point_to_buffer_offset(self.range.end);
2660 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2661 let (start_buffer, start_buffer_offset) = start;
2662 let (end_buffer, end_buffer_offset) = end;
2663 if start_buffer.remote_id() == end_buffer.remote_id() {
2664 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2665 } else {
2666 return Err(anyhow::anyhow!("invalid transformation range"));
2667 }
2668 } else {
2669 return Err(anyhow::anyhow!("invalid transformation range"));
2670 };
2671
2672 let prompt = self
2673 .builder
2674 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2675 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2676
2677 let mut request_message = LanguageModelRequestMessage {
2678 role: Role::User,
2679 content: Vec::new(),
2680 cache: false,
2681 };
2682
2683 if let Some(context_store) = &self.context_store {
2684 let context = context_store.update(cx, |this, _cx| this.context().clone());
2685 attach_context_to_message(&mut request_message, context);
2686 }
2687
2688 request_message.content.push(prompt.into());
2689
2690 Ok(LanguageModelRequest {
2691 tools: Vec::new(),
2692 stop: Vec::new(),
2693 temperature: None,
2694 messages: vec![request_message],
2695 })
2696 }
2697
2698 pub fn handle_stream(
2699 &mut self,
2700 model_telemetry_id: String,
2701 model_provider_id: String,
2702 model_api_key: Option<String>,
2703 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2704 cx: &mut ModelContext<Self>,
2705 ) {
2706 let start_time = Instant::now();
2707 let snapshot = self.snapshot.clone();
2708 let selected_text = snapshot
2709 .text_for_range(self.range.start..self.range.end)
2710 .collect::<Rope>();
2711
2712 let selection_start = self.range.start.to_point(&snapshot);
2713
2714 // Start with the indentation of the first line in the selection
2715 let mut suggested_line_indent = snapshot
2716 .suggested_indents(selection_start.row..=selection_start.row, cx)
2717 .into_values()
2718 .next()
2719 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2720
2721 // If the first line in the selection does not have indentation, check the following lines
2722 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2723 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2724 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2725 // Prefer tabs if a line in the selection uses tabs as indentation
2726 if line_indent.kind == IndentKind::Tab {
2727 suggested_line_indent.kind = IndentKind::Tab;
2728 break;
2729 }
2730 }
2731 }
2732
2733 let http_client = cx.http_client().clone();
2734 let telemetry = self.telemetry.clone();
2735 let language_name = {
2736 let multibuffer = self.buffer.read(cx);
2737 let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
2738 ranges
2739 .first()
2740 .and_then(|(buffer, _, _)| buffer.read(cx).language())
2741 .map(|language| language.name())
2742 };
2743
2744 self.diff = Diff::default();
2745 self.status = CodegenStatus::Pending;
2746 let mut edit_start = self.range.start.to_offset(&snapshot);
2747 let completion = Arc::new(Mutex::new(String::new()));
2748 let completion_clone = completion.clone();
2749
2750 self.generation = cx.spawn(|codegen, mut cx| {
2751 async move {
2752 let stream = stream.await;
2753 let message_id = stream
2754 .as_ref()
2755 .ok()
2756 .and_then(|stream| stream.message_id.clone());
2757 let generate = async {
2758 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2759 let executor = cx.background_executor().clone();
2760 let message_id = message_id.clone();
2761 let line_based_stream_diff: Task<anyhow::Result<()>> =
2762 cx.background_executor().spawn(async move {
2763 let mut response_latency = None;
2764 let request_start = Instant::now();
2765 let diff = async {
2766 let chunks = StripInvalidSpans::new(stream?.stream);
2767 futures::pin_mut!(chunks);
2768 let mut diff = StreamingDiff::new(selected_text.to_string());
2769 let mut line_diff = LineDiff::default();
2770
2771 let mut new_text = String::new();
2772 let mut base_indent = None;
2773 let mut line_indent = None;
2774 let mut first_line = true;
2775
2776 while let Some(chunk) = chunks.next().await {
2777 if response_latency.is_none() {
2778 response_latency = Some(request_start.elapsed());
2779 }
2780 let chunk = chunk?;
2781 completion_clone.lock().push_str(&chunk);
2782
2783 let mut lines = chunk.split('\n').peekable();
2784 while let Some(line) = lines.next() {
2785 new_text.push_str(line);
2786 if line_indent.is_none() {
2787 if let Some(non_whitespace_ch_ix) =
2788 new_text.find(|ch: char| !ch.is_whitespace())
2789 {
2790 line_indent = Some(non_whitespace_ch_ix);
2791 base_indent = base_indent.or(line_indent);
2792
2793 let line_indent = line_indent.unwrap();
2794 let base_indent = base_indent.unwrap();
2795 let indent_delta =
2796 line_indent as i32 - base_indent as i32;
2797 let mut corrected_indent_len = cmp::max(
2798 0,
2799 suggested_line_indent.len as i32 + indent_delta,
2800 )
2801 as usize;
2802 if first_line {
2803 corrected_indent_len = corrected_indent_len
2804 .saturating_sub(
2805 selection_start.column as usize,
2806 );
2807 }
2808
2809 let indent_char = suggested_line_indent.char();
2810 let mut indent_buffer = [0; 4];
2811 let indent_str =
2812 indent_char.encode_utf8(&mut indent_buffer);
2813 new_text.replace_range(
2814 ..line_indent,
2815 &indent_str.repeat(corrected_indent_len),
2816 );
2817 }
2818 }
2819
2820 if line_indent.is_some() {
2821 let char_ops = diff.push_new(&new_text);
2822 line_diff
2823 .push_char_operations(&char_ops, &selected_text);
2824 diff_tx
2825 .send((char_ops, line_diff.line_operations()))
2826 .await?;
2827 new_text.clear();
2828 }
2829
2830 if lines.peek().is_some() {
2831 let char_ops = diff.push_new("\n");
2832 line_diff
2833 .push_char_operations(&char_ops, &selected_text);
2834 diff_tx
2835 .send((char_ops, line_diff.line_operations()))
2836 .await?;
2837 if line_indent.is_none() {
2838 // Don't write out the leading indentation in empty lines on the next line
2839 // This is the case where the above if statement didn't clear the buffer
2840 new_text.clear();
2841 }
2842 line_indent = None;
2843 first_line = false;
2844 }
2845 }
2846 }
2847
2848 let mut char_ops = diff.push_new(&new_text);
2849 char_ops.extend(diff.finish());
2850 line_diff.push_char_operations(&char_ops, &selected_text);
2851 line_diff.finish(&selected_text);
2852 diff_tx
2853 .send((char_ops, line_diff.line_operations()))
2854 .await?;
2855
2856 anyhow::Ok(())
2857 };
2858
2859 let result = diff.await;
2860
2861 let error_message =
2862 result.as_ref().err().map(|error| error.to_string());
2863 report_assistant_event(
2864 AssistantEvent {
2865 conversation_id: None,
2866 message_id,
2867 kind: AssistantKind::Inline,
2868 phase: AssistantPhase::Response,
2869 model: model_telemetry_id,
2870 model_provider: model_provider_id.to_string(),
2871 response_latency,
2872 error_message,
2873 language_name: language_name.map(|name| name.to_proto()),
2874 },
2875 telemetry,
2876 http_client,
2877 model_api_key,
2878 &executor,
2879 );
2880
2881 result?;
2882 Ok(())
2883 });
2884
2885 while let Some((char_ops, line_ops)) = diff_rx.next().await {
2886 codegen.update(&mut cx, |codegen, cx| {
2887 codegen.last_equal_ranges.clear();
2888
2889 let edits = char_ops
2890 .into_iter()
2891 .filter_map(|operation| match operation {
2892 CharOperation::Insert { text } => {
2893 let edit_start = snapshot.anchor_after(edit_start);
2894 Some((edit_start..edit_start, text))
2895 }
2896 CharOperation::Delete { bytes } => {
2897 let edit_end = edit_start + bytes;
2898 let edit_range = snapshot.anchor_after(edit_start)
2899 ..snapshot.anchor_before(edit_end);
2900 edit_start = edit_end;
2901 Some((edit_range, String::new()))
2902 }
2903 CharOperation::Keep { bytes } => {
2904 let edit_end = edit_start + bytes;
2905 let edit_range = snapshot.anchor_after(edit_start)
2906 ..snapshot.anchor_before(edit_end);
2907 edit_start = edit_end;
2908 codegen.last_equal_ranges.push(edit_range);
2909 None
2910 }
2911 })
2912 .collect::<Vec<_>>();
2913
2914 if codegen.active {
2915 codegen.apply_edits(edits.iter().cloned(), cx);
2916 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2917 }
2918 codegen.edits.extend(edits);
2919 codegen.line_operations = line_ops;
2920 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2921
2922 cx.notify();
2923 })?;
2924 }
2925
2926 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2927 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2928 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2929 let batch_diff_task =
2930 codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
2931 let (line_based_stream_diff, ()) =
2932 join!(line_based_stream_diff, batch_diff_task);
2933 line_based_stream_diff?;
2934
2935 anyhow::Ok(())
2936 };
2937
2938 let result = generate.await;
2939 let elapsed_time = start_time.elapsed().as_secs_f64();
2940
2941 codegen
2942 .update(&mut cx, |this, cx| {
2943 this.message_id = message_id;
2944 this.last_equal_ranges.clear();
2945 if let Err(error) = result {
2946 this.status = CodegenStatus::Error(error);
2947 } else {
2948 this.status = CodegenStatus::Done;
2949 }
2950 this.elapsed_time = Some(elapsed_time);
2951 this.completion = Some(completion.lock().clone());
2952 cx.emit(CodegenEvent::Finished);
2953 cx.notify();
2954 })
2955 .ok();
2956 }
2957 });
2958 cx.notify();
2959 }
2960
2961 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2962 self.last_equal_ranges.clear();
2963 if self.diff.is_empty() {
2964 self.status = CodegenStatus::Idle;
2965 } else {
2966 self.status = CodegenStatus::Done;
2967 }
2968 self.generation = Task::ready(());
2969 cx.emit(CodegenEvent::Finished);
2970 cx.notify();
2971 }
2972
2973 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2974 self.buffer.update(cx, |buffer, cx| {
2975 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2976 buffer.undo_transaction(transaction_id, cx);
2977 buffer.refresh_preview(cx);
2978 }
2979 });
2980 }
2981
2982 fn apply_edits(
2983 &mut self,
2984 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
2985 cx: &mut ModelContext<CodegenAlternative>,
2986 ) {
2987 let transaction = self.buffer.update(cx, |buffer, cx| {
2988 // Avoid grouping assistant edits with user edits.
2989 buffer.finalize_last_transaction(cx);
2990 buffer.start_transaction(cx);
2991 buffer.edit(edits, None, cx);
2992 buffer.end_transaction(cx)
2993 });
2994
2995 if let Some(transaction) = transaction {
2996 if let Some(first_transaction) = self.transformation_transaction_id {
2997 // Group all assistant edits into the first transaction.
2998 self.buffer.update(cx, |buffer, cx| {
2999 buffer.merge_transactions(transaction, first_transaction, cx)
3000 });
3001 } else {
3002 self.transformation_transaction_id = Some(transaction);
3003 self.buffer
3004 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
3005 }
3006 }
3007 }
3008
3009 fn reapply_line_based_diff(
3010 &mut self,
3011 line_operations: impl IntoIterator<Item = LineOperation>,
3012 cx: &mut ModelContext<Self>,
3013 ) {
3014 let old_snapshot = self.snapshot.clone();
3015 let old_range = self.range.to_point(&old_snapshot);
3016 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3017 let new_range = self.range.to_point(&new_snapshot);
3018
3019 let mut old_row = old_range.start.row;
3020 let mut new_row = new_range.start.row;
3021
3022 self.diff.deleted_row_ranges.clear();
3023 self.diff.inserted_row_ranges.clear();
3024 for operation in line_operations {
3025 match operation {
3026 LineOperation::Keep { lines } => {
3027 old_row += lines;
3028 new_row += lines;
3029 }
3030 LineOperation::Delete { lines } => {
3031 let old_end_row = old_row + lines - 1;
3032 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3033
3034 if let Some((_, last_deleted_row_range)) =
3035 self.diff.deleted_row_ranges.last_mut()
3036 {
3037 if *last_deleted_row_range.end() + 1 == old_row {
3038 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3039 } else {
3040 self.diff
3041 .deleted_row_ranges
3042 .push((new_row, old_row..=old_end_row));
3043 }
3044 } else {
3045 self.diff
3046 .deleted_row_ranges
3047 .push((new_row, old_row..=old_end_row));
3048 }
3049
3050 old_row += lines;
3051 }
3052 LineOperation::Insert { lines } => {
3053 let new_end_row = new_row + lines - 1;
3054 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3055 let end = new_snapshot.anchor_before(Point::new(
3056 new_end_row,
3057 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3058 ));
3059 self.diff.inserted_row_ranges.push(start..end);
3060 new_row += lines;
3061 }
3062 }
3063
3064 cx.notify();
3065 }
3066 }
3067
3068 fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3069 let old_snapshot = self.snapshot.clone();
3070 let old_range = self.range.to_point(&old_snapshot);
3071 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3072 let new_range = self.range.to_point(&new_snapshot);
3073
3074 cx.spawn(|codegen, mut cx| async move {
3075 let (deleted_row_ranges, inserted_row_ranges) = cx
3076 .background_executor()
3077 .spawn(async move {
3078 let old_text = old_snapshot
3079 .text_for_range(
3080 Point::new(old_range.start.row, 0)
3081 ..Point::new(
3082 old_range.end.row,
3083 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3084 ),
3085 )
3086 .collect::<String>();
3087 let new_text = new_snapshot
3088 .text_for_range(
3089 Point::new(new_range.start.row, 0)
3090 ..Point::new(
3091 new_range.end.row,
3092 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3093 ),
3094 )
3095 .collect::<String>();
3096
3097 let mut old_row = old_range.start.row;
3098 let mut new_row = new_range.start.row;
3099 let batch_diff =
3100 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3101
3102 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3103 let mut inserted_row_ranges = Vec::new();
3104 for change in batch_diff.iter_all_changes() {
3105 let line_count = change.value().lines().count() as u32;
3106 match change.tag() {
3107 similar::ChangeTag::Equal => {
3108 old_row += line_count;
3109 new_row += line_count;
3110 }
3111 similar::ChangeTag::Delete => {
3112 let old_end_row = old_row + line_count - 1;
3113 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3114
3115 if let Some((_, last_deleted_row_range)) =
3116 deleted_row_ranges.last_mut()
3117 {
3118 if *last_deleted_row_range.end() + 1 == old_row {
3119 *last_deleted_row_range =
3120 *last_deleted_row_range.start()..=old_end_row;
3121 } else {
3122 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3123 }
3124 } else {
3125 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3126 }
3127
3128 old_row += line_count;
3129 }
3130 similar::ChangeTag::Insert => {
3131 let new_end_row = new_row + line_count - 1;
3132 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3133 let end = new_snapshot.anchor_before(Point::new(
3134 new_end_row,
3135 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3136 ));
3137 inserted_row_ranges.push(start..end);
3138 new_row += line_count;
3139 }
3140 }
3141 }
3142
3143 (deleted_row_ranges, inserted_row_ranges)
3144 })
3145 .await;
3146
3147 codegen
3148 .update(&mut cx, |codegen, cx| {
3149 codegen.diff.deleted_row_ranges = deleted_row_ranges;
3150 codegen.diff.inserted_row_ranges = inserted_row_ranges;
3151 cx.notify();
3152 })
3153 .ok();
3154 })
3155 }
3156}
3157
3158struct StripInvalidSpans<T> {
3159 stream: T,
3160 stream_done: bool,
3161 buffer: String,
3162 first_line: bool,
3163 line_end: bool,
3164 starts_with_code_block: bool,
3165}
3166
3167impl<T> StripInvalidSpans<T>
3168where
3169 T: Stream<Item = Result<String>>,
3170{
3171 fn new(stream: T) -> Self {
3172 Self {
3173 stream,
3174 stream_done: false,
3175 buffer: String::new(),
3176 first_line: true,
3177 line_end: false,
3178 starts_with_code_block: false,
3179 }
3180 }
3181}
3182
3183impl<T> Stream for StripInvalidSpans<T>
3184where
3185 T: Stream<Item = Result<String>>,
3186{
3187 type Item = Result<String>;
3188
3189 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3190 const CODE_BLOCK_DELIMITER: &str = "```";
3191 const CURSOR_SPAN: &str = "<|CURSOR|>";
3192
3193 let this = unsafe { self.get_unchecked_mut() };
3194 loop {
3195 if !this.stream_done {
3196 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3197 match stream.as_mut().poll_next(cx) {
3198 Poll::Ready(Some(Ok(chunk))) => {
3199 this.buffer.push_str(&chunk);
3200 }
3201 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3202 Poll::Ready(None) => {
3203 this.stream_done = true;
3204 }
3205 Poll::Pending => return Poll::Pending,
3206 }
3207 }
3208
3209 let mut chunk = String::new();
3210 let mut consumed = 0;
3211 if !this.buffer.is_empty() {
3212 let mut lines = this.buffer.split('\n').enumerate().peekable();
3213 while let Some((line_ix, line)) = lines.next() {
3214 if line_ix > 0 {
3215 this.first_line = false;
3216 }
3217
3218 if this.first_line {
3219 let trimmed_line = line.trim();
3220 if lines.peek().is_some() {
3221 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3222 consumed += line.len() + 1;
3223 this.starts_with_code_block = true;
3224 continue;
3225 }
3226 } else if trimmed_line.is_empty()
3227 || prefixes(CODE_BLOCK_DELIMITER)
3228 .any(|prefix| trimmed_line.starts_with(prefix))
3229 {
3230 break;
3231 }
3232 }
3233
3234 let line_without_cursor = line.replace(CURSOR_SPAN, "");
3235 if lines.peek().is_some() {
3236 if this.line_end {
3237 chunk.push('\n');
3238 }
3239
3240 chunk.push_str(&line_without_cursor);
3241 this.line_end = true;
3242 consumed += line.len() + 1;
3243 } else if this.stream_done {
3244 if !this.starts_with_code_block
3245 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3246 {
3247 if this.line_end {
3248 chunk.push('\n');
3249 }
3250
3251 chunk.push_str(&line);
3252 }
3253
3254 consumed += line.len();
3255 } else {
3256 let trimmed_line = line.trim();
3257 if trimmed_line.is_empty()
3258 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3259 || prefixes(CODE_BLOCK_DELIMITER)
3260 .any(|prefix| trimmed_line.ends_with(prefix))
3261 {
3262 break;
3263 } else {
3264 if this.line_end {
3265 chunk.push('\n');
3266 this.line_end = false;
3267 }
3268
3269 chunk.push_str(&line_without_cursor);
3270 consumed += line.len();
3271 }
3272 }
3273 }
3274 }
3275
3276 this.buffer = this.buffer.split_off(consumed);
3277 if !chunk.is_empty() {
3278 return Poll::Ready(Some(Ok(chunk)));
3279 } else if this.stream_done {
3280 return Poll::Ready(None);
3281 }
3282 }
3283 }
3284}
3285
3286struct AssistantCodeActionProvider {
3287 editor: WeakView<Editor>,
3288 workspace: WeakView<Workspace>,
3289 thread_store: Option<WeakModel<ThreadStore>>,
3290}
3291
3292impl CodeActionProvider for AssistantCodeActionProvider {
3293 fn code_actions(
3294 &self,
3295 buffer: &Model<Buffer>,
3296 range: Range<text::Anchor>,
3297 cx: &mut WindowContext,
3298 ) -> Task<Result<Vec<CodeAction>>> {
3299 if !AssistantSettings::get_global(cx).enabled {
3300 return Task::ready(Ok(Vec::new()));
3301 }
3302
3303 let snapshot = buffer.read(cx).snapshot();
3304 let mut range = range.to_point(&snapshot);
3305
3306 // Expand the range to line boundaries.
3307 range.start.column = 0;
3308 range.end.column = snapshot.line_len(range.end.row);
3309
3310 let mut has_diagnostics = false;
3311 for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3312 range.start = cmp::min(range.start, diagnostic.range.start);
3313 range.end = cmp::max(range.end, diagnostic.range.end);
3314 has_diagnostics = true;
3315 }
3316 if has_diagnostics {
3317 if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3318 if let Some(symbol) = symbols_containing_start.last() {
3319 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3320 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3321 }
3322 }
3323
3324 if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3325 if let Some(symbol) = symbols_containing_end.last() {
3326 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3327 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3328 }
3329 }
3330
3331 Task::ready(Ok(vec![CodeAction {
3332 server_id: language::LanguageServerId(0),
3333 range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3334 lsp_action: lsp::CodeAction {
3335 title: "Fix with Assistant".into(),
3336 ..Default::default()
3337 },
3338 }]))
3339 } else {
3340 Task::ready(Ok(Vec::new()))
3341 }
3342 }
3343
3344 fn apply_code_action(
3345 &self,
3346 buffer: Model<Buffer>,
3347 action: CodeAction,
3348 excerpt_id: ExcerptId,
3349 _push_to_history: bool,
3350 cx: &mut WindowContext,
3351 ) -> Task<Result<ProjectTransaction>> {
3352 let editor = self.editor.clone();
3353 let workspace = self.workspace.clone();
3354 let thread_store = self.thread_store.clone();
3355 cx.spawn(|mut cx| async move {
3356 let editor = editor.upgrade().context("editor was released")?;
3357 let range = editor
3358 .update(&mut cx, |editor, cx| {
3359 editor.buffer().update(cx, |multibuffer, cx| {
3360 let buffer = buffer.read(cx);
3361 let multibuffer_snapshot = multibuffer.read(cx);
3362
3363 let old_context_range =
3364 multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3365 let mut new_context_range = old_context_range.clone();
3366 if action
3367 .range
3368 .start
3369 .cmp(&old_context_range.start, buffer)
3370 .is_lt()
3371 {
3372 new_context_range.start = action.range.start;
3373 }
3374 if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3375 new_context_range.end = action.range.end;
3376 }
3377 drop(multibuffer_snapshot);
3378
3379 if new_context_range != old_context_range {
3380 multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3381 }
3382
3383 let multibuffer_snapshot = multibuffer.read(cx);
3384 Some(
3385 multibuffer_snapshot
3386 .anchor_in_excerpt(excerpt_id, action.range.start)?
3387 ..multibuffer_snapshot
3388 .anchor_in_excerpt(excerpt_id, action.range.end)?,
3389 )
3390 })
3391 })?
3392 .context("invalid range")?;
3393
3394 cx.update_global(|assistant: &mut InlineAssistant, cx| {
3395 let assist_id = assistant.suggest_assist(
3396 &editor,
3397 range,
3398 "Fix Diagnostics".into(),
3399 None,
3400 true,
3401 workspace,
3402 thread_store,
3403 cx,
3404 );
3405 assistant.start_assist(assist_id, cx);
3406 })?;
3407
3408 Ok(ProjectTransaction::default())
3409 })
3410 }
3411}
3412
3413fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3414 (0..text.len() - 1).map(|ix| &text[..ix + 1])
3415}
3416
3417fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3418 ranges.sort_unstable_by(|a, b| {
3419 a.start
3420 .cmp(&b.start, buffer)
3421 .then_with(|| b.end.cmp(&a.end, buffer))
3422 });
3423
3424 let mut ix = 0;
3425 while ix + 1 < ranges.len() {
3426 let b = ranges[ix + 1].clone();
3427 let a = &mut ranges[ix];
3428 if a.end.cmp(&b.start, buffer).is_gt() {
3429 if a.end.cmp(&b.end, buffer).is_lt() {
3430 a.end = b.end;
3431 }
3432 ranges.remove(ix + 1);
3433 } else {
3434 ix += 1;
3435 }
3436 }
3437}
3438
3439#[cfg(test)]
3440mod tests {
3441 use super::*;
3442 use futures::stream::{self};
3443 use gpui::{Context, TestAppContext};
3444 use indoc::indoc;
3445 use language::{
3446 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3447 Point,
3448 };
3449 use language_model::LanguageModelRegistry;
3450 use rand::prelude::*;
3451 use serde::Serialize;
3452 use settings::SettingsStore;
3453 use std::{future, sync::Arc};
3454
3455 #[derive(Serialize)]
3456 pub struct DummyCompletionRequest {
3457 pub name: String,
3458 }
3459
3460 #[gpui::test(iterations = 10)]
3461 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3462 cx.set_global(cx.update(SettingsStore::test));
3463 cx.update(language_model::LanguageModelRegistry::test);
3464 cx.update(language_settings::init);
3465
3466 let text = indoc! {"
3467 fn main() {
3468 let x = 0;
3469 for _ in 0..10 {
3470 x += 1;
3471 }
3472 }
3473 "};
3474 let buffer =
3475 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3476 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3477 let range = buffer.read_with(cx, |buffer, cx| {
3478 let snapshot = buffer.snapshot(cx);
3479 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3480 });
3481 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3482 let codegen = cx.new_model(|cx| {
3483 CodegenAlternative::new(
3484 buffer.clone(),
3485 range.clone(),
3486 true,
3487 None,
3488 None,
3489 prompt_builder,
3490 cx,
3491 )
3492 });
3493
3494 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3495
3496 let mut new_text = concat!(
3497 " let mut x = 0;\n",
3498 " while x < 10 {\n",
3499 " x += 1;\n",
3500 " }",
3501 );
3502 while !new_text.is_empty() {
3503 let max_len = cmp::min(new_text.len(), 10);
3504 let len = rng.gen_range(1..=max_len);
3505 let (chunk, suffix) = new_text.split_at(len);
3506 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3507 new_text = suffix;
3508 cx.background_executor.run_until_parked();
3509 }
3510 drop(chunks_tx);
3511 cx.background_executor.run_until_parked();
3512
3513 assert_eq!(
3514 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3515 indoc! {"
3516 fn main() {
3517 let mut x = 0;
3518 while x < 10 {
3519 x += 1;
3520 }
3521 }
3522 "}
3523 );
3524 }
3525
3526 #[gpui::test(iterations = 10)]
3527 async fn test_autoindent_when_generating_past_indentation(
3528 cx: &mut TestAppContext,
3529 mut rng: StdRng,
3530 ) {
3531 cx.set_global(cx.update(SettingsStore::test));
3532 cx.update(language_settings::init);
3533
3534 let text = indoc! {"
3535 fn main() {
3536 le
3537 }
3538 "};
3539 let buffer =
3540 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3541 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3542 let range = buffer.read_with(cx, |buffer, cx| {
3543 let snapshot = buffer.snapshot(cx);
3544 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3545 });
3546 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3547 let codegen = cx.new_model(|cx| {
3548 CodegenAlternative::new(
3549 buffer.clone(),
3550 range.clone(),
3551 true,
3552 None,
3553 None,
3554 prompt_builder,
3555 cx,
3556 )
3557 });
3558
3559 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3560
3561 cx.background_executor.run_until_parked();
3562
3563 let mut new_text = concat!(
3564 "t mut x = 0;\n",
3565 "while x < 10 {\n",
3566 " x += 1;\n",
3567 "}", //
3568 );
3569 while !new_text.is_empty() {
3570 let max_len = cmp::min(new_text.len(), 10);
3571 let len = rng.gen_range(1..=max_len);
3572 let (chunk, suffix) = new_text.split_at(len);
3573 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3574 new_text = suffix;
3575 cx.background_executor.run_until_parked();
3576 }
3577 drop(chunks_tx);
3578 cx.background_executor.run_until_parked();
3579
3580 assert_eq!(
3581 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3582 indoc! {"
3583 fn main() {
3584 let mut x = 0;
3585 while x < 10 {
3586 x += 1;
3587 }
3588 }
3589 "}
3590 );
3591 }
3592
3593 #[gpui::test(iterations = 10)]
3594 async fn test_autoindent_when_generating_before_indentation(
3595 cx: &mut TestAppContext,
3596 mut rng: StdRng,
3597 ) {
3598 cx.update(LanguageModelRegistry::test);
3599 cx.set_global(cx.update(SettingsStore::test));
3600 cx.update(language_settings::init);
3601
3602 let text = concat!(
3603 "fn main() {\n",
3604 " \n",
3605 "}\n" //
3606 );
3607 let buffer =
3608 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3609 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3610 let range = buffer.read_with(cx, |buffer, cx| {
3611 let snapshot = buffer.snapshot(cx);
3612 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3613 });
3614 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3615 let codegen = cx.new_model(|cx| {
3616 CodegenAlternative::new(
3617 buffer.clone(),
3618 range.clone(),
3619 true,
3620 None,
3621 None,
3622 prompt_builder,
3623 cx,
3624 )
3625 });
3626
3627 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3628
3629 cx.background_executor.run_until_parked();
3630
3631 let mut new_text = concat!(
3632 "let mut x = 0;\n",
3633 "while x < 10 {\n",
3634 " x += 1;\n",
3635 "}", //
3636 );
3637 while !new_text.is_empty() {
3638 let max_len = cmp::min(new_text.len(), 10);
3639 let len = rng.gen_range(1..=max_len);
3640 let (chunk, suffix) = new_text.split_at(len);
3641 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3642 new_text = suffix;
3643 cx.background_executor.run_until_parked();
3644 }
3645 drop(chunks_tx);
3646 cx.background_executor.run_until_parked();
3647
3648 assert_eq!(
3649 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3650 indoc! {"
3651 fn main() {
3652 let mut x = 0;
3653 while x < 10 {
3654 x += 1;
3655 }
3656 }
3657 "}
3658 );
3659 }
3660
3661 #[gpui::test(iterations = 10)]
3662 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3663 cx.update(LanguageModelRegistry::test);
3664 cx.set_global(cx.update(SettingsStore::test));
3665 cx.update(language_settings::init);
3666
3667 let text = indoc! {"
3668 func main() {
3669 \tx := 0
3670 \tfor i := 0; i < 10; i++ {
3671 \t\tx++
3672 \t}
3673 }
3674 "};
3675 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3676 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3677 let range = buffer.read_with(cx, |buffer, cx| {
3678 let snapshot = buffer.snapshot(cx);
3679 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3680 });
3681 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3682 let codegen = cx.new_model(|cx| {
3683 CodegenAlternative::new(
3684 buffer.clone(),
3685 range.clone(),
3686 true,
3687 None,
3688 None,
3689 prompt_builder,
3690 cx,
3691 )
3692 });
3693
3694 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3695 let new_text = concat!(
3696 "func main() {\n",
3697 "\tx := 0\n",
3698 "\tfor x < 10 {\n",
3699 "\t\tx++\n",
3700 "\t}", //
3701 );
3702 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3703 drop(chunks_tx);
3704 cx.background_executor.run_until_parked();
3705
3706 assert_eq!(
3707 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3708 indoc! {"
3709 func main() {
3710 \tx := 0
3711 \tfor x < 10 {
3712 \t\tx++
3713 \t}
3714 }
3715 "}
3716 );
3717 }
3718
3719 #[gpui::test]
3720 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3721 cx.update(LanguageModelRegistry::test);
3722 cx.set_global(cx.update(SettingsStore::test));
3723 cx.update(language_settings::init);
3724
3725 let text = indoc! {"
3726 fn main() {
3727 let x = 0;
3728 }
3729 "};
3730 let buffer =
3731 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3732 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3733 let range = buffer.read_with(cx, |buffer, cx| {
3734 let snapshot = buffer.snapshot(cx);
3735 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3736 });
3737 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3738 let codegen = cx.new_model(|cx| {
3739 CodegenAlternative::new(
3740 buffer.clone(),
3741 range.clone(),
3742 false,
3743 None,
3744 None,
3745 prompt_builder,
3746 cx,
3747 )
3748 });
3749
3750 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3751 chunks_tx
3752 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3753 .unwrap();
3754 drop(chunks_tx);
3755 cx.run_until_parked();
3756
3757 // The codegen is inactive, so the buffer doesn't get modified.
3758 assert_eq!(
3759 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3760 text
3761 );
3762
3763 // Activating the codegen applies the changes.
3764 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3765 assert_eq!(
3766 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3767 indoc! {"
3768 fn main() {
3769 let mut x = 0;
3770 x += 1;
3771 }
3772 "}
3773 );
3774
3775 // Deactivating the codegen undoes the changes.
3776 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3777 cx.run_until_parked();
3778 assert_eq!(
3779 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3780 text
3781 );
3782 }
3783
3784 #[gpui::test]
3785 async fn test_strip_invalid_spans_from_codeblock() {
3786 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3787 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3788 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3789 assert_chunks(
3790 "```html\n```js\nLorem ipsum dolor\n```\n```",
3791 "```js\nLorem ipsum dolor\n```",
3792 )
3793 .await;
3794 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3795 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3796 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3797 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3798
3799 async fn assert_chunks(text: &str, expected_text: &str) {
3800 for chunk_size in 1..=text.len() {
3801 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3802 .map(|chunk| chunk.unwrap())
3803 .collect::<String>()
3804 .await;
3805 assert_eq!(
3806 actual_text, expected_text,
3807 "failed to strip invalid spans, chunk size: {}",
3808 chunk_size
3809 );
3810 }
3811 }
3812
3813 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3814 stream::iter(
3815 text.chars()
3816 .collect::<Vec<_>>()
3817 .chunks(size)
3818 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3819 .collect::<Vec<_>>(),
3820 )
3821 }
3822 }
3823
3824 fn simulate_response_stream(
3825 codegen: Model<CodegenAlternative>,
3826 cx: &mut TestAppContext,
3827 ) -> mpsc::UnboundedSender<String> {
3828 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3829 codegen.update(cx, |codegen, cx| {
3830 codegen.handle_stream(
3831 String::new(),
3832 String::new(),
3833 None,
3834 future::ready(Ok(LanguageModelTextStream {
3835 message_id: None,
3836 stream: chunks_rx.map(Ok).boxed(),
3837 })),
3838 cx,
3839 );
3840 });
3841 chunks_tx
3842 }
3843
3844 fn rust_lang() -> Language {
3845 Language::new(
3846 LanguageConfig {
3847 name: "Rust".into(),
3848 matcher: LanguageMatcher {
3849 path_suffixes: vec!["rs".to_string()],
3850 ..Default::default()
3851 },
3852 ..Default::default()
3853 },
3854 Some(tree_sitter_rust::LANGUAGE.into()),
3855 )
3856 .with_indents_query(
3857 r#"
3858 (call_expression) @indent
3859 (field_expression) @indent
3860 (_ "(" ")" @end) @indent
3861 (_ "{" "}" @end) @indent
3862 "#,
3863 )
3864 .unwrap()
3865 }
3866}