zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::Result;
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use language::{
 15    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 16};
 17use project::{Project, ProjectPath};
 18use release_channel::AppVersion;
 19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 20use zeta_prompt::{
 21    Event, ZetaPromptInput,
 22    zeta1::{
 23        CURSOR_MARKER, EDITABLE_REGION_END_MARKER, EDITABLE_REGION_START_MARKER,
 24        START_OF_FILE_MARKER,
 25    },
 26};
 27
 28pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 29pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 30pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 31
 32pub(crate) fn request_prediction_with_zeta1(
 33    store: &mut EditPredictionStore,
 34    EditPredictionModelInput {
 35        project,
 36        buffer,
 37        snapshot,
 38        position,
 39        events,
 40        trigger,
 41        debug_tx,
 42        ..
 43    }: EditPredictionModelInput,
 44    cx: &mut Context<EditPredictionStore>,
 45) -> Task<Result<Option<EditPredictionResult>>> {
 46    let buffer_snapshotted_at = Instant::now();
 47    let client = store.client.clone();
 48    let llm_token = store.llm_token.clone();
 49    let app_version = AppVersion::global(cx);
 50
 51    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 52        let can_collect_file = store.can_collect_file(&project, file, cx);
 53        let git_info = if can_collect_file {
 54            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 55        } else {
 56            None
 57        };
 58        (git_info, can_collect_file)
 59    } else {
 60        (None, false)
 61    };
 62
 63    let full_path: Arc<Path> = snapshot
 64        .file()
 65        .map(|f| Arc::from(f.full_path(cx).as_path()))
 66        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 67    let full_path_str = full_path.to_string_lossy().into_owned();
 68    let cursor_point = position.to_point(&snapshot);
 69    let prompt_for_events = {
 70        let events = events.clone();
 71        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 72    };
 73    let gather_task = gather_context(
 74        full_path_str,
 75        &snapshot,
 76        cursor_point,
 77        prompt_for_events,
 78        trigger,
 79        cx,
 80    );
 81
 82    let (uri, require_auth) = match &store.custom_predict_edits_url {
 83        Some(custom_url) => (custom_url.clone(), false),
 84        None => {
 85            match client
 86                .http_client()
 87                .build_zed_llm_url("/predict_edits/v2", &[])
 88            {
 89                Ok(url) => (url.into(), true),
 90                Err(err) => return Task::ready(Err(err)),
 91            }
 92        }
 93    };
 94
 95    cx.spawn(async move |this, cx| {
 96        let GatherContextOutput {
 97            mut body,
 98            context_range,
 99            editable_range,
100            included_events_count,
101        } = gather_task.await?;
102        let done_gathering_context_at = Instant::now();
103
104        let included_events = &events[events.len() - included_events_count..events.len()];
105        body.can_collect_data = can_collect_file
106            && this
107                .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
108                .unwrap_or(false);
109        if body.can_collect_data {
110            body.git_info = git_info;
111        }
112
113        log::debug!(
114            "Events:\n{}\nExcerpt:\n{:?}",
115            body.input_events,
116            body.input_excerpt
117        );
118
119        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
120            |request| {
121                Ok(request
122                    .uri(uri.as_str())
123                    .body(serde_json::to_string(&body)?.into())?)
124            },
125            client,
126            llm_token,
127            app_version,
128            require_auth,
129        )
130        .await;
131
132        let context_start_offset = context_range.start.to_offset(&snapshot);
133        let editable_offset_range = editable_range.to_offset(&snapshot);
134
135        let inputs = ZetaPromptInput {
136            events: included_events.into(),
137            related_files: vec![],
138            cursor_path: full_path,
139            cursor_excerpt: snapshot
140                .text_for_range(context_range)
141                .collect::<String>()
142                .into(),
143            editable_range_in_excerpt: (editable_range.start - context_start_offset)
144                ..(editable_offset_range.end - context_start_offset),
145            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
146        };
147
148        if let Some(debug_tx) = &debug_tx {
149            debug_tx
150                .unbounded_send(DebugEvent::EditPredictionStarted(
151                    EditPredictionStartedDebugEvent {
152                        buffer: buffer.downgrade(),
153                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
154                        position,
155                    },
156                ))
157                .ok();
158        }
159
160        let (response, usage) = match response {
161            Ok(response) => response,
162            Err(err) => {
163                if err.is::<ZedUpdateRequiredError>() {
164                    cx.update(|cx| {
165                        this.update(cx, |ep_store, _cx| {
166                            ep_store.update_required = true;
167                        })
168                        .ok();
169
170                        let error_message: SharedString = err.to_string().into();
171                        show_app_notification(
172                            NotificationId::unique::<ZedUpdateRequiredError>(),
173                            cx,
174                            move |cx| {
175                                cx.new(|cx| {
176                                    ErrorMessagePrompt::new(error_message.clone(), cx)
177                                        .with_link_button("Update Zed", "https://zed.dev/releases")
178                                })
179                            },
180                        );
181                    });
182                }
183
184                return Err(err);
185            }
186        };
187
188        let received_response_at = Instant::now();
189        log::debug!("completion response: {}", &response.output_excerpt);
190
191        if let Some(usage) = usage {
192            this.update(cx, |this, cx| {
193                this.user_store.update(cx, |user_store, cx| {
194                    user_store.update_edit_prediction_usage(usage, cx);
195                });
196            })
197            .ok();
198        }
199
200        if let Some(debug_tx) = &debug_tx {
201            debug_tx
202                .unbounded_send(DebugEvent::EditPredictionFinished(
203                    EditPredictionFinishedDebugEvent {
204                        buffer: buffer.downgrade(),
205                        model_output: Some(response.output_excerpt.clone()),
206                        position,
207                    },
208                ))
209                .ok();
210        }
211
212        let edit_prediction = process_completion_response(
213            response,
214            buffer,
215            &snapshot,
216            editable_range,
217            inputs,
218            buffer_snapshotted_at,
219            received_response_at,
220            cx,
221        )
222        .await;
223
224        let finished_at = Instant::now();
225
226        // record latency for ~1% of requests
227        if rand::random::<u8>() <= 2 {
228            telemetry::event!(
229                "Edit Prediction Request",
230                context_latency = done_gathering_context_at
231                    .duration_since(buffer_snapshotted_at)
232                    .as_millis(),
233                request_latency = received_response_at
234                    .duration_since(done_gathering_context_at)
235                    .as_millis(),
236                process_latency = finished_at.duration_since(received_response_at).as_millis()
237            );
238        }
239
240        edit_prediction.map(Some)
241    })
242}
243
244fn process_completion_response(
245    prediction_response: PredictEditsResponse,
246    buffer: Entity<Buffer>,
247    snapshot: &BufferSnapshot,
248    editable_range: Range<usize>,
249    inputs: ZetaPromptInput,
250    buffer_snapshotted_at: Instant,
251    received_response_at: Instant,
252    cx: &AsyncApp,
253) -> Task<Result<EditPredictionResult>> {
254    let snapshot = snapshot.clone();
255    let request_id = prediction_response.request_id;
256    let output_excerpt = prediction_response.output_excerpt;
257    cx.spawn(async move |cx| {
258        let output_excerpt: Arc<str> = output_excerpt.into();
259
260        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
261            .background_spawn({
262                let output_excerpt = output_excerpt.clone();
263                let editable_range = editable_range.clone();
264                let snapshot = snapshot.clone();
265                async move { parse_edits(output_excerpt.as_ref(), editable_range, &snapshot) }
266            })
267            .await?
268            .into();
269
270        let id = EditPredictionId(request_id.into());
271        Ok(EditPredictionResult::new(
272            id,
273            &buffer,
274            &snapshot,
275            edits,
276            buffer_snapshotted_at,
277            received_response_at,
278            inputs,
279            cx,
280        )
281        .await)
282    })
283}
284
285pub(crate) fn parse_edits(
286    output_excerpt: &str,
287    editable_range: Range<usize>,
288    snapshot: &BufferSnapshot,
289) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
290    let content = output_excerpt.replace(CURSOR_MARKER, "");
291
292    let start_markers = content
293        .match_indices(EDITABLE_REGION_START_MARKER)
294        .collect::<Vec<_>>();
295    anyhow::ensure!(
296        start_markers.len() <= 1,
297        "expected at most one start marker, found {}",
298        start_markers.len()
299    );
300
301    let end_markers = content
302        .match_indices(EDITABLE_REGION_END_MARKER)
303        .collect::<Vec<_>>();
304    anyhow::ensure!(
305        end_markers.len() <= 1,
306        "expected at most one end marker, found {}",
307        end_markers.len()
308    );
309
310    let sof_markers = content
311        .match_indices(START_OF_FILE_MARKER)
312        .collect::<Vec<_>>();
313    anyhow::ensure!(
314        sof_markers.len() <= 1,
315        "expected at most one start-of-file marker, found {}",
316        sof_markers.len()
317    );
318
319    let content_start = start_markers
320        .first()
321        .map(|e| e.0 + EDITABLE_REGION_START_MARKER.len() + 1) // +1 to skip \n after marker
322        .unwrap_or(0);
323    let content_end = end_markers
324        .first()
325        .map(|e| e.0.saturating_sub(1)) // -1 to exclude \n before marker
326        .unwrap_or(content.strip_suffix("\n").unwrap_or(&content).len());
327
328    let new_text = &content[content_start..content_end];
329
330    let old_text = snapshot
331        .text_for_range(editable_range.clone())
332        .collect::<String>();
333
334    Ok(compute_edits(
335        old_text,
336        new_text,
337        editable_range.start,
338        snapshot,
339    ))
340}
341
342pub fn compute_edits(
343    old_text: String,
344    new_text: &str,
345    offset: usize,
346    snapshot: &BufferSnapshot,
347) -> Vec<(Range<Anchor>, Arc<str>)> {
348    text_diff(&old_text, new_text)
349        .into_iter()
350        .map(|(mut old_range, new_text)| {
351            let old_slice = &old_text[old_range.clone()];
352
353            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
354            let suffix_len = common_prefix(
355                old_slice[prefix_len..].chars().rev(),
356                new_text[prefix_len..].chars().rev(),
357            );
358
359            old_range.start += offset;
360            old_range.end += offset;
361            old_range.start += prefix_len;
362            old_range.end -= suffix_len;
363
364            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
365            let range = if old_range.is_empty() {
366                let anchor = snapshot.anchor_after(old_range.start);
367                anchor..anchor
368            } else {
369                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
370            };
371            (range, new_text)
372        })
373        .collect()
374}
375
376fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
377    a.zip(b)
378        .take_while(|(a, b)| a == b)
379        .map(|(a, _)| a.len_utf8())
380        .sum()
381}
382
383fn git_info_for_file(
384    project: &Entity<Project>,
385    project_path: &ProjectPath,
386    cx: &App,
387) -> Option<PredictEditsGitInfo> {
388    let git_store = project.read(cx).git_store().read(cx);
389    if let Some((repository, _repo_path)) =
390        git_store.repository_and_path_for_project_path(project_path, cx)
391    {
392        let repository = repository.read(cx);
393        let head_sha = repository
394            .head_commit
395            .as_ref()
396            .map(|head_commit| head_commit.sha.to_string());
397        let remote_origin_url = repository.remote_origin_url.clone();
398        let remote_upstream_url = repository.remote_upstream_url.clone();
399        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
400            return None;
401        }
402        Some(PredictEditsGitInfo {
403            head_sha,
404            remote_origin_url,
405            remote_upstream_url,
406        })
407    } else {
408        None
409    }
410}
411
412pub struct GatherContextOutput {
413    pub body: PredictEditsBody,
414    pub context_range: Range<Point>,
415    pub editable_range: Range<usize>,
416    pub included_events_count: usize,
417}
418
419pub fn gather_context(
420    full_path_str: String,
421    snapshot: &BufferSnapshot,
422    cursor_point: language::Point,
423    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
424    trigger: PredictEditsRequestTrigger,
425    cx: &App,
426) -> Task<Result<GatherContextOutput>> {
427    cx.background_spawn({
428        let snapshot = snapshot.clone();
429        async move {
430            let input_excerpt = excerpt_for_cursor_position(
431                cursor_point,
432                &full_path_str,
433                &snapshot,
434                MAX_REWRITE_TOKENS,
435                MAX_CONTEXT_TOKENS,
436            );
437            let (input_events, included_events_count) = prompt_for_events();
438            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
439
440            let body = PredictEditsBody {
441                input_events,
442                input_excerpt: input_excerpt.prompt,
443                can_collect_data: false,
444                diagnostic_groups: None,
445                git_info: None,
446                outline: None,
447                speculated_output: None,
448                trigger,
449            };
450
451            Ok(GatherContextOutput {
452                body,
453                context_range: input_excerpt.context_range,
454                editable_range,
455                included_events_count,
456            })
457        }
458    })
459}
460
461pub(crate) fn prompt_for_events(events: &[Arc<Event>], max_tokens: usize) -> String {
462    prompt_for_events_impl(events, max_tokens).0
463}
464
465fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
466    let mut result = String::new();
467    for (ix, event) in events.iter().rev().enumerate() {
468        let event_string = format_event(event.as_ref());
469        let event_tokens = guess_token_count(event_string.len());
470        if event_tokens > remaining_tokens {
471            return (result, ix);
472        }
473
474        if !result.is_empty() {
475            result.insert_str(0, "\n\n");
476        }
477        result.insert_str(0, &event_string);
478        remaining_tokens -= event_tokens;
479    }
480    return (result, events.len());
481}
482
483pub fn format_event(event: &Event) -> String {
484    match event {
485        Event::BufferChange {
486            path,
487            old_path,
488            diff,
489            ..
490        } => {
491            let mut prompt = String::new();
492
493            if old_path != path {
494                writeln!(
495                    prompt,
496                    "User renamed {} to {}\n",
497                    old_path.display(),
498                    path.display()
499                )
500                .unwrap();
501            }
502
503            if !diff.is_empty() {
504                write!(
505                    prompt,
506                    "User edited {}:\n```diff\n{}\n```",
507                    path.display(),
508                    diff
509                )
510                .unwrap();
511            }
512
513            prompt
514        }
515    }
516}
517
518#[derive(Debug)]
519pub struct InputExcerpt {
520    pub context_range: Range<Point>,
521    pub editable_range: Range<Point>,
522    pub prompt: String,
523}
524
525pub fn excerpt_for_cursor_position(
526    position: Point,
527    path: &str,
528    snapshot: &BufferSnapshot,
529    editable_region_token_limit: usize,
530    context_token_limit: usize,
531) -> InputExcerpt {
532    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
533        position,
534        snapshot,
535        editable_region_token_limit,
536        context_token_limit,
537    );
538
539    let mut prompt = String::new();
540
541    writeln!(&mut prompt, "```{path}").unwrap();
542    if context_range.start == Point::zero() {
543        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
544    }
545
546    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
547        prompt.push_str(chunk.text);
548    }
549
550    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
551
552    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
553        prompt.push_str(chunk.text);
554    }
555    write!(prompt, "\n```").unwrap();
556
557    InputExcerpt {
558        context_range,
559        editable_range,
560        prompt,
561    }
562}
563
564fn push_editable_range(
565    cursor_position: Point,
566    snapshot: &BufferSnapshot,
567    editable_range: Range<Point>,
568    prompt: &mut String,
569) {
570    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
571    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
572        prompt.push_str(chunk.text);
573    }
574    prompt.push_str(CURSOR_MARKER);
575    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
576        prompt.push_str(chunk.text);
577    }
578    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use gpui::{App, AppContext};
585    use indoc::indoc;
586    use language::Buffer;
587
588    #[gpui::test]
589    fn test_excerpt_for_cursor_position(cx: &mut App) {
590        let text = indoc! {r#"
591            fn foo() {
592                let x = 42;
593                println!("Hello, world!");
594            }
595
596            fn bar() {
597                let x = 42;
598                let mut sum = 0;
599                for i in 0..x {
600                    sum += i;
601                }
602                println!("Sum: {}", sum);
603                return sum;
604            }
605
606            fn generate_random_numbers() -> Vec<i32> {
607                let mut rng = rand::thread_rng();
608                let mut numbers = Vec::new();
609                for _ in 0..5 {
610                    numbers.push(rng.random_range(1..101));
611                }
612                numbers
613            }
614        "#};
615        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
616        let snapshot = buffer.read(cx).snapshot();
617
618        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
619        // when a larger scope doesn't fit the editable region.
620        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
621        assert_eq!(
622            excerpt.prompt,
623            indoc! {r#"
624            ```main.rs
625                let x = 42;
626                println!("Hello, world!");
627            <|editable_region_start|>
628            }
629
630            fn bar() {
631                let x = 42;
632                let mut sum = 0;
633                for i in 0..x {
634                    sum += i;
635                }
636                println!("Sum: {}", sum);
637                r<|user_cursor_is_here|>eturn sum;
638            }
639
640            fn generate_random_numbers() -> Vec<i32> {
641            <|editable_region_end|>
642                let mut rng = rand::thread_rng();
643                let mut numbers = Vec::new();
644            ```"#}
645        );
646
647        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
648        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
649        assert_eq!(
650            excerpt.prompt,
651            indoc! {r#"
652            ```main.rs
653            fn bar() {
654                let x = 42;
655                let mut sum = 0;
656            <|editable_region_start|>
657                for i in 0..x {
658                    sum += i;
659                }
660                println!("Sum: {}", sum);
661                r<|user_cursor_is_here|>eturn sum;
662            }
663
664            fn generate_random_numbers() -> Vec<i32> {
665                let mut rng = rand::thread_rng();
666            <|editable_region_end|>
667                let mut numbers = Vec::new();
668                for _ in 0..5 {
669                    numbers.push(rng.random_range(1..101));
670            ```"#}
671        );
672    }
673}