language_model.rs

  1mod model;
  2mod rate_limiter;
  3mod registry;
  4mod request;
  5mod role;
  6
  7#[cfg(any(test, feature = "test-support"))]
  8pub mod fake_provider;
  9
 10use anyhow::Result;
 11use futures::FutureExt;
 12use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
 13use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 14pub use model::*;
 15use proto::Plan;
 16pub use rate_limiter::*;
 17pub use registry::*;
 18pub use request::*;
 19pub use role::*;
 20use schemars::JsonSchema;
 21use serde::{de::DeserializeOwned, Deserialize, Serialize};
 22use std::fmt;
 23use std::{future::Future, sync::Arc};
 24use ui::IconName;
 25
 26pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 27
 28pub fn init(cx: &mut App) {
 29    registry::init(cx);
 30}
 31
 32/// The availability of a [`LanguageModel`].
 33#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 34pub enum LanguageModelAvailability {
 35    /// The language model is available to the general public.
 36    Public,
 37    /// The language model is available to users on the indicated plan.
 38    RequiresPlan(Plan),
 39}
 40
 41/// Configuration for caching language model messages.
 42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 43pub struct LanguageModelCacheConfiguration {
 44    pub max_cache_anchors: usize,
 45    pub should_speculate: bool,
 46    pub min_total_token: usize,
 47}
 48
 49/// A completion event from a language model.
 50#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
 51pub enum LanguageModelCompletionEvent {
 52    Stop(StopReason),
 53    Text(String),
 54    ToolUse(LanguageModelToolUse),
 55    StartMessage { message_id: String },
 56}
 57
 58#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)]
 59#[serde(rename_all = "snake_case")]
 60pub enum StopReason {
 61    EndTurn,
 62    MaxTokens,
 63    ToolUse,
 64}
 65
 66#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 67pub struct LanguageModelToolUseId(Arc<str>);
 68
 69impl fmt::Display for LanguageModelToolUseId {
 70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 71        write!(f, "{}", self.0)
 72    }
 73}
 74
 75impl<T> From<T> for LanguageModelToolUseId
 76where
 77    T: Into<Arc<str>>,
 78{
 79    fn from(value: T) -> Self {
 80        Self(value.into())
 81    }
 82}
 83
 84#[derive(Debug, PartialEq, Eq, Hash, Clone, Serialize, Deserialize)]
 85pub struct LanguageModelToolUse {
 86    pub id: LanguageModelToolUseId,
 87    pub name: String,
 88    pub input: serde_json::Value,
 89}
 90
 91pub struct LanguageModelTextStream {
 92    pub message_id: Option<String>,
 93    pub stream: BoxStream<'static, Result<String>>,
 94}
 95
 96impl Default for LanguageModelTextStream {
 97    fn default() -> Self {
 98        Self {
 99            message_id: None,
100            stream: Box::pin(futures::stream::empty()),
101        }
102    }
103}
104
105pub trait LanguageModel: Send + Sync {
106    fn id(&self) -> LanguageModelId;
107    fn name(&self) -> LanguageModelName;
108    /// If None, falls back to [LanguageModelProvider::icon]
109    fn icon(&self) -> Option<IconName> {
110        None
111    }
112    fn provider_id(&self) -> LanguageModelProviderId;
113    fn provider_name(&self) -> LanguageModelProviderName;
114    fn telemetry_id(&self) -> String;
115
116    fn api_key(&self, _cx: &App) -> Option<String> {
117        None
118    }
119
120    /// Returns the availability of this language model.
121    fn availability(&self) -> LanguageModelAvailability {
122        LanguageModelAvailability::Public
123    }
124
125    fn max_token_count(&self) -> usize;
126    fn max_output_tokens(&self) -> Option<u32> {
127        None
128    }
129
130    fn count_tokens(
131        &self,
132        request: LanguageModelRequest,
133        cx: &App,
134    ) -> BoxFuture<'static, Result<usize>>;
135
136    fn stream_completion(
137        &self,
138        request: LanguageModelRequest,
139        cx: &AsyncApp,
140    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>>;
141
142    fn stream_completion_text(
143        &self,
144        request: LanguageModelRequest,
145        cx: &AsyncApp,
146    ) -> BoxFuture<'static, Result<LanguageModelTextStream>> {
147        let events = self.stream_completion(request, cx);
148
149        async move {
150            let mut events = events.await?.fuse();
151            let mut message_id = None;
152            let mut first_item_text = None;
153
154            if let Some(first_event) = events.next().await {
155                match first_event {
156                    Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
157                        message_id = Some(id.clone());
158                    }
159                    Ok(LanguageModelCompletionEvent::Text(text)) => {
160                        first_item_text = Some(text);
161                    }
162                    _ => (),
163                }
164            }
165
166            let stream = futures::stream::iter(first_item_text.map(Ok))
167                .chain(events.filter_map(|result| async move {
168                    match result {
169                        Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
170                        Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
171                        Ok(LanguageModelCompletionEvent::Stop(_)) => None,
172                        Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
173                        Err(err) => Some(Err(err)),
174                    }
175                }))
176                .boxed();
177
178            Ok(LanguageModelTextStream { message_id, stream })
179        }
180        .boxed()
181    }
182
183    fn use_any_tool(
184        &self,
185        request: LanguageModelRequest,
186        name: String,
187        description: String,
188        schema: serde_json::Value,
189        cx: &AsyncApp,
190    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
191
192    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
193        None
194    }
195
196    #[cfg(any(test, feature = "test-support"))]
197    fn as_fake(&self) -> &fake_provider::FakeLanguageModel {
198        unimplemented!()
199    }
200}
201
202impl dyn LanguageModel {
203    pub fn use_tool<T: LanguageModelTool>(
204        &self,
205        request: LanguageModelRequest,
206        cx: &AsyncApp,
207    ) -> impl 'static + Future<Output = Result<T>> {
208        let schema = schemars::schema_for!(T);
209        let schema_json = serde_json::to_value(&schema).unwrap();
210        let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
211        async move {
212            let stream = stream.await?;
213            let response = stream.try_collect::<String>().await?;
214            Ok(serde_json::from_str(&response)?)
215        }
216    }
217
218    pub fn use_tool_stream<T: LanguageModelTool>(
219        &self,
220        request: LanguageModelRequest,
221        cx: &AsyncApp,
222    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
223        let schema = schemars::schema_for!(T);
224        let schema_json = serde_json::to_value(&schema).unwrap();
225        self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
226    }
227}
228
229pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
230    fn name() -> String;
231    fn description() -> String;
232}
233
234pub trait LanguageModelProvider: 'static {
235    fn id(&self) -> LanguageModelProviderId;
236    fn name(&self) -> LanguageModelProviderName;
237    fn icon(&self) -> IconName {
238        IconName::ZedAssistant
239    }
240    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
241    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
242    fn is_authenticated(&self, cx: &App) -> bool;
243    fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
244    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView;
245    fn must_accept_terms(&self, _cx: &App) -> bool {
246        false
247    }
248    fn render_accept_terms(
249        &self,
250        _view: LanguageModelProviderTosView,
251        _cx: &mut App,
252    ) -> Option<AnyElement> {
253        None
254    }
255    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>>;
256}
257
258#[derive(PartialEq, Eq)]
259pub enum LanguageModelProviderTosView {
260    ThreadEmptyState,
261    PromptEditorPopup,
262    Configuration,
263}
264
265pub trait LanguageModelProviderState: 'static {
266    type ObservableEntity;
267
268    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>>;
269
270    fn subscribe<T: 'static>(
271        &self,
272        cx: &mut gpui::Context<T>,
273        callback: impl Fn(&mut T, &mut gpui::Context<T>) + 'static,
274    ) -> Option<gpui::Subscription> {
275        let entity = self.observable_entity()?;
276        Some(cx.observe(&entity, move |this, _, cx| {
277            callback(this, cx);
278        }))
279    }
280}
281
282#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
283pub struct LanguageModelId(pub SharedString);
284
285#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
286pub struct LanguageModelName(pub SharedString);
287
288#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
289pub struct LanguageModelProviderId(pub SharedString);
290
291#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
292pub struct LanguageModelProviderName(pub SharedString);
293
294impl fmt::Display for LanguageModelProviderId {
295    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
296        write!(f, "{}", self.0)
297    }
298}
299
300impl From<String> for LanguageModelId {
301    fn from(value: String) -> Self {
302        Self(SharedString::from(value))
303    }
304}
305
306impl From<String> for LanguageModelName {
307    fn from(value: String) -> Self {
308        Self(SharedString::from(value))
309    }
310}
311
312impl From<String> for LanguageModelProviderId {
313    fn from(value: String) -> Self {
314        Self(SharedString::from(value))
315    }
316}
317
318impl From<String> for LanguageModelProviderName {
319    fn from(value: String) -> Self {
320        Self(SharedString::from(value))
321    }
322}