zeta2.rs

  1#[cfg(feature = "cli-support")]
  2use crate::EvalCacheEntryKind;
  3use crate::open_ai_response::text_from_response;
  4use crate::prediction::EditPredictionResult;
  5use crate::{
  6    DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
  7    EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
  8};
  9use anyhow::{Result, anyhow};
 10use cloud_llm_client::EditPredictionRejectReason;
 11use gpui::{Task, prelude::*};
 12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
 13use release_channel::AppVersion;
 14use std::{path::Path, sync::Arc, time::Instant};
 15use zeta_prompt::CURSOR_MARKER;
 16use zeta_prompt::format_zeta_prompt;
 17
 18const MAX_CONTEXT_TOKENS: usize = 150;
 19const MAX_REWRITE_TOKENS: usize = 350;
 20
 21pub fn request_prediction_with_zeta2(
 22    store: &mut EditPredictionStore,
 23    EditPredictionModelInput {
 24        buffer,
 25        snapshot,
 26        position,
 27        related_files,
 28        events,
 29        debug_tx,
 30        ..
 31    }: EditPredictionModelInput,
 32    cx: &mut Context<EditPredictionStore>,
 33) -> Task<Result<Option<EditPredictionResult>>> {
 34    let buffer_snapshotted_at = Instant::now();
 35
 36    let Some(excerpt_path) = snapshot
 37        .file()
 38        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 39    else {
 40        return Task::ready(Err(anyhow!("No file path for excerpt")));
 41    };
 42
 43    let client = store.client.clone();
 44    let llm_token = store.llm_token.clone();
 45    let app_version = AppVersion::global(cx);
 46
 47    #[cfg(feature = "cli-support")]
 48    let eval_cache = store.eval_cache.clone();
 49
 50    let request_task = cx.background_spawn({
 51        async move {
 52            let cursor_offset = position.to_offset(&snapshot);
 53            let (editable_offset_range, prompt_input) = zeta2_prompt_input(
 54                &snapshot,
 55                related_files,
 56                events,
 57                excerpt_path,
 58                cursor_offset,
 59            );
 60
 61            let prompt = format_zeta_prompt(&prompt_input);
 62
 63            if let Some(debug_tx) = &debug_tx {
 64                debug_tx
 65                    .unbounded_send(DebugEvent::EditPredictionStarted(
 66                        EditPredictionStartedDebugEvent {
 67                            buffer: buffer.downgrade(),
 68                            prompt: Some(prompt.clone()),
 69                            position,
 70                        },
 71                    ))
 72                    .ok();
 73            }
 74
 75            let request = open_ai::Request {
 76                model: EDIT_PREDICTIONS_MODEL_ID.clone(),
 77                messages: vec![open_ai::RequestMessage::User {
 78                    content: open_ai::MessageContent::Plain(prompt),
 79                }],
 80                stream: false,
 81                max_completion_tokens: None,
 82                stop: Default::default(),
 83                temperature: Default::default(),
 84                tool_choice: None,
 85                parallel_tool_calls: None,
 86                tools: vec![],
 87                prompt_cache_key: None,
 88                reasoning_effort: None,
 89            };
 90
 91            log::trace!("Sending edit prediction request");
 92
 93            let response = EditPredictionStore::send_raw_llm_request(
 94                request,
 95                client,
 96                llm_token,
 97                app_version,
 98                #[cfg(feature = "cli-support")]
 99                eval_cache,
100                #[cfg(feature = "cli-support")]
101                EvalCacheEntryKind::Prediction,
102            )
103            .await;
104            let received_response_at = Instant::now();
105
106            log::trace!("Got edit prediction response");
107
108            let (res, usage) = response?;
109            let request_id = EditPredictionId(res.id.clone().into());
110            let Some(mut output_text) = text_from_response(res) else {
111                return Ok((Some((request_id, None)), usage));
112            };
113
114            if let Some(debug_tx) = &debug_tx {
115                debug_tx
116                    .unbounded_send(DebugEvent::EditPredictionFinished(
117                        EditPredictionFinishedDebugEvent {
118                            buffer: buffer.downgrade(),
119                            position,
120                            model_output: Some(output_text.clone()),
121                        },
122                    ))
123                    .ok();
124            }
125
126            if output_text.contains(CURSOR_MARKER) {
127                log::trace!("Stripping out {CURSOR_MARKER} from response");
128                output_text = output_text.replace(CURSOR_MARKER, "");
129            }
130
131            let old_text = snapshot
132                .text_for_range(editable_offset_range.clone())
133                .collect::<String>();
134            let edits: Vec<_> = language::text_diff(&old_text, &output_text)
135                .into_iter()
136                .map(|(range, text)| {
137                    (
138                        snapshot.anchor_after(editable_offset_range.start + range.start)
139                            ..snapshot.anchor_before(editable_offset_range.start + range.end),
140                        text,
141                    )
142                })
143                .collect();
144
145            anyhow::Ok((
146                Some((
147                    request_id,
148                    Some((
149                        prompt_input,
150                        buffer,
151                        snapshot.clone(),
152                        edits,
153                        received_response_at,
154                    )),
155                )),
156                usage,
157            ))
158        }
159    });
160
161    cx.spawn(async move |this, cx| {
162        let Some((id, prediction)) =
163            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
164        else {
165            return Ok(None);
166        };
167
168        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
169            prediction
170        else {
171            return Ok(Some(EditPredictionResult {
172                id,
173                prediction: Err(EditPredictionRejectReason::Empty),
174            }));
175        };
176
177        Ok(Some(
178            EditPredictionResult::new(
179                id,
180                &edited_buffer,
181                &edited_buffer_snapshot,
182                edits.into(),
183                buffer_snapshotted_at,
184                received_response_at,
185                inputs,
186                cx,
187            )
188            .await,
189        ))
190    })
191}
192
193pub fn zeta2_prompt_input(
194    snapshot: &language::BufferSnapshot,
195    related_files: Arc<[zeta_prompt::RelatedFile]>,
196    events: Vec<Arc<zeta_prompt::Event>>,
197    excerpt_path: Arc<Path>,
198    cursor_offset: usize,
199) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
200    let cursor_point = cursor_offset.to_point(snapshot);
201
202    let (editable_range, context_range) =
203        crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
204            cursor_point,
205            snapshot,
206            MAX_CONTEXT_TOKENS,
207            MAX_REWRITE_TOKENS,
208        );
209
210    let context_start_offset = context_range.start.to_offset(snapshot);
211    let editable_offset_range = editable_range.to_offset(snapshot);
212    let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
213    let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
214        ..(editable_offset_range.end - context_start_offset);
215
216    let prompt_input = zeta_prompt::ZetaPromptInput {
217        cursor_path: excerpt_path,
218        cursor_excerpt: snapshot
219            .text_for_range(context_range)
220            .collect::<String>()
221            .into(),
222        editable_range_in_excerpt,
223        cursor_offset_in_excerpt,
224        events,
225        related_files,
226    };
227    (editable_offset_range, prompt_input)
228}
229
230#[cfg(feature = "cli-support")]
231pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
232    let text = &input.cursor_excerpt;
233    let editable_region = input.editable_range_in_excerpt.clone();
234    let old_prefix = &text[..editable_region.start];
235    let old_suffix = &text[editable_region.end..];
236
237    let new = crate::udiff::apply_diff_to_string(patch, text)?;
238    if !new.starts_with(old_prefix) || !new.ends_with(old_suffix) {
239        anyhow::bail!("Patch shouldn't affect text outside of editable region");
240    }
241
242    Ok(new[editable_region.start..new.len() - old_suffix.len()].to_string())
243}