language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6mod telemetry;
  7
  8#[cfg(any(test, feature = "test-support"))]
  9pub mod fake_provider;
 10
 11use anyhow::Result;
 12use futures::FutureExt;
 13use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
 14use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 15use proto::Plan;
 16use schemars::JsonSchema;
 17use serde::{de::DeserializeOwned, Deserialize, Serialize};
 18use std::fmt;
 19use std::{future::Future, sync::Arc};
 20use thiserror::Error;
 21use ui::IconName;
 22
 23pub use crate::model::*;
 24pub use crate::rate_limiter::*;
 25pub use crate::registry::*;
 26pub use crate::request::*;
 27pub use crate::role::*;
 28pub use crate::telemetry::*;
 29
 30pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 31
 32pub fn init(cx: &mut App) {
 33    registry::init(cx);
 34}
 35
 36/// The availability of a [`LanguageModel`].
 37#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 38pub enum LanguageModelAvailability {
 39    /// The language model is available to the general public.
 40    Public,
 41    /// The language model is available to users on the indicated plan.
 42    RequiresPlan(Plan),
 43}
 44
 45/// Configuration for caching language model messages.
 46#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 47pub struct LanguageModelCacheConfiguration {
 48    pub max_cache_anchors: usize,
 49    pub should_speculate: bool,
 50    pub min_total_token: usize,
 51}
 52
 53/// A completion event from a language model.
 54#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 55pub enum LanguageModelCompletionEvent {
 56    Stop(StopReason),
 57    Text(String),
 58    ToolUse(LanguageModelToolUse),
 59    StartMessage { message_id: String },
 60}
 61
 62#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 63#[serde(rename_all = "snake_case")]
 64pub enum StopReason {
 65    EndTurn,
 66    MaxTokens,
 67    ToolUse,
 68}
 69
 70#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 71pub struct LanguageModelToolUseId(Arc<str>);
 72
 73impl fmt::Display for LanguageModelToolUseId {
 74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 75        write!(f, "{}", self.0)
 76    }
 77}
 78
 79impl<T> From<T> for LanguageModelToolUseId
 80where
 81    T: Into<Arc<str>>,
 82{
 83    fn from(value: T) -> Self {
 84        Self(value.into())
 85    }
 86}
 87
 88#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 89pub struct LanguageModelToolUse {
 90    pub id: LanguageModelToolUseId,
 91    pub name: String,
 92    pub input: serde_json::Value,
 93}
 94
 95pub struct LanguageModelTextStream {
 96    pub message_id: Option<String>,
 97    pub stream: BoxStream<'static, Result<String>>,
 98}
 99
