zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
  5    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  6    prediction::{EditPredictionInputs, EditPredictionResult},
  7};
  8use anyhow::{Context as _, Result};
  9use cloud_llm_client::{
 10    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 11    predict_edits_v3::Event,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use language::{
 15    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
 16};
 17use project::{Project, ProjectPath};
 18use release_channel::AppVersion;
 19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 20
 21const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 22const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 23const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 24const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 25
 26pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 27pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 28pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 29
 30pub(crate) fn request_prediction_with_zeta1(
 31    store: &mut EditPredictionStore,
 32    project: &Entity<Project>,
 33    buffer: &Entity<Buffer>,
 34    snapshot: BufferSnapshot,
 35    position: language::Anchor,
 36    events: Vec<Arc<Event>>,
 37    trigger: PredictEditsRequestTrigger,
 38    cx: &mut Context<EditPredictionStore>,
 39) -> Task<Result<Option<EditPredictionResult>>> {
 40    let buffer = buffer.clone();
 41    let buffer_snapshotted_at = Instant::now();
 42    let client = store.client.clone();
 43    let llm_token = store.llm_token.clone();
 44    let app_version = AppVersion::global(cx);
 45
 46    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 47        let can_collect_file = store.can_collect_file(project, file, cx);
 48        let git_info = if can_collect_file {
 49            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 50        } else {
 51            None
 52        };
 53        (git_info, can_collect_file)
 54    } else {
 55        (None, false)
 56    };
 57
 58    let full_path: Arc<Path> = snapshot
 59        .file()
 60        .map(|f| Arc::from(f.full_path(cx).as_path()))
 61        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 62    let full_path_str = full_path.to_string_lossy().into_owned();
 63    let cursor_point = position.to_point(&snapshot);
 64    let prompt_for_events = {
 65        let events = events.clone();
 66        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 67    };
 68    let gather_task = gather_context(
 69        full_path_str,
 70        &snapshot,
 71        cursor_point,
 72        prompt_for_events,
 73        trigger,
 74        cx,
 75    );
 76
 77    cx.spawn(async move |this, cx| {
 78        let GatherContextOutput {
 79            mut body,
 80            context_range,
 81            editable_range,
 82            included_events_count,
 83        } = gather_task.await?;
 84        let done_gathering_context_at = Instant::now();
 85
 86        let included_events = &events[events.len() - included_events_count..events.len()];
 87        body.can_collect_data = can_collect_file
 88            && this
 89                .read_with(cx, |this, _| this.can_collect_events(included_events))
 90                .unwrap_or(false);
 91        if body.can_collect_data {
 92            body.git_info = git_info;
 93        }
 94
 95        log::debug!(
 96            "Events:\n{}\nExcerpt:\n{:?}",
 97            body.input_events,
 98            body.input_excerpt
 99        );
100
101        let http_client = client.http_client();
102
103        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
104            |request| {
105                let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
106                    predict_edits_url
107                } else {
108                    http_client
109                        .build_zed_llm_url("/predict_edits/v2", &[])?
110                        .as_str()
111                        .into()
112                };
113                Ok(request
114                    .uri(uri)
115                    .body(serde_json::to_string(&body)?.into())?)
116            },
117            client,
118            llm_token,
119            app_version,
120        )
121        .await;
122
123        let inputs = EditPredictionInputs {
124            events: included_events.into(),
125            included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
126                path: full_path.clone(),
127                max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
128                excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
129                    start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
130                    text: snapshot
131                        .text_for_range(context_range)
132                        .collect::<String>()
133                        .into(),
134                }],
135            }],
136            cursor_point: cloud_llm_client::predict_edits_v3::Point {
137                column: cursor_point.column,
138                line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
139            },
140            cursor_path: full_path,
141        };
142
143        // let response = perform_predict_edits(PerformPredictEditsParams {
144        //     client,
145        //     llm_token,
146        //     app_version,
147        //     body,
148        // })
149        // .await;
150
151        let (response, usage) = match response {
152            Ok(response) => response,
153            Err(err) => {
154                if err.is::<ZedUpdateRequiredError>() {
155                    cx.update(|cx| {
156                        this.update(cx, |ep_store, _cx| {
157                            ep_store.update_required = true;
158                        })
159                        .ok();
160
161                        let error_message: SharedString = err.to_string().into();
162                        show_app_notification(
163                            NotificationId::unique::<ZedUpdateRequiredError>(),
164                            cx,
165                            move |cx| {
166                                cx.new(|cx| {
167                                    ErrorMessagePrompt::new(error_message.clone(), cx)
168                                        .with_link_button("Update Zed", "https://zed.dev/releases")
169                                })
170                            },
171                        );
172                    })
173                    .ok();
174                }
175
176                return Err(err);
177            }
178        };
179
180        let received_response_at = Instant::now();
181        log::debug!("completion response: {}", &response.output_excerpt);
182
183        if let Some(usage) = usage {
184            this.update(cx, |this, cx| {
185                this.user_store.update(cx, |user_store, cx| {
186                    user_store.update_edit_prediction_usage(usage, cx);
187                });
188            })
189            .ok();
190        }
191
192        let edit_prediction = process_completion_response(
193            response,
194            buffer,
195            &snapshot,
196            editable_range,
197            inputs,
198            buffer_snapshotted_at,
199            received_response_at,
200            cx,
201        )
202        .await;
203
204        let finished_at = Instant::now();
205
206        // record latency for ~1% of requests
207        if rand::random::<u8>() <= 2 {
208            telemetry::event!(
209                "Edit Prediction Request",
210                context_latency = done_gathering_context_at
211                    .duration_since(buffer_snapshotted_at)
212                    .as_millis(),
213                request_latency = received_response_at
214                    .duration_since(done_gathering_context_at)
215                    .as_millis(),
216                process_latency = finished_at.duration_since(received_response_at).as_millis()
217            );
218        }
219
220        edit_prediction.map(Some)
221    })
222}
223
224fn process_completion_response(
225    prediction_response: PredictEditsResponse,
226    buffer: Entity<Buffer>,
227    snapshot: &BufferSnapshot,
228    editable_range: Range<usize>,
229    inputs: EditPredictionInputs,
230    buffer_snapshotted_at: Instant,
231    received_response_at: Instant,
232    cx: &AsyncApp,
233) -> Task<Result<EditPredictionResult>> {
234    let snapshot = snapshot.clone();
235    let request_id = prediction_response.request_id;
236    let output_excerpt = prediction_response.output_excerpt;
237    cx.spawn(async move |cx| {
238        let output_excerpt: Arc<str> = output_excerpt.into();
239
240        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
241            .background_spawn({
242                let output_excerpt = output_excerpt.clone();
243                let editable_range = editable_range.clone();
244                let snapshot = snapshot.clone();
245                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
246            })
247            .await?
248            .into();
249
250        let id = EditPredictionId(request_id.into());
251        Ok(EditPredictionResult::new(
252            id,
253            &buffer,
254            &snapshot,
255            edits,
256            buffer_snapshotted_at,
257            received_response_at,
258            inputs,
259            cx,
260        )
261        .await)
262    })
263}
264
265fn parse_edits(
266    output_excerpt: Arc<str>,
267    editable_range: Range<usize>,
268    snapshot: &BufferSnapshot,
269) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
270    let content = output_excerpt.replace(CURSOR_MARKER, "");
271
272    let start_markers = content
273        .match_indices(EDITABLE_REGION_START_MARKER)
274        .collect::<Vec<_>>();
275    anyhow::ensure!(
276        start_markers.len() == 1,
277        "expected exactly one start marker, found {}",
278        start_markers.len()
279    );
280
281    let end_markers = content
282        .match_indices(EDITABLE_REGION_END_MARKER)
283        .collect::<Vec<_>>();
284    anyhow::ensure!(
285        end_markers.len() == 1,
286        "expected exactly one end marker, found {}",
287        end_markers.len()
288    );
289
290    let sof_markers = content
291        .match_indices(START_OF_FILE_MARKER)
292        .collect::<Vec<_>>();
293    anyhow::ensure!(
294        sof_markers.len() <= 1,
295        "expected at most one start-of-file marker, found {}",
296        sof_markers.len()
297    );
298
299    let codefence_start = start_markers[0].0;
300    let content = &content[codefence_start..];
301
302    let newline_ix = content.find('\n').context("could not find newline")?;
303    let content = &content[newline_ix + 1..];
304
305    let codefence_end = content
306        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
307        .context("could not find end marker")?;
308    let new_text = &content[..codefence_end];
309
310    let old_text = snapshot
311        .text_for_range(editable_range.clone())
312        .collect::<String>();
313
314    Ok(compute_edits(
315        old_text,
316        new_text,
317        editable_range.start,
318        snapshot,
319    ))
320}
321
322pub fn compute_edits(
323    old_text: String,
324    new_text: &str,
325    offset: usize,
326    snapshot: &BufferSnapshot,
327) -> Vec<(Range<Anchor>, Arc<str>)> {
328    text_diff(&old_text, new_text)
329        .into_iter()
330        .map(|(mut old_range, new_text)| {
331            old_range.start += offset;
332            old_range.end += offset;
333
334            let prefix_len = common_prefix(
335                snapshot.chars_for_range(old_range.clone()),
336                new_text.chars(),
337            );
338            old_range.start += prefix_len;
339
340            let suffix_len = common_prefix(
341                snapshot.reversed_chars_for_range(old_range.clone()),
342                new_text[prefix_len..].chars().rev(),
343            );
344            old_range.end = old_range.end.saturating_sub(suffix_len);
345
346            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
347            let range = if old_range.is_empty() {
348                let anchor = snapshot.anchor_after(old_range.start);
349                anchor..anchor
350            } else {
351                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
352            };
353            (range, new_text)
354        })
355        .collect()
356}
357
358fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
359    a.zip(b)
360        .take_while(|(a, b)| a == b)
361        .map(|(a, _)| a.len_utf8())
362        .sum()
363}
364
365fn git_info_for_file(
366    project: &Entity<Project>,
367    project_path: &ProjectPath,
368    cx: &App,
369) -> Option<PredictEditsGitInfo> {
370    let git_store = project.read(cx).git_store().read(cx);
371    if let Some((repository, _repo_path)) =
372        git_store.repository_and_path_for_project_path(project_path, cx)
373    {
374        let repository = repository.read(cx);
375        let head_sha = repository
376            .head_commit
377            .as_ref()
378            .map(|head_commit| head_commit.sha.to_string());
379        let remote_origin_url = repository.remote_origin_url.clone();
380        let remote_upstream_url = repository.remote_upstream_url.clone();
381        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
382            return None;
383        }
384        Some(PredictEditsGitInfo {
385            head_sha,
386            remote_origin_url,
387            remote_upstream_url,
388        })
389    } else {
390        None
391    }
392}
393
394pub struct GatherContextOutput {
395    pub body: PredictEditsBody,
396    pub context_range: Range<Point>,
397    pub editable_range: Range<usize>,
398    pub included_events_count: usize,
399}
400
401pub fn gather_context(
402    full_path_str: String,
403    snapshot: &BufferSnapshot,
404    cursor_point: language::Point,
405    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
406    trigger: PredictEditsRequestTrigger,
407    cx: &App,
408) -> Task<Result<GatherContextOutput>> {
409    cx.background_spawn({
410        let snapshot = snapshot.clone();
411        async move {
412            let input_excerpt = excerpt_for_cursor_position(
413                cursor_point,
414                &full_path_str,
415                &snapshot,
416                MAX_REWRITE_TOKENS,
417                MAX_CONTEXT_TOKENS,
418            );
419            let (input_events, included_events_count) = prompt_for_events();
420            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
421
422            let body = PredictEditsBody {
423                input_events,
424                input_excerpt: input_excerpt.prompt,
425                can_collect_data: false,
426                diagnostic_groups: None,
427                git_info: None,
428                outline: None,
429                speculated_output: None,
430                trigger,
431            };
432
433            Ok(GatherContextOutput {
434                body,
435                context_range: input_excerpt.context_range,
436                editable_range,
437                included_events_count,
438            })
439        }
440    })
441}
442
443fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
444    let mut result = String::new();
445    for (ix, event) in events.iter().rev().enumerate() {
446        let event_string = format_event(event.as_ref());
447        let event_tokens = guess_token_count(event_string.len());
448        if event_tokens > remaining_tokens {
449            return (result, ix);
450        }
451
452        if !result.is_empty() {
453            result.insert_str(0, "\n\n");
454        }
455        result.insert_str(0, &event_string);
456        remaining_tokens -= event_tokens;
457    }
458    return (result, events.len());
459}
460
461pub fn format_event(event: &Event) -> String {
462    match event {
463        Event::BufferChange {
464            path,
465            old_path,
466            diff,
467            ..
468        } => {
469            let mut prompt = String::new();
470
471            if old_path != path {
472                writeln!(
473                    prompt,
474                    "User renamed {} to {}\n",
475                    old_path.display(),
476                    path.display()
477                )
478                .unwrap();
479            }
480
481            if !diff.is_empty() {
482                write!(
483                    prompt,
484                    "User edited {}:\n```diff\n{}\n```",
485                    path.display(),
486                    diff
487                )
488                .unwrap();
489            }
490
491            prompt
492        }
493    }
494}
495
496#[derive(Debug)]
497pub struct InputExcerpt {
498    pub context_range: Range<Point>,
499    pub editable_range: Range<Point>,
500    pub prompt: String,
501}
502
503pub fn excerpt_for_cursor_position(
504    position: Point,
505    path: &str,
506    snapshot: &BufferSnapshot,
507    editable_region_token_limit: usize,
508    context_token_limit: usize,
509) -> InputExcerpt {
510    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
511        position,
512        snapshot,
513        editable_region_token_limit,
514        context_token_limit,
515    );
516
517    let mut prompt = String::new();
518
519    writeln!(&mut prompt, "```{path}").unwrap();
520    if context_range.start == Point::zero() {
521        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
522    }
523
524    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
525        prompt.push_str(chunk.text);
526    }
527
528    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
529
530    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
531        prompt.push_str(chunk.text);
532    }
533    write!(prompt, "\n```").unwrap();
534
535    InputExcerpt {
536        context_range,
537        editable_range,
538        prompt,
539    }
540}
541
542fn push_editable_range(
543    cursor_position: Point,
544    snapshot: &BufferSnapshot,
545    editable_range: Range<Point>,
546    prompt: &mut String,
547) {
548    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
549    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
550        prompt.push_str(chunk.text);
551    }
552    prompt.push_str(CURSOR_MARKER);
553    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
554        prompt.push_str(chunk.text);
555    }
556    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use gpui::{App, AppContext};
563    use indoc::indoc;
564    use language::Buffer;
565
566    #[gpui::test]
567    fn test_excerpt_for_cursor_position(cx: &mut App) {
568        let text = indoc! {r#"
569            fn foo() {
570                let x = 42;
571                println!("Hello, world!");
572            }
573
574            fn bar() {
575                let x = 42;
576                let mut sum = 0;
577                for i in 0..x {
578                    sum += i;
579                }
580                println!("Sum: {}", sum);
581                return sum;
582            }
583
584            fn generate_random_numbers() -> Vec<i32> {
585                let mut rng = rand::thread_rng();
586                let mut numbers = Vec::new();
587                for _ in 0..5 {
588                    numbers.push(rng.random_range(1..101));
589                }
590                numbers
591            }
592        "#};
593        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
594        let snapshot = buffer.read(cx).snapshot();
595
596        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
597        // when a larger scope doesn't fit the editable region.
598        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
599        assert_eq!(
600            excerpt.prompt,
601            indoc! {r#"
602            ```main.rs
603                let x = 42;
604                println!("Hello, world!");
605            <|editable_region_start|>
606            }
607
608            fn bar() {
609                let x = 42;
610                let mut sum = 0;
611                for i in 0..x {
612                    sum += i;
613                }
614                println!("Sum: {}", sum);
615                r<|user_cursor_is_here|>eturn sum;
616            }
617
618            fn generate_random_numbers() -> Vec<i32> {
619            <|editable_region_end|>
620                let mut rng = rand::thread_rng();
621                let mut numbers = Vec::new();
622            ```"#}
623        );
624
625        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
626        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
627        assert_eq!(
628            excerpt.prompt,
629            indoc! {r#"
630            ```main.rs
631            fn bar() {
632                let x = 42;
633                let mut sum = 0;
634            <|editable_region_start|>
635                for i in 0..x {
636                    sum += i;
637                }
638                println!("Sum: {}", sum);
639                r<|user_cursor_is_here|>eturn sum;
640            }
641
642            fn generate_random_numbers() -> Vec<i32> {
643                let mut rng = rand::thread_rng();
644            <|editable_region_end|>
645                let mut numbers = Vec::new();
646                for _ in 0..5 {
647                    numbers.push(rng.random_range(1..101));
648            ```"#}
649        );
650    }
651}