diagnostics_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{AnyWindowHandle, App, Entity, Task};
  5use language::{DiagnosticSeverity, OffsetRangeExt};
  6use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::{fmt::Write, path::Path, sync::Arc};
 11use ui::IconName;
 12use util::markdown::MarkdownInlineCode;
 13
 14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 15pub struct DiagnosticsToolInput {
 16    /// The path to get diagnostics for. If not provided, returns a project-wide summary.
 17    ///
 18    /// This path should never be absolute, and the first component
 19    /// of the path should always be a root directory in a project.
 20    ///
 21    /// <example>
 22    /// If the project has the following root directories:
 23    ///
 24    /// - lorem
 25    /// - ipsum
 26    ///
 27    /// If you wanna access diagnostics for `dolor.txt` in `ipsum`, you should use the path `ipsum/dolor.txt`.
 28    /// </example>
 29    #[serde(deserialize_with = "deserialize_path")]
 30    pub path: Option<String>,
 31}
 32
 33fn deserialize_path<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
 34where
 35    D: serde::Deserializer<'de>,
 36{
 37    let opt = Option::<String>::deserialize(deserializer)?;
 38    // The model passes an empty string sometimes
 39    Ok(opt.filter(|s| !s.is_empty()))
 40}
 41
 42pub struct DiagnosticsTool;
 43
 44impl Tool for DiagnosticsTool {
 45    fn name(&self) -> String {
 46        "diagnostics".into()
 47    }
 48
 49    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 50        false
 51    }
 52
 53    fn may_perform_edits(&self) -> bool {
 54        false
 55    }
 56
 57    fn description(&self) -> String {
 58        include_str!("./diagnostics_tool/description.md").into()
 59    }
 60
 61    fn icon(&self) -> IconName {
 62        IconName::ToolDiagnostics
 63    }
 64
 65    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 66        json_schema_for::<DiagnosticsToolInput>(format)
 67    }
 68
 69    fn ui_text(&self, input: &serde_json::Value) -> String {
 70        if let Some(path) = serde_json::from_value::<DiagnosticsToolInput>(input.clone())
 71            .ok()
 72            .and_then(|input| match input.path {
 73                Some(path) if !path.is_empty() => Some(path),
 74                _ => None,
 75            })
 76        {
 77            format!("Check diagnostics for {}", MarkdownInlineCode(&path))
 78        } else {
 79            "Check project diagnostics".to_string()
 80        }
 81    }
 82
 83    fn run(
 84        self: Arc<Self>,
 85        input: serde_json::Value,
 86        _request: Arc<LanguageModelRequest>,
 87        project: Entity<Project>,
 88        action_log: Entity<ActionLog>,
 89        _model: Arc<dyn LanguageModel>,
 90        _window: Option<AnyWindowHandle>,
 91        cx: &mut App,
 92    ) -> ToolResult {
 93        match serde_json::from_value::<DiagnosticsToolInput>(input)
 94            .ok()
 95            .and_then(|input| input.path)
 96        {
 97            Some(path) if !path.is_empty() => {
 98                let Some(project_path) = project.read(cx).find_project_path(&path, cx) else {
 99                    return Task::ready(Err(anyhow!("Could not find path {path} in project",)))
100                        .into();
101                };
102
103                let buffer =
104                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
105
106                cx.spawn(async move |cx| {
107                    let mut output = String::new();
108                    let buffer = buffer.await?;
109                    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
110
111                    for (_, group) in snapshot.diagnostic_groups(None) {
112                        let entry = &group.entries[group.primary_ix];
113                        let range = entry.range.to_point(&snapshot);
114                        let severity = match entry.diagnostic.severity {
115                            DiagnosticSeverity::ERROR => "error",
116                            DiagnosticSeverity::WARNING => "warning",
117                            _ => continue,
118                        };
119
120                        writeln!(
121                            output,
122                            "{} at line {}: {}",
123                            severity,
124                            range.start.row + 1,
125                            entry.diagnostic.message
126                        )?;
127                    }
128
129                    if output.is_empty() {
130                        Ok("File doesn't have errors or warnings!".to_string().into())
131                    } else {
132                        Ok(output.into())
133                    }
134                })
135                .into()
136            }
137            _ => {
138                let project = project.read(cx);
139                let mut output = String::new();
140                let mut has_diagnostics = false;
141
142                for (project_path, _, summary) in project.diagnostic_summaries(true, cx) {
143                    if summary.error_count > 0 || summary.warning_count > 0 {
144                        let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx)
145                        else {
146                            continue;
147                        };
148
149                        has_diagnostics = true;
150                        output.push_str(&format!(
151                            "{}: {} error(s), {} warning(s)\n",
152                            Path::new(worktree.read(cx).root_name())
153                                .join(project_path.path)
154                                .display(),
155                            summary.error_count,
156                            summary.warning_count
157                        ));
158                    }
159                }
160
161                action_log.update(cx, |action_log, _cx| {
162                    action_log.checked_project_diagnostics();
163                });
164
165                if has_diagnostics {
166                    Task::ready(Ok(output.into())).into()
167                } else {
168                    Task::ready(Ok("No errors or warnings found in the project."
169                        .to_string()
170                        .into()))
171                    .into()
172                }
173            }
174        }
175    }
176}