language_model.rs

  1mod model;
  2pub mod provider;
  3mod registry;
  4mod request;
  5mod role;
  6pub mod settings;
  7
  8use std::sync::Arc;
  9
 10use anyhow::Result;
 11use client::Client;
 12use futures::{future::BoxFuture, stream::BoxStream};
 13use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
 14
 15pub use model::*;
 16pub use registry::*;
 17pub use request::*;
 18pub use role::*;
 19use schemars::JsonSchema;
 20use serde::de::DeserializeOwned;
 21
 22pub fn init(client: Arc<Client>, cx: &mut AppContext) {
 23    settings::init(cx);
 24    registry::init(client, cx);
 25}
 26
 27pub trait LanguageModel: Send + Sync {
 28    fn id(&self) -> LanguageModelId;
 29    fn name(&self) -> LanguageModelName;
 30    fn provider_id(&self) -> LanguageModelProviderId;
 31    fn provider_name(&self) -> LanguageModelProviderName;
 32    fn telemetry_id(&self) -> String;
 33
 34    fn max_token_count(&self) -> usize;
 35
 36    fn count_tokens(
 37        &self,
 38        request: LanguageModelRequest,
 39        cx: &AppContext,
 40    ) -> BoxFuture<'static, Result<usize>>;
 41
 42    fn stream_completion(
 43        &self,
 44        request: LanguageModelRequest,
 45        cx: &AsyncAppContext,
 46    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 47
 48    fn use_tool(
 49        &self,
 50        request: LanguageModelRequest,
 51        name: String,
 52        description: String,
 53        schema: serde_json::Value,
 54        cx: &AsyncAppContext,
 55    ) -> BoxFuture<'static, Result<serde_json::Value>>;
 56}
 57
 58pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
 59    fn name() -> String;
 60    fn description() -> String;
 61}
 62
 63pub trait LanguageModelProvider: 'static {
 64    fn id(&self) -> LanguageModelProviderId;
 65    fn name(&self) -> LanguageModelProviderName;
 66    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
 67    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
 68    fn is_authenticated(&self, cx: &AppContext) -> bool;
 69    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
 70    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 71    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
 72}
 73
 74pub trait LanguageModelProviderState: 'static {
 75    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
 76}
 77
 78#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 79pub struct LanguageModelId(pub SharedString);
 80
 81#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 82pub struct LanguageModelName(pub SharedString);
 83
 84#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 85pub struct LanguageModelProviderId(pub SharedString);
 86
 87#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 88pub struct LanguageModelProviderName(pub SharedString);
 89
 90impl From<String> for LanguageModelId {
 91    fn from(value: String) -> Self {
 92        Self(SharedString::from(value))
 93    }
 94}
 95
 96impl From<String> for LanguageModelName {
 97    fn from(value: String) -> Self {
 98        Self(SharedString::from(value))
 99    }
100}
101
102impl From<String> for LanguageModelProviderId {
103    fn from(value: String) -> Self {
104        Self(SharedString::from(value))
105    }
106}
107
108impl From<String> for LanguageModelProviderName {
109    fn from(value: String) -> Self {
110        Self(SharedString::from(value))
111    }
112}