1use agent::{ThreadEvent, ThreadStore};
2use anyhow::{Context as _, Result, anyhow};
3use assistant_tool::ToolWorkingSet;
4use client::proto::LspWorkProgress;
5use collections::HashMap;
6use dap::DapRegistry;
7use futures::channel::mpsc;
8use futures::{FutureExt, StreamExt as _, select_biased};
9use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
10use handlebars::Handlebars;
11use language::{Buffer, DiagnosticSeverity, OffsetRangeExt};
12use language_model::{
13 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
14 MessageContent, Role, StopReason, TokenUsage,
15};
16use project::{Project, ProjectPath};
17use serde::{Deserialize, Serialize};
18use std::cell::RefCell;
19use std::fmt::Write as _;
20use std::fs::File;
21use std::io::Write as _;
22use std::rc::Rc;
23use std::sync::{Arc, Mutex};
24use std::time::Duration;
25use std::{
26 fs,
27 path::{Path, PathBuf},
28};
29use unindent::Unindent as _;
30use util::ResultExt as _;
31use util::command::new_smol_command;
32use util::markdown::MarkdownString;
33use util::serde::default_true;
34
35use crate::AgentAppState;
36
37pub const EXAMPLES_DIR: &str = "./crates/eval/examples";
38pub const REPOS_DIR: &str = "./crates/eval/repos";
39pub const WORKTREES_DIR: &str = "./crates/eval/worktrees";
40
41const THREAD_EVENT_TIMEOUT: Duration = Duration::from_secs(60 * 2);
42
43#[derive(Clone, Debug, Deserialize)]
44pub struct ExampleBase {
45 pub url: String,
46 pub revision: String,
47 pub language_extension: Option<String>,
48 pub insert_id: Option<String>,
49 #[serde(default = "default_true")]
50 pub require_lsp: bool,
51 #[serde(default)]
52 pub allow_preexisting_diagnostics: bool,
53}
54
55impl ExampleBase {
56 pub fn repo_name(&self) -> String {
57 self.url
58 .split('/')
59 .next_back()
60 .unwrap_or(&"")
61 .trim_end_matches(".git")
62 .into()
63 }
64}
65
66#[derive(Clone, Debug)]
67pub struct Example {
68 pub name: String,
69 /// Content of `base.toml`
70 pub base: ExampleBase,
71 /// Content of `prompt.md`
72 pub prompt: String,
73 /// Content of `diff_criteria.md`
74 pub diff_criteria: String,
75 /// Content of `thread_criteria.md`, if that file exists (it's optional)
76 pub thread_criteria: Option<String>,
77 /// Path to the directory containing the requests and responses for the agentic loop
78 pub run_directory_path: PathBuf,
79 /// Prefix used for logging that identifies this example
80 pub log_prefix: String,
81}
82
83#[derive(Debug, Serialize, Deserialize, Clone)]
84pub struct RunOutput {
85 pub repository_diff: String,
86 pub ran_diagnostics_check: bool,
87 pub diagnostics_before: Option<String>,
88 pub diagnostics_after: Option<String>,
89 pub response_count: usize,
90 pub token_usage: TokenUsage,
91 pub tool_use_counts: HashMap<Arc<str>, u32>,
92 pub last_request: LanguageModelRequest,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct JudgeDiffInput {
97 pub repository_diff: String,
98 pub ran_diagnostics_check: bool,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 pub diagnostics_before: Option<String>,
101 #[serde(skip_serializing_if = "Option::is_none")]
102 pub diagnostics_after: Option<String>,
103 pub criteria: String,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct JudgeThreadInput {
108 pub messages: String,
109 pub criteria: String,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct JudgeResponse {
114 pub analysis: String,
115 pub score: u32,
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct JudgeOutput {
120 pub thread: Option<JudgeResponse>,
121 pub diff: JudgeResponse,
122}
123
124impl Example {
125 /// Load an example from a directory containing base.toml, prompt.md, and criteria.md
126 pub fn load_from_directory(dir_path: &Path, run_dir: &Path) -> Result<Self> {
127 let name = Self::name_from_path(dir_path);
128 let base_path = dir_path.join("base.toml");
129 let prompt_path = dir_path.join("prompt.md");
130 let diff_criteria_path = dir_path.join("diff_criteria.md");
131 let thread_criteria_path = dir_path.join("thread_criteria.md");
132 let thread_criteria = if thread_criteria_path.exists() {
133 Some(fs::read_to_string(thread_criteria_path.clone())?)
134 } else {
135 None
136 };
137
138 Ok(Example {
139 name: name.clone(),
140 base: toml::from_str(&fs::read_to_string(&base_path)?)?,
141 prompt: fs::read_to_string(prompt_path.clone())?,
142 thread_criteria,
143 diff_criteria: fs::read_to_string(diff_criteria_path.clone())?,
144 run_directory_path: run_dir.to_path_buf(),
145 log_prefix: name,
146 })
147 }
148
149 pub fn set_repetition_number(&mut self, repetition_number: u32) {
150 if repetition_number > 0 {
151 self.name = format!("{}-{}", self.name, repetition_number);
152 }
153 }
154
155 pub fn example_output_directory(&self) -> PathBuf {
156 self.run_directory_path.join(&self.name)
157 }
158
159 pub fn set_log_prefix_style(&mut self, color: &str, name_width: usize) {
160 self.log_prefix = format!(
161 "{}{:<width$}\x1b[0m | ",
162 color,
163 self.name,
164 width = name_width
165 );
166 }
167
168 pub fn name_from_path(path: &Path) -> String {
169 path.file_name().unwrap().to_string_lossy().to_string()
170 }
171
172 pub fn worktree_path(&self) -> PathBuf {
173 Path::new(WORKTREES_DIR)
174 .canonicalize()
175 .context(format!("No such directory {WORKTREES_DIR}"))
176 .unwrap()
177 .join(&self.name)
178 .join(self.base.repo_name())
179 }
180
181 /// Set up the example by checking out the specified Git revision
182 pub async fn setup(&mut self) -> Result<()> {
183 let repo_path = repo_path_for_url(&self.base.url);
184
185 let revision_exists = run_git(
186 &repo_path,
187 &["rev-parse", &format!("{}^{{commit}}", self.base.revision)],
188 )
189 .await
190 .is_ok();
191
192 if !revision_exists {
193 println!(
194 "{}Fetching revision {}",
195 self.log_prefix, &self.base.revision
196 );
197 run_git(
198 &repo_path,
199 &["fetch", "--depth", "1", "origin", &self.base.revision],
200 )
201 .await?;
202 }
203
204 let worktree_path = self.worktree_path();
205
206 if worktree_path.is_dir() {
207 println!("{}Resetting existing worktree", self.log_prefix);
208
209 // TODO: consider including "-x" to remove ignored files. The downside of this is that
210 // it will also remove build artifacts, and so prevent incremental reuse there.
211 run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
212 run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
213 run_git(&worktree_path, &["checkout", &self.base.revision]).await?;
214 } else {
215 println!("{}Creating worktree", self.log_prefix);
216
217 let worktree_path_string = worktree_path.to_string_lossy().to_string();
218
219 run_git(
220 &repo_path,
221 &[
222 "worktree",
223 "add",
224 "-f",
225 &worktree_path_string,
226 &self.base.revision,
227 ],
228 )
229 .await?;
230 }
231
232 std::fs::create_dir_all(self.example_output_directory())?;
233
234 Ok(())
235 }
236
237 pub fn run(
238 &self,
239 model: Arc<dyn LanguageModel>,
240 app_state: Arc<AgentAppState>,
241 cx: &mut App,
242 ) -> Task<Result<RunOutput>> {
243 let project = Project::local(
244 app_state.client.clone(),
245 app_state.node_runtime.clone(),
246 app_state.user_store.clone(),
247 app_state.languages.clone(),
248 Arc::new(DapRegistry::default()),
249 app_state.fs.clone(),
250 None,
251 cx,
252 );
253
254 let worktree_path = self.worktree_path();
255 let worktree = project.update(cx, |project, cx| {
256 project.create_worktree(&worktree_path, true, cx)
257 });
258
259 let tools = cx.new(|_| ToolWorkingSet::default());
260 let thread_store =
261 ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
262 let this = self.clone();
263
264 cx.spawn(async move |cx| {
265 let worktree = worktree.await?;
266
267 // Wait for worktree scan to finish before choosing a file to open.
268 worktree
269 .update(cx, |worktree, _cx| {
270 worktree.as_local().unwrap().scan_complete()
271 })?
272 .await;
273
274 let lsp = if this.base.require_lsp {
275 let language_extension = this.base.language_extension.as_deref().context(
276 "language_extension field is required in base.toml when `require_lsp == true`",
277 )?;
278
279 // Open a file that matches the language to cause LSP to start.
280 let language_file = worktree.read_with(cx, |worktree, _cx| {
281 worktree
282 .files(false, 0)
283 .find_map(|e| {
284 if e.path.clone().extension().and_then(|ext| ext.to_str())
285 == Some(language_extension)
286 {
287 Some(ProjectPath {
288 worktree_id: worktree.id(),
289 path: e.path.clone(),
290 })
291 } else {
292 None
293 }
294 })
295 .context("Failed to find a file for example language")
296 })??;
297
298 let open_language_file_buffer_task = project.update(cx, |project, cx| {
299 project.open_buffer(language_file.clone(), cx)
300 })?;
301
302 let language_file_buffer = open_language_file_buffer_task.await?;
303
304 let lsp_open_handle = project.update(cx, |project, cx| {
305 project.register_buffer_with_language_servers(&language_file_buffer, cx)
306 })?;
307
308 wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
309
310 Some((lsp_open_handle, language_file_buffer))
311 } else {
312 None
313 };
314
315 let diagnostics_before = query_lsp_diagnostics(project.clone(), cx).await?;
316 if diagnostics_before.is_some() && !this.base.allow_preexisting_diagnostics {
317 return Err(anyhow!("Example has pre-existing diagnostics. If you want to run this example regardless, set `allow_preexisting_diagnostics` to `true` in `base.toml`"));
318 }
319
320 if std::env::var("ZED_EVAL_SETUP_ONLY").is_ok() {
321 return Err(anyhow!("Setup only mode"));
322 }
323
324 let thread_store = thread_store.await?;
325 let thread =
326 thread_store.update(cx, |thread_store, cx| thread_store.create_thread(cx))?;
327 let last_request = Rc::new(RefCell::new(None));
328
329 thread.update(cx, |thread, _cx| {
330 let mut request_count = 0;
331 let example_dir_path = this.example_output_directory();
332
333 let last_request = Rc::clone(&last_request);
334 thread.set_request_callback(move |request, response_events| {
335 *last_request.borrow_mut() = Some(request.clone());
336
337 request_count += 1;
338 let messages_file_path = example_dir_path.join(format!("{request_count}.messages.md"));
339 let last_messages_file_path = example_dir_path.join("last.messages.md");
340 let request_markdown = RequestMarkdown::new(request);
341 let response_events_markdown = response_events_to_markdown(response_events);
342
343 let messages = format!("{}\n\n{}", request_markdown.messages, response_events_markdown);
344 fs::write(messages_file_path, messages.clone()).expect("failed to write messages file");
345 fs::write(last_messages_file_path, messages).expect("failed to write last messages file");
346
347 if request_count == 1 {
348 let tools_file_path = example_dir_path.join("tools.md");
349 fs::write(tools_file_path, request_markdown.tools).expect("failed to write tools file");
350 }
351 });
352 })?;
353
354 let tool_use_counts: Arc<Mutex<HashMap<Arc<str>, u32>>> =
355 Mutex::new(HashMap::default()).into();
356
357 let (thread_event_tx, mut thread_event_rx) = mpsc::unbounded();
358
359 let subscription = cx.subscribe(&thread, move |_thread, event: &ThreadEvent, _cx| {
360 thread_event_tx.unbounded_send(event.clone()).log_err();
361 });
362
363 let event_handler_task = cx.spawn({
364 let log_prefix = this.log_prefix.clone();
365 let tool_use_counts = tool_use_counts.clone();
366 let thread = thread.downgrade();
367 async move |cx| {
368 loop {
369 let event = select_biased! {
370 event = thread_event_rx.next() => event,
371 _ = cx.background_executor().timer(THREAD_EVENT_TIMEOUT).fuse() => {
372 return Err(anyhow!("Agentic loop stalled - waited {:?} without any events", THREAD_EVENT_TIMEOUT));
373 }
374 };
375 let Some(event) = event else {
376 return Err(anyhow!("ThreadEvent channel ended early"));
377 };
378
379 match event {
380 ThreadEvent::Stopped(reason) => match reason {
381 Ok(StopReason::EndTurn) => {
382 return Ok(());
383 }
384 Ok(StopReason::MaxTokens) => {
385 return Err(anyhow!("Exceeded maximum tokens"));
386 }
387 Ok(StopReason::ToolUse) => {
388 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
389 println!("{}StopReason: Tool use", log_prefix);
390 }
391 }
392 Err(error) => {
393 return Err(anyhow!(error.clone()));
394 }
395 },
396 ThreadEvent::ShowError(thread_error) => {
397 break Err(anyhow!(thread_error.clone()));
398 }
399 ThreadEvent::StreamedAssistantText(_, _)| ThreadEvent::StreamedAssistantThinking(_, _) | ThreadEvent::UsePendingTools { .. } => {
400 }
401 ThreadEvent::ToolFinished {
402 tool_use_id,
403 pending_tool_use,
404 ..
405 } => {
406 thread.update(cx, |thread, _cx| {
407 if let Some(tool_use) = pending_tool_use {
408 if let Some(tool_result) = thread.tool_result(&tool_use_id) {
409 let message = if tool_result.is_error {
410 format!("TOOL FAILED: {}", tool_use.name)
411 } else {
412 format!("TOOL FINISHED: {}", tool_use.name)
413 };
414 println!("{log_prefix}{message}");
415 let mut tool_use_counts = tool_use_counts.lock().unwrap();
416 *tool_use_counts
417 .entry(tool_result.tool_name.clone())
418 .or_insert(0) += 1;
419 } else {
420 let message = format!("TOOL FINISHED WITHOUT RESULT: {}", tool_use.name);
421 println!("{log_prefix}{message}");
422 }
423 }
424 })?;
425 }
426 ThreadEvent::ToolConfirmationNeeded => {
427 panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
428 },
429 ThreadEvent::StreamedCompletion |
430 ThreadEvent::MessageAdded(_) |
431 ThreadEvent::MessageEdited(_) |
432 ThreadEvent::MessageDeleted(_) |
433 ThreadEvent::SummaryChanged |
434 ThreadEvent::SummaryGenerated |
435 ThreadEvent::CheckpointChanged |
436 ThreadEvent::UsageUpdated(_) => {
437 if std::env::var("ZED_EVAL_DEBUG").is_ok() {
438 println!("{}Event: {:#?}", log_prefix, event);
439 }
440 }
441 }
442 }
443 }
444 });
445
446 thread.update(cx, |thread, cx| {
447 let context = vec![];
448 thread.insert_user_message(this.prompt.clone(), context, None, cx);
449 thread.send_to_model(model, cx);
450 })?;
451
452 event_handler_task.await?;
453
454 println!("{}Stopped", this.log_prefix);
455
456 if let Some((_, language_file_buffer)) = lsp.as_ref() {
457 wait_for_lang_server(&project, &language_file_buffer, this.log_prefix.clone(), cx).await?;
458 }
459
460 println!("{}Getting repository diff", this.log_prefix);
461 let repository_diff = this.repository_diff().await?;
462
463 let example_output_dir = this.example_output_directory();
464 let repository_diff_path = example_output_dir.join("patch.diff");
465 let mut repository_diff_output_file = File::create(&repository_diff_path)?;
466 writeln!(&mut repository_diff_output_file, "{}", &repository_diff).log_err();
467
468 println!("{}Getting diagnostics", this.log_prefix);
469 let diagnostics_after = cx
470 .update(move |cx| {
471 cx.spawn(async move |cx| query_lsp_diagnostics(project, cx).await)
472 })?
473 .await?;
474 println!("{}Got diagnostics", this.log_prefix);
475
476 let Some(last_request) = last_request.borrow_mut().take() else {
477 return Err(anyhow!("No requests ran."));
478 };
479
480 drop(subscription);
481 drop(lsp);
482
483 if let Some(diagnostics_before) = &diagnostics_before {
484 fs::write(example_output_dir.join("diagnostics_before.txt"), diagnostics_before)?;
485 }
486
487 if let Some(diagnostics_after) = &diagnostics_after {
488 fs::write(example_output_dir.join("diagnostics_after.txt"), diagnostics_after)?;
489 }
490
491
492 thread.update(cx, |thread, _cx| {
493 let response_count = thread
494 .messages()
495 .filter(|message| message.role == language_model::Role::Assistant)
496 .count();
497 RunOutput {
498 repository_diff,
499 ran_diagnostics_check: this.base.require_lsp,
500 diagnostics_before,
501 diagnostics_after,
502 response_count,
503 token_usage: thread.cumulative_token_usage(),
504 tool_use_counts: tool_use_counts.lock().unwrap().clone(),
505 last_request,
506 }
507 })
508 })
509 }
510
511 async fn judge_diff(
512 &self,
513 model: Arc<dyn LanguageModel>,
514 run_output: &RunOutput,
515 judge_number: u32,
516 cx: &AsyncApp,
517 ) -> Result<(String, JudgeResponse)> {
518 let judge_diff_prompt = include_str!("judge_diff_prompt.hbs");
519 let judge_diff_prompt_name = "judge_diff_prompt";
520 let mut hbs = Handlebars::new();
521 hbs.register_template_string(judge_diff_prompt_name, judge_diff_prompt)?;
522
523 let diff_prompt = hbs.render(
524 judge_diff_prompt_name,
525 &JudgeDiffInput {
526 repository_diff: run_output.repository_diff.clone(),
527 ran_diagnostics_check: run_output.ran_diagnostics_check,
528 diagnostics_before: run_output.diagnostics_before.clone(),
529 diagnostics_after: run_output.diagnostics_after.clone(),
530 criteria: self.diff_criteria.clone(),
531 },
532 )?;
533
534 let request = LanguageModelRequest {
535 thread_id: None,
536 prompt_id: None,
537 messages: vec![LanguageModelRequestMessage {
538 role: Role::User,
539 content: vec![MessageContent::Text(diff_prompt)],
540 cache: false,
541 }],
542 temperature: None,
543 tools: Vec::new(),
544 stop: Vec::new(),
545 };
546
547 let diff_response = send_language_model_request(model, request, cx).await?;
548 let diff_output = JudgeResponse::parse(&diff_response)?;
549
550 println!(
551 "{}Judge #{judge_number} - Diff score: {}",
552 self.log_prefix, diff_output.score
553 );
554
555 Ok((diff_response, diff_output))
556 }
557
558 async fn judge_thread(
559 &self,
560 model: Arc<dyn LanguageModel>,
561 run_output: &RunOutput,
562 judge_number: u32,
563 cx: &AsyncApp,
564 ) -> Result<(String, Option<JudgeResponse>)> {
565 if let Some(criteria) = self.thread_criteria.clone() {
566 let judge_thread_prompt = include_str!("judge_thread_prompt.hbs");
567 let judge_thread_prompt_name = "judge_thread_prompt";
568 let mut hbs = Handlebars::new();
569 hbs.register_template_string(judge_thread_prompt_name, judge_thread_prompt)?;
570
571 let request_markdown = RequestMarkdown::new(&run_output.last_request);
572 let thread_prompt = hbs.render(
573 judge_thread_prompt_name,
574 &JudgeThreadInput {
575 messages: request_markdown.messages,
576 criteria,
577 },
578 )?;
579
580 let request = LanguageModelRequest {
581 thread_id: None,
582 prompt_id: None,
583 messages: vec![LanguageModelRequestMessage {
584 role: Role::User,
585 content: vec![MessageContent::Text(thread_prompt)],
586 cache: false,
587 }],
588 temperature: None,
589 tools: Vec::new(),
590 stop: Vec::new(),
591 };
592
593 let thread_response = send_language_model_request(model, request, cx).await?;
594 let thread_output = JudgeResponse::parse(&thread_response)?;
595
596 println!(
597 "{}Judge #{judge_number} - Thread score: {}",
598 self.log_prefix, thread_output.score
599 );
600
601 Ok((thread_response, Some(thread_output)))
602 } else {
603 let msg = "There were no criteria specified for this thread, so this example was not judged on its thread.".to_string();
604 Ok((msg, None))
605 }
606 }
607
608 pub async fn judge(
609 &self,
610 model: Arc<dyn LanguageModel>,
611 run_output: &RunOutput,
612 judge_number: u32,
613 cx: &AsyncApp,
614 ) -> Result<JudgeOutput> {
615 let mut output_file = File::create(
616 self.example_output_directory()
617 .join(format!("judge_{}.md", judge_number)),
618 )
619 .expect("failed to create judge.md");
620
621 println!("{}Running judge #{judge_number}", self.log_prefix);
622
623 let diff_task = self.judge_diff(model.clone(), &run_output, judge_number, cx);
624 let thread_task = self.judge_thread(model.clone(), &run_output, judge_number, cx);
625
626 let (diff_result, thread_result) = futures::join!(diff_task, thread_task);
627
628 let (diff_response, diff_output) = diff_result?;
629 let (thread_response, thread_output) = thread_result?;
630
631 writeln!(
632 &mut output_file,
633 "# Judgment\n\n## Thread\n\n{thread_response}\n\n## Diff\n\n{diff_response}",
634 )
635 .log_err();
636
637 Ok(JudgeOutput {
638 thread: thread_output,
639 diff: diff_output,
640 })
641 }
642
643 async fn repository_diff(&self) -> Result<String> {
644 let worktree_path = self.worktree_path();
645 run_git(&worktree_path, &["add", "."]).await?;
646 run_git(&worktree_path, &["diff", "--staged"]).await
647 }
648}
649
650fn wait_for_lang_server(
651 project: &Entity<Project>,
652 buffer: &Entity<Buffer>,
653 log_prefix: String,
654 cx: &mut AsyncApp,
655) -> Task<Result<()>> {
656 println!("{}⏵ Waiting for language server", log_prefix);
657
658 let (mut tx, mut rx) = mpsc::channel(1);
659
660 let lsp_store = project
661 .update(cx, |project, _| project.lsp_store())
662 .unwrap();
663
664 let has_lang_server = buffer
665 .update(cx, |buffer, cx| {
666 lsp_store.update(cx, |lsp_store, cx| {
667 lsp_store
668 .language_servers_for_local_buffer(&buffer, cx)
669 .next()
670 .is_some()
671 })
672 })
673 .unwrap_or(false);
674
675 if has_lang_server {
676 project
677 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
678 .unwrap()
679 .detach();
680 }
681
682 let subscriptions =
683 [
684 cx.subscribe(&lsp_store, {
685 let log_prefix = log_prefix.clone();
686 move |_, event, _| match event {
687 project::LspStoreEvent::LanguageServerUpdate {
688 message:
689 client::proto::update_language_server::Variant::WorkProgress(
690 LspWorkProgress {
691 message: Some(message),
692 ..
693 },
694 ),
695 ..
696 } => println!("{}⟲ {message}", log_prefix),
697 _ => {}
698 }
699 }),
700 cx.subscribe(&project, {
701 let buffer = buffer.clone();
702 move |project, event, cx| match event {
703 project::Event::LanguageServerAdded(_, _, _) => {
704 let buffer = buffer.clone();
705 project
706 .update(cx, |project, cx| project.save_buffer(buffer, cx))
707 .detach();
708 }
709 project::Event::DiskBasedDiagnosticsFinished { .. } => {
710 tx.try_send(()).ok();
711 }
712 _ => {}
713 }
714 }),
715 ];
716
717 cx.spawn(async move |cx| {
718 let timeout = cx.background_executor().timer(Duration::new(60 * 5, 0));
719 let result = futures::select! {
720 _ = rx.next() => {
721 println!("{}⚑ Language server idle", log_prefix);
722 anyhow::Ok(())
723 },
724 _ = timeout.fuse() => {
725 Err(anyhow!("LSP wait timed out after 5 minutes"))
726 }
727 };
728 drop(subscriptions);
729 result
730 })
731}
732
733async fn query_lsp_diagnostics(
734 project: Entity<Project>,
735 cx: &mut AsyncApp,
736) -> Result<Option<String>> {
737 let paths_with_diagnostics = project.update(cx, |project, cx| {
738 project
739 .diagnostic_summaries(true, cx)
740 .filter(|(_, _, summary)| summary.error_count > 0 || summary.warning_count > 0)
741 .map(|(project_path, _, _)| project_path)
742 .collect::<Vec<_>>()
743 })?;
744
745 if paths_with_diagnostics.is_empty() {
746 return Ok(None);
747 }
748
749 let mut output = String::new();
750 for project_path in paths_with_diagnostics {
751 let buffer = project
752 .update(cx, |project, cx| project.open_buffer(project_path, cx))?
753 .await?;
754 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
755
756 for (_, group) in snapshot.diagnostic_groups(None) {
757 let entry = &group.entries[group.primary_ix];
758 let range = entry.range.to_point(&snapshot);
759 let severity = match entry.diagnostic.severity {
760 DiagnosticSeverity::ERROR => "error",
761 DiagnosticSeverity::WARNING => "warning",
762 _ => continue,
763 };
764
765 writeln!(
766 output,
767 "{} at line {}: {}",
768 severity,
769 range.start.row + 1,
770 entry.diagnostic.message
771 )?;
772 }
773 }
774 anyhow::Ok(Some(output))
775}
776
777impl JudgeResponse {
778 fn parse(response: &str) -> Result<Self> {
779 let analysis = get_tag("analysis", response)?.to_string();
780 let score = get_tag("score", response)?
781 .parse()
782 .context("error parsing score")?;
783
784 Ok(Self { analysis, score })
785 }
786}
787
788fn get_tag(name: &'static str, response: &str) -> Result<String> {
789 let start_tag = format!("<{}>", name);
790 let end_tag = format!("</{}>", name);
791
792 let start_ix = response
793 .find(&start_tag)
794 .context(format!("{} start tag not found", name))?;
795 let content_start_ix = start_ix + start_tag.len();
796
797 let end_ix = content_start_ix
798 + response[content_start_ix..]
799 .find(&end_tag)
800 .context(format!("{} end tag not found", name))?;
801
802 let content = response[content_start_ix..end_ix].trim().unindent();
803
804 anyhow::Ok(content)
805}
806
807pub fn repo_path_for_url(repo_url: &str) -> PathBuf {
808 let repo_name = repo_url
809 .trim_start_matches("https://")
810 .replace(|c: char| !c.is_alphanumeric(), "-");
811 Path::new(REPOS_DIR)
812 .canonicalize()
813 .context(format!("No such directory {REPOS_DIR}"))
814 .unwrap()
815 .join(repo_name)
816}
817
818pub async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
819 let output = new_smol_command("git")
820 .current_dir(repo_path)
821 .args(args)
822 .output()
823 .await?;
824
825 if output.status.success() {
826 Ok(String::from_utf8(output.stdout)?.trim().to_string())
827 } else {
828 Err(anyhow!(
829 "`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
830 args.join(" "),
831 repo_path.display(),
832 output.status,
833 String::from_utf8_lossy(&output.stderr),
834 String::from_utf8_lossy(&output.stdout),
835 ))
836 }
837}
838
839pub async fn send_language_model_request(
840 model: Arc<dyn LanguageModel>,
841 request: LanguageModelRequest,
842 cx: &AsyncApp,
843) -> anyhow::Result<String> {
844 match model.stream_completion_text(request, &cx).await {
845 Ok(mut stream) => {
846 let mut full_response = String::new();
847 while let Some(chunk_result) = stream.stream.next().await {
848 match chunk_result {
849 Ok(chunk_str) => {
850 full_response.push_str(&chunk_str);
851 }
852 Err(err) => {
853 return Err(anyhow!(
854 "Error receiving response from language model: {err}"
855 ));
856 }
857 }
858 }
859 Ok(full_response)
860 }
861 Err(err) => Err(anyhow!(
862 "Failed to get response from language model. Error was: {err}"
863 )),
864 }
865}
866
867struct RequestMarkdown {
868 tools: String,
869 messages: String,
870}
871
872impl RequestMarkdown {
873 fn new(request: &LanguageModelRequest) -> Self {
874 let mut tools = String::new();
875 let mut messages = String::new();
876 let mut assistant_message_number: u32 = 1;
877
878 // Print the tools
879 if !request.tools.is_empty() {
880 for tool in &request.tools {
881 write!(&mut tools, "# {}\n\n", tool.name).unwrap();
882 write!(&mut tools, "{}\n\n", tool.description).unwrap();
883 write!(
884 &mut tools,
885 "{}\n",
886 MarkdownString::code_block("json", &format!("{:#}", tool.input_schema))
887 )
888 .unwrap();
889 }
890 }
891
892 // Print the messages
893 for message in &request.messages {
894 match message.role {
895 Role::System => messages.push_str("# ⚙️ SYSTEM\n\n"),
896 Role::User => messages.push_str("# 👤 USER\n\n"),
897 Role::Assistant => {
898 messages.push_str(&format!("# 🤖 ASSISTANT {assistant_message_number}\n\n"));
899 assistant_message_number += 1;
900 }
901 };
902
903 for content in &message.content {
904 match content {
905 MessageContent::Text(text) => {
906 messages.push_str(text);
907 messages.push_str("\n\n");
908 }
909 MessageContent::Image(_) => {
910 messages.push_str("[IMAGE DATA]\n\n");
911 }
912 MessageContent::Thinking { text, signature } => {
913 messages.push_str("**Thinking**:\n\n");
914 if let Some(sig) = signature {
915 messages.push_str(&format!("Signature: {}\n\n", sig));
916 }
917 messages.push_str(text);
918 messages.push_str("\n");
919 }
920 MessageContent::RedactedThinking(items) => {
921 messages.push_str(&format!(
922 "**Redacted Thinking**: {} item(s)\n\n",
923 items.len()
924 ));
925 }
926 MessageContent::ToolUse(tool_use) => {
927 messages.push_str(&format!(
928 "**Tool Use**: {} (ID: {})\n",
929 tool_use.name, tool_use.id
930 ));
931 messages.push_str(&format!(
932 "{}\n",
933 MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
934 ));
935 }
936 MessageContent::ToolResult(tool_result) => {
937 messages.push_str(&format!(
938 "**Tool Result**: {} (ID: {})\n\n",
939 tool_result.tool_name, tool_result.tool_use_id
940 ));
941 if tool_result.is_error {
942 messages.push_str("**ERROR:**\n");
943 }
944 messages.push_str(&format!("{}\n\n", tool_result.content));
945 }
946 }
947 }
948 }
949
950 Self { tools, messages }
951 }
952}
953
954fn response_events_to_markdown(
955 response_events: &[std::result::Result<LanguageModelCompletionEvent, String>],
956) -> String {
957 let mut response = String::new();
958 // Print the response events if any
959 response.push_str("# Response\n\n");
960 let mut text_buffer = String::new();
961 let mut thinking_buffer = String::new();
962
963 let flush_buffers =
964 |output: &mut String, text_buffer: &mut String, thinking_buffer: &mut String| {
965 if !text_buffer.is_empty() {
966 output.push_str(&format!("**Text**:\n{}\n\n", text_buffer));
967 text_buffer.clear();
968 }
969 if !thinking_buffer.is_empty() {
970 output.push_str(&format!("**Thinking**:\n{}\n\n", thinking_buffer));
971 thinking_buffer.clear();
972 }
973 };
974
975 for event in response_events {
976 match event {
977 Ok(LanguageModelCompletionEvent::Text(text)) => {
978 text_buffer.push_str(text);
979 }
980 Ok(LanguageModelCompletionEvent::Thinking { text, .. }) => {
981 thinking_buffer.push_str(text);
982 }
983 Ok(LanguageModelCompletionEvent::Stop(reason)) => {
984 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
985 response.push_str(&format!("**Stop**: {:?}\n\n", reason));
986 }
987 Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
988 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
989 response.push_str(&format!(
990 "**Tool Use**: {} (ID: {})\n",
991 tool_use.name, tool_use.id
992 ));
993 response.push_str(&format!(
994 "{}\n",
995 MarkdownString::code_block("json", &format!("{:#}", tool_use.input))
996 ));
997 }
998 Ok(
999 LanguageModelCompletionEvent::UsageUpdate(_)
1000 | LanguageModelCompletionEvent::StartMessage { .. },
1001 ) => {}
1002 Err(error) => {
1003 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1004 response.push_str(&format!("**Error**: {}\n\n", error));
1005 }
1006 }
1007 }
1008
1009 flush_buffers(&mut response, &mut text_buffer, &mut thinking_buffer);
1010
1011 response
1012}
1013
1014#[cfg(test)]
1015mod test {
1016 use super::*;
1017 use handlebars::Handlebars;
1018
1019 #[test]
1020 fn test_parse_judge_output() {
1021 let response = r#"
1022 <analysis>The model did a good job but there were still compilations errors.</analysis>
1023 <score>3</score>
1024 "#
1025 .unindent();
1026
1027 let output = JudgeResponse::parse(&response).unwrap();
1028 assert_eq!(
1029 output.analysis,
1030 "The model did a good job but there were still compilations errors."
1031 );
1032 assert_eq!(output.score, 3);
1033
1034 let response = r#"
1035 Text around ignored
1036
1037 <analysis>
1038 Failed to compile:
1039 - Error 1
1040 - Error 2
1041 </analysis>
1042
1043 <score>1</score>
1044 "#
1045 .unindent();
1046
1047 let output = JudgeResponse::parse(&response).unwrap();
1048 assert_eq!(output.analysis, "Failed to compile:\n- Error 1\n- Error 2");
1049 assert_eq!(output.score, 1);
1050 }
1051
1052 #[test]
1053 fn test_judge_prompt_with_diagnostics() {
1054 // Case 1: Both diagnostics before and after are present
1055 let input = JudgeDiffInput {
1056 repository_diff: "diff content goes here".to_string(),
1057 ran_diagnostics_check: true,
1058 diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1059 diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1060 criteria: "Fix all bugs".to_string(),
1061 };
1062
1063 let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1064
1065 let expected_diagnostics_section = r#"
1066 Take into account the diagnostics before and after applying the change:
1067
1068 <diagnostics_before>
1069 Error at line 10: variable not found
1070 </diagnostics_before>
1071
1072 <diagnostics_after>
1073 Error at line 15: missing semicolon
1074 </diagnostics_after>
1075 "#
1076 .unindent();
1077
1078 assert!(rendered.contains(&expected_diagnostics_section));
1079 }
1080
1081 #[test]
1082 fn test_judge_prompt_with_empty_diagnostics() {
1083 // Case 2: Diagnostics check run but no diagnostics found
1084 let input = JudgeDiffInput {
1085 repository_diff: "diff content goes here".to_string(),
1086 ran_diagnostics_check: true,
1087 diagnostics_before: None,
1088 diagnostics_after: None,
1089 criteria: "Fix all bugs".to_string(),
1090 };
1091
1092 let rendered = templates().render(JUDGE_PROMPT_NAME, &input).unwrap();
1093
1094 let expected_diagnostics_section = r#"
1095 Take into account the diagnostics before and after applying the change:
1096
1097 <diagnostics_before>
1098 No diagnostics before applying the edits.
1099 </diagnostics_before>
1100
1101 <diagnostics_after>
1102 No diagnostics after applying the edits.
1103 </diagnostics_after>
1104 "#
1105 .unindent();
1106
1107 assert!(rendered.contains(&expected_diagnostics_section));
1108 }
1109
1110 #[test]
1111 fn test_judge_prompt_with_mixed_diagnostics() {
1112 let templates = templates();
1113
1114 // Case 3: Before diagnostics present, after diagnostics absent
1115 let input = JudgeDiffInput {
1116 repository_diff: "diff content goes here".to_string(),
1117 ran_diagnostics_check: true,
1118 diagnostics_before: Some("Error at line 10: variable not found".to_string()),
1119 diagnostics_after: None,
1120 criteria: "Fix all bugs".to_string(),
1121 };
1122
1123 let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1124
1125 let expected_diagnostics_section = r#"
1126 Take into account the diagnostics before and after applying the change:
1127
1128 <diagnostics_before>
1129 Error at line 10: variable not found
1130 </diagnostics_before>
1131
1132 <diagnostics_after>
1133 No diagnostics after applying the edits.
1134 </diagnostics_after>
1135 "#
1136 .unindent();
1137
1138 assert!(rendered.contains(&expected_diagnostics_section));
1139
1140 // Case 4: Before diagnostics absent, after diagnostics present
1141 let input = JudgeDiffInput {
1142 repository_diff: "diff content goes here".to_string(),
1143 ran_diagnostics_check: true,
1144 diagnostics_before: None,
1145 diagnostics_after: Some("Error at line 15: missing semicolon".to_string()),
1146 criteria: "Fix all bugs".to_string(),
1147 };
1148
1149 let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1150
1151 let expected_diagnostics_section = r#"
1152 Take into account the diagnostics before and after applying the change:
1153
1154 <diagnostics_before>
1155 No diagnostics before applying the edits.
1156 </diagnostics_before>
1157
1158 <diagnostics_after>
1159 Error at line 15: missing semicolon
1160 </diagnostics_after>
1161 "#
1162 .unindent();
1163
1164 assert!(rendered.contains(&expected_diagnostics_section));
1165 }
1166
1167 #[test]
1168 fn test_judge_prompt_without_diagnostics() {
1169 let templates = templates();
1170
1171 // Case 5: No diagnostics check run
1172 let input = JudgeDiffInput {
1173 repository_diff: "diff content goes here".to_string(),
1174 ran_diagnostics_check: false,
1175 diagnostics_before: None,
1176 diagnostics_after: None,
1177 criteria: "Fix all bugs".to_string(),
1178 };
1179
1180 let rendered = templates.render(JUDGE_PROMPT_NAME, &input).unwrap();
1181
1182 // Check for the message when no diagnostics were performed
1183 let diagnostics_message = "No diagnostic checks were performed.";
1184
1185 assert!(rendered.contains(diagnostics_message));
1186 assert!(!rendered.contains("<diagnostics_before>"));
1187 assert!(!rendered.contains("<diagnostics_after>"));
1188 }
1189
1190 const JUDGE_PROMPT_NAME: &str = "judge_prompt";
1191
1192 fn templates() -> Handlebars<'static> {
1193 let mut judge_prompt = include_str!("judge_diff_prompt.hbs").to_string();
1194 language::LineEnding::normalize(&mut judge_prompt);
1195 let mut handlebars = Handlebars::new();
1196 handlebars
1197 .register_template_string(JUDGE_PROMPT_NAME, judge_prompt)
1198 .unwrap();
1199 handlebars
1200 }
1201}