create_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::AnyWindowHandle;
  5use gpui::{App, Entity, Task};
  6use language_model::LanguageModelRequestMessage;
  7use language_model::LanguageModelToolSchemaFormat;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::sync::Arc;
 12use ui::IconName;
 13use util::markdown::MarkdownInlineCode;
 14
 15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 16pub struct CreateFileToolInput {
 17    /// The path where the file should be created.
 18    ///
 19    /// <example>
 20    /// If the project has the following structure:
 21    ///
 22    /// - directory1/
 23    /// - directory2/
 24    ///
 25    /// You can create a new file by providing a path of "directory1/new_file.txt"
 26    /// </example>
 27    ///
 28    /// Make sure to include this field before the `contents` field in the input object
 29    /// so that we can display it immediately.
 30    pub path: String,
 31
 32    /// The text contents of the file to create.
 33    ///
 34    /// <example>
 35    /// To create a file with the text "Hello, World!", provide contents of "Hello, World!"
 36    /// </example>
 37    pub contents: String,
 38}
 39
 40#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 41struct PartialInput {
 42    #[serde(default)]
 43    path: String,
 44    #[serde(default)]
 45    contents: String,
 46}
 47
 48pub struct CreateFileTool;
 49
 50const DEFAULT_UI_TEXT: &str = "Create file";
 51
 52impl Tool for CreateFileTool {
 53    fn name(&self) -> String {
 54        "create_file".into()
 55    }
 56
 57    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 58        false
 59    }
 60
 61    fn description(&self) -> String {
 62        include_str!("./create_file_tool/description.md").into()
 63    }
 64
 65    fn icon(&self) -> IconName {
 66        IconName::FileCreate
 67    }
 68
 69    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 70        json_schema_for::<CreateFileToolInput>(format)
 71    }
 72
 73    fn ui_text(&self, input: &serde_json::Value) -> String {
 74        match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
 75            Ok(input) => {
 76                let path = MarkdownInlineCode(&input.path);
 77                format!("Create file {path}")
 78            }
 79            Err(_) => DEFAULT_UI_TEXT.to_string(),
 80        }
 81    }
 82
 83    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
 84        match serde_json::from_value::<PartialInput>(input.clone()).ok() {
 85            Some(input) if !input.path.is_empty() => input.path,
 86            _ => DEFAULT_UI_TEXT.to_string(),
 87        }
 88    }
 89
 90    fn run(
 91        self: Arc<Self>,
 92        input: serde_json::Value,
 93        _messages: &[LanguageModelRequestMessage],
 94        project: Entity<Project>,
 95        action_log: Entity<ActionLog>,
 96        _window: Option<AnyWindowHandle>,
 97        cx: &mut App,
 98    ) -> ToolResult {
 99        let input = match serde_json::from_value::<CreateFileToolInput>(input) {
100            Ok(input) => input,
101            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
102        };
103        let project_path = match project.read(cx).find_project_path(&input.path, cx) {
104            Some(project_path) => project_path,
105            None => {
106                return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
107            }
108        };
109        let contents: Arc<str> = input.contents.as_str().into();
110        let destination_path: Arc<str> = input.path.as_str().into();
111
112        cx.spawn(async move |cx| {
113            let buffer = project
114                .update(cx, |project, cx| {
115                    project.open_buffer(project_path.clone(), cx)
116                })?
117                .await
118                .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
119            cx.update(|cx| {
120                action_log.update(cx, |action_log, cx| {
121                    action_log.buffer_created(buffer.clone(), cx)
122                });
123                buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
124                action_log.update(cx, |action_log, cx| {
125                    action_log.buffer_edited(buffer.clone(), cx)
126                });
127            })?;
128
129            project
130                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
131                .await
132                .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
133
134            Ok(format!("Created file {destination_path}"))
135        })
136        .into()
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use serde_json::json;
144
145    #[test]
146    fn still_streaming_ui_text_with_path() {
147        let tool = CreateFileTool;
148        let input = json!({
149            "path": "src/main.rs",
150            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
151        });
152
153        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
154    }
155
156    #[test]
157    fn still_streaming_ui_text_without_path() {
158        let tool = CreateFileTool;
159        let input = json!({
160            "path": "",
161            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
162        });
163
164        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
165    }
166
167    #[test]
168    fn still_streaming_ui_text_with_null() {
169        let tool = CreateFileTool;
170        let input = serde_json::Value::Null;
171
172        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
173    }
174
175    #[test]
176    fn ui_text_with_valid_input() {
177        let tool = CreateFileTool;
178        let input = json!({
179            "path": "src/main.rs",
180            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
181        });
182
183        assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
184    }
185
186    #[test]
187    fn ui_text_with_invalid_input() {
188        let tool = CreateFileTool;
189        let input = json!({
190            "invalid": "field"
191        });
192
193        assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
194    }
195}