code_block_citations.rs

  1use agent_settings::AgentProfileId;
  2use anyhow::Result;
  3use async_trait::async_trait;
  4use markdown::PathWithRange;
  5
  6use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
  7
  8pub struct CodeBlockCitations;
  9
 10const FENCE: &str = "```";
 11
 12#[async_trait(?Send)]
 13impl Example for CodeBlockCitations {
 14    fn meta(&self) -> ExampleMetadata {
 15        ExampleMetadata {
 16            name: "code_block_citations".to_string(),
 17            url: "https://github.com/zed-industries/zed.git".to_string(),
 18            revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(),
 19            language_server: Some(LanguageServer {
 20                file_extension: "rs".to_string(),
 21                allow_preexisting_diagnostics: false,
 22            }),
 23            max_assertions: None,
 24            profile_id: AgentProfileId::default(),
 25            existing_thread_json: None,
 26            max_turns: None,
 27        }
 28    }
 29
 30    async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
 31        const FILENAME: &str = "assistant_tool.rs";
 32
 33        // Verify that the messages all have the correct formatting.
 34        let texts: Vec<String> = cx
 35            .prompt(format!(
 36                r#"
 37                Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}.
 38
 39                Please show each method in a separate code snippet.
 40                "#
 41            ))
 42            .await?
 43            .texts()
 44            .collect();
 45        let closing_fence = format!("\n{FENCE}");
 46
 47        for text in texts.iter() {
 48            let mut text = text.as_str();
 49
 50            while let Some(index) = text.find(FENCE) {
 51                // Advance text past the opening backticks.
 52                text = &text[index + FENCE.len()..];
 53
 54                // Find the closing backticks.
 55                let content_len = text.find(&closing_fence);
 56
 57                // Verify the citation format - e.g. ```path/to/foo.txt#L123-456
 58                if let Some(citation_len) = text.find('\n') {
 59                    let citation = &text[..citation_len];
 60
 61                    if let Ok(()) =
 62                        cx.assert(citation.contains("/"), format!("Slash in {citation:?}",))
 63                    {
 64                        let path_range = PathWithRange::new(citation);
 65                        let path = cx.agent_thread().update(cx, |thread, cx| {
 66                            thread
 67                                .project()
 68                                .read(cx)
 69                                .find_project_path(path_range.path.as_ref(), cx)
 70                        });
 71
 72                        if let Ok(path) = cx.assert_some(path, format!("Valid path: {citation:?}"))
 73                        {
 74                            let buffer_text = {
 75                                let buffer = cx
 76                                    .agent_thread()
 77                                    .update(cx, |thread, cx| {
 78                                        thread
 79                                            .project()
 80                                            .update(cx, |project, cx| project.open_buffer(path, cx))
 81                                    })
 82                                    .await
 83                                    .ok();
 84
 85                                let Ok(buffer_text) = cx.assert_some(
 86                                    buffer.map(|buffer| {
 87                                        buffer.read_with(cx, |buffer, _| buffer.text())
 88                                    }),
 89                                    "Reading buffer text succeeded",
 90                                ) else {
 91                                    continue;
 92                                };
 93                                buffer_text
 94                            };
 95
 96                            if let Some(content_len) = content_len {
 97                                // + 1 because there's a newline character after the citation.
 98                                let start_index = citation.len() + 1;
 99                                let end_index = content_len.saturating_sub(start_index);
100
101                                if cx
102                                    .assert(
103                                        start_index <= end_index,
104                                        "Code block had a valid citation",
105                                    )
106                                    .is_ok()
107                                {
108                                    let content = &text[start_index..end_index];
109
110                                    // deindent (trim the start of each line) because sometimes the model
111                                    // chooses to deindent its code snippets for the sake of readability,
112                                    // which in markdown is not only reasonable but usually desirable.
113                                    cx.assert(
114                                        deindent(&buffer_text)
115                                            .trim()
116                                            .contains(deindent(&content).trim()),
117                                        "Code block content was found in file",
118                                    )
119                                    .ok();
120
121                                    if let Some(range) = path_range.range {
122                                        let start_line_index = range.start.line.saturating_sub(1);
123                                        let line_count =
124                                            range.end.line.saturating_sub(start_line_index);
125                                        let mut snippet = buffer_text
126                                            .lines()
127                                            .skip(start_line_index as usize)
128                                            .take(line_count as usize)
129                                            .collect::<Vec<&str>>()
130                                            .join("\n");
131
132                                        if let Some(start_col) = range.start.col {
133                                            snippet = snippet[start_col as usize..].to_string();
134                                        }
135
136                                        if let Some(end_col) = range.end.col {
137                                            let last_line = snippet.lines().last().unwrap();
138                                            snippet = snippet[..snippet.len() - last_line.len()
139                                                + end_col as usize]
140                                                .to_string();
141                                        }
142
143                                        // deindent (trim the start of each line) because sometimes the model
144                                        // chooses to deindent its code snippets for the sake of readability,
145                                        // which in markdown is not only reasonable but usually desirable.
146                                        cx.assert_eq(
147                                            deindent(snippet.as_str()).trim(),
148                                            deindent(content).trim(),
149                                            format!(
150                                                "Code block was at {:?}-{:?}",
151                                                range.start, range.end
152                                            ),
153                                        )
154                                        .ok();
155                                    }
156                                }
157                            }
158                        }
159                    }
160                } else {
161                    cx.assert(
162                        false,
163                        format!("Opening {FENCE} did not have a newline anywhere after it."),
164                    )
165                    .ok();
166                }
167
168                if let Some(content_len) = content_len {
169                    // Advance past the closing backticks
170                    text = &text[content_len + FENCE.len()..];
171                } else {
172                    // There were no closing backticks associated with these opening backticks.
173                    cx.assert(
174                        false,
175                        "Code block opening had matching closing backticks.".to_string(),
176                    )
177                    .ok();
178
179                    // There are no more code blocks to parse, so we're done.
180                    break;
181                }
182            }
183        }
184
185        Ok(())
186    }
187
188    fn thread_assertions(&self) -> Vec<JudgeAssertion> {
189        vec![
190            JudgeAssertion {
191                id: "trait method bodies are shown".to_string(),
192                description:
193                    "All method bodies of the Tool trait are shown."
194                        .to_string(),
195            },
196            JudgeAssertion {
197                id: "code blocks used".to_string(),
198                description:
199                   "All code snippets are rendered inside markdown code blocks (as opposed to any other formatting besides code blocks)."
200                        .to_string(),
201            },
202            JudgeAssertion {
203              id: "code blocks use backticks".to_string(),
204              description:
205                  format!("All markdown code blocks use backtick fences ({FENCE}) rather than indentation.")
206            }
207        ]
208    }
209}
210
211fn deindent(as_str: impl AsRef<str>) -> String {
212    as_str
213        .as_ref()
214        .lines()
215        .map(|line| line.trim_start())
216        .collect::<Vec<&str>>()
217        .join("\n")
218}