language_model.rs

  1mod model;
  2pub mod provider;
  3mod rate_limiter;
  4mod registry;
  5mod request;
  6mod role;
  7pub mod settings;
  8
  9use anyhow::Result;
 10use client::Client;
 11use futures::{future::BoxFuture, stream::BoxStream};
 12use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
 13pub use model::*;
 14use project::Fs;
 15pub(crate) use rate_limiter::*;
 16pub use registry::*;
 17pub use request::*;
 18pub use role::*;
 19use schemars::JsonSchema;
 20use serde::de::DeserializeOwned;
 21use std::{future::Future, sync::Arc};
 22
 23pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
 24    settings::init(fs, cx);
 25    registry::init(client, cx);
 26}
 27
 28pub trait LanguageModel: Send + Sync {
 29    fn id(&self) -> LanguageModelId;
 30    fn name(&self) -> LanguageModelName;
 31    fn provider_id(&self) -> LanguageModelProviderId;
 32    fn provider_name(&self) -> LanguageModelProviderName;
 33    fn telemetry_id(&self) -> String;
 34
 35    fn max_token_count(&self) -> usize;
 36
 37    fn count_tokens(
 38        &self,
 39        request: LanguageModelRequest,
 40        cx: &AppContext,
 41    ) -> BoxFuture<'static, Result<usize>>;
 42
 43    fn stream_completion(
 44        &self,
 45        request: LanguageModelRequest,
 46        cx: &AsyncAppContext,
 47    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 48
 49    fn use_any_tool(
 50        &self,
 51        request: LanguageModelRequest,
 52        name: String,
 53        description: String,
 54        schema: serde_json::Value,
 55        cx: &AsyncAppContext,
 56    ) -> BoxFuture<'static, Result<serde_json::Value>>;
 57}
 58
 59impl dyn LanguageModel {
 60    pub fn use_tool<T: LanguageModelTool>(
 61        &self,
 62        request: LanguageModelRequest,
 63        cx: &AsyncAppContext,
 64    ) -> impl 'static + Future<Output = Result<T>> {
 65        let schema = schemars::schema_for!(T);
 66        let schema_json = serde_json::to_value(&schema).unwrap();
 67        let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
 68        async move {
 69            let response = request.await?;
 70            Ok(serde_json::from_value(response)?)
 71        }
 72    }
 73}
 74
 75pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
 76    fn name() -> String;
 77    fn description() -> String;
 78}
 79
 80pub trait LanguageModelProvider: 'static {
 81    fn id(&self) -> LanguageModelProviderId;
 82    fn name(&self) -> LanguageModelProviderName;
 83    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
 84    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
 85    fn is_authenticated(&self, cx: &AppContext) -> bool;
 86    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
 87    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 88    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
 89
 90    // fn observable_entity(&self) ;
 91}
 92
 93pub trait LanguageModelProviderState: 'static {
 94    type ObservableEntity;
 95
 96    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>>;
 97
 98    fn subscribe<T: 'static>(
 99        &self,
100        cx: &mut gpui::ModelContext<T>,
101        callback: impl Fn(&mut T, &mut gpui::ModelContext<T>) + 'static,
102    ) -> Option<gpui::Subscription> {
103        let entity = self.observable_entity()?;
104        Some(cx.observe(&entity, move |this, _, cx| {
105            callback(this, cx);
106        }))
107    }
108}
109
110#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
111pub struct LanguageModelId(pub SharedString);
112
113#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
114pub struct LanguageModelName(pub SharedString);
115
116#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
117pub struct LanguageModelProviderId(pub SharedString);
118
119#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
120pub struct LanguageModelProviderName(pub SharedString);
121
122impl From<String> for LanguageModelId {
123    fn from(value: String) -> Self {
124        Self(SharedString::from(value))
125    }
126}
127
128impl From<String> for LanguageModelName {
129    fn from(value: String) -> Self {
130        Self(SharedString::from(value))
131    }
132}
133
134impl From<String> for LanguageModelProviderId {
135    fn from(value: String) -> Self {
136        Self(SharedString::from(value))
137    }
138}
139
140impl From<String> for LanguageModelProviderName {
141    fn from(value: String) -> Self {
142        Self(SharedString::from(value))
143    }
144}