mercury.rs

  1use crate::{
  2    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  3    EditPredictionStartedDebugEvent, EditPredictionStore, open_ai_response::text_from_response,
  4    prediction::EditPredictionResult, zeta::compute_edits,
  5};
  6use anyhow::{Context as _, Result};
  7use cloud_llm_client::EditPredictionRejectReason;
  8use futures::AsyncReadExt as _;
  9use gpui::{
 10    App, AppContext as _, Context, Entity, Global, SharedString, Task,
 11    http_client::{self, AsyncBody, HttpClient, Method, StatusCode},
 12};
 13use language::{ToOffset, ToPoint as _};
 14use language_model::{ApiKeyState, EnvVar, env_var};
 15use release_channel::AppVersion;
 16use serde::{Deserialize, Serialize};
 17use std::{mem, ops::Range, path::Path, sync::Arc};
 18use zeta_prompt::ZetaPromptInput;
 19
 20const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
 21
 22pub struct Mercury {
 23    pub api_token: Entity<ApiKeyState>,
 24    payment_required_error: bool,
 25}
 26
 27impl Mercury {
 28    pub fn new(cx: &mut App) -> Self {
 29        Mercury {
 30            api_token: mercury_api_token(cx),
 31            payment_required_error: false,
 32        }
 33    }
 34
 35    pub fn has_payment_required_error(&self) -> bool {
 36        self.payment_required_error
 37    }
 38
 39    pub fn set_payment_required_error(&mut self, payment_required_error: bool) {
 40        self.payment_required_error = payment_required_error;
 41    }
 42
 43    pub(crate) fn request_prediction(
 44        &mut self,
 45        EditPredictionModelInput {
 46            buffer,
 47            snapshot,
 48            position,
 49            events,
 50            related_files,
 51            debug_tx,
 52            ..
 53        }: EditPredictionModelInput,
 54        cx: &mut Context<EditPredictionStore>,
 55    ) -> Task<Result<Option<EditPredictionResult>>> {
 56        self.api_token.update(cx, |key_state, cx| {
 57            _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
 58        });
 59        let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
 60            return Task::ready(Ok(None));
 61        };
 62        let full_path: Arc<Path> = snapshot
 63            .file()
 64            .map(|file| file.full_path(cx))
 65            .unwrap_or_else(|| "untitled".into())
 66            .into();
 67
 68        let http_client = cx.http_client();
 69        let cursor_point = position.to_point(&snapshot);
 70        let request_start = cx.background_executor().now();
 71        let active_buffer = buffer.clone();
 72
 73        let result = cx.background_spawn(async move {
 74            let cursor_offset = cursor_point.to_offset(&snapshot);
 75            let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
 76                crate::cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
 77
 78            let related_files = zeta_prompt::filter_redundant_excerpts(
 79                related_files,
 80                full_path.as_ref(),
 81                excerpt_point_range.start.row..excerpt_point_range.end.row,
 82            );
 83
 84            let cursor_excerpt: Arc<str> = snapshot
 85                .text_for_range(excerpt_point_range.clone())
 86                .collect::<String>()
 87                .into();
 88            let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
 89                &snapshot,
 90                cursor_offset,
 91                &excerpt_offset_range,
 92            );
 93            let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
 94                &cursor_excerpt,
 95                cursor_offset_in_excerpt,
 96                &syntax_ranges,
 97            );
 98
 99            let editable_offset_range = (excerpt_offset_range.start
100                + excerpt_ranges.editable_350.start)
101                ..(excerpt_offset_range.start + excerpt_ranges.editable_350.end);
102
103            let inputs = zeta_prompt::ZetaPromptInput {
104                events,
105                related_files: Some(related_files),
106                cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
107                    - excerpt_offset_range.start,
108                cursor_path: full_path.clone(),
109                cursor_excerpt,
110                experiment: None,
111                excerpt_start_row: Some(excerpt_point_range.start.row),
112                excerpt_ranges,
113                syntax_ranges: Some(syntax_ranges),
114                active_buffer_diagnostics: vec![],
115                in_open_source_repo: false,
116                can_collect_data: false,
117                repo_url: None,
118            };
119
120            let prompt = build_prompt(&inputs);
121
122            if let Some(debug_tx) = &debug_tx {
123                debug_tx
124                    .unbounded_send(DebugEvent::EditPredictionStarted(
125                        EditPredictionStartedDebugEvent {
126                            buffer: active_buffer.downgrade(),
127                            prompt: Some(prompt.clone()),
128                            position,
129                        },
130                    ))
131                    .ok();
132            }
133
134            let request_body = open_ai::Request {
135                model: "mercury-coder".into(),
136                messages: vec![open_ai::RequestMessage::User {
137                    content: open_ai::MessageContent::Plain(prompt),
138                }],
139                stream: false,
140                stream_options: None,
141                max_completion_tokens: None,
142                stop: vec![],
143                temperature: None,
144                tool_choice: None,
145                parallel_tool_calls: None,
146                tools: vec![],
147                prompt_cache_key: None,
148                reasoning_effort: None,
149            };
150
151            let buf = serde_json::to_vec(&request_body)?;
152            let body: AsyncBody = buf.into();
153
154            let request = http_client::Request::builder()
155                .uri(MERCURY_API_URL)
156                .header("Content-Type", "application/json")
157                .header("Authorization", format!("Bearer {}", api_token))
158                .header("Connection", "keep-alive")
159                .method(Method::POST)
160                .body(body)
161                .context("Failed to create request")?;
162
163            let mut response = http_client
164                .send(request)
165                .await
166                .context("Failed to send request")?;
167
168            let mut body: Vec<u8> = Vec::new();
169            response
170                .body_mut()
171                .read_to_end(&mut body)
172                .await
173                .context("Failed to read response body")?;
174
175            if !response.status().is_success() {
176                if response.status() == StatusCode::PAYMENT_REQUIRED {
177                    anyhow::bail!(MercuryPaymentRequiredError(
178                        mercury_payment_required_message(&body),
179                    ));
180                }
181
182                anyhow::bail!(
183                    "Request failed with status: {:?}\nBody: {}",
184                    response.status(),
185                    String::from_utf8_lossy(&body),
186                );
187            };
188
189            let mut response: open_ai::Response =
190                serde_json::from_slice(&body).context("Failed to parse response")?;
191
192            let id = mem::take(&mut response.id);
193            let response_str = text_from_response(response).unwrap_or_default();
194
195            if let Some(debug_tx) = &debug_tx {
196                debug_tx
197                    .unbounded_send(DebugEvent::EditPredictionFinished(
198                        EditPredictionFinishedDebugEvent {
199                            buffer: active_buffer.downgrade(),
200                            model_output: Some(response_str.clone()),
201                            position,
202                        },
203                    ))
204                    .ok();
205            }
206
207            let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
208            let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
209
210            let mut edits = Vec::new();
211            const NO_PREDICTION_OUTPUT: &str = "None";
212
213            if response_str != NO_PREDICTION_OUTPUT {
214                let old_text = snapshot
215                    .text_for_range(editable_offset_range.clone())
216                    .collect::<String>();
217                edits = compute_edits(
218                    old_text,
219                    &response_str,
220                    editable_offset_range.start,
221                    &snapshot,
222                );
223            }
224
225            anyhow::Ok((id, edits, snapshot, inputs))
226        });
227
228        cx.spawn(async move |ep_store, cx| {
229            let result = result.await.context("Mercury edit prediction failed");
230
231            let has_payment_required_error = result
232                .as_ref()
233                .err()
234                .is_some_and(is_mercury_payment_required_error);
235
236            ep_store.update(cx, |store, cx| {
237                store
238                    .mercury
239                    .set_payment_required_error(has_payment_required_error);
240                cx.notify();
241            })?;
242
243            let (id, edits, old_snapshot, inputs) = result?;
244            anyhow::Ok(Some(
245                EditPredictionResult::new(
246                    EditPredictionId(id.into()),
247                    &buffer,
248                    &old_snapshot,
249                    edits.into(),
250                    None,
251                    inputs,
252                    None,
253                    cx.background_executor().now() - request_start,
254                    cx,
255                )
256                .await,
257            ))
258        })
259    }
260}
261
262fn build_prompt(inputs: &ZetaPromptInput) -> String {
263    const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
264    const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
265    const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
266    const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
267    const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
268    const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
269    const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
270    const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
271    const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
272    const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
273    const CURSOR_TAG: &str = "<|cursor|>";
274    const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
275    const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
276
277    let mut prompt = String::new();
278
279    push_delimited(
280        &mut prompt,
281        RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
282        |prompt| {
283            for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() {
284                for related_excerpt in &related_file.excerpts {
285                    push_delimited(
286                        prompt,
287                        RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
288                        |prompt| {
289                            prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
290                            prompt.push_str(related_file.path.to_string_lossy().as_ref());
291                            prompt.push('\n');
292                            prompt.push_str(related_excerpt.text.as_ref());
293                        },
294                    );
295                }
296            }
297        },
298    );
299
300    push_delimited(
301        &mut prompt,
302        CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
303        |prompt| {
304            prompt.push_str(CURRENT_FILE_PATH_PREFIX);
305            prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
306            prompt.push('\n');
307
308            let editable_range = &inputs.excerpt_ranges.editable_350;
309            prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]);
310            push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
311                prompt.push_str(
312                    &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt],
313                );
314                prompt.push_str(CURSOR_TAG);
315                prompt.push_str(
316                    &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end],
317                );
318            });
319            prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]);
320        },
321    );
322
323    push_delimited(
324        &mut prompt,
325        EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
326        |prompt| {
327            for event in inputs.events.iter() {
328                zeta_prompt::write_event(prompt, &event);
329            }
330        },
331    );
332
333    prompt
334}
335
336fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
337    prompt.push_str(delimiters.start);
338    cb(prompt);
339    prompt.push('\n');
340    prompt.push_str(delimiters.end);
341}
342
343pub const MERCURY_CREDENTIALS_URL: SharedString =
344    SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
345pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
346
347#[derive(Debug, thiserror::Error)]
348#[error("{0}")]
349struct MercuryPaymentRequiredError(SharedString);
350
351#[derive(Deserialize)]
352struct MercuryErrorResponse {
353    error: MercuryErrorMessage,
354}
355
356#[derive(Deserialize)]
357struct MercuryErrorMessage {
358    message: String,
359}
360
361fn is_mercury_payment_required_error(error: &anyhow::Error) -> bool {
362    error
363        .downcast_ref::<MercuryPaymentRequiredError>()
364        .is_some()
365}
366
367fn mercury_payment_required_message(body: &[u8]) -> SharedString {
368    serde_json::from_slice::<MercuryErrorResponse>(body)
369        .map(|response| response.error.message.into())
370        .unwrap_or_else(|_| String::from_utf8_lossy(body).trim().to_string().into())
371}
372
373pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
374
375struct GlobalMercuryApiKey(Entity<ApiKeyState>);
376
377impl Global for GlobalMercuryApiKey {}
378
379pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
380    if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
381        return global.0.clone();
382    }
383    let entity =
384        cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
385    cx.set_global(GlobalMercuryApiKey(entity.clone()));
386    entity
387}
388
389pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
390    mercury_api_token(cx).update(cx, |key_state, cx| {
391        key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
392    })
393}
394
395const FEEDBACK_API_URL: &str = "https://api-feedback.inceptionlabs.ai/feedback";
396
397#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
398#[serde(rename_all = "snake_case")]
399enum MercuryUserAction {
400    Accept,
401    Reject,
402    Ignore,
403}
404
405#[derive(Serialize)]
406struct FeedbackRequest {
407    request_id: SharedString,
408    provider_name: &'static str,
409    user_action: MercuryUserAction,
410    provider_version: String,
411}
412
413pub(crate) fn edit_prediction_accepted(
414    prediction_id: EditPredictionId,
415    http_client: Arc<dyn HttpClient>,
416    cx: &App,
417) {
418    send_feedback(prediction_id, MercuryUserAction::Accept, http_client, cx);
419}
420
421pub(crate) fn edit_prediction_rejected(
422    prediction_id: EditPredictionId,
423    was_shown: bool,
424    reason: EditPredictionRejectReason,
425    http_client: Arc<dyn HttpClient>,
426    cx: &App,
427) {
428    if !was_shown {
429        return;
430    }
431    let action = match reason {
432        EditPredictionRejectReason::Rejected => MercuryUserAction::Reject,
433        EditPredictionRejectReason::Discarded => MercuryUserAction::Ignore,
434        _ => return,
435    };
436    send_feedback(prediction_id, action, http_client, cx);
437}
438
439fn send_feedback(
440    prediction_id: EditPredictionId,
441    action: MercuryUserAction,
442    http_client: Arc<dyn HttpClient>,
443    cx: &App,
444) {
445    let request_id = prediction_id.0;
446    let app_version = AppVersion::global(cx);
447    cx.background_spawn(async move {
448        let body = FeedbackRequest {
449            request_id,
450            provider_name: "zed",
451            user_action: action,
452            provider_version: app_version.to_string(),
453        };
454
455        let request = http_client::Request::builder()
456            .uri(FEEDBACK_API_URL)
457            .method(Method::POST)
458            .header("Content-Type", "application/json")
459            .body(AsyncBody::from(serde_json::to_vec(&body)?))?;
460
461        let response = http_client.send(request).await?;
462        if !response.status().is_success() {
463            anyhow::bail!("Feedback API returned status: {}", response.status());
464        }
465
466        log::debug!(
467            "Mercury feedback sent: request_id={}, action={:?}",
468            body.request_id,
469            body.user_action
470        );
471
472        anyhow::Ok(())
473    })
474    .detach_and_log_err(cx);
475}