1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use gpui::{App, Entity, Task};
5use language_model::LanguageModelRequestMessage;
6use language_model::LanguageModelToolSchemaFormat;
7use project::Project;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use ui::IconName;
12use util::markdown::MarkdownString;
13
14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
15pub struct CreateFileToolInput {
16 /// The path where the file should be created.
17 ///
18 /// <example>
19 /// If the project has the following structure:
20 ///
21 /// - directory1/
22 /// - directory2/
23 ///
24 /// You can create a new file by providing a path of "directory1/new_file.txt"
25 /// </example>
26 pub path: String,
27
28 /// The text contents of the file to create.
29 ///
30 /// <example>
31 /// To create a file with the text "Hello, World!", provide contents of "Hello, World!"
32 /// </example>
33 pub contents: String,
34}
35
36#[derive(Debug, Serialize, Deserialize, JsonSchema)]
37struct PartialInput {
38 #[serde(default)]
39 path: String,
40 #[serde(default)]
41 contents: String,
42}
43
44pub struct CreateFileTool;
45
46const DEFAULT_UI_TEXT: &str = "Create file";
47
48impl Tool for CreateFileTool {
49 fn name(&self) -> String {
50 "create_file".into()
51 }
52
53 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
54 false
55 }
56
57 fn description(&self) -> String {
58 include_str!("./create_file_tool/description.md").into()
59 }
60
61 fn icon(&self) -> IconName {
62 IconName::FileCreate
63 }
64
65 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
66 json_schema_for::<CreateFileToolInput>(format)
67 }
68
69 fn ui_text(&self, input: &serde_json::Value) -> String {
70 match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
71 Ok(input) => {
72 let path = MarkdownString::inline_code(&input.path);
73 format!("Create file {path}")
74 }
75 Err(_) => DEFAULT_UI_TEXT.to_string(),
76 }
77 }
78
79 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
80 match serde_json::from_value::<PartialInput>(input.clone()).ok() {
81 Some(input) if !input.path.is_empty() => input.path,
82 _ => DEFAULT_UI_TEXT.to_string(),
83 }
84 }
85
86 fn run(
87 self: Arc<Self>,
88 input: serde_json::Value,
89 _messages: &[LanguageModelRequestMessage],
90 project: Entity<Project>,
91 action_log: Entity<ActionLog>,
92 cx: &mut App,
93 ) -> ToolResult {
94 let input = match serde_json::from_value::<CreateFileToolInput>(input) {
95 Ok(input) => input,
96 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
97 };
98 let project_path = match project.read(cx).find_project_path(&input.path, cx) {
99 Some(project_path) => project_path,
100 None => {
101 return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
102 }
103 };
104 let contents: Arc<str> = input.contents.as_str().into();
105 let destination_path: Arc<str> = input.path.as_str().into();
106
107 cx.spawn(async move |cx| {
108 let buffer = project
109 .update(cx, |project, cx| {
110 project.open_buffer(project_path.clone(), cx)
111 })?
112 .await
113 .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
114 cx.update(|cx| {
115 action_log.update(cx, |action_log, cx| {
116 action_log.track_buffer(buffer.clone(), cx)
117 });
118 buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
119 action_log.update(cx, |action_log, cx| {
120 action_log.buffer_edited(buffer.clone(), cx)
121 });
122 })?;
123
124 project
125 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
126 .await
127 .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
128
129 Ok(format!("Created file {destination_path}"))
130 })
131 .into()
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use serde_json::json;
139
140 #[test]
141 fn still_streaming_ui_text_with_path() {
142 let tool = CreateFileTool;
143 let input = json!({
144 "path": "src/main.rs",
145 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
146 });
147
148 assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
149 }
150
151 #[test]
152 fn still_streaming_ui_text_without_path() {
153 let tool = CreateFileTool;
154 let input = json!({
155 "path": "",
156 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
157 });
158
159 assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
160 }
161
162 #[test]
163 fn still_streaming_ui_text_with_null() {
164 let tool = CreateFileTool;
165 let input = serde_json::Value::Null;
166
167 assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
168 }
169
170 #[test]
171 fn ui_text_with_valid_input() {
172 let tool = CreateFileTool;
173 let input = json!({
174 "path": "src/main.rs",
175 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
176 });
177
178 assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
179 }
180
181 #[test]
182 fn ui_text_with_invalid_input() {
183 let tool = CreateFileTool;
184 let input = json!({
185 "invalid": "field"
186 });
187
188 assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
189 }
190}