example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    path::Path,
  5    sync::{Arc, Mutex},
  6    time::Duration,
  7};
  8
  9use crate::{
 10    ToolMetrics,
 11    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 12};
 13use agent::{ContextLoadResult, Thread, ThreadEvent};
 14use anyhow::{Result, anyhow};
 15use async_trait::async_trait;
 16use buffer_diff::DiffHunkStatus;
 17use collections::HashMap;
 18use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 19use gpui::{App, AppContext, AsyncApp, Entity};
 20use language_model::{LanguageModel, Role, StopReason};
 21
 22pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 23
 24#[async_trait(?Send)]
 25pub trait Example {
 26    fn meta(&self) -> ExampleMetadata;
 27    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 28    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 29        Vec::new()
 30    }
 31    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 32        Vec::new()
 33    }
 34}
 35
 36#[derive(Clone, Debug)]
 37pub struct JudgeAssertion {
 38    pub id: String,
 39    pub description: String,
 40}
 41
 42#[derive(Clone, Debug)]
 43pub struct ExampleMetadata {
 44    pub name: String,
 45    pub url: String,
 46    pub revision: String,
 47    pub language_server: Option<LanguageServer>,
 48    pub max_assertions: Option<usize>,
 49}
 50
 51#[derive(Clone, Debug)]
 52pub struct LanguageServer {
 53    pub file_extension: String,
 54    pub allow_preexisting_diagnostics: bool,
 55}
 56
 57impl ExampleMetadata {
 58    pub fn repo_name(&self) -> String {
 59        self.url
 60            .split('/')
 61            .next_back()
 62            .unwrap_or(&"")
 63            .trim_end_matches(".git")
 64            .into()
 65    }
 66}
 67
 68pub struct FailedAssertion(pub String);
 69
 70impl fmt::Debug for FailedAssertion {
 71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 72        write!(f, "Assertion failure: {}", self.0)
 73    }
 74}
 75
 76impl fmt::Display for FailedAssertion {
 77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 78        write!(f, "{}", self.0)
 79    }
 80}
 81
 82impl Error for FailedAssertion {}
 83
 84pub struct ExampleContext {
 85    meta: ExampleMetadata,
 86    log_prefix: String,
 87    agent_thread: Entity<agent::Thread>,
 88    app: AsyncApp,
 89    model: Arc<dyn LanguageModel>,
 90    pub assertions: AssertionsReport,
 91    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 92}
 93
 94impl ExampleContext {
 95    pub fn new(
 96        meta: ExampleMetadata,
 97        log_prefix: String,
 98        agent_thread: Entity<agent::Thread>,
 99        model: Arc<dyn LanguageModel>,
100        app: AsyncApp,
101    ) -> Self {
102        let assertions = AssertionsReport::new(meta.max_assertions);
103
104        Self {
105            meta,
106            log_prefix,
107            agent_thread,
108            assertions,
109            model,
110            app,
111            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
112        }
113    }
114
115    pub fn push_user_message(&mut self, text: impl ToString) {
116        self.app
117            .update_entity(&self.agent_thread, |thread, cx| {
118                thread.insert_user_message(
119                    text.to_string(),
120                    ContextLoadResult::default(),
121                    None,
122                    Vec::new(),
123                    cx,
124                );
125            })
126            .unwrap();
127    }
128
129    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
130        let message = message.to_string();
131        self.log_assertion(
132            if expected {
133                Ok(())
134            } else {
135                Err(anyhow::Error::from(FailedAssertion(message.clone())))
136            },
137            message,
138        )
139    }
140
141    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
142        let message = message.to_string();
143        self.log_assertion(
144            match option {
145                Some(value) => Ok(value),
146                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
147            },
148            message,
149        )
150    }
151
152    #[allow(dead_code)]
153    pub fn assert_eq<T: PartialEq + Debug>(
154        &mut self,
155        left: T,
156        right: T,
157        message: impl ToString,
158    ) -> Result<()> {
159        let message = message.to_string();
160        self.log_assertion(
161            if left == right {
162                Ok(())
163            } else {
164                println!(
165                    "{}{}",
166                    self.log_prefix,
167                    pretty_assertions::Comparison::new(&left, &right)
168                );
169                Err(anyhow::Error::from(FailedAssertion(message.clone())))
170            },
171            message,
172        )
173    }
174
175    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
176        if let Some(max) = self.meta.max_assertions {
177            if self.assertions.run_count() > max {
178                return Err(anyhow!(
179                    "More assertions were run than the stated max_assertions of {}",
180                    max
181                ));
182            }
183        }
184
185        self.assertions.ran.push(RanAssertion {
186            id: message.clone(),
187            result: Ok(RanAssertionResult {
188                analysis: None,
189                passed: result.is_ok(),
190            }),
191        });
192
193        if result.is_ok() {
194            println!("{}{}", self.log_prefix, message);
195        } else {
196            println!("{}{}", self.log_prefix, message);
197        }
198
199        result
200    }
201
202    pub async fn run_to_end(&mut self) -> Result<Response> {
203        self.run_turns(u32::MAX).await
204    }
205
206    pub async fn run_turn(&mut self) -> Result<Response> {
207        self.run_turns(1).await
208    }
209
210    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
211        let (mut tx, mut rx) = mpsc::channel(1);
212
213        let tool_metrics = self.tool_metrics.clone();
214        let log_prefix = self.log_prefix.clone();
215        let _subscription = self.app.subscribe(
216            &self.agent_thread,
217            move |thread, event: &ThreadEvent, cx| match event {
218                ThreadEvent::ShowError(thread_error) => {
219                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
220                }
221                ThreadEvent::Stopped(reason) => match reason {
222                    Ok(StopReason::EndTurn) => {
223                        tx.close_channel();
224                    }
225                    Ok(StopReason::ToolUse) => {
226                        if thread.read(cx).remaining_turns() == 0 {
227                            tx.close_channel();
228                        }
229                    }
230                    Ok(StopReason::MaxTokens) => {
231                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
232                    }
233                    Err(err) => {
234                        tx.try_send(Err(anyhow!(err.clone()))).ok();
235                    }
236                },
237                ThreadEvent::NewRequest
238                | ThreadEvent::StreamedAssistantText(_, _)
239                | ThreadEvent::StreamedAssistantThinking(_, _)
240                | ThreadEvent::UsePendingTools { .. }
241                | ThreadEvent::CompletionCanceled => {}
242                ThreadEvent::ToolFinished {
243                    tool_use_id,
244                    pending_tool_use,
245                    ..
246                } => {
247                    thread.update(cx, |thread, _cx| {
248                        if let Some(tool_use) = pending_tool_use {
249                            let mut tool_metrics = tool_metrics.lock().unwrap();
250                            if let Some(tool_result) = thread.tool_result(&tool_use_id) {
251                                let message = if tool_result.is_error {
252                                    format!("✖︎ {}", tool_use.name)
253                                } else {
254                                    format!("✔︎ {}", tool_use.name)
255                                };
256                                println!("{log_prefix}{message}");
257                                tool_metrics
258                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
259                            } else {
260                                let message =
261                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
262                                println!("{log_prefix}{message}");
263                                tool_metrics.insert(tool_use.name.clone(), true);
264                            }
265                        }
266                    });
267                }
268                ThreadEvent::InvalidToolInput { .. } => {
269                    println!("{log_prefix} invalid tool input");
270                }
271                ThreadEvent::ToolConfirmationNeeded => {
272                    panic!(
273                        "{}Bug: Tool confirmation should not be required in eval",
274                        log_prefix
275                    );
276                }
277                ThreadEvent::StreamedCompletion
278                | ThreadEvent::MessageAdded(_)
279                | ThreadEvent::MessageEdited(_)
280                | ThreadEvent::MessageDeleted(_)
281                | ThreadEvent::SummaryChanged
282                | ThreadEvent::SummaryGenerated
283                | ThreadEvent::ReceivedTextChunk
284                | ThreadEvent::StreamedToolUse { .. }
285                | ThreadEvent::CheckpointChanged
286                | ThreadEvent::UsageUpdated(_)
287                | ThreadEvent::CancelEditing => {
288                    tx.try_send(Ok(())).ok();
289                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
290                        println!("{}Event: {:#?}", log_prefix, event);
291                    }
292                }
293            },
294        );
295
296        let model = self.model.clone();
297
298        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
299            thread.set_remaining_turns(iterations);
300            thread.send_to_model(model, None, cx);
301            thread.messages().len()
302        })?;
303
304        loop {
305            select_biased! {
306                result = rx.next() => {
307                    if let Some(result) = result {
308                        result?;
309                    } else {
310                        break;
311                    }
312                }
313                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
314                    return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
315                }
316            }
317        }
318
319        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
320            let mut messages = Vec::new();
321            for message in thread.messages().skip(message_count_before) {
322                messages.push(Message {
323                    _role: message.role,
324                    text: message.to_string(),
325                    tool_use: thread
326                        .tool_uses_for_message(message.id, cx)
327                        .into_iter()
328                        .map(|tool_use| ToolUse {
329                            name: tool_use.name.to_string(),
330                            value: tool_use.input,
331                        })
332                        .collect(),
333                });
334            }
335            messages
336        })?;
337
338        let response = Response::new(messages);
339
340        Ok(response)
341    }
342
343    pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
344        self.agent_thread
345            .read_with(&self.app, |thread, cx| {
346                let action_log = thread.action_log().read(cx);
347                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
348                    |(buffer, diff)| {
349                        let snapshot = buffer.read(cx).snapshot();
350
351                        let file = snapshot.file().unwrap();
352                        let diff = diff.read(cx);
353                        let base_text = diff.base_text().text();
354
355                        let hunks = diff
356                            .hunks(&snapshot, cx)
357                            .map(|hunk| FileEditHunk {
358                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
359                                text: snapshot
360                                    .text_for_range(hunk.range.clone())
361                                    .collect::<String>(),
362                                status: hunk.status(),
363                            })
364                            .collect();
365
366                        (file.path().clone(), FileEdits { hunks })
367                    },
368                ))
369            })
370            .unwrap()
371    }
372
373    pub fn agent_thread(&self) -> Entity<Thread> {
374        self.agent_thread.clone()
375    }
376}
377
378impl AppContext for ExampleContext {
379    type Result<T> = anyhow::Result<T>;
380
381    fn new<T: 'static>(
382        &mut self,
383        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
384    ) -> Self::Result<Entity<T>> {
385        self.app.new(build_entity)
386    }
387
388    fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
389        self.app.reserve_entity()
390    }
391
392    fn insert_entity<T: 'static>(
393        &mut self,
394        reservation: gpui::Reservation<T>,
395        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
396    ) -> Self::Result<Entity<T>> {
397        self.app.insert_entity(reservation, build_entity)
398    }
399
400    fn update_entity<T, R>(
401        &mut self,
402        handle: &Entity<T>,
403        update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
404    ) -> Self::Result<R>
405    where
406        T: 'static,
407    {
408        self.app.update_entity(handle, update)
409    }
410
411    fn read_entity<T, R>(
412        &self,
413        handle: &Entity<T>,
414        read: impl FnOnce(&T, &App) -> R,
415    ) -> Self::Result<R>
416    where
417        T: 'static,
418    {
419        self.app.read_entity(handle, read)
420    }
421
422    fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
423    where
424        F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
425    {
426        self.app.update_window(window, f)
427    }
428
429    fn read_window<T, R>(
430        &self,
431        window: &gpui::WindowHandle<T>,
432        read: impl FnOnce(Entity<T>, &App) -> R,
433    ) -> Result<R>
434    where
435        T: 'static,
436    {
437        self.app.read_window(window, read)
438    }
439
440    fn background_spawn<R>(
441        &self,
442        future: impl std::future::Future<Output = R> + Send + 'static,
443    ) -> gpui::Task<R>
444    where
445        R: Send + 'static,
446    {
447        self.app.background_spawn(future)
448    }
449
450    fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
451    where
452        G: gpui::Global,
453    {
454        self.app.read_global(callback)
455    }
456}
457
458#[derive(Debug)]
459pub struct Response {
460    messages: Vec<Message>,
461}
462
463impl Response {
464    pub fn new(messages: Vec<Message>) -> Self {
465        Self { messages }
466    }
467
468    pub fn expect_tool(
469        &self,
470        tool_name: &'static str,
471        cx: &mut ExampleContext,
472    ) -> Result<&ToolUse> {
473        let result = self.messages.iter().find_map(|msg| {
474            msg.tool_use
475                .iter()
476                .find(|tool_use| tool_use.name == tool_name)
477        });
478        cx.assert_some(result, format!("called `{}`", tool_name))
479    }
480
481    #[allow(dead_code)]
482    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
483        self.messages.iter().flat_map(|msg| &msg.tool_use)
484    }
485
486    pub fn texts(&self) -> impl Iterator<Item = String> {
487        self.messages.iter().map(|message| message.text.clone())
488    }
489}
490
491#[derive(Debug)]
492pub struct Message {
493    _role: Role,
494    text: String,
495    tool_use: Vec<ToolUse>,
496}
497
498#[derive(Debug)]
499pub struct ToolUse {
500    pub name: String,
501    value: serde_json::Value,
502}
503
504impl ToolUse {
505    pub fn parse_input<Input>(&self) -> Result<Input>
506    where
507        Input: for<'de> serde::Deserialize<'de>,
508    {
509        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
510    }
511}
512
513#[derive(Debug, Eq, PartialEq)]
514pub struct FileEdits {
515    pub hunks: Vec<FileEditHunk>,
516}
517
518#[derive(Debug, Eq, PartialEq)]
519pub struct FileEditHunk {
520    pub base_text: String,
521    pub text: String,
522    pub status: DiffHunkStatus,
523}
524
525impl FileEdits {
526    pub fn has_added_line(&self, line: &str) -> bool {
527        self.hunks.iter().any(|hunk| {
528            hunk.status == DiffHunkStatus::added_none()
529                && hunk.base_text.is_empty()
530                && hunk.text.contains(line)
531        })
532    }
533}