zeta2.rs

  1use crate::prediction::EditPredictionResult;
  2use crate::{
  3    CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
  4    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  5    EditPredictionStore,
  6};
  7use anyhow::{Result, anyhow};
  8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
  9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 10use gpui::{App, Task, prelude::*};
 11use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 12use release_channel::AppVersion;
 13
 14use std::env;
 15use std::{path::Path, sync::Arc, time::Instant};
 16use zeta_prompt::format_zeta_prompt;
 17use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
 18
 19pub const MAX_CONTEXT_TOKENS: usize = 350;
 20
 21pub fn max_editable_tokens(version: ZetaVersion) -> usize {
 22    match version {
 23        ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
 24        ZetaVersion::V0114180EditableRegion | ZetaVersion::V0120GitMergeMarkers => 180,
 25    }
 26}
 27
 28pub fn request_prediction_with_zeta2(
 29    store: &mut EditPredictionStore,
 30    EditPredictionModelInput {
 31        buffer,
 32        snapshot,
 33        position,
 34        related_files,
 35        events,
 36        debug_tx,
 37        ..
 38    }: EditPredictionModelInput,
 39    zeta_version: ZetaVersion,
 40    cx: &mut Context<EditPredictionStore>,
 41) -> Task<Result<Option<EditPredictionResult>>> {
 42    let buffer_snapshotted_at = Instant::now();
 43    let custom_url = store.custom_predict_edits_url.clone();
 44
 45    let Some(excerpt_path) = snapshot
 46        .file()
 47        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 48    else {
 49        return Task::ready(Err(anyhow!("No file path for excerpt")));
 50    };
 51
 52    let client = store.client.clone();
 53    let llm_token = store.llm_token.clone();
 54    let app_version = AppVersion::global(cx);
 55
 56    let request_task = cx.background_spawn({
 57        async move {
 58            let cursor_offset = position.to_offset(&snapshot);
 59            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 60                &snapshot,
 61                related_files,
 62                events,
 63                excerpt_path,
 64                cursor_offset,
 65                zeta_version,
 66            );
 67
 68            if let Some(debug_tx) = &debug_tx {
 69                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 70                debug_tx
 71                    .unbounded_send(DebugEvent::EditPredictionStarted(
 72                        EditPredictionStartedDebugEvent {
 73                            buffer: buffer.downgrade(),
 74                            prompt: Some(prompt),
 75                            position,
 76                        },
 77                    ))
 78                    .ok();
 79            }
 80
 81            log::trace!("Sending edit prediction request");
 82
 83            let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
 84                // Use raw endpoint with custom URL
 85                let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 86                let request = RawCompletionRequest {
 87                    model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
 88                    prompt,
 89                    temperature: None,
 90                    stop: vec![],
 91                    max_tokens: Some(2048),
 92                };
 93
 94                let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
 95                    request,
 96                    client,
 97                    Some(custom_url),
 98                    llm_token,
 99                    app_version,
100                )
101                .await?;
102
103                let request_id = EditPredictionId(response.id.clone().into());
104                let output_text = response.choices.pop().map(|choice| choice.text);
105                (request_id, output_text, usage)
106            } else {
107                let (response, usage) = EditPredictionStore::send_v3_request(
108                    prompt_input.clone(),
109                    zeta_version,
110                    client,
111                    llm_token,
112                    app_version,
113                )
114                .await?;
115
116                let request_id = EditPredictionId(response.request_id.into());
117                let output_text = if response.output.is_empty() {
118                    None
119                } else {
120                    Some(response.output)
121                };
122                (request_id, output_text, usage)
123            };
124
125            let received_response_at = Instant::now();
126
127            log::trace!("Got edit prediction response");
128
129            let Some(mut output_text) = output_text else {
130                return Ok((Some((request_id, None)), usage));
131            };
132
133            if let Some(debug_tx) = &debug_tx {
134                debug_tx
135                    .unbounded_send(DebugEvent::EditPredictionFinished(
136                        EditPredictionFinishedDebugEvent {
137                            buffer: buffer.downgrade(),
138                            position,
139                            model_output: Some(output_text.clone()),
140                        },
141                    ))
142                    .ok();
143            }
144
145            if output_text.contains(CURSOR_MARKER) {
146                log::trace!("Stripping out {CURSOR_MARKER} from response");
147                output_text = output_text.replace(CURSOR_MARKER, "");
148            }
149
150            if zeta_version == ZetaVersion::V0120GitMergeMarkers {
151                if let Some(stripped) =
152                    output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
153                {
154                    output_text = stripped.to_string();
155                }
156            }
157
158            let mut old_text = snapshot
159                .text_for_range(editable_offset_range.clone())
160                .collect::<String>();
161
162            if !output_text.is_empty() && !output_text.ends_with('\n') {
163                output_text.push('\n');
164            }
165            if !old_text.is_empty() && !old_text.ends_with('\n') {
166                old_text.push('\n');
167            }
168
169            let edits: Vec<_> = language::text_diff(&old_text, &output_text)
170                .into_iter()
171                .map(|(range, text)| {
172                    (
173                        snapshot.anchor_after(editable_offset_range.start + range.start)
174                            ..snapshot.anchor_before(editable_offset_range.start + range.end),
175                        text,
176                    )
177                })
178                .collect();
179
180            anyhow::Ok((
181                Some((
182                    request_id,
183                    Some((
184                        prompt_input,
185                        buffer,
186                        snapshot.clone(),
187                        edits,
188                        received_response_at,
189                    )),
190                )),
191                usage,
192            ))
193        }
194    });
195
196    cx.spawn(async move |this, cx| {
197        let Some((id, prediction)) =
198            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
199        else {
200            return Ok(None);
201        };
202
203        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
204            prediction
205        else {
206            return Ok(Some(EditPredictionResult {
207                id,
208                prediction: Err(EditPredictionRejectReason::Empty),
209            }));
210        };
211
212        Ok(Some(
213            EditPredictionResult::new(
214                id,
215                &edited_buffer,
216                &edited_buffer_snapshot,
217                edits.into(),
218                buffer_snapshotted_at,
219                received_response_at,
220                inputs,
221                cx,
222            )
223            .await,
224        ))
225    })
226}
227
228pub fn zeta2_prompt_input(
229    snapshot: &language::BufferSnapshot,
230    related_files: Vec<zeta_prompt::RelatedFile>,
231    events: Vec<Arc<zeta_prompt::Event>>,
232    excerpt_path: Arc<Path>,
233    cursor_offset: usize,
234    zeta_version: ZetaVersion,
235) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
236    let cursor_point = cursor_offset.to_point(snapshot);
237
238    let (editable_range, context_range) =
239        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
240            cursor_point,
241            snapshot,
242            max_editable_tokens(zeta_version),
243            MAX_CONTEXT_TOKENS,
244        );
245
246    let related_files = crate::filter_redundant_excerpts(
247        related_files,
248        excerpt_path.as_ref(),
249        context_range.start.row..context_range.end.row,
250    );
251
252    let context_start_offset = context_range.start.to_offset(snapshot);
253    let editable_offset_range = editable_range.to_offset(snapshot);
254    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
255    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
256        ..(editable_offset_range.end - context_start_offset);
257
258    let prompt_input = zeta_prompt::ZetaPromptInput {
259        cursor_path: excerpt_path,
260        cursor_excerpt: snapshot
261            .text_for_range(context_range)
262            .collect::<String>()
263            .into(),
264        editable_range_in_excerpt,
265        cursor_offset_in_excerpt,
266        events,
267        related_files,
268    };
269    (editable_offset_range, prompt_input)
270}
271
272pub(crate) fn edit_prediction_accepted(
273    store: &EditPredictionStore,
274    current_prediction: CurrentEditPrediction,
275    cx: &App,
276) {
277    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
278    if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
279        return;
280    }
281
282    let request_id = current_prediction.prediction.id.to_string();
283    let require_auth = custom_accept_url.is_none();
284    let client = store.client.clone();
285    let llm_token = store.llm_token.clone();
286    let app_version = AppVersion::global(cx);
287
288    cx.background_spawn(async move {
289        let url = if let Some(accept_edits_url) = custom_accept_url {
290            gpui::http_client::Url::parse(&accept_edits_url)?
291        } else {
292            client
293                .http_client()
294                .build_zed_llm_url("/predict_edits/accept", &[])?
295        };
296
297        let response = EditPredictionStore::send_api_request::<()>(
298            move |builder| {
299                let req = builder.uri(url.as_ref()).body(
300                    serde_json::to_string(&AcceptEditPredictionBody {
301                        request_id: request_id.clone(),
302                    })?
303                    .into(),
304                );
305                Ok(req?)
306            },
307            client,
308            llm_token,
309            app_version,
310            require_auth,
311        )
312        .await;
313
314        response?;
315        anyhow::Ok(())
316    })
317    .detach_and_log_err(cx);
318}