edit_prediction_ui.rs

  1mod edit_prediction_button;
  2mod edit_prediction_context_view;
  3mod rate_prediction_modal;
  4
  5use std::any::{Any as _, TypeId};
  6use std::path::Path;
  7use std::sync::Arc;
  8
  9use command_palette_hooks::CommandPaletteFilter;
 10use edit_prediction::{
 11    EditPredictionStore, ResetOnboarding, Zeta2FeatureFlag, example_spec::ExampleSpec,
 12};
 13use edit_prediction_context_view::EditPredictionContextView;
 14use editor::Editor;
 15use feature_flags::FeatureFlagAppExt as _;
 16use git::repository::DiffType;
 17use gpui::{Window, actions};
 18use language::ToPoint as _;
 19use log;
 20use project::DisableAiSettings;
 21use rate_prediction_modal::RatePredictionsModal;
 22use settings::{Settings as _, SettingsStore};
 23use text::ToOffset as _;
 24use ui::{App, prelude::*};
 25use workspace::{SplitDirection, Workspace};
 26
 27pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
 28
 29use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
 30
 31actions!(
 32    dev,
 33    [
 34        /// Opens the edit prediction context view.
 35        OpenEditPredictionContextView,
 36    ]
 37);
 38
 39actions!(
 40    edit_prediction,
 41    [
 42        /// Opens the rate completions modal.
 43        RatePredictions,
 44        /// Captures an ExampleSpec from the current editing session and opens it as Markdown.
 45        CaptureExample,
 46    ]
 47);
 48
 49pub fn init(cx: &mut App) {
 50    feature_gate_predict_edits_actions(cx);
 51
 52    cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
 53        workspace.register_action(|workspace, _: &RatePredictions, window, cx| {
 54            if cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>() {
 55                RatePredictionsModal::toggle(workspace, window, cx);
 56            }
 57        });
 58
 59        workspace.register_action(capture_edit_prediction_example);
 60        workspace.register_action_renderer(|div, _, _, cx| {
 61            let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
 62            div.when(has_flag, |div| {
 63                div.on_action(cx.listener(
 64                    move |workspace, _: &OpenEditPredictionContextView, window, cx| {
 65                        let project = workspace.project();
 66                        workspace.split_item(
 67                            SplitDirection::Right,
 68                            Box::new(cx.new(|cx| {
 69                                EditPredictionContextView::new(
 70                                    project.clone(),
 71                                    workspace.client(),
 72                                    workspace.user_store(),
 73                                    window,
 74                                    cx,
 75                                )
 76                            })),
 77                            window,
 78                            cx,
 79                        );
 80                    },
 81                ))
 82            })
 83        });
 84    })
 85    .detach();
 86}
 87
 88fn feature_gate_predict_edits_actions(cx: &mut App) {
 89    let rate_completion_action_types = [TypeId::of::<RatePredictions>()];
 90    let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
 91    let all_action_types = [
 92        TypeId::of::<RatePredictions>(),
 93        TypeId::of::<CaptureExample>(),
 94        TypeId::of::<edit_prediction::ResetOnboarding>(),
 95        zed_actions::OpenZedPredictOnboarding.type_id(),
 96        TypeId::of::<edit_prediction::ClearHistory>(),
 97        TypeId::of::<rate_prediction_modal::ThumbsUpActivePrediction>(),
 98        TypeId::of::<rate_prediction_modal::ThumbsDownActivePrediction>(),
 99        TypeId::of::<rate_prediction_modal::NextEdit>(),
100        TypeId::of::<rate_prediction_modal::PreviousEdit>(),
101    ];
102
103    CommandPaletteFilter::update_global(cx, |filter, _cx| {
104        filter.hide_action_types(&rate_completion_action_types);
105        filter.hide_action_types(&reset_onboarding_action_types);
106        filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
107    });
108
109    cx.observe_global::<SettingsStore>(move |cx| {
110        let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
111        let has_feature_flag = cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>();
112
113        CommandPaletteFilter::update_global(cx, |filter, _cx| {
114            if is_ai_disabled {
115                filter.hide_action_types(&all_action_types);
116            } else if has_feature_flag {
117                filter.show_action_types(&rate_completion_action_types);
118            } else {
119                filter.hide_action_types(&rate_completion_action_types);
120            }
121        });
122    })
123    .detach();
124
125    cx.observe_flag::<PredictEditsRatePredictionsFeatureFlag, _>(move |is_enabled, cx| {
126        if !DisableAiSettings::get_global(cx).disable_ai {
127            if is_enabled {
128                CommandPaletteFilter::update_global(cx, |filter, _cx| {
129                    filter.show_action_types(&rate_completion_action_types);
130                });
131            } else {
132                CommandPaletteFilter::update_global(cx, |filter, _cx| {
133                    filter.hide_action_types(&rate_completion_action_types);
134                });
135            }
136        }
137    })
138    .detach();
139}
140
141fn capture_edit_prediction_example(
142    workspace: &mut Workspace,
143    _: &CaptureExample,
144    window: &mut Window,
145    cx: &mut Context<Workspace>,
146) {
147    let Some(ep_store) = EditPredictionStore::try_global(cx) else {
148        return;
149    };
150
151    let project = workspace.project().clone();
152
153    let (worktree_root, repository) = {
154        let project_ref = project.read(cx);
155        let worktree_root = project_ref
156            .visible_worktrees(cx)
157            .next()
158            .map(|worktree| worktree.read(cx).abs_path());
159        let repository = project_ref.active_repository(cx);
160        (worktree_root, repository)
161    };
162
163    let (Some(worktree_root), Some(repository)) = (worktree_root, repository) else {
164        log::error!("CaptureExampleSpec: missing worktree or active repository");
165        return;
166    };
167
168    let repository_snapshot = repository.read(cx).snapshot();
169    if worktree_root.as_ref() != repository_snapshot.work_directory_abs_path.as_ref() {
170        log::error!(
171            "repository is not at worktree root (repo={:?}, worktree={:?})",
172            repository_snapshot.work_directory_abs_path,
173            worktree_root
174        );
175        return;
176    }
177
178    let Some(repository_url) = repository_snapshot
179        .remote_origin_url
180        .clone()
181        .or_else(|| repository_snapshot.remote_upstream_url.clone())
182    else {
183        log::error!("active repository has no origin/upstream remote url");
184        return;
185    };
186
187    let Some(revision) = repository_snapshot
188        .head_commit
189        .as_ref()
190        .map(|commit| commit.sha.to_string())
191    else {
192        log::error!("active repository has no head commit");
193        return;
194    };
195
196    let mut events = ep_store.update(cx, |store, cx| {
197        store.edit_history_for_project_with_pause_split_last_event(&project, cx)
198    });
199
200    let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
201        log::error!("no active editor");
202        return;
203    };
204
205    let Some(project_path) = editor.read(cx).project_path(cx) else {
206        log::error!("active editor has no project path");
207        return;
208    };
209
210    let Some((buffer, cursor_anchor)) = editor
211        .read(cx)
212        .buffer()
213        .read(cx)
214        .text_anchor_for_position(editor.read(cx).selections.newest_anchor().head(), cx)
215    else {
216        log::error!("failed to resolve cursor buffer/anchor");
217        return;
218    };
219
220    let snapshot = buffer.read(cx).snapshot();
221    let cursor_point = cursor_anchor.to_point(&snapshot);
222    let (_editable_range, context_range) =
223        edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
224            cursor_point,
225            &snapshot,
226            100,
227            50,
228        );
229
230    let cursor_path: Arc<Path> = repository
231        .read(cx)
232        .project_path_to_repo_path(&project_path, cx)
233        .map(|repo_path| Path::new(repo_path.as_unix_str()).into())
234        .unwrap_or_else(|| Path::new(project_path.path.as_unix_str()).into());
235
236    let cursor_position = {
237        let context_start_offset = context_range.start.to_offset(&snapshot);
238        let cursor_offset = cursor_anchor.to_offset(&snapshot);
239        let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
240        let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
241        if cursor_offset_in_excerpt <= excerpt.len() {
242            excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
243        }
244        excerpt
245    };
246
247    let markdown_language = workspace
248        .app_state()
249        .languages
250        .language_for_name("Markdown");
251
252    cx.spawn_in(window, async move |workspace_entity, cx| {
253        let markdown_language = markdown_language.await?;
254
255        let uncommitted_diff_rx = repository.update(cx, |repository, cx| {
256            repository.diff(DiffType::HeadToWorktree, cx)
257        })?;
258
259        let uncommitted_diff = match uncommitted_diff_rx.await {
260            Ok(Ok(diff)) => diff,
261            Ok(Err(error)) => {
262                log::error!("failed to compute uncommitted diff: {error:#}");
263                return Ok(());
264            }
265            Err(error) => {
266                log::error!("uncommitted diff channel dropped: {error:#}");
267                return Ok(());
268            }
269        };
270
271        let mut edit_history = String::new();
272        let mut expected_patch = String::new();
273        if let Some(last_event) = events.pop() {
274            for event in &events {
275                zeta_prompt::write_event(&mut edit_history, event);
276                if !edit_history.ends_with('\n') {
277                    edit_history.push('\n');
278                }
279                edit_history.push('\n');
280            }
281
282            zeta_prompt::write_event(&mut expected_patch, &last_event);
283        }
284
285        let format =
286            time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
287        let name = match format {
288            Ok(format) => {
289                let now = time::OffsetDateTime::now_local()
290                    .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
291                now.format(&format)
292                    .unwrap_or_else(|_| "unknown-time".to_string())
293            }
294            Err(_) => "unknown-time".to_string(),
295        };
296
297        let markdown = ExampleSpec {
298            name,
299            repository_url,
300            revision,
301            uncommitted_diff,
302            cursor_path,
303            cursor_position,
304            edit_history,
305            expected_patch,
306        }
307        .to_markdown();
308
309        let buffer = project
310            .update(cx, |project, cx| project.create_buffer(false, cx))?
311            .await?;
312        buffer.update(cx, |buffer, cx| {
313            buffer.set_text(markdown, cx);
314            buffer.set_language(Some(markdown_language), cx);
315        })?;
316
317        workspace_entity.update_in(cx, |workspace, window, cx| {
318            workspace.add_item_to_active_pane(
319                Box::new(
320                    cx.new(|cx| Editor::for_buffer(buffer, Some(project.clone()), window, cx)),
321                ),
322                None,
323                true,
324                window,
325                cx,
326            );
327        })
328    })
329    .detach_and_log_err(cx);
330}