mercury.rs

  1use crate::{
  2    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  3    EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
  4    prediction::EditPredictionResult, zeta::compute_edits,
  5};
  6use anyhow::{Context as _, Result};
  7use cloud_llm_client::EditPredictionRejectReason;
  8use futures::AsyncReadExt as _;
  9use gpui::{
 10    App, AppContext as _, Entity, Global, SharedString, Task,
 11    http_client::{self, AsyncBody, HttpClient, Method},
 12};
 13use language::{ToOffset, ToPoint as _};
 14use language_model::{ApiKeyState, EnvVar, env_var};
 15use release_channel::AppVersion;
 16use serde::Serialize;
 17use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
 18use zeta_prompt::ZetaPromptInput;
 19
 20const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
 21
 22pub struct Mercury {
 23    pub api_token: Entity<ApiKeyState>,
 24}
 25
 26impl Mercury {
 27    pub fn new(cx: &mut App) -> Self {
 28        Mercury {
 29            api_token: mercury_api_token(cx),
 30        }
 31    }
 32
 33    pub(crate) fn request_prediction(
 34        &self,
 35        EditPredictionModelInput {
 36            buffer,
 37            snapshot,
 38            position,
 39            events,
 40            related_files,
 41            debug_tx,
 42            ..
 43        }: EditPredictionModelInput,
 44        cx: &mut App,
 45    ) -> Task<Result<Option<EditPredictionResult>>> {
 46        self.api_token.update(cx, |key_state, cx| {
 47            _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
 48        });
 49        let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
 50            return Task::ready(Ok(None));
 51        };
 52        let full_path: Arc<Path> = snapshot
 53            .file()
 54            .map(|file| file.full_path(cx))
 55            .unwrap_or_else(|| "untitled".into())
 56            .into();
 57
 58        let http_client = cx.http_client();
 59        let cursor_point = position.to_point(&snapshot);
 60        let buffer_snapshotted_at = Instant::now();
 61        let active_buffer = buffer.clone();
 62
 63        let result = cx.background_spawn(async move {
 64            let cursor_offset = cursor_point.to_offset(&snapshot);
 65            let (excerpt_point_range, excerpt_offset_range, cursor_offset_in_excerpt) =
 66                crate::cursor_excerpt::compute_cursor_excerpt(&snapshot, cursor_offset);
 67
 68            let related_files = zeta_prompt::filter_redundant_excerpts(
 69                related_files,
 70                full_path.as_ref(),
 71                excerpt_point_range.start.row..excerpt_point_range.end.row,
 72            );
 73
 74            let cursor_excerpt: Arc<str> = snapshot
 75                .text_for_range(excerpt_point_range.clone())
 76                .collect::<String>()
 77                .into();
 78            let syntax_ranges = crate::cursor_excerpt::compute_syntax_ranges(
 79                &snapshot,
 80                cursor_offset,
 81                &excerpt_offset_range,
 82            );
 83            let excerpt_ranges = zeta_prompt::compute_legacy_excerpt_ranges(
 84                &cursor_excerpt,
 85                cursor_offset_in_excerpt,
 86                &syntax_ranges,
 87            );
 88
 89            let editable_offset_range = (excerpt_offset_range.start
 90                + excerpt_ranges.editable_350.start)
 91                ..(excerpt_offset_range.start + excerpt_ranges.editable_350.end);
 92
 93            let inputs = zeta_prompt::ZetaPromptInput {
 94                events,
 95                related_files: Some(related_files),
 96                cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
 97                    - excerpt_offset_range.start,
 98                cursor_path: full_path.clone(),
 99                cursor_excerpt,
100                experiment: None,
101                excerpt_start_row: Some(excerpt_point_range.start.row),
102                excerpt_ranges,
103                syntax_ranges: Some(syntax_ranges),
104                active_buffer_diagnostics: vec![],
105                in_open_source_repo: false,
106                can_collect_data: false,
107                repo_url: None,
108            };
109
110            let prompt = build_prompt(&inputs);
111
112            if let Some(debug_tx) = &debug_tx {
113                debug_tx
114                    .unbounded_send(DebugEvent::EditPredictionStarted(
115                        EditPredictionStartedDebugEvent {
116                            buffer: active_buffer.downgrade(),
117                            prompt: Some(prompt.clone()),
118                            position,
119                        },
120                    ))
121                    .ok();
122            }
123
124            let request_body = open_ai::Request {
125                model: "mercury-coder".into(),
126                messages: vec![open_ai::RequestMessage::User {
127                    content: open_ai::MessageContent::Plain(prompt),
128                }],
129                stream: false,
130                max_completion_tokens: None,
131                stop: vec![],
132                temperature: None,
133                tool_choice: None,
134                parallel_tool_calls: None,
135                tools: vec![],
136                prompt_cache_key: None,
137                reasoning_effort: None,
138            };
139
140            let buf = serde_json::to_vec(&request_body)?;
141            let body: AsyncBody = buf.into();
142
143            let request = http_client::Request::builder()
144                .uri(MERCURY_API_URL)
145                .header("Content-Type", "application/json")
146                .header("Authorization", format!("Bearer {}", api_token))
147                .header("Connection", "keep-alive")
148                .method(Method::POST)
149                .body(body)
150                .context("Failed to create request")?;
151
152            let mut response = http_client
153                .send(request)
154                .await
155                .context("Failed to send request")?;
156
157            let mut body: Vec<u8> = Vec::new();
158            response
159                .body_mut()
160                .read_to_end(&mut body)
161                .await
162                .context("Failed to read response body")?;
163
164            let response_received_at = Instant::now();
165            if !response.status().is_success() {
166                anyhow::bail!(
167                    "Request failed with status: {:?}\nBody: {}",
168                    response.status(),
169                    String::from_utf8_lossy(&body),
170                );
171            };
172
173            let mut response: open_ai::Response =
174                serde_json::from_slice(&body).context("Failed to parse response")?;
175
176            let id = mem::take(&mut response.id);
177            let response_str = text_from_response(response).unwrap_or_default();
178
179            if let Some(debug_tx) = &debug_tx {
180                debug_tx
181                    .unbounded_send(DebugEvent::EditPredictionFinished(
182                        EditPredictionFinishedDebugEvent {
183                            buffer: active_buffer.downgrade(),
184                            model_output: Some(response_str.clone()),
185                            position,
186                        },
187                    ))
188                    .ok();
189            }
190
191            let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
192            let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
193
194            let mut edits = Vec::new();
195            const NO_PREDICTION_OUTPUT: &str = "None";
196
197            if response_str != NO_PREDICTION_OUTPUT {
198                let old_text = snapshot
199                    .text_for_range(editable_offset_range.clone())
200                    .collect::<String>();
201                edits = compute_edits(
202                    old_text,
203                    &response_str,
204                    editable_offset_range.start,
205                    &snapshot,
206                );
207            }
208
209            anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
210        });
211
212        cx.spawn(async move |cx| {
213            let (id, edits, old_snapshot, response_received_at, inputs) =
214                result.await.context("Mercury edit prediction failed")?;
215            anyhow::Ok(Some(
216                EditPredictionResult::new(
217                    EditPredictionId(id.into()),
218                    &buffer,
219                    &old_snapshot,
220                    edits.into(),
221                    None,
222                    buffer_snapshotted_at,
223                    response_received_at,
224                    inputs,
225                    None,
226                    cx,
227                )
228                .await,
229            ))
230        })
231    }
232}
233
234fn build_prompt(inputs: &ZetaPromptInput) -> String {
235    const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
236    const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
237    const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
238    const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
239    const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
240    const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
241    const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
242    const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
243    const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
244    const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
245    const CURSOR_TAG: &str = "<|cursor|>";
246    const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
247    const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
248
249    let mut prompt = String::new();
250
251    push_delimited(
252        &mut prompt,
253        RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
254        |prompt| {
255            for related_file in inputs.related_files.as_deref().unwrap_or_default().iter() {
256                for related_excerpt in &related_file.excerpts {
257                    push_delimited(
258                        prompt,
259                        RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
260                        |prompt| {
261                            prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
262                            prompt.push_str(related_file.path.to_string_lossy().as_ref());
263                            prompt.push('\n');
264                            prompt.push_str(related_excerpt.text.as_ref());
265                        },
266                    );
267                }
268            }
269        },
270    );
271
272    push_delimited(
273        &mut prompt,
274        CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
275        |prompt| {
276            prompt.push_str(CURRENT_FILE_PATH_PREFIX);
277            prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
278            prompt.push('\n');
279
280            let editable_range = &inputs.excerpt_ranges.editable_350;
281            prompt.push_str(&inputs.cursor_excerpt[0..editable_range.start]);
282            push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
283                prompt.push_str(
284                    &inputs.cursor_excerpt[editable_range.start..inputs.cursor_offset_in_excerpt],
285                );
286                prompt.push_str(CURSOR_TAG);
287                prompt.push_str(
288                    &inputs.cursor_excerpt[inputs.cursor_offset_in_excerpt..editable_range.end],
289                );
290            });
291            prompt.push_str(&inputs.cursor_excerpt[editable_range.end..]);
292        },
293    );
294
295    push_delimited(
296        &mut prompt,
297        EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
298        |prompt| {
299            for event in inputs.events.iter() {
300                zeta_prompt::write_event(prompt, &event);
301            }
302        },
303    );
304
305    prompt
306}
307
308fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
309    prompt.push_str(delimiters.start);
310    cb(prompt);
311    prompt.push('\n');
312    prompt.push_str(delimiters.end);
313}
314
315pub const MERCURY_CREDENTIALS_URL: SharedString =
316    SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
317pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
318pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
319
320struct GlobalMercuryApiKey(Entity<ApiKeyState>);
321
322impl Global for GlobalMercuryApiKey {}
323
324pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
325    if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
326        return global.0.clone();
327    }
328    let entity =
329        cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
330    cx.set_global(GlobalMercuryApiKey(entity.clone()));
331    entity
332}
333
334pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
335    mercury_api_token(cx).update(cx, |key_state, cx| {
336        key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
337    })
338}
339
340const FEEDBACK_API_URL: &str = "https://api-feedback.inceptionlabs.ai/feedback";
341
342#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
343#[serde(rename_all = "snake_case")]
344enum MercuryUserAction {
345    Accept,
346    Reject,
347    Ignore,
348}
349
350#[derive(Serialize)]
351struct FeedbackRequest {
352    request_id: SharedString,
353    provider_name: &'static str,
354    user_action: MercuryUserAction,
355    provider_version: String,
356}
357
358pub(crate) fn edit_prediction_accepted(
359    prediction_id: EditPredictionId,
360    http_client: Arc<dyn HttpClient>,
361    cx: &App,
362) {
363    send_feedback(prediction_id, MercuryUserAction::Accept, http_client, cx);
364}
365
366pub(crate) fn edit_prediction_rejected(
367    prediction_id: EditPredictionId,
368    was_shown: bool,
369    reason: EditPredictionRejectReason,
370    http_client: Arc<dyn HttpClient>,
371    cx: &App,
372) {
373    if !was_shown {
374        return;
375    }
376    let action = match reason {
377        EditPredictionRejectReason::Rejected => MercuryUserAction::Reject,
378        EditPredictionRejectReason::Discarded => MercuryUserAction::Ignore,
379        _ => return,
380    };
381    send_feedback(prediction_id, action, http_client, cx);
382}
383
384fn send_feedback(
385    prediction_id: EditPredictionId,
386    action: MercuryUserAction,
387    http_client: Arc<dyn HttpClient>,
388    cx: &App,
389) {
390    let request_id = prediction_id.0;
391    let app_version = AppVersion::global(cx);
392    cx.background_spawn(async move {
393        let body = FeedbackRequest {
394            request_id,
395            provider_name: "zed",
396            user_action: action,
397            provider_version: app_version.to_string(),
398        };
399
400        let request = http_client::Request::builder()
401            .uri(FEEDBACK_API_URL)
402            .method(Method::POST)
403            .header("Content-Type", "application/json")
404            .body(AsyncBody::from(serde_json::to_vec(&body)?))?;
405
406        let response = http_client.send(request).await?;
407        if !response.status().is_success() {
408            anyhow::bail!("Feedback API returned status: {}", response.status());
409        }
410
411        log::debug!(
412            "Mercury feedback sent: request_id={}, action={:?}",
413            body.request_id,
414            body.user_action
415        );
416
417        anyhow::Ok(())
418    })
419    .detach_and_log_err(cx);
420}