language_model.rs

  1mod model;
  2pub mod provider;
  3mod rate_limiter;
  4mod registry;
  5mod request;
  6mod role;
  7pub mod settings;
  8
  9use anyhow::Result;
 10use client::{Client, UserStore};
 11use futures::FutureExt;
 12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
 13use gpui::{
 14    AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
 15};
 16pub use model::*;
 17use project::Fs;
 18use proto::Plan;
 19pub(crate) use rate_limiter::*;
 20pub use registry::*;
 21pub use request::*;
 22pub use role::*;
 23use schemars::JsonSchema;
 24use serde::{de::DeserializeOwned, Deserialize, Serialize};
 25use std::{future::Future, sync::Arc};
 26use ui::IconName;
 27
 28pub fn init(
 29    user_store: Model<UserStore>,
 30    client: Arc<Client>,
 31    fs: Arc<dyn Fs>,
 32    cx: &mut AppContext,
 33) {
 34    settings::init(fs, cx);
 35    registry::init(user_store, client, cx);
 36}
 37
 38/// The availability of a [`LanguageModel`].
 39#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 40pub enum LanguageModelAvailability {
 41    /// The language model is available to the general public.
 42    Public,
 43    /// The language model is available to users on the indicated plan.
 44    RequiresPlan(Plan),
 45}
 46
 47/// Configuration for caching language model messages.
 48#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 49pub struct LanguageModelCacheConfiguration {
 50    pub max_cache_anchors: usize,
 51    pub should_speculate: bool,
 52    pub min_total_token: usize,
 53}
 54
 55/// A completion event from a language model.
 56#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 57pub enum LanguageModelCompletionEvent {
 58    Stop(StopReason),
 59    Text(String),
 60    ToolUse(LanguageModelToolUse),
 61}
 62
 63#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 64#[serde(rename_all = "snake_case")]
 65pub enum StopReason {
 66    EndTurn,
 67    MaxTokens,
 68    ToolUse,
 69}
 70
 71#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 72pub struct LanguageModelToolUse {
 73    pub id: String,
 74    pub name: String,
 75    pub input: serde_json::Value,
 76}
 77
 78pub trait LanguageModel: Send + Sync {
 79    fn id(&self) -> LanguageModelId;
 80    fn name(&self) -> LanguageModelName;
 81    /// If None, falls back to [LanguageModelProvider::icon]
 82    fn icon(&self) -> Option<IconName> {
 83        None
 84    }
 85    fn provider_id(&self) -> LanguageModelProviderId;
 86    fn provider_name(&self) -> LanguageModelProviderName;
 87    fn telemetry_id(&self) -> String;
 88
 89    /// Returns the availability of this language model.
 90    fn availability(&self) -> LanguageModelAvailability {
 91        LanguageModelAvailability::Public
 92    }
 93
 94    fn max_token_count(&self) -> usize;
 95    fn max_output_tokens(&self) -> Option<u32> {
 96        None
 97    }
 98
 99    fn count_tokens(
100        &self,
101        request: LanguageModelRequest,
102        cx: &AppContext,
103    ) -> BoxFuture<'static, Result<usize>>;
104
105    fn stream_completion(
106        &self,
107        request: LanguageModelRequest,
108        cx: &AsyncAppContext,
109    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
110
111    fn stream_completion_text(
112        &self,
113        request: LanguageModelRequest,
114        cx: &AsyncAppContext,
115    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
116        let events = self.stream_completion(request, cx);
117
118        async move {
119            Ok(events
120                .await?
121                .filter_map(|result| async move {
122                    match result {
123                        Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
124                        Ok(LanguageModelCompletionEvent::Stop(_)) => None,
125                        Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
126                        Err(err) => Some(Err(err)),
127                    }
128                })
129                .boxed())
130        }
131        .boxed()
132    }
133
134    fn use_any_tool(
135        &self,
136        request: LanguageModelRequest,
137        name: String,
138        description: String,
139        schema: serde_json::Value,
140        cx: &AsyncAppContext,
141    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
142
143    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
144        None
145    }
146
147    #[cfg(any(test, feature = "test-support"))]
148    fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
149        unimplemented!()
150    }
151}
152
153impl dyn LanguageModel {
154    pub fn use_tool<T: LanguageModelTool>(
155        &self,
156        request: LanguageModelRequest,
157        cx: &AsyncAppContext,
158    ) -> impl 'static + Future<Output = Result<T>> {
159        let schema = schemars::schema_for!(T);
160        let schema_json = serde_json::to_value(&schema).unwrap();
161        let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
162        async move {
163            let stream = stream.await?;
164            let response = stream.try_collect::<String>().await?;
165            Ok(serde_json::from_str(&response)?)
166        }
167    }
168
169    pub fn use_tool_stream<T: LanguageModelTool>(
170        &self,
171        request: LanguageModelRequest,
172        cx: &AsyncAppContext,
173    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
174        let schema = schemars::schema_for!(T);
175        let schema_json = serde_json::to_value(&schema).unwrap();
176        self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
177    }
178}
179
180pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
181    fn name() -> String;
182    fn description() -> String;
183}
184
185pub trait LanguageModelProvider: 'static {
186    fn id(&self) -> LanguageModelProviderId;
187    fn name(&self) -> LanguageModelProviderName;
188    fn icon(&self) -> IconName {
189        IconName::ZedAssistant
190    }
191    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
192    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
193    fn is_authenticated(&self, cx: &AppContext) -> bool;
194    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
195    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
196    fn must_accept_terms(&self, _cx: &AppContext) -> bool {
197        false
198    }
199    fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
200        None
201    }
202    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
203}
204
205pub trait LanguageModelProviderState: 'static {
206    type ObservableEntity;
207
208    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
209
210    fn subscribe<T: 'static>(
211        &self,
212        cx: &mut gpui::ModelContext<T>,
213        callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
214    ) -> Option<gpui::Subscription> {
215        let entity = self.observable_entity()?;
216        Some(cx.observe(&entity, move |this, _, cx| {
217            callback(this, cx);
218        }))
219    }
220}
221
222#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
223pub struct LanguageModelId(pub SharedString);
224
225#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
226pub struct LanguageModelName(pub SharedString);
227
228#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
229pub struct LanguageModelProviderId(pub SharedString);
230
231#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
232pub struct LanguageModelProviderName(pub SharedString);
233
234impl From<String> for LanguageModelId {
235    fn from(value: String) -> Self {
236        Self(SharedString::from(value))
237    }
238}
239
240impl From<String> for LanguageModelName {
241    fn from(value: String) -> Self {
242        Self(SharedString::from(value))
243    }
244}
245
246impl From<String> for LanguageModelProviderId {
247    fn from(value: String) -> Self {
248        Self(SharedString::from(value))
249    }
250}
251
252impl From<String> for LanguageModelProviderName {
253    fn from(value: String) -> Self {
254        Self(SharedString::from(value))
255    }
256}