zeta1.rs

  1use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  2
  3use crate::{
  4    DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
  5    EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
  6    cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
  7    prediction::EditPredictionResult,
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use language::{
 15    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
 16};
 17use project::{Project, ProjectPath};
 18use release_channel::AppVersion;
 19use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 20use zeta_prompt::{Event, ZetaPromptInput};
 21
 22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 26
 27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 30
 31pub(crate) fn request_prediction_with_zeta1(
 32    store: &mut EditPredictionStore,
 33    EditPredictionModelInput {
 34        project,
 35        buffer,
 36        snapshot,
 37        position,
 38        events,
 39        trigger,
 40        debug_tx,
 41        ..
 42    }: EditPredictionModelInput,
 43    cx: &mut Context<EditPredictionStore>,
 44) -> Task<Result<Option<EditPredictionResult>>> {
 45    let buffer_snapshotted_at = Instant::now();
 46    let client = store.client.clone();
 47    let llm_token = store.llm_token.clone();
 48    let app_version = AppVersion::global(cx);
 49
 50    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 51        let can_collect_file = store.can_collect_file(&project, file, cx);
 52        let git_info = if can_collect_file {
 53            git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 54        } else {
 55            None
 56        };
 57        (git_info, can_collect_file)
 58    } else {
 59        (None, false)
 60    };
 61
 62    let full_path: Arc<Path> = snapshot
 63        .file()
 64        .map(|f| Arc::from(f.full_path(cx).as_path()))
 65        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 66    let full_path_str = full_path.to_string_lossy().into_owned();
 67    let cursor_point = position.to_point(&snapshot);
 68    let prompt_for_events = {
 69        let events = events.clone();
 70        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 71    };
 72    let gather_task = gather_context(
 73        full_path_str,
 74        &snapshot,
 75        cursor_point,
 76        prompt_for_events,
 77        trigger,
 78        cx,
 79    );
 80
 81    let (uri, require_auth) = match &store.custom_predict_edits_url {
 82        Some(custom_url) => (custom_url.clone(), false),
 83        None => {
 84            match client
 85                .http_client()
 86                .build_zed_llm_url("/predict_edits/v2", &[])
 87            {
 88                Ok(url) => (url.into(), true),
 89                Err(err) => return Task::ready(Err(err)),
 90            }
 91        }
 92    };
 93
 94    cx.spawn(async move |this, cx| {
 95        let GatherContextOutput {
 96            mut body,
 97            context_range,
 98            editable_range,
 99            included_events_count,
100        } = gather_task.await?;
101        let done_gathering_context_at = Instant::now();
102
103        let included_events = &events[events.len() - included_events_count..events.len()];
104        body.can_collect_data = can_collect_file
105            && this
106                .read_with(cx, |this, cx| this.can_collect_events(included_events, cx))
107                .unwrap_or(false);
108        if body.can_collect_data {
109            body.git_info = git_info;
110        }
111
112        log::debug!(
113            "Events:\n{}\nExcerpt:\n{:?}",
114            body.input_events,
115            body.input_excerpt
116        );
117
118        let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
119            |request| {
120                Ok(request
121                    .uri(uri.as_str())
122                    .body(serde_json::to_string(&body)?.into())?)
123            },
124            client,
125            llm_token,
126            app_version,
127            require_auth,
128        )
129        .await;
130
131        let context_start_offset = context_range.start.to_offset(&snapshot);
132        let editable_offset_range = editable_range.to_offset(&snapshot);
133
134        let inputs = ZetaPromptInput {
135            events: included_events.into(),
136            related_files: vec![].into(),
137            cursor_path: full_path,
138            cursor_excerpt: snapshot
139                .text_for_range(context_range)
140                .collect::<String>()
141                .into(),
142            editable_range_in_excerpt: (editable_range.start - context_start_offset)
143                ..(editable_offset_range.end - context_start_offset),
144            cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
145        };
146
147        if let Some(debug_tx) = &debug_tx {
148            debug_tx
149                .unbounded_send(DebugEvent::EditPredictionStarted(
150                    EditPredictionStartedDebugEvent {
151                        buffer: buffer.downgrade(),
152                        prompt: Some(serde_json::to_string(&inputs).unwrap()),
153                        position,
154                    },
155                ))
156                .ok();
157        }
158
159        let (response, usage) = match response {
160            Ok(response) => response,
161            Err(err) => {
162                if err.is::<ZedUpdateRequiredError>() {
163                    cx.update(|cx| {
164                        this.update(cx, |ep_store, _cx| {
165                            ep_store.update_required = true;
166                        })
167                        .ok();
168
169                        let error_message: SharedString = err.to_string().into();
170                        show_app_notification(
171                            NotificationId::unique::<ZedUpdateRequiredError>(),
172                            cx,
173                            move |cx| {
174                                cx.new(|cx| {
175                                    ErrorMessagePrompt::new(error_message.clone(), cx)
176                                        .with_link_button("Update Zed", "https://zed.dev/releases")
177                                })
178                            },
179                        );
180                    });
181                }
182
183                return Err(err);
184            }
185        };
186
187        let received_response_at = Instant::now();
188        log::debug!("completion response: {}", &response.output_excerpt);
189
190        if let Some(usage) = usage {
191            this.update(cx, |this, cx| {
192                this.user_store.update(cx, |user_store, cx| {
193                    user_store.update_edit_prediction_usage(usage, cx);
194                });
195            })
196            .ok();
197        }
198
199        if let Some(debug_tx) = &debug_tx {
200            debug_tx
201                .unbounded_send(DebugEvent::EditPredictionFinished(
202                    EditPredictionFinishedDebugEvent {
203                        buffer: buffer.downgrade(),
204                        model_output: Some(response.output_excerpt.clone()),
205                        position,
206                    },
207                ))
208                .ok();
209        }
210
211        let edit_prediction = process_completion_response(
212            response,
213            buffer,
214            &snapshot,
215            editable_range,
216            inputs,
217            buffer_snapshotted_at,
218            received_response_at,
219            cx,
220        )
221        .await;
222
223        let finished_at = Instant::now();
224
225        // record latency for ~1% of requests
226        if rand::random::<u8>() <= 2 {
227            telemetry::event!(
228                "Edit Prediction Request",
229                context_latency = done_gathering_context_at
230                    .duration_since(buffer_snapshotted_at)
231                    .as_millis(),
232                request_latency = received_response_at
233                    .duration_since(done_gathering_context_at)
234                    .as_millis(),
235                process_latency = finished_at.duration_since(received_response_at).as_millis()
236            );
237        }
238
239        edit_prediction.map(Some)
240    })
241}
242
243fn process_completion_response(
244    prediction_response: PredictEditsResponse,
245    buffer: Entity<Buffer>,
246    snapshot: &BufferSnapshot,
247    editable_range: Range<usize>,
248    inputs: ZetaPromptInput,
249    buffer_snapshotted_at: Instant,
250    received_response_at: Instant,
251    cx: &AsyncApp,
252) -> Task<Result<EditPredictionResult>> {
253    let snapshot = snapshot.clone();
254    let request_id = prediction_response.request_id;
255    let output_excerpt = prediction_response.output_excerpt;
256    cx.spawn(async move |cx| {
257        let output_excerpt: Arc<str> = output_excerpt.into();
258
259        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
260            .background_spawn({
261                let output_excerpt = output_excerpt.clone();
262                let editable_range = editable_range.clone();
263                let snapshot = snapshot.clone();
264                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
265            })
266            .await?
267            .into();
268
269        let id = EditPredictionId(request_id.into());
270        Ok(EditPredictionResult::new(
271            id,
272            &buffer,
273            &snapshot,
274            edits,
275            buffer_snapshotted_at,
276            received_response_at,
277            inputs,
278            cx,
279        )
280        .await)
281    })
282}
283
284fn parse_edits(
285    output_excerpt: Arc<str>,
286    editable_range: Range<usize>,
287    snapshot: &BufferSnapshot,
288) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
289    let content = output_excerpt.replace(CURSOR_MARKER, "");
290
291    let start_markers = content
292        .match_indices(EDITABLE_REGION_START_MARKER)
293        .collect::<Vec<_>>();
294    anyhow::ensure!(
295        start_markers.len() == 1,
296        "expected exactly one start marker, found {}",
297        start_markers.len()
298    );
299
300    let end_markers = content
301        .match_indices(EDITABLE_REGION_END_MARKER)
302        .collect::<Vec<_>>();
303    anyhow::ensure!(
304        end_markers.len() == 1,
305        "expected exactly one end marker, found {}",
306        end_markers.len()
307    );
308
309    let sof_markers = content
310        .match_indices(START_OF_FILE_MARKER)
311        .collect::<Vec<_>>();
312    anyhow::ensure!(
313        sof_markers.len() <= 1,
314        "expected at most one start-of-file marker, found {}",
315        sof_markers.len()
316    );
317
318    let codefence_start = start_markers[0].0;
319    let content = &content[codefence_start..];
320
321    let newline_ix = content.find('\n').context("could not find newline")?;
322    let content = &content[newline_ix + 1..];
323
324    let codefence_end = content
325        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
326        .context("could not find end marker")?;
327    let new_text = &content[..codefence_end];
328
329    let old_text = snapshot
330        .text_for_range(editable_range.clone())
331        .collect::<String>();
332
333    Ok(compute_edits(
334        old_text,
335        new_text,
336        editable_range.start,
337        snapshot,
338    ))
339}
340
341pub fn compute_edits(
342    old_text: String,
343    new_text: &str,
344    offset: usize,
345    snapshot: &BufferSnapshot,
346) -> Vec<(Range<Anchor>, Arc<str>)> {
347    text_diff(&old_text, new_text)
348        .into_iter()
349        .map(|(mut old_range, new_text)| {
350            old_range.start += offset;
351            old_range.end += offset;
352
353            let prefix_len = common_prefix(
354                snapshot.chars_for_range(old_range.clone()),
355                new_text.chars(),
356            );
357            old_range.start += prefix_len;
358
359            let suffix_len = common_prefix(
360                snapshot.reversed_chars_for_range(old_range.clone()),
361                new_text[prefix_len..].chars().rev(),
362            );
363            old_range.end = old_range.end.saturating_sub(suffix_len);
364
365            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
366            let range = if old_range.is_empty() {
367                let anchor = snapshot.anchor_after(old_range.start);
368                anchor..anchor
369            } else {
370                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
371            };
372            (range, new_text)
373        })
374        .collect()
375}
376
377fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
378    a.zip(b)
379        .take_while(|(a, b)| a == b)
380        .map(|(a, _)| a.len_utf8())
381        .sum()
382}
383
384fn git_info_for_file(
385    project: &Entity<Project>,
386    project_path: &ProjectPath,
387    cx: &App,
388) -> Option<PredictEditsGitInfo> {
389    let git_store = project.read(cx).git_store().read(cx);
390    if let Some((repository, _repo_path)) =
391        git_store.repository_and_path_for_project_path(project_path, cx)
392    {
393        let repository = repository.read(cx);
394        let head_sha = repository
395            .head_commit
396            .as_ref()
397            .map(|head_commit| head_commit.sha.to_string());
398        let remote_origin_url = repository.remote_origin_url.clone();
399        let remote_upstream_url = repository.remote_upstream_url.clone();
400        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
401            return None;
402        }
403        Some(PredictEditsGitInfo {
404            head_sha,
405            remote_origin_url,
406            remote_upstream_url,
407        })
408    } else {
409        None
410    }
411}
412
413pub struct GatherContextOutput {
414    pub body: PredictEditsBody,
415    pub context_range: Range<Point>,
416    pub editable_range: Range<usize>,
417    pub included_events_count: usize,
418}
419
420pub fn gather_context(
421    full_path_str: String,
422    snapshot: &BufferSnapshot,
423    cursor_point: language::Point,
424    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
425    trigger: PredictEditsRequestTrigger,
426    cx: &App,
427) -> Task<Result<GatherContextOutput>> {
428    cx.background_spawn({
429        let snapshot = snapshot.clone();
430        async move {
431            let input_excerpt = excerpt_for_cursor_position(
432                cursor_point,
433                &full_path_str,
434                &snapshot,
435                MAX_REWRITE_TOKENS,
436                MAX_CONTEXT_TOKENS,
437            );
438            let (input_events, included_events_count) = prompt_for_events();
439            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
440
441            let body = PredictEditsBody {
442                input_events,
443                input_excerpt: input_excerpt.prompt,
444                can_collect_data: false,
445                diagnostic_groups: None,
446                git_info: None,
447                outline: None,
448                speculated_output: None,
449                trigger,
450            };
451
452            Ok(GatherContextOutput {
453                body,
454                context_range: input_excerpt.context_range,
455                editable_range,
456                included_events_count,
457            })
458        }
459    })
460}
461
462fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
463    let mut result = String::new();
464    for (ix, event) in events.iter().rev().enumerate() {
465        let event_string = format_event(event.as_ref());
466        let event_tokens = guess_token_count(event_string.len());
467        if event_tokens > remaining_tokens {
468            return (result, ix);
469        }
470
471        if !result.is_empty() {
472            result.insert_str(0, "\n\n");
473        }
474        result.insert_str(0, &event_string);
475        remaining_tokens -= event_tokens;
476    }
477    return (result, events.len());
478}
479
480pub fn format_event(event: &Event) -> String {
481    match event {
482        Event::BufferChange {
483            path,
484            old_path,
485            diff,
486            ..
487        } => {
488            let mut prompt = String::new();
489
490            if old_path != path {
491                writeln!(
492                    prompt,
493                    "User renamed {} to {}\n",
494                    old_path.display(),
495                    path.display()
496                )
497                .unwrap();
498            }
499
500            if !diff.is_empty() {
501                write!(
502                    prompt,
503                    "User edited {}:\n```diff\n{}\n```",
504                    path.display(),
505                    diff
506                )
507                .unwrap();
508            }
509
510            prompt
511        }
512    }
513}
514
515#[derive(Debug)]
516pub struct InputExcerpt {
517    pub context_range: Range<Point>,
518    pub editable_range: Range<Point>,
519    pub prompt: String,
520}
521
522pub fn excerpt_for_cursor_position(
523    position: Point,
524    path: &str,
525    snapshot: &BufferSnapshot,
526    editable_region_token_limit: usize,
527    context_token_limit: usize,
528) -> InputExcerpt {
529    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
530        position,
531        snapshot,
532        editable_region_token_limit,
533        context_token_limit,
534    );
535
536    let mut prompt = String::new();
537
538    writeln!(&mut prompt, "```{path}").unwrap();
539    if context_range.start == Point::zero() {
540        writeln!(&mut prompt, "{START_OF_FILE_MARKER}").unwrap();
541    }
542
543    for chunk in snapshot.chunks(context_range.start..editable_range.start, false) {
544        prompt.push_str(chunk.text);
545    }
546
547    push_editable_range(position, snapshot, editable_range.clone(), &mut prompt);
548
549    for chunk in snapshot.chunks(editable_range.end..context_range.end, false) {
550        prompt.push_str(chunk.text);
551    }
552    write!(prompt, "\n```").unwrap();
553
554    InputExcerpt {
555        context_range,
556        editable_range,
557        prompt,
558    }
559}
560
561fn push_editable_range(
562    cursor_position: Point,
563    snapshot: &BufferSnapshot,
564    editable_range: Range<Point>,
565    prompt: &mut String,
566) {
567    writeln!(prompt, "{EDITABLE_REGION_START_MARKER}").unwrap();
568    for chunk in snapshot.chunks(editable_range.start..cursor_position, false) {
569        prompt.push_str(chunk.text);
570    }
571    prompt.push_str(CURSOR_MARKER);
572    for chunk in snapshot.chunks(cursor_position..editable_range.end, false) {
573        prompt.push_str(chunk.text);
574    }
575    write!(prompt, "\n{EDITABLE_REGION_END_MARKER}").unwrap();
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use gpui::{App, AppContext};
582    use indoc::indoc;
583    use language::Buffer;
584
585    #[gpui::test]
586    fn test_excerpt_for_cursor_position(cx: &mut App) {
587        let text = indoc! {r#"
588            fn foo() {
589                let x = 42;
590                println!("Hello, world!");
591            }
592
593            fn bar() {
594                let x = 42;
595                let mut sum = 0;
596                for i in 0..x {
597                    sum += i;
598                }
599                println!("Sum: {}", sum);
600                return sum;
601            }
602
603            fn generate_random_numbers() -> Vec<i32> {
604                let mut rng = rand::thread_rng();
605                let mut numbers = Vec::new();
606                for _ in 0..5 {
607                    numbers.push(rng.random_range(1..101));
608                }
609                numbers
610            }
611        "#};
612        let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
613        let snapshot = buffer.read(cx).snapshot();
614
615        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
616        // when a larger scope doesn't fit the editable region.
617        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
618        assert_eq!(
619            excerpt.prompt,
620            indoc! {r#"
621            ```main.rs
622                let x = 42;
623                println!("Hello, world!");
624            <|editable_region_start|>
625            }
626
627            fn bar() {
628                let x = 42;
629                let mut sum = 0;
630                for i in 0..x {
631                    sum += i;
632                }
633                println!("Sum: {}", sum);
634                r<|user_cursor_is_here|>eturn sum;
635            }
636
637            fn generate_random_numbers() -> Vec<i32> {
638            <|editable_region_end|>
639                let mut rng = rand::thread_rng();
640                let mut numbers = Vec::new();
641            ```"#}
642        );
643
644        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
645        let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
646        assert_eq!(
647            excerpt.prompt,
648            indoc! {r#"
649            ```main.rs
650            fn bar() {
651                let x = 42;
652                let mut sum = 0;
653            <|editable_region_start|>
654                for i in 0..x {
655                    sum += i;
656                }
657                println!("Sum: {}", sum);
658                r<|user_cursor_is_here|>eturn sum;
659            }
660
661            fn generate_random_numbers() -> Vec<i32> {
662                let mut rng = rand::thread_rng();
663            <|editable_region_end|>
664                let mut numbers = Vec::new();
665                for _ in 0..5 {
666                    numbers.push(rng.random_range(1..101));
667            ```"#}
668        );
669    }
670}