1pub mod assistant;
2mod assistant_settings;
3mod streaming_diff;
4
5use anyhow::{anyhow, Result};
6pub use assistant::AssistantPanel;
7use assistant_settings::OpenAIModel;
8use chrono::{DateTime, Local};
9use collections::HashMap;
10use fs::Fs;
11use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
12use gpui::{executor::Background, AppContext};
13use isahc::{http::StatusCode, Request, RequestExt};
14use regex::Regex;
15use serde::{Deserialize, Serialize};
16use std::{
17 cmp::Reverse,
18 ffi::OsStr,
19 fmt::{self, Display},
20 io,
21 path::PathBuf,
22 sync::Arc,
23};
24use util::paths::CONVERSATIONS_DIR;
25
26const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
27
28// Data types for chat completion requests
29#[derive(Debug, Serialize)]
30pub struct OpenAIRequest {
31 model: String,
32 messages: Vec<RequestMessage>,
33 stream: bool,
34}
35
36#[derive(
37 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
38)]
39struct MessageId(usize);
40
41#[derive(Clone, Debug, Serialize, Deserialize)]
42struct MessageMetadata {
43 role: Role,
44 sent_at: DateTime<Local>,
45 status: MessageStatus,
46}
47
48#[derive(Clone, Debug, Serialize, Deserialize)]
49enum MessageStatus {
50 Pending,
51 Done,
52 Error(Arc<str>),
53}
54
55#[derive(Serialize, Deserialize)]
56struct SavedMessage {
57 id: MessageId,
58 start: usize,
59}
60
61#[derive(Serialize, Deserialize)]
62struct SavedConversation {
63 zed: String,
64 version: String,
65 text: String,
66 messages: Vec<SavedMessage>,
67 message_metadata: HashMap<MessageId, MessageMetadata>,
68 summary: String,
69 model: OpenAIModel,
70}
71
72impl SavedConversation {
73 const VERSION: &'static str = "0.1.0";
74}
75
76struct SavedConversationMetadata {
77 title: String,
78 path: PathBuf,
79 mtime: chrono::DateTime<chrono::Local>,
80}
81
82impl SavedConversationMetadata {
83 pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
84 fs.create_dir(&CONVERSATIONS_DIR).await?;
85
86 let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
87 let mut conversations = Vec::<SavedConversationMetadata>::new();
88 while let Some(path) = paths.next().await {
89 let path = path?;
90 if path.extension() != Some(OsStr::new("json")) {
91 continue;
92 }
93
94 let pattern = r" - \d+.zed.json$";
95 let re = Regex::new(pattern).unwrap();
96
97 let metadata = fs.metadata(&path).await?;
98 if let Some((file_name, metadata)) = path
99 .file_name()
100 .and_then(|name| name.to_str())
101 .zip(metadata)
102 {
103 let title = re.replace(file_name, "");
104 conversations.push(Self {
105 title: title.into_owned(),
106 path,
107 mtime: metadata.mtime.into(),
108 });
109 }
110 }
111 conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
112
113 Ok(conversations)
114 }
115}
116
117#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
118struct RequestMessage {
119 role: Role,
120 content: String,
121}
122
123#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
124pub struct ResponseMessage {
125 role: Option<Role>,
126 content: Option<String>,
127}
128
129#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
130#[serde(rename_all = "lowercase")]
131enum Role {
132 User,
133 Assistant,
134 System,
135}
136
137impl Role {
138 pub fn cycle(&mut self) {
139 *self = match self {
140 Role::User => Role::Assistant,
141 Role::Assistant => Role::System,
142 Role::System => Role::User,
143 }
144 }
145}
146
147impl Display for Role {
148 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
149 match self {
150 Role::User => write!(f, "User"),
151 Role::Assistant => write!(f, "Assistant"),
152 Role::System => write!(f, "System"),
153 }
154 }
155}
156
157#[derive(Deserialize, Debug)]
158pub struct OpenAIResponseStreamEvent {
159 pub id: Option<String>,
160 pub object: String,
161 pub created: u32,
162 pub model: String,
163 pub choices: Vec<ChatChoiceDelta>,
164 pub usage: Option<Usage>,
165}
166
167#[derive(Deserialize, Debug)]
168pub struct Usage {
169 pub prompt_tokens: u32,
170 pub completion_tokens: u32,
171 pub total_tokens: u32,
172}
173
174#[derive(Deserialize, Debug)]
175pub struct ChatChoiceDelta {
176 pub index: u32,
177 pub delta: ResponseMessage,
178 pub finish_reason: Option<String>,
179}
180
181#[derive(Deserialize, Debug)]
182struct OpenAIUsage {
183 prompt_tokens: u64,
184 completion_tokens: u64,
185 total_tokens: u64,
186}
187
188#[derive(Deserialize, Debug)]
189struct OpenAIChoice {
190 text: String,
191 index: u32,
192 logprobs: Option<serde_json::Value>,
193 finish_reason: Option<String>,
194}
195
196pub fn init(cx: &mut AppContext) {
197 assistant::init(cx);
198}
199
200pub async fn stream_completion(
201 api_key: String,
202 executor: Arc<Background>,
203 mut request: OpenAIRequest,
204) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
205 request.stream = true;
206
207 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
208
209 let json_data = serde_json::to_string(&request)?;
210 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
211 .header("Content-Type", "application/json")
212 .header("Authorization", format!("Bearer {}", api_key))
213 .body(json_data)?
214 .send_async()
215 .await?;
216
217 let status = response.status();
218 if status == StatusCode::OK {
219 executor
220 .spawn(async move {
221 let mut lines = BufReader::new(response.body_mut()).lines();
222
223 fn parse_line(
224 line: Result<String, io::Error>,
225 ) -> Result<Option<OpenAIResponseStreamEvent>> {
226 if let Some(data) = line?.strip_prefix("data: ") {
227 let event = serde_json::from_str(&data)?;
228 Ok(Some(event))
229 } else {
230 Ok(None)
231 }
232 }
233
234 while let Some(line) = lines.next().await {
235 if let Some(event) = parse_line(line).transpose() {
236 let done = event.as_ref().map_or(false, |event| {
237 event
238 .choices
239 .last()
240 .map_or(false, |choice| choice.finish_reason.is_some())
241 });
242 if tx.unbounded_send(event).is_err() {
243 break;
244 }
245
246 if done {
247 break;
248 }
249 }
250 }
251
252 anyhow::Ok(())
253 })
254 .detach();
255
256 Ok(rx)
257 } else {
258 let mut body = String::new();
259 response.body_mut().read_to_string(&mut body).await?;
260
261 #[derive(Deserialize)]
262 struct OpenAIResponse {
263 error: OpenAIError,
264 }
265
266 #[derive(Deserialize)]
267 struct OpenAIError {
268 message: String,
269 }
270
271 match serde_json::from_str::<OpenAIResponse>(&body) {
272 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
273 "Failed to connect to OpenAI API: {}",
274 response.error.message,
275 )),
276
277 _ => Err(anyhow!(
278 "Failed to connect to OpenAI API: {} {}",
279 response.status(),
280 body,
281 )),
282 }
283 }
284}
285
286#[cfg(test)]
287#[ctor::ctor]
288fn init_logger() {
289 if std::env::var("RUST_LOG").is_ok() {
290 env_logger::init();
291 }
292}