zeta1.rs

  1mod input_excerpt;
  2
  3use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
  4
  5use crate::{
  6    EditPredictionId, ZedUpdateRequiredError, Zeta,
  7    prediction::{EditPredictionInputs, EditPredictionResult},
  8};
  9use anyhow::{Context as _, Result};
 10use cloud_llm_client::{
 11    PredictEditsBody, PredictEditsGitInfo, PredictEditsResponse, predict_edits_v3::Event,
 12};
 13use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
 14use input_excerpt::excerpt_for_cursor_position;
 15use language::{
 16    Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
 17};
 18use project::{Project, ProjectPath};
 19use release_channel::AppVersion;
 20use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 21
 22const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
 23const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
 24const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>";
 25const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>";
 26
 27pub(crate) const MAX_CONTEXT_TOKENS: usize = 150;
 28pub(crate) const MAX_REWRITE_TOKENS: usize = 350;
 29pub(crate) const MAX_EVENT_TOKENS: usize = 500;
 30
 31pub(crate) fn request_prediction_with_zeta1(
 32    zeta: &mut Zeta,
 33    project: &Entity<Project>,
 34    buffer: &Entity<Buffer>,
 35    snapshot: BufferSnapshot,
 36    position: language::Anchor,
 37    events: Vec<Arc<Event>>,
 38    cx: &mut Context<Zeta>,
 39) -> Task<Result<Option<EditPredictionResult>>> {
 40    let buffer = buffer.clone();
 41    let buffer_snapshotted_at = Instant::now();
 42    let client = zeta.client.clone();
 43    let llm_token = zeta.llm_token.clone();
 44    let app_version = AppVersion::global(cx);
 45
 46    let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
 47        let can_collect_file = zeta.can_collect_file(project, file, cx);
 48        let git_info = if can_collect_file {
 49            git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
 50        } else {
 51            None
 52        };
 53        (git_info, can_collect_file)
 54    } else {
 55        (None, false)
 56    };
 57
 58    let full_path: Arc<Path> = snapshot
 59        .file()
 60        .map(|f| Arc::from(f.full_path(cx).as_path()))
 61        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 62    let full_path_str = full_path.to_string_lossy().into_owned();
 63    let cursor_point = position.to_point(&snapshot);
 64    let prompt_for_events = {
 65        let events = events.clone();
 66        move || prompt_for_events_impl(&events, MAX_EVENT_TOKENS)
 67    };
 68    let gather_task = gather_context(
 69        full_path_str,
 70        &snapshot,
 71        cursor_point,
 72        prompt_for_events,
 73        cx,
 74    );
 75
 76    cx.spawn(async move |this, cx| {
 77        let GatherContextOutput {
 78            mut body,
 79            context_range,
 80            editable_range,
 81            included_events_count,
 82        } = gather_task.await?;
 83        let done_gathering_context_at = Instant::now();
 84
 85        let included_events = &events[events.len() - included_events_count..events.len()];
 86        body.can_collect_data = can_collect_file
 87            && this
 88                .read_with(cx, |this, _| this.can_collect_events(included_events))
 89                .unwrap_or(false);
 90        if body.can_collect_data {
 91            body.git_info = git_info;
 92        }
 93
 94        log::debug!(
 95            "Events:\n{}\nExcerpt:\n{:?}",
 96            body.input_events,
 97            body.input_excerpt
 98        );
 99
100        let http_client = client.http_client();
101
102        let response = Zeta::send_api_request::<PredictEditsResponse>(
103            |request| {
104                let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
105                    predict_edits_url
106                } else {
107                    http_client
108                        .build_zed_llm_url("/predict_edits/v2", &[])?
109                        .as_str()
110                        .into()
111                };
112                Ok(request
113                    .uri(uri)
114                    .body(serde_json::to_string(&body)?.into())?)
115            },
116            client,
117            llm_token,
118            app_version,
119        )
120        .await;
121
122        let inputs = EditPredictionInputs {
123            events: included_events.into(),
124            included_files: vec![cloud_llm_client::predict_edits_v3::IncludedFile {
125                path: full_path.clone(),
126                max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
127                excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
128                    start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
129                    text: snapshot
130                        .text_for_range(context_range)
131                        .collect::<String>()
132                        .into(),
133                }],
134            }],
135            cursor_point: cloud_llm_client::predict_edits_v3::Point {
136                column: cursor_point.column,
137                line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
138            },
139            cursor_path: full_path,
140        };
141
142        // let response = perform_predict_edits(PerformPredictEditsParams {
143        //     client,
144        //     llm_token,
145        //     app_version,
146        //     body,
147        // })
148        // .await;
149
150        let (response, usage) = match response {
151            Ok(response) => response,
152            Err(err) => {
153                if err.is::<ZedUpdateRequiredError>() {
154                    cx.update(|cx| {
155                        this.update(cx, |zeta, _cx| {
156                            zeta.update_required = true;
157                        })
158                        .ok();
159
160                        let error_message: SharedString = err.to_string().into();
161                        show_app_notification(
162                            NotificationId::unique::<ZedUpdateRequiredError>(),
163                            cx,
164                            move |cx| {
165                                cx.new(|cx| {
166                                    ErrorMessagePrompt::new(error_message.clone(), cx)
167                                        .with_link_button("Update Zed", "https://zed.dev/releases")
168                                })
169                            },
170                        );
171                    })
172                    .ok();
173                }
174
175                return Err(err);
176            }
177        };
178
179        let received_response_at = Instant::now();
180        log::debug!("completion response: {}", &response.output_excerpt);
181
182        if let Some(usage) = usage {
183            this.update(cx, |this, cx| {
184                this.user_store.update(cx, |user_store, cx| {
185                    user_store.update_edit_prediction_usage(usage, cx);
186                });
187            })
188            .ok();
189        }
190
191        let edit_prediction = process_completion_response(
192            response,
193            buffer,
194            &snapshot,
195            editable_range,
196            inputs,
197            buffer_snapshotted_at,
198            received_response_at,
199            cx,
200        )
201        .await;
202
203        let finished_at = Instant::now();
204
205        // record latency for ~1% of requests
206        if rand::random::<u8>() <= 2 {
207            telemetry::event!(
208                "Edit Prediction Request",
209                context_latency = done_gathering_context_at
210                    .duration_since(buffer_snapshotted_at)
211                    .as_millis(),
212                request_latency = received_response_at
213                    .duration_since(done_gathering_context_at)
214                    .as_millis(),
215                process_latency = finished_at.duration_since(received_response_at).as_millis()
216            );
217        }
218
219        edit_prediction.map(Some)
220    })
221}
222
223fn process_completion_response(
224    prediction_response: PredictEditsResponse,
225    buffer: Entity<Buffer>,
226    snapshot: &BufferSnapshot,
227    editable_range: Range<usize>,
228    inputs: EditPredictionInputs,
229    buffer_snapshotted_at: Instant,
230    received_response_at: Instant,
231    cx: &AsyncApp,
232) -> Task<Result<EditPredictionResult>> {
233    let snapshot = snapshot.clone();
234    let request_id = prediction_response.request_id;
235    let output_excerpt = prediction_response.output_excerpt;
236    cx.spawn(async move |cx| {
237        let output_excerpt: Arc<str> = output_excerpt.into();
238
239        let edits: Arc<[(Range<Anchor>, Arc<str>)]> = cx
240            .background_spawn({
241                let output_excerpt = output_excerpt.clone();
242                let editable_range = editable_range.clone();
243                let snapshot = snapshot.clone();
244                async move { parse_edits(output_excerpt, editable_range, &snapshot) }
245            })
246            .await?
247            .into();
248
249        let id = EditPredictionId(request_id.into());
250        Ok(EditPredictionResult::new(
251            id,
252            &buffer,
253            &snapshot,
254            edits,
255            buffer_snapshotted_at,
256            received_response_at,
257            inputs,
258            cx,
259        )
260        .await)
261    })
262}
263
264fn parse_edits(
265    output_excerpt: Arc<str>,
266    editable_range: Range<usize>,
267    snapshot: &BufferSnapshot,
268) -> Result<Vec<(Range<Anchor>, Arc<str>)>> {
269    let content = output_excerpt.replace(CURSOR_MARKER, "");
270
271    let start_markers = content
272        .match_indices(EDITABLE_REGION_START_MARKER)
273        .collect::<Vec<_>>();
274    anyhow::ensure!(
275        start_markers.len() == 1,
276        "expected exactly one start marker, found {}",
277        start_markers.len()
278    );
279
280    let end_markers = content
281        .match_indices(EDITABLE_REGION_END_MARKER)
282        .collect::<Vec<_>>();
283    anyhow::ensure!(
284        end_markers.len() == 1,
285        "expected exactly one end marker, found {}",
286        end_markers.len()
287    );
288
289    let sof_markers = content
290        .match_indices(START_OF_FILE_MARKER)
291        .collect::<Vec<_>>();
292    anyhow::ensure!(
293        sof_markers.len() <= 1,
294        "expected at most one start-of-file marker, found {}",
295        sof_markers.len()
296    );
297
298    let codefence_start = start_markers[0].0;
299    let content = &content[codefence_start..];
300
301    let newline_ix = content.find('\n').context("could not find newline")?;
302    let content = &content[newline_ix + 1..];
303
304    let codefence_end = content
305        .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
306        .context("could not find end marker")?;
307    let new_text = &content[..codefence_end];
308
309    let old_text = snapshot
310        .text_for_range(editable_range.clone())
311        .collect::<String>();
312
313    Ok(compute_edits(
314        old_text,
315        new_text,
316        editable_range.start,
317        snapshot,
318    ))
319}
320
321pub fn compute_edits(
322    old_text: String,
323    new_text: &str,
324    offset: usize,
325    snapshot: &BufferSnapshot,
326) -> Vec<(Range<Anchor>, Arc<str>)> {
327    text_diff(&old_text, new_text)
328        .into_iter()
329        .map(|(mut old_range, new_text)| {
330            old_range.start += offset;
331            old_range.end += offset;
332
333            let prefix_len = common_prefix(
334                snapshot.chars_for_range(old_range.clone()),
335                new_text.chars(),
336            );
337            old_range.start += prefix_len;
338
339            let suffix_len = common_prefix(
340                snapshot.reversed_chars_for_range(old_range.clone()),
341                new_text[prefix_len..].chars().rev(),
342            );
343            old_range.end = old_range.end.saturating_sub(suffix_len);
344
345            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
346            let range = if old_range.is_empty() {
347                let anchor = snapshot.anchor_after(old_range.start);
348                anchor..anchor
349            } else {
350                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
351            };
352            (range, new_text)
353        })
354        .collect()
355}
356
357fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
358    a.zip(b)
359        .take_while(|(a, b)| a == b)
360        .map(|(a, _)| a.len_utf8())
361        .sum()
362}
363
364fn git_info_for_file(
365    project: &Entity<Project>,
366    project_path: &ProjectPath,
367    cx: &App,
368) -> Option<PredictEditsGitInfo> {
369    let git_store = project.read(cx).git_store().read(cx);
370    if let Some((repository, _repo_path)) =
371        git_store.repository_and_path_for_project_path(project_path, cx)
372    {
373        let repository = repository.read(cx);
374        let head_sha = repository
375            .head_commit
376            .as_ref()
377            .map(|head_commit| head_commit.sha.to_string());
378        let remote_origin_url = repository.remote_origin_url.clone();
379        let remote_upstream_url = repository.remote_upstream_url.clone();
380        if head_sha.is_none() && remote_origin_url.is_none() && remote_upstream_url.is_none() {
381            return None;
382        }
383        Some(PredictEditsGitInfo {
384            head_sha,
385            remote_origin_url,
386            remote_upstream_url,
387        })
388    } else {
389        None
390    }
391}
392
393pub struct GatherContextOutput {
394    pub body: PredictEditsBody,
395    pub context_range: Range<Point>,
396    pub editable_range: Range<usize>,
397    pub included_events_count: usize,
398}
399
400pub fn gather_context(
401    full_path_str: String,
402    snapshot: &BufferSnapshot,
403    cursor_point: language::Point,
404    prompt_for_events: impl FnOnce() -> (String, usize) + Send + 'static,
405    cx: &App,
406) -> Task<Result<GatherContextOutput>> {
407    cx.background_spawn({
408        let snapshot = snapshot.clone();
409        async move {
410            let input_excerpt = excerpt_for_cursor_position(
411                cursor_point,
412                &full_path_str,
413                &snapshot,
414                MAX_REWRITE_TOKENS,
415                MAX_CONTEXT_TOKENS,
416            );
417            let (input_events, included_events_count) = prompt_for_events();
418            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
419
420            let body = PredictEditsBody {
421                input_events,
422                input_excerpt: input_excerpt.prompt,
423                can_collect_data: false,
424                diagnostic_groups: None,
425                git_info: None,
426                outline: None,
427                speculated_output: None,
428            };
429
430            Ok(GatherContextOutput {
431                body,
432                context_range: input_excerpt.context_range,
433                editable_range,
434                included_events_count,
435            })
436        }
437    })
438}
439
440fn prompt_for_events_impl(events: &[Arc<Event>], mut remaining_tokens: usize) -> (String, usize) {
441    let mut result = String::new();
442    for (ix, event) in events.iter().rev().enumerate() {
443        let event_string = format_event(event.as_ref());
444        let event_tokens = guess_token_count(event_string.len());
445        if event_tokens > remaining_tokens {
446            return (result, ix);
447        }
448
449        if !result.is_empty() {
450            result.insert_str(0, "\n\n");
451        }
452        result.insert_str(0, &event_string);
453        remaining_tokens -= event_tokens;
454    }
455    return (result, events.len());
456}
457
458pub fn format_event(event: &Event) -> String {
459    match event {
460        Event::BufferChange {
461            path,
462            old_path,
463            diff,
464            ..
465        } => {
466            let mut prompt = String::new();
467
468            if old_path != path {
469                writeln!(
470                    prompt,
471                    "User renamed {} to {}\n",
472                    old_path.display(),
473                    path.display()
474                )
475                .unwrap();
476            }
477
478            if !diff.is_empty() {
479                write!(
480                    prompt,
481                    "User edited {}:\n```diff\n{}\n```",
482                    path.display(),
483                    diff
484                )
485                .unwrap();
486            }
487
488            prompt
489        }
490    }
491}
492
493/// Typical number of string bytes per token for the purposes of limiting model input. This is
494/// intentionally low to err on the side of underestimating limits.
495pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
496
497fn guess_token_count(bytes: usize) -> usize {
498    bytes / BYTES_PER_TOKEN_GUESS
499}