language_model.rs

  1mod model;
  2pub mod provider;
  3mod registry;
  4mod request;
  5mod role;
  6pub mod settings;
  7
  8use std::sync::Arc;
  9
 10use anyhow::Result;
 11use client::Client;
 12use futures::{future::BoxFuture, stream::BoxStream};
 13use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
 14
 15pub use model::*;
 16use project::Fs;
 17pub use registry::*;
 18pub use request::*;
 19pub use role::*;
 20use schemars::JsonSchema;
 21use serde::de::DeserializeOwned;
 22
 23pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
 24    settings::init(fs, cx);
 25    registry::init(client, cx);
 26}
 27
 28pub trait LanguageModel: Send + Sync {
 29    fn id(&self) -> LanguageModelId;
 30    fn name(&self) -> LanguageModelName;
 31    fn provider_id(&self) -> LanguageModelProviderId;
 32    fn provider_name(&self) -> LanguageModelProviderName;
 33    fn telemetry_id(&self) -> String;
 34
 35    fn max_token_count(&self) -> usize;
 36
 37    fn count_tokens(
 38        &self,
 39        request: LanguageModelRequest,
 40        cx: &AppContext,
 41    ) -> BoxFuture<'static, Result<usize>>;
 42
 43    fn stream_completion(
 44        &self,
 45        request: LanguageModelRequest,
 46        cx: &AsyncAppContext,
 47    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 48
 49    fn use_tool(
 50        &self,
 51        request: LanguageModelRequest,
 52        name: String,
 53        description: String,
 54        schema: serde_json::Value,
 55        cx: &AsyncAppContext,
 56    ) -> BoxFuture<'static, Result<serde_json::Value>>;
 57}
 58
 59pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
 60    fn name() -> String;
 61    fn description() -> String;
 62}
 63
 64pub trait LanguageModelProvider: 'static {
 65    fn id(&self) -> LanguageModelProviderId;
 66    fn name(&self) -> LanguageModelProviderName;
 67    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
 68    fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
 69    fn is_authenticated(&self, cx: &AppContext) -> bool;
 70    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
 71    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
 72    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
 73}
 74
 75pub trait LanguageModelProviderState: 'static {
 76    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription>;
 77}
 78
 79#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 80pub struct LanguageModelId(pub SharedString);
 81
 82#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 83pub struct LanguageModelName(pub SharedString);
 84
 85#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 86pub struct LanguageModelProviderId(pub SharedString);
 87
 88#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
 89pub struct LanguageModelProviderName(pub SharedString);
 90
 91impl From<String> for LanguageModelId {
 92    fn from(value: String) -> Self {
 93        Self(SharedString::from(value))
 94    }
 95}
 96
 97impl From<String> for LanguageModelName {
 98    fn from(value: String) -> Self {
 99        Self(SharedString::from(value))
100    }
101}
102
103impl From<String> for LanguageModelProviderId {
104    fn from(value: String) -> Self {
105        Self(SharedString::from(value))
106    }
107}
108
109impl From<String> for LanguageModelProviderName {
110    fn from(value: String) -> Self {
111        Self(SharedString::from(value))
112    }
113}