zeta1.rs

  1mod input_excerpt;
  2
  3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  4
  5use crate::{
  6    EditPredictionId, ZedUpdateRequiredError, Zeta,
  7    prediction::{EditPrediction, EditPredictionInputs},
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use input_excerpt::excerpt_for_cursor_position;
 15use language::{
 16    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
 17};
 18use project::{Project, ProjectPath};
 19use release_channel::AppVersion;
 20use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 21
 22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 26
 27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 30
 31pub(crate) fn request_prediction_with_zeta1(
 32    zeta: &mut Zeta,
 33    project: &Entity<Project>,
 34    buffer: &Entity<Buffer>,
 35    snapshot: BufferSnapshot,
 36    position: language::Anchor,
 37    events: Vec<Arc<Event>>,
 38    cx: &mut Context<Zeta>,
 39) -> Task<Result<Option<EditPrediction>>> {
 40    let buffer = buffer.clone();
 41    let buffer_snapshotted_at = Instant::now();
 42    let client = zeta.client.clone();
 43    let llm_token = zeta.llm_token.clone();
 44    let app_version = AppVersion::global(cx);
 45
 46    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 47        let can_collect_file = zeta.can_collect_file(project, file, cx);
 48        let git_info = if can_collect_file {
 49            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 50        } else {
 51            None
 52        };
 53        (git_info, can_collect_file)
 54    } else {
 55        (None, false)
 56    };
 57
 58    let full_path: Arc<Path> = snapshot
 59        .file()
 60        .map(|f| Arc::from(f.full_path(cx).as_path()))
 61        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 62    let full_path_str = full_path.to_string_lossy().into_owned();
 63    let cursor_point = position.to_point(&snapshot);
 64    let prompt_for_events = {
 65        let events = events.clone();
 66        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 67    };
 68    let gather_task = gather_context(
 69        full_path_str,
 70        &snapshot,
 71        cursor_point,
 72        prompt_for_events,
 73        cx,
 74    );
 75
 76    cx.spawn(async move |this, cx| {
 77        let GatherContextOutput {
 78            mut body,
 79            context_range,
 80            editable_range,
 81            included_events_count,
 82        } = gather_task.await?;
 83        let done_gathering_context_at = Instant::now();
 84
 85        let included_events = &events[events.len() - included_events_count..events.len()];
 86        body.can_collect_data = can_collect_file
 87            && this
 88                .read_with(cx, |this, _| this.can_collect_events(included_events))
 89                .unwrap_or(false);
 90        if body.can_collect_data {
 91            body.git_info = git_info;
 92        }
 93
 94        log::debug!(
 95            "Events:\n{}\nExcerpt:\n{:?}",
 96            body.input_events,
 97            body.input_excerpt
 98        );
 99
100        let http_client = client.http_client();
101
102        let response = Zeta::send_api_request::<PredictEditsResponse>(
103            |request| {
104                let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
105                    predict_edits_url
106                } else {
107                    http_client
108                        .build_zed_llm_url("/predict_edits/v2", &[])?
109                        .as_str()
110                        .into()
111                };
112                Ok(request
113                    .uri(uri)
114                    .body(serde_json::to_string(&body)?.into())?)
115            },
116            client,
117            llm_token,
118            app_version,
119        )
120        .await;
121
122        let inputs = EditPredictionInputs {
123            events: included_events.into(),
124            included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
125                path: full_path.clone(),
126                max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
127                excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
128                    start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
129                    text: snapshot
130                        .text_for_range(context_range)
131                        .collect::<String>()
132                        .into(),
133                }],
134            }],
135            cursor_point: cloud_llm_client::predict_edits_v3::Point {
136                column: cursor_point.column,
137                line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
138            },
139            cursor_path: full_path,
140        };
141
142        // let response = perform_predict_edits(PerformPredictEditsParams {
143        //     client,
144        //     llm_token,
145        //     app_version,
146        //     body,
147        // })
148        // .await;
149
150        let (response, usage) = match response {
151            Ok(response) => response,
152            Err(err) => {
153                if err.is::<ZedUpdateRequiredError>() {
154                    cx.update(|cx| {
155                        this.update(cx, |zeta, _cx| {
156                            zeta.update_required = true;
157                        })
158                        .ok();
159
160                        let error_message: SharedString = err.to_string().into();
161                        show_app_notification(
162                            NotificationId::unique::<ZedUpdateRequiredError>(),
163                            cx,
164                            move |cx| {
165                                cx.new(|cx| {
166                                    ErrorMessagePrompt::new(error_message.clone(), cx)
167                                        .with_link_button("Update Zed", "https://zed.dev/releases")
168                                })
169                            },
170                        );
171                    })
172                    .ok();
173                }
174
175                return Err(err);
176            }
177        };
178
179        let received_response_at = Instant::now();
180        log::debug!("completion response: {}", &response.output_excerpt);
181
182        if let Some(usage) = usage {
183            this.update(cx, |this, cx| {
184                this.user_store.update(cx, |user_store, cx| {
185                    user_store.update_edit_prediction_usage(usage, cx);
186                });
187            })
188            .ok();
189        }
190
191        let edit_prediction = process_completion_response(
192            response,
193            buffer,
194            &snapshot,
195            editable_range,
196            inputs,
197            buffer_snapshotted_at,
198            received_response_at,
199            cx,
200        )
201        .await;
202
203        let finished_at = Instant::now();
204
205        // record latency for ~1% of requests
206        if rand::random::<u8>() <= 2 {
207            telemetry::event!(
208                "Edit Prediction Request",
209                context_latency = done_gathering_context_at
210                    .duration_since(buffer_snapshotted_at)
211                    .as_millis(),
212                request_latency = received_response_at
213                    .duration_since(done_gathering_context_at)
214                    .as_millis(),
215                process_latency = finished_at.duration_since(received_response_at).as_millis()
216            );
217        }
218
219        edit_prediction
220    })
221}
222
223fn process_completion_response(
224    prediction_response: PredictEditsResponse,
225    buffer: Entity<Buffer>,
226    snapshot: &BufferSnapshot,
227    editable_range: Range<usize>,
228    inputs: EditPredictionInputs,
229    buffer_snapshotted_at: Instant,
230    received_response_at: Instant,
231    cx: &AsyncApp,
232) -> Task<Result<Option<EditPrediction>>> {
233    let snapshot = snapshot.clone();
234    let request_id = prediction_response.request_id;
235    let output_excerpt = prediction_response.output_excerpt;
236    cx.spawn(async move |cx| {
237        let output_excerpt: Arc<str> = output_excerpt.into();
238
239        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
240            .background_spawn({
241                let output_excerpt = output_excerpt.clone();
242                let editable_range = editable_range.clone();
243                let snapshot = snapshot.clone();
244                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
245            })
246            .await?
247            .into();
248
249        Ok(EditPrediction::new(
250            EditPredictionId(request_id.into()),
251            &buffer,
252            &snapshot,
253            edits,
254            buffer_snapshotted_at,
255            received_response_at,
256            inputs,
257            cx,
258        )
259        .await)
260    })
261}
262
263fn parse_edits(
264    output_excerpt: Arc<str>,
265    editable_range: Range<usize>,
266    snapshot: &BufferSnapshot,
267) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
268    let content = output_excerpt.replace(CURSOR_MARKER, "");
269
270    let start_markers = content
271        .match_indices(EDITABLE_REGION_START_MARKER)
272        .collect::<Vec<_>>();
273    anyhow::ensure!(
274        start_markers.len() == 1,
275        "expected exactly one start marker, found {}",
276        start_markers.len()
277    );
278
279    let end_markers = content
280        .match_indices(EDITABLE_REGION_END_MARKER)
281        .collect::<Vec<_>>();
282    anyhow::ensure!(
283        end_markers.len() == 1,
284        "expected exactly one end marker, found {}",
285        end_markers.len()
286    );
287
288    let sof_markers = content
289        .match_indices(START_OF_FILE_MARKER)
290        .collect::<Vec<_>>();
291    anyhow::ensure!(
292        sof_markers.len() <= 1,
293        "expected at most one start-of-file marker, found {}",
294        sof_markers.len()
295    );
296
297    let codefence_start = start_markers[0].0;
298    let content = &content[codefence_start..];
299
300    let newline_ix = content.find('\n').context("could not find newline")?;
301    let content = &content[newline_ix + 1..];
302
303    let codefence_end = content
304        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
305        .context("could not find end marker")?;
306    let new_text = &content[..codefence_end];
307
308    let old_text = snapshot
309        .text_for_range(editable_range.clone())
310        .collect::<String>();
311
312    Ok(compute_edits(
313        old_text,
314        new_text,
315        editable_range.start,
316        snapshot,
317    ))
318}
319
320pub fn compute_edits(
321    old_text: String,
322    new_text: &str,
323    offset: usize,
324    snapshot: &BufferSnapshot,
325) -> Vec<(Range<Anchor>, Arc<str>)> {
326    text_diff(&old_text, new_text)
327        .into_iter()
328        .map(|(mut old_range, new_text)| {
329            old_range.start += offset;
330            old_range.end += offset;
331
332            let prefix_len = common_prefix(
333                snapshot.chars_for_range(old_range.clone()),
334                new_text.chars(),
335            );
336            old_range.start += prefix_len;
337
338            let suffix_len = common_prefix(
339                snapshot.reversed_chars_for_range(old_range.clone()),
340                new_text[prefix_len..].chars().rev(),
341            );
342            old_range.end = old_range.end.saturating_sub(suffix_len);
343
344            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
345            let range = if old_range.is_empty() {
346                let anchor = snapshot.anchor_after(old_range.start);
347                anchor..anchor
348            } else {
349                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
350            };
351            (range, new_text)
352        })
353        .collect()
354}
355
356fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
357    a.zip(b)
358        .take_while(|(a, b)| a == b)
359        .map(|(a, _)| a.len_utf8())
360        .sum()
361}
362
363fn git_info_for_file(
364    project: &Entity<Project>,
365    project_path: &ProjectPath,
366    cx: &App,
367) -> Option<PredictEditsGitInfo> {
368    let git_store = project.read(cx).git_store().read(cx);
369    if let Some((repository, _repo_path)) =
370        git_store.repository_and_path_for_project_path(project_path, cx)
371    {
372        let repository = repository.read(cx);
373        let head_sha = repository
374            .head_commit
375            .as_ref()
376            .map(|head_commit| head_commit.sha.to_string());
377        let remote_origin_url = repository.remote_origin_url.clone();
378        let remote_upstream_url = repository.remote_upstream_url.clone();
379        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
380            return None;
381        }
382        Some(PredictEditsGitInfo {
383            head_sha,
384            remote_origin_url,
385            remote_upstream_url,
386        })
387    } else {
388        None
389    }
390}
391
392pub struct GatherContextOutput {
393    pub body: PredictEditsBody,
394    pub context_range: Range<Point>,
395    pub editable_range: Range<usize>,
396    pub included_events_count: usize,
397}
398
399pub fn gather_context(
400    full_path_str: String,
401    snapshot: &BufferSnapshot,
402    cursor_point: language::Point,
403    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
404    cx: &App,
405) -> Task<Result<GatherContextOutput>> {
406    cx.background_spawn({
407        let snapshot = snapshot.clone();
408        async move {
409            let input_excerpt = excerpt_for_cursor_position(
410                cursor_point,
411                &full_path_str,
412                &snapshot,
413                MAX_REWRITE_TOKENS,
414                MAX_CONTEXT_TOKENS,
415            );
416            let (input_events, included_events_count) = prompt_for_events();
417            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
418
419            let body = PredictEditsBody {
420                input_events,
421                input_excerpt: input_excerpt.prompt,
422                can_collect_data: false,
423                diagnostic_groups: None,
424                git_info: None,
425                outline: None,
426                speculated_output: None,
427            };
428
429            Ok(GatherContextOutput {
430                body,
431                context_range: input_excerpt.context_range,
432                editable_range,
433                included_events_count,
434            })
435        }
436    })
437}
438
439fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
440    let mut result = String::new();
441    for (ix, event) in events.iter().rev().enumerate() {
442        let event_string = format_event(event.as_ref());
443        let event_tokens = guess_token_count(event_string.len());
444        if event_tokens > remaining_tokens {
445            return (result, ix);
446        }
447
448        if !result.is_empty() {
449            result.insert_str(0, "\n\n");
450        }
451        result.insert_str(0, &event_string);
452        remaining_tokens -= event_tokens;
453    }
454    return (result, events.len());
455}
456
457pub fn format_event(event: &Event) -> String {
458    match event {
459        Event::BufferChange {
460            path,
461            old_path,
462            diff,
463            ..
464        } => {
465            let mut prompt = String::new();
466
467            if old_path != path {
468                writeln!(
469                    prompt,
470                    "User renamed {} to {}\n",
471                    old_path.display(),
472                    path.display()
473                )
474                .unwrap();
475            }
476
477            if !diff.is_empty() {
478                write!(
479                    prompt,
480                    "User edited {}:\n```diff\n{}\n```",
481                    path.display(),
482                    diff
483                )
484                .unwrap();
485            }
486
487            prompt
488        }
489    }
490}
491
492/// Typical number of string bytes per token for the purposes of limiting model input. This is
493/// intentionally low to err on the side of underestimating limits.
494pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
495
496fn guess_token_count(bytes: usize) -> usize {
497    bytes / BYTES_PER_TOKEN_GUESS
498}