model-resolver.ts

  1import { getModel, type Model } from "@mariozechner/pi-ai";
  2import {
  3  type CustomModelConfig,
  4  type RumiloConfig,
  5} from "../config/schema.js";
  6import { ConfigError } from "../util/errors.js";
  7
  8export function resolveModel(
  9  modelString: string,
 10  config: RumiloConfig,
 11): Model<any> {
 12  const colonIndex = modelString.indexOf(":");
 13  if (colonIndex === -1) {
 14    throw new ConfigError("Model must be in provider:model format");
 15  }
 16
 17  const provider = modelString.slice(0, colonIndex);
 18  const modelName = modelString.slice(colonIndex + 1);
 19
 20  if (!provider || !modelName) {
 21    throw new ConfigError("Model must be in provider:model format");
 22  }
 23
 24  // Handle custom models
 25  if (provider === "custom") {
 26    return resolveCustomModel(modelName, config);
 27  }
 28
 29  // Handle built-in providers
 30  return getModel(provider as any, modelName);
 31}
 32
 33function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
 34  if (!config.custom_models) {
 35    throw new ConfigError(
 36      `No custom models configured. Use 'custom:' prefix only with custom model definitions in config.`,
 37    );
 38  }
 39
 40  const customConfig = config.custom_models[modelName];
 41  if (!customConfig) {
 42    const available = Object.keys(config.custom_models).join(", ");
 43    throw new ConfigError(
 44      `Custom model '${modelName}' not found. Available custom models: ${available}`,
 45    );
 46  }
 47
 48  return buildCustomModel(customConfig);
 49}
 50
 51function buildCustomModel(config: CustomModelConfig): Model<any> {
 52  const api = config.api as any;
 53
 54  const cost: any = {
 55    input: config.cost.input,
 56    output: config.cost.output,
 57  };
 58
 59  if (config.cost.cache_read !== undefined) {
 60    cost.cacheRead = config.cost.cache_read;
 61  }
 62
 63  if (config.cost.cache_write !== undefined) {
 64    cost.cacheWrite = config.cost.cache_write;
 65  }
 66
 67  const model: any = {
 68    id: config.id,
 69    name: config.name,
 70    api,
 71    provider: config.provider as any,
 72    baseUrl: config.base_url,
 73    reasoning: config.reasoning,
 74    input: config.input,
 75    cost,
 76    contextWindow: config.context_window,
 77    maxTokens: config.max_tokens,
 78  };
 79
 80  if (config.headers) {
 81    model.headers = config.headers;
 82  }
 83
 84  if (config.compat) {
 85    model.compat = convertCompatConfig(config.compat);
 86  }
 87
 88  return model;
 89}
 90
 91function convertCompatConfig(
 92  compat: CustomModelConfig["compat"],
 93): any {
 94  if (!compat) {
 95    throw new Error("Compat config is expected to be defined");
 96  }
 97
 98  return {
 99    supportsStore: compat.supports_store,
100    supportsDeveloperRole: compat.supports_developer_role,
101    supportsReasoningEffort: compat.supports_reasoning_effort,
102    supportsUsageInStreaming: compat.supports_usage_in_streaming,
103    maxTokensField: compat.max_tokens_field,
104    requiresToolResultName: compat.requires_tool_result_name,
105    requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
106    requiresThinkingAsText: compat.requires_thinking_as_text,
107    requiresMistralToolIds: compat.requires_mistral_tool_ids,
108    thinkingFormat: compat.thinking_format,
109  };
110}