language_model.rs

  1mod model;
  2pub mod provider;
  3mod rate_limiter;
  4mod registry;
  5mod request;
  6mod role;
  7pub mod settings;
  8
  9use anyhow::Result;
 10use client::{Client, UserStore};
 11use futures::FutureExt;
 12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
 13use gpui::{
 14    AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
 15};
 16pub use model::*;
 17use project::Fs;
 18use proto::Plan;
 19pub(crate) use rate_limiter::*;
 20pub use registry::*;
 21pub use request::*;
 22pub use role::*;
 23use schemars::JsonSchema;
 24use serde::{de::DeserializeOwned, Deserialize, Serialize};
 25use std::fmt;
 26use std::{future::Future, sync::Arc};
 27use ui::IconName;
 28
 29pub fn init(
 30    user_store: Model<UserStore>,
 31    client: Arc<Client>,
 32    fs: Arc<dyn Fs>,
 33    cx: &mut AppContext,
 34) {
 35    settings::init(fs, cx);
 36    registry::init(user_store, client, cx);
 37}
 38
 39/// The availability of a [`LanguageModel`].
 40#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 41pub enum LanguageModelAvailability {
 42    /// The language model is available to the general public.
 43    Public,
 44    /// The language model is available to users on the indicated plan.
 45    RequiresPlan(Plan),
 46}
 47
 48/// Configuration for caching language model messages.
 49#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 50pub struct LanguageModelCacheConfiguration {
 51    pub max_cache_anchors: usize,
 52    pub should_speculate: bool,
 53    pub min_total_token: usize,
 54}
 55
 56/// A completion event from a language model.
 57#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 58pub enum LanguageModelCompletionEvent {
 59    Stop(StopReason),
 60    Text(String),
 61    ToolUse(LanguageModelToolUse),
 62}
 63
 64#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 65#[serde(rename_all = "snake_case")]
 66pub enum StopReason {
 67    EndTurn,
 68    MaxTokens,
 69    ToolUse,
 70}
 71
 72#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 73pub struct LanguageModelToolUse {
 74    pub id: String,
 75    pub name: String,
 76    pub input: serde_json::Value,
 77}
 78
 79pub trait LanguageModel: Send + Sync {
 80    fn id(&self) -> LanguageModelId;
 81    fn name(&self) -> LanguageModelName;
 82    /// If None, falls back to [LanguageModelProvider::icon]
 83    fn icon(&self) -> Option<IconName> {
 84        None
 85    }
 86    fn provider_id(&self) -> LanguageModelProviderId;
 87    fn provider_name(&self) -> LanguageModelProviderName;
 88    fn telemetry_id(&self) -> String;
 89
 90    /// Returns the availability of this language model.
 91    fn availability(&self) -> LanguageModelAvailability {
 92        LanguageModelAvailability::Public
 93    }
 94
 95    fn max_token_count(&self) -> usize;
 96    fn max_output_tokens(&self) -> Option<u32> {
 97        None
 98    }
 99
100    fn count_tokens(
101        &self,
102        request: LanguageModelRequest,
103        cx: &AppContext,
104    ) -> BoxFuture<'static, Result<usize>>;
105
106    fn stream_completion(
107        &self,
108        request: LanguageModelRequest,
109        cx: &AsyncAppContext,
110    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
111
112    fn stream_completion_text(
113        &self,
114        request: LanguageModelRequest,
115        cx: &AsyncAppContext,
116    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
117        let events = self.stream_completion(request, cx);
118
119        async move {
120            Ok(events
121                .await?
122                .filter_map(|result| async move {
123                    match result {
124                        Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
125                        Ok(LanguageModelCompletionEvent::Stop(_)) => None,
126                        Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
127                        Err(err) => Some(Err(err)),
128                    }
129                })
130                .boxed())
131        }
132        .boxed()
133    }
134
135    fn use_any_tool(
136        &self,
137        request: LanguageModelRequest,
138        name: String,
139        description: String,
140        schema: serde_json::Value,
141        cx: &AsyncAppContext,
142    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
143
144    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
145        None
146    }
147
148    #[cfg(any(test, feature = "test-support"))]
149    fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
150        unimplemented!()
151    }
152}
153
154impl dyn LanguageModel {
155    pub fn use_tool<T: LanguageModelTool>(
156        &self,
157        request: LanguageModelRequest,
158        cx: &AsyncAppContext,
159    ) -> impl 'static + Future<Output = Result<T>> {
160        let schema = schemars::schema_for!(T);
161        let schema_json = serde_json::to_value(&schema).unwrap();
162        let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
163        async move {
164            let stream = stream.await?;
165            let response = stream.try_collect::<String>().await?;
166            Ok(serde_json::from_str(&response)?)
167        }
168    }
169
170    pub fn use_tool_stream<T: LanguageModelTool>(
171        &self,
172        request: LanguageModelRequest,
173        cx: &AsyncAppContext,
174    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
175        let schema = schemars::schema_for!(T);
176        let schema_json = serde_json::to_value(&schema).unwrap();
177        self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
178    }
179}
180
181pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
182    fn name() -> String;
183    fn description() -> String;
184}
185
186pub trait LanguageModelProvider: 'static {
187    fn id(&self) -> LanguageModelProviderId;
188    fn name(&self) -> LanguageModelProviderName;
189    fn icon(&self) -> IconName {
190        IconName::ZedAssistant
191    }
192    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
193    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
194    fn is_authenticated(&self, cx: &AppContext) -> bool;
195    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
196    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
197    fn must_accept_terms(&self, _cx: &AppContext) -> bool {
198        false
199    }
200    fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
201        None
202    }
203    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
204}
205
206pub trait LanguageModelProviderState: 'static {
207    type ObservableEntity;
208
209    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
210
211    fn subscribe<T: 'static>(
212        &self,
213        cx: &mut gpui::ModelContext<T>,
214        callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
215    ) -> Option<gpui::Subscription> {
216        let entity = self.observable_entity()?;
217        Some(cx.observe(&entity, move |this, _, cx| {
218            callback(this, cx);
219        }))
220    }
221}
222
223#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
224pub struct LanguageModelId(pub SharedString);
225
226#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
227pub struct LanguageModelName(pub SharedString);
228
229#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
230pub struct LanguageModelProviderId(pub SharedString);
231
232#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
233pub struct LanguageModelProviderName(pub SharedString);
234
235impl fmt::Display for LanguageModelProviderId {
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        write!(f, "{}", self.0)
238    }
239}
240
241impl From<String> for LanguageModelId {
242    fn from(value: String) -> Self {
243        Self(SharedString::from(value))
244    }
245}
246
247impl From<String> for LanguageModelName {
248    fn from(value: String) -> Self {
249        Self(SharedString::from(value))
250    }
251}
252
253impl From<String> for LanguageModelProviderId {
254    fn from(value: String) -> Self {
255        Self(SharedString::from(value))
256    }
257}
258
259impl From<String> for LanguageModelProviderName {
260    fn from(value: String) -> Self {
261        Self(SharedString::from(value))
262    }
263}