Detailed changes
@@ -86,6 +86,20 @@ dependencies = [
"memchr",
]
+[[package]]
+name = "ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "ctor",
+ "futures 0.3.28",
+ "gpui",
+ "isahc",
+ "regex",
+ "serde",
+ "serde_json",
+]
+
[[package]]
name = "alacritty_config"
version = "0.1.2-dev"
@@ -272,6 +286,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
name = "assistant"
version = "0.1.0"
dependencies = [
+ "ai",
"anyhow",
"chrono",
"client",
@@ -1,6 +1,7 @@
[workspace]
members = [
"crates/activity_indicator",
+ "crates/ai",
"crates/assistant",
"crates/audio",
"crates/auto_update",
@@ -0,0 +1,21 @@
+[package]
+name = "ai"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai.rs"
+doctest = false
+
+[dependencies]
+gpui = { path = "../gpui" }
+anyhow.workspace = true
+futures.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+
+[dev-dependencies]
+ctor.workspace = true
@@ -0,0 +1 @@
+pub mod completion;
@@ -0,0 +1,212 @@
+use anyhow::{anyhow, Result};
+use futures::{
+ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+ Stream, StreamExt,
+};
+use gpui::executor::Background;
+use isahc::{http::StatusCode, Request, RequestExt};
+use serde::{Deserialize, Serialize};
+use std::{
+ fmt::{self, Display},
+ io,
+ sync::Arc,
+};
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl Role {
+ pub fn cycle(&mut self) {
+ *self = match self {
+ Role::User => Role::Assistant,
+ Role::Assistant => Role::System,
+ Role::System => Role::User,
+ }
+ }
+}
+
+impl Display for Role {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Role::User => write!(f, "User"),
+ Role::Assistant => write!(f, "Assistant"),
+ Role::System => write!(f, "System"),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+ pub model: String,
+ pub messages: Vec<RequestMessage>,
+ pub stream: bool,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+ pub index: u32,
+ pub delta: ResponseMessage,
+ pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+ pub id: Option<String>,
+ pub object: String,
+ pub created: u32,
+ pub model: String,
+ pub choices: Vec<ChatChoiceDelta>,
+ pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+ api_key: String,
+ executor: Arc<Background>,
+ mut request: OpenAIRequest,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+ request.stream = true;
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+ let json_data = serde_json::to_string(&request)?;
+ let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(json_data)?
+ .send_async()
+ .await?;
+
+ let status = response.status();
+ if status == StatusCode::OK {
+ executor
+ .spawn(async move {
+ let mut lines = BufReader::new(response.body_mut()).lines();
+
+ fn parse_line(
+ line: Result<String, io::Error>,
+ ) -> Result<Option<OpenAIResponseStreamEvent>> {
+ if let Some(data) = line?.strip_prefix("data: ") {
+ let event = serde_json::from_str(&data)?;
+ Ok(Some(event))
+ } else {
+ Ok(None)
+ }
+ }
+
+ while let Some(line) = lines.next().await {
+ if let Some(event) = parse_line(line).transpose() {
+ let done = event.as_ref().map_or(false, |event| {
+ event
+ .choices
+ .last()
+ .map_or(false, |choice| choice.finish_reason.is_some())
+ });
+ if tx.unbounded_send(event).is_err() {
+ break;
+ }
+
+ if done {
+ break;
+ }
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ Ok(rx)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAIResponse {
+ error: OpenAIError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAIError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAIResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+pub trait CompletionProvider {
+ fn complete(
+ &self,
+ prompt: OpenAIRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+}
+
+pub struct OpenAICompletionProvider {
+ api_key: String,
+ executor: Arc<Background>,
+}
+
+impl OpenAICompletionProvider {
+ pub fn new(api_key: String, executor: Arc<Background>) -> Self {
+ Self { api_key, executor }
+ }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+ fn complete(
+ &self,
+ prompt: OpenAIRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
+ async move {
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+}
@@ -9,6 +9,7 @@ path = "src/assistant.rs"
doctest = false
[dependencies]
+ai = { path = "../ai" }
client = { path = "../client" }
collections = { path = "../collections"}
editor = { path = "../editor" }
@@ -3,37 +3,20 @@ mod assistant_settings;
mod codegen;
mod streaming_diff;
-use anyhow::{anyhow, Result};
+use ai::completion::Role;
+use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;
use chrono::{DateTime, Local};
use collections::HashMap;
use fs::Fs;
-use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
-use gpui::{executor::Background, AppContext};
-use isahc::{http::StatusCode, Request, RequestExt};
+use futures::StreamExt;
+use gpui::AppContext;
use regex::Regex;
use serde::{Deserialize, Serialize};
-use std::{
- cmp::Reverse,
- ffi::OsStr,
- fmt::{self, Display},
- io,
- path::PathBuf,
- sync::Arc,
-};
+use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
use util::paths::CONVERSATIONS_DIR;
-const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
-
-// Data types for chat completion requests
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAIRequest {
- model: String,
- messages: Vec<RequestMessage>,
- stream: bool,
-}
-
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
@@ -116,175 +99,10 @@ impl SavedConversationMetadata {
}
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-struct RequestMessage {
- role: Role,
- content: String,
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
- role: Option<Role>,
- content: Option<String>,
-}
-
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-enum Role {
- User,
- Assistant,
- System,
-}
-
-impl Role {
- pub fn cycle(&mut self) {
- *self = match self {
- Role::User => Role::Assistant,
- Role::Assistant => Role::System,
- Role::System => Role::User,
- }
- }
-}
-
-impl Display for Role {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Role::User => write!(f, "User"),
- Role::Assistant => write!(f, "Assistant"),
- Role::System => write!(f, "System"),
- }
- }
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIResponseStreamEvent {
- pub id: Option<String>,
- pub object: String,
- pub created: u32,
- pub model: String,
- pub choices: Vec<ChatChoiceDelta>,
- pub usage: Option<Usage>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct Usage {
- pub prompt_tokens: u32,
- pub completion_tokens: u32,
- pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
- pub index: u32,
- pub delta: ResponseMessage,
- pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-struct OpenAIUsage {
- prompt_tokens: u64,
- completion_tokens: u64,
- total_tokens: u64,
-}
-
-#[derive(Deserialize, Debug)]
-struct OpenAIChoice {
- text: String,
- index: u32,
- logprobs: Option<serde_json::Value>,
- finish_reason: Option<String>,
-}
-
pub fn init(cx: &mut AppContext) {
assistant_panel::init(cx);
}
-pub async fn stream_completion(
- api_key: String,
- executor: Arc<Background>,
- mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
- request.stream = true;
-
- let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
- let json_data = serde_json::to_string(&request)?;
- let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(json_data)?
- .send_async()
- .await?;
-
- let status = response.status();
- if status == StatusCode::OK {
- executor
- .spawn(async move {
- let mut lines = BufReader::new(response.body_mut()).lines();
-
- fn parse_line(
- line: Result<String, io::Error>,
- ) -> Result<Option<OpenAIResponseStreamEvent>> {
- if let Some(data) = line?.strip_prefix("data: ") {
- let event = serde_json::from_str(&data)?;
- Ok(Some(event))
- } else {
- Ok(None)
- }
- }
-
- while let Some(line) = lines.next().await {
- if let Some(event) = parse_line(line).transpose() {
- let done = event.as_ref().map_or(false, |event| {
- event
- .choices
- .last()
- .map_or(false, |choice| choice.finish_reason.is_some())
- });
- if tx.unbounded_send(event).is_err() {
- break;
- }
-
- if done {
- break;
- }
- }
- }
-
- anyhow::Ok(())
- })
- .detach();
-
- Ok(rx)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- #[derive(Deserialize)]
- struct OpenAIResponse {
- error: OpenAIError,
- }
-
- #[derive(Deserialize)]
- struct OpenAIError {
- message: String,
- }
-
- match serde_json::from_str::<OpenAIResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
- "Failed to connect to OpenAI API: {}",
- response.error.message,
- )),
-
- _ => Err(anyhow!(
- "Failed to connect to OpenAI API: {} {}",
- response.status(),
- body,
- )),
- }
- }
-}
-
#[cfg(test)]
#[ctor::ctor]
fn init_logger() {
@@ -1,8 +1,11 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
- codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider},
- stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
- Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
+ codegen::{self, Codegen, CodegenKind},
+ MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
+ SavedMessage,
+};
+use ai::completion::{
+ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@@ -1,59 +1,14 @@
-use crate::{
- stream_completion,
- streaming_diff::{Hunk, StreamingDiff},
- OpenAIRequest,
-};
+use crate::streaming_diff::{Hunk, StreamingDiff};
+use ai::completion::{CompletionProvider, OpenAIRequest};
use anyhow::Result;
use editor::{
multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
-use futures::{
- channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
-};
-use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
+use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
+use gpui::{Entity, ModelContext, ModelHandle, Task};
use language::{Rope, TransactionId};
use std::{cmp, future, ops::Range, sync::Arc};
-pub trait CompletionProvider {
- fn complete(
- &self,
- prompt: OpenAIRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
-}
-
-pub struct OpenAICompletionProvider {
- api_key: String,
- executor: Arc<Background>,
-}
-
-impl OpenAICompletionProvider {
- pub fn new(api_key: String, executor: Arc<Background>) -> Self {
- Self { api_key, executor }
- }
-}
-
-impl CompletionProvider for OpenAICompletionProvider {
- fn complete(
- &self,
- prompt: OpenAIRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
- async move {
- let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
- }
- .boxed()
- }
-}
-
pub enum Event {
Finished,
Undone,
@@ -397,13 +352,17 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
- use futures::stream;
+ use futures::{
+ future::BoxFuture,
+ stream::{self, BoxStream},
+ };
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
use parking_lot::Mutex;
use rand::prelude::*;
use settings::SettingsStore;
+ use smol::future::FutureExt;
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
@@ -5,9 +5,9 @@ pub mod only_instance;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
-use assistant::AssistantPanel;
use anyhow::Context;
use assets::Assets;
+use assistant::AssistantPanel;
use breadcrumbs::Breadcrumbs;
pub use client;
use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut
@@ -2418,7 +2418,7 @@ mod tests {
pane::init(cx);
project_panel::init((), cx);
terminal_view::init(cx);
- ai::init(cx);
+ assistant::init(cx);
app_state
})
}