assistant_tools.rs

  1mod copy_path_tool;
  2mod create_directory_tool;
  3mod create_file_tool;
  4mod delete_path_tool;
  5mod diagnostics_tool;
  6mod edit_agent;
  7mod edit_file_tool;
  8mod fetch_tool;
  9mod find_path_tool;
 10mod grep_tool;
 11mod list_directory_tool;
 12mod move_path_tool;
 13mod now_tool;
 14mod open_tool;
 15mod read_file_tool;
 16mod replace;
 17mod schema;
 18mod streaming_edit_file_tool;
 19mod templates;
 20mod terminal_tool;
 21mod thinking_tool;
 22mod ui;
 23mod web_search_tool;
 24
 25use std::sync::Arc;
 26
 27use assistant_settings::AssistantSettings;
 28use assistant_tool::ToolRegistry;
 29use copy_path_tool::CopyPathTool;
 30use gpui::{App, Entity};
 31use http_client::HttpClientWithUrl;
 32use language_model::LanguageModelRegistry;
 33use move_path_tool::MovePathTool;
 34use settings::{Settings, SettingsStore};
 35use web_search_tool::WebSearchTool;
 36
 37pub(crate) use templates::*;
 38
 39use crate::create_directory_tool::CreateDirectoryTool;
 40use crate::delete_path_tool::DeletePathTool;
 41use crate::diagnostics_tool::DiagnosticsTool;
 42use crate::fetch_tool::FetchTool;
 43use crate::find_path_tool::FindPathTool;
 44use crate::grep_tool::GrepTool;
 45use crate::list_directory_tool::ListDirectoryTool;
 46use crate::now_tool::NowTool;
 47use crate::read_file_tool::ReadFileTool;
 48use crate::streaming_edit_file_tool::StreamingEditFileTool;
 49use crate::thinking_tool::ThinkingTool;
 50
 51pub use create_file_tool::{CreateFileTool, CreateFileToolInput};
 52pub use edit_file_tool::{EditFileTool, EditFileToolInput};
 53pub use find_path_tool::FindPathToolInput;
 54pub use open_tool::OpenTool;
 55pub use read_file_tool::ReadFileToolInput;
 56pub use streaming_edit_file_tool::StreamingEditFileToolInput;
 57pub use terminal_tool::TerminalTool;
 58
 59pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
 60    assistant_tool::init(cx);
 61
 62    let registry = ToolRegistry::global(cx);
 63    registry.register_tool(TerminalTool::new(cx));
 64    registry.register_tool(CreateDirectoryTool);
 65    registry.register_tool(CopyPathTool);
 66    registry.register_tool(DeletePathTool);
 67    registry.register_tool(MovePathTool);
 68    registry.register_tool(DiagnosticsTool);
 69    registry.register_tool(ListDirectoryTool);
 70    registry.register_tool(NowTool);
 71    registry.register_tool(OpenTool);
 72    registry.register_tool(FindPathTool);
 73    registry.register_tool(ReadFileTool);
 74    registry.register_tool(GrepTool);
 75    registry.register_tool(ThinkingTool);
 76    registry.register_tool(FetchTool::new(http_client));
 77
 78    register_edit_file_tool(cx);
 79    cx.observe_global::<SettingsStore>(register_edit_file_tool)
 80        .detach();
 81
 82    register_web_search_tool(&LanguageModelRegistry::global(cx), cx);
 83    cx.subscribe(
 84        &LanguageModelRegistry::global(cx),
 85        move |registry, event, cx| match event {
 86            language_model::Event::DefaultModelChanged => {
 87                register_web_search_tool(&registry, cx);
 88            }
 89            _ => {}
 90        },
 91    )
 92    .detach();
 93}
 94
 95fn register_web_search_tool(registry: &Entity<LanguageModelRegistry>, cx: &mut App) {
 96    let using_zed_provider = registry
 97        .read(cx)
 98        .default_model()
 99        .map_or(false, |default| default.is_provided_by_zed());
100    if using_zed_provider {
101        ToolRegistry::global(cx).register_tool(WebSearchTool);
102    } else {
103        ToolRegistry::global(cx).unregister_tool(WebSearchTool);
104    }
105}
106
107fn register_edit_file_tool(cx: &mut App) {
108    let registry = ToolRegistry::global(cx);
109
110    registry.unregister_tool(CreateFileTool);
111    registry.unregister_tool(EditFileTool);
112    registry.unregister_tool(StreamingEditFileTool);
113
114    if AssistantSettings::get_global(cx).stream_edits(cx) {
115        registry.register_tool(StreamingEditFileTool);
116    } else {
117        registry.register_tool(CreateFileTool);
118        registry.register_tool(EditFileTool);
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use client::Client;
126    use clock::FakeSystemClock;
127    use http_client::FakeHttpClient;
128    use schemars::JsonSchema;
129    use serde::Serialize;
130
131    #[test]
132    fn test_json_schema() {
133        #[derive(Serialize, JsonSchema)]
134        struct GetWeatherTool {
135            location: String,
136        }
137
138        let schema = schema::json_schema_for::<GetWeatherTool>(
139            language_model::LanguageModelToolSchemaFormat::JsonSchema,
140        )
141        .unwrap();
142
143        assert_eq!(
144            schema,
145            serde_json::json!({
146                "type": "object",
147                "properties": {
148                    "location": {
149                        "type": "string"
150                    }
151                },
152                "required": ["location"],
153            })
154        );
155    }
156
157    #[gpui::test]
158    fn test_builtin_tool_schema_compatibility(cx: &mut App) {
159        settings::init(cx);
160        AssistantSettings::register(cx);
161
162        let client = Client::new(
163            Arc::new(FakeSystemClock::new()),
164            FakeHttpClient::with_200_response(),
165            cx,
166        );
167        language_model::init(client.clone(), cx);
168        crate::init(client.http_client(), cx);
169
170        for tool in ToolRegistry::global(cx).tools() {
171            let actual_schema = tool
172                .input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
173                .unwrap();
174            let mut expected_schema = actual_schema.clone();
175            assistant_tool::adapt_schema_to_format(
176                &mut expected_schema,
177                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
178            )
179            .unwrap();
180
181            let error_message = format!(
182                "Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
183                Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
184                tool.name(),
185            );
186
187            assert_eq!(actual_schema, expected_schema, "{}", error_message)
188        }
189    }
190}