language_model.rs

  1mod model;
  2pub mod provider;
  3mod rate_limiter;
  4mod registry;
  5mod request;
  6mod role;
  7pub mod settings;
  8
  9use anyhow::Result;
 10use client::Client;
 11use futures::{future::BoxFuture, stream::BoxStream};
 12use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
 13pub use model::*;
 14use project::Fs;
 15pub(crate) use rate_limiter::*;
 16pub use registry::*;
 17pub use request::*;
 18pub use role::*;
 19use schemars::JsonSchema;
 20use serde::de::DeserializeOwned;
 21use std::{future::Future, sync::Arc};
 22
 23pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
 24    settings::init(fs, cx);
 25    registry::init(client, cx);
 26}
 27
 28pub trait LanguageModel: Send + Sync {
 29    fn id(&self) -> LanguageModelId;
 30    fn name(&self) -> LanguageModelName;
 31    fn provider_id(&self) -> LanguageModelProviderId;
 32    fn provider_name(&self) -> LanguageModelProviderName;
 33    fn telemetry_id(&self) -> String;
 34
 35    fn max_token_count(&self) -> usize;
 36
 37    fn count_tokens(
 38        &self,
 39        request: LanguageModelRequest,
 40        cx: &AppContext,
 41    ) -> BoxFuture<'static, Result<usize>>;
 42
 43    fn stream_completion(
 44        &self,
 45        request: LanguageModelRequest,
 46        cx: &AsyncAppContext,
 47    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 48
 49    fn use_any_tool(
 50        &self,
 51        request: LanguageModelRequest,
 52        name: String,
 53        description: String,
 54        schema: serde_json::Value,
 55        cx: &AsyncAppContext,
 56    ) -> BoxFuture<'static, Result<serde_json::Value>>;
 57}
 58
 59impl dyn LanguageModel {
 60    pub fn use_tool<T: LanguageModelTool>(
 61        &self,
 62        request: LanguageModelRequest,
 63        cx: &AsyncAppContext,
 64    ) -> impl 'static + Future<Output = Result<T>> {
 65        let schema = schemars::schema_for!(T);
 66        let schema_json = serde_json::to_value(&schema).unwrap();
 67        let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
 68        async move {
 69            let response = request.await?;
 70            Ok(serde_json::from_value(response)?)
 71        }
 72    }
 73}
 74
 75pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
 76    fn name() -> String;
 77    fn description() -> String;
 78}
 79
 80pub trait LanguageModelProvider: 'static {
 81    fn id(&self) -> LanguageModelProviderId;
 82    fn name(&self) -> LanguageModelProviderName;
 83    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
 84    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
 85    fn is_authenticated(&self, cx: &AppContext) -> bool;
 86    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
 87    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 88    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
 89}
 90
 91pub trait LanguageModelProviderState: 'static {
 92    type ObservableEntity;
 93
 94    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
 95
 96    fn subscribe<T: 'static>(
 97        &self,
 98        cx: &mut gpui::ModelContext<T>,
 99        callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
100    ) -> Option<gpui::Subscription> {
101        let entity = self.observable_entity()?;
102        Some(cx.observe(&entity, move |this, _, cx| {
103            callback(this, cx);
104        }))
105    }
106}
107
108#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
109pub struct LanguageModelId(pub SharedString);
110
111#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
112pub struct LanguageModelName(pub SharedString);
113
114#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
115pub struct LanguageModelProviderId(pub SharedString);
116
117#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
118pub struct LanguageModelProviderName(pub SharedString);
119
120impl From<String> for LanguageModelId {
121    fn from(value: String) -> Self {
122        Self(SharedString::from(value))
123    }
124}
125
126impl From<String> for LanguageModelName {
127    fn from(value: String) -> Self {
128        Self(SharedString::from(value))
129    }
130}
131
132impl From<String> for LanguageModelProviderId {
133    fn from(value: String) -> Self {
134        Self(SharedString::from(value))
135    }
136}
137
138impl From<String> for LanguageModelProviderName {
139    fn from(value: String) -> Self {
140        Self(SharedString::from(value))
141    }
142}