1use crate::{
2 FormatPromptArgs, PredictArgs, PredictionProvider,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{InfoStyle, Progress, Step},
10 retrieve_context::run_context_retrieval,
11};
12use anyhow::Context as _;
13use edit_prediction::{DebugEvent, EditPredictionStore};
14use futures::{FutureExt as _, StreamExt as _, future::Shared};
15use gpui::{AppContext as _, AsyncApp, Task};
16use std::{
17 fs,
18 sync::{
19 Arc, Mutex, OnceLock,
20 atomic::{AtomicUsize, Ordering::SeqCst},
21 },
22};
23
24static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
25
26pub async fn run_prediction(
27 example: &mut Example,
28 args: &PredictArgs,
29 app_state: Arc<EpAppState>,
30 mut cx: AsyncApp,
31) -> anyhow::Result<()> {
32 let provider = args.provider;
33 let repetition_count = args.repetitions;
34 let zeta_version = args.version;
35
36 if let Some(existing_prediction) = example.predictions.first() {
37 if existing_prediction.provider == provider {
38 return Ok(());
39 } else {
40 example.predictions.clear();
41 }
42 }
43
44 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
45
46 if matches!(
47 provider,
48 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
49 ) {
50 let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
51
52 run_format_prompt(
53 example,
54 &FormatPromptArgs {
55 provider,
56 version: args.version,
57 },
58 app_state.clone(),
59 cx,
60 )
61 .await?;
62
63 let batched = matches!(provider, PredictionProvider::Teacher);
64 return predict_anthropic(example, repetition_count, batched).await;
65 }
66
67 run_load_project(example, app_state.clone(), cx.clone()).await?;
68
69 let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
70
71 if matches!(
72 provider,
73 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
74 ) {
75 step_progress.set_substatus("authenticating");
76 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
77 AUTHENTICATED
78 .get_or_init(|| {
79 let client = app_state.client.clone();
80 cx.spawn(async move |cx| {
81 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
82 eprintln!("Authentication failed: {}", e);
83 }
84 })
85 .shared()
86 })
87 .clone()
88 .await;
89 }
90
91 let ep_store = cx
92 .update(|cx| EditPredictionStore::try_global(cx))
93 .context("EditPredictionStore not initialized")?;
94
95 ep_store.update(&mut cx, |store, _cx| {
96 let model = match provider {
97 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
98 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
99 version: zeta_version,
100 },
101 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
102 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
103 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
104 unreachable!()
105 }
106 };
107 store.set_edit_prediction_model(model);
108 });
109 step_progress.set_substatus("configuring model");
110 let state = example.state.as_ref().context("state must be set")?;
111 let run_dir = RUN_DIR.join(&example.spec.name);
112
113 let updated_example = Arc::new(Mutex::new(example.clone()));
114 let current_run_ix = Arc::new(AtomicUsize::new(0));
115
116 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
117 let debug_task = cx.background_spawn({
118 let updated_example = updated_example.clone();
119 let current_run_ix = current_run_ix.clone();
120 let run_dir = run_dir.clone();
121 async move {
122 while let Some(event) = debug_rx.next().await {
123 let run_ix = current_run_ix.load(SeqCst);
124 let mut updated_example = updated_example.lock().unwrap();
125
126 let run_dir = if repetition_count > 1 {
127 run_dir.join(format!("{:03}", run_ix))
128 } else {
129 run_dir.clone()
130 };
131
132 match event {
133 DebugEvent::EditPredictionStarted(request) => {
134 assert_eq!(updated_example.predictions.len(), run_ix + 1);
135
136 if let Some(prompt) = request.prompt {
137 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
138 if provider == PredictionProvider::Zeta2 {
139 updated_example.prompt.get_or_insert(ExamplePrompt {
140 input: prompt,
141 expected_output: String::new(),
142 provider,
143 });
144 }
145 }
146 }
147 DebugEvent::EditPredictionFinished(request) => {
148 assert_eq!(updated_example.predictions.len(), run_ix + 1);
149
150 if let Some(output) = request.model_output {
151 fs::write(run_dir.join("prediction_response.md"), &output)?;
152 updated_example
153 .predictions
154 .last_mut()
155 .unwrap()
156 .actual_output = output;
157 }
158 if run_ix >= repetition_count {
159 break;
160 }
161 }
162 _ => {}
163 }
164 }
165 anyhow::Ok(())
166 }
167 });
168
169 for ix in 0..repetition_count {
170 current_run_ix.store(ix, SeqCst);
171 let run_dir = if repetition_count > 1 {
172 run_dir.join(format!("{:03}", ix))
173 } else {
174 run_dir.clone()
175 };
176
177 fs::create_dir_all(&run_dir)?;
178 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
179 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
180 }
181 #[cfg(unix)]
182 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
183 #[cfg(windows)]
184 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
185
186 updated_example
187 .lock()
188 .unwrap()
189 .predictions
190 .push(ExamplePrediction {
191 actual_patch: String::new(),
192 actual_output: String::new(),
193 provider,
194 });
195
196 step_progress.set_substatus("requesting prediction");
197 let prediction = ep_store
198 .update(&mut cx, |store, cx| {
199 store.request_prediction(
200 &state.project,
201 &state.buffer,
202 state.cursor_position,
203 cloud_llm_client::PredictEditsRequestTrigger::Cli,
204 cx,
205 )
206 })
207 .await?;
208
209 let actual_patch = prediction
210 .and_then(|prediction| {
211 let prediction = prediction.prediction.ok()?;
212 prediction
213 .edit_preview
214 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
215 })
216 .unwrap_or_default();
217
218 let has_prediction = !actual_patch.is_empty();
219
220 updated_example
221 .lock()
222 .unwrap()
223 .predictions
224 .last_mut()
225 .unwrap()
226 .actual_patch = actual_patch;
227
228 if ix == repetition_count - 1 {
229 let (info, style) = if has_prediction {
230 ("predicted", InfoStyle::Normal)
231 } else {
232 ("no prediction", InfoStyle::Warning)
233 };
234 step_progress.set_info(info, style);
235 }
236 }
237
238 ep_store.update(&mut cx, |store, _| {
239 store.remove_project(&state.project);
240 });
241 debug_task.await?;
242
243 *example = Arc::into_inner(updated_example)
244 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
245 .into_inner()
246 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
247 Ok(())
248}
249
250async fn predict_anthropic(
251 example: &mut Example,
252 _repetition_count: usize,
253 batched: bool,
254) -> anyhow::Result<()> {
255 let llm_model_name = "claude-sonnet-4-5";
256 let max_tokens = 16384;
257 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
258 let client = if batched {
259 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
260 } else {
261 AnthropicClient::plain()
262 };
263 client.expect("Failed to create Anthropic client")
264 });
265
266 let prompt = example.prompt.as_ref().context("Prompt is required")?;
267
268 let messages = vec![anthropic::Message {
269 role: anthropic::Role::User,
270 content: vec![anthropic::RequestContent::Text {
271 text: prompt.input.clone(),
272 cache_control: None,
273 }],
274 }];
275
276 let Some(response) = llm_client
277 .generate(llm_model_name, max_tokens, messages)
278 .await?
279 else {
280 // Request stashed for batched processing
281 return Ok(());
282 };
283
284 let actual_output = response
285 .content
286 .into_iter()
287 .filter_map(|content| match content {
288 anthropic::ResponseContent::Text { text } => Some(text),
289 _ => None,
290 })
291 .collect::<Vec<String>>()
292 .join("\n");
293
294 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
295
296 let prediction = ExamplePrediction {
297 actual_patch,
298 actual_output,
299 provider: PredictionProvider::Teacher,
300 };
301
302 example.predictions.push(prediction);
303 Ok(())
304}
305
306pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
307 match provider {
308 PredictionProvider::Teacher => {
309 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
310 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
311 .expect("Failed to create Anthropic client")
312 });
313 llm_client
314 .sync_batches()
315 .await
316 .context("Failed to sync batches")?;
317 }
318 _ => (),
319 };
320 Ok(())
321}