zeta.rs

  1use crate::{
  2    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  3    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, StoredEvent,
  4    ZedUpdateRequiredError, cursor_excerpt::compute_excerpt_ranges,
  5    prediction::EditPredictionResult,
  6};
  7use anyhow::Result;
  8use cloud_llm_client::{
  9    AcceptEditPredictionBody, EditPredictionRejectReason, predict_edits_v3::RawCompletionRequest,
 10};
 11use edit_prediction_types::PredictedCursorPosition;
 12use gpui::{App, AppContext as _, Entity, Task, WeakEntity, prelude::*};
 13use language::{
 14    Buffer, BufferSnapshot, ToOffset as _, ToPoint, language_settings::all_language_settings,
 15    text_diff,
 16};
 17use release_channel::AppVersion;
 18use settings::EditPredictionPromptFormat;
 19use text::{Anchor, Bias};
 20use ui::SharedString;
 21use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
 22use zeta_prompt::{ParsedOutput, ZetaPromptInput};
 23
 24use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
 25use zeta_prompt::{
 26    CURSOR_MARKER, ZetaFormat, format_zeta_prompt, get_prefill, parse_zeta2_model_output,
 27    prompt_input_contains_special_tokens, stop_tokens_for_format,
 28    zeta1::{self, EDITABLE_REGION_END_MARKER},
 29};
 30
 31use crate::open_ai_compatible::{
 32    load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
 33};
 34
 35pub fn request_prediction_with_zeta(
 36    store: &mut EditPredictionStore,
 37    EditPredictionModelInput {
 38        buffer,
 39        snapshot,
 40        position,
 41        related_files,
 42        events,
 43        debug_tx,
 44        trigger,
 45        project,
 46        can_collect_data,
 47        is_open_source,
 48        ..
 49    }: EditPredictionModelInput,
 50    capture_data: Option<Vec<StoredEvent>>,
 51    cx: &mut Context<EditPredictionStore>,
 52) -> Task<Result<Option<EditPredictionResult>>> {
 53    let settings = &all_language_settings(None, cx).edit_predictions;
 54    let provider = settings.provider;
 55    let custom_server_settings = match provider {
 56        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 57        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 58            settings.open_ai_compatible_api.clone()
 59        }
 60        _ => None,
 61    };
 62
 63    let http_client = cx.http_client();
 64    let buffer_snapshotted_at = Instant::now();
 65    let raw_config = store.zeta2_raw_config().cloned();
 66    let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
 67    let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 68
 69    let excerpt_path: Arc<Path> = snapshot
 70        .file()
 71        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 72        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 73
 74    let repo_url = if can_collect_data {
 75        let buffer_id = buffer.read(cx).remote_id();
 76        project
 77            .read(cx)
 78            .git_store()
 79            .read(cx)
 80            .repository_and_path_for_buffer_id(buffer_id, cx)
 81            .and_then(|(repo, _)| repo.read(cx).default_remote_url())
 82    } else {
 83        None
 84    };
 85
 86    let client = store.client.clone();
 87    let llm_token = store.llm_token.clone();
 88    let organization_id = store
 89        .user_store
 90        .read(cx)
 91        .current_organization()
 92        .map(|organization| organization.id.clone());
 93    let app_version = AppVersion::global(cx);
 94
 95    struct Prediction {
 96        prompt_input: ZetaPromptInput,
 97        buffer: Entity<Buffer>,
 98        snapshot: BufferSnapshot,
 99        edits: Vec<(Range<Anchor>, Arc<str>)>,
100        cursor_position: Option<PredictedCursorPosition>,
101        received_response_at: Instant,
102        editable_range_in_buffer: Range<usize>,
103        model_version: Option<String>,
104    }
105
106    let request_task = cx.background_spawn({
107        async move {
108            let zeta_version = raw_config
109                .as_ref()
110                .map(|config| config.format)
111                .unwrap_or(ZetaFormat::default());
112
113            let cursor_offset = position.to_offset(&snapshot);
114            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
115                &snapshot,
116                related_files,
117                events,
118                excerpt_path,
119                cursor_offset,
120                preferred_experiment,
121                is_open_source,
122                can_collect_data,
123                repo_url,
124            );
125
126            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
127                return Err(anyhow::anyhow!("prompt contains special tokens"));
128            }
129
130            if let Some(debug_tx) = &debug_tx {
131                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
132                debug_tx
133                    .unbounded_send(DebugEvent::EditPredictionStarted(
134                        EditPredictionStartedDebugEvent {
135                            buffer: buffer.downgrade(),
136                            prompt: Some(prompt),
137                            position,
138                        },
139                    ))
140                    .ok();
141            }
142
143            log::trace!("Sending edit prediction request");
144
145            let (request_id, output, model_version, usage) =
146                if let Some(custom_settings) = &custom_server_settings {
147                    let max_tokens = custom_settings.max_output_tokens * 4;
148
149                    match custom_settings.prompt_format {
150                        EditPredictionPromptFormat::Zeta => {
151                            let ranges = &prompt_input.excerpt_ranges;
152                            let editable_range_in_excerpt = ranges.editable_350.clone();
153                            let prompt = zeta1::format_zeta1_from_input(
154                                &prompt_input,
155                                editable_range_in_excerpt.clone(),
156                                ranges.editable_350_context_150.clone(),
157                            );
158                            let stop_tokens = vec![
159                                EDITABLE_REGION_END_MARKER.to_string(),
160                                format!("{EDITABLE_REGION_END_MARKER}\n"),
161                                format!("{EDITABLE_REGION_END_MARKER}\n\n"),
162                                format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
163                            ];
164
165                            let (response_text, request_id) = send_custom_server_request(
166                                provider,
167                                custom_settings,
168                                prompt,
169                                max_tokens,
170                                stop_tokens,
171                                open_ai_compatible_api_key.clone(),
172                                &http_client,
173                            )
174                            .await?;
175
176                            let request_id = EditPredictionId(request_id.into());
177                            let output_text = zeta1::clean_zeta1_model_output(&response_text);
178                            let parsed_output = output_text.map(|text| ParsedOutput {
179                                new_editable_region: text,
180                                range_in_excerpt: editable_range_in_excerpt,
181                            });
182
183                            (request_id, parsed_output, None, None)
184                        }
185                        EditPredictionPromptFormat::Zeta2 => {
186                            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
187                            let prefill = get_prefill(&prompt_input, zeta_version);
188                            let prompt = format!("{prompt}{prefill}");
189
190                            let (response_text, request_id) = send_custom_server_request(
191                                provider,
192                                custom_settings,
193                                prompt,
194                                max_tokens,
195                                stop_tokens_for_format(zeta_version)
196                                    .iter()
197                                    .map(|token| token.to_string())
198                                    .collect(),
199                                open_ai_compatible_api_key.clone(),
200                                &http_client,
201                            )
202                            .await?;
203
204                            let request_id = EditPredictionId(request_id.into());
205                            let output_text = if response_text.is_empty() {
206                                None
207                            } else {
208                                let output = format!("{prefill}{response_text}");
209                                Some(parse_zeta2_model_output(
210                                    &output,
211                                    zeta_version,
212                                    &prompt_input,
213                                )?)
214                            };
215
216                            (request_id, output_text, None, None)
217                        }
218                        _ => anyhow::bail!("unsupported prompt format"),
219                    }
220                } else if let Some(config) = &raw_config {
221                    let prompt = format_zeta_prompt(&prompt_input, config.format);
222                    let prefill = get_prefill(&prompt_input, config.format);
223                    let prompt = format!("{prompt}{prefill}");
224                    let environment = config
225                        .environment
226                        .clone()
227                        .or_else(|| Some(config.format.to_string().to_lowercase()));
228                    let request = RawCompletionRequest {
229                        model: config.model_id.clone().unwrap_or_default(),
230                        prompt,
231                        temperature: None,
232                        stop: stop_tokens_for_format(config.format)
233                            .iter()
234                            .map(|token| std::borrow::Cow::Borrowed(*token))
235                            .collect(),
236                        max_tokens: Some(2048),
237                        environment,
238                    };
239
240                    let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
241                        request,
242                        client,
243                        None,
244                        llm_token,
245                        organization_id,
246                        app_version,
247                    )
248                    .await?;
249
250                    let request_id = EditPredictionId(response.id.clone().into());
251                    let output = if let Some(choice) = response.choices.pop() {
252                        let response = &choice.text;
253                        let output = format!("{prefill}{response}");
254                        Some(parse_zeta2_model_output(
255                            &output,
256                            config.format,
257                            &prompt_input,
258                        )?)
259                    } else {
260                        None
261                    };
262
263                    (request_id, output, None, usage)
264                } else {
265                    // Use V3 endpoint - server handles model/version selection and suffix stripping
266                    let (response, usage) = EditPredictionStore::send_v3_request(
267                        prompt_input.clone(),
268                        client,
269                        llm_token,
270                        organization_id,
271                        app_version,
272                        trigger,
273                    )
274                    .await?;
275
276                    let request_id = EditPredictionId(response.request_id.into());
277                    let output_text = Some(response.output).filter(|s| !s.is_empty());
278                    let model_version = response.model_version;
279                    let parsed_output = ParsedOutput {
280                        new_editable_region: output_text.unwrap_or_default(),
281                        range_in_excerpt: response.editable_range,
282                    };
283
284                    (request_id, Some(parsed_output), model_version, usage)
285                };
286
287            let received_response_at = Instant::now();
288
289            log::trace!("Got edit prediction response");
290
291            let Some(ParsedOutput {
292                new_editable_region: mut output_text,
293                range_in_excerpt: editable_range_in_excerpt,
294            }) = output
295            else {
296                return Ok(((request_id, None), None));
297            };
298
299            let editable_range_in_buffer = editable_range_in_excerpt.start
300                + full_context_offset_range.start
301                ..editable_range_in_excerpt.end + full_context_offset_range.start;
302
303            let mut old_text = snapshot
304                .text_for_range(editable_range_in_buffer.clone())
305                .collect::<String>();
306
307            // Client-side cursor marker processing (applies to both raw and v3 responses)
308            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
309            if let Some(offset) = cursor_offset_in_output {
310                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
311                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
312            }
313
314            if let Some(debug_tx) = &debug_tx {
315                debug_tx
316                    .unbounded_send(DebugEvent::EditPredictionFinished(
317                        EditPredictionFinishedDebugEvent {
318                            buffer: buffer.downgrade(),
319                            position,
320                            model_output: Some(output_text.clone()),
321                        },
322                    ))
323                    .ok();
324            }
325
326            if !output_text.is_empty() && !output_text.ends_with('\n') {
327                output_text.push('\n');
328            }
329            if !old_text.is_empty() && !old_text.ends_with('\n') {
330                old_text.push('\n');
331            }
332
333            let (edits, cursor_position) = compute_edits_and_cursor_position(
334                old_text,
335                &output_text,
336                editable_range_in_buffer.start,
337                cursor_offset_in_output,
338                &snapshot,
339            );
340
341            anyhow::Ok((
342                (
343                    request_id,
344                    Some(Prediction {
345                        prompt_input,
346                        buffer,
347                        snapshot: snapshot.clone(),
348                        edits,
349                        cursor_position,
350                        received_response_at,
351                        editable_range_in_buffer,
352                        model_version,
353                    }),
354                ),
355                usage,
356            ))
357        }
358    });
359
360    cx.spawn(async move |this, cx| {
361        let (id, prediction) = handle_api_response(&this, request_task.await, cx)?;
362
363        let Some(Prediction {
364            prompt_input: inputs,
365            buffer: edited_buffer,
366            snapshot: edited_buffer_snapshot,
367            edits,
368            cursor_position,
369            received_response_at,
370            editable_range_in_buffer,
371            model_version,
372        }) = prediction
373        else {
374            return Ok(Some(EditPredictionResult {
375                id,
376                prediction: Err(EditPredictionRejectReason::Empty),
377            }));
378        };
379
380        if can_collect_data {
381            let weak_this = this.clone();
382            let id = id.clone();
383            let edited_buffer = edited_buffer.clone();
384            let edited_buffer_snapshot = edited_buffer_snapshot.clone();
385            let example_task = capture_data.and_then(|stored_events| {
386                cx.update(|cx| {
387                    crate::capture_example(
388                        project.clone(),
389                        edited_buffer.clone(),
390                        position,
391                        stored_events,
392                        false,
393                        cx,
394                    )
395                })
396            });
397            cx.spawn(async move |cx| {
398                let example_spec = if let Some(task) = example_task {
399                    task.await.ok()
400                } else {
401                    None
402                };
403
404                weak_this
405                    .update(cx, |this, cx| {
406                        this.enqueue_settled_prediction(
407                            id.clone(),
408                            &project,
409                            &edited_buffer,
410                            &edited_buffer_snapshot,
411                            editable_range_in_buffer,
412                            example_spec,
413                            cx,
414                        );
415                    })
416                    .ok();
417            })
418            .detach();
419        }
420
421        Ok(Some(
422            EditPredictionResult::new(
423                id,
424                &edited_buffer,
425                &edited_buffer_snapshot,
426                edits.into(),
427                cursor_position,
428                buffer_snapshotted_at,
429                received_response_at,
430                inputs,
431                model_version,
432                cx,
433            )
434            .await,
435        ))
436    })
437}
438
439fn handle_api_response<T>(
440    this: &WeakEntity<EditPredictionStore>,
441    response: Result<(T, Option<client::EditPredictionUsage>)>,
442    cx: &mut gpui::AsyncApp,
443) -> Result<T> {
444    match response {
445        Ok((data, usage)) => {
446            if let Some(usage) = usage {
447                this.update(cx, |this, cx| {
448                    this.user_store.update(cx, |user_store, cx| {
449                        user_store.update_edit_prediction_usage(usage, cx);
450                    });
451                })
452                .ok();
453            }
454            Ok(data)
455        }
456        Err(err) => {
457            if err.is::<ZedUpdateRequiredError>() {
458                cx.update(|cx| {
459                    this.update(cx, |this, _cx| {
460                        this.update_required = true;
461                    })
462                    .ok();
463
464                    let error_message: SharedString = err.to_string().into();
465                    show_app_notification(
466                        NotificationId::unique::<ZedUpdateRequiredError>(),
467                        cx,
468                        move |cx| {
469                            cx.new(|cx| {
470                                ErrorMessagePrompt::new(error_message.clone(), cx)
471                                    .with_link_button("Update Zed", "https://zed.dev/releases")
472                            })
473                        },
474                    );
475                });
476            }
477            Err(err)
478        }
479    }
480}
481
482pub fn zeta2_prompt_input(
483    snapshot: &language::BufferSnapshot,
484    related_files: Vec<zeta_prompt::RelatedFile>,
485    events: Vec<Arc<zeta_prompt::Event>>,
486    excerpt_path: Arc<Path>,
487    cursor_offset: usize,
488    preferred_experiment: Option<String>,
489    is_open_source: bool,
490    can_collect_data: bool,
491    repo_url: Option<String>,
492) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
493    let cursor_point = cursor_offset.to_point(snapshot);
494
495    let (full_context, full_context_offset_range, excerpt_ranges) =
496        compute_excerpt_ranges(cursor_point, snapshot);
497
498    let full_context_start_offset = full_context_offset_range.start;
499    let full_context_start_row = full_context.start.row;
500
501    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
502
503    let prompt_input = zeta_prompt::ZetaPromptInput {
504        cursor_path: excerpt_path,
505        cursor_excerpt: snapshot
506            .text_for_range(full_context)
507            .collect::<String>()
508            .into(),
509        cursor_offset_in_excerpt,
510        excerpt_start_row: Some(full_context_start_row),
511        events,
512        related_files: Some(related_files),
513        excerpt_ranges,
514        experiment: preferred_experiment,
515        in_open_source_repo: is_open_source,
516        can_collect_data,
517        repo_url,
518    };
519    (full_context_offset_range, prompt_input)
520}
521
522pub(crate) fn edit_prediction_accepted(
523    store: &EditPredictionStore,
524    current_prediction: CurrentEditPrediction,
525    cx: &App,
526) {
527    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
528    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
529        return;
530    }
531
532    let request_id = current_prediction.prediction.id.to_string();
533    let model_version = current_prediction.prediction.model_version;
534    let require_auth = custom_accept_url.is_none();
535    let client = store.client.clone();
536    let llm_token = store.llm_token.clone();
537    let organization_id = store
538        .user_store
539        .read(cx)
540        .current_organization()
541        .map(|organization| organization.id.clone());
542    let app_version = AppVersion::global(cx);
543
544    cx.background_spawn(async move {
545        let url = if let Some(accept_edits_url) = custom_accept_url {
546            gpui::http_client::Url::parse(&accept_edits_url)?
547        } else {
548            client
549                .http_client()
550                .build_zed_llm_url("/predict_edits/accept", &[])?
551        };
552
553        let response = EditPredictionStore::send_api_request::<()>(
554            move |builder| {
555                let req = builder.uri(url.as_ref()).body(
556                    serde_json::to_string(&AcceptEditPredictionBody {
557                        request_id: request_id.clone(),
558                        model_version: model_version.clone(),
559                    })?
560                    .into(),
561                );
562                Ok(req?)
563            },
564            client,
565            llm_token,
566            organization_id,
567            app_version,
568            require_auth,
569        )
570        .await;
571
572        response?;
573        anyhow::Ok(())
574    })
575    .detach_and_log_err(cx);
576}
577
578pub fn compute_edits(
579    old_text: String,
580    new_text: &str,
581    offset: usize,
582    snapshot: &BufferSnapshot,
583) -> Vec<(Range<Anchor>, Arc<str>)> {
584    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
585}
586
587pub fn compute_edits_and_cursor_position(
588    old_text: String,
589    new_text: &str,
590    offset: usize,
591    cursor_offset_in_new_text: Option<usize>,
592    snapshot: &BufferSnapshot,
593) -> (
594    Vec<(Range<Anchor>, Arc<str>)>,
595    Option<PredictedCursorPosition>,
596) {
597    let diffs = text_diff(&old_text, new_text);
598
599    // Delta represents the cumulative change in byte count from all preceding edits.
600    // new_offset = old_offset + delta, so old_offset = new_offset - delta
601    let mut delta: isize = 0;
602    let mut cursor_position: Option<PredictedCursorPosition> = None;
603    let buffer_len = snapshot.len();
604
605    let edits = diffs
606        .iter()
607        .map(|(raw_old_range, new_text)| {
608            // Compute cursor position if it falls within or before this edit.
609            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
610                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
611                let edit_end_in_new = edit_start_in_new + new_text.len();
612
613                if cursor_offset < edit_start_in_new {
614                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
615                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
616                    cursor_position = Some(PredictedCursorPosition::at_anchor(
617                        snapshot.anchor_after(buffer_offset),
618                    ));
619                } else if cursor_offset < edit_end_in_new {
620                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
621                    let offset_within_insertion = cursor_offset - edit_start_in_new;
622                    cursor_position = Some(PredictedCursorPosition::new(
623                        snapshot.anchor_before(buffer_offset),
624                        offset_within_insertion,
625                    ));
626                }
627
628                delta += new_text.len() as isize - raw_old_range.len() as isize;
629            }
630
631            // Compute the edit with prefix/suffix trimming.
632            let mut old_range = raw_old_range.clone();
633            let old_slice = &old_text[old_range.clone()];
634
635            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
636            let suffix_len = common_prefix(
637                old_slice[prefix_len..].chars().rev(),
638                new_text[prefix_len..].chars().rev(),
639            );
640
641            old_range.start += offset;
642            old_range.end += offset;
643            old_range.start += prefix_len;
644            old_range.end -= suffix_len;
645
646            old_range.start = old_range.start.min(buffer_len);
647            old_range.end = old_range.end.min(buffer_len);
648
649            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
650            let range = if old_range.is_empty() {
651                let anchor = snapshot.anchor_after(old_range.start);
652                anchor..anchor
653            } else {
654                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
655            };
656            (range, new_text)
657        })
658        .collect();
659
660    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
661        let cursor_in_old = (cursor_offset as isize - delta) as usize;
662        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
663        cursor_position = Some(PredictedCursorPosition::at_anchor(
664            snapshot.anchor_after(buffer_offset),
665        ));
666    }
667
668    (edits, cursor_position)
669}
670
671fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
672    a.zip(b)
673        .take_while(|(a, b)| a == b)
674        .map(|(a, _)| a.len_utf8())
675        .sum()
676}