mercury.rs

  1use crate::{
  2    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  3    EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
  4    prediction::EditPredictionResult,
  5};
  6use anyhow::{Context as _, Result};
  7use futures::AsyncReadExt as _;
  8use gpui::{
  9    App, AppContext as _, Entity, Global, SharedString, Task,
 10    http_client::{self, AsyncBody, Method},
 11};
 12use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
 13use language_model::{ApiKeyState, EnvVar, env_var};
 14use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
 15use zeta_prompt::ZetaPromptInput;
 16
 17const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
 18const MAX_REWRITE_TOKENS: usize = 150;
 19const MAX_CONTEXT_TOKENS: usize = 350;
 20
 21pub struct Mercury {
 22    pub api_token: Entity<ApiKeyState>,
 23}
 24
 25impl Mercury {
 26    pub fn new(cx: &mut App) -> Self {
 27        Mercury {
 28            api_token: mercury_api_token(cx),
 29        }
 30    }
 31
 32    pub(crate) fn request_prediction(
 33        &self,
 34        EditPredictionModelInput {
 35            buffer,
 36            snapshot,
 37            position,
 38            events,
 39            related_files,
 40            debug_tx,
 41            ..
 42        }: EditPredictionModelInput,
 43        cx: &mut App,
 44    ) -> Task<Result<Option<EditPredictionResult>>> {
 45        self.api_token.update(cx, |key_state, cx| {
 46            _ = key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx);
 47        });
 48        let Some(api_token) = self.api_token.read(cx).key(&MERCURY_CREDENTIALS_URL) else {
 49            return Task::ready(Ok(None));
 50        };
 51        let full_path: Arc<Path> = snapshot
 52            .file()
 53            .map(|file| file.full_path(cx))
 54            .unwrap_or_else(|| "untitled".into())
 55            .into();
 56
 57        let http_client = cx.http_client();
 58        let cursor_point = position.to_point(&snapshot);
 59        let buffer_snapshotted_at = Instant::now();
 60        let active_buffer = buffer.clone();
 61
 62        let result = cx.background_spawn(async move {
 63            let (editable_range, context_range) =
 64                crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
 65                    cursor_point,
 66                    &snapshot,
 67                    MAX_CONTEXT_TOKENS,
 68                    MAX_REWRITE_TOKENS,
 69                );
 70
 71            let related_files = crate::filter_redundant_excerpts(
 72                related_files,
 73                full_path.as_ref(),
 74                context_range.start.row..context_range.end.row,
 75            );
 76
 77            let context_offset_range = context_range.to_offset(&snapshot);
 78
 79            let editable_offset_range = editable_range.to_offset(&snapshot);
 80
 81            let inputs = zeta_prompt::ZetaPromptInput {
 82                events,
 83                related_files,
 84                cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
 85                    - context_range.start.to_offset(&snapshot),
 86                cursor_path: full_path.clone(),
 87                cursor_excerpt: snapshot
 88                    .text_for_range(context_range)
 89                    .collect::<String>()
 90                    .into(),
 91                editable_range_in_excerpt: (editable_offset_range.start
 92                    - context_offset_range.start)
 93                    ..(editable_offset_range.end - context_offset_range.start),
 94            };
 95
 96            let prompt = build_prompt(&inputs);
 97
 98            if let Some(debug_tx) = &debug_tx {
 99                debug_tx
100                    .unbounded_send(DebugEvent::EditPredictionStarted(
101                        EditPredictionStartedDebugEvent {
102                            buffer: active_buffer.downgrade(),
103                            prompt: Some(prompt.clone()),
104                            position,
105                        },
106                    ))
107                    .ok();
108            }
109
110            let request_body = open_ai::Request {
111                model: "mercury-coder".into(),
112                messages: vec![open_ai::RequestMessage::User {
113                    content: open_ai::MessageContent::Plain(prompt),
114                }],
115                stream: false,
116                max_completion_tokens: None,
117                stop: vec![],
118                temperature: None,
119                tool_choice: None,
120                parallel_tool_calls: None,
121                tools: vec![],
122                prompt_cache_key: None,
123                reasoning_effort: None,
124            };
125
126            let buf = serde_json::to_vec(&request_body)?;
127            let body: AsyncBody = buf.into();
128
129            let request = http_client::Request::builder()
130                .uri(MERCURY_API_URL)
131                .header("Content-Type", "application/json")
132                .header("Authorization", format!("Bearer {}", api_token))
133                .header("Connection", "keep-alive")
134                .method(Method::POST)
135                .body(body)
136                .context("Failed to create request")?;
137
138            let mut response = http_client
139                .send(request)
140                .await
141                .context("Failed to send request")?;
142
143            let mut body: Vec<u8> = Vec::new();
144            response
145                .body_mut()
146                .read_to_end(&mut body)
147                .await
148                .context("Failed to read response body")?;
149
150            let response_received_at = Instant::now();
151            if !response.status().is_success() {
152                anyhow::bail!(
153                    "Request failed with status: {:?}\nBody: {}",
154                    response.status(),
155                    String::from_utf8_lossy(&body),
156                );
157            };
158
159            let mut response: open_ai::Response =
160                serde_json::from_slice(&body).context("Failed to parse response")?;
161
162            let id = mem::take(&mut response.id);
163            let response_str = text_from_response(response).unwrap_or_default();
164
165            if let Some(debug_tx) = &debug_tx {
166                debug_tx
167                    .unbounded_send(DebugEvent::EditPredictionFinished(
168                        EditPredictionFinishedDebugEvent {
169                            buffer: active_buffer.downgrade(),
170                            model_output: Some(response_str.clone()),
171                            position,
172                        },
173                    ))
174                    .ok();
175            }
176
177            let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
178            let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
179
180            let mut edits = Vec::new();
181            const NO_PREDICTION_OUTPUT: &str = "None";
182
183            if response_str != NO_PREDICTION_OUTPUT {
184                let old_text = snapshot
185                    .text_for_range(editable_offset_range.clone())
186                    .collect::<String>();
187                edits.extend(
188                    language::text_diff(&old_text, &response_str)
189                        .into_iter()
190                        .map(|(range, text)| {
191                            (
192                                snapshot.anchor_after(editable_offset_range.start + range.start)
193                                    ..snapshot
194                                        .anchor_before(editable_offset_range.start + range.end),
195                                text,
196                            )
197                        }),
198                );
199            }
200
201            anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
202        });
203
204        cx.spawn(async move |cx| {
205            let (id, edits, old_snapshot, response_received_at, inputs) =
206                result.await.context("Mercury edit prediction failed")?;
207            anyhow::Ok(Some(
208                EditPredictionResult::new(
209                    EditPredictionId(id.into()),
210                    &buffer,
211                    &old_snapshot,
212                    edits.into(),
213                    buffer_snapshotted_at,
214                    response_received_at,
215                    inputs,
216                    cx,
217                )
218                .await,
219            ))
220        })
221    }
222}
223
224fn build_prompt(inputs: &ZetaPromptInput) -> String {
225    const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
226    const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
227    const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
228    const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
229    const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
230    const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
231    const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
232    const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
233    const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
234    const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
235    const CURSOR_TAG: &str = "<|cursor|>";
236    const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
237    const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
238
239    let mut prompt = String::new();
240
241    push_delimited(
242        &mut prompt,
243        RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
244        |prompt| {
245            for related_file in inputs.related_files.iter() {
246                for related_excerpt in &related_file.excerpts {
247                    push_delimited(
248                        prompt,
249                        RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
250                        |prompt| {
251                            prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
252                            prompt.push_str(related_file.path.to_string_lossy().as_ref());
253                            prompt.push('\n');
254                            prompt.push_str(related_excerpt.text.as_ref());
255                        },
256                    );
257                }
258            }
259        },
260    );
261
262    push_delimited(
263        &mut prompt,
264        CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
265        |prompt| {
266            prompt.push_str(CURRENT_FILE_PATH_PREFIX);
267            prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
268            prompt.push('\n');
269
270            prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
271            push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
272                prompt.push_str(
273                    &inputs.cursor_excerpt
274                        [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
275                );
276                prompt.push_str(CURSOR_TAG);
277                prompt.push_str(
278                    &inputs.cursor_excerpt
279                        [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
280                );
281            });
282            prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
283        },
284    );
285
286    push_delimited(
287        &mut prompt,
288        EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
289        |prompt| {
290            for event in inputs.events.iter() {
291                zeta_prompt::write_event(prompt, &event);
292            }
293        },
294    );
295
296    prompt
297}
298
299fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
300    prompt.push_str(delimiters.start);
301    cb(prompt);
302    prompt.push('\n');
303    prompt.push_str(delimiters.end);
304}
305
306pub const MERCURY_CREDENTIALS_URL: SharedString =
307    SharedString::new_static("https://api.inceptionlabs.ai/v1/edit/completions");
308pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
309pub static MERCURY_TOKEN_ENV_VAR: std::sync::LazyLock<EnvVar> = env_var!("MERCURY_AI_TOKEN");
310
311struct GlobalMercuryApiKey(Entity<ApiKeyState>);
312
313impl Global for GlobalMercuryApiKey {}
314
315pub fn mercury_api_token(cx: &mut App) -> Entity<ApiKeyState> {
316    if let Some(global) = cx.try_global::<GlobalMercuryApiKey>() {
317        return global.0.clone();
318    }
319    let entity =
320        cx.new(|_| ApiKeyState::new(MERCURY_CREDENTIALS_URL, MERCURY_TOKEN_ENV_VAR.clone()));
321    cx.set_global(GlobalMercuryApiKey(entity.clone()));
322    entity
323}
324
325pub fn load_mercury_api_token(cx: &mut App) -> Task<Result<(), language_model::AuthenticateError>> {
326    mercury_api_token(cx).update(cx, |key_state, cx| {
327        key_state.load_if_needed(MERCURY_CREDENTIALS_URL, |s| s, cx)
328    })
329}