mirror of
https://github.com/k8sgpt-ai/k8sgpt.git
synced 2026-06-01 15:37:27 +02:00
feat: add Anthropic API support (#1652)
Build container / Prepare CI Run (push) Has been cancelled
release / release-please (push) Has been cancelled
Run tests / build (push) Has been cancelled
Build container / Build and Push Multi-arch Image (push) Has been cancelled
release / goreleaser (push) Has been cancelled
release / build-container (push) Has been cancelled
Build container / Prepare CI Run (push) Has been cancelled
release / release-please (push) Has been cancelled
Run tests / build (push) Has been cancelled
Build container / Build and Push Multi-arch Image (push) Has been cancelled
release / goreleaser (push) Has been cancelled
release / build-container (push) Has been cancelled
Implement the Anthropic backend client for Claude models, including: - AnthropicClient with Configure/GetCompletion/GetName methods - Proxy, custom headers, and base URL support - Default model (claude-3-5-sonnet-latest) fallback in auth add command - Unit tests with httptest mock server Signed-off-by: Alex <alexsimonjones@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
+14
-8
@@ -20,16 +20,18 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/k8sgpt-ai/k8sgpt/pkg/ai"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBackend = "openai"
|
||||
defaultModel = "gpt-4o"
|
||||
defaultBackend = "openai"
|
||||
defaultModel = "gpt-4o"
|
||||
anthropicDefaultModel = "claude-3-5-sonnet-latest"
|
||||
)
|
||||
|
||||
var addCmd = &cobra.Command{
|
||||
@@ -112,8 +114,12 @@ var addCmd = &cobra.Command{
|
||||
|
||||
// check if model is not empty
|
||||
if model == "" {
|
||||
model = defaultModel
|
||||
color.Yellow(fmt.Sprintf("Warning: model input is empty, will use the default value: %s", defaultModel))
|
||||
fallbackModel := defaultModel
|
||||
if backend == "anthropic" {
|
||||
fallbackModel = anthropicDefaultModel
|
||||
}
|
||||
model = fallbackModel
|
||||
color.Yellow(fmt.Sprintf("Warning: model input is empty, will use the default value: %s", fallbackModel))
|
||||
}
|
||||
if temperature > 1.0 || temperature < 0.0 {
|
||||
color.Red("Error: temperature ranges from 0 to 1.")
|
||||
@@ -195,11 +201,11 @@ func init() {
|
||||
addCmd.Flags().Float32VarP(&temperature, "temperature", "t", 0.7, "The sampling temperature, value ranges between 0 ( output be more deterministic) and 1 (more random)")
|
||||
// add flag for azure open ai engine/deployment name
|
||||
addCmd.Flags().StringVarP(&engine, "engine", "e", "", "Azure AI deployment name (only for azureopenai backend)")
|
||||
//add flag for amazonbedrock region name
|
||||
// add flag for amazonbedrock region name
|
||||
addCmd.Flags().StringVarP(&providerRegion, "providerRegion", "r", "", "Provider Region name (only for amazonbedrock, amazonbedrockconverse, googlevertexai backend)")
|
||||
//add flag for vertexAI/WatsonxAI Project ID
|
||||
// add flag for vertexAI/WatsonxAI Project ID
|
||||
addCmd.Flags().StringVarP(&providerId, "providerId", "i", "", "Provider specific ID for e.g. project (only for googlevertexai/ibmwatsonxai backend)")
|
||||
//add flag for OCI Compartment ID
|
||||
// add flag for OCI Compartment ID
|
||||
addCmd.Flags().StringVarP(&compartmentId, "compartmentId", "k", "", "Compartment ID for generative AI model (only for oci backend)")
|
||||
// add flag for openai organization
|
||||
addCmd.Flags().StringVarP(&organizationId, "organizationId", "o", "", "OpenAI or AzureOpenAI Organization ID (only for openai and azureopenai backend)")
|
||||
|
||||
@@ -21,7 +21,6 @@ require (
|
||||
k8s.io/apimachinery v0.32.3
|
||||
k8s.io/client-go v0.32.3
|
||||
k8s.io/kubectl v0.32.2 // indirect
|
||||
|
||||
)
|
||||
|
||||
require github.com/adrg/xdg v0.5.3
|
||||
@@ -38,6 +37,7 @@ require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.5.0
|
||||
github.com/IBM/watsonx-go v1.0.1
|
||||
github.com/agiledragon/gomonkey/v2 v2.13.0
|
||||
github.com/anthropics/anthropic-sdk-go v1.44.0
|
||||
github.com/aws/aws-sdk-go v1.55.7
|
||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||
github.com/aws/aws-sdk-go-v2/config v1.29.14
|
||||
@@ -97,13 +97,12 @@ require (
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/blang/semver/v4 v4.0.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/buger/jsonparser v1.1.2 // indirect
|
||||
github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect
|
||||
github.com/containerd/console v1.0.4 // indirect
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/creack/pty v1.1.21 // indirect
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
github.com/docker/libtrust v0.0.0-20160708172513-aabc10ec26b7 // indirect
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
|
||||
@@ -140,7 +139,12 @@ require (
|
||||
github.com/sony/gobreaker v0.5.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect
|
||||
github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/x448/float16 v0.8.4 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
|
||||
@@ -724,6 +724,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF
|
||||
github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b h1:mimo19zliBX/vSQ6PWWSL9lK8qwHozUj03+zLoEB8O0=
|
||||
github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b/go.mod h1:fvzegU4vN3H1qMT+8wDmzjAcDONcgo2/SZ/TyfdUOFs=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/anthropics/anthropic-sdk-go v1.44.0 h1:qCNYFccgCf3Zxi8eaqBX9zYmCepsOG1jGqQegu8w8aw=
|
||||
github.com/anthropics/anthropic-sdk-go v1.44.0/go.mod h1:bx5vWuHFuGPkELH8Z4KUiNSohFnUwScdpTyr+50myPo=
|
||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||
github.com/apache/arrow/go/v10 v10.0.1/go.mod h1:YvhnlEePVnBS4+0z3fhPfUy7W1Ikj0Ih0vcRo/gZ1M0=
|
||||
github.com/apache/arrow/go/v11 v11.0.0/go.mod h1:Eg5OsL5H+e299f7u5ssuXsuHQVEGC4xei5aX110hRiI=
|
||||
@@ -781,8 +783,9 @@ github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl
|
||||
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bshuster-repo/logrus-logstash-hook v1.0.0 h1:e+C0SB5R1pu//O4MQ3f9cFuPGoOVeF2fE4Og9otCc70=
|
||||
github.com/bshuster-repo/logrus-logstash-hook v1.0.0/go.mod h1:zsTqEiSzDgAa/8GZR7E1qaXrhYNDKBYy5/dWPTIflbk=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk=
|
||||
github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd h1:rFt+Y/IK1aEZkEHchZRSq9OQbsSzIT/OrI8YFFmRIng=
|
||||
github.com/bugsnag/bugsnag-go v0.0.0-20141110184014-b1d153021fcd/go.mod h1:2oa8nejYd4cQ/b0hMIopN0lCRxU0bueqREvZLWFrtK8=
|
||||
github.com/bugsnag/osext v0.0.0-20130617224835-0dd3f918b21b h1:otBG+dV+YK+Soembjv71DPz3uX/V/6MMlSyD9JBQ6kQ=
|
||||
@@ -839,8 +842,8 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0=
|
||||
github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/cyphar/filepath-securejoin v0.3.6 h1:4d9N5ykBnSp5Xn2JkhocYDkOpURL/18CYMpo6xB9uWM=
|
||||
github.com/cyphar/filepath-securejoin v0.3.6/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -855,6 +858,8 @@ github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2 h1:aB
|
||||
github.com/distribution/distribution/v3 v3.0.0-20221208165359-362910506bc2/go.mod h1:WHNsWjnIn2V1LYOrME7e8KxSeKunYHsxEm4am0BUtcI=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/docker/cli v26.1.4+incompatible h1:I8PHdc0MtxEADqYJZvhBrW9bo8gawKwwenxRM7/rLu8=
|
||||
github.com/docker/cli v26.1.4+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
|
||||
github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk=
|
||||
@@ -1451,6 +1456,8 @@ github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMps
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs=
|
||||
github.com/stackitcloud/stackit-sdk-go/core v0.17.2 h1:jPyn+i8rkp2hM80+hOg0B/1EVRbMt778Tr5RWyK1m2E=
|
||||
github.com/stackitcloud/stackit-sdk-go/core v0.17.2/go.mod h1:8KIw3czdNJ9sdil9QQimxjR6vHjeINFrRv0iZ67wfn0=
|
||||
github.com/standard-webhooks/standard-webhooks/libraries v0.0.1 h1:uOfcYT+3QungH6tIGSVCR/Y3KJmgJiHcojJbMTPDZAI=
|
||||
github.com/standard-webhooks/standard-webhooks/libraries v0.0.1/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo=
|
||||
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -1473,6 +1480,16 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/vultr/govultr/v2 v2.17.2 h1:gej/rwr91Puc/tgh+j33p/BLR16UrIPnSr+AIwYWZQs=
|
||||
github.com/vultr/govultr/v2 v2.17.2/go.mod h1:ZFOKGWmgjytfyjeyAdhQlSWwTjh2ig+X49cAp50dzXI=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
)
|
||||
|
||||
const (
|
||||
anthropicClientName = "anthropic"
|
||||
anthropicDefaultBaseURL = "https://api.anthropic.com"
|
||||
anthropicMessagesPath = "/v1/messages"
|
||||
anthropicDefaultModel = "claude-3-5-sonnet-latest"
|
||||
anthropicDefaultMaxTokens = 2048
|
||||
)
|
||||
|
||||
type AnthropicClient struct {
|
||||
nopCloser
|
||||
|
||||
client anthropic.Client
|
||||
token string
|
||||
model string
|
||||
temperature float32
|
||||
topP float32
|
||||
topK int32
|
||||
maxTokens int
|
||||
stopSequences []string
|
||||
customHeaders []http.Header
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) Configure(config IAIConfig) error {
|
||||
baseURL, err := anthropicBaseURL(config.GetBaseURL())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var opts []option.RequestOption
|
||||
opts = append(opts,
|
||||
option.WithoutEnvironmentDefaults(),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
if token := config.GetPassword(); token != "" {
|
||||
opts = append(opts, option.WithAPIKey(token))
|
||||
}
|
||||
|
||||
proxyEndpoint := config.GetProxyEndpoint()
|
||||
if proxyEndpoint != "" {
|
||||
proxyURL, err := url.Parse(proxyEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(&http.Client{
|
||||
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
|
||||
}))
|
||||
}
|
||||
|
||||
model := config.GetModel()
|
||||
if model == "" {
|
||||
model = anthropicDefaultModel
|
||||
}
|
||||
maxTokens := config.GetMaxTokens()
|
||||
if maxTokens <= 0 {
|
||||
maxTokens = anthropicDefaultMaxTokens
|
||||
}
|
||||
|
||||
for key, values := range mergeCustomHeaders(config.GetCustomHeaders()) {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
opts = append(opts, option.WithHeaderDel(key), option.WithHeader(key, values[0]))
|
||||
for _, value := range values[1:] {
|
||||
opts = append(opts, option.WithHeaderAdd(key, value))
|
||||
}
|
||||
}
|
||||
|
||||
c.client = anthropic.NewClient(opts...)
|
||||
c.token = config.GetPassword()
|
||||
c.model = model
|
||||
c.temperature = config.GetTemperature()
|
||||
c.topP = config.GetTopP()
|
||||
c.topK = config.GetTopK()
|
||||
c.maxTokens = maxTokens
|
||||
c.stopSequences = config.GetStopSequences()
|
||||
c.customHeaders = config.GetCustomHeaders()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) GetCompletion(ctx context.Context, prompt string) (string, error) {
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: c.model,
|
||||
MaxTokens: int64(c.maxTokens),
|
||||
Messages: []anthropic.MessageParam{anthropic.NewUserMessage(anthropic.NewTextBlock(prompt))},
|
||||
Temperature: anthropic.Float(float64(c.temperature)),
|
||||
TopP: anthropic.Float(float64(c.topP)),
|
||||
}
|
||||
if c.topK > 0 {
|
||||
params.TopK = anthropic.Int(int64(c.topK))
|
||||
}
|
||||
if len(c.stopSequences) > 0 {
|
||||
params.StopSequences = c.stopSequences
|
||||
}
|
||||
|
||||
var textBlocks []string
|
||||
message, err := c.client.Messages.New(ctx, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, content := range message.Content {
|
||||
if content.Type == "text" {
|
||||
text := content.AsText().Text
|
||||
if text != "" {
|
||||
textBlocks = append(textBlocks, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(textBlocks) == 0 {
|
||||
return "", errors.New("anthropic response did not include any text content")
|
||||
}
|
||||
return strings.Join(textBlocks, "\n"), nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) GetName() string {
|
||||
return anthropicClientName
|
||||
}
|
||||
|
||||
func anthropicBaseURL(rawBaseURL string) (string, error) {
|
||||
if rawBaseURL == "" {
|
||||
rawBaseURL = anthropicDefaultBaseURL
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(rawBaseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if baseURL.Scheme == "" || baseURL.Host == "" {
|
||||
return "", fmt.Errorf("invalid anthropic base URL %q", rawBaseURL)
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasSuffix(baseURL.Path, anthropicMessagesPath), strings.HasSuffix(baseURL.Path, "/messages"):
|
||||
baseURL.Path = strings.TrimSuffix(baseURL.Path, anthropicMessagesPath)
|
||||
baseURL.Path = strings.TrimSuffix(baseURL.Path, "/messages")
|
||||
if baseURL.Path == "" {
|
||||
baseURL.Path = "/"
|
||||
}
|
||||
return strings.TrimRight(baseURL.String(), "/"), nil
|
||||
default:
|
||||
baseURL.Path = path.Clean(baseURL.Path)
|
||||
if baseURL.Path == "." {
|
||||
baseURL.Path = ""
|
||||
}
|
||||
return strings.TrimRight(baseURL.String(), "/"), nil
|
||||
}
|
||||
}
|
||||
|
||||
func mergeCustomHeaders(headers []http.Header) http.Header {
|
||||
merged := http.Header{}
|
||||
for _, header := range headers {
|
||||
for key, values := range header {
|
||||
copiedValues := make([]string, len(values))
|
||||
copy(copiedValues, values)
|
||||
merged[key] = append(merged[key], copiedValues...)
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type anthropicMockConfig struct {
|
||||
baseURL string
|
||||
password string
|
||||
model string
|
||||
temperature float32
|
||||
topP float32
|
||||
topK int32
|
||||
maxTokens int
|
||||
stopSequences []string
|
||||
customHeaders []http.Header
|
||||
}
|
||||
|
||||
func (m *anthropicMockConfig) GetPassword() string { return m.password }
|
||||
func (m *anthropicMockConfig) GetModel() string { return m.model }
|
||||
func (m *anthropicMockConfig) GetBaseURL() string { return m.baseURL }
|
||||
func (m *anthropicMockConfig) GetProxyEndpoint() string { return "" }
|
||||
func (m *anthropicMockConfig) GetEndpointName() string { return "" }
|
||||
func (m *anthropicMockConfig) GetEngine() string { return "" }
|
||||
func (m *anthropicMockConfig) GetTemperature() float32 { return m.temperature }
|
||||
func (m *anthropicMockConfig) GetProviderRegion() string { return "" }
|
||||
func (m *anthropicMockConfig) GetTopP() float32 { return m.topP }
|
||||
func (m *anthropicMockConfig) GetTopK() int32 { return m.topK }
|
||||
func (m *anthropicMockConfig) GetMaxTokens() int { return m.maxTokens }
|
||||
func (m *anthropicMockConfig) GetStopSequences() []string { return m.stopSequences }
|
||||
func (m *anthropicMockConfig) GetProviderId() string { return "" }
|
||||
func (m *anthropicMockConfig) GetCompartmentId() string { return "" }
|
||||
func (m *anthropicMockConfig) GetOrganizationId() string { return "" }
|
||||
func (m *anthropicMockConfig) GetAzureAPIType() string { return "" }
|
||||
func (m *anthropicMockConfig) GetCustomHeaders() []http.Header { return m.customHeaders }
|
||||
|
||||
func TestAnthropicClientGetCompletion(t *testing.T) {
|
||||
type requestBody struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
TopK int32 `json:"top_k"`
|
||||
StopSequences []string `json:"stop_sequences"`
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1/messages", r.URL.Path)
|
||||
assert.Equal(t, "test-token", r.Header.Get("X-Api-Key"))
|
||||
assert.NotEmpty(t, r.Header.Get("Anthropic-Version"))
|
||||
assert.Equal(t, "test-value", r.Header.Get("X-Test-Header"))
|
||||
|
||||
var body requestBody
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
|
||||
assert.Equal(t, "claude-test", body.Model)
|
||||
assert.Equal(t, 1024, body.MaxTokens)
|
||||
assert.Equal(t, float32(0.1), body.Temperature)
|
||||
assert.Equal(t, float32(0.8), body.TopP)
|
||||
assert.Equal(t, int32(25), body.TopK)
|
||||
assert.Equal(t, []string{"STOP"}, body.StopSequences)
|
||||
require.Len(t, body.Messages, 1)
|
||||
assert.Equal(t, "user", body.Messages[0].Role)
|
||||
require.Len(t, body.Messages[0].Content, 1)
|
||||
assert.Equal(t, "text", body.Messages[0].Content[0].Type)
|
||||
assert.Equal(t, "hello cluster", body.Messages[0].Content[0].Text)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write([]byte(`{"content":[{"type":"text","text":"diagnosis"}]}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &AnthropicClient{}
|
||||
err := client.Configure(&anthropicMockConfig{
|
||||
baseURL: server.URL,
|
||||
password: "test-token",
|
||||
model: "claude-test",
|
||||
temperature: 0.1,
|
||||
topP: 0.8,
|
||||
topK: 25,
|
||||
maxTokens: 1024,
|
||||
stopSequences: []string{"STOP"},
|
||||
customHeaders: []http.Header{{"X-Test-Header": []string{"test-value"}}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
completion, err := client.GetCompletion(context.Background(), "hello cluster")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "diagnosis", completion)
|
||||
}
|
||||
|
||||
func TestAnthropicClientHonorsExplicitMessagesURLAndDefaultModel(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/custom/v1/messages", r.URL.Path)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write([]byte(`{"content":[{"type":"text","text":"ok"}]}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &AnthropicClient{}
|
||||
err := client.Configure(&anthropicMockConfig{
|
||||
baseURL: server.URL + "/custom/v1/messages",
|
||||
password: "test-token",
|
||||
maxTokens: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, anthropicDefaultModel, client.model)
|
||||
assert.Equal(t, anthropicDefaultMaxTokens, client.maxTokens)
|
||||
|
||||
completion, err := client.GetCompletion(context.Background(), "hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", completion)
|
||||
}
|
||||
|
||||
func TestAnthropicClientReturnsStructuredErrors(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, err := w.Write([]byte(`{"error":{"type":"invalid_request_error","message":"bad prompt"}}`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := &AnthropicClient{}
|
||||
err := client.Configure(&anthropicMockConfig{baseURL: server.URL, password: "test-token"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.GetCompletion(context.Background(), "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "bad prompt")
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
var (
|
||||
clients = []IAI{
|
||||
&OpenAIClient{},
|
||||
&AnthropicClient{},
|
||||
&AzureAIClient{},
|
||||
&LocalAIClient{},
|
||||
&OllamaClient{},
|
||||
@@ -39,6 +40,7 @@ var (
|
||||
}
|
||||
Backends = []string{
|
||||
openAIClientName,
|
||||
anthropicClientName,
|
||||
localAIClientName,
|
||||
ollamaClientName,
|
||||
azureAIClientName,
|
||||
|
||||
Reference in New Issue
Block a user