zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::Result;
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use edit_prediction_types::PredictedCursorPosition;
 14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 15use language::{
 16    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 17};
 18use project::{Project, ProjectPath};
 19use release_channel::AppVersion;
 20use text::Bias;
 21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 22use zeta_prompt::{
 23    Event, ZetaPromptInput,
 24    zeta1::{
 25        CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
 26        START_OF_FILE_MARKER,
 27    },
 28};
 29
 30pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 31pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 32pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 33
 34pub(crate) fn request_prediction_with_zeta1(
 35    store: &mut EditPredictionStore,
 36    EditPredictionModelInput {
 37        project,
 38        buffer,
 39        snapshot,
 40        position,
 41        events,
 42        trigger,
 43        debug_tx,
 44        ..
 45    }: EditPredictionModelInput,
 46    cx: &mut Context<EditPredictionStore>,
 47) -> Task<Result<Option<EditPredictionResult>>> {
 48    let buffer_snapshotted_at = Instant::now();
 49    let client = store.client.clone();
 50    let llm_token = store.llm_token.clone();
 51    let app_version = AppVersion::global(cx);
 52
 53    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 54        let can_collect_file = store.can_collect_file(&project, file, cx);
 55        let git_info = if can_collect_file {
 56            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 57        } else {
 58            None
 59        };
 60        (git_info, can_collect_file)
 61    } else {
 62        (None, false)
 63    };
 64
 65    let full_path: Arc<Path> = snapshot
 66        .file()
 67        .map(|f| Arc::from(f.full_path(cx).as_path()))
 68        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 69    let full_path_str = full_path.to_string_lossy().into_owned();
 70    let cursor_point = position.to_point(&snapshot);
 71    let prompt_for_events = {
 72        let events = events.clone();
 73        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 74    };
 75    let gather_task = gather_context(
 76        full_path_str,
 77        &snapshot,
 78        cursor_point,
 79        prompt_for_events,
 80        trigger,
 81        cx,
 82    );
 83
 84    let uri = match client
 85        .http_client()
 86        .build_zed_llm_url("/predict_edits/v2", &[])
 87    {
 88        Ok(url) => Arc::from(url),
 89        Err(err) => return Task::ready(Err(err)),
 90    };
 91
 92    cx.spawn(async move |this, cx| {
 93        let GatherContextOutput {
 94            mut body,
 95            context_range,
 96            editable_range,
 97            included_events_count,
 98        } = gather_task.await?;
 99        let done_gathering_context_at = Instant::now();
100
101        let included_events = &events[events.len() - included_events_count..events.len()];
102        body.can_collect_data = can_collect_file
103            && this
104                .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
105                .unwrap_or(false);
106        if body.can_collect_data {
107            body.git_info = git_info;
108        }
109
110        log::debug!(
111            "Events:\n{}\nExcerpt:\n{:?}",
112            body.input_events,
113            body.input_excerpt
114        );
115
116        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
117            |request| {
118                Ok(request
119                    .uri(uri.as_str())
120                    .body(serde_json::to_string(&body)?.into())?)
121            },
122            client,
123            llm_token,
124            app_version,
125            true,
126        )
127        .await;
128
129        let context_start_offset = context_range.start.to_offset(&snapshot);
130        let context_start_row = context_range.start.row;
131        let editable_offset_range = editable_range.to_offset(&snapshot);
132
133        let inputs = ZetaPromptInput {
134            events: included_events.into(),
135            related_files: vec![],
136            cursor_path: full_path,
137            cursor_excerpt: snapshot
138                .text_for_range(context_range)
139                .collect::<String>()
140                .into(),
141            editable_range_in_excerpt: (editable_range.start - context_start_offset)
142                ..(editable_offset_range.end - context_start_offset),
143            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
144            excerpt_start_row: Some(context_start_row),
145        };
146
147        if let Some(debug_tx) = &debug_tx {
148            debug_tx
149                .unbounded_send(DebugEvent::EditPredictionStarted(
150                    EditPredictionStartedDebugEvent {
151                        buffer: buffer.downgrade(),
152                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
153                        position,
154                    },
155                ))
156                .ok();
157        }
158
159        let (response, usage) = match response {
160            Ok(response) => response,
161            Err(err) => {
162                if err.is::<ZedUpdateRequiredError>() {
163                    cx.update(|cx| {
164                        this.update(cx, |ep_store, _cx| {
165                            ep_store.update_required = true;
166                        })
167                        .ok();
168
169                        let error_message: SharedString = err.to_string().into();
170                        show_app_notification(
171                            NotificationId::unique::<ZedUpdateRequiredError>(),
172                            cx,
173                            move |cx| {
174                                cx.new(|cx| {
175                                    ErrorMessagePrompt::new(error_message.clone(), cx)
176                                        .with_link_button("Update Zed", "https://zed.dev/releases")
177                                })
178                            },
179                        );
180                    });
181                }
182
183                return Err(err);
184            }
185        };
186
187        let received_response_at = Instant::now();
188        log::debug!("completion response: {}", &response.output_excerpt);
189
190        if let Some(usage) = usage {
191            this.update(cx, |this, cx| {
192                this.user_store.update(cx, |user_store, cx| {
193                    user_store.update_edit_prediction_usage(usage, cx);
194                });
195            })
196            .ok();
197        }
198
199        if let Some(debug_tx) = &debug_tx {
200            debug_tx
201                .unbounded_send(DebugEvent::EditPredictionFinished(
202                    EditPredictionFinishedDebugEvent {
203                        buffer: buffer.downgrade(),
204                        model_output: Some(response.output_excerpt.clone()),
205                        position,
206                    },
207                ))
208                .ok();
209        }
210
211        let edit_prediction = process_completion_response(
212            response,
213            buffer,
214            &snapshot,
215            editable_range,
216            inputs,
217            buffer_snapshotted_at,
218            received_response_at,
219            cx,
220        )
221        .await;
222
223        let finished_at = Instant::now();
224
225        // record latency for ~1% of requests
226        if rand::random::<u8>() <= 2 {
227            telemetry::event!(
228                "Edit Prediction Request",
229                context_latency = done_gathering_context_at
230                    .duration_since(buffer_snapshotted_at)
231                    .as_millis(),
232                request_latency = received_response_at
233                    .duration_since(done_gathering_context_at)
234                    .as_millis(),
235                process_latency = finished_at.duration_since(received_response_at).as_millis()
236            );
237        }
238
239        edit_prediction.map(Some)
240    })
241}
242
243fn process_completion_response(
244    prediction_response: PredictEditsResponse,
245    buffer: Entity<Buffer>,
246    snapshot: &BufferSnapshot,
247    editable_range: Range<usize>,
248    inputs: ZetaPromptInput,
249    buffer_snapshotted_at: Instant,
250    received_response_at: Instant,
251    cx: &AsyncApp,
252) -> Task<Result<EditPredictionResult>> {
253    let snapshot = snapshot.clone();
254    let request_id = prediction_response.request_id;
255    let output_excerpt = prediction_response.output_excerpt;
256    cx.spawn(async move |cx| {
257        let output_excerpt: Arc<str> = output_excerpt.into();
258
259        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
260            .background_spawn({
261                let output_excerpt = output_excerpt.clone();
262                let editable_range = editable_range.clone();
263                let snapshot = snapshot.clone();
264                async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
265            })
266            .await?
267            .into();
268
269        let id = EditPredictionId(request_id.into());
270        Ok(EditPredictionResult::new(
271            id,
272            &buffer,
273            &snapshot,
274            edits,
275            None,
276            buffer_snapshotted_at,
277            received_response_at,
278            inputs,
279            cx,
280        )
281        .await)
282    })
283}
284
285pub(crate) fn parse_edits(
286    output_excerpt: &str,
287    editable_range: Range<usize>,
288    snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290    let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292    let start_markers = content
293        .match_indices(EDITABLE_REGION_START_MARKER)
294        .collect::<Vec<_>>();
295    anyhow::ensure!(
296        start_markers.len() <= 1,
297        "expected at most one start marker, found {}",
298        start_markers.len()
299    );
300
301    let end_markers = content
302        .match_indices(EDITABLE_REGION_END_MARKER)
303        .collect::<Vec<_>>();
304    anyhow::ensure!(
305        end_markers.len() <= 1,
306        "expected at most one end marker, found {}",
307        end_markers.len()
308    );
309
310    let sof_markers = content
311        .match_indices(START_OF_FILE_MARKER)
312        .collect::<Vec<_>>();
313    anyhow::ensure!(
314        sof_markers.len() <= 1,
315        "expected at most one start-of-file marker, found {}",
316        sof_markers.len()
317    );
318
319    let content_start = start_markers
320        .first()
321        .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
322        .unwrap_or(0);
323    let content_end = end_markers
324        .first()
325        .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
326        .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
327
328    // if there is a single newline between markers, content_start will be 1 more than content_end. .min ensures empty slice in that case
329    let new_text = &content[content_start.min(content_end)..content_end];
330
331    let old_text = snapshot
332        .text_for_range(editable_range.clone())
333        .collect::<String>();
334
335    Ok(compute_edits(
336        old_text,
337        new_text,
338        editable_range.start,
339        snapshot,
340    ))
341}
342
343pub fn compute_edits(
344    old_text: String,
345    new_text: &str,
346    offset: usize,
347    snapshot: &BufferSnapshot,
348) -> Vec<(Range<Anchor>, Arc<str>)> {
349    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
350}
351
352pub fn compute_edits_and_cursor_position(
353    old_text: String,
354    new_text: &str,
355    offset: usize,
356    cursor_offset_in_new_text: Option<usize>,
357    snapshot: &BufferSnapshot,
358) -> (
359    Vec<(Range<Anchor>, Arc<str>)>,
360    Option<PredictedCursorPosition>,
361) {
362    let diffs = text_diff(&old_text, new_text);
363
364    // Delta represents the cumulative change in byte count from all preceding edits.
365    // new_offset = old_offset + delta, so old_offset = new_offset - delta
366    let mut delta: isize = 0;
367    let mut cursor_position: Option<PredictedCursorPosition> = None;
368
369    let edits = diffs
370        .iter()
371        .map(|(raw_old_range, new_text)| {
372            // Compute cursor position if it falls within or before this edit.
373            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
374                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
375                let edit_end_in_new = edit_start_in_new + new_text.len();
376
377                if cursor_offset < edit_start_in_new {
378                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
379                    cursor_position = Some(PredictedCursorPosition::at_anchor(
380                        snapshot.anchor_after(offset + cursor_in_old),
381                    ));
382                } else if cursor_offset < edit_end_in_new {
383                    let offset_within_insertion = cursor_offset - edit_start_in_new;
384                    cursor_position = Some(PredictedCursorPosition::new(
385                        snapshot.anchor_before(offset + raw_old_range.start),
386                        offset_within_insertion,
387                    ));
388                }
389
390                delta += new_text.len() as isize - raw_old_range.len() as isize;
391            }
392
393            // Compute the edit with prefix/suffix trimming.
394            let mut old_range = raw_old_range.clone();
395            let old_slice = &old_text[old_range.clone()];
396
397            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
398            let suffix_len = common_prefix(
399                old_slice[prefix_len..].chars().rev(),
400                new_text[prefix_len..].chars().rev(),
401            );
402
403            old_range.start += offset;
404            old_range.end += offset;
405            old_range.start += prefix_len;
406            old_range.end -= suffix_len;
407
408            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
409            let range = if old_range.is_empty() {
410                let anchor = snapshot.anchor_after(old_range.start);
411                anchor..anchor
412            } else {
413                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
414            };
415            (range, new_text)
416        })
417        .collect();
418
419    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
420        let cursor_in_old = (cursor_offset as isize - delta) as usize;
421        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
422        cursor_position = Some(PredictedCursorPosition::at_anchor(
423            snapshot.anchor_after(buffer_offset),
424        ));
425    }
426
427    (edits, cursor_position)
428}
429
430fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
431    a.zip(b)
432        .take_while(|(a, b)| a == b)
433        .map(|(a, _)| a.len_utf8())
434        .sum()
435}
436
437fn git_info_for_file(
438    project: &Entity<Project>,
439    project_path: &ProjectPath,
440    cx: &App,
441) -> Option<PredictEditsGitInfo> {
442    let git_store = project.read(cx).git_store().read(cx);
443    if let Some((repository, _repo_path)) =
444        git_store.repository_and_path_for_project_path(project_path, cx)
445    {
446        let repository = repository.read(cx);
447        let head_sha = repository
448            .head_commit
449            .as_ref()
450            .map(|head_commit| head_commit.sha.to_string());
451        let remote_origin_url = repository.remote_origin_url.clone();
452        let remote_upstream_url = repository.remote_upstream_url.clone();
453        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
454            return None;
455        }
456        Some(PredictEditsGitInfo {
457            head_sha,
458            remote_origin_url,
459            remote_upstream_url,
460        })
461    } else {
462        None
463    }
464}
465
466pub struct GatherContextOutput {
467    pub body: PredictEditsBody,
468    pub context_range: Range<Point>,
469    pub editable_range: Range<usize>,
470    pub included_events_count: usize,
471}
472
473pub fn gather_context(
474    full_path_str: String,
475    snapshot: &BufferSnapshot,
476    cursor_point: language::Point,
477    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
478    trigger: PredictEditsRequestTrigger,
479    cx: &App,
480) -> Task<Result<GatherContextOutput>> {
481    cx.background_spawn({
482        let snapshot = snapshot.clone();
483        async move {
484            let input_excerpt = excerpt_for_cursor_position(
485                cursor_point,
486                &full_path_str,
487                &snapshot,
488                MAX_REWRITE_TOKENS,
489                MAX_CONTEXT_TOKENS,
490            );
491            let (input_events, included_events_count) = prompt_for_events();
492            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
493
494            let body = PredictEditsBody {
495                input_events,
496                input_excerpt: input_excerpt.prompt,
497                can_collect_data: false,
498                diagnostic_groups: None,
499                git_info: None,
500                outline: None,
501                speculated_output: None,
502                trigger,
503            };
504
505            Ok(GatherContextOutput {
506                body,
507                context_range: input_excerpt.context_range,
508                editable_range,
509                included_events_count,
510            })
511        }
512    })
513}
514
515pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
516    prompt_for_events_impl(events, max_tokens).0
517}
518
519fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
520    let mut result = String::new();
521    for (ix, event) in events.iter().rev().enumerate() {
522        let event_string = format_event(event.as_ref());
523        let event_tokens = guess_token_count(event_string.len());
524        if event_tokens > remaining_tokens {
525            return (result, ix);
526        }
527
528        if !result.is_empty() {
529            result.insert_str(0, "\n\n");
530        }
531        result.insert_str(0, &event_string);
532        remaining_tokens -= event_tokens;
533    }
534    return (result, events.len());
535}
536
537pub fn format_event(event: &Event) -> String {
538    match event {
539        Event::BufferChange {
540            path,
541            old_path,
542            diff,
543            ..
544        } => {
545            let mut prompt = String::new();
546
547            if old_path != path {
548                writeln!(
549                    prompt,
550                    "User renamed {} to {}\n",
551                    old_path.display(),
552                    path.display()
553                )
554                .unwrap();
555            }
556
557            if !diff.is_empty() {
558                write!(
559                    prompt,
560                    "User edited {}:\n```diff\n{}\n```",
561                    path.display(),
562                    diff
563                )
564                .unwrap();
565            }
566
567            prompt
568        }
569    }
570}
571
572#[derive(Debug)]
573pub struct InputExcerpt {
574    pub context_range: Range<Point>,
575    pub editable_range: Range<Point>,
576    pub prompt: String,
577}
578
579pub fn excerpt_for_cursor_position(
580    position: Point,
581    path: &str,
582    snapshot: &BufferSnapshot,
583    editable_region_token_limit: usize,
584    context_token_limit: usize,
585) -> InputExcerpt {
586    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
587        position,
588        snapshot,
589        editable_region_token_limit,
590        context_token_limit,
591    );
592
593    let mut prompt = String::new();
594
595    writeln!(&mut prompt, "```{path}").unwrap();
596    if context_range.start == Point::zero() {
597        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
598    }
599
600    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
601        prompt.push_str(chunk.text);
602    }
603
604    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
605
606    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
607        prompt.push_str(chunk.text);
608    }
609    write!(prompt, "\n```").unwrap();
610
611    InputExcerpt {
612        context_range,
613        editable_range,
614        prompt,
615    }
616}
617
618fn push_editable_range(
619    cursor_position: Point,
620    snapshot: &BufferSnapshot,
621    editable_range: Range<Point>,
622    prompt: &mut String,
623) {
624    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
625    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
626        prompt.push_str(chunk.text);
627    }
628    prompt.push_str(CURSOR_MARKER);
629    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
630        prompt.push_str(chunk.text);
631    }
632    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use gpui::{App, AppContext};
639    use indoc::indoc;
640    use language::Buffer;
641
642    #[gpui::test]
643    fn test_excerpt_for_cursor_position(cx: &mut App) {
644        let text = indoc! {r#"
645            fn foo() {
646                let x = 42;
647                println!("Hello, world!");
648            }
649
650            fn bar() {
651                let x = 42;
652                let mut sum = 0;
653                for i in 0..x {
654                    sum += i;
655                }
656                println!("Sum: {}", sum);
657                return sum;
658            }
659
660            fn generate_random_numbers() -> Vec<i32> {
661                let mut rng = rand::thread_rng();
662                let mut numbers = Vec::new();
663                for _ in 0..5 {
664                    numbers.push(rng.random_range(1..101));
665                }
666                numbers
667            }
668        "#};
669        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
670        let snapshot = buffer.read(cx).snapshot();
671
672        // The excerpt expands to syntax boundaries.
673        // With 50 token editable limit, we get a region that expands to syntax nodes.
674        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
675        assert_eq!(
676            excerpt.prompt,
677            indoc! {r#"
678            ```main.rs
679
680            fn bar() {
681                let x = 42;
682            <|editable_region_start|>
683                let mut sum = 0;
684                for i in 0..x {
685                    sum += i;
686                }
687                println!("Sum: {}", sum);
688                r<|user_cursor_is_here|>eturn sum;
689            }
690
691            fn generate_random_numbers() -> Vec<i32> {
692            <|editable_region_end|>
693                let mut rng = rand::thread_rng();
694                let mut numbers = Vec::new();
695            ```"#}
696        );
697
698        // With smaller budget, the region expands to syntax boundaries but is tighter.
699        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
700        assert_eq!(
701            excerpt.prompt,
702            indoc! {r#"
703            ```main.rs
704            fn bar() {
705                let x = 42;
706                let mut sum = 0;
707                for i in 0..x {
708            <|editable_region_start|>
709                    sum += i;
710                }
711                println!("Sum: {}", sum);
712                r<|user_cursor_is_here|>eturn sum;
713            }
714
715            fn generate_random_numbers() -> Vec<i32> {
716            <|editable_region_end|>
717                let mut rng = rand::thread_rng();
718            ```"#}
719        );
720    }
721
722    #[gpui::test]
723    fn test_parse_edits_empty_editable_region(cx: &mut App) {
724        let text = "fn foo() {\n    let x = 42;\n}\n";
725        let buffer = cx.new(|cx| Buffer::local(text, cx));
726        let snapshot = buffer.read(cx).snapshot();
727
728        let output = "<|editable_region_start|>\n<|editable_region_end|>";
729        let editable_range = 0..text.len();
730        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
731        assert_eq!(edits.len(), 1);
732        let (range, new_text) = &edits[0];
733        assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
734        assert_eq!(new_text.as_ref(), "");
735    }
736}