language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6
  7#[cfg(any(test, feature = "test-support"))]
  8pub mod fake_provider;
  9
 10use anyhow::Result;
 11use futures::FutureExt;
 12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
 13use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 14pub use model::*;
 15use proto::Plan;
 16pub use rate_limiter::*;
 17pub use registry::*;
 18pub use request::*;
 19pub use role::*;
 20use schemars::JsonSchema;
 21use serde::{de::DeserializeOwned, Deserialize, Serialize};
 22use std::fmt;
 23use std::{future::Future, sync::Arc};
 24use thiserror::Error;
 25use ui::IconName;
 26
 27pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 28
 29pub fn init(cx: &mut App) {
 30    registry::init(cx);
 31}
 32
 33/// The availability of a [`LanguageModel`].
 34#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 35pub enum LanguageModelAvailability {
 36    /// The language model is available to the general public.
 37    Public,
 38    /// The language model is available to users on the indicated plan.
 39    RequiresPlan(Plan),
 40}
 41
 42/// Configuration for caching language model messages.
 43#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 44pub struct LanguageModelCacheConfiguration {
 45    pub max_cache_anchors: usize,
 46    pub should_speculate: bool,
 47    pub min_total_token: usize,
 48}
 49
 50/// A completion event from a language model.
 51#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 52pub enum LanguageModelCompletionEvent {
 53    Stop(StopReason),
 54    Text(String),
 55    ToolUse(LanguageModelToolUse),
 56    StartMessage { message_id: String },
 57}
 58
 59#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 60#[serde(rename_all = "snake_case")]
 61pub enum StopReason {
 62    EndTurn,
 63    MaxTokens,
 64    ToolUse,
 65}
 66
 67#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 68pub struct LanguageModelToolUseId(Arc<str>);
 69
 70impl fmt::Display for LanguageModelToolUseId {
 71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 72        write!(f, "{}", self.0)
 73    }
 74}
 75
 76impl<T> From<T> for LanguageModelToolUseId
 77where
 78    T: Into<Arc<str>>,
 79{
 80    fn from(value: T) -> Self {
 81        Self(value.into())
 82    }
 83}
 84
 85#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 86pub struct LanguageModelToolUse {
 87    pub id: LanguageModelToolUseId,
 88    pub name: String,
 89    pub input: serde_json::Value,
 90}
 91
 92pub struct LanguageModelTextStream {
 93    pub message_id: Option<String>,
 94    pub stream: BoxStream<'static, Result<String>>,
 95}
 96
 97impl Default for LanguageModelTextStream {
 98    fn default() -> Self {
 99        Self {
100            message_id: None,
101            stream: Box::pin(futures::stream::empty()),
102        }
103    }
104}
105
106pub trait LanguageModel: Send + Sync {
107    fn id(&self) -> LanguageModelId;
108    fn name(&self) -> LanguageModelName;
109    /// If None, falls back to [LanguageModelProvider::icon]
110    fn icon(&self) -> Option<IconName> {
111        None
112    }
113    fn provider_id(&self) -> LanguageModelProviderId;
114    fn provider_name(&self) -> LanguageModelProviderName;
115    fn telemetry_id(&self) -> String;
116
117    fn api_key(&self, _cx: &App) -> Option<String> {
118        None
119    }
120
121    /// Returns the availability of this language model.
122    fn availability(&self) -> LanguageModelAvailability {
123        LanguageModelAvailability::Public
124    }
125
126    fn max_token_count(&self) -> usize;
127    fn max_output_tokens(&self) -> Option<u32> {
128        None
129    }
130
131    fn count_tokens(
132        &self,
133        request: LanguageModelRequest,
134        cx: &App,
135    ) -> BoxFuture<'static, Result<usize>>;
136
137    fn stream_completion(
138        &self,
139        request: LanguageModelRequest,
140        cx: &AsyncApp,
141    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
142
143    fn stream_completion_text(
144        &self,
145        request: LanguageModelRequest,
146        cx: &AsyncApp,
147    ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
148        let events = self.stream_completion(request, cx);
149
150        async move {
151            let mut events = events.await?.fuse();
152            let mut message_id = None;
153            let mut first_item_text = None;
154
155            if let Some(first_event) = events.next().await {
156                match first_event {
157                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
158                        message_id = Some(id.clone());
159                    }
160                    Ok(LanguageModelCompletionEvent::Text(text)) => {
161                        first_item_text = Some(text);
162                    }
163                    _ => (),
164                }
165            }
166
167            let stream = futures::stream::iter(first_item_text.map(Ok))
168                .chain(events.filter_map(|result| async move {
169                    match result {
170                        Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
171                        Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
172                        Ok(LanguageModelCompletionEvent::Stop(_)) => None,
173                        Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
174                        Err(err) => Some(Err(err)),
175                    }
176                }))
177                .boxed();
178
179            Ok(LanguageModelTextStream { message_id, stream })
180        }
181        .boxed()
182    }
183
184    fn use_any_tool(
185        &self,
186        request: LanguageModelRequest,
187        name: String,
188        description: String,
189        schema: serde_json::Value,
190        cx: &AsyncApp,
191    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
192
193    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
194        None
195    }
196
197    #[cfg(any(test, feature = "test-support"))]
198    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
199        unimplemented!()
200    }
201}
202
203impl dyn LanguageModel {
204    pub fn use_tool<T: LanguageModelTool>(
205        &self,
206        request: LanguageModelRequest,
207        cx: &AsyncApp,
208    ) -> impl 'static + Future<Output = Result<T>> {
209        let schema = schemars::schema_for!(T);
210        let schema_json = serde_json::to_value(&schema).unwrap();
211        let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
212        async move {
213            let stream = stream.await?;
214            let response = stream.try_collect::<String>().await?;
215            Ok(serde_json::from_str(&response)?)
216        }
217    }
218
219    pub fn use_tool_stream<T: LanguageModelTool>(
220        &self,
221        request: LanguageModelRequest,
222        cx: &AsyncApp,
223    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
224        let schema = schemars::schema_for!(T);
225        let schema_json = serde_json::to_value(&schema).unwrap();
226        self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
227    }
228}
229
230pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
231    fn name() -> String;
232    fn description() -> String;
233}
234
235/// An error that occurred when trying to authenticate the language model provider.
236#[derive(Debug, Error)]
237pub enum AuthenticateError {
238    #[error("credentials not found")]
239    CredentialsNotFound,
240    #[error(transparent)]
241    Other(#[from] anyhow::Error),
242}
243
244pub trait LanguageModelProvider: 'static {
245    fn id(&self) -> LanguageModelProviderId;
246    fn name(&self) -> LanguageModelProviderName;
247    fn icon(&self) -> IconName {
248        IconName::ZedAssistant
249    }
250    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
251    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
252    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
253    fn is_authenticated(&self, cx: &App) -> bool;
254    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;
255    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
256    fn must_accept_terms(&self, _cx: &App) -> bool {
257        false
258    }
259    fn render_accept_terms(
260        &self,
261        _view: LanguageModelProviderTosView,
262        _cx: &mut App,
263    ) -> Option<AnyElement> {
264        None
265    }
266    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
267}
268
269#[derive(PartialEq, Eq)]
270pub enum LanguageModelProviderTosView {
271    ThreadEmptyState,
272    PromptEditorPopup,
273    Configuration,
274}
275
276pub trait LanguageModelProviderState: 'static {
277    type ObservableEntity;
278
279    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
280
281    fn subscribe<T: 'static>(
282        &self,
283        cx: &mut gpui::Context<T>,
284        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
285    ) -> Option<gpui::Subscription> {
286        let entity = self.observable_entity()?;
287        Some(cx.observe(&entity, move |this, _, cx| {
288            callback(this, cx);
289        }))
290    }
291}
292
293#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
294pub struct LanguageModelId(pub SharedString);
295
296#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
297pub struct LanguageModelName(pub SharedString);
298
299#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
300pub struct LanguageModelProviderId(pub SharedString);
301
302#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
303pub struct LanguageModelProviderName(pub SharedString);
304
305impl fmt::Display for LanguageModelProviderId {
306    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307        write!(f, "{}", self.0)
308    }
309}
310
311impl From<String> for LanguageModelId {
312    fn from(value: String) -> Self {
313        Self(SharedString::from(value))
314    }
315}
316
317impl From<String> for LanguageModelName {
318    fn from(value: String) -> Self {
319        Self(SharedString::from(value))
320    }
321}
322
323impl From<String> for LanguageModelProviderId {
324    fn from(value: String) -> Self {
325        Self(SharedString::from(value))
326    }
327}
328
329impl From<String> for LanguageModelProviderName {
330    fn from(value: String) -> Self {
331        Self(SharedString::from(value))
332    }
333}