1pub mod assistant;
2mod assistant_settings;
3mod codegen;
4mod streaming_diff;
5
6use anyhow::{anyhow, Result};
7pub use assistant::AssistantPanel;
8use assistant_settings::OpenAIModel;
9use chrono::{DateTime, Local};
10use collections::HashMap;
11use fs::Fs;
12use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
13use gpui::{executor::Background, AppContext};
14use isahc::{http::StatusCode, Request, RequestExt};
15use regex::Regex;
16use serde::{Deserialize, Serialize};
17use std::{
18 cmp::Reverse,
19 ffi::OsStr,
20 fmt::{self, Display},
21 io,
22 path::PathBuf,
23 sync::Arc,
24};
25use util::paths::CONVERSATIONS_DIR;
26
27const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
28
29// Data types for chat completion requests
30#[derive(Debug, Default, Serialize)]
31pub struct OpenAIRequest {
32 model: String,
33 messages: Vec<RequestMessage>,
34 stream: bool,
35}
36
37#[derive(
38 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
39)]
40struct MessageId(usize);
41
42#[derive(Clone, Debug, Serialize, Deserialize)]
43struct MessageMetadata {
44 role: Role,
45 sent_at: DateTime<Local>,
46 status: MessageStatus,
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
50enum MessageStatus {
51 Pending,
52 Done,
53 Error(Arc<str>),
54}
55
56#[derive(Serialize, Deserialize)]
57struct SavedMessage {
58 id: MessageId,
59 start: usize,
60}
61
62#[derive(Serialize, Deserialize)]
63struct SavedConversation {
64 id: Option<String>,
65 zed: String,
66 version: String,
67 text: String,
68 messages: Vec<SavedMessage>,
69 message_metadata: HashMap<MessageId, MessageMetadata>,
70 summary: String,
71 model: OpenAIModel,
72}
73
74impl SavedConversation {
75 const VERSION: &'static str = "0.1.0";
76}
77
78struct SavedConversationMetadata {
79 title: String,
80 path: PathBuf,
81 mtime: chrono::DateTime<chrono::Local>,
82}
83
84impl SavedConversationMetadata {
85 pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
86 fs.create_dir(&CONVERSATIONS_DIR).await?;
87
88 let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
89 let mut conversations = Vec::<SavedConversationMetadata>::new();
90 while let Some(path) = paths.next().await {
91 let path = path?;
92 if path.extension() != Some(OsStr::new("json")) {
93 continue;
94 }
95
96 let pattern = r" - \d+.zed.json$";
97 let re = Regex::new(pattern).unwrap();
98
99 let metadata = fs.metadata(&path).await?;
100 if let Some((file_name, metadata)) = path
101 .file_name()
102 .and_then(|name| name.to_str())
103 .zip(metadata)
104 {
105 let title = re.replace(file_name, "");
106 conversations.push(Self {
107 title: title.into_owned(),
108 path,
109 mtime: metadata.mtime.into(),
110 });
111 }
112 }
113 conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
114
115 Ok(conversations)
116 }
117}
118
119#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
120struct RequestMessage {
121 role: Role,
122 content: String,
123}
124
125#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
126pub struct ResponseMessage {
127 role: Option<Role>,
128 content: Option<String>,
129}
130
131#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
132#[serde(rename_all = "lowercase")]
133enum Role {
134 User,
135 Assistant,
136 System,
137}
138
139impl Role {
140 pub fn cycle(&mut self) {
141 *self = match self {
142 Role::User => Role::Assistant,
143 Role::Assistant => Role::System,
144 Role::System => Role::User,
145 }
146 }
147}
148
149impl Display for Role {
150 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 Role::User => write!(f, "User"),
153 Role::Assistant => write!(f, "Assistant"),
154 Role::System => write!(f, "System"),
155 }
156 }
157}
158
159#[derive(Deserialize, Debug)]
160pub struct OpenAIResponseStreamEvent {
161 pub id: Option<String>,
162 pub object: String,
163 pub created: u32,
164 pub model: String,
165 pub choices: Vec<ChatChoiceDelta>,
166 pub usage: Option<Usage>,
167}
168
169#[derive(Deserialize, Debug)]
170pub struct Usage {
171 pub prompt_tokens: u32,
172 pub completion_tokens: u32,
173 pub total_tokens: u32,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct ChatChoiceDelta {
178 pub index: u32,
179 pub delta: ResponseMessage,
180 pub finish_reason: Option<String>,
181}
182
183#[derive(Deserialize, Debug)]
184struct OpenAIUsage {
185 prompt_tokens: u64,
186 completion_tokens: u64,
187 total_tokens: u64,
188}
189
190#[derive(Deserialize, Debug)]
191struct OpenAIChoice {
192 text: String,
193 index: u32,
194 logprobs: Option<serde_json::Value>,
195 finish_reason: Option<String>,
196}
197
198pub fn init(cx: &mut AppContext) {
199 assistant::init(cx);
200}
201
202pub async fn stream_completion(
203 api_key: String,
204 executor: Arc<Background>,
205 mut request: OpenAIRequest,
206) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
207 request.stream = true;
208
209 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
210
211 let json_data = serde_json::to_string(&request)?;
212 let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
213 .header("Content-Type", "application/json")
214 .header("Authorization", format!("Bearer {}", api_key))
215 .body(json_data)?
216 .send_async()
217 .await?;
218
219 let status = response.status();
220 if status == StatusCode::OK {
221 executor
222 .spawn(async move {
223 let mut lines = BufReader::new(response.body_mut()).lines();
224
225 fn parse_line(
226 line: Result<String, io::Error>,
227 ) -> Result<Option<OpenAIResponseStreamEvent>> {
228 if let Some(data) = line?.strip_prefix("data: ") {
229 let event = serde_json::from_str(&data)?;
230 Ok(Some(event))
231 } else {
232 Ok(None)
233 }
234 }
235
236 while let Some(line) = lines.next().await {
237 if let Some(event) = parse_line(line).transpose() {
238 let done = event.as_ref().map_or(false, |event| {
239 event
240 .choices
241 .last()
242 .map_or(false, |choice| choice.finish_reason.is_some())
243 });
244 if tx.unbounded_send(event).is_err() {
245 break;
246 }
247
248 if done {
249 break;
250 }
251 }
252 }
253
254 anyhow::Ok(())
255 })
256 .detach();
257
258 Ok(rx)
259 } else {
260 let mut body = String::new();
261 response.body_mut().read_to_string(&mut body).await?;
262
263 #[derive(Deserialize)]
264 struct OpenAIResponse {
265 error: OpenAIError,
266 }
267
268 #[derive(Deserialize)]
269 struct OpenAIError {
270 message: String,
271 }
272
273 match serde_json::from_str::<OpenAIResponse>(&body) {
274 Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
275 "Failed to connect to OpenAI API: {}",
276 response.error.message,
277 )),
278
279 _ => Err(anyhow!(
280 "Failed to connect to OpenAI API: {} {}",
281 response.status(),
282 body,
283 )),
284 }
285 }
286}
287
288#[cfg(test)]
289#[ctor::ctor]
290fn init_logger() {
291 if std::env::var("RUST_LOG").is_ok() {
292 env_logger::init();
293 }
294}