1use crate::prediction::EditPredictionResult;
2use crate::{
3 CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
4 EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
5 EditPredictionStore,
6};
7use anyhow::{Result, anyhow};
8use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
9use cloud_llm_client::{AcceptEditPredictionBody, EditPredictionRejectReason};
10use gpui::{App, Task, prelude::*};
11use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
12use release_channel::AppVersion;
13
14use std::env;
15use std::{path::Path, sync::Arc, time::Instant};
16use zeta_prompt::format_zeta_prompt;
17use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
18
19pub const MAX_CONTEXT_TOKENS: usize = 350;
20
21pub fn max_editable_tokens(version: ZetaVersion) -> usize {
22 match version {
23 ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
24 ZetaVersion::V0114180EditableRegion | ZetaVersion::V0120GitMergeMarkers => 180,
25 }
26}
27
28pub fn request_prediction_with_zeta2(
29 store: &mut EditPredictionStore,
30 EditPredictionModelInput {
31 buffer,
32 snapshot,
33 position,
34 related_files,
35 events,
36 debug_tx,
37 ..
38 }: EditPredictionModelInput,
39 zeta_version: ZetaVersion,
40 cx: &mut Context<EditPredictionStore>,
41) -> Task<Result<Option<EditPredictionResult>>> {
42 let buffer_snapshotted_at = Instant::now();
43 let custom_url = store.custom_predict_edits_url.clone();
44
45 let Some(excerpt_path) = snapshot
46 .file()
47 .map(|file| -> Arc<Path> { file.full_path(cx).into() })
48 else {
49 return Task::ready(Err(anyhow!("No file path for excerpt")));
50 };
51
52 let client = store.client.clone();
53 let llm_token = store.llm_token.clone();
54 let app_version = AppVersion::global(cx);
55
56 let request_task = cx.background_spawn({
57 async move {
58 let cursor_offset = position.to_offset(&snapshot);
59 let (editable_offset_range, prompt_input) = zeta2_prompt_input(
60 &snapshot,
61 related_files,
62 events,
63 excerpt_path,
64 cursor_offset,
65 zeta_version,
66 );
67
68 if let Some(debug_tx) = &debug_tx {
69 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
70 debug_tx
71 .unbounded_send(DebugEvent::EditPredictionStarted(
72 EditPredictionStartedDebugEvent {
73 buffer: buffer.downgrade(),
74 prompt: Some(prompt),
75 position,
76 },
77 ))
78 .ok();
79 }
80
81 log::trace!("Sending edit prediction request");
82
83 let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
84 // Use raw endpoint with custom URL
85 let prompt = format_zeta_prompt(&prompt_input, zeta_version);
86 let request = RawCompletionRequest {
87 model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
88 prompt,
89 temperature: None,
90 stop: vec![],
91 max_tokens: Some(2048),
92 };
93
94 let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
95 request,
96 client,
97 Some(custom_url),
98 llm_token,
99 app_version,
100 )
101 .await?;
102
103 let request_id = EditPredictionId(response.id.clone().into());
104 let output_text = response.choices.pop().map(|choice| choice.text);
105 (request_id, output_text, usage)
106 } else {
107 let (response, usage) = EditPredictionStore::send_v3_request(
108 prompt_input.clone(),
109 zeta_version,
110 client,
111 llm_token,
112 app_version,
113 )
114 .await?;
115
116 let request_id = EditPredictionId(response.request_id.into());
117 let output_text = if response.output.is_empty() {
118 None
119 } else {
120 Some(response.output)
121 };
122 (request_id, output_text, usage)
123 };
124
125 let received_response_at = Instant::now();
126
127 log::trace!("Got edit prediction response");
128
129 let Some(mut output_text) = output_text else {
130 return Ok((Some((request_id, None)), usage));
131 };
132
133 if let Some(debug_tx) = &debug_tx {
134 debug_tx
135 .unbounded_send(DebugEvent::EditPredictionFinished(
136 EditPredictionFinishedDebugEvent {
137 buffer: buffer.downgrade(),
138 position,
139 model_output: Some(output_text.clone()),
140 },
141 ))
142 .ok();
143 }
144
145 if output_text.contains(CURSOR_MARKER) {
146 log::trace!("Stripping out {CURSOR_MARKER} from response");
147 output_text = output_text.replace(CURSOR_MARKER, "");
148 }
149
150 if zeta_version == ZetaVersion::V0120GitMergeMarkers {
151 if let Some(stripped) =
152 output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
153 {
154 output_text = stripped.to_string();
155 }
156 }
157
158 let mut old_text = snapshot
159 .text_for_range(editable_offset_range.clone())
160 .collect::<String>();
161
162 if !output_text.is_empty() && !output_text.ends_with('\n') {
163 output_text.push('\n');
164 }
165 if !old_text.is_empty() && !old_text.ends_with('\n') {
166 old_text.push('\n');
167 }
168
169 let edits: Vec<_> = language::text_diff(&old_text, &output_text)
170 .into_iter()
171 .map(|(range, text)| {
172 (
173 snapshot.anchor_after(editable_offset_range.start + range.start)
174 ..snapshot.anchor_before(editable_offset_range.start + range.end),
175 text,
176 )
177 })
178 .collect();
179
180 anyhow::Ok((
181 Some((
182 request_id,
183 Some((
184 prompt_input,
185 buffer,
186 snapshot.clone(),
187 edits,
188 received_response_at,
189 )),
190 )),
191 usage,
192 ))
193 }
194 });
195
196 cx.spawn(async move |this, cx| {
197 let Some((id, prediction)) =
198 EditPredictionStore::handle_api_response(&this, request_task.await, cx)?
199 else {
200 return Ok(None);
201 };
202
203 let Some((inputs, edited_buffer, edited_buffer_snapshot, edits, received_response_at)) =
204 prediction
205 else {
206 return Ok(Some(EditPredictionResult {
207 id,
208 prediction: Err(EditPredictionRejectReason::Empty),
209 }));
210 };
211
212 Ok(Some(
213 EditPredictionResult::new(
214 id,
215 &edited_buffer,
216 &edited_buffer_snapshot,
217 edits.into(),
218 buffer_snapshotted_at,
219 received_response_at,
220 inputs,
221 cx,
222 )
223 .await,
224 ))
225 })
226}
227
228pub fn zeta2_prompt_input(
229 snapshot: &language::BufferSnapshot,
230 related_files: Vec<zeta_prompt::RelatedFile>,
231 events: Vec<Arc<zeta_prompt::Event>>,
232 excerpt_path: Arc<Path>,
233 cursor_offset: usize,
234 zeta_version: ZetaVersion,
235) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
236 let cursor_point = cursor_offset.to_point(snapshot);
237
238 let (editable_range, context_range) =
239 crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
240 cursor_point,
241 snapshot,
242 max_editable_tokens(zeta_version),
243 MAX_CONTEXT_TOKENS,
244 );
245
246 let related_files = crate::filter_redundant_excerpts(
247 related_files,
248 excerpt_path.as_ref(),
249 context_range.start.row..context_range.end.row,
250 );
251
252 let context_start_offset = context_range.start.to_offset(snapshot);
253 let editable_offset_range = editable_range.to_offset(snapshot);
254 let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
255 let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
256 ..(editable_offset_range.end - context_start_offset);
257
258 let prompt_input = zeta_prompt::ZetaPromptInput {
259 cursor_path: excerpt_path,
260 cursor_excerpt: snapshot
261 .text_for_range(context_range)
262 .collect::<String>()
263 .into(),
264 editable_range_in_excerpt,
265 cursor_offset_in_excerpt,
266 events,
267 related_files,
268 };
269 (editable_offset_range, prompt_input)
270}
271
272pub(crate) fn edit_prediction_accepted(
273 store: &EditPredictionStore,
274 current_prediction: CurrentEditPrediction,
275 cx: &App,
276) {
277 let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
278 if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
279 return;
280 }
281
282 let request_id = current_prediction.prediction.id.to_string();
283 let require_auth = custom_accept_url.is_none();
284 let client = store.client.clone();
285 let llm_token = store.llm_token.clone();
286 let app_version = AppVersion::global(cx);
287
288 cx.background_spawn(async move {
289 let url = if let Some(accept_edits_url) = custom_accept_url {
290 gpui::http_client::Url::parse(&accept_edits_url)?
291 } else {
292 client
293 .http_client()
294 .build_zed_llm_url("/predict_edits/accept", &[])?
295 };
296
297 let response = EditPredictionStore::send_api_request::<()>(
298 move |builder| {
299 let req = builder.uri(url.as_ref()).body(
300 serde_json::to_string(&AcceptEditPredictionBody {
301 request_id: request_id.clone(),
302 })?
303 .into(),
304 );
305 Ok(req?)
306 },
307 client,
308 llm_token,
309 app_version,
310 require_auth,
311 )
312 .await;
313
314 response?;
315 anyhow::Ok(())
316 })
317 .detach_and_log_err(cx);
318}