assistant_tools.rs

  1mod batch_tool;
  2mod code_action_tool;
  3mod code_symbols_tool;
  4mod contents_tool;
  5mod copy_path_tool;
  6mod create_directory_tool;
  7mod create_file_tool;
  8mod delete_path_tool;
  9mod diagnostics_tool;
 10mod edit_agent;
 11mod edit_file_tool;
 12mod fetch_tool;
 13mod find_path_tool;
 14mod grep_tool;
 15mod list_directory_tool;
 16mod move_path_tool;
 17mod now_tool;
 18mod open_tool;
 19mod read_file_tool;
 20mod rename_tool;
 21mod replace;
 22mod schema;
 23mod streaming_edit_file_tool;
 24mod symbol_info_tool;
 25mod templates;
 26mod terminal_tool;
 27mod thinking_tool;
 28mod ui;
 29mod web_search_tool;
 30
 31use std::sync::Arc;
 32
 33use assistant_settings::AssistantSettings;
 34use assistant_tool::ToolRegistry;
 35use copy_path_tool::CopyPathTool;
 36use feature_flags::{AgentStreamEditsFeatureFlag, FeatureFlagAppExt};
 37use gpui::{App, Entity};
 38use http_client::HttpClientWithUrl;
 39use language_model::LanguageModelRegistry;
 40use move_path_tool::MovePathTool;
 41use settings::{Settings, SettingsStore};
 42use web_search_tool::WebSearchTool;
 43
 44pub(crate) use templates::*;
 45
 46use crate::batch_tool::BatchTool;
 47use crate::code_action_tool::CodeActionTool;
 48use crate::code_symbols_tool::CodeSymbolsTool;
 49use crate::contents_tool::ContentsTool;
 50use crate::create_directory_tool::CreateDirectoryTool;
 51use crate::create_file_tool::CreateFileTool;
 52use crate::delete_path_tool::DeletePathTool;
 53use crate::diagnostics_tool::DiagnosticsTool;
 54use crate::edit_file_tool::EditFileTool;
 55use crate::fetch_tool::FetchTool;
 56use crate::find_path_tool::FindPathTool;
 57use crate::grep_tool::GrepTool;
 58use crate::list_directory_tool::ListDirectoryTool;
 59use crate::now_tool::NowTool;
 60use crate::open_tool::OpenTool;
 61use crate::read_file_tool::ReadFileTool;
 62use crate::rename_tool::RenameTool;
 63use crate::streaming_edit_file_tool::StreamingEditFileTool;
 64use crate::symbol_info_tool::SymbolInfoTool;
 65use crate::terminal_tool::TerminalTool;
 66use crate::thinking_tool::ThinkingTool;
 67
 68pub use create_file_tool::CreateFileToolInput;
 69pub use edit_file_tool::EditFileToolInput;
 70pub use find_path_tool::FindPathToolInput;
 71pub use read_file_tool::ReadFileToolInput;
 72
 73pub fn init(http_client: Arc<HttpClientWithUrl>, cx: &mut App) {
 74    assistant_tool::init(cx);
 75
 76    let registry = ToolRegistry::global(cx);
 77    registry.register_tool(TerminalTool);
 78    registry.register_tool(BatchTool);
 79    registry.register_tool(CreateDirectoryTool);
 80    registry.register_tool(CopyPathTool);
 81    registry.register_tool(DeletePathTool);
 82    registry.register_tool(SymbolInfoTool);
 83    registry.register_tool(CodeActionTool);
 84    registry.register_tool(MovePathTool);
 85    registry.register_tool(DiagnosticsTool);
 86    registry.register_tool(ListDirectoryTool);
 87    registry.register_tool(NowTool);
 88    registry.register_tool(OpenTool);
 89    registry.register_tool(CodeSymbolsTool);
 90    registry.register_tool(ContentsTool);
 91    registry.register_tool(FindPathTool);
 92    registry.register_tool(ReadFileTool);
 93    registry.register_tool(GrepTool);
 94    registry.register_tool(RenameTool);
 95    registry.register_tool(ThinkingTool);
 96    registry.register_tool(FetchTool::new(http_client));
 97
 98    register_edit_file_tool(cx);
 99    cx.observe_flag::<AgentStreamEditsFeatureFlag, _>(|_, cx| register_edit_file_tool(cx))
100        .detach();
101    cx.observe_global::<SettingsStore>(register_edit_file_tool)
102        .detach();
103
104    register_web_search_tool(&LanguageModelRegistry::global(cx), cx);
105    cx.subscribe(
106        &LanguageModelRegistry::global(cx),
107        move |registry, event, cx| match event {
108            language_model::Event::DefaultModelChanged => {
109                register_web_search_tool(&registry, cx);
110            }
111            _ => {}
112        },
113    )
114    .detach();
115}
116
117fn register_web_search_tool(registry: &Entity<LanguageModelRegistry>, cx: &mut App) {
118    let using_zed_provider = registry
119        .read(cx)
120        .default_model()
121        .map_or(false, |default| default.is_provided_by_zed());
122    if using_zed_provider {
123        ToolRegistry::global(cx).register_tool(WebSearchTool);
124    } else {
125        ToolRegistry::global(cx).unregister_tool(WebSearchTool);
126    }
127}
128
129fn register_edit_file_tool(cx: &mut App) {
130    let registry = ToolRegistry::global(cx);
131
132    registry.unregister_tool(CreateFileTool);
133    registry.unregister_tool(EditFileTool);
134    registry.unregister_tool(StreamingEditFileTool);
135
136    if AssistantSettings::get_global(cx).stream_edits(cx) {
137        registry.register_tool(StreamingEditFileTool);
138    } else {
139        registry.register_tool(CreateFileTool);
140        registry.register_tool(EditFileTool);
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use client::Client;
148    use clock::FakeSystemClock;
149    use http_client::FakeHttpClient;
150    use schemars::JsonSchema;
151    use serde::Serialize;
152
153    #[test]
154    fn test_json_schema() {
155        #[derive(Serialize, JsonSchema)]
156        struct GetWeatherTool {
157            location: String,
158        }
159
160        let schema = schema::json_schema_for::<GetWeatherTool>(
161            language_model::LanguageModelToolSchemaFormat::JsonSchema,
162        )
163        .unwrap();
164
165        assert_eq!(
166            schema,
167            serde_json::json!({
168                "type": "object",
169                "properties": {
170                    "location": {
171                        "type": "string"
172                    }
173                },
174                "required": ["location"],
175            })
176        );
177    }
178
179    #[gpui::test]
180    fn test_builtin_tool_schema_compatibility(cx: &mut App) {
181        settings::init(cx);
182        AssistantSettings::register(cx);
183
184        let client = Client::new(
185            Arc::new(FakeSystemClock::new()),
186            FakeHttpClient::with_200_response(),
187            cx,
188        );
189        language_model::init(client.clone(), cx);
190        crate::init(client.http_client(), cx);
191
192        for tool in ToolRegistry::global(cx).tools() {
193            let actual_schema = tool
194                .input_schema(language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset)
195                .unwrap();
196            let mut expected_schema = actual_schema.clone();
197            assistant_tool::adapt_schema_to_format(
198                &mut expected_schema,
199                language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset,
200            )
201            .unwrap();
202
203            let error_message = format!(
204                "Tool schema for `{}` is not compatible with `language_model::LanguageModelToolSchemaFormat::JsonSchemaSubset` (Gemini Models).\n\
205                Are you using `schema::json_schema_for<T>(format)` to generate the schema?",
206                tool.name(),
207            );
208
209            assert_eq!(actual_schema, expected_schema, "{}", error_message)
210        }
211    }
212}