mercury.rs

  1use crate::{
  2    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  3    EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
  4    prediction::EditPredictionResult, zeta::compute_edits,
  5};
  6use anyhow::{Context as _, Result};
  7use cloud_llm_client::EditPredictionRejectReason;
  8use credentials_provider::CredentialsProvider;
  9use futures::AsyncReadExt as _;
 10use gpui::{
 11    App, AppContext as _, Context, Entity, Global, SharedString, Task,
 12    http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
 13};
 14use language::{ToOffset, ToPoint as _};
 15use language_model::{ApiKeyState, EnvVar, env_var};
 16use release_channel::AppVersion;
 17use serde::{Deserialize, Serialize};
 18use std::{mem, ops::Range, path::Path, sync::Arc};
 19use zed_credentials_provider::global as global_credentials_provider;
 20use zeta_prompt::ZetaPromptInput;
 21
 22const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
 23
 24pub struct Mercury {
 25    pub api_token: Entity<ApiKeyState>,
 26    payment_required_error: bool,
 27}
 28
 29impl Mercury {
 30    pub fn new(cx: &mut App) -> Self {
 31        Mercury {
 32            api_token: mercury_api_token(cx),
 33            payment_required_error: false,
 34        }
 35    }
 36
 37    pub fn has_payment_required_error(&self) -> bool {
 38        self.payment_required_error
 39    }
 40
 41    pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
 42        self.payment_required_error = payment_required_error;
 43    }
 44
 45    pub(crate) fn request_prediction(
 46        &mut self,
 47        EditPredictionModelInput {
 48            buffer,
 49            snapshot,
 50            position,
 51            events,
 52            related_files,
 53            debug_tx,
 54            ..
 55        }: EditPredictionModelInput,
 56        credentials_provider: Arc<dyn CredentialsProvider>,
 57        cx: &mut Context<EditPredictionStore>,
 58    ) -> Task<Result<Option<EditPredictionResult>>> {
 59        self.api_token.update(cx, |key_state, cx| {
 60            _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx);
 61        });
 62        let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
 63            return Task::ready(Ok(None));
 64        };
 65        let full_path: Arc<Path> = snapshot
 66            .file()
 67            .map(|file| file.full_path(cx))
 68            .unwrap_or_else(|| "untitled".into())
 69            .into();
 70
 71        let http_client = cx.http_client();
 72        let cursor_point = position.to_point(&snapshot);
 73        let request_start = cx.background_executor().now();
 74        let active_buffer = buffer.clone();
 75
 76        let result = cx.background_spawn(async move {
 77            let cursor_offset = cursor_point.to_offset(&snapshot);
 78            let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
 79                crate::cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
 80
 81            let related_files = zeta_prompt::filter_redundant_excerpts(
 82                related_files,
 83                full_path.as_ref(),
 84                excerpt_point_range.start.row..excerpt_point_range.end.row,
 85            );
 86
 87            let cursor_excerpt: Arc<str> = snapshot
 88                .text_for_range(excerpt_point_range.clone())
 89                .collect::<String>()
 90                .into();
 91            let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
 92                &snapshot,
 93                cursor_offset,
 94                &excerpt_offset_range,
 95            );
 96            let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
 97                &cursor_excerpt,
 98                cursor_offset_in_excerpt,
 99                &syntax_ranges,
100            );
101
102            let editable_offset_range = (excerpt_offset_range.start
103                + excerpt_ranges.editable_350.start)
104                ..(excerpt_offset_range.start + excerpt_ranges.editable_350.end);
105
106            let inputs = zeta_prompt::ZetaPromptInput {
107                events,
108                related_files: Some(related_files),
109                cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
110                    - excerpt_offset_range.start,
111                cursor_path: full_path.clone(),
112                cursor_excerpt,
113                experiment: None,
114                excerpt_start_row: Some(excerpt_point_range.start.row),
115                excerpt_ranges,
116                syntax_ranges: Some(syntax_ranges),
117                active_buffer_diagnostics: vec![],
118                in_open_source_repo: false,
119                can_collect_data: false,
120                repo_url: None,
121            };
122
123            let prompt = build_prompt(&inputs);
124
125            if let Some(debug_tx) = &debug_tx {
126                debug_tx
127                    .unbounded_send(DebugEvent::EditPredictionStarted(
128                        EditPredictionStartedDebugEvent {
129                            buffer: active_buffer.downgrade(),
130                            prompt: Some(prompt.clone()),
131                            position,
132                        },
133                    ))
134                    .ok();
135            }
136
137            let request_body = open_ai::Request {
138                model: "mercury-coder".into(),
139                messages: vec![open_ai::RequestMessage::User {
140                    content: open_ai::MessageContent::Plain(prompt),
141                }],
142                stream: false,
143                stream_options: None,
144                max_completion_tokens: None,
145                stop: vec![],
146                temperature: None,
147                tool_choice: None,
148                parallel_tool_calls: None,
149                tools: vec![],
150                prompt_cache_key: None,
151                reasoning_effort: None,
152            };
153
154            let buf = serde_json::to_vec(&request_body)?;
155            let body: AsyncBody = buf.into();
156
157            let request = http_client::Request::builder()
158                .uri(MERCURY_API_URL)
159                .header("Content-Type", "application/json")
160                .header("Authorization", format!("Bearer {}", api_token))
161                .header("Connection", "keep-alive")
162                .method(Method::POST)
163                .body(body)
164                .context("Failed to create request")?;
165
166            let mut response = http_client
167                .send(request)
168                .await
169                .context("Failed to send request")?;
170
171            let mut body: Vec<u8> = Vec::new();
172            response
173                .body_mut()
174                .read_to_end(&mut body)
175                .await
176                .context("Failed to read response body")?;
177
178            if !response.status().is_success() {
179                if response.status() == StatusCode::PAYMENT_REQUIRED {
180                    anyhow::bail!(MercuryPaymentRequiredError(
181                        mercury_payment_required_message(&body),
182                    ));
183                }
184
185                anyhow::bail!(
186                    "Request failed with status: {:?}\nBody: {}",
187                    response.status(),
188                    String::from_utf8_lossy(&body),
189                );
190            };
191
192            let mut response: open_ai::Response =
193                serde_json::from_slice(&body).context("Failed to parse response")?;
194
195            let id = mem::take(&mut response.id);
196            let response_str = text_from_response(response).unwrap_or_default();
197
198            if let Some(debug_tx) = &debug_tx {
199                debug_tx
200                    .unbounded_send(DebugEvent::EditPredictionFinished(
201                        EditPredictionFinishedDebugEvent {
202                            buffer: active_buffer.downgrade(),
203                            model_output: Some(response_str.clone()),
204                            position,
205                        },
206                    ))
207                    .ok();
208            }
209
210            let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
211            let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
212
213            let mut edits = Vec::new();
214            const NO_PREDICTION_OUTPUT: &str = "None";
215
216            if response_str != NO_PREDICTION_OUTPUT {
217                let old_text = snapshot
218                    .text_for_range(editable_offset_range.clone())
219                    .collect::<String>();
220                edits = compute_edits(
221                    old_text,
222                    &response_str,
223                    editable_offset_range.start,
224                    &snapshot,
225                );
226            }
227
228            anyhow::Ok((id, edits, snapshot, inputs))
229        });
230
231        cx.spawn(async move |ep_store, cx| {
232            let result = result.await.context("Mercury edit prediction failed");
233
234            let has_payment_required_error = result
235                .as_ref()
236                .err()
237                .is_some_and(is_mercury_payment_required_error);
238
239            ep_store.update(cx, |store, cx| {
240                store
241                    .mercury
242                    .set_payment_required_error(has_payment_required_error);
243                cx.notify();
244            })?;
245
246            let (id, edits, old_snapshot, inputs) = result?;
247            anyhow::Ok(Some(
248                EditPredictionResult::new(
249                    EditPredictionId(id.into()),
250                    &buffer,
251                    &old_snapshot,
252                    edits.into(),
253                    None,
254                    inputs,
255                    None,
256                    cx.background_executor().now() - request_start,
257                    cx,
258                )
259                .await,
260            ))
261        })
262    }
263}
264
265fn build_prompt(inputs: &ZetaPromptInput) -> String {
266    const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
267    const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
268    const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
269    const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
270    const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
271    const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
272    const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
273    const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
274    const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
275    const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
276    const CURSOR_TAG: &str = "<|cursor|>";
277    const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
278    const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
279
280    let mut prompt = String::new();
281
282    push_delimited(
283        &mut prompt,
284        RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
285        |prompt| {
286            for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() {
287                for related_excerpt in &related_file.excerpts {
288                    push_delimited(
289                        prompt,
290                        RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
291                        |prompt| {
292                            prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
293                            prompt.push_str(related_file.path.to_string_lossy().as_ref());
294                            prompt.push('\n');
295                            prompt.push_str(related_excerpt.text.as_ref());
296                        },
297                    );
298                }
299            }
300        },
301    );
302
303    push_delimited(
304        &mut prompt,
305        CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
306        |prompt| {
307            prompt.push_str(CURRENT_FILE_PATH_PREFIX);
308            prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
309            prompt.push('\n');
310
311            let editable_range = &inputs.excerpt_ranges.editable_350;
312            prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]);
313            push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
314                prompt.push_str(
315                    &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt],
316                );
317                prompt.push_str(CURSOR_TAG);
318                prompt.push_str(
319                    &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end],
320                );
321            });
322            prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]);
323        },
324    );
325
326    push_delimited(
327        &mut prompt,
328        EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
329        |prompt| {
330            for event in inputs.events.iter() {
331                zeta_prompt::write_event(prompt, &event);
332            }
333        },
334    );
335
336    prompt
337}
338
339fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
340    prompt.push_str(delimiters.start);
341    cb(prompt);
342    prompt.push('\n');
343    prompt.push_str(delimiters.end);
344}
345
346pub const MERCURY_CREDENTIALS_URL: SharedString =
347    SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
348pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
349
350#[derive(Debug, thiserror::Error)]
351#[error("{0}")]
352struct MercuryPaymentRequiredError(SharedString);
353
354#[derive(Deserialize)]
355struct MercuryErrorResponse {
356    error: MercuryErrorMessage,
357}
358
359#[derive(Deserialize)]
360struct MercuryErrorMessage {
361    message: String,
362}
363
364fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
365    error
366        .downcast_ref::<MercuryPaymentRequiredError>()
367        .is_some()
368}
369
370fn mercury_payment_required_message(body: &[u8]) -> SharedString {
371    serde_json::from_slice::<MercuryErrorResponse>(body)
372        .map(|response| response.error.message.into())
373        .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
374}
375
376pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
377
378struct GlobalMercuryApiKey(Entity<ApiKeyState>);
379
380impl Global for GlobalMercuryApiKey {}
381
382pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
383    if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
384        return global.0.clone();
385    }
386    let entity =
387        cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
388    cx.set_global(GlobalMercuryApiKey(entity.clone()));
389    entity
390}
391
392pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
393    let credentials_provider = global_credentials_provider(cx);
394    mercury_api_token(cx).update(cx, |key_state, cx| {
395        key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, credentials_provider, cx)
396    })
397}
398
399const FEEDBACK_API_URL: &str = "https://api-feedback.inceptionlabs.ai/feedback";
400
401#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
402#[serde(rename_all = "snake_case")]
403enum MercuryUserAction {
404    Accept,
405    Reject,
406    Ignore,
407}
408
409#[derive(Serialize)]
410struct FeedbackRequest {
411    request_id: SharedString,
412    provider_name: &'static str,
413    user_action: MercuryUserAction,
414    provider_version: String,
415}
416
417pub(crate) fn edit_prediction_accepted(
418    prediction_id: EditPredictionId,
419    http_client: Arc<dyn HttpClient>,
420    cx: &App,
421) {
422    send_feedback(prediction_id, MercuryUserAction::Accept, http_client, cx);
423}
424
425pub(crate) fn edit_prediction_rejected(
426    prediction_id: EditPredictionId,
427    was_shown: bool,
428    reason: EditPredictionRejectReason,
429    http_client: Arc<dyn HttpClient>,
430    cx: &App,
431) {
432    if !was_shown {
433        return;
434    }
435    let action = match reason {
436        EditPredictionRejectReason::Rejected => MercuryUserAction::Reject,
437        EditPredictionRejectReason::Discarded => MercuryUserAction::Ignore,
438        _ => return,
439    };
440    send_feedback(prediction_id, action, http_client, cx);
441}
442
443fn send_feedback(
444    prediction_id: EditPredictionId,
445    action: MercuryUserAction,
446    http_client: Arc<dyn HttpClient>,
447    cx: &App,
448) {
449    let request_id = prediction_id.0;
450    let app_version = AppVersion::global(cx);
451    cx.background_spawn(async move {
452        let body = FeedbackRequest {
453            request_id,
454            provider_name: "zed",
455            user_action: action,
456            provider_version: app_version.to_string(),
457        };
458
459        let request = http_client::Request::builder()
460            .uri(FEEDBACK_API_URL)
461            .method(Method::POST)
462            .header("Content-Type", "application/json")
463            .body(AsyncBody::from(serde_json::to_vec(&body)?))?;
464
465        let response = http_client.send(request).await?;
466        if !response.status().is_success() {
467            anyhow::bail!("Feedback API returned status: {}", response.status());
468        }
469
470        log::debug!(
471            "Mercury feedback sent: request_id={}, action={:?}",
472            body.request_id,
473            body.user_action
474        );
475
476        anyhow::Ok(())
477    })
478    .detach_and_log_err(cx);
479}