mercury.rs

  1use anyhow::{Context as _, Result};
  2use credentials_provider::CredentialsProvider;
  3use futures::{AsyncReadExt as _, FutureExt, future::Shared};
  4use gpui::{
  5    App, AppContext as _, Task,
  6    http_client::{self, AsyncBody, Method},
  7};
  8use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
  9use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
 10use zeta_prompt::ZetaPromptInput;
 11
 12use crate::{
 13    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
 14    EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
 15    prediction::EditPredictionResult,
 16};
 17
 18const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
 19const MAX_CONTEXT_TOKENS: usize = 150;
 20const MAX_REWRITE_TOKENS: usize = 350;
 21
 22pub struct Mercury {
 23    pub api_token: Shared<Task<Option<String>>>,
 24}
 25
 26impl Mercury {
 27    pub fn new(cx: &App) -> Self {
 28        Mercury {
 29            api_token: load_api_token(cx).shared(),
 30        }
 31    }
 32
 33    pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
 34        self.api_token = Task::ready(api_token.clone()).shared();
 35        store_api_token_in_keychain(api_token, cx)
 36    }
 37
 38    pub(crate) fn request_prediction(
 39        &self,
 40        EditPredictionModelInput {
 41            buffer,
 42            snapshot,
 43            position,
 44            events,
 45            related_files,
 46            debug_tx,
 47            ..
 48        }: EditPredictionModelInput,
 49        cx: &mut App,
 50    ) -> Task<Result<Option<EditPredictionResult>>> {
 51        let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
 52            return Task::ready(Ok(None));
 53        };
 54        let full_path: Arc<Path> = snapshot
 55            .file()
 56            .map(|file| file.full_path(cx))
 57            .unwrap_or_else(|| "untitled".into())
 58            .into();
 59
 60        let http_client = cx.http_client();
 61        let cursor_point = position.to_point(&snapshot);
 62        let buffer_snapshotted_at = Instant::now();
 63        let active_buffer = buffer.clone();
 64
 65        let result = cx.background_spawn(async move {
 66            let (editable_range, context_range) =
 67                crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
 68                    cursor_point,
 69                    &snapshot,
 70                    MAX_CONTEXT_TOKENS,
 71                    MAX_REWRITE_TOKENS,
 72                );
 73
 74            let context_offset_range = context_range.to_offset(&snapshot);
 75
 76            let editable_offset_range = editable_range.to_offset(&snapshot);
 77
 78            let inputs = zeta_prompt::ZetaPromptInput {
 79                events,
 80                related_files,
 81                cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
 82                    - context_range.start.to_offset(&snapshot),
 83                cursor_path: full_path.clone(),
 84                cursor_excerpt: snapshot
 85                    .text_for_range(context_range)
 86                    .collect::<String>()
 87                    .into(),
 88                editable_range_in_excerpt: (editable_offset_range.start
 89                    - context_offset_range.start)
 90                    ..(editable_offset_range.end - context_offset_range.start),
 91            };
 92
 93            let prompt = build_prompt(&inputs);
 94
 95            if let Some(debug_tx) = &debug_tx {
 96                debug_tx
 97                    .unbounded_send(DebugEvent::EditPredictionStarted(
 98                        EditPredictionStartedDebugEvent {
 99                            buffer: active_buffer.downgrade(),
100                            prompt: Some(prompt.clone()),
101                            position,
102                        },
103                    ))
104                    .ok();
105            }
106
107            let request_body = open_ai::Request {
108                model: "mercury-coder".into(),
109                messages: vec![open_ai::RequestMessage::User {
110                    content: open_ai::MessageContent::Plain(prompt),
111                }],
112                stream: false,
113                max_completion_tokens: None,
114                stop: vec![],
115                temperature: None,
116                tool_choice: None,
117                parallel_tool_calls: None,
118                tools: vec![],
119                prompt_cache_key: None,
120                reasoning_effort: None,
121            };
122
123            let buf = serde_json::to_vec(&request_body)?;
124            let body: AsyncBody = buf.into();
125
126            let request = http_client::Request::builder()
127                .uri(MERCURY_API_URL)
128                .header("Content-Type", "application/json")
129                .header("Authorization", format!("Bearer {}", api_token))
130                .header("Connection", "keep-alive")
131                .method(Method::POST)
132                .body(body)
133                .context("Failed to create request")?;
134
135            let mut response = http_client
136                .send(request)
137                .await
138                .context("Failed to send request")?;
139
140            let mut body: Vec<u8> = Vec::new();
141            response
142                .body_mut()
143                .read_to_end(&mut body)
144                .await
145                .context("Failed to read response body")?;
146
147            let response_received_at = Instant::now();
148            if !response.status().is_success() {
149                anyhow::bail!(
150                    "Request failed with status: {:?}\nBody: {}",
151                    response.status(),
152                    String::from_utf8_lossy(&body),
153                );
154            };
155
156            let mut response: open_ai::Response =
157                serde_json::from_slice(&body).context("Failed to parse response")?;
158
159            let id = mem::take(&mut response.id);
160            let response_str = text_from_response(response).unwrap_or_default();
161
162            if let Some(debug_tx) = &debug_tx {
163                debug_tx
164                    .unbounded_send(DebugEvent::EditPredictionFinished(
165                        EditPredictionFinishedDebugEvent {
166                            buffer: active_buffer.downgrade(),
167                            model_output: Some(response_str.clone()),
168                            position,
169                        },
170                    ))
171                    .ok();
172            }
173
174            let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
175            let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
176
177            let mut edits = Vec::new();
178            const NO_PREDICTION_OUTPUT: &str = "None";
179
180            if response_str != NO_PREDICTION_OUTPUT {
181                let old_text = snapshot
182                    .text_for_range(editable_offset_range.clone())
183                    .collect::<String>();
184                edits.extend(
185                    language::text_diff(&old_text, &response_str)
186                        .into_iter()
187                        .map(|(range, text)| {
188                            (
189                                snapshot.anchor_after(editable_offset_range.start + range.start)
190                                    ..snapshot
191                                        .anchor_before(editable_offset_range.start + range.end),
192                                text,
193                            )
194                        }),
195                );
196            }
197
198            anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
199        });
200
201        cx.spawn(async move |cx| {
202            let (id, edits, old_snapshot, response_received_at, inputs) =
203                result.await.context("Mercury edit prediction failed")?;
204            anyhow::Ok(Some(
205                EditPredictionResult::new(
206                    EditPredictionId(id.into()),
207                    &buffer,
208                    &old_snapshot,
209                    edits.into(),
210                    buffer_snapshotted_at,
211                    response_received_at,
212                    inputs,
213                    cx,
214                )
215                .await,
216            ))
217        })
218    }
219}
220
221fn build_prompt(inputs: &ZetaPromptInput) -> String {
222    const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
223    const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
224    const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
225    const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
226    const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
227    const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
228    const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
229    const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
230    const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
231    const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
232    const CURSOR_TAG: &str = "<|cursor|>";
233    const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
234    const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
235
236    let mut prompt = String::new();
237
238    push_delimited(
239        &mut prompt,
240        RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
241        |prompt| {
242            for related_file in inputs.related_files.iter() {
243                for related_excerpt in &related_file.excerpts {
244                    push_delimited(
245                        prompt,
246                        RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
247                        |prompt| {
248                            prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
249                            prompt.push_str(related_file.path.to_string_lossy().as_ref());
250                            prompt.push('\n');
251                            prompt.push_str(&related_excerpt.text.to_string());
252                        },
253                    );
254                }
255            }
256        },
257    );
258
259    push_delimited(
260        &mut prompt,
261        CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
262        |prompt| {
263            prompt.push_str(CURRENT_FILE_PATH_PREFIX);
264            prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
265            prompt.push('\n');
266
267            prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
268            push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
269                prompt.push_str(
270                    &inputs.cursor_excerpt
271                        [inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
272                );
273                prompt.push_str(CURSOR_TAG);
274                prompt.push_str(
275                    &inputs.cursor_excerpt
276                        [inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
277                );
278            });
279            prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
280        },
281    );
282
283    push_delimited(
284        &mut prompt,
285        EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
286        |prompt| {
287            for event in inputs.events.iter() {
288                zeta_prompt::write_event(prompt, &event);
289            }
290        },
291    );
292
293    prompt
294}
295
296fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
297    prompt.push_str(delimiters.start);
298    cb(prompt);
299    prompt.push_str(delimiters.end);
300}
301
302pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
303pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
304
305pub fn load_api_token(cx: &App) -> Task<Option<String>> {
306    if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
307        .ok()
308        .filter(|value| !value.is_empty())
309    {
310        return Task::ready(Some(api_token));
311    }
312    let credentials_provider = <dyn CredentialsProvider>::global(cx);
313    cx.spawn(async move |cx| {
314        let (_, credentials) = credentials_provider
315            .read_credentials(MERCURY_CREDENTIALS_URL, &cx)
316            .await
317            .ok()??;
318        String::from_utf8(credentials).ok()
319    })
320}
321
322fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
323    let credentials_provider = <dyn CredentialsProvider>::global(cx);
324
325    cx.spawn(async move |cx| {
326        if let Some(api_token) = api_token {
327            credentials_provider
328                .write_credentials(
329                    MERCURY_CREDENTIALS_URL,
330                    MERCURY_CREDENTIALS_USERNAME,
331                    api_token.as_bytes(),
332                    cx,
333                )
334                .await
335                .context("Failed to save Mercury API token to system keychain")
336        } else {
337            credentials_provider
338                .delete_credentials(MERCURY_CREDENTIALS_URL, cx)
339                .await
340                .context("Failed to delete Mercury API token from system keychain")
341        }
342    })
343}