zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::Result;
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use edit_prediction_types::PredictedCursorPosition;
 14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 15use language::{
 16    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 17};
 18use project::{Project, ProjectPath};
 19use release_channel::AppVersion;
 20use text::Bias;
 21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 22use zeta_prompt::{
 23    Event, ZetaPromptInput,
 24    zeta1::{
 25        CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
 26        START_OF_FILE_MARKER,
 27    },
 28};
 29
 30pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 31pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 32pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 33
 34pub(crate) fn request_prediction_with_zeta1(
 35    store: &mut EditPredictionStore,
 36    EditPredictionModelInput {
 37        project,
 38        buffer,
 39        snapshot,
 40        position,
 41        events,
 42        trigger,
 43        debug_tx,
 44        ..
 45    }: EditPredictionModelInput,
 46    cx: &mut Context<EditPredictionStore>,
 47) -> Task<Result<Option<EditPredictionResult>>> {
 48    let buffer_snapshotted_at = Instant::now();
 49    let client = store.client.clone();
 50    let llm_token = store.llm_token.clone();
 51    let app_version = AppVersion::global(cx);
 52
 53    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 54        let can_collect_file = store.can_collect_file(&project, file, cx);
 55        let git_info = if can_collect_file {
 56            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 57        } else {
 58            None
 59        };
 60        (git_info, can_collect_file)
 61    } else {
 62        (None, false)
 63    };
 64
 65    let full_path: Arc<Path> = snapshot
 66        .file()
 67        .map(|f| Arc::from(f.full_path(cx).as_path()))
 68        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 69    let full_path_str = full_path.to_string_lossy().into_owned();
 70    let cursor_point = position.to_point(&snapshot);
 71    let prompt_for_events = {
 72        let events = events.clone();
 73        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 74    };
 75    let gather_task = gather_context(
 76        full_path_str,
 77        &snapshot,
 78        cursor_point,
 79        prompt_for_events,
 80        trigger,
 81        cx,
 82    );
 83
 84    let (uri, require_auth) = match &store.custom_predict_edits_url {
 85        Some(custom_url) => (custom_url.clone(), false),
 86        None => {
 87            match client
 88                .http_client()
 89                .build_zed_llm_url("/predict_edits/v2", &[])
 90            {
 91                Ok(url) => (url.into(), true),
 92                Err(err) => return Task::ready(Err(err)),
 93            }
 94        }
 95    };
 96
 97    cx.spawn(async move |this, cx| {
 98        let GatherContextOutput {
 99            mut body,
100            context_range,
101            editable_range,
102            included_events_count,
103        } = gather_task.await?;
104        let done_gathering_context_at = Instant::now();
105
106        let included_events = &events[events.len() - included_events_count..events.len()];
107        body.can_collect_data = can_collect_file
108            && this
109                .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
110                .unwrap_or(false);
111        if body.can_collect_data {
112            body.git_info = git_info;
113        }
114
115        log::debug!(
116            "Events:\n{}\nExcerpt:\n{:?}",
117            body.input_events,
118            body.input_excerpt
119        );
120
121        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
122            |request| {
123                Ok(request
124                    .uri(uri.as_str())
125                    .body(serde_json::to_string(&body)?.into())?)
126            },
127            client,
128            llm_token,
129            app_version,
130            require_auth,
131        )
132        .await;
133
134        let context_start_offset = context_range.start.to_offset(&snapshot);
135        let context_start_row = context_range.start.row;
136        let editable_offset_range = editable_range.to_offset(&snapshot);
137
138        let inputs = ZetaPromptInput {
139            events: included_events.into(),
140            related_files: vec![],
141            cursor_path: full_path,
142            cursor_excerpt: snapshot
143                .text_for_range(context_range)
144                .collect::<String>()
145                .into(),
146            editable_range_in_excerpt: (editable_range.start - context_start_offset)
147                ..(editable_offset_range.end - context_start_offset),
148            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
149            excerpt_start_row: Some(context_start_row),
150        };
151
152        if let Some(debug_tx) = &debug_tx {
153            debug_tx
154                .unbounded_send(DebugEvent::EditPredictionStarted(
155                    EditPredictionStartedDebugEvent {
156                        buffer: buffer.downgrade(),
157                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
158                        position,
159                    },
160                ))
161                .ok();
162        }
163
164        let (response, usage) = match response {
165            Ok(response) => response,
166            Err(err) => {
167                if err.is::<ZedUpdateRequiredError>() {
168                    cx.update(|cx| {
169                        this.update(cx, |ep_store, _cx| {
170                            ep_store.update_required = true;
171                        })
172                        .ok();
173
174                        let error_message: SharedString = err.to_string().into();
175                        show_app_notification(
176                            NotificationId::unique::<ZedUpdateRequiredError>(),
177                            cx,
178                            move |cx| {
179                                cx.new(|cx| {
180                                    ErrorMessagePrompt::new(error_message.clone(), cx)
181                                        .with_link_button("Update Zed", "https://zed.dev/releases")
182                                })
183                            },
184                        );
185                    });
186                }
187
188                return Err(err);
189            }
190        };
191
192        let received_response_at = Instant::now();
193        log::debug!("completion response: {}", &response.output_excerpt);
194
195        if let Some(usage) = usage {
196            this.update(cx, |this, cx| {
197                this.user_store.update(cx, |user_store, cx| {
198                    user_store.update_edit_prediction_usage(usage, cx);
199                });
200            })
201            .ok();
202        }
203
204        if let Some(debug_tx) = &debug_tx {
205            debug_tx
206                .unbounded_send(DebugEvent::EditPredictionFinished(
207                    EditPredictionFinishedDebugEvent {
208                        buffer: buffer.downgrade(),
209                        model_output: Some(response.output_excerpt.clone()),
210                        position,
211                    },
212                ))
213                .ok();
214        }
215
216        let edit_prediction = process_completion_response(
217            response,
218            buffer,
219            &snapshot,
220            editable_range,
221            inputs,
222            buffer_snapshotted_at,
223            received_response_at,
224            cx,
225        )
226        .await;
227
228        let finished_at = Instant::now();
229
230        // record latency for ~1% of requests
231        if rand::random::<u8>() <= 2 {
232            telemetry::event!(
233                "Edit Prediction Request",
234                context_latency = done_gathering_context_at
235                    .duration_since(buffer_snapshotted_at)
236                    .as_millis(),
237                request_latency = received_response_at
238                    .duration_since(done_gathering_context_at)
239                    .as_millis(),
240                process_latency = finished_at.duration_since(received_response_at).as_millis()
241            );
242        }
243
244        edit_prediction.map(Some)
245    })
246}
247
248fn process_completion_response(
249    prediction_response: PredictEditsResponse,
250    buffer: Entity<Buffer>,
251    snapshot: &BufferSnapshot,
252    editable_range: Range<usize>,
253    inputs: ZetaPromptInput,
254    buffer_snapshotted_at: Instant,
255    received_response_at: Instant,
256    cx: &AsyncApp,
257) -> Task<Result<EditPredictionResult>> {
258    let snapshot = snapshot.clone();
259    let request_id = prediction_response.request_id;
260    let output_excerpt = prediction_response.output_excerpt;
261    cx.spawn(async move |cx| {
262        let output_excerpt: Arc<str> = output_excerpt.into();
263
264        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
265            .background_spawn({
266                let output_excerpt = output_excerpt.clone();
267                let editable_range = editable_range.clone();
268                let snapshot = snapshot.clone();
269                async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
270            })
271            .await?
272            .into();
273
274        let id = EditPredictionId(request_id.into());
275        Ok(EditPredictionResult::new(
276            id,
277            &buffer,
278            &snapshot,
279            edits,
280            None,
281            buffer_snapshotted_at,
282            received_response_at,
283            inputs,
284            cx,
285        )
286        .await)
287    })
288}
289
290pub(crate) fn parse_edits(
291    output_excerpt: &str,
292    editable_range: Range<usize>,
293    snapshot: &BufferSnapshot,
294) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
295    let content = output_excerpt.replace(CURSOR_MARKER, "");
296
297    let start_markers = content
298        .match_indices(EDITABLE_REGION_START_MARKER)
299        .collect::<Vec<_>>();
300    anyhow::ensure!(
301        start_markers.len() <= 1,
302        "expected at most one start marker, found {}",
303        start_markers.len()
304    );
305
306    let end_markers = content
307        .match_indices(EDITABLE_REGION_END_MARKER)
308        .collect::<Vec<_>>();
309    anyhow::ensure!(
310        end_markers.len() <= 1,
311        "expected at most one end marker, found {}",
312        end_markers.len()
313    );
314
315    let sof_markers = content
316        .match_indices(START_OF_FILE_MARKER)
317        .collect::<Vec<_>>();
318    anyhow::ensure!(
319        sof_markers.len() <= 1,
320        "expected at most one start-of-file marker, found {}",
321        sof_markers.len()
322    );
323
324    let content_start = start_markers
325        .first()
326        .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
327        .unwrap_or(0);
328    let content_end = end_markers
329        .first()
330        .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
331        .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
332
333    let new_text = &content[content_start..content_end];
334
335    let old_text = snapshot
336        .text_for_range(editable_range.clone())
337        .collect::<String>();
338
339    Ok(compute_edits(
340        old_text,
341        new_text,
342        editable_range.start,
343        snapshot,
344    ))
345}
346
347pub fn compute_edits(
348    old_text: String,
349    new_text: &str,
350    offset: usize,
351    snapshot: &BufferSnapshot,
352) -> Vec<(Range<Anchor>, Arc<str>)> {
353    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
354}
355
356pub fn compute_edits_and_cursor_position(
357    old_text: String,
358    new_text: &str,
359    offset: usize,
360    cursor_offset_in_new_text: Option<usize>,
361    snapshot: &BufferSnapshot,
362) -> (
363    Vec<(Range<Anchor>, Arc<str>)>,
364    Option<PredictedCursorPosition>,
365) {
366    let diffs = text_diff(&old_text, new_text);
367
368    // Delta represents the cumulative change in byte count from all preceding edits.
369    // new_offset = old_offset + delta, so old_offset = new_offset - delta
370    let mut delta: isize = 0;
371    let mut cursor_position: Option<PredictedCursorPosition> = None;
372
373    let edits = diffs
374        .iter()
375        .map(|(raw_old_range, new_text)| {
376            // Compute cursor position if it falls within or before this edit.
377            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
378                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
379                let edit_end_in_new = edit_start_in_new + new_text.len();
380
381                if cursor_offset < edit_start_in_new {
382                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
383                    cursor_position = Some(PredictedCursorPosition::at_anchor(
384                        snapshot.anchor_after(offset + cursor_in_old),
385                    ));
386                } else if cursor_offset < edit_end_in_new {
387                    let offset_within_insertion = cursor_offset - edit_start_in_new;
388                    cursor_position = Some(PredictedCursorPosition::new(
389                        snapshot.anchor_before(offset + raw_old_range.start),
390                        offset_within_insertion,
391                    ));
392                }
393
394                delta += new_text.len() as isize - raw_old_range.len() as isize;
395            }
396
397            // Compute the edit with prefix/suffix trimming.
398            let mut old_range = raw_old_range.clone();
399            let old_slice = &old_text[old_range.clone()];
400
401            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
402            let suffix_len = common_prefix(
403                old_slice[prefix_len..].chars().rev(),
404                new_text[prefix_len..].chars().rev(),
405            );
406
407            old_range.start += offset;
408            old_range.end += offset;
409            old_range.start += prefix_len;
410            old_range.end -= suffix_len;
411
412            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
413            let range = if old_range.is_empty() {
414                let anchor = snapshot.anchor_after(old_range.start);
415                anchor..anchor
416            } else {
417                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
418            };
419            (range, new_text)
420        })
421        .collect();
422
423    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
424        let cursor_in_old = (cursor_offset as isize - delta) as usize;
425        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
426        cursor_position = Some(PredictedCursorPosition::at_anchor(
427            snapshot.anchor_after(buffer_offset),
428        ));
429    }
430
431    (edits, cursor_position)
432}
433
434fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
435    a.zip(b)
436        .take_while(|(a, b)| a == b)
437        .map(|(a, _)| a.len_utf8())
438        .sum()
439}
440
441fn git_info_for_file(
442    project: &Entity<Project>,
443    project_path: &ProjectPath,
444    cx: &App,
445) -> Option<PredictEditsGitInfo> {
446    let git_store = project.read(cx).git_store().read(cx);
447    if let Some((repository, _repo_path)) =
448        git_store.repository_and_path_for_project_path(project_path, cx)
449    {
450        let repository = repository.read(cx);
451        let head_sha = repository
452            .head_commit
453            .as_ref()
454            .map(|head_commit| head_commit.sha.to_string());
455        let remote_origin_url = repository.remote_origin_url.clone();
456        let remote_upstream_url = repository.remote_upstream_url.clone();
457        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
458            return None;
459        }
460        Some(PredictEditsGitInfo {
461            head_sha,
462            remote_origin_url,
463            remote_upstream_url,
464        })
465    } else {
466        None
467    }
468}
469
470pub struct GatherContextOutput {
471    pub body: PredictEditsBody,
472    pub context_range: Range<Point>,
473    pub editable_range: Range<usize>,
474    pub included_events_count: usize,
475}
476
477pub fn gather_context(
478    full_path_str: String,
479    snapshot: &BufferSnapshot,
480    cursor_point: language::Point,
481    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
482    trigger: PredictEditsRequestTrigger,
483    cx: &App,
484) -> Task<Result<GatherContextOutput>> {
485    cx.background_spawn({
486        let snapshot = snapshot.clone();
487        async move {
488            let input_excerpt = excerpt_for_cursor_position(
489                cursor_point,
490                &full_path_str,
491                &snapshot,
492                MAX_REWRITE_TOKENS,
493                MAX_CONTEXT_TOKENS,
494            );
495            let (input_events, included_events_count) = prompt_for_events();
496            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
497
498            let body = PredictEditsBody {
499                input_events,
500                input_excerpt: input_excerpt.prompt,
501                can_collect_data: false,
502                diagnostic_groups: None,
503                git_info: None,
504                outline: None,
505                speculated_output: None,
506                trigger,
507            };
508
509            Ok(GatherContextOutput {
510                body,
511                context_range: input_excerpt.context_range,
512                editable_range,
513                included_events_count,
514            })
515        }
516    })
517}
518
519pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
520    prompt_for_events_impl(events, max_tokens).0
521}
522
523fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
524    let mut result = String::new();
525    for (ix, event) in events.iter().rev().enumerate() {
526        let event_string = format_event(event.as_ref());
527        let event_tokens = guess_token_count(event_string.len());
528        if event_tokens > remaining_tokens {
529            return (result, ix);
530        }
531
532        if !result.is_empty() {
533            result.insert_str(0, "\n\n");
534        }
535        result.insert_str(0, &event_string);
536        remaining_tokens -= event_tokens;
537    }
538    return (result, events.len());
539}
540
541pub fn format_event(event: &Event) -> String {
542    match event {
543        Event::BufferChange {
544            path,
545            old_path,
546            diff,
547            ..
548        } => {
549            let mut prompt = String::new();
550
551            if old_path != path {
552                writeln!(
553                    prompt,
554                    "User renamed {} to {}\n",
555                    old_path.display(),
556                    path.display()
557                )
558                .unwrap();
559            }
560
561            if !diff.is_empty() {
562                write!(
563                    prompt,
564                    "User edited {}:\n```diff\n{}\n```",
565                    path.display(),
566                    diff
567                )
568                .unwrap();
569            }
570
571            prompt
572        }
573    }
574}
575
576#[derive(Debug)]
577pub struct InputExcerpt {
578    pub context_range: Range<Point>,
579    pub editable_range: Range<Point>,
580    pub prompt: String,
581}
582
583pub fn excerpt_for_cursor_position(
584    position: Point,
585    path: &str,
586    snapshot: &BufferSnapshot,
587    editable_region_token_limit: usize,
588    context_token_limit: usize,
589) -> InputExcerpt {
590    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
591        position,
592        snapshot,
593        editable_region_token_limit,
594        context_token_limit,
595    );
596
597    let mut prompt = String::new();
598
599    writeln!(&mut prompt, "```{path}").unwrap();
600    if context_range.start == Point::zero() {
601        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
602    }
603
604    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
605        prompt.push_str(chunk.text);
606    }
607
608    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
609
610    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
611        prompt.push_str(chunk.text);
612    }
613    write!(prompt, "\n```").unwrap();
614
615    InputExcerpt {
616        context_range,
617        editable_range,
618        prompt,
619    }
620}
621
622fn push_editable_range(
623    cursor_position: Point,
624    snapshot: &BufferSnapshot,
625    editable_range: Range<Point>,
626    prompt: &mut String,
627) {
628    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
629    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
630        prompt.push_str(chunk.text);
631    }
632    prompt.push_str(CURSOR_MARKER);
633    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
634        prompt.push_str(chunk.text);
635    }
636    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use gpui::{App, AppContext};
643    use indoc::indoc;
644    use language::Buffer;
645
646    #[gpui::test]
647    fn test_excerpt_for_cursor_position(cx: &mut App) {
648        let text = indoc! {r#"
649            fn foo() {
650                let x = 42;
651                println!("Hello, world!");
652            }
653
654            fn bar() {
655                let x = 42;
656                let mut sum = 0;
657                for i in 0..x {
658                    sum += i;
659                }
660                println!("Sum: {}", sum);
661                return sum;
662            }
663
664            fn generate_random_numbers() -> Vec<i32> {
665                let mut rng = rand::thread_rng();
666                let mut numbers = Vec::new();
667                for _ in 0..5 {
668                    numbers.push(rng.random_range(1..101));
669                }
670                numbers
671            }
672        "#};
673        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
674        let snapshot = buffer.read(cx).snapshot();
675
676        // The excerpt expands to syntax boundaries.
677        // With 50 token editable limit, we get a region that expands to syntax nodes.
678        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
679        assert_eq!(
680            excerpt.prompt,
681            indoc! {r#"
682            ```main.rs
683
684            fn bar() {
685                let x = 42;
686            <|editable_region_start|>
687                let mut sum = 0;
688                for i in 0..x {
689                    sum += i;
690                }
691                println!("Sum: {}", sum);
692                r<|user_cursor_is_here|>eturn sum;
693            }
694
695            fn generate_random_numbers() -> Vec<i32> {
696            <|editable_region_end|>
697                let mut rng = rand::thread_rng();
698                let mut numbers = Vec::new();
699            ```"#}
700        );
701
702        // With smaller budget, the region expands to syntax boundaries but is tighter.
703        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
704        assert_eq!(
705            excerpt.prompt,
706            indoc! {r#"
707            ```main.rs
708            fn bar() {
709                let x = 42;
710                let mut sum = 0;
711                for i in 0..x {
712            <|editable_region_start|>
713                    sum += i;
714                }
715                println!("Sum: {}", sum);
716                r<|user_cursor_is_here|>eturn sum;
717            }
718
719            fn generate_random_numbers() -> Vec<i32> {
720            <|editable_region_end|>
721                let mut rng = rand::thread_rng();
722            ```"#}
723        );
724    }
725}