1use anyhow::Result;
2use async_trait::async_trait;
3use markdown::PathWithRange;
4
5use crate::example::{Example, ExampleContext, ExampleMetadata, JudgeAssertion, LanguageServer};
6
7pub struct CodeBlockCitations;
8
9const FENCE: &str = "```";
10
11#[async_trait(?Send)]
12impl Example for CodeBlockCitations {
13 fn meta(&self) -> ExampleMetadata {
14 ExampleMetadata {
15 name: "code_block_citations".to_string(),
16 url: "https://github.com/zed-industries/zed.git".to_string(),
17 revision: "f69aeb6311dde3c0b8979c293d019d66498d54f2".to_string(),
18 language_server: Some(LanguageServer {
19 file_extension: "rs".to_string(),
20 allow_preexisting_diagnostics: false,
21 }),
22 max_assertions: None,
23 }
24 }
25
26 async fn conversation(&self, cx: &mut ExampleContext) -> Result<()> {
27 const FILENAME: &str = "assistant_tool.rs";
28 cx.push_user_message(format!(
29 r#"
30 Show me the method bodies of all the methods of the `Tool` trait in {FILENAME}.
31
32 Please show each method in a separate code snippet.
33 "#
34 ));
35
36 // Verify that the messages all have the correct formatting.
37 let texts: Vec<String> = cx.run_to_end().await?.texts().collect();
38 let closing_fence = format!("\n{FENCE}");
39
40 for text in texts.iter() {
41 let mut text = text.as_str();
42
43 while let Some(index) = text.find(FENCE) {
44 // Advance text past the opening backticks.
45 text = &text[index + FENCE.len()..];
46
47 // Find the closing backticks.
48 let content_len = text.find(&closing_fence);
49
50 // Verify the citation format - e.g. ```path/to/foo.txt#L123-456
51 if let Some(citation_len) = text.find('\n') {
52 let citation = &text[..citation_len];
53
54 if let Ok(()) =
55 cx.assert(citation.contains("/"), format!("Slash in {citation:?}",))
56 {
57 let path_range = PathWithRange::new(citation);
58 let path = cx
59 .agent_thread()
60 .update(cx, |thread, cx| {
61 thread
62 .project()
63 .read(cx)
64 .find_project_path(path_range.path, cx)
65 })
66 .ok()
67 .flatten();
68
69 if let Ok(path) = cx.assert_some(path, format!("Valid path: {citation:?}"))
70 {
71 let buffer_text = {
72 let buffer = match cx.agent_thread().update(cx, |thread, cx| {
73 thread
74 .project()
75 .update(cx, |project, cx| project.open_buffer(path, cx))
76 }) {
77 Ok(buffer_task) => buffer_task.await.ok(),
78 Err(err) => {
79 cx.assert(
80 false,
81 format!("Expected Ok(buffer), not {err:?}"),
82 )
83 .ok();
84 break;
85 }
86 };
87
88 let Ok(buffer_text) = cx.assert_some(
89 buffer.and_then(|buffer| {
90 buffer.read_with(cx, |buffer, _| buffer.text()).ok()
91 }),
92 "Reading buffer text succeeded",
93 ) else {
94 continue;
95 };
96 buffer_text
97 };
98
99 if let Some(content_len) = content_len {
100 // + 1 because there's a newline character after the citation.
101 let content =
102 &text[(citation.len() + 1)..content_len - (citation.len() + 1)];
103
104 cx.assert(
105 buffer_text.contains(&content),
106 "Code block content was found in file",
107 )
108 .ok();
109
110 if let Some(range) = path_range.range {
111 let start_line_index = range.start.line.saturating_sub(1);
112 let line_count =
113 range.end.line.saturating_sub(start_line_index);
114 let mut snippet = buffer_text
115 .lines()
116 .skip(start_line_index as usize)
117 .take(line_count as usize)
118 .collect::<Vec<&str>>()
119 .join("\n");
120
121 if let Some(start_col) = range.start.col {
122 snippet = snippet[start_col as usize..].to_string();
123 }
124
125 if let Some(end_col) = range.end.col {
126 let last_line = snippet.lines().last().unwrap();
127 snippet = snippet
128 [..snippet.len() - last_line.len() + end_col as usize]
129 .to_string();
130 }
131
132 cx.assert_eq(
133 snippet.as_str(),
134 content,
135 "Code block snippet was at specified line/col",
136 )
137 .ok();
138 }
139 }
140 }
141 }
142 } else {
143 cx.assert(
144 false,
145 format!("Opening {FENCE} did not have a newline anywhere after it."),
146 )
147 .ok();
148 }
149
150 if let Some(content_len) = content_len {
151 // Advance past the closing backticks
152 text = &text[content_len + FENCE.len()..];
153 } else {
154 // There were no closing backticks associated with these opening backticks.
155 cx.assert(
156 false,
157 "Code block opening had matching closing backticks.".to_string(),
158 )
159 .ok();
160
161 // There are no more code blocks to parse, so we're done.
162 break;
163 }
164 }
165 }
166
167 Ok(())
168 }
169
170 fn thread_assertions(&self) -> Vec<JudgeAssertion> {
171 vec![
172 JudgeAssertion {
173 id: "trait method bodies are shown".to_string(),
174 description:
175 "All method bodies of the Tool trait are shown."
176 .to_string(),
177 },
178 JudgeAssertion {
179 id: "code blocks used".to_string(),
180 description:
181 "All code snippets are rendered inside markdown code blocks (as opposed to any other formatting besides code blocks)."
182 .to_string(),
183 },
184 JudgeAssertion {
185 id: "code blocks use backticks".to_string(),
186 description:
187 format!("All markdown code blocks use backtick fences ({FENCE}) rather than indentation.")
188 }
189 ]
190 }
191}