zeta.rs

  1use crate::cursor_excerpt::compute_excerpt_ranges;
  2use crate::prediction::EditPredictionResult;
  3use crate::{
  4    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  5    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
  6};
  7use anyhow::Result;
  8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
  9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 10use edit_prediction_types::PredictedCursorPosition;
 11use gpui::{App, AppContext as _, Task, prelude::*};
 12use language::language_settings::all_language_settings;
 13use language::{BufferSnapshot, ToOffset as _, ToPoint, text_diff};
 14use release_channel::AppVersion;
 15use settings::EditPredictionPromptFormat;
 16use text::{Anchor, Bias};
 17
 18use std::{env, ops::Range, path::Path, sync::Arc, time::Instant};
 19use zeta_prompt::{
 20    CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt, get_prefill,
 21    output_with_context_for_format, prompt_input_contains_special_tokens,
 22    zeta1::{self, EDITABLE_REGION_END_MARKER},
 23};
 24
 25use crate::open_ai_compatible::{
 26    load_open_ai_compatible_api_key_if_needed, send_custom_server_request,
 27};
 28
 29pub fn request_prediction_with_zeta(
 30    store: &mut EditPredictionStore,
 31    EditPredictionModelInput {
 32        buffer,
 33        snapshot,
 34        position,
 35        related_files,
 36        events,
 37        debug_tx,
 38        trigger,
 39        project,
 40        can_collect_data,
 41        is_open_source,
 42        ..
 43    }: EditPredictionModelInput,
 44    cx: &mut Context<EditPredictionStore>,
 45) -> Task<Result<Option<EditPredictionResult>>> {
 46    let settings = &all_language_settings(None, cx).edit_predictions;
 47    let provider = settings.provider;
 48    let custom_server_settings = match provider {
 49        settings::EditPredictionProvider::Ollama => settings.ollama.clone(),
 50        settings::EditPredictionProvider::OpenAiCompatibleApi => {
 51            settings.open_ai_compatible_api.clone()
 52        }
 53        _ => None,
 54    };
 55
 56    let http_client = cx.http_client();
 57    let buffer_snapshotted_at = Instant::now();
 58    let raw_config = store.zeta2_raw_config().cloned();
 59    let preferred_experiment = store.preferred_experiment().map(|s| s.to_owned());
 60    let open_ai_compatible_api_key = load_open_ai_compatible_api_key_if_needed(provider, cx);
 61
 62    let excerpt_path: Arc<Path> = snapshot
 63        .file()
 64        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 65        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 66
 67    let repo_url = if can_collect_data {
 68        let buffer_id = buffer.read(cx).remote_id();
 69        project
 70            .read(cx)
 71            .git_store()
 72            .read(cx)
 73            .repository_and_path_for_buffer_id(buffer_id, cx)
 74            .and_then(|(repo, _)| repo.read(cx).default_remote_url())
 75    } else {
 76        None
 77    };
 78
 79    let client = store.client.clone();
 80    let llm_token = store.llm_token.clone();
 81    let organization_id = store
 82        .user_store
 83        .read(cx)
 84        .current_organization()
 85        .map(|organization| organization.id.clone());
 86    let app_version = AppVersion::global(cx);
 87
 88    let request_task = cx.background_spawn({
 89        async move {
 90            let zeta_version = raw_config
 91                .as_ref()
 92                .map(|config| config.format)
 93                .unwrap_or(ZetaFormat::default());
 94
 95            let cursor_offset = position.to_offset(&snapshot);
 96            let editable_range_in_excerpt: Range<usize>;
 97            let (full_context_offset_range, prompt_input) = zeta2_prompt_input(
 98                &snapshot,
 99                related_files,
100                events,
101                excerpt_path,
102                cursor_offset,
103                preferred_experiment,
104                is_open_source,
105                can_collect_data,
106                repo_url,
107            );
108
109            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
110                return Ok((None, None));
111            }
112
113            if let Some(debug_tx) = &debug_tx {
114                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
115                debug_tx
116                    .unbounded_send(DebugEvent::EditPredictionStarted(
117                        EditPredictionStartedDebugEvent {
118                            buffer: buffer.downgrade(),
119                            prompt: Some(prompt),
120                            position,
121                        },
122                    ))
123                    .ok();
124            }
125
126            log::trace!("Sending edit prediction request");
127
128            let (request_id, output_text, model_version, usage) =
129                if let Some(custom_settings) = &custom_server_settings {
130                    let max_tokens = custom_settings.max_output_tokens * 4;
131
132                    match custom_settings.prompt_format {
133                        EditPredictionPromptFormat::Zeta => {
134                            let ranges = &prompt_input.excerpt_ranges;
135                            let prompt = zeta1::format_zeta1_from_input(
136                                &prompt_input,
137                                ranges.editable_350.clone(),
138                                ranges.editable_350_context_150.clone(),
139                            );
140                            editable_range_in_excerpt = ranges.editable_350.clone();
141                            let stop_tokens = vec![
142                                EDITABLE_REGION_END_MARKER.to_string(),
143                                format!("{EDITABLE_REGION_END_MARKER}\n"),
144                                format!("{EDITABLE_REGION_END_MARKER}\n\n"),
145                                format!("{EDITABLE_REGION_END_MARKER}\n\n\n"),
146                            ];
147
148                            let (response_text, request_id) = send_custom_server_request(
149                                provider,
150                                custom_settings,
151                                prompt,
152                                max_tokens,
153                                stop_tokens,
154                                open_ai_compatible_api_key.clone(),
155                                &http_client,
156                            )
157                            .await?;
158
159                            let request_id = EditPredictionId(request_id.into());
160                            let output_text = zeta1::clean_zeta1_model_output(&response_text);
161
162                            (request_id, output_text, None, None)
163                        }
164                        EditPredictionPromptFormat::Zeta2 => {
165                            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
166                            let prefill = get_prefill(&prompt_input, zeta_version);
167                            let prompt = format!("{prompt}{prefill}");
168
169                            editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
170                                zeta_version,
171                                &prompt_input.excerpt_ranges,
172                            )
173                            .0;
174
175                            let (response_text, request_id) = send_custom_server_request(
176                                provider,
177                                custom_settings,
178                                prompt,
179                                max_tokens,
180                                vec![],
181                                open_ai_compatible_api_key.clone(),
182                                &http_client,
183                            )
184                            .await?;
185
186                            let request_id = EditPredictionId(request_id.into());
187                            let output_text = if response_text.is_empty() {
188                                None
189                            } else {
190                                let output = format!("{prefill}{response_text}");
191                                Some(clean_zeta2_model_output(&output, zeta_version).to_string())
192                            };
193
194                            (request_id, output_text, None, None)
195                        }
196                        _ => anyhow::bail!("unsupported prompt format"),
197                    }
198                } else if let Some(config) = &raw_config {
199                    let prompt = format_zeta_prompt(&prompt_input, config.format);
200                    let prefill = get_prefill(&prompt_input, config.format);
201                    let prompt = format!("{prompt}{prefill}");
202                    let environment = config
203                        .environment
204                        .clone()
205                        .or_else(|| Some(config.format.to_string().to_lowercase()));
206                    let request = RawCompletionRequest {
207                        model: config.model_id.clone().unwrap_or_default(),
208                        prompt,
209                        temperature: None,
210                        stop: vec![],
211                        max_tokens: Some(2048),
212                        environment,
213                    };
214
215                    editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
216                        config.format,
217                        &prompt_input.excerpt_ranges,
218                    )
219                    .1;
220
221                    let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
222                        request,
223                        client,
224                        None,
225                        llm_token,
226                        organization_id,
227                        app_version,
228                    )
229                    .await?;
230
231                    let request_id = EditPredictionId(response.id.clone().into());
232                    let output_text = response.choices.pop().map(|choice| {
233                        let response = &choice.text;
234                        let output = format!("{prefill}{response}");
235                        clean_zeta2_model_output(&output, config.format).to_string()
236                    });
237
238                    (request_id, output_text, None, usage)
239                } else {
240                    // Use V3 endpoint - server handles model/version selection and suffix stripping
241                    let (response, usage) = EditPredictionStore::send_v3_request(
242                        prompt_input.clone(),
243                        client,
244                        llm_token,
245                        organization_id,
246                        app_version,
247                        trigger,
248                    )
249                    .await?;
250
251                    let request_id = EditPredictionId(response.request_id.into());
252                    let output_text = if response.output.is_empty() {
253                        None
254                    } else {
255                        Some(response.output)
256                    };
257                    editable_range_in_excerpt = response.editable_range;
258                    let model_version = response.model_version;
259
260                    (request_id, output_text, model_version, usage)
261                };
262
263            let received_response_at = Instant::now();
264
265            log::trace!("Got edit prediction response");
266
267            let Some(mut output_text) = output_text else {
268                return Ok((Some((request_id, None, model_version)), usage));
269            };
270
271            let editable_range_in_buffer = editable_range_in_excerpt.start
272                + full_context_offset_range.start
273                ..editable_range_in_excerpt.end + full_context_offset_range.start;
274
275            let mut old_text = snapshot
276                .text_for_range(editable_range_in_buffer.clone())
277                .collect::<String>();
278
279            // For the hashline format, the model may return <|set|>/<|insert|>
280            // edit commands instead of a full replacement. Apply them against
281            // the original editable region to produce the full replacement text.
282            // This must happen before cursor marker stripping because the cursor
283            // marker is embedded inside edit command content.
284            if let Some(rewritten_output) =
285                output_with_context_for_format(zeta_version, &old_text, &output_text)?
286            {
287                output_text = rewritten_output;
288            }
289
290            // Client-side cursor marker processing (applies to both raw and v3 responses)
291            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
292            if let Some(offset) = cursor_offset_in_output {
293                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
294                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
295            }
296
297            if let Some(debug_tx) = &debug_tx {
298                debug_tx
299                    .unbounded_send(DebugEvent::EditPredictionFinished(
300                        EditPredictionFinishedDebugEvent {
301                            buffer: buffer.downgrade(),
302                            position,
303                            model_output: Some(output_text.clone()),
304                        },
305                    ))
306                    .ok();
307            }
308
309            if !output_text.is_empty() && !output_text.ends_with('\n') {
310                output_text.push('\n');
311            }
312            if !old_text.is_empty() && !old_text.ends_with('\n') {
313                old_text.push('\n');
314            }
315
316            let (edits, cursor_position) = compute_edits_and_cursor_position(
317                old_text,
318                &output_text,
319                editable_range_in_buffer.start,
320                cursor_offset_in_output,
321                &snapshot,
322            );
323
324            anyhow::Ok((
325                Some((
326                    request_id,
327                    Some((
328                        prompt_input,
329                        buffer,
330                        snapshot.clone(),
331                        edits,
332                        cursor_position,
333                        received_response_at,
334                        editable_range_in_buffer,
335                    )),
336                    model_version,
337                )),
338                usage,
339            ))
340        }
341    });
342
343    cx.spawn(async move |this, cx| {
344        let Some((id, prediction, model_version)) =
345            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
346        else {
347            return Ok(None);
348        };
349
350        let Some((
351            inputs,
352            edited_buffer,
353            edited_buffer_snapshot,
354            edits,
355            cursor_position,
356            received_response_at,
357            editable_range_in_buffer,
358        )) = prediction
359        else {
360            return Ok(Some(EditPredictionResult {
361                id,
362                prediction: Err(EditPredictionRejectReason::Empty),
363            }));
364        };
365
366        if can_collect_data {
367            this.update(cx, |this, cx| {
368                this.enqueue_settled_prediction(
369                    id.clone(),
370                    &project,
371                    &edited_buffer,
372                    &edited_buffer_snapshot,
373                    editable_range_in_buffer,
374                    cx,
375                );
376            })
377            .ok();
378        }
379
380        Ok(Some(
381            EditPredictionResult::new(
382                id,
383                &edited_buffer,
384                &edited_buffer_snapshot,
385                edits.into(),
386                cursor_position,
387                buffer_snapshotted_at,
388                received_response_at,
389                inputs,
390                model_version,
391                cx,
392            )
393            .await,
394        ))
395    })
396}
397
398pub fn zeta2_prompt_input(
399    snapshot: &language::BufferSnapshot,
400    related_files: Vec<zeta_prompt::RelatedFile>,
401    events: Vec<Arc<zeta_prompt::Event>>,
402    excerpt_path: Arc<Path>,
403    cursor_offset: usize,
404    preferred_experiment: Option<String>,
405    is_open_source: bool,
406    can_collect_data: bool,
407    repo_url: Option<String>,
408) -> (Range<usize>, zeta_prompt::ZetaPromptInput) {
409    let cursor_point = cursor_offset.to_point(snapshot);
410
411    let (full_context, full_context_offset_range, excerpt_ranges) =
412        compute_excerpt_ranges(cursor_point, snapshot);
413
414    let related_files = crate::filter_redundant_excerpts(
415        related_files,
416        excerpt_path.as_ref(),
417        full_context.start.row..full_context.end.row,
418    );
419
420    let full_context_start_offset = full_context_offset_range.start;
421    let full_context_start_row = full_context.start.row;
422
423    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
424
425    let prompt_input = zeta_prompt::ZetaPromptInput {
426        cursor_path: excerpt_path,
427        cursor_excerpt: snapshot
428            .text_for_range(full_context)
429            .collect::<String>()
430            .into(),
431        cursor_offset_in_excerpt,
432        excerpt_start_row: Some(full_context_start_row),
433        events,
434        related_files,
435        excerpt_ranges,
436        experiment: preferred_experiment,
437        in_open_source_repo: is_open_source,
438        can_collect_data,
439        repo_url,
440    };
441    (full_context_offset_range, prompt_input)
442}
443
444pub(crate) fn edit_prediction_accepted(
445    store: &EditPredictionStore,
446    current_prediction: CurrentEditPrediction,
447    cx: &App,
448) {
449    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
450    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
451        return;
452    }
453
454    let request_id = current_prediction.prediction.id.to_string();
455    let model_version = current_prediction.prediction.model_version;
456    let require_auth = custom_accept_url.is_none();
457    let client = store.client.clone();
458    let llm_token = store.llm_token.clone();
459    let organization_id = store
460        .user_store
461        .read(cx)
462        .current_organization()
463        .map(|organization| organization.id.clone());
464    let app_version = AppVersion::global(cx);
465
466    cx.background_spawn(async move {
467        let url = if let Some(accept_edits_url) = custom_accept_url {
468            gpui::http_client::Url::parse(&accept_edits_url)?
469        } else {
470            client
471                .http_client()
472                .build_zed_llm_url("/predict_edits/accept", &[])?
473        };
474
475        let response = EditPredictionStore::send_api_request::<()>(
476            move |builder| {
477                let req = builder.uri(url.as_ref()).body(
478                    serde_json::to_string(&AcceptEditPredictionBody {
479                        request_id: request_id.clone(),
480                        model_version: model_version.clone(),
481                    })?
482                    .into(),
483                );
484                Ok(req?)
485            },
486            client,
487            llm_token,
488            organization_id,
489            app_version,
490            require_auth,
491        )
492        .await;
493
494        response?;
495        anyhow::Ok(())
496    })
497    .detach_and_log_err(cx);
498}
499
500pub fn compute_edits(
501    old_text: String,
502    new_text: &str,
503    offset: usize,
504    snapshot: &BufferSnapshot,
505) -> Vec<(Range<Anchor>, Arc<str>)> {
506    compute_edits_and_cursor_position(old_text, new_text, offset, None, snapshot).0
507}
508
509pub fn compute_edits_and_cursor_position(
510    old_text: String,
511    new_text: &str,
512    offset: usize,
513    cursor_offset_in_new_text: Option<usize>,
514    snapshot: &BufferSnapshot,
515) -> (
516    Vec<(Range<Anchor>, Arc<str>)>,
517    Option<PredictedCursorPosition>,
518) {
519    let diffs = text_diff(&old_text, new_text);
520
521    // Delta represents the cumulative change in byte count from all preceding edits.
522    // new_offset = old_offset + delta, so old_offset = new_offset - delta
523    let mut delta: isize = 0;
524    let mut cursor_position: Option<PredictedCursorPosition> = None;
525    let buffer_len = snapshot.len();
526
527    let edits = diffs
528        .iter()
529        .map(|(raw_old_range, new_text)| {
530            // Compute cursor position if it falls within or before this edit.
531            if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
532                let edit_start_in_new = (raw_old_range.start as isize + delta) as usize;
533                let edit_end_in_new = edit_start_in_new + new_text.len();
534
535                if cursor_offset < edit_start_in_new {
536                    let cursor_in_old = (cursor_offset as isize - delta) as usize;
537                    let buffer_offset = (offset + cursor_in_old).min(buffer_len);
538                    cursor_position = Some(PredictedCursorPosition::at_anchor(
539                        snapshot.anchor_after(buffer_offset),
540                    ));
541                } else if cursor_offset < edit_end_in_new {
542                    let buffer_offset = (offset + raw_old_range.start).min(buffer_len);
543                    let offset_within_insertion = cursor_offset - edit_start_in_new;
544                    cursor_position = Some(PredictedCursorPosition::new(
545                        snapshot.anchor_before(buffer_offset),
546                        offset_within_insertion,
547                    ));
548                }
549
550                delta += new_text.len() as isize - raw_old_range.len() as isize;
551            }
552
553            // Compute the edit with prefix/suffix trimming.
554            let mut old_range = raw_old_range.clone();
555            let old_slice = &old_text[old_range.clone()];
556
557            let prefix_len = common_prefix(old_slice.chars(), new_text.chars());
558            let suffix_len = common_prefix(
559                old_slice[prefix_len..].chars().rev(),
560                new_text[prefix_len..].chars().rev(),
561            );
562
563            old_range.start += offset;
564            old_range.end += offset;
565            old_range.start += prefix_len;
566            old_range.end -= suffix_len;
567
568            old_range.start = old_range.start.min(buffer_len);
569            old_range.end = old_range.end.min(buffer_len);
570
571            let new_text = new_text[prefix_len..new_text.len() - suffix_len].into();
572            let range = if old_range.is_empty() {
573                let anchor = snapshot.anchor_after(old_range.start);
574                anchor..anchor
575            } else {
576                snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
577            };
578            (range, new_text)
579        })
580        .collect();
581
582    if let (Some(cursor_offset), None) = (cursor_offset_in_new_text, cursor_position) {
583        let cursor_in_old = (cursor_offset as isize - delta) as usize;
584        let buffer_offset = snapshot.clip_offset(offset + cursor_in_old, Bias::Right);
585        cursor_position = Some(PredictedCursorPosition::at_anchor(
586            snapshot.anchor_after(buffer_offset),
587        ));
588    }
589
590    (edits, cursor_position)
591}
592
593fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
594    a.zip(b)
595        .take_while(|(a, b)| a == b)
596        .map(|(a, _)| a.len_utf8())
597        .sum()
598}