1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    sync::{Arc, Mutex},
  5    time::Duration,
  6    u32,
  7};
  8
  9use crate::{
 10    ToolMetrics,
 11    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 12};
 13use acp_thread::UserMessageId;
 14use agent::{Thread, ThreadEvent, UserMessageContent};
 15use agent_client_protocol as acp;
 16use agent_settings::AgentProfileId;
 17use anyhow::{Result, anyhow};
 18use async_trait::async_trait;
 19use buffer_diff::DiffHunkStatus;
 20use collections::HashMap;
 21use futures::{FutureExt as _, StreamExt, select_biased};
 22use gpui::{App, AppContext, AsyncApp, Entity};
 23use language_model::Role;
 24use util::rel_path::RelPath;
 25
 26pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 27
 28#[async_trait(?Send)]
 29pub trait Example {
 30    fn meta(&self) -> ExampleMetadata;
 31    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 32    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 33        Vec::new()
 34    }
 35    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 36        Vec::new()
 37    }
 38}
 39
 40#[derive(Clone, Debug)]
 41pub struct JudgeAssertion {
 42    pub id: String,
 43    pub description: String,
 44}
 45
 46#[derive(Clone, Debug)]
 47pub struct ExampleMetadata {
 48    pub name: String,
 49    pub url: String,
 50    pub revision: String,
 51    pub language_server: Option<LanguageServer>,
 52    pub max_assertions: Option<usize>,
 53    pub profile_id: AgentProfileId,
 54    pub existing_thread_json: Option<String>,
 55    pub max_turns: Option<u32>,
 56}
 57
 58#[derive(Clone, Debug)]
 59pub struct LanguageServer {
 60    pub file_extension: String,
 61    pub allow_preexisting_diagnostics: bool,
 62}
 63
 64impl ExampleMetadata {
 65    pub fn repo_name(&self) -> String {
 66        self.url
 67            .split('/')
 68            .next_back()
 69            .unwrap_or("")
 70            .trim_end_matches(".git")
 71            .into()
 72    }
 73}
 74
 75pub struct FailedAssertion(pub String);
 76
 77impl fmt::Debug for FailedAssertion {
 78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 79        write!(f, "Assertion failure: {}", self.0)
 80    }
 81}
 82
 83impl fmt::Display for FailedAssertion {
 84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 85        write!(f, "{}", self.0)
 86    }
 87}
 88
 89impl Error for FailedAssertion {}
 90
 91pub struct ExampleContext {
 92    meta: ExampleMetadata,
 93    log_prefix: String,
 94    agent_thread: Entity<agent::Thread>,
 95    app: AsyncApp,
 96    pub assertions: AssertionsReport,
 97    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 98}
 99
