1pub mod assistant;
2mod assistant_settings;
3mod streaming_diff;
4
5use anyhow::{anyhow, Result};
6pub use assistant::AssistantPanel;
7use chrono::{DateTime, Local};
8use collections::HashMap;
9use fs::Fs;
10use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
11use gpui::{executor::Background, AppContext};
12use isahc::{http::StatusCode, Request, RequestExt};
13use regex::Regex;
14use serde::{Deserialize, Serialize};
15use std::{
16 cmp::Reverse,
17 ffi::OsStr,
18 fmt::{self, Display},
19 io,
20 path::PathBuf,
21 sync::Arc,
22};
23use util::paths::CONVERSATIONS_DIR;
24
25const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
26
27// Data types for chat completion requests
28#[derive(Debug, Serialize)]
29pub struct OpenAIRequest {
30 model: String,
31 messages: Vec<RequestMessage>,
32 stream: bool,
33}
34
35#[derive(
36 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
37)]
38struct MessageId(usize);
39
40#[derive(Clone, Debug, Serialize, Deserialize)]
41struct MessageMetadata {
42 role: Role,
43 sent_at: DateTime<Local>,
44 status: MessageStatus,
45}
46
47#[derive(Clone, Debug, Serialize, Deserialize)]
48enum MessageStatus {
49 Pending,
50 Done,
51 Error(Arc<str>),
52}
53
54#[derive(Serialize, Deserialize)]
55struct SavedMessage {
56 id: MessageId,
57 start: usize,
58}
59
60#[derive(Serialize, Deserialize)]
61struct SavedConversation {
62 zed: String,
63 version: String,
64 text: String,
65 messages: Vec<SavedMessage>,
66 message_metadata: HashMap<MessageId, MessageMetadata>,
67 summary: String,
68 model: String,
69}
70
71impl SavedConversation {
72 const VERSION: &'static str = "0.1.0";
73}
74
75struct SavedConversationMetadata {
76 title: String,
77 path: PathBuf,
78 mtime: chrono::DateTime<chrono::Local>,
79}
80
81impl SavedConversationMetadata {
82 pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
83 fs.create_dir(&CONVERSATIONS_DIR).await?;
84
85 let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
86 let mut conversations = Vec::<SavedConversationMetadata>::new();
87 while let Some(path) = paths.next().await {
88 let path = path?;
89 if path.extension() != Some(OsStr::new("json")) {
90 continue;
91 }
92
93 let pattern = r" - \d+.zed.json$";
94 let re = Regex::new(pattern).unwrap();
95
96 let metadata = fs.metadata(&path).await?;
97 if let Some((file_name, metadata)) = path
98 .file_name()
99 .and_then(|name| name.to_str())
100 .zip(metadata)
101 {
102 let title = re.replace(file_name, "");
103 conversations.push(Self {
104 title: title.into_owned(),
105 path,
106 mtime: metadata.mtime.into(),
107 });
108 }
109 }
110 conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
111
112 Ok(conversations)
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117struct RequestMessage {
118 role: Role,
119 content: String,
120}
121
122#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
123pub struct ResponseMessage {
124 role: Option<Role>,
125 content: Option<String>,
126}
127
128#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
129#[serde(rename_all = "lowercase")]
130enum Role {
131 User,
132 Assistant,
133 System,
134}
135
136impl Role {
137 pub fn cycle(&mut self) {
138 *self = match self {
139 Role::User => Role::Assistant,
140 Role::Assistant => Role::System,
141 Role::System => Role::User,
142 }
143 }
144}
145
146impl Display for Role {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
148 match self {
149 Role::User => write!(f, "User"),
150 Role::Assistant => write!(f, "Assistant"),
151 Role::System => write!(f, "System"),
152 }
153 }
154}
155
156#[derive(Deserialize, Debug)]
157pub struct OpenAIResponseStreamEvent {
158 pub id: Option<String>,
159 pub object: String,
160 pub created: u32,
161 pub model: String,
162 pub choices: Vec<ChatChoiceDelta>,
163 pub usage: Option<Usage>,
164}
165
166#[derive(Deserialize, Debug)]
167pub struct Usage {
168 pub prompt_tokens: u32,
169 pub completion_tokens: u32,
170 pub total_tokens: u32,
171}
172
173#[derive(Deserialize, Debug)]
174pub struct ChatChoiceDelta {
175 pub index: u32,
176 pub delta: ResponseMessage,
177 pub finish_reason: Option<String>,
178}
179
180#[derive(Deserialize, Debug)]
181struct OpenAIUsage {
182 prompt_tokens: u64,
183 completion_tokens: u64,
184 total_tokens: u64,
185}
186
187#[derive(Deserialize, Debug)]
188struct OpenAIChoice {
189 text: String,
190 index: u32,
191 logprobs: Option<serde_json::Value>,
192 finish_reason: Option<String>,
193}
194
195pub fn init(cx: &mut AppContext) {
196 assistant::init(cx);
197}
198
199pub async fn stream_completion(
200 api_key: String,
201 executor: Arc<Background>,
202 mut request: OpenAIRequest,
203) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
204 request.stream = true;
205
206 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
207
208 let json_data = serde_json::to_string(&request)?;
209 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
210 .header("Content-Type", "application/json")
211 .header("Authorization", format!("Bearer {}", api_key))
212 .body(json_data)?
213 .send_async()
214 .await?;
215
216 let status = response.status();
217 if status == StatusCode::OK {
218 executor
219 .spawn(async move {
220 let mut lines = BufReader::new(response.body_mut()).lines();
221
222 fn parse_line(
223 line: Result<String, io::Error>,
224 ) -> Result<Option<OpenAIResponseStreamEvent>> {
225 if let Some(data) = line?.strip_prefix("data: ") {
226 let event = serde_json::from_str(&data)?;
227 Ok(Some(event))
228 } else {
229 Ok(None)
230 }
231 }
232
233 while let Some(line) = lines.next().await {
234 if let Some(event) = parse_line(line).transpose() {
235 let done = event.as_ref().map_or(false, |event| {
236 event
237 .choices
238 .last()
239 .map_or(false, |choice| choice.finish_reason.is_some())
240 });
241 if tx.unbounded_send(event).is_err() {
242 break;
243 }
244
245 if done {
246 break;
247 }
248 }
249 }
250
251 anyhow::Ok(())
252 })
253 .detach();
254
255 Ok(rx)
256 } else {
257 let mut body = String::new();
258 response.body_mut().read_to_string(&mut body).await?;
259
260 #[derive(Deserialize)]
261 struct OpenAIResponse {
262 error: OpenAIError,
263 }
264
265 #[derive(Deserialize)]
266 struct OpenAIError {
267 message: String,
268 }
269
270 match serde_json::from_str::<OpenAIResponse>(&body) {
271 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
272 "Failed to connect to OpenAI API: {}",
273 response.error.message,
274 )),
275
276 _ => Err(anyhow!(
277 "Failed to connect to OpenAI API: {} {}",
278 response.status(),
279 body,
280 )),
281 }
282 }
283}
284
285#[cfg(test)]
286#[ctor::ctor]
287fn init_logger() {
288 if std::env::var("RUST_LOG").is_ok() {
289 env_logger::init();
290 }
291}