zeta2.rs

  1#[cfg(feature = "eval-support")]
  2use crate::EvalCacheEntryKind;
  3use crate::open_ai_response::text_from_response;
  4use crate::prediction::EditPredictionResult;
  5use crate::{
  6    DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
  7    EditPredictionRequestedDebugEvent, EditPredictionStore,
  8};
  9use anyhow::{Result, anyhow, bail};
 10use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
 11use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
 12use cloud_zeta2_prompt::CURSOR_MARKER;
 13use edit_prediction_context::{EditPredictionExcerpt, Line};
 14use edit_prediction_context::{RelatedExcerpt, RelatedFile};
 15use futures::channel::oneshot;
 16use gpui::{Entity, Task, prelude::*};
 17use language::{Anchor, BufferSnapshot};
 18use language::{Buffer, Point, ToOffset as _, ToPoint};
 19use project::{Project, ProjectItem as _};
 20use release_channel::AppVersion;
 21use std::{
 22    env,
 23    path::Path,
 24    sync::Arc,
 25    time::{Duration, Instant},
 26};
 27
 28pub fn request_prediction_with_zeta2(
 29    store: &mut EditPredictionStore,
 30    project: &Entity<Project>,
 31    active_buffer: &Entity<Buffer>,
 32    active_snapshot: BufferSnapshot,
 33    position: Anchor,
 34    events: Vec<Arc<Event>>,
 35    mut included_files: Vec<RelatedFile>,
 36    trigger: PredictEditsRequestTrigger,
 37    cx: &mut Context<EditPredictionStore>,
 38) -> Task<Result<Option<EditPredictionResult>>> {
 39    let options = store.options.clone();
 40    let buffer_snapshotted_at = Instant::now();
 41
 42    let Some((excerpt_path, active_project_path)) = active_snapshot
 43        .file()
 44        .map(|file| -> Arc<Path> { file.full_path(cx).into() })
 45        .zip(active_buffer.read(cx).project_path(cx))
 46    else {
 47        return Task::ready(Err(anyhow!("No file path for excerpt")));
 48    };
 49
 50    let client = store.client.clone();
 51    let llm_token = store.llm_token.clone();
 52    let app_version = AppVersion::global(cx);
 53    let debug_tx = store.debug_tx.clone();
 54
 55    let file = active_buffer.read(cx).file();
 56
 57    let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
 58
 59    // TODO data collection
 60    let can_collect_data = file
 61        .as_ref()
 62        .map_or(false, |file| store.can_collect_file(project, file, cx));
 63
 64    #[cfg(feature = "eval-support")]
 65    let eval_cache = store.eval_cache.clone();
 66
 67    let request_task = cx.background_spawn({
 68        let active_buffer = active_buffer.clone();
 69        async move {
 70            let cursor_offset = position.to_offset(&active_snapshot);
 71            let cursor_point = cursor_offset.to_point(&active_snapshot);
 72
 73            let before_retrieval = Instant::now();
 74
 75            let excerpt_options = options.context;
 76
 77            let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
 78                cursor_point,
 79                &active_snapshot,
 80                &excerpt_options,
 81            ) else {
 82                return Ok((None, None));
 83            };
 84
 85            let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
 86                ..active_snapshot.anchor_before(excerpt.range.end);
 87            let related_excerpt = RelatedExcerpt {
 88                anchor_range: excerpt_anchor_range.clone(),
 89                point_range: Point::new(excerpt.line_range.start.0, 0)
 90                    ..Point::new(excerpt.line_range.end.0, 0),
 91                text: active_snapshot.as_rope().slice(excerpt.range),
 92            };
 93
 94            if let Some(buffer_ix) = included_files
 95                .iter()
 96                .position(|file| file.buffer.entity_id() == active_buffer.entity_id())
 97            {
 98                let file = &mut included_files[buffer_ix];
 99                file.excerpts.push(related_excerpt);
100                file.merge_excerpts();
101                let last_ix = included_files.len() - 1;
102                included_files.swap(buffer_ix, last_ix);
103            } else {
104                let active_file = RelatedFile {
105                    path: active_project_path,
106                    buffer: active_buffer.downgrade(),
107                    excerpts: vec![related_excerpt],
108                    max_row: active_snapshot.max_point().row,
109                };
110                included_files.push(active_file);
111            }
112
113            let included_files = included_files
114                .iter()
115                .map(|related_file| predict_edits_v3::RelatedFile {
116                    path: Arc::from(related_file.path.path.as_std_path()),
117                    max_row: Line(related_file.max_row),
118                    excerpts: related_file
119                        .excerpts
120                        .iter()
121                        .map(|excerpt| predict_edits_v3::Excerpt {
122                            start_line: Line(excerpt.point_range.start.row),
123                            text: excerpt.text.to_string().into(),
124                        })
125                        .collect(),
126                })
127                .collect::<Vec<_>>();
128
129            let cloud_request = predict_edits_v3::PredictEditsRequest {
130                excerpt_path,
131                excerpt: String::new(),
132                excerpt_line_range: Line(0)..Line(0),
133                excerpt_range: 0..0,
134                cursor_point: predict_edits_v3::Point {
135                    line: predict_edits_v3::Line(cursor_point.row),
136                    column: cursor_point.column,
137                },
138                related_files: included_files,
139                events,
140                can_collect_data,
141                debug_info: debug_tx.is_some(),
142                prompt_max_bytes: Some(options.max_prompt_bytes),
143                prompt_format: options.prompt_format,
144                excerpt_parent: None,
145                git_info: None,
146                trigger,
147            };
148
149            let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
150
151            let inputs = EditPredictionInputs {
152                included_files: cloud_request.related_files,
153                events: cloud_request.events,
154                cursor_point: cloud_request.cursor_point,
155                cursor_path: cloud_request.excerpt_path,
156            };
157
158            let retrieval_time = Instant::now() - before_retrieval;
159
160            let debug_response_tx = if let Some(debug_tx) = &debug_tx {
161                let (response_tx, response_rx) = oneshot::channel();
162
163                debug_tx
164                    .unbounded_send(DebugEvent::EditPredictionRequested(
165                        EditPredictionRequestedDebugEvent {
166                            inputs: inputs.clone(),
167                            retrieval_time,
168                            buffer: active_buffer.downgrade(),
169                            local_prompt: match prompt_result.as_ref() {
170                                Ok(prompt) => Ok(prompt.clone()),
171                                Err(err) => Err(err.to_string()),
172                            },
173                            position,
174                            response_rx,
175                        },
176                    ))
177                    .ok();
178                Some(response_tx)
179            } else {
180                None
181            };
182
183            if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
184                if let Some(debug_response_tx) = debug_response_tx {
185                    debug_response_tx
186                        .send((Err("Request skipped".to_string()), Duration::ZERO))
187                        .ok();
188                }
189                anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
190            }
191
192            let prompt = prompt_result?;
193            let generation_params =
194                cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
195            let request = open_ai::Request {
196                model: EDIT_PREDICTIONS_MODEL_ID.clone(),
197                messages: vec![open_ai::RequestMessage::User {
198                    content: open_ai::MessageContent::Plain(prompt),
199                }],
200                stream: false,
201                max_completion_tokens: None,
202                stop: generation_params.stop.unwrap_or_default(),
203                temperature: generation_params.temperature.or(Some(0.7)),
204                tool_choice: None,
205                parallel_tool_calls: None,
206                tools: vec![],
207                prompt_cache_key: None,
208                reasoning_effort: None,
209            };
210
211            log::trace!("Sending edit prediction request");
212
213            let before_request = Instant::now();
214            let response = EditPredictionStore::send_raw_llm_request(
215                request,
216                client,
217                llm_token,
218                app_version,
219                #[cfg(feature = "eval-support")]
220                eval_cache,
221                #[cfg(feature = "eval-support")]
222                EvalCacheEntryKind::Prediction,
223            )
224            .await;
225            let received_response_at = Instant::now();
226            let request_time = received_response_at - before_request;
227
228            log::trace!("Got edit prediction response");
229
230            if let Some(debug_response_tx) = debug_response_tx {
231                debug_response_tx
232                    .send((
233                        response
234                            .as_ref()
235                            .map_err(|err| err.to_string())
236                            .map(|response| response.0.clone()),
237                        request_time,
238                    ))
239                    .ok();
240            }
241
242            let (res, usage) = response?;
243            let request_id = EditPredictionId(res.id.clone().into());
244            let Some(mut output_text) = text_from_response(res) else {
245                return Ok((Some((request_id, None)), usage));
246            };
247
248            if output_text.contains(CURSOR_MARKER) {
249                log::trace!("Stripping out {CURSOR_MARKER} from response");
250                output_text = output_text.replace(CURSOR_MARKER, "");
251            }
252
253            let get_buffer_from_context = |path: &Path| {
254                if Some(path) == active_file_full_path.as_deref() {
255                    Some((
256                        &active_snapshot,
257                        std::slice::from_ref(&excerpt_anchor_range),
258                    ))
259                } else {
260                    None
261                }
262            };
263
264            let (_, edits) = match options.prompt_format {
265                PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
266                    if output_text.contains("--- a/\n+++ b/\nNo edits") {
267                        let edits = vec![];
268                        (&active_snapshot, edits)
269                    } else {
270                        crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
271                    }
272                }
273                PromptFormat::OldTextNewText => {
274                    crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
275                }
276                _ => {
277                    bail!("unsupported prompt format {}", options.prompt_format)
278                }
279            };
280
281            anyhow::Ok((
282                Some((
283                    request_id,
284                    Some((
285                        inputs,
286                        active_buffer,
287                        active_snapshot.clone(),
288                        edits,
289                        received_response_at,
290                    )),
291                )),
292                usage,
293            ))
294        }
295    });
296
297    cx.spawn(async move |this, cx| {
298        let Some((id, prediction)) =
299            EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
300        else {
301            return Ok(None);
302        };
303
304        let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
305            prediction
306        else {
307            return Ok(Some(EditPredictionResult {
308                id,
309                prediction: Err(EditPredictionRejectReason::Empty),
310            }));
311        };
312
313        Ok(Some(
314            EditPredictionResult::new(
315                id,
316                &edited_buffer,
317                &edited_buffer_snapshot,
318                edits.into(),
319                buffer_snapshotted_at,
320                received_response_at,
321                inputs,
322                cx,
323            )
324            .await,
325        ))
326    })
327}