zeta.rs

  1use crate::cursor_excerpt::compute_excerpt_ranges;
  2use crate::prediction::EditPredictionResult;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  5    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore, ollama,
  6};
  7use anyhow::{Context as _, Result};
  8use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
  9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 10use edit_prediction_types::PredictedCursorPosition;
 11use futures::AsyncReadExt as _;
 12use gpui::{App, AppContext as _, Task, http_client, prelude::*};
 13use language::language_settings::{OpenAiCompatibleEditPredictionSettings, all_language_settings};
 14use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
 15use release_channel::AppVersion;
 16use text::{Anchor, Bias};
 17
 18use std::env;
 19use std::ops::Range;
 20use std::{path::Path, sync::Arc, time::Instant};
 21use zeta_prompt::{
 22    CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
 23    format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
 24    zeta1::{self, EDITABLE_REGION_END_MARKER},
 25};
 26
 27pub fn request_prediction_with_zeta(
 28    store: &mut EditPredictionStore,
 29    EditPredictionModelInput {
 30        buffer,
 31        snapshot,
 32        position,
 33        related_files,
 34        events,
 35        debug_tx,
 36        trigger,
 37        project,
 38        ..
 39    }: EditPredictionModelInput,
 40    preferred_model: Option<EditPredictionModelKind>,
 41    cx: &mut Context<EditPredictionStore>,
 42) -> Task<Result<Option<EditPredictionResult>>> {
 43    let settings = &all_language_settings(None, cx).edit_predictions;
 44    let provider = settings.provider;
 45    let custom_server_settings = match provider {
 46        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 47        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 48            settings.open_ai_compatible_api.clone()
 49        }
 50        _ => None,
 51    };
 52
 53    let http_client = cx.http_client();
 54    let buffer_snapshotted_at = Instant::now();
 55    let raw_config = store.zeta2_raw_config().cloned();
 56
 57    let excerpt_path: Arc<Path> = snapshot
 58        .file()
 59        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 60        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 61
 62    let client = store.client.clone();
 63    let llm_token = store.llm_token.clone();
 64    let app_version = AppVersion::global(cx);
 65
 66    let is_open_source = snapshot
 67        .file()
 68        .map_or(false, |file| store.is_file_open_source(&project, file, cx))
 69        && events.iter().all(|event| event.in_open_source_repo())
 70        && related_files.iter().all(|file| file.in_open_source_repo);
 71
 72    let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
 73
 74    let request_task = cx.background_spawn({
 75        async move {
 76            let zeta_version = raw_config
 77                .as_ref()
 78                .map(|config| config.format)
 79                .unwrap_or(ZetaFormat::default());
 80
 81            let cursor_offset = position.to_offset(&snapshot);
 82            let editable_range_in_excerpt: Range<usize>;
 83            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
 84                &snapshot,
 85                related_files,
 86                events,
 87                excerpt_path,
 88                cursor_offset,
 89                zeta_version,
 90                preferred_model,
 91                is_open_source,
 92                can_collect_data,
 93            );
 94
 95            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
 96                return Ok((None, None));
 97            }
 98
 99            let is_zeta1 = preferred_model == Some(EditPredictionModelKind::Zeta1);
100            let excerpt_ranges = prompt_input
101                .excerpt_ranges
102                .as_ref()
103                .ok_or_else(|| anyhow::anyhow!("excerpt_ranges missing from prompt input"))?;
104
105            if let Some(debug_tx) = &debug_tx {
106                let prompt = if is_zeta1 {
107                    zeta1::format_zeta1_from_input(
108                        &prompt_input,
109                        excerpt_ranges.editable_350.clone(),
110                        excerpt_ranges.editable_350_context_150.clone(),
111                    )
112                } else {
113                    format_zeta_prompt(&prompt_input, zeta_version)
114                };
115                debug_tx
116                    .unbounded_send(DebugEvent::EditPredictionStarted(
117                        EditPredictionStartedDebugEvent {
118                            buffer: buffer.downgrade(),
119                            prompt: Some(prompt),
120                            position,
121                        },
122                    ))
123                    .ok();
124            }
125
126            log::trace!("Sending edit prediction request");
127
128            let (request_id, output_text, model_version, usage) = if let Some(custom_settings) =
129                &custom_server_settings
130            {
131                let max_tokens = custom_settings.max_output_tokens * 4;
132
133                if is_zeta1 {
134                    let ranges = excerpt_ranges;
135                    let prompt = zeta1::format_zeta1_from_input(
136                        &prompt_input,
137                        ranges.editable_350.clone(),
138                        ranges.editable_350_context_150.clone(),
139                    );
140                    editable_range_in_excerpt = ranges.editable_350.clone();
141                    let stop_tokens = vec![
142                        EDITABLE_REGION_END_MARKER.to_string(),
143                        format!("{EDITABLE_REGION_END_MARKER}\n"),
144                        format!("{EDITABLE_REGION_END_MARKER}\n\n"),
145                        format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
146                    ];
147
148                    let (response_text, request_id) = send_custom_server_request(
149                        provider,
150                        custom_settings,
151                        prompt,
152                        max_tokens,
153                        stop_tokens,
154                        &http_client,
155                    )
156                    .await?;
157
158                    let request_id = EditPredictionId(request_id.into());
159                    let output_text = zeta1::clean_zeta1_model_output(&response_text);
160
161                    (request_id, output_text, None, None)
162                } else {
163                    let prompt = format_zeta_prompt(&prompt_input, zeta_version);
164                    let prefill = get_prefill(&prompt_input, zeta_version);
165                    let prompt = format!("{prompt}{prefill}");
166
167                    editable_range_in_excerpt = prompt_input
168                        .excerpt_ranges
169                        .as_ref()
170                        .map(|ranges| zeta_prompt::excerpt_range_for_format(zeta_version, ranges).0)
171                        .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
172
173                    let (response_text, request_id) = send_custom_server_request(
174                        provider,
175                        custom_settings,
176                        prompt,
177                        max_tokens,
178                        vec![],
179                        &http_client,
180                    )
181                    .await?;
182
183                    let request_id = EditPredictionId(request_id.into());
184                    let output_text = if response_text.is_empty() {
185                        None
186                    } else {
187                        let output = format!("{prefill}{response_text}");
188                        Some(clean_zeta2_model_output(&output, zeta_version).to_string())
189                    };
190
191                    (request_id, output_text, None, None)
192                }
193            } else if let Some(config) = &raw_config {
194                let prompt = format_zeta_prompt(&prompt_input, config.format);
195                let prefill = get_prefill(&prompt_input, config.format);
196                let prompt = format!("{prompt}{prefill}");
197                let request = RawCompletionRequest {
198                    model: config.model_id.clone().unwrap_or_default(),
199                    prompt,
200                    temperature: None,
201                    stop: vec![],
202                    max_tokens: Some(2048),
203                    environment: Some(config.format.to_string().to_lowercase()),
204                };
205
206                editable_range_in_excerpt = prompt_input
207                    .excerpt_ranges
208                    .as_ref()
209                    .map(|ranges| zeta_prompt::excerpt_range_for_format(config.format, ranges).1)
210                    .unwrap_or(prompt_input.editable_range_in_excerpt.clone());
211
212                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
213                    request,
214                    client,
215                    None,
216                    llm_token,
217                    app_version,
218                )
219                .await?;
220
221                let request_id = EditPredictionId(response.id.clone().into());
222                let output_text = response.choices.pop().map(|choice| {
223                    let response = &choice.text;
224                    let output = format!("{prefill}{response}");
225                    clean_zeta2_model_output(&output, config.format).to_string()
226                });
227
228                (request_id, output_text, None, usage)
229            } else {
230                // Use V3 endpoint - server handles model/version selection and suffix stripping
231                let (response, usage) = EditPredictionStore::send_v3_request(
232                    prompt_input.clone(),
233                    client,
234                    llm_token,
235                    app_version,
236                    trigger,
237                )
238                .await?;
239
240                let request_id = EditPredictionId(response.request_id.into());
241                let output_text = if response.output.is_empty() {
242                    None
243                } else {
244                    Some(response.output)
245                };
246                editable_range_in_excerpt = response.editable_range;
247                let model_version = response.model_version;
248
249                (request_id, output_text, model_version, usage)
250            };
251
252            let received_response_at = Instant::now();
253
254            log::trace!("Got edit prediction response");
255
256            let Some(mut output_text) = output_text else {
257                return Ok((Some((request_id, None, model_version)), usage));
258            };
259
260            // Client-side cursor marker processing (applies to both raw and v3 responses)
261            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
262            if let Some(offset) = cursor_offset_in_output {
263                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
264                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
265            }
266
267            if let Some(debug_tx) = &debug_tx {
268                debug_tx
269                    .unbounded_send(DebugEvent::EditPredictionFinished(
270                        EditPredictionFinishedDebugEvent {
271                            buffer: buffer.downgrade(),
272                            position,
273                            model_output: Some(output_text.clone()),
274                        },
275                    ))
276                    .ok();
277            }
278
279            let editable_range_in_buffer = editable_range_in_excerpt.start
280                + full_context_offset_range.start
281                ..editable_range_in_excerpt.end + full_context_offset_range.start;
282
283            let mut old_text = snapshot
284                .text_for_range(editable_range_in_buffer.clone())
285                .collect::<String>();
286
287            if !output_text.is_empty() && !output_text.ends_with('\n') {
288                output_text.push('\n');
289            }
290            if !old_text.is_empty() && !old_text.ends_with('\n') {
291                old_text.push('\n');
292            }
293
294            let (edits, cursor_position) = compute_edits_and_cursor_position(
295                old_text,
296                &output_text,
297                editable_range_in_buffer.start,
298                cursor_offset_in_output,
299                &snapshot,
300            );
301
302            anyhow::Ok((
303                Some((
304                    request_id,
305                    Some((
306                        prompt_input,
307                        buffer,
308                        snapshot.clone(),
309                        edits,
310                        cursor_position,
311                        received_response_at,
312                        full_context_offset_range,
313                        editable_range_in_buffer,
314                    )),
315                    model_version,
316                )),
317                usage,
318            ))
319        }
320    });
321
322    cx.spawn(async move |this, cx| {
323        let Some((id, prediction, model_version)) =
324            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
325        else {
326            return Ok(None);
327        };
328
329        let Some((
330            inputs,
331            edited_buffer,
332            edited_buffer_snapshot,
333            edits,
334            cursor_position,
335            received_response_at,
336            full_context_offset_range,
337            editable_range_in_buffer,
338        )) = prediction
339        else {
340            return Ok(Some(EditPredictionResult {
341                id,
342                prediction: Err(EditPredictionRejectReason::Empty),
343            }));
344        };
345
346        if can_collect_data {
347            cx.spawn({
348                let weak_buffer = edited_buffer.downgrade();
349                let context_anchor_range =
350                    edited_buffer_snapshot.anchor_range_around(full_context_offset_range);
351                let editable_anchor_range =
352                    edited_buffer_snapshot.anchor_range_around(editable_range_in_buffer);
353                let request_id = id.0.clone();
354                async move |cx| {
355                    cx.background_executor()
356                        .timer(std::time::Duration::from_secs(30))
357                        .await;
358
359                    let Some(buffer) = weak_buffer.upgrade() else {
360                        return;
361                    };
362                    let (new_cursor_region, editable_range_in_excerpt) =
363                        buffer.read_with(cx, |buffer, _| {
364                            let context_start =
365                                buffer.offset_for_anchor(&context_anchor_range.start);
366                            let editable_range_in_excerpt = (buffer
367                                .offset_for_anchor(&editable_anchor_range.start)
368                                - context_start)
369                                ..(buffer.offset_for_anchor(&editable_anchor_range.end)
370                                    - context_start);
371                            let text = buffer
372                                .text_for_range(context_anchor_range)
373                                .collect::<String>();
374                            (text, editable_range_in_excerpt)
375                        });
376                    telemetry::event!(
377                        "Edit Prediction Snapshot",
378                        request_id,
379                        new_cursor_region,
380                        editable_range_in_excerpt,
381                    );
382                }
383            })
384            .detach();
385        }
386
387        Ok(Some(
388            EditPredictionResult::new(
389                id,
390                &edited_buffer,
391                &edited_buffer_snapshot,
392                edits.into(),
393                cursor_position,
394                buffer_snapshotted_at,
395                received_response_at,
396                inputs,
397                model_version,
398                cx,
399            )
400            .await,
401        ))
402    })
403}
404
405pub fn zeta2_prompt_input(
406    snapshot: &language::BufferSnapshot,
407    related_files: Vec<zeta_prompt::RelatedFile>,
408    events: Vec<Arc<zeta_prompt::Event>>,
409    excerpt_path: Arc<Path>,
410    cursor_offset: usize,
411    zeta_format: ZetaFormat,
412    preferred_model: Option<EditPredictionModelKind>,
413    is_open_source: bool,
414    can_collect_data: bool,
415) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
416    let cursor_point = cursor_offset.to_point(snapshot);
417
418    let (full_context, full_context_offset_range, excerpt_ranges) =
419        compute_excerpt_ranges(cursor_point, snapshot);
420
421    let related_files = crate::filter_redundant_excerpts(
422        related_files,
423        excerpt_path.as_ref(),
424        full_context.start.row..full_context.end.row,
425    );
426
427    let full_context_start_offset = full_context_offset_range.start;
428    let full_context_start_row = full_context.start.row;
429
430    let editable_offset_range = match preferred_model {
431        Some(EditPredictionModelKind::Zeta1) => excerpt_ranges.editable_350.clone(),
432        _ => zeta_prompt::excerpt_range_for_format(zeta_format, &excerpt_ranges).0,
433    };
434
435    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
436
437    let prompt_input = zeta_prompt::ZetaPromptInput {
438        cursor_path: excerpt_path,
439        cursor_excerpt: snapshot
440            .text_for_range(full_context)
441            .collect::<String>()
442            .into(),
443        editable_range_in_excerpt: editable_offset_range,
444        cursor_offset_in_excerpt,
445        excerpt_start_row: Some(full_context_start_row),
446        events,
447        related_files,
448        excerpt_ranges: Some(excerpt_ranges),
449        preferred_model,
450        in_open_source_repo: is_open_source,
451        can_collect_data,
452    };
453    (full_context_offset_range, prompt_input)
454}
455
456pub(crate) async fn send_custom_server_request(
457    provider: settings::EditPredictionProvider,
458    settings: &OpenAiCompatibleEditPredictionSettings,
459    prompt: String,
460    max_tokens: u32,
461    stop_tokens: Vec<String>,
462    http_client: &Arc<dyn http_client::HttpClient>,
463) -> Result<(String, String)> {
464    match provider {
465        settings::EditPredictionProvider::Ollama => {
466            let response =
467                ollama::make_request(settings.clone(), prompt, stop_tokens, http_client.clone())
468                    .await?;
469            Ok((response.response, response.created_at))
470        }
471        _ => {
472            let request = RawCompletionRequest {
473                model: settings.model.clone(),
474                prompt,
475                max_tokens: Some(max_tokens),
476                temperature: None,
477                stop: stop_tokens
478                    .into_iter()
479                    .map(std::borrow::Cow::Owned)
480                    .collect(),
481                environment: None,
482            };
483
484            let request_body = serde_json::to_string(&request)?;
485            let http_request = http_client::Request::builder()
486                .method(http_client::Method::POST)
487                .uri(settings.api_url.as_ref())
488                .header("Content-Type", "application/json")
489                .body(http_client::AsyncBody::from(request_body))?;
490
491            let mut response = http_client.send(http_request).await?;
492            let status = response.status();
493
494            if !status.is_success() {
495                let mut body = String::new();
496                response.body_mut().read_to_string(&mut body).await?;
497                anyhow::bail!("custom server error: {} - {}", status, body);
498            }
499
500            let mut body = String::new();
501            response.body_mut().read_to_string(&mut body).await?;
502
503            let parsed: RawCompletionResponse =
504                serde_json::from_str(&body).context("Failed to parse completion response")?;
505            let text = parsed
506                .choices
507                .into_iter()
508                .next()
509                .map(|choice| choice.text)
510                .unwrap_or_default();
511            Ok((text, parsed.id))
512        }
513    }
514}
515
516pub(crate) fn edit_prediction_accepted(
517    store: &EditPredictionStore,
518    current_prediction: CurrentEditPrediction,
519    cx: &App,
520) {
521    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
522    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
523        return;
524    }
525
526    let request_id = current_prediction.prediction.id.to_string();
527    let model_version = current_prediction.prediction.model_version;
528    let require_auth = custom_accept_url.is_none();
529    let client = store.client.clone();
530    let llm_token = store.llm_token.clone();
531    let app_version = AppVersion::global(cx);
532
533    cx.background_spawn(async move {
534        let url = if let Some(accept_edits_url) = custom_accept_url {
535            gpui::http_client::Url::parse(&accept_edits_url)?
536        } else {
537            client
538                .http_client()
539                .build_zed_llm_url("/predict_edits/accept", &[])?
540        };
541
542        let response = EditPredictionStore::send_api_request::<()>(
543            move |builder| {
544                let req = builder.uri(url.as_ref()).body(
545                    serde_json::to_string(&AcceptEditPredictionBody {
546                        request_id: request_id.clone(),
547                        model_version: model_version.clone(),
548                    })?
549                    .into(),
550                );
551                Ok(req?)
552            },
553            client,
554            llm_token,
555            app_version,
556            require_auth,
557        )
558        .await;
559
560        response?;
561        anyhow::Ok(())
562    })
563    .detach_and_log_err(cx);
564}
565
566pub fn compute_edits(
567    old_text: String,
568    new_text: &str,
569    offset: usize,
570    snapshot: &BufferSnapshot,
571) -> Vec<(Range<Anchor>, Arc<str>)> {
572    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
573}
574
575pub fn compute_edits_and_cursor_position(
576    old_text: String,
577    new_text: &str,
578    offset: usize,
579    cursor_offset_in_new_text: Option<usize>,
580    snapshot: &BufferSnapshot,
581) -> (
582    Vec<(Range<Anchor>, Arc<str>)>,
583    Option<PredictedCursorPosition>,
584) {
585    let diffs = text_diff(&old_text, new_text);
586
587    // Delta represents the cumulative change in byte count from all preceding edits.
588    // new_offset = old_offset + delta, so old_offset = new_offset - delta
589    let mut delta: isize = 0;
590    let mut cursor_position: Option<PredictedCursorPosition> = None;
591    let buffer_len = snapshot.len();
592
593    let edits = diffs
594        .iter()
595        .map(|(raw_old_range, new_text)| {
596            // Compute cursor position if it falls within or before this edit.
597            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
598                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
599                let edit_end_in_new = edit_start_in_new + new_text.len();
600
601                if cursor_offset < edit_start_in_new {
602                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
603                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
604                    cursor_position = Some(PredictedCursorPosition::at_anchor(
605                        snapshot.anchor_after(buffer_offset),
606                    ));
607                } else if cursor_offset < edit_end_in_new {
608                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
609                    let offset_within_insertion = cursor_offset - edit_start_in_new;
610                    cursor_position = Some(PredictedCursorPosition::new(
611                        snapshot.anchor_before(buffer_offset),
612                        offset_within_insertion,
613                    ));
614                }
615
616                delta += new_text.len() as isize - raw_old_range.len() as isize;
617            }
618
619            // Compute the edit with prefix/suffix trimming.
620            let mut old_range = raw_old_range.clone();
621            let old_slice = &old_text[old_range.clone()];
622
623            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
624            let suffix_len = common_prefix(
625                old_slice[prefix_len..].chars().rev(),
626                new_text[prefix_len..].chars().rev(),
627            );
628
629            old_range.start += offset;
630            old_range.end += offset;
631            old_range.start += prefix_len;
632            old_range.end -= suffix_len;
633
634            old_range.start = old_range.start.min(buffer_len);
635            old_range.end = old_range.end.min(buffer_len);
636
637            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
638            let range = if old_range.is_empty() {
639                let anchor = snapshot.anchor_after(old_range.start);
640                anchor..anchor
641            } else {
642                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
643            };
644            (range, new_text)
645        })
646        .collect();
647
648    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
649        let cursor_in_old = (cursor_offset as isize - delta) as usize;
650        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
651        cursor_position = Some(PredictedCursorPosition::at_anchor(
652            snapshot.anchor_after(buffer_offset),
653        ));
654    }
655
656    (edits, cursor_position)
657}
658
659fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
660    a.zip(b)
661        .take_while(|(a, b)| a == b)
662        .map(|(a, _)| a.len_utf8())
663        .sum()
664}