language_model.rs

  1mod model;
  2pub mod provider;
  3mod rate_limiter;
  4mod registry;
  5mod request;
  6mod role;
  7pub mod settings;
  8
  9use anyhow::Result;
 10use client::{Client, UserStore};
 11use futures::FutureExt;
 12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
 13use gpui::{
 14    AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
 15};
 16pub use model::*;
 17use project::Fs;
 18use proto::Plan;
 19pub(crate) use rate_limiter::*;
 20pub use registry::*;
 21pub use request::*;
 22pub use role::*;
 23use schemars::JsonSchema;
 24use serde::{de::DeserializeOwned, Deserialize, Serialize};
 25use std::{future::Future, sync::Arc};
 26use ui::IconName;
 27
 28pub fn init(
 29    user_store: Model<UserStore>,
 30    client: Arc<Client>,
 31    fs: Arc<dyn Fs>,
 32    cx: &mut AppContext,
 33) {
 34    settings::init(fs, cx);
 35    registry::init(user_store, client, cx);
 36}
 37
 38/// The availability of a [`LanguageModel`].
 39#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 40pub enum LanguageModelAvailability {
 41    /// The language model is available to the general public.
 42    Public,
 43    /// The language model is available to users on the indicated plan.
 44    RequiresPlan(Plan),
 45}
 46
 47/// Configuration for caching language model messages.
 48#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 49pub struct LanguageModelCacheConfiguration {
 50    pub max_cache_anchors: usize,
 51    pub should_speculate: bool,
 52    pub min_total_token: usize,
 53}
 54
 55/// A completion event from a language model.
 56#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 57pub enum LanguageModelCompletionEvent {
 58    Text(String),
 59    ToolUse(LanguageModelToolUse),
 60}
 61
 62#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 63pub struct LanguageModelToolUse {
 64    pub id: String,
 65    pub name: String,
 66    pub input: serde_json::Value,
 67}
 68
 69pub trait LanguageModel: Send + Sync {
 70    fn id(&self) -> LanguageModelId;
 71    fn name(&self) -> LanguageModelName;
 72    /// If None, falls back to [LanguageModelProvider::icon]
 73    fn icon(&self) -> Option<IconName> {
 74        None
 75    }
 76    fn provider_id(&self) -> LanguageModelProviderId;
 77    fn provider_name(&self) -> LanguageModelProviderName;
 78    fn telemetry_id(&self) -> String;
 79
 80    /// Returns the availability of this language model.
 81    fn availability(&self) -> LanguageModelAvailability {
 82        LanguageModelAvailability::Public
 83    }
 84
 85    fn max_token_count(&self) -> usize;
 86    fn max_output_tokens(&self) -> Option<u32> {
 87        None
 88    }
 89
 90    fn count_tokens(
 91        &self,
 92        request: LanguageModelRequest,
 93        cx: &AppContext,
 94    ) -> BoxFuture<'static, Result<usize>>;
 95
 96    fn stream_completion(
 97        &self,
 98        request: LanguageModelRequest,
 99        cx: &AsyncAppContext,
100    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
101
102    fn stream_completion_text(
103        &self,
104        request: LanguageModelRequest,
105        cx: &AsyncAppContext,
106    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
107        let events = self.stream_completion(request, cx);
108
109        async move {
110            Ok(events
111                .await?
112                .filter_map(|result| async move {
113                    match result {
114                        Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
115                        Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
116                        Err(err) => Some(Err(err)),
117                    }
118                })
119                .boxed())
120        }
121        .boxed()
122    }
123
124    fn use_any_tool(
125        &self,
126        request: LanguageModelRequest,
127        name: String,
128        description: String,
129        schema: serde_json::Value,
130        cx: &AsyncAppContext,
131    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
132
133    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
134        None
135    }
136
137    #[cfg(any(test, feature = "test-support"))]
138    fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
139        unimplemented!()
140    }
141}
142
143impl dyn LanguageModel {
144    pub fn use_tool<T: LanguageModelTool>(
145        &self,
146        request: LanguageModelRequest,
147        cx: &AsyncAppContext,
148    ) -> impl 'static + Future<Output = Result<T>> {
149        let schema = schemars::schema_for!(T);
150        let schema_json = serde_json::to_value(&schema).unwrap();
151        let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
152        async move {
153            let stream = stream.await?;
154            let response = stream.try_collect::<String>().await?;
155            Ok(serde_json::from_str(&response)?)
156        }
157    }
158
159    pub fn use_tool_stream<T: LanguageModelTool>(
160        &self,
161        request: LanguageModelRequest,
162        cx: &AsyncAppContext,
163    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
164        let schema = schemars::schema_for!(T);
165        let schema_json = serde_json::to_value(&schema).unwrap();
166        self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
167    }
168}
169
170pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
171    fn name() -> String;
172    fn description() -> String;
173}
174
175pub trait LanguageModelProvider: 'static {
176    fn id(&self) -> LanguageModelProviderId;
177    fn name(&self) -> LanguageModelProviderName;
178    fn icon(&self) -> IconName {
179        IconName::ZedAssistant
180    }
181    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
182    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
183    fn is_authenticated(&self, cx: &AppContext) -> bool;
184    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
185    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
186    fn must_accept_terms(&self, _cx: &AppContext) -> bool {
187        false
188    }
189    fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
190        None
191    }
192    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
193}
194
195pub trait LanguageModelProviderState: 'static {
196    type ObservableEntity;
197
198    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
199
200    fn subscribe<T: 'static>(
201        &self,
202        cx: &mut gpui::ModelContext<T>,
203        callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
204    ) -> Option<gpui::Subscription> {
205        let entity = self.observable_entity()?;
206        Some(cx.observe(&entity, move |this, _, cx| {
207            callback(this, cx);
208        }))
209    }
210}
211
212#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
213pub struct LanguageModelId(pub SharedString);
214
215#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
216pub struct LanguageModelName(pub SharedString);
217
218#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
219pub struct LanguageModelProviderId(pub SharedString);
220
221#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
222pub struct LanguageModelProviderName(pub SharedString);
223
224impl From<String> for LanguageModelId {
225    fn from(value: String) -> Self {
226        Self(SharedString::from(value))
227    }
228}
229
230impl From<String> for LanguageModelName {
231    fn from(value: String) -> Self {
232        Self(SharedString::from(value))
233    }
234}
235
236impl From<String> for LanguageModelProviderId {
237    fn from(value: String) -> Self {
238        Self(SharedString::from(value))
239    }
240}
241
242impl From<String> for LanguageModelProviderName {
243    fn from(value: String) -> Self {
244        Self(SharedString::from(value))
245    }
246}