1mod model;
2pub mod provider;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7pub mod settings;
8
9use anyhow::Result;
10use client::{Client, UserStore};
11use futures::FutureExt;
12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
13use gpui::{
14 AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
15};
16pub use model::*;
17use project::Fs;
18use proto::Plan;
19pub(crate) use rate_limiter::*;
20pub use registry::*;
21pub use request::*;
22pub use role::*;
23use schemars::JsonSchema;
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use std::{future::Future, sync::Arc};
26use ui::IconName;
27
28pub fn init(
29 user_store: Model<UserStore>,
30 client: Arc<Client>,
31 fs: Arc<dyn Fs>,
32 cx: &mut AppContext,
33) {
34 settings::init(fs, cx);
35 registry::init(user_store, client, cx);
36}
37
38/// The availability of a [`LanguageModel`].
39#[derive(Debug, PartialEq, Eq, Clone, Copy)]
40pub enum LanguageModelAvailability {
41 /// The language model is available to the general public.
42 Public,
43 /// The language model is available to users on the indicated plan.
44 RequiresPlan(Plan),
45}
46
47/// Configuration for caching language model messages.
48#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
49pub struct LanguageModelCacheConfiguration {
50 pub max_cache_anchors: usize,
51 pub should_speculate: bool,
52 pub min_total_token: usize,
53}
54
55/// A completion event from a language model.
56#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
57pub enum LanguageModelCompletionEvent {
58 Stop(StopReason),
59 Text(String),
60 ToolUse(LanguageModelToolUse),
61}
62
63#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
64#[serde(rename_all = "snake_case")]
65pub enum StopReason {
66 EndTurn,
67 MaxTokens,
68 ToolUse,
69}
70
71#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
72pub struct LanguageModelToolUse {
73 pub id: String,
74 pub name: String,
75 pub input: serde_json::Value,
76}
77
78pub trait LanguageModel: Send + Sync {
79 fn id(&self) -> LanguageModelId;
80 fn name(&self) -> LanguageModelName;
81 /// If None, falls back to [LanguageModelProvider::icon]
82 fn icon(&self) -> Option<IconName> {
83 None
84 }
85 fn provider_id(&self) -> LanguageModelProviderId;
86 fn provider_name(&self) -> LanguageModelProviderName;
87 fn telemetry_id(&self) -> String;
88
89 /// Returns the availability of this language model.
90 fn availability(&self) -> LanguageModelAvailability {
91 LanguageModelAvailability::Public
92 }
93
94 fn max_token_count(&self) -> usize;
95 fn max_output_tokens(&self) -> Option<u32> {
96 None
97 }
98
99 fn count_tokens(
100 &self,
101 request: LanguageModelRequest,
102 cx: &AppContext,
103 ) -> BoxFuture<'static, Result<usize>>;
104
105 fn stream_completion(
106 &self,
107 request: LanguageModelRequest,
108 cx: &AsyncAppContext,
109 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
110
111 fn stream_completion_text(
112 &self,
113 request: LanguageModelRequest,
114 cx: &AsyncAppContext,
115 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
116 let events = self.stream_completion(request, cx);
117
118 async move {
119 Ok(events
120 .await?
121 .filter_map(|result| async move {
122 match result {
123 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
124 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
125 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
126 Err(err) => Some(Err(err)),
127 }
128 })
129 .boxed())
130 }
131 .boxed()
132 }
133
134 fn use_any_tool(
135 &self,
136 request: LanguageModelRequest,
137 name: String,
138 description: String,
139 schema: serde_json::Value,
140 cx: &AsyncAppContext,
141 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
142
143 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
144 None
145 }
146
147 #[cfg(any(test, feature = "test-support"))]
148 fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
149 unimplemented!()
150 }
151}
152
153impl dyn LanguageModel {
154 pub fn use_tool<T: LanguageModelTool>(
155 &self,
156 request: LanguageModelRequest,
157 cx: &AsyncAppContext,
158 ) -> impl 'static + Future<Output = Result<T>> {
159 let schema = schemars::schema_for!(T);
160 let schema_json = serde_json::to_value(&schema).unwrap();
161 let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
162 async move {
163 let stream = stream.await?;
164 let response = stream.try_collect::<String>().await?;
165 Ok(serde_json::from_str(&response)?)
166 }
167 }
168
169 pub fn use_tool_stream<T: LanguageModelTool>(
170 &self,
171 request: LanguageModelRequest,
172 cx: &AsyncAppContext,
173 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
174 let schema = schemars::schema_for!(T);
175 let schema_json = serde_json::to_value(&schema).unwrap();
176 self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
177 }
178}
179
180pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
181 fn name() -> String;
182 fn description() -> String;
183}
184
185pub trait LanguageModelProvider: 'static {
186 fn id(&self) -> LanguageModelProviderId;
187 fn name(&self) -> LanguageModelProviderName;
188 fn icon(&self) -> IconName {
189 IconName::ZedAssistant
190 }
191 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
192 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
193 fn is_authenticated(&self, cx: &AppContext) -> bool;
194 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
195 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
196 fn must_accept_terms(&self, _cx: &AppContext) -> bool {
197 false
198 }
199 fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
200 None
201 }
202 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
203}
204
205pub trait LanguageModelProviderState: 'static {
206 type ObservableEntity;
207
208 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
209
210 fn subscribe<T: 'static>(
211 &self,
212 cx: &mut gpui::ModelContext<T>,
213 callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
214 ) -> Option<gpui::Subscription> {
215 let entity = self.observable_entity()?;
216 Some(cx.observe(&entity, move |this, _, cx| {
217 callback(this, cx);
218 }))
219 }
220}
221
222#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
223pub struct LanguageModelId(pub SharedString);
224
225#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
226pub struct LanguageModelName(pub SharedString);
227
228#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
229pub struct LanguageModelProviderId(pub SharedString);
230
231#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
232pub struct LanguageModelProviderName(pub SharedString);
233
234impl From<String> for LanguageModelId {
235 fn from(value: String) -> Self {
236 Self(SharedString::from(value))
237 }
238}
239
240impl From<String> for LanguageModelName {
241 fn from(value: String) -> Self {
242 Self(SharedString::from(value))
243 }
244}
245
246impl From<String> for LanguageModelProviderId {
247 fn from(value: String) -> Self {
248 Self(SharedString::from(value))
249 }
250}
251
252impl From<String> for LanguageModelProviderName {
253 fn from(value: String) -> Self {
254 Self(SharedString::from(value))
255 }
256}