1mod model;
2mod rate_limiter;
3mod registry;
4mod request;
5mod role;
6
7#[cfg(any(test, feature = "test-support"))]
8pub mod fake_provider;
9
10use anyhow::Result;
11use futures::FutureExt;
12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
13use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
14pub use model::*;
15use proto::Plan;
16pub use rate_limiter::*;
17pub use registry::*;
18pub use request::*;
19pub use role::*;
20use schemars::JsonSchema;
21use serde::{de::DeserializeOwned, Deserialize, Serialize};
22use std::fmt;
23use std::{future::Future, sync::Arc};
24use thiserror::Error;
25use ui::IconName;
26
27pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
28
29pub fn init(cx: &mut App) {
30 registry::init(cx);
31}
32
33/// The availability of a [`LanguageModel`].
34#[derive(Debug, PartialEq, Eq, Clone, Copy)]
35pub enum LanguageModelAvailability {
36 /// The language model is available to the general public.
37 Public,
38 /// The language model is available to users on the indicated plan.
39 RequiresPlan(Plan),
40}
41
42/// Configuration for caching language model messages.
43#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
44pub struct LanguageModelCacheConfiguration {
45 pub max_cache_anchors: usize,
46 pub should_speculate: bool,
47 pub min_total_token: usize,
48}
49
50/// A completion event from a language model.
51#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
52pub enum LanguageModelCompletionEvent {
53 Stop(StopReason),
54 Text(String),
55 ToolUse(LanguageModelToolUse),
56 StartMessage { message_id: String },
57}
58
59#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
60#[serde(rename_all = "snake_case")]
61pub enum StopReason {
62 EndTurn,
63 MaxTokens,
64 ToolUse,
65}
66
67#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
68pub struct LanguageModelToolUseId(Arc<str>);
69
70impl fmt::Display for LanguageModelToolUseId {
71 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72 write!(f, "{}", self.0)
73 }
74}
75
76impl<T> From<T> for LanguageModelToolUseId
77where
78 T: Into<Arc<str>>,
79{
80 fn from(value: T) -> Self {
81 Self(value.into())
82 }
83}
84
85#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
86pub struct LanguageModelToolUse {
87 pub id: LanguageModelToolUseId,
88 pub name: String,
89 pub input: serde_json::Value,
90}
91
92pub struct LanguageModelTextStream {
93 pub message_id: Option<String>,
94 pub stream: BoxStream<'static, Result<String>>,
95}
96
97impl Default for LanguageModelTextStream {
98 fn default() -> Self {
99 Self {
100 message_id: None,
101 stream: Box::pin(futures::stream::empty()),
102 }
103 }
104}
105
106pub trait LanguageModel: Send + Sync {
107 fn id(&self) -> LanguageModelId;
108 fn name(&self) -> LanguageModelName;
109 /// If None, falls back to [LanguageModelProvider::icon]
110 fn icon(&self) -> Option<IconName> {
111 None
112 }
113 fn provider_id(&self) -> LanguageModelProviderId;
114 fn provider_name(&self) -> LanguageModelProviderName;
115 fn telemetry_id(&self) -> String;
116
117 fn api_key(&self, _cx: &App) -> Option<String> {
118 None
119 }
120
121 /// Returns the availability of this language model.
122 fn availability(&self) -> LanguageModelAvailability {
123 LanguageModelAvailability::Public
124 }
125
126 fn max_token_count(&self) -> usize;
127 fn max_output_tokens(&self) -> Option<u32> {
128 None
129 }
130
131 fn count_tokens(
132 &self,
133 request: LanguageModelRequest,
134 cx: &App,
135 ) -> BoxFuture<'static, Result<usize>>;
136
137 fn stream_completion(
138 &self,
139 request: LanguageModelRequest,
140 cx: &AsyncApp,
141 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
142
143 fn stream_completion_text(
144 &self,
145 request: LanguageModelRequest,
146 cx: &AsyncApp,
147 ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
148 let events = self.stream_completion(request, cx);
149
150 async move {
151 let mut events = events.await?.fuse();
152 let mut message_id = None;
153 let mut first_item_text = None;
154
155 if let Some(first_event) = events.next().await {
156 match first_event {
157 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
158 message_id = Some(id.clone());
159 }
160 Ok(LanguageModelCompletionEvent::Text(text)) => {
161 first_item_text = Some(text);
162 }
163 _ => (),
164 }
165 }
166
167 let stream = futures::stream::iter(first_item_text.map(Ok))
168 .chain(events.filter_map(|result| async move {
169 match result {
170 Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
171 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
172 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
173 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
174 Err(err) => Some(Err(err)),
175 }
176 }))
177 .boxed();
178
179 Ok(LanguageModelTextStream { message_id, stream })
180 }
181 .boxed()
182 }
183
184 fn use_any_tool(
185 &self,
186 request: LanguageModelRequest,
187 name: String,
188 description: String,
189 schema: serde_json::Value,
190 cx: &AsyncApp,
191 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
192
193 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
194 None
195 }
196
197 #[cfg(any(test, feature = "test-support"))]
198 fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
199 unimplemented!()
200 }
201}
202
203impl dyn LanguageModel {
204 pub fn use_tool<T: LanguageModelTool>(
205 &self,
206 request: LanguageModelRequest,
207 cx: &AsyncApp,
208 ) -> impl 'static + Future<Output = Result<T>> {
209 let schema = schemars::schema_for!(T);
210 let schema_json = serde_json::to_value(&schema).unwrap();
211 let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
212 async move {
213 let stream = stream.await?;
214 let response = stream.try_collect::<String>().await?;
215 Ok(serde_json::from_str(&response)?)
216 }
217 }
218
219 pub fn use_tool_stream<T: LanguageModelTool>(
220 &self,
221 request: LanguageModelRequest,
222 cx: &AsyncApp,
223 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
224 let schema = schemars::schema_for!(T);
225 let schema_json = serde_json::to_value(&schema).unwrap();
226 self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
227 }
228}
229
230pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
231 fn name() -> String;
232 fn description() -> String;
233}
234
235/// An error that occurred when trying to authenticate the language model provider.
236#[derive(Debug, Error)]
237pub enum AuthenticateError {
238 #[error("credentials not found")]
239 CredentialsNotFound,
240 #[error(transparent)]
241 Other(#[from] anyhow::Error),
242}
243
244pub trait LanguageModelProvider: 'static {
245 fn id(&self) -> LanguageModelProviderId;
246 fn name(&self) -> LanguageModelProviderName;
247 fn icon(&self) -> IconName {
248 IconName::ZedAssistant
249 }
250 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
251 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
252 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
253 fn is_authenticated(&self, cx: &App) -> bool;
254 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
255 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
256 fn must_accept_terms(&self, _cx: &App) -> bool {
257 false
258 }
259 fn render_accept_terms(
260 &self,
261 _view: LanguageModelProviderTosView,
262 _cx: &mut App,
263 ) -> Option<AnyElement> {
264 None
265 }
266 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
267}
268
269#[derive(PartialEq, Eq)]
270pub enum LanguageModelProviderTosView {
271 ThreadEmptyState,
272 PromptEditorPopup,
273 Configuration,
274}
275
276pub trait LanguageModelProviderState: 'static {
277 type ObservableEntity;
278
279 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
280
281 fn subscribe<T: 'static>(
282 &self,
283 cx: &mut gpui::Context<T>,
284 callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
285 ) -> Option<gpui::Subscription> {
286 let entity = self.observable_entity()?;
287 Some(cx.observe(&entity, move |this, _, cx| {
288 callback(this, cx);
289 }))
290 }
291}
292
293#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
294pub struct LanguageModelId(pub SharedString);
295
296#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
297pub struct LanguageModelName(pub SharedString);
298
299#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
300pub struct LanguageModelProviderId(pub SharedString);
301
302#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
303pub struct LanguageModelProviderName(pub SharedString);
304
305impl fmt::Display for LanguageModelProviderId {
306 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307 write!(f, "{}", self.0)
308 }
309}
310
311impl From<String> for LanguageModelId {
312 fn from(value: String) -> Self {
313 Self(SharedString::from(value))
314 }
315}
316
317impl From<String> for LanguageModelName {
318 fn from(value: String) -> Self {
319 Self(SharedString::from(value))
320 }
321}
322
323impl From<String> for LanguageModelProviderId {
324 fn from(value: String) -> Self {
325 Self(SharedString::from(value))
326 }
327}
328
329impl From<String> for LanguageModelProviderName {
330 fn from(value: String) -> Self {
331 Self(SharedString::from(value))
332 }
333}