1package bedrock
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8
9 "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
10 "github.com/charmbracelet/fantasy/ai"
11)
12
13type languageModel struct {
14 provider string
15 modelID string
16 client *bedrockruntime.Client
17}
18
19func (b languageModel) Model() string {
20 return b.modelID
21}
22
23func (b languageModel) Provider() string {
24 return b.provider
25}
26
27func (b languageModel) Generate(ctx context.Context, call ai.Call) (*ai.Response, error) {
28 params, err := b.prepareParams(call)
29 if err != nil {
30 return nil, err
31 }
32
33 output, err := b.client.InvokeModel(ctx, params)
34 if err != nil {
35 return nil, err
36 }
37
38 panic(fmt.Sprintf("bedrock output: %+v", output))
39
40 // return &ai.Response{
41 // Content: content,
42 // // Usage: ai.Usage{
43 // // InputTokens: response.Usage.InputTokens,
44 // // OutputTokens: response.Usage.OutputTokens,
45 // // TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
46 // // CacheCreationTokens: response.Usage.CacheCreationInputTokens,
47 // // CacheReadTokens: response.Usage.CacheReadInputTokens,
48 // // },
49 // // FinishReason: mapFinishReason(string(response.StopReason)),
50 // ProviderMetadata: ai.ProviderMetadata{},
51 // // Warnings: warnings,
52 // }, nil
53}
54
55func (b languageModel) Stream(ctx context.Context, call ai.Call) (ai.StreamResponse, error) {
56 return nil, errors.New("bedrock provider not fully implemented")
57}
58
59func (b languageModel) prepareParams(call ai.Call) (*bedrockruntime.InvokeModelInput, error) {
60 input := bedrockruntime.InvokeModelInput{
61 // ModelId: ptr(fmt.Sprintf("us-east-1.%s", b.modelID)),
62 ModelId: ptr(b.modelID),
63 // ModelId: ptr("us-east-1.anthropic.claude-sonnet-4-5-v2:0"),
64 ContentType: ptr("application/json"),
65 // Body: body,
66 }
67
68 // call.Prompt
69
70 switch {
71 case containsAny(b.modelID, "anthropic", "claude", "sonnet"):
72 i, err := toClaudeInput(call)
73 if err != nil {
74 return nil, err
75 }
76 input.Body, err = json.Marshal(i)
77 if err != nil {
78 return nil, err
79 }
80 default:
81 return nil, fmt.Errorf("fantasy: bedrock provider does not support model: %s", b.modelID)
82 }
83
84 return &input, nil
85}