zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::Result;
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use edit_prediction_types::PredictedCursorPosition;
 14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 15use language::{
 16    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 17};
 18use project::{Project, ProjectPath};
 19use release_channel::AppVersion;
 20use text::Bias;
 21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 22use zeta_prompt::{
 23    Event, ZetaPromptInput,
 24    zeta1::{
 25        CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
 26        START_OF_FILE_MARKER,
 27    },
 28};
 29
 30pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 31pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 32pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 33
 34pub(crate) fn request_prediction_with_zeta1(
 35    store: &mut EditPredictionStore,
 36    EditPredictionModelInput {
 37        project,
 38        buffer,
 39        snapshot,
 40        position,
 41        events,
 42        trigger,
 43        debug_tx,
 44        ..
 45    }: EditPredictionModelInput,
 46    cx: &mut Context<EditPredictionStore>,
 47) -> Task<Result<Option<EditPredictionResult>>> {
 48    let buffer_snapshotted_at = Instant::now();
 49    let client = store.client.clone();
 50    let llm_token = store.llm_token.clone();
 51    let app_version = AppVersion::global(cx);
 52
 53    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 54        let can_collect_file = store.can_collect_file(&project, file, cx);
 55        let git_info = if can_collect_file {
 56            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 57        } else {
 58            None
 59        };
 60        (git_info, can_collect_file)
 61    } else {
 62        (None, false)
 63    };
 64
 65    let full_path: Arc<Path> = snapshot
 66        .file()
 67        .map(|f| Arc::from(f.full_path(cx).as_path()))
 68        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 69    let full_path_str = full_path.to_string_lossy().into_owned();
 70    let cursor_point = position.to_point(&snapshot);
 71    let prompt_for_events = {
 72        let events = events.clone();
 73        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 74    };
 75    let gather_task = gather_context(
 76        full_path_str,
 77        &snapshot,
 78        cursor_point,
 79        prompt_for_events,
 80        trigger,
 81        cx,
 82    );
 83
 84    let (uri, require_auth) = match &store.custom_predict_edits_url {
 85        Some(custom_url) => (custom_url.clone(), false),
 86        None => {
 87            match client
 88                .http_client()
 89                .build_zed_llm_url("/predict_edits/v2", &[])
 90            {
 91                Ok(url) => (url.into(), true),
 92                Err(err) => return Task::ready(Err(err)),
 93            }
 94        }
 95    };
 96
 97    cx.spawn(async move |this, cx| {
 98        let GatherContextOutput {
 99            mut body,
100            context_range,
101            editable_range,
102            included_events_count,
103        } = gather_task.await?;
104        let done_gathering_context_at = Instant::now();
105
106        let included_events = &events[events.len() - included_events_count..events.len()];
107        body.can_collect_data = can_collect_file
108            && this
109                .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
110                .unwrap_or(false);
111        if body.can_collect_data {
112            body.git_info = git_info;
113        }
114
115        log::debug!(
116            "Events:\n{}\nExcerpt:\n{:?}",
117            body.input_events,
118            body.input_excerpt
119        );
120
121        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
122            |request| {
123                Ok(request
124                    .uri(uri.as_str())
125                    .body(serde_json::to_string(&body)?.into())?)
126            },
127            client,
128            llm_token,
129            app_version,
130            require_auth,
131        )
132        .await;
133
134        let context_start_offset = context_range.start.to_offset(&snapshot);
135        let context_start_row = context_range.start.row;
136        let editable_offset_range = editable_range.to_offset(&snapshot);
137
138        let inputs = ZetaPromptInput {
139            events: included_events.into(),
140            related_files: vec![],
141            cursor_path: full_path,
142            cursor_excerpt: snapshot
143                .text_for_range(context_range)
144                .collect::<String>()
145                .into(),
146            editable_range_in_excerpt: (editable_range.start - context_start_offset)
147                ..(editable_offset_range.end - context_start_offset),
148            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
149            excerpt_start_row: Some(context_start_row),
150        };
151
152        if let Some(debug_tx) = &debug_tx {
153            debug_tx
154                .unbounded_send(DebugEvent::EditPredictionStarted(
155                    EditPredictionStartedDebugEvent {
156                        buffer: buffer.downgrade(),
157                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
158                        position,
159                    },
160                ))
161                .ok();
162        }
163
164        let (response, usage) = match response {
165            Ok(response) => response,
166            Err(err) => {
167                if err.is::<ZedUpdateRequiredError>() {
168                    cx.update(|cx| {
169                        this.update(cx, |ep_store, _cx| {
170                            ep_store.update_required = true;
171                        })
172                        .ok();
173
174                        let error_message: SharedString = err.to_string().into();
175                        show_app_notification(
176                            NotificationId::unique::<ZedUpdateRequiredError>(),
177                            cx,
178                            move |cx| {
179                                cx.new(|cx| {
180                                    ErrorMessagePrompt::new(error_message.clone(), cx)
181                                        .with_link_button("Update Zed", "https://zed.dev/releases")
182                                })
183                            },
184                        );
185                    });
186                }
187
188                return Err(err);
189            }
190        };
191
192        let received_response_at = Instant::now();
193        log::debug!("completion response: {}", &response.output_excerpt);
194
195        if let Some(usage) = usage {
196            this.update(cx, |this, cx| {
197                this.user_store.update(cx, |user_store, cx| {
198                    user_store.update_edit_prediction_usage(usage, cx);
199                });
200            })
201            .ok();
202        }
203
204        if let Some(debug_tx) = &debug_tx {
205            debug_tx
206                .unbounded_send(DebugEvent::EditPredictionFinished(
207                    EditPredictionFinishedDebugEvent {
208                        buffer: buffer.downgrade(),
209                        model_output: Some(response.output_excerpt.clone()),
210                        position,
211                    },
212                ))
213                .ok();
214        }
215
216        let edit_prediction = process_completion_response(
217            response,
218            buffer,
219            &snapshot,
220            editable_range,
221            inputs,
222            buffer_snapshotted_at,
223            received_response_at,
224            cx,
225        )
226        .await;
227
228        let finished_at = Instant::now();
229
230        // record latency for ~1% of requests
231        if rand::random::<u8>() <= 2 {
232            telemetry::event!(
233                "Edit Prediction Request",
234                context_latency = done_gathering_context_at
235                    .duration_since(buffer_snapshotted_at)
236                    .as_millis(),
237                request_latency = received_response_at
238                    .duration_since(done_gathering_context_at)
239                    .as_millis(),
240                process_latency = finished_at.duration_since(received_response_at).as_millis()
241            );
242        }
243
244        edit_prediction.map(Some)
245    })
246}
247
248fn process_completion_response(
249    prediction_response: PredictEditsResponse,
250    buffer: Entity<Buffer>,
251    snapshot: &BufferSnapshot,
252    editable_range: Range<usize>,
253    inputs: ZetaPromptInput,
254    buffer_snapshotted_at: Instant,
255    received_response_at: Instant,
256    cx: &AsyncApp,
257) -> Task<Result<EditPredictionResult>> {
258    let snapshot = snapshot.clone();
259    let request_id = prediction_response.request_id;
260    let output_excerpt = prediction_response.output_excerpt;
261    cx.spawn(async move |cx| {
262        let output_excerpt: Arc<str> = output_excerpt.into();
263
264        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
265            .background_spawn({
266                let output_excerpt = output_excerpt.clone();
267                let editable_range = editable_range.clone();
268                let snapshot = snapshot.clone();
269                async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
270            })
271            .await?
272            .into();
273
274        let id = EditPredictionId(request_id.into());
275        Ok(EditPredictionResult::new(
276            id,
277            &buffer,
278            &snapshot,
279            edits,
280            None,
281            buffer_snapshotted_at,
282            received_response_at,
283            inputs,
284            cx,
285        )
286        .await)
287    })
288}
289
290pub(crate) fn parse_edits(
291    output_excerpt: &str,
292    editable_range: Range<usize>,
293    snapshot: &BufferSnapshot,
294) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
295    let content = output_excerpt.replace(CURSOR_MARKER, "");
296
297    let start_markers = content
298        .match_indices(EDITABLE_REGION_START_MARKER)
299        .collect::<Vec<_>>();
300    anyhow::ensure!(
301        start_markers.len() <= 1,
302        "expected at most one start marker, found {}",
303        start_markers.len()
304    );
305
306    let end_markers = content
307        .match_indices(EDITABLE_REGION_END_MARKER)
308        .collect::<Vec<_>>();
309    anyhow::ensure!(
310        end_markers.len() <= 1,
311        "expected at most one end marker, found {}",
312        end_markers.len()
313    );
314
315    let sof_markers = content
316        .match_indices(START_OF_FILE_MARKER)
317        .collect::<Vec<_>>();
318    anyhow::ensure!(
319        sof_markers.len() <= 1,
320        "expected at most one start-of-file marker, found {}",
321        sof_markers.len()
322    );
323
324    let content_start = start_markers
325        .first()
326        .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
327        .unwrap_or(0);
328    let content_end = end_markers
329        .first()
330        .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
331        .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
332
333    // if there is a single newline between markers, content_start will be 1 more than content_end. .min ensures empty slice in that case
334    let new_text = &content[content_start.min(content_end)..content_end];
335
336    let old_text = snapshot
337        .text_for_range(editable_range.clone())
338        .collect::<String>();
339
340    Ok(compute_edits(
341        old_text,
342        new_text,
343        editable_range.start,
344        snapshot,
345    ))
346}
347
348pub fn compute_edits(
349    old_text: String,
350    new_text: &str,
351    offset: usize,
352    snapshot: &BufferSnapshot,
353) -> Vec<(Range<Anchor>, Arc<str>)> {
354    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
355}
356
357pub fn compute_edits_and_cursor_position(
358    old_text: String,
359    new_text: &str,
360    offset: usize,
361    cursor_offset_in_new_text: Option<usize>,
362    snapshot: &BufferSnapshot,
363) -> (
364    Vec<(Range<Anchor>, Arc<str>)>,
365    Option<PredictedCursorPosition>,
366) {
367    let diffs = text_diff(&old_text, new_text);
368
369    // Delta represents the cumulative change in byte count from all preceding edits.
370    // new_offset = old_offset + delta, so old_offset = new_offset - delta
371    let mut delta: isize = 0;
372    let mut cursor_position: Option<PredictedCursorPosition> = None;
373
374    let edits = diffs
375        .iter()
376        .map(|(raw_old_range, new_text)| {
377            // Compute cursor position if it falls within or before this edit.
378            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
379                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
380                let edit_end_in_new = edit_start_in_new + new_text.len();
381
382                if cursor_offset < edit_start_in_new {
383                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
384                    cursor_position = Some(PredictedCursorPosition::at_anchor(
385                        snapshot.anchor_after(offset + cursor_in_old),
386                    ));
387                } else if cursor_offset < edit_end_in_new {
388                    let offset_within_insertion = cursor_offset - edit_start_in_new;
389                    cursor_position = Some(PredictedCursorPosition::new(
390                        snapshot.anchor_before(offset + raw_old_range.start),
391                        offset_within_insertion,
392                    ));
393                }
394
395                delta += new_text.len() as isize - raw_old_range.len() as isize;
396            }
397
398            // Compute the edit with prefix/suffix trimming.
399            let mut old_range = raw_old_range.clone();
400            let old_slice = &old_text[old_range.clone()];
401
402            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
403            let suffix_len = common_prefix(
404                old_slice[prefix_len..].chars().rev(),
405                new_text[prefix_len..].chars().rev(),
406            );
407
408            old_range.start += offset;
409            old_range.end += offset;
410            old_range.start += prefix_len;
411            old_range.end -= suffix_len;
412
413            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
414            let range = if old_range.is_empty() {
415                let anchor = snapshot.anchor_after(old_range.start);
416                anchor..anchor
417            } else {
418                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
419            };
420            (range, new_text)
421        })
422        .collect();
423
424    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
425        let cursor_in_old = (cursor_offset as isize - delta) as usize;
426        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
427        cursor_position = Some(PredictedCursorPosition::at_anchor(
428            snapshot.anchor_after(buffer_offset),
429        ));
430    }
431
432    (edits, cursor_position)
433}
434
435fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
436    a.zip(b)
437        .take_while(|(a, b)| a == b)
438        .map(|(a, _)| a.len_utf8())
439        .sum()
440}
441
442fn git_info_for_file(
443    project: &Entity<Project>,
444    project_path: &ProjectPath,
445    cx: &App,
446) -> Option<PredictEditsGitInfo> {
447    let git_store = project.read(cx).git_store().read(cx);
448    if let Some((repository, _repo_path)) =
449        git_store.repository_and_path_for_project_path(project_path, cx)
450    {
451        let repository = repository.read(cx);
452        let head_sha = repository
453            .head_commit
454            .as_ref()
455            .map(|head_commit| head_commit.sha.to_string());
456        let remote_origin_url = repository.remote_origin_url.clone();
457        let remote_upstream_url = repository.remote_upstream_url.clone();
458        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
459            return None;
460        }
461        Some(PredictEditsGitInfo {
462            head_sha,
463            remote_origin_url,
464            remote_upstream_url,
465        })
466    } else {
467        None
468    }
469}
470
471pub struct GatherContextOutput {
472    pub body: PredictEditsBody,
473    pub context_range: Range<Point>,
474    pub editable_range: Range<usize>,
475    pub included_events_count: usize,
476}
477
478pub fn gather_context(
479    full_path_str: String,
480    snapshot: &BufferSnapshot,
481    cursor_point: language::Point,
482    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
483    trigger: PredictEditsRequestTrigger,
484    cx: &App,
485) -> Task<Result<GatherContextOutput>> {
486    cx.background_spawn({
487        let snapshot = snapshot.clone();
488        async move {
489            let input_excerpt = excerpt_for_cursor_position(
490                cursor_point,
491                &full_path_str,
492                &snapshot,
493                MAX_REWRITE_TOKENS,
494                MAX_CONTEXT_TOKENS,
495            );
496            let (input_events, included_events_count) = prompt_for_events();
497            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
498
499            let body = PredictEditsBody {
500                input_events,
501                input_excerpt: input_excerpt.prompt,
502                can_collect_data: false,
503                diagnostic_groups: None,
504                git_info: None,
505                outline: None,
506                speculated_output: None,
507                trigger,
508            };
509
510            Ok(GatherContextOutput {
511                body,
512                context_range: input_excerpt.context_range,
513                editable_range,
514                included_events_count,
515            })
516        }
517    })
518}
519
520pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
521    prompt_for_events_impl(events, max_tokens).0
522}
523
524fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
525    let mut result = String::new();
526    for (ix, event) in events.iter().rev().enumerate() {
527        let event_string = format_event(event.as_ref());
528        let event_tokens = guess_token_count(event_string.len());
529        if event_tokens > remaining_tokens {
530            return (result, ix);
531        }
532
533        if !result.is_empty() {
534            result.insert_str(0, "\n\n");
535        }
536        result.insert_str(0, &event_string);
537        remaining_tokens -= event_tokens;
538    }
539    return (result, events.len());
540}
541
542pub fn format_event(event: &Event) -> String {
543    match event {
544        Event::BufferChange {
545            path,
546            old_path,
547            diff,
548            ..
549        } => {
550            let mut prompt = String::new();
551
552            if old_path != path {
553                writeln!(
554                    prompt,
555                    "User renamed {} to {}\n",
556                    old_path.display(),
557                    path.display()
558                )
559                .unwrap();
560            }
561
562            if !diff.is_empty() {
563                write!(
564                    prompt,
565                    "User edited {}:\n```diff\n{}\n```",
566                    path.display(),
567                    diff
568                )
569                .unwrap();
570            }
571
572            prompt
573        }
574    }
575}
576
577#[derive(Debug)]
578pub struct InputExcerpt {
579    pub context_range: Range<Point>,
580    pub editable_range: Range<Point>,
581    pub prompt: String,
582}
583
584pub fn excerpt_for_cursor_position(
585    position: Point,
586    path: &str,
587    snapshot: &BufferSnapshot,
588    editable_region_token_limit: usize,
589    context_token_limit: usize,
590) -> InputExcerpt {
591    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
592        position,
593        snapshot,
594        editable_region_token_limit,
595        context_token_limit,
596    );
597
598    let mut prompt = String::new();
599
600    writeln!(&mut prompt, "```{path}").unwrap();
601    if context_range.start == Point::zero() {
602        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
603    }
604
605    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
606        prompt.push_str(chunk.text);
607    }
608
609    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
610
611    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
612        prompt.push_str(chunk.text);
613    }
614    write!(prompt, "\n```").unwrap();
615
616    InputExcerpt {
617        context_range,
618        editable_range,
619        prompt,
620    }
621}
622
623fn push_editable_range(
624    cursor_position: Point,
625    snapshot: &BufferSnapshot,
626    editable_range: Range<Point>,
627    prompt: &mut String,
628) {
629    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
630    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
631        prompt.push_str(chunk.text);
632    }
633    prompt.push_str(CURSOR_MARKER);
634    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
635        prompt.push_str(chunk.text);
636    }
637    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643    use gpui::{App, AppContext};
644    use indoc::indoc;
645    use language::Buffer;
646
647    #[gpui::test]
648    fn test_excerpt_for_cursor_position(cx: &mut App) {
649        let text = indoc! {r#"
650            fn foo() {
651                let x = 42;
652                println!("Hello, world!");
653            }
654
655            fn bar() {
656                let x = 42;
657                let mut sum = 0;
658                for i in 0..x {
659                    sum += i;
660                }
661                println!("Sum: {}", sum);
662                return sum;
663            }
664
665            fn generate_random_numbers() -> Vec<i32> {
666                let mut rng = rand::thread_rng();
667                let mut numbers = Vec::new();
668                for _ in 0..5 {
669                    numbers.push(rng.random_range(1..101));
670                }
671                numbers
672            }
673        "#};
674        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
675        let snapshot = buffer.read(cx).snapshot();
676
677        // The excerpt expands to syntax boundaries.
678        // With 50 token editable limit, we get a region that expands to syntax nodes.
679        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
680        assert_eq!(
681            excerpt.prompt,
682            indoc! {r#"
683            ```main.rs
684
685            fn bar() {
686                let x = 42;
687            <|editable_region_start|>
688                let mut sum = 0;
689                for i in 0..x {
690                    sum += i;
691                }
692                println!("Sum: {}", sum);
693                r<|user_cursor_is_here|>eturn sum;
694            }
695
696            fn generate_random_numbers() -> Vec<i32> {
697            <|editable_region_end|>
698                let mut rng = rand::thread_rng();
699                let mut numbers = Vec::new();
700            ```"#}
701        );
702
703        // With smaller budget, the region expands to syntax boundaries but is tighter.
704        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
705        assert_eq!(
706            excerpt.prompt,
707            indoc! {r#"
708            ```main.rs
709            fn bar() {
710                let x = 42;
711                let mut sum = 0;
712                for i in 0..x {
713            <|editable_region_start|>
714                    sum += i;
715                }
716                println!("Sum: {}", sum);
717                r<|user_cursor_is_here|>eturn sum;
718            }
719
720            fn generate_random_numbers() -> Vec<i32> {
721            <|editable_region_end|>
722                let mut rng = rand::thread_rng();
723            ```"#}
724        );
725    }
726
727    #[gpui::test]
728    fn test_parse_edits_empty_editable_region(cx: &mut App) {
729        let text = "fn foo() {\n    let x = 42;\n}\n";
730        let buffer = cx.new(|cx| Buffer::local(text, cx));
731        let snapshot = buffer.read(cx).snapshot();
732
733        let output = "<|editable_region_start|>\n<|editable_region_end|>";
734        let editable_range = 0..text.len();
735        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
736        assert_eq!(edits.len(), 1);
737        let (range, new_text) = &edits[0];
738        assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
739        assert_eq!(new_text.as_ref(), "");
740    }
741}