tool.rs

  1use anyhow::Result;
  2use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
  3use schemars::{schema::RootSchema, schema_for, JsonSchema};
  4use serde::Deserialize;
  5use std::{
  6    any::Any,
  7    fmt::{Debug, Display},
  8};
  9
 10#[derive(Default, Deserialize)]
 11pub struct ToolFunctionCall {
 12    pub id: String,
 13    pub name: String,
 14    pub arguments: String,
 15    #[serde(skip)]
 16    pub result: Option<ToolFunctionCallResult>,
 17}
 18
 19pub enum ToolFunctionCallResult {
 20    NoSuchTool,
 21    ParsingFailed,
 22    ExecutionFailed {
 23        input: Box<dyn Any>,
 24    },
 25    Finished {
 26        input: Box<dyn Any>,
 27        output: Box<dyn Any>,
 28        render_fn: fn(
 29            // tool_call_id
 30            &str,
 31            // LanguageModelTool::Input
 32            &Box<dyn Any>,
 33            // LanguageModelTool::Output
 34            &Box<dyn Any>,
 35            &mut WindowContext,
 36        ) -> AnyElement,
 37        format_fn: fn(
 38            // LanguageModelTool::Input
 39            &Box<dyn Any>,
 40            // LanguageModelTool::Output
 41            &Box<dyn Any>,
 42        ) -> String,
 43    },
 44}
 45
 46impl ToolFunctionCallResult {
 47    pub fn render(
 48        &self,
 49        tool_name: &str,
 50        tool_call_id: &str,
 51        cx: &mut WindowContext,
 52    ) -> AnyElement {
 53        match self {
 54            ToolFunctionCallResult::NoSuchTool => {
 55                div().child(format!("no such tool {tool_name}")).into_any()
 56            }
 57            ToolFunctionCallResult::ParsingFailed => div()
 58                .child(format!("failed to parse input for tool {tool_name}"))
 59                .into_any(),
 60            ToolFunctionCallResult::ExecutionFailed { .. } => div()
 61                .child(format!("failed to execute tool {tool_name}"))
 62                .into_any(),
 63            ToolFunctionCallResult::Finished {
 64                input,
 65                output,
 66                render_fn,
 67                ..
 68            } => render_fn(tool_call_id, input, output, cx),
 69        }
 70    }
 71
 72    pub fn format(&self, tool: &str) -> String {
 73        match self {
 74            ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
 75            ToolFunctionCallResult::ParsingFailed => {
 76                format!("failed to parse input for tool {tool}")
 77            }
 78            ToolFunctionCallResult::ExecutionFailed { input: _input } => {
 79                format!("failed to execute tool {tool}")
 80            }
 81            ToolFunctionCallResult::Finished {
 82                input,
 83                output,
 84                format_fn,
 85                ..
 86            } => format_fn(input, output),
 87        }
 88    }
 89}
 90
 91#[derive(Clone)]
 92pub struct ToolFunctionDefinition {
 93    pub name: String,
 94    pub description: String,
 95    pub parameters: RootSchema,
 96}
 97
 98impl Display for ToolFunctionDefinition {
 99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        let schema = serde_json::to_string(&self.parameters).ok();
101        let schema = schema.unwrap_or("None".to_string());
102        write!(f, "Name: {}:\n", self.name)?;
103        write!(f, "Description: {}\n", self.description)?;
104        write!(f, "Parameters: {}", schema)
105    }
106}
107
108impl Debug for ToolFunctionDefinition {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        let schema = serde_json::to_string(&self.parameters).ok();
111        let schema = schema.unwrap_or("None".to_string());
112
113        f.debug_struct("ToolFunctionDefinition")
114            .field("name", &self.name)
115            .field("description", &self.description)
116            .field("parameters", &schema)
117            .finish()
118    }
119}
120
121pub trait LanguageModelTool {
122    /// The input type that will be passed in to `execute` when the tool is called
123    /// by the language model.
124    type Input: for<'de> Deserialize<'de> + JsonSchema;
125
126    /// The output returned by executing the tool.
127    type Output: 'static;
128
129    /// The name of the tool is exposed to the language model to allow
130    /// the model to pick which tools to use. As this name is used to
131    /// identify the tool within a tool registry, it should be unique.
132    fn name(&self) -> String;
133
134    /// A description of the tool that can be used to _prompt_ the model
135    /// as to what the tool does.
136    fn description(&self) -> String;
137
138    /// The OpenAI Function definition for the tool, for direct use with OpenAI's API.
139    fn definition(&self) -> ToolFunctionDefinition {
140        let root_schema = schema_for!(Self::Input);
141
142        ToolFunctionDefinition {
143            name: self.name(),
144            description: self.description(),
145            parameters: root_schema,
146        }
147    }
148
149    /// Execute the tool
150    fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
151
152    fn render(
153        tool_call_id: &str,
154        input: &Self::Input,
155        output: &Self::Output,
156        cx: &mut WindowContext,
157    ) -> AnyElement;
158
159    fn format(input: &Self::Input, output: &Self::Output) -> String;
160}