1mod model;
2pub mod provider;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7pub mod settings;
8
9use anyhow::Result;
10use client::{Client, UserStore};
11use futures::FutureExt;
12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
13use gpui::{
14 AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
15};
16pub use model::*;
17use project::Fs;
18use proto::Plan;
19pub(crate) use rate_limiter::*;
20pub use registry::*;
21pub use request::*;
22pub use role::*;
23use schemars::JsonSchema;
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use std::{future::Future, sync::Arc};
26use ui::IconName;
27
28pub fn init(
29 user_store: Model<UserStore>,
30 client: Arc<Client>,
31 fs: Arc<dyn Fs>,
32 cx: &mut AppContext,
33) {
34 settings::init(fs, cx);
35 registry::init(user_store, client, cx);
36}
37
38/// The availability of a [`LanguageModel`].
39#[derive(Debug, PartialEq, Eq, Clone, Copy)]
40pub enum LanguageModelAvailability {
41 /// The language model is available to the general public.
42 Public,
43 /// The language model is available to users on the indicated plan.
44 RequiresPlan(Plan),
45}
46
47/// Configuration for caching language model messages.
48#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
49pub struct LanguageModelCacheConfiguration {
50 pub max_cache_anchors: usize,
51 pub should_speculate: bool,
52 pub min_total_token: usize,
53}
54
55/// A completion event from a language model.
56#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
57pub enum LanguageModelCompletionEvent {
58 Text(String),
59 ToolUse(LanguageModelToolUse),
60}
61
62#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
63pub struct LanguageModelToolUse {
64 pub id: String,
65 pub name: String,
66 pub input: serde_json::Value,
67}
68
69pub trait LanguageModel: Send + Sync {
70 fn id(&self) -> LanguageModelId;
71 fn name(&self) -> LanguageModelName;
72 /// If None, falls back to [LanguageModelProvider::icon]
73 fn icon(&self) -> Option<IconName> {
74 None
75 }
76 fn provider_id(&self) -> LanguageModelProviderId;
77 fn provider_name(&self) -> LanguageModelProviderName;
78 fn telemetry_id(&self) -> String;
79
80 /// Returns the availability of this language model.
81 fn availability(&self) -> LanguageModelAvailability {
82 LanguageModelAvailability::Public
83 }
84
85 fn max_token_count(&self) -> usize;
86 fn max_output_tokens(&self) -> Option<u32> {
87 None
88 }
89
90 fn count_tokens(
91 &self,
92 request: LanguageModelRequest,
93 cx: &AppContext,
94 ) -> BoxFuture<'static, Result<usize>>;
95
96 fn stream_completion(
97 &self,
98 request: LanguageModelRequest,
99 cx: &AsyncAppContext,
100 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
101
102 fn stream_completion_text(
103 &self,
104 request: LanguageModelRequest,
105 cx: &AsyncAppContext,
106 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
107 let events = self.stream_completion(request, cx);
108
109 async move {
110 Ok(events
111 .await?
112 .filter_map(|result| async move {
113 match result {
114 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
115 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
116 Err(err) => Some(Err(err)),
117 }
118 })
119 .boxed())
120 }
121 .boxed()
122 }
123
124 fn use_any_tool(
125 &self,
126 request: LanguageModelRequest,
127 name: String,
128 description: String,
129 schema: serde_json::Value,
130 cx: &AsyncAppContext,
131 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
132
133 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
134 None
135 }
136
137 #[cfg(any(test, feature = "test-support"))]
138 fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
139 unimplemented!()
140 }
141}
142
143impl dyn LanguageModel {
144 pub fn use_tool<T: LanguageModelTool>(
145 &self,
146 request: LanguageModelRequest,
147 cx: &AsyncAppContext,
148 ) -> impl 'static + Future<Output = Result<T>> {
149 let schema = schemars::schema_for!(T);
150 let schema_json = serde_json::to_value(&schema).unwrap();
151 let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
152 async move {
153 let stream = stream.await?;
154 let response = stream.try_collect::<String>().await?;
155 Ok(serde_json::from_str(&response)?)
156 }
157 }
158
159 pub fn use_tool_stream<T: LanguageModelTool>(
160 &self,
161 request: LanguageModelRequest,
162 cx: &AsyncAppContext,
163 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
164 let schema = schemars::schema_for!(T);
165 let schema_json = serde_json::to_value(&schema).unwrap();
166 self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
167 }
168}
169
170pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
171 fn name() -> String;
172 fn description() -> String;
173}
174
175pub trait LanguageModelProvider: 'static {
176 fn id(&self) -> LanguageModelProviderId;
177 fn name(&self) -> LanguageModelProviderName;
178 fn icon(&self) -> IconName {
179 IconName::ZedAssistant
180 }
181 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
182 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
183 fn is_authenticated(&self, cx: &AppContext) -> bool;
184 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
185 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
186 fn must_accept_terms(&self, _cx: &AppContext) -> bool {
187 false
188 }
189 fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
190 None
191 }
192 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
193}
194
195pub trait LanguageModelProviderState: 'static {
196 type ObservableEntity;
197
198 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
199
200 fn subscribe<T: 'static>(
201 &self,
202 cx: &mut gpui::ModelContext<T>,
203 callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
204 ) -> Option<gpui::Subscription> {
205 let entity = self.observable_entity()?;
206 Some(cx.observe(&entity, move |this, _, cx| {
207 callback(this, cx);
208 }))
209 }
210}
211
212#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
213pub struct LanguageModelId(pub SharedString);
214
215#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
216pub struct LanguageModelName(pub SharedString);
217
218#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
219pub struct LanguageModelProviderId(pub SharedString);
220
221#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
222pub struct LanguageModelProviderName(pub SharedString);
223
224impl From<String> for LanguageModelId {
225 fn from(value: String) -> Self {
226 Self(SharedString::from(value))
227 }
228}
229
230impl From<String> for LanguageModelName {
231 fn from(value: String) -> Self {
232 Self(SharedString::from(value))
233 }
234}
235
236impl From<String> for LanguageModelProviderId {
237 fn from(value: String) -> Self {
238 Self(SharedString::from(value))
239 }
240}
241
242impl From<String> for LanguageModelProviderName {
243 fn from(value: String) -> Self {
244 Self(SharedString::from(value))
245 }
246}