example.rs

  1use std::{
  2    error::Error,
  3    fmt::{self, Debug},
  4    path::Path,
  5    sync::{Arc, Mutex},
  6    time::Duration,
  7};
  8
  9use crate::{
 10    ToolMetrics,
 11    assertions::{AssertionsReport, RanAssertion, RanAssertionResult},
 12};
 13use agent::{ContextLoadResult, Thread, ThreadEvent};
 14use anyhow::{Result, anyhow};
 15use assistant_settings::AgentProfileId;
 16use async_trait::async_trait;
 17use buffer_diff::DiffHunkStatus;
 18use collections::HashMap;
 19use futures::{FutureExt as _, StreamExt, channel::mpsc, select_biased};
 20use gpui::{App, AppContext, AsyncApp, Entity};
 21use language_model::{LanguageModel, Role, StopReason};
 22
 23pub const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
 24
 25#[async_trait(?Send)]
 26pub trait Example {
 27    fn meta(&self) -> ExampleMetadata;
 28    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()>;
 29    fn diff_assertions(&self) -> Vec<JudgeAssertion> {
 30        Vec::new()
 31    }
 32    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
 33        Vec::new()
 34    }
 35}
 36
 37#[derive(Clone, Debug)]
 38pub struct JudgeAssertion {
 39    pub id: String,
 40    pub description: String,
 41}
 42
 43#[derive(Clone, Debug)]
 44pub struct ExampleMetadata {
 45    pub name: String,
 46    pub url: String,
 47    pub revision: String,
 48    pub language_server: Option<LanguageServer>,
 49    pub max_assertions: Option<usize>,
 50    pub profile_id: AgentProfileId,
 51}
 52
 53#[derive(Clone, Debug)]
 54pub struct LanguageServer {
 55    pub file_extension: String,
 56    pub allow_preexisting_diagnostics: bool,
 57}
 58
 59impl ExampleMetadata {
 60    pub fn repo_name(&self) -> String {
 61        self.url
 62            .split('/')
 63            .next_back()
 64            .unwrap_or(&"")
 65            .trim_end_matches(".git")
 66            .into()
 67    }
 68}
 69
 70pub struct FailedAssertion(pub String);
 71
 72impl fmt::Debug for FailedAssertion {
 73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 74        write!(f, "Assertion failure: {}", self.0)
 75    }
 76}
 77
 78impl fmt::Display for FailedAssertion {
 79    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 80        write!(f, "{}", self.0)
 81    }
 82}
 83
 84impl Error for FailedAssertion {}
 85
 86pub struct ExampleContext {
 87    meta: ExampleMetadata,
 88    log_prefix: String,
 89    agent_thread: Entity<agent::Thread>,
 90    app: AsyncApp,
 91    model: Arc<dyn LanguageModel>,
 92    pub assertions: AssertionsReport,
 93    pub tool_metrics: Arc<Mutex<ToolMetrics>>,
 94}
 95
 96impl ExampleContext {
 97    pub fn new(
 98        meta: ExampleMetadata,
 99        log_prefix: String,
100        agent_thread: Entity<agent::Thread>,
101        model: Arc<dyn LanguageModel>,
102        app: AsyncApp,
103    ) -> Self {
104        let assertions = AssertionsReport::new(meta.max_assertions);
105
106        Self {
107            meta,
108            log_prefix,
109            agent_thread,
110            assertions,
111            model,
112            app,
113            tool_metrics: Arc::new(Mutex::new(ToolMetrics::default())),
114        }
115    }
116
117    pub fn push_user_message(&mut self, text: impl ToString) {
118        self.app
119            .update_entity(&self.agent_thread, |thread, cx| {
120                thread.insert_user_message(
121                    text.to_string(),
122                    ContextLoadResult::default(),
123                    None,
124                    Vec::new(),
125                    cx,
126                );
127            })
128            .unwrap();
129    }
130
131    pub fn assert(&mut self, expected: bool, message: impl ToString) -> Result<()> {
132        let message = message.to_string();
133        self.log_assertion(
134            if expected {
135                Ok(())
136            } else {
137                Err(anyhow::Error::from(FailedAssertion(message.clone())))
138            },
139            message,
140        )
141    }
142
143    pub fn assert_some<T>(&mut self, option: Option<T>, message: impl ToString) -> Result<T> {
144        let message = message.to_string();
145        self.log_assertion(
146            match option {
147                Some(value) => Ok(value),
148                None => Err(anyhow::Error::from(FailedAssertion(message.clone()))),
149            },
150            message,
151        )
152    }
153
154    #[allow(dead_code)]
155    pub fn assert_eq<T: PartialEq + Debug>(
156        &mut self,
157        left: T,
158        right: T,
159        message: impl ToString,
160    ) -> Result<()> {
161        let message = message.to_string();
162        self.log_assertion(
163            if left == right {
164                Ok(())
165            } else {
166                println!(
167                    "{}{}",
168                    self.log_prefix,
169                    pretty_assertions::Comparison::new(&left, &right)
170                );
171                Err(anyhow::Error::from(FailedAssertion(message.clone())))
172            },
173            message,
174        )
175    }
176
177    fn log_assertion<T>(&mut self, result: Result<T>, message: String) -> Result<T> {
178        if let Some(max) = self.meta.max_assertions {
179            if self.assertions.run_count() > max {
180                return Err(anyhow!(
181                    "More assertions were run than the stated max_assertions of {}",
182                    max
183                ));
184            }
185        }
186
187        self.assertions.ran.push(RanAssertion {
188            id: message.clone(),
189            result: Ok(RanAssertionResult {
190                analysis: None,
191                passed: result.is_ok(),
192            }),
193        });
194
195        if result.is_ok() {
196            println!("{}{}", self.log_prefix, message);
197        } else {
198            println!("{}{}", self.log_prefix, message);
199        }
200
201        result
202    }
203
204    pub async fn run_to_end(&mut self) -> Result<Response> {
205        self.run_turns(u32::MAX).await
206    }
207
208    pub async fn run_turn(&mut self) -> Result<Response> {
209        self.run_turns(1).await
210    }
211
212    pub async fn run_turns(&mut self, iterations: u32) -> Result<Response> {
213        let (mut tx, mut rx) = mpsc::channel(1);
214
215        let tool_metrics = self.tool_metrics.clone();
216        let log_prefix = self.log_prefix.clone();
217        let _subscription = self.app.subscribe(
218            &self.agent_thread,
219            move |thread, event: &ThreadEvent, cx| match event {
220                ThreadEvent::ShowError(thread_error) => {
221                    tx.try_send(Err(anyhow!(thread_error.clone()))).ok();
222                }
223                ThreadEvent::Stopped(reason) => match reason {
224                    Ok(StopReason::EndTurn) => {
225                        tx.close_channel();
226                    }
227                    Ok(StopReason::ToolUse) => {
228                        if thread.read(cx).remaining_turns() == 0 {
229                            tx.close_channel();
230                        }
231                    }
232                    Ok(StopReason::MaxTokens) => {
233                        tx.try_send(Err(anyhow!("Exceeded maximum tokens"))).ok();
234                    }
235                    Err(err) => {
236                        tx.try_send(Err(anyhow!(err.clone()))).ok();
237                    }
238                },
239                ThreadEvent::NewRequest
240                | ThreadEvent::StreamedAssistantText(_, _)
241                | ThreadEvent::StreamedAssistantThinking(_, _)
242                | ThreadEvent::UsePendingTools { .. }
243                | ThreadEvent::CompletionCanceled => {}
244                ThreadEvent::ToolFinished {
245                    tool_use_id,
246                    pending_tool_use,
247                    ..
248                } => {
249                    thread.update(cx, |thread, _cx| {
250                        if let Some(tool_use) = pending_tool_use {
251                            let mut tool_metrics = tool_metrics.lock().unwrap();
252                            if let Some(tool_result) = thread.tool_result(&tool_use_id) {
253                                let message = if tool_result.is_error {
254                                    format!("✖︎ {}", tool_use.name)
255                                } else {
256                                    format!("✔︎ {}", tool_use.name)
257                                };
258                                println!("{log_prefix}{message}");
259                                tool_metrics
260                                    .insert(tool_result.tool_name.clone(), !tool_result.is_error);
261                            } else {
262                                let message =
263                                    format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
264                                println!("{log_prefix}{message}");
265                                tool_metrics.insert(tool_use.name.clone(), true);
266                            }
267                        }
268                    });
269                }
270                ThreadEvent::InvalidToolInput { .. } => {
271                    println!("{log_prefix} invalid tool input");
272                }
273                ThreadEvent::MissingToolUse {
274                    tool_use_id: _,
275                    ui_text,
276                } => {
277                    println!("{log_prefix} {ui_text}");
278                }
279                ThreadEvent::ToolConfirmationNeeded => {
280                    panic!(
281                        "{}Bug: Tool confirmation should not be required in eval",
282                        log_prefix
283                    );
284                }
285                ThreadEvent::StreamedCompletion
286                | ThreadEvent::MessageAdded(_)
287                | ThreadEvent::MessageEdited(_)
288                | ThreadEvent::MessageDeleted(_)
289                | ThreadEvent::SummaryChanged
290                | ThreadEvent::SummaryGenerated
291                | ThreadEvent::ReceivedTextChunk
292                | ThreadEvent::StreamedToolUse { .. }
293                | ThreadEvent::CheckpointChanged
294                | ThreadEvent::CancelEditing => {
295                    tx.try_send(Ok(())).ok();
296                    if std::env::var("ZED_EVAL_DEBUG").is_ok() {
297                        println!("{}Event: {:#?}", log_prefix, event);
298                    }
299                }
300            },
301        );
302
303        let model = self.model.clone();
304
305        let message_count_before = self.app.update_entity(&self.agent_thread, |thread, cx| {
306            thread.set_remaining_turns(iterations);
307            thread.send_to_model(model, None, cx);
308            thread.messages().len()
309        })?;
310
311        loop {
312            select_biased! {
313                result = rx.next() => {
314                    if let Some(result) = result {
315                        result?;
316                    } else {
317                        break;
318                    }
319                }
320                _ = self.app.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
321                    return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
322                }
323            }
324        }
325
326        let messages = self.app.read_entity(&self.agent_thread, |thread, cx| {
327            let mut messages = Vec::new();
328            for message in thread.messages().skip(message_count_before) {
329                messages.push(Message {
330                    _role: message.role,
331                    text: message.to_string(),
332                    tool_use: thread
333                        .tool_uses_for_message(message.id, cx)
334                        .into_iter()
335                        .map(|tool_use| ToolUse {
336                            name: tool_use.name.to_string(),
337                            value: tool_use.input,
338                        })
339                        .collect(),
340                });
341            }
342            messages
343        })?;
344
345        let response = Response::new(messages);
346
347        Ok(response)
348    }
349
350    pub fn edits(&self) -> HashMap<Arc<Path>, FileEdits> {
351        self.agent_thread
352            .read_with(&self.app, |thread, cx| {
353                let action_log = thread.action_log().read(cx);
354                HashMap::from_iter(action_log.changed_buffers(cx).into_iter().map(
355                    |(buffer, diff)| {
356                        let snapshot = buffer.read(cx).snapshot();
357
358                        let file = snapshot.file().unwrap();
359                        let diff = diff.read(cx);
360                        let base_text = diff.base_text().text();
361
362                        let hunks = diff
363                            .hunks(&snapshot, cx)
364                            .map(|hunk| FileEditHunk {
365                                base_text: base_text[hunk.diff_base_byte_range.clone()].to_string(),
366                                text: snapshot
367                                    .text_for_range(hunk.range.clone())
368                                    .collect::<String>(),
369                                status: hunk.status(),
370                            })
371                            .collect();
372
373                        (file.path().clone(), FileEdits { hunks })
374                    },
375                ))
376            })
377            .unwrap()
378    }
379
380    pub fn agent_thread(&self) -> Entity<Thread> {
381        self.agent_thread.clone()
382    }
383}
384
385impl AppContext for ExampleContext {
386    type Result<T> = anyhow::Result<T>;
387
388    fn new<T: 'static>(
389        &mut self,
390        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
391    ) -> Self::Result<Entity<T>> {
392        self.app.new(build_entity)
393    }
394
395    fn reserve_entity<T: 'static>(&mut self) -> Self::Result<gpui::Reservation<T>> {
396        self.app.reserve_entity()
397    }
398
399    fn insert_entity<T: 'static>(
400        &mut self,
401        reservation: gpui::Reservation<T>,
402        build_entity: impl FnOnce(&mut gpui::Context<T>) -> T,
403    ) -> Self::Result<Entity<T>> {
404        self.app.insert_entity(reservation, build_entity)
405    }
406
407    fn update_entity<T, R>(
408        &mut self,
409        handle: &Entity<T>,
410        update: impl FnOnce(&mut T, &mut gpui::Context<T>) -> R,
411    ) -> Self::Result<R>
412    where
413        T: 'static,
414    {
415        self.app.update_entity(handle, update)
416    }
417
418    fn read_entity<T, R>(
419        &self,
420        handle: &Entity<T>,
421        read: impl FnOnce(&T, &App) -> R,
422    ) -> Self::Result<R>
423    where
424        T: 'static,
425    {
426        self.app.read_entity(handle, read)
427    }
428
429    fn update_window<T, F>(&mut self, window: gpui::AnyWindowHandle, f: F) -> Result<T>
430    where
431        F: FnOnce(gpui::AnyView, &mut gpui::Window, &mut App) -> T,
432    {
433        self.app.update_window(window, f)
434    }
435
436    fn read_window<T, R>(
437        &self,
438        window: &gpui::WindowHandle<T>,
439        read: impl FnOnce(Entity<T>, &App) -> R,
440    ) -> Result<R>
441    where
442        T: 'static,
443    {
444        self.app.read_window(window, read)
445    }
446
447    fn background_spawn<R>(
448        &self,
449        future: impl std::future::Future<Output = R> + Send + 'static,
450    ) -> gpui::Task<R>
451    where
452        R: Send + 'static,
453    {
454        self.app.background_spawn(future)
455    }
456
457    fn read_global<G, R>(&self, callback: impl FnOnce(&G, &App) -> R) -> Self::Result<R>
458    where
459        G: gpui::Global,
460    {
461        self.app.read_global(callback)
462    }
463}
464
465#[derive(Debug)]
466pub struct Response {
467    messages: Vec<Message>,
468}
469
470impl Response {
471    pub fn new(messages: Vec<Message>) -> Self {
472        Self { messages }
473    }
474
475    pub fn expect_tool(
476        &self,
477        tool_name: &'static str,
478        cx: &mut ExampleContext,
479    ) -> Result<&ToolUse> {
480        let result = self.messages.iter().find_map(|msg| {
481            msg.tool_use
482                .iter()
483                .find(|tool_use| tool_use.name == tool_name)
484        });
485        cx.assert_some(result, format!("called `{}`", tool_name))
486    }
487
488    #[allow(dead_code)]
489    pub fn tool_uses(&self) -> impl Iterator<Item = &ToolUse> {
490        self.messages.iter().flat_map(|msg| &msg.tool_use)
491    }
492
493    pub fn texts(&self) -> impl Iterator<Item = String> {
494        self.messages.iter().map(|message| message.text.clone())
495    }
496}
497
498#[derive(Debug)]
499pub struct Message {
500    _role: Role,
501    text: String,
502    tool_use: Vec<ToolUse>,
503}
504
505#[derive(Debug)]
506pub struct ToolUse {
507    pub name: String,
508    value: serde_json::Value,
509}
510
511impl ToolUse {
512    pub fn parse_input<Input>(&self) -> Result<Input>
513    where
514        Input: for<'de> serde::Deserialize<'de>,
515    {
516        serde_json::from_value::<Input>(self.value.clone()).map_err(|err| anyhow!(err))
517    }
518}
519
520#[derive(Debug, Eq, PartialEq)]
521pub struct FileEdits {
522    pub hunks: Vec<FileEditHunk>,
523}
524
525#[derive(Debug, Eq, PartialEq)]
526pub struct FileEditHunk {
527    pub base_text: String,
528    pub text: String,
529    pub status: DiffHunkStatus,
530}
531
532impl FileEdits {
533    pub fn has_added_line(&self, line: &str) -> bool {
534        self.hunks.iter().any(|hunk| {
535            hunk.status == DiffHunkStatus::added_none()
536                && hunk.base_text.is_empty()
537                && hunk.text.contains(line)
538        })
539    }
540}