tool.rs

  1use anyhow::Result;
  2use gpui::{AnyElement, AnyView, IntoElement as _, Render, Task, View, WindowContext};
  3use schemars::{schema::RootSchema, schema_for, JsonSchema};
  4use serde::Deserialize;
  5use std::fmt::Display;
  6
  7#[derive(Default, Deserialize)]
  8pub struct ToolFunctionCall {
  9    pub id: String,
 10    pub name: String,
 11    pub arguments: String,
 12    #[serde(skip)]
 13    pub result: Option<ToolFunctionCallResult>,
 14}
 15
 16pub enum ToolFunctionCallResult {
 17    NoSuchTool,
 18    ParsingFailed,
 19    Finished { for_model: String, view: AnyView },
 20}
 21
 22impl ToolFunctionCallResult {
 23    pub fn format(&self, name: &String) -> String {
 24        match self {
 25            ToolFunctionCallResult::NoSuchTool => format!("No tool for {name}"),
 26            ToolFunctionCallResult::ParsingFailed => {
 27                format!("Unable to parse arguments for {name}")
 28            }
 29            ToolFunctionCallResult::Finished { for_model, .. } => for_model.clone(),
 30        }
 31    }
 32
 33    pub fn into_any_element(&self, name: &String) -> AnyElement {
 34        match self {
 35            ToolFunctionCallResult::NoSuchTool => {
 36                format!("Language Model attempted to call {name}").into_any_element()
 37            }
 38            ToolFunctionCallResult::ParsingFailed => {
 39                format!("Language Model called {name} with bad arguments").into_any_element()
 40            }
 41            ToolFunctionCallResult::Finished { view, .. } => view.clone().into_any_element(),
 42        }
 43    }
 44}
 45
 46#[derive(Clone)]
 47pub struct ToolFunctionDefinition {
 48    pub name: String,
 49    pub description: String,
 50    pub parameters: RootSchema,
 51}
 52
 53impl Display for ToolFunctionDefinition {
 54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 55        let schema = serde_json::to_string(&self.parameters).ok();
 56        let schema = schema.unwrap_or("None".to_string());
 57        write!(f, "Name: {}:\n", self.name)?;
 58        write!(f, "Description: {}\n", self.description)?;
 59        write!(f, "Parameters: {}", schema)
 60    }
 61}
 62
 63pub trait LanguageModelTool {
 64    /// The input type that will be passed in to `execute` when the tool is called
 65    /// by the language model.
 66    type Input: for<'de> Deserialize<'de> + JsonSchema;
 67
 68    /// The output returned by executing the tool.
 69    type Output: 'static;
 70
 71    type View: Render;
 72
 73    /// Returns the name of the tool.
 74    ///
 75    /// This name is exposed to the language model to allow the model to pick
 76    /// which tools to use. As this name is used to identify the tool within a
 77    /// tool registry, it should be unique.
 78    fn name(&self) -> String;
 79
 80    /// Returns the description of the tool.
 81    ///
 82    /// This can be used to _prompt_ the model as to what the tool does.
 83    fn description(&self) -> String;
 84
 85    /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
 86    fn definition(&self) -> ToolFunctionDefinition {
 87        let root_schema = schema_for!(Self::Input);
 88
 89        ToolFunctionDefinition {
 90            name: self.name(),
 91            description: self.description(),
 92            parameters: root_schema,
 93        }
 94    }
 95
 96    /// Executes the tool with the given input.
 97    fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
 98
 99    fn format(input: &Self::Input, output: &Result<Self::Output>) -> String;
100
101    fn output_view(
102        tool_call_id: String,
103        input: Self::Input,
104        output: Result<Self::Output>,
105        cx: &mut WindowContext,
106    ) -> View<Self::View>;
107}