assistant_tools.rs

  1mod copy_path_tool;
  2mod create_directory_tool;
  3mod create_file_tool;
  4mod delete_path_tool;
  5mod diagnostics_tool;
  6mod edit_agent;
  7mod edit_file_tool;
  8mod fetch_tool;
  9mod find_path_tool;
 10mod grep_tool;
 11mod list_directory_tool;
 12mod move_path_tool;
 13mod now_tool;
 14mod open_tool;
 15mod read_file_tool;
 16mod replace;
 17mod schema;
 18mod streaming_edit_file_tool;
 19mod templates;
 20mod terminal_tool;
 21mod thinking_tool;
 22mod ui;
 23mod web_search_tool;
 24
 25use std::sync::Arc;
 26
 27use assistant_settings::AssistantSettings;
 28use assistant_tool::ToolRegistry;
 29use copy_path_tool::CopyPathTool;
 30use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
 31use gpui::{App, Entity};
 32use http_client::HttpClientWithUrl;
 33use language_model::LanguageModelRegistry;
 34use move_path_tool::MovePathTool;
 35use settings::{Settings, SettingsStore};
 36use web_search_tool::WebSearchTool;
 37
 38pub(crate) use templates::*;
 39
 40use crate::create_directory_tool::CreateDirectoryTool;
 41use crate::delete_path_tool::DeletePathTool;
 42use crate::diagnostics_tool::DiagnosticsTool;
 43use crate::fetch_tool::FetchTool;
 44use crate::find_path_tool::FindPathTool;
 45use crate::grep_tool::GrepTool;
 46use crate::list_directory_tool::ListDirectoryTool;
 47use crate::now_tool::NowTool;
 48use crate::read_file_tool::ReadFileTool;
 49use crate::streaming_edit_file_tool::StreamingEditFileTool;
 50use crate::thinking_tool::ThinkingTool;
 51
 52pub use create_file_tool::{CreateFileTool, CreateFileToolInput};
 53pub use edit_file_tool::{EditFileTool, EditFileToolInput};
 54pub use find_path_tool::FindPathToolInput;
 55pub use open_tool::OpenTool;
 56pub use read_file_tool::ReadFileToolInput;
 57pub use streaming_edit_file_tool::StreamingEditFileToolInput;
 58pub use terminal_tool::TerminalTool;
 59
 60pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
 61    assistant_tool::init(cx);
 62
 63    let registry = ToolRegistry::global(cx);
 64    registry.register_tool(TerminalTool::new(cx));
 65    registry.register_tool(CreateDirectoryTool);
 66    registry.register_tool(CopyPathTool);
 67    registry.register_tool(DeletePathTool);
 68    registry.register_tool(MovePathTool);
 69    registry.register_tool(DiagnosticsTool);
 70    registry.register_tool(ListDirectoryTool);
 71    registry.register_tool(NowTool);
 72    registry.register_tool(OpenTool);
 73    registry.register_tool(FindPathTool);
 74    registry.register_tool(ReadFileTool);
 75    registry.register_tool(GrepTool);
 76    registry.register_tool(ThinkingTool);
 77    registry.register_tool(FetchTool::new(http_client));
 78
 79    register_edit_file_tool(cx);
 80    cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
 81        .detach();
 82    cx.observe_global::<SettingsStore>(register_edit_file_tool)
 83        .detach();
 84
 85    register_web_search_tool(&LanguageModelRegistry::global(cx), cx);
 86    cx.subscribe(
 87        &LanguageModelRegistry::global(cx),
 88        move |registry, event, cx| match event {
 89            language_model::Event::DefaultModelChanged => {
 90                register_web_search_tool(&registry, cx);
 91            }
 92            _ => {}
 93        },
 94    )
 95    .detach();
 96}
 97
 98fn register_web_search_tool(registry: &Entity<LanguageModelRegistry>, cx: &mut App) {
 99    let using_zed_provider = registry
100        .read(cx)
101        .default_model()
102        .map_or(false, |default| default.is_provided_by_zed());
103    if using_zed_provider {
104        ToolRegistry::global(cx).register_tool(WebSearchTool);
105    } else {
106        ToolRegistry::global(cx).unregister_tool(WebSearchTool);
107    }
108}
109
110fn register_edit_file_tool(cx: &mut App) {
111    let registry = ToolRegistry::global(cx);
112
113    registry.unregister_tool(CreateFileTool);
114    registry.unregister_tool(EditFileTool);
115    registry.unregister_tool(StreamingEditFileTool);
116
117    if AssistantSettings::get_global(cx).stream_edits(cx) {
118        registry.register_tool(StreamingEditFileTool);
119    } else {
120        registry.register_tool(CreateFileTool);
121        registry.register_tool(EditFileTool);
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use client::Client;
129    use clock::FakeSystemClock;
130    use http_client::FakeHttpClient;
131    use schemars::JsonSchema;
132    use serde::Serialize;
133
134    #[test]
135    fn test_json_schema() {
136        #[derive(Serialize, JsonSchema)]
137        struct GetWeatherTool {
138            location: String,
139        }
140
141        let schema = schema::json_schema_for::<GetWeatherTool>(
142            language_model::LanguageModelToolSchemaFormat::JsonSchema,
143        )
144        .unwrap();
145
146        assert_eq!(
147            schema,
148            serde_json::json!({
149                "type": "object",
150                "properties": {
151                    "location": {
152                        "type": "string"
153                    }
154                },
155                "required": ["location"],
156            })
157        );
158    }
159
160    #[gpui::test]
161    fn test_builtin_tool_schema_compatibility(cx: &mut App) {
162        settings::init(cx);
163        AssistantSettings::register(cx);
164
165        let client = Client::new(
166            Arc::new(FakeSystemClock::new()),
167            FakeHttpClient::with_200_response(),
168            cx,
169        );
170        language_model::init(client.clone(), cx);
171        crate::init(client.http_client(), cx);
172
173        for tool in ToolRegistry::global(cx).tools() {
174            let actual_schema = tool
175                .input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
176                .unwrap();
177            let mut expected_schema = actual_schema.clone();
178            assistant_tool::adapt_schema_to_format(
179                &mut expected_schema,
180                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
181            )
182            .unwrap();
183
184            let error_message = format!(
185                "Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
186                Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
187                tool.name(),
188            );
189
190            assert_eq!(actual_schema, expected_schema, "{}", error_message)
191        }
192    }
193}