create_file_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use gpui::AnyWindowHandle;
  5use gpui::{App, Entity, Task};
  6use language_model::LanguageModelRequestMessage;
  7use language_model::LanguageModelToolSchemaFormat;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use std::sync::Arc;
 12use ui::IconName;
 13use util::markdown::MarkdownString;
 14
 15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 16pub struct CreateFileToolInput {
 17    /// The path where the file should be created.
 18    ///
 19    /// <example>
 20    /// If the project has the following structure:
 21    ///
 22    /// - directory1/
 23    /// - directory2/
 24    ///
 25    /// You can create a new file by providing a path of "directory1/new_file.txt"
 26    /// </example>
 27    pub path: String,
 28
 29    /// The text contents of the file to create.
 30    ///
 31    /// <example>
 32    /// To create a file with the text "Hello, World!", provide contents of "Hello, World!"
 33    /// </example>
 34    pub contents: String,
 35}
 36
 37#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 38struct PartialInput {
 39    #[serde(default)]
 40    path: String,
 41    #[serde(default)]
 42    contents: String,
 43}
 44
 45pub struct CreateFileTool;
 46
 47const DEFAULT_UI_TEXT: &str = "Create file";
 48
 49impl Tool for CreateFileTool {
 50    fn name(&self) -> String {
 51        "create_file".into()
 52    }
 53
 54    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 55        false
 56    }
 57
 58    fn description(&self) -> String {
 59        include_str!("./create_file_tool/description.md").into()
 60    }
 61
 62    fn icon(&self) -> IconName {
 63        IconName::FileCreate
 64    }
 65
 66    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 67        json_schema_for::<CreateFileToolInput>(format)
 68    }
 69
 70    fn ui_text(&self, input: &serde_json::Value) -> String {
 71        match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
 72            Ok(input) => {
 73                let path = MarkdownString::inline_code(&input.path);
 74                format!("Create file {path}")
 75            }
 76            Err(_) => DEFAULT_UI_TEXT.to_string(),
 77        }
 78    }
 79
 80    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
 81        match serde_json::from_value::<PartialInput>(input.clone()).ok() {
 82            Some(input) if !input.path.is_empty() => input.path,
 83            _ => DEFAULT_UI_TEXT.to_string(),
 84        }
 85    }
 86
 87    fn run(
 88        self: Arc<Self>,
 89        input: serde_json::Value,
 90        _messages: &[LanguageModelRequestMessage],
 91        project: Entity<Project>,
 92        action_log: Entity<ActionLog>,
 93        _window: Option<AnyWindowHandle>,
 94        cx: &mut App,
 95    ) -> ToolResult {
 96        let input = match serde_json::from_value::<CreateFileToolInput>(input) {
 97            Ok(input) => input,
 98            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 99        };
100        let project_path = match project.read(cx).find_project_path(&input.path, cx) {
101            Some(project_path) => project_path,
102            None => {
103                return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
104            }
105        };
106        let contents: Arc<str> = input.contents.as_str().into();
107        let destination_path: Arc<str> = input.path.as_str().into();
108
109        cx.spawn(async move |cx| {
110            let buffer = project
111                .update(cx, |project, cx| {
112                    project.open_buffer(project_path.clone(), cx)
113                })?
114                .await
115                .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
116            cx.update(|cx| {
117                action_log.update(cx, |action_log, cx| {
118                    action_log.track_buffer(buffer.clone(), cx)
119                });
120                buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
121                action_log.update(cx, |action_log, cx| {
122                    action_log.buffer_edited(buffer.clone(), cx)
123                });
124            })?;
125
126            project
127                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
128                .await
129                .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
130
131            Ok(format!("Created file {destination_path}"))
132        })
133        .into()
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use serde_json::json;
141
142    #[test]
143    fn still_streaming_ui_text_with_path() {
144        let tool = CreateFileTool;
145        let input = json!({
146            "path": "src/main.rs",
147            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
148        });
149
150        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
151    }
152
153    #[test]
154    fn still_streaming_ui_text_without_path() {
155        let tool = CreateFileTool;
156        let input = json!({
157            "path": "",
158            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
159        });
160
161        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
162    }
163
164    #[test]
165    fn still_streaming_ui_text_with_null() {
166        let tool = CreateFileTool;
167        let input = serde_json::Value::Null;
168
169        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
170    }
171
172    #[test]
173    fn ui_text_with_valid_input() {
174        let tool = CreateFileTool;
175        let input = json!({
176            "path": "src/main.rs",
177            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
178        });
179
180        assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
181    }
182
183    #[test]
184    fn ui_text_with_invalid_input() {
185        let tool = CreateFileTool;
186        let input = json!({
187            "invalid": "field"
188        });
189
190        assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
191    }
192}