1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{ExampleProgress, InfoStyle, Step},
10 retrieve_context::run_context_retrieval,
11};
12use anyhow::Context as _;
13use edit_prediction::{DebugEvent, EditPredictionStore};
14use futures::{FutureExt as _, StreamExt as _, future::Shared};
15use gpui::{AppContext as _, AsyncApp, Task};
16use std::{
17 fs,
18 sync::{
19 Arc, Mutex, OnceLock,
20 atomic::{AtomicUsize, Ordering::SeqCst},
21 },
22};
23use zeta_prompt::ZetaVersion;
24
25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
26
27pub async fn run_prediction(
28 example: &mut Example,
29 args: &PredictArgs,
30 app_state: Arc<EpAppState>,
31 example_progress: &ExampleProgress,
32 mut cx: AsyncApp,
33) -> anyhow::Result<()> {
34 let provider = args.provider;
35 let repetition_count = args.repetitions;
36
37 if let Some(existing_prediction) = example.predictions.first() {
38 if existing_prediction.provider == provider {
39 return Ok(());
40 } else {
41 example.predictions.clear();
42 }
43 }
44
45 run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
46
47 if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
48 args.provider
49 {
50 let _step_progress = example_progress.start(Step::Predict);
51
52 run_format_prompt(
53 example,
54 &FormatPromptArgs { provider },
55 app_state.clone(),
56 example_progress,
57 cx,
58 )
59 .await?;
60
61 let batched = matches!(provider, PredictionProvider::Teacher(..));
62 return predict_anthropic(example, repetition_count, version, batched).await;
63 }
64
65 run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
66
67 let step_progress = example_progress.start(Step::Predict);
68
69 if matches!(
70 provider,
71 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
72 ) {
73 step_progress.set_substatus("authenticating");
74 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
75 AUTHENTICATED
76 .get_or_init(|| {
77 let client = app_state.client.clone();
78 cx.spawn(async move |cx| {
79 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
80 eprintln!("Authentication failed: {}", e);
81 }
82 })
83 .shared()
84 })
85 .clone()
86 .await;
87 }
88
89 let ep_store = cx
90 .update(|cx| EditPredictionStore::try_global(cx))
91 .context("EditPredictionStore not initialized")?;
92
93 ep_store.update(&mut cx, |store, _cx| {
94 let model = match provider {
95 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
96 PredictionProvider::Zeta2(version) => {
97 edit_prediction::EditPredictionModel::Zeta2 { version }
98 }
99 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
100 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
101 PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => {
102 unreachable!()
103 }
104 };
105 store.set_edit_prediction_model(model);
106 });
107 step_progress.set_substatus("configuring model");
108 let state = example.state.as_ref().context("state must be set")?;
109 let run_dir = RUN_DIR.join(&example.spec.name);
110
111 let updated_example = Arc::new(Mutex::new(example.clone()));
112 let current_run_ix = Arc::new(AtomicUsize::new(0));
113
114 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
115 let debug_task = cx.background_spawn({
116 let updated_example = updated_example.clone();
117 let current_run_ix = current_run_ix.clone();
118 let run_dir = run_dir.clone();
119 async move {
120 while let Some(event) = debug_rx.next().await {
121 let run_ix = current_run_ix.load(SeqCst);
122 let mut updated_example = updated_example.lock().unwrap();
123
124 let run_dir = if repetition_count > 1 {
125 run_dir.join(format!("{:03}", run_ix))
126 } else {
127 run_dir.clone()
128 };
129
130 match event {
131 DebugEvent::EditPredictionStarted(request) => {
132 assert_eq!(updated_example.predictions.len(), run_ix + 1);
133
134 if let Some(prompt) = request.prompt {
135 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
136 if matches!(provider, PredictionProvider::Zeta2(_)) {
137 updated_example.prompt.get_or_insert(ExamplePrompt {
138 input: prompt,
139 expected_output: String::new(),
140 provider,
141 });
142 }
143 }
144 }
145 DebugEvent::EditPredictionFinished(request) => {
146 assert_eq!(updated_example.predictions.len(), run_ix + 1);
147
148 if let Some(output) = request.model_output {
149 fs::write(run_dir.join("prediction_response.md"), &output)?;
150 updated_example
151 .predictions
152 .last_mut()
153 .unwrap()
154 .actual_output = output;
155 }
156 if run_ix >= repetition_count {
157 break;
158 }
159 }
160 _ => {}
161 }
162 }
163 anyhow::Ok(())
164 }
165 });
166
167 for ix in 0..repetition_count {
168 current_run_ix.store(ix, SeqCst);
169 let run_dir = if repetition_count > 1 {
170 run_dir.join(format!("{:03}", ix))
171 } else {
172 run_dir.clone()
173 };
174
175 fs::create_dir_all(&run_dir)?;
176 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
177 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
178 }
179 #[cfg(unix)]
180 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
181 #[cfg(windows)]
182 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
183
184 updated_example
185 .lock()
186 .unwrap()
187 .predictions
188 .push(ExamplePrediction {
189 actual_patch: String::new(),
190 actual_output: String::new(),
191 provider,
192 });
193
194 step_progress.set_substatus("requesting prediction");
195 let prediction = ep_store
196 .update(&mut cx, |store, cx| {
197 store.request_prediction(
198 &state.project,
199 &state.buffer,
200 state.cursor_position,
201 cloud_llm_client::PredictEditsRequestTrigger::Cli,
202 cx,
203 )
204 })
205 .await?;
206
207 let actual_patch = prediction
208 .and_then(|prediction| {
209 let prediction = prediction.prediction.ok()?;
210 prediction
211 .edit_preview
212 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
213 })
214 .unwrap_or_default();
215
216 let has_prediction = !actual_patch.is_empty();
217
218 updated_example
219 .lock()
220 .unwrap()
221 .predictions
222 .last_mut()
223 .unwrap()
224 .actual_patch = actual_patch;
225
226 if ix == repetition_count - 1 {
227 let (info, style) = if has_prediction {
228 ("predicted", InfoStyle::Normal)
229 } else {
230 ("no prediction", InfoStyle::Warning)
231 };
232 step_progress.set_info(info, style);
233 }
234 }
235
236 ep_store.update(&mut cx, |store, _| {
237 store.remove_project(&state.project);
238 });
239 debug_task.await?;
240
241 *example = Arc::into_inner(updated_example)
242 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
243 .into_inner()
244 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
245 Ok(())
246}
247
248async fn predict_anthropic(
249 example: &mut Example,
250 _repetition_count: usize,
251 version: ZetaVersion,
252 batched: bool,
253) -> anyhow::Result<()> {
254 let llm_model_name = "claude-sonnet-4-5";
255 let max_tokens = 16384;
256 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
257 let client = if batched {
258 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
259 } else {
260 AnthropicClient::plain()
261 };
262 client.expect("Failed to create Anthropic client")
263 });
264
265 let prompt = example.prompt.as_ref().context("Prompt is required")?;
266
267 let messages = vec![anthropic::Message {
268 role: anthropic::Role::User,
269 content: vec![anthropic::RequestContent::Text {
270 text: prompt.input.clone(),
271 cache_control: None,
272 }],
273 }];
274
275 let Some(response) = llm_client
276 .generate(llm_model_name, max_tokens, messages)
277 .await?
278 else {
279 // Request stashed for batched processing
280 return Ok(());
281 };
282
283 let actual_output = response
284 .content
285 .into_iter()
286 .filter_map(|content| match content {
287 anthropic::ResponseContent::Text { text } => Some(text),
288 _ => None,
289 })
290 .collect::<Vec<String>>()
291 .join("\n");
292
293 let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
294
295 let prediction = ExamplePrediction {
296 actual_patch,
297 actual_output,
298 provider: if batched {
299 PredictionProvider::Teacher(version)
300 } else {
301 PredictionProvider::TeacherNonBatching(version)
302 },
303 };
304
305 example.predictions.push(prediction);
306 Ok(())
307}
308
309pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
310 match provider {
311 PredictionProvider::Teacher(..) => {
312 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
313 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
314 .expect("Failed to create Anthropic client")
315 });
316 llm_client
317 .sync_batches()
318 .await
319 .context("Failed to sync batches")?;
320 }
321 _ => (),
322 };
323 Ok(())
324}