mercury.rs

  1use crate::{
  2    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  3    EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
  4    prediction::EditPredictionResult, zeta::compute_edits,
  5};
  6use anyhow::{Context as _, Result};
  7use cloud_llm_client::EditPredictionRejectReason;
  8use credentials_provider::CredentialsProvider;
  9use futures::AsyncReadExt as _;
 10use gpui::{
 11    App, AppContext as _, Context, Entity, Global, SharedString, Task,
 12    http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
 13};
 14use language::{ToOffset, ToPoint as _};
 15use language_model::{ApiKeyState, EnvVar, env_var};
 16use release_channel::AppVersion;
 17use serde::{Deserialize, Serialize};
 18use std::{mem, ops::Range, path::Path, sync::Arc};
 19use zeta_prompt::ZetaPromptInput;
 20
 21const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
 22
 23pub struct Mercury {
 24    pub api_token: Entity<ApiKeyState>,
 25    payment_required_error: bool,
 26}
 27
 28impl Mercury {
 29    pub fn new(cx: &mut App) -> Self {
 30        Mercury {
 31            api_token: mercury_api_token(cx),
 32            payment_required_error: false,
 33        }
 34    }
 35
 36    pub fn has_payment_required_error(&self) -> bool {
 37        self.payment_required_error
 38    }
 39
 40    pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
 41        self.payment_required_error = payment_required_error;
 42    }
 43
 44    pub(crate) fn request_prediction(
 45        &mut self,
 46        EditPredictionModelInput {
 47            buffer,
 48            snapshot,
 49            position,
 50            events,
 51            related_files,
 52            debug_tx,
 53            ..
 54        }: EditPredictionModelInput,
 55        credentials_provider: Arc<dyn CredentialsProvider>,
 56        cx: &mut Context<EditPredictionStore>,
 57    ) -> Task<Result<Option<EditPredictionResult>>> {
 58        self.api_token.update(cx, |key_state, cx| {
 59            _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
 60        });
 61        let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
 62            return Task::ready(Ok(None));
 63        };
 64        let full_path: Arc<Path> = snapshot
 65            .file()
 66            .map(|file| file.full_path(cx))
 67            .unwrap_or_else(|| "untitled".into())
 68            .into();
 69
 70        let http_client = cx.http_client();
 71        let cursor_point = position.to_point(&snapshot);
 72        let request_start = cx.background_executor().now();
 73        let active_buffer = buffer.clone();
 74
 75        let result = cx.background_spawn(async move {
 76            let cursor_offset = cursor_point.to_offset(&snapshot);
 77            let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
 78                crate::cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
 79
 80            let related_files = zeta_prompt::filter_redundant_excerpts(
 81                related_files,
 82                full_path.as_ref(),
 83                excerpt_point_range.start.row..excerpt_point_range.end.row,
 84            );
 85
 86            let cursor_excerpt: Arc<str> = snapshot
 87                .text_for_range(excerpt_point_range.clone())
 88                .collect::<String>()
 89                .into();
 90            let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
 91                &snapshot,
 92                cursor_offset,
 93                &excerpt_offset_range,
 94            );
 95            let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
 96                &cursor_excerpt,
 97                cursor_offset_in_excerpt,
 98                &syntax_ranges,
 99            );