100impl Default for LanguageModelTextStream {
101    fn default() -> Self {
102        Self {
103            message_id: None,
104            stream: Box::pin(futures::stream::empty()),
105        }
106    }
107}
108
109pub trait LanguageModel: Send + Sync {
110    fn id(&self) -> LanguageModelId;
111    fn name(&self) -> LanguageModelName;
112    /// If None, falls back to [LanguageModelProvider::icon]
113    fn icon(&self) -> Option<IconName> {
114        None
115    }
116    fn provider_id(&self) -> LanguageModelProviderId;
117    fn provider_name(&self) -> LanguageModelProviderName;
118    fn telemetry_id(&self) -> String;
119
120    fn api_key(&self, _cx: &App) -> Option<String> {
121        None
122    }
123
124    /// Returns the availability of this language model.
125    fn availability(&self) -> LanguageModelAvailability {
126        LanguageModelAvailability::Public
127    }
128
129    fn max_token_count(&self) -> usize;
130    fn max_output_tokens(&self) -> Option<u32> {
131        None
132    }
133
134    fn count_tokens(
135        &self,
136        request: LanguageModelRequest,
137        cx: &App,
138    ) -> BoxFuture<'static, Result<usize>>;
139
140    fn stream_completion(
141        &self,
142        request: LanguageModelRequest,
143        cx: &AsyncApp,
144    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
145
146    fn stream_completion_text(
147        &self,
148        request: LanguageModelRequest,
149        cx: &AsyncApp,
150    ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
151        let events = self.stream_completion(request, cx);
152
153        async move {
154            let mut events = events.await?.fuse();
155            let mut message_id = None;
156            let mut first_item_text = None;
157
158            if let Some(first_event) = events.next().await {
159                match first_event {
160                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
161                        message_id = Some(id.clone());
162                    }
163                    Ok(LanguageModelCompletionEvent::Text(text)) => {
164                        first_item_text = Some(text);
165                    }
166                    _ => (),
167                }
168            }
169
170            let stream = futures::stream::iter(first_item_text.map(Ok))
171                .chain(events.filter_map(|result| async move {
172                    match result {
173                        Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
174                        Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
175                        Ok(LanguageModelCompletionEvent::Stop(_)) => None,
176                        Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
177                        Err(err) => Some(Err(err)),
178                    }
179                }))
180                .boxed();
181
182            Ok(LanguageModelTextStream { message_id, stream })
183        }
184        .boxed()
185    }
186
187    fn use_any_tool(
188        &self,
189        request: LanguageModelRequest,
190        name: String,
191        description: String,
192        schema: serde_json::Value,
193        cx: &AsyncApp,
194    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
195
196    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
197        None
198    }
199
200    #[cfg(any(test, feature = "test-support"))]
201    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
202        unimplemented!()
203    }
204}
205
206impl dyn LanguageModel {
207    pub fn use_tool<T: LanguageModelTool>(
208        &self,
209        request: LanguageModelRequest,
210        cx: &AsyncApp,
211    ) -> impl 'static + Future<Output = Result<T>> {
212        let schema = schemars::schema_for!(T);
213        let schema_json = serde_json::to_value(&schema).unwrap();
214        let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
215        async move {
216            let stream = stream.await?;
217            let response = stream.try_collect::<String>().await?;
218            Ok(serde_json::from_str(&response)?)
219        }
220    }
221
222    pub fn use_tool_stream<T: LanguageModelTool>(
223        &self,
224        request: LanguageModelRequest,
225        cx: &AsyncApp,
226    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
227        let schema = schemars::schema_for!(T);
228        let schema_json = serde_json::to_value(&schema).unwrap();
229        self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
230    }
231}
232
233pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
234    fn name() -> String;
235    fn description() -> String;
236}
237
238/// An error that occurred when trying to authenticate the language model provider.
239#[derive(Debug, Error)]
240pub enum AuthenticateError {
241    #[error("credentials not found")]
242    CredentialsNotFound,
243    #[error(transparent)]
244    Other(#[from] anyhow::Error),
245}
246
247pub trait LanguageModelProvider: 'static {
248    fn id(&self) -> LanguageModelProviderId;
249    fn name(&self) -> LanguageModelProviderName;
250    fn icon(&self) -> IconName {
251        IconName::ZedAssistant
252    }
253    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
254    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
255    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
256    fn is_authenticated(&self, cx: &App) -> bool;
257    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
258    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
259    fn must_accept_terms(&self, _cx: &App) -> bool {
260        false
261    }
262    fn render_accept_terms(
263        &self,
264        _view: LanguageModelProviderTosView,
265        _cx: &mut App,
266    ) -> Option<AnyElement> {
267        None
268    }
269    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
270}
271
272#[derive(PartialEq, Eq)]
273pub enum LanguageModelProviderTosView {
274    ThreadEmptyState,
275    PromptEditorPopup,
276    Configuration,
277}
278
279pub trait LanguageModelProviderState: 'static {
280    type ObservableEntity;
281
282    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
283
284    fn subscribe<T: 'static>(
285        &self,
286        cx: &mut gpui::Context<T>,
287        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
288    ) -> Option<gpui::Subscription> {
289        let entity = self.observable_entity()?;
290        Some(cx.observe(&entity, move |this, _, cx| {
291            callback(this, cx);
292        }))
293    }
294}
295
296#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
297pub struct LanguageModelId(pub SharedString);
298
299#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
300pub struct LanguageModelName(pub SharedString);
301
302#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
303pub struct LanguageModelProviderId(pub SharedString);
304
305#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
306pub struct LanguageModelProviderName(pub SharedString);
307
308impl fmt::Display for LanguageModelProviderId {
309    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
310        write!(f, "{}", self.0)
311    }
312}
313
314impl From<String> for LanguageModelId {
315    fn from(value: String) -> Self {
316        Self(SharedString::from(value))
317    }
318}
319
320impl From<String> for LanguageModelName {
321    fn from(value: String) -> Self {
322        Self(SharedString::from(value))
323    }
324}
325
326impl From<String> for LanguageModelProviderId {
327    fn from(value: String) -> Self {
328        Self(SharedString::from(value))
329    }
330}
331
332impl From<String> for LanguageModelProviderName {
333    fn from(value: String) -> Self {
334        Self(SharedString::from(value))
335    }
336}