1use crate::prediction::EditPredictionResult;
2use crate::zeta1::compute_edits_and_cursor_position;
3use crate::{
4 CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
5 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
6 EditPredictionStore,
7};
8use anyhow::{Result, anyhow};
9use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
10use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
11use gpui::{App, Task, prelude::*};
12use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
13use release_channel::AppVersion;
14
15use std::env;
16use std::{path::Path, sync::Arc, time::Instant};
17use zeta_prompt::format_zeta_prompt;
18use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
19
20pub const MAX_CONTEXT_TOKENS: usize = 350;
21
22pub fn max_editable_tokens(version: ZetaVersion) -> usize {
23 match version {
24 ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
25 ZetaVersion::V0114180EditableRegion => 180,
26 ZetaVersion::V0120GitMergeMarkers => 180,
27 ZetaVersion::V0131GitMergeMarkersPrefix => 180,
28 }
29}
30
31pub fn request_prediction_with_zeta2(
32 store: &mut EditPredictionStore,
33 EditPredictionModelInput {
34 buffer,
35 snapshot,
36 position,
37 related_files,
38 events,
39 debug_tx,
40 trigger,
41 ..
42 }: EditPredictionModelInput,
43 zeta_version: ZetaVersion,
44 cx: &mut Context<EditPredictionStore>,
45) -> Task<Result<Option<EditPredictionResult>>> {
46 let buffer_snapshotted_at = Instant::now();
47 let custom_url = store.custom_predict_edits_url.clone();
48
49 let Some(excerpt_path) = snapshot
50 .file()
51 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
52 else {
53 return Task::ready(Err(anyhow!("No file path for excerpt")));
54 };
55
56 let client = store.client.clone();
57 let llm_token = store.llm_token.clone();
58 let app_version = AppVersion::global(cx);
59
60 let request_task = cx.background_spawn({
61 async move {
62 let cursor_offset = position.to_offset(&snapshot);
63 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
64 &snapshot,
65 related_files,
66 events,
67 excerpt_path,
68 cursor_offset,
69 zeta_version,
70 );
71
72 if let Some(debug_tx) = &debug_tx {
73 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
74 debug_tx
75 .unbounded_send(DebugEvent::EditPredictionStarted(
76 EditPredictionStartedDebugEvent {
77 buffer: buffer.downgrade(),
78 prompt: Some(prompt),
79 position,
80 },
81 ))
82 .ok();
83 }
84
85 log::trace!("Sending edit prediction request");
86
87 let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
88 // Use raw endpoint with custom URL
89 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
90 let request = RawCompletionRequest {
91 model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
92 prompt,
93 temperature: None,
94 stop: vec![],
95 max_tokens: Some(2048),
96 };
97
98 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
99 request,
100 client,
101 Some(custom_url),
102 llm_token,
103 app_version,
104 )
105 .await?;
106
107 let request_id = EditPredictionId(response.id.clone().into());
108 let output_text = response.choices.pop().map(|choice| choice.text);
109 (request_id, output_text, usage)
110 } else {
111 let (response, usage) = EditPredictionStore::send_v3_request(
112 prompt_input.clone(),
113 zeta_version,
114 client,
115 llm_token,
116 app_version,
117 trigger,
118 )
119 .await?;
120
121 let request_id = EditPredictionId(response.request_id.into());
122 let output_text = if response.output.is_empty() {
123 None
124 } else {
125 Some(response.output)
126 };
127 (request_id, output_text, usage)
128 };
129
130 let received_response_at = Instant::now();
131
132 log::trace!("Got edit prediction response");
133
134 let Some(mut output_text) = output_text else {
135 return Ok((Some((request_id, None)), usage));
136 };
137
138 if let Some(debug_tx) = &debug_tx {
139 debug_tx
140 .unbounded_send(DebugEvent::EditPredictionFinished(
141 EditPredictionFinishedDebugEvent {
142 buffer: buffer.downgrade(),
143 position,
144 model_output: Some(output_text.clone()),
145 },
146 ))
147 .ok();
148 }
149
150 let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
151 if let Some(offset) = cursor_offset_in_output {
152 log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
153 output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
154 }
155
156 if zeta_version == ZetaVersion::V0120GitMergeMarkers {
157 if let Some(stripped) =
158 output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
159 {
160 output_text = stripped.to_string();
161 }
162 }
163
164 let mut old_text = snapshot
165 .text_for_range(editable_offset_range.clone())
166 .collect::<String>();
167
168 if !output_text.is_empty() && !output_text.ends_with('\n') {
169 output_text.push('\n');
170 }
171 if !old_text.is_empty() && !old_text.ends_with('\n') {
172 old_text.push('\n');
173 }
174
175 let (edits, cursor_position) = compute_edits_and_cursor_position(
176 old_text,
177 &output_text,
178 editable_offset_range.start,
179 cursor_offset_in_output,
180 &snapshot,
181 );
182
183 anyhow::Ok((
184 Some((
185 request_id,
186 Some((
187 prompt_input,
188 buffer,
189 snapshot.clone(),
190 edits,
191 cursor_position,
192 received_response_at,
193 )),
194 )),
195 usage,
196 ))
197 }
198 });
199
200 cx.spawn(async move |this, cx| {
201 let Some((id, prediction)) =
202 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
203 else {
204 return Ok(None);
205 };
206
207 let Some((
208 inputs,
209 edited_buffer,
210 edited_buffer_snapshot,
211 edits,
212 cursor_position,
213 received_response_at,
214 )) = prediction
215 else {
216 return Ok(Some(EditPredictionResult {
217 id,
218 prediction: Err(EditPredictionRejectReason::Empty),
219 }));
220 };
221
222 Ok(Some(
223 EditPredictionResult::new(
224 id,
225 &edited_buffer,
226 &edited_buffer_snapshot,
227 edits.into(),
228 cursor_position,
229 buffer_snapshotted_at,
230 received_response_at,
231 inputs,
232 cx,
233 )
234 .await,
235 ))
236 })
237}
238
239pub fn zeta2_prompt_input(
240 snapshot: &language::BufferSnapshot,
241 related_files: Vec<zeta_prompt::RelatedFile>,
242 events: Vec<Arc<zeta_prompt::Event>>,
243 excerpt_path: Arc<Path>,
244 cursor_offset: usize,
245 zeta_version: ZetaVersion,
246) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
247 let cursor_point = cursor_offset.to_point(snapshot);
248
249 let (editable_range, context_range) =
250 crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
251 cursor_point,
252 snapshot,
253 max_editable_tokens(zeta_version),
254 MAX_CONTEXT_TOKENS,
255 );
256
257 let related_files = crate::filter_redundant_excerpts(
258 related_files,
259 excerpt_path.as_ref(),
260 context_range.start.row..context_range.end.row,
261 );
262
263 let context_start_offset = context_range.start.to_offset(snapshot);
264 let context_start_row = context_range.start.row;
265 let editable_offset_range = editable_range.to_offset(snapshot);
266 let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
267 let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
268 ..(editable_offset_range.end - context_start_offset);
269
270 let prompt_input = zeta_prompt::ZetaPromptInput {
271 cursor_path: excerpt_path,
272 cursor_excerpt: snapshot
273 .text_for_range(context_range)
274 .collect::<String>()
275 .into(),
276 editable_range_in_excerpt,
277 cursor_offset_in_excerpt,
278 excerpt_start_row: Some(context_start_row),
279 events,
280 related_files,
281 };
282 (editable_offset_range, prompt_input)
283}
284
285pub(crate) fn edit_prediction_accepted(
286 store: &EditPredictionStore,
287 current_prediction: CurrentEditPrediction,
288 cx: &App,
289) {
290 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
291 if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
292 return;
293 }
294
295 let request_id = current_prediction.prediction.id.to_string();
296 let require_auth = custom_accept_url.is_none();
297 let client = store.client.clone();
298 let llm_token = store.llm_token.clone();
299 let app_version = AppVersion::global(cx);
300
301 cx.background_spawn(async move {
302 let url = if let Some(accept_edits_url) = custom_accept_url {
303 gpui::http_client::Url::parse(&accept_edits_url)?
304 } else {
305 client
306 .http_client()
307 .build_zed_llm_url("/predict_edits/accept", &[])?
308 };
309
310 let response = EditPredictionStore::send_api_request::<()>(
311 move |builder| {
312 let req = builder.uri(url.as_ref()).body(
313 serde_json::to_string(&AcceptEditPredictionBody {
314 request_id: request_id.clone(),
315 })?
316 .into(),
317 );
318 Ok(req?)
319 },
320 client,
321 llm_token,
322 app_version,
323 require_auth,
324 )
325 .await;
326
327 response?;
328 anyhow::Ok(())
329 })
330 .detach_and_log_err(cx);
331}