1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{InfoStyle, Progress, Step},
10 retrieve_context::run_context_retrieval,
11};
12use anyhow::Context as _;
13use edit_prediction::{DebugEvent, EditPredictionStore};
14use futures::{FutureExt as _, StreamExt as _, future::Shared};
15use gpui::{AppContext as _, AsyncApp, Task};
16use std::{
17 fs,
18 sync::{
19 Arc, Mutex, OnceLock,
20 atomic::{AtomicUsize, Ordering::SeqCst},
21 },
22};
23use zeta_prompt::ZetaVersion;
24
25static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
26
27pub async fn run_prediction(
28 example: &mut Example,
29 args: &PredictArgs,
30 app_state: Arc<EpAppState>,
31 mut cx: AsyncApp,
32) -> anyhow::Result<()> {
33 let provider = args.provider;
34 let repetition_count = args.repetitions;
35
36 if let Some(existing_prediction) = example.predictions.first() {
37 if existing_prediction.provider == provider {
38 return Ok(());
39 } else {
40 example.predictions.clear();
41 }
42 }
43
44 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
45
46 if let PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) =
47 args.provider
48 {
49 let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
50
51 run_format_prompt(
52 example,
53 &FormatPromptArgs { provider },
54 app_state.clone(),
55 cx,
56 )
57 .await?;
58
59 let batched = matches!(provider, PredictionProvider::Teacher(..));
60 return predict_anthropic(example, repetition_count, version, batched).await;
61 }
62
63 run_load_project(example, app_state.clone(), cx.clone()).await?;
64
65 let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
66
67 if matches!(
68 provider,
69 PredictionProvider::Zeta1 | PredictionProvider::Zeta2(_)
70 ) {
71 step_progress.set_substatus("authenticating");
72 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
73 AUTHENTICATED
74 .get_or_init(|| {
75 let client = app_state.client.clone();
76 cx.spawn(async move |cx| {
77 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
78 eprintln!("Authentication failed: {}", e);
79 }
80 })
81 .shared()
82 })
83 .clone()
84 .await;
85 }
86
87 let ep_store = cx
88 .update(|cx| EditPredictionStore::try_global(cx))
89 .context("EditPredictionStore not initialized")?;
90
91 ep_store.update(&mut cx, |store, _cx| {
92 let model = match provider {
93 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
94 PredictionProvider::Zeta2(version) => {
95 edit_prediction::EditPredictionModel::Zeta2 { version }
96 }
97 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
98 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
99 PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) => {
100 unreachable!()
101 }
102 };
103 store.set_edit_prediction_model(model);
104 });
105 step_progress.set_substatus("configuring model");
106 let state = example.state.as_ref().context("state must be set")?;
107 let run_dir = RUN_DIR.join(&example.spec.name);
108
109 let updated_example = Arc::new(Mutex::new(example.clone()));
110 let current_run_ix = Arc::new(AtomicUsize::new(0));
111
112 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
113 let debug_task = cx.background_spawn({
114 let updated_example = updated_example.clone();
115 let current_run_ix = current_run_ix.clone();
116 let run_dir = run_dir.clone();
117 async move {
118 while let Some(event) = debug_rx.next().await {
119 let run_ix = current_run_ix.load(SeqCst);
120 let mut updated_example = updated_example.lock().unwrap();
121
122 let run_dir = if repetition_count > 1 {
123 run_dir.join(format!("{:03}", run_ix))
124 } else {
125 run_dir.clone()
126 };
127
128 match event {
129 DebugEvent::EditPredictionStarted(request) => {
130 assert_eq!(updated_example.predictions.len(), run_ix + 1);
131
132 if let Some(prompt) = request.prompt {
133 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
134 if matches!(provider, PredictionProvider::Zeta2(_)) {
135 updated_example.prompt.get_or_insert(ExamplePrompt {
136 input: prompt,
137 expected_output: String::new(),
138 provider,
139 });
140 }
141 }
142 }
143 DebugEvent::EditPredictionFinished(request) => {
144 assert_eq!(updated_example.predictions.len(), run_ix + 1);
145
146 if let Some(output) = request.model_output {
147 fs::write(run_dir.join("prediction_response.md"), &output)?;
148 updated_example
149 .predictions
150 .last_mut()
151 .unwrap()
152 .actual_output = output;
153 }
154 if run_ix >= repetition_count {
155 break;
156 }
157 }
158 _ => {}
159 }
160 }
161 anyhow::Ok(())
162 }
163 });
164
165 for ix in 0..repetition_count {
166 current_run_ix.store(ix, SeqCst);
167 let run_dir = if repetition_count > 1 {
168 run_dir.join(format!("{:03}", ix))
169 } else {
170 run_dir.clone()
171 };
172
173 fs::create_dir_all(&run_dir)?;
174 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
175 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
176 }
177 #[cfg(unix)]
178 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
179 #[cfg(windows)]
180 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
181
182 updated_example
183 .lock()
184 .unwrap()
185 .predictions
186 .push(ExamplePrediction {
187 actual_patch: String::new(),
188 actual_output: String::new(),
189 provider,
190 });
191
192 step_progress.set_substatus("requesting prediction");
193 let prediction = ep_store
194 .update(&mut cx, |store, cx| {
195 store.request_prediction(
196 &state.project,
197 &state.buffer,
198 state.cursor_position,
199 cloud_llm_client::PredictEditsRequestTrigger::Cli,
200 cx,
201 )
202 })
203 .await?;
204
205 let actual_patch = prediction
206 .and_then(|prediction| {
207 let prediction = prediction.prediction.ok()?;
208 prediction
209 .edit_preview
210 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
211 })
212 .unwrap_or_default();
213
214 let has_prediction = !actual_patch.is_empty();
215
216 updated_example
217 .lock()
218 .unwrap()
219 .predictions
220 .last_mut()
221 .unwrap()
222 .actual_patch = actual_patch;
223
224 if ix == repetition_count - 1 {
225 let (info, style) = if has_prediction {
226 ("predicted", InfoStyle::Normal)
227 } else {
228 ("no prediction", InfoStyle::Warning)
229 };
230 step_progress.set_info(info, style);
231 }
232 }
233
234 ep_store.update(&mut cx, |store, _| {
235 store.remove_project(&state.project);
236 });
237 debug_task.await?;
238
239 *example = Arc::into_inner(updated_example)
240 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
241 .into_inner()
242 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
243 Ok(())
244}
245
246async fn predict_anthropic(
247 example: &mut Example,
248 _repetition_count: usize,
249 version: ZetaVersion,
250 batched: bool,
251) -> anyhow::Result<()> {
252 let llm_model_name = "claude-sonnet-4-5";
253 let max_tokens = 16384;
254 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
255 let client = if batched {
256 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
257 } else {
258 AnthropicClient::plain()
259 };
260 client.expect("Failed to create Anthropic client")
261 });
262
263 let prompt = example.prompt.as_ref().context("Prompt is required")?;
264
265 let messages = vec![anthropic::Message {
266 role: anthropic::Role::User,
267 content: vec![anthropic::RequestContent::Text {
268 text: prompt.input.clone(),
269 cache_control: None,
270 }],
271 }];
272
273 let Some(response) = llm_client
274 .generate(llm_model_name, max_tokens, messages)
275 .await?
276 else {
277 // Request stashed for batched processing
278 return Ok(());
279 };
280
281 let actual_output = response
282 .content
283 .into_iter()
284 .filter_map(|content| match content {
285 anthropic::ResponseContent::Text { text } => Some(text),
286 _ => None,
287 })
288 .collect::<Vec<String>>()
289 .join("\n");
290
291 let actual_patch = TeacherPrompt::parse(&example, &actual_output)?;
292
293 let prediction = ExamplePrediction {
294 actual_patch,
295 actual_output,
296 provider: if batched {
297 PredictionProvider::Teacher(version)
298 } else {
299 PredictionProvider::TeacherNonBatching(version)
300 },
301 };
302
303 example.predictions.push(prediction);
304 Ok(())
305}
306
307pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
308 match provider {
309 PredictionProvider::Teacher(..) => {
310 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
311 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
312 .expect("Failed to create Anthropic client")
313 });
314 llm_client
315 .sync_batches()
316 .await
317 .context("Failed to sync batches")?;
318 }
319 _ => (),
320 };
321 Ok(())
322}