1use crate::{
2 assistant_settings::AssistantSettings,
3 prompts::PromptBuilder,
4 streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
5 terminal_inline_assistant::TerminalInlineAssistant,
6 CycleNextInlineAssist, CyclePreviousInlineAssist, ToggleInlineAssist,
7};
8use anyhow::{Context as _, Result};
9use client::{telemetry::Telemetry, ErrorExt};
10use collections::{hash_map, HashMap, HashSet, VecDeque};
11use editor::{
12 actions::{MoveDown, MoveUp, SelectAll},
13 display_map::{
14 BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
15 ToDisplayPoint,
16 },
17 Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
18 EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
19 ToOffset as _, ToPoint,
20};
21use feature_flags::{FeatureFlagAppExt as _, ZedPro};
22use fs::Fs;
23use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
24use gpui::{
25 anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
26 FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext,
27 Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakView, WindowContext,
28};
29use language::{Buffer, IndentKind, Point, Selection, TransactionId};
30use language_model::{
31 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
32 LanguageModelTextStream, Role,
33};
34use language_model_selector::LanguageModelSelector;
35use language_models::report_assistant_event;
36use multi_buffer::MultiBufferRow;
37use parking_lot::Mutex;
38use project::{CodeAction, ProjectTransaction};
39use rope::Rope;
40use settings::{update_settings_file, Settings, SettingsStore};
41use smol::future::FutureExt;
42use std::{
43 cmp,
44 future::Future,
45 iter, mem,
46 ops::{Range, RangeInclusive},
47 pin::Pin,
48 sync::Arc,
49 task::{self, Poll},
50 time::Instant,
51};
52use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
53use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
54use text::{OffsetRangeExt, ToPoint as _};
55use theme::ThemeSettings;
56use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, Tooltip};
57use util::{RangeExt, ResultExt};
58use workspace::{dock::Panel, ShowConfiguration};
59use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
60
61pub fn init(
62 fs: Arc<dyn Fs>,
63 prompt_builder: Arc<PromptBuilder>,
64 telemetry: Arc<Telemetry>,
65 cx: &mut AppContext,
66) {
67 cx.set_global(InlineAssistant::new(fs, prompt_builder, telemetry));
68 cx.observe_new_views(|workspace: &mut Workspace, cx| {
69 workspace.register_action(InlineAssistant::toggle_inline_assist);
70
71 let workspace = cx.view().clone();
72 InlineAssistant::update_global(cx, |inline_assistant, cx| {
73 inline_assistant.register_workspace(&workspace, cx)
74 })
75 })
76 .detach();
77}
78
79const PROMPT_HISTORY_MAX_LEN: usize = 20;
80
81enum InlineAssistTarget {
82 Editor(View<Editor>),
83 Terminal(View<TerminalView>),
84}
85
86pub struct InlineAssistant {
87 next_assist_id: InlineAssistId,
88 next_assist_group_id: InlineAssistGroupId,
89 assists: HashMap<InlineAssistId, InlineAssist>,
90 assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
91 assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
92 confirmed_assists: HashMap<InlineAssistId, Model<CodegenAlternative>>,
93 prompt_history: VecDeque<String>,
94 prompt_builder: Arc<PromptBuilder>,
95 telemetry: Arc<Telemetry>,
96 fs: Arc<dyn Fs>,
97}
98
99impl Global for InlineAssistant {}
100
101impl InlineAssistant {
102 pub fn new(
103 fs: Arc<dyn Fs>,
104 prompt_builder: Arc<PromptBuilder>,
105 telemetry: Arc<Telemetry>,
106 ) -> Self {
107 Self {
108 next_assist_id: InlineAssistId::default(),
109 next_assist_group_id: InlineAssistGroupId::default(),
110 assists: HashMap::default(),
111 assists_by_editor: HashMap::default(),
112 assist_groups: HashMap::default(),
113 confirmed_assists: HashMap::default(),
114 prompt_history: VecDeque::default(),
115 prompt_builder,
116 telemetry,
117 fs,
118 }
119 }
120
121 pub fn register_workspace(&mut self, workspace: &View<Workspace>, cx: &mut WindowContext) {
122 cx.subscribe(workspace, |workspace, event, cx| {
123 Self::update_global(cx, |this, cx| {
124 this.handle_workspace_event(workspace, event, cx)
125 });
126 })
127 .detach();
128
129 let workspace = workspace.downgrade();
130 cx.observe_global::<SettingsStore>(move |cx| {
131 let Some(workspace) = workspace.upgrade() else {
132 return;
133 };
134 let Some(terminal_panel) = workspace.read(cx).panel::<TerminalPanel>(cx) else {
135 return;
136 };
137 let enabled = AssistantSettings::get_global(cx).enabled;
138 terminal_panel.update(cx, |terminal_panel, cx| {
139 terminal_panel.asssistant_enabled(enabled, cx)
140 });
141 })
142 .detach();
143 }
144
145 fn handle_workspace_event(
146 &mut self,
147 workspace: View<Workspace>,
148 event: &workspace::Event,
149 cx: &mut WindowContext,
150 ) {
151 match event {
152 workspace::Event::UserSavedItem { item, .. } => {
153 // When the user manually saves an editor, automatically accepts all finished transformations.
154 if let Some(editor) = item.upgrade().and_then(|item| item.act_as::<Editor>(cx)) {
155 if let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) {
156 for assist_id in editor_assists.assist_ids.clone() {
157 let assist = &self.assists[&assist_id];
158 if let CodegenStatus::Done = assist.codegen.read(cx).status(cx) {
159 self.finish_assist(assist_id, false, cx)
160 }
161 }
162 }
163 }
164 }
165 workspace::Event::ItemAdded { item } => {
166 self.register_workspace_item(&workspace, item.as_ref(), cx);
167 }
168 _ => (),
169 }
170 }
171
172 fn register_workspace_item(
173 &mut self,
174 workspace: &View<Workspace>,
175 item: &dyn ItemHandle,
176 cx: &mut WindowContext,
177 ) {
178 if let Some(editor) = item.act_as::<Editor>(cx) {
179 editor.update(cx, |editor, cx| {
180 editor.push_code_action_provider(
181 Arc::new(AssistantCodeActionProvider {
182 editor: cx.view().downgrade(),
183 workspace: workspace.downgrade(),
184 }),
185 cx,
186 );
187 });
188 }
189 }
190
191 pub fn toggle_inline_assist(
192 workspace: &mut Workspace,
193 _action: &ToggleInlineAssist,
194 cx: &mut ViewContext<Workspace>,
195 ) {
196 let settings = AssistantSettings::get_global(cx);
197 if !settings.enabled {
198 return;
199 }
200
201 let Some(inline_assist_target) = Self::resolve_inline_assist_target(workspace, cx) else {
202 return;
203 };
204
205 let is_authenticated = || {
206 LanguageModelRegistry::read_global(cx)
207 .active_provider()
208 .map_or(false, |provider| provider.is_authenticated(cx))
209 };
210
211 let handle_assist = |cx: &mut ViewContext<Workspace>| match inline_assist_target {
212 InlineAssistTarget::Editor(active_editor) => {
213 InlineAssistant::update_global(cx, |assistant, cx| {
214 assistant.assist(&active_editor, Some(cx.view().downgrade()), cx)
215 })
216 }
217 InlineAssistTarget::Terminal(active_terminal) => {
218 TerminalInlineAssistant::update_global(cx, |assistant, cx| {
219 assistant.assist(&active_terminal, Some(cx.view().downgrade()), cx)
220 })
221 }
222 };
223
224 if is_authenticated() {
225 handle_assist(cx);
226 } else {
227 cx.spawn(|_workspace, mut cx| async move {
228 let Some(task) = cx.update(|cx| {
229 LanguageModelRegistry::read_global(cx)
230 .active_provider()
231 .map_or(None, |provider| Some(provider.authenticate(cx)))
232 })?
233 else {
234 let answer = cx
235 .prompt(
236 gpui::PromptLevel::Warning,
237 "No language model provider configured",
238 None,
239 &["Configure", "Cancel"],
240 )
241 .await
242 .ok();
243 if let Some(answer) = answer {
244 if answer == 0 {
245 cx.update(|cx| cx.dispatch_action(Box::new(ShowConfiguration)))
246 .ok();
247 }
248 }
249 return Ok(());
250 };
251 task.await?;
252
253 anyhow::Ok(())
254 })
255 .detach_and_log_err(cx);
256
257 if is_authenticated() {
258 handle_assist(cx);
259 }
260 }
261 }
262
263 pub fn assist(
264 &mut self,
265 editor: &View<Editor>,
266 workspace: Option<WeakView<Workspace>>,
267 cx: &mut WindowContext,
268 ) {
269 let (snapshot, initial_selections) = editor.update(cx, |editor, cx| {
270 (
271 editor.buffer().read(cx).snapshot(cx),
272 editor.selections.all::<Point>(cx),
273 )
274 });
275
276 let mut selections = Vec::<Selection<Point>>::new();
277 let mut newest_selection = None;
278 for mut selection in initial_selections {
279 if selection.end > selection.start {
280 selection.start.column = 0;
281 // If the selection ends at the start of the line, we don't want to include it.
282 if selection.end.column == 0 {
283 selection.end.row -= 1;
284 }
285 selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
286 }
287
288 if let Some(prev_selection) = selections.last_mut() {
289 if selection.start <= prev_selection.end {
290 prev_selection.end = selection.end;
291 continue;
292 }
293 }
294
295 let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
296 if selection.id > latest_selection.id {
297 *latest_selection = selection.clone();
298 }
299 selections.push(selection);
300 }
301 let newest_selection = newest_selection.unwrap();
302
303 let mut codegen_ranges = Vec::new();
304 for (excerpt_id, buffer, buffer_range) in
305 snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
306 snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
307 }))
308 {
309 let start = Anchor {
310 buffer_id: Some(buffer.remote_id()),
311 excerpt_id,
312 text_anchor: buffer.anchor_before(buffer_range.start),
313 };
314 let end = Anchor {
315 buffer_id: Some(buffer.remote_id()),
316 excerpt_id,
317 text_anchor: buffer.anchor_after(buffer_range.end),
318 };
319 codegen_ranges.push(start..end);
320
321 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
322 self.telemetry.report_assistant_event(AssistantEvent {
323 conversation_id: None,
324 kind: AssistantKind::Inline,
325 phase: AssistantPhase::Invoked,
326 message_id: None,
327 model: model.telemetry_id(),
328 model_provider: model.provider_id().to_string(),
329 response_latency: None,
330 error_message: None,
331 language_name: buffer.language().map(|language| language.name().to_proto()),
332 });
333 }
334 }
335
336 let assist_group_id = self.next_assist_group_id.post_inc();
337 let prompt_buffer = cx.new_model(|cx| {
338 MultiBuffer::singleton(cx.new_model(|cx| Buffer::local(String::new(), cx)), cx)
339 });
340
341 let mut assists = Vec::new();
342 let mut assist_to_focus = None;
343 for range in codegen_ranges {
344 let assist_id = self.next_assist_id.post_inc();
345 let codegen = cx.new_model(|cx| {
346 Codegen::new(
347 editor.read(cx).buffer().clone(),
348 range.clone(),
349 None,
350 self.telemetry.clone(),
351 self.prompt_builder.clone(),
352 cx,
353 )
354 });
355
356 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
357 let prompt_editor = cx.new_view(|cx| {
358 PromptEditor::new(
359 assist_id,
360 gutter_dimensions.clone(),
361 self.prompt_history.clone(),
362 prompt_buffer.clone(),
363 codegen.clone(),
364 self.fs.clone(),
365 cx,
366 )
367 });
368
369 if assist_to_focus.is_none() {
370 let focus_assist = if newest_selection.reversed {
371 range.start.to_point(&snapshot) == newest_selection.start
372 } else {
373 range.end.to_point(&snapshot) == newest_selection.end
374 };
375 if focus_assist {
376 assist_to_focus = Some(assist_id);
377 }
378 }
379
380 let [prompt_block_id, end_block_id] =
381 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
382
383 assists.push((
384 assist_id,
385 range,
386 prompt_editor,
387 prompt_block_id,
388 end_block_id,
389 ));
390 }
391
392 let editor_assists = self
393 .assists_by_editor
394 .entry(editor.downgrade())
395 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
396 let mut assist_group = InlineAssistGroup::new();
397 for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
398 self.assists.insert(
399 assist_id,
400 InlineAssist::new(
401 assist_id,
402 assist_group_id,
403 editor,
404 &prompt_editor,
405 prompt_block_id,
406 end_block_id,
407 range,
408 prompt_editor.read(cx).codegen.clone(),
409 workspace.clone(),
410 cx,
411 ),
412 );
413 assist_group.assist_ids.push(assist_id);
414 editor_assists.assist_ids.push(assist_id);
415 }
416 self.assist_groups.insert(assist_group_id, assist_group);
417
418 if let Some(assist_id) = assist_to_focus {
419 self.focus_assist(assist_id, cx);
420 }
421 }
422
423 #[allow(clippy::too_many_arguments)]
424 pub fn suggest_assist(
425 &mut self,
426 editor: &View<Editor>,
427 mut range: Range<Anchor>,
428 initial_prompt: String,
429 initial_transaction_id: Option<TransactionId>,
430 focus: bool,
431 workspace: Option<WeakView<Workspace>>,
432 cx: &mut WindowContext,
433 ) -> InlineAssistId {
434 let assist_group_id = self.next_assist_group_id.post_inc();
435 let prompt_buffer = cx.new_model(|cx| Buffer::local(&initial_prompt, cx));
436 let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
437
438 let assist_id = self.next_assist_id.post_inc();
439
440 let buffer = editor.read(cx).buffer().clone();
441 {
442 let snapshot = buffer.read(cx).read(cx);
443 range.start = range.start.bias_left(&snapshot);
444 range.end = range.end.bias_right(&snapshot);
445 }
446
447 let codegen = cx.new_model(|cx| {
448 Codegen::new(
449 editor.read(cx).buffer().clone(),
450 range.clone(),
451 initial_transaction_id,
452 self.telemetry.clone(),
453 self.prompt_builder.clone(),
454 cx,
455 )
456 });
457
458 let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
459 let prompt_editor = cx.new_view(|cx| {
460 PromptEditor::new(
461 assist_id,
462 gutter_dimensions.clone(),
463 self.prompt_history.clone(),
464 prompt_buffer.clone(),
465 codegen.clone(),
466 self.fs.clone(),
467 cx,
468 )
469 });
470
471 let [prompt_block_id, end_block_id] =
472 self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
473
474 let editor_assists = self
475 .assists_by_editor
476 .entry(editor.downgrade())
477 .or_insert_with(|| EditorInlineAssists::new(&editor, cx));
478
479 let mut assist_group = InlineAssistGroup::new();
480 self.assists.insert(
481 assist_id,
482 InlineAssist::new(
483 assist_id,
484 assist_group_id,
485 editor,
486 &prompt_editor,
487 prompt_block_id,
488 end_block_id,
489 range,
490 prompt_editor.read(cx).codegen.clone(),
491 workspace.clone(),
492 cx,
493 ),
494 );
495 assist_group.assist_ids.push(assist_id);
496 editor_assists.assist_ids.push(assist_id);
497 self.assist_groups.insert(assist_group_id, assist_group);
498
499 if focus {
500 self.focus_assist(assist_id, cx);
501 }
502
503 assist_id
504 }
505
506 fn insert_assist_blocks(
507 &self,
508 editor: &View<Editor>,
509 range: &Range<Anchor>,
510 prompt_editor: &View<PromptEditor>,
511 cx: &mut WindowContext,
512 ) -> [CustomBlockId; 2] {
513 let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
514 prompt_editor
515 .editor
516 .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1 + 2)
517 });
518 let assist_blocks = vec![
519 BlockProperties {
520 style: BlockStyle::Sticky,
521 placement: BlockPlacement::Above(range.start),
522 height: prompt_editor_height,
523 render: build_assist_editor_renderer(prompt_editor),
524 priority: 0,
525 },
526 BlockProperties {
527 style: BlockStyle::Sticky,
528 placement: BlockPlacement::Below(range.end),
529 height: 0,
530 render: Arc::new(|cx| {
531 v_flex()
532 .h_full()
533 .w_full()
534 .border_t_1()
535 .border_color(cx.theme().status().info_border)
536 .into_any_element()
537 }),
538 priority: 0,
539 },
540 ];
541
542 editor.update(cx, |editor, cx| {
543 let block_ids = editor.insert_blocks(assist_blocks, None, cx);
544 [block_ids[0], block_ids[1]]
545 })
546 }
547
548 fn handle_prompt_editor_focus_in(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
549 let assist = &self.assists[&assist_id];
550 let Some(decorations) = assist.decorations.as_ref() else {
551 return;
552 };
553 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
554 let editor_assists = self.assists_by_editor.get_mut(&assist.editor).unwrap();
555
556 assist_group.active_assist_id = Some(assist_id);
557 if assist_group.linked {
558 for assist_id in &assist_group.assist_ids {
559 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
560 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
561 prompt_editor.set_show_cursor_when_unfocused(true, cx)
562 });
563 }
564 }
565 }
566
567 assist
568 .editor
569 .update(cx, |editor, cx| {
570 let scroll_top = editor.scroll_position(cx).y;
571 let scroll_bottom = scroll_top + editor.visible_line_count().unwrap_or(0.);
572 let prompt_row = editor
573 .row_for_block(decorations.prompt_block_id, cx)
574 .unwrap()
575 .0 as f32;
576
577 if (scroll_top..scroll_bottom).contains(&prompt_row) {
578 editor_assists.scroll_lock = Some(InlineAssistScrollLock {
579 assist_id,
580 distance_from_top: prompt_row - scroll_top,
581 });
582 } else {
583 editor_assists.scroll_lock = None;
584 }
585 })
586 .ok();
587 }
588
589 fn handle_prompt_editor_focus_out(
590 &mut self,
591 assist_id: InlineAssistId,
592 cx: &mut WindowContext,
593 ) {
594 let assist = &self.assists[&assist_id];
595 let assist_group = self.assist_groups.get_mut(&assist.group_id).unwrap();
596 if assist_group.active_assist_id == Some(assist_id) {
597 assist_group.active_assist_id = None;
598 if assist_group.linked {
599 for assist_id in &assist_group.assist_ids {
600 if let Some(decorations) = self.assists[assist_id].decorations.as_ref() {
601 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
602 prompt_editor.set_show_cursor_when_unfocused(false, cx)
603 });
604 }
605 }
606 }
607 }
608 }
609
610 fn handle_prompt_editor_event(
611 &mut self,
612 prompt_editor: View<PromptEditor>,
613 event: &PromptEditorEvent,
614 cx: &mut WindowContext,
615 ) {
616 let assist_id = prompt_editor.read(cx).id;
617 match event {
618 PromptEditorEvent::StartRequested => {
619 self.start_assist(assist_id, cx);
620 }
621 PromptEditorEvent::StopRequested => {
622 self.stop_assist(assist_id, cx);
623 }
624 PromptEditorEvent::ConfirmRequested => {
625 self.finish_assist(assist_id, false, cx);
626 }
627 PromptEditorEvent::CancelRequested => {
628 self.finish_assist(assist_id, true, cx);
629 }
630 PromptEditorEvent::DismissRequested => {
631 self.dismiss_assist(assist_id, cx);
632 }
633 }
634 }
635
636 fn handle_editor_newline(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
637 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
638 return;
639 };
640
641 if editor.read(cx).selections.count() == 1 {
642 let (selection, buffer) = editor.update(cx, |editor, cx| {
643 (
644 editor.selections.newest::<usize>(cx),
645 editor.buffer().read(cx).snapshot(cx),
646 )
647 });
648 for assist_id in &editor_assists.assist_ids {
649 let assist = &self.assists[assist_id];
650 let assist_range = assist.range.to_offset(&buffer);
651 if assist_range.contains(&selection.start) && assist_range.contains(&selection.end)
652 {
653 if matches!(assist.codegen.read(cx).status(cx), CodegenStatus::Pending) {
654 self.dismiss_assist(*assist_id, cx);
655 } else {
656 self.finish_assist(*assist_id, false, cx);
657 }
658
659 return;
660 }
661 }
662 }
663
664 cx.propagate();
665 }
666
667 fn handle_editor_cancel(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
668 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
669 return;
670 };
671
672 if editor.read(cx).selections.count() == 1 {
673 let (selection, buffer) = editor.update(cx, |editor, cx| {
674 (
675 editor.selections.newest::<usize>(cx),
676 editor.buffer().read(cx).snapshot(cx),
677 )
678 });
679 let mut closest_assist_fallback = None;
680 for assist_id in &editor_assists.assist_ids {
681 let assist = &self.assists[assist_id];
682 let assist_range = assist.range.to_offset(&buffer);
683 if assist.decorations.is_some() {
684 if assist_range.contains(&selection.start)
685 && assist_range.contains(&selection.end)
686 {
687 self.focus_assist(*assist_id, cx);
688 return;
689 } else {
690 let distance_from_selection = assist_range
691 .start
692 .abs_diff(selection.start)
693 .min(assist_range.start.abs_diff(selection.end))
694 + assist_range
695 .end
696 .abs_diff(selection.start)
697 .min(assist_range.end.abs_diff(selection.end));
698 match closest_assist_fallback {
699 Some((_, old_distance)) => {
700 if distance_from_selection < old_distance {
701 closest_assist_fallback =
702 Some((assist_id, distance_from_selection));
703 }
704 }
705 None => {
706 closest_assist_fallback = Some((assist_id, distance_from_selection))
707 }
708 }
709 }
710 }
711 }
712
713 if let Some((&assist_id, _)) = closest_assist_fallback {
714 self.focus_assist(assist_id, cx);
715 }
716 }
717
718 cx.propagate();
719 }
720
721 fn handle_editor_release(&mut self, editor: WeakView<Editor>, cx: &mut WindowContext) {
722 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor) {
723 for assist_id in editor_assists.assist_ids.clone() {
724 self.finish_assist(assist_id, true, cx);
725 }
726 }
727 }
728
729 fn handle_editor_change(&mut self, editor: View<Editor>, cx: &mut WindowContext) {
730 let Some(editor_assists) = self.assists_by_editor.get(&editor.downgrade()) else {
731 return;
732 };
733 let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() else {
734 return;
735 };
736 let assist = &self.assists[&scroll_lock.assist_id];
737 let Some(decorations) = assist.decorations.as_ref() else {
738 return;
739 };
740
741 editor.update(cx, |editor, cx| {
742 let scroll_position = editor.scroll_position(cx);
743 let target_scroll_top = editor
744 .row_for_block(decorations.prompt_block_id, cx)
745 .unwrap()
746 .0 as f32
747 - scroll_lock.distance_from_top;
748 if target_scroll_top != scroll_position.y {
749 editor.set_scroll_position(point(scroll_position.x, target_scroll_top), cx);
750 }
751 });
752 }
753
754 fn handle_editor_event(
755 &mut self,
756 editor: View<Editor>,
757 event: &EditorEvent,
758 cx: &mut WindowContext,
759 ) {
760 let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) else {
761 return;
762 };
763
764 match event {
765 EditorEvent::Edited { transaction_id } => {
766 let buffer = editor.read(cx).buffer().read(cx);
767 let edited_ranges =
768 buffer.edited_ranges_for_transaction::<usize>(*transaction_id, cx);
769 let snapshot = buffer.snapshot(cx);
770
771 for assist_id in editor_assists.assist_ids.clone() {
772 let assist = &self.assists[&assist_id];
773 if matches!(
774 assist.codegen.read(cx).status(cx),
775 CodegenStatus::Error(_) | CodegenStatus::Done
776 ) {
777 let assist_range = assist.range.to_offset(&snapshot);
778 if edited_ranges
779 .iter()
780 .any(|range| range.overlaps(&assist_range))
781 {
782 self.finish_assist(assist_id, false, cx);
783 }
784 }
785 }
786 }
787 EditorEvent::ScrollPositionChanged { .. } => {
788 if let Some(scroll_lock) = editor_assists.scroll_lock.as_ref() {
789 let assist = &self.assists[&scroll_lock.assist_id];
790 if let Some(decorations) = assist.decorations.as_ref() {
791 let distance_from_top = editor.update(cx, |editor, cx| {
792 let scroll_top = editor.scroll_position(cx).y;
793 let prompt_row = editor
794 .row_for_block(decorations.prompt_block_id, cx)
795 .unwrap()
796 .0 as f32;
797 prompt_row - scroll_top
798 });
799
800 if distance_from_top != scroll_lock.distance_from_top {
801 editor_assists.scroll_lock = None;
802 }
803 }
804 }
805 }
806 EditorEvent::SelectionsChanged { .. } => {
807 for assist_id in editor_assists.assist_ids.clone() {
808 let assist = &self.assists[&assist_id];
809 if let Some(decorations) = assist.decorations.as_ref() {
810 if decorations.prompt_editor.focus_handle(cx).is_focused(cx) {
811 return;
812 }
813 }
814 }
815
816 editor_assists.scroll_lock = None;
817 }
818 _ => {}
819 }
820 }
821
822 pub fn finish_assist(&mut self, assist_id: InlineAssistId, undo: bool, cx: &mut WindowContext) {
823 if let Some(assist) = self.assists.get(&assist_id) {
824 let assist_group_id = assist.group_id;
825 if self.assist_groups[&assist_group_id].linked {
826 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
827 self.finish_assist(assist_id, undo, cx);
828 }
829 return;
830 }
831 }
832
833 self.dismiss_assist(assist_id, cx);
834
835 if let Some(assist) = self.assists.remove(&assist_id) {
836 if let hash_map::Entry::Occupied(mut entry) = self.assist_groups.entry(assist.group_id)
837 {
838 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
839 if entry.get().assist_ids.is_empty() {
840 entry.remove();
841 }
842 }
843
844 if let hash_map::Entry::Occupied(mut entry) =
845 self.assists_by_editor.entry(assist.editor.clone())
846 {
847 entry.get_mut().assist_ids.retain(|id| *id != assist_id);
848 if entry.get().assist_ids.is_empty() {
849 entry.remove();
850 if let Some(editor) = assist.editor.upgrade() {
851 self.update_editor_highlights(&editor, cx);
852 }
853 } else {
854 entry.get().highlight_updates.send(()).ok();
855 }
856 }
857
858 let active_alternative = assist.codegen.read(cx).active_alternative().clone();
859 let message_id = active_alternative.read(cx).message_id.clone();
860
861 if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
862 let language_name = assist.editor.upgrade().and_then(|editor| {
863 let multibuffer = editor.read(cx).buffer().read(cx);
864 let ranges = multibuffer.range_to_buffer_ranges(assist.range.clone(), cx);
865 ranges
866 .first()
867 .and_then(|(buffer, _, _)| buffer.read(cx).language())
868 .map(|language| language.name())
869 });
870 report_assistant_event(
871 AssistantEvent {
872 conversation_id: None,
873 kind: AssistantKind::Inline,
874 message_id,
875 phase: if undo {
876 AssistantPhase::Rejected
877 } else {
878 AssistantPhase::Accepted
879 },
880 model: model.telemetry_id(),
881 model_provider: model.provider_id().to_string(),
882 response_latency: None,
883 error_message: None,
884 language_name: language_name.map(|name| name.to_proto()),
885 },
886 Some(self.telemetry.clone()),
887 cx.http_client(),
888 model.api_key(cx),
889 cx.background_executor(),
890 );
891 }
892
893 if undo {
894 assist.codegen.update(cx, |codegen, cx| codegen.undo(cx));
895 } else {
896 self.confirmed_assists.insert(assist_id, active_alternative);
897 }
898 }
899 }
900
901 fn dismiss_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) -> bool {
902 let Some(assist) = self.assists.get_mut(&assist_id) else {
903 return false;
904 };
905 let Some(editor) = assist.editor.upgrade() else {
906 return false;
907 };
908 let Some(decorations) = assist.decorations.take() else {
909 return false;
910 };
911
912 editor.update(cx, |editor, cx| {
913 let mut to_remove = decorations.removed_line_block_ids;
914 to_remove.insert(decorations.prompt_block_id);
915 to_remove.insert(decorations.end_block_id);
916 editor.remove_blocks(to_remove, None, cx);
917 });
918
919 if decorations
920 .prompt_editor
921 .focus_handle(cx)
922 .contains_focused(cx)
923 {
924 self.focus_next_assist(assist_id, cx);
925 }
926
927 if let Some(editor_assists) = self.assists_by_editor.get_mut(&editor.downgrade()) {
928 if editor_assists
929 .scroll_lock
930 .as_ref()
931 .map_or(false, |lock| lock.assist_id == assist_id)
932 {
933 editor_assists.scroll_lock = None;
934 }
935 editor_assists.highlight_updates.send(()).ok();
936 }
937
938 true
939 }
940
941 fn focus_next_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
942 let Some(assist) = self.assists.get(&assist_id) else {
943 return;
944 };
945
946 let assist_group = &self.assist_groups[&assist.group_id];
947 let assist_ix = assist_group
948 .assist_ids
949 .iter()
950 .position(|id| *id == assist_id)
951 .unwrap();
952 let assist_ids = assist_group
953 .assist_ids
954 .iter()
955 .skip(assist_ix + 1)
956 .chain(assist_group.assist_ids.iter().take(assist_ix));
957
958 for assist_id in assist_ids {
959 let assist = &self.assists[assist_id];
960 if assist.decorations.is_some() {
961 self.focus_assist(*assist_id, cx);
962 return;
963 }
964 }
965
966 assist.editor.update(cx, |editor, cx| editor.focus(cx)).ok();
967 }
968
969 fn focus_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
970 let Some(assist) = self.assists.get(&assist_id) else {
971 return;
972 };
973
974 if let Some(decorations) = assist.decorations.as_ref() {
975 decorations.prompt_editor.update(cx, |prompt_editor, cx| {
976 prompt_editor.editor.update(cx, |editor, cx| {
977 editor.focus(cx);
978 editor.select_all(&SelectAll, cx);
979 })
980 });
981 }
982
983 self.scroll_to_assist(assist_id, cx);
984 }
985
986 pub fn scroll_to_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
987 let Some(assist) = self.assists.get(&assist_id) else {
988 return;
989 };
990 let Some(editor) = assist.editor.upgrade() else {
991 return;
992 };
993
994 let position = assist.range.start;
995 editor.update(cx, |editor, cx| {
996 editor.change_selections(None, cx, |selections| {
997 selections.select_anchor_ranges([position..position])
998 });
999
1000 let mut scroll_target_top;
1001 let mut scroll_target_bottom;
1002 if let Some(decorations) = assist.decorations.as_ref() {
1003 scroll_target_top = editor
1004 .row_for_block(decorations.prompt_block_id, cx)
1005 .unwrap()
1006 .0 as f32;
1007 scroll_target_bottom = editor
1008 .row_for_block(decorations.end_block_id, cx)
1009 .unwrap()
1010 .0 as f32;
1011 } else {
1012 let snapshot = editor.snapshot(cx);
1013 let start_row = assist
1014 .range
1015 .start
1016 .to_display_point(&snapshot.display_snapshot)
1017 .row();
1018 scroll_target_top = start_row.0 as f32;
1019 scroll_target_bottom = scroll_target_top + 1.;
1020 }
1021 scroll_target_top -= editor.vertical_scroll_margin() as f32;
1022 scroll_target_bottom += editor.vertical_scroll_margin() as f32;
1023
1024 let height_in_lines = editor.visible_line_count().unwrap_or(0.);
1025 let scroll_top = editor.scroll_position(cx).y;
1026 let scroll_bottom = scroll_top + height_in_lines;
1027
1028 if scroll_target_top < scroll_top {
1029 editor.set_scroll_position(point(0., scroll_target_top), cx);
1030 } else if scroll_target_bottom > scroll_bottom {
1031 if (scroll_target_bottom - scroll_target_top) <= height_in_lines {
1032 editor
1033 .set_scroll_position(point(0., scroll_target_bottom - height_in_lines), cx);
1034 } else {
1035 editor.set_scroll_position(point(0., scroll_target_top), cx);
1036 }
1037 }
1038 });
1039 }
1040
1041 fn unlink_assist_group(
1042 &mut self,
1043 assist_group_id: InlineAssistGroupId,
1044 cx: &mut WindowContext,
1045 ) -> Vec<InlineAssistId> {
1046 let assist_group = self.assist_groups.get_mut(&assist_group_id).unwrap();
1047 assist_group.linked = false;
1048 for assist_id in &assist_group.assist_ids {
1049 let assist = self.assists.get_mut(assist_id).unwrap();
1050 if let Some(editor_decorations) = assist.decorations.as_ref() {
1051 editor_decorations
1052 .prompt_editor
1053 .update(cx, |prompt_editor, cx| prompt_editor.unlink(cx));
1054 }
1055 }
1056 assist_group.assist_ids.clone()
1057 }
1058
1059 pub fn start_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1060 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1061 assist
1062 } else {
1063 return;
1064 };
1065
1066 let assist_group_id = assist.group_id;
1067 if self.assist_groups[&assist_group_id].linked {
1068 for assist_id in self.unlink_assist_group(assist_group_id, cx) {
1069 self.start_assist(assist_id, cx);
1070 }
1071 return;
1072 }
1073
1074 let Some(user_prompt) = assist.user_prompt(cx) else {
1075 return;
1076 };
1077
1078 self.prompt_history.retain(|prompt| *prompt != user_prompt);
1079 self.prompt_history.push_back(user_prompt.clone());
1080 if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
1081 self.prompt_history.pop_front();
1082 }
1083
1084 assist
1085 .codegen
1086 .update(cx, |codegen, cx| codegen.start(user_prompt, cx))
1087 .log_err();
1088 }
1089
1090 pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
1091 let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
1092 assist
1093 } else {
1094 return;
1095 };
1096
1097 assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
1098 }
1099
1100 fn update_editor_highlights(&self, editor: &View<Editor>, cx: &mut WindowContext) {
1101 let mut gutter_pending_ranges = Vec::new();
1102 let mut gutter_transformed_ranges = Vec::new();
1103 let mut foreground_ranges = Vec::new();
1104 let mut inserted_row_ranges = Vec::new();
1105 let empty_assist_ids = Vec::new();
1106 let assist_ids = self
1107 .assists_by_editor
1108 .get(&editor.downgrade())
1109 .map_or(&empty_assist_ids, |editor_assists| {
1110 &editor_assists.assist_ids
1111 });
1112
1113 for assist_id in assist_ids {
1114 if let Some(assist) = self.assists.get(assist_id) {
1115 let codegen = assist.codegen.read(cx);
1116 let buffer = codegen.buffer(cx).read(cx).read(cx);
1117 foreground_ranges.extend(codegen.last_equal_ranges(cx).iter().cloned());
1118
1119 let pending_range =
1120 codegen.edit_position(cx).unwrap_or(assist.range.start)..assist.range.end;
1121 if pending_range.end.to_offset(&buffer) > pending_range.start.to_offset(&buffer) {
1122 gutter_pending_ranges.push(pending_range);
1123 }
1124
1125 if let Some(edit_position) = codegen.edit_position(cx) {
1126 let edited_range = assist.range.start..edit_position;
1127 if edited_range.end.to_offset(&buffer) > edited_range.start.to_offset(&buffer) {
1128 gutter_transformed_ranges.push(edited_range);
1129 }
1130 }
1131
1132 if assist.decorations.is_some() {
1133 inserted_row_ranges
1134 .extend(codegen.diff(cx).inserted_row_ranges.iter().cloned());
1135 }
1136 }
1137 }
1138
1139 let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
1140 merge_ranges(&mut foreground_ranges, &snapshot);
1141 merge_ranges(&mut gutter_pending_ranges, &snapshot);
1142 merge_ranges(&mut gutter_transformed_ranges, &snapshot);
1143 editor.update(cx, |editor, cx| {
1144 enum GutterPendingRange {}
1145 if gutter_pending_ranges.is_empty() {
1146 editor.clear_gutter_highlights::<GutterPendingRange>(cx);
1147 } else {
1148 editor.highlight_gutter::<GutterPendingRange>(
1149 &gutter_pending_ranges,
1150 |cx| cx.theme().status().info_background,
1151 cx,
1152 )
1153 }
1154
1155 enum GutterTransformedRange {}
1156 if gutter_transformed_ranges.is_empty() {
1157 editor.clear_gutter_highlights::<GutterTransformedRange>(cx);
1158 } else {
1159 editor.highlight_gutter::<GutterTransformedRange>(
1160 &gutter_transformed_ranges,
1161 |cx| cx.theme().status().info,
1162 cx,
1163 )
1164 }
1165
1166 if foreground_ranges.is_empty() {
1167 editor.clear_highlights::<InlineAssist>(cx);
1168 } else {
1169 editor.highlight_text::<InlineAssist>(
1170 foreground_ranges,
1171 HighlightStyle {
1172 fade_out: Some(0.6),
1173 ..Default::default()
1174 },
1175 cx,
1176 );
1177 }
1178
1179 editor.clear_row_highlights::<InlineAssist>();
1180 for row_range in inserted_row_ranges {
1181 editor.highlight_rows::<InlineAssist>(
1182 row_range,
1183 cx.theme().status().info_background,
1184 false,
1185 cx,
1186 );
1187 }
1188 });
1189 }
1190
1191 fn update_editor_blocks(
1192 &mut self,
1193 editor: &View<Editor>,
1194 assist_id: InlineAssistId,
1195 cx: &mut WindowContext,
1196 ) {
1197 let Some(assist) = self.assists.get_mut(&assist_id) else {
1198 return;
1199 };
1200 let Some(decorations) = assist.decorations.as_mut() else {
1201 return;
1202 };
1203
1204 let codegen = assist.codegen.read(cx);
1205 let old_snapshot = codegen.snapshot(cx);
1206 let old_buffer = codegen.old_buffer(cx);
1207 let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
1208
1209 editor.update(cx, |editor, cx| {
1210 let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
1211 editor.remove_blocks(old_blocks, None, cx);
1212
1213 let mut new_blocks = Vec::new();
1214 for (new_row, old_row_range) in deleted_row_ranges {
1215 let (_, buffer_start) = old_snapshot
1216 .point_to_buffer_offset(Point::new(*old_row_range.start(), 0))
1217 .unwrap();
1218 let (_, buffer_end) = old_snapshot
1219 .point_to_buffer_offset(Point::new(
1220 *old_row_range.end(),
1221 old_snapshot.line_len(MultiBufferRow(*old_row_range.end())),
1222 ))
1223 .unwrap();
1224
1225 let deleted_lines_editor = cx.new_view(|cx| {
1226 let multi_buffer = cx.new_model(|_| {
1227 MultiBuffer::without_headers(language::Capability::ReadOnly)
1228 });
1229 multi_buffer.update(cx, |multi_buffer, cx| {
1230 multi_buffer.push_excerpts(
1231 old_buffer.clone(),
1232 Some(ExcerptRange {
1233 context: buffer_start..buffer_end,
1234 primary: None,
1235 }),
1236 cx,
1237 );
1238 });
1239
1240 enum DeletedLines {}
1241 let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
1242 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
1243 editor.set_show_wrap_guides(false, cx);
1244 editor.set_show_gutter(false, cx);
1245 editor.scroll_manager.set_forbid_vertical_scroll(true);
1246 editor.set_read_only(true);
1247 editor.set_show_inline_completions(Some(false), cx);
1248 editor.highlight_rows::<DeletedLines>(
1249 Anchor::min()..Anchor::max(),
1250 cx.theme().status().deleted_background,
1251 false,
1252 cx,
1253 );
1254 editor
1255 });
1256
1257 let height =
1258 deleted_lines_editor.update(cx, |editor, cx| editor.max_point(cx).row().0 + 1);
1259 new_blocks.push(BlockProperties {
1260 placement: BlockPlacement::Above(new_row),
1261 height,
1262 style: BlockStyle::Flex,
1263 render: Arc::new(move |cx| {
1264 div()
1265 .block_mouse_down()
1266 .bg(cx.theme().status().deleted_background)
1267 .size_full()
1268 .h(height as f32 * cx.line_height())
1269 .pl(cx.gutter_dimensions.full_width())
1270 .child(deleted_lines_editor.clone())
1271 .into_any_element()
1272 }),
1273 priority: 0,
1274 });
1275 }
1276
1277 decorations.removed_line_block_ids = editor
1278 .insert_blocks(new_blocks, None, cx)
1279 .into_iter()
1280 .collect();
1281 })
1282 }
1283
1284 fn resolve_inline_assist_target(
1285 workspace: &mut Workspace,
1286 cx: &mut WindowContext,
1287 ) -> Option<InlineAssistTarget> {
1288 if let Some(terminal_panel) = workspace.panel::<TerminalPanel>(cx) {
1289 if terminal_panel
1290 .read(cx)
1291 .focus_handle(cx)
1292 .contains_focused(cx)
1293 {
1294 if let Some(terminal_view) = terminal_panel.read(cx).pane().and_then(|pane| {
1295 pane.read(cx)
1296 .active_item()
1297 .and_then(|t| t.downcast::<TerminalView>())
1298 }) {
1299 return Some(InlineAssistTarget::Terminal(terminal_view));
1300 }
1301 }
1302 }
1303
1304 if let Some(workspace_editor) = workspace
1305 .active_item(cx)
1306 .and_then(|item| item.act_as::<Editor>(cx))
1307 {
1308 Some(InlineAssistTarget::Editor(workspace_editor))
1309 } else if let Some(terminal_view) = workspace
1310 .active_item(cx)
1311 .and_then(|item| item.act_as::<TerminalView>(cx))
1312 {
1313 Some(InlineAssistTarget::Terminal(terminal_view))
1314 } else {
1315 None
1316 }
1317 }
1318}
1319
1320struct EditorInlineAssists {
1321 assist_ids: Vec<InlineAssistId>,
1322 scroll_lock: Option<InlineAssistScrollLock>,
1323 highlight_updates: async_watch::Sender<()>,
1324 _update_highlights: Task<Result<()>>,
1325 _subscriptions: Vec<gpui::Subscription>,
1326}
1327
1328struct InlineAssistScrollLock {
1329 assist_id: InlineAssistId,
1330 distance_from_top: f32,
1331}
1332
1333impl EditorInlineAssists {
1334 #[allow(clippy::too_many_arguments)]
1335 fn new(editor: &View<Editor>, cx: &mut WindowContext) -> Self {
1336 let (highlight_updates_tx, mut highlight_updates_rx) = async_watch::channel(());
1337 Self {
1338 assist_ids: Vec::new(),
1339 scroll_lock: None,
1340 highlight_updates: highlight_updates_tx,
1341 _update_highlights: cx.spawn(|mut cx| {
1342 let editor = editor.downgrade();
1343 async move {
1344 while let Ok(()) = highlight_updates_rx.changed().await {
1345 let editor = editor.upgrade().context("editor was dropped")?;
1346 cx.update_global(|assistant: &mut InlineAssistant, cx| {
1347 assistant.update_editor_highlights(&editor, cx);
1348 })?;
1349 }
1350 Ok(())
1351 }
1352 }),
1353 _subscriptions: vec![
1354 cx.observe_release(editor, {
1355 let editor = editor.downgrade();
1356 |_, cx| {
1357 InlineAssistant::update_global(cx, |this, cx| {
1358 this.handle_editor_release(editor, cx);
1359 })
1360 }
1361 }),
1362 cx.observe(editor, move |editor, cx| {
1363 InlineAssistant::update_global(cx, |this, cx| {
1364 this.handle_editor_change(editor, cx)
1365 })
1366 }),
1367 cx.subscribe(editor, move |editor, event, cx| {
1368 InlineAssistant::update_global(cx, |this, cx| {
1369 this.handle_editor_event(editor, event, cx)
1370 })
1371 }),
1372 editor.update(cx, |editor, cx| {
1373 let editor_handle = cx.view().downgrade();
1374 editor.register_action(
1375 move |_: &editor::actions::Newline, cx: &mut WindowContext| {
1376 InlineAssistant::update_global(cx, |this, cx| {
1377 if let Some(editor) = editor_handle.upgrade() {
1378 this.handle_editor_newline(editor, cx)
1379 }
1380 })
1381 },
1382 )
1383 }),
1384 editor.update(cx, |editor, cx| {
1385 let editor_handle = cx.view().downgrade();
1386 editor.register_action(
1387 move |_: &editor::actions::Cancel, cx: &mut WindowContext| {
1388 InlineAssistant::update_global(cx, |this, cx| {
1389 if let Some(editor) = editor_handle.upgrade() {
1390 this.handle_editor_cancel(editor, cx)
1391 }
1392 })
1393 },
1394 )
1395 }),
1396 ],
1397 }
1398 }
1399}
1400
1401struct InlineAssistGroup {
1402 assist_ids: Vec<InlineAssistId>,
1403 linked: bool,
1404 active_assist_id: Option<InlineAssistId>,
1405}
1406
1407impl InlineAssistGroup {
1408 fn new() -> Self {
1409 Self {
1410 assist_ids: Vec::new(),
1411 linked: true,
1412 active_assist_id: None,
1413 }
1414 }
1415}
1416
1417fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
1418 let editor = editor.clone();
1419 Arc::new(move |cx: &mut BlockContext| {
1420 *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
1421 editor.clone().into_any_element()
1422 })
1423}
1424
1425#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1426pub struct InlineAssistId(usize);
1427
1428impl InlineAssistId {
1429 fn post_inc(&mut self) -> InlineAssistId {
1430 let id = *self;
1431 self.0 += 1;
1432 id
1433 }
1434}
1435
1436#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
1437struct InlineAssistGroupId(usize);
1438
1439impl InlineAssistGroupId {
1440 fn post_inc(&mut self) -> InlineAssistGroupId {
1441 let id = *self;
1442 self.0 += 1;
1443 id
1444 }
1445}
1446
1447enum PromptEditorEvent {
1448 StartRequested,
1449 StopRequested,
1450 ConfirmRequested,
1451 CancelRequested,
1452 DismissRequested,
1453}
1454
1455struct PromptEditor {
1456 id: InlineAssistId,
1457 fs: Arc<dyn Fs>,
1458 editor: View<Editor>,
1459 edited_since_done: bool,
1460 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1461 prompt_history: VecDeque<String>,
1462 prompt_history_ix: Option<usize>,
1463 pending_prompt: String,
1464 codegen: Model<Codegen>,
1465 _codegen_subscription: Subscription,
1466 editor_subscriptions: Vec<Subscription>,
1467 show_rate_limit_notice: bool,
1468}
1469
1470impl EventEmitter<PromptEditorEvent> for PromptEditor {}
1471
1472impl Render for PromptEditor {
1473 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
1474 let gutter_dimensions = *self.gutter_dimensions.lock();
1475 let mut buttons = vec![Button::new("add-context", "Add Context")
1476 .style(ButtonStyle::Filled)
1477 .icon(IconName::Plus)
1478 .icon_position(IconPosition::Start)
1479 .into_any_element()];
1480 let codegen = self.codegen.read(cx);
1481 if codegen.alternative_count(cx) > 1 {
1482 buttons.push(self.render_cycle_controls(cx));
1483 }
1484
1485 let status = codegen.status(cx);
1486 buttons.extend(match status {
1487 CodegenStatus::Idle => {
1488 vec![
1489 IconButton::new("cancel", IconName::Close)
1490 .icon_color(Color::Muted)
1491 .shape(IconButtonShape::Square)
1492 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1493 .on_click(
1494 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1495 )
1496 .into_any_element(),
1497 IconButton::new("start", IconName::SparkleAlt)
1498 .icon_color(Color::Muted)
1499 .shape(IconButtonShape::Square)
1500 .tooltip(|cx| Tooltip::for_action("Transform", &menu::Confirm, cx))
1501 .on_click(
1502 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
1503 )
1504 .into_any_element(),
1505 ]
1506 }
1507 CodegenStatus::Pending => {
1508 vec![
1509 IconButton::new("cancel", IconName::Close)
1510 .icon_color(Color::Muted)
1511 .shape(IconButtonShape::Square)
1512 .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
1513 .on_click(
1514 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1515 )
1516 .into_any_element(),
1517 IconButton::new("stop", IconName::Stop)
1518 .icon_color(Color::Error)
1519 .shape(IconButtonShape::Square)
1520 .tooltip(|cx| {
1521 Tooltip::with_meta(
1522 "Interrupt Transformation",
1523 Some(&menu::Cancel),
1524 "Changes won't be discarded",
1525 cx,
1526 )
1527 })
1528 .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
1529 .into_any_element(),
1530 ]
1531 }
1532 CodegenStatus::Error(_) | CodegenStatus::Done => {
1533 vec![
1534 IconButton::new("cancel", IconName::Close)
1535 .icon_color(Color::Muted)
1536 .shape(IconButtonShape::Square)
1537 .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
1538 .on_click(
1539 cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
1540 )
1541 .into_any_element(),
1542 if self.edited_since_done || matches!(status, CodegenStatus::Error(_)) {
1543 IconButton::new("restart", IconName::RotateCw)
1544 .icon_color(Color::Info)
1545 .shape(IconButtonShape::Square)
1546 .tooltip(|cx| {
1547 Tooltip::with_meta(
1548 "Restart Transformation",
1549 Some(&menu::Confirm),
1550 "Changes will be discarded",
1551 cx,
1552 )
1553 })
1554 .on_click(cx.listener(|_, _, cx| {
1555 cx.emit(PromptEditorEvent::StartRequested);
1556 }))
1557 .into_any_element()
1558 } else {
1559 IconButton::new("confirm", IconName::Check)
1560 .icon_color(Color::Info)
1561 .shape(IconButtonShape::Square)
1562 .tooltip(|cx| Tooltip::for_action("Confirm Assist", &menu::Confirm, cx))
1563 .on_click(cx.listener(|_, _, cx| {
1564 cx.emit(PromptEditorEvent::ConfirmRequested);
1565 }))
1566 .into_any_element()
1567 },
1568 ]
1569 }
1570 });
1571
1572 h_flex()
1573 .key_context("PromptEditor")
1574 .bg(cx.theme().colors().editor_background)
1575 .block_mouse_down()
1576 .cursor(CursorStyle::Arrow)
1577 .border_y_1()
1578 .border_color(cx.theme().status().info_border)
1579 .size_full()
1580 .py(cx.line_height() / 2.5)
1581 .on_action(cx.listener(Self::confirm))
1582 .on_action(cx.listener(Self::cancel))
1583 .on_action(cx.listener(Self::move_up))
1584 .on_action(cx.listener(Self::move_down))
1585 .capture_action(cx.listener(Self::cycle_prev))
1586 .capture_action(cx.listener(Self::cycle_next))
1587 .child(
1588 h_flex()
1589 .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
1590 .justify_center()
1591 .gap_2()
1592 .child(
1593 LanguageModelSelector::new(
1594 {
1595 let fs = self.fs.clone();
1596 move |model, cx| {
1597 update_settings_file::<AssistantSettings>(
1598 fs.clone(),
1599 cx,
1600 move |settings, _| settings.set_model(model.clone()),
1601 );
1602 }
1603 },
1604 IconButton::new("context", IconName::SettingsAlt)
1605 .shape(IconButtonShape::Square)
1606 .icon_size(IconSize::Small)
1607 .icon_color(Color::Muted)
1608 .tooltip(move |cx| {
1609 Tooltip::with_meta(
1610 format!(
1611 "Using {}",
1612 LanguageModelRegistry::read_global(cx)
1613 .active_model()
1614 .map(|model| model.name().0)
1615 .unwrap_or_else(|| "No model selected".into()),
1616 ),
1617 None,
1618 "Change Model",
1619 cx,
1620 )
1621 }),
1622 )
1623 .info_text(
1624 "Inline edits use context\n\
1625 from the currently selected\n\
1626 assistant panel tab.",
1627 ),
1628 )
1629 .map(|el| {
1630 let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx) else {
1631 return el;
1632 };
1633
1634 let error_message = SharedString::from(error.to_string());
1635 if error.error_code() == proto::ErrorCode::RateLimitExceeded
1636 && cx.has_flag::<ZedPro>()
1637 {
1638 el.child(
1639 v_flex()
1640 .child(
1641 IconButton::new("rate-limit-error", IconName::XCircle)
1642 .selected(self.show_rate_limit_notice)
1643 .shape(IconButtonShape::Square)
1644 .icon_size(IconSize::Small)
1645 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
1646 )
1647 .children(self.show_rate_limit_notice.then(|| {
1648 deferred(
1649 anchored()
1650 .position_mode(gpui::AnchoredPositionMode::Local)
1651 .position(point(px(0.), px(24.)))
1652 .anchor(gpui::AnchorCorner::TopLeft)
1653 .child(self.render_rate_limit_notice(cx)),
1654 )
1655 })),
1656 )
1657 } else {
1658 el.child(
1659 div()
1660 .id("error")
1661 .tooltip(move |cx| Tooltip::text(error_message.clone(), cx))
1662 .child(
1663 Icon::new(IconName::XCircle)
1664 .size(IconSize::Small)
1665 .color(Color::Error),
1666 ),
1667 )
1668 }
1669 }),
1670 )
1671 .child(div().flex_1().child(self.render_editor(cx)))
1672 .child(h_flex().gap_2().pr_6().children(buttons))
1673 }
1674}
1675
1676impl FocusableView for PromptEditor {
1677 fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
1678 self.editor.focus_handle(cx)
1679 }
1680}
1681
1682impl PromptEditor {
1683 const MAX_LINES: u8 = 8;
1684
1685 #[allow(clippy::too_many_arguments)]
1686 fn new(
1687 id: InlineAssistId,
1688 gutter_dimensions: Arc<Mutex<GutterDimensions>>,
1689 prompt_history: VecDeque<String>,
1690 prompt_buffer: Model<MultiBuffer>,
1691 codegen: Model<Codegen>,
1692 fs: Arc<dyn Fs>,
1693 cx: &mut ViewContext<Self>,
1694 ) -> Self {
1695 let prompt_editor = cx.new_view(|cx| {
1696 let mut editor = Editor::new(
1697 EditorMode::AutoHeight {
1698 max_lines: Self::MAX_LINES as usize,
1699 },
1700 prompt_buffer,
1701 None,
1702 false,
1703 cx,
1704 );
1705 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1706 // Since the prompt editors for all inline assistants are linked,
1707 // always show the cursor (even when it isn't focused) because
1708 // typing in one will make what you typed appear in all of them.
1709 editor.set_show_cursor_when_unfocused(true, cx);
1710 editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx)), cx);
1711 editor
1712 });
1713
1714 let mut this = Self {
1715 id,
1716 editor: prompt_editor,
1717 edited_since_done: false,
1718 gutter_dimensions,
1719 prompt_history,
1720 prompt_history_ix: None,
1721 pending_prompt: String::new(),
1722 _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
1723 editor_subscriptions: Vec::new(),
1724 codegen,
1725 fs,
1726 show_rate_limit_notice: false,
1727 };
1728 this.subscribe_to_editor(cx);
1729 this
1730 }
1731
1732 fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
1733 self.editor_subscriptions.clear();
1734 self.editor_subscriptions
1735 .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
1736 }
1737
1738 fn set_show_cursor_when_unfocused(
1739 &mut self,
1740 show_cursor_when_unfocused: bool,
1741 cx: &mut ViewContext<Self>,
1742 ) {
1743 self.editor.update(cx, |editor, cx| {
1744 editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
1745 });
1746 }
1747
1748 fn unlink(&mut self, cx: &mut ViewContext<Self>) {
1749 let prompt = self.prompt(cx);
1750 let focus = self.editor.focus_handle(cx).contains_focused(cx);
1751 self.editor = cx.new_view(|cx| {
1752 let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
1753 editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
1754 editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx)), cx);
1755 editor.set_placeholder_text("Add a prompt…", cx);
1756 editor.set_text(prompt, cx);
1757 if focus {
1758 editor.focus(cx);
1759 }
1760 editor
1761 });
1762 self.subscribe_to_editor(cx);
1763 }
1764
1765 fn placeholder_text(codegen: &Codegen) -> String {
1766 let action = if codegen.is_insertion {
1767 "Generate"
1768 } else {
1769 "Transform"
1770 };
1771
1772 format!("{action}… ↓↑ for history")
1773 }
1774
1775 fn prompt(&self, cx: &AppContext) -> String {
1776 self.editor.read(cx).text(cx)
1777 }
1778
1779 fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
1780 self.show_rate_limit_notice = !self.show_rate_limit_notice;
1781 if self.show_rate_limit_notice {
1782 cx.focus_view(&self.editor);
1783 }
1784 cx.notify();
1785 }
1786
1787 fn handle_prompt_editor_events(
1788 &mut self,
1789 _: View<Editor>,
1790 event: &EditorEvent,
1791 cx: &mut ViewContext<Self>,
1792 ) {
1793 match event {
1794 EditorEvent::Edited { .. } => {
1795 if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
1796 workspace
1797 .update(cx, |workspace, cx| {
1798 let is_via_ssh = workspace
1799 .project()
1800 .update(cx, |project, _| project.is_via_ssh());
1801
1802 workspace
1803 .client()
1804 .telemetry()
1805 .log_edit_event("inline assist", is_via_ssh);
1806 })
1807 .log_err();
1808 }
1809 let prompt = self.editor.read(cx).text(cx);
1810 if self
1811 .prompt_history_ix
1812 .map_or(true, |ix| self.prompt_history[ix] != prompt)
1813 {
1814 self.prompt_history_ix.take();
1815 self.pending_prompt = prompt;
1816 }
1817
1818 self.edited_since_done = true;
1819 cx.notify();
1820 }
1821 EditorEvent::Blurred => {
1822 if self.show_rate_limit_notice {
1823 self.show_rate_limit_notice = false;
1824 cx.notify();
1825 }
1826 }
1827 _ => {}
1828 }
1829 }
1830
1831 fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
1832 match self.codegen.read(cx).status(cx) {
1833 CodegenStatus::Idle => {
1834 self.editor
1835 .update(cx, |editor, _| editor.set_read_only(false));
1836 }
1837 CodegenStatus::Pending => {
1838 self.editor
1839 .update(cx, |editor, _| editor.set_read_only(true));
1840 }
1841 CodegenStatus::Done => {
1842 self.edited_since_done = false;
1843 self.editor
1844 .update(cx, |editor, _| editor.set_read_only(false));
1845 }
1846 CodegenStatus::Error(error) => {
1847 if cx.has_flag::<ZedPro>()
1848 && error.error_code() == proto::ErrorCode::RateLimitExceeded
1849 && !dismissed_rate_limit_notice()
1850 {
1851 self.show_rate_limit_notice = true;
1852 cx.notify();
1853 }
1854
1855 self.edited_since_done = false;
1856 self.editor
1857 .update(cx, |editor, _| editor.set_read_only(false));
1858 }
1859 }
1860 }
1861
1862 fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
1863 match self.codegen.read(cx).status(cx) {
1864 CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
1865 cx.emit(PromptEditorEvent::CancelRequested);
1866 }
1867 CodegenStatus::Pending => {
1868 cx.emit(PromptEditorEvent::StopRequested);
1869 }
1870 }
1871 }
1872
1873 fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
1874 match self.codegen.read(cx).status(cx) {
1875 CodegenStatus::Idle => {
1876 cx.emit(PromptEditorEvent::StartRequested);
1877 }
1878 CodegenStatus::Pending => {
1879 cx.emit(PromptEditorEvent::DismissRequested);
1880 }
1881 CodegenStatus::Done => {
1882 if self.edited_since_done {
1883 cx.emit(PromptEditorEvent::StartRequested);
1884 } else {
1885 cx.emit(PromptEditorEvent::ConfirmRequested);
1886 }
1887 }
1888 CodegenStatus::Error(_) => {
1889 cx.emit(PromptEditorEvent::StartRequested);
1890 }
1891 }
1892 }
1893
1894 fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
1895 if let Some(ix) = self.prompt_history_ix {
1896 if ix > 0 {
1897 self.prompt_history_ix = Some(ix - 1);
1898 let prompt = self.prompt_history[ix - 1].as_str();
1899 self.editor.update(cx, |editor, cx| {
1900 editor.set_text(prompt, cx);
1901 editor.move_to_beginning(&Default::default(), cx);
1902 });
1903 }
1904 } else if !self.prompt_history.is_empty() {
1905 self.prompt_history_ix = Some(self.prompt_history.len() - 1);
1906 let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
1907 self.editor.update(cx, |editor, cx| {
1908 editor.set_text(prompt, cx);
1909 editor.move_to_beginning(&Default::default(), cx);
1910 });
1911 }
1912 }
1913
1914 fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
1915 if let Some(ix) = self.prompt_history_ix {
1916 if ix < self.prompt_history.len() - 1 {
1917 self.prompt_history_ix = Some(ix + 1);
1918 let prompt = self.prompt_history[ix + 1].as_str();
1919 self.editor.update(cx, |editor, cx| {
1920 editor.set_text(prompt, cx);
1921 editor.move_to_end(&Default::default(), cx)
1922 });
1923 } else {
1924 self.prompt_history_ix = None;
1925 let prompt = self.pending_prompt.as_str();
1926 self.editor.update(cx, |editor, cx| {
1927 editor.set_text(prompt, cx);
1928 editor.move_to_end(&Default::default(), cx)
1929 });
1930 }
1931 }
1932 }
1933
1934 fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
1935 self.codegen
1936 .update(cx, |codegen, cx| codegen.cycle_prev(cx));
1937 }
1938
1939 fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
1940 self.codegen
1941 .update(cx, |codegen, cx| codegen.cycle_next(cx));
1942 }
1943
1944 fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
1945 let codegen = self.codegen.read(cx);
1946 let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
1947
1948 let model_registry = LanguageModelRegistry::read_global(cx);
1949 let default_model = model_registry.active_model();
1950 let alternative_models = model_registry.inline_alternative_models();
1951
1952 let get_model_name = |index: usize| -> String {
1953 let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
1954
1955 match index {
1956 0 => default_model.as_ref().map_or_else(String::new, name),
1957 index if index <= alternative_models.len() => alternative_models
1958 .get(index - 1)
1959 .map_or_else(String::new, name),
1960 _ => String::new(),
1961 }
1962 };
1963
1964 let total_models = alternative_models.len() + 1;
1965
1966 if total_models <= 1 {
1967 return div().into_any_element();
1968 }
1969
1970 let current_index = codegen.active_alternative;
1971 let prev_index = (current_index + total_models - 1) % total_models;
1972 let next_index = (current_index + 1) % total_models;
1973
1974 let prev_model_name = get_model_name(prev_index);
1975 let next_model_name = get_model_name(next_index);
1976
1977 h_flex()
1978 .child(
1979 IconButton::new("previous", IconName::ChevronLeft)
1980 .icon_color(Color::Muted)
1981 .disabled(disabled || current_index == 0)
1982 .shape(IconButtonShape::Square)
1983 .tooltip({
1984 let focus_handle = self.editor.focus_handle(cx);
1985 move |cx| {
1986 cx.new_view(|cx| {
1987 let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
1988 KeyBinding::for_action_in(
1989 &CyclePreviousInlineAssist,
1990 &focus_handle,
1991 cx,
1992 ),
1993 );
1994 if !disabled && current_index != 0 {
1995 tooltip = tooltip.meta(prev_model_name.clone());
1996 }
1997 tooltip
1998 })
1999 .into()
2000 }
2001 })
2002 .on_click(cx.listener(|this, _, cx| {
2003 this.codegen
2004 .update(cx, |codegen, cx| codegen.cycle_prev(cx))
2005 })),
2006 )
2007 .child(
2008 Label::new(format!(
2009 "{}/{}",
2010 codegen.active_alternative + 1,
2011 codegen.alternative_count(cx)
2012 ))
2013 .size(LabelSize::Small)
2014 .color(if disabled {
2015 Color::Disabled
2016 } else {
2017 Color::Muted
2018 }),
2019 )
2020 .child(
2021 IconButton::new("next", IconName::ChevronRight)
2022 .icon_color(Color::Muted)
2023 .disabled(disabled || current_index == total_models - 1)
2024 .shape(IconButtonShape::Square)
2025 .tooltip({
2026 let focus_handle = self.editor.focus_handle(cx);
2027 move |cx| {
2028 cx.new_view(|cx| {
2029 let mut tooltip = Tooltip::new("Next Alternative").key_binding(
2030 KeyBinding::for_action_in(
2031 &CycleNextInlineAssist,
2032 &focus_handle,
2033 cx,
2034 ),
2035 );
2036 if !disabled && current_index != total_models - 1 {
2037 tooltip = tooltip.meta(next_model_name.clone());
2038 }
2039 tooltip
2040 })
2041 .into()
2042 }
2043 })
2044 .on_click(cx.listener(|this, _, cx| {
2045 this.codegen
2046 .update(cx, |codegen, cx| codegen.cycle_next(cx))
2047 })),
2048 )
2049 .into_any_element()
2050 }
2051
2052 fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
2053 Popover::new().child(
2054 v_flex()
2055 .occlude()
2056 .p_2()
2057 .child(
2058 Label::new("Out of Tokens")
2059 .size(LabelSize::Small)
2060 .weight(FontWeight::BOLD),
2061 )
2062 .child(Label::new(
2063 "Try Zed Pro for higher limits, a wider range of models, and more.",
2064 ))
2065 .child(
2066 h_flex()
2067 .justify_between()
2068 .child(CheckboxWithLabel::new(
2069 "dont-show-again",
2070 Label::new("Don't show again"),
2071 if dismissed_rate_limit_notice() {
2072 ui::Selection::Selected
2073 } else {
2074 ui::Selection::Unselected
2075 },
2076 |selection, cx| {
2077 let is_dismissed = match selection {
2078 ui::Selection::Unselected => false,
2079 ui::Selection::Indeterminate => return,
2080 ui::Selection::Selected => true,
2081 };
2082
2083 set_rate_limit_notice_dismissed(is_dismissed, cx)
2084 },
2085 ))
2086 .child(
2087 h_flex()
2088 .gap_2()
2089 .child(
2090 Button::new("dismiss", "Dismiss")
2091 .style(ButtonStyle::Transparent)
2092 .on_click(cx.listener(Self::toggle_rate_limit_notice)),
2093 )
2094 .child(Button::new("more-info", "More Info").on_click(
2095 |_event, cx| {
2096 cx.dispatch_action(Box::new(
2097 zed_actions::OpenAccountSettings,
2098 ))
2099 },
2100 )),
2101 ),
2102 ),
2103 )
2104 }
2105
2106 fn render_editor(&mut self, cx: &mut ViewContext<Self>) -> AnyElement {
2107 let font_size = TextSize::Default.rems(cx);
2108 let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
2109
2110 v_flex()
2111 .key_context("MessageEditor")
2112 .size_full()
2113 .gap_2()
2114 .p_2()
2115 .bg(cx.theme().colors().editor_background)
2116 .child({
2117 let settings = ThemeSettings::get_global(cx);
2118 let text_style = TextStyle {
2119 color: cx.theme().colors().editor_foreground,
2120 font_family: settings.ui_font.family.clone(),
2121 font_features: settings.ui_font.features.clone(),
2122 font_size: font_size.into(),
2123 font_weight: settings.ui_font.weight,
2124 line_height: line_height.into(),
2125 ..Default::default()
2126 };
2127
2128 EditorElement::new(
2129 &self.editor,
2130 EditorStyle {
2131 background: cx.theme().colors().editor_background,
2132 local_player: cx.theme().players().local(),
2133 text: text_style,
2134 ..Default::default()
2135 },
2136 )
2137 })
2138 .into_any_element()
2139 }
2140}
2141
2142const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
2143
2144fn dismissed_rate_limit_notice() -> bool {
2145 db::kvp::KEY_VALUE_STORE
2146 .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
2147 .log_err()
2148 .map_or(false, |s| s.is_some())
2149}
2150
2151fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
2152 db::write_and_log(cx, move || async move {
2153 if is_dismissed {
2154 db::kvp::KEY_VALUE_STORE
2155 .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
2156 .await
2157 } else {
2158 db::kvp::KEY_VALUE_STORE
2159 .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
2160 .await
2161 }
2162 })
2163}
2164
2165pub struct InlineAssist {
2166 group_id: InlineAssistGroupId,
2167 range: Range<Anchor>,
2168 editor: WeakView<Editor>,
2169 decorations: Option<InlineAssistDecorations>,
2170 codegen: Model<Codegen>,
2171 _subscriptions: Vec<Subscription>,
2172 workspace: Option<WeakView<Workspace>>,
2173}
2174
2175impl InlineAssist {
2176 #[allow(clippy::too_many_arguments)]
2177 fn new(
2178 assist_id: InlineAssistId,
2179 group_id: InlineAssistGroupId,
2180 editor: &View<Editor>,
2181 prompt_editor: &View<PromptEditor>,
2182 prompt_block_id: CustomBlockId,
2183 end_block_id: CustomBlockId,
2184 range: Range<Anchor>,
2185 codegen: Model<Codegen>,
2186 workspace: Option<WeakView<Workspace>>,
2187 cx: &mut WindowContext,
2188 ) -> Self {
2189 let prompt_editor_focus_handle = prompt_editor.focus_handle(cx);
2190 InlineAssist {
2191 group_id,
2192 editor: editor.downgrade(),
2193 decorations: Some(InlineAssistDecorations {
2194 prompt_block_id,
2195 prompt_editor: prompt_editor.clone(),
2196 removed_line_block_ids: HashSet::default(),
2197 end_block_id,
2198 }),
2199 range,
2200 codegen: codegen.clone(),
2201 workspace: workspace.clone(),
2202 _subscriptions: vec![
2203 cx.on_focus_in(&prompt_editor_focus_handle, move |cx| {
2204 InlineAssistant::update_global(cx, |this, cx| {
2205 this.handle_prompt_editor_focus_in(assist_id, cx)
2206 })
2207 }),
2208 cx.on_focus_out(&prompt_editor_focus_handle, move |_, cx| {
2209 InlineAssistant::update_global(cx, |this, cx| {
2210 this.handle_prompt_editor_focus_out(assist_id, cx)
2211 })
2212 }),
2213 cx.subscribe(prompt_editor, |prompt_editor, event, cx| {
2214 InlineAssistant::update_global(cx, |this, cx| {
2215 this.handle_prompt_editor_event(prompt_editor, event, cx)
2216 })
2217 }),
2218 cx.observe(&codegen, {
2219 let editor = editor.downgrade();
2220 move |_, cx| {
2221 if let Some(editor) = editor.upgrade() {
2222 InlineAssistant::update_global(cx, |this, cx| {
2223 if let Some(editor_assists) =
2224 this.assists_by_editor.get(&editor.downgrade())
2225 {
2226 editor_assists.highlight_updates.send(()).ok();
2227 }
2228
2229 this.update_editor_blocks(&editor, assist_id, cx);
2230 })
2231 }
2232 }
2233 }),
2234 cx.subscribe(&codegen, move |codegen, event, cx| {
2235 InlineAssistant::update_global(cx, |this, cx| match event {
2236 CodegenEvent::Undone => this.finish_assist(assist_id, false, cx),
2237 CodegenEvent::Finished => {
2238 let assist = if let Some(assist) = this.assists.get(&assist_id) {
2239 assist
2240 } else {
2241 return;
2242 };
2243
2244 if let CodegenStatus::Error(error) = codegen.read(cx).status(cx) {
2245 if assist.decorations.is_none() {
2246 if let Some(workspace) = assist
2247 .workspace
2248 .as_ref()
2249 .and_then(|workspace| workspace.upgrade())
2250 {
2251 let error = format!("Inline assistant error: {}", error);
2252 workspace.update(cx, |workspace, cx| {
2253 struct InlineAssistantError;
2254
2255 let id =
2256 NotificationId::composite::<InlineAssistantError>(
2257 assist_id.0,
2258 );
2259
2260 workspace.show_toast(Toast::new(id, error), cx);
2261 })
2262 }
2263 }
2264 }
2265
2266 if assist.decorations.is_none() {
2267 this.finish_assist(assist_id, false, cx);
2268 }
2269 }
2270 })
2271 }),
2272 ],
2273 }
2274 }
2275
2276 fn user_prompt(&self, cx: &AppContext) -> Option<String> {
2277 let decorations = self.decorations.as_ref()?;
2278 Some(decorations.prompt_editor.read(cx).prompt(cx))
2279 }
2280}
2281
2282struct InlineAssistDecorations {
2283 prompt_block_id: CustomBlockId,
2284 prompt_editor: View<PromptEditor>,
2285 removed_line_block_ids: HashSet<CustomBlockId>,
2286 end_block_id: CustomBlockId,
2287}
2288
2289#[derive(Copy, Clone, Debug)]
2290pub enum CodegenEvent {
2291 Finished,
2292 Undone,
2293}
2294
2295pub struct Codegen {
2296 alternatives: Vec<Model<CodegenAlternative>>,
2297 active_alternative: usize,
2298 seen_alternatives: HashSet<usize>,
2299 subscriptions: Vec<Subscription>,
2300 buffer: Model<MultiBuffer>,
2301 range: Range<Anchor>,
2302 initial_transaction_id: Option<TransactionId>,
2303 telemetry: Arc<Telemetry>,
2304 builder: Arc<PromptBuilder>,
2305 is_insertion: bool,
2306}
2307
2308impl Codegen {
2309 pub fn new(
2310 buffer: Model<MultiBuffer>,
2311 range: Range<Anchor>,
2312 initial_transaction_id: Option<TransactionId>,
2313 telemetry: Arc<Telemetry>,
2314 builder: Arc<PromptBuilder>,
2315 cx: &mut ModelContext<Self>,
2316 ) -> Self {
2317 let codegen = cx.new_model(|cx| {
2318 CodegenAlternative::new(
2319 buffer.clone(),
2320 range.clone(),
2321 false,
2322 Some(telemetry.clone()),
2323 builder.clone(),
2324 cx,
2325 )
2326 });
2327 let mut this = Self {
2328 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
2329 alternatives: vec![codegen],
2330 active_alternative: 0,
2331 seen_alternatives: HashSet::default(),
2332 subscriptions: Vec::new(),
2333 buffer,
2334 range,
2335 initial_transaction_id,
2336 telemetry,
2337 builder,
2338 };
2339 this.activate(0, cx);
2340 this
2341 }
2342
2343 fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
2344 let codegen = self.active_alternative().clone();
2345 self.subscriptions.clear();
2346 self.subscriptions
2347 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
2348 self.subscriptions
2349 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
2350 }
2351
2352 fn active_alternative(&self) -> &Model<CodegenAlternative> {
2353 &self.alternatives[self.active_alternative]
2354 }
2355
2356 fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
2357 &self.active_alternative().read(cx).status
2358 }
2359
2360 fn alternative_count(&self, cx: &AppContext) -> usize {
2361 LanguageModelRegistry::read_global(cx)
2362 .inline_alternative_models()
2363 .len()
2364 + 1
2365 }
2366
2367 pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
2368 let next_active_ix = if self.active_alternative == 0 {
2369 self.alternatives.len() - 1
2370 } else {
2371 self.active_alternative - 1
2372 };
2373 self.activate(next_active_ix, cx);
2374 }
2375
2376 pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
2377 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
2378 self.activate(next_active_ix, cx);
2379 }
2380
2381 fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
2382 self.active_alternative()
2383 .update(cx, |codegen, cx| codegen.set_active(false, cx));
2384 self.seen_alternatives.insert(index);
2385 self.active_alternative = index;
2386 self.active_alternative()
2387 .update(cx, |codegen, cx| codegen.set_active(true, cx));
2388 self.subscribe_to_alternative(cx);
2389 cx.notify();
2390 }
2391
2392 pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
2393 let alternative_models = LanguageModelRegistry::read_global(cx)
2394 .inline_alternative_models()
2395 .to_vec();
2396
2397 self.active_alternative()
2398 .update(cx, |alternative, cx| alternative.undo(cx));
2399 self.activate(0, cx);
2400 self.alternatives.truncate(1);
2401
2402 for _ in 0..alternative_models.len() {
2403 self.alternatives.push(cx.new_model(|cx| {
2404 CodegenAlternative::new(
2405 self.buffer.clone(),
2406 self.range.clone(),
2407 false,
2408 Some(self.telemetry.clone()),
2409 self.builder.clone(),
2410 cx,
2411 )
2412 }));
2413 }
2414
2415 let primary_model = LanguageModelRegistry::read_global(cx)
2416 .active_model()
2417 .context("no active model")?;
2418
2419 for (model, alternative) in iter::once(primary_model)
2420 .chain(alternative_models)
2421 .zip(&self.alternatives)
2422 {
2423 alternative.update(cx, |alternative, cx| {
2424 alternative.start(user_prompt.clone(), model.clone(), cx)
2425 })?;
2426 }
2427
2428 Ok(())
2429 }
2430
2431 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2432 for codegen in &self.alternatives {
2433 codegen.update(cx, |codegen, cx| codegen.stop(cx));
2434 }
2435 }
2436
2437 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2438 self.active_alternative()
2439 .update(cx, |codegen, cx| codegen.undo(cx));
2440
2441 self.buffer.update(cx, |buffer, cx| {
2442 if let Some(transaction_id) = self.initial_transaction_id.take() {
2443 buffer.undo_transaction(transaction_id, cx);
2444 buffer.refresh_preview(cx);
2445 }
2446 });
2447 }
2448
2449 pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
2450 self.active_alternative().read(cx).buffer.clone()
2451 }
2452
2453 pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
2454 self.active_alternative().read(cx).old_buffer.clone()
2455 }
2456
2457 pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
2458 self.active_alternative().read(cx).snapshot.clone()
2459 }
2460
2461 pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
2462 self.active_alternative().read(cx).edit_position
2463 }
2464
2465 fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
2466 &self.active_alternative().read(cx).diff
2467 }
2468
2469 pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
2470 self.active_alternative().read(cx).last_equal_ranges()
2471 }
2472}
2473
2474impl EventEmitter<CodegenEvent> for Codegen {}
2475
2476pub struct CodegenAlternative {
2477 buffer: Model<MultiBuffer>,
2478 old_buffer: Model<Buffer>,
2479 snapshot: MultiBufferSnapshot,
2480 edit_position: Option<Anchor>,
2481 range: Range<Anchor>,
2482 last_equal_ranges: Vec<Range<Anchor>>,
2483 transformation_transaction_id: Option<TransactionId>,
2484 status: CodegenStatus,
2485 generation: Task<()>,
2486 diff: Diff,
2487 telemetry: Option<Arc<Telemetry>>,
2488 _subscription: gpui::Subscription,
2489 builder: Arc<PromptBuilder>,
2490 active: bool,
2491 edits: Vec<(Range<Anchor>, String)>,
2492 line_operations: Vec<LineOperation>,
2493 request: Option<LanguageModelRequest>,
2494 elapsed_time: Option<f64>,
2495 completion: Option<String>,
2496 message_id: Option<String>,
2497}
2498
2499enum CodegenStatus {
2500 Idle,
2501 Pending,
2502 Done,
2503 Error(anyhow::Error),
2504}
2505
2506#[derive(Default)]
2507struct Diff {
2508 deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
2509 inserted_row_ranges: Vec<Range<Anchor>>,
2510}
2511
2512impl Diff {
2513 fn is_empty(&self) -> bool {
2514 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
2515 }
2516}
2517
2518impl EventEmitter<CodegenEvent> for CodegenAlternative {}
2519
2520impl CodegenAlternative {
2521 pub fn new(
2522 buffer: Model<MultiBuffer>,
2523 range: Range<Anchor>,
2524 active: bool,
2525 telemetry: Option<Arc<Telemetry>>,
2526 builder: Arc<PromptBuilder>,
2527 cx: &mut ModelContext<Self>,
2528 ) -> Self {
2529 let snapshot = buffer.read(cx).snapshot(cx);
2530
2531 let (old_buffer, _, _) = buffer
2532 .read(cx)
2533 .range_to_buffer_ranges(range.clone(), cx)
2534 .pop()
2535 .unwrap();
2536 let old_buffer = cx.new_model(|cx| {
2537 let old_buffer = old_buffer.read(cx);
2538 let text = old_buffer.as_rope().clone();
2539 let line_ending = old_buffer.line_ending();
2540 let language = old_buffer.language().cloned();
2541 let language_registry = old_buffer.language_registry();
2542
2543 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
2544 buffer.set_language(language, cx);
2545 if let Some(language_registry) = language_registry {
2546 buffer.set_language_registry(language_registry)
2547 }
2548 buffer
2549 });
2550
2551 Self {
2552 buffer: buffer.clone(),
2553 old_buffer,
2554 edit_position: None,
2555 message_id: None,
2556 snapshot,
2557 last_equal_ranges: Default::default(),
2558 transformation_transaction_id: None,
2559 status: CodegenStatus::Idle,
2560 generation: Task::ready(()),
2561 diff: Diff::default(),
2562 telemetry,
2563 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
2564 builder,
2565 active,
2566 edits: Vec::new(),
2567 line_operations: Vec::new(),
2568 range,
2569 request: None,
2570 elapsed_time: None,
2571 completion: None,
2572 }
2573 }
2574
2575 fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
2576 if active != self.active {
2577 self.active = active;
2578
2579 if self.active {
2580 let edits = self.edits.clone();
2581 self.apply_edits(edits, cx);
2582 if matches!(self.status, CodegenStatus::Pending) {
2583 let line_operations = self.line_operations.clone();
2584 self.reapply_line_based_diff(line_operations, cx);
2585 } else {
2586 self.reapply_batch_diff(cx).detach();
2587 }
2588 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
2589 self.buffer.update(cx, |buffer, cx| {
2590 buffer.undo_transaction(transaction_id, cx);
2591 buffer.forget_transaction(transaction_id, cx);
2592 });
2593 }
2594 }
2595 }
2596
2597 fn handle_buffer_event(
2598 &mut self,
2599 _buffer: Model<MultiBuffer>,
2600 event: &multi_buffer::Event,
2601 cx: &mut ModelContext<Self>,
2602 ) {
2603 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
2604 if self.transformation_transaction_id == Some(*transaction_id) {
2605 self.transformation_transaction_id = None;
2606 self.generation = Task::ready(());
2607 cx.emit(CodegenEvent::Undone);
2608 }
2609 }
2610 }
2611
2612 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
2613 &self.last_equal_ranges
2614 }
2615
2616 pub fn start(
2617 &mut self,
2618 user_prompt: String,
2619 model: Arc<dyn LanguageModel>,
2620 cx: &mut ModelContext<Self>,
2621 ) -> Result<()> {
2622 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
2623 self.buffer.update(cx, |buffer, cx| {
2624 buffer.undo_transaction(transformation_transaction_id, cx);
2625 });
2626 }
2627
2628 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
2629
2630 let api_key = model.api_key(cx);
2631 let telemetry_id = model.telemetry_id();
2632 let provider_id = model.provider_id();
2633 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
2634 if user_prompt.trim().to_lowercase() == "delete" {
2635 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
2636 } else {
2637 let request = self.build_request(user_prompt, cx)?;
2638 self.request = Some(request.clone());
2639
2640 cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
2641 .boxed_local()
2642 };
2643 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
2644 Ok(())
2645 }
2646
2647 fn build_request(&self, user_prompt: String, cx: &AppContext) -> Result<LanguageModelRequest> {
2648 let buffer = self.buffer.read(cx).snapshot(cx);
2649 let language = buffer.language_at(self.range.start);
2650 let language_name = if let Some(language) = language.as_ref() {
2651 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
2652 None
2653 } else {
2654 Some(language.name())
2655 }
2656 } else {
2657 None
2658 };
2659
2660 let language_name = language_name.as_ref();
2661 let start = buffer.point_to_buffer_offset(self.range.start);
2662 let end = buffer.point_to_buffer_offset(self.range.end);
2663 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
2664 let (start_buffer, start_buffer_offset) = start;
2665 let (end_buffer, end_buffer_offset) = end;
2666 if start_buffer.remote_id() == end_buffer.remote_id() {
2667 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
2668 } else {
2669 return Err(anyhow::anyhow!("invalid transformation range"));
2670 }
2671 } else {
2672 return Err(anyhow::anyhow!("invalid transformation range"));
2673 };
2674
2675 let prompt = self
2676 .builder
2677 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
2678 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
2679
2680 Ok(LanguageModelRequest {
2681 tools: Vec::new(),
2682 stop: Vec::new(),
2683 temperature: None,
2684 messages: vec![LanguageModelRequestMessage {
2685 role: Role::User,
2686 content: vec![prompt.into()],
2687 cache: false,
2688 }],
2689 })
2690 }
2691
2692 pub fn handle_stream(
2693 &mut self,
2694 model_telemetry_id: String,
2695 model_provider_id: String,
2696 model_api_key: Option<String>,
2697 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
2698 cx: &mut ModelContext<Self>,
2699 ) {
2700 let start_time = Instant::now();
2701 let snapshot = self.snapshot.clone();
2702 let selected_text = snapshot
2703 .text_for_range(self.range.start..self.range.end)
2704 .collect::<Rope>();
2705
2706 let selection_start = self.range.start.to_point(&snapshot);
2707
2708 // Start with the indentation of the first line in the selection
2709 let mut suggested_line_indent = snapshot
2710 .suggested_indents(selection_start.row..=selection_start.row, cx)
2711 .into_values()
2712 .next()
2713 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
2714
2715 // If the first line in the selection does not have indentation, check the following lines
2716 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
2717 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
2718 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
2719 // Prefer tabs if a line in the selection uses tabs as indentation
2720 if line_indent.kind == IndentKind::Tab {
2721 suggested_line_indent.kind = IndentKind::Tab;
2722 break;
2723 }
2724 }
2725 }
2726
2727 let http_client = cx.http_client().clone();
2728 let telemetry = self.telemetry.clone();
2729 let language_name = {
2730 let multibuffer = self.buffer.read(cx);
2731 let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
2732 ranges
2733 .first()
2734 .and_then(|(buffer, _, _)| buffer.read(cx).language())
2735 .map(|language| language.name())
2736 };
2737
2738 self.diff = Diff::default();
2739 self.status = CodegenStatus::Pending;
2740 let mut edit_start = self.range.start.to_offset(&snapshot);
2741 let completion = Arc::new(Mutex::new(String::new()));
2742 let completion_clone = completion.clone();
2743
2744 self.generation = cx.spawn(|codegen, mut cx| {
2745 async move {
2746 let stream = stream.await;
2747 let message_id = stream
2748 .as_ref()
2749 .ok()
2750 .and_then(|stream| stream.message_id.clone());
2751 let generate = async {
2752 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
2753 let executor = cx.background_executor().clone();
2754 let message_id = message_id.clone();
2755 let line_based_stream_diff: Task<anyhow::Result<()>> =
2756 cx.background_executor().spawn(async move {
2757 let mut response_latency = None;
2758 let request_start = Instant::now();
2759 let diff = async {
2760 let chunks = StripInvalidSpans::new(stream?.stream);
2761 futures::pin_mut!(chunks);
2762 let mut diff = StreamingDiff::new(selected_text.to_string());
2763 let mut line_diff = LineDiff::default();
2764
2765 let mut new_text = String::new();
2766 let mut base_indent = None;
2767 let mut line_indent = None;
2768 let mut first_line = true;
2769
2770 while let Some(chunk) = chunks.next().await {
2771 if response_latency.is_none() {
2772 response_latency = Some(request_start.elapsed());
2773 }
2774 let chunk = chunk?;
2775 completion_clone.lock().push_str(&chunk);
2776
2777 let mut lines = chunk.split('\n').peekable();
2778 while let Some(line) = lines.next() {
2779 new_text.push_str(line);
2780 if line_indent.is_none() {
2781 if let Some(non_whitespace_ch_ix) =
2782 new_text.find(|ch: char| !ch.is_whitespace())
2783 {
2784 line_indent = Some(non_whitespace_ch_ix);
2785 base_indent = base_indent.or(line_indent);
2786
2787 let line_indent = line_indent.unwrap();
2788 let base_indent = base_indent.unwrap();
2789 let indent_delta =
2790 line_indent as i32 - base_indent as i32;
2791 let mut corrected_indent_len = cmp::max(
2792 0,
2793 suggested_line_indent.len as i32 + indent_delta,
2794 )
2795 as usize;
2796 if first_line {
2797 corrected_indent_len = corrected_indent_len
2798 .saturating_sub(
2799 selection_start.column as usize,
2800 );
2801 }
2802
2803 let indent_char = suggested_line_indent.char();
2804 let mut indent_buffer = [0; 4];
2805 let indent_str =
2806 indent_char.encode_utf8(&mut indent_buffer);
2807 new_text.replace_range(
2808 ..line_indent,
2809 &indent_str.repeat(corrected_indent_len),
2810 );
2811 }
2812 }
2813
2814 if line_indent.is_some() {
2815 let char_ops = diff.push_new(&new_text);
2816 line_diff
2817 .push_char_operations(&char_ops, &selected_text);
2818 diff_tx
2819 .send((char_ops, line_diff.line_operations()))
2820 .await?;
2821 new_text.clear();
2822 }
2823
2824 if lines.peek().is_some() {
2825 let char_ops = diff.push_new("\n");
2826 line_diff
2827 .push_char_operations(&char_ops, &selected_text);
2828 diff_tx
2829 .send((char_ops, line_diff.line_operations()))
2830 .await?;
2831 if line_indent.is_none() {
2832 // Don't write out the leading indentation in empty lines on the next line
2833 // This is the case where the above if statement didn't clear the buffer
2834 new_text.clear();
2835 }
2836 line_indent = None;
2837 first_line = false;
2838 }
2839 }
2840 }
2841
2842 let mut char_ops = diff.push_new(&new_text);
2843 char_ops.extend(diff.finish());
2844 line_diff.push_char_operations(&char_ops, &selected_text);
2845 line_diff.finish(&selected_text);
2846 diff_tx
2847 .send((char_ops, line_diff.line_operations()))
2848 .await?;
2849
2850 anyhow::Ok(())
2851 };
2852
2853 let result = diff.await;
2854
2855 let error_message =
2856 result.as_ref().err().map(|error| error.to_string());
2857 report_assistant_event(
2858 AssistantEvent {
2859 conversation_id: None,
2860 message_id,
2861 kind: AssistantKind::Inline,
2862 phase: AssistantPhase::Response,
2863 model: model_telemetry_id,
2864 model_provider: model_provider_id.to_string(),
2865 response_latency,
2866 error_message,
2867 language_name: language_name.map(|name| name.to_proto()),
2868 },
2869 telemetry,
2870 http_client,
2871 model_api_key,
2872 &executor,
2873 );
2874
2875 result?;
2876 Ok(())
2877 });
2878
2879 while let Some((char_ops, line_ops)) = diff_rx.next().await {
2880 codegen.update(&mut cx, |codegen, cx| {
2881 codegen.last_equal_ranges.clear();
2882
2883 let edits = char_ops
2884 .into_iter()
2885 .filter_map(|operation| match operation {
2886 CharOperation::Insert { text } => {
2887 let edit_start = snapshot.anchor_after(edit_start);
2888 Some((edit_start..edit_start, text))
2889 }
2890 CharOperation::Delete { bytes } => {
2891 let edit_end = edit_start + bytes;
2892 let edit_range = snapshot.anchor_after(edit_start)
2893 ..snapshot.anchor_before(edit_end);
2894 edit_start = edit_end;
2895 Some((edit_range, String::new()))
2896 }
2897 CharOperation::Keep { bytes } => {
2898 let edit_end = edit_start + bytes;
2899 let edit_range = snapshot.anchor_after(edit_start)
2900 ..snapshot.anchor_before(edit_end);
2901 edit_start = edit_end;
2902 codegen.last_equal_ranges.push(edit_range);
2903 None
2904 }
2905 })
2906 .collect::<Vec<_>>();
2907
2908 if codegen.active {
2909 codegen.apply_edits(edits.iter().cloned(), cx);
2910 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
2911 }
2912 codegen.edits.extend(edits);
2913 codegen.line_operations = line_ops;
2914 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
2915
2916 cx.notify();
2917 })?;
2918 }
2919
2920 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
2921 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
2922 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
2923 let batch_diff_task =
2924 codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
2925 let (line_based_stream_diff, ()) =
2926 join!(line_based_stream_diff, batch_diff_task);
2927 line_based_stream_diff?;
2928
2929 anyhow::Ok(())
2930 };
2931
2932 let result = generate.await;
2933 let elapsed_time = start_time.elapsed().as_secs_f64();
2934
2935 codegen
2936 .update(&mut cx, |this, cx| {
2937 this.message_id = message_id;
2938 this.last_equal_ranges.clear();
2939 if let Err(error) = result {
2940 this.status = CodegenStatus::Error(error);
2941 } else {
2942 this.status = CodegenStatus::Done;
2943 }
2944 this.elapsed_time = Some(elapsed_time);
2945 this.completion = Some(completion.lock().clone());
2946 cx.emit(CodegenEvent::Finished);
2947 cx.notify();
2948 })
2949 .ok();
2950 }
2951 });
2952 cx.notify();
2953 }
2954
2955 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
2956 self.last_equal_ranges.clear();
2957 if self.diff.is_empty() {
2958 self.status = CodegenStatus::Idle;
2959 } else {
2960 self.status = CodegenStatus::Done;
2961 }
2962 self.generation = Task::ready(());
2963 cx.emit(CodegenEvent::Finished);
2964 cx.notify();
2965 }
2966
2967 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
2968 self.buffer.update(cx, |buffer, cx| {
2969 if let Some(transaction_id) = self.transformation_transaction_id.take() {
2970 buffer.undo_transaction(transaction_id, cx);
2971 buffer.refresh_preview(cx);
2972 }
2973 });
2974 }
2975
2976 fn apply_edits(
2977 &mut self,
2978 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
2979 cx: &mut ModelContext<CodegenAlternative>,
2980 ) {
2981 let transaction = self.buffer.update(cx, |buffer, cx| {
2982 // Avoid grouping assistant edits with user edits.
2983 buffer.finalize_last_transaction(cx);
2984 buffer.start_transaction(cx);
2985 buffer.edit(edits, None, cx);
2986 buffer.end_transaction(cx)
2987 });
2988
2989 if let Some(transaction) = transaction {
2990 if let Some(first_transaction) = self.transformation_transaction_id {
2991 // Group all assistant edits into the first transaction.
2992 self.buffer.update(cx, |buffer, cx| {
2993 buffer.merge_transactions(transaction, first_transaction, cx)
2994 });
2995 } else {
2996 self.transformation_transaction_id = Some(transaction);
2997 self.buffer
2998 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
2999 }
3000 }
3001 }
3002
3003 fn reapply_line_based_diff(
3004 &mut self,
3005 line_operations: impl IntoIterator<Item = LineOperation>,
3006 cx: &mut ModelContext<Self>,
3007 ) {
3008 let old_snapshot = self.snapshot.clone();
3009 let old_range = self.range.to_point(&old_snapshot);
3010 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3011 let new_range = self.range.to_point(&new_snapshot);
3012
3013 let mut old_row = old_range.start.row;
3014 let mut new_row = new_range.start.row;
3015
3016 self.diff.deleted_row_ranges.clear();
3017 self.diff.inserted_row_ranges.clear();
3018 for operation in line_operations {
3019 match operation {
3020 LineOperation::Keep { lines } => {
3021 old_row += lines;
3022 new_row += lines;
3023 }
3024 LineOperation::Delete { lines } => {
3025 let old_end_row = old_row + lines - 1;
3026 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3027
3028 if let Some((_, last_deleted_row_range)) =
3029 self.diff.deleted_row_ranges.last_mut()
3030 {
3031 if *last_deleted_row_range.end() + 1 == old_row {
3032 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
3033 } else {
3034 self.diff
3035 .deleted_row_ranges
3036 .push((new_row, old_row..=old_end_row));
3037 }
3038 } else {
3039 self.diff
3040 .deleted_row_ranges
3041 .push((new_row, old_row..=old_end_row));
3042 }
3043
3044 old_row += lines;
3045 }
3046 LineOperation::Insert { lines } => {
3047 let new_end_row = new_row + lines - 1;
3048 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3049 let end = new_snapshot.anchor_before(Point::new(
3050 new_end_row,
3051 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3052 ));
3053 self.diff.inserted_row_ranges.push(start..end);
3054 new_row += lines;
3055 }
3056 }
3057
3058 cx.notify();
3059 }
3060 }
3061
3062 fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
3063 let old_snapshot = self.snapshot.clone();
3064 let old_range = self.range.to_point(&old_snapshot);
3065 let new_snapshot = self.buffer.read(cx).snapshot(cx);
3066 let new_range = self.range.to_point(&new_snapshot);
3067
3068 cx.spawn(|codegen, mut cx| async move {
3069 let (deleted_row_ranges, inserted_row_ranges) = cx
3070 .background_executor()
3071 .spawn(async move {
3072 let old_text = old_snapshot
3073 .text_for_range(
3074 Point::new(old_range.start.row, 0)
3075 ..Point::new(
3076 old_range.end.row,
3077 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
3078 ),
3079 )
3080 .collect::<String>();
3081 let new_text = new_snapshot
3082 .text_for_range(
3083 Point::new(new_range.start.row, 0)
3084 ..Point::new(
3085 new_range.end.row,
3086 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
3087 ),
3088 )
3089 .collect::<String>();
3090
3091 let mut old_row = old_range.start.row;
3092 let mut new_row = new_range.start.row;
3093 let batch_diff =
3094 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
3095
3096 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
3097 let mut inserted_row_ranges = Vec::new();
3098 for change in batch_diff.iter_all_changes() {
3099 let line_count = change.value().lines().count() as u32;
3100 match change.tag() {
3101 similar::ChangeTag::Equal => {
3102 old_row += line_count;
3103 new_row += line_count;
3104 }
3105 similar::ChangeTag::Delete => {
3106 let old_end_row = old_row + line_count - 1;
3107 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
3108
3109 if let Some((_, last_deleted_row_range)) =
3110 deleted_row_ranges.last_mut()
3111 {
3112 if *last_deleted_row_range.end() + 1 == old_row {
3113 *last_deleted_row_range =
3114 *last_deleted_row_range.start()..=old_end_row;
3115 } else {
3116 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3117 }
3118 } else {
3119 deleted_row_ranges.push((new_row, old_row..=old_end_row));
3120 }
3121
3122 old_row += line_count;
3123 }
3124 similar::ChangeTag::Insert => {
3125 let new_end_row = new_row + line_count - 1;
3126 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
3127 let end = new_snapshot.anchor_before(Point::new(
3128 new_end_row,
3129 new_snapshot.line_len(MultiBufferRow(new_end_row)),
3130 ));
3131 inserted_row_ranges.push(start..end);
3132 new_row += line_count;
3133 }
3134 }
3135 }
3136
3137 (deleted_row_ranges, inserted_row_ranges)
3138 })
3139 .await;
3140
3141 codegen
3142 .update(&mut cx, |codegen, cx| {
3143 codegen.diff.deleted_row_ranges = deleted_row_ranges;
3144 codegen.diff.inserted_row_ranges = inserted_row_ranges;
3145 cx.notify();
3146 })
3147 .ok();
3148 })
3149 }
3150}
3151
3152struct StripInvalidSpans<T> {
3153 stream: T,
3154 stream_done: bool,
3155 buffer: String,
3156 first_line: bool,
3157 line_end: bool,
3158 starts_with_code_block: bool,
3159}
3160
3161impl<T> StripInvalidSpans<T>
3162where
3163 T: Stream<Item = Result<String>>,
3164{
3165 fn new(stream: T) -> Self {
3166 Self {
3167 stream,
3168 stream_done: false,
3169 buffer: String::new(),
3170 first_line: true,
3171 line_end: false,
3172 starts_with_code_block: false,
3173 }
3174 }
3175}
3176
3177impl<T> Stream for StripInvalidSpans<T>
3178where
3179 T: Stream<Item = Result<String>>,
3180{
3181 type Item = Result<String>;
3182
3183 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
3184 const CODE_BLOCK_DELIMITER: &str = "```";
3185 const CURSOR_SPAN: &str = "<|CURSOR|>";
3186
3187 let this = unsafe { self.get_unchecked_mut() };
3188 loop {
3189 if !this.stream_done {
3190 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
3191 match stream.as_mut().poll_next(cx) {
3192 Poll::Ready(Some(Ok(chunk))) => {
3193 this.buffer.push_str(&chunk);
3194 }
3195 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
3196 Poll::Ready(None) => {
3197 this.stream_done = true;
3198 }
3199 Poll::Pending => return Poll::Pending,
3200 }
3201 }
3202
3203 let mut chunk = String::new();
3204 let mut consumed = 0;
3205 if !this.buffer.is_empty() {
3206 let mut lines = this.buffer.split('\n').enumerate().peekable();
3207 while let Some((line_ix, line)) = lines.next() {
3208 if line_ix > 0 {
3209 this.first_line = false;
3210 }
3211
3212 if this.first_line {
3213 let trimmed_line = line.trim();
3214 if lines.peek().is_some() {
3215 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
3216 consumed += line.len() + 1;
3217 this.starts_with_code_block = true;
3218 continue;
3219 }
3220 } else if trimmed_line.is_empty()
3221 || prefixes(CODE_BLOCK_DELIMITER)
3222 .any(|prefix| trimmed_line.starts_with(prefix))
3223 {
3224 break;
3225 }
3226 }
3227
3228 let line_without_cursor = line.replace(CURSOR_SPAN, "");
3229 if lines.peek().is_some() {
3230 if this.line_end {
3231 chunk.push('\n');
3232 }
3233
3234 chunk.push_str(&line_without_cursor);
3235 this.line_end = true;
3236 consumed += line.len() + 1;
3237 } else if this.stream_done {
3238 if !this.starts_with_code_block
3239 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
3240 {
3241 if this.line_end {
3242 chunk.push('\n');
3243 }
3244
3245 chunk.push_str(&line);
3246 }
3247
3248 consumed += line.len();
3249 } else {
3250 let trimmed_line = line.trim();
3251 if trimmed_line.is_empty()
3252 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
3253 || prefixes(CODE_BLOCK_DELIMITER)
3254 .any(|prefix| trimmed_line.ends_with(prefix))
3255 {
3256 break;
3257 } else {
3258 if this.line_end {
3259 chunk.push('\n');
3260 this.line_end = false;
3261 }
3262
3263 chunk.push_str(&line_without_cursor);
3264 consumed += line.len();
3265 }
3266 }
3267 }
3268 }
3269
3270 this.buffer = this.buffer.split_off(consumed);
3271 if !chunk.is_empty() {
3272 return Poll::Ready(Some(Ok(chunk)));
3273 } else if this.stream_done {
3274 return Poll::Ready(None);
3275 }
3276 }
3277 }
3278}
3279
3280struct AssistantCodeActionProvider {
3281 editor: WeakView<Editor>,
3282 workspace: WeakView<Workspace>,
3283}
3284
3285impl CodeActionProvider for AssistantCodeActionProvider {
3286 fn code_actions(
3287 &self,
3288 buffer: &Model<Buffer>,
3289 range: Range<text::Anchor>,
3290 cx: &mut WindowContext,
3291 ) -> Task<Result<Vec<CodeAction>>> {
3292 if !AssistantSettings::get_global(cx).enabled {
3293 return Task::ready(Ok(Vec::new()));
3294 }
3295
3296 let snapshot = buffer.read(cx).snapshot();
3297 let mut range = range.to_point(&snapshot);
3298
3299 // Expand the range to line boundaries.
3300 range.start.column = 0;
3301 range.end.column = snapshot.line_len(range.end.row);
3302
3303 let mut has_diagnostics = false;
3304 for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
3305 range.start = cmp::min(range.start, diagnostic.range.start);
3306 range.end = cmp::max(range.end, diagnostic.range.end);
3307 has_diagnostics = true;
3308 }
3309 if has_diagnostics {
3310 if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
3311 if let Some(symbol) = symbols_containing_start.last() {
3312 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3313 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3314 }
3315 }
3316
3317 if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
3318 if let Some(symbol) = symbols_containing_end.last() {
3319 range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
3320 range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
3321 }
3322 }
3323
3324 Task::ready(Ok(vec![CodeAction {
3325 server_id: language::LanguageServerId(0),
3326 range: snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end),
3327 lsp_action: lsp::CodeAction {
3328 title: "Fix with Assistant".into(),
3329 ..Default::default()
3330 },
3331 }]))
3332 } else {
3333 Task::ready(Ok(Vec::new()))
3334 }
3335 }
3336
3337 fn apply_code_action(
3338 &self,
3339 buffer: Model<Buffer>,
3340 action: CodeAction,
3341 excerpt_id: ExcerptId,
3342 _push_to_history: bool,
3343 cx: &mut WindowContext,
3344 ) -> Task<Result<ProjectTransaction>> {
3345 let editor = self.editor.clone();
3346 let workspace = self.workspace.clone();
3347 cx.spawn(|mut cx| async move {
3348 let editor = editor.upgrade().context("editor was released")?;
3349 let range = editor
3350 .update(&mut cx, |editor, cx| {
3351 editor.buffer().update(cx, |multibuffer, cx| {
3352 let buffer = buffer.read(cx);
3353 let multibuffer_snapshot = multibuffer.read(cx);
3354
3355 let old_context_range =
3356 multibuffer_snapshot.context_range_for_excerpt(excerpt_id)?;
3357 let mut new_context_range = old_context_range.clone();
3358 if action
3359 .range
3360 .start
3361 .cmp(&old_context_range.start, buffer)
3362 .is_lt()
3363 {
3364 new_context_range.start = action.range.start;
3365 }
3366 if action.range.end.cmp(&old_context_range.end, buffer).is_gt() {
3367 new_context_range.end = action.range.end;
3368 }
3369 drop(multibuffer_snapshot);
3370
3371 if new_context_range != old_context_range {
3372 multibuffer.resize_excerpt(excerpt_id, new_context_range, cx);
3373 }
3374
3375 let multibuffer_snapshot = multibuffer.read(cx);
3376 Some(
3377 multibuffer_snapshot
3378 .anchor_in_excerpt(excerpt_id, action.range.start)?
3379 ..multibuffer_snapshot
3380 .anchor_in_excerpt(excerpt_id, action.range.end)?,
3381 )
3382 })
3383 })?
3384 .context("invalid range")?;
3385
3386 cx.update_global(|assistant: &mut InlineAssistant, cx| {
3387 let assist_id = assistant.suggest_assist(
3388 &editor,
3389 range,
3390 "Fix Diagnostics".into(),
3391 None,
3392 true,
3393 Some(workspace),
3394 cx,
3395 );
3396 assistant.start_assist(assist_id, cx);
3397 })?;
3398
3399 Ok(ProjectTransaction::default())
3400 })
3401 }
3402}
3403
3404fn prefixes(text: &str) -> impl Iterator<Item = &str> {
3405 (0..text.len() - 1).map(|ix| &text[..ix + 1])
3406}
3407
3408fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
3409 ranges.sort_unstable_by(|a, b| {
3410 a.start
3411 .cmp(&b.start, buffer)
3412 .then_with(|| b.end.cmp(&a.end, buffer))
3413 });
3414
3415 let mut ix = 0;
3416 while ix + 1 < ranges.len() {
3417 let b = ranges[ix + 1].clone();
3418 let a = &mut ranges[ix];
3419 if a.end.cmp(&b.start, buffer).is_gt() {
3420 if a.end.cmp(&b.end, buffer).is_lt() {
3421 a.end = b.end;
3422 }
3423 ranges.remove(ix + 1);
3424 } else {
3425 ix += 1;
3426 }
3427 }
3428}
3429
3430#[cfg(test)]
3431mod tests {
3432 use super::*;
3433 use futures::stream::{self};
3434 use gpui::{Context, TestAppContext};
3435 use indoc::indoc;
3436 use language::{
3437 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
3438 Point,
3439 };
3440 use language_model::LanguageModelRegistry;
3441 use rand::prelude::*;
3442 use serde::Serialize;
3443 use settings::SettingsStore;
3444 use std::{future, sync::Arc};
3445
3446 #[derive(Serialize)]
3447 pub struct DummyCompletionRequest {
3448 pub name: String,
3449 }
3450
3451 #[gpui::test(iterations = 10)]
3452 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
3453 cx.set_global(cx.update(SettingsStore::test));
3454 cx.update(language_model::LanguageModelRegistry::test);
3455 cx.update(language_settings::init);
3456
3457 let text = indoc! {"
3458 fn main() {
3459 let x = 0;
3460 for _ in 0..10 {
3461 x += 1;
3462 }
3463 }
3464 "};
3465 let buffer =
3466 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3467 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3468 let range = buffer.read_with(cx, |buffer, cx| {
3469 let snapshot = buffer.snapshot(cx);
3470 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
3471 });
3472 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3473 let codegen = cx.new_model(|cx| {
3474 CodegenAlternative::new(
3475 buffer.clone(),
3476 range.clone(),
3477 true,
3478 None,
3479 prompt_builder,
3480 cx,
3481 )
3482 });
3483
3484 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3485
3486 let mut new_text = concat!(
3487 " let mut x = 0;\n",
3488 " while x < 10 {\n",
3489 " x += 1;\n",
3490 " }",
3491 );
3492 while !new_text.is_empty() {
3493 let max_len = cmp::min(new_text.len(), 10);
3494 let len = rng.gen_range(1..=max_len);
3495 let (chunk, suffix) = new_text.split_at(len);
3496 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3497 new_text = suffix;
3498 cx.background_executor.run_until_parked();
3499 }
3500 drop(chunks_tx);
3501 cx.background_executor.run_until_parked();
3502
3503 assert_eq!(
3504 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3505 indoc! {"
3506 fn main() {
3507 let mut x = 0;
3508 while x < 10 {
3509 x += 1;
3510 }
3511 }
3512 "}
3513 );
3514 }
3515
3516 #[gpui::test(iterations = 10)]
3517 async fn test_autoindent_when_generating_past_indentation(
3518 cx: &mut TestAppContext,
3519 mut rng: StdRng,
3520 ) {
3521 cx.set_global(cx.update(SettingsStore::test));
3522 cx.update(language_settings::init);
3523
3524 let text = indoc! {"
3525 fn main() {
3526 le
3527 }
3528 "};
3529 let buffer =
3530 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3531 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3532 let range = buffer.read_with(cx, |buffer, cx| {
3533 let snapshot = buffer.snapshot(cx);
3534 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
3535 });
3536 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3537 let codegen = cx.new_model(|cx| {
3538 CodegenAlternative::new(
3539 buffer.clone(),
3540 range.clone(),
3541 true,
3542 None,
3543 prompt_builder,
3544 cx,
3545 )
3546 });
3547
3548 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3549
3550 cx.background_executor.run_until_parked();
3551
3552 let mut new_text = concat!(
3553 "t mut x = 0;\n",
3554 "while x < 10 {\n",
3555 " x += 1;\n",
3556 "}", //
3557 );
3558 while !new_text.is_empty() {
3559 let max_len = cmp::min(new_text.len(), 10);
3560 let len = rng.gen_range(1..=max_len);
3561 let (chunk, suffix) = new_text.split_at(len);
3562 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3563 new_text = suffix;
3564 cx.background_executor.run_until_parked();
3565 }
3566 drop(chunks_tx);
3567 cx.background_executor.run_until_parked();
3568
3569 assert_eq!(
3570 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3571 indoc! {"
3572 fn main() {
3573 let mut x = 0;
3574 while x < 10 {
3575 x += 1;
3576 }
3577 }
3578 "}
3579 );
3580 }
3581
3582 #[gpui::test(iterations = 10)]
3583 async fn test_autoindent_when_generating_before_indentation(
3584 cx: &mut TestAppContext,
3585 mut rng: StdRng,
3586 ) {
3587 cx.update(LanguageModelRegistry::test);
3588 cx.set_global(cx.update(SettingsStore::test));
3589 cx.update(language_settings::init);
3590
3591 let text = concat!(
3592 "fn main() {\n",
3593 " \n",
3594 "}\n" //
3595 );
3596 let buffer =
3597 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3598 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3599 let range = buffer.read_with(cx, |buffer, cx| {
3600 let snapshot = buffer.snapshot(cx);
3601 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
3602 });
3603 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3604 let codegen = cx.new_model(|cx| {
3605 CodegenAlternative::new(
3606 buffer.clone(),
3607 range.clone(),
3608 true,
3609 None,
3610 prompt_builder,
3611 cx,
3612 )
3613 });
3614
3615 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3616
3617 cx.background_executor.run_until_parked();
3618
3619 let mut new_text = concat!(
3620 "let mut x = 0;\n",
3621 "while x < 10 {\n",
3622 " x += 1;\n",
3623 "}", //
3624 );
3625 while !new_text.is_empty() {
3626 let max_len = cmp::min(new_text.len(), 10);
3627 let len = rng.gen_range(1..=max_len);
3628 let (chunk, suffix) = new_text.split_at(len);
3629 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
3630 new_text = suffix;
3631 cx.background_executor.run_until_parked();
3632 }
3633 drop(chunks_tx);
3634 cx.background_executor.run_until_parked();
3635
3636 assert_eq!(
3637 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3638 indoc! {"
3639 fn main() {
3640 let mut x = 0;
3641 while x < 10 {
3642 x += 1;
3643 }
3644 }
3645 "}
3646 );
3647 }
3648
3649 #[gpui::test(iterations = 10)]
3650 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
3651 cx.update(LanguageModelRegistry::test);
3652 cx.set_global(cx.update(SettingsStore::test));
3653 cx.update(language_settings::init);
3654
3655 let text = indoc! {"
3656 func main() {
3657 \tx := 0
3658 \tfor i := 0; i < 10; i++ {
3659 \t\tx++
3660 \t}
3661 }
3662 "};
3663 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
3664 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3665 let range = buffer.read_with(cx, |buffer, cx| {
3666 let snapshot = buffer.snapshot(cx);
3667 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
3668 });
3669 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3670 let codegen = cx.new_model(|cx| {
3671 CodegenAlternative::new(
3672 buffer.clone(),
3673 range.clone(),
3674 true,
3675 None,
3676 prompt_builder,
3677 cx,
3678 )
3679 });
3680
3681 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3682 let new_text = concat!(
3683 "func main() {\n",
3684 "\tx := 0\n",
3685 "\tfor x < 10 {\n",
3686 "\t\tx++\n",
3687 "\t}", //
3688 );
3689 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
3690 drop(chunks_tx);
3691 cx.background_executor.run_until_parked();
3692
3693 assert_eq!(
3694 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3695 indoc! {"
3696 func main() {
3697 \tx := 0
3698 \tfor x < 10 {
3699 \t\tx++
3700 \t}
3701 }
3702 "}
3703 );
3704 }
3705
3706 #[gpui::test]
3707 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
3708 cx.update(LanguageModelRegistry::test);
3709 cx.set_global(cx.update(SettingsStore::test));
3710 cx.update(language_settings::init);
3711
3712 let text = indoc! {"
3713 fn main() {
3714 let x = 0;
3715 }
3716 "};
3717 let buffer =
3718 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
3719 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
3720 let range = buffer.read_with(cx, |buffer, cx| {
3721 let snapshot = buffer.snapshot(cx);
3722 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
3723 });
3724 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3725 let codegen = cx.new_model(|cx| {
3726 CodegenAlternative::new(
3727 buffer.clone(),
3728 range.clone(),
3729 false,
3730 None,
3731 prompt_builder,
3732 cx,
3733 )
3734 });
3735
3736 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
3737 chunks_tx
3738 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
3739 .unwrap();
3740 drop(chunks_tx);
3741 cx.run_until_parked();
3742
3743 // The codegen is inactive, so the buffer doesn't get modified.
3744 assert_eq!(
3745 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3746 text
3747 );
3748
3749 // Activating the codegen applies the changes.
3750 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
3751 assert_eq!(
3752 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3753 indoc! {"
3754 fn main() {
3755 let mut x = 0;
3756 x += 1;
3757 }
3758 "}
3759 );
3760
3761 // Deactivating the codegen undoes the changes.
3762 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
3763 cx.run_until_parked();
3764 assert_eq!(
3765 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
3766 text
3767 );
3768 }
3769
3770 #[gpui::test]
3771 async fn test_strip_invalid_spans_from_codeblock() {
3772 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
3773 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
3774 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
3775 assert_chunks(
3776 "```html\n```js\nLorem ipsum dolor\n```\n```",
3777 "```js\nLorem ipsum dolor\n```",
3778 )
3779 .await;
3780 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
3781 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
3782 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
3783 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
3784
3785 async fn assert_chunks(text: &str, expected_text: &str) {
3786 for chunk_size in 1..=text.len() {
3787 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
3788 .map(|chunk| chunk.unwrap())
3789 .collect::<String>()
3790 .await;
3791 assert_eq!(
3792 actual_text, expected_text,
3793 "failed to strip invalid spans, chunk size: {}",
3794 chunk_size
3795 );
3796 }
3797 }
3798
3799 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
3800 stream::iter(
3801 text.chars()
3802 .collect::<Vec<_>>()
3803 .chunks(size)
3804 .map(|chunk| Ok(chunk.iter().collect::<String>()))
3805 .collect::<Vec<_>>(),
3806 )
3807 }
3808 }
3809
3810 fn simulate_response_stream(
3811 codegen: Model<CodegenAlternative>,
3812 cx: &mut TestAppContext,
3813 ) -> mpsc::UnboundedSender<String> {
3814 let (chunks_tx, chunks_rx) = mpsc::unbounded();
3815 codegen.update(cx, |codegen, cx| {
3816 codegen.handle_stream(
3817 String::new(),
3818 String::new(),
3819 None,
3820 future::ready(Ok(LanguageModelTextStream {
3821 message_id: None,
3822 stream: chunks_rx.map(Ok).boxed(),
3823 })),
3824 cx,
3825 );
3826 });
3827 chunks_tx
3828 }
3829
3830 fn rust_lang() -> Language {
3831 Language::new(
3832 LanguageConfig {
3833 name: "Rust".into(),
3834 matcher: LanguageMatcher {
3835 path_suffixes: vec!["rs".to_string()],
3836 ..Default::default()
3837 },
3838 ..Default::default()
3839 },
3840 Some(tree_sitter_rust::LANGUAGE.into()),
3841 )
3842 .with_indents_query(
3843 r#"
3844 (call_expression) @indent
3845 (field_expression) @indent
3846 (_ "(" ")" @end) @indent
3847 (_ "{" "}" @end) @indent
3848 "#,
3849 )
3850 .unwrap()
3851 }
3852}