model-resolver.ts

  1import { getModel, type Model } from "@mariozechner/pi-ai";
  2import {
  3  type CustomModelConfig,
  4  type RumiloConfig,
  5} from "../config/schema.js";
  6import { ConfigError } from "../util/errors.js";
  7
  8export function resolveModel(
  9  modelString: string,
 10  config: RumiloConfig,
 11): Model<any> {
 12  const [provider, modelName] = modelString.split(":");
 13
 14  if (!provider || !modelName) {
 15    throw new ConfigError("Model must be in provider:model format");
 16  }
 17
 18  // Handle custom models
 19  if (provider === "custom") {
 20    return resolveCustomModel(modelName, config);
 21  }
 22
 23  // Handle built-in providers
 24  return getModel(provider as any, modelName);
 25}
 26
 27function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
 28  if (!config.custom_models) {
 29    throw new ConfigError(
 30      `No custom models configured. Use 'custom:' prefix only with custom model definitions in config.`,
 31    );
 32  }
 33
 34  const customConfig = config.custom_models[modelName];
 35  if (!customConfig) {
 36    const available = Object.keys(config.custom_models).join(", ");
 37    throw new ConfigError(
 38      `Custom model '${modelName}' not found. Available custom models: ${available}`,
 39    );
 40  }
 41
 42  return buildCustomModel(customConfig);
 43}
 44
 45function buildCustomModel(config: CustomModelConfig): Model<any> {
 46  const api = config.api as any;
 47
 48  const cost: any = {
 49    input: config.cost.input,
 50    output: config.cost.output,
 51  };
 52
 53  if (config.cost.cache_read !== undefined) {
 54    cost.cacheRead = config.cost.cache_read;
 55  }
 56
 57  if (config.cost.cache_write !== undefined) {
 58    cost.cacheWrite = config.cost.cache_write;
 59  }
 60
 61  const model: any = {
 62    id: config.id,
 63    name: config.name,
 64    api,
 65    provider: config.provider as any,
 66    baseUrl: config.base_url,
 67    reasoning: config.reasoning,
 68    input: config.input,
 69    cost,
 70    contextWindow: config.context_window,
 71    maxTokens: config.max_tokens,
 72  };
 73
 74  if (config.headers) {
 75    model.headers = config.headers;
 76  }
 77
 78  if (config.compat) {
 79    model.compat = convertCompatConfig(config.compat);
 80  }
 81
 82  return model;
 83}
 84
 85function convertCompatConfig(
 86  compat: CustomModelConfig["compat"],
 87): any {
 88  if (!compat) {
 89    throw new Error("Compat config is expected to be defined");
 90  }
 91
 92  return {
 93    supportsStore: compat.supports_store,
 94    supportsDeveloperRole: compat.supports_developer_role,
 95    supportsReasoningEffort: compat.supports_reasoning_effort,
 96    supportsUsageInStreaming: compat.supports_usage_in_streaming,
 97    maxTokensField: compat.max_tokens_field,
 98    requiresToolResultName: compat.requires_tool_result_name,
 99    requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
100    requiresThinkingAsText: compat.requires_thinking_as_text,
101    requiresMistralToolIds: compat.requires_mistral_tool_ids,
102    thinkingFormat: compat.thinking_format,
103  };
104}