example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    path::Path,
  5    sync::{Arc, Mutex},
  6    time::Duration,
  7};
  8
  9use crate::{
 10    ToolMetrics,
 11    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 12};
 13use agent::ThreadEvent;
 14use anyhow::{Result, anyhow};
 15use async_trait::async_trait;
 16use buffer_diff::DiffHunkStatus;
 17use collections::HashMap;
 18use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 19use gpui::{AppContext, AsyncApp, Entity};
 20use language_model::{LanguageModel, Role, StopReason};
 21
 22pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 23
 24#[async_trait(?Send)]
 25pub trait Example {
 26    fn meta(&self) -> ExampleMetadata;
 27    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 28    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 29        Vec::new()
 30    }
 31    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 32        Vec::new()
 33    }
 34}
 35
 36#[derive(Clone, Debug)]
 37pub struct JudgeAssertion {
 38    pub id: String,
 39    pub description: String,
 40}
 41
 42#[derive(Clone, Debug)]
 43pub struct ExampleMetadata {
 44    pub name: String,
 45    pub url: String,
 46    pub revision: String,
 47    pub language_server: Option<LanguageServer>,
 48    pub max_assertions: Option<usize>,
 49}
 50
 51#[derive(Clone, Debug)]
 52pub struct LanguageServer {
 53    pub file_extension: String,
 54    pub allow_preexisting_diagnostics: bool,
 55}
 56
 57impl ExampleMetadata {
 58    pub fn repo_name(&self) -> String {
 59        self.url
 60            .split('/')
 61            .next_back()
 62            .unwrap_or(&"")
 63            .trim_end_matches(".git")
 64            .into()
 65    }
 66}
 67
 68pub struct FailedAssertion(pub String);
 69
 70impl fmt::Debug for FailedAssertion {
 71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 72        write!(f, "Assertion failure: {}", self.0)
 73    }
 74}
 75
 76impl fmt::Display for FailedAssertion {
 77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 78        write!(f, "{}", self.0)
 79    }
 80}
 81
 82impl Error for FailedAssertion {}
 83
 84pub struct ExampleContext {
 85    meta: ExampleMetadata,
 86    log_prefix: String,
 87    agent_thread: Entity<agent::Thread>,
 88    app: AsyncApp,
 89    model: Arc<dyn LanguageModel>,
 90    pub assertions: AssertionsReport,
 91    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 92}
 93
 94impl ExampleContext {
 95    pub fn new(
 96        meta: ExampleMetadata,
 97        log_prefix: String,
 98        agent_thread: Entity<agent::Thread>,
 99        model: Arc<dyn LanguageModel>,
100        app: AsyncApp,
101    ) -> Self {
102        let assertions = AssertionsReport::new(meta.max_assertions);
103
104        Self {
105            meta,
106            log_prefix,
107            agent_thread,
108            assertions,
109            model,
110            app,
111            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
112        }
113    }
114
115    pub fn push_user_message(&mut self, text: impl ToString) {
116        self.app
117            .update_entity(&self.agent_thread, |thread, cx| {
118                thread.insert_user_message(text.to_string(), vec![], None, cx);
119            })
120            .unwrap();
121    }
122
123    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
124        let message = message.to_string();
125        self.log_assertion(
126            if expected {
127                Ok(())
128            } else {
129                Err(anyhow::Error::from(FailedAssertion(message.clone())))
130            },
131            message,
132        )
133    }
134
135    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
136        let message = message.to_string();
137        self.log_assertion(
138            match option {
139                Some(value) => Ok(value),
140                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
141            },
142            message,
143        )
144    }
145
146    #[allow(dead_code)]
147    pub fn assert_eq<T: PartialEq + Debug>(
148        &mut self,
149        left: T,
150        right: T,
151        message: impl ToString,
152    ) -> Result<()> {
153        let message = message.to_string();
154        self.log_assertion(
155            if left == right {
156                Ok(())
157            } else {
158                println!("{}{:#?} != {:#?}", self.log_prefix, left, right);
159                Err(anyhow::Error::from(FailedAssertion(message.clone())))
160            },
161            message,
162        )
163    }
164
165    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
166        if let Some(max) = self.meta.max_assertions {
167            if self.assertions.run_count() > max {
168                return Err(anyhow!(
169                    "More assertions were run than the stated max_assertions of {}",
170                    max
171                ));
172            }
173        }
174
175        self.assertions.ran.push(RanAssertion {
176            id: message.clone(),
177            result: Ok(RanAssertionResult {
178                analysis: None,
179                passed: result.is_ok(),
180            }),
181        });
182
183        if result.is_ok() {
184            println!("{}{}", self.log_prefix, message);
185        } else {
186            println!("{}{}", self.log_prefix, message);
187        }
188
189        result
190    }
191
192    pub async fn run_to_end(&mut self) -> Result<Response> {
193        self.run_turns(u32::MAX).await
194    }
195
196    pub async fn run_turn(&mut self) -> Result<Response> {
197        self.run_turns(1).await
198    }
199
200    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
201        let (mut tx, mut rx) = mpsc::channel(1);
202
203        let tool_metrics = self.tool_metrics.clone();
204        let log_prefix = self.log_prefix.clone();
205        let _subscription = self.app.subscribe(
206            &self.agent_thread,
207            move |thread, event: &ThreadEvent, cx| match event {
208                ThreadEvent::ShowError(thread_error) => {
209                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
210                }
211                ThreadEvent::Stopped(reason) => match reason {
212                    Ok(StopReason::EndTurn) => {
213                        tx.close_channel();
214                    }
215                    Ok(StopReason::ToolUse) => {
216                        if thread.read(cx).remaining_turns() == 0 {
217                            tx.close_channel();
218                        }
219                    }
220                    Ok(StopReason::MaxTokens) => {
221                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
222                    }
223                    Err(err) => {
224                        tx.try_send(Err(anyhow!(err.clone()))).ok();
225                    }
226                },
227                ThreadEvent::StreamedAssistantText(_, _)
228                | ThreadEvent::StreamedAssistantThinking(_, _)
229                | ThreadEvent::UsePendingTools { .. } => {}
230                ThreadEvent::ToolFinished {
231                    tool_use_id,
232                    pending_tool_use,
233                    ..
234                } => {
235                    thread.update(cx, |thread, _cx| {
236                        if let Some(tool_use) = pending_tool_use {
237                            let mut tool_metrics = tool_metrics.lock().unwrap();
238                            if let Some(tool_result) = thread.tool_result(&tool_use_id) {
239                                let message = if tool_result.is_error {
240                                    format!("✖︎ {}", tool_use.name)
241                                } else {
242                                    format!("✔︎ {}", tool_use.name)
243                                };
244                                println!("{log_prefix}{message}");
245                                tool_metrics
246                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
247                            } else {
248                                let message =
249                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
250                                println!("{log_prefix}{message}");
251                                tool_metrics.insert(tool_use.name.clone(), true);
252                            }
253                        }
254                    });
255                }
256                ThreadEvent::InvalidToolInput { .. } => {
257                    println!("{log_prefix} invalid tool input");
258                }
259                ThreadEvent::ToolConfirmationNeeded => {
260                    panic!(
261                        "{}Bug: Tool confirmation should not be required in eval",
262                        log_prefix
263                    );
264                }
265                ThreadEvent::StreamedCompletion
266                | ThreadEvent::MessageAdded(_)
267                | ThreadEvent::MessageEdited(_)
268                | ThreadEvent::MessageDeleted(_)
269                | ThreadEvent::SummaryChanged
270                | ThreadEvent::SummaryGenerated
271                | ThreadEvent::ReceivedTextChunk
272                | ThreadEvent::StreamedToolUse { .. }
273                | ThreadEvent::CheckpointChanged
274                | ThreadEvent::UsageUpdated(_) => {
275                    tx.try_send(Ok(())).ok();
276                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
277                        println!("{}Event: {:#?}", log_prefix, event);
278                    }
279                }
280            },
281        );
282
283        let model = self.model.clone();
284
285        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
286            thread.set_remaining_turns(iterations);
287            thread.send_to_model(model, None, cx);
288            thread.messages().len()
289        })?;
290
291        loop {
292            select_biased! {
293                result = rx.next() => {
294                    if let Some(result) = result {
295                        result?;
296                    } else {
297                        break;
298                    }
299                }
300                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
301                    return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
302                }
303            }
304        }
305
306        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
307            let mut messages = Vec::new();
308            for message in thread.messages().skip(message_count_before) {
309                messages.push(Message {
310                    _role: message.role,
311                    _text: message.to_string(),
312                    tool_use: thread
313                        .tool_uses_for_message(message.id, cx)
314                        .into_iter()
315                        .map(|tool_use| ToolUse {
316                            name: tool_use.name.to_string(),
317                            value: tool_use.input,
318                        })
319                        .collect(),
320                });
321            }
322            messages
323        })?;
324
325        let response = Response::new(messages);
326
327        Ok(response)
328    }
329
330    pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
331        self.app
332            .read_entity(&self.agent_thread, |thread, cx| {
333                let action_log = thread.action_log().read(cx);
334                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
335                    |(buffer, diff)| {
336                        let snapshot = buffer.read(cx).snapshot();
337
338                        let file = snapshot.file().unwrap();
339                        let diff = diff.read(cx);
340                        let base_text = diff.base_text().text();
341
342                        let hunks = diff
343                            .hunks(&snapshot, cx)
344                            .map(|hunk| FileEditHunk {
345                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
346                                text: snapshot
347                                    .text_for_range(hunk.range.clone())
348                                    .collect::<String>(),
349                                status: hunk.status(),
350                            })
351                            .collect();
352
353                        (file.path().clone(), FileEdits { hunks })
354                    },
355                ))
356            })
357            .unwrap()
358    }
359}
360
361#[derive(Debug)]
362pub struct Response {
363    messages: Vec<Message>,
364}
365
366impl Response {
367    pub fn new(messages: Vec<Message>) -> Self {
368        Self { messages }
369    }
370
371    pub fn expect_tool(
372        &self,
373        tool_name: &'static str,
374        cx: &mut ExampleContext,
375    ) -> Result<&ToolUse> {
376        let result = self.messages.iter().find_map(|msg| {
377            msg.tool_use
378                .iter()
379                .find(|tool_use| tool_use.name == tool_name)
380        });
381        cx.assert_some(result, format!("called `{}`", tool_name))
382    }
383
384    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
385        self.messages.iter().flat_map(|msg| &msg.tool_use)
386    }
387}
388
389#[derive(Debug)]
390pub struct Message {
391    _role: Role,
392    _text: String,
393    tool_use: Vec<ToolUse>,
394}
395
396#[derive(Debug)]
397pub struct ToolUse {
398    pub name: String,
399    value: serde_json::Value,
400}
401
402impl ToolUse {
403    pub fn parse_input<Input>(&self) -> Result<Input>
404    where
405        Input: for<'de> serde::Deserialize<'de>,
406    {
407        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
408    }
409}
410
411#[derive(Debug)]
412pub struct FileEdits {
413    hunks: Vec<FileEditHunk>,
414}
415
416#[derive(Debug)]
417struct FileEditHunk {
418    base_text: String,
419    text: String,
420    status: DiffHunkStatus,
421}
422
423impl FileEdits {
424    pub fn has_added_line(&self, line: &str) -> bool {
425        self.hunks.iter().any(|hunk| {
426            hunk.status == DiffHunkStatus::added_none()
427                && hunk.base_text.is_empty()
428                && hunk.text.contains(line)
429        })
430    }
431}