1use crate::{
2 PredictionProvider, PromptFormat,
3 anthropic_client::AnthropicClient,
4 example::{Example, ExamplePrediction},
5 format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
6 headless::EpAppState,
7 load_project::run_load_project,
8 paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
9 retrieve_context::run_context_retrieval,
10};
11use edit_prediction::{DebugEvent, EditPredictionStore};
12use futures::{FutureExt as _, StreamExt as _, future::Shared};
13use gpui::{AppContext as _, AsyncApp, Task};
14use std::{
15 fs,
16 sync::{
17 Arc, Mutex, OnceLock,
18 atomic::{AtomicUsize, Ordering::SeqCst},
19 },
20};
21
22pub async fn run_prediction(
23 example: &mut Example,
24 provider: Option<PredictionProvider>,
25 repetition_count: usize,
26 app_state: Arc<EpAppState>,
27 mut cx: AsyncApp,
28) {
29 if !example.predictions.is_empty() {
30 return;
31 }
32
33 run_load_project(example, app_state.clone(), cx.clone()).await;
34 run_context_retrieval(example, app_state.clone(), cx.clone()).await;
35
36 let provider = provider.unwrap();
37
38 if matches!(provider, PredictionProvider::Teacher) {
39 if example.prompt.is_none() {
40 run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
41 }
42
43 let batched = true;
44 return predict_anthropic(example, repetition_count, batched).await;
45 }
46
47 if matches!(
48 provider,
49 PredictionProvider::Zeta1 | PredictionProvider::Zeta2
50 ) {
51 static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
52 AUTHENTICATED
53 .get_or_init(|| {
54 let client = app_state.client.clone();
55 cx.spawn(async move |cx| {
56 client
57 .sign_in_with_optional_connect(true, cx)
58 .await
59 .unwrap();
60 })
61 .shared()
62 })
63 .clone()
64 .await;
65 }
66
67 let ep_store = cx
68 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
69 .unwrap();
70
71 ep_store
72 .update(&mut cx, |store, _cx| {
73 let model = match provider {
74 PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
75 PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
76 PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
77 PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
78 PredictionProvider::Teacher => unreachable!(),
79 };
80 store.set_edit_prediction_model(model);
81 })
82 .unwrap();
83 let state = example.state.as_ref().unwrap();
84 let run_dir = RUN_DIR.join(&example.name);
85
86 let updated_example = Arc::new(Mutex::new(example.clone()));
87 let current_run_ix = Arc::new(AtomicUsize::new(0));
88
89 let mut debug_rx = ep_store
90 .update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
91 .unwrap();
92 let debug_task = cx.background_spawn({
93 let updated_example = updated_example.clone();
94 let current_run_ix = current_run_ix.clone();
95 let run_dir = run_dir.clone();
96 async move {
97 while let Some(event) = debug_rx.next().await {
98 let run_ix = current_run_ix.load(SeqCst);
99 let mut updated_example = updated_example.lock().unwrap();
100
101 let run_dir = if repetition_count > 1 {
102 run_dir.join(format!("{:03}", run_ix))
103 } else {
104 run_dir.clone()
105 };
106
107 match event {
108 DebugEvent::EditPredictionStarted(request) => {
109 assert_eq!(updated_example.predictions.len(), run_ix + 1);
110
111 if let Some(prompt) = request.prompt {
112 fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
113 }
114 }
115 DebugEvent::EditPredictionFinished(request) => {
116 assert_eq!(updated_example.predictions.len(), run_ix + 1);
117
118 if let Some(output) = request.model_output {
119 fs::write(run_dir.join("prediction_response.md"), &output)?;
120 updated_example
121 .predictions
122 .last_mut()
123 .unwrap()
124 .actual_output = output;
125 }
126 if run_ix >= repetition_count {
127 break;
128 }
129 }
130 _ => {}
131 }
132 }
133 anyhow::Ok(())
134 }
135 });
136
137 for ix in 0..repetition_count {
138 current_run_ix.store(ix, SeqCst);
139 let run_dir = if repetition_count > 1 {
140 run_dir.join(format!("{:03}", ix))
141 } else {
142 run_dir.clone()
143 };
144
145 fs::create_dir_all(&run_dir).unwrap();
146 if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
147 fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
148 }
149 #[cfg(unix)]
150 std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
151 #[cfg(windows)]
152 std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
153
154 updated_example
155 .lock()
156 .unwrap()
157 .predictions
158 .push(ExamplePrediction {
159 actual_patch: String::new(),
160 actual_output: String::new(),
161 provider,
162 });
163
164 let prediction = ep_store
165 .update(&mut cx, |store, cx| {
166 store.request_prediction(
167 &state.project,
168 &state.buffer,
169 state.cursor_position,
170 cloud_llm_client::PredictEditsRequestTrigger::Cli,
171 cx,
172 )
173 })
174 .unwrap()
175 .await
176 .unwrap();
177
178 updated_example
179 .lock()
180 .unwrap()
181 .predictions
182 .last_mut()
183 .unwrap()
184 .actual_patch = prediction
185 .and_then(|prediction| {
186 let prediction = prediction.prediction.ok()?;
187 prediction.edit_preview.as_unified_diff(&prediction.edits)
188 })
189 .unwrap_or_default();
190 }
191
192 ep_store
193 .update(&mut cx, |store, _| {
194 store.remove_project(&state.project);
195 })
196 .unwrap();
197 debug_task.await.unwrap();
198
199 *example = Arc::into_inner(updated_example)
200 .unwrap()
201 .into_inner()
202 .unwrap();
203}
204
205async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
206 let llm_model_name = "claude-sonnet-4-5";
207 let max_tokens = 16384;
208 let llm_client = if batched {
209 AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
210 } else {
211 AnthropicClient::plain()
212 };
213 let llm_client = llm_client.expect("Failed to create LLM client");
214
215 let prompt = example
216 .prompt
217 .as_ref()
218 .unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
219
220 let messages = vec![anthropic::Message {
221 role: anthropic::Role::User,
222 content: vec![anthropic::RequestContent::Text {
223 text: prompt.input.clone(),
224 cache_control: None,
225 }],
226 }];
227
228 let Some(response) = llm_client
229 .generate(llm_model_name, max_tokens, messages)
230 .await
231 .unwrap()
232 else {
233 // Request stashed for batched processing
234 return;
235 };
236
237 let actual_output = response
238 .content
239 .into_iter()
240 .filter_map(|content| match content {
241 anthropic::ResponseContent::Text { text } => Some(text),
242 _ => None,
243 })
244 .collect::<Vec<String>>()
245 .join("\n");
246
247 let actual_patch = TeacherPrompt::parse(example, &actual_output);
248
249 let prediction = ExamplePrediction {
250 actual_patch,
251 actual_output,
252 provider: PredictionProvider::Teacher,
253 };
254
255 example.predictions.push(prediction);
256}
257
258pub async fn sync_batches(provider: &PredictionProvider) {
259 match provider {
260 PredictionProvider::Teacher => {
261 let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
262 let llm_client =
263 AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
264 llm_client
265 .sync_batches()
266 .await
267 .expect("Failed to sync batches");
268 }
269 _ => (),
270 }
271}