create_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::{App, Entity, Task};
  5use language_model::LanguageModelRequestMessage;
  6use language_model::LanguageModelToolSchemaFormat;
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::sync::Arc;
 11use ui::IconName;
 12use util::markdown::MarkdownString;
 13
 14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 15pub struct CreateFileToolInput {
 16    /// The path where the file should be created.
 17    ///
 18    /// <example>
 19    /// If the project has the following structure:
 20    ///
 21    /// - directory1/
 22    /// - directory2/
 23    ///
 24    /// You can create a new file by providing a path of "directory1/new_file.txt"
 25    /// </example>
 26    pub path: String,
 27
 28    /// The text contents of the file to create.
 29    ///
 30    /// <example>
 31    /// To create a file with the text "Hello, World!", provide contents of "Hello, World!"
 32    /// </example>
 33    pub contents: String,
 34}
 35
 36#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 37struct PartialInput {
 38    #[serde(default)]
 39    path: String,
 40    #[serde(default)]
 41    contents: String,
 42}
 43
 44pub struct CreateFileTool;
 45
 46const DEFAULT_UI_TEXT: &str = "Create file";
 47
 48impl Tool for CreateFileTool {
 49    fn name(&self) -> String {
 50        "create_file".into()
 51    }
 52
 53    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 54        false
 55    }
 56
 57    fn description(&self) -> String {
 58        include_str!("./create_file_tool/description.md").into()
 59    }
 60
 61    fn icon(&self) -> IconName {
 62        IconName::FileCreate
 63    }
 64
 65    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 66        json_schema_for::<CreateFileToolInput>(format)
 67    }
 68
 69    fn ui_text(&self, input: &serde_json::Value) -> String {
 70        match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
 71            Ok(input) => {
 72                let path = MarkdownString::inline_code(&input.path);
 73                format!("Create file {path}")
 74            }
 75            Err(_) => DEFAULT_UI_TEXT.to_string(),
 76        }
 77    }
 78
 79    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
 80        match serde_json::from_value::<PartialInput>(input.clone()).ok() {
 81            Some(input) if !input.path.is_empty() => input.path,
 82            _ => DEFAULT_UI_TEXT.to_string(),
 83        }
 84    }
 85
 86    fn run(
 87        self: Arc<Self>,
 88        input: serde_json::Value,
 89        _messages: &[LanguageModelRequestMessage],
 90        project: Entity<Project>,
 91        action_log: Entity<ActionLog>,
 92        cx: &mut App,
 93    ) -> ToolResult {
 94        let input = match serde_json::from_value::<CreateFileToolInput>(input) {
 95            Ok(input) => input,
 96            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 97        };
 98        let project_path = match project.read(cx).find_project_path(&input.path, cx) {
 99            Some(project_path) => project_path,
100            None => {
101                return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
102            }
103        };
104        let contents: Arc<str> = input.contents.as_str().into();
105        let destination_path: Arc<str> = input.path.as_str().into();
106
107        cx.spawn(async move |cx| {
108            let buffer = project
109                .update(cx, |project, cx| {
110                    project.open_buffer(project_path.clone(), cx)
111                })?
112                .await
113                .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
114            cx.update(|cx| {
115                action_log.update(cx, |action_log, cx| {
116                    action_log.track_buffer(buffer.clone(), cx)
117                });
118                buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
119                action_log.update(cx, |action_log, cx| {
120                    action_log.buffer_edited(buffer.clone(), cx)
121                });
122            })?;
123
124            project
125                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
126                .await
127                .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
128
129            Ok(format!("Created file {destination_path}"))
130        })
131        .into()
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use serde_json::json;
139
140    #[test]
141    fn still_streaming_ui_text_with_path() {
142        let tool = CreateFileTool;
143        let input = json!({
144            "path": "src/main.rs",
145            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
146        });
147
148        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
149    }
150
151    #[test]
152    fn still_streaming_ui_text_without_path() {
153        let tool = CreateFileTool;
154        let input = json!({
155            "path": "",
156            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
157        });
158
159        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
160    }
161
162    #[test]
163    fn still_streaming_ui_text_with_null() {
164        let tool = CreateFileTool;
165        let input = serde_json::Value::Null;
166
167        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
168    }
169
170    #[test]
171    fn ui_text_with_valid_input() {
172        let tool = CreateFileTool;
173        let input = json!({
174            "path": "src/main.rs",
175            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
176        });
177
178        assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
179    }
180
181    #[test]
182    fn ui_text_with_invalid_input() {
183        let tool = CreateFileTool;
184        let input = json!({
185            "invalid": "field"
186        });
187
188        assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
189    }
190}