1use anyhow::{anyhow, Result};
2use collections::HashMap;
3use editor::Editor;
4use futures::AsyncBufReadExt;
5use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
6use gpui::executor::Background;
7use gpui::{actions, AppContext, Task, ViewContext};
8use isahc::prelude::*;
9use isahc::{http::StatusCode, Request};
10use serde::{Deserialize, Serialize};
11use std::cell::RefCell;
12use std::fs;
13use std::rc::Rc;
14use std::{io, sync::Arc};
15use util::channel::{ReleaseChannel, RELEASE_CHANNEL};
16use util::{ResultExt, TryFutureExt};
17
18use rust_embed::RustEmbed;
19use std::str;
20
21#[derive(RustEmbed)]
22#[folder = "../../assets/contexts"]
23#[exclude = "*.DS_Store"]
24pub struct ContextAssets;
25
26actions!(ai, [Assist]);
27
28// Data types for chat completion requests
29#[derive(Serialize)]
30struct OpenAIRequest {
31 model: String,
32 messages: Vec<RequestMessage>,
33 stream: bool,
34}
35
36#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
37struct RequestMessage {
38 role: Role,
39 content: String,
40}
41
42#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
43struct ResponseMessage {
44 role: Option<Role>,
45 content: Option<String>,
46}
47
48#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
49#[serde(rename_all = "lowercase")]
50enum Role {
51 User,
52 Assistant,
53 System,
54}
55
56#[derive(Deserialize, Debug)]
57struct OpenAIResponseStreamEvent {
58 pub id: Option<String>,
59 pub object: String,
60 pub created: u32,
61 pub model: String,
62 pub choices: Vec<ChatChoiceDelta>,
63 pub usage: Option<Usage>,
64}
65
66#[derive(Deserialize, Debug)]
67struct Usage {
68 pub prompt_tokens: u32,
69 pub completion_tokens: u32,
70 pub total_tokens: u32,
71}
72
73#[derive(Deserialize, Debug)]
74struct ChatChoiceDelta {
75 pub index: u32,
76 pub delta: ResponseMessage,
77 pub finish_reason: Option<String>,
78}
79
80#[derive(Deserialize, Debug)]
81struct OpenAIUsage {
82 prompt_tokens: u64,
83 completion_tokens: u64,
84 total_tokens: u64,
85}
86
87#[derive(Deserialize, Debug)]
88struct OpenAIChoice {
89 text: String,
90 index: u32,
91 logprobs: Option<serde_json::Value>,
92 finish_reason: Option<String>,
93}
94
95pub fn init(cx: &mut AppContext) {
96 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
97 return;
98 }
99
100 let assistant = Rc::new(Assistant::default());
101 cx.add_action({
102 let assistant = assistant.clone();
103 move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext<Editor>| {
104 assistant.assist(editor, cx).log_err();
105 }
106 });
107 cx.capture_action({
108 let assistant = assistant.clone();
109 move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext<Editor>| {
110 if !assistant.cancel_last_assist(cx.view_id()) {
111 cx.propagate_action();
112 }
113 }
114 });
115}
116
117type CompletionId = usize;
118
119#[derive(Default)]
120struct Assistant(RefCell<AssistantState>);
121
122#[derive(Default)]
123struct AssistantState {
124 assist_stacks: HashMap<usize, Vec<(CompletionId, Task<Option<()>>)>>,
125 next_completion_id: CompletionId,
126}
127
128impl Assistant {
129 fn assist(self: &Rc<Self>, editor: &mut Editor, cx: &mut ViewContext<Editor>) -> Result<()> {
130 let api_key = std::env::var("OPENAI_API_KEY")?;
131
132 let selections = editor.selections.all(cx);
133 let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
134 // Insert markers around selected text as described in the system prompt above.
135 let snapshot = buffer.snapshot(cx);
136 let mut user_message = String::new();
137 let mut user_message_suffix = String::new();
138 let mut buffer_offset = 0;
139 for selection in selections {
140 if !selection.is_empty() {
141 if user_message_suffix.is_empty() {
142 user_message_suffix.push_str("\n\n");
143 }
144 user_message_suffix.push_str("[Selected excerpt from above]\n");
145 user_message_suffix
146 .extend(snapshot.text_for_range(selection.start..selection.end));
147 user_message_suffix.push_str("\n\n");
148 }
149
150 user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
151 user_message.push_str("[SELECTION_START]");
152 user_message.extend(snapshot.text_for_range(selection.start..selection.end));
153 buffer_offset = selection.end;
154 user_message.push_str("[SELECTION_END]");
155 }
156 if buffer_offset < snapshot.len() {
157 user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
158 }
159 user_message.push_str(&user_message_suffix);
160
161 // Ensure the document ends with 4 trailing newlines.
162 let trailing_newline_count = snapshot
163 .reversed_chars_at(snapshot.len())
164 .take_while(|c| *c == '\n')
165 .take(4);
166 let buffer_suffix = "\n".repeat(4 - trailing_newline_count.count());
167 buffer.edit([(snapshot.len()..snapshot.len(), buffer_suffix)], None, cx);
168
169 let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
170 let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
171
172 (user_message, insertion_site)
173 });
174
175 let this = self.clone();
176 let buffer = editor.buffer().clone();
177 let executor = cx.background_executor().clone();
178 let editor_id = cx.view_id();
179 let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id);
180 let assist_task = cx.spawn(|_, mut cx| {
181 async move {
182 // TODO: We should have a get_string method on assets. This is repateated elsewhere.
183 let content = ContextAssets::get("system.zmd").unwrap();
184 let mut system_message = std::str::from_utf8(content.data.as_ref())
185 .unwrap()
186 .to_string();
187
188 if let Ok(custom_system_message_path) =
189 std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH")
190 {
191 system_message.push_str(
192 "\n\nAlso consider the following user-defined system prompt:\n\n",
193 );
194 // TODO: Replace this with our file system trait object.
195 system_message.push_str(
196 &cx.background()
197 .spawn(async move { fs::read_to_string(custom_system_message_path) })
198 .await?,
199 );
200 }
201
202 let stream = stream_completion(
203 api_key,
204 executor,
205 OpenAIRequest {
206 model: "gpt-4".to_string(),
207 messages: vec![
208 RequestMessage {
209 role: Role::System,
210 content: system_message.to_string(),
211 },
212 RequestMessage {
213 role: Role::User,
214 content: user_message,
215 },
216 ],
217 stream: false,
218 },
219 );
220
221 let mut messages = stream.await?;
222 while let Some(message) = messages.next().await {
223 let mut message = message?;
224 if let Some(choice) = message.choices.pop() {
225 buffer.update(&mut cx, |buffer, cx| {
226 let text: Arc<str> = choice.delta.content?.into();
227 buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
228 Some(())
229 });
230 }
231 }
232
233 this.0
234 .borrow_mut()
235 .assist_stacks
236 .get_mut(&editor_id)
237 .unwrap()
238 .retain(|(id, _)| *id != assist_id);
239
240 anyhow::Ok(())
241 }
242 .log_err()
243 });
244
245 self.0
246 .borrow_mut()
247 .assist_stacks
248 .entry(cx.view_id())
249 .or_default()
250 .push((assist_id, assist_task));
251
252 Ok(())
253 }
254
255 fn cancel_last_assist(self: &Rc<Self>, editor_id: usize) -> bool {
256 self.0
257 .borrow_mut()
258 .assist_stacks
259 .get_mut(&editor_id)
260 .and_then(|assists| assists.pop())
261 .is_some()
262 }
263}
264
265async fn stream_completion(
266 api_key: String,
267 executor: Arc<Background>,
268 mut request: OpenAIRequest,
269) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
270 request.stream = true;
271
272 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
273
274 let json_data = serde_json::to_string(&request)?;
275 let mut response = Request::post("https://api.openai.com/v1/chat/completions")
276 .header("Content-Type", "application/json")
277 .header("Authorization", format!("Bearer {}", api_key))
278 .body(json_data)?
279 .send_async()
280 .await?;
281
282 let status = response.status();
283 if status == StatusCode::OK {
284 executor
285 .spawn(async move {
286 let mut lines = BufReader::new(response.body_mut()).lines();
287
288 fn parse_line(
289 line: Result<String, io::Error>,
290 ) -> Result<Option<OpenAIResponseStreamEvent>> {
291 if let Some(data) = line?.strip_prefix("data: ") {
292 let event = serde_json::from_str(&data)?;
293 Ok(Some(event))
294 } else {
295 Ok(None)
296 }
297 }
298
299 while let Some(line) = lines.next().await {
300 if let Some(event) = parse_line(line).transpose() {
301 tx.unbounded_send(event).log_err();
302 }
303 }
304
305 anyhow::Ok(())
306 })
307 .detach();
308
309 Ok(rx)
310 } else {
311 let mut body = String::new();
312 response.body_mut().read_to_string(&mut body).await?;
313
314 Err(anyhow!(
315 "Failed to connect to OpenAI API: {} {}",
316 response.status(),
317 body,
318 ))
319 }
320}