model-resolver.ts

  1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: GPL-3.0-or-later
  4
  5import { getModel, type Model } from "@mariozechner/pi-ai";
  6import type { CustomModelConfig, RumiloConfig } from "../config/schema.js";
  7import { resolveHeaders } from "../util/env.js";
  8import { ConfigError } from "../util/errors.js";
  9
 10export function resolveModel(
 11  modelString: string,
 12  config: RumiloConfig,
 13): Model<any> {
 14  const colonIndex = modelString.indexOf(":");
 15  if (colonIndex === -1) {
 16    throw new ConfigError(
 17      `Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`,
 18    );
 19  }
 20
 21  const provider = modelString.slice(0, colonIndex);
 22  const modelName = modelString.slice(colonIndex + 1);
 23
 24  if (!provider || !modelName) {
 25    throw new ConfigError(
 26      `Invalid model format: "${modelString}". Expected provider:model (e.g. anthropic:claude-sonnet-4-20250514)`,
 27    );
 28  }
 29
 30  // Handle custom models
 31  if (provider === "custom") {
 32    return resolveCustomModel(modelName, config);
 33  }
 34
 35  // Handle built-in providers
 36  return getModel(provider as any, modelName);
 37}
 38
 39function resolveCustomModel(
 40  modelName: string,
 41  config: RumiloConfig,
 42): Model<any> {
 43  if (!config.custom_models) {
 44    throw new ConfigError(
 45      `No custom models defined in config. Add a [custom_models.${modelName}] section to use custom:${modelName}`,
 46    );
 47  }
 48
 49  const customConfig = config.custom_models[modelName];
 50  if (!customConfig) {
 51    const available = Object.keys(config.custom_models).join(", ");
 52    throw new ConfigError(
 53      `Custom model '${modelName}' not found. Available custom models: ${available}`,
 54    );
 55  }
 56
 57  return buildCustomModel(customConfig);
 58}
 59
 60function buildCustomModel(config: CustomModelConfig): Model<any> {
 61  const api = config.api as any;
 62
 63  const cost: any = {
 64    input: config.cost.input,
 65    output: config.cost.output,
 66  };
 67
 68  if (config.cost.cache_read !== undefined) {
 69    cost.cacheRead = config.cost.cache_read;
 70  }
 71
 72  if (config.cost.cache_write !== undefined) {
 73    cost.cacheWrite = config.cost.cache_write;
 74  }
 75
 76  const model: any = {
 77    id: config.id,
 78    name: config.name,
 79    api,
 80    provider: config.provider as any,
 81    baseUrl: config.base_url,
 82    reasoning: config.reasoning,
 83    input: config.input,
 84    cost,
 85    contextWindow: config.context_window,
 86    maxTokens: config.max_tokens,
 87  };
 88
 89  const resolvedHeaders = resolveHeaders(config.headers);
 90  if (resolvedHeaders) {
 91    model.headers = resolvedHeaders;
 92  }
 93
 94  if (config.compat) {
 95    model.compat = convertCompatConfig(config.compat);
 96  }
 97
 98  return model;
 99}
100
101function convertCompatConfig(compat: CustomModelConfig["compat"]): any {
102  if (!compat) {
103    throw new Error("Compat config is expected to be defined");
104  }
105
106  return {
107    supportsStore: compat.supports_store,
108    supportsDeveloperRole: compat.supports_developer_role,
109    supportsReasoningEffort: compat.supports_reasoning_effort,
110    supportsUsageInStreaming: compat.supports_usage_in_streaming,
111    maxTokensField: compat.max_tokens_field,
112    requiresToolResultName: compat.requires_tool_result_name,
113    requiresAssistantAfterToolResult:
114      compat.requires_assistant_after_tool_result,
115    requiresThinkingAsText: compat.requires_thinking_as_text,
116    requiresMistralToolIds: compat.requires_mistral_tool_ids,
117    thinkingFormat: compat.thinking_format,
118  };
119}