diff --git a/crates/agent/src/agent_profile.rs b/crates/agent/src/agent_profile.rs index a89857e71a6b8ed0f4e7a397be2bcd1bce4b1d7a..08aa72a46a0aa35c62048f825e673f44a0c15cbf 100644 --- a/crates/agent/src/agent_profile.rs +++ b/crates/agent/src/agent_profile.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use agent_settings::{AgentProfileId, AgentProfileSettings, AgentSettings}; -use assistant_tool::{Tool, ToolSource, ToolWorkingSet, UniqueToolName}; +use assistant_tool::{AnyTool, ToolSource, ToolWorkingSet, UniqueToolName}; use collections::IndexMap; use convert_case::{Case, Casing}; use fs::Fs; @@ -72,7 +72,7 @@ impl AgentProfile { &self.id } - pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc)> { + pub fn enabled_tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> { let Some(settings) = AgentSettings::get_global(cx).profiles.get(&self.id) else { return Vec::new(); }; @@ -108,7 +108,7 @@ impl AgentProfile { #[cfg(test)] mod tests { use agent_settings::ContextServerPreset; - use assistant_tool::ToolRegistry; + use assistant_tool::{Tool, ToolRegistry}; use collections::IndexMap; use gpui::SharedString; use gpui::{AppContext, TestAppContext}; @@ -269,8 +269,14 @@ mod tests { fn default_tool_set(cx: &mut TestAppContext) -> Entity { cx.new(|cx| { let mut tool_set = ToolWorkingSet::default(); - tool_set.insert(Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")), cx); - tool_set.insert(Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")), cx); + tool_set.insert( + Arc::new(FakeTool::new("enabled_mcp_tool", "mcp")).into(), + cx, + ); + tool_set.insert( + Arc::new(FakeTool::new("disabled_mcp_tool", "mcp")).into(), + cx, + ); tool_set }) } @@ -290,6 +296,8 @@ mod tests { } impl Tool for FakeTool { + type Input = (); + fn name(&self) -> String { self.name.clone() } @@ -308,17 +316,17 @@ mod tests { unimplemented!() } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool { unimplemented!() } - fn ui_text(&self, _input: &serde_json::Value) -> String { + fn ui_text(&self, _input: &Self::Input) -> String { unimplemented!() } fn run( self: Arc, - _input: serde_json::Value, + _input: Self::Input, _request: Arc, _project: Entity, _action_log: Entity, diff --git a/crates/agent/src/context_server_tool.rs b/crates/agent/src/context_server_tool.rs index da7de1e312cea24c1be63568cc796a49ddfa178c..7164708cf8dce500d31660abc57ef9d2ff08acd4 100644 --- a/crates/agent/src/context_server_tool.rs +++ b/crates/agent/src/context_server_tool.rs @@ -29,6 +29,8 @@ impl ContextServerTool { } impl Tool for ContextServerTool { + type Input = serde_json::Value; + fn name(&self) -> String { self.tool.name.clone() } @@ -47,7 +49,7 @@ impl Tool for ContextServerTool { } } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { true } @@ -69,13 +71,13 @@ impl Tool for ContextServerTool { }) } - fn ui_text(&self, _input: &serde_json::Value) -> String { + fn ui_text(&self, _input: &Self::Input) -> String { format!("Run MCP tool `{}`", self.tool.name) } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, _project: Entity, _action_log: Entity, diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 815b9e86ea8a7c4c0879e81028c4ee42e3a84ca8..45e85b9abd385ed948779b9263541d613aa8db6f 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -10,7 +10,7 @@ use crate::{ }; use agent_settings::{AgentProfileId, AgentSettings, CompletionMode}; use anyhow::{Result, anyhow}; -use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet}; +use assistant_tool::{ActionLog, AnyTool, AnyToolCard, ToolWorkingSet}; use chrono::{DateTime, Utc}; use client::{ModelRequestUsage, RequestUsage}; use collections::HashMap; @@ -2452,7 +2452,7 @@ impl Thread { ui_text: impl Into, input: serde_json::Value, request: Arc, - tool: Arc, + tool: AnyTool, model: Arc, window: Option, cx: &mut Context, @@ -2468,7 +2468,7 @@ impl Thread { tool_use_id: LanguageModelToolUseId, request: Arc, input: serde_json::Value, - tool: Arc, + tool: AnyTool, model: Arc, window: Option, cx: &mut Context, diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 0347156cd4df0d8b5d953def949739cab1135025..7e4564bfa0dd7ae3fbd95fd9134bcd5fb147a412 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -6,7 +6,7 @@ use crate::{ }; use agent_settings::{AgentProfileId, CompletionMode}; use anyhow::{Context as _, Result, anyhow}; -use assistant_tool::{Tool, ToolId, ToolWorkingSet}; +use assistant_tool::{ToolId, ToolWorkingSet}; use chrono::{DateTime, Utc}; use collections::HashMap; use context_server::ContextServerId; @@ -576,7 +576,8 @@ impl ThreadStore { context_server_store.clone(), server.id(), tool, - )) as Arc + )) + .into() }), cx, ) diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 76de3d20223fcd1c22631029d2040c9109d9ac0d..a2acde3032a3ccd4ef604d15825708501ebb41a0 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -4,7 +4,7 @@ use crate::{ }; use anyhow::Result; use assistant_tool::{ - AnyToolCard, Tool, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, + AnyTool, AnyToolCard, ToolResultContent, ToolResultOutput, ToolUseStatus, ToolWorkingSet, }; use collections::HashMap; use futures::{FutureExt as _, future::Shared}; @@ -378,7 +378,7 @@ impl ToolUseState { ui_text: impl Into>, input: serde_json::Value, request: Arc, - tool: Arc, + tool: AnyTool, ) { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { let ui_text = ui_text.into(); @@ -533,7 +533,7 @@ pub struct Confirmation { pub input: serde_json::Value, pub ui_text: Arc, pub request: Arc, - pub tool: Arc, + pub tool: AnyTool, } #[derive(Debug, Clone)] diff --git a/crates/agent_ui/src/tool_compatibility.rs b/crates/agent_ui/src/tool_compatibility.rs index d4e1da5bb0a532c8307364582349378d98c51a26..ec354dc06fdfde76368e5aaa4623592bc39237de 100644 --- a/crates/agent_ui/src/tool_compatibility.rs +++ b/crates/agent_ui/src/tool_compatibility.rs @@ -1,5 +1,5 @@ use agent::{Thread, ThreadEvent}; -use assistant_tool::{Tool, ToolSource}; +use assistant_tool::{AnyTool, ToolSource}; use collections::HashMap; use gpui::{App, Context, Entity, IntoElement, Render, Subscription, Window}; use language_model::{LanguageModel, LanguageModelToolSchemaFormat}; @@ -7,7 +7,7 @@ use std::sync::Arc; use ui::prelude::*; pub struct IncompatibleToolsState { - cache: HashMap>>, + cache: HashMap>, thread: Entity, _thread_subscription: Subscription, } @@ -29,11 +29,7 @@ impl IncompatibleToolsState { } } - pub fn incompatible_tools( - &mut self, - model: &Arc, - cx: &App, - ) -> &[Arc] { + pub fn incompatible_tools(&mut self, model: &Arc, cx: &App) -> &[AnyTool] { self.cache .entry(model.tool_input_format()) .or_insert_with(|| { @@ -50,7 +46,7 @@ impl IncompatibleToolsState { } pub struct IncompatibleToolsTooltip { - pub incompatible_tools: Vec>, + pub incompatible_tools: Vec, } impl Render for IncompatibleToolsTooltip { diff --git a/crates/assistant_tool/src/assistant_tool.rs b/crates/assistant_tool/src/assistant_tool.rs index 554b3f3f3cf7eb0bc369ee6fed67722755704443..976d5ee697cfe546b608b62260529ff87ccabe10 100644 --- a/crates/assistant_tool/src/assistant_tool.rs +++ b/crates/assistant_tool/src/assistant_tool.rs @@ -4,25 +4,19 @@ mod tool_registry; mod tool_schema; mod tool_working_set; -use std::fmt; -use std::fmt::Debug; -use std::fmt::Formatter; -use std::ops::Deref; -use std::sync::Arc; +use std::{fmt, fmt::Debug, fmt::Formatter, ops::Deref, sync::Arc}; use anyhow::Result; -use gpui::AnyElement; -use gpui::AnyWindowHandle; -use gpui::Context; -use gpui::IntoElement; -use gpui::Window; -use gpui::{App, Entity, SharedString, Task, WeakEntity}; +use gpui::{ + AnyElement, AnyWindowHandle, App, Context, Entity, IntoElement, SharedString, Task, WeakEntity, + Window, +}; use icons::IconName; -use language_model::LanguageModel; -use language_model::LanguageModelImage; -use language_model::LanguageModelRequest; -use language_model::LanguageModelToolSchemaFormat; +use language_model::{ + LanguageModel, LanguageModelImage, LanguageModelRequest, LanguageModelToolSchemaFormat, +}; use project::Project; +use serde::de::DeserializeOwned; use workspace::Workspace; pub use crate::action_log::*; @@ -199,7 +193,10 @@ pub enum ToolSource { } /// A tool that can be used by a language model. -pub trait Tool: 'static + Send + Sync { +pub trait Tool: Send + Sync + 'static { + /// The input type that is accepted by the tool. + type Input: DeserializeOwned; + /// Returns the name of the tool. fn name(&self) -> String; @@ -216,7 +213,7 @@ pub trait Tool: 'static + Send + Sync { /// Returns true if the tool needs the users's confirmation /// before having permission to run. - fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool; + fn needs_confirmation(&self, input: &Self::Input, cx: &App) -> bool; /// Returns true if the tool may perform edits. fn may_perform_edits(&self) -> bool; @@ -227,18 +224,18 @@ pub trait Tool: 'static + Send + Sync { } /// Returns markdown to be displayed in the UI for this tool. - fn ui_text(&self, input: &serde_json::Value) -> String; + fn ui_text(&self, input: &Self::Input) -> String; /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming /// (so information may be missing). - fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { + fn still_streaming_ui_text(&self, input: &Self::Input) -> String { self.ui_text(input) } /// Runs the tool with the provided input. fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, request: Arc, project: Entity, action_log: Entity, @@ -258,7 +255,199 @@ pub trait Tool: 'static + Send + Sync { } } -impl Debug for dyn Tool { +#[derive(Clone)] +pub struct AnyTool { + inner: Arc, +} + +/// Copy of `Tool` where the Input type is erased. +trait ErasedTool: Send + Sync { + fn name(&self) -> String; + fn description(&self) -> String; + fn icon(&self) -> IconName; + fn source(&self) -> ToolSource; + fn may_perform_edits(&self) -> bool; + fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool; + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result; + fn ui_text(&self, input: &serde_json::Value) -> String; + fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String; + fn run( + &self, + input: serde_json::Value, + request: Arc, + project: Entity, + action_log: Entity, + model: Arc, + window: Option, + cx: &mut App, + ) -> ToolResult; + fn deserialize_card( + &self, + output: serde_json::Value, + project: Entity, + window: &mut Window, + cx: &mut App, + ) -> Option; +} + +struct ErasedToolWrapper { + tool: Arc, +} + +impl ErasedTool for ErasedToolWrapper { + fn name(&self) -> String { + self.tool.name() + } + + fn description(&self) -> String { + self.tool.description() + } + + fn icon(&self) -> IconName { + self.tool.icon() + } + + fn source(&self) -> ToolSource { + self.tool.source() + } + + fn may_perform_edits(&self) -> bool { + self.tool.may_perform_edits() + } + + fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool { + match serde_json::from_value::(input.clone()) { + Ok(parsed_input) => self.tool.needs_confirmation(&parsed_input, cx), + Err(_) => true, + } + } + + fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + self.tool.input_schema(format) + } + + fn ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(parsed_input) => self.tool.ui_text(&parsed_input), + Err(_) => "Invalid input".to_string(), + } + } + + fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { + match serde_json::from_value::(input.clone()) { + Ok(parsed_input) => self.tool.still_streaming_ui_text(&parsed_input), + Err(_) => "Invalid input".to_string(), + } + } + + fn run( + &self, + input: serde_json::Value, + request: Arc, + project: Entity, + action_log: Entity, + model: Arc, + window: Option, + cx: &mut App, + ) -> ToolResult { + match serde_json::from_value::(input) { + Ok(parsed_input) => self.tool.clone().run( + parsed_input, + request, + project, + action_log, + model, + window, + cx, + ), + Err(err) => ToolResult::from(Task::ready(Err(err.into()))), + } + } + + fn deserialize_card( + &self, + output: serde_json::Value, + project: Entity, + window: &mut Window, + cx: &mut App, + ) -> Option { + self.tool + .clone() + .deserialize_card(output, project, window, cx) + } +} + +impl From> for AnyTool { + fn from(tool: Arc) -> Self { + Self { + inner: Arc::new(ErasedToolWrapper { tool }), + } + } +} + +impl AnyTool { + pub fn name(&self) -> String { + self.inner.name() + } + + pub fn description(&self) -> String { + self.inner.description() + } + + pub fn icon(&self) -> IconName { + self.inner.icon() + } + + pub fn source(&self) -> ToolSource { + self.inner.source() + } + + pub fn may_perform_edits(&self) -> bool { + self.inner.may_perform_edits() + } + + pub fn needs_confirmation(&self, input: &serde_json::Value, cx: &App) -> bool { + self.inner.needs_confirmation(input, cx) + } + + pub fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result { + self.inner.input_schema(format) + } + + pub fn ui_text(&self, input: &serde_json::Value) -> String { + self.inner.ui_text(input) + } + + pub fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { + self.inner.still_streaming_ui_text(input) + } + + pub fn run( + &self, + input: serde_json::Value, + request: Arc, + project: Entity, + action_log: Entity, + model: Arc, + window: Option, + cx: &mut App, + ) -> ToolResult { + self.inner + .run(input, request, project, action_log, model, window, cx) + } + + pub fn deserialize_card( + &self, + output: serde_json::Value, + project: Entity, + window: &mut Window, + cx: &mut App, + ) -> Option { + self.inner.deserialize_card(output, project, window, cx) + } +} + +impl Debug for AnyTool { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Tool").field("name", &self.name()).finish() } diff --git a/crates/assistant_tool/src/tool_registry.rs b/crates/assistant_tool/src/tool_registry.rs index 26b4821a6d1af05a5e42d639f465486b9311d427..8583688bb81a3b7d83f351857f836530c86ca5e7 100644 --- a/crates/assistant_tool/src/tool_registry.rs +++ b/crates/assistant_tool/src/tool_registry.rs @@ -6,7 +6,7 @@ use gpui::Global; use gpui::{App, ReadGlobal}; use parking_lot::RwLock; -use crate::Tool; +use crate::{AnyTool, Tool}; #[derive(Default, Deref, DerefMut)] struct GlobalToolRegistry(Arc); @@ -15,7 +15,7 @@ impl Global for GlobalToolRegistry {} #[derive(Default)] struct ToolRegistryState { - tools: HashMap, Arc>, + tools: HashMap, AnyTool>, } #[derive(Default)] @@ -48,7 +48,7 @@ impl ToolRegistry { pub fn register_tool(&self, tool: impl Tool) { let mut state = self.state.write(); let tool_name: Arc = tool.name().into(); - state.tools.insert(tool_name, Arc::new(tool)); + state.tools.insert(tool_name, Arc::new(tool).into()); } /// Unregisters the provided [`Tool`]. @@ -63,12 +63,12 @@ impl ToolRegistry { } /// Returns the list of tools in the registry. - pub fn tools(&self) -> Vec> { + pub fn tools(&self) -> Vec { self.state.read().tools.values().cloned().collect() } /// Returns the [`Tool`] with the given name. - pub fn tool(&self, name: &str) -> Option> { + pub fn tool(&self, name: &str) -> Option { self.state.read().tools.get(name).cloned() } } diff --git a/crates/assistant_tool/src/tool_working_set.rs b/crates/assistant_tool/src/tool_working_set.rs index 9a6ec49914eea3cd22f014ce2a5c014d1dca1220..72d3c06fbb462ba55ef680e9002b73bb32b4bb8f 100644 --- a/crates/assistant_tool/src/tool_working_set.rs +++ b/crates/assistant_tool/src/tool_working_set.rs @@ -1,6 +1,6 @@ -use std::{borrow::Borrow, sync::Arc}; +use std::borrow::Borrow; -use crate::{Tool, ToolRegistry, ToolSource}; +use crate::{AnyTool, ToolRegistry, ToolSource}; use collections::{HashMap, HashSet, IndexMap}; use gpui::{App, SharedString}; use util::debug_panic; @@ -45,20 +45,20 @@ impl std::fmt::Display for UniqueToolName { /// A working set of tools for use in one instance of the Assistant Panel. #[derive(Default)] pub struct ToolWorkingSet { - context_server_tools_by_id: HashMap>, - context_server_tools_by_name: HashMap>, + context_server_tools_by_id: HashMap, + context_server_tools_by_name: HashMap, next_tool_id: ToolId, } impl ToolWorkingSet { - pub fn tool(&self, name: &str, cx: &App) -> Option> { + pub fn tool(&self, name: &str, cx: &App) -> Option { self.context_server_tools_by_name .get(name) .cloned() .or_else(|| ToolRegistry::global(cx).tool(name)) } - pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, Arc)> { + pub fn tools(&self, cx: &App) -> Vec<(UniqueToolName, AnyTool)> { let mut tools = ToolRegistry::global(cx) .tools() .into_iter() @@ -68,7 +68,7 @@ impl ToolWorkingSet { tools } - pub fn tools_by_source(&self, cx: &App) -> IndexMap>> { + pub fn tools_by_source(&self, cx: &App) -> IndexMap> { let mut tools_by_source = IndexMap::default(); for (_, tool) in self.tools(cx) { @@ -87,13 +87,13 @@ impl ToolWorkingSet { tools_by_source } - pub fn insert(&mut self, tool: Arc, cx: &App) -> ToolId { + pub fn insert(&mut self, tool: AnyTool, cx: &App) -> ToolId { let tool_id = self.register_tool(tool); self.tools_changed(cx); tool_id } - pub fn extend(&mut self, tools: impl Iterator>, cx: &App) -> Vec { + pub fn extend(&mut self, tools: impl Iterator, cx: &App) -> Vec { let ids = tools.map(|tool| self.register_tool(tool)).collect(); self.tools_changed(cx); ids @@ -105,7 +105,7 @@ impl ToolWorkingSet { self.tools_changed(cx); } - fn register_tool(&mut self, tool: Arc) -> ToolId { + fn register_tool(&mut self, tool: AnyTool) -> ToolId { let tool_id = self.next_tool_id; self.next_tool_id.0 += 1; self.context_server_tools_by_id @@ -126,10 +126,10 @@ impl ToolWorkingSet { } fn resolve_context_server_tool_name_conflicts( - context_server_tools: &[Arc], - native_tools: &[Arc], -) -> HashMap> { - fn resolve_tool_name(tool: &Arc) -> String { + context_server_tools: &[AnyTool], + native_tools: &[AnyTool], +) -> HashMap { + fn resolve_tool_name(tool: &AnyTool) -> String { let mut tool_name = tool.name(); tool_name.truncate(MAX_TOOL_NAME_LENGTH); tool_name @@ -201,11 +201,13 @@ fn resolve_context_server_tool_name_conflicts( } #[cfg(test)] mod tests { + use std::sync::Arc; + use gpui::{AnyWindowHandle, Entity, Task, TestAppContext}; use language_model::{LanguageModel, LanguageModelRequest}; use project::Project; - use crate::{ActionLog, ToolResult}; + use crate::{ActionLog, Tool, ToolResult}; use super::*; @@ -234,11 +236,13 @@ mod tests { Arc::new(TestTool::new( "tool2", ToolSource::ContextServer { id: "mcp-1".into() }, - )) as Arc, + )) + .into(), Arc::new(TestTool::new( "tool2", ToolSource::ContextServer { id: "mcp-2".into() }, - )) as Arc, + )) + .into(), ] .into_iter(), cx, @@ -324,13 +328,13 @@ mod tests { context_server_tools: Vec, expected: Vec<&'static str>, ) { - let context_server_tools: Vec> = context_server_tools + let context_server_tools: Vec = context_server_tools .into_iter() - .map(|t| Arc::new(t) as Arc) + .map(|t| Arc::new(t).into()) .collect(); - let builtin_tools: Vec> = builtin_tools + let builtin_tools: Vec = builtin_tools .into_iter() - .map(|t| Arc::new(t) as Arc) + .map(|t| Arc::new(t).into()) .collect(); let tools = resolve_context_server_tool_name_conflicts(&context_server_tools, &builtin_tools); @@ -363,6 +367,8 @@ mod tests { } impl Tool for TestTool { + type Input = (); + fn name(&self) -> String { self.name.clone() } @@ -375,7 +381,7 @@ mod tests { false } - fn needs_confirmation(&self, _input: &serde_json::Value, _cx: &App) -> bool { + fn needs_confirmation(&self, _input: &Self::Input, _cx: &App) -> bool { true } @@ -387,13 +393,13 @@ mod tests { "Test tool".to_string() } - fn ui_text(&self, _input: &serde_json::Value) -> String { + fn ui_text(&self, _input: &Self::Input) -> String { "Test tool".to_string() } fn run( self: Arc, - _input: serde_json::Value, + _input: Self::Input, _request: Arc, _project: Entity, _action_log: Entity, diff --git a/crates/assistant_tools/src/copy_path_tool.rs b/crates/assistant_tools/src/copy_path_tool.rs index 28d6bef9dd899360cd08e28b876830f81a5bb50a..4363cdd791048ff3264c1047735ae57bae4267dc 100644 --- a/crates/assistant_tools/src/copy_path_tool.rs +++ b/crates/assistant_tools/src/copy_path_tool.rs @@ -40,11 +40,13 @@ pub struct CopyPathToolInput { pub struct CopyPathTool; impl Tool for CopyPathTool { + type Input = CopyPathToolInput; + fn name(&self) -> String { "copy_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -64,20 +66,15 @@ impl Tool for CopyPathTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let src = MarkdownInlineCode(&input.source_path); - let dest = MarkdownInlineCode(&input.destination_path); - format!("Copy {src} to {dest}") - } - Err(_) => "Copy path".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + let src = MarkdownInlineCode(&input.source_path); + let dest = MarkdownInlineCode(&input.destination_path); + format!("Copy {src} to {dest}") } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -85,10 +82,6 @@ impl Tool for CopyPathTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; let copy_task = project.update(cx, |project, cx| { match project .find_project_path(&input.source_path, cx) diff --git a/crates/assistant_tools/src/create_directory_tool.rs b/crates/assistant_tools/src/create_directory_tool.rs index b3e198c1b5e276032846dc8a6c2b67b02c917379..2f93d615eca668bb01b81230cccaf3d1ad52275c 100644 --- a/crates/assistant_tools/src/create_directory_tool.rs +++ b/crates/assistant_tools/src/create_directory_tool.rs @@ -29,6 +29,8 @@ pub struct CreateDirectoryToolInput { pub struct CreateDirectoryTool; impl Tool for CreateDirectoryTool { + type Input = CreateDirectoryToolInput; + fn name(&self) -> String { "create_directory".into() } @@ -37,7 +39,7 @@ impl Tool for CreateDirectoryTool { include_str!("./create_directory_tool/description.md").into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -53,18 +55,13 @@ impl Tool for CreateDirectoryTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - format!("Create directory {}", MarkdownInlineCode(&input.path)) - } - Err(_) => "Create directory".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + format!("Create directory {}", MarkdownInlineCode(&input.path)) } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -72,10 +69,6 @@ impl Tool for CreateDirectoryTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; let project_path = match project.read(cx).find_project_path(&input.path, cx) { Some(project_path) => project_path, None => { diff --git a/crates/assistant_tools/src/delete_path_tool.rs b/crates/assistant_tools/src/delete_path_tool.rs index e45c1976d1f32642b4091e9fad75385a5b4a7c93..32921c3887280e694cb77cd569bd3fd4f4dfe049 100644 --- a/crates/assistant_tools/src/delete_path_tool.rs +++ b/crates/assistant_tools/src/delete_path_tool.rs @@ -29,11 +29,13 @@ pub struct DeletePathToolInput { pub struct DeletePathTool; impl Tool for DeletePathTool { + type Input = DeletePathToolInput; + fn name(&self) -> String { "delete_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -53,16 +55,13 @@ impl Tool for DeletePathTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Delete “`{}`”", input.path), - Err(_) => "Delete path".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + format!("Delete “`{}`”", input.path) } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, action_log: Entity, @@ -70,10 +69,7 @@ impl Tool for DeletePathTool { _window: Option, cx: &mut App, ) -> ToolResult { - let path_str = match serde_json::from_value::(input) { - Ok(input) => input.path, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; + let path_str = input.path; let Some(project_path) = project.read(cx).find_project_path(&path_str, cx) else { return Task::ready(Err(anyhow!( "Couldn't delete {path_str} because that path isn't in this project." diff --git a/crates/assistant_tools/src/diagnostics_tool.rs b/crates/assistant_tools/src/diagnostics_tool.rs index 3b6d38fc06c0e9f8b95f031cb900ace74c5c6b04..cd8e0e831767bbc53acdb6757509286bf214d4af 100644 --- a/crates/assistant_tools/src/diagnostics_tool.rs +++ b/crates/assistant_tools/src/diagnostics_tool.rs @@ -42,11 +42,13 @@ where pub struct DiagnosticsTool; impl Tool for DiagnosticsTool { + type Input = DiagnosticsToolInput; + fn name(&self) -> String { "diagnostics".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -66,15 +68,9 @@ impl Tool for DiagnosticsTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - if let Some(path) = serde_json::from_value::(input.clone()) - .ok() - .and_then(|input| match input.path { - Some(path) if !path.is_empty() => Some(path), - _ => None, - }) - { - format!("Check diagnostics for {}", MarkdownInlineCode(&path)) + fn ui_text(&self, input: &Self::Input) -> String { + if let Some(path) = input.path.as_ref().filter(|p| !p.is_empty()) { + format!("Check diagnostics for {}", MarkdownInlineCode(path)) } else { "Check project diagnostics".to_string() } @@ -82,7 +78,7 @@ impl Tool for DiagnosticsTool { fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, action_log: Entity, @@ -90,10 +86,7 @@ impl Tool for DiagnosticsTool { _window: Option, cx: &mut App, ) -> ToolResult { - match serde_json::from_value::(input) - .ok() - .and_then(|input| input.path) - { + match input.path { Some(path) if !path.is_empty() => { let Some(project_path) = project.read(cx).find_project_path(&path, cx) else { return Task::ready(Err(anyhow!("Could not find path {path} in project",))) diff --git a/crates/assistant_tools/src/edit_file_tool.rs b/crates/assistant_tools/src/edit_file_tool.rs index 8c7728b4b72c9aa52c717e58fbdd63591dd88f0f..34b7d4e486e2effa07c7ba9c4157f832941f2fd0 100644 --- a/crates/assistant_tools/src/edit_file_tool.rs +++ b/crates/assistant_tools/src/edit_file_tool.rs @@ -121,11 +121,13 @@ struct PartialInput { const DEFAULT_UI_TEXT: &str = "Editing file"; impl Tool for EditFileTool { + type Input = EditFileToolInput; + fn name(&self) -> String { "edit_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -145,24 +147,20 @@ impl Tool for EditFileTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => input.display_description, - Err(_) => "Editing file".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + input.display_description.clone() } - fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String { - if let Some(input) = serde_json::from_value::(input.clone()).ok() { - let description = input.display_description.trim(); - if !description.is_empty() { - return description.to_string(); - } + fn still_streaming_ui_text(&self, input: &Self::Input) -> String { + let description = input.display_description.trim(); + if !description.is_empty() { + return description.to_string(); + } - let path = input.path.trim(); - if !path.is_empty() { - return path.to_string(); - } + let path = input.path.to_string_lossy(); + let path = path.trim(); + if !path.is_empty() { + return path.to_string(); } DEFAULT_UI_TEXT.to_string() @@ -170,7 +168,7 @@ impl Tool for EditFileTool { fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, request: Arc, project: Entity, action_log: Entity, @@ -178,11 +176,6 @@ impl Tool for EditFileTool { window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let project_path = match resolve_path(&input, project.clone(), cx) { Ok(path) => path, Err(err) => return Task::ready(Err(anyhow!(err))).into(), @@ -1169,12 +1162,11 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let result = cx .update(|cx| { - let input = serde_json::to_value(EditFileToolInput { + let input = EditFileToolInput { display_description: "Some edit".into(), path: "root/nonexistent_file.txt".into(), mode: EditFileMode::Edit, - }) - .unwrap(); + }; Arc::new(EditFileTool) .run( input, @@ -1288,24 +1280,22 @@ mod tests { #[test] fn still_streaming_ui_text_with_path() { - let input = json!({ - "path": "src/main.rs", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }); + let input = EditFileToolInput { + path: "src/main.rs".into(), + display_description: "".into(), + mode: EditFileMode::Edit, + }; assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs"); } #[test] fn still_streaming_ui_text_with_description() { - let input = json!({ - "path": "", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }); + let input = EditFileToolInput { + path: "".into(), + display_description: "Fix error handling".into(), + mode: EditFileMode::Edit, + }; assert_eq!( EditFileTool.still_streaming_ui_text(&input), @@ -1315,12 +1305,11 @@ mod tests { #[test] fn still_streaming_ui_text_with_path_and_description() { - let input = json!({ - "path": "src/main.rs", - "display_description": "Fix error handling", - "old_string": "old code", - "new_string": "new code" - }); + let input = EditFileToolInput { + path: "src/main.rs".into(), + display_description: "Fix error handling".into(), + mode: EditFileMode::Edit, + }; assert_eq!( EditFileTool.still_streaming_ui_text(&input), @@ -1330,12 +1319,11 @@ mod tests { #[test] fn still_streaming_ui_text_no_path_or_description() { - let input = json!({ - "path": "", - "display_description": "", - "old_string": "old code", - "new_string": "new code" - }); + let input = EditFileToolInput { + path: "".into(), + display_description: "".into(), + mode: EditFileMode::Edit, + }; assert_eq!( EditFileTool.still_streaming_ui_text(&input), @@ -1345,7 +1333,11 @@ mod tests { #[test] fn still_streaming_ui_text_with_null() { - let input = serde_json::Value::Null; + let input = EditFileToolInput { + path: "".into(), + display_description: "".into(), + mode: EditFileMode::Edit, + }; assert_eq!( EditFileTool.still_streaming_ui_text(&input), @@ -1457,12 +1449,11 @@ mod tests { // Have the model stream unformatted content let edit_result = { let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { + let input = EditFileToolInput { display_description: "Create main function".into(), path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, - }) - .unwrap(); + }; Arc::new(EditFileTool) .run( input, @@ -1521,12 +1512,11 @@ mod tests { // Stream unformatted edits again let edit_result = { let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { + let input = EditFileToolInput { display_description: "Update main function".into(), path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, - }) - .unwrap(); + }; Arc::new(EditFileTool) .run( input, @@ -1600,12 +1590,11 @@ mod tests { // Have the model stream content that contains trailing whitespace let edit_result = { let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { + let input = EditFileToolInput { display_description: "Create main function".into(), path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, - }) - .unwrap(); + }; Arc::new(EditFileTool) .run( input, @@ -1657,12 +1646,11 @@ mod tests { // Stream edits again with trailing whitespace let edit_result = { let edit_task = cx.update(|cx| { - let input = serde_json::to_value(EditFileToolInput { + let input = EditFileToolInput { display_description: "Update main function".into(), path: "root/src/main.rs".into(), mode: EditFileMode::Overwrite, - }) - .unwrap(); + }; Arc::new(EditFileTool) .run( input, diff --git a/crates/assistant_tools/src/fetch_tool.rs b/crates/assistant_tools/src/fetch_tool.rs index 82b15b7a86905219167d4f4fb630e6c9bab2c79d..6a48941d5bdc096cd620ea51030edee617aad3ff 100644 --- a/crates/assistant_tools/src/fetch_tool.rs +++ b/crates/assistant_tools/src/fetch_tool.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use std::{borrow::Cow, cell::RefCell}; use crate::schema::json_schema_for; -use anyhow::{Context as _, Result, anyhow, bail}; +use anyhow::{Context as _, Result, bail}; use assistant_tool::{ActionLog, Tool, ToolResult}; use futures::AsyncReadExt as _; -use gpui::{AnyWindowHandle, App, AppContext as _, Entity, Task}; +use gpui::{AnyWindowHandle, App, AppContext as _, Entity}; use html_to_markdown::{TagHandler, convert_html_to_markdown, markdown}; use http_client::{AsyncBody, HttpClientWithUrl}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; @@ -113,11 +113,13 @@ impl FetchTool { } impl Tool for FetchTool { + type Input = FetchToolInput; + fn name(&self) -> String { "fetch".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -137,16 +139,13 @@ impl Tool for FetchTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Fetch {}", MarkdownEscaped(&input.url)), - Err(_) => "Fetch URL".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + format!("Fetch {}", MarkdownEscaped(&input.url)) } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, _project: Entity, _action_log: Entity, @@ -154,11 +153,6 @@ impl Tool for FetchTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let text = cx.background_spawn({ let http_client = self.http_client.clone(); async move { Self::build_message(http_client, &input.url).await } diff --git a/crates/assistant_tools/src/find_path_tool.rs b/crates/assistant_tools/src/find_path_tool.rs index 86e67a8f58cd71aedd163e15cb95aeb9e3357a87..ffd6a28ed259294538e2a430fb16388e0b907d74 100644 --- a/crates/assistant_tools/src/find_path_tool.rs +++ b/crates/assistant_tools/src/find_path_tool.rs @@ -51,11 +51,13 @@ const RESULTS_PER_PAGE: usize = 50; pub struct FindPathTool; impl Tool for FindPathTool { + type Input = FindPathToolInput; + fn name(&self) -> String { "find_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -75,16 +77,13 @@ impl Tool for FindPathTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Find paths matching “`{}`”", input.glob), - Err(_) => "Search paths".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + format!("Find paths matching \"`{}`\"", input.glob) } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -92,10 +91,7 @@ impl Tool for FindPathTool { _window: Option, cx: &mut App, ) -> ToolResult { - let (offset, glob) = match serde_json::from_value::(input) { - Ok(input) => (input.offset, input.glob), - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; + let (offset, glob) = (input.offset, input.glob); let (sender, receiver) = oneshot::channel(); diff --git a/crates/assistant_tools/src/grep_tool.rs b/crates/assistant_tools/src/grep_tool.rs index eb4c8d38e5a586ca7d236906ab537754deb36f1f..384a1093df7a634734d30f839c06d632d53789c3 100644 --- a/crates/assistant_tools/src/grep_tool.rs +++ b/crates/assistant_tools/src/grep_tool.rs @@ -53,11 +53,13 @@ const RESULTS_PER_PAGE: u32 = 20; pub struct GrepTool; impl Tool for GrepTool { + type Input = GrepToolInput; + fn name(&self) -> String { "grep".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -77,30 +79,25 @@ impl Tool for GrepTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let page = input.page(); - let regex_str = MarkdownInlineCode(&input.regex); - let case_info = if input.case_sensitive { - " (case-sensitive)" - } else { - "" - }; + fn ui_text(&self, input: &Self::Input) -> String { + let page = input.page(); + let regex_str = MarkdownInlineCode(&input.regex); + let case_info = if input.case_sensitive { + " (case-sensitive)" + } else { + "" + }; - if page > 1 { - format!("Get page {page} of search results for regex {regex_str}{case_info}") - } else { - format!("Search files for regex {regex_str}{case_info}") - } - } - Err(_) => "Search with regex".to_string(), + if page > 1 { + format!("Get page {page} of search results for regex {regex_str}{case_info}") + } else { + format!("Search files for regex {regex_str}{case_info}") } } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -111,13 +108,6 @@ impl Tool for GrepTool { const CONTEXT_LINES: u32 = 2; const MAX_ANCESTOR_LINES: u32 = 10; - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(error) => { - return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into(); - } - }; - let include_matcher = match PathMatcher::new( input .include_pattern @@ -348,13 +338,12 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; // Test with include pattern for Rust files inside the root of the project - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "println".to_string(), include_pattern: Some("root/**/*.rs".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; assert!(result.contains("main.rs"), "Should find matches in main.rs"); @@ -368,13 +357,12 @@ mod tests { ); // Test with include pattern for src directory only - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "fn".to_string(), include_pattern: Some("root/**/src/**".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; assert!( @@ -391,13 +379,12 @@ mod tests { ); // Test with empty include pattern (should default to all files) - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "fn".to_string(), include_pattern: None, offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; assert!(result.contains("main.rs"), "Should find matches in main.rs"); @@ -428,13 +415,12 @@ mod tests { let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; // Test case-insensitive search (default) - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "uppercase".to_string(), include_pattern: Some("**/*.txt".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; assert!( @@ -443,13 +429,12 @@ mod tests { ); // Test case-sensitive search - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "uppercase".to_string(), include_pattern: Some("**/*.txt".to_string()), offset: 0, case_sensitive: true, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; assert!( @@ -458,13 +443,12 @@ mod tests { ); // Test case-sensitive search - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "LOWERCASE".to_string(), include_pattern: Some("**/*.txt".to_string()), offset: 0, case_sensitive: true, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; @@ -474,13 +458,12 @@ mod tests { ); // Test case-sensitive search for lowercase pattern - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "lowercase".to_string(), include_pattern: Some("**/*.txt".to_string()), offset: 0, case_sensitive: true, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; assert!( @@ -576,13 +559,12 @@ mod tests { let project = setup_syntax_test(cx).await; // Test: Line at the top level of the file - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "This is at the top level".to_string(), include_pattern: Some("**/*.rs".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; let expected = r#" @@ -606,13 +588,12 @@ mod tests { let project = setup_syntax_test(cx).await; // Test: Line inside a function body - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "Function in nested module".to_string(), include_pattern: Some("**/*.rs".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; let expected = r#" @@ -638,13 +619,12 @@ mod tests { let project = setup_syntax_test(cx).await; // Test: Line with a function argument - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "second_arg".to_string(), include_pattern: Some("**/*.rs".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; let expected = r#" @@ -674,13 +654,12 @@ mod tests { let project = setup_syntax_test(cx).await; // Test: Line inside an if block - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "Inside if block".to_string(), include_pattern: Some("**/*.rs".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; let expected = r#" @@ -705,13 +684,12 @@ mod tests { let project = setup_syntax_test(cx).await; // Test: Line in the middle of a long function - should show message about remaining lines - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "Line 5".to_string(), include_pattern: Some("**/*.rs".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; let expected = r#" @@ -746,13 +724,12 @@ mod tests { let project = setup_syntax_test(cx).await; // Test: Line in the long function - let input = serde_json::to_value(GrepToolInput { + let input = GrepToolInput { regex: "Line 12".to_string(), include_pattern: Some("**/*.rs".to_string()), offset: 0, case_sensitive: false, - }) - .unwrap(); + }; let result = run_grep_tool(input, project.clone(), cx).await; let expected = r#" @@ -774,7 +751,7 @@ mod tests { } async fn run_grep_tool( - input: serde_json::Value, + input: GrepToolInput, project: Entity, cx: &mut TestAppContext, ) -> String { @@ -876,9 +853,12 @@ mod tests { // Searching for files outside the project worktree should return no results let result = cx .update(|cx| { - let input = json!({ - "regex": "outside_function" - }); + let input = GrepToolInput { + regex: "outside_function".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -902,9 +882,12 @@ mod tests { // Searching within the project should succeed let result = cx .update(|cx| { - let input = json!({ - "regex": "main" - }); + let input = GrepToolInput { + regex: "main".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -928,9 +911,12 @@ mod tests { // Searching files that match file_scan_exclusions should return no results let result = cx .update(|cx| { - let input = json!({ - "regex": "special_configuration" - }); + let input = GrepToolInput { + regex: "special_configuration".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -953,9 +939,12 @@ mod tests { let result = cx .update(|cx| { - let input = json!({ - "regex": "custom_metadata" - }); + let input = GrepToolInput { + regex: "custom_metadata".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -979,9 +968,12 @@ mod tests { // Searching private files should return no results let result = cx .update(|cx| { - let input = json!({ - "regex": "SECRET_KEY" - }); + let input = GrepToolInput { + regex: "SECRET_KEY".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -1004,9 +996,12 @@ mod tests { let result = cx .update(|cx| { - let input = json!({ - "regex": "private_key_content" - }); + let input = GrepToolInput { + regex: "private_key_content".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -1029,9 +1024,12 @@ mod tests { let result = cx .update(|cx| { - let input = json!({ - "regex": "sensitive_data" - }); + let input = GrepToolInput { + regex: "sensitive_data".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -1055,9 +1053,12 @@ mod tests { // Searching a normal file should still work, even with private_files configured let result = cx .update(|cx| { - let input = json!({ - "regex": "normal_file_content" - }); + let input = GrepToolInput { + regex: "normal_file_content".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -1081,10 +1082,12 @@ mod tests { // Path traversal attempts with .. in include_pattern should not escape project let result = cx .update(|cx| { - let input = json!({ - "regex": "outside_function", - "include_pattern": "../outside_project/**/*.rs" - }); + let input = GrepToolInput { + regex: "outside_function".to_string(), + include_pattern: Some("../outside_project/**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -1185,10 +1188,12 @@ mod tests { // Search for "secret" - should exclude files based on worktree-specific settings let result = cx .update(|cx| { - let input = json!({ - "regex": "secret", - "case_sensitive": false - }); + let input = GrepToolInput { + regex: "secret".to_string(), + include_pattern: None, + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, @@ -1250,10 +1255,12 @@ mod tests { // Test with `include_pattern` specific to one worktree let result = cx .update(|cx| { - let input = json!({ - "regex": "secret", - "include_pattern": "worktree1/**/*.rs" - }); + let input = GrepToolInput { + regex: "secret".to_string(), + include_pattern: Some("worktree1/**/*.rs".to_string()), + offset: 0, + case_sensitive: false, + }; Arc::new(GrepTool) .run( input, diff --git a/crates/assistant_tools/src/list_directory_tool.rs b/crates/assistant_tools/src/list_directory_tool.rs index aef186b9ae5adcc0e7d1625d483b1e4d6d9d51ca..7c9a4a00b33adb14ad907787310278536f20dc01 100644 --- a/crates/assistant_tools/src/list_directory_tool.rs +++ b/crates/assistant_tools/src/list_directory_tool.rs @@ -41,11 +41,13 @@ pub struct ListDirectoryToolInput { pub struct ListDirectoryTool; impl Tool for ListDirectoryTool { + type Input = ListDirectoryToolInput; + fn name(&self) -> String { "list_directory".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -65,19 +67,14 @@ impl Tool for ListDirectoryTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let path = MarkdownInlineCode(&input.path); - format!("List the {path} directory's contents") - } - Err(_) => "List directory".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + let path = MarkdownInlineCode(&input.path); + format!("List the {path} directory's contents") } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -85,11 +82,6 @@ impl Tool for ListDirectoryTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - // Sometimes models will return these even though we tell it to give a path and not a glob. // When this happens, just list the root worktree directories. if matches!(input.path.as_str(), "." | "" | "./" | "*") { @@ -285,9 +277,9 @@ mod tests { let tool = Arc::new(ListDirectoryTool); // Test listing root directory - let input = json!({ - "path": "project" - }); + let input = ListDirectoryToolInput { + path: "project".to_string(), + }; let result = cx .update(|cx| { @@ -320,9 +312,9 @@ mod tests { ); // Test listing src directory - let input = json!({ - "path": "project/src" - }); + let input = ListDirectoryToolInput { + path: "project/src".to_string(), + }; let result = cx .update(|cx| { @@ -355,9 +347,9 @@ mod tests { ); // Test listing directory with only files - let input = json!({ - "path": "project/tests" - }); + let input = ListDirectoryToolInput { + path: "project/tests".to_string(), + }; let result = cx .update(|cx| { @@ -399,9 +391,9 @@ mod tests { let model = Arc::new(FakeLanguageModel::default()); let tool = Arc::new(ListDirectoryTool); - let input = json!({ - "path": "project/empty_dir" - }); + let input = ListDirectoryToolInput { + path: "project/empty_dir".to_string(), + }; let result = cx .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx)) @@ -432,9 +424,9 @@ mod tests { let tool = Arc::new(ListDirectoryTool); // Test non-existent path - let input = json!({ - "path": "project/nonexistent" - }); + let input = ListDirectoryToolInput { + path: "project/nonexistent".to_string(), + }; let result = cx .update(|cx| { @@ -455,9 +447,9 @@ mod tests { assert!(result.unwrap_err().to_string().contains("Path not found")); // Test trying to list a file instead of directory - let input = json!({ - "path": "project/file.txt" - }); + let input = ListDirectoryToolInput { + path: "project/file.txt".to_string(), + }; let result = cx .update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx)) @@ -527,9 +519,9 @@ mod tests { let tool = Arc::new(ListDirectoryTool); // Listing root directory should exclude private and excluded files - let input = json!({ - "path": "project" - }); + let input = ListDirectoryToolInput { + path: "project".to_string(), + }; let result = cx .update(|cx| { @@ -568,9 +560,9 @@ mod tests { ); // Trying to list an excluded directory should fail - let input = json!({ - "path": "project/.secretdir" - }); + let input = ListDirectoryToolInput { + path: "project/.secretdir".to_string(), + }; let result = cx .update(|cx| { @@ -600,9 +592,9 @@ mod tests { ); // Listing a directory should exclude private files within it - let input = json!({ - "path": "project/visible_dir" - }); + let input = ListDirectoryToolInput { + path: "project/visible_dir".to_string(), + }; let result = cx .update(|cx| { @@ -720,9 +712,9 @@ mod tests { let tool = Arc::new(ListDirectoryTool); // Test listing worktree1/src - should exclude secret.rs and config.toml based on local settings - let input = json!({ - "path": "worktree1/src" - }); + let input = ListDirectoryToolInput { + path: "worktree1/src".to_string(), + }; let result = cx .update(|cx| { @@ -752,9 +744,9 @@ mod tests { ); // Test listing worktree1/tests - should exclude fixture.sql based on local settings - let input = json!({ - "path": "worktree1/tests" - }); + let input = ListDirectoryToolInput { + path: "worktree1/tests".to_string(), + }; let result = cx .update(|cx| { @@ -780,9 +772,9 @@ mod tests { ); // Test listing worktree2/lib - should exclude private.js and data.json based on local settings - let input = json!({ - "path": "worktree2/lib" - }); + let input = ListDirectoryToolInput { + path: "worktree2/lib".to_string(), + }; let result = cx .update(|cx| { @@ -812,9 +804,9 @@ mod tests { ); // Test listing worktree2/docs - should exclude internal.md based on local settings - let input = json!({ - "path": "worktree2/docs" - }); + let input = ListDirectoryToolInput { + path: "worktree2/docs".to_string(), + }; let result = cx .update(|cx| { @@ -840,9 +832,9 @@ mod tests { ); // Test trying to list an excluded directory directly - let input = json!({ - "path": "worktree1/src/secret.rs" - }); + let input = ListDirectoryToolInput { + path: "worktree1/src/secret.rs".to_string(), + }; let result = cx .update(|cx| { diff --git a/crates/assistant_tools/src/move_path_tool.rs b/crates/assistant_tools/src/move_path_tool.rs index 27ae10151d4e91f951e198e850e5ff6fc2fb331b..def76eb914dc732fb0767f232b9ff3d65ab4462e 100644 --- a/crates/assistant_tools/src/move_path_tool.rs +++ b/crates/assistant_tools/src/move_path_tool.rs @@ -38,11 +38,13 @@ pub struct MovePathToolInput { pub struct MovePathTool; impl Tool for MovePathTool { + type Input = MovePathToolInput; + fn name(&self) -> String { "move_path".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -62,34 +64,29 @@ impl Tool for MovePathTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let src = MarkdownInlineCode(&input.source_path); - let dest = MarkdownInlineCode(&input.destination_path); - let src_path = Path::new(&input.source_path); - let dest_path = Path::new(&input.destination_path); + fn ui_text(&self, input: &Self::Input) -> String { + let src = MarkdownInlineCode(&input.source_path); + let dest = MarkdownInlineCode(&input.destination_path); + let src_path = Path::new(&input.source_path); + let dest_path = Path::new(&input.destination_path); - match dest_path - .file_name() - .and_then(|os_str| os_str.to_os_string().into_string().ok()) - { - Some(filename) if src_path.parent() == dest_path.parent() => { - let filename = MarkdownInlineCode(&filename); - format!("Rename {src} to {filename}") - } - _ => { - format!("Move {src} to {dest}") - } - } + match dest_path + .file_name() + .and_then(|os_str| os_str.to_os_string().into_string().ok()) + { + Some(filename) if src_path.parent() == dest_path.parent() => { + let filename = MarkdownInlineCode(&filename); + format!("Rename {src} to {filename}") + } + _ => { + format!("Move {src} to {dest}") } - Err(_) => "Move path".to_string(), } } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -97,10 +94,6 @@ impl Tool for MovePathTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; let rename_task = project.update(cx, |project, cx| { match project .find_project_path(&input.source_path, cx) diff --git a/crates/assistant_tools/src/now_tool.rs b/crates/assistant_tools/src/now_tool.rs index b6b1cf90a43b487684b9c8f0d4f6a69a14af6455..50c59c134d02431aed6a3c5e351d9d863bc2d449 100644 --- a/crates/assistant_tools/src/now_tool.rs +++ b/crates/assistant_tools/src/now_tool.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::schema::json_schema_for; -use anyhow::{Result, anyhow}; +use anyhow::Result; use assistant_tool::{ActionLog, Tool, ToolResult}; use chrono::{Local, Utc}; use gpui::{AnyWindowHandle, App, Entity, Task}; @@ -29,11 +29,13 @@ pub struct NowToolInput { pub struct NowTool; impl Tool for NowTool { + type Input = NowToolInput; + fn name(&self) -> String { "now".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -53,13 +55,13 @@ impl Tool for NowTool { json_schema_for::(format) } - fn ui_text(&self, _input: &serde_json::Value) -> String { + fn ui_text(&self, _input: &Self::Input) -> String { "Get current time".to_string() } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, _project: Entity, _action_log: Entity, @@ -67,11 +69,6 @@ impl Tool for NowTool { _window: Option, _cx: &mut App, ) -> ToolResult { - let input: NowToolInput = match serde_json::from_value(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let now = match input.timezone { Timezone::Utc => Utc::now().to_rfc3339(), Timezone::Local => Local::now().to_rfc3339(), diff --git a/crates/assistant_tools/src/open_tool.rs b/crates/assistant_tools/src/open_tool.rs index 97a4769e19e60758fe509fab56bf7329ac7f30b6..7966ccb614f7eabe20c3fad13e85a1d70c926a0c 100644 --- a/crates/assistant_tools/src/open_tool.rs +++ b/crates/assistant_tools/src/open_tool.rs @@ -1,7 +1,7 @@ use crate::schema::json_schema_for; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result}; use assistant_tool::{ActionLog, Tool, ToolResult}; -use gpui::{AnyWindowHandle, App, AppContext, Entity, Task}; +use gpui::{AnyWindowHandle, App, AppContext, Entity}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; use project::Project; use schemars::JsonSchema; @@ -19,11 +19,13 @@ pub struct OpenToolInput { pub struct OpenTool; impl Tool for OpenTool { + type Input = OpenToolInput; + fn name(&self) -> String { "open".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { true } fn may_perform_edits(&self) -> bool { @@ -41,16 +43,13 @@ impl Tool for OpenTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => format!("Open `{}`", MarkdownEscaped(&input.path_or_url)), - Err(_) => "Open file or URL".to_string(), - } + fn ui_text(&self, input: &Self::Input) -> String { + format!("Open `{}`", MarkdownEscaped(&input.path_or_url)) } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -58,11 +57,6 @@ impl Tool for OpenTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input: OpenToolInput = match serde_json::from_value(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - // If path_or_url turns out to be a path in the project, make it absolute. let abs_path = to_absolute_path(&input.path_or_url, project, cx); diff --git a/crates/assistant_tools/src/read_file_tool.rs b/crates/assistant_tools/src/read_file_tool.rs index 4d40fc6a7c71fc41cb23f689f3e9dc6b699f81c1..4c5fe7f0077a411036451ccd8756d6186c4168c1 100644 --- a/crates/assistant_tools/src/read_file_tool.rs +++ b/crates/assistant_tools/src/read_file_tool.rs @@ -51,11 +51,13 @@ pub struct ReadFileToolInput { pub struct ReadFileTool; impl Tool for ReadFileTool { + type Input = ReadFileToolInput; + fn name(&self) -> String { "read_file".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -75,23 +77,18 @@ impl Tool for ReadFileTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let path = MarkdownInlineCode(&input.path); - match (input.start_line, input.end_line) { - (Some(start), None) => format!("Read file {path} (from line {start})"), - (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"), - _ => format!("Read file {path}"), - } - } - Err(_) => "Read file".to_string(), + fn ui_text(&self, input: &Self::Input) -> String { + let path = MarkdownInlineCode(&input.path); + match (input.start_line, input.end_line) { + (Some(start), None) => format!("Read file {path} (from line {start})"), + (Some(start), Some(end)) => format!("Read file {path} (lines {start}-{end})"), + _ => format!("Read file {path}"), } } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, action_log: Entity, @@ -99,11 +96,6 @@ impl Tool for ReadFileTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else { return Task::ready(Err(anyhow!("Path {} not found in project", &input.path))).into(); }; @@ -308,9 +300,12 @@ mod test { let model = Arc::new(FakeLanguageModel::default()); let result = cx .update(|cx| { - let input = json!({ - "path": "root/nonexistent_file.txt" - }); + let input = ReadFileToolInput { + path: "root/nonexistent_file.txt".to_string(), + start_line: None, + end_line: None, + }; + Arc::new(ReadFileTool) .run( input, @@ -347,9 +342,11 @@ mod test { let model = Arc::new(FakeLanguageModel::default()); let result = cx .update(|cx| { - let input = json!({ - "path": "root/small_file.txt" - }); + let input = ReadFileToolInput { + path: "root/small_file.txt".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -389,9 +386,11 @@ mod test { let result = cx .update(|cx| { - let input = json!({ - "path": "root/large_file.rs" - }); + let input = ReadFileToolInput { + path: "root/large_file.rs".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -421,10 +420,11 @@ mod test { let result = cx .update(|cx| { - let input = json!({ - "path": "root/large_file.rs", - "offset": 1 - }); + let input = ReadFileToolInput { + path: "root/large_file.rs".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -477,11 +477,11 @@ mod test { let model = Arc::new(FakeLanguageModel::default()); let result = cx .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 2, - "end_line": 4 - }); + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(2), + end_line: Some(4), + }; Arc::new(ReadFileTool) .run( input, @@ -520,11 +520,11 @@ mod test { // start_line of 0 should be treated as 1 let result = cx .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 0, - "end_line": 2 - }); + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(0), + end_line: Some(2), + }; Arc::new(ReadFileTool) .run( input, @@ -543,11 +543,11 @@ mod test { // end_line of 0 should result in at least 1 line let result = cx .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 1, - "end_line": 0 - }); + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(1), + end_line: Some(0), + }; Arc::new(ReadFileTool) .run( input, @@ -566,11 +566,11 @@ mod test { // when start_line > end_line, should still return at least 1 line let result = cx .update(|cx| { - let input = json!({ - "path": "root/multiline.txt", - "start_line": 3, - "end_line": 2 - }); + let input = ReadFileToolInput { + path: "root/multiline.txt".to_string(), + start_line: Some(3), + end_line: Some(2), + }; Arc::new(ReadFileTool) .run( input, @@ -694,9 +694,11 @@ mod test { // Reading a file outside the project worktree should fail let result = cx .update(|cx| { - let input = json!({ - "path": "/outside_project/sensitive_file.txt" - }); + let input = ReadFileToolInput { + path: "/outside_project/sensitive_file.txt".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -718,9 +720,11 @@ mod test { // Reading a file within the project should succeed let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/allowed_file.txt" - }); + let input = ReadFileToolInput { + path: "project_root/allowed_file.txt".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -742,9 +746,11 @@ mod test { // Reading files that match file_scan_exclusions should fail let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/.secretdir/config" - }); + let input = ReadFileToolInput { + path: "project_root/.secretdir/config".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -765,9 +771,11 @@ mod test { let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/.mymetadata" - }); + let input = ReadFileToolInput { + path: "project_root/.mymetadata".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -789,9 +797,11 @@ mod test { // Reading private files should fail let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/.mysecrets" - }); + let input = ReadFileToolInput { + path: "project_root/secrets/.mysecrets".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -812,9 +822,11 @@ mod test { let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/subdir/special.privatekey" - }); + let input = ReadFileToolInput { + path: "project_root/subdir/special.privatekey".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -835,9 +847,11 @@ mod test { let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/subdir/data.mysensitive" - }); + let input = ReadFileToolInput { + path: "project_root/subdir/data.mysensitive".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -859,9 +873,11 @@ mod test { // Reading a normal file should still work, even with private_files configured let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/subdir/normal_file.txt" - }); + let input = ReadFileToolInput { + path: "project_root/subdir/normal_file.txt".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -884,9 +900,11 @@ mod test { // Path traversal attempts with .. should fail let result = cx .update(|cx| { - let input = json!({ - "path": "project_root/../outside_project/sensitive_file.txt" - }); + let input = ReadFileToolInput { + path: "project_root/../outside_project/sensitive_file.txt".to_string(), + start_line: None, + end_line: None, + }; Arc::new(ReadFileTool) .run( input, @@ -981,9 +999,11 @@ mod test { let tool = Arc::new(ReadFileTool); // Test reading allowed files in worktree1 - let input = json!({ - "path": "worktree1/src/main.rs" - }); + let input = ReadFileToolInput { + path: "worktree1/src/main.rs".to_string(), + start_line: None, + end_line: None, + }; let result = cx .update(|cx| { @@ -1007,9 +1027,11 @@ mod test { ); // Test reading private file in worktree1 should fail - let input = json!({ - "path": "worktree1/src/secret.rs" - }); + let input = ReadFileToolInput { + path: "worktree1/src/secret.rs".to_string(), + start_line: None, + end_line: None, + }; let result = cx .update(|cx| { @@ -1036,9 +1058,11 @@ mod test { ); // Test reading excluded file in worktree1 should fail - let input = json!({ - "path": "worktree1/tests/fixture.sql" - }); + let input = ReadFileToolInput { + path: "worktree1/tests/fixture.sql".to_string(), + start_line: None, + end_line: None, + }; let result = cx .update(|cx| { @@ -1065,9 +1089,11 @@ mod test { ); // Test reading allowed files in worktree2 - let input = json!({ - "path": "worktree2/lib/public.js" - }); + let input = ReadFileToolInput { + path: "worktree2/lib/public.js".to_string(), + start_line: None, + end_line: None, + }; let result = cx .update(|cx| { @@ -1091,9 +1117,11 @@ mod test { ); // Test reading private file in worktree2 should fail - let input = json!({ - "path": "worktree2/lib/private.js" - }); + let input = ReadFileToolInput { + path: "worktree2/lib/private.js".to_string(), + start_line: None, + end_line: None, + }; let result = cx .update(|cx| { @@ -1120,9 +1148,11 @@ mod test { ); // Test reading excluded file in worktree2 should fail - let input = json!({ - "path": "worktree2/docs/internal.md" - }); + let input = ReadFileToolInput { + path: "worktree2/docs/internal.md".to_string(), + start_line: None, + end_line: None, + }; let result = cx .update(|cx| { @@ -1150,9 +1180,11 @@ mod test { // Test that files allowed in one worktree but not in another are handled correctly // (e.g., config.toml is private in worktree1 but doesn't exist in worktree2) - let input = json!({ - "path": "worktree1/src/config.toml" - }); + let input = ReadFileToolInput { + path: "worktree1/src/config.toml".to_string(), + start_line: None, + end_line: None, + }; let result = cx .update(|cx| { diff --git a/crates/assistant_tools/src/terminal_tool.rs b/crates/assistant_tools/src/terminal_tool.rs index 2c582a531069eb9a81340af7eb07731e8df8a96e..43b1309e786457d312e082c2bfed7a8fc9d3a799 100644 --- a/crates/assistant_tools/src/terminal_tool.rs +++ b/crates/assistant_tools/src/terminal_tool.rs @@ -2,7 +2,7 @@ use crate::{ schema::json_schema_for, ui::{COLLAPSED_LINES, ToolOutputPreview}, }; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result}; use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus}; use futures::{FutureExt as _, future::Shared}; use gpui::{ @@ -72,11 +72,13 @@ impl TerminalTool { } impl Tool for TerminalTool { + type Input = TerminalToolInput; + fn name(&self) -> String { Self::NAME.to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { true } @@ -96,30 +98,24 @@ impl Tool for TerminalTool { json_schema_for::(format) } - fn ui_text(&self, input: &serde_json::Value) -> String { - match serde_json::from_value::(input.clone()) { - Ok(input) => { - let mut lines = input.command.lines(); - let first_line = lines.next().unwrap_or_default(); - let remaining_line_count = lines.count(); - match remaining_line_count { - 0 => MarkdownInlineCode(&first_line).to_string(), - 1 => MarkdownInlineCode(&format!( - "{} - {} more line", - first_line, remaining_line_count - )) - .to_string(), - n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)) - .to_string(), - } - } - Err(_) => "Run terminal command".to_string(), + fn ui_text(&self, input: &Self::Input) -> String { + let mut lines = input.command.lines(); + let first_line = lines.next().unwrap_or_default(); + let remaining_line_count = lines.count(); + match remaining_line_count { + 0 => MarkdownInlineCode(&first_line).to_string(), + 1 => MarkdownInlineCode(&format!( + "{} - {} more line", + first_line, remaining_line_count + )) + .to_string(), + n => MarkdownInlineCode(&format!("{} - {} more lines", first_line, n)).to_string(), } } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, project: Entity, _action_log: Entity, @@ -127,11 +123,6 @@ impl Tool for TerminalTool { window: Option, cx: &mut App, ) -> ToolResult { - let input: TerminalToolInput = match serde_json::from_value(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; - let working_dir = match working_dir(&input, &project, cx) { Ok(dir) => dir, Err(err) => return Task::ready(Err(err)).into(), @@ -756,7 +747,7 @@ mod tests { let result = cx.update(|cx| { TerminalTool::run( Arc::new(TerminalTool::new(cx)), - serde_json::to_value(input).unwrap(), + input, Arc::default(), project.clone(), action_log.clone(), @@ -791,7 +782,7 @@ mod tests { let check = |input, expected, cx: &mut App| { let headless_result = TerminalTool::run( Arc::new(TerminalTool::new(cx)), - serde_json::to_value(input).unwrap(), + input, Arc::default(), project.clone(), action_log.clone(), diff --git a/crates/assistant_tools/src/thinking_tool.rs b/crates/assistant_tools/src/thinking_tool.rs index 4641b7359e1039cefb80e2a4f97ec5db94bfd90e..4d672dedfc906fcf985ba37476eefc8022d3a8e3 100644 --- a/crates/assistant_tools/src/thinking_tool.rs +++ b/crates/assistant_tools/src/thinking_tool.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::schema::json_schema_for; -use anyhow::{Result, anyhow}; +use anyhow::Result; use assistant_tool::{ActionLog, Tool, ToolResult}; use gpui::{AnyWindowHandle, App, Entity, Task}; use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat}; @@ -20,11 +20,13 @@ pub struct ThinkingToolInput { pub struct ThinkingTool; impl Tool for ThinkingTool { + type Input = ThinkingToolInput; + fn name(&self) -> String { "thinking".to_string() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -44,13 +46,13 @@ impl Tool for ThinkingTool { json_schema_for::(format) } - fn ui_text(&self, _input: &serde_json::Value) -> String { + fn ui_text(&self, _input: &Self::Input) -> String { "Thinking".to_string() } fn run( self: Arc, - input: serde_json::Value, + _input: Self::Input, _request: Arc, _project: Entity, _action_log: Entity, @@ -59,10 +61,6 @@ impl Tool for ThinkingTool { _cx: &mut App, ) -> ToolResult { // This tool just "thinks out loud" and doesn't perform any actions. - Task::ready(match serde_json::from_value::(input) { - Ok(_input) => Ok("Finished thinking.".to_string().into()), - Err(err) => Err(anyhow!(err)), - }) - .into() + Task::ready(Ok("Finished thinking.".to_string().into())).into() } } diff --git a/crates/assistant_tools/src/web_search_tool.rs b/crates/assistant_tools/src/web_search_tool.rs index 9430ac9d9e245d4f8871fcf120cba9ed48a5ba97..b3c445a4f512f949f326df07a7b12bd431bbbbd1 100644 --- a/crates/assistant_tools/src/web_search_tool.rs +++ b/crates/assistant_tools/src/web_search_tool.rs @@ -28,11 +28,13 @@ pub struct WebSearchToolInput { pub struct WebSearchTool; impl Tool for WebSearchTool { + type Input = WebSearchToolInput; + fn name(&self) -> String { "web_search".into() } - fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool { + fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool { false } @@ -52,13 +54,13 @@ impl Tool for WebSearchTool { json_schema_for::(format) } - fn ui_text(&self, _input: &serde_json::Value) -> String { + fn ui_text(&self, _input: &Self::Input) -> String { "Searching the Web".to_string() } fn run( self: Arc, - input: serde_json::Value, + input: Self::Input, _request: Arc, _project: Entity, _action_log: Entity, @@ -66,10 +68,6 @@ impl Tool for WebSearchTool { _window: Option, cx: &mut App, ) -> ToolResult { - let input = match serde_json::from_value::(input) { - Ok(input) => input, - Err(err) => return Task::ready(Err(anyhow!(err))).into(), - }; let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else { return Task::ready(Err(anyhow!("Web search is not available."))).into(); }; diff --git a/crates/remote_server/src/remote_editing_tests.rs b/crates/remote_server/src/remote_editing_tests.rs index 9730984f2632be65330203fcd93350cf29233435..b25fd400bbf8fc0bd21e2b09408be5a4d55828cf 100644 --- a/crates/remote_server/src/remote_editing_tests.rs +++ b/crates/remote_server/src/remote_editing_tests.rs @@ -1736,7 +1736,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu let exists_result = cx.update(|cx| { ReadFileTool::run( Arc::new(ReadFileTool), - serde_json::to_value(input).unwrap(), + input, request.clone(), project.clone(), action_log.clone(), @@ -1756,7 +1756,7 @@ async fn test_remote_agent_fs_tool_calls(cx: &mut TestAppContext, server_cx: &mu let does_not_exist_result = cx.update(|cx| { ReadFileTool::run( Arc::new(ReadFileTool), - serde_json::to_value(input).unwrap(), + input, request.clone(), project.clone(), action_log.clone(),