language_model.rs

  1mod model;
  2pub mod provider;
  3mod rate_limiter;
  4mod registry;
  5mod request;
  6mod role;
  7pub mod settings;
  8
  9use anyhow::Result;
 10use client::Client;
 11use futures::{future::BoxFuture, stream::BoxStream};
 12use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
 13pub use model::*;
 14use project::Fs;
 15pub(crate) use rate_limiter::*;
 16pub use registry::*;
 17pub use request::*;
 18pub use role::*;
 19use schemars::JsonSchema;
 20use serde::de::DeserializeOwned;
 21use std::{future::Future, sync::Arc};
 22
 23pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
 24    settings::init(fs, cx);
 25    registry::init(client, cx);
 26}
 27
 28pub trait LanguageModel: Send + Sync {
 29    fn id(&self) -> LanguageModelId;
 30    fn name(&self) -> LanguageModelName;
 31    fn provider_id(&self) -> LanguageModelProviderId;
 32    fn provider_name(&self) -> LanguageModelProviderName;
 33    fn telemetry_id(&self) -> String;
 34
 35    fn max_token_count(&self) -> usize;
 36
 37    fn count_tokens(
 38        &self,
 39        request: LanguageModelRequest,
 40        cx: &AppContext,
 41    ) -> BoxFuture<'static, Result<usize>>;
 42
 43    fn stream_completion(
 44        &self,
 45        request: LanguageModelRequest,
 46        cx: &AsyncAppContext,
 47    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 48
 49    fn use_any_tool(
 50        &self,
 51        request: LanguageModelRequest,
 52        name: String,
 53        description: String,
 54        schema: serde_json::Value,
 55        cx: &AsyncAppContext,
 56    ) -> BoxFuture<'static, Result<serde_json::Value>>;
 57}
 58
 59impl dyn LanguageModel {
 60    pub fn use_tool<T: LanguageModelTool>(
 61        &self,
 62        request: LanguageModelRequest,
 63        cx: &AsyncAppContext,
 64    ) -> impl 'static + Future<Output = Result<T>> {
 65        let schema = schemars::schema_for!(T);
 66        let schema_json = serde_json::to_value(&schema).unwrap();
 67        let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
 68        async move {
 69            let response = request.await?;
 70            Ok(serde_json::from_value(response)?)
 71        }
 72    }
 73}
 74
 75pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
 76    fn name() -> String;
 77    fn description() -> String;
 78}
 79
 80pub trait LanguageModelProvider: 'static {
 81    fn id(&self) -> LanguageModelProviderId;
 82    fn name(&self) -> LanguageModelProviderName;
 83    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
 84    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
 85    fn is_authenticated(&self, cx: &AppContext) -> bool;
 86    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
 87    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 88    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
 89}
 90
 91pub trait LanguageModelProviderState: 'static {
 92    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
 93}
 94
 95#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 96pub struct LanguageModelId(pub SharedString);
 97
 98#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 99pub struct LanguageModelName(pub SharedString);
100
101#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
102pub struct LanguageModelProviderId(pub SharedString);
103
104#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
105pub struct LanguageModelProviderName(pub SharedString);
106
107impl From<String> for LanguageModelId {
108    fn from(value: String) -> Self {
109        Self(SharedString::from(value))
110    }
111}
112
113impl From<String> for LanguageModelName {
114    fn from(value: String) -> Self {
115        Self(SharedString::from(value))
116    }
117}
118
119impl From<String> for LanguageModelProviderId {
120    fn from(value: String) -> Self {
121        Self(SharedString::from(value))
122    }
123}
124
125impl From<String> for LanguageModelProviderName {
126    fn from(value: String) -> Self {
127        Self(SharedString::from(value))
128    }
129}