tool.rs

  1use anyhow::Result;
  2use gpui::{div, AnyElement, AppContext, Element, ParentElement as _, Task, WindowContext};
  3use schemars::{schema::SchemaObject, schema_for, JsonSchema};
  4use serde::Deserialize;
  5use std::{any::Any, fmt::Debug};
  6
  7#[derive(Default, Deserialize)]
  8pub struct ToolFunctionCall {
  9    pub id: String,
 10    pub name: String,
 11    pub arguments: String,
 12    #[serde(skip)]
 13    pub result: Option<ToolFunctionCallResult>,
 14}
 15
 16pub enum ToolFunctionCallResult {
 17    NoSuchTool,
 18    ParsingFailed,
 19    ExecutionFailed {
 20        input: Box<dyn Any>,
 21    },
 22    Finished {
 23        input: Box<dyn Any>,
 24        output: Box<dyn Any>,
 25        render_fn: fn(
 26            // tool_call_id
 27            &str,
 28            // LanguageModelTool::Input
 29            &Box<dyn Any>,
 30            // LanguageModelTool::Output
 31            &Box<dyn Any>,
 32            &mut WindowContext,
 33        ) -> AnyElement,
 34        format_fn: fn(
 35            // LanguageModelTool::Input
 36            &Box<dyn Any>,
 37            // LanguageModelTool::Output
 38            &Box<dyn Any>,
 39        ) -> String,
 40    },
 41}
 42
 43impl ToolFunctionCallResult {
 44    pub fn render(
 45        &self,
 46        tool_name: &str,
 47        tool_call_id: &str,
 48        cx: &mut WindowContext,
 49    ) -> AnyElement {
 50        match self {
 51            ToolFunctionCallResult::NoSuchTool => {
 52                div().child(format!("no such tool {tool_name}")).into_any()
 53            }
 54            ToolFunctionCallResult::ParsingFailed => div()
 55                .child(format!("failed to parse input for tool {tool_name}"))
 56                .into_any(),
 57            ToolFunctionCallResult::ExecutionFailed { .. } => div()
 58                .child(format!("failed to execute tool {tool_name}"))
 59                .into_any(),
 60            ToolFunctionCallResult::Finished {
 61                input,
 62                output,
 63                render_fn,
 64                ..
 65            } => render_fn(tool_call_id, input, output, cx),
 66        }
 67    }
 68
 69    pub fn format(&self, tool: &str) -> String {
 70        match self {
 71            ToolFunctionCallResult::NoSuchTool => format!("no such tool {tool}"),
 72            ToolFunctionCallResult::ParsingFailed => {
 73                format!("failed to parse input for tool {tool}")
 74            }
 75            ToolFunctionCallResult::ExecutionFailed { input: _input } => {
 76                format!("failed to execute tool {tool}")
 77            }
 78            ToolFunctionCallResult::Finished {
 79                input,
 80                output,
 81                format_fn,
 82                ..
 83            } => format_fn(input, output),
 84        }
 85    }
 86}
 87
 88#[derive(Clone)]
 89pub struct ToolFunctionDefinition {
 90    pub name: String,
 91    pub description: String,
 92    pub parameters: SchemaObject,
 93}
 94
 95impl Debug for ToolFunctionDefinition {
 96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 97        let schema = serde_json::to_string(&self.parameters).ok();
 98        let schema = schema.unwrap_or("None".to_string());
 99
100        f.debug_struct("ToolFunctionDefinition")
101            .field("name", &self.name)
102            .field("description", &self.description)
103            .field("parameters", &schema)
104            .finish()
105    }
106}
107
108pub trait LanguageModelTool {
109    /// The input type that will be passed in to `execute` when the tool is called
110    /// by the language model.
111    type Input: for<'de> Deserialize<'de> + JsonSchema;
112
113    /// The output returned by executing the tool.
114    type Output: 'static;
115
116    /// The name of the tool is exposed to the language model to allow
117    /// the model to pick which tools to use. As this name is used to
118    /// identify the tool within a tool registry, it should be unique.
119    fn name(&self) -> String;
120
121    /// A description of the tool that can be used to _prompt_ the model
122    /// as to what the tool does.
123    fn description(&self) -> String;
124
125    /// The OpenAI Function definition for the tool, for direct use with OpenAI's API.
126    fn definition(&self) -> ToolFunctionDefinition {
127        ToolFunctionDefinition {
128            name: self.name(),
129            description: self.description(),
130            parameters: schema_for!(Self::Input).schema,
131        }
132    }
133
134    /// Execute the tool
135    fn execute(&self, input: &Self::Input, cx: &AppContext) -> Task<Result<Self::Output>>;
136
137    fn render(
138        tool_call_id: &str,
139        input: &Self::Input,
140        output: &Self::Output,
141        cx: &mut WindowContext,
142    ) -> AnyElement;
143
144    fn format(input: &Self::Input, output: &Self::Output) -> String;
145}