zeta2.rs

  1#[cfg(feature = "cli-support")]
  2use crate::EvalCacheEntryKind;
  3use crate::prediction::EditPredictionResult;
  4use crate::{
  5    CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
  6    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  7    EditPredictionStore,
  8};
  9use anyhow::{Result, anyhow};
 10use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 11use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 12use gpui::{App, Task, prelude::*};
 13use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 14use release_channel::AppVersion;
 15
 16use std::env;
 17use std::{path::Path, sync::Arc, time::Instant};
 18use zeta_prompt::format_zeta_prompt;
 19use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
 20
 21pub const MAX_CONTEXT_TOKENS: usize = 350;
 22pub const MAX_EDITABLE_TOKENS: usize = 150;
 23
 24pub fn request_prediction_with_zeta2(
 25    store: &mut EditPredictionStore,
 26    EditPredictionModelInput {
 27        buffer,
 28        snapshot,
 29        position,
 30        related_files,
 31        events,
 32        debug_tx,
 33        ..
 34    }: EditPredictionModelInput,
 35    zeta_version: ZetaVersion,
 36    cx: &mut Context<EditPredictionStore>,
 37) -> Task<Result<Option<EditPredictionResult>>> {
 38    let buffer_snapshotted_at = Instant::now();
 39    let url = store.custom_predict_edits_url.clone();
 40
 41    let Some(excerpt_path) = snapshot
 42        .file()
 43        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 44    else {
 45        return Task::ready(Err(anyhow!("No file path for excerpt")));
 46    };
 47
 48    let client = store.client.clone();
 49    let llm_token = store.llm_token.clone();
 50    let app_version = AppVersion::global(cx);
 51
 52    #[cfg(feature = "cli-support")]
 53    let eval_cache = store.eval_cache.clone();
 54
 55    let request_task = cx.background_spawn({
 56        async move {
 57            let cursor_offset = position.to_offset(&snapshot);
 58            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 59                &snapshot,
 60                related_files,
 61                events,
 62                excerpt_path,
 63                cursor_offset,
 64            );
 65
 66            let prompt = format_zeta_prompt(&prompt_input, zeta_version);
 67
 68            if let Some(debug_tx) = &debug_tx {
 69                debug_tx
 70                    .unbounded_send(DebugEvent::EditPredictionStarted(
 71                        EditPredictionStartedDebugEvent {
 72                            buffer: buffer.downgrade(),
 73                            prompt: Some(prompt.clone()),
 74                            position,
 75                        },
 76                    ))
 77                    .ok();
 78            }
 79
 80            let request = RawCompletionRequest {
 81                model: EDIT_PREDICTIONS_MODEL_ID.clone(),
 82                prompt,
 83                temperature: None,
 84                stop: vec![],
 85                max_tokens: Some(2048),
 86            };
 87
 88            log::trace!("Sending edit prediction request");
 89
 90            let response = EditPredictionStore::send_raw_llm_request(
 91                request,
 92                client,
 93                url,
 94                llm_token,
 95                app_version,
 96                #[cfg(feature = "cli-support")]
 97                eval_cache,
 98                #[cfg(feature = "cli-support")]
 99                EvalCacheEntryKind::Prediction,
100            )
101            .await;
102            let received_response_at = Instant::now();
103
104            log::trace!("Got edit prediction response");
105
106            let (mut res, usage) = response?;
107            let request_id = EditPredictionId(res.id.clone().into());
108            let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else {
109                return Ok((Some((request_id, None)), usage));
110            };
111
112            if let Some(debug_tx) = &debug_tx {
113                debug_tx
114                    .unbounded_send(DebugEvent::EditPredictionFinished(
115                        EditPredictionFinishedDebugEvent {
116                            buffer: buffer.downgrade(),
117                            position,
118                            model_output: Some(output_text.clone()),
119                        },
120                    ))
121                    .ok();
122            }
123
124            if output_text.contains(CURSOR_MARKER) {
125                log::trace!("Stripping out {CURSOR_MARKER} from response");
126                output_text = output_text.replace(CURSOR_MARKER, "");
127            }
128
129            let mut old_text = snapshot
130                .text_for_range(editable_offset_range.clone())
131                .collect::<String>();
132
133            if !output_text.is_empty() && !output_text.ends_with('\n') {
134                output_text.push('\n');
135            }
136            if !old_text.is_empty() && !old_text.ends_with('\n') {
137                old_text.push('\n');
138            }
139
140            let edits: Vec<_> = language::text_diff(&old_text, &output_text)
141                .into_iter()
142                .map(|(range, text)| {
143                    (
144                        snapshot.anchor_after(editable_offset_range.start + range.start)
145                            ..snapshot.anchor_before(editable_offset_range.start + range.end),
146                        text,
147                    )
148                })
149                .collect();
150
151            anyhow::Ok((
152                Some((
153                    request_id,
154                    Some((
155                        prompt_input,
156                        buffer,
157                        snapshot.clone(),
158                        edits,
159                        received_response_at,
160                    )),
161                )),
162                usage,
163            ))
164        }
165    });
166
167    cx.spawn(async move |this, cx| {
168        let Some((id, prediction)) =
169            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
170        else {
171            return Ok(None);
172        };
173
174        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
175            prediction
176        else {
177            return Ok(Some(EditPredictionResult {
178                id,
179                prediction: Err(EditPredictionRejectReason::Empty),
180            }));
181        };
182
183        Ok(Some(
184            EditPredictionResult::new(
185                id,
186                &edited_buffer,
187                &edited_buffer_snapshot,
188                edits.into(),
189                buffer_snapshotted_at,
190                received_response_at,
191                inputs,
192                cx,
193            )
194            .await,
195        ))
196    })
197}
198
199pub fn zeta2_prompt_input(
200    snapshot: &language::BufferSnapshot,
201    related_files: Vec<zeta_prompt::RelatedFile>,
202    events: Vec<Arc<zeta_prompt::Event>>,
203    excerpt_path: Arc<Path>,
204    cursor_offset: usize,
205) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
206    let cursor_point = cursor_offset.to_point(snapshot);
207
208    let (editable_range, context_range) =
209        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
210            cursor_point,
211            snapshot,
212            MAX_EDITABLE_TOKENS,
213            MAX_CONTEXT_TOKENS,
214        );
215
216    let related_files = crate::filter_redundant_excerpts(
217        related_files,
218        excerpt_path.as_ref(),
219        context_range.start.row..context_range.end.row,
220    );
221
222    let context_start_offset = context_range.start.to_offset(snapshot);
223    let editable_offset_range = editable_range.to_offset(snapshot);
224    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
225    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
226        ..(editable_offset_range.end - context_start_offset);
227
228    let prompt_input = zeta_prompt::ZetaPromptInput {
229        cursor_path: excerpt_path,
230        cursor_excerpt: snapshot
231            .text_for_range(context_range)
232            .collect::<String>()
233            .into(),
234        editable_range_in_excerpt,
235        cursor_offset_in_excerpt,
236        events,
237        related_files,
238    };
239    (editable_offset_range, prompt_input)
240}
241
242pub(crate) fn edit_prediction_accepted(
243    store: &EditPredictionStore,
244    current_prediction: CurrentEditPrediction,
245    cx: &App,
246) {
247    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
248    if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
249        return;
250    }
251
252    let request_id = current_prediction.prediction.id.to_string();
253    let require_auth = custom_accept_url.is_none();
254    let client = store.client.clone();
255    let llm_token = store.llm_token.clone();
256    let app_version = AppVersion::global(cx);
257
258    cx.background_spawn(async move {
259        let url = if let Some(accept_edits_url) = custom_accept_url {
260            gpui::http_client::Url::parse(&accept_edits_url)?
261        } else {
262            client
263                .http_client()
264                .build_zed_llm_url("/predict_edits/accept", &[])?
265        };
266
267        let response = EditPredictionStore::send_api_request::<()>(
268            move |builder| {
269                let req = builder.uri(url.as_ref()).body(
270                    serde_json::to_string(&AcceptEditPredictionBody {
271                        request_id: request_id.clone(),
272                    })?
273                    .into(),
274                );
275                Ok(req?)
276            },
277            client,
278            llm_token,
279            app_version,
280            require_auth,
281        )
282        .await;
283
284        response?;
285        anyhow::Ok(())
286    })
287    .detach_and_log_err(cx);
288}