zeta2.rs

  1#[cfg(feature = "cli-support")]
  2use crate::EvalCacheEntryKind;
  3use crate::prediction::EditPredictionResult;
  4use crate::{
  5    CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
  6    EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
  7    EditPredictionStore,
  8};
  9use anyhow::{Result, anyhow};
 10use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
 11use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
 12use gpui::{App, Task, prelude::*};
 13use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 14use release_channel::AppVersion;
 15
 16use std::env;
 17use std::{path::Path, sync::Arc, time::Instant};
 18use zeta_prompt::CURSOR_MARKER;
 19use zeta_prompt::format_zeta_prompt;
 20
 21pub const MAX_CONTEXT_TOKENS: usize = 350;
 22pub const MAX_EDITABLE_TOKENS: usize = 150;
 23
 24pub fn request_prediction_with_zeta2(
 25    store: &mut EditPredictionStore,
 26    EditPredictionModelInput {
 27        buffer,
 28        snapshot,
 29        position,
 30        related_files,
 31        events,
 32        debug_tx,
 33        ..
 34    }: EditPredictionModelInput,
 35    cx: &mut Context<EditPredictionStore>,
 36) -> Task<Result<Option<EditPredictionResult>>> {
 37    let buffer_snapshotted_at = Instant::now();
 38    let url = store.custom_predict_edits_url.clone();
 39
 40    let Some(excerpt_path) = snapshot
 41        .file()
 42        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 43    else {
 44        return Task::ready(Err(anyhow!("No file path for excerpt")));
 45    };
 46
 47    let client = store.client.clone();
 48    let llm_token = store.llm_token.clone();
 49    let app_version = AppVersion::global(cx);
 50
 51    #[cfg(feature = "cli-support")]
 52    let eval_cache = store.eval_cache.clone();
 53
 54    let request_task = cx.background_spawn({
 55        async move {
 56            let cursor_offset = position.to_offset(&snapshot);
 57            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 58                &snapshot,
 59                related_files,
 60                events,
 61                excerpt_path,
 62                cursor_offset,
 63            );
 64
 65            let prompt = format_zeta_prompt(&prompt_input);
 66
 67            if let Some(debug_tx) = &debug_tx {
 68                debug_tx
 69                    .unbounded_send(DebugEvent::EditPredictionStarted(
 70                        EditPredictionStartedDebugEvent {
 71                            buffer: buffer.downgrade(),
 72                            prompt: Some(prompt.clone()),
 73                            position,
 74                        },
 75                    ))
 76                    .ok();
 77            }
 78
 79            let request = RawCompletionRequest {
 80                model: EDIT_PREDICTIONS_MODEL_ID.clone(),
 81                prompt,
 82                temperature: None,
 83                stop: vec![],
 84                max_tokens: 1024,
 85            };
 86
 87            log::trace!("Sending edit prediction request");
 88
 89            let response = EditPredictionStore::send_raw_llm_request(
 90                request,
 91                client,
 92                url,
 93                llm_token,
 94                app_version,
 95                #[cfg(feature = "cli-support")]
 96                eval_cache,
 97                #[cfg(feature = "cli-support")]
 98                EvalCacheEntryKind::Prediction,
 99            )
100            .await;
101            let received_response_at = Instant::now();
102
103            log::trace!("Got edit prediction response");
104
105            let (mut res, usage) = response?;
106            let request_id = EditPredictionId(res.id.clone().into());
107            let Some(mut output_text) = res.choices.pop().map(|choice| choice.text) else {
108                return Ok((Some((request_id, None)), usage));
109            };
110
111            if let Some(debug_tx) = &debug_tx {
112                debug_tx
113                    .unbounded_send(DebugEvent::EditPredictionFinished(
114                        EditPredictionFinishedDebugEvent {
115                            buffer: buffer.downgrade(),
116                            position,
117                            model_output: Some(output_text.clone()),
118                        },
119                    ))
120                    .ok();
121            }
122
123            if output_text.contains(CURSOR_MARKER) {
124                log::trace!("Stripping out {CURSOR_MARKER} from response");
125                output_text = output_text.replace(CURSOR_MARKER, "");
126            }
127
128            let old_text = snapshot
129                .text_for_range(editable_offset_range.clone())
130                .collect::<String>();
131            let edits: Vec<_> = language::text_diff(&old_text, &output_text)
132                .into_iter()
133                .map(|(range, text)| {
134                    (
135                        snapshot.anchor_after(editable_offset_range.start + range.start)
136                            ..snapshot.anchor_before(editable_offset_range.start + range.end),
137                        text,
138                    )
139                })
140                .collect();
141
142            anyhow::Ok((
143                Some((
144                    request_id,
145                    Some((
146                        prompt_input,
147                        buffer,
148                        snapshot.clone(),
149                        edits,
150                        received_response_at,
151                    )),
152                )),
153                usage,
154            ))
155        }
156    });
157
158    cx.spawn(async move |this, cx| {
159        let Some((id, prediction)) =
160            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
161        else {
162            return Ok(None);
163        };
164
165        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
166            prediction
167        else {
168            return Ok(Some(EditPredictionResult {
169                id,
170                prediction: Err(EditPredictionRejectReason::Empty),
171            }));
172        };
173
174        Ok(Some(
175            EditPredictionResult::new(
176                id,
177                &edited_buffer,
178                &edited_buffer_snapshot,
179                edits.into(),
180                buffer_snapshotted_at,
181                received_response_at,
182                inputs,
183                cx,
184            )
185            .await,
186        ))
187    })
188}
189
190pub fn zeta2_prompt_input(
191    snapshot: &language::BufferSnapshot,
192    related_files: Arc<[zeta_prompt::RelatedFile]>,
193    events: Vec<Arc<zeta_prompt::Event>>,
194    excerpt_path: Arc<Path>,
195    cursor_offset: usize,
196) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
197    let cursor_point = cursor_offset.to_point(snapshot);
198
199    let (editable_range, context_range) =
200        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
201            cursor_point,
202            snapshot,
203            MAX_EDITABLE_TOKENS,
204            MAX_CONTEXT_TOKENS,
205        );
206
207    let context_start_offset = context_range.start.to_offset(snapshot);
208    let editable_offset_range = editable_range.to_offset(snapshot);
209    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
210    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
211        ..(editable_offset_range.end - context_start_offset);
212
213    let prompt_input = zeta_prompt::ZetaPromptInput {
214        cursor_path: excerpt_path,
215        cursor_excerpt: snapshot
216            .text_for_range(context_range)
217            .collect::<String>()
218            .into(),
219        editable_range_in_excerpt,
220        cursor_offset_in_excerpt,
221        events,
222        related_files,
223    };
224    (editable_offset_range, prompt_input)
225}
226
227pub(crate) fn edit_prediction_accepted(
228    store: &EditPredictionStore,
229    current_prediction: CurrentEditPrediction,
230    cx: &App,
231) {
232    let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
233    if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
234        return;
235    }
236
237    let request_id = current_prediction.prediction.id.to_string();
238    let require_auth = custom_accept_url.is_none();
239    let client = store.client.clone();
240    let llm_token = store.llm_token.clone();
241    let app_version = AppVersion::global(cx);
242
243    cx.background_spawn(async move {
244        let url = if let Some(accept_edits_url) = custom_accept_url {
245            gpui::http_client::Url::parse(&accept_edits_url)?
246        } else {
247            client
248                .http_client()
249                .build_zed_llm_url("/predict_edits/accept", &[])?
250        };
251
252        let response = EditPredictionStore::send_api_request::<()>(
253            move |builder| {
254                let req = builder.uri(url.as_ref()).body(
255                    serde_json::to_string(&AcceptEditPredictionBody {
256                        request_id: request_id.clone(),
257                    })?
258                    .into(),
259                );
260                Ok(req?)
261            },
262            client,
263            llm_token,
264            app_version,
265            require_auth,
266        )
267        .await;
268
269        response?;
270        anyhow::Ok(())
271    })
272    .detach_and_log_err(cx);
273}