1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use gpui::AnyWindowHandle;
5use gpui::{App, Entity, Task};
6use language_model::LanguageModelRequestMessage;
7use language_model::LanguageModelToolSchemaFormat;
8use project::Project;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use std::sync::Arc;
12use ui::IconName;
13use util::markdown::MarkdownString;
14
15#[derive(Debug, Serialize, Deserialize, JsonSchema)]
16pub struct CreateFileToolInput {
17 /// The path where the file should be created.
18 ///
19 /// <example>
20 /// If the project has the following structure:
21 ///
22 /// - directory1/
23 /// - directory2/
24 ///
25 /// You can create a new file by providing a path of "directory1/new_file.txt"
26 /// </example>
27 pub path: String,
28
29 /// The text contents of the file to create.
30 ///
31 /// <example>
32 /// To create a file with the text "Hello, World!", provide contents of "Hello, World!"
33 /// </example>
34 pub contents: String,
35}
36
37#[derive(Debug, Serialize, Deserialize, JsonSchema)]
38struct PartialInput {
39 #[serde(default)]
40 path: String,
41 #[serde(default)]
42 contents: String,
43}
44
45pub struct CreateFileTool;
46
47const DEFAULT_UI_TEXT: &str = "Create file";
48
49impl Tool for CreateFileTool {
50 fn name(&self) -> String {
51 "create_file".into()
52 }
53
54 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
55 false
56 }
57
58 fn description(&self) -> String {
59 include_str!("./create_file_tool/description.md").into()
60 }
61
62 fn icon(&self) -> IconName {
63 IconName::FileCreate
64 }
65
66 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
67 json_schema_for::<CreateFileToolInput>(format)
68 }
69
70 fn ui_text(&self, input: &serde_json::Value) -> String {
71 match serde_json::from_value::<CreateFileToolInput>(input.clone()) {
72 Ok(input) => {
73 let path = MarkdownString::inline_code(&input.path);
74 format!("Create file {path}")
75 }
76 Err(_) => DEFAULT_UI_TEXT.to_string(),
77 }
78 }
79
80 fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
81 match serde_json::from_value::<PartialInput>(input.clone()).ok() {
82 Some(input) if !input.path.is_empty() => input.path,
83 _ => DEFAULT_UI_TEXT.to_string(),
84 }
85 }
86
87 fn run(
88 self: Arc<Self>,
89 input: serde_json::Value,
90 _messages: &[LanguageModelRequestMessage],
91 project: Entity<Project>,
92 action_log: Entity<ActionLog>,
93 _window: Option<AnyWindowHandle>,
94 cx: &mut App,
95 ) -> ToolResult {
96 let input = match serde_json::from_value::<CreateFileToolInput>(input) {
97 Ok(input) => input,
98 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
99 };
100 let project_path = match project.read(cx).find_project_path(&input.path, cx) {
101 Some(project_path) => project_path,
102 None => {
103 return Task::ready(Err(anyhow!("Path to create was outside the project"))).into();
104 }
105 };
106 let contents: Arc<str> = input.contents.as_str().into();
107 let destination_path: Arc<str> = input.path.as_str().into();
108
109 cx.spawn(async move |cx| {
110 let buffer = project
111 .update(cx, |project, cx| {
112 project.open_buffer(project_path.clone(), cx)
113 })?
114 .await
115 .map_err(|err| anyhow!("Unable to open buffer for {destination_path}: {err}"))?;
116 cx.update(|cx| {
117 action_log.update(cx, |action_log, cx| {
118 action_log.track_buffer(buffer.clone(), cx)
119 });
120 buffer.update(cx, |buffer, cx| buffer.set_text(contents, cx));
121 action_log.update(cx, |action_log, cx| {
122 action_log.buffer_edited(buffer.clone(), cx)
123 });
124 })?;
125
126 project
127 .update(cx, |project, cx| project.save_buffer(buffer, cx))?
128 .await
129 .map_err(|err| anyhow!("Unable to save buffer for {destination_path}: {err}"))?;
130
131 Ok(format!("Created file {destination_path}"))
132 })
133 .into()
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use serde_json::json;
141
142 #[test]
143 fn still_streaming_ui_text_with_path() {
144 let tool = CreateFileTool;
145 let input = json!({
146 "path": "src/main.rs",
147 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
148 });
149
150 assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
151 }
152
153 #[test]
154 fn still_streaming_ui_text_without_path() {
155 let tool = CreateFileTool;
156 let input = json!({
157 "path": "",
158 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
159 });
160
161 assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
162 }
163
164 #[test]
165 fn still_streaming_ui_text_with_null() {
166 let tool = CreateFileTool;
167 let input = serde_json::Value::Null;
168
169 assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
170 }
171
172 #[test]
173 fn ui_text_with_valid_input() {
174 let tool = CreateFileTool;
175 let input = json!({
176 "path": "src/main.rs",
177 "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
178 });
179
180 assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
181 }
182
183 #[test]
184 fn ui_text_with_invalid_input() {
185 let tool = CreateFileTool;
186 let input = json!({
187 "invalid": "field"
188 });
189
190 assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
191 }
192}