1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6mod telemetry;
7
8#[cfg(any(test, feature = "test-support"))]
9pub mod fake_provider;
10
11use anyhow::Result;
12use futures::FutureExt;
13use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
14use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
15use proto::Plan;
16use schemars::JsonSchema;
17use serde::{de::DeserializeOwned, Deserialize, Serialize};
18use std::fmt;
19use std::{future::Future, sync::Arc};
20use thiserror::Error;
21use ui::IconName;
22
23pub use crate::model::*;
24pub use crate::rate_limiter::*;
25pub use crate::registry::*;
26pub use crate::request::*;
27pub use crate::role::*;
28pub use crate::telemetry::*;
29
30pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
31
32pub fn init(cx: &mut App) {
33 registry::init(cx);
34}
35
36/// The availability of a [`LanguageModel`].
37#[derive(Debug, PartialEq, Eq, Clone, Copy)]
38pub enum LanguageModelAvailability {
39 /// The language model is available to the general public.
40 Public,
41 /// The language model is available to users on the indicated plan.
42 RequiresPlan(Plan),
43}
44
45/// Configuration for caching language model messages.
46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
47pub struct LanguageModelCacheConfiguration {
48 pub max_cache_anchors: usize,
49 pub should_speculate: bool,
50 pub min_total_token: usize,
51}
52
53/// A completion event from a language model.
54#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
55pub enum LanguageModelCompletionEvent {
56 Stop(StopReason),
57 Text(String),
58 ToolUse(LanguageModelToolUse),
59 StartMessage { message_id: String },
60}
61
62#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64pub enum StopReason {
65 EndTurn,
66 MaxTokens,
67 ToolUse,
68}
69
70#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
71pub struct LanguageModelToolUseId(Arc<str>);
72
73impl fmt::Display for LanguageModelToolUseId {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 write!(f, "{}", self.0)
76 }
77}
78
79impl<T> From<T> for LanguageModelToolUseId
80where
81 T: Into<Arc<str>>,
82{
83 fn from(value: T) -> Self {
84 Self(value.into())
85 }
86}
87
88#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
89pub struct LanguageModelToolUse {
90 pub id: LanguageModelToolUseId,
91 pub name: String,
92 pub input: serde_json::Value,
93}
94
95pub struct LanguageModelTextStream {
96 pub message_id: Option<String>,
97 pub stream: BoxStream<'static, Result<String>>,
98}
99
100impl Default for LanguageModelTextStream {
101 fn default() -> Self {
102 Self {
103 message_id: None,
104 stream: Box::pin(futures::stream::empty()),
105 }
106 }
107}
108
109pub trait LanguageModel: Send + Sync {
110 fn id(&self) -> LanguageModelId;
111 fn name(&self) -> LanguageModelName;
112 /// If None, falls back to [LanguageModelProvider::icon]
113 fn icon(&self) -> Option<IconName> {
114 None
115 }
116 fn provider_id(&self) -> LanguageModelProviderId;
117 fn provider_name(&self) -> LanguageModelProviderName;
118 fn telemetry_id(&self) -> String;
119
120 fn api_key(&self, _cx: &App) -> Option<String> {
121 None
122 }
123
124 /// Returns the availability of this language model.
125 fn availability(&self) -> LanguageModelAvailability {
126 LanguageModelAvailability::Public
127 }
128
129 fn max_token_count(&self) -> usize;
130 fn max_output_tokens(&self) -> Option<u32> {
131 None
132 }
133
134 fn count_tokens(
135 &self,
136 request: LanguageModelRequest,
137 cx: &App,
138 ) -> BoxFuture<'static, Result<usize>>;
139
140 fn stream_completion(
141 &self,
142 request: LanguageModelRequest,
143 cx: &AsyncApp,
144 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
145
146 fn stream_completion_text(
147 &self,
148 request: LanguageModelRequest,
149 cx: &AsyncApp,
150 ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
151 let events = self.stream_completion(request, cx);
152
153 async move {
154 let mut events = events.await?.fuse();
155 let mut message_id = None;
156 let mut first_item_text = None;
157
158 if let Some(first_event) = events.next().await {
159 match first_event {
160 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
161 message_id = Some(id.clone());
162 }
163 Ok(LanguageModelCompletionEvent::Text(text)) => {
164 first_item_text = Some(text);
165 }
166 _ => (),
167 }
168 }
169
170 let stream = futures::stream::iter(first_item_text.map(Ok))
171 .chain(events.filter_map(|result| async move {
172 match result {
173 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
174 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
175 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
176 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
177 Err(err) => Some(Err(err)),
178 }
179 }))
180 .boxed();
181
182 Ok(LanguageModelTextStream { message_id, stream })
183 }
184 .boxed()
185 }
186
187 fn use_any_tool(
188 &self,
189 request: LanguageModelRequest,
190 name: String,
191 description: String,
192 schema: serde_json::Value,
193 cx: &AsyncApp,
194 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
195
196 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
197 None
198 }
199
200 #[cfg(any(test, feature = "test-support"))]
201 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
202 unimplemented!()
203 }
204}
205
206impl dyn LanguageModel {
207 pub fn use_tool<T: LanguageModelTool>(
208 &self,
209 request: LanguageModelRequest,
210 cx: &AsyncApp,
211 ) -> impl 'static + Future<Output = Result<T>> {
212 let schema = schemars::schema_for!(T);
213 let schema_json = serde_json::to_value(&schema).unwrap();
214 let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
215 async move {
216 let stream = stream.await?;
217 let response = stream.try_collect::<String>().await?;
218 Ok(serde_json::from_str(&response)?)
219 }
220 }
221
222 pub fn use_tool_stream<T: LanguageModelTool>(
223 &self,
224 request: LanguageModelRequest,
225 cx: &AsyncApp,
226 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
227 let schema = schemars::schema_for!(T);
228 let schema_json = serde_json::to_value(&schema).unwrap();
229 self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
230 }
231}
232
233pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
234 fn name() -> String;
235 fn description() -> String;
236}
237
238/// An error that occurred when trying to authenticate the language model provider.
239#[derive(Debug, Error)]
240pub enum AuthenticateError {
241 #[error("credentials not found")]
242 CredentialsNotFound,
243 #[error(transparent)]
244 Other(#[from] anyhow::Error),
245}
246
247pub trait LanguageModelProvider: 'static {
248 fn id(&self) -> LanguageModelProviderId;
249 fn name(&self) -> LanguageModelProviderName;
250 fn icon(&self) -> IconName {
251 IconName::ZedAssistant
252 }
253 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
254 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
255 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
256 fn is_authenticated(&self, cx: &App) -> bool;
257 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
258 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
259 fn must_accept_terms(&self, _cx: &App) -> bool {
260 false
261 }
262 fn render_accept_terms(
263 &self,
264 _view: LanguageModelProviderTosView,
265 _cx: &mut App,
266 ) -> Option<AnyElement> {
267 None
268 }
269 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
270}
271
272#[derive(PartialEq, Eq)]
273pub enum LanguageModelProviderTosView {
274 ThreadEmptyState,
275 PromptEditorPopup,
276 Configuration,
277}
278
279pub trait LanguageModelProviderState: 'static {
280 type ObservableEntity;
281
282 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
283
284 fn subscribe<T: 'static>(
285 &self,
286 cx: &mut gpui::Context<T>,
287 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
288 ) -> Option<gpui::Subscription> {
289 let entity = self.observable_entity()?;
290 Some(cx.observe(&entity, move |this, _, cx| {
291 callback(this, cx);
292 }))
293 }
294}
295
296#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
297pub struct LanguageModelId(pub SharedString);
298
299#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
300pub struct LanguageModelName(pub SharedString);
301
302#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
303pub struct LanguageModelProviderId(pub SharedString);
304
305#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
306pub struct LanguageModelProviderName(pub SharedString);
307
308impl fmt::Display for LanguageModelProviderId {
309 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310 write!(f, "{}", self.0)
311 }
312}
313
314impl From<String> for LanguageModelId {
315 fn from(value: String) -> Self {
316 Self(SharedString::from(value))
317 }
318}
319
320impl From<String> for LanguageModelName {
321 fn from(value: String) -> Self {
322 Self(SharedString::from(value))
323 }
324}
325
326impl From<String> for LanguageModelProviderId {
327 fn from(value: String) -> Self {
328 Self(SharedString::from(value))
329 }
330}
331
332impl From<String> for LanguageModelProviderName {
333 fn from(value: String) -> Self {
334 Self(SharedString::from(value))
335 }
336}