1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
10 progress::{ExampleProgress, InfoStyle, Step},
11 retrieve_context::run_context_retrieval,
12};
13use anyhow::Context as _;
14use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
15use futures::{FutureExt as _, StreamExt as _, future::Shared};
16use gpui::{AppContext as _, AsyncApp, Task};
17use std::{
18 fs,
19 sync::{
20 Arc, Mutex, OnceLock,
21 atomic::{AtomicUsize, Ordering::SeqCst},
22 },
23};
24use zeta_prompt::ZetaFormat;
25
26static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
27static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
28
29pub async fn run_prediction(
30 example: &mut Example,
31 args: &PredictArgs,
32 app_state: Arc<EpAppState>,
33 example_progress: &ExampleProgress,
34 mut cx: AsyncApp,
35) -> anyhow::Result<()> {
36 let repetition_count = args.repetitions;
37
38 if let Some(existing_prediction) = example.predictions.first() {
39 let has_prediction = existing_prediction.actual_patch.is_some()
40 || !existing_prediction.actual_output.is_empty();
41 if has_prediction {
42 match args.provider {
43 None => return Ok(()),
44 Some(provider) if existing_prediction.provider == provider => return Ok(()),
45 Some(_) => example.predictions.clear(),
46 }
47 }
48 }
49
50 let Some(provider) = args.provider else {
51 anyhow::bail!(
52 "No existing predictions found. Use --provider to specify which model to use for prediction."
53 );
54 };
55
56 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
57
58 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
59 provider
60 {
61 let _step_progress = example_progress.start(Step::Predict);
62
63 run_format_prompt(
64 example,
65 &FormatPromptArgs { provider },
66 app_state.clone(),
67 example_progress,
68 cx,
69 )
70 .await?;
71
72 let batched = matches!(provider, PredictionProvider::Teacher(..));
73 return predict_teacher(example, backend, batched, repetition_count, args.cache_only).await;
74 }
75
76 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
77
78 let step_progress = example_progress.start(Step::Predict);
79
80 if matches!(
81 provider,
82 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
83 ) {
84 step_progress.set_substatus("authenticating");
85 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
86 AUTHENTICATED
87 .get_or_init(|| {
88 let client = app_state.client.clone();
89 cx.spawn(async move |cx| {
90 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
91 eprintln!("Authentication failed: {}", e);
92 }
93 })
94 .shared()
95 })
96 .clone()
97 .await;
98 }
99
100 let ep_store = cx
101 .update(|cx| EditPredictionStore::try_global(cx))
102 .context("EditPredictionStore not initialized")?;
103
104 ep_store.update(&mut cx, |store, _cx| {
105 let model = match provider {
106 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
107 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2,
108 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
109 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
110 PredictionProvider::Teacher(..)
111 | PredictionProvider::TeacherNonBatching(..)
112 | PredictionProvider::Repair => {
113 unreachable!()
114 }
115 };
116 store.set_edit_prediction_model(model);
117
118 // If user specified a non-default Zeta2 version, configure raw endpoint.
119 // ZED_ZETA_MODEL env var is optional.
120 if let PredictionProvider::Zeta2(format) = provider {
121 if format != ZetaFormat::default() {
122 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
123 store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
124 }
125 }
126 });
127 step_progress.set_substatus("configuring model");
128 let state = example.state.as_ref().context("state must be set")?;
129 let run_dir = RUN_DIR.join(&example.spec.name);
130
131 let updated_example = Arc::new(Mutex::new(example.clone()));
132 let current_run_ix = Arc::new(AtomicUsize::new(0));
133
134 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
135 let debug_task = cx.background_spawn({
136 let updated_example = updated_example.clone();
137 let current_run_ix = current_run_ix.clone();
138 let run_dir = run_dir.clone();
139 async move {
140 while let Some(event) = debug_rx.next().await {
141 let run_ix = current_run_ix.load(SeqCst);
142 let mut updated_example = updated_example.lock().unwrap();
143
144 let run_dir = if repetition_count > 1 {
145 run_dir.join(format!("{:03}", run_ix))
146 } else {
147 run_dir.clone()
148 };
149
150 match event {
151 DebugEvent::EditPredictionStarted(request) => {
152 assert_eq!(updated_example.predictions.len(), run_ix + 1);
153
154 if let Some(prompt) = request.prompt {
155 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
156 if matches!(provider, PredictionProvider::Zeta2(_)) {
157 updated_example.prompt.get_or_insert(ExamplePrompt {
158 input: prompt,
159 expected_output: String::new(),
160 rejected_output: None,
161 provider,
162 prefill: None,
163 });
164 }
165 }
166 }
167 DebugEvent::EditPredictionFinished(request) => {
168 assert_eq!(updated_example.predictions.len(), run_ix + 1);
169
170 if let Some(output) = request.model_output {
171 fs::write(run_dir.join("prediction_response.md"), &output)?;
172 updated_example
173 .predictions
174 .last_mut()
175 .unwrap()
176 .actual_output = output;
177 }
178 if run_ix >= repetition_count {
179 break;
180 }
181 }
182 _ => {}
183 }
184 }
185 anyhow::Ok(())
186 }
187 });
188
189 for ix in 0..repetition_count {
190 current_run_ix.store(ix, SeqCst);
191 let run_dir = if repetition_count > 1 {
192 run_dir.join(format!("{:03}", ix))
193 } else {
194 run_dir.clone()
195 };
196
197 fs::create_dir_all(&run_dir)?;
198 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
199 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
200 }
201 #[cfg(unix)]
202 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
203 #[cfg(windows)]
204 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
205
206 updated_example
207 .lock()
208 .unwrap()
209 .predictions
210 .push(ExamplePrediction {
211 actual_patch: None,
212 actual_output: String::new(),
213 actual_cursor: None,
214 error: None,
215 provider,
216 });
217
218 step_progress.set_substatus("requesting prediction");
219 let prediction = ep_store
220 .update(&mut cx, |store, cx| {
221 store.request_prediction(
222 &state.project,
223 &state.buffer,
224 state.cursor_position,
225 cloud_llm_client::PredictEditsRequestTrigger::Cli,
226 cx,
227 )
228 })
229 .await?;
230
231 let actual_patch = prediction.and_then(|prediction| {
232 let prediction = prediction.prediction.ok()?;
233 prediction
234 .edit_preview
235 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
236 });
237
238 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
239
240 updated_example
241 .lock()
242 .unwrap()
243 .predictions
244 .last_mut()
245 .unwrap()
246 .actual_patch = actual_patch;
247
248 if ix == repetition_count - 1 {
249 let (info, style) = if has_prediction {
250 ("predicted", InfoStyle::Normal)
251 } else {
252 ("no prediction", InfoStyle::Warning)
253 };
254 step_progress.set_info(info, style);
255 }
256 }
257
258 ep_store.update(&mut cx, |store, _| {
259 store.remove_project(&state.project);
260 });
261 debug_task.await?;
262
263 *example = Arc::into_inner(updated_example)
264 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
265 .into_inner()
266 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
267 Ok(())
268}
269
270async fn predict_teacher(
271 example: &mut Example,
272 backend: TeacherBackend,
273 batched: bool,
274 repetition_count: usize,
275 cache_only: bool,
276) -> anyhow::Result<()> {
277 match backend {
278 TeacherBackend::Sonnet45 => {
279 predict_anthropic(example, backend, batched, repetition_count, cache_only).await
280 }
281 TeacherBackend::Gpt52 => {
282 predict_openai(example, backend, batched, repetition_count, cache_only).await
283 }
284 }
285}
286
287async fn predict_anthropic(
288 example: &mut Example,
289 backend: TeacherBackend,
290 batched: bool,
291 repetition_count: usize,
292 cache_only: bool,
293) -> anyhow::Result<()> {
294 let llm_model_name = backend.model_name();
295 let max_tokens = 16384;
296 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
297 let client = if batched {
298 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
299 } else {
300 AnthropicClient::plain()
301 };
302 client.expect("Failed to create Anthropic client")
303 });
304
305 let prompt = example.prompt.as_ref().context("Prompt is required")?;
306
307 for ix in 0..repetition_count {
308 let messages = vec![anthropic::Message {
309 role: anthropic::Role::User,
310 content: vec![anthropic::RequestContent::Text {
311 text: prompt.input.clone(),
312 cache_control: None,
313 }],
314 }];
315
316 let seed = if repetition_count > 1 { Some(ix) } else { None };
317 let Some(response) = llm_client
318 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
319 .await?
320 else {
321 // Request stashed for batched processing
322 return Ok(());
323 };
324
325 let actual_output = response
326 .content
327 .into_iter()
328 .filter_map(|content| match content {
329 anthropic::ResponseContent::Text { text } => Some(text),
330 _ => None,
331 })
332 .collect::<Vec<String>>()
333 .join("\n");
334
335 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
336
337 let prediction = ExamplePrediction {
338 actual_patch: Some(actual_patch),
339 actual_output,
340 actual_cursor,
341 error: None,
342 provider: if batched {
343 PredictionProvider::Teacher(backend)
344 } else {
345 PredictionProvider::TeacherNonBatching(backend)
346 },
347 };
348
349 example.predictions.push(prediction);
350 }
351 Ok(())
352}
353
354async fn predict_openai(
355 example: &mut Example,
356 backend: TeacherBackend,
357 batched: bool,
358 repetition_count: usize,
359 cache_only: bool,
360) -> anyhow::Result<()> {
361 let llm_model_name = backend.model_name();
362 let max_tokens = 16384;
363 let llm_client = OPENAI_CLIENT.get_or_init(|| {
364 let client = if batched {
365 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
366 } else {
367 OpenAiClient::plain()
368 };
369 client.expect("Failed to create OpenAI client")
370 });
371
372 let prompt = example.prompt.as_ref().context("Prompt is required")?;
373
374 for ix in 0..repetition_count {
375 let messages = vec![open_ai::RequestMessage::User {
376 content: open_ai::MessageContent::Plain(prompt.input.clone()),
377 }];
378
379 let seed = if repetition_count > 1 { Some(ix) } else { None };
380 let Some(response) = llm_client
381 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
382 .await?
383 else {
384 // Request stashed for batched processing
385 return Ok(());
386 };
387
388 let actual_output = response
389 .choices
390 .into_iter()
391 .filter_map(|choice| match choice.message {
392 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
393 open_ai::MessageContent::Plain(text) => text,
394 open_ai::MessageContent::Multipart(parts) => parts
395 .into_iter()
396 .filter_map(|p| match p {
397 open_ai::MessagePart::Text { text } => Some(text),
398 _ => None,
399 })
400 .collect::<Vec<_>>()
401 .join(""),
402 }),
403 _ => None,
404 })
405 .collect::<Vec<String>>()
406 .join("\n");
407
408 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
409
410 let prediction = ExamplePrediction {
411 actual_patch: Some(actual_patch),
412 actual_output,
413 actual_cursor,
414 error: None,
415 provider: if batched {
416 PredictionProvider::Teacher(backend)
417 } else {
418 PredictionProvider::TeacherNonBatching(backend)
419 },
420 };
421
422 example.predictions.push(prediction);
423 }
424 Ok(())
425}
426
427pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
428 match provider {
429 Some(PredictionProvider::Teacher(backend)) => match backend {
430 TeacherBackend::Sonnet45 => {
431 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
432 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
433 .expect("Failed to create Anthropic client")
434 });
435 llm_client
436 .sync_batches()
437 .await
438 .context("Failed to sync Anthropic batches")?;
439 }
440 TeacherBackend::Gpt52 => {
441 let llm_client = OPENAI_CLIENT.get_or_init(|| {
442 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
443 .expect("Failed to create OpenAI client")
444 });
445 llm_client
446 .sync_batches()
447 .await
448 .context("Failed to sync OpenAI batches")?;
449 }
450 },
451 _ => (),
452 };
453 Ok(())
454}