zeta2.rs

  1use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets};
  2use crate::prediction::EditPredictionResult;
  3use crate::zeta1::compute_edits_and_cursor_position;
  4use crate::{
  5    CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
  6    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
  7};
  8use anyhow::Result;
  9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 11use gpui::{App, Task, prelude::*};
 12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 13use release_channel::AppVersion;
 14
 15use std::env;
 16use std::{path::Path, sync::Arc, time::Instant};
 17use zeta_prompt::{
 18    CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
 19    format_zeta_prompt, get_prefill, prompt_input_contains_special_tokens,
 20};
 21
 22pub const MAX_CONTEXT_TOKENS: usize = 350;
 23
 24pub fn max_editable_tokens(format: ZetaFormat) -> usize {
 25    match format {
 26        ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
 27        ZetaFormat::V0114180EditableRegion => 180,
 28        ZetaFormat::V0120GitMergeMarkers => 180,
 29        ZetaFormat::V0131GitMergeMarkersPrefix => 180,
 30        ZetaFormat::V0211Prefill => 180,
 31        ZetaFormat::V0211SeedCoder => 180,
 32    }
 33}
 34
 35pub fn request_prediction_with_zeta2(
 36    store: &mut EditPredictionStore,
 37    EditPredictionModelInput {
 38        buffer,
 39        snapshot,
 40        position,
 41        related_files,
 42        events,
 43        debug_tx,
 44        trigger,
 45        project,
 46        ..
 47    }: EditPredictionModelInput,
 48    preferred_model: Option<EditPredictionModelKind>,
 49    cx: &mut Context<EditPredictionStore>,
 50) -> Task<Result<Option<EditPredictionResult>>> {
 51    let buffer_snapshotted_at = Instant::now();
 52    let raw_config = store.zeta2_raw_config().cloned();
 53
 54    let excerpt_path: Arc<Path> = snapshot
 55        .file()
 56        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 57        .unwrap_or_else(|| Arc::from(Path::new("untitled")));
 58
 59    let client = store.client.clone();
 60    let llm_token = store.llm_token.clone();
 61    let app_version = AppVersion::global(cx);
 62
 63    let is_open_source = snapshot
 64        .file()
 65        .map_or(false, |file| store.is_file_open_source(&project, file, cx))
 66        && events.iter().all(|event| event.in_open_source_repo())
 67        && related_files.iter().all(|file| file.in_open_source_repo);
 68
 69    let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
 70
 71    let request_task = cx.background_spawn({
 72        async move {
 73            let zeta_version = raw_config
 74                .as_ref()
 75                .map(|config| config.format)
 76                .unwrap_or(ZetaFormat::default());
 77
 78            let cursor_offset = position.to_offset(&snapshot);
 79            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 80                &snapshot,
 81                related_files,
 82                events,
 83                excerpt_path,
 84                cursor_offset,
 85                zeta_version,
 86                preferred_model,
 87                is_open_source,
 88                can_collect_data,
 89            );
 90
 91            if prompt_input_contains_special_tokens(&prompt_input, zeta_version) {
 92                return Ok((None, None));
 93            }
 94
 95            if let Some(debug_tx) = &debug_tx {
 96                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 97                debug_tx
 98                    .unbounded_send(DebugEvent::EditPredictionStarted(
 99                        EditPredictionStartedDebugEvent {
100                            buffer: buffer.downgrade(),
101                            prompt: Some(prompt),
102                            position,
103                        },
104                    ))
105                    .ok();
106            }
107
108            log::trace!("Sending edit prediction request");
109
110            let (request_id, output_text, usage) = if let Some(config) = &raw_config {
111                let prompt = format_zeta_prompt(&prompt_input, config.format);
112                let prefill = get_prefill(&prompt_input, config.format);
113                let prompt = format!("{prompt}{prefill}");
114                let request = RawCompletionRequest {
115                    model: config.model_id.clone().unwrap_or_default(),
116                    prompt,
117                    temperature: None,
118                    stop: vec![],
119                    max_tokens: Some(2048),
120                    environment: Some(config.format.to_string().to_lowercase()),
121                };
122
123                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
124                    request,
125                    client,
126                    None,
127                    llm_token,
128                    app_version,
129                )
130                .await?;
131
132                let request_id = EditPredictionId(response.id.clone().into());
133                let output_text = response.choices.pop().map(|choice| {
134                    let response = &choice.text;
135                    let output = format!("{prefill}{response}");
136                    clean_zeta2_model_output(&output, config.format).to_string()
137                });
138
139                (request_id, output_text, usage)
140            } else {
141                // Use V3 endpoint - server handles model/version selection and suffix stripping
142                let (response, usage) = EditPredictionStore::send_v3_request(
143                    prompt_input.clone(),
144                    client,
145                    llm_token,
146                    app_version,
147                    trigger,
148                )
149                .await?;
150
151                let request_id = EditPredictionId(response.request_id.into());
152                let output_text = if response.output.is_empty() {
153                    None
154                } else {
155                    Some(response.output)
156                };
157                (request_id, output_text, usage)
158            };
159
160            let received_response_at = Instant::now();
161
162            log::trace!("Got edit prediction response");
163
164            let Some(mut output_text) = output_text else {
165                return Ok((Some((request_id, None)), usage));
166            };
167
168            // Client-side cursor marker processing (applies to both raw and v3 responses)
169            let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
170            if let Some(offset) = cursor_offset_in_output {
171                log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
172                output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
173            }
174
175            if let Some(debug_tx) = &debug_tx {
176                debug_tx
177                    .unbounded_send(DebugEvent::EditPredictionFinished(
178                        EditPredictionFinishedDebugEvent {
179                            buffer: buffer.downgrade(),
180                            position,
181                            model_output: Some(output_text.clone()),
182                        },
183                    ))
184                    .ok();
185            }
186
187            let mut old_text = snapshot
188                .text_for_range(editable_offset_range.clone())
189                .collect::<String>();
190
191            if !output_text.is_empty() && !output_text.ends_with('\n') {
192                output_text.push('\n');
193            }
194            if !old_text.is_empty() && !old_text.ends_with('\n') {
195                old_text.push('\n');
196            }
197
198            let (edits, cursor_position) = compute_edits_and_cursor_position(
199                old_text,
200                &output_text,
201                editable_offset_range.start,
202                cursor_offset_in_output,
203                &snapshot,
204            );
205
206            anyhow::Ok((
207                Some((
208                    request_id,
209                    Some((
210                        prompt_input,
211                        buffer,
212                        snapshot.clone(),
213                        edits,
214                        cursor_position,
215                        received_response_at,
216                    )),
217                )),
218                usage,
219            ))
220        }
221    });
222
223    cx.spawn(async move |this, cx| {
224        let Some((id, prediction)) =
225            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
226        else {
227            return Ok(None);
228        };
229
230        let Some((
231            inputs,
232            edited_buffer,
233            edited_buffer_snapshot,
234            edits,
235            cursor_position,
236            received_response_at,
237        )) = prediction
238        else {
239            return Ok(Some(EditPredictionResult {
240                id,
241                prediction: Err(EditPredictionRejectReason::Empty),
242            }));
243        };
244
245        Ok(Some(
246            EditPredictionResult::new(
247                id,
248                &edited_buffer,
249                &edited_buffer_snapshot,
250                edits.into(),
251                cursor_position,
252                buffer_snapshotted_at,
253                received_response_at,
254                inputs,
255                cx,
256            )
257            .await,
258        ))
259    })
260}
261
262pub fn zeta2_prompt_input(
263    snapshot: &language::BufferSnapshot,
264    related_files: Vec<zeta_prompt::RelatedFile>,
265    events: Vec<Arc<zeta_prompt::Event>>,
266    excerpt_path: Arc<Path>,
267    cursor_offset: usize,
268    zeta_format: ZetaFormat,
269    preferred_model: Option<EditPredictionModelKind>,
270    is_open_source: bool,
271    can_collect_data: bool,
272) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
273    let cursor_point = cursor_offset.to_point(snapshot);
274
275    let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot);
276
277    let related_files = crate::filter_redundant_excerpts(
278        related_files,
279        excerpt_path.as_ref(),
280        full_context.start.row..full_context.end.row,
281    );
282
283    let full_context_start_offset = full_context.start.to_offset(snapshot);
284    let full_context_start_row = full_context.start.row;
285
286    let excerpt_ranges =
287        excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot);
288
289    let editable_range = match preferred_model {
290        Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350,
291        _ => match zeta_format {
292            ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150,
293            _ => &range_points.editable_180,
294        },
295    };
296
297    let editable_offset_range = editable_range.to_offset(snapshot);
298    let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
299    let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset)
300        ..(editable_offset_range.end - full_context_start_offset);
301
302    let prompt_input = zeta_prompt::ZetaPromptInput {
303        cursor_path: excerpt_path,
304        cursor_excerpt: snapshot
305            .text_for_range(full_context)
306            .collect::<String>()
307            .into(),
308        editable_range_in_excerpt,
309        cursor_offset_in_excerpt,
310        excerpt_start_row: Some(full_context_start_row),
311        events,
312        related_files,
313        excerpt_ranges: Some(excerpt_ranges),
314        preferred_model,
315        in_open_source_repo: is_open_source,
316        can_collect_data,
317    };
318    (editable_offset_range, prompt_input)
319}
320
321pub(crate) fn edit_prediction_accepted(
322    store: &EditPredictionStore,
323    current_prediction: CurrentEditPrediction,
324    cx: &App,
325) {
326    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
327    if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
328        return;
329    }
330
331    let request_id = current_prediction.prediction.id.to_string();
332    let require_auth = custom_accept_url.is_none();
333    let client = store.client.clone();
334    let llm_token = store.llm_token.clone();
335    let app_version = AppVersion::global(cx);
336
337    cx.background_spawn(async move {
338        let url = if let Some(accept_edits_url) = custom_accept_url {
339            gpui::http_client::Url::parse(&accept_edits_url)?
340        } else {
341            client
342                .http_client()
343                .build_zed_llm_url("/predict_edits/accept", &[])?
344        };
345
346        let response = EditPredictionStore::send_api_request::<()>(
347            move |builder| {
348                let req = builder.uri(url.as_ref()).body(
349                    serde_json::to_string(&AcceptEditPredictionBody {
350                        request_id: request_id.clone(),
351                    })?
352                    .into(),
353                );
354                Ok(req?)
355            },
356            client,
357            llm_token,
358            app_version,
359            require_auth,
360        )
361        .await;
362
363        response?;
364        anyhow::Ok(())
365    })
366    .detach_and_log_err(cx);
367}