1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider, TeacherBackend,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 openai_client::OpenAiClient,
9 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
10 progress::{ExampleProgress, InfoStyle, Step},
11 retrieve_context::run_context_retrieval,
12};
13use anyhow::Context as _;
14use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
15use futures::{FutureExt as _, StreamExt as _, future::Shared};
16use gpui::{AppContext as _, AsyncApp, Task};
17use std::{
18 fs,
19 sync::{
20 Arc, Mutex, OnceLock,
21 atomic::{AtomicUsize, Ordering::SeqCst},
22 },
23};
24use zeta_prompt::ZetaFormat;
25
26static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
27static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
28
29pub async fn run_prediction(
30 example: &mut Example,
31 args: &PredictArgs,
32 app_state: Arc<EpAppState>,
33 example_progress: &ExampleProgress,
34 mut cx: AsyncApp,
35) -> anyhow::Result<()> {
36 let repetition_count = args.repetitions;
37
38 if let Some(existing_prediction) = example.predictions.first() {
39 let has_prediction = existing_prediction.actual_patch.is_some()
40 || !existing_prediction.actual_output.is_empty();
41 if has_prediction {
42 match args.provider {
43 None => return Ok(()),
44 Some(provider) if existing_prediction.provider == provider => return Ok(()),
45 Some(_) => example.predictions.clear(),
46 }
47 }
48 }
49
50 let Some(provider) = args.provider else {
51 anyhow::bail!(
52 "No existing predictions found. Use --provider to specify which model to use for prediction."
53 );
54 };
55
56 if let PredictionProvider::Teacher(backend) | PredictionProvider::TeacherNonBatching(backend) =
57 provider
58 {
59 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
60 run_format_prompt(
61 example,
62 &FormatPromptArgs { provider },
63 app_state.clone(),
64 example_progress,
65 cx,
66 )
67 .await?;
68
69 let step_progress = example_progress.start(Step::Predict);
70 let batched = matches!(provider, PredictionProvider::Teacher(..));
71 return predict_teacher(
72 example,
73 backend,
74 batched,
75 repetition_count,
76 args.cache_only,
77 &step_progress,
78 )
79 .await;
80 }
81
82 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
83 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
84
85 let step_progress = example_progress.start(Step::Predict);
86
87 if matches!(
88 provider,
89 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
90 ) {
91 step_progress.set_substatus("authenticating");
92 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
93 AUTHENTICATED
94 .get_or_init(|| {
95 let client = app_state.client.clone();
96 cx.spawn(async move |cx| {
97 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
98 eprintln!("Authentication failed: {}", e);
99 }
100 })
101 .shared()
102 })
103 .clone()
104 .await;
105 }
106
107 let ep_store = cx
108 .update(|cx| EditPredictionStore::try_global(cx))
109 .context("EditPredictionStore not initialized")?;
110
111 ep_store.update(&mut cx, |store, _cx| {
112 let model = match provider {
113 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
114 PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2,
115 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
116 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
117 PredictionProvider::Teacher(..)
118 | PredictionProvider::TeacherNonBatching(..)
119 | PredictionProvider::Repair => {
120 unreachable!()
121 }
122 };
123 store.set_edit_prediction_model(model);
124
125 // If user specified a non-default Zeta2 version, configure raw endpoint.
126 // ZED_ZETA_MODEL env var is optional.
127 if let PredictionProvider::Zeta2(format) = provider {
128 if format != ZetaFormat::default() {
129 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
130 store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
131 }
132 }
133 });
134 step_progress.set_substatus("configuring model");
135 let state = example.state.as_ref().context("state must be set")?;
136 let run_dir = RUN_DIR.join(&example.spec.name);
137
138 let updated_example = Arc::new(Mutex::new(example.clone()));
139 let current_run_ix = Arc::new(AtomicUsize::new(0));
140
141 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
142 let debug_task = cx.background_spawn({
143 let updated_example = updated_example.clone();
144 let current_run_ix = current_run_ix.clone();
145 let run_dir = run_dir.clone();
146 async move {
147 while let Some(event) = debug_rx.next().await {
148 let run_ix = current_run_ix.load(SeqCst);
149 let mut updated_example = updated_example.lock().unwrap();
150
151 let run_dir = if repetition_count > 1 {
152 run_dir.join(format!("{:03}", run_ix))
153 } else {
154 run_dir.clone()
155 };
156
157 match event {
158 DebugEvent::EditPredictionStarted(request) => {
159 assert_eq!(updated_example.predictions.len(), run_ix + 1);
160
161 if let Some(prompt) = request.prompt {
162 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
163 if matches!(provider, PredictionProvider::Zeta2(_)) {
164 updated_example.prompt.get_or_insert(ExamplePrompt {
165 input: prompt,
166 expected_output: String::new(),
167 rejected_output: None,
168 provider,
169 prefill: None,
170 });
171 }
172 }
173 }
174 DebugEvent::EditPredictionFinished(request) => {
175 assert_eq!(updated_example.predictions.len(), run_ix + 1);
176
177 if let Some(output) = request.model_output {
178 fs::write(run_dir.join("prediction_response.md"), &output)?;
179 updated_example
180 .predictions
181 .last_mut()
182 .unwrap()
183 .actual_output = output;
184 }
185 if run_ix >= repetition_count {
186 break;
187 }
188 }
189 _ => {}
190 }
191 }
192 anyhow::Ok(())
193 }
194 });
195
196 for ix in 0..repetition_count {
197 current_run_ix.store(ix, SeqCst);
198 let run_dir = if repetition_count > 1 {
199 run_dir.join(format!("{:03}", ix))
200 } else {
201 run_dir.clone()
202 };
203
204 if repetition_count > 1 {
205 step_progress.set_substatus(format!(
206 "running prediction {}/{}",
207 ix + 1,
208 repetition_count
209 ));
210 } else {
211 step_progress.set_substatus("running prediction");
212 }
213
214 fs::create_dir_all(&run_dir)?;
215 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
216 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
217 }
218 #[cfg(unix)]
219 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
220 #[cfg(windows)]
221 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
222
223 updated_example
224 .lock()
225 .unwrap()
226 .predictions
227 .push(ExamplePrediction {
228 actual_patch: None,
229 actual_output: String::new(),
230 actual_cursor: None,
231 error: None,
232 provider,
233 });
234
235 step_progress.set_substatus("requesting prediction");
236 let prediction = ep_store
237 .update(&mut cx, |store, cx| {
238 store.request_prediction(
239 &state.project,
240 &state.buffer,
241 state.cursor_position,
242 cloud_llm_client::PredictEditsRequestTrigger::Cli,
243 cx,
244 )
245 })
246 .await?;
247
248 let actual_patch = prediction.and_then(|prediction| {
249 let prediction = prediction.prediction.ok()?;
250 prediction
251 .edit_preview
252 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
253 });
254
255 let has_prediction = actual_patch.as_ref().is_some_and(|p| !p.is_empty());
256
257 updated_example
258 .lock()
259 .unwrap()
260 .predictions
261 .last_mut()
262 .unwrap()
263 .actual_patch = actual_patch;
264
265 if ix == repetition_count - 1 {
266 let (info, style) = if has_prediction {
267 ("predicted", InfoStyle::Normal)
268 } else {
269 ("no prediction", InfoStyle::Warning)
270 };
271 step_progress.set_info(info, style);
272 }
273 }
274
275 ep_store.update(&mut cx, |store, _| {
276 store.remove_project(&state.project);
277 });
278 debug_task.await?;
279
280 *example = Arc::into_inner(updated_example)
281 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
282 .into_inner()
283 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
284 Ok(())
285}
286
287async fn predict_teacher(
288 example: &mut Example,
289 backend: TeacherBackend,
290 batched: bool,
291 repetition_count: usize,
292 cache_only: bool,
293 step_progress: &crate::progress::StepProgress,
294) -> anyhow::Result<()> {
295 match backend {
296 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
297 predict_anthropic(
298 example,
299 backend,
300 batched,
301 repetition_count,
302 cache_only,
303 step_progress,
304 )
305 .await
306 }
307 TeacherBackend::Gpt52 => {
308 predict_openai(
309 example,
310 backend,
311 batched,
312 repetition_count,
313 cache_only,
314 step_progress,
315 )
316 .await
317 }
318 }
319}
320
321async fn predict_anthropic(
322 example: &mut Example,
323 backend: TeacherBackend,
324 batched: bool,
325 repetition_count: usize,
326 cache_only: bool,
327 step_progress: &crate::progress::StepProgress,
328) -> anyhow::Result<()> {
329 let llm_model_name = backend.model_name();
330 let max_tokens = 16384;
331 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
332 let client = if batched {
333 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
334 } else {
335 AnthropicClient::plain()
336 };
337 client.expect("Failed to create Anthropic client")
338 });
339
340 let prompt = example.prompt.as_ref().context("Prompt is required")?;
341
342 for ix in 0..repetition_count {
343 if repetition_count > 1 {
344 step_progress.set_substatus(format!(
345 "running prediction {}/{}",
346 ix + 1,
347 repetition_count
348 ));
349 } else {
350 step_progress.set_substatus("running prediction");
351 }
352
353 let messages = vec![anthropic::Message {
354 role: anthropic::Role::User,
355 content: vec![anthropic::RequestContent::Text {
356 text: prompt.input.clone(),
357 cache_control: None,
358 }],
359 }];
360
361 let seed = if repetition_count > 1 { Some(ix) } else { None };
362 let Some(response) = llm_client
363 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
364 .await?
365 else {
366 // Request stashed for batched processing
367 return Ok(());
368 };
369
370 let actual_output = response
371 .content
372 .into_iter()
373 .filter_map(|content| match content {
374 anthropic::ResponseContent::Text { text } => Some(text),
375 _ => None,
376 })
377 .collect::<Vec<String>>()
378 .join("\n");
379
380 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
381
382 let prediction = ExamplePrediction {
383 actual_patch: Some(actual_patch),
384 actual_output,
385 actual_cursor,
386 error: None,
387 provider: if batched {
388 PredictionProvider::Teacher(backend)
389 } else {
390 PredictionProvider::TeacherNonBatching(backend)
391 },
392 };
393
394 example.predictions.push(prediction);
395 }
396 Ok(())
397}
398
399async fn predict_openai(
400 example: &mut Example,
401 backend: TeacherBackend,
402 batched: bool,
403 repetition_count: usize,
404 cache_only: bool,
405 step_progress: &crate::progress::StepProgress,
406) -> anyhow::Result<()> {
407 let llm_model_name = backend.model_name();
408 let max_tokens = 16384;
409 let llm_client = OPENAI_CLIENT.get_or_init(|| {
410 let client = if batched {
411 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
412 } else {
413 OpenAiClient::plain()
414 };
415 client.expect("Failed to create OpenAI client")
416 });
417
418 let prompt = example.prompt.as_ref().context("Prompt is required")?;
419
420 for ix in 0..repetition_count {
421 if repetition_count > 1 {
422 step_progress.set_substatus(format!(
423 "running prediction {}/{}",
424 ix + 1,
425 repetition_count
426 ));
427 } else {
428 step_progress.set_substatus("running prediction");
429 }
430
431 let messages = vec![open_ai::RequestMessage::User {
432 content: open_ai::MessageContent::Plain(prompt.input.clone()),
433 }];
434
435 let seed = if repetition_count > 1 { Some(ix) } else { None };
436 let Some(response) = llm_client
437 .generate(llm_model_name, max_tokens, messages, seed, cache_only)
438 .await?
439 else {
440 // Request stashed for batched processing
441 return Ok(());
442 };
443
444 let actual_output = response
445 .choices
446 .into_iter()
447 .filter_map(|choice| match choice.message {
448 open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
449 open_ai::MessageContent::Plain(text) => text,
450 open_ai::MessageContent::Multipart(parts) => parts
451 .into_iter()
452 .filter_map(|p| match p {
453 open_ai::MessagePart::Text { text } => Some(text),
454 _ => None,
455 })
456 .collect::<Vec<_>>()
457 .join(""),
458 }),
459 _ => None,
460 })
461 .collect::<Vec<String>>()
462 .join("\n");
463
464 let (actual_patch, actual_cursor) = TeacherPrompt::parse(example, &actual_output)?;
465
466 let prediction = ExamplePrediction {
467 actual_patch: Some(actual_patch),
468 actual_output,
469 actual_cursor,
470 error: None,
471 provider: if batched {
472 PredictionProvider::Teacher(backend)
473 } else {
474 PredictionProvider::TeacherNonBatching(backend)
475 },
476 };
477
478 example.predictions.push(prediction);
479 }
480 Ok(())
481}
482
483pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
484 match provider {
485 Some(PredictionProvider::Teacher(backend)) => match backend {
486 TeacherBackend::Sonnet45 | TeacherBackend::Sonnet46 => {
487 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
488 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
489 .expect("Failed to create Anthropic client")
490 });
491 llm_client
492 .sync_batches()
493 .await
494 .context("Failed to sync Anthropic batches")?;
495 }
496 TeacherBackend::Gpt52 => {
497 let llm_client = OPENAI_CLIENT.get_or_init(|| {
498 OpenAiClient::batch(&crate::paths::LLM_CACHE_DB)
499 .expect("Failed to create OpenAI client")
500 });
501 llm_client
502 .sync_batches()
503 .await
504 .context("Failed to sync OpenAI batches")?;
505 }
506 },
507 _ => (),
508 };
509 Ok(())
510}