100impl ExampleContext {
101    pub fn new(
102        meta: ExampleMetadata,
103        log_prefix: String,
104        agent_thread: Entity<Thread>,
105        app: AsyncApp,
106    ) -> Self {
107        let assertions = AssertionsReport::new(meta.max_assertions);
108
109        Self {
110            meta,
111            log_prefix,
112            agent_thread,
113            assertions,
114            app,
115            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
116        }
117    }
118
119    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
120        let message = message.to_string();
121        self.log_assertion(
122            if expected {
123                Ok(())
124            } else {
125                Err(anyhow::Error::from(FailedAssertion(message.clone())))
126            },
127            message,
128        )
129    }
130
131    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
132        let message = message.to_string();
133        self.log_assertion(
134            match option {
135                Some(value) => Ok(value),
136                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
137            },
138            message,
139        )
140    }
141
142    #[allow(dead_code)]
143    pub fn assert_eq<T: PartialEq + Debug>(
144        &mut self,
145        left: T,
146        right: T,
147        message: impl ToString,
148    ) -> Result<()> {
149        let message = message.to_string();
150        self.log_assertion(
151            if left == right {
152                Ok(())
153            } else {
154                println!(
155                    "{}{}",
156                    self.log_prefix,
157                    pretty_assertions::Comparison::new(&left, &right)
158                );
159                Err(anyhow::Error::from(FailedAssertion(message.clone())))
160            },
161            message,
162        )
163    }
164
165    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
166        if let Some(max) = self.meta.max_assertions {
167            anyhow::ensure!(
168                self.assertions.run_count() <= max,
169                "More assertions were run than the stated max_assertions of {max}"
170            );
171        }
172
173        self.assertions.ran.push(RanAssertion {
174            id: message.clone(),
175            result: Ok(RanAssertionResult {
176                analysis: None,
177                passed: result.is_ok(),
178            }),
179        });
180
181        if result.is_ok() {
182            println!("{}✅ {}", self.log_prefix, message);
183        } else {
184            println!("{}❌ {}", self.log_prefix, message);
185        }
186
187        result
188    }
189
190    pub async fn prompt(&mut self, prompt: impl Into<String>) -> Result<Response> {
191        self.prompt_with_max_turns(prompt, u32::MAX).await
192    }
193
194    pub async fn prompt_with_max_turns(
195        &mut self,
196        prompt: impl Into<String>,
197        max_turns: u32,
198    ) -> Result<Response> {
199        let content = vec![UserMessageContent::Text(prompt.into())];
200        self.run_turns(Some(content), max_turns).await
201    }
202
203    pub async fn proceed_with_max_turns(&mut self, max_turns: u32) -> Result<Response> {
204        self.run_turns(None, max_turns).await
205    }
206
207    async fn run_turns(
208        &mut self,
209        prompt: Option<Vec<UserMessageContent>>,
210        max_turns: u32,
211    ) -> Result<Response> {
212        let tool_metrics = self.tool_metrics.clone();
213        let log_prefix = self.log_prefix.clone();
214
215        let mut remaining_turns = max_turns;
216
217        let mut event_stream = self.agent_thread.update(&mut self.app, |thread, cx| {
218            if let Some(prompt) = prompt {
219                let id = UserMessageId::new();
220                thread.send(id, prompt, cx)
221            } else {
222                thread.proceed(cx)
223            }
224        })??;
225
226        let task = self.app.background_spawn(async move {
227            let mut messages = Vec::new();
228            let mut tool_uses_by_id = HashMap::default();
229            while let Some(event) = event_stream.next().await {
230                match event? {
231                    ThreadEvent::UserMessage(user_message) => {
232                        messages.push(Message {
233                            role: Role::User,
234                            text: user_message.to_markdown(),
235                            tool_use: Vec::new(),
236                        });
237                    }
238                    ThreadEvent::AgentThinking(text) | ThreadEvent::AgentText(text) => {
239                        if matches!(
240                            messages.last(),
241                            Some(Message {
242                                role: Role::Assistant,
243                                ..
244                            })
245                        ) {
246                            messages.last_mut().unwrap().text.push_str(&text);
247                        } else {
248                            messages.push(Message {
249                                role: Role::Assistant,
250                                text,
251                                tool_use: Vec::new(),
252                            });
253                        }
254                    }
255                    ThreadEvent::ToolCall(tool_call) => {
256                        let meta = tool_call.meta.expect("Missing meta field in tool_call");
257                        let tool_name = meta
258                            .get("tool_name")
259                            .expect("Missing tool_name field in meta")
260                            .as_str()
261                            .expect("Unknown tool_name content in meta");
262
263                        tool_uses_by_id.insert(
264                            tool_call.id,
265                            ToolUse {
266                                name: tool_name.to_string(),
267                                value: tool_call.raw_input.unwrap_or_default(),
268                            },
269                        );
270                        if matches!(
271                            tool_call.status,
272                            acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed
273                        ) {
274                            panic!("Tool call completed without update");
275                        }
276                    }
277                    ThreadEvent::ToolCallUpdate(tool_call_update) => {
278                        if let acp_thread::ToolCallUpdate::UpdateFields(update) = tool_call_update {
279                            if let Some(raw_input) = update.fields.raw_input {
280                                if let Some(tool_use) = tool_uses_by_id.get_mut(&update.id) {
281                                    tool_use.value = raw_input;
282                                }
283                            }
284
285                            if matches!(
286                                update.fields.status,
287                                Some(acp::ToolCallStatus::Completed | acp::ToolCallStatus::Failed)
288                            ) {
289                                let succeeded =
290                                    update.fields.status == Some(acp::ToolCallStatus::Completed);
291
292                                let tool_use = tool_uses_by_id
293                                    .remove(&update.id)
294                                    .expect("Unrecognized tool call completed");
295
296                                let log_message = if succeeded {
297                                    format!("✔︎ {}", tool_use.name)
298                                } else {
299                                    format!("✖︎ {}", tool_use.name)
300                                };
301                                println!("{log_prefix}{log_message}");
302
303                                tool_metrics
304                                    .lock()
305                                    .unwrap()
306                                    .insert(tool_use.name.clone().into(), succeeded);
307
308                                if let Some(message) = messages.last_mut() {
309                                    message.tool_use.push(tool_use);
310                                } else {
311                                    messages.push(Message {
312                                        role: Role::Assistant,
313                                        text: "".to_string(),
314                                        tool_use: vec![tool_use],
315                                    });
316                                }
317
318                                remaining_turns -= 1;
319                                if remaining_turns == 0 {
320                                    return Ok(messages);
321                                }
322                            }
323                        }
324                    }
325                    ThreadEvent::ToolCallAuthorization(_) => panic!(
326                        "{}Bug: Tool confirmation should not be required in eval",
327                        log_prefix
328                    ),
329                    ThreadEvent::Retry(status) => {
330                        println!("{log_prefix} Got retry: {status:?}");
331                    }
332                    ThreadEvent::Stop(stop_reason) => match stop_reason {
333                        acp::StopReason::EndTurn => {}
334                        acp::StopReason::MaxTokens => {
335                            return Err(anyhow!("Exceeded maximum tokens"));
336                        }
337                        acp::StopReason::MaxTurnRequests => {
338                            return Err(anyhow!("Exceeded maximum turn requests"));
339                        }
340                        acp::StopReason::Refusal => {
341                            return Err(anyhow!("Refusal"));
342                        }
343                        acp::StopReason::Cancelled => return Err(anyhow!("Cancelled")),
344                    },
345                }
346            }
347            Ok(messages)
348        });
349
350        select_biased! {
351            result = task.fuse() => {
352                Ok(Response::new(result?))
353            }
354            _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
355                anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
356            }
357        }
358    }
359
360    pub fn edits(&self) -> HashMap<Arc<RelPath>, FileEdits> {
361        self.agent_thread
362            .read_with(&self.app, |thread, cx| {
363                let action_log = thread.action_log().read(cx);
364                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
365                    |(buffer, diff)| {
366                        let snapshot = buffer.read(cx).snapshot();
367
368                        let file = snapshot.file().unwrap();
369                        let diff = diff.read(cx);
370                        let base_text = diff.base_text().text();
371
372                        let hunks = diff
373                            .hunks(&snapshot, cx)
374                            .map(|hunk| FileEditHunk {
375                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
376                                text: snapshot
377                                    .text_for_range(hunk.range.clone())
378                                    .collect::<String>(),
379                                status: hunk.status(),
380                            })
381                            .collect();
382
383                        (file.path().clone(), FileEdits { hunks })
384                    },
385                ))
386            })
387            .unwrap()
388    }
389
390    pub fn agent_thread(&self) -> Entity<Thread> {
391        self.agent_thread.clone()
392    }
393}
394
395impl AppContext for ExampleContext {
396    type Result<T> = anyhow::Result<T>;
397
398    fn new<T: 'static>(
399        &mut self,
400        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
401    ) -> Self::Result<Entity<T>> {
402        self.app.new(build_entity)
403    }
404
405    fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
406        self.app.reserve_entity()
407    }
408
409    fn insert_entity<T: 'static>(
410        &mut self,
411        reservation: gpui::Reservation<T>,
412        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
413    ) -> Self::Result<Entity<T>> {
414        self.app.insert_entity(reservation, build_entity)
415    }
416
417    fn update_entity<T, R>(
418        &mut self,
419        handle: &Entity<T>,
420        update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
421    ) -> Self::Result<R>
422    where
423        T: 'static,
424    {
425        self.app.update_entity(handle, update)
426    }
427
428    fn as_mut<'a, T>(&'a mut self, handle: &Entity<T>) -> Self::Result<gpui::GpuiBorrow<'a, T>>
429    where
430        T: 'static,
431    {
432        self.app.as_mut(handle)
433    }
434
435    fn read_entity<T, R>(
436        &self,
437        handle: &Entity<T>,
438        read: impl FnOnce(&T, &App) -> R,
439    ) -> Self::Result<R>
440    where
441        T: 'static,
442    {
443        self.app.read_entity(handle, read)
444    }
445
446    fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
447    where
448        F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
449    {
450        self.app.update_window(window, f)
451    }
452
453    fn read_window<T, R>(
454        &self,
455        window: &gpui::WindowHandle<T>,
456        read: impl FnOnce(Entity<T>, &App) -> R,
457    ) -> Result<R>
458    where
459        T: 'static,
460    {
461        self.app.read_window(window, read)
462    }
463
464    fn background_spawn<R>(
465        &self,
466        future: impl std::future::Future<Output = R> + Send + 'static,
467    ) -> gpui::Task<R>
468    where
469        R: Send + 'static,
470    {
471        self.app.background_spawn(future)
472    }
473
474    fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
475    where
476        G: gpui::Global,
477    {
478        self.app.read_global(callback)
479    }
480}
481
482#[derive(Debug)]
483pub struct Response {
484    messages: Vec<Message>,
485}
486
487impl Response {
488    pub fn new(messages: Vec<Message>) -> Self {
489        Self { messages }
490    }
491
492    pub fn expect_tool_call(
493        &self,
494        tool_name: &'static str,
495        cx: &mut ExampleContext,
496    ) -> Result<&ToolUse> {
497        let result = self.find_tool_call(tool_name);
498        cx.assert_some(result, format!("called `{}`", tool_name))
499    }
500
501    pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
502        self.messages.iter().rev().find_map(|msg| {
503            msg.tool_use
504                .iter()
505                .find(|tool_use| tool_use.name == tool_name)
506        })
507    }
508
509    pub fn tool_calls(&self) -> impl Iterator<Item = &ToolUse> {
510        self.messages.iter().flat_map(|msg| &msg.tool_use)
511    }
512
513    pub fn texts(&self) -> impl Iterator<Item = String> {
514        self.messages.iter().map(|message| message.text.clone())
515    }
516}
517
518#[derive(Debug)]
519pub struct Message {
520    role: Role,
521    text: String,
522    tool_use: Vec<ToolUse>,
523}
524
525#[derive(Debug)]
526pub struct ToolUse {
527    pub name: String,
528    value: serde_json::Value,
529}
530
531impl ToolUse {
532    pub fn parse_input<Input>(&self) -> Result<Input>
533    where
534        Input: for<'de> serde::Deserialize<'de>,
535    {
536        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
537    }
538}
539
540#[derive(Debug, Eq, PartialEq)]
541pub struct FileEdits {
542    pub hunks: Vec<FileEditHunk>,
543}
544
545#[derive(Debug, Eq, PartialEq)]
546pub struct FileEditHunk {
547    pub base_text: String,
548    pub text: String,
549    pub status: DiffHunkStatus,
550}
551
552impl FileEdits {
553    pub fn has_added_line(&self, line: &str) -> bool {
554        self.hunks.iter().any(|hunk| {
555            hunk.status == DiffHunkStatus::added_none()
556                && hunk.base_text.is_empty()
557                && hunk.text.contains(line)
558        })
559    }
560}