1use crate::{
2 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
3 EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
4 prediction::EditPredictionResult, zeta::compute_edits,
5};
6use anyhow::{Context as _, Result};
7use cloud_llm_client::EditPredictionRejectReason;
8use credentials_provider::CredentialsProvider;
9use futures::AsyncReadExt as _;
10use gpui::{
11 App, AppContext as _, Context, Entity, Global, SharedString, Task,
12 http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
13};
14use language::{ToOffset, ToPoint as _};
15use language_model::{ApiKeyState, EnvVar, env_var};
16use release_channel::AppVersion;
17use serde::{Deserialize, Serialize};
18use std::{mem, ops::Range, path::Path, sync::Arc};
19use zed_credentials_provider::global as global_credentials_provider;
20use zeta_prompt::ZetaPromptInput;
21
22const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
23
24pub struct Mercury {
25 pub api_token: Entity<ApiKeyState>,
26 payment_required_error: bool,
27}
28
29impl Mercury {
30 pub fn new(cx: &mut App) -> Self {
31 Mercury {
32 api_token: mercury_api_token(cx),
33 payment_required_error: false,
34 }
35 }
36
37 pub fn has_payment_required_error(&self) -> bool {
38 self.payment_required_error
39 }
40
41 pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
42 self.payment_required_error = payment_required_error;
43 }
44
45 pub(crate) fn request_prediction(
46 &mut self,
47 EditPredictionModelInput {
48 buffer,
49 snapshot,
50 position,
51 events,
52 related_files,
53 debug_tx,
54 ..
55 }: EditPredictionModelInput,
56 credentials_provider: Arc<dyn CredentialsProvider>,
57 cx: &mut Context<EditPredictionStore>,
58 ) -> Task<Result<Option<EditPredictionResult>>> {
59 self.api_token.update(cx, |key_state, cx| {
60 _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
61 });
62 let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
63 return Task::ready(Ok(None));
64 };
65 let full_path: Arc<Path> = snapshot
66 .file()
67 .map(|file| file.full_path(cx))
68 .unwrap_or_else(|| "untitled".into())
69 .into();
70
71 let http_client = cx.http_client();
72 let cursor_point = position.to_point(&snapshot);
73 let request_start = cx.background_executor().now();
74 let active_buffer = buffer.clone();
75
76 let result = cx.background_spawn(async move {
77 let cursor_offset = cursor_point.to_offset(&snapshot);
78 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
79 crate::cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
80
81 let related_files = zeta_prompt::filter_redundant_excerpts(
82 related_files,
83 full_path.as_ref(),
84 excerpt_point_range.start.row..excerpt_point_range.end.row,
85 );
86
87 let cursor_excerpt: Arc<str> = snapshot
88 .text_for_range(excerpt_point_range.clone())
89 .collect::<String>()
90 .into();
91 let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
92 &snapshot,
93 cursor_offset,
94 &excerpt_offset_range,
95 );
96 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
97 &cursor_excerpt,
98 cursor_offset_in_excerpt,
99 &syntax_ranges,
100 );
101
102 let editable_offset_range = (excerpt_offset_range.start
103 + excerpt_ranges.editable_350.start)
104 ..(excerpt_offset_range.start + excerpt_ranges.editable_350.end);
105
106 let inputs = zeta_prompt::ZetaPromptInput {
107 events,
108 related_files: Some(related_files),
109 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
110 - excerpt_offset_range.start,
111 cursor_path: full_path.clone(),
112 cursor_excerpt,
113 experiment: None,
114 excerpt_start_row: Some(excerpt_point_range.start.row),
115 excerpt_ranges,
116 syntax_ranges: Some(syntax_ranges),
117 active_buffer_diagnostics: vec![],
118 in_open_source_repo: false,
119 can_collect_data: false,
120 repo_url: None,
121 };
122
123 let prompt = build_prompt(&inputs);
124
125 if let Some(debug_tx) = &debug_tx {
126 debug_tx
127 .unbounded_send(DebugEvent::EditPredictionStarted(
128 EditPredictionStartedDebugEvent {
129 buffer: active_buffer.downgrade(),
130 prompt: Some(prompt.clone()),
131 position,
132 },
133 ))
134 .ok();
135 }
136
137 let request_body = open_ai::Request {
138 model: "mercury-coder".into(),
139 messages: vec![open_ai::RequestMessage::User {
140 content: open_ai::MessageContent::Plain(prompt),
141 }],
142 stream: false,
143 stream_options: None,
144 max_completion_tokens: None,
145 stop: vec![],
146 temperature: None,
147 tool_choice: None,
148 parallel_tool_calls: None,
149 tools: vec![],
150 prompt_cache_key: None,
151 reasoning_effort: None,
152 };
153
154 let buf = serde_json::to_vec(&request_body)?;
155 let body: AsyncBody = buf.into();
156
157 let request = http_client::Request::builder()
158 .uri(MERCURY_API_URL)
159 .header("Content-Type", "application/json")
160 .header("Authorization", format!("Bearer {}", api_token))
161 .header("Connection", "keep-alive")
162 .method(Method::POST)
163 .body(body)
164 .context("Failed to create request")?;
165
166 let mut response = http_client
167 .send(request)
168 .await
169 .context("Failed to send request")?;
170
171 let mut body: Vec<u8> = Vec::new();
172 response
173 .body_mut()
174 .read_to_end(&mut body)
175 .await
176 .context("Failed to read response body")?;
177
178 if !response.status().is_success() {
179 if response.status() == StatusCode::PAYMENT_REQUIRED {
180 anyhow::bail!(MercuryPaymentRequiredError(
181 mercury_payment_required_message(&body),
182 ));
183 }
184
185 anyhow::bail!(
186 "Request failed with status: {:?}\nBody: {}",
187 response.status(),
188 String::from_utf8_lossy(&body),
189 );
190 };
191
192 let mut response: open_ai::Response =
193 serde_json::from_slice(&body).context("Failed to parse response")?;
194
195 let id = mem::take(&mut response.id);
196 let response_str = text_from_response(response).unwrap_or_default();
197
198 if let Some(debug_tx) = &debug_tx {
199 debug_tx
200 .unbounded_send(DebugEvent::EditPredictionFinished(
201 EditPredictionFinishedDebugEvent {
202 buffer: active_buffer.downgrade(),
203 model_output: Some(response_str.clone()),
204 position,
205 },
206 ))
207 .ok();
208 }
209
210 let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
211 let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
212
213 let mut edits = Vec::new();
214 const NO_PREDICTION_OUTPUT: &str = "None";
215
216 if response_str != NO_PREDICTION_OUTPUT {
217 let old_text = snapshot
218 .text_for_range(editable_offset_range.clone())
219 .collect::<String>();
220 edits = compute_edits(
221 old_text,
222 &response_str,
223 editable_offset_range.start,
224 &snapshot,
225 );
226 }
227
228 anyhow::Ok((id, edits, snapshot, inputs))
229 });
230
231 cx.spawn(async move |ep_store, cx| {
232 let result = result.await.context("Mercury edit prediction failed");
233
234 let has_payment_required_error = result
235 .as_ref()
236 .err()
237 .is_some_and(is_mercury_payment_required_error);
238
239 ep_store.update(cx, |store, cx| {
240 store
241 .mercury
242 .set_payment_required_error(has_payment_required_error);
243 cx.notify();
244 })?;
245
246 let (id, edits, old_snapshot, inputs) = result?;
247 anyhow::Ok(Some(
248 EditPredictionResult::new(
249 EditPredictionId(id.into()),
250 &buffer,
251 &old_snapshot,
252 edits.into(),
253 None,
254 inputs,
255 None,
256 cx.background_executor().now() - request_start,
257 cx,
258 )
259 .await,
260 ))
261 })
262 }
263}
264
265fn build_prompt(inputs: &ZetaPromptInput) -> String {
266 const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
267 const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
268 const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
269 const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
270 const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
271 const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
272 const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
273 const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
274 const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
275 const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
276 const CURSOR_TAG: &str = "<|cursor|>";
277 const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
278 const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
279
280 let mut prompt = String::new();
281
282 push_delimited(
283 &mut prompt,
284 RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
285 |prompt| {
286 for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() {
287 for related_excerpt in &related_file.excerpts {
288 push_delimited(
289 prompt,
290 RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
291 |prompt| {
292 prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
293 prompt.push_str(related_file.path.to_string_lossy().as_ref());
294 prompt.push('\n');
295 prompt.push_str(related_excerpt.text.as_ref());
296 },
297 );
298 }
299 }
300 },
301 );
302
303 push_delimited(
304 &mut prompt,
305 CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
306 |prompt| {
307 prompt.push_str(CURRENT_FILE_PATH_PREFIX);
308 prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
309 prompt.push('\n');
310
311 let editable_range = &inputs.excerpt_ranges.editable_350;
312 prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]);
313 push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
314 prompt.push_str(
315 &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt],
316 );
317 prompt.push_str(CURSOR_TAG);
318 prompt.push_str(
319 &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end],
320 );
321 });
322 prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]);
323 },
324 );
325
326 push_delimited(
327 &mut prompt,
328 EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
329 |prompt| {
330 for event in inputs.events.iter() {
331 zeta_prompt::write_event(prompt, &event);
332 }
333 },
334 );
335
336 prompt
337}
338
339fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
340 prompt.push_str(delimiters.start);
341 cb(prompt);
342 prompt.push('\n');
343 prompt.push_str(delimiters.end);
344}
345
346pub const MERCURY_CREDENTIALS_URL: SharedString =
347 SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
348pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
349
350#[derive(Debug, thiserror::Error)]
351#[error("{0}")]
352struct MercuryPaymentRequiredError(SharedString);
353
354#[derive(Deserialize)]
355struct MercuryErrorResponse {
356 error: MercuryErrorMessage,
357}
358
359#[derive(Deserialize)]
360struct MercuryErrorMessage {
361 message: String,
362}
363
364fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
365 error
366 .downcast_ref::<MercuryPaymentRequiredError>()
367 .is_some()
368}
369
370fn mercury_payment_required_message(body: &[u8]) -> SharedString {
371 serde_json::from_slice::<MercuryErrorResponse>(body)
372 .map(|response| response.error.message.into())
373 .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
374}
375
376pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
377
378struct GlobalMercuryApiKey(Entity<ApiKeyState>);
379
380impl Global for GlobalMercuryApiKey {}
381
382pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
383 if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
384 return global.0.clone();
385 }
386 let entity =
387 cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
388 cx.set_global(GlobalMercuryApiKey(entity.clone()));
389 entity
390}
391
392pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
393 let credentials_provider = global_credentials_provider(cx);
394 mercury_api_token(cx).update(cx, |key_state, cx| {
395 key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
396 })
397}
398
399const FEEDBACK_API_URL: &str = "https://api-feedback.inceptionlabs.ai/feedback";
400
401#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
402#[serde(rename_all = "snake_case")]
403enum MercuryUserAction {
404 Accept,
405 Reject,
406 Ignore,
407}
408
409#[derive(Serialize)]
410struct FeedbackRequest {
411 request_id: SharedString,
412 provider_name: &'static str,
413 user_action: MercuryUserAction,
414 provider_version: String,
415}
416
417pub(crate) fn edit_prediction_accepted(
418 prediction_id: EditPredictionId,
419 http_client: Arc<dyn HttpClient>,
420 cx: &App,
421) {
422 send_feedback(prediction_id, MercuryUserAction::Accept, http_client, cx);
423}
424
425pub(crate) fn edit_prediction_rejected(
426 prediction_id: EditPredictionId,
427 was_shown: bool,
428 reason: EditPredictionRejectReason,
429 http_client: Arc<dyn HttpClient>,
430 cx: &App,
431) {
432 if !was_shown {
433 return;
434 }
435 let action = match reason {
436 EditPredictionRejectReason::Rejected => MercuryUserAction::Reject,
437 EditPredictionRejectReason::Discarded => MercuryUserAction::Ignore,
438 _ => return,
439 };
440 send_feedback(prediction_id, action, http_client, cx);
441}
442
443fn send_feedback(
444 prediction_id: EditPredictionId,
445 action: MercuryUserAction,
446 http_client: Arc<dyn HttpClient>,
447 cx: &App,
448) {
449 let request_id = prediction_id.0;
450 let app_version = AppVersion::global(cx);
451 cx.background_spawn(async move {
452 let body = FeedbackRequest {
453 request_id,
454 provider_name: "zed",
455 user_action: action,
456 provider_version: app_version.to_string(),
457 };
458
459 let request = http_client::Request::builder()
460 .uri(FEEDBACK_API_URL)
461 .method(Method::POST)
462 .header("Content-Type", "application/json")
463 .body(AsyncBody::from(serde_json::to_vec(&body)?))?;
464
465 let response = http_client.send(request).await?;
466 if !response.status().is_success() {
467 anyhow::bail!("Feedback API returned status: {}", response.status());
468 }
469
470 log::debug!(
471 "Mercury feedback sent: request_id={}, action={:?}",
472 body.request_id,
473 body.user_action
474 );
475
476 anyhow::Ok(())
477 })
478 .detach_and_log_err(cx);
479}