zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use language::{
 15    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 16};
 17use project::{Project, ProjectPath};
 18use release_channel::AppVersion;
 19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 20use zeta_prompt::{Event, ZetaPromptInput};
 21
 22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 26
 27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 30
 31pub(crate) fn request_prediction_with_zeta1(
 32    store: &mut EditPredictionStore,
 33    EditPredictionModelInput {
 34        project,
 35        buffer,
 36        snapshot,
 37        position,
 38        events,
 39        trigger,
 40        debug_tx,
 41        ..
 42    }: EditPredictionModelInput,
 43    cx: &mut Context<EditPredictionStore>,
 44) -> Task<Result<Option<EditPredictionResult>>> {
 45    let buffer_snapshotted_at = Instant::now();
 46    let client = store.client.clone();
 47    let llm_token = store.llm_token.clone();
 48    let app_version = AppVersion::global(cx);
 49
 50    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 51        let can_collect_file = store.can_collect_file(&project, file, cx);
 52        let git_info = if can_collect_file {
 53            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 54        } else {
 55            None
 56        };
 57        (git_info, can_collect_file)
 58    } else {
 59        (None, false)
 60    };
 61
 62    let full_path: Arc<Path> = snapshot
 63        .file()
 64        .map(|f| Arc::from(f.full_path(cx).as_path()))
 65        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 66    let full_path_str = full_path.to_string_lossy().into_owned();
 67    let cursor_point = position.to_point(&snapshot);
 68    let prompt_for_events = {
 69        let events = events.clone();
 70        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 71    };
 72    let gather_task = gather_context(
 73        full_path_str,
 74        &snapshot,
 75        cursor_point,
 76        prompt_for_events,
 77        trigger,
 78        cx,
 79    );
 80
 81    let (uri, require_auth) = match &store.custom_predict_edits_url {
 82        Some(custom_url) => (custom_url.clone(), false),
 83        None => {
 84            match client
 85                .http_client()
 86                .build_zed_llm_url("/predict_edits/v2", &[])
 87            {
 88                Ok(url) => (url.into(), true),
 89                Err(err) => return Task::ready(Err(err)),
 90            }
 91        }
 92    };
 93
 94    cx.spawn(async move |this, cx| {
 95        let GatherContextOutput {
 96            mut body,
 97            context_range,
 98            editable_range,
 99            included_events_count,
100        } = gather_task.await?;
101        let done_gathering_context_at = Instant::now();
102
103        let included_events = &events[events.len() - included_events_count..events.len()];
104        body.can_collect_data = can_collect_file
105            && this
106                .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
107                .unwrap_or(false);
108        if body.can_collect_data {
109            body.git_info = git_info;
110        }
111
112        log::debug!(
113            "Events:\n{}\nExcerpt:\n{:?}",
114            body.input_events,
115            body.input_excerpt
116        );
117
118        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
119            |request| {
120                Ok(request
121                    .uri(uri.as_str())
122                    .body(serde_json::to_string(&body)?.into())?)
123            },
124            client,
125            llm_token,
126            app_version,
127            require_auth,
128        )
129        .await;
130
131        let context_start_offset = context_range.start.to_offset(&snapshot);
132        let editable_offset_range = editable_range.to_offset(&snapshot);
133
134        let inputs = ZetaPromptInput {
135            events: included_events.into(),
136            related_files: vec![],
137            cursor_path: full_path,
138            cursor_excerpt: snapshot
139                .text_for_range(context_range)
140                .collect::<String>()
141                .into(),
142            editable_range_in_excerpt: (editable_range.start - context_start_offset)
143                ..(editable_offset_range.end - context_start_offset),
144            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
145        };
146
147        if let Some(debug_tx) = &debug_tx {
148            debug_tx
149                .unbounded_send(DebugEvent::EditPredictionStarted(
150                    EditPredictionStartedDebugEvent {
151                        buffer: buffer.downgrade(),
152                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
153                        position,
154                    },
155                ))
156                .ok();
157        }
158
159        let (response, usage) = match response {
160            Ok(response) => response,
161            Err(err) => {
162                if err.is::<ZedUpdateRequiredError>() {
163                    cx.update(|cx| {
164                        this.update(cx, |ep_store, _cx| {
165                            ep_store.update_required = true;
166                        })
167                        .ok();
168
169                        let error_message: SharedString = err.to_string().into();
170                        show_app_notification(
171                            NotificationId::unique::<ZedUpdateRequiredError>(),
172                            cx,
173                            move |cx| {
174                                cx.new(|cx| {
175                                    ErrorMessagePrompt::new(error_message.clone(), cx)
176                                        .with_link_button("Update Zed", "https://zed.dev/releases")
177                                })
178                            },
179                        );
180                    });
181                }
182
183                return Err(err);
184            }
185        };
186
187        let received_response_at = Instant::now();
188        log::debug!("completion response: {}", &response.output_excerpt);
189
190        if let Some(usage) = usage {
191            this.update(cx, |this, cx| {
192                this.user_store.update(cx, |user_store, cx| {
193                    user_store.update_edit_prediction_usage(usage, cx);
194                });
195            })
196            .ok();
197        }
198
199        if let Some(debug_tx) = &debug_tx {
200            debug_tx
201                .unbounded_send(DebugEvent::EditPredictionFinished(
202                    EditPredictionFinishedDebugEvent {
203                        buffer: buffer.downgrade(),
204                        model_output: Some(response.output_excerpt.clone()),
205                        position,
206                    },
207                ))
208                .ok();
209        }
210
211        let edit_prediction = process_completion_response(
212            response,
213            buffer,
214            &snapshot,
215            editable_range,
216            inputs,
217            buffer_snapshotted_at,
218            received_response_at,
219            cx,
220        )
221        .await;
222
223        let finished_at = Instant::now();
224
225        // record latency for ~1% of requests
226        if rand::random::<u8>() <= 2 {
227            telemetry::event!(
228                "Edit Prediction Request",
229                context_latency = done_gathering_context_at
230                    .duration_since(buffer_snapshotted_at)
231                    .as_millis(),
232                request_latency = received_response_at
233                    .duration_since(done_gathering_context_at)
234                    .as_millis(),
235                process_latency = finished_at.duration_since(received_response_at).as_millis()
236            );
237        }
238
239        edit_prediction.map(Some)
240    })
241}
242
243fn process_completion_response(
244    prediction_response: PredictEditsResponse,
245    buffer: Entity<Buffer>,
246    snapshot: &BufferSnapshot,
247    editable_range: Range<usize>,
248    inputs: ZetaPromptInput,
249    buffer_snapshotted_at: Instant,
250    received_response_at: Instant,
251    cx: &AsyncApp,
252) -> Task<Result<EditPredictionResult>> {
253    let snapshot = snapshot.clone();
254    let request_id = prediction_response.request_id;
255    let output_excerpt = prediction_response.output_excerpt;
256    cx.spawn(async move |cx| {
257        let output_excerpt: Arc<str> = output_excerpt.into();
258
259        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
260            .background_spawn({
261                let output_excerpt = output_excerpt.clone();
262                let editable_range = editable_range.clone();
263                let snapshot = snapshot.clone();
264                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
265            })
266            .await?
267            .into();
268
269        let id = EditPredictionId(request_id.into());
270        Ok(EditPredictionResult::new(
271            id,
272            &buffer,
273            &snapshot,
274            edits,
275            buffer_snapshotted_at,
276            received_response_at,
277            inputs,
278            cx,
279        )
280        .await)
281    })
282}
283
284fn parse_edits(
285    output_excerpt: Arc<str>,
286    editable_range: Range<usize>,
287    snapshot: &BufferSnapshot,
288) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
289    let content = output_excerpt.replace(CURSOR_MARKER, "");
290
291    let start_markers = content
292        .match_indices(EDITABLE_REGION_START_MARKER)
293        .collect::<Vec<_>>();
294    anyhow::ensure!(
295        start_markers.len() == 1,
296        "expected exactly one start marker, found {}",
297        start_markers.len()
298    );
299
300    let end_markers = content
301        .match_indices(EDITABLE_REGION_END_MARKER)
302        .collect::<Vec<_>>();
303    anyhow::ensure!(
304        end_markers.len() == 1,
305        "expected exactly one end marker, found {}",
306        end_markers.len()
307    );
308
309    let sof_markers = content
310        .match_indices(START_OF_FILE_MARKER)
311        .collect::<Vec<_>>();
312    anyhow::ensure!(
313        sof_markers.len() <= 1,
314        "expected at most one start-of-file marker, found {}",
315        sof_markers.len()
316    );
317
318    let codefence_start = start_markers[0].0;
319    let content = &content[codefence_start..];
320
321    let newline_ix = content.find('\n').context("could not find newline")?;
322    let content = &content[newline_ix + 1..];
323
324    let codefence_end = content
325        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
326        .context("could not find end marker")?;
327    let new_text = &content[..codefence_end];
328
329    let old_text = snapshot
330        .text_for_range(editable_range.clone())
331        .collect::<String>();
332
333    Ok(compute_edits(
334        old_text,
335        new_text,
336        editable_range.start,
337        snapshot,
338    ))
339}
340
341pub fn compute_edits(
342    old_text: String,
343    new_text: &str,
344    offset: usize,
345    snapshot: &BufferSnapshot,
346) -> Vec<(Range<Anchor>, Arc<str>)> {
347    text_diff(&old_text, new_text)
348        .into_iter()
349        .map(|(mut old_range, new_text)| {
350            let old_slice = &old_text[old_range.clone()];
351
352            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
353            let suffix_len = common_prefix(
354                old_slice[prefix_len..].chars().rev(),
355                new_text[prefix_len..].chars().rev(),
356            );
357
358            old_range.start += offset;
359            old_range.end += offset;
360            old_range.start += prefix_len;
361            old_range.end -= suffix_len;
362
363            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
364            let range = if old_range.is_empty() {
365                let anchor = snapshot.anchor_after(old_range.start);
366                anchor..anchor
367            } else {
368                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
369            };
370            (range, new_text)
371        })
372        .collect()
373}
374
375fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
376    a.zip(b)
377        .take_while(|(a, b)| a == b)
378        .map(|(a, _)| a.len_utf8())
379        .sum()
380}
381
382fn git_info_for_file(
383    project: &Entity<Project>,
384    project_path: &ProjectPath,
385    cx: &App,
386) -> Option<PredictEditsGitInfo> {
387    let git_store = project.read(cx).git_store().read(cx);
388    if let Some((repository, _repo_path)) =
389        git_store.repository_and_path_for_project_path(project_path, cx)
390    {
391        let repository = repository.read(cx);
392        let head_sha = repository
393            .head_commit
394            .as_ref()
395            .map(|head_commit| head_commit.sha.to_string());
396        let remote_origin_url = repository.remote_origin_url.clone();
397        let remote_upstream_url = repository.remote_upstream_url.clone();
398        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
399            return None;
400        }
401        Some(PredictEditsGitInfo {
402            head_sha,
403            remote_origin_url,
404            remote_upstream_url,
405        })
406    } else {
407        None
408    }
409}
410
411pub struct GatherContextOutput {
412    pub body: PredictEditsBody,
413    pub context_range: Range<Point>,
414    pub editable_range: Range<usize>,
415    pub included_events_count: usize,
416}
417
418pub fn gather_context(
419    full_path_str: String,
420    snapshot: &BufferSnapshot,
421    cursor_point: language::Point,
422    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
423    trigger: PredictEditsRequestTrigger,
424    cx: &App,
425) -> Task<Result<GatherContextOutput>> {
426    cx.background_spawn({
427        let snapshot = snapshot.clone();
428        async move {
429            let input_excerpt = excerpt_for_cursor_position(
430                cursor_point,
431                &full_path_str,
432                &snapshot,
433                MAX_REWRITE_TOKENS,
434                MAX_CONTEXT_TOKENS,
435            );
436            let (input_events, included_events_count) = prompt_for_events();
437            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
438
439            let body = PredictEditsBody {
440                input_events,
441                input_excerpt: input_excerpt.prompt,
442                can_collect_data: false,
443                diagnostic_groups: None,
444                git_info: None,
445                outline: None,
446                speculated_output: None,
447                trigger,
448            };
449
450            Ok(GatherContextOutput {
451                body,
452                context_range: input_excerpt.context_range,
453                editable_range,
454                included_events_count,
455            })
456        }
457    })
458}
459
460fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
461    let mut result = String::new();
462    for (ix, event) in events.iter().rev().enumerate() {
463        let event_string = format_event(event.as_ref());
464        let event_tokens = guess_token_count(event_string.len());
465        if event_tokens > remaining_tokens {
466            return (result, ix);
467        }
468
469        if !result.is_empty() {
470            result.insert_str(0, "\n\n");
471        }
472        result.insert_str(0, &event_string);
473        remaining_tokens -= event_tokens;
474    }
475    return (result, events.len());
476}
477
478pub fn format_event(event: &Event) -> String {
479    match event {
480        Event::BufferChange {
481            path,
482            old_path,
483            diff,
484            ..
485        } => {
486            let mut prompt = String::new();
487
488            if old_path != path {
489                writeln!(
490                    prompt,
491                    "User renamed {} to {}\n",
492                    old_path.display(),
493                    path.display()
494                )
495                .unwrap();
496            }
497
498            if !diff.is_empty() {
499                write!(
500                    prompt,
501                    "User edited {}:\n```diff\n{}\n```",
502                    path.display(),
503                    diff
504                )
505                .unwrap();
506            }
507
508            prompt
509        }
510    }
511}
512
513#[derive(Debug)]
514pub struct InputExcerpt {
515    pub context_range: Range<Point>,
516    pub editable_range: Range<Point>,
517    pub prompt: String,
518}
519
520pub fn excerpt_for_cursor_position(
521    position: Point,
522    path: &str,
523    snapshot: &BufferSnapshot,
524    editable_region_token_limit: usize,
525    context_token_limit: usize,
526) -> InputExcerpt {
527    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
528        position,
529        snapshot,
530        editable_region_token_limit,
531        context_token_limit,
532    );
533
534    let mut prompt = String::new();
535
536    writeln!(&mut prompt, "```{path}").unwrap();
537    if context_range.start == Point::zero() {
538        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
539    }
540
541    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
542        prompt.push_str(chunk.text);
543    }
544
545    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
546
547    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
548        prompt.push_str(chunk.text);
549    }
550    write!(prompt, "\n```").unwrap();
551
552    InputExcerpt {
553        context_range,
554        editable_range,
555        prompt,
556    }
557}
558
559fn push_editable_range(
560    cursor_position: Point,
561    snapshot: &BufferSnapshot,
562    editable_range: Range<Point>,
563    prompt: &mut String,
564) {
565    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
566    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
567        prompt.push_str(chunk.text);
568    }
569    prompt.push_str(CURSOR_MARKER);
570    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
571        prompt.push_str(chunk.text);
572    }
573    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579    use gpui::{App, AppContext};
580    use indoc::indoc;
581    use language::Buffer;
582
583    #[gpui::test]
584    fn test_excerpt_for_cursor_position(cx: &mut App) {
585        let text = indoc! {r#"
586            fn foo() {
587                let x = 42;
588                println!("Hello, world!");
589            }
590
591            fn bar() {
592                let x = 42;
593                let mut sum = 0;
594                for i in 0..x {
595                    sum += i;
596                }
597                println!("Sum: {}", sum);
598                return sum;
599            }
600
601            fn generate_random_numbers() -> Vec<i32> {
602                let mut rng = rand::thread_rng();
603                let mut numbers = Vec::new();
604                for _ in 0..5 {
605                    numbers.push(rng.random_range(1..101));
606                }
607                numbers
608            }
609        "#};
610        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
611        let snapshot = buffer.read(cx).snapshot();
612
613        // The excerpt expands to syntax boundaries.
614        // With 50 token editable limit, we get a region that expands to syntax nodes.
615        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
616        assert_eq!(
617            excerpt.prompt,
618            indoc! {r#"
619            ```main.rs
620
621            fn bar() {
622                let x = 42;
623            <|editable_region_start|>
624                let mut sum = 0;
625                for i in 0..x {
626                    sum += i;
627                }
628                println!("Sum: {}", sum);
629                r<|user_cursor_is_here|>eturn sum;
630            }
631
632            fn generate_random_numbers() -> Vec<i32> {
633            <|editable_region_end|>
634                let mut rng = rand::thread_rng();
635                let mut numbers = Vec::new();
636            ```"#}
637        );
638
639        // With smaller budget, the region expands to syntax boundaries but is tighter.
640        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
641        assert_eq!(
642            excerpt.prompt,
643            indoc! {r#"
644            ```main.rs
645            fn bar() {
646                let x = 42;
647                let mut sum = 0;
648                for i in 0..x {
649            <|editable_region_start|>
650                    sum += i;
651                }
652                println!("Sum: {}", sum);
653                r<|user_cursor_is_here|>eturn sum;
654            }
655
656            fn generate_random_numbers() -> Vec<i32> {
657            <|editable_region_end|>
658                let mut rng = rand::thread_rng();
659            ```"#}
660        );
661    }
662}