zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::Result;
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use language::{
 15    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 16};
 17use project::{Project, ProjectPath};
 18use release_channel::AppVersion;
 19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 20use zeta_prompt::{
 21    Event, ZetaPromptInput,
 22    zeta1::{
 23        CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
 24        START_OF_FILE_MARKER,
 25    },
 26};
 27
 28pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 29pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 30pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 31
 32pub(crate) fn request_prediction_with_zeta1(
 33    store: &mut EditPredictionStore,
 34    EditPredictionModelInput {
 35        project,
 36        buffer,
 37        snapshot,
 38        position,
 39        events,
 40        trigger,
 41        debug_tx,
 42        ..
 43    }: EditPredictionModelInput,
 44    cx: &mut Context<EditPredictionStore>,
 45) -> Task<Result<Option<EditPredictionResult>>> {
 46    let buffer_snapshotted_at = Instant::now();
 47    let client = store.client.clone();
 48    let llm_token = store.llm_token.clone();
 49    let app_version = AppVersion::global(cx);
 50
 51    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 52        let can_collect_file = store.can_collect_file(&project, file, cx);
 53        let git_info = if can_collect_file {
 54            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 55        } else {
 56            None
 57        };
 58        (git_info, can_collect_file)
 59    } else {
 60        (None, false)
 61    };
 62
 63    let full_path: Arc<Path> = snapshot
 64        .file()
 65        .map(|f| Arc::from(f.full_path(cx).as_path()))
 66        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 67    let full_path_str = full_path.to_string_lossy().into_owned();
 68    let cursor_point = position.to_point(&snapshot);
 69    let prompt_for_events = {
 70        let events = events.clone();
 71        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 72    };
 73    let gather_task = gather_context(
 74        full_path_str,
 75        &snapshot,
 76        cursor_point,
 77        prompt_for_events,
 78        trigger,
 79        cx,
 80    );
 81
 82    let (uri, require_auth) = match &store.custom_predict_edits_url {
 83        Some(custom_url) => (custom_url.clone(), false),
 84        None => {
 85            match client
 86                .http_client()
 87                .build_zed_llm_url("/predict_edits/v2", &[])
 88            {
 89                Ok(url) => (url.into(), true),
 90                Err(err) => return Task::ready(Err(err)),
 91            }
 92        }
 93    };
 94
 95    cx.spawn(async move |this, cx| {
 96        let GatherContextOutput {
 97            mut body,
 98            context_range,
 99            editable_range,
100            included_events_count,
101        } = gather_task.await?;
102        let done_gathering_context_at = Instant::now();
103
104        let included_events = &events[events.len() - included_events_count..events.len()];
105        body.can_collect_data = can_collect_file
106            && this
107                .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
108                .unwrap_or(false);
109        if body.can_collect_data {
110            body.git_info = git_info;
111        }
112
113        log::debug!(
114            "Events:\n{}\nExcerpt:\n{:?}",
115            body.input_events,
116            body.input_excerpt
117        );
118
119        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
120            |request| {
121                Ok(request
122                    .uri(uri.as_str())
123                    .body(serde_json::to_string(&body)?.into())?)
124            },
125            client,
126            llm_token,
127            app_version,
128            require_auth,
129        )
130        .await;
131
132        let context_start_offset = context_range.start.to_offset(&snapshot);
133        let editable_offset_range = editable_range.to_offset(&snapshot);
134
135        let inputs = ZetaPromptInput {
136            events: included_events.into(),
137            related_files: vec![],
138            cursor_path: full_path,
139            cursor_excerpt: snapshot
140                .text_for_range(context_range)
141                .collect::<String>()
142                .into(),
143            editable_range_in_excerpt: (editable_range.start - context_start_offset)
144                ..(editable_offset_range.end - context_start_offset),
145            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
146        };
147
148        if let Some(debug_tx) = &debug_tx {
149            debug_tx
150                .unbounded_send(DebugEvent::EditPredictionStarted(
151                    EditPredictionStartedDebugEvent {
152                        buffer: buffer.downgrade(),
153                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
154                        position,
155                    },
156                ))
157                .ok();
158        }
159
160        let (response, usage) = match response {
161            Ok(response) => response,
162            Err(err) => {
163                if err.is::<ZedUpdateRequiredError>() {
164                    cx.update(|cx| {
165                        this.update(cx, |ep_store, _cx| {
166                            ep_store.update_required = true;
167                        })
168                        .ok();
169
170                        let error_message: SharedString = err.to_string().into();
171                        show_app_notification(
172                            NotificationId::unique::<ZedUpdateRequiredError>(),
173                            cx,
174                            move |cx| {
175                                cx.new(|cx| {
176                                    ErrorMessagePrompt::new(error_message.clone(), cx)
177                                        .with_link_button("Update Zed", "https://zed.dev/releases")
178                                })
179                            },
180                        );
181                    });
182                }
183
184                return Err(err);
185            }
186        };
187
188        let received_response_at = Instant::now();
189        log::debug!("completion response: {}", &response.output_excerpt);
190
191        if let Some(usage) = usage {
192            this.update(cx, |this, cx| {
193                this.user_store.update(cx, |user_store, cx| {
194                    user_store.update_edit_prediction_usage(usage, cx);
195                });
196            })
197            .ok();
198        }
199
200        if let Some(debug_tx) = &debug_tx {
201            debug_tx
202                .unbounded_send(DebugEvent::EditPredictionFinished(
203                    EditPredictionFinishedDebugEvent {
204                        buffer: buffer.downgrade(),
205                        model_output: Some(response.output_excerpt.clone()),
206                        position,
207                    },
208                ))
209                .ok();
210        }
211
212        let edit_prediction = process_completion_response(
213            response,
214            buffer,
215            &snapshot,
216            editable_range,
217            inputs,
218            buffer_snapshotted_at,
219            received_response_at,
220            cx,
221        )
222        .await;
223
224        let finished_at = Instant::now();
225
226        // record latency for ~1% of requests
227        if rand::random::<u8>() <= 2 {
228            telemetry::event!(
229                "Edit Prediction Request",
230                context_latency = done_gathering_context_at
231                    .duration_since(buffer_snapshotted_at)
232                    .as_millis(),
233                request_latency = received_response_at
234                    .duration_since(done_gathering_context_at)
235                    .as_millis(),
236                process_latency = finished_at.duration_since(received_response_at).as_millis()
237            );
238        }
239
240        edit_prediction.map(Some)
241    })
242}
243
244fn process_completion_response(
245    prediction_response: PredictEditsResponse,
246    buffer: Entity<Buffer>,
247    snapshot: &BufferSnapshot,
248    editable_range: Range<usize>,
249    inputs: ZetaPromptInput,
250    buffer_snapshotted_at: Instant,
251    received_response_at: Instant,
252    cx: &AsyncApp,
253) -> Task<Result<EditPredictionResult>> {
254    let snapshot = snapshot.clone();
255    let request_id = prediction_response.request_id;
256    let output_excerpt = prediction_response.output_excerpt;
257    cx.spawn(async move |cx| {
258        let output_excerpt: Arc<str> = output_excerpt.into();
259
260        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
261            .background_spawn({
262                let output_excerpt = output_excerpt.clone();
263                let editable_range = editable_range.clone();
264                let snapshot = snapshot.clone();
265                async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
266            })
267            .await?
268            .into();
269
270        let id = EditPredictionId(request_id.into());
271        Ok(EditPredictionResult::new(
272            id,
273            &buffer,
274            &snapshot,
275            edits,
276            buffer_snapshotted_at,
277            received_response_at,
278            inputs,
279            cx,
280        )
281        .await)
282    })
283}
284
285pub(crate) fn parse_edits(
286    output_excerpt: &str,
287    editable_range: Range<usize>,
288    snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290    let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292    let start_markers = content
293        .match_indices(EDITABLE_REGION_START_MARKER)
294        .collect::<Vec<_>>();
295    anyhow::ensure!(
296        start_markers.len() <= 1,
297        "expected at most one start marker, found {}",
298        start_markers.len()
299    );
300
301    let end_markers = content
302        .match_indices(EDITABLE_REGION_END_MARKER)
303        .collect::<Vec<_>>();
304    anyhow::ensure!(
305        end_markers.len() <= 1,
306        "expected at most one end marker, found {}",
307        end_markers.len()
308    );
309
310    let sof_markers = content
311        .match_indices(START_OF_FILE_MARKER)
312        .collect::<Vec<_>>();
313    anyhow::ensure!(
314        sof_markers.len() <= 1,
315        "expected at most one start-of-file marker, found {}",
316        sof_markers.len()
317    );
318
319    let content_start = start_markers
320        .first()
321        .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
322        .unwrap_or(0);
323    let content_end = end_markers
324        .first()
325        .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
326        .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
327
328    // if there is a single newline between markers, content_start will be 1 more than content_end. .min ensures empty slice in that case
329    let new_text = &content[content_start.min(content_end)..content_end];
330
331    let old_text = snapshot
332        .text_for_range(editable_range.clone())
333        .collect::<String>();
334
335    Ok(compute_edits(
336        old_text,
337        new_text,
338        editable_range.start,
339        snapshot,
340    ))
341}
342
343pub fn compute_edits(
344    old_text: String,
345    new_text: &str,
346    offset: usize,
347    snapshot: &BufferSnapshot,
348) -> Vec<(Range<Anchor>, Arc<str>)> {
349    text_diff(&old_text, new_text)
350        .into_iter()
351        .map(|(mut old_range, new_text)| {
352            let old_slice = &old_text[old_range.clone()];
353
354            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
355            let suffix_len = common_prefix(
356                old_slice[prefix_len..].chars().rev(),
357                new_text[prefix_len..].chars().rev(),
358            );
359
360            old_range.start += offset;
361            old_range.end += offset;
362            old_range.start += prefix_len;
363            old_range.end -= suffix_len;
364
365            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
366            let range = if old_range.is_empty() {
367                let anchor = snapshot.anchor_after(old_range.start);
368                anchor..anchor
369            } else {
370                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
371            };
372            (range, new_text)
373        })
374        .collect()
375}
376
377fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
378    a.zip(b)
379        .take_while(|(a, b)| a == b)
380        .map(|(a, _)| a.len_utf8())
381        .sum()
382}
383
384fn git_info_for_file(
385    project: &Entity<Project>,
386    project_path: &ProjectPath,
387    cx: &App,
388) -> Option<PredictEditsGitInfo> {
389    let git_store = project.read(cx).git_store().read(cx);
390    if let Some((repository, _repo_path)) =
391        git_store.repository_and_path_for_project_path(project_path, cx)
392    {
393        let repository = repository.read(cx);
394        let head_sha = repository
395            .head_commit
396            .as_ref()
397            .map(|head_commit| head_commit.sha.to_string());
398        let remote_origin_url = repository.remote_origin_url.clone();
399        let remote_upstream_url = repository.remote_upstream_url.clone();
400        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
401            return None;
402        }
403        Some(PredictEditsGitInfo {
404            head_sha,
405            remote_origin_url,
406            remote_upstream_url,
407        })
408    } else {
409        None
410    }
411}
412
413pub struct GatherContextOutput {
414    pub body: PredictEditsBody,
415    pub context_range: Range<Point>,
416    pub editable_range: Range<usize>,
417    pub included_events_count: usize,
418}
419
420pub fn gather_context(
421    full_path_str: String,
422    snapshot: &BufferSnapshot,
423    cursor_point: language::Point,
424    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
425    trigger: PredictEditsRequestTrigger,
426    cx: &App,
427) -> Task<Result<GatherContextOutput>> {
428    cx.background_spawn({
429        let snapshot = snapshot.clone();
430        async move {
431            let input_excerpt = excerpt_for_cursor_position(
432                cursor_point,
433                &full_path_str,
434                &snapshot,
435                MAX_REWRITE_TOKENS,
436                MAX_CONTEXT_TOKENS,
437            );
438            let (input_events, included_events_count) = prompt_for_events();
439            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
440
441            let body = PredictEditsBody {
442                input_events,
443                input_excerpt: input_excerpt.prompt,
444                can_collect_data: false,
445                diagnostic_groups: None,
446                git_info: None,
447                outline: None,
448                speculated_output: None,
449                trigger,
450            };
451
452            Ok(GatherContextOutput {
453                body,
454                context_range: input_excerpt.context_range,
455                editable_range,
456                included_events_count,
457            })
458        }
459    })
460}
461
462pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
463    prompt_for_events_impl(events, max_tokens).0
464}
465
466fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
467    let mut result = String::new();
468    for (ix, event) in events.iter().rev().enumerate() {
469        let event_string = format_event(event.as_ref());
470        let event_tokens = guess_token_count(event_string.len());
471        if event_tokens > remaining_tokens {
472            return (result, ix);
473        }
474
475        if !result.is_empty() {
476            result.insert_str(0, "\n\n");
477        }
478        result.insert_str(0, &event_string);
479        remaining_tokens -= event_tokens;
480    }
481    return (result, events.len());
482}
483
484pub fn format_event(event: &Event) -> String {
485    match event {
486        Event::BufferChange {
487            path,
488            old_path,
489            diff,
490            ..
491        } => {
492            let mut prompt = String::new();
493
494            if old_path != path {
495                writeln!(
496                    prompt,
497                    "User renamed {} to {}\n",
498                    old_path.display(),
499                    path.display()
500                )
501                .unwrap();
502            }
503
504            if !diff.is_empty() {
505                write!(
506                    prompt,
507                    "User edited {}:\n```diff\n{}\n```",
508                    path.display(),
509                    diff
510                )
511                .unwrap();
512            }
513
514            prompt
515        }
516    }
517}
518
519#[derive(Debug)]
520pub struct InputExcerpt {
521    pub context_range: Range<Point>,
522    pub editable_range: Range<Point>,
523    pub prompt: String,
524}
525
526pub fn excerpt_for_cursor_position(
527    position: Point,
528    path: &str,
529    snapshot: &BufferSnapshot,
530    editable_region_token_limit: usize,
531    context_token_limit: usize,
532) -> InputExcerpt {
533    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
534        position,
535        snapshot,
536        editable_region_token_limit,
537        context_token_limit,
538    );
539
540    let mut prompt = String::new();
541
542    writeln!(&mut prompt, "```{path}").unwrap();
543    if context_range.start == Point::zero() {
544        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
545    }
546
547    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
548        prompt.push_str(chunk.text);
549    }
550
551    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
552
553    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
554        prompt.push_str(chunk.text);
555    }
556    write!(prompt, "\n```").unwrap();
557
558    InputExcerpt {
559        context_range,
560        editable_range,
561        prompt,
562    }
563}
564
565fn push_editable_range(
566    cursor_position: Point,
567    snapshot: &BufferSnapshot,
568    editable_range: Range<Point>,
569    prompt: &mut String,
570) {
571    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
572    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
573        prompt.push_str(chunk.text);
574    }
575    prompt.push_str(CURSOR_MARKER);
576    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
577        prompt.push_str(chunk.text);
578    }
579    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585    use gpui::{App, AppContext};
586    use indoc::indoc;
587    use language::Buffer;
588
589    #[gpui::test]
590    fn test_excerpt_for_cursor_position(cx: &mut App) {
591        let text = indoc! {r#"
592            fn foo() {
593                let x = 42;
594                println!("Hello, world!");
595            }
596
597            fn bar() {
598                let x = 42;
599                let mut sum = 0;
600                for i in 0..x {
601                    sum += i;
602                }
603                println!("Sum: {}", sum);
604                return sum;
605            }
606
607            fn generate_random_numbers() -> Vec<i32> {
608                let mut rng = rand::thread_rng();
609                let mut numbers = Vec::new();
610                for _ in 0..5 {
611                    numbers.push(rng.random_range(1..101));
612                }
613                numbers
614            }
615        "#};
616        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
617        let snapshot = buffer.read(cx).snapshot();
618
619        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
620        // when a larger scope doesn't fit the editable region.
621        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
622        assert_eq!(
623            excerpt.prompt,
624            indoc! {r#"
625            ```main.rs
626                let x = 42;
627                println!("Hello, world!");
628            <|editable_region_start|>
629            }
630
631            fn bar() {
632                let x = 42;
633                let mut sum = 0;
634                for i in 0..x {
635                    sum += i;
636                }
637                println!("Sum: {}", sum);
638                r<|user_cursor_is_here|>eturn sum;
639            }
640
641            fn generate_random_numbers() -> Vec<i32> {
642            <|editable_region_end|>
643                let mut rng = rand::thread_rng();
644                let mut numbers = Vec::new();
645            ```"#}
646        );
647
648        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
649        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
650        assert_eq!(
651            excerpt.prompt,
652            indoc! {r#"
653            ```main.rs
654            fn bar() {
655                let x = 42;
656                let mut sum = 0;
657            <|editable_region_start|>
658                for i in 0..x {
659                    sum += i;
660                }
661                println!("Sum: {}", sum);
662                r<|user_cursor_is_here|>eturn sum;
663            }
664
665            fn generate_random_numbers() -> Vec<i32> {
666                let mut rng = rand::thread_rng();
667            <|editable_region_end|>
668                let mut numbers = Vec::new();
669                for _ in 0..5 {
670                    numbers.push(rng.random_range(1..101));
671            ```"#}
672        );
673    }
674
675    #[gpui::test]
676    fn test_parse_edits_empty_editable_region(cx: &mut App) {
677        let text = "fn foo() {\n    let x = 42;\n}\n";
678        let buffer = cx.new(|cx| Buffer::local(text, cx));
679        let snapshot = buffer.read(cx).snapshot();
680
681        let output = "<|editable_region_start|>\n<|editable_region_end|>";
682        let editable_range = 0..text.len();
683        let edits = parse_edits(output, editable_range, &snapshot).unwrap();
684        assert_eq!(edits.len(), 1);
685        let (range, new_text) = &edits[0];
686        assert_eq!(range.to_offset(&snapshot), 0..text.len(),);
687        assert_eq!(new_text.as_ref(), "");
688    }
689}