1use crate::cursor_excerpt::{compute_excerpt_ranges, excerpt_ranges_to_byte_offsets};
2use crate::prediction::EditPredictionResult;
3use crate::zeta1::compute_edits_and_cursor_position;
4use crate::{
5 CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
6 EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
7};
8use anyhow::Result;
9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
11use gpui::{App, Task, prelude::*};
12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
13use release_channel::AppVersion;
14
15use std::env;
16use std::{path::Path, sync::Arc, time::Instant};
17use zeta_prompt::{
18 CURSOR_MARKER, EditPredictionModelKind, ZetaFormat, clean_zeta2_model_output,
19 format_zeta_prompt, get_prefill,
20};
21
22pub const MAX_CONTEXT_TOKENS: usize = 350;
23
24pub fn max_editable_tokens(format: ZetaFormat) -> usize {
25 match format {
26 ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
27 ZetaFormat::V0114180EditableRegion => 180,
28 ZetaFormat::V0120GitMergeMarkers => 180,
29 ZetaFormat::V0131GitMergeMarkersPrefix => 180,
30 ZetaFormat::V0211Prefill => 180,
31 ZetaFormat::V0211SeedCoder => 180,
32 }
33}
34
35pub fn request_prediction_with_zeta2(
36 store: &mut EditPredictionStore,
37 EditPredictionModelInput {
38 buffer,
39 snapshot,
40 position,
41 related_files,
42 events,
43 debug_tx,
44 trigger,
45 project,
46 ..
47 }: EditPredictionModelInput,
48 preferred_model: Option<EditPredictionModelKind>,
49 cx: &mut Context<EditPredictionStore>,
50) -> Task<Result<Option<EditPredictionResult>>> {
51 let buffer_snapshotted_at = Instant::now();
52 let raw_config = store.zeta2_raw_config().cloned();
53
54 let excerpt_path: Arc<Path> = snapshot
55 .file()
56 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
57 .unwrap_or_else(|| Arc::from(Path::new("untitled")));
58
59 let client = store.client.clone();
60 let llm_token = store.llm_token.clone();
61 let app_version = AppVersion::global(cx);
62
63 let is_open_source = snapshot
64 .file()
65 .map_or(false, |file| store.is_file_open_source(&project, file, cx))
66 && events.iter().all(|event| event.in_open_source_repo())
67 && related_files.iter().all(|file| file.in_open_source_repo);
68
69 let can_collect_data = is_open_source && store.is_data_collection_enabled(cx);
70
71 let request_task = cx.background_spawn({
72 async move {
73 let zeta_version = raw_config
74 .as_ref()
75 .map(|config| config.format)
76 .unwrap_or(ZetaFormat::default());
77
78 let cursor_offset = position.to_offset(&snapshot);
79 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
80 &snapshot,
81 related_files,
82 events,
83 excerpt_path,
84 cursor_offset,
85 zeta_version,
86 preferred_model,
87 is_open_source,
88 can_collect_data,
89 );
90
91 if let Some(debug_tx) = &debug_tx {
92 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
93 debug_tx
94 .unbounded_send(DebugEvent::EditPredictionStarted(
95 EditPredictionStartedDebugEvent {
96 buffer: buffer.downgrade(),
97 prompt: Some(prompt),
98 position,
99 },
100 ))
101 .ok();
102 }
103
104 log::trace!("Sending edit prediction request");
105
106 let (request_id, output_text, usage) = if let Some(config) = &raw_config {
107 let prompt = format_zeta_prompt(&prompt_input, config.format);
108 let prefill = get_prefill(&prompt_input, config.format);
109 let prompt = format!("{prompt}{prefill}");
110 let request = RawCompletionRequest {
111 model: config.model_id.clone().unwrap_or_default(),
112 prompt,
113 temperature: None,
114 stop: vec![],
115 max_tokens: Some(2048),
116 environment: Some(config.format.to_string().to_lowercase()),
117 };
118
119 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
120 request,
121 client,
122 None,
123 llm_token,
124 app_version,
125 )
126 .await?;
127
128 let request_id = EditPredictionId(response.id.clone().into());
129 let output_text = response.choices.pop().map(|choice| {
130 let response = &choice.text;
131 let output = format!("{prefill}{response}");
132 clean_zeta2_model_output(&output, config.format).to_string()
133 });
134
135 (request_id, output_text, usage)
136 } else {
137 // Use V3 endpoint - server handles model/version selection and suffix stripping
138 let (response, usage) = EditPredictionStore::send_v3_request(
139 prompt_input.clone(),
140 client,
141 llm_token,
142 app_version,
143 trigger,
144 )
145 .await?;
146
147 let request_id = EditPredictionId(response.request_id.into());
148 let output_text = if response.output.is_empty() {
149 None
150 } else {
151 Some(response.output)
152 };
153 (request_id, output_text, usage)
154 };
155
156 let received_response_at = Instant::now();
157
158 log::trace!("Got edit prediction response");
159
160 let Some(mut output_text) = output_text else {
161 return Ok((Some((request_id, None)), usage));
162 };
163
164 // Client-side cursor marker processing (applies to both raw and v3 responses)
165 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
166 if let Some(offset) = cursor_offset_in_output {
167 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
168 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
169 }
170
171 if let Some(debug_tx) = &debug_tx {
172 debug_tx
173 .unbounded_send(DebugEvent::EditPredictionFinished(
174 EditPredictionFinishedDebugEvent {
175 buffer: buffer.downgrade(),
176 position,
177 model_output: Some(output_text.clone()),
178 },
179 ))
180 .ok();
181 }
182
183 let mut old_text = snapshot
184 .text_for_range(editable_offset_range.clone())
185 .collect::<String>();
186
187 if !output_text.is_empty() && !output_text.ends_with('\n') {
188 output_text.push('\n');
189 }
190 if !old_text.is_empty() && !old_text.ends_with('\n') {
191 old_text.push('\n');
192 }
193
194 let (edits, cursor_position) = compute_edits_and_cursor_position(
195 old_text,
196 &output_text,
197 editable_offset_range.start,
198 cursor_offset_in_output,
199 &snapshot,
200 );
201
202 anyhow::Ok((
203 Some((
204 request_id,
205 Some((
206 prompt_input,
207 buffer,
208 snapshot.clone(),
209 edits,
210 cursor_position,
211 received_response_at,
212 )),
213 )),
214 usage,
215 ))
216 }
217 });
218
219 cx.spawn(async move |this, cx| {
220 let Some((id, prediction)) =
221 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
222 else {
223 return Ok(None);
224 };
225
226 let Some((
227 inputs,
228 edited_buffer,
229 edited_buffer_snapshot,
230 edits,
231 cursor_position,
232 received_response_at,
233 )) = prediction
234 else {
235 return Ok(Some(EditPredictionResult {
236 id,
237 prediction: Err(EditPredictionRejectReason::Empty),
238 }));
239 };
240
241 Ok(Some(
242 EditPredictionResult::new(
243 id,
244 &edited_buffer,
245 &edited_buffer_snapshot,
246 edits.into(),
247 cursor_position,
248 buffer_snapshotted_at,
249 received_response_at,
250 inputs,
251 cx,
252 )
253 .await,
254 ))
255 })
256}
257
258pub fn zeta2_prompt_input(
259 snapshot: &language::BufferSnapshot,
260 related_files: Vec<zeta_prompt::RelatedFile>,
261 events: Vec<Arc<zeta_prompt::Event>>,
262 excerpt_path: Arc<Path>,
263 cursor_offset: usize,
264 zeta_format: ZetaFormat,
265 preferred_model: Option<EditPredictionModelKind>,
266 is_open_source: bool,
267 can_collect_data: bool,
268) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
269 let cursor_point = cursor_offset.to_point(snapshot);
270
271 let (full_context, range_points) = compute_excerpt_ranges(cursor_point, snapshot);
272
273 let related_files = crate::filter_redundant_excerpts(
274 related_files,
275 excerpt_path.as_ref(),
276 full_context.start.row..full_context.end.row,
277 );
278
279 let full_context_start_offset = full_context.start.to_offset(snapshot);
280 let full_context_start_row = full_context.start.row;
281
282 let excerpt_ranges =
283 excerpt_ranges_to_byte_offsets(&range_points, full_context_start_offset, snapshot);
284
285 let editable_range = match preferred_model {
286 Some(EditPredictionModelKind::Zeta1) => &range_points.editable_350,
287 _ => match zeta_format {
288 ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => &range_points.editable_150,
289 _ => &range_points.editable_180,
290 },
291 };
292
293 let editable_offset_range = editable_range.to_offset(snapshot);
294 let cursor_offset_in_excerpt = cursor_offset - full_context_start_offset;
295 let editable_range_in_excerpt = (editable_offset_range.start - full_context_start_offset)
296 ..(editable_offset_range.end - full_context_start_offset);
297
298 let prompt_input = zeta_prompt::ZetaPromptInput {
299 cursor_path: excerpt_path,
300 cursor_excerpt: snapshot
301 .text_for_range(full_context)
302 .collect::<String>()
303 .into(),
304 editable_range_in_excerpt,
305 cursor_offset_in_excerpt,
306 excerpt_start_row: Some(full_context_start_row),
307 events,
308 related_files,
309 excerpt_ranges: Some(excerpt_ranges),
310 preferred_model,
311 in_open_source_repo: is_open_source,
312 can_collect_data,
313 };
314 (editable_offset_range, prompt_input)
315}
316
317pub(crate) fn edit_prediction_accepted(
318 store: &EditPredictionStore,
319 current_prediction: CurrentEditPrediction,
320 cx: &App,
321) {
322 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
323 if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
324 return;
325 }
326
327 let request_id = current_prediction.prediction.id.to_string();
328 let require_auth = custom_accept_url.is_none();
329 let client = store.client.clone();
330 let llm_token = store.llm_token.clone();
331 let app_version = AppVersion::global(cx);
332
333 cx.background_spawn(async move {
334 let url = if let Some(accept_edits_url) = custom_accept_url {
335 gpui::http_client::Url::parse(&accept_edits_url)?
336 } else {
337 client
338 .http_client()
339 .build_zed_llm_url("/predict_edits/accept", &[])?
340 };
341
342 let response = EditPredictionStore::send_api_request::<()>(
343 move |builder| {
344 let req = builder.uri(url.as_ref()).body(
345 serde_json::to_string(&AcceptEditPredictionBody {
346 request_id: request_id.clone(),
347 })?
348 .into(),
349 );
350 Ok(req?)
351 },
352 client,
353 llm_token,
354 app_version,
355 require_auth,
356 )
357 .await;
358
359 response?;
360 anyhow::Ok(())
361 })
362 .detach_and_log_err(cx);
363}