1use crate::{
2 DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
3 EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
4 prediction::EditPredictionResult, zeta::compute_edits,
5};
6use anyhow::{Context as _, Result};
7use cloud_llm_client::EditPredictionRejectReason;
8use credentials_provider::CredentialsProvider;
9use futures::AsyncReadExt as _;
10use gpui::{
11 App, AppContext as _, Context, Entity, Global, SharedString, Task,
12 http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
13};
14use language::{ToOffset, ToPoint as _};
15use language_model::{ApiKeyState, EnvVar, env_var};
16use release_channel::AppVersion;
17use serde::{Deserialize, Serialize};
18use std::{mem, ops::Range, path::Path, sync::Arc};
19use zeta_prompt::ZetaPromptInput;
20
21const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
22
23pub struct Mercury {
24 pub api_token: Entity<ApiKeyState>,
25 payment_required_error: bool,
26}
27
28impl Mercury {
29 pub fn new(cx: &mut App) -> Self {
30 Mercury {
31 api_token: mercury_api_token(cx),
32 payment_required_error: false,
33 }
34 }
35
36 pub fn has_payment_required_error(&self) -> bool {
37 self.payment_required_error
38 }
39
40 pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
41 self.payment_required_error = payment_required_error;
42 }
43
44 pub(crate) fn request_prediction(
45 &mut self,
46 EditPredictionModelInput {
47 buffer,
48 snapshot,
49 position,
50 events,
51 related_files,
52 debug_tx,
53 ..
54 }: EditPredictionModelInput,
55 credentials_provider: Arc<dyn CredentialsProvider>,
56 cx: &mut Context<EditPredictionStore>,
57 ) -> Task<Result<Option<EditPredictionResult>>> {
58 self.api_token.update(cx, |key_state, cx| {
59 _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
60 });
61 let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
62 return Task::ready(Ok(None));
63 };
64 let full_path: Arc<Path> = snapshot
65 .file()
66 .map(|file| file.full_path(cx))
67 .unwrap_or_else(|| "untitled".into())
68 .into();
69
70 let http_client = cx.http_client();
71 let cursor_point = position.to_point(&snapshot);
72 let request_start = cx.background_executor().now();
73 let active_buffer = buffer.clone();
74
75 let result = cx.background_spawn(async move {
76 let cursor_offset = cursor_point.to_offset(&snapshot);
77 let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
78 crate::cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
79
80 let related_files = zeta_prompt::filter_redundant_excerpts(
81 related_files,
82 full_path.as_ref(),
83 excerpt_point_range.start.row..excerpt_point_range.end.row,
84 );
85
86 let cursor_excerpt: Arc<str> = snapshot
87 .text_for_range(excerpt_point_range.clone())
88 .collect::<String>()
89 .into();
90 let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
91 &snapshot,
92 cursor_offset,
93 &excerpt_offset_range,
94 );
95 let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
96 &cursor_excerpt,
97 cursor_offset_in_excerpt,
98 &syntax_ranges,
99 );
100
101 let editable_offset_range = (excerpt_offset_range.start
102 + excerpt_ranges.editable_350.start)
103 ..(excerpt_offset_range.start + excerpt_ranges.editable_350.end);
104
105 let inputs = zeta_prompt::ZetaPromptInput {
106 events,
107 related_files: Some(related_files),
108 cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
109 - excerpt_offset_range.start,
110 cursor_path: full_path.clone(),
111 cursor_excerpt,
112 experiment: None,
113 excerpt_start_row: Some(excerpt_point_range.start.row),
114 excerpt_ranges,
115 syntax_ranges: Some(syntax_ranges),
116 active_buffer_diagnostics: vec![],
117 in_open_source_repo: false,
118 can_collect_data: false,
119 repo_url: None,
120 };
121
122 let prompt = build_prompt(&inputs);
123
124 if let Some(debug_tx) = &debug_tx {
125 debug_tx
126 .unbounded_send(DebugEvent::EditPredictionStarted(
127 EditPredictionStartedDebugEvent {
128 buffer: active_buffer.downgrade(),
129 prompt: Some(prompt.clone()),
130 position,
131 },
132 ))
133 .ok();
134 }
135
136 let request_body = open_ai::Request {
137 model: "mercury-coder".into(),
138 messages: vec![open_ai::RequestMessage::User {
139 content: open_ai::MessageContent::Plain(prompt),
140 }],
141 stream: false,
142 stream_options: None,
143 max_completion_tokens: None,
144 stop: vec![],
145 temperature: None,
146 tool_choice: None,
147 parallel_tool_calls: None,
148 tools: vec![],
149 prompt_cache_key: None,
150 reasoning_effort: None,
151 };
152
153 let buf = serde_json::to_vec(&request_body)?;
154 let body: AsyncBody = buf.into();
155
156 let request = http_client::Request::builder()
157 .uri(MERCURY_API_URL)
158 .header("Content-Type", "application/json")
159 .header("Authorization", format!("Bearer {}", api_token))
160 .header("Connection", "keep-alive")
161 .method(Method::POST)
162 .body(body)
163 .context("Failed to create request")?;
164
165 let mut response = http_client
166 .send(request)
167 .await
168 .context("Failed to send request")?;
169
170 let mut body: Vec<u8> = Vec::new();
171 response
172 .body_mut()
173 .read_to_end(&mut body)
174 .await
175 .context("Failed to read response body")?;
176
177 if !response.status().is_success() {
178 if response.status() == StatusCode::PAYMENT_REQUIRED {
179 anyhow::bail!(MercuryPaymentRequiredError(
180 mercury_payment_required_message(&body),
181 ));
182 }
183
184 anyhow::bail!(
185 "Request failed with status: {:?}\nBody: {}",
186 response.status(),
187 String::from_utf8_lossy(&body),
188 );
189 };
190
191 let mut response: open_ai::Response =
192 serde_json::from_slice(&body).context("Failed to parse response")?;
193
194 let id = mem::take(&mut response.id);
195 let response_str = text_from_response(response).unwrap_or_default();
196
197 if let Some(debug_tx) = &debug_tx {
198 debug_tx
199 .unbounded_send(DebugEvent::EditPredictionFinished(
200 EditPredictionFinishedDebugEvent {
201 buffer: active_buffer.downgrade(),
202 model_output: Some(response_str.clone()),
203 position,
204 },
205 ))
206 .ok();
207 }
208
209 let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
210 let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
211
212 let mut edits = Vec::new();
213 const NO_PREDICTION_OUTPUT: &str = "None";
214
215 if response_str != NO_PREDICTION_OUTPUT {
216 let old_text = snapshot
217 .text_for_range(editable_offset_range.clone())
218 .collect::<String>();
219 edits = compute_edits(
220 old_text,
221 &response_str,
222 editable_offset_range.start,
223 &snapshot,
224 );
225 }
226
227 anyhow::Ok((id, edits, snapshot, inputs))
228 });
229
230 cx.spawn(async move |ep_store, cx| {
231 let result = result.await.context("Mercury edit prediction failed");
232
233 let has_payment_required_error = result
234 .as_ref()
235 .err()
236 .is_some_and(is_mercury_payment_required_error);
237
238 ep_store.update(cx, |store, cx| {
239 store
240 .mercury
241 .set_payment_required_error(has_payment_required_error);
242 cx.notify();
243 })?;
244
245 let (id, edits, old_snapshot, inputs) = result?;
246 anyhow::Ok(Some(
247 EditPredictionResult::new(
248 EditPredictionId(id.into()),
249 &buffer,
250 &old_snapshot,
251 edits.into(),
252 None,
253 inputs,
254 None,
255 cx.background_executor().now() - request_start,
256 cx,
257 )
258 .await,
259 ))
260 })
261 }
262}
263
264fn build_prompt(inputs: &ZetaPromptInput) -> String {
265 const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
266 const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
267 const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
268 const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
269 const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
270 const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
271 const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
272 const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
273 const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
274 const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
275 const CURSOR_TAG: &str = "<|cursor|>";
276 const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
277 const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
278
279 let mut prompt = String::new();
280
281 push_delimited(
282 &mut prompt,
283 RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
284 |prompt| {
285 for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() {
286 for related_excerpt in &related_file.excerpts {
287 push_delimited(
288 prompt,
289 RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
290 |prompt| {
291 prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
292 prompt.push_str(related_file.path.to_string_lossy().as_ref());
293 prompt.push('\n');
294 prompt.push_str(related_excerpt.text.as_ref());
295 },
296 );
297 }
298 }
299 },
300 );
301
302 push_delimited(
303 &mut prompt,
304 CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
305 |prompt| {
306 prompt.push_str(CURRENT_FILE_PATH_PREFIX);
307 prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
308 prompt.push('\n');
309
310 let editable_range = &inputs.excerpt_ranges.editable_350;
311 prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]);
312 push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
313 prompt.push_str(
314 &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt],
315 );
316 prompt.push_str(CURSOR_TAG);
317 prompt.push_str(
318 &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end],
319 );
320 });
321 prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]);
322 },
323 );
324
325 push_delimited(
326 &mut prompt,
327 EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
328 |prompt| {
329 for event in inputs.events.iter() {
330 zeta_prompt::write_event(prompt, &event);
331 }
332 },
333 );
334
335 prompt
336}
337
338fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
339 prompt.push_str(delimiters.start);
340 cb(prompt);
341 prompt.push('\n');
342 prompt.push_str(delimiters.end);
343}
344
345pub const MERCURY_CREDENTIALS_URL: SharedString =
346 SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
347pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
348
349#[derive(Debug, thiserror::Error)]
350#[error("{0}")]
351struct MercuryPaymentRequiredError(SharedString);
352
353#[derive(Deserialize)]
354struct MercuryErrorResponse {
355 error: MercuryErrorMessage,
356}
357
358#[derive(Deserialize)]
359struct MercuryErrorMessage {
360 message: String,
361}
362
363fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
364 error
365 .downcast_ref::<MercuryPaymentRequiredError>()
366 .is_some()
367}
368
369fn mercury_payment_required_message(body: &[u8]) -> SharedString {
370 serde_json::from_slice::<MercuryErrorResponse>(body)
371 .map(|response| response.error.message.into())
372 .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
373}
374
375pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
376
377struct GlobalMercuryApiKey(Entity<ApiKeyState>);
378
379impl Global for GlobalMercuryApiKey {}
380
381pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
382 if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
383 return global.0.clone();
384 }
385 let entity =
386 cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
387 cx.set_global(GlobalMercuryApiKey(entity.clone()));
388 entity
389}
390
391pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
392 let credentials_provider = zed_credentials_provider::global(cx);
393 mercury_api_token(cx).update(cx, |key_state, cx| {
394 key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
395 })
396}
397
398const FEEDBACK_API_URL: &str = "https://api-feedback.inceptionlabs.ai/feedback";
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
401#[serde(rename_all = "snake_case")]
402enum MercuryUserAction {
403 Accept,
404 Reject,
405 Ignore,
406}
407
408#[derive(Serialize)]
409struct FeedbackRequest {
410 request_id: SharedString,
411 provider_name: &'static str,
412 user_action: MercuryUserAction,
413 provider_version: String,
414}
415
416pub(crate) fn edit_prediction_accepted(
417 prediction_id: EditPredictionId,
418 http_client: Arc<dyn HttpClient>,
419 cx: &App,
420) {
421 send_feedback(prediction_id, MercuryUserAction::Accept, http_client, cx);
422}
423
424pub(crate) fn edit_prediction_rejected(
425 prediction_id: EditPredictionId,
426 was_shown: bool,
427 reason: EditPredictionRejectReason,
428 http_client: Arc<dyn HttpClient>,
429 cx: &App,
430) {
431 if !was_shown {
432 return;
433 }
434 let action = match reason {
435 EditPredictionRejectReason::Rejected => MercuryUserAction::Reject,
436 EditPredictionRejectReason::Discarded => MercuryUserAction::Ignore,
437 _ => return,
438 };
439 send_feedback(prediction_id, action, http_client, cx);
440}
441
442fn send_feedback(
443 prediction_id: EditPredictionId,
444 action: MercuryUserAction,
445 http_client: Arc<dyn HttpClient>,
446 cx: &App,
447) {
448 let request_id = prediction_id.0;
449 let app_version = AppVersion::global(cx);
450 cx.background_spawn(async move {
451 let body = FeedbackRequest {
452 request_id,
453 provider_name: "zed",
454 user_action: action,
455 provider_version: app_version.to_string(),
456 };
457
458 let request = http_client::Request::builder()
459 .uri(FEEDBACK_API_URL)
460 .method(Method::POST)
461 .header("Content-Type", "application/json")
462 .body(AsyncBody::from(serde_json::to_vec(&body)?))?;
463
464 let response = http_client.send(request).await?;
465 if !response.status().is_success() {
466 anyhow::bail!("Feedback API returned status: {}", response.status());
467 }
468
469 log::debug!(
470 "Mercury feedback sent: request_id={}, action={:?}",
471 body.request_id,
472 body.user_action
473 );
474
475 anyhow::Ok(())
476 })
477 .detach_and_log_err(cx);
478}