zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use language::{
 15    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 16};
 17use project::{Project, ProjectPath};
 18use release_channel::AppVersion;
 19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 20use zeta_prompt::{Event, ZetaPromptInput};
 21
 22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 26
 27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 30
 31pub(crate) fn request_prediction_with_zeta1(
 32    store: &mut EditPredictionStore,
 33    EditPredictionModelInput {
 34        project,
 35        buffer,
 36        snapshot,
 37        position,
 38        events,
 39        trigger,
 40        debug_tx,
 41        ..
 42    }: EditPredictionModelInput,
 43    cx: &mut Context<EditPredictionStore>,
 44) -> Task<Result<Option<EditPredictionResult>>> {
 45    let buffer_snapshotted_at = Instant::now();
 46    let client = store.client.clone();
 47    let llm_token = store.llm_token.clone();
 48    let app_version = AppVersion::global(cx);
 49
 50    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 51        let can_collect_file = store.can_collect_file(&project, file, cx);
 52        let git_info = if can_collect_file {
 53            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 54        } else {
 55            None
 56        };
 57        (git_info, can_collect_file)
 58    } else {
 59        (None, false)
 60    };
 61
 62    let full_path: Arc<Path> = snapshot
 63        .file()
 64        .map(|f| Arc::from(f.full_path(cx).as_path()))
 65        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 66    let full_path_str = full_path.to_string_lossy().into_owned();
 67    let cursor_point = position.to_point(&snapshot);
 68    let prompt_for_events = {
 69        let events = events.clone();
 70        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 71    };
 72    let gather_task = gather_context(
 73        full_path_str,
 74        &snapshot,
 75        cursor_point,
 76        prompt_for_events,
 77        trigger,
 78        cx,
 79    );
 80
 81    cx.spawn(async move |this, cx| {
 82        let GatherContextOutput {
 83            mut body,
 84            context_range,
 85            editable_range,
 86            included_events_count,
 87        } = gather_task.await?;
 88        let done_gathering_context_at = Instant::now();
 89
 90        let included_events = &events[events.len() - included_events_count..events.len()];
 91        body.can_collect_data = can_collect_file
 92            && this
 93                .read_with(cx, |this, _| this.can_collect_events(included_events))
 94                .unwrap_or(false);
 95        if body.can_collect_data {
 96            body.git_info = git_info;
 97        }
 98
 99        log::debug!(
100            "Events:\n{}\nExcerpt:\n{:?}",
101            body.input_events,
102            body.input_excerpt
103        );
104
105        let http_client = client.http_client();
106
107        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
108            |request| {
109                let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
110                    predict_edits_url
111                } else {
112                    http_client
113                        .build_zed_llm_url("/predict_edits/v2", &[])?
114                        .as_str()
115                        .into()
116                };
117                Ok(request
118                    .uri(uri)
119                    .body(serde_json::to_string(&body)?.into())?)
120            },
121            client,
122            llm_token,
123            app_version,
124        )
125        .await;
126
127        let context_start_offset = context_range.start.to_offset(&snapshot);
128        let editable_offset_range = editable_range.to_offset(&snapshot);
129
130        let inputs = ZetaPromptInput {
131            events: included_events.into(),
132            related_files: vec![].into(),
133            cursor_path: full_path,
134            cursor_excerpt: snapshot
135                .text_for_range(context_range)
136                .collect::<String>()
137                .into(),
138            editable_range_in_excerpt: (editable_range.start - context_start_offset)
139                ..(editable_offset_range.end - context_start_offset),
140            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
141        };
142
143        if let Some(debug_tx) = &debug_tx {
144            debug_tx
145                .unbounded_send(DebugEvent::EditPredictionStarted(
146                    EditPredictionStartedDebugEvent {
147                        buffer: buffer.downgrade(),
148                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
149                        position,
150                    },
151                ))
152                .ok();
153        }
154
155        let (response, usage) = match response {
156            Ok(response) => response,
157            Err(err) => {
158                if err.is::<ZedUpdateRequiredError>() {
159                    cx.update(|cx| {
160                        this.update(cx, |ep_store, _cx| {
161                            ep_store.update_required = true;
162                        })
163                        .ok();
164
165                        let error_message: SharedString = err.to_string().into();
166                        show_app_notification(
167                            NotificationId::unique::<ZedUpdateRequiredError>(),
168                            cx,
169                            move |cx| {
170                                cx.new(|cx| {
171                                    ErrorMessagePrompt::new(error_message.clone(), cx)
172                                        .with_link_button("Update Zed", "https://zed.dev/releases")
173                                })
174                            },
175                        );
176                    })
177                    .ok();
178                }
179
180                return Err(err);
181            }
182        };
183
184        let received_response_at = Instant::now();
185        log::debug!("completion response: {}", &response.output_excerpt);
186
187        if let Some(usage) = usage {
188            this.update(cx, |this, cx| {
189                this.user_store.update(cx, |user_store, cx| {
190                    user_store.update_edit_prediction_usage(usage, cx);
191                });
192            })
193            .ok();
194        }
195
196        if let Some(debug_tx) = &debug_tx {
197            debug_tx
198                .unbounded_send(DebugEvent::EditPredictionFinished(
199                    EditPredictionFinishedDebugEvent {
200                        buffer: buffer.downgrade(),
201                        model_output: Some(response.output_excerpt.clone()),
202                        position,
203                    },
204                ))
205                .ok();
206        }
207
208        let edit_prediction = process_completion_response(
209            response,
210            buffer,
211            &snapshot,
212            editable_range,
213            inputs,
214            buffer_snapshotted_at,
215            received_response_at,
216            cx,
217        )
218        .await;
219
220        let finished_at = Instant::now();
221
222        // record latency for ~1% of requests
223        if rand::random::<u8>() <= 2 {
224            telemetry::event!(
225                "Edit Prediction Request",
226                context_latency = done_gathering_context_at
227                    .duration_since(buffer_snapshotted_at)
228                    .as_millis(),
229                request_latency = received_response_at
230                    .duration_since(done_gathering_context_at)
231                    .as_millis(),
232                process_latency = finished_at.duration_since(received_response_at).as_millis()
233            );
234        }
235
236        edit_prediction.map(Some)
237    })
238}
239
240fn process_completion_response(
241    prediction_response: PredictEditsResponse,
242    buffer: Entity<Buffer>,
243    snapshot: &BufferSnapshot,
244    editable_range: Range<usize>,
245    inputs: ZetaPromptInput,
246    buffer_snapshotted_at: Instant,
247    received_response_at: Instant,
248    cx: &AsyncApp,
249) -> Task<Result<EditPredictionResult>> {
250    let snapshot = snapshot.clone();
251    let request_id = prediction_response.request_id;
252    let output_excerpt = prediction_response.output_excerpt;
253    cx.spawn(async move |cx| {
254        let output_excerpt: Arc<str> = output_excerpt.into();
255
256        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
257            .background_spawn({
258                let output_excerpt = output_excerpt.clone();
259                let editable_range = editable_range.clone();
260                let snapshot = snapshot.clone();
261                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
262            })
263            .await?
264            .into();
265
266        let id = EditPredictionId(request_id.into());
267        Ok(EditPredictionResult::new(
268            id,
269            &buffer,
270            &snapshot,
271            edits,
272            buffer_snapshotted_at,
273            received_response_at,
274            inputs,
275            cx,
276        )
277        .await)
278    })
279}
280
281fn parse_edits(
282    output_excerpt: Arc<str>,
283    editable_range: Range<usize>,
284    snapshot: &BufferSnapshot,
285) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
286    let content = output_excerpt.replace(CURSOR_MARKER, "");
287
288    let start_markers = content
289        .match_indices(EDITABLE_REGION_START_MARKER)
290        .collect::<Vec<_>>();
291    anyhow::ensure!(
292        start_markers.len() == 1,
293        "expected exactly one start marker, found {}",
294        start_markers.len()
295    );
296
297    let end_markers = content
298        .match_indices(EDITABLE_REGION_END_MARKER)
299        .collect::<Vec<_>>();
300    anyhow::ensure!(
301        end_markers.len() == 1,
302        "expected exactly one end marker, found {}",
303        end_markers.len()
304    );
305
306    let sof_markers = content
307        .match_indices(START_OF_FILE_MARKER)
308        .collect::<Vec<_>>();
309    anyhow::ensure!(
310        sof_markers.len() <= 1,
311        "expected at most one start-of-file marker, found {}",
312        sof_markers.len()
313    );
314
315    let codefence_start = start_markers[0].0;
316    let content = &content[codefence_start..];
317
318    let newline_ix = content.find('\n').context("could not find newline")?;
319    let content = &content[newline_ix + 1..];
320
321    let codefence_end = content
322        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
323        .context("could not find end marker")?;
324    let new_text = &content[..codefence_end];
325
326    let old_text = snapshot
327        .text_for_range(editable_range.clone())
328        .collect::<String>();
329
330    Ok(compute_edits(
331        old_text,
332        new_text,
333        editable_range.start,
334        snapshot,
335    ))
336}
337
338pub fn compute_edits(
339    old_text: String,
340    new_text: &str,
341    offset: usize,
342    snapshot: &BufferSnapshot,
343) -> Vec<(Range<Anchor>, Arc<str>)> {
344    text_diff(&old_text, new_text)
345        .into_iter()
346        .map(|(mut old_range, new_text)| {
347            old_range.start += offset;
348            old_range.end += offset;
349
350            let prefix_len = common_prefix(
351                snapshot.chars_for_range(old_range.clone()),
352                new_text.chars(),
353            );
354            old_range.start += prefix_len;
355
356            let suffix_len = common_prefix(
357                snapshot.reversed_chars_for_range(old_range.clone()),
358                new_text[prefix_len..].chars().rev(),
359            );
360            old_range.end = old_range.end.saturating_sub(suffix_len);
361
362            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
363            let range = if old_range.is_empty() {
364                let anchor = snapshot.anchor_after(old_range.start);
365                anchor..anchor
366            } else {
367                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
368            };
369            (range, new_text)
370        })
371        .collect()
372}
373
374fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
375    a.zip(b)
376        .take_while(|(a, b)| a == b)
377        .map(|(a, _)| a.len_utf8())
378        .sum()
379}
380
381fn git_info_for_file(
382    project: &Entity<Project>,
383    project_path: &ProjectPath,
384    cx: &App,
385) -> Option<PredictEditsGitInfo> {
386    let git_store = project.read(cx).git_store().read(cx);
387    if let Some((repository, _repo_path)) =
388        git_store.repository_and_path_for_project_path(project_path, cx)
389    {
390        let repository = repository.read(cx);
391        let head_sha = repository
392            .head_commit
393            .as_ref()
394            .map(|head_commit| head_commit.sha.to_string());
395        let remote_origin_url = repository.remote_origin_url.clone();
396        let remote_upstream_url = repository.remote_upstream_url.clone();
397        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
398            return None;
399        }
400        Some(PredictEditsGitInfo {
401            head_sha,
402            remote_origin_url,
403            remote_upstream_url,
404        })
405    } else {
406        None
407    }
408}
409
410pub struct GatherContextOutput {
411    pub body: PredictEditsBody,
412    pub context_range: Range<Point>,
413    pub editable_range: Range<usize>,
414    pub included_events_count: usize,
415}
416
417pub fn gather_context(
418    full_path_str: String,
419    snapshot: &BufferSnapshot,
420    cursor_point: language::Point,
421    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
422    trigger: PredictEditsRequestTrigger,
423    cx: &App,
424) -> Task<Result<GatherContextOutput>> {
425    cx.background_spawn({
426        let snapshot = snapshot.clone();
427        async move {
428            let input_excerpt = excerpt_for_cursor_position(
429                cursor_point,
430                &full_path_str,
431                &snapshot,
432                MAX_REWRITE_TOKENS,
433                MAX_CONTEXT_TOKENS,
434            );
435            let (input_events, included_events_count) = prompt_for_events();
436            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
437
438            let body = PredictEditsBody {
439                input_events,
440                input_excerpt: input_excerpt.prompt,
441                can_collect_data: false,
442                diagnostic_groups: None,
443                git_info: None,
444                outline: None,
445                speculated_output: None,
446                trigger,
447            };
448
449            Ok(GatherContextOutput {
450                body,
451                context_range: input_excerpt.context_range,
452                editable_range,
453                included_events_count,
454            })
455        }
456    })
457}
458
459fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
460    let mut result = String::new();
461    for (ix, event) in events.iter().rev().enumerate() {
462        let event_string = format_event(event.as_ref());
463        let event_tokens = guess_token_count(event_string.len());
464        if event_tokens > remaining_tokens {
465            return (result, ix);
466        }
467
468        if !result.is_empty() {
469            result.insert_str(0, "\n\n");
470        }
471        result.insert_str(0, &event_string);
472        remaining_tokens -= event_tokens;
473    }
474    return (result, events.len());
475}
476
477pub fn format_event(event: &Event) -> String {
478    match event {
479        Event::BufferChange {
480            path,
481            old_path,
482            diff,
483            ..
484        } => {
485            let mut prompt = String::new();
486
487            if old_path != path {
488                writeln!(
489                    prompt,
490                    "User renamed {} to {}\n",
491                    old_path.display(),
492                    path.display()
493                )
494                .unwrap();
495            }
496
497            if !diff.is_empty() {
498                write!(
499                    prompt,
500                    "User edited {}:\n```diff\n{}\n```",
501                    path.display(),
502                    diff
503                )
504                .unwrap();
505            }
506
507            prompt
508        }
509    }
510}
511
512#[derive(Debug)]
513pub struct InputExcerpt {
514    pub context_range: Range<Point>,
515    pub editable_range: Range<Point>,
516    pub prompt: String,
517}
518
519pub fn excerpt_for_cursor_position(
520    position: Point,
521    path: &str,
522    snapshot: &BufferSnapshot,
523    editable_region_token_limit: usize,
524    context_token_limit: usize,
525) -> InputExcerpt {
526    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
527        position,
528        snapshot,
529        editable_region_token_limit,
530        context_token_limit,
531    );
532
533    let mut prompt = String::new();
534
535    writeln!(&mut prompt, "```{path}").unwrap();
536    if context_range.start == Point::zero() {
537        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
538    }
539
540    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
541        prompt.push_str(chunk.text);
542    }
543
544    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
545
546    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
547        prompt.push_str(chunk.text);
548    }
549    write!(prompt, "\n```").unwrap();
550
551    InputExcerpt {
552        context_range,
553        editable_range,
554        prompt,
555    }
556}
557
558fn push_editable_range(
559    cursor_position: Point,
560    snapshot: &BufferSnapshot,
561    editable_range: Range<Point>,
562    prompt: &mut String,
563) {
564    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
565    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
566        prompt.push_str(chunk.text);
567    }
568    prompt.push_str(CURSOR_MARKER);
569    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
570        prompt.push_str(chunk.text);
571    }
572    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use gpui::{App, AppContext};
579    use indoc::indoc;
580    use language::Buffer;
581
582    #[gpui::test]
583    fn test_excerpt_for_cursor_position(cx: &mut App) {
584        let text = indoc! {r#"
585            fn foo() {
586                let x = 42;
587                println!("Hello, world!");
588            }
589
590            fn bar() {
591                let x = 42;
592                let mut sum = 0;
593                for i in 0..x {
594                    sum += i;
595                }
596                println!("Sum: {}", sum);
597                return sum;
598            }
599
600            fn generate_random_numbers() -> Vec<i32> {
601                let mut rng = rand::thread_rng();
602                let mut numbers = Vec::new();
603                for _ in 0..5 {
604                    numbers.push(rng.random_range(1..101));
605                }
606                numbers
607            }
608        "#};
609        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
610        let snapshot = buffer.read(cx).snapshot();
611
612        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
613        // when a larger scope doesn't fit the editable region.
614        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
615        assert_eq!(
616            excerpt.prompt,
617            indoc! {r#"
618            ```main.rs
619                let x = 42;
620                println!("Hello, world!");
621            <|editable_region_start|>
622            }
623
624            fn bar() {
625                let x = 42;
626                let mut sum = 0;
627                for i in 0..x {
628                    sum += i;
629                }
630                println!("Sum: {}", sum);
631                r<|user_cursor_is_here|>eturn sum;
632            }
633
634            fn generate_random_numbers() -> Vec<i32> {
635            <|editable_region_end|>
636                let mut rng = rand::thread_rng();
637                let mut numbers = Vec::new();
638            ```"#}
639        );
640
641        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
642        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
643        assert_eq!(
644            excerpt.prompt,
645            indoc! {r#"
646            ```main.rs
647            fn bar() {
648                let x = 42;
649                let mut sum = 0;
650            <|editable_region_start|>
651                for i in 0..x {
652                    sum += i;
653                }
654                println!("Sum: {}", sum);
655                r<|user_cursor_is_here|>eturn sum;
656            }
657
658            fn generate_random_numbers() -> Vec<i32> {
659                let mut rng = rand::thread_rng();
660            <|editable_region_end|>
661                let mut numbers = Vec::new();
662                for _ in 0..5 {
663                    numbers.push(rng.random_range(1..101));
664            ```"#}
665        );
666    }
667}