zeta.rs

  1use crate::cursor_excerpt::compute_excerpt_ranges;
  2use crate::prediction::EditPredictionResult;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  5    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
  6};
  7use anyhow::Result;
  8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
  9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 10use edit_prediction_types::PredictedCursorPosition;
 11use gpui::{App, AppContext as _, Task, prelude::*};
 12use language::language_settings::all_language_settings;
 13use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
 14use release_channel::AppVersion;
 15use settings::EditPredictionPromptFormat;
 16use text::{Anchor, Bias};
 17
 18use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
 19use zeta_prompt::{
 20    CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
 21    output_with_context_for_format, prompt_input_contains_special_tokens,
 22    zeta1::{self, EDITABLE_REGION_END_MARKER},
 23};
 24
 25use crate::open_ai_compatible::{
 26    load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
 27};
 28
 29pub fn request_prediction_with_zeta(
 30    store: &mut EditPredictionStore,
 31    EditPredictionModelInput {
 32        buffer,
 33        snapshot,
 34        position,
 35        related_files,
 36        events,
 37        debug_tx,
 38        trigger,
 39        project,
 40        can_collect_data,
 41        is_open_source,
 42        ..
 43    }: EditPredictionModelInput,
 44    cx: &mut Context<EditPredictionStore>,
 45) -> Task<Result<Option<EditPredictionResult>>> {
 46    let settings = &all_language_settings(None, cx).edit_predictions;
 47    let provider = settings.provider;
 48    let custom_server_settings = match provider {
 49        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 50        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 51            settings.open_ai_compatible_api.clone()
 52        }
 53        _ => None,
 54    };
 55
 56    let http_client = cx.http_client();
 57    let buffer_snapshotted_at = Instant::now();
 58    let raw_config = store.zeta2_raw_config().cloned();
 59    let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
 60    let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 61
 62    let excerpt_path: Arc<Path> = snapshot
 63        .file()
 64        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 65        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 66
 67    let client = store.client.clone();
 68    let llm_token = store.llm_token.clone();
 69    let organization_id = store
 70        .user_store
 71        .read(cx)
 72        .current_organization()
 73        .map(|organization| organization.id.clone());
 74    let app_version = AppVersion::global(cx);
 75
 76    let request_task = cx.background_spawn({
 77        async move {
 78            let zeta_version = raw_config
 79                .as_ref()
 80                .map(|config| config.format)
 81                .unwrap_or(ZetaFormat::default());
 82
 83            let cursor_offset = position.to_offset(&snapshot);
 84            let editable_range_in_excerpt: Range<usize>;
 85            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
 86                &snapshot,
 87                related_files,
 88                events,
 89                excerpt_path,
 90                cursor_offset,
 91                preferred_experiment,
 92                is_open_source,
 93                can_collect_data,
 94            );
 95
 96            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
 97                return Ok((None, None));
 98            }
 99
100            if let Some(debug_tx) = &debug_tx {
101                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
102                debug_tx
103                    .unbounded_send(DebugEvent::EditPredictionStarted(
104                        EditPredictionStartedDebugEvent {
105                            buffer: buffer.downgrade(),
106                            prompt: Some(prompt),
107                            position,
108                        },
109                    ))
110                    .ok();
111            }
112
113            log::trace!("Sending edit prediction request");
114
115            let (request_id, output_text, model_version, usage) =
116                if let Some(custom_settings) = &custom_server_settings {
117                    let max_tokens = custom_settings.max_output_tokens * 4;
118
119                    match custom_settings.prompt_format {
120                        EditPredictionPromptFormat::Zeta => {
121                            let ranges = &prompt_input.excerpt_ranges;
122                            let prompt = zeta1::format_zeta1_from_input(
123                                &prompt_input,
124                                ranges.editable_350.clone(),
125                                ranges.editable_350_context_150.clone(),
126                            );
127                            editable_range_in_excerpt = ranges.editable_350.clone();
128                            let stop_tokens = vec![
129                                EDITABLE_REGION_END_MARKER.to_string(),
130                                format!("{EDITABLE_REGION_END_MARKER}\n"),
131                                format!("{EDITABLE_REGION_END_MARKER}\n\n"),
132                                format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
133                            ];
134
135                            let (response_text, request_id) = send_custom_server_request(
136                                provider,
137                                custom_settings,
138                                prompt,
139                                max_tokens,
140                                stop_tokens,
141                                open_ai_compatible_api_key.clone(),
142                                &http_client,
143                            )
144                            .await?;
145
146                            let request_id = EditPredictionId(request_id.into());
147                            let output_text = zeta1::clean_zeta1_model_output(&response_text);
148
149                            (request_id, output_text, None, None)
150                        }
151                        EditPredictionPromptFormat::Zeta2 => {
152                            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
153                            let prefill = get_prefill(&prompt_input, zeta_version);
154                            let prompt = format!("{prompt}{prefill}");
155
156                            editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
157                                zeta_version,
158                                &prompt_input.excerpt_ranges,
159                            )
160                            .0;
161
162                            let (response_text, request_id) = send_custom_server_request(
163                                provider,
164                                custom_settings,
165                                prompt,
166                                max_tokens,
167                                vec![],
168                                open_ai_compatible_api_key.clone(),
169                                &http_client,
170                            )
171                            .await?;
172
173                            let request_id = EditPredictionId(request_id.into());
174                            let output_text = if response_text.is_empty() {
175                                None
176                            } else {
177                                let output = format!("{prefill}{response_text}");
178                                Some(clean_zeta2_model_output(&output, zeta_version).to_string())
179                            };
180
181                            (request_id, output_text, None, None)
182                        }
183                        _ => anyhow::bail!("unsupported prompt format"),
184                    }
185                } else if let Some(config) = &raw_config {
186                    let prompt = format_zeta_prompt(&prompt_input, config.format);
187                    let prefill = get_prefill(&prompt_input, config.format);
188                    let prompt = format!("{prompt}{prefill}");
189                    let environment = config
190                        .environment
191                        .clone()
192                        .or_else(|| Some(config.format.to_string().to_lowercase()));
193                    let request = RawCompletionRequest {
194                        model: config.model_id.clone().unwrap_or_default(),
195                        prompt,
196                        temperature: None,
197                        stop: vec![],
198                        max_tokens: Some(2048),
199                        environment,
200                    };
201
202                    editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
203                        config.format,
204                        &prompt_input.excerpt_ranges,
205                    )
206                    .1;
207
208                    let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
209                        request,
210                        client,
211                        None,
212                        llm_token,
213                        organization_id,
214                        app_version,
215                    )
216                    .await?;
217
218                    let request_id = EditPredictionId(response.id.clone().into());
219                    let output_text = response.choices.pop().map(|choice| {
220                        let response = &choice.text;
221                        let output = format!("{prefill}{response}");
222                        clean_zeta2_model_output(&output, config.format).to_string()
223                    });
224
225                    (request_id, output_text, None, usage)
226                } else {
227                    // Use V3 endpoint - server handles model/version selection and suffix stripping
228                    let (response, usage) = EditPredictionStore::send_v3_request(
229                        prompt_input.clone(),
230                        client,
231                        llm_token,
232                        organization_id,
233                        app_version,
234                        trigger,
235                    )
236                    .await?;
237
238                    let request_id = EditPredictionId(response.request_id.into());
239                    let output_text = if response.output.is_empty() {
240                        None
241                    } else {
242                        Some(response.output)
243                    };
244                    editable_range_in_excerpt = response.editable_range;
245                    let model_version = response.model_version;
246
247                    (request_id, output_text, model_version, usage)
248                };
249
250            let received_response_at = Instant::now();
251
252            log::trace!("Got edit prediction response");
253
254            let Some(mut output_text) = output_text else {
255                return Ok((Some((request_id, None, model_version)), usage));
256            };
257
258            let editable_range_in_buffer = editable_range_in_excerpt.start
259                + full_context_offset_range.start
260                ..editable_range_in_excerpt.end + full_context_offset_range.start;
261
262            let mut old_text = snapshot
263                .text_for_range(editable_range_in_buffer.clone())
264                .collect::<String>();
265
266            // For the hashline format, the model may return <|set|>/<|insert|>
267            // edit commands instead of a full replacement. Apply them against
268            // the original editable region to produce the full replacement text.
269            // This must happen before cursor marker stripping because the cursor
270            // marker is embedded inside edit command content.
271            if let Some(rewritten_output) =
272                output_with_context_for_format(zeta_version, &old_text, &output_text)?
273            {
274                output_text = rewritten_output;
275            }
276
277            // Client-side cursor marker processing (applies to both raw and v3 responses)
278            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
279            if let Some(offset) = cursor_offset_in_output {
280                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
281                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
282            }
283
284            if let Some(debug_tx) = &debug_tx {
285                debug_tx
286                    .unbounded_send(DebugEvent::EditPredictionFinished(
287                        EditPredictionFinishedDebugEvent {
288                            buffer: buffer.downgrade(),
289                            position,
290                            model_output: Some(output_text.clone()),
291                        },
292                    ))
293                    .ok();
294            }
295
296            if !output_text.is_empty() && !output_text.ends_with('\n') {
297                output_text.push('\n');
298            }
299            if !old_text.is_empty() && !old_text.ends_with('\n') {
300                old_text.push('\n');
301            }
302
303            let (edits, cursor_position) = compute_edits_and_cursor_position(
304                old_text,
305                &output_text,
306                editable_range_in_buffer.start,
307                cursor_offset_in_output,
308                &snapshot,
309            );
310
311            anyhow::Ok((
312                Some((
313                    request_id,
314                    Some((
315                        prompt_input,
316                        buffer,
317                        snapshot.clone(),
318                        edits,
319                        cursor_position,
320                        received_response_at,
321                        editable_range_in_buffer,
322                    )),
323                    model_version,
324                )),
325                usage,
326            ))
327        }
328    });
329
330    cx.spawn(async move |this, cx| {
331        let Some((id, prediction, model_version)) =
332            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
333        else {
334            return Ok(None);
335        };
336
337        let Some((
338            inputs,
339            edited_buffer,
340            edited_buffer_snapshot,
341            edits,
342            cursor_position,
343            received_response_at,
344            editable_range_in_buffer,
345        )) = prediction
346        else {
347            return Ok(Some(EditPredictionResult {
348                id,
349                prediction: Err(EditPredictionRejectReason::Empty),
350            }));
351        };
352
353        if can_collect_data {
354            this.update(cx, |this, cx| {
355                this.enqueue_settled_prediction(
356                    id.clone(),
357                    &project,
358                    &edited_buffer,
359                    &edited_buffer_snapshot,
360                    editable_range_in_buffer,
361                    cx,
362                );
363            })
364            .ok();
365        }
366
367        Ok(Some(
368            EditPredictionResult::new(
369                id,
370                &edited_buffer,
371                &edited_buffer_snapshot,
372                edits.into(),
373                cursor_position,
374                buffer_snapshotted_at,
375                received_response_at,
376                inputs,
377                model_version,
378                cx,
379            )
380            .await,
381        ))
382    })
383}
384
385pub fn zeta2_prompt_input(
386    snapshot: &language::BufferSnapshot,
387    related_files: Vec<zeta_prompt::RelatedFile>,
388    events: Vec<Arc<zeta_prompt::Event>>,
389    excerpt_path: Arc<Path>,
390    cursor_offset: usize,
391    preferred_experiment: Option<String>,
392    is_open_source: bool,
393    can_collect_data: bool,
394) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
395    let cursor_point = cursor_offset.to_point(snapshot);
396
397    let (full_context, full_context_offset_range, excerpt_ranges) =
398        compute_excerpt_ranges(cursor_point, snapshot);
399
400    let full_context_start_offset = full_context_offset_range.start;
401    let full_context_start_row = full_context.start.row;
402
403    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
404
405    let prompt_input = zeta_prompt::ZetaPromptInput {
406        cursor_path: excerpt_path,
407        cursor_excerpt: snapshot
408            .text_for_range(full_context)
409            .collect::<String>()
410            .into(),
411        cursor_offset_in_excerpt,
412        excerpt_start_row: Some(full_context_start_row),
413        events,
414        related_files,
415        excerpt_ranges,
416        experiment: preferred_experiment,
417        in_open_source_repo: is_open_source,
418        can_collect_data,
419    };
420    (full_context_offset_range, prompt_input)
421}
422
423pub(crate) fn edit_prediction_accepted(
424    store: &EditPredictionStore,
425    current_prediction: CurrentEditPrediction,
426    cx: &App,
427) {
428    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
429    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
430        return;
431    }
432
433    let request_id = current_prediction.prediction.id.to_string();
434    let model_version = current_prediction.prediction.model_version;
435    let require_auth = custom_accept_url.is_none();
436    let client = store.client.clone();
437    let llm_token = store.llm_token.clone();
438    let organization_id = store
439        .user_store
440        .read(cx)
441        .current_organization()
442        .map(|organization| organization.id.clone());
443    let app_version = AppVersion::global(cx);
444
445    cx.background_spawn(async move {
446        let url = if let Some(accept_edits_url) = custom_accept_url {
447            gpui::http_client::Url::parse(&accept_edits_url)?
448        } else {
449            client
450                .http_client()
451                .build_zed_llm_url("/predict_edits/accept", &[])?
452        };
453
454        let response = EditPredictionStore::send_api_request::<()>(
455            move |builder| {
456                let req = builder.uri(url.as_ref()).body(
457                    serde_json::to_string(&AcceptEditPredictionBody {
458                        request_id: request_id.clone(),
459                        model_version: model_version.clone(),
460                    })?
461                    .into(),
462                );
463                Ok(req?)
464            },
465            client,
466            llm_token,
467            organization_id,
468            app_version,
469            require_auth,
470        )
471        .await;
472
473        response?;
474        anyhow::Ok(())
475    })
476    .detach_and_log_err(cx);
477}
478
479pub fn compute_edits(
480    old_text: String,
481    new_text: &str,
482    offset: usize,
483    snapshot: &BufferSnapshot,
484) -> Vec<(Range<Anchor>, Arc<str>)> {
485    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
486}
487
488pub fn compute_edits_and_cursor_position(
489    old_text: String,
490    new_text: &str,
491    offset: usize,
492    cursor_offset_in_new_text: Option<usize>,
493    snapshot: &BufferSnapshot,
494) -> (
495    Vec<(Range<Anchor>, Arc<str>)>,
496    Option<PredictedCursorPosition>,
497) {
498    let diffs = text_diff(&old_text, new_text);
499
500    // Delta represents the cumulative change in byte count from all preceding edits.
501    // new_offset = old_offset + delta, so old_offset = new_offset - delta
502    let mut delta: isize = 0;
503    let mut cursor_position: Option<PredictedCursorPosition> = None;
504    let buffer_len = snapshot.len();
505
506    let edits = diffs
507        .iter()
508        .map(|(raw_old_range, new_text)| {
509            // Compute cursor position if it falls within or before this edit.
510            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
511                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
512                let edit_end_in_new = edit_start_in_new + new_text.len();
513
514                if cursor_offset < edit_start_in_new {
515                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
516                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
517                    cursor_position = Some(PredictedCursorPosition::at_anchor(
518                        snapshot.anchor_after(buffer_offset),
519                    ));
520                } else if cursor_offset < edit_end_in_new {
521                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
522                    let offset_within_insertion = cursor_offset - edit_start_in_new;
523                    cursor_position = Some(PredictedCursorPosition::new(
524                        snapshot.anchor_before(buffer_offset),
525                        offset_within_insertion,
526                    ));
527                }
528
529                delta += new_text.len() as isize - raw_old_range.len() as isize;
530            }
531
532            // Compute the edit with prefix/suffix trimming.
533            let mut old_range = raw_old_range.clone();
534            let old_slice = &old_text[old_range.clone()];
535
536            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
537            let suffix_len = common_prefix(
538                old_slice[prefix_len..].chars().rev(),
539                new_text[prefix_len..].chars().rev(),
540            );
541
542            old_range.start += offset;
543            old_range.end += offset;
544            old_range.start += prefix_len;
545            old_range.end -= suffix_len;
546
547            old_range.start = old_range.start.min(buffer_len);
548            old_range.end = old_range.end.min(buffer_len);
549
550            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
551            let range = if old_range.is_empty() {
552                let anchor = snapshot.anchor_after(old_range.start);
553                anchor..anchor
554            } else {
555                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
556            };
557            (range, new_text)
558        })
559        .collect();
560
561    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
562        let cursor_in_old = (cursor_offset as isize - delta) as usize;
563        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
564        cursor_position = Some(PredictedCursorPosition::at_anchor(
565            snapshot.anchor_after(buffer_offset),
566        ));
567    }
568
569    (edits, cursor_position)
570}
571
572fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
573    a.zip(b)
574        .take_while(|(a, b)| a == b)
575        .map(|(a, _)| a.len_utf8())
576        .sum()
577}