From d6bdaa8a9141e181ec91ffb634cccae03e46ba08 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Sun, 28 Jul 2024 11:07:10 +0200 Subject: [PATCH] Simplify LLM protocol (#15366) In this pull request, we change the zed.dev protocol so that we pass the raw JSON for the specified provider directly to our server. This avoids the need to define a protobuf message that's a superset of all these formats. @bennetbo: We also changed the settings for available_models under zed.dev to be a flat format, because the nesting seemed too confusing. Can you help us upgrade the local provider configuration to be consistent with this? We do whatever we need to do when parsing the settings to make this simple for users, even if it's a bit more complex on our end. We want to use versioning to avoid breaking existing users, but need to keep making progress. ```json "zed.dev": { "available_models": [ { "provider": "anthropic", "name": "some-newly-released-model-we-havent-added", "max_tokens": 200000 } ] } ``` Release Notes: - N/A --------- Co-authored-by: Nathan --- Cargo.lock | 33 +- Cargo.toml | 2 - assets/settings/default.json | 3 + crates/anthropic/src/anthropic.rs | 33 +- crates/assistant/src/assistant_settings.rs | 4 +- crates/assistant_tooling/Cargo.toml | 33 -- crates/assistant_tooling/LICENSE-GPL | 1 - crates/assistant_tooling/README.md | 85 --- .../src/assistant_tooling.rs | 13 - .../src/attachment_registry.rs | 234 -------- .../assistant_tooling/src/project_context.rs | 296 ---------- crates/assistant_tooling/src/tool_registry.rs | 526 ------------------ crates/collab/src/ai.rs | 138 ----- crates/collab/src/lib.rs | 1 - crates/collab/src/rpc.rs | 394 +++---------- crates/google_ai/Cargo.toml | 5 + crates/google_ai/src/google_ai.rs | 110 +++- crates/language_model/Cargo.toml | 1 + .../language_model/src/model/cloud_model.rs | 104 +--- crates/language_model/src/provider.rs | 1 + .../language_model/src/provider/anthropic.rs | 105 +--- crates/language_model/src/provider/cloud.rs | 190 ++++--- crates/language_model/src/provider/google.rs | 351 ++++++++++++ crates/language_model/src/provider/open_ai.rs | 43 +- crates/language_model/src/registry.rs | 18 +- crates/language_model/src/request.rs | 121 +++- crates/language_model/src/role.rs | 1 - crates/language_model/src/settings.rs | 43 +- crates/open_ai/src/open_ai.rs | 42 +- crates/proto/proto/zed.proto | 106 +--- crates/proto/src/proto.rs | 9 +- 31 files changed, 894 insertions(+), 2152 deletions(-) delete mode 100644 crates/assistant_tooling/Cargo.toml delete mode 120000 crates/assistant_tooling/LICENSE-GPL delete mode 100644 crates/assistant_tooling/README.md delete mode 100644 crates/assistant_tooling/src/assistant_tooling.rs delete mode 100644 crates/assistant_tooling/src/attachment_registry.rs delete mode 100644 crates/assistant_tooling/src/project_context.rs delete mode 100644 crates/assistant_tooling/src/tool_registry.rs delete mode 100644 crates/collab/src/ai.rs create mode 100644 crates/language_model/src/provider/google.rs diff --git a/Cargo.lock b/Cargo.lock index 92dd5d9a8e33689234e09d93508a12a9736b3243..2876ec86a49da30e281166bc644032c8d7898cb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -471,27 +471,6 @@ dependencies = [ "workspace", ] -[[package]] -name = "assistant_tooling" -version = "0.1.0" -dependencies = [ - "anyhow", - "collections", - "futures 0.3.28", - "gpui", - "log", - "project", - "repair_json", - "schemars", - "serde", - "serde_json", - "settings", - "sum_tree", - "ui", - "unindent", - "util", -] - [[package]] name = "async-attributes" version = "1.1.2" @@ -4811,8 +4790,10 @@ dependencies = [ "anyhow", "futures 0.3.28", "http_client", + "schemars", "serde", "serde_json", + "strum", ] [[package]] @@ -5988,6 +5969,7 @@ dependencies = [ "env_logger", "feature_flags", "futures 0.3.28", + "google_ai", "gpui", "http_client", "language", @@ -8715,15 +8697,6 @@ dependencies = [ "bytecheck", ] -[[package]] -name = "repair_json" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15" -dependencies = [ - "thiserror", -] - [[package]] name = "repl" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index b289d083bb35733b99ff7b9e78772db81d31aa9f..19a6b6b83662a3ca71754ef8532030e1e71e5c39 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ members = [ "crates/assets", "crates/assistant", "crates/assistant_slash_command", - "crates/assistant_tooling", "crates/audio", "crates/auto_update", "crates/breadcrumbs", @@ -178,7 +177,6 @@ anthropic = { path = "crates/anthropic" } assets = { path = "crates/assets" } assistant = { path = "crates/assistant" } assistant_slash_command = { path = "crates/assistant_slash_command" } -assistant_tooling = { path = "crates/assistant_tooling" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } breadcrumbs = { path = "crates/breadcrumbs" } diff --git a/assets/settings/default.json b/assets/settings/default.json index 529b91b7cdd5f84b30f0aff667a613ff4700ace7..a26c7d27a07d1953e3976e5fe4231a5e5517ab60 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -870,6 +870,9 @@ "openai": { "api_url": "https://api.openai.com/v1" }, + "google": { + "api_url": "https://generativelanguage.googleapis.com" + }, "ollama": { "api_url": "http://localhost:11434" } diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 2d9bd311b89b21a56b8b2cdbd602597b148ec7f9..45a4dfc0d34646b4361111af09e1da8bb42a9320 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; @@ -98,7 +98,7 @@ impl From for String { } } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Request { pub model: String, pub messages: Vec, @@ -113,7 +113,7 @@ pub struct RequestMessage { pub content: String, } -#[derive(Deserialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ResponseEvent { MessageStart { @@ -138,7 +138,7 @@ pub enum ResponseEvent { MessageStop {}, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ResponseMessage { #[serde(rename = "type")] pub message_type: Option, @@ -151,19 +151,19 @@ pub struct ResponseMessage { pub usage: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub input_tokens: Option, pub output_tokens: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ContentBlock { Text { text: String }, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum TextDelta { TextDelta { text: String }, @@ -226,6 +226,25 @@ pub async fn stream_completion( } } +pub fn extract_text_from_events( + response: impl Stream>, +) -> impl Stream> { + response.filter_map(|response| async move { + match response { + Ok(response) => match response { + ResponseEvent::ContentBlockStart { content_block, .. } => match content_block { + ContentBlock::Text { text } => Some(Ok(text)), + }, + ResponseEvent::ContentBlockDelta { delta, .. } => match delta { + TextDelta::TextDelta { text } => Some(Ok(text)), + }, + _ => None, + }, + Err(error) => Some(Err(error)), + } + }) +} + // #[cfg(test)] // mod tests { // use super::*; diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 05c5b56f1ce476f285b4fd1051d87f6d3da7f4b2..0d4dbd68240353961d84a65e3845e8801ecb0f24 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -249,9 +249,7 @@ impl AssistantSettingsContent { AssistantSettingsContent::Versioned(settings) => match settings { VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() { "zed.dev" => { - settings.provider = Some(AssistantProviderContentV1::ZedDotDev { - default_model: CloudModel::from_id(&model).ok(), - }); + log::warn!("attempted to set zed.dev model on outdated settings"); } "anthropic" => { let (api_url, low_speed_timeout_in_seconds) = match &settings.provider { diff --git a/crates/assistant_tooling/Cargo.toml b/crates/assistant_tooling/Cargo.toml deleted file mode 100644 index 79f41faad279b0366beea25eebad9edfc549ff52..0000000000000000000000000000000000000000 --- a/crates/assistant_tooling/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "assistant_tooling" -version = "0.1.0" -edition = "2021" -publish = false -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/assistant_tooling.rs" - -[dependencies] -anyhow.workspace = true -collections.workspace = true -futures.workspace = true -gpui.workspace = true -log.workspace = true -project.workspace = true -repair_json.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -sum_tree.workspace = true -ui.workspace = true -util.workspace = true - -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } -project = { workspace = true, features = ["test-support"] } -settings = { workspace = true, features = ["test-support"] } -unindent.workspace = true diff --git a/crates/assistant_tooling/LICENSE-GPL b/crates/assistant_tooling/LICENSE-GPL deleted file mode 120000 index 89e542f750cd3860a0598eff0dc34b56d7336dc4..0000000000000000000000000000000000000000 --- a/crates/assistant_tooling/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant_tooling/README.md b/crates/assistant_tooling/README.md deleted file mode 100644 index 160869ae974c97871e167018d786b8e454908fae..0000000000000000000000000000000000000000 --- a/crates/assistant_tooling/README.md +++ /dev/null @@ -1,85 +0,0 @@ -# Assistant Tooling - -Bringing Language Model tool calling to GPUI. - -This unlocks: - -- **Structured Extraction** of model responses -- **Validation** of model inputs -- **Execution** of chosen tools - -## Overview - -Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call. - -> **User**: "Hey I need help with implementing a collapsible panel in GPUI" -> -> **Assistant**: "Sure, I can help with that. Let me see what I can find." -> -> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]` -> -> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"` -> -> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you." - -This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`. - -## Using the Tool Registry - -```rust -let mut tool_registry = ToolRegistry::new(); -tool_registry - .register(WeatherTool { api_client }, - }) - .unwrap(); // You can only register one tool per name - -let completion = cx.update(|cx| { - CompletionProvider::get(cx).complete( - model_name, - messages, - Vec::new(), - 1.0, - // The definitions get passed directly to OpenAI when you want - // the model to be able to call your tool - tool_registry.definitions(), - ) -}); - -let mut stream = completion?.await?; - -let mut message = AssistantMessage::new(); - -while let Some(delta) = stream.next().await { - // As messages stream in, you'll get both assistant content - if let Some(content) = &delta.content { - message - .body - .update(cx, |message, cx| message.append(&content, cx)); - } - - // And tool calls! - for tool_call_delta in delta.tool_calls { - let index = tool_call_delta.index as usize; - if index >= message.tool_calls.len() { - message.tool_calls.resize_with(index + 1, Default::default); - } - let tool_call = &mut message.tool_calls[index]; - - // Build up an ID - if let Some(id) = &tool_call_delta.id { - tool_call.id.push_str(id); - } - - tool_registry.update_tool_call( - tool_call, - tool_call_delta.name.as_deref(), - tool_call_delta.arguments.as_deref(), - cx, - ); - } -} -``` - -Once the stream of tokens is complete, you can execute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task>`. - -As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`. diff --git a/crates/assistant_tooling/src/assistant_tooling.rs b/crates/assistant_tooling/src/assistant_tooling.rs deleted file mode 100644 index 9dcf2908e92c54f8cdfe0e0140f037334d3d0bc5..0000000000000000000000000000000000000000 --- a/crates/assistant_tooling/src/assistant_tooling.rs +++ /dev/null @@ -1,13 +0,0 @@ -mod attachment_registry; -mod project_context; -mod tool_registry; - -pub use attachment_registry::{ - AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, - UserAttachment, -}; -pub use project_context::ProjectContext; -pub use tool_registry::{ - LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, - ToolRegistry, ToolView, -}; diff --git a/crates/assistant_tooling/src/attachment_registry.rs b/crates/assistant_tooling/src/attachment_registry.rs deleted file mode 100644 index e8b52d26f08d86968393f425039424f3d9a6365c..0000000000000000000000000000000000000000 --- a/crates/assistant_tooling/src/attachment_registry.rs +++ /dev/null @@ -1,234 +0,0 @@ -use crate::ProjectContext; -use anyhow::{anyhow, Result}; -use collections::HashMap; -use futures::future::join_all; -use gpui::{AnyView, Render, Task, View, WindowContext}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::value::RawValue; -use std::{ - any::TypeId, - sync::{ - atomic::{AtomicBool, Ordering::SeqCst}, - Arc, - }, -}; -use util::ResultExt as _; - -pub struct AttachmentRegistry { - registered_attachments: HashMap, -} - -pub trait AttachmentOutput { - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; -} - -pub trait LanguageModelAttachment { - type Output: DeserializeOwned + Serialize + 'static; - type View: Render + AttachmentOutput; - - fn name(&self) -> Arc; - fn run(&self, cx: &mut WindowContext) -> Task>; - fn view(&self, output: Result, cx: &mut WindowContext) -> View; -} - -/// A collected attachment from running an attachment tool -pub struct UserAttachment { - pub view: AnyView, - name: Arc, - serialized_output: Result, String>, - generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String, -} - -#[derive(Serialize, Deserialize)] -pub struct SavedUserAttachment { - name: Arc, - serialized_output: Result, String>, -} - -/// Internal representation of an attachment tool to allow us to treat them dynamically -struct RegisteredAttachment { - name: Arc, - enabled: AtomicBool, - call: Box Task>>, - deserialize: Box Result>, -} - -impl AttachmentRegistry { - pub fn new() -> Self { - Self { - registered_attachments: HashMap::default(), - } - } - - pub fn register(&mut self, attachment: A) { - let attachment = Arc::new(attachment); - - let call = Box::new({ - let attachment = attachment.clone(); - move |cx: &mut WindowContext| { - let result = attachment.run(cx); - let attachment = attachment.clone(); - cx.spawn(move |mut cx| async move { - let result: Result = result.await; - let serialized_output = - result - .as_ref() - .map_err(ToString::to_string) - .and_then(|output| { - Ok(RawValue::from_string( - serde_json::to_string(output).map_err(|e| e.to_string())?, - ) - .unwrap()) - }); - - let view = cx.update(|cx| attachment.view(result, cx))?; - - Ok(UserAttachment { - name: attachment.name(), - view: view.into(), - generate_fn: generate::, - serialized_output, - }) - }) - } - }); - - let deserialize = Box::new({ - let attachment = attachment.clone(); - move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| { - let serialized_output = saved_attachment.serialized_output.clone(); - let output = match &serialized_output { - Ok(serialized_output) => { - Ok(serde_json::from_str::(serialized_output.get())?) - } - Err(error) => Err(anyhow!("{error}")), - }; - let view = attachment.view(output, cx).into(); - - Ok(UserAttachment { - name: saved_attachment.name.clone(), - view, - serialized_output, - generate_fn: generate::, - }) - } - }); - - self.registered_attachments.insert( - TypeId::of::(), - RegisteredAttachment { - name: attachment.name(), - call, - deserialize, - enabled: AtomicBool::new(true), - }, - ); - return; - - fn generate( - view: AnyView, - project: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - view.downcast::() - .unwrap() - .update(cx, |view, cx| T::View::generate(view, project, cx)) - } - } - - pub fn set_attachment_tool_enabled( - &self, - is_enabled: bool, - ) { - if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { - attachment.enabled.store(is_enabled, SeqCst); - } - } - - pub fn is_attachment_tool_enabled(&self) -> bool { - if let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) { - attachment.enabled.load(SeqCst) - } else { - false - } - } - - pub fn call( - &self, - cx: &mut WindowContext, - ) -> Task> { - let Some(attachment) = self.registered_attachments.get(&TypeId::of::()) else { - return Task::ready(Err(anyhow!("no attachment tool"))); - }; - - (attachment.call)(cx) - } - - pub fn call_all_attachment_tools( - self: Arc, - cx: &mut WindowContext<'_>, - ) -> Task>> { - let this = self.clone(); - cx.spawn(|mut cx| async move { - let attachment_tasks = cx.update(|cx| { - let mut tasks = Vec::new(); - for attachment in this - .registered_attachments - .values() - .filter(|attachment| attachment.enabled.load(SeqCst)) - { - tasks.push((attachment.call)(cx)) - } - - tasks - })?; - - let attachments = join_all(attachment_tasks.into_iter()).await; - - Ok(attachments - .into_iter() - .filter_map(|attachment| attachment.log_err()) - .collect()) - }) - } - - pub fn serialize_user_attachment( - &self, - user_attachment: &UserAttachment, - ) -> SavedUserAttachment { - SavedUserAttachment { - name: user_attachment.name.clone(), - serialized_output: user_attachment.serialized_output.clone(), - } - } - - pub fn deserialize_user_attachment( - &self, - saved_user_attachment: SavedUserAttachment, - cx: &mut WindowContext, - ) -> Result { - if let Some(registered_attachment) = self - .registered_attachments - .values() - .find(|attachment| attachment.name == saved_user_attachment.name) - { - (registered_attachment.deserialize)(&saved_user_attachment, cx) - } else { - Err(anyhow!( - "no attachment tool for name {}", - saved_user_attachment.name - )) - } - } -} - -impl UserAttachment { - pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option { - let result = (self.generate_fn)(self.view.clone(), output, cx); - if result.is_empty() { - None - } else { - Some(result) - } - } -} diff --git a/crates/assistant_tooling/src/project_context.rs b/crates/assistant_tooling/src/project_context.rs deleted file mode 100644 index 2640ce1ed556acb83d6d0a4e1dd3d98ad6593677..0000000000000000000000000000000000000000 --- a/crates/assistant_tooling/src/project_context.rs +++ /dev/null @@ -1,296 +0,0 @@ -use anyhow::{anyhow, Result}; -use gpui::{AppContext, Model, Task, WeakModel}; -use project::{Fs, Project, ProjectPath, Worktree}; -use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc}; -use sum_tree::TreeMap; - -pub struct ProjectContext { - files: TreeMap, - project: WeakModel, - fs: Arc, -} - -#[derive(Debug, Clone)] -enum PathState { - PathOnly, - EntireFile, - Excerpts { ranges: Vec> }, -} - -impl ProjectContext { - pub fn new(project: WeakModel, fs: Arc) -> Self { - Self { - files: TreeMap::default(), - fs, - project, - } - } - - pub fn add_path(&mut self, project_path: ProjectPath) { - if self.files.get(&project_path).is_none() { - self.files.insert(project_path, PathState::PathOnly); - } - } - - pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range]) { - let previous_state = self - .files - .get(&project_path) - .unwrap_or(&PathState::PathOnly); - - let mut ranges = match previous_state { - PathState::EntireFile => return, - PathState::PathOnly => Vec::new(), - PathState::Excerpts { ranges } => ranges.to_vec(), - }; - - for new_range in new_ranges { - let ix = ranges.binary_search_by(|probe| { - if probe.end < new_range.start { - Ordering::Less - } else if probe.start > new_range.end { - Ordering::Greater - } else { - Ordering::Equal - } - }); - - match ix { - Ok(mut ix) => { - let existing = &mut ranges[ix]; - existing.start = existing.start.min(new_range.start); - existing.end = existing.end.max(new_range.end); - while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end { - ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end); - ranges.remove(ix + 1); - } - while ix > 0 && ranges[ix - 1].end >= ranges[ix].start { - ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start); - ranges.remove(ix - 1); - ix -= 1; - } - } - Err(ix) => { - ranges.insert(ix, new_range.clone()); - } - } - } - - self.files - .insert(project_path, PathState::Excerpts { ranges }); - } - - pub fn add_file(&mut self, project_path: ProjectPath) { - self.files.insert(project_path, PathState::EntireFile); - } - - pub fn generate_system_message(&self, cx: &mut AppContext) -> Task> { - let project = self - .project - .upgrade() - .ok_or_else(|| anyhow!("project dropped")); - let files = self.files.clone(); - let fs = self.fs.clone(); - cx.spawn(|cx| async move { - let project = project?; - let mut result = "project structure:\n".to_string(); - - let mut last_worktree: Option> = None; - for (project_path, path_state) in files.iter() { - if let Some(worktree) = &last_worktree { - if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id { - last_worktree = None; - } - } - - let worktree; - if let Some(last_worktree) = &last_worktree { - worktree = last_worktree.clone(); - } else if let Some(tree) = project.read_with(&cx, |project, cx| { - project.worktree_for_id(project_path.worktree_id, cx) - })? { - worktree = tree; - last_worktree = Some(worktree.clone()); - let worktree_name = - worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?; - writeln!(&mut result, "# {}", worktree_name).unwrap(); - } else { - continue; - } - - let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?; - let path = &project_path.path; - writeln!(&mut result, "## {}", path.display()).unwrap(); - - match path_state { - PathState::PathOnly => {} - PathState::EntireFile => { - let text = fs.load(&worktree_abs_path.join(&path)).await?; - writeln!(&mut result, "~~~\n{text}\n~~~").unwrap(); - } - PathState::Excerpts { ranges } => { - let text = fs.load(&worktree_abs_path.join(&path)).await?; - - writeln!(&mut result, "~~~").unwrap(); - - // Assumption: ranges are in order, not overlapping - let mut prev_range_end = 0; - for range in ranges { - if range.start > prev_range_end { - writeln!(&mut result, "...").unwrap(); - prev_range_end = range.end; - } - - let mut start = range.start; - let mut end = range.end.min(text.len()); - while !text.is_char_boundary(start) { - start += 1; - } - while !text.is_char_boundary(end) { - end -= 1; - } - result.push_str(&text[start..end]); - if !result.ends_with('\n') { - result.push('\n'); - } - } - - if prev_range_end < text.len() { - writeln!(&mut result, "...").unwrap(); - } - - writeln!(&mut result, "~~~").unwrap(); - } - } - } - Ok(result) - }) - } -} - -#[cfg(test)] -mod tests { - use std::path::Path; - - use super::*; - use gpui::TestAppContext; - use project::FakeFs; - use serde_json::json; - use settings::SettingsStore; - - use unindent::Unindent as _; - - #[gpui::test] - async fn test_system_message_generation(cx: &mut TestAppContext) { - init_test(cx); - - let file_3_contents = r#" - fn test1() {} - fn test2() {} - fn test3() {} - "# - .unindent(); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - "/code", - json!({ - "root1": { - "lib": { - "file1.rs": "mod example;", - "file2.rs": "", - }, - "test": { - "file3.rs": file_3_contents, - } - }, - "root2": { - "src": { - "main.rs": "" - } - } - }), - ) - .await; - - let project = Project::test( - fs.clone(), - ["/code/root1".as_ref(), "/code/root2".as_ref()], - cx, - ) - .await; - - let worktree_ids = project.read_with(cx, |project, cx| { - project - .worktrees(cx) - .map(|worktree| worktree.read(cx).id()) - .collect::>() - }); - - let mut ax = ProjectContext::new(project.downgrade(), fs); - - ax.add_file(ProjectPath { - worktree_id: worktree_ids[0], - path: Path::new("lib/file1.rs").into(), - }); - - let message = cx - .update(|cx| ax.generate_system_message(cx)) - .await - .unwrap(); - assert_eq!( - r#" - project structure: - # root1 - ## lib/file1.rs - ~~~ - mod example; - ~~~ - "# - .unindent(), - message - ); - - ax.add_excerpts( - ProjectPath { - worktree_id: worktree_ids[0], - path: Path::new("test/file3.rs").into(), - }, - &[ - file_3_contents.find("fn test2").unwrap() - ..file_3_contents.find("fn test3").unwrap(), - ], - ); - - let message = cx - .update(|cx| ax.generate_system_message(cx)) - .await - .unwrap(); - assert_eq!( - r#" - project structure: - # root1 - ## lib/file1.rs - ~~~ - mod example; - ~~~ - ## test/file3.rs - ~~~ - ... - fn test2() {} - ... - ~~~ - "# - .unindent(), - message - ); - } - - fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - Project::init_settings(cx); - }); - } -} diff --git a/crates/assistant_tooling/src/tool_registry.rs b/crates/assistant_tooling/src/tool_registry.rs deleted file mode 100644 index e5f8914eb57c4ca5d65447990b677bdf1fe5c944..0000000000000000000000000000000000000000 --- a/crates/assistant_tooling/src/tool_registry.rs +++ /dev/null @@ -1,526 +0,0 @@ -use crate::ProjectContext; -use anyhow::{anyhow, Result}; -use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext}; -use repair_json::repair; -use schemars::{schema::RootSchema, schema_for, JsonSchema}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_json::value::RawValue; -use std::{ - any::TypeId, - collections::HashMap, - fmt::Display, - mem, - sync::atomic::{AtomicBool, Ordering::SeqCst}, -}; -use ui::ViewContext; - -pub struct ToolRegistry { - registered_tools: HashMap, -} - -#[derive(Default)] -pub struct ToolFunctionCall { - pub id: String, - pub name: String, - pub arguments: String, - state: ToolFunctionCallState, -} - -#[derive(Default)] -enum ToolFunctionCallState { - #[default] - Initializing, - NoSuchTool, - KnownTool(Box), - ExecutedTool(Box), -} - -trait InternalToolView { - fn view(&self) -> AnyView; - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String; - fn try_set_input(&self, input: &str, cx: &mut WindowContext); - fn execute(&self, cx: &mut WindowContext) -> Task>; - fn serialize_output(&self, cx: &mut WindowContext) -> Result>; - fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>; -} - -#[derive(Default, Serialize, Deserialize)] -pub struct SavedToolFunctionCall { - id: String, - name: String, - arguments: String, - state: SavedToolFunctionCallState, -} - -#[derive(Default, Serialize, Deserialize)] -enum SavedToolFunctionCallState { - #[default] - Initializing, - NoSuchTool, - KnownTool, - ExecutedTool(Box), -} - -#[derive(Clone, Debug, PartialEq)] -pub struct ToolFunctionDefinition { - pub name: String, - pub description: String, - pub parameters: RootSchema, -} - -pub trait LanguageModelTool { - type View: ToolView; - - /// Returns the name of the tool. - /// - /// This name is exposed to the language model to allow the model to pick - /// which tools to use. As this name is used to identify the tool within a - /// tool registry, it should be unique. - fn name(&self) -> String; - - /// Returns the description of the tool. - /// - /// This can be used to _prompt_ the model as to what the tool does. - fn description(&self) -> String; - - /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API. - fn definition(&self) -> ToolFunctionDefinition { - let root_schema = schema_for!(::Input); - - ToolFunctionDefinition { - name: self.name(), - description: self.description(), - parameters: root_schema, - } - } - - /// A view of the output of running the tool, for displaying to the user. - fn view(&self, cx: &mut WindowContext) -> View; -} - -pub trait ToolView: Render { - /// The input type that will be passed in to `execute` when the tool is called - /// by the language model. - type Input: DeserializeOwned + JsonSchema; - - /// The output returned by executing the tool. - type SerializedState: DeserializeOwned + Serialize; - - fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext) -> String; - fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext); - fn execute(&mut self, cx: &mut ViewContext) -> Task>; - - fn serialize(&self, cx: &mut ViewContext) -> Self::SerializedState; - fn deserialize( - &mut self, - output: Self::SerializedState, - cx: &mut ViewContext, - ) -> Result<()>; -} - -struct RegisteredTool { - enabled: AtomicBool, - type_id: TypeId, - build_view: Box Box>, - definition: ToolFunctionDefinition, -} - -impl ToolRegistry { - pub fn new() -> Self { - Self { - registered_tools: HashMap::new(), - } - } - - pub fn set_tool_enabled(&self, is_enabled: bool) { - for tool in self.registered_tools.values() { - if tool.type_id == TypeId::of::() { - tool.enabled.store(is_enabled, SeqCst); - return; - } - } - } - - pub fn is_tool_enabled(&self) -> bool { - for tool in self.registered_tools.values() { - if tool.type_id == TypeId::of::() { - return tool.enabled.load(SeqCst); - } - } - false - } - - pub fn definitions(&self) -> Vec { - self.registered_tools - .values() - .filter(|tool| tool.enabled.load(SeqCst)) - .map(|tool| tool.definition.clone()) - .collect() - } - - pub fn update_tool_call( - &self, - call: &mut ToolFunctionCall, - name: Option<&str>, - arguments: Option<&str>, - cx: &mut WindowContext, - ) { - if let Some(name) = name { - call.name.push_str(name); - } - if let Some(arguments) = arguments { - if call.arguments.is_empty() { - if let Some(tool) = self.registered_tools.get(&call.name) { - let view = (tool.build_view)(cx); - call.state = ToolFunctionCallState::KnownTool(view); - } else { - call.state = ToolFunctionCallState::NoSuchTool; - } - } - call.arguments.push_str(arguments); - - if let ToolFunctionCallState::KnownTool(view) = &call.state { - if let Ok(repaired_arguments) = repair(call.arguments.clone()) { - view.try_set_input(&repaired_arguments, cx) - } - } - } - } - - pub fn execute_tool_call( - &self, - tool_call: &mut ToolFunctionCall, - cx: &mut WindowContext, - ) -> Option>> { - if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) { - let task = view.execute(cx); - tool_call.state = ToolFunctionCallState::ExecutedTool(view); - Some(task) - } else { - None - } - } - - pub fn render_tool_call( - &self, - tool_call: &ToolFunctionCall, - _cx: &mut WindowContext, - ) -> Option { - match &tool_call.state { - ToolFunctionCallState::NoSuchTool => { - Some(ui::Label::new("No such tool").into_any_element()) - } - ToolFunctionCallState::Initializing => None, - ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { - Some(view.view().into_any_element()) - } - } - } - - pub fn content_for_tool_call( - &self, - tool_call: &ToolFunctionCall, - project_context: &mut ProjectContext, - cx: &mut WindowContext, - ) -> String { - match &tool_call.state { - ToolFunctionCallState::Initializing => String::new(), - ToolFunctionCallState::NoSuchTool => { - format!("No such tool: {}", tool_call.name) - } - ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => { - view.generate(project_context, cx) - } - } - } - - pub fn serialize_tool_call( - &self, - call: &ToolFunctionCall, - cx: &mut WindowContext, - ) -> Result { - Ok(SavedToolFunctionCall { - id: call.id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - state: match &call.state { - ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing, - ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool, - ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool, - ToolFunctionCallState::ExecutedTool(view) => { - SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?) - } - }, - }) - } - - pub fn deserialize_tool_call( - &self, - call: &SavedToolFunctionCall, - cx: &mut WindowContext, - ) -> Result { - let Some(tool) = self.registered_tools.get(&call.name) else { - return Err(anyhow!("no such tool {}", call.name)); - }; - - Ok(ToolFunctionCall { - id: call.id.clone(), - name: call.name.clone(), - arguments: call.arguments.clone(), - state: match &call.state { - SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing, - SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool, - SavedToolFunctionCallState::KnownTool => { - log::error!("Deserialized tool that had not executed"); - let view = (tool.build_view)(cx); - view.try_set_input(&call.arguments, cx); - ToolFunctionCallState::KnownTool(view) - } - SavedToolFunctionCallState::ExecutedTool(output) => { - let view = (tool.build_view)(cx); - view.try_set_input(&call.arguments, cx); - view.deserialize_output(output, cx)?; - ToolFunctionCallState::ExecutedTool(view) - } - }, - }) - } - - pub fn register(&mut self, tool: T) -> Result<()> { - let name = tool.name(); - let registered_tool = RegisteredTool { - type_id: TypeId::of::(), - definition: tool.definition(), - enabled: AtomicBool::new(true), - build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))), - }; - - let previous = self.registered_tools.insert(name.clone(), registered_tool); - if previous.is_some() { - return Err(anyhow!("already registered a tool with name {}", name)); - } - - return Ok(()); - } -} - -impl InternalToolView for View { - fn view(&self) -> AnyView { - self.clone().into() - } - - fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String { - self.update(cx, |view, cx| view.generate(project, cx)) - } - - fn try_set_input(&self, input: &str, cx: &mut WindowContext) { - if let Ok(input) = serde_json::from_str::(input) { - self.update(cx, |view, cx| { - view.set_input(input, cx); - cx.notify(); - }); - } - } - - fn execute(&self, cx: &mut WindowContext) -> Task> { - self.update(cx, |view, cx| view.execute(cx)) - } - - fn serialize_output(&self, cx: &mut WindowContext) -> Result> { - let output = self.update(cx, |view, cx| view.serialize(cx)); - Ok(RawValue::from_string(serde_json::to_string(&output)?)?) - } - - fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> { - let state = serde_json::from_str::(output.get())?; - self.update(cx, |view, cx| view.deserialize(state, cx))?; - Ok(()) - } -} - -impl Display for ToolFunctionDefinition { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let schema = serde_json::to_string(&self.parameters).ok(); - let schema = schema.unwrap_or("None".to_string()); - write!(f, "Name: {}:\n", self.name)?; - write!(f, "Description: {}\n", self.description)?; - write!(f, "Parameters: {}", schema) - } -} - -#[cfg(test)] -mod test { - use super::*; - use gpui::{div, prelude::*, Render, TestAppContext}; - use gpui::{EmptyView, View}; - use schemars::JsonSchema; - use serde::{Deserialize, Serialize}; - use serde_json::json; - - #[derive(Deserialize, Serialize, JsonSchema)] - struct WeatherQuery { - location: String, - unit: String, - } - - #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)] - struct WeatherResult { - location: String, - temperature: f64, - unit: String, - } - - struct WeatherView { - input: Option, - result: Option, - - // Fake API call - current_weather: WeatherResult, - } - - #[derive(Clone, Serialize)] - struct WeatherTool { - current_weather: WeatherResult, - } - - impl WeatherView { - fn new(current_weather: WeatherResult) -> Self { - Self { - input: None, - result: None, - current_weather, - } - } - } - - impl Render for WeatherView { - fn render(&mut self, _cx: &mut gpui::ViewContext) -> impl IntoElement { - match self.result { - Some(ref result) => div() - .child(format!("temperature: {}", result.temperature)) - .into_any_element(), - None => div().child("Calculating weather...").into_any_element(), - } - } - } - - impl ToolView for WeatherView { - type Input = WeatherQuery; - - type SerializedState = WeatherResult; - - fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext) -> String { - serde_json::to_string(&self.result).unwrap() - } - - fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext) { - self.input = Some(input); - cx.notify(); - } - - fn execute(&mut self, _cx: &mut ViewContext) -> Task> { - let input = self.input.as_ref().unwrap(); - - let _location = input.location.clone(); - let _unit = input.unit.clone(); - - let weather = self.current_weather.clone(); - - self.result = Some(weather); - - Task::ready(Ok(())) - } - - fn serialize(&self, _cx: &mut ViewContext) -> Self::SerializedState { - self.current_weather.clone() - } - - fn deserialize( - &mut self, - output: Self::SerializedState, - _cx: &mut ViewContext, - ) -> Result<()> { - self.current_weather = output; - Ok(()) - } - } - - impl LanguageModelTool for WeatherTool { - type View = WeatherView; - - fn name(&self) -> String { - "get_current_weather".to_string() - } - - fn description(&self) -> String { - "Fetches the current weather for a given location.".to_string() - } - - fn view(&self, cx: &mut WindowContext) -> View { - cx.new_view(|_cx| WeatherView::new(self.current_weather.clone())) - } - } - - #[gpui::test] - async fn test_openai_weather_example(cx: &mut TestAppContext) { - let (_, cx) = cx.add_window_view(|_cx| EmptyView); - - let mut registry = ToolRegistry::new(); - registry - .register(WeatherTool { - current_weather: WeatherResult { - location: "San Francisco".to_string(), - temperature: 21.0, - unit: "Celsius".to_string(), - }, - }) - .unwrap(); - - let definitions = registry.definitions(); - assert_eq!( - definitions, - [ToolFunctionDefinition { - name: "get_current_weather".to_string(), - description: "Fetches the current weather for a given location.".to_string(), - parameters: serde_json::from_value(json!({ - "$schema": "http://json-schema.org/draft-07/schema#", - "title": "WeatherQuery", - "type": "object", - "properties": { - "location": { - "type": "string" - }, - "unit": { - "type": "string" - } - }, - "required": ["location", "unit"] - })) - .unwrap(), - }] - ); - - let mut call = ToolFunctionCall { - id: "the-id".to_string(), - name: "get_cur".to_string(), - ..Default::default() - }; - - let task = cx.update(|cx| { - registry.update_tool_call( - &mut call, - Some("rent_weather"), - Some(r#"{"location": "San Francisco","#), - cx, - ); - registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx); - registry.execute_tool_call(&mut call, cx).unwrap() - }); - task.await.unwrap(); - - match &call.state { - ToolFunctionCallState::ExecutedTool(_view) => {} - _ => panic!(), - } - } -} diff --git a/crates/collab/src/ai.rs b/crates/collab/src/ai.rs deleted file mode 100644 index 06c6e77dfddcadaaa162f7bac9c680159e33a708..0000000000000000000000000000000000000000 --- a/crates/collab/src/ai.rs +++ /dev/null @@ -1,138 +0,0 @@ -use anyhow::{anyhow, Context as _, Result}; -use rpc::proto; -use util::ResultExt as _; - -pub fn language_model_request_to_open_ai( - request: proto::CompleteWithLanguageModel, -) -> Result { - Ok(open_ai::Request { - model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo), - messages: request - .messages - .into_iter() - .map(|message: proto::LanguageModelRequestMessage| { - let role = proto::LanguageModelRole::from_i32(message.role) - .ok_or_else(|| anyhow!("invalid role {}", message.role))?; - - let openai_message = match role { - proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User { - content: message.content, - }, - proto::LanguageModelRole::LanguageModelAssistant => { - open_ai::RequestMessage::Assistant { - content: Some(message.content), - tool_calls: message - .tool_calls - .into_iter() - .filter_map(|call| { - Some(open_ai::ToolCall { - id: call.id, - content: match call.variant? { - proto::tool_call::Variant::Function(f) => { - open_ai::ToolCallContent::Function { - function: open_ai::FunctionContent { - name: f.name, - arguments: f.arguments, - }, - } - } - }, - }) - }) - .collect(), - } - } - proto::LanguageModelRole::LanguageModelSystem => { - open_ai::RequestMessage::System { - content: message.content, - } - } - proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool { - tool_call_id: message - .tool_call_id - .ok_or_else(|| anyhow!("tool message is missing tool call id"))?, - content: message.content, - }, - }; - - Ok(openai_message) - }) - .collect::>>()?, - stream: true, - stop: request.stop, - temperature: request.temperature, - tools: request - .tools - .into_iter() - .filter_map(|tool| { - Some(match tool.variant? { - proto::chat_completion_tool::Variant::Function(f) => { - open_ai::ToolDefinition::Function { - function: open_ai::FunctionDefinition { - name: f.name, - description: f.description, - parameters: if let Some(params) = &f.parameters { - Some( - serde_json::from_str(params) - .context("failed to deserialize tool parameters") - .log_err()?, - ) - } else { - None - }, - }, - } - } - }) - }) - .collect(), - tool_choice: request.tool_choice, - }) -} - -pub fn language_model_request_to_google_ai( - request: proto::CompleteWithLanguageModel, -) -> Result { - Ok(google_ai::GenerateContentRequest { - contents: request - .messages - .into_iter() - .map(language_model_request_message_to_google_ai) - .collect::>>()?, - generation_config: None, - safety_settings: None, - }) -} - -pub fn language_model_request_message_to_google_ai( - message: proto::LanguageModelRequestMessage, -) -> Result { - let role = proto::LanguageModelRole::from_i32(message.role) - .ok_or_else(|| anyhow!("invalid role {}", message.role))?; - - Ok(google_ai::Content { - parts: vec![google_ai::Part::TextPart(google_ai::TextPart { - text: message.content, - })], - role: match role { - proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User, - proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model, - proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User, - proto::LanguageModelRole::LanguageModelTool => { - Err(anyhow!("we don't handle tool calls with google ai yet"))? - } - }, - }) -} - -pub fn count_tokens_request_to_google_ai( - request: proto::CountTokensWithLanguageModel, -) -> Result { - Ok(google_ai::CountTokensRequest { - contents: request - .messages - .into_iter() - .map(language_model_request_message_to_google_ai) - .collect::>>()?, - }) -} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index ae83fccb98cf622600e6b22a83c7f76e2abd95a0..2673ca3fb8640375c0a7d377732656193a921729 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,4 +1,3 @@ -pub mod ai; pub mod api; pub mod auth; pub mod db; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 3ec13ce04573a04c15bb4093f50046026d01ed37..92e5b1a58411d0cadca68f6947ff174499739eae 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -46,8 +46,8 @@ use http_client::IsahcHttpClient; use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole, - LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators, + self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, + RequestMessage, ShareProject, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope, }; @@ -618,17 +618,6 @@ impl Server { ) } }) - .add_request_handler({ - let app_state = app_state.clone(); - user_handler(move |request, response, session| { - count_tokens_with_language_model( - request, - response, - session, - app_state.config.google_ai_api_key.clone(), - ) - }) - }) .add_request_handler({ user_handler(move |request, response, session| { get_cached_embeddings(request, response, session) @@ -4514,8 +4503,8 @@ impl RateLimit for CompleteWithLanguageModelRateLimit { } async fn complete_with_language_model( - mut request: proto::CompleteWithLanguageModel, - response: StreamingResponse, + query: proto::QueryLanguageModel, + response: StreamingResponse, session: Session, open_ai_api_key: Option>, google_ai_api_key: Option>, @@ -4525,287 +4514,95 @@ async fn complete_with_language_model( return Err(anyhow!("user not found"))?; }; authorize_access_to_language_models(&session).await?; - session - .rate_limiter - .check::(session.user_id()) - .await?; - let mut provider_and_model = request.model.split('/'); - let (provider, model) = match ( - provider_and_model.next().unwrap(), - provider_and_model.next(), - ) { - (provider, Some(model)) => (provider, model), - (model, None) => { - if model.starts_with("gpt") { - ("openai", model) - } else if model.starts_with("gemini") { - ("google", model) - } else if model.starts_with("claude") { - ("anthropic", model) - } else { - ("unknown", model) - } + match proto::LanguageModelRequestKind::from_i32(query.kind) { + Some(proto::LanguageModelRequestKind::Complete) => { + session + .rate_limiter + .check::(session.user_id()) + .await?; } - }; - let provider = provider.to_string(); - request.model = model.to_string(); - - match provider.as_str() { - "openai" => { - let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; - complete_with_open_ai(request, response, session, api_key).await?; + Some(proto::LanguageModelRequestKind::CountTokens) => { + session + .rate_limiter + .check::(session.user_id()) + .await?; } - "anthropic" => { + None => Err(anyhow!("unknown request kind"))?, + } + + match proto::LanguageModelProvider::from_i32(query.provider) { + Some(proto::LanguageModelProvider::Anthropic) => { let api_key = anthropic_api_key.context("no Anthropic AI API key configured on the server")?; - complete_with_anthropic(request, response, session, api_key).await?; + let mut chunks = anthropic::stream_completion( + session.http_client.as_ref(), + anthropic::ANTHROPIC_API_URL, + &api_key, + serde_json::from_str(&query.request)?, + None, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; + } } - "google" => { + Some(proto::LanguageModelProvider::OpenAi) => { + let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; + let mut chunks = open_ai::stream_completion( + session.http_client.as_ref(), + open_ai::OPEN_AI_API_URL, + &api_key, + serde_json::from_str(&query.request)?, + None, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; + } + } + Some(proto::LanguageModelProvider::Google) => { let api_key = google_ai_api_key.context("no Google AI API key configured on the server")?; - complete_with_google_ai(request, response, session, api_key).await?; - } - provider => return Err(anyhow!("unknown provider {:?}", provider))?, - } - - Ok(()) -} - -async fn complete_with_open_ai( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut completion_stream = open_ai::stream_completion( - session.http_client.as_ref(), - OPEN_AI_API_URL, - &api_key, - crate::ai::language_model_request_to_open_ai(request)?, - None, - ) - .await - .context("open_ai::stream_completion request failed within collab")?; - - while let Some(event) = completion_stream.next().await { - let event = event?; - response.send(proto::LanguageModelResponse { - choices: event - .choices - .into_iter() - .map(|choice| proto::LanguageModelChoiceDelta { - index: choice.index, - delta: Some(proto::LanguageModelResponseMessage { - role: choice.delta.role.map(|role| match role { - open_ai::Role::User => LanguageModelRole::LanguageModelUser, - open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant, - open_ai::Role::System => LanguageModelRole::LanguageModelSystem, - open_ai::Role::Tool => LanguageModelRole::LanguageModelTool, - } as i32), - content: choice.delta.content, - tool_calls: choice - .delta - .tool_calls - .unwrap_or_default() - .into_iter() - .map(|delta| proto::ToolCallDelta { - index: delta.index as u32, - id: delta.id, - variant: match delta.function { - Some(function) => { - let name = function.name; - let arguments = function.arguments; - - Some(proto::tool_call_delta::Variant::Function( - proto::tool_call_delta::FunctionCallDelta { - name, - arguments, - }, - )) - } - None => None, - }, - }) - .collect(), - }), - finish_reason: choice.finish_reason, - }) - .collect(), - })?; - } - - Ok(()) -} - -async fn complete_with_google_ai( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut stream = google_ai::stream_generate_content( - session.http_client.clone(), - google_ai::API_URL, - api_key.as_ref(), - &request.model.clone(), - crate::ai::language_model_request_to_google_ai(request)?, - ) - .await - .context("google_ai::stream_generate_content request failed")?; - - while let Some(event) = stream.next().await { - let event = event?; - response.send(proto::LanguageModelResponse { - choices: event - .candidates - .unwrap_or_default() - .into_iter() - .map(|candidate| proto::LanguageModelChoiceDelta { - index: candidate.index as u32, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(match candidate.content.role { - google_ai::Role::User => LanguageModelRole::LanguageModelUser, - google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant, - } as i32), - content: Some( - candidate - .content - .parts - .into_iter() - .filter_map(|part| match part { - google_ai::Part::TextPart(part) => Some(part.text), - google_ai::Part::InlineDataPart(_) => None, - }) - .collect(), - ), - // Tool calls are not supported for Google - tool_calls: Vec::new(), - }), - finish_reason: candidate.finish_reason.map(|reason| reason.to_string()), - }) - .collect(), - })?; - } - - Ok(()) -} - -async fn complete_with_anthropic( - request: proto::CompleteWithLanguageModel, - response: StreamingResponse, - session: UserSession, - api_key: Arc, -) -> Result<()> { - let mut system_message = String::new(); - let messages = request - .messages - .into_iter() - .filter_map(|message| { - match message.role() { - LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage { - role: anthropic::Role::User, - content: message.content, - }), - LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage { - role: anthropic::Role::Assistant, - content: message.content, - }), - // Anthropic's API breaks system instructions out as a separate field rather - // than having a system message role. - LanguageModelRole::LanguageModelSystem => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - - None - } - // We don't yet support tool calls for Anthropic - LanguageModelRole::LanguageModelTool => None, - } - }) - .collect(); - - let mut stream = anthropic::stream_completion( - session.http_client.as_ref(), - anthropic::ANTHROPIC_API_URL, - &api_key, - anthropic::Request { - model: request.model, - messages, - stream: true, - system: system_message, - max_tokens: 4092, - }, - None, - ) - .await?; - - let mut current_role = proto::LanguageModelRole::LanguageModelAssistant; - - while let Some(event) = stream.next().await { - let event = event?; - match event { - anthropic::ResponseEvent::MessageStart { message } => { - if let Some(role) = message.role { - if role == "assistant" { - current_role = proto::LanguageModelRole::LanguageModelAssistant; - } else if role == "user" { - current_role = proto::LanguageModelRole::LanguageModelUser; - } - } - } - anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => { - match content_block { - anthropic::ContentBlock::Text { text } => { - if !text.is_empty() { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(current_role as i32), - content: Some(text), - tool_calls: Vec::new(), - }), - finish_reason: None, - }], - })?; - } + match proto::LanguageModelRequestKind::from_i32(query.kind) { + Some(proto::LanguageModelRequestKind::Complete) => { + let mut chunks = google_ai::stream_generate_content( + session.http_client.as_ref(), + google_ai::API_URL, + &api_key, + serde_json::from_str(&query.request)?, + ) + .await?; + while let Some(chunk) = chunks.next().await { + let chunk = chunk?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&chunk)?, + })?; } } - } - anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta { - anthropic::TextDelta::TextDelta { text } => { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: Some(proto::LanguageModelResponseMessage { - role: Some(current_role as i32), - content: Some(text), - tool_calls: Vec::new(), - }), - finish_reason: None, - }], - })?; - } - }, - anthropic::ResponseEvent::MessageDelta { delta, .. } => { - if let Some(stop_reason) = delta.stop_reason { - response.send(proto::LanguageModelResponse { - choices: vec![proto::LanguageModelChoiceDelta { - index: 0, - delta: None, - finish_reason: Some(stop_reason), - }], + Some(proto::LanguageModelRequestKind::CountTokens) => { + let tokens_response = google_ai::count_tokens( + session.http_client.as_ref(), + google_ai::API_URL, + &api_key, + serde_json::from_str(&query.request)?, + ) + .await?; + response.send(proto::QueryLanguageModelResponse { + response: serde_json::to_string(&tokens_response)?, })?; } + None => Err(anyhow!("unknown request kind"))?, } - anthropic::ResponseEvent::ContentBlockStop { .. } => {} - anthropic::ResponseEvent::MessageStop {} => {} - anthropic::ResponseEvent::Ping {} => {} } + None => return Err(anyhow!("unknown provider"))?, } Ok(()) @@ -4830,41 +4627,6 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit { } } -async fn count_tokens_with_language_model( - request: proto::CountTokensWithLanguageModel, - response: Response, - session: UserSession, - google_ai_api_key: Option>, -) -> Result<()> { - authorize_access_to_language_models(&session).await?; - - if !request.model.starts_with("gemini") { - return Err(anyhow!( - "counting tokens for model: {:?} is not supported", - request.model - ))?; - } - - session - .rate_limiter - .check::(session.user_id()) - .await?; - - let api_key = google_ai_api_key - .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; - let tokens_response = google_ai::count_tokens( - session.http_client.as_ref(), - google_ai::API_URL, - &api_key, - crate::ai::count_tokens_request_to_google_ai(request)?, - ) - .await?; - response.send(proto::CountTokensResponse { - token_count: tokens_response.total_tokens as u32, - })?; - Ok(()) -} - struct ComputeEmbeddingsRateLimit; impl RateLimit for ComputeEmbeddingsRateLimit { diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml index 1495f55a31e0cbcfbdfc163afa9758c9f067bcfe..f923e0ec91742652a70e9e99ddd6604321e401a5 100644 --- a/crates/google_ai/Cargo.toml +++ b/crates/google_ai/Cargo.toml @@ -11,9 +11,14 @@ workspace = true [lib] path = "src/google_ai.rs" +[features] +schemars = ["dep:schemars"] + [dependencies] anyhow.workspace = true futures.workspace = true http_client.workspace = true +schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +strum.workspace = true diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 34c43176d0a218f71bf0dfc319350f8f38c5bdf3..b2ecf332433320ab2841ceac4a26027dd0e73910 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,23 +1,21 @@ -use std::sync::Arc; - use anyhow::{anyhow, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::HttpClient; use serde::{Deserialize, Serialize}; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; pub async fn stream_generate_content( - client: Arc, + client: &dyn HttpClient, api_url: &str, api_key: &str, - model: &str, - request: GenerateContentRequest, + mut request: GenerateContentRequest, ) -> Result>> { let uri = format!( - "{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}", - api_url, api_key + "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}", + model = request.model ); + request.model.clear(); let request = serde_json::to_string(&request)?; let mut response = client.post_json(&uri, request.into()).await?; @@ -52,8 +50,8 @@ pub async fn stream_generate_content( } } -pub async fn count_tokens( - client: &T, +pub async fn count_tokens( + client: &dyn HttpClient, api_url: &str, api_key: &str, request: CountTokensRequest, @@ -91,22 +89,24 @@ pub enum Task { BatchEmbedContents, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentRequest { + #[serde(default, skip_serializing_if = "String::is_empty")] + pub model: String, pub contents: Vec, pub generation_config: Option, pub safety_settings: Option>, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentResponse { pub candidates: Option>, pub prompt_feedback: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct GenerateContentCandidate { pub index: usize, @@ -157,7 +157,7 @@ pub struct GenerativeContentBlob { pub data: String, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationSource { pub start_index: Option, @@ -166,13 +166,13 @@ pub struct CitationSource { pub license: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CitationMetadata { pub citation_sources: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct PromptFeedback { pub block_reason: Option, @@ -180,7 +180,7 @@ pub struct PromptFeedback { pub block_reason_message: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct GenerationConfig { pub candidate_count: Option, @@ -191,7 +191,7 @@ pub struct GenerationConfig { pub top_k: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetySetting { pub category: HarmCategory, @@ -224,7 +224,7 @@ pub enum HarmCategory { DangerousContent, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub enum HarmBlockThreshold { #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] Unspecified, @@ -238,7 +238,7 @@ pub enum HarmBlockThreshold { BlockNone, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum HarmProbability { #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] @@ -249,21 +249,85 @@ pub enum HarmProbability { High, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetyRating { pub category: HarmCategory, pub probability: HarmProbability, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensRequest { pub contents: Vec, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct CountTokensResponse { pub total_tokens: usize, } + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)] +pub enum Model { + #[serde(rename = "gemini-1.5-pro")] + Gemini15Pro, + #[serde(rename = "gemini-1.5-flash")] + Gemini15Flash, + #[serde(rename = "custom")] + Custom { name: String, max_tokens: usize }, +} + +impl Model { + pub fn id(&self) -> &str { + match self { + Model::Gemini15Pro => "gemini-1.5-pro", + Model::Gemini15Flash => "gemini-1.5-flash", + Model::Custom { name, .. } => name, + } + } + + pub fn display_name(&self) -> &str { + match self { + Model::Gemini15Pro => "Gemini 1.5 Pro", + Model::Gemini15Flash => "Gemini 1.5 Flash", + Model::Custom { name, .. } => name, + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Model::Gemini15Pro => 2_000_000, + Model::Gemini15Flash => 1_000_000, + Model::Custom { max_tokens, .. } => *max_tokens, + } + } +} + +impl std::fmt::Display for Model { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.id()) + } +} + +pub fn extract_text_from_events( + events: impl Stream>, +) -> impl Stream> { + events.filter_map(|event| async move { + match event { + Ok(event) => event.candidates.and_then(|candidates| { + candidates.into_iter().next().and_then(|candidate| { + candidate.content.parts.into_iter().next().and_then(|part| { + if let Part::TextPart(TextPart { text }) = part { + Some(Ok(text)) + } else { + None + } + }) + }) + }), + Err(error) => Some(Err(error)), + } + }) +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 1a099897a359a04b284f519d1b0518c6ba522a5f..de3ba8ef650c7cf17484a499ef6ed0468726e444 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -28,6 +28,7 @@ collections.workspace = true editor.workspace = true feature_flags.workspace = true futures.workspace = true +google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true http_client.workspace = true menu.workspace = true diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index b7b304a65d683ed981faaa7153ef60b087dc1950..1023ee337a66649411af6486cf2b0e4289a98140 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,108 +1,42 @@ -pub use anthropic::Model as AnthropicModel; -use anyhow::{anyhow, Result}; -pub use ollama::Model as OllamaModel; -pub use open_ai::Model as OpenAiModel; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use strum::EnumIter; -#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(tag = "provider", rename_all = "lowercase")] pub enum CloudModel { - #[serde(rename = "gpt-3.5-turbo")] - Gpt3Point5Turbo, - #[serde(rename = "gpt-4")] - Gpt4, - #[serde(rename = "gpt-4-turbo-preview")] - Gpt4Turbo, - #[serde(rename = "gpt-4o")] - #[default] - Gpt4Omni, - #[serde(rename = "gpt-4o-mini")] - Gpt4OmniMini, - #[serde(rename = "claude-3-5-sonnet")] - Claude3_5Sonnet, - #[serde(rename = "claude-3-opus")] - Claude3Opus, - #[serde(rename = "claude-3-sonnet")] - Claude3Sonnet, - #[serde(rename = "claude-3-haiku")] - Claude3Haiku, - #[serde(rename = "gemini-1.5-pro")] - Gemini15Pro, - #[serde(rename = "gemini-1.5-flash")] - Gemini15Flash, - #[serde(rename = "custom")] - Custom { - name: String, - max_tokens: Option, - }, + Anthropic(anthropic::Model), + OpenAi(open_ai::Model), + Google(google_ai::Model), } -impl CloudModel { - pub fn from_id(value: &str) -> Result { - match value { - "gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo), - "gpt-4" => Ok(Self::Gpt4), - "gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo), - "gpt-4o" => Ok(Self::Gpt4Omni), - "gpt-4o-mini" => Ok(Self::Gpt4OmniMini), - "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet), - "claude-3-opus" => Ok(Self::Claude3Opus), - "claude-3-sonnet" => Ok(Self::Claude3Sonnet), - "claude-3-haiku" => Ok(Self::Claude3Haiku), - "gemini-1.5-pro" => Ok(Self::Gemini15Pro), - "gemini-1.5-flash" => Ok(Self::Gemini15Flash), - _ => Err(anyhow!("invalid model id")), - } +impl Default for CloudModel { + fn default() -> Self { + Self::Anthropic(anthropic::Model::default()) } +} +impl CloudModel { pub fn id(&self) -> &str { match self { - Self::Gpt3Point5Turbo => "gpt-3.5-turbo", - Self::Gpt4 => "gpt-4", - Self::Gpt4Turbo => "gpt-4-turbo-preview", - Self::Gpt4Omni => "gpt-4o", - Self::Gpt4OmniMini => "gpt-4o-mini", - Self::Claude3_5Sonnet => "claude-3-5-sonnet", - Self::Claude3Opus => "claude-3-opus", - Self::Claude3Sonnet => "claude-3-sonnet", - Self::Claude3Haiku => "claude-3-haiku", - Self::Gemini15Pro => "gemini-1.5-pro", - Self::Gemini15Flash => "gemini-1.5-flash", - Self::Custom { name, .. } => name, + CloudModel::Anthropic(model) => model.id(), + CloudModel::OpenAi(model) => model.id(), + CloudModel::Google(model) => model.id(), } } pub fn display_name(&self) -> &str { match self { - Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", - Self::Gpt4 => "GPT 4", - Self::Gpt4Turbo => "GPT 4 Turbo", - Self::Gpt4Omni => "GPT 4 Omni", - Self::Gpt4OmniMini => "GPT 4 Omni Mini", - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", - Self::Claude3Opus => "Claude 3 Opus", - Self::Claude3Sonnet => "Claude 3 Sonnet", - Self::Claude3Haiku => "Claude 3 Haiku", - Self::Gemini15Pro => "Gemini 1.5 Pro", - Self::Gemini15Flash => "Gemini 1.5 Flash", - Self::Custom { name, .. } => name, + CloudModel::Anthropic(model) => model.display_name(), + CloudModel::OpenAi(model) => model.display_name(), + CloudModel::Google(model) => model.display_name(), } } pub fn max_token_count(&self) -> usize { match self { - Self::Gpt3Point5Turbo => 2048, - Self::Gpt4 => 4096, - Self::Gpt4Turbo | Self::Gpt4Omni => 128000, - Self::Gpt4OmniMini => 128000, - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 200000, - Self::Gemini15Pro => 128000, - Self::Gemini15Flash => 32000, - Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000), + CloudModel::Anthropic(model) => model.max_token_count(), + CloudModel::OpenAi(model) => model.max_token_count(), + CloudModel::Google(model) => model.max_token_count(), } } } diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs index f2713db003e6bfcfbde230af978286a9f52e9783..6fe0bfd7a1e71c8c63221600564d5afde0dd9871 100644 --- a/crates/language_model/src/provider.rs +++ b/crates/language_model/src/provider.rs @@ -2,5 +2,6 @@ pub mod anthropic; pub mod cloud; #[cfg(any(test, feature = "test-support"))] pub mod fake; +pub mod google; pub mod ollama; pub mod open_ai; diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 52ac22b29f33c53d0b871592e184dbc5c78a0d40..7cc9922546d4ad2540a1ed9a5e45e971752d2709 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -1,4 +1,4 @@ -use anthropic::{stream_completion, Request, RequestMessage}; +use anthropic::stream_completion; use anyhow::{anyhow, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; @@ -18,7 +18,7 @@ use util::ResultExt; use crate::{ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role, + LanguageModelProviderState, LanguageModelRequest, Role, }; const PROVIDER_ID: &str = "anthropic"; @@ -160,40 +160,6 @@ pub struct AnthropicModel { http_client: Arc, } -impl AnthropicModel { - fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { - preprocess_anthropic_request(&mut request); - - let mut system_message = String::new(); - if request - .messages - .first() - .map_or(false, |message| message.role == Role::System) - { - system_message = request.messages.remove(0).content; - } - - Request { - model: self.model.id().to_string(), - messages: request - .messages - .iter() - .map(|msg| RequestMessage { - role: match msg.role { - Role::User => anthropic::Role::User, - Role::Assistant => anthropic::Role::Assistant, - Role::System => unreachable!("filtered out by preprocess_request"), - }, - content: msg.content.clone(), - }) - .collect(), - stream: true, - system: system_message, - max_tokens: 4092, - } - } -} - pub fn count_anthropic_tokens( request: LanguageModelRequest, cx: &AppContext, @@ -260,7 +226,7 @@ impl LanguageModel for AnthropicModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = self.to_anthropic_request(request); + let request = request.into_anthropic(self.model.id().into()); let http_client = self.http_client.clone(); @@ -285,75 +251,12 @@ impl LanguageModel for AnthropicModel { low_speed_timeout, ); let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(response) => match response { - anthropic::ResponseEvent::ContentBlockStart { - content_block, .. - } => match content_block { - anthropic::ContentBlock::Text { text } => Some(Ok(text)), - }, - anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => { - match delta { - anthropic::TextDelta::TextDelta { text } => Some(Ok(text)), - } - } - _ => None, - }, - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) + Ok(anthropic::extract_text_from_events(response).boxed()) } .boxed() } } -pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages.drain(..) { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == message.role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - new_messages.push(message); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } - } - - if !system_message.is_empty() { - new_messages.insert( - 0, - LanguageModelRequestMessage { - role: Role::System, - content: system_message, - }, - ); - } - - request.messages = new_messages; -} - struct AuthenticationPrompt { api_key: View, state: gpui::Model, diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 1cd8b99e98acc661939821b115cc3fda53ded86d..d290876ad9ccf50acf07d05a77b4298fb3a9d5f3 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -7,8 +7,10 @@ use crate::{ use anyhow::Result; use client::Client; use collections::BTreeMap; -use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::sync::Arc; use strum::IntoEnumIterator; @@ -16,14 +18,29 @@ use ui::prelude::*; use crate::LanguageModelProvider; -use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request}; +use super::anthropic::count_anthropic_tokens; pub const PROVIDER_ID: &str = "zed.dev"; pub const PROVIDER_NAME: &str = "zed.dev"; #[derive(Default, Clone, Debug, PartialEq)] pub struct ZedDotDevSettings { - pub available_models: Vec, + pub available_models: Vec, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "lowercase")] +pub enum AvailableProvider { + Anthropic, + OpenAi, + Google, +} + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] +pub struct AvailableModel { + provider: AvailableProvider, + name: String, + max_tokens: usize, } pub struct CloudLanguageModelProvider { @@ -100,10 +117,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider { fn provided_models(&self, cx: &AppContext) -> Vec> { let mut models = BTreeMap::default(); - // Add base models from CloudModel::iter() - for model in CloudModel::iter() { - if !matches!(model, CloudModel::Custom { .. }) { - models.insert(model.id().to_string(), model); + for model in anthropic::Model::iter() { + if !matches!(model, anthropic::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Anthropic(model)); + } + } + for model in open_ai::Model::iter() { + if !matches!(model, open_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::OpenAi(model)); + } + } + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), CloudModel::Google(model)); } } @@ -112,6 +138,20 @@ impl LanguageModelProvider for CloudLanguageModelProvider { .zed_dot_dev .available_models { + let model = match model.provider { + AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom { + name: model.name.clone(), + max_tokens: model.max_tokens, + }), + }; models.insert(model.id().to_string(), model.clone()); } @@ -183,35 +223,26 @@ impl LanguageModel for CloudLanguageModel { request: LanguageModelRequest, cx: &AppContext, ) -> BoxFuture<'static, Result> { - match &self.model { - CloudModel::Gpt3Point5Turbo => { - count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx) - } - CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx), - CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx), - CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx), - CloudModel::Gpt4OmniMini => { - count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx) - } - CloudModel::Claude3_5Sonnet - | CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx), - CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { - count_anthropic_tokens(request, cx) - } - _ => { - let request = self.client.request(proto::CountTokensWithLanguageModel { - model: self.model.id().to_string(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - }); + match self.model.clone() { + CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx), + CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx), + CloudModel::Google(model) => { + let client = self.client.clone(); + let request = request.into_google(model.id().into()); + let request = google_ai::CountTokensRequest { + contents: request.contents, + }; async move { - let response = request.await?; - Ok(response.token_count as usize) + let request = serde_json::to_string(&request)?; + let response = client.request(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + kind: proto::LanguageModelRequestKind::CountTokens as i32, + request, + }); + let response = response.await?; + let response = + serde_json::from_str::(&response.response)?; + Ok(response.total_tokens) } .boxed() } @@ -220,46 +251,65 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, - mut request: LanguageModelRequest, + request: LanguageModelRequest, _: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { match &self.model { - CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku - | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request), - CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => { - preprocess_anthropic_request(&mut request) + CloudModel::Anthropic(model) => { + let client = self.client.clone(); + let request = request.into_anthropic(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(anthropic::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() + } + CloudModel::OpenAi(model) => { + let client = self.client.clone(); + let request = request.into_open_ai(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(open_ai::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() + } + CloudModel::Google(model) => { + let client = self.client.clone(); + let request = request.into_google(model.id().into()); + async move { + let request = serde_json::to_string(&request)?; + let response = client.request_stream(proto::QueryLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + kind: proto::LanguageModelRequestKind::Complete as i32, + request, + }); + let chunks = response.await?; + Ok(google_ai::extract_text_from_events( + chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + ) + .boxed()) + } + .boxed() } - _ => {} } - - let request = proto::CompleteWithLanguageModel { - model: self.id.0.to_string(), - messages: request - .messages - .iter() - .map(|message| message.to_proto()) - .collect(), - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - }; - - self.client - .request_stream(request) - .map_ok(|stream| { - stream - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed() - }) - .boxed() } } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs new file mode 100644 index 0000000000000000000000000000000000000000..3a0c0a3f7e25db16c3e0517b68ec33a14e7f0aa7 --- /dev/null +++ b/crates/language_model/src/provider/google.rs @@ -0,0 +1,351 @@ +use anyhow::{anyhow, Result}; +use collections::BTreeMap; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, FutureExt, StreamExt}; +use google_ai::stream_generate_content; +use gpui::{ + AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View, + WhiteSpace, +}; +use http_client::HttpClient; +use settings::{Settings, SettingsStore}; +use std::{sync::Arc, time::Duration}; +use strum::IntoEnumIterator; +use theme::ThemeSettings; +use ui::prelude::*; +use util::ResultExt; + +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, +}; + +const PROVIDER_ID: &str = "google"; +const PROVIDER_NAME: &str = "Google AI"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct GoogleSettings { + pub api_url: String, + pub low_speed_timeout: Option, + pub available_models: Vec, +} + +pub struct GoogleLanguageModelProvider { + http_client: Arc, + state: gpui::Model, +} + +struct State { + api_key: Option, + _subscription: Subscription, +} + +impl GoogleLanguageModelProvider { + pub fn new(http_client: Arc, cx: &mut AppContext) -> Self { + let state = cx.new_model(|cx| State { + api_key: None, + _subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }), + }); + + Self { http_client, state } + } +} + +impl LanguageModelProviderState for GoogleLanguageModelProvider { + fn subscribe(&self, cx: &mut gpui::ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for GoogleLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, cx: &AppContext) -> Vec> { + let mut models = BTreeMap::default(); + + // Add base models from google_ai::Model::iter() + for model in google_ai::Model::iter() { + if !matches!(model, google_ai::Model::Custom { .. }) { + models.insert(model.id().to_string(), model); + } + } + + // Override with available models from settings + for model in &AllLanguageModelSettings::get_global(cx) + .google + .available_models + { + models.insert(model.id().to_string(), model.clone()); + } + + models + .into_values() + .map(|model| { + Arc::new(GoogleLanguageModel { + id: LanguageModelId::from(model.id().to_string()), + model, + state: self.state.clone(), + http_client: self.http_client.clone(), + }) as Arc + }) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + self.state.read(cx).api_key.is_some() + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated(cx) { + Task::ready(Ok(())) + } else { + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + let state = self.state.clone(); + cx.spawn(|mut cx| async move { + let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") { + api_key + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + String::from_utf8(api_key)? + }; + + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + } + } + + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx)) + .into() + } + + fn reset_credentials(&self, cx: &AppContext) -> Task> { + let state = self.state.clone(); + let delete_credentials = + cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url); + cx.spawn(|mut cx| async move { + delete_credentials.await.log_err(); + state.update(&mut cx, |this, cx| { + this.api_key = None; + cx.notify(); + }) + }) + } +} + +pub struct GoogleLanguageModel { + id: LanguageModelId, + model: google_ai::Model, + state: gpui::Model, + http_client: Arc, +} + +impl LanguageModel for GoogleLanguageModel { + fn id(&self) -> LanguageModelId { + self.id.clone() + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("google/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + let request = request.into_google(self.model.id().to_string()); + let http_client = self.http_client.clone(); + let api_key = self.state.read(cx).api_key.clone(); + let api_url = AllLanguageModelSettings::get_global(cx) + .google + .api_url + .clone(); + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let response = google_ai::count_tokens( + http_client.as_ref(), + &api_url, + &api_key, + google_ai::CountTokensRequest { + contents: request.contents, + }, + ) + .await?; + Ok(response.total_tokens) + } + .boxed() + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let request = request.into_google(self.model.id().to_string()); + + let http_client = self.http_client.clone(); + let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).google; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let response = + stream_generate_content(http_client.as_ref(), &api_url, &api_key, request); + let events = response.await?; + Ok(google_ai::extract_text_from_events(events).boxed()) + } + .boxed() + } +} + +struct AuthenticationPrompt { + api_key: View, + state: gpui::Model, +} + +impl AuthenticationPrompt { + fn new(state: gpui::Model, cx: &mut WindowContext) -> Self { + Self { + api_key: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text("AIzaSy...", cx); + editor + }), + state, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let api_key = self.api_key.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let settings = &AllLanguageModelSettings::get_global(cx).google; + let write_credentials = + cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes()); + let state = self.state.clone(); + cx.spawn(|_, mut cx| async move { + write_credentials.await?; + state.update(&mut cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) + .detach_and_log_err(cx); + } + + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_size: rems(0.875).into(), + font_weight: settings.ui_font.weight, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + }; + EditorElement::new( + &self.api_key, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const INSTRUCTIONS: [&str; 4] = [ + "To use the Google AI assistant, you need to add your Google AI API key.", + "You can create an API key at: https://makersuite.google.com/app/apikey", + "", + "Paste your Google AI API key below and hit enter to use the assistant:", + ]; + + v_flex() + .p_4() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .child( + h_flex() + .gap_2() + .child(Label::new("Click on").size(LabelSize::Small)) + .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall)) + .child( + Label::new("in the status bar to close this panel.").size(LabelSize::Small), + ), + ) + .into_any() + } +} diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index c81a43594630ffafb4ae7399bc25df52597b8067..1b3bf18dd5a1208e98d4d7eb29685eced319a060 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -7,7 +7,7 @@ use gpui::{ WhiteSpace, }; use http_client::HttpClient; -use open_ai::{stream_completion, Request, RequestMessage}; +use open_ai::stream_completion; use settings::{Settings, SettingsStore}; use std::{sync::Arc, time::Duration}; use strum::IntoEnumIterator; @@ -159,35 +159,6 @@ pub struct OpenAiLanguageModel { http_client: Arc, } -impl OpenAiLanguageModel { - fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { - Request { - model: self.model.clone(), - messages: request - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => RequestMessage::User { - content: msg.content, - }, - Role::Assistant => RequestMessage::Assistant { - content: Some(msg.content), - tool_calls: Vec::new(), - }, - Role::System => RequestMessage::System { - content: msg.content, - }, - }) - .collect(), - stream: true, - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - } - } -} - impl LanguageModel for OpenAiLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -226,7 +197,7 @@ impl LanguageModel for OpenAiLanguageModel { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { - let request = self.to_open_ai_request(request); + let request = request.into_open_ai(self.model.id().into()); let http_client = self.http_client.clone(); let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { @@ -250,15 +221,7 @@ impl LanguageModel for OpenAiLanguageModel { low_speed_timeout, ); let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) + Ok(open_ai::extract_text_from_events(response).boxed()) } .boxed() } diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index e787f5f7e75c85f3fbda6570a7db5deddaaffd06..05dcbced5ddedd5df3a13ebed4aa1ad0bcce8886 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,16 +1,16 @@ -use client::Client; -use collections::BTreeMap; -use gpui::{AppContext, Global, Model, ModelContext}; -use std::sync::Arc; -use ui::Context; - use crate::{ provider::{ anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, - ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, + google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider, + open_ai::OpenAiLanguageModelProvider, }, LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, }; +use client::Client; +use collections::BTreeMap; +use gpui::{AppContext, Global, Model, ModelContext}; +use std::sync::Arc; +use ui::Context; pub fn init(client: Arc, cx: &mut AppContext) { let registry = cx.new_model(|cx| { @@ -40,6 +40,10 @@ fn register_language_model_providers( OllamaLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider( + GoogleLanguageModelProvider::new(client.http_client(), cx), + cx, + ); cx.observe_flag::(move |enabled, cx| { let client = client.clone(); diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index e3e1d3e77b6067a4d5ca945a8e8f96152e2ef878..fc3b8c019282c5d8ec445cfad51ff557ba582dbe 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -1,4 +1,4 @@ -use crate::{role::Role, LanguageModelId}; +use crate::role::Role; use serde::{Deserialize, Serialize}; #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -7,17 +7,6 @@ pub struct LanguageModelRequestMessage { pub content: String, } -impl LanguageModelRequestMessage { - pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { - proto::LanguageModelRequestMessage { - role: self.role.to_proto() as i32, - content: self.content.clone(), - tool_calls: Vec::new(), - tool_call_id: None, - } - } -} - #[derive(Debug, Default, Serialize, Deserialize)] pub struct LanguageModelRequest { pub messages: Vec, @@ -26,14 +15,110 @@ pub struct LanguageModelRequest { } impl LanguageModelRequest { - pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel { - proto::CompleteWithLanguageModel { - model: model_id.0.to_string(), - messages: self.messages.iter().map(|m| m.to_proto()).collect(), - stop: self.stop.clone(), + pub fn into_open_ai(self, model: String) -> open_ai::Request { + open_ai::Request { + model, + messages: self + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => open_ai::RequestMessage::User { + content: msg.content, + }, + Role::Assistant => open_ai::RequestMessage::Assistant { + content: Some(msg.content), + tool_calls: Vec::new(), + }, + Role::System => open_ai::RequestMessage::System { + content: msg.content, + }, + }) + .collect(), + stream: true, + stop: self.stop, temperature: self.temperature, - tool_choice: None, tools: Vec::new(), + tool_choice: None, + } + } + + pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest { + google_ai::GenerateContentRequest { + model, + contents: self + .messages + .into_iter() + .map(|msg| google_ai::Content { + parts: vec![google_ai::Part::TextPart(google_ai::TextPart { + text: msg.content, + })], + role: match msg.role { + Role::User => google_ai::Role::User, + Role::Assistant => google_ai::Role::Model, + Role::System => google_ai::Role::User, // Google AI doesn't have a system role + }, + }) + .collect(), + generation_config: Some(google_ai::GenerationConfig { + candidate_count: Some(1), + stop_sequences: Some(self.stop), + max_output_tokens: None, + temperature: Some(self.temperature as f64), + top_p: None, + top_k: None, + }), + safety_settings: None, + } + } + + pub fn into_anthropic(self, model: String) -> anthropic::Request { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in self.messages { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + anthropic::Request { + model, + messages: new_messages + .into_iter() + .filter_map(|message| { + Some(anthropic::RequestMessage { + role: match message.role { + Role::User => anthropic::Role::User, + Role::Assistant => anthropic::Role::Assistant, + Role::System => return None, + }, + content: message.content, + }) + }) + .collect(), + stream: true, + max_tokens: 4092, + system: system_message, } } } diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs index f6276a4823651c200b207c44eb612e16283ac913..82184038f63834f4513fcacd181ffb527295a04b 100644 --- a/crates/language_model/src/role.rs +++ b/crates/language_model/src/role.rs @@ -15,7 +15,6 @@ impl Role { Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, - Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, None => Role::User, } } diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 262e14937a67b019ed5ef71d0d7f1c4d546c8e98..85ae91649a58cb66bda36172d02f8b4f063e3b07 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -6,12 +6,12 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use crate::{ - provider::{ - anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings, - open_ai::OpenAiSettings, - }, - CloudModel, +use crate::provider::{ + anthropic::AnthropicSettings, + cloud::{self, ZedDotDevSettings}, + google::GoogleSettings, + ollama::OllamaSettings, + open_ai::OpenAiSettings, }; /// Initializes the language model settings. @@ -25,6 +25,7 @@ pub struct AllLanguageModelSettings { pub ollama: OllamaSettings, pub openai: OpenAiSettings, pub zed_dot_dev: ZedDotDevSettings, + pub google: GoogleSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -34,6 +35,7 @@ pub struct AllLanguageModelSettingsContent { pub openai: Option, #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, + pub google: Option, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -56,9 +58,16 @@ pub struct OpenAiSettingsContent { pub available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct GoogleSettingsContent { + pub api_url: Option, + pub low_speed_timeout_in_seconds: Option, + pub available_models: Option>, +} + #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] pub struct ZedDotDevSettingsContent { - available_models: Option>, + available_models: Option>, } impl settings::Settings for AllLanguageModelSettings { @@ -136,6 +145,26 @@ impl settings::Settings for AllLanguageModelSettings { .as_ref() .and_then(|s| s.available_models.clone()), ); + + merge( + &mut settings.google.api_url, + value.google.as_ref().and_then(|s| s.api_url.clone()), + ); + if let Some(low_speed_timeout_in_seconds) = value + .google + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.google.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout_in_seconds)); + } + merge( + &mut settings.google.available_models, + value + .google + .as_ref() + .and_then(|s| s.available_models.clone()), + ); } Ok(settings) diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index dfcd6646d14cf5068234e9e3f70599c9d115f367..13a6eb11d11900e5ae777bb6154889e4c699993e 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Context, Result}; -use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; @@ -111,38 +111,27 @@ impl Model { } } -fn serialize_model(model: &Model, serializer: S) -> Result -where - S: serde::Serializer, -{ - match model { - Model::Custom { name, .. } => serializer.serialize_str(name), - _ => serializer.serialize_str(model.id()), - } -} - -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Request { - #[serde(serialize_with = "serialize_model")] - pub model: Model, + pub model: String, pub messages: Vec, pub stream: bool, pub stop: Vec, pub temperature: f32, - #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default, skip_serializing_if = "Option::is_none")] pub tool_choice: Option, - #[serde(skip_serializing_if = "Vec::is_empty")] + #[serde(default, skip_serializing_if = "Vec::is_empty")] pub tools: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct FunctionDefinition { pub name: String, pub description: Option, pub parameters: Option>, } -#[derive(Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ToolDefinition { #[allow(dead_code)] @@ -213,21 +202,21 @@ pub struct FunctionChunk { pub arguments: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ChoiceDelta { pub index: u32, pub delta: ResponseMessageDelta, pub finish_reason: Option, } -#[derive(Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub struct ResponseStreamEvent { pub created: u32, pub model: String, @@ -369,3 +358,14 @@ pub fn embed<'a>( } } } + +pub fn extract_text_from_events( + response: impl Stream>, +) -> impl Stream> { + response.filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) +} diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 60f8d01558a75cf38bf55c4e9120e26bb6cd12f1..658d552848b3da22229405964c685ee6df2837c5 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -13,13 +13,6 @@ message Envelope { optional uint32 responding_to = 2; optional PeerId original_sender_id = 3; - /* - When you are adding a new message type, instead of adding it in semantic order - and bumping the message ID's of everything that follows, add it at the end of the - file and bump the max number. See this - https://github.com/zed-industries/zed/pull/7890#discussion_r1496621823 - - */ oneof payload { Hello hello = 4; Ack ack = 5; @@ -201,10 +194,8 @@ message Envelope { JoinHostedProject join_hosted_project = 164; - CompleteWithLanguageModel complete_with_language_model = 166; - LanguageModelResponse language_model_response = 167; - CountTokensWithLanguageModel count_tokens_with_language_model = 168; - CountTokensResponse count_tokens_response = 169; + QueryLanguageModel query_language_model = 224; + QueryLanguageModelResponse query_language_model_response = 225; // current max GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; @@ -271,10 +262,11 @@ message Envelope { UpdateDevServerProject update_dev_server_project = 221; AddWorktree add_worktree = 222; - AddWorktreeResponse add_worktree_response = 223; // current max + AddWorktreeResponse add_worktree_response = 223; } reserved 158 to 161; + reserved 166 to 169; } // Messages @@ -2051,94 +2043,32 @@ message SetRoomParticipantRole { ChannelRole role = 3; } -message CompleteWithLanguageModel { - string model = 1; - repeated LanguageModelRequestMessage messages = 2; - repeated string stop = 3; - float temperature = 4; - repeated ChatCompletionTool tools = 5; - optional string tool_choice = 6; -} - -// A tool presented to the language model for its use -message ChatCompletionTool { - oneof variant { - FunctionObject function = 1; - } - - message FunctionObject { - string name = 1; - optional string description = 2; - optional string parameters = 3; - } -} - -// A message to the language model -message LanguageModelRequestMessage { - LanguageModelRole role = 1; - string content = 2; - optional string tool_call_id = 3; - repeated ToolCall tool_calls = 4; -} - enum LanguageModelRole { LanguageModelUser = 0; LanguageModelAssistant = 1; LanguageModelSystem = 2; - LanguageModelTool = 3; -} - -message LanguageModelResponseMessage { - optional LanguageModelRole role = 1; - optional string content = 2; - repeated ToolCallDelta tool_calls = 3; -} - -// A request to call a tool, by the language model -message ToolCall { - string id = 1; - - oneof variant { - FunctionCall function = 2; - } - - message FunctionCall { - string name = 1; - string arguments = 2; - } -} - -message ToolCallDelta { - uint32 index = 1; - optional string id = 2; - - oneof variant { - FunctionCallDelta function = 3; - } - - message FunctionCallDelta { - optional string name = 1; - optional string arguments = 2; - } + reserved 3; } -message LanguageModelResponse { - repeated LanguageModelChoiceDelta choices = 1; +message QueryLanguageModel { + LanguageModelProvider provider = 1; + LanguageModelRequestKind kind = 2; + string request = 3; } -message LanguageModelChoiceDelta { - uint32 index = 1; - LanguageModelResponseMessage delta = 2; - optional string finish_reason = 3; +enum LanguageModelProvider { + Anthropic = 0; + OpenAI = 1; + Google = 2; } -message CountTokensWithLanguageModel { - string model = 1; - repeated LanguageModelRequestMessage messages = 2; +enum LanguageModelRequestKind { + Complete = 0; + CountTokens = 1; } -message CountTokensResponse { - uint32 token_count = 1; +message QueryLanguageModelResponse { + string response = 1; } message GetCachedEmbeddings { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index a205b79ecbb319c3194b64c56bbcd16e9499b5c9..7ef1866acd3e299dfb583e4a62a1beffb17fcdc6 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -203,12 +203,9 @@ messages!( (CancelCall, Foreground), (ChannelMessageSent, Foreground), (ChannelMessageUpdate, Foreground), - (CompleteWithLanguageModel, Background), (ComputeEmbeddings, Background), (ComputeEmbeddingsResponse, Background), (CopyProjectEntry, Foreground), - (CountTokensWithLanguageModel, Background), - (CountTokensResponse, Background), (CreateBufferForPeer, Foreground), (CreateChannel, Foreground), (CreateChannelResponse, Foreground), @@ -278,7 +275,6 @@ messages!( (JoinProjectResponse, Foreground), (JoinRoom, Foreground), (JoinRoomResponse, Foreground), - (LanguageModelResponse, Background), (LeaveChannelBuffer, Background), (LeaveChannelChat, Foreground), (LeaveProject, Foreground), @@ -298,6 +294,8 @@ messages!( (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), + (QueryLanguageModel, Background), + (QueryLanguageModelResponse, Background), (RefreshInlayHints, Foreground), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -412,9 +410,7 @@ request_messages!( (Call, Ack), (CancelCall, Ack), (CopyProjectEntry, ProjectEntryResponse), - (CompleteWithLanguageModel, LanguageModelResponse), (ComputeEmbeddings, ComputeEmbeddingsResponse), - (CountTokensWithLanguageModel, CountTokensResponse), (CreateChannel, CreateChannelResponse), (CreateProjectEntry, ProjectEntryResponse), (CreateRoom, CreateRoomResponse), @@ -467,6 +463,7 @@ request_messages!( (PerformRename, PerformRenameResponse), (Ping, Ack), (PrepareRename, PrepareRenameResponse), + (QueryLanguageModel, QueryLanguageModelResponse), (RefreshInlayHints, Ack), (RejoinChannelBuffers, RejoinChannelBuffersResponse), (RejoinRoom, RejoinRoomResponse),