1mod edit_prediction_button;
2mod edit_prediction_context_view;
3mod rate_prediction_modal;
4
5use std::any::{Any as _, TypeId};
6use std::path::Path;
7use std::sync::Arc;
8
9use command_palette_hooks::CommandPaletteFilter;
10use edit_prediction::{
11 EditPredictionStore, ResetOnboarding, Zeta2FeatureFlag, example_spec::ExampleSpec,
12};
13use edit_prediction_context_view::EditPredictionContextView;
14use editor::Editor;
15use feature_flags::FeatureFlagAppExt as _;
16use git::repository::DiffType;
17use gpui::{Window, actions};
18use language::ToPoint as _;
19use log;
20use project::DisableAiSettings;
21use rate_prediction_modal::RatePredictionsModal;
22use settings::{Settings as _, SettingsStore};
23use text::ToOffset as _;
24use ui::{App, prelude::*};
25use workspace::{SplitDirection, Workspace};
26
27pub use edit_prediction_button::{EditPredictionButton, ToggleMenu};
28
29use crate::rate_prediction_modal::PredictEditsRatePredictionsFeatureFlag;
30
31actions!(
32 dev,
33 [
34 /// Opens the edit prediction context view.
35 OpenEditPredictionContextView,
36 ]
37);
38
39actions!(
40 edit_prediction,
41 [
42 /// Opens the rate completions modal.
43 RatePredictions,
44 /// Captures an ExampleSpec from the current editing session and opens it as Markdown.
45 CaptureExample,
46 ]
47);
48
49pub fn init(cx: &mut App) {
50 feature_gate_predict_edits_actions(cx);
51
52 cx.observe_new(move |workspace: &mut Workspace, _, _cx| {
53 workspace.register_action(|workspace, _: &RatePredictions, window, cx| {
54 if cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>() {
55 RatePredictionsModal::toggle(workspace, window, cx);
56 }
57 });
58
59 workspace.register_action(capture_edit_prediction_example);
60 workspace.register_action_renderer(|div, _, _, cx| {
61 let has_flag = cx.has_flag::<Zeta2FeatureFlag>();
62 div.when(has_flag, |div| {
63 div.on_action(cx.listener(
64 move |workspace, _: &OpenEditPredictionContextView, window, cx| {
65 let project = workspace.project();
66 workspace.split_item(
67 SplitDirection::Right,
68 Box::new(cx.new(|cx| {
69 EditPredictionContextView::new(
70 project.clone(),
71 workspace.client(),
72 workspace.user_store(),
73 window,
74 cx,
75 )
76 })),
77 window,
78 cx,
79 );
80 },
81 ))
82 })
83 });
84 })
85 .detach();
86}
87
88fn feature_gate_predict_edits_actions(cx: &mut App) {
89 let rate_completion_action_types = [TypeId::of::<RatePredictions>()];
90 let reset_onboarding_action_types = [TypeId::of::<ResetOnboarding>()];
91 let all_action_types = [
92 TypeId::of::<RatePredictions>(),
93 TypeId::of::<CaptureExample>(),
94 TypeId::of::<edit_prediction::ResetOnboarding>(),
95 zed_actions::OpenZedPredictOnboarding.type_id(),
96 TypeId::of::<edit_prediction::ClearHistory>(),
97 TypeId::of::<rate_prediction_modal::ThumbsUpActivePrediction>(),
98 TypeId::of::<rate_prediction_modal::ThumbsDownActivePrediction>(),
99 TypeId::of::<rate_prediction_modal::NextEdit>(),
100 TypeId::of::<rate_prediction_modal::PreviousEdit>(),
101 ];
102
103 CommandPaletteFilter::update_global(cx, |filter, _cx| {
104 filter.hide_action_types(&rate_completion_action_types);
105 filter.hide_action_types(&reset_onboarding_action_types);
106 filter.hide_action_types(&[zed_actions::OpenZedPredictOnboarding.type_id()]);
107 });
108
109 cx.observe_global::<SettingsStore>(move |cx| {
110 let is_ai_disabled = DisableAiSettings::get_global(cx).disable_ai;
111 let has_feature_flag = cx.has_flag::<PredictEditsRatePredictionsFeatureFlag>();
112
113 CommandPaletteFilter::update_global(cx, |filter, _cx| {
114 if is_ai_disabled {
115 filter.hide_action_types(&all_action_types);
116 } else if has_feature_flag {
117 filter.show_action_types(&rate_completion_action_types);
118 } else {
119 filter.hide_action_types(&rate_completion_action_types);
120 }
121 });
122 })
123 .detach();
124
125 cx.observe_flag::<PredictEditsRatePredictionsFeatureFlag, _>(move |is_enabled, cx| {
126 if !DisableAiSettings::get_global(cx).disable_ai {
127 if is_enabled {
128 CommandPaletteFilter::update_global(cx, |filter, _cx| {
129 filter.show_action_types(&rate_completion_action_types);
130 });
131 } else {
132 CommandPaletteFilter::update_global(cx, |filter, _cx| {
133 filter.hide_action_types(&rate_completion_action_types);
134 });
135 }
136 }
137 })
138 .detach();
139}
140
141fn capture_edit_prediction_example(
142 workspace: &mut Workspace,
143 _: &CaptureExample,
144 window: &mut Window,
145 cx: &mut Context<Workspace>,
146) {
147 let Some(ep_store) = EditPredictionStore::try_global(cx) else {
148 return;
149 };
150
151 let project = workspace.project().clone();
152
153 let (worktree_root, repository) = {
154 let project_ref = project.read(cx);
155 let worktree_root = project_ref
156 .visible_worktrees(cx)
157 .next()
158 .map(|worktree| worktree.read(cx).abs_path());
159 let repository = project_ref.active_repository(cx);
160 (worktree_root, repository)
161 };
162
163 let (Some(worktree_root), Some(repository)) = (worktree_root, repository) else {
164 log::error!("CaptureExampleSpec: missing worktree or active repository");
165 return;
166 };
167
168 let repository_snapshot = repository.read(cx).snapshot();
169 if worktree_root.as_ref() != repository_snapshot.work_directory_abs_path.as_ref() {
170 log::error!(
171 "repository is not at worktree root (repo={:?}, worktree={:?})",
172 repository_snapshot.work_directory_abs_path,
173 worktree_root
174 );
175 return;
176 }
177
178 let Some(repository_url) = repository_snapshot
179 .remote_origin_url
180 .clone()
181 .or_else(|| repository_snapshot.remote_upstream_url.clone())
182 else {
183 log::error!("active repository has no origin/upstream remote url");
184 return;
185 };
186
187 let Some(revision) = repository_snapshot
188 .head_commit
189 .as_ref()
190 .map(|commit| commit.sha.to_string())
191 else {
192 log::error!("active repository has no head commit");
193 return;
194 };
195
196 let mut events = ep_store.update(cx, |store, cx| {
197 store.edit_history_for_project_with_pause_split_last_event(&project, cx)
198 });
199
200 let Some(editor) = workspace.active_item_as::<Editor>(cx) else {
201 log::error!("no active editor");
202 return;
203 };
204
205 let Some(project_path) = editor.read(cx).project_path(cx) else {
206 log::error!("active editor has no project path");
207 return;
208 };
209
210 let Some((buffer, cursor_anchor)) = editor
211 .read(cx)
212 .buffer()
213 .read(cx)
214 .text_anchor_for_position(editor.read(cx).selections.newest_anchor().head(), cx)
215 else {
216 log::error!("failed to resolve cursor buffer/anchor");
217 return;
218 };
219
220 let snapshot = buffer.read(cx).snapshot();
221 let cursor_point = cursor_anchor.to_point(&snapshot);
222 let (_editable_range, context_range) =
223 edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
224 cursor_point,
225 &snapshot,
226 100,
227 50,
228 );
229
230 let cursor_path: Arc<Path> = repository
231 .read(cx)
232 .project_path_to_repo_path(&project_path, cx)
233 .map(|repo_path| Path::new(repo_path.as_unix_str()).into())
234 .unwrap_or_else(|| Path::new(project_path.path.as_unix_str()).into());
235
236 let cursor_position = {
237 let context_start_offset = context_range.start.to_offset(&snapshot);
238 let cursor_offset = cursor_anchor.to_offset(&snapshot);
239 let cursor_offset_in_excerpt = cursor_offset.saturating_sub(context_start_offset);
240 let mut excerpt = snapshot.text_for_range(context_range).collect::<String>();
241 if cursor_offset_in_excerpt <= excerpt.len() {
242 excerpt.insert_str(cursor_offset_in_excerpt, zeta_prompt::CURSOR_MARKER);
243 }
244 excerpt
245 };
246
247 let markdown_language = workspace
248 .app_state()
249 .languages
250 .language_for_name("Markdown");
251
252 cx.spawn_in(window, async move |workspace_entity, cx| {
253 let markdown_language = markdown_language.await?;
254
255 let uncommitted_diff_rx = repository.update(cx, |repository, cx| {
256 repository.diff(DiffType::HeadToWorktree, cx)
257 })?;
258
259 let uncommitted_diff = match uncommitted_diff_rx.await {
260 Ok(Ok(diff)) => diff,
261 Ok(Err(error)) => {
262 log::error!("failed to compute uncommitted diff: {error:#}");
263 return Ok(());
264 }
265 Err(error) => {
266 log::error!("uncommitted diff channel dropped: {error:#}");
267 return Ok(());
268 }
269 };
270
271 let mut edit_history = String::new();
272 let mut expected_patch = String::new();
273 if let Some(last_event) = events.pop() {
274 for event in &events {
275 zeta_prompt::write_event(&mut edit_history, event);
276 if !edit_history.ends_with('\n') {
277 edit_history.push('\n');
278 }
279 edit_history.push('\n');
280 }
281
282 zeta_prompt::write_event(&mut expected_patch, &last_event);
283 }
284
285 let format =
286 time::format_description::parse("[year]-[month]-[day] [hour]:[minute]:[second]");
287 let name = match format {
288 Ok(format) => {
289 let now = time::OffsetDateTime::now_local()
290 .unwrap_or_else(|_| time::OffsetDateTime::now_utc());
291 now.format(&format)
292 .unwrap_or_else(|_| "unknown-time".to_string())
293 }
294 Err(_) => "unknown-time".to_string(),
295 };
296
297 let markdown = ExampleSpec {
298 name,
299 repository_url,
300 revision,
301 uncommitted_diff,
302 cursor_path,
303 cursor_position,
304 edit_history,
305 expected_patch,
306 }
307 .to_markdown();
308
309 let buffer = project
310 .update(cx, |project, cx| project.create_buffer(false, cx))?
311 .await?;
312 buffer.update(cx, |buffer, cx| {
313 buffer.set_text(markdown, cx);
314 buffer.set_language(Some(markdown_language), cx);
315 })?;
316
317 workspace_entity.update_in(cx, |workspace, window, cx| {
318 workspace.add_item_to_active_pane(
319 Box::new(
320 cx.new(|cx| Editor::for_buffer(buffer, Some(project.clone()), window, cx)),
321 ),
322 None,
323 true,
324 window,
325 cx,
326 );
327 })
328 })
329 .detach_and_log_err(cx);
330}