Detailed changes
@@ -471,27 +471,6 @@ dependencies = [
"workspace",
]
-[[package]]
-name = "assistant_tooling"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "collections",
- "futures 0.3.28",
- "gpui",
- "log",
- "project",
- "repair_json",
- "schemars",
- "serde",
- "serde_json",
- "settings",
- "sum_tree",
- "ui",
- "unindent",
- "util",
-]
-
[[package]]
name = "async-attributes"
version = "1.1.2"
@@ -4811,8 +4790,10 @@ dependencies = [
"anyhow",
"futures 0.3.28",
"http_client",
+ "schemars",
"serde",
"serde_json",
+ "strum",
]
[[package]]
@@ -5988,6 +5969,7 @@ dependencies = [
"env_logger",
"feature_flags",
"futures 0.3.28",
+ "google_ai",
"gpui",
"http_client",
"language",
@@ -8715,15 +8697,6 @@ dependencies = [
"bytecheck",
]
-[[package]]
-name = "repair_json"
-version = "0.1.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5ee191e184125fe72cb59b74160e25584e3908f2aaa84cbda1e161347102aa15"
-dependencies = [
- "thiserror",
-]
-
[[package]]
name = "repl"
version = "0.1.0"
@@ -6,7 +6,6 @@ members = [
"crates/assets",
"crates/assistant",
"crates/assistant_slash_command",
- "crates/assistant_tooling",
"crates/audio",
"crates/auto_update",
"crates/breadcrumbs",
@@ -178,7 +177,6 @@ anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
assistant_slash_command = { path = "crates/assistant_slash_command" }
-assistant_tooling = { path = "crates/assistant_tooling" }
audio = { path = "crates/audio" }
auto_update = { path = "crates/auto_update" }
breadcrumbs = { path = "crates/breadcrumbs" }
@@ -870,6 +870,9 @@
"openai": {
"api_url": "https://api.openai.com/v1"
},
+ "google": {
+ "api_url": "https://generativelanguage.googleapis.com"
+ },
"ollama": {
"api_url": "http://localhost:11434"
}
@@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
-use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
@@ -98,7 +98,7 @@ impl From<Role> for String {
}
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
pub model: String,
pub messages: Vec<RequestMessage>,
@@ -113,7 +113,7 @@ pub struct RequestMessage {
pub content: String,
}
-#[derive(Deserialize, Debug)]
+#[derive(Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseEvent {
MessageStart {
@@ -138,7 +138,7 @@ pub enum ResponseEvent {
MessageStop {},
}
-#[derive(Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseMessage {
#[serde(rename = "type")]
pub message_type: Option<String>,
@@ -151,19 +151,19 @@ pub struct ResponseMessage {
pub usage: Option<Usage>,
}
-#[derive(Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
}
-#[derive(Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlock {
Text { text: String },
}
-#[derive(Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TextDelta {
TextDelta { text: String },
@@ -226,6 +226,25 @@ pub async fn stream_completion(
}
}
+pub fn extract_text_from_events(
+ response: impl Stream<Item = Result<ResponseEvent>>,
+) -> impl Stream<Item = Result<String>> {
+ response.filter_map(|response| async move {
+ match response {
+ Ok(response) => match response {
+ ResponseEvent::ContentBlockStart { content_block, .. } => match content_block {
+ ContentBlock::Text { text } => Some(Ok(text)),
+ },
+ ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
+ TextDelta::TextDelta { text } => Some(Ok(text)),
+ },
+ _ => None,
+ },
+ Err(error) => Some(Err(error)),
+ }
+ })
+}
+
// #[cfg(test)]
// mod tests {
// use super::*;
@@ -249,9 +249,7 @@ impl AssistantSettingsContent {
AssistantSettingsContent::Versioned(settings) => match settings {
VersionedAssistantSettingsContent::V1(settings) => match provider.as_ref() {
"zed.dev" => {
- settings.provider = Some(AssistantProviderContentV1::ZedDotDev {
- default_model: CloudModel::from_id(&model).ok(),
- });
+ log::warn!("attempted to set zed.dev model on outdated settings");
}
"anthropic" => {
let (api_url, low_speed_timeout_in_seconds) = match &settings.provider {
@@ -1,33 +0,0 @@
-[package]
-name = "assistant_tooling"
-version = "0.1.0"
-edition = "2021"
-publish = false
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/assistant_tooling.rs"
-
-[dependencies]
-anyhow.workspace = true
-collections.workspace = true
-futures.workspace = true
-gpui.workspace = true
-log.workspace = true
-project.workspace = true
-repair_json.workspace = true
-schemars.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-sum_tree.workspace = true
-ui.workspace = true
-util.workspace = true
-
-[dev-dependencies]
-gpui = { workspace = true, features = ["test-support"] }
-project = { workspace = true, features = ["test-support"] }
-settings = { workspace = true, features = ["test-support"] }
-unindent.workspace = true
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,85 +0,0 @@
-# Assistant Tooling
-
-Bringing Language Model tool calling to GPUI.
-
-This unlocks:
-
-- **Structured Extraction** of model responses
-- **Validation** of model inputs
-- **Execution** of chosen tools
-
-## Overview
-
-Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When making a chat completion you can pass a list of tools available to the model. The model will choose `0..n` tools to help them complete a user's task. It's up to _you_ to create the tools that the model can call.
-
-> **User**: "Hey I need help with implementing a collapsible panel in GPUI"
->
-> **Assistant**: "Sure, I can help with that. Let me see what I can find."
->
-> `tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]`
->
-> `result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"`
->
-> **Assistant**: "Here are some excerpts from the GPUI codebase that might help you."
-
-This library is designed to facilitate this interaction mode by allowing you to go from `struct` to `tool` with two simple traits, `LanguageModelTool` and `ToolView`.
-
-## Using the Tool Registry
-
-```rust
-let mut tool_registry = ToolRegistry::new();
-tool_registry
- .register(WeatherTool { api_client },
- })
- .unwrap(); // You can only register one tool per name
-
-let completion = cx.update(|cx| {
- CompletionProvider::get(cx).complete(
- model_name,
- messages,
- Vec::new(),
- 1.0,
- // The definitions get passed directly to OpenAI when you want
- // the model to be able to call your tool
- tool_registry.definitions(),
- )
-});
-
-let mut stream = completion?.await?;
-
-let mut message = AssistantMessage::new();
-
-while let Some(delta) = stream.next().await {
- // As messages stream in, you'll get both assistant content
- if let Some(content) = &delta.content {
- message
- .body
- .update(cx, |message, cx| message.append(&content, cx));
- }
-
- // And tool calls!
- for tool_call_delta in delta.tool_calls {
- let index = tool_call_delta.index as usize;
- if index >= message.tool_calls.len() {
- message.tool_calls.resize_with(index + 1, Default::default);
- }
- let tool_call = &mut message.tool_calls[index];
-
- // Build up an ID
- if let Some(id) = &tool_call_delta.id {
- tool_call.id.push_str(id);
- }
-
- tool_registry.update_tool_call(
- tool_call,
- tool_call_delta.name.as_deref(),
- tool_call_delta.arguments.as_deref(),
- cx,
- );
- }
-}
-```
-
-Once the stream of tokens is complete, you can execute the tool call by calling `tool_registry.execute_tool_call(tool_call, cx)`, which returns a `Task<Result<()>>`.
-
-As the tokens stream in and tool calls are executed, your `ToolView` will get updates. Render each tool call by passing that `tool_call` in to `tool_registry.render_tool_call(tool_call, cx)`. The final message for the model can be pulled by calling `self.tool_registry.content_for_tool_call( tool_call, &mut project_context, cx, )`.
@@ -1,13 +0,0 @@
-mod attachment_registry;
-mod project_context;
-mod tool_registry;
-
-pub use attachment_registry::{
- AttachmentOutput, AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment,
- UserAttachment,
-};
-pub use project_context::ProjectContext;
-pub use tool_registry::{
- LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition,
- ToolRegistry, ToolView,
-};
@@ -1,234 +0,0 @@
-use crate::ProjectContext;
-use anyhow::{anyhow, Result};
-use collections::HashMap;
-use futures::future::join_all;
-use gpui::{AnyView, Render, Task, View, WindowContext};
-use serde::{de::DeserializeOwned, Deserialize, Serialize};
-use serde_json::value::RawValue;
-use std::{
- any::TypeId,
- sync::{
- atomic::{AtomicBool, Ordering::SeqCst},
- Arc,
- },
-};
-use util::ResultExt as _;
-
-pub struct AttachmentRegistry {
- registered_attachments: HashMap<TypeId, RegisteredAttachment>,
-}
-
-pub trait AttachmentOutput {
- fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
-}
-
-pub trait LanguageModelAttachment {
- type Output: DeserializeOwned + Serialize + 'static;
- type View: Render + AttachmentOutput;
-
- fn name(&self) -> Arc<str>;
- fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
- fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
-}
-
-/// A collected attachment from running an attachment tool
-pub struct UserAttachment {
- pub view: AnyView,
- name: Arc<str>,
- serialized_output: Result<Box<RawValue>, String>,
- generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
-}
-
-#[derive(Serialize, Deserialize)]
-pub struct SavedUserAttachment {
- name: Arc<str>,
- serialized_output: Result<Box<RawValue>, String>,
-}
-
-/// Internal representation of an attachment tool to allow us to treat them dynamically
-struct RegisteredAttachment {
- name: Arc<str>,
- enabled: AtomicBool,
- call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
- deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
-}
-
-impl AttachmentRegistry {
- pub fn new() -> Self {
- Self {
- registered_attachments: HashMap::default(),
- }
- }
-
- pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
- let attachment = Arc::new(attachment);
-
- let call = Box::new({
- let attachment = attachment.clone();
- move |cx: &mut WindowContext| {
- let result = attachment.run(cx);
- let attachment = attachment.clone();
- cx.spawn(move |mut cx| async move {
- let result: Result<A::Output> = result.await;
- let serialized_output =
- result
- .as_ref()
- .map_err(ToString::to_string)
- .and_then(|output| {
- Ok(RawValue::from_string(
- serde_json::to_string(output).map_err(|e| e.to_string())?,
- )
- .unwrap())
- });
-
- let view = cx.update(|cx| attachment.view(result, cx))?;
-
- Ok(UserAttachment {
- name: attachment.name(),
- view: view.into(),
- generate_fn: generate::<A>,
- serialized_output,
- })
- })
- }
- });
-
- let deserialize = Box::new({
- let attachment = attachment.clone();
- move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
- let serialized_output = saved_attachment.serialized_output.clone();
- let output = match &serialized_output {
- Ok(serialized_output) => {
- Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
- }
- Err(error) => Err(anyhow!("{error}")),
- };
- let view = attachment.view(output, cx).into();
-
- Ok(UserAttachment {
- name: saved_attachment.name.clone(),
- view,
- serialized_output,
- generate_fn: generate::<A>,
- })
- }
- });
-
- self.registered_attachments.insert(
- TypeId::of::<A>(),
- RegisteredAttachment {
- name: attachment.name(),
- call,
- deserialize,
- enabled: AtomicBool::new(true),
- },
- );
- return;
-
- fn generate<T: LanguageModelAttachment>(
- view: AnyView,
- project: &mut ProjectContext,
- cx: &mut WindowContext,
- ) -> String {
- view.downcast::<T::View>()
- .unwrap()
- .update(cx, |view, cx| T::View::generate(view, project, cx))
- }
- }
-
- pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
- &self,
- is_enabled: bool,
- ) {
- if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
- attachment.enabled.store(is_enabled, SeqCst);
- }
- }
-
- pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
- if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
- attachment.enabled.load(SeqCst)
- } else {
- false
- }
- }
-
- pub fn call<A: LanguageModelAttachment + 'static>(
- &self,
- cx: &mut WindowContext,
- ) -> Task<Result<UserAttachment>> {
- let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
- return Task::ready(Err(anyhow!("no attachment tool")));
- };
-
- (attachment.call)(cx)
- }
-
- pub fn call_all_attachment_tools(
- self: Arc<Self>,
- cx: &mut WindowContext<'_>,
- ) -> Task<Result<Vec<UserAttachment>>> {
- let this = self.clone();
- cx.spawn(|mut cx| async move {
- let attachment_tasks = cx.update(|cx| {
- let mut tasks = Vec::new();
- for attachment in this
- .registered_attachments
- .values()
- .filter(|attachment| attachment.enabled.load(SeqCst))
- {
- tasks.push((attachment.call)(cx))
- }
-
- tasks
- })?;
-
- let attachments = join_all(attachment_tasks.into_iter()).await;
-
- Ok(attachments
- .into_iter()
- .filter_map(|attachment| attachment.log_err())
- .collect())
- })
- }
-
- pub fn serialize_user_attachment(
- &self,
- user_attachment: &UserAttachment,
- ) -> SavedUserAttachment {
- SavedUserAttachment {
- name: user_attachment.name.clone(),
- serialized_output: user_attachment.serialized_output.clone(),
- }
- }
-
- pub fn deserialize_user_attachment(
- &self,
- saved_user_attachment: SavedUserAttachment,
- cx: &mut WindowContext,
- ) -> Result<UserAttachment> {
- if let Some(registered_attachment) = self
- .registered_attachments
- .values()
- .find(|attachment| attachment.name == saved_user_attachment.name)
- {
- (registered_attachment.deserialize)(&saved_user_attachment, cx)
- } else {
- Err(anyhow!(
- "no attachment tool for name {}",
- saved_user_attachment.name
- ))
- }
- }
-}
-
-impl UserAttachment {
- pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
- let result = (self.generate_fn)(self.view.clone(), output, cx);
- if result.is_empty() {
- None
- } else {
- Some(result)
- }
- }
-}
@@ -1,296 +0,0 @@
-use anyhow::{anyhow, Result};
-use gpui::{AppContext, Model, Task, WeakModel};
-use project::{Fs, Project, ProjectPath, Worktree};
-use std::{cmp::Ordering, fmt::Write as _, ops::Range, sync::Arc};
-use sum_tree::TreeMap;
-
-pub struct ProjectContext {
- files: TreeMap<ProjectPath, PathState>,
- project: WeakModel<Project>,
- fs: Arc<dyn Fs>,
-}
-
-#[derive(Debug, Clone)]
-enum PathState {
- PathOnly,
- EntireFile,
- Excerpts { ranges: Vec<Range<usize>> },
-}
-
-impl ProjectContext {
- pub fn new(project: WeakModel<Project>, fs: Arc<dyn Fs>) -> Self {
- Self {
- files: TreeMap::default(),
- fs,
- project,
- }
- }
-
- pub fn add_path(&mut self, project_path: ProjectPath) {
- if self.files.get(&project_path).is_none() {
- self.files.insert(project_path, PathState::PathOnly);
- }
- }
-
- pub fn add_excerpts(&mut self, project_path: ProjectPath, new_ranges: &[Range<usize>]) {
- let previous_state = self
- .files
- .get(&project_path)
- .unwrap_or(&PathState::PathOnly);
-
- let mut ranges = match previous_state {
- PathState::EntireFile => return,
- PathState::PathOnly => Vec::new(),
- PathState::Excerpts { ranges } => ranges.to_vec(),
- };
-
- for new_range in new_ranges {
- let ix = ranges.binary_search_by(|probe| {
- if probe.end < new_range.start {
- Ordering::Less
- } else if probe.start > new_range.end {
- Ordering::Greater
- } else {
- Ordering::Equal
- }
- });
-
- match ix {
- Ok(mut ix) => {
- let existing = &mut ranges[ix];
- existing.start = existing.start.min(new_range.start);
- existing.end = existing.end.max(new_range.end);
- while ix + 1 < ranges.len() && ranges[ix + 1].start <= ranges[ix].end {
- ranges[ix].end = ranges[ix].end.max(ranges[ix + 1].end);
- ranges.remove(ix + 1);
- }
- while ix > 0 && ranges[ix - 1].end >= ranges[ix].start {
- ranges[ix].start = ranges[ix].start.min(ranges[ix - 1].start);
- ranges.remove(ix - 1);
- ix -= 1;
- }
- }
- Err(ix) => {
- ranges.insert(ix, new_range.clone());
- }
- }
- }
-
- self.files
- .insert(project_path, PathState::Excerpts { ranges });
- }
-
- pub fn add_file(&mut self, project_path: ProjectPath) {
- self.files.insert(project_path, PathState::EntireFile);
- }
-
- pub fn generate_system_message(&self, cx: &mut AppContext) -> Task<Result<String>> {
- let project = self
- .project
- .upgrade()
- .ok_or_else(|| anyhow!("project dropped"));
- let files = self.files.clone();
- let fs = self.fs.clone();
- cx.spawn(|cx| async move {
- let project = project?;
- let mut result = "project structure:\n".to_string();
-
- let mut last_worktree: Option<Model<Worktree>> = None;
- for (project_path, path_state) in files.iter() {
- if let Some(worktree) = &last_worktree {
- if worktree.read_with(&cx, |tree, _| tree.id())? != project_path.worktree_id {
- last_worktree = None;
- }
- }
-
- let worktree;
- if let Some(last_worktree) = &last_worktree {
- worktree = last_worktree.clone();
- } else if let Some(tree) = project.read_with(&cx, |project, cx| {
- project.worktree_for_id(project_path.worktree_id, cx)
- })? {
- worktree = tree;
- last_worktree = Some(worktree.clone());
- let worktree_name =
- worktree.read_with(&cx, |tree, _cx| tree.root_name().to_string())?;
- writeln!(&mut result, "# {}", worktree_name).unwrap();
- } else {
- continue;
- }
-
- let worktree_abs_path = worktree.read_with(&cx, |tree, _cx| tree.abs_path())?;
- let path = &project_path.path;
- writeln!(&mut result, "## {}", path.display()).unwrap();
-
- match path_state {
- PathState::PathOnly => {}
- PathState::EntireFile => {
- let text = fs.load(&worktree_abs_path.join(&path)).await?;
- writeln!(&mut result, "~~~\n{text}\n~~~").unwrap();
- }
- PathState::Excerpts { ranges } => {
- let text = fs.load(&worktree_abs_path.join(&path)).await?;
-
- writeln!(&mut result, "~~~").unwrap();
-
- // Assumption: ranges are in order, not overlapping
- let mut prev_range_end = 0;
- for range in ranges {
- if range.start > prev_range_end {
- writeln!(&mut result, "...").unwrap();
- prev_range_end = range.end;
- }
-
- let mut start = range.start;
- let mut end = range.end.min(text.len());
- while !text.is_char_boundary(start) {
- start += 1;
- }
- while !text.is_char_boundary(end) {
- end -= 1;
- }
- result.push_str(&text[start..end]);
- if !result.ends_with('\n') {
- result.push('\n');
- }
- }
-
- if prev_range_end < text.len() {
- writeln!(&mut result, "...").unwrap();
- }
-
- writeln!(&mut result, "~~~").unwrap();
- }
- }
- }
- Ok(result)
- })
- }
-}
-
-#[cfg(test)]
-mod tests {
- use std::path::Path;
-
- use super::*;
- use gpui::TestAppContext;
- use project::FakeFs;
- use serde_json::json;
- use settings::SettingsStore;
-
- use unindent::Unindent as _;
-
- #[gpui::test]
- async fn test_system_message_generation(cx: &mut TestAppContext) {
- init_test(cx);
-
- let file_3_contents = r#"
- fn test1() {}
- fn test2() {}
- fn test3() {}
- "#
- .unindent();
-
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/code",
- json!({
- "root1": {
- "lib": {
- "file1.rs": "mod example;",
- "file2.rs": "",
- },
- "test": {
- "file3.rs": file_3_contents,
- }
- },
- "root2": {
- "src": {
- "main.rs": ""
- }
- }
- }),
- )
- .await;
-
- let project = Project::test(
- fs.clone(),
- ["/code/root1".as_ref(), "/code/root2".as_ref()],
- cx,
- )
- .await;
-
- let worktree_ids = project.read_with(cx, |project, cx| {
- project
- .worktrees(cx)
- .map(|worktree| worktree.read(cx).id())
- .collect::<Vec<_>>()
- });
-
- let mut ax = ProjectContext::new(project.downgrade(), fs);
-
- ax.add_file(ProjectPath {
- worktree_id: worktree_ids[0],
- path: Path::new("lib/file1.rs").into(),
- });
-
- let message = cx
- .update(|cx| ax.generate_system_message(cx))
- .await
- .unwrap();
- assert_eq!(
- r#"
- project structure:
- # root1
- ## lib/file1.rs
- ~~~
- mod example;
- ~~~
- "#
- .unindent(),
- message
- );
-
- ax.add_excerpts(
- ProjectPath {
- worktree_id: worktree_ids[0],
- path: Path::new("test/file3.rs").into(),
- },
- &[
- file_3_contents.find("fn test2").unwrap()
- ..file_3_contents.find("fn test3").unwrap(),
- ],
- );
-
- let message = cx
- .update(|cx| ax.generate_system_message(cx))
- .await
- .unwrap();
- assert_eq!(
- r#"
- project structure:
- # root1
- ## lib/file1.rs
- ~~~
- mod example;
- ~~~
- ## test/file3.rs
- ~~~
- ...
- fn test2() {}
- ...
- ~~~
- "#
- .unindent(),
- message
- );
- }
-
- fn init_test(cx: &mut TestAppContext) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- Project::init_settings(cx);
- });
- }
-}
@@ -1,526 +0,0 @@
-use crate::ProjectContext;
-use anyhow::{anyhow, Result};
-use gpui::{AnyElement, AnyView, IntoElement, Render, Task, View, WindowContext};
-use repair_json::repair;
-use schemars::{schema::RootSchema, schema_for, JsonSchema};
-use serde::{de::DeserializeOwned, Deserialize, Serialize};
-use serde_json::value::RawValue;
-use std::{
- any::TypeId,
- collections::HashMap,
- fmt::Display,
- mem,
- sync::atomic::{AtomicBool, Ordering::SeqCst},
-};
-use ui::ViewContext;
-
-pub struct ToolRegistry {
- registered_tools: HashMap<String, RegisteredTool>,
-}
-
-#[derive(Default)]
-pub struct ToolFunctionCall {
- pub id: String,
- pub name: String,
- pub arguments: String,
- state: ToolFunctionCallState,
-}
-
-#[derive(Default)]
-enum ToolFunctionCallState {
- #[default]
- Initializing,
- NoSuchTool,
- KnownTool(Box<dyn InternalToolView>),
- ExecutedTool(Box<dyn InternalToolView>),
-}
-
-trait InternalToolView {
- fn view(&self) -> AnyView;
- fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
- fn try_set_input(&self, input: &str, cx: &mut WindowContext);
- fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
- fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
- fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
-}
-
-#[derive(Default, Serialize, Deserialize)]
-pub struct SavedToolFunctionCall {
- id: String,
- name: String,
- arguments: String,
- state: SavedToolFunctionCallState,
-}
-
-#[derive(Default, Serialize, Deserialize)]
-enum SavedToolFunctionCallState {
- #[default]
- Initializing,
- NoSuchTool,
- KnownTool,
- ExecutedTool(Box<RawValue>),
-}
-
-#[derive(Clone, Debug, PartialEq)]
-pub struct ToolFunctionDefinition {
- pub name: String,
- pub description: String,
- pub parameters: RootSchema,
-}
-
-pub trait LanguageModelTool {
- type View: ToolView;
-
- /// Returns the name of the tool.
- ///
- /// This name is exposed to the language model to allow the model to pick
- /// which tools to use. As this name is used to identify the tool within a
- /// tool registry, it should be unique.
- fn name(&self) -> String;
-
- /// Returns the description of the tool.
- ///
- /// This can be used to _prompt_ the model as to what the tool does.
- fn description(&self) -> String;
-
- /// Returns the OpenAI Function definition for the tool, for direct use with OpenAI's API.
- fn definition(&self) -> ToolFunctionDefinition {
- let root_schema = schema_for!(<Self::View as ToolView>::Input);
-
- ToolFunctionDefinition {
- name: self.name(),
- description: self.description(),
- parameters: root_schema,
- }
- }
-
- /// A view of the output of running the tool, for displaying to the user.
- fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
-}
-
-pub trait ToolView: Render {
- /// The input type that will be passed in to `execute` when the tool is called
- /// by the language model.
- type Input: DeserializeOwned + JsonSchema;
-
- /// The output returned by executing the tool.
- type SerializedState: DeserializeOwned + Serialize;
-
- fn generate(&self, project: &mut ProjectContext, cx: &mut ViewContext<Self>) -> String;
- fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>);
- fn execute(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>>;
-
- fn serialize(&self, cx: &mut ViewContext<Self>) -> Self::SerializedState;
- fn deserialize(
- &mut self,
- output: Self::SerializedState,
- cx: &mut ViewContext<Self>,
- ) -> Result<()>;
-}
-
-struct RegisteredTool {
- enabled: AtomicBool,
- type_id: TypeId,
- build_view: Box<dyn Fn(&mut WindowContext) -> Box<dyn InternalToolView>>,
- definition: ToolFunctionDefinition,
-}
-
-impl ToolRegistry {
- pub fn new() -> Self {
- Self {
- registered_tools: HashMap::new(),
- }
- }
-
- pub fn set_tool_enabled<T: 'static + LanguageModelTool>(&self, is_enabled: bool) {
- for tool in self.registered_tools.values() {
- if tool.type_id == TypeId::of::<T>() {
- tool.enabled.store(is_enabled, SeqCst);
- return;
- }
- }
- }
-
- pub fn is_tool_enabled<T: 'static + LanguageModelTool>(&self) -> bool {
- for tool in self.registered_tools.values() {
- if tool.type_id == TypeId::of::<T>() {
- return tool.enabled.load(SeqCst);
- }
- }
- false
- }
-
- pub fn definitions(&self) -> Vec<ToolFunctionDefinition> {
- self.registered_tools
- .values()
- .filter(|tool| tool.enabled.load(SeqCst))
- .map(|tool| tool.definition.clone())
- .collect()
- }
-
- pub fn update_tool_call(
- &self,
- call: &mut ToolFunctionCall,
- name: Option<&str>,
- arguments: Option<&str>,
- cx: &mut WindowContext,
- ) {
- if let Some(name) = name {
- call.name.push_str(name);
- }
- if let Some(arguments) = arguments {
- if call.arguments.is_empty() {
- if let Some(tool) = self.registered_tools.get(&call.name) {
- let view = (tool.build_view)(cx);
- call.state = ToolFunctionCallState::KnownTool(view);
- } else {
- call.state = ToolFunctionCallState::NoSuchTool;
- }
- }
- call.arguments.push_str(arguments);
-
- if let ToolFunctionCallState::KnownTool(view) = &call.state {
- if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
- view.try_set_input(&repaired_arguments, cx)
- }
- }
- }
- }
-
- pub fn execute_tool_call(
- &self,
- tool_call: &mut ToolFunctionCall,
- cx: &mut WindowContext,
- ) -> Option<Task<Result<()>>> {
- if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
- let task = view.execute(cx);
- tool_call.state = ToolFunctionCallState::ExecutedTool(view);
- Some(task)
- } else {
- None
- }
- }
-
- pub fn render_tool_call(
- &self,
- tool_call: &ToolFunctionCall,
- _cx: &mut WindowContext,
- ) -> Option<AnyElement> {
- match &tool_call.state {
- ToolFunctionCallState::NoSuchTool => {
- Some(ui::Label::new("No such tool").into_any_element())
- }
- ToolFunctionCallState::Initializing => None,
- ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
- Some(view.view().into_any_element())
- }
- }
- }
-
- pub fn content_for_tool_call(
- &self,
- tool_call: &ToolFunctionCall,
- project_context: &mut ProjectContext,
- cx: &mut WindowContext,
- ) -> String {
- match &tool_call.state {
- ToolFunctionCallState::Initializing => String::new(),
- ToolFunctionCallState::NoSuchTool => {
- format!("No such tool: {}", tool_call.name)
- }
- ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
- view.generate(project_context, cx)
- }
- }
- }
-
- pub fn serialize_tool_call(
- &self,
- call: &ToolFunctionCall,
- cx: &mut WindowContext,
- ) -> Result<SavedToolFunctionCall> {
- Ok(SavedToolFunctionCall {
- id: call.id.clone(),
- name: call.name.clone(),
- arguments: call.arguments.clone(),
- state: match &call.state {
- ToolFunctionCallState::Initializing => SavedToolFunctionCallState::Initializing,
- ToolFunctionCallState::NoSuchTool => SavedToolFunctionCallState::NoSuchTool,
- ToolFunctionCallState::KnownTool(_) => SavedToolFunctionCallState::KnownTool,
- ToolFunctionCallState::ExecutedTool(view) => {
- SavedToolFunctionCallState::ExecutedTool(view.serialize_output(cx)?)
- }
- },
- })
- }
-
- pub fn deserialize_tool_call(
- &self,
- call: &SavedToolFunctionCall,
- cx: &mut WindowContext,
- ) -> Result<ToolFunctionCall> {
- let Some(tool) = self.registered_tools.get(&call.name) else {
- return Err(anyhow!("no such tool {}", call.name));
- };
-
- Ok(ToolFunctionCall {
- id: call.id.clone(),
- name: call.name.clone(),
- arguments: call.arguments.clone(),
- state: match &call.state {
- SavedToolFunctionCallState::Initializing => ToolFunctionCallState::Initializing,
- SavedToolFunctionCallState::NoSuchTool => ToolFunctionCallState::NoSuchTool,
- SavedToolFunctionCallState::KnownTool => {
- log::error!("Deserialized tool that had not executed");
- let view = (tool.build_view)(cx);
- view.try_set_input(&call.arguments, cx);
- ToolFunctionCallState::KnownTool(view)
- }
- SavedToolFunctionCallState::ExecutedTool(output) => {
- let view = (tool.build_view)(cx);
- view.try_set_input(&call.arguments, cx);
- view.deserialize_output(output, cx)?;
- ToolFunctionCallState::ExecutedTool(view)
- }
- },
- })
- }
-
- pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
- let name = tool.name();
- let registered_tool = RegisteredTool {
- type_id: TypeId::of::<T>(),
- definition: tool.definition(),
- enabled: AtomicBool::new(true),
- build_view: Box::new(move |cx: &mut WindowContext| Box::new(tool.view(cx))),
- };
-
- let previous = self.registered_tools.insert(name.clone(), registered_tool);
- if previous.is_some() {
- return Err(anyhow!("already registered a tool with name {}", name));
- }
-
- return Ok(());
- }
-}
-
-impl<T: ToolView> InternalToolView for View<T> {
- fn view(&self) -> AnyView {
- self.clone().into()
- }
-
- fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
- self.update(cx, |view, cx| view.generate(project, cx))
- }
-
- fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
- if let Ok(input) = serde_json::from_str::<T::Input>(input) {
- self.update(cx, |view, cx| {
- view.set_input(input, cx);
- cx.notify();
- });
- }
- }
-
- fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>> {
- self.update(cx, |view, cx| view.execute(cx))
- }
-
- fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>> {
- let output = self.update(cx, |view, cx| view.serialize(cx));
- Ok(RawValue::from_string(serde_json::to_string(&output)?)?)
- }
-
- fn deserialize_output(&self, output: &RawValue, cx: &mut WindowContext) -> Result<()> {
- let state = serde_json::from_str::<T::SerializedState>(output.get())?;
- self.update(cx, |view, cx| view.deserialize(state, cx))?;
- Ok(())
- }
-}
-
-impl Display for ToolFunctionDefinition {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let schema = serde_json::to_string(&self.parameters).ok();
- let schema = schema.unwrap_or("None".to_string());
- write!(f, "Name: {}:\n", self.name)?;
- write!(f, "Description: {}\n", self.description)?;
- write!(f, "Parameters: {}", schema)
- }
-}
-
-#[cfg(test)]
-mod test {
- use super::*;
- use gpui::{div, prelude::*, Render, TestAppContext};
- use gpui::{EmptyView, View};
- use schemars::JsonSchema;
- use serde::{Deserialize, Serialize};
- use serde_json::json;
-
- #[derive(Deserialize, Serialize, JsonSchema)]
- struct WeatherQuery {
- location: String,
- unit: String,
- }
-
- #[derive(Clone, Serialize, Deserialize, PartialEq, Debug)]
- struct WeatherResult {
- location: String,
- temperature: f64,
- unit: String,
- }
-
- struct WeatherView {
- input: Option<WeatherQuery>,
- result: Option<WeatherResult>,
-
- // Fake API call
- current_weather: WeatherResult,
- }
-
- #[derive(Clone, Serialize)]
- struct WeatherTool {
- current_weather: WeatherResult,
- }
-
- impl WeatherView {
- fn new(current_weather: WeatherResult) -> Self {
- Self {
- input: None,
- result: None,
- current_weather,
- }
- }
- }
-
- impl Render for WeatherView {
- fn render(&mut self, _cx: &mut gpui::ViewContext<Self>) -> impl IntoElement {
- match self.result {
- Some(ref result) => div()
- .child(format!("temperature: {}", result.temperature))
- .into_any_element(),
- None => div().child("Calculating weather...").into_any_element(),
- }
- }
- }
-
- impl ToolView for WeatherView {
- type Input = WeatherQuery;
-
- type SerializedState = WeatherResult;
-
- fn generate(&self, _output: &mut ProjectContext, _cx: &mut ViewContext<Self>) -> String {
- serde_json::to_string(&self.result).unwrap()
- }
-
- fn set_input(&mut self, input: Self::Input, cx: &mut ViewContext<Self>) {
- self.input = Some(input);
- cx.notify();
- }
-
- fn execute(&mut self, _cx: &mut ViewContext<Self>) -> Task<Result<()>> {
- let input = self.input.as_ref().unwrap();
-
- let _location = input.location.clone();
- let _unit = input.unit.clone();
-
- let weather = self.current_weather.clone();
-
- self.result = Some(weather);
-
- Task::ready(Ok(()))
- }
-
- fn serialize(&self, _cx: &mut ViewContext<Self>) -> Self::SerializedState {
- self.current_weather.clone()
- }
-
- fn deserialize(
- &mut self,
- output: Self::SerializedState,
- _cx: &mut ViewContext<Self>,
- ) -> Result<()> {
- self.current_weather = output;
- Ok(())
- }
- }
-
- impl LanguageModelTool for WeatherTool {
- type View = WeatherView;
-
- fn name(&self) -> String {
- "get_current_weather".to_string()
- }
-
- fn description(&self) -> String {
- "Fetches the current weather for a given location.".to_string()
- }
-
- fn view(&self, cx: &mut WindowContext) -> View<Self::View> {
- cx.new_view(|_cx| WeatherView::new(self.current_weather.clone()))
- }
- }
-
- #[gpui::test]
- async fn test_openai_weather_example(cx: &mut TestAppContext) {
- let (_, cx) = cx.add_window_view(|_cx| EmptyView);
-
- let mut registry = ToolRegistry::new();
- registry
- .register(WeatherTool {
- current_weather: WeatherResult {
- location: "San Francisco".to_string(),
- temperature: 21.0,
- unit: "Celsius".to_string(),
- },
- })
- .unwrap();
-
- let definitions = registry.definitions();
- assert_eq!(
- definitions,
- [ToolFunctionDefinition {
- name: "get_current_weather".to_string(),
- description: "Fetches the current weather for a given location.".to_string(),
- parameters: serde_json::from_value(json!({
- "$schema": "http://json-schema.org/draft-07/schema#",
- "title": "WeatherQuery",
- "type": "object",
- "properties": {
- "location": {
- "type": "string"
- },
- "unit": {
- "type": "string"
- }
- },
- "required": ["location", "unit"]
- }))
- .unwrap(),
- }]
- );
-
- let mut call = ToolFunctionCall {
- id: "the-id".to_string(),
- name: "get_cur".to_string(),
- ..Default::default()
- };
-
- let task = cx.update(|cx| {
- registry.update_tool_call(
- &mut call,
- Some("rent_weather"),
- Some(r#"{"location": "San Francisco","#),
- cx,
- );
- registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
- registry.execute_tool_call(&mut call, cx).unwrap()
- });
- task.await.unwrap();
-
- match &call.state {
- ToolFunctionCallState::ExecutedTool(_view) => {}
- _ => panic!(),
- }
- }
-}
@@ -1,138 +0,0 @@
-use anyhow::{anyhow, Context as _, Result};
-use rpc::proto;
-use util::ResultExt as _;
-
-pub fn language_model_request_to_open_ai(
- request: proto::CompleteWithLanguageModel,
-) -> Result<open_ai::Request> {
- Ok(open_ai::Request {
- model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
- messages: request
- .messages
- .into_iter()
- .map(|message: proto::LanguageModelRequestMessage| {
- let role = proto::LanguageModelRole::from_i32(message.role)
- .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
-
- let openai_message = match role {
- proto::LanguageModelRole::LanguageModelUser => open_ai::RequestMessage::User {
- content: message.content,
- },
- proto::LanguageModelRole::LanguageModelAssistant => {
- open_ai::RequestMessage::Assistant {
- content: Some(message.content),
- tool_calls: message
- .tool_calls
- .into_iter()
- .filter_map(|call| {
- Some(open_ai::ToolCall {
- id: call.id,
- content: match call.variant? {
- proto::tool_call::Variant::Function(f) => {
- open_ai::ToolCallContent::Function {
- function: open_ai::FunctionContent {
- name: f.name,
- arguments: f.arguments,
- },
- }
- }
- },
- })
- })
- .collect(),
- }
- }
- proto::LanguageModelRole::LanguageModelSystem => {
- open_ai::RequestMessage::System {
- content: message.content,
- }
- }
- proto::LanguageModelRole::LanguageModelTool => open_ai::RequestMessage::Tool {
- tool_call_id: message
- .tool_call_id
- .ok_or_else(|| anyhow!("tool message is missing tool call id"))?,
- content: message.content,
- },
- };
-
- Ok(openai_message)
- })
- .collect::<Result<Vec<open_ai::RequestMessage>>>()?,
- stream: true,
- stop: request.stop,
- temperature: request.temperature,
- tools: request
- .tools
- .into_iter()
- .filter_map(|tool| {
- Some(match tool.variant? {
- proto::chat_completion_tool::Variant::Function(f) => {
- open_ai::ToolDefinition::Function {
- function: open_ai::FunctionDefinition {
- name: f.name,
- description: f.description,
- parameters: if let Some(params) = &f.parameters {
- Some(
- serde_json::from_str(params)
- .context("failed to deserialize tool parameters")
- .log_err()?,
- )
- } else {
- None
- },
- },
- }
- }
- })
- })
- .collect(),
- tool_choice: request.tool_choice,
- })
-}
-
-pub fn language_model_request_to_google_ai(
- request: proto::CompleteWithLanguageModel,
-) -> Result<google_ai::GenerateContentRequest> {
- Ok(google_ai::GenerateContentRequest {
- contents: request
- .messages
- .into_iter()
- .map(language_model_request_message_to_google_ai)
- .collect::<Result<Vec<_>>>()?,
- generation_config: None,
- safety_settings: None,
- })
-}
-
-pub fn language_model_request_message_to_google_ai(
- message: proto::LanguageModelRequestMessage,
-) -> Result<google_ai::Content> {
- let role = proto::LanguageModelRole::from_i32(message.role)
- .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
-
- Ok(google_ai::Content {
- parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
- text: message.content,
- })],
- role: match role {
- proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
- proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
- proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
- proto::LanguageModelRole::LanguageModelTool => {
- Err(anyhow!("we don't handle tool calls with google ai yet"))?
- }
- },
- })
-}
-
-pub fn count_tokens_request_to_google_ai(
- request: proto::CountTokensWithLanguageModel,
-) -> Result<google_ai::CountTokensRequest> {
- Ok(google_ai::CountTokensRequest {
- contents: request
- .messages
- .into_iter()
- .map(language_model_request_message_to_google_ai)
- .collect::<Result<Vec<_>>>()?,
- })
-}
@@ -1,4 +1,3 @@
-pub mod ai;
pub mod api;
pub mod auth;
pub mod db;
@@ -46,8 +46,8 @@ use http_client::IsahcHttpClient;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
- self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
- LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
+ self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
+ RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
};
@@ -618,17 +618,6 @@ impl Server {
)
}
})
- .add_request_handler({
- let app_state = app_state.clone();
- user_handler(move |request, response, session| {
- count_tokens_with_language_model(
- request,
- response,
- session,
- app_state.config.google_ai_api_key.clone(),
- )
- })
- })
.add_request_handler({
user_handler(move |request, response, session| {
get_cached_embeddings(request, response, session)
@@ -4514,8 +4503,8 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
- mut request: proto::CompleteWithLanguageModel,
- response: StreamingResponse<proto::CompleteWithLanguageModel>,
+ query: proto::QueryLanguageModel,
+ response: StreamingResponse<proto::QueryLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
@@ -4525,287 +4514,95 @@ async fn complete_with_language_model(
return Err(anyhow!("user not found"))?;
};
authorize_access_to_language_models(&session).await?;
- session
- .rate_limiter
- .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
- .await?;
- let mut provider_and_model = request.model.split('/');
- let (provider, model) = match (
- provider_and_model.next().unwrap(),
- provider_and_model.next(),
- ) {
- (provider, Some(model)) => (provider, model),
- (model, None) => {
- if model.starts_with("gpt") {
- ("openai", model)
- } else if model.starts_with("gemini") {
- ("google", model)
- } else if model.starts_with("claude") {
- ("anthropic", model)
- } else {
- ("unknown", model)
- }
+ match proto::LanguageModelRequestKind::from_i32(query.kind) {
+ Some(proto::LanguageModelRequestKind::Complete) => {
+ session
+ .rate_limiter
+ .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+ .await?;
}
- };
- let provider = provider.to_string();
- request.model = model.to_string();
-
- match provider.as_str() {
- "openai" => {
- let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
- complete_with_open_ai(request, response, session, api_key).await?;
+ Some(proto::LanguageModelRequestKind::CountTokens) => {
+ session
+ .rate_limiter
+ .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
+ .await?;
}
- "anthropic" => {
+ None => Err(anyhow!("unknown request kind"))?,
+ }
+
+ match proto::LanguageModelProvider::from_i32(query.provider) {
+ Some(proto::LanguageModelProvider::Anthropic) => {
let api_key =
anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
- complete_with_anthropic(request, response, session, api_key).await?;
+ let mut chunks = anthropic::stream_completion(
+ session.http_client.as_ref(),
+ anthropic::ANTHROPIC_API_URL,
+ &api_key,
+ serde_json::from_str(&query.request)?,
+ None,
+ )
+ .await?;
+ while let Some(chunk) = chunks.next().await {
+ let chunk = chunk?;
+ response.send(proto::QueryLanguageModelResponse {
+ response: serde_json::to_string(&chunk)?,
+ })?;
+ }
}
- "google" => {
+ Some(proto::LanguageModelProvider::OpenAi) => {
+ let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
+ let mut chunks = open_ai::stream_completion(
+ session.http_client.as_ref(),
+ open_ai::OPEN_AI_API_URL,
+ &api_key,
+ serde_json::from_str(&query.request)?,
+ None,
+ )
+ .await?;
+ while let Some(chunk) = chunks.next().await {
+ let chunk = chunk?;
+ response.send(proto::QueryLanguageModelResponse {
+ response: serde_json::to_string(&chunk)?,
+ })?;
+ }
+ }
+ Some(proto::LanguageModelProvider::Google) => {
let api_key =
google_ai_api_key.context("no Google AI API key configured on the server")?;
- complete_with_google_ai(request, response, session, api_key).await?;
- }
- provider => return Err(anyhow!("unknown provider {:?}", provider))?,
- }
-
- Ok(())
-}
-
-async fn complete_with_open_ai(
- request: proto::CompleteWithLanguageModel,
- response: StreamingResponse<proto::CompleteWithLanguageModel>,
- session: UserSession,
- api_key: Arc<str>,
-) -> Result<()> {
- let mut completion_stream = open_ai::stream_completion(
- session.http_client.as_ref(),
- OPEN_AI_API_URL,
- &api_key,
- crate::ai::language_model_request_to_open_ai(request)?,
- None,
- )
- .await
- .context("open_ai::stream_completion request failed within collab")?;
-
- while let Some(event) = completion_stream.next().await {
- let event = event?;
- response.send(proto::LanguageModelResponse {
- choices: event
- .choices
- .into_iter()
- .map(|choice| proto::LanguageModelChoiceDelta {
- index: choice.index,
- delta: Some(proto::LanguageModelResponseMessage {
- role: choice.delta.role.map(|role| match role {
- open_ai::Role::User => LanguageModelRole::LanguageModelUser,
- open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
- open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
- open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
- } as i32),
- content: choice.delta.content,
- tool_calls: choice
- .delta
- .tool_calls
- .unwrap_or_default()
- .into_iter()
- .map(|delta| proto::ToolCallDelta {
- index: delta.index as u32,
- id: delta.id,
- variant: match delta.function {
- Some(function) => {
- let name = function.name;
- let arguments = function.arguments;
-
- Some(proto::tool_call_delta::Variant::Function(
- proto::tool_call_delta::FunctionCallDelta {
- name,
- arguments,
- },
- ))
- }
- None => None,
- },
- })
- .collect(),
- }),
- finish_reason: choice.finish_reason,
- })
- .collect(),
- })?;
- }
-
- Ok(())
-}
-
-async fn complete_with_google_ai(
- request: proto::CompleteWithLanguageModel,
- response: StreamingResponse<proto::CompleteWithLanguageModel>,
- session: UserSession,
- api_key: Arc<str>,
-) -> Result<()> {
- let mut stream = google_ai::stream_generate_content(
- session.http_client.clone(),
- google_ai::API_URL,
- api_key.as_ref(),
- &request.model.clone(),
- crate::ai::language_model_request_to_google_ai(request)?,
- )
- .await
- .context("google_ai::stream_generate_content request failed")?;
-
- while let Some(event) = stream.next().await {
- let event = event?;
- response.send(proto::LanguageModelResponse {
- choices: event
- .candidates
- .unwrap_or_default()
- .into_iter()
- .map(|candidate| proto::LanguageModelChoiceDelta {
- index: candidate.index as u32,
- delta: Some(proto::LanguageModelResponseMessage {
- role: Some(match candidate.content.role {
- google_ai::Role::User => LanguageModelRole::LanguageModelUser,
- google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
- } as i32),
- content: Some(
- candidate
- .content
- .parts
- .into_iter()
- .filter_map(|part| match part {
- google_ai::Part::TextPart(part) => Some(part.text),
- google_ai::Part::InlineDataPart(_) => None,
- })
- .collect(),
- ),
- // Tool calls are not supported for Google
- tool_calls: Vec::new(),
- }),
- finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
- })
- .collect(),
- })?;
- }
-
- Ok(())
-}
-
-async fn complete_with_anthropic(
- request: proto::CompleteWithLanguageModel,
- response: StreamingResponse<proto::CompleteWithLanguageModel>,
- session: UserSession,
- api_key: Arc<str>,
-) -> Result<()> {
- let mut system_message = String::new();
- let messages = request
- .messages
- .into_iter()
- .filter_map(|message| {
- match message.role() {
- LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
- role: anthropic::Role::User,
- content: message.content,
- }),
- LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
- role: anthropic::Role::Assistant,
- content: message.content,
- }),
- // Anthropic's API breaks system instructions out as a separate field rather
- // than having a system message role.
- LanguageModelRole::LanguageModelSystem => {
- if !system_message.is_empty() {
- system_message.push_str("\n\n");
- }
- system_message.push_str(&message.content);
-
- None
- }
- // We don't yet support tool calls for Anthropic
- LanguageModelRole::LanguageModelTool => None,
- }
- })
- .collect();
-
- let mut stream = anthropic::stream_completion(
- session.http_client.as_ref(),
- anthropic::ANTHROPIC_API_URL,
- &api_key,
- anthropic::Request {
- model: request.model,
- messages,
- stream: true,
- system: system_message,
- max_tokens: 4092,
- },
- None,
- )
- .await?;
-
- let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
-
- while let Some(event) = stream.next().await {
- let event = event?;
- match event {
- anthropic::ResponseEvent::MessageStart { message } => {
- if let Some(role) = message.role {
- if role == "assistant" {
- current_role = proto::LanguageModelRole::LanguageModelAssistant;
- } else if role == "user" {
- current_role = proto::LanguageModelRole::LanguageModelUser;
- }
- }
- }
- anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
- match content_block {
- anthropic::ContentBlock::Text { text } => {
- if !text.is_empty() {
- response.send(proto::LanguageModelResponse {
- choices: vec![proto::LanguageModelChoiceDelta {
- index: 0,
- delta: Some(proto::LanguageModelResponseMessage {
- role: Some(current_role as i32),
- content: Some(text),
- tool_calls: Vec::new(),
- }),
- finish_reason: None,
- }],
- })?;
- }
+ match proto::LanguageModelRequestKind::from_i32(query.kind) {
+ Some(proto::LanguageModelRequestKind::Complete) => {
+ let mut chunks = google_ai::stream_generate_content(
+ session.http_client.as_ref(),
+ google_ai::API_URL,
+ &api_key,
+ serde_json::from_str(&query.request)?,
+ )
+ .await?;
+ while let Some(chunk) = chunks.next().await {
+ let chunk = chunk?;
+ response.send(proto::QueryLanguageModelResponse {
+ response: serde_json::to_string(&chunk)?,
+ })?;
}
}
- }
- anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
- anthropic::TextDelta::TextDelta { text } => {
- response.send(proto::LanguageModelResponse {
- choices: vec![proto::LanguageModelChoiceDelta {
- index: 0,
- delta: Some(proto::LanguageModelResponseMessage {
- role: Some(current_role as i32),
- content: Some(text),
- tool_calls: Vec::new(),
- }),
- finish_reason: None,
- }],
- })?;
- }
- },
- anthropic::ResponseEvent::MessageDelta { delta, .. } => {
- if let Some(stop_reason) = delta.stop_reason {
- response.send(proto::LanguageModelResponse {
- choices: vec![proto::LanguageModelChoiceDelta {
- index: 0,
- delta: None,
- finish_reason: Some(stop_reason),
- }],
+ Some(proto::LanguageModelRequestKind::CountTokens) => {
+ let tokens_response = google_ai::count_tokens(
+ session.http_client.as_ref(),
+ google_ai::API_URL,
+ &api_key,
+ serde_json::from_str(&query.request)?,
+ )
+ .await?;
+ response.send(proto::QueryLanguageModelResponse {
+ response: serde_json::to_string(&tokens_response)?,
})?;
}
+ None => Err(anyhow!("unknown request kind"))?,
}
- anthropic::ResponseEvent::ContentBlockStop { .. } => {}
- anthropic::ResponseEvent::MessageStop {} => {}
- anthropic::ResponseEvent::Ping {} => {}
}
+ None => return Err(anyhow!("unknown provider"))?,
}
Ok(())
@@ -4830,41 +4627,6 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit {
}
}
-async fn count_tokens_with_language_model(
- request: proto::CountTokensWithLanguageModel,
- response: Response<proto::CountTokensWithLanguageModel>,
- session: UserSession,
- google_ai_api_key: Option<Arc<str>>,
-) -> Result<()> {
- authorize_access_to_language_models(&session).await?;
-
- if !request.model.starts_with("gemini") {
- return Err(anyhow!(
- "counting tokens for model: {:?} is not supported",
- request.model
- ))?;
- }
-
- session
- .rate_limiter
- .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
- .await?;
-
- let api_key = google_ai_api_key
- .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
- let tokens_response = google_ai::count_tokens(
- session.http_client.as_ref(),
- google_ai::API_URL,
- &api_key,
- crate::ai::count_tokens_request_to_google_ai(request)?,
- )
- .await?;
- response.send(proto::CountTokensResponse {
- token_count: tokens_response.total_tokens as u32,
- })?;
- Ok(())
-}
-
struct ComputeEmbeddingsRateLimit;
impl RateLimit for ComputeEmbeddingsRateLimit {
@@ -11,9 +11,14 @@ workspace = true
[lib]
path = "src/google_ai.rs"
+[features]
+schemars = ["dep:schemars"]
+
[dependencies]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
+schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
+strum.workspace = true
@@ -1,23 +1,21 @@
-use std::sync::Arc;
-
use anyhow::{anyhow, Result};
-use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::HttpClient;
use serde::{Deserialize, Serialize};
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
pub async fn stream_generate_content(
- client: Arc<dyn HttpClient>,
+ client: &dyn HttpClient,
api_url: &str,
api_key: &str,
- model: &str,
- request: GenerateContentRequest,
+ mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
let uri = format!(
- "{}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={}",
- api_url, api_key
+ "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
+ model = request.model
);
+ request.model.clear();
let request = serde_json::to_string(&request)?;
let mut response = client.post_json(&uri, request.into()).await?;
@@ -52,8 +50,8 @@ pub async fn stream_generate_content(
}
}
-pub async fn count_tokens<T: HttpClient>(
- client: &T,
+pub async fn count_tokens(
+ client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CountTokensRequest,
@@ -91,22 +89,24 @@ pub enum Task {
BatchEmbedContents,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
+ #[serde(default, skip_serializing_if = "String::is_empty")]
+ pub model: String,
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
pub candidates: Option<Vec<GenerateContentCandidate>>,
pub prompt_feedback: Option<PromptFeedback>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentCandidate {
pub index: usize,
@@ -157,7 +157,7 @@ pub struct GenerativeContentBlob {
pub data: String,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationSource {
pub start_index: Option<usize>,
@@ -166,13 +166,13 @@ pub struct CitationSource {
pub license: Option<String>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CitationMetadata {
pub citation_sources: Vec<CitationSource>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
pub block_reason: Option<String>,
@@ -180,7 +180,7 @@ pub struct PromptFeedback {
pub block_reason_message: Option<String>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
pub candidate_count: Option<usize>,
@@ -191,7 +191,7 @@ pub struct GenerationConfig {
pub top_k: Option<usize>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetySetting {
pub category: HarmCategory,
@@ -224,7 +224,7 @@ pub enum HarmCategory {
DangerousContent,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
pub enum HarmBlockThreshold {
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
Unspecified,
@@ -238,7 +238,7 @@ pub enum HarmBlockThreshold {
BlockNone,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum HarmProbability {
#[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
@@ -249,21 +249,85 @@ pub enum HarmProbability {
High,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SafetyRating {
pub category: HarmCategory,
pub probability: HarmProbability,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensRequest {
pub contents: Vec<Content>,
}
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensResponse {
pub total_tokens: usize,
}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
+pub enum Model {
+ #[serde(rename = "gemini-1.5-pro")]
+ Gemini15Pro,
+ #[serde(rename = "gemini-1.5-flash")]
+ Gemini15Flash,
+ #[serde(rename = "custom")]
+ Custom { name: String, max_tokens: usize },
+}
+
+impl Model {
+ pub fn id(&self) -> &str {
+ match self {
+ Model::Gemini15Pro => "gemini-1.5-pro",
+ Model::Gemini15Flash => "gemini-1.5-flash",
+ Model::Custom { name, .. } => name,
+ }
+ }
+
+ pub fn display_name(&self) -> &str {
+ match self {
+ Model::Gemini15Pro => "Gemini 1.5 Pro",
+ Model::Gemini15Flash => "Gemini 1.5 Flash",
+ Model::Custom { name, .. } => name,
+ }
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ match self {
+ Model::Gemini15Pro => 2_000_000,
+ Model::Gemini15Flash => 1_000_000,
+ Model::Custom { max_tokens, .. } => *max_tokens,
+ }
+ }
+}
+
+impl std::fmt::Display for Model {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", self.id())
+ }
+}
+
+pub fn extract_text_from_events(
+ events: impl Stream<Item = Result<GenerateContentResponse>>,
+) -> impl Stream<Item = Result<String>> {
+ events.filter_map(|event| async move {
+ match event {
+ Ok(event) => event.candidates.and_then(|candidates| {
+ candidates.into_iter().next().and_then(|candidate| {
+ candidate.content.parts.into_iter().next().and_then(|part| {
+ if let Part::TextPart(TextPart { text }) = part {
+ Some(Ok(text))
+ } else {
+ None
+ }
+ })
+ })
+ }),
+ Err(error) => Some(Err(error)),
+ }
+ })
+}
@@ -28,6 +28,7 @@ collections.workspace = true
editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
+google_ai = { workspace = true, features = ["schemars"] }
gpui.workspace = true
http_client.workspace = true
menu.workspace = true
@@ -1,108 +1,42 @@
-pub use anthropic::Model as AnthropicModel;
-use anyhow::{anyhow, Result};
-pub use ollama::Model as OllamaModel;
-pub use open_ai::Model as OpenAiModel;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use strum::EnumIter;
-#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "provider", rename_all = "lowercase")]
pub enum CloudModel {
- #[serde(rename = "gpt-3.5-turbo")]
- Gpt3Point5Turbo,
- #[serde(rename = "gpt-4")]
- Gpt4,
- #[serde(rename = "gpt-4-turbo-preview")]
- Gpt4Turbo,
- #[serde(rename = "gpt-4o")]
- #[default]
- Gpt4Omni,
- #[serde(rename = "gpt-4o-mini")]
- Gpt4OmniMini,
- #[serde(rename = "claude-3-5-sonnet")]
- Claude3_5Sonnet,
- #[serde(rename = "claude-3-opus")]
- Claude3Opus,
- #[serde(rename = "claude-3-sonnet")]
- Claude3Sonnet,
- #[serde(rename = "claude-3-haiku")]
- Claude3Haiku,
- #[serde(rename = "gemini-1.5-pro")]
- Gemini15Pro,
- #[serde(rename = "gemini-1.5-flash")]
- Gemini15Flash,
- #[serde(rename = "custom")]
- Custom {
- name: String,
- max_tokens: Option<usize>,
- },
+ Anthropic(anthropic::Model),
+ OpenAi(open_ai::Model),
+ Google(google_ai::Model),
}
-impl CloudModel {
- pub fn from_id(value: &str) -> Result<Self> {
- match value {
- "gpt-3.5-turbo" => Ok(Self::Gpt3Point5Turbo),
- "gpt-4" => Ok(Self::Gpt4),
- "gpt-4-turbo-preview" => Ok(Self::Gpt4Turbo),
- "gpt-4o" => Ok(Self::Gpt4Omni),
- "gpt-4o-mini" => Ok(Self::Gpt4OmniMini),
- "claude-3-5-sonnet" => Ok(Self::Claude3_5Sonnet),
- "claude-3-opus" => Ok(Self::Claude3Opus),
- "claude-3-sonnet" => Ok(Self::Claude3Sonnet),
- "claude-3-haiku" => Ok(Self::Claude3Haiku),
- "gemini-1.5-pro" => Ok(Self::Gemini15Pro),
- "gemini-1.5-flash" => Ok(Self::Gemini15Flash),
- _ => Err(anyhow!("invalid model id")),
- }
+impl Default for CloudModel {
+ fn default() -> Self {
+ Self::Anthropic(anthropic::Model::default())
}
+}
+impl CloudModel {
pub fn id(&self) -> &str {
match self {
- Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
- Self::Gpt4 => "gpt-4",
- Self::Gpt4Turbo => "gpt-4-turbo-preview",
- Self::Gpt4Omni => "gpt-4o",
- Self::Gpt4OmniMini => "gpt-4o-mini",
- Self::Claude3_5Sonnet => "claude-3-5-sonnet",
- Self::Claude3Opus => "claude-3-opus",
- Self::Claude3Sonnet => "claude-3-sonnet",
- Self::Claude3Haiku => "claude-3-haiku",
- Self::Gemini15Pro => "gemini-1.5-pro",
- Self::Gemini15Flash => "gemini-1.5-flash",
- Self::Custom { name, .. } => name,
+ CloudModel::Anthropic(model) => model.id(),
+ CloudModel::OpenAi(model) => model.id(),
+ CloudModel::Google(model) => model.id(),
}
}
pub fn display_name(&self) -> &str {
match self {
- Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
- Self::Gpt4 => "GPT 4",
- Self::Gpt4Turbo => "GPT 4 Turbo",
- Self::Gpt4Omni => "GPT 4 Omni",
- Self::Gpt4OmniMini => "GPT 4 Omni Mini",
- Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
- Self::Claude3Opus => "Claude 3 Opus",
- Self::Claude3Sonnet => "Claude 3 Sonnet",
- Self::Claude3Haiku => "Claude 3 Haiku",
- Self::Gemini15Pro => "Gemini 1.5 Pro",
- Self::Gemini15Flash => "Gemini 1.5 Flash",
- Self::Custom { name, .. } => name,
+ CloudModel::Anthropic(model) => model.display_name(),
+ CloudModel::OpenAi(model) => model.display_name(),
+ CloudModel::Google(model) => model.display_name(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
- Self::Gpt3Point5Turbo => 2048,
- Self::Gpt4 => 4096,
- Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
- Self::Gpt4OmniMini => 128000,
- Self::Claude3_5Sonnet
- | Self::Claude3Opus
- | Self::Claude3Sonnet
- | Self::Claude3Haiku => 200000,
- Self::Gemini15Pro => 128000,
- Self::Gemini15Flash => 32000,
- Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
+ CloudModel::Anthropic(model) => model.max_token_count(),
+ CloudModel::OpenAi(model) => model.max_token_count(),
+ CloudModel::Google(model) => model.max_token_count(),
}
}
}
@@ -2,5 +2,6 @@ pub mod anthropic;
pub mod cloud;
#[cfg(any(test, feature = "test-support"))]
pub mod fake;
+pub mod google;
pub mod ollama;
pub mod open_ai;
@@ -1,4 +1,4 @@
-use anthropic::{stream_completion, Request, RequestMessage};
+use anthropic::stream_completion;
use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
@@ -18,7 +18,7 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
+ LanguageModelProviderState, LanguageModelRequest, Role,
};
const PROVIDER_ID: &str = "anthropic";
@@ -160,40 +160,6 @@ pub struct AnthropicModel {
http_client: Arc<dyn HttpClient>,
}
-impl AnthropicModel {
- fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
- preprocess_anthropic_request(&mut request);
-
- let mut system_message = String::new();
- if request
- .messages
- .first()
- .map_or(false, |message| message.role == Role::System)
- {
- system_message = request.messages.remove(0).content;
- }
-
- Request {
- model: self.model.id().to_string(),
- messages: request
- .messages
- .iter()
- .map(|msg| RequestMessage {
- role: match msg.role {
- Role::User => anthropic::Role::User,
- Role::Assistant => anthropic::Role::Assistant,
- Role::System => unreachable!("filtered out by preprocess_request"),
- },
- content: msg.content.clone(),
- })
- .collect(),
- stream: true,
- system: system_message,
- max_tokens: 4092,
- }
- }
-}
-
pub fn count_anthropic_tokens(
request: LanguageModelRequest,
cx: &AppContext,
@@ -260,7 +226,7 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let request = self.to_anthropic_request(request);
+ let request = request.into_anthropic(self.model.id().into());
let http_client = self.http_client.clone();
@@ -285,75 +251,12 @@ impl LanguageModel for AnthropicModel {
low_speed_timeout,
);
let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(response) => match response {
- anthropic::ResponseEvent::ContentBlockStart {
- content_block, ..
- } => match content_block {
- anthropic::ContentBlock::Text { text } => Some(Ok(text)),
- },
- anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => {
- match delta {
- anthropic::TextDelta::TextDelta { text } => Some(Ok(text)),
- }
- }
- _ => None,
- },
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
+ Ok(anthropic::extract_text_from_events(response).boxed())
}
.boxed()
}
}
-pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
- let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
- let mut system_message = String::new();
-
- for message in request.messages.drain(..) {
- if message.content.is_empty() {
- continue;
- }
-
- match message.role {
- Role::User | Role::Assistant => {
- if let Some(last_message) = new_messages.last_mut() {
- if last_message.role == message.role {
- last_message.content.push_str("\n\n");
- last_message.content.push_str(&message.content);
- continue;
- }
- }
-
- new_messages.push(message);
- }
- Role::System => {
- if !system_message.is_empty() {
- system_message.push_str("\n\n");
- }
- system_message.push_str(&message.content);
- }
- }
- }
-
- if !system_message.is_empty() {
- new_messages.insert(
- 0,
- LanguageModelRequestMessage {
- role: Role::System,
- content: system_message,
- },
- );
- }
-
- request.messages = new_messages;
-}
-
struct AuthenticationPrompt {
api_key: View<Editor>,
state: gpui::Model<State>,
@@ -7,8 +7,10 @@ use crate::{
use anyhow::Result;
use client::Client;
use collections::BTreeMap;
-use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::sync::Arc;
use strum::IntoEnumIterator;
@@ -16,14 +18,29 @@ use ui::prelude::*;
use crate::LanguageModelProvider;
-use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
+use super::anthropic::count_anthropic_tokens;
pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "zed.dev";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct ZedDotDevSettings {
- pub available_models: Vec<CloudModel>,
+ pub available_models: Vec<AvailableModel>,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(rename_all = "lowercase")]
+pub enum AvailableProvider {
+ Anthropic,
+ OpenAi,
+ Google,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct AvailableModel {
+ provider: AvailableProvider,
+ name: String,
+ max_tokens: usize,
}
pub struct CloudLanguageModelProvider {
@@ -100,10 +117,19 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();
- // Add base models from CloudModel::iter()
- for model in CloudModel::iter() {
- if !matches!(model, CloudModel::Custom { .. }) {
- models.insert(model.id().to_string(), model);
+ for model in anthropic::Model::iter() {
+ if !matches!(model, anthropic::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), CloudModel::Anthropic(model));
+ }
+ }
+ for model in open_ai::Model::iter() {
+ if !matches!(model, open_ai::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), CloudModel::OpenAi(model));
+ }
+ }
+ for model in google_ai::Model::iter() {
+ if !matches!(model, google_ai::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), CloudModel::Google(model));
}
}
@@ -112,6 +138,20 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.zed_dot_dev
.available_models
{
+ let model = match model.provider {
+ AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
+ name: model.name.clone(),
+ max_tokens: model.max_tokens,
+ }),
+ AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
+ name: model.name.clone(),
+ max_tokens: model.max_tokens,
+ }),
+ AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
+ name: model.name.clone(),
+ max_tokens: model.max_tokens,
+ }),
+ };
models.insert(model.id().to_string(), model.clone());
}
@@ -183,35 +223,26 @@ impl LanguageModel for CloudLanguageModel {
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
- match &self.model {
- CloudModel::Gpt3Point5Turbo => {
- count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
- }
- CloudModel::Gpt4 => count_open_ai_tokens(request, open_ai::Model::Four, cx),
- CloudModel::Gpt4Turbo => count_open_ai_tokens(request, open_ai::Model::FourTurbo, cx),
- CloudModel::Gpt4Omni => count_open_ai_tokens(request, open_ai::Model::FourOmni, cx),
- CloudModel::Gpt4OmniMini => {
- count_open_ai_tokens(request, open_ai::Model::FourOmniMini, cx)
- }
- CloudModel::Claude3_5Sonnet
- | CloudModel::Claude3Opus
- | CloudModel::Claude3Sonnet
- | CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
- CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
- count_anthropic_tokens(request, cx)
- }
- _ => {
- let request = self.client.request(proto::CountTokensWithLanguageModel {
- model: self.model.id().to_string(),
- messages: request
- .messages
- .iter()
- .map(|message| message.to_proto())
- .collect(),
- });
+ match self.model.clone() {
+ CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
+ CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
+ CloudModel::Google(model) => {
+ let client = self.client.clone();
+ let request = request.into_google(model.id().into());
+ let request = google_ai::CountTokensRequest {
+ contents: request.contents,
+ };
async move {
- let response = request.await?;
- Ok(response.token_count as usize)
+ let request = serde_json::to_string(&request)?;
+ let response = client.request(proto::QueryLanguageModel {
+ provider: proto::LanguageModelProvider::Google as i32,
+ kind: proto::LanguageModelRequestKind::CountTokens as i32,
+ request,
+ });
+ let response = response.await?;
+ let response =
+ serde_json::from_str::<google_ai::CountTokensResponse>(&response.response)?;
+ Ok(response.total_tokens)
}
.boxed()
}
@@ -220,46 +251,65 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
- mut request: LanguageModelRequest,
+ request: LanguageModelRequest,
_: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model {
- CloudModel::Claude3Opus
- | CloudModel::Claude3Sonnet
- | CloudModel::Claude3Haiku
- | CloudModel::Claude3_5Sonnet => preprocess_anthropic_request(&mut request),
- CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
- preprocess_anthropic_request(&mut request)
+ CloudModel::Anthropic(model) => {
+ let client = self.client.clone();
+ let request = request.into_anthropic(model.id().into());
+ async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client.request_stream(proto::QueryLanguageModel {
+ provider: proto::LanguageModelProvider::Anthropic as i32,
+ kind: proto::LanguageModelRequestKind::Complete as i32,
+ request,
+ });
+ let chunks = response.await?;
+ Ok(anthropic::extract_text_from_events(
+ chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+ )
+ .boxed())
+ }
+ .boxed()
+ }
+ CloudModel::OpenAi(model) => {
+ let client = self.client.clone();
+ let request = request.into_open_ai(model.id().into());
+ async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client.request_stream(proto::QueryLanguageModel {
+ provider: proto::LanguageModelProvider::OpenAi as i32,
+ kind: proto::LanguageModelRequestKind::Complete as i32,
+ request,
+ });
+ let chunks = response.await?;
+ Ok(open_ai::extract_text_from_events(
+ chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+ )
+ .boxed())
+ }
+ .boxed()
+ }
+ CloudModel::Google(model) => {
+ let client = self.client.clone();
+ let request = request.into_google(model.id().into());
+ async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client.request_stream(proto::QueryLanguageModel {
+ provider: proto::LanguageModelProvider::Google as i32,
+ kind: proto::LanguageModelRequestKind::Complete as i32,
+ request,
+ });
+ let chunks = response.await?;
+ Ok(google_ai::extract_text_from_events(
+ chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)),
+ )
+ .boxed())
+ }
+ .boxed()
}
- _ => {}
}
-
- let request = proto::CompleteWithLanguageModel {
- model: self.id.0.to_string(),
- messages: request
- .messages
- .iter()
- .map(|message| message.to_proto())
- .collect(),
- stop: request.stop,
- temperature: request.temperature,
- tools: Vec::new(),
- tool_choice: None,
- };
-
- self.client
- .request_stream(request)
- .map_ok(|stream| {
- stream
- .filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
- Err(error) => Some(Err(error)),
- }
- })
- .boxed()
- })
- .boxed()
}
}
@@ -0,0 +1,351 @@
+use anyhow::{anyhow, Result};
+use collections::BTreeMap;
+use editor::{Editor, EditorElement, EditorStyle};
+use futures::{future::BoxFuture, FutureExt, StreamExt};
+use google_ai::stream_generate_content;
+use gpui::{
+ AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
+ WhiteSpace,
+};
+use http_client::HttpClient;
+use settings::{Settings, SettingsStore};
+use std::{sync::Arc, time::Duration};
+use strum::IntoEnumIterator;
+use theme::ThemeSettings;
+use ui::prelude::*;
+use util::ResultExt;
+
+use crate::{
+ settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest,
+};
+
+const PROVIDER_ID: &str = "google";
+const PROVIDER_NAME: &str = "Google AI";
+
+#[derive(Default, Clone, Debug, PartialEq)]
+pub struct GoogleSettings {
+ pub api_url: String,
+ pub low_speed_timeout: Option<Duration>,
+ pub available_models: Vec<google_ai::Model>,
+}
+
+pub struct GoogleLanguageModelProvider {
+ http_client: Arc<dyn HttpClient>,
+ state: gpui::Model<State>,
+}
+
+struct State {
+ api_key: Option<String>,
+ _subscription: Subscription,
+}
+
+impl GoogleLanguageModelProvider {
+ pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
+ let state = cx.new_model(|cx| State {
+ api_key: None,
+ _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
+ cx.notify();
+ }),
+ });
+
+ Self { http_client, state }
+ }
+}
+
+impl LanguageModelProviderState for GoogleLanguageModelProvider {
+ fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
+ Some(cx.observe(&self.state, |_, _, cx| {
+ cx.notify();
+ }))
+ }
+}
+
+impl LanguageModelProvider for GoogleLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
+ fn name(&self) -> LanguageModelProviderName {
+ LanguageModelProviderName(PROVIDER_NAME.into())
+ }
+
+ fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
+ let mut models = BTreeMap::default();
+
+ // Add base models from google_ai::Model::iter()
+ for model in google_ai::Model::iter() {
+ if !matches!(model, google_ai::Model::Custom { .. }) {
+ models.insert(model.id().to_string(), model);
+ }
+ }
+
+ // Override with available models from settings
+ for model in &AllLanguageModelSettings::get_global(cx)
+ .google
+ .available_models
+ {
+ models.insert(model.id().to_string(), model.clone());
+ }
+
+ models
+ .into_values()
+ .map(|model| {
+ Arc::new(GoogleLanguageModel {
+ id: LanguageModelId::from(model.id().to_string()),
+ model,
+ state: self.state.clone(),
+ http_client: self.http_client.clone(),
+ }) as Arc<dyn LanguageModel>
+ })
+ .collect()
+ }
+
+ fn is_authenticated(&self, cx: &AppContext) -> bool {
+ self.state.read(cx).api_key.is_some()
+ }
+
+ fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ if self.is_authenticated(cx) {
+ Task::ready(Ok(()))
+ } else {
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .google
+ .api_url
+ .clone();
+ let state = self.state.clone();
+ cx.spawn(|mut cx| async move {
+ let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
+ api_key
+ } else {
+ let (_, api_key) = cx
+ .update(|cx| cx.read_credentials(&api_url))?
+ .await?
+ .ok_or_else(|| anyhow!("credentials not found"))?;
+ String::from_utf8(api_key)?
+ };
+
+ state.update(&mut cx, |this, cx| {
+ this.api_key = Some(api_key);
+ cx.notify();
+ })
+ })
+ }
+ }
+
+ fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ let state = self.state.clone();
+ let delete_credentials =
+ cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
+ cx.spawn(|mut cx| async move {
+ delete_credentials.await.log_err();
+ state.update(&mut cx, |this, cx| {
+ this.api_key = None;
+ cx.notify();
+ })
+ })
+ }
+}
+
+pub struct GoogleLanguageModel {
+ id: LanguageModelId,
+ model: google_ai::Model,
+ state: gpui::Model<State>,
+ http_client: Arc<dyn HttpClient>,
+}
+
+impl LanguageModel for GoogleLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ LanguageModelName::from(self.model.display_name().to_string())
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ LanguageModelProviderName(PROVIDER_NAME.into())
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("google/{}", self.model.id())
+ }
+
+ fn max_token_count(&self) -> usize {
+ self.model.max_token_count()
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ let request = request.into_google(self.model.id().to_string());
+ let http_client = self.http_client.clone();
+ let api_key = self.state.read(cx).api_key.clone();
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .google
+ .api_url
+ .clone();
+
+ async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let response = google_ai::count_tokens(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ google_ai::CountTokensRequest {
+ contents: request.contents,
+ },
+ )
+ .await?;
+ Ok(response.total_tokens)
+ }
+ .boxed()
+ }
+
+ fn stream_completion(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
+ let request = request.into_google(self.model.id().to_string());
+
+ let http_client = self.http_client.clone();
+ let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).google;
+ (state.api_key.clone(), settings.api_url.clone())
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let response =
+ stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
+ let events = response.await?;
+ Ok(google_ai::extract_text_from_events(events).boxed())
+ }
+ .boxed()
+ }
+}
+
+struct AuthenticationPrompt {
+ api_key: View<Editor>,
+ state: gpui::Model<State>,
+}
+
+impl AuthenticationPrompt {
+ fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
+ Self {
+ api_key: cx.new_view(|cx| {
+ let mut editor = Editor::single_line(cx);
+ editor.set_placeholder_text("AIzaSy...", cx);
+ editor
+ }),
+ state,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+ let api_key = self.api_key.read(cx).text(cx);
+ if api_key.is_empty() {
+ return;
+ }
+
+ let settings = &AllLanguageModelSettings::get_global(cx).google;
+ let write_credentials =
+ cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
+ let state = self.state.clone();
+ cx.spawn(|_, mut cx| async move {
+ write_credentials.await?;
+ state.update(&mut cx, |this, cx| {
+ this.api_key = Some(api_key);
+ cx.notify();
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let settings = ThemeSettings::get_global(cx);
+ let text_style = TextStyle {
+ color: cx.theme().colors().text,
+ font_family: settings.ui_font.family.clone(),
+ font_features: settings.ui_font.features.clone(),
+ font_fallbacks: settings.ui_font.fallbacks.clone(),
+ font_size: rems(0.875).into(),
+ font_weight: settings.ui_font.weight,
+ font_style: FontStyle::Normal,
+ line_height: relative(1.3),
+ background_color: None,
+ underline: None,
+ strikethrough: None,
+ white_space: WhiteSpace::Normal,
+ };
+ EditorElement::new(
+ &self.api_key,
+ EditorStyle {
+ background: cx.theme().colors().editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ ..Default::default()
+ },
+ )
+ }
+}
+
+impl Render for AuthenticationPrompt {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ const INSTRUCTIONS: [&str; 4] = [
+ "To use the Google AI assistant, you need to add your Google AI API key.",
+ "You can create an API key at: https://makersuite.google.com/app/apikey",
+ "",
+ "Paste your Google AI API key below and hit enter to use the assistant:",
+ ];
+
+ v_flex()
+ .p_4()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .children(
+ INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .my_2()
+ .px_2()
+ .py_1()
+ .bg(cx.theme().colors().editor_background)
+ .rounded_md()
+ .child(self.render_api_key_editor(cx)),
+ )
+ .child(
+ Label::new(
+ "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
+ )
+ .size(LabelSize::Small),
+ )
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Label::new("Click on").size(LabelSize::Small))
+ .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
+ .child(
+ Label::new("in the status bar to close this panel.").size(LabelSize::Small),
+ ),
+ )
+ .into_any()
+ }
+}
@@ -7,7 +7,7 @@ use gpui::{
WhiteSpace,
};
use http_client::HttpClient;
-use open_ai::{stream_completion, Request, RequestMessage};
+use open_ai::stream_completion;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
@@ -159,35 +159,6 @@ pub struct OpenAiLanguageModel {
http_client: Arc<dyn HttpClient>,
}
-impl OpenAiLanguageModel {
- fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
- Request {
- model: self.model.clone(),
- messages: request
- .messages
- .into_iter()
- .map(|msg| match msg.role {
- Role::User => RequestMessage::User {
- content: msg.content,
- },
- Role::Assistant => RequestMessage::Assistant {
- content: Some(msg.content),
- tool_calls: Vec::new(),
- },
- Role::System => RequestMessage::System {
- content: msg.content,
- },
- })
- .collect(),
- stream: true,
- stop: request.stop,
- temperature: request.temperature,
- tools: Vec::new(),
- tool_choice: None,
- }
- }
-}
-
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -226,7 +197,7 @@ impl LanguageModel for OpenAiLanguageModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
- let request = self.to_open_ai_request(request);
+ let request = request.into_open_ai(self.model.id().into());
let http_client = self.http_client.clone();
let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
@@ -250,15 +221,7 @@ impl LanguageModel for OpenAiLanguageModel {
low_speed_timeout,
);
let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
+ Ok(open_ai::extract_text_from_events(response).boxed())
}
.boxed()
}
@@ -1,16 +1,16 @@
-use client::Client;
-use collections::BTreeMap;
-use gpui::{AppContext, Global, Model, ModelContext};
-use std::sync::Arc;
-use ui::Context;
-
use crate::{
provider::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
- ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
+ google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider,
+ open_ai::OpenAiLanguageModelProvider,
},
LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
};
+use client::Client;
+use collections::BTreeMap;
+use gpui::{AppContext, Global, Model, ModelContext};
+use std::sync::Arc;
+use ui::Context;
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let registry = cx.new_model(|cx| {
@@ -40,6 +40,10 @@ fn register_language_model_providers(
OllamaLanguageModelProvider::new(client.http_client(), cx),
cx,
);
+ registry.register_provider(
+ GoogleLanguageModelProvider::new(client.http_client(), cx),
+ cx,
+ );
cx.observe_flag::<feature_flags::LanguageModels, _>(move |enabled, cx| {
let client = client.clone();
@@ -1,4 +1,4 @@
-use crate::{role::Role, LanguageModelId};
+use crate::role::Role;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -7,17 +7,6 @@ pub struct LanguageModelRequestMessage {
pub content: String,
}
-impl LanguageModelRequestMessage {
- pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
- proto::LanguageModelRequestMessage {
- role: self.role.to_proto() as i32,
- content: self.content.clone(),
- tool_calls: Vec::new(),
- tool_call_id: None,
- }
- }
-}
-
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub messages: Vec<LanguageModelRequestMessage>,
@@ -26,14 +15,110 @@ pub struct LanguageModelRequest {
}
impl LanguageModelRequest {
- pub fn to_proto(&self, model_id: LanguageModelId) -> proto::CompleteWithLanguageModel {
- proto::CompleteWithLanguageModel {
- model: model_id.0.to_string(),
- messages: self.messages.iter().map(|m| m.to_proto()).collect(),
- stop: self.stop.clone(),
+ pub fn into_open_ai(self, model: String) -> open_ai::Request {
+ open_ai::Request {
+ model,
+ messages: self
+ .messages
+ .into_iter()
+ .map(|msg| match msg.role {
+ Role::User => open_ai::RequestMessage::User {
+ content: msg.content,
+ },
+ Role::Assistant => open_ai::RequestMessage::Assistant {
+ content: Some(msg.content),
+ tool_calls: Vec::new(),
+ },
+ Role::System => open_ai::RequestMessage::System {
+ content: msg.content,
+ },
+ })
+ .collect(),
+ stream: true,
+ stop: self.stop,
temperature: self.temperature,
- tool_choice: None,
tools: Vec::new(),
+ tool_choice: None,
+ }
+ }
+
+ pub fn into_google(self, model: String) -> google_ai::GenerateContentRequest {
+ google_ai::GenerateContentRequest {
+ model,
+ contents: self
+ .messages
+ .into_iter()
+ .map(|msg| google_ai::Content {
+ parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
+ text: msg.content,
+ })],
+ role: match msg.role {
+ Role::User => google_ai::Role::User,
+ Role::Assistant => google_ai::Role::Model,
+ Role::System => google_ai::Role::User, // Google AI doesn't have a system role
+ },
+ })
+ .collect(),
+ generation_config: Some(google_ai::GenerationConfig {
+ candidate_count: Some(1),
+ stop_sequences: Some(self.stop),
+ max_output_tokens: None,
+ temperature: Some(self.temperature as f64),
+ top_p: None,
+ top_k: None,
+ }),
+ safety_settings: None,
+ }
+ }
+
+ pub fn into_anthropic(self, model: String) -> anthropic::Request {
+ let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+ let mut system_message = String::new();
+
+ for message in self.messages {
+ if message.content.is_empty() {
+ continue;
+ }
+
+ match message.role {
+ Role::User | Role::Assistant => {
+ if let Some(last_message) = new_messages.last_mut() {
+ if last_message.role == message.role {
+ last_message.content.push_str("\n\n");
+ last_message.content.push_str(&message.content);
+ continue;
+ }
+ }
+
+ new_messages.push(message);
+ }
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.content);
+ }
+ }
+ }
+
+ anthropic::Request {
+ model,
+ messages: new_messages
+ .into_iter()
+ .filter_map(|message| {
+ Some(anthropic::RequestMessage {
+ role: match message.role {
+ Role::User => anthropic::Role::User,
+ Role::Assistant => anthropic::Role::Assistant,
+ Role::System => return None,
+ },
+ content: message.content,
+ })
+ })
+ .collect(),
+ stream: true,
+ max_tokens: 4092,
+ system: system_message,
}
}
}
@@ -15,7 +15,6 @@ impl Role {
Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
- Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
None => Role::User,
}
}
@@ -6,12 +6,12 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
-use crate::{
- provider::{
- anthropic::AnthropicSettings, cloud::ZedDotDevSettings, ollama::OllamaSettings,
- open_ai::OpenAiSettings,
- },
- CloudModel,
+use crate::provider::{
+ anthropic::AnthropicSettings,
+ cloud::{self, ZedDotDevSettings},
+ google::GoogleSettings,
+ ollama::OllamaSettings,
+ open_ai::OpenAiSettings,
};
/// Initializes the language model settings.
@@ -25,6 +25,7 @@ pub struct AllLanguageModelSettings {
pub ollama: OllamaSettings,
pub openai: OpenAiSettings,
pub zed_dot_dev: ZedDotDevSettings,
+ pub google: GoogleSettings,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@@ -34,6 +35,7 @@ pub struct AllLanguageModelSettingsContent {
pub openai: Option<OpenAiSettingsContent>,
#[serde(rename = "zed.dev")]
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
+ pub google: Option<GoogleSettingsContent>,
}
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@@ -56,9 +58,16 @@ pub struct OpenAiSettingsContent {
pub available_models: Option<Vec<open_ai::Model>>,
}
+#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
+pub struct GoogleSettingsContent {
+ pub api_url: Option<String>,
+ pub low_speed_timeout_in_seconds: Option<u64>,
+ pub available_models: Option<Vec<google_ai::Model>>,
+}
+
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
pub struct ZedDotDevSettingsContent {
- available_models: Option<Vec<CloudModel>>,
+ available_models: Option<Vec<cloud::AvailableModel>>,
}
impl settings::Settings for AllLanguageModelSettings {
@@ -136,6 +145,26 @@ impl settings::Settings for AllLanguageModelSettings {
.as_ref()
.and_then(|s| s.available_models.clone()),
);
+
+ merge(
+ &mut settings.google.api_url,
+ value.google.as_ref().and_then(|s| s.api_url.clone()),
+ );
+ if let Some(low_speed_timeout_in_seconds) = value
+ .google
+ .as_ref()
+ .and_then(|s| s.low_speed_timeout_in_seconds)
+ {
+ settings.google.low_speed_timeout =
+ Some(Duration::from_secs(low_speed_timeout_in_seconds));
+ }
+ merge(
+ &mut settings.google.available_models,
+ value
+ .google
+ .as_ref()
+ .and_then(|s| s.available_models.clone()),
+ );
}
Ok(settings)
@@ -1,5 +1,5 @@
use anyhow::{anyhow, Context, Result};
-use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
@@ -111,38 +111,27 @@ impl Model {
}
}
-fn serialize_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
-where
- S: serde::Serializer,
-{
- match model {
- Model::Custom { name, .. } => serializer.serialize_str(name),
- _ => serializer.serialize_str(model.id()),
- }
-}
-
-#[derive(Debug, Serialize)]
+#[derive(Debug, Serialize, Deserialize)]
pub struct Request {
- #[serde(serialize_with = "serialize_model")]
- pub model: Model,
+ pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
pub stop: Vec<String>,
pub temperature: f32,
- #[serde(skip_serializing_if = "Option::is_none")]
+ #[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
- #[serde(skip_serializing_if = "Vec::is_empty")]
+ #[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
}
-#[derive(Debug, Serialize)]
+#[derive(Debug, Deserialize, Serialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Map<String, Value>>,
}
-#[derive(Serialize, Debug)]
+#[derive(Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
@@ -213,21 +202,21 @@ pub struct FunctionChunk {
pub arguments: Option<String>,
}
-#[derive(Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
-#[derive(Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
pub delta: ResponseMessageDelta,
pub finish_reason: Option<String>,
}
-#[derive(Deserialize, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub struct ResponseStreamEvent {
pub created: u32,
pub model: String,
@@ -369,3 +358,14 @@ pub fn embed<'a>(
}
}
}
+
+pub fn extract_text_from_events(
+ response: impl Stream<Item = Result<ResponseStreamEvent>>,
+) -> impl Stream<Item = Result<String>> {
+ response.filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+}
@@ -13,13 +13,6 @@ message Envelope {
optional uint32 responding_to = 2;
optional PeerId original_sender_id = 3;
- /*
- When you are adding a new message type, instead of adding it in semantic order
- and bumping the message ID's of everything that follows, add it at the end of the
- file and bump the max number. See this
- https://github.com/zed-industries/zed/pull/7890#discussion_r1496621823
-
- */
oneof payload {
Hello hello = 4;
Ack ack = 5;
@@ -201,10 +194,8 @@ message Envelope {
JoinHostedProject join_hosted_project = 164;
- CompleteWithLanguageModel complete_with_language_model = 166;
- LanguageModelResponse language_model_response = 167;
- CountTokensWithLanguageModel count_tokens_with_language_model = 168;
- CountTokensResponse count_tokens_response = 169;
+ QueryLanguageModel query_language_model = 224;
+ QueryLanguageModelResponse query_language_model_response = 225; // current max
GetCachedEmbeddings get_cached_embeddings = 189;
GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
ComputeEmbeddings compute_embeddings = 191;
@@ -271,10 +262,11 @@ message Envelope {
UpdateDevServerProject update_dev_server_project = 221;
AddWorktree add_worktree = 222;
- AddWorktreeResponse add_worktree_response = 223; // current max
+ AddWorktreeResponse add_worktree_response = 223;
}
reserved 158 to 161;
+ reserved 166 to 169;
}
// Messages
@@ -2051,94 +2043,32 @@ message SetRoomParticipantRole {
ChannelRole role = 3;
}
-message CompleteWithLanguageModel {
- string model = 1;
- repeated LanguageModelRequestMessage messages = 2;
- repeated string stop = 3;
- float temperature = 4;
- repeated ChatCompletionTool tools = 5;
- optional string tool_choice = 6;
-}
-
-// A tool presented to the language model for its use
-message ChatCompletionTool {
- oneof variant {
- FunctionObject function = 1;
- }
-
- message FunctionObject {
- string name = 1;
- optional string description = 2;
- optional string parameters = 3;
- }
-}
-
-// A message to the language model
-message LanguageModelRequestMessage {
- LanguageModelRole role = 1;
- string content = 2;
- optional string tool_call_id = 3;
- repeated ToolCall tool_calls = 4;
-}
-
enum LanguageModelRole {
LanguageModelUser = 0;
LanguageModelAssistant = 1;
LanguageModelSystem = 2;
- LanguageModelTool = 3;
-}
-
-message LanguageModelResponseMessage {
- optional LanguageModelRole role = 1;
- optional string content = 2;
- repeated ToolCallDelta tool_calls = 3;
-}
-
-// A request to call a tool, by the language model
-message ToolCall {
- string id = 1;
-
- oneof variant {
- FunctionCall function = 2;
- }
-
- message FunctionCall {
- string name = 1;
- string arguments = 2;
- }
-}
-
-message ToolCallDelta {
- uint32 index = 1;
- optional string id = 2;
-
- oneof variant {
- FunctionCallDelta function = 3;
- }
-
- message FunctionCallDelta {
- optional string name = 1;
- optional string arguments = 2;
- }
+ reserved 3;
}
-message LanguageModelResponse {
- repeated LanguageModelChoiceDelta choices = 1;
+message QueryLanguageModel {
+ LanguageModelProvider provider = 1;
+ LanguageModelRequestKind kind = 2;
+ string request = 3;
}
-message LanguageModelChoiceDelta {
- uint32 index = 1;
- LanguageModelResponseMessage delta = 2;
- optional string finish_reason = 3;
+enum LanguageModelProvider {
+ Anthropic = 0;
+ OpenAI = 1;
+ Google = 2;
}
-message CountTokensWithLanguageModel {
- string model = 1;
- repeated LanguageModelRequestMessage messages = 2;
+enum LanguageModelRequestKind {
+ Complete = 0;
+ CountTokens = 1;
}
-message CountTokensResponse {
- uint32 token_count = 1;
+message QueryLanguageModelResponse {
+ string response = 1;
}
message GetCachedEmbeddings {
@@ -203,12 +203,9 @@ messages!(
(CancelCall, Foreground),
(ChannelMessageSent, Foreground),
(ChannelMessageUpdate, Foreground),
- (CompleteWithLanguageModel, Background),
(ComputeEmbeddings, Background),
(ComputeEmbeddingsResponse, Background),
(CopyProjectEntry, Foreground),
- (CountTokensWithLanguageModel, Background),
- (CountTokensResponse, Background),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
@@ -278,7 +275,6 @@ messages!(
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
- (LanguageModelResponse, Background),
(LeaveChannelBuffer, Background),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
@@ -298,6 +294,8 @@ messages!(
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
+ (QueryLanguageModel, Background),
+ (QueryLanguageModelResponse, Background),
(RefreshInlayHints, Foreground),
(RejoinChannelBuffers, Foreground),
(RejoinChannelBuffersResponse, Foreground),
@@ -412,9 +410,7 @@ request_messages!(
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
- (CompleteWithLanguageModel, LanguageModelResponse),
(ComputeEmbeddings, ComputeEmbeddingsResponse),
- (CountTokensWithLanguageModel, CountTokensResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
@@ -467,6 +463,7 @@ request_messages!(
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
+ (QueryLanguageModel, QueryLanguageModelResponse),
(RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),
(RejoinRoom, RejoinRoomResponse),