language_model.go

 1package bedrock
 2
 3import (
 4	"context"
 5	"encoding/json"
 6	"errors"
 7	"fmt"
 8
 9	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
10	"github.com/charmbracelet/fantasy/ai"
11)
12
13type languageModel struct {
14	provider string
15	modelID  string
16	client   *bedrockruntime.Client
17}
18
19func (b languageModel) Model() string {
20	return b.modelID
21}
22
23func (b languageModel) Provider() string {
24	return b.provider
25}
26
27func (b languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
28	params, err := b.prepareParams(call)
29	if err != nil {
30		return nil, err
31	}
32
33	output, err := b.client.InvokeModel(ctx, params)
34	if err != nil {
35		return nil, err
36	}
37
38	panic(fmt.Sprintf("bedrock output: %+v", output))
39
40	// return &ai.Response{
41	// 	Content: content,
42	// 	// Usage: ai.Usage{
43	// 	// 	InputTokens:         response.Usage.InputTokens,
44	// 	// 	OutputTokens:        response.Usage.OutputTokens,
45	// 	// 	TotalTokens:         response.Usage.InputTokens + response.Usage.OutputTokens,
46	// 	// 	CacheCreationTokens: response.Usage.CacheCreationInputTokens,
47	// 	// 	CacheReadTokens:     response.Usage.CacheReadInputTokens,
48	// 	// },
49	// 	// FinishReason:     mapFinishReason(string(response.StopReason)),
50	// 	ProviderMetadata: ai.ProviderMetadata{},
51	// 	// Warnings:         warnings,
52	// }, nil
53}
54
55func (b languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
56	return nil, errors.New("bedrock provider not fully implemented")
57}
58
59func (b languageModel) prepareParams(call ai.Call) (*bedrockruntime.InvokeModelInput, error) {
60	input := bedrockruntime.InvokeModelInput{
61		// ModelId:     ptr(fmt.Sprintf("us-east-1.%s", b.modelID)),
62		ModelId: ptr(b.modelID),
63		// ModelId:     ptr("us-east-1.anthropic.claude-sonnet-4-5-v2:0"),
64		ContentType: ptr("application/json"),
65		// Body:        body,
66	}
67
68	// call.Prompt
69
70	switch {
71	case containsAny(b.modelID, "anthropic", "claude", "sonnet"):
72		i, err := toClaudeInput(call)
73		if err != nil {
74			return nil, err
75		}
76		input.Body, err = json.Marshal(i)
77		if err != nil {
78			return nil, err
79		}
80	default:
81		return nil, fmt.Errorf("fantasy: bedrock provider does not support model: %s", b.modelID)
82	}
83
84	return &input, nil
85}