1pub mod assistant;
2mod assistant_settings;
3mod codegen;
4mod streaming_diff;
5
6use anyhow::{anyhow, Result};
7pub use assistant::AssistantPanel;
8use assistant_settings::OpenAIModel;
9use chrono::{DateTime, Local};
10use collections::HashMap;
11use fs::Fs;
12use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
13use gpui::{executor::Background, AppContext};
14use isahc::{http::StatusCode, Request, RequestExt};
15use regex::Regex;
16use serde::{Deserialize, Serialize};
17use std::{
18 cmp::Reverse,
19 ffi::OsStr,
20 fmt::{self, Display},
21 io,
22 path::PathBuf,
23 sync::Arc,
24};
25use util::paths::CONVERSATIONS_DIR;
26
27const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
28
29// Data types for chat completion requests
30#[derive(Debug, Default, Serialize)]
31pub struct OpenAIRequest {
32 model: String,
33 messages: Vec<RequestMessage>,
34 stream: bool,
35}
36
37#[derive(
38 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
39)]
40struct MessageId(usize);
41
42#[derive(Clone, Debug, Serialize, Deserialize)]
43struct MessageMetadata {
44 role: Role,
45 sent_at: DateTime<Local>,
46 status: MessageStatus,
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
50enum MessageStatus {
51 Pending,
52 Done,
53 Error(Arc<str>),
54}
55
56#[derive(Serialize, Deserialize)]
57struct SavedMessage {
58 id: MessageId,
59 start: usize,
60}
61
62#[derive(Serialize, Deserialize)]
63struct SavedConversation {
64 zed: String,
65 version: String,
66 text: String,
67 messages: Vec<SavedMessage>,
68 message_metadata: HashMap<MessageId, MessageMetadata>,
69 summary: String,
70 model: OpenAIModel,
71}
72
73impl SavedConversation {
74 const VERSION: &'static str = "0.1.0";
75}
76
77struct SavedConversationMetadata {
78 title: String,
79 path: PathBuf,
80 mtime: chrono::DateTime<chrono::Local>,
81}
82
83impl SavedConversationMetadata {
84 pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
85 fs.create_dir(&CONVERSATIONS_DIR).await?;
86
87 let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
88 let mut conversations = Vec::<SavedConversationMetadata>::new();
89 while let Some(path) = paths.next().await {
90 let path = path?;
91 if path.extension() != Some(OsStr::new("json")) {
92 continue;
93 }
94
95 let pattern = r" - \d+.zed.json$";
96 let re = Regex::new(pattern).unwrap();
97
98 let metadata = fs.metadata(&path).await?;
99 if let Some((file_name, metadata)) = path
100 .file_name()
101 .and_then(|name| name.to_str())
102 .zip(metadata)
103 {
104 let title = re.replace(file_name, "");
105 conversations.push(Self {
106 title: title.into_owned(),
107 path,
108 mtime: metadata.mtime.into(),
109 });
110 }
111 }
112 conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
113
114 Ok(conversations)
115 }
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119struct RequestMessage {
120 role: Role,
121 content: String,
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
125pub struct ResponseMessage {
126 role: Option<Role>,
127 content: Option<String>,
128}
129
130#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
131#[serde(rename_all = "lowercase")]
132enum Role {
133 User,
134 Assistant,
135 System,
136}
137
138impl Role {
139 pub fn cycle(&mut self) {
140 *self = match self {
141 Role::User => Role::Assistant,
142 Role::Assistant => Role::System,
143 Role::System => Role::User,
144 }
145 }
146}
147
148impl Display for Role {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 Role::User => write!(f, "User"),
152 Role::Assistant => write!(f, "Assistant"),
153 Role::System => write!(f, "System"),
154 }
155 }
156}
157
158#[derive(Deserialize, Debug)]
159pub struct OpenAIResponseStreamEvent {
160 pub id: Option<String>,
161 pub object: String,
162 pub created: u32,
163 pub model: String,
164 pub choices: Vec<ChatChoiceDelta>,
165 pub usage: Option<Usage>,
166}
167
168#[derive(Deserialize, Debug)]
169pub struct Usage {
170 pub prompt_tokens: u32,
171 pub completion_tokens: u32,
172 pub total_tokens: u32,
173}
174
175#[derive(Deserialize, Debug)]
176pub struct ChatChoiceDelta {
177 pub index: u32,
178 pub delta: ResponseMessage,
179 pub finish_reason: Option<String>,
180}
181
182#[derive(Deserialize, Debug)]
183struct OpenAIUsage {
184 prompt_tokens: u64,
185 completion_tokens: u64,
186 total_tokens: u64,
187}
188
189#[derive(Deserialize, Debug)]
190struct OpenAIChoice {
191 text: String,
192 index: u32,
193 logprobs: Option<serde_json::Value>,
194 finish_reason: Option<String>,
195}
196
197pub fn init(cx: &mut AppContext) {
198 assistant::init(cx);
199}
200
201pub async fn stream_completion(
202 api_key: String,
203 executor: Arc<Background>,
204 mut request: OpenAIRequest,
205) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
206 request.stream = true;
207
208 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
209
210 let json_data = serde_json::to_string(&request)?;
211 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
212 .header("Content-Type", "application/json")
213 .header("Authorization", format!("Bearer {}", api_key))
214 .body(json_data)?
215 .send_async()
216 .await?;
217
218 let status = response.status();
219 if status == StatusCode::OK {
220 executor
221 .spawn(async move {
222 let mut lines = BufReader::new(response.body_mut()).lines();
223
224 fn parse_line(
225 line: Result<String, io::Error>,
226 ) -> Result<Option<OpenAIResponseStreamEvent>> {
227 if let Some(data) = line?.strip_prefix("data: ") {
228 let event = serde_json::from_str(&data)?;
229 Ok(Some(event))
230 } else {
231 Ok(None)
232 }
233 }
234
235 while let Some(line) = lines.next().await {
236 if let Some(event) = parse_line(line).transpose() {
237 let done = event.as_ref().map_or(false, |event| {
238 event
239 .choices
240 .last()
241 .map_or(false, |choice| choice.finish_reason.is_some())
242 });
243 if tx.unbounded_send(event).is_err() {
244 break;
245 }
246
247 if done {
248 break;
249 }
250 }
251 }
252
253 anyhow::Ok(())
254 })
255 .detach();
256
257 Ok(rx)
258 } else {
259 let mut body = String::new();
260 response.body_mut().read_to_string(&mut body).await?;
261
262 #[derive(Deserialize)]
263 struct OpenAIResponse {
264 error: OpenAIError,
265 }
266
267 #[derive(Deserialize)]
268 struct OpenAIError {
269 message: String,
270 }
271
272 match serde_json::from_str::<OpenAIResponse>(&body) {
273 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
274 "Failed to connect to OpenAI API: {}",
275 response.error.message,
276 )),
277
278 _ => Err(anyhow!(
279 "Failed to connect to OpenAI API: {} {}",
280 response.status(),
281 body,
282 )),
283 }
284 }
285}
286
287#[cfg(test)]
288#[ctor::ctor]
289fn init_logger() {
290 if std::env::var("RUST_LOG").is_ok() {
291 env_logger::init();
292 }
293}