zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use language::{
 15    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 16};
 17use project::{Project, ProjectPath};
 18use release_channel::AppVersion;
 19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 20use zeta_prompt::{Event, ZetaPromptInput};
 21
 22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 26
 27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 30
 31pub(crate) fn request_prediction_with_zeta1(
 32    store: &mut EditPredictionStore,
 33    EditPredictionModelInput {
 34        project,
 35        buffer,
 36        snapshot,
 37        position,
 38        events,
 39        trigger,
 40        debug_tx,
 41        ..
 42    }: EditPredictionModelInput,
 43    cx: &mut Context<EditPredictionStore>,
 44) -> Task<Result<Option<EditPredictionResult>>> {
 45    let buffer_snapshotted_at = Instant::now();
 46    let client = store.client.clone();
 47    let llm_token = store.llm_token.clone();
 48    let app_version = AppVersion::global(cx);
 49
 50    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 51        let can_collect_file = store.can_collect_file(&project, file, cx);
 52        let git_info = if can_collect_file {
 53            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 54        } else {
 55            None
 56        };
 57        (git_info, can_collect_file)
 58    } else {
 59        (None, false)
 60    };
 61
 62    let full_path: Arc<Path> = snapshot
 63        .file()
 64        .map(|f| Arc::from(f.full_path(cx).as_path()))
 65        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 66    let full_path_str = full_path.to_string_lossy().into_owned();
 67    let cursor_point = position.to_point(&snapshot);
 68    let prompt_for_events = {
 69        let events = events.clone();
 70        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 71    };
 72    let gather_task = gather_context(
 73        full_path_str,
 74        &snapshot,
 75        cursor_point,
 76        prompt_for_events,
 77        trigger,
 78        cx,
 79    );
 80
 81    let (uri, require_auth) = match &store.custom_predict_edits_url {
 82        Some(custom_url) => (custom_url.clone(), false),
 83        None => {
 84            match client
 85                .http_client()
 86                .build_zed_llm_url("/predict_edits/v2", &[])
 87            {
 88                Ok(url) => (url.into(), true),
 89                Err(err) => return Task::ready(Err(err)),
 90            }
 91        }
 92    };
 93
 94    cx.spawn(async move |this, cx| {
 95        let GatherContextOutput {
 96            mut body,
 97            context_range,
 98            editable_range,
 99            included_events_count,
100        } = gather_task.await?;
101        let done_gathering_context_at = Instant::now();
102
103        let included_events = &events[events.len() - included_events_count..events.len()];
104        body.can_collect_data = can_collect_file
105            && this
106                .read_with(cx, |this, _| this.can_collect_events(included_events))
107                .unwrap_or(false);
108        if body.can_collect_data {
109            body.git_info = git_info;
110        }
111
112        log::debug!(
113            "Events:\n{}\nExcerpt:\n{:?}",
114            body.input_events,
115            body.input_excerpt
116        );
117
118        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
119            |request| {
120                Ok(request
121                    .uri(uri.as_str())
122                    .body(serde_json::to_string(&body)?.into())?)
123            },
124            client,
125            llm_token,
126            app_version,
127            require_auth,
128        )
129        .await;
130
131        let context_start_offset = context_range.start.to_offset(&snapshot);
132        let editable_offset_range = editable_range.to_offset(&snapshot);
133
134        let inputs = ZetaPromptInput {
135            events: included_events.into(),
136            related_files: vec![].into(),
137            cursor_path: full_path,
138            cursor_excerpt: snapshot
139                .text_for_range(context_range)
140                .collect::<String>()
141                .into(),
142            editable_range_in_excerpt: (editable_range.start - context_start_offset)
143                ..(editable_offset_range.end - context_start_offset),
144            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
145        };
146
147        if let Some(debug_tx) = &debug_tx {
148            debug_tx
149                .unbounded_send(DebugEvent::EditPredictionStarted(
150                    EditPredictionStartedDebugEvent {
151                        buffer: buffer.downgrade(),
152                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
153                        position,
154                    },
155                ))
156                .ok();
157        }
158
159        let (response, usage) = match response {
160            Ok(response) => response,
161            Err(err) => {
162                if err.is::<ZedUpdateRequiredError>() {
163                    cx.update(|cx| {
164                        this.update(cx, |ep_store, _cx| {
165                            ep_store.update_required = true;
166                        })
167                        .ok();
168
169                        let error_message: SharedString = err.to_string().into();
170                        show_app_notification(
171                            NotificationId::unique::<ZedUpdateRequiredError>(),
172                            cx,
173                            move |cx| {
174                                cx.new(|cx| {
175                                    ErrorMessagePrompt::new(error_message.clone(), cx)
176                                        .with_link_button("Update Zed", "https://zed.dev/releases")
177                                })
178                            },
179                        );
180                    })
181                    .ok();
182                }
183
184                return Err(err);
185            }
186        };
187
188        let received_response_at = Instant::now();
189        log::debug!("completion response: {}", &response.output_excerpt);
190
191        if let Some(usage) = usage {
192            this.update(cx, |this, cx| {
193                this.user_store.update(cx, |user_store, cx| {
194                    user_store.update_edit_prediction_usage(usage, cx);
195                });
196            })
197            .ok();
198        }
199
200        if let Some(debug_tx) = &debug_tx {
201            debug_tx
202                .unbounded_send(DebugEvent::EditPredictionFinished(
203                    EditPredictionFinishedDebugEvent {
204                        buffer: buffer.downgrade(),
205                        model_output: Some(response.output_excerpt.clone()),
206                        position,
207                    },
208                ))
209                .ok();
210        }
211
212        let edit_prediction = process_completion_response(
213            response,
214            buffer,
215            &snapshot,
216            editable_range,
217            inputs,
218            buffer_snapshotted_at,
219            received_response_at,
220            cx,
221        )
222        .await;
223
224        let finished_at = Instant::now();
225
226        // record latency for ~1% of requests
227        if rand::random::<u8>() <= 2 {
228            telemetry::event!(
229                "Edit Prediction Request",
230                context_latency = done_gathering_context_at
231                    .duration_since(buffer_snapshotted_at)
232                    .as_millis(),
233                request_latency = received_response_at
234                    .duration_since(done_gathering_context_at)
235                    .as_millis(),
236                process_latency = finished_at.duration_since(received_response_at).as_millis()
237            );
238        }
239
240        edit_prediction.map(Some)
241    })
242}
243
244fn process_completion_response(
245    prediction_response: PredictEditsResponse,
246    buffer: Entity<Buffer>,
247    snapshot: &BufferSnapshot,
248    editable_range: Range<usize>,
249    inputs: ZetaPromptInput,
250    buffer_snapshotted_at: Instant,
251    received_response_at: Instant,
252    cx: &AsyncApp,
253) -> Task<Result<EditPredictionResult>> {
254    let snapshot = snapshot.clone();
255    let request_id = prediction_response.request_id;
256    let output_excerpt = prediction_response.output_excerpt;
257    cx.spawn(async move |cx| {
258        let output_excerpt: Arc<str> = output_excerpt.into();
259
260        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
261            .background_spawn({
262                let output_excerpt = output_excerpt.clone();
263                let editable_range = editable_range.clone();
264                let snapshot = snapshot.clone();
265                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
266            })
267            .await?
268            .into();
269
270        let id = EditPredictionId(request_id.into());
271        Ok(EditPredictionResult::new(
272            id,
273            &buffer,
274            &snapshot,
275            edits,
276            buffer_snapshotted_at,
277            received_response_at,
278            inputs,
279            cx,
280        )
281        .await)
282    })
283}
284
285fn parse_edits(
286    output_excerpt: Arc<str>,
287    editable_range: Range<usize>,
288    snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290    let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292    let start_markers = content
293        .match_indices(EDITABLE_REGION_START_MARKER)
294        .collect::<Vec<_>>();
295    anyhow::ensure!(
296        start_markers.len() == 1,
297        "expected exactly one start marker, found {}",
298        start_markers.len()
299    );
300
301    let end_markers = content
302        .match_indices(EDITABLE_REGION_END_MARKER)
303        .collect::<Vec<_>>();
304    anyhow::ensure!(
305        end_markers.len() == 1,
306        "expected exactly one end marker, found {}",
307        end_markers.len()
308    );
309
310    let sof_markers = content
311        .match_indices(START_OF_FILE_MARKER)
312        .collect::<Vec<_>>();
313    anyhow::ensure!(
314        sof_markers.len() <= 1,
315        "expected at most one start-of-file marker, found {}",
316        sof_markers.len()
317    );
318
319    let codefence_start = start_markers[0].0;
320    let content = &content[codefence_start..];
321
322    let newline_ix = content.find('\n').context("could not find newline")?;
323    let content = &content[newline_ix + 1..];
324
325    let codefence_end = content
326        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
327        .context("could not find end marker")?;
328    let new_text = &content[..codefence_end];
329
330    let old_text = snapshot
331        .text_for_range(editable_range.clone())
332        .collect::<String>();
333
334    Ok(compute_edits(
335        old_text,
336        new_text,
337        editable_range.start,
338        snapshot,
339    ))
340}
341
342pub fn compute_edits(
343    old_text: String,
344    new_text: &str,
345    offset: usize,
346    snapshot: &BufferSnapshot,
347) -> Vec<(Range<Anchor>, Arc<str>)> {
348    text_diff(&old_text, new_text)
349        .into_iter()
350        .map(|(mut old_range, new_text)| {
351            old_range.start += offset;
352            old_range.end += offset;
353
354            let prefix_len = common_prefix(
355                snapshot.chars_for_range(old_range.clone()),
356                new_text.chars(),
357            );
358            old_range.start += prefix_len;
359
360            let suffix_len = common_prefix(
361                snapshot.reversed_chars_for_range(old_range.clone()),
362                new_text[prefix_len..].chars().rev(),
363            );
364            old_range.end = old_range.end.saturating_sub(suffix_len);
365
366            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
367            let range = if old_range.is_empty() {
368                let anchor = snapshot.anchor_after(old_range.start);
369                anchor..anchor
370            } else {
371                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
372            };
373            (range, new_text)
374        })
375        .collect()
376}
377
378fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
379    a.zip(b)
380        .take_while(|(a, b)| a == b)
381        .map(|(a, _)| a.len_utf8())
382        .sum()
383}
384
385fn git_info_for_file(
386    project: &Entity<Project>,
387    project_path: &ProjectPath,
388    cx: &App,
389) -> Option<PredictEditsGitInfo> {
390    let git_store = project.read(cx).git_store().read(cx);
391    if let Some((repository, _repo_path)) =
392        git_store.repository_and_path_for_project_path(project_path, cx)
393    {
394        let repository = repository.read(cx);
395        let head_sha = repository
396            .head_commit
397            .as_ref()
398            .map(|head_commit| head_commit.sha.to_string());
399        let remote_origin_url = repository.remote_origin_url.clone();
400        let remote_upstream_url = repository.remote_upstream_url.clone();
401        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
402            return None;
403        }
404        Some(PredictEditsGitInfo {
405            head_sha,
406            remote_origin_url,
407            remote_upstream_url,
408        })
409    } else {
410        None
411    }
412}
413
414pub struct GatherContextOutput {
415    pub body: PredictEditsBody,
416    pub context_range: Range<Point>,
417    pub editable_range: Range<usize>,
418    pub included_events_count: usize,
419}
420
421pub fn gather_context(
422    full_path_str: String,
423    snapshot: &BufferSnapshot,
424    cursor_point: language::Point,
425    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
426    trigger: PredictEditsRequestTrigger,
427    cx: &App,
428) -> Task<Result<GatherContextOutput>> {
429    cx.background_spawn({
430        let snapshot = snapshot.clone();
431        async move {
432            let input_excerpt = excerpt_for_cursor_position(
433                cursor_point,
434                &full_path_str,
435                &snapshot,
436                MAX_REWRITE_TOKENS,
437                MAX_CONTEXT_TOKENS,
438            );
439            let (input_events, included_events_count) = prompt_for_events();
440            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
441
442            let body = PredictEditsBody {
443                input_events,
444                input_excerpt: input_excerpt.prompt,
445                can_collect_data: false,
446                diagnostic_groups: None,
447                git_info: None,
448                outline: None,
449                speculated_output: None,
450                trigger,
451            };
452
453            Ok(GatherContextOutput {
454                body,
455                context_range: input_excerpt.context_range,
456                editable_range,
457                included_events_count,
458            })
459        }
460    })
461}
462
463fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
464    let mut result = String::new();
465    for (ix, event) in events.iter().rev().enumerate() {
466        let event_string = format_event(event.as_ref());
467        let event_tokens = guess_token_count(event_string.len());
468        if event_tokens > remaining_tokens {
469            return (result, ix);
470        }
471
472        if !result.is_empty() {
473            result.insert_str(0, "\n\n");
474        }
475        result.insert_str(0, &event_string);
476        remaining_tokens -= event_tokens;
477    }
478    return (result, events.len());
479}
480
481pub fn format_event(event: &Event) -> String {
482    match event {
483        Event::BufferChange {
484            path,
485            old_path,
486            diff,
487            ..
488        } => {
489            let mut prompt = String::new();
490
491            if old_path != path {
492                writeln!(
493                    prompt,
494                    "User renamed {} to {}\n",
495                    old_path.display(),
496                    path.display()
497                )
498                .unwrap();
499            }
500
501            if !diff.is_empty() {
502                write!(
503                    prompt,
504                    "User edited {}:\n```diff\n{}\n```",
505                    path.display(),
506                    diff
507                )
508                .unwrap();
509            }
510
511            prompt
512        }
513    }
514}
515
516#[derive(Debug)]
517pub struct InputExcerpt {
518    pub context_range: Range<Point>,
519    pub editable_range: Range<Point>,
520    pub prompt: String,
521}
522
523pub fn excerpt_for_cursor_position(
524    position: Point,
525    path: &str,
526    snapshot: &BufferSnapshot,
527    editable_region_token_limit: usize,
528    context_token_limit: usize,
529) -> InputExcerpt {
530    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
531        position,
532        snapshot,
533        editable_region_token_limit,
534        context_token_limit,
535    );
536
537    let mut prompt = String::new();
538
539    writeln!(&mut prompt, "```{path}").unwrap();
540    if context_range.start == Point::zero() {
541        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
542    }
543
544    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
545        prompt.push_str(chunk.text);
546    }
547
548    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
549
550    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
551        prompt.push_str(chunk.text);
552    }
553    write!(prompt, "\n```").unwrap();
554
555    InputExcerpt {
556        context_range,
557        editable_range,
558        prompt,
559    }
560}
561
562fn push_editable_range(
563    cursor_position: Point,
564    snapshot: &BufferSnapshot,
565    editable_range: Range<Point>,
566    prompt: &mut String,
567) {
568    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
569    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
570        prompt.push_str(chunk.text);
571    }
572    prompt.push_str(CURSOR_MARKER);
573    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
574        prompt.push_str(chunk.text);
575    }
576    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582    use gpui::{App, AppContext};
583    use indoc::indoc;
584    use language::Buffer;
585
586    #[gpui::test]
587    fn test_excerpt_for_cursor_position(cx: &mut App) {
588        let text = indoc! {r#"
589            fn foo() {
590                let x = 42;
591                println!("Hello, world!");
592            }
593
594            fn bar() {
595                let x = 42;
596                let mut sum = 0;
597                for i in 0..x {
598                    sum += i;
599                }
600                println!("Sum: {}", sum);
601                return sum;
602            }
603
604            fn generate_random_numbers() -> Vec<i32> {
605                let mut rng = rand::thread_rng();
606                let mut numbers = Vec::new();
607                for _ in 0..5 {
608                    numbers.push(rng.random_range(1..101));
609                }
610                numbers
611            }
612        "#};
613        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
614        let snapshot = buffer.read(cx).snapshot();
615
616        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
617        // when a larger scope doesn't fit the editable region.
618        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
619        assert_eq!(
620            excerpt.prompt,
621            indoc! {r#"
622            ```main.rs
623                let x = 42;
624                println!("Hello, world!");
625            <|editable_region_start|>
626            }
627
628            fn bar() {
629                let x = 42;
630                let mut sum = 0;
631                for i in 0..x {
632                    sum += i;
633                }
634                println!("Sum: {}", sum);
635                r<|user_cursor_is_here|>eturn sum;
636            }
637
638            fn generate_random_numbers() -> Vec<i32> {
639            <|editable_region_end|>
640                let mut rng = rand::thread_rng();
641                let mut numbers = Vec::new();
642            ```"#}
643        );
644
645        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
646        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
647        assert_eq!(
648            excerpt.prompt,
649            indoc! {r#"
650            ```main.rs
651            fn bar() {
652                let x = 42;
653                let mut sum = 0;
654            <|editable_region_start|>
655                for i in 0..x {
656                    sum += i;
657                }
658                println!("Sum: {}", sum);
659                r<|user_cursor_is_here|>eturn sum;
660            }
661
662            fn generate_random_numbers() -> Vec<i32> {
663                let mut rng = rand::thread_rng();
664            <|editable_region_end|>
665                let mut numbers = Vec::new();
666                for _ in 0..5 {
667                    numbers.push(rng.random_range(1..101));
668            ```"#}
669        );
670    }
671}