1use crate::{
2 PredictionProvider, PromptFormat,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction},
5 format_prompt::{TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 progress::{InfoStyle, Progress, Step},
10 retrieve_context::run_context_retrieval,
11};
12use anyhow::Context as _;
13use edit_prediction::{DebugEvent, EditPredictionStore};
14use futures::{FutureExt as _, StreamExt as _, future::Shared};
15use gpui::{AppContext as _, AsyncApp, Task};
16use std::{
17 fs,
18 sync::{
19 Arc, Mutex, OnceLock,
20 atomic::{AtomicUsize, Ordering::SeqCst},
21 },
22};
23
24pub async fn run_prediction(
25 example: &mut Example,
26 provider: Option<PredictionProvider>,
27 repetition_count: usize,
28 app_state: Arc<EpAppState>,
29 mut cx: AsyncApp,
30) -> anyhow::Result<()> {
31 let provider = provider.context("provider is required")?;
32
33 if let Some(existing_prediction) = example.predictions.first() {
34 if existing_prediction.provider == provider {
35 return Ok(());
36 } else {
37 example.predictions.clear();
38 }
39 }
40
41 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
42
43 if matches!(
44 provider,
45 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
46 ) {
47 let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
48
49 if example.prompt.is_none() {
50 run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
51 }
52
53 let batched = matches!(provider, PredictionProvider::Teacher);
54 return predict_anthropic(example, repetition_count, batched).await;
55 }
56
57 run_load_project(example, app_state.clone(), cx.clone()).await?;
58
59 let step_progress = Progress::global().start(Step::Predict, &example.spec.name);
60
61 if matches!(
62 provider,
63 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
64 ) {
65 step_progress.set_substatus("authenticating");
66 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
67 AUTHENTICATED
68 .get_or_init(|| {
69 let client = app_state.client.clone();
70 cx.spawn(async move |cx| {
71 if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
72 eprintln!("Authentication failed: {}", e);
73 }
74 })
75 .shared()
76 })
77 .clone()
78 .await;
79 }
80
81 let ep_store = cx.update(|cx| {
82 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
83 })??;
84
85 ep_store.update(&mut cx, |store, _cx| {
86 let model = match provider {
87 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
88 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
89 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
90 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
91 PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
92 unreachable!()
93 }
94 };
95 store.set_edit_prediction_model(model);
96 })?;
97 step_progress.set_substatus("configuring model");
98 let state = example.state.as_ref().context("state must be set")?;
99 let run_dir = RUN_DIR.join(&example.spec.name);
100
101 let updated_example = Arc::new(Mutex::new(example.clone()));
102 let current_run_ix = Arc::new(AtomicUsize::new(0));
103
104 let mut debug_rx =
105 ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
106 let debug_task = cx.background_spawn({
107 let updated_example = updated_example.clone();
108 let current_run_ix = current_run_ix.clone();
109 let run_dir = run_dir.clone();
110 async move {
111 while let Some(event) = debug_rx.next().await {
112 let run_ix = current_run_ix.load(SeqCst);
113 let mut updated_example = updated_example.lock().unwrap();
114
115 let run_dir = if repetition_count > 1 {
116 run_dir.join(format!("{:03}", run_ix))
117 } else {
118 run_dir.clone()
119 };
120
121 match event {
122 DebugEvent::EditPredictionStarted(request) => {
123 assert_eq!(updated_example.predictions.len(), run_ix + 1);
124
125 if let Some(prompt) = request.prompt {
126 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
127 }
128 }
129 DebugEvent::EditPredictionFinished(request) => {
130 assert_eq!(updated_example.predictions.len(), run_ix + 1);
131
132 if let Some(output) = request.model_output {
133 fs::write(run_dir.join("prediction_response.md"), &output)?;
134 updated_example
135 .predictions
136 .last_mut()
137 .unwrap()
138 .actual_output = output;
139 }
140 if run_ix >= repetition_count {
141 break;
142 }
143 }
144 _ => {}
145 }
146 }
147 anyhow::Ok(())
148 }
149 });
150
151 for ix in 0..repetition_count {
152 current_run_ix.store(ix, SeqCst);
153 let run_dir = if repetition_count > 1 {
154 run_dir.join(format!("{:03}", ix))
155 } else {
156 run_dir.clone()
157 };
158
159 fs::create_dir_all(&run_dir)?;
160 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
161 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
162 }
163 #[cfg(unix)]
164 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
165 #[cfg(windows)]
166 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
167
168 updated_example
169 .lock()
170 .unwrap()
171 .predictions
172 .push(ExamplePrediction {
173 actual_patch: String::new(),
174 actual_output: String::new(),
175 provider,
176 });
177
178 step_progress.set_substatus("requesting prediction");
179 let prediction = ep_store
180 .update(&mut cx, |store, cx| {
181 store.request_prediction(
182 &state.project,
183 &state.buffer,
184 state.cursor_position,
185 cloud_llm_client::PredictEditsRequestTrigger::Cli,
186 cx,
187 )
188 })?
189 .await?;
190
191 let actual_patch = prediction
192 .and_then(|prediction| {
193 let prediction = prediction.prediction.ok()?;
194 prediction
195 .edit_preview
196 .as_unified_diff(prediction.snapshot.file(), &prediction.edits)
197 })
198 .unwrap_or_default();
199
200 let has_prediction = !actual_patch.is_empty();
201
202 updated_example
203 .lock()
204 .unwrap()
205 .predictions
206 .last_mut()
207 .unwrap()
208 .actual_patch = actual_patch;
209
210 if ix == repetition_count - 1 {
211 let (info, style) = if has_prediction {
212 ("predicted", InfoStyle::Normal)
213 } else {
214 ("no prediction", InfoStyle::Warning)
215 };
216 step_progress.set_info(info, style);
217 }
218 }
219
220 ep_store.update(&mut cx, |store, _| {
221 store.remove_project(&state.project);
222 })?;
223 debug_task.await?;
224
225 *example = Arc::into_inner(updated_example)
226 .ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
227 .into_inner()
228 .map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
229 Ok(())
230}
231
232async fn predict_anthropic(
233 example: &mut Example,
234 _repetition_count: usize,
235 batched: bool,
236) -> anyhow::Result<()> {
237 let llm_model_name = "claude-sonnet-4-5";
238 let max_tokens = 16384;
239 let llm_client = if batched {
240 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
241 } else {
242 AnthropicClient::plain()
243 };
244 let llm_client = llm_client.context("Failed to create LLM client")?;
245
246 let prompt = example.prompt.as_ref().context("Prompt is required")?;
247
248 let messages = vec![anthropic::Message {
249 role: anthropic::Role::User,
250 content: vec![anthropic::RequestContent::Text {
251 text: prompt.input.clone(),
252 cache_control: None,
253 }],
254 }];
255
256 let Some(response) = llm_client
257 .generate(llm_model_name, max_tokens, messages)
258 .await?
259 else {
260 // Request stashed for batched processing
261 return Ok(());
262 };
263
264 let actual_output = response
265 .content
266 .into_iter()
267 .filter_map(|content| match content {
268 anthropic::ResponseContent::Text { text } => Some(text),
269 _ => None,
270 })
271 .collect::<Vec<String>>()
272 .join("\n");
273
274 let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
275
276 let prediction = ExamplePrediction {
277 actual_patch,
278 actual_output,
279 provider: PredictionProvider::Teacher,
280 };
281
282 example.predictions.push(prediction);
283 Ok(())
284}
285
286pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
287 match provider {
288 PredictionProvider::Teacher => {
289 let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
290 let llm_client =
291 AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
292 llm_client
293 .sync_batches()
294 .await
295 .context("Failed to sync batches")?;
296 }
297 _ => (),
298 };
299 Ok(())
300}