1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use gpui::{App, Entity, Task};
5use language_model::LanguageModelRequestMessage;
6use language_model::LanguageModelToolSchemaFormat;
7use project::Project;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::sync::Arc;
11use ui::IconName;
12use util::markdown::MarkdownString;
13
14#[derive(Debug, Serialize, Deserialize, JsonSchema)]
15pub struct CreateFileToolInput {
16 /// The path where the file should be created.
17 ///
18 /// <example>
19 /// If the project has the following structure:
20 ///
21 /// - directory1/
22 /// - directory2/
23 ///
24 /// You can create a new file by providing a path of "directory1/new_file.txt"
25 /// </example>
26 pub path: String,
27
28 /// The text contents of the file to create.
29 ///
30 /// <example>
31 /// To create a file with the text "Hello, World!", provide contents of "Hello, World!"
32 /// </example>
33 pub contents: String,
34}
35
36#[derive(Debug, Serialize, Deserialize, JsonSchema)]
37struct PartialInput {
38 #[serde(default)]
39 path: String,
40 #[serde(default)]
41 contents: String,
42}
43
44pub struct CreateFileTool;
45
46const DEFAULT_UI_TEXT: &str = "Create file";
47
48impl Tool for CreateFileTool {
49 fn name(&self) -> String {
50 "create_file".into()
51 }
52
53 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
54 false
55 }
56
57 fn description(&self) -> String {
58 include_str!("./create_file_tool/description.md").into()
59 }
60
61 fn icon(&self) -> IconName {
62 IconName::FileCreate
63 }
64
65 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
66 json_schema_for::<CreateFileToolInput>(format)
67 }
68
69 fn ui_text(&self, input: &serde_json::Value) -> String {
70 match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
71 Ok(input) => {
72 let path = MarkdownString::inline_code(&input.path);
73 format!("Create file {path}")
74 }
75 Err(_) => DEFAULT_UI_TEXT.to_string(),
76 }
77 }
78
79 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
80 match serde_json::from_value::<PartialInput>(input.clone()).ok() {
81 Some(input) if !input.path.is_empty() => input.path,
82 _ => DEFAULT_UI_TEXT.to_string(),
83 }
84 }
85
86 fn run(
87 self: Arc<Self>,
88 input: serde_json::Value,
89 _messages: &[LanguageModelRequestMessage],
90 project: Entity<Project>,
91 action_log: Entity<ActionLog>,
92 cx: &mut App,
93 ) -> ToolResult {
94 let input = match serde_json::from_value::<CreateFileToolInput>(input) {
95 Ok(input) => input,
96 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
97 };
98 let project_path = match project.read(cx).find_project_path(&input.path, cx) {
99 Some(project_path) => project_path,
100 None => {
101 return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
102 }
103 };
104 let contents: Arc<str> = input.contents.as_str().into();
105 let destination_path: Arc<str> = input.path.as_str().into();
106
107 cx.spawn(async move |cx| {
108 let buffer = project
109 .update(cx, |project, cx| {
110 project.open_buffer(project_path.clone(), cx)
111 })?
112 .await
113 .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
114 cx.update(|cx| {
115 buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
116 action_log.update(cx, |action_log, cx| {
117 action_log.will_create_buffer(buffer.clone(), cx)
118 });
119 })?;
120
121 project
122 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
123 .await
124 .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
125
126 Ok(format!("Created file {destination_path}"))
127 })
128 .into()
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use serde_json::json;
136
137 #[test]
138 fn still_streaming_ui_text_with_path() {
139 let tool = CreateFileTool;
140 let input = json!({
141 "path": "src/main.rs",
142 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
143 });
144
145 assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
146 }
147
148 #[test]
149 fn still_streaming_ui_text_without_path() {
150 let tool = CreateFileTool;
151 let input = json!({
152 "path": "",
153 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
154 });
155
156 assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
157 }
158
159 #[test]
160 fn still_streaming_ui_text_with_null() {
161 let tool = CreateFileTool;
162 let input = serde_json::Value::Null;
163
164 assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
165 }
166
167 #[test]
168 fn ui_text_with_valid_input() {
169 let tool = CreateFileTool;
170 let input = json!({
171 "path": "src/main.rs",
172 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
173 });
174
175 assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
176 }
177
178 #[test]
179 fn ui_text_with_invalid_input() {
180 let tool = CreateFileTool;
181 let input = json!({
182 "invalid": "field"
183 });
184
185 assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
186 }
187}