1package providertests
2
3import (
4 "net/http"
5 "os"
6 "testing"
7
8 "charm.land/fantasy"
9 "charm.land/fantasy/providers/bedrock"
10 "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
11)
12
13func TestBedrockCommon(t *testing.T) {
14 testCommon(t, []builderPair{
15 {"bedrock-anthropic-claude-3-sonnet", builderBedrockClaude3Sonnet, nil},
16 {"bedrock-anthropic-claude-3-opus", builderBedrockClaude3Opus, nil},
17 {"bedrock-anthropic-claude-3-haiku", builderBedrockClaude3Haiku, nil},
18 })
19}
20
21func TestBedrockBasicAuth(t *testing.T) {
22 testSimple(t, builderPair{"bedrock-anthropic-claude-3-sonnet", buildersBedrockBasicAuth, nil})
23}
24
25func builderBedrockClaude3Sonnet(r *recorder.Recorder) (fantasy.LanguageModel, error) {
26 provider := bedrock.New(
27 bedrock.WithHTTPClient(&http.Client{Transport: r}),
28 bedrock.WithSkipAuth(!r.IsRecording()),
29 )
30 return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0")
31}
32
33func builderBedrockClaude3Opus(r *recorder.Recorder) (fantasy.LanguageModel, error) {
34 provider := bedrock.New(
35 bedrock.WithHTTPClient(&http.Client{Transport: r}),
36 bedrock.WithSkipAuth(!r.IsRecording()),
37 )
38 return provider.LanguageModel("us.anthropic.claude-3-opus-20240229-v1:0")
39}
40
41func builderBedrockClaude3Haiku(r *recorder.Recorder) (fantasy.LanguageModel, error) {
42 provider := bedrock.New(
43 bedrock.WithHTTPClient(&http.Client{Transport: r}),
44 bedrock.WithSkipAuth(!r.IsRecording()),
45 )
46 return provider.LanguageModel("us.anthropic.claude-3-haiku-20240307-v1:0")
47}
48
49func buildersBedrockBasicAuth(r *recorder.Recorder) (fantasy.LanguageModel, error) {
50 provider := bedrock.New(
51 bedrock.WithHTTPClient(&http.Client{Transport: r}),
52 bedrock.WithAPIKey(os.Getenv("FANTASY_BEDROCK_API_KEY")),
53 bedrock.WithSkipAuth(true),
54 )
55 return provider.LanguageModel("us.anthropic.claude-3-sonnet-20240229-v1:0")
56}