mercury.rs

  1use anyhow::{Context as _, Result};
  2use cloud_llm_client::predict_edits_v3::Event;
  3use credentials_provider::CredentialsProvider;
  4use edit_prediction_context::RelatedFile;
  5use futures::{AsyncReadExt as _, FutureExt, future::Shared};
  6use gpui::{
  7    App, AppContext as _, Entity, Task,
  8    http_client::{self, AsyncBody, Method},
  9};
 10use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
 11use project::{Project, ProjectPath};
 12use std::{
 13    collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
 14};
 15
 16use crate::{
 17    EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
 18    prediction::EditPredictionResult,
 19};
 20
 21const MERCURY_API_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
 22const MAX_CONTEXT_TOKENS: usize = 150;
 23const MAX_REWRITE_TOKENS: usize = 350;
 24
 25pub struct Mercury {
 26    pub api_token: Shared<Task<Option<String>>>,
 27}
 28
 29impl Mercury {
 30    pub fn new(cx: &App) -> Self {
 31        Mercury {
 32            api_token: load_api_token(cx).shared(),
 33        }
 34    }
 35
 36    pub fn set_api_token(&mut self, api_token: Option<String>, cx: &mut App) -> Task<Result<()>> {
 37        self.api_token = Task::ready(api_token.clone()).shared();
 38        store_api_token_in_keychain(api_token, cx)
 39    }
 40
 41    pub fn request_prediction(
 42        &self,
 43        _project: &Entity<Project>,
 44        active_buffer: &Entity<Buffer>,
 45        snapshot: BufferSnapshot,
 46        position: language::Anchor,
 47        events: Vec<Arc<Event>>,
 48        _recent_paths: &VecDeque<ProjectPath>,
 49        related_files: Vec<RelatedFile>,
 50        _diagnostic_search_range: Range<Point>,
 51        cx: &mut App,
 52    ) -> Task<Result<Option<EditPredictionResult>>> {
 53        let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
 54            return Task::ready(Ok(None));
 55        };
 56        let full_path: Arc<Path> = snapshot
 57            .file()
 58            .map(|file| file.full_path(cx))
 59            .unwrap_or_else(|| "untitled".into())
 60            .into();
 61
 62        let http_client = cx.http_client();
 63        let cursor_point = position.to_point(&snapshot);
 64        let buffer_snapshotted_at = Instant::now();
 65
 66        let result = cx.background_spawn(async move {
 67            let (editable_range, context_range) =
 68                crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
 69                    cursor_point,
 70                    &snapshot,
 71                    MAX_CONTEXT_TOKENS,
 72                    MAX_REWRITE_TOKENS,
 73                );
 74
 75            let offset_range = editable_range.to_offset(&snapshot);
 76            let prompt = build_prompt(
 77                &events,
 78                &related_files,
 79                &snapshot,
 80                full_path.as_ref(),
 81                cursor_point,
 82                editable_range,
 83                context_range.clone(),
 84            );
 85
 86            let inputs = EditPredictionInputs {
 87                events: events,
 88                included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
 89                    path: full_path.clone(),
 90                    max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
 91                    excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
 92                        start_line: cloud_llm_client::predict_edits_v3::Line(
 93                            context_range.start.row,
 94                        ),
 95                        text: snapshot
 96                            .text_for_range(context_range.clone())
 97                            .collect::<String>()
 98                            .into(),
 99                    }],
100                }],
101                cursor_point: cloud_llm_client::predict_edits_v3::Point {
102                    column: cursor_point.column,
103                    line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
104                },
105                cursor_path: full_path.clone(),
106            };
107
108            let request_body = open_ai::Request {
109                model: "mercury-coder".into(),
110                messages: vec![open_ai::RequestMessage::User {
111                    content: open_ai::MessageContent::Plain(prompt),
112                }],
113                stream: false,
114                max_completion_tokens: None,
115                stop: vec![],
116                temperature: None,
117                tool_choice: None,
118                parallel_tool_calls: None,
119                tools: vec![],
120                prompt_cache_key: None,
121                reasoning_effort: None,
122            };
123
124            let buf = serde_json::to_vec(&request_body)?;
125            let body: AsyncBody = buf.into();
126
127            let request = http_client::Request::builder()
128                .uri(MERCURY_API_URL)
129                .header("Content-Type", "application/json")
130                .header("Authorization", format!("Bearer {}", api_token))
131                .header("Connection", "keep-alive")
132                .method(Method::POST)
133                .body(body)
134                .context("Failed to create request")?;
135
136            let mut response = http_client
137                .send(request)
138                .await
139                .context("Failed to send request")?;
140
141            let mut body: Vec<u8> = Vec::new();
142            response
143                .body_mut()
144                .read_to_end(&mut body)
145                .await
146                .context("Failed to read response body")?;
147
148            let response_received_at = Instant::now();
149            if !response.status().is_success() {
150                anyhow::bail!(
151                    "Request failed with status: {:?}\nBody: {}",
152                    response.status(),
153                    String::from_utf8_lossy(&body),
154                );
155            };
156
157            let mut response: open_ai::Response =
158                serde_json::from_slice(&body).context("Failed to parse response")?;
159
160            let id = mem::take(&mut response.id);
161            let response_str = text_from_response(response).unwrap_or_default();
162
163            let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
164            let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
165
166            let mut edits = Vec::new();
167            const NO_PREDICTION_OUTPUT: &str = "None";
168
169            if response_str != NO_PREDICTION_OUTPUT {
170                let old_text = snapshot
171                    .text_for_range(offset_range.clone())
172                    .collect::<String>();
173                edits.extend(
174                    language::text_diff(&old_text, &response_str)
175                        .into_iter()
176                        .map(|(range, text)| {
177                            (
178                                snapshot.anchor_after(offset_range.start + range.start)
179                                    ..snapshot.anchor_before(offset_range.start + range.end),
180                                text,
181                            )
182                        }),
183                );
184            }
185
186            anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
187        });
188
189        let buffer = active_buffer.clone();
190
191        cx.spawn(async move |cx| {
192            let (id, edits, old_snapshot, response_received_at, inputs) =
193                result.await.context("Mercury edit prediction failed")?;
194            anyhow::Ok(Some(
195                EditPredictionResult::new(
196                    EditPredictionId(id.into()),
197                    &buffer,
198                    &old_snapshot,
199                    edits.into(),
200                    buffer_snapshotted_at,
201                    response_received_at,
202                    inputs,
203                    cx,
204                )
205                .await,
206            ))
207        })
208    }
209}
210
211fn build_prompt(
212    events: &[Arc<Event>],
213    related_files: &[RelatedFile],
214    cursor_buffer: &BufferSnapshot,
215    cursor_buffer_path: &Path,
216    cursor_point: Point,
217    editable_range: Range<Point>,
218    context_range: Range<Point>,
219) -> String {
220    const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
221    const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
222    const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
223    const RECENTLY_VIEWED_SNIPPET_END: &str = "<|/recently_viewed_code_snippet|>\n";
224    const CURRENT_FILE_CONTENT_START: &str = "<|current_file_content|>\n";
225    const CURRENT_FILE_CONTENT_END: &str = "<|/current_file_content|>\n";
226    const CODE_TO_EDIT_START: &str = "<|code_to_edit|>\n";
227    const CODE_TO_EDIT_END: &str = "<|/code_to_edit|>\n";
228    const EDIT_DIFF_HISTORY_START: &str = "<|edit_diff_history|>\n";
229    const EDIT_DIFF_HISTORY_END: &str = "<|/edit_diff_history|>\n";
230    const CURSOR_TAG: &str = "<|cursor|>";
231    const CODE_SNIPPET_FILE_PATH_PREFIX: &str = "code_snippet_file_path: ";
232    const CURRENT_FILE_PATH_PREFIX: &str = "current_file_path: ";
233
234    let mut prompt = String::new();
235
236    push_delimited(
237        &mut prompt,
238        RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
239        |prompt| {
240            for related_file in related_files {
241                for related_excerpt in &related_file.excerpts {
242                    push_delimited(
243                        prompt,
244                        RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
245                        |prompt| {
246                            prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
247                            prompt.push_str(related_file.path.path.as_unix_str());
248                            prompt.push('\n');
249                            prompt.push_str(&related_excerpt.text.to_string());
250                        },
251                    );
252                }
253            }
254        },
255    );
256
257    push_delimited(
258        &mut prompt,
259        CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
260        |prompt| {
261            prompt.push_str(CURRENT_FILE_PATH_PREFIX);
262            prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
263            prompt.push('\n');
264
265            let prefix_range = context_range.start..editable_range.start;
266            let suffix_range = editable_range.end..context_range.end;
267
268            prompt.extend(cursor_buffer.text_for_range(prefix_range));
269            push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
270                let range_before_cursor = editable_range.start..cursor_point;
271                let range_after_cursor = cursor_point..editable_range.end;
272                prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
273                prompt.push_str(CURSOR_TAG);
274                prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
275            });
276            prompt.extend(cursor_buffer.text_for_range(suffix_range));
277        },
278    );
279
280    push_delimited(
281        &mut prompt,
282        EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
283        |prompt| {
284            for event in events {
285                writeln!(prompt, "{event}").unwrap();
286            }
287        },
288    );
289
290    prompt
291}
292
293fn push_delimited(prompt: &mut String, delimiters: Range<&str>, cb: impl FnOnce(&mut String)) {
294    prompt.push_str(delimiters.start);
295    cb(prompt);
296    prompt.push_str(delimiters.end);
297}
298
299pub const MERCURY_CREDENTIALS_URL: &str = "https://api.inceptionlabs.ai/v1/edit/completions";
300pub const MERCURY_CREDENTIALS_USERNAME: &str = "mercury-api-token";
301
302pub fn load_api_token(cx: &App) -> Task<Option<String>> {
303    if let Some(api_token) = std::env::var("MERCURY_AI_TOKEN")
304        .ok()
305        .filter(|value| !value.is_empty())
306    {
307        return Task::ready(Some(api_token));
308    }
309    let credentials_provider = <dyn CredentialsProvider>::global(cx);
310    cx.spawn(async move |cx| {
311        let (_, credentials) = credentials_provider
312            .read_credentials(MERCURY_CREDENTIALS_URL, &cx)
313            .await
314            .ok()??;
315        String::from_utf8(credentials).ok()
316    })
317}
318
319fn store_api_token_in_keychain(api_token: Option<String>, cx: &App) -> Task<Result<()>> {
320    let credentials_provider = <dyn CredentialsProvider>::global(cx);
321
322    cx.spawn(async move |cx| {
323        if let Some(api_token) = api_token {
324            credentials_provider
325                .write_credentials(
326                    MERCURY_CREDENTIALS_URL,
327                    MERCURY_CREDENTIALS_USERNAME,
328                    api_token.as_bytes(),
329                    cx,
330                )
331                .await
332                .context("Failed to save Mercury API token to system keychain")
333        } else {
334            credentials_provider
335                .delete_credentials(MERCURY_CREDENTIALS_URL, cx)
336                .await
337                .context("Failed to delete Mercury API token from system keychain")
338        }
339    })
340}