edit_prediction_tools.rs

  1use std::{
  2    collections::hash_map::Entry,
  3    ffi::OsStr,
  4    path::{Path, PathBuf},
  5    str::FromStr,
  6    sync::Arc,
  7    time::{Duration, Instant},
  8};
  9
 10use collections::HashMap;
 11use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
 12use gpui::{
 13    Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
 14    prelude::*,
 15};
 16use language::{Buffer, DiskState};
 17use project::{Project, WorktreeId};
 18use text::ToPoint;
 19use ui::prelude::*;
 20use ui_input::SingleLineInput;
 21use workspace::{Item, SplitDirection, Workspace};
 22
 23use edit_prediction_context::{
 24    EditPredictionContext, EditPredictionExcerptOptions, SnippetStyle, SyntaxIndex,
 25};
 26
 27actions!(
 28    dev,
 29    [
 30        /// Opens the language server protocol logs viewer.
 31        OpenEditPredictionContext
 32    ]
 33);
 34
 35pub fn init(cx: &mut App) {
 36    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
 37        workspace.register_action(
 38            move |workspace, _: &OpenEditPredictionContext, window, cx| {
 39                let workspace_entity = cx.entity();
 40                let project = workspace.project();
 41                let active_editor = workspace.active_item_as::<Editor>(cx);
 42                workspace.split_item(
 43                    SplitDirection::Right,
 44                    Box::new(cx.new(|cx| {
 45                        EditPredictionTools::new(
 46                            &workspace_entity,
 47                            &project,
 48                            active_editor,
 49                            window,
 50                            cx,
 51                        )
 52                    })),
 53                    window,
 54                    cx,
 55                );
 56            },
 57        );
 58    })
 59    .detach();
 60}
 61
 62pub struct EditPredictionTools {
 63    focus_handle: FocusHandle,
 64    project: Entity<Project>,
 65    last_context: Option<ContextState>,
 66    max_bytes_input: Entity<SingleLineInput>,
 67    min_bytes_input: Entity<SingleLineInput>,
 68    cursor_context_ratio_input: Entity<SingleLineInput>,
 69    // TODO move to project or provider?
 70    syntax_index: Entity<SyntaxIndex>,
 71    last_editor: WeakEntity<Editor>,
 72    _active_editor_subscription: Option<Subscription>,
 73    _edit_prediction_context_task: Task<()>,
 74}
 75
 76struct ContextState {
 77    context_editor: Entity<Editor>,
 78    retrieval_duration: Duration,
 79}
 80
 81impl EditPredictionTools {
 82    pub fn new(
 83        workspace: &Entity<Workspace>,
 84        project: &Entity<Project>,
 85        active_editor: Option<Entity<Editor>>,
 86        window: &mut Window,
 87        cx: &mut Context<Self>,
 88    ) -> Self {
 89        cx.subscribe_in(workspace, window, |this, workspace, event, window, cx| {
 90            if let workspace::Event::ActiveItemChanged = event {
 91                if let Some(editor) = workspace.read(cx).active_item_as::<Editor>(cx) {
 92                    this._active_editor_subscription = Some(cx.subscribe_in(
 93                        &editor,
 94                        window,
 95                        |this, editor, event, window, cx| {
 96                            if let EditorEvent::SelectionsChanged { .. } = event {
 97                                this.update_context(editor, window, cx);
 98                            }
 99                        },
100                    ));
101                    this.update_context(&editor, window, cx);
102                } else {
103                    this._active_editor_subscription = None;
104                }
105            }
106        })
107        .detach();
108        let syntax_index = cx.new(|cx| SyntaxIndex::new(project, cx));
109
110        let number_input = |label: &'static str,
111                            value: &'static str,
112                            window: &mut Window,
113                            cx: &mut Context<Self>|
114         -> Entity<SingleLineInput> {
115            let input = cx.new(|cx| {
116                let input = SingleLineInput::new(window, cx, "")
117                    .label(label)
118                    .label_min_width(px(64.));
119                input.set_text(value, window, cx);
120                input
121            });
122            cx.subscribe_in(
123                &input.read(cx).editor().clone(),
124                window,
125                |this, _, event, window, cx| {
126                    if let EditorEvent::BufferEdited = event
127                        && let Some(editor) = this.last_editor.upgrade()
128                    {
129                        this.update_context(&editor, window, cx);
130                    }
131                },
132            )
133            .detach();
134            input
135        };
136
137        let mut this = Self {
138            focus_handle: cx.focus_handle(),
139            project: project.clone(),
140            last_context: None,
141            max_bytes_input: number_input("Max Bytes", "512", window, cx),
142            min_bytes_input: number_input("Min Bytes", "128", window, cx),
143            cursor_context_ratio_input: number_input("Cursor Context Ratio", "0.5", window, cx),
144            syntax_index,
145            last_editor: WeakEntity::new_invalid(),
146            _active_editor_subscription: None,
147            _edit_prediction_context_task: Task::ready(()),
148        };
149
150        if let Some(editor) = active_editor {
151            this.update_context(&editor, window, cx);
152        }
153
154        this
155    }
156
157    fn update_context(
158        &mut self,
159        editor: &Entity<Editor>,
160        window: &mut Window,
161        cx: &mut Context<Self>,
162    ) {
163        self.last_editor = editor.downgrade();
164
165        let editor = editor.read(cx);
166        let buffer = editor.buffer().clone();
167        let cursor_position = editor.selections.newest_anchor().start;
168
169        let Some(buffer) = buffer.read(cx).buffer_for_anchor(cursor_position, cx) else {
170            self.last_context.take();
171            return;
172        };
173        let current_buffer_snapshot = buffer.read(cx).snapshot();
174        let cursor_position = cursor_position
175            .text_anchor
176            .to_point(&current_buffer_snapshot);
177
178        let language = current_buffer_snapshot.language().cloned();
179        let Some(worktree_id) = self
180            .project
181            .read(cx)
182            .worktrees(cx)
183            .next()
184            .map(|worktree| worktree.read(cx).id())
185        else {
186            log::error!("Open a worktree to use edit prediction debug view");
187            self.last_context.take();
188            return;
189        };
190
191        self._edit_prediction_context_task = cx.spawn_in(window, {
192            let language_registry = self.project.read(cx).languages().clone();
193            async move |this, cx| {
194                cx.background_executor()
195                    .timer(Duration::from_millis(50))
196                    .await;
197
198                let mut start_time = None;
199
200                let Ok(task) = this.update(cx, |this, cx| {
201                    fn number_input_value<T: FromStr + Default>(
202                        input: &Entity<SingleLineInput>,
203                        cx: &App,
204                    ) -> T {
205                        input
206                            .read(cx)
207                            .editor()
208                            .read(cx)
209                            .text(cx)
210                            .parse::<T>()
211                            .unwrap_or_default()
212                    }
213
214                    let options = EditPredictionExcerptOptions {
215                        max_bytes: number_input_value(&this.max_bytes_input, cx),
216                        min_bytes: number_input_value(&this.min_bytes_input, cx),
217                        target_before_cursor_over_total_bytes: number_input_value(
218                            &this.cursor_context_ratio_input,
219                            cx,
220                        ),
221                    };
222
223                    start_time = Some(Instant::now());
224
225                    EditPredictionContext::gather_context_in_background(
226                        cursor_position,
227                        current_buffer_snapshot,
228                        options,
229                        this.syntax_index.clone(),
230                        cx,
231                    )
232                }) else {
233                    this.update(cx, |this, _cx| {
234                        this.last_context.take();
235                    })
236                    .ok();
237                    return;
238                };
239
240                let Some(context) = task.await else {
241                    // TODO: Display message
242                    this.update(cx, |this, _cx| {
243                        this.last_context.take();
244                    })
245                    .ok();
246                    return;
247                };
248                let retrieval_duration = start_time.unwrap().elapsed();
249
250                let mut languages = HashMap::default();
251                for snippet in context.snippets.iter() {
252                    let lang_id = snippet.declaration.identifier().language_id;
253                    if let Entry::Vacant(entry) = languages.entry(lang_id) {
254                        // Most snippets are gonna be the same language,
255                        // so we think it's fine to do this sequentially for now
256                        entry.insert(language_registry.language_for_id(lang_id).await.ok());
257                    }
258                }
259
260                this.update_in(cx, |this, window, cx| {
261                    let context_editor = cx.new(|cx| {
262                        let multibuffer = cx.new(|cx| {
263                            let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
264                            let excerpt_file = Arc::new(ExcerptMetadataFile {
265                                title: PathBuf::from("Cursor Excerpt").into(),
266                                worktree_id,
267                            });
268
269                            let excerpt_buffer = cx.new(|cx| {
270                                let mut buffer = Buffer::local(context.excerpt_text.body, cx);
271                                buffer.set_language(language, cx);
272                                buffer.file_updated(excerpt_file, cx);
273                                buffer
274                            });
275
276                            multibuffer.push_excerpts(
277                                excerpt_buffer,
278                                [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
279                                cx,
280                            );
281
282                            for snippet in context.snippets {
283                                let path = this
284                                    .project
285                                    .read(cx)
286                                    .path_for_entry(snippet.declaration.project_entry_id(), cx);
287
288                                let snippet_file = Arc::new(ExcerptMetadataFile {
289                                    title: PathBuf::from(format!(
290                                        "{} (Score density: {})",
291                                        path.map(|p| p.path.to_string_lossy().to_string())
292                                            .unwrap_or_else(|| "".to_string()),
293                                        snippet.score_density(SnippetStyle::Declaration)
294                                    ))
295                                    .into(),
296                                    worktree_id,
297                                });
298
299                                let excerpt_buffer = cx.new(|cx| {
300                                    let mut buffer =
301                                        Buffer::local(snippet.declaration.item_text().0, cx);
302                                    buffer.file_updated(snippet_file, cx);
303                                    if let Some(language) =
304                                        languages.get(&snippet.declaration.identifier().language_id)
305                                    {
306                                        buffer.set_language(language.clone(), cx);
307                                    }
308                                    buffer
309                                });
310
311                                multibuffer.push_excerpts(
312                                    excerpt_buffer,
313                                    [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
314                                    cx,
315                                );
316                            }
317
318                            multibuffer
319                        });
320
321                        Editor::new(EditorMode::full(), multibuffer, None, window, cx)
322                    });
323
324                    this.last_context = Some(ContextState {
325                        context_editor,
326                        retrieval_duration,
327                    });
328                    cx.notify();
329                })
330                .ok();
331            }
332        });
333    }
334}
335
336impl Focusable for EditPredictionTools {
337    fn focus_handle(&self, _cx: &App) -> FocusHandle {
338        self.focus_handle.clone()
339    }
340}
341
342impl Item for EditPredictionTools {
343    type Event = ();
344
345    fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
346        "Edit Prediction Context Debug View".into()
347    }
348
349    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
350        Some(Icon::new(IconName::ZedPredict))
351    }
352}
353
354impl EventEmitter<()> for EditPredictionTools {}
355
356impl Render for EditPredictionTools {
357    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
358        v_flex()
359            .size_full()
360            .bg(cx.theme().colors().editor_background)
361            .child(
362                h_flex()
363                    .items_start()
364                    .w_full()
365                    .child(
366                        v_flex()
367                            .flex_1()
368                            .p_4()
369                            .gap_2()
370                            .child(Headline::new("Excerpt Options").size(HeadlineSize::Small))
371                            .child(
372                                h_flex()
373                                    .gap_2()
374                                    .child(self.max_bytes_input.clone())
375                                    .child(self.min_bytes_input.clone())
376                                    .child(self.cursor_context_ratio_input.clone()),
377                            ),
378                    )
379                    .child(ui::Divider::vertical())
380                    .when_some(self.last_context.as_ref(), |this, last_context| {
381                        this.child(
382                            v_flex()
383                                .p_4()
384                                .gap_2()
385                                .min_w(px(160.))
386                                .child(Headline::new("Stats").size(HeadlineSize::Small))
387                                .child(
388                                    h_flex()
389                                        .gap_1()
390                                        .child(
391                                            Label::new("Time to retrieve")
392                                                .color(Color::Muted)
393                                                .size(LabelSize::Small),
394                                        )
395                                        .child(
396                                            Label::new(
397                                                if last_context.retrieval_duration.as_micros()
398                                                    > 1000
399                                                {
400                                                    format!(
401                                                        "{} ms",
402                                                        last_context.retrieval_duration.as_millis()
403                                                    )
404                                                } else {
405                                                    format!(
406                                                        "{} ยตs",
407                                                        last_context.retrieval_duration.as_micros()
408                                                    )
409                                                },
410                                            )
411                                            .size(LabelSize::Small),
412                                        ),
413                                ),
414                        )
415                    }),
416            )
417            .children(self.last_context.as_ref().map(|c| c.context_editor.clone()))
418    }
419}
420
421// Using same approach as commit view
422
423struct ExcerptMetadataFile {
424    title: Arc<Path>,
425    worktree_id: WorktreeId,
426}
427
428impl language::File for ExcerptMetadataFile {
429    fn as_local(&self) -> Option<&dyn language::LocalFile> {
430        None
431    }
432
433    fn disk_state(&self) -> DiskState {
434        DiskState::New
435    }
436
437    fn path(&self) -> &Arc<Path> {
438        &self.title
439    }
440
441    fn full_path(&self, _: &App) -> PathBuf {
442        self.title.as_ref().into()
443    }
444
445    fn file_name<'a>(&'a self, _: &'a App) -> &'a OsStr {
446        self.title.file_name().unwrap()
447    }
448
449    fn worktree_id(&self, _: &App) -> WorktreeId {
450        self.worktree_id
451    }
452
453    fn to_proto(&self, _: &App) -> language::proto::File {
454        unimplemented!()
455    }
456
457    fn is_private(&self) -> bool {
458        false
459    }
460}