zeta1.rs

  1mod input_excerpt;
  2
  3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  4
  5use crate::{
  6    EditPredictionId, ZedUpdateRequiredError, Zeta,
  7    prediction::{EditPrediction, EditPredictionInputs},
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use input_excerpt::excerpt_for_cursor_position;
 15use language::{
 16    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
 17};
 18use project::{Project, ProjectPath};
 19use release_channel::AppVersion;
 20use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 21
 22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 26
 27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 30
 31pub(crate) fn request_prediction_with_zeta1(
 32    zeta: &mut Zeta,
 33    project: &Entity<Project>,
 34    buffer: &Entity<Buffer>,
 35    position: language::Anchor,
 36    cx: &mut Context<Zeta>,
 37) -> Task<Result<Option<EditPrediction>>> {
 38    let buffer = buffer.clone();
 39    let buffer_snapshotted_at = Instant::now();
 40    let snapshot = buffer.read(cx).snapshot();
 41    let client = zeta.client.clone();
 42    let llm_token = zeta.llm_token.clone();
 43    let app_version = AppVersion::global(cx);
 44
 45    let zeta_project = zeta.get_or_init_zeta_project(project, cx);
 46    let events = Arc::new(zeta_project.events(cx));
 47
 48    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 49        let can_collect_file = zeta.can_collect_file(project, file, cx);
 50        let git_info = if can_collect_file {
 51            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 52        } else {
 53            None
 54        };
 55        (git_info, can_collect_file)
 56    } else {
 57        (None, false)
 58    };
 59
 60    let full_path: Arc<Path> = snapshot
 61        .file()
 62        .map(|f| Arc::from(f.full_path(cx).as_path()))
 63        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 64    let full_path_str = full_path.to_string_lossy().into_owned();
 65    let cursor_point = position.to_point(&snapshot);
 66    let prompt_for_events = {
 67        let events = events.clone();
 68        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 69    };
 70    let gather_task = gather_context(
 71        full_path_str,
 72        &snapshot,
 73        cursor_point,
 74        prompt_for_events,
 75        cx,
 76    );
 77
 78    cx.spawn(async move |this, cx| {
 79        let GatherContextOutput {
 80            mut body,
 81            context_range,
 82            editable_range,
 83            included_events_count,
 84        } = gather_task.await?;
 85        let done_gathering_context_at = Instant::now();
 86
 87        let included_events = &events[events.len() - included_events_count..events.len()];
 88        body.can_collect_data = can_collect_file
 89            && this
 90                .read_with(cx, |this, _| this.can_collect_events(included_events))
 91                .unwrap_or(false);
 92        if body.can_collect_data {
 93            body.git_info = git_info;
 94        }
 95
 96        log::debug!(
 97            "Events:\n{}\nExcerpt:\n{:?}",
 98            body.input_events,
 99            body.input_excerpt
100        );
101
102        let http_client = client.http_client();
103
104        let response = Zeta::send_api_request::<PredictEditsResponse>(
105            |request| {
106                let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
107                    predict_edits_url
108                } else {
109                    http_client
110                        .build_zed_llm_url("/predict_edits/v2", &[])?
111                        .as_str()
112                        .into()
113                };
114                Ok(request
115                    .uri(uri)
116                    .body(serde_json::to_string(&body)?.into())?)
117            },
118            client,
119            llm_token,
120            app_version,
121        )
122        .await;
123
124        let inputs = EditPredictionInputs {
125            events: included_events.into(),
126            included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
127                path: full_path.clone(),
128                max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
129                excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
130                    start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
131                    text: snapshot
132                        .text_for_range(context_range)
133                        .collect::<String>()
134                        .into(),
135                }],
136            }],
137            cursor_point: cloud_llm_client::predict_edits_v3::Point {
138                column: cursor_point.column,
139                line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
140            },
141            cursor_path: full_path,
142        };
143
144        // let response = perform_predict_edits(PerformPredictEditsParams {
145        //     client,
146        //     llm_token,
147        //     app_version,
148        //     body,
149        // })
150        // .await;
151
152        let (response, usage) = match response {
153            Ok(response) => response,
154            Err(err) => {
155                if err.is::<ZedUpdateRequiredError>() {
156                    cx.update(|cx| {
157                        this.update(cx, |zeta, _cx| {
158                            zeta.update_required = true;
159                        })
160                        .ok();
161
162                        let error_message: SharedString = err.to_string().into();
163                        show_app_notification(
164                            NotificationId::unique::<ZedUpdateRequiredError>(),
165                            cx,
166                            move |cx| {
167                                cx.new(|cx| {
168                                    ErrorMessagePrompt::new(error_message.clone(), cx)
169                                        .with_link_button("Update Zed", "https://zed.dev/releases")
170                                })
171                            },
172                        );
173                    })
174                    .ok();
175                }
176
177                return Err(err);
178            }
179        };
180
181        let received_response_at = Instant::now();
182        log::debug!("completion response: {}", &response.output_excerpt);
183
184        if let Some(usage) = usage {
185            this.update(cx, |this, cx| {
186                this.user_store.update(cx, |user_store, cx| {
187                    user_store.update_edit_prediction_usage(usage, cx);
188                });
189            })
190            .ok();
191        }
192
193        let edit_prediction = process_completion_response(
194            response,
195            buffer,
196            &snapshot,
197            editable_range,
198            inputs,
199            buffer_snapshotted_at,
200            received_response_at,
201            cx,
202        )
203        .await;
204
205        let finished_at = Instant::now();
206
207        // record latency for ~1% of requests
208        if rand::random::<u8>() <= 2 {
209            telemetry::event!(
210                "Edit Prediction Request",
211                context_latency = done_gathering_context_at
212                    .duration_since(buffer_snapshotted_at)
213                    .as_millis(),
214                request_latency = received_response_at
215                    .duration_since(done_gathering_context_at)
216                    .as_millis(),
217                process_latency = finished_at.duration_since(received_response_at).as_millis()
218            );
219        }
220
221        edit_prediction
222    })
223}
224
225fn process_completion_response(
226    prediction_response: PredictEditsResponse,
227    buffer: Entity<Buffer>,
228    snapshot: &BufferSnapshot,
229    editable_range: Range<usize>,
230    inputs: EditPredictionInputs,
231    buffer_snapshotted_at: Instant,
232    received_response_at: Instant,
233    cx: &AsyncApp,
234) -> Task<Result<Option<EditPrediction>>> {
235    let snapshot = snapshot.clone();
236    let request_id = prediction_response.request_id;
237    let output_excerpt = prediction_response.output_excerpt;
238    cx.spawn(async move |cx| {
239        let output_excerpt: Arc<str> = output_excerpt.into();
240
241        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
242            .background_spawn({
243                let output_excerpt = output_excerpt.clone();
244                let editable_range = editable_range.clone();
245                let snapshot = snapshot.clone();
246                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
247            })
248            .await?
249            .into();
250
251        Ok(EditPrediction::new(
252            EditPredictionId(request_id.into()),
253            &buffer,
254            &snapshot,
255            edits,
256            buffer_snapshotted_at,
257            received_response_at,
258            inputs,
259            cx,
260        )
261        .await)
262    })
263}
264
265fn parse_edits(
266    output_excerpt: Arc<str>,
267    editable_range: Range<usize>,
268    snapshot: &BufferSnapshot,
269) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
270    let content = output_excerpt.replace(CURSOR_MARKER, "");
271
272    let start_markers = content
273        .match_indices(EDITABLE_REGION_START_MARKER)
274        .collect::<Vec<_>>();
275    anyhow::ensure!(
276        start_markers.len() == 1,
277        "expected exactly one start marker, found {}",
278        start_markers.len()
279    );
280
281    let end_markers = content
282        .match_indices(EDITABLE_REGION_END_MARKER)
283        .collect::<Vec<_>>();
284    anyhow::ensure!(
285        end_markers.len() == 1,
286        "expected exactly one end marker, found {}",
287        end_markers.len()
288    );
289
290    let sof_markers = content
291        .match_indices(START_OF_FILE_MARKER)
292        .collect::<Vec<_>>();
293    anyhow::ensure!(
294        sof_markers.len() <= 1,
295        "expected at most one start-of-file marker, found {}",
296        sof_markers.len()
297    );
298
299    let codefence_start = start_markers[0].0;
300    let content = &content[codefence_start..];
301
302    let newline_ix = content.find('\n').context("could not find newline")?;
303    let content = &content[newline_ix + 1..];
304
305    let codefence_end = content
306        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
307        .context("could not find end marker")?;
308    let new_text = &content[..codefence_end];
309
310    let old_text = snapshot
311        .text_for_range(editable_range.clone())
312        .collect::<String>();
313
314    Ok(compute_edits(
315        old_text,
316        new_text,
317        editable_range.start,
318        snapshot,
319    ))
320}
321
322pub fn compute_edits(
323    old_text: String,
324    new_text: &str,
325    offset: usize,
326    snapshot: &BufferSnapshot,
327) -> Vec<(Range<Anchor>, Arc<str>)> {
328    text_diff(&old_text, new_text)
329        .into_iter()
330        .map(|(mut old_range, new_text)| {
331            old_range.start += offset;
332            old_range.end += offset;
333
334            let prefix_len = common_prefix(
335                snapshot.chars_for_range(old_range.clone()),
336                new_text.chars(),
337            );
338            old_range.start += prefix_len;
339
340            let suffix_len = common_prefix(
341                snapshot.reversed_chars_for_range(old_range.clone()),
342                new_text[prefix_len..].chars().rev(),
343            );
344            old_range.end = old_range.end.saturating_sub(suffix_len);
345
346            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
347            let range = if old_range.is_empty() {
348                let anchor = snapshot.anchor_after(old_range.start);
349                anchor..anchor
350            } else {
351                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
352            };
353            (range, new_text)
354        })
355        .collect()
356}
357
358fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
359    a.zip(b)
360        .take_while(|(a, b)| a == b)
361        .map(|(a, _)| a.len_utf8())
362        .sum()
363}
364
365fn git_info_for_file(
366    project: &Entity<Project>,
367    project_path: &ProjectPath,
368    cx: &App,
369) -> Option<PredictEditsGitInfo> {
370    let git_store = project.read(cx).git_store().read(cx);
371    if let Some((repository, _repo_path)) =
372        git_store.repository_and_path_for_project_path(project_path, cx)
373    {
374        let repository = repository.read(cx);
375        let head_sha = repository
376            .head_commit
377            .as_ref()
378            .map(|head_commit| head_commit.sha.to_string());
379        let remote_origin_url = repository.remote_origin_url.clone();
380        let remote_upstream_url = repository.remote_upstream_url.clone();
381        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
382            return None;
383        }
384        Some(PredictEditsGitInfo {
385            head_sha,
386            remote_origin_url,
387            remote_upstream_url,
388        })
389    } else {
390        None
391    }
392}
393
394pub struct GatherContextOutput {
395    pub body: PredictEditsBody,
396    pub context_range: Range<Point>,
397    pub editable_range: Range<usize>,
398    pub included_events_count: usize,
399}
400
401pub fn gather_context(
402    full_path_str: String,
403    snapshot: &BufferSnapshot,
404    cursor_point: language::Point,
405    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
406    cx: &App,
407) -> Task<Result<GatherContextOutput>> {
408    cx.background_spawn({
409        let snapshot = snapshot.clone();
410        async move {
411            let input_excerpt = excerpt_for_cursor_position(
412                cursor_point,
413                &full_path_str,
414                &snapshot,
415                MAX_REWRITE_TOKENS,
416                MAX_CONTEXT_TOKENS,
417            );
418            let (input_events, included_events_count) = prompt_for_events();
419            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
420
421            let body = PredictEditsBody {
422                input_events,
423                input_excerpt: input_excerpt.prompt,
424                can_collect_data: false,
425                diagnostic_groups: None,
426                git_info: None,
427                outline: None,
428                speculated_output: None,
429            };
430
431            Ok(GatherContextOutput {
432                body,
433                context_range: input_excerpt.context_range,
434                editable_range,
435                included_events_count,
436            })
437        }
438    })
439}
440
441fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
442    let mut result = String::new();
443    for (ix, event) in events.iter().rev().enumerate() {
444        let event_string = format_event(event.as_ref());
445        let event_tokens = guess_token_count(event_string.len());
446        if event_tokens > remaining_tokens {
447            return (result, ix);
448        }
449
450        if !result.is_empty() {
451            result.insert_str(0, "\n\n");
452        }
453        result.insert_str(0, &event_string);
454        remaining_tokens -= event_tokens;
455    }
456    return (result, events.len());
457}
458
459pub fn format_event(event: &Event) -> String {
460    match event {
461        Event::BufferChange {
462            path,
463            old_path,
464            diff,
465            ..
466        } => {
467            let mut prompt = String::new();
468
469            if old_path != path {
470                writeln!(
471                    prompt,
472                    "User renamed {} to {}\n",
473                    old_path.display(),
474                    path.display()
475                )
476                .unwrap();
477            }
478
479            if !diff.is_empty() {
480                write!(
481                    prompt,
482                    "User edited {}:\n```diff\n{}\n```",
483                    path.display(),
484                    diff
485                )
486                .unwrap();
487            }
488
489            prompt
490        }
491    }
492}
493
494/// Typical number of string bytes per token for the purposes of limiting model input. This is
495/// intentionally low to err on the side of underestimating limits.
496pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
497
498fn guess_token_count(bytes: usize) -> usize {
499    bytes / BYTES_PER_TOKEN_GUESS
500}