edit_prediction_tools.rs

  1use std::{
  2    collections::hash_map::Entry,
  3    ffi::OsStr,
  4    path::{Path, PathBuf},
  5    str::FromStr,
  6    sync::Arc,
  7    time::{Duration, Instant},
  8};
  9
 10use collections::HashMap;
 11use editor::{Editor, EditorEvent, EditorMode, ExcerptRange, MultiBuffer};
 12use gpui::{
 13    Entity, EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity, actions,
 14    prelude::*,
 15};
 16use language::{Buffer, DiskState};
 17use project::{Project, WorktreeId};
 18use text::ToPoint;
 19use ui::prelude::*;
 20use ui_input::SingleLineInput;
 21use workspace::{Item, SplitDirection, Workspace};
 22
 23use edit_prediction_context::{
 24    EditPredictionContext, EditPredictionExcerptOptions, SnippetStyle, SyntaxIndex,
 25};
 26
 27actions!(
 28    dev,
 29    [
 30        /// Opens the language server protocol logs viewer.
 31        OpenEditPredictionContext
 32    ]
 33);
 34
 35pub fn init(cx: &mut App) {
 36    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
 37        workspace.register_action(
 38            move |workspace, _: &OpenEditPredictionContext, window, cx| {
 39                let workspace_entity = cx.entity();
 40                let project = workspace.project();
 41                let active_editor = workspace.active_item_as::<Editor>(cx);
 42                workspace.split_item(
 43                    SplitDirection::Right,
 44                    Box::new(cx.new(|cx| {
 45                        EditPredictionTools::new(
 46                            &workspace_entity,
 47                            &project,
 48                            active_editor,
 49                            window,
 50                            cx,
 51                        )
 52                    })),
 53                    window,
 54                    cx,
 55                );
 56            },
 57        );
 58    })
 59    .detach();
 60}
 61
 62pub struct EditPredictionTools {
 63    focus_handle: FocusHandle,
 64    project: Entity<Project>,
 65    last_context: Option<ContextState>,
 66    max_bytes_input: Entity<SingleLineInput>,
 67    min_bytes_input: Entity<SingleLineInput>,
 68    cursor_context_ratio_input: Entity<SingleLineInput>,
 69    // TODO move to project or provider?
 70    syntax_index: Entity<SyntaxIndex>,
 71    last_editor: WeakEntity<Editor>,
 72    _active_editor_subscription: Option<Subscription>,
 73    _edit_prediction_context_task: Task<()>,
 74}
 75
 76struct ContextState {
 77    context_editor: Entity<Editor>,
 78    retrieval_duration: Duration,
 79}
 80
 81impl EditPredictionTools {
 82    pub fn new(
 83        workspace: &Entity<Workspace>,
 84        project: &Entity<Project>,
 85        active_editor: Option<Entity<Editor>>,
 86        window: &mut Window,
 87        cx: &mut Context<Self>,
 88    ) -> Self {
 89        cx.subscribe_in(workspace, window, |this, workspace, event, window, cx| {
 90            if let workspace::Event::ActiveItemChanged = event {
 91                if let Some(editor) = workspace.read(cx).active_item_as::<Editor>(cx) {
 92                    this._active_editor_subscription = Some(cx.subscribe_in(
 93                        &editor,
 94                        window,
 95                        |this, editor, event, window, cx| {
 96                            if let EditorEvent::SelectionsChanged { .. } = event {
 97                                this.update_context(editor, window, cx);
 98                            }
 99                        },
100                    ));
101                    this.update_context(&editor, window, cx);
102                } else {
103                    this._active_editor_subscription = None;
104                }
105            }
106        })
107        .detach();
108        let syntax_index = cx.new(|cx| SyntaxIndex::new(project, cx));
109
110        let number_input = |label: &'static str,
111                            value: &'static str,
112                            window: &mut Window,
113                            cx: &mut Context<Self>|
114         -> Entity<SingleLineInput> {
115            let input = cx.new(|cx| {
116                let input = SingleLineInput::new(window, cx, "")
117                    .label(label)
118                    .label_min_width(px(64.));
119                input.set_text(value, window, cx);
120                input
121            });
122            cx.subscribe_in(
123                &input.read(cx).editor().clone(),
124                window,
125                |this, _, event, window, cx| {
126                    if let EditorEvent::BufferEdited = event
127                        && let Some(editor) = this.last_editor.upgrade()
128                    {
129                        this.update_context(&editor, window, cx);
130                    }
131                },
132            )
133            .detach();
134            input
135        };
136
137        let mut this = Self {
138            focus_handle: cx.focus_handle(),
139            project: project.clone(),
140            last_context: None,
141            max_bytes_input: number_input("Max Bytes", "512", window, cx),
142            min_bytes_input: number_input("Min Bytes", "128", window, cx),
143            cursor_context_ratio_input: number_input("Cursor Context Ratio", "0.5", window, cx),
144            syntax_index,
145            last_editor: WeakEntity::new_invalid(),
146            _active_editor_subscription: None,
147            _edit_prediction_context_task: Task::ready(()),
148        };
149
150        if let Some(editor) = active_editor {
151            this.update_context(&editor, window, cx);
152        }
153
154        this
155    }
156
157    fn update_context(
158        &mut self,
159        editor: &Entity<Editor>,
160        window: &mut Window,
161        cx: &mut Context<Self>,
162    ) {
163        self.last_editor = editor.downgrade();
164
165        let editor = editor.read(cx);
166        let buffer = editor.buffer().clone();
167        let cursor_position = editor.selections.newest_anchor().start;
168
169        let Some(buffer) = buffer.read(cx).buffer_for_anchor(cursor_position, cx) else {
170            self.last_context.take();
171            return;
172        };
173        let current_buffer_snapshot = buffer.read(cx).snapshot();
174        let cursor_position = cursor_position
175            .text_anchor
176            .to_point(&current_buffer_snapshot);
177
178        let language = current_buffer_snapshot.language().cloned();
179        let Some(worktree_id) = self
180            .project
181            .read(cx)
182            .worktrees(cx)
183            .next()
184            .map(|worktree| worktree.read(cx).id())
185        else {
186            log::error!("Open a worktree to use edit prediction debug view");
187            self.last_context.take();
188            return;
189        };
190
191        self._edit_prediction_context_task = cx.spawn_in(window, {
192            let language_registry = self.project.read(cx).languages().clone();
193            async move |this, cx| {
194                cx.background_executor()
195                    .timer(Duration::from_millis(50))
196                    .await;
197
198                let mut start_time = None;
199
200                let Ok(task) = this.update(cx, |this, cx| {
201                    fn number_input_value<T: FromStr + Default>(
202                        input: &Entity<SingleLineInput>,
203                        cx: &App,
204                    ) -> T {
205                        input
206                            .read(cx)
207                            .editor()
208                            .read(cx)
209                            .text(cx)
210                            .parse::<T>()
211                            .unwrap_or_default()
212                    }
213
214                    let options = EditPredictionExcerptOptions {
215                        max_bytes: number_input_value(&this.max_bytes_input, cx),
216                        min_bytes: number_input_value(&this.min_bytes_input, cx),
217                        target_before_cursor_over_total_bytes: number_input_value(
218                            &this.cursor_context_ratio_input,
219                            cx,
220                        ),
221                    };
222
223                    start_time = Some(Instant::now());
224
225                    // TODO use global zeta instead
226                    EditPredictionContext::gather_context_in_background(
227                        cursor_position,
228                        current_buffer_snapshot,
229                        options,
230                        Some(this.syntax_index.clone()),
231                        cx,
232                    )
233                }) else {
234                    this.update(cx, |this, _cx| {
235                        this.last_context.take();
236                    })
237                    .ok();
238                    return;
239                };
240
241                let Some(context) = task.await else {
242                    // TODO: Display message
243                    this.update(cx, |this, _cx| {
244                        this.last_context.take();
245                    })
246                    .ok();
247                    return;
248                };
249                let retrieval_duration = start_time.unwrap().elapsed();
250
251                let mut languages = HashMap::default();
252                for snippet in context.snippets.iter() {
253                    let lang_id = snippet.declaration.identifier().language_id;
254                    if let Entry::Vacant(entry) = languages.entry(lang_id) {
255                        // Most snippets are gonna be the same language,
256                        // so we think it's fine to do this sequentially for now
257                        entry.insert(language_registry.language_for_id(lang_id).await.ok());
258                    }
259                }
260
261                this.update_in(cx, |this, window, cx| {
262                    let context_editor = cx.new(|cx| {
263                        let multibuffer = cx.new(|cx| {
264                            let mut multibuffer = MultiBuffer::new(language::Capability::ReadOnly);
265                            let excerpt_file = Arc::new(ExcerptMetadataFile {
266                                title: PathBuf::from("Cursor Excerpt").into(),
267                                worktree_id,
268                            });
269
270                            let excerpt_buffer = cx.new(|cx| {
271                                let mut buffer = Buffer::local(context.excerpt_text.body, cx);
272                                buffer.set_language(language, cx);
273                                buffer.file_updated(excerpt_file, cx);
274                                buffer
275                            });
276
277                            multibuffer.push_excerpts(
278                                excerpt_buffer,
279                                [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
280                                cx,
281                            );
282
283                            for snippet in context.snippets {
284                                let path = this
285                                    .project
286                                    .read(cx)
287                                    .path_for_entry(snippet.declaration.project_entry_id(), cx);
288
289                                let snippet_file = Arc::new(ExcerptMetadataFile {
290                                    title: PathBuf::from(format!(
291                                        "{} (Score density: {})",
292                                        path.map(|p| p.path.to_string_lossy().to_string())
293                                            .unwrap_or_else(|| "".to_string()),
294                                        snippet.score_density(SnippetStyle::Declaration)
295                                    ))
296                                    .into(),
297                                    worktree_id,
298                                });
299
300                                let excerpt_buffer = cx.new(|cx| {
301                                    let mut buffer =
302                                        Buffer::local(snippet.declaration.item_text().0, cx);
303                                    buffer.file_updated(snippet_file, cx);
304                                    if let Some(language) =
305                                        languages.get(&snippet.declaration.identifier().language_id)
306                                    {
307                                        buffer.set_language(language.clone(), cx);
308                                    }
309                                    buffer
310                                });
311
312                                multibuffer.push_excerpts(
313                                    excerpt_buffer,
314                                    [ExcerptRange::new(text::Anchor::MIN..text::Anchor::MAX)],
315                                    cx,
316                                );
317                            }
318
319                            multibuffer
320                        });
321
322                        Editor::new(EditorMode::full(), multibuffer, None, window, cx)
323                    });
324
325                    this.last_context = Some(ContextState {
326                        context_editor,
327                        retrieval_duration,
328                    });
329                    cx.notify();
330                })
331                .ok();
332            }
333        });
334    }
335}
336
337impl Focusable for EditPredictionTools {
338    fn focus_handle(&self, _cx: &App) -> FocusHandle {
339        self.focus_handle.clone()
340    }
341}
342
343impl Item for EditPredictionTools {
344    type Event = ();
345
346    fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
347        "Edit Prediction Context Debug View".into()
348    }
349
350    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
351        Some(Icon::new(IconName::ZedPredict))
352    }
353}
354
355impl EventEmitter<()> for EditPredictionTools {}
356
357impl Render for EditPredictionTools {
358    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
359        v_flex()
360            .size_full()
361            .bg(cx.theme().colors().editor_background)
362            .child(
363                h_flex()
364                    .items_start()
365                    .w_full()
366                    .child(
367                        v_flex()
368                            .flex_1()
369                            .p_4()
370                            .gap_2()
371                            .child(Headline::new("Excerpt Options").size(HeadlineSize::Small))
372                            .child(
373                                h_flex()
374                                    .gap_2()
375                                    .child(self.max_bytes_input.clone())
376                                    .child(self.min_bytes_input.clone())
377                                    .child(self.cursor_context_ratio_input.clone()),
378                            ),
379                    )
380                    .child(ui::Divider::vertical())
381                    .when_some(self.last_context.as_ref(), |this, last_context| {
382                        this.child(
383                            v_flex()
384                                .p_4()
385                                .gap_2()
386                                .min_w(px(160.))
387                                .child(Headline::new("Stats").size(HeadlineSize::Small))
388                                .child(
389                                    h_flex()
390                                        .gap_1()
391                                        .child(
392                                            Label::new("Time to retrieve")
393                                                .color(Color::Muted)
394                                                .size(LabelSize::Small),
395                                        )
396                                        .child(
397                                            Label::new(
398                                                if last_context.retrieval_duration.as_micros()
399                                                    > 1000
400                                                {
401                                                    format!(
402                                                        "{} ms",
403                                                        last_context.retrieval_duration.as_millis()
404                                                    )
405                                                } else {
406                                                    format!(
407                                                        "{} ยตs",
408                                                        last_context.retrieval_duration.as_micros()
409                                                    )
410                                                },
411                                            )
412                                            .size(LabelSize::Small),
413                                        ),
414                                ),
415                        )
416                    }),
417            )
418            .children(self.last_context.as_ref().map(|c| c.context_editor.clone()))
419    }
420}
421
422// Using same approach as commit view
423
424struct ExcerptMetadataFile {
425    title: Arc<Path>,
426    worktree_id: WorktreeId,
427}
428
429impl language::File for ExcerptMetadataFile {
430    fn as_local(&self) -> Option<&dyn language::LocalFile> {
431        None
432    }
433
434    fn disk_state(&self) -> DiskState {
435        DiskState::New
436    }
437
438    fn path(&self) -> &Arc<Path> {
439        &self.title
440    }
441
442    fn full_path(&self, _: &App) -> PathBuf {
443        self.title.as_ref().into()
444    }
445
446    fn file_name<'a>(&'a self, _: &'a App) -> &'a OsStr {
447        self.title.file_name().unwrap()
448    }
449
450    fn worktree_id(&self, _: &App) -> WorktreeId {
451        self.worktree_id
452    }
453
454    fn to_proto(&self, _: &App) -> language::proto::File {
455        unimplemented!()
456    }
457
458    fn is_private(&self) -> bool {
459        false
460    }
461}