100
101            let editable_offset_range = (excerpt_offset_range.start
102                + excerpt_ranges.editable_350.start)
103                ..(excerpt_offset_range.start + excerpt_ranges.editable_350.end);
104
105            let inputs = zeta_prompt::ZetaPromptInput {
106                events,
107                related_files: Some(related_files),
108                cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
109                    - excerpt_offset_range.start,
110                cursor_path: full_path.clone(),
111                cursor_excerpt,
112                experiment: None,
113                excerpt_start_row: Some(excerpt_point_range.start.row),
114                excerpt_ranges,
115                syntax_ranges: Some(syntax_ranges),
116                active_buffer_diagnostics: vec![],
117                in_open_source_repo: false,
118                can_collect_data: false,
119                repo_url: None,
120            };
121
122            let prompt = build_prompt(&inputs);
123
124            if let Some(debug_tx) = &debug_tx {
125                debug_tx
126                    .unbounded_send(DebugEvent::EditPredictionStarted(
127                        EditPredictionStartedDebugEvent {
128                            buffer: active_buffer.downgrade(),
129                            prompt: Some(prompt.clone()),
130                            position,
131                        },
132                    ))
133                    .ok();
134            }
135
136            let request_body = open_ai::Request {
137                model: "mercury-coder".into(),
138                messages: vec![open_ai::RequestMessage::User {
139                    content: open_ai::MessageContent::Plain(prompt),
140                }],
141                stream: false,
142                stream_options: None,
143                max_completion_tokens: None,
144                stop: vec![],
145                temperature: None,
146                tool_choice: None,
147                parallel_tool_calls: None,
148                tools: vec![],
149                prompt_cache_key: None,
150                reasoning_effort: None,
151            };
152
153            let buf = serde_json::to_vec(&request_body)?;
154            let body: AsyncBody = buf.into();
155
156            let request = http_client::Request::builder()
157                .uri(MERCURY_API_URL)
158                .header("Content-Type", "application/json")
159                .header("Authorization", format!("Bearer {}", api_token))
160                .header("Connection", "keep-alive")
161                .method(Method::POST)
162                .body(body)
163                .context("Failed to create request")?;
164
165            let mut response = http_client
166                .send(request)
167                .await
168                .context("Failed to send request")?;
169
170            let mut body: Vec<u8> = Vec::new();
171            response
172                .body_mut()
173                .read_to_end(&mut body)
174                .await
175                .context("Failed to read response body")?;
176
177            if !response.status().is_success() {
178                if response.status() == StatusCode::PAYMENT_REQUIRED {
179                    anyhow::bail!(MercuryPaymentRequiredError(
180                        mercury_payment_required_message(&body),
181                    ));
182                }
183
184                anyhow::bail!(
185                    "Request failed with status: {:?}\nBody: {}",
186                    response.status(),
187                    String::from_utf8_lossy(&body),
188                );
189            };
190
191            let mut response: open_ai::Response =
192                serde_json::from_slice(&body).context("Failed to parse response")?;
193
194            let id = mem::take(&mut response.id);
195            let response_str = text_from_response(response).unwrap_or_default();
196
197            if let Some(debug_tx) = &debug_tx {
198                debug_tx
199                    .unbounded_send(DebugEvent::EditPredictionFinished(
200                        EditPredictionFinishedDebugEvent {
201                            buffer: active_buffer.downgrade(),
202                            model_output: Some(response_str.clone()),
203                            position,
204                        },
205                    ))
206                    .ok();
207            }
208
209            let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
210            let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
211
212            let mut edits = Vec::new();
213            const NO_PREDICTION_OUTPUT: &str = "None";
214
215            if response_str != NO_PREDICTION_OUTPUT {
216                let old_text = snapshot
217                    .text_for_range(editable_offset_range.clone())
218                    .collect::<String>();
219                edits = compute_edits(
220                    old_text,
221                    &response_str,
222                    editable_offset_range.start,
223                    &snapshot,
224                );
225            }
226
227            anyhow::Ok((id, edits, snapshot, inputs))
228        });
229
230        cx.spawn(async move |ep_store, cx| {
231            let result = result.await.context("Mercury edit prediction failed");
232
233            let has_payment_required_error = result
234                .as_ref()
235                .err()
236                .is_some_and(is_mercury_payment_required_error);
237
238            ep_store.update(cx, |store, cx| {
239                store
240                    .mercury
241                    .set_payment_required_error(has_payment_required_error);
242                cx.notify();
243            })?;
244
245            let (id, edits, old_snapshot, inputs) = result?;
246            anyhow::Ok(Some(
247                EditPredictionResult::new(
248                    EditPredictionId(id.into()),
249                    &buffer,
250                    &old_snapshot,
251                    edits.into(),
252                    None,
253                    inputs,
254                    None,
255                    cx.background_executor().now() - request_start,
256                    cx,
257                )
258                .await,
259            ))
260        })
261    }
262}
263
264fn build_prompt(inputs: &ZetaPromptInput) -> String {
265    const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
266    const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
267    const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
268    const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
269    const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
270    const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
271    const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
272    const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
273    const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
274    const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
275    const CURSOR_TAG: &str = "<|cursor|>";
276    const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
277    const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
278
279    let mut prompt = String::new();
280
281    push_delimited(
282        &mut prompt,
283        RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
284        |prompt| {
285            for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() {
286                for related_excerpt in &related_file.excerpts {
287                    push_delimited(
288                        prompt,
289                        RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
290                        |prompt| {
291                            prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
292                            prompt.push_str(related_file.path.to_string_lossy().as_ref());
293                            prompt.push('\n');
294                            prompt.push_str(related_excerpt.text.as_ref());
295                        },
296                    );
297                }
298            }
299        },
300    );
301
302    push_delimited(
303        &mut prompt,
304        CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
305        |prompt| {
306            prompt.push_str(CURRENT_FILE_PATH_PREFIX);
307            prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
308            prompt.push('\n');
309
310            let editable_range = &inputs.excerpt_ranges.editable_350;
311            prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]);
312            push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
313                prompt.push_str(
314                    &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt],
315                );
316                prompt.push_str(CURSOR_TAG);
317                prompt.push_str(
318                    &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end],
319                );
320            });
321            prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]);
322        },
323    );
324
325    push_delimited(
326        &mut prompt,
327        EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
328        |prompt| {
329            for event in inputs.events.iter() {
330                zeta_prompt::write_event(prompt, &event);
331            }
332        },
333    );
334
335    prompt
336}
337
338fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
339    prompt.push_str(delimiters.start);
340    cb(prompt);
341    prompt.push('\n');
342    prompt.push_str(delimiters.end);
343}
344
345pub const MERCURY_CREDENTIALS_URL: SharedString =
346    SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
347pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
348
349#[derive(Debug, thiserror::Error)]
350#[error("{0}")]
351struct MercuryPaymentRequiredError(SharedString);
352
353#[derive(Deserialize)]
354struct MercuryErrorResponse {
355    error: MercuryErrorMessage,
356}
357
358#[derive(Deserialize)]
359struct MercuryErrorMessage {
360    message: String,
361}
362
363fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
364    error
365        .downcast_ref::<MercuryPaymentRequiredError>()
366        .is_some()
367}
368
369fn mercury_payment_required_message(body: &[u8]) -> SharedString {
370    serde_json::from_slice::<MercuryErrorResponse>(body)
371        .map(|response| response.error.message.into())
372        .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
373}
374
375pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
376
377struct GlobalMercuryApiKey(Entity<ApiKeyState>);
378
379impl Global for GlobalMercuryApiKey {}
380
381pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
382    if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
383        return global.0.clone();
384    }
385    let entity =
386        cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
387    cx.set_global(GlobalMercuryApiKey(entity.clone()));
388    entity
389}
390
391pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
392    let credentials_provider = zed_credentials_provider::global(cx);
393    mercury_api_token(cx).update(cx, |key_state, cx| {
394        key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
395    })
396}
397
398const FEEDBACK_API_URL: &str = "https://api-feedback.inceptionlabs.ai/feedback";
399
400#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
401#[serde(rename_all = "snake_case")]
402enum MercuryUserAction {
403    Accept,
404    Reject,
405    Ignore,
406}
407
408#[derive(Serialize)]
409struct FeedbackRequest {
410    request_id: SharedString,
411    provider_name: &'static str,
412    user_action: MercuryUserAction,
413    provider_version: String,
414}
415
416pub(crate) fn edit_prediction_accepted(
417    prediction_id: EditPredictionId,
418    http_client: Arc<dyn HttpClient>,
419    cx: &App,
420) {
421    send_feedback(prediction_id, MercuryUserAction::Accept, http_client, cx);
422}
423
424pub(crate) fn edit_prediction_rejected(
425    prediction_id: EditPredictionId,
426    was_shown: bool,
427    reason: EditPredictionRejectReason,
428    http_client: Arc<dyn HttpClient>,
429    cx: &App,
430) {
431    if !was_shown {
432        return;
433    }
434    let action = match reason {
435        EditPredictionRejectReason::Rejected => MercuryUserAction::Reject,
436        EditPredictionRejectReason::Discarded => MercuryUserAction::Ignore,
437        _ => return,
438    };
439    send_feedback(prediction_id, action, http_client, cx);
440}
441
442fn send_feedback(
443    prediction_id: EditPredictionId,
444    action: MercuryUserAction,
445    http_client: Arc<dyn HttpClient>,
446    cx: &App,
447) {
448    let request_id = prediction_id.0;
449    let app_version = AppVersion::global(cx);
450    cx.background_spawn(async move {
451        let body = FeedbackRequest {
452            request_id,
453            provider_name: "zed",
454            user_action: action,
455            provider_version: app_version.to_string(),
456        };
457
458        let request = http_client::Request::builder()
459            .uri(FEEDBACK_API_URL)
460            .method(Method::POST)
461            .header("Content-Type", "application/json")
462            .body(AsyncBody::from(serde_json::to_vec(&body)?))?;
463
464        let response = http_client.send(request).await?;
465        if !response.status().is_success() {
466            anyhow::bail!("Feedback API returned status: {}", response.status());
467        }
468
469        log::debug!(
470            "Mercury feedback sent: request_id={}, action={:?}",
471            body.request_id,
472            body.user_action
473        );
474
475        anyhow::Ok(())
476    })
477    .detach_and_log_err(cx);
478}