1use crate::{
2 PredictionProvider, PromptFormat,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction, ExamplePrompt},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{InfoStyle, Progress, Step},
10 retrieve_context::run_context_retrieval,
11};
12use anyhow::Context as _;
13use edit_prediction::{DebugEvent, EditPredictionStore};
14use futures::{FutureExt as _, StreamExt as _, future::Shared};
15use gpui::{AppContext as _, AsyncApp, Task};
16use std::{
17 fs,
18 sync::{
19 Arc, Mutex, OnceLock,
20 atomic::{AtomicUsize, Ordering::SeqCst},
21 },
22};
23
24static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
25
26pub async fn run_prediction(
27 example: &mut Example,
28 provider: Option<PredictionProvider>,
29 repetition_count: usize,
30 app_state: Arc<EpAppState>,
31 mut cx: AsyncApp,
32) -> anyhow::Result<()> {
33 let provider = provider.context("provider is required")?;
34
35 if let Some(existing_prediction) = example.predictions.first() {
36 if existing_prediction.provider == provider {
37 return Ok(());
38 } else {
39 example.predictions.clear();
40 }
41 }
42
43 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
44
45 if matches!(
46 provider,
47 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
48 ) {
49 let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
50
51 run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
52
53 let batched = matches!(provider, PredictionProvider::Teacher);
54 return predict_anthropic(example, repetition_count, batched).await;
55 }
56
57 run_load_project(example, app_state.clone(), cx.clone()).await?;
58
59 let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
60
61 if matches!(
62 provider,
63 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
64 ) {
65 step_progress.set_substatus("authenticating");
66 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
67 AUTHENTICATED
68 .get_or_init(|| {
69 let client = app_state.client.clone();
70 cx.spawn(async move |cx| {
71 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
72 eprintln!("Authentication failed: {}", e);
73 }
74 })
75 .shared()
76 })
77 .clone()
78 .await;
79 }
80
81 let ep_store = cx
82 .update(|cx| EditPredictionStore::try_global(cx))
83 .context("EditPredictionStore not initialized")?;
84
85 ep_store.update(&mut cx, |store, _cx| {
86 let model = match provider {
87 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
88 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
89 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
90 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
91 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
92 unreachable!()
93 }
94 };
95 store.set_edit_prediction_model(model);
96 });
97 step_progress.set_substatus("configuring model");
98 let state = example.state.as_ref().context("state must be set")?;
99 let run_dir = RUN_DIR.join(&example.spec.name);
100
101 let updated_example = Arc::new(Mutex::new(example.clone()));
102 let current_run_ix = Arc::new(AtomicUsize::new(0));
103
104 let mut debug_rx = ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx));
105 let debug_task = cx.background_spawn({
106 let updated_example = updated_example.clone();
107 let current_run_ix = current_run_ix.clone();
108 let run_dir = run_dir.clone();
109 async move {
110 while let Some(event) = debug_rx.next().await {
111 let run_ix = current_run_ix.load(SeqCst);
112 let mut updated_example = updated_example.lock().unwrap();
113
114 let run_dir = if repetition_count > 1 {
115 run_dir.join(format!("{:03}", run_ix))
116 } else {
117 run_dir.clone()
118 };
119
120 match event {
121 DebugEvent::EditPredictionStarted(request) => {
122 assert_eq!(updated_example.predictions.len(), run_ix + 1);
123
124 if let Some(prompt) = request.prompt {
125 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
126 if provider == PredictionProvider::Zeta2 {
127 updated_example.prompt.get_or_insert(ExamplePrompt {
128 input: prompt,
129 expected_output: String::new(),
130 format: PromptFormat::Zeta2,
131 });
132 }
133 }
134 }
135 DebugEvent::EditPredictionFinished(request) => {
136 assert_eq!(updated_example.predictions.len(), run_ix + 1);
137
138 if let Some(output) = request.model_output {
139 fs::write(run_dir.join("prediction_response.md"), &output)?;
140 updated_example
141 .predictions
142 .last_mut()
143 .unwrap()
144 .actual_output = output;
145 }
146 if run_ix >= repetition_count {
147 break;
148 }
149 }
150 _ => {}
151 }
152 }
153 anyhow::Ok(())
154 }
155 });
156
157 for ix in 0..repetition_count {
158 current_run_ix.store(ix, SeqCst);
159 let run_dir = if repetition_count > 1 {
160 run_dir.join(format!("{:03}", ix))
161 } else {
162 run_dir.clone()
163 };
164
165 fs::create_dir_all(&run_dir)?;
166 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
167 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
168 }
169 #[cfg(unix)]
170 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
171 #[cfg(windows)]
172 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
173
174 updated_example
175 .lock()
176 .unwrap()
177 .predictions
178 .push(ExamplePrediction {
179 actual_patch: String::new(),
180 actual_output: String::new(),
181 provider,
182 });
183
184 step_progress.set_substatus("requesting prediction");
185 let prediction = ep_store
186 .update(&mut cx, |store, cx| {
187 store.request_prediction(
188 &state.project,
189 &state.buffer,
190 state.cursor_position,
191 cloud_llm_client::PredictEditsRequestTrigger::Cli,
192 cx,
193 )
194 })
195 .await?;
196
197 let actual_patch = prediction
198 .and_then(|prediction| {
199 let prediction = prediction.prediction.ok()?;
200 prediction
201 .edit_preview
202 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
203 })
204 .unwrap_or_default();
205
206 let has_prediction = !actual_patch.is_empty();
207
208 updated_example
209 .lock()
210 .unwrap()
211 .predictions
212 .last_mut()
213 .unwrap()
214 .actual_patch = actual_patch;
215
216 if ix == repetition_count - 1 {
217 let (info, style) = if has_prediction {
218 ("predicted", InfoStyle::Normal)
219 } else {
220 ("no prediction", InfoStyle::Warning)
221 };
222 step_progress.set_info(info, style);
223 }
224 }
225
226 ep_store.update(&mut cx, |store, _| {
227 store.remove_project(&state.project);
228 });
229 debug_task.await?;
230
231 *example = Arc::into_inner(updated_example)
232 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
233 .into_inner()
234 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
235 Ok(())
236}
237
238async fn predict_anthropic(
239 example: &mut Example,
240 _repetition_count: usize,
241 batched: bool,
242) -> anyhow::Result<()> {
243 let llm_model_name = "claude-sonnet-4-5";
244 let max_tokens = 16384;
245 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
246 let client = if batched {
247 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
248 } else {
249 AnthropicClient::plain()
250 };
251 client.expect("Failed to create Anthropic client")
252 });
253
254 let prompt = example.prompt.as_ref().context("Prompt is required")?;
255
256 let messages = vec![anthropic::Message {
257 role: anthropic::Role::User,
258 content: vec![anthropic::RequestContent::Text {
259 text: prompt.input.clone(),
260 cache_control: None,
261 }],
262 }];
263
264 let Some(response) = llm_client
265 .generate(llm_model_name, max_tokens, messages)
266 .await?
267 else {
268 // Request stashed for batched processing
269 return Ok(());
270 };
271
272 let actual_output = response
273 .content
274 .into_iter()
275 .filter_map(|content| match content {
276 anthropic::ResponseContent::Text { text } => Some(text),
277 _ => None,
278 })
279 .collect::<Vec<String>>()
280 .join("\n");
281
282 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
283
284 let prediction = ExamplePrediction {
285 actual_patch,
286 actual_output,
287 provider: PredictionProvider::Teacher,
288 };
289
290 example.predictions.push(prediction);
291 Ok(())
292}
293
294pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
295 match provider {
296 PredictionProvider::Teacher => {
297 let llm_client = ANTHROPIC_CLIENT.get_or_init(|| {
298 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB)
299 .expect("Failed to create Anthropic client")
300 });
301 llm_client
302 .sync_batches()
303 .await
304 .context("Failed to sync batches")?;
305 }
306 _ => (),
307 };
308 Ok(())
309}