1mod model;
2pub mod provider;
3mod rate_limiter;
4mod registry;
5mod request;
6mod role;
7pub mod settings;
8
9use anyhow::Result;
10use client::{Client, UserStore};
11use futures::FutureExt;
12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
13use gpui::{
14 AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
15};
16pub use model::*;
17use project::Fs;
18use proto::Plan;
19pub(crate) use rate_limiter::*;
20pub use registry::*;
21pub use request::*;
22pub use role::*;
23use schemars::JsonSchema;
24use serde::{de::DeserializeOwned, Deserialize, Serialize};
25use std::fmt;
26use std::{future::Future, sync::Arc};
27use ui::IconName;
28
29pub fn init(
30 user_store: Model<UserStore>,
31 client: Arc<Client>,
32 fs: Arc<dyn Fs>,
33 cx: &mut AppContext,
34) {
35 settings::init(fs, cx);
36 registry::init(user_store, client, cx);
37}
38
39/// The availability of a [`LanguageModel`].
40#[derive(Debug, PartialEq, Eq, Clone, Copy)]
41pub enum LanguageModelAvailability {
42 /// The language model is available to the general public.
43 Public,
44 /// The language model is available to users on the indicated plan.
45 RequiresPlan(Plan),
46}
47
48/// Configuration for caching language model messages.
49#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
50pub struct LanguageModelCacheConfiguration {
51 pub max_cache_anchors: usize,
52 pub should_speculate: bool,
53 pub min_total_token: usize,
54}
55
56/// A completion event from a language model.
57#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
58pub enum LanguageModelCompletionEvent {
59 Stop(StopReason),
60 Text(String),
61 ToolUse(LanguageModelToolUse),
62}
63
64#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
65#[serde(rename_all = "snake_case")]
66pub enum StopReason {
67 EndTurn,
68 MaxTokens,
69 ToolUse,
70}
71
72#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
73pub struct LanguageModelToolUse {
74 pub id: String,
75 pub name: String,
76 pub input: serde_json::Value,
77}
78
79pub trait LanguageModel: Send + Sync {
80 fn id(&self) -> LanguageModelId;
81 fn name(&self) -> LanguageModelName;
82 /// If None, falls back to [LanguageModelProvider::icon]
83 fn icon(&self) -> Option<IconName> {
84 None
85 }
86 fn provider_id(&self) -> LanguageModelProviderId;
87 fn provider_name(&self) -> LanguageModelProviderName;
88 fn telemetry_id(&self) -> String;
89
90 /// Returns the availability of this language model.
91 fn availability(&self) -> LanguageModelAvailability {
92 LanguageModelAvailability::Public
93 }
94
95 fn max_token_count(&self) -> usize;
96 fn max_output_tokens(&self) -> Option<u32> {
97 None
98 }
99
100 fn count_tokens(
101 &self,
102 request: LanguageModelRequest,
103 cx: &AppContext,
104 ) -> BoxFuture<'static, Result<usize>>;
105
106 fn stream_completion(
107 &self,
108 request: LanguageModelRequest,
109 cx: &AsyncAppContext,
110 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
111
112 fn stream_completion_text(
113 &self,
114 request: LanguageModelRequest,
115 cx: &AsyncAppContext,
116 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
117 let events = self.stream_completion(request, cx);
118
119 async move {
120 Ok(events
121 .await?
122 .filter_map(|result| async move {
123 match result {
124 Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
125 Ok(LanguageModelCompletionEvent::Stop(_)) => None,
126 Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
127 Err(err) => Some(Err(err)),
128 }
129 })
130 .boxed())
131 }
132 .boxed()
133 }
134
135 fn use_any_tool(
136 &self,
137 request: LanguageModelRequest,
138 name: String,
139 description: String,
140 schema: serde_json::Value,
141 cx: &AsyncAppContext,
142 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
143
144 fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
145 None
146 }
147
148 #[cfg(any(test, feature = "test-support"))]
149 fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
150 unimplemented!()
151 }
152}
153
154impl dyn LanguageModel {
155 pub fn use_tool<T: LanguageModelTool>(
156 &self,
157 request: LanguageModelRequest,
158 cx: &AsyncAppContext,
159 ) -> impl 'static + Future<Output = Result<T>> {
160 let schema = schemars::schema_for!(T);
161 let schema_json = serde_json::to_value(&schema).unwrap();
162 let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
163 async move {
164 let stream = stream.await?;
165 let response = stream.try_collect::<String>().await?;
166 Ok(serde_json::from_str(&response)?)
167 }
168 }
169
170 pub fn use_tool_stream<T: LanguageModelTool>(
171 &self,
172 request: LanguageModelRequest,
173 cx: &AsyncAppContext,
174 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
175 let schema = schemars::schema_for!(T);
176 let schema_json = serde_json::to_value(&schema).unwrap();
177 self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
178 }
179}
180
181pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
182 fn name() -> String;
183 fn description() -> String;
184}
185
186pub trait LanguageModelProvider: 'static {
187 fn id(&self) -> LanguageModelProviderId;
188 fn name(&self) -> LanguageModelProviderName;
189 fn icon(&self) -> IconName {
190 IconName::ZedAssistant
191 }
192 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
193 fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
194 fn is_authenticated(&self, cx: &AppContext) -> bool;
195 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
196 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
197 fn must_accept_terms(&self, _cx: &AppContext) -> bool {
198 false
199 }
200 fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
201 None
202 }
203 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
204}
205
206pub trait LanguageModelProviderState: 'static {
207 type ObservableEntity;
208
209 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
210
211 fn subscribe<T: 'static>(
212 &self,
213 cx: &mut gpui::ModelContext<T>,
214 callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
215 ) -> Option<gpui::Subscription> {
216 let entity = self.observable_entity()?;
217 Some(cx.observe(&entity, move |this, _, cx| {
218 callback(this, cx);
219 }))
220 }
221}
222
223#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
224pub struct LanguageModelId(pub SharedString);
225
226#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
227pub struct LanguageModelName(pub SharedString);
228
229#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
230pub struct LanguageModelProviderId(pub SharedString);
231
232#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
233pub struct LanguageModelProviderName(pub SharedString);
234
235impl fmt::Display for LanguageModelProviderId {
236 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237 write!(f, "{}", self.0)
238 }
239}
240
241impl From<String> for LanguageModelId {
242 fn from(value: String) -> Self {
243 Self(SharedString::from(value))
244 }
245}
246
247impl From<String> for LanguageModelName {
248 fn from(value: String) -> Self {
249 Self(SharedString::from(value))
250 }
251}
252
253impl From<String> for LanguageModelProviderId {
254 fn from(value: String) -> Self {
255 Self(SharedString::from(value))
256 }
257}
258
259impl From<String> for LanguageModelProviderName {
260 fn from(value: String) -> Self {
261 Self(SharedString::from(value))
262 }
263}