1use crate::{
2 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
3 EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
4 prediction::EditPredictionResult, zeta::compute_edits,
5};
6use anyhow::{Context as _, Result};
7use cloud_llm_client::EditPredictionRejectReason;
8use futures::AsyncReadExt as _;
9use gpui::{
10 App, AppContext as _, Context, Entity, Global, SharedString, Task,
11 http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
12};
13use language::{ToOffset, ToPoint as _};
14use language_model::{ApiKeyState, EnvVar, env_var};
15use release_channel::AppVersion;
16use serde::{Deserialize, Serialize};
17use std::{mem, ops::Range, path::Path, sync::Arc};
18use zeta_prompt::ZetaPromptInput;
19
20const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
21
22pub struct Mercury {
23 pub api_token: Entity<ApiKeyState>,
24 payment_required_error: bool,
25}
26
27impl Mercury {
28 pub fn new(cx: &mut App) -> Self {
29 Mercury {
30 api_token: mercury_api_token(cx),
31 payment_required_error: false,
32 }
33 }
34
35 pub fn has_payment_required_error(&self) -> bool {
36 self.payment_required_error
37 }
38
39 pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
40 self.payment_required_error = payment_required_error;
41 }
42
43 pub(crate) fn request_prediction(
44 &mut self,
45 EditPredictionModelInput {
46 buffer,
47 snapshot,
48 position,
49 events,
50 related_files,
51 debug_tx,
52 ..
53 }: EditPredictionModelInput,
54 cx: &mut Context<EditPredictionStore>,
55 ) -> Task<Result<Option<EditPredictionResult>>> {
56 self.api_token.update(cx, |key_state, cx| {
57 _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
58 });
59 let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
60 return Task::ready(Ok(None));
61 };
62 let full_path: Arc<Path> = snapshot
63 .file()
64 .map(|file| file.full_path(cx))
65 .unwrap_or_else(|| "untitled".into())
66 .into();
67
68 let http_client = cx.http_client();
69 let cursor_point = position.to_point(&snapshot);
70 let request_start = cx.background_executor().now();
71 let active_buffer = buffer.clone();
72
73 let result = cx.background_spawn(async move {
74 let cursor_offset = cursor_point.to_offset(&snapshot);
75 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
76 crate::cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
77
78 let related_files = zeta_prompt::filter_redundant_excerpts(
79 related_files,
80 full_path.as_ref(),
81 excerpt_point_range.start.row..excerpt_point_range.end.row,
82 );
83
84 let cursor_excerpt: Arc<str> = snapshot
85 .text_for_range(excerpt_point_range.clone())
86 .collect::<String>()
87 .into();
88 let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
89 &snapshot,
90 cursor_offset,
91 &excerpt_offset_range,
92 );
93 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
94 &cursor_excerpt,
95 cursor_offset_in_excerpt,
96 &syntax_ranges,
97 );
98
99 let editable_offset_range = (excerpt_offset_range.start
100 + excerpt_ranges.editable_350.start)
101 ..(excerpt_offset_range.start + excerpt_ranges.editable_350.end);
102
103 let inputs = zeta_prompt::ZetaPromptInput {
104 events,
105 related_files: Some(related_files),
106 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
107 - excerpt_offset_range.start,
108 cursor_path: full_path.clone(),
109 cursor_excerpt,
110 experiment: None,
111 excerpt_start_row: Some(excerpt_point_range.start.row),
112 excerpt_ranges,
113 syntax_ranges: Some(syntax_ranges),
114 active_buffer_diagnostics: vec![],
115 in_open_source_repo: false,
116 can_collect_data: false,
117 repo_url: None,
118 };
119
120 let prompt = build_prompt(&inputs);
121
122 if let Some(debug_tx) = &debug_tx {
123 debug_tx
124 .unbounded_send(DebugEvent::EditPredictionStarted(
125 EditPredictionStartedDebugEvent {
126 buffer: active_buffer.downgrade(),
127 prompt: Some(prompt.clone()),
128 position,
129 },
130 ))
131 .ok();
132 }
133
134 let request_body = open_ai::Request {
135 model: "mercury-coder".into(),
136 messages: vec![open_ai::RequestMessage::User {
137 content: open_ai::MessageContent::Plain(prompt),
138 }],
139 stream: false,
140 stream_options: None,
141 max_completion_tokens: None,
142 stop: vec![],
143 temperature: None,
144 tool_choice: None,
145 parallel_tool_calls: None,
146 tools: vec![],
147 prompt_cache_key: None,
148 reasoning_effort: None,
149 };
150
151 let buf = serde_json::to_vec(&request_body)?;
152 let body: AsyncBody = buf.into();
153
154 let request = http_client::Request::builder()
155 .uri(MERCURY_API_URL)
156 .header("Content-Type", "application/json")
157 .header("Authorization", format!("Bearer {}", api_token))
158 .header("Connection", "keep-alive")
159 .method(Method::POST)
160 .body(body)
161 .context("Failed to create request")?;
162
163 let mut response = http_client
164 .send(request)
165 .await
166 .context("Failed to send request")?;
167
168 let mut body: Vec<u8> = Vec::new();
169 response
170 .body_mut()
171 .read_to_end(&mut body)
172 .await
173 .context("Failed to read response body")?;
174
175 if !response.status().is_success() {
176 if response.status() == StatusCode::PAYMENT_REQUIRED {
177 anyhow::bail!(MercuryPaymentRequiredError(
178 mercury_payment_required_message(&body),
179 ));
180 }
181
182 anyhow::bail!(
183 "Request failed with status: {:?}\nBody: {}",
184 response.status(),
185 String::from_utf8_lossy(&body),
186 );
187 };
188
189 let mut response: open_ai::Response =
190 serde_json::from_slice(&body).context("Failed to parse response")?;
191
192 let id = mem::take(&mut response.id);
193 let response_str = text_from_response(response).unwrap_or_default();
194
195 if let Some(debug_tx) = &debug_tx {
196 debug_tx
197 .unbounded_send(DebugEvent::EditPredictionFinished(
198 EditPredictionFinishedDebugEvent {
199 buffer: active_buffer.downgrade(),
200 model_output: Some(response_str.clone()),
201 position,
202 },
203 ))
204 .ok();
205 }
206
207 let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
208 let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
209
210 let mut edits = Vec::new();
211 const NO_PREDICTION_OUTPUT: &str = "None";
212
213 if response_str != NO_PREDICTION_OUTPUT {
214 let old_text = snapshot
215 .text_for_range(editable_offset_range.clone())
216 .collect::<String>();
217 edits = compute_edits(
218 old_text,
219 &response_str,
220 editable_offset_range.start,
221 &snapshot,
222 );
223 }
224
225 anyhow::Ok((id, edits, snapshot, inputs))
226 });
227
228 cx.spawn(async move |ep_store, cx| {
229 let result = result.await.context("Mercury edit prediction failed");
230
231 let has_payment_required_error = result
232 .as_ref()
233 .err()
234 .is_some_and(is_mercury_payment_required_error);
235
236 ep_store.update(cx, |store, cx| {
237 store
238 .mercury
239 .set_payment_required_error(has_payment_required_error);
240 cx.notify();
241 })?;
242
243 let (id, edits, old_snapshot, inputs) = result?;
244 anyhow::Ok(Some(
245 EditPredictionResult::new(
246 EditPredictionId(id.into()),
247 &buffer,
248 &old_snapshot,
249 edits.into(),
250 None,
251 inputs,
252 None,
253 cx.background_executor().now() - request_start,
254 cx,
255 )
256 .await,
257 ))
258 })
259 }
260}
261
262fn build_prompt(inputs: &ZetaPromptInput) -> String {
263 const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
264 const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
265 const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
266 const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
267 const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
268 const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
269 const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
270 const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
271 const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
272 const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
273 const CURSOR_TAG: &str = "<|cursor|>";
274 const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
275 const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
276
277 let mut prompt = String::new();
278
279 push_delimited(
280 &mut prompt,
281 RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
282 |prompt| {
283 for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() {
284 for related_excerpt in &related_file.excerpts {
285 push_delimited(
286 prompt,
287 RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
288 |prompt| {
289 prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
290 prompt.push_str(related_file.path.to_string_lossy().as_ref());
291 prompt.push('\n');
292 prompt.push_str(related_excerpt.text.as_ref());
293 },
294 );
295 }
296 }
297 },
298 );
299
300 push_delimited(
301 &mut prompt,
302 CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
303 |prompt| {
304 prompt.push_str(CURRENT_FILE_PATH_PREFIX);
305 prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
306 prompt.push('\n');
307
308 let editable_range = &inputs.excerpt_ranges.editable_350;
309 prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]);
310 push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
311 prompt.push_str(
312 &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt],
313 );
314 prompt.push_str(CURSOR_TAG);
315 prompt.push_str(
316 &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end],
317 );
318 });
319 prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]);
320 },
321 );
322
323 push_delimited(
324 &mut prompt,
325 EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
326 |prompt| {
327 for event in inputs.events.iter() {
328 zeta_prompt::write_event(prompt, &event);
329 }
330 },
331 );
332
333 prompt
334}
335
336fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
337 prompt.push_str(delimiters.start);
338 cb(prompt);
339 prompt.push('\n');
340 prompt.push_str(delimiters.end);
341}
342
343pub const MERCURY_CREDENTIALS_URL: SharedString =
344 SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
345pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
346
347#[derive(Debug, thiserror::Error)]
348#[error("{0}")]
349struct MercuryPaymentRequiredError(SharedString);
350
351#[derive(Deserialize)]
352struct MercuryErrorResponse {
353 error: MercuryErrorMessage,
354}
355
356#[derive(Deserialize)]
357struct MercuryErrorMessage {
358 message: String,
359}
360
361fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
362 error
363 .downcast_ref::<MercuryPaymentRequiredError>()
364 .is_some()
365}
366
367fn mercury_payment_required_message(body: &[u8]) -> SharedString {
368 serde_json::from_slice::<MercuryErrorResponse>(body)
369 .map(|response| response.error.message.into())
370 .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
371}
372
373pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
374
375struct GlobalMercuryApiKey(Entity<ApiKeyState>);
376
377impl Global for GlobalMercuryApiKey {}
378
379pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
380 if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
381 return global.0.clone();
382 }
383 let entity =
384 cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
385 cx.set_global(GlobalMercuryApiKey(entity.clone()));
386 entity
387}
388
389pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
390 mercury_api_token(cx).update(cx, |key_state, cx| {
391 key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
392 })
393}
394
395const FEEDBACK_API_URL: &str = "https://api-feedback.inceptionlabs.ai/feedback";
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
398#[serde(rename_all = "snake_case")]
399enum MercuryUserAction {
400 Accept,
401 Reject,
402 Ignore,
403}
404
405#[derive(Serialize)]
406struct FeedbackRequest {
407 request_id: SharedString,
408 provider_name: &'static str,
409 user_action: MercuryUserAction,
410 provider_version: String,
411}
412
413pub(crate) fn edit_prediction_accepted(
414 prediction_id: EditPredictionId,
415 http_client: Arc<dyn HttpClient>,
416 cx: &App,
417) {
418 send_feedback(prediction_id, MercuryUserAction::Accept, http_client, cx);
419}
420
421pub(crate) fn edit_prediction_rejected(
422 prediction_id: EditPredictionId,
423 was_shown: bool,
424 reason: EditPredictionRejectReason,
425 http_client: Arc<dyn HttpClient>,
426 cx: &App,
427) {
428 if !was_shown {
429 return;
430 }
431 let action = match reason {
432 EditPredictionRejectReason::Rejected => MercuryUserAction::Reject,
433 EditPredictionRejectReason::Discarded => MercuryUserAction::Ignore,
434 _ => return,
435 };
436 send_feedback(prediction_id, action, http_client, cx);
437}
438
439fn send_feedback(
440 prediction_id: EditPredictionId,
441 action: MercuryUserAction,
442 http_client: Arc<dyn HttpClient>,
443 cx: &App,
444) {
445 let request_id = prediction_id.0;
446 let app_version = AppVersion::global(cx);
447 cx.background_spawn(async move {
448 let body = FeedbackRequest {
449 request_id,
450 provider_name: "zed",
451 user_action: action,
452 provider_version: app_version.to_string(),
453 };
454
455 let request = http_client::Request::builder()
456 .uri(FEEDBACK_API_URL)
457 .method(Method::POST)
458 .header("Content-Type", "application/json")
459 .body(AsyncBody::from(serde_json::to_vec(&body)?))?;
460
461 let response = http_client.send(request).await?;
462 if !response.status().is_success() {
463 anyhow::bail!("Feedback API returned status: {}", response.status());
464 }
465
466 log::debug!(
467 "Mercury feedback sent: request_id={}, action={:?}",
468 body.request_id,
469 body.user_action
470 );
471
472 anyhow::Ok(())
473 })
474 .detach_and_log_err(cx);
475}