create_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{App, Entity, Task};
  5use language_model::LanguageModelRequestMessage;
  6use language_model::LanguageModelToolSchemaFormat;
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::sync::Arc;
 11use ui::IconName;
 12use util::markdown::MarkdownString;
 13
 14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 15pub struct CreateFileToolInput {
 16    /// The path where the file should be created.
 17    ///
 18    /// <example>
 19    /// If the project has the following structure:
 20    ///
 21    /// - directory1/
 22    /// - directory2/
 23    ///
 24    /// You can create a new file by providing a path of "directory1/new_file.txt"
 25    /// </example>
 26    pub path: String,
 27
 28    /// The text contents of the file to create.
 29    ///
 30    /// <example>
 31    /// To create a file with the text "Hello, World!", provide contents of "Hello, World!"
 32    /// </example>
 33    pub contents: String,
 34}
 35
 36#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 37struct PartialInput {
 38    #[serde(default)]
 39    path: String,
 40    #[serde(default)]
 41    contents: String,
 42}
 43
 44pub struct CreateFileTool;
 45
 46const DEFAULT_UI_TEXT: &str = "Create file";
 47
 48impl Tool for CreateFileTool {
 49    fn name(&self) -> String {
 50        "create_file".into()
 51    }
 52
 53    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 54        false
 55    }
 56
 57    fn description(&self) -> String {
 58        include_str!("./create_file_tool/description.md").into()
 59    }
 60
 61    fn icon(&self) -> IconName {
 62        IconName::FileCreate
 63    }
 64
 65    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 66        json_schema_for::<CreateFileToolInput>(format)
 67    }
 68
 69    fn ui_text(&self, input: &serde_json::Value) -> String {
 70        match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
 71            Ok(input) => {
 72                let path = MarkdownString::inline_code(&input.path);
 73                format!("Create file {path}")
 74            }
 75            Err(_) => DEFAULT_UI_TEXT.to_string(),
 76        }
 77    }
 78
 79    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
 80        match serde_json::from_value::<PartialInput>(input.clone()).ok() {
 81            Some(input) if !input.path.is_empty() => input.path,
 82            _ => DEFAULT_UI_TEXT.to_string(),
 83        }
 84    }
 85
 86    fn run(
 87        self: Arc<Self>,
 88        input: serde_json::Value,
 89        _messages: &[LanguageModelRequestMessage],
 90        project: Entity<Project>,
 91        action_log: Entity<ActionLog>,
 92        cx: &mut App,
 93    ) -> ToolResult {
 94        let input = match serde_json::from_value::<CreateFileToolInput>(input) {
 95            Ok(input) => input,
 96            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 97        };
 98        let project_path = match project.read(cx).find_project_path(&input.path, cx) {
 99            Some(project_path) => project_path,
100            None => {
101                return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
102            }
103        };
104        let contents: Arc<str> = input.contents.as_str().into();
105        let destination_path: Arc<str> = input.path.as_str().into();
106
107        cx.spawn(async move |cx| {
108            let buffer = project
109                .update(cx, |project, cx| {
110                    project.open_buffer(project_path.clone(), cx)
111                })?
112                .await
113                .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
114            cx.update(|cx| {
115                buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
116                action_log.update(cx, |action_log, cx| {
117                    action_log.will_create_buffer(buffer.clone(), cx)
118                });
119            })?;
120
121            project
122                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
123                .await
124                .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
125
126            Ok(format!("Created file {destination_path}"))
127        })
128        .into()
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use serde_json::json;
136
137    #[test]
138    fn still_streaming_ui_text_with_path() {
139        let tool = CreateFileTool;
140        let input = json!({
141            "path": "src/main.rs",
142            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
143        });
144
145        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
146    }
147
148    #[test]
149    fn still_streaming_ui_text_without_path() {
150        let tool = CreateFileTool;
151        let input = json!({
152            "path": "",
153            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
154        });
155
156        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
157    }
158
159    #[test]
160    fn still_streaming_ui_text_with_null() {
161        let tool = CreateFileTool;
162        let input = serde_json::Value::Null;
163
164        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
165    }
166
167    #[test]
168    fn ui_text_with_valid_input() {
169        let tool = CreateFileTool;
170        let input = json!({
171            "path": "src/main.rs",
172            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
173        });
174
175        assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
176    }
177
178    #[test]
179    fn ui_text_with_invalid_input() {
180        let tool = CreateFileTool;
181        let input = json!({
182            "invalid": "field"
183        });
184
185        assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
186    }
187}