example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    path::Path,
  5    sync::{Arc, Mutex},
  6    time::Duration,
  7};
  8
  9use crate::{
 10    ToolMetrics,
 11    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 12};
 13use agent::{ContextLoadResult, Thread, ThreadEvent};
 14use agent_settings::AgentProfileId;
 15use anyhow::{Result, anyhow};
 16use async_trait::async_trait;
 17use buffer_diff::DiffHunkStatus;
 18use collections::HashMap;
 19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 20use gpui::{App, AppContext, AsyncApp, Entity};
 21use language_model::{LanguageModel, Role, StopReason};
 22
 23pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 24
 25#[async_trait(?Send)]
 26pub trait Example {
 27    fn meta(&self) -> ExampleMetadata;
 28    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 29    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 30        Vec::new()
 31    }
 32    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 33        Vec::new()
 34    }
 35}
 36
 37#[derive(Clone, Debug)]
 38pub struct JudgeAssertion {
 39    pub id: String,
 40    pub description: String,
 41}
 42
 43#[derive(Clone, Debug)]
 44pub struct ExampleMetadata {
 45    pub name: String,
 46    pub url: String,
 47    pub revision: String,
 48    pub language_server: Option<LanguageServer>,
 49    pub max_assertions: Option<usize>,
 50    pub profile_id: AgentProfileId,
 51    pub existing_thread_json: Option<String>,
 52}
 53
 54#[derive(Clone, Debug)]
 55pub struct LanguageServer {
 56    pub file_extension: String,
 57    pub allow_preexisting_diagnostics: bool,
 58}
 59
 60impl ExampleMetadata {
 61    pub fn repo_name(&self) -> String {
 62        self.url
 63            .split('/')
 64            .next_back()
 65            .unwrap_or(&"")
 66            .trim_end_matches(".git")
 67            .into()
 68    }
 69}
 70
 71pub struct FailedAssertion(pub String);
 72
 73impl fmt::Debug for FailedAssertion {
 74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 75        write!(f, "Assertion failure: {}", self.0)
 76    }
 77}
 78
 79impl fmt::Display for FailedAssertion {
 80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 81        write!(f, "{}", self.0)
 82    }
 83}
 84
 85impl Error for FailedAssertion {}
 86
 87pub struct ExampleContext {
 88    meta: ExampleMetadata,
 89    log_prefix: String,
 90    agent_thread: Entity<agent::Thread>,
 91    app: AsyncApp,
 92    model: Arc<dyn LanguageModel>,
 93    pub assertions: AssertionsReport,
 94    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 95}
 96
 97impl ExampleContext {
 98    pub fn new(
 99        meta: ExampleMetadata,
100        log_prefix: String,
101        agent_thread: Entity<agent::Thread>,
102        model: Arc<dyn LanguageModel>,
103        app: AsyncApp,
104    ) -> Self {
105        let assertions = AssertionsReport::new(meta.max_assertions);
106
107        Self {
108            meta,
109            log_prefix,
110            agent_thread,
111            assertions,
112            model,
113            app,
114            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
115        }
116    }
117
118    pub fn push_user_message(&mut self, text: impl ToString) {
119        self.app
120            .update_entity(&self.agent_thread, |thread, cx| {
121                thread.insert_user_message(
122                    text.to_string(),
123                    ContextLoadResult::default(),
124                    None,
125                    Vec::new(),
126                    cx,
127                );
128            })
129            .unwrap();
130    }
131
132    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
133        let message = message.to_string();
134        self.log_assertion(
135            if expected {
136                Ok(())
137            } else {
138                Err(anyhow::Error::from(FailedAssertion(message.clone())))
139            },
140            message,
141        )
142    }
143
144    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
145        let message = message.to_string();
146        self.log_assertion(
147            match option {
148                Some(value) => Ok(value),
149                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
150            },
151            message,
152        )
153    }
154
155    #[allow(dead_code)]
156    pub fn assert_eq<T: PartialEq + Debug>(
157        &mut self,
158        left: T,
159        right: T,
160        message: impl ToString,
161    ) -> Result<()> {
162        let message = message.to_string();
163        self.log_assertion(
164            if left == right {
165                Ok(())
166            } else {
167                println!(
168                    "{}{}",
169                    self.log_prefix,
170                    pretty_assertions::Comparison::new(&left, &right)
171                );
172                Err(anyhow::Error::from(FailedAssertion(message.clone())))
173            },
174            message,
175        )
176    }
177
178    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
179        if let Some(max) = self.meta.max_assertions {
180            anyhow::ensure!(
181                self.assertions.run_count() <= max,
182                "More assertions were run than the stated max_assertions of {max}"
183            );
184        }
185
186        self.assertions.ran.push(RanAssertion {
187            id: message.clone(),
188            result: Ok(RanAssertionResult {
189                analysis: None,
190                passed: result.is_ok(),
191            }),
192        });
193
194        if result.is_ok() {
195            println!("{}{}", self.log_prefix, message);
196        } else {
197            println!("{}{}", self.log_prefix, message);
198        }
199
200        result
201    }
202
203    pub async fn run_to_end(&mut self) -> Result<Response> {
204        self.run_turns(u32::MAX).await
205    }
206
207    pub async fn run_turn(&mut self) -> Result<Response> {
208        self.run_turns(1).await
209    }
210
211    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
212        let (mut tx, mut rx) = mpsc::channel(1);
213
214        let tool_metrics = self.tool_metrics.clone();
215        let log_prefix = self.log_prefix.clone();
216        let _subscription = self.app.subscribe(
217            &self.agent_thread,
218            move |thread, event: &ThreadEvent, cx| match event {
219                ThreadEvent::ShowError(thread_error) => {
220                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
221                }
222                ThreadEvent::Stopped(reason) => match reason {
223                    Ok(StopReason::EndTurn) => {
224                        tx.close_channel();
225                    }
226                    Ok(StopReason::ToolUse) => {
227                        if thread.read(cx).remaining_turns() == 0 {
228                            tx.close_channel();
229                        }
230                    }
231                    Ok(StopReason::MaxTokens) => {
232                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
233                    }
234                    Ok(StopReason::Refusal) => {
235                        tx.try_send(Err(anyhow!("Model refused to generate content")))
236                            .ok();
237                    }
238                    Err(err) => {
239                        tx.try_send(Err(anyhow!(err.clone()))).ok();
240                    }
241                },
242                ThreadEvent::NewRequest
243                | ThreadEvent::StreamedAssistantText(_, _)
244                | ThreadEvent::StreamedAssistantThinking(_, _)
245                | ThreadEvent::UsePendingTools { .. }
246                | ThreadEvent::CompletionCanceled => {}
247                ThreadEvent::ToolFinished {
248                    tool_use_id,
249                    pending_tool_use,
250                    ..
251                } => {
252                    thread.update(cx, |thread, _cx| {
253                        if let Some(tool_use) = pending_tool_use {
254                            let mut tool_metrics = tool_metrics.lock().unwrap();
255                            if let Some(tool_result) = thread.tool_result(&tool_use_id) {
256                                let message = if tool_result.is_error {
257                                    format!("✖︎ {}", tool_use.name)
258                                } else {
259                                    format!("✔︎ {}", tool_use.name)
260                                };
261                                println!("{log_prefix}{message}");
262                                tool_metrics
263                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
264                            } else {
265                                let message =
266                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
267                                println!("{log_prefix}{message}");
268                                tool_metrics.insert(tool_use.name.clone(), true);
269                            }
270                        }
271                    });
272                }
273                ThreadEvent::InvalidToolInput { .. } => {
274                    println!("{log_prefix} invalid tool input");
275                }
276                ThreadEvent::MissingToolUse {
277                    tool_use_id: _,
278                    ui_text,
279                } => {
280                    println!("{log_prefix} {ui_text}");
281                }
282                ThreadEvent::ToolConfirmationNeeded => {
283                    panic!(
284                        "{}Bug: Tool confirmation should not be required in eval",
285                        log_prefix
286                    );
287                }
288                ThreadEvent::StreamedCompletion
289                | ThreadEvent::MessageAdded(_)
290                | ThreadEvent::MessageEdited(_)
291                | ThreadEvent::MessageDeleted(_)
292                | ThreadEvent::SummaryChanged
293                | ThreadEvent::SummaryGenerated
294                | ThreadEvent::ReceivedTextChunk
295                | ThreadEvent::StreamedToolUse { .. }
296                | ThreadEvent::CheckpointChanged
297                | ThreadEvent::CancelEditing => {
298                    tx.try_send(Ok(())).ok();
299                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
300                        println!("{}Event: {:#?}", log_prefix, event);
301                    }
302                }
303            },
304        );
305
306        let model = self.model.clone();
307
308        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
309            thread.set_remaining_turns(iterations);
310            thread.send_to_model(model, None, cx);
311            thread.messages().len()
312        })?;
313
314        loop {
315            select_biased! {
316                result = rx.next() => {
317                    if let Some(result) = result {
318                        result?;
319                    } else {
320                        break;
321                    }
322                }
323                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
324                    anyhow::bail!("Agentic loop stalled - waited {THREAD_EVENT_TIMEOUT:?} without any events");
325                }
326            }
327        }
328
329        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
330            let mut messages = Vec::new();
331            for message in thread.messages().skip(message_count_before) {
332                messages.push(Message {
333                    _role: message.role,
334                    text: message.to_string(),
335                    tool_use: thread
336                        .tool_uses_for_message(message.id, cx)
337                        .into_iter()
338                        .map(|tool_use| ToolUse {
339                            name: tool_use.name.to_string(),
340                            value: tool_use.input,
341                        })
342                        .collect(),
343                });
344            }
345            messages
346        })?;
347
348        let response = Response::new(messages);
349
350        Ok(response)
351    }
352
353    pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
354        self.agent_thread
355            .read_with(&self.app, |thread, cx| {
356                let action_log = thread.action_log().read(cx);
357                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
358                    |(buffer, diff)| {
359                        let snapshot = buffer.read(cx).snapshot();
360
361                        let file = snapshot.file().unwrap();
362                        let diff = diff.read(cx);
363                        let base_text = diff.base_text().text();
364
365                        let hunks = diff
366                            .hunks(&snapshot, cx)
367                            .map(|hunk| FileEditHunk {
368                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
369                                text: snapshot
370                                    .text_for_range(hunk.range.clone())
371                                    .collect::<String>(),
372                                status: hunk.status(),
373                            })
374                            .collect();
375
376                        (file.path().clone(), FileEdits { hunks })
377                    },
378                ))
379            })
380            .unwrap()
381    }
382
383    pub fn agent_thread(&self) -> Entity<Thread> {
384        self.agent_thread.clone()
385    }
386}
387
388impl AppContext for ExampleContext {
389    type Result<T> = anyhow::Result<T>;
390
391    fn new<T: 'static>(
392        &mut self,
393        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
394    ) -> Self::Result<Entity<T>> {
395        self.app.new(build_entity)
396    }
397
398    fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
399        self.app.reserve_entity()
400    }
401
402    fn insert_entity<T: 'static>(
403        &mut self,
404        reservation: gpui::Reservation<T>,
405        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
406    ) -> Self::Result<Entity<T>> {
407        self.app.insert_entity(reservation, build_entity)
408    }
409
410    fn update_entity<T, R>(
411        &mut self,
412        handle: &Entity<T>,
413        update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
414    ) -> Self::Result<R>
415    where
416        T: 'static,
417    {
418        self.app.update_entity(handle, update)
419    }
420
421    fn read_entity<T, R>(
422        &self,
423        handle: &Entity<T>,
424        read: impl FnOnce(&T, &App) -> R,
425    ) -> Self::Result<R>
426    where
427        T: 'static,
428    {
429        self.app.read_entity(handle, read)
430    }
431
432    fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
433    where
434        F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
435    {
436        self.app.update_window(window, f)
437    }
438
439    fn read_window<T, R>(
440        &self,
441        window: &gpui::WindowHandle<T>,
442        read: impl FnOnce(Entity<T>, &App) -> R,
443    ) -> Result<R>
444    where
445        T: 'static,
446    {
447        self.app.read_window(window, read)
448    }
449
450    fn background_spawn<R>(
451        &self,
452        future: impl std::future::Future<Output = R> + Send + 'static,
453    ) -> gpui::Task<R>
454    where
455        R: Send + 'static,
456    {
457        self.app.background_spawn(future)
458    }
459
460    fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
461    where
462        G: gpui::Global,
463    {
464        self.app.read_global(callback)
465    }
466}
467
468#[derive(Debug)]
469pub struct Response {
470    messages: Vec<Message>,
471}
472
473impl Response {
474    pub fn new(messages: Vec<Message>) -> Self {
475        Self { messages }
476    }
477
478    pub fn expect_tool(
479        &self,
480        tool_name: &'static str,
481        cx: &mut ExampleContext,
482    ) -> Result<&ToolUse> {
483        let result = self.find_tool_call(tool_name);
484        cx.assert_some(result, format!("called `{}`", tool_name))
485    }
486
487    pub fn find_tool_call(&self, tool_name: &str) -> Option<&ToolUse> {
488        self.messages.iter().rev().find_map(|msg| {
489            msg.tool_use
490                .iter()
491                .find(|tool_use| tool_use.name == tool_name)
492        })
493    }
494
495    #[allow(dead_code)]
496    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
497        self.messages.iter().flat_map(|msg| &msg.tool_use)
498    }
499
500    pub fn texts(&self) -> impl Iterator<Item = String> {
501        self.messages.iter().map(|message| message.text.clone())
502    }
503}
504
505#[derive(Debug)]
506pub struct Message {
507    _role: Role,
508    text: String,
509    tool_use: Vec<ToolUse>,
510}
511
512#[derive(Debug)]
513pub struct ToolUse {
514    pub name: String,
515    value: serde_json::Value,
516}
517
518impl ToolUse {
519    pub fn parse_input<Input>(&self) -> Result<Input>
520    where
521        Input: for<'de> serde::Deserialize<'de>,
522    {
523        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
524    }
525}
526
527#[derive(Debug, Eq, PartialEq)]
528pub struct FileEdits {
529    pub hunks: Vec<FileEditHunk>,
530}
531
532#[derive(Debug, Eq, PartialEq)]
533pub struct FileEditHunk {
534    pub base_text: String,
535    pub text: String,
536    pub status: DiffHunkStatus,
537}
538
539impl FileEdits {
540    pub fn has_added_line(&self, line: &str) -> bool {
541        self.hunks.iter().any(|hunk| {
542            hunk.status == DiffHunkStatus::added_none()
543                && hunk.base_text.is_empty()
544                && hunk.text.contains(line)
545        })
546    }
547}