example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    sync::{Arc, Mutex},
  5    time::Duration,
  6};
  7
  8use crate::{
  9    ToolMetrics,
 10    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 11};
 12use agent::ThreadEvent;
 13use anyhow::{Result, anyhow};
 14use async_trait::async_trait;
 15use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 16use gpui::{AppContext, AsyncApp, Entity};
 17use language_model::{LanguageModel, Role, StopReason};
 18
 19pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 20
 21#[async_trait(?Send)]
 22pub trait Example {
 23    fn meta(&self) -> ExampleMetadata;
 24    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 25    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 26        Vec::new()
 27    }
 28    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 29        Vec::new()
 30    }
 31}
 32
 33#[derive(Clone, Debug)]
 34pub struct JudgeAssertion {
 35    pub id: String,
 36    pub description: String,
 37}
 38
 39#[derive(Clone, Debug)]
 40pub struct ExampleMetadata {
 41    pub name: String,
 42    pub url: String,
 43    pub revision: String,
 44    pub language_server: Option<LanguageServer>,
 45    pub max_assertions: Option<usize>,
 46}
 47
 48#[derive(Clone, Debug)]
 49pub struct LanguageServer {
 50    pub file_extension: String,
 51    pub allow_preexisting_diagnostics: bool,
 52}
 53
 54impl ExampleMetadata {
 55    pub fn repo_name(&self) -> String {
 56        self.url
 57            .split('/')
 58            .next_back()
 59            .unwrap_or(&"")
 60            .trim_end_matches(".git")
 61            .into()
 62    }
 63}
 64
 65pub struct FailedAssertion(pub String);
 66
 67impl fmt::Debug for FailedAssertion {
 68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 69        write!(f, "Assertion failure: {}", self.0)
 70    }
 71}
 72
 73impl fmt::Display for FailedAssertion {
 74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 75        write!(f, "{}", self.0)
 76    }
 77}
 78
 79impl Error for FailedAssertion {}
 80
 81pub struct ExampleContext {
 82    meta: ExampleMetadata,
 83    log_prefix: String,
 84    agent_thread: Entity<agent::Thread>,
 85    app: AsyncApp,
 86    model: Arc<dyn LanguageModel>,
 87    pub assertions: AssertionsReport,
 88    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 89}
 90
 91impl ExampleContext {
 92    pub fn new(
 93        meta: ExampleMetadata,
 94        log_prefix: String,
 95        agent_thread: Entity<agent::Thread>,
 96        model: Arc<dyn LanguageModel>,
 97        app: AsyncApp,
 98    ) -> Self {
 99        let assertions = AssertionsReport::new(meta.max_assertions);
100
101        Self {
102            meta,
103            log_prefix,
104            agent_thread,
105            assertions,
106            model,
107            app,
108            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
109        }
110    }
111
112    pub fn push_user_message(&mut self, text: impl ToString) {
113        self.app
114            .update_entity(&self.agent_thread, |thread, cx| {
115                thread.insert_user_message(text.to_string(), vec![], None, cx);
116            })
117            .unwrap();
118    }
119
120    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
121        let message = message.to_string();
122        self.log_assertion(
123            if expected {
124                Ok(())
125            } else {
126                Err(anyhow::Error::from(FailedAssertion(message.clone())))
127            },
128            message,
129        )
130    }
131
132    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
133        let message = message.to_string();
134        self.log_assertion(
135            match option {
136                Some(value) => Ok(value),
137                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
138            },
139            message,
140        )
141    }
142
143    #[allow(dead_code)]
144    pub fn assert_eq<T: PartialEq + Debug>(
145        &mut self,
146        left: T,
147        right: T,
148        message: impl ToString,
149    ) -> Result<()> {
150        let message = message.to_string();
151        self.log_assertion(
152            if left == right {
153                Ok(())
154            } else {
155                println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
156                Err(anyhow::Error::from(FailedAssertion(message.clone())))
157            },
158            message,
159        )
160    }
161
162    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
163        if let Some(max) = self.meta.max_assertions {
164            if self.assertions.run_count() > max {
165                return Err(anyhow!(
166                    "More assertions were run than the stated max_assertions of {}",
167                    max
168                ));
169            }
170        }
171
172        self.assertions.ran.push(RanAssertion {
173            id: message.clone(),
174            result: Ok(RanAssertionResult {
175                analysis: None,
176                passed: result.is_ok(),
177            }),
178        });
179
180        if result.is_ok() {
181            println!("{}{}", self.log_prefix, message);
182        } else {
183            println!("{}{}", self.log_prefix, message);
184        }
185
186        result
187    }
188
189    pub async fn run_to_end(&mut self) -> Result<Response> {
190        self.run_turns(u32::MAX).await
191    }
192
193    pub async fn run_turn(&mut self) -> Result<Response> {
194        self.run_turns(1).await
195    }
196
197    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
198        let (mut tx, mut rx) = mpsc::channel(1);
199
200        let tool_metrics = self.tool_metrics.clone();
201        let log_prefix = self.log_prefix.clone();
202        let _subscription = self.app.subscribe(
203            &self.agent_thread,
204            move |thread, event: &ThreadEvent, cx| match event {
205                ThreadEvent::ShowError(thread_error) => {
206                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
207                }
208                ThreadEvent::Stopped(reason) => match reason {
209                    Ok(StopReason::EndTurn) => {
210                        tx.close_channel();
211                    }
212                    Ok(StopReason::ToolUse) => {
213                        if thread.read(cx).remaining_turns() == 0 {
214                            tx.close_channel();
215                        }
216                    }
217                    Ok(StopReason::MaxTokens) => {
218                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
219                    }
220                    Err(err) => {
221                        tx.try_send(Err(anyhow!(err.clone()))).ok();
222                    }
223                },
224                ThreadEvent::StreamedAssistantText(_, _)
225                | ThreadEvent::StreamedAssistantThinking(_, _)
226                | ThreadEvent::UsePendingTools { .. } => {}
227                ThreadEvent::ToolFinished {
228                    tool_use_id,
229                    pending_tool_use,
230                    ..
231                } => {
232                    thread.update(cx, |thread, _cx| {
233                        if let Some(tool_use) = pending_tool_use {
234                            let mut tool_metrics = tool_metrics.lock().unwrap();
235                            if let Some(tool_result) = thread.tool_result(&tool_use_id) {
236                                let message = if tool_result.is_error {
237                                    format!("TOOL FAILED: {}", tool_use.name)
238                                } else {
239                                    format!("TOOL FINISHED: {}", tool_use.name)
240                                };
241                                println!("{log_prefix}{message}");
242                                tool_metrics
243                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
244                            } else {
245                                let message =
246                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
247                                println!("{log_prefix}{message}");
248                                tool_metrics.insert(tool_use.name.clone(), true);
249                            }
250                        }
251                    });
252                }
253                ThreadEvent::ToolConfirmationNeeded => {
254                    panic!(
255                        "{}Bug: Tool confirmation should not be required in eval",
256                        log_prefix
257                    );
258                }
259                ThreadEvent::StreamedCompletion
260                | ThreadEvent::MessageAdded(_)
261                | ThreadEvent::MessageEdited(_)
262                | ThreadEvent::MessageDeleted(_)
263                | ThreadEvent::SummaryChanged
264                | ThreadEvent::SummaryGenerated
265                | ThreadEvent::ReceivedTextChunk
266                | ThreadEvent::StreamedToolUse { .. }
267                | ThreadEvent::CheckpointChanged
268                | ThreadEvent::UsageUpdated(_) => {
269                    tx.try_send(Ok(())).ok();
270                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
271                        println!("{}Event: {:#?}", log_prefix, event);
272                    }
273                }
274            },
275        );
276
277        let model = self.model.clone();
278
279        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
280            thread.set_remaining_turns(iterations);
281            thread.send_to_model(model, None, cx);
282            thread.messages().len()
283        })?;
284
285        loop {
286            select_biased! {
287                result = rx.next() => {
288                    if let Some(result) = result {
289                        result?;
290                    } else {
291                        break;
292                    }
293                }
294                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
295                    return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
296                }
297            }
298        }
299
300        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
301            let mut messages = Vec::new();
302            for message in thread.messages().skip(message_count_before) {
303                messages.push(Message {
304                    _role: message.role,
305                    _text: message.to_string(),
306                    tool_use: thread
307                        .tool_uses_for_message(message.id, cx)
308                        .into_iter()
309                        .map(|tool_use| ToolUse {
310                            name: tool_use.name.to_string(),
311                            value: tool_use.input,
312                        })
313                        .collect(),
314                });
315            }
316            messages
317        })?;
318
319        let response = Response::new(messages);
320
321        Ok(response)
322    }
323}
324
325#[derive(Debug)]
326pub struct Response {
327    messages: Vec<Message>,
328}
329
330impl Response {
331    pub fn new(messages: Vec<Message>) -> Self {
332        Self { messages }
333    }
334
335    pub fn expect_tool(
336        &self,
337        tool_name: &'static str,
338        cx: &mut ExampleContext,
339    ) -> Result<&ToolUse> {
340        let result = self.messages.iter().find_map(|msg| {
341            msg.tool_use
342                .iter()
343                .find(|tool_use| tool_use.name == tool_name)
344        });
345        cx.assert_some(result, format!("called `{}`", tool_name))
346    }
347}
348
349#[derive(Debug)]
350pub struct Message {
351    _role: Role,
352    _text: String,
353    tool_use: Vec<ToolUse>,
354}
355
356#[derive(Debug)]
357pub struct ToolUse {
358    name: String,
359    value: serde_json::Value,
360}
361
362impl ToolUse {
363    pub fn expect_input<Input>(&self, cx: &mut ExampleContext) -> Result<Input>
364    where
365        Input: for<'de> serde::Deserialize<'de>,
366    {
367        let result =
368            serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err));
369        cx.log_assertion(result, format!("valid `{}` input", &self.name))
370    }
